From e559eb97891404157894cafa10d6de09e73c64ce Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Sun, 3 Jan 2021 21:47:56 +0000 Subject: [PATCH 001/264] Initial implementation of Prob U-Net using pytorch --- platipy/imaging/cnn/prob_unet.py | 277 +++++++++++++++++++++++++++++++ platipy/imaging/cnn/unet.py | 145 ++++++++++++++++ requirements.txt | 2 + 3 files changed, 424 insertions(+) create mode 100644 platipy/imaging/cnn/prob_unet.py create mode 100644 platipy/imaging/cnn/unet.py diff --git a/platipy/imaging/cnn/prob_unet.py b/platipy/imaging/cnn/prob_unet.py new file mode 100644 index 00000000..2d2c329d --- /dev/null +++ b/platipy/imaging/cnn/prob_unet.py @@ -0,0 +1,277 @@ +# Copyright 2020 University of New South Wales, University of Sydney, Ingham Institute + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Parts of this work are derived from: +# https://github.com/stefanknegt/Probabilistic-Unet-Pytorch +# which is released under the Apache Licence 2.0 + +# pylint: disable=invalid-name + +import torch +from torch.distributions import Normal, Independent, kl + +import numpy as np + +from platipy.imaging.cnn.unet import UNet, Conv, init_weights + + +class Encoder(torch.nn.Module): + def __init__(self, input_channels, filters_per_layer=[64 * (2 ** x) for x in range(5)]): + super(Encoder, self).__init__() + + layers = [] + for idx, layer_filters in enumerate(filters_per_layer): + + input_filters = input_channels if idx == 0 else output_filters + output_filters = layer_filters + + down_sample = 0 if idx == 0 else -2 + + layers.append(Conv(input_filters, output_filters, up_down_sample=down_sample)) + + self.layers = torch.nn.Sequential(*layers) + + self.layers.apply(init_weights) + + def forward(self, x): + + return self.layers(x) + + +class AxisAlignedConvGaussian(torch.nn.Module): + def __init__( + self, input_channels, filters_per_layer=[64 * (2 ** x) for x in range(5)], latent_dim=2 + ): + + super(AxisAlignedConvGaussian, self).__init__() + + self.latent_dim = latent_dim + + self.encoder = Encoder(input_channels, filters_per_layer) + self.final = torch.nn.Conv2d(filters_per_layer[-1], 2 * self.latent_dim, (1, 1), stride=1) + + self.final.apply(init_weights) + + def forward(self, img, seg=None): + + x = img + if seg is not None: + seg = torch.unsqueeze(seg, dim=1) + x = torch.cat((img, seg), dim=1) + + encoding = self.encoder(x) + + # We only want the mean of the resulting hxw image + encoding = torch.mean(encoding, dim=2, keepdim=True) + encoding = torch.mean(encoding, dim=3, keepdim=True) + + # Convert encoding to 2 x latent dim and split up for mu and log_sigma + mu_log_sigma = self.final(encoding) + + # We squeeze the second dimension twice, since otherwise it won't work when batch size is + # equal to 1 + mu_log_sigma = torch.squeeze(mu_log_sigma, dim=2) + mu_log_sigma = torch.squeeze(mu_log_sigma, dim=2) + + mu = mu_log_sigma[:, : self.latent_dim] + log_sigma = mu_log_sigma[:, self.latent_dim :] + + # This is a multivariate normal with diagonal covariance matrix sigma + # https://github.com/pytorch/pytorch/pull/11178 + dist = Independent(Normal(loc=mu, scale=torch.exp(log_sigma)), 1) + return dist + + +class Fcomb(torch.nn.Module): + """ + A function composed of no_convs_fcomb times a 1x1 convolution that combines the sample taken + from the latent space, and output of the UNet (the feature map) by concatenating them along + their channel axis. + """ + + def __init__(self, filters_per_layer, latent_dim, num_classes, no_convs_fcomb): + super(Fcomb, self).__init__() + + layers = [] + + # Decoder of N x a 1x1 convolution followed by a ReLU activation function except for the + # last layer + layers.append( + torch.nn.Conv2d(filters_per_layer[0] + latent_dim, filters_per_layer[0], kernel_size=1) + ) + layers.append(torch.nn.ReLU(inplace=True)) + + for _ in range(no_convs_fcomb - 2): + layers.append( + torch.nn.Conv2d(filters_per_layer[0], filters_per_layer[0], kernel_size=1) + ) + layers.append(torch.nn.ReLU(inplace=True)) + + self.layers = torch.nn.Sequential(*layers) + + self.last_layer = torch.nn.Conv2d(filters_per_layer[0], num_classes, kernel_size=1) + + self.layers.apply(init_weights) + self.last_layer.apply(init_weights) + + def tile(self, a, dim, n_tile): + """ + This function is taken form PyTorch forum and mimics the behavior of tf.tile. + Source: https://discuss.pytorch.org/t/how-to-tile-a-tensor/13853/3 + """ + init_dim = a.size(dim) + repeat_idx = [1] * a.dim() + repeat_idx[dim] = n_tile + a = a.repeat(*(repeat_idx)) + order_index = torch.LongTensor( + np.concatenate([init_dim * np.arange(n_tile) + i for i in range(init_dim)]) + ).cuda() + return torch.index_select(a, dim, order_index) + + def forward(self, feature_map, z): + + z = torch.unsqueeze(z, 2) + z = self.tile(z, 2, feature_map.shape[2]) + z = torch.unsqueeze(z, 3) + z = self.tile(z, 3, feature_map.shape[3]) + + # Concatenate the feature map (output of the UNet) and the sample taken from the latent + # space + feature_map = torch.cat((feature_map, z), dim=1) + output = self.layers(feature_map) + return self.last_layer(output) + + +class ProbabilisticUnet(torch.nn.Module): + """ + A probabilistic UNet implementation + (https://papers.nips.cc/paper/2018/file/473447ac58e1cd7e96172575f48dca3b-Paper.pdf) + + input_channels: the number of channels in the image (1 for greyscale and 3 for RGB) + num_classes: the number of classes to predict + num_filters: is a list consisint of the amount of filters layer + latent_dim: dimension of the latent space + no_cons_per_block: no convs per block in the (convolutional) encoder of prior and posterior + """ + + def __init__( + self, + input_channels=1, + num_classes=2, + filters_per_layer=[64 * (2 ** x) for x in range(5)], + latent_dim=6, + no_convs_fcomb=4, + beta=1.0, + ): + super(ProbabilisticUnet, self).__init__() + + self.no_convs_per_block = 3 + self.no_convs_fcomb = no_convs_fcomb + self.initializers = {"w": "he_normal", "b": "normal"} + self.beta = beta + self.z_prior_sample = 0 + + self.unet = UNet(input_channels, num_classes, filters_per_layer) + self.prior = AxisAlignedConvGaussian(input_channels, filters_per_layer, latent_dim) + self.posterior = AxisAlignedConvGaussian(input_channels + 1, filters_per_layer, latent_dim) + self.fcomb = Fcomb(filters_per_layer, latent_dim, num_classes, no_convs_fcomb) + + self.posterior_latent_space = None + self.prior_latent_space = None + self.unet_features = None + + def forward(self, img, seg, training=True): + """ + Construct prior latent space for patch and run patch through UNet, + in case training is True also construct posterior latent space + """ + if training: + self.posterior_latent_space = self.posterior.forward(img, seg=seg) + self.prior_latent_space = self.prior.forward(img) + self.unet_features = self.unet.forward(img) + + def sample(self, testing=False, use_mean=False, sample_x_stddev_from_mean=None): + """ + Sample a segmentation by reconstructing from a prior sample + and combining this with UNet features + """ + + if testing: + if use_mean: + z_prior = self.prior_latent_space.base_dist.loc + elif not sample_x_stddev_from_mean is None: + z_prior = self.prior_latent_space.base_dist.loc + ( + self.prior_latent_space.base_dist.scale * sample_x_stddev_from_mean + ) + else: + z_prior = self.prior_latent_space.sample() + self.z_prior_sample = z_prior + else: + z_prior = self.prior_latent_space.rsample() + self.z_prior_sample = z_prior + + return self.fcomb.forward(self.unet_features, z_prior) + + def reconstruct(self, use_posterior_mean=False, z_posterior=None): + """ + Reconstruct a segmentation from a posterior sample (decoding a posterior sample) and UNet + feature map + + use_posterior_mean: use posterior_mean instead of sampling z_q + """ + if use_posterior_mean: + z_posterior = self.posterior_latent_space.mean + else: + if z_posterior is None: + z_posterior = self.posterior_latent_space.rsample() + return self.fcomb.forward(self.unet_features, z_posterior) + + def kl_divergence(self, analytic=True, z_posterior=None): + """ + Calculate the KL divergence between the posterior and prior KL(Q||P) + analytic: calculate KL analytically or via sampling from the posterior + """ + if analytic: + kl_div = kl.kl_divergence(self.posterior_latent_space, self.prior_latent_space) + else: + if z_posterior is None: + z_posterior = self.posterior_latent_space.rsample() + log_posterior_prob = self.posterior_latent_space.log_prob(z_posterior) + log_prior_prob = self.prior_latent_space.log_prob(z_posterior) + kl_div = log_posterior_prob - log_prior_prob + return kl_div + + def elbo(self, segm, analytic_kl=True, reconstruct_posterior_mean=False): + """ + Calculate the evidence lower bound of the log-likelihood of P(Y|X) + """ + + criterion = torch.nn.BCEWithLogitsLoss(size_average=False, reduce=False, reduction=None) + z_posterior = self.posterior_latent_space.rsample() + + kl_div = torch.mean(self.kl_divergence(analytic=analytic_kl, z_posterior=z_posterior)) + + # Here we use the posterior sample sampled above + reconstruction = self.reconstruct( + use_posterior_mean=reconstruct_posterior_mean, z_posterior=z_posterior + ) + + segm = torch.unsqueeze(segm, dim=1) + not_seg = segm.logical_not().type("torch.cuda.LongTensor") + segm = torch.cat((not_seg, segm), dim=1).type("torch.cuda.FloatTensor") + reconstruction_loss = criterion(input=reconstruction, target=segm) + reconstruction_loss = torch.sum(reconstruction_loss) + # mean_reconstruction_loss = torch.mean(reconstruction_loss) + + return -(reconstruction_loss + self.beta * kl_div) diff --git a/platipy/imaging/cnn/unet.py b/platipy/imaging/cnn/unet.py new file mode 100644 index 00000000..1dd70bbb --- /dev/null +++ b/platipy/imaging/cnn/unet.py @@ -0,0 +1,145 @@ +# Copyright 2020 University of New South Wales, University of Sydney, Ingham Institute + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Parts of this work are derived from: +# https://github.com/stefanknegt/Probabilistic-Unet-Pytorch +# which is released under the Apache Licence 2.0 + +# pylint: disable=invalid-name + +import torch +from torch import nn + + +def l2_regularisation(m): + l2_reg = None + + for W in m.parameters(): + if l2_reg is None: + l2_reg = W.norm(2) + else: + l2_reg = l2_reg + W.norm(2) + return l2_reg + + +def truncated_normal_(tensor, mean=0, std=1): + size = tensor.shape + tmp = tensor.new_empty(size + (4,)).normal_() + valid = (tmp < 2) & (tmp > -2) + ind = valid.max(-1, keepdim=True)[1] + tensor.data.copy_(tmp.gather(-1, ind).squeeze(-1)) + tensor.data.mul_(std).add_(mean) + + +def init_weights(m): + if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d): + nn.init.kaiming_normal_(m.weight, mode="fan_in", nonlinearity="relu") + truncated_normal_(m.bias, mean=0, std=0.001) + + +class Conv(torch.nn.Module): + def __init__(self, input_channels, output_channels, up_down_sample=0): + + super(Conv, self).__init__() + + self.pre_op = None + size_and_stride = abs(up_down_sample) + if up_down_sample < 0: + self.pre_op = nn.MaxPool2d(kernel_size=size_and_stride, stride=size_and_stride) + elif up_down_sample > 0: + self.pre_op = nn.ConvTranspose2d( + input_channels, + output_channels, + kernel_size=size_and_stride, + stride=size_and_stride, + ) + + layers = [] + layers.append( + nn.Conv2d( + in_channels=input_channels, out_channels=output_channels, kernel_size=3, padding=1 + ) + ) + layers.append(nn.ReLU(inplace=True)) + layers.append( + nn.Conv2d( + in_channels=output_channels, out_channels=output_channels, kernel_size=3, padding=1 + ) + ) + layers.append(nn.ReLU(inplace=True)) + self.layers = nn.Sequential(*layers) + + self.layers.apply(init_weights) + + def forward(self, x, concat=None): + + if not self.pre_op is None: + x = self.pre_op(x) + + if not concat is None: + x = torch.cat([x, concat], 1) + + return self.layers(x) + + +class UNet(nn.Module): + def __init__( + self, + input_channels=1, + output_classes=2, + filters_per_layer=[64 * (2 ** x) for x in range(5)], + final_layer=True, + ): + + super(UNet, self).__init__() + + self.encoder = nn.ModuleList() + for idx, layer_filters in enumerate(filters_per_layer): + input_filters = input_channels if idx == 0 else output_filters + output_filters = layer_filters + down_sample = 0 if idx == 0 else -2 + + self.encoder.append(Conv(input_filters, output_filters, up_down_sample=down_sample)) + + reversed_filters = list(reversed(filters_per_layer)) + self.decoder = nn.ModuleList() + for idx, layer_filters in enumerate(reversed_filters): + + if idx == len(reversed_filters) - 1: + continue + + input_filters = layer_filters + output_filters = reversed_filters[idx + 1] + + self.decoder.append(Conv(input_filters, output_filters, up_down_sample=2)) + + self.final = None + if final_layer: + self.final = nn.Conv2d(filters_per_layer[0], output_classes, kernel_size=1) + + def forward(self, x): + + blocks = [] + for idx, enc in enumerate(self.encoder): + x = enc(x) + if idx != len(self.encoder) - 1: + blocks.append(x) + + for idx, dec in enumerate(self.decoder): + x = dec(x, concat=blocks[-idx - 1]) + + if self.final: + return self.final(x) + + return x diff --git a/requirements.txt b/requirements.txt index 44f30ce7..db77de0c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -19,3 +19,5 @@ gunicorn requests pandas pymedphys >= 0.35.0 +torch == 1.7.1 +torchvision == 0.8.2 From 2128a8b2b50214f76d8c8e1672ec2b59583d28a4 Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Sun, 3 Jan 2021 22:33:35 +0000 Subject: [PATCH 002/264] Disable final layer for prob UNet --- platipy/imaging/cnn/prob_unet.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/platipy/imaging/cnn/prob_unet.py b/platipy/imaging/cnn/prob_unet.py index 2d2c329d..0e35f063 100644 --- a/platipy/imaging/cnn/prob_unet.py +++ b/platipy/imaging/cnn/prob_unet.py @@ -182,7 +182,7 @@ def __init__( self.beta = beta self.z_prior_sample = 0 - self.unet = UNet(input_channels, num_classes, filters_per_layer) + self.unet = UNet(input_channels, num_classes, filters_per_layer, final_layer=False) self.prior = AxisAlignedConvGaussian(input_channels, filters_per_layer, latent_dim) self.posterior = AxisAlignedConvGaussian(input_channels + 1, filters_per_layer, latent_dim) self.fcomb = Fcomb(filters_per_layer, latent_dim, num_classes, no_convs_fcomb) From d5f53f605ce55d9eac4d97f3455b8e10abffa073 Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Mon, 4 Jan 2021 09:47:33 +1100 Subject: [PATCH 003/264] Remove cuda specific code --- platipy/imaging/cnn/prob_unet.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/platipy/imaging/cnn/prob_unet.py b/platipy/imaging/cnn/prob_unet.py index 0e35f063..905e559b 100644 --- a/platipy/imaging/cnn/prob_unet.py +++ b/platipy/imaging/cnn/prob_unet.py @@ -268,8 +268,8 @@ def elbo(self, segm, analytic_kl=True, reconstruct_posterior_mean=False): ) segm = torch.unsqueeze(segm, dim=1) - not_seg = segm.logical_not().type("torch.cuda.LongTensor") - segm = torch.cat((not_seg, segm), dim=1).type("torch.cuda.FloatTensor") + not_seg = segm.logical_not() + segm = torch.cat((not_seg, segm), dim=1).float() reconstruction_loss = criterion(input=reconstruction, target=segm) reconstruction_loss = torch.sum(reconstruction_loss) # mean_reconstruction_loss = torch.mean(reconstruction_loss) From 4361589d6a20559fde6573ab0647cf19cc5fac65 Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Sun, 3 Jan 2021 23:01:12 +0000 Subject: [PATCH 004/264] Replace tile function --- platipy/imaging/cnn/prob_unet.py | 20 ++------------------ 1 file changed, 2 insertions(+), 18 deletions(-) diff --git a/platipy/imaging/cnn/prob_unet.py b/platipy/imaging/cnn/prob_unet.py index 905e559b..d2e38408 100644 --- a/platipy/imaging/cnn/prob_unet.py +++ b/platipy/imaging/cnn/prob_unet.py @@ -125,26 +125,10 @@ def __init__(self, filters_per_layer, latent_dim, num_classes, no_convs_fcomb): self.layers.apply(init_weights) self.last_layer.apply(init_weights) - def tile(self, a, dim, n_tile): - """ - This function is taken form PyTorch forum and mimics the behavior of tf.tile. - Source: https://discuss.pytorch.org/t/how-to-tile-a-tensor/13853/3 - """ - init_dim = a.size(dim) - repeat_idx = [1] * a.dim() - repeat_idx[dim] = n_tile - a = a.repeat(*(repeat_idx)) - order_index = torch.LongTensor( - np.concatenate([init_dim * np.arange(n_tile) + i for i in range(init_dim)]) - ).cuda() - return torch.index_select(a, dim, order_index) - def forward(self, feature_map, z): - z = torch.unsqueeze(z, 2) - z = self.tile(z, 2, feature_map.shape[2]) - z = torch.unsqueeze(z, 3) - z = self.tile(z, 3, feature_map.shape[3]) + z = torch.unsqueeze(z, 2).expand(-1, -1, feature_map.shape[2], -1) + z = torch.unsqueeze(z, 3).expand(-1, -1, -1, feature_map.shape[3], -1) # Concatenate the feature map (output of the UNet) and the sample taken from the latent # space From c0628a96f2cacb06504ac6917056a7d45b154651 Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Sun, 3 Jan 2021 23:02:49 +0000 Subject: [PATCH 005/264] Correct for 2d --- platipy/imaging/cnn/prob_unet.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/platipy/imaging/cnn/prob_unet.py b/platipy/imaging/cnn/prob_unet.py index d2e38408..0327b66f 100644 --- a/platipy/imaging/cnn/prob_unet.py +++ b/platipy/imaging/cnn/prob_unet.py @@ -127,8 +127,8 @@ def __init__(self, filters_per_layer, latent_dim, num_classes, no_convs_fcomb): def forward(self, feature_map, z): - z = torch.unsqueeze(z, 2).expand(-1, -1, feature_map.shape[2], -1) - z = torch.unsqueeze(z, 3).expand(-1, -1, -1, feature_map.shape[3], -1) + z = torch.unsqueeze(z, 2).expand(-1, -1, feature_map.shape[2]) + z = torch.unsqueeze(z, 3).expand(-1, -1, -1, feature_map.shape[3]) # Concatenate the feature map (output of the UNet) and the sample taken from the latent # space From b70dd02e9f8a209fbd5473261af568b10853623d Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Mon, 4 Jan 2021 02:41:30 +0000 Subject: [PATCH 006/264] Add init py --- platipy/imaging/cnn/__init__.py | 0 platipy/imaging/cnn/prob_unet.py | 2 -- 2 files changed, 2 deletions(-) create mode 100644 platipy/imaging/cnn/__init__.py diff --git a/platipy/imaging/cnn/__init__.py b/platipy/imaging/cnn/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/platipy/imaging/cnn/prob_unet.py b/platipy/imaging/cnn/prob_unet.py index 0327b66f..65dcaefc 100644 --- a/platipy/imaging/cnn/prob_unet.py +++ b/platipy/imaging/cnn/prob_unet.py @@ -21,8 +21,6 @@ import torch from torch.distributions import Normal, Independent, kl -import numpy as np - from platipy.imaging.cnn.unet import UNet, Conv, init_weights From 2b482b699ab1049b0f4a958f9e035c024dcb5bdc Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Mon, 4 Jan 2021 21:15:01 +0000 Subject: [PATCH 007/264] Remove deprecated args --- platipy/imaging/cnn/prob_unet.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/platipy/imaging/cnn/prob_unet.py b/platipy/imaging/cnn/prob_unet.py index 65dcaefc..64036d3e 100644 --- a/platipy/imaging/cnn/prob_unet.py +++ b/platipy/imaging/cnn/prob_unet.py @@ -239,7 +239,7 @@ def elbo(self, segm, analytic_kl=True, reconstruct_posterior_mean=False): Calculate the evidence lower bound of the log-likelihood of P(Y|X) """ - criterion = torch.nn.BCEWithLogitsLoss(size_average=False, reduce=False, reduction=None) + criterion = torch.nn.BCEWithLogitsLoss(reduction="mean") z_posterior = self.posterior_latent_space.rsample() kl_div = torch.mean(self.kl_divergence(analytic=analytic_kl, z_posterior=z_posterior)) From d9e4f49789124fde8906904f8f211fc926173c8c Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Mon, 18 Jan 2021 03:55:19 +0000 Subject: [PATCH 008/264] Work on prob net --- .vscode/settings.json | 3 ++- requirements.txt | 1 + 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/.vscode/settings.json b/.vscode/settings.json index 697119ae..abdb5f10 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -11,7 +11,8 @@ "--disable=W0105", "--init-hook", "import sys; sys.path.append('.')", - "--extension-pkg-whitelist=vtk" + "--extension-pkg-whitelist=vtk", + "--generated-members=numpy.* ,torch.*" ], "python.formatting.blackArgs": [ "--line-length", diff --git a/requirements.txt b/requirements.txt index db77de0c..d6f3f543 100644 --- a/requirements.txt +++ b/requirements.txt @@ -21,3 +21,4 @@ pandas pymedphys >= 0.35.0 torch == 1.7.1 torchvision == 0.8.2 +seaborn From 4a5e5e34a4cb31da5c0b4678289774340ff8f094 Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Fri, 26 Mar 2021 02:06:19 +0000 Subject: [PATCH 009/264] Initial code of pytorch implementation of hierarchical probabilistic UNet --- platipy/imaging/cnn/hierarchical_prob_unet.py | 831 ++++++++++++++++++ platipy/imaging/cnn/test_hpunet.py | 158 ++++ 2 files changed, 989 insertions(+) create mode 100644 platipy/imaging/cnn/hierarchical_prob_unet.py create mode 100644 platipy/imaging/cnn/test_hpunet.py diff --git a/platipy/imaging/cnn/hierarchical_prob_unet.py b/platipy/imaging/cnn/hierarchical_prob_unet.py new file mode 100644 index 00000000..6df6eac1 --- /dev/null +++ b/platipy/imaging/cnn/hierarchical_prob_unet.py @@ -0,0 +1,831 @@ +# Copyright 2020 University of New South Wales, University of Sydney, Ingham Institute + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This code is adapted from +# https://github.com/deepmind/deepmind-research/tree/5cf55efe1f1748ebdd33cb69223b0df6bcc88e6a/hierarchical_probabilistic_unet +# which is released under the Apache Licence 2.0 + +# pylint: disable=invalid-name + +import math +import torch + + +class ResBlock(torch.nn.Module): + """A residual block""" + + def __init__( + self, + input_channels, + output_channels, + n_down_channels=None, + activation_fn=torch.nn.ReLU, + convs_per_block=3, + ): + """Create a residual block + + Args: + input_channels (int): The number of input channels to the block + output_channels (int): The number of output channels from the block + n_down_channels (int, optional): The number of intermediate cahnnels within the block. + Defaults to the same as the number of output channels. + activation_fn (torch.nn.Module, optional): The activation function to apply. Defaults + to torch.nn.ReLU. + convs_per_block (int, optional): The number of convolutions to perform within the + block. Defaults to 3. + """ + + super(ResBlock, self).__init__() + + self._activation_fn = activation_fn(inplace=True) + + # Set the number of intermediate channels that we compress to. + if n_down_channels is None: + n_down_channels = output_channels + + layers = [] + in_channels = input_channels + for c in range(convs_per_block): + layers.append( + torch.nn.Conv2d( + in_channels=in_channels, out_channels=n_down_channels, kernel_size=3, padding=1 + ) + ) + + if c < convs_per_block - 1: + layers.append(activation_fn(inplace=True)) + + in_channels = n_down_channels + + if not n_down_channels == output_channels: + resize_outgoing = torch.nn.Conv2d( + in_channels=n_down_channels, out_channels=output_channels, kernel_size=1, padding=0 + ) + layers.append(resize_outgoing) + + self._layers = torch.nn.Sequential(*layers) + + self._resize_skip = None + + if not input_channels == output_channels: + self._resize_skip = torch.nn.Conv2d( + in_channels=input_channels, out_channels=output_channels, kernel_size=1, padding=0 + ) + + def forward(self, input_features): + + # Pre-activate the inputs. + skip = input_features + residual = self._activation_fn(input_features) + + for layer in self._layers: + residual = layer(residual) + + if not self._resize_skip is None: + skip = self._resize_skip(skip) + + return skip + residual + + +def resize_up(input_features, scale=2): + """Resize the the input to upsample + + Args: + input_features (torch.Tensor): The Tensor to upsize + scale (int, optional): The scale used to upsize. Defaults to 2. + + Returns: + torch.Tensor: The upsized Tensor + """ + _, _, size_x, size_y = input_features.shape + new_size_x = int(round(size_x * scale)) + new_size_y = int(round(size_y * scale)) + return torch.nn.functional.interpolate(input_features, size=[new_size_x, new_size_y]) + + +def resize_down(input_features, scale=2): + """Resize the the input to downsample + + Args: + input_features (torch.Tensor): The Tensor to downsize + scale (int, optional): The scale used to downsize. Defaults to 2. + + Returns: + torch.Tensor: The downsized Tensor + """ + return torch.nn.AvgPool2d(kernel_size=scale, stride=scale, padding=0)(input_features) + + +def softmax_cross_entropy_with_logits(target, logits): + """Computes the softmax_cross_entropy_with_logits to replicate the equivalent function in + tensorflow. + + From https://gist.github.com/tejaskhot/cf3d087ce4708c422e68b3b747494b9f + + Args: + target (torch.Tensor): The target tensor + logits (torch.Tensor): Log probabilities + + Returns: + torch.Tensor: Tensor containing the softmax cross entropy loss + """ + + loss = torch.sum(-target * torch.nn.functional.log_softmax(logits, -1), -1) + return loss.mean() + + +def _sample_gumbel(shape): + """Transforms a uniform random variable to be standard Gumbel distributed. + + Args: + shape (tuple): The shape of the data + + Returns: + torch.Tensor: Standard Gumbel distribution + """ + + eps = 1e-20 + return -torch.log(-torch.log(torch.rand(shape) + eps) + eps) + + +def _topk_mask(score, k): + """Returns a mask for the top-k elements in score. + + Args: + score (torch.Tensor): The tensor of score values + k (float): The value of k (0-1) + + Returns: + torch.Tensor: The mask + """ + + _, indices = torch.topk(score, k) + return torch.scatter_add(torch.zeros(score.shape), 0, indices, torch.ones(k)) + + +def ce_loss(logits, labels, mask=None, top_k_percentage=None, deterministic=False): + """Computes the cross-entropy loss. + Optionally a mask and a top-k percentage for the used pixels can be specified. + The top-k mask can be produced deterministically or sampled. + + Args: + logits (torch.Tensor): A tensor of shape (b,num_classes,h,w) + labels (torch.Tensor): A tensor of shape (b,num_classes,h,w) + mask (torch.Tensor, optional): None or a tensor of shape (b,h,w). Defaults to None. + top_k_percentage (float, optional): None or a float in (0.,1.]. If None, a standard + cross-entropy loss is calculated. Defaults to None. + deterministic (bool, optional): A Boolean indicating whether or not to produce the + prospective top-k mask deterministically. Defaults to + False. + + Returns: + dict: A dictionary holding the mean and the pixelwise sum of the loss for the + batch as well as the employed loss mask. + """ + + num_classes = logits.shape[1] + + y_flat = torch.reshape(logits, (-1, num_classes)) + t_flat = torch.reshape(labels, (-1, num_classes)) + if mask is None: + mask = torch.ones(t_flat.shape[0]) + else: + assert ( + mask.shape.as_list()[:3] == labels.shape.as_list()[:3] + ), "The loss mask shape differs from the target shape: {} vs. {}.".format( + mask.shape.as_list(), labels.shape.as_list()[:3] + ) + mask = torch.reshape(mask, (-1,), name="reshape_mask") + + n_pixels_in_batch = y_flat.shape[0] + xe = softmax_cross_entropy_with_logits(t_flat, y_flat) + + if top_k_percentage is not None: + assert 0.0 < top_k_percentage <= 1.0 + k_pixels = math.floor(n_pixels_in_batch * top_k_percentage) + + # stopgrad_xe = tf.stop_gradient(xe) + norm_xe = xe / xe.sum() + + if deterministic: + + score = norm_xe.log() + else: + # Use the Gumbel trick to sample the top-k pixels, equivalent to sampling + # from a categorical distribution over pixels whose probabilities are + # given by the normalized cross-entropy loss values. This is done by + # adding Gumbel noise to the logarithmic normalized cross-entropy loss + # (followed by choosing the top-k pixels). + score = norm_xe.log() + _sample_gumbel(norm_xe.shape) + + score = score + mask.log() + top_k_mask = _topk_mask(score, k_pixels) + mask = mask * top_k_mask + + # Calculate batch-averages for the sum and mean of the loss + batch_size = labels.shape[0] + xe = torch.reshape(xe, (batch_size, -1)) + mask = torch.reshape(mask, (batch_size, -1)) + ce_sum_per_instance = torch.sum(mask * xe, 1) + ce_sum = torch.mean(ce_sum_per_instance, 0) + ce_mean = torch.sum(mask * xe) / torch.sum(mask) + + return {"mean": ce_mean, "sum": ce_sum, "mask": mask} + + +class _HierarchicalCore(torch.nn.Module): + """A U-Net encoder-decoder with a full encoder and a truncated decoder. + The truncated decoder is interleaved with the hierarchical latent space and + has as many levels as there are levels in the hierarchy plus one additional + level. + """ + + def __init__( + self, + latent_dims, + input_channels, + channels_per_block, + down_channels_per_block=None, + activation_fn=torch.nn.ReLU, + convs_per_block=3, + blocks_per_level=3, + ): + """Initializes a HierarchicalCore. + + Args: + latent_dims (list): List of integers specifying the dimensions of the latents at + each scale. The length of the list indicates the number of U-Net + decoder scales that have latents. + input_channels (int): The number of input channels. + channels_per_block (list): A list of integers specifying the number of output + channels for each encoder block. + down_channels_per_block (list, optional): A list of integers specifying the number of + intermediate channels for each encoder block + or None. If None, the intermediate channels + are chosen equal to channels_per_block. + Defaults to None. + activation_fn (torch.nn.Module, optional): A callable activation function. Defaults to + torch.nn.ReLU. + convs_per_block (int, optional): An integer specifying the number of convolutional + layers. Defaults to 3. + blocks_per_level (int, optional): An integer specifying the number of residual blocks + per level. Defaults to 3. + """ + + super(_HierarchicalCore, self).__init__() + + self._latent_dims = latent_dims + self._input_channels = input_channels + self._channels_per_block = channels_per_block + self._activation_fn = activation_fn + self._convs_per_block = convs_per_block + self._blocks_per_level = blocks_per_level + if down_channels_per_block is None: + self._down_channels_per_block = channels_per_block + else: + self._down_channels_per_block = down_channels_per_block + + num_levels = len(self._channels_per_block) + self._num_latent_levels = len(self._latent_dims) + + # Iterate the descending levels in the U-Net encoder. + self.encoder_layers = torch.nn.ModuleList() + in_channels = input_channels + for level in range(num_levels): + # Iterate the residual blocks in each level. + layer = [] + for _ in range(self._blocks_per_level): + layer.append( + ResBlock( + in_channels, + channels_per_block[level], + n_down_channels=self._down_channels_per_block[level], + activation_fn=self._activation_fn, + convs_per_block=self._convs_per_block, + ) + ) + in_channels = channels_per_block[level] + + self.encoder_layers.append(torch.nn.Sequential(*layer)) + + # Iterate the ascending levels in the (truncated) U-Net decoder. + self.decoder_layers = torch.nn.ModuleList() + self._mu_logsigma_blocks = torch.nn.ModuleList() + + for level in range(self._num_latent_levels): + + latent_dim = latent_dims[level] + + mu_logsigma_block = torch.nn.Conv2d( + channels_per_block[::-1][level], 2 * latent_dim, kernel_size=1, padding=0 + ) + + self._mu_logsigma_blocks.append(mu_logsigma_block) + + decoder_in_channels = ( + channels_per_block[::-1][level + 1] + channels_per_block[::-1][level] + ) + latent_dim + layer = [] + for _ in range(self._blocks_per_level): + layer.append( + ResBlock( + decoder_in_channels, + channels_per_block[::-1][level + 1], + n_down_channels=self._down_channels_per_block[::-1][level + 1], + activation_fn=self._activation_fn, + convs_per_block=self._convs_per_block, + ) + ) + decoder_in_channels = channels_per_block[::-1][level + 1] + + self.decoder_layers.append(torch.nn.Sequential(*layer)) + + def forward(self, inputs, mean=False, z_q=None): + """Forward pass to sample from the module as specified. + + Args: + inputs: + mean: + z_q: + Returns: + + Args: + inputs (torch.Tensor): A tensor of shape (b,c,h,w). When using the module as a prior + the `inputs` tensor should be a batch of images. When using it + as a posterior the tensor should be a (batched) concatentation + of images and segmentations. + mean (bool|list, optional): A boolean or a list of booleans. If a boolean, it specifies + whether or not to use the distributions' means in ALL + latent scales. If a list, each bool therein specifies + whether or not to use the scale's mean. If False, the + latents of the scale are sampled. Defaults to False. + z_q (list, optional): None or a list of tensors. If not None, z_q provides external + latents to be used instead of sampling them. This is used to + employ posterior latents in the prior during training. Therefore, + if z_q is not None, the value of `mean` is ignored. If z_q is + None, either the distributions mean is used (in case `mean` for + the respective scale is True) or else a sample from the + distribution is drawn. Defaults to None. + + Returns: + dict: A Dictionary holding the output feature map of the truncated U-Net decoder under + key 'decoder_features', a list of the U-Net encoder features produced at the end of + each encoder scale under key 'encoder_outputs', a list of the predicted distributions + at each scale under key 'distributions', a list of the used latents at each scale under + the key 'used_latents'. + """ + + encoder_features = inputs + encoder_outputs = [] + num_levels = len(self._channels_per_block) + num_latent_levels = len(self._latent_dims) + if isinstance(mean, bool): + mean = [mean] * self._num_latent_levels + distributions = [] + used_latents = [] + + # Iterate the descending levels in the U-Net encoder. + for level, encoder_layer in enumerate(self.encoder_layers): + encoder_features = encoder_layer(encoder_features) + encoder_outputs.append(encoder_features) + if not level == num_levels - 1: + encoder_features = resize_down(encoder_features, scale=2) + + # Iterate the ascending levels in the (truncated) U-Net decoder. + decoder_features = encoder_outputs[-1] + for level in range(num_latent_levels): + + # Predict a Gaussian distribution for each pixel in the feature map. + latent_dim = self._latent_dims[level] + mu_logsigma = self._mu_logsigma_blocks[level](decoder_features) + + mu = mu_logsigma[:, :latent_dim] + log_sigma = mu_logsigma[:, latent_dim:] + + dist = torch.distributions.Independent( + torch.distributions.Normal(loc=mu, scale=torch.exp(log_sigma)), 1 + ) + distributions.append(dist) + + # Get the latents to condition on. + if z_q is not None: + z = z_q[level] + elif mean[level]: + z = dist.base_dist.loc + else: + z = dist.sample() + + used_latents.append(z) + + # Concat and upsample the latents with the previous features. + decoder_output_lo = torch.cat([z, decoder_features], axis=1) + decoder_output_hi = resize_up(decoder_output_lo, scale=2) + decoder_features = torch.cat( + [decoder_output_hi, encoder_outputs[::-1][level + 1]], axis=1 + ) + decoder_features = self.decoder_layers[level](decoder_features) + + return { + "decoder_features": decoder_features, + "encoder_features": encoder_outputs, + "distributions": distributions, + "used_latents": used_latents, + } + + +class _StitchingDecoder(torch.nn.Module): + """A module that completes the truncated U-Net decoder. + Using the output of the HierarchicalCore this module fills in the missing + decoder levels such that together the two form a symmetric U-Net. + """ + + def __init__( + self, + latent_dims, + channels_per_block, + num_classes, + down_channels_per_block=None, + activation_fn=torch.nn.ReLU, + convs_per_block=3, + blocks_per_level=3, + ): + """Initializes a StichtingDecoder. + + Args: + latent_dims (list): List of integers specifying the dimensions of the latents at each + scale. The length of the list indicates the number of U-Net decoder + scales that have latents. + channels_per_block (list): A list of integers specifying the number of output channels + for each encoder block. + num_classes (int): The number of segmentation classes. + down_channels_per_block ([type], optional): A list of integers specifying the number of + intermediate channels for each encoder + block. If None, the intermediate channels + are chosen equal to channels_per_block. + Defaults to None. + activation_fn (torch.nn.Module, optional): A callable activation function.Defaults to + torch.nn.ReLU. + initializers ([type], optional): [description]. Defaults to None. + regularizers ([type], optional): [description]. Defaults to None. + convs_per_block (int, optional): An integer specifying the number of convolutional + layers. Defaults to 3. + blocks_per_level (int, optional): An integer specifying the number of residual blocks + per level. Defaults to 3. + """ + super(_StitchingDecoder, self).__init__() + self._latent_dims = latent_dims + self._channels_per_block = channels_per_block + self._num_classes = num_classes + self._activation_fn = activation_fn + self._convs_per_block = convs_per_block + self._blocks_per_level = blocks_per_level + if down_channels_per_block is None: + down_channels_per_block = channels_per_block + self._down_channels_per_block = down_channels_per_block + + num_latents = len(self._latent_dims) + self._start_level = num_latents + 1 + self._num_levels = len(self._channels_per_block) + + self.decoder_layers = torch.nn.ModuleList() + for level in range(self._start_level, self._num_levels, 1): + + decoder_in_channels = ( + channels_per_block[::-1][level - 1] + channels_per_block[::-1][level] + ) + layer = [] + for _ in range(self._blocks_per_level): + layer.append( + ResBlock( + decoder_in_channels, + channels_per_block[::-1][level], + n_down_channels=self._down_channels_per_block[::-1][level], + activation_fn=self._activation_fn, + convs_per_block=self._convs_per_block, + ) + ) + decoder_in_channels = channels_per_block[::-1][level] + + self.decoder_layers.append(torch.nn.Sequential(*layer)) + + self.final_layer = torch.nn.Conv2d( + decoder_in_channels, self._num_classes, kernel_size=1, padding=0 + ) + + def forward(self, encoder_features, decoder_features): + """Forward pass through the stiching decoder + + Args: + encoder_features (torch.Tensor): Tensor of encoder features + decoder_features (dict): Tensor of decoder features + + Returns: + torch.Tensor: The stiched output + """ + + for level in range(len(self.decoder_layers)): + enc_level = self._start_level + level + decoder_features = resize_up(decoder_features, scale=2) + decoder_features = torch.cat( + [decoder_features, encoder_features[::-1][enc_level]], axis=1 + ) + decoder_features = self.decoder_layers[level](decoder_features) + + return self.final_layer(decoder_features) + + +class HierarchicalProbabilisticUnet(torch.nn.Module): + """A hierarchical probabilistic UNet implementation: https://arxiv.org/abs/1905.13077""" + + def __init__( + self, + input_channels=1, + num_classes=2, + channels_per_block=None, + down_channels_per_block=None, + latent_dims=(1, 1, 1, 1), + convs_per_block=3, + blocks_per_level=3, + loss_kwargs=None, + ): + """Initialize the Hierarchical Probabilistic UNet + + Args: + input_channels (int, optional): The number of channels in the image (1 for + greyscale and 3 for RGB). Defaults to 1. + num_classes (int, optional): The number of classes to predict. Defaults to 2. + channels_per_block (list, optional): A list of channels to use in blocks of each + layer the amount of filters layer. Defaults + to None. + down_channels_per_block (list, optional): [description]. Defaults to None. + latent_dims (tuple, optional): The number of latent dimensions at each layer. + Defaults to (1, 1, 1, 1). + convs_per_block (int, optional): An integer specifying the number of convolutional + layers. Defaults to 3. Defaults to 3. + blocks_per_level (int, optional): An integer specifying the number of residual + blocks per level. Defaults to 3. + loss_kwargs (dict, optional): Dictionary of argument used by loss function. + Defaults to None. + """ + super(HierarchicalProbabilisticUnet, self).__init__() + + base_channels = 24 + default_channels_per_block = ( + base_channels, + 2 * base_channels, + 4 * base_channels, + 8 * base_channels, + 8 * base_channels, + 8 * base_channels, + 8 * base_channels, + 8 * base_channels, + ) + if channels_per_block is None: + channels_per_block = default_channels_per_block + if down_channels_per_block is None: + down_channels_per_block = [int(i / 2) for i in default_channels_per_block] + + self._prior = _HierarchicalCore( + input_channels=input_channels, + latent_dims=latent_dims, + channels_per_block=channels_per_block, + down_channels_per_block=down_channels_per_block, + convs_per_block=convs_per_block, + blocks_per_level=blocks_per_level, + ) + + self._posterior = _HierarchicalCore( + input_channels=input_channels + num_classes, + latent_dims=latent_dims, + channels_per_block=channels_per_block, + down_channels_per_block=down_channels_per_block, + convs_per_block=convs_per_block, + blocks_per_level=blocks_per_level, + ) + + self._f_comb = _StitchingDecoder( + latent_dims=latent_dims, + channels_per_block=channels_per_block, + num_classes=num_classes, + down_channels_per_block=down_channels_per_block, + convs_per_block=convs_per_block, + blocks_per_level=blocks_per_level, + ) + + self._cache = None + + if loss_kwargs is None: + self._loss_kwargs = { + "type": "elbo", + "top_k_percentage": 0.02, + "deterministic_top_k": False, + "kappa": 0.05, + "decay": 0.99, + "rate": 1e-2, + "beta": 1.0, + } + else: + self._loss_kwargs = loss_kwargs + + # if self._loss_kwargs["type"] == "geco": + # self._moving_average = ExponentialMovingAverage( + # model=self, decay=self._loss_kwargs["decay"] + # ) + # self._lagmul = geco_utils.LagrangeMultiplier(rate=self._loss_kwargs["rate"]) + + self._q_sample = None + self._q_sample_mean = None + self._p_sample = None + self._p_sample_z_q = None + self._p_sample_z_q_mean = None + + def forward(self, img, seg): + """Inserts all ops used during training into the graph exactly once. The first time this + method is called given the input pair (img, seg) all ops relevant for training are inserted + into the graph. Calling this method more than once does not re-insert the modules into the + graph (memoization), thus preventing multiple forward passes of submodules for the same + inputs. + + Args: + img (torch.Tensor): A tensor of shape (b, c, h, w). + seg (torch.Tensor): A tensor of shape (b, num_classes, h, w). + """ + + inputs = (img, seg) + + if self._cache == inputs: + # No need to recompute + return + + input_tensor = torch.cat([img, seg], axis=1) + self._q_sample = self._posterior(input_tensor, mean=False) + self._q_sample_mean = self._posterior(input_tensor, mean=True) + self._p_sample = self._prior(img, mean=False, z_q=None) + self._p_sample_z_q = self._prior(img, z_q=self._q_sample["used_latents"]) + self._p_sample_z_q_mean = self._prior(img, z_q=self._q_sample_mean["used_latents"]) + self._cache = inputs + + def sample(self, img, mean=False, z_q=None): + """Sample a segmentation from the prior, given an input image. + + Args: + img (torch.Tensor): A tensor of shape (b, c, h, w). + mean (bool, optional): A boolean or a list of booleans. If a boolean, it specifies + whether or not to use the distributions' means in ALL latent + scales. If a list, each bool therein specifies whether or not to + use the scale's mean. If False, the latents of the scale are + sampled. Defaults to False. + z_q (list, optional): If not None, z_q provides external latents to be used instead of + sampling them. This is used to employ posterior latents in the + prior during training. Therefore, if z_q is not None, the value + of `mean` is ignored. If z_q is None, either the distributions + mean is used (in case `mean` for the respective scale is True) or + else a sample from the distribution is drawn. Defaults to None. + + Returns: + torch.Tensor: A segmentation tensor of shape (b, num_classes, h, w). + """ + + prior_out = self._prior(img, mean, z_q) + encoder_features = prior_out["encoder_features"] + decoder_features = prior_out["decoder_features"] + return self._f_comb(encoder_features, decoder_features) + + def reconstruct(self, img, seg, mean=False): + """Reconstruct a segmentation using the posterior. + + Args: + img ([torch.Tensor): A tensor of shape (b, c, h, w). + seg (torch.Tensor): A tensor of shape (b, num_classes, h, w). + mean (bool, optional): A boolean, specifying whether to sample from the full hierarchy + of the posterior or use the posterior means at each scale of the + hierarchy. Defaults to False. + + Returns: + torch.Tensor: A segmentation tensor of shape (b,num_classes,h,w). + """ + + self.forward(img, seg) + if mean: + prior_out = self._p_sample_z_q_mean + else: + prior_out = self._p_sample_z_q + encoder_features = prior_out["encoder_features"] + decoder_features = prior_out["decoder_features"] + return self._f_comb(encoder_features, decoder_features) + + def kl(self, img, seg): + """Kullback-Leibler divergence between the posterior and the prior. + + Args: + img (torch.Tensor): A tensor of shape (b, c, h, w). + seg (torch.Tensor): A tensor of shape (b, num_classes, h, w). + + Returns: + dict: A dictionary with keys indexing the hierarchy's levels and corresponding + values holding the KL-term for each level (per batch). + """ + self.forward(img, seg) + posterior_out = self._q_sample + prior_out = self._p_sample_z_q + + q_dists = posterior_out["distributions"] + p_dists = prior_out["distributions"] + + kl = {} + for level, (p, q) in enumerate(zip(p_dists, q_dists)): + kl_per_pixel = torch.distributions.kl.kl_divergence(p, q) + kl_per_instance = torch.sum(kl_per_pixel, [1, 2]) + kl[level] = torch.mean(kl_per_instance) + + return kl + + def rec_loss(self, img, seg, mask=None, top_k_percentage=None, deterministic=True): + """Cross-entropy reconstruction loss employed in the ELBO-/ GECO-objective. + + Args: + img (torch.Tensor): A tensor of shape (b, c, h, w). + seg (torch.Tensor): A tensor of shape (b, num_classes, h, w). + mask (torch.Tensor, optional): A mask of shape (b, h, w) or None. If None no pixels are + masked in the loss. Defaults to None. + top_k_percentage (float, optional): None or a float in (0.,1.]. If None, a standard + cross-entropy loss is calculated. Defaults to None. + deterministic (bool, optional): A Boolean indicating whether or not to produce the + prospective top-k mask deterministically. Defaults to + True. + + Returns: + dict: A dictionary holding the mean and the pixelwise sum of the loss for the + batch as well as the employed loss mask. + """ + reconstruction = self.reconstruct(img, seg, mean=False) + return ce_loss(reconstruction, seg, mask, top_k_percentage, deterministic) + + def loss(self, img, seg, mask=None): + """The full training objective, either ELBO or GECO. + + Args: + img (torch.Tensor): A tensor of shape (b, c, h, w). + seg (torch.Tensor): A tensor of shape (b, num_classes, h, w). + mask (torch.Tensor, optional): A mask of shape (b, h, w) or None. If None no pixels are + masked in the loss. Defaults to None. + + Raises: + NotImplementedError: Raised if loss function supplied isn't implemented yet. + + Returns: + dict: A dictionary holding the loss (with key 'loss') and the tensorboard summaries + (with key 'summaries'). + """ + summaries = {} + top_k_percentage = self._loss_kwargs["top_k_percentage"] + deterministic = self._loss_kwargs["deterministic_top_k"] + rec_loss = self.rec_loss(img, seg, mask, top_k_percentage, deterministic) + + kl_dict = self.kl(img, seg) + kl_sum = torch.sum(torch.stack([kl for _, kl in kl_dict.items()], axis=-1)) + + summaries["rec_loss_mean"] = rec_loss["mean"] + summaries["rec_loss_sum"] = rec_loss["sum"] + summaries["kl_sum"] = kl_sum + for level, kl in kl_dict.items(): + summaries["kl_{}".format(level)] = kl + + # Set up a regular ELBO objective. + if self._loss_kwargs["type"] == "elbo": + loss = rec_loss["sum"] + self._loss_kwargs["beta"] * kl_sum + summaries["elbo_loss"] = loss + + # TODO Still need to implement geco + # Set up a GECO objective (ELBO with a reconstruction constraint). + # elif self._loss_kwargs["type"] == "geco": + # ma_rec_loss = self._moving_average(rec_loss["sum"]) + # mask_sum_per_instance = torch.sum(rec_loss["mask"], -1) + # num_valid_pixels = torch.mean(mask_sum_per_instance) + # reconstruction_threshold = self._loss_kwargs["kappa"] * num_valid_pixels + + # rec_constraint = ma_rec_loss - reconstruction_threshold + # lagmul = self._lagmul(rec_constraint) + # loss = lagmul * rec_constraint + kl_sum + + # summaries["geco_loss"] = loss + # summaries["ma_rec_loss_mean"] = ma_rec_loss / num_valid_pixels + # summaries["num_valid_pixels"] = num_valid_pixels + # summaries["lagmul"] = lagmul + else: + raise NotImplementedError( + "Loss type {} not implemeted!".format(self._loss_kwargs["type"]) + ) + + return dict(supervised_loss=loss, summaries=summaries) diff --git a/platipy/imaging/cnn/test_hpunet.py b/platipy/imaging/cnn/test_hpunet.py new file mode 100644 index 00000000..1dca7888 --- /dev/null +++ b/platipy/imaging/cnn/test_hpunet.py @@ -0,0 +1,158 @@ +import torch + +from platipy.imaging.cnn.hierarchical_prob_unet import ( + _HierarchicalCore, + ResBlock, + HierarchicalProbabilisticUnet, +) + +base_channels = 24 +default_channels_per_block = [ + base_channels, + 2 * base_channels, + 4 * base_channels, + 8 * base_channels, + # 8 * base_channels, + # 8 * base_channels, + # 8 * base_channels, + # 8 * base_channels, +] +# default_channels_per_block = [ +# base_channels, +# 2 * base_channels, +# 4 * base_channels, +# 8 * base_channels, +# ] +channels_per_block = default_channels_per_block +down_channels_per_block = [int(i / 2) for i in default_channels_per_block] +a = _HierarchicalCore([8, 6, 2], 1, channels_per_block, down_channels_per_block) +b = ResBlock(1, 3, base_channels) +d = ResBlock(3, 3, base_channels * 2) +c = torch.rand([1, 1, 256, 256]) + +fg = torch.ones(c.shape) +bg = torch.zeros(c.shape) +labels = torch.cat([fg, bg], axis=1) +# a(b) +# print(a) + +hpunet = HierarchicalProbabilisticUnet() +output = hpunet.sample(c) +print(output.shape) +output = hpunet.reconstruct(c, labels) +print(output.shape) +loss = hpunet.loss(c, labels) +print(loss) + + +_NUM_CLASSES = 2 +_BATCH_SIZE = 2 +_SPATIAL_SHAPE = [32, 32] +_CHANNELS_PER_BLOCK = [5, 7, 9, 11, 13] +_IMAGE_SHAPE = [_BATCH_SIZE] + [1] + _SPATIAL_SHAPE +_BOTTLENECK_SIZE = _SPATIAL_SHAPE[0] // 2 ** (len(_CHANNELS_PER_BLOCK) - 1) +_SEGMENTATION_SHAPE = [_BATCH_SIZE] + [_NUM_CLASSES] + _SPATIAL_SHAPE +_LATENT_DIMS = [3, 2, 1] +# _INITIALIZERS = { +# "w": tf.orthogonal_initializer(gain=1.0, seed=None), +# "b": tf.truncated_normal_initializer(stddev=0.001), +# } + + +def _get_placeholders(): + """Returns placeholders for the image and segmentation.""" + img = torch.rand(_IMAGE_SHAPE) + seg = torch.rand(_SEGMENTATION_SHAPE) + return img, seg + + +def test_shape_of_sample(): + hpu_net = HierarchicalProbabilisticUnet( + latent_dims=_LATENT_DIMS, + channels_per_block=_CHANNELS_PER_BLOCK, + num_classes=_NUM_CLASSES, + # initializers=_INITIALIZERS, + ) + img, _ = _get_placeholders() + sample = hpu_net.sample(img) + + assert list(sample.shape) == _SEGMENTATION_SHAPE + + +def test_shape_of_reconstruction(): + hpu_net = HierarchicalProbabilisticUnet( + latent_dims=_LATENT_DIMS, + channels_per_block=_CHANNELS_PER_BLOCK, + num_classes=_NUM_CLASSES, + # initializers=_INITIALIZERS, + ) + img, seg = _get_placeholders() + reconstruction = hpu_net.reconstruct(img, seg) + assert list(reconstruction.shape) == _SEGMENTATION_SHAPE + + +def test_shapes_in_prior(): + hpu_net = HierarchicalProbabilisticUnet( + latent_dims=_LATENT_DIMS, + channels_per_block=_CHANNELS_PER_BLOCK, + num_classes=_NUM_CLASSES, + # initializers=_INITIALIZERS, + ) + img, _ = _get_placeholders() + prior_out = hpu_net._prior(img) + distributions = prior_out["distributions"] + latents = prior_out["used_latents"] + encoder_features = prior_out["encoder_features"] + decoder_features = prior_out["decoder_features"] + + # Test number of latent disctributions. + assert len(distributions) == len(_LATENT_DIMS) + + # Test shapes of latent scales. + for level in range(len(_LATENT_DIMS)): + latent_spatial_shape = _BOTTLENECK_SIZE * 2 ** level + latent_shape = [ + _BATCH_SIZE, + _LATENT_DIMS[level], + latent_spatial_shape, + latent_spatial_shape, + ] + assert list(latents[level].shape) == latent_shape + + # Test encoder shapes. + for level in range(len(_CHANNELS_PER_BLOCK)): + spatial_shape = _SPATIAL_SHAPE[0] // 2 ** level + feature_shape = [_BATCH_SIZE, _CHANNELS_PER_BLOCK[level], spatial_shape, spatial_shape] + + assert list(encoder_features[level].shape) == feature_shape + + # Test decoder shape. + start_level = len(_LATENT_DIMS) + latent_spatial_shape = _BOTTLENECK_SIZE * 2 ** start_level + latent_shape = [ + _BATCH_SIZE, + _CHANNELS_PER_BLOCK[::-1][start_level], + latent_spatial_shape, + latent_spatial_shape, + ] + + assert list(decoder_features.shape) == latent_shape + + +def test_shape_of_kl(): + hpu_net = HierarchicalProbabilisticUnet( + latent_dims=_LATENT_DIMS, + channels_per_block=_CHANNELS_PER_BLOCK, + num_classes=_NUM_CLASSES, + # initializers=_INITIALIZERS, + ) + img, seg = _get_placeholders() + kl_dict = hpu_net.kl(img, seg) + assert len(kl_dict) == len(_LATENT_DIMS) + + +test_shape_of_sample() +test_shape_of_reconstruction() +test_shapes_in_prior() +test_shape_of_kl() +# if __name__ == "__main__": From 26dc438308d07659b5f66923577f85c6cab3c56a Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Fri, 26 Mar 2021 21:41:34 +0000 Subject: [PATCH 010/264] Put mask on same device as other tensors --- platipy/imaging/cnn/hierarchical_prob_unet.py | 1 + 1 file changed, 1 insertion(+) diff --git a/platipy/imaging/cnn/hierarchical_prob_unet.py b/platipy/imaging/cnn/hierarchical_prob_unet.py index 6df6eac1..43d7f7c5 100644 --- a/platipy/imaging/cnn/hierarchical_prob_unet.py +++ b/platipy/imaging/cnn/hierarchical_prob_unet.py @@ -200,6 +200,7 @@ def ce_loss(logits, labels, mask=None, top_k_percentage=None, deterministic=Fals t_flat = torch.reshape(labels, (-1, num_classes)) if mask is None: mask = torch.ones(t_flat.shape[0]) + mask = mask.to(logits.device) else: assert ( mask.shape.as_list()[:3] == labels.shape.as_list()[:3] From c28257a388cde0084864be8dbc03fdc5909d09b6 Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Fri, 26 Mar 2021 21:43:21 +0000 Subject: [PATCH 011/264] More tensors to correct device --- platipy/imaging/cnn/hierarchical_prob_unet.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/platipy/imaging/cnn/hierarchical_prob_unet.py b/platipy/imaging/cnn/hierarchical_prob_unet.py index 43d7f7c5..130bbefd 100644 --- a/platipy/imaging/cnn/hierarchical_prob_unet.py +++ b/platipy/imaging/cnn/hierarchical_prob_unet.py @@ -171,7 +171,9 @@ def _topk_mask(score, k): """ _, indices = torch.topk(score, k) - return torch.scatter_add(torch.zeros(score.shape), 0, indices, torch.ones(k)) + zeros = torch.zeros(score.shape).to(score.device) + ones = torch.ones(k).to(score.device) + return torch.scatter_add(zeros, 0, indices, ones) def ce_loss(logits, labels, mask=None, top_k_percentage=None, deterministic=False): From a49e7ff4471dbfbf279696c82da92781f97d2692 Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Sun, 28 Mar 2021 21:24:39 +0000 Subject: [PATCH 012/264] Correct loss function --- platipy/imaging/cnn/hierarchical_prob_unet.py | 7 +++---- platipy/imaging/cnn/test_hpunet.py | 2 +- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/platipy/imaging/cnn/hierarchical_prob_unet.py b/platipy/imaging/cnn/hierarchical_prob_unet.py index 130bbefd..0d044b8e 100644 --- a/platipy/imaging/cnn/hierarchical_prob_unet.py +++ b/platipy/imaging/cnn/hierarchical_prob_unet.py @@ -141,8 +141,7 @@ def softmax_cross_entropy_with_logits(target, logits): torch.Tensor: Tensor containing the softmax cross entropy loss """ - loss = torch.sum(-target * torch.nn.functional.log_softmax(logits, -1), -1) - return loss.mean() + return torch.sum(-target * torch.nn.functional.log_softmax(logits, -1), -1) def _sample_gumbel(shape): @@ -238,8 +237,8 @@ def ce_loss(logits, labels, mask=None, top_k_percentage=None, deterministic=Fals # Calculate batch-averages for the sum and mean of the loss batch_size = labels.shape[0] - xe = torch.reshape(xe, (batch_size, -1)) - mask = torch.reshape(mask, (batch_size, -1)) + xe = torch.reshape(xe, (batch_size, int(xe.numel() / batch_size))) + mask = torch.reshape(mask, (batch_size, int(mask.numel() / batch_size))) ce_sum_per_instance = torch.sum(mask * xe, 1) ce_sum = torch.mean(ce_sum_per_instance, 0) ce_mean = torch.sum(mask * xe) / torch.sum(mask) diff --git a/platipy/imaging/cnn/test_hpunet.py b/platipy/imaging/cnn/test_hpunet.py index 1dca7888..3d5f212e 100644 --- a/platipy/imaging/cnn/test_hpunet.py +++ b/platipy/imaging/cnn/test_hpunet.py @@ -28,7 +28,7 @@ a = _HierarchicalCore([8, 6, 2], 1, channels_per_block, down_channels_per_block) b = ResBlock(1, 3, base_channels) d = ResBlock(3, 3, base_channels * 2) -c = torch.rand([1, 1, 256, 256]) +c = torch.rand([3, 1, 256, 256]) fg = torch.ones(c.shape) bg = torch.zeros(c.shape) From 389ca09681518afb450651dea0ea32784cc3a8de Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Sun, 28 Mar 2021 21:27:06 +0000 Subject: [PATCH 013/264] Move sample gumbel to correct device --- platipy/imaging/cnn/hierarchical_prob_unet.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/platipy/imaging/cnn/hierarchical_prob_unet.py b/platipy/imaging/cnn/hierarchical_prob_unet.py index 0d044b8e..2322d61f 100644 --- a/platipy/imaging/cnn/hierarchical_prob_unet.py +++ b/platipy/imaging/cnn/hierarchical_prob_unet.py @@ -229,6 +229,8 @@ def ce_loss(logits, labels, mask=None, top_k_percentage=None, deterministic=Fals # given by the normalized cross-entropy loss values. This is done by # adding Gumbel noise to the logarithmic normalized cross-entropy loss # (followed by choosing the top-k pixels). + sg = _sample_gumbel(norm_xe.shape) + sg = sg.to(logits.device) score = norm_xe.log() + _sample_gumbel(norm_xe.shape) score = score + mask.log() From 06b2c87479951add59718c258b5ae4bb1905c69e Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Sun, 28 Mar 2021 21:28:16 +0000 Subject: [PATCH 014/264] Correct using correct sg --- platipy/imaging/cnn/hierarchical_prob_unet.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/platipy/imaging/cnn/hierarchical_prob_unet.py b/platipy/imaging/cnn/hierarchical_prob_unet.py index 2322d61f..53a09485 100644 --- a/platipy/imaging/cnn/hierarchical_prob_unet.py +++ b/platipy/imaging/cnn/hierarchical_prob_unet.py @@ -231,7 +231,7 @@ def ce_loss(logits, labels, mask=None, top_k_percentage=None, deterministic=Fals # (followed by choosing the top-k pixels). sg = _sample_gumbel(norm_xe.shape) sg = sg.to(logits.device) - score = norm_xe.log() + _sample_gumbel(norm_xe.shape) + score = norm_xe.log() + sg score = score + mask.log() top_k_mask = _topk_mask(score, k_pixels) From 20ba8843a3f7c4c58758145c88757f69ada23419 Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Sun, 28 Mar 2021 21:38:03 +0000 Subject: [PATCH 015/264] remove inplace from activation fn --- platipy/imaging/cnn/hierarchical_prob_unet.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/platipy/imaging/cnn/hierarchical_prob_unet.py b/platipy/imaging/cnn/hierarchical_prob_unet.py index 53a09485..fd15eb48 100644 --- a/platipy/imaging/cnn/hierarchical_prob_unet.py +++ b/platipy/imaging/cnn/hierarchical_prob_unet.py @@ -48,7 +48,7 @@ def __init__( super(ResBlock, self).__init__() - self._activation_fn = activation_fn(inplace=True) + self._activation_fn = activation_fn() # Set the number of intermediate channels that we compress to. if n_down_channels is None: @@ -64,7 +64,7 @@ def __init__( ) if c < convs_per_block - 1: - layers.append(activation_fn(inplace=True)) + layers.append(activation_fn()) in_channels = n_down_channels From b0ecab1a5f5c824d59c9b757f34b12303b1a07a5 Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Sun, 28 Mar 2021 21:41:45 +0000 Subject: [PATCH 016/264] Cache inputs --- platipy/imaging/cnn/hierarchical_prob_unet.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/platipy/imaging/cnn/hierarchical_prob_unet.py b/platipy/imaging/cnn/hierarchical_prob_unet.py index fd15eb48..a6e126f9 100644 --- a/platipy/imaging/cnn/hierarchical_prob_unet.py +++ b/platipy/imaging/cnn/hierarchical_prob_unet.py @@ -666,19 +666,18 @@ def forward(self, img, seg): seg (torch.Tensor): A tensor of shape (b, num_classes, h, w). """ - inputs = (img, seg) + input_tensor = torch.cat([img, seg], axis=1) - if self._cache == inputs: + if self._cache == input_tensor: # No need to recompute return - input_tensor = torch.cat([img, seg], axis=1) self._q_sample = self._posterior(input_tensor, mean=False) self._q_sample_mean = self._posterior(input_tensor, mean=True) self._p_sample = self._prior(img, mean=False, z_q=None) self._p_sample_z_q = self._prior(img, z_q=self._q_sample["used_latents"]) self._p_sample_z_q_mean = self._prior(img, z_q=self._q_sample_mean["used_latents"]) - self._cache = inputs + self._cache = input_tensor def sample(self, img, mean=False, z_q=None): """Sample a segmentation from the prior, given an input image. From 781bd6a88d95516b3887a3958306bda939e9085a Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Sun, 28 Mar 2021 21:52:02 +0000 Subject: [PATCH 017/264] Correct cache --- platipy/imaging/cnn/hierarchical_prob_unet.py | 2 +- platipy/imaging/cnn/test_hpunet.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/platipy/imaging/cnn/hierarchical_prob_unet.py b/platipy/imaging/cnn/hierarchical_prob_unet.py index a6e126f9..c9776f68 100644 --- a/platipy/imaging/cnn/hierarchical_prob_unet.py +++ b/platipy/imaging/cnn/hierarchical_prob_unet.py @@ -668,7 +668,7 @@ def forward(self, img, seg): input_tensor = torch.cat([img, seg], axis=1) - if self._cache == input_tensor: + if not self._cache is None and torch.equal(self._cache, input_tensor): # No need to recompute return diff --git a/platipy/imaging/cnn/test_hpunet.py b/platipy/imaging/cnn/test_hpunet.py index 3d5f212e..0cc61ab9 100644 --- a/platipy/imaging/cnn/test_hpunet.py +++ b/platipy/imaging/cnn/test_hpunet.py @@ -42,6 +42,7 @@ output = hpunet.reconstruct(c, labels) print(output.shape) loss = hpunet.loss(c, labels) +loss = hpunet.loss(c, labels) print(loss) From c584130fcf6a4c5a131a84a395efe4c0223edb49 Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Sun, 28 Mar 2021 22:39:51 +0000 Subject: [PATCH 018/264] Fix a wrong indentation --- platipy/imaging/cnn/hierarchical_prob_unet.py | 56 +++++++++---------- 1 file changed, 25 insertions(+), 31 deletions(-) diff --git a/platipy/imaging/cnn/hierarchical_prob_unet.py b/platipy/imaging/cnn/hierarchical_prob_unet.py index c9776f68..b995732c 100644 --- a/platipy/imaging/cnn/hierarchical_prob_unet.py +++ b/platipy/imaging/cnn/hierarchical_prob_unet.py @@ -358,12 +358,6 @@ def __init__( def forward(self, inputs, mean=False, z_q=None): """Forward pass to sample from the module as specified. - Args: - inputs: - mean: - z_q: - Returns: - Args: inputs (torch.Tensor): A tensor of shape (b,c,h,w). When using the module as a prior the `inputs` tensor should be a batch of images. When using it @@ -805,30 +799,30 @@ def loss(self, img, seg, mask=None): for level, kl in kl_dict.items(): summaries["kl_{}".format(level)] = kl - # Set up a regular ELBO objective. - if self._loss_kwargs["type"] == "elbo": - loss = rec_loss["sum"] + self._loss_kwargs["beta"] * kl_sum - summaries["elbo_loss"] = loss - - # TODO Still need to implement geco - # Set up a GECO objective (ELBO with a reconstruction constraint). - # elif self._loss_kwargs["type"] == "geco": - # ma_rec_loss = self._moving_average(rec_loss["sum"]) - # mask_sum_per_instance = torch.sum(rec_loss["mask"], -1) - # num_valid_pixels = torch.mean(mask_sum_per_instance) - # reconstruction_threshold = self._loss_kwargs["kappa"] * num_valid_pixels - - # rec_constraint = ma_rec_loss - reconstruction_threshold - # lagmul = self._lagmul(rec_constraint) - # loss = lagmul * rec_constraint + kl_sum - - # summaries["geco_loss"] = loss - # summaries["ma_rec_loss_mean"] = ma_rec_loss / num_valid_pixels - # summaries["num_valid_pixels"] = num_valid_pixels - # summaries["lagmul"] = lagmul - else: - raise NotImplementedError( - "Loss type {} not implemeted!".format(self._loss_kwargs["type"]) - ) + # Set up a regular ELBO objective. + if self._loss_kwargs["type"] == "elbo": + loss = rec_loss["sum"] + self._loss_kwargs["beta"] * kl_sum + summaries["elbo_loss"] = loss + + # TODO Still need to implement geco + # Set up a GECO objective (ELBO with a reconstruction constraint). + # elif self._loss_kwargs["type"] == "geco": + # ma_rec_loss = self._moving_average(rec_loss["sum"]) + # mask_sum_per_instance = torch.sum(rec_loss["mask"], -1) + # num_valid_pixels = torch.mean(mask_sum_per_instance) + # reconstruction_threshold = self._loss_kwargs["kappa"] * num_valid_pixels + + # rec_constraint = ma_rec_loss - reconstruction_threshold + # lagmul = self._lagmul(rec_constraint) + # loss = lagmul * rec_constraint + kl_sum + + # summaries["geco_loss"] = loss + # summaries["ma_rec_loss_mean"] = ma_rec_loss / num_valid_pixels + # summaries["num_valid_pixels"] = num_valid_pixels + # summaries["lagmul"] = lagmul + else: + raise NotImplementedError( + "Loss type {} not implemeted!".format(self._loss_kwargs["type"]) + ) return dict(supervised_loss=loss, summaries=summaries) From 32095e5c912050bc94906d042dbc1d1e41eab167 Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Tue, 30 Mar 2021 08:21:39 +0000 Subject: [PATCH 019/264] Update hprobnet loss function --- platipy/imaging/cnn/hierarchical_prob_unet.py | 108 ++++++++++-------- platipy/imaging/cnn/test_hpunet.py | 2 +- 2 files changed, 63 insertions(+), 47 deletions(-) diff --git a/platipy/imaging/cnn/hierarchical_prob_unet.py b/platipy/imaging/cnn/hierarchical_prob_unet.py index b995732c..af183609 100644 --- a/platipy/imaging/cnn/hierarchical_prob_unet.py +++ b/platipy/imaging/cnn/hierarchical_prob_unet.py @@ -197,53 +197,61 @@ def ce_loss(logits, labels, mask=None, top_k_percentage=None, deterministic=Fals num_classes = logits.shape[1] - y_flat = torch.reshape(logits, (-1, num_classes)) - t_flat = torch.reshape(labels, (-1, num_classes)) - if mask is None: - mask = torch.ones(t_flat.shape[0]) - mask = mask.to(logits.device) - else: - assert ( - mask.shape.as_list()[:3] == labels.shape.as_list()[:3] - ), "The loss mask shape differs from the target shape: {} vs. {}.".format( - mask.shape.as_list(), labels.shape.as_list()[:3] - ) - mask = torch.reshape(mask, (-1,), name="reshape_mask") - - n_pixels_in_batch = y_flat.shape[0] - xe = softmax_cross_entropy_with_logits(t_flat, y_flat) - - if top_k_percentage is not None: - assert 0.0 < top_k_percentage <= 1.0 - k_pixels = math.floor(n_pixels_in_batch * top_k_percentage) - - # stopgrad_xe = tf.stop_gradient(xe) - norm_xe = xe / xe.sum() - - if deterministic: - - score = norm_xe.log() - else: - # Use the Gumbel trick to sample the top-k pixels, equivalent to sampling - # from a categorical distribution over pixels whose probabilities are - # given by the normalized cross-entropy loss values. This is done by - # adding Gumbel noise to the logarithmic normalized cross-entropy loss - # (followed by choosing the top-k pixels). - sg = _sample_gumbel(norm_xe.shape) - sg = sg.to(logits.device) - score = norm_xe.log() + sg - - score = score + mask.log() - top_k_mask = _topk_mask(score, k_pixels) - mask = mask * top_k_mask + # y_flat = torch.reshape(logits, (-1, num_classes)) + # t_flat = torch.reshape(labels, (-1, num_classes)) + + # print(y_flat.shape) + # print(t_flat.shape) + # if mask is None: + # mask = torch.ones(t_flat.shape[0]) + # mask = mask.to(logits.device) + # else: + # assert ( + # mask.shape.as_list()[:3] == labels.shape.as_list()[:3] + # ), "The loss mask shape differs from the target shape: {} vs. {}.".format( + # mask.shape.as_list(), labels.shape.as_list()[:3] + # ) + # mask = torch.reshape(mask, (-1,)) + + # n_pixels_in_batch = y_flat.shape[0] + criterion = torch.nn.BCEWithLogitsLoss(reduction="mean") + xe = criterion(input=logits, target=labels) + + print(xe) + # xe = softmax_cross_entropy_with_logits(t_flat, y_flat) + + # if top_k_percentage is not None: + # assert 0.0 < top_k_percentage <= 1.0 + # k_pixels = math.floor(n_pixels_in_batch * top_k_percentage) + + # # stopgrad_xe = tf.stop_gradient(xe) + # norm_xe = xe / xe.sum() + + # if deterministic: + + # score = norm_xe.log() + # else: + # # Use the Gumbel trick to sample the top-k pixels, equivalent to sampling + # # from a categorical distribution over pixels whose probabilities are + # # given by the normalized cross-entropy loss values. This is done by + # # adding Gumbel noise to the logarithmic normalized cross-entropy loss + # # (followed by choosing the top-k pixels). + # sg = _sample_gumbel(norm_xe.shape) + # sg = sg.to(logits.device) + # score = norm_xe.log() + sg + + # score = score + mask.log() + # top_k_mask = _topk_mask(score, k_pixels) + # mask = mask * top_k_mask # Calculate batch-averages for the sum and mean of the loss - batch_size = labels.shape[0] - xe = torch.reshape(xe, (batch_size, int(xe.numel() / batch_size))) - mask = torch.reshape(mask, (batch_size, int(mask.numel() / batch_size))) - ce_sum_per_instance = torch.sum(mask * xe, 1) - ce_sum = torch.mean(ce_sum_per_instance, 0) - ce_mean = torch.sum(mask * xe) / torch.sum(mask) + # batch_size = labels.shape[0] + # xe = torch.reshape(xe, (batch_size, int(xe.numel() / batch_size))) + # mask = torch.reshape(mask, (batch_size, int(mask.numel() / batch_size))) + # ce_sum_per_instance = torch.sum(mask * xe, 1) + # ce_sum = torch.mean(ce_sum_per_instance, 0) + # ce_mean = torch.sum(mask * xe) / torch.sum(mask) + xe = torch.sum(xe) return {"mean": ce_mean, "sum": ce_sum, "mask": mask} @@ -767,7 +775,15 @@ def rec_loss(self, img, seg, mask=None, top_k_percentage=None, deterministic=Tru batch as well as the employed loss mask. """ reconstruction = self.reconstruct(img, seg, mean=False) - return ce_loss(reconstruction, seg, mask, top_k_percentage, deterministic) + + criterion = torch.nn.BCEWithLogitsLoss(reduction="none") + reconstruction_loss = criterion(input=reconstruction, target=seg) + reconstruction_loss_sum = torch.sum(reconstruction_loss) + reconstruction_loss_mean = torch.mean(reconstruction_loss) + + return {"mean": reconstruction_loss_mean, "sum": reconstruction_loss_sum} + + # return ce_loss(reconstruction, seg, mask, top_k_percentage, deterministic) def loss(self, img, seg, mask=None): """The full training objective, either ELBO or GECO. diff --git a/platipy/imaging/cnn/test_hpunet.py b/platipy/imaging/cnn/test_hpunet.py index 0cc61ab9..1de62ceb 100644 --- a/platipy/imaging/cnn/test_hpunet.py +++ b/platipy/imaging/cnn/test_hpunet.py @@ -28,7 +28,7 @@ a = _HierarchicalCore([8, 6, 2], 1, channels_per_block, down_channels_per_block) b = ResBlock(1, 3, base_channels) d = ResBlock(3, 3, base_channels * 2) -c = torch.rand([3, 1, 256, 256]) +c = torch.rand([3, 1, 128, 128]) fg = torch.ones(c.shape) bg = torch.zeros(c.shape) From c9984bf64636d99e2fba3b889b826eb48f35c1a5 Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Tue, 30 Mar 2021 21:42:47 +0000 Subject: [PATCH 020/264] Cleaning up hprob unet code --- platipy/imaging/cnn/hierarchical_prob_unet.py | 157 +----------------- 1 file changed, 5 insertions(+), 152 deletions(-) diff --git a/platipy/imaging/cnn/hierarchical_prob_unet.py b/platipy/imaging/cnn/hierarchical_prob_unet.py index af183609..d20acaa7 100644 --- a/platipy/imaging/cnn/hierarchical_prob_unet.py +++ b/platipy/imaging/cnn/hierarchical_prob_unet.py @@ -18,7 +18,6 @@ # pylint: disable=invalid-name -import math import torch @@ -127,135 +126,6 @@ def resize_down(input_features, scale=2): return torch.nn.AvgPool2d(kernel_size=scale, stride=scale, padding=0)(input_features) -def softmax_cross_entropy_with_logits(target, logits): - """Computes the softmax_cross_entropy_with_logits to replicate the equivalent function in - tensorflow. - - From https://gist.github.com/tejaskhot/cf3d087ce4708c422e68b3b747494b9f - - Args: - target (torch.Tensor): The target tensor - logits (torch.Tensor): Log probabilities - - Returns: - torch.Tensor: Tensor containing the softmax cross entropy loss - """ - - return torch.sum(-target * torch.nn.functional.log_softmax(logits, -1), -1) - - -def _sample_gumbel(shape): - """Transforms a uniform random variable to be standard Gumbel distributed. - - Args: - shape (tuple): The shape of the data - - Returns: - torch.Tensor: Standard Gumbel distribution - """ - - eps = 1e-20 - return -torch.log(-torch.log(torch.rand(shape) + eps) + eps) - - -def _topk_mask(score, k): - """Returns a mask for the top-k elements in score. - - Args: - score (torch.Tensor): The tensor of score values - k (float): The value of k (0-1) - - Returns: - torch.Tensor: The mask - """ - - _, indices = torch.topk(score, k) - zeros = torch.zeros(score.shape).to(score.device) - ones = torch.ones(k).to(score.device) - return torch.scatter_add(zeros, 0, indices, ones) - - -def ce_loss(logits, labels, mask=None, top_k_percentage=None, deterministic=False): - """Computes the cross-entropy loss. - Optionally a mask and a top-k percentage for the used pixels can be specified. - The top-k mask can be produced deterministically or sampled. - - Args: - logits (torch.Tensor): A tensor of shape (b,num_classes,h,w) - labels (torch.Tensor): A tensor of shape (b,num_classes,h,w) - mask (torch.Tensor, optional): None or a tensor of shape (b,h,w). Defaults to None. - top_k_percentage (float, optional): None or a float in (0.,1.]. If None, a standard - cross-entropy loss is calculated. Defaults to None. - deterministic (bool, optional): A Boolean indicating whether or not to produce the - prospective top-k mask deterministically. Defaults to - False. - - Returns: - dict: A dictionary holding the mean and the pixelwise sum of the loss for the - batch as well as the employed loss mask. - """ - - num_classes = logits.shape[1] - - # y_flat = torch.reshape(logits, (-1, num_classes)) - # t_flat = torch.reshape(labels, (-1, num_classes)) - - # print(y_flat.shape) - # print(t_flat.shape) - # if mask is None: - # mask = torch.ones(t_flat.shape[0]) - # mask = mask.to(logits.device) - # else: - # assert ( - # mask.shape.as_list()[:3] == labels.shape.as_list()[:3] - # ), "The loss mask shape differs from the target shape: {} vs. {}.".format( - # mask.shape.as_list(), labels.shape.as_list()[:3] - # ) - # mask = torch.reshape(mask, (-1,)) - - # n_pixels_in_batch = y_flat.shape[0] - criterion = torch.nn.BCEWithLogitsLoss(reduction="mean") - xe = criterion(input=logits, target=labels) - - print(xe) - # xe = softmax_cross_entropy_with_logits(t_flat, y_flat) - - # if top_k_percentage is not None: - # assert 0.0 < top_k_percentage <= 1.0 - # k_pixels = math.floor(n_pixels_in_batch * top_k_percentage) - - # # stopgrad_xe = tf.stop_gradient(xe) - # norm_xe = xe / xe.sum() - - # if deterministic: - - # score = norm_xe.log() - # else: - # # Use the Gumbel trick to sample the top-k pixels, equivalent to sampling - # # from a categorical distribution over pixels whose probabilities are - # # given by the normalized cross-entropy loss values. This is done by - # # adding Gumbel noise to the logarithmic normalized cross-entropy loss - # # (followed by choosing the top-k pixels). - # sg = _sample_gumbel(norm_xe.shape) - # sg = sg.to(logits.device) - # score = norm_xe.log() + sg - - # score = score + mask.log() - # top_k_mask = _topk_mask(score, k_pixels) - # mask = mask * top_k_mask - - # Calculate batch-averages for the sum and mean of the loss - # batch_size = labels.shape[0] - # xe = torch.reshape(xe, (batch_size, int(xe.numel() / batch_size))) - # mask = torch.reshape(mask, (batch_size, int(mask.numel() / batch_size))) - # ce_sum_per_instance = torch.sum(mask * xe, 1) - # ce_sum = torch.mean(ce_sum_per_instance, 0) - # ce_mean = torch.sum(mask * xe) / torch.sum(mask) - xe = torch.sum(xe) - - return {"mean": ce_mean, "sum": ce_sum, "mask": mask} - - class _HierarchicalCore(torch.nn.Module): """A U-Net encoder-decoder with a full encoder and a truncated decoder. The truncated decoder is interleaved with the hierarchical latent space and @@ -634,8 +504,6 @@ def __init__( if loss_kwargs is None: self._loss_kwargs = { "type": "elbo", - "top_k_percentage": 0.02, - "deterministic_top_k": False, "kappa": 0.05, "decay": 0.99, "rate": 1e-2, @@ -756,23 +624,15 @@ def kl(self, img, seg): return kl - def rec_loss(self, img, seg, mask=None, top_k_percentage=None, deterministic=True): + def rec_loss(self, img, seg): """Cross-entropy reconstruction loss employed in the ELBO-/ GECO-objective. Args: img (torch.Tensor): A tensor of shape (b, c, h, w). seg (torch.Tensor): A tensor of shape (b, num_classes, h, w). - mask (torch.Tensor, optional): A mask of shape (b, h, w) or None. If None no pixels are - masked in the loss. Defaults to None. - top_k_percentage (float, optional): None or a float in (0.,1.]. If None, a standard - cross-entropy loss is calculated. Defaults to None. - deterministic (bool, optional): A Boolean indicating whether or not to produce the - prospective top-k mask deterministically. Defaults to - True. Returns: - dict: A dictionary holding the mean and the pixelwise sum of the loss for the - batch as well as the employed loss mask. + dict: A dictionary holding the mean and the pixelwise sum of the loss """ reconstruction = self.reconstruct(img, seg, mean=False) @@ -783,28 +643,21 @@ def rec_loss(self, img, seg, mask=None, top_k_percentage=None, deterministic=Tru return {"mean": reconstruction_loss_mean, "sum": reconstruction_loss_sum} - # return ce_loss(reconstruction, seg, mask, top_k_percentage, deterministic) - - def loss(self, img, seg, mask=None): + def loss(self, img, seg): """The full training objective, either ELBO or GECO. Args: img (torch.Tensor): A tensor of shape (b, c, h, w). seg (torch.Tensor): A tensor of shape (b, num_classes, h, w). - mask (torch.Tensor, optional): A mask of shape (b, h, w) or None. If None no pixels are - masked in the loss. Defaults to None. Raises: NotImplementedError: Raised if loss function supplied isn't implemented yet. Returns: - dict: A dictionary holding the loss (with key 'loss') and the tensorboard summaries - (with key 'summaries'). + dict: A dictionary holding the loss (with key 'loss') """ summaries = {} - top_k_percentage = self._loss_kwargs["top_k_percentage"] - deterministic = self._loss_kwargs["deterministic_top_k"] - rec_loss = self.rec_loss(img, seg, mask, top_k_percentage, deterministic) + rec_loss = self.rec_loss(img, seg) kl_dict = self.kl(img, seg) kl_sum = torch.sum(torch.stack([kl for _, kl in kl_dict.items()], axis=-1)) From 536192b2ae5a22216e27c659ff7ed2103f479adf Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Wed, 31 Mar 2021 07:23:01 +0000 Subject: [PATCH 021/264] Work on extending hprob net to 3d --- platipy/imaging/cnn/hierarchical_prob_unet.py | 92 +++++++++++++++---- platipy/imaging/cnn/test_hpunet.py | 20 ++-- 2 files changed, 86 insertions(+), 26 deletions(-) diff --git a/platipy/imaging/cnn/hierarchical_prob_unet.py b/platipy/imaging/cnn/hierarchical_prob_unet.py index d20acaa7..53b0fc5d 100644 --- a/platipy/imaging/cnn/hierarchical_prob_unet.py +++ b/platipy/imaging/cnn/hierarchical_prob_unet.py @@ -21,6 +21,16 @@ import torch +def conv_nd(ndims=2, **kwargs): + + if ndims == 2: + return torch.nn.Conv2d(**kwargs) + elif ndims == 3: + return torch.nn.Conv3d(**kwargs) + + raise NotImplementedError("Only 2 or 3 dimensions are supported") + + class ResBlock(torch.nn.Module): """A residual block""" @@ -31,6 +41,7 @@ def __init__( n_down_channels=None, activation_fn=torch.nn.ReLU, convs_per_block=3, + ndims=2, ): """Create a residual block @@ -57,8 +68,12 @@ def __init__( in_channels = input_channels for c in range(convs_per_block): layers.append( - torch.nn.Conv2d( - in_channels=in_channels, out_channels=n_down_channels, kernel_size=3, padding=1 + conv_nd( + ndims=ndims, + in_channels=in_channels, + out_channels=n_down_channels, + kernel_size=3, + padding=1, ) ) @@ -68,8 +83,12 @@ def __init__( in_channels = n_down_channels if not n_down_channels == output_channels: - resize_outgoing = torch.nn.Conv2d( - in_channels=n_down_channels, out_channels=output_channels, kernel_size=1, padding=0 + resize_outgoing = conv_nd( + ndims=ndims, + in_channels=n_down_channels, + out_channels=output_channels, + kernel_size=1, + padding=0, ) layers.append(resize_outgoing) @@ -78,8 +97,12 @@ def __init__( self._resize_skip = None if not input_channels == output_channels: - self._resize_skip = torch.nn.Conv2d( - in_channels=input_channels, out_channels=output_channels, kernel_size=1, padding=0 + self._resize_skip = conv_nd( + ndims=ndims, + in_channels=input_channels, + out_channels=output_channels, + kernel_size=1, + padding=0, ) def forward(self, input_features): @@ -107,10 +130,18 @@ def resize_up(input_features, scale=2): Returns: torch.Tensor: The upsized Tensor """ - _, _, size_x, size_y = input_features.shape - new_size_x = int(round(size_x * scale)) - new_size_y = int(round(size_y * scale)) - return torch.nn.functional.interpolate(input_features, size=[new_size_x, new_size_y]) + + input_shape = input_features.shape + size_x = input_shape[2] + size_y = input_shape[3] + + new_size = [int(round(size_x * scale)), int(round(size_y * scale))] + + if len(input_shape) == 5: + size_z = input_shape[4] + new_size = new_size + [int(round(size_z * scale))] + + return torch.nn.functional.interpolate(input_features, size=new_size) def resize_down(input_features, scale=2): @@ -123,7 +154,10 @@ def resize_down(input_features, scale=2): Returns: torch.Tensor: The downsized Tensor """ - return torch.nn.AvgPool2d(kernel_size=scale, stride=scale, padding=0)(input_features) + if input_features.ndim == 5: + return torch.nn.AvgPool3d(kernel_size=scale, stride=scale, padding=0)(input_features) + else: + return torch.nn.AvgPool2d(kernel_size=scale, stride=scale, padding=0)(input_features) class _HierarchicalCore(torch.nn.Module): @@ -142,6 +176,7 @@ def __init__( activation_fn=torch.nn.ReLU, convs_per_block=3, blocks_per_level=3, + ndims=2, ): """Initializes a HierarchicalCore. @@ -195,6 +230,7 @@ def __init__( n_down_channels=self._down_channels_per_block[level], activation_fn=self._activation_fn, convs_per_block=self._convs_per_block, + ndims=ndims, ) ) in_channels = channels_per_block[level] @@ -209,8 +245,12 @@ def __init__( latent_dim = latent_dims[level] - mu_logsigma_block = torch.nn.Conv2d( - channels_per_block[::-1][level], 2 * latent_dim, kernel_size=1, padding=0 + mu_logsigma_block = conv_nd( + ndims=ndims, + in_channels=channels_per_block[::-1][level], + out_channels=2 * latent_dim, + kernel_size=1, + padding=0, ) self._mu_logsigma_blocks.append(mu_logsigma_block) @@ -227,10 +267,11 @@ def __init__( n_down_channels=self._down_channels_per_block[::-1][level + 1], activation_fn=self._activation_fn, convs_per_block=self._convs_per_block, + ndims=ndims, ) ) decoder_in_channels = channels_per_block[::-1][level + 1] - + print(channels_per_block[::-1][level + 1]) self.decoder_layers.append(torch.nn.Sequential(*layer)) def forward(self, inputs, mean=False, z_q=None): @@ -335,6 +376,7 @@ def __init__( activation_fn=torch.nn.ReLU, convs_per_block=3, blocks_per_level=3, + ndims=2, ): """Initializes a StichtingDecoder. @@ -375,11 +417,13 @@ def __init__( self._num_levels = len(self._channels_per_block) self.decoder_layers = torch.nn.ModuleList() + decoder_in_channels = None for level in range(self._start_level, self._num_levels, 1): decoder_in_channels = ( channels_per_block[::-1][level - 1] + channels_per_block[::-1][level] ) + layer = [] for _ in range(self._blocks_per_level): layer.append( @@ -389,15 +433,21 @@ def __init__( n_down_channels=self._down_channels_per_block[::-1][level], activation_fn=self._activation_fn, convs_per_block=self._convs_per_block, + ndims=ndims, ) ) - decoder_in_channels = channels_per_block[::-1][level] self.decoder_layers.append(torch.nn.Sequential(*layer)) - self.final_layer = torch.nn.Conv2d( - decoder_in_channels, self._num_classes, kernel_size=1, padding=0 - ) + decoder_in_channels = channels_per_block[::-1][self._num_levels - 1] + + self.final_layer = conv_nd( + ndims=ndims, + in_channels=decoder_in_channels, + out_channels=self._num_classes, + kernel_size=1, + padding=0, + ) def forward(self, encoder_features, decoder_features): """Forward pass through the stiching decoder @@ -434,6 +484,7 @@ def __init__( convs_per_block=3, blocks_per_level=3, loss_kwargs=None, + ndims=2, ): """Initialize the Hierarchical Probabilistic UNet @@ -470,7 +521,7 @@ def __init__( if channels_per_block is None: channels_per_block = default_channels_per_block if down_channels_per_block is None: - down_channels_per_block = [int(i / 2) for i in default_channels_per_block] + down_channels_per_block = [int(i / 2) for i in channels_per_block] self._prior = _HierarchicalCore( input_channels=input_channels, @@ -479,6 +530,7 @@ def __init__( down_channels_per_block=down_channels_per_block, convs_per_block=convs_per_block, blocks_per_level=blocks_per_level, + ndims=ndims, ) self._posterior = _HierarchicalCore( @@ -488,6 +540,7 @@ def __init__( down_channels_per_block=down_channels_per_block, convs_per_block=convs_per_block, blocks_per_level=blocks_per_level, + ndims=ndims, ) self._f_comb = _StitchingDecoder( @@ -497,6 +550,7 @@ def __init__( down_channels_per_block=down_channels_per_block, convs_per_block=convs_per_block, blocks_per_level=blocks_per_level, + ndims=ndims, ) self._cache = None diff --git a/platipy/imaging/cnn/test_hpunet.py b/platipy/imaging/cnn/test_hpunet.py index 1de62ceb..3a58b113 100644 --- a/platipy/imaging/cnn/test_hpunet.py +++ b/platipy/imaging/cnn/test_hpunet.py @@ -10,8 +10,8 @@ default_channels_per_block = [ base_channels, 2 * base_channels, - 4 * base_channels, - 8 * base_channels, + # 4 * base_channels, + # 8 * base_channels, # 8 * base_channels, # 8 * base_channels, # 8 * base_channels, @@ -23,12 +23,16 @@ # 4 * base_channels, # 8 * base_channels, # ] + +latent_dims = [8, 6, 2] +latent_dims = [2] + channels_per_block = default_channels_per_block down_channels_per_block = [int(i / 2) for i in default_channels_per_block] -a = _HierarchicalCore([8, 6, 2], 1, channels_per_block, down_channels_per_block) -b = ResBlock(1, 3, base_channels) -d = ResBlock(3, 3, base_channels * 2) -c = torch.rand([3, 1, 128, 128]) +# a = _HierarchicalCore(latent_dims, 1, channels_per_block, down_channels_per_block) +# b = ResBlock(1, 3, base_channels, ndims=3) +# d = ResBlock(3, 3, base_channels * 2, ndims=3) +c = torch.rand([3, 1, 32, 32]) fg = torch.ones(c.shape) bg = torch.zeros(c.shape) @@ -36,7 +40,9 @@ # a(b) # print(a) -hpunet = HierarchicalProbabilisticUnet() +hpunet = HierarchicalProbabilisticUnet( + ndims=2, channels_per_block=channels_per_block, latent_dims=[1] +) output = hpunet.sample(c) print(output.shape) output = hpunet.reconstruct(c, labels) From 773f75c3a38c0578fb986ef9ee42339ee615ef8f Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Wed, 31 Mar 2021 20:37:56 +0000 Subject: [PATCH 022/264] Fixes for 3d hprob unet --- platipy/imaging/cnn/hierarchical_prob_unet.py | 21 +++++++++++++++++-- platipy/imaging/cnn/test_hpunet.py | 17 +-------------- 2 files changed, 20 insertions(+), 18 deletions(-) diff --git a/platipy/imaging/cnn/hierarchical_prob_unet.py b/platipy/imaging/cnn/hierarchical_prob_unet.py index 53b0fc5d..eaf7a448 100644 --- a/platipy/imaging/cnn/hierarchical_prob_unet.py +++ b/platipy/imaging/cnn/hierarchical_prob_unet.py @@ -22,6 +22,17 @@ def conv_nd(ndims=2, **kwargs): + """Generate a 2D or 3D convolution + + Args: + ndims (int, optional): 2 or 3 dimensions. Defaults to 2. + + Raises: + NotImplementedError: Raised if ndims is not in 2 or 3 dimensions. + + Returns: + torch.nn.Conv: The convolution. + """ if ndims == 2: return torch.nn.Conv2d(**kwargs) @@ -54,6 +65,7 @@ def __init__( to torch.nn.ReLU. convs_per_block (int, optional): The number of convolutions to perform within the block. Defaults to 3. + ndims (int, optional): Specify whether to use 2 or 3 dimensions. Defaults to 2. """ super(ResBlock, self).__init__() @@ -198,6 +210,7 @@ def __init__( layers. Defaults to 3. blocks_per_level (int, optional): An integer specifying the number of residual blocks per level. Defaults to 3. + ndims (int, optional): Specify whether to use 2 or 3 dimensions. Defaults to 2. """ super(_HierarchicalCore, self).__init__() @@ -271,7 +284,7 @@ def __init__( ) ) decoder_in_channels = channels_per_block[::-1][level + 1] - print(channels_per_block[::-1][level + 1]) + self.decoder_layers.append(torch.nn.Sequential(*layer)) def forward(self, inputs, mean=False, z_q=None): @@ -400,6 +413,7 @@ def __init__( layers. Defaults to 3. blocks_per_level (int, optional): An integer specifying the number of residual blocks per level. Defaults to 3. + ndims (int, optional): Specify whether to use 2 or 3 dimensions. Defaults to 2. """ super(_StitchingDecoder, self).__init__() self._latent_dims = latent_dims @@ -436,10 +450,12 @@ def __init__( ndims=ndims, ) ) + decoder_in_channels = channels_per_block[::-1][level] self.decoder_layers.append(torch.nn.Sequential(*layer)) - decoder_in_channels = channels_per_block[::-1][self._num_levels - 1] + if decoder_in_channels is None: + decoder_in_channels = channels_per_block[::-1][self._num_levels - 1] self.final_layer = conv_nd( ndims=ndims, @@ -504,6 +520,7 @@ def __init__( blocks per level. Defaults to 3. loss_kwargs (dict, optional): Dictionary of argument used by loss function. Defaults to None. + ndims (int, optional): Specify whether to use 2 or 3 dimensions. Defaults to 2. """ super(HierarchicalProbabilisticUnet, self).__init__() diff --git a/platipy/imaging/cnn/test_hpunet.py b/platipy/imaging/cnn/test_hpunet.py index 3a58b113..b2fe021e 100644 --- a/platipy/imaging/cnn/test_hpunet.py +++ b/platipy/imaging/cnn/test_hpunet.py @@ -29,20 +29,13 @@ channels_per_block = default_channels_per_block down_channels_per_block = [int(i / 2) for i in default_channels_per_block] -# a = _HierarchicalCore(latent_dims, 1, channels_per_block, down_channels_per_block) -# b = ResBlock(1, 3, base_channels, ndims=3) -# d = ResBlock(3, 3, base_channels * 2, ndims=3) c = torch.rand([3, 1, 32, 32]) fg = torch.ones(c.shape) bg = torch.zeros(c.shape) labels = torch.cat([fg, bg], axis=1) -# a(b) -# print(a) -hpunet = HierarchicalProbabilisticUnet( - ndims=2, channels_per_block=channels_per_block, latent_dims=[1] -) +hpunet = HierarchicalProbabilisticUnet(channels_per_block=channels_per_block, latent_dims=[1]) output = hpunet.sample(c) print(output.shape) output = hpunet.reconstruct(c, labels) @@ -60,10 +53,6 @@ _BOTTLENECK_SIZE = _SPATIAL_SHAPE[0] // 2 ** (len(_CHANNELS_PER_BLOCK) - 1) _SEGMENTATION_SHAPE = [_BATCH_SIZE] + [_NUM_CLASSES] + _SPATIAL_SHAPE _LATENT_DIMS = [3, 2, 1] -# _INITIALIZERS = { -# "w": tf.orthogonal_initializer(gain=1.0, seed=None), -# "b": tf.truncated_normal_initializer(stddev=0.001), -# } def _get_placeholders(): @@ -78,7 +67,6 @@ def test_shape_of_sample(): latent_dims=_LATENT_DIMS, channels_per_block=_CHANNELS_PER_BLOCK, num_classes=_NUM_CLASSES, - # initializers=_INITIALIZERS, ) img, _ = _get_placeholders() sample = hpu_net.sample(img) @@ -91,7 +79,6 @@ def test_shape_of_reconstruction(): latent_dims=_LATENT_DIMS, channels_per_block=_CHANNELS_PER_BLOCK, num_classes=_NUM_CLASSES, - # initializers=_INITIALIZERS, ) img, seg = _get_placeholders() reconstruction = hpu_net.reconstruct(img, seg) @@ -103,7 +90,6 @@ def test_shapes_in_prior(): latent_dims=_LATENT_DIMS, channels_per_block=_CHANNELS_PER_BLOCK, num_classes=_NUM_CLASSES, - # initializers=_INITIALIZERS, ) img, _ = _get_placeholders() prior_out = hpu_net._prior(img) @@ -151,7 +137,6 @@ def test_shape_of_kl(): latent_dims=_LATENT_DIMS, channels_per_block=_CHANNELS_PER_BLOCK, num_classes=_NUM_CLASSES, - # initializers=_INITIALIZERS, ) img, seg = _get_placeholders() kl_dict = hpu_net.kl(img, seg) From d2389adc7d0f8888d3fd96f53c50afdf774b3d6e Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Wed, 14 Apr 2021 08:15:49 +0000 Subject: [PATCH 023/264] Add init weights code --- platipy/imaging/cnn/hierarchical_prob_unet.py | 28 +++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/platipy/imaging/cnn/hierarchical_prob_unet.py b/platipy/imaging/cnn/hierarchical_prob_unet.py index eaf7a448..1ce3fd97 100644 --- a/platipy/imaging/cnn/hierarchical_prob_unet.py +++ b/platipy/imaging/cnn/hierarchical_prob_unet.py @@ -20,6 +20,25 @@ import torch +def truncated_normal_(tensor, mean=0, std=1): + size = tensor.shape + tmp = tensor.new_empty(size + (4,)).normal_() + valid = (tmp < 2) & (tmp > -2) + ind = valid.max(-1, keepdim=True)[1] + tensor.data.copy_(tmp.gather(-1, ind).squeeze(-1)) + tensor.data.mul_(std).add_(mean) + + +def init_weights(m): + if ( + isinstance(m, torch.nn.Conv2d) + or isinstance(m, torch.nn.ConvTranspose2d) + or isinstance(m, torch.nn.Conv3d) + or isinstance(m, torch.nn.ConvTranspose3d) + ): + torch.nn.init.kaiming_normal_(m.weight, mode="fan_in", nonlinearity="relu") + truncated_normal_(m.bias, mean=0, std=0.001) + def conv_nd(ndims=2, **kwargs): """Generate a 2D or 3D convolution @@ -105,6 +124,7 @@ def __init__( layers.append(resize_outgoing) self._layers = torch.nn.Sequential(*layers) + self._layers.apply(init_weights) self._resize_skip = None @@ -116,6 +136,7 @@ def __init__( kernel_size=1, padding=0, ) + self._resize_skip.apply(init_weights) def forward(self, input_features): @@ -250,6 +271,8 @@ def __init__( self.encoder_layers.append(torch.nn.Sequential(*layer)) + self.encoder_layers.apply(init_weights) + # Iterate the ascending levels in the (truncated) U-Net decoder. self.decoder_layers = torch.nn.ModuleList() self._mu_logsigma_blocks = torch.nn.ModuleList() @@ -287,6 +310,9 @@ def __init__( self.decoder_layers.append(torch.nn.Sequential(*layer)) + self._mu_logsigma_blocks.apply(init_weights) + self.decoder_layers.apply(init_weights) + def forward(self, inputs, mean=False, z_q=None): """Forward pass to sample from the module as specified. @@ -453,6 +479,7 @@ def __init__( decoder_in_channels = channels_per_block[::-1][level] self.decoder_layers.append(torch.nn.Sequential(*layer)) + self.decoder_layers.apply(init_weights) if decoder_in_channels is None: decoder_in_channels = channels_per_block[::-1][self._num_levels - 1] @@ -464,6 +491,7 @@ def __init__( kernel_size=1, padding=0, ) + self.final_layer.apply(init_weights) def forward(self, encoder_features, decoder_features): """Forward pass through the stiching decoder From 9f02f66c90453269614d83ed54c0415541ce6c6d Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Wed, 14 Apr 2021 21:31:09 +0000 Subject: [PATCH 024/264] Init prob dist weights to zero --- platipy/imaging/cnn/hierarchical_prob_unet.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/platipy/imaging/cnn/hierarchical_prob_unet.py b/platipy/imaging/cnn/hierarchical_prob_unet.py index 1ce3fd97..f57e3bde 100644 --- a/platipy/imaging/cnn/hierarchical_prob_unet.py +++ b/platipy/imaging/cnn/hierarchical_prob_unet.py @@ -20,6 +20,7 @@ import torch + def truncated_normal_(tensor, mean=0, std=1): size = tensor.shape tmp = tensor.new_empty(size + (4,)).normal_() @@ -40,6 +41,17 @@ def init_weights(m): truncated_normal_(m.bias, mean=0, std=0.001) +def init_zeros(m): + if ( + isinstance(m, torch.nn.Conv2d) + or isinstance(m, torch.nn.ConvTranspose2d) + or isinstance(m, torch.nn.Conv3d) + or isinstance(m, torch.nn.ConvTranspose3d) + ): + torch.nn.init.zeros_(m.weight) + truncated_normal_(m.bias, mean=0, std=0.1) + + def conv_nd(ndims=2, **kwargs): """Generate a 2D or 3D convolution @@ -310,7 +322,7 @@ def __init__( self.decoder_layers.append(torch.nn.Sequential(*layer)) - self._mu_logsigma_blocks.apply(init_weights) + self._mu_logsigma_blocks.apply(init_zeros) self.decoder_layers.apply(init_weights) def forward(self, inputs, mean=False, z_q=None): From 9344d7f33975a5148f8a454a67e40b16a10d65c7 Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Tue, 4 May 2021 00:52:04 +0000 Subject: [PATCH 025/264] Allow sample certain distance from mean --- platipy/imaging/cnn/hierarchical_prob_unet.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/platipy/imaging/cnn/hierarchical_prob_unet.py b/platipy/imaging/cnn/hierarchical_prob_unet.py index f57e3bde..cc860b87 100644 --- a/platipy/imaging/cnn/hierarchical_prob_unet.py +++ b/platipy/imaging/cnn/hierarchical_prob_unet.py @@ -325,7 +325,7 @@ def __init__( self._mu_logsigma_blocks.apply(init_zeros) self.decoder_layers.apply(init_weights) - def forward(self, inputs, mean=False, z_q=None): + def forward(self, inputs, mean=False, std_devs_from_mean=0.0, z_q=None): """Forward pass to sample from the module as specified. Args: @@ -338,6 +338,9 @@ def forward(self, inputs, mean=False, z_q=None): latent scales. If a list, each bool therein specifies whether or not to use the scale's mean. If False, the latents of the scale are sampled. Defaults to False. + std_devs_from_mean (float|list, optional): A float or list of floats describing how far + from the mean should be sampled. Only at + scales where mean is True. Defaults to 0. if mean is True Defaults to None. z_q (list, optional): None or a list of tensors. If not None, z_q provides external latents to be used instead of sampling them. This is used to employ posterior latents in the prior during training. Therefore, @@ -358,8 +361,16 @@ def forward(self, inputs, mean=False, z_q=None): encoder_outputs = [] num_levels = len(self._channels_per_block) num_latent_levels = len(self._latent_dims) + if isinstance(mean, bool): mean = [mean] * self._num_latent_levels + + if isinstance(std_devs_from_mean, int): + std_devs_from_mean = float(std_devs_from_mean) + + if isinstance(std_devs_from_mean, float): + std_devs_from_mean = [std_devs_from_mean] * self._num_latent_levels + distributions = [] used_latents = [] @@ -390,7 +401,7 @@ def forward(self, inputs, mean=False, z_q=None): if z_q is not None: z = z_q[level] elif mean[level]: - z = dist.base_dist.loc + z = dist.base_dist.loc + (dist.base_dist.scale * std_devs_from_mean[level]) else: z = dist.sample() From 2358d0c7d760429d83d4659490562d6535b273b7 Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Tue, 4 May 2021 00:53:43 +0000 Subject: [PATCH 026/264] Correct comment --- platipy/imaging/cnn/hierarchical_prob_unet.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/platipy/imaging/cnn/hierarchical_prob_unet.py b/platipy/imaging/cnn/hierarchical_prob_unet.py index cc860b87..ba01a158 100644 --- a/platipy/imaging/cnn/hierarchical_prob_unet.py +++ b/platipy/imaging/cnn/hierarchical_prob_unet.py @@ -340,7 +340,7 @@ def forward(self, inputs, mean=False, std_devs_from_mean=0.0, z_q=None): latents of the scale are sampled. Defaults to False. std_devs_from_mean (float|list, optional): A float or list of floats describing how far from the mean should be sampled. Only at - scales where mean is True. Defaults to 0. if mean is True Defaults to None. + scales where mean is True. Defaults to 0. z_q (list, optional): None or a list of tensors. If not None, z_q provides external latents to be used instead of sampling them. This is used to employ posterior latents in the prior during training. Therefore, From 9e7c02ce62132479b46986cdd7660516a6e22174 Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Tue, 4 May 2021 04:36:40 +0000 Subject: [PATCH 027/264] Allow sample std devs from mean --- platipy/imaging/cnn/hierarchical_prob_unet.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/platipy/imaging/cnn/hierarchical_prob_unet.py b/platipy/imaging/cnn/hierarchical_prob_unet.py index ba01a158..592eed3c 100644 --- a/platipy/imaging/cnn/hierarchical_prob_unet.py +++ b/platipy/imaging/cnn/hierarchical_prob_unet.py @@ -671,7 +671,7 @@ def forward(self, img, seg): self._p_sample_z_q_mean = self._prior(img, z_q=self._q_sample_mean["used_latents"]) self._cache = input_tensor - def sample(self, img, mean=False, z_q=None): + def sample(self, img, mean=False, std_devs_from_mean=0.0, z_q=None): """Sample a segmentation from the prior, given an input image. Args: @@ -681,6 +681,9 @@ def sample(self, img, mean=False, z_q=None): scales. If a list, each bool therein specifies whether or not to use the scale's mean. If False, the latents of the scale are sampled. Defaults to False. + std_devs_from_mean (float|list, optional): A float or list of floats describing how far + from the mean should be sampled. Only at + scales where mean is True. Defaults to 0. z_q (list, optional): If not None, z_q provides external latents to be used instead of sampling them. This is used to employ posterior latents in the prior during training. Therefore, if z_q is not None, the value @@ -692,7 +695,7 @@ def sample(self, img, mean=False, z_q=None): torch.Tensor: A segmentation tensor of shape (b, num_classes, h, w). """ - prior_out = self._prior(img, mean, z_q) + prior_out = self._prior(img, mean, std_devs_from_mean, z_q) encoder_features = prior_out["encoder_features"] decoder_features = prior_out["decoder_features"] return self._f_comb(encoder_features, decoder_features) From 6745ccdb2039c392208567f53414e23557e304f1 Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Tue, 11 May 2021 00:42:18 +0000 Subject: [PATCH 028/264] Working on GECO --- platipy/imaging/cnn/hierarchical_prob_unet.py | 94 +++++++++++++++---- 1 file changed, 74 insertions(+), 20 deletions(-) diff --git a/platipy/imaging/cnn/hierarchical_prob_unet.py b/platipy/imaging/cnn/hierarchical_prob_unet.py index 592eed3c..90a7a5b4 100644 --- a/platipy/imaging/cnn/hierarchical_prob_unet.py +++ b/platipy/imaging/cnn/hierarchical_prob_unet.py @@ -73,6 +73,63 @@ def conv_nd(ndims=2, **kwargs): raise NotImplementedError("Only 2 or 3 dimensions are supported") +class ExponentialMovingAverage(torch.nn.Module): + """Maintains an exponential moving average for a value. + Note this module uses debiasing by default. If you don't want this please use + an alternative implementation. + This module keeps track of a hidden exponential moving average that is + initialized as a vector of zeros which is then normalized to give the average. + This gives us a moving average which isn't biased towards either zero or the + initial value. Reference (https://arxiv.org/pdf/1412.6980.pdf) + Initially: + hidden_0 = 0 + Then iteratively: + hidden_i = (hidden_{i-1} - value) * (1 - decay) + average_i = hidden_i / (1 - decay^i) + Attributes: + average: Variable holding average. Note that this is None until the first + value is passed. + """ + + def __init__(self, decay): + """Creates a debiased moving average module. + Args: + decay: The decay to use. Note values close to 1 result in a slow decay + whereas values close to 0 result in faster decay, tracking the input + values more closely. + name: Name of the module. + """ + super(ExponentialMovingAverage, self).__init__() + + self._decay = decay + self._counter = torch.Tensor(0, requires_grad=False) + + self._hidden = None + self._average = None + + def forward(self, value): + """Applies EMA to the value given.""" + + # Initialise if not already + if self._hidden is None: + self._hidden = torch.Tensor(torch.zeros(value.shape), requires_grad=False) + self._average = torch.Tensor(torch.zeros(value.shape), requires_grad=False) + + # self._counter.assign_add(1) + self._counter += 1 + counter = self._counter.type(value.type()) + self._hidden -= (self._hidden - value) * (1 - self._decay) + self._average = self._hidden / (1.0 - torch.pow(self._decay, counter)) + + return self._average + + def reset(self): + """Resets the EMA.""" + self._counter = torch.zeros(self._contour.shape) + self._hidden = torch.zeros(self._hidden.shape) + self._average = torch.zeros(self._average.shape) + + class ResBlock(torch.nn.Module): """A residual block""" @@ -634,11 +691,9 @@ def __init__( else: self._loss_kwargs = loss_kwargs - # if self._loss_kwargs["type"] == "geco": - # self._moving_average = ExponentialMovingAverage( - # model=self, decay=self._loss_kwargs["decay"] - # ) - # self._lagmul = geco_utils.LagrangeMultiplier(rate=self._loss_kwargs["rate"]) + if self._loss_kwargs["type"] == "geco": + self._moving_average = ExponentialMovingAverage(decay=self._loss_kwargs["decay"]) + self._lagmul = geco_utils.LagrangeMultiplier(rate=self._loss_kwargs["rate"]) self._q_sample = None self._q_sample_mean = None @@ -798,22 +853,21 @@ def loss(self, img, seg): loss = rec_loss["sum"] + self._loss_kwargs["beta"] * kl_sum summaries["elbo_loss"] = loss - # TODO Still need to implement geco # Set up a GECO objective (ELBO with a reconstruction constraint). - # elif self._loss_kwargs["type"] == "geco": - # ma_rec_loss = self._moving_average(rec_loss["sum"]) - # mask_sum_per_instance = torch.sum(rec_loss["mask"], -1) - # num_valid_pixels = torch.mean(mask_sum_per_instance) - # reconstruction_threshold = self._loss_kwargs["kappa"] * num_valid_pixels - - # rec_constraint = ma_rec_loss - reconstruction_threshold - # lagmul = self._lagmul(rec_constraint) - # loss = lagmul * rec_constraint + kl_sum - - # summaries["geco_loss"] = loss - # summaries["ma_rec_loss_mean"] = ma_rec_loss / num_valid_pixels - # summaries["num_valid_pixels"] = num_valid_pixels - # summaries["lagmul"] = lagmul + elif self._loss_kwargs["type"] == "geco": + ma_rec_loss = self._moving_average(rec_loss["sum"]) + mask_sum_per_instance = torch.sum(rec_loss["mask"], -1) + num_valid_pixels = torch.mean(mask_sum_per_instance) + reconstruction_threshold = self._loss_kwargs["kappa"] * num_valid_pixels + + rec_constraint = ma_rec_loss - reconstruction_threshold + lagmul = self._lagmul(rec_constraint) + loss = lagmul * rec_constraint + kl_sum + + summaries["geco_loss"] = loss + summaries["ma_rec_loss_mean"] = ma_rec_loss / num_valid_pixels + summaries["num_valid_pixels"] = num_valid_pixels + summaries["lagmul"] = lagmul else: raise NotImplementedError( "Loss type {} not implemeted!".format(self._loss_kwargs["type"]) From a8ad3f23a66f499e21a076b55f7f2ce447ee7fb9 Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Thu, 20 May 2021 01:23:28 +0000 Subject: [PATCH 029/264] Implement GECO loss function --- platipy/imaging/cnn/hierarchical_prob_unet.py | 32 ++++++++++++++++--- platipy/imaging/cnn/test_hpunet.py | 14 +++++++- 2 files changed, 40 insertions(+), 6 deletions(-) diff --git a/platipy/imaging/cnn/hierarchical_prob_unet.py b/platipy/imaging/cnn/hierarchical_prob_unet.py index 90a7a5b4..ad3e2253 100644 --- a/platipy/imaging/cnn/hierarchical_prob_unet.py +++ b/platipy/imaging/cnn/hierarchical_prob_unet.py @@ -102,7 +102,7 @@ def __init__(self, decay): super(ExponentialMovingAverage, self).__init__() self._decay = decay - self._counter = torch.Tensor(0, requires_grad=False) + self._counter = torch.zeros(1, requires_grad=False) self._hidden = None self._average = None @@ -112,8 +112,8 @@ def forward(self, value): # Initialise if not already if self._hidden is None: - self._hidden = torch.Tensor(torch.zeros(value.shape), requires_grad=False) - self._average = torch.Tensor(torch.zeros(value.shape), requires_grad=False) + self._hidden = torch.zeros(value.shape, requires_grad=False) + self._average = torch.zeros(value.shape, requires_grad=False) # self._counter.assign_add(1) self._counter += 1 @@ -130,6 +130,26 @@ def reset(self): self._average = torch.zeros(self._average.shape) +class LagrangeMultiplier(torch.nn.Module): + def __init__(self, rate): + super(LagrangeMultiplier, self).__init__() + self._rate = rate + self._softplus = torch.nn.Softplus() + self._lambda_var = None + + def forward(self, value): + + if not self._lambda_var: + self._lambda_var = torch.ones(value.shape, requires_grad=True) + + lag_multiplier = self._softplus(self._lambda_var) ** 2 + lag_multiplier.retain_grad() + if lag_multiplier.grad: + lag_multiplier.grad *= self._rate + + return lag_multiplier + + class ResBlock(torch.nn.Module): """A residual block""" @@ -693,7 +713,7 @@ def __init__( if self._loss_kwargs["type"] == "geco": self._moving_average = ExponentialMovingAverage(decay=self._loss_kwargs["decay"]) - self._lagmul = geco_utils.LagrangeMultiplier(rate=self._loss_kwargs["rate"]) + self._lagmul = LagrangeMultiplier(rate=self._loss_kwargs["rate"]) self._q_sample = None self._q_sample_mean = None @@ -821,7 +841,9 @@ def rec_loss(self, img, seg): reconstruction_loss_sum = torch.sum(reconstruction_loss) reconstruction_loss_mean = torch.mean(reconstruction_loss) - return {"mean": reconstruction_loss_mean, "sum": reconstruction_loss_sum} + mask = torch.ones(torch.numel(img)) + + return {"mean": reconstruction_loss_mean, "sum": reconstruction_loss_sum, "mask": mask} def loss(self, img, seg): """The full training objective, either ELBO or GECO. diff --git a/platipy/imaging/cnn/test_hpunet.py b/platipy/imaging/cnn/test_hpunet.py index b2fe021e..7562b41d 100644 --- a/platipy/imaging/cnn/test_hpunet.py +++ b/platipy/imaging/cnn/test_hpunet.py @@ -35,7 +35,19 @@ bg = torch.zeros(c.shape) labels = torch.cat([fg, bg], axis=1) -hpunet = HierarchicalProbabilisticUnet(channels_per_block=channels_per_block, latent_dims=[1]) +hpunet = HierarchicalProbabilisticUnet( + channels_per_block=channels_per_block, + latent_dims=[1], + loss_kwargs={ + "type": "geco", + "top_k_percentage": 0.02, + "deterministic_top_k": False, + "kappa": 0.05, + "decay": 0.99, + "rate": 1e-2, + "beta": 5, + }, +) output = hpunet.sample(c) print(output.shape) output = hpunet.reconstruct(c, labels) From 7a54a6b6ec0e804644ca6d8be42dc4e94c8e0c25 Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Thu, 20 May 2021 01:27:43 +0000 Subject: [PATCH 030/264] Update prob dist equation --- platipy/imaging/cnn/hierarchical_prob_unet.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/platipy/imaging/cnn/hierarchical_prob_unet.py b/platipy/imaging/cnn/hierarchical_prob_unet.py index ad3e2253..ec5d0707 100644 --- a/platipy/imaging/cnn/hierarchical_prob_unet.py +++ b/platipy/imaging/cnn/hierarchical_prob_unet.py @@ -478,7 +478,7 @@ def forward(self, inputs, mean=False, std_devs_from_mean=0.0, z_q=None): if z_q is not None: z = z_q[level] elif mean[level]: - z = dist.base_dist.loc + (dist.base_dist.scale * std_devs_from_mean[level]) + z = dist.mean + (dist.base_dist.stddev * std_devs_from_mean[level]) else: z = dist.sample() From bf9242a9253aaa8ab95584549509822a87048184 Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Thu, 20 May 2021 01:37:04 +0000 Subject: [PATCH 031/264] Add lagrange to model as parameter --- platipy/imaging/cnn/hierarchical_prob_unet.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/platipy/imaging/cnn/hierarchical_prob_unet.py b/platipy/imaging/cnn/hierarchical_prob_unet.py index ec5d0707..bb0434d2 100644 --- a/platipy/imaging/cnn/hierarchical_prob_unet.py +++ b/platipy/imaging/cnn/hierarchical_prob_unet.py @@ -140,7 +140,7 @@ def __init__(self, rate): def forward(self, value): if not self._lambda_var: - self._lambda_var = torch.ones(value.shape, requires_grad=True) + self._lambda_var = torch.nn.Parameter(torch.ones(value.shape, requires_grad=True)) lag_multiplier = self._softplus(self._lambda_var) ** 2 lag_multiplier.retain_grad() @@ -713,6 +713,7 @@ def __init__( if self._loss_kwargs["type"] == "geco": self._moving_average = ExponentialMovingAverage(decay=self._loss_kwargs["decay"]) + self._moving_average.to() self._lagmul = LagrangeMultiplier(rate=self._loss_kwargs["rate"]) self._q_sample = None From 7d68fc7beed8579f2f81355fcac1f238b45ff1c9 Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Thu, 20 May 2021 01:41:29 +0000 Subject: [PATCH 032/264] Add other params to model --- platipy/imaging/cnn/hierarchical_prob_unet.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/platipy/imaging/cnn/hierarchical_prob_unet.py b/platipy/imaging/cnn/hierarchical_prob_unet.py index bb0434d2..e29276de 100644 --- a/platipy/imaging/cnn/hierarchical_prob_unet.py +++ b/platipy/imaging/cnn/hierarchical_prob_unet.py @@ -102,7 +102,7 @@ def __init__(self, decay): super(ExponentialMovingAverage, self).__init__() self._decay = decay - self._counter = torch.zeros(1, requires_grad=False) + self._counter = torch.nn.Parameter(torch.zeros(1, requires_grad=False)) self._hidden = None self._average = None @@ -112,8 +112,8 @@ def forward(self, value): # Initialise if not already if self._hidden is None: - self._hidden = torch.zeros(value.shape, requires_grad=False) - self._average = torch.zeros(value.shape, requires_grad=False) + self._hidden = torch.nn.Parameter(torch.zeros(value.shape, requires_grad=False)) + self._average = torch.nn.Parameter(torch.zeros(value.shape, requires_grad=False)) # self._counter.assign_add(1) self._counter += 1 From 5698ddb8de52e4a3d373b032852cb91be76b062a Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Thu, 20 May 2021 01:56:59 +0000 Subject: [PATCH 033/264] Init before forward pass --- platipy/imaging/cnn/hierarchical_prob_unet.py | 23 +++++-------------- 1 file changed, 6 insertions(+), 17 deletions(-) diff --git a/platipy/imaging/cnn/hierarchical_prob_unet.py b/platipy/imaging/cnn/hierarchical_prob_unet.py index e29276de..a35a8ab8 100644 --- a/platipy/imaging/cnn/hierarchical_prob_unet.py +++ b/platipy/imaging/cnn/hierarchical_prob_unet.py @@ -102,20 +102,13 @@ def __init__(self, decay): super(ExponentialMovingAverage, self).__init__() self._decay = decay - self._counter = torch.nn.Parameter(torch.zeros(1, requires_grad=False)) + self._counter = torch.zeros(1, requires_grad=False) - self._hidden = None - self._average = None + self._hidden = torch.zeros(1, requires_grad=False) + self._average = torch.zeros(1, requires_grad=False) def forward(self, value): """Applies EMA to the value given.""" - - # Initialise if not already - if self._hidden is None: - self._hidden = torch.nn.Parameter(torch.zeros(value.shape, requires_grad=False)) - self._average = torch.nn.Parameter(torch.zeros(value.shape, requires_grad=False)) - - # self._counter.assign_add(1) self._counter += 1 counter = self._counter.type(value.type()) self._hidden -= (self._hidden - value) * (1 - self._decay) @@ -135,12 +128,9 @@ def __init__(self, rate): super(LagrangeMultiplier, self).__init__() self._rate = rate self._softplus = torch.nn.Softplus() - self._lambda_var = None - - def forward(self, value): + self._lambda_var = torch.ones(1, requires_grad=True) - if not self._lambda_var: - self._lambda_var = torch.nn.Parameter(torch.ones(value.shape, requires_grad=True)) + def forward(self): lag_multiplier = self._softplus(self._lambda_var) ** 2 lag_multiplier.retain_grad() @@ -713,7 +703,6 @@ def __init__( if self._loss_kwargs["type"] == "geco": self._moving_average = ExponentialMovingAverage(decay=self._loss_kwargs["decay"]) - self._moving_average.to() self._lagmul = LagrangeMultiplier(rate=self._loss_kwargs["rate"]) self._q_sample = None @@ -884,7 +873,7 @@ def loss(self, img, seg): reconstruction_threshold = self._loss_kwargs["kappa"] * num_valid_pixels rec_constraint = ma_rec_loss - reconstruction_threshold - lagmul = self._lagmul(rec_constraint) + lagmul = self._lagmul() loss = lagmul * rec_constraint + kl_sum summaries["geco_loss"] = loss From 59dcb67d54d6db78aa922e3289b682eef484a8c4 Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Thu, 20 May 2021 02:11:24 +0000 Subject: [PATCH 034/264] Register buffers and parameters --- platipy/imaging/cnn/hierarchical_prob_unet.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/platipy/imaging/cnn/hierarchical_prob_unet.py b/platipy/imaging/cnn/hierarchical_prob_unet.py index a35a8ab8..fb46af6a 100644 --- a/platipy/imaging/cnn/hierarchical_prob_unet.py +++ b/platipy/imaging/cnn/hierarchical_prob_unet.py @@ -103,9 +103,13 @@ def __init__(self, decay): self._decay = decay self._counter = torch.zeros(1, requires_grad=False) + self.register_buffer("moving_avg_counter", self._counter) self._hidden = torch.zeros(1, requires_grad=False) + self.register_buffer("moving_avg_hidden", self._hidden) + self._average = torch.zeros(1, requires_grad=False) + self.register_buffer("moving_avg_average", self._average) def forward(self, value): """Applies EMA to the value given.""" @@ -128,7 +132,8 @@ def __init__(self, rate): super(LagrangeMultiplier, self).__init__() self._rate = rate self._softplus = torch.nn.Softplus() - self._lambda_var = torch.ones(1, requires_grad=True) + self._lambda_var = torch.nn.Parameter(torch.ones(1, requires_grad=True)) + self.register_parameter("lagrange_multiplier", self._lambda_var) def forward(self): From aead3037bc9c7b2f5003efb1f4de13bea583471b Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Thu, 20 May 2021 03:24:27 +0000 Subject: [PATCH 035/264] Correct register buffer --- platipy/imaging/cnn/hierarchical_prob_unet.py | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/platipy/imaging/cnn/hierarchical_prob_unet.py b/platipy/imaging/cnn/hierarchical_prob_unet.py index fb46af6a..2bc74453 100644 --- a/platipy/imaging/cnn/hierarchical_prob_unet.py +++ b/platipy/imaging/cnn/hierarchical_prob_unet.py @@ -97,19 +97,14 @@ def __init__(self, decay): decay: The decay to use. Note values close to 1 result in a slow decay whereas values close to 0 result in faster decay, tracking the input values more closely. - name: Name of the module. """ super(ExponentialMovingAverage, self).__init__() self._decay = decay - self._counter = torch.zeros(1, requires_grad=False) - self.register_buffer("moving_avg_counter", self._counter) - self._hidden = torch.zeros(1, requires_grad=False) - self.register_buffer("moving_avg_hidden", self._hidden) - - self._average = torch.zeros(1, requires_grad=False) - self.register_buffer("moving_avg_average", self._average) + self.register_buffer("_counter", torch.zeros(1, requires_grad=False)) + self.register_buffer("_hidden", torch.zeros(1, requires_grad=False)) + self.register_buffer("_average", torch.zeros(1, requires_grad=False)) def forward(self, value): """Applies EMA to the value given.""" From 087f5383ed7fdced405a6c82abf6e2a14bafcc08 Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Thu, 20 May 2021 03:40:59 +0000 Subject: [PATCH 036/264] Don't modify in place --- platipy/imaging/cnn/hierarchical_prob_unet.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/platipy/imaging/cnn/hierarchical_prob_unet.py b/platipy/imaging/cnn/hierarchical_prob_unet.py index 2bc74453..ce1cddf4 100644 --- a/platipy/imaging/cnn/hierarchical_prob_unet.py +++ b/platipy/imaging/cnn/hierarchical_prob_unet.py @@ -108,9 +108,9 @@ def __init__(self, decay): def forward(self, value): """Applies EMA to the value given.""" - self._counter += 1 + self._counter = self._counter + 1 counter = self._counter.type(value.type()) - self._hidden -= (self._hidden - value) * (1 - self._decay) + self._hidden = self._hidden - (self._hidden - value) * (1 - self._decay) self._average = self._hidden / (1.0 - torch.pow(self._decay, counter)) return self._average @@ -135,7 +135,7 @@ def forward(self): lag_multiplier = self._softplus(self._lambda_var) ** 2 lag_multiplier.retain_grad() if lag_multiplier.grad: - lag_multiplier.grad *= self._rate + lag_multiplier.grad = lag_multiplier.grad * self._rate return lag_multiplier From 3b2f4aad71cddb5f818c951751cc9285788aed91 Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Thu, 20 May 2021 04:24:35 +0000 Subject: [PATCH 037/264] Clean the EMA --- platipy/imaging/cnn/hierarchical_prob_unet.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/platipy/imaging/cnn/hierarchical_prob_unet.py b/platipy/imaging/cnn/hierarchical_prob_unet.py index ce1cddf4..d644327d 100644 --- a/platipy/imaging/cnn/hierarchical_prob_unet.py +++ b/platipy/imaging/cnn/hierarchical_prob_unet.py @@ -103,17 +103,16 @@ def __init__(self, decay): self._decay = decay self.register_buffer("_counter", torch.zeros(1, requires_grad=False)) + self._counter = 1 self.register_buffer("_hidden", torch.zeros(1, requires_grad=False)) - self.register_buffer("_average", torch.zeros(1, requires_grad=False)) def forward(self, value): """Applies EMA to the value given.""" - self._counter = self._counter + 1 - counter = self._counter.type(value.type()) + + self._counter += self._counter self._hidden = self._hidden - (self._hidden - value) * (1 - self._decay) - self._average = self._hidden / (1.0 - torch.pow(self._decay, counter)) - return self._average + return self._hidden / (1.0 - torch.pow(self._decay, self._counter)) def reset(self): """Resets the EMA.""" From 6306ec4621171066b1d83ce1fa277773c4839bfe Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Thu, 20 May 2021 04:57:35 +0000 Subject: [PATCH 038/264] Correct issue --- platipy/imaging/cnn/hierarchical_prob_unet.py | 1 - 1 file changed, 1 deletion(-) diff --git a/platipy/imaging/cnn/hierarchical_prob_unet.py b/platipy/imaging/cnn/hierarchical_prob_unet.py index d644327d..e89270ef 100644 --- a/platipy/imaging/cnn/hierarchical_prob_unet.py +++ b/platipy/imaging/cnn/hierarchical_prob_unet.py @@ -102,7 +102,6 @@ def __init__(self, decay): self._decay = decay - self.register_buffer("_counter", torch.zeros(1, requires_grad=False)) self._counter = 1 self.register_buffer("_hidden", torch.zeros(1, requires_grad=False)) From c1f0c0c44b0ef1a2cb9ad684f3d00a3d1fd56cc7 Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Thu, 20 May 2021 04:59:20 +0000 Subject: [PATCH 039/264] Fix issue --- platipy/imaging/cnn/hierarchical_prob_unet.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/platipy/imaging/cnn/hierarchical_prob_unet.py b/platipy/imaging/cnn/hierarchical_prob_unet.py index e89270ef..2f6a7bf0 100644 --- a/platipy/imaging/cnn/hierarchical_prob_unet.py +++ b/platipy/imaging/cnn/hierarchical_prob_unet.py @@ -102,16 +102,15 @@ def __init__(self, decay): self._decay = decay - self._counter = 1 + self.register_buffer("_counter", torch.zeros(1, requires_grad=False)) self.register_buffer("_hidden", torch.zeros(1, requires_grad=False)) def forward(self, value): """Applies EMA to the value given.""" - - self._counter += self._counter + self._counter = self._counter + 1 + counter = self._counter.type(value.type()) self._hidden = self._hidden - (self._hidden - value) * (1 - self._decay) - - return self._hidden / (1.0 - torch.pow(self._decay, self._counter)) + return self._hidden / (1.0 - torch.pow(self._decay, counter)) def reset(self): """Resets the EMA.""" From 7bf48051522b699a556cc3d2e86d2fa8a18d2a7c Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Thu, 20 May 2021 05:47:30 +0000 Subject: [PATCH 040/264] Experiment with new Constraint class --- platipy/imaging/cnn/hierarchical_prob_unet.py | 107 ++++++++++++++++-- 1 file changed, 95 insertions(+), 12 deletions(-) diff --git a/platipy/imaging/cnn/hierarchical_prob_unet.py b/platipy/imaging/cnn/hierarchical_prob_unet.py index 2f6a7bf0..b21dc886 100644 --- a/platipy/imaging/cnn/hierarchical_prob_unet.py +++ b/platipy/imaging/cnn/hierarchical_prob_unet.py @@ -137,6 +137,87 @@ def forward(self): return lag_multiplier +# From https://github.com/eelcovdw/pytorch-constrained-opt/blob/master/constraint.py +class Constraint(torch.nn.Module): + def __init__( + self, + bound, + relation, + name=None, + multiplier_act=torch.nn.functional.softplus, + alpha=0.0, + start_val=0.0, + ): + """ + Adds a constraint to a loss function by turning the loss into a lagrangian. + Alpha is used for a moving average as described in [1]. + Note that this is similar as using an optimizer with momentum. + [1] Rezende, Danilo Jimenez, and Fabio Viola. + "Taming vaes." arXiv preprint arXiv:1810.00597 (2018). + Args: + bound: Constraint bound. + relation (str): relation of constraint, + using naming convention from operator module (eq, le, ge). + Defaults to 'ge'. + name (str, optional): Constraint name + multiplier_act (optional): When using inequality relations, + an activation function is used to force the multiplier to be positive. + I've experimented with ReLU, abs and softplus, softplus seems the most stable. + Defaults to F.softplus. + alpha (float, optional): alpha of moving average, as in [1]. + If alpha=0, no moving average is used. + start_val (float, optional): Start value of multiplier. If an activation function + is used the true start value might be different, because this is pre-activation. + """ + super().__init__() + self.name = name + if isinstance(bound, (int, float)): + self.bound = torch.Tensor([bound]) + elif isinstance(bound, list): + self.bound = torch.Tensor(bound) + else: + self.bound = bound + + if relation in {"ge", "le", "eq"}: + self.relation = relation + else: + raise ValueError("Unknown relation: {}".format(relation)) + + if self.relation == "eq" and multiplier_act is not None: + print( + "WARNING using an activation that maps to R+ with an equality \ + constraint turns it into an inequality constraint" + ) + + self._multiplier = torch.nn.Parameter(torch.full((len(self.bound),), start_val)) + self._act = multiplier_act + + self.alpha = alpha + self.avg_value = None + + @property + def multiplier(self): + if self._act is not None: + return self._act(self._multiplier) + return self._multiplier + + def forward(self, value): + # Apply moving average, defined in [1] + if self.alpha > 0: + if self.avg_value is None: + self.avg_value = value.detach().mean(0) + else: + self.avg_value = ( + self.avg_value * self.alpha + value.detach() * (1 - self.alpha) + ).mean(0) + value = value + (self.avg_value.unsqueeze(0) - value).detach() + if self.relation in {"ge", "eq"}: + loss = self.bound.to(value.device) - value + elif self.relation == "le": + loss = value - self.bound.to(value.device) + return loss * self.multiplier + + class ResBlock(torch.nn.Module): """A residual block""" @@ -699,8 +780,9 @@ def __init__( self._loss_kwargs = loss_kwargs if self._loss_kwargs["type"] == "geco": - self._moving_average = ExponentialMovingAverage(decay=self._loss_kwargs["decay"]) - self._lagmul = LagrangeMultiplier(rate=self._loss_kwargs["rate"]) + # self._moving_average = ExponentialMovingAverage(decay=self._loss_kwargs["decay"]) + # self._lagmul = LagrangeMultiplier(rate=self._loss_kwargs["rate"]) + self._rec_constraint = Constraint(0.02, "le", alpha=0.5) # rec_loss <= 0.02 self._q_sample = None self._q_sample_mean = None @@ -864,19 +946,20 @@ def loss(self, img, seg): # Set up a GECO objective (ELBO with a reconstruction constraint). elif self._loss_kwargs["type"] == "geco": - ma_rec_loss = self._moving_average(rec_loss["sum"]) - mask_sum_per_instance = torch.sum(rec_loss["mask"], -1) - num_valid_pixels = torch.mean(mask_sum_per_instance) - reconstruction_threshold = self._loss_kwargs["kappa"] * num_valid_pixels + # ma_rec_loss = self._moving_average(rec_loss["sum"]) + # mask_sum_per_instance = torch.sum(rec_loss["mask"], -1) + # num_valid_pixels = torch.mean(mask_sum_per_instance) + # reconstruction_threshold = self._loss_kwargs["kappa"] * num_valid_pixels + + # rec_constraint = ma_rec_loss - reconstruction_threshold + # lagmul = self._lagmul() - rec_constraint = ma_rec_loss - reconstruction_threshold - lagmul = self._lagmul() - loss = lagmul * rec_constraint + kl_sum + loss = (rec_loss["sum"] + kl_sum) * self._rec_constraint(rec_loss["sum"]) summaries["geco_loss"] = loss - summaries["ma_rec_loss_mean"] = ma_rec_loss / num_valid_pixels - summaries["num_valid_pixels"] = num_valid_pixels - summaries["lagmul"] = lagmul + # summaries["ma_rec_loss_mean"] = ma_rec_loss / num_valid_pixels + # summaries["num_valid_pixels"] = num_valid_pixels + # summaries["lagmul"] = lagmul else: raise NotImplementedError( "Loss type {} not implemeted!".format(self._loss_kwargs["type"]) From 94fb15d4dbae7092c1898cdb002c41734d2bd743 Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Fri, 21 May 2021 16:51:10 +1000 Subject: [PATCH 041/264] Scale bound by numel --- platipy/imaging/cnn/hierarchical_prob_unet.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/platipy/imaging/cnn/hierarchical_prob_unet.py b/platipy/imaging/cnn/hierarchical_prob_unet.py index b21dc886..1ca47b41 100644 --- a/platipy/imaging/cnn/hierarchical_prob_unet.py +++ b/platipy/imaging/cnn/hierarchical_prob_unet.py @@ -201,7 +201,7 @@ def multiplier(self): return self._act(self._multiplier) return self._multiplier - def forward(self, value): + def forward(self, value, numel): # Apply moving average, defined in [1] if self.alpha > 0: if self.avg_value is None: @@ -212,9 +212,9 @@ def forward(self, value): ).mean(0) value = value + (self.avg_value.unsqueeze(0) - value).detach() if self.relation in {"ge", "eq"}: - loss = self.bound.to(value.device) - value + loss = self.bound.to(value.device) * numel * - value elif self.relation == "le": - loss = value - self.bound.to(value.device) + loss = value - self.bound.to(value.device) * numel return loss * self.multiplier @@ -954,7 +954,7 @@ def loss(self, img, seg): # rec_constraint = ma_rec_loss - reconstruction_threshold # lagmul = self._lagmul() - loss = (rec_loss["sum"] + kl_sum) * self._rec_constraint(rec_loss["sum"]) + loss = (rec_loss["sum"] + kl_sum) * self._rec_constraint(rec_loss["sum"], seg.numel()) summaries["geco_loss"] = loss # summaries["ma_rec_loss_mean"] = ma_rec_loss / num_valid_pixels From 83c317aa6610d92155c6057487ab9b8606df6fa2 Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Sat, 22 May 2021 08:47:51 +1000 Subject: [PATCH 042/264] Adjust loss --- platipy/imaging/cnn/hierarchical_prob_unet.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/platipy/imaging/cnn/hierarchical_prob_unet.py b/platipy/imaging/cnn/hierarchical_prob_unet.py index 1ca47b41..540dd420 100644 --- a/platipy/imaging/cnn/hierarchical_prob_unet.py +++ b/platipy/imaging/cnn/hierarchical_prob_unet.py @@ -954,8 +954,9 @@ def loss(self, img, seg): # rec_constraint = ma_rec_loss - reconstruction_threshold # lagmul = self._lagmul() - loss = (rec_loss["sum"] + kl_sum) * self._rec_constraint(rec_loss["sum"], seg.numel()) + #loss = (rec_loss["sum"] + kl_sum) * self._rec_constraint(rec_loss["sum"], img[0,0,:].numel()) + loss = self._rec_constraint(rec_loss["sum"], img[0,0,:].numel()) + kl_sum summaries["geco_loss"] = loss # summaries["ma_rec_loss_mean"] = ma_rec_loss / num_valid_pixels # summaries["num_valid_pixels"] = num_valid_pixels From d267d5e9592df121a6a6f046d5127fed55022d4e Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Thu, 20 May 2021 22:57:25 +0000 Subject: [PATCH 043/264] Update lagrange multiplier --- platipy/imaging/cnn/hierarchical_prob_unet.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/platipy/imaging/cnn/hierarchical_prob_unet.py b/platipy/imaging/cnn/hierarchical_prob_unet.py index 540dd420..26126c78 100644 --- a/platipy/imaging/cnn/hierarchical_prob_unet.py +++ b/platipy/imaging/cnn/hierarchical_prob_unet.py @@ -212,10 +212,10 @@ def forward(self, value, numel): ).mean(0) value = value + (self.avg_value.unsqueeze(0) - value).detach() if self.relation in {"ge", "eq"}: - loss = self.bound.to(value.device) * numel * - value + loss = self.bound.to(value.device) * numel * -value elif self.relation == "le": loss = value - self.bound.to(value.device) * numel - return loss * self.multiplier + return {"updated_loss": loss * self.multiplier, "multiplier": self.multiplier} class ResBlock(torch.nn.Module): @@ -954,13 +954,14 @@ def loss(self, img, seg): # rec_constraint = ma_rec_loss - reconstruction_threshold # lagmul = self._lagmul() - #loss = (rec_loss["sum"] + kl_sum) * self._rec_constraint(rec_loss["sum"], img[0,0,:].numel()) + # loss = (rec_loss["sum"] + kl_sum) * self._rec_constraint(rec_loss["sum"], img[0,0,:].numel()) - loss = self._rec_constraint(rec_loss["sum"], img[0,0,:].numel()) + kl_sum + rec_loss_weighted = self._rec_constraint(rec_loss["sum"], img[0, 0, :].numel()) + loss = rec_loss_weighted["updated_loss"] + kl_sum summaries["geco_loss"] = loss # summaries["ma_rec_loss_mean"] = ma_rec_loss / num_valid_pixels # summaries["num_valid_pixels"] = num_valid_pixels - # summaries["lagmul"] = lagmul + summaries["lagmul"] = rec_loss_weighted["multiplier"] else: raise NotImplementedError( "Loss type {} not implemeted!".format(self._loss_kwargs["type"]) From a580c3b9909856346b5f64b2865b7f3427343882 Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Fri, 21 May 2021 00:10:02 +0000 Subject: [PATCH 044/264] Reimplement GECO --- platipy/imaging/cnn/hierarchical_prob_unet.py | 182 +++--------------- 1 file changed, 23 insertions(+), 159 deletions(-) diff --git a/platipy/imaging/cnn/hierarchical_prob_unet.py b/platipy/imaging/cnn/hierarchical_prob_unet.py index 26126c78..1ae084ed 100644 --- a/platipy/imaging/cnn/hierarchical_prob_unet.py +++ b/platipy/imaging/cnn/hierarchical_prob_unet.py @@ -73,151 +73,6 @@ def conv_nd(ndims=2, **kwargs): raise NotImplementedError("Only 2 or 3 dimensions are supported") -class ExponentialMovingAverage(torch.nn.Module): - """Maintains an exponential moving average for a value. - Note this module uses debiasing by default. If you don't want this please use - an alternative implementation. - This module keeps track of a hidden exponential moving average that is - initialized as a vector of zeros which is then normalized to give the average. - This gives us a moving average which isn't biased towards either zero or the - initial value. Reference (https://arxiv.org/pdf/1412.6980.pdf) - Initially: - hidden_0 = 0 - Then iteratively: - hidden_i = (hidden_{i-1} - value) * (1 - decay) - average_i = hidden_i / (1 - decay^i) - Attributes: - average: Variable holding average. Note that this is None until the first - value is passed. - """ - - def __init__(self, decay): - """Creates a debiased moving average module. - Args: - decay: The decay to use. Note values close to 1 result in a slow decay - whereas values close to 0 result in faster decay, tracking the input - values more closely. - """ - super(ExponentialMovingAverage, self).__init__() - - self._decay = decay - - self.register_buffer("_counter", torch.zeros(1, requires_grad=False)) - self.register_buffer("_hidden", torch.zeros(1, requires_grad=False)) - - def forward(self, value): - """Applies EMA to the value given.""" - self._counter = self._counter + 1 - counter = self._counter.type(value.type()) - self._hidden = self._hidden - (self._hidden - value) * (1 - self._decay) - return self._hidden / (1.0 - torch.pow(self._decay, counter)) - - def reset(self): - """Resets the EMA.""" - self._counter = torch.zeros(self._contour.shape) - self._hidden = torch.zeros(self._hidden.shape) - self._average = torch.zeros(self._average.shape) - - -class LagrangeMultiplier(torch.nn.Module): - def __init__(self, rate): - super(LagrangeMultiplier, self).__init__() - self._rate = rate - self._softplus = torch.nn.Softplus() - self._lambda_var = torch.nn.Parameter(torch.ones(1, requires_grad=True)) - self.register_parameter("lagrange_multiplier", self._lambda_var) - - def forward(self): - - lag_multiplier = self._softplus(self._lambda_var) ** 2 - lag_multiplier.retain_grad() - if lag_multiplier.grad: - lag_multiplier.grad = lag_multiplier.grad * self._rate - - return lag_multiplier - - -# From https://github.com/eelcovdw/pytorch-constrained-opt/blob/master/constraint.py -class Constraint(torch.nn.Module): - def __init__( - self, - bound, - relation, - name=None, - multiplier_act=torch.nn.functional.softplus, - alpha=0.0, - start_val=0.0, - ): - """ - Adds a constraint to a loss function by turning the loss into a lagrangian. - Alpha is used for a moving average as described in [1]. - Note that this is similar as using an optimizer with momentum. - [1] Rezende, Danilo Jimenez, and Fabio Viola. - "Taming vaes." arXiv preprint arXiv:1810.00597 (2018). - Args: - bound: Constraint bound. - relation (str): relation of constraint, - using naming convention from operator module (eq, le, ge). - Defaults to 'ge'. - name (str, optional): Constraint name - multiplier_act (optional): When using inequality relations, - an activation function is used to force the multiplier to be positive. - I've experimented with ReLU, abs and softplus, softplus seems the most stable. - Defaults to F.softplus. - alpha (float, optional): alpha of moving average, as in [1]. - If alpha=0, no moving average is used. - start_val (float, optional): Start value of multiplier. If an activation function - is used the true start value might be different, because this is pre-activation. - """ - super().__init__() - self.name = name - if isinstance(bound, (int, float)): - self.bound = torch.Tensor([bound]) - elif isinstance(bound, list): - self.bound = torch.Tensor(bound) - else: - self.bound = bound - - if relation in {"ge", "le", "eq"}: - self.relation = relation - else: - raise ValueError("Unknown relation: {}".format(relation)) - - if self.relation == "eq" and multiplier_act is not None: - print( - "WARNING using an activation that maps to R+ with an equality \ - constraint turns it into an inequality constraint" - ) - - self._multiplier = torch.nn.Parameter(torch.full((len(self.bound),), start_val)) - self._act = multiplier_act - - self.alpha = alpha - self.avg_value = None - - @property - def multiplier(self): - if self._act is not None: - return self._act(self._multiplier) - return self._multiplier - - def forward(self, value, numel): - # Apply moving average, defined in [1] - if self.alpha > 0: - if self.avg_value is None: - self.avg_value = value.detach().mean(0) - else: - self.avg_value = ( - self.avg_value * self.alpha + value.detach() * (1 - self.alpha) - ).mean(0) - value = value + (self.avg_value.unsqueeze(0) - value).detach() - if self.relation in {"ge", "eq"}: - loss = self.bound.to(value.device) * numel * -value - elif self.relation == "le": - loss = value - self.bound.to(value.device) * numel - return {"updated_loss": loss * self.multiplier, "multiplier": self.multiplier} - - class ResBlock(torch.nn.Module): """A residual block""" @@ -781,8 +636,9 @@ def __init__( if self._loss_kwargs["type"] == "geco": # self._moving_average = ExponentialMovingAverage(decay=self._loss_kwargs["decay"]) - # self._lagmul = LagrangeMultiplier(rate=self._loss_kwargs["rate"]) - self._rec_constraint = Constraint(0.02, "le", alpha=0.5) # rec_loss <= 0.02 + # self._geco_loss = GECOLoss(target_ratio, alpha=0.5) + self._ema = None + self.register_buffer("_multiplier", torch.zeros(1, requires_grad=False)) self._q_sample = None self._q_sample_mean = None @@ -947,21 +803,29 @@ def loss(self, img, seg): # Set up a GECO objective (ELBO with a reconstruction constraint). elif self._loss_kwargs["type"] == "geco": # ma_rec_loss = self._moving_average(rec_loss["sum"]) - # mask_sum_per_instance = torch.sum(rec_loss["mask"], -1) - # num_valid_pixels = torch.mean(mask_sum_per_instance) - # reconstruction_threshold = self._loss_kwargs["kappa"] * num_valid_pixels - - # rec_constraint = ma_rec_loss - reconstruction_threshold - # lagmul = self._lagmul() - - # loss = (rec_loss["sum"] + kl_sum) * self._rec_constraint(rec_loss["sum"], img[0,0,:].numel()) + if self._ema is None: + self._ema = rec_loss["sum"].detach().mean(0) + else: + alpha = self._loss_kwargs["alpha"] + self._ema = (self._ema * alpha + rec_loss["sum"].detach() * (1 - alpha)).mean(0) + + mask_sum_per_instance = torch.sum(rec_loss["mask"], -1) + num_valid_pixels = torch.mean(mask_sum_per_instance) + reconstruction_threshold = self._loss_kwargs["kappa"] * num_valid_pixels + rec_constraint = self._ema - reconstruction_threshold + + speed = 1 + if rec_constraint > 0: + speed = 2 + self._multiplier = (torch.exp(speed * rec_constraint) * self._multiplier).clamp( + 1e-5, 1e5 + ) + loss = rec_loss["sum"] * self._multiplier + self._loss_kwargs["beta"] * kl_sum - rec_loss_weighted = self._rec_constraint(rec_loss["sum"], img[0, 0, :].numel()) - loss = rec_loss_weighted["updated_loss"] + kl_sum summaries["geco_loss"] = loss # summaries["ma_rec_loss_mean"] = ma_rec_loss / num_valid_pixels - # summaries["num_valid_pixels"] = num_valid_pixels - summaries["lagmul"] = rec_loss_weighted["multiplier"] + summaries["num_valid_pixels"] = num_valid_pixels + summaries["lagmul"] = self._multiplier else: raise NotImplementedError( "Loss type {} not implemeted!".format(self._loss_kwargs["type"]) From bd4f7dc6ae3d01cd679ffa1909fd8e82e0a42c05 Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Fri, 21 May 2021 00:17:41 +0000 Subject: [PATCH 045/264] Use hard coded alpha for now --- platipy/imaging/cnn/hierarchical_prob_unet.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/platipy/imaging/cnn/hierarchical_prob_unet.py b/platipy/imaging/cnn/hierarchical_prob_unet.py index 1ae084ed..00885e30 100644 --- a/platipy/imaging/cnn/hierarchical_prob_unet.py +++ b/platipy/imaging/cnn/hierarchical_prob_unet.py @@ -806,8 +806,8 @@ def loss(self, img, seg): if self._ema is None: self._ema = rec_loss["sum"].detach().mean(0) else: - alpha = self._loss_kwargs["alpha"] - self._ema = (self._ema * alpha + rec_loss["sum"].detach() * (1 - alpha)).mean(0) + # alpha = self._loss_kwargs["alpha"] + self._ema = (self._ema * 0.5 + rec_loss["sum"].detach() * (1 - 0.5)).mean(0) mask_sum_per_instance = torch.sum(rec_loss["mask"], -1) num_valid_pixels = torch.mean(mask_sum_per_instance) From c8ef33ab69b2369181e9d66a721d257dd5c10faf Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Fri, 21 May 2021 00:17:50 +0000 Subject: [PATCH 046/264] Don't use exp --- platipy/imaging/cnn/hierarchical_prob_unet.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/platipy/imaging/cnn/hierarchical_prob_unet.py b/platipy/imaging/cnn/hierarchical_prob_unet.py index 00885e30..8c105a05 100644 --- a/platipy/imaging/cnn/hierarchical_prob_unet.py +++ b/platipy/imaging/cnn/hierarchical_prob_unet.py @@ -813,13 +813,10 @@ def loss(self, img, seg): num_valid_pixels = torch.mean(mask_sum_per_instance) reconstruction_threshold = self._loss_kwargs["kappa"] * num_valid_pixels rec_constraint = self._ema - reconstruction_threshold - speed = 1 if rec_constraint > 0: speed = 2 - self._multiplier = (torch.exp(speed * rec_constraint) * self._multiplier).clamp( - 1e-5, 1e5 - ) + self._multiplier = (speed * rec_constraint * self._multiplier).clamp(1e-5, 1e5) loss = rec_loss["sum"] * self._multiplier + self._loss_kwargs["beta"] * kl_sum summaries["geco_loss"] = loss From d0dd9cd408197a68982483f203bbb9ce71acf64c Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Sat, 29 May 2021 21:07:14 +1000 Subject: [PATCH 047/264] Update to prob Unet --- platipy/imaging/cnn/prob_unet.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/platipy/imaging/cnn/prob_unet.py b/platipy/imaging/cnn/prob_unet.py index 64036d3e..98ad9ddc 100644 --- a/platipy/imaging/cnn/prob_unet.py +++ b/platipy/imaging/cnn/prob_unet.py @@ -88,6 +88,7 @@ def forward(self, img, seg=None): # This is a multivariate normal with diagonal covariance matrix sigma # https://github.com/pytorch/pytorch/pull/11178 dist = Independent(Normal(loc=mu, scale=torch.exp(log_sigma)), 1) + return dist @@ -193,6 +194,9 @@ def sample(self, testing=False, use_mean=False, sample_x_stddev_from_mean=None): if use_mean: z_prior = self.prior_latent_space.base_dist.loc elif not sample_x_stddev_from_mean is None: + if isinstance(sample_x_stddev_from_mean, list): + sample_x_stddev_from_mean = torch.Tensor(sample_x_stddev_from_mean) + sample_x_stddev_from_mean = sample_x_stddev_from_mean.to(self.prior_latent_space.base_dist.stddev.device) z_prior = self.prior_latent_space.base_dist.loc + ( self.prior_latent_space.base_dist.scale * sample_x_stddev_from_mean ) @@ -256,4 +260,6 @@ def elbo(self, segm, analytic_kl=True, reconstruct_posterior_mean=False): reconstruction_loss = torch.sum(reconstruction_loss) # mean_reconstruction_loss = torch.mean(reconstruction_loss) - return -(reconstruction_loss + self.beta * kl_div) + return {"loss": -(reconstruction_loss + self.beta * kl_div), + "rec_loss": reconstruction_loss, + "kl_div": kl_div} From 68a5b0bdba56f96a565ac8c2cc0aa895b6869ad0 Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Sat, 19 Jun 2021 01:28:19 +0000 Subject: [PATCH 048/264] Implement GECO for prob unet --- platipy/imaging/cnn/prob_unet.py | 63 ++++++++++++++++++++++++++++++-- 1 file changed, 59 insertions(+), 4 deletions(-) diff --git a/platipy/imaging/cnn/prob_unet.py b/platipy/imaging/cnn/prob_unet.py index 98ad9ddc..e5ae7e85 100644 --- a/platipy/imaging/cnn/prob_unet.py +++ b/platipy/imaging/cnn/prob_unet.py @@ -174,6 +174,9 @@ def __init__( self.prior_latent_space = None self.unet_features = None + self._moving_avg = None + self.register_buffer("_lambda", torch.zeros(1, requires_grad=False)) + def forward(self, img, seg, training=True): """ Construct prior latent space for patch and run patch through UNet, @@ -196,7 +199,9 @@ def sample(self, testing=False, use_mean=False, sample_x_stddev_from_mean=None): elif not sample_x_stddev_from_mean is None: if isinstance(sample_x_stddev_from_mean, list): sample_x_stddev_from_mean = torch.Tensor(sample_x_stddev_from_mean) - sample_x_stddev_from_mean = sample_x_stddev_from_mean.to(self.prior_latent_space.base_dist.stddev.device) + sample_x_stddev_from_mean = sample_x_stddev_from_mean.to( + self.prior_latent_space.base_dist.stddev.device + ) z_prior = self.prior_latent_space.base_dist.loc + ( self.prior_latent_space.base_dist.scale * sample_x_stddev_from_mean ) @@ -260,6 +265,56 @@ def elbo(self, segm, analytic_kl=True, reconstruct_posterior_mean=False): reconstruction_loss = torch.sum(reconstruction_loss) # mean_reconstruction_loss = torch.mean(reconstruction_loss) - return {"loss": -(reconstruction_loss + self.beta * kl_div), - "rec_loss": reconstruction_loss, - "kl_div": kl_div} + return { + "loss": -(reconstruction_loss + self.beta * kl_div), + "rec_loss": reconstruction_loss, + "kl_div": kl_div, + } + + def geco(self, segm, analytic_kl=True, reconstruct_posterior_mean=False): + """ + Calculate the evidence lower bound of the log-likelihood of P(Y|X) + """ + + criterion = torch.nn.BCEWithLogitsLoss(reduction="mean") + z_posterior = self.posterior_latent_space.rsample() + + kl_div = torch.mean(self.kl_divergence(analytic=analytic_kl, z_posterior=z_posterior)) + + # Here we use the posterior sample sampled above + reconstruction = self.reconstruct( + use_posterior_mean=reconstruct_posterior_mean, z_posterior=z_posterior + ) + + segm = torch.unsqueeze(segm, dim=1) + not_seg = segm.logical_not() + segm = torch.cat((not_seg, segm), dim=1).float() + reconstruction_loss = criterion(input=reconstruction, target=segm) + reconstruction_loss = torch.sum(reconstruction_loss) + # mean_reconstruction_loss = torch.mean(reconstruction_loss) + + num_pixels = reconstruction.numel() + reconstruction_threshold = 0.02 * num_pixels + rec_constraint = self._moving_avg - reconstruction_threshold + + loss = self._lambda * rec_constraint + kl_div + + with torch.no_grad(): + if self._moving_avg is None: + self._moving_avg = reconstruction_loss.detach().mean(0) + else: + self._moving_avg = ( + self._moving_avg * 0.5 + reconstruction_loss.detach() * (1 - 0.5) + ).mean(0) + speed = 1 + self._lambda = (speed * self._moving_avg * self._lambda).clamp(1e-5, 1e5) + + return { + "loss": -loss, + "rec_loss": reconstruction_loss, + "kl_div": kl_div, + "lambda": self._lambda, + "moving_avg": self._moving_avg, + "reconstruction_threshold": reconstruction_threshold, + "rec_constraint": rec_constraint, + } From 2ef525f1495466faca9da3a8fdbaaa47fa32844b Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Sat, 19 Jun 2021 02:40:37 +0000 Subject: [PATCH 049/264] Correction to GECO --- platipy/imaging/cnn/prob_unet.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/platipy/imaging/cnn/prob_unet.py b/platipy/imaging/cnn/prob_unet.py index e5ae7e85..5cd91e3d 100644 --- a/platipy/imaging/cnn/prob_unet.py +++ b/platipy/imaging/cnn/prob_unet.py @@ -295,7 +295,7 @@ def geco(self, segm, analytic_kl=True, reconstruct_posterior_mean=False): num_pixels = reconstruction.numel() reconstruction_threshold = 0.02 * num_pixels - rec_constraint = self._moving_avg - reconstruction_threshold + rec_constraint = reconstruction_loss - reconstruction_threshold loss = self._lambda * rec_constraint + kl_div @@ -308,7 +308,6 @@ def geco(self, segm, analytic_kl=True, reconstruct_posterior_mean=False): ).mean(0) speed = 1 self._lambda = (speed * self._moving_avg * self._lambda).clamp(1e-5, 1e5) - return { "loss": -loss, "rec_loss": reconstruction_loss, From 6b7d8c8c1d6d1af7153c3e1cf6bfdb67d9dbe599 Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Mon, 28 Jun 2021 09:50:32 +1000 Subject: [PATCH 050/264] hprob --- platipy/imaging/cnn/hierarchical_prob_unet.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/platipy/imaging/cnn/hierarchical_prob_unet.py b/platipy/imaging/cnn/hierarchical_prob_unet.py index 8c105a05..31d248b9 100644 --- a/platipy/imaging/cnn/hierarchical_prob_unet.py +++ b/platipy/imaging/cnn/hierarchical_prob_unet.py @@ -802,6 +802,8 @@ def loss(self, img, seg): # Set up a GECO objective (ELBO with a reconstruction constraint). elif self._loss_kwargs["type"] == "geco": + + loss = rec_loss["sum"] * self._multiplier + self._loss_kwargs["beta"] * kl_sum # ma_rec_loss = self._moving_average(rec_loss["sum"]) if self._ema is None: self._ema = rec_loss["sum"].detach().mean(0) @@ -816,8 +818,8 @@ def loss(self, img, seg): speed = 1 if rec_constraint > 0: speed = 2 - self._multiplier = (speed * rec_constraint * self._multiplier).clamp(1e-5, 1e5) - loss = rec_loss["sum"] * self._multiplier + self._loss_kwargs["beta"] * kl_sum + self._multiplier = (speed * rec_constraint * self._multiplier).clamp(1e-5, 1e2) + # loss = rec_loss["sum"] * self._multiplier + self._loss_kwargs["beta"] * kl_sum summaries["geco_loss"] = loss # summaries["ma_rec_loss_mean"] = ma_rec_loss / num_valid_pixels From 5a4ed42b46ec538719d9545f4e8148de529cb543 Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Tue, 29 Jun 2021 04:16:34 +0000 Subject: [PATCH 051/264] Use sum for loss --- platipy/imaging/cnn/prob_unet.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/platipy/imaging/cnn/prob_unet.py b/platipy/imaging/cnn/prob_unet.py index 5cd91e3d..d9fe008c 100644 --- a/platipy/imaging/cnn/prob_unet.py +++ b/platipy/imaging/cnn/prob_unet.py @@ -248,7 +248,7 @@ def elbo(self, segm, analytic_kl=True, reconstruct_posterior_mean=False): Calculate the evidence lower bound of the log-likelihood of P(Y|X) """ - criterion = torch.nn.BCEWithLogitsLoss(reduction="mean") + criterion = torch.nn.BCEWithLogitsLoss(size_average=False, reduce=False, reduction=None) z_posterior = self.posterior_latent_space.rsample() kl_div = torch.mean(self.kl_divergence(analytic=analytic_kl, z_posterior=z_posterior)) @@ -276,7 +276,7 @@ def geco(self, segm, analytic_kl=True, reconstruct_posterior_mean=False): Calculate the evidence lower bound of the log-likelihood of P(Y|X) """ - criterion = torch.nn.BCEWithLogitsLoss(reduction="mean") + criterion = torch.nn.BCEWithLogitsLoss(size_average=False, reduce=False, reduction=None) z_posterior = self.posterior_latent_space.rsample() kl_div = torch.mean(self.kl_divergence(analytic=analytic_kl, z_posterior=z_posterior)) From 7ac76291827a8c1a2176340a124257f914404ff4 Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Tue, 29 Jun 2021 04:19:33 +0000 Subject: [PATCH 052/264] Allow configure kappa --- platipy/imaging/cnn/prob_unet.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/platipy/imaging/cnn/prob_unet.py b/platipy/imaging/cnn/prob_unet.py index d9fe008c..c6671dcc 100644 --- a/platipy/imaging/cnn/prob_unet.py +++ b/platipy/imaging/cnn/prob_unet.py @@ -271,7 +271,7 @@ def elbo(self, segm, analytic_kl=True, reconstruct_posterior_mean=False): "kl_div": kl_div, } - def geco(self, segm, analytic_kl=True, reconstruct_posterior_mean=False): + def geco(self, segm, analytic_kl=True, reconstruct_posterior_mean=False, kappa=0.02): """ Calculate the evidence lower bound of the log-likelihood of P(Y|X) """ @@ -294,7 +294,7 @@ def geco(self, segm, analytic_kl=True, reconstruct_posterior_mean=False): # mean_reconstruction_loss = torch.mean(reconstruction_loss) num_pixels = reconstruction.numel() - reconstruction_threshold = 0.02 * num_pixels + reconstruction_threshold = kappa * num_pixels rec_constraint = reconstruction_loss - reconstruction_threshold loss = self._lambda * rec_constraint + kl_div From 1afbec48fe3bae3bd60839dc5ebbc2ee190ba80b Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Tue, 29 Jun 2021 22:42:13 +0000 Subject: [PATCH 053/264] Correct geco --- platipy/imaging/cnn/prob_unet.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/platipy/imaging/cnn/prob_unet.py b/platipy/imaging/cnn/prob_unet.py index c6671dcc..2772e1b3 100644 --- a/platipy/imaging/cnn/prob_unet.py +++ b/platipy/imaging/cnn/prob_unet.py @@ -177,7 +177,7 @@ def __init__( self._moving_avg = None self.register_buffer("_lambda", torch.zeros(1, requires_grad=False)) - def forward(self, img, seg, training=True): + def forward(self, img, seg=None, training=False): """ Construct prior latent space for patch and run patch through UNet, in case training is True also construct posterior latent space @@ -291,7 +291,6 @@ def geco(self, segm, analytic_kl=True, reconstruct_posterior_mean=False, kappa=0 segm = torch.cat((not_seg, segm), dim=1).float() reconstruction_loss = criterion(input=reconstruction, target=segm) reconstruction_loss = torch.sum(reconstruction_loss) - # mean_reconstruction_loss = torch.mean(reconstruction_loss) num_pixels = reconstruction.numel() reconstruction_threshold = kappa * num_pixels @@ -301,13 +300,11 @@ def geco(self, segm, analytic_kl=True, reconstruct_posterior_mean=False, kappa=0 with torch.no_grad(): if self._moving_avg is None: - self._moving_avg = reconstruction_loss.detach().mean(0) + self._moving_avg = rec_constraint.detach() else: - self._moving_avg = ( - self._moving_avg * 0.5 + reconstruction_loss.detach() * (1 - 0.5) - ).mean(0) + self._moving_avg = self._moving_avg * 0.5 + rec_constraint.detach() * (1 - 0.5) speed = 1 - self._lambda = (speed * self._moving_avg * self._lambda).clamp(1e-5, 1e5) + self._lambda = (speed * torch.exp(self._moving_avg) * self._lambda).clamp(1e-5, 1e5) return { "loss": -loss, "rec_loss": reconstruction_loss, From 62417aa66064fa6aef4e936e6974cccb8aee68ec Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Tue, 29 Jun 2021 22:52:23 +0000 Subject: [PATCH 054/264] Update git attributes with binary --- .gitattributes | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/.gitattributes b/.gitattributes index ed7419e8..985fc46b 100644 --- a/.gitattributes +++ b/.gitattributes @@ -4,6 +4,12 @@ $.dcm binary *.tar.gz binary +*.png binary +*.jpeg binary +*.jpg binary +*.ttf binary +*.pickle binary +*.woff binary *.ipynb filter=nbstripout *.ipynb diff=ipynb From 85df680991f53523553f37307f53c678e1292956 Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Tue, 29 Jun 2021 22:53:04 +0000 Subject: [PATCH 055/264] update git attributes --- .gitattributes | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/.gitattributes b/.gitattributes index 985fc46b..c6172540 100644 --- a/.gitattributes +++ b/.gitattributes @@ -9,7 +9,4 @@ $.dcm binary *.jpg binary *.ttf binary *.pickle binary -*.woff binary - -*.ipynb filter=nbstripout -*.ipynb diff=ipynb +*.woff binary \ No newline at end of file From bb157ed24bfa8ad69f2ca0efc0b45029cca5ea2c Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Tue, 29 Jun 2021 22:53:58 +0000 Subject: [PATCH 056/264] update git attributes --- .gitattributes | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/.gitattributes b/.gitattributes index c6172540..985fc46b 100644 --- a/.gitattributes +++ b/.gitattributes @@ -9,4 +9,7 @@ $.dcm binary *.jpg binary *.ttf binary *.pickle binary -*.woff binary \ No newline at end of file +*.woff binary + +*.ipynb filter=nbstripout +*.ipynb diff=ipynb From cdf4064845fdaea48df3ccc16208fe78fba8f5fa Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Wed, 30 Jun 2021 05:21:35 +0000 Subject: [PATCH 057/264] Converting to pytorch-lightning --- .pylintrc | 2 +- platipy/imaging/cnn/dataset.py | 177 ++++++++++++++++++++++++ platipy/imaging/cnn/prob_unet.py | 2 - platipy/imaging/cnn/pseudo_generator.py | 54 ++++++++ platipy/imaging/cnn/test_sampler.py | 27 ++++ platipy/imaging/cnn/train.py | 150 ++++++++++++++++++++ 6 files changed, 409 insertions(+), 3 deletions(-) create mode 100644 platipy/imaging/cnn/dataset.py create mode 100644 platipy/imaging/cnn/pseudo_generator.py create mode 100644 platipy/imaging/cnn/test_sampler.py create mode 100644 platipy/imaging/cnn/train.py diff --git a/.pylintrc b/.pylintrc index 7f7aa386..f9ecee68 100644 --- a/.pylintrc +++ b/.pylintrc @@ -443,7 +443,7 @@ contextmanager-decorators=contextlib.contextmanager # List of members which are set dynamically and missed by pylint inference # system, and so shouldn't trigger E1101 when accessed. Python regular # expressions are accepted. -generated-members= +generated-members=torch.* # Tells whether missing members accessed in mixin class should be ignored. A # mixin class is detected if its name ends with "mixin" (case insensitive). diff --git a/platipy/imaging/cnn/dataset.py b/platipy/imaging/cnn/dataset.py new file mode 100644 index 00000000..cd265ea8 --- /dev/null +++ b/platipy/imaging/cnn/dataset.py @@ -0,0 +1,177 @@ +from pathlib import Path + +import torch + +import SimpleITK as sitk + +from imgaug import augmenters as iaa +from imgaug.augmentables.segmaps import SegmentationMapsOnImage + +from loguru import logger + + +def preprocess_image(img, crop_to_mm=256): + + img = sitk.Normalize(img) + + new_spacing = sitk.VectorDouble(3) + new_spacing[0] = 1.0 + new_spacing[1] = 1.0 + new_spacing[2] = img.GetSpacing()[2] + + new_size = sitk.VectorUInt32(3) + new_size[0] = int(img.GetSize()[0] * img.GetSpacing()[0]) + new_size[1] = int(img.GetSize()[1] * img.GetSpacing()[1]) + new_size[2] = int(img.GetSize()[2]) + + if new_size[0] < crop_to_mm: + new_size[0] = crop_to_mm + + if new_size[1] < crop_to_mm: + new_size[1] = crop_to_mm + + img = sitk.Resample( + img, + new_size, + sitk.Transform(), + sitk.sitkLinear, + img.GetOrigin(), + new_spacing, + img.GetDirection(), + -1, + img.GetPixelID(), + ) + + center_x = img.GetSize()[0] / 2 + x_from = int(center_x - crop_to_mm / 2) + x_to = x_from + crop_to_mm + + center_y = img.GetSize()[1] / 2 + y_from = int(center_y - crop_to_mm / 2) + y_to = y_from + crop_to_mm + + img = img[x_from:x_to, y_from:y_to, :] + + return img + + +def resample_mask_to_image(img, mask): + + return sitk.Resample( + mask, + img, + sitk.Transform(), + sitk.sitkNearestNeighbor, + 0, + mask.GetPixelID(), + ) + + +def prepare_transforms(): + + sometimes = lambda aug: iaa.Sometimes(0.5, aug) + + seq = iaa.Sequential( + [ + sometimes( + iaa.Affine( + scale={"x": (0.8, 1.2), "y": (0.8, 1.2)}, + translate_percent={"x": (-0.2, 0.2), "y": (-0.2, 0.2)}, + rotate=(-15, 15), + shear=(-8, 8), + cval=-1, + ) + ), + # execute 0 to 2 of the following (less important) augmenters per image + # don't execute all of them, as that would often be way too strong + iaa.SomeOf( + (0, 2), + [ + iaa.OneOf( + [ + iaa.GaussianBlur((0, 1.5)), + iaa.AverageBlur(k=(3, 5)), + ] + ), + sometimes(iaa.PerspectiveTransform(scale=(0.01, 0.1))), + ], + random_order=True, + ), + ], + random_order=True, + ) + + return seq + + +class NiftiDataset(torch.utils.data.Dataset): + """PyTorch Dataset for processing Nifti data""" + + def __init__(self, data, working_dir): + """Prepare a dataset from Nifti images/labels + + Args: + data (list): List of dict's where each item contains keys: "image" and "label". Values + are paths to the Nifti file. "label" may be a list where each item is a path to one + observer. + working_dir (str|path): Working directory where to write prepared files. + """ + + self.data = data + self.transforms = prepare_transforms() + self.slices = [] + self.working_dir = Path(working_dir) + + self.img_dir = working_dir.joinpath("img") + self.mask_dir = working_dir.joinpath("mask") + + self.img_dir.mkdir(exist_ok=True, parents=True) + self.mask_dir.mkdir(exist_ok=True, parents=True) + + for case in data: + case_id = case["id"] + img_path = str(case["image"]) + + structure_paths = case["label"] + if isinstance(structure_paths, (str, Path)): + structure_paths = [structure_paths] + + img_file = self.img_dir.joinpath(f"{case_id}.nii.gz") + img = sitk.ReadImage(img_path) + + if not img_file.exists(): + img = preprocess_image(img) + sitk.WriteImage(img, str(img_file)) + + for obs, structure_path in enumerate(structure_paths): + structure_path = str(structure_path) + mask = sitk.ReadImage(structure_path) + mask = resample_mask_to_image(img, mask) + mask_file = self.mask_dir.joinpath(f"{case_id}_{obs}.nii.gz") + sitk.WriteImage(mask, str(mask_file)) + + for z_slice in range(img.GetSize()[2]): + for obs, mask in enumerate(structure_paths): + mask_file = self.mask_dir.joinpath(f"{case_id}_{obs}.nii.gz") + self.slices.append({"z": z_slice, "image": img_file, "mask": mask_file}) + + def __len__(self): + return len(self.slices) + + def __getitem__(self, index): + + img_file = self.slices[index]["image"] + mask_file = self.slices[index]["mask"] + z_slice = self.slices[index]["z"] + + img = sitk.GetArrayFromImage(sitk.ReadImage(str(img_file))[:, :, z_slice]) + mask = sitk.GetArrayFromImage(sitk.ReadImage(str(mask_file))[:, :, z_slice]) + + segmap = SegmentationMapsOnImage(mask, shape=mask.shape) + img, mask = self.transforms(image=img, segmentation_maps=segmap) + mask = mask.get_arr() + + img = torch.FloatTensor(img) + mask = torch.LongTensor(mask) + + return img.unsqueeze(0), mask diff --git a/platipy/imaging/cnn/prob_unet.py b/platipy/imaging/cnn/prob_unet.py index 2772e1b3..efda1973 100644 --- a/platipy/imaging/cnn/prob_unet.py +++ b/platipy/imaging/cnn/prob_unet.py @@ -16,8 +16,6 @@ # https://github.com/stefanknegt/Probabilistic-Unet-Pytorch # which is released under the Apache Licence 2.0 -# pylint: disable=invalid-name - import torch from torch.distributions import Normal, Independent, kl diff --git a/platipy/imaging/cnn/pseudo_generator.py b/platipy/imaging/cnn/pseudo_generator.py new file mode 100644 index 00000000..429dc4f2 --- /dev/null +++ b/platipy/imaging/cnn/pseudo_generator.py @@ -0,0 +1,54 @@ +from pathlib import Path + +import matplotlib.pyplot as plt +import numpy as np +import SimpleITK as sitk + +import random + +from platipy.imaging.generation.image import insert_sphere +from platipy.imaging import ImageVisualiser + + +def generate_pseudo_data(data_dir="data"): + + test_data_directory = Path(data_dir) + image_directory = test_data_directory.joinpath("images") + label_directory = test_data_directory.joinpath("labels") + slice_directory = test_data_directory.joinpath("slices") + + image_directory.mkdir(parents=True, exist_ok=True) + label_directory.mkdir(parents=True, exist_ok=True) + slice_directory.mkdir(parents=True, exist_ok=True) + + for case, sphere_rad in enumerate(range(10, 30)): + + ct_arr = np.ones((80, 128, 128)) * -1000 + mask_arr = np.zeros((80, 128, 128)) + + xpos = random.randint(50, 80) + ypos = random.randint(50, 80) + + ct_arr = insert_sphere(ct_arr, sp_radius=sphere_rad, sp_centre=(30, ypos, xpos)) + ct = sitk.GetImageFromArray(ct_arr) + sitk.WriteImage(ct, str(image_directory.joinpath(f"{case}.nii.gz"))) + + vis = ImageVisualiser(ct, cut=(30, ypos, xpos)) + masks = {} + + for obs_id, obs in enumerate(range(-4, 5, 2)): + obs_rad = sphere_rad + obs + + mask_arr = insert_sphere(mask_arr, sp_radius=obs_rad, sp_centre=(30, ypos, xpos)) + + mask = sitk.GetImageFromArray(mask_arr) + mask.CopyInformation(ct) + mask = sitk.Cast(mask, sitk.sitkUInt8) + sitk.WriteImage(mask, str(label_directory.joinpath(f"{case}_{obs_id}.nii.gz"))) + + masks[f"obs_{obs_id}_{obs_rad}"] = mask + + vis.add_contour(masks) + fig = vis.show() + plt.savefig(slice_directory.joinpath(f"{case}.png")) + plt.close() diff --git a/platipy/imaging/cnn/test_sampler.py b/platipy/imaging/cnn/test_sampler.py new file mode 100644 index 00000000..39a3c4eb --- /dev/null +++ b/platipy/imaging/cnn/test_sampler.py @@ -0,0 +1,27 @@ +import random + +from torch.utils.data import BatchSampler +from torch.utils.data import Sampler + + +class ObserverSampler(Sampler): + def __init__(self, data_source, num_observers): + self.data_source = data_source + self.num_observers = num_observers + + def __iter__(self): + indices = list(range(int(len(self.data_source) / self.num_observers))) + random.shuffle(indices) + for i in indices: + for o in range(self.num_observers): + yield i * self.num_observers + o + + def __len__(self): + return len(self.data_source) + + +print( + list( + BatchSampler(ObserverSampler(["x" for x in range(50)], 5), batch_size=10, drop_last=False) + ) +) diff --git a/platipy/imaging/cnn/train.py b/platipy/imaging/cnn/train.py new file mode 100644 index 00000000..3aa0d881 --- /dev/null +++ b/platipy/imaging/cnn/train.py @@ -0,0 +1,150 @@ +from pathlib import Path + +import torch +import pytorch_lightning as pl + +from argparse import ArgumentParser + + +from platipy.imaging.cnn.prob_unet import ProbabilisticUnet +from platipy.imaging.cnn.unet import l2_regularisation +from platipy.imaging.cnn.dataset import NiftiDataset + + +class ProbUNet(pl.LightningModule): + def __init__(self): + super().__init__() + self.prob_unet = ProbabilisticUnet() + + def forward(self, x): + self.prob_unet.forward(x, None, False) + return x + + def configure_optimizers(self): + optimizer = torch.optim.Adam(self.parameters(), lr=1e-5, weight_decay=0) + return optimizer + + def training_step(self, batch, batch_idx): + x, y = batch + self.prob_unet.forward(x, y, training=True) + elbo = self.prob_unet.elbo(y, analytic_kl=True) + reg_loss = ( + l2_regularisation(self.prob_unet.posterior) + + l2_regularisation(self.prob_unet.prior) + + l2_regularisation(self.prob_unet.fcomb.layers) + ) + loss = -elbo["loss"] + 1e-5 * reg_loss + + print(loss) + + self.log("elbo_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True) + self.log( + "rec_loss", elbo["rec_loss"], on_step=True, on_epoch=True, prog_bar=True, logger=True + ) + self.log( + "kl_loss", elbo["kl_div"], on_step=True, on_epoch=True, prog_bar=True, logger=True + ) + + return loss + + def validation_step(self, batch, batch_idx): + # x, y = batch + # y_hat = self(x) + # val_loss = F.cross_entropy(y_hat, y) + val_loss = 1 + return val_loss + + +class ProbUNetDataModule(pl.LightningDataModule): + def __init__( + self, + data_dir: str = "./data", + working_dir: str = "./working", + train_cases=[], + validation_cases=[], + test_cases=[], + ): + super().__init__() + self.data_dir = Path(data_dir) + self.working_dir = Path(working_dir) + + self.train_cases = train_cases + self.validation_cases = validation_cases + + def prepare_data(self): + pass + + def setup(self, stage=None): + + train_data = [ + { + "id": case, + "image": self.data_dir.joinpath("images", f"{case}.nii.gz"), + "label": [p for p in self.data_dir.joinpath("labels").glob(f"{case}_*.nii.gz")], + } + for case in self.train_cases + ] + + validation_data = [ + { + "id": case, + "image": self.data_dir.joinpath("images", f"{case}.nii.gz"), + "label": [p for p in self.data_dir.joinpath("labels").glob(f"{case}_*.nii.gz")], + } + for case in self.validation_cases + ] + + self.training_set = NiftiDataset(train_data, self.working_dir.joinpath("train")) + print(len(self.training_set)) + self.validation_set = NiftiDataset( + validation_data, self.working_dir.joinpath("validation") + ) + + def train_dataloader(self): + return torch.utils.data.DataLoader( + self.training_set, + # batch_sampler=BatchSampler( + # ObserverSampler(train_set, 5), batch_size=params["batch_size"], drop_last=False + # ), + # num_workers=params["num_workers"], + batch_size=5, + shuffle=True, + num_workers=4, + ) + + def val_dataloader(self): + return torch.utils.data.DataLoader( + self.validation_set, + # batch_sampler=BatchSampler( + # ObserverSampler(train_set, 5), batch_size=params["batch_size"], drop_last=False + # ), + # num_workers=params["num_workers"], + batch_size=5, + shuffle=True, + num_workers=4, + ) + + +def main(args): + + pl.seed_everything(args.seed, workers=True) + + data_module = ProbUNetDataModule( + data_dir="./data", + working_dir="./working", + train_cases=[c for c in range(15)], + validation_cases=[c for c in range(15, 20)], + ) + + prob_unet = ProbUNet() + trainer = pl.Trainer.from_argparse_args(args) + trainer.fit(prob_unet, data_module) + + +if __name__ == "__main__": + parser = ArgumentParser() + parser = pl.Trainer.add_argparse_args(parser) + parser.add_argument("--seed", type=int, default=42, help="an integer to use as seed") + args = parser.parse_args() + + main(args) From 669eb572cbb92ed2f017fcc247af3ac053eb40ff Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Thu, 1 Jul 2021 00:00:04 +0000 Subject: [PATCH 058/264] Work on migrating to pytorch lightning --- platipy/imaging/cnn/dataset.py | 13 +++-- platipy/imaging/cnn/prob_unet.py | 47 ++++++++------- platipy/imaging/cnn/pseudo_generator.py | 2 +- .../cnn/{test_sampler.py => sampler.py} | 7 --- platipy/imaging/cnn/train.py | 57 +++++++++++++------ 5 files changed, 78 insertions(+), 48 deletions(-) rename platipy/imaging/cnn/{test_sampler.py => sampler.py} (82%) diff --git a/platipy/imaging/cnn/dataset.py b/platipy/imaging/cnn/dataset.py index cd265ea8..d64994fe 100644 --- a/platipy/imaging/cnn/dataset.py +++ b/platipy/imaging/cnn/dataset.py @@ -10,7 +10,7 @@ from loguru import logger -def preprocess_image(img, crop_to_mm=256): +def preprocess_image(img, crop_to_mm=128): img = sitk.Normalize(img) @@ -163,9 +163,14 @@ def __getitem__(self, index): img_file = self.slices[index]["image"] mask_file = self.slices[index]["mask"] z_slice = self.slices[index]["z"] - - img = sitk.GetArrayFromImage(sitk.ReadImage(str(img_file))[:, :, z_slice]) - mask = sitk.GetArrayFromImage(sitk.ReadImage(str(mask_file))[:, :, z_slice]) + + img = sitk.ReadImage(str(img_file)) + img = sitk.GetArrayFromImage(img) + img = img[z_slice, :, :] + + mask = sitk.ReadImage(str(mask_file)) + mask = sitk.GetArrayFromImage(mask) + mask = mask[z_slice, :, :] segmap = SegmentationMapsOnImage(mask, shape=mask.shape) img, mask = self.transforms(image=img, segmentation_maps=segmap) diff --git a/platipy/imaging/cnn/prob_unet.py b/platipy/imaging/cnn/prob_unet.py index efda1973..ebae3b7e 100644 --- a/platipy/imaging/cnn/prob_unet.py +++ b/platipy/imaging/cnn/prob_unet.py @@ -241,17 +241,13 @@ def kl_divergence(self, analytic=True, z_posterior=None): kl_div = log_posterior_prob - log_prior_prob return kl_div - def elbo(self, segm, analytic_kl=True, reconstruct_posterior_mean=False): - """ - Calculate the evidence lower bound of the log-likelihood of P(Y|X) - """ + def reconstruction_loss(self, segm, reconstruct_posterior_mean=False, z_posterior=None): criterion = torch.nn.BCEWithLogitsLoss(size_average=False, reduce=False, reduction=None) - z_posterior = self.posterior_latent_space.rsample() - kl_div = torch.mean(self.kl_divergence(analytic=analytic_kl, z_posterior=z_posterior)) + if z_posterior is None: + z_posterior = self.posterior_latent_space.rsample() - # Here we use the posterior sample sampled above reconstruction = self.reconstruct( use_posterior_mean=reconstruct_posterior_mean, z_posterior=z_posterior ) @@ -261,10 +257,25 @@ def elbo(self, segm, analytic_kl=True, reconstruct_posterior_mean=False): segm = torch.cat((not_seg, segm), dim=1).float() reconstruction_loss = criterion(input=reconstruction, target=segm) reconstruction_loss = torch.sum(reconstruction_loss) - # mean_reconstruction_loss = torch.mean(reconstruction_loss) + + return reconstruction_loss + + def elbo(self, segm, analytic_kl=True, reconstruct_posterior_mean=False): + """ + Calculate the evidence lower bound of the log-likelihood of P(Y|X) + """ + + z_posterior = self.posterior_latent_space.rsample() + + kl_div = torch.mean(self.kl_divergence(analytic=analytic_kl, z_posterior=z_posterior)) + + # Here we use the posterior sample sampled above + reconstruction_loss = self.reconstruction_loss( + segm, reconstruct_posterior_mean=reconstruct_posterior_mean, z_posterior=z_posterior + ) return { - "loss": -(reconstruction_loss + self.beta * kl_div), + "loss": reconstruction_loss + self.beta * kl_div, "rec_loss": reconstruction_loss, "kl_div": kl_div, } @@ -274,27 +285,23 @@ def geco(self, segm, analytic_kl=True, reconstruct_posterior_mean=False, kappa=0 Calculate the evidence lower bound of the log-likelihood of P(Y|X) """ - criterion = torch.nn.BCEWithLogitsLoss(size_average=False, reduce=False, reduction=None) z_posterior = self.posterior_latent_space.rsample() kl_div = torch.mean(self.kl_divergence(analytic=analytic_kl, z_posterior=z_posterior)) # Here we use the posterior sample sampled above - reconstruction = self.reconstruct( - use_posterior_mean=reconstruct_posterior_mean, z_posterior=z_posterior + reconstruction_loss = self.reconstruction_loss( + segm, reconstruct_posterior_mean=reconstruct_posterior_mean, z_posterior=z_posterior ) - segm = torch.unsqueeze(segm, dim=1) - not_seg = segm.logical_not() - segm = torch.cat((not_seg, segm), dim=1).float() - reconstruction_loss = criterion(input=reconstruction, target=segm) - reconstruction_loss = torch.sum(reconstruction_loss) - - num_pixels = reconstruction.numel() + num_pixels = segm.numel() reconstruction_threshold = kappa * num_pixels rec_constraint = reconstruction_loss - reconstruction_threshold - loss = self._lambda * rec_constraint + kl_div + loss = ( + self._lambda * rec_constraint # pylint: disable=access-member-before-definition + + kl_div + ) with torch.no_grad(): if self._moving_avg is None: diff --git a/platipy/imaging/cnn/pseudo_generator.py b/platipy/imaging/cnn/pseudo_generator.py index 429dc4f2..434cdd2a 100644 --- a/platipy/imaging/cnn/pseudo_generator.py +++ b/platipy/imaging/cnn/pseudo_generator.py @@ -49,6 +49,6 @@ def generate_pseudo_data(data_dir="data"): masks[f"obs_{obs_id}_{obs_rad}"] = mask vis.add_contour(masks) - fig = vis.show() + vis.show() plt.savefig(slice_directory.joinpath(f"{case}.png")) plt.close() diff --git a/platipy/imaging/cnn/test_sampler.py b/platipy/imaging/cnn/sampler.py similarity index 82% rename from platipy/imaging/cnn/test_sampler.py rename to platipy/imaging/cnn/sampler.py index 39a3c4eb..43cddd22 100644 --- a/platipy/imaging/cnn/test_sampler.py +++ b/platipy/imaging/cnn/sampler.py @@ -18,10 +18,3 @@ def __iter__(self): def __len__(self): return len(self.data_source) - - -print( - list( - BatchSampler(ObserverSampler(["x" for x in range(50)], 5), batch_size=10, drop_last=False) - ) -) diff --git a/platipy/imaging/cnn/train.py b/platipy/imaging/cnn/train.py index 3aa0d881..4b58c5cc 100644 --- a/platipy/imaging/cnn/train.py +++ b/platipy/imaging/cnn/train.py @@ -1,4 +1,6 @@ from pathlib import Path +import numpy as np +from scipy.optimize import linear_sum_assignment import torch import pytorch_lightning as pl @@ -9,6 +11,7 @@ from platipy.imaging.cnn.prob_unet import ProbabilisticUnet from platipy.imaging.cnn.unet import l2_regularisation from platipy.imaging.cnn.dataset import NiftiDataset +from platipy.imaging.cnn.sampler import ObserverSampler class ProbUNet(pl.LightningModule): @@ -33,9 +36,7 @@ def training_step(self, batch, batch_idx): + l2_regularisation(self.prob_unet.prior) + l2_regularisation(self.prob_unet.fcomb.layers) ) - loss = -elbo["loss"] + 1e-5 * reg_loss - - print(loss) + loss = elbo["loss"] + 1e-5 * reg_loss self.log("elbo_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True) self.log( @@ -48,11 +49,38 @@ def training_step(self, batch, batch_idx): return loss def validation_step(self, batch, batch_idx): - # x, y = batch - # y_hat = self(x) - # val_loss = F.cross_entropy(y_hat, y) - val_loss = 1 - return val_loss + with torch.set_grad_enabled(False): + x, y = batch + self.prob_unet.forward(x) + sample = self.prob_unet.sample(testing=True) + mean = self.prob_unet.sample(testing=True, use_mean=True) + + criterion = torch.nn.BCEWithLogitsLoss( + size_average=False, reduce=False, reduction=None + ) + + onehot = torch.nn.functional.one_hot(y, 2).transpose(1, 3).float() + + observers = y.shape[0] + sim_matrix = np.zeros((observers, observers)) + for i in range(observers): + for j in range(observers): + rec_loss = criterion(input=sample[i], target=onehot[j]) + rec_loss = torch.sum(rec_loss) + sim_matrix[i, j] = rec_loss.item() + + row_idx, col_idx = linear_sum_assignment(sim_matrix) + + matched_val = sim_matrix[row_idx, col_idx].mean() + self.log( + "matched_val", matched_val, on_step=True, on_epoch=True, prog_bar=True, logger=True + ) + + mean_val = criterion(input=mean, target=onehot) + mean_val = torch.sum(rec_loss).item() + self.log("mean_val", mean_val, on_step=True, on_epoch=True, prog_bar=True, logger=True) + + return matched_val class ProbUNetDataModule(pl.LightningDataModule): @@ -115,12 +143,9 @@ def train_dataloader(self): def val_dataloader(self): return torch.utils.data.DataLoader( self.validation_set, - # batch_sampler=BatchSampler( - # ObserverSampler(train_set, 5), batch_size=params["batch_size"], drop_last=False - # ), - # num_workers=params["num_workers"], - batch_size=5, - shuffle=True, + batch_sampler=torch.utils.data.BatchSampler( + ObserverSampler(self.validation_set, 5), batch_size=5, drop_last=False + ), num_workers=4, ) @@ -132,8 +157,8 @@ def main(args): data_module = ProbUNetDataModule( data_dir="./data", working_dir="./working", - train_cases=[c for c in range(15)], - validation_cases=[c for c in range(15, 20)], + train_cases=[c for c in range(2)], + validation_cases=[c for c in range(15, 16)], ) prob_unet = ProbUNet() From cd53ecc5bc535fa8434fe3e21efe6a43936414ab Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Thu, 1 Jul 2021 06:59:14 +0000 Subject: [PATCH 059/264] Adding CometML to pytorch lightning --- platipy/imaging/cnn/prob_unet.py | 85 ++++++++++---------- platipy/imaging/cnn/train.py | 132 +++++++++++++++++++++++++------ 2 files changed, 147 insertions(+), 70 deletions(-) diff --git a/platipy/imaging/cnn/prob_unet.py b/platipy/imaging/cnn/prob_unet.py index ebae3b7e..827e373a 100644 --- a/platipy/imaging/cnn/prob_unet.py +++ b/platipy/imaging/cnn/prob_unet.py @@ -153,14 +153,14 @@ def __init__( filters_per_layer=[64 * (2 ** x) for x in range(5)], latent_dim=6, no_convs_fcomb=4, - beta=1.0, + loss_type="elbo", + loss_params={"beta": 1}, ): super(ProbabilisticUnet, self).__init__() self.no_convs_per_block = 3 self.no_convs_fcomb = no_convs_fcomb self.initializers = {"w": "he_normal", "b": "normal"} - self.beta = beta self.z_prior_sample = 0 self.unet = UNet(input_channels, num_classes, filters_per_layer, final_layer=False) @@ -168,6 +168,9 @@ def __init__( self.posterior = AxisAlignedConvGaussian(input_channels + 1, filters_per_layer, latent_dim) self.fcomb = Fcomb(filters_per_layer, latent_dim, num_classes, no_convs_fcomb) + self.loss_type = loss_type + self.loss_params = loss_params + self.posterior_latent_space = None self.prior_latent_space = None self.unet_features = None @@ -256,11 +259,11 @@ def reconstruction_loss(self, segm, reconstruct_posterior_mean=False, z_posterio not_seg = segm.logical_not() segm = torch.cat((not_seg, segm), dim=1).float() reconstruction_loss = criterion(input=reconstruction, target=segm) - reconstruction_loss = torch.sum(reconstruction_loss) + reconstruction_loss = torch.mean(reconstruction_loss) return reconstruction_loss - def elbo(self, segm, analytic_kl=True, reconstruct_posterior_mean=False): + def loss(self, segm, analytic_kl=True, reconstruct_posterior_mean=False): """ Calculate the evidence lower bound of the log-likelihood of P(Y|X) """ @@ -274,48 +277,42 @@ def elbo(self, segm, analytic_kl=True, reconstruct_posterior_mean=False): segm, reconstruct_posterior_mean=reconstruct_posterior_mean, z_posterior=z_posterior ) - return { - "loss": reconstruction_loss + self.beta * kl_div, - "rec_loss": reconstruction_loss, - "kl_div": kl_div, - } + if self.loss_type == "elbo": - def geco(self, segm, analytic_kl=True, reconstruct_posterior_mean=False, kappa=0.02): - """ - Calculate the evidence lower bound of the log-likelihood of P(Y|X) - """ + return { + "loss": reconstruction_loss + self.loss_params["beta"] * kl_div, + "rec_loss": reconstruction_loss, + "kl_div": kl_div, + } + elif self.loss_type == "geco": - z_posterior = self.posterior_latent_space.rsample() + num_pixels = segm.numel() + reconstruction_threshold = self.loss_params["kappa"] # * num_pixels + rec_constraint = reconstruction_loss - reconstruction_threshold - kl_div = torch.mean(self.kl_divergence(analytic=analytic_kl, z_posterior=z_posterior)) - - # Here we use the posterior sample sampled above - reconstruction_loss = self.reconstruction_loss( - segm, reconstruct_posterior_mean=reconstruct_posterior_mean, z_posterior=z_posterior - ) - - num_pixels = segm.numel() - reconstruction_threshold = kappa * num_pixels - rec_constraint = reconstruction_loss - reconstruction_threshold + loss = ( + self._lambda * rec_constraint # pylint: disable=access-member-before-definition + + kl_div + ) - loss = ( - self._lambda * rec_constraint # pylint: disable=access-member-before-definition - + kl_div - ) + with torch.no_grad(): + if self._moving_avg is None: + self._moving_avg = rec_constraint.detach() + else: + self._moving_avg = self._moving_avg * 0.5 + rec_constraint.detach() * (1 - 0.5) + speed = 1 + self._lambda = (speed * torch.exp(self._moving_avg) * self._lambda).clamp( + self.loss_params["clamp_rec"][0], self.loss_params["clamp_rec"][1] + ) + return { + "loss": loss, + "rec_loss": reconstruction_loss, + "kl_div": kl_div, + "lambda": self._lambda, + "moving_avg": self._moving_avg, + "reconstruction_threshold": reconstruction_threshold, + "rec_constraint": rec_constraint, + } - with torch.no_grad(): - if self._moving_avg is None: - self._moving_avg = rec_constraint.detach() - else: - self._moving_avg = self._moving_avg * 0.5 + rec_constraint.detach() * (1 - 0.5) - speed = 1 - self._lambda = (speed * torch.exp(self._moving_avg) * self._lambda).clamp(1e-5, 1e5) - return { - "loss": -loss, - "rec_loss": reconstruction_loss, - "kl_div": kl_div, - "lambda": self._lambda, - "moving_avg": self._moving_avg, - "reconstruction_threshold": reconstruction_threshold, - "rec_constraint": rec_constraint, - } + else: + raise NotImplementedError("Loss must be 'elbo' or 'geco'") diff --git a/platipy/imaging/cnn/train.py b/platipy/imaging/cnn/train.py index 4b58c5cc..c858779f 100644 --- a/platipy/imaging/cnn/train.py +++ b/platipy/imaging/cnn/train.py @@ -1,13 +1,18 @@ +import os +import math + from pathlib import Path import numpy as np from scipy.optimize import linear_sum_assignment +from comet_ml import Experiment +from pytorch_lightning.loggers import CometLogger + import torch import pytorch_lightning as pl from argparse import ArgumentParser - from platipy.imaging.cnn.prob_unet import ProbabilisticUnet from platipy.imaging.cnn.unet import l2_regularisation from platipy.imaging.cnn.dataset import NiftiDataset @@ -15,37 +20,79 @@ class ProbUNet(pl.LightningModule): - def __init__(self): + def __init__( + self, + **kwargs, + ): super().__init__() - self.prob_unet = ProbabilisticUnet() + + self.save_hyperparameters() + + loss_params = {"beta": self.hparams.beta} + + if self.hparams.loss_type == "geco": + loss_params = {"kappa": self.hparams.kappa, "clamp_rec": self.hparams.clamp_rec} + + self.prob_unet = ProbabilisticUnet( + self.hparams.input_channels, + self.hparams.num_classes, + self.hparams.filters_per_layer, + self.hparams.latent_dim, + self.hparams.no_convs_fcomb, + self.hparams.loss_type, + loss_params, + ) + + @staticmethod + def add_model_specific_args(parent_parser): + parser = parent_parser.add_argument_group("Probabilistic UNet") + parser.add_argument("--learning_rate", type=float, default=1e-5) + parser.add_argument("--input_channels", type=int, default=1) + parser.add_argument("--num_classes", type=int, default=2) + parser.add_argument( + "--filters_per_layer", nargs="+", type=int, default=[64 * (2 ** x) for x in range(5)] + ) + parser.add_argument("--latent_dim", type=int, default=6) + parser.add_argument("--no_convs_fcomb", type=int, default=4) + parser.add_argument("--loss_type", type=str, default="elbo") + parser.add_argument("--beta", type=float, default=1.0) + parser.add_argument("--kappa", type=float, default=0.02) + parser.add_argument("--clamp_rec", nargs="+", type=float, default=[1e-5, 1e5]) + return parent_parser def forward(self, x): self.prob_unet.forward(x, None, False) return x def configure_optimizers(self): - optimizer = torch.optim.Adam(self.parameters(), lr=1e-5, weight_decay=0) + optimizer = torch.optim.Adam( + self.parameters(), lr=self.hparams.learning_rate, weight_decay=0 + ) return optimizer def training_step(self, batch, batch_idx): x, y = batch self.prob_unet.forward(x, y, training=True) - elbo = self.prob_unet.elbo(y, analytic_kl=True) + loss = self.prob_unet.loss(y, analytic_kl=True) reg_loss = ( l2_regularisation(self.prob_unet.posterior) + l2_regularisation(self.prob_unet.prior) + l2_regularisation(self.prob_unet.fcomb.layers) ) - loss = elbo["loss"] + 1e-5 * reg_loss - - self.log("elbo_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True) - self.log( - "rec_loss", elbo["rec_loss"], on_step=True, on_epoch=True, prog_bar=True, logger=True - ) + training_loss = loss["loss"] + 1e-5 * reg_loss self.log( - "kl_loss", elbo["kl_div"], on_step=True, on_epoch=True, prog_bar=True, logger=True + "training_loss", training_loss, on_step=True, on_epoch=True, prog_bar=True, logger=True ) + for k in loss: + self.log( + k, + loss[k], + on_step=True, + on_epoch=True, + prog_bar=True, + logger=True, + ) return loss def validation_step(self, batch, batch_idx): @@ -88,22 +135,43 @@ def __init__( self, data_dir: str = "./data", working_dir: str = "./working", - train_cases=[], - validation_cases=[], - test_cases=[], + fold=0, + k_folds=5, + **kwargs, ): super().__init__() self.data_dir = Path(data_dir) self.working_dir = Path(working_dir) - self.train_cases = train_cases - self.validation_cases = validation_cases + self.fold = fold + self.k_folds = k_folds + + self.train_cases = [] + self.validation_cases = [] + + print(f"Training fold {self.fold}") - def prepare_data(self): - pass + @staticmethod + def add_model_specific_args(parent_parser): + parser = parent_parser.add_argument_group("Data Loader") + parser.add_argument("--data_dir", type=str, default="./data") + parser.add_argument("--fold", type=int, default=0) + parser.add_argument("--k_folds", type=int, default=5) + + return parent_parser def setup(self, stage=None): + cases = [p.name.replace(".nii.gz", "") for p in self.data_dir.joinpath("images").glob("*")] + cases.sort() + cases_per_fold = math.ceil(len(cases) / self.k_folds) + for f in range(self.k_folds): + + if self.fold == f: + self.validation_cases = cases[f * cases_per_fold : (f + 1) * cases_per_fold] + else: + self.train_cases += cases[f * cases_per_fold : (f + 1) * cases_per_fold] + train_data = [ { "id": case, @@ -123,7 +191,6 @@ def setup(self, stage=None): ] self.training_set = NiftiDataset(train_data, self.working_dir.joinpath("train")) - print(len(self.training_set)) self.validation_set = NiftiDataset( validation_data, self.working_dir.joinpath("validation") ) @@ -154,22 +221,35 @@ def main(args): pl.seed_everything(args.seed, workers=True) - data_module = ProbUNetDataModule( - data_dir="./data", - working_dir="./working", - train_cases=[c for c in range(2)], - validation_cases=[c for c in range(15, 16)], + args.working_dir = Path(args.working_dir) + args.working_dir = args.working_dir.joinpath(args.experiment) + + comet_logger = CometLogger( + api_key=os.environ["COMET_API_KEY"], + workspace=os.environ["COMET_WORKSPACE"], + project_name=os.environ["COMET_PROJECT"], + experiment_name=args.experiment, + save_dir=args.working_dir, ) - prob_unet = ProbUNet() + dict_args = vars(args) + + data_module = ProbUNetDataModule(**dict_args) + + prob_unet = ProbUNet(**dict_args) trainer = pl.Trainer.from_argparse_args(args) + trainer.logger = comet_logger trainer.fit(prob_unet, data_module) if __name__ == "__main__": parser = ArgumentParser() + parser = ProbUNet.add_model_specific_args(parser) + parser = ProbUNetDataModule.add_model_specific_args(parser) parser = pl.Trainer.add_argparse_args(parser) parser.add_argument("--seed", type=int, default=42, help="an integer to use as seed") + parser.add_argument("--experiment", type=str, default="default", help="Name of experiment") + parser.add_argument("--working_dir", type=str, default="./working") args = parser.parse_args() main(args) From 1e322584a3e857dc7d475fabccfefa561b081668 Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Thu, 1 Jul 2021 23:56:13 +0000 Subject: [PATCH 060/264] Correction to param passing --- platipy/imaging/cnn/train.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/platipy/imaging/cnn/train.py b/platipy/imaging/cnn/train.py index c858779f..848e3e05 100644 --- a/platipy/imaging/cnn/train.py +++ b/platipy/imaging/cnn/train.py @@ -5,7 +5,7 @@ import numpy as np from scipy.optimize import linear_sum_assignment -from comet_ml import Experiment +import comet_ml from pytorch_lightning.loggers import CometLogger import torch @@ -28,7 +28,10 @@ def __init__( self.save_hyperparameters() - loss_params = {"beta": self.hparams.beta} + loss_params = None + + if self.hparams.loss_type == "elbo": + loss_params = {"beta": self.hparams.beta} if self.hparams.loss_type == "geco": loss_params = {"kappa": self.hparams.kappa, "clamp_rec": self.hparams.clamp_rec} From a8d774006686a88025524906694dbbcaee21574a Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Thu, 1 Jul 2021 23:57:04 +0000 Subject: [PATCH 061/264] Able to gen pseudo data directly --- platipy/imaging/cnn/pseudo_generator.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/platipy/imaging/cnn/pseudo_generator.py b/platipy/imaging/cnn/pseudo_generator.py index 434cdd2a..bf3a78f2 100644 --- a/platipy/imaging/cnn/pseudo_generator.py +++ b/platipy/imaging/cnn/pseudo_generator.py @@ -52,3 +52,7 @@ def generate_pseudo_data(data_dir="data"): vis.show() plt.savefig(slice_directory.joinpath(f"{case}.png")) plt.close() + + +if __name__ == "__main__": + generate_pseudo_data() From 0cead6f28c6f588070e87b3e87d1c9fd59bdeb65 Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Fri, 2 Jul 2021 00:16:46 +0000 Subject: [PATCH 062/264] Allow configuration of batch size --- platipy/imaging/cnn/train.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/platipy/imaging/cnn/train.py b/platipy/imaging/cnn/train.py index 848e3e05..324bb7e0 100644 --- a/platipy/imaging/cnn/train.py +++ b/platipy/imaging/cnn/train.py @@ -140,6 +140,7 @@ def __init__( working_dir: str = "./working", fold=0, k_folds=5, + batch_size=5, **kwargs, ): super().__init__() @@ -152,6 +153,8 @@ def __init__( self.train_cases = [] self.validation_cases = [] + self.batch_size = batch_size + print(f"Training fold {self.fold}") @staticmethod @@ -160,6 +163,7 @@ def add_model_specific_args(parent_parser): parser.add_argument("--data_dir", type=str, default="./data") parser.add_argument("--fold", type=int, default=0) parser.add_argument("--k_folds", type=int, default=5) + parser.add_argument("--batch_size", type=int, default=5) return parent_parser @@ -205,7 +209,7 @@ def train_dataloader(self): # ObserverSampler(train_set, 5), batch_size=params["batch_size"], drop_last=False # ), # num_workers=params["num_workers"], - batch_size=5, + batch_size=self.batch_size, shuffle=True, num_workers=4, ) From 53cf69f333f8c01fb500e696ab3d3dd6e4566e09 Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Fri, 2 Jul 2021 00:22:41 +0000 Subject: [PATCH 063/264] Allow config num workers --- platipy/imaging/cnn/train.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/platipy/imaging/cnn/train.py b/platipy/imaging/cnn/train.py index 324bb7e0..696ab0ca 100644 --- a/platipy/imaging/cnn/train.py +++ b/platipy/imaging/cnn/train.py @@ -141,6 +141,7 @@ def __init__( fold=0, k_folds=5, batch_size=5, + num_workers=4, **kwargs, ): super().__init__() @@ -154,6 +155,7 @@ def __init__( self.validation_cases = [] self.batch_size = batch_size + self.num_workers = num_workers print(f"Training fold {self.fold}") @@ -164,6 +166,7 @@ def add_model_specific_args(parent_parser): parser.add_argument("--fold", type=int, default=0) parser.add_argument("--k_folds", type=int, default=5) parser.add_argument("--batch_size", type=int, default=5) + parser.add_argument("--num_workers", type=int, default=4) return parent_parser @@ -211,7 +214,7 @@ def train_dataloader(self): # num_workers=params["num_workers"], batch_size=self.batch_size, shuffle=True, - num_workers=4, + num_workers=self.num_workers, ) def val_dataloader(self): @@ -220,7 +223,7 @@ def val_dataloader(self): batch_sampler=torch.utils.data.BatchSampler( ObserverSampler(self.validation_set, 5), batch_size=5, drop_last=False ), - num_workers=4, + num_workers=self.num_workers, ) From 8ae7a5e3e0721a6376e01abb28b4909b98f6235e Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Fri, 2 Jul 2021 01:22:52 +0000 Subject: [PATCH 064/264] Save slices as npy --- platipy/imaging/cnn/dataset.py | 69 +++++++++++++++++++++++++--------- 1 file changed, 52 insertions(+), 17 deletions(-) diff --git a/platipy/imaging/cnn/dataset.py b/platipy/imaging/cnn/dataset.py index d64994fe..7e101e59 100644 --- a/platipy/imaging/cnn/dataset.py +++ b/platipy/imaging/cnn/dataset.py @@ -1,5 +1,7 @@ from pathlib import Path +import numpy as np + import torch import SimpleITK as sitk @@ -125,6 +127,8 @@ def __init__(self, data, working_dir): self.img_dir = working_dir.joinpath("img") self.mask_dir = working_dir.joinpath("mask") + self.data_exists = self.img_dir.exists() + self.img_dir.mkdir(exist_ok=True, parents=True) self.mask_dir.mkdir(exist_ok=True, parents=True) @@ -136,24 +140,58 @@ def __init__(self, data, working_dir): if isinstance(structure_paths, (str, Path)): structure_paths = [structure_paths] - img_file = self.img_dir.joinpath(f"{case_id}.nii.gz") + existing_images = [i for i in self.img_dir.glob(f"{case_id}_*.npy")] + if len(existing_images) > 0: + logger.debug(f"Image for case already exist: {case_id}") + + for z_slice in range(len(existing_images)): + img_file = self.img_dir.joinpath(f"{case_id}_{z_slice}.npy") + + for obs in range(len(structure_paths)): + mask_file = self.mask_dir.joinpath(f"{case_id}_{obs}_{z_slice}.npy") + self.slices.append( + { + "z": z_slice, + "image": img_file, + "mask": mask_file, + "case": case_id, + "observer": obs, + } + ) + + continue + img = sitk.ReadImage(img_path) - if not img_file.exists(): - img = preprocess_image(img) - sitk.WriteImage(img, str(img_file)) + img = preprocess_image(img) - for obs, structure_path in enumerate(structure_paths): - structure_path = str(structure_path) - mask = sitk.ReadImage(structure_path) - mask = resample_mask_to_image(img, mask) - mask_file = self.mask_dir.joinpath(f"{case_id}_{obs}.nii.gz") - sitk.WriteImage(mask, str(mask_file)) + observers = [] + for obs, structure_path in enumerate(structure_paths): + structure_path = str(structure_path) + mask = sitk.ReadImage(structure_path) + mask = resample_mask_to_image(img, mask) + observers.append(mask) for z_slice in range(img.GetSize()[2]): - for obs, mask in enumerate(structure_paths): - mask_file = self.mask_dir.joinpath(f"{case_id}_{obs}.nii.gz") - self.slices.append({"z": z_slice, "image": img_file, "mask": mask_file}) + + img_slice = img[:, :, z_slice] + img_file = self.img_dir.joinpath(f"{case_id}_{z_slice}.npy") + np.save(img_file, sitk.GetArrayFromImage(img_slice)) + + for obs, mask in enumerate(observers): + + mask_slice = mask[:, :, z_slice] + mask_file = self.mask_dir.joinpath(f"{case_id}_{obs}_{z_slice}.npy") + np.save(mask_file, sitk.GetArrayFromImage(mask_slice)) + self.slices.append( + { + "z": z_slice, + "image": img_file, + "mask": mask_file, + "case": case_id, + "observer": obs, + } + ) def __len__(self): return len(self.slices) @@ -162,15 +200,12 @@ def __getitem__(self, index): img_file = self.slices[index]["image"] mask_file = self.slices[index]["mask"] - z_slice = self.slices[index]["z"] - + img = sitk.ReadImage(str(img_file)) img = sitk.GetArrayFromImage(img) - img = img[z_slice, :, :] mask = sitk.ReadImage(str(mask_file)) mask = sitk.GetArrayFromImage(mask) - mask = mask[z_slice, :, :] segmap = SegmentationMapsOnImage(mask, shape=mask.shape) img, mask = self.transforms(image=img, segmentation_maps=segmap) From 6fff9e6271fee5fb08ff255a9b51a7c4c445a33c Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Fri, 2 Jul 2021 01:25:53 +0000 Subject: [PATCH 065/264] correct reading slice --- platipy/imaging/cnn/dataset.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/platipy/imaging/cnn/dataset.py b/platipy/imaging/cnn/dataset.py index 7e101e59..af74bd24 100644 --- a/platipy/imaging/cnn/dataset.py +++ b/platipy/imaging/cnn/dataset.py @@ -198,14 +198,8 @@ def __len__(self): def __getitem__(self, index): - img_file = self.slices[index]["image"] - mask_file = self.slices[index]["mask"] - - img = sitk.ReadImage(str(img_file)) - img = sitk.GetArrayFromImage(img) - - mask = sitk.ReadImage(str(mask_file)) - mask = sitk.GetArrayFromImage(mask) + img = np.load(self.slices[index]["image"]) + mask = np.load(self.slices[index]["mask"]) segmap = SegmentationMapsOnImage(mask, shape=mask.shape) img, mask = self.transforms(image=img, segmentation_maps=segmap) From e0d154876c88e15a591ea36f106427dd281cdec3 Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Fri, 2 Jul 2021 05:31:51 +0000 Subject: [PATCH 066/264] Working on validation step --- platipy/imaging/cnn/dataset.py | 10 ++- platipy/imaging/cnn/train.py | 149 +++++++++++++++++++++++++-------- 2 files changed, 122 insertions(+), 37 deletions(-) diff --git a/platipy/imaging/cnn/dataset.py b/platipy/imaging/cnn/dataset.py index af74bd24..98fd0abf 100644 --- a/platipy/imaging/cnn/dataset.py +++ b/platipy/imaging/cnn/dataset.py @@ -208,4 +208,12 @@ def __getitem__(self, index): img = torch.FloatTensor(img) mask = torch.LongTensor(mask) - return img.unsqueeze(0), mask + return ( + img.unsqueeze(0), + mask, + { + "case": self.slices[index]["case"], + "observer": self.slices[index]["observer"], + "z": self.slices[index]["z"], + }, + ) diff --git a/platipy/imaging/cnn/train.py b/platipy/imaging/cnn/train.py index 696ab0ca..97e6ff61 100644 --- a/platipy/imaging/cnn/train.py +++ b/platipy/imaging/cnn/train.py @@ -1,7 +1,10 @@ import os import math +import tempfile +import shutil from pathlib import Path +import SimpleITK as sitk import numpy as np from scipy.optimize import linear_sum_assignment @@ -13,6 +16,8 @@ from argparse import ArgumentParser +from torch._C import NoneType + from platipy.imaging.cnn.prob_unet import ProbabilisticUnet from platipy.imaging.cnn.unet import l2_regularisation from platipy.imaging.cnn.dataset import NiftiDataset @@ -46,6 +51,8 @@ def __init__( loss_params, ) + self.validation_directory = None + @staticmethod def add_model_specific_args(parent_parser): parser = parent_parser.add_argument_group("Probabilistic UNet") @@ -74,7 +81,7 @@ def configure_optimizers(self): return optimizer def training_step(self, batch, batch_idx): - x, y = batch + x, y, _ = batch self.prob_unet.forward(x, y, training=True) loss = self.prob_unet.loss(y, analytic_kl=True) reg_loss = ( @@ -99,38 +106,108 @@ def training_step(self, batch, batch_idx): return loss def validation_step(self, batch, batch_idx): + + if self.validation_directory is None: + self.validation_directory = Path(tempfile.mkdtemp()) + print(self.validation_directory) + with torch.set_grad_enabled(False): - x, y = batch - self.prob_unet.forward(x) - sample = self.prob_unet.sample(testing=True) - mean = self.prob_unet.sample(testing=True, use_mean=True) + x, y, info = batch - criterion = torch.nn.BCEWithLogitsLoss( - size_average=False, reduce=False, reduction=None - ) + for s in range(y.shape[0]): - onehot = torch.nn.functional.one_hot(y, 2).transpose(1, 3).float() + img_file = self.validation_directory.joinpath( + f"img_{info['case'][s]}_{info['z'][s]}.npy" + ) + np.save(img_file, x[0].unsqueeze(0).numpy()) - observers = y.shape[0] - sim_matrix = np.zeros((observers, observers)) - for i in range(observers): - for j in range(observers): - rec_loss = criterion(input=sample[i], target=onehot[j]) - rec_loss = torch.sum(rec_loss) - sim_matrix[i, j] = rec_loss.item() + mask_file = self.validation_directory.joinpath( + f"mask_{info['case'][s]}_{info['z'][s]}_{info['observer'][s]}.npy" + ) + np.save(mask_file, y[0].unsqueeze(0).numpy()) - row_idx, col_idx = linear_sum_assignment(sim_matrix) + self.prob_unet.forward(x[s].unsqueeze(0)) + sample = self.prob_unet.sample(testing=True) + sample_file = self.validation_directory.joinpath( + f"sample_{info['case'][s]}_{info['z'][s]}_{info['observer'][s]}.npy" + ) + np.save(sample_file, sample.numpy()) - matched_val = sim_matrix[row_idx, col_idx].mean() - self.log( - "matched_val", matched_val, on_step=True, on_epoch=True, prog_bar=True, logger=True - ) + mean = self.prob_unet.sample(testing=True, use_mean=True) + mean_file = self.validation_directory.joinpath( + f"mean_{info['case'][s]}_{info['z'][s]}_{info['observer'][s]}.npy" + ) + np.save(mean_file, mean.numpy()) + + def validation_epoch_end(self, validation_step_outputs): + + print(validation_step_outputs) + + cases = {} + for mask in self.validation_directory.glob("mask_*.nii.gz"): + + parts = mask.name.replace(".nii.gz", "").split("_") + case = parts[0] + z = parts[1] + observer = parts[2] + + if not case in cases: + cases[case] = {"slices": z, "observers": [observer]} + else: + if z > cases[case]["slices"]: + cases[case]["slices"] = z + cases[case]["observers"].append(observer) + + for case in cases: + + img_arrs = [] + for z in range(cases[case]["slices"] + 1): + img_file = self.validation_directory.joinpath(f"img_{case}_{z}.npy") + img_arrs.append(np.load(img_file)) + + img_arr = np.stack(img_arrs) + + sitk.WriteImage(sitk.GetImageFromArray(img_arr), f"test_{case}.nii.gz") + + # for pred in validation_step_outputs: + # # do something with a pred + # print(pred) + + # break + + # criterion = torch.nn.BCEWithLogitsLoss( + # size_average=False, reduce=False, reduction=None + # ) + + # onehot = torch.nn.functional.one_hot(y, 2).transpose(1, 3).float() + + # observers = y.shape[0] + # sim_matrix = np.zeros((observers, observers)) + # for i in range(observers): + # for j in range(observers): + # rec_loss = criterion(input=sample[i], target=onehot[j]) + # rec_loss = torch.sum(rec_loss) + # sim_matrix[i, j] = rec_loss.item() + + # row_idx, col_idx = linear_sum_assignment(sim_matrix) + + # matched_val = sim_matrix[row_idx, col_idx].mean() + # self.log( + # "matched_val", + # matched_val, + # on_step=False, + # on_epoch=True, + # prog_bar=False, + # logger=True, + # ) - mean_val = criterion(input=mean, target=onehot) - mean_val = torch.sum(rec_loss).item() - self.log("mean_val", mean_val, on_step=True, on_epoch=True, prog_bar=True, logger=True) + # mean_val = criterion(input=mean, target=onehot) + # mean_val = torch.sum(rec_loss).item() + # self.log( + # "mean_val", mean_val, on_step=False, on_epoch=True, prog_bar=False, logger=True + # ) - return matched_val + # shutil.rmtree(self.validation_directory) class ProbUNetDataModule(pl.LightningDataModule): @@ -240,6 +317,7 @@ def main(args): project_name=os.environ["COMET_PROJECT"], experiment_name=args.experiment, save_dir=args.working_dir, + offline=True, ) dict_args = vars(args) @@ -249,17 +327,16 @@ def main(args): prob_unet = ProbUNet(**dict_args) trainer = pl.Trainer.from_argparse_args(args) trainer.logger = comet_logger - trainer.fit(prob_unet, data_module) + trainer.fit(prob_unet, data_module) # pylint: disable=no-member if __name__ == "__main__": - parser = ArgumentParser() - parser = ProbUNet.add_model_specific_args(parser) - parser = ProbUNetDataModule.add_model_specific_args(parser) - parser = pl.Trainer.add_argparse_args(parser) - parser.add_argument("--seed", type=int, default=42, help="an integer to use as seed") - parser.add_argument("--experiment", type=str, default="default", help="Name of experiment") - parser.add_argument("--working_dir", type=str, default="./working") - args = parser.parse_args() - - main(args) + arg_parser = ArgumentParser() + arg_parser = ProbUNet.add_model_specific_args(arg_parser) + arg_parser = ProbUNetDataModule.add_model_specific_args(arg_parser) + arg_parser = pl.Trainer.add_argparse_args(arg_parser) + arg_parser.add_argument("--seed", type=int, default=42, help="an integer to use as seed") + arg_parser.add_argument("--experiment", type=str, default="default", help="Name of experiment") + arg_parser.add_argument("--working_dir", type=str, default="./working") + + main(arg_parser.parse_args()) From e4d99c8276d98deafc1936fc66468dbce2c72298 Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Fri, 2 Jul 2021 17:24:02 +1000 Subject: [PATCH 067/264] Working on validation --- platipy/imaging/cnn/dataset.py | 13 +++-- platipy/imaging/cnn/train.py | 98 ++++++++++++++++++++++++++-------- 2 files changed, 85 insertions(+), 26 deletions(-) diff --git a/platipy/imaging/cnn/dataset.py b/platipy/imaging/cnn/dataset.py index 98fd0abf..23e1f74a 100644 --- a/platipy/imaging/cnn/dataset.py +++ b/platipy/imaging/cnn/dataset.py @@ -109,7 +109,7 @@ def prepare_transforms(): class NiftiDataset(torch.utils.data.Dataset): """PyTorch Dataset for processing Nifti data""" - def __init__(self, data, working_dir): + def __init__(self, data, working_dir, augment_on_the_fly=True): """Prepare a dataset from Nifti images/labels Args: @@ -120,7 +120,9 @@ def __init__(self, data, working_dir): """ self.data = data - self.transforms = prepare_transforms() + self.transforms = None + if augment_on_the_fly: + self.transforms = prepare_transforms() self.slices = [] self.working_dir = Path(working_dir) @@ -201,9 +203,10 @@ def __getitem__(self, index): img = np.load(self.slices[index]["image"]) mask = np.load(self.slices[index]["mask"]) - segmap = SegmentationMapsOnImage(mask, shape=mask.shape) - img, mask = self.transforms(image=img, segmentation_maps=segmap) - mask = mask.get_arr() + if self.transforms: + segmap = SegmentationMapsOnImage(mask, shape=mask.shape) + img, mask = self.transforms(image=img, segmentation_maps=segmap) + mask = mask.get_arr() img = torch.FloatTensor(img) mask = torch.LongTensor(mask) diff --git a/platipy/imaging/cnn/train.py b/platipy/imaging/cnn/train.py index 97e6ff61..ed7319e8 100644 --- a/platipy/imaging/cnn/train.py +++ b/platipy/imaging/cnn/train.py @@ -23,6 +23,8 @@ from platipy.imaging.cnn.dataset import NiftiDataset from platipy.imaging.cnn.sampler import ObserverSampler +from platipy.imaging import ImageVisualiser +from platipy.imaging.label.utils import get_com class ProbUNet(pl.LightningModule): def __init__( @@ -119,55 +121,109 @@ def validation_step(self, batch, batch_idx): img_file = self.validation_directory.joinpath( f"img_{info['case'][s]}_{info['z'][s]}.npy" ) - np.save(img_file, x[0].unsqueeze(0).numpy()) + np.save(img_file, x[0].squeeze(0).cpu().numpy()) mask_file = self.validation_directory.joinpath( f"mask_{info['case'][s]}_{info['z'][s]}_{info['observer'][s]}.npy" ) - np.save(mask_file, y[0].unsqueeze(0).numpy()) + np.save(mask_file, y[s].squeeze(0).cpu().numpy()) self.prob_unet.forward(x[s].unsqueeze(0)) sample = self.prob_unet.sample(testing=True) sample_file = self.validation_directory.joinpath( f"sample_{info['case'][s]}_{info['z'][s]}_{info['observer'][s]}.npy" ) - np.save(sample_file, sample.numpy()) + sample = np.argmax(sample.squeeze(0).cpu().numpy(), axis=0) + np.save(sample_file, sample) mean = self.prob_unet.sample(testing=True, use_mean=True) mean_file = self.validation_directory.joinpath( - f"mean_{info['case'][s]}_{info['z'][s]}_{info['observer'][s]}.npy" + f"mean_{info['case'][s]}_{info['z'][s]}.npy" ) - np.save(mean_file, mean.numpy()) + mean = np.argmax(mean.squeeze(0).cpu().numpy(), axis=0) + np.save(mean_file, mean) + + return info def validation_epoch_end(self, validation_step_outputs): - print(validation_step_outputs) + print(self.validation_directory) cases = {} - for mask in self.validation_directory.glob("mask_*.nii.gz"): - - parts = mask.name.replace(".nii.gz", "").split("_") - case = parts[0] - z = parts[1] - observer = parts[2] + for info in validation_step_outputs: - if not case in cases: - cases[case] = {"slices": z, "observers": [observer]} - else: - if z > cases[case]["slices"]: - cases[case]["slices"] = z - cases[case]["observers"].append(observer) + for case, z, observer in zip(info["case"], info["z"], info["observer"]): + if not case in cases: + cases[case] = {"slices": z.item(), "observers": [observer.item()]} + else: + if z.item() > cases[case]["slices"]: + cases[case]["slices"] = z.item() + if not observer in cases[case]["observers"]: + cases[case]["observers"].append(observer.item()) for case in cases: img_arrs = [] + mean_arrs = [] + slices = [] for z in range(cases[case]["slices"] + 1): img_file = self.validation_directory.joinpath(f"img_{case}_{z}.npy") - img_arrs.append(np.load(img_file)) + mean_file = self.validation_directory.joinpath(f"mean_{case}_{z}.npy") + if img_file.exists(): + img_arrs.append(np.load(img_file)) + mean_arrs.append(np.load(mean_file)) + slices.append(z) + + if len(slices) < 5: + # Likely initial sanity check + continue img_arr = np.stack(img_arrs) + print(img_arr.min()) + print(img_arr.max()) + img = sitk.GetImageFromArray(img_arr) + sitk.WriteImage(img, f"test_{case}.nii.gz") + + mean_arr = np.stack(mean_arrs) + mean = sitk.GetImageFromArray(mean_arr) + mean = sitk.Cast(mean, sitk.sitkUInt8) + sitk.WriteImage(mean, f"test_mean_{case}_{observer}.nii.gz") + + obs_dict = {} + pred_dict = {} + for observer in cases[case]["observers"]: + mask_arrs = [] + sample_arrs = [] + for z in slices: + mask_file = self.validation_directory.joinpath(f"mask_{case}_{z}_{observer}.npy") + sample_file = self.validation_directory.joinpath(f"sample_{case}_{z}_{observer}.npy") + + mask_arrs.append(np.load(mask_file)) + sample_arrs.append(np.load(sample_file)) + + mask_arr = np.stack(mask_arrs) + mask = sitk.GetImageFromArray(mask_arr) + mask = sitk.Cast(mask, sitk.sitkUInt8) + sitk.WriteImage(mask, f"test_mask_{case}_{observer}.nii.gz") + obs_dict[f"manual_{observer}"] = mask + + sample_arr = np.stack(sample_arrs) + sample = sitk.GetImageFromArray(sample_arr) + sample = sitk.Cast(sample, sitk.sitkUInt8) + sitk.WriteImage(sample, f"test_sample_{case}_{observer}.nii.gz") + pred_dict[f"auto_{observer}"] = sample + + img_vis = ImageVisualiser(img, cut=get_com(mask), figure_size_in=16, window=[img_arr.min(), img_arr.max()]) + + # color_dict = {str(i): [0.5, 0.5, 0.5] for i, m in enumerate(observers)} + contour_dict = {**obs_dict, **pred_dict} + + img_vis.add_contour(contour_dict)#, color=color_dict) + fig = img_vis.show() + figure_path = f"valid_{case}.png" + fig.savefig(figure_path, dpi=300) + - sitk.WriteImage(sitk.GetImageFromArray(img_arr), f"test_{case}.nii.gz") # for pred in validation_step_outputs: # # do something with a pred @@ -279,7 +335,7 @@ def setup(self, stage=None): self.training_set = NiftiDataset(train_data, self.working_dir.joinpath("train")) self.validation_set = NiftiDataset( - validation_data, self.working_dir.joinpath("validation") + validation_data, self.working_dir.joinpath("validation"), False ) def train_dataloader(self): From e20d4b1fa5a8c366937fc76f632fae02fb0f9b2e Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Fri, 2 Jul 2021 07:28:44 +0000 Subject: [PATCH 068/264] Log images --- platipy/imaging/cnn/train.py | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/platipy/imaging/cnn/train.py b/platipy/imaging/cnn/train.py index ed7319e8..1f4386df 100644 --- a/platipy/imaging/cnn/train.py +++ b/platipy/imaging/cnn/train.py @@ -26,6 +26,7 @@ from platipy.imaging import ImageVisualiser from platipy.imaging.label.utils import get_com + class ProbUNet(pl.LightningModule): def __init__( self, @@ -179,8 +180,6 @@ def validation_epoch_end(self, validation_step_outputs): continue img_arr = np.stack(img_arrs) - print(img_arr.min()) - print(img_arr.max()) img = sitk.GetImageFromArray(img_arr) sitk.WriteImage(img, f"test_{case}.nii.gz") @@ -195,8 +194,12 @@ def validation_epoch_end(self, validation_step_outputs): mask_arrs = [] sample_arrs = [] for z in slices: - mask_file = self.validation_directory.joinpath(f"mask_{case}_{z}_{observer}.npy") - sample_file = self.validation_directory.joinpath(f"sample_{case}_{z}_{observer}.npy") + mask_file = self.validation_directory.joinpath( + f"mask_{case}_{z}_{observer}.npy" + ) + sample_file = self.validation_directory.joinpath( + f"sample_{case}_{z}_{observer}.npy" + ) mask_arrs.append(np.load(mask_file)) sample_arrs.append(np.load(sample_file)) @@ -213,17 +216,19 @@ def validation_epoch_end(self, validation_step_outputs): sitk.WriteImage(sample, f"test_sample_{case}_{observer}.nii.gz") pred_dict[f"auto_{observer}"] = sample - img_vis = ImageVisualiser(img, cut=get_com(mask), figure_size_in=16, window=[img_arr.min(), img_arr.max()]) + img_vis = ImageVisualiser( + img, cut=get_com(mask), figure_size_in=16, window=[img_arr.min(), img_arr.max()] + ) # color_dict = {str(i): [0.5, 0.5, 0.5] for i, m in enumerate(observers)} contour_dict = {**obs_dict, **pred_dict} - img_vis.add_contour(contour_dict)#, color=color_dict) + img_vis.add_contour(contour_dict) # , color=color_dict) fig = img_vis.show() figure_path = f"valid_{case}.png" fig.savefig(figure_path, dpi=300) - + self.logger.experiment.add_image(figure_path) # for pred in validation_step_outputs: # # do something with a pred From 7596e410e7f6754a5189976bd0caa1389b102271 Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Sat, 3 Jul 2021 11:36:54 +1000 Subject: [PATCH 069/264] Work on prob unet --- platipy/imaging/cnn/train.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/platipy/imaging/cnn/train.py b/platipy/imaging/cnn/train.py index 1f4386df..2be5def4 100644 --- a/platipy/imaging/cnn/train.py +++ b/platipy/imaging/cnn/train.py @@ -228,7 +228,7 @@ def validation_epoch_end(self, validation_step_outputs): figure_path = f"valid_{case}.png" fig.savefig(figure_path, dpi=300) - self.logger.experiment.add_image(figure_path) + self.logger.experiment.log_image(figure_path) # for pred in validation_step_outputs: # # do something with a pred @@ -378,7 +378,7 @@ def main(args): project_name=os.environ["COMET_PROJECT"], experiment_name=args.experiment, save_dir=args.working_dir, - offline=True, +# offline=True, ) dict_args = vars(args) From 3692acec1e8e1023ae1bce626dc60d6497befa78 Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Sat, 3 Jul 2021 02:17:10 +0000 Subject: [PATCH 070/264] Update to pseudo generator --- platipy/imaging/cnn/pseudo_generator.py | 19 +++++++--- platipy/imaging/cnn/train.py | 46 +++++++++++++++++++------ 2 files changed, 51 insertions(+), 14 deletions(-) diff --git a/platipy/imaging/cnn/pseudo_generator.py b/platipy/imaging/cnn/pseudo_generator.py index bf3a78f2..af493912 100644 --- a/platipy/imaging/cnn/pseudo_generator.py +++ b/platipy/imaging/cnn/pseudo_generator.py @@ -23,14 +23,24 @@ def generate_pseudo_data(data_dir="data"): for case, sphere_rad in enumerate(range(10, 30)): - ct_arr = np.ones((80, 128, 128)) * -1000 - mask_arr = np.zeros((80, 128, 128)) - xpos = random.randint(50, 80) ypos = random.randint(50, 80) - ct_arr = insert_sphere(ct_arr, sp_radius=sphere_rad, sp_centre=(30, ypos, xpos)) + mask_arr = np.zeros((80, 128, 128)) + mask_arr = insert_sphere(mask_arr, sp_radius=sphere_rad, sp_centre=(30, ypos, xpos)) + + mask = sitk.GetImageFromArray(mask_arr) + mask = sitk.Cast(mask, sitk.sitkUInt8) + mask = sitk.BinaryNot(mask) + + ct = sitk.SignedMaurerDistanceMap(mask) + + ct_arr = sitk.GetArrayFromImage(ct) + ct_arr[ct_arr < -10] = -1000 + ct_arr[ct_arr > 20] = 100 + ct = sitk.GetImageFromArray(ct_arr) + sitk.WriteImage(ct, str(image_directory.joinpath(f"{case}.nii.gz"))) vis = ImageVisualiser(ct, cut=(30, ypos, xpos)) @@ -39,6 +49,7 @@ def generate_pseudo_data(data_dir="data"): for obs_id, obs in enumerate(range(-4, 5, 2)): obs_rad = sphere_rad + obs + mask_arr = np.zeros((80, 128, 128)) mask_arr = insert_sphere(mask_arr, sp_radius=obs_rad, sp_centre=(30, ypos, xpos)) mask = sitk.GetImageFromArray(mask_arr) diff --git a/platipy/imaging/cnn/train.py b/platipy/imaging/cnn/train.py index 2be5def4..98225080 100644 --- a/platipy/imaging/cnn/train.py +++ b/platipy/imaging/cnn/train.py @@ -371,15 +371,34 @@ def main(args): args.working_dir = Path(args.working_dir) args.working_dir = args.working_dir.joinpath(args.experiment) - - comet_logger = CometLogger( - api_key=os.environ["COMET_API_KEY"], - workspace=os.environ["COMET_WORKSPACE"], - project_name=os.environ["COMET_PROJECT"], - experiment_name=args.experiment, - save_dir=args.working_dir, -# offline=True, - ) + args.default_root_dir = str(args.working_dir) + + comet_api_key = None + comet_workspace = None + comet_project = None + + if args.comet_api_key: + comet_api_key = args.comet_api_key + comet_workspace = args.comet_workspace + comet_project = args.comet_project + + if comet_api_key is None: + if "COMET_API_KEY" in os.environ: + comet_api_key = os.environ["COMET_API_KEY"] + if "COMET_WORKSPACE" in os.environ: + comet_workspace = os.environ["COMET_WORKSPACE"] + if "COMET_PROJECT" in os.environ: + comet_project = os.environ["COMET_PROJECT"] + + if comet_api_key is not None: + comet_logger = CometLogger( + api_key=comet_api_key, + workspace=comet_workspace, + project_name=comet_project, + experiment_name=args.experiment, + save_dir=args.working_dir, + offline=True, + ) dict_args = vars(args) @@ -387,7 +406,10 @@ def main(args): prob_unet = ProbUNet(**dict_args) trainer = pl.Trainer.from_argparse_args(args) - trainer.logger = comet_logger + + if comet_api_key is not None: + trainer.logger = comet_logger + trainer.fit(prob_unet, data_module) # pylint: disable=no-member @@ -399,5 +421,9 @@ def main(args): arg_parser.add_argument("--seed", type=int, default=42, help="an integer to use as seed") arg_parser.add_argument("--experiment", type=str, default="default", help="Name of experiment") arg_parser.add_argument("--working_dir", type=str, default="./working") + arg_parser.add_argument("--offline", type=bool, default=False) + arg_parser.add_argument("--comet_api_key", type=str, default=None) + arg_parser.add_argument("--comet_workspace", type=str, default=None) + arg_parser.add_argument("--comet_project", type=str, default=None) main(arg_parser.parse_args()) From 6c4eada2d6995b4b8be0b385d647aa00b9d0bfd7 Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Sat, 3 Jul 2021 02:46:23 +0000 Subject: [PATCH 071/264] validation metrics --- platipy/imaging/cnn/train.py | 137 +++++++++++++++++++++++------------ 1 file changed, 92 insertions(+), 45 deletions(-) diff --git a/platipy/imaging/cnn/train.py b/platipy/imaging/cnn/train.py index 98225080..6323f63a 100644 --- a/platipy/imaging/cnn/train.py +++ b/platipy/imaging/cnn/train.py @@ -27,6 +27,46 @@ from platipy.imaging.label.utils import get_com +def post_process(pred): + + # Take only the largest componenet + labelled_image = sitk.ConnectedComponent(pred) + label_shape_filter = sitk.LabelShapeStatisticsImageFilter() + label_shape_filter.Execute(labelled_image) + label_indices = label_shape_filter.GetLabels() + voxel_counts = [label_shape_filter.GetNumberOfPixels(i) for i in label_indices] + if len(voxel_counts) > 0: + largest_component_label = label_indices[np.argmax(voxel_counts)] + largest_component_image = labelled_image == largest_component_label + pred = sitk.Cast(largest_component_image, sitk.sitkUInt8) + + # Fill any holes in the structure + pred = sitk.BinaryMorphologicalClosing(pred, (5, 5, 5)) + pred = sitk.BinaryFillhole(pred) + + return pred + + +def get_metrics(target, pred): + + result = {} + lomif = sitk.LabelOverlapMeasuresImageFilter() + lomif.Execute(target, pred) + result["JI"] = lomif.GetJaccardCoefficient() + result["DSC"] = lomif.GetDiceCoefficient() + + if sitk.GetArrayFromImage(pred).sum() == 0: + result["HD"] = 1000 + result["ASD"] = 100 + else: + hdif = sitk.HausdorffDistanceImageFilter() + hdif.Execute(target, pred) + result["HD"] = hdif.GetHausdorffDistance() + result["ASD"] = hdif.GetAverageHausdorffDistance() + + return result + + class ProbUNet(pl.LightningModule): def __init__( self, @@ -148,8 +188,6 @@ def validation_step(self, batch, batch_idx): def validation_epoch_end(self, validation_step_outputs): - print(self.validation_directory) - cases = {} for info in validation_step_outputs: @@ -162,6 +200,9 @@ def validation_epoch_end(self, validation_step_outputs): cases[case]["slices"] = z.item() if not observer in cases[case]["observers"]: cases[case]["observers"].append(observer.item()) + + metrics = ["JI", "DSC", "HD", "ASD"] + result = {"probnet": {k: [] for k in metrics}, "unet": {k: [] for k in metrics}} for case in cases: img_arrs = [] @@ -181,15 +222,18 @@ def validation_epoch_end(self, validation_step_outputs): img_arr = np.stack(img_arrs) img = sitk.GetImageFromArray(img_arr) - sitk.WriteImage(img, f"test_{case}.nii.gz") + # sitk.WriteImage(img, f"test_{case}.nii.gz") mean_arr = np.stack(mean_arrs) mean = sitk.GetImageFromArray(mean_arr) mean = sitk.Cast(mean, sitk.sitkUInt8) - sitk.WriteImage(mean, f"test_mean_{case}_{observer}.nii.gz") + mean = post_process(mean) + # sitk.WriteImage(mean, f"val_mean_{case}_mean.nii.gz") obs_dict = {} pred_dict = {} + observers = [] + samples = [] for observer in cases[case]["observers"]: mask_arrs = [] sample_arrs = [] @@ -207,13 +251,16 @@ def validation_epoch_end(self, validation_step_outputs): mask_arr = np.stack(mask_arrs) mask = sitk.GetImageFromArray(mask_arr) mask = sitk.Cast(mask, sitk.sitkUInt8) - sitk.WriteImage(mask, f"test_mask_{case}_{observer}.nii.gz") + # sitk.WriteImage(mask, f"val_mask_{case}_{observer}.nii.gz") + mask.append(mask) obs_dict[f"manual_{observer}"] = mask sample_arr = np.stack(sample_arrs) sample = sitk.GetImageFromArray(sample_arr) sample = sitk.Cast(sample, sitk.sitkUInt8) - sitk.WriteImage(sample, f"test_sample_{case}_{observer}.nii.gz") + sample = post_process(sample) + # sitk.WriteImage(sample, f"val_sample_{case}_{observer}.nii.gz") + samples.append(sample) pred_dict[f"auto_{observer}"] = sample img_vis = ImageVisualiser( @@ -222,6 +269,7 @@ def validation_epoch_end(self, validation_step_outputs): # color_dict = {str(i): [0.5, 0.5, 0.5] for i, m in enumerate(observers)} contour_dict = {**obs_dict, **pred_dict} + contour_dict["mean"] = mean img_vis.add_contour(contour_dict) # , color=color_dict) fig = img_vis.show() @@ -230,44 +278,43 @@ def validation_epoch_end(self, validation_step_outputs): self.logger.experiment.log_image(figure_path) - # for pred in validation_step_outputs: - # # do something with a pred - # print(pred) - - # break - - # criterion = torch.nn.BCEWithLogitsLoss( - # size_average=False, reduce=False, reduction=None - # ) - - # onehot = torch.nn.functional.one_hot(y, 2).transpose(1, 3).float() - - # observers = y.shape[0] - # sim_matrix = np.zeros((observers, observers)) - # for i in range(observers): - # for j in range(observers): - # rec_loss = criterion(input=sample[i], target=onehot[j]) - # rec_loss = torch.sum(rec_loss) - # sim_matrix[i, j] = rec_loss.item() - - # row_idx, col_idx = linear_sum_assignment(sim_matrix) - - # matched_val = sim_matrix[row_idx, col_idx].mean() - # self.log( - # "matched_val", - # matched_val, - # on_step=False, - # on_epoch=True, - # prog_bar=False, - # logger=True, - # ) - - # mean_val = criterion(input=mean, target=onehot) - # mean_val = torch.sum(rec_loss).item() - # self.log( - # "mean_val", mean_val, on_step=False, on_epoch=True, prog_bar=False, logger=True - # ) - + sim = {k: np.zeros((len(observers), len(samples))) for k in metrics} + msim = {k: np.zeros((len(observers), len(samples))) for k in metrics} + for sid, samp in enumerate(samples): + for oid, obs in enumerate(observers): + sample_metrics = get_metrics(obs, samp) + mean_metrics = get_metrics(obs, mean) + + for k in sample_metrics: + sim[k][sid, oid] = sample_metrics[k] + msim[k][sid, oid] = mean_metrics[k] + + for k in sim: + + val = sim[k] + if not k.endswith("D"): + val = -val + row_idx, col_idx = linear_sum_assignment(val) + prob_unet_mean = sim[k][row_idx, col_idx].mean() + result["prob"][k].append(prob_unet_mean) + + val = msim[k] + if not k.endswith("D"): + val = -val + row_idx, col_idx = linear_sum_assignment(val) + unet_mean = msim[k][row_idx, col_idx].mean() + result["unet"][k].append(unet_mean) + + for t in result: + for m in result[t]: + self.log( + f"val_{t}_{m}", + np.array(result[t][m]).mean(), + on_step=True, + on_epoch=True, + prog_bar=True, + logger=True, + ) # shutil.rmtree(self.validation_directory) @@ -397,7 +444,7 @@ def main(args): project_name=comet_project, experiment_name=args.experiment, save_dir=args.working_dir, - offline=True, + offline=args.offline, ) dict_args = vars(args) From e31763c4b5bad07c609bb7078261ca92389aa47e Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Sat, 3 Jul 2021 06:54:22 +0000 Subject: [PATCH 072/264] Log metrics during validation --- platipy/imaging/cnn/prob_unet.py | 2 +- platipy/imaging/cnn/train.py | 86 +++++++++++++++++--------------- 2 files changed, 48 insertions(+), 40 deletions(-) diff --git a/platipy/imaging/cnn/prob_unet.py b/platipy/imaging/cnn/prob_unet.py index 827e373a..48d5c26a 100644 --- a/platipy/imaging/cnn/prob_unet.py +++ b/platipy/imaging/cnn/prob_unet.py @@ -246,7 +246,7 @@ def kl_divergence(self, analytic=True, z_posterior=None): def reconstruction_loss(self, segm, reconstruct_posterior_mean=False, z_posterior=None): - criterion = torch.nn.BCEWithLogitsLoss(size_average=False, reduce=False, reduction=None) + criterion = torch.nn.BCEWithLogitsLoss(reduction="none") if z_posterior is None: z_posterior = self.posterior_latent_space.rsample() diff --git a/platipy/imaging/cnn/train.py b/platipy/imaging/cnn/train.py index 6323f63a..2b500832 100644 --- a/platipy/imaging/cnn/train.py +++ b/platipy/imaging/cnn/train.py @@ -202,7 +202,10 @@ def validation_epoch_end(self, validation_step_outputs): cases[case]["observers"].append(observer.item()) metrics = ["JI", "DSC", "HD", "ASD"] - result = {"probnet": {k: [] for k in metrics}, "unet": {k: [] for k in metrics}} + computed_metrics = { + **{f"probnet_{m}": [] for m in metrics}, + **{f"unet_{m}": [] for m in metrics}, + } for case in cases: img_arrs = [] @@ -252,7 +255,7 @@ def validation_epoch_end(self, validation_step_outputs): mask = sitk.GetImageFromArray(mask_arr) mask = sitk.Cast(mask, sitk.sitkUInt8) # sitk.WriteImage(mask, f"val_mask_{case}_{observer}.nii.gz") - mask.append(mask) + observers.append(mask) obs_dict[f"manual_{observer}"] = mask sample_arr = np.stack(sample_arrs) @@ -278,43 +281,48 @@ def validation_epoch_end(self, validation_step_outputs): self.logger.experiment.log_image(figure_path) - sim = {k: np.zeros((len(observers), len(samples))) for k in metrics} - msim = {k: np.zeros((len(observers), len(samples))) for k in metrics} - for sid, samp in enumerate(samples): - for oid, obs in enumerate(observers): - sample_metrics = get_metrics(obs, samp) - mean_metrics = get_metrics(obs, mean) - - for k in sample_metrics: - sim[k][sid, oid] = sample_metrics[k] - msim[k][sid, oid] = mean_metrics[k] - - for k in sim: - - val = sim[k] - if not k.endswith("D"): - val = -val - row_idx, col_idx = linear_sum_assignment(val) - prob_unet_mean = sim[k][row_idx, col_idx].mean() - result["prob"][k].append(prob_unet_mean) - - val = msim[k] - if not k.endswith("D"): - val = -val - row_idx, col_idx = linear_sum_assignment(val) - unet_mean = msim[k][row_idx, col_idx].mean() - result["unet"][k].append(unet_mean) - - for t in result: - for m in result[t]: - self.log( - f"val_{t}_{m}", - np.array(result[t][m]).mean(), - on_step=True, - on_epoch=True, - prog_bar=True, - logger=True, - ) + sim = {k: np.zeros((len(observers), len(samples))) for k in metrics} + msim = {k: np.zeros((len(observers), len(samples))) for k in metrics} + for sid, samp in enumerate(samples): + for oid, obs in enumerate(observers): + sample_metrics = get_metrics(obs, samp) + mean_metrics = get_metrics(obs, mean) + + for k in sample_metrics: + sim[k][sid, oid] = sample_metrics[k] + msim[k][sid, oid] = mean_metrics[k] + + result = {"probnet": {k: [] for k in metrics}, "unet": {k: [] for k in metrics}} + for k in sim: + + val = sim[k] + if not k.endswith("D"): + val = -val + row_idx, col_idx = linear_sum_assignment(val) + prob_unet_mean = sim[k][row_idx, col_idx].mean() + result["probnet"][k].append(prob_unet_mean) + + val = msim[k] + if not k.endswith("D"): + val = -val + row_idx, col_idx = linear_sum_assignment(val) + unet_mean = msim[k][row_idx, col_idx].mean() + result["unet"][k].append(unet_mean) + + for t in result: + for m in result[t]: + computed_metrics[f"{t}_{m}"].append(np.array(result[t][m]).mean()) + + for cm in computed_metrics: + self.log( + cm, + np.array(computed_metrics[cm]).mean(), + on_step=False, + on_epoch=True, + prog_bar=True, + logger=True, + ) + # shutil.rmtree(self.validation_directory) From 7c45cd77fd37947ec5221073fcea05ffab456e99 Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Sun, 4 Jul 2021 11:48:59 +1000 Subject: [PATCH 073/264] updates to vis --- platipy/imaging/cnn/train.py | 31 ++++++++++++++++++++++--------- 1 file changed, 22 insertions(+), 9 deletions(-) diff --git a/platipy/imaging/cnn/train.py b/platipy/imaging/cnn/train.py index 2b500832..8fc6a166 100644 --- a/platipy/imaging/cnn/train.py +++ b/platipy/imaging/cnn/train.py @@ -17,6 +17,7 @@ from argparse import ArgumentParser from torch._C import NoneType +import matplotlib.pyplot as plt from platipy.imaging.cnn.prob_unet import ProbabilisticUnet from platipy.imaging.cnn.unet import l2_regularisation @@ -96,6 +97,8 @@ def __init__( self.validation_directory = None + self.stddevs = np.linspace(-2,2,5) + @staticmethod def add_model_specific_args(parent_parser): parser = parent_parser.add_argument_group("Probabilistic UNet") @@ -121,7 +124,12 @@ def configure_optimizers(self): optimizer = torch.optim.Adam( self.parameters(), lr=self.hparams.learning_rate, weight_decay=0 ) - return optimizer + + scheduler = torch.optim.lr_scheduler.LambdaLR( + optimizer, lr_lambda=[lambda epoch: 0.99 ** (epoch)] + ) + + return [optimizer], [scheduler] def training_step(self, batch, batch_idx): x, y, _ = batch @@ -134,7 +142,7 @@ def training_step(self, batch, batch_idx): ) training_loss = loss["loss"] + 1e-5 * reg_loss self.log( - "training_loss", training_loss, on_step=True, on_epoch=True, prog_bar=True, logger=True + "training_loss", training_loss, on_step=True, on_epoch=False, prog_bar=True, logger=True ) for k in loss: @@ -142,7 +150,7 @@ def training_step(self, batch, batch_idx): k, loss[k], on_step=True, - on_epoch=True, + on_epoch=False, prog_bar=True, logger=True, ) @@ -170,7 +178,7 @@ def validation_step(self, batch, batch_idx): np.save(mask_file, y[s].squeeze(0).cpu().numpy()) self.prob_unet.forward(x[s].unsqueeze(0)) - sample = self.prob_unet.sample(testing=True) + sample = self.prob_unet.sample(testing=True, use_mean=False, sample_x_stddev_from_mean=self.stddevs[info["observer"][s]]) sample_file = self.validation_directory.joinpath( f"sample_{info['case'][s]}_{info['z'][s]}_{info['observer'][s]}.npy" ) @@ -189,6 +197,7 @@ def validation_step(self, batch, batch_idx): def validation_epoch_end(self, validation_step_outputs): cases = {} + cmap = plt.cm.get_cmap('Set2') for info in validation_step_outputs: for case, z, observer in zip(info["case"], info["z"], info["observer"]): @@ -235,6 +244,7 @@ def validation_epoch_end(self, validation_step_outputs): obs_dict = {} pred_dict = {} + color_dict = {} observers = [] samples = [] for observer in cases[case]["observers"]: @@ -257,6 +267,7 @@ def validation_epoch_end(self, validation_step_outputs): # sitk.WriteImage(mask, f"val_mask_{case}_{observer}.nii.gz") observers.append(mask) obs_dict[f"manual_{observer}"] = mask + color_dict[f"manual_{observer}"] = [0.5, 0.5, 0.5] sample_arr = np.stack(sample_arrs) sample = sitk.GetImageFromArray(sample_arr) @@ -264,17 +275,19 @@ def validation_epoch_end(self, validation_step_outputs): sample = post_process(sample) # sitk.WriteImage(sample, f"val_sample_{case}_{observer}.nii.gz") samples.append(sample) - pred_dict[f"auto_{observer}"] = sample + pred_dict[f"auto_{self.stddevs[observer]}"] = sample + color_dict[f"auto_{self.stddevs[observer]}"] = cmap(observer/5) img_vis = ImageVisualiser( img, cut=get_com(mask), figure_size_in=16, window=[img_arr.min(), img_arr.max()] ) - # color_dict = {str(i): [0.5, 0.5, 0.5] for i, m in enumerate(observers)} + #color_dict = {str(i): [0.5, 0.5, 0.5] for i, m in enumerate(observers)} contour_dict = {**obs_dict, **pred_dict} - contour_dict["mean"] = mean + contour_dict["auto_mean"] = mean + color_dict["auto_mean"] = [0.0, 0.0, 0.0] - img_vis.add_contour(contour_dict) # , color=color_dict) + img_vis.add_contour(contour_dict, color=color_dict) fig = img_vis.show() figure_path = f"valid_{case}.png" fig.savefig(figure_path, dpi=300) @@ -319,7 +332,7 @@ def validation_epoch_end(self, validation_step_outputs): np.array(computed_metrics[cm]).mean(), on_step=False, on_epoch=True, - prog_bar=True, + prog_bar=False, logger=True, ) From fda33a3b81323959d596cf1323cb1016a91e6eda Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Sun, 4 Jul 2021 06:26:09 +0000 Subject: [PATCH 074/264] Loss on top k percentage --- .gitignore | 1 + platipy/imaging/cnn/prob_unet.py | 72 +++++++++++++++++++++++++++++--- 2 files changed, 68 insertions(+), 5 deletions(-) diff --git a/.gitignore b/.gitignore index e69b1fa8..2f076351 100644 --- a/.gitignore +++ b/.gitignore @@ -140,6 +140,7 @@ platipy/*/tests/data testing/ converted/ **/data +**/working **/tcia **/nifti_output diff --git a/platipy/imaging/cnn/prob_unet.py b/platipy/imaging/cnn/prob_unet.py index 48d5c26a..1aecf7c7 100644 --- a/platipy/imaging/cnn/prob_unet.py +++ b/platipy/imaging/cnn/prob_unet.py @@ -244,7 +244,23 @@ def kl_divergence(self, analytic=True, z_posterior=None): kl_div = log_posterior_prob - log_prior_prob return kl_div - def reconstruction_loss(self, segm, reconstruct_posterior_mean=False, z_posterior=None): + def topk_mask(self, score, k): + """Returns a mask for the top-k elements in score.""" + + values, _ = torch.topk(score, 1, axis=1) + _, indices = torch.topk(values, k, axis=0) + return torch.scatter_add( + torch.zeros(score.shape[0]), 0, indices.reshape(-1), torch.ones(score.shape[0]) + ) + + def reconstruction_loss( + self, + segm, + reconstruct_posterior_mean=False, + z_posterior=None, + mask=None, + top_k_percentage=None, + ): criterion = torch.nn.BCEWithLogitsLoss(reduction="none") @@ -258,10 +274,49 @@ def reconstruction_loss(self, segm, reconstruct_posterior_mean=False, z_posterio segm = torch.unsqueeze(segm, dim=1) not_seg = segm.logical_not() segm = torch.cat((not_seg, segm), dim=1).float() - reconstruction_loss = criterion(input=reconstruction, target=segm) - reconstruction_loss = torch.mean(reconstruction_loss) - return reconstruction_loss + ##### + num_classes = reconstruction.shape[1] + y_flat = torch.reshape(reconstruction, (-1, num_classes)) + t_flat = torch.reshape(segm, (-1, num_classes)) + if mask is None: + mask = torch.ones(torch.reshape(t_flat, (-1, 2)).shape[0]) + else: + assert ( + mask.shape == segm.shape + ), f"The loss mask shape differs from the target shape: {mask.shape} vs. {segm.shape}." + mask = torch.reshape(segm, (-1,)) + + n_pixels_in_batch = y_flat.shape[0] + xe = criterion(input=y_flat, target=t_flat) + top_k_percentage = 0.02 + if top_k_percentage is not None: + + assert 0.0 < top_k_percentage <= 1.0 + k_pixels = int(n_pixels_in_batch * top_k_percentage) + + with torch.no_grad(): + norm_xe = xe / torch.sum(xe) + deterministic = True + if deterministic: + score = torch.log(norm_xe) + else: + # TODO Gumbel trick + raise NotImplementedError("Still need to implement Gumbel trick") + + score = score + torch.log(mask.unsqueeze(1).repeat((1, num_classes))) + top_k_mask = self.topk_mask(score, k_pixels) + mask = mask * top_k_mask + + batch_size = segm.shape[0] + xe = torch.reshape(xe, shape=(batch_size, -1)) + mask = mask.repeat((1, num_classes)) + mask = torch.reshape(mask, shape=(batch_size, -1)) + + ce_sum_per_instance = torch.sum(mask * xe, axis=1) + ce_sum = torch.mean(ce_sum_per_instance, axis=0) + + return ce_sum def loss(self, segm, analytic_kl=True, reconstruct_posterior_mean=False): """ @@ -273,8 +328,15 @@ def loss(self, segm, analytic_kl=True, reconstruct_posterior_mean=False): kl_div = torch.mean(self.kl_divergence(analytic=analytic_kl, z_posterior=z_posterior)) # Here we use the posterior sample sampled above + top_k_percentage = None + if "top_k_percentage" in self.loss_params: + top_k_percentage = self.loss_params["top_k_percentage"] + reconstruction_loss = self.reconstruction_loss( - segm, reconstruct_posterior_mean=reconstruct_posterior_mean, z_posterior=z_posterior + segm, + reconstruct_posterior_mean=reconstruct_posterior_mean, + z_posterior=z_posterior, + top_k_percentage=top_k_percentage, ) if self.loss_type == "elbo": From 5da5ba470f74b49a3a902aa3922292d62704bfd2 Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Sun, 4 Jul 2021 06:31:42 +0000 Subject: [PATCH 075/264] minor corrections --- platipy/imaging/cnn/prob_unet.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/platipy/imaging/cnn/prob_unet.py b/platipy/imaging/cnn/prob_unet.py index 1aecf7c7..b00717b1 100644 --- a/platipy/imaging/cnn/prob_unet.py +++ b/platipy/imaging/cnn/prob_unet.py @@ -286,10 +286,11 @@ def reconstruction_loss( mask.shape == segm.shape ), f"The loss mask shape differs from the target shape: {mask.shape} vs. {segm.shape}." mask = torch.reshape(segm, (-1,)) + mask = mask.to(y_flat.device) n_pixels_in_batch = y_flat.shape[0] xe = criterion(input=y_flat, target=t_flat) - top_k_percentage = 0.02 + if top_k_percentage is not None: assert 0.0 < top_k_percentage <= 1.0 @@ -306,6 +307,7 @@ def reconstruction_loss( score = score + torch.log(mask.unsqueeze(1).repeat((1, num_classes))) top_k_mask = self.topk_mask(score, k_pixels) + top_k_mask = top_k_mask.to(y_flat.device) mask = mask * top_k_mask batch_size = segm.shape[0] From 3077b8df24c10cde77da652539d32d06232aa6a2 Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Mon, 5 Jul 2021 10:25:03 +1000 Subject: [PATCH 076/264] Recale CT data properly --- platipy/imaging/cnn/dataset.py | 5 ++++- platipy/imaging/cnn/train.py | 10 +++++++--- 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/platipy/imaging/cnn/dataset.py b/platipy/imaging/cnn/dataset.py index 23e1f74a..d2b4aa43 100644 --- a/platipy/imaging/cnn/dataset.py +++ b/platipy/imaging/cnn/dataset.py @@ -14,7 +14,10 @@ def preprocess_image(img, crop_to_mm=128): - img = sitk.Normalize(img) + img = sitk.Cast(img, sitk.sitkFloat32) + img = sitk.IntensityWindowing( + img, windowMinimum=-100.0, windowMaximum=100.0, outputMinimum=-1.0, outputMaximum=1.0 + ) new_spacing = sitk.VectorDouble(3) new_spacing[0] = 1.0 diff --git a/platipy/imaging/cnn/train.py b/platipy/imaging/cnn/train.py index 8fc6a166..ab777ad8 100644 --- a/platipy/imaging/cnn/train.py +++ b/platipy/imaging/cnn/train.py @@ -2,6 +2,7 @@ import math import tempfile import shutil +import random from pathlib import Path import SimpleITK as sitk @@ -83,7 +84,7 @@ def __init__( loss_params = {"beta": self.hparams.beta} if self.hparams.loss_type == "geco": - loss_params = {"kappa": self.hparams.kappa, "clamp_rec": self.hparams.clamp_rec} + loss_params = {"kappa": self.hparams.kappa, "clamp_rec": self.hparams.clamp_rec, "top_k_percentage": self.hparams.top_k_percentage} self.prob_unet = ProbabilisticUnet( self.hparams.input_channels, @@ -114,6 +115,8 @@ def add_model_specific_args(parent_parser): parser.add_argument("--beta", type=float, default=1.0) parser.add_argument("--kappa", type=float, default=0.02) parser.add_argument("--clamp_rec", nargs="+", type=float, default=[1e-5, 1e5]) + parser.add_argument("--top_k_percentage", type=float, default=None) + return parent_parser def forward(self, x): @@ -279,10 +282,9 @@ def validation_epoch_end(self, validation_step_outputs): color_dict[f"auto_{self.stddevs[observer]}"] = cmap(observer/5) img_vis = ImageVisualiser( - img, cut=get_com(mask), figure_size_in=16, window=[img_arr.min(), img_arr.max()] + img, cut=get_com(mask), figure_size_in=16, window=[-1.0, 1.0] ) - #color_dict = {str(i): [0.5, 0.5, 0.5] for i, m in enumerate(observers)} contour_dict = {**obs_dict, **pred_dict} contour_dict["auto_mean"] = mean color_dict["auto_mean"] = [0.0, 0.0, 0.0] @@ -291,6 +293,7 @@ def validation_epoch_end(self, validation_step_outputs): fig = img_vis.show() figure_path = f"valid_{case}.png" fig.savefig(figure_path, dpi=300) + plt.close("all") self.logger.experiment.log_image(figure_path) @@ -380,6 +383,7 @@ def setup(self, stage=None): cases = [p.name.replace(".nii.gz", "") for p in self.data_dir.joinpath("images").glob("*")] cases.sort() + random.shuffle(cases) # will be consistent for same value of 'seed everything' cases_per_fold = math.ceil(len(cases) / self.k_folds) for f in range(self.k_folds): From 2856692f943529ce2b7345ab099399b76976e229 Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Sat, 10 Jul 2021 01:49:45 +0000 Subject: [PATCH 077/264] Able to parse from json file --- platipy/imaging/cnn/train.py | 84 +++++++++++++++++++++++++++++------- 1 file changed, 69 insertions(+), 15 deletions(-) diff --git a/platipy/imaging/cnn/train.py b/platipy/imaging/cnn/train.py index ab777ad8..afdd5fa5 100644 --- a/platipy/imaging/cnn/train.py +++ b/platipy/imaging/cnn/train.py @@ -1,8 +1,23 @@ +# Copyright 2021 University of New South Wales, University of Sydney, Ingham Institute + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys import os import math import tempfile -import shutil import random +import json from pathlib import Path import SimpleITK as sitk @@ -17,7 +32,6 @@ from argparse import ArgumentParser -from torch._C import NoneType import matplotlib.pyplot as plt from platipy.imaging.cnn.prob_unet import ProbabilisticUnet @@ -84,7 +98,11 @@ def __init__( loss_params = {"beta": self.hparams.beta} if self.hparams.loss_type == "geco": - loss_params = {"kappa": self.hparams.kappa, "clamp_rec": self.hparams.clamp_rec, "top_k_percentage": self.hparams.top_k_percentage} + loss_params = { + "kappa": self.hparams.kappa, + "clamp_rec": self.hparams.clamp_rec, + "top_k_percentage": self.hparams.top_k_percentage, + } self.prob_unet = ProbabilisticUnet( self.hparams.input_channels, @@ -98,7 +116,7 @@ def __init__( self.validation_directory = None - self.stddevs = np.linspace(-2,2,5) + self.stddevs = np.linspace(-2, 2, 5) @staticmethod def add_model_specific_args(parent_parser): @@ -145,7 +163,12 @@ def training_step(self, batch, batch_idx): ) training_loss = loss["loss"] + 1e-5 * reg_loss self.log( - "training_loss", training_loss, on_step=True, on_epoch=False, prog_bar=True, logger=True + "training_loss", + training_loss, + on_step=True, + on_epoch=False, + prog_bar=True, + logger=True, ) for k in loss: @@ -181,7 +204,11 @@ def validation_step(self, batch, batch_idx): np.save(mask_file, y[s].squeeze(0).cpu().numpy()) self.prob_unet.forward(x[s].unsqueeze(0)) - sample = self.prob_unet.sample(testing=True, use_mean=False, sample_x_stddev_from_mean=self.stddevs[info["observer"][s]]) + sample = self.prob_unet.sample( + testing=True, + use_mean=False, + sample_x_stddev_from_mean=self.stddevs[info["observer"][s]], + ) sample_file = self.validation_directory.joinpath( f"sample_{info['case'][s]}_{info['z'][s]}_{info['observer'][s]}.npy" ) @@ -200,7 +227,7 @@ def validation_step(self, batch, batch_idx): def validation_epoch_end(self, validation_step_outputs): cases = {} - cmap = plt.cm.get_cmap('Set2') + cmap = plt.cm.get_cmap("Set2") for info in validation_step_outputs: for case, z, observer in zip(info["case"], info["z"], info["observer"]): @@ -279,7 +306,7 @@ def validation_epoch_end(self, validation_step_outputs): # sitk.WriteImage(sample, f"val_sample_{case}_{observer}.nii.gz") samples.append(sample) pred_dict[f"auto_{self.stddevs[observer]}"] = sample - color_dict[f"auto_{self.stddevs[observer]}"] = cmap(observer/5) + color_dict[f"auto_{self.stddevs[observer]}"] = cmap(observer / 5) img_vis = ImageVisualiser( img, cut=get_com(mask), figure_size_in=16, window=[-1.0, 1.0] @@ -347,6 +374,9 @@ def __init__( self, data_dir: str = "./data", working_dir: str = "./working", + case_glob="images/*.nii.gz", + image_glob="images/{case}.nii.gz", + label_glob="labels/{case}_*.nii.gz", fold=0, k_folds=5, batch_size=5, @@ -357,6 +387,10 @@ def __init__( self.data_dir = Path(data_dir) self.working_dir = Path(working_dir) + self.case_glob = case_glob + self.image_glob = image_glob + self.label_glob = label_glob + self.fold = fold self.k_folds = k_folds @@ -366,6 +400,9 @@ def __init__( self.batch_size = batch_size self.num_workers = num_workers + self.training_set = None + self.validation_set = None + print(f"Training fold {self.fold}") @staticmethod @@ -376,14 +413,17 @@ def add_model_specific_args(parent_parser): parser.add_argument("--k_folds", type=int, default=5) parser.add_argument("--batch_size", type=int, default=5) parser.add_argument("--num_workers", type=int, default=4) + parser.add_argument("--case_glob", type=str, default="images/*.nii.gz") + parser.add_argument("--image_glob", type=str, default="images/{case}.nii.gz") + parser.add_argument("--label_glob", type=str, default="labels/{case}_*.nii.gz") return parent_parser def setup(self, stage=None): - cases = [p.name.replace(".nii.gz", "") for p in self.data_dir.joinpath("images").glob("*")] + cases = [p.name.replace(".nii.gz", "") for p in self.data_dir.glob(self.case_glob)] cases.sort() - random.shuffle(cases) # will be consistent for same value of 'seed everything' + random.shuffle(cases) # will be consistent for same value of 'seed everything' cases_per_fold = math.ceil(len(cases) / self.k_folds) for f in range(self.k_folds): @@ -395,20 +435,22 @@ def setup(self, stage=None): train_data = [ { "id": case, - "image": self.data_dir.joinpath("images", f"{case}.nii.gz"), - "label": [p for p in self.data_dir.joinpath("labels").glob(f"{case}_*.nii.gz")], + "image": self.data_dir.joinpath(self.image_glob.format(case=case)), + "label": [p for p in self.data_dir.glob(self.label_glob.format(case=case))], } for case in self.train_cases ] + print(train_data) validation_data = [ { "id": case, - "image": self.data_dir.joinpath("images", f"{case}.nii.gz"), - "label": [p for p in self.data_dir.joinpath("labels").glob(f"{case}_*.nii.gz")], + "image": self.data_dir.joinpath(self.image_glob.format(case=case)), + "label": [p for p in self.data_dir.glob(self.label_glob.format(case=case))], } for case in self.validation_cases ] + print(validation_data) self.training_set = NiftiDataset(train_data, self.working_dir.joinpath("train")) self.validation_set = NiftiDataset( @@ -486,6 +528,18 @@ def main(args): if __name__ == "__main__": + + args = None + if len(sys.argv) == 2: + # Check if JSON file parsed, if so read arguments from there... + if sys.argv[-1].endswith(".json"): + with open(sys.argv[-1], "r") as f: + params = json.load(f) + args = [] + for k in params: + args.append(f"--{k}") + args.append(params[k]) + arg_parser = ArgumentParser() arg_parser = ProbUNet.add_model_specific_args(arg_parser) arg_parser = ProbUNetDataModule.add_model_specific_args(arg_parser) @@ -498,4 +552,4 @@ def main(args): arg_parser.add_argument("--comet_workspace", type=str, default=None) arg_parser.add_argument("--comet_project", type=str, default=None) - main(arg_parser.parse_args()) + main(arg_parser.parse_args(args)) From d1fbb87fa11ccefc930ec20fde6dcbb43c7d63c3 Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Sat, 10 Jul 2021 05:20:52 +0000 Subject: [PATCH 078/264] Make crop to mm configurable --- platipy/imaging/cnn/dataset.py | 4 ++-- platipy/imaging/cnn/train.py | 27 ++++++++++++++++++--------- 2 files changed, 20 insertions(+), 11 deletions(-) diff --git a/platipy/imaging/cnn/dataset.py b/platipy/imaging/cnn/dataset.py index d2b4aa43..c44db831 100644 --- a/platipy/imaging/cnn/dataset.py +++ b/platipy/imaging/cnn/dataset.py @@ -112,7 +112,7 @@ def prepare_transforms(): class NiftiDataset(torch.utils.data.Dataset): """PyTorch Dataset for processing Nifti data""" - def __init__(self, data, working_dir, augment_on_the_fly=True): + def __init__(self, data, working_dir, augment_on_the_fly=True, crop_to_mm=128): """Prepare a dataset from Nifti images/labels Args: @@ -168,7 +168,7 @@ def __init__(self, data, working_dir, augment_on_the_fly=True): img = sitk.ReadImage(img_path) - img = preprocess_image(img) + img = preprocess_image(img, crop_to_mm=crop_to_mm) observers = [] for obs, structure_path in enumerate(structure_paths): diff --git a/platipy/imaging/cnn/train.py b/platipy/imaging/cnn/train.py index afdd5fa5..4fcd4e6f 100644 --- a/platipy/imaging/cnn/train.py +++ b/platipy/imaging/cnn/train.py @@ -381,6 +381,8 @@ def __init__( k_folds=5, batch_size=5, num_workers=4, + crop_to_mm=128, + num_observers=5, **kwargs, ): super().__init__() @@ -399,6 +401,8 @@ def __init__( self.batch_size = batch_size self.num_workers = num_workers + self.crop_to_mm = crop_to_mm + self.num_observers = num_observers self.training_set = None self.validation_set = None @@ -416,6 +420,7 @@ def add_model_specific_args(parent_parser): parser.add_argument("--case_glob", type=str, default="images/*.nii.gz") parser.add_argument("--image_glob", type=str, default="images/{case}.nii.gz") parser.add_argument("--label_glob", type=str, default="labels/{case}_*.nii.gz") + parser.add_argument("--crop_to_mm", type=int, default=128) return parent_parser @@ -432,6 +437,9 @@ def setup(self, stage=None): else: self.train_cases += cases[f * cases_per_fold : (f + 1) * cases_per_fold] + print(f"Training cases: {self.train_cases}") + print(f"Validation cases: {self.validation_cases}") + train_data = [ { "id": case, @@ -440,7 +448,6 @@ def setup(self, stage=None): } for case in self.train_cases ] - print(train_data) validation_data = [ { @@ -450,20 +457,20 @@ def setup(self, stage=None): } for case in self.validation_cases ] - print(validation_data) - self.training_set = NiftiDataset(train_data, self.working_dir.joinpath("train")) + self.training_set = NiftiDataset( + train_data, self.working_dir.joinpath("train"), crop_to_mm=self.crop_to_mm + ) self.validation_set = NiftiDataset( - validation_data, self.working_dir.joinpath("validation"), False + validation_data, + self.working_dir.joinpath("validation"), + augment_on_the_fly=False, + crop_to_mm=self.crop_to_mm, ) def train_dataloader(self): return torch.utils.data.DataLoader( self.training_set, - # batch_sampler=BatchSampler( - # ObserverSampler(train_set, 5), batch_size=params["batch_size"], drop_last=False - # ), - # num_workers=params["num_workers"], batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers, @@ -473,7 +480,9 @@ def val_dataloader(self): return torch.utils.data.DataLoader( self.validation_set, batch_sampler=torch.utils.data.BatchSampler( - ObserverSampler(self.validation_set, 5), batch_size=5, drop_last=False + ObserverSampler(self.validation_set, self.num_observers), + batch_size=self.num_observers, + drop_last=False, ), num_workers=self.num_workers, ) From b6e5ecfc7cae5f7a68154bcc1ab9961dd31d504e Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Sun, 11 Jul 2021 12:18:04 +1000 Subject: [PATCH 079/264] Updates to prob unet training --- platipy/imaging/cnn/prob_unet.py | 23 ++++++++++++++++------- platipy/imaging/cnn/train.py | 19 +++++++++++-------- 2 files changed, 27 insertions(+), 15 deletions(-) diff --git a/platipy/imaging/cnn/prob_unet.py b/platipy/imaging/cnn/prob_unet.py index b00717b1..05f436b1 100644 --- a/platipy/imaging/cnn/prob_unet.py +++ b/platipy/imaging/cnn/prob_unet.py @@ -250,7 +250,7 @@ def topk_mask(self, score, k): values, _ = torch.topk(score, 1, axis=1) _, indices = torch.topk(values, k, axis=0) return torch.scatter_add( - torch.zeros(score.shape[0]), 0, indices.reshape(-1), torch.ones(score.shape[0]) + torch.zeros(score.shape[0]).to(score.device), 0, indices.reshape(-1), torch.ones(score.shape[0]).to(score.device) ) def reconstruction_loss( @@ -318,7 +318,7 @@ def reconstruction_loss( ce_sum_per_instance = torch.sum(mask * xe, axis=1) ce_sum = torch.mean(ce_sum_per_instance, axis=0) - return ce_sum + return ce_sum, mask def loss(self, segm, analytic_kl=True, reconstruct_posterior_mean=False): """ @@ -334,7 +334,7 @@ def loss(self, segm, analytic_kl=True, reconstruct_posterior_mean=False): if "top_k_percentage" in self.loss_params: top_k_percentage = self.loss_params["top_k_percentage"] - reconstruction_loss = self.reconstruction_loss( + reconstruction_loss, mask = self.reconstruction_loss( segm, reconstruct_posterior_mean=reconstruct_posterior_mean, z_posterior=z_posterior, @@ -350,8 +350,9 @@ def loss(self, segm, analytic_kl=True, reconstruct_posterior_mean=False): } elif self.loss_type == "geco": - num_pixels = segm.numel() - reconstruction_threshold = self.loss_params["kappa"] # * num_pixels + num_pixels = mask.sum().item() + batch_size = segm.shape[0] + reconstruction_threshold = (self.loss_params["kappa"] * num_pixels) / batch_size rec_constraint = reconstruction_loss - reconstruction_threshold loss = ( @@ -360,14 +361,22 @@ def loss(self, segm, analytic_kl=True, reconstruct_posterior_mean=False): ) with torch.no_grad(): + + rc = rec_constraint.detach() if self._moving_avg is None: - self._moving_avg = rec_constraint.detach() + self._moving_avg = rc else: - self._moving_avg = self._moving_avg * 0.5 + rec_constraint.detach() * (1 - 0.5) + self._moving_avg = self._moving_avg * 0.5 + rc * (1 - 0.5) speed = 1 + self._moving_avg = self._moving_avg.clamp( + -25, 25 + ) self._lambda = (speed * torch.exp(self._moving_avg) * self._lambda).clamp( self.loss_params["clamp_rec"][0], self.loss_params["clamp_rec"][1] ) +# self._lambda = (speed * self._moving_avg * self._lambda).clamp( +# self.loss_params["clamp_rec"][0], self.loss_params["clamp_rec"][1] +# ) return { "loss": loss, "rec_loss": reconstruction_loss, diff --git a/platipy/imaging/cnn/train.py b/platipy/imaging/cnn/train.py index 4fcd4e6f..9556b9ac 100644 --- a/platipy/imaging/cnn/train.py +++ b/platipy/imaging/cnn/train.py @@ -116,7 +116,7 @@ def __init__( self.validation_directory = None - self.stddevs = np.linspace(-2, 2, 5) + self.stddevs = np.linspace(-2, 2, self.hparams.num_observers) @staticmethod def add_model_specific_args(parent_parser): @@ -153,7 +153,9 @@ def configure_optimizers(self): return [optimizer], [scheduler] def training_step(self, batch, batch_idx): + x, y, _ = batch + self.prob_unet.forward(x, y, training=True) loss = self.prob_unet.loss(y, analytic_kl=True) reg_loss = ( @@ -207,7 +209,7 @@ def validation_step(self, batch, batch_idx): sample = self.prob_unet.sample( testing=True, use_mean=False, - sample_x_stddev_from_mean=self.stddevs[info["observer"][s]], + sample_x_stddev_from_mean=self.stddevs[s], ) sample_file = self.validation_directory.joinpath( f"sample_{info['case'][s]}_{info['z'][s]}_{info['observer'][s]}.npy" @@ -277,7 +279,7 @@ def validation_epoch_end(self, validation_step_outputs): color_dict = {} observers = [] samples = [] - for observer in cases[case]["observers"]: + for idx, observer in enumerate(cases[case]["observers"]): mask_arrs = [] sample_arrs = [] for z in slices: @@ -305,8 +307,8 @@ def validation_epoch_end(self, validation_step_outputs): sample = post_process(sample) # sitk.WriteImage(sample, f"val_sample_{case}_{observer}.nii.gz") samples.append(sample) - pred_dict[f"auto_{self.stddevs[observer]}"] = sample - color_dict[f"auto_{self.stddevs[observer]}"] = cmap(observer / 5) + pred_dict[f"auto_{self.stddevs[idx]}"] = sample + color_dict[f"auto_{self.stddevs[idx]}"] = cmap(observer / 5) img_vis = ImageVisualiser( img, cut=get_com(mask), figure_size_in=16, window=[-1.0, 1.0] @@ -426,7 +428,7 @@ def add_model_specific_args(parent_parser): def setup(self, stage=None): - cases = [p.name.replace(".nii.gz", "") for p in self.data_dir.glob(self.case_glob)] + cases = [p.name.replace(".nii.gz", "") for p in self.data_dir.glob(self.case_glob) if not p.name.startswith(".")] cases.sort() random.shuffle(cases) # will be consistent for same value of 'seed everything' cases_per_fold = math.ceil(len(cases) / self.k_folds) @@ -453,7 +455,7 @@ def setup(self, stage=None): { "id": case, "image": self.data_dir.joinpath(self.image_glob.format(case=case)), - "label": [p for p in self.data_dir.glob(self.label_glob.format(case=case))], + "label": [p for p in self.data_dir.glob(self.label_glob.format(case=case)) if not "edited" in p.name], } for case in self.validation_cases ] @@ -547,7 +549,7 @@ def main(args): args = [] for k in params: args.append(f"--{k}") - args.append(params[k]) + args.append(str(params[k])) arg_parser = ArgumentParser() arg_parser = ProbUNet.add_model_specific_args(arg_parser) @@ -556,6 +558,7 @@ def main(args): arg_parser.add_argument("--seed", type=int, default=42, help="an integer to use as seed") arg_parser.add_argument("--experiment", type=str, default="default", help="Name of experiment") arg_parser.add_argument("--working_dir", type=str, default="./working") + arg_parser.add_argument("--num_observers", type=int, default=5) arg_parser.add_argument("--offline", type=bool, default=False) arg_parser.add_argument("--comet_api_key", type=str, default=None) arg_parser.add_argument("--comet_workspace", type=str, default=None) From 62755dcf73150781c6e166deb76fcb00140f6f92 Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Sun, 11 Jul 2021 02:48:14 +0000 Subject: [PATCH 080/264] Adapt range for CT --- platipy/imaging/cnn/dataset.py | 2 +- platipy/imaging/cnn/train.py | 19 ++++++++++++++++--- 2 files changed, 17 insertions(+), 4 deletions(-) diff --git a/platipy/imaging/cnn/dataset.py b/platipy/imaging/cnn/dataset.py index c44db831..ea092744 100644 --- a/platipy/imaging/cnn/dataset.py +++ b/platipy/imaging/cnn/dataset.py @@ -16,7 +16,7 @@ def preprocess_image(img, crop_to_mm=128): img = sitk.Cast(img, sitk.sitkFloat32) img = sitk.IntensityWindowing( - img, windowMinimum=-100.0, windowMaximum=100.0, outputMinimum=-1.0, outputMaximum=1.0 + img, windowMinimum=-500.0, windowMaximum=500.0, outputMinimum=-1.0, outputMaximum=1.0 ) new_spacing = sitk.VectorDouble(3) diff --git a/platipy/imaging/cnn/train.py b/platipy/imaging/cnn/train.py index 9556b9ac..9a0796fb 100644 --- a/platipy/imaging/cnn/train.py +++ b/platipy/imaging/cnn/train.py @@ -428,7 +428,11 @@ def add_model_specific_args(parent_parser): def setup(self, stage=None): - cases = [p.name.replace(".nii.gz", "") for p in self.data_dir.glob(self.case_glob) if not p.name.startswith(".")] + cases = [ + p.name.replace(".nii.gz", "") + for p in self.data_dir.glob(self.case_glob) + if not p.name.startswith(".") + ] cases.sort() random.shuffle(cases) # will be consistent for same value of 'seed everything' cases_per_fold = math.ceil(len(cases) / self.k_folds) @@ -455,7 +459,11 @@ def setup(self, stage=None): { "id": case, "image": self.data_dir.joinpath(self.image_glob.format(case=case)), - "label": [p for p in self.data_dir.glob(self.label_glob.format(case=case)) if not "edited" in p.name], + "label": [ + p + for p in self.data_dir.glob(self.label_glob.format(case=case)) + if not "edited" in p.name + ], } for case in self.validation_cases ] @@ -549,7 +557,12 @@ def main(args): args = [] for k in params: args.append(f"--{k}") - args.append(str(params[k])) + + if isinstance(params[k], list): + for x in params[k]: + args.append(str(x)) + else: + args.append(str(params[k])) arg_parser = ArgumentParser() arg_parser = ProbUNet.add_model_specific_args(arg_parser) From 94efb7fa2500fd0fc1d2a1ee70977d69d1c72905 Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Sat, 17 Jul 2021 07:23:58 +1000 Subject: [PATCH 081/264] Adjustments to GECO --- platipy/imaging/cnn/prob_unet.py | 39 ++++++++++++++++---------------- platipy/imaging/cnn/train.py | 8 ++++++- 2 files changed, 27 insertions(+), 20 deletions(-) diff --git a/platipy/imaging/cnn/prob_unet.py b/platipy/imaging/cnn/prob_unet.py index 05f436b1..264f36c1 100644 --- a/platipy/imaging/cnn/prob_unet.py +++ b/platipy/imaging/cnn/prob_unet.py @@ -260,6 +260,7 @@ def reconstruction_loss( z_posterior=None, mask=None, top_k_percentage=None, + deterministic=True ): criterion = torch.nn.BCEWithLogitsLoss(reduction="none") @@ -298,7 +299,6 @@ def reconstruction_loss( with torch.no_grad(): norm_xe = xe / torch.sum(xe) - deterministic = True if deterministic: score = torch.log(norm_xe) else: @@ -317,8 +317,9 @@ def reconstruction_loss( ce_sum_per_instance = torch.sum(mask * xe, axis=1) ce_sum = torch.mean(ce_sum_per_instance, axis=0) + ce_mean = torch.sum(mask * xe) / torch.sum(mask) - return ce_sum, mask + return ce_sum, ce_mean, mask def loss(self, segm, analytic_kl=True, reconstruct_posterior_mean=False): """ @@ -334,7 +335,7 @@ def loss(self, segm, analytic_kl=True, reconstruct_posterior_mean=False): if "top_k_percentage" in self.loss_params: top_k_percentage = self.loss_params["top_k_percentage"] - reconstruction_loss, mask = self.reconstruction_loss( + reconstruction_loss, rec_loss_mean, mask = self.reconstruction_loss( segm, reconstruct_posterior_mean=reconstruct_posterior_mean, z_posterior=z_posterior, @@ -353,27 +354,27 @@ def loss(self, segm, analytic_kl=True, reconstruct_posterior_mean=False): num_pixels = mask.sum().item() batch_size = segm.shape[0] reconstruction_threshold = (self.loss_params["kappa"] * num_pixels) / batch_size - rec_constraint = reconstruction_loss - reconstruction_threshold - - loss = ( - self._lambda * rec_constraint # pylint: disable=access-member-before-definition - + kl_div - ) + reconstruction_threshold = self.loss_params["kappa"] with torch.no_grad(): - rc = rec_constraint.detach() + rl = rec_loss_mean.detach() + moving_avg_factor = 0.8 if self._moving_avg is None: - self._moving_avg = rc + self._moving_avg = rl else: - self._moving_avg = self._moving_avg * 0.5 + rc * (1 - 0.5) - speed = 1 - self._moving_avg = self._moving_avg.clamp( - -25, 25 - ) - self._lambda = (speed * torch.exp(self._moving_avg) * self._lambda).clamp( - self.loss_params["clamp_rec"][0], self.loss_params["clamp_rec"][1] + self._moving_avg = self._moving_avg * moving_avg_factor + rl * (1 - moving_avg_factor) + + rc = self._moving_avg - reconstruction_threshold + lambda_lower = self.loss_params["clamp_rec"][0] + lambda_upper = self.loss_params["clamp_rec"][1] + self._lambda = (torch.exp(rc) * self._lambda).clamp( + lambda_lower, lambda_upper ) + loss = ( + self._lambda * reconstruction_loss # pylint: disable=access-member-before-definition + + kl_div + ) # self._lambda = (speed * self._moving_avg * self._lambda).clamp( # self.loss_params["clamp_rec"][0], self.loss_params["clamp_rec"][1] # ) @@ -384,7 +385,7 @@ def loss(self, segm, analytic_kl=True, reconstruct_posterior_mean=False): "lambda": self._lambda, "moving_avg": self._moving_avg, "reconstruction_threshold": reconstruction_threshold, - "rec_constraint": rec_constraint, + "rec_constraint": rc, } else: diff --git a/platipy/imaging/cnn/train.py b/platipy/imaging/cnn/train.py index 9a0796fb..9b29b7fd 100644 --- a/platipy/imaging/cnn/train.py +++ b/platipy/imaging/cnn/train.py @@ -29,6 +29,7 @@ import torch import pytorch_lightning as pl +from pytorch_lightning.callbacks import LearningRateMonitor from argparse import ArgumentParser @@ -174,6 +175,7 @@ def training_step(self, batch, batch_idx): ) for k in loss: + if k == "loss": continue self.log( k, loss[k], @@ -311,7 +313,7 @@ def validation_epoch_end(self, validation_step_outputs): color_dict[f"auto_{self.stddevs[idx]}"] = cmap(observer / 5) img_vis = ImageVisualiser( - img, cut=get_com(mask), figure_size_in=16, window=[-1.0, 1.0] + img, cut=get_com(mask), figure_size_in=16, window=[-0.4, 0.8] ) contour_dict = {**obs_dict, **pred_dict} @@ -538,11 +540,15 @@ def main(args): data_module = ProbUNetDataModule(**dict_args) prob_unet = ProbUNet(**dict_args) + trainer = pl.Trainer.from_argparse_args(args) if comet_api_key is not None: trainer.logger = comet_logger + lr_monitor = LearningRateMonitor(logging_interval='step') + trainer.callbacks = [lr_monitor] + trainer.fit(prob_unet, data_module) # pylint: disable=no-member From e0fcd72f78b0bf828683f50149f90b48fff7f2bb Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Sat, 17 Jul 2021 03:56:11 +0000 Subject: [PATCH 082/264] Allow using pre-augmented data --- platipy/imaging/cnn/train.py | 62 +++++++++++++++++++-- platipy/imaging/visualisation/visualiser.py | 8 +-- 2 files changed, 60 insertions(+), 10 deletions(-) diff --git a/platipy/imaging/cnn/train.py b/platipy/imaging/cnn/train.py index 9b29b7fd..949b7ab0 100644 --- a/platipy/imaging/cnn/train.py +++ b/platipy/imaging/cnn/train.py @@ -30,6 +30,7 @@ import torch import pytorch_lightning as pl from pytorch_lightning.callbacks import LearningRateMonitor +from pytorch_lightning.callbacks.progress import ProgressBar from argparse import ArgumentParser @@ -175,7 +176,8 @@ def training_step(self, batch, batch_idx): ) for k in loss: - if k == "loss": continue + if k == "loss": + continue self.log( k, loss[k], @@ -268,7 +270,7 @@ def validation_epoch_end(self, validation_step_outputs): img_arr = np.stack(img_arrs) img = sitk.GetImageFromArray(img_arr) - # sitk.WriteImage(img, f"test_{case}.nii.gz") + sitk.WriteImage(img, f"test_{case}.nii.gz") mean_arr = np.stack(mean_arrs) mean = sitk.GetImageFromArray(mean_arr) @@ -313,7 +315,7 @@ def validation_epoch_end(self, validation_step_outputs): color_dict[f"auto_{self.stddevs[idx]}"] = cmap(observer / 5) img_vis = ImageVisualiser( - img, cut=get_com(mask), figure_size_in=16, window=[-0.4, 0.8] + img, cut=get_com(mask), figure_size_in=16, window=[-1.0, 1.0] ) contour_dict = {**obs_dict, **pred_dict} @@ -326,7 +328,11 @@ def validation_epoch_end(self, validation_step_outputs): fig.savefig(figure_path, dpi=300) plt.close("all") - self.logger.experiment.log_image(figure_path) + try: + self.logger.experiment.log_image(figure_path) + except AttributeError: + # Likely offline mode + pass sim = {k: np.zeros((len(observers), len(samples))) for k in metrics} msim = {k: np.zeros((len(observers), len(samples))) for k in metrics} @@ -377,10 +383,14 @@ class ProbUNetDataModule(pl.LightningDataModule): def __init__( self, data_dir: str = "./data", + augmented_dir: str = None, working_dir: str = "./working", case_glob="images/*.nii.gz", image_glob="images/{case}.nii.gz", label_glob="labels/{case}_*.nii.gz", + augmented_case_glob="{case}/*", + augmented_image_glob="images/{augmented_case}.nii.gz", + augmented_label_glob="labels/{augmented_case}_*.nii.gz", fold=0, k_folds=5, batch_size=5, @@ -391,11 +401,17 @@ def __init__( ): super().__init__() self.data_dir = Path(data_dir) + self.augmented_dir = None + if augmented_dir is not None: + self.augmented_dir = Path(augmented_dir) self.working_dir = Path(working_dir) self.case_glob = case_glob self.image_glob = image_glob self.label_glob = label_glob + self.augmented_case_glob = augmented_case_glob + self.augmented_image_glob = augmented_image_glob + self.augmented_label_glob = augmented_label_glob self.fold = fold self.k_folds = k_folds @@ -417,6 +433,7 @@ def __init__( def add_model_specific_args(parent_parser): parser = parent_parser.add_argument_group("Data Loader") parser.add_argument("--data_dir", type=str, default="./data") + parser.add_argument("--augmented_dir", type=str, default=None) parser.add_argument("--fold", type=int, default=0) parser.add_argument("--k_folds", type=int, default=5) parser.add_argument("--batch_size", type=int, default=5) @@ -424,6 +441,9 @@ def add_model_specific_args(parent_parser): parser.add_argument("--case_glob", type=str, default="images/*.nii.gz") parser.add_argument("--image_glob", type=str, default="images/{case}.nii.gz") parser.add_argument("--label_glob", type=str, default="labels/{case}_*.nii.gz") + parser.add_argument("--augmented_case_glob", type=str, default=None) + parser.add_argument("--augmented_image_glob", type=str, default=None) + parser.add_argument("--augmented_label_glob", type=str, default=None) parser.add_argument("--crop_to_mm", type=int, default=128) return parent_parser @@ -457,6 +477,35 @@ def setup(self, stage=None): for case in self.train_cases ] + # If a directory with augmented data is specified, use that for training as well + if self.augmented_dir is not None: + + for case in self.train_cases: + augmented_cases = [ + p.name.replace(".nii.gz", "") + for p in self.augmented_dir.glob(self.augmented_case_glob) + if not p.name.startswith(".") + ] + train_data += [ + { + "id": case, + "image": self.augmented_dir.joinpath( + self.augmented_image_glob.format( + case=case, augmented_case=augmented_case + ) + ), + "label": [ + p + for p in self.augmented_dir.glob( + self.augmented_label_glob.format( + case=case, augmented_case=augmented_case + ) + ) + ], + } + for augmented_case in augmented_cases + ] + validation_data = [ { "id": case, @@ -546,8 +595,9 @@ def main(args): if comet_api_key is not None: trainer.logger = comet_logger - lr_monitor = LearningRateMonitor(logging_interval='step') - trainer.callbacks = [lr_monitor] + lr_monitor = LearningRateMonitor(logging_interval="step") + # bar = ProgressBar() + trainer.callbacks.append(lr_monitor) trainer.fit(prob_unet, data_module) # pylint: disable=no-member diff --git a/platipy/imaging/visualisation/visualiser.py b/platipy/imaging/visualisation/visualiser.py index a682ddcb..3f3e2c8b 100644 --- a/platipy/imaging/visualisation/visualiser.py +++ b/platipy/imaging/visualisation/visualiser.py @@ -520,7 +520,7 @@ def _display_slice(self): interpolation="none", origin={"normal": "upper", "reversed": "lower"}[self.__origin], cmap=self.__colormap, - clim=(self.__window[0], self.__window[0] + self.__window[1]), + clim=self.__window, ) cor_view = ax_cor.imshow( cor_img, @@ -528,7 +528,7 @@ def _display_slice(self): aspect=asp, interpolation="none", cmap=self.__colormap, - clim=(self.__window[0], self.__window[0] + self.__window[1]), + clim=self.__window, ) sag_view = ax_sag.imshow( sag_img, @@ -536,7 +536,7 @@ def _display_slice(self): aspect=asp, interpolation="none", cmap=self.__colormap, - clim=(self.__window[0], self.__window[0] + self.__window[1]), + clim=self.__window, ) ax_ax.axis("off") @@ -616,7 +616,7 @@ def _display_slice(self): interpolation="none", origin=org, cmap=self.__colormap, - clim=(self.__window[0], self.__window[0] + self.__window[1]), + clim=self.__window, ) ax.axis("off") From 7d7863dd9162a3bbee66c970e4f6c07e0b62273e Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Sat, 17 Jul 2021 15:59:25 +1000 Subject: [PATCH 083/264] Corrections to using augmented data --- platipy/imaging/cnn/train.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/platipy/imaging/cnn/train.py b/platipy/imaging/cnn/train.py index 949b7ab0..0a3b846c 100644 --- a/platipy/imaging/cnn/train.py +++ b/platipy/imaging/cnn/train.py @@ -401,9 +401,7 @@ def __init__( ): super().__init__() self.data_dir = Path(data_dir) - self.augmented_dir = None - if augmented_dir is not None: - self.augmented_dir = Path(augmented_dir) + self.augmented_dir = augmented_dir self.working_dir = Path(working_dir) self.case_glob = case_glob @@ -481,22 +479,25 @@ def setup(self, stage=None): if self.augmented_dir is not None: for case in self.train_cases: + + case_aug_dir = Path(self.augmented_dir.format(case=case)) augmented_cases = [ p.name.replace(".nii.gz", "") - for p in self.augmented_dir.glob(self.augmented_case_glob) + for p in case_aug_dir.glob(self.augmented_case_glob.format(case=case)) if not p.name.startswith(".") ] + print(augmented_cases) train_data += [ { - "id": case, - "image": self.augmented_dir.joinpath( + "id": f"{case}_{augmented_case}", + "image": case_aug_dir.joinpath( self.augmented_image_glob.format( case=case, augmented_case=augmented_case ) ), "label": [ p - for p in self.augmented_dir.glob( + for p in case_aug_dir.glob( self.augmented_label_glob.format( case=case, augmented_case=augmented_case ) @@ -505,7 +506,7 @@ def setup(self, stage=None): } for augmented_case in augmented_cases ] - + print(train_data) validation_data = [ { "id": case, From 2ec283db7b034ae417f70122e5b3978360d107a8 Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Mon, 19 Jul 2021 09:43:23 +1000 Subject: [PATCH 084/264] read data correctly when already generated --- platipy/imaging/cnn/dataset.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/platipy/imaging/cnn/dataset.py b/platipy/imaging/cnn/dataset.py index ea092744..44db4f09 100644 --- a/platipy/imaging/cnn/dataset.py +++ b/platipy/imaging/cnn/dataset.py @@ -1,3 +1,4 @@ +import re from pathlib import Path import numpy as np @@ -149,11 +150,17 @@ def __init__(self, data, working_dir, augment_on_the_fly=True, crop_to_mm=128): if len(existing_images) > 0: logger.debug(f"Image for case already exist: {case_id}") - for z_slice in range(len(existing_images)): + for img_path in existing_images: + z_matches = re.findall(f"{case_id}_([0-9]*)\.npy", img_path.name) + if len(z_matches) == 0: continue + z_slice = int(z_matches[0]) + img_file = self.img_dir.joinpath(f"{case_id}_{z_slice}.npy") + assert img_file.exists() for obs in range(len(structure_paths)): mask_file = self.mask_dir.joinpath(f"{case_id}_{obs}_{z_slice}.npy") + assert mask_file.exists() self.slices.append( { "z": z_slice, From 6fbb31769cb8fee7e16784459bc6e2bbcd8989d9 Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Sun, 18 Jul 2021 23:47:33 +0000 Subject: [PATCH 085/264] Correct top_k_percentage mask --- platipy/imaging/cnn/prob_unet.py | 43 +++++++++++++++++++------------- platipy/imaging/cnn/train.py | 5 +++- 2 files changed, 29 insertions(+), 19 deletions(-) diff --git a/platipy/imaging/cnn/prob_unet.py b/platipy/imaging/cnn/prob_unet.py index 264f36c1..f6c3293e 100644 --- a/platipy/imaging/cnn/prob_unet.py +++ b/platipy/imaging/cnn/prob_unet.py @@ -250,7 +250,10 @@ def topk_mask(self, score, k): values, _ = torch.topk(score, 1, axis=1) _, indices = torch.topk(values, k, axis=0) return torch.scatter_add( - torch.zeros(score.shape[0]).to(score.device), 0, indices.reshape(-1), torch.ones(score.shape[0]).to(score.device) + torch.zeros(score.shape[0]).to(score.device), + 0, + indices.reshape(-1), + torch.ones(score.shape[0]).to(score.device), ) def reconstruction_loss( @@ -260,7 +263,7 @@ def reconstruction_loss( z_posterior=None, mask=None, top_k_percentage=None, - deterministic=True + deterministic=True, ): criterion = torch.nn.BCEWithLogitsLoss(reduction="none") @@ -278,10 +281,11 @@ def reconstruction_loss( ##### num_classes = reconstruction.shape[1] - y_flat = torch.reshape(reconstruction, (-1, num_classes)) - t_flat = torch.reshape(segm, (-1, num_classes)) + y_flat = torch.transpose(reconstruction, 1, -1).reshape((-1, num_classes)) + t_flat = torch.transpose(segm, 1, -1).reshape((-1, num_classes)) + n_pixels_in_batch = y_flat.shape[0] if mask is None: - mask = torch.ones(torch.reshape(t_flat, (-1, 2)).shape[0]) + mask = torch.ones(n_pixels_in_batch) else: assert ( mask.shape == segm.shape @@ -289,7 +293,6 @@ def reconstruction_loss( mask = torch.reshape(segm, (-1,)) mask = mask.to(y_flat.device) - n_pixels_in_batch = y_flat.shape[0] xe = criterion(input=y_flat, target=t_flat) if top_k_percentage is not None: @@ -306,14 +309,17 @@ def reconstruction_loss( raise NotImplementedError("Still need to implement Gumbel trick") score = score + torch.log(mask.unsqueeze(1).repeat((1, num_classes))) + top_k_mask = self.topk_mask(score, k_pixels) top_k_mask = top_k_mask.to(y_flat.device) mask = mask * top_k_mask batch_size = segm.shape[0] - xe = torch.reshape(xe, shape=(batch_size, -1)) - mask = mask.repeat((1, num_classes)) - mask = torch.reshape(mask, shape=(batch_size, -1)) + xe = xe.reshape((batch_size, -1, num_classes)).transpose(-1, 1).reshape((batch_size, -1)) + mask = mask.unsqueeze(1).repeat((1, num_classes)) + mask = ( + mask.reshape((batch_size, -1, num_classes)).transpose(-1, 1).reshape((batch_size, -1)) + ) ce_sum_per_instance = torch.sum(mask * xe, axis=1) ce_sum = torch.mean(ce_sum_per_instance, axis=0) @@ -330,11 +336,11 @@ def loss(self, segm, analytic_kl=True, reconstruct_posterior_mean=False): kl_div = torch.mean(self.kl_divergence(analytic=analytic_kl, z_posterior=z_posterior)) - # Here we use the posterior sample sampled above top_k_percentage = None if "top_k_percentage" in self.loss_params: top_k_percentage = self.loss_params["top_k_percentage"] + # Here we use the posterior sample sampled above reconstruction_loss, rec_loss_mean, mask = self.reconstruction_loss( segm, reconstruct_posterior_mean=reconstruct_posterior_mean, @@ -363,21 +369,22 @@ def loss(self, segm, analytic_kl=True, reconstruct_posterior_mean=False): if self._moving_avg is None: self._moving_avg = rl else: - self._moving_avg = self._moving_avg * moving_avg_factor + rl * (1 - moving_avg_factor) + self._moving_avg = self._moving_avg * moving_avg_factor + rl * ( + 1 - moving_avg_factor + ) rc = self._moving_avg - reconstruction_threshold lambda_lower = self.loss_params["clamp_rec"][0] lambda_upper = self.loss_params["clamp_rec"][1] - self._lambda = (torch.exp(rc) * self._lambda).clamp( - lambda_lower, lambda_upper - ) + self._lambda = (torch.exp(rc) * self._lambda).clamp(lambda_lower, lambda_upper) loss = ( - self._lambda * reconstruction_loss # pylint: disable=access-member-before-definition + self._lambda + * reconstruction_loss # pylint: disable=access-member-before-definition + kl_div ) -# self._lambda = (speed * self._moving_avg * self._lambda).clamp( -# self.loss_params["clamp_rec"][0], self.loss_params["clamp_rec"][1] -# ) + # self._lambda = (speed * self._moving_avg * self._lambda).clamp( + # self.loss_params["clamp_rec"][0], self.loss_params["clamp_rec"][1] + # ) return { "loss": loss, "rec_loss": reconstruction_loss, diff --git a/platipy/imaging/cnn/train.py b/platipy/imaging/cnn/train.py index 0a3b846c..d0fe8fe6 100644 --- a/platipy/imaging/cnn/train.py +++ b/platipy/imaging/cnn/train.py @@ -97,7 +97,10 @@ def __init__( loss_params = None if self.hparams.loss_type == "elbo": - loss_params = {"beta": self.hparams.beta} + loss_params = { + "beta": self.hparams.beta, + "top_k_percentage": self.hparams.top_k_percentage, + } if self.hparams.loss_type == "geco": loss_params = { From c19809bdc11a4b587d1e08591f3734775b739be3 Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Mon, 19 Jul 2021 00:09:25 +0000 Subject: [PATCH 086/264] Making target spacing configurable --- platipy/imaging/cnn/dataset.py | 28 +++++++++++++--------------- platipy/imaging/cnn/train.py | 15 ++++++++++++--- 2 files changed, 25 insertions(+), 18 deletions(-) diff --git a/platipy/imaging/cnn/dataset.py b/platipy/imaging/cnn/dataset.py index 44db4f09..38b9eec4 100644 --- a/platipy/imaging/cnn/dataset.py +++ b/platipy/imaging/cnn/dataset.py @@ -13,22 +13,17 @@ from loguru import logger -def preprocess_image(img, crop_to_mm=128): +def preprocess_image(img, spacing=[1, 1, 1], crop_to_mm=128): img = sitk.Cast(img, sitk.sitkFloat32) img = sitk.IntensityWindowing( img, windowMinimum=-500.0, windowMaximum=500.0, outputMinimum=-1.0, outputMaximum=1.0 ) - new_spacing = sitk.VectorDouble(3) - new_spacing[0] = 1.0 - new_spacing[1] = 1.0 - new_spacing[2] = img.GetSpacing()[2] - new_size = sitk.VectorUInt32(3) - new_size[0] = int(img.GetSize()[0] * img.GetSpacing()[0]) - new_size[1] = int(img.GetSize()[1] * img.GetSpacing()[1]) - new_size[2] = int(img.GetSize()[2]) + new_size[0] = int(img.GetSize()[0] * (img.GetSpacing()[0] / spacing[0])) + new_size[1] = int(img.GetSize()[1] * (img.GetSpacing()[1] / spacing[1])) + new_size[2] = int(img.GetSize()[2] * (img.GetSpacing()[2] / spacing[2])) if new_size[0] < crop_to_mm: new_size[0] = crop_to_mm @@ -42,7 +37,7 @@ def preprocess_image(img, crop_to_mm=128): sitk.Transform(), sitk.sitkLinear, img.GetOrigin(), - new_spacing, + spacing, img.GetDirection(), -1, img.GetPixelID(), @@ -89,7 +84,6 @@ def prepare_transforms(): ) ), # execute 0 to 2 of the following (less important) augmenters per image - # don't execute all of them, as that would often be way too strong iaa.SomeOf( (0, 2), [ @@ -113,7 +107,9 @@ def prepare_transforms(): class NiftiDataset(torch.utils.data.Dataset): """PyTorch Dataset for processing Nifti data""" - def __init__(self, data, working_dir, augment_on_the_fly=True, crop_to_mm=128): + def __init__( + self, data, working_dir, augment_on_the_fly=True, spacing=[1, 1, 1], crop_to_mm=128 + ): """Prepare a dataset from Nifti images/labels Args: @@ -151,8 +147,9 @@ def __init__(self, data, working_dir, augment_on_the_fly=True, crop_to_mm=128): logger.debug(f"Image for case already exist: {case_id}") for img_path in existing_images: - z_matches = re.findall(f"{case_id}_([0-9]*)\.npy", img_path.name) - if len(z_matches) == 0: continue + z_matches = re.findall(fr"{case_id}_([0-9]*)\.npy", img_path.name) + if len(z_matches) == 0: + continue z_slice = int(z_matches[0]) img_file = self.img_dir.joinpath(f"{case_id}_{z_slice}.npy") @@ -173,9 +170,10 @@ def __init__(self, data, working_dir, augment_on_the_fly=True, crop_to_mm=128): continue + logger.debug(f"Generating images for case: {case_id}") img = sitk.ReadImage(img_path) - img = preprocess_image(img, crop_to_mm=crop_to_mm) + img = preprocess_image(img, spacing=spacing, crop_to_mm=crop_to_mm) observers = [] for obs, structure_path in enumerate(structure_paths): diff --git a/platipy/imaging/cnn/train.py b/platipy/imaging/cnn/train.py index d0fe8fe6..d3c29833 100644 --- a/platipy/imaging/cnn/train.py +++ b/platipy/imaging/cnn/train.py @@ -273,12 +273,13 @@ def validation_epoch_end(self, validation_step_outputs): img_arr = np.stack(img_arrs) img = sitk.GetImageFromArray(img_arr) - sitk.WriteImage(img, f"test_{case}.nii.gz") + img.SetSpacing(self.hparams.spacing) mean_arr = np.stack(mean_arrs) mean = sitk.GetImageFromArray(mean_arr) mean = sitk.Cast(mean, sitk.sitkUInt8) mean = post_process(mean) + mean.CopyInformation(img) # sitk.WriteImage(mean, f"val_mean_{case}_mean.nii.gz") obs_dict = {} @@ -303,6 +304,7 @@ def validation_epoch_end(self, validation_step_outputs): mask_arr = np.stack(mask_arrs) mask = sitk.GetImageFromArray(mask_arr) mask = sitk.Cast(mask, sitk.sitkUInt8) + mask.CopyInformation(img) # sitk.WriteImage(mask, f"val_mask_{case}_{observer}.nii.gz") observers.append(mask) obs_dict[f"manual_{observer}"] = mask @@ -312,6 +314,7 @@ def validation_epoch_end(self, validation_step_outputs): sample = sitk.GetImageFromArray(sample_arr) sample = sitk.Cast(sample, sitk.sitkUInt8) sample = post_process(sample) + sample.CopyInformation(img) # sitk.WriteImage(sample, f"val_sample_{case}_{observer}.nii.gz") samples.append(sample) pred_dict[f"auto_{self.stddevs[idx]}"] = sample @@ -400,6 +403,7 @@ def __init__( num_workers=4, crop_to_mm=128, num_observers=5, + spacing=[1, 1, 1], **kwargs, ): super().__init__() @@ -424,6 +428,7 @@ def __init__( self.num_workers = num_workers self.crop_to_mm = crop_to_mm self.num_observers = num_observers + self.spacing = spacing self.training_set = None self.validation_set = None @@ -509,7 +514,7 @@ def setup(self, stage=None): } for augmented_case in augmented_cases ] - print(train_data) + validation_data = [ { "id": case, @@ -524,7 +529,10 @@ def setup(self, stage=None): ] self.training_set = NiftiDataset( - train_data, self.working_dir.joinpath("train"), crop_to_mm=self.crop_to_mm + train_data, + self.working_dir.joinpath("train"), + spacing=self.spacing, + crop_to_mm=self.crop_to_mm, ) self.validation_set = NiftiDataset( validation_data, @@ -632,6 +640,7 @@ def main(args): arg_parser.add_argument("--experiment", type=str, default="default", help="Name of experiment") arg_parser.add_argument("--working_dir", type=str, default="./working") arg_parser.add_argument("--num_observers", type=int, default=5) + arg_parser.add_argument("--spacing", type=list, default=[1, 1, 1]) arg_parser.add_argument("--offline", type=bool, default=False) arg_parser.add_argument("--comet_api_key", type=str, default=None) arg_parser.add_argument("--comet_workspace", type=str, default=None) From cac5578d7d57906d1eb16376ce95eaabaae18680 Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Tue, 20 Jul 2021 05:29:01 +0000 Subject: [PATCH 087/264] Improvements to deformable augmentation tool --- platipy/imaging/generation/augment.py | 228 ++++++++++++++++++++++---- 1 file changed, 198 insertions(+), 30 deletions(-) diff --git a/platipy/imaging/generation/augment.py b/platipy/imaging/generation/augment.py index b6fd7455..8ed82136 100644 --- a/platipy/imaging/generation/augment.py +++ b/platipy/imaging/generation/augment.py @@ -16,8 +16,16 @@ from collections.abc import Iterable import random +from pathlib import Path + +from argparse import ArgumentParser + import SimpleITK as sitk +from loguru import logger + +from platipy.imaging import ImageVisualiser + from platipy.imaging.generation.dvf import ( generate_field_shift, generate_field_expand, @@ -83,39 +91,21 @@ def apply_augmentation(image, augmentation, masks=[]): return image_deformed, dvf -def generate_random_augmentation(ct_image, masks): - - random.shuffle(masks) - # mask_count = len(masks) - # masks = masks[: random.randint(2, 5)] - - # print(len(masks)) - augmentation_types = [ - { - "class": ShiftAugment, - "args": {"vector_shift": [(-10, 10), (10, 10), (-10, 10)], "gaussian_smooth": (3, 5)}, - }, - { - "class": ContractAugment, - "args": { - "vector_contract": [(0, 10), (0, 10), (0, 10)], - "gaussian_smooth": (3, 5), - "bone_mask": True, - }, - }, - { - "class": ExpandAugment, - "args": { - "vector_expand": [(0, 10), (0, 10), (0, 10)], - "gaussian_smooth": (3, 5), - "bone_mask": True, - }, - }, - ] +def generate_random_augmentation(ct_image, masks, augmentation_types): augmentation = [] + + probabilities = [a["probability"] for a in augmentation_types] + prob_total = sum(probabilities) + prob_none = 1.0 - prob_total + if prob_none < 0: + prob_none = 0 + for mask in masks: - aug = random.choice(augmentation_types) + aug = random.choices(augmentation_types + [None], weights=probabilities + [prob_none])[0] + + if aug is None: + continue aug_class = aug["class"] aug_args = {} @@ -203,3 +193,181 @@ def augment(self): gaussian_smooth=self.gaussian_smooth, ) return transform, dvf + + +def augment_data(args): + + random.seed(args.seed) + + augmentation_types = [] + + if args.enable_shift: + augmentation_types.append( + { + "class": ShiftAugment, + "args": { + "vector_shift": [ + tuple(args.shift_x_range), + tuple(args.shift_y_range), + tuple(args.shift_z_range), + ], + "gaussian_smooth": tuple(args.shift_smooth_range), + }, + "probability": args.shift_probability, + } + ) + + if args.enable_contract: + augmentation_types.append( + { + "class": ContractAugment, + "args": { + "vector_contract": [ + tuple(args.contract_x_range), + tuple(args.contract_y_range), + tuple(args.contract_z_range), + ], + "gaussian_smooth": tuple(args.contract_smooth_range), + "bone_mask": args.contract_bone_mask, + }, + "probability": args.contract_probability, + } + ) + + if args.enable_expand: + augmentation_types.append( + { + "class": ExpandAugment, + "args": { + "vector_expand": [ + tuple(args.expand_x_range), + tuple(args.expand_y_range), + tuple(args.expand_z_range), + ], + "gaussian_smooth": tuple(args.expand_smooth_range), + "bone_mask": args.expand_bone_mask, + }, + "probability": args.expand_probability, + } + ) + + data_dir = Path(args.data_dir) + output_dir = Path(args.output_dir) + + cases = [ + p.name.replace(".nii.gz", "") + for p in data_dir.glob(args.case_glob) + if not p.name.startswith(".") + ] + cases.sort() + + data = { + case: { + "image": data_dir.joinpath(args.image_glob.format(case=case)), + "label": [p for p in data_dir.glob(args.label_glob.format(case=case))], + } + for case in cases + } + + for case in cases: + + logger.info(f"Augmenting for case: {case}") + + ct_image = sitk.ReadImage(str(data[case]["image"])) + + # Get list of structures to generate augmentations off + all_masks = [] + all_names = [] + for structure_path in data[case]["label"]: + + mask = sitk.ReadImage(str(structure_path)) + + all_masks.append(mask) + all_names.append(structure_path.name.replace(".nii.gz", "")) + + # Generate 10 random augmentations per case + for i in range(args.augmentations_per_case): + + logger.debug("Generating augmentation") + + augmented_case_path = output_dir.joinpath(case, f"augment_{i}") + augmented_case_path.mkdir(exist_ok=True, parents=True) + + augmentation = generate_random_augmentation(ct_image, all_masks, augmentation_types) + + dvf = None + + if len(augmentation) == 0: + logger.debug( + "No augmentations generated, generated image won't differ from original" + ) + + augmented_image = ct_image + augmented_masks = all_masks + else: + + logger.debug("Applying augmentation") + augmented_image, augmented_masks, dvf = apply_augmentation( + ct_image, augmentation, masks=all_masks + ) + + augmented_image_path = augmented_case_path.joinpath("CT.nii.gz") + sitk.WriteImage(augmented_image, str(augmented_image_path)) + + vis = ImageVisualiser(image=ct_image, figure_size_in=6) + vis.add_comparison_overlay(augmented_image) + if dvf is not None: + vis.add_vector_overlay(dvf, arrow_scale=1, subsample=(4, 12, 12)) + for mask_name, mask, augmented_mask in zip(all_names, all_masks, augmented_masks): + vis.add_contour({f"{mask_name}": mask, f"{mask_name} (augmented)": augmented_mask}) + + logger.debug(f"Applying augmentation to mask: {mask_name}") + augmented_mask_path = augmented_case_path.joinpath(f"{mask_name}.nii.gz") + sitk.WriteImage(augmented_mask, str(augmented_mask_path)) + + fig = vis.show() + + figure_path = augmented_case_path.joinpath("aug.png") + fig.savefig(figure_path, bbox_inches="tight") + + +if __name__ == "__main__": + + arg_parser = ArgumentParser() + arg_parser.add_argument("--seed", type=int, default=42, help="an integer to use as seed") + arg_parser.add_argument("--data_dir", type=str, default="./data") + arg_parser.add_argument("--output_dir", type=str, default="./augment") + arg_parser.add_argument("--case_glob", type=str, default="images/*.nii.gz") + arg_parser.add_argument("--image_glob", type=str, default="images/{case}.nii.gz") + arg_parser.add_argument("--label_glob", type=str, default="labels/{case}_*.nii.gz") + arg_parser.add_argument( + "--augmentations_per_case", + type=int, + default=10, + help="How many augmented images per case to generate", + ) + + arg_parser.add_argument("--enable_shift", type=bool, default=True) + arg_parser.add_argument("--shift_x_range", nargs="+", type=int, default=[-10, 10]) + arg_parser.add_argument("--shift_y_range", nargs="+", type=int, default=[-10, 10]) + arg_parser.add_argument("--shift_z_range", nargs="+", type=int, default=[-10, 10]) + arg_parser.add_argument("--shift_smooth_range", nargs="+", type=int, default=[3, 5]) + arg_parser.add_argument("--shift_probability", type=float, default=0.5) + + arg_parser.add_argument("--enable_expand", type=bool, default=True) + arg_parser.add_argument("--expand_x_range", nargs="+", type=int, default=[0, 10]) + arg_parser.add_argument("--expand_y_range", nargs="+", type=int, default=[0, 10]) + arg_parser.add_argument("--expand_z_range", nargs="+", type=int, default=[0, 10]) + arg_parser.add_argument("--expand_smooth_range", nargs="+", type=int, default=[3, 5]) + arg_parser.add_argument("--expand_bone_mask", type=bool, default=True) + arg_parser.add_argument("--expand_probability", type=float, default=0.5) + + arg_parser.add_argument("--enable_contract", type=bool, default=True) + arg_parser.add_argument("--contract_x_range", nargs="+", type=int, default=[0, 10]) + arg_parser.add_argument("--contract_y_range", nargs="+", type=int, default=[0, 10]) + arg_parser.add_argument("--contract_z_range", nargs="+", type=int, default=[0, 10]) + arg_parser.add_argument("--contract_smooth_range", nargs="+", type=int, default=[3, 5]) + arg_parser.add_argument("--contract_bone_mask", type=bool, default=True) + arg_parser.add_argument("--contract_probability", type=float, default=0.5) + + augment_data(arg_parser.parse_args()) From fbd1c03f0ac379f233dc1fc0ec13be5492b14383 Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Sun, 25 Jul 2021 02:15:05 +0000 Subject: [PATCH 088/264] compute loss in contour area --- platipy/imaging/cnn/dataset.py | 98 +++++++++++++++++++++++++------- platipy/imaging/cnn/prob_unet.py | 10 +++- platipy/imaging/cnn/train.py | 16 +++--- 3 files changed, 95 insertions(+), 29 deletions(-) diff --git a/platipy/imaging/cnn/dataset.py b/platipy/imaging/cnn/dataset.py index 38b9eec4..b35071ba 100644 --- a/platipy/imaging/cnn/dataset.py +++ b/platipy/imaging/cnn/dataset.py @@ -1,4 +1,5 @@ import re +import copy from pathlib import Path import numpy as np @@ -13,6 +14,40 @@ from loguru import logger +def get_union_mask(masks): + + union_mask = copy.copy(masks[0]) + for mask in masks[1:]: + union_mask += mask + + return sitk.Cast(union_mask > 0, sitk.sitkUInt8) + + +def get_intersection_mask(masks): + + intersection_mask = copy.copy(masks[0]) + for mask in masks[1:]: + intersection_mask += mask + + return sitk.Cast(intersection_mask == len(masks), sitk.sitkUInt8) + + +def get_contour_mask(masks, kernel=5): + + if not hasattr(kernel, "__iter__"): + kernel = (kernel,) * 3 + + union_mask = get_union_mask(masks) + intersection_mask = get_intersection_mask(masks) + + union_mask = sitk.BinaryDilate(union_mask, np.abs(kernel).astype(int).tolist(), sitk.sitkBall) + intersection_mask = sitk.BinaryErode( + intersection_mask, np.abs(kernel).astype(int).tolist(), sitk.sitkBall + ) + + return union_mask - intersection_mask + + def preprocess_image(img, spacing=[1, 1, 1], crop_to_mm=128): img = sitk.Cast(img, sitk.sitkFloat32) @@ -127,12 +162,12 @@ def __init__( self.working_dir = Path(working_dir) self.img_dir = working_dir.joinpath("img") - self.mask_dir = working_dir.joinpath("mask") - - self.data_exists = self.img_dir.exists() + self.label_dir = working_dir.joinpath("label") + self.contour_mask_dir = working_dir.joinpath("contour_mask") self.img_dir.mkdir(exist_ok=True, parents=True) - self.mask_dir.mkdir(exist_ok=True, parents=True) + self.label_dir.mkdir(exist_ok=True, parents=True) + self.contour_mask_dir.mkdir(exist_ok=True, parents=True) for case in data: case_id = case["id"] @@ -155,14 +190,18 @@ def __init__( img_file = self.img_dir.joinpath(f"{case_id}_{z_slice}.npy") assert img_file.exists() + contour_mask_file = self.contour_mask_dir.joinpath(f"{case_id}_{z_slice}.npy") + assert contour_mask_file.exists() + for obs in range(len(structure_paths)): - mask_file = self.mask_dir.joinpath(f"{case_id}_{obs}_{z_slice}.npy") - assert mask_file.exists() + label_file = self.label_dir.joinpath(f"{case_id}_{obs}_{z_slice}.npy") + assert label_file.exists() self.slices.append( { "z": z_slice, "image": img_file, - "mask": mask_file, + "label": label_file, + "contour_mask": contour_mask_file, "case": case_id, "observer": obs, } @@ -178,26 +217,36 @@ def __init__( observers = [] for obs, structure_path in enumerate(structure_paths): structure_path = str(structure_path) - mask = sitk.ReadImage(structure_path) - mask = resample_mask_to_image(img, mask) - observers.append(mask) + label = sitk.ReadImage(structure_path) + label = resample_mask_to_image(img, label) + observers.append(label) + + contour_mask = get_contour_mask(observers) + sitk.WriteImage(contour_mask, f"cm_{case['id']}.nii.gz") for z_slice in range(img.GetSize()[2]): + # Save the image slice img_slice = img[:, :, z_slice] img_file = self.img_dir.joinpath(f"{case_id}_{z_slice}.npy") np.save(img_file, sitk.GetArrayFromImage(img_slice)) - for obs, mask in enumerate(observers): + # Save the contour mask slice + contour_mask_slice = contour_mask[:, :, z_slice] + contour_mask_file = self.contour_mask_dir.joinpath(f"{case_id}_{z_slice}.npy") + np.save(contour_mask_file, sitk.GetArrayFromImage(contour_mask_slice)) + + for obs, label in enumerate(observers): - mask_slice = mask[:, :, z_slice] - mask_file = self.mask_dir.joinpath(f"{case_id}_{obs}_{z_slice}.npy") - np.save(mask_file, sitk.GetArrayFromImage(mask_slice)) + label_slice = label[:, :, z_slice] + label_file = self.label_dir.joinpath(f"{case_id}_{obs}_{z_slice}.npy") + np.save(label_file, sitk.GetArrayFromImage(label_slice)) self.slices.append( { "z": z_slice, "image": img_file, - "mask": mask_file, + "label": label_file, + "contour_mask": contour_mask_file, "case": case_id, "observer": obs, } @@ -209,19 +258,26 @@ def __len__(self): def __getitem__(self, index): img = np.load(self.slices[index]["image"]) - mask = np.load(self.slices[index]["mask"]) + label = np.load(self.slices[index]["label"]) + contour_mask = np.load(self.slices[index]["contour_mask"]) if self.transforms: - segmap = SegmentationMapsOnImage(mask, shape=mask.shape) - img, mask = self.transforms(image=img, segmentation_maps=segmap) - mask = mask.get_arr() + seg_arr = np.concatenate( + (np.expand_dims(label, 2), np.expand_dims(contour_mask, 2)), 2 + ) + segmap = SegmentationMapsOnImage(seg_arr, shape=label.shape) + img, seg = self.transforms(image=img, segmentation_maps=segmap) + label = seg.get_arr()[:, :, 0].squeeze() + contour_mask = seg.get_arr()[:, :, 1].squeeze() img = torch.FloatTensor(img) - mask = torch.LongTensor(mask) + label = torch.LongTensor(label) + contour_mask = torch.LongTensor(contour_mask) return ( img.unsqueeze(0), - mask, + label, + contour_mask, { "case": self.slices[index]["case"], "observer": self.slices[index]["observer"], diff --git a/platipy/imaging/cnn/prob_unet.py b/platipy/imaging/cnn/prob_unet.py index f6c3293e..40fee81f 100644 --- a/platipy/imaging/cnn/prob_unet.py +++ b/platipy/imaging/cnn/prob_unet.py @@ -327,7 +327,7 @@ def reconstruction_loss( return ce_sum, ce_mean, mask - def loss(self, segm, analytic_kl=True, reconstruct_posterior_mean=False): + def loss(self, segm, analytic_kl=True, reconstruct_posterior_mean=False, mask=None): """ Calculate the evidence lower bound of the log-likelihood of P(Y|X) """ @@ -340,12 +340,20 @@ def loss(self, segm, analytic_kl=True, reconstruct_posterior_mean=False): if "top_k_percentage" in self.loss_params: top_k_percentage = self.loss_params["top_k_percentage"] + loss_mask = None + if self.loss_params["contour_loss_lambda_threshold"]: + if ( + self._lambda <= self.loss_params["contour_loss_lambda_threshold"] + ): # pylint: disable=access-member-before-definition + loss_mask = mask + # Here we use the posterior sample sampled above reconstruction_loss, rec_loss_mean, mask = self.reconstruction_loss( segm, reconstruct_posterior_mean=reconstruct_posterior_mean, z_posterior=z_posterior, top_k_percentage=top_k_percentage, + mask=loss_mask, ) if self.loss_type == "elbo": diff --git a/platipy/imaging/cnn/train.py b/platipy/imaging/cnn/train.py index d3c29833..eb422fcc 100644 --- a/platipy/imaging/cnn/train.py +++ b/platipy/imaging/cnn/train.py @@ -99,16 +99,17 @@ def __init__( if self.hparams.loss_type == "elbo": loss_params = { "beta": self.hparams.beta, - "top_k_percentage": self.hparams.top_k_percentage, } if self.hparams.loss_type == "geco": loss_params = { "kappa": self.hparams.kappa, "clamp_rec": self.hparams.clamp_rec, - "top_k_percentage": self.hparams.top_k_percentage, } + loss_params["top_k_percentage"] = self.hparams.top_k_percentage + loss_params["contour_loss_lambda_threshold"] = self.hparams.contour_loss_lambda_threshold + self.prob_unet = ProbabilisticUnet( self.hparams.input_channels, self.hparams.num_classes, @@ -139,6 +140,7 @@ def add_model_specific_args(parent_parser): parser.add_argument("--kappa", type=float, default=0.02) parser.add_argument("--clamp_rec", nargs="+", type=float, default=[1e-5, 1e5]) parser.add_argument("--top_k_percentage", type=float, default=None) + parser.add_argument("--contour_loss_lambda_threshold", type=float, default=None) return parent_parser @@ -157,12 +159,12 @@ def configure_optimizers(self): return [optimizer], [scheduler] - def training_step(self, batch, batch_idx): + def training_step(self, batch, _): - x, y, _ = batch + x, y, m, _ = batch self.prob_unet.forward(x, y, training=True) - loss = self.prob_unet.loss(y, analytic_kl=True) + loss = self.prob_unet.loss(y, analytic_kl=True, mask=m) reg_loss = ( l2_regularisation(self.prob_unet.posterior) + l2_regularisation(self.prob_unet.prior) @@ -198,7 +200,7 @@ def validation_step(self, batch, batch_idx): print(self.validation_directory) with torch.set_grad_enabled(False): - x, y, info = batch + x, y, _, info = batch for s in range(y.shape[0]): @@ -494,7 +496,7 @@ def setup(self, stage=None): for p in case_aug_dir.glob(self.augmented_case_glob.format(case=case)) if not p.name.startswith(".") ] - print(augmented_cases) + train_data += [ { "id": f"{case}_{augmented_case}", From aad9eb01de5bbe7aafb621b1d0ca997927a5363e Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Sun, 25 Jul 2021 12:17:59 +1000 Subject: [PATCH 089/264] send config to comet ml --- platipy/imaging/cnn/train.py | 23 +++++++++++++++++------ platipy/imaging/generation/augment.py | 26 ++++++++++++++++++++++++++ 2 files changed, 43 insertions(+), 6 deletions(-) diff --git a/platipy/imaging/cnn/train.py b/platipy/imaging/cnn/train.py index d3c29833..81ed1fd3 100644 --- a/platipy/imaging/cnn/train.py +++ b/platipy/imaging/cnn/train.py @@ -127,6 +127,7 @@ def __init__( def add_model_specific_args(parent_parser): parser = parent_parser.add_argument_group("Probabilistic UNet") parser.add_argument("--learning_rate", type=float, default=1e-5) + parser.add_argument("--lr_lambda", type=float, default=0.99) parser.add_argument("--input_channels", type=int, default=1) parser.add_argument("--num_classes", type=int, default=2) parser.add_argument( @@ -152,7 +153,7 @@ def configure_optimizers(self): ) scheduler = torch.optim.lr_scheduler.LambdaLR( - optimizer, lr_lambda=[lambda epoch: 0.99 ** (epoch)] + optimizer, lr_lambda=[lambda epoch: self.hparams.lr_lambda ** (epoch)] ) return [optimizer], [scheduler] @@ -538,6 +539,7 @@ def setup(self, stage=None): validation_data, self.working_dir.joinpath("validation"), augment_on_the_fly=False, + spacing=self.spacing, crop_to_mm=self.crop_to_mm, ) @@ -561,7 +563,7 @@ def val_dataloader(self): ) -def main(args): +def main(args, config_json_path=None): pl.seed_everything(args.seed, workers=True) @@ -595,6 +597,8 @@ def main(args): save_dir=args.working_dir, offline=args.offline, ) + if config_json_path: + comet_logger.experiment.log_code(config_json_path) dict_args = vars(args) @@ -602,7 +606,10 @@ def main(args): prob_unet = ProbUNet(**dict_args) - trainer = pl.Trainer.from_argparse_args(args) + if args.resume_from is not None: + trainer = pl.Trainer(resume_from_checkpoint=args.resume_from) + else: + trainer = pl.Trainer.from_argparse_args(args) if comet_api_key is not None: trainer.logger = comet_logger @@ -617,10 +624,12 @@ def main(args): if __name__ == "__main__": args = None + config_json_path = None if len(sys.argv) == 2: # Check if JSON file parsed, if so read arguments from there... if sys.argv[-1].endswith(".json"): - with open(sys.argv[-1], "r") as f: + config_json_path = sys.argv[-1] + with open(config_json_path, "r") as f: params = json.load(f) args = [] for k in params: @@ -636,14 +645,16 @@ def main(args): arg_parser = ProbUNet.add_model_specific_args(arg_parser) arg_parser = ProbUNetDataModule.add_model_specific_args(arg_parser) arg_parser = pl.Trainer.add_argparse_args(arg_parser) + arg_parser.add_argument("--config", type=str, default=None, help="JSON file with parameters to load") arg_parser.add_argument("--seed", type=int, default=42, help="an integer to use as seed") arg_parser.add_argument("--experiment", type=str, default="default", help="Name of experiment") arg_parser.add_argument("--working_dir", type=str, default="./working") arg_parser.add_argument("--num_observers", type=int, default=5) - arg_parser.add_argument("--spacing", type=list, default=[1, 1, 1]) + arg_parser.add_argument("--spacing", nargs="+", type=int, default=[1,1,1]) arg_parser.add_argument("--offline", type=bool, default=False) arg_parser.add_argument("--comet_api_key", type=str, default=None) arg_parser.add_argument("--comet_workspace", type=str, default=None) arg_parser.add_argument("--comet_project", type=str, default=None) + arg_parser.add_argument("--resume_from", type=str, default=None) - main(arg_parser.parse_args(args)) + main(arg_parser.parse_args(args), config_json_path=config_json_path) diff --git a/platipy/imaging/generation/augment.py b/platipy/imaging/generation/augment.py index 8ed82136..c3eb5cd7 100644 --- a/platipy/imaging/generation/augment.py +++ b/platipy/imaging/generation/augment.py @@ -21,6 +21,7 @@ from argparse import ArgumentParser import SimpleITK as sitk +import numpy as np from loguru import logger @@ -37,6 +38,7 @@ from platipy.imaging.registration.utils import apply_transform +from platipy.imaging.utils.lung import detect_holes def apply_augmentation(image, augmentation, masks=[]): @@ -275,6 +277,27 @@ def augment_data(args): ct_image = sitk.ReadImage(str(data[case]["image"])) + if args.enable_fill_holes: + + label_image, labels = detect_holes(ct_image) + + for label in labels[1:]: # Skip first hole since likely air around body + + + if random.random() > args.fill_probability: continue + + hole = label_image == label["label"] + hole_dilate = sitk.BinaryDilate(hole, (2,2,2), sitk.sitkBall) + contour_points = sitk.BinaryContour(hole_dilate) + fill_value = np.median(sitk.GetArrayFromImage(ct_image)[sitk.GetArrayFromImage(contour_points)==1]) + + ct_arr = sitk.GetArrayFromImage(ct_image) + ct_arr[sitk.GetArrayFromImage(hole_dilate)==1] = fill_value + ct_filled = sitk.GetImageFromArray(ct_arr) + ct_filled.CopyInformation(ct_image) + + ct_image = ct_filled + # Get list of structures to generate augmentations off all_masks = [] all_names = [] @@ -370,4 +393,7 @@ def augment_data(args): arg_parser.add_argument("--contract_bone_mask", type=bool, default=True) arg_parser.add_argument("--contract_probability", type=float, default=0.5) + arg_parser.add_argument("--enable_fill_holes", type=bool, default=True) + arg_parser.add_argument("--fill_probability", type=float, default=0.2) + augment_data(arg_parser.parse_args()) From 803ddecb2f52584c3a718a1604ce591a2e618b4f Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Tue, 27 Jul 2021 09:15:43 +1000 Subject: [PATCH 090/264] train prob unet rec initially --- platipy/imaging/cnn/dataset.py | 3 +-- platipy/imaging/cnn/prob_unet.py | 38 +++++++++++++++++++++++--------- platipy/imaging/cnn/train.py | 8 ++++++- 3 files changed, 35 insertions(+), 14 deletions(-) diff --git a/platipy/imaging/cnn/dataset.py b/platipy/imaging/cnn/dataset.py index b35071ba..e8b3e150 100644 --- a/platipy/imaging/cnn/dataset.py +++ b/platipy/imaging/cnn/dataset.py @@ -222,7 +222,6 @@ def __init__( observers.append(label) contour_mask = get_contour_mask(observers) - sitk.WriteImage(contour_mask, f"cm_{case['id']}.nii.gz") for z_slice in range(img.GetSize()[2]): @@ -272,7 +271,7 @@ def __getitem__(self, index): img = torch.FloatTensor(img) label = torch.LongTensor(label) - contour_mask = torch.LongTensor(contour_mask) + contour_mask = torch.FloatTensor(contour_mask) return ( img.unsqueeze(0), diff --git a/platipy/imaging/cnn/prob_unet.py b/platipy/imaging/cnn/prob_unet.py index 40fee81f..078f00ce 100644 --- a/platipy/imaging/cnn/prob_unet.py +++ b/platipy/imaging/cnn/prob_unet.py @@ -158,6 +158,7 @@ def __init__( ): super(ProbabilisticUnet, self).__init__() + self.num_classes = num_classes self.no_convs_per_block = 3 self.no_convs_fcomb = no_convs_fcomb self.initializers = {"w": "he_normal", "b": "normal"} @@ -264,6 +265,8 @@ def reconstruction_loss( mask=None, top_k_percentage=None, deterministic=True, + weight_mask=None, + weight_mask_weighting=0.0 ): criterion = torch.nn.BCEWithLogitsLoss(reduction="none") @@ -284,13 +287,13 @@ def reconstruction_loss( y_flat = torch.transpose(reconstruction, 1, -1).reshape((-1, num_classes)) t_flat = torch.transpose(segm, 1, -1).reshape((-1, num_classes)) n_pixels_in_batch = y_flat.shape[0] - if mask is None: + if mask is None or mask.sum() == 0: mask = torch.ones(n_pixels_in_batch) else: - assert ( - mask.shape == segm.shape - ), f"The loss mask shape differs from the target shape: {mask.shape} vs. {segm.shape}." - mask = torch.reshape(segm, (-1,)) +# assert ( +# mask.shape == segm.shape +# ), f"The loss mask shape differs from the target shape: {mask.shape} vs. {segm.shape}." + mask = torch.reshape(mask, (-1,)) mask = mask.to(y_flat.device) xe = criterion(input=y_flat, target=t_flat) @@ -321,13 +324,21 @@ def reconstruction_loss( mask.reshape((batch_size, -1, num_classes)).transpose(-1, 1).reshape((batch_size, -1)) ) + if weight_mask is not None: + weight_mask = torch.reshape(weight_mask, (-1,)) + weight_mask = weight_mask.unsqueeze(1).repeat((1, num_classes)) + weight_mask = ( + weight_mask.reshape((batch_size, -1, num_classes)).transpose(-1, 1).reshape((batch_size, -1)) + ) + mask = mask + (weight_mask * weight_mask_weighting) + ce_sum_per_instance = torch.sum(mask * xe, axis=1) ce_sum = torch.mean(ce_sum_per_instance, axis=0) ce_mean = torch.sum(mask * xe) / torch.sum(mask) return ce_sum, ce_mean, mask - def loss(self, segm, analytic_kl=True, reconstruct_posterior_mean=False, mask=None): + def loss(self, segm, analytic_kl=True, reconstruct_posterior_mean=False, mask=None, use_max_lambda=False): """ Calculate the evidence lower bound of the log-likelihood of P(Y|X) """ @@ -343,9 +354,10 @@ def loss(self, segm, analytic_kl=True, reconstruct_posterior_mean=False, mask=No loss_mask = None if self.loss_params["contour_loss_lambda_threshold"]: if ( - self._lambda <= self.loss_params["contour_loss_lambda_threshold"] + self._lambda <= self.loss_params["contour_loss_lambda_threshold"] and not use_max_lambda ): # pylint: disable=access-member-before-definition loss_mask = mask +# loss_mask = loss_mask.unsqueeze(1).repeat((1, self.num_classes, 1, 1)) # Here we use the posterior sample sampled above reconstruction_loss, rec_loss_mean, mask = self.reconstruction_loss( @@ -354,6 +366,8 @@ def loss(self, segm, analytic_kl=True, reconstruct_posterior_mean=False, mask=No z_posterior=z_posterior, top_k_percentage=top_k_percentage, mask=loss_mask, + weight_mask=mask, + weight_mask_weighting=self.loss_params["contour_loss_weight"] ) if self.loss_type == "elbo": @@ -384,15 +398,17 @@ def loss(self, segm, analytic_kl=True, reconstruct_posterior_mean=False, mask=No rc = self._moving_avg - reconstruction_threshold lambda_lower = self.loss_params["clamp_rec"][0] lambda_upper = self.loss_params["clamp_rec"][1] - self._lambda = (torch.exp(rc) * self._lambda).clamp(lambda_lower, lambda_upper) + if use_max_lambda: + self._lambda = torch.Tensor([lambda_upper]).to(rc.device) + else: + self._lambda = (torch.exp(rc) * self._lambda).clamp(lambda_lower, lambda_upper) + loss = ( self._lambda * reconstruction_loss # pylint: disable=access-member-before-definition + kl_div ) - # self._lambda = (speed * self._moving_avg * self._lambda).clamp( - # self.loss_params["clamp_rec"][0], self.loss_params["clamp_rec"][1] - # ) + return { "loss": loss, "rec_loss": reconstruction_loss, diff --git a/platipy/imaging/cnn/train.py b/platipy/imaging/cnn/train.py index 60c69891..f3acdaba 100644 --- a/platipy/imaging/cnn/train.py +++ b/platipy/imaging/cnn/train.py @@ -109,6 +109,7 @@ def __init__( loss_params["top_k_percentage"] = self.hparams.top_k_percentage loss_params["contour_loss_lambda_threshold"] = self.hparams.contour_loss_lambda_threshold + loss_params["contour_loss_weight"] = self.hparams.contour_loss_weight self.prob_unet = ProbabilisticUnet( self.hparams.input_channels, @@ -142,6 +143,8 @@ def add_model_specific_args(parent_parser): parser.add_argument("--clamp_rec", nargs="+", type=float, default=[1e-5, 1e5]) parser.add_argument("--top_k_percentage", type=float, default=None) parser.add_argument("--contour_loss_lambda_threshold", type=float, default=None) + parser.add_argument("--contour_loss_weight", type=float, default=0.0) + parser.add_argument("--epochs_all_rec", type=int, default=0) return parent_parser @@ -165,7 +168,10 @@ def training_step(self, batch, _): x, y, m, _ = batch self.prob_unet.forward(x, y, training=True) - loss = self.prob_unet.loss(y, analytic_kl=True, mask=m) + + use_max_lambda = self.current_epoch < self.hparams.epochs_all_rec + + loss = self.prob_unet.loss(y, analytic_kl=True, mask=m, use_max_lambda=use_max_lambda) reg_loss = ( l2_regularisation(self.prob_unet.posterior) + l2_regularisation(self.prob_unet.prior) From a490e97cb3e21c79e52c22a372e30476bd5382c9 Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Sat, 31 Jul 2021 01:04:41 +0000 Subject: [PATCH 091/264] Don't separate out directories for working files --- .pylintrc | 2 +- platipy/imaging/cnn/dataset.py | 20 +---------- platipy/imaging/cnn/train.py | 20 +++++------ platipy/imaging/label/utils.py | 63 ++++++++++++++++++++++++++++++++++ 4 files changed, 75 insertions(+), 30 deletions(-) diff --git a/.pylintrc b/.pylintrc index f9ecee68..821348c0 100644 --- a/.pylintrc +++ b/.pylintrc @@ -443,7 +443,7 @@ contextmanager-decorators=contextlib.contextmanager # List of members which are set dynamically and missed by pylint inference # system, and so shouldn't trigger E1101 when accessed. Python regular # expressions are accepted. -generated-members=torch.* +generated-members=torch.*,pytorch_lightning.* # Tells whether missing members accessed in mixin class should be ignored. A # mixin class is detected if its name ends with "mixin" (case insensitive). diff --git a/platipy/imaging/cnn/dataset.py b/platipy/imaging/cnn/dataset.py index e8b3e150..e525d36d 100644 --- a/platipy/imaging/cnn/dataset.py +++ b/platipy/imaging/cnn/dataset.py @@ -1,5 +1,4 @@ import re -import copy from pathlib import Path import numpy as np @@ -13,24 +12,7 @@ from loguru import logger - -def get_union_mask(masks): - - union_mask = copy.copy(masks[0]) - for mask in masks[1:]: - union_mask += mask - - return sitk.Cast(union_mask > 0, sitk.sitkUInt8) - - -def get_intersection_mask(masks): - - intersection_mask = copy.copy(masks[0]) - for mask in masks[1:]: - intersection_mask += mask - - return sitk.Cast(intersection_mask == len(masks), sitk.sitkUInt8) - +from platipy.imaging.label.utils import get_union_mask, get_intersection_mask def get_contour_mask(masks, kernel=5): diff --git a/platipy/imaging/cnn/train.py b/platipy/imaging/cnn/train.py index f3acdaba..b258f0f1 100644 --- a/platipy/imaging/cnn/train.py +++ b/platipy/imaging/cnn/train.py @@ -24,13 +24,12 @@ import numpy as np from scipy.optimize import linear_sum_assignment -import comet_ml +import comet_ml # pylint: disable=unused-import from pytorch_lightning.loggers import CometLogger import torch import pytorch_lightning as pl from pytorch_lightning.callbacks import LearningRateMonitor -from pytorch_lightning.callbacks.progress import ProgressBar from argparse import ArgumentParser @@ -539,13 +538,13 @@ def setup(self, stage=None): self.training_set = NiftiDataset( train_data, - self.working_dir.joinpath("train"), + self.working_dir, spacing=self.spacing, crop_to_mm=self.crop_to_mm, ) self.validation_set = NiftiDataset( validation_data, - self.working_dir.joinpath("validation"), + self.working_dir, augment_on_the_fly=False, spacing=self.spacing, crop_to_mm=self.crop_to_mm, @@ -623,10 +622,9 @@ def main(args, config_json_path=None): trainer.logger = comet_logger lr_monitor = LearningRateMonitor(logging_interval="step") - # bar = ProgressBar() trainer.callbacks.append(lr_monitor) - trainer.fit(prob_unet, data_module) # pylint: disable=no-member + trainer.fit(prob_unet, data_module) if __name__ == "__main__": @@ -644,8 +642,8 @@ def main(args, config_json_path=None): args.append(f"--{k}") if isinstance(params[k], list): - for x in params[k]: - args.append(str(x)) + for s in params[k]: + args.append(str(s)) else: args.append(str(params[k])) @@ -653,12 +651,14 @@ def main(args, config_json_path=None): arg_parser = ProbUNet.add_model_specific_args(arg_parser) arg_parser = ProbUNetDataModule.add_model_specific_args(arg_parser) arg_parser = pl.Trainer.add_argparse_args(arg_parser) - arg_parser.add_argument("--config", type=str, default=None, help="JSON file with parameters to load") + arg_parser.add_argument( + "--config", type=str, default=None, help="JSON file with parameters to load" + ) arg_parser.add_argument("--seed", type=int, default=42, help="an integer to use as seed") arg_parser.add_argument("--experiment", type=str, default="default", help="Name of experiment") arg_parser.add_argument("--working_dir", type=str, default="./working") arg_parser.add_argument("--num_observers", type=int, default=5) - arg_parser.add_argument("--spacing", nargs="+", type=int, default=[1,1,1]) + arg_parser.add_argument("--spacing", nargs="+", type=int, default=[1, 1, 1]) arg_parser.add_argument("--offline", type=bool, default=False) arg_parser.add_argument("--comet_api_key", type=str, default=None) arg_parser.add_argument("--comet_workspace", type=str, default=None) diff --git a/platipy/imaging/label/utils.py b/platipy/imaging/label/utils.py index 8648c109..16ae4cd0 100644 --- a/platipy/imaging/label/utils.py +++ b/platipy/imaging/label/utils.py @@ -12,6 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +import copy +from pathlib import Path + import SimpleITK as sitk import numpy as np @@ -218,3 +221,63 @@ def binary_decode_image(binary_encoded_img): continue return structure_list + + +def get_union_mask(masks): + """Get the union mask + + Args: + masks (list|dict): A list or dictionary of masks given as SimpleITK.Image or path to mask + file. + + Raises: + ValueError: Raised if masks provided is empty. + + Returns: + SimpleITK.Image: The union mask + """ + + if isinstance(masks, dict): + masks = [masks[k] for k in masks] + + if len(masks) > 0: + raise ValueError("Masks must not be empty") + + if isinstance(masks[0], str, Path): + masks = [sitk.ReadImage(str(m)) for m in masks] + + union_mask = copy.copy(masks[0]) + for mask in masks[1:]: + union_mask += mask + + return sitk.Cast(union_mask > 0, sitk.sitkUInt8) + + +def get_intersection_mask(masks): + """Get the intersection mask + + Args: + masks (list|dict): A list or dictionary of masks given as SimpleITK.Image or path to mask + file. + + Raises: + ValueError: Raised if masks provided is empty. + + Returns: + SimpleITK.Image: The intersection mask + """ + + if isinstance(masks, dict): + masks = [masks[k] for k in masks] + + if len(masks) > 0: + raise ValueError("Masks must not be empty") + + if isinstance(masks[0], str, Path): + masks = [sitk.ReadImage(str(m)) for m in masks] + + intersection_mask = copy.copy(masks[0]) + for mask in masks[1:]: + intersection_mask += mask + + return sitk.Cast(intersection_mask == len(masks), sitk.sitkUInt8) From 2f1ed501f9082c792f8241faa57164be80d0d371 Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Sat, 31 Jul 2021 01:16:48 +0000 Subject: [PATCH 092/264] Correct check in union intersection masks --- platipy/imaging/label/utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/platipy/imaging/label/utils.py b/platipy/imaging/label/utils.py index 16ae4cd0..b54a99c5 100644 --- a/platipy/imaging/label/utils.py +++ b/platipy/imaging/label/utils.py @@ -240,7 +240,7 @@ def get_union_mask(masks): if isinstance(masks, dict): masks = [masks[k] for k in masks] - if len(masks) > 0: + if len(masks) == 0: raise ValueError("Masks must not be empty") if isinstance(masks[0], str, Path): @@ -270,7 +270,7 @@ def get_intersection_mask(masks): if isinstance(masks, dict): masks = [masks[k] for k in masks] - if len(masks) > 0: + if len(masks) == 0: raise ValueError("Masks must not be empty") if isinstance(masks[0], str, Path): From 0517f983a0f6734010e81354aca54edaf88a2754 Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Sat, 31 Jul 2021 01:24:15 +0000 Subject: [PATCH 093/264] make contour mask kernel size configurable --- platipy/imaging/cnn/dataset.py | 11 +++++++++-- platipy/imaging/cnn/train.py | 3 +++ platipy/imaging/label/utils.py | 4 ++-- 3 files changed, 14 insertions(+), 4 deletions(-) diff --git a/platipy/imaging/cnn/dataset.py b/platipy/imaging/cnn/dataset.py index e525d36d..a50fa867 100644 --- a/platipy/imaging/cnn/dataset.py +++ b/platipy/imaging/cnn/dataset.py @@ -14,6 +14,7 @@ from platipy.imaging.label.utils import get_union_mask, get_intersection_mask + def get_contour_mask(masks, kernel=5): if not hasattr(kernel, "__iter__"): @@ -125,7 +126,13 @@ class NiftiDataset(torch.utils.data.Dataset): """PyTorch Dataset for processing Nifti data""" def __init__( - self, data, working_dir, augment_on_the_fly=True, spacing=[1, 1, 1], crop_to_mm=128 + self, + data, + working_dir, + augment_on_the_fly=True, + spacing=[1, 1, 1], + crop_to_mm=128, + contour_mask_kernel=5, ): """Prepare a dataset from Nifti images/labels @@ -203,7 +210,7 @@ def __init__( label = resample_mask_to_image(img, label) observers.append(label) - contour_mask = get_contour_mask(observers) + contour_mask = get_contour_mask(observers, kernel=contour_mask_kernel) for z_slice in range(img.GetSize()[2]): diff --git a/platipy/imaging/cnn/train.py b/platipy/imaging/cnn/train.py index b258f0f1..7c635e3c 100644 --- a/platipy/imaging/cnn/train.py +++ b/platipy/imaging/cnn/train.py @@ -459,6 +459,7 @@ def add_model_specific_args(parent_parser): parser.add_argument("--augmented_image_glob", type=str, default=None) parser.add_argument("--augmented_label_glob", type=str, default=None) parser.add_argument("--crop_to_mm", type=int, default=128) + parser.add_argument("--contour_mask_kernel", type=int, default=5) return parent_parser @@ -541,6 +542,7 @@ def setup(self, stage=None): self.working_dir, spacing=self.spacing, crop_to_mm=self.crop_to_mm, + contour_mask_kernel=self.contour_mask_kernel ) self.validation_set = NiftiDataset( validation_data, @@ -548,6 +550,7 @@ def setup(self, stage=None): augment_on_the_fly=False, spacing=self.spacing, crop_to_mm=self.crop_to_mm, + contour_mask_kernel=self.contour_mask_kernel ) def train_dataloader(self): diff --git a/platipy/imaging/label/utils.py b/platipy/imaging/label/utils.py index b54a99c5..3de03085 100644 --- a/platipy/imaging/label/utils.py +++ b/platipy/imaging/label/utils.py @@ -243,7 +243,7 @@ def get_union_mask(masks): if len(masks) == 0: raise ValueError("Masks must not be empty") - if isinstance(masks[0], str, Path): + if isinstance(masks[0], (str, Path)): masks = [sitk.ReadImage(str(m)) for m in masks] union_mask = copy.copy(masks[0]) @@ -273,7 +273,7 @@ def get_intersection_mask(masks): if len(masks) == 0: raise ValueError("Masks must not be empty") - if isinstance(masks[0], str, Path): + if isinstance(masks[0], (str, Path)): masks = [sitk.ReadImage(str(m)) for m in masks] intersection_mask = copy.copy(masks[0]) From 086210818312db11262da66606bd286b2803258b Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Sat, 31 Jul 2021 11:29:23 +1000 Subject: [PATCH 094/264] Add contour_mask_kernel to class attributes --- platipy/imaging/cnn/train.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/platipy/imaging/cnn/train.py b/platipy/imaging/cnn/train.py index 7c635e3c..a7a3a90a 100644 --- a/platipy/imaging/cnn/train.py +++ b/platipy/imaging/cnn/train.py @@ -412,6 +412,7 @@ def __init__( crop_to_mm=128, num_observers=5, spacing=[1, 1, 1], + contour_mask_kernel=3, **kwargs, ): super().__init__() @@ -437,6 +438,7 @@ def __init__( self.crop_to_mm = crop_to_mm self.num_observers = num_observers self.spacing = spacing + self.contour_mask_kernel = contour_mask_kernel self.training_set = None self.validation_set = None From b36593322da0db3a20ba839be12b41c31a530bac Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Fri, 13 Aug 2021 01:40:37 +0000 Subject: [PATCH 095/264] Extend probnet to 3D --- .pylintrc | 9 +- platipy/imaging/cnn/prob_unet.py | 124 ++++++++++++++++++++++------ platipy/imaging/cnn/train.py | 22 +++-- platipy/imaging/cnn/unet.py | 136 +++++++++++++++++++++++++++---- requirements-dl.txt | 4 + 5 files changed, 240 insertions(+), 55 deletions(-) create mode 100644 requirements-dl.txt diff --git a/.pylintrc b/.pylintrc index 821348c0..68118e59 100644 --- a/.pylintrc +++ b/.pylintrc @@ -139,10 +139,11 @@ disable=print-statement, deprecated-sys-function, exception-escape, comprehension-escape, - C0330, - C0114, - W0102, - W0105 + bad-continuation, + missing-module-docstring, + # pointless-string-statement, + dangerous-default-value, + arguments-differ # Enable the message, report, category or checker with the given id(s). You can # either give multiple identifier separated by comma (,) or put this option diff --git a/platipy/imaging/cnn/prob_unet.py b/platipy/imaging/cnn/prob_unet.py index 078f00ce..e9c24176 100644 --- a/platipy/imaging/cnn/prob_unet.py +++ b/platipy/imaging/cnn/prob_unet.py @@ -19,11 +19,16 @@ import torch from torch.distributions import Normal, Independent, kl -from platipy.imaging.cnn.unet import UNet, Conv, init_weights +from platipy.imaging.cnn.unet import UNet, Conv, init_weights, conv_nd class Encoder(torch.nn.Module): - def __init__(self, input_channels, filters_per_layer=[64 * (2 ** x) for x in range(5)]): + """Encoder part of the probabilistic UNet + """ + + def __init__( + self, input_channels, filters_per_layer=[64 * (2 ** x) for x in range(5)], ndims=2 + ): super(Encoder, self).__init__() layers = [] @@ -34,7 +39,9 @@ def __init__(self, input_channels, filters_per_layer=[64 * (2 ** x) for x in ran down_sample = 0 if idx == 0 else -2 - layers.append(Conv(input_filters, output_filters, up_down_sample=down_sample)) + layers.append( + Conv(input_filters, output_filters, up_down_sample=down_sample, ndims=ndims) + ) self.layers = torch.nn.Sequential(*layers) @@ -47,19 +54,42 @@ def forward(self, x): class AxisAlignedConvGaussian(torch.nn.Module): def __init__( - self, input_channels, filters_per_layer=[64 * (2 ** x) for x in range(5)], latent_dim=2 + self, + input_channels, + filters_per_layer=[64 * (2 ** x) for x in range(5)], + latent_dim=2, + ndims=2, ): super(AxisAlignedConvGaussian, self).__init__() self.latent_dim = latent_dim - self.encoder = Encoder(input_channels, filters_per_layer) - self.final = torch.nn.Conv2d(filters_per_layer[-1], 2 * self.latent_dim, (1, 1), stride=1) + self.encoder = Encoder(input_channels, filters_per_layer, ndims=ndims) + + self.final = conv_nd( + in_channels=filters_per_layer[-1], + out_channels=2 * self.latent_dim, + kernel_size=1, + stride=1, + ndims=ndims + ) + + self.ndims = ndims self.final.apply(init_weights) def forward(self, img, seg=None): + """Forward pass through the network + + Args: + img (torch.Tensor): The image to be passed through. + seg (torch.Tensor, optional): The segmentation mask to use in the case of the prior + network. Defaults to None. + + Returns: + torch.distributions.distribution.Distribution: The distribution output + """ x = img if seg is not None: @@ -71,6 +101,8 @@ def forward(self, img, seg=None): # We only want the mean of the resulting hxw image encoding = torch.mean(encoding, dim=2, keepdim=True) encoding = torch.mean(encoding, dim=3, keepdim=True) + if self.ndims == 3: + encoding = torch.mean(encoding, dim=4, keepdim=True) # Convert encoding to 2 x latent dim and split up for mu and log_sigma mu_log_sigma = self.final(encoding) @@ -79,6 +111,8 @@ def forward(self, img, seg=None): # equal to 1 mu_log_sigma = torch.squeeze(mu_log_sigma, dim=2) mu_log_sigma = torch.squeeze(mu_log_sigma, dim=2) + if self.ndims == 3: + mu_log_sigma = torch.squeeze(mu_log_sigma, dim=2) mu = mu_log_sigma[:, : self.latent_dim] log_sigma = mu_log_sigma[:, self.latent_dim :] @@ -97,7 +131,7 @@ class Fcomb(torch.nn.Module): their channel axis. """ - def __init__(self, filters_per_layer, latent_dim, num_classes, no_convs_fcomb): + def __init__(self, filters_per_layer, latent_dim, num_classes, no_convs_fcomb, ndims=2): super(Fcomb, self).__init__() layers = [] @@ -105,27 +139,43 @@ def __init__(self, filters_per_layer, latent_dim, num_classes, no_convs_fcomb): # Decoder of N x a 1x1 convolution followed by a ReLU activation function except for the # last layer layers.append( - torch.nn.Conv2d(filters_per_layer[0] + latent_dim, filters_per_layer[0], kernel_size=1) + conv_nd( + in_channels=filters_per_layer[0] + latent_dim, + out_channels=filters_per_layer[0], + kernel_size=1, + ndims=ndims, + ) ) layers.append(torch.nn.ReLU(inplace=True)) for _ in range(no_convs_fcomb - 2): layers.append( - torch.nn.Conv2d(filters_per_layer[0], filters_per_layer[0], kernel_size=1) + conv_nd( + in_channels=filters_per_layer[0], + out_channels=filters_per_layer[0], + kernel_size=1, + ndims=ndims, + ) ) layers.append(torch.nn.ReLU(inplace=True)) self.layers = torch.nn.Sequential(*layers) - self.last_layer = torch.nn.Conv2d(filters_per_layer[0], num_classes, kernel_size=1) + self.last_layer = conv_nd( + in_channels=filters_per_layer[0], out_channels=num_classes, kernel_size=1, ndims=ndims + ) self.layers.apply(init_weights) self.last_layer.apply(init_weights) + self.ndims = ndims + def forward(self, feature_map, z): z = torch.unsqueeze(z, 2).expand(-1, -1, feature_map.shape[2]) z = torch.unsqueeze(z, 3).expand(-1, -1, -1, feature_map.shape[3]) + if self.ndims == 3: + z = torch.unsqueeze(z, 4).expand(-1, -1, -1, -1, feature_map.shape[4]) # Concatenate the feature map (output of the UNet) and the sample taken from the latent # space @@ -155,6 +205,7 @@ def __init__( no_convs_fcomb=4, loss_type="elbo", loss_params={"beta": 1}, + ndims=2, ): super(ProbabilisticUnet, self).__init__() @@ -164,10 +215,16 @@ def __init__( self.initializers = {"w": "he_normal", "b": "normal"} self.z_prior_sample = 0 - self.unet = UNet(input_channels, num_classes, filters_per_layer, final_layer=False) - self.prior = AxisAlignedConvGaussian(input_channels, filters_per_layer, latent_dim) - self.posterior = AxisAlignedConvGaussian(input_channels + 1, filters_per_layer, latent_dim) - self.fcomb = Fcomb(filters_per_layer, latent_dim, num_classes, no_convs_fcomb) + self.unet = UNet( + input_channels, num_classes, filters_per_layer, final_layer=False, ndims=ndims + ) + self.prior = AxisAlignedConvGaussian( + input_channels, filters_per_layer, latent_dim, ndims=ndims + ) + self.posterior = AxisAlignedConvGaussian( + input_channels + 1, filters_per_layer, latent_dim, ndims=ndims + ) + self.fcomb = Fcomb(filters_per_layer, latent_dim, num_classes, no_convs_fcomb, ndims=ndims) self.loss_type = loss_type self.loss_params = loss_params @@ -266,7 +323,7 @@ def reconstruction_loss( top_k_percentage=None, deterministic=True, weight_mask=None, - weight_mask_weighting=0.0 + weight_mask_weighting=0.0, ): criterion = torch.nn.BCEWithLogitsLoss(reduction="none") @@ -290,9 +347,9 @@ def reconstruction_loss( if mask is None or mask.sum() == 0: mask = torch.ones(n_pixels_in_batch) else: -# assert ( -# mask.shape == segm.shape -# ), f"The loss mask shape differs from the target shape: {mask.shape} vs. {segm.shape}." + # assert ( + # mask.shape == segm.shape + # ), f"The loss mask shape differs from the target shape: {mask.shape} vs. {segm.shape}." mask = torch.reshape(mask, (-1,)) mask = mask.to(y_flat.device) @@ -328,7 +385,9 @@ def reconstruction_loss( weight_mask = torch.reshape(weight_mask, (-1,)) weight_mask = weight_mask.unsqueeze(1).repeat((1, num_classes)) weight_mask = ( - weight_mask.reshape((batch_size, -1, num_classes)).transpose(-1, 1).reshape((batch_size, -1)) + weight_mask.reshape((batch_size, -1, num_classes)) + .transpose(-1, 1) + .reshape((batch_size, -1)) ) mask = mask + (weight_mask * weight_mask_weighting) @@ -338,7 +397,14 @@ def reconstruction_loss( return ce_sum, ce_mean, mask - def loss(self, segm, analytic_kl=True, reconstruct_posterior_mean=False, mask=None, use_max_lambda=False): + def loss( + self, + segm, + analytic_kl=True, + reconstruct_posterior_mean=False, + mask=None, + use_max_lambda=False, + ): """ Calculate the evidence lower bound of the log-likelihood of P(Y|X) """ @@ -354,10 +420,12 @@ def loss(self, segm, analytic_kl=True, reconstruct_posterior_mean=False, mask=No loss_mask = None if self.loss_params["contour_loss_lambda_threshold"]: if ( - self._lambda <= self.loss_params["contour_loss_lambda_threshold"] and not use_max_lambda - ): # pylint: disable=access-member-before-definition + self._lambda # pylint: disable=access-member-before-definition + <= self.loss_params["contour_loss_lambda_threshold"] + and not use_max_lambda + ): loss_mask = mask -# loss_mask = loss_mask.unsqueeze(1).repeat((1, self.num_classes, 1, 1)) + # loss_mask = loss_mask.unsqueeze(1).repeat((1, self.num_classes, 1, 1)) # Here we use the posterior sample sampled above reconstruction_loss, rec_loss_mean, mask = self.reconstruction_loss( @@ -367,7 +435,7 @@ def loss(self, segm, analytic_kl=True, reconstruct_posterior_mean=False, mask=No top_k_percentage=top_k_percentage, mask=loss_mask, weight_mask=mask, - weight_mask_weighting=self.loss_params["contour_loss_weight"] + weight_mask_weighting=self.loss_params["contour_loss_weight"], ) if self.loss_type == "elbo": @@ -399,9 +467,13 @@ def loss(self, segm, analytic_kl=True, reconstruct_posterior_mean=False, mask=No lambda_lower = self.loss_params["clamp_rec"][0] lambda_upper = self.loss_params["clamp_rec"][1] if use_max_lambda: - self._lambda = torch.Tensor([lambda_upper]).to(rc.device) + self._lambda = torch.Tensor( # pylint: disable=attribute-defined-outside-init + [lambda_upper] + ).to(rc.device) else: - self._lambda = (torch.exp(rc) * self._lambda).clamp(lambda_lower, lambda_upper) + self._lambda = ( # pylint: disable=attribute-defined-outside-init + torch.exp(rc) * self._lambda + ).clamp(lambda_lower, lambda_upper) loss = ( self._lambda diff --git a/platipy/imaging/cnn/train.py b/platipy/imaging/cnn/train.py index a7a3a90a..f245f325 100644 --- a/platipy/imaging/cnn/train.py +++ b/platipy/imaging/cnn/train.py @@ -118,6 +118,7 @@ def __init__( self.hparams.no_convs_fcomb, self.hparams.loss_type, loss_params, + self.hparams.ndims, ) self.validation_directory = None @@ -413,6 +414,7 @@ def __init__( num_observers=5, spacing=[1, 1, 1], contour_mask_kernel=3, + ndims=2, **kwargs, ): super().__init__() @@ -443,6 +445,8 @@ def __init__( self.training_set = None self.validation_set = None + self.ndims = ndims + print(f"Training fold {self.fold}") @staticmethod @@ -462,6 +466,7 @@ def add_model_specific_args(parent_parser): parser.add_argument("--augmented_label_glob", type=str, default=None) parser.add_argument("--crop_to_mm", type=int, default=128) parser.add_argument("--contour_mask_kernel", type=int, default=5) + parser.add_argument("--ndims", type=int, default=2) return parent_parser @@ -542,9 +547,11 @@ def setup(self, stage=None): self.training_set = NiftiDataset( train_data, self.working_dir, + augment_on_the_fly=False, spacing=self.spacing, crop_to_mm=self.crop_to_mm, - contour_mask_kernel=self.contour_mask_kernel + contour_mask_kernel=self.contour_mask_kernel, + ndims=self.ndims, ) self.validation_set = NiftiDataset( validation_data, @@ -552,7 +559,8 @@ def setup(self, stage=None): augment_on_the_fly=False, spacing=self.spacing, crop_to_mm=self.crop_to_mm, - contour_mask_kernel=self.contour_mask_kernel + contour_mask_kernel=self.contour_mask_kernel, + ndims=self.ndims, ) def train_dataloader(self): @@ -643,14 +651,14 @@ def main(args, config_json_path=None): with open(config_json_path, "r") as f: params = json.load(f) args = [] - for k in params: - args.append(f"--{k}") + for key in params: + args.append(f"--{key}") - if isinstance(params[k], list): - for s in params[k]: + if isinstance(params[key], list): + for s in params[key]: args.append(str(s)) else: - args.append(str(params[k])) + args.append(str(params[key])) arg_parser = ArgumentParser() arg_parser = ProbUNet.add_model_specific_args(arg_parser) diff --git a/platipy/imaging/cnn/unet.py b/platipy/imaging/cnn/unet.py index 1dd70bbb..f41adf93 100644 --- a/platipy/imaging/cnn/unet.py +++ b/platipy/imaging/cnn/unet.py @@ -22,6 +22,38 @@ from torch import nn +def conv_nd(ndims=2, **kwargs): + """Generate a 2D or 3D convolution + + Args: + ndims (int, optional): 2 or 3 dimensions. Defaults to 2. + + Raises: + NotImplementedError: Raised if ndims is not in 2 or 3 dimensions. + + Returns: + torch.nn.Conv: The convolution. + """ + + if ndims == 2: + return torch.nn.Conv2d(**kwargs) + elif ndims == 3: + return torch.nn.Conv3d(**kwargs) + + raise NotImplementedError("Only 2 or 3 dimensions are supported") + + +def init_weights(m): + if ( + isinstance(m, torch.nn.Conv2d) + or isinstance(m, torch.nn.ConvTranspose2d) + or isinstance(m, torch.nn.Conv3d) + or isinstance(m, torch.nn.ConvTranspose3d) + ): + torch.nn.init.kaiming_normal_(m.weight, mode="fan_in", nonlinearity="relu") + truncated_normal_(m.bias, mean=0, std=0.001) + + def l2_regularisation(m): l2_reg = None @@ -33,6 +65,17 @@ def l2_regularisation(m): return l2_reg +def init_zeros(m): + if ( + isinstance(m, torch.nn.Conv2d) + or isinstance(m, torch.nn.ConvTranspose2d) + or isinstance(m, torch.nn.Conv3d) + or isinstance(m, torch.nn.ConvTranspose3d) + ): + torch.nn.init.zeros_(m.weight) + truncated_normal_(m.bias, mean=0, std=0.1) + + def truncated_normal_(tensor, mean=0, std=1): size = tensor.shape tmp = tensor.new_empty(size + (4,)).normal_() @@ -42,39 +85,88 @@ def truncated_normal_(tensor, mean=0, std=1): tensor.data.mul_(std).add_(mean) -def init_weights(m): - if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d): - nn.init.kaiming_normal_(m.weight, mode="fan_in", nonlinearity="relu") - truncated_normal_(m.bias, mean=0, std=0.001) +def resize_down_func(scale=2, ndims=2): + """Returns function to resize the input to downsample + + Args: + scale (int, optional): The scale used to downsize. Defaults to 2. + ndims (int, optional): Number of dimensions (2 or 3). Defaults to 2. + + Returns: + function: The downsize function + """ + if ndims == 3: + return torch.nn.MaxPool3d(kernel_size=scale, stride=scale, padding=0) + elif ndims == 2: + return torch.nn.MaxPool2d(kernel_size=scale, stride=scale, padding=0) + + raise NotImplementedError() + + +def resize_up_func(in_channels, out_channels, scale=2, ndims=2): + """Return function to resize the input to upsample + + Args: + in_channels (int): Number of input channels + out_channels (int): Number of output channels + scale (int, optional): The scale used to upsize. Defaults to 2. + ndims (int, optional): Number of dimensions (2 or 3). Defaults to 2. + + Raises: + NotImplementedError: Only supports 2d or 3d + + Returns: + function: The upsize function + """ + if ndims == 3: + return torch.nn.ConvTranspose3d( + in_channels, + out_channels, + kernel_size=scale, + stride=scale, + ) + elif ndims == 2: + return torch.nn.ConvTranspose2d( + in_channels, + out_channels, + kernel_size=scale, + stride=scale, + ) + raise NotImplementedError() class Conv(torch.nn.Module): - def __init__(self, input_channels, output_channels, up_down_sample=0): + def __init__(self, input_channels, output_channels, up_down_sample=0, ndims=2): super(Conv, self).__init__() self.pre_op = None size_and_stride = abs(up_down_sample) if up_down_sample < 0: - self.pre_op = nn.MaxPool2d(kernel_size=size_and_stride, stride=size_and_stride) + self.pre_op = resize_down_func(size_and_stride, ndims=ndims) elif up_down_sample > 0: - self.pre_op = nn.ConvTranspose2d( - input_channels, - output_channels, - kernel_size=size_and_stride, - stride=size_and_stride, + self.pre_op = resize_up_func( + input_channels, output_channels, size_and_stride, ndims=ndims ) layers = [] layers.append( - nn.Conv2d( - in_channels=input_channels, out_channels=output_channels, kernel_size=3, padding=1 + conv_nd( + ndims=ndims, + in_channels=input_channels, + out_channels=output_channels, + kernel_size=3, + padding=1, ) ) layers.append(nn.ReLU(inplace=True)) layers.append( - nn.Conv2d( - in_channels=output_channels, out_channels=output_channels, kernel_size=3, padding=1 + conv_nd( + ndims=ndims, + in_channels=output_channels, + out_channels=output_channels, + kernel_size=3, + padding=1, ) ) layers.append(nn.ReLU(inplace=True)) @@ -100,6 +192,7 @@ def __init__( output_classes=2, filters_per_layer=[64 * (2 ** x) for x in range(5)], final_layer=True, + ndims=2, ): super(UNet, self).__init__() @@ -110,7 +203,9 @@ def __init__( output_filters = layer_filters down_sample = 0 if idx == 0 else -2 - self.encoder.append(Conv(input_filters, output_filters, up_down_sample=down_sample)) + self.encoder.append( + Conv(input_filters, output_filters, up_down_sample=down_sample, ndims=ndims) + ) reversed_filters = list(reversed(filters_per_layer)) self.decoder = nn.ModuleList() @@ -122,11 +217,16 @@ def __init__( input_filters = layer_filters output_filters = reversed_filters[idx + 1] - self.decoder.append(Conv(input_filters, output_filters, up_down_sample=2)) + self.decoder.append(Conv(input_filters, output_filters, up_down_sample=2, ndims=ndims)) self.final = None if final_layer: - self.final = nn.Conv2d(filters_per_layer[0], output_classes, kernel_size=1) + self.final = conv_nd( + ndims=ndims, + in_channels=filters_per_layer[0], + out_channels=output_classes, + kernel_size=1, + ) def forward(self, x): diff --git a/requirements-dl.txt b/requirements-dl.txt new file mode 100644 index 00000000..3e711455 --- /dev/null +++ b/requirements-dl.txt @@ -0,0 +1,4 @@ +pytorch-lightning >= 1.3.7 +torch >= 1.9.0 +comet-ml >= 3.12.2 +imgaug >= 0.4.0 \ No newline at end of file From 0db479b6483680293acae1e25e6ea72ac484eec0 Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Fri, 13 Aug 2021 11:46:37 +1000 Subject: [PATCH 096/264] Prepare nnUnet service example --- examples/experimental/nnunet_service.ipynb | 88 +++++++++++----------- 1 file changed, 46 insertions(+), 42 deletions(-) diff --git a/examples/experimental/nnunet_service.ipynb b/examples/experimental/nnunet_service.ipynb index 1f8bbc71..b7116437 100644 --- a/examples/experimental/nnunet_service.ipynb +++ b/examples/experimental/nnunet_service.ipynb @@ -26,19 +26,24 @@ "sys.path.append(\"../../..\")\n", "\n", "import os\n", + "from pathlib import Path\n", + "\n", + "import SimpleITK as sitk \n", + "import time\n", "\n", "from platipy.backend.client import PlatiPyClient\n", "from platipy.imaging.tests.data import get_lung_nifti\n", + "from platipy.imaging import ImageVisualiser\n", + "from platipy.imaging.label.utils import get_com\n", "\n", "from loguru import logger\n", "\n", "host = \"127.0.0.1\" # Set the host name or IP of the server running the service here\n", - "host = \"10.55.72.183\"\n", "port = 8001 # Set the port the service was configured to run on here\n", "\n", "api_key = 'XXX' # Put API key here\n", - "api_key = \"fc1858e6-4432-47a4-b3b6-6df0ff652c38\"\n", - "algorithm_name = \"nnUNet Segmentation\" # The name of the algorithm, in this case it should be left as is\n", + "\n", + "algorithm_name = \"nnUNet Service\" # The name of the algorithm, in this case it should be left as is\n", "\n", "log_level = \"INFO\" # Choose an appropriate level of logging output: \"DEBUG\" or \"INFO\"\n", "\n", @@ -121,9 +126,9 @@ "metadata": {}, "outputs": [], "source": [ - "pat_id = list(images.keys())[0]\n", - "ct_file = os.path.join(images[pat_id], \"CT.nii.gz\")\n", - "data_object = client.add_data_object(dataset, file_path=ct_file)" + "images = [i for i in lung_data.glob(\"*/IMAGES/*.nii.gz\")]\n", + "ct_image = str(images[0])\n", + "data_object = client.add_data_object(dataset, file_path=ct_image)" ] }, { @@ -159,38 +164,11 @@ "metadata": {}, "outputs": [], "source": [ - "images" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "atlas_cases = list(images.keys())[1:]\n", - "atlas_path = os.path.dirname(images[atlas_cases[0]])\n", - "\n", "settings = client.get_default_settings()\n", - "\n", - "# Atlas settings\n", - "settings[\"atlasSettings\"][\"atlasPath\"] = atlas_path\n", - "settings[\"atlasSettings\"][\"atlasStructures\"] = [\"Heart\",\"Lung_L\",\"Lung_R\"]\n", - "settings[\"atlasSettings\"][\"atlasIdList\"] = atlas_cases\n", - "settings[\"atlasSettings\"][\"atlasImageFormat\"] = '{0}/CT.nii.gz'\n", - "settings[\"atlasSettings\"][\"atlasLabelFormat\"] = '{0}/Struct_{1}.nii.gz' \n", - "\n", - "# Run the DIR a bit more than default\n", - "settings['deformableSettings']['iterationStaging'] = [75,50,50]\n", - "\n", - "# Run the IAR using the heart\n", - "settings[\"IARSettings\"][\"referenceStructure\"] = 'Lung_L' \n", - "\n", - "# Set the threshold\n", - "settings['labelFusionSettings'][\"optimalThreshold\"] = {\"Heart\":0.5, \"Lung_L\": 0.5, \"Lung_R\": 0.5}\n", - "\n", - "# No vessels\n", - "settings['vesselSpliningSettings']['vesselNameList'] = []" + "settings['task'] = \"Task200_ClinicalHeart\"\n", + "settings['config'] = \"3d_lowres\"\n", + "settings['trainer'] = \"nnUNetTrainerHeart\"\n", + "settings['clean_sup_slices'] = True" ] }, { @@ -210,8 +188,13 @@ }, "outputs": [], "source": [ + "start = time.time()\n", + "\n", "for status in client.run_algorithm(dataset, config=settings):\n", - " print('.', end='')" + " print('.', end='')\n", + "\n", + "end = time.time()\n", + "print(f\"Took {end - start:.1f} seconds\")" ] }, { @@ -229,10 +212,30 @@ "metadata": {}, "outputs": [], "source": [ - "output_directory = os.path.join(\".\", \"results\", pat_id)\n", + "output_directory = os.path.join(\".\", \"results\")\n", "client.download_output_objects(dataset, output_path=output_directory)" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Display the results" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "heart = sitk.ReadImage(str([s for s in Path(output_directory).glob(\"*\")][0]))\n", + "\n", + "vis = ImageVisualiser(sitk.ReadImage(str(ct_image)), cut=get_com(heart))\n", + "vis.add_contour({\"Heart\": heart})\n", + "fig=vis.show()" + ] + }, { "cell_type": "code", "execution_count": null, @@ -246,7 +249,8 @@ "hash": "767d51c1340bd893661ea55ea3124f6de3c7a262a8b4abca0554b478b1e2ff90" }, "kernelspec": { - "display_name": "Python 3.6.9 64-bit", + "display_name": "Python 3", + "language": "python", "name": "python3" }, "language_info": { @@ -259,9 +263,9 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.6.9" + "version": "3.8.8" } }, "nbformat": 4, - "nbformat_minor": 2 + "nbformat_minor": 4 } From 810fe5ab0f1a86a4e2f740772e0f77cefe514340 Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Fri, 13 Aug 2021 01:50:28 +0000 Subject: [PATCH 097/264] Add updated dataset code --- platipy/imaging/cnn/dataset.py | 32 ++++++++++++++++++++++++++++---- 1 file changed, 28 insertions(+), 4 deletions(-) diff --git a/platipy/imaging/cnn/dataset.py b/platipy/imaging/cnn/dataset.py index a50fa867..d4b63e45 100644 --- a/platipy/imaging/cnn/dataset.py +++ b/platipy/imaging/cnn/dataset.py @@ -16,6 +16,15 @@ def get_contour_mask(masks, kernel=5): + """Returns a mask around the region where observer masks don't agree + + Args: + masks (list): List of observer masks (as sitk.Image) + kernel (int, optional): The size of the kernal to dilate the contour of. Defaults to 5. + + Returns: + sitk.Image: The resulting contour mask + """ if not hasattr(kernel, "__iter__"): kernel = (kernel,) * 3 @@ -133,6 +142,7 @@ def __init__( spacing=[1, 1, 1], crop_to_mm=128, contour_mask_kernel=5, + ndims=2, ): """Prepare a dataset from Nifti images/labels @@ -149,6 +159,7 @@ def __init__( self.transforms = prepare_transforms() self.slices = [] self.working_dir = Path(working_dir) + self.ndims = ndims self.img_dir = working_dir.joinpath("img") self.label_dir = working_dir.joinpath("label") @@ -212,21 +223,34 @@ def __init__( contour_mask = get_contour_mask(observers, kernel=contour_mask_kernel) - for z_slice in range(img.GetSize()[2]): + z_range = range(img.GetSize()[2]) + if ndims == 3: + z_range = range(1) + for z_slice in z_range: # Save the image slice - img_slice = img[:, :, z_slice] + if ndims == 2: + img_slice = img[:, :, z_slice] + else: + img_slice = img + img_file = self.img_dir.joinpath(f"{case_id}_{z_slice}.npy") np.save(img_file, sitk.GetArrayFromImage(img_slice)) # Save the contour mask slice - contour_mask_slice = contour_mask[:, :, z_slice] + if ndims == 2: + contour_mask_slice = contour_mask[:, :, z_slice] + else: + contour_mask_slice = contour_mask contour_mask_file = self.contour_mask_dir.joinpath(f"{case_id}_{z_slice}.npy") np.save(contour_mask_file, sitk.GetArrayFromImage(contour_mask_slice)) for obs, label in enumerate(observers): - label_slice = label[:, :, z_slice] + if ndims == 2: + label_slice = label[:, :, z_slice] + else: + label_slice = label label_file = self.label_dir.joinpath(f"{case_id}_{obs}_{z_slice}.npy") np.save(label_file, sitk.GetArrayFromImage(label_slice)) self.slices.append( From 3762bc324d44adbbb1802e9e0ee710bab6e85774 Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Sat, 14 Aug 2021 07:29:51 +0000 Subject: [PATCH 098/264] Add localiser network --- platipy/imaging/cnn/dataload.py | 204 ++++++++++++++++++ platipy/imaging/cnn/dataset.py | 9 +- platipy/imaging/cnn/localise.py | 363 ++++++++++++++++++++++++++++++++ platipy/imaging/cnn/train.py | 196 +---------------- 4 files changed, 578 insertions(+), 194 deletions(-) create mode 100644 platipy/imaging/cnn/dataload.py create mode 100644 platipy/imaging/cnn/localise.py diff --git a/platipy/imaging/cnn/dataload.py b/platipy/imaging/cnn/dataload.py new file mode 100644 index 00000000..256b2829 --- /dev/null +++ b/platipy/imaging/cnn/dataload.py @@ -0,0 +1,204 @@ +import random +import math +from pathlib import Path + +import torch + +import pytorch_lightning as pl + +from platipy.imaging.cnn.dataset import NiftiDataset +from platipy.imaging.cnn.sampler import ObserverSampler + + +class UNetDataModule(pl.LightningDataModule): + def __init__( + self, + data_dir: str = "./data", + augmented_dir: str = None, + working_dir: str = "./working", + case_glob="images/*.nii.gz", + image_glob="images/{case}.nii.gz", + label_glob="labels/{case}_*.nii.gz", + augmented_case_glob="{case}/*", + augmented_image_glob="images/{augmented_case}.nii.gz", + augmented_label_glob="labels/{augmented_case}_*.nii.gz", + augment_on_fly=True, + fold=0, + k_folds=5, + batch_size=5, + num_workers=4, + crop_to_mm=128, + num_observers=5, + spacing=[1, 1, 1], + contour_mask_kernel=3, + ndims=2, + **kwargs, + ): + super().__init__() + self.data_dir = Path(data_dir) + self.augmented_dir = augmented_dir + self.working_dir = Path(working_dir) + + self.case_glob = case_glob + self.image_glob = image_glob + self.label_glob = label_glob + self.augmented_case_glob = augmented_case_glob + self.augmented_image_glob = augmented_image_glob + self.augmented_label_glob = augmented_label_glob + + self.augment_on_fly = augment_on_fly + self.fold = fold + self.k_folds = k_folds + + self.train_cases = [] + self.validation_cases = [] + + self.batch_size = batch_size + self.num_workers = num_workers + self.crop_to_mm = crop_to_mm + self.num_observers = num_observers + self.spacing = spacing + self.contour_mask_kernel = contour_mask_kernel + + print(self.spacing) + + self.training_set = None + self.validation_set = None + + self.ndims = ndims + + print(f"Training fold {self.fold}") + + @staticmethod + def add_model_specific_args(parent_parser): + parser = parent_parser.add_argument_group("Data Loader") + parser.add_argument("--data_dir", type=str, default="./data") + parser.add_argument("--augmented_dir", type=str, default=None) + parser.add_argument("--augment_onfly", type=bool, default=True) + parser.add_argument("--fold", type=int, default=0) + parser.add_argument("--k_folds", type=int, default=5) + parser.add_argument("--batch_size", type=int, default=5) + parser.add_argument("--num_workers", type=int, default=4) + parser.add_argument("--case_glob", type=str, default="images/*.nii.gz") + parser.add_argument("--image_glob", type=str, default="images/{case}.nii.gz") + parser.add_argument("--label_glob", type=str, default="labels/{case}_*.nii.gz") + parser.add_argument("--augmented_case_glob", type=str, default=None) + parser.add_argument("--augmented_image_glob", type=str, default=None) + parser.add_argument("--augmented_label_glob", type=str, default=None) + parser.add_argument("--crop_to_mm", type=int, default=128) + parser.add_argument("--contour_mask_kernel", type=int, default=5) + parser.add_argument("--ndims", type=int, default=2) + + return parent_parser + + def setup(self, stage=None): + + cases = [ + p.name.replace(".nii.gz", "") + for p in self.data_dir.glob(self.case_glob) + if not p.name.startswith(".") + ] + cases.sort() + random.shuffle(cases) # will be consistent for same value of 'seed everything' + cases_per_fold = math.ceil(len(cases) / self.k_folds) + for f in range(self.k_folds): + + if self.fold == f: + self.validation_cases = cases[f * cases_per_fold : (f + 1) * cases_per_fold] + else: + self.train_cases += cases[f * cases_per_fold : (f + 1) * cases_per_fold] + + print(f"Training cases: {self.train_cases}") + print(f"Validation cases: {self.validation_cases}") + + train_data = [ + { + "id": case, + "image": self.data_dir.joinpath(self.image_glob.format(case=case)), + "label": [p for p in self.data_dir.glob(self.label_glob.format(case=case))], + } + for case in self.train_cases + ] + + # If a directory with augmented data is specified, use that for training as well + if self.augmented_dir is not None: + + for case in self.train_cases: + + case_aug_dir = Path(self.augmented_dir.format(case=case)) + augmented_cases = [ + p.name.replace(".nii.gz", "") + for p in case_aug_dir.glob(self.augmented_case_glob.format(case=case)) + if not p.name.startswith(".") + ] + + train_data += [ + { + "id": f"{case}_{augmented_case}", + "image": case_aug_dir.joinpath( + self.augmented_image_glob.format( + case=case, augmented_case=augmented_case + ) + ), + "label": [ + p + for p in case_aug_dir.glob( + self.augmented_label_glob.format( + case=case, augmented_case=augmented_case + ) + ) + ], + } + for augmented_case in augmented_cases + ] + + validation_data = [ + { + "id": case, + "image": self.data_dir.joinpath(self.image_glob.format(case=case)), + "label": [ + p + for p in self.data_dir.glob(self.label_glob.format(case=case)) + if not "edited" in p.name + ], + } + for case in self.validation_cases + ] + + self.training_set = NiftiDataset( + train_data, + self.working_dir, + augment_on_the_fly=self.augment_on_fly, + spacing=self.spacing, + crop_to_mm=self.crop_to_mm, + contour_mask_kernel=self.contour_mask_kernel, + ndims=self.ndims, + ) + self.validation_set = NiftiDataset( + validation_data, + self.working_dir, + augment_on_the_fly=False, + spacing=self.spacing, + crop_to_mm=self.crop_to_mm, + contour_mask_kernel=self.contour_mask_kernel, + ndims=self.ndims, + ) + + def train_dataloader(self): + return torch.utils.data.DataLoader( + self.training_set, + batch_size=self.batch_size, + shuffle=True, + num_workers=self.num_workers, + ) + + def val_dataloader(self): + return torch.utils.data.DataLoader( + self.validation_set, + batch_sampler=torch.utils.data.BatchSampler( + ObserverSampler(self.validation_set, self.num_observers), + batch_size=self.num_observers, + drop_last=False, + ), + num_workers=self.num_workers, + ) diff --git a/platipy/imaging/cnn/dataset.py b/platipy/imaging/cnn/dataset.py index d4b63e45..13e1c3ca 100644 --- a/platipy/imaging/cnn/dataset.py +++ b/platipy/imaging/cnn/dataset.py @@ -140,8 +140,9 @@ def __init__( working_dir, augment_on_the_fly=True, spacing=[1, 1, 1], - crop_to_mm=128, + crop_to_mm=None, contour_mask_kernel=5, + combine_observers=None, ndims=2, ): """Prepare a dataset from Nifti images/labels @@ -223,6 +224,12 @@ def __init__( contour_mask = get_contour_mask(observers, kernel=contour_mask_kernel) + if combine_observers == "union": + observers = [get_union_mask(observers)] + + if combine_observers == "intersection": + observers = [get_intersection_mask(observers)] + z_range = range(img.GetSize()[2]) if ndims == 3: z_range = range(1) diff --git a/platipy/imaging/cnn/localise.py b/platipy/imaging/cnn/localise.py new file mode 100644 index 00000000..2e3c35d7 --- /dev/null +++ b/platipy/imaging/cnn/localise.py @@ -0,0 +1,363 @@ +# Copyright 2021 University of New South Wales, University of Sydney, Ingham Institute + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +import os +import tempfile +import json + +from pathlib import Path +import SimpleITK as sitk +import numpy as np + +import comet_ml # pylint: disable=unused-import +from pytorch_lightning.loggers import CometLogger + +import torch +import pytorch_lightning as pl +from pytorch_lightning.callbacks import LearningRateMonitor + +from argparse import ArgumentParser + +import matplotlib.pyplot as plt + +from platipy.imaging.cnn.unet import UNet +from platipy.imaging.cnn.dataload import UNetDataModule + +from platipy.imaging import ImageVisualiser +from platipy.imaging.label.utils import get_com + + +def post_process(pred): + + # Take only the largest componenet + labelled_image = sitk.ConnectedComponent(pred) + label_shape_filter = sitk.LabelShapeStatisticsImageFilter() + label_shape_filter.Execute(labelled_image) + label_indices = label_shape_filter.GetLabels() + voxel_counts = [label_shape_filter.GetNumberOfPixels(i) for i in label_indices] + if len(voxel_counts) > 0: + largest_component_label = label_indices[np.argmax(voxel_counts)] + largest_component_image = labelled_image == largest_component_label + pred = sitk.Cast(largest_component_image, sitk.sitkUInt8) + + # Fill any holes in the structure + pred = sitk.BinaryMorphologicalClosing(pred, (5, 5, 5)) + pred = sitk.BinaryFillhole(pred) + + return pred + + +def get_metrics(target, pred): + + result = {} + lomif = sitk.LabelOverlapMeasuresImageFilter() + lomif.Execute(target, pred) + result["JI"] = lomif.GetJaccardCoefficient() + result["DSC"] = lomif.GetDiceCoefficient() + + if sitk.GetArrayFromImage(pred).sum() == 0: + result["HD"] = 1000 + result["ASD"] = 100 + else: + hdif = sitk.HausdorffDistanceImageFilter() + hdif.Execute(target, pred) + result["HD"] = hdif.GetHausdorffDistance() + result["ASD"] = hdif.GetAverageHausdorffDistance() + + return result + + +class LocaliseUNet(pl.LightningModule): + def __init__( + self, + **kwargs, + ): + super().__init__() + + self.save_hyperparameters() + + self.unet = UNet( + self.hparams.input_channels, + self.hparams.num_classes, + filters_per_layer=[32, 64, 128], + final_layer=True, + ) + + self.validation_directory = None + + @staticmethod + def add_model_specific_args(parent_parser): + parser = parent_parser.add_argument_group("Localize UNet") + parser.add_argument("--learning_rate", type=float, default=1e-3) + parser.add_argument("--input_channels", type=int, default=1) + parser.add_argument("--num_classes", type=int, default=2) + + return parent_parser + + def forward(self, x): + return self.unet.forward(x) + + def configure_optimizers(self): + optimizer = torch.optim.Adam( + self.parameters(), lr=self.hparams.learning_rate, weight_decay=0 + ) + + # scheduler = torch.optim.lr_scheduler.LambdaLR( + # optimizer, lr_lambda=[lambda epoch: self.hparams.lr_lambda ** (epoch)] + # ) + + return optimizer + + def training_step(self, batch, _): + + x, y, _, _ = batch + + pred = self.unet.forward(x) + + criterion = torch.nn.BCEWithLogitsLoss(reduction="mean") + + y = torch.unsqueeze(y, dim=1) + not_y = y.logical_not() + y = torch.cat((not_y, y), dim=1).float() + + loss = criterion(input=pred, target=y) + + return loss + + def validation_step(self, batch, batch_idx): + + if self.validation_directory is None: + self.validation_directory = Path(tempfile.mkdtemp()) + print(self.validation_directory) + + with torch.set_grad_enabled(False): + x, y, _, info = batch + + for s in range(y.shape[0]): + + img_file = self.validation_directory.joinpath( + f"img_{info['case'][s]}_{info['z'][s]}.npy" + ) + np.save(img_file, x[0].squeeze(0).cpu().numpy()) + + mask_file = self.validation_directory.joinpath( + f"mask_{info['case'][s]}_{info['z'][s]}_{info['observer'][s]}.npy" + ) + np.save(mask_file, y[s].squeeze(0).cpu().numpy()) + + pred = self.unet.forward(x[s].unsqueeze(0)) + pred_file = self.validation_directory.joinpath( + f"pred_{info['case'][s]}_{info['z'][s]}.npy" + ) + pred = np.argmax(pred.squeeze(0).cpu().numpy(), axis=0) + np.save(pred_file, pred) + + return info + + def validation_epoch_end(self, validation_step_outputs): + + cases = {} + for info in validation_step_outputs: + + for case, z, observer in zip(info["case"], info["z"], info["observer"]): + + if not case in cases: + cases[case] = {"slices": z.item(), "observers": [observer.item()]} + else: + if z.item() > cases[case]["slices"]: + cases[case]["slices"] = z.item() + if not observer in cases[case]["observers"]: + cases[case]["observers"].append(observer.item()) + + metrics = {"JI": [], "DSC": [], "HD": [], "ASD": []} + for case in cases: + + img_arrs = [] + pred_arrs = [] + slices = [] + for z in range(cases[case]["slices"] + 1): + img_file = self.validation_directory.joinpath(f"img_{case}_{z}.npy") + pred_file = self.validation_directory.joinpath(f"pred_{case}_{z}.npy") + if img_file.exists(): + img_arrs.append(np.load(img_file)) + pred_arrs.append(np.load(pred_file)) + slices.append(z) + + if len(slices) < 5: + # Likely initial sanity check + continue + + img_arr = np.stack(img_arrs) + img = sitk.GetImageFromArray(img_arr) + img.SetSpacing(self.hparams.spacing) + + pred_arr = np.stack(pred_arrs) + pred = sitk.GetImageFromArray(pred_arr) + pred = sitk.Cast(pred, sitk.sitkUInt8) + pred = post_process(pred) + pred.CopyInformation(img) + sitk.WriteImage(pred, f"val_pred_{case}.nii.gz") + + color_dict = {} + obs_dict = {} + + img_vis = ImageVisualiser( + img, cut=get_com(pred), figure_size_in=16, window=[-1.0, 1.0] + ) + + for _, observer in enumerate(cases[case]["observers"]): + mask_arrs = [] + for z in slices: + mask_file = self.validation_directory.joinpath( + f"mask_{case}_{z}_{observer}.npy" + ) + + mask_arrs.append(np.load(mask_file)) + + mask_arr = np.stack(mask_arrs) + mask = sitk.GetImageFromArray(mask_arr) + mask = sitk.Cast(mask, sitk.sitkUInt8) + mask.CopyInformation(img) + obs_dict[f"manual_{observer}"] = mask + color_dict[f"manual_{observer}"] = [0.7, 0.2, 0.2] + + contour_dict = {**obs_dict} + contour_dict["pred"] = pred + color_dict["pred"] = [0.2, 0.4, 0.8] + + img_vis.add_contour(contour_dict, color=color_dict) + fig = img_vis.show() + figure_path = f"valid_{case}.png" + fig.savefig(figure_path, dpi=300) + plt.close("all") + + try: + self.logger.experiment.log_image(figure_path) + except AttributeError: + # Likely offline mode + pass + + case_metrics = get_metrics(pred, mask) + for m in case_metrics: + metrics[m].append(case_metrics[m]) + + for m in metrics: + self.log( + m, + np.array(metrics[m]).mean(), + on_step=False, + on_epoch=True, + prog_bar=False, + logger=True, + ) + + +def main(args, config_json_path=None): + + pl.seed_everything(args.seed, workers=True) + + args.working_dir = Path(args.working_dir) + args.working_dir = args.working_dir.joinpath(args.experiment) + args.default_root_dir = str(args.working_dir) + + comet_api_key = None + comet_workspace = None + comet_project = None + + if args.comet_api_key: + comet_api_key = args.comet_api_key + comet_workspace = args.comet_workspace + comet_project = args.comet_project + + if comet_api_key is None: + if "COMET_API_KEY" in os.environ: + comet_api_key = os.environ["COMET_API_KEY"] + if "COMET_WORKSPACE" in os.environ: + comet_workspace = os.environ["COMET_WORKSPACE"] + if "COMET_PROJECT" in os.environ: + comet_project = os.environ["COMET_PROJECT"] + + if comet_api_key is not None: + comet_logger = CometLogger( + api_key=comet_api_key, + workspace=comet_workspace, + project_name=comet_project, + experiment_name=args.experiment, + save_dir=args.working_dir, + offline=args.offline, + ) + if config_json_path: + comet_logger.experiment.log_code(config_json_path) + + dict_args = vars(args) + + data_module = UNetDataModule(**dict_args) + + prob_unet = LocaliseUNet(**dict_args) + + if args.resume_from is not None: + trainer = pl.Trainer(resume_from_checkpoint=args.resume_from) + else: + trainer = pl.Trainer.from_argparse_args(args) + + if comet_api_key is not None: + trainer.logger = comet_logger + + lr_monitor = LearningRateMonitor(logging_interval="step") + trainer.callbacks.append(lr_monitor) + + trainer.fit(prob_unet, data_module) + + +if __name__ == "__main__": + + args = None + config_json_path = None + if len(sys.argv) == 2: + # Check if JSON file parsed, if so read arguments from there... + if sys.argv[-1].endswith(".json"): + config_json_path = sys.argv[-1] + with open(config_json_path, "r") as f: + params = json.load(f) + args = [] + for key in params: + args.append(f"--{key}") + + if isinstance(params[key], list): + for s in params[key]: + args.append(str(s)) + else: + args.append(str(params[key])) + + arg_parser = ArgumentParser() + arg_parser = LocaliseUNet.add_model_specific_args(arg_parser) + arg_parser = UNetDataModule.add_model_specific_args(arg_parser) + arg_parser = pl.Trainer.add_argparse_args(arg_parser) + arg_parser.add_argument( + "--config", type=str, default=None, help="JSON file with parameters to load" + ) + arg_parser.add_argument("--seed", type=int, default=42, help="an integer to use as seed") + arg_parser.add_argument("--experiment", type=str, default="default", help="Name of experiment") + arg_parser.add_argument("--working_dir", type=str, default="./working") + arg_parser.add_argument("--num_observers", type=int, default=5) + arg_parser.add_argument("--spacing", nargs="+", type=int, default=[3, 3, 3]) + arg_parser.add_argument("--offline", type=bool, default=False) + arg_parser.add_argument("--comet_api_key", type=str, default=None) + arg_parser.add_argument("--comet_workspace", type=str, default=None) + arg_parser.add_argument("--comet_project", type=str, default=None) + arg_parser.add_argument("--resume_from", type=str, default=None) + arg_parser.add_argument("--combine_observers", type=str, default="union") + + main(arg_parser.parse_args(args), config_json_path=config_json_path) diff --git a/platipy/imaging/cnn/train.py b/platipy/imaging/cnn/train.py index f245f325..250324d2 100644 --- a/platipy/imaging/cnn/train.py +++ b/platipy/imaging/cnn/train.py @@ -37,8 +37,7 @@ from platipy.imaging.cnn.prob_unet import ProbabilisticUnet from platipy.imaging.cnn.unet import l2_regularisation -from platipy.imaging.cnn.dataset import NiftiDataset -from platipy.imaging.cnn.sampler import ObserverSampler +from platipy.imaging.cnn.dataload import UNetDataModule from platipy.imaging import ImageVisualiser from platipy.imaging.label.utils import get_com @@ -394,195 +393,6 @@ def validation_epoch_end(self, validation_step_outputs): # shutil.rmtree(self.validation_directory) -class ProbUNetDataModule(pl.LightningDataModule): - def __init__( - self, - data_dir: str = "./data", - augmented_dir: str = None, - working_dir: str = "./working", - case_glob="images/*.nii.gz", - image_glob="images/{case}.nii.gz", - label_glob="labels/{case}_*.nii.gz", - augmented_case_glob="{case}/*", - augmented_image_glob="images/{augmented_case}.nii.gz", - augmented_label_glob="labels/{augmented_case}_*.nii.gz", - fold=0, - k_folds=5, - batch_size=5, - num_workers=4, - crop_to_mm=128, - num_observers=5, - spacing=[1, 1, 1], - contour_mask_kernel=3, - ndims=2, - **kwargs, - ): - super().__init__() - self.data_dir = Path(data_dir) - self.augmented_dir = augmented_dir - self.working_dir = Path(working_dir) - - self.case_glob = case_glob - self.image_glob = image_glob - self.label_glob = label_glob - self.augmented_case_glob = augmented_case_glob - self.augmented_image_glob = augmented_image_glob - self.augmented_label_glob = augmented_label_glob - - self.fold = fold - self.k_folds = k_folds - - self.train_cases = [] - self.validation_cases = [] - - self.batch_size = batch_size - self.num_workers = num_workers - self.crop_to_mm = crop_to_mm - self.num_observers = num_observers - self.spacing = spacing - self.contour_mask_kernel = contour_mask_kernel - - self.training_set = None - self.validation_set = None - - self.ndims = ndims - - print(f"Training fold {self.fold}") - - @staticmethod - def add_model_specific_args(parent_parser): - parser = parent_parser.add_argument_group("Data Loader") - parser.add_argument("--data_dir", type=str, default="./data") - parser.add_argument("--augmented_dir", type=str, default=None) - parser.add_argument("--fold", type=int, default=0) - parser.add_argument("--k_folds", type=int, default=5) - parser.add_argument("--batch_size", type=int, default=5) - parser.add_argument("--num_workers", type=int, default=4) - parser.add_argument("--case_glob", type=str, default="images/*.nii.gz") - parser.add_argument("--image_glob", type=str, default="images/{case}.nii.gz") - parser.add_argument("--label_glob", type=str, default="labels/{case}_*.nii.gz") - parser.add_argument("--augmented_case_glob", type=str, default=None) - parser.add_argument("--augmented_image_glob", type=str, default=None) - parser.add_argument("--augmented_label_glob", type=str, default=None) - parser.add_argument("--crop_to_mm", type=int, default=128) - parser.add_argument("--contour_mask_kernel", type=int, default=5) - parser.add_argument("--ndims", type=int, default=2) - - return parent_parser - - def setup(self, stage=None): - - cases = [ - p.name.replace(".nii.gz", "") - for p in self.data_dir.glob(self.case_glob) - if not p.name.startswith(".") - ] - cases.sort() - random.shuffle(cases) # will be consistent for same value of 'seed everything' - cases_per_fold = math.ceil(len(cases) / self.k_folds) - for f in range(self.k_folds): - - if self.fold == f: - self.validation_cases = cases[f * cases_per_fold : (f + 1) * cases_per_fold] - else: - self.train_cases += cases[f * cases_per_fold : (f + 1) * cases_per_fold] - - print(f"Training cases: {self.train_cases}") - print(f"Validation cases: {self.validation_cases}") - - train_data = [ - { - "id": case, - "image": self.data_dir.joinpath(self.image_glob.format(case=case)), - "label": [p for p in self.data_dir.glob(self.label_glob.format(case=case))], - } - for case in self.train_cases - ] - - # If a directory with augmented data is specified, use that for training as well - if self.augmented_dir is not None: - - for case in self.train_cases: - - case_aug_dir = Path(self.augmented_dir.format(case=case)) - augmented_cases = [ - p.name.replace(".nii.gz", "") - for p in case_aug_dir.glob(self.augmented_case_glob.format(case=case)) - if not p.name.startswith(".") - ] - - train_data += [ - { - "id": f"{case}_{augmented_case}", - "image": case_aug_dir.joinpath( - self.augmented_image_glob.format( - case=case, augmented_case=augmented_case - ) - ), - "label": [ - p - for p in case_aug_dir.glob( - self.augmented_label_glob.format( - case=case, augmented_case=augmented_case - ) - ) - ], - } - for augmented_case in augmented_cases - ] - - validation_data = [ - { - "id": case, - "image": self.data_dir.joinpath(self.image_glob.format(case=case)), - "label": [ - p - for p in self.data_dir.glob(self.label_glob.format(case=case)) - if not "edited" in p.name - ], - } - for case in self.validation_cases - ] - - self.training_set = NiftiDataset( - train_data, - self.working_dir, - augment_on_the_fly=False, - spacing=self.spacing, - crop_to_mm=self.crop_to_mm, - contour_mask_kernel=self.contour_mask_kernel, - ndims=self.ndims, - ) - self.validation_set = NiftiDataset( - validation_data, - self.working_dir, - augment_on_the_fly=False, - spacing=self.spacing, - crop_to_mm=self.crop_to_mm, - contour_mask_kernel=self.contour_mask_kernel, - ndims=self.ndims, - ) - - def train_dataloader(self): - return torch.utils.data.DataLoader( - self.training_set, - batch_size=self.batch_size, - shuffle=True, - num_workers=self.num_workers, - ) - - def val_dataloader(self): - return torch.utils.data.DataLoader( - self.validation_set, - batch_sampler=torch.utils.data.BatchSampler( - ObserverSampler(self.validation_set, self.num_observers), - batch_size=self.num_observers, - drop_last=False, - ), - num_workers=self.num_workers, - ) - - def main(args, config_json_path=None): pl.seed_everything(args.seed, workers=True) @@ -622,7 +432,7 @@ def main(args, config_json_path=None): dict_args = vars(args) - data_module = ProbUNetDataModule(**dict_args) + data_module = UNetDataModule(**dict_args) prob_unet = ProbUNet(**dict_args) @@ -662,7 +472,7 @@ def main(args, config_json_path=None): arg_parser = ArgumentParser() arg_parser = ProbUNet.add_model_specific_args(arg_parser) - arg_parser = ProbUNetDataModule.add_model_specific_args(arg_parser) + arg_parser = UNetDataModule.add_model_specific_args(arg_parser) arg_parser = pl.Trainer.add_argparse_args(arg_parser) arg_parser.add_argument( "--config", type=str, default=None, help="JSON file with parameters to load" From 6eebade836a6b5e5071986ac8a18a5d728e89d27 Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Sat, 14 Aug 2021 17:33:36 +1000 Subject: [PATCH 099/264] Trainer to support 3d --- platipy/imaging/cnn/train.py | 66 ++++++++++++++++++++++++------------ 1 file changed, 44 insertions(+), 22 deletions(-) diff --git a/platipy/imaging/cnn/train.py b/platipy/imaging/cnn/train.py index f245f325..f5941d57 100644 --- a/platipy/imaging/cnn/train.py +++ b/platipy/imaging/cnn/train.py @@ -263,28 +263,37 @@ def validation_epoch_end(self, validation_step_outputs): **{f"probnet_{m}": [] for m in metrics}, **{f"unet_{m}": [] for m in metrics}, } + for case in cases: img_arrs = [] mean_arrs = [] slices = [] - for z in range(cases[case]["slices"] + 1): - img_file = self.validation_directory.joinpath(f"img_{case}_{z}.npy") - mean_file = self.validation_directory.joinpath(f"mean_{case}_{z}.npy") - if img_file.exists(): - img_arrs.append(np.load(img_file)) - mean_arrs.append(np.load(mean_file)) - slices.append(z) - - if len(slices) < 5: + + if self.hparams.ndims == 2: + for z in range(cases[case]["slices"] + 1): + img_file = self.validation_directory.joinpath(f"img_{case}_{z}.npy") + mean_file = self.validation_directory.joinpath(f"mean_{case}_{z}.npy") + if img_file.exists(): + img_arrs.append(np.load(img_file)) + mean_arrs.append(np.load(mean_file)) + slices.append(z) + + #if len(slices) < 5: # Likely initial sanity check - continue + # continue + + img_arr = np.stack(img_arrs) + mean_arr = np.stack(mean_arrs) - img_arr = np.stack(img_arrs) + else: + img_file = self.validation_directory.joinpath(f"img_{case}_0.npy") + mean_file = self.validation_directory.joinpath(f"mean_{case}_0.npy") + img_arr = np.load(img_file) + mean_arr = np.load(mean_file) img = sitk.GetImageFromArray(img_arr) img.SetSpacing(self.hparams.spacing) - mean_arr = np.stack(mean_arrs) mean = sitk.GetImageFromArray(mean_arr) mean = sitk.Cast(mean, sitk.sitkUInt8) mean = post_process(mean) @@ -297,34 +306,47 @@ def validation_epoch_end(self, validation_step_outputs): observers = [] samples = [] for idx, observer in enumerate(cases[case]["observers"]): - mask_arrs = [] - sample_arrs = [] - for z in slices: + + if self.hparams.ndims == 2: + mask_arrs = [] + sample_arrs = [] + for z in slices: + mask_file = self.validation_directory.joinpath( + f"mask_{case}_{z}_{observer}.npy" + ) + sample_file = self.validation_directory.joinpath( + f"sample_{case}_{z}_{observer}.npy" + ) + + mask_arrs.append(np.load(mask_file)) + sample_arrs.append(np.load(sample_file)) + + mask_arr = np.stack(mask_arrs) + sample_arr = np.stack(sample_arrs) + + else: mask_file = self.validation_directory.joinpath( f"mask_{case}_{z}_{observer}.npy" ) sample_file = self.validation_directory.joinpath( f"sample_{case}_{z}_{observer}.npy" ) + mask_arr = np.load(mask_file) + sample_arr = np.load(sample_file) - mask_arrs.append(np.load(mask_file)) - sample_arrs.append(np.load(sample_file)) - - mask_arr = np.stack(mask_arrs) mask = sitk.GetImageFromArray(mask_arr) mask = sitk.Cast(mask, sitk.sitkUInt8) mask.CopyInformation(img) - # sitk.WriteImage(mask, f"val_mask_{case}_{observer}.nii.gz") + sitk.WriteImage(mask, f"val_mask_{case}_{observer}.nii.gz") observers.append(mask) obs_dict[f"manual_{observer}"] = mask color_dict[f"manual_{observer}"] = [0.5, 0.5, 0.5] - sample_arr = np.stack(sample_arrs) sample = sitk.GetImageFromArray(sample_arr) sample = sitk.Cast(sample, sitk.sitkUInt8) sample = post_process(sample) sample.CopyInformation(img) - # sitk.WriteImage(sample, f"val_sample_{case}_{observer}.nii.gz") + sitk.WriteImage(sample, f"val_sample_{case}_{observer}.nii.gz") samples.append(sample) pred_dict[f"auto_{self.stddevs[idx]}"] = sample color_dict[f"auto_{self.stddevs[idx]}"] = cmap(observer / 5) From f9fca18e7943ddd766c7f0f6a7f752ba8e22514c Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Sun, 15 Aug 2021 16:39:05 +1000 Subject: [PATCH 100/264] run inference on localise model --- platipy/imaging/cnn/dataload.py | 35 ++++++++++++++++++++++--- platipy/imaging/cnn/dataset.py | 24 ++++++++++-------- platipy/imaging/cnn/localise.py | 45 +++++++++++++++++++++++++++++++-- 3 files changed, 88 insertions(+), 16 deletions(-) diff --git a/platipy/imaging/cnn/dataload.py b/platipy/imaging/cnn/dataload.py index 256b2829..0a1e762f 100644 --- a/platipy/imaging/cnn/dataload.py +++ b/platipy/imaging/cnn/dataload.py @@ -52,6 +52,7 @@ def __init__( self.train_cases = [] self.validation_cases = [] + self.test_cases = [] self.batch_size = batch_size self.num_workers = num_workers @@ -64,6 +65,7 @@ def __init__( self.training_set = None self.validation_set = None + self.test_set = None self.ndims = ndims @@ -104,18 +106,22 @@ def setup(self, stage=None): for f in range(self.k_folds): if self.fold == f: - self.validation_cases = cases[f * cases_per_fold : (f + 1) * cases_per_fold] + self.val_test_cases = cases[f * cases_per_fold : (f + 1) * cases_per_fold] + + self.validation_cases = self.val_test_cases[:int(len(self.val_test_cases)/2)] + self.test_cases = self.val_test_cases[int(len(self.val_test_cases)/2):] else: self.train_cases += cases[f * cases_per_fold : (f + 1) * cases_per_fold] print(f"Training cases: {self.train_cases}") print(f"Validation cases: {self.validation_cases}") + print(f"Testing cases: {self.test_cases}") train_data = [ { "id": case, "image": self.data_dir.joinpath(self.image_glob.format(case=case)), - "label": [p for p in self.data_dir.glob(self.label_glob.format(case=case))], + "label": [p for p in self.data_dir.glob(self.label_glob.format(case=case)) if not "_OLD" in p.name], } for case in self.train_cases ] @@ -147,6 +153,7 @@ def setup(self, stage=None): case=case, augmented_case=augmented_case ) ) + if not "_OLD" in p.name ], } for augmented_case in augmented_cases @@ -159,12 +166,25 @@ def setup(self, stage=None): "label": [ p for p in self.data_dir.glob(self.label_glob.format(case=case)) - if not "edited" in p.name + if not "_OLD" in p.name ], } for case in self.validation_cases ] + test_data = [ + { + "id": case, + "image": self.data_dir.joinpath(self.image_glob.format(case=case)), + "label": [ + p + for p in self.data_dir.glob(self.label_glob.format(case=case)) + if not "_OLD" in p.name + ], + } + for case in self.test_cases + ] + self.training_set = NiftiDataset( train_data, self.working_dir, @@ -183,6 +203,15 @@ def setup(self, stage=None): contour_mask_kernel=self.contour_mask_kernel, ndims=self.ndims, ) + self.test_set = NiftiDataset( + test_data, + self.working_dir, + augment_on_the_fly=False, + spacing=self.spacing, + crop_to_mm=self.crop_to_mm, + contour_mask_kernel=self.contour_mask_kernel, + ndims=self.ndims, + ) def train_dataloader(self): return torch.utils.data.DataLoader( diff --git a/platipy/imaging/cnn/dataset.py b/platipy/imaging/cnn/dataset.py index 13e1c3ca..8975026e 100644 --- a/platipy/imaging/cnn/dataset.py +++ b/platipy/imaging/cnn/dataset.py @@ -52,11 +52,12 @@ def preprocess_image(img, spacing=[1, 1, 1], crop_to_mm=128): new_size[1] = int(img.GetSize()[1] * (img.GetSpacing()[1] / spacing[1])) new_size[2] = int(img.GetSize()[2] * (img.GetSpacing()[2] / spacing[2])) - if new_size[0] < crop_to_mm: - new_size[0] = crop_to_mm + if crop_to_mm: + if new_size[0] < crop_to_mm: + new_size[0] = crop_to_mm - if new_size[1] < crop_to_mm: - new_size[1] = crop_to_mm + if new_size[1] < crop_to_mm: + new_size[1] = crop_to_mm img = sitk.Resample( img, @@ -70,15 +71,16 @@ def preprocess_image(img, spacing=[1, 1, 1], crop_to_mm=128): img.GetPixelID(), ) - center_x = img.GetSize()[0] / 2 - x_from = int(center_x - crop_to_mm / 2) - x_to = x_from + crop_to_mm + if crop_to_mm: + center_x = img.GetSize()[0] / 2 + x_from = int(center_x - crop_to_mm / 2) + x_to = x_from + crop_to_mm - center_y = img.GetSize()[1] / 2 - y_from = int(center_y - crop_to_mm / 2) - y_to = y_from + crop_to_mm + center_y = img.GetSize()[1] / 2 + y_from = int(center_y - crop_to_mm / 2) + y_to = y_from + crop_to_mm - img = img[x_from:x_to, y_from:y_to, :] + img = img[x_from:x_to, y_from:y_to, :] return img diff --git a/platipy/imaging/cnn/localise.py b/platipy/imaging/cnn/localise.py index 2e3c35d7..486dac2b 100644 --- a/platipy/imaging/cnn/localise.py +++ b/platipy/imaging/cnn/localise.py @@ -26,7 +26,7 @@ import torch import pytorch_lightning as pl -from pytorch_lightning.callbacks import LearningRateMonitor +from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint from argparse import ArgumentParser @@ -34,6 +34,7 @@ from platipy.imaging.cnn.unet import UNet from platipy.imaging.cnn.dataload import UNetDataModule +from platipy.imaging.cnn.dataset import preprocess_image from platipy.imaging import ImageVisualiser from platipy.imaging.label.utils import get_com @@ -120,6 +121,30 @@ def configure_optimizers(self): return optimizer + def infer(self, img): + + pp_img = preprocess_image(img, spacing=self.hparams.spacing, crop_to_mm=self.hparams.crop_to_mm) + + preds = [] + for z in range(pp_img.GetSize()[2]): + x = sitk.GetArrayFromImage(pp_img[:,:, z]) + x = torch.Tensor(x) + x = x.unsqueeze(0) + x = x.unsqueeze(0) + y = self(x) + y = y.squeeze(0) + y = np.argmax(y.cpu().detach().numpy(), axis=0) + preds.append(y) + + pred = sitk.GetImageFromArray(np.stack(preds)) + pred = sitk.Cast(pred, sitk.sitkUInt8) + + pred.CopyInformation(pp_img) + pred = post_process(pred) + pred = sitk.Resample(pred, img, sitk.Transform(), sitk.sitkNearestNeighbor) + + return pred + def training_step(self, batch, _): x, y, _, _ = batch @@ -213,6 +238,10 @@ def validation_epoch_end(self, validation_step_outputs): color_dict = {} obs_dict = {} + try: + get_com(pred) + except: + continue img_vis = ImageVisualiser( img, cut=get_com(pred), figure_size_in=16, window=[-1.0, 1.0] ) @@ -270,7 +299,8 @@ def main(args, config_json_path=None): args.working_dir = Path(args.working_dir) args.working_dir = args.working_dir.joinpath(args.experiment) - args.default_root_dir = str(args.working_dir) + args.fold_dir = args.working_dir.joinpath(f"fold_{args.fold}") + args.default_root_dir = str(args.fold_dir) comet_api_key = None comet_workspace = None @@ -318,6 +348,17 @@ def main(args, config_json_path=None): lr_monitor = LearningRateMonitor(logging_interval="step") trainer.callbacks.append(lr_monitor) + # Save the best model + checkpoint_callback = ModelCheckpoint( + monitor="DSC", + dirpath=args.default_root_dir, + filename="localise-{epoch:02d}-{DSC:.2f}", + save_top_k=1, + mode="max", + ) + + trainer.callbacks.append(checkpoint_callback) + trainer.fit(prob_unet, data_module) From a88a94835ab5f3aac57645624f809e8930143c32 Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Sun, 15 Aug 2021 07:15:37 +0000 Subject: [PATCH 101/264] Use localise network to preprocess data --- platipy/imaging/cnn/dataload.py | 24 +++++++++++++++++++----- platipy/imaging/cnn/dataset.py | 18 ++++++++++++++++++ platipy/imaging/cnn/localise.py | 31 +++++++++++++------------------ platipy/imaging/cnn/train.py | 12 +++++------- 4 files changed, 55 insertions(+), 30 deletions(-) diff --git a/platipy/imaging/cnn/dataload.py b/platipy/imaging/cnn/dataload.py index 0a1e762f..7b9610f3 100644 --- a/platipy/imaging/cnn/dataload.py +++ b/platipy/imaging/cnn/dataload.py @@ -11,6 +11,8 @@ class UNetDataModule(pl.LightningDataModule): + """PyTorch data module to training UNets""" + def __init__( self, data_dir: str = "./data", @@ -31,6 +33,8 @@ def __init__( num_observers=5, spacing=[1, 1, 1], contour_mask_kernel=3, + crop_using_localise_model=None, + localise_voxel_grid_size=[100, 100, 100], ndims=2, **kwargs, ): @@ -61,7 +65,8 @@ def __init__( self.spacing = spacing self.contour_mask_kernel = contour_mask_kernel - print(self.spacing) + self.crop_using_localise_model = crop_using_localise_model + self.localise_voxel_grid_size = localise_voxel_grid_size self.training_set = None self.validation_set = None @@ -73,6 +78,7 @@ def __init__( @staticmethod def add_model_specific_args(parent_parser): + """Add arguments used for Data module""" parser = parent_parser.add_argument_group("Data Loader") parser.add_argument("--data_dir", type=str, default="./data") parser.add_argument("--augmented_dir", type=str, default=None) @@ -89,6 +95,10 @@ def add_model_specific_args(parent_parser): parser.add_argument("--augmented_label_glob", type=str, default=None) parser.add_argument("--crop_to_mm", type=int, default=128) parser.add_argument("--contour_mask_kernel", type=int, default=5) + parser.add_argument("--crop_using_localise_model", type=str, default=None) + parser.add_argument( + "--localise_voxel_grid_size", nargs="+", type=int, default=[100, 100, 100] + ) parser.add_argument("--ndims", type=int, default=2) return parent_parser @@ -106,10 +116,10 @@ def setup(self, stage=None): for f in range(self.k_folds): if self.fold == f: - self.val_test_cases = cases[f * cases_per_fold : (f + 1) * cases_per_fold] + val_test_cases = cases[f * cases_per_fold : (f + 1) * cases_per_fold] - self.validation_cases = self.val_test_cases[:int(len(self.val_test_cases)/2)] - self.test_cases = self.val_test_cases[int(len(self.val_test_cases)/2):] + self.validation_cases = val_test_cases[: int(len(val_test_cases) / 2)] + self.test_cases = val_test_cases[int(len(val_test_cases) / 2) :] else: self.train_cases += cases[f * cases_per_fold : (f + 1) * cases_per_fold] @@ -121,7 +131,11 @@ def setup(self, stage=None): { "id": case, "image": self.data_dir.joinpath(self.image_glob.format(case=case)), - "label": [p for p in self.data_dir.glob(self.label_glob.format(case=case)) if not "_OLD" in p.name], + "label": [ + p + for p in self.data_dir.glob(self.label_glob.format(case=case)) + if not "_OLD" in p.name + ], } for case in self.train_cases ] diff --git a/platipy/imaging/cnn/dataset.py b/platipy/imaging/cnn/dataset.py index 8975026e..dd68879f 100644 --- a/platipy/imaging/cnn/dataset.py +++ b/platipy/imaging/cnn/dataset.py @@ -12,7 +12,9 @@ from loguru import logger +from platipy.imaging.cnn.localise import LocaliseUNet from platipy.imaging.label.utils import get_union_mask, get_intersection_mask +from platipy.imaging.utils.crop import label_to_roi, crop_to_roi def get_contour_mask(masks, kernel=5): @@ -143,6 +145,8 @@ def __init__( augment_on_the_fly=True, spacing=[1, 1, 1], crop_to_mm=None, + crop_using_localise_model=None, + localise_voxel_grid_size=[100, 100, 100], contour_mask_kernel=5, combine_observers=None, ndims=2, @@ -156,6 +160,9 @@ def __init__( working_dir (str|path): Working directory where to write prepared files. """ + if crop_to_mm is not None and crop_using_localise_model is not None: + raise AttributeError("Only one of crop_to_mm or crop_using_localise_model may be set") + self.data = data self.transforms = None if augment_on_the_fly: @@ -217,6 +224,17 @@ def __init__( img = preprocess_image(img, spacing=spacing, crop_to_mm=crop_to_mm) + if crop_using_localise_model: + localise_model = LocaliseUNet.load_from_checkpoint(crop_using_localise_model) + localise_model.eval() + localise_pred = localise_model.infer(img) + localise_pred = resample_mask_to_image(img, localise_pred) + + size, index = label_to_roi(localise_pred) + index = [i - int((100 - s) / 2) for i, s in zip(index, size)] + size = localise_voxel_grid_size + img = crop_to_roi(img, size, index) + observers = [] for obs, structure_path in enumerate(structure_paths): structure_path = str(structure_path) diff --git a/platipy/imaging/cnn/localise.py b/platipy/imaging/cnn/localise.py index 486dac2b..323b2073 100644 --- a/platipy/imaging/cnn/localise.py +++ b/platipy/imaging/cnn/localise.py @@ -26,7 +26,7 @@ import torch import pytorch_lightning as pl -from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint +from pytorch_lightning.callbacks import ModelCheckpoint from argparse import ArgumentParser @@ -115,19 +115,17 @@ def configure_optimizers(self): self.parameters(), lr=self.hparams.learning_rate, weight_decay=0 ) - # scheduler = torch.optim.lr_scheduler.LambdaLR( - # optimizer, lr_lambda=[lambda epoch: self.hparams.lr_lambda ** (epoch)] - # ) - return optimizer def infer(self, img): - pp_img = preprocess_image(img, spacing=self.hparams.spacing, crop_to_mm=self.hparams.crop_to_mm) + pp_img = preprocess_image( + img, spacing=self.hparams.spacing, crop_to_mm=self.hparams.crop_to_mm + ) preds = [] for z in range(pp_img.GetSize()[2]): - x = sitk.GetArrayFromImage(pp_img[:,:, z]) + x = sitk.GetArrayFromImage(pp_img[:, :, z]) x = torch.Tensor(x) x = x.unsqueeze(0) x = x.unsqueeze(0) @@ -161,7 +159,7 @@ def training_step(self, batch, _): return loss - def validation_step(self, batch, batch_idx): + def validation_step(self, batch, _): if self.validation_directory is None: self.validation_directory = Path(tempfile.mkdtemp()) @@ -238,13 +236,13 @@ def validation_epoch_end(self, validation_step_outputs): color_dict = {} obs_dict = {} + com = None try: - get_com(pred) + com = get_com(pred) except: - continue - img_vis = ImageVisualiser( - img, cut=get_com(pred), figure_size_in=16, window=[-1.0, 1.0] - ) + com = [int(i / 2) for i in pred.GetSize()] + + img_vis = ImageVisualiser(img, cut=com, figure_size_in=16, window=[-1.0, 1.0]) for _, observer in enumerate(cases[case]["observers"]): mask_arrs = [] @@ -345,9 +343,6 @@ def main(args, config_json_path=None): if comet_api_key is not None: trainer.logger = comet_logger - lr_monitor = LearningRateMonitor(logging_interval="step") - trainer.callbacks.append(lr_monitor) - # Save the best model checkpoint_callback = ModelCheckpoint( monitor="DSC", @@ -377,8 +372,8 @@ def main(args, config_json_path=None): args.append(f"--{key}") if isinstance(params[key], list): - for s in params[key]: - args.append(str(s)) + for list_val in params[key]: + args.append(str(list_val)) else: args.append(str(params[key])) diff --git a/platipy/imaging/cnn/train.py b/platipy/imaging/cnn/train.py index 65ed757b..e5241fa9 100644 --- a/platipy/imaging/cnn/train.py +++ b/platipy/imaging/cnn/train.py @@ -14,9 +14,7 @@ import sys import os -import math import tempfile -import random import json from pathlib import Path @@ -199,7 +197,7 @@ def training_step(self, batch, _): ) return loss - def validation_step(self, batch, batch_idx): + def validation_step(self, batch, _): if self.validation_directory is None: self.validation_directory = Path(tempfile.mkdtemp()) @@ -278,9 +276,9 @@ def validation_epoch_end(self, validation_step_outputs): mean_arrs.append(np.load(mean_file)) slices.append(z) - #if len(slices) < 5: + # if len(slices) < 5: # Likely initial sanity check - # continue + # continue img_arr = np.stack(img_arrs) mean_arr = np.stack(mean_arrs) @@ -487,8 +485,8 @@ def main(args, config_json_path=None): args.append(f"--{key}") if isinstance(params[key], list): - for s in params[key]: - args.append(str(s)) + for list_val in params[key]: + args.append(str(list_val)) else: args.append(str(params[key])) From b19d76d6e16acda122e64181082ed0f12fb6565c Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Sun, 15 Aug 2021 17:36:39 +1000 Subject: [PATCH 102/264] Load checkpoint --- platipy/imaging/cnn/dataload.py | 23 ++++++++++++++++++++--- 1 file changed, 20 insertions(+), 3 deletions(-) diff --git a/platipy/imaging/cnn/dataload.py b/platipy/imaging/cnn/dataload.py index 7b9610f3..acf2d417 100644 --- a/platipy/imaging/cnn/dataload.py +++ b/platipy/imaging/cnn/dataload.py @@ -199,12 +199,25 @@ def setup(self, stage=None): for case in self.test_cases ] + crop_to_mm = None + localise_model_path = None + if self.crop_using_localise_model: + localise_model_path = Path(self.crop_using_localise_model.format(fold=self.fold)) + if localise_model_path.name.isdir(): + localise_model_path = next(localise_model_path.glob("*.ckpt")) + + logger.info(f"Using localise model: {localise_model_path}") + else: + crop_to_mm = self.crop_to_mm + self.training_set = NiftiDataset( train_data, self.working_dir, augment_on_the_fly=self.augment_on_fly, spacing=self.spacing, - crop_to_mm=self.crop_to_mm, + crop_to_mm=crop_to_mm, + crop_using_localise_model=localise_model_path, + localise_voxel_grid_size=self.localise_voxel_grid_size, contour_mask_kernel=self.contour_mask_kernel, ndims=self.ndims, ) @@ -213,7 +226,9 @@ def setup(self, stage=None): self.working_dir, augment_on_the_fly=False, spacing=self.spacing, - crop_to_mm=self.crop_to_mm, + crop_to_mm=crop_to_mm, + crop_using_localise_model=localise_model_path, + localise_voxel_grid_size=self.localise_voxel_grid_size, contour_mask_kernel=self.contour_mask_kernel, ndims=self.ndims, ) @@ -222,7 +237,9 @@ def setup(self, stage=None): self.working_dir, augment_on_the_fly=False, spacing=self.spacing, - crop_to_mm=self.crop_to_mm, + crop_to_mm=crop_to_mm, + crop_using_localise_model=localise_model_path, + localise_voxel_grid_size=self.localise_voxel_grid_size, contour_mask_kernel=self.contour_mask_kernel, ndims=self.ndims, ) From dd0aee3ec558b89e97aff6757c16c6fa683fdc40 Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Sun, 15 Aug 2021 07:42:25 +0000 Subject: [PATCH 103/264] Separate out localise net and train --- platipy/imaging/cnn/dataset.py | 2 +- .../cnn/{localise.py => localise_net.py} | 115 --------------- platipy/imaging/cnn/train_localise.py | 138 ++++++++++++++++++ 3 files changed, 139 insertions(+), 116 deletions(-) rename platipy/imaging/cnn/{localise.py => localise_net.py} (69%) create mode 100644 platipy/imaging/cnn/train_localise.py diff --git a/platipy/imaging/cnn/dataset.py b/platipy/imaging/cnn/dataset.py index dd68879f..5e1b5653 100644 --- a/platipy/imaging/cnn/dataset.py +++ b/platipy/imaging/cnn/dataset.py @@ -12,7 +12,7 @@ from loguru import logger -from platipy.imaging.cnn.localise import LocaliseUNet +from platipy.imaging.cnn.localise_net import LocaliseUNet from platipy.imaging.label.utils import get_union_mask, get_intersection_mask from platipy.imaging.utils.crop import label_to_roi, crop_to_roi diff --git a/platipy/imaging/cnn/localise.py b/platipy/imaging/cnn/localise_net.py similarity index 69% rename from platipy/imaging/cnn/localise.py rename to platipy/imaging/cnn/localise_net.py index 323b2073..0435a227 100644 --- a/platipy/imaging/cnn/localise.py +++ b/platipy/imaging/cnn/localise_net.py @@ -12,10 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -import sys -import os import tempfile -import json from pathlib import Path import SimpleITK as sitk @@ -26,14 +23,10 @@ import torch import pytorch_lightning as pl -from pytorch_lightning.callbacks import ModelCheckpoint - -from argparse import ArgumentParser import matplotlib.pyplot as plt from platipy.imaging.cnn.unet import UNet -from platipy.imaging.cnn.dataload import UNetDataModule from platipy.imaging.cnn.dataset import preprocess_image from platipy.imaging import ImageVisualiser @@ -289,111 +282,3 @@ def validation_epoch_end(self, validation_step_outputs): prog_bar=False, logger=True, ) - - -def main(args, config_json_path=None): - - pl.seed_everything(args.seed, workers=True) - - args.working_dir = Path(args.working_dir) - args.working_dir = args.working_dir.joinpath(args.experiment) - args.fold_dir = args.working_dir.joinpath(f"fold_{args.fold}") - args.default_root_dir = str(args.fold_dir) - - comet_api_key = None - comet_workspace = None - comet_project = None - - if args.comet_api_key: - comet_api_key = args.comet_api_key - comet_workspace = args.comet_workspace - comet_project = args.comet_project - - if comet_api_key is None: - if "COMET_API_KEY" in os.environ: - comet_api_key = os.environ["COMET_API_KEY"] - if "COMET_WORKSPACE" in os.environ: - comet_workspace = os.environ["COMET_WORKSPACE"] - if "COMET_PROJECT" in os.environ: - comet_project = os.environ["COMET_PROJECT"] - - if comet_api_key is not None: - comet_logger = CometLogger( - api_key=comet_api_key, - workspace=comet_workspace, - project_name=comet_project, - experiment_name=args.experiment, - save_dir=args.working_dir, - offline=args.offline, - ) - if config_json_path: - comet_logger.experiment.log_code(config_json_path) - - dict_args = vars(args) - - data_module = UNetDataModule(**dict_args) - - prob_unet = LocaliseUNet(**dict_args) - - if args.resume_from is not None: - trainer = pl.Trainer(resume_from_checkpoint=args.resume_from) - else: - trainer = pl.Trainer.from_argparse_args(args) - - if comet_api_key is not None: - trainer.logger = comet_logger - - # Save the best model - checkpoint_callback = ModelCheckpoint( - monitor="DSC", - dirpath=args.default_root_dir, - filename="localise-{epoch:02d}-{DSC:.2f}", - save_top_k=1, - mode="max", - ) - - trainer.callbacks.append(checkpoint_callback) - - trainer.fit(prob_unet, data_module) - - -if __name__ == "__main__": - - args = None - config_json_path = None - if len(sys.argv) == 2: - # Check if JSON file parsed, if so read arguments from there... - if sys.argv[-1].endswith(".json"): - config_json_path = sys.argv[-1] - with open(config_json_path, "r") as f: - params = json.load(f) - args = [] - for key in params: - args.append(f"--{key}") - - if isinstance(params[key], list): - for list_val in params[key]: - args.append(str(list_val)) - else: - args.append(str(params[key])) - - arg_parser = ArgumentParser() - arg_parser = LocaliseUNet.add_model_specific_args(arg_parser) - arg_parser = UNetDataModule.add_model_specific_args(arg_parser) - arg_parser = pl.Trainer.add_argparse_args(arg_parser) - arg_parser.add_argument( - "--config", type=str, default=None, help="JSON file with parameters to load" - ) - arg_parser.add_argument("--seed", type=int, default=42, help="an integer to use as seed") - arg_parser.add_argument("--experiment", type=str, default="default", help="Name of experiment") - arg_parser.add_argument("--working_dir", type=str, default="./working") - arg_parser.add_argument("--num_observers", type=int, default=5) - arg_parser.add_argument("--spacing", nargs="+", type=int, default=[3, 3, 3]) - arg_parser.add_argument("--offline", type=bool, default=False) - arg_parser.add_argument("--comet_api_key", type=str, default=None) - arg_parser.add_argument("--comet_workspace", type=str, default=None) - arg_parser.add_argument("--comet_project", type=str, default=None) - arg_parser.add_argument("--resume_from", type=str, default=None) - arg_parser.add_argument("--combine_observers", type=str, default="union") - - main(arg_parser.parse_args(args), config_json_path=config_json_path) diff --git a/platipy/imaging/cnn/train_localise.py b/platipy/imaging/cnn/train_localise.py new file mode 100644 index 00000000..3c39a5a9 --- /dev/null +++ b/platipy/imaging/cnn/train_localise.py @@ -0,0 +1,138 @@ +# Copyright 2021 University of New South Wales, University of Sydney, Ingham Institute + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +import os +import json + +from pathlib import Path + +import comet_ml # pylint: disable=unused-import +from pytorch_lightning.loggers import CometLogger + +import pytorch_lightning as pl +from pytorch_lightning.callbacks import ModelCheckpoint + +from argparse import ArgumentParser + +from platipy.imaging.cnn.localise_net import LocaliseUNet +from platipy.imaging.cnn.dataload import UNetDataModule + + +def main(args, config_json_path=None): + + pl.seed_everything(args.seed, workers=True) + + args.working_dir = Path(args.working_dir) + args.working_dir = args.working_dir.joinpath(args.experiment) + args.fold_dir = args.working_dir.joinpath(f"fold_{args.fold}") + args.default_root_dir = str(args.fold_dir) + + comet_api_key = None + comet_workspace = None + comet_project = None + + if args.comet_api_key: + comet_api_key = args.comet_api_key + comet_workspace = args.comet_workspace + comet_project = args.comet_project + + if comet_api_key is None: + if "COMET_API_KEY" in os.environ: + comet_api_key = os.environ["COMET_API_KEY"] + if "COMET_WORKSPACE" in os.environ: + comet_workspace = os.environ["COMET_WORKSPACE"] + if "COMET_PROJECT" in os.environ: + comet_project = os.environ["COMET_PROJECT"] + + if comet_api_key is not None: + comet_logger = CometLogger( + api_key=comet_api_key, + workspace=comet_workspace, + project_name=comet_project, + experiment_name=args.experiment, + save_dir=args.working_dir, + offline=args.offline, + ) + if config_json_path: + comet_logger.experiment.log_code(config_json_path) + + dict_args = vars(args) + + data_module = UNetDataModule(**dict_args) + + prob_unet = LocaliseUNet(**dict_args) + + if args.resume_from is not None: + trainer = pl.Trainer(resume_from_checkpoint=args.resume_from) + else: + trainer = pl.Trainer.from_argparse_args(args) + + if comet_api_key is not None: + trainer.logger = comet_logger + + # Save the best model + checkpoint_callback = ModelCheckpoint( + monitor="DSC", + dirpath=args.default_root_dir, + filename="localise-{epoch:02d}-{DSC:.2f}", + save_top_k=1, + mode="max", + ) + + trainer.callbacks.append(checkpoint_callback) + + trainer.fit(prob_unet, data_module) + + +if __name__ == "__main__": + + args = None + config_json_path = None + if len(sys.argv) == 2: + # Check if JSON file parsed, if so read arguments from there... + if sys.argv[-1].endswith(".json"): + config_json_path = sys.argv[-1] + with open(config_json_path, "r") as f: + params = json.load(f) + args = [] + for key in params: + args.append(f"--{key}") + + if isinstance(params[key], list): + for list_val in params[key]: + args.append(str(list_val)) + else: + args.append(str(params[key])) + + arg_parser = ArgumentParser() + arg_parser = LocaliseUNet.add_model_specific_args(arg_parser) + arg_parser = UNetDataModule.add_model_specific_args(arg_parser) + arg_parser = pl.Trainer.add_argparse_args(arg_parser) + arg_parser.add_argument( + "--config", type=str, default=None, help="JSON file with parameters to load" + ) + arg_parser.add_argument("--seed", type=int, default=42, help="an integer to use as seed") + arg_parser.add_argument("--experiment", type=str, default="default", help="Name of experiment") + arg_parser.add_argument("--working_dir", type=str, default="./working") + arg_parser.add_argument("--num_observers", type=int, default=5) + arg_parser.add_argument("--spacing", nargs="+", type=int, default=[3, 3, 3]) + arg_parser.add_argument("--offline", type=bool, default=False) + arg_parser.add_argument("--comet_api_key", type=str, default=None) + arg_parser.add_argument("--comet_workspace", type=str, default=None) + arg_parser.add_argument("--comet_project", type=str, default=None) + arg_parser.add_argument("--resume_from", type=str, default=None) + arg_parser.add_argument("--combine_observers", type=str, default="union") + + main(arg_parser.parse_args(args), config_json_path=config_json_path) From 7e2ae2ab714b85c3188bfd01eb1cbebca9d076ff Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Sun, 15 Aug 2021 07:52:01 +0000 Subject: [PATCH 104/264] Separate out utils --- platipy/imaging/cnn/dataload.py | 2 + platipy/imaging/cnn/dataset.py | 85 +---------------- platipy/imaging/cnn/localise_net.py | 46 +--------- platipy/imaging/cnn/train.py | 46 +--------- platipy/imaging/cnn/train_localise.py | 3 +- platipy/imaging/cnn/utils.py | 126 ++++++++++++++++++++++++++ 6 files changed, 137 insertions(+), 171 deletions(-) create mode 100644 platipy/imaging/cnn/utils.py diff --git a/platipy/imaging/cnn/dataload.py b/platipy/imaging/cnn/dataload.py index acf2d417..25e0dfc0 100644 --- a/platipy/imaging/cnn/dataload.py +++ b/platipy/imaging/cnn/dataload.py @@ -2,6 +2,8 @@ import math from pathlib import Path +from loguru import logger + import torch import pytorch_lightning as pl diff --git a/platipy/imaging/cnn/dataset.py b/platipy/imaging/cnn/dataset.py index 5e1b5653..0f73f16a 100644 --- a/platipy/imaging/cnn/dataset.py +++ b/platipy/imaging/cnn/dataset.py @@ -13,90 +13,9 @@ from loguru import logger from platipy.imaging.cnn.localise_net import LocaliseUNet -from platipy.imaging.label.utils import get_union_mask, get_intersection_mask +from platipy.imaging.cnn.utils import preprocess_image, resample_mask_to_image, get_contour_mask from platipy.imaging.utils.crop import label_to_roi, crop_to_roi - - -def get_contour_mask(masks, kernel=5): - """Returns a mask around the region where observer masks don't agree - - Args: - masks (list): List of observer masks (as sitk.Image) - kernel (int, optional): The size of the kernal to dilate the contour of. Defaults to 5. - - Returns: - sitk.Image: The resulting contour mask - """ - - if not hasattr(kernel, "__iter__"): - kernel = (kernel,) * 3 - - union_mask = get_union_mask(masks) - intersection_mask = get_intersection_mask(masks) - - union_mask = sitk.BinaryDilate(union_mask, np.abs(kernel).astype(int).tolist(), sitk.sitkBall) - intersection_mask = sitk.BinaryErode( - intersection_mask, np.abs(kernel).astype(int).tolist(), sitk.sitkBall - ) - - return union_mask - intersection_mask - - -def preprocess_image(img, spacing=[1, 1, 1], crop_to_mm=128): - - img = sitk.Cast(img, sitk.sitkFloat32) - img = sitk.IntensityWindowing( - img, windowMinimum=-500.0, windowMaximum=500.0, outputMinimum=-1.0, outputMaximum=1.0 - ) - - new_size = sitk.VectorUInt32(3) - new_size[0] = int(img.GetSize()[0] * (img.GetSpacing()[0] / spacing[0])) - new_size[1] = int(img.GetSize()[1] * (img.GetSpacing()[1] / spacing[1])) - new_size[2] = int(img.GetSize()[2] * (img.GetSpacing()[2] / spacing[2])) - - if crop_to_mm: - if new_size[0] < crop_to_mm: - new_size[0] = crop_to_mm - - if new_size[1] < crop_to_mm: - new_size[1] = crop_to_mm - - img = sitk.Resample( - img, - new_size, - sitk.Transform(), - sitk.sitkLinear, - img.GetOrigin(), - spacing, - img.GetDirection(), - -1, - img.GetPixelID(), - ) - - if crop_to_mm: - center_x = img.GetSize()[0] / 2 - x_from = int(center_x - crop_to_mm / 2) - x_to = x_from + crop_to_mm - - center_y = img.GetSize()[1] / 2 - y_from = int(center_y - crop_to_mm / 2) - y_to = y_from + crop_to_mm - - img = img[x_from:x_to, y_from:y_to, :] - - return img - - -def resample_mask_to_image(img, mask): - - return sitk.Resample( - mask, - img, - sitk.Transform(), - sitk.sitkNearestNeighbor, - 0, - mask.GetPixelID(), - ) +from platipy.imaging.label.utils import get_union_mask, get_intersection_mask def prepare_transforms(): diff --git a/platipy/imaging/cnn/localise_net.py b/platipy/imaging/cnn/localise_net.py index 0435a227..a7414a12 100644 --- a/platipy/imaging/cnn/localise_net.py +++ b/platipy/imaging/cnn/localise_net.py @@ -27,52 +27,12 @@ import matplotlib.pyplot as plt from platipy.imaging.cnn.unet import UNet -from platipy.imaging.cnn.dataset import preprocess_image +from platipy.imaging.cnn.utils import preprocess_image, postprocess_mask, get_metrics from platipy.imaging import ImageVisualiser from platipy.imaging.label.utils import get_com -def post_process(pred): - - # Take only the largest componenet - labelled_image = sitk.ConnectedComponent(pred) - label_shape_filter = sitk.LabelShapeStatisticsImageFilter() - label_shape_filter.Execute(labelled_image) - label_indices = label_shape_filter.GetLabels() - voxel_counts = [label_shape_filter.GetNumberOfPixels(i) for i in label_indices] - if len(voxel_counts) > 0: - largest_component_label = label_indices[np.argmax(voxel_counts)] - largest_component_image = labelled_image == largest_component_label - pred = sitk.Cast(largest_component_image, sitk.sitkUInt8) - - # Fill any holes in the structure - pred = sitk.BinaryMorphologicalClosing(pred, (5, 5, 5)) - pred = sitk.BinaryFillhole(pred) - - return pred - - -def get_metrics(target, pred): - - result = {} - lomif = sitk.LabelOverlapMeasuresImageFilter() - lomif.Execute(target, pred) - result["JI"] = lomif.GetJaccardCoefficient() - result["DSC"] = lomif.GetDiceCoefficient() - - if sitk.GetArrayFromImage(pred).sum() == 0: - result["HD"] = 1000 - result["ASD"] = 100 - else: - hdif = sitk.HausdorffDistanceImageFilter() - hdif.Execute(target, pred) - result["HD"] = hdif.GetHausdorffDistance() - result["ASD"] = hdif.GetAverageHausdorffDistance() - - return result - - class LocaliseUNet(pl.LightningModule): def __init__( self, @@ -131,7 +91,7 @@ def infer(self, img): pred = sitk.Cast(pred, sitk.sitkUInt8) pred.CopyInformation(pp_img) - pred = post_process(pred) + pred = postprocess_mask(pred) pred = sitk.Resample(pred, img, sitk.Transform(), sitk.sitkNearestNeighbor) return pred @@ -222,7 +182,7 @@ def validation_epoch_end(self, validation_step_outputs): pred_arr = np.stack(pred_arrs) pred = sitk.GetImageFromArray(pred_arr) pred = sitk.Cast(pred, sitk.sitkUInt8) - pred = post_process(pred) + pred = postprocess_mask(pred) pred.CopyInformation(img) sitk.WriteImage(pred, f"val_pred_{case}.nii.gz") diff --git a/platipy/imaging/cnn/train.py b/platipy/imaging/cnn/train.py index e5241fa9..5ac0ed97 100644 --- a/platipy/imaging/cnn/train.py +++ b/platipy/imaging/cnn/train.py @@ -36,51 +36,11 @@ from platipy.imaging.cnn.prob_unet import ProbabilisticUnet from platipy.imaging.cnn.unet import l2_regularisation from platipy.imaging.cnn.dataload import UNetDataModule +from platipy.imaging.cnn.utils import postprocess_mask, get_metrics from platipy.imaging import ImageVisualiser from platipy.imaging.label.utils import get_com - -def post_process(pred): - - # Take only the largest componenet - labelled_image = sitk.ConnectedComponent(pred) - label_shape_filter = sitk.LabelShapeStatisticsImageFilter() - label_shape_filter.Execute(labelled_image) - label_indices = label_shape_filter.GetLabels() - voxel_counts = [label_shape_filter.GetNumberOfPixels(i) for i in label_indices] - if len(voxel_counts) > 0: - largest_component_label = label_indices[np.argmax(voxel_counts)] - largest_component_image = labelled_image == largest_component_label - pred = sitk.Cast(largest_component_image, sitk.sitkUInt8) - - # Fill any holes in the structure - pred = sitk.BinaryMorphologicalClosing(pred, (5, 5, 5)) - pred = sitk.BinaryFillhole(pred) - - return pred - - -def get_metrics(target, pred): - - result = {} - lomif = sitk.LabelOverlapMeasuresImageFilter() - lomif.Execute(target, pred) - result["JI"] = lomif.GetJaccardCoefficient() - result["DSC"] = lomif.GetDiceCoefficient() - - if sitk.GetArrayFromImage(pred).sum() == 0: - result["HD"] = 1000 - result["ASD"] = 100 - else: - hdif = sitk.HausdorffDistanceImageFilter() - hdif.Execute(target, pred) - result["HD"] = hdif.GetHausdorffDistance() - result["ASD"] = hdif.GetAverageHausdorffDistance() - - return result - - class ProbUNet(pl.LightningModule): def __init__( self, @@ -293,7 +253,7 @@ def validation_epoch_end(self, validation_step_outputs): mean = sitk.GetImageFromArray(mean_arr) mean = sitk.Cast(mean, sitk.sitkUInt8) - mean = post_process(mean) + mean = postprocess_mask(mean) mean.CopyInformation(img) # sitk.WriteImage(mean, f"val_mean_{case}_mean.nii.gz") @@ -341,7 +301,7 @@ def validation_epoch_end(self, validation_step_outputs): sample = sitk.GetImageFromArray(sample_arr) sample = sitk.Cast(sample, sitk.sitkUInt8) - sample = post_process(sample) + sample = postprocess_mask(sample) sample.CopyInformation(img) sitk.WriteImage(sample, f"val_sample_{case}_{observer}.nii.gz") samples.append(sample) diff --git a/platipy/imaging/cnn/train_localise.py b/platipy/imaging/cnn/train_localise.py index 3c39a5a9..49c64ecb 100644 --- a/platipy/imaging/cnn/train_localise.py +++ b/platipy/imaging/cnn/train_localise.py @@ -17,6 +17,7 @@ import json from pathlib import Path +from argparse import ArgumentParser import comet_ml # pylint: disable=unused-import from pytorch_lightning.loggers import CometLogger @@ -24,8 +25,6 @@ import pytorch_lightning as pl from pytorch_lightning.callbacks import ModelCheckpoint -from argparse import ArgumentParser - from platipy.imaging.cnn.localise_net import LocaliseUNet from platipy.imaging.cnn.dataload import UNetDataModule diff --git a/platipy/imaging/cnn/utils.py b/platipy/imaging/cnn/utils.py new file mode 100644 index 00000000..295ddcc0 --- /dev/null +++ b/platipy/imaging/cnn/utils.py @@ -0,0 +1,126 @@ +import numpy as np +import SimpleITK as sitk + +from platipy.imaging.label.utils import get_union_mask, get_intersection_mask + + +def get_contour_mask(masks, kernel=5): + """Returns a mask around the region where observer masks don't agree + + Args: + masks (list): List of observer masks (as sitk.Image) + kernel (int, optional): The size of the kernal to dilate the contour of. Defaults to 5. + + Returns: + sitk.Image: The resulting contour mask + """ + + if not hasattr(kernel, "__iter__"): + kernel = (kernel,) * 3 + + union_mask = get_union_mask(masks) + intersection_mask = get_intersection_mask(masks) + + union_mask = sitk.BinaryDilate(union_mask, np.abs(kernel).astype(int).tolist(), sitk.sitkBall) + intersection_mask = sitk.BinaryErode( + intersection_mask, np.abs(kernel).astype(int).tolist(), sitk.sitkBall + ) + + return union_mask - intersection_mask + + +def preprocess_image(img, spacing=[1, 1, 1], crop_to_mm=128): + + img = sitk.Cast(img, sitk.sitkFloat32) + img = sitk.IntensityWindowing( + img, windowMinimum=-500.0, windowMaximum=500.0, outputMinimum=-1.0, outputMaximum=1.0 + ) + + new_size = sitk.VectorUInt32(3) + new_size[0] = int(img.GetSize()[0] * (img.GetSpacing()[0] / spacing[0])) + new_size[1] = int(img.GetSize()[1] * (img.GetSpacing()[1] / spacing[1])) + new_size[2] = int(img.GetSize()[2] * (img.GetSpacing()[2] / spacing[2])) + + if crop_to_mm: + if new_size[0] < crop_to_mm: + new_size[0] = crop_to_mm + + if new_size[1] < crop_to_mm: + new_size[1] = crop_to_mm + + img = sitk.Resample( + img, + new_size, + sitk.Transform(), + sitk.sitkLinear, + img.GetOrigin(), + spacing, + img.GetDirection(), + -1, + img.GetPixelID(), + ) + + if crop_to_mm: + center_x = img.GetSize()[0] / 2 + x_from = int(center_x - crop_to_mm / 2) + x_to = x_from + crop_to_mm + + center_y = img.GetSize()[1] / 2 + y_from = int(center_y - crop_to_mm / 2) + y_to = y_from + crop_to_mm + + img = img[x_from:x_to, y_from:y_to, :] + + return img + + +def resample_mask_to_image(img, mask): + + return sitk.Resample( + mask, + img, + sitk.Transform(), + sitk.sitkNearestNeighbor, + 0, + mask.GetPixelID(), + ) + + +def postprocess_mask(pred): + + # Take only the largest componenet + labelled_image = sitk.ConnectedComponent(pred) + label_shape_filter = sitk.LabelShapeStatisticsImageFilter() + label_shape_filter.Execute(labelled_image) + label_indices = label_shape_filter.GetLabels() + voxel_counts = [label_shape_filter.GetNumberOfPixels(i) for i in label_indices] + if len(voxel_counts) > 0: + largest_component_label = label_indices[np.argmax(voxel_counts)] + largest_component_image = labelled_image == largest_component_label + pred = sitk.Cast(largest_component_image, sitk.sitkUInt8) + + # Fill any holes in the structure + pred = sitk.BinaryMorphologicalClosing(pred, (5, 5, 5)) + pred = sitk.BinaryFillhole(pred) + + return pred + + +def get_metrics(target, pred): + + result = {} + lomif = sitk.LabelOverlapMeasuresImageFilter() + lomif.Execute(target, pred) + result["JI"] = lomif.GetJaccardCoefficient() + result["DSC"] = lomif.GetDiceCoefficient() + + if sitk.GetArrayFromImage(pred).sum() == 0: + result["HD"] = 1000 + result["ASD"] = 100 + else: + hdif = sitk.HausdorffDistanceImageFilter() + hdif.Execute(target, pred) + result["HD"] = hdif.GetHausdorffDistance() + result["ASD"] = hdif.GetAverageHausdorffDistance() + + return result From e18291d312b5cb7b79e6144ad87b5f7485dfc298 Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Mon, 16 Aug 2021 22:19:04 +0000 Subject: [PATCH 105/264] Move out some util functions --- platipy/imaging/cnn/hierarchical_prob_unet.py | 53 +------------------ 1 file changed, 1 insertion(+), 52 deletions(-) diff --git a/platipy/imaging/cnn/hierarchical_prob_unet.py b/platipy/imaging/cnn/hierarchical_prob_unet.py index 31d248b9..40b7b186 100644 --- a/platipy/imaging/cnn/hierarchical_prob_unet.py +++ b/platipy/imaging/cnn/hierarchical_prob_unet.py @@ -20,58 +20,7 @@ import torch - -def truncated_normal_(tensor, mean=0, std=1): - size = tensor.shape - tmp = tensor.new_empty(size + (4,)).normal_() - valid = (tmp < 2) & (tmp > -2) - ind = valid.max(-1, keepdim=True)[1] - tensor.data.copy_(tmp.gather(-1, ind).squeeze(-1)) - tensor.data.mul_(std).add_(mean) - - -def init_weights(m): - if ( - isinstance(m, torch.nn.Conv2d) - or isinstance(m, torch.nn.ConvTranspose2d) - or isinstance(m, torch.nn.Conv3d) - or isinstance(m, torch.nn.ConvTranspose3d) - ): - torch.nn.init.kaiming_normal_(m.weight, mode="fan_in", nonlinearity="relu") - truncated_normal_(m.bias, mean=0, std=0.001) - - -def init_zeros(m): - if ( - isinstance(m, torch.nn.Conv2d) - or isinstance(m, torch.nn.ConvTranspose2d) - or isinstance(m, torch.nn.Conv3d) - or isinstance(m, torch.nn.ConvTranspose3d) - ): - torch.nn.init.zeros_(m.weight) - truncated_normal_(m.bias, mean=0, std=0.1) - - -def conv_nd(ndims=2, **kwargs): - """Generate a 2D or 3D convolution - - Args: - ndims (int, optional): 2 or 3 dimensions. Defaults to 2. - - Raises: - NotImplementedError: Raised if ndims is not in 2 or 3 dimensions. - - Returns: - torch.nn.Conv: The convolution. - """ - - if ndims == 2: - return torch.nn.Conv2d(**kwargs) - elif ndims == 3: - return torch.nn.Conv3d(**kwargs) - - raise NotImplementedError("Only 2 or 3 dimensions are supported") - +from .unet import init_weights, init_zeros, conv_nd class ResBlock(torch.nn.Module): """A residual block""" From f96abb4f315fba8a51bf44ec57e81ed53efe8a8f Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Mon, 16 Aug 2021 22:21:50 +0000 Subject: [PATCH 106/264] Crop to roi for synth DVF generation --- platipy/imaging/generation/dvf.py | 41 ++++++++++++++++++++++--------- 1 file changed, 30 insertions(+), 11 deletions(-) diff --git a/platipy/imaging/generation/dvf.py b/platipy/imaging/generation/dvf.py index 142405a6..63c9cca5 100644 --- a/platipy/imaging/generation/dvf.py +++ b/platipy/imaging/generation/dvf.py @@ -25,13 +25,15 @@ fast_symmetric_forces_demons_registration, ) +from platipy.imaging.utils.crop import label_to_roi, crop_to_roi -def generate_field_shift(mask_image, vector_shift=(10, 10, 10), gaussian_smooth=5): + +def generate_field_shift(mask, vector_shift=(10, 10, 10), gaussian_smooth=5): """ Shifts (moves) a structure defined using a binary mask. Args: - mask_image ([SimpleITK.Image]): The binary mask to shift. + mask ([SimpleITK.Image]): The binary mask to shift. vector_shift (tuple, optional): The displacement vector applied to the entire binary mask. Convention: (+/-, +/-, +/-) = (sup/inf, post/ant, left/right) shift. @@ -45,9 +47,15 @@ def generate_field_shift(mask_image, vector_shift=(10, 10, 10), gaussian_smooth= [SimpleITK.DisplacementFieldTransform]: The transform representing the shift. [SimpleITK.Image]: The displacement vector field representing the shift. """ + + mask_full = mask + + size, index = label_to_roi(mask, expansion_mm=[x + 5 for x in vector_shift]) + mask = crop_to_roi(mask, size, index) + # Define array # Used for image array manipulations - mask_image_arr = sitk.GetArrayFromImage(mask_image) + mask_image_arr = sitk.GetArrayFromImage(mask) # The template deformation field # Used to generate transforms @@ -56,14 +64,14 @@ def generate_field_shift(mask_image, vector_shift=(10, 10, 10), gaussian_smooth= dvf_template = sitk.GetImageFromArray(dvf_arr) # Copy image information - dvf_template.CopyInformation(mask_image) + dvf_template.CopyInformation(mask) dvf_tfm = sitk.DisplacementFieldTransform(sitk.Cast(dvf_template, sitk.sitkVectorFloat64)) - mask_image_shift = apply_transform( - mask_image, transform=dvf_tfm, default_value=0, interpolator=sitk.sitkNearestNeighbor + mask_shift = apply_transform( + mask, transform=dvf_tfm, default_value=0, interpolator=sitk.sitkNearestNeighbor ) - dvf_template = sitk.Mask(dvf_template, mask_image | mask_image_shift) + dvf_template = sitk.Mask(dvf_template, mask | mask_shift) # smooth if np.any(gaussian_smooth): @@ -73,12 +81,15 @@ def generate_field_shift(mask_image, vector_shift=(10, 10, 10), gaussian_smooth= dvf_template = sitk.SmoothingRecursiveGaussian(dvf_template, gaussian_smooth) + # Resample back to original image + dvf_template = sitk.Resample(dvf_template, mask_full) dvf_tfm = sitk.DisplacementFieldTransform(sitk.Cast(dvf_template, sitk.sitkVectorFloat64)) - mask_image_shift = apply_transform( - mask_image, transform=dvf_tfm, default_value=0, interpolator=sitk.sitkNearestNeighbor + + mask_shift = apply_transform( + mask_full, transform=dvf_tfm, default_value=0, interpolator=sitk.sitkNearestNeighbor ) - return mask_image_shift, dvf_tfm, dvf_template + return mask_shift, dvf_tfm, dvf_template def generate_field_asymmetric_contract( @@ -229,7 +240,13 @@ def generate_field_expand( [SimpleITK.Image]: The displacement vector field representing the expansion. """ + mask_full = mask + + size, index = label_to_roi(mask, expansion_mm=[expand + gaussian_smooth] * 3) + mask = crop_to_roi(mask, size, index) + if bone_mask is not False: + bone_mask = sitk.Resample(bone_mask, mask, sitk.Transform(), sitk.sitkNearestNeighbor) mask_original = mask + bone_mask else: mask_original = mask @@ -299,10 +316,12 @@ def generate_field_expand( dvf_template = sitk.SmoothingRecursiveGaussian(dvf_template, gaussian_smooth) + # Resample back to original image + dvf_template = sitk.Resample(dvf_template, mask_full) dvf_tfm = sitk.DisplacementFieldTransform(sitk.Cast(dvf_template, sitk.sitkVectorFloat64)) mask_symmetric_expand = apply_transform( - mask, transform=dvf_tfm, default_value=0, interpolator=sitk.sitkNearestNeighbor + mask_full, transform=dvf_tfm, default_value=0, interpolator=sitk.sitkNearestNeighbor ) return mask_symmetric_expand, dvf_tfm, dvf_template From 8d50c50b3be2d3b829a4f3ecc81ffc253787c6f3 Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Mon, 16 Aug 2021 22:49:21 +0000 Subject: [PATCH 107/264] handle int or list for expand properly --- platipy/imaging/generation/dvf.py | 57 ++++++++++++++++++------------- 1 file changed, 33 insertions(+), 24 deletions(-) diff --git a/platipy/imaging/generation/dvf.py b/platipy/imaging/generation/dvf.py index 63c9cca5..ca944a5d 100644 --- a/platipy/imaging/generation/dvf.py +++ b/platipy/imaging/generation/dvf.py @@ -50,7 +50,16 @@ def generate_field_shift(mask, vector_shift=(10, 10, 10), gaussian_smooth=5): mask_full = mask - size, index = label_to_roi(mask, expansion_mm=[x + 5 for x in vector_shift]) + roi_expand = [x + 5 for x in vector_shift] + + if np.any(gaussian_smooth): + + if not hasattr(gaussian_smooth, "__iter__"): + gaussian_smooth = (gaussian_smooth,) * 3 + + roi_expand = [x + y for x, y in zip(roi_expand, gaussian_smooth)] + + size, index = label_to_roi(mask, expansion_mm=roi_expand) mask = crop_to_roi(mask, size, index) # Define array @@ -75,10 +84,6 @@ def generate_field_shift(mask, vector_shift=(10, 10, 10), gaussian_smooth=5): # smooth if np.any(gaussian_smooth): - - if not hasattr(gaussian_smooth, "__iter__"): - gaussian_smooth = (gaussian_smooth,) * 3 - dvf_template = sitk.SmoothingRecursiveGaussian(dvf_template, gaussian_smooth) # Resample back to original image @@ -223,25 +228,36 @@ def generate_field_expand( dilation kernel. Args: - mask ([SimpleITK.Image]): The binary mask to expand. - bone_mask ([SimpleITK.Image, optional]): A binary mask defining regions where we expect - restricted deformations. - vector_asymmetric_extend (int |tuple, optional): The expansion vector applied to the entire - binary mask. - Convention: (z,y,x) size of expansion kernel. - Defined in millimetres. - Defaults to 3. + mask (SimpleITK.Image): The binary mask to expand. + bone_mask (SimpleITK.Image, optional): A binary mask defining regions where we expect + restricted deformations. + expand (int |tuple, optional): The expansion vector applied to the entire binary mask. + Convention: (z,y,x) size of expansion kernel. + Defined in millimetres. + Defaults to 3. gaussian_smooth (int | list, optional): Scale of a Gaussian kernel used to smooth the - deformation vector field. Defaults to 5. + deformation vector field. Defaults to 5. Returns: - [SimpleITK.Image]: The binary mask following the expansion. - [SimpleITK.DisplacementFieldTransform]: The transform representing the expansion. - [SimpleITK.Image]: The displacement vector field representing the expansion. + SimpleITK.Image: The binary mask following the expansion. + SimpleITK.DisplacementFieldTransform: The transform representing the expansion. + SimpleITK.Image: The displacement vector field representing the expansion. """ mask_full = mask + if not hasattr(expand, "__iter__"): + expand = (expand,) * 3 + + roi_expand = expand + + if np.any(gaussian_smooth): + + if not hasattr(gaussian_smooth, "__iter__"): + gaussian_smooth = (gaussian_smooth,) * 3 + + roi_expand = [x + y for x, y in zip(roi_expand, gaussian_smooth)] + size, index = label_to_roi(mask, expansion_mm=[expand + gaussian_smooth] * 3) mask = crop_to_roi(mask, size, index) @@ -252,9 +268,6 @@ def generate_field_expand( mask_original = mask # Use binary erosion to create a smaller volume - if not hasattr(expand, "__iter__"): - expand = (expand,) * 3 - expand = np.array(expand) # Convert voxels to millimetres @@ -310,10 +323,6 @@ def generate_field_expand( # smooth if np.any(gaussian_smooth): - - if not hasattr(gaussian_smooth, "__iter__"): - gaussian_smooth = (gaussian_smooth,) * 3 - dvf_template = sitk.SmoothingRecursiveGaussian(dvf_template, gaussian_smooth) # Resample back to original image From 97b9cf89b1ecbe73b7d0d6d073bf2f5769da56b9 Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Mon, 16 Aug 2021 22:54:55 +0000 Subject: [PATCH 108/264] Use roi expand --- platipy/imaging/generation/dvf.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/platipy/imaging/generation/dvf.py b/platipy/imaging/generation/dvf.py index ca944a5d..494bf6a4 100644 --- a/platipy/imaging/generation/dvf.py +++ b/platipy/imaging/generation/dvf.py @@ -258,7 +258,7 @@ def generate_field_expand( roi_expand = [x + y for x, y in zip(roi_expand, gaussian_smooth)] - size, index = label_to_roi(mask, expansion_mm=[expand + gaussian_smooth] * 3) + size, index = label_to_roi(mask, expansion_mm=roi_expand) mask = crop_to_roi(mask, size, index) if bone_mask is not False: From 2acff8e56da2236379d1c6f848cb17c2eb8231e4 Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Wed, 18 Aug 2021 08:02:23 +1000 Subject: [PATCH 109/264] Run localise model as first step of prob unet --- platipy/imaging/cnn/dataload.py | 2 +- platipy/imaging/cnn/dataset.py | 12 +++++++++--- platipy/imaging/cnn/train.py | 2 +- platipy/imaging/generation/augment.py | 2 ++ 4 files changed, 13 insertions(+), 5 deletions(-) diff --git a/platipy/imaging/cnn/dataload.py b/platipy/imaging/cnn/dataload.py index 25e0dfc0..0413e625 100644 --- a/platipy/imaging/cnn/dataload.py +++ b/platipy/imaging/cnn/dataload.py @@ -205,7 +205,7 @@ def setup(self, stage=None): localise_model_path = None if self.crop_using_localise_model: localise_model_path = Path(self.crop_using_localise_model.format(fold=self.fold)) - if localise_model_path.name.isdir(): + if localise_model_path.is_dir(): localise_model_path = next(localise_model_path.glob("*.ckpt")) logger.info(f"Using localise model: {localise_model_path}") diff --git a/platipy/imaging/cnn/dataset.py b/platipy/imaging/cnn/dataset.py index 0f73f16a..f1da53f5 100644 --- a/platipy/imaging/cnn/dataset.py +++ b/platipy/imaging/cnn/dataset.py @@ -141,18 +141,24 @@ def __init__( logger.debug(f"Generating images for case: {case_id}") img = sitk.ReadImage(img_path) - img = preprocess_image(img, spacing=spacing, crop_to_mm=crop_to_mm) - if crop_using_localise_model: localise_model = LocaliseUNet.load_from_checkpoint(crop_using_localise_model) localise_model.eval() localise_pred = localise_model.infer(img) + print(localise_pred.GetSize()) + img = preprocess_image(img, spacing=spacing, crop_to_mm=crop_to_mm) localise_pred = resample_mask_to_image(img, localise_pred) size, index = label_to_roi(localise_pred) - index = [i - int((100 - s) / 2) for i, s in zip(index, size)] + print(size) + print(index) + index = [i - int((g - s) / 2) for i, s, g in zip(index, size, localise_voxel_grid_size)] size = localise_voxel_grid_size + print(size) + print(index) img = crop_to_roi(img, size, index) + else: + img = preprocess_image(img, spacing=spacing, crop_to_mm=crop_to_mm) observers = [] for obs, structure_path in enumerate(structure_paths): diff --git a/platipy/imaging/cnn/train.py b/platipy/imaging/cnn/train.py index 5ac0ed97..11b11a1c 100644 --- a/platipy/imaging/cnn/train.py +++ b/platipy/imaging/cnn/train.py @@ -461,7 +461,7 @@ def main(args, config_json_path=None): arg_parser.add_argument("--experiment", type=str, default="default", help="Name of experiment") arg_parser.add_argument("--working_dir", type=str, default="./working") arg_parser.add_argument("--num_observers", type=int, default=5) - arg_parser.add_argument("--spacing", nargs="+", type=int, default=[1, 1, 1]) + arg_parser.add_argument("--spacing", nargs="+", type=float, default=[1, 1, 1]) arg_parser.add_argument("--offline", type=bool, default=False) arg_parser.add_argument("--comet_api_key", type=str, default=None) arg_parser.add_argument("--comet_workspace", type=str, default=None) diff --git a/platipy/imaging/generation/augment.py b/platipy/imaging/generation/augment.py index c3eb5cd7..ba414edc 100644 --- a/platipy/imaging/generation/augment.py +++ b/platipy/imaging/generation/augment.py @@ -25,6 +25,7 @@ from loguru import logger +import matplotlib.pyplot as plt from platipy.imaging import ImageVisualiser from platipy.imaging.generation.dvf import ( @@ -352,6 +353,7 @@ def augment_data(args): figure_path = augmented_case_path.joinpath("aug.png") fig.savefig(figure_path, bbox_inches="tight") + plt.close() if __name__ == "__main__": From 708e7257797b0bb08d89b3fe58bca6d1e10b1460 Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Wed, 18 Aug 2021 04:34:20 +0000 Subject: [PATCH 110/264] More control over image preprocessing --- platipy/imaging/cnn/dataload.py | 40 +++++++++++------- platipy/imaging/cnn/dataset.py | 33 ++++++++------- platipy/imaging/cnn/utils.py | 75 ++++++++++++++++++++++++++------- 3 files changed, 102 insertions(+), 46 deletions(-) diff --git a/platipy/imaging/cnn/dataload.py b/platipy/imaging/cnn/dataload.py index 0413e625..6338bca6 100644 --- a/platipy/imaging/cnn/dataload.py +++ b/platipy/imaging/cnn/dataload.py @@ -31,7 +31,9 @@ def __init__( k_folds=5, batch_size=5, num_workers=4, - crop_to_mm=128, + crop_to_grid_size_xy=128, + intensity_scaling="window", + intensity_window=[-500, 500], num_observers=5, spacing=[1, 1, 1], contour_mask_kernel=3, @@ -62,9 +64,11 @@ def __init__( self.batch_size = batch_size self.num_workers = num_workers - self.crop_to_mm = crop_to_mm + self.crop_to_grid_size_xy = crop_to_grid_size_xy self.num_observers = num_observers self.spacing = spacing + self.intensity_scaling = intensity_scaling + self.intensity_window = intensity_window self.contour_mask_kernel = contour_mask_kernel self.crop_using_localise_model = crop_using_localise_model @@ -84,7 +88,7 @@ def add_model_specific_args(parent_parser): parser = parent_parser.add_argument_group("Data Loader") parser.add_argument("--data_dir", type=str, default="./data") parser.add_argument("--augmented_dir", type=str, default=None) - parser.add_argument("--augment_onfly", type=bool, default=True) + parser.add_argument("--augment_on_fly", type=bool, default=True) parser.add_argument("--fold", type=int, default=0) parser.add_argument("--k_folds", type=int, default=5) parser.add_argument("--batch_size", type=int, default=5) @@ -95,7 +99,9 @@ def add_model_specific_args(parent_parser): parser.add_argument("--augmented_case_glob", type=str, default=None) parser.add_argument("--augmented_image_glob", type=str, default=None) parser.add_argument("--augmented_label_glob", type=str, default=None) - parser.add_argument("--crop_to_mm", type=int, default=128) + parser.add_argument("--crop_to_grid_size_xy", type=int, default=128) + parser.add_argument("--intensity_scaling", type=str, default="window") + parser.add_argument("--intensity_window", nargs="+", type=int, default=[-500, 500]) parser.add_argument("--contour_mask_kernel", type=int, default=5) parser.add_argument("--crop_using_localise_model", type=str, default=None) parser.add_argument( @@ -201,7 +207,7 @@ def setup(self, stage=None): for case in self.test_cases ] - crop_to_mm = None + crop_to_grid_size = None localise_model_path = None if self.crop_using_localise_model: localise_model_path = Path(self.crop_using_localise_model.format(fold=self.fold)) @@ -209,40 +215,44 @@ def setup(self, stage=None): localise_model_path = next(localise_model_path.glob("*.ckpt")) logger.info(f"Using localise model: {localise_model_path}") + crop_to_grid_size = self.localise_voxel_grid_size else: - crop_to_mm = self.crop_to_mm + crop_to_grid_size = self.crop_to_grid_size_xy self.training_set = NiftiDataset( train_data, self.working_dir, - augment_on_the_fly=self.augment_on_fly, + augment_on_fly=self.augment_on_fly, spacing=self.spacing, - crop_to_mm=crop_to_mm, + crop_to_grid_size=crop_to_grid_size, crop_using_localise_model=localise_model_path, - localise_voxel_grid_size=self.localise_voxel_grid_size, contour_mask_kernel=self.contour_mask_kernel, + intensity_scaling=self.intensity_scaling, + intensity_window=self.intensity_window, ndims=self.ndims, ) self.validation_set = NiftiDataset( validation_data, self.working_dir, - augment_on_the_fly=False, + augment_on_fly=False, spacing=self.spacing, - crop_to_mm=crop_to_mm, + crop_to_grid_size=crop_to_grid_size, crop_using_localise_model=localise_model_path, - localise_voxel_grid_size=self.localise_voxel_grid_size, contour_mask_kernel=self.contour_mask_kernel, + intensity_scaling=self.intensity_scaling, + intensity_window=self.intensity_window, ndims=self.ndims, ) self.test_set = NiftiDataset( test_data, self.working_dir, - augment_on_the_fly=False, + augment_on_fly=False, spacing=self.spacing, - crop_to_mm=crop_to_mm, + crop_to_grid_size=crop_to_grid_size, crop_using_localise_model=localise_model_path, - localise_voxel_grid_size=self.localise_voxel_grid_size, contour_mask_kernel=self.contour_mask_kernel, + intensity_scaling=self.intensity_scaling, + intensity_window=self.intensity_window, ndims=self.ndims, ) diff --git a/platipy/imaging/cnn/dataset.py b/platipy/imaging/cnn/dataset.py index f1da53f5..5b9b016c 100644 --- a/platipy/imaging/cnn/dataset.py +++ b/platipy/imaging/cnn/dataset.py @@ -61,13 +61,14 @@ def __init__( self, data, working_dir, - augment_on_the_fly=True, + augment_on_fly=True, spacing=[1, 1, 1], - crop_to_mm=None, crop_using_localise_model=None, - localise_voxel_grid_size=[100, 100, 100], + crop_to_grid_size=128, contour_mask_kernel=5, combine_observers=None, + intensity_scaling="window", + intensity_window=[-500, 500], ndims=2, ): """Prepare a dataset from Nifti images/labels @@ -79,12 +80,9 @@ def __init__( working_dir (str|path): Working directory where to write prepared files. """ - if crop_to_mm is not None and crop_using_localise_model is not None: - raise AttributeError("Only one of crop_to_mm or crop_using_localise_model may be set") - self.data = data self.transforms = None - if augment_on_the_fly: + if augment_on_fly: self.transforms = prepare_transforms() self.slices = [] self.working_dir = Path(working_dir) @@ -145,20 +143,23 @@ def __init__( localise_model = LocaliseUNet.load_from_checkpoint(crop_using_localise_model) localise_model.eval() localise_pred = localise_model.infer(img) - print(localise_pred.GetSize()) - img = preprocess_image(img, spacing=spacing, crop_to_mm=crop_to_mm) + + img = preprocess_image(img, spacing=spacing, crop_to_grid_size_xy=None) localise_pred = resample_mask_to_image(img, localise_pred) size, index = label_to_roi(localise_pred) - print(size) - print(index) - index = [i - int((g - s) / 2) for i, s, g in zip(index, size, localise_voxel_grid_size)] - size = localise_voxel_grid_size - print(size) - print(index) + + if not hasattr(crop_to_grid_size, "__iter__"): + crop_to_grid_size = (crop_to_grid_size,) * 3 + + index = [i - int((g - s) / 2) for i, s, g in zip(index, size, crop_to_grid_size)] + size = crop_to_grid_size + img = crop_to_roi(img, size, index) else: - img = preprocess_image(img, spacing=spacing, crop_to_mm=crop_to_mm) + img = preprocess_image( + img, spacing=spacing, crop_to_grid_size_xy=crop_to_grid_size, intensity_scaling=intensity_scaling, intensity_window=intensity_window + ) observers = [] for obs, structure_path in enumerate(structure_paths): diff --git a/platipy/imaging/cnn/utils.py b/platipy/imaging/cnn/utils.py index 295ddcc0..ba7c0a04 100644 --- a/platipy/imaging/cnn/utils.py +++ b/platipy/imaging/cnn/utils.py @@ -29,24 +29,60 @@ def get_contour_mask(masks, kernel=5): return union_mask - intersection_mask -def preprocess_image(img, spacing=[1, 1, 1], crop_to_mm=128): +def preprocess_image( + img, + spacing=[1, 1, 1], + crop_to_grid_size_xy=128, + intensity_scaling="window", + intensity_window=[-500, 500], +): + """Preprocess an image to prepare it for use in a CNN. - img = sitk.Cast(img, sitk.sitkFloat32) - img = sitk.IntensityWindowing( - img, windowMinimum=-500.0, windowMaximum=500.0, outputMinimum=-1.0, outputMaximum=1.0 - ) + Args: + img (sitk.Image): [description] + spacing (list, optional): [description]. Defaults to [1, 1, 1]. + crop_to_grid_size_xy (int|list, optional): Crop to the center grid of this size in x and y + direction. May be int value which will be use for both x and y size. Or a list containing + two int values for x and y. Defaults to 128. + intensity_scaling (str, optional): How to scale the intensity values. Should be one of + 'norm' (center mean and unit variance), 'window' (map window [min max] to [-1 1]), 'none' + (no intensity scaling applied). Defaults to "window". + intensity_window (list, optional): List with min and max values to be used when + intensity_scaling is 'window'. Not used otherwise. Defaults to [-500, 500]. + + Returns: + sitk.Image: The preprocessed image. + """ + + if intensity_scaling == "norm": + img = sitk.Normalize(img) + elif intensity_scaling == "window": + img = sitk.Cast(img, sitk.sitkFloat32) + img = sitk.IntensityWindowing( + img, + windowMinimum=intensity_window[0], + windowMaximum=intensity_window[1], + outputMinimum=-1.0, + outputMaximum=1.0, + ) + elif intensity_scaling != "none" and intensity_scaling is not None: + raise ValueError("intensity_scaling should be one of: 'norm', 'window', 'none'") new_size = sitk.VectorUInt32(3) new_size[0] = int(img.GetSize()[0] * (img.GetSpacing()[0] / spacing[0])) new_size[1] = int(img.GetSize()[1] * (img.GetSpacing()[1] / spacing[1])) new_size[2] = int(img.GetSize()[2] * (img.GetSpacing()[2] / spacing[2])) - if crop_to_mm: - if new_size[0] < crop_to_mm: - new_size[0] = crop_to_mm + if crop_to_grid_size_xy: + + if not hasattr(crop_to_grid_size_xy, "__iter__"): + crop_to_grid_size_xy = (crop_to_grid_size_xy,) * 2 - if new_size[1] < crop_to_mm: - new_size[1] = crop_to_mm + if new_size[0] < crop_to_grid_size_xy[0]: + new_size[0] = crop_to_grid_size_xy[0] + + if new_size[1] < crop_to_grid_size_xy[1]: + new_size[1] = crop_to_grid_size_xy[1] img = sitk.Resample( img, @@ -60,14 +96,14 @@ def preprocess_image(img, spacing=[1, 1, 1], crop_to_mm=128): img.GetPixelID(), ) - if crop_to_mm: + if crop_to_grid_size_xy: center_x = img.GetSize()[0] / 2 - x_from = int(center_x - crop_to_mm / 2) - x_to = x_from + crop_to_mm + x_from = int(center_x - crop_to_grid_size_xy[0] / 2) + x_to = x_from + crop_to_grid_size_xy[0] center_y = img.GetSize()[1] / 2 - y_from = int(center_y - crop_to_mm / 2) - y_to = y_from + crop_to_mm + y_from = int(center_y - crop_to_grid_size_xy[1] / 2) + y_to = y_from + crop_to_grid_size_xy[1] img = img[x_from:x_to, y_from:y_to, :] @@ -75,6 +111,15 @@ def preprocess_image(img, spacing=[1, 1, 1], crop_to_mm=128): def resample_mask_to_image(img, mask): + """Repsample a mask to the space of the image supplied. + + Args: + img (sitk.Image): Image to sample to space of. + mask (sitk.Image): Mask to resample. + + Returns: + sitk.Image: The resampled mask. + """ return sitk.Resample( mask, From a82ac42db21e593e862f186a6af327fbbd27b902 Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Wed, 18 Aug 2021 04:38:09 +0000 Subject: [PATCH 111/264] Ensure min size of expand roi --- platipy/imaging/generation/dvf.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/platipy/imaging/generation/dvf.py b/platipy/imaging/generation/dvf.py index 494bf6a4..951256b2 100644 --- a/platipy/imaging/generation/dvf.py +++ b/platipy/imaging/generation/dvf.py @@ -16,6 +16,8 @@ import numpy as np import SimpleITK as sitk +from loguru import logger + from platipy.imaging.registration.utils import ( apply_transform, convert_mask_to_reg_structure, @@ -59,6 +61,9 @@ def generate_field_shift(mask, vector_shift=(10, 10, 10), gaussian_smooth=5): roi_expand = [x + y for x, y in zip(roi_expand, gaussian_smooth)] + # Make sure the expansion meets a minimum size (1cm) + roi_expand = [max(e, 10) for e in roi_expand] + size, index = label_to_roi(mask, expansion_mm=roi_expand) mask = crop_to_roi(mask, size, index) @@ -258,6 +263,9 @@ def generate_field_expand( roi_expand = [x + y for x, y in zip(roi_expand, gaussian_smooth)] + # Make sure the expansion meets a minimum size (1cm) + roi_expand = [max(e, 10) for e in roi_expand] + size, index = label_to_roi(mask, expansion_mm=roi_expand) mask = crop_to_roi(mask, size, index) @@ -279,17 +287,17 @@ def generate_field_expand( # If all negative: erode if np.all(np.array(expand) <= 0): - print("All factors negative: shrinking only.") + logger.debug("All factors negative: shrinking only.") mask_expand = sitk.BinaryErode(mask, np.abs(expand).astype(int).tolist(), sitk.sitkBall) # If all positive: dilate elif np.all(np.array(expand) >= 0): - print("All factors positive: expansion only.") + logger.debug("All factors positive: expansion only.") mask_expand = sitk.BinaryDilate(mask, np.abs(expand).astype(int).tolist(), sitk.sitkBall) # Otherwise: sequential operations else: - print("Mixed factors: shrinking and expansion.") + logger.debug("Mixed factors: shrinking and expansion.") expansion_kernel = expand * (expand > 0) shrink_kernel = expand * (expand < 0) From 4b46e99477a1fc1e56cb726f478b21621d6032d1 Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Thu, 19 Aug 2021 13:44:18 +1000 Subject: [PATCH 112/264] Update to augmentation script --- platipy/imaging/generation/augment.py | 92 +++++++++++++++++---------- 1 file changed, 57 insertions(+), 35 deletions(-) diff --git a/platipy/imaging/generation/augment.py b/platipy/imaging/generation/augment.py index ba414edc..0a409902 100644 --- a/platipy/imaging/generation/augment.py +++ b/platipy/imaging/generation/augment.py @@ -26,7 +26,7 @@ from loguru import logger import matplotlib.pyplot as plt -from platipy.imaging import ImageVisualiser +#from platipy.imaging import ImageVisualiser from platipy.imaging.generation.dvf import ( generate_field_shift, @@ -55,23 +55,27 @@ def apply_augmentation(image, augmentation, masks=[]): "DeformableAugment's" ) - transforms = [] + #transforms = [] + transform = None dvf = None for aug in augmentation: if not isinstance(aug, DeformableAugment): raise AttributeError("Each augmentation must be of type DeformableAugment") + logger.debug(str(aug)) tfm, field = aug.augment() - transforms.append(tfm) + #transforms.append(tfm) if dvf is None: dvf = field + transform = tfm else: dvf += field + transform = sitk.CompositeTransform([transform, tfm]) - transform = sitk.CompositeTransform(transforms) - del transforms + #transform = sitk.CompositeTransform(transforms) + #del transforms image_deformed = apply_transform( image, @@ -158,6 +162,9 @@ def augment(self): ) return transform, dvf + def __str__(self): + return f"Shift with vector: {self.vector_shift}, gauss: {self.gaussian_smooth}" + class ExpandAugment(DeformableAugment): def __init__(self, mask, vector_expand=(10, 10, 10), gaussian_smooth=5, bone_mask=False): @@ -178,12 +185,15 @@ def augment(self): return transform, dvf + def __str__(self): + return f"Expand with vector: {self.vector_expand}, smooth: {self.gaussian_smooth}" + class ContractAugment(DeformableAugment): def __init__(self, mask, vector_contract=(10, 10, 10), gaussian_smooth=5, bone_mask=False): self.mask = mask - self.contract = [int(-x / s) for x, s in zip(vector_contract, mask.GetSpacing())] + self.vector_contract = [int(-x / s) for x, s in zip(vector_contract, mask.GetSpacing())] self.gaussian_smooth = gaussian_smooth self.bone_mask = bone_mask @@ -192,11 +202,14 @@ def augment(self): _, transform, dvf = generate_field_expand( self.mask, bone_mask=self.bone_mask, - expand=self.contract, + expand=self.vector_contract, gaussian_smooth=self.gaussian_smooth, ) return transform, dvf + def __str__(self): + return f"Contract with vector: {self.vector_contract}, smooth: {self.gaussian_smooth}" + def augment_data(args): @@ -280,26 +293,12 @@ def augment_data(args): if args.enable_fill_holes: + logger.debug("Finding holes") + ct_image = sitk.ReadImage(str(data[case]["image"])) label_image, labels = detect_holes(ct_image) - for label in labels[1:]: # Skip first hole since likely air around body - - - if random.random() > args.fill_probability: continue - - hole = label_image == label["label"] - hole_dilate = sitk.BinaryDilate(hole, (2,2,2), sitk.sitkBall) - contour_points = sitk.BinaryContour(hole_dilate) - fill_value = np.median(sitk.GetArrayFromImage(ct_image)[sitk.GetArrayFromImage(contour_points)==1]) - - ct_arr = sitk.GetArrayFromImage(ct_image) - ct_arr[sitk.GetArrayFromImage(hole_dilate)==1] = fill_value - ct_filled = sitk.GetImageFromArray(ct_arr) - ct_filled.CopyInformation(ct_image) - - ct_image = ct_filled - # Get list of structures to generate augmentations off + logger.debug("Collecting structures") all_masks = [] all_names = [] for structure_path in data[case]["label"]: @@ -309,14 +308,37 @@ def augment_data(args): all_masks.append(mask) all_names.append(structure_path.name.replace(".nii.gz", "")) - # Generate 10 random augmentations per case + # Generate x random augmentations per case for i in range(args.augmentations_per_case): - logger.debug("Generating augmentation") + logger.debug(f"Generating augmentation {i}") + + ct_image = sitk.ReadImage(str(data[case]["image"])) + + if args.enable_fill_holes: + + logger.debug("Filling holes") + + for label in labels[1:]: # Skip first hole since likely air around body + + if random.random() > args.fill_probability: continue + + hole = label_image == label["label"] + hole_dilate = sitk.BinaryDilate(hole, (2,2,2), sitk.sitkBall) + contour_points = sitk.BinaryContour(hole_dilate) + fill_value = np.median(sitk.GetArrayFromImage(ct_image)[sitk.GetArrayFromImage(contour_points)==1]) + + ct_arr = sitk.GetArrayFromImage(ct_image) + ct_arr[sitk.GetArrayFromImage(hole_dilate)==1] = fill_value + ct_filled = sitk.GetImageFromArray(ct_arr) + ct_filled.CopyInformation(ct_image) + + ct_image = ct_filled augmented_case_path = output_dir.joinpath(case, f"augment_{i}") augmented_case_path.mkdir(exist_ok=True, parents=True) + logger.debug("Generating random augmentations") augmentation = generate_random_augmentation(ct_image, all_masks, augmentation_types) dvf = None @@ -338,22 +360,22 @@ def augment_data(args): augmented_image_path = augmented_case_path.joinpath("CT.nii.gz") sitk.WriteImage(augmented_image, str(augmented_image_path)) - vis = ImageVisualiser(image=ct_image, figure_size_in=6) - vis.add_comparison_overlay(augmented_image) - if dvf is not None: - vis.add_vector_overlay(dvf, arrow_scale=1, subsample=(4, 12, 12)) + #vis = ImageVisualiser(image=ct_image, figure_size_in=6) + #vis.add_comparison_overlay(augmented_image) + #if dvf is not None: + # vis.add_vector_overlay(dvf, arrow_scale=1, subsample=(4, 12, 12)) for mask_name, mask, augmented_mask in zip(all_names, all_masks, augmented_masks): - vis.add_contour({f"{mask_name}": mask, f"{mask_name} (augmented)": augmented_mask}) + #vis.add_contour({f"{mask_name}": mask, f"{mask_name} (augmented)": augmented_mask}) logger.debug(f"Applying augmentation to mask: {mask_name}") augmented_mask_path = augmented_case_path.joinpath(f"{mask_name}.nii.gz") sitk.WriteImage(augmented_mask, str(augmented_mask_path)) - fig = vis.show() + #fig = vis.show() - figure_path = augmented_case_path.joinpath("aug.png") - fig.savefig(figure_path, bbox_inches="tight") - plt.close() + #figure_path = augmented_case_path.joinpath("aug.png") + #fig.savefig(figure_path, bbox_inches="tight") + #plt.close() if __name__ == "__main__": From 3c88ee0141adb3054af344dbd033362b2a005cb3 Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Thu, 19 Aug 2021 04:00:45 +0000 Subject: [PATCH 113/264] Use cropping during augmentation --- platipy/imaging/generation/augment.py | 68 ++++++++++++++++++--------- 1 file changed, 46 insertions(+), 22 deletions(-) diff --git a/platipy/imaging/generation/augment.py b/platipy/imaging/generation/augment.py index 0a409902..c100104e 100644 --- a/platipy/imaging/generation/augment.py +++ b/platipy/imaging/generation/augment.py @@ -26,7 +26,7 @@ from loguru import logger import matplotlib.pyplot as plt -#from platipy.imaging import ImageVisualiser +from platipy.imaging import ImageVisualiser from platipy.imaging.generation.dvf import ( generate_field_shift, @@ -40,6 +40,9 @@ from platipy.imaging.registration.utils import apply_transform from platipy.imaging.utils.lung import detect_holes +from platipy.imaging.label.utils import get_union_mask +from platipy.imaging.utils.crop import label_to_roi, crop_to_roi + def apply_augmentation(image, augmentation, masks=[]): @@ -55,7 +58,7 @@ def apply_augmentation(image, augmentation, masks=[]): "DeformableAugment's" ) - #transforms = [] + # transforms = [] transform = None dvf = None for aug in augmentation: @@ -65,7 +68,7 @@ def apply_augmentation(image, augmentation, masks=[]): logger.debug(str(aug)) tfm, field = aug.augment() - #transforms.append(tfm) + # transforms.append(tfm) if dvf is None: dvf = field @@ -74,8 +77,8 @@ def apply_augmentation(image, augmentation, masks=[]): dvf += field transform = sitk.CompositeTransform([transform, tfm]) - #transform = sitk.CompositeTransform(transforms) - #del transforms + # transform = sitk.CompositeTransform(transforms) + # del transforms image_deformed = apply_transform( image, @@ -289,13 +292,13 @@ def augment_data(args): logger.info(f"Augmenting for case: {case}") - ct_image = sitk.ReadImage(str(data[case]["image"])) + ct_image_original = sitk.ReadImage(str(data[case]["image"])) if args.enable_fill_holes: logger.debug("Finding holes") - ct_image = sitk.ReadImage(str(data[case]["image"])) - label_image, labels = detect_holes(ct_image) + ct_image_original = sitk.ReadImage(str(data[case]["image"])) + label_image, labels = detect_holes(ct_image_original) # Get list of structures to generate augmentations off logger.debug("Collecting structures") @@ -308,28 +311,41 @@ def augment_data(args): all_masks.append(mask) all_names.append(structure_path.name.replace(".nii.gz", "")) + logger.debug("Cropping to regions around all structures") + union_mask = get_union_mask(all_masks) + size, index = label_to_roi(union_mask, expansion_mm=[25, 25, 25]) + + for m, mask in enumerate(all_masks): + all_masks[m] = crop_to_roi(mask, size, index) + # Generate x random augmentations per case for i in range(args.augmentations_per_case): logger.debug(f"Generating augmentation {i}") ct_image = sitk.ReadImage(str(data[case]["image"])) + ct_image = crop_to_roi(ct_image, size, index) if args.enable_fill_holes: logger.debug("Filling holes") - for label in labels[1:]: # Skip first hole since likely air around body + for label in labels[1:]: # Skip first hole since likely air around body - if random.random() > args.fill_probability: continue + if random.random() > args.fill_probability: + continue hole = label_image == label["label"] - hole_dilate = sitk.BinaryDilate(hole, (2,2,2), sitk.sitkBall) + hole_dilate = sitk.BinaryDilate(hole, (2, 2, 2), sitk.sitkBall) contour_points = sitk.BinaryContour(hole_dilate) - fill_value = np.median(sitk.GetArrayFromImage(ct_image)[sitk.GetArrayFromImage(contour_points)==1]) + fill_value = np.median( + sitk.GetArrayFromImage(ct_image)[ + sitk.GetArrayFromImage(contour_points) == 1 + ] + ) ct_arr = sitk.GetArrayFromImage(ct_image) - ct_arr[sitk.GetArrayFromImage(hole_dilate)==1] = fill_value + ct_arr[sitk.GetArrayFromImage(hole_dilate) == 1] = fill_value ct_filled = sitk.GetImageFromArray(ct_arr) ct_filled.CopyInformation(ct_image) @@ -358,24 +374,32 @@ def augment_data(args): ) augmented_image_path = augmented_case_path.joinpath("CT.nii.gz") + ct_image_original[ + index[0] : index[0] + size[0], + index[1] : index[1] + size[1], + index[2] : index[2] + size[2], + ] = augmented_image sitk.WriteImage(augmented_image, str(augmented_image_path)) - #vis = ImageVisualiser(image=ct_image, figure_size_in=6) - #vis.add_comparison_overlay(augmented_image) - #if dvf is not None: - # vis.add_vector_overlay(dvf, arrow_scale=1, subsample=(4, 12, 12)) + vis = ImageVisualiser(image=ct_image, figure_size_in=6) + vis.add_comparison_overlay(augmented_image) + if dvf is not None: + vis.add_vector_overlay(dvf, arrow_scale=1, subsample=(4, 12, 12)) for mask_name, mask, augmented_mask in zip(all_names, all_masks, augmented_masks): - #vis.add_contour({f"{mask_name}": mask, f"{mask_name} (augmented)": augmented_mask}) + vis.add_contour({f"{mask_name}": mask, f"{mask_name} (augmented)": augmented_mask}) logger.debug(f"Applying augmentation to mask: {mask_name}") augmented_mask_path = augmented_case_path.joinpath(f"{mask_name}.nii.gz") + augmented_mask = sitk.Resample( + augmented_mask, ct_image_original, sitk.Transform, sitk.sitkNearestNeighbor + ) sitk.WriteImage(augmented_mask, str(augmented_mask_path)) - #fig = vis.show() + fig = vis.show() - #figure_path = augmented_case_path.joinpath("aug.png") - #fig.savefig(figure_path, bbox_inches="tight") - #plt.close() + figure_path = augmented_case_path.joinpath("aug.png") + fig.savefig(figure_path, bbox_inches="tight") + plt.close() if __name__ == "__main__": From 9ef72661869908b83de48280c6d4209ef46def4e Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Thu, 19 Aug 2021 04:03:08 +0000 Subject: [PATCH 114/264] crop before fill holes --- platipy/imaging/generation/augment.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/platipy/imaging/generation/augment.py b/platipy/imaging/generation/augment.py index c100104e..4b5a046e 100644 --- a/platipy/imaging/generation/augment.py +++ b/platipy/imaging/generation/augment.py @@ -294,12 +294,6 @@ def augment_data(args): ct_image_original = sitk.ReadImage(str(data[case]["image"])) - if args.enable_fill_holes: - - logger.debug("Finding holes") - ct_image_original = sitk.ReadImage(str(data[case]["image"])) - label_image, labels = detect_holes(ct_image_original) - # Get list of structures to generate augmentations off logger.debug("Collecting structures") all_masks = [] @@ -314,10 +308,16 @@ def augment_data(args): logger.debug("Cropping to regions around all structures") union_mask = get_union_mask(all_masks) size, index = label_to_roi(union_mask, expansion_mm=[25, 25, 25]) + ct_image = crop_to_roi(ct_image_original, size, index) for m, mask in enumerate(all_masks): all_masks[m] = crop_to_roi(mask, size, index) + if args.enable_fill_holes: + + logger.debug("Finding holes") + label_image, labels = detect_holes(ct_image) + # Generate x random augmentations per case for i in range(args.augmentations_per_case): From eb1ed88a656859f2f80056a72f269d2d4ade27c5 Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Sun, 22 Aug 2021 12:15:07 +1000 Subject: [PATCH 115/264] Fixes and stuff --- platipy/imaging/cnn/dataload.py | 35 +++++++++++++++++++-------- platipy/imaging/cnn/dataset.py | 10 ++++++-- platipy/imaging/cnn/localise_net.py | 19 ++++++++------- platipy/imaging/cnn/prob_unet.py | 5 ++-- platipy/imaging/cnn/train.py | 12 ++++++++- platipy/imaging/cnn/train_localise.py | 1 + platipy/imaging/cnn/utils.py | 5 +++- platipy/imaging/generation/augment.py | 14 ++++++----- 8 files changed, 70 insertions(+), 31 deletions(-) diff --git a/platipy/imaging/cnn/dataload.py b/platipy/imaging/cnn/dataload.py index 6338bca6..a34696ac 100644 --- a/platipy/imaging/cnn/dataload.py +++ b/platipy/imaging/cnn/dataload.py @@ -39,6 +39,7 @@ def __init__( contour_mask_kernel=3, crop_using_localise_model=None, localise_voxel_grid_size=[100, 100, 100], + validation_sampler="observer", # observer or batch ndims=2, **kwargs, ): @@ -54,6 +55,7 @@ def __init__( self.augmented_image_glob = augmented_image_glob self.augmented_label_glob = augmented_label_glob + print(augment_on_fly) self.augment_on_fly = augment_on_fly self.fold = fold self.k_folds = k_folds @@ -77,6 +79,7 @@ def __init__( self.training_set = None self.validation_set = None self.test_set = None + self.validation_sampler = validation_sampler self.ndims = ndims @@ -219,10 +222,14 @@ def setup(self, stage=None): else: crop_to_grid_size = self.crop_to_grid_size_xy + augment_on_fly = self.augment_on_fly + if self.ndims == 3: + augment_on_fly = False + self.training_set = NiftiDataset( train_data, self.working_dir, - augment_on_fly=self.augment_on_fly, + augment_on_fly=augment_on_fly, spacing=self.spacing, crop_to_grid_size=crop_to_grid_size, crop_using_localise_model=localise_model_path, @@ -265,12 +272,20 @@ def train_dataloader(self): ) def val_dataloader(self): - return torch.utils.data.DataLoader( - self.validation_set, - batch_sampler=torch.utils.data.BatchSampler( - ObserverSampler(self.validation_set, self.num_observers), - batch_size=self.num_observers, - drop_last=False, - ), - num_workers=self.num_workers, - ) + if self.validation_sampler == "observer": + return torch.utils.data.DataLoader( + self.validation_set, + batch_sampler=torch.utils.data.BatchSampler( + ObserverSampler(self.validation_set, self.num_observers), + batch_size=self.num_observers, + drop_last=False, + ), + num_workers=self.num_workers, + ) + else: + return torch.utils.data.DataLoader( + self.validation_set, + batch_size=self.batch_size, + shuffle=False, + num_workers=self.num_workers, + ) diff --git a/platipy/imaging/cnn/dataset.py b/platipy/imaging/cnn/dataset.py index 5b9b016c..efe533d7 100644 --- a/platipy/imaging/cnn/dataset.py +++ b/platipy/imaging/cnn/dataset.py @@ -153,7 +153,13 @@ def __init__( crop_to_grid_size = (crop_to_grid_size,) * 3 index = [i - int((g - s) / 2) for i, s, g in zip(index, size, crop_to_grid_size)] + size = crop_to_grid_size + img_size = img.GetSize() + for i in range(3): + if index[i] + size[i] >= img_size[i]: + index[i] = img_size[i] - size[i] - 1 + if index[i] < 0: index[i] = 0 img = crop_to_roi(img, size, index) else: @@ -205,7 +211,7 @@ def __init__( else: label_slice = label label_file = self.label_dir.joinpath(f"{case_id}_{obs}_{z_slice}.npy") - np.save(label_file, sitk.GetArrayFromImage(label_slice)) + np.save(label_file, sitk.GetArrayFromImage(label_slice).astype(np.int8)) self.slices.append( { "z": z_slice, @@ -236,7 +242,7 @@ def __getitem__(self, index): contour_mask = seg.get_arr()[:, :, 1].squeeze() img = torch.FloatTensor(img) - label = torch.LongTensor(label) + label = torch.IntTensor(label) contour_mask = torch.FloatTensor(contour_mask) return ( diff --git a/platipy/imaging/cnn/localise_net.py b/platipy/imaging/cnn/localise_net.py index a7414a12..b1f43fb3 100644 --- a/platipy/imaging/cnn/localise_net.py +++ b/platipy/imaging/cnn/localise_net.py @@ -73,7 +73,7 @@ def configure_optimizers(self): def infer(self, img): pp_img = preprocess_image( - img, spacing=self.hparams.spacing, crop_to_mm=self.hparams.crop_to_mm + img, spacing=self.hparams.spacing, crop_to_grid_size_xy=self.hparams.crop_to_mm ) preds = [] @@ -189,14 +189,6 @@ def validation_epoch_end(self, validation_step_outputs): color_dict = {} obs_dict = {} - com = None - try: - com = get_com(pred) - except: - com = [int(i / 2) for i in pred.GetSize()] - - img_vis = ImageVisualiser(img, cut=com, figure_size_in=16, window=[-1.0, 1.0]) - for _, observer in enumerate(cases[case]["observers"]): mask_arrs = [] for z in slices: @@ -213,6 +205,15 @@ def validation_epoch_end(self, validation_step_outputs): obs_dict[f"manual_{observer}"] = mask color_dict[f"manual_{observer}"] = [0.7, 0.2, 0.2] + com = None + try: + com = get_com(mask) + except: + com = [int(i / 2) for i in mask.GetSize()] + + img_vis = ImageVisualiser(img, cut=com, figure_size_in=16) + img_vis.set_limits_from_label(mask, expansion=[0, 0, 0]) + contour_dict = {**obs_dict} contour_dict["pred"] = pred color_dict["pred"] = [0.2, 0.4, 0.8] diff --git a/platipy/imaging/cnn/prob_unet.py b/platipy/imaging/cnn/prob_unet.py index e9c24176..c8d47626 100644 --- a/platipy/imaging/cnn/prob_unet.py +++ b/platipy/imaging/cnn/prob_unet.py @@ -233,8 +233,9 @@ def __init__( self.prior_latent_space = None self.unet_features = None - self._moving_avg = None - self.register_buffer("_lambda", torch.zeros(1, requires_grad=False)) + if self.loss_type == "geco": + self._moving_avg = None + self.register_buffer("_lambda", torch.zeros(1, requires_grad=False)) def forward(self, img, seg=None, training=False): """ diff --git a/platipy/imaging/cnn/train.py b/platipy/imaging/cnn/train.py index 11b11a1c..5a5ce356 100644 --- a/platipy/imaging/cnn/train.py +++ b/platipy/imaging/cnn/train.py @@ -27,7 +27,7 @@ import torch import pytorch_lightning as pl -from pytorch_lightning.callbacks import LearningRateMonitor +from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint from argparse import ArgumentParser @@ -427,6 +427,16 @@ def main(args, config_json_path=None): lr_monitor = LearningRateMonitor(logging_interval="step") trainer.callbacks.append(lr_monitor) + # Save the best model + checkpoint_callback = ModelCheckpoint( + monitor="probnet_DSC", + dirpath=args.default_root_dir, + filename="probunet-{epoch:02d}-{DSC:.2f}", + save_top_k=1, + mode="max", + ) + trainer.callbacks.append(checkpoint_callback) + trainer.fit(prob_unet, data_module) diff --git a/platipy/imaging/cnn/train_localise.py b/platipy/imaging/cnn/train_localise.py index 49c64ecb..e4932856 100644 --- a/platipy/imaging/cnn/train_localise.py +++ b/platipy/imaging/cnn/train_localise.py @@ -68,6 +68,7 @@ def main(args, config_json_path=None): comet_logger.experiment.log_code(config_json_path) dict_args = vars(args) + dict_args["validation_sampler"] = "batch" data_module = UNetDataModule(**dict_args) diff --git a/platipy/imaging/cnn/utils.py b/platipy/imaging/cnn/utils.py index ba7c0a04..2c1f5813 100644 --- a/platipy/imaging/cnn/utils.py +++ b/platipy/imaging/cnn/utils.py @@ -54,10 +54,10 @@ def preprocess_image( sitk.Image: The preprocessed image. """ + img = sitk.Cast(img, sitk.sitkFloat32) if intensity_scaling == "norm": img = sitk.Normalize(img) elif intensity_scaling == "window": - img = sitk.Cast(img, sitk.sitkFloat32) img = sitk.IntensityWindowing( img, windowMinimum=intensity_window[0], @@ -162,6 +162,9 @@ def get_metrics(target, pred): if sitk.GetArrayFromImage(pred).sum() == 0: result["HD"] = 1000 result["ASD"] = 100 + elif sitk.GetArrayFromImage(target).sum() == 0: + result["HD"] = 1000 + result["ASD"] = 100 else: hdif = sitk.HausdorffDistanceImageFilter() hdif.Execute(target, pred) diff --git a/platipy/imaging/generation/augment.py b/platipy/imaging/generation/augment.py index 4b5a046e..338de085 100644 --- a/platipy/imaging/generation/augment.py +++ b/platipy/imaging/generation/augment.py @@ -89,12 +89,14 @@ def apply_augmentation(image, augmentation, masks=[]): masks_deformed = [] for mask in masks: - masks_deformed.append( - apply_transform( - mask, transform=transform, default_value=0, interpolator=sitk.sitkNearestNeighbor - ) + def_mask = apply_transform( + mask, transform=transform, default_value=0, interpolator=sitk.sitkNearestNeighbor ) + def_mask = sitk.BinaryMorphologicalClosing(def_mask, [3, 3, 3]) + + masks_deformed.append(def_mask) + if masks: return image_deformed, masks_deformed, dvf @@ -379,7 +381,7 @@ def augment_data(args): index[1] : index[1] + size[1], index[2] : index[2] + size[2], ] = augmented_image - sitk.WriteImage(augmented_image, str(augmented_image_path)) + sitk.WriteImage(ct_image_original, str(augmented_image_path)) vis = ImageVisualiser(image=ct_image, figure_size_in=6) vis.add_comparison_overlay(augmented_image) @@ -391,7 +393,7 @@ def augment_data(args): logger.debug(f"Applying augmentation to mask: {mask_name}") augmented_mask_path = augmented_case_path.joinpath(f"{mask_name}.nii.gz") augmented_mask = sitk.Resample( - augmented_mask, ct_image_original, sitk.Transform, sitk.sitkNearestNeighbor + augmented_mask, ct_image_original, sitk.Transform(), sitk.sitkNearestNeighbor ) sitk.WriteImage(augmented_mask, str(augmented_mask_path)) From 1da91ae37b7aa6eefb847879103ad21203679cc3 Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Mon, 23 Aug 2021 22:13:59 +0000 Subject: [PATCH 116/264] Work on prob net infer --- platipy/imaging/cnn/dataload.py | 11 +-- platipy/imaging/cnn/dataset.py | 36 +++------- platipy/imaging/cnn/localise_net.py | 1 - platipy/imaging/cnn/train.py | 99 ++++++++++++++++++++++++++- platipy/imaging/cnn/utils.py | 101 ++++++++++++++++++++++++---- 5 files changed, 204 insertions(+), 44 deletions(-) diff --git a/platipy/imaging/cnn/dataload.py b/platipy/imaging/cnn/dataload.py index a34696ac..0c68c9e6 100644 --- a/platipy/imaging/cnn/dataload.py +++ b/platipy/imaging/cnn/dataload.py @@ -81,6 +81,9 @@ def __init__( self.test_set = None self.validation_sampler = validation_sampler + self.validation_data = [] + self.test_data = [] + self.ndims = ndims print(f"Training fold {self.fold}") @@ -184,7 +187,7 @@ def setup(self, stage=None): for augmented_case in augmented_cases ] - validation_data = [ + self.validation_data = [ { "id": case, "image": self.data_dir.joinpath(self.image_glob.format(case=case)), @@ -197,7 +200,7 @@ def setup(self, stage=None): for case in self.validation_cases ] - test_data = [ + self.test_data = [ { "id": case, "image": self.data_dir.joinpath(self.image_glob.format(case=case)), @@ -239,7 +242,7 @@ def setup(self, stage=None): ndims=self.ndims, ) self.validation_set = NiftiDataset( - validation_data, + self.validation_data, self.working_dir, augment_on_fly=False, spacing=self.spacing, @@ -251,7 +254,7 @@ def setup(self, stage=None): ndims=self.ndims, ) self.test_set = NiftiDataset( - test_data, + self.test_data, self.working_dir, augment_on_fly=False, spacing=self.spacing, diff --git a/platipy/imaging/cnn/dataset.py b/platipy/imaging/cnn/dataset.py index efe533d7..03bfdbcc 100644 --- a/platipy/imaging/cnn/dataset.py +++ b/platipy/imaging/cnn/dataset.py @@ -12,9 +12,7 @@ from loguru import logger -from platipy.imaging.cnn.localise_net import LocaliseUNet from platipy.imaging.cnn.utils import preprocess_image, resample_mask_to_image, get_contour_mask -from platipy.imaging.utils.crop import label_to_roi, crop_to_roi from platipy.imaging.label.utils import get_union_mask, get_intersection_mask @@ -140,31 +138,19 @@ def __init__( img = sitk.ReadImage(img_path) if crop_using_localise_model: - localise_model = LocaliseUNet.load_from_checkpoint(crop_using_localise_model) - localise_model.eval() - localise_pred = localise_model.infer(img) - - img = preprocess_image(img, spacing=spacing, crop_to_grid_size_xy=None) - localise_pred = resample_mask_to_image(img, localise_pred) - - size, index = label_to_roi(localise_pred) - - if not hasattr(crop_to_grid_size, "__iter__"): - crop_to_grid_size = (crop_to_grid_size,) * 3 - - index = [i - int((g - s) / 2) for i, s, g in zip(index, size, crop_to_grid_size)] - - size = crop_to_grid_size - img_size = img.GetSize() - for i in range(3): - if index[i] + size[i] >= img_size[i]: - index[i] = img_size[i] - size[i] - 1 - if index[i] < 0: index[i] = 0 - - img = crop_to_roi(img, size, index) + crop_using_localise_model( + img, + crop_using_localise_model, + spacing=spacing, + crop_to_grid_size=crop_to_grid_size, + ) else: img = preprocess_image( - img, spacing=spacing, crop_to_grid_size_xy=crop_to_grid_size, intensity_scaling=intensity_scaling, intensity_window=intensity_window + img, + spacing=spacing, + crop_to_grid_size_xy=crop_to_grid_size, + intensity_scaling=intensity_scaling, + intensity_window=intensity_window, ) observers = [] diff --git a/platipy/imaging/cnn/localise_net.py b/platipy/imaging/cnn/localise_net.py index b1f43fb3..5812d357 100644 --- a/platipy/imaging/cnn/localise_net.py +++ b/platipy/imaging/cnn/localise_net.py @@ -19,7 +19,6 @@ import numpy as np import comet_ml # pylint: disable=unused-import -from pytorch_lightning.loggers import CometLogger import torch import pytorch_lightning as pl diff --git a/platipy/imaging/cnn/train.py b/platipy/imaging/cnn/train.py index 5a5ce356..8207e904 100644 --- a/platipy/imaging/cnn/train.py +++ b/platipy/imaging/cnn/train.py @@ -36,11 +36,17 @@ from platipy.imaging.cnn.prob_unet import ProbabilisticUnet from platipy.imaging.cnn.unet import l2_regularisation from platipy.imaging.cnn.dataload import UNetDataModule -from platipy.imaging.cnn.utils import postprocess_mask, get_metrics +from platipy.imaging.cnn.utils import ( + preprocess_image, + postprocess_mask, + get_metrics, + crop_using_localise_model, +) from platipy.imaging import ImageVisualiser from platipy.imaging.label.utils import get_com + class ProbUNet(pl.LightningModule): def __init__( self, @@ -120,6 +126,95 @@ def configure_optimizers(self): return [optimizer], [scheduler] + def infer( + self, img, num_samples=1, sample_strategy="mean", latent_dim=True, spaced_range=[-1.5, 1.5] + ): + # sample strategy in "mean", "random", "spaced" + + if not hasattr(latent_dim, "__iter__"): + latent_dim = [ + latent_dim, + ] * self.hparams.latent_dim + + if sample_strategy == "mean": + samples = [{"name": "mean", "std_dev_from_mean": [0.0] * len(latent_dim), "preds": []}] + elif sample_strategy == "random": + samples = [ + { + "name": f"random_{i}", + "std_dev_from_mean": torch.Tensor( + [np.random.normal(0, 1.0, 1)[0] if d else 0.0 for d in latent_dim] + ), + "preds": [], + } + for i in range(num_samples) + ] + elif sample_strategy == "spaced": + samples = [ + { + "name": f"spaced_{s}", + "std_dev_from_mean": torch.Tensor([s if d else 0.0 for d in latent_dim]), + "preds": [], + } + for s in np.linspace(spaced_range[0], spaced_range[1], num_samples) + ] + + with torch.no_grad(): + + if self.hparams.crop_using_localise_model: + localise_path = self.hparams.crop_using_localise_model.format( + fold=self.hparams.fold + ) + img = crop_using_localise_model( + img, + localise_path, + spacing=self.hparams.spacing, + crop_to_grid_size=self.hparams.localise_voxel_grid_size, + ) + else: + img = preprocess_image( + img, + spacing=self.hparams.spacing, + crop_to_grid_size_xy=self.hparams.crop_to_grid_size, + intensity_scaling=self.hparams.intensity_scaling, + intensity_window=self.hparams.intensity_window, + ) + + img_arr = sitk.GetArrayFromImage(img) + for z in range(img_arr.shape[0]): + + x = torch.Tensor(img_arr[z, :, :]) + x = x.unsqueeze(0) + x = x.unsqueeze(0) + self.prob_unet.forward(x) + + for sample in samples: + if sample["name"] == "mean": + y = self.prob_unet.sample(testing=True, use_mean=True) + else: + y = self.prob_unet.sample( + testing=True, + use_mean=False, + sample_x_stddev_from_mean=sample["std_dev_from_mean"], + ) + + y = y.squeeze(0) + y = np.argmax(y.cpu().detach().numpy(), axis=0) + sample["preds"].append(y) + + result = {} + for sample in samples: + pred = sitk.GetImageFromArray(np.stack(sample["preds"])) + pred = sitk.Cast(pred, sitk.sitkUInt8) + + pred.CopyInformation(img) + pred = postprocess_mask(pred) + pred = sitk.Resample(pred, img, sitk.Transform(), sitk.sitkNearestNeighbor) + + result[sample["name"]] = pred + + return result + def training_step(self, batch, _): x, y, m, _ = batch @@ -309,7 +404,7 @@ def validation_epoch_end(self, validation_step_outputs): color_dict[f"auto_{self.stddevs[idx]}"] = cmap(observer / 5) img_vis = ImageVisualiser( - img, cut=get_com(mask), figure_size_in=16, window=[-1.0, 1.0] + img, cut=get_com(mask), figure_size_in=16, window=[-0.3, 1.0] ) contour_dict = {**obs_dict, **pred_dict} diff --git a/platipy/imaging/cnn/utils.py b/platipy/imaging/cnn/utils.py index 2c1f5813..2b2bb9e5 100644 --- a/platipy/imaging/cnn/utils.py +++ b/platipy/imaging/cnn/utils.py @@ -1,7 +1,29 @@ +# Copyright 2021 University of New South Wales, University of Sydney, Ingham Institute + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from pathlib import Path + import numpy as np import SimpleITK as sitk +from platipy.imaging.cnn.localise_net import LocaliseUNet from platipy.imaging.label.utils import get_union_mask, get_intersection_mask +from platipy.imaging.utils.crop import label_to_roi, crop_to_roi +from platipy.imaging.label.comparison import ( + compute_metric_dsc, + compute_metric_hd, + compute_metric_masd, +) def get_contour_mask(masks, kernel=5): @@ -110,6 +132,55 @@ def preprocess_image( return img +def crop_using_localise_model( + img, localise_model, spacing=[1, 1, 1], crop_to_grid_size=[100, 100, 100] +): + """Crops an image using a LocaliseUNet + + Args: + img (SimpleITK.Image): The image to crop + localise_model (str|Path|LocaliseUNet): The LocaliseUNet or path to checkpoint of + LocaliseUNet. + spacing (list, optional): The image spacing (mm) to resample to. Defaults to [1,1,1]. + crop_to_grid_size (list, optional): The size of the grid to crop to. Defaults to + [100,100,100]. + + Returns: + SimpleITK.Image: The cropped image. + """ + + if isinstance(localise_model, str): + localise_model = Path(localise_model) + + if isinstance(localise_model, Path): + if localise_model.is_dir(): + # Find the first actual model checkpoint in this directory + localise_model = next(localise_model.glob("*.ckpt")) + + localise_model = LocaliseUNet.load_from_checkpoint(localise_model) + + localise_model.eval() + localise_pred = localise_model.infer(img) + + img = preprocess_image(img, spacing=spacing, crop_to_grid_size_xy=None) + localise_pred = resample_mask_to_image(img, localise_pred) + size, index = label_to_roi(localise_pred) + + if not hasattr(crop_to_grid_size, "__iter__"): + crop_to_grid_size = (crop_to_grid_size,) * 3 + + index = [i - int((g - s) / 2) for i, s, g in zip(index, size, crop_to_grid_size)] + size = crop_to_grid_size + img_size = img.GetSize() + for i in range(3): + if index[i] + size[i] >= img_size[i]: + index[i] = img_size[i] - size[i] - 1 + if index[i] < 0: + index[i] = 0 + + return crop_to_roi(img, size, index) + + def resample_mask_to_image(img, mask): """Repsample a mask to the space of the image supplied. @@ -132,6 +203,14 @@ def resample_mask_to_image(img, mask): def postprocess_mask(pred): + """Perform postprocessing on a generated auto-segmentation + + Args: + pred (sitk.Image): The predicted mask + + Returns: + sitk.Image: The postprocessed mask + """ # Take only the largest componenet labelled_image = sitk.ConnectedComponent(pred) @@ -154,21 +233,19 @@ def postprocess_mask(pred): def get_metrics(target, pred): result = {} - lomif = sitk.LabelOverlapMeasuresImageFilter() - lomif.Execute(target, pred) - result["JI"] = lomif.GetJaccardCoefficient() - result["DSC"] = lomif.GetDiceCoefficient() + result["DSC"] = compute_metric_dsc(target, pred) - if sitk.GetArrayFromImage(pred).sum() == 0: - result["HD"] = 1000 - result["ASD"] = 100 - elif sitk.GetArrayFromImage(target).sum() == 0: + target_pixels = sitk.GetArrayFromImage(target).sum() + pred_pixels = sitk.GetArrayFromImage(pred).sum() + + if pred_pixels == 0 and target_pixels == 0: + result["HD"] = 0 + result["ASD"] = 0 + elif pred_pixels == 0 or target_pixels == 0: result["HD"] = 1000 result["ASD"] = 100 else: - hdif = sitk.HausdorffDistanceImageFilter() - hdif.Execute(target, pred) - result["HD"] = hdif.GetHausdorffDistance() - result["ASD"] = hdif.GetAverageHausdorffDistance() + result["HD"] = compute_metric_hd(target, pred) + result["ASD"] = compute_metric_masd(target, pred) return result From 8f8e57e6b74886cb660af01879104364eb68f7df Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Tue, 24 Aug 2021 23:58:30 +0000 Subject: [PATCH 117/264] Add validation function --- platipy/imaging/cnn/train.py | 259 +++++++++++++++++------------------ 1 file changed, 128 insertions(+), 131 deletions(-) diff --git a/platipy/imaging/cnn/train.py b/platipy/imaging/cnn/train.py index 8207e904..94027870 100644 --- a/platipy/imaging/cnn/train.py +++ b/platipy/imaging/cnn/train.py @@ -44,7 +44,7 @@ ) from platipy.imaging import ImageVisualiser -from platipy.imaging.label.utils import get_com +from platipy.imaging.label.utils import get_com, get_union_mask, get_intersection_mask class ProbUNet(pl.LightningModule): @@ -127,7 +127,13 @@ def configure_optimizers(self): return [optimizer], [scheduler] def infer( - self, img, num_samples=1, sample_strategy="mean", latent_dim=True, spaced_range=[-1.5, 1.5] + self, + img, + num_samples=1, + sample_strategy="mean", + latent_dim=True, + spaced_range=[-1.5, 1.5], + preprocess=True, ): # sample strategy in "mean", "random", "spaced" @@ -161,24 +167,25 @@ def infer( with torch.no_grad(): - if self.hparams.crop_using_localise_model: - localise_path = self.hparams.crop_using_localise_model.format( - fold=self.hparams.fold - ) - img = crop_using_localise_model( - img, - localise_path, - spacing=self.hparams.spacing, - crop_to_grid_size=self.hparams.localise_voxel_grid_size, - ) - else: - img = preprocess_image( - img, - spacing=self.hparams.spacing, - crop_to_grid_size_xy=self.hparams.crop_to_grid_size, - intensity_scaling=self.hparams.intensity_scaling, - intensity_window=self.hparams.intensity_window, - ) + if preprocess: + if self.hparams.crop_using_localise_model: + localise_path = self.hparams.crop_using_localise_model.format( + fold=self.hparams.fold + ) + img = crop_using_localise_model( + img, + localise_path, + spacing=self.hparams.spacing, + crop_to_grid_size=self.hparams.localise_voxel_grid_size, + ) + else: + img = preprocess_image( + img, + spacing=self.hparams.spacing, + crop_to_grid_size_xy=self.hparams.crop_to_grid_size, + intensity_scaling=self.hparams.intensity_scaling, + intensity_window=self.hparams.intensity_window, + ) img_arr = sitk.GetArrayFromImage(img) for z in range(img_arr.shape[0]): @@ -215,6 +222,92 @@ def infer( return result + def validate(self, img, manual_observers, samples, mean, matching_type="best"): + + metrics = {"DSC": "max", "HD": "min", "ASD": "min"} + + contour_cmap = "coolwarm" + + intersection_mask = get_intersection_mask(manual_observers) + union_mask = get_union_mask(manual_observers) + vis = ImageVisualiser(img, cut=get_com(union_mask), window=[-200, 700]) + vis.add_contour( + intersection_mask, name="intersection", color=[0.13, 0.67, 0.275], linewidth=3 + ) + vis.add_contour(union_mask, name="union", color=[0.13, 0.67, 0.275], linewidth=3) + vis.add_contour( + samples, + show_legend=False, + linewidth=1.5, + color={ + s: c + for s, c in zip( + samples, plt.cm.get_cmap(contour_cmap)(np.linspace(0, 1, len(samples))) + ) + }, + ) + vis.add_contour( + mean, color=plt.cm.get_cmap(contour_cmap)(0.5), linewidth=3, show_legend=False + ) + vis.add_contour( + manual_observers, color=[0.13, 0.67, 0.275], linewidth=0.5, show_legend=False + ) + + vis.set_limits_from_label(union_mask, expansion=30) + + fig = vis.show() + + first_obs = manual_observers[list(manual_observers.keys())[0]] + for s in samples: + samples[s] = sitk.Resample( + samples[s], first_obs, sitk.Transform(), sitk.sitkNearestNeighbor + ) + mean["mean"] = sitk.Resample( + mean["mean"], first_obs, sitk.Transform(), sitk.sitkNearestNeighbor + ) + + sim = {k: np.zeros((len(samples), len(manual_observers))) for k in metrics} + msim = {k: np.zeros((len(samples), len(manual_observers))) for k in metrics} + for sid, samp in enumerate(samples): + for oid, obs in enumerate(manual_observers): + sample_metrics = get_metrics(manual_observers[obs], samples[samp]) + mean_metrics = get_metrics(manual_observers[obs], mean["mean"]) + + for k in sample_metrics: + sim[k][sid, oid] = sample_metrics[k] + msim[k][sid, oid] = mean_metrics[k] + + result = {"probnet": {k: [] for k in metrics}, "unet": {k: [] for k in metrics}} + for k in sim: + + val = sim[k] + if matching_type == "hungarian": + if metrics[k] == "max": + val = -val + row_idx, col_idx = linear_sum_assignment(val) + prob_unet_mean = sim[k][row_idx, col_idx].mean() + else: + if metrics[k] == "max": + prob_unet_mean = val.max() + else: + prob_unet_mean = val.min() + result["probnet"][k].append(prob_unet_mean) + + val = msim[k] + if matching_type == "hungarian": + if metrics[k] == "max": + val = -val + row_idx, col_idx = linear_sum_assignment(val) + unet_mean = msim[k][row_idx, col_idx].mean() + else: + if metrics[k] == "max": + unet_mean = val.max() + else: + unet_mean = val.min() + result["unet"][k].append(unet_mean) + + return result, fig + def training_step(self, batch, _): x, y, m, _ = batch @@ -256,7 +349,6 @@ def validation_step(self, batch, _): if self.validation_directory is None: self.validation_directory = Path(tempfile.mkdtemp()) - print(self.validation_directory) with torch.set_grad_enabled(False): x, y, _, info = batch @@ -273,31 +365,11 @@ def validation_step(self, batch, _): ) np.save(mask_file, y[s].squeeze(0).cpu().numpy()) - self.prob_unet.forward(x[s].unsqueeze(0)) - sample = self.prob_unet.sample( - testing=True, - use_mean=False, - sample_x_stddev_from_mean=self.stddevs[s], - ) - sample_file = self.validation_directory.joinpath( - f"sample_{info['case'][s]}_{info['z'][s]}_{info['observer'][s]}.npy" - ) - sample = np.argmax(sample.squeeze(0).cpu().numpy(), axis=0) - np.save(sample_file, sample) - - mean = self.prob_unet.sample(testing=True, use_mean=True) - mean_file = self.validation_directory.joinpath( - f"mean_{info['case'][s]}_{info['z'][s]}.npy" - ) - mean = np.argmax(mean.squeeze(0).cpu().numpy(), axis=0) - np.save(mean_file, mean) - return info def validation_epoch_end(self, validation_step_outputs): cases = {} - cmap = plt.cm.get_cmap("Set2") for info in validation_step_outputs: for case, z, observer in zip(info["case"], info["z"], info["observer"]): @@ -310,7 +382,7 @@ def validation_epoch_end(self, validation_step_outputs): if not observer in cases[case]["observers"]: cases[case]["observers"].append(observer.item()) - metrics = ["JI", "DSC", "HD", "ASD"] + metrics = ["DSC", "HD", "ASD"] computed_metrics = { **{f"probnet_{m}": [] for m in metrics}, **{f"unet_{m}": [] for m in metrics}, @@ -319,141 +391,66 @@ def validation_epoch_end(self, validation_step_outputs): for case in cases: img_arrs = [] - mean_arrs = [] slices = [] if self.hparams.ndims == 2: for z in range(cases[case]["slices"] + 1): img_file = self.validation_directory.joinpath(f"img_{case}_{z}.npy") - mean_file = self.validation_directory.joinpath(f"mean_{case}_{z}.npy") if img_file.exists(): img_arrs.append(np.load(img_file)) - mean_arrs.append(np.load(mean_file)) slices.append(z) - # if len(slices) < 5: - # Likely initial sanity check - # continue - img_arr = np.stack(img_arrs) - mean_arr = np.stack(mean_arrs) else: img_file = self.validation_directory.joinpath(f"img_{case}_0.npy") - mean_file = self.validation_directory.joinpath(f"mean_{case}_0.npy") img_arr = np.load(img_file) - mean_arr = np.load(mean_file) img = sitk.GetImageFromArray(img_arr) img.SetSpacing(self.hparams.spacing) - mean = sitk.GetImageFromArray(mean_arr) - mean = sitk.Cast(mean, sitk.sitkUInt8) - mean = postprocess_mask(mean) - mean.CopyInformation(img) - # sitk.WriteImage(mean, f"val_mean_{case}_mean.nii.gz") + mean = self.infer(img, num_samples=1, sample_strategy="mean", preprocess=True) + samples = self.infer( + img, + sample_strategy="spaced", + num_samples=5, + spaced_range=[-1.5, 1.5], + preprocess=True, + ) - obs_dict = {} - pred_dict = {} - color_dict = {} - observers = [] - samples = [] - for idx, observer in enumerate(cases[case]["observers"]): + observers = {} + for _, observer in enumerate(cases[case]["observers"]): if self.hparams.ndims == 2: mask_arrs = [] - sample_arrs = [] for z in slices: mask_file = self.validation_directory.joinpath( f"mask_{case}_{z}_{observer}.npy" ) - sample_file = self.validation_directory.joinpath( - f"sample_{case}_{z}_{observer}.npy" - ) mask_arrs.append(np.load(mask_file)) - sample_arrs.append(np.load(sample_file)) mask_arr = np.stack(mask_arrs) - sample_arr = np.stack(sample_arrs) else: mask_file = self.validation_directory.joinpath( f"mask_{case}_{z}_{observer}.npy" ) - sample_file = self.validation_directory.joinpath( - f"sample_{case}_{z}_{observer}.npy" - ) mask_arr = np.load(mask_file) - sample_arr = np.load(sample_file) mask = sitk.GetImageFromArray(mask_arr) mask = sitk.Cast(mask, sitk.sitkUInt8) mask.CopyInformation(img) - sitk.WriteImage(mask, f"val_mask_{case}_{observer}.nii.gz") - observers.append(mask) - obs_dict[f"manual_{observer}"] = mask - color_dict[f"manual_{observer}"] = [0.5, 0.5, 0.5] - - sample = sitk.GetImageFromArray(sample_arr) - sample = sitk.Cast(sample, sitk.sitkUInt8) - sample = postprocess_mask(sample) - sample.CopyInformation(img) - sitk.WriteImage(sample, f"val_sample_{case}_{observer}.nii.gz") - samples.append(sample) - pred_dict[f"auto_{self.stddevs[idx]}"] = sample - color_dict[f"auto_{self.stddevs[idx]}"] = cmap(observer / 5) - - img_vis = ImageVisualiser( - img, cut=get_com(mask), figure_size_in=16, window=[-0.3, 1.0] - ) + observers[f"manual_{observer}"] = mask - contour_dict = {**obs_dict, **pred_dict} - contour_dict["auto_mean"] = mean - color_dict["auto_mean"] = [0.0, 0.0, 0.0] + result, fig = self.validate(img, observers, samples, mean, matching_type="best") - img_vis.add_contour(contour_dict, color=color_dict) - fig = img_vis.show() figure_path = f"valid_{case}.png" fig.savefig(figure_path, dpi=300) plt.close("all") - try: - self.logger.experiment.log_image(figure_path) - except AttributeError: - # Likely offline mode - pass - - sim = {k: np.zeros((len(observers), len(samples))) for k in metrics} - msim = {k: np.zeros((len(observers), len(samples))) for k in metrics} - for sid, samp in enumerate(samples): - for oid, obs in enumerate(observers): - sample_metrics = get_metrics(obs, samp) - mean_metrics = get_metrics(obs, mean) - - for k in sample_metrics: - sim[k][sid, oid] = sample_metrics[k] - msim[k][sid, oid] = mean_metrics[k] - - result = {"probnet": {k: [] for k in metrics}, "unet": {k: [] for k in metrics}} - for k in sim: - - val = sim[k] - if not k.endswith("D"): - val = -val - row_idx, col_idx = linear_sum_assignment(val) - prob_unet_mean = sim[k][row_idx, col_idx].mean() - result["probnet"][k].append(prob_unet_mean) - - val = msim[k] - if not k.endswith("D"): - val = -val - row_idx, col_idx = linear_sum_assignment(val) - unet_mean = msim[k][row_idx, col_idx].mean() - result["unet"][k].append(unet_mean) - for t in result: - for m in result[t]: - computed_metrics[f"{t}_{m}"].append(np.array(result[t][m]).mean()) + for m in metrics: + computed_metrics[f"{t}_{m}"]+=result[t][m] for cm in computed_metrics: self.log( From c6095d3c9388ba03c3c030eb2654bc84354c1d2f Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Wed, 25 Aug 2021 10:38:26 +1000 Subject: [PATCH 118/264] Add coarse dropout to augmentation --- platipy/imaging/cnn/dataset.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/platipy/imaging/cnn/dataset.py b/platipy/imaging/cnn/dataset.py index 03bfdbcc..fe72db30 100644 --- a/platipy/imaging/cnn/dataset.py +++ b/platipy/imaging/cnn/dataset.py @@ -45,6 +45,9 @@ def prepare_transforms(): ], random_order=True, ), + sometimes(iaa.CoarseDropout( + (0.03, 0.15), size_percent=(0.02, 0.1) + )) ], random_order=True, ) From 23c37016b11de9569acf93b5eb05ab807ee5afdb Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Wed, 25 Aug 2021 10:40:08 +1000 Subject: [PATCH 119/264] Correct issue in train --- platipy/imaging/cnn/train.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/platipy/imaging/cnn/train.py b/platipy/imaging/cnn/train.py index 94027870..f6df0b1b 100644 --- a/platipy/imaging/cnn/train.py +++ b/platipy/imaging/cnn/train.py @@ -408,13 +408,13 @@ def validation_epoch_end(self, validation_step_outputs): img = sitk.GetImageFromArray(img_arr) img.SetSpacing(self.hparams.spacing) - mean = self.infer(img, num_samples=1, sample_strategy="mean", preprocess=True) + mean = self.infer(img, num_samples=1, sample_strategy="mean", preprocess=False) samples = self.infer( img, sample_strategy="spaced", num_samples=5, spaced_range=[-1.5, 1.5], - preprocess=True, + preprocess=False, ) observers = {} From 58b1e4cfbb6c37cadeeda98294801055231aa981 Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Wed, 25 Aug 2021 00:50:07 +0000 Subject: [PATCH 120/264] Correct circular import --- platipy/imaging/cnn/dataset.py | 71 ++++++++++++++++++++++++++++++++-- platipy/imaging/cnn/utils.py | 52 ------------------------- 2 files changed, 67 insertions(+), 56 deletions(-) diff --git a/platipy/imaging/cnn/dataset.py b/platipy/imaging/cnn/dataset.py index fe72db30..643c4879 100644 --- a/platipy/imaging/cnn/dataset.py +++ b/platipy/imaging/cnn/dataset.py @@ -1,3 +1,17 @@ +# Copyright 2021 University of New South Wales, University of Sydney, Ingham Institute + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import re from pathlib import Path @@ -14,6 +28,57 @@ from platipy.imaging.cnn.utils import preprocess_image, resample_mask_to_image, get_contour_mask from platipy.imaging.label.utils import get_union_mask, get_intersection_mask +from platipy.imaging.cnn.localise_net import LocaliseUNet +from platipy.imaging.utils.crop import label_to_roi, crop_to_roi + + +def crop_img_using_localise_model( + img, localise_model, spacing=[1, 1, 1], crop_to_grid_size=[100, 100, 100] +): + """Crops an image using a LocaliseUNet + + Args: + img (SimpleITK.Image): The image to crop + localise_model (str|Path|LocaliseUNet): The LocaliseUNet or path to checkpoint of + LocaliseUNet. + spacing (list, optional): The image spacing (mm) to resample to. Defaults to [1,1,1]. + crop_to_grid_size (list, optional): The size of the grid to crop to. Defaults to + [100,100,100]. + + Returns: + SimpleITK.Image: The cropped image. + """ + + if isinstance(localise_model, str): + localise_model = Path(localise_model) + + if isinstance(localise_model, Path): + if localise_model.is_dir(): + # Find the first actual model checkpoint in this directory + localise_model = next(localise_model.glob("*.ckpt")) + + localise_model = LocaliseUNet.load_from_checkpoint(localise_model) + + localise_model.eval() + localise_pred = localise_model.infer(img) + + img = preprocess_image(img, spacing=spacing, crop_to_grid_size_xy=None) + localise_pred = resample_mask_to_image(img, localise_pred) + size, index = label_to_roi(localise_pred) + + if not hasattr(crop_to_grid_size, "__iter__"): + crop_to_grid_size = (crop_to_grid_size,) * 3 + + index = [i - int((g - s) / 2) for i, s, g in zip(index, size, crop_to_grid_size)] + size = crop_to_grid_size + img_size = img.GetSize() + for i in range(3): + if index[i] + size[i] >= img_size[i]: + index[i] = img_size[i] - size[i] - 1 + if index[i] < 0: + index[i] = 0 + + return crop_to_roi(img, size, index) def prepare_transforms(): @@ -45,9 +110,7 @@ def prepare_transforms(): ], random_order=True, ), - sometimes(iaa.CoarseDropout( - (0.03, 0.15), size_percent=(0.02, 0.1) - )) + sometimes(iaa.CoarseDropout((0.03, 0.15), size_percent=(0.02, 0.1))), ], random_order=True, ) @@ -141,7 +204,7 @@ def __init__( img = sitk.ReadImage(img_path) if crop_using_localise_model: - crop_using_localise_model( + crop_img_using_localise_model( img, crop_using_localise_model, spacing=spacing, diff --git a/platipy/imaging/cnn/utils.py b/platipy/imaging/cnn/utils.py index 2b2bb9e5..c2d4231b 100644 --- a/platipy/imaging/cnn/utils.py +++ b/platipy/imaging/cnn/utils.py @@ -11,14 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from pathlib import Path import numpy as np import SimpleITK as sitk -from platipy.imaging.cnn.localise_net import LocaliseUNet from platipy.imaging.label.utils import get_union_mask, get_intersection_mask -from platipy.imaging.utils.crop import label_to_roi, crop_to_roi from platipy.imaging.label.comparison import ( compute_metric_dsc, compute_metric_hd, @@ -132,55 +129,6 @@ def preprocess_image( return img -def crop_using_localise_model( - img, localise_model, spacing=[1, 1, 1], crop_to_grid_size=[100, 100, 100] -): - """Crops an image using a LocaliseUNet - - Args: - img (SimpleITK.Image): The image to crop - localise_model (str|Path|LocaliseUNet): The LocaliseUNet or path to checkpoint of - LocaliseUNet. - spacing (list, optional): The image spacing (mm) to resample to. Defaults to [1,1,1]. - crop_to_grid_size (list, optional): The size of the grid to crop to. Defaults to - [100,100,100]. - - Returns: - SimpleITK.Image: The cropped image. - """ - - if isinstance(localise_model, str): - localise_model = Path(localise_model) - - if isinstance(localise_model, Path): - if localise_model.is_dir(): - # Find the first actual model checkpoint in this directory - localise_model = next(localise_model.glob("*.ckpt")) - - localise_model = LocaliseUNet.load_from_checkpoint(localise_model) - - localise_model.eval() - localise_pred = localise_model.infer(img) - - img = preprocess_image(img, spacing=spacing, crop_to_grid_size_xy=None) - localise_pred = resample_mask_to_image(img, localise_pred) - size, index = label_to_roi(localise_pred) - - if not hasattr(crop_to_grid_size, "__iter__"): - crop_to_grid_size = (crop_to_grid_size,) * 3 - - index = [i - int((g - s) / 2) for i, s, g in zip(index, size, crop_to_grid_size)] - size = crop_to_grid_size - img_size = img.GetSize() - for i in range(3): - if index[i] + size[i] >= img_size[i]: - index[i] = img_size[i] - size[i] - 1 - if index[i] < 0: - index[i] = 0 - - return crop_to_roi(img, size, index) - - def resample_mask_to_image(img, mask): """Repsample a mask to the space of the image supplied. From aee092b75211abe6918a367e496c063c694ea584 Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Wed, 25 Aug 2021 13:19:20 +1000 Subject: [PATCH 121/264] Able to run on gpu --- platipy/imaging/cnn/train.py | 27 ++++++++++++++++++--------- platipy/imaging/cnn/utils.py | 12 ++++++------ 2 files changed, 24 insertions(+), 15 deletions(-) diff --git a/platipy/imaging/cnn/train.py b/platipy/imaging/cnn/train.py index f6df0b1b..0d1d8468 100644 --- a/platipy/imaging/cnn/train.py +++ b/platipy/imaging/cnn/train.py @@ -36,11 +36,11 @@ from platipy.imaging.cnn.prob_unet import ProbabilisticUnet from platipy.imaging.cnn.unet import l2_regularisation from platipy.imaging.cnn.dataload import UNetDataModule +from platipy.imaging.cnn.dataset import crop_img_using_localise_model from platipy.imaging.cnn.utils import ( preprocess_image, postprocess_mask, - get_metrics, - crop_using_localise_model, + get_metrics ) from platipy.imaging import ImageVisualiser @@ -143,14 +143,14 @@ def infer( ] * self.hparams.latent_dim if sample_strategy == "mean": - samples = [{"name": "mean", "std_dev_from_mean": [0.0] * len(latent_dim), "preds": []}] + samples = [{"name": "mean", "std_dev_from_mean": torch.Tensor([0.0] * len(latent_dim)).to(self.device), "preds": []}] elif sample_strategy == "random": samples = [ { "name": f"random_{i}", "std_dev_from_mean": torch.Tensor( [np.random.normal(0, 1.0, 1)[0] if d else 0.0 for d in latent_dim] - ), + ).to(self.device), "preds": [], } for i in range(num_samples) @@ -159,7 +159,7 @@ def infer( samples = [ { "name": f"spaced_{s}", - "std_dev_from_mean": torch.Tensor([s if d else 0.0 for d in latent_dim]), + "std_dev_from_mean": torch.Tensor([s if d else 0.0 for d in latent_dim]).to(self.device), "preds": [], } for s in np.linspace(spaced_range[0], spaced_range[1], num_samples) @@ -172,7 +172,7 @@ def infer( localise_path = self.hparams.crop_using_localise_model.format( fold=self.hparams.fold ) - img = crop_using_localise_model( + img = crop_img_using_localise_model( img, localise_path, spacing=self.hparams.spacing, @@ -188,9 +188,13 @@ def infer( ) img_arr = sitk.GetArrayFromImage(img) - for z in range(img_arr.shape[0]): + if self.hparams.ndims == 2: + slices = [img_arr[z, :, :] for z in range(img_arr.shape[0])] + else: + slices = [img_arr] + for i in slices: - x = torch.Tensor(img_arr[z, :, :]) + x = torch.Tensor(i).to(self.device) x = x.unsqueeze(0) x = x.unsqueeze(0) self.prob_unet.forward(x) @@ -211,7 +215,12 @@ def infer( result = {} for sample in samples: - pred = sitk.GetImageFromArray(np.stack(sample["preds"])) + + pred_arr = sample["preds"][0] + if len(sample["preds"]) > 1: + pred_arr = np.stack(sample["preds"]) + + pred = sitk.GetImageFromArray(pred_arr) pred = sitk.Cast(pred, sitk.sitkUInt8) pred.CopyInformation(img) diff --git a/platipy/imaging/cnn/utils.py b/platipy/imaging/cnn/utils.py index c2d4231b..f2ed1dcd 100644 --- a/platipy/imaging/cnn/utils.py +++ b/platipy/imaging/cnn/utils.py @@ -183,17 +183,17 @@ def get_metrics(target, pred): result = {} result["DSC"] = compute_metric_dsc(target, pred) - target_pixels = sitk.GetArrayFromImage(target).sum() - pred_pixels = sitk.GetArrayFromImage(pred).sum() + target_pixels = sitk.GetArrayFromImage(target) + pred_pixels = sitk.GetArrayFromImage(pred) - if pred_pixels == 0 and target_pixels == 0: + if pred_pixels.max() == 0 and target_pixels.max() == 0: result["HD"] = 0 result["ASD"] = 0 - elif pred_pixels == 0 or target_pixels == 0: + elif pred_pixels.max() == 0 or target_pixels.max() == 0 or pred_pixels.min() == 1 or target_pixels.min() == 1: result["HD"] = 1000 result["ASD"] = 100 else: - result["HD"] = compute_metric_hd(target, pred) - result["ASD"] = compute_metric_masd(target, pred) + result["HD"] = compute_metric_hd(target, pred, auto_crop=False) + result["ASD"] = compute_metric_masd(target, pred, auto_crop=False) return result From c1cbf634d60604eab8d17e8d3dc2698f41b51433 Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Thu, 26 Aug 2021 09:18:42 +1000 Subject: [PATCH 122/264] Correct visualisation --- platipy/imaging/cnn/train.py | 34 ++++++++++++++++------------------ 1 file changed, 16 insertions(+), 18 deletions(-) diff --git a/platipy/imaging/cnn/train.py b/platipy/imaging/cnn/train.py index 0d1d8468..625b4785 100644 --- a/platipy/imaging/cnn/train.py +++ b/platipy/imaging/cnn/train.py @@ -239,14 +239,21 @@ def validate(self, img, manual_observers, samples, mean, matching_type="best"): intersection_mask = get_intersection_mask(manual_observers) union_mask = get_union_mask(manual_observers) - vis = ImageVisualiser(img, cut=get_com(union_mask), window=[-200, 700]) + vis = ImageVisualiser(img, cut=get_com(union_mask), window=[-200, 700], figure_size_in=16) + + vis.add_contour( + mean, color=plt.cm.get_cmap(contour_cmap)(0.5), linewidth=3, show_legend=False + ) + vis.add_contour( + manual_observers, color=[0.13, 0.67, 0.275], linewidth=0.5, show_legend=False + ) + vis.add_contour( intersection_mask, name="intersection", color=[0.13, 0.67, 0.275], linewidth=3 ) vis.add_contour(union_mask, name="union", color=[0.13, 0.67, 0.275], linewidth=3) vis.add_contour( samples, - show_legend=False, linewidth=1.5, color={ s: c @@ -255,26 +262,11 @@ def validate(self, img, manual_observers, samples, mean, matching_type="best"): ) }, ) - vis.add_contour( - mean, color=plt.cm.get_cmap(contour_cmap)(0.5), linewidth=3, show_legend=False - ) - vis.add_contour( - manual_observers, color=[0.13, 0.67, 0.275], linewidth=0.5, show_legend=False - ) vis.set_limits_from_label(union_mask, expansion=30) fig = vis.show() - first_obs = manual_observers[list(manual_observers.keys())[0]] - for s in samples: - samples[s] = sitk.Resample( - samples[s], first_obs, sitk.Transform(), sitk.sitkNearestNeighbor - ) - mean["mean"] = sitk.Resample( - mean["mean"], first_obs, sitk.Transform(), sitk.sitkNearestNeighbor - ) - sim = {k: np.zeros((len(samples), len(manual_observers))) for k in metrics} msim = {k: np.zeros((len(samples), len(manual_observers))) for k in metrics} for sid, samp in enumerate(samples): @@ -457,6 +449,12 @@ def validation_epoch_end(self, validation_step_outputs): fig.savefig(figure_path, dpi=300) plt.close("all") + try: + self.logger.experiment.log_image(figure_path) + except AttributeError: + # Likely offline mode + pass + for t in result: for m in metrics: computed_metrics[f"{t}_{m}"]+=result[t][m] @@ -532,7 +530,7 @@ def main(args, config_json_path=None): checkpoint_callback = ModelCheckpoint( monitor="probnet_DSC", dirpath=args.default_root_dir, - filename="probunet-{epoch:02d}-{DSC:.2f}", + filename="probunet-{fold}-{epoch:02d}-{probnet_DSC:.2f}", save_top_k=1, mode="max", ) From dd1001599a94e97d3dbe604972b63a9595fb2a0e Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Fri, 27 Aug 2021 23:42:34 +0000 Subject: [PATCH 123/264] 3d augmentations --- platipy/imaging/cnn/dataload.py | 2 - platipy/imaging/cnn/dataset.py | 244 ++++++++++++++++++++++++++++++-- 2 files changed, 236 insertions(+), 10 deletions(-) diff --git a/platipy/imaging/cnn/dataload.py b/platipy/imaging/cnn/dataload.py index 0c68c9e6..cab2349a 100644 --- a/platipy/imaging/cnn/dataload.py +++ b/platipy/imaging/cnn/dataload.py @@ -226,8 +226,6 @@ def setup(self, stage=None): crop_to_grid_size = self.crop_to_grid_size_xy augment_on_fly = self.augment_on_fly - if self.ndims == 3: - augment_on_fly = False self.training_set = NiftiDataset( train_data, diff --git a/platipy/imaging/cnn/dataset.py b/platipy/imaging/cnn/dataset.py index 643c4879..b7c00eda 100644 --- a/platipy/imaging/cnn/dataset.py +++ b/platipy/imaging/cnn/dataset.py @@ -26,12 +26,215 @@ from loguru import logger +import math +import random +from scipy.ndimage import affine_transform +from scipy.ndimage.filters import gaussian_filter, median_filter + from platipy.imaging.cnn.utils import preprocess_image, resample_mask_to_image, get_contour_mask from platipy.imaging.label.utils import get_union_mask, get_intersection_mask from platipy.imaging.cnn.localise_net import LocaliseUNet from platipy.imaging.utils.crop import label_to_roi, crop_to_roi +class GaussianNoise: + def __init__(self, mu=0.0, sigma=0.0, probability=1.0): + + self.mu = mu + self.sigma = sigma + self.probability = probability + + if not hasattr(self.mu, "__iter__"): + self.mu = (self.mu,) * 2 + + if not hasattr(self.sigma, "__iter__"): + self.sigma = (self.sigma,) * 2 + + def apply(self, img, masks=[]): + + if random.random() > self.probability: + # Don't augment this time + return img, masks + + mean = random.uniform(self.mu[0], self.mu[1]) + sigma = random.uniform(self.sigma[0], self.sigma[1]) + + gaussian = np.random.normal(mean, sigma, img.shape) + return img + gaussian, masks + + +class GaussianBlur: + def __init__(self, sigma=0.0, probability=1.0): + + self.sigma = sigma + self.probability = probability + + if not hasattr(self.sigma, "__iter__"): + self.sigma = (self.sigma,) * 2 + + def apply(self, img, masks=[]): + + if random.random() > self.probability: + # Don't augment this time + return img, masks + + sigma = random.uniform(self.sigma[0], self.sigma[1]) + + return gaussian_filter(img, sigma=sigma), masks + + +class MedianBlur: + def __init__(self, size=1.0, probability=1.0): + + self.size = size + self.probability = probability + + if not hasattr(self.size, "__iter__"): + self.size = (self.size,) * 2 + + def apply(self, img, masks=[]): + + if random.random() > self.probability: + # Don't augment this time + return img, masks + + size = random.uniform(self.size[0], self.size[1]) + + return median_filter(img, size=size), masks + + +DIMS = ["ax", "cor", "sag"] + + +class Affine: + def __init__( + self, + scale={"ax": 1.0, "cor": 1.0, "sag": 1.0}, + translate_percent={"ax": 0.0, "cor": 0.0, "sag": 0.0}, + rotate={"ax": 0.0, "cor": 0.0, "sag": 0.0}, + shear={"ax": 0.0, "cor": 0.0, "sag": 0.0}, + mode="constant", + cval=-1, + probability=1.0, + ): + + self.scale = scale + self.translate_percent = translate_percent + self.rotate = rotate + self.shear = shear + self.probability = probability + + for d in self.rotate: + if not hasattr(self.rotate[d], "__iter__"): + self.rotate[d] = (self.rotate[d],) * 2 + + for d in self.scale: + if not hasattr(self.scale[d], "__iter__"): + self.scale[d] = (self.scale[d],) * 2 + + for d in self.translate_percent: + if not hasattr(self.translate_percent[d], "__iter__"): + self.translate_percent[d] = (self.translate_percent[d],) * 2 + + for d in self.shear: + if not hasattr(self.shear[d], "__iter__"): + self.shear[d] = (self.shear[d],) * 2 + + for d in self.scale: + if not hasattr(self.scale[d], "__iter__"): + self.scale[d] = (self.scale[d],) * 2 + + def get_rot(self, theta, d): + if d == "ax": + return np.matrix( + [ + [1, 0, 0, 0], + [0, math.cos(theta), -math.sin(theta), 0], + [0, math.sin(theta), math.cos(theta), 0], + [0, 0, 0, 1], + ] + ) + + if d == "cor": + return np.matrix( + [ + [math.cos(theta), 0, math.sin(theta), 0], + [0, 1, 0, 0], + [-math.sin(theta), 0, math.cos(theta), 0], + [0, 0, 0, 1], + ] + ) + + if d == "sag": + return np.matrix( + [ + [math.cos(theta), -math.sin(theta), 0, 0], + [math.sin(theta), math.cos(theta), 0, 0], + [0, 0, 1, 0], + [0, 0, 0, 1], + ] + ) + + def get_shear(self, shear): + + mat = np.identity(4) + mat[0, 1] = shear[1] + mat[0, 2] = shear[2] + mat[1, 0] = shear[0] + mat[1, 2] = shear[2] + mat[2, 0] = shear[0] + mat[2, 1] = shear[1] + + return mat + + def apply(self, img, masks=[]): + + if random.random() > self.probability: + # Don't augment this time + return img, masks + + deg_to_rad = math.pi / 180 + + t_prerot = np.identity(4) + t_postrot = np.identity(4) + for i, d in enumerate(DIMS): + t_prerot[i, -1] = -img.shape[i] / 2 + t_postrot[i, -1] = img.shape[i] / 2 + + t = t_postrot + + for i, d in enumerate(DIMS): + t = t * self.get_rot( + random.uniform(self.rotate[d][0], self.rotate[d][1]) * deg_to_rad, d + ) + + for i, d in enumerate(DIMS): + scale = np.identity(4) + scale[i, i] = 1 / random.uniform(self.scale[d][0], self.scale[d][1]) + t = t * scale + + shear = [] + for i, d in enumerate(DIMS): + shear.append(random.uniform(self.shear[d][0], self.shear[d][1])) + + t = t * self.get_shear(shear) + + t = t * t_prerot + + for i, d in enumerate(DIMS): + trans = [p * img.shape[i] for p in self.translate_percent[d]] + translation = np.identity(4) + translation[i, -1] = random.uniform(trans[0], trans[1]) + t = t * translation + + augmented_image = affine_transform(img, t, mode="mirror") + augmented_masks = [] + for mask in masks: + augmented_masks.append(affine_transform(mask, t, mode="nearest")) + + return augmented_image, augmented_masks + + def crop_img_using_localise_model( img, localise_model, spacing=[1, 1, 1], crop_to_grid_size=[100, 100, 100] ): @@ -81,6 +284,21 @@ def crop_img_using_localise_model( return crop_to_roi(img, size, index) +def prepare_3d_transforms(): + affine_aug = Affine( + translate_percent={"ax": [-0.1, 0.1], "cor": [-0.1, 0.1], "sag": [-0.1, 0.1]}, + rotate={"ax": [-10.0, 10.0], "cor": [-10.0, 10.0], "sag": [-10.0, 10.0]}, + scale={"ax": [0.8, 1.2], "cor": [0.8, 1.2], "sag": [0.8, 1.2]}, + shear={"ax": [0.0, 0.2], "cor": [0.0, 0.2], "sag": [0.0, 0.2]}, + probability=0.5, + ) + gaussian_blur = GaussianBlur(sigma=[0.0, 1.0], probability=0.33) + median_blur = MedianBlur(size=[1, 3], probability=0.5) + gaussian_noise = GaussianNoise(sigma=[0, 0.2], probability=0.5) + + return [affine_aug, gaussian_blur, median_blur, gaussian_noise] + + def prepare_transforms(): sometimes = lambda aug: iaa.Sometimes(0.5, aug) @@ -147,7 +365,10 @@ def __init__( self.data = data self.transforms = None if augment_on_fly: - self.transforms = prepare_transforms() + if self.ndims == 2: + self.transforms = prepare_transforms() + else: + self.transforms = prepare_3d_transforms() self.slices = [] self.working_dir = Path(working_dir) self.ndims = ndims @@ -285,13 +506,20 @@ def __getitem__(self, index): contour_mask = np.load(self.slices[index]["contour_mask"]) if self.transforms: - seg_arr = np.concatenate( - (np.expand_dims(label, 2), np.expand_dims(contour_mask, 2)), 2 - ) - segmap = SegmentationMapsOnImage(seg_arr, shape=label.shape) - img, seg = self.transforms(image=img, segmentation_maps=segmap) - label = seg.get_arr()[:, :, 0].squeeze() - contour_mask = seg.get_arr()[:, :, 1].squeeze() + if self.ndims == 2: + seg_arr = np.concatenate( + (np.expand_dims(label, 2), np.expand_dims(contour_mask, 2)), 2 + ) + segmap = SegmentationMapsOnImage(seg_arr, shape=label.shape) + img, seg = self.transforms(image=img, segmentation_maps=segmap) + label = seg.get_arr()[:, :, 0].squeeze() + contour_mask = seg.get_arr()[:, :, 1].squeeze() + else: + masks = [label, contour_mask] + for aug in self.transforms: + img, masks = aug.apply(img, masks) + label = masks[0] + contour_mask = masks[1] img = torch.FloatTensor(img) label = torch.IntTensor(label) From e70b968200a048dd2b257afe5a058b2b7d5c9227 Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Fri, 27 Aug 2021 23:45:46 +0000 Subject: [PATCH 124/264] correct bug --- platipy/imaging/cnn/dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/platipy/imaging/cnn/dataset.py b/platipy/imaging/cnn/dataset.py index b7c00eda..8e69c9ce 100644 --- a/platipy/imaging/cnn/dataset.py +++ b/platipy/imaging/cnn/dataset.py @@ -364,6 +364,7 @@ def __init__( self.data = data self.transforms = None + self.ndims = ndims if augment_on_fly: if self.ndims == 2: self.transforms = prepare_transforms() @@ -371,7 +372,6 @@ def __init__( self.transforms = prepare_3d_transforms() self.slices = [] self.working_dir = Path(working_dir) - self.ndims = ndims self.img_dir = working_dir.joinpath("img") self.label_dir = working_dir.joinpath("label") From 4363bc5e54685b82b30f47eb9c52a583f30b0ab2 Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Fri, 27 Aug 2021 23:48:15 +0000 Subject: [PATCH 125/264] Correct median blur --- platipy/imaging/cnn/dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/platipy/imaging/cnn/dataset.py b/platipy/imaging/cnn/dataset.py index 8e69c9ce..cda1b1fe 100644 --- a/platipy/imaging/cnn/dataset.py +++ b/platipy/imaging/cnn/dataset.py @@ -98,7 +98,7 @@ def apply(self, img, masks=[]): # Don't augment this time return img, masks - size = random.uniform(self.size[0], self.size[1]) + size = random.randint(self.size[0], self.size[1]) return median_filter(img, size=size), masks From abdc1221ac073498b0f0a357a52c7ce1a08581b3 Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Wed, 1 Sep 2021 08:31:36 +1000 Subject: [PATCH 126/264] Adjust window level in validation --- platipy/imaging/cnn/train.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/platipy/imaging/cnn/train.py b/platipy/imaging/cnn/train.py index 625b4785..75c670d3 100644 --- a/platipy/imaging/cnn/train.py +++ b/platipy/imaging/cnn/train.py @@ -231,7 +231,7 @@ def infer( return result - def validate(self, img, manual_observers, samples, mean, matching_type="best"): + def validate(self, img, manual_observers, samples, mean, matching_type="best", window=[-0.5,1.0]): metrics = {"DSC": "max", "HD": "min", "ASD": "min"} @@ -239,7 +239,7 @@ def validate(self, img, manual_observers, samples, mean, matching_type="best"): intersection_mask = get_intersection_mask(manual_observers) union_mask = get_union_mask(manual_observers) - vis = ImageVisualiser(img, cut=get_com(union_mask), window=[-200, 700], figure_size_in=16) + vis = ImageVisualiser(img, cut=get_com(union_mask), figure_size_in=16, window=window) vis.add_contour( mean, color=plt.cm.get_cmap(contour_cmap)(0.5), linewidth=3, show_legend=False From 4a703fefe3efd141318bbe078096426bbf89e761 Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Tue, 31 Aug 2021 23:22:04 +0000 Subject: [PATCH 127/264] Build contour loss into loss function --- platipy/imaging/cnn/prob_unet.py | 104 ++++++++++++++++++------------- platipy/imaging/cnn/train.py | 28 +++++---- 2 files changed, 76 insertions(+), 56 deletions(-) diff --git a/platipy/imaging/cnn/prob_unet.py b/platipy/imaging/cnn/prob_unet.py index c8d47626..8f5c3c39 100644 --- a/platipy/imaging/cnn/prob_unet.py +++ b/platipy/imaging/cnn/prob_unet.py @@ -23,8 +23,7 @@ class Encoder(torch.nn.Module): - """Encoder part of the probabilistic UNet - """ + """Encoder part of the probabilistic UNet""" def __init__( self, input_channels, filters_per_layer=[64 * (2 ** x) for x in range(5)], ndims=2 @@ -72,7 +71,7 @@ def __init__( out_channels=2 * self.latent_dim, kernel_size=1, stride=1, - ndims=ndims + ndims=ndims, ) self.ndims = ndims @@ -234,8 +233,9 @@ def __init__( self.unet_features = None if self.loss_type == "geco": - self._moving_avg = None - self.register_buffer("_lambda", torch.zeros(1, requires_grad=False)) + self._rec_moving_avg = None + self._contour_moving_avg = None + self.register_buffer("_lambda", torch.zeros(2, requires_grad=False)) def forward(self, img, seg=None, training=False): """ @@ -323,8 +323,6 @@ def reconstruction_loss( mask=None, top_k_percentage=None, deterministic=True, - weight_mask=None, - weight_mask_weighting=0.0, ): criterion = torch.nn.BCEWithLogitsLoss(reduction="none") @@ -382,16 +380,6 @@ def reconstruction_loss( mask.reshape((batch_size, -1, num_classes)).transpose(-1, 1).reshape((batch_size, -1)) ) - if weight_mask is not None: - weight_mask = torch.reshape(weight_mask, (-1,)) - weight_mask = weight_mask.unsqueeze(1).repeat((1, num_classes)) - weight_mask = ( - weight_mask.reshape((batch_size, -1, num_classes)) - .transpose(-1, 1) - .reshape((batch_size, -1)) - ) - mask = mask + (weight_mask * weight_mask_weighting) - ce_sum_per_instance = torch.sum(mask * xe, axis=1) ce_sum = torch.mean(ce_sum_per_instance, axis=0) ce_mean = torch.sum(mask * xe) / torch.sum(mask) @@ -418,14 +406,14 @@ def loss( if "top_k_percentage" in self.loss_params: top_k_percentage = self.loss_params["top_k_percentage"] - loss_mask = None - if self.loss_params["contour_loss_lambda_threshold"]: - if ( - self._lambda # pylint: disable=access-member-before-definition - <= self.loss_params["contour_loss_lambda_threshold"] - and not use_max_lambda - ): - loss_mask = mask + # loss_mask = None + # if self.loss_params["contour_loss_lambda_threshold"]: + # if ( + # self._lambda # pylint: disable=access-member-before-definition + # <= self.loss_params["contour_loss_lambda_threshold"] + # and not use_max_lambda + # ): + # loss_mask = mask # loss_mask = loss_mask.unsqueeze(1).repeat((1, self.num_classes, 1, 1)) # Here we use the posterior sample sampled above @@ -434,9 +422,6 @@ def loss( reconstruct_posterior_mean=reconstruct_posterior_mean, z_posterior=z_posterior, top_k_percentage=top_k_percentage, - mask=loss_mask, - weight_mask=mask, - weight_mask_weighting=self.loss_params["contour_loss_weight"], ) if self.loss_type == "elbo": @@ -448,48 +433,77 @@ def loss( } elif self.loss_type == "geco": - num_pixels = mask.sum().item() - batch_size = segm.shape[0] - reconstruction_threshold = (self.loss_params["kappa"] * num_pixels) / batch_size reconstruction_threshold = self.loss_params["kappa"] + contour_threshold = None + if "kappa_contour" in self.loss_params: + contour_threshold = self.loss_params["kappa_contour"] + + contour_loss, contour_loss_mean, mask = self.reconstruction_loss( + segm, + reconstruct_posterior_mean=reconstruct_posterior_mean, + z_posterior=z_posterior, + top_k_percentage=top_k_percentage, + mask=mask, + ) + with torch.no_grad(): - rl = rec_loss_mean.detach() moving_avg_factor = 0.8 - if self._moving_avg is None: - self._moving_avg = rl + + rl = rec_loss_mean.detach() + if self._rec_moving_avg is None: + self._rec_moving_avg = rl else: - self._moving_avg = self._moving_avg * moving_avg_factor + rl * ( + self._rec_moving_avg = self._rec_moving_avg * moving_avg_factor + rl * ( 1 - moving_avg_factor ) - rc = self._moving_avg - reconstruction_threshold + rc = self._rec_moving_avg - reconstruction_threshold + + if contour_threshold: + cl = contour_loss_mean.detach() + if self._contour_moving_avg is None: + self._contour_moving_avg = rl + else: + self._contour_moving_avg = self._moving_avg * moving_avg_factor + cl * ( + 1 - moving_avg_factor + ) + + cc = self._contour_moving_avg - contour_threshold + lambda_lower = self.loss_params["clamp_rec"][0] lambda_upper = self.loss_params["clamp_rec"][1] if use_max_lambda: - self._lambda = torch.Tensor( # pylint: disable=attribute-defined-outside-init - [lambda_upper] - ).to(rc.device) + self._lambda[0] = lambda_upper + self._lambda[1] = lambda_upper else: self._lambda = ( # pylint: disable=attribute-defined-outside-init - torch.exp(rc) * self._lambda + torch.exp(torch.Tensor([rc, cc])) * self._lambda ).clamp(lambda_lower, lambda_upper) + # If not using the contour loss part, set lambda of that to zero + if contour_threshold is None: + self._lambda[1] = 0.0 + + # pylint: disable=access-member-before-definition loss = ( - self._lambda - * reconstruction_loss # pylint: disable=access-member-before-definition - + kl_div + (self._lambda[0] * reconstruction_loss) + (self._lambda[1] * contour_loss) + kl_div ) return { "loss": loss, "rec_loss": reconstruction_loss, + "contour_loss": contour_loss, "kl_div": kl_div, - "lambda": self._lambda, - "moving_avg": self._moving_avg, + "lambda_rec": self._lambda[0], + "lambda_contour": self._lambda[1], + "moving_avg_rec": self._rec_moving_avg, + "moving_avg_contour": self._contour_moving_avg, "reconstruction_threshold": reconstruction_threshold, + "contour_threshold": contour_threshold, "rec_constraint": rc, + "contour_constraint": cc, } else: diff --git a/platipy/imaging/cnn/train.py b/platipy/imaging/cnn/train.py index 625b4785..35b0d6d3 100644 --- a/platipy/imaging/cnn/train.py +++ b/platipy/imaging/cnn/train.py @@ -37,11 +37,7 @@ from platipy.imaging.cnn.unet import l2_regularisation from platipy.imaging.cnn.dataload import UNetDataModule from platipy.imaging.cnn.dataset import crop_img_using_localise_model -from platipy.imaging.cnn.utils import ( - preprocess_image, - postprocess_mask, - get_metrics -) +from platipy.imaging.cnn.utils import preprocess_image, postprocess_mask, get_metrics from platipy.imaging import ImageVisualiser from platipy.imaging.label.utils import get_com, get_union_mask, get_intersection_mask @@ -67,6 +63,7 @@ def __init__( loss_params = { "kappa": self.hparams.kappa, "clamp_rec": self.hparams.clamp_rec, + "kappa_contour": self.hparams.kappa_contour, } loss_params["top_k_percentage"] = self.hparams.top_k_percentage @@ -103,10 +100,11 @@ def add_model_specific_args(parent_parser): parser.add_argument("--loss_type", type=str, default="elbo") parser.add_argument("--beta", type=float, default=1.0) parser.add_argument("--kappa", type=float, default=0.02) + parser.add_argument("--kappa_contour", type=float, default=None) parser.add_argument("--clamp_rec", nargs="+", type=float, default=[1e-5, 1e5]) parser.add_argument("--top_k_percentage", type=float, default=None) parser.add_argument("--contour_loss_lambda_threshold", type=float, default=None) - parser.add_argument("--contour_loss_weight", type=float, default=0.0) + parser.add_argument("--contour_loss_weight", type=float, default=0.0) # no longer used parser.add_argument("--epochs_all_rec", type=int, default=0) return parent_parser @@ -143,7 +141,13 @@ def infer( ] * self.hparams.latent_dim if sample_strategy == "mean": - samples = [{"name": "mean", "std_dev_from_mean": torch.Tensor([0.0] * len(latent_dim)).to(self.device), "preds": []}] + samples = [ + { + "name": "mean", + "std_dev_from_mean": torch.Tensor([0.0] * len(latent_dim)).to(self.device), + "preds": [], + } + ] elif sample_strategy == "random": samples = [ { @@ -159,7 +163,9 @@ def infer( samples = [ { "name": f"spaced_{s}", - "std_dev_from_mean": torch.Tensor([s if d else 0.0 for d in latent_dim]).to(self.device), + "std_dev_from_mean": torch.Tensor([s if d else 0.0 for d in latent_dim]).to( + self.device + ), "preds": [], } for s in np.linspace(spaced_range[0], spaced_range[1], num_samples) @@ -189,9 +195,9 @@ def infer( img_arr = sitk.GetArrayFromImage(img) if self.hparams.ndims == 2: - slices = [img_arr[z, :, :] for z in range(img_arr.shape[0])] + slices = [img_arr[z, :, :] for z in range(img_arr.shape[0])] else: - slices = [img_arr] + slices = [img_arr] for i in slices: x = torch.Tensor(i).to(self.device) @@ -457,7 +463,7 @@ def validation_epoch_end(self, validation_step_outputs): for t in result: for m in metrics: - computed_metrics[f"{t}_{m}"]+=result[t][m] + computed_metrics[f"{t}_{m}"] += result[t][m] for cm in computed_metrics: self.log( From a09c1c9aa7d38402368d91b87564f8a48e12eee6 Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Tue, 31 Aug 2021 23:45:08 +0000 Subject: [PATCH 128/264] Correct issue --- platipy/imaging/cnn/prob_unet.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/platipy/imaging/cnn/prob_unet.py b/platipy/imaging/cnn/prob_unet.py index 8f5c3c39..745b2d5f 100644 --- a/platipy/imaging/cnn/prob_unet.py +++ b/platipy/imaging/cnn/prob_unet.py @@ -417,7 +417,7 @@ def loss( # loss_mask = loss_mask.unsqueeze(1).repeat((1, self.num_classes, 1, 1)) # Here we use the posterior sample sampled above - reconstruction_loss, rec_loss_mean, mask = self.reconstruction_loss( + reconstruction_loss, rec_loss_mean, _ = self.reconstruction_loss( segm, reconstruct_posterior_mean=reconstruct_posterior_mean, z_posterior=z_posterior, @@ -439,7 +439,7 @@ def loss( if "kappa_contour" in self.loss_params: contour_threshold = self.loss_params["kappa_contour"] - contour_loss, contour_loss_mean, mask = self.reconstruction_loss( + contour_loss, contour_loss_mean, _ = self.reconstruction_loss( segm, reconstruct_posterior_mean=reconstruct_posterior_mean, z_posterior=z_posterior, From 9b393961b2f42b7dd6b05b1d675b7f7a3da1e6cb Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Tue, 31 Aug 2021 23:47:11 +0000 Subject: [PATCH 129/264] Correct another issue --- platipy/imaging/cnn/prob_unet.py | 1 + 1 file changed, 1 insertion(+) diff --git a/platipy/imaging/cnn/prob_unet.py b/platipy/imaging/cnn/prob_unet.py index 745b2d5f..f96c02f2 100644 --- a/platipy/imaging/cnn/prob_unet.py +++ b/platipy/imaging/cnn/prob_unet.py @@ -461,6 +461,7 @@ def loss( rc = self._rec_moving_avg - reconstruction_threshold + cc = 0 if contour_threshold: cl = contour_loss_mean.detach() if self._contour_moving_avg is None: From 6d3132d2b040c692d6f96cab0f6a19cf3ae45be7 Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Tue, 31 Aug 2021 23:50:26 +0000 Subject: [PATCH 130/264] Ensure moved to gpu --- platipy/imaging/cnn/prob_unet.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/platipy/imaging/cnn/prob_unet.py b/platipy/imaging/cnn/prob_unet.py index f96c02f2..e0400f8d 100644 --- a/platipy/imaging/cnn/prob_unet.py +++ b/platipy/imaging/cnn/prob_unet.py @@ -480,7 +480,7 @@ def loss( self._lambda[1] = lambda_upper else: self._lambda = ( # pylint: disable=attribute-defined-outside-init - torch.exp(torch.Tensor([rc, cc])) * self._lambda + torch.exp(torch.Tensor([rc, cc]).to(rc.device)) * self._lambda ).clamp(lambda_lower, lambda_upper) # If not using the contour loss part, set lambda of that to zero From 586ca471fe5f38c701241c0f8e2133942a6b380d Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Tue, 31 Aug 2021 23:54:06 +0000 Subject: [PATCH 131/264] Ensure contour loss can be turned off --- platipy/imaging/cnn/prob_unet.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/platipy/imaging/cnn/prob_unet.py b/platipy/imaging/cnn/prob_unet.py index e0400f8d..fa2f9b00 100644 --- a/platipy/imaging/cnn/prob_unet.py +++ b/platipy/imaging/cnn/prob_unet.py @@ -436,7 +436,10 @@ def loss( reconstruction_threshold = self.loss_params["kappa"] contour_threshold = None - if "kappa_contour" in self.loss_params: + if ( + "kappa_contour" in self.loss_params + and self.loss_params["kappa_contour"] is not None + ): contour_threshold = self.loss_params["kappa_contour"] contour_loss, contour_loss_mean, _ = self.reconstruction_loss( From 25108ffdadb70191b6e0acdf5e7fa6708fe20fa6 Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Tue, 31 Aug 2021 23:58:18 +0000 Subject: [PATCH 132/264] Separate results for contour loss --- platipy/imaging/cnn/prob_unet.py | 27 +++++++++++++-------------- 1 file changed, 13 insertions(+), 14 deletions(-) diff --git a/platipy/imaging/cnn/prob_unet.py b/platipy/imaging/cnn/prob_unet.py index fa2f9b00..8b98ac94 100644 --- a/platipy/imaging/cnn/prob_unet.py +++ b/platipy/imaging/cnn/prob_unet.py @@ -486,29 +486,28 @@ def loss( torch.exp(torch.Tensor([rc, cc]).to(rc.device)) * self._lambda ).clamp(lambda_lower, lambda_upper) - # If not using the contour loss part, set lambda of that to zero - if contour_threshold is None: - self._lambda[1] = 0.0 - # pylint: disable=access-member-before-definition - loss = ( - (self._lambda[0] * reconstruction_loss) + (self._lambda[1] * contour_loss) + kl_div - ) + loss = (self._lambda[0] * reconstruction_loss) + kl_div - return { + result = { "loss": loss, "rec_loss": reconstruction_loss, - "contour_loss": contour_loss, "kl_div": kl_div, "lambda_rec": self._lambda[0], - "lambda_contour": self._lambda[1], - "moving_avg_rec": self._rec_moving_avg, - "moving_avg_contour": self._contour_moving_avg, + "moving_avg": self._rec_moving_avg, "reconstruction_threshold": reconstruction_threshold, - "contour_threshold": contour_threshold, "rec_constraint": rc, - "contour_constraint": cc, } + if contour_threshold is not None: + loss = loss + (self._lambda[1] * contour_loss) + result["contour_loss"] = contour_loss + result["contour_threshold"] = contour_threshold + result["contour_constraint"] = cc + result["moving_avg_contour"] = self._contour_moving_avg + result["lambda_contour"] = self._lambda[1] + + return result + else: raise NotImplementedError("Loss must be 'elbo' or 'geco'") From 94b7d07aee3664825aec36c4f3b5f6bc76a6bc24 Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Wed, 1 Sep 2021 00:07:31 +0000 Subject: [PATCH 133/264] Correct issue --- platipy/imaging/cnn/prob_unet.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/platipy/imaging/cnn/prob_unet.py b/platipy/imaging/cnn/prob_unet.py index 8b98ac94..dee3e693 100644 --- a/platipy/imaging/cnn/prob_unet.py +++ b/platipy/imaging/cnn/prob_unet.py @@ -470,8 +470,9 @@ def loss( if self._contour_moving_avg is None: self._contour_moving_avg = rl else: - self._contour_moving_avg = self._moving_avg * moving_avg_factor + cl * ( - 1 - moving_avg_factor + self._contour_moving_avg = ( + self._contour_moving_avg * moving_avg_factor + + cl * (1 - moving_avg_factor) ) cc = self._contour_moving_avg - contour_threshold From 744bba93bc15724505eb44415ce0b55799c2393e Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Wed, 1 Sep 2021 06:23:59 +0000 Subject: [PATCH 134/264] Able to compute loss in multiple masks --- platipy/imaging/cnn/prob_unet.py | 145 ++++++++++++++++++++----------- 1 file changed, 96 insertions(+), 49 deletions(-) diff --git a/platipy/imaging/cnn/prob_unet.py b/platipy/imaging/cnn/prob_unet.py index dee3e693..e49cebb0 100644 --- a/platipy/imaging/cnn/prob_unet.py +++ b/platipy/imaging/cnn/prob_unet.py @@ -315,6 +315,52 @@ def topk_mask(self, score, k): torch.ones(score.shape[0]).to(score.device), ) + def prepare_mask( + self, + mask, + top_k_percentage, + deterministic, + num_classes, + device, + batch_size, + n_pixels_in_batch, + xe, + ): + if mask is None or mask.sum() == 0: + mask = torch.ones(n_pixels_in_batch) + else: + # assert ( + # mask.shape == segm.shape + # ), f"The loss mask shape differs from the target shape: {mask.shape} vs. {segm.shape}." + mask = torch.reshape(mask, (-1,)) + mask = mask.to(device) + + if top_k_percentage is not None: + + assert 0.0 < top_k_percentage <= 1.0 + k_pixels = int(n_pixels_in_batch * top_k_percentage) + + with torch.no_grad(): + norm_xe = xe / torch.sum(xe) + if deterministic: + score = torch.log(norm_xe) + else: + # TODO Gumbel trick + raise NotImplementedError("Still need to implement Gumbel trick") + + score = score + torch.log(mask.unsqueeze(1).repeat((1, num_classes))) + + top_k_mask = self.topk_mask(score, k_pixels) + top_k_mask = top_k_mask.to(device) + mask = mask * top_k_mask + + mask = mask.unsqueeze(1).repeat((1, num_classes)) + mask = ( + mask.reshape((batch_size, -1, num_classes)).transpose(-1, 1).reshape((batch_size, -1)) + ) + + return mask + def reconstruction_loss( self, segm, @@ -343,41 +389,44 @@ def reconstruction_loss( y_flat = torch.transpose(reconstruction, 1, -1).reshape((-1, num_classes)) t_flat = torch.transpose(segm, 1, -1).reshape((-1, num_classes)) n_pixels_in_batch = y_flat.shape[0] - if mask is None or mask.sum() == 0: - mask = torch.ones(n_pixels_in_batch) - else: - # assert ( - # mask.shape == segm.shape - # ), f"The loss mask shape differs from the target shape: {mask.shape} vs. {segm.shape}." - mask = torch.reshape(mask, (-1,)) - mask = mask.to(y_flat.device) + batch_size = segm.shape[0] xe = criterion(input=y_flat, target=t_flat) + xe = xe.reshape((batch_size, -1, num_classes)).transpose(-1, 1).reshape((batch_size, -1)) - if top_k_percentage is not None: - - assert 0.0 < top_k_percentage <= 1.0 - k_pixels = int(n_pixels_in_batch * top_k_percentage) - - with torch.no_grad(): - norm_xe = xe / torch.sum(xe) - if deterministic: - score = torch.log(norm_xe) - else: - # TODO Gumbel trick - raise NotImplementedError("Still need to implement Gumbel trick") - - score = score + torch.log(mask.unsqueeze(1).repeat((1, num_classes))) - - top_k_mask = self.topk_mask(score, k_pixels) - top_k_mask = top_k_mask.to(y_flat.device) - mask = mask * top_k_mask + # If multiple masks supplied, compute a loss for each mask + if hasattr(mask, "__iter__"): + ce_sums = [] + ce_means = [] + masks = [] + for this_mask in masks: + this_mask = self.prepare_mask( + this_mask, + top_k_percentage, + deterministic, + num_classes, + y_flat.device, + batch_size, + n_pixels_in_batch, + xe, + ) - batch_size = segm.shape[0] - xe = xe.reshape((batch_size, -1, num_classes)).transpose(-1, 1).reshape((batch_size, -1)) - mask = mask.unsqueeze(1).repeat((1, num_classes)) - mask = ( - mask.reshape((batch_size, -1, num_classes)).transpose(-1, 1).reshape((batch_size, -1)) + ce_sum_per_instance = torch.sum(this_mask * xe, axis=1) + ce_sums.append(torch.mean(ce_sum_per_instance, axis=0)) + ce_means.append(torch.sum(this_mask * xe) / torch.sum(this_mask)) + masks.append(this_mask) + + return ce_sums, ce_means, masks + + mask = self.prepare_mask( + mask, + top_k_percentage, + deterministic, + num_classes, + y_flat.device, + batch_size, + n_pixels_in_batch, + xe, ) ce_sum_per_instance = torch.sum(mask * xe, axis=1) @@ -416,14 +465,29 @@ def loss( # loss_mask = mask # loss_mask = loss_mask.unsqueeze(1).repeat((1, self.num_classes, 1, 1)) + loss_mask = None + reconstruction_threshold = self.loss_params["kappa"] + contour_threshold = None + if "kappa_contour" in self.loss_params and self.loss_params["kappa_contour"] is not None: + loss_mask = [None, mask] + contour_threshold = self.loss_params["kappa_contour"] + # Here we use the posterior sample sampled above - reconstruction_loss, rec_loss_mean, _ = self.reconstruction_loss( + rec_loss, rec_loss_mean, _ = self.reconstruction_loss( segm, reconstruct_posterior_mean=reconstruct_posterior_mean, z_posterior=z_posterior, top_k_percentage=top_k_percentage, + mask=loss_mask, ) + # If using contour mask in loss, we get back those in a list. Unpack here. + if contour_threshold: + contour_loss = rec_loss[1] + contour_loss_mean = rec_loss_mean[1] + reconstruction_loss = rec_loss[0] + rec_loss_mean = rec_loss_mean[0] + if self.loss_type == "elbo": return { @@ -433,23 +497,6 @@ def loss( } elif self.loss_type == "geco": - reconstruction_threshold = self.loss_params["kappa"] - - contour_threshold = None - if ( - "kappa_contour" in self.loss_params - and self.loss_params["kappa_contour"] is not None - ): - contour_threshold = self.loss_params["kappa_contour"] - - contour_loss, contour_loss_mean, _ = self.reconstruction_loss( - segm, - reconstruct_posterior_mean=reconstruct_posterior_mean, - z_posterior=z_posterior, - top_k_percentage=top_k_percentage, - mask=mask, - ) - with torch.no_grad(): moving_avg_factor = 0.8 From e7078b9ff41a63a9edeb0772eb31cb47bca9f2df Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Wed, 1 Sep 2021 10:45:47 +0000 Subject: [PATCH 135/264] Correction to converting image --- platipy/imaging/cnn/dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/platipy/imaging/cnn/dataset.py b/platipy/imaging/cnn/dataset.py index cda1b1fe..f62b7ba8 100644 --- a/platipy/imaging/cnn/dataset.py +++ b/platipy/imaging/cnn/dataset.py @@ -425,7 +425,7 @@ def __init__( img = sitk.ReadImage(img_path) if crop_using_localise_model: - crop_img_using_localise_model( + img = crop_img_using_localise_model( img, crop_using_localise_model, spacing=spacing, From 94948cd9e6dc37c646c09fa9fda0827eca8c3deb Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Sat, 4 Sep 2021 18:26:13 +1000 Subject: [PATCH 136/264] Fix to prob unet --- platipy/imaging/cnn/prob_unet.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/platipy/imaging/cnn/prob_unet.py b/platipy/imaging/cnn/prob_unet.py index e49cebb0..088f4139 100644 --- a/platipy/imaging/cnn/prob_unet.py +++ b/platipy/imaging/cnn/prob_unet.py @@ -399,7 +399,7 @@ def reconstruction_loss( ce_sums = [] ce_means = [] masks = [] - for this_mask in masks: + for this_mask in mask: this_mask = self.prepare_mask( this_mask, top_k_percentage, From c4f280b8202c0f75bdc15e98572d4798ae8c6d46 Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Mon, 6 Sep 2021 09:06:36 +1000 Subject: [PATCH 137/264] Update to fold path --- platipy/imaging/cnn/train.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/platipy/imaging/cnn/train.py b/platipy/imaging/cnn/train.py index 9bc58927..cdb18329 100644 --- a/platipy/imaging/cnn/train.py +++ b/platipy/imaging/cnn/train.py @@ -484,7 +484,9 @@ def main(args, config_json_path=None): args.working_dir = Path(args.working_dir) args.working_dir = args.working_dir.joinpath(args.experiment) - args.default_root_dir = str(args.working_dir) + #args.default_root_dir = str(args.working_dir) + args.fold_dir = args.working_dir.joinpath(f"fold_{args.fold}") + args.default_root_dir = str(args.fold_dir) comet_api_key = None comet_workspace = None @@ -536,7 +538,7 @@ def main(args, config_json_path=None): checkpoint_callback = ModelCheckpoint( monitor="probnet_DSC", dirpath=args.default_root_dir, - filename="probunet-{fold}-{epoch:02d}-{probnet_DSC:.2f}", + filename="probunet-{epoch:02d}-{probnet_DSC:.2f}", save_top_k=1, mode="max", ) From 546cbb4bce0268c04e6c9808a090b5c301c76714 Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Sun, 5 Sep 2021 23:25:43 +0000 Subject: [PATCH 138/264] Use mean to compute rec loss --- platipy/imaging/cnn/dataload.py | 3 +-- platipy/imaging/cnn/prob_unet.py | 6 +++--- platipy/imaging/cnn/train.py | 6 ++++-- 3 files changed, 8 insertions(+), 7 deletions(-) diff --git a/platipy/imaging/cnn/dataload.py b/platipy/imaging/cnn/dataload.py index cab2349a..4b4d92d9 100644 --- a/platipy/imaging/cnn/dataload.py +++ b/platipy/imaging/cnn/dataload.py @@ -39,7 +39,7 @@ def __init__( contour_mask_kernel=3, crop_using_localise_model=None, localise_voxel_grid_size=[100, 100, 100], - validation_sampler="observer", # observer or batch + validation_sampler="observer", # observer or batch ndims=2, **kwargs, ): @@ -55,7 +55,6 @@ def __init__( self.augmented_image_glob = augmented_image_glob self.augmented_label_glob = augmented_label_glob - print(augment_on_fly) self.augment_on_fly = augment_on_fly self.fold = fold self.k_folds = k_folds diff --git a/platipy/imaging/cnn/prob_unet.py b/platipy/imaging/cnn/prob_unet.py index 088f4139..83d66416 100644 --- a/platipy/imaging/cnn/prob_unet.py +++ b/platipy/imaging/cnn/prob_unet.py @@ -473,7 +473,7 @@ def loss( contour_threshold = self.loss_params["kappa_contour"] # Here we use the posterior sample sampled above - rec_loss, rec_loss_mean, _ = self.reconstruction_loss( + _, rec_loss_mean, _ = self.reconstruction_loss( segm, reconstruct_posterior_mean=reconstruct_posterior_mean, z_posterior=z_posterior, @@ -483,9 +483,9 @@ def loss( # If using contour mask in loss, we get back those in a list. Unpack here. if contour_threshold: - contour_loss = rec_loss[1] + contour_loss = rec_loss_mean[1] contour_loss_mean = rec_loss_mean[1] - reconstruction_loss = rec_loss[0] + reconstruction_loss = rec_loss_mean[0] rec_loss_mean = rec_loss_mean[0] if self.loss_type == "elbo": diff --git a/platipy/imaging/cnn/train.py b/platipy/imaging/cnn/train.py index cdb18329..bf17c907 100644 --- a/platipy/imaging/cnn/train.py +++ b/platipy/imaging/cnn/train.py @@ -237,7 +237,9 @@ def infer( return result - def validate(self, img, manual_observers, samples, mean, matching_type="best", window=[-0.5,1.0]): + def validate( + self, img, manual_observers, samples, mean, matching_type="best", window=[-0.5, 1.0] + ): metrics = {"DSC": "max", "HD": "min", "ASD": "min"} @@ -484,7 +486,7 @@ def main(args, config_json_path=None): args.working_dir = Path(args.working_dir) args.working_dir = args.working_dir.joinpath(args.experiment) - #args.default_root_dir = str(args.working_dir) + # args.default_root_dir = str(args.working_dir) args.fold_dir = args.working_dir.joinpath(f"fold_{args.fold}") args.default_root_dir = str(args.fold_dir) From 5e3f2adeab057dab18187b41556ccbed9f8bd685 Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Wed, 8 Sep 2021 07:52:31 +0000 Subject: [PATCH 139/264] Correction to include contour loss --- platipy/imaging/cnn/prob_unet.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/platipy/imaging/cnn/prob_unet.py b/platipy/imaging/cnn/prob_unet.py index 83d66416..de16ffec 100644 --- a/platipy/imaging/cnn/prob_unet.py +++ b/platipy/imaging/cnn/prob_unet.py @@ -548,7 +548,7 @@ def loss( } if contour_threshold is not None: - loss = loss + (self._lambda[1] * contour_loss) + result["loss"] = result["loss"] + (self._lambda[1] * contour_loss) result["contour_loss"] = contour_loss result["contour_threshold"] = contour_threshold result["contour_constraint"] = cc From dfa6d46faf832fbbc6e7ec55e440baeda9f10359 Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Wed, 15 Sep 2021 04:02:38 +0000 Subject: [PATCH 140/264] Hierarchical probabilistic UNet work --- platipy/imaging/cnn/hierarchical_prob_unet.py | 311 +++++++++++++----- platipy/imaging/cnn/prob_unet.py | 57 +--- platipy/imaging/cnn/train.py | 49 ++- 3 files changed, 274 insertions(+), 143 deletions(-) diff --git a/platipy/imaging/cnn/hierarchical_prob_unet.py b/platipy/imaging/cnn/hierarchical_prob_unet.py index 40b7b186..26c974bc 100644 --- a/platipy/imaging/cnn/hierarchical_prob_unet.py +++ b/platipy/imaging/cnn/hierarchical_prob_unet.py @@ -22,6 +22,7 @@ from .unet import init_weights, init_zeros, conv_nd + class ResBlock(torch.nn.Module): """A residual block""" @@ -494,12 +495,13 @@ def __init__( self, input_channels=1, num_classes=2, - channels_per_block=None, + filters_per_layer=None, down_channels_per_block=None, latent_dims=(1, 1, 1, 1), convs_per_block=3, blocks_per_level=3, - loss_kwargs=None, + loss_type="elbo", + loss_params={"beta": 1}, ndims=2, ): """Initialize the Hierarchical Probabilistic UNet @@ -508,7 +510,7 @@ def __init__( input_channels (int, optional): The number of channels in the image (1 for greyscale and 3 for RGB). Defaults to 1. num_classes (int, optional): The number of classes to predict. Defaults to 2. - channels_per_block (list, optional): A list of channels to use in blocks of each + filters_per_layer (list, optional): A list of channels to use in blocks of each layer the amount of filters layer. Defaults to None. down_channels_per_block (list, optional): [description]. Defaults to None. @@ -525,7 +527,7 @@ def __init__( super(HierarchicalProbabilisticUnet, self).__init__() base_channels = 24 - default_channels_per_block = ( + default_filters_per_layer = ( base_channels, 2 * base_channels, 4 * base_channels, @@ -535,15 +537,15 @@ def __init__( 8 * base_channels, 8 * base_channels, ) - if channels_per_block is None: - channels_per_block = default_channels_per_block + if filters_per_layer is None: + filters_per_layer = default_filters_per_layer if down_channels_per_block is None: - down_channels_per_block = [int(i / 2) for i in channels_per_block] + down_channels_per_block = [int(i / 2) for i in filters_per_layer] self._prior = _HierarchicalCore( input_channels=input_channels, latent_dims=latent_dims, - channels_per_block=channels_per_block, + channels_per_block=filters_per_layer, down_channels_per_block=down_channels_per_block, convs_per_block=convs_per_block, blocks_per_level=blocks_per_level, @@ -553,7 +555,7 @@ def __init__( self._posterior = _HierarchicalCore( input_channels=input_channels + num_classes, latent_dims=latent_dims, - channels_per_block=channels_per_block, + channels_per_block=filters_per_layer, down_channels_per_block=down_channels_per_block, convs_per_block=convs_per_block, blocks_per_level=blocks_per_level, @@ -562,7 +564,7 @@ def __init__( self._f_comb = _StitchingDecoder( latent_dims=latent_dims, - channels_per_block=channels_per_block, + channels_per_block=filters_per_layer, num_classes=num_classes, down_channels_per_block=down_channels_per_block, convs_per_block=convs_per_block, @@ -572,22 +574,13 @@ def __init__( self._cache = None - if loss_kwargs is None: - self._loss_kwargs = { - "type": "elbo", - "kappa": 0.05, - "decay": 0.99, - "rate": 1e-2, - "beta": 1.0, - } - else: - self._loss_kwargs = loss_kwargs + self.loss_type = loss_type + self.loss_params = loss_params - if self._loss_kwargs["type"] == "geco": - # self._moving_average = ExponentialMovingAverage(decay=self._loss_kwargs["decay"]) - # self._geco_loss = GECOLoss(target_ratio, alpha=0.5) - self._ema = None - self.register_buffer("_multiplier", torch.zeros(1, requires_grad=False)) + if self.loss_type == "geco": + self._rec_moving_avg = None + self._contour_moving_avg = None + self.register_buffer("_lambda", torch.zeros(2, requires_grad=False)) self._q_sample = None self._q_sample_mean = None @@ -698,28 +691,131 @@ def kl(self, img, seg): return kl - def rec_loss(self, img, seg): - """Cross-entropy reconstruction loss employed in the ELBO-/ GECO-objective. + def topk_mask(self, score, k): + """Returns a mask for the top-k elements in score.""" - Args: - img (torch.Tensor): A tensor of shape (b, c, h, w). - seg (torch.Tensor): A tensor of shape (b, num_classes, h, w). + values, _ = torch.topk(score, 1, axis=0) + _, indices = torch.topk(values, k, axis=1) + return torch.scatter_add( + torch.zeros(score.shape[1]).to(score.device), + 0, + indices.reshape(-1), + torch.ones(score.shape[1]).to(score.device), + ) - Returns: - dict: A dictionary holding the mean and the pixelwise sum of the loss - """ - reconstruction = self.reconstruct(img, seg, mean=False) + def prepare_mask( + self, + mask, + top_k_percentage, + deterministic, + num_classes, + device, + batch_size, + n_pixels_in_batch, + xe, + ): + if mask is None or mask.sum() == 0: + mask = torch.ones(n_pixels_in_batch) + else: + # assert ( + # mask.shape == segm.shape + # ), f"The loss mask shape differs from the target shape: {mask.shape} vs. {segm.shape}." + mask = torch.reshape(mask, (-1,)) + mask = mask.to(device) + mask = mask.repeat((1, num_classes)) + + if top_k_percentage is not None: + + assert 0.0 < top_k_percentage <= 1.0 + k_pixels = int(n_pixels_in_batch * top_k_percentage) + + with torch.no_grad(): + norm_xe = xe / torch.sum(xe) + if deterministic: + score = torch.log(norm_xe) + else: + # TODO Gumbel trick + raise NotImplementedError("Still need to implement Gumbel trick") + # score = score + torch.log(mask) + score = score + torch.log(mask) + + top_k_mask = self.topk_mask(score, k_pixels) + top_k_mask = top_k_mask.to(device) + mask = mask * top_k_mask + + mask = mask.repeat(batch_size, 1) + + return mask + + def reconstruction_loss( + self, + img, + segm, + mask=None, + top_k_percentage=None, + deterministic=True, + ): criterion = torch.nn.BCEWithLogitsLoss(reduction="none") - reconstruction_loss = criterion(input=reconstruction, target=seg) - reconstruction_loss_sum = torch.sum(reconstruction_loss) - reconstruction_loss_mean = torch.mean(reconstruction_loss) - mask = torch.ones(torch.numel(img)) + reconstruction = self.reconstruct(img, segm) + + # segm = torch.unsqueeze(segm, dim=1) + # not_seg = segm.logical_not() + # segm = torch.cat((not_seg, segm), dim=1).float() + + ##### + num_classes = reconstruction.shape[1] + y_flat = torch.transpose(reconstruction, 1, -1).reshape((-1, num_classes)) + t_flat = torch.transpose(segm, 1, -1).reshape((-1, num_classes)) + batch_size = segm.shape[0] + + xe = criterion(input=y_flat, target=t_flat) + xe = xe.reshape((batch_size, -1, num_classes)).transpose(-1, 1).reshape((batch_size, -1)) + n_pixels_in_batch = int(xe.shape[1] / num_classes) + + # If multiple masks supplied, compute a loss for each mask + if hasattr(mask, "__iter__"): + ce_sums = [] + ce_means = [] + masks = [] + for this_mask in mask: + this_mask = self.prepare_mask( + this_mask, + top_k_percentage, + deterministic, + num_classes, + y_flat.device, + batch_size, + n_pixels_in_batch, + xe, + ) + + ce_sum_per_instance = torch.sum(this_mask * xe, axis=1) + ce_sums.append(torch.mean(ce_sum_per_instance, axis=0)) + ce_means.append(torch.sum(this_mask * xe) / torch.sum(this_mask)) + masks.append(this_mask) + + return ce_sums, ce_means, masks + + mask = self.prepare_mask( + mask, + top_k_percentage, + deterministic, + num_classes, + y_flat.device, + batch_size, + n_pixels_in_batch, + xe, + ) - return {"mean": reconstruction_loss_mean, "sum": reconstruction_loss_sum, "mask": mask} + ce_sum_per_instance = torch.sum(mask * xe, axis=1) + ce_sum = torch.mean(ce_sum_per_instance, axis=0) + ce_mean = torch.sum(mask * xe) / torch.sum(mask) - def loss(self, img, seg): + return ce_sum, ce_mean, mask + + def loss(self, img, seg, mask=None): """The full training objective, either ELBO or GECO. Args: @@ -732,51 +828,104 @@ def loss(self, img, seg): Returns: dict: A dictionary holding the loss (with key 'loss') """ - summaries = {} - rec_loss = self.rec_loss(img, seg) - + kl_summaries = {} kl_dict = self.kl(img, seg) kl_sum = torch.sum(torch.stack([kl for _, kl in kl_dict.items()], axis=-1)) - - summaries["rec_loss_mean"] = rec_loss["mean"] - summaries["rec_loss_sum"] = rec_loss["sum"] - summaries["kl_sum"] = kl_sum for level, kl in kl_dict.items(): - summaries["kl_{}".format(level)] = kl + kl_summaries[f"kl_{level}"] = kl + + top_k_percentage = None + if "top_k_percentage" in self.loss_params: + top_k_percentage = self.loss_params["top_k_percentage"] + + loss_mask = None + reconstruction_threshold = self.loss_params["kappa"] + contour_threshold = None + if "kappa_contour" in self.loss_params and self.loss_params["kappa_contour"] is not None: + loss_mask = [None, mask] + contour_threshold = self.loss_params["kappa_contour"] + + # Here we use the posterior sample sampled above + _, rec_loss_mean, _ = self.reconstruction_loss( + img, + seg, + top_k_percentage=top_k_percentage, + mask=loss_mask, + ) - # Set up a regular ELBO objective. - if self._loss_kwargs["type"] == "elbo": - loss = rec_loss["sum"] + self._loss_kwargs["beta"] * kl_sum - summaries["elbo_loss"] = loss + # If using contour mask in loss, we get back those in a list. Unpack here. + if contour_threshold: + contour_loss = rec_loss_mean[1] + contour_loss_mean = rec_loss_mean[1] + reconstruction_loss = rec_loss_mean[0] + rec_loss_mean = rec_loss_mean[0] + else: + reconstruction_loss = rec_loss_mean - # Set up a GECO objective (ELBO with a reconstruction constraint). - elif self._loss_kwargs["type"] == "geco": + if self.loss_type == "elbo": - loss = rec_loss["sum"] * self._multiplier + self._loss_kwargs["beta"] * kl_sum - # ma_rec_loss = self._moving_average(rec_loss["sum"]) - if self._ema is None: - self._ema = rec_loss["sum"].detach().mean(0) - else: - # alpha = self._loss_kwargs["alpha"] - self._ema = (self._ema * 0.5 + rec_loss["sum"].detach() * (1 - 0.5)).mean(0) - - mask_sum_per_instance = torch.sum(rec_loss["mask"], -1) - num_valid_pixels = torch.mean(mask_sum_per_instance) - reconstruction_threshold = self._loss_kwargs["kappa"] * num_valid_pixels - rec_constraint = self._ema - reconstruction_threshold - speed = 1 - if rec_constraint > 0: - speed = 2 - self._multiplier = (speed * rec_constraint * self._multiplier).clamp(1e-5, 1e2) - # loss = rec_loss["sum"] * self._multiplier + self._loss_kwargs["beta"] * kl_sum - - summaries["geco_loss"] = loss - # summaries["ma_rec_loss_mean"] = ma_rec_loss / num_valid_pixels - summaries["num_valid_pixels"] = num_valid_pixels - summaries["lagmul"] = self._multiplier - else: - raise NotImplementedError( - "Loss type {} not implemeted!".format(self._loss_kwargs["type"]) - ) + return { + "loss": reconstruction_loss + self.loss_params["beta"] * kl_sum, + "rec_loss": reconstruction_loss, + "kl_div": kl_sum, + } + elif self.loss_type == "geco": + + with torch.no_grad(): + + moving_avg_factor = 0.8 + + rl = rec_loss_mean.detach() + if self._rec_moving_avg is None: + self._rec_moving_avg = rl + else: + self._rec_moving_avg = self._rec_moving_avg * moving_avg_factor + rl * ( + 1 - moving_avg_factor + ) - return dict(supervised_loss=loss, summaries=summaries) + rc = self._rec_moving_avg - reconstruction_threshold + + cc = 0 + if contour_threshold: + cl = contour_loss_mean.detach() + if self._contour_moving_avg is None: + self._contour_moving_avg = rl + else: + self._contour_moving_avg = ( + self._contour_moving_avg * moving_avg_factor + + cl * (1 - moving_avg_factor) + ) + + cc = self._contour_moving_avg - contour_threshold + + lambda_lower = self.loss_params["clamp_rec"][0] + lambda_upper = self.loss_params["clamp_rec"][1] + self._lambda = ( # pylint: disable=attribute-defined-outside-init + torch.exp(torch.Tensor([rc, cc]).to(rc.device)) * self._lambda + ).clamp(lambda_lower, lambda_upper) + + # pylint: disable=access-member-before-definition + loss = (self._lambda[0] * reconstruction_loss) + kl_sum + + result = { + "loss": loss, + "rec_loss": reconstruction_loss, + "kl_div": kl_sum, + "lambda_rec": self._lambda[0], + "moving_avg": self._rec_moving_avg, + "reconstruction_threshold": reconstruction_threshold, + "rec_constraint": rc, + } + + if contour_threshold is not None: + result["loss"] = result["loss"] + (self._lambda[1] * contour_loss) + result["contour_loss"] = contour_loss + result["contour_threshold"] = contour_threshold + result["contour_constraint"] = cc + result["moving_avg_contour"] = self._contour_moving_avg + result["lambda_contour"] = self._lambda[1] + + return result + + else: + raise NotImplementedError("Loss must be 'elbo' or 'geco'") diff --git a/platipy/imaging/cnn/prob_unet.py b/platipy/imaging/cnn/prob_unet.py index de16ffec..ae0173eb 100644 --- a/platipy/imaging/cnn/prob_unet.py +++ b/platipy/imaging/cnn/prob_unet.py @@ -288,19 +288,13 @@ def reconstruct(self, use_posterior_mean=False, z_posterior=None): z_posterior = self.posterior_latent_space.rsample() return self.fcomb.forward(self.unet_features, z_posterior) - def kl_divergence(self, analytic=True, z_posterior=None): + def kl_divergence(self): """ Calculate the KL divergence between the posterior and prior KL(Q||P) - analytic: calculate KL analytically or via sampling from the posterior """ - if analytic: - kl_div = kl.kl_divergence(self.posterior_latent_space, self.prior_latent_space) - else: - if z_posterior is None: - z_posterior = self.posterior_latent_space.rsample() - log_posterior_prob = self.posterior_latent_space.log_prob(z_posterior) - log_prior_prob = self.prior_latent_space.log_prob(z_posterior) - kl_div = log_posterior_prob - log_prior_prob + + kl_div = kl.kl_divergence(self.posterior_latent_space, self.prior_latent_space) + return kl_div def topk_mask(self, score, k): @@ -364,7 +358,6 @@ def prepare_mask( def reconstruction_loss( self, segm, - reconstruct_posterior_mean=False, z_posterior=None, mask=None, top_k_percentage=None, @@ -376,13 +369,7 @@ def reconstruction_loss( if z_posterior is None: z_posterior = self.posterior_latent_space.rsample() - reconstruction = self.reconstruct( - use_posterior_mean=reconstruct_posterior_mean, z_posterior=z_posterior - ) - - segm = torch.unsqueeze(segm, dim=1) - not_seg = segm.logical_not() - segm = torch.cat((not_seg, segm), dim=1).float() + reconstruction = self.reconstruct(use_posterior_mean=False, z_posterior=z_posterior) ##### num_classes = reconstruction.shape[1] @@ -435,36 +422,19 @@ def reconstruction_loss( return ce_sum, ce_mean, mask - def loss( - self, - segm, - analytic_kl=True, - reconstruct_posterior_mean=False, - mask=None, - use_max_lambda=False, - ): + def loss(self, segm, mask=None): """ Calculate the evidence lower bound of the log-likelihood of P(Y|X) """ z_posterior = self.posterior_latent_space.rsample() - kl_div = torch.mean(self.kl_divergence(analytic=analytic_kl, z_posterior=z_posterior)) + kl_div = torch.mean(self.kl_divergence()) top_k_percentage = None if "top_k_percentage" in self.loss_params: top_k_percentage = self.loss_params["top_k_percentage"] - # loss_mask = None - # if self.loss_params["contour_loss_lambda_threshold"]: - # if ( - # self._lambda # pylint: disable=access-member-before-definition - # <= self.loss_params["contour_loss_lambda_threshold"] - # and not use_max_lambda - # ): - # loss_mask = mask - # loss_mask = loss_mask.unsqueeze(1).repeat((1, self.num_classes, 1, 1)) - loss_mask = None reconstruction_threshold = self.loss_params["kappa"] contour_threshold = None @@ -475,7 +445,6 @@ def loss( # Here we use the posterior sample sampled above _, rec_loss_mean, _ = self.reconstruction_loss( segm, - reconstruct_posterior_mean=reconstruct_posterior_mean, z_posterior=z_posterior, top_k_percentage=top_k_percentage, mask=loss_mask, @@ -487,6 +456,8 @@ def loss( contour_loss_mean = rec_loss_mean[1] reconstruction_loss = rec_loss_mean[0] rec_loss_mean = rec_loss_mean[0] + else: + reconstruction_loss = rec_loss_mean if self.loss_type == "elbo": @@ -526,13 +497,9 @@ def loss( lambda_lower = self.loss_params["clamp_rec"][0] lambda_upper = self.loss_params["clamp_rec"][1] - if use_max_lambda: - self._lambda[0] = lambda_upper - self._lambda[1] = lambda_upper - else: - self._lambda = ( # pylint: disable=attribute-defined-outside-init - torch.exp(torch.Tensor([rc, cc]).to(rc.device)) * self._lambda - ).clamp(lambda_lower, lambda_upper) + self._lambda = ( # pylint: disable=attribute-defined-outside-init + torch.exp(torch.Tensor([rc, cc]).to(rc.device)) * self._lambda + ).clamp(lambda_lower, lambda_upper) # pylint: disable=access-member-before-definition loss = (self._lambda[0] * reconstruction_loss) + kl_div diff --git a/platipy/imaging/cnn/train.py b/platipy/imaging/cnn/train.py index bf17c907..58c63729 100644 --- a/platipy/imaging/cnn/train.py +++ b/platipy/imaging/cnn/train.py @@ -34,6 +34,7 @@ import matplotlib.pyplot as plt from platipy.imaging.cnn.prob_unet import ProbabilisticUnet +from platipy.imaging.cnn.hierarchical_prob_unet import HierarchicalProbabilisticUnet from platipy.imaging.cnn.unet import l2_regularisation from platipy.imaging.cnn.dataload import UNetDataModule from platipy.imaging.cnn.dataset import crop_img_using_localise_model @@ -70,16 +71,27 @@ def __init__( loss_params["contour_loss_lambda_threshold"] = self.hparams.contour_loss_lambda_threshold loss_params["contour_loss_weight"] = self.hparams.contour_loss_weight - self.prob_unet = ProbabilisticUnet( - self.hparams.input_channels, - self.hparams.num_classes, - self.hparams.filters_per_layer, - self.hparams.latent_dim, - self.hparams.no_convs_fcomb, - self.hparams.loss_type, - loss_params, - self.hparams.ndims, - ) + if self.hparams.prob_type == "prob": + self.prob_unet = ProbabilisticUnet( + self.hparams.input_channels, + self.hparams.num_classes, + self.hparams.filters_per_layer, + self.hparams.latent_dim, + self.hparams.no_convs_fcomb, + self.hparams.loss_type, + loss_params, + self.hparams.ndims, + ) + elif self.hparams.prob_type == "hierarchical": + self.prob_unet = HierarchicalProbabilisticUnet( + self.hparams.input_channels, + self.hparams.num_classes, + self.hparams.filters_per_layer, + self.hparams.down_channels_per_block, + [self.hparams.latent_dim] * (len(self.hparams.filters_per_layer) - 1), + loss_params, + self.hparams.ndims, + ) self.validation_directory = None @@ -88,13 +100,13 @@ def __init__( @staticmethod def add_model_specific_args(parent_parser): parser = parent_parser.add_argument_group("Probabilistic UNet") + parser.add_argument("--prob_type", type=str, default="prob") parser.add_argument("--learning_rate", type=float, default=1e-5) parser.add_argument("--lr_lambda", type=float, default=0.99) parser.add_argument("--input_channels", type=int, default=1) parser.add_argument("--num_classes", type=int, default=2) - parser.add_argument( - "--filters_per_layer", nargs="+", type=int, default=[64 * (2 ** x) for x in range(5)] - ) + parser.add_argument("--down_channels_per_block", nargs="+", type=int, default=None) + parser.add_argument("--clamp_rec", nargs="+", type=float, default=[1e-5, 1e5]) parser.add_argument("--latent_dim", type=int, default=6) parser.add_argument("--no_convs_fcomb", type=int, default=4) parser.add_argument("--loss_type", type=str, default="elbo") @@ -105,7 +117,7 @@ def add_model_specific_args(parent_parser): parser.add_argument("--top_k_percentage", type=float, default=None) parser.add_argument("--contour_loss_lambda_threshold", type=float, default=None) parser.add_argument("--contour_loss_weight", type=float, default=0.0) # no longer used - parser.add_argument("--epochs_all_rec", type=int, default=0) + parser.add_argument("--epochs_all_rec", type=int, default=0) # no longer used return parent_parser @@ -321,11 +333,14 @@ def training_step(self, batch, _): x, y, m, _ = batch - self.prob_unet.forward(x, y, training=True) + # Add background layer for one-hot encoding + y = torch.unsqueeze(y, dim=1) + not_y = y.logical_not() + y = torch.cat((not_y, y), dim=1).float() - use_max_lambda = self.current_epoch < self.hparams.epochs_all_rec + self.prob_unet.forward(x, y, training=True) - loss = self.prob_unet.loss(y, analytic_kl=True, mask=m, use_max_lambda=use_max_lambda) + loss = self.prob_unet.loss(y, mask=m) reg_loss = ( l2_regularisation(self.prob_unet.posterior) + l2_regularisation(self.prob_unet.prior) From 38c7e2adc31ad4afc3a5f8c4b8cc30ddb55e7109 Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Wed, 15 Sep 2021 14:07:33 +1000 Subject: [PATCH 141/264] format visualisation --- platipy/imaging/cnn/train.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/platipy/imaging/cnn/train.py b/platipy/imaging/cnn/train.py index bf17c907..16a6bea9 100644 --- a/platipy/imaging/cnn/train.py +++ b/platipy/imaging/cnn/train.py @@ -162,7 +162,7 @@ def infer( elif sample_strategy == "spaced": samples = [ { - "name": f"spaced_{s}", + "name": f"spaced_{s:.2f}", "std_dev_from_mean": torch.Tensor([s if d else 0.0 for d in latent_dim]).to( self.device ), @@ -233,7 +233,7 @@ def infer( pred = postprocess_mask(pred) pred = sitk.Resample(pred, img, sitk.Transform(), sitk.sitkNearestNeighbor) - result[sample["name"]] = pred + result[sample['name']] = pred return result From 05ca2b446c3729d0a856e7891e068882c05eae89 Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Wed, 15 Sep 2021 06:34:21 +0000 Subject: [PATCH 142/264] Update to hpunet test --- platipy/imaging/cnn/test_hpunet.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/platipy/imaging/cnn/test_hpunet.py b/platipy/imaging/cnn/test_hpunet.py index 7562b41d..928f96cd 100644 --- a/platipy/imaging/cnn/test_hpunet.py +++ b/platipy/imaging/cnn/test_hpunet.py @@ -36,15 +36,17 @@ labels = torch.cat([fg, bg], axis=1) hpunet = HierarchicalProbabilisticUnet( - channels_per_block=channels_per_block, + filters_per_layer=channels_per_block, latent_dims=[1], - loss_kwargs={ - "type": "geco", - "top_k_percentage": 0.02, + loss_type="geco", + loss_params={ + # "top_k_percentage": 0.02, + "top_k_percentage": None, "deterministic_top_k": False, "kappa": 0.05, "decay": 0.99, "rate": 1e-2, + "clamp_rec": [0.001, 10000], "beta": 5, }, ) @@ -77,7 +79,7 @@ def _get_placeholders(): def test_shape_of_sample(): hpu_net = HierarchicalProbabilisticUnet( latent_dims=_LATENT_DIMS, - channels_per_block=_CHANNELS_PER_BLOCK, + filters_per_layer=_CHANNELS_PER_BLOCK, num_classes=_NUM_CLASSES, ) img, _ = _get_placeholders() @@ -89,7 +91,7 @@ def test_shape_of_sample(): def test_shape_of_reconstruction(): hpu_net = HierarchicalProbabilisticUnet( latent_dims=_LATENT_DIMS, - channels_per_block=_CHANNELS_PER_BLOCK, + filters_per_layer=_CHANNELS_PER_BLOCK, num_classes=_NUM_CLASSES, ) img, seg = _get_placeholders() @@ -100,7 +102,7 @@ def test_shape_of_reconstruction(): def test_shapes_in_prior(): hpu_net = HierarchicalProbabilisticUnet( latent_dims=_LATENT_DIMS, - channels_per_block=_CHANNELS_PER_BLOCK, + filters_per_layer=_CHANNELS_PER_BLOCK, num_classes=_NUM_CLASSES, ) img, _ = _get_placeholders() @@ -147,7 +149,7 @@ def test_shapes_in_prior(): def test_shape_of_kl(): hpu_net = HierarchicalProbabilisticUnet( latent_dims=_LATENT_DIMS, - channels_per_block=_CHANNELS_PER_BLOCK, + filters_per_layer=_CHANNELS_PER_BLOCK, num_classes=_NUM_CLASSES, ) img, seg = _get_placeholders() From dd81b37d624040abe6681485d80705ed1f6d4d62 Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Sat, 18 Sep 2021 15:44:30 +1000 Subject: [PATCH 143/264] Fix double argument --- platipy/imaging/cnn/train.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/platipy/imaging/cnn/train.py b/platipy/imaging/cnn/train.py index 728fff57..cbbc7e11 100644 --- a/platipy/imaging/cnn/train.py +++ b/platipy/imaging/cnn/train.py @@ -105,8 +105,11 @@ def add_model_specific_args(parent_parser): parser.add_argument("--lr_lambda", type=float, default=0.99) parser.add_argument("--input_channels", type=int, default=1) parser.add_argument("--num_classes", type=int, default=2) + parser.add_argument( + "--filters_per_layer", nargs="+", type=int, default=[64 * (2 ** x) for x in range(5)] + ) parser.add_argument("--down_channels_per_block", nargs="+", type=int, default=None) - parser.add_argument("--clamp_rec", nargs="+", type=float, default=[1e-5, 1e5]) +# parser.add_argument("--clamp_rec", nargs="+", type=float, default=[1e-5, 1e5]) parser.add_argument("--latent_dim", type=int, default=6) parser.add_argument("--no_convs_fcomb", type=int, default=4) parser.add_argument("--loss_type", type=str, default="elbo") @@ -185,6 +188,8 @@ def infer( with torch.no_grad(): + print(img.GetSize()) + if preprocess: if self.hparams.crop_using_localise_model: localise_path = self.hparams.crop_using_localise_model.format( @@ -205,6 +210,8 @@ def infer( intensity_window=self.hparams.intensity_window, ) + print(img.GetSize()) + img_arr = sitk.GetArrayFromImage(img) if self.hparams.ndims == 2: slices = [img_arr[z, :, :] for z in range(img_arr.shape[0])] @@ -227,6 +234,7 @@ def infer( sample_x_stddev_from_mean=sample["std_dev_from_mean"], ) + print(y.shape) y = y.squeeze(0) y = np.argmax(y.cpu().detach().numpy(), axis=0) sample["preds"].append(y) @@ -237,7 +245,8 @@ def infer( pred_arr = sample["preds"][0] if len(sample["preds"]) > 1: pred_arr = np.stack(sample["preds"]) - + print(pred_arr.shape) + print(img.GetSize()) pred = sitk.GetImageFromArray(pred_arr) pred = sitk.Cast(pred, sitk.sitkUInt8) From e537466da920efe67087ad079837d9bf9448c8e8 Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Sat, 18 Sep 2021 08:07:14 +0000 Subject: [PATCH 144/264] Working on hprobunet --- platipy/imaging/cnn/dataload.py | 8 +++- platipy/imaging/cnn/prob_unet.py | 16 ++++--- platipy/imaging/cnn/pseudo_generator.py | 27 +++++++---- platipy/imaging/cnn/train.py | 60 +++++++++++++++---------- 4 files changed, 70 insertions(+), 41 deletions(-) diff --git a/platipy/imaging/cnn/dataload.py b/platipy/imaging/cnn/dataload.py index 4b4d92d9..ef5f89df 100644 --- a/platipy/imaging/cnn/dataload.py +++ b/platipy/imaging/cnn/dataload.py @@ -126,13 +126,17 @@ def setup(self, stage=None): cases.sort() random.shuffle(cases) # will be consistent for same value of 'seed everything' cases_per_fold = math.ceil(len(cases) / self.k_folds) + print(cases_per_fold) for f in range(self.k_folds): if self.fold == f: val_test_cases = cases[f * cases_per_fold : (f + 1) * cases_per_fold] - self.validation_cases = val_test_cases[: int(len(val_test_cases) / 2)] - self.test_cases = val_test_cases[int(len(val_test_cases) / 2) :] + if len(val_test_cases) == 1: + self.validation_cases = val_test_cases + else: + self.validation_cases = val_test_cases[: int(len(val_test_cases) / 2)] + self.test_cases = val_test_cases[int(len(val_test_cases) / 2) :] else: self.train_cases += cases[f * cases_per_fold : (f + 1) * cases_per_fold] diff --git a/platipy/imaging/cnn/prob_unet.py b/platipy/imaging/cnn/prob_unet.py index ae0173eb..9ce9a92e 100644 --- a/platipy/imaging/cnn/prob_unet.py +++ b/platipy/imaging/cnn/prob_unet.py @@ -92,7 +92,7 @@ def forward(self, img, seg=None): x = img if seg is not None: - seg = torch.unsqueeze(seg, dim=1) + # seg = torch.unsqueeze(seg, dim=1) x = torch.cat((img, seg), dim=1) encoding = self.encoder(x) @@ -221,7 +221,7 @@ def __init__( input_channels, filters_per_layer, latent_dim, ndims=ndims ) self.posterior = AxisAlignedConvGaussian( - input_channels + 1, filters_per_layer, latent_dim, ndims=ndims + input_channels + num_classes, filters_per_layer, latent_dim, ndims=ndims ) self.fcomb = Fcomb(filters_per_layer, latent_dim, num_classes, no_convs_fcomb, ndims=ndims) @@ -436,11 +436,15 @@ def loss(self, segm, mask=None): top_k_percentage = self.loss_params["top_k_percentage"] loss_mask = None - reconstruction_threshold = self.loss_params["kappa"] contour_threshold = None - if "kappa_contour" in self.loss_params and self.loss_params["kappa_contour"] is not None: - loss_mask = [None, mask] - contour_threshold = self.loss_params["kappa_contour"] + if self.loss_type == "geco": + reconstruction_threshold = self.loss_params["kappa"] + if ( + "kappa_contour" in self.loss_params + and self.loss_params["kappa_contour"] is not None + ): + loss_mask = [None, mask] + contour_threshold = self.loss_params["kappa_contour"] # Here we use the posterior sample sampled above _, rec_loss_mean, _ = self.reconstruction_loss( diff --git a/platipy/imaging/cnn/pseudo_generator.py b/platipy/imaging/cnn/pseudo_generator.py index af493912..5312ed64 100644 --- a/platipy/imaging/cnn/pseudo_generator.py +++ b/platipy/imaging/cnn/pseudo_generator.py @@ -10,9 +10,14 @@ from platipy.imaging import ImageVisualiser -def generate_pseudo_data(data_dir="data"): +def generate_pseudo_data(data_dir="data", cases=5, size=(24, 32, 32)): test_data_directory = Path(data_dir) + + if test_data_directory.exists(): + print("Data directory already exists, won't regenerate") + return + image_directory = test_data_directory.joinpath("images") label_directory = test_data_directory.joinpath("labels") slice_directory = test_data_directory.joinpath("slices") @@ -21,13 +26,15 @@ def generate_pseudo_data(data_dir="data"): label_directory.mkdir(parents=True, exist_ok=True) slice_directory.mkdir(parents=True, exist_ok=True) - for case, sphere_rad in enumerate(range(10, 30)): + for case, sphere_rad in enumerate(range(5, 5 + cases)): - xpos = random.randint(50, 80) - ypos = random.randint(50, 80) + xpos = random.randint(6, 24) + ypos = random.randint(6, 24) - mask_arr = np.zeros((80, 128, 128)) - mask_arr = insert_sphere(mask_arr, sp_radius=sphere_rad, sp_centre=(30, ypos, xpos)) + mask_arr = np.zeros(size) + mask_arr = insert_sphere( + mask_arr, sp_radius=sphere_rad, sp_centre=(int(size[0] / 2), ypos, xpos) + ) mask = sitk.GetImageFromArray(mask_arr) mask = sitk.Cast(mask, sitk.sitkUInt8) @@ -43,14 +50,16 @@ def generate_pseudo_data(data_dir="data"): sitk.WriteImage(ct, str(image_directory.joinpath(f"{case}.nii.gz"))) - vis = ImageVisualiser(ct, cut=(30, ypos, xpos)) + vis = ImageVisualiser(ct, cut=(int(size[0] / 2), ypos, xpos)) masks = {} for obs_id, obs in enumerate(range(-4, 5, 2)): obs_rad = sphere_rad + obs - mask_arr = np.zeros((80, 128, 128)) - mask_arr = insert_sphere(mask_arr, sp_radius=obs_rad, sp_centre=(30, ypos, xpos)) + mask_arr = np.zeros(size) + mask_arr = insert_sphere( + mask_arr, sp_radius=obs_rad, sp_centre=(int(size[0] / 2), ypos, xpos) + ) mask = sitk.GetImageFromArray(mask_arr) mask.CopyInformation(ct) diff --git a/platipy/imaging/cnn/train.py b/platipy/imaging/cnn/train.py index cbbc7e11..049f1cf6 100644 --- a/platipy/imaging/cnn/train.py +++ b/platipy/imaging/cnn/train.py @@ -84,13 +84,15 @@ def __init__( ) elif self.hparams.prob_type == "hierarchical": self.prob_unet = HierarchicalProbabilisticUnet( - self.hparams.input_channels, - self.hparams.num_classes, - self.hparams.filters_per_layer, - self.hparams.down_channels_per_block, - [self.hparams.latent_dim] * (len(self.hparams.filters_per_layer) - 1), - loss_params, - self.hparams.ndims, + input_channels=self.hparams.input_channels, + num_classes=self.hparams.num_classes, + filters_per_layer=self.hparams.filters_per_layer, + down_channels_per_block=self.hparams.down_channels_per_block, + latent_dims=[self.hparams.latent_dim] * (len(self.hparams.filters_per_layer) - 1), + convs_per_block=self.hparams.convs_per_block, + loss_type=self.hparams.loss_type, + loss_params=loss_params, + ndims=self.hparams.ndims, ) self.validation_directory = None @@ -109,9 +111,10 @@ def add_model_specific_args(parent_parser): "--filters_per_layer", nargs="+", type=int, default=[64 * (2 ** x) for x in range(5)] ) parser.add_argument("--down_channels_per_block", nargs="+", type=int, default=None) -# parser.add_argument("--clamp_rec", nargs="+", type=float, default=[1e-5, 1e5]) + # parser.add_argument("--clamp_rec", nargs="+", type=float, default=[1e-5, 1e5]) parser.add_argument("--latent_dim", type=int, default=6) parser.add_argument("--no_convs_fcomb", type=int, default=4) + parser.add_argument("--convs_per_block", type=int, default=3) parser.add_argument("--loss_type", type=str, default="elbo") parser.add_argument("--beta", type=float, default=1.0) parser.add_argument("--kappa", type=float, default=0.02) @@ -188,8 +191,6 @@ def infer( with torch.no_grad(): - print(img.GetSize()) - if preprocess: if self.hparams.crop_using_localise_model: localise_path = self.hparams.crop_using_localise_model.format( @@ -210,8 +211,6 @@ def infer( intensity_window=self.hparams.intensity_window, ) - print(img.GetSize()) - img_arr = sitk.GetArrayFromImage(img) if self.hparams.ndims == 2: slices = [img_arr[z, :, :] for z in range(img_arr.shape[0])] @@ -222,19 +221,31 @@ def infer( x = torch.Tensor(i).to(self.device) x = x.unsqueeze(0) x = x.unsqueeze(0) - self.prob_unet.forward(x) + + if self.hparams.prob_type == "prob": + self.prob_unet.forward(x) for sample in samples: - if sample["name"] == "mean": - y = self.prob_unet.sample(testing=True, use_mean=True) + + if self.hparams.prob_type == "prob": + if sample["name"] == "mean": + y = self.prob_unet.sample(testing=True, use_mean=True) + else: + y = self.prob_unet.sample( + testing=True, + use_mean=False, + sample_x_stddev_from_mean=sample["std_dev_from_mean"], + ) else: - y = self.prob_unet.sample( - testing=True, - use_mean=False, - sample_x_stddev_from_mean=sample["std_dev_from_mean"], - ) + if sample["name"] == "mean": + y = self.prob_unet.sample(x, mean=True) + else: + y = self.prob_unet.sample( + x, + mean=False, + std_devs_from_mean=sample["std_dev_from_mean"], + ) - print(y.shape) y = y.squeeze(0) y = np.argmax(y.cpu().detach().numpy(), axis=0) sample["preds"].append(y) @@ -243,10 +254,11 @@ def infer( for sample in samples: pred_arr = sample["preds"][0] + + if self.hparams.ndims == 2: + pred_arr = np.expand_dims(pred_arr, 0) if len(sample["preds"]) > 1: pred_arr = np.stack(sample["preds"]) - print(pred_arr.shape) - print(img.GetSize()) pred = sitk.GetImageFromArray(pred_arr) pred = sitk.Cast(pred, sitk.sitkUInt8) @@ -254,7 +266,7 @@ def infer( pred = postprocess_mask(pred) pred = sitk.Resample(pred, img, sitk.Transform(), sitk.sitkNearestNeighbor) - result[sample['name']] = pred + result[sample["name"]] = pred return result From fe4586924d66766cd44c84d0ecb9f2dd10382b9f Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Tue, 21 Sep 2021 16:27:11 +1000 Subject: [PATCH 145/264] Separate clamp for contour loss and rec loss --- platipy/imaging/cnn/prob_unet.py | 9 ++++++++- platipy/imaging/cnn/train.py | 3 ++- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/platipy/imaging/cnn/prob_unet.py b/platipy/imaging/cnn/prob_unet.py index 9ce9a92e..fac4b8ff 100644 --- a/platipy/imaging/cnn/prob_unet.py +++ b/platipy/imaging/cnn/prob_unet.py @@ -430,6 +430,7 @@ def loss(self, segm, mask=None): z_posterior = self.posterior_latent_space.rsample() kl_div = torch.mean(self.kl_divergence()) + kl_div = torch.clamp(kl_div,0.0,100.0) top_k_percentage = None if "top_k_percentage" in self.loss_params: @@ -501,9 +502,15 @@ def loss(self, segm, mask=None): lambda_lower = self.loss_params["clamp_rec"][0] lambda_upper = self.loss_params["clamp_rec"][1] + lambda_lower_contour = self.loss_params["clamp_contour"][0] + lambda_upper_contour = self.loss_params["clamp_contour"][1] + self._lambda = ( # pylint: disable=attribute-defined-outside-init torch.exp(torch.Tensor([rc, cc]).to(rc.device)) * self._lambda - ).clamp(lambda_lower, lambda_upper) + ) + + self._lambda[0] = self._lambda[0].clamp(lambda_lower, lambda_upper) + self._lambda[1] = self._lambda[1].clamp(lambda_lower_contour, lambda_upper_contour) # pylint: disable=access-member-before-definition loss = (self._lambda[0] * reconstruction_loss) + kl_div diff --git a/platipy/imaging/cnn/train.py b/platipy/imaging/cnn/train.py index 049f1cf6..e6729c1e 100644 --- a/platipy/imaging/cnn/train.py +++ b/platipy/imaging/cnn/train.py @@ -64,6 +64,7 @@ def __init__( loss_params = { "kappa": self.hparams.kappa, "clamp_rec": self.hparams.clamp_rec, + "clamp_contour": self.hparams.clamp_contour, "kappa_contour": self.hparams.kappa_contour, } @@ -111,7 +112,6 @@ def add_model_specific_args(parent_parser): "--filters_per_layer", nargs="+", type=int, default=[64 * (2 ** x) for x in range(5)] ) parser.add_argument("--down_channels_per_block", nargs="+", type=int, default=None) - # parser.add_argument("--clamp_rec", nargs="+", type=float, default=[1e-5, 1e5]) parser.add_argument("--latent_dim", type=int, default=6) parser.add_argument("--no_convs_fcomb", type=int, default=4) parser.add_argument("--convs_per_block", type=int, default=3) @@ -120,6 +120,7 @@ def add_model_specific_args(parent_parser): parser.add_argument("--kappa", type=float, default=0.02) parser.add_argument("--kappa_contour", type=float, default=None) parser.add_argument("--clamp_rec", nargs="+", type=float, default=[1e-5, 1e5]) + parser.add_argument("--clamp_contour", nargs="+", type=float, default=[1e-3, 1e3]) parser.add_argument("--top_k_percentage", type=float, default=None) parser.add_argument("--contour_loss_lambda_threshold", type=float, default=None) parser.add_argument("--contour_loss_weight", type=float, default=0.0) # no longer used From 26530a6e3a24c6fbd04c9800defc1e31439b45fb Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Tue, 21 Sep 2021 06:37:16 +0000 Subject: [PATCH 146/264] Tests for prob UNet --- platipy/imaging/cnn/hierarchical_prob_unet.py | 10 +- platipy/imaging/cnn/train.py | 6 +- platipy/imaging/tests/test_probunet.py | 187 ++++++++++++++++++ 3 files changed, 200 insertions(+), 3 deletions(-) create mode 100644 platipy/imaging/tests/test_probunet.py diff --git a/platipy/imaging/cnn/hierarchical_prob_unet.py b/platipy/imaging/cnn/hierarchical_prob_unet.py index 26c974bc..b4d32578 100644 --- a/platipy/imaging/cnn/hierarchical_prob_unet.py +++ b/platipy/imaging/cnn/hierarchical_prob_unet.py @@ -656,7 +656,7 @@ def reconstruct(self, img, seg, mean=False): torch.Tensor: A segmentation tensor of shape (b,num_classes,h,w). """ - self.forward(img, seg) + # self.forward(img, seg) if mean: prior_out = self._p_sample_z_q_mean else: @@ -900,9 +900,15 @@ def loss(self, img, seg, mask=None): lambda_lower = self.loss_params["clamp_rec"][0] lambda_upper = self.loss_params["clamp_rec"][1] + lambda_lower_contour = self.loss_params["clamp_contour"][0] + lambda_upper_contour = self.loss_params["clamp_contour"][1] + self._lambda = ( # pylint: disable=attribute-defined-outside-init torch.exp(torch.Tensor([rc, cc]).to(rc.device)) * self._lambda - ).clamp(lambda_lower, lambda_upper) + ) + + self._lambda[0] = self._lambda[0].clamp(lambda_lower, lambda_upper) + self._lambda[1] = self._lambda[1].clamp(lambda_lower_contour, lambda_upper_contour) # pylint: disable=access-member-before-definition loss = (self._lambda[0] * reconstruction_loss) + kl_sum diff --git a/platipy/imaging/cnn/train.py b/platipy/imaging/cnn/train.py index e6729c1e..0221fb7e 100644 --- a/platipy/imaging/cnn/train.py +++ b/platipy/imaging/cnn/train.py @@ -360,7 +360,11 @@ def training_step(self, batch, _): not_y = y.logical_not() y = torch.cat((not_y, y), dim=1).float() - self.prob_unet.forward(x, y, training=True) + # self.prob_unet.forward(x, y, training=True) + if self.hparams.prob_type == "prob": + self.prob_unet.forward(x, y, training=True) + else: + self.prob_unet.forward(x, y) loss = self.prob_unet.loss(y, mask=m) reg_loss = ( diff --git a/platipy/imaging/tests/test_probunet.py b/platipy/imaging/tests/test_probunet.py new file mode 100644 index 00000000..d9379391 --- /dev/null +++ b/platipy/imaging/tests/test_probunet.py @@ -0,0 +1,187 @@ +# Copyright 2021 University of New South Wales, University of Sydney, Ingham Institute + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pylint: disable=redefined-outer-name,missing-function-docstring + +from argparse import ArgumentParser + +import pytest + +import pytorch_lightning as pl +from platipy.imaging.cnn.train import main, ProbUNet, UNetDataModule +from platipy.imaging.cnn.pseudo_generator import generate_pseudo_data + + +@pytest.fixture +def trainer_arg_parser(): + + generate_pseudo_data() + + arg_parser = ArgumentParser() + arg_parser = ProbUNet.add_model_specific_args(arg_parser) + arg_parser = UNetDataModule.add_model_specific_args(arg_parser) + arg_parser = pl.Trainer.add_argparse_args(arg_parser) + arg_parser.add_argument( + "--config", type=str, default=None, help="JSON file with parameters to load" + ) + arg_parser.add_argument("--seed", type=int, default=42, help="an integer to use as seed") + arg_parser.add_argument("--experiment", type=str, default="default", help="Name of experiment") + arg_parser.add_argument("--working_dir", type=str, default="./working") + arg_parser.add_argument("--num_observers", type=int, default=5) + arg_parser.add_argument("--spacing", nargs="+", type=float, default=[1, 1, 1]) + arg_parser.add_argument("--offline", type=bool, default=False) + arg_parser.add_argument("--comet_api_key", type=str, default=None) + arg_parser.add_argument("--comet_workspace", type=str, default=None) + arg_parser.add_argument("--comet_project", type=str, default=None) + arg_parser.add_argument("--resume_from", type=str, default=None) + return arg_parser + + +def test_prob_unet_2d_elbo(trainer_arg_parser): + + args = trainer_arg_parser.parse_args( + [ + "--working_dir", + "test_prob_unet_2d_elbo", + "--num_workers", + "1", + "--limit_train_batches", + "0.01", + "--loss_type", + "elbo", + "--prob_type", + "prob", + "--max_epochs", + "1", + "--ndims", + "2", + "--filters_per_layer", + "2", + "4", + ] + ) + + main(args) + + +def test_prob_unet_3d_elbo(trainer_arg_parser): + + args = trainer_arg_parser.parse_args( + [ + "--working_dir", + "test_prob_unet_3d_elbo", + "--num_workers", + "1", + "--limit_train_batches", + "0.05", + "--loss_type", + "elbo", + "--prob_type", + "prob", + "--max_epochs", + "1", + "--ndims", + "3", + "--filters_per_layer", + "2", + "4", + "--batch_size", + "1", + ] + ) + + main(args) + + +def test_prob_unet_2d_geco(trainer_arg_parser): + + args = trainer_arg_parser.parse_args( + [ + "--working_dir", + "test_prob_unet_2d_geco", + "--num_workers", + "1", + "--limit_train_batches", + "0.01", + "--loss_type", + "geco", + "--prob_type", + "prob", + "--max_epochs", + "1", + "--ndims", + "2", + "--filters_per_layer", + "2", + "4", + ] + ) + + main(args) + + +def test_prob_unet_3d_geco(trainer_arg_parser): + + args = trainer_arg_parser.parse_args( + [ + "--working_dir", + "test_prob_unet_3d_geco", + "--num_workers", + "1", + "--limit_train_batches", + "0.05", + "--loss_type", + "geco", + "--prob_type", + "prob", + "--max_epochs", + "1", + "--ndims", + "3", + "--filters_per_layer", + "2", + "4", + "--batch_size", + "1", + ] + ) + + main(args) + + +def test_hierarchical_prob_unet_2d_geco(trainer_arg_parser): + + args = trainer_arg_parser.parse_args( + [ + "--working_dir", + "test_prob_unet_2d_geco", + "--num_workers", + "1", + "--limit_train_batches", + "0.01", + "--loss_type", + "geco", + "--prob_type", + "hierarchical", + "--max_epochs", + "1", + "--ndims", + "2", + "--filters_per_layer", + "2", + "4", + ] + ) + + main(args) From 410ffd43d3867272f1997b2892e9b6b746e9a4a9 Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Tue, 21 Sep 2021 06:37:27 +0000 Subject: [PATCH 147/264] Update to gitignore --- .gitignore | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/.gitignore b/.gitignore index 2f076351..b7918abc 100644 --- a/.gitignore +++ b/.gitignore @@ -145,4 +145,9 @@ converted/ **/nifti_output # Don't include html docs in repo -docs/site/ \ No newline at end of file +docs/site/ + +*.npy +*.nii.gz + +test_prob*/ \ No newline at end of file From 8369bc606a379c5782e76440060e59fa61afda00 Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Wed, 22 Sep 2021 08:20:18 +1000 Subject: [PATCH 148/264] Correct call for hprob to loss --- platipy/imaging/cnn/train.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/platipy/imaging/cnn/train.py b/platipy/imaging/cnn/train.py index 0221fb7e..9d21b3b1 100644 --- a/platipy/imaging/cnn/train.py +++ b/platipy/imaging/cnn/train.py @@ -366,7 +366,10 @@ def training_step(self, batch, _): else: self.prob_unet.forward(x, y) - loss = self.prob_unet.loss(y, mask=m) + if self.hparams.prob_type == "prob": + loss = self.prob_unet.loss(y, mask=m) + else: + loss = self.prob_unet.loss(x, y, mask=m) reg_loss = ( l2_regularisation(self.prob_unet.posterior) + l2_regularisation(self.prob_unet.prior) From c3f38162553ef81bb72f46c07e105757a3d629fd Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Wed, 22 Sep 2021 00:45:38 +0000 Subject: [PATCH 149/264] Corrections to hprob net --- platipy/imaging/cnn/hierarchical_prob_unet.py | 56 +++++++++---------- platipy/imaging/cnn/prob_unet.py | 2 +- platipy/imaging/tests/test_probunet.py | 39 +++++++++++-- 3 files changed, 62 insertions(+), 35 deletions(-) diff --git a/platipy/imaging/cnn/hierarchical_prob_unet.py b/platipy/imaging/cnn/hierarchical_prob_unet.py index b4d32578..978d21dd 100644 --- a/platipy/imaging/cnn/hierarchical_prob_unet.py +++ b/platipy/imaging/cnn/hierarchical_prob_unet.py @@ -429,7 +429,7 @@ def __init__( self._start_level = num_latents + 1 self._num_levels = len(self._channels_per_block) - self.decoder_layers = torch.nn.ModuleList() + self.layers = torch.nn.ModuleList() decoder_in_channels = None for level in range(self._start_level, self._num_levels, 1): @@ -451,8 +451,8 @@ def __init__( ) decoder_in_channels = channels_per_block[::-1][level] - self.decoder_layers.append(torch.nn.Sequential(*layer)) - self.decoder_layers.apply(init_weights) + self.layers.append(torch.nn.Sequential(*layer)) + self.layers.apply(init_weights) if decoder_in_channels is None: decoder_in_channels = channels_per_block[::-1][self._num_levels - 1] @@ -477,13 +477,13 @@ def forward(self, encoder_features, decoder_features): torch.Tensor: The stiched output """ - for level in range(len(self.decoder_layers)): + for level in range(len(self.layers)): enc_level = self._start_level + level decoder_features = resize_up(decoder_features, scale=2) decoder_features = torch.cat( [decoder_features, encoder_features[::-1][enc_level]], axis=1 ) - decoder_features = self.decoder_layers[level](decoder_features) + decoder_features = self.layers[level](decoder_features) return self.final_layer(decoder_features) @@ -542,7 +542,7 @@ def __init__( if down_channels_per_block is None: down_channels_per_block = [int(i / 2) for i in filters_per_layer] - self._prior = _HierarchicalCore( + self.prior = _HierarchicalCore( input_channels=input_channels, latent_dims=latent_dims, channels_per_block=filters_per_layer, @@ -552,7 +552,7 @@ def __init__( ndims=ndims, ) - self._posterior = _HierarchicalCore( + self.posterior = _HierarchicalCore( input_channels=input_channels + num_classes, latent_dims=latent_dims, channels_per_block=filters_per_layer, @@ -562,7 +562,7 @@ def __init__( ndims=ndims, ) - self._f_comb = _StitchingDecoder( + self.fcomb = _StitchingDecoder( latent_dims=latent_dims, channels_per_block=filters_per_layer, num_classes=num_classes, @@ -606,11 +606,11 @@ def forward(self, img, seg): # No need to recompute return - self._q_sample = self._posterior(input_tensor, mean=False) - self._q_sample_mean = self._posterior(input_tensor, mean=True) - self._p_sample = self._prior(img, mean=False, z_q=None) - self._p_sample_z_q = self._prior(img, z_q=self._q_sample["used_latents"]) - self._p_sample_z_q_mean = self._prior(img, z_q=self._q_sample_mean["used_latents"]) + self._q_sample = self.posterior(input_tensor, mean=False) + self._q_sample_mean = self.posterior(input_tensor, mean=True) + self._p_sample = self.prior(img, mean=False, z_q=None) + self._p_sample_z_q = self.prior(img, z_q=self._q_sample["used_latents"]) + self._p_sample_z_q_mean = self.prior(img, z_q=self._q_sample_mean["used_latents"]) self._cache = input_tensor def sample(self, img, mean=False, std_devs_from_mean=0.0, z_q=None): @@ -637,10 +637,10 @@ def sample(self, img, mean=False, std_devs_from_mean=0.0, z_q=None): torch.Tensor: A segmentation tensor of shape (b, num_classes, h, w). """ - prior_out = self._prior(img, mean, std_devs_from_mean, z_q) + prior_out = self.prior(img, mean, std_devs_from_mean, z_q) encoder_features = prior_out["encoder_features"] decoder_features = prior_out["decoder_features"] - return self._f_comb(encoder_features, decoder_features) + return self.fcomb(encoder_features, decoder_features) def reconstruct(self, img, seg, mean=False): """Reconstruct a segmentation using the posterior. @@ -663,7 +663,7 @@ def reconstruct(self, img, seg, mean=False): prior_out = self._p_sample_z_q encoder_features = prior_out["encoder_features"] decoder_features = prior_out["decoder_features"] - return self._f_comb(encoder_features, decoder_features) + return self.fcomb(encoder_features, decoder_features) def kl(self, img, seg): """Kullback-Leibler divergence between the posterior and the prior. @@ -694,13 +694,13 @@ def kl(self, img, seg): def topk_mask(self, score, k): """Returns a mask for the top-k elements in score.""" - values, _ = torch.topk(score, 1, axis=0) - _, indices = torch.topk(values, k, axis=1) + values, _ = torch.topk(score, 1, axis=1) + _, indices = torch.topk(values, k, axis=0) return torch.scatter_add( - torch.zeros(score.shape[1]).to(score.device), + torch.zeros(score.shape[0]).to(score.device), 0, indices.reshape(-1), - torch.ones(score.shape[1]).to(score.device), + torch.ones(score.shape[0]).to(score.device), ) def prepare_mask( @@ -722,7 +722,6 @@ def prepare_mask( # ), f"The loss mask shape differs from the target shape: {mask.shape} vs. {segm.shape}." mask = torch.reshape(mask, (-1,)) mask = mask.to(device) - mask = mask.repeat((1, num_classes)) if top_k_percentage is not None: @@ -736,14 +735,17 @@ def prepare_mask( else: # TODO Gumbel trick raise NotImplementedError("Still need to implement Gumbel trick") - # score = score + torch.log(mask) - score = score + torch.log(mask) + + score = score + torch.log(mask.unsqueeze(1).repeat((1, num_classes))) top_k_mask = self.topk_mask(score, k_pixels) top_k_mask = top_k_mask.to(device) mask = mask * top_k_mask - mask = mask.repeat(batch_size, 1) + mask = mask.unsqueeze(1).repeat((1, num_classes)) + mask = ( + mask.reshape((batch_size, -1, num_classes)).transpose(-1, 1).reshape((batch_size, -1)) + ) return mask @@ -760,19 +762,15 @@ def reconstruction_loss( reconstruction = self.reconstruct(img, segm) - # segm = torch.unsqueeze(segm, dim=1) - # not_seg = segm.logical_not() - # segm = torch.cat((not_seg, segm), dim=1).float() - ##### num_classes = reconstruction.shape[1] y_flat = torch.transpose(reconstruction, 1, -1).reshape((-1, num_classes)) t_flat = torch.transpose(segm, 1, -1).reshape((-1, num_classes)) + n_pixels_in_batch = y_flat.shape[0] batch_size = segm.shape[0] xe = criterion(input=y_flat, target=t_flat) xe = xe.reshape((batch_size, -1, num_classes)).transpose(-1, 1).reshape((batch_size, -1)) - n_pixels_in_batch = int(xe.shape[1] / num_classes) # If multiple masks supplied, compute a loss for each mask if hasattr(mask, "__iter__"): diff --git a/platipy/imaging/cnn/prob_unet.py b/platipy/imaging/cnn/prob_unet.py index fac4b8ff..58ae0946 100644 --- a/platipy/imaging/cnn/prob_unet.py +++ b/platipy/imaging/cnn/prob_unet.py @@ -430,7 +430,7 @@ def loss(self, segm, mask=None): z_posterior = self.posterior_latent_space.rsample() kl_div = torch.mean(self.kl_divergence()) - kl_div = torch.clamp(kl_div,0.0,100.0) + kl_div = torch.clamp(kl_div, 0.0, 100.0) top_k_percentage = None if "top_k_percentage" in self.loss_params: diff --git a/platipy/imaging/tests/test_probunet.py b/platipy/imaging/tests/test_probunet.py index d9379391..e39b3996 100644 --- a/platipy/imaging/tests/test_probunet.py +++ b/platipy/imaging/tests/test_probunet.py @@ -57,7 +57,7 @@ def test_prob_unet_2d_elbo(trainer_arg_parser): "--num_workers", "1", "--limit_train_batches", - "0.01", + "0.025", "--loss_type", "elbo", "--prob_type", @@ -84,7 +84,7 @@ def test_prob_unet_3d_elbo(trainer_arg_parser): "--num_workers", "1", "--limit_train_batches", - "0.05", + "0.1", "--loss_type", "elbo", "--prob_type", @@ -113,7 +113,7 @@ def test_prob_unet_2d_geco(trainer_arg_parser): "--num_workers", "1", "--limit_train_batches", - "0.01", + "0.025", "--loss_type", "geco", "--prob_type", @@ -140,7 +140,7 @@ def test_prob_unet_3d_geco(trainer_arg_parser): "--num_workers", "1", "--limit_train_batches", - "0.05", + "0.1", "--loss_type", "geco", "--prob_type", @@ -169,7 +169,7 @@ def test_hierarchical_prob_unet_2d_geco(trainer_arg_parser): "--num_workers", "1", "--limit_train_batches", - "0.01", + "0.025", "--loss_type", "geco", "--prob_type", @@ -185,3 +185,32 @@ def test_hierarchical_prob_unet_2d_geco(trainer_arg_parser): ) main(args) + + +def test_hierarchical_prob_unet_2d_geco_contour(trainer_arg_parser): + + args = trainer_arg_parser.parse_args( + [ + "--working_dir", + "test_prob_unet_2d_geco", + "--num_workers", + "1", + "--limit_train_batches", + "0.025", + "--loss_type", + "geco", + "--prob_type", + "hierarchical", + "--max_epochs", + "1", + "--ndims", + "2", + "--filters_per_layer", + "2", + "4", + "--kappa_contour", + "0.01", + ] + ) + + main(args) From 40f30ffc7d865a248a6a1fb56e1943fec6a88613 Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Wed, 22 Sep 2021 01:00:52 +0000 Subject: [PATCH 150/264] Correction to train --- platipy/imaging/cnn/train.py | 20 ++++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/platipy/imaging/cnn/train.py b/platipy/imaging/cnn/train.py index 9d21b3b1..65891f0c 100644 --- a/platipy/imaging/cnn/train.py +++ b/platipy/imaging/cnn/train.py @@ -370,15 +370,19 @@ def training_step(self, batch, _): loss = self.prob_unet.loss(y, mask=m) else: loss = self.prob_unet.loss(x, y, mask=m) - reg_loss = ( - l2_regularisation(self.prob_unet.posterior) - + l2_regularisation(self.prob_unet.prior) - + l2_regularisation(self.prob_unet.fcomb.layers) - ) - training_loss = loss["loss"] + 1e-5 * reg_loss + + training_loss = loss["loss"] + + if self.hparams.prob_type == "prob": + reg_loss = ( + l2_regularisation(self.prob_unet.posterior) + + l2_regularisation(self.prob_unet.prior) + + l2_regularisation(self.prob_unet.fcomb.layers) + ) + training_loss = training_loss + 1e-5 * reg_loss self.log( "training_loss", - training_loss, + training_loss.detach(), on_step=True, on_epoch=False, prog_bar=True, @@ -390,7 +394,7 @@ def training_step(self, batch, _): continue self.log( k, - loss[k], + loss[k].detach() if isinstance(loss[k], torch.Tensor) else loss[k], on_step=True, on_epoch=False, prog_bar=True, From 781bf4956c695babd0bbabd841e170eb47d4a664 Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Thu, 7 Oct 2021 09:34:58 +1100 Subject: [PATCH 151/264] Work on hierarchical probabilistic UNet --- platipy/imaging/cnn/hierarchical_prob_unet.py | 74 +++++++++++-------- platipy/imaging/cnn/train.py | 49 ++++++++---- 2 files changed, 77 insertions(+), 46 deletions(-) diff --git a/platipy/imaging/cnn/hierarchical_prob_unet.py b/platipy/imaging/cnn/hierarchical_prob_unet.py index 978d21dd..9679a049 100644 --- a/platipy/imaging/cnn/hierarchical_prob_unet.py +++ b/platipy/imaging/cnn/hierarchical_prob_unet.py @@ -32,7 +32,7 @@ def __init__( output_channels, n_down_channels=None, activation_fn=torch.nn.ReLU, - convs_per_block=3, + convs_per_block=2, ndims=2, ): """Create a residual block @@ -45,7 +45,7 @@ def __init__( activation_fn (torch.nn.Module, optional): The activation function to apply. Defaults to torch.nn.ReLU. convs_per_block (int, optional): The number of convolutions to perform within the - block. Defaults to 3. + block. Defaults to 2. ndims (int, optional): Specify whether to use 2 or 3 dimensions. Defaults to 2. """ @@ -86,7 +86,7 @@ def __init__( layers.append(resize_outgoing) self._layers = torch.nn.Sequential(*layers) - self._layers.apply(init_weights) +# self._layers.apply(init_weights) self._resize_skip = None @@ -98,7 +98,7 @@ def __init__( kernel_size=1, padding=0, ) - self._resize_skip.apply(init_weights) + # self._resize_skip.apply(init_weights) def forward(self, input_features): @@ -169,8 +169,8 @@ def __init__( channels_per_block, down_channels_per_block=None, activation_fn=torch.nn.ReLU, - convs_per_block=3, - blocks_per_level=3, + convs_per_block=2, + blocks_per_level=1, ndims=2, ): """Initializes a HierarchicalCore. @@ -190,9 +190,9 @@ def __init__( activation_fn (torch.nn.Module, optional): A callable activation function. Defaults to torch.nn.ReLU. convs_per_block (int, optional): An integer specifying the number of convolutional - layers. Defaults to 3. + layers. Defaults to 2. blocks_per_level (int, optional): An integer specifying the number of residual blocks - per level. Defaults to 3. + per level. Defaults to 1. ndims (int, optional): Specify whether to use 2 or 3 dimensions. Defaults to 2. """ @@ -233,7 +233,7 @@ def __init__( self.encoder_layers.append(torch.nn.Sequential(*layer)) - self.encoder_layers.apply(init_weights) + # self.encoder_layers.apply(init_weights) # Iterate the ascending levels in the (truncated) U-Net decoder. self.decoder_layers = torch.nn.ModuleList() @@ -272,8 +272,8 @@ def __init__( self.decoder_layers.append(torch.nn.Sequential(*layer)) - self._mu_logsigma_blocks.apply(init_zeros) - self.decoder_layers.apply(init_weights) + # self._mu_logsigma_blocks.apply(init_zeros) + # self.decoder_layers.apply(init_weights) def forward(self, inputs, mean=False, std_devs_from_mean=0.0, z_q=None): """Forward pass to sample from the module as specified. @@ -339,8 +339,8 @@ def forward(self, inputs, mean=False, std_devs_from_mean=0.0, z_q=None): latent_dim = self._latent_dims[level] mu_logsigma = self._mu_logsigma_blocks[level](decoder_features) - mu = mu_logsigma[:, :latent_dim] - log_sigma = mu_logsigma[:, latent_dim:] + mu = mu_logsigma[:, :latent_dim].clamp(-1000, 1000) + log_sigma = mu_logsigma[:, latent_dim:].clamp(-10, 10) dist = torch.distributions.Independent( torch.distributions.Normal(loc=mu, scale=torch.exp(log_sigma)), 1 @@ -386,8 +386,8 @@ def __init__( num_classes, down_channels_per_block=None, activation_fn=torch.nn.ReLU, - convs_per_block=3, - blocks_per_level=3, + convs_per_block=2, + blocks_per_level=1, ndims=2, ): """Initializes a StichtingDecoder. @@ -409,9 +409,9 @@ def __init__( initializers ([type], optional): [description]. Defaults to None. regularizers ([type], optional): [description]. Defaults to None. convs_per_block (int, optional): An integer specifying the number of convolutional - layers. Defaults to 3. + layers. Defaults to 2. blocks_per_level (int, optional): An integer specifying the number of residual blocks - per level. Defaults to 3. + per level. Defaults to 1. ndims (int, optional): Specify whether to use 2 or 3 dimensions. Defaults to 2. """ super(_StitchingDecoder, self).__init__() @@ -452,7 +452,7 @@ def __init__( decoder_in_channels = channels_per_block[::-1][level] self.layers.append(torch.nn.Sequential(*layer)) - self.layers.apply(init_weights) + # self.layers.apply(init_weights) if decoder_in_channels is None: decoder_in_channels = channels_per_block[::-1][self._num_levels - 1] @@ -464,7 +464,7 @@ def __init__( kernel_size=1, padding=0, ) - self.final_layer.apply(init_weights) + # self.final_layer.apply(init_weights) def forward(self, encoder_features, decoder_features): """Forward pass through the stiching decoder @@ -498,8 +498,8 @@ def __init__( filters_per_layer=None, down_channels_per_block=None, latent_dims=(1, 1, 1, 1), - convs_per_block=3, - blocks_per_level=3, + convs_per_block=2, + blocks_per_level=1, loss_type="elbo", loss_params={"beta": 1}, ndims=2, @@ -517,9 +517,9 @@ def __init__( latent_dims (tuple, optional): The number of latent dimensions at each layer. Defaults to (1, 1, 1, 1). convs_per_block (int, optional): An integer specifying the number of convolutional - layers. Defaults to 3. Defaults to 3. + layers. Defaults to 3. Defaults to 2. blocks_per_level (int, optional): An integer specifying the number of residual - blocks per level. Defaults to 3. + blocks per level. Defaults to 1. loss_kwargs (dict, optional): Dictionary of argument used by loss function. Defaults to None. ndims (int, optional): Specify whether to use 2 or 3 dimensions. Defaults to 2. @@ -571,6 +571,7 @@ def __init__( blocks_per_level=blocks_per_level, ndims=ndims, ) + self.ndims = ndims self._cache = None @@ -686,7 +687,14 @@ def kl(self, img, seg): kl = {} for level, (p, q) in enumerate(zip(p_dists, q_dists)): kl_per_pixel = torch.distributions.kl.kl_divergence(p, q) - kl_per_instance = torch.sum(kl_per_pixel, [1, 2]) + + if self.ndims == 2: + kl_per_instance = torch.sum(kl_per_pixel, [1, 2]) + else: + kl_per_instance = torch.sum(kl_per_pixel, [1, 2, 3]) + + kl_clamp = img.shape[2:].numel() * 10 + kl_per_instance = kl_per_instance.clamp(0, kl_clamp) kl[level] = torch.mean(kl_per_instance) return kl @@ -837,7 +845,8 @@ def loss(self, img, seg, mask=None): top_k_percentage = self.loss_params["top_k_percentage"] loss_mask = None - reconstruction_threshold = self.loss_params["kappa"] + if "kappa" in self.loss_params: + reconstruction_threshold = self.loss_params["kappa"] contour_threshold = None if "kappa_contour" in self.loss_params and self.loss_params["kappa_contour"] is not None: loss_mask = [None, mask] @@ -871,7 +880,7 @@ def loss(self, img, seg, mask=None): with torch.no_grad(): - moving_avg_factor = 0.8 + moving_avg_factor = 0.5 rl = rec_loss_mean.detach() if self._rec_moving_avg is None: @@ -901,12 +910,14 @@ def loss(self, img, seg, mask=None): lambda_lower_contour = self.loss_params["clamp_contour"][0] lambda_upper_contour = self.loss_params["clamp_contour"][1] - self._lambda = ( # pylint: disable=attribute-defined-outside-init - torch.exp(torch.Tensor([rc, cc]).to(rc.device)) * self._lambda - ) + self._lambda[0] = (torch.exp(rc) * self._lambda[0]).clamp(lambda_lower, lambda_upper) + if self._lambda[0].isnan(): self._lambda[0] = lambda_upper + if contour_threshold: + lambda_lower_contour = self.loss_params["clamp_contour"][0] + lambda_upper_contour = self.loss_params["clamp_contour"][1] - self._lambda[0] = self._lambda[0].clamp(lambda_lower, lambda_upper) - self._lambda[1] = self._lambda[1].clamp(lambda_lower_contour, lambda_upper_contour) + self._lambda[1] = (torch.exp(cc) * self._lambda[1]).clamp(lambda_lower_contour, lambda_upper_contour) + if self._lambda[1].isnan(): self._lambda[1] = lambda_upper_contour # pylint: disable=access-member-before-definition loss = (self._lambda[0] * reconstruction_loss) + kl_sum @@ -928,6 +939,7 @@ def loss(self, img, seg, mask=None): result["contour_constraint"] = cc result["moving_avg_contour"] = self._contour_moving_avg result["lambda_contour"] = self._lambda[1] + result = {**result, **kl_summaries} return result diff --git a/platipy/imaging/cnn/train.py b/platipy/imaging/cnn/train.py index 65891f0c..50588c0b 100644 --- a/platipy/imaging/cnn/train.py +++ b/platipy/imaging/cnn/train.py @@ -91,14 +91,16 @@ def __init__( down_channels_per_block=self.hparams.down_channels_per_block, latent_dims=[self.hparams.latent_dim] * (len(self.hparams.filters_per_layer) - 1), convs_per_block=self.hparams.convs_per_block, + blocks_per_level=self.hparams.blocks_per_level, loss_type=self.hparams.loss_type, loss_params=loss_params, ndims=self.hparams.ndims, ) self.validation_directory = None + self.kl_div = None - self.stddevs = np.linspace(-2, 2, self.hparams.num_observers) + self.stddevs = np.linspace(-3, 3, self.hparams.num_observers) @staticmethod def add_model_specific_args(parent_parser): @@ -114,7 +116,8 @@ def add_model_specific_args(parent_parser): parser.add_argument("--down_channels_per_block", nargs="+", type=int, default=None) parser.add_argument("--latent_dim", type=int, default=6) parser.add_argument("--no_convs_fcomb", type=int, default=4) - parser.add_argument("--convs_per_block", type=int, default=3) + parser.add_argument("--convs_per_block", type=int, default=2) + parser.add_argument("--blocks_per_level", type=int, default=1) parser.add_argument("--loss_type", type=str, default="elbo") parser.add_argument("--beta", type=float, default=1.0) parser.add_argument("--kappa", type=float, default=0.02) @@ -179,6 +182,8 @@ def infer( for i in range(num_samples) ] elif sample_strategy == "spaced": + if self.hparams.prob_type == "hierarchical": + latent_dim = [True] * (len(self.hparams.filters_per_layer) - 1) samples = [ { "name": f"spaced_{s:.2f}", @@ -243,7 +248,7 @@ def infer( else: y = self.prob_unet.sample( x, - mean=False, + mean=True, std_devs_from_mean=sample["std_dev_from_mean"], ) @@ -389,6 +394,8 @@ def training_step(self, batch, _): logger=True, ) + self.kl_div = loss["kl_div"].detach().cpu() + for k in loss: if k == "loss": continue @@ -464,15 +471,18 @@ def validation_epoch_end(self, validation_step_outputs): img_arr = np.load(img_file) img = sitk.GetImageFromArray(img_arr) img.SetSpacing(self.hparams.spacing) - - mean = self.infer(img, num_samples=1, sample_strategy="mean", preprocess=False) - samples = self.infer( - img, - sample_strategy="spaced", - num_samples=5, - spaced_range=[-1.5, 1.5], - preprocess=False, - ) + try: + mean = self.infer(img, num_samples=1, sample_strategy="mean", preprocess=False) + samples = self.infer( + img, + sample_strategy="spaced", + num_samples=5, + spaced_range=[-1.5, 1.5], + preprocess=False, + ) + except Exception as e: + print(f"ERROR DURING VALIDATION INFERENCE: {e}") + return observers = {} for _, observer in enumerate(cases[case]["observers"]): @@ -499,7 +509,11 @@ def validation_epoch_end(self, validation_step_outputs): mask.CopyInformation(img) observers[f"manual_{observer}"] = mask - result, fig = self.validate(img, observers, samples, mean, matching_type="best") + try: + result, fig = self.validate(img, observers, samples, mean, matching_type="best") + except Exception as e: + print(f"ERROR DURING VALIDATION VALIDATE: {e}") + return figure_path = f"valid_{case}.png" fig.savefig(figure_path, dpi=300) @@ -515,6 +529,11 @@ def validation_epoch_end(self, validation_step_outputs): for m in metrics: computed_metrics[f"{t}_{m}"] += result[t][m] + if self.kl_div: + p = np.array(computed_metrics["probnet_DSC"]).mean() + u = np.array(computed_metrics["unet_DSC"]).mean() + computed_metrics["scaled_DSC"] = ((p + u) / 2) + (p-u) - self.kl_div + for cm in computed_metrics: self.log( cm, @@ -586,9 +605,9 @@ def main(args, config_json_path=None): # Save the best model checkpoint_callback = ModelCheckpoint( - monitor="probnet_DSC", + monitor="scaled_DSC", dirpath=args.default_root_dir, - filename="probunet-{epoch:02d}-{probnet_DSC:.2f}", + filename="probunet-{epoch:02d}-{scaled_DSC:.2f}", save_top_k=1, mode="max", ) From 840cf9c2b84837677c67637a2ee483309415fdcf Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Wed, 6 Oct 2021 23:42:19 +0000 Subject: [PATCH 152/264] Add dropout to prob unet --- platipy/imaging/cnn/unet.py | 27 ++++++++++++++++++++++++++- 1 file changed, 26 insertions(+), 1 deletion(-) diff --git a/platipy/imaging/cnn/unet.py b/platipy/imaging/cnn/unet.py index f41adf93..378a35aa 100644 --- a/platipy/imaging/cnn/unet.py +++ b/platipy/imaging/cnn/unet.py @@ -43,6 +43,27 @@ def conv_nd(ndims=2, **kwargs): raise NotImplementedError("Only 2 or 3 dimensions are supported") +def dropout_nd(ndims=2, **kwargs): + """Get a 2D or 3D dropout layer + + Args: + ndims (int, optional): 2 or 3 dimensions. Defaults to 2. + + Raises: + NotImplementedError: Raised if ndims is not in 2 or 3 dimensions. + + Returns: + torch.nn.Dropout: The dropout layer + """ + + if ndims == 2: + return torch.nn.Dropout2d(**kwargs) + elif ndims == 3: + return torch.nn.Dropout3d(**kwargs) + + raise NotImplementedError("Only 2 or 3 dimensions are supported") + + def init_weights(m): if ( isinstance(m, torch.nn.Conv2d) @@ -136,7 +157,9 @@ def resize_up_func(in_channels, out_channels, scale=2, ndims=2): class Conv(torch.nn.Module): - def __init__(self, input_channels, output_channels, up_down_sample=0, ndims=2): + def __init__( + self, input_channels, output_channels, up_down_sample=0, dropout_probability=0.2, ndims=2 + ): super(Conv, self).__init__() @@ -160,6 +183,7 @@ def __init__(self, input_channels, output_channels, up_down_sample=0, ndims=2): ) ) layers.append(nn.ReLU(inplace=True)) + layers.append(dropout_nd(ndims=ndims, p=dropout_probability)) layers.append( conv_nd( ndims=ndims, @@ -169,6 +193,7 @@ def __init__(self, input_channels, output_channels, up_down_sample=0, ndims=2): padding=1, ) ) + layers.append(dropout_nd(ndims=ndims, p=dropout_probability)) layers.append(nn.ReLU(inplace=True)) self.layers = nn.Sequential(*layers) From 507f6f70e9e6ea6baf9701eeabcd821341184fa0 Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Wed, 13 Oct 2021 08:19:38 +1100 Subject: [PATCH 153/264] Add dropout --- platipy/imaging/cnn/prob_unet.py | 26 ++++++++++++++------------ platipy/imaging/cnn/unet.py | 14 ++++++++------ 2 files changed, 22 insertions(+), 18 deletions(-) diff --git a/platipy/imaging/cnn/prob_unet.py b/platipy/imaging/cnn/prob_unet.py index 58ae0946..e49ddf09 100644 --- a/platipy/imaging/cnn/prob_unet.py +++ b/platipy/imaging/cnn/prob_unet.py @@ -39,7 +39,7 @@ def __init__( down_sample = 0 if idx == 0 else -2 layers.append( - Conv(input_filters, output_filters, up_down_sample=down_sample, ndims=ndims) + Conv(input_filters, output_filters, up_down_sample=down_sample, ndims=ndims, dropout_probability=0.2) ) self.layers = torch.nn.Sequential(*layers) @@ -113,8 +113,9 @@ def forward(self, img, seg=None): if self.ndims == 3: mu_log_sigma = torch.squeeze(mu_log_sigma, dim=2) - mu = mu_log_sigma[:, : self.latent_dim] - log_sigma = mu_log_sigma[:, self.latent_dim :] + + mu = mu_log_sigma[:, :self.latent_dim].clamp(-1000, 1000) + log_sigma = mu_log_sigma[:, self.latent_dim:].clamp(-10, 10) # This is a multivariate normal with diagonal covariance matrix sigma # https://github.com/pytorch/pytorch/pull/11178 @@ -205,6 +206,7 @@ def __init__( loss_type="elbo", loss_params={"beta": 1}, ndims=2, + dropout_probability=0.2, ): super(ProbabilisticUnet, self).__init__() @@ -215,7 +217,7 @@ def __init__( self.z_prior_sample = 0 self.unet = UNet( - input_channels, num_classes, filters_per_layer, final_layer=False, ndims=ndims + input_channels, num_classes, filters_per_layer, final_layer=False, dropout_probability=dropout_probability, ndims=ndims ) self.prior = AxisAlignedConvGaussian( input_channels, filters_per_layer, latent_dim, ndims=ndims @@ -430,7 +432,7 @@ def loss(self, segm, mask=None): z_posterior = self.posterior_latent_space.rsample() kl_div = torch.mean(self.kl_divergence()) - kl_div = torch.clamp(kl_div, 0.0, 100.0) + #kl_div = torch.clamp(kl_div, 0.0, 100.0) top_k_percentage = None if "top_k_percentage" in self.loss_params: @@ -502,15 +504,15 @@ def loss(self, segm, mask=None): lambda_lower = self.loss_params["clamp_rec"][0] lambda_upper = self.loss_params["clamp_rec"][1] - lambda_lower_contour = self.loss_params["clamp_contour"][0] - lambda_upper_contour = self.loss_params["clamp_contour"][1] - self._lambda = ( # pylint: disable=attribute-defined-outside-init - torch.exp(torch.Tensor([rc, cc]).to(rc.device)) * self._lambda - ) + self._lambda[0] = (torch.exp(rc) * self._lambda[0]).clamp(lambda_lower, lambda_upper) + if self._lambda[0].isnan(): self._lambda[0] = lambda_upper + if contour_threshold: + lambda_lower_contour = self.loss_params["clamp_contour"][0] + lambda_upper_contour = self.loss_params["clamp_contour"][1] - self._lambda[0] = self._lambda[0].clamp(lambda_lower, lambda_upper) - self._lambda[1] = self._lambda[1].clamp(lambda_lower_contour, lambda_upper_contour) + self._lambda[1] = (torch.exp(cc) * self._lambda[1]).clamp(lambda_lower_contour, lambda_upper_contour) + if self._lambda[1].isnan(): self._lambda[1] = lambda_upper_contour # pylint: disable=access-member-before-definition loss = (self._lambda[0] * reconstruction_loss) + kl_div diff --git a/platipy/imaging/cnn/unet.py b/platipy/imaging/cnn/unet.py index 378a35aa..6b674c1e 100644 --- a/platipy/imaging/cnn/unet.py +++ b/platipy/imaging/cnn/unet.py @@ -74,7 +74,6 @@ def init_weights(m): torch.nn.init.kaiming_normal_(m.weight, mode="fan_in", nonlinearity="relu") truncated_normal_(m.bias, mean=0, std=0.001) - def l2_regularisation(m): l2_reg = None @@ -158,7 +157,7 @@ def resize_up_func(in_channels, out_channels, scale=2, ndims=2): class Conv(torch.nn.Module): def __init__( - self, input_channels, output_channels, up_down_sample=0, dropout_probability=0.2, ndims=2 + self, input_channels, output_channels, up_down_sample=0, dropout_probability=None, ndims=2 ): super(Conv, self).__init__() @@ -183,7 +182,8 @@ def __init__( ) ) layers.append(nn.ReLU(inplace=True)) - layers.append(dropout_nd(ndims=ndims, p=dropout_probability)) + if dropout_probability: + layers.append(dropout_nd(ndims=ndims, p=dropout_probability)) layers.append( conv_nd( ndims=ndims, @@ -193,7 +193,8 @@ def __init__( padding=1, ) ) - layers.append(dropout_nd(ndims=ndims, p=dropout_probability)) + if dropout_probability: + layers.append(dropout_nd(ndims=ndims, p=dropout_probability)) layers.append(nn.ReLU(inplace=True)) self.layers = nn.Sequential(*layers) @@ -218,6 +219,7 @@ def __init__( filters_per_layer=[64 * (2 ** x) for x in range(5)], final_layer=True, ndims=2, + dropout_probability=None ): super(UNet, self).__init__() @@ -229,7 +231,7 @@ def __init__( down_sample = 0 if idx == 0 else -2 self.encoder.append( - Conv(input_filters, output_filters, up_down_sample=down_sample, ndims=ndims) + Conv(input_filters, output_filters, up_down_sample=down_sample, dropout_probability=dropout_probability, ndims=ndims) ) reversed_filters = list(reversed(filters_per_layer)) @@ -242,7 +244,7 @@ def __init__( input_filters = layer_filters output_filters = reversed_filters[idx + 1] - self.decoder.append(Conv(input_filters, output_filters, up_down_sample=2, ndims=ndims)) + self.decoder.append(Conv(input_filters, output_filters, up_down_sample=2, dropout_probability=dropout_probability, ndims=ndims)) self.final = None if final_layer: From 4bb8d865400da4ae652f42094709944af7598a8c Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Tue, 9 Nov 2021 00:19:50 +0000 Subject: [PATCH 154/264] Working towards supporting multiple structures --- platipy/imaging/cnn/dataload.py | 87 +++++++++----- platipy/imaging/cnn/dataset.py | 152 +++++++++++++++--------- platipy/imaging/cnn/localise_net.py | 9 +- platipy/imaging/cnn/pseudo_generator.py | 52 +++++--- platipy/imaging/cnn/train.py | 38 +++--- 5 files changed, 224 insertions(+), 114 deletions(-) diff --git a/platipy/imaging/cnn/dataload.py b/platipy/imaging/cnn/dataload.py index ef5f89df..82350b55 100644 --- a/platipy/imaging/cnn/dataload.py +++ b/platipy/imaging/cnn/dataload.py @@ -20,12 +20,14 @@ def __init__( data_dir: str = "./data", augmented_dir: str = None, working_dir: str = "./working", + structures=["a", "b", "c"], + observers=["0", "1", "2", "3", "4"], case_glob="images/*.nii.gz", image_glob="images/{case}.nii.gz", - label_glob="labels/{case}_*.nii.gz", + label_glob="labels/{case}_{structure}_*.nii.gz", augmented_case_glob="{case}/*", augmented_image_glob="images/{augmented_case}.nii.gz", - augmented_label_glob="labels/{augmented_case}_*.nii.gz", + augmented_label_glob="labels/{augmented_case}_{structure}_*.nii.gz", augment_on_fly=True, fold=0, k_folds=5, @@ -71,6 +73,8 @@ def __init__( self.intensity_scaling = intensity_scaling self.intensity_window = intensity_window self.contour_mask_kernel = contour_mask_kernel + self.structures = structures + self.observers = observers self.crop_using_localise_model = crop_using_localise_model self.localise_voxel_grid_size = localise_voxel_grid_size @@ -98,9 +102,13 @@ def add_model_specific_args(parent_parser): parser.add_argument("--k_folds", type=int, default=5) parser.add_argument("--batch_size", type=int, default=5) parser.add_argument("--num_workers", type=int, default=4) + parser.add_argument("--structures", nargs="+", type=str, default=["a", "b", "c"]) + parser.add_argument("--observers", nargs="+", type=str, default=["0", "1", "2", "3", "4"]) parser.add_argument("--case_glob", type=str, default="images/*.nii.gz") parser.add_argument("--image_glob", type=str, default="images/{case}.nii.gz") - parser.add_argument("--label_glob", type=str, default="labels/{case}_*.nii.gz") + parser.add_argument( + "--label_glob", type=str, default="labels/{case}_{structure}_{observer}.nii.gz" + ) parser.add_argument("--augmented_case_glob", type=str, default=None) parser.add_argument("--augmented_image_glob", type=str, default=None) parser.add_argument("--augmented_label_glob", type=str, default=None) @@ -126,7 +134,7 @@ def setup(self, stage=None): cases.sort() random.shuffle(cases) # will be consistent for same value of 'seed everything' cases_per_fold = math.ceil(len(cases) / self.k_folds) - print(cases_per_fold) + for f in range(self.k_folds): if self.fold == f: @@ -148,15 +156,23 @@ def setup(self, stage=None): { "id": case, "image": self.data_dir.joinpath(self.image_glob.format(case=case)), - "label": [ - p - for p in self.data_dir.glob(self.label_glob.format(case=case)) - if not "_OLD" in p.name - ], + "observers": { + observer: { + structure: self.data_dir.joinpath( + self.label_glob.format( + case=case, structure=structure, observer=observer + ) + ) + for structure in self.structures + } + for observer in self.observers + }, } for case in self.train_cases ] + print(train_data) + # If a directory with augmented data is specified, use that for training as well if self.augmented_dir is not None: @@ -177,15 +193,20 @@ def setup(self, stage=None): case=case, augmented_case=augmented_case ) ), - "label": [ - p - for p in case_aug_dir.glob( - self.augmented_label_glob.format( - case=case, augmented_case=augmented_case + "observers": { + observer: { + structure: case_aug_dir.joinpath( + self.augmented_label_glob.format( + case=case, + augmented_case=augmented_case, + structure=structure, + observer=observer + ) ) - ) - if not "_OLD" in p.name - ], + for structure in self.structures + } + for observer in self.observers + }, } for augmented_case in augmented_cases ] @@ -194,11 +215,17 @@ def setup(self, stage=None): { "id": case, "image": self.data_dir.joinpath(self.image_glob.format(case=case)), - "label": [ - p - for p in self.data_dir.glob(self.label_glob.format(case=case)) - if not "_OLD" in p.name - ], + "observers": { + observer: { + structure: self.data_dir.joinpath( + self.label_glob.format( + case=case, structure=structure, observer=observer + ) + ) + for structure in self.structures + } + for observer in self.observers + }, } for case in self.validation_cases ] @@ -207,11 +234,17 @@ def setup(self, stage=None): { "id": case, "image": self.data_dir.joinpath(self.image_glob.format(case=case)), - "label": [ - p - for p in self.data_dir.glob(self.label_glob.format(case=case)) - if not "_OLD" in p.name - ], + "observers": { + observer: { + structure: self.data_dir.joinpath( + self.label_glob.format( + case=case, structure=structure, observer=observer + ) + ) + for structure in self.structures + } + for observer in self.observers + }, } for case in self.test_cases ] diff --git a/platipy/imaging/cnn/dataset.py b/platipy/imaging/cnn/dataset.py index f62b7ba8..d2b0f250 100644 --- a/platipy/imaging/cnn/dataset.py +++ b/platipy/imaging/cnn/dataset.py @@ -385,10 +385,6 @@ def __init__( case_id = case["id"] img_path = str(case["image"]) - structure_paths = case["label"] - if isinstance(structure_paths, (str, Path)): - structure_paths = [structure_paths] - existing_images = [i for i in self.img_dir.glob(f"{case_id}_*.npy")] if len(existing_images) > 0: logger.debug(f"Image for case already exist: {case_id}") @@ -402,18 +398,29 @@ def __init__( img_file = self.img_dir.joinpath(f"{case_id}_{z_slice}.npy") assert img_file.exists() - contour_mask_file = self.contour_mask_dir.joinpath(f"{case_id}_{z_slice}.npy") - assert contour_mask_file.exists() + for obs in case["observers"]: + + labels = [] + contour_mask_files = [] + for structure in case["observers"][obs]: + label_file = self.label_dir.joinpath( + f"{case_id}_{structure}_{obs}_{z_slice}.npy" + ) + assert label_file.exists() + labels.append(label_file) + + contour_mask_file = self.contour_mask_dir.joinpath( + f"{case_id}_{structure}_{z_slice}.npy" + ) + assert contour_mask_file.exists() + contour_mask_files.append(contour_mask_file) - for obs in range(len(structure_paths)): - label_file = self.label_dir.joinpath(f"{case_id}_{obs}_{z_slice}.npy") - assert label_file.exists() self.slices.append( { "z": z_slice, "image": img_file, - "label": label_file, - "contour_mask": contour_mask_file, + "labels": labels, + "contour_masks": contour_mask_files, "case": case_id, "observer": obs, } @@ -440,20 +447,45 @@ def __init__( intensity_window=intensity_window, ) - observers = [] - for obs, structure_path in enumerate(structure_paths): - structure_path = str(structure_path) - label = sitk.ReadImage(structure_path) - label = resample_mask_to_image(img, label) - observers.append(label) - - contour_mask = get_contour_mask(observers, kernel=contour_mask_kernel) + observers = {} + structure_names = [] + for obs in case["observers"]: + observers[obs] = {} + for structure in case["observers"][obs]: + structure_names.append(structure) + structure_path = str(case["observers"][obs][structure]) + label = sitk.ReadImage(structure_path) + label = resample_mask_to_image(img, label) + observers[obs][structure] = label + + contour_masks = {} + for structure in structure_names: + contour_masks[structure] = get_contour_mask( + [observers[obs][structure] for obs in case["observers"]], + kernel=contour_mask_kernel, + ) - if combine_observers == "union": - observers = [get_union_mask(observers)] + if combine_observers: + updated_observers = {"": {}} + for structure in structure_names: + if combine_observers == "union": + updated_observers[""][structure] = [ + get_union_mask( + [observers[obs][structure] for obs in case["observers"]] + ) + ] + elif combine_observers == "intersection": + updated_observers[""][structure] = [ + get_intersection_mask( + [observers[obs][structure] for obs in case["observers"]] + ) + ] + else: + raise NotImplementedError( + "combine_observers should be 'union' or 'intersection'" + ) - if combine_observers == "intersection": - observers = [get_intersection_mask(observers)] + observers = updated_observers z_range = range(img.GetSize()[2]) if ndims == 3: @@ -470,27 +502,38 @@ def __init__( np.save(img_file, sitk.GetArrayFromImage(img_slice)) # Save the contour mask slice - if ndims == 2: - contour_mask_slice = contour_mask[:, :, z_slice] - else: - contour_mask_slice = contour_mask - contour_mask_file = self.contour_mask_dir.joinpath(f"{case_id}_{z_slice}.npy") - np.save(contour_mask_file, sitk.GetArrayFromImage(contour_mask_slice)) - - for obs, label in enumerate(observers): - + contour_masks = [] + for structure in structure_names: if ndims == 2: - label_slice = label[:, :, z_slice] + contour_mask_slice = contour_masks[structure][:, :, z_slice] else: - label_slice = label - label_file = self.label_dir.joinpath(f"{case_id}_{obs}_{z_slice}.npy") - np.save(label_file, sitk.GetArrayFromImage(label_slice).astype(np.int8)) + contour_mask_slice = contour_masks[structure] + contour_mask_file = self.contour_mask_dir.joinpath( + f"{case_id}_{structure}_{z_slice}.npy" + ) + np.save(contour_mask_file, sitk.GetArrayFromImage(contour_mask_slice)) + contour_masks.append() + + for obs in observers: + + labels = [] + for structure in structure_names: + if ndims == 2: + label_slice = observers[obs][structure][:, :, z_slice] + else: + label_slice = observers[obs][structure] + label_file = self.label_dir.joinpath( + f"{case_id}_{structure}_{obs}_{z_slice}.npy" + ) + np.save(label_file, sitk.GetArrayFromImage(label_slice).astype(np.int8)) + labels.append(label_file) + self.slices.append( { "z": z_slice, "image": img_file, - "label": label_file, - "contour_mask": contour_mask_file, + "labels": labels, + "contour_masks": contour_masks, "case": case_id, "observer": obs, } @@ -502,36 +545,39 @@ def __len__(self): def __getitem__(self, index): img = np.load(self.slices[index]["image"]) - label = np.load(self.slices[index]["label"]) - contour_mask = np.load(self.slices[index]["contour_mask"]) + labels = [np.load(label_file) for label_file in self.slices[index]["labels"]] + contour_masks = [ + np.load(contour_mask_file) for contour_mask_file in self.slices[index]["contour_masks"] + ] if self.transforms: + masks = labels + contour_masks if self.ndims == 2: - seg_arr = np.concatenate( - (np.expand_dims(label, 2), np.expand_dims(contour_mask, 2)), 2 - ) - segmap = SegmentationMapsOnImage(seg_arr, shape=label.shape) + seg_arr = np.concatenate([np.expand_dims(m, 2) for m in masks], 2) + segmap = SegmentationMapsOnImage(seg_arr, shape=labels[0].shape) img, seg = self.transforms(image=img, segmentation_maps=segmap) - label = seg.get_arr()[:, :, 0].squeeze() - contour_mask = seg.get_arr()[:, :, 1].squeeze() + for idx, _ in enumerate(labels): + labels[idx] = seg.get_arr()[:, :, idx].squeeze() + contour_masks = seg.get_arr()[:, :, int(len(contour_masks) / 2) :].squeeze() else: - masks = [label, contour_mask] for aug in self.transforms: img, masks = aug.apply(img, masks) - label = masks[0] - contour_mask = masks[1] + labels = masks[: int(len(labels) / 2)] + contour_masks = masks[int(len(contour_masks) / 2) :] img = torch.FloatTensor(img) - label = torch.IntTensor(label) - contour_mask = torch.FloatTensor(contour_mask) + label = torch.IntTensor(np.concatenate([np.expand_dims(l, 0) for l in labels], 0)) + contour_mask = torch.FloatTensor( + np.concatenate([np.expand_dims(l, 0) for l in contour_masks], 0) + ) return ( img.unsqueeze(0), label, contour_mask, { - "case": self.slices[index]["case"], - "observer": self.slices[index]["observer"], + "case": str(self.slices[index]["case"]), + "observer": str(self.slices[index]["observer"]), "z": self.slices[index]["z"], }, ) diff --git a/platipy/imaging/cnn/localise_net.py b/platipy/imaging/cnn/localise_net.py index 5812d357..6d153ca2 100644 --- a/platipy/imaging/cnn/localise_net.py +++ b/platipy/imaging/cnn/localise_net.py @@ -103,7 +103,11 @@ def training_step(self, batch, _): criterion = torch.nn.BCEWithLogitsLoss(reduction="mean") + # Take the max of all structure to combine into one big structure to localise + y = y.max(axis=1).values y = torch.unsqueeze(y, dim=1) + + # Add a background for the localise UNet not_y = y.logical_not() y = torch.cat((not_y, y), dim=1).float() @@ -119,6 +123,7 @@ def validation_step(self, batch, _): with torch.set_grad_enabled(False): x, y, _, info = batch + y = y.max(axis=1).values for s in range(y.shape[0]): @@ -149,12 +154,12 @@ def validation_epoch_end(self, validation_step_outputs): for case, z, observer in zip(info["case"], info["z"], info["observer"]): if not case in cases: - cases[case] = {"slices": z.item(), "observers": [observer.item()]} + cases[case] = {"slices": z.item(), "observers": [observer]} else: if z.item() > cases[case]["slices"]: cases[case]["slices"] = z.item() if not observer in cases[case]["observers"]: - cases[case]["observers"].append(observer.item()) + cases[case]["observers"].append(observer) metrics = {"JI": [], "DSC": [], "HD": [], "ASD": []} for case in cases: diff --git a/platipy/imaging/cnn/pseudo_generator.py b/platipy/imaging/cnn/pseudo_generator.py index 5312ed64..0af3058a 100644 --- a/platipy/imaging/cnn/pseudo_generator.py +++ b/platipy/imaging/cnn/pseudo_generator.py @@ -10,7 +10,16 @@ from platipy.imaging import ImageVisualiser -def generate_pseudo_data(data_dir="data", cases=5, size=(24, 32, 32)): +def generate_pseudo_data(data_dir="data", cases=5, size=(24, 32, 32), structures=["a", "b", "c"]): + """Generates some Pseudo data to use for testing the CNN code + + Args: + data_dir (str, optional): Directory in which to store pseudo data. Defaults to "data". + cases (int, optional): Number of cases to generate data for. Defaults to 5. + size (tuple, optional): The size of the generated images. Defaults to (24, 32, 32). + structures (list, optional): A list of structure names to generate. Defaults to + ["a", "b", "c"]. + """ test_data_directory = Path(data_dir) @@ -53,20 +62,33 @@ def generate_pseudo_data(data_dir="data", cases=5, size=(24, 32, 32)): vis = ImageVisualiser(ct, cut=(int(size[0] / 2), ypos, xpos)) masks = {} - for obs_id, obs in enumerate(range(-4, 5, 2)): - obs_rad = sphere_rad + obs - - mask_arr = np.zeros(size) - mask_arr = insert_sphere( - mask_arr, sp_radius=obs_rad, sp_centre=(int(size[0] / 2), ypos, xpos) - ) - - mask = sitk.GetImageFromArray(mask_arr) - mask.CopyInformation(ct) - mask = sitk.Cast(mask, sitk.sitkUInt8) - sitk.WriteImage(mask, str(label_directory.joinpath(f"{case}_{obs_id}.nii.gz"))) - - masks[f"obs_{obs_id}_{obs_rad}"] = mask + for struct_id, structure in enumerate(structures): + + x_shift = y_shift = 0 + if struct_id > 0: + if struct_id % 2 == 0: + x_shift = struct_id + else: + y_shift = struct_id + + for obs_id, obs in enumerate(range(-4, 5, 2)): + obs_rad = sphere_rad + obs + + mask_arr = np.zeros(size) + mask_arr = insert_sphere( + mask_arr, + sp_radius=obs_rad, + sp_centre=(int(size[0] / 2), ypos + y_shift, xpos + x_shift), + ) + + mask = sitk.GetImageFromArray(mask_arr) + mask.CopyInformation(ct) + mask = sitk.Cast(mask, sitk.sitkUInt8) + sitk.WriteImage( + mask, str(label_directory.joinpath(f"{case}_{structure}_{obs_id}.nii.gz")) + ) + + masks[f"struct_{structure}_obs_{obs_id}_{obs_rad}"] = mask vis.add_contour(masks) vis.show() diff --git a/platipy/imaging/cnn/train.py b/platipy/imaging/cnn/train.py index 50588c0b..f3b2ff78 100644 --- a/platipy/imaging/cnn/train.py +++ b/platipy/imaging/cnn/train.py @@ -253,7 +253,7 @@ def infer( ) y = y.squeeze(0) - y = np.argmax(y.cpu().detach().numpy(), axis=0) + # y = np.argmax(y.cpu().detach().numpy(), axis=0) sample["preds"].append(y) result = {} @@ -265,14 +265,17 @@ def infer( pred_arr = np.expand_dims(pred_arr, 0) if len(sample["preds"]) > 1: pred_arr = np.stack(sample["preds"]) - pred = sitk.GetImageFromArray(pred_arr) - pred = sitk.Cast(pred, sitk.sitkUInt8) - pred.CopyInformation(img) - pred = postprocess_mask(pred) - pred = sitk.Resample(pred, img, sitk.Transform(), sitk.sitkNearestNeighbor) + for idx, structure in enumerate(self.hparams.structures): + pred = sitk.GetImageFromArray(pred_arr[idx]) + pred = pred > 0.5 # Threshold softmax at 0.5 + pred = sitk.Cast(pred, sitk.sitkUInt8) - result[sample["name"]] = pred + pred.CopyInformation(img) + pred = postprocess_mask(pred) + pred = sitk.Resample(pred, img, sitk.Transform(), sitk.sitkNearestNeighbor) + + result[sample["name"]][structure] = pred return result @@ -416,7 +419,6 @@ def validation_step(self, batch, _): with torch.set_grad_enabled(False): x, y, _, info = batch - for s in range(y.shape[0]): img_file = self.validation_directory.joinpath( @@ -427,7 +429,7 @@ def validation_step(self, batch, _): mask_file = self.validation_directory.joinpath( f"mask_{info['case'][s]}_{info['z'][s]}_{info['observer'][s]}.npy" ) - np.save(mask_file, y[s].squeeze(0).cpu().numpy()) + np.save(mask_file, y[s].cpu().numpy()) return info @@ -439,12 +441,12 @@ def validation_epoch_end(self, validation_step_outputs): for case, z, observer in zip(info["case"], info["z"], info["observer"]): if not case in cases: - cases[case] = {"slices": z.item(), "observers": [observer.item()]} + cases[case] = {"slices": z.item(), "observers": [observer]} else: if z.item() > cases[case]["slices"]: cases[case]["slices"] = z.item() if not observer in cases[case]["observers"]: - cases[case]["observers"].append(observer.item()) + cases[case]["observers"].append(observer) metrics = ["DSC", "HD", "ASD"] computed_metrics = { @@ -496,7 +498,7 @@ def validation_epoch_end(self, validation_step_outputs): mask_arrs.append(np.load(mask_file)) - mask_arr = np.stack(mask_arrs) + mask_arr = np.stack(mask_arrs, axis=1) else: mask_file = self.validation_directory.joinpath( @@ -504,10 +506,12 @@ def validation_epoch_end(self, validation_step_outputs): ) mask_arr = np.load(mask_file) - mask = sitk.GetImageFromArray(mask_arr) - mask = sitk.Cast(mask, sitk.sitkUInt8) - mask.CopyInformation(img) - observers[f"manual_{observer}"] = mask + observers[f"manual_{observer}"] = {} + for idx, structure in enumerate(self.hparams.structures): + mask = sitk.GetImageFromArray(mask_arr[idx]) + mask = sitk.Cast(mask, sitk.sitkUInt8) + mask.CopyInformation(img) + observers[f"manual_{observer}"][structure] = mask try: result, fig = self.validate(img, observers, samples, mean, matching_type="best") @@ -532,7 +536,7 @@ def validation_epoch_end(self, validation_step_outputs): if self.kl_div: p = np.array(computed_metrics["probnet_DSC"]).mean() u = np.array(computed_metrics["unet_DSC"]).mean() - computed_metrics["scaled_DSC"] = ((p + u) / 2) + (p-u) - self.kl_div + computed_metrics["scaled_DSC"] = ((p + u) / 2) + (p - u) - self.kl_div for cm in computed_metrics: self.log( From 882ee8d626e98a46f9441d15d4bae432a65b2ee0 Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Wed, 10 Nov 2021 08:24:46 +1100 Subject: [PATCH 155/264] Localise Net support multiple structures use case --- platipy/imaging/cnn/dataset.py | 36 +++++++++++++++++++---------- platipy/imaging/cnn/localise_net.py | 9 ++++---- platipy/imaging/cnn/utils.py | 2 ++ 3 files changed, 31 insertions(+), 16 deletions(-) diff --git a/platipy/imaging/cnn/dataset.py b/platipy/imaging/cnn/dataset.py index d2b0f250..f65cdd03 100644 --- a/platipy/imaging/cnn/dataset.py +++ b/platipy/imaging/cnn/dataset.py @@ -406,8 +406,11 @@ def __init__( label_file = self.label_dir.joinpath( f"{case_id}_{structure}_{obs}_{z_slice}.npy" ) - assert label_file.exists() - labels.append(label_file) + if label_file.exists(): + labels.append(label_file) + else: + print(label_file) + labels.append(None) contour_mask_file = self.contour_mask_dir.joinpath( f"{case_id}_{structure}_{z_slice}.npy" @@ -453,15 +456,18 @@ def __init__( observers[obs] = {} for structure in case["observers"][obs]: structure_names.append(structure) - structure_path = str(case["observers"][obs][structure]) - label = sitk.ReadImage(structure_path) - label = resample_mask_to_image(img, label) + structure_path = case["observers"][obs][structure] + + label = None + if structure_path.exists(): + label = sitk.ReadImage(str(structure_path)) + label = resample_mask_to_image(img, label) observers[obs][structure] = label contour_masks = {} for structure in structure_names: contour_masks[structure] = get_contour_mask( - [observers[obs][structure] for obs in case["observers"]], + [observers[obs][structure] for obs in case["observers"] if observers[obs][structure] is not None], kernel=contour_mask_kernel, ) @@ -471,13 +477,13 @@ def __init__( if combine_observers == "union": updated_observers[""][structure] = [ get_union_mask( - [observers[obs][structure] for obs in case["observers"]] + [observers[obs][structure] for obs in case["observers"] if observers[obs][structure] is not None] ) ] elif combine_observers == "intersection": updated_observers[""][structure] = [ get_intersection_mask( - [observers[obs][structure] for obs in case["observers"]] + [observers[obs][structure] for obs in case["observers"] if observers[obs][structure] is not None] ) ] else: @@ -502,7 +508,7 @@ def __init__( np.save(img_file, sitk.GetArrayFromImage(img_slice)) # Save the contour mask slice - contour_masks = [] + cmasks = [] for structure in structure_names: if ndims == 2: contour_mask_slice = contour_masks[structure][:, :, z_slice] @@ -512,12 +518,16 @@ def __init__( f"{case_id}_{structure}_{z_slice}.npy" ) np.save(contour_mask_file, sitk.GetArrayFromImage(contour_mask_slice)) - contour_masks.append() + cmasks.append(contour_mask_file) for obs in observers: labels = [] for structure in structure_names: + + if observers[obs][structure] is None: + labels.append(None) + continue if ndims == 2: label_slice = observers[obs][structure][:, :, z_slice] else: @@ -533,7 +543,7 @@ def __init__( "z": z_slice, "image": img_file, "labels": labels, - "contour_masks": contour_masks, + "contour_masks": cmasks, "case": case_id, "observer": obs, } @@ -545,7 +555,7 @@ def __len__(self): def __getitem__(self, index): img = np.load(self.slices[index]["image"]) - labels = [np.load(label_file) for label_file in self.slices[index]["labels"]] + labels = [np.load(label_file) if label_file else np.zeros(img.shape, dtype=np.ushort) for label_file in self.slices[index]["labels"]] contour_masks = [ np.load(contour_mask_file) for contour_mask_file in self.slices[index]["contour_masks"] ] @@ -570,6 +580,7 @@ def __getitem__(self, index): contour_mask = torch.FloatTensor( np.concatenate([np.expand_dims(l, 0) for l in contour_masks], 0) ) + label_present = [label is not None for label in self.slices[index]["labels"]] return ( img.unsqueeze(0), @@ -578,6 +589,7 @@ def __getitem__(self, index): { "case": str(self.slices[index]["case"]), "observer": str(self.slices[index]["observer"]), + "label_present": label_present, "z": self.slices[index]["z"], }, ) diff --git a/platipy/imaging/cnn/localise_net.py b/platipy/imaging/cnn/localise_net.py index 6d153ca2..348f6f1e 100644 --- a/platipy/imaging/cnn/localise_net.py +++ b/platipy/imaging/cnn/localise_net.py @@ -43,7 +43,7 @@ def __init__( self.unet = UNet( self.hparams.input_channels, - self.hparams.num_classes, + 2, # num_classes is always 2 for the localise net, just separating forground from background filters_per_layer=[32, 64, 128], final_layer=True, ) @@ -55,7 +55,6 @@ def add_model_specific_args(parent_parser): parser = parent_parser.add_argument_group("Localize UNet") parser.add_argument("--learning_rate", type=float, default=1e-3) parser.add_argument("--input_channels", type=int, default=1) - parser.add_argument("--num_classes", type=int, default=2) return parent_parser @@ -130,7 +129,7 @@ def validation_step(self, batch, _): img_file = self.validation_directory.joinpath( f"img_{info['case'][s]}_{info['z'][s]}.npy" ) - np.save(img_file, x[0].squeeze(0).cpu().numpy()) + np.save(img_file, x[s].squeeze(0).cpu().numpy()) mask_file = self.validation_directory.joinpath( f"mask_{info['case'][s]}_{info['z'][s]}_{info['observer'][s]}.npy" @@ -141,6 +140,7 @@ def validation_step(self, batch, _): pred_file = self.validation_directory.joinpath( f"pred_{info['case'][s]}_{info['z'][s]}.npy" ) + pred = np.argmax(pred.squeeze(0).cpu().numpy(), axis=0) np.save(pred_file, pred) @@ -216,7 +216,7 @@ def validation_epoch_end(self, validation_step_outputs): com = [int(i / 2) for i in mask.GetSize()] img_vis = ImageVisualiser(img, cut=com, figure_size_in=16) - img_vis.set_limits_from_label(mask, expansion=[0, 0, 0]) + #img_vis.set_limits_from_label(mask, expansion=[0, 0, 0]) contour_dict = {**obs_dict} contour_dict["pred"] = pred @@ -246,4 +246,5 @@ def validation_epoch_end(self, validation_step_outputs): on_epoch=True, prog_bar=False, logger=True, + batch_size=self.hparams.batch_size ) diff --git a/platipy/imaging/cnn/utils.py b/platipy/imaging/cnn/utils.py index f2ed1dcd..b863b0c6 100644 --- a/platipy/imaging/cnn/utils.py +++ b/platipy/imaging/cnn/utils.py @@ -126,6 +126,8 @@ def preprocess_image( img = img[x_from:x_to, y_from:y_to, :] + sitk.WriteImage(img, "tmp.nii.gz") + return img From b5a4683b79efd3cdf96929cc07b89f746601b847 Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Wed, 10 Nov 2021 04:25:22 +0000 Subject: [PATCH 156/264] Work towards supporting multiple structures in prob UNet --- platipy/imaging/cnn/dataset.py | 32 ++++++++++++++++++------- platipy/imaging/cnn/localise_net.py | 14 ++++++++--- platipy/imaging/cnn/pseudo_generator.py | 2 +- platipy/imaging/cnn/train.py | 14 ++++++----- 4 files changed, 43 insertions(+), 19 deletions(-) diff --git a/platipy/imaging/cnn/dataset.py b/platipy/imaging/cnn/dataset.py index f65cdd03..fcfba346 100644 --- a/platipy/imaging/cnn/dataset.py +++ b/platipy/imaging/cnn/dataset.py @@ -409,7 +409,6 @@ def __init__( if label_file.exists(): labels.append(label_file) else: - print(label_file) labels.append(None) contour_mask_file = self.contour_mask_dir.joinpath( @@ -467,7 +466,11 @@ def __init__( contour_masks = {} for structure in structure_names: contour_masks[structure] = get_contour_mask( - [observers[obs][structure] for obs in case["observers"] if observers[obs][structure] is not None], + [ + observers[obs][structure] + for obs in case["observers"] + if observers[obs][structure] is not None + ], kernel=contour_mask_kernel, ) @@ -477,13 +480,21 @@ def __init__( if combine_observers == "union": updated_observers[""][structure] = [ get_union_mask( - [observers[obs][structure] for obs in case["observers"] if observers[obs][structure] is not None] + [ + observers[obs][structure] + for obs in case["observers"] + if observers[obs][structure] is not None + ] ) ] elif combine_observers == "intersection": updated_observers[""][structure] = [ get_intersection_mask( - [observers[obs][structure] for obs in case["observers"] if observers[obs][structure] is not None] + [ + observers[obs][structure] + for obs in case["observers"] + if observers[obs][structure] is not None + ] ) ] else: @@ -555,7 +566,10 @@ def __len__(self): def __getitem__(self, index): img = np.load(self.slices[index]["image"]) - labels = [np.load(label_file) if label_file else np.zeros(img.shape, dtype=np.ushort) for label_file in self.slices[index]["labels"]] + labels = [ + np.load(label_file) if label_file else np.zeros(img.shape, dtype=np.ushort) + for label_file in self.slices[index]["labels"] + ] contour_masks = [ np.load(contour_mask_file) for contour_mask_file in self.slices[index]["contour_masks"] ] @@ -568,15 +582,15 @@ def __getitem__(self, index): img, seg = self.transforms(image=img, segmentation_maps=segmap) for idx, _ in enumerate(labels): labels[idx] = seg.get_arr()[:, :, idx].squeeze() - contour_masks = seg.get_arr()[:, :, int(len(contour_masks) / 2) :].squeeze() + contour_masks[idx] = seg.get_arr()[:, :, len(labels) + idx].squeeze() else: for aug in self.transforms: img, masks = aug.apply(img, masks) - labels = masks[: int(len(labels) / 2)] - contour_masks = masks[int(len(contour_masks) / 2) :] + labels = masks[: len(labels)] + contour_masks = masks[len(contour_masks) :] img = torch.FloatTensor(img) - label = torch.IntTensor(np.concatenate([np.expand_dims(l, 0) for l in labels], 0)) + label = torch.FloatTensor(np.concatenate([np.expand_dims(l, 0) for l in labels], 0)) contour_mask = torch.FloatTensor( np.concatenate([np.expand_dims(l, 0) for l in contour_masks], 0) ) diff --git a/platipy/imaging/cnn/localise_net.py b/platipy/imaging/cnn/localise_net.py index 348f6f1e..1afc5496 100644 --- a/platipy/imaging/cnn/localise_net.py +++ b/platipy/imaging/cnn/localise_net.py @@ -33,6 +33,14 @@ class LocaliseUNet(pl.LightningModule): + """Provides a Localisation UNet which can be used find the general area of an image where the + structures of interest are positioned. Would usually be used as a preprocessing step to another + network. + + The Localise UNet operates only on 2D slices. It only uses two classes: foreground and + background. + """ + def __init__( self, **kwargs, @@ -43,7 +51,7 @@ def __init__( self.unet = UNet( self.hparams.input_channels, - 2, # num_classes is always 2 for the localise net, just separating forground from background + 2, # num_classes is always 2 for the localise net foreground from background) filters_per_layer=[32, 64, 128], final_layer=True, ) @@ -216,7 +224,7 @@ def validation_epoch_end(self, validation_step_outputs): com = [int(i / 2) for i in mask.GetSize()] img_vis = ImageVisualiser(img, cut=com, figure_size_in=16) - #img_vis.set_limits_from_label(mask, expansion=[0, 0, 0]) + # img_vis.set_limits_from_label(mask, expansion=[0, 0, 0]) contour_dict = {**obs_dict} contour_dict["pred"] = pred @@ -246,5 +254,5 @@ def validation_epoch_end(self, validation_step_outputs): on_epoch=True, prog_bar=False, logger=True, - batch_size=self.hparams.batch_size + batch_size=self.hparams.batch_size, ) diff --git a/platipy/imaging/cnn/pseudo_generator.py b/platipy/imaging/cnn/pseudo_generator.py index 0af3058a..18420322 100644 --- a/platipy/imaging/cnn/pseudo_generator.py +++ b/platipy/imaging/cnn/pseudo_generator.py @@ -1,10 +1,10 @@ +import random from pathlib import Path import matplotlib.pyplot as plt import numpy as np import SimpleITK as sitk -import random from platipy.imaging.generation.image import insert_sphere from platipy.imaging import ImageVisualiser diff --git a/platipy/imaging/cnn/train.py b/platipy/imaging/cnn/train.py index f3b2ff78..8dda0389 100644 --- a/platipy/imaging/cnn/train.py +++ b/platipy/imaging/cnn/train.py @@ -75,7 +75,7 @@ def __init__( if self.hparams.prob_type == "prob": self.prob_unet = ProbabilisticUnet( self.hparams.input_channels, - self.hparams.num_classes, + len(self.hparams.structures), self.hparams.filters_per_layer, self.hparams.latent_dim, self.hparams.no_convs_fcomb, @@ -86,7 +86,7 @@ def __init__( elif self.hparams.prob_type == "hierarchical": self.prob_unet = HierarchicalProbabilisticUnet( input_channels=self.hparams.input_channels, - num_classes=self.hparams.num_classes, + num_classes=len(self.hparams.structures), filters_per_layer=self.hparams.filters_per_layer, down_channels_per_block=self.hparams.down_channels_per_block, latent_dims=[self.hparams.latent_dim] * (len(self.hparams.filters_per_layer) - 1), @@ -109,7 +109,6 @@ def add_model_specific_args(parent_parser): parser.add_argument("--learning_rate", type=float, default=1e-5) parser.add_argument("--lr_lambda", type=float, default=0.99) parser.add_argument("--input_channels", type=int, default=1) - parser.add_argument("--num_classes", type=int, default=2) parser.add_argument( "--filters_per_layer", nargs="+", type=int, default=[64 * (2 ** x) for x in range(5)] ) @@ -362,11 +361,14 @@ def validate( def training_step(self, batch, _): x, y, m, _ = batch + # print(x.shape) + # print(y.shape) + # print(m.shape) # Add background layer for one-hot encoding - y = torch.unsqueeze(y, dim=1) - not_y = y.logical_not() - y = torch.cat((not_y, y), dim=1).float() + # y = torch.unsqueeze(y, dim=1) + # not_y = y.logical_not() + # y = torch.cat((not_y, y), dim=1).float() # self.prob_unet.forward(x, y, training=True) if self.hparams.prob_type == "prob": From ea31bd8f13cb3407d2441c949cc0612a569f6b51 Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Wed, 10 Nov 2021 06:40:21 +0000 Subject: [PATCH 157/264] Working on supporting multiple structures --- platipy/imaging/cnn/train.py | 161 +++++++++++++++++++++-------------- 1 file changed, 95 insertions(+), 66 deletions(-) diff --git a/platipy/imaging/cnn/train.py b/platipy/imaging/cnn/train.py index 8dda0389..df1ca091 100644 --- a/platipy/imaging/cnn/train.py +++ b/platipy/imaging/cnn/train.py @@ -253,7 +253,7 @@ def infer( y = y.squeeze(0) # y = np.argmax(y.cpu().detach().numpy(), axis=0) - sample["preds"].append(y) + sample["preds"].append(y.cpu().detach().numpy()) result = {} for sample in samples: @@ -283,78 +283,107 @@ def validate( ): metrics = {"DSC": "max", "HD": "min", "ASD": "min"} + result = {} - contour_cmap = "coolwarm" - - intersection_mask = get_intersection_mask(manual_observers) - union_mask = get_union_mask(manual_observers) - vis = ImageVisualiser(img, cut=get_com(union_mask), figure_size_in=16, window=window) + contour_cmaps = ["RdPu", "YlOrRd", "GnBu"] + structures = self.hparmas.structures - vis.add_contour( - mean, color=plt.cm.get_cmap(contour_cmap)(0.5), linewidth=3, show_legend=False - ) - vis.add_contour( - manual_observers, color=[0.13, 0.67, 0.275], linewidth=0.5, show_legend=False + vis = ImageVisualiser( + img, cut=get_com(mean["mean"][structures[0]]), figure_size_in=16, window=window ) - vis.add_contour( - intersection_mask, name="intersection", color=[0.13, 0.67, 0.275], linewidth=3 - ) - vis.add_contour(union_mask, name="union", color=[0.13, 0.67, 0.275], linewidth=3) - vis.add_contour( - samples, - linewidth=1.5, - color={ - s: c - for s, c in zip( - samples, plt.cm.get_cmap(contour_cmap)(np.linspace(0, 1, len(samples))) - ) - }, - ) + mean_contours = {} + for idx, structure in enumerate(structures): - vis.set_limits_from_label(union_mask, expansion=30) + color_map = plt.cm.get_cmap(contour_cmaps[idx % len(len(structures))]) + mean_contours[f"mean_{structure}"] = mean["mean"][structure] - fig = vis.show() + vis.add_contour(mean_contours, color=color_map(0.35), linewidth=3, show_legend=False) - sim = {k: np.zeros((len(samples), len(manual_observers))) for k in metrics} - msim = {k: np.zeros((len(samples), len(manual_observers))) for k in metrics} - for sid, samp in enumerate(samples): - for oid, obs in enumerate(manual_observers): - sample_metrics = get_metrics(manual_observers[obs], samples[samp]) - mean_metrics = get_metrics(manual_observers[obs], mean["mean"]) - - for k in sample_metrics: - sim[k][sid, oid] = sample_metrics[k] - msim[k][sid, oid] = mean_metrics[k] - - result = {"probnet": {k: [] for k in metrics}, "unet": {k: [] for k in metrics}} - for k in sim: - - val = sim[k] - if matching_type == "hungarian": - if metrics[k] == "max": - val = -val - row_idx, col_idx = linear_sum_assignment(val) - prob_unet_mean = sim[k][row_idx, col_idx].mean() - else: - if metrics[k] == "max": - prob_unet_mean = val.max() + manual_color = color_map(0.9) + + manual_observers_struct = { + f"{man_struct}_{structure}": manual_observers[man_struct][structure] + for man_struct in manual_observers + } + + vis.add_contour( + manual_observers_struct, color=manual_color, linewidth=0.5, show_legend=False + ) + + intersection_mask = get_intersection_mask(manual_observers_struct) + union_mask = get_union_mask(manual_observers_struct) + + vis.add_contour( + intersection_mask, name="intersection", color=manual_color, linewidth=3 + ) + vis.add_contour(union_mask, name="union", color=manual_color, linewidth=3) + + samples_struct = { + f"{sample_struct}_{structure}": samples[sample_struct][structure] + for sample_struct in samples + } + vis.add_contour( + samples_struct, + linewidth=1.5, + color={ + s: c + for s, c in zip( + samples_struct, color_map(np.linspace(0.1, 0.7, len(samples_struct))) + ) + }, + ) + + # vis.set_limits_from_label(union_mask, expansion=30) + + sim = { + k: np.zeros((len(samples_struct), len(manual_observers_struct))) for k in metrics + } + msim = { + k: np.zeros((len(samples_struct), len(manual_observers_struct))) for k in metrics + } + for sid, samp in enumerate(samples_struct): + for oid, obs in enumerate(manual_observers_struct): + sample_metrics = get_metrics( + manual_observers_struct[obs], samples_struct[samp] + ) + mean_metrics = get_metrics(manual_observers_struct[obs], mean["mean"]) + + for k in sample_metrics: + sim[k][sid, oid] = sample_metrics[k] + msim[k][sid, oid] = mean_metrics[k] + + result[f"probnet_{structure}"] = {k: [] for k in metrics} + result[f"unet_{structure}"] = {k: [] for k in metrics} + for k in sim: + + val = sim[k] + if matching_type == "hungarian": + if metrics[k] == "max": + val = -val + row_idx, col_idx = linear_sum_assignment(val) + prob_unet_mean = sim[k][row_idx, col_idx].mean() else: - prob_unet_mean = val.min() - result["probnet"][k].append(prob_unet_mean) - - val = msim[k] - if matching_type == "hungarian": - if metrics[k] == "max": - val = -val - row_idx, col_idx = linear_sum_assignment(val) - unet_mean = msim[k][row_idx, col_idx].mean() - else: - if metrics[k] == "max": - unet_mean = val.max() + if metrics[k] == "max": + prob_unet_mean = val.max() + else: + prob_unet_mean = val.min() + result[f"probnet_{structure}"][k].append(prob_unet_mean) + + val = msim[k] + if matching_type == "hungarian": + if metrics[k] == "max": + val = -val + row_idx, col_idx = linear_sum_assignment(val) + unet_mean = msim[k][row_idx, col_idx].mean() else: - unet_mean = val.min() - result["unet"][k].append(unet_mean) + if metrics[k] == "max": + unet_mean = val.max() + else: + unet_mean = val.min() + result[f"unet_{structure}"][k].append(unet_mean) + + fig = vis.show() return result, fig @@ -426,7 +455,7 @@ def validation_step(self, batch, _): img_file = self.validation_directory.joinpath( f"img_{info['case'][s]}_{info['z'][s]}.npy" ) - np.save(img_file, x[0].squeeze(0).cpu().numpy()) + np.save(img_file, x[s].squeeze(0).cpu().numpy()) mask_file = self.validation_directory.joinpath( f"mask_{info['case'][s]}_{info['z'][s]}_{info['observer'][s]}.npy" @@ -441,11 +470,11 @@ def validation_epoch_end(self, validation_step_outputs): for info in validation_step_outputs: for case, z, observer in zip(info["case"], info["z"], info["observer"]): - if not case in cases: cases[case] = {"slices": z.item(), "observers": [observer]} else: if z.item() > cases[case]["slices"]: + cases[case]["slices"] = z.item() if not observer in cases[case]["observers"]: cases[case]["observers"].append(observer) From 6e0d76a631261b8c067270ba5a9ba7f259fff180 Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Thu, 11 Nov 2021 13:54:51 +1100 Subject: [PATCH 158/264] Update attribute name --- platipy/imaging/cnn/localise_net.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/platipy/imaging/cnn/localise_net.py b/platipy/imaging/cnn/localise_net.py index 1afc5496..01204483 100644 --- a/platipy/imaging/cnn/localise_net.py +++ b/platipy/imaging/cnn/localise_net.py @@ -79,7 +79,7 @@ def configure_optimizers(self): def infer(self, img): pp_img = preprocess_image( - img, spacing=self.hparams.spacing, crop_to_grid_size_xy=self.hparams.crop_to_mm + img, spacing=self.hparams.spacing, crop_to_grid_size_xy=self.hparams.crop_to_grid_size_xy ) preds = [] From 7b4a2739d09b945c7eaa746412befd7626d29b34 Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Thu, 11 Nov 2021 14:00:24 +1100 Subject: [PATCH 159/264] Fix issues with validating multiple structures --- platipy/imaging/cnn/train.py | 49 ++++++++++++++++++++++-------------- 1 file changed, 30 insertions(+), 19 deletions(-) diff --git a/platipy/imaging/cnn/train.py b/platipy/imaging/cnn/train.py index df1ca091..3bf48ffc 100644 --- a/platipy/imaging/cnn/train.py +++ b/platipy/imaging/cnn/train.py @@ -217,6 +217,7 @@ def infer( ) img_arr = sitk.GetArrayFromImage(img) + if self.hparams.ndims == 2: slices = [img_arr[z, :, :] for z in range(img_arr.shape[0])] else: @@ -261,9 +262,11 @@ def infer( pred_arr = sample["preds"][0] if self.hparams.ndims == 2: - pred_arr = np.expand_dims(pred_arr, 0) + pred_arr = np.expand_dims(pred_arr, 1) if len(sample["preds"]) > 1: - pred_arr = np.stack(sample["preds"]) + pred_arr = np.stack(sample["preds"], axis=1) + + result[sample["name"]] = {} for idx, structure in enumerate(self.hparams.structures): pred = sitk.GetImageFromArray(pred_arr[idx]) @@ -286,16 +289,19 @@ def validate( result = {} contour_cmaps = ["RdPu", "YlOrRd", "GnBu"] - structures = self.hparmas.structures + structures = self.hparams.structures - vis = ImageVisualiser( - img, cut=get_com(mean["mean"][structures[0]]), figure_size_in=16, window=window - ) + try: + cut = get_com(mean["mean"][structures[0]]) + except ValueError: + cut = [int(i/2) for i in img.GetSize()][::-1] + + vis = ImageVisualiser(img, cut=cut, figure_size_in=16, window=window) mean_contours = {} for idx, structure in enumerate(structures): - color_map = plt.cm.get_cmap(contour_cmaps[idx % len(len(structures))]) + color_map = plt.cm.get_cmap(contour_cmaps[idx % len(structures)]) mean_contours[f"mean_{structure}"] = mean["mean"][structure] vis.add_contour(mean_contours, color=color_map(0.35), linewidth=3, show_legend=False) @@ -315,9 +321,9 @@ def validate( union_mask = get_union_mask(manual_observers_struct) vis.add_contour( - intersection_mask, name="intersection", color=manual_color, linewidth=3 + intersection_mask, name=f"intersection_{structure}", color=manual_color, linewidth=3 ) - vis.add_contour(union_mask, name="union", color=manual_color, linewidth=3) + vis.add_contour(union_mask, name=f"union_{structure}", color=manual_color, linewidth=3) samples_struct = { f"{sample_struct}_{structure}": samples[sample_struct][structure] @@ -347,7 +353,7 @@ def validate( sample_metrics = get_metrics( manual_observers_struct[obs], samples_struct[samp] ) - mean_metrics = get_metrics(manual_observers_struct[obs], mean["mean"]) + mean_metrics = get_metrics(manual_observers_struct[obs], mean_contours[f"mean_{structure}"]) for k in sample_metrics: sim[k][sid, oid] = sample_metrics[k] @@ -481,8 +487,8 @@ def validation_epoch_end(self, validation_step_outputs): metrics = ["DSC", "HD", "ASD"] computed_metrics = { - **{f"probnet_{m}": [] for m in metrics}, - **{f"unet_{m}": [] for m in metrics}, + **{f"probnet_{s}_{m}": [] for m in metrics for s in self.hparams.structures}, + **{f"unet_{s}_{m}": [] for m in metrics for s in self.hparams.structures}, } for case in cases: @@ -544,11 +550,11 @@ def validation_epoch_end(self, validation_step_outputs): mask.CopyInformation(img) observers[f"manual_{observer}"][structure] = mask - try: - result, fig = self.validate(img, observers, samples, mean, matching_type="best") - except Exception as e: - print(f"ERROR DURING VALIDATION VALIDATE: {e}") - return + #try: + result, fig = self.validate(img, observers, samples, mean, matching_type="best") + #except Exception as e: + # print(f"ERROR DURING VALIDATION VALIDATE: {e}") + # return figure_path = f"valid_{case}.png" fig.savefig(figure_path, dpi=300) @@ -565,8 +571,13 @@ def validation_epoch_end(self, validation_step_outputs): computed_metrics[f"{t}_{m}"] += result[t][m] if self.kl_div: - p = np.array(computed_metrics["probnet_DSC"]).mean() - u = np.array(computed_metrics["unet_DSC"]).mean() + p = u = 0 + for s in self.hparams.structures: + p += np.array(computed_metrics[f"probnet_{s}_DSC"]).mean() + u += np.array(computed_metrics[f"unet_{s}_DSC"]).mean() + + p /= len(self.hparams.structures) + u /= len(self.hparams.structures) computed_metrics["scaled_DSC"] = ((p + u) / 2) + (p - u) - self.kl_div for cm in computed_metrics: From e34fee6817c2ab698030e1006166cf7677c5034c Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Mon, 15 Nov 2021 10:53:29 +1100 Subject: [PATCH 160/264] Add sigmoid to output --- platipy/imaging/cnn/train.py | 1 + 1 file changed, 1 insertion(+) diff --git a/platipy/imaging/cnn/train.py b/platipy/imaging/cnn/train.py index 3bf48ffc..d38e0759 100644 --- a/platipy/imaging/cnn/train.py +++ b/platipy/imaging/cnn/train.py @@ -254,6 +254,7 @@ def infer( y = y.squeeze(0) # y = np.argmax(y.cpu().detach().numpy(), axis=0) + y = torch.sigmoid(y) sample["preds"].append(y.cpu().detach().numpy()) result = {} From 89450a4096996a24d4ed34ac016fdd59d187da00 Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Wed, 17 Nov 2021 09:10:12 +1100 Subject: [PATCH 161/264] Trac pos weight --- platipy/imaging/cnn/prob_unet.py | 54 ++++++++++++++++++-------------- 1 file changed, 31 insertions(+), 23 deletions(-) diff --git a/platipy/imaging/cnn/prob_unet.py b/platipy/imaging/cnn/prob_unet.py index e49ddf09..d8509bba 100644 --- a/platipy/imaging/cnn/prob_unet.py +++ b/platipy/imaging/cnn/prob_unet.py @@ -239,6 +239,8 @@ def __init__( self._contour_moving_avg = None self.register_buffer("_lambda", torch.zeros(2, requires_grad=False)) + self.register_buffer("_pos_weight", torch.ones(num_classes, requires_grad=False)) + def forward(self, img, seg=None, training=False): """ Construct prior latent space for patch and run patch through UNet, @@ -324,33 +326,33 @@ def prepare_mask( ): if mask is None or mask.sum() == 0: mask = torch.ones(n_pixels_in_batch) - else: - # assert ( - # mask.shape == segm.shape - # ), f"The loss mask shape differs from the target shape: {mask.shape} vs. {segm.shape}." - mask = torch.reshape(mask, (-1,)) - mask = mask.to(device) + mask = mask.to(device) - if top_k_percentage is not None: + if top_k_percentage is not None: - assert 0.0 < top_k_percentage <= 1.0 - k_pixels = int(n_pixels_in_batch * top_k_percentage) + assert 0.0 < top_k_percentage <= 1.0 + k_pixels = int(n_pixels_in_batch * top_k_percentage) - with torch.no_grad(): - norm_xe = xe / torch.sum(xe) - if deterministic: - score = torch.log(norm_xe) - else: - # TODO Gumbel trick - raise NotImplementedError("Still need to implement Gumbel trick") + with torch.no_grad(): + norm_xe = xe / torch.sum(xe) + if deterministic: + score = torch.log(norm_xe) + else: + # TODO Gumbel trick + raise NotImplementedError("Still need to implement Gumbel trick") + + score = score + torch.log(mask.unsqueeze(1).repeat((1, num_classes))) - score = score + torch.log(mask.unsqueeze(1).repeat((1, num_classes))) + top_k_mask = self.topk_mask(score, k_pixels) + top_k_mask = top_k_mask.to(device) + mask = mask * top_k_mask + + mask = mask.unsqueeze(1).repeat((1, num_classes)) + + else: + mask = torch.reshape(mask, (-1,)) - top_k_mask = self.topk_mask(score, k_pixels) - top_k_mask = top_k_mask.to(device) - mask = mask * top_k_mask - mask = mask.unsqueeze(1).repeat((1, num_classes)) mask = ( mask.reshape((batch_size, -1, num_classes)).transpose(-1, 1).reshape((batch_size, -1)) ) @@ -366,8 +368,6 @@ def reconstruction_loss( deterministic=True, ): - criterion = torch.nn.BCEWithLogitsLoss(reduction="none") - if z_posterior is None: z_posterior = self.posterior_latent_space.rsample() @@ -380,6 +380,14 @@ def reconstruction_loss( n_pixels_in_batch = y_flat.shape[0] batch_size = segm.shape[0] + pos_class_count = t_flat.sum(axis=0)/batch_size + neg_class_count = torch.logical_not(t_flat).sum(axis=0)/batch_size + self._pos_weight = self._pos_weight * 0.5 + pos_class_count/neg_class_count * 0.5 + print(pos_class_count) + print(neg_class_count) + print(self._pos_weight) + + criterion = torch.nn.BCEWithLogitsLoss(reduction="none", pos_weight=self._pos_weight) xe = criterion(input=y_flat, target=t_flat) xe = xe.reshape((batch_size, -1, num_classes)).transpose(-1, 1).reshape((batch_size, -1)) From 4da54a28e311fdd1416a82db23433dac785c0b0d Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Mon, 20 Dec 2021 15:49:29 +1100 Subject: [PATCH 162/264] Update to multiple structure training --- platipy/imaging/cnn/prob_unet.py | 5 +---- platipy/imaging/cnn/train.py | 3 --- 2 files changed, 1 insertion(+), 7 deletions(-) diff --git a/platipy/imaging/cnn/prob_unet.py b/platipy/imaging/cnn/prob_unet.py index d8509bba..ecd8e60c 100644 --- a/platipy/imaging/cnn/prob_unet.py +++ b/platipy/imaging/cnn/prob_unet.py @@ -382,10 +382,7 @@ def reconstruction_loss( pos_class_count = t_flat.sum(axis=0)/batch_size neg_class_count = torch.logical_not(t_flat).sum(axis=0)/batch_size - self._pos_weight = self._pos_weight * 0.5 + pos_class_count/neg_class_count * 0.5 - print(pos_class_count) - print(neg_class_count) - print(self._pos_weight) + self._pos_weight = self._pos_weight * 0.5 + (neg_class_count/pos_class_count).clamp(0, 10000) * 0.5 criterion = torch.nn.BCEWithLogitsLoss(reduction="none", pos_weight=self._pos_weight) xe = criterion(input=y_flat, target=t_flat) diff --git a/platipy/imaging/cnn/train.py b/platipy/imaging/cnn/train.py index d38e0759..f0598f8e 100644 --- a/platipy/imaging/cnn/train.py +++ b/platipy/imaging/cnn/train.py @@ -397,9 +397,6 @@ def validate( def training_step(self, batch, _): x, y, m, _ = batch - # print(x.shape) - # print(y.shape) - # print(m.shape) # Add background layer for one-hot encoding # y = torch.unsqueeze(y, dim=1) From 1ae54f4830b3fec8836602c49998f3ee4c38d660 Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Tue, 18 Jan 2022 05:46:48 +0000 Subject: [PATCH 163/264] Augmentation of fail cases using DVF --- platipy/imaging/generation/dvf.py | 81 +++++++++++++++++++++++++++++++ 1 file changed, 81 insertions(+) diff --git a/platipy/imaging/generation/dvf.py b/platipy/imaging/generation/dvf.py index 142405a6..af832c4d 100644 --- a/platipy/imaging/generation/dvf.py +++ b/platipy/imaging/generation/dvf.py @@ -25,6 +25,8 @@ fast_symmetric_forces_demons_registration, ) +from platipy.imaging.label.utils import get_com + def generate_field_shift(mask_image, vector_shift=(10, 10, 10), gaussian_smooth=5): """ @@ -397,3 +399,82 @@ def generate_field_radial_bend( ) return reference_image_bend, dvf_tfm, dvf_template + + +def expand_mask_towards_target( + mask_image, target_image, expand_mag=20, gaussian_smooth=5, dvf_overlap_into_mask=3 +): + """Generate a deformation vector field to expand a mask towards a target mask. Can be useful to + manipulate structures for augmentation of fail cases for automated contour QA work. + + Args: + mask_image (sitk.Image): The mask of the structure of manipulate + target_image (sitk.Image): The mask of the target structure to expand towards + expand_mag (int, optional): The magnitude of the expansion in mm. Defaults to 20. + gaussian_smooth (int, optional): Scale of a Gaussian kernel used to smooth the + deformation vector field.. Defaults to 5. + dvf_overlap_into_mask (int, optional): Defines how much overlap the deformation field + into the mask image. Effects how much of the structure is deformed. Defaults to 3. + + Returns: + [SimpleITK.Image]: The binary mask following the expansion. + [SimpleITK.DisplacementFieldTransform]: The transform representing the expansion. + [SimpleITK.Image]: The displacement vector field representing the expansion. + """ + + # Remove any potential overlap between the target and the mask + target_image = sitk.MaskNegated(target_image, mask_image) + + # Determine the vector to expand the mask towards + mask_com = get_com(mask_image, as_int=False, real_coords=True) + target_com = get_com(target_image, as_int=False, real_coords=True) + + expand_vec = np.array([p - q for p, q in zip(target_com, mask_com)]) + expand_vec = expand_vec / np.linalg.norm(expand_vec) + + mask_image_arr = sitk.GetArrayFromImage(mask_image) + + # Compute the distance map from the target to every other voxel + dist_map = sitk.SignedMaurerDistanceMap(target_image, squaredDistance=False) + dist_map_arr = sitk.GetArrayFromImage(dist_map) + dist_map_arr[dist_map_arr < 0] = 0 + + # Manipulate the distance map so that only voxel within the range of dvf_overlap_into_mask are + # kept + dist_from_mask_to_target = dist_map_arr[mask_image_arr > 0].min() + max_mask_dist = dist_map_arr[mask_image_arr > 0].max() + dist_map_arr[dist_map_arr > max_mask_dist] = max_mask_dist + dist_map_arr[dist_map_arr > dist_from_mask_to_target + dvf_overlap_into_mask] = ( + dist_from_mask_to_target + dvf_overlap_into_mask + ) + + dvf_weight = np.zeros(dist_map_arr.shape) + dvf_weight[dist_map_arr < dist_from_mask_to_target + dvf_overlap_into_mask] = 1 + dvf_weight = np.tile(np.expand_dims(dvf_weight, axis=3), [1, 1, 1, 3]) + + # The template deformation field + # Used to generate transforms + dvf_arr = np.zeros(mask_image_arr.shape + (3,)) + dvf_arr = dvf_arr - np.array([[[expand_vec * expand_mag]]]) + + # Weight the deformation vectors by the manipulated distance map + dvf_arr = dvf_arr * dvf_weight + dvf_template = sitk.GetImageFromArray(dvf_arr) + + # Copy image information + dvf_template.CopyInformation(mask_image) + + if np.any(gaussian_smooth): + + if not hasattr(gaussian_smooth, "__iter__"): + gaussian_smooth = (gaussian_smooth,) * 3 + + dvf_template = sitk.SmoothingRecursiveGaussian(dvf_template, gaussian_smooth) + + dvf_tfm = sitk.DisplacementFieldTransform(sitk.Cast(dvf_template, sitk.sitkVectorFloat64)) + + mask_image_expanded = apply_transform( + mask_image, transform=dvf_tfm, default_value=0, interpolator=sitk.sitkNearestNeighbor + ) + + return mask_image_expanded, dvf_tfm, dvf_template From 42bfc46f6d92fcc32d90a7f1b370008e3d2eaf5f Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Tue, 1 Mar 2022 10:12:56 +1100 Subject: [PATCH 164/264] Work on DVF augmentation --- .gitattributes | 3 + .gitignore | 8 +- .pylintrc | 11 +- platipy/imaging/generation/augment.py | 324 ++++++++++++++++++++++---- platipy/imaging/generation/dvf.py | 110 ++++++--- 5 files changed, 373 insertions(+), 83 deletions(-) diff --git a/.gitattributes b/.gitattributes index e8fc248b..985fc46b 100644 --- a/.gitattributes +++ b/.gitattributes @@ -5,7 +5,10 @@ $.dcm binary *.tar.gz binary *.png binary +*.jpeg binary +*.jpg binary *.ttf binary +*.pickle binary *.woff binary *.ipynb filter=nbstripout diff --git a/.gitignore b/.gitignore index e69b1fa8..b7918abc 100644 --- a/.gitignore +++ b/.gitignore @@ -140,8 +140,14 @@ platipy/*/tests/data testing/ converted/ **/data +**/working **/tcia **/nifti_output # Don't include html docs in repo -docs/site/ \ No newline at end of file +docs/site/ + +*.npy +*.nii.gz + +test_prob*/ \ No newline at end of file diff --git a/.pylintrc b/.pylintrc index 7f7aa386..68118e59 100644 --- a/.pylintrc +++ b/.pylintrc @@ -139,10 +139,11 @@ disable=print-statement, deprecated-sys-function, exception-escape, comprehension-escape, - C0330, - C0114, - W0102, - W0105 + bad-continuation, + missing-module-docstring, + # pointless-string-statement, + dangerous-default-value, + arguments-differ # Enable the message, report, category or checker with the given id(s). You can # either give multiple identifier separated by comma (,) or put this option @@ -443,7 +444,7 @@ contextmanager-decorators=contextlib.contextmanager # List of members which are set dynamically and missed by pylint inference # system, and so shouldn't trigger E1101 when accessed. Python regular # expressions are accepted. -generated-members= +generated-members=torch.*,pytorch_lightning.* # Tells whether missing members accessed in mixin class should be ignored. A # mixin class is detected if its name ends with "mixin" (case insensitive). diff --git a/platipy/imaging/generation/augment.py b/platipy/imaging/generation/augment.py index b6fd7455..338de085 100644 --- a/platipy/imaging/generation/augment.py +++ b/platipy/imaging/generation/augment.py @@ -16,7 +16,17 @@ from collections.abc import Iterable import random +from pathlib import Path + +from argparse import ArgumentParser + import SimpleITK as sitk +import numpy as np + +from loguru import logger + +import matplotlib.pyplot as plt +from platipy.imaging import ImageVisualiser from platipy.imaging.generation.dvf import ( generate_field_shift, @@ -29,6 +39,10 @@ from platipy.imaging.registration.utils import apply_transform +from platipy.imaging.utils.lung import detect_holes +from platipy.imaging.label.utils import get_union_mask +from platipy.imaging.utils.crop import label_to_roi, crop_to_roi + def apply_augmentation(image, augmentation, masks=[]): @@ -44,23 +58,27 @@ def apply_augmentation(image, augmentation, masks=[]): "DeformableAugment's" ) - transforms = [] + # transforms = [] + transform = None dvf = None for aug in augmentation: if not isinstance(aug, DeformableAugment): raise AttributeError("Each augmentation must be of type DeformableAugment") + logger.debug(str(aug)) tfm, field = aug.augment() - transforms.append(tfm) + # transforms.append(tfm) if dvf is None: dvf = field + transform = tfm else: dvf += field + transform = sitk.CompositeTransform([transform, tfm]) - transform = sitk.CompositeTransform(transforms) - del transforms + # transform = sitk.CompositeTransform(transforms) + # del transforms image_deformed = apply_transform( image, @@ -71,51 +89,35 @@ def apply_augmentation(image, augmentation, masks=[]): masks_deformed = [] for mask in masks: - masks_deformed.append( - apply_transform( - mask, transform=transform, default_value=0, interpolator=sitk.sitkNearestNeighbor - ) + def_mask = apply_transform( + mask, transform=transform, default_value=0, interpolator=sitk.sitkNearestNeighbor ) + def_mask = sitk.BinaryMorphologicalClosing(def_mask, [3, 3, 3]) + + masks_deformed.append(def_mask) + if masks: return image_deformed, masks_deformed, dvf return image_deformed, dvf -def generate_random_augmentation(ct_image, masks): - - random.shuffle(masks) - # mask_count = len(masks) - # masks = masks[: random.randint(2, 5)] - - # print(len(masks)) - augmentation_types = [ - { - "class": ShiftAugment, - "args": {"vector_shift": [(-10, 10), (10, 10), (-10, 10)], "gaussian_smooth": (3, 5)}, - }, - { - "class": ContractAugment, - "args": { - "vector_contract": [(0, 10), (0, 10), (0, 10)], - "gaussian_smooth": (3, 5), - "bone_mask": True, - }, - }, - { - "class": ExpandAugment, - "args": { - "vector_expand": [(0, 10), (0, 10), (0, 10)], - "gaussian_smooth": (3, 5), - "bone_mask": True, - }, - }, - ] +def generate_random_augmentation(ct_image, masks, augmentation_types): augmentation = [] + + probabilities = [a["probability"] for a in augmentation_types] + prob_total = sum(probabilities) + prob_none = 1.0 - prob_total + if prob_none < 0: + prob_none = 0 + for mask in masks: - aug = random.choice(augmentation_types) + aug = random.choices(augmentation_types + [None], weights=probabilities + [prob_none])[0] + + if aug is None: + continue aug_class = aug["class"] aug_args = {} @@ -165,6 +167,9 @@ def augment(self): ) return transform, dvf + def __str__(self): + return f"Shift with vector: {self.vector_shift}, gauss: {self.gaussian_smooth}" + class ExpandAugment(DeformableAugment): def __init__(self, mask, vector_expand=(10, 10, 10), gaussian_smooth=5, bone_mask=False): @@ -185,12 +190,15 @@ def augment(self): return transform, dvf + def __str__(self): + return f"Expand with vector: {self.vector_expand}, smooth: {self.gaussian_smooth}" + class ContractAugment(DeformableAugment): def __init__(self, mask, vector_contract=(10, 10, 10), gaussian_smooth=5, bone_mask=False): self.mask = mask - self.contract = [int(-x / s) for x, s in zip(vector_contract, mask.GetSpacing())] + self.vector_contract = [int(-x / s) for x, s in zip(vector_contract, mask.GetSpacing())] self.gaussian_smooth = gaussian_smooth self.bone_mask = bone_mask @@ -199,7 +207,243 @@ def augment(self): _, transform, dvf = generate_field_expand( self.mask, bone_mask=self.bone_mask, - expand=self.contract, + expand=self.vector_contract, gaussian_smooth=self.gaussian_smooth, ) return transform, dvf + + def __str__(self): + return f"Contract with vector: {self.vector_contract}, smooth: {self.gaussian_smooth}" + + +def augment_data(args): + + random.seed(args.seed) + + augmentation_types = [] + + if args.enable_shift: + augmentation_types.append( + { + "class": ShiftAugment, + "args": { + "vector_shift": [ + tuple(args.shift_x_range), + tuple(args.shift_y_range), + tuple(args.shift_z_range), + ], + "gaussian_smooth": tuple(args.shift_smooth_range), + }, + "probability": args.shift_probability, + } + ) + + if args.enable_contract: + augmentation_types.append( + { + "class": ContractAugment, + "args": { + "vector_contract": [ + tuple(args.contract_x_range), + tuple(args.contract_y_range), + tuple(args.contract_z_range), + ], + "gaussian_smooth": tuple(args.contract_smooth_range), + "bone_mask": args.contract_bone_mask, + }, + "probability": args.contract_probability, + } + ) + + if args.enable_expand: + augmentation_types.append( + { + "class": ExpandAugment, + "args": { + "vector_expand": [ + tuple(args.expand_x_range), + tuple(args.expand_y_range), + tuple(args.expand_z_range), + ], + "gaussian_smooth": tuple(args.expand_smooth_range), + "bone_mask": args.expand_bone_mask, + }, + "probability": args.expand_probability, + } + ) + + data_dir = Path(args.data_dir) + output_dir = Path(args.output_dir) + + cases = [ + p.name.replace(".nii.gz", "") + for p in data_dir.glob(args.case_glob) + if not p.name.startswith(".") + ] + cases.sort() + + data = { + case: { + "image": data_dir.joinpath(args.image_glob.format(case=case)), + "label": [p for p in data_dir.glob(args.label_glob.format(case=case))], + } + for case in cases + } + + for case in cases: + + logger.info(f"Augmenting for case: {case}") + + ct_image_original = sitk.ReadImage(str(data[case]["image"])) + + # Get list of structures to generate augmentations off + logger.debug("Collecting structures") + all_masks = [] + all_names = [] + for structure_path in data[case]["label"]: + + mask = sitk.ReadImage(str(structure_path)) + + all_masks.append(mask) + all_names.append(structure_path.name.replace(".nii.gz", "")) + + logger.debug("Cropping to regions around all structures") + union_mask = get_union_mask(all_masks) + size, index = label_to_roi(union_mask, expansion_mm=[25, 25, 25]) + ct_image = crop_to_roi(ct_image_original, size, index) + + for m, mask in enumerate(all_masks): + all_masks[m] = crop_to_roi(mask, size, index) + + if args.enable_fill_holes: + + logger.debug("Finding holes") + label_image, labels = detect_holes(ct_image) + + # Generate x random augmentations per case + for i in range(args.augmentations_per_case): + + logger.debug(f"Generating augmentation {i}") + + ct_image = sitk.ReadImage(str(data[case]["image"])) + ct_image = crop_to_roi(ct_image, size, index) + + if args.enable_fill_holes: + + logger.debug("Filling holes") + + for label in labels[1:]: # Skip first hole since likely air around body + + if random.random() > args.fill_probability: + continue + + hole = label_image == label["label"] + hole_dilate = sitk.BinaryDilate(hole, (2, 2, 2), sitk.sitkBall) + contour_points = sitk.BinaryContour(hole_dilate) + fill_value = np.median( + sitk.GetArrayFromImage(ct_image)[ + sitk.GetArrayFromImage(contour_points) == 1 + ] + ) + + ct_arr = sitk.GetArrayFromImage(ct_image) + ct_arr[sitk.GetArrayFromImage(hole_dilate) == 1] = fill_value + ct_filled = sitk.GetImageFromArray(ct_arr) + ct_filled.CopyInformation(ct_image) + + ct_image = ct_filled + + augmented_case_path = output_dir.joinpath(case, f"augment_{i}") + augmented_case_path.mkdir(exist_ok=True, parents=True) + + logger.debug("Generating random augmentations") + augmentation = generate_random_augmentation(ct_image, all_masks, augmentation_types) + + dvf = None + + if len(augmentation) == 0: + logger.debug( + "No augmentations generated, generated image won't differ from original" + ) + + augmented_image = ct_image + augmented_masks = all_masks + else: + + logger.debug("Applying augmentation") + augmented_image, augmented_masks, dvf = apply_augmentation( + ct_image, augmentation, masks=all_masks + ) + + augmented_image_path = augmented_case_path.joinpath("CT.nii.gz") + ct_image_original[ + index[0] : index[0] + size[0], + index[1] : index[1] + size[1], + index[2] : index[2] + size[2], + ] = augmented_image + sitk.WriteImage(ct_image_original, str(augmented_image_path)) + + vis = ImageVisualiser(image=ct_image, figure_size_in=6) + vis.add_comparison_overlay(augmented_image) + if dvf is not None: + vis.add_vector_overlay(dvf, arrow_scale=1, subsample=(4, 12, 12)) + for mask_name, mask, augmented_mask in zip(all_names, all_masks, augmented_masks): + vis.add_contour({f"{mask_name}": mask, f"{mask_name} (augmented)": augmented_mask}) + + logger.debug(f"Applying augmentation to mask: {mask_name}") + augmented_mask_path = augmented_case_path.joinpath(f"{mask_name}.nii.gz") + augmented_mask = sitk.Resample( + augmented_mask, ct_image_original, sitk.Transform(), sitk.sitkNearestNeighbor + ) + sitk.WriteImage(augmented_mask, str(augmented_mask_path)) + + fig = vis.show() + + figure_path = augmented_case_path.joinpath("aug.png") + fig.savefig(figure_path, bbox_inches="tight") + plt.close() + + +if __name__ == "__main__": + + arg_parser = ArgumentParser() + arg_parser.add_argument("--seed", type=int, default=42, help="an integer to use as seed") + arg_parser.add_argument("--data_dir", type=str, default="./data") + arg_parser.add_argument("--output_dir", type=str, default="./augment") + arg_parser.add_argument("--case_glob", type=str, default="images/*.nii.gz") + arg_parser.add_argument("--image_glob", type=str, default="images/{case}.nii.gz") + arg_parser.add_argument("--label_glob", type=str, default="labels/{case}_*.nii.gz") + arg_parser.add_argument( + "--augmentations_per_case", + type=int, + default=10, + help="How many augmented images per case to generate", + ) + + arg_parser.add_argument("--enable_shift", type=bool, default=True) + arg_parser.add_argument("--shift_x_range", nargs="+", type=int, default=[-10, 10]) + arg_parser.add_argument("--shift_y_range", nargs="+", type=int, default=[-10, 10]) + arg_parser.add_argument("--shift_z_range", nargs="+", type=int, default=[-10, 10]) + arg_parser.add_argument("--shift_smooth_range", nargs="+", type=int, default=[3, 5]) + arg_parser.add_argument("--shift_probability", type=float, default=0.5) + + arg_parser.add_argument("--enable_expand", type=bool, default=True) + arg_parser.add_argument("--expand_x_range", nargs="+", type=int, default=[0, 10]) + arg_parser.add_argument("--expand_y_range", nargs="+", type=int, default=[0, 10]) + arg_parser.add_argument("--expand_z_range", nargs="+", type=int, default=[0, 10]) + arg_parser.add_argument("--expand_smooth_range", nargs="+", type=int, default=[3, 5]) + arg_parser.add_argument("--expand_bone_mask", type=bool, default=True) + arg_parser.add_argument("--expand_probability", type=float, default=0.5) + + arg_parser.add_argument("--enable_contract", type=bool, default=True) + arg_parser.add_argument("--contract_x_range", nargs="+", type=int, default=[0, 10]) + arg_parser.add_argument("--contract_y_range", nargs="+", type=int, default=[0, 10]) + arg_parser.add_argument("--contract_z_range", nargs="+", type=int, default=[0, 10]) + arg_parser.add_argument("--contract_smooth_range", nargs="+", type=int, default=[3, 5]) + arg_parser.add_argument("--contract_bone_mask", type=bool, default=True) + arg_parser.add_argument("--contract_probability", type=float, default=0.5) + + arg_parser.add_argument("--enable_fill_holes", type=bool, default=True) + arg_parser.add_argument("--fill_probability", type=float, default=0.2) + + augment_data(arg_parser.parse_args()) diff --git a/platipy/imaging/generation/dvf.py b/platipy/imaging/generation/dvf.py index 142405a6..951256b2 100644 --- a/platipy/imaging/generation/dvf.py +++ b/platipy/imaging/generation/dvf.py @@ -16,6 +16,8 @@ import numpy as np import SimpleITK as sitk +from loguru import logger + from platipy.imaging.registration.utils import ( apply_transform, convert_mask_to_reg_structure, @@ -25,13 +27,15 @@ fast_symmetric_forces_demons_registration, ) +from platipy.imaging.utils.crop import label_to_roi, crop_to_roi + -def generate_field_shift(mask_image, vector_shift=(10, 10, 10), gaussian_smooth=5): +def generate_field_shift(mask, vector_shift=(10, 10, 10), gaussian_smooth=5): """ Shifts (moves) a structure defined using a binary mask. Args: - mask_image ([SimpleITK.Image]): The binary mask to shift. + mask ([SimpleITK.Image]): The binary mask to shift. vector_shift (tuple, optional): The displacement vector applied to the entire binary mask. Convention: (+/-, +/-, +/-) = (sup/inf, post/ant, left/right) shift. @@ -45,9 +49,27 @@ def generate_field_shift(mask_image, vector_shift=(10, 10, 10), gaussian_smooth= [SimpleITK.DisplacementFieldTransform]: The transform representing the shift. [SimpleITK.Image]: The displacement vector field representing the shift. """ + + mask_full = mask + + roi_expand = [x + 5 for x in vector_shift] + + if np.any(gaussian_smooth): + + if not hasattr(gaussian_smooth, "__iter__"): + gaussian_smooth = (gaussian_smooth,) * 3 + + roi_expand = [x + y for x, y in zip(roi_expand, gaussian_smooth)] + + # Make sure the expansion meets a minimum size (1cm) + roi_expand = [max(e, 10) for e in roi_expand] + + size, index = label_to_roi(mask, expansion_mm=roi_expand) + mask = crop_to_roi(mask, size, index) + # Define array # Used for image array manipulations - mask_image_arr = sitk.GetArrayFromImage(mask_image) + mask_image_arr = sitk.GetArrayFromImage(mask) # The template deformation field # Used to generate transforms @@ -56,29 +78,28 @@ def generate_field_shift(mask_image, vector_shift=(10, 10, 10), gaussian_smooth= dvf_template = sitk.GetImageFromArray(dvf_arr) # Copy image information - dvf_template.CopyInformation(mask_image) + dvf_template.CopyInformation(mask) dvf_tfm = sitk.DisplacementFieldTransform(sitk.Cast(dvf_template, sitk.sitkVectorFloat64)) - mask_image_shift = apply_transform( - mask_image, transform=dvf_tfm, default_value=0, interpolator=sitk.sitkNearestNeighbor + mask_shift = apply_transform( + mask, transform=dvf_tfm, default_value=0, interpolator=sitk.sitkNearestNeighbor ) - dvf_template = sitk.Mask(dvf_template, mask_image | mask_image_shift) + dvf_template = sitk.Mask(dvf_template, mask | mask_shift) # smooth if np.any(gaussian_smooth): - - if not hasattr(gaussian_smooth, "__iter__"): - gaussian_smooth = (gaussian_smooth,) * 3 - dvf_template = sitk.SmoothingRecursiveGaussian(dvf_template, gaussian_smooth) + # Resample back to original image + dvf_template = sitk.Resample(dvf_template, mask_full) dvf_tfm = sitk.DisplacementFieldTransform(sitk.Cast(dvf_template, sitk.sitkVectorFloat64)) - mask_image_shift = apply_transform( - mask_image, transform=dvf_tfm, default_value=0, interpolator=sitk.sitkNearestNeighbor + + mask_shift = apply_transform( + mask_full, transform=dvf_tfm, default_value=0, interpolator=sitk.sitkNearestNeighbor ) - return mask_image_shift, dvf_tfm, dvf_template + return mask_shift, dvf_tfm, dvf_template def generate_field_asymmetric_contract( @@ -212,32 +233,49 @@ def generate_field_expand( dilation kernel. Args: - mask ([SimpleITK.Image]): The binary mask to expand. - bone_mask ([SimpleITK.Image, optional]): A binary mask defining regions where we expect - restricted deformations. - vector_asymmetric_extend (int |tuple, optional): The expansion vector applied to the entire - binary mask. - Convention: (z,y,x) size of expansion kernel. - Defined in millimetres. - Defaults to 3. + mask (SimpleITK.Image): The binary mask to expand. + bone_mask (SimpleITK.Image, optional): A binary mask defining regions where we expect + restricted deformations. + expand (int |tuple, optional): The expansion vector applied to the entire binary mask. + Convention: (z,y,x) size of expansion kernel. + Defined in millimetres. + Defaults to 3. gaussian_smooth (int | list, optional): Scale of a Gaussian kernel used to smooth the - deformation vector field. Defaults to 5. + deformation vector field. Defaults to 5. Returns: - [SimpleITK.Image]: The binary mask following the expansion. - [SimpleITK.DisplacementFieldTransform]: The transform representing the expansion. - [SimpleITK.Image]: The displacement vector field representing the expansion. + SimpleITK.Image: The binary mask following the expansion. + SimpleITK.DisplacementFieldTransform: The transform representing the expansion. + SimpleITK.Image: The displacement vector field representing the expansion. """ + mask_full = mask + + if not hasattr(expand, "__iter__"): + expand = (expand,) * 3 + + roi_expand = expand + + if np.any(gaussian_smooth): + + if not hasattr(gaussian_smooth, "__iter__"): + gaussian_smooth = (gaussian_smooth,) * 3 + + roi_expand = [x + y for x, y in zip(roi_expand, gaussian_smooth)] + + # Make sure the expansion meets a minimum size (1cm) + roi_expand = [max(e, 10) for e in roi_expand] + + size, index = label_to_roi(mask, expansion_mm=roi_expand) + mask = crop_to_roi(mask, size, index) + if bone_mask is not False: + bone_mask = sitk.Resample(bone_mask, mask, sitk.Transform(), sitk.sitkNearestNeighbor) mask_original = mask + bone_mask else: mask_original = mask # Use binary erosion to create a smaller volume - if not hasattr(expand, "__iter__"): - expand = (expand,) * 3 - expand = np.array(expand) # Convert voxels to millimetres @@ -249,17 +287,17 @@ def generate_field_expand( # If all negative: erode if np.all(np.array(expand) <= 0): - print("All factors negative: shrinking only.") + logger.debug("All factors negative: shrinking only.") mask_expand = sitk.BinaryErode(mask, np.abs(expand).astype(int).tolist(), sitk.sitkBall) # If all positive: dilate elif np.all(np.array(expand) >= 0): - print("All factors positive: expansion only.") + logger.debug("All factors positive: expansion only.") mask_expand = sitk.BinaryDilate(mask, np.abs(expand).astype(int).tolist(), sitk.sitkBall) # Otherwise: sequential operations else: - print("Mixed factors: shrinking and expansion.") + logger.debug("Mixed factors: shrinking and expansion.") expansion_kernel = expand * (expand > 0) shrink_kernel = expand * (expand < 0) @@ -293,16 +331,14 @@ def generate_field_expand( # smooth if np.any(gaussian_smooth): - - if not hasattr(gaussian_smooth, "__iter__"): - gaussian_smooth = (gaussian_smooth,) * 3 - dvf_template = sitk.SmoothingRecursiveGaussian(dvf_template, gaussian_smooth) + # Resample back to original image + dvf_template = sitk.Resample(dvf_template, mask_full) dvf_tfm = sitk.DisplacementFieldTransform(sitk.Cast(dvf_template, sitk.sitkVectorFloat64)) mask_symmetric_expand = apply_transform( - mask, transform=dvf_tfm, default_value=0, interpolator=sitk.sitkNearestNeighbor + mask_full, transform=dvf_tfm, default_value=0, interpolator=sitk.sitkNearestNeighbor ) return mask_symmetric_expand, dvf_tfm, dvf_template From de27b844fe97c3ce2b3004c1c34bc634719a5084 Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Thu, 3 Mar 2022 15:51:26 +1100 Subject: [PATCH 165/264] Add contract from target function --- platipy/imaging/generation/dvf.py | 79 +++++++++++++++++++++++++++++++ 1 file changed, 79 insertions(+) diff --git a/platipy/imaging/generation/dvf.py b/platipy/imaging/generation/dvf.py index 63b9aa04..08ab4346 100644 --- a/platipy/imaging/generation/dvf.py +++ b/platipy/imaging/generation/dvf.py @@ -513,3 +513,82 @@ def expand_mask_towards_target( ) return mask_image_expanded, dvf_tfm, dvf_template + + +def contract_mask_away_from_target( + mask_image, target_image, contract_mag=20, gaussian_smooth=5 +): + """Generate a deformation vector field to contract a mask away from a target mask. Can be useful to + manipulate structures for augmentation of fail cases for automated contour QA work. + + Args: + mask_image (sitk.Image): The mask of the structure of manipulate + target_image (sitk.Image): The mask of the target structure to expand towards + contract_mag (int, optional): The magnitude of the contraction in mm. Defaults to 20. + gaussian_smooth (int, optional): Scale of a Gaussian kernel used to smooth the + deformation vector field.. Defaults to 5. + + Returns: + [SimpleITK.Image]: The binary mask following the contraction. + [SimpleITK.DisplacementFieldTransform]: The transform representing the contraction. + [SimpleITK.Image]: The displacement vector field representing the contraction. + """ + + # Remove any potential overlap between the target and the mask + target_image = sitk.MaskNegated(target_image, mask_image) + dvf_overlap_into_mask = contract_mag + 5 + + # Determine the vector to expand the mask towards + mask_com = get_com(mask_image, as_int=False, real_coords=True) + target_com = get_com(target_image, as_int=False, real_coords=True) + + expand_vec = np.array([q - p for p, q in zip(target_com, mask_com)]) + expand_vec = expand_vec / np.linalg.norm(expand_vec) + + mask_image_arr = sitk.GetArrayFromImage(mask_image) + + # Compute the distance map from the target to every other voxel + dist_map = sitk.SignedMaurerDistanceMap(target_image, squaredDistance=False) + dist_map_arr = sitk.GetArrayFromImage(dist_map) + dist_map_arr[dist_map_arr < 0] = 0 + + # Manipulate the distance map so that only voxel within the range of dvf_overlap_into_mask are + # kept + dist_from_mask_to_target = dist_map_arr[mask_image_arr > 0].min() + max_mask_dist = dist_map_arr[mask_image_arr > 0].max() + dist_map_arr[dist_map_arr > max_mask_dist] = max_mask_dist + dist_map_arr[dist_map_arr > dist_from_mask_to_target + dvf_overlap_into_mask] = ( + dist_from_mask_to_target + dvf_overlap_into_mask + ) + + dvf_weight = np.zeros(dist_map_arr.shape) + dvf_weight[dist_map_arr < dist_from_mask_to_target + dvf_overlap_into_mask] = 1 + dvf_weight = np.tile(np.expand_dims(dvf_weight, axis=3), [1, 1, 1, 3]) + + # The template deformation field + # Used to generate transforms + dvf_arr = np.zeros(mask_image_arr.shape + (3,)) + dvf_arr = dvf_arr - np.array([[[expand_vec * contract_mag]]]) + + # Weight the deformation vectors by the manipulated distance map + dvf_arr = dvf_arr * dvf_weight + dvf_template = sitk.GetImageFromArray(dvf_arr) + + # Copy image information + dvf_template.CopyInformation(mask_image) + + if np.any(gaussian_smooth): + + if not hasattr(gaussian_smooth, "__iter__"): + gaussian_smooth = (gaussian_smooth,) * 3 + + dvf_template = sitk.SmoothingRecursiveGaussian(dvf_template, gaussian_smooth) + + dvf_tfm = sitk.DisplacementFieldTransform(sitk.Cast(dvf_template, sitk.sitkVectorFloat64)) + + mask_image_expanded = apply_transform( + mask_image, transform=dvf_tfm, default_value=0, interpolator=sitk.sitkNearestNeighbor + ) + + return mask_image_expanded, dvf_tfm, dvf_template + From 9f99fe9e31f1f94bb1dc5322439946f80852e770 Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Wed, 27 Apr 2022 07:15:19 +1000 Subject: [PATCH 166/264] Corrections --- platipy/imaging/cnn/localise_net.py | 5 +++-- platipy/imaging/cnn/prob_unet.py | 9 +++------ platipy/imaging/cnn/train.py | 11 ++++++----- 3 files changed, 12 insertions(+), 13 deletions(-) diff --git a/platipy/imaging/cnn/localise_net.py b/platipy/imaging/cnn/localise_net.py index 01204483..82c85ec5 100644 --- a/platipy/imaging/cnn/localise_net.py +++ b/platipy/imaging/cnn/localise_net.py @@ -221,7 +221,8 @@ def validation_epoch_end(self, validation_step_outputs): try: com = get_com(mask) except: - com = [int(i / 2) for i in mask.GetSize()] + com = [int(i / 2) for i in img.GetSize()] + print(com) img_vis = ImageVisualiser(img, cut=com, figure_size_in=16) # img_vis.set_limits_from_label(mask, expansion=[0, 0, 0]) @@ -254,5 +255,5 @@ def validation_epoch_end(self, validation_step_outputs): on_epoch=True, prog_bar=False, logger=True, - batch_size=self.hparams.batch_size, +# batch_size=self.hparams.batch_size, ) diff --git a/platipy/imaging/cnn/prob_unet.py b/platipy/imaging/cnn/prob_unet.py index d8509bba..022e50c6 100644 --- a/platipy/imaging/cnn/prob_unet.py +++ b/platipy/imaging/cnn/prob_unet.py @@ -347,11 +347,10 @@ def prepare_mask( top_k_mask = top_k_mask.to(device) mask = mask * top_k_mask - mask = mask.unsqueeze(1).repeat((1, num_classes)) - else: mask = torch.reshape(mask, (-1,)) + mask = mask.unsqueeze(1).repeat((1, num_classes)) mask = ( mask.reshape((batch_size, -1, num_classes)).transpose(-1, 1).reshape((batch_size, -1)) @@ -367,6 +366,7 @@ def reconstruction_loss( top_k_percentage=None, deterministic=True, ): + criterion = torch.nn.BCEWithLogitsLoss(reduction="none") if z_posterior is None: z_posterior = self.posterior_latent_space.rsample() @@ -383,11 +383,8 @@ def reconstruction_loss( pos_class_count = t_flat.sum(axis=0)/batch_size neg_class_count = torch.logical_not(t_flat).sum(axis=0)/batch_size self._pos_weight = self._pos_weight * 0.5 + pos_class_count/neg_class_count * 0.5 - print(pos_class_count) - print(neg_class_count) - print(self._pos_weight) - criterion = torch.nn.BCEWithLogitsLoss(reduction="none", pos_weight=self._pos_weight) + # criterion = torch.nn.BCEWithLogitsLoss(reduction="none", pos_weight=self._pos_weight) xe = criterion(input=y_flat, target=t_flat) xe = xe.reshape((batch_size, -1, num_classes)).transpose(-1, 1).reshape((batch_size, -1)) diff --git a/platipy/imaging/cnn/train.py b/platipy/imaging/cnn/train.py index d38e0759..a3465273 100644 --- a/platipy/imaging/cnn/train.py +++ b/platipy/imaging/cnn/train.py @@ -75,7 +75,7 @@ def __init__( if self.hparams.prob_type == "prob": self.prob_unet = ProbabilisticUnet( self.hparams.input_channels, - len(self.hparams.structures), + len(self.hparams.structures) + 1, # Add 1 to num classes for background class self.hparams.filters_per_layer, self.hparams.latent_dim, self.hparams.no_convs_fcomb, @@ -270,7 +270,7 @@ def infer( result[sample["name"]] = {} for idx, structure in enumerate(self.hparams.structures): - pred = sitk.GetImageFromArray(pred_arr[idx]) + pred = sitk.GetImageFromArray(pred_arr[idx+1]) # Skip the background pred = pred > 0.5 # Threshold softmax at 0.5 pred = sitk.Cast(pred, sitk.sitkUInt8) @@ -402,9 +402,10 @@ def training_step(self, batch, _): # print(m.shape) # Add background layer for one-hot encoding - # y = torch.unsqueeze(y, dim=1) - # not_y = y.logical_not() - # y = torch.cat((not_y, y), dim=1).float() + #y = torch.unsqueeze(y, dim=1) + not_y = y.max(axis=1).values.logical_not() + not_y = torch.unsqueeze(not_y, dim=1) + y = torch.cat((not_y, y), dim=1).float() # self.prob_unet.forward(x, y, training=True) if self.hparams.prob_type == "prob": From 200922c3a403c7c8fcfcc430c967f50ad1debb1c Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Mon, 30 May 2022 14:02:03 +1000 Subject: [PATCH 167/264] change default dropout --- platipy/imaging/cnn/prob_unet.py | 4 ++-- platipy/imaging/cnn/train.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/platipy/imaging/cnn/prob_unet.py b/platipy/imaging/cnn/prob_unet.py index 385567c6..29d9b6e0 100644 --- a/platipy/imaging/cnn/prob_unet.py +++ b/platipy/imaging/cnn/prob_unet.py @@ -39,7 +39,7 @@ def __init__( down_sample = 0 if idx == 0 else -2 layers.append( - Conv(input_filters, output_filters, up_down_sample=down_sample, ndims=ndims, dropout_probability=0.2) + Conv(input_filters, output_filters, up_down_sample=down_sample, ndims=ndims, dropout_probability=None) ) self.layers = torch.nn.Sequential(*layers) @@ -206,7 +206,7 @@ def __init__( loss_type="elbo", loss_params={"beta": 1}, ndims=2, - dropout_probability=0.2, + dropout_probability=None, ): super(ProbabilisticUnet, self).__init__() diff --git a/platipy/imaging/cnn/train.py b/platipy/imaging/cnn/train.py index c0c8e44a..18a73bd1 100644 --- a/platipy/imaging/cnn/train.py +++ b/platipy/imaging/cnn/train.py @@ -446,7 +446,7 @@ def training_step(self, batch, _): prog_bar=True, logger=True, ) - return loss + return training_loss def validation_step(self, batch, _): From a3f1ed8239eb363a66c68c9a727da30a567277d8 Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Mon, 6 Jun 2022 12:00:51 +0000 Subject: [PATCH 168/264] Add code to process LIDC dataset --- platipy/imaging/cnn/lidc_dataset.py | 226 +++++ platipy/imaging/cnn/train_lidc.py | 1358 +++++++++++++++++++++++++++ 2 files changed, 1584 insertions(+) create mode 100644 platipy/imaging/cnn/lidc_dataset.py create mode 100644 platipy/imaging/cnn/train_lidc.py diff --git a/platipy/imaging/cnn/lidc_dataset.py b/platipy/imaging/cnn/lidc_dataset.py new file mode 100644 index 00000000..96b517b4 --- /dev/null +++ b/platipy/imaging/cnn/lidc_dataset.py @@ -0,0 +1,226 @@ +# Copyright 2022 University of New South Wales, University of Sydney, Ingham Institute + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from pathlib import Path +import pickle +import os + +import SimpleITK as sitk +import numpy as np +import matplotlib.pyplot as plt + +import torch + +from imgaug import augmenters as iaa +from imgaug.augmentables.segmaps import SegmentationMapsOnImage + +from platipy.imaging import ImageVisualiser + +def prepare_transforms(): + + sometimes = lambda aug: iaa.Sometimes(0.5, aug) + + seq = iaa.Sequential( + [ + sometimes( + iaa.Affine( + scale={"x": (0.8, 1.2), "y": (0.8, 1.2)}, + translate_percent={"x": (-0.2, 0.2), "y": (-0.2, 0.2)}, + rotate=(-15, 15), + shear=(-8, 8), + cval=-1, + ) + ), + # execute 0 to 2 of the following (less important) augmenters per image + iaa.SomeOf( + (0, 2), + [ + iaa.OneOf( + [ + iaa.GaussianBlur((0, 1.5)), + iaa.AverageBlur(k=(3, 5)), + ] + ), + sometimes(iaa.PerspectiveTransform(scale=(0.01, 0.1))), + ], + random_order=True, + ), + sometimes(iaa.CoarseDropout((0.03, 0.15), size_percent=(0.02, 0.1))), + ], + random_order=True, + ) + + return seq + + +class LIDCDataset(torch.utils.data.Dataset): + """PyTorch Dataset for processing LIDC data""" + + def __init__( + self, + working_dir, + case_ids=None, + pickle_path="lidc.pickle", + augment_on_fly=True, + ): + """Prepare a dataset from Nifti images/labels + + Args: + data (list): List of dict's where each item contains keys: "image" and "label". Values + are paths to the Nifti file. "label" may be a list where each item is a path to one + observer. + working_dir (str|path): Working directory where to write prepared files. + """ + + self.transforms = None + if augment_on_fly: + self.transforms = prepare_transforms() + self.slices = [] + self.working_dir = Path(working_dir) + + self.img_dir = self.working_dir.joinpath("img") + self.label_dir = self.working_dir.joinpath("label") + self.contour_mask_dir = self.working_dir.joinpath("contour_mask") + self.snap_dir = self.working_dir.joinpath("snapshots") + + self.img_dir.mkdir(exist_ok=True, parents=True) + self.label_dir.mkdir(exist_ok=True, parents=True) + self.contour_mask_dir.mkdir(exist_ok=True, parents=True) + self.snap_dir.mkdir(exist_ok=True, parents=True) + + # If data doesn't already exist, unpickle data and place into directory + if len(list(self.img_dir.glob("*"))) == 0: + pickle_path = Path(pickle_path) + + max_bytes = 2**31 - 1 + data = {} + + print("Loading file", pickle_path) + bytes_in = bytearray(0) + input_size = os.path.getsize(pickle_path) + with open(pickle_path, 'rb') as f_in: + for _ in range(0, input_size, max_bytes): + bytes_in += f_in.read(max_bytes) + new_data = pickle.loads(bytes_in) + data.update(new_data) + + for k,i in data.items(): + + pat_id = k.split("_")[0] + slice_id = k.split("_")[1].replace("slice", "") + + i["pixel_spacing"] = [float(a) for a in i["pixel_spacing"]] + + img_file = self.img_dir.joinpath(f"{pat_id}_{slice_id}.npy") + np.save(img_file, i["image"]) + + intersection = None + union = None + vis = ImageVisualiser(sitk.GetImageFromArray(np.expand_dims(i["image"], axis=0)), axis="z", window=[0,1]) + for obs, mask in enumerate(i["masks"]): + + vis.add_contour(sitk.GetImageFromArray(np.expand_dims(mask, axis=0)), name=f"{obs}") + label_file = self.label_dir.joinpath(f"{pat_id}_{slice_id}_{obs}.npy") + np.save(label_file, mask) + + mask = mask.astype(int) + + if intersection is None: + intersection = np.copy(mask) + else: + intersection += mask + + if union is None: + union = np.copy(mask) + else: + union += mask + + intersection[intersection>1] = 1 + union[union Date: Tue, 7 Jun 2022 16:41:53 +1000 Subject: [PATCH 169/264] zero if both masks empty --- platipy/imaging/cnn/train_lidc.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/platipy/imaging/cnn/train_lidc.py b/platipy/imaging/cnn/train_lidc.py index d0fa2905..cd9dbbb0 100644 --- a/platipy/imaging/cnn/train_lidc.py +++ b/platipy/imaging/cnn/train_lidc.py @@ -1221,12 +1221,15 @@ def validation_step(self, batch, _): self.prob_unet.forward(x) pred_y = self.prob_unet.sample(testing=True) + pred_y = pred_y.to("cpu") + y = y.to("cpu") # Intersection over Union (also known as Jaccard Index) jaccard = JaccardIndex(num_classes=2) term_1 = 0 for i in range(n): for j in range(m): + if pred_y[i,:,:,:].sum() == 0 and y[j,:,:,:].sum() == 0: continue iou = jaccard(pred_y[i,:,:,:].unsqueeze(0), y[j,:,:,:].unsqueeze(0)) term_1 += 1 - iou term_1 = term_1 * (2/(m*n)) @@ -1234,6 +1237,7 @@ def validation_step(self, batch, _): term_2 = 0 for i in range(n): for j in range(n): + if pred_y[i,:,:,:].sum() == 0 and pred_y[j,:,:,:].sum() == 0: continue iou = jaccard(pred_y[i,:,:,:].unsqueeze(0), pred_y[j,:,:,:].unsqueeze(0).argmax(1)) term_2 += 1 - iou term_2 = term_2 * (1/(n*n)) @@ -1241,12 +1245,14 @@ def validation_step(self, batch, _): term_3 = 0 for i in range(m): for j in range(m): + if y[i,:,:,:].sum() == 0 and y[j,:,:,:].sum() == 0: continue iou = jaccard(y[i,:,:,:].unsqueeze(0), y[j,:,:,:].unsqueeze(0)) term_3 += 1 - iou term_3 = term_3 * (1/(m*m)) D_ged = term_1 - term_2 - term_3 + self.log("GED", D_ged) return D_ged @@ -1308,11 +1314,11 @@ def main(args, config_json_path=None): # Save the best model checkpoint_callback = ModelCheckpoint( - monitor="scaled_DSC", + monitor="GED", dirpath=args.default_root_dir, - filename="probunet-{epoch:02d}-{scaled_DSC:.2f}", + filename="probunet-{epoch:02d}-{GED:.2f}", save_top_k=1, - mode="max", + mode="min", ) trainer.callbacks.append(checkpoint_callback) From 3f5ff14322ec51a11c6c7f44a399c016d4ba7ca1 Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Tue, 7 Jun 2022 07:23:41 +0000 Subject: [PATCH 170/264] Visualise image on validate --- platipy/imaging/cnn/train_lidc.py | 24 ++++++++++++++++++++++-- 1 file changed, 22 insertions(+), 2 deletions(-) diff --git a/platipy/imaging/cnn/train_lidc.py b/platipy/imaging/cnn/train_lidc.py index cd9dbbb0..9681f0bc 100644 --- a/platipy/imaging/cnn/train_lidc.py +++ b/platipy/imaging/cnn/train_lidc.py @@ -1024,12 +1024,12 @@ def setup(self, stage=None): ) self.validation_set = LIDCDataset( self.working_dir, - augment_on_fly=augment_on_fly, + augment_on_fly=False, case_ids=self.validation_cases ) self.test_set = LIDCDataset( self.working_dir, - augment_on_fly=augment_on_fly, + augment_on_fly=False, case_ids=self.test_cases ) @@ -1217,6 +1217,7 @@ def validation_step(self, batch, _): # Image will be same for all in batch x = x[0, :, :, :].unsqueeze(0) + vis = ImageVisualiser(sitk.GetImageFromArray(x[0,:,:,:]), axis="z") x = x.repeat(m, 1, 1, 1) self.prob_unet.forward(x) @@ -1252,6 +1253,25 @@ def validation_step(self, batch, _): D_ged = term_1 - term_2 - term_3 + contours = {} + for o in range(n): + contours[f"obs_{o}"] = sitk.GetImageFromArray(y[o,:,:,:]) + for mm in range(m): + contours[f"sample_{mm}"] = sitk.GetImageFromArray(pred_y[j,:,:,:].argmax(1).unsqueeze(0)) + + vis.add_contour(contours, colormap=plt.cm.get_cmap("cool")) + vis.show() + + figure_path = "valid.png" + plt.savefig(figure_path, dpi=300) + plt.close("all") + + try: + self.logger.experiment.log_image(figure_path) + except AttributeError: + # Likely offline mode + pass + self.log("GED", D_ged) return D_ged From 7f2bb0fa45cfcb0c2bf9de06a08032412e608fa4 Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Tue, 7 Jun 2022 20:42:48 +1000 Subject: [PATCH 171/264] Fix dataloading --- platipy/imaging/cnn/lidc_dataset.py | 25 +++++++++--------- platipy/imaging/cnn/train_lidc.py | 39 +++++++++++++++++++++++++++-- 2 files changed, 49 insertions(+), 15 deletions(-) diff --git a/platipy/imaging/cnn/lidc_dataset.py b/platipy/imaging/cnn/lidc_dataset.py index 96b517b4..bd834a12 100644 --- a/platipy/imaging/cnn/lidc_dataset.py +++ b/platipy/imaging/cnn/lidc_dataset.py @@ -171,24 +171,23 @@ def __init__( plt.savefig(fig_path) plt.close(fig) - for img in self.img_dir.glob("*"): - case_and_slice = img.name.replace(".npy", "") - contour_mask = self.contour_mask_dir.joinpath(img.name) - - for label in self.label_dir.glob(f"{case_and_slice}_*.npy"): + for case in case_ids: + for label in self.label_dir.glob(f"{case}_*.npy"): case_id, z_slice, obs = label.name.replace(".npy", "").split("_") - if case_ids is not None and not case_id in case_ids: - continue + img = self.img_dir.joinpath(f"{case_id}_{z_slice}.npy") + contour_mask = self.contour_mask_dir.joinpath(f"{case_id}_{z_slice}.npy") + assert img.exists() + assert contour_mask.exists() self.slices.append( { - "z": z_slice, - "image": img, - "label": label, - "contour_mask": contour_mask, - "case": case_id, - "observer": obs, + "z": z_slice, + "image": img, + "label": label, + "contour_mask": contour_mask, + "case": case_id, + "observer": obs, } ) diff --git a/platipy/imaging/cnn/train_lidc.py b/platipy/imaging/cnn/train_lidc.py index 9681f0bc..72def412 100644 --- a/platipy/imaging/cnn/train_lidc.py +++ b/platipy/imaging/cnn/train_lidc.py @@ -938,7 +938,40 @@ LIDC_PAT_IDS = [ "208177797605474151106520124306", "234289514191030145998276287188", - "255763746952382642554143442493" + "255763746952382642554143442493", + '879950425588844557177426831791', + '880369233375230345237001511006', + '886180838786633773936677813818', + '886430358468567310311007723004', + '887916493746193939407481623391', + '888021904600511420323095129935', + '888517498954149177086283916722', + '891182989185983545761655978623', + '897705953598294772269569489281', + '899082900417573006084750602123', + '899353684702041035700102438716', + '901103510796290218917651903823', + '902763751470786794946912471631', + '906824386019316813789030483695', + '908741193082513651836950434578', + '911801947447849749468305764840', + '915736860556289509455669530049', + '915986308688735366393353350740', + '919002813906622793381125049416', + '921287276013810837359841315530', + '924939006160714533549353726515', + '925679448263116681180549137562', + '927075281189608119735911336961', + '931660023131522836511470299550', + '934701751347399243333120058853', + '939103340398727679812199945201', + '952440288393343800284327753087', + '965523656856760127560055059644', + '969325292841504805778529336047', + '971476961920773226447199844576', + '988068515766013782236551550185', + '989440509183467842001314342301', + '995561512722026805270815340218', ] class LIDCDataModule(pl.LightningDataModule): @@ -1022,11 +1055,13 @@ def setup(self, stage=None): augment_on_fly=augment_on_fly, case_ids=self.train_cases ) + print(f"Training Set Size: {len(self.training_set)}") self.validation_set = LIDCDataset( self.working_dir, augment_on_fly=False, case_ids=self.validation_cases ) + print(f"Validation Set Size: {len(self.validation_set)}") self.test_set = LIDCDataset( self.working_dir, augment_on_fly=False, @@ -1217,7 +1252,7 @@ def validation_step(self, batch, _): # Image will be same for all in batch x = x[0, :, :, :].unsqueeze(0) - vis = ImageVisualiser(sitk.GetImageFromArray(x[0,:,:,:]), axis="z") + vis = ImageVisualiser(sitk.GetImageFromArray(x.to("cpu")[0,:,:,:]), axis="z") x = x.repeat(m, 1, 1, 1) self.prob_unet.forward(x) From c2537a0870049df082cbe12baa54b4a56078ab42 Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Mon, 13 Jun 2022 20:16:03 +1000 Subject: [PATCH 172/264] LIDC tweaks --- platipy/imaging/cnn/lidc_dataset.py | 2 +- platipy/imaging/cnn/train_lidc.py | 88 +++++++++++++++-------------- 2 files changed, 46 insertions(+), 44 deletions(-) diff --git a/platipy/imaging/cnn/lidc_dataset.py b/platipy/imaging/cnn/lidc_dataset.py index bd834a12..7a314be0 100644 --- a/platipy/imaging/cnn/lidc_dataset.py +++ b/platipy/imaging/cnn/lidc_dataset.py @@ -39,7 +39,7 @@ def prepare_transforms(): translate_percent={"x": (-0.2, 0.2), "y": (-0.2, 0.2)}, rotate=(-15, 15), shear=(-8, 8), - cval=-1, + cval=0, ) ), # execute 0 to 2 of the following (less important) augmenters per image diff --git a/platipy/imaging/cnn/train_lidc.py b/platipy/imaging/cnn/train_lidc.py index 72def412..7fa72dd3 100644 --- a/platipy/imaging/cnn/train_lidc.py +++ b/platipy/imaging/cnn/train_lidc.py @@ -935,44 +935,44 @@ '998144378008088787705870980497' ] -LIDC_PAT_IDS = [ - "208177797605474151106520124306", - "234289514191030145998276287188", - "255763746952382642554143442493", - '879950425588844557177426831791', - '880369233375230345237001511006', - '886180838786633773936677813818', - '886430358468567310311007723004', - '887916493746193939407481623391', - '888021904600511420323095129935', - '888517498954149177086283916722', - '891182989185983545761655978623', - '897705953598294772269569489281', - '899082900417573006084750602123', - '899353684702041035700102438716', - '901103510796290218917651903823', - '902763751470786794946912471631', - '906824386019316813789030483695', - '908741193082513651836950434578', - '911801947447849749468305764840', - '915736860556289509455669530049', - '915986308688735366393353350740', - '919002813906622793381125049416', - '921287276013810837359841315530', - '924939006160714533549353726515', - '925679448263116681180549137562', - '927075281189608119735911336961', - '931660023131522836511470299550', - '934701751347399243333120058853', - '939103340398727679812199945201', - '952440288393343800284327753087', - '965523656856760127560055059644', - '969325292841504805778529336047', - '971476961920773226447199844576', - '988068515766013782236551550185', - '989440509183467842001314342301', - '995561512722026805270815340218', -] +#LIDC_PAT_IDS = [ +# "208177797605474151106520124306", +# "234289514191030145998276287188", +# "255763746952382642554143442493", +# '879950425588844557177426831791', +# '880369233375230345237001511006', +# '886180838786633773936677813818', +# '886430358468567310311007723004', +# '887916493746193939407481623391', +# '888021904600511420323095129935', +# '888517498954149177086283916722', +# '891182989185983545761655978623', +# '897705953598294772269569489281', +# '899082900417573006084750602123', +# '899353684702041035700102438716', +# '901103510796290218917651903823', +# '902763751470786794946912471631', +# '906824386019316813789030483695', +# '908741193082513651836950434578', +# '911801947447849749468305764840', +# '915736860556289509455669530049', +# '915986308688735366393353350740', +# '919002813906622793381125049416', +# '921287276013810837359841315530', +# '924939006160714533549353726515', +# '925679448263116681180549137562', +# '927075281189608119735911336961', +# '931660023131522836511470299550', +# '934701751347399243333120058853', +# '939103340398727679812199945201', +# '952440288393343800284327753087', +# '965523656856760127560055059644', +# '969325292841504805778529336047', +# '971476961920773226447199844576', +# '988068515766013782236551550185', +# '989440509183467842001314342301', +# '995561512722026805270815340218', +#] class LIDCDataModule(pl.LightningDataModule): """PyTorch data module to load LIDC data""" @@ -1000,7 +1000,7 @@ def __init__( self.augment_on_fly = augment_on_fly self.batch_size = batch_size self.num_workers = num_workers - + self.training_set = None self.validation_set = None self.test_set = None @@ -1198,7 +1198,6 @@ def training_step(self, batch, _): not_y = torch.unsqueeze(not_y, dim=1) y = torch.cat((not_y, y), dim=1).float() - # self.prob_unet.forward(x, y, training=True) if self.hparams.prob_type == "prob": self.prob_unet.forward(x, y, training=True) else: @@ -1240,7 +1239,7 @@ def training_step(self, batch, _): prog_bar=True, logger=True, ) - return loss + return training_loss def validation_step(self, batch, _): @@ -1292,11 +1291,14 @@ def validation_step(self, batch, _): for o in range(n): contours[f"obs_{o}"] = sitk.GetImageFromArray(y[o,:,:,:]) for mm in range(m): - contours[f"sample_{mm}"] = sitk.GetImageFromArray(pred_y[j,:,:,:].argmax(1).unsqueeze(0)) + samp_pred = pred_y[j,:,:,:] + samp_pred = samp_pred.argmax(0) + samp_pred = samp_pred.unsqueeze(0) + contours[f"sample_{mm}"] = sitk.GetImageFromArray(samp_pred) vis.add_contour(contours, colormap=plt.cm.get_cmap("cool")) vis.show() - + figure_path = "valid.png" plt.savefig(figure_path, dpi=300) plt.close("all") From 1edbc18369fd9f64468aef6287de84343d73ada7 Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Wed, 3 Aug 2022 13:50:25 +1000 Subject: [PATCH 173/264] Correct issue in training step --- platipy/imaging/cnn/train_lidc.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/platipy/imaging/cnn/train_lidc.py b/platipy/imaging/cnn/train_lidc.py index 7fa72dd3..0cff3c08 100644 --- a/platipy/imaging/cnn/train_lidc.py +++ b/platipy/imaging/cnn/train_lidc.py @@ -1040,6 +1040,7 @@ def setup(self, stage=None): self.validation_cases = val_test_cases else: self.validation_cases = val_test_cases[: int(len(val_test_cases) / 2)] + self.validation_cases = val_test_cases[: 5] self.test_cases = val_test_cases[int(len(val_test_cases) / 2) :] else: self.train_cases += cases[f * cases_per_fold : (f + 1) * cases_per_fold] @@ -1194,7 +1195,7 @@ def training_step(self, batch, _): # Add background layer for one-hot encoding #y = torch.unsqueeze(y, dim=1) - not_y = y.max(axis=1).values.logical_not() + not_y = 1 - y.max(axis=1).values not_y = torch.unsqueeze(not_y, dim=1) y = torch.cat((not_y, y), dim=1).float() @@ -1257,6 +1258,9 @@ def validation_step(self, batch, _): pred_y = self.prob_unet.sample(testing=True) pred_y = pred_y.to("cpu") + + pred_y = pred_y.argmax(1) + pred_y = pred_y.unsqueeze(1) y = y.to("cpu") # Intersection over Union (also known as Jaccard Index) @@ -1273,7 +1277,7 @@ def validation_step(self, batch, _): for i in range(n): for j in range(n): if pred_y[i,:,:,:].sum() == 0 and pred_y[j,:,:,:].sum() == 0: continue - iou = jaccard(pred_y[i,:,:,:].unsqueeze(0), pred_y[j,:,:,:].unsqueeze(0).argmax(1)) + iou = jaccard(pred_y[i,:,:,:].unsqueeze(0), pred_y[j,:,:,:].unsqueeze(0)) term_2 += 1 - iou term_2 = term_2 * (1/(n*n)) @@ -1291,7 +1295,7 @@ def validation_step(self, batch, _): for o in range(n): contours[f"obs_{o}"] = sitk.GetImageFromArray(y[o,:,:,:]) for mm in range(m): - samp_pred = pred_y[j,:,:,:] + samp_pred = pred_y[mm,:,:,:] samp_pred = samp_pred.argmax(0) samp_pred = samp_pred.unsqueeze(0) contours[f"sample_{mm}"] = sitk.GetImageFromArray(samp_pred) From 9f1c295840822cfba334473a5f2af428050ab7af Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Thu, 4 Aug 2022 08:25:00 +1000 Subject: [PATCH 174/264] Work on argmax --- platipy/imaging/cnn/train_lidc.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/platipy/imaging/cnn/train_lidc.py b/platipy/imaging/cnn/train_lidc.py index 0cff3c08..ad4c8668 100644 --- a/platipy/imaging/cnn/train_lidc.py +++ b/platipy/imaging/cnn/train_lidc.py @@ -1258,8 +1258,11 @@ def validation_step(self, batch, _): pred_y = self.prob_unet.sample(testing=True) pred_y = pred_y.to("cpu") + print(f"{pred_y[0,:,:,:].min()} {pred_y[0,:,:,:].max()}") + pred_y = torch.sigmoid(pred_y) + print(f"{pred_y[0,:,:,:].min()} {pred_y[0,:,:,:].max()}") - pred_y = pred_y.argmax(1) + pred_y = pred_y[:,1,:,:] > 0.5 pred_y = pred_y.unsqueeze(1) y = y.to("cpu") @@ -1295,9 +1298,9 @@ def validation_step(self, batch, _): for o in range(n): contours[f"obs_{o}"] = sitk.GetImageFromArray(y[o,:,:,:]) for mm in range(m): - samp_pred = pred_y[mm,:,:,:] - samp_pred = samp_pred.argmax(0) - samp_pred = samp_pred.unsqueeze(0) + samp_pred = pred_y[mm,:,:,:].float() + #samp_pred = samp_pred.argmax(0) + #samp_pred = samp_pred.unsqueeze(0) contours[f"sample_{mm}"] = sitk.GetImageFromArray(samp_pred) vis.add_contour(contours, colormap=plt.cm.get_cmap("cool")) From 259f5c7f8fa81f26f04c1c51319e7518f780f13a Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Wed, 3 Aug 2022 22:42:27 +0000 Subject: [PATCH 175/264] Correct argmax --- platipy/imaging/cnn/train_lidc.py | 45 ++++++++++++++++++------------- 1 file changed, 27 insertions(+), 18 deletions(-) diff --git a/platipy/imaging/cnn/train_lidc.py b/platipy/imaging/cnn/train_lidc.py index ad4c8668..7b350453 100644 --- a/platipy/imaging/cnn/train_lidc.py +++ b/platipy/imaging/cnn/train_lidc.py @@ -1256,14 +1256,19 @@ def validation_step(self, batch, _): x = x.repeat(m, 1, 1, 1) self.prob_unet.forward(x) - pred_y = self.prob_unet.sample(testing=True) - pred_y = pred_y.to("cpu") - print(f"{pred_y[0,:,:,:].min()} {pred_y[0,:,:,:].max()}") - pred_y = torch.sigmoid(pred_y) - print(f"{pred_y[0,:,:,:].min()} {pred_y[0,:,:,:].max()}") - - pred_y = pred_y[:,1,:,:] > 0.5 - pred_y = pred_y.unsqueeze(1) + py = self.prob_unet.sample(testing=True) + py = py.to("cpu") + # print(f"{pred_y[0,:,:,:].min()} {pred_y[0,:,:,:].max()}") + # pred_y = torch.sigmoid(pred_y) + # print(f"{pred_y[0,:,:,:].min()} {pred_y[0,:,:,:].max()}") + + # pred_y = pred_y[:,1,:,:] > 0.5 + # pred_y = pred_y.unsqueeze(1) + pred_y = torch.zeros(py[:,0,:].shape).int() + for b in range(py.shape[0]): + pred_y[b] = py[b,:].argmax(0).int() + + y = y.squeeze(1) y = y.to("cpu") # Intersection over Union (also known as Jaccard Index) @@ -1271,24 +1276,27 @@ def validation_step(self, batch, _): term_1 = 0 for i in range(n): for j in range(m): - if pred_y[i,:,:,:].sum() == 0 and y[j,:,:,:].sum() == 0: continue - iou = jaccard(pred_y[i,:,:,:].unsqueeze(0), y[j,:,:,:].unsqueeze(0)) + if pred_y[i].sum() + y[j].sum() == 0: + continue + iou = jaccard(pred_y[i], y[j]) term_1 += 1 - iou term_1 = term_1 * (2/(m*n)) term_2 = 0 for i in range(n): for j in range(n): - if pred_y[i,:,:,:].sum() == 0 and pred_y[j,:,:,:].sum() == 0: continue - iou = jaccard(pred_y[i,:,:,:].unsqueeze(0), pred_y[j,:,:,:].unsqueeze(0)) + if pred_y[i].sum() + pred_y[j].sum() == 0: + continue + iou = jaccard(pred_y[i], pred_y[j]) term_2 += 1 - iou term_2 = term_2 * (1/(n*n)) term_3 = 0 for i in range(m): for j in range(m): - if y[i,:,:,:].sum() == 0 and y[j,:,:,:].sum() == 0: continue - iou = jaccard(y[i,:,:,:].unsqueeze(0), y[j,:,:,:].unsqueeze(0)) + if y[i].sum() + y[j].sum() == 0: + continue + iou = jaccard(y[i], y[j]) term_3 += 1 - iou term_3 = term_3 * (1/(m*m)) @@ -1296,11 +1304,12 @@ def validation_step(self, batch, _): contours = {} for o in range(n): - contours[f"obs_{o}"] = sitk.GetImageFromArray(y[o,:,:,:]) + obs_y = y[o].float() + obs_y = obs_y.unsqueeze(0) + contours[f"obs_{o}"] = sitk.GetImageFromArray(obs_y) for mm in range(m): - samp_pred = pred_y[mm,:,:,:].float() - #samp_pred = samp_pred.argmax(0) - #samp_pred = samp_pred.unsqueeze(0) + samp_pred = pred_y[mm].float() + samp_pred = samp_pred.unsqueeze(0) contours[f"sample_{mm}"] = sitk.GetImageFromArray(samp_pred) vis.add_contour(contours, colormap=plt.cm.get_cmap("cool")) From e75e8b9ccf453c4ba61ba1281aa5e49156d49d36 Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Mon, 8 Aug 2022 10:21:35 +1000 Subject: [PATCH 176/264] Change to rec loss summation --- platipy/imaging/cnn/prob_unet.py | 26 +++++++++++++++----------- platipy/imaging/cnn/train_lidc.py | 15 ++++++++++++++- 2 files changed, 29 insertions(+), 12 deletions(-) diff --git a/platipy/imaging/cnn/prob_unet.py b/platipy/imaging/cnn/prob_unet.py index 29d9b6e0..08dec353 100644 --- a/platipy/imaging/cnn/prob_unet.py +++ b/platipy/imaging/cnn/prob_unet.py @@ -366,7 +366,8 @@ def reconstruction_loss( top_k_percentage=None, deterministic=True, ): - criterion = torch.nn.BCEWithLogitsLoss(reduction="none") +# criterion = torch.nn.BCEWithLogitsLoss(reduction="none") + criterion = torch.nn.BCEWithLogitsLoss(size_average = False, reduce=False, reduction=None) if z_posterior is None: z_posterior = self.posterior_latent_space.rsample() @@ -429,7 +430,7 @@ def reconstruction_loss( return ce_sum, ce_mean, mask - def loss(self, segm, mask=None): + def loss(self, segm, mask=None, beta=None): """ Calculate the evidence lower bound of the log-likelihood of P(Y|X) """ @@ -455,7 +456,7 @@ def loss(self, segm, mask=None): contour_threshold = self.loss_params["kappa_contour"] # Here we use the posterior sample sampled above - _, rec_loss_mean, _ = self.reconstruction_loss( + rl_sum, rec_loss_mean, _ = self.reconstruction_loss( segm, z_posterior=z_posterior, top_k_percentage=top_k_percentage, @@ -464,19 +465,22 @@ def loss(self, segm, mask=None): # If using contour mask in loss, we get back those in a list. Unpack here. if contour_threshold: - contour_loss = rec_loss_mean[1] - contour_loss_mean = rec_loss_mean[1] - reconstruction_loss = rec_loss_mean[0] - rec_loss_mean = rec_loss_mean[0] + contour_loss = rl_sum[1] +# contour_loss_mean = rl_sum[1] + reconstruction_loss = rl_sum[0] + # rec_loss_mean = rl_sum[0] else: - reconstruction_loss = rec_loss_mean + reconstruction_loss = rl_sum if self.loss_type == "elbo": + if beta==None: + beta = self.loss_params["beta"] return { - "loss": reconstruction_loss + self.loss_params["beta"] * kl_div, + "loss": reconstruction_loss + beta * kl_div, "rec_loss": reconstruction_loss, "kl_div": kl_div, + "beta": beta } elif self.loss_type == "geco": @@ -484,7 +488,7 @@ def loss(self, segm, mask=None): moving_avg_factor = 0.8 - rl = rec_loss_mean.detach() + rl = reconstruction_loss.detach() if self._rec_moving_avg is None: self._rec_moving_avg = rl else: @@ -496,7 +500,7 @@ def loss(self, segm, mask=None): cc = 0 if contour_threshold: - cl = contour_loss_mean.detach() + cl = contour_loss.detach() if self._contour_moving_avg is None: self._contour_moving_avg = rl else: diff --git a/platipy/imaging/cnn/train_lidc.py b/platipy/imaging/cnn/train_lidc.py index 7b350453..ddf268f5 100644 --- a/platipy/imaging/cnn/train_lidc.py +++ b/platipy/imaging/cnn/train_lidc.py @@ -1040,7 +1040,7 @@ def setup(self, stage=None): self.validation_cases = val_test_cases else: self.validation_cases = val_test_cases[: int(len(val_test_cases) / 2)] - self.validation_cases = val_test_cases[: 5] + # self.validation_cases = val_test_cases[: 5] self.test_cases = val_test_cases[int(len(val_test_cases) / 2) :] else: self.train_cases += cases[f * cases_per_fold : (f + 1) * cases_per_fold] @@ -1188,6 +1188,17 @@ def configure_optimizers(self): return [optimizer], [scheduler] + def frange_cycle_linear(self, start, stop, n_epoch, n_cycle=4, ratio=0.5): + L = np.ones(n_epoch) + period = n_epoch/n_cycle + step = (stop-start)/(period*ratio) # linear schedule + for c in range(n_cycle): + v , i = start , 0 + while v <= stop and (int(i+c*period) < n_epoch): + L[int(i+c*period)] = v + v += step + i += 1 + return L def training_step(self, batch, _): @@ -1205,6 +1216,8 @@ def training_step(self, batch, _): self.prob_unet.forward(x, y) if self.hparams.prob_type == "prob": + beta_vals = self.frange_cycle_linear(0.0, 0.01, 100, 4, 1.0) +# loss = self.prob_unet.loss(y, mask=m, beta=beta_vals[self.current_epoch]) loss = self.prob_unet.loss(y, mask=m) else: loss = self.prob_unet.loss(x, y, mask=m) From 28246401fa934d1ac4d720487a2b434b9a46fc6b Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Mon, 8 Aug 2022 00:44:25 +0000 Subject: [PATCH 177/264] Used GED for main train function --- platipy/imaging/cnn/train.py | 350 ++++++++++++++++++++++------------- 1 file changed, 217 insertions(+), 133 deletions(-) diff --git a/platipy/imaging/cnn/train.py b/platipy/imaging/cnn/train.py index 18a73bd1..0b90d196 100644 --- a/platipy/imaging/cnn/train.py +++ b/platipy/imaging/cnn/train.py @@ -399,8 +399,7 @@ def training_step(self, batch, _): x, y, m, _ = batch # Add background layer for one-hot encoding - #y = torch.unsqueeze(y, dim=1) - not_y = y.max(axis=1).values.logical_not() + not_y = 1 - y.max(axis=1).values not_y = torch.unsqueeze(not_y, dim=1) y = torch.cat((not_y, y), dim=1).float() @@ -450,113 +449,79 @@ def training_step(self, batch, _): def validation_step(self, batch, _): - if self.validation_directory is None: - self.validation_directory = Path(tempfile.mkdtemp()) + n = 4 + m = 4 with torch.set_grad_enabled(False): x, y, _, info = batch - for s in range(y.shape[0]): - img_file = self.validation_directory.joinpath( - f"img_{info['case'][s]}_{info['z'][s]}.npy" - ) - np.save(img_file, x[s].squeeze(0).cpu().numpy()) - - mask_file = self.validation_directory.joinpath( - f"mask_{info['case'][s]}_{info['z'][s]}_{info['observer'][s]}.npy" - ) - np.save(mask_file, y[s].cpu().numpy()) - - return info - - def validation_epoch_end(self, validation_step_outputs): - - cases = {} - for info in validation_step_outputs: - - for case, z, observer in zip(info["case"], info["z"], info["observer"]): - if not case in cases: - cases[case] = {"slices": z.item(), "observers": [observer]} - else: - if z.item() > cases[case]["slices"]: - - cases[case]["slices"] = z.item() - if not observer in cases[case]["observers"]: - cases[case]["observers"].append(observer) - - metrics = ["DSC", "HD", "ASD"] - computed_metrics = { - **{f"probnet_{s}_{m}": [] for m in metrics for s in self.hparams.structures}, - **{f"unet_{s}_{m}": [] for m in metrics for s in self.hparams.structures}, - } - - for case in cases: - - img_arrs = [] - slices = [] - - if self.hparams.ndims == 2: - for z in range(cases[case]["slices"] + 1): - img_file = self.validation_directory.joinpath(f"img_{case}_{z}.npy") - if img_file.exists(): - img_arrs.append(np.load(img_file)) - slices.append(z) - - img_arr = np.stack(img_arrs) - - else: - img_file = self.validation_directory.joinpath(f"img_{case}_0.npy") - img_arr = np.load(img_file) - img = sitk.GetImageFromArray(img_arr) - img.SetSpacing(self.hparams.spacing) - try: - mean = self.infer(img, num_samples=1, sample_strategy="mean", preprocess=False) - samples = self.infer( - img, - sample_strategy="spaced", - num_samples=5, - spaced_range=[-1.5, 1.5], - preprocess=False, - ) - except Exception as e: - print(f"ERROR DURING VALIDATION INFERENCE: {e}") - return - - observers = {} - for _, observer in enumerate(cases[case]["observers"]): - - if self.hparams.ndims == 2: - mask_arrs = [] - for z in slices: - mask_file = self.validation_directory.joinpath( - f"mask_{case}_{z}_{observer}.npy" - ) - - mask_arrs.append(np.load(mask_file)) - - mask_arr = np.stack(mask_arrs, axis=1) - - else: - mask_file = self.validation_directory.joinpath( - f"mask_{case}_{z}_{observer}.npy" - ) - mask_arr = np.load(mask_file) - - observers[f"manual_{observer}"] = {} - for idx, structure in enumerate(self.hparams.structures): - mask = sitk.GetImageFromArray(mask_arr[idx]) - mask = sitk.Cast(mask, sitk.sitkUInt8) - mask.CopyInformation(img) - observers[f"manual_{observer}"][structure] = mask - - #try: - result, fig = self.validate(img, observers, samples, mean, matching_type="best") - #except Exception as e: - # print(f"ERROR DURING VALIDATION VALIDATE: {e}") - # return - - figure_path = f"valid_{case}.png" - fig.savefig(figure_path, dpi=300) + # Image will be same for all in batch + x = x[0, :, :, :].unsqueeze(0) + vis = ImageVisualiser(sitk.GetImageFromArray(x.to("cpu")[0,:,:,:]), axis="z") + x = x.repeat(m, 1, 1, 1) + self.prob_unet.forward(x) + + py = self.prob_unet.sample(testing=True) + py = py.to("cpu") + # print(f"{pred_y[0,:,:,:].min()} {pred_y[0,:,:,:].max()}") + # pred_y = torch.sigmoid(pred_y) + # print(f"{pred_y[0,:,:,:].min()} {pred_y[0,:,:,:].max()}") + + # pred_y = pred_y[:,1,:,:] > 0.5 + # pred_y = pred_y.unsqueeze(1) + pred_y = torch.zeros(py[:,0,:].shape).int() + for b in range(py.shape[0]): + pred_y[b] = py[b,:].argmax(0).int() + + y = y.squeeze(1) + y = y.to("cpu") + + # Intersection over Union (also known as Jaccard Index) + jaccard = JaccardIndex(num_classes=2) + term_1 = 0 + for i in range(n): + for j in range(m): + if pred_y[i].sum() + y[j].sum() == 0: + continue + iou = jaccard(pred_y[i], y[j]) + term_1 += 1 - iou + term_1 = term_1 * (2/(m*n)) + + term_2 = 0 + for i in range(n): + for j in range(n): + if pred_y[i].sum() + pred_y[j].sum() == 0: + continue + iou = jaccard(pred_y[i], pred_y[j]) + term_2 += 1 - iou + term_2 = term_2 * (1/(n*n)) + + term_3 = 0 + for i in range(m): + for j in range(m): + if y[i].sum() + y[j].sum() == 0: + continue + iou = jaccard(y[i], y[j]) + term_3 += 1 - iou + term_3 = term_3 * (1/(m*m)) + + D_ged = term_1 - term_2 - term_3 + + contours = {} + for o in range(n): + obs_y = y[o].float() + obs_y = obs_y.unsqueeze(0) + contours[f"obs_{o}"] = sitk.GetImageFromArray(obs_y) + for mm in range(m): + samp_pred = pred_y[mm].float() + samp_pred = samp_pred.unsqueeze(0) + contours[f"sample_{mm}"] = sitk.GetImageFromArray(samp_pred) + + vis.add_contour(contours, colormap=plt.cm.get_cmap("cool")) + vis.show() + + figure_path = "valid.png" + plt.savefig(figure_path, dpi=300) plt.close("all") try: @@ -565,31 +530,150 @@ def validation_epoch_end(self, validation_step_outputs): # Likely offline mode pass - for t in result: - for m in metrics: - computed_metrics[f"{t}_{m}"] += result[t][m] - - if self.kl_div: - p = u = 0 - for s in self.hparams.structures: - p += np.array(computed_metrics[f"probnet_{s}_DSC"]).mean() - u += np.array(computed_metrics[f"unet_{s}_DSC"]).mean() - - p /= len(self.hparams.structures) - u /= len(self.hparams.structures) - computed_metrics["scaled_DSC"] = ((p + u) / 2) + (p - u) - self.kl_div - - for cm in computed_metrics: - self.log( - cm, - np.array(computed_metrics[cm]).mean(), - on_step=False, - on_epoch=True, - prog_bar=False, - logger=True, - ) - - # shutil.rmtree(self.validation_directory) + self.log("GED", D_ged) + return D_ged + # def validation_step(self, batch, _): + + # if self.validation_directory is None: + # self.validation_directory = Path(tempfile.mkdtemp()) + + # with torch.set_grad_enabled(False): + # x, y, _, info = batch + # for s in range(y.shape[0]): + + # img_file = self.validation_directory.joinpath( + # f"img_{info['case'][s]}_{info['z'][s]}.npy" + # ) + # np.save(img_file, x[s].squeeze(0).cpu().numpy()) + + # mask_file = self.validation_directory.joinpath( + # f"mask_{info['case'][s]}_{info['z'][s]}_{info['observer'][s]}.npy" + # ) + # np.save(mask_file, y[s].cpu().numpy()) + + # return info + + # def validation_epoch_end(self, validation_step_outputs): + + # cases = {} + # for info in validation_step_outputs: + + # for case, z, observer in zip(info["case"], info["z"], info["observer"]): + # if not case in cases: + # cases[case] = {"slices": z.item(), "observers": [observer]} + # else: + # if z.item() > cases[case]["slices"]: + + # cases[case]["slices"] = z.item() + # if not observer in cases[case]["observers"]: + # cases[case]["observers"].append(observer) + + # metrics = ["DSC", "HD", "ASD"] + # computed_metrics = { + # **{f"probnet_{s}_{m}": [] for m in metrics for s in self.hparams.structures}, + # **{f"unet_{s}_{m}": [] for m in metrics for s in self.hparams.structures}, + # } + + # for case in cases: + + # img_arrs = [] + # slices = [] + + # if self.hparams.ndims == 2: + # for z in range(cases[case]["slices"] + 1): + # img_file = self.validation_directory.joinpath(f"img_{case}_{z}.npy") + # if img_file.exists(): + # img_arrs.append(np.load(img_file)) + # slices.append(z) + + # img_arr = np.stack(img_arrs) + + # else: + # img_file = self.validation_directory.joinpath(f"img_{case}_0.npy") + # img_arr = np.load(img_file) + # img = sitk.GetImageFromArray(img_arr) + # img.SetSpacing(self.hparams.spacing) + # try: + # mean = self.infer(img, num_samples=1, sample_strategy="mean", preprocess=False) + # samples = self.infer( + # img, + # sample_strategy="spaced", + # num_samples=5, + # spaced_range=[-1.5, 1.5], + # preprocess=False, + # ) + # except Exception as e: + # print(f"ERROR DURING VALIDATION INFERENCE: {e}") + # return + + # observers = {} + # for _, observer in enumerate(cases[case]["observers"]): + + # if self.hparams.ndims == 2: + # mask_arrs = [] + # for z in slices: + # mask_file = self.validation_directory.joinpath( + # f"mask_{case}_{z}_{observer}.npy" + # ) + + # mask_arrs.append(np.load(mask_file)) + + # mask_arr = np.stack(mask_arrs, axis=1) + + # else: + # mask_file = self.validation_directory.joinpath( + # f"mask_{case}_{z}_{observer}.npy" + # ) + # mask_arr = np.load(mask_file) + + # observers[f"manual_{observer}"] = {} + # for idx, structure in enumerate(self.hparams.structures): + # mask = sitk.GetImageFromArray(mask_arr[idx]) + # mask = sitk.Cast(mask, sitk.sitkUInt8) + # mask.CopyInformation(img) + # observers[f"manual_{observer}"][structure] = mask + + # #try: + # result, fig = self.validate(img, observers, samples, mean, matching_type="best") + # #except Exception as e: + # # print(f"ERROR DURING VALIDATION VALIDATE: {e}") + # # return + + # figure_path = f"valid_{case}.png" + # fig.savefig(figure_path, dpi=300) + # plt.close("all") + + # try: + # self.logger.experiment.log_image(figure_path) + # except AttributeError: + # # Likely offline mode + # pass + + # for t in result: + # for m in metrics: + # computed_metrics[f"{t}_{m}"] += result[t][m] + + # if self.kl_div: + # p = u = 0 + # for s in self.hparams.structures: + # p += np.array(computed_metrics[f"probnet_{s}_DSC"]).mean() + # u += np.array(computed_metrics[f"unet_{s}_DSC"]).mean() + + # p /= len(self.hparams.structures) + # u /= len(self.hparams.structures) + # computed_metrics["scaled_DSC"] = ((p + u) / 2) + (p - u) - self.kl_div + + # for cm in computed_metrics: + # self.log( + # cm, + # np.array(computed_metrics[cm]).mean(), + # on_step=False, + # on_epoch=True, + # prog_bar=False, + # logger=True, + # ) + + # # shutil.rmtree(self.validation_directory) def main(args, config_json_path=None): @@ -650,11 +734,11 @@ def main(args, config_json_path=None): # Save the best model checkpoint_callback = ModelCheckpoint( - monitor="scaled_DSC", + monitor="GED", dirpath=args.default_root_dir, - filename="probunet-{epoch:02d}-{scaled_DSC:.2f}", + filename="probunet-{epoch:02d}-{GED:.2f}", save_top_k=1, - mode="max", + mode="min", ) trainer.callbacks.append(checkpoint_callback) From 4880422e778b2be0137c1f95fd8e609027ce4fd5 Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Wed, 10 Aug 2022 08:19:14 +1000 Subject: [PATCH 178/264] Add missing import --- platipy/imaging/cnn/train.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/platipy/imaging/cnn/train.py b/platipy/imaging/cnn/train.py index 0b90d196..efd4bcc4 100644 --- a/platipy/imaging/cnn/train.py +++ b/platipy/imaging/cnn/train.py @@ -24,6 +24,7 @@ import comet_ml # pylint: disable=unused-import from pytorch_lightning.loggers import CometLogger +from torchmetrics import JaccardIndex import torch import pytorch_lightning as pl @@ -474,6 +475,7 @@ def validation_step(self, batch, _): pred_y[b] = py[b,:].argmax(0).int() y = y.squeeze(1) + y = y.int() y = y.to("cpu") # Intersection over Union (also known as Jaccard Index) From 35508d8c1267c63a9b2662596315f0e1bef17326 Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Thu, 25 Aug 2022 16:57:43 +1000 Subject: [PATCH 179/264] A few prob UNET corrections --- platipy/imaging/cnn/prob_unet.py | 64 ++++--- platipy/imaging/cnn/train.py | 284 +++++++++++++++---------------- 2 files changed, 182 insertions(+), 166 deletions(-) diff --git a/platipy/imaging/cnn/prob_unet.py b/platipy/imaging/cnn/prob_unet.py index 08dec353..aa2cfa65 100644 --- a/platipy/imaging/cnn/prob_unet.py +++ b/platipy/imaging/cnn/prob_unet.py @@ -26,7 +26,7 @@ class Encoder(torch.nn.Module): """Encoder part of the probabilistic UNet""" def __init__( - self, input_channels, filters_per_layer=[64 * (2 ** x) for x in range(5)], ndims=2 + self, input_channels, filters_per_layer=[64 * (2**x) for x in range(5)], ndims=2 ): super(Encoder, self).__init__() @@ -39,7 +39,13 @@ def __init__( down_sample = 0 if idx == 0 else -2 layers.append( - Conv(input_filters, output_filters, up_down_sample=down_sample, ndims=ndims, dropout_probability=None) + Conv( + input_filters, + output_filters, + up_down_sample=down_sample, + ndims=ndims, + dropout_probability=None, + ) ) self.layers = torch.nn.Sequential(*layers) @@ -55,7 +61,7 @@ class AxisAlignedConvGaussian(torch.nn.Module): def __init__( self, input_channels, - filters_per_layer=[64 * (2 ** x) for x in range(5)], + filters_per_layer=[64 * (2**x) for x in range(5)], latent_dim=2, ndims=2, ): @@ -113,9 +119,8 @@ def forward(self, img, seg=None): if self.ndims == 3: mu_log_sigma = torch.squeeze(mu_log_sigma, dim=2) - - mu = mu_log_sigma[:, :self.latent_dim].clamp(-1000, 1000) - log_sigma = mu_log_sigma[:, self.latent_dim:].clamp(-10, 10) + mu = mu_log_sigma[:, : self.latent_dim].clamp(-1000, 1000) + log_sigma = mu_log_sigma[:, self.latent_dim :].clamp(-10, 10) # This is a multivariate normal with diagonal covariance matrix sigma # https://github.com/pytorch/pytorch/pull/11178 @@ -200,7 +205,7 @@ def __init__( self, input_channels=1, num_classes=2, - filters_per_layer=[64 * (2 ** x) for x in range(5)], + filters_per_layer=[64 * (2**x) for x in range(5)], latent_dim=6, no_convs_fcomb=4, loss_type="elbo", @@ -217,7 +222,12 @@ def __init__( self.z_prior_sample = 0 self.unet = UNet( - input_channels, num_classes, filters_per_layer, final_layer=False, dropout_probability=dropout_probability, ndims=ndims + input_channels, + num_classes, + filters_per_layer, + final_layer=False, + dropout_probability=dropout_probability, + ndims=ndims, ) self.prior = AxisAlignedConvGaussian( input_channels, filters_per_layer, latent_dim, ndims=ndims @@ -366,8 +376,8 @@ def reconstruction_loss( top_k_percentage=None, deterministic=True, ): -# criterion = torch.nn.BCEWithLogitsLoss(reduction="none") - criterion = torch.nn.BCEWithLogitsLoss(size_average = False, reduce=False, reduction=None) + # criterion = torch.nn.BCEWithLogitsLoss(reduction="none") + criterion = torch.nn.BCEWithLogitsLoss(size_average=False, reduce=False, reduction=None) if z_posterior is None: z_posterior = self.posterior_latent_space.rsample() @@ -381,9 +391,11 @@ def reconstruction_loss( n_pixels_in_batch = y_flat.shape[0] batch_size = segm.shape[0] - pos_class_count = t_flat.sum(axis=0)/batch_size - neg_class_count = torch.logical_not(t_flat).sum(axis=0)/batch_size - self._pos_weight = self._pos_weight * 0.5 + (neg_class_count/pos_class_count).clamp(0, 10000) * 0.5 + pos_class_count = t_flat.sum(axis=0) / batch_size + neg_class_count = torch.logical_not(t_flat).sum(axis=0) / batch_size + self._pos_weight = ( + self._pos_weight * 0.5 + (neg_class_count / pos_class_count).clamp(0, 10000) * 0.5 + ) # criterion = torch.nn.BCEWithLogitsLoss(reduction="none", pos_weight=self._pos_weight) xe = criterion(input=y_flat, target=t_flat) @@ -438,7 +450,7 @@ def loss(self, segm, mask=None, beta=None): z_posterior = self.posterior_latent_space.rsample() kl_div = torch.mean(self.kl_divergence()) - #kl_div = torch.clamp(kl_div, 0.0, 100.0) + # kl_div = torch.clamp(kl_div, 0.0, 100.0) top_k_percentage = None if "top_k_percentage" in self.loss_params: @@ -466,27 +478,27 @@ def loss(self, segm, mask=None, beta=None): # If using contour mask in loss, we get back those in a list. Unpack here. if contour_threshold: contour_loss = rl_sum[1] -# contour_loss_mean = rl_sum[1] + # contour_loss_mean = rl_sum[1] reconstruction_loss = rl_sum[0] - # rec_loss_mean = rl_sum[0] + # rec_loss_mean = rl_sum[0] else: reconstruction_loss = rl_sum if self.loss_type == "elbo": - if beta==None: + if beta == None: beta = self.loss_params["beta"] return { "loss": reconstruction_loss + beta * kl_div, "rec_loss": reconstruction_loss, "kl_div": kl_div, - "beta": beta + "beta": beta, } elif self.loss_type == "geco": with torch.no_grad(): - moving_avg_factor = 0.8 + moving_avg_factor = 0.5 rl = reconstruction_loss.detach() if self._rec_moving_avg is None: @@ -514,14 +526,20 @@ def loss(self, segm, mask=None, beta=None): lambda_lower = self.loss_params["clamp_rec"][0] lambda_upper = self.loss_params["clamp_rec"][1] - self._lambda[0] = (torch.exp(rc) * self._lambda[0]).clamp(lambda_lower, lambda_upper) - if self._lambda[0].isnan(): self._lambda[0] = lambda_upper + # self._lambda[0] = (torch.exp(rc) * self._lambda[0]).clamp(lambda_lower, lambda_upper) + self._lambda[0] = (rc * self._lambda[0]).clamp(lambda_lower, lambda_upper) + if self._lambda[0].isnan(): + self._lambda[0] = lambda_upper if contour_threshold: lambda_lower_contour = self.loss_params["clamp_contour"][0] lambda_upper_contour = self.loss_params["clamp_contour"][1] - self._lambda[1] = (torch.exp(cc) * self._lambda[1]).clamp(lambda_lower_contour, lambda_upper_contour) - if self._lambda[1].isnan(): self._lambda[1] = lambda_upper_contour + # self._lambda[1] = (torch.exp(cc) * self._lambda[1]).clamp(lambda_lower_contour, lambda_upper_contour) + self._lambda[1] = (cc * self._lambda[1]).clamp( + lambda_lower_contour, lambda_upper_contour + ) + if self._lambda[1].isnan(): + self._lambda[1] = lambda_upper_contour # pylint: disable=access-member-before-definition loss = (self._lambda[0] * reconstruction_loss) + kl_div diff --git a/platipy/imaging/cnn/train.py b/platipy/imaging/cnn/train.py index efd4bcc4..b78626e7 100644 --- a/platipy/imaging/cnn/train.py +++ b/platipy/imaging/cnn/train.py @@ -450,6 +450,23 @@ def training_step(self, batch, _): def validation_step(self, batch, _): + if self.validation_directory is None: + self.validation_directory = Path(tempfile.mkdtemp()) + + with torch.set_grad_enabled(False): + x, y, _, info = batch + for s in range(y.shape[0]): + + img_file = self.validation_directory.joinpath( + f"img_{info['case'][s]}_{info['z'][s]}.npy" + ) + np.save(img_file, x[s].squeeze(0).cpu().numpy()) + + mask_file = self.validation_directory.joinpath( + f"mask_{info['case'][s]}_{info['z'][s]}_{info['observer'][s]}.npy" + ) + np.save(mask_file, y[s].cpu().numpy()) + n = 4 m = 4 @@ -533,149 +550,130 @@ def validation_step(self, batch, _): pass self.log("GED", D_ged) - return D_ged - # def validation_step(self, batch, _): - - # if self.validation_directory is None: - # self.validation_directory = Path(tempfile.mkdtemp()) - - # with torch.set_grad_enabled(False): - # x, y, _, info = batch - # for s in range(y.shape[0]): - - # img_file = self.validation_directory.joinpath( - # f"img_{info['case'][s]}_{info['z'][s]}.npy" - # ) - # np.save(img_file, x[s].squeeze(0).cpu().numpy()) - - # mask_file = self.validation_directory.joinpath( - # f"mask_{info['case'][s]}_{info['z'][s]}_{info['observer'][s]}.npy" - # ) - # np.save(mask_file, y[s].cpu().numpy()) - - # return info - - # def validation_epoch_end(self, validation_step_outputs): - - # cases = {} - # for info in validation_step_outputs: - - # for case, z, observer in zip(info["case"], info["z"], info["observer"]): - # if not case in cases: - # cases[case] = {"slices": z.item(), "observers": [observer]} - # else: - # if z.item() > cases[case]["slices"]: - - # cases[case]["slices"] = z.item() - # if not observer in cases[case]["observers"]: - # cases[case]["observers"].append(observer) - - # metrics = ["DSC", "HD", "ASD"] - # computed_metrics = { - # **{f"probnet_{s}_{m}": [] for m in metrics for s in self.hparams.structures}, - # **{f"unet_{s}_{m}": [] for m in metrics for s in self.hparams.structures}, - # } - - # for case in cases: - - # img_arrs = [] - # slices = [] - - # if self.hparams.ndims == 2: - # for z in range(cases[case]["slices"] + 1): - # img_file = self.validation_directory.joinpath(f"img_{case}_{z}.npy") - # if img_file.exists(): - # img_arrs.append(np.load(img_file)) - # slices.append(z) - - # img_arr = np.stack(img_arrs) - - # else: - # img_file = self.validation_directory.joinpath(f"img_{case}_0.npy") - # img_arr = np.load(img_file) - # img = sitk.GetImageFromArray(img_arr) - # img.SetSpacing(self.hparams.spacing) - # try: - # mean = self.infer(img, num_samples=1, sample_strategy="mean", preprocess=False) - # samples = self.infer( - # img, - # sample_strategy="spaced", - # num_samples=5, - # spaced_range=[-1.5, 1.5], - # preprocess=False, - # ) - # except Exception as e: - # print(f"ERROR DURING VALIDATION INFERENCE: {e}") - # return - - # observers = {} - # for _, observer in enumerate(cases[case]["observers"]): - - # if self.hparams.ndims == 2: - # mask_arrs = [] - # for z in slices: - # mask_file = self.validation_directory.joinpath( - # f"mask_{case}_{z}_{observer}.npy" - # ) - - # mask_arrs.append(np.load(mask_file)) - - # mask_arr = np.stack(mask_arrs, axis=1) - - # else: - # mask_file = self.validation_directory.joinpath( - # f"mask_{case}_{z}_{observer}.npy" - # ) - # mask_arr = np.load(mask_file) - - # observers[f"manual_{observer}"] = {} - # for idx, structure in enumerate(self.hparams.structures): - # mask = sitk.GetImageFromArray(mask_arr[idx]) - # mask = sitk.Cast(mask, sitk.sitkUInt8) - # mask.CopyInformation(img) - # observers[f"manual_{observer}"][structure] = mask - - # #try: - # result, fig = self.validate(img, observers, samples, mean, matching_type="best") - # #except Exception as e: - # # print(f"ERROR DURING VALIDATION VALIDATE: {e}") - # # return - - # figure_path = f"valid_{case}.png" - # fig.savefig(figure_path, dpi=300) - # plt.close("all") - - # try: - # self.logger.experiment.log_image(figure_path) - # except AttributeError: - # # Likely offline mode - # pass - - # for t in result: - # for m in metrics: - # computed_metrics[f"{t}_{m}"] += result[t][m] - - # if self.kl_div: - # p = u = 0 - # for s in self.hparams.structures: - # p += np.array(computed_metrics[f"probnet_{s}_DSC"]).mean() - # u += np.array(computed_metrics[f"unet_{s}_DSC"]).mean() - - # p /= len(self.hparams.structures) - # u /= len(self.hparams.structures) - # computed_metrics["scaled_DSC"] = ((p + u) / 2) + (p - u) - self.kl_div - - # for cm in computed_metrics: - # self.log( - # cm, - # np.array(computed_metrics[cm]).mean(), - # on_step=False, - # on_epoch=True, - # prog_bar=False, - # logger=True, - # ) - - # # shutil.rmtree(self.validation_directory) + + return info + + def validation_epoch_end(self, validation_step_outputs): + + cases = {} + for info in validation_step_outputs: + + for case, z, observer in zip(info["case"], info["z"], info["observer"]): + if not case in cases: + cases[case] = {"slices": z.item(), "observers": [observer]} + else: + if z.item() > cases[case]["slices"]: + + cases[case]["slices"] = z.item() + if not observer in cases[case]["observers"]: + cases[case]["observers"].append(observer) + + metrics = ["DSC", "HD", "ASD"] + computed_metrics = { + **{f"probnet_{s}_{m}": [] for m in metrics for s in self.hparams.structures}, + **{f"unet_{s}_{m}": [] for m in metrics for s in self.hparams.structures}, + } + + for case in cases: + + img_arrs = [] + slices = [] + + if self.hparams.ndims == 2: + for z in range(cases[case]["slices"] + 1): + img_file = self.validation_directory.joinpath(f"img_{case}_{z}.npy") + if img_file.exists(): + img_arrs.append(np.load(img_file)) + slices.append(z) + + img_arr = np.stack(img_arrs) + + else: + img_file = self.validation_directory.joinpath(f"img_{case}_0.npy") + img_arr = np.load(img_file) + img = sitk.GetImageFromArray(img_arr) + img.SetSpacing(self.hparams.spacing) + try: + mean = self.infer(img, num_samples=1, sample_strategy="mean", preprocess=False) + samples = self.infer( + img, + sample_strategy="spaced", + num_samples=5, + spaced_range=[-1.5, 1.5], + preprocess=False, + ) + except Exception as e: + print(f"ERROR DURING VALIDATION INFERENCE: {e}") + return + + observers = {} + for _, observer in enumerate(cases[case]["observers"]): + + if self.hparams.ndims == 2: + mask_arrs = [] + for z in slices: + mask_file = self.validation_directory.joinpath( + f"mask_{case}_{z}_{observer}.npy" + ) + + mask_arrs.append(np.load(mask_file)) + + mask_arr = np.stack(mask_arrs, axis=1) + + else: + mask_file = self.validation_directory.joinpath( + f"mask_{case}_{z}_{observer}.npy" + ) + mask_arr = np.load(mask_file) + + observers[f"manual_{observer}"] = {} + for idx, structure in enumerate(self.hparams.structures): + mask = sitk.GetImageFromArray(mask_arr[idx]) + mask = sitk.Cast(mask, sitk.sitkUInt8) + mask.CopyInformation(img) + observers[f"manual_{observer}"][structure] = mask + + #try: + result, fig = self.validate(img, observers, samples, mean, matching_type="best") + #except Exception as e: + # print(f"ERROR DURING VALIDATION VALIDATE: {e}") + # return + + figure_path = f"valid_{case}.png" + fig.savefig(figure_path, dpi=300) + plt.close("all") + + try: + self.logger.experiment.log_image(figure_path) + except AttributeError: + # Likely offline mode + pass + + for t in result: + for m in metrics: + computed_metrics[f"{t}_{m}"] += result[t][m] + + if self.kl_div: + p = u = 0 + for s in self.hparams.structures: + p += np.array(computed_metrics[f"probnet_{s}_DSC"]).mean() + u += np.array(computed_metrics[f"unet_{s}_DSC"]).mean() + + p /= len(self.hparams.structures) + u /= len(self.hparams.structures) + computed_metrics["scaled_DSC"] = ((p + u) / 2) + (p - u) - self.kl_div + + for cm in computed_metrics: + self.log( + cm, + np.array(computed_metrics[cm]).mean(), + on_step=False, + on_epoch=True, + prog_bar=False, + logger=True, + ) + + # shutil.rmtree(self.validation_directory) def main(args, config_json_path=None): From 4a038ec9bbc50a22750fbd695e38a1563144a7fd Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Fri, 2 Sep 2022 16:28:32 +1000 Subject: [PATCH 180/264] Corrections to prob-unet --- platipy/imaging/cnn/prob_unet.py | 22 +++++++----- platipy/imaging/cnn/train.py | 62 +++++++++++++++++--------------- 2 files changed, 47 insertions(+), 37 deletions(-) diff --git a/platipy/imaging/cnn/prob_unet.py b/platipy/imaging/cnn/prob_unet.py index aa2cfa65..08aa41a0 100644 --- a/platipy/imaging/cnn/prob_unet.py +++ b/platipy/imaging/cnn/prob_unet.py @@ -247,7 +247,7 @@ def __init__( if self.loss_type == "geco": self._rec_moving_avg = None self._contour_moving_avg = None - self.register_buffer("_lambda", torch.zeros(2, requires_grad=False)) + self.register_buffer("_lambda", torch.ones(2, requires_grad=False)) self.register_buffer("_pos_weight", torch.ones(num_classes, requires_grad=False)) @@ -478,9 +478,9 @@ def loss(self, segm, mask=None, beta=None): # If using contour mask in loss, we get back those in a list. Unpack here. if contour_threshold: contour_loss = rl_sum[1] - # contour_loss_mean = rl_sum[1] + contour_loss_mean = rec_loss_mean[1] reconstruction_loss = rl_sum[0] - # rec_loss_mean = rl_sum[0] + rec_loss_mean = rec_loss_mean[0] else: reconstruction_loss = rl_sum @@ -500,7 +500,7 @@ def loss(self, segm, mask=None, beta=None): moving_avg_factor = 0.5 - rl = reconstruction_loss.detach() + rl = rec_loss_mean.detach() if self._rec_moving_avg is None: self._rec_moving_avg = rl else: @@ -512,7 +512,7 @@ def loss(self, segm, mask=None, beta=None): cc = 0 if contour_threshold: - cl = contour_loss.detach() + cl = contour_loss_mean.detach() if self._contour_moving_avg is None: self._contour_moving_avg = rl else: @@ -526,18 +526,22 @@ def loss(self, segm, mask=None, beta=None): lambda_lower = self.loss_params["clamp_rec"][0] lambda_upper = self.loss_params["clamp_rec"][1] - # self._lambda[0] = (torch.exp(rc) * self._lambda[0]).clamp(lambda_lower, lambda_upper) - self._lambda[0] = (rc * self._lambda[0]).clamp(lambda_lower, lambda_upper) + self._lambda[0] = (torch.exp(rc) * self._lambda[0]).clamp( + lambda_lower, lambda_upper + ) + # self._lambda[0] = (rc * self._lambda[0]).clamp(lambda_lower, lambda_upper) if self._lambda[0].isnan(): self._lambda[0] = lambda_upper if contour_threshold: lambda_lower_contour = self.loss_params["clamp_contour"][0] lambda_upper_contour = self.loss_params["clamp_contour"][1] - # self._lambda[1] = (torch.exp(cc) * self._lambda[1]).clamp(lambda_lower_contour, lambda_upper_contour) - self._lambda[1] = (cc * self._lambda[1]).clamp( + self._lambda[1] = (torch.exp(cc) * self._lambda[1]).clamp( lambda_lower_contour, lambda_upper_contour ) + # self._lambda[1] = (cc * self._lambda[1]).clamp( + # lambda_lower_contour, lambda_upper_contour + # ) if self._lambda[1].isnan(): self._lambda[1] = lambda_upper_contour diff --git a/platipy/imaging/cnn/train.py b/platipy/imaging/cnn/train.py index b78626e7..45cbaddc 100644 --- a/platipy/imaging/cnn/train.py +++ b/platipy/imaging/cnn/train.py @@ -76,7 +76,7 @@ def __init__( if self.hparams.prob_type == "prob": self.prob_unet = ProbabilisticUnet( self.hparams.input_channels, - len(self.hparams.structures) + 1, # Add 1 to num classes for background class + len(self.hparams.structures) + 1, # Add 1 to num classes for background class self.hparams.filters_per_layer, self.hparams.latent_dim, self.hparams.no_convs_fcomb, @@ -111,7 +111,7 @@ def add_model_specific_args(parent_parser): parser.add_argument("--lr_lambda", type=float, default=0.99) parser.add_argument("--input_channels", type=int, default=1) parser.add_argument( - "--filters_per_layer", nargs="+", type=int, default=[64 * (2 ** x) for x in range(5)] + "--filters_per_layer", nargs="+", type=int, default=[64 * (2**x) for x in range(5)] ) parser.add_argument("--down_channels_per_block", nargs="+", type=int, default=None) parser.add_argument("--latent_dim", type=int, default=6) @@ -271,7 +271,7 @@ def infer( result[sample["name"]] = {} for idx, structure in enumerate(self.hparams.structures): - pred = sitk.GetImageFromArray(pred_arr[idx+1]) # Skip the background + pred = sitk.GetImageFromArray(pred_arr[idx + 1]) # Skip the background pred = pred > 0.5 # Threshold softmax at 0.5 pred = sitk.Cast(pred, sitk.sitkUInt8) @@ -296,7 +296,7 @@ def validate( try: cut = get_com(mean["mean"][structures[0]]) except ValueError: - cut = [int(i/2) for i in img.GetSize()][::-1] + cut = [int(i / 2) for i in img.GetSize()][::-1] vis = ImageVisualiser(img, cut=cut, figure_size_in=16, window=window) @@ -323,7 +323,10 @@ def validate( union_mask = get_union_mask(manual_observers_struct) vis.add_contour( - intersection_mask, name=f"intersection_{structure}", color=manual_color, linewidth=3 + intersection_mask, + name=f"intersection_{structure}", + color=manual_color, + linewidth=3, ) vis.add_contour(union_mask, name=f"union_{structure}", color=manual_color, linewidth=3) @@ -355,7 +358,9 @@ def validate( sample_metrics = get_metrics( manual_observers_struct[obs], samples_struct[samp] ) - mean_metrics = get_metrics(manual_observers_struct[obs], mean_contours[f"mean_{structure}"]) + mean_metrics = get_metrics( + manual_observers_struct[obs], mean_contours[f"mean_{structure}"] + ) for k in sample_metrics: sim[k][sid, oid] = sample_metrics[k] @@ -453,8 +458,13 @@ def validation_step(self, batch, _): if self.validation_directory is None: self.validation_directory = Path(tempfile.mkdtemp()) + n = 4 + m = 4 + with torch.set_grad_enabled(False): x, y, _, info = batch + + # Save off slices/volumes for analysis of entire structure in end of validation step for s in range(y.shape[0]): img_file = self.validation_directory.joinpath( @@ -467,29 +477,25 @@ def validation_step(self, batch, _): ) np.save(mask_file, y[s].cpu().numpy()) - n = 4 - m = 4 - - with torch.set_grad_enabled(False): - x, y, _, info = batch - # Image will be same for all in batch - x = x[0, :, :, :].unsqueeze(0) - vis = ImageVisualiser(sitk.GetImageFromArray(x.to("cpu")[0,:,:,:]), axis="z") - x = x.repeat(m, 1, 1, 1) + x = x[0].unsqueeze(0) + if self.hparams.ndims == 2: + vis = ImageVisualiser(sitk.GetImageFromArray(x.to("cpu")[0]), axis="z") + else: + vis = ImageVisualiser(sitk.GetImageFromArray(x.to("cpu")[0])) + + if self.hparams.ndims == 2: + x = x.repeat(m, 1, 1, 1) + else: + x = x.repeat(m, 1, 1, 1, 1) self.prob_unet.forward(x) py = self.prob_unet.sample(testing=True) py = py.to("cpu") - # print(f"{pred_y[0,:,:,:].min()} {pred_y[0,:,:,:].max()}") - # pred_y = torch.sigmoid(pred_y) - # print(f"{pred_y[0,:,:,:].min()} {pred_y[0,:,:,:].max()}") - # pred_y = pred_y[:,1,:,:] > 0.5 - # pred_y = pred_y.unsqueeze(1) - pred_y = torch.zeros(py[:,0,:].shape).int() + pred_y = torch.zeros(py[:, 0, :].shape).int() for b in range(py.shape[0]): - pred_y[b] = py[b,:].argmax(0).int() + pred_y[b] = py[b, :].argmax(0).int() y = y.squeeze(1) y = y.int() @@ -504,7 +510,7 @@ def validation_step(self, batch, _): continue iou = jaccard(pred_y[i], y[j]) term_1 += 1 - iou - term_1 = term_1 * (2/(m*n)) + term_1 = term_1 * (2 / (m * n)) term_2 = 0 for i in range(n): @@ -513,7 +519,7 @@ def validation_step(self, batch, _): continue iou = jaccard(pred_y[i], pred_y[j]) term_2 += 1 - iou - term_2 = term_2 * (1/(n*n)) + term_2 = term_2 * (1 / (n * n)) term_3 = 0 for i in range(m): @@ -522,7 +528,7 @@ def validation_step(self, batch, _): continue iou = jaccard(y[i], y[j]) term_3 += 1 - iou - term_3 = term_3 * (1/(m*m)) + term_3 = term_3 * (1 / (m * m)) D_ged = term_1 - term_2 - term_3 @@ -539,7 +545,7 @@ def validation_step(self, batch, _): vis.add_contour(contours, colormap=plt.cm.get_cmap("cool")) vis.show() - figure_path = "valid.png" + figure_path = f"ged_{info['z'][s]}.png" plt.savefig(figure_path, dpi=300) plt.close("all") @@ -633,9 +639,9 @@ def validation_epoch_end(self, validation_step_outputs): mask.CopyInformation(img) observers[f"manual_{observer}"][structure] = mask - #try: + # try: result, fig = self.validate(img, observers, samples, mean, matching_type="best") - #except Exception as e: + # except Exception as e: # print(f"ERROR DURING VALIDATION VALIDATE: {e}") # return From d73467b3ccae90b1d12530631ee7153c259b56aa Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Tue, 6 Sep 2022 17:19:28 +1000 Subject: [PATCH 181/264] Update to train --- platipy/imaging/cnn/train.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/platipy/imaging/cnn/train.py b/platipy/imaging/cnn/train.py index 45cbaddc..9be2ab3a 100644 --- a/platipy/imaging/cnn/train.py +++ b/platipy/imaging/cnn/train.py @@ -458,8 +458,8 @@ def validation_step(self, batch, _): if self.validation_directory is None: self.validation_directory = Path(tempfile.mkdtemp()) - n = 4 - m = 4 + n = self.hparams.num_observers + m = self.hparams.num_observers with torch.set_grad_enabled(False): x, y, _, info = batch @@ -482,7 +482,7 @@ def validation_step(self, batch, _): if self.hparams.ndims == 2: vis = ImageVisualiser(sitk.GetImageFromArray(x.to("cpu")[0]), axis="z") else: - vis = ImageVisualiser(sitk.GetImageFromArray(x.to("cpu")[0])) + vis = ImageVisualiser(sitk.GetImageFromArray(x.to("cpu")[0, 0])) if self.hparams.ndims == 2: x = x.repeat(m, 1, 1, 1) @@ -535,11 +535,13 @@ def validation_step(self, batch, _): contours = {} for o in range(n): obs_y = y[o].float() - obs_y = obs_y.unsqueeze(0) + if self.hparams.ndims == 2: + obs_y = obs_y.unsqueeze(0) contours[f"obs_{o}"] = sitk.GetImageFromArray(obs_y) for mm in range(m): samp_pred = pred_y[mm].float() - samp_pred = samp_pred.unsqueeze(0) + if self.hparams.ndims == 2: + samp_pred = samp_pred.unsqueeze(0) contours[f"sample_{mm}"] = sitk.GetImageFromArray(samp_pred) vis.add_contour(contours, colormap=plt.cm.get_cmap("cool")) @@ -691,6 +693,7 @@ def main(args, config_json_path=None): # args.default_root_dir = str(args.working_dir) args.fold_dir = args.working_dir.joinpath(f"fold_{args.fold}") args.default_root_dir = str(args.fold_dir) + args.accumulate_grad_batches = {0: 20, 10: 10, 50: 5, 100: 1} comet_api_key = None comet_workspace = None From e36acd4f71656d2fcbc15761e27cfc727eff3658 Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Tue, 6 Sep 2022 17:20:32 +1000 Subject: [PATCH 182/264] Correct code --- platipy/imaging/cnn/prob_unet.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/platipy/imaging/cnn/prob_unet.py b/platipy/imaging/cnn/prob_unet.py index 08aa41a0..ca9cab87 100644 --- a/platipy/imaging/cnn/prob_unet.py +++ b/platipy/imaging/cnn/prob_unet.py @@ -485,7 +485,7 @@ def loss(self, segm, mask=None, beta=None): reconstruction_loss = rl_sum if self.loss_type == "elbo": - if beta == None: + if beta is None: beta = self.loss_params["beta"] return { From afd3cd32b4b4d9082d449e5027b2ee4053869d0a Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Thu, 8 Sep 2022 17:16:58 +1000 Subject: [PATCH 183/264] set geco step size --- platipy/imaging/cnn/prob_unet.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/platipy/imaging/cnn/prob_unet.py b/platipy/imaging/cnn/prob_unet.py index 08aa41a0..1f91dec0 100644 --- a/platipy/imaging/cnn/prob_unet.py +++ b/platipy/imaging/cnn/prob_unet.py @@ -496,6 +496,8 @@ def loss(self, segm, mask=None, beta=None): } elif self.loss_type == "geco": + rec_geco_step_size = 1e-02 + with torch.no_grad(): moving_avg_factor = 0.5 @@ -526,7 +528,7 @@ def loss(self, segm, mask=None, beta=None): lambda_lower = self.loss_params["clamp_rec"][0] lambda_upper = self.loss_params["clamp_rec"][1] - self._lambda[0] = (torch.exp(rc) * self._lambda[0]).clamp( + self._lambda[0] = (torch.exp(rc * rec_geco_step_size) * self._lambda[0]).clamp( lambda_lower, lambda_upper ) # self._lambda[0] = (rc * self._lambda[0]).clamp(lambda_lower, lambda_upper) @@ -536,7 +538,7 @@ def loss(self, segm, mask=None, beta=None): lambda_lower_contour = self.loss_params["clamp_contour"][0] lambda_upper_contour = self.loss_params["clamp_contour"][1] - self._lambda[1] = (torch.exp(cc) * self._lambda[1]).clamp( + self._lambda[1] = (torch.exp(cc * rec_geco_step_size) * self._lambda[1]).clamp( lambda_lower_contour, lambda_upper_contour ) # self._lambda[1] = (cc * self._lambda[1]).clamp( From 007d0434a863abf9f8fb4025bb7fa585956628fe Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Thu, 8 Sep 2022 17:19:00 +1000 Subject: [PATCH 184/264] Remove acc grad --- platipy/imaging/cnn/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/platipy/imaging/cnn/train.py b/platipy/imaging/cnn/train.py index 9be2ab3a..bb16b8da 100644 --- a/platipy/imaging/cnn/train.py +++ b/platipy/imaging/cnn/train.py @@ -693,7 +693,7 @@ def main(args, config_json_path=None): # args.default_root_dir = str(args.working_dir) args.fold_dir = args.working_dir.joinpath(f"fold_{args.fold}") args.default_root_dir = str(args.fold_dir) - args.accumulate_grad_batches = {0: 20, 10: 10, 50: 5, 100: 1} + # args.accumulate_grad_batches = {0: 20, 10: 10, 50: 5, 100: 1} comet_api_key = None comet_workspace = None From 81dbd9fb3eba454aa137244711be415fdba39b66 Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Thu, 15 Sep 2022 13:57:43 +1000 Subject: [PATCH 185/264] adjustments to prob unet --- platipy/imaging/cnn/prob_unet.py | 14 ++++----- platipy/imaging/cnn/train.py | 51 ++++++++++++++++---------------- 2 files changed, 33 insertions(+), 32 deletions(-) diff --git a/platipy/imaging/cnn/prob_unet.py b/platipy/imaging/cnn/prob_unet.py index c3c71695..81c8bf90 100644 --- a/platipy/imaging/cnn/prob_unet.py +++ b/platipy/imaging/cnn/prob_unet.py @@ -376,8 +376,8 @@ def reconstruction_loss( top_k_percentage=None, deterministic=True, ): - # criterion = torch.nn.BCEWithLogitsLoss(reduction="none") - criterion = torch.nn.BCEWithLogitsLoss(size_average=False, reduce=False, reduction=None) + criterion = torch.nn.BCEWithLogitsLoss(reduction="none") + # criterion = torch.nn.BCEWithLogitsLoss(size_average=False, reduce=False, reduction=None) if z_posterior is None: z_posterior = self.posterior_latent_space.rsample() @@ -391,11 +391,11 @@ def reconstruction_loss( n_pixels_in_batch = y_flat.shape[0] batch_size = segm.shape[0] - pos_class_count = t_flat.sum(axis=0) / batch_size - neg_class_count = torch.logical_not(t_flat).sum(axis=0) / batch_size - self._pos_weight = ( - self._pos_weight * 0.5 + (neg_class_count / pos_class_count).clamp(0, 10000) * 0.5 - ) + # pos_class_count = t_flat.sum(axis=0) / batch_size + # neg_class_count = torch.logical_not(t_flat).sum(axis=0) / batch_size + # self._pos_weight = ( + # self._pos_weight * 0.5 + (neg_class_count / pos_class_count).clamp(0, 10000) * 0.5 + # ) # criterion = torch.nn.BCEWithLogitsLoss(reduction="none", pos_weight=self._pos_weight) xe = criterion(input=y_flat, target=t_flat) diff --git a/platipy/imaging/cnn/train.py b/platipy/imaging/cnn/train.py index bb16b8da..53e28346 100644 --- a/platipy/imaging/cnn/train.py +++ b/platipy/imaging/cnn/train.py @@ -85,18 +85,19 @@ def __init__( self.hparams.ndims, ) elif self.hparams.prob_type == "hierarchical": - self.prob_unet = HierarchicalProbabilisticUnet( - input_channels=self.hparams.input_channels, - num_classes=len(self.hparams.structures), - filters_per_layer=self.hparams.filters_per_layer, - down_channels_per_block=self.hparams.down_channels_per_block, - latent_dims=[self.hparams.latent_dim] * (len(self.hparams.filters_per_layer) - 1), - convs_per_block=self.hparams.convs_per_block, - blocks_per_level=self.hparams.blocks_per_level, - loss_type=self.hparams.loss_type, - loss_params=loss_params, - ndims=self.hparams.ndims, - ) + raise NotImplementedError("Hierarchical Prob UNet current not working...") + # self.prob_unet = HierarchicalProbabilisticUnet( + # input_channels=self.hparams.input_channels, + # num_classes=len(self.hparams.structures), + # filters_per_layer=self.hparams.filters_per_layer, + # down_channels_per_block=self.hparams.down_channels_per_block, + # latent_dims=[self.hparams.latent_dim] * (len(self.hparams.filters_per_layer) - 1), + # convs_per_block=self.hparams.convs_per_block, + # blocks_per_level=self.hparams.blocks_per_level, + # loss_type=self.hparams.loss_type, + # loss_params=loss_params, + # ndims=self.hparams.ndims, + # ) self.validation_directory = None self.kl_div = None @@ -243,15 +244,15 @@ def infer( use_mean=False, sample_x_stddev_from_mean=sample["std_dev_from_mean"], ) - else: - if sample["name"] == "mean": - y = self.prob_unet.sample(x, mean=True) - else: - y = self.prob_unet.sample( - x, - mean=True, - std_devs_from_mean=sample["std_dev_from_mean"], - ) + # else: + # if sample["name"] == "mean": + # y = self.prob_unet.sample(x, mean=True) + # else: + # y = self.prob_unet.sample( + # x, + # mean=True, + # std_devs_from_mean=sample["std_dev_from_mean"], + # ) y = y.squeeze(0) # y = np.argmax(y.cpu().detach().numpy(), axis=0) @@ -412,13 +413,13 @@ def training_step(self, batch, _): # self.prob_unet.forward(x, y, training=True) if self.hparams.prob_type == "prob": self.prob_unet.forward(x, y, training=True) - else: - self.prob_unet.forward(x, y) + # else: + # self.prob_unet.forward(x, y) if self.hparams.prob_type == "prob": loss = self.prob_unet.loss(y, mask=m) - else: - loss = self.prob_unet.loss(x, y, mask=m) + # else: + # loss = self.prob_unet.loss(x, y, mask=m) training_loss = loss["loss"] From 09b704d58a15569762035fa6cacd29e7d0a31158 Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Thu, 22 Sep 2022 16:31:36 +1000 Subject: [PATCH 186/264] Add psDSC metric --- platipy/imaging/cnn/metrics.py | 155 +++++++++++++++++++++++++++++++ platipy/imaging/cnn/prob_unet.py | 2 +- platipy/imaging/cnn/train.py | 59 ++++++++++-- 3 files changed, 208 insertions(+), 8 deletions(-) create mode 100644 platipy/imaging/cnn/metrics.py diff --git a/platipy/imaging/cnn/metrics.py b/platipy/imaging/cnn/metrics.py new file mode 100644 index 00000000..cd4f3494 --- /dev/null +++ b/platipy/imaging/cnn/metrics.py @@ -0,0 +1,155 @@ +import collections +import math + +import SimpleITK as sitk +import numpy as np + +from platipy.imaging.label.comparison import compute_surface_dsc, compute_metric_dsc +from platipy.imaging.label.utils import get_union_mask, get_intersection_mask + + +def probabilistic_dice(gt_labels, sampled_labels, dsc_type="dsc", tau=3): + + gt_union = get_union_mask(gt_labels) + gt_intersection = get_intersection_mask(gt_labels) + + st_union = get_union_mask(sampled_labels) + st_intersection = get_intersection_mask(sampled_labels) + + if dsc_type == "dsc": + dsc_union = compute_metric_dsc(gt_union, st_union) + dsc_intersection = compute_metric_dsc(gt_intersection, st_intersection) + + if dsc_type == "sdsc": + dsc_union = compute_surface_dsc(gt_union, st_union, tau=tau) + dsc_intersection = compute_surface_dsc(gt_intersection, st_intersection, tau=tau) + + return (dsc_union + dsc_intersection) / 2 + + +def probabilistic_surface_dice(gt_labels, sampled_labels, sd_range=3, tau=0): + + if isinstance(gt_labels, dict): + gt_labels = [gt_labels[l] for l in gt_labels] + + if isinstance(sampled_labels, dict): + sampled_labels = [sampled_labels[l] for l in sampled_labels] + + binary_contour_filter = sitk.BinaryContourImageFilter() + binary_contour_filter.FullyConnectedOff() + summed = None + for mask in gt_labels: + + if summed is None: + summed = mask + + else: + summed += mask + + intersection = summed >= 1 + union = summed >= 5 + + mask_mean = summed >= 3 + intersection_minus_mean = intersection - mask_mean + mean_minus_union = mask_mean - union + + contour_i = binary_contour_filter.Execute(intersection) + contour_u = binary_contour_filter.Execute(union) + contour_mean = binary_contour_filter.Execute(mask_mean) + + dist_to_i = sitk.SignedMaurerDistanceMap( + contour_i, useImageSpacing=True, squaredDistance=False + ) + + dist_to_u = sitk.SignedMaurerDistanceMap( + contour_u, useImageSpacing=True, squaredDistance=False + ) + + dist_to_mean = sitk.SignedMaurerDistanceMap( + contour_mean, useImageSpacing=True, squaredDistance=False + ) + + mean = 0 + sd = 1 / sd_range + max_agg = np.pi * sd + + dist_sum = dist_to_mean + dist_to_i + dist_ratio_neg = dist_to_mean / dist_sum + + dist_ratio_arr = sitk.GetArrayFromImage(dist_ratio_neg) + + dist_ratio_arr = (np.pi * sd) * np.exp(-0.5 * ((dist_ratio_arr - mean) / sd) ** 2) + dist_ratio_arr = dist_ratio_arr / max_agg / 2 # Normalise + dist_ratio_arr[sitk.GetArrayFromImage(intersection_minus_mean) == 0] = 0 + dist_ratio_neg = sitk.GetImageFromArray(dist_ratio_arr) + dist_ratio_neg.CopyInformation(dist_sum) + + dist_sum = dist_to_mean + dist_to_u + dist_ratio_pos = dist_to_u / dist_sum + + dist_ratio_arr = sitk.GetArrayFromImage(dist_ratio_pos) + + dist_ratio_arr = (np.pi * sd) * np.exp(-0.5 * ((dist_ratio_arr - mean) / sd) ** 2) + dist_ratio_arr = (dist_ratio_arr / max_agg / 2) + 0.5 # Normalise + dist_ratio_arr[sitk.GetArrayFromImage(mean_minus_union) == 0] = 0 + dist_ratio_arr[sitk.GetArrayFromImage(union) == 1] = 1 + dist_ratio_pos = sitk.GetImageFromArray(dist_ratio_arr) + dist_ratio_pos.CopyInformation(dist_sum) + + dist_ratio = dist_ratio_neg + dist_ratio_pos + + sample_count = math.floor(len(sampled_labels) / 2) + + ranges = {} + range_masks = {} + start_mask = None + for pr in np.linspace(0.5, 1, sample_count + 1): + next_mask = dist_ratio >= pr + next_contour = binary_contour_filter.Execute(next_mask) + + if start_mask is None: + ranges[pr] = next_contour + else: + ranges[pr] = ((start_mask - next_mask) + start_contour + next_contour) > 0 + + range_masks[pr] = next_mask + + start_mask = next_mask + start_contour = binary_contour_filter.Execute(start_mask) + + start_mask = None + for pr in np.linspace(0.5, 0.000001, sample_count + 1): + next_mask = dist_ratio >= pr + next_contour = binary_contour_filter.Execute(next_mask) + + if start_mask is None: + ranges[pr] = next_contour + else: + ranges[pr] = ((next_mask - start_mask) + start_contour + next_contour) > 0 + + range_masks[pr] = next_mask + + start_mask = next_mask + start_contour = binary_contour_filter.Execute(start_mask) + + ranges = collections.OrderedDict(sorted(ranges.items())) + range_masks = collections.OrderedDict(sorted(range_masks.items())) + + result = 0 + for idx, r in enumerate(ranges): + auto_mask = sampled_labels[idx] + auto_contour = binary_contour_filter.Execute(auto_mask > 0) + + dist_to_range = sitk.SignedMaurerDistanceMap( + ranges[r], useImageSpacing=True, squaredDistance=False + ) + + auto_intersection = sitk.GetArrayFromImage(auto_contour * (dist_to_range <= tau)).sum() + + this_result = auto_intersection / sitk.GetArrayFromImage(auto_contour).sum() + if np.isnan(this_result): + this_result = 0 + + result += this_result + + return result / len(ranges) diff --git a/platipy/imaging/cnn/prob_unet.py b/platipy/imaging/cnn/prob_unet.py index 81c8bf90..84f1453e 100644 --- a/platipy/imaging/cnn/prob_unet.py +++ b/platipy/imaging/cnn/prob_unet.py @@ -496,7 +496,7 @@ def loss(self, segm, mask=None, beta=None): } elif self.loss_type == "geco": - rec_geco_step_size = 1e-02 + rec_geco_step_size = self.loss_params["rec_geco_step_size"] with torch.no_grad(): diff --git a/platipy/imaging/cnn/train.py b/platipy/imaging/cnn/train.py index 53e28346..ab5912b8 100644 --- a/platipy/imaging/cnn/train.py +++ b/platipy/imaging/cnn/train.py @@ -16,6 +16,7 @@ import os import tempfile import json +from argparse import ArgumentParser from pathlib import Path import SimpleITK as sitk @@ -30,16 +31,17 @@ import pytorch_lightning as pl from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint -from argparse import ArgumentParser import matplotlib.pyplot as plt from platipy.imaging.cnn.prob_unet import ProbabilisticUnet -from platipy.imaging.cnn.hierarchical_prob_unet import HierarchicalProbabilisticUnet + +# from platipy.imaging.cnn.hierarchical_prob_unet import HierarchicalProbabilisticUnet from platipy.imaging.cnn.unet import l2_regularisation from platipy.imaging.cnn.dataload import UNetDataModule from platipy.imaging.cnn.dataset import crop_img_using_localise_model from platipy.imaging.cnn.utils import preprocess_image, postprocess_mask, get_metrics +from platipy.imaging.cnn.metrics import probabilistic_dice from platipy.imaging import ImageVisualiser from platipy.imaging.label.utils import get_com, get_union_mask, get_intersection_mask @@ -67,6 +69,7 @@ def __init__( "clamp_rec": self.hparams.clamp_rec, "clamp_contour": self.hparams.clamp_contour, "kappa_contour": self.hparams.kappa_contour, + "rec_geco_step_size": self.hparams.rec_geco_step_size, } loss_params["top_k_percentage"] = self.hparams.top_k_percentage @@ -123,6 +126,7 @@ def add_model_specific_args(parent_parser): parser.add_argument("--beta", type=float, default=1.0) parser.add_argument("--kappa", type=float, default=0.02) parser.add_argument("--kappa_contour", type=float, default=None) + parser.add_argument("--rec_geco_step_size", type=float, default=1e-02) parser.add_argument("--clamp_rec", nargs="+", type=float, default=[1e-5, 1e5]) parser.add_argument("--clamp_contour", nargs="+", type=float, default=[1e-3, 1e3]) parser.add_argument("--top_k_percentage", type=float, default=None) @@ -583,6 +587,9 @@ def validation_epoch_end(self, validation_step_outputs): **{f"unet_{s}_{m}": [] for m in metrics for s in self.hparams.structures}, } + prob_surface_dice = 0 + prob_dice = 0 + for case in cases: img_arrs = [] @@ -607,8 +614,8 @@ def validation_epoch_end(self, validation_step_outputs): samples = self.infer( img, sample_strategy="spaced", - num_samples=5, - spaced_range=[-1.5, 1.5], + num_samples=7, + spaced_range=[-3, 3], preprocess=False, ) except Exception as e: @@ -662,6 +669,43 @@ def validation_epoch_end(self, validation_step_outputs): for m in metrics: computed_metrics[f"{t}_{m}"] += result[t][m] + # Compute the probabilistic (surface) dice + for idx, structure in enumerate(self.hparams.structures): + + gt_labels = [] + for _, observer in enumerate(cases[case]["observers"]): + gt_labels.append(observers[f"manual_{observer}"][structure]) + + sample_labels = [] + for rk in samples: + sample_labels.append(samples[rk][structure]) + + # prob_surface_dice += probabilistic_surface_dice(gt_labels, sample_labels, tau=1) + prob_dice += probabilistic_dice(gt_labels, sample_labels, dsc_type="dsc") + prob_surface_dice += probabilistic_dice( + gt_labels, sample_labels, dsc_type="sdsc", tau=3 + ) + + prob_dice = prob_dice / len(cases) + self.log( + "probabilisticDice", + prob_dice, + on_step=False, + on_epoch=True, + prog_bar=False, + logger=True, + ) + + prob_surface_dice = prob_surface_dice / len(cases) + self.log( + "probabilisticSurfaceDice", + prob_surface_dice, + on_step=False, + on_epoch=True, + prog_bar=False, + logger=True, + ) + if self.kl_div: p = u = 0 for s in self.hparams.structures: @@ -695,6 +739,7 @@ def main(args, config_json_path=None): args.fold_dir = args.working_dir.joinpath(f"fold_{args.fold}") args.default_root_dir = str(args.fold_dir) # args.accumulate_grad_batches = {0: 20, 10: 10, 50: 5, 100: 1} + args.precision = 16 comet_api_key = None comet_workspace = None @@ -744,11 +789,11 @@ def main(args, config_json_path=None): # Save the best model checkpoint_callback = ModelCheckpoint( - monitor="GED", + monitor="probabilisticSurfaceDice", dirpath=args.default_root_dir, - filename="probunet-{epoch:02d}-{GED:.2f}", + filename="probunet-{epoch:02d}-{probabilisticSurfaceDice:.2f}", save_top_k=1, - mode="min", + mode="max", ) trainer.callbacks.append(checkpoint_callback) From be80af771031f1ecfd7d190362dc46c3cc47a366 Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Thu, 22 Sep 2022 16:36:10 +1000 Subject: [PATCH 187/264] add images to gitignore --- .gitignore | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.gitignore b/.gitignore index b7918abc..577d061c 100644 --- a/.gitignore +++ b/.gitignore @@ -149,5 +149,7 @@ docs/site/ *.npy *.nii.gz +valid*.png +ged*.png test_prob*/ \ No newline at end of file From b3227935f3681800597a1302e06b513e18fb3f6d Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Fri, 23 Sep 2022 08:30:06 +1000 Subject: [PATCH 188/264] Use reduce lr on plateu schedule --- platipy/imaging/cnn/train.py | 47 ++++++++++++++++++++++++----------- platipy/imaging/utils/crop.py | 7 ++++++ 2 files changed, 40 insertions(+), 14 deletions(-) diff --git a/platipy/imaging/cnn/train.py b/platipy/imaging/cnn/train.py index ab5912b8..c794a41e 100644 --- a/platipy/imaging/cnn/train.py +++ b/platipy/imaging/cnn/train.py @@ -126,7 +126,8 @@ def add_model_specific_args(parent_parser): parser.add_argument("--beta", type=float, default=1.0) parser.add_argument("--kappa", type=float, default=0.02) parser.add_argument("--kappa_contour", type=float, default=None) - parser.add_argument("--rec_geco_step_size", type=float, default=1e-02) + parser.add_argument("--rec_geco_step_size", type=float, default=1e-2) + parser.add_argument("--weight_decay", type=float, default=1e-2) parser.add_argument("--clamp_rec", nargs="+", type=float, default=[1e-5, 1e5]) parser.add_argument("--clamp_contour", nargs="+", type=float, default=[1e-3, 1e3]) parser.add_argument("--top_k_percentage", type=float, default=None) @@ -142,14 +143,31 @@ def forward(self, x): def configure_optimizers(self): optimizer = torch.optim.Adam( - self.parameters(), lr=self.hparams.learning_rate, weight_decay=0 + self.parameters(), + lr=self.hparams.learning_rate, + weight_decay=self.hparams.weight_decay, ) - scheduler = torch.optim.lr_scheduler.LambdaLR( - optimizer, lr_lambda=[lambda epoch: self.hparams.lr_lambda ** (epoch)] - ) - - return [optimizer], [scheduler] + # scheduler = torch.optim.lr_scheduler.LambdaLR( + # optimizer, lr_lambda=[lambda epoch: self.hparams.lr_lambda ** (epoch)] + # ) + # scheduler = torch.optim.lr_scheduler.CyclicLR( + # optimizer, + # base_lr=self.hparams.learning_rate / 10, + # max_lr=self.hparams.learning_rate, + # step_size_up=1000, + # ) + + # return [optimizer], [scheduler] + return { + "optimizer": optimizer, + "lr_scheduler": { + "scheduler": torch.optim.lr_scheduler.ReduceLROnPlateau( + optimizer, "max", patience=3, threshold=0.01, factor=0.5 + ), + "monitor": "probabilisticSurfaceDice", + }, + } def infer( self, @@ -427,13 +445,14 @@ def training_step(self, batch, _): training_loss = loss["loss"] - if self.hparams.prob_type == "prob": - reg_loss = ( - l2_regularisation(self.prob_unet.posterior) - + l2_regularisation(self.prob_unet.prior) - + l2_regularisation(self.prob_unet.fcomb.layers) - ) - training_loss = training_loss + 1e-5 * reg_loss + # Using weight decay instead + # if self.hparams.prob_type == "prob": + # reg_loss = ( + # l2_regularisation(self.prob_unet.posterior) + # + l2_regularisation(self.prob_unet.prior) + # + l2_regularisation(self.prob_unet.fcomb.layers) + # ) + # training_loss = training_loss + 1e-5 * reg_loss self.log( "training_loss", training_loss.detach(), diff --git a/platipy/imaging/utils/crop.py b/platipy/imaging/utils/crop.py index b944f5d8..7bb3a4d6 100644 --- a/platipy/imaging/utils/crop.py +++ b/platipy/imaging/utils/crop.py @@ -45,6 +45,13 @@ def label_to_roi(label, expansion_mm=[0, 0, 0], return_as_list=False): label_stats_image_filter.Execute(reference_label, reference_label) bounding_box = np.array(label_stats_image_filter.GetBoundingBox(1)) + # If bounding_box is empty then the mask is likely empty. Just return entire mask as ROI. + if bounding_box.size == 0: + if return_as_list: + return [0, 0, 0] + [int(x) for x in label.GetSize()] + + return [int(x) for x in label.GetSize()], [0, 0, 0] + index = [bounding_box[x * 2] for x in range(3)] size = [bounding_box[(x * 2) + 1] - bounding_box[x * 2] for x in range(3)] From e77274abf7c83225cab69f5dc57282a213f44527 Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Fri, 23 Sep 2022 08:38:53 +1000 Subject: [PATCH 189/264] Att early stopping --- platipy/imaging/cnn/train.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/platipy/imaging/cnn/train.py b/platipy/imaging/cnn/train.py index c794a41e..2af01703 100644 --- a/platipy/imaging/cnn/train.py +++ b/platipy/imaging/cnn/train.py @@ -30,7 +30,7 @@ import torch import pytorch_lightning as pl from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint - +from pytorch_lightning.callbacks.early_stopping import EarlyStopping import matplotlib.pyplot as plt @@ -165,7 +165,7 @@ def configure_optimizers(self): "scheduler": torch.optim.lr_scheduler.ReduceLROnPlateau( optimizer, "max", patience=3, threshold=0.01, factor=0.5 ), - "monitor": "probabilisticSurfaceDice", + "monitor": "probabilisticDice", }, } @@ -808,7 +808,7 @@ def main(args, config_json_path=None): # Save the best model checkpoint_callback = ModelCheckpoint( - monitor="probabilisticSurfaceDice", + monitor="probabilisticDice", dirpath=args.default_root_dir, filename="probunet-{epoch:02d}-{probabilisticSurfaceDice:.2f}", save_top_k=1, @@ -816,6 +816,11 @@ def main(args, config_json_path=None): ) trainer.callbacks.append(checkpoint_callback) + early_stop_callback = EarlyStopping( + monitor="probabilisticDice", min_delta=0.005, patience=3, verbose=False, mode="max" + ) + trainer.callbacks.append(early_stop_callback) + trainer.fit(prob_unet, data_module) From fb2e0038360a7fe5c3338602c8d8185d12c2d260 Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Fri, 23 Sep 2022 15:49:03 +1000 Subject: [PATCH 190/264] Be more patient --- platipy/imaging/cnn/train.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/platipy/imaging/cnn/train.py b/platipy/imaging/cnn/train.py index 2af01703..762680d1 100644 --- a/platipy/imaging/cnn/train.py +++ b/platipy/imaging/cnn/train.py @@ -163,7 +163,7 @@ def configure_optimizers(self): "optimizer": optimizer, "lr_scheduler": { "scheduler": torch.optim.lr_scheduler.ReduceLROnPlateau( - optimizer, "max", patience=3, threshold=0.01, factor=0.5 + optimizer, "max", patience=5, threshold=0.01, factor=0.5 ), "monitor": "probabilisticDice", }, @@ -817,7 +817,7 @@ def main(args, config_json_path=None): trainer.callbacks.append(checkpoint_callback) early_stop_callback = EarlyStopping( - monitor="probabilisticDice", min_delta=0.005, patience=3, verbose=False, mode="max" + monitor="probabilisticDice", min_delta=0.005, patience=10, verbose=False, mode="max" ) trainer.callbacks.append(early_stop_callback) From dc403c3fe2a4e069ad338fccae19a3e6b340ae13 Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Sat, 24 Sep 2022 10:12:18 +1000 Subject: [PATCH 191/264] Increace patience --- platipy/imaging/cnn/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/platipy/imaging/cnn/train.py b/platipy/imaging/cnn/train.py index 762680d1..2b9a1584 100644 --- a/platipy/imaging/cnn/train.py +++ b/platipy/imaging/cnn/train.py @@ -817,7 +817,7 @@ def main(args, config_json_path=None): trainer.callbacks.append(checkpoint_callback) early_stop_callback = EarlyStopping( - monitor="probabilisticDice", min_delta=0.005, patience=10, verbose=False, mode="max" + monitor="probabilisticDice", min_delta=0.005, patience=25, verbose=False, mode="max" ) trainer.callbacks.append(early_stop_callback) From b007d0f8a4a36f638e65f995465f2278bd9377c9 Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Sun, 25 Sep 2022 09:58:18 +1000 Subject: [PATCH 192/264] Don't early stop --- platipy/imaging/cnn/train.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/platipy/imaging/cnn/train.py b/platipy/imaging/cnn/train.py index 2b9a1584..08b7943c 100644 --- a/platipy/imaging/cnn/train.py +++ b/platipy/imaging/cnn/train.py @@ -163,7 +163,7 @@ def configure_optimizers(self): "optimizer": optimizer, "lr_scheduler": { "scheduler": torch.optim.lr_scheduler.ReduceLROnPlateau( - optimizer, "max", patience=5, threshold=0.01, factor=0.5 + optimizer, "max", patience=25, threshold=0.01, factor=0.5 ), "monitor": "probabilisticDice", }, @@ -819,7 +819,7 @@ def main(args, config_json_path=None): early_stop_callback = EarlyStopping( monitor="probabilisticDice", min_delta=0.005, patience=25, verbose=False, mode="max" ) - trainer.callbacks.append(early_stop_callback) + # trainer.callbacks.append(early_stop_callback) trainer.fit(prob_unet, data_module) From f213daa3ac469bb24978f06ab9ca4160fee1b48d Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Tue, 27 Sep 2022 18:03:23 +1000 Subject: [PATCH 193/264] Reenable early stopping --- platipy/imaging/cnn/train.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/platipy/imaging/cnn/train.py b/platipy/imaging/cnn/train.py index 08b7943c..804ebd18 100644 --- a/platipy/imaging/cnn/train.py +++ b/platipy/imaging/cnn/train.py @@ -163,7 +163,7 @@ def configure_optimizers(self): "optimizer": optimizer, "lr_scheduler": { "scheduler": torch.optim.lr_scheduler.ReduceLROnPlateau( - optimizer, "max", patience=25, threshold=0.01, factor=0.5 + optimizer, "max", patience=25, threshold=0.75, factor=0.5 ), "monitor": "probabilisticDice", }, @@ -817,7 +817,7 @@ def main(args, config_json_path=None): trainer.callbacks.append(checkpoint_callback) early_stop_callback = EarlyStopping( - monitor="probabilisticDice", min_delta=0.005, patience=25, verbose=False, mode="max" + monitor="probabilisticDice", min_delta=0.005, patience=50, verbose=False, mode="max" ) # trainer.callbacks.append(early_stop_callback) From 43d50194d2ee7c7c8103fc77be47730ca2f1c65a Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Wed, 28 Sep 2022 17:00:37 +1000 Subject: [PATCH 194/264] Actually enable early stopping --- platipy/imaging/cnn/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/platipy/imaging/cnn/train.py b/platipy/imaging/cnn/train.py index 804ebd18..8d39c3a4 100644 --- a/platipy/imaging/cnn/train.py +++ b/platipy/imaging/cnn/train.py @@ -819,7 +819,7 @@ def main(args, config_json_path=None): early_stop_callback = EarlyStopping( monitor="probabilisticDice", min_delta=0.005, patience=50, verbose=False, mode="max" ) - # trainer.callbacks.append(early_stop_callback) + trainer.callbacks.append(early_stop_callback) trainer.fit(prob_unet, data_module) From 9c825249acea0d30334cb3ce43fde7399631803f Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Thu, 29 Sep 2022 10:29:25 +1000 Subject: [PATCH 195/264] Remove KL clamp --- platipy/imaging/cnn/prob_unet.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/platipy/imaging/cnn/prob_unet.py b/platipy/imaging/cnn/prob_unet.py index 84f1453e..69dd997c 100644 --- a/platipy/imaging/cnn/prob_unet.py +++ b/platipy/imaging/cnn/prob_unet.py @@ -119,8 +119,8 @@ def forward(self, img, seg=None): if self.ndims == 3: mu_log_sigma = torch.squeeze(mu_log_sigma, dim=2) - mu = mu_log_sigma[:, : self.latent_dim].clamp(-1000, 1000) - log_sigma = mu_log_sigma[:, self.latent_dim :].clamp(-10, 10) + mu = mu_log_sigma[:, : self.latent_dim]#.clamp(-1000, 1000) + log_sigma = mu_log_sigma[:, self.latent_dim :]#.clamp(-10, 10) # This is a multivariate normal with diagonal covariance matrix sigma # https://github.com/pytorch/pytorch/pull/11178 From 1fbe811f1c44c14e103b7e015d76c4f8525b4392 Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Thu, 29 Sep 2022 17:38:56 +1000 Subject: [PATCH 196/264] Add KL clamp back in --- platipy/imaging/cnn/prob_unet.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/platipy/imaging/cnn/prob_unet.py b/platipy/imaging/cnn/prob_unet.py index 69dd997c..84f1453e 100644 --- a/platipy/imaging/cnn/prob_unet.py +++ b/platipy/imaging/cnn/prob_unet.py @@ -119,8 +119,8 @@ def forward(self, img, seg=None): if self.ndims == 3: mu_log_sigma = torch.squeeze(mu_log_sigma, dim=2) - mu = mu_log_sigma[:, : self.latent_dim]#.clamp(-1000, 1000) - log_sigma = mu_log_sigma[:, self.latent_dim :]#.clamp(-10, 10) + mu = mu_log_sigma[:, : self.latent_dim].clamp(-1000, 1000) + log_sigma = mu_log_sigma[:, self.latent_dim :].clamp(-10, 10) # This is a multivariate normal with diagonal covariance matrix sigma # https://github.com/pytorch/pytorch/pull/11178 From 9b0f7251b4b3bb56603a5b73fe9cb9b71ae35019 Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Sat, 1 Oct 2022 10:48:13 +1000 Subject: [PATCH 197/264] Add in dropout probability --- platipy/imaging/cnn/prob_unet.py | 11 ++++++----- platipy/imaging/cnn/train.py | 2 ++ 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/platipy/imaging/cnn/prob_unet.py b/platipy/imaging/cnn/prob_unet.py index 84f1453e..a37eb401 100644 --- a/platipy/imaging/cnn/prob_unet.py +++ b/platipy/imaging/cnn/prob_unet.py @@ -26,7 +26,7 @@ class Encoder(torch.nn.Module): """Encoder part of the probabilistic UNet""" def __init__( - self, input_channels, filters_per_layer=[64 * (2**x) for x in range(5)], ndims=2 + self, input_channels, filters_per_layer=[64 * (2**x) for x in range(5)], ndims=2, dropout_probability=None ): super(Encoder, self).__init__() @@ -44,7 +44,7 @@ def __init__( output_filters, up_down_sample=down_sample, ndims=ndims, - dropout_probability=None, + dropout_probability=dropout_probability, ) ) @@ -64,13 +64,14 @@ def __init__( filters_per_layer=[64 * (2**x) for x in range(5)], latent_dim=2, ndims=2, + dropout_probability=None ): super(AxisAlignedConvGaussian, self).__init__() self.latent_dim = latent_dim - self.encoder = Encoder(input_channels, filters_per_layer, ndims=ndims) + self.encoder = Encoder(input_channels, filters_per_layer, ndims=ndims, dropout_probability=dropout_probability) self.final = conv_nd( in_channels=filters_per_layer[-1], @@ -230,10 +231,10 @@ def __init__( ndims=ndims, ) self.prior = AxisAlignedConvGaussian( - input_channels, filters_per_layer, latent_dim, ndims=ndims + input_channels, filters_per_layer, latent_dim, ndims=ndims, dropout_probability=dropout_probability ) self.posterior = AxisAlignedConvGaussian( - input_channels + num_classes, filters_per_layer, latent_dim, ndims=ndims + input_channels + num_classes, filters_per_layer, latent_dim, ndims=ndims, dropout_probability=dropout_probability ) self.fcomb = Fcomb(filters_per_layer, latent_dim, num_classes, no_convs_fcomb, ndims=ndims) diff --git a/platipy/imaging/cnn/train.py b/platipy/imaging/cnn/train.py index 8d39c3a4..7d4fdb4b 100644 --- a/platipy/imaging/cnn/train.py +++ b/platipy/imaging/cnn/train.py @@ -86,6 +86,7 @@ def __init__( self.hparams.loss_type, loss_params, self.hparams.ndims, + dropout_probability=self.hparams.dropout_probability ) elif self.hparams.prob_type == "hierarchical": raise NotImplementedError("Hierarchical Prob UNet current not working...") @@ -134,6 +135,7 @@ def add_model_specific_args(parent_parser): parser.add_argument("--contour_loss_lambda_threshold", type=float, default=None) parser.add_argument("--contour_loss_weight", type=float, default=0.0) # no longer used parser.add_argument("--epochs_all_rec", type=int, default=0) # no longer used + parser.add_argument("--dropout_probability", type=float, default=0.0) return parent_parser From 27a5b77315ae2c1e0f21d3372d6745e773a57875 Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Mon, 3 Oct 2022 11:05:29 +1100 Subject: [PATCH 198/264] Turn off dropout for prob part --- platipy/imaging/cnn/prob_unet.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/platipy/imaging/cnn/prob_unet.py b/platipy/imaging/cnn/prob_unet.py index a37eb401..99021732 100644 --- a/platipy/imaging/cnn/prob_unet.py +++ b/platipy/imaging/cnn/prob_unet.py @@ -231,10 +231,10 @@ def __init__( ndims=ndims, ) self.prior = AxisAlignedConvGaussian( - input_channels, filters_per_layer, latent_dim, ndims=ndims, dropout_probability=dropout_probability + input_channels, filters_per_layer, latent_dim, ndims=ndims, dropout_probability=None ) self.posterior = AxisAlignedConvGaussian( - input_channels + num_classes, filters_per_layer, latent_dim, ndims=ndims, dropout_probability=dropout_probability + input_channels + num_classes, filters_per_layer, latent_dim, ndims=ndims, dropout_probability=None ) self.fcomb = Fcomb(filters_per_layer, latent_dim, num_classes, no_convs_fcomb, ndims=ndims) From 13a6549eeaf94110f85323d76827217919637a0a Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Fri, 7 Oct 2022 16:14:06 +1100 Subject: [PATCH 199/264] Reenable weight agg --- platipy/imaging/cnn/train.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/platipy/imaging/cnn/train.py b/platipy/imaging/cnn/train.py index 7d4fdb4b..d33a502d 100644 --- a/platipy/imaging/cnn/train.py +++ b/platipy/imaging/cnn/train.py @@ -759,8 +759,8 @@ def main(args, config_json_path=None): # args.default_root_dir = str(args.working_dir) args.fold_dir = args.working_dir.joinpath(f"fold_{args.fold}") args.default_root_dir = str(args.fold_dir) - # args.accumulate_grad_batches = {0: 20, 10: 10, 50: 5, 100: 1} - args.precision = 16 + args.accumulate_grad_batches = {0: 1, 5: 10, 10: 20} +# args.precision = 16 comet_api_key = None comet_workspace = None From 5ccfab6a576eb62f07d1d5a7208b642269b1774d Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Mon, 17 Oct 2022 15:03:22 +1100 Subject: [PATCH 200/264] Ensure dropout layers are there even if not used --- platipy/imaging/cnn/prob_unet.py | 8 ++++---- platipy/imaging/cnn/unet.py | 10 ++++------ 2 files changed, 8 insertions(+), 10 deletions(-) diff --git a/platipy/imaging/cnn/prob_unet.py b/platipy/imaging/cnn/prob_unet.py index 99021732..5ec0da91 100644 --- a/platipy/imaging/cnn/prob_unet.py +++ b/platipy/imaging/cnn/prob_unet.py @@ -64,7 +64,7 @@ def __init__( filters_per_layer=[64 * (2**x) for x in range(5)], latent_dim=2, ndims=2, - dropout_probability=None + dropout_probability=0.0 ): super(AxisAlignedConvGaussian, self).__init__() @@ -212,7 +212,7 @@ def __init__( loss_type="elbo", loss_params={"beta": 1}, ndims=2, - dropout_probability=None, + dropout_probability=0.0, ): super(ProbabilisticUnet, self).__init__() @@ -231,10 +231,10 @@ def __init__( ndims=ndims, ) self.prior = AxisAlignedConvGaussian( - input_channels, filters_per_layer, latent_dim, ndims=ndims, dropout_probability=None + input_channels, filters_per_layer, latent_dim, ndims=ndims, dropout_probability=0.0 ) self.posterior = AxisAlignedConvGaussian( - input_channels + num_classes, filters_per_layer, latent_dim, ndims=ndims, dropout_probability=None + input_channels + num_classes, filters_per_layer, latent_dim, ndims=ndims, dropout_probability=0.0 ) self.fcomb = Fcomb(filters_per_layer, latent_dim, num_classes, no_convs_fcomb, ndims=ndims) diff --git a/platipy/imaging/cnn/unet.py b/platipy/imaging/cnn/unet.py index 6b674c1e..e16c8d0b 100644 --- a/platipy/imaging/cnn/unet.py +++ b/platipy/imaging/cnn/unet.py @@ -157,7 +157,7 @@ def resize_up_func(in_channels, out_channels, scale=2, ndims=2): class Conv(torch.nn.Module): def __init__( - self, input_channels, output_channels, up_down_sample=0, dropout_probability=None, ndims=2 + self, input_channels, output_channels, up_down_sample=0, dropout_probability=0.0, ndims=2 ): super(Conv, self).__init__() @@ -182,8 +182,7 @@ def __init__( ) ) layers.append(nn.ReLU(inplace=True)) - if dropout_probability: - layers.append(dropout_nd(ndims=ndims, p=dropout_probability)) + layers.append(dropout_nd(ndims=ndims, p=dropout_probability)) layers.append( conv_nd( ndims=ndims, @@ -193,8 +192,7 @@ def __init__( padding=1, ) ) - if dropout_probability: - layers.append(dropout_nd(ndims=ndims, p=dropout_probability)) + layers.append(dropout_nd(ndims=ndims, p=dropout_probability)) layers.append(nn.ReLU(inplace=True)) self.layers = nn.Sequential(*layers) @@ -219,7 +217,7 @@ def __init__( filters_per_layer=[64 * (2 ** x) for x in range(5)], final_layer=True, ndims=2, - dropout_probability=None + dropout_probability=0.0 ): super(UNet, self).__init__() From 72a67fc1345333621a70a96e0f5350a4a8909634 Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Thu, 5 Jan 2023 15:10:12 +1100 Subject: [PATCH 201/264] Adjustments for early stopping --- platipy/imaging/cnn/dataload.py | 3 +- platipy/imaging/cnn/dataset.py | 1 + platipy/imaging/cnn/train.py | 84 ++++++++++++++++++--------- platipy/imaging/cnn/train_localise.py | 1 + platipy/imaging/generation/augment.py | 4 +- 5 files changed, 61 insertions(+), 32 deletions(-) diff --git a/platipy/imaging/cnn/dataload.py b/platipy/imaging/cnn/dataload.py index 82350b55..8939a7d5 100644 --- a/platipy/imaging/cnn/dataload.py +++ b/platipy/imaging/cnn/dataload.py @@ -171,8 +171,6 @@ def setup(self, stage=None): for case in self.train_cases ] - print(train_data) - # If a directory with augmented data is specified, use that for training as well if self.augmented_dir is not None: @@ -229,6 +227,7 @@ def setup(self, stage=None): } for case in self.validation_cases ] + print(self.validation_data) self.test_data = [ { diff --git a/platipy/imaging/cnn/dataset.py b/platipy/imaging/cnn/dataset.py index fcfba346..ad8a6434 100644 --- a/platipy/imaging/cnn/dataset.py +++ b/platipy/imaging/cnn/dataset.py @@ -594,6 +594,7 @@ def __getitem__(self, index): contour_mask = torch.FloatTensor( np.concatenate([np.expand_dims(l, 0) for l in contour_masks], 0) ) + contour_mask = contour_mask.max(axis=0).values.unsqueeze(0) label_present = [label is not None for label in self.slices[index]["labels"]] return ( diff --git a/platipy/imaging/cnn/train.py b/platipy/imaging/cnn/train.py index d33a502d..ed0d7ea3 100644 --- a/platipy/imaging/cnn/train.py +++ b/platipy/imaging/cnn/train.py @@ -165,7 +165,8 @@ def configure_optimizers(self): "optimizer": optimizer, "lr_scheduler": { "scheduler": torch.optim.lr_scheduler.ReduceLROnPlateau( - optimizer, "max", patience=25, threshold=0.75, factor=0.5 + optimizer, "max", patience=25, threshold=0.1e-2, factor=0.75 +# optimizer, "max", patience=200, threshold=0.75, factor=0.5 ), "monitor": "probabilisticDice", }, @@ -315,7 +316,7 @@ def validate( metrics = {"DSC": "max", "HD": "min", "ASD": "min"} result = {} - contour_cmaps = ["RdPu", "YlOrRd", "GnBu"] + contour_cmaps = ["RdPu", "YlOrRd", "GnBu", "OrRd", "YlGn", "YlGnBu"] structures = self.hparams.structures try: @@ -527,6 +528,8 @@ def validation_step(self, batch, _): y = y.int() y = y.to("cpu") + + # TODO Make this work for multi class # Intersection over Union (also known as Jaccard Index) jaccard = JaccardIndex(num_classes=2) term_1 = 0 @@ -608,6 +611,8 @@ def validation_epoch_end(self, validation_step_outputs): **{f"unet_{s}_{m}": [] for m in metrics for s in self.hparams.structures}, } + if len(cases) == 0: return + prob_surface_dice = 0 prob_dice = 0 @@ -701,13 +706,13 @@ def validation_epoch_end(self, validation_step_outputs): for rk in samples: sample_labels.append(samples[rk][structure]) - # prob_surface_dice += probabilistic_surface_dice(gt_labels, sample_labels, tau=1) prob_dice += probabilistic_dice(gt_labels, sample_labels, dsc_type="dsc") prob_surface_dice += probabilistic_dice( gt_labels, sample_labels, dsc_type="sdsc", tau=3 ) prob_dice = prob_dice / len(cases) + if np.isnan(prob_dice): prob_dice = 0 self.log( "probabilisticDice", prob_dice, @@ -718,6 +723,7 @@ def validation_epoch_end(self, validation_step_outputs): ) prob_surface_dice = prob_surface_dice / len(cases) + if np.isnan(prob_surface_dice): prob_surface_dice = 0 self.log( "probabilisticSurfaceDice", prob_surface_dice, @@ -809,23 +815,41 @@ def main(args, config_json_path=None): trainer.callbacks.append(lr_monitor) # Save the best model - checkpoint_callback = ModelCheckpoint( - monitor="probabilisticDice", - dirpath=args.default_root_dir, - filename="probunet-{epoch:02d}-{probabilisticSurfaceDice:.2f}", - save_top_k=1, - mode="max", - ) - trainer.callbacks.append(checkpoint_callback) + if args.checkpoint_var: + checkpoint_callback = ModelCheckpoint( + monitor=args.checkpoint_var, + dirpath=args.default_root_dir, + filename="probunet-{epoch:02d}-{"+args.checkpoint_var+":.2f}", + save_top_k=1, + mode=args.checkpoint_mode, + ) + trainer.callbacks.append(checkpoint_callback) - early_stop_callback = EarlyStopping( - monitor="probabilisticDice", min_delta=0.005, patience=50, verbose=False, mode="max" - ) - trainer.callbacks.append(early_stop_callback) + if args.early_stopping_var: + early_stop_callback = EarlyStopping( + monitor=args.early_stopping_var, min_delta=args.early_stopping_min_delta, patience=args.early_stopping_patience, verbose=False, mode=args.early_stopping_mode + ) + trainer.callbacks.append(early_stop_callback) trainer.fit(prob_unet, data_module) +def parse_config_file(config_json_path, args): + + with open(config_json_path, "r") as f: + params = json.load(f) + for key in params: + args.append(f"--{key}") + + if isinstance(params[key], list): + for list_val in params[key]: + args.append(str(list_val)) + else: + args.append(str(params[key])) + + + return args + if __name__ == "__main__": args = None @@ -834,17 +858,7 @@ def main(args, config_json_path=None): # Check if JSON file parsed, if so read arguments from there... if sys.argv[-1].endswith(".json"): config_json_path = sys.argv[-1] - with open(config_json_path, "r") as f: - params = json.load(f) - args = [] - for key in params: - args.append(f"--{key}") - - if isinstance(params[key], list): - for list_val in params[key]: - args.append(str(list_val)) - else: - args.append(str(params[key])) + args = parse_config_file(config_json_path, []) arg_parser = ArgumentParser() arg_parser = ProbUNet.add_model_specific_args(arg_parser) @@ -863,5 +877,19 @@ def main(args, config_json_path=None): arg_parser.add_argument("--comet_workspace", type=str, default=None) arg_parser.add_argument("--comet_project", type=str, default=None) arg_parser.add_argument("--resume_from", type=str, default=None) - - main(arg_parser.parse_args(args), config_json_path=config_json_path) + arg_parser.add_argument("--early_stopping_var", type=str, default=None) + arg_parser.add_argument("--early_stopping_min_delta", type=float, default=0.01) + arg_parser.add_argument("--early_stopping_patience", type=int, default=50) + arg_parser.add_argument("--early_stopping_mode", type=str, default="max") + arg_parser.add_argument("--checkpoint_var", type=str, default=None) + arg_parser.add_argument("--checkpoint_mode", type=str, default="max") + + parsed_args = arg_parser.parse_args(args) + + # Check if config arg parsed, if so take over values and reparse + if parsed_args.config: + print("parseing args") + args = parse_config_file(parsed_args.config, sys.argv[1:]) + parsed_args = arg_parser.parse_args(args) + + main(parsed_args) diff --git a/platipy/imaging/cnn/train_localise.py b/platipy/imaging/cnn/train_localise.py index e4932856..af2a3d6c 100644 --- a/platipy/imaging/cnn/train_localise.py +++ b/platipy/imaging/cnn/train_localise.py @@ -37,6 +37,7 @@ def main(args, config_json_path=None): args.working_dir = args.working_dir.joinpath(args.experiment) args.fold_dir = args.working_dir.joinpath(f"fold_{args.fold}") args.default_root_dir = str(args.fold_dir) + args.num_sanity_val_steps = 0 comet_api_key = None comet_workspace = None diff --git a/platipy/imaging/generation/augment.py b/platipy/imaging/generation/augment.py index 338de085..f099f42d 100644 --- a/platipy/imaging/generation/augment.py +++ b/platipy/imaging/generation/augment.py @@ -285,7 +285,7 @@ def augment_data(args): data = { case: { "image": data_dir.joinpath(args.image_glob.format(case=case)), - "label": [p for p in data_dir.glob(args.label_glob.format(case=case))], + "label": [i for sl in [list(data_dir.glob(lg.format(case=case))) for lg in args.label_glob] for i in sl], } for case in cases } @@ -412,7 +412,7 @@ def augment_data(args): arg_parser.add_argument("--output_dir", type=str, default="./augment") arg_parser.add_argument("--case_glob", type=str, default="images/*.nii.gz") arg_parser.add_argument("--image_glob", type=str, default="images/{case}.nii.gz") - arg_parser.add_argument("--label_glob", type=str, default="labels/{case}_*.nii.gz") + arg_parser.add_argument("--label_glob", nargs="+", type=str, default="labels/{case}_*.nii.gz") arg_parser.add_argument( "--augmentations_per_case", type=int, From 89adbe7addfe7d8f66f57ab26873bdcfb6ec485b Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Fri, 6 Jan 2023 09:42:31 +1100 Subject: [PATCH 202/264] Fix dataset bug --- platipy/imaging/cnn/dataset.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/platipy/imaging/cnn/dataset.py b/platipy/imaging/cnn/dataset.py index ad8a6434..ddbebd13 100644 --- a/platipy/imaging/cnn/dataset.py +++ b/platipy/imaging/cnn/dataset.py @@ -428,6 +428,8 @@ def __init__( } ) + print(f"Loaded {len(self.slices)} slices") + print(self.slices) continue logger.debug(f"Generating images for case: {case_id}") @@ -454,7 +456,8 @@ def __init__( for obs in case["observers"]: observers[obs] = {} for structure in case["observers"][obs]: - structure_names.append(structure) + if structure not in structure_names: + structure_names.append(structure) structure_path = case["observers"][obs][structure] label = None @@ -559,6 +562,8 @@ def __init__( "observer": obs, } ) + print(f"Generated {len(self.slices)} slices") + print(self.slices) def __len__(self): return len(self.slices) From a4becf12743b7902bcca181ccfdf86e9c8889184 Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Fri, 6 Jan 2023 09:46:29 +1100 Subject: [PATCH 203/264] Remove debug statements --- platipy/imaging/cnn/dataset.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/platipy/imaging/cnn/dataset.py b/platipy/imaging/cnn/dataset.py index ddbebd13..ace3d695 100644 --- a/platipy/imaging/cnn/dataset.py +++ b/platipy/imaging/cnn/dataset.py @@ -428,8 +428,6 @@ def __init__( } ) - print(f"Loaded {len(self.slices)} slices") - print(self.slices) continue logger.debug(f"Generating images for case: {case_id}") @@ -562,8 +560,6 @@ def __init__( "observer": obs, } ) - print(f"Generated {len(self.slices)} slices") - print(self.slices) def __len__(self): return len(self.slices) From cb4f1b7e63fd23ebc3362344b812977c356edf7c Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Fri, 6 Jan 2023 14:42:12 +1100 Subject: [PATCH 204/264] Adjust patience --- platipy/imaging/cnn/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/platipy/imaging/cnn/train.py b/platipy/imaging/cnn/train.py index ed0d7ea3..c25afce4 100644 --- a/platipy/imaging/cnn/train.py +++ b/platipy/imaging/cnn/train.py @@ -165,7 +165,7 @@ def configure_optimizers(self): "optimizer": optimizer, "lr_scheduler": { "scheduler": torch.optim.lr_scheduler.ReduceLROnPlateau( - optimizer, "max", patience=25, threshold=0.1e-2, factor=0.75 + optimizer, "max", patience=20, threshold=0.1e-2, factor=0.75 # optimizer, "max", patience=200, threshold=0.75, factor=0.5 ), "monitor": "probabilisticDice", From 0a55d0118d3760f990307584a4dac908d074dcbb Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Sun, 8 Jan 2023 12:24:41 +1100 Subject: [PATCH 205/264] Ensure lambda is below zero before early stop --- platipy/imaging/cnn/train.py | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/platipy/imaging/cnn/train.py b/platipy/imaging/cnn/train.py index c25afce4..6a33e601 100644 --- a/platipy/imaging/cnn/train.py +++ b/platipy/imaging/cnn/train.py @@ -46,6 +46,23 @@ from platipy.imaging import ImageVisualiser from platipy.imaging.label.utils import get_com, get_union_mask, get_intersection_mask +class GECOEarlyStopping(EarlyStopping): + def on_validation_end(self, trainer, pl_module): + + # Make sure the GECO lambda metrics are below 1 before stopping + logs = trainer.callback_metrics + should_consider_early_stop = True + if "lambda_rec" in logs and logs["lambda_rec"] >= 1: + should_consider_early_stop = False + + if "lambda_contour" in logs and logs["lambda_contour"] >= 1: + should_consider_early_stop = False + + if should_consider_early_stop: + self._run_early_stopping_check(trainer) + + def on_train_end(self, trainer, pl_module): + pass class ProbUNet(pl.LightningModule): def __init__( @@ -826,7 +843,7 @@ def main(args, config_json_path=None): trainer.callbacks.append(checkpoint_callback) if args.early_stopping_var: - early_stop_callback = EarlyStopping( + early_stop_callback = GECOEarlyStopping( monitor=args.early_stopping_var, min_delta=args.early_stopping_min_delta, patience=args.early_stopping_patience, verbose=False, mode=args.early_stopping_mode ) trainer.callbacks.append(early_stop_callback) From 9f47ae20a3cbf634b5738922de60ce000258e230 Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Mon, 9 Jan 2023 08:58:42 +1100 Subject: [PATCH 206/264] Early stopping tweaks --- platipy/imaging/cnn/train.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/platipy/imaging/cnn/train.py b/platipy/imaging/cnn/train.py index 6a33e601..ddd8bc56 100644 --- a/platipy/imaging/cnn/train.py +++ b/platipy/imaging/cnn/train.py @@ -49,13 +49,13 @@ class GECOEarlyStopping(EarlyStopping): def on_validation_end(self, trainer, pl_module): - # Make sure the GECO lambda metrics are below 1 before stopping + # Make sure the GECO lambda metrics are below 0.1 before stopping logs = trainer.callback_metrics should_consider_early_stop = True - if "lambda_rec" in logs and logs["lambda_rec"] >= 1: + if "lambda_rec" in logs and logs["lambda_rec"] >= 0.1: should_consider_early_stop = False - if "lambda_contour" in logs and logs["lambda_contour"] >= 1: + if "lambda_contour" in logs and logs["lambda_contour"] >= 0.1: should_consider_early_stop = False if should_consider_early_stop: @@ -844,7 +844,7 @@ def main(args, config_json_path=None): if args.early_stopping_var: early_stop_callback = GECOEarlyStopping( - monitor=args.early_stopping_var, min_delta=args.early_stopping_min_delta, patience=args.early_stopping_patience, verbose=False, mode=args.early_stopping_mode + monitor=args.early_stopping_var, min_delta=args.early_stopping_min_delta, patience=args.early_stopping_patience, verbose=True, mode=args.early_stopping_mode ) trainer.callbacks.append(early_stop_callback) From 152449ed73040ae8e7d5bfbf02ddc3313ba7cd72 Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Mon, 9 Jan 2023 10:34:44 +1100 Subject: [PATCH 207/264] Cycle learning rate --- platipy/imaging/cnn/train.py | 22 ++++++++++++---------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/platipy/imaging/cnn/train.py b/platipy/imaging/cnn/train.py index ddd8bc56..de196893 100644 --- a/platipy/imaging/cnn/train.py +++ b/platipy/imaging/cnn/train.py @@ -52,16 +52,17 @@ def on_validation_end(self, trainer, pl_module): # Make sure the GECO lambda metrics are below 0.1 before stopping logs = trainer.callback_metrics should_consider_early_stop = True - if "lambda_rec" in logs and logs["lambda_rec"] >= 0.1: + + if "lambda_rec" in logs and logs["lambda_rec"] >= 0.01: should_consider_early_stop = False - if "lambda_contour" in logs and logs["lambda_contour"] >= 0.1: + if "lambda_contour" in logs and logs["lambda_contour"] >= 0.01: should_consider_early_stop = False if should_consider_early_stop: self._run_early_stopping_check(trainer) - def on_train_end(self, trainer, pl_module): + def on_train_epoch_end(self, trainer, pl_module): pass class ProbUNet(pl.LightningModule): @@ -170,14 +171,15 @@ def configure_optimizers(self): # scheduler = torch.optim.lr_scheduler.LambdaLR( # optimizer, lr_lambda=[lambda epoch: self.hparams.lr_lambda ** (epoch)] # ) - # scheduler = torch.optim.lr_scheduler.CyclicLR( - # optimizer, - # base_lr=self.hparams.learning_rate / 10, - # max_lr=self.hparams.learning_rate, - # step_size_up=1000, - # ) + scheduler = torch.optim.lr_scheduler.CyclicLR( + optimizer, + base_lr=self.hparams.learning_rate / 100, + max_lr=self.hparams.learning_rate, + step_size_up=200, + ) + + return [optimizer], [scheduler] - # return [optimizer], [scheduler] return { "optimizer": optimizer, "lr_scheduler": { From 8ddd2b195d78766bb895fa85d99997cff028bf4f Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Mon, 9 Jan 2023 10:44:32 +1100 Subject: [PATCH 208/264] Don't cycle momentum --- platipy/imaging/cnn/train.py | 1 + 1 file changed, 1 insertion(+) diff --git a/platipy/imaging/cnn/train.py b/platipy/imaging/cnn/train.py index de196893..3e56b87d 100644 --- a/platipy/imaging/cnn/train.py +++ b/platipy/imaging/cnn/train.py @@ -176,6 +176,7 @@ def configure_optimizers(self): base_lr=self.hparams.learning_rate / 100, max_lr=self.hparams.learning_rate, step_size_up=200, + cycle_momentum=False ) return [optimizer], [scheduler] From 1088363d91323f621113029644ba73337eadd9cd Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Mon, 9 Jan 2023 14:26:29 +1100 Subject: [PATCH 209/264] change LR step size --- platipy/imaging/cnn/train.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/platipy/imaging/cnn/train.py b/platipy/imaging/cnn/train.py index 3e56b87d..50ad113e 100644 --- a/platipy/imaging/cnn/train.py +++ b/platipy/imaging/cnn/train.py @@ -175,7 +175,7 @@ def configure_optimizers(self): optimizer, base_lr=self.hparams.learning_rate / 100, max_lr=self.hparams.learning_rate, - step_size_up=200, + step_size_up=100, cycle_momentum=False ) @@ -785,7 +785,7 @@ def main(args, config_json_path=None): # args.default_root_dir = str(args.working_dir) args.fold_dir = args.working_dir.joinpath(f"fold_{args.fold}") args.default_root_dir = str(args.fold_dir) - args.accumulate_grad_batches = {0: 1, 5: 10, 10: 20} +# args.accumulate_grad_batches = {0: 1, 5: 10, 10: 20} # args.precision = 16 comet_api_key = None From 19a9ce256a723105f728e9ab6c04905b717aabad Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Mon, 9 Jan 2023 19:30:43 +1100 Subject: [PATCH 210/264] Faster cycle LR --- platipy/imaging/cnn/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/platipy/imaging/cnn/train.py b/platipy/imaging/cnn/train.py index 50ad113e..c006afcc 100644 --- a/platipy/imaging/cnn/train.py +++ b/platipy/imaging/cnn/train.py @@ -175,7 +175,7 @@ def configure_optimizers(self): optimizer, base_lr=self.hparams.learning_rate / 100, max_lr=self.hparams.learning_rate, - step_size_up=100, + step_size_up=25, cycle_momentum=False ) From 6ee758597402f19773bdf175b194e7062cec2043 Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Fri, 13 Jan 2023 14:27:25 +1100 Subject: [PATCH 211/264] Adjustments to learning rate --- platipy/imaging/cnn/train.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/platipy/imaging/cnn/train.py b/platipy/imaging/cnn/train.py index c006afcc..2db3a12a 100644 --- a/platipy/imaging/cnn/train.py +++ b/platipy/imaging/cnn/train.py @@ -175,7 +175,9 @@ def configure_optimizers(self): optimizer, base_lr=self.hparams.learning_rate / 100, max_lr=self.hparams.learning_rate, - step_size_up=25, + step_size_up=10, + mode="exp_range", + gamma=0.99, cycle_momentum=False ) @@ -785,7 +787,7 @@ def main(args, config_json_path=None): # args.default_root_dir = str(args.working_dir) args.fold_dir = args.working_dir.joinpath(f"fold_{args.fold}") args.default_root_dir = str(args.fold_dir) -# args.accumulate_grad_batches = {0: 1, 5: 10, 10: 20} + args.accumulate_grad_batches = {0: 10, 5: 15, 10: 20} # args.precision = 16 comet_api_key = None From c8b7a0916f0b71395bc439b85e04bb3d9e65a765 Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Fri, 13 Jan 2023 14:29:14 +1100 Subject: [PATCH 212/264] change batch agg --- platipy/imaging/cnn/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/platipy/imaging/cnn/train.py b/platipy/imaging/cnn/train.py index 2db3a12a..a3f2795f 100644 --- a/platipy/imaging/cnn/train.py +++ b/platipy/imaging/cnn/train.py @@ -787,7 +787,7 @@ def main(args, config_json_path=None): # args.default_root_dir = str(args.working_dir) args.fold_dir = args.working_dir.joinpath(f"fold_{args.fold}") args.default_root_dir = str(args.fold_dir) - args.accumulate_grad_batches = {0: 10, 5: 15, 10: 20} + args.accumulate_grad_batches = {0: 5, 5: 10, 10: 15} # args.precision = 16 comet_api_key = None From 2e6d8d1b0905c61ce42e1d9e8153756de42b5539 Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Wed, 18 Jan 2023 13:37:03 +1100 Subject: [PATCH 213/264] Adjust LR schedule --- platipy/imaging/cnn/train.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/platipy/imaging/cnn/train.py b/platipy/imaging/cnn/train.py index a3f2795f..12a041d6 100644 --- a/platipy/imaging/cnn/train.py +++ b/platipy/imaging/cnn/train.py @@ -173,11 +173,11 @@ def configure_optimizers(self): # ) scheduler = torch.optim.lr_scheduler.CyclicLR( optimizer, - base_lr=self.hparams.learning_rate / 100, + base_lr=self.hparams.learning_rate / 10, max_lr=self.hparams.learning_rate, - step_size_up=10, + step_size_up=20, mode="exp_range", - gamma=0.99, + gamma=0.999, cycle_momentum=False ) From f192858d3aa68013818166f8a0c1dfba0b3e293f Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Wed, 15 Mar 2023 16:05:26 +1100 Subject: [PATCH 214/264] Use Cosine annealing LR --- platipy/imaging/cnn/train.py | 47 +++++++++++++++++++++++++----------- 1 file changed, 33 insertions(+), 14 deletions(-) diff --git a/platipy/imaging/cnn/train.py b/platipy/imaging/cnn/train.py index 12a041d6..4ca53886 100644 --- a/platipy/imaging/cnn/train.py +++ b/platipy/imaging/cnn/train.py @@ -162,23 +162,42 @@ def forward(self, x): return x def configure_optimizers(self): + + params = [{ + 'params': self.prob_unet.unet.parameters(), + 'weight_decay': self.hparams.weight_decay, + 'lr': 1e-4 + }] + for m in [self.prob_unet.prior.parameters(), self.prob_unet.posterior.parameters(), self.prob_unet.fcomb.parameters()]: + params += [{'params': m, 'weight_decay': self.hparams.weight_decay, 'lr': 1e-5}] + optimizer = torch.optim.Adam( - self.parameters(), - lr=self.hparams.learning_rate, - weight_decay=self.hparams.weight_decay, + params ) - # scheduler = torch.optim.lr_scheduler.LambdaLR( - # optimizer, lr_lambda=[lambda epoch: self.hparams.lr_lambda ** (epoch)] + lr_lambda_unet = lambda epoch: self.hparams.lr_lambda ** (epoch) + lr_lambda_prob = lambda epoch: 0.99 ** (epoch) + +# max_epochs = self.hparams.max_epochs +# lr_lambda = lambda x: np.interp(((np.sin(x/(max_epochs/8)) * np.sin(x/(max_epochs/4)))), np.array([-1,0,1]), np.array([0.1,1,10])) + + #scheduler = torch.optim.lr_scheduler.LambdaLR( + # optimizer, lr_lambda=[lr_lambda_unet, lr_lambda_prob, lr_lambda_prob, lr_lambda_prob] + #) + #scheduler = torch.optim.lr_scheduler.CyclicLR( + # optimizer, + # base_lr=self.hparams.learning_rate / 10, + # max_lr=self.hparams.learning_rate, + # step_size_up=20, + # mode="exp_range", + # gamma=0.999, + # cycle_momentum=False # ) - scheduler = torch.optim.lr_scheduler.CyclicLR( + scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( optimizer, - base_lr=self.hparams.learning_rate / 10, - max_lr=self.hparams.learning_rate, - step_size_up=20, - mode="exp_range", - gamma=0.999, - cycle_momentum=False + 50, + eta_min=1e-6, + verbose=True ) return [optimizer], [scheduler] @@ -662,8 +681,8 @@ def validation_epoch_end(self, validation_step_outputs): samples = self.infer( img, sample_strategy="spaced", - num_samples=7, - spaced_range=[-3, 3], + num_samples=5, + spaced_range=[-2, 2], preprocess=False, ) except Exception as e: From 51c6df25255797ae026cf7e3be553509ccf5a6f8 Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Mon, 22 May 2023 19:29:42 +1000 Subject: [PATCH 215/264] Support additional training data --- platipy/imaging/cnn/dataload.py | 98 +++++++++++++++++++++++++++++++++ 1 file changed, 98 insertions(+) diff --git a/platipy/imaging/cnn/dataload.py b/platipy/imaging/cnn/dataload.py index 8939a7d5..fde6dcfd 100644 --- a/platipy/imaging/cnn/dataload.py +++ b/platipy/imaging/cnn/dataload.py @@ -18,16 +18,21 @@ class UNetDataModule(pl.LightningDataModule): def __init__( self, data_dir: str = "./data", + data_add_dirs: list = [], augmented_dir: str = None, + augmented_add_dirs: list = [], working_dir: str = "./working", structures=["a", "b", "c"], observers=["0", "1", "2", "3", "4"], + observers_add=[], case_glob="images/*.nii.gz", image_glob="images/{case}.nii.gz", label_glob="labels/{case}_{structure}_*.nii.gz", + label_add_glob="labels/{case}_{structure}.nii.gz", augmented_case_glob="{case}/*", augmented_image_glob="images/{augmented_case}.nii.gz", augmented_label_glob="labels/{augmented_case}_{structure}_*.nii.gz", + augmented_label_add_glob="labels/{augmented_case}_{structure}_*.nii.gz", augment_on_fly=True, fold=0, k_folds=5, @@ -47,15 +52,19 @@ def __init__( ): super().__init__() self.data_dir = Path(data_dir) + self.data_add_dirs = [Path(p) for p in data_add_dirs] self.augmented_dir = augmented_dir + self.augmented_add_dirs = augmented_add_dirs self.working_dir = Path(working_dir) self.case_glob = case_glob self.image_glob = image_glob self.label_glob = label_glob + self.label_add_glob = label_add_glob self.augmented_case_glob = augmented_case_glob self.augmented_image_glob = augmented_image_glob self.augmented_label_glob = augmented_label_glob + self.augmented_label_add_glob = augmented_label_add_glob self.augment_on_fly = augment_on_fly self.fold = fold @@ -75,6 +84,7 @@ def __init__( self.contour_mask_kernel = contour_mask_kernel self.structures = structures self.observers = observers + self.observers_add = observers_add self.crop_using_localise_model = crop_using_localise_model self.localise_voxel_grid_size = localise_voxel_grid_size @@ -96,7 +106,9 @@ def add_model_specific_args(parent_parser): """Add arguments used for Data module""" parser = parent_parser.add_argument_group("Data Loader") parser.add_argument("--data_dir", type=str, default="./data") + parser.add_argument("--data_add_dirs", nargs="+", type=str, default=[]) parser.add_argument("--augmented_dir", type=str, default=None) + parser.add_argument("--augmented_add_dirs", nargs="+", type=str, default=[]) parser.add_argument("--augment_on_fly", type=bool, default=True) parser.add_argument("--fold", type=int, default=0) parser.add_argument("--k_folds", type=int, default=5) @@ -104,14 +116,19 @@ def add_model_specific_args(parent_parser): parser.add_argument("--num_workers", type=int, default=4) parser.add_argument("--structures", nargs="+", type=str, default=["a", "b", "c"]) parser.add_argument("--observers", nargs="+", type=str, default=["0", "1", "2", "3", "4"]) + parser.add_argument("--observers_add", nargs="+", type=str, default=[]) parser.add_argument("--case_glob", type=str, default="images/*.nii.gz") parser.add_argument("--image_glob", type=str, default="images/{case}.nii.gz") parser.add_argument( "--label_glob", type=str, default="labels/{case}_{structure}_{observer}.nii.gz" ) + parser.add_argument( + "--label_add_glob", type=str, default="labels/{case}_{structure}.nii.gz" + ) parser.add_argument("--augmented_case_glob", type=str, default=None) parser.add_argument("--augmented_image_glob", type=str, default=None) parser.add_argument("--augmented_label_glob", type=str, default=None) + parser.add_argument("--augmented_label_add_glob", type=str, default=None) parser.add_argument("--crop_to_grid_size_xy", type=int, default=128) parser.add_argument("--intensity_scaling", type=str, default="window") parser.add_argument("--intensity_window", nargs="+", type=int, default=[-500, 500]) @@ -209,6 +226,87 @@ def setup(self, stage=None): for augmented_case in augmented_cases ] + # If observers_add is empty then just add one dummy observer since they are not using + # Multi observer data here + if len(self.observers_add) == 0: + self.observers_add = ["X"] + + # Add in the addtional cases, these are only use for training and may only have 1 observer + for data_add_dir in self.data_add_dirs: + self.add_train_cases = [] + cases = [ + p.name.replace(".nii.gz", "") + for p in data_add_dir.glob(self.case_glob) + if not p.name.startswith(".") + ] + self.add_train_cases += cases + train_data += [ + { + "id": case, + "image": data_add_dir.joinpath(self.image_glob.format(case=case)), + "observers": { + observer: { + structure: data_add_dir.joinpath( + self.label_add_glob.format( + case=case, structure=structure, observer=observer + ) + ) + for structure in self.structures + } + for observer in self.observers_add + }, + } + for case in cases + ] + + for case in cases: + + case_aug_dir = None + for aug_add_dir in self.augmented_add_dirs: + if Path(aug_add_dir.format(case=case)).exists(): + + case_aug_dir = Path(aug_add_dir.format(case=case)) + else: + print(f"No dir {Path(aug_add_dir.format(case=case))}") + + if case_aug_dir is None: + continue + + augmented_cases = [ + p.name.replace(".nii.gz", "") + for p in case_aug_dir.glob(self.augmented_case_glob.format(case=case)) + if not p.name.startswith(".") + ] + print(augmented_cases) + + train_data += [ + { + "id": f"{case}_{augmented_case}", + "image": case_aug_dir.joinpath( + self.augmented_image_glob.format( + case=case, augmented_case=augmented_case + ) + ), + "observers": { + observer: { + structure: case_aug_dir.joinpath( + self.augmented_label_add_glob.format( + case=case, + augmented_case=augmented_case, + structure=structure, + observer=observer + ) + ) + for structure in self.structures + } + for observer in self.observers_add + }, + } + for augmented_case in augmented_cases + ] + print(train_data) + print(len(train_data)) + self.validation_data = [ { "id": case, From d75210adbee7b7c0189d9cfd8a6a0ddbccc4fed2 Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Mon, 27 Nov 2023 16:19:38 +1100 Subject: [PATCH 216/264] Remove loguru imports --- platipy/imaging/cnn/dataload.py | 4 ++-- platipy/imaging/cnn/dataset.py | 4 ++-- platipy/imaging/cnn/train_lidc.py | 5 +++-- platipy/imaging/generation/augment.py | 4 ++-- platipy/imaging/generation/dvf.py | 4 ++-- 5 files changed, 11 insertions(+), 10 deletions(-) diff --git a/platipy/imaging/cnn/dataload.py b/platipy/imaging/cnn/dataload.py index fde6dcfd..4aac2baa 100644 --- a/platipy/imaging/cnn/dataload.py +++ b/platipy/imaging/cnn/dataload.py @@ -1,9 +1,8 @@ import random import math +import logging from pathlib import Path -from loguru import logger - import torch import pytorch_lightning as pl @@ -11,6 +10,7 @@ from platipy.imaging.cnn.dataset import NiftiDataset from platipy.imaging.cnn.sampler import ObserverSampler +logger = logging.getLogger(__name__) class UNetDataModule(pl.LightningDataModule): """PyTorch data module to training UNets""" diff --git a/platipy/imaging/cnn/dataset.py b/platipy/imaging/cnn/dataset.py index ace3d695..4dd7b0c2 100644 --- a/platipy/imaging/cnn/dataset.py +++ b/platipy/imaging/cnn/dataset.py @@ -13,6 +13,7 @@ # limitations under the License. import re +import logging from pathlib import Path import numpy as np @@ -24,8 +25,6 @@ from imgaug import augmenters as iaa from imgaug.augmentables.segmaps import SegmentationMapsOnImage -from loguru import logger - import math import random from scipy.ndimage import affine_transform @@ -36,6 +35,7 @@ from platipy.imaging.cnn.localise_net import LocaliseUNet from platipy.imaging.utils.crop import label_to_roi, crop_to_roi +logger = logging.getLogger(__name__) class GaussianNoise: def __init__(self, mu=0.0, sigma=0.0, probability=1.0): diff --git a/platipy/imaging/cnn/train_lidc.py b/platipy/imaging/cnn/train_lidc.py index ddf268f5..7a8132ba 100644 --- a/platipy/imaging/cnn/train_lidc.py +++ b/platipy/imaging/cnn/train_lidc.py @@ -16,6 +16,7 @@ import os import tempfile import json +import logging from pathlib import Path import SimpleITK as sitk @@ -48,8 +49,6 @@ import math from pathlib import Path -from loguru import logger - import torch import pytorch_lightning as pl @@ -57,6 +56,8 @@ from platipy.imaging.cnn.lidc_dataset import LIDCDataset from platipy.imaging.cnn.sampler import ObserverSampler +logger = logging.getLogger(__name__) + LIDC_PAT_IDS = [ '100036212881370097961774473021', '100063870746088919758706456900', diff --git a/platipy/imaging/generation/augment.py b/platipy/imaging/generation/augment.py index f099f42d..46332692 100644 --- a/platipy/imaging/generation/augment.py +++ b/platipy/imaging/generation/augment.py @@ -15,6 +15,7 @@ from abc import ABC, abstractmethod from collections.abc import Iterable import random +import logging from pathlib import Path @@ -23,8 +24,6 @@ import SimpleITK as sitk import numpy as np -from loguru import logger - import matplotlib.pyplot as plt from platipy.imaging import ImageVisualiser @@ -43,6 +42,7 @@ from platipy.imaging.label.utils import get_union_mask from platipy.imaging.utils.crop import label_to_roi, crop_to_roi +logger = logging.getLogger(__name__) def apply_augmentation(image, augmentation, masks=[]): diff --git a/platipy/imaging/generation/dvf.py b/platipy/imaging/generation/dvf.py index 6ac9e52d..efd9f7b9 100644 --- a/platipy/imaging/generation/dvf.py +++ b/platipy/imaging/generation/dvf.py @@ -12,12 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. +import logging import numpy as np import SimpleITK as sitk -from loguru import logger - from platipy.imaging.registration.utils import ( apply_transform, convert_mask_to_reg_structure, @@ -30,6 +29,7 @@ from platipy.imaging.label.utils import get_com from platipy.imaging.utils.crop import label_to_roi, crop_to_roi +logger = logging.getLogger(__name__) def generate_field_shift(mask, vector_shift=(10, 10, 10), gaussian_smooth=5): """ From 5bd7b4d01c387c3f25ba6ab07179efa723061475 Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Mon, 4 Dec 2023 11:23:04 +1100 Subject: [PATCH 217/264] Support context map inputs to prob UNet --- platipy/imaging/cnn/dataload.py | 81 +++++++--- platipy/imaging/cnn/dataset.py | 89 ++++++++--- platipy/imaging/cnn/train.py | 262 ++++++++++++++++++++++---------- 3 files changed, 312 insertions(+), 120 deletions(-) diff --git a/platipy/imaging/cnn/dataload.py b/platipy/imaging/cnn/dataload.py index 4aac2baa..98f62e79 100644 --- a/platipy/imaging/cnn/dataload.py +++ b/platipy/imaging/cnn/dataload.py @@ -12,6 +12,7 @@ logger = logging.getLogger(__name__) + class UNetDataModule(pl.LightningDataModule): """PyTorch data module to training UNets""" @@ -29,10 +30,12 @@ def __init__( image_glob="images/{case}.nii.gz", label_glob="labels/{case}_{structure}_*.nii.gz", label_add_glob="labels/{case}_{structure}.nii.gz", + context_map_glob="context_maps/{case}.nii.gz", augmented_case_glob="{case}/*", augmented_image_glob="images/{augmented_case}.nii.gz", augmented_label_glob="labels/{augmented_case}_{structure}_*.nii.gz", augmented_label_add_glob="labels/{augmented_case}_{structure}_*.nii.gz", + augmented_context_map_glob="context_maps/{case}_{augmented_case}.nii.gz", augment_on_fly=True, fold=0, k_folds=5, @@ -47,6 +50,7 @@ def __init__( crop_using_localise_model=None, localise_voxel_grid_size=[100, 100, 100], validation_sampler="observer", # observer or batch + input_channels=1, ndims=2, **kwargs, ): @@ -61,10 +65,13 @@ def __init__( self.image_glob = image_glob self.label_glob = label_glob self.label_add_glob = label_add_glob + self.context_map_glob = context_map_glob + self.augmented_case_glob = augmented_case_glob self.augmented_image_glob = augmented_image_glob self.augmented_label_glob = augmented_label_glob self.augmented_label_add_glob = augmented_label_add_glob + self.augmented_context_map_glob = augmented_context_map_glob self.augment_on_fly = augment_on_fly self.fold = fold @@ -97,6 +104,7 @@ def __init__( self.validation_data = [] self.test_data = [] + self.input_channels = input_channels self.ndims = ndims print(f"Training fold {self.fold}") @@ -114,24 +122,34 @@ def add_model_specific_args(parent_parser): parser.add_argument("--k_folds", type=int, default=5) parser.add_argument("--batch_size", type=int, default=5) parser.add_argument("--num_workers", type=int, default=4) - parser.add_argument("--structures", nargs="+", type=str, default=["a", "b", "c"]) - parser.add_argument("--observers", nargs="+", type=str, default=["0", "1", "2", "3", "4"]) + parser.add_argument( + "--structures", nargs="+", type=str, default=["a", "b", "c"] + ) + parser.add_argument( + "--observers", nargs="+", type=str, default=["0", "1", "2", "3", "4"] + ) parser.add_argument("--observers_add", nargs="+", type=str, default=[]) parser.add_argument("--case_glob", type=str, default="images/*.nii.gz") parser.add_argument("--image_glob", type=str, default="images/{case}.nii.gz") parser.add_argument( - "--label_glob", type=str, default="labels/{case}_{structure}_{observer}.nii.gz" + "--label_glob", + type=str, + default="labels/{case}_{structure}_{observer}.nii.gz", ) parser.add_argument( "--label_add_glob", type=str, default="labels/{case}_{structure}.nii.gz" ) + parser.add_argument("--context_map_glob", type=str, default=None) parser.add_argument("--augmented_case_glob", type=str, default=None) parser.add_argument("--augmented_image_glob", type=str, default=None) parser.add_argument("--augmented_label_glob", type=str, default=None) parser.add_argument("--augmented_label_add_glob", type=str, default=None) + parser.add_argument("--augmented_context_map_glob", type=str, default=None) parser.add_argument("--crop_to_grid_size_xy", type=int, default=128) parser.add_argument("--intensity_scaling", type=str, default="window") - parser.add_argument("--intensity_window", nargs="+", type=int, default=[-500, 500]) + parser.add_argument( + "--intensity_window", nargs="+", type=int, default=[-500, 500] + ) parser.add_argument("--contour_mask_kernel", type=int, default=5) parser.add_argument("--crop_using_localise_model", type=str, default=None) parser.add_argument( @@ -142,7 +160,6 @@ def add_model_specific_args(parent_parser): return parent_parser def setup(self, stage=None): - cases = [ p.name.replace(".nii.gz", "") for p in self.data_dir.glob(self.case_glob) @@ -153,14 +170,15 @@ def setup(self, stage=None): cases_per_fold = math.ceil(len(cases) / self.k_folds) for f in range(self.k_folds): - if self.fold == f: val_test_cases = cases[f * cases_per_fold : (f + 1) * cases_per_fold] if len(val_test_cases) == 1: self.validation_cases = val_test_cases else: - self.validation_cases = val_test_cases[: int(len(val_test_cases) / 2)] + self.validation_cases = val_test_cases[ + : int(len(val_test_cases) / 2) + ] self.test_cases = val_test_cases[int(len(val_test_cases) / 2) :] else: self.train_cases += cases[f * cases_per_fold : (f + 1) * cases_per_fold] @@ -173,6 +191,9 @@ def setup(self, stage=None): { "id": case, "image": self.data_dir.joinpath(self.image_glob.format(case=case)), + "context_map": self.data_dir.joinpath( + self.context_map_glob.format(case=case) + ), "observers": { observer: { structure: self.data_dir.joinpath( @@ -190,13 +211,13 @@ def setup(self, stage=None): # If a directory with augmented data is specified, use that for training as well if self.augmented_dir is not None: - for case in self.train_cases: - case_aug_dir = Path(self.augmented_dir.format(case=case)) augmented_cases = [ p.name.replace(".nii.gz", "") - for p in case_aug_dir.glob(self.augmented_case_glob.format(case=case)) + for p in case_aug_dir.glob( + self.augmented_case_glob.format(case=case) + ) if not p.name.startswith(".") ] @@ -208,6 +229,11 @@ def setup(self, stage=None): case=case, augmented_case=augmented_case ) ), + "context_map": case_aug_dir.joinpath( + self.augmented_context_map_glob.format( + case=case, augmented_case=augmented_case + ) + ), "observers": { observer: { structure: case_aug_dir.joinpath( @@ -215,7 +241,7 @@ def setup(self, stage=None): case=case, augmented_case=augmented_case, structure=structure, - observer=observer + observer=observer, ) ) for structure in self.structures @@ -226,8 +252,8 @@ def setup(self, stage=None): for augmented_case in augmented_cases ] - # If observers_add is empty then just add one dummy observer since they are not using - # Multi observer data here + # If observers_add is empty then just add one dummy observer since they are not using + # Multi observer data here if len(self.observers_add) == 0: self.observers_add = ["X"] @@ -244,12 +270,15 @@ def setup(self, stage=None): { "id": case, "image": data_add_dir.joinpath(self.image_glob.format(case=case)), + "context_map": data_add_dir.joinpath( + self.context_map_glob.format(case=case) + ), "observers": { observer: { structure: data_add_dir.joinpath( self.label_add_glob.format( case=case, structure=structure, observer=observer - ) + ) ) for structure in self.structures } @@ -260,11 +289,9 @@ def setup(self, stage=None): ] for case in cases: - case_aug_dir = None for aug_add_dir in self.augmented_add_dirs: if Path(aug_add_dir.format(case=case)).exists(): - case_aug_dir = Path(aug_add_dir.format(case=case)) else: print(f"No dir {Path(aug_add_dir.format(case=case))}") @@ -274,7 +301,9 @@ def setup(self, stage=None): augmented_cases = [ p.name.replace(".nii.gz", "") - for p in case_aug_dir.glob(self.augmented_case_glob.format(case=case)) + for p in case_aug_dir.glob( + self.augmented_case_glob.format(case=case) + ) if not p.name.startswith(".") ] print(augmented_cases) @@ -287,6 +316,11 @@ def setup(self, stage=None): case=case, augmented_case=augmented_case ) ), + "context_map": case_aug_dir.joinpath( + self.augmented_context_map_glob.format( + case=case, augmented_case=augmented_case + ) + ), "observers": { observer: { structure: case_aug_dir.joinpath( @@ -294,7 +328,7 @@ def setup(self, stage=None): case=case, augmented_case=augmented_case, structure=structure, - observer=observer + observer=observer, ) ) for structure in self.structures @@ -349,7 +383,9 @@ def setup(self, stage=None): crop_to_grid_size = None localise_model_path = None if self.crop_using_localise_model: - localise_model_path = Path(self.crop_using_localise_model.format(fold=self.fold)) + localise_model_path = Path( + self.crop_using_localise_model.format(fold=self.fold) + ) if localise_model_path.is_dir(): localise_model_path = next(localise_model_path.glob("*.ckpt")) @@ -360,6 +396,10 @@ def setup(self, stage=None): augment_on_fly = self.augment_on_fly + use_context_map = False + if self.input_channels > 1: + use_context_map = True + self.training_set = NiftiDataset( train_data, self.working_dir, @@ -371,6 +411,7 @@ def setup(self, stage=None): intensity_scaling=self.intensity_scaling, intensity_window=self.intensity_window, ndims=self.ndims, + use_context_map=use_context_map, ) self.validation_set = NiftiDataset( self.validation_data, @@ -383,6 +424,7 @@ def setup(self, stage=None): intensity_scaling=self.intensity_scaling, intensity_window=self.intensity_window, ndims=self.ndims, + use_context_map=use_context_map, ) self.test_set = NiftiDataset( self.test_data, @@ -395,6 +437,7 @@ def setup(self, stage=None): intensity_scaling=self.intensity_scaling, intensity_window=self.intensity_window, ndims=self.ndims, + use_context_map=use_context_map, ) def train_dataloader(self): diff --git a/platipy/imaging/cnn/dataset.py b/platipy/imaging/cnn/dataset.py index 4dd7b0c2..bd21929e 100644 --- a/platipy/imaging/cnn/dataset.py +++ b/platipy/imaging/cnn/dataset.py @@ -30,16 +30,20 @@ from scipy.ndimage import affine_transform from scipy.ndimage.filters import gaussian_filter, median_filter -from platipy.imaging.cnn.utils import preprocess_image, resample_mask_to_image, get_contour_mask +from platipy.imaging.cnn.utils import ( + preprocess_image, + resample_mask_to_image, + get_contour_mask, +) from platipy.imaging.label.utils import get_union_mask, get_intersection_mask from platipy.imaging.cnn.localise_net import LocaliseUNet from platipy.imaging.utils.crop import label_to_roi, crop_to_roi logger = logging.getLogger(__name__) + class GaussianNoise: def __init__(self, mu=0.0, sigma=0.0, probability=1.0): - self.mu = mu self.sigma = sigma self.probability = probability @@ -51,7 +55,6 @@ def __init__(self, mu=0.0, sigma=0.0, probability=1.0): self.sigma = (self.sigma,) * 2 def apply(self, img, masks=[]): - if random.random() > self.probability: # Don't augment this time return img, masks @@ -65,7 +68,6 @@ def apply(self, img, masks=[]): class GaussianBlur: def __init__(self, sigma=0.0, probability=1.0): - self.sigma = sigma self.probability = probability @@ -73,7 +75,6 @@ def __init__(self, sigma=0.0, probability=1.0): self.sigma = (self.sigma,) * 2 def apply(self, img, masks=[]): - if random.random() > self.probability: # Don't augment this time return img, masks @@ -85,7 +86,6 @@ def apply(self, img, masks=[]): class MedianBlur: def __init__(self, size=1.0, probability=1.0): - self.size = size self.probability = probability @@ -93,7 +93,6 @@ def __init__(self, size=1.0, probability=1.0): self.size = (self.size,) * 2 def apply(self, img, masks=[]): - if random.random() > self.probability: # Don't augment this time return img, masks @@ -117,7 +116,6 @@ def __init__( cval=-1, probability=1.0, ): - self.scale = scale self.translate_percent = translate_percent self.rotate = rotate @@ -176,7 +174,6 @@ def get_rot(self, theta, d): ) def get_shear(self, shear): - mat = np.identity(4) mat[0, 1] = shear[1] mat[0, 2] = shear[2] @@ -188,7 +185,6 @@ def get_shear(self, shear): return mat def apply(self, img, masks=[]): - if random.random() > self.probability: # Don't augment this time return img, masks @@ -300,7 +296,6 @@ def prepare_3d_transforms(): def prepare_transforms(): - sometimes = lambda aug: iaa.Sometimes(0.5, aug) seq = iaa.Sequential( @@ -351,6 +346,7 @@ def __init__( combine_observers=None, intensity_scaling="window", intensity_window=[-500, 500], + use_context_map=False, ndims=2, ): """Prepare a dataset from Nifti images/labels @@ -376,21 +372,26 @@ def __init__( self.img_dir = working_dir.joinpath("img") self.label_dir = working_dir.joinpath("label") self.contour_mask_dir = working_dir.joinpath("contour_mask") + self.context_map_dir = working_dir.joinpath("context_map") self.img_dir.mkdir(exist_ok=True, parents=True) self.label_dir.mkdir(exist_ok=True, parents=True) self.contour_mask_dir.mkdir(exist_ok=True, parents=True) + self.context_map_dir.mkdir(exist_ok=True, parents=True) for case in data: case_id = case["id"] img_path = str(case["image"]) + if use_context_map: + context_map_path = str(case["context_map"]) + existing_images = [i for i in self.img_dir.glob(f"{case_id}_*.npy")] if len(existing_images) > 0: logger.debug(f"Image for case already exist: {case_id}") for img_path in existing_images: - z_matches = re.findall(fr"{case_id}_([0-9]*)\.npy", img_path.name) + z_matches = re.findall(rf"{case_id}_([0-9]*)\.npy", img_path.name) if len(z_matches) == 0: continue z_slice = int(z_matches[0]) @@ -398,8 +399,14 @@ def __init__( img_file = self.img_dir.joinpath(f"{case_id}_{z_slice}.npy") assert img_file.exists() - for obs in case["observers"]: + cmap_file = None + if use_context_map: + cmap_file = self.context_map_dir.joinpath( + f"{case_id}_{z_slice}.npy" + ) + assert cmap_file.exists() + for obs in case["observers"]: labels = [] contour_mask_files = [] for structure in case["observers"][obs]: @@ -423,6 +430,7 @@ def __init__( "image": img_file, "labels": labels, "contour_masks": contour_mask_files, + "context_map": cmap_file, "case": case_id, "observer": obs, } @@ -433,6 +441,9 @@ def __init__( logger.debug(f"Generating images for case: {case_id}") img = sitk.ReadImage(img_path) + if use_context_map: + context_map = sitk.ReadImage(context_map_path) + if crop_using_localise_model: img = crop_img_using_localise_model( img, @@ -449,6 +460,9 @@ def __init__( intensity_window=intensity_window, ) + if use_context_map: + context_map = resample_mask_to_image(img, context_map) + observers = {} structure_names = [] for obs in case["observers"]: @@ -509,16 +523,26 @@ def __init__( if ndims == 3: z_range = range(1) for z_slice in z_range: - # Save the image slice if ndims == 2: img_slice = img[:, :, z_slice] + + if use_context_map: + cmap_slice = context_map[:, :, z_slice] else: img_slice = img + if use_context_map: + cmap_slice = context_map img_file = self.img_dir.joinpath(f"{case_id}_{z_slice}.npy") np.save(img_file, sitk.GetArrayFromImage(img_slice)) + if use_context_map: + cmap_file = self.context_map_dir.joinpath( + f"{case_id}_{z_slice}.npy" + ) + np.save(cmap_file, sitk.GetArrayFromImage(cmap_slice)) + # Save the contour mask slice cmasks = [] for structure in structure_names: @@ -529,14 +553,14 @@ def __init__( contour_mask_file = self.contour_mask_dir.joinpath( f"{case_id}_{structure}_{z_slice}.npy" ) - np.save(contour_mask_file, sitk.GetArrayFromImage(contour_mask_slice)) + np.save( + contour_mask_file, sitk.GetArrayFromImage(contour_mask_slice) + ) cmasks.append(contour_mask_file) for obs in observers: - labels = [] for structure in structure_names: - if observers[obs][structure] is None: labels.append(None) continue @@ -547,7 +571,10 @@ def __init__( label_file = self.label_dir.joinpath( f"{case_id}_{structure}_{obs}_{z_slice}.npy" ) - np.save(label_file, sitk.GetArrayFromImage(label_slice).astype(np.int8)) + np.save( + label_file, + sitk.GetArrayFromImage(label_slice).astype(np.int8), + ) labels.append(label_file) self.slices.append( @@ -556,6 +583,7 @@ def __init__( "image": img_file, "labels": labels, "contour_masks": cmasks, + "context_map": cmap_file, "case": case_id, "observer": obs, } @@ -565,16 +593,20 @@ def __len__(self): return len(self.slices) def __getitem__(self, index): - img = np.load(self.slices[index]["image"]) labels = [ np.load(label_file) if label_file else np.zeros(img.shape, dtype=np.ushort) for label_file in self.slices[index]["labels"] ] contour_masks = [ - np.load(contour_mask_file) for contour_mask_file in self.slices[index]["contour_masks"] + np.load(contour_mask_file) + for contour_mask_file in self.slices[index]["contour_masks"] ] + context_map = None + if self.slices[index]["context_map"]: + context_map = np.load(self.slices[index]["context_map"]) + if self.transforms: masks = labels + contour_masks if self.ndims == 2: @@ -583,7 +615,9 @@ def __getitem__(self, index): img, seg = self.transforms(image=img, segmentation_maps=segmap) for idx, _ in enumerate(labels): labels[idx] = seg.get_arr()[:, :, idx].squeeze() - contour_masks[idx] = seg.get_arr()[:, :, len(labels) + idx].squeeze() + contour_masks[idx] = seg.get_arr()[ + :, :, len(labels) + idx + ].squeeze() else: for aug in self.transforms: img, masks = aug.apply(img, masks) @@ -591,7 +625,15 @@ def __getitem__(self, index): contour_masks = masks[len(contour_masks) :] img = torch.FloatTensor(img) - label = torch.FloatTensor(np.concatenate([np.expand_dims(l, 0) for l in labels], 0)) + img = img.unsqueeze(0) + + if context_map: + context_map = torch.FloatTensor(context_map) + context_map = context_map.unqsqueeze(0) + + label = torch.FloatTensor( + np.concatenate([np.expand_dims(l, 0) for l in labels], 0) + ) contour_mask = torch.FloatTensor( np.concatenate([np.expand_dims(l, 0) for l in contour_masks], 0) ) @@ -599,7 +641,8 @@ def __getitem__(self, index): label_present = [label is not None for label in self.slices[index]["labels"]] return ( - img.unsqueeze(0), + img, + context_map, label, contour_mask, { diff --git a/platipy/imaging/cnn/train.py b/platipy/imaging/cnn/train.py index 4ca53886..da1c173b 100644 --- a/platipy/imaging/cnn/train.py +++ b/platipy/imaging/cnn/train.py @@ -46,9 +46,9 @@ from platipy.imaging import ImageVisualiser from platipy.imaging.label.utils import get_com, get_union_mask, get_intersection_mask + class GECOEarlyStopping(EarlyStopping): def on_validation_end(self, trainer, pl_module): - # Make sure the GECO lambda metrics are below 0.1 before stopping logs = trainer.callback_metrics should_consider_early_stop = True @@ -65,6 +65,7 @@ def on_validation_end(self, trainer, pl_module): def on_train_epoch_end(self, trainer, pl_module): pass + class ProbUNet(pl.LightningModule): def __init__( self, @@ -91,20 +92,23 @@ def __init__( } loss_params["top_k_percentage"] = self.hparams.top_k_percentage - loss_params["contour_loss_lambda_threshold"] = self.hparams.contour_loss_lambda_threshold + loss_params[ + "contour_loss_lambda_threshold" + ] = self.hparams.contour_loss_lambda_threshold loss_params["contour_loss_weight"] = self.hparams.contour_loss_weight if self.hparams.prob_type == "prob": self.prob_unet = ProbabilisticUnet( self.hparams.input_channels, - len(self.hparams.structures) + 1, # Add 1 to num classes for background class + len(self.hparams.structures) + + 1, # Add 1 to num classes for background class self.hparams.filters_per_layer, self.hparams.latent_dim, self.hparams.no_convs_fcomb, self.hparams.loss_type, loss_params, self.hparams.ndims, - dropout_probability=self.hparams.dropout_probability + dropout_probability=self.hparams.dropout_probability, ) elif self.hparams.prob_type == "hierarchical": raise NotImplementedError("Hierarchical Prob UNet current not working...") @@ -134,9 +138,14 @@ def add_model_specific_args(parent_parser): parser.add_argument("--lr_lambda", type=float, default=0.99) parser.add_argument("--input_channels", type=int, default=1) parser.add_argument( - "--filters_per_layer", nargs="+", type=int, default=[64 * (2**x) for x in range(5)] + "--filters_per_layer", + nargs="+", + type=int, + default=[64 * (2**x) for x in range(5)], + ) + parser.add_argument( + "--down_channels_per_block", nargs="+", type=int, default=None ) - parser.add_argument("--down_channels_per_block", nargs="+", type=int, default=None) parser.add_argument("--latent_dim", type=int, default=6) parser.add_argument("--no_convs_fcomb", type=int, default=4) parser.add_argument("--convs_per_block", type=int, default=2) @@ -148,10 +157,14 @@ def add_model_specific_args(parent_parser): parser.add_argument("--rec_geco_step_size", type=float, default=1e-2) parser.add_argument("--weight_decay", type=float, default=1e-2) parser.add_argument("--clamp_rec", nargs="+", type=float, default=[1e-5, 1e5]) - parser.add_argument("--clamp_contour", nargs="+", type=float, default=[1e-3, 1e3]) + parser.add_argument( + "--clamp_contour", nargs="+", type=float, default=[1e-3, 1e3] + ) parser.add_argument("--top_k_percentage", type=float, default=None) parser.add_argument("--contour_loss_lambda_threshold", type=float, default=None) - parser.add_argument("--contour_loss_weight", type=float, default=0.0) # no longer used + parser.add_argument( + "--contour_loss_weight", type=float, default=0.0 + ) # no longer used parser.add_argument("--epochs_all_rec", type=int, default=0) # no longer used parser.add_argument("--dropout_probability", type=float, default=0.0) @@ -162,29 +175,34 @@ def forward(self, x): return x def configure_optimizers(self): + params = [ + { + "params": self.prob_unet.unet.parameters(), + "weight_decay": self.hparams.weight_decay, + "lr": 1e-4, + } + ] + for m in [ + self.prob_unet.prior.parameters(), + self.prob_unet.posterior.parameters(), + self.prob_unet.fcomb.parameters(), + ]: + params += [ + {"params": m, "weight_decay": self.hparams.weight_decay, "lr": 1e-5} + ] - params = [{ - 'params': self.prob_unet.unet.parameters(), - 'weight_decay': self.hparams.weight_decay, - 'lr': 1e-4 - }] - for m in [self.prob_unet.prior.parameters(), self.prob_unet.posterior.parameters(), self.prob_unet.fcomb.parameters()]: - params += [{'params': m, 'weight_decay': self.hparams.weight_decay, 'lr': 1e-5}] - - optimizer = torch.optim.Adam( - params - ) + optimizer = torch.optim.Adam(params) lr_lambda_unet = lambda epoch: self.hparams.lr_lambda ** (epoch) lr_lambda_prob = lambda epoch: 0.99 ** (epoch) -# max_epochs = self.hparams.max_epochs -# lr_lambda = lambda x: np.interp(((np.sin(x/(max_epochs/8)) * np.sin(x/(max_epochs/4)))), np.array([-1,0,1]), np.array([0.1,1,10])) + # max_epochs = self.hparams.max_epochs + # lr_lambda = lambda x: np.interp(((np.sin(x/(max_epochs/8)) * np.sin(x/(max_epochs/4)))), np.array([-1,0,1]), np.array([0.1,1,10])) - #scheduler = torch.optim.lr_scheduler.LambdaLR( + # scheduler = torch.optim.lr_scheduler.LambdaLR( # optimizer, lr_lambda=[lr_lambda_unet, lr_lambda_prob, lr_lambda_prob, lr_lambda_prob] - #) - #scheduler = torch.optim.lr_scheduler.CyclicLR( + # ) + # scheduler = torch.optim.lr_scheduler.CyclicLR( # optimizer, # base_lr=self.hparams.learning_rate / 10, # max_lr=self.hparams.learning_rate, @@ -194,10 +212,7 @@ def configure_optimizers(self): # cycle_momentum=False # ) scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( - optimizer, - 50, - eta_min=1e-6, - verbose=True + optimizer, 50, eta_min=1e-6, verbose=True ) return [optimizer], [scheduler] @@ -206,8 +221,12 @@ def configure_optimizers(self): "optimizer": optimizer, "lr_scheduler": { "scheduler": torch.optim.lr_scheduler.ReduceLROnPlateau( - optimizer, "max", patience=20, threshold=0.1e-2, factor=0.75 -# optimizer, "max", patience=200, threshold=0.75, factor=0.5 + optimizer, + "max", + patience=20, + threshold=0.1e-2, + factor=0.75 + # optimizer, "max", patience=200, threshold=0.75, factor=0.5 ), "monitor": "probabilisticDice", }, @@ -216,6 +235,7 @@ def configure_optimizers(self): def infer( self, img, + context_map=None, num_samples=1, sample_strategy="mean", latent_dim=True, @@ -233,7 +253,9 @@ def infer( samples = [ { "name": "mean", - "std_dev_from_mean": torch.Tensor([0.0] * len(latent_dim)).to(self.device), + "std_dev_from_mean": torch.Tensor([0.0] * len(latent_dim)).to( + self.device + ), "preds": [], } ] @@ -242,7 +264,10 @@ def infer( { "name": f"random_{i}", "std_dev_from_mean": torch.Tensor( - [np.random.normal(0, 1.0, 1)[0] if d else 0.0 for d in latent_dim] + [ + np.random.normal(0, 1.0, 1)[0] if d else 0.0 + for d in latent_dim + ] ).to(self.device), "preds": [], } @@ -254,16 +279,15 @@ def infer( samples = [ { "name": f"spaced_{s:.2f}", - "std_dev_from_mean": torch.Tensor([s if d else 0.0 for d in latent_dim]).to( - self.device - ), + "std_dev_from_mean": torch.Tensor( + [s if d else 0.0 for d in latent_dim] + ).to(self.device), "preds": [], } for s in np.linspace(spaced_range[0], spaced_range[1], num_samples) ] with torch.no_grad(): - if preprocess: if self.hparams.crop_using_localise_model: localise_path = self.hparams.crop_using_localise_model.format( @@ -286,21 +310,36 @@ def infer( img_arr = sitk.GetArrayFromImage(img) + if context_map is not None: + context_map = resample_mask_to_image(img, context_map) + cmap_arr = sitk.GetArrayFromImage(img) + if self.hparams.ndims == 2: slices = [img_arr[z, :, :] for z in range(img_arr.shape[0])] + + if context_map is not None: + cmap_slices = [cmap_arr[z, :, :] for z in range(cmap_arr.shape[0])] else: slices = [img_arr] - for i in slices: + if context_map is not None: + cmap_slices = [cmap_arr] + for idx, i in enumerate(slices): x = torch.Tensor(i).to(self.device) x = x.unsqueeze(0) x = x.unsqueeze(0) + if context_map is not None: + c = torch.Tensor(cmap_slices[idx]).to(self.device) + c = c.unsqueeze(0) + c = c.unsqueeze(0) + + x = torch.cat((x, c), dim=1) + if self.hparams.prob_type == "prob": self.prob_unet.forward(x) for sample in samples: - if self.hparams.prob_type == "prob": if sample["name"] == "mean": y = self.prob_unet.sample(testing=True, use_mean=True) @@ -327,7 +366,6 @@ def infer( result = {} for sample in samples: - pred_arr = sample["preds"][0] if self.hparams.ndims == 2: @@ -344,16 +382,23 @@ def infer( pred.CopyInformation(img) pred = postprocess_mask(pred) - pred = sitk.Resample(pred, img, sitk.Transform(), sitk.sitkNearestNeighbor) + pred = sitk.Resample( + pred, img, sitk.Transform(), sitk.sitkNearestNeighbor + ) result[sample["name"]][structure] = pred return result def validate( - self, img, manual_observers, samples, mean, matching_type="best", window=[-0.5, 1.0] + self, + img, + manual_observers, + samples, + mean, + matching_type="best", + window=[-0.5, 1.0], ): - metrics = {"DSC": "max", "HD": "min", "ASD": "min"} result = {} @@ -369,11 +414,12 @@ def validate( mean_contours = {} for idx, structure in enumerate(structures): - color_map = plt.cm.get_cmap(contour_cmaps[idx % len(structures)]) mean_contours[f"mean_{structure}"] = mean["mean"][structure] - vis.add_contour(mean_contours, color=color_map(0.35), linewidth=3, show_legend=False) + vis.add_contour( + mean_contours, color=color_map(0.35), linewidth=3, show_legend=False + ) manual_color = color_map(0.9) @@ -383,7 +429,10 @@ def validate( } vis.add_contour( - manual_observers_struct, color=manual_color, linewidth=0.5, show_legend=False + manual_observers_struct, + color=manual_color, + linewidth=0.5, + show_legend=False, ) intersection_mask = get_intersection_mask(manual_observers_struct) @@ -395,7 +444,9 @@ def validate( color=manual_color, linewidth=3, ) - vis.add_contour(union_mask, name=f"union_{structure}", color=manual_color, linewidth=3) + vis.add_contour( + union_mask, name=f"union_{structure}", color=manual_color, linewidth=3 + ) samples_struct = { f"{sample_struct}_{structure}": samples[sample_struct][structure] @@ -407,7 +458,8 @@ def validate( color={ s: c for s, c in zip( - samples_struct, color_map(np.linspace(0.1, 0.7, len(samples_struct))) + samples_struct, + color_map(np.linspace(0.1, 0.7, len(samples_struct))), ) }, ) @@ -415,10 +467,12 @@ def validate( # vis.set_limits_from_label(union_mask, expansion=30) sim = { - k: np.zeros((len(samples_struct), len(manual_observers_struct))) for k in metrics + k: np.zeros((len(samples_struct), len(manual_observers_struct))) + for k in metrics } msim = { - k: np.zeros((len(samples_struct), len(manual_observers_struct))) for k in metrics + k: np.zeros((len(samples_struct), len(manual_observers_struct))) + for k in metrics } for sid, samp in enumerate(samples_struct): for oid, obs in enumerate(manual_observers_struct): @@ -436,7 +490,6 @@ def validate( result[f"probnet_{structure}"] = {k: [] for k in metrics} result[f"unet_{structure}"] = {k: [] for k in metrics} for k in sim: - val = sim[k] if matching_type == "hungarian": if metrics[k] == "max": @@ -468,14 +521,17 @@ def validate( return result, fig def training_step(self, batch, _): - - x, y, m, _ = batch + x, c, y, m, _ = batch # Add background layer for one-hot encoding not_y = 1 - y.max(axis=1).values not_y = torch.unsqueeze(not_y, dim=1) y = torch.cat((not_y, y), dim=1).float() + # Concat context map to image if we have one + if c is not None: + x = torch.cat((x, c), dim=1) + # self.prob_unet.forward(x, y, training=True) if self.hparams.prob_type == "prob": self.prob_unet.forward(x, y, training=True) @@ -522,7 +578,6 @@ def training_step(self, batch, _): return training_loss def validation_step(self, batch, _): - if self.validation_directory is None: self.validation_directory = Path(tempfile.mkdtemp()) @@ -530,16 +585,21 @@ def validation_step(self, batch, _): m = self.hparams.num_observers with torch.set_grad_enabled(False): - x, y, _, info = batch + x, c, y, _, info = batch # Save off slices/volumes for analysis of entire structure in end of validation step for s in range(y.shape[0]): - img_file = self.validation_directory.joinpath( f"img_{info['case'][s]}_{info['z'][s]}.npy" ) np.save(img_file, x[s].squeeze(0).cpu().numpy()) + if c is not None: + cmap_file = self.validation_directory.joinpath( + f"cmap_{info['case'][s]}_{info['z'][s]}.npy" + ) + np.save(cmap_file, c[s].squeeze(0).cpu().numpy()) + mask_file = self.validation_directory.joinpath( f"mask_{info['case'][s]}_{info['z'][s]}_{info['observer'][s]}.npy" ) @@ -554,8 +614,18 @@ def validation_step(self, batch, _): if self.hparams.ndims == 2: x = x.repeat(m, 1, 1, 1) + + if c is not None: + c = c.repeat(m, 1, 1, 1) else: x = x.repeat(m, 1, 1, 1, 1) + + if c is not None: + c = c.repeat(m, 1, 1, 1, 1) + + if c is not None: + x = torch.cat((x, c), dim=1) + self.prob_unet.forward(x) py = self.prob_unet.sample(testing=True) @@ -569,7 +639,6 @@ def validation_step(self, batch, _): y = y.int() y = y.to("cpu") - # TODO Make this work for multi class # Intersection over Union (also known as Jaccard Index) jaccard = JaccardIndex(num_classes=2) @@ -632,34 +701,35 @@ def validation_step(self, batch, _): return info def validation_epoch_end(self, validation_step_outputs): - cases = {} for info in validation_step_outputs: - for case, z, observer in zip(info["case"], info["z"], info["observer"]): if not case in cases: cases[case] = {"slices": z.item(), "observers": [observer]} else: if z.item() > cases[case]["slices"]: - cases[case]["slices"] = z.item() if not observer in cases[case]["observers"]: cases[case]["observers"].append(observer) metrics = ["DSC", "HD", "ASD"] computed_metrics = { - **{f"probnet_{s}_{m}": [] for m in metrics for s in self.hparams.structures}, + **{ + f"probnet_{s}_{m}": [] for m in metrics for s in self.hparams.structures + }, **{f"unet_{s}_{m}": [] for m in metrics for s in self.hparams.structures}, } - if len(cases) == 0: return + if len(cases) == 0: + return prob_surface_dice = 0 prob_dice = 0 for case in cases: - img_arrs = [] + cmap_arrs = [] + cmap_arr = None slices = [] if self.hparams.ndims == 2: @@ -669,17 +739,44 @@ def validation_epoch_end(self, validation_step_outputs): img_arrs.append(np.load(img_file)) slices.append(z) + cmap_file = self.validation_directory.joinpath( + f"cmap_{case}_{z}.npy" + ) + if cmap_file.exists(): + cmap_arrs.append(np.load(cmap_file)) + img_arr = np.stack(img_arrs) + if len(cmap_arrs) > 0: + cmap_arr = np.stack(cmap_arr) + else: img_file = self.validation_directory.joinpath(f"img_{case}_0.npy") img_arr = np.load(img_file) + + cmap_file = self.validation_directory.joinpath(f"cmap_{case}_0.npy") + if cmap_file.exists(): + cmap_arr = np.load(cmap_file) + img = sitk.GetImageFromArray(img_arr) img.SetSpacing(self.hparams.spacing) + + context_map = None + if cmap_arr: + context_map = sitk.GetImageFromArray(cmap_arr) + context_map.SetSpacing(self.hparams.spacing) + try: - mean = self.infer(img, num_samples=1, sample_strategy="mean", preprocess=False) + mean = self.infer( + img, + context_map=context_map, + num_samples=1, + sample_strategy="mean", + preprocess=False, + ) samples = self.infer( img, + context_map=context_map, sample_strategy="spaced", num_samples=5, spaced_range=[-2, 2], @@ -691,7 +788,6 @@ def validation_epoch_end(self, validation_step_outputs): observers = {} for _, observer in enumerate(cases[case]["observers"]): - if self.hparams.ndims == 2: mask_arrs = [] for z in slices: @@ -717,7 +813,9 @@ def validation_epoch_end(self, validation_step_outputs): observers[f"manual_{observer}"][structure] = mask # try: - result, fig = self.validate(img, observers, samples, mean, matching_type="best") + result, fig = self.validate( + img, observers, samples, mean, matching_type="best" + ) # except Exception as e: # print(f"ERROR DURING VALIDATION VALIDATE: {e}") # return @@ -738,7 +836,6 @@ def validation_epoch_end(self, validation_step_outputs): # Compute the probabilistic (surface) dice for idx, structure in enumerate(self.hparams.structures): - gt_labels = [] for _, observer in enumerate(cases[case]["observers"]): gt_labels.append(observers[f"manual_{observer}"][structure]) @@ -747,13 +844,16 @@ def validation_epoch_end(self, validation_step_outputs): for rk in samples: sample_labels.append(samples[rk][structure]) - prob_dice += probabilistic_dice(gt_labels, sample_labels, dsc_type="dsc") + prob_dice += probabilistic_dice( + gt_labels, sample_labels, dsc_type="dsc" + ) prob_surface_dice += probabilistic_dice( gt_labels, sample_labels, dsc_type="sdsc", tau=3 ) prob_dice = prob_dice / len(cases) - if np.isnan(prob_dice): prob_dice = 0 + if np.isnan(prob_dice): + prob_dice = 0 self.log( "probabilisticDice", prob_dice, @@ -764,7 +864,8 @@ def validation_epoch_end(self, validation_step_outputs): ) prob_surface_dice = prob_surface_dice / len(cases) - if np.isnan(prob_surface_dice): prob_surface_dice = 0 + if np.isnan(prob_surface_dice): + prob_surface_dice = 0 self.log( "probabilisticSurfaceDice", prob_surface_dice, @@ -798,7 +899,6 @@ def validation_epoch_end(self, validation_step_outputs): def main(args, config_json_path=None): - pl.seed_everything(args.seed, workers=True) args.working_dir = Path(args.working_dir) @@ -807,7 +907,7 @@ def main(args, config_json_path=None): args.fold_dir = args.working_dir.joinpath(f"fold_{args.fold}") args.default_root_dir = str(args.fold_dir) args.accumulate_grad_batches = {0: 5, 5: 10, 10: 15} -# args.precision = 16 + # args.precision = 16 comet_api_key = None comet_workspace = None @@ -860,7 +960,7 @@ def main(args, config_json_path=None): checkpoint_callback = ModelCheckpoint( monitor=args.checkpoint_var, dirpath=args.default_root_dir, - filename="probunet-{epoch:02d}-{"+args.checkpoint_var+":.2f}", + filename="probunet-{epoch:02d}-{" + args.checkpoint_var + ":.2f}", save_top_k=1, mode=args.checkpoint_mode, ) @@ -868,7 +968,11 @@ def main(args, config_json_path=None): if args.early_stopping_var: early_stop_callback = GECOEarlyStopping( - monitor=args.early_stopping_var, min_delta=args.early_stopping_min_delta, patience=args.early_stopping_patience, verbose=True, mode=args.early_stopping_mode + monitor=args.early_stopping_var, + min_delta=args.early_stopping_min_delta, + patience=args.early_stopping_patience, + verbose=True, + mode=args.early_stopping_mode, ) trainer.callbacks.append(early_stop_callback) @@ -876,7 +980,6 @@ def main(args, config_json_path=None): def parse_config_file(config_json_path, args): - with open(config_json_path, "r") as f: params = json.load(f) for key in params: @@ -888,11 +991,10 @@ def parse_config_file(config_json_path, args): else: args.append(str(params[key])) - return args -if __name__ == "__main__": +if __name__ == "__main__": args = None config_json_path = None if len(sys.argv) == 2: @@ -908,8 +1010,12 @@ def parse_config_file(config_json_path, args): arg_parser.add_argument( "--config", type=str, default=None, help="JSON file with parameters to load" ) - arg_parser.add_argument("--seed", type=int, default=42, help="an integer to use as seed") - arg_parser.add_argument("--experiment", type=str, default="default", help="Name of experiment") + arg_parser.add_argument( + "--seed", type=int, default=42, help="an integer to use as seed" + ) + arg_parser.add_argument( + "--experiment", type=str, default="default", help="Name of experiment" + ) arg_parser.add_argument("--working_dir", type=str, default="./working") arg_parser.add_argument("--num_observers", type=int, default=5) arg_parser.add_argument("--spacing", nargs="+", type=float, default=[1, 1, 1]) From d95b1c33e730f14c962fd3d45963dd997f183e44 Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Mon, 4 Dec 2023 12:00:57 +1100 Subject: [PATCH 218/264] Support context map for augmentation --- platipy/imaging/generation/augment.py | 135 ++++++++++++++++++-------- 1 file changed, 94 insertions(+), 41 deletions(-) diff --git a/platipy/imaging/generation/augment.py b/platipy/imaging/generation/augment.py index 46332692..bb4b4484 100644 --- a/platipy/imaging/generation/augment.py +++ b/platipy/imaging/generation/augment.py @@ -44,8 +44,8 @@ logger = logging.getLogger(__name__) -def apply_augmentation(image, augmentation, masks=[]): +def apply_augmentation(image, augmentation, context_map=None, masks=[]): if not isinstance(image, sitk.Image): raise AttributeError("image should be a SimpleITK.Image") @@ -62,7 +62,6 @@ def apply_augmentation(image, augmentation, masks=[]): transform = None dvf = None for aug in augmentation: - if not isinstance(aug, DeformableAugment): raise AttributeError("Each augmentation must be of type DeformableAugment") @@ -90,21 +89,32 @@ def apply_augmentation(image, augmentation, masks=[]): masks_deformed = [] for mask in masks: def_mask = apply_transform( - mask, transform=transform, default_value=0, interpolator=sitk.sitkNearestNeighbor + mask, + transform=transform, + default_value=0, + interpolator=sitk.sitkNearestNeighbor, ) def_mask = sitk.BinaryMorphologicalClosing(def_mask, [3, 3, 3]) masks_deformed.append(def_mask) + cmap_deformed = None + if context_map is not None: + cmap_deformed = apply_transform( + context_map, + transform=transform, + default_value=0, + interpolator=sitk.sitkNearestNeighbor, + ) + if masks: - return image_deformed, masks_deformed, dvf + return image_deformed, cmap_deformed, masks_deformed, dvf - return image_deformed, dvf + return image_deformed, cmap_deformed, dvf def generate_random_augmentation(ct_image, masks, augmentation_types): - augmentation = [] probabilities = [a["probability"] for a in augmentation_types] @@ -114,7 +124,9 @@ def generate_random_augmentation(ct_image, masks, augmentation_types): prob_none = 0 for mask in masks: - aug = random.choices(augmentation_types + [None], weights=probabilities + [prob_none])[0] + aug = random.choices( + augmentation_types + [None], weights=probabilities + [prob_none] + )[0] if aug is None: continue @@ -122,10 +134,8 @@ def generate_random_augmentation(ct_image, masks, augmentation_types): aug_class = aug["class"] aug_args = {} for arg in aug["args"]: - value = aug["args"][arg] if isinstance(value, list): - # Randomly sample for each dim result = [] for rng in value: @@ -146,20 +156,17 @@ def generate_random_augmentation(ct_image, masks, augmentation_types): class DeformableAugment(ABC): @abstractmethod def augment(self): - # return deformation pass class ShiftAugment(DeformableAugment): def __init__(self, mask, vector_shift=(10, 10, 10), gaussian_smooth=5): - self.mask = mask self.vector_shift = vector_shift self.gaussian_smooth = gaussian_smooth def augment(self): - _, transform, dvf = generate_field_shift( self.mask, self.vector_shift, @@ -172,15 +179,15 @@ def __str__(self): class ExpandAugment(DeformableAugment): - def __init__(self, mask, vector_expand=(10, 10, 10), gaussian_smooth=5, bone_mask=False): - + def __init__( + self, mask, vector_expand=(10, 10, 10), gaussian_smooth=5, bone_mask=False + ): self.mask = mask self.vector_expand = vector_expand self.gaussian_smooth = gaussian_smooth self.bone_mask = bone_mask def augment(self): - _, transform, dvf = generate_field_expand( self.mask, bone_mask=self.bone_mask, @@ -191,19 +198,23 @@ def augment(self): return transform, dvf def __str__(self): - return f"Expand with vector: {self.vector_expand}, smooth: {self.gaussian_smooth}" + return ( + f"Expand with vector: {self.vector_expand}, smooth: {self.gaussian_smooth}" + ) class ContractAugment(DeformableAugment): - def __init__(self, mask, vector_contract=(10, 10, 10), gaussian_smooth=5, bone_mask=False): - + def __init__( + self, mask, vector_contract=(10, 10, 10), gaussian_smooth=5, bone_mask=False + ): self.mask = mask - self.vector_contract = [int(-x / s) for x, s in zip(vector_contract, mask.GetSpacing())] + self.vector_contract = [ + int(-x / s) for x, s in zip(vector_contract, mask.GetSpacing()) + ] self.gaussian_smooth = gaussian_smooth self.bone_mask = bone_mask def augment(self): - _, transform, dvf = generate_field_expand( self.mask, bone_mask=self.bone_mask, @@ -217,7 +228,6 @@ def __str__(self): def augment_data(args): - random.seed(args.seed) augmentation_types = [] @@ -285,23 +295,32 @@ def augment_data(args): data = { case: { "image": data_dir.joinpath(args.image_glob.format(case=case)), - "label": [i for sl in [list(data_dir.glob(lg.format(case=case))) for lg in args.label_glob] for i in sl], + "context_map": data_dir.joinpath(args.context_map_glob.format(case=case)), + "label": [ + i + for sl in [ + list(data_dir.glob(lg.format(case=case))) for lg in args.label_glob + ] + for i in sl + ], } for case in cases } for case in cases: - logger.info(f"Augmenting for case: {case}") ct_image_original = sitk.ReadImage(str(data[case]["image"])) + cmap_original = None + if data[case]["context_map"]: + cmap_original = sitk.ReadImage(str(data[case]["context_map"])) + # Get list of structures to generate augmentations off logger.debug("Collecting structures") all_masks = [] all_names = [] for structure_path in data[case]["label"]: - mask = sitk.ReadImage(str(structure_path)) all_masks.append(mask) @@ -316,24 +335,25 @@ def augment_data(args): all_masks[m] = crop_to_roi(mask, size, index) if args.enable_fill_holes: - logger.debug("Finding holes") label_image, labels = detect_holes(ct_image) # Generate x random augmentations per case for i in range(args.augmentations_per_case): - logger.debug(f"Generating augmentation {i}") ct_image = sitk.ReadImage(str(data[case]["image"])) ct_image = crop_to_roi(ct_image, size, index) - if args.enable_fill_holes: + cmap = None + if data[case]["context_map"]: + cmap = sitk.ReadImage(str(data[case]["context_map"])) + cmap = crop_to_roi(cmap, size, index) + if args.enable_fill_holes: logger.debug("Filling holes") for label in labels[1:]: # Skip first hole since likely air around body - if random.random() > args.fill_probability: continue @@ -357,7 +377,9 @@ def augment_data(args): augmented_case_path.mkdir(exist_ok=True, parents=True) logger.debug("Generating random augmentations") - augmentation = generate_random_augmentation(ct_image, all_masks, augmentation_types) + augmentation = generate_random_augmentation( + ct_image, all_masks, augmentation_types + ) dvf = None @@ -369,12 +391,17 @@ def augment_data(args): augmented_image = ct_image augmented_masks = all_masks else: - logger.debug("Applying augmentation") - augmented_image, augmented_masks, dvf = apply_augmentation( - ct_image, augmentation, masks=all_masks + ( + augmented_image, + augmented_cmap, + augmented_masks, + dvf, + ) = apply_augmentation( + ct_image, augmentation, context_map=cmap, masks=all_masks ) + # Save off image augmented_image_path = augmented_case_path.joinpath("CT.nii.gz") ct_image_original[ index[0] : index[0] + size[0], @@ -383,17 +410,35 @@ def augment_data(args): ] = augmented_image sitk.WriteImage(ct_image_original, str(augmented_image_path)) + # Save off context map if we have one + augmented_cmap_path = augmented_case_path.joinpath("context_map.nii.gz") + cmap_original[ + index[0] : index[0] + size[0], + index[1] : index[1] + size[1], + index[2] : index[2] + size[2], + ] = augmented_cmap + sitk.WriteImage(cmap_original, str(augmented_cmap_path)) + vis = ImageVisualiser(image=ct_image, figure_size_in=6) vis.add_comparison_overlay(augmented_image) if dvf is not None: vis.add_vector_overlay(dvf, arrow_scale=1, subsample=(4, 12, 12)) - for mask_name, mask, augmented_mask in zip(all_names, all_masks, augmented_masks): - vis.add_contour({f"{mask_name}": mask, f"{mask_name} (augmented)": augmented_mask}) + for mask_name, mask, augmented_mask in zip( + all_names, all_masks, augmented_masks + ): + vis.add_contour( + {f"{mask_name}": mask, f"{mask_name} (augmented)": augmented_mask} + ) logger.debug(f"Applying augmentation to mask: {mask_name}") - augmented_mask_path = augmented_case_path.joinpath(f"{mask_name}.nii.gz") + augmented_mask_path = augmented_case_path.joinpath( + f"{mask_name}.nii.gz" + ) augmented_mask = sitk.Resample( - augmented_mask, ct_image_original, sitk.Transform(), sitk.sitkNearestNeighbor + augmented_mask, + ct_image_original, + sitk.Transform(), + sitk.sitkNearestNeighbor, ) sitk.WriteImage(augmented_mask, str(augmented_mask_path)) @@ -405,14 +450,18 @@ def augment_data(args): if __name__ == "__main__": - arg_parser = ArgumentParser() - arg_parser.add_argument("--seed", type=int, default=42, help="an integer to use as seed") + arg_parser.add_argument( + "--seed", type=int, default=42, help="an integer to use as seed" + ) arg_parser.add_argument("--data_dir", type=str, default="./data") arg_parser.add_argument("--output_dir", type=str, default="./augment") arg_parser.add_argument("--case_glob", type=str, default="images/*.nii.gz") arg_parser.add_argument("--image_glob", type=str, default="images/{case}.nii.gz") - arg_parser.add_argument("--label_glob", nargs="+", type=str, default="labels/{case}_*.nii.gz") + arg_parser.add_argument( + "--label_glob", nargs="+", type=str, default="labels/{case}_*.nii.gz" + ) + arg_parser.add_argument("--context_map_glob", type=str, default=None) arg_parser.add_argument( "--augmentations_per_case", type=int, @@ -431,7 +480,9 @@ def augment_data(args): arg_parser.add_argument("--expand_x_range", nargs="+", type=int, default=[0, 10]) arg_parser.add_argument("--expand_y_range", nargs="+", type=int, default=[0, 10]) arg_parser.add_argument("--expand_z_range", nargs="+", type=int, default=[0, 10]) - arg_parser.add_argument("--expand_smooth_range", nargs="+", type=int, default=[3, 5]) + arg_parser.add_argument( + "--expand_smooth_range", nargs="+", type=int, default=[3, 5] + ) arg_parser.add_argument("--expand_bone_mask", type=bool, default=True) arg_parser.add_argument("--expand_probability", type=float, default=0.5) @@ -439,7 +490,9 @@ def augment_data(args): arg_parser.add_argument("--contract_x_range", nargs="+", type=int, default=[0, 10]) arg_parser.add_argument("--contract_y_range", nargs="+", type=int, default=[0, 10]) arg_parser.add_argument("--contract_z_range", nargs="+", type=int, default=[0, 10]) - arg_parser.add_argument("--contract_smooth_range", nargs="+", type=int, default=[3, 5]) + arg_parser.add_argument( + "--contract_smooth_range", nargs="+", type=int, default=[3, 5] + ) arg_parser.add_argument("--contract_bone_mask", type=bool, default=True) arg_parser.add_argument("--contract_probability", type=float, default=0.5) From dac6f701f8010a75c257402551dfa36ab4630e9b Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Mon, 4 Dec 2023 12:07:35 +1100 Subject: [PATCH 219/264] Fix bug --- platipy/imaging/generation/augment.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/platipy/imaging/generation/augment.py b/platipy/imaging/generation/augment.py index bb4b4484..cfd98385 100644 --- a/platipy/imaging/generation/augment.py +++ b/platipy/imaging/generation/augment.py @@ -382,6 +382,7 @@ def augment_data(args): ) dvf = None + augmented_cmap = None if len(augmentation) == 0: logger.debug( @@ -411,13 +412,14 @@ def augment_data(args): sitk.WriteImage(ct_image_original, str(augmented_image_path)) # Save off context map if we have one - augmented_cmap_path = augmented_case_path.joinpath("context_map.nii.gz") - cmap_original[ - index[0] : index[0] + size[0], - index[1] : index[1] + size[1], - index[2] : index[2] + size[2], - ] = augmented_cmap - sitk.WriteImage(cmap_original, str(augmented_cmap_path)) + if augmented_cmap: + augmented_cmap_path = augmented_case_path.joinpath("context_map.nii.gz") + cmap_original[ + index[0] : index[0] + size[0], + index[1] : index[1] + size[1], + index[2] : index[2] + size[2], + ] = augmented_cmap + sitk.WriteImage(cmap_original, str(augmented_cmap_path)) vis = ImageVisualiser(image=ct_image, figure_size_in=6) vis.add_comparison_overlay(augmented_image) From 0de77d5242da36108e8e3613b7e18feea3c680f4 Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Mon, 4 Dec 2023 12:28:17 +1100 Subject: [PATCH 220/264] augment log to std out --- platipy/imaging/generation/augment.py | 1 + 1 file changed, 1 insertion(+) diff --git a/platipy/imaging/generation/augment.py b/platipy/imaging/generation/augment.py index cfd98385..6a2a8b22 100644 --- a/platipy/imaging/generation/augment.py +++ b/platipy/imaging/generation/augment.py @@ -42,6 +42,7 @@ from platipy.imaging.label.utils import get_union_mask from platipy.imaging.utils.crop import label_to_roi, crop_to_roi +logging.basicConfig(level=logging.DEBUG, stream=sys.stdout) logger = logging.getLogger(__name__) From eeda24dbecd5a61dfd102339ceaedee9de9d4526 Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Mon, 4 Dec 2023 12:30:13 +1100 Subject: [PATCH 221/264] Add missing import --- platipy/imaging/generation/augment.py | 1 + 1 file changed, 1 insertion(+) diff --git a/platipy/imaging/generation/augment.py b/platipy/imaging/generation/augment.py index 6a2a8b22..d8fca7ec 100644 --- a/platipy/imaging/generation/augment.py +++ b/platipy/imaging/generation/augment.py @@ -16,6 +16,7 @@ from collections.abc import Iterable import random import logging +import sys from pathlib import Path From fe45f6976105d5f26a1dd5c7add2b7a995551b3b Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Wed, 6 Dec 2023 16:33:41 +1100 Subject: [PATCH 222/264] Fix issues --- platipy/imaging/cnn/dataload.py | 8 ++++---- platipy/imaging/cnn/dataset.py | 1 + 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/platipy/imaging/cnn/dataload.py b/platipy/imaging/cnn/dataload.py index 98f62e79..a3335f4b 100644 --- a/platipy/imaging/cnn/dataload.py +++ b/platipy/imaging/cnn/dataload.py @@ -191,7 +191,7 @@ def setup(self, stage=None): { "id": case, "image": self.data_dir.joinpath(self.image_glob.format(case=case)), - "context_map": self.data_dir.joinpath( + "context_map": None if self.context_map_glob is None else self.data_dir.joinpath( self.context_map_glob.format(case=case) ), "observers": { @@ -229,7 +229,7 @@ def setup(self, stage=None): case=case, augmented_case=augmented_case ) ), - "context_map": case_aug_dir.joinpath( + "context_map": None if self.augmented_context_map_glob is None else case_aug_dir.joinpath( self.augmented_context_map_glob.format( case=case, augmented_case=augmented_case ) @@ -270,7 +270,7 @@ def setup(self, stage=None): { "id": case, "image": data_add_dir.joinpath(self.image_glob.format(case=case)), - "context_map": data_add_dir.joinpath( + "context_map": None if self.context_map_glob is None else data_add_dir.joinpath( self.context_map_glob.format(case=case) ), "observers": { @@ -316,7 +316,7 @@ def setup(self, stage=None): case=case, augmented_case=augmented_case ) ), - "context_map": case_aug_dir.joinpath( + "context_map": None if self.augmented_context_map_glob is None else case_aug_dir.joinpath( self.augmented_context_map_glob.format( case=case, augmented_case=augmented_case ) diff --git a/platipy/imaging/cnn/dataset.py b/platipy/imaging/cnn/dataset.py index bd21929e..a935c313 100644 --- a/platipy/imaging/cnn/dataset.py +++ b/platipy/imaging/cnn/dataset.py @@ -382,6 +382,7 @@ def __init__( for case in data: case_id = case["id"] img_path = str(case["image"]) + cmap_file = None if use_context_map: context_map_path = str(case["context_map"]) From 11f2ccd642ec369d26650d75b55e3c28e4141eac Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Thu, 7 Dec 2023 09:09:28 +1100 Subject: [PATCH 223/264] Add context map for validation and test cases --- platipy/imaging/cnn/dataload.py | 26 ++++++++++++++++++-------- 1 file changed, 18 insertions(+), 8 deletions(-) diff --git a/platipy/imaging/cnn/dataload.py b/platipy/imaging/cnn/dataload.py index a3335f4b..9f63b87d 100644 --- a/platipy/imaging/cnn/dataload.py +++ b/platipy/imaging/cnn/dataload.py @@ -191,9 +191,9 @@ def setup(self, stage=None): { "id": case, "image": self.data_dir.joinpath(self.image_glob.format(case=case)), - "context_map": None if self.context_map_glob is None else self.data_dir.joinpath( - self.context_map_glob.format(case=case) - ), + "context_map": None + if self.context_map_glob is None + else self.data_dir.joinpath(self.context_map_glob.format(case=case)), "observers": { observer: { structure: self.data_dir.joinpath( @@ -229,7 +229,9 @@ def setup(self, stage=None): case=case, augmented_case=augmented_case ) ), - "context_map": None if self.augmented_context_map_glob is None else case_aug_dir.joinpath( + "context_map": None + if self.augmented_context_map_glob is None + else case_aug_dir.joinpath( self.augmented_context_map_glob.format( case=case, augmented_case=augmented_case ) @@ -270,9 +272,9 @@ def setup(self, stage=None): { "id": case, "image": data_add_dir.joinpath(self.image_glob.format(case=case)), - "context_map": None if self.context_map_glob is None else data_add_dir.joinpath( - self.context_map_glob.format(case=case) - ), + "context_map": None + if self.context_map_glob is None + else data_add_dir.joinpath(self.context_map_glob.format(case=case)), "observers": { observer: { structure: data_add_dir.joinpath( @@ -316,7 +318,9 @@ def setup(self, stage=None): case=case, augmented_case=augmented_case ) ), - "context_map": None if self.augmented_context_map_glob is None else case_aug_dir.joinpath( + "context_map": None + if self.augmented_context_map_glob is None + else case_aug_dir.joinpath( self.augmented_context_map_glob.format( case=case, augmented_case=augmented_case ) @@ -345,6 +349,9 @@ def setup(self, stage=None): { "id": case, "image": self.data_dir.joinpath(self.image_glob.format(case=case)), + "context_map": self.data_dir.joinpath( + self.context_map_glob.format(case=case) + ), "observers": { observer: { structure: self.data_dir.joinpath( @@ -365,6 +372,9 @@ def setup(self, stage=None): { "id": case, "image": self.data_dir.joinpath(self.image_glob.format(case=case)), + "context_map": self.data_dir.joinpath( + self.context_map_glob.format(case=case) + ), "observers": { observer: { structure: self.data_dir.joinpath( From 604985a3fbbff28e7469a82dfd26e83d55b33149 Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Thu, 7 Dec 2023 09:27:07 +1100 Subject: [PATCH 224/264] Resolve ambigious truth value --- platipy/imaging/cnn/dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/platipy/imaging/cnn/dataset.py b/platipy/imaging/cnn/dataset.py index a935c313..fb06d086 100644 --- a/platipy/imaging/cnn/dataset.py +++ b/platipy/imaging/cnn/dataset.py @@ -628,7 +628,7 @@ def __getitem__(self, index): img = torch.FloatTensor(img) img = img.unsqueeze(0) - if context_map: + if context_map is not None: context_map = torch.FloatTensor(context_map) context_map = context_map.unqsqueeze(0) From c8ec57e255376c63cd8241a4a2de423af16a1dcc Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Thu, 7 Dec 2023 09:28:43 +1100 Subject: [PATCH 225/264] Correct typo --- platipy/imaging/cnn/dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/platipy/imaging/cnn/dataset.py b/platipy/imaging/cnn/dataset.py index fb06d086..69c48136 100644 --- a/platipy/imaging/cnn/dataset.py +++ b/platipy/imaging/cnn/dataset.py @@ -630,7 +630,7 @@ def __getitem__(self, index): if context_map is not None: context_map = torch.FloatTensor(context_map) - context_map = context_map.unqsqueeze(0) + context_map = context_map.unsqueeze(0) label = torch.FloatTensor( np.concatenate([np.expand_dims(l, 0) for l in labels], 0) From 0ddb078cacdec0a3a5dd2e45d4b3acebdc457113 Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Thu, 7 Dec 2023 09:42:28 +1100 Subject: [PATCH 226/264] Add missing unsqueeze --- platipy/imaging/cnn/train.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/platipy/imaging/cnn/train.py b/platipy/imaging/cnn/train.py index da1c173b..bba93250 100644 --- a/platipy/imaging/cnn/train.py +++ b/platipy/imaging/cnn/train.py @@ -605,8 +605,10 @@ def validation_step(self, batch, _): ) np.save(mask_file, y[s].cpu().numpy()) - # Image will be same for all in batch + # Image (and context map) will be same for all in batch x = x[0].unsqueeze(0) + if c is not None: + c = c[0].unsqueeze(0) if self.hparams.ndims == 2: vis = ImageVisualiser(sitk.GetImageFromArray(x.to("cpu")[0]), axis="z") else: From acad3bb241e4e729d17e4d357c8d20eb45d19376 Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Thu, 7 Dec 2023 09:47:40 +1100 Subject: [PATCH 227/264] Resolve ambigous truth value --- platipy/imaging/cnn/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/platipy/imaging/cnn/train.py b/platipy/imaging/cnn/train.py index bba93250..6291be93 100644 --- a/platipy/imaging/cnn/train.py +++ b/platipy/imaging/cnn/train.py @@ -764,7 +764,7 @@ def validation_epoch_end(self, validation_step_outputs): img.SetSpacing(self.hparams.spacing) context_map = None - if cmap_arr: + if cmap_arr is not None: context_map = sitk.GetImageFromArray(cmap_arr) context_map.SetSpacing(self.hparams.spacing) From 2e6adedb7a6bcb9e3f7f2f0ee787f8543039092c Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Thu, 7 Dec 2023 10:38:27 +1100 Subject: [PATCH 228/264] Add missing import --- platipy/imaging/cnn/train.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/platipy/imaging/cnn/train.py b/platipy/imaging/cnn/train.py index 6291be93..1453c92f 100644 --- a/platipy/imaging/cnn/train.py +++ b/platipy/imaging/cnn/train.py @@ -40,7 +40,12 @@ from platipy.imaging.cnn.unet import l2_regularisation from platipy.imaging.cnn.dataload import UNetDataModule from platipy.imaging.cnn.dataset import crop_img_using_localise_model -from platipy.imaging.cnn.utils import preprocess_image, postprocess_mask, get_metrics +from platipy.imaging.cnn.utils import ( + preprocess_image, + postprocess_mask, + get_metrics, + resample_mask_to_image, +) from platipy.imaging.cnn.metrics import probabilistic_dice from platipy.imaging import ImageVisualiser From 45e4bc1a7b3a19dcc5916755d17ba8f97bc6657d Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Thu, 7 Dec 2023 10:42:19 +1100 Subject: [PATCH 229/264] update deprecated cmap call --- platipy/imaging/cnn/train.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/platipy/imaging/cnn/train.py b/platipy/imaging/cnn/train.py index 1453c92f..4bf667dc 100644 --- a/platipy/imaging/cnn/train.py +++ b/platipy/imaging/cnn/train.py @@ -419,7 +419,7 @@ def validate( mean_contours = {} for idx, structure in enumerate(structures): - color_map = plt.cm.get_cmap(contour_cmaps[idx % len(structures)]) + color_map = matplotlib.colormaps.get_cmap(contour_cmaps[idx % len(structures)]) mean_contours[f"mean_{structure}"] = mean["mean"][structure] vis.add_contour( @@ -690,7 +690,7 @@ def validation_step(self, batch, _): samp_pred = samp_pred.unsqueeze(0) contours[f"sample_{mm}"] = sitk.GetImageFromArray(samp_pred) - vis.add_contour(contours, colormap=plt.cm.get_cmap("cool")) + vis.add_contour(contours, colormap=matplotlib.colormaps.get_cmap("cool")) vis.show() figure_path = f"ged_{info['z'][s]}.png" From fb746485143342e26d52743183b8ec5a8c857071 Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Thu, 7 Dec 2023 10:45:05 +1100 Subject: [PATCH 230/264] Add missing import --- platipy/imaging/cnn/train.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/platipy/imaging/cnn/train.py b/platipy/imaging/cnn/train.py index 4bf667dc..ed255e8f 100644 --- a/platipy/imaging/cnn/train.py +++ b/platipy/imaging/cnn/train.py @@ -19,6 +19,7 @@ from argparse import ArgumentParser from pathlib import Path +import matplotlib import SimpleITK as sitk import numpy as np from scipy.optimize import linear_sum_assignment @@ -419,7 +420,9 @@ def validate( mean_contours = {} for idx, structure in enumerate(structures): - color_map = matplotlib.colormaps.get_cmap(contour_cmaps[idx % len(structures)]) + color_map = matplotlib.colormaps.get_cmap( + contour_cmaps[idx % len(structures)] + ) mean_contours[f"mean_{structure}"] = mean["mean"][structure] vis.add_contour( From 83550f424687230523c1d3e3ecf6cb285d16549a Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Thu, 7 Dec 2023 10:46:27 +1100 Subject: [PATCH 231/264] Check none context map glob --- platipy/imaging/cnn/dataload.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/platipy/imaging/cnn/dataload.py b/platipy/imaging/cnn/dataload.py index 9f63b87d..bfcb4fab 100644 --- a/platipy/imaging/cnn/dataload.py +++ b/platipy/imaging/cnn/dataload.py @@ -349,9 +349,9 @@ def setup(self, stage=None): { "id": case, "image": self.data_dir.joinpath(self.image_glob.format(case=case)), - "context_map": self.data_dir.joinpath( - self.context_map_glob.format(case=case) - ), + "context_map": None + if self.context_map_glob is None + else self.data_dir.joinpath(self.context_map_glob.format(case=case)), "observers": { observer: { structure: self.data_dir.joinpath( @@ -372,9 +372,9 @@ def setup(self, stage=None): { "id": case, "image": self.data_dir.joinpath(self.image_glob.format(case=case)), - "context_map": self.data_dir.joinpath( - self.context_map_glob.format(case=case) - ), + "context_map": None + if self.context_map_glob is None + else self.data_dir.joinpath(self.context_map_glob.format(case=case)), "observers": { observer: { structure: self.data_dir.joinpath( From 5549beddfbc41f99d2e9011be30e9bcb248fa0ed Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Thu, 7 Dec 2023 10:54:37 +1100 Subject: [PATCH 232/264] Deal with none cmap --- platipy/imaging/cnn/dataset.py | 4 ++-- platipy/imaging/cnn/train.py | 12 ++++++------ 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/platipy/imaging/cnn/dataset.py b/platipy/imaging/cnn/dataset.py index 69c48136..96351a29 100644 --- a/platipy/imaging/cnn/dataset.py +++ b/platipy/imaging/cnn/dataset.py @@ -604,8 +604,8 @@ def __getitem__(self, index): for contour_mask_file in self.slices[index]["contour_masks"] ] - context_map = None - if self.slices[index]["context_map"]: + context_map = torch.Tensor() + if self.slices[index]["context_map"] is not None: context_map = np.load(self.slices[index]["context_map"]) if self.transforms: diff --git a/platipy/imaging/cnn/train.py b/platipy/imaging/cnn/train.py index ed255e8f..b230b3f1 100644 --- a/platipy/imaging/cnn/train.py +++ b/platipy/imaging/cnn/train.py @@ -537,7 +537,7 @@ def training_step(self, batch, _): y = torch.cat((not_y, y), dim=1).float() # Concat context map to image if we have one - if c is not None: + if c.numel() > 0: x = torch.cat((x, c), dim=1) # self.prob_unet.forward(x, y, training=True) @@ -602,7 +602,7 @@ def validation_step(self, batch, _): ) np.save(img_file, x[s].squeeze(0).cpu().numpy()) - if c is not None: + if c.numel() > 0: cmap_file = self.validation_directory.joinpath( f"cmap_{info['case'][s]}_{info['z'][s]}.npy" ) @@ -615,7 +615,7 @@ def validation_step(self, batch, _): # Image (and context map) will be same for all in batch x = x[0].unsqueeze(0) - if c is not None: + if c.numel() > 0: c = c[0].unsqueeze(0) if self.hparams.ndims == 2: vis = ImageVisualiser(sitk.GetImageFromArray(x.to("cpu")[0]), axis="z") @@ -625,15 +625,15 @@ def validation_step(self, batch, _): if self.hparams.ndims == 2: x = x.repeat(m, 1, 1, 1) - if c is not None: + if c.numel() > 0: c = c.repeat(m, 1, 1, 1) else: x = x.repeat(m, 1, 1, 1, 1) - if c is not None: + if c.numel() > 0: c = c.repeat(m, 1, 1, 1, 1) - if c is not None: + if c.numel() > 0: x = torch.cat((x, c), dim=1) self.prob_unet.forward(x) From 06249406a2b6fb348bff3348f7db55abeda8ccc0 Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Thu, 7 Dec 2023 11:30:26 +1100 Subject: [PATCH 233/264] Fix missing augmentations for context map --- platipy/imaging/cnn/dataset.py | 43 +++++++++++++++++++++++----------- 1 file changed, 29 insertions(+), 14 deletions(-) diff --git a/platipy/imaging/cnn/dataset.py b/platipy/imaging/cnn/dataset.py index 96351a29..3914cda3 100644 --- a/platipy/imaging/cnn/dataset.py +++ b/platipy/imaging/cnn/dataset.py @@ -54,16 +54,16 @@ def __init__(self, mu=0.0, sigma=0.0, probability=1.0): if not hasattr(self.sigma, "__iter__"): self.sigma = (self.sigma,) * 2 - def apply(self, img, masks=[]): + def apply(self, img, context_map, masks=[]): if random.random() > self.probability: # Don't augment this time - return img, masks + return img, context_map, masks mean = random.uniform(self.mu[0], self.mu[1]) sigma = random.uniform(self.sigma[0], self.sigma[1]) gaussian = np.random.normal(mean, sigma, img.shape) - return img + gaussian, masks + return img + gaussian, context_map, masks class GaussianBlur: @@ -74,14 +74,14 @@ def __init__(self, sigma=0.0, probability=1.0): if not hasattr(self.sigma, "__iter__"): self.sigma = (self.sigma,) * 2 - def apply(self, img, masks=[]): + def apply(self, img, context_map, masks=[]): if random.random() > self.probability: # Don't augment this time - return img, masks + return img, context_map, masks sigma = random.uniform(self.sigma[0], self.sigma[1]) - return gaussian_filter(img, sigma=sigma), masks + return gaussian_filter(img, sigma=sigma), context_map, masks class MedianBlur: @@ -92,14 +92,14 @@ def __init__(self, size=1.0, probability=1.0): if not hasattr(self.size, "__iter__"): self.size = (self.size,) * 2 - def apply(self, img, masks=[]): + def apply(self, img, context_map, masks=[]): if random.random() > self.probability: # Don't augment this time - return img, masks + return img, context_map, masks size = random.randint(self.size[0], self.size[1]) - return median_filter(img, size=size), masks + return median_filter(img, size=size), context_map, masks DIMS = ["ax", "cor", "sag"] @@ -184,10 +184,10 @@ def get_shear(self, shear): return mat - def apply(self, img, masks=[]): + def apply(self, img, context_map, masks=[]): if random.random() > self.probability: # Don't augment this time - return img, masks + return img, context_map, masks deg_to_rad = math.pi / 180 @@ -224,6 +224,8 @@ def apply(self, img, masks=[]): t = t * translation augmented_image = affine_transform(img, t, mode="mirror") + if context_map is not None: + augmented_context_map = affine_transform(context_map, t, mode="mirror") augmented_masks = [] for mask in masks: augmented_masks.append(affine_transform(mask, t, mode="nearest")) @@ -605,7 +607,9 @@ def __getitem__(self, index): ] context_map = torch.Tensor() + use_context = False if self.slices[index]["context_map"] is not None: + use_context = True context_map = np.load(self.slices[index]["context_map"]) if self.transforms: @@ -614,6 +618,12 @@ def __getitem__(self, index): seg_arr = np.concatenate([np.expand_dims(m, 2) for m in masks], 2) segmap = SegmentationMapsOnImage(seg_arr, shape=labels[0].shape) img, seg = self.transforms(image=img, segmentation_maps=segmap) + + # TODO Implement context map aug for 2D + if use_context: + raise NotImplementedError( + "WARNING!!! Augmentation for context map in 2D not yet implemented!" + ) for idx, _ in enumerate(labels): labels[idx] = seg.get_arr()[:, :, idx].squeeze() contour_masks[idx] = seg.get_arr()[ @@ -621,7 +631,10 @@ def __getitem__(self, index): ].squeeze() else: for aug in self.transforms: - img, masks = aug.apply(img, masks) + if use_context: + img, context_map, masks = aug.apply(img, context_map, masks) + else: + img, _, masks = aug.apply(img, None, masks) labels = masks[: len(labels)] contour_masks = masks[len(contour_masks) :] @@ -633,10 +646,12 @@ def __getitem__(self, index): context_map = context_map.unsqueeze(0) label = torch.FloatTensor( - np.concatenate([np.expand_dims(l, 0) for l in labels], 0) + np.concatenate([np.expand_dims(l, 0) for l in labels], 0).astype("int8") ) contour_mask = torch.FloatTensor( - np.concatenate([np.expand_dims(l, 0) for l in contour_masks], 0) + np.concatenate([np.expand_dims(l, 0) for l in contour_masks], 0).astype( + "int8" + ) ) contour_mask = contour_mask.max(axis=0).values.unsqueeze(0) label_present = [label is not None for label in self.slices[index]["labels"]] From fb47e1a4e3f22892f6056dc952c2ab14fab348c7 Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Thu, 7 Dec 2023 11:34:51 +1100 Subject: [PATCH 234/264] Add missing return value --- platipy/imaging/cnn/dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/platipy/imaging/cnn/dataset.py b/platipy/imaging/cnn/dataset.py index 3914cda3..5a43d3b9 100644 --- a/platipy/imaging/cnn/dataset.py +++ b/platipy/imaging/cnn/dataset.py @@ -230,7 +230,7 @@ def apply(self, img, context_map, masks=[]): for mask in masks: augmented_masks.append(affine_transform(mask, t, mode="nearest")) - return augmented_image, augmented_masks + return augmented_image, augmented_context_map, augmented_masks def crop_img_using_localise_model( From d65ec7a68a966f0e540a4214d8faec0d4ffe57d7 Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Thu, 7 Dec 2023 11:39:00 +1100 Subject: [PATCH 235/264] init variable --- platipy/imaging/cnn/dataset.py | 1 + 1 file changed, 1 insertion(+) diff --git a/platipy/imaging/cnn/dataset.py b/platipy/imaging/cnn/dataset.py index 5a43d3b9..4b5ddf73 100644 --- a/platipy/imaging/cnn/dataset.py +++ b/platipy/imaging/cnn/dataset.py @@ -224,6 +224,7 @@ def apply(self, img, context_map, masks=[]): t = t * translation augmented_image = affine_transform(img, t, mode="mirror") + augmented_context_map = None if context_map is not None: augmented_context_map = affine_transform(context_map, t, mode="mirror") augmented_masks = [] From f3d6a99b54d8c2d6a6124d894fc3b7379c1c1de2 Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Fri, 15 Dec 2023 09:05:11 +1100 Subject: [PATCH 236/264] Read the label map in nnUNet service --- services/nnunet/service.py | 69 +++++++++++++++++++++++++++++++------- 1 file changed, 57 insertions(+), 12 deletions(-) diff --git a/services/nnunet/service.py b/services/nnunet/service.py index af101cff..03a00e4d 100644 --- a/services/nnunet/service.py +++ b/services/nnunet/service.py @@ -14,13 +14,14 @@ import os import subprocess +import json from pathlib import Path import logging import SimpleITK as sitk -from platipy.backend import app, DataObject, celery # pylint: disable=unused-import +from platipy.backend import app, DataObject, celery # pylint: disable=unused-import logger = logging.getLogger(__name__) @@ -32,12 +33,13 @@ "clean_sup_slices": False, } + def clean_sup_slices(mask): lssif = sitk.LabelShapeStatisticsImageFilter() max_slice_size = 0 sizes = {} - for z in range(mask.GetSize()[2]-1, -1, -1): - lssif.Execute(sitk.ConnectedComponent(mask[:,:,z])) + for z in range(mask.GetSize()[2] - 1, -1, -1): + lssif.Execute(sitk.ConnectedComponent(mask[:, :, z])) if len(lssif.GetLabels()) == 0: continue @@ -48,13 +50,40 @@ def clean_sup_slices(mask): sizes[z] = phys_size for z in sizes: - if sizes[z] > max_slice_size/2: - mask[:,:,z+1:mask.GetSize()[2]] = 0 + if sizes[z] > max_slice_size / 2: + mask[:, :, z + 1 : mask.GetSize()[2]] = 0 break return mask +def get_structure_names(task): + # Look up structure names if we can find them dataset.json file + if "nnUNet_raw_data_base" not in os.environ: + logger.info("nnUNet_raw_data_base not set") + return {} + + raw_path = Path(os.environ["nnUNet_raw_data_base"]) + task_path = raw_path.joinpath("nnUNet_raw_data", task) + dataset_file = task_path.joinpath("dataset.json") + + logger.info("Attempting to read %s", dataset_file) + + if not dataset_file.exists(): + logger.info("dataset.json file does not exist for %s", dataset_file) + return {} + + dataset = {} + with open(dataset_file, "r") as f: + dataset = json.load(f) + + if "labels" not in dataset: + logger.info("Something went wrong reading dataset.json file") + return {} + + return dataset["labels"] + + @app.register("nnUNet Service", default_settings=NNUNET_SETTINGS_DEFAULTS) def nnunet_service(data_objects, working_dir, settings): """ @@ -73,8 +102,10 @@ def nnunet_service(data_objects, working_dir, settings): output_path = Path(working_dir).joinpath("output") output_path.mkdir() - for data_object in data_objects: + labels = get_structure_names(settings["task"]) + logger.info("Read labels: %s", labels) + for data_object in data_objects: # Create a symbolic link for each image to auto-segment using the nnUNet do_path = Path(data_object.path) io_path = input_path.joinpath(f"{settings['task']}_0000.nii.gz") @@ -109,13 +140,28 @@ def nnunet_service(data_objects, working_dir, settings): subprocess.call(command) for op in output_path.glob("*.nii.gz"): + label_map = sitk.ReadImage(str(op)) + + label_map_arr = sitk.GetArrayFromImage(label_map) + label_count = label_map_arr.max() + + for label_id in range(1, label_count + 1): + mask = label_map == label_id + + label_name = f"Structure_{label_id}" + if str(label_id) in labels: + label_name = labels[str(label_id)] - if settings["clean_sup_slices"]: - mask = sitk.ReadImage(str(op)) - mask = clean_sup_slices(mask) - sitk.WriteImage(mask, str(op)) + if settings["clean_sup_slices"]: + mask = clean_sup_slices(mask) - output_data_object = DataObject(type="FILE", path=str(op), parent=data_object) + mask_file = output_path.joinpath(f"{label_name}.nii.gz") + + sitk.WriteImage(mask, str(mask_file)) + + output_data_object = DataObject( + type="FILE", path=str(mask_file), parent=data_object + ) output_objects.append(output_data_object) os.remove(io_path) @@ -126,7 +172,6 @@ def nnunet_service(data_objects, working_dir, settings): if __name__ == "__main__": - # Run app by calling "python service.py" from the command line DICOM_LISTENER_PORT = 7777 From 28909431e6eaa3d49720d502871256f3df99669b Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Fri, 15 Dec 2023 10:59:59 +1100 Subject: [PATCH 237/264] Save the label mask in the loop --- services/nnunet/service.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/services/nnunet/service.py b/services/nnunet/service.py index 03a00e4d..5cbf0222 100644 --- a/services/nnunet/service.py +++ b/services/nnunet/service.py @@ -159,10 +159,10 @@ def nnunet_service(data_objects, working_dir, settings): sitk.WriteImage(mask, str(mask_file)) - output_data_object = DataObject( - type="FILE", path=str(mask_file), parent=data_object - ) - output_objects.append(output_data_object) + output_data_object = DataObject( + type="FILE", path=str(mask_file), parent=data_object + ) + output_objects.append(output_data_object) os.remove(io_path) From 48e8ba39b634f96862899fc33120d4f89b790420 Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Wed, 20 Dec 2023 12:05:39 +1100 Subject: [PATCH 238/264] Add HD 95 metric --- platipy/imaging/label/comparison.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/platipy/imaging/label/comparison.py b/platipy/imaging/label/comparison.py index b59981fa..9425fb48 100644 --- a/platipy/imaging/label/comparison.py +++ b/platipy/imaging/label/comparison.py @@ -95,8 +95,7 @@ def compute_surface_metrics(label_a, label_b, verbose=False): std_sd_list = [] median_sd_list = [] num_points = [] - for (la, lb) in ((label_a, label_b), (label_b, label_a)): - + for la, lb in ((label_a, label_b), (label_b, label_a)): label_intensity_stat = sitk.LabelIntensityStatisticsImageFilter() reference_distance_map = sitk.Abs( sitk.SignedMaurerDistanceMap( @@ -118,6 +117,7 @@ def compute_surface_metrics(label_a, label_b, verbose=False): mean_surf_dist = np.dot(mean_sd_list, num_points) / np.sum(num_points) max_surf_dist = np.max(max_sd_list) + hd_95 = np.percentile(max_sd_list, 95) std_surf_dist = np.sqrt( np.dot( num_points, @@ -131,6 +131,7 @@ def compute_surface_metrics(label_a, label_b, verbose=False): result = {} result["hausdorffDistance"] = hd + result["hausdorffDistance95"] = hd_95 result["meanSurfaceDistance"] = mean_surf_dist result["medianSurfaceDistance"] = median_surf_dist result["maximumSurfaceDistance"] = max_surf_dist @@ -294,8 +295,7 @@ def compute_metric_masd(label_a, label_b, auto_crop=True): mean_sd_list = [] num_points = [] - for (la, lb) in ((label_a, label_b), (label_b, label_a)): - + for la, lb in ((label_a, label_b), (label_b, label_a)): label_intensity_stat = sitk.LabelIntensityStatisticsImageFilter() reference_distance_map = sitk.Abs( sitk.SignedMaurerDistanceMap( @@ -364,7 +364,6 @@ def compute_apl(label_ref, label_test, distance_threshold_mm=3): # iterate over each slice for i in range(n_slices): - if ( sitk.GetArrayViewFromImage(label_ref)[i].sum() + sitk.GetArrayViewFromImage(label_test)[i].sum() From 34e2ebb8f4fd86cbde9a1363099e60c58cc8fea7 Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Wed, 20 Dec 2023 12:14:13 +1100 Subject: [PATCH 239/264] Update docstring --- platipy/imaging/label/comparison.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/platipy/imaging/label/comparison.py b/platipy/imaging/label/comparison.py index 9425fb48..e154b999 100644 --- a/platipy/imaging/label/comparison.py +++ b/platipy/imaging/label/comparison.py @@ -74,8 +74,8 @@ def compute_surface_dsc(label_a, label_b, tau=3.0): def compute_surface_metrics(label_a, label_b, verbose=False): """Compute surface distance metrics between two labels. Surface metrics computed are: - hausdorffDistance, meanSurfaceDistance, medianSurfaceDistance, maximumSurfaceDistance, - sigmaSurfaceDistance, surfaceDSC + hausdorffDistance, hausdorffDistance95, meanSurfaceDistance, medianSurfaceDistance, + maximumSurfaceDistance, sigmaSurfaceDistance, surfaceDSC Args: label_a (sitk.Image): A mask to compare From b56a8767dca4cb29dff342a1688e037ffbefff7c Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Mon, 1 Apr 2024 17:13:27 -0500 Subject: [PATCH 240/264] Allow inference using a reference segmentation --- platipy/imaging/cnn/prob_unet.py | 11 +++++++- platipy/imaging/cnn/train.py | 45 +++++++++++++++++++++++++++----- 2 files changed, 48 insertions(+), 8 deletions(-) diff --git a/platipy/imaging/cnn/prob_unet.py b/platipy/imaging/cnn/prob_unet.py index 5ec0da91..7f565663 100644 --- a/platipy/imaging/cnn/prob_unet.py +++ b/platipy/imaging/cnn/prob_unet.py @@ -289,7 +289,7 @@ def sample(self, testing=False, use_mean=False, sample_x_stddev_from_mean=None): return self.fcomb.forward(self.unet_features, z_prior) - def reconstruct(self, use_posterior_mean=False, z_posterior=None): + def reconstruct(self, use_posterior_mean=False, z_posterior=None, sample_x_stddev_from_mean=None): """ Reconstruct a segmentation from a posterior sample (decoding a posterior sample) and UNet feature map @@ -298,6 +298,15 @@ def reconstruct(self, use_posterior_mean=False, z_posterior=None): """ if use_posterior_mean: z_posterior = self.posterior_latent_space.mean + elif sample_x_stddev_from_mean is not None: + if isinstance(sample_x_stddev_from_mean, list): + sample_x_stddev_from_mean = torch.Tensor(sample_x_stddev_from_mean) + sample_x_stddev_from_mean = sample_x_stddev_from_mean.to( + self.posterior_latent_space.base_dist.stddev.device + ) + z_posterior = self.posterior_latent_space.base_dist.loc + ( + self.posterior_latent_space.base_dist.scale * sample_x_stddev_from_mean + ) else: if z_posterior is None: z_posterior = self.posterior_latent_space.rsample() diff --git a/platipy/imaging/cnn/train.py b/platipy/imaging/cnn/train.py index b230b3f1..fb576803 100644 --- a/platipy/imaging/cnn/train.py +++ b/platipy/imaging/cnn/train.py @@ -242,6 +242,7 @@ def infer( self, img, context_map=None, + seg=None, num_samples=1, sample_strategy="mean", latent_dim=True, @@ -314,22 +315,34 @@ def infer( intensity_window=self.hparams.intensity_window, ) + + img_arr = sitk.GetArrayFromImage(img) if context_map is not None: context_map = resample_mask_to_image(img, context_map) cmap_arr = sitk.GetArrayFromImage(img) + if seg is not None: + seg = resample_mask_to_image(img, seg) + seg_arr = sitk.GetArrayFromImage(img) + if self.hparams.ndims == 2: slices = [img_arr[z, :, :] for z in range(img_arr.shape[0])] if context_map is not None: cmap_slices = [cmap_arr[z, :, :] for z in range(cmap_arr.shape[0])] + + if seg is not None: + seg_slices = [seg_arr[z, :, :] for z in range(seg_arr.shape[0])] else: slices = [img_arr] if context_map is not None: cmap_slices = [cmap_arr] + if seg is not None: + seg_slices = [seg_arr] + for idx, i in enumerate(slices): x = torch.Tensor(i).to(self.device) x = x.unsqueeze(0) @@ -342,19 +355,37 @@ def infer( x = torch.cat((x, c), dim=1) + if seg is not None: + s = torch.Tensor(seg_slices[idx]).to(self.device) + s = s.unsqueeze(0) + s = s.unsqueeze(0) + if self.hparams.prob_type == "prob": - self.prob_unet.forward(x) + if seg is not None: + self.prob_unet.forward(img, seg=seg, training=True) + else: + self.prob_unet.forward(x) for sample in samples: if self.hparams.prob_type == "prob": if sample["name"] == "mean": - y = self.prob_unet.sample(testing=True, use_mean=True) + if seg is None: + y = self.prob_unet.sample(testing=True, use_mean=True) + else: + y = self.prob_unet.reconstruct(use_posterior_mean=True) else: - y = self.prob_unet.sample( - testing=True, - use_mean=False, - sample_x_stddev_from_mean=sample["std_dev_from_mean"], - ) + if seg is None: + y = self.prob_unet.sample( + testing=True, + use_mean=False, + sample_x_stddev_from_mean=sample["std_dev_from_mean"], + ) + else: + y = self.prob_unet.reconstruct( + use_posterior_mean=False, + sample_x_stddev_from_mean=sample["std_dev_from_mean"], + ) + # else: # if sample["name"] == "mean": # y = self.prob_unet.sample(x, mean=True) From 080aba603ecc03978158dff618020019ab5594f7 Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Mon, 1 Apr 2024 17:29:31 -0500 Subject: [PATCH 241/264] Fix variable passed --- platipy/imaging/cnn/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/platipy/imaging/cnn/train.py b/platipy/imaging/cnn/train.py index fb576803..ce38a9bc 100644 --- a/platipy/imaging/cnn/train.py +++ b/platipy/imaging/cnn/train.py @@ -362,7 +362,7 @@ def infer( if self.hparams.prob_type == "prob": if seg is not None: - self.prob_unet.forward(img, seg=seg, training=True) + self.prob_unet.forward(img, seg=s, training=True) else: self.prob_unet.forward(x) From f20acf6eddddd5b39b7d693314645bf47263cbf1 Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Mon, 1 Apr 2024 17:34:09 -0500 Subject: [PATCH 242/264] Pass img --- platipy/imaging/cnn/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/platipy/imaging/cnn/train.py b/platipy/imaging/cnn/train.py index ce38a9bc..0d7c3a12 100644 --- a/platipy/imaging/cnn/train.py +++ b/platipy/imaging/cnn/train.py @@ -362,7 +362,7 @@ def infer( if self.hparams.prob_type == "prob": if seg is not None: - self.prob_unet.forward(img, seg=s, training=True) + self.prob_unet.forward(x, seg=s, training=True) else: self.prob_unet.forward(x) From ec18dbc1f946d78568e5f1ede85969ea6d9fbb03 Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Mon, 1 Apr 2024 17:42:00 -0500 Subject: [PATCH 243/264] Add background channel --- platipy/imaging/cnn/train.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/platipy/imaging/cnn/train.py b/platipy/imaging/cnn/train.py index 0d7c3a12..afffdc85 100644 --- a/platipy/imaging/cnn/train.py +++ b/platipy/imaging/cnn/train.py @@ -360,6 +360,11 @@ def infer( s = s.unsqueeze(0) s = s.unsqueeze(0) + # Add in background channel + not_s = 1 - sample_strategy.max(axis=1).values + not_s = torch.unsqueeze(not_s, dim=1) + s = torch.cat((not_s, s), dim=1).float() + if self.hparams.prob_type == "prob": if seg is not None: self.prob_unet.forward(x, seg=s, training=True) From 7373a2afb478c3e2bd96d1366f94ff0602a411ff Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Mon, 1 Apr 2024 17:44:44 -0500 Subject: [PATCH 244/264] Fix var name --- platipy/imaging/cnn/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/platipy/imaging/cnn/train.py b/platipy/imaging/cnn/train.py index afffdc85..09b28f5c 100644 --- a/platipy/imaging/cnn/train.py +++ b/platipy/imaging/cnn/train.py @@ -361,7 +361,7 @@ def infer( s = s.unsqueeze(0) # Add in background channel - not_s = 1 - sample_strategy.max(axis=1).values + not_s = 1 - s.max(axis=1).values not_s = torch.unsqueeze(not_s, dim=1) s = torch.cat((not_s, s), dim=1).float() From e9e5a145f7e18c79b7ea66eac2b36bd0b44086b7 Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Sun, 14 Apr 2024 10:52:37 -0500 Subject: [PATCH 245/264] Experiment with using structure context --- platipy/imaging/cnn/prob_unet.py | 40 ++++++++++++++++++++++---------- platipy/imaging/cnn/train.py | 8 +++++++ 2 files changed, 36 insertions(+), 12 deletions(-) diff --git a/platipy/imaging/cnn/prob_unet.py b/platipy/imaging/cnn/prob_unet.py index 7f565663..408c2d75 100644 --- a/platipy/imaging/cnn/prob_unet.py +++ b/platipy/imaging/cnn/prob_unet.py @@ -213,6 +213,7 @@ def __init__( loss_params={"beta": 1}, ndims=2, dropout_probability=0.0, + use_structure_context=False ): super(ProbabilisticUnet, self).__init__() @@ -221,18 +222,24 @@ def __init__( self.no_convs_fcomb = no_convs_fcomb self.initializers = {"w": "he_normal", "b": "normal"} self.z_prior_sample = 0 + + unet_input_channels = input_channels + if use_structure_context: + unet_input_channels += 1 self.unet = UNet( - input_channels, + unet_input_channels, num_classes, filters_per_layer, final_layer=False, dropout_probability=dropout_probability, ndims=ndims, ) - self.prior = AxisAlignedConvGaussian( - input_channels, filters_per_layer, latent_dim, ndims=ndims, dropout_probability=0.0 - ) + self.prior = None + if not use_structure_context: + self.prior = AxisAlignedConvGaussian( + input_channels, filters_per_layer, latent_dim, ndims=ndims, dropout_probability=0.0 + ) self.posterior = AxisAlignedConvGaussian( input_channels + num_classes, filters_per_layer, latent_dim, ndims=ndims, dropout_probability=0.0 ) @@ -257,9 +264,12 @@ def forward(self, img, seg=None, training=False): Construct prior latent space for patch and run patch through UNet, in case training is True also construct posterior latent space """ - if training: + if training or self.prior is None: self.posterior_latent_space = self.posterior.forward(img, seg=seg) - self.prior_latent_space = self.prior.forward(img) + + self.prior_latent_space = None + if self.prior is not None: + self.prior_latent_space = self.prior.forward(img) self.unet_features = self.unet.forward(img) def sample(self, testing=False, use_mean=False, sample_x_stddev_from_mean=None): @@ -268,23 +278,25 @@ def sample(self, testing=False, use_mean=False, sample_x_stddev_from_mean=None): and combining this with UNet features """ + latent_space = self.prior_latent_space if self.prior is not None else self.posterior_latent_space + if testing: if use_mean: - z_prior = self.prior_latent_space.base_dist.loc + z_prior = latent_space.base_dist.loc elif not sample_x_stddev_from_mean is None: if isinstance(sample_x_stddev_from_mean, list): sample_x_stddev_from_mean = torch.Tensor(sample_x_stddev_from_mean) sample_x_stddev_from_mean = sample_x_stddev_from_mean.to( - self.prior_latent_space.base_dist.stddev.device + latent_space.base_dist.stddev.device ) z_prior = self.prior_latent_space.base_dist.loc + ( - self.prior_latent_space.base_dist.scale * sample_x_stddev_from_mean + latent_space.base_dist.scale * sample_x_stddev_from_mean ) else: - z_prior = self.prior_latent_space.sample() + z_prior = latent_space.sample() self.z_prior_sample = z_prior else: - z_prior = self.prior_latent_space.rsample() + z_prior = latent_space.rsample() self.z_prior_sample = z_prior return self.fcomb.forward(self.unet_features, z_prior) @@ -317,7 +329,11 @@ def kl_divergence(self): Calculate the KL divergence between the posterior and prior KL(Q||P) """ - kl_div = kl.kl_divergence(self.posterior_latent_space, self.prior_latent_space) + if self.prior_latent_space is None: + dist = Independent(Normal(loc=0, scale=1), 1) + kl_div = kl.kl_divergence(self.posterior_latent_space, dist) + else: + kl_div = kl.kl_divergence(self.posterior_latent_space, self.prior_latent_space) return kl_div diff --git a/platipy/imaging/cnn/train.py b/platipy/imaging/cnn/train.py index 09b28f5c..3d6a6e5f 100644 --- a/platipy/imaging/cnn/train.py +++ b/platipy/imaging/cnn/train.py @@ -103,6 +103,8 @@ def __init__( ] = self.hparams.contour_loss_lambda_threshold loss_params["contour_loss_weight"] = self.hparams.contour_loss_weight + self.use_structure_context = self.hparams.use_structure_context + if self.hparams.prob_type == "prob": self.prob_unet = ProbabilisticUnet( self.hparams.input_channels, @@ -115,6 +117,7 @@ def __init__( loss_params, self.hparams.ndims, dropout_probability=self.hparams.dropout_probability, + use_structure_context=self.use_structure_context, ) elif self.hparams.prob_type == "hierarchical": raise NotImplementedError("Hierarchical Prob UNet current not working...") @@ -173,6 +176,7 @@ def add_model_specific_args(parent_parser): ) # no longer used parser.add_argument("--epochs_all_rec", type=int, default=0) # no longer used parser.add_argument("--dropout_probability", type=float, default=0.0) + parser.add_argument("--use_structure_context", type=bool, default=False) return parent_parser @@ -576,6 +580,10 @@ def training_step(self, batch, _): if c.numel() > 0: x = torch.cat((x, c), dim=1) + # Concat input mask if we are using the structure as context + if self.use_structure_context: + x = torch.cat((x, y), dim=1) + # self.prob_unet.forward(x, y, training=True) if self.hparams.prob_type == "prob": self.prob_unet.forward(x, y, training=True) From c71b9505244e112c3602ef7b05e54ebec01e0bc6 Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Sun, 14 Apr 2024 11:35:35 -0500 Subject: [PATCH 246/264] Use int instead of bool --- platipy/imaging/cnn/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/platipy/imaging/cnn/train.py b/platipy/imaging/cnn/train.py index 3d6a6e5f..6eb8809b 100644 --- a/platipy/imaging/cnn/train.py +++ b/platipy/imaging/cnn/train.py @@ -176,7 +176,7 @@ def add_model_specific_args(parent_parser): ) # no longer used parser.add_argument("--epochs_all_rec", type=int, default=0) # no longer used parser.add_argument("--dropout_probability", type=float, default=0.0) - parser.add_argument("--use_structure_context", type=bool, default=False) + parser.add_argument("--use_structure_context", type=int, default=0) return parent_parser From f3bc014bbcff97c1a23972d80e230544a0324848 Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Mon, 15 Apr 2024 09:27:49 -0500 Subject: [PATCH 247/264] Fix None prior --- platipy/imaging/cnn/train.py | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/platipy/imaging/cnn/train.py b/platipy/imaging/cnn/train.py index 6eb8809b..f438e652 100644 --- a/platipy/imaging/cnn/train.py +++ b/platipy/imaging/cnn/train.py @@ -192,11 +192,19 @@ def configure_optimizers(self): "lr": 1e-4, } ] - for m in [ - self.prob_unet.prior.parameters(), - self.prob_unet.posterior.parameters(), - self.prob_unet.fcomb.parameters(), - ]: + + if self.prob_unet.prior is not None: + param_list =[ + self.prob_unet.prior.parameters(), + self.prob_unet.posterior.parameters(), + self.prob_unet.fcomb.parameters(), + ] + else: + param_list =[ + self.prob_unet.posterior.parameters(), + self.prob_unet.fcomb.parameters(), + ] + for m in: params += [ {"params": m, "weight_decay": self.hparams.weight_decay, "lr": 1e-5} ] From 06b7ec65d310110b6a25c3249372b882987aa608 Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Mon, 15 Apr 2024 09:28:43 -0500 Subject: [PATCH 248/264] Fix missing var name --- platipy/imaging/cnn/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/platipy/imaging/cnn/train.py b/platipy/imaging/cnn/train.py index f438e652..9f260336 100644 --- a/platipy/imaging/cnn/train.py +++ b/platipy/imaging/cnn/train.py @@ -204,7 +204,7 @@ def configure_optimizers(self): self.prob_unet.posterior.parameters(), self.prob_unet.fcomb.parameters(), ] - for m in: + for m in param_list: params += [ {"params": m, "weight_decay": self.hparams.weight_decay, "lr": 1e-5} ] From 8e0491c033f66d9ff260115a41594bb3b6f92ebe Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Tue, 16 Apr 2024 08:36:30 +1000 Subject: [PATCH 249/264] Tweak probunet input --- platipy/imaging/cnn/prob_unet.py | 2 +- platipy/imaging/cnn/train.py | 15 +++++++++++---- 2 files changed, 12 insertions(+), 5 deletions(-) diff --git a/platipy/imaging/cnn/prob_unet.py b/platipy/imaging/cnn/prob_unet.py index 408c2d75..37838734 100644 --- a/platipy/imaging/cnn/prob_unet.py +++ b/platipy/imaging/cnn/prob_unet.py @@ -222,7 +222,7 @@ def __init__( self.no_convs_fcomb = no_convs_fcomb self.initializers = {"w": "he_normal", "b": "normal"} self.z_prior_sample = 0 - + unet_input_channels = input_channels if use_structure_context: unet_input_channels += 1 diff --git a/platipy/imaging/cnn/train.py b/platipy/imaging/cnn/train.py index 9f260336..b4be05a2 100644 --- a/platipy/imaging/cnn/train.py +++ b/platipy/imaging/cnn/train.py @@ -579,18 +579,25 @@ def validate( def training_step(self, batch, _): x, c, y, m, _ = batch + # Concat input mask if we are using the structure as context + if self.use_structure_context: + x = torch.cat((x, y), dim=1) + + print(f"y.shape1 {y.shape}") # Add background layer for one-hot encoding not_y = 1 - y.max(axis=1).values not_y = torch.unsqueeze(not_y, dim=1) y = torch.cat((not_y, y), dim=1).float() + print(f"y.shape2 {y.shape}") + + print(f"x.shape0 {x.shape}") + + print(f"c.shape {c.shape}") # Concat context map to image if we have one if c.numel() > 0: x = torch.cat((x, c), dim=1) - - # Concat input mask if we are using the structure as context - if self.use_structure_context: - x = torch.cat((x, y), dim=1) + print(f"x.shape 1 {x.shape}") # self.prob_unet.forward(x, y, training=True) if self.hparams.prob_type == "prob": From a1f98aa1ed0c1c50d92f57e5bd36e23e92051ddf Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Mon, 15 Apr 2024 17:41:05 -0500 Subject: [PATCH 250/264] cat in probunet --- platipy/imaging/cnn/prob_unet.py | 10 +++++++++- platipy/imaging/cnn/train.py | 7 ------- 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/platipy/imaging/cnn/prob_unet.py b/platipy/imaging/cnn/prob_unet.py index 37838734..fc81be62 100644 --- a/platipy/imaging/cnn/prob_unet.py +++ b/platipy/imaging/cnn/prob_unet.py @@ -222,10 +222,11 @@ def __init__( self.no_convs_fcomb = no_convs_fcomb self.initializers = {"w": "he_normal", "b": "normal"} self.z_prior_sample = 0 + self.use_structure_context = use_structure_context unet_input_channels = input_channels if use_structure_context: - unet_input_channels += 1 + unet_input_channels = unet_input_channels + num_classes self.unet = UNet( unet_input_channels, @@ -270,6 +271,13 @@ def forward(self, img, seg=None, training=False): self.prior_latent_space = None if self.prior is not None: self.prior_latent_space = self.prior.forward(img) + + if self.use_structure_context: + if seg is None: + raise ValueError("Structure context is enabled, but no segmentation mask provided") + + img = torch.cat((img, seg), dim=1) + self.unet_features = self.unet.forward(img) def sample(self, testing=False, use_mean=False, sample_x_stddev_from_mean=None): diff --git a/platipy/imaging/cnn/train.py b/platipy/imaging/cnn/train.py index b4be05a2..ea8b49de 100644 --- a/platipy/imaging/cnn/train.py +++ b/platipy/imaging/cnn/train.py @@ -583,21 +583,14 @@ def training_step(self, batch, _): if self.use_structure_context: x = torch.cat((x, y), dim=1) - print(f"y.shape1 {y.shape}") # Add background layer for one-hot encoding not_y = 1 - y.max(axis=1).values not_y = torch.unsqueeze(not_y, dim=1) y = torch.cat((not_y, y), dim=1).float() - print(f"y.shape2 {y.shape}") - - print(f"x.shape0 {x.shape}") - - print(f"c.shape {c.shape}") # Concat context map to image if we have one if c.numel() > 0: x = torch.cat((x, c), dim=1) - print(f"x.shape 1 {x.shape}") # self.prob_unet.forward(x, y, training=True) if self.hparams.prob_type == "prob": From a3e775b21c15f340a574622980d3c28e112bdb69 Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Mon, 15 Apr 2024 17:45:09 -0500 Subject: [PATCH 251/264] Dont set in train loop --- platipy/imaging/cnn/train.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/platipy/imaging/cnn/train.py b/platipy/imaging/cnn/train.py index ea8b49de..c821acfe 100644 --- a/platipy/imaging/cnn/train.py +++ b/platipy/imaging/cnn/train.py @@ -579,10 +579,6 @@ def validate( def training_step(self, batch, _): x, c, y, m, _ = batch - # Concat input mask if we are using the structure as context - if self.use_structure_context: - x = torch.cat((x, y), dim=1) - # Add background layer for one-hot encoding not_y = 1 - y.max(axis=1).values not_y = torch.unsqueeze(not_y, dim=1) From 89c34fe52802ea72ce545e4acac24ad6445a877c Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Mon, 15 Apr 2024 17:55:15 -0500 Subject: [PATCH 252/264] Correct unit dist --- platipy/imaging/cnn/prob_unet.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/platipy/imaging/cnn/prob_unet.py b/platipy/imaging/cnn/prob_unet.py index fc81be62..a5250d1a 100644 --- a/platipy/imaging/cnn/prob_unet.py +++ b/platipy/imaging/cnn/prob_unet.py @@ -222,6 +222,7 @@ def __init__( self.no_convs_fcomb = no_convs_fcomb self.initializers = {"w": "he_normal", "b": "normal"} self.z_prior_sample = 0 + self.latent_dim = latent_dim self.use_structure_context = use_structure_context unet_input_channels = input_channels @@ -338,7 +339,7 @@ def kl_divergence(self): """ if self.prior_latent_space is None: - dist = Independent(Normal(loc=0, scale=1), 1) + dist = Independent(Normal(loc=torch.zeros(self.latent_dim), scale=torch.ones(self.latent_dim)), 1) kl_div = kl.kl_divergence(self.posterior_latent_space, dist) else: kl_div = kl.kl_divergence(self.posterior_latent_space, self.prior_latent_space) From eb4277462b85542d66a26e5b5d9a6e0ea22a7184 Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Mon, 15 Apr 2024 18:01:27 -0500 Subject: [PATCH 253/264] Move dist to device --- platipy/imaging/cnn/prob_unet.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/platipy/imaging/cnn/prob_unet.py b/platipy/imaging/cnn/prob_unet.py index a5250d1a..38f5aa64 100644 --- a/platipy/imaging/cnn/prob_unet.py +++ b/platipy/imaging/cnn/prob_unet.py @@ -222,7 +222,7 @@ def __init__( self.no_convs_fcomb = no_convs_fcomb self.initializers = {"w": "he_normal", "b": "normal"} self.z_prior_sample = 0 - self.latent_dim = latent_dim + self.dist = Independent(Normal(loc=torch.zeros(latent_dim), scale=torch.ones(latent_dim)), 1) self.use_structure_context = use_structure_context unet_input_channels = input_channels @@ -339,8 +339,8 @@ def kl_divergence(self): """ if self.prior_latent_space is None: - dist = Independent(Normal(loc=torch.zeros(self.latent_dim), scale=torch.ones(self.latent_dim)), 1) - kl_div = kl.kl_divergence(self.posterior_latent_space, dist) + self.dist.to(self.posterior_latent_space.loc.device) + kl_div = kl.kl_divergence(self.posterior_latent_space, self.dist) else: kl_div = kl.kl_divergence(self.posterior_latent_space, self.prior_latent_space) From 0a76cb9b4371231c09a0a6f1470d874e3dbdea39 Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Mon, 15 Apr 2024 18:04:08 -0500 Subject: [PATCH 254/264] Fix move to device --- platipy/imaging/cnn/prob_unet.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/platipy/imaging/cnn/prob_unet.py b/platipy/imaging/cnn/prob_unet.py index 38f5aa64..94107c9f 100644 --- a/platipy/imaging/cnn/prob_unet.py +++ b/platipy/imaging/cnn/prob_unet.py @@ -339,7 +339,7 @@ def kl_divergence(self): """ if self.prior_latent_space is None: - self.dist.to(self.posterior_latent_space.loc.device) + self.dist.to(self.posterior_latent_space.base_dist.stddev.device) kl_div = kl.kl_divergence(self.posterior_latent_space, self.dist) else: kl_div = kl.kl_divergence(self.posterior_latent_space, self.prior_latent_space) From 926825b8426ca15b86c0a6911e6b7465bfc4dccd Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Mon, 15 Apr 2024 18:07:45 -0500 Subject: [PATCH 255/264] Fix move to device --- platipy/imaging/cnn/prob_unet.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/platipy/imaging/cnn/prob_unet.py b/platipy/imaging/cnn/prob_unet.py index 94107c9f..31db04fe 100644 --- a/platipy/imaging/cnn/prob_unet.py +++ b/platipy/imaging/cnn/prob_unet.py @@ -222,7 +222,7 @@ def __init__( self.no_convs_fcomb = no_convs_fcomb self.initializers = {"w": "he_normal", "b": "normal"} self.z_prior_sample = 0 - self.dist = Independent(Normal(loc=torch.zeros(latent_dim), scale=torch.ones(latent_dim)), 1) + self.latent_dim = latent_dim self.use_structure_context = use_structure_context unet_input_channels = input_channels @@ -339,8 +339,10 @@ def kl_divergence(self): """ if self.prior_latent_space is None: - self.dist.to(self.posterior_latent_space.base_dist.stddev.device) - kl_div = kl.kl_divergence(self.posterior_latent_space, self.dist) + + device = self.posterior_latent_space.base_dist.stddev.device + dist = Independent(Normal(loc=torch.zeros(self.latent_dim).to(device), scale=torch.ones(self.latent_dim)).to(device), 1) + kl_div = kl.kl_divergence(self.posterior_latent_space, dist) else: kl_div = kl.kl_divergence(self.posterior_latent_space, self.prior_latent_space) From 3c12dd39de7bfa568c21f10bb0cdf8b5b44b6136 Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Mon, 15 Apr 2024 18:10:25 -0500 Subject: [PATCH 256/264] Fix bracket --- platipy/imaging/cnn/prob_unet.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/platipy/imaging/cnn/prob_unet.py b/platipy/imaging/cnn/prob_unet.py index 31db04fe..087fe65f 100644 --- a/platipy/imaging/cnn/prob_unet.py +++ b/platipy/imaging/cnn/prob_unet.py @@ -341,7 +341,7 @@ def kl_divergence(self): if self.prior_latent_space is None: device = self.posterior_latent_space.base_dist.stddev.device - dist = Independent(Normal(loc=torch.zeros(self.latent_dim).to(device), scale=torch.ones(self.latent_dim)).to(device), 1) + dist = Independent(Normal(loc=torch.zeros(self.latent_dim).to(device), scale=torch.ones(self.latent_dim).to(device)), 1) kl_div = kl.kl_divergence(self.posterior_latent_space, dist) else: kl_div = kl.kl_divergence(self.posterior_latent_space, self.prior_latent_space) From c908e85b04d6a34fa6dcd35b6ca19c679d576b05 Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Tue, 16 Apr 2024 17:33:28 -0500 Subject: [PATCH 257/264] Cat context during validation --- platipy/imaging/cnn/train.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/platipy/imaging/cnn/train.py b/platipy/imaging/cnn/train.py index c821acfe..417560d9 100644 --- a/platipy/imaging/cnn/train.py +++ b/platipy/imaging/cnn/train.py @@ -684,6 +684,11 @@ def validation_step(self, batch, _): if c.numel() > 0: x = torch.cat((x, c), dim=1) + if self.use_structure_context: + not_y = 1 - y.max(axis=1).values + not_y = torch.unsqueeze(not_y, dim=1) + x = torch.cat((x, not_y, y), dim=1).float() + self.prob_unet.forward(x) py = self.prob_unet.sample(testing=True) From 09346ab63d30ff01f5d8493a3b51821d97d81219 Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Tue, 16 Apr 2024 18:03:39 -0500 Subject: [PATCH 258/264] Pass seg during validation --- platipy/imaging/cnn/train.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/platipy/imaging/cnn/train.py b/platipy/imaging/cnn/train.py index 417560d9..01187866 100644 --- a/platipy/imaging/cnn/train.py +++ b/platipy/imaging/cnn/train.py @@ -684,12 +684,13 @@ def validation_step(self, batch, _): if c.numel() > 0: x = torch.cat((x, c), dim=1) + seg = None if self.use_structure_context: not_y = 1 - y.max(axis=1).values not_y = torch.unsqueeze(not_y, dim=1) - x = torch.cat((x, not_y, y), dim=1).float() + seg = torch.cat((not_y, y), dim=1).float() - self.prob_unet.forward(x) + self.prob_unet.forward(x, seg=seg) py = self.prob_unet.sample(testing=True) py = py.to("cpu") From 967300fed3af980e72534c4d7175407d1719d983 Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Fri, 19 Apr 2024 16:27:42 -0500 Subject: [PATCH 259/264] Pass the seg during inference --- platipy/imaging/cnn/train.py | 60 ++++++++++++++++++++---------------- 1 file changed, 34 insertions(+), 26 deletions(-) diff --git a/platipy/imaging/cnn/train.py b/platipy/imaging/cnn/train.py index 01187866..4a83d440 100644 --- a/platipy/imaging/cnn/train.py +++ b/platipy/imaging/cnn/train.py @@ -379,7 +379,7 @@ def infer( if self.hparams.prob_type == "prob": if seg is not None: - self.prob_unet.forward(x, seg=s, training=True) + self.prob_unet.forward(x, seg=s) else: self.prob_unet.forward(x) @@ -825,31 +825,6 @@ def validation_epoch_end(self, validation_step_outputs): img = sitk.GetImageFromArray(img_arr) img.SetSpacing(self.hparams.spacing) - context_map = None - if cmap_arr is not None: - context_map = sitk.GetImageFromArray(cmap_arr) - context_map.SetSpacing(self.hparams.spacing) - - try: - mean = self.infer( - img, - context_map=context_map, - num_samples=1, - sample_strategy="mean", - preprocess=False, - ) - samples = self.infer( - img, - context_map=context_map, - sample_strategy="spaced", - num_samples=5, - spaced_range=[-2, 2], - preprocess=False, - ) - except Exception as e: - print(f"ERROR DURING VALIDATION INFERENCE: {e}") - return - observers = {} for _, observer in enumerate(cases[case]["observers"]): if self.hparams.ndims == 2: @@ -876,6 +851,39 @@ def validation_epoch_end(self, validation_step_outputs): mask.CopyInformation(img) observers[f"manual_{observer}"][structure] = mask + context_map = None + if cmap_arr is not None: + context_map = sitk.GetImageFromArray(cmap_arr) + context_map.SetSpacing(self.hparams.spacing) + + seg = None + if self.use_structure_context: + # TODO choose the observer to pass properly + seg = observers[f"manual_{observer}"][structure] + + try: + mean = self.infer( + img, + context_map=context_map, + seg=seg, + num_samples=1, + sample_strategy="mean", + preprocess=False, + ) + samples = self.infer( + img, + context_map=context_map, + seg=seg, + sample_strategy="spaced", + num_samples=5, + spaced_range=[-2, 2], + preprocess=False, + ) + except Exception as e: + print(f"ERROR DURING VALIDATION INFERENCE: {e}") + return + + # try: result, fig = self.validate( img, observers, samples, mean, matching_type="best" From c4daf9939aa5b2f3d2a9a55c8c3b6a9bb24135b5 Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Tue, 14 May 2024 07:53:04 +1000 Subject: [PATCH 260/264] Debug probunet with multi channels --- platipy/imaging/cnn/prob_unet_debug.py | 647 ++++++++++++++ platipy/imaging/cnn/train_debug.py | 1134 ++++++++++++++++++++++++ 2 files changed, 1781 insertions(+) create mode 100644 platipy/imaging/cnn/prob_unet_debug.py create mode 100644 platipy/imaging/cnn/train_debug.py diff --git a/platipy/imaging/cnn/prob_unet_debug.py b/platipy/imaging/cnn/prob_unet_debug.py new file mode 100644 index 00000000..09dda08b --- /dev/null +++ b/platipy/imaging/cnn/prob_unet_debug.py @@ -0,0 +1,647 @@ +# Copyright 2020 University of New South Wales, University of Sydney, Ingham Institute + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Parts of this work are derived from: +# https://github.com/stefanknegt/Probabilistic-Unet-Pytorch +# which is released under the Apache Licence 2.0 + +import torch +import numpy as np +from torch.distributions import Normal, Independent, kl + +from platipy.imaging.cnn.unet import UNet, Conv, init_weights, conv_nd + + +class Encoder(torch.nn.Module): + """Encoder part of the probabilistic UNet""" + + def __init__( + self, input_channels, filters_per_layer=[64 * (2**x) for x in range(5)], ndims=2, dropout_probability=None + ): + super(Encoder, self).__init__() + + layers = [] + for idx, layer_filters in enumerate(filters_per_layer): + + input_filters = input_channels if idx == 0 else output_filters + output_filters = layer_filters + + down_sample = 0 if idx == 0 else -2 + + layers.append( + Conv( + input_filters, + output_filters, + up_down_sample=down_sample, + ndims=ndims, + dropout_probability=dropout_probability, + ) + ) + + self.layers = torch.nn.Sequential(*layers) + + self.layers.apply(init_weights) + + def forward(self, x): + + return self.layers(x) + + +class AxisAlignedConvGaussian(torch.nn.Module): + def __init__( + self, + input_channels, + filters_per_layer=[64 * (2**x) for x in range(5)], + latent_dim=2, + ndims=2, + dropout_probability=0.0 + ): + + super(AxisAlignedConvGaussian, self).__init__() + + self.latent_dim = latent_dim + + self.encoder = Encoder(input_channels, filters_per_layer, ndims=ndims, dropout_probability=dropout_probability) + + self.final = conv_nd( + in_channels=filters_per_layer[-1], + out_channels=2 * self.latent_dim, + kernel_size=1, + stride=1, + ndims=ndims, + ) + + self.ndims = ndims + + self.final.apply(init_weights) + + def forward(self, img, seg=None): + """Forward pass through the network + + Args: + img (torch.Tensor): The image to be passed through. + seg (torch.Tensor, optional): The segmentation mask to use in the case of the prior + network. Defaults to None. + + Returns: + torch.distributions.distribution.Distribution: The distribution output + """ + + x = img + if seg is not None: + # seg = torch.unsqueeze(seg, dim=1) + x = torch.cat((img, seg), dim=1) + + encoding = self.encoder(x) + + # We only want the mean of the resulting hxw image + encoding = torch.mean(encoding, dim=2, keepdim=True) + encoding = torch.mean(encoding, dim=3, keepdim=True) + if self.ndims == 3: + encoding = torch.mean(encoding, dim=4, keepdim=True) + + # Convert encoding to 2 x latent dim and split up for mu and log_sigma + mu_log_sigma = self.final(encoding) + + # We squeeze the second dimension twice, since otherwise it won't work when batch size is + # equal to 1 + mu_log_sigma = torch.squeeze(mu_log_sigma, dim=2) + mu_log_sigma = torch.squeeze(mu_log_sigma, dim=2) + if self.ndims == 3: + mu_log_sigma = torch.squeeze(mu_log_sigma, dim=2) + + mu = mu_log_sigma[:, : self.latent_dim].clamp(-1000, 1000) + log_sigma = mu_log_sigma[:, self.latent_dim :].clamp(-10, 10) + + # This is a multivariate normal with diagonal covariance matrix sigma + # https://github.com/pytorch/pytorch/pull/11178 + dist = Independent(Normal(loc=mu, scale=torch.exp(log_sigma)), 1) + + return dist + + +class Fcomb(torch.nn.Module): + """ + A function composed of no_convs_fcomb times a 1x1 convolution that combines the sample taken + from the latent space, and output of the UNet (the feature map) by concatenating them along + their channel axis. + """ + + def __init__(self, filters_per_layer, latent_dim, num_classes, no_convs_fcomb, ndims=2): + super(Fcomb, self).__init__() + + layers = [] + + # Decoder of N x a 1x1 convolution followed by a ReLU activation function except for the + # last layer + layers.append( + conv_nd( + in_channels=filters_per_layer[0] + latent_dim, + out_channels=filters_per_layer[0], + kernel_size=1, + ndims=ndims, + ) + ) + layers.append(torch.nn.ReLU(inplace=True)) + + for _ in range(no_convs_fcomb - 2): + layers.append( + conv_nd( + in_channels=filters_per_layer[0], + out_channels=filters_per_layer[0], + kernel_size=1, + ndims=ndims, + ) + ) + layers.append(torch.nn.ReLU(inplace=True)) + + self.layers = torch.nn.Sequential(*layers) + + self.last_layer = conv_nd( + in_channels=filters_per_layer[0], out_channels=num_classes, kernel_size=1, ndims=ndims + ) + + self.layers.apply(init_weights) + self.last_layer.apply(init_weights) + + self.ndims = ndims + + def forward(self, feature_map, z): + + #z = torch.unsqueeze(z, 2).expand(-1, -1, feature_map.shape[2]) + #z = torch.unsqueeze(z, 3).expand(-1, -1, -1, feature_map.shape[3]) + #if self.ndims == 3: + # z = torch.unsqueeze(z, 4).expand(-1, -1, -1, -1, feature_map.shape[4]) + + # Concatenate the feature map (output of the UNet) and the sample taken from the latent + # space + # feature_map = torch.cat((feature_map, z), dim=1) + output = self.layers(feature_map) + return self.last_layer(output) + + +class ProbabilisticUnet(torch.nn.Module): + """ + A probabilistic UNet implementation + (https://papers.nips.cc/paper/2018/file/473447ac58e1cd7e96172575f48dca3b-Paper.pdf) + + input_channels: the number of channels in the image (1 for greyscale and 3 for RGB) + num_classes: the number of classes to predict + num_filters: is a list consisint of the amount of filters layer + latent_dim: dimension of the latent space + no_cons_per_block: no convs per block in the (convolutional) encoder of prior and posterior + """ + + def __init__( + self, + input_channels=1, + num_classes=2, + filters_per_layer=[64 * (2**x) for x in range(5)], + latent_dim=6, + no_convs_fcomb=4, + loss_type="elbo", + loss_params={"beta": 1}, + ndims=2, + dropout_probability=0.0, + use_structure_context=False + ): + super(ProbabilisticUnet, self).__init__() + + self.num_classes = num_classes + self.no_convs_per_block = 3 + self.no_convs_fcomb = no_convs_fcomb + self.initializers = {"w": "he_normal", "b": "normal"} + self.z_prior_sample = 0 + self.latent_dim = latent_dim + self.use_structure_context = use_structure_context + + unet_input_channels = input_channels + if use_structure_context: + unet_input_channels = unet_input_channels + num_classes + + self.unet = UNet( + unet_input_channels, + num_classes, + filters_per_layer, + final_layer=False, + dropout_probability=dropout_probability, + ndims=ndims, + ) + self.prior = None + if not use_structure_context: + self.prior = AxisAlignedConvGaussian( + input_channels, filters_per_layer, latent_dim, ndims=ndims, dropout_probability=0.0 + ) + #self.posterior = AxisAlignedConvGaussian( + # input_channels + num_classes, filters_per_layer, latent_dim, ndims=ndims, dropout_probability=0.0 + #) + self.fcomb = Fcomb(filters_per_layer, 0, num_classes, no_convs_fcomb, ndims=ndims) + + self.loss_type = loss_type + self.loss_params = loss_params + + self.posterior_latent_space = None + self.prior_latent_space = None + self.unet_features = None + + if self.loss_type == "geco": + self._rec_moving_avg = None + self._contour_moving_avg = None + self.register_buffer("_lambda", torch.ones(2, requires_grad=False)) + + self.register_buffer("_pos_weight", torch.ones(num_classes, requires_grad=False)) + + def forward(self, img, seg=None, training=False): + """ + Construct prior latent space for patch and run patch through UNet, + in case training is True also construct posterior latent space + """ + #if training or self.prior is None: + # self.posterior_latent_space = self.posterior.forward(img, seg=seg) + + self.prior_latent_space = None + if self.prior is not None: + self.prior_latent_space = self.prior.forward(img) + + if self.use_structure_context: + if seg is None: + raise ValueError("Structure context is enabled, but no segmentation mask provided") + import numpy as np + print(f"imgtype: {img.dtype}") + print(f"imgshape: {img.shape}") + print(f"segtype: {seg.dtype}") + np.save("imgg.npy", img.cpu().numpy()) + img = torch.cat((img, seg), dim=1) + np.save("imgg2.npy", img.cpu().numpy()) + np.save("segg.npy", seg.cpu().numpy()) + + self.unet_features = self.unet.forward(img) + + def sample(self, testing=False, use_mean=False, sample_x_stddev_from_mean=None): + """ + Sample a segmentation by reconstructing from a prior sample + and combining this with UNet features + """ + + latent_space = self.prior_latent_space if self.prior is not None else self.posterior_latent_space + z_prior = None + + + testing = False + if testing: + if use_mean: + z_prior = latent_space.base_dist.loc + elif not sample_x_stddev_from_mean is None: + if isinstance(sample_x_stddev_from_mean, list): + sample_x_stddev_from_mean = torch.Tensor(sample_x_stddev_from_mean) + sample_x_stddev_from_mean = sample_x_stddev_from_mean.to( + latent_space.base_dist.stddev.device + ) + z_prior = self.prior_latent_space.base_dist.loc + ( + latent_space.base_dist.scale * sample_x_stddev_from_mean + ) + else: + z_prior = latent_space.sample() + self.z_prior_sample = z_prior + else: + pass + z_prior = None + #z_prior = latent_space.rsample() + #self.z_prior_sample = z_prior + + return self.fcomb.forward(self.unet_features, z_prior) + + def reconstruct(self, use_posterior_mean=False, z_posterior=None, sample_x_stddev_from_mean=None): + """ + Reconstruct a segmentation from a posterior sample (decoding a posterior sample) and UNet + feature map + + use_posterior_mean: use posterior_mean instead of sampling z_q + """ + use_posterior_mean = False + sample_x_stddev_from_mean = None + if use_posterior_mean: + z_posterior = self.posterior_latent_space.mean + elif sample_x_stddev_from_mean is not None: + if isinstance(sample_x_stddev_from_mean, list): + sample_x_stddev_from_mean = torch.Tensor(sample_x_stddev_from_mean) + sample_x_stddev_from_mean = sample_x_stddev_from_mean.to( + self.posterior_latent_space.base_dist.stddev.device + ) + z_posterior = self.posterior_latent_space.base_dist.loc + ( + self.posterior_latent_space.base_dist.scale * sample_x_stddev_from_mean + ) + else: + pass + z_posterior = None +# if z_posterior is None: +# z_posterior = self.posterior_latent_space.rsample() + return self.fcomb.forward(self.unet_features, z_posterior) + + def kl_divergence(self): + """ + Calculate the KL divergence between the posterior and prior KL(Q||P) + """ + + if self.prior_latent_space is None: + + device = self.posterior_latent_space.base_dist.stddev.device + dist = Independent(Normal(loc=torch.zeros(self.latent_dim).to(device), scale=torch.ones(self.latent_dim).to(device)), 1) + kl_div = kl.kl_divergence(self.posterior_latent_space, dist) + else: + kl_div = kl.kl_divergence(self.posterior_latent_space, self.prior_latent_space) + + return kl_div + + def topk_mask(self, score, k): + """Returns a mask for the top-k elements in score.""" + + values, _ = torch.topk(score, 1, axis=1) + _, indices = torch.topk(values, k, axis=0) + return torch.scatter_add( + torch.zeros(score.shape[0]).to(score.device), + 0, + indices.reshape(-1), + torch.ones(score.shape[0]).to(score.device), + ) + + def prepare_mask( + self, + mask, + top_k_percentage, + deterministic, + num_classes, + device, + batch_size, + n_pixels_in_batch, + xe, + ): + if mask is None or mask.sum() == 0: + mask = torch.ones(n_pixels_in_batch) + mask = mask.to(device) + + if top_k_percentage is not None: + + assert 0.0 < top_k_percentage <= 1.0 + k_pixels = int(n_pixels_in_batch * top_k_percentage) + + with torch.no_grad(): + norm_xe = xe / torch.sum(xe) + if deterministic: + score = torch.log(norm_xe) + else: + # TODO Gumbel trick + raise NotImplementedError("Still need to implement Gumbel trick") + + score = score + torch.log(mask.unsqueeze(1).repeat((1, num_classes))) + + top_k_mask = self.topk_mask(score, k_pixels) + top_k_mask = top_k_mask.to(device) + mask = mask * top_k_mask + + else: + mask = torch.reshape(mask, (-1,)) + + mask = mask.unsqueeze(1).repeat((1, num_classes)) + + mask = ( + mask.reshape((batch_size, -1, num_classes)).transpose(-1, 1).reshape((batch_size, -1)) + ) + + return mask + + def reconstruction_loss( + self, + segm, + z_posterior=None, + mask=None, + top_k_percentage=None, + deterministic=True, + ): + criterion = torch.nn.BCEWithLogitsLoss(reduction="none") + + # criterion = torch.nn.BCEWithLogitsLoss(size_average=False, reduce=False, reduction=None) + + # if z_posterior is None: + # z_posterior = self.posterior_latent_space.rsample() + z_posterior = 1 + + reconstruction = self.reconstruct(use_posterior_mean=False, z_posterior=z_posterior) + + criterion = torch.nn.BCEWithLogitsLoss(reduction="mean") + + # Take the max of all structure to combine into one big structure to localise + y = segm + pred = reconstruction + np.save("predd.npy", pred.cpu().detach().numpy()) + #y = y.max(axis=1).values + # y = torch.unsqueeze(y, dim=1) + + # Add a background for the localise UNet + # not_y = y.logical_not() + # y = torch.cat((not_y, y), dim=1).float() + np.save("yyy.npy", y.cpu().detach().numpy()) + + loss = criterion(input=pred, target=y) + return loss, None, None + + ##### + num_classes = reconstruction.shape[1] + y_flat = torch.transpose(reconstruction, 1, -1).reshape((-1, num_classes)) + t_flat = torch.transpose(segm, 1, -1).reshape((-1, num_classes)) + n_pixels_in_batch = y_flat.shape[0] + batch_size = segm.shape[0] + + # pos_class_count = t_flat.sum(axis=0) / batch_size + # neg_class_count = torch.logical_not(t_flat).sum(axis=0) / batch_size + # self._pos_weight = ( + # self._pos_weight * 0.5 + (neg_class_count / pos_class_count).clamp(0, 10000) * 0.5 + # ) + + # criterion = torch.nn.BCEWithLogitsLoss(reduction="none", pos_weight=self._pos_weight) + xe = criterion(input=y_flat, target=t_flat) + xe = xe.reshape((batch_size, -1, num_classes)).transpose(-1, 1).reshape((batch_size, -1)) + + # If multiple masks supplied, compute a loss for each mask + if hasattr(mask, "__iter__"): + ce_sums = [] + ce_means = [] + masks = [] + for this_mask in mask: + this_mask = self.prepare_mask( + this_mask, + top_k_percentage, + deterministic, + num_classes, + y_flat.device, + batch_size, + n_pixels_in_batch, + xe, + ) + + ce_sum_per_instance = torch.sum(this_mask * xe, axis=1) + ce_sums.append(torch.mean(ce_sum_per_instance, axis=0)) + ce_means.append(torch.sum(this_mask * xe) / torch.sum(this_mask)) + masks.append(this_mask) + + return ce_sums, ce_means, masks + + mask = self.prepare_mask( + mask, + top_k_percentage, + deterministic, + num_classes, + y_flat.device, + batch_size, + n_pixels_in_batch, + xe, + ) + + ce_sum_per_instance = torch.sum(mask * xe, axis=1) + ce_sum = torch.mean(ce_sum_per_instance, axis=0) + ce_mean = torch.sum(mask * xe) / torch.sum(mask) + + return ce_sum, ce_mean, mask + + def loss(self, segm, mask=None, beta=None): + """ + Calculate the evidence lower bound of the log-likelihood of P(Y|X) + """ + + # z_posterior = self.posterior_latent_space.rsample() + z_posterior = False + # kl_div = torch.mean(self.kl_divergence()) + # kl_div = torch.clamp(kl_div, 0.0, 100.0) + print(f"##### {segm[0][1].max()}") + + top_k_percentage = None + if "top_k_percentage" in self.loss_params: + top_k_percentage = self.loss_params["top_k_percentage"] + + loss_mask = None + contour_threshold = None + if self.loss_type == "geco": + reconstruction_threshold = self.loss_params["kappa"] + if ( + "kappa_contour" in self.loss_params + and self.loss_params["kappa_contour"] is not None + ): + loss_mask = [None, mask] + contour_threshold = self.loss_params["kappa_contour"] + + # Here we use the posterior sample sampled above + rl_sum, rec_loss_mean, _ = self.reconstruction_loss( + segm, + z_posterior=z_posterior, + top_k_percentage=top_k_percentage, + mask=loss_mask, + ) + + # If using contour mask in loss, we get back those in a list. Unpack here. + if contour_threshold: + contour_loss = rl_sum[1] + contour_loss_mean = rec_loss_mean[1] + reconstruction_loss = rl_sum[0] + rec_loss_mean = rec_loss_mean[0] + else: + reconstruction_loss = rl_sum + + if self.loss_type == "elbo": + if beta is None: + beta = self.loss_params["beta"] + + return { + "loss": reconstruction_loss,# + beta * kl_div, + "rec_loss": reconstruction_loss, + "kl_div": 1, + "beta": beta, + } + elif self.loss_type == "geco": + + rec_geco_step_size = self.loss_params["rec_geco_step_size"] + + with torch.no_grad(): + + moving_avg_factor = 0.5 + + rl = rec_loss_mean.detach() + if self._rec_moving_avg is None: + self._rec_moving_avg = rl + else: + self._rec_moving_avg = self._rec_moving_avg * moving_avg_factor + rl * ( + 1 - moving_avg_factor + ) + + rc = self._rec_moving_avg - reconstruction_threshold + + cc = 0 + if contour_threshold: + cl = contour_loss_mean.detach() + if self._contour_moving_avg is None: + self._contour_moving_avg = rl + else: + self._contour_moving_avg = ( + self._contour_moving_avg * moving_avg_factor + + cl * (1 - moving_avg_factor) + ) + + cc = self._contour_moving_avg - contour_threshold + + lambda_lower = self.loss_params["clamp_rec"][0] + lambda_upper = self.loss_params["clamp_rec"][1] + + self._lambda[0] = (torch.exp(rc * rec_geco_step_size) * self._lambda[0]).clamp( + lambda_lower, lambda_upper + ) + # self._lambda[0] = (rc * self._lambda[0]).clamp(lambda_lower, lambda_upper) + if self._lambda[0].isnan(): + self._lambda[0] = lambda_upper + if contour_threshold: + lambda_lower_contour = self.loss_params["clamp_contour"][0] + lambda_upper_contour = self.loss_params["clamp_contour"][1] + + self._lambda[1] = (torch.exp(cc * rec_geco_step_size) * self._lambda[1]).clamp( + lambda_lower_contour, lambda_upper_contour + ) + # self._lambda[1] = (cc * self._lambda[1]).clamp( + # lambda_lower_contour, lambda_upper_contour + # ) + if self._lambda[1].isnan(): + self._lambda[1] = lambda_upper_contour + + # pylint: disable=access-member-before-definition + loss = (self._lambda[0] * reconstruction_loss)# + kl_div + + result = { + "loss": loss, + "rec_loss": reconstruction_loss, + "kl_div": 0, + "lambda_rec": self._lambda[0], + "moving_avg": self._rec_moving_avg, + "reconstruction_threshold": reconstruction_threshold, + "rec_constraint": rc, + } + + if contour_threshold is not None: + result["loss"] = result["loss"] + (self._lambda[1] * contour_loss) + result["contour_loss"] = contour_loss + result["contour_threshold"] = contour_threshold + result["contour_constraint"] = cc + result["moving_avg_contour"] = self._contour_moving_avg + result["lambda_contour"] = self._lambda[1] + + return result + + else: + raise NotImplementedError("Loss must be 'elbo' or 'geco'") diff --git a/platipy/imaging/cnn/train_debug.py b/platipy/imaging/cnn/train_debug.py new file mode 100644 index 00000000..d87d0468 --- /dev/null +++ b/platipy/imaging/cnn/train_debug.py @@ -0,0 +1,1134 @@ +# Copyright 2021 University of New South Wales, University of Sydney, Ingham Institute + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +import os +import tempfile +import json +from argparse import ArgumentParser + +from pathlib import Path +import matplotlib +import SimpleITK as sitk +import numpy as np +from scipy.optimize import linear_sum_assignment + +import comet_ml # pylint: disable=unused-import +from pytorch_lightning.loggers import CometLogger +from torchmetrics import JaccardIndex + +import torch +import pytorch_lightning as pl +from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint +from pytorch_lightning.callbacks.early_stopping import EarlyStopping + +import matplotlib.pyplot as plt + +from platipy.imaging.cnn.prob_unet import ProbabilisticUnet + +# from platipy.imaging.cnn.hierarchical_prob_unet import HierarchicalProbabilisticUnet +from platipy.imaging.cnn.unet import l2_regularisation +from platipy.imaging.cnn.dataload import UNetDataModule +from platipy.imaging.cnn.dataset import crop_img_using_localise_model +from platipy.imaging.cnn.utils import ( + preprocess_image, + postprocess_mask, + get_metrics, + resample_mask_to_image, +) +from platipy.imaging.cnn.metrics import probabilistic_dice + +from platipy.imaging import ImageVisualiser +from platipy.imaging.label.utils import get_com, get_union_mask, get_intersection_mask + + +class GECOEarlyStopping(EarlyStopping): + def on_validation_end(self, trainer, pl_module): + # Make sure the GECO lambda metrics are below 0.1 before stopping + logs = trainer.callback_metrics + should_consider_early_stop = True + + if "lambda_rec" in logs and logs["lambda_rec"] >= 0.01: + should_consider_early_stop = False + + if "lambda_contour" in logs and logs["lambda_contour"] >= 0.01: + should_consider_early_stop = False + + if should_consider_early_stop: + self._run_early_stopping_check(trainer) + + def on_train_epoch_end(self, trainer, pl_module): + pass + + +class ProbUNet(pl.LightningModule): + def __init__( + self, + **kwargs, + ): + super().__init__() + + self.save_hyperparameters() + + loss_params = None + + if self.hparams.loss_type == "elbo": + loss_params = { + "beta": self.hparams.beta, + } + + if self.hparams.loss_type == "geco": + loss_params = { + "kappa": self.hparams.kappa, + "clamp_rec": self.hparams.clamp_rec, + "clamp_contour": self.hparams.clamp_contour, + "kappa_contour": self.hparams.kappa_contour, + "rec_geco_step_size": self.hparams.rec_geco_step_size, + } + + loss_params["top_k_percentage"] = self.hparams.top_k_percentage + loss_params[ + "contour_loss_lambda_threshold" + ] = self.hparams.contour_loss_lambda_threshold + loss_params["contour_loss_weight"] = self.hparams.contour_loss_weight + + self.use_structure_context = self.hparams.use_structure_context + + if self.hparams.prob_type == "prob": + self.prob_unet = ProbabilisticUnet( + self.hparams.input_channels, + len(self.hparams.structures) + + 1, # Add 1 to num classes for background class + self.hparams.filters_per_layer, + self.hparams.latent_dim, + self.hparams.no_convs_fcomb, + self.hparams.loss_type, + loss_params, + self.hparams.ndims, + dropout_probability=self.hparams.dropout_probability, + use_structure_context=self.use_structure_context, + ) + elif self.hparams.prob_type == "hierarchical": + raise NotImplementedError("Hierarchical Prob UNet current not working...") + # self.prob_unet = HierarchicalProbabilisticUnet( + # input_channels=self.hparams.input_channels, + # num_classes=len(self.hparams.structures), + # filters_per_layer=self.hparams.filters_per_layer, + # down_channels_per_block=self.hparams.down_channels_per_block, + # latent_dims=[self.hparams.latent_dim] * (len(self.hparams.filters_per_layer) - 1), + # convs_per_block=self.hparams.convs_per_block, + # blocks_per_level=self.hparams.blocks_per_level, + # loss_type=self.hparams.loss_type, + # loss_params=loss_params, + # ndims=self.hparams.ndims, + # ) + + self.validation_directory = None + self.kl_div = None + + self.stddevs = np.linspace(-3, 3, self.hparams.num_observers) + + @staticmethod + def add_model_specific_args(parent_parser): + parser = parent_parser.add_argument_group("Probabilistic UNet") + parser.add_argument("--prob_type", type=str, default="prob") + parser.add_argument("--learning_rate", type=float, default=1e-5) + parser.add_argument("--lr_lambda", type=float, default=0.99) + parser.add_argument("--input_channels", type=int, default=1) + parser.add_argument( + "--filters_per_layer", + nargs="+", + type=int, + default=[64 * (2**x) for x in range(5)], + ) + parser.add_argument( + "--down_channels_per_block", nargs="+", type=int, default=None + ) + parser.add_argument("--latent_dim", type=int, default=6) + parser.add_argument("--no_convs_fcomb", type=int, default=4) + parser.add_argument("--convs_per_block", type=int, default=2) + parser.add_argument("--blocks_per_level", type=int, default=1) + parser.add_argument("--loss_type", type=str, default="elbo") + parser.add_argument("--beta", type=float, default=1.0) + parser.add_argument("--kappa", type=float, default=0.02) + parser.add_argument("--kappa_contour", type=float, default=None) + parser.add_argument("--rec_geco_step_size", type=float, default=1e-2) + parser.add_argument("--weight_decay", type=float, default=1e-2) + parser.add_argument("--clamp_rec", nargs="+", type=float, default=[1e-5, 1e5]) + parser.add_argument( + "--clamp_contour", nargs="+", type=float, default=[1e-3, 1e3] + ) + parser.add_argument("--top_k_percentage", type=float, default=None) + parser.add_argument("--contour_loss_lambda_threshold", type=float, default=None) + parser.add_argument( + "--contour_loss_weight", type=float, default=0.0 + ) # no longer used + parser.add_argument("--epochs_all_rec", type=int, default=0) # no longer used + parser.add_argument("--dropout_probability", type=float, default=0.0) + parser.add_argument("--use_structure_context", type=int, default=0) + + return parent_parser + + def forward(self, x): + self.prob_unet.forward(x, None, False) + return x + + def configure_optimizers(self): + optimizer = torch.optim.Adam( + self.parameters(), lr=self.hparams.learning_rate, weight_decay=0 + ) + + return optimizer + params = [ + { + "params": self.prob_unet.unet.parameters(), + "weight_decay": self.hparams.weight_decay, + "lr": 1e-4, + } + ] + + if self.prob_unet.prior is not None: + param_list =[ + self.prob_unet.prior.parameters(), + self.prob_unet.posterior.parameters(), + self.prob_unet.fcomb.parameters(), + ] + else: + param_list =[ +# self.prob_unet.posterior.parameters(), + self.prob_unet.fcomb.parameters(), + ] + for m in param_list: + params += [ + {"params": m, "weight_decay": self.hparams.weight_decay, "lr": 1e-4} + ] + + optimizer = torch.optim.Adam(params) + + lr_lambda_unet = lambda epoch: self.hparams.lr_lambda ** (epoch) + lr_lambda_prob = lambda epoch: 0.99 ** (epoch) + + # max_epochs = self.hparams.max_epochs + # lr_lambda = lambda x: np.interp(((np.sin(x/(max_epochs/8)) * np.sin(x/(max_epochs/4)))), np.array([-1,0,1]), np.array([0.1,1,10])) + + scheduler = torch.optim.lr_scheduler.LambdaLR( + optimizer, lr_lambda=[lr_lambda_unet, lr_lambda_prob] + ) + # scheduler = torch.optim.lr_scheduler.CyclicLR( + # optimizer, + # base_lr=self.hparams.learning_rate / 10, + # max_lr=self.hparams.learning_rate, + # step_size_up=20, + # mode="exp_range", + # gamma=0.999, + # cycle_momentum=False + # ) + #scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( + # optimizer, 50, eta_min=1e-6, verbose=True + #) + + return [optimizer], [scheduler] + + return { + "optimizer": optimizer, + "lr_scheduler": { + "scheduler": torch.optim.lr_scheduler.ReduceLROnPlateau( + optimizer, + "max", + patience=20, + threshold=0.1e-2, + factor=0.75 + # optimizer, "max", patience=200, threshold=0.75, factor=0.5 + ), + "monitor": "probabilisticDice", + }, + } + + def infer( + self, + img, + context_map=None, + seg=None, + num_samples=1, + sample_strategy="mean", + latent_dim=True, + spaced_range=[-1.5, 1.5], + preprocess=True, + ): + # sample strategy in "mean", "random", "spaced" + + if not hasattr(latent_dim, "__iter__"): + latent_dim = [ + latent_dim, + ] * self.hparams.latent_dim + + if sample_strategy == "mean": + samples = [ + { + "name": "mean", + "std_dev_from_mean": torch.Tensor([0.0] * len(latent_dim)).to( + self.device + ), + "preds": [], + } + ] + elif sample_strategy == "random": + samples = [ + { + "name": f"random_{i}", + "std_dev_from_mean": torch.Tensor( + [ + np.random.normal(0, 1.0, 1)[0] if d else 0.0 + for d in latent_dim + ] + ).to(self.device), + "preds": [], + } + for i in range(num_samples) + ] + elif sample_strategy == "spaced": + if self.hparams.prob_type == "hierarchical": + latent_dim = [True] * (len(self.hparams.filters_per_layer) - 1) + samples = [ + { + "name": f"spaced_{s:.2f}", + "std_dev_from_mean": torch.Tensor( + [s if d else 0.0 for d in latent_dim] + ).to(self.device), + "preds": [], + } + for s in np.linspace(spaced_range[0], spaced_range[1], num_samples) + ] + + with torch.no_grad(): + if preprocess: + if self.hparams.crop_using_localise_model: + localise_path = self.hparams.crop_using_localise_model.format( + fold=self.hparams.fold + ) + img = crop_img_using_localise_model( + img, + localise_path, + spacing=self.hparams.spacing, + crop_to_grid_size=self.hparams.localise_voxel_grid_size, + ) + else: + img = preprocess_image( + img, + spacing=self.hparams.spacing, + crop_to_grid_size_xy=self.hparams.crop_to_grid_size, + intensity_scaling=self.hparams.intensity_scaling, + intensity_window=self.hparams.intensity_window, + ) + + + + img_arr = sitk.GetArrayFromImage(img) + + if context_map is not None: + context_map = resample_mask_to_image(img, context_map) + cmap_arr = sitk.GetArrayFromImage(img) + + if seg is not None: + seg = resample_mask_to_image(img, seg) + seg_arr = sitk.GetArrayFromImage(seg) + + if self.hparams.ndims == 2: + slices = [img_arr[z, :, :] for z in range(img_arr.shape[0])] + + if context_map is not None: + cmap_slices = [cmap_arr[z, :, :] for z in range(cmap_arr.shape[0])] + + if seg is not None: + seg_slices = [seg_arr[z, :, :] for z in range(seg_arr.shape[0])] + else: + slices = [img_arr] + if context_map is not None: + cmap_slices = [cmap_arr] + + if seg is not None: + seg_slices = [seg_arr] + + for idx, i in enumerate(slices): + x = torch.Tensor(i).to(self.device) + x = x.unsqueeze(0) + x = x.unsqueeze(0) + + if context_map is not None: + c = torch.Tensor(cmap_slices[idx]).to(self.device) + c = c.unsqueeze(0) + c = c.unsqueeze(0) + + x = torch.cat((x, c), dim=1) + + if seg is not None: + s = torch.Tensor(seg_slices[idx]).to(self.device) + s = s.unsqueeze(0) + s = s.unsqueeze(0) + + # Add in background channel + not_s = 1 - s.max(axis=1).values + not_s = torch.unsqueeze(not_s, dim=1) + s = torch.cat((not_s, s), dim=1).float() + + if self.hparams.prob_type == "prob": + if seg is not None: + self.prob_unet.forward(x, seg=s) + else: + self.prob_unet.forward(x) + + for sample in samples: + if self.hparams.prob_type == "prob": + if sample["name"] == "mean": + if seg is None: + y = self.prob_unet.sample(testing=True, use_mean=True) + else: + y = self.prob_unet.reconstruct(use_posterior_mean=True) + else: + if seg is None: + y = self.prob_unet.sample( + testing=True, + use_mean=False, + sample_x_stddev_from_mean=sample["std_dev_from_mean"], + ) + else: + y = self.prob_unet.reconstruct( + use_posterior_mean=False, + sample_x_stddev_from_mean=sample["std_dev_from_mean"], + ) + + # else: + # if sample["name"] == "mean": + # y = self.prob_unet.sample(x, mean=True) + # else: + # y = self.prob_unet.sample( + # x, + # mean=True, + # std_devs_from_mean=sample["std_dev_from_mean"], + # ) + + y = y.squeeze(0) + # y = np.argmax(y.cpu().detach().numpy(), axis=0) + y = torch.sigmoid(y) + sample["preds"].append(y.cpu().detach().numpy()) + + result = {} + for sample in samples: + pred_arr = sample["preds"][0] + + if self.hparams.ndims == 2: + pred_arr = np.expand_dims(pred_arr, 1) + if len(sample["preds"]) > 1: + pred_arr = np.stack(sample["preds"], axis=1) + + result[sample["name"]] = {} + + for idx, structure in enumerate(self.hparams.structures): + pred = sitk.GetImageFromArray(pred_arr[idx + 1]) # Skip the background + pred = pred > 0.5 # Threshold softmax at 0.5 + pred = sitk.Cast(pred, sitk.sitkUInt8) + + pred.CopyInformation(img) + pred = postprocess_mask(pred) + pred = sitk.Resample( + pred, img, sitk.Transform(), sitk.sitkNearestNeighbor + ) + + result[sample["name"]][structure] = pred + + return result + + def validate( + self, + img, + manual_observers, + samples, + mean, + matching_type="best", + window=[-0.5, 1.0], + ): + metrics = {"DSC": "max", "HD": "min", "ASD": "min"} + result = {} + + contour_cmaps = ["RdPu", "YlOrRd", "GnBu", "OrRd", "YlGn", "YlGnBu"] + structures = self.hparams.structures + + try: + cut = get_com(mean["mean"][structures[0]]) + except ValueError: + cut = [int(i / 2) for i in img.GetSize()][::-1] + + vis = ImageVisualiser(img, cut=cut, figure_size_in=16, window=window) + + mean_contours = {} + for idx, structure in enumerate(structures): + color_map = matplotlib.colormaps.get_cmap( + contour_cmaps[idx % len(structures)] + ) + mean_contours[f"mean_{structure}"] = mean["mean"][structure] + + vis.add_contour( + mean_contours, color=color_map(0.35), linewidth=3, show_legend=False + ) + + manual_color = color_map(0.9) + + manual_observers_struct = { + f"{man_struct}_{structure}": manual_observers[man_struct][structure] + for man_struct in manual_observers + } + + vis.add_contour( + manual_observers_struct, + color=manual_color, + linewidth=0.5, + show_legend=False, + ) + + intersection_mask = get_intersection_mask(manual_observers_struct) + union_mask = get_union_mask(manual_observers_struct) + + vis.add_contour( + intersection_mask, + name=f"intersection_{structure}", + color=manual_color, + linewidth=3, + ) + vis.add_contour( + union_mask, name=f"union_{structure}", color=manual_color, linewidth=3 + ) + + samples_struct = { + f"{sample_struct}_{structure}": samples[sample_struct][structure] + for sample_struct in samples + } + vis.add_contour( + samples_struct, + linewidth=1.5, + color={ + s: c + for s, c in zip( + samples_struct, + color_map(np.linspace(0.1, 0.7, len(samples_struct))), + ) + }, + ) + + # vis.set_limits_from_label(union_mask, expansion=30) + + sim = { + k: np.zeros((len(samples_struct), len(manual_observers_struct))) + for k in metrics + } + msim = { + k: np.zeros((len(samples_struct), len(manual_observers_struct))) + for k in metrics + } + for sid, samp in enumerate(samples_struct): + for oid, obs in enumerate(manual_observers_struct): + sample_metrics = get_metrics( + manual_observers_struct[obs], samples_struct[samp] + ) + mean_metrics = get_metrics( + manual_observers_struct[obs], mean_contours[f"mean_{structure}"] + ) + + for k in sample_metrics: + sim[k][sid, oid] = sample_metrics[k] + msim[k][sid, oid] = mean_metrics[k] + + result[f"probnet_{structure}"] = {k: [] for k in metrics} + result[f"unet_{structure}"] = {k: [] for k in metrics} + for k in sim: + val = sim[k] + if matching_type == "hungarian": + if metrics[k] == "max": + val = -val + row_idx, col_idx = linear_sum_assignment(val) + prob_unet_mean = sim[k][row_idx, col_idx].mean() + else: + if metrics[k] == "max": + prob_unet_mean = val.max() + else: + prob_unet_mean = val.min() + result[f"probnet_{structure}"][k].append(prob_unet_mean) + + val = msim[k] + if matching_type == "hungarian": + if metrics[k] == "max": + val = -val + row_idx, col_idx = linear_sum_assignment(val) + unet_mean = msim[k][row_idx, col_idx].mean() + else: + if metrics[k] == "max": + unet_mean = val.max() + else: + unet_mean = val.min() + result[f"unet_{structure}"][k].append(unet_mean) + + fig = vis.show() + + return result, fig + + def training_step(self, batch, _): + x, c, y, m, _ = batch + + # Add background layer for one-hot encoding + not_y = 1 - y.max(axis=1).values + not_y = torch.unsqueeze(not_y, dim=1) + y = torch.cat((not_y, y), dim=1).float() + + # Concat context map to image if we have one + if c.numel() > 0: + x = torch.cat((x, c), dim=1).float() + + # self.prob_unet.forward(x, y, training=True) + if self.hparams.prob_type == "prob": + self.prob_unet.forward(x, y, training=True) + # else: + # self.prob_unet.forward(x, y) + np.save("yyyy.npy", y.cpu().detach().numpy()) + + if self.hparams.prob_type == "prob": + loss = self.prob_unet.loss(y, mask=m) + # else: + # loss = self.prob_unet.loss(x, y, mask=m) + + training_loss = loss["loss"] + + # Using weight decay instead + # if self.hparams.prob_type == "prob": + # reg_loss = ( + # l2_regularisation(self.prob_unet.posterior) + # + l2_regularisation(self.prob_unet.prior) + # + l2_regularisation(self.prob_unet.fcomb.layers) + # ) + # training_loss = training_loss + 1e-5 * reg_loss + self.log( + "training_loss", + training_loss.detach(), + on_step=True, + on_epoch=False, + prog_bar=True, + logger=True, + ) + + self.kl_div = 1#loss["kl_div"].detach().cpu() + + for k in loss: + if k == "loss": + continue + if k == "kl_div": continue + self.log( + k, + loss[k].detach() if isinstance(loss[k], torch.Tensor) else loss[k], + on_step=True, + on_epoch=False, + prog_bar=True, + logger=True, + ) + return training_loss + + def validation_step(self, batch, _): + if self.validation_directory is None: + self.validation_directory = Path(tempfile.mkdtemp()) + + n = self.hparams.num_observers + m = self.hparams.num_observers + + with torch.set_grad_enabled(False): + x, c, y, _, info = batch + + np.save("img.npy", x.cpu().numpy()) + print(f"x: {x.shape}") + print("y: " + str(y.shape)) + print(f"y1sum: {y.sum()}") + np.save("seg.npy", y.cpu().numpy()) + + # Save off slices/volumes for analysis of entire structure in end of validation step + for s in range(y.shape[0]): + img_file = self.validation_directory.joinpath( + f"img_{info['case'][s]}_{info['z'][s]}.npy" + ) + np.save(img_file, x[s].squeeze(0).cpu().numpy()) + + if c.numel() > 0: + cmap_file = self.validation_directory.joinpath( + f"cmap_{info['case'][s]}_{info['z'][s]}.npy" + ) + np.save(cmap_file, c[s].squeeze(0).cpu().numpy()) + + mask_file = self.validation_directory.joinpath( + f"mask_{info['case'][s]}_{info['z'][s]}_{info['observer'][s]}.npy" + ) + np.save(mask_file, y[s].cpu().numpy()) + + # Image (and context map) will be same for all in batch + x = x[0].unsqueeze(0) + if c.numel() > 0: + c = c[0].unsqueeze(0) + if self.hparams.ndims == 2: + vis = ImageVisualiser(sitk.GetImageFromArray(x.to("cpu")[0]), axis="z") + else: + vis = ImageVisualiser(sitk.GetImageFromArray(x.to("cpu")[0, 0])) + + if self.hparams.ndims == 2: + x = x.repeat(m, 1, 1, 1) + + if c.numel() > 0: + c = c.repeat(m, 1, 1, 1) + else: + x = x.repeat(m, 1, 1, 1, 1) + + if c.numel() > 0: + c = c.repeat(m, 1, 1, 1, 1) + + if c.numel() > 0: + x = torch.cat((x, c), dim=1) + + seg = None + print(f"y2sum: {y.sum()}") + if self.use_structure_context: + not_y = 1 - y.max(axis=1).values + not_y = torch.unsqueeze(not_y, dim=1) + seg = torch.cat((not_y, y), dim=1).float() + print(f"seg1sum: {seg.sum()}") + np.save("seg2.npy", seg.cpu().numpy()) + + self.prob_unet.forward(x, seg=seg) + + loss = self.prob_unet.loss(seg) + print(f"VAL LOSS: {loss}") + + py = self.prob_unet.sample(testing=True) + py = py.to("cpu") + np.save("pred.npy", py.numpy()) + + pred_y = torch.zeros(py[:, 0, :].shape).int() + for b in range(py.shape[0]): + pred_y[b] = py[b, :].argmax(0).int() + + y = y.squeeze(1) + y = y.int() + y = y.to("cpu") + + # TODO Make this work for multi class + # Intersection over Union (also known as Jaccard Index) + jaccard = JaccardIndex(num_classes=2) + term_1 = 0 + for i in range(n): + for j in range(m): + if pred_y[i].sum() + y[j].sum() == 0: + continue + iou = jaccard(pred_y[i], y[j]) + term_1 += 1 - iou + term_1 = term_1 * (2 / (m * n)) + + term_2 = 0 + for i in range(n): + for j in range(n): + if pred_y[i].sum() + pred_y[j].sum() == 0: + continue + iou = jaccard(pred_y[i], pred_y[j]) + term_2 += 1 - iou + term_2 = term_2 * (1 / (n * n)) + + term_3 = 0 + for i in range(m): + for j in range(m): + if y[i].sum() + y[j].sum() == 0: + continue + iou = jaccard(y[i], y[j]) + term_3 += 1 - iou + term_3 = term_3 * (1 / (m * m)) + + D_ged = term_1 - term_2 - term_3 + + contours = {} + for o in range(n): + obs_y = y[o].float() + if self.hparams.ndims == 2: + obs_y = obs_y.unsqueeze(0) + contours[f"obs_{o}"] = sitk.GetImageFromArray(obs_y) + for mm in range(m): + samp_pred = pred_y[mm].float() + if self.hparams.ndims == 2: + samp_pred = samp_pred.unsqueeze(0) + contours[f"sample_{mm}"] = sitk.GetImageFromArray(samp_pred) + + vis.add_contour(contours, colormap=matplotlib.colormaps.get_cmap("cool")) + vis.show() + + figure_path = f"ged_{info['z'][s]}.png" + plt.savefig(figure_path, dpi=300) + plt.close("all") + + try: + self.logger.experiment.log_image(figure_path) + except AttributeError: + # Likely offline mode + pass + + self.log("GED", D_ged) + + return info + + def validation_epoch_end(self, validation_step_outputs): + cases = {} + for info in validation_step_outputs: + for case, z, observer in zip(info["case"], info["z"], info["observer"]): + if not case in cases: + cases[case] = {"slices": z.item(), "observers": [observer]} + else: + if z.item() > cases[case]["slices"]: + cases[case]["slices"] = z.item() + if not observer in cases[case]["observers"]: + cases[case]["observers"].append(observer) + + metrics = ["DSC", "HD", "ASD"] + computed_metrics = { + **{ + f"probnet_{s}_{m}": [] for m in metrics for s in self.hparams.structures + }, + **{f"unet_{s}_{m}": [] for m in metrics for s in self.hparams.structures}, + } + + if len(cases) == 0: + return + + prob_surface_dice = 0 + prob_dice = 0 + + for case in cases: + img_arrs = [] + cmap_arrs = [] + cmap_arr = None + slices = [] + + if self.hparams.ndims == 2: + for z in range(cases[case]["slices"] + 1): + img_file = self.validation_directory.joinpath(f"img_{case}_{z}.npy") + if img_file.exists(): + img_arrs.append(np.load(img_file)) + slices.append(z) + + cmap_file = self.validation_directory.joinpath( + f"cmap_{case}_{z}.npy" + ) + if cmap_file.exists(): + cmap_arrs.append(np.load(cmap_file)) + + img_arr = np.stack(img_arrs) + + if len(cmap_arrs) > 0: + cmap_arr = np.stack(cmap_arr) + + else: + img_file = self.validation_directory.joinpath(f"img_{case}_0.npy") + img_arr = np.load(img_file) + + cmap_file = self.validation_directory.joinpath(f"cmap_{case}_0.npy") + if cmap_file.exists(): + cmap_arr = np.load(cmap_file) + + img = sitk.GetImageFromArray(img_arr) + img.SetSpacing(self.hparams.spacing) + + observers = {} + for _, observer in enumerate(cases[case]["observers"]): + if self.hparams.ndims == 2: + mask_arrs = [] + for z in slices: + mask_file = self.validation_directory.joinpath( + f"mask_{case}_{z}_{observer}.npy" + ) + + mask_arrs.append(np.load(mask_file)) + + mask_arr = np.stack(mask_arrs, axis=1) + + else: + mask_file = self.validation_directory.joinpath( + f"mask_{case}_{z}_{observer}.npy" + ) + mask_arr = np.load(mask_file) + + observers[f"manual_{observer}"] = {} + for idx, structure in enumerate(self.hparams.structures): + mask = sitk.GetImageFromArray(mask_arr[idx]) + mask = sitk.Cast(mask, sitk.sitkUInt8) + mask.CopyInformation(img) + observers[f"manual_{observer}"][structure] = mask + + context_map = None + if cmap_arr is not None: + context_map = sitk.GetImageFromArray(cmap_arr) + context_map.SetSpacing(self.hparams.spacing) + + seg = None + if self.use_structure_context: + # TODO choose the observer to pass properly + seg = observers[f"manual_{observer}"][structure] + + try: + mean = self.infer( + img, + context_map=context_map, + seg=seg, + num_samples=1, + sample_strategy="mean", + preprocess=False, + ) + samples = self.infer( + img, + context_map=context_map, + seg=seg, + sample_strategy="spaced", + num_samples=5, + spaced_range=[-2, 2], + preprocess=False, + ) + except Exception as e: + print(f"ERROR DURING VALIDATION INFERENCE: {e}") + return + + + # try: + result, fig = self.validate( + img, observers, samples, mean, matching_type="best" + ) + # except Exception as e: + # print(f"ERROR DURING VALIDATION VALIDATE: {e}") + # return + + figure_path = f"valid_{case}.png" + fig.savefig(figure_path, dpi=300) + plt.close("all") + + try: + self.logger.experiment.log_image(figure_path) + except AttributeError: + # Likely offline mode + pass + + for t in result: + for m in metrics: + computed_metrics[f"{t}_{m}"] += result[t][m] + + # Compute the probabilistic (surface) dice + for idx, structure in enumerate(self.hparams.structures): + gt_labels = [] + for _, observer in enumerate(cases[case]["observers"]): + gt_labels.append(observers[f"manual_{observer}"][structure]) + + sample_labels = [] + for rk in samples: + sample_labels.append(samples[rk][structure]) + + prob_dice += probabilistic_dice( + gt_labels, sample_labels, dsc_type="dsc" + ) + prob_surface_dice += probabilistic_dice( + gt_labels, sample_labels, dsc_type="sdsc", tau=3 + ) + + prob_dice = prob_dice / len(cases) + if np.isnan(prob_dice): + prob_dice = 0 + self.log( + "probabilisticDice", + prob_dice, + on_step=False, + on_epoch=True, + prog_bar=False, + logger=True, + ) + + prob_surface_dice = prob_surface_dice / len(cases) + if np.isnan(prob_surface_dice): + prob_surface_dice = 0 + self.log( + "probabilisticSurfaceDice", + prob_surface_dice, + on_step=False, + on_epoch=True, + prog_bar=False, + logger=True, + ) + + if self.kl_div: + p = u = 0 + for s in self.hparams.structures: + p += np.array(computed_metrics[f"probnet_{s}_DSC"]).mean() + u += np.array(computed_metrics[f"unet_{s}_DSC"]).mean() + + p /= len(self.hparams.structures) + u /= len(self.hparams.structures) + computed_metrics["scaled_DSC"] = ((p + u) / 2) + (p - u) - self.kl_div + + for cm in computed_metrics: + self.log( + cm, + np.array(computed_metrics[cm]).mean(), + on_step=False, + on_epoch=True, + prog_bar=False, + logger=True, + ) + + # shutil.rmtree(self.validation_directory) + + +def main(args, config_json_path=None): + pl.seed_everything(args.seed, workers=True) + + args.working_dir = Path(args.working_dir) + args.working_dir = args.working_dir.joinpath(args.experiment) + # args.default_root_dir = str(args.working_dir) + args.fold_dir = args.working_dir.joinpath(f"fold_{args.fold}") + args.default_root_dir = str(args.fold_dir) + args.accumulate_grad_batches = {0: 5, 5: 10, 10: 15} + # args.precision = 16 + + comet_api_key = None + comet_workspace = None + comet_project = None + + if args.comet_api_key: + comet_api_key = args.comet_api_key + comet_workspace = args.comet_workspace + comet_project = args.comet_project + + if comet_api_key is None: + if "COMET_API_KEY" in os.environ: + comet_api_key = os.environ["COMET_API_KEY"] + if "COMET_WORKSPACE" in os.environ: + comet_workspace = os.environ["COMET_WORKSPACE"] + if "COMET_PROJECT" in os.environ: + comet_project = os.environ["COMET_PROJECT"] + + if comet_api_key is not None: + comet_logger = CometLogger( + api_key=comet_api_key, + workspace=comet_workspace, + project_name=comet_project, + experiment_name=args.experiment, + save_dir=args.working_dir, + offline=args.offline, + ) + if config_json_path: + comet_logger.experiment.log_code(config_json_path) + + dict_args = vars(args) + + data_module = UNetDataModule(**dict_args) + + prob_unet = ProbUNet(**dict_args) + + if args.resume_from is not None: + trainer = pl.Trainer(resume_from_checkpoint=args.resume_from) + else: + trainer = pl.Trainer.from_argparse_args(args) + + if comet_api_key is not None: + trainer.logger = comet_logger + + lr_monitor = LearningRateMonitor(logging_interval="step") + trainer.callbacks.append(lr_monitor) + + # Save the best model + if args.checkpoint_var: + checkpoint_callback = ModelCheckpoint( + monitor=args.checkpoint_var, + dirpath=args.default_root_dir, + filename="probunet-{epoch:02d}-{" + args.checkpoint_var + ":.2f}", + save_top_k=1, + mode=args.checkpoint_mode, + ) + trainer.callbacks.append(checkpoint_callback) + + if args.early_stopping_var: + early_stop_callback = GECOEarlyStopping( + monitor=args.early_stopping_var, + min_delta=args.early_stopping_min_delta, + patience=args.early_stopping_patience, + verbose=True, + mode=args.early_stopping_mode, + ) + trainer.callbacks.append(early_stop_callback) + + trainer.fit(prob_unet, data_module) + + +def parse_config_file(config_json_path, args): + with open(config_json_path, "r") as f: + params = json.load(f) + for key in params: + args.append(f"--{key}") + + if isinstance(params[key], list): + for list_val in params[key]: + args.append(str(list_val)) + else: + args.append(str(params[key])) + + return args + + +if __name__ == "__main__": + args = None + config_json_path = None + if len(sys.argv) == 2: + # Check if JSON file parsed, if so read arguments from there... + if sys.argv[-1].endswith(".json"): + config_json_path = sys.argv[-1] + args = parse_config_file(config_json_path, []) + + arg_parser = ArgumentParser() + arg_parser = ProbUNet.add_model_specific_args(arg_parser) + arg_parser = UNetDataModule.add_model_specific_args(arg_parser) + arg_parser = pl.Trainer.add_argparse_args(arg_parser) + arg_parser.add_argument( + "--config", type=str, default=None, help="JSON file with parameters to load" + ) + arg_parser.add_argument( + "--seed", type=int, default=42, help="an integer to use as seed" + ) + arg_parser.add_argument( + "--experiment", type=str, default="default", help="Name of experiment" + ) + arg_parser.add_argument("--working_dir", type=str, default="./working") + arg_parser.add_argument("--num_observers", type=int, default=5) + arg_parser.add_argument("--spacing", nargs="+", type=float, default=[1, 1, 1]) + arg_parser.add_argument("--offline", type=bool, default=False) + arg_parser.add_argument("--comet_api_key", type=str, default=None) + arg_parser.add_argument("--comet_workspace", type=str, default=None) + arg_parser.add_argument("--comet_project", type=str, default=None) + arg_parser.add_argument("--resume_from", type=str, default=None) + arg_parser.add_argument("--early_stopping_var", type=str, default=None) + arg_parser.add_argument("--early_stopping_min_delta", type=float, default=0.01) + arg_parser.add_argument("--early_stopping_patience", type=int, default=50) + arg_parser.add_argument("--early_stopping_mode", type=str, default="max") + arg_parser.add_argument("--checkpoint_var", type=str, default=None) + arg_parser.add_argument("--checkpoint_mode", type=str, default="max") + + parsed_args = arg_parser.parse_args(args) + + # Check if config arg parsed, if so take over values and reparse + if parsed_args.config: + print("parseing args") + args = parse_config_file(parsed_args.config, sys.argv[1:]) + parsed_args = arg_parser.parse_args(args) + + main(parsed_args) From 809d501011c3887e1e20f33891e972cf378b7d5c Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Mon, 13 May 2024 17:01:06 -0500 Subject: [PATCH 261/264] Prep for multi input experiment --- platipy/imaging/cnn/prob_unet.py | 4 ++++ platipy/imaging/cnn/train.py | 15 +++++++++++++-- 2 files changed, 17 insertions(+), 2 deletions(-) diff --git a/platipy/imaging/cnn/prob_unet.py b/platipy/imaging/cnn/prob_unet.py index 087fe65f..5b2b8a10 100644 --- a/platipy/imaging/cnn/prob_unet.py +++ b/platipy/imaging/cnn/prob_unet.py @@ -420,7 +420,11 @@ def reconstruction_loss( z_posterior = self.posterior_latent_space.rsample() reconstruction = self.reconstruct(use_posterior_mean=False, z_posterior=z_posterior) + + criterion = torch.nn.BCEWithLogitsLoss(reduction="mean") + loss = criterion(input=reconstruction, target=segm) + return loss, None, None ##### num_classes = reconstruction.shape[1] y_flat = torch.transpose(reconstruction, 1, -1).reshape((-1, num_classes)) diff --git a/platipy/imaging/cnn/train.py b/platipy/imaging/cnn/train.py index 4a83d440..7d1f96f8 100644 --- a/platipy/imaging/cnn/train.py +++ b/platipy/imaging/cnn/train.py @@ -185,6 +185,15 @@ def forward(self, x): return x def configure_optimizers(self): + optimizer = torch.optim.Adam( + self.parameters(), lr=self.hparams.learning_rate, weight_decay=0 + ) + lr_lambda_unet = lambda epoch: self.hparams.lr_lambda ** (epoch) + scheduler = torch.optim.lr_scheduler.LambdaLR( + optimizer, lr_lambda=[lr_lambda_unet] + ) + + return [optimizer], [scheduler] params = [ { "params": self.prob_unet.unet.parameters(), @@ -337,7 +346,7 @@ def infer( if seg is not None: seg = resample_mask_to_image(img, seg) - seg_arr = sitk.GetArrayFromImage(img) + seg_arr = sitk.GetArrayFromImage(seg) if self.hparams.ndims == 2: slices = [img_arr[z, :, :] for z in range(img_arr.shape[0])] @@ -586,7 +595,7 @@ def training_step(self, batch, _): # Concat context map to image if we have one if c.numel() > 0: - x = torch.cat((x, c), dim=1) + x = torch.cat((x, c), dim=1).float() # self.prob_unet.forward(x, y, training=True) if self.hparams.prob_type == "prob": @@ -691,6 +700,8 @@ def validation_step(self, batch, _): seg = torch.cat((not_y, y), dim=1).float() self.prob_unet.forward(x, seg=seg) + loss = self.prob_unet.loss(seg) + print(f"VAL LOSS: {loss}") py = self.prob_unet.sample(testing=True) py = py.to("cpu") From 4a8e372a4e493148902b2d00663a1e9fe163e6d7 Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Tue, 14 May 2024 17:40:16 -0500 Subject: [PATCH 262/264] staple mask for validation --- platipy/imaging/cnn/prob_unet.py | 6 +++--- platipy/imaging/cnn/train.py | 11 +++++++++-- 2 files changed, 12 insertions(+), 5 deletions(-) diff --git a/platipy/imaging/cnn/prob_unet.py b/platipy/imaging/cnn/prob_unet.py index 5b2b8a10..47465a49 100644 --- a/platipy/imaging/cnn/prob_unet.py +++ b/platipy/imaging/cnn/prob_unet.py @@ -421,10 +421,10 @@ def reconstruction_loss( reconstruction = self.reconstruct(use_posterior_mean=False, z_posterior=z_posterior) - criterion = torch.nn.BCEWithLogitsLoss(reduction="mean") + #criterion = torch.nn.BCEWithLogitsLoss(reduction="mean") - loss = criterion(input=reconstruction, target=segm) - return loss, None, None + # loss = criterion(input=reconstruction, target=segm) + # return loss, None, None ##### num_classes = reconstruction.shape[1] y_flat = torch.transpose(reconstruction, 1, -1).reshape((-1, num_classes)) diff --git a/platipy/imaging/cnn/train.py b/platipy/imaging/cnn/train.py index 7d1f96f8..b5468135 100644 --- a/platipy/imaging/cnn/train.py +++ b/platipy/imaging/cnn/train.py @@ -869,8 +869,15 @@ def validation_epoch_end(self, validation_step_outputs): seg = None if self.use_structure_context: - # TODO choose the observer to pass properly - seg = observers[f"manual_{observer}"][structure] + # Staple the man observers to pass in as context seg + masks = [] + for man_obs in observers: + masks.append(observers[man_obs][structure]) + + stapled = sitk.STAPLE(masks) + stapled = stapled > 0.5 + stapled = sitk.Cast(stapled, sitk.sitkUInt8) + seg = stapled try: mean = self.infer( From de6628eccc793d6c22488a814d8739e19c42dbf5 Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Fri, 27 Sep 2024 08:54:44 +1000 Subject: [PATCH 263/264] Allow providing context structure --- platipy/imaging/cnn/dataset.py | 38 ++++-- platipy/imaging/cnn/localise_net.py | 24 +++- platipy/imaging/cnn/prob_unet.py | 54 ++++---- platipy/imaging/cnn/train.py | 182 +++++++++++++++----------- platipy/imaging/cnn/train_localise.py | 1 + 5 files changed, 185 insertions(+), 114 deletions(-) diff --git a/platipy/imaging/cnn/dataset.py b/platipy/imaging/cnn/dataset.py index 4b5ddf73..ad5c1180 100644 --- a/platipy/imaging/cnn/dataset.py +++ b/platipy/imaging/cnn/dataset.py @@ -428,17 +428,21 @@ def __init__( assert contour_mask_file.exists() contour_mask_files.append(contour_mask_file) - self.slices.append( - { + for oo in case["observers"]: + self.slices.append( + { "z": z_slice, "image": img_file, "labels": labels, + "complabels": [self.label_dir.joinpath( + f"{case_id}_{structure}_{oo}_{z_slice}.npy" + ) for structure in case["observers"][obs]], "contour_masks": contour_mask_files, "context_map": cmap_file, "case": case_id, "observer": obs, - } - ) + } + ) continue @@ -581,17 +585,22 @@ def __init__( ) labels.append(label_file) - self.slices.append( - { + # TODO allow enabling this + for oo in observers: + self.slices.append( + { "z": z_slice, "image": img_file, "labels": labels, + "complabels": [self.label_dir.joinpath( + f"{case_id}_{structure}_{oo}_{z_slice}.npy" + ) for structure in structure_names], "contour_masks": cmasks, "context_map": cmap_file, "case": case_id, "observer": obs, - } - ) + } + ) def __len__(self): return len(self.slices) @@ -602,6 +611,10 @@ def __getitem__(self, index): np.load(label_file) if label_file else np.zeros(img.shape, dtype=np.ushort) for label_file in self.slices[index]["labels"] ] + complabels = [ + np.load(label_file) if label_file else np.zeros(img.shape, dtype=np.ushort) + for label_file in self.slices[index]["complabels"] + ] contour_masks = [ np.load(contour_mask_file) for contour_mask_file in self.slices[index]["contour_masks"] @@ -614,7 +627,7 @@ def __getitem__(self, index): context_map = np.load(self.slices[index]["context_map"]) if self.transforms: - masks = labels + contour_masks + masks = labels + complabels + contour_masks if self.ndims == 2: seg_arr = np.concatenate([np.expand_dims(m, 2) for m in masks], 2) segmap = SegmentationMapsOnImage(seg_arr, shape=labels[0].shape) @@ -637,7 +650,8 @@ def __getitem__(self, index): else: img, _, masks = aug.apply(img, None, masks) labels = masks[: len(labels)] - contour_masks = masks[len(contour_masks) :] + complabels = masks[len(labels) : len(complabels) + len(labels)] + contour_masks = masks[len(labels) + len(complabels) : ] img = torch.FloatTensor(img) img = img.unsqueeze(0) @@ -649,6 +663,9 @@ def __getitem__(self, index): label = torch.FloatTensor( np.concatenate([np.expand_dims(l, 0) for l in labels], 0).astype("int8") ) + complabel = torch.FloatTensor( + np.concatenate([np.expand_dims(l, 0) for l in complabels], 0).astype("int8") + ) contour_mask = torch.FloatTensor( np.concatenate([np.expand_dims(l, 0) for l in contour_masks], 0).astype( "int8" @@ -661,6 +678,7 @@ def __getitem__(self, index): img, context_map, label, + complabel, contour_mask, { "case": str(self.slices[index]["case"]), diff --git a/platipy/imaging/cnn/localise_net.py b/platipy/imaging/cnn/localise_net.py index 82c85ec5..511bae5e 100644 --- a/platipy/imaging/cnn/localise_net.py +++ b/platipy/imaging/cnn/localise_net.py @@ -88,6 +88,7 @@ def infer(self, img): x = torch.Tensor(x) x = x.unsqueeze(0) x = x.unsqueeze(0) + x = x.to(self.device) y = self(x) y = y.squeeze(0) y = np.argmax(y.cpu().detach().numpy(), axis=0) @@ -104,11 +105,8 @@ def infer(self, img): def training_step(self, batch, _): - x, y, _, _ = batch + x, c, y, m, _ = batch - pred = self.unet.forward(x) - - criterion = torch.nn.BCEWithLogitsLoss(reduction="mean") # Take the max of all structure to combine into one big structure to localise y = y.max(axis=1).values @@ -118,6 +116,12 @@ def training_step(self, batch, _): not_y = y.logical_not() y = torch.cat((not_y, y), dim=1).float() + x = torch.cat((x, y), dim=1) + + pred = self.unet.forward(x) + + criterion = torch.nn.BCEWithLogitsLoss(reduction="mean") + loss = criterion(input=pred, target=y) return loss @@ -129,15 +133,23 @@ def validation_step(self, batch, _): print(self.validation_directory) with torch.set_grad_enabled(False): - x, y, _, info = batch + # x, y, _, info = batch + x, c, y, m, info = batch + y = y.max(axis=1).values + yy = torch.unsqueeze(y, dim=1) + + not_y = yy.logical_not() + yy = torch.cat((not_y, yy), dim=1).float() + + x = torch.cat((x, yy), dim=1) for s in range(y.shape[0]): img_file = self.validation_directory.joinpath( f"img_{info['case'][s]}_{info['z'][s]}.npy" ) - np.save(img_file, x[s].squeeze(0).cpu().numpy()) + np.save(img_file, x[s].squeeze(0)[0].cpu().numpy()) mask_file = self.validation_directory.joinpath( f"mask_{info['case'][s]}_{info['z'][s]}_{info['observer'][s]}.npy" diff --git a/platipy/imaging/cnn/prob_unet.py b/platipy/imaging/cnn/prob_unet.py index 47465a49..219a24df 100644 --- a/platipy/imaging/cnn/prob_unet.py +++ b/platipy/imaging/cnn/prob_unet.py @@ -237,13 +237,17 @@ def __init__( dropout_probability=dropout_probability, ndims=ndims, ) - self.prior = None - if not use_structure_context: - self.prior = AxisAlignedConvGaussian( - input_channels, filters_per_layer, latent_dim, ndims=ndims, dropout_probability=0.0 - ) + + self.prior = AxisAlignedConvGaussian( + unet_input_channels, filters_per_layer, latent_dim, ndims=ndims, dropout_probability=0.0 + ) + + post_channels = input_channels + num_classes + if use_structure_context: + post_channels = input_channels + (num_classes * 2) + self.posterior = AxisAlignedConvGaussian( - input_channels + num_classes, filters_per_layer, latent_dim, ndims=ndims, dropout_probability=0.0 + post_channels, filters_per_layer, latent_dim, ndims=ndims, dropout_probability=0.0 ) self.fcomb = Fcomb(filters_per_layer, latent_dim, num_classes, no_convs_fcomb, ndims=ndims) @@ -261,23 +265,22 @@ def __init__( self.register_buffer("_pos_weight", torch.ones(num_classes, requires_grad=False)) - def forward(self, img, seg=None, training=False): + def forward(self, img, seg=None, cseg=None, training=False): """ Construct prior latent space for patch and run patch through UNet, in case training is True also construct posterior latent space """ - if training or self.prior is None: - self.posterior_latent_space = self.posterior.forward(img, seg=seg) - - self.prior_latent_space = None - if self.prior is not None: - self.prior_latent_space = self.prior.forward(img) if self.use_structure_context: - if seg is None: - raise ValueError("Structure context is enabled, but no segmentation mask provided") - - img = torch.cat((img, seg), dim=1) + if cseg is None: + raise ValueError("Structure context is enabled, but no context segmentation mask provided") + + img = torch.cat((img, cseg), dim=1) + + if training: + self.posterior_latent_space = self.posterior.forward(img, seg=seg) + + self.prior_latent_space = self.prior.forward(img) self.unet_features = self.unet.forward(img) @@ -287,7 +290,7 @@ def sample(self, testing=False, use_mean=False, sample_x_stddev_from_mean=None): and combining this with UNet features """ - latent_space = self.prior_latent_space if self.prior is not None else self.posterior_latent_space + latent_space = self.prior_latent_space if testing: if use_mean: @@ -338,13 +341,14 @@ def kl_divergence(self): Calculate the KL divergence between the posterior and prior KL(Q||P) """ - if self.prior_latent_space is None: - - device = self.posterior_latent_space.base_dist.stddev.device - dist = Independent(Normal(loc=torch.zeros(self.latent_dim).to(device), scale=torch.ones(self.latent_dim).to(device)), 1) - kl_div = kl.kl_divergence(self.posterior_latent_space, dist) - else: - kl_div = kl.kl_divergence(self.posterior_latent_space, self.prior_latent_space) + #if self.prior_latent_space is None: + # device = self.posterior_latent_space.base_dist.stddev.device + # dist = Independent(Normal(loc=torch.zeros(self.latent_dim).to(device), scale=torch.ones(self.latent_dim).to(device)), 1) + # kl_div = kl.kl_divergence(self.posterior_latent_space, dist) + #else: + print(self.posterior_latent_space) + print(self.prior_latent_space) + kl_div = kl.kl_divergence(self.posterior_latent_space, self.prior_latent_space) return kl_div diff --git a/platipy/imaging/cnn/train.py b/platipy/imaging/cnn/train.py index b5468135..b30b793e 100644 --- a/platipy/imaging/cnn/train.py +++ b/platipy/imaging/cnn/train.py @@ -1,4 +1,4 @@ -# Copyright 2021 University of New South Wales, University of Sydney, Ingham Institute +# 2021 University of New South Wales, University of Sydney, Ingham Institute # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -137,7 +137,7 @@ def __init__( self.validation_directory = None self.kl_div = None - self.stddevs = np.linspace(-3, 3, self.hparams.num_observers) + self.stddevs = np.linspace(-2, 2, self.hparams.num_observers) @staticmethod def add_model_specific_args(parent_parser): @@ -185,20 +185,34 @@ def forward(self, x): return x def configure_optimizers(self): - optimizer = torch.optim.Adam( - self.parameters(), lr=self.hparams.learning_rate, weight_decay=0 - ) - lr_lambda_unet = lambda epoch: self.hparams.lr_lambda ** (epoch) - scheduler = torch.optim.lr_scheduler.LambdaLR( - optimizer, lr_lambda=[lr_lambda_unet] - ) + #optimizer = torch.optim.Adam( + # self.parameters(), lr=self.hparams.learning_rate, weight_decay=0 + #) + #lr_lambda_unet = lambda epoch: self.hparams.lr_lambda ** (epoch) +# scheduler = torch.optim.lr_scheduler.LambdaLR( +# optimizer, lr_lambda=[lr_lambda_unet] +# ) + + #scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( + # optimizer, 50, eta_min=1e-5, verbose=True + #) + + #scheduler = torch.optim.lr_scheduler.CyclicLR( + # optimizer, + # base_lr=self.hparams.learning_rate / 10, + # max_lr=self.hparams.learning_rate * 10, + # step_size_up=50, + # mode="exp_range", + # gamma=0.9999, + # cycle_momentum=False + #) - return [optimizer], [scheduler] + #return [optimizer], [scheduler] params = [ { "params": self.prob_unet.unet.parameters(), "weight_decay": self.hparams.weight_decay, - "lr": 1e-4, + "lr": 1e-5, } ] @@ -229,18 +243,18 @@ def configure_optimizers(self): # scheduler = torch.optim.lr_scheduler.LambdaLR( # optimizer, lr_lambda=[lr_lambda_unet, lr_lambda_prob, lr_lambda_prob, lr_lambda_prob] # ) - # scheduler = torch.optim.lr_scheduler.CyclicLR( - # optimizer, - # base_lr=self.hparams.learning_rate / 10, - # max_lr=self.hparams.learning_rate, - # step_size_up=20, - # mode="exp_range", - # gamma=0.999, - # cycle_momentum=False - # ) - scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( - optimizer, 50, eta_min=1e-6, verbose=True + scheduler = torch.optim.lr_scheduler.CyclicLR( + optimizer, + base_lr=self.hparams.learning_rate, + max_lr=self.hparams.learning_rate * 10, + step_size_up=50, + mode="exp_range", + gamma=0.99, + cycle_momentum=False ) + # scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( + # optimizer, 50, eta_min=1e-6, verbose=True + # ) return [optimizer], [scheduler] @@ -347,7 +361,7 @@ def infer( if seg is not None: seg = resample_mask_to_image(img, seg) seg_arr = sitk.GetArrayFromImage(seg) - + if self.hparams.ndims == 2: slices = [img_arr[z, :, :] for z in range(img_arr.shape[0])] @@ -388,29 +402,29 @@ def infer( if self.hparams.prob_type == "prob": if seg is not None: - self.prob_unet.forward(x, seg=s) + self.prob_unet.forward(x, cseg=s) else: self.prob_unet.forward(x) for sample in samples: if self.hparams.prob_type == "prob": if sample["name"] == "mean": - if seg is None: - y = self.prob_unet.sample(testing=True, use_mean=True) - else: - y = self.prob_unet.reconstruct(use_posterior_mean=True) +# if seg is None: + y = self.prob_unet.sample(testing=True, use_mean=True) +# else: +# y = self.prob_unet.reconstruct(use_posterior_mean=True) else: - if seg is None: - y = self.prob_unet.sample( - testing=True, - use_mean=False, - sample_x_stddev_from_mean=sample["std_dev_from_mean"], - ) - else: - y = self.prob_unet.reconstruct( - use_posterior_mean=False, - sample_x_stddev_from_mean=sample["std_dev_from_mean"], - ) +# if seg is None: + y = self.prob_unet.sample( + testing=True, + use_mean=False, + sample_x_stddev_from_mean=sample["std_dev_from_mean"], + ) + # else: + # y = self.prob_unet.reconstruct( + # use_posterior_mean=False, + # sample_x_stddev_from_mean=sample["std_dev_from_mean"], + # ) # else: # if sample["name"] == "mean": @@ -493,12 +507,12 @@ def validate( for man_struct in manual_observers } - vis.add_contour( - manual_observers_struct, - color=manual_color, - linewidth=0.5, - show_legend=False, - ) +# vis.add_contour( +# manual_observers_struct, +# color=manual_color, +# linewidth=0.5, +# show_legend=False, +# ) intersection_mask = get_intersection_mask(manual_observers_struct) union_mask = get_union_mask(manual_observers_struct) @@ -586,20 +600,26 @@ def validate( return result, fig def training_step(self, batch, _): - x, c, y, m, _ = batch + x, c, y, cy, m, _ = batch # Add background layer for one-hot encoding not_y = 1 - y.max(axis=1).values not_y = torch.unsqueeze(not_y, dim=1) y = torch.cat((not_y, y), dim=1).float() + not_cy = 1 - cy.max(axis=1).values + not_cy = torch.unsqueeze(not_cy, dim=1) + cy = torch.cat((not_cy, cy), dim=1).float() + # Concat context map to image if we have one if c.numel() > 0: x = torch.cat((x, c), dim=1).float() + print(f"{y.shape} {cy.shape}") + # self.prob_unet.forward(x, y, training=True) if self.hparams.prob_type == "prob": - self.prob_unet.forward(x, y, training=True) + self.prob_unet.forward(x, y, cy, training=True) # else: # self.prob_unet.forward(x, y) @@ -650,7 +670,7 @@ def validation_step(self, batch, _): m = self.hparams.num_observers with torch.set_grad_enabled(False): - x, c, y, _, info = batch + x, c, y, cy, _, info = batch # Save off slices/volumes for analysis of entire structure in end of validation step for s in range(y.shape[0]): @@ -699,9 +719,9 @@ def validation_step(self, batch, _): not_y = torch.unsqueeze(not_y, dim=1) seg = torch.cat((not_y, y), dim=1).float() - self.prob_unet.forward(x, seg=seg) - loss = self.prob_unet.loss(seg) - print(f"VAL LOSS: {loss}") + self.prob_unet.forward(x, cseg=seg) + # loss = self.prob_unet.loss(seg) + # print(f"VAL LOSS: {loss}") py = self.prob_unet.sample(testing=True) py = py.to("cpu") @@ -714,6 +734,11 @@ def validation_step(self, batch, _): y = y.int() y = y.to("cpu") + + cy = cy.squeeze(1) + cy = cy.int() + cy = cy.to("cpu") + # TODO Make this work for multi class # Intersection over Union (also known as Jaccard Index) jaccard = JaccardIndex(num_classes=2) @@ -747,18 +772,29 @@ def validation_step(self, batch, _): D_ged = term_1 - term_2 - term_3 contours = {} + contour_colors = {} for o in range(n): obs_y = y[o].float() if self.hparams.ndims == 2: obs_y = obs_y.unsqueeze(0) contours[f"obs_{o}"] = sitk.GetImageFromArray(obs_y) + contour_colors[f"obs_{o}"] = (0.3, 0.6, 0.3) for mm in range(m): samp_pred = pred_y[mm].float() if self.hparams.ndims == 2: samp_pred = samp_pred.unsqueeze(0) contours[f"sample_{mm}"] = sitk.GetImageFromArray(samp_pred) + contour_colors[f"sample_{mm}"] = (0.1, 0.1, 0.8) - vis.add_contour(contours, colormap=matplotlib.colormaps.get_cmap("cool")) + if self.use_structure_context: + for o in range(n): + obs_y = cy[o].float() + if self.hparams.ndims == 2: + obs_y = obs_y.unsqueeze(0) + contours[f"compobs_{o}"] = sitk.GetImageFromArray(obs_y) + contour_colors[f"compobs_{o}"] = (0.6, 0.3, 0.3) + + vis.add_contour(contours, color=contour_colors) vis.show() figure_path = f"ged_{info['z'][s]}.png" @@ -879,27 +915,27 @@ def validation_epoch_end(self, validation_step_outputs): stapled = sitk.Cast(stapled, sitk.sitkUInt8) seg = stapled - try: - mean = self.infer( - img, - context_map=context_map, - seg=seg, - num_samples=1, - sample_strategy="mean", - preprocess=False, - ) - samples = self.infer( - img, - context_map=context_map, - seg=seg, - sample_strategy="spaced", - num_samples=5, - spaced_range=[-2, 2], - preprocess=False, - ) - except Exception as e: - print(f"ERROR DURING VALIDATION INFERENCE: {e}") - return +# try: + mean = self.infer( + img, + context_map=context_map, + seg=seg, + num_samples=1, + sample_strategy="mean", + preprocess=False, + ) + samples = self.infer( + img, + context_map=context_map, + seg=seg, + sample_strategy="spaced", + num_samples=11, + spaced_range=[-2, 2], + preprocess=False, + ) +# except Exception as e: +# print(f"ERROR DURING VALIDATION INFERENCE: {e}") +# return # try: diff --git a/platipy/imaging/cnn/train_localise.py b/platipy/imaging/cnn/train_localise.py index af2a3d6c..97b97f91 100644 --- a/platipy/imaging/cnn/train_localise.py +++ b/platipy/imaging/cnn/train_localise.py @@ -109,6 +109,7 @@ def main(args, config_json_path=None): params = json.load(f) args = [] for key in params: + print(key) args.append(f"--{key}") if isinstance(params[key], list): From 2b9b5e0fa9d9af3572eb62405f726ca237b898a5 Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Sun, 17 Nov 2024 08:55:39 +1100 Subject: [PATCH 264/264] Context based inference --- platipy/imaging/cnn/dataset.py | 21 +++++++++++++-------- platipy/imaging/cnn/train.py | 5 +++++ 2 files changed, 18 insertions(+), 8 deletions(-) diff --git a/platipy/imaging/cnn/dataset.py b/platipy/imaging/cnn/dataset.py index ad5c1180..ccfd211b 100644 --- a/platipy/imaging/cnn/dataset.py +++ b/platipy/imaging/cnn/dataset.py @@ -235,7 +235,7 @@ def apply(self, img, context_map, masks=[]): def crop_img_using_localise_model( - img, localise_model, spacing=[1, 1, 1], crop_to_grid_size=[100, 100, 100] + img, localise_model, spacing=[1, 1, 1], crop_to_grid_size=[100, 100, 100], context_seg=None ): """Crops an image using a LocaliseUNet @@ -246,6 +246,8 @@ def crop_img_using_localise_model( spacing (list, optional): The image spacing (mm) to resample to. Defaults to [1,1,1]. crop_to_grid_size (list, optional): The size of the grid to crop to. Defaults to [100,100,100]. + context_seg (sitk.Image, optional): Use this segmentation instead of localise model if + provided. Defaults to None. Returns: SimpleITK.Image: The cropped image. @@ -254,15 +256,18 @@ def crop_img_using_localise_model( if isinstance(localise_model, str): localise_model = Path(localise_model) - if isinstance(localise_model, Path): - if localise_model.is_dir(): - # Find the first actual model checkpoint in this directory - localise_model = next(localise_model.glob("*.ckpt")) + if context_seg is not None: + localise_pred = context_seg + else: + if isinstance(localise_model, Path): + if localise_model.is_dir(): + # Find the first actual model checkpoint in this directory + localise_model = next(localise_model.glob("*.ckpt")) - localise_model = LocaliseUNet.load_from_checkpoint(localise_model) + localise_model = LocaliseUNet.load_from_checkpoint(localise_model) - localise_model.eval() - localise_pred = localise_model.infer(img) + localise_model.eval() + localise_pred = localise_model.infer(img) img = preprocess_image(img, spacing=spacing, crop_to_grid_size_xy=None) localise_pred = resample_mask_to_image(img, localise_pred) diff --git a/platipy/imaging/cnn/train.py b/platipy/imaging/cnn/train.py index b30b793e..7d865297 100644 --- a/platipy/imaging/cnn/train.py +++ b/platipy/imaging/cnn/train.py @@ -283,6 +283,7 @@ def infer( latent_dim=True, spaced_range=[-1.5, 1.5], preprocess=True, + return_latent_space=False ): # sample strategy in "mean", "random", "spaced" @@ -340,6 +341,7 @@ def infer( localise_path, spacing=self.hparams.spacing, crop_to_grid_size=self.hparams.localise_voxel_grid_size, + context_seg=seg ) else: img = preprocess_image( @@ -406,6 +408,9 @@ def infer( else: self.prob_unet.forward(x) + if return_latent_space: + return self.prob_unet.prior_latent_space + for sample in samples: if self.hparams.prob_type == "prob": if sample["name"] == "mean":