From d24816d657e82f409baf40e9fd390c491d959ba0 Mon Sep 17 00:00:00 2001 From: Imraj-Singh <72553490+Imraj-Singh@users.noreply.github.com> Date: Tue, 30 Jul 2024 11:00:29 +0000 Subject: [PATCH] greedy multio iteration learning to optimise --- adam.py | 4 +- greedy_multiple_iteration.py | 268 +++++++++++++++++ petric.py | 566 +++++++++++++++++------------------ 3 files changed, 553 insertions(+), 285 deletions(-) create mode 100644 greedy_multiple_iteration.py diff --git a/adam.py b/adam.py index cc5d836..15cf978 100644 --- a/adam.py +++ b/adam.py @@ -59,8 +59,8 @@ def __init__(self, data, initial, initial_step_size, relaxation_eta, self.beta2 = 0.999 self.eps_adam = 1e-8 - self.m = initial.copy() - self.m.fill(np.zeros_like(initial.as_array())) + self.m = initial.get_uniform_copy(0) + # self.m.fill(np.zeros_like(initial.as_array())) self.m_hat = initial.copy() self.m_hat.fill(np.zeros_like(initial.as_array())) diff --git a/greedy_multiple_iteration.py b/greedy_multiple_iteration.py new file mode 100644 index 0000000..ab4cc02 --- /dev/null +++ b/greedy_multiple_iteration.py @@ -0,0 +1,268 @@ +# %% + +from pathlib import Path +import sirf.STIR as STIR +import numpy as np +import logging +import os +from dataclasses import dataclass +from matplotlib import pyplot as plt +from random import shuffle + + +log = logging.getLogger('petric') +TEAM = os.getenv("GITHUB_REPOSITORY", "SyneRBI/PETRIC-").split("/PETRIC-", 1)[-1] +VERSION = os.getenv("GITHUB_REF_NAME", "") +OUTDIR = Path(f"/o/logs/{TEAM}/{VERSION}" if TEAM and VERSION else "./output") +if not (SRCDIR := Path("/mnt/share/petric")).is_dir(): + SRCDIR = Path("./data") + + +def construct_RDP(penalty_strength, initial_image, kappa, max_scaling=1e-3): + """ + Construct a smoothed Relative Difference Prior (RDP) + + initial_image: used to determine a smoothing factor (epsilon). + kappa: used to pass voxel-dependent weights. + """ + prior = getattr(STIR, 'CudaRelativeDifferencePrior', STIR.RelativeDifferencePrior)() + # need to make it differentiable + epsilon = initial_image.max() * max_scaling + prior.set_epsilon(epsilon) + prior.set_penalisation_factor(penalty_strength) + prior.set_kappa(kappa) + prior.set_up(initial_image) + return prior +@dataclass +class Dataset: + acquired_data: STIR.AcquisitionData + additive_term: STIR.AcquisitionData + mult_factors: STIR.AcquisitionData + OSEM_image: STIR.ImageData + prior: STIR.RelativeDifferencePrior + kappa: STIR.ImageData + reference_image: STIR.ImageData | None + whole_object_mask: STIR.ImageData | None + background_mask: STIR.ImageData | None + voi_masks: dict[str, STIR.ImageData] + +def get_data(srcdir=".", outdir=OUTDIR, sirf_verbosity=0): + """ + Load data from `srcdir`, constructs prior and return as a `Dataset`. + Also redirects sirf.STIR log output to `outdir`. + """ + srcdir = Path(srcdir) + outdir = Path(outdir) + STIR.set_verbosity(sirf_verbosity) # set to higher value to diagnose problems + STIR.AcquisitionData.set_storage_scheme('memory') # needed for get_subsets() + + _ = STIR.MessageRedirector(str(outdir / 'info.txt'), str(outdir / 'warnings.txt'), str(outdir / 'errors.txt')) + acquired_data = STIR.AcquisitionData(str(srcdir / 'prompts.hs')) + additive_term = STIR.AcquisitionData(str(srcdir / 'additive_term.hs')) + mult_factors = STIR.AcquisitionData(str(srcdir / 'mult_factors.hs')) + OSEM_image = STIR.ImageData(str(srcdir / 'OSEM_image.hv')) + kappa = STIR.ImageData(str(srcdir / 'kappa.hv')) + if (penalty_strength_file := (srcdir / 'penalisation_factor.txt')).is_file(): + penalty_strength = float(np.loadtxt(penalty_strength_file)) + else: + penalty_strength = 1 / 700 # default choice + prior = construct_RDP(penalty_strength, OSEM_image, kappa) + + def get_image(fname): + if (source := srcdir / 'PETRIC' / fname).is_file(): + return STIR.ImageData(str(source)) + return None # explicit to suppress linter warnings + + reference_image = get_image('reference_image.hv') + whole_object_mask = get_image('VOI_whole_object.hv') + background_mask = get_image('VOI_background.hv') + voi_masks = { + voi.stem[4:]: STIR.ImageData(str(voi)) + for voi in (srcdir / 'PETRIC').glob("VOI_*.hv") if voi.stem[4:] not in ('background', 'whole_object')} + + return Dataset(acquired_data, additive_term, mult_factors, OSEM_image, prior, kappa, reference_image, + whole_object_mask, background_mask, voi_masks) + + +if SRCDIR.is_dir(): + data_dirs_metrics = [(SRCDIR / "Siemens_mMR_NEMA_IQ", OUTDIR / "mMR_NEMA"), + (SRCDIR / "NeuroLF_Hoffman_Dataset", OUTDIR / "NeuroLF_Hoffman"), + (SRCDIR / "Siemens_Vision600_thorax", OUTDIR / "Vision600_thorax")] + +dataset = "nema" +if dataset == "nema": + data = get_data(srcdir=SRCDIR / "Siemens_mMR_NEMA_IQ", outdir=OUTDIR / "mMR_NEMA") +elif dataset == "hoffman": + data = get_data(srcdir=SRCDIR / "NeuroLF_Hoffman_Dataset", outdir=OUTDIR / "NeuroLF_Hoffman") +elif dataset == "thorax": + data = get_data(srcdir=SRCDIR / "Siemens_Vision600_thorax", outdir=OUTDIR / "Vision600_thorax") +print("Data loaded") + +# %% +from sirf.contrib.partitioner import partitioner +num_subsets = 5 + +if dataset == "nema": + _, _, obj_funs = partitioner.data_partition(data.acquired_data, data.additive_term, + data.mult_factors, num_subsets, + initial_image=data.OSEM_image, + mode="staggered") + _, _, full_obj_fun = partitioner.data_partition(data.acquired_data, data.additive_term, + data.mult_factors, 1, + initial_image=data.OSEM_image, + mode="staggered") +elif dataset == "hoffman": + _, _, obj_funs = partitioner.data_partition(data.acquired_data, data.additive_term, + data.mult_factors, num_subsets, + initial_image=data.OSEM_image, + mode="staggered") + _, _, full_obj_fun = partitioner.data_partition(data.acquired_data, data.additive_term, + data.mult_factors, 1, + initial_image=data.OSEM_image, + mode="staggered") +elif dataset == "thorax": + _, _, obj_funs = partitioner.data_partition(data.acquired_data, data.additive_term, + data.mult_factors, num_subsets, + initial_image=data.OSEM_image, + mode="staggered") + _, _, full_obj_fun = partitioner.data_partition(data.acquired_data, data.additive_term, + data.mult_factors, 1, + initial_image=data.OSEM_image, + mode="staggered") + +print("Data partitioned") + + +# make dir is non existent +Path("unrolled_imgs").mkdir(exist_ok=True) +# make subdir of dataset +Path(f"unrolled_imgs/{dataset}").mkdir(exist_ok=True) +dir_path = Path(f"unrolled_imgs/{dataset}") + +# %% +import torch +#torch.cuda.set_per_process_memory_fraction(0.2) +class _SIRF_objective_wrapper(torch.autograd.Function): + @staticmethod + def forward(ctx, x_torch, x_sirf, obj): + ctx.device = x_torch.device + ctx.dtype = x_torch.dtype + ctx.shape = x_torch.shape + x_torch = x_torch.data.clone().cpu().detach().squeeze().numpy() + x_sirf = x_sirf.fill(x_torch) + ctx.x_sirf = x_sirf + ctx.obj = obj + + return -torch.tensor(obj(x_sirf.fill(x_torch)), device=ctx.device, dtype=ctx.dtype) + + @staticmethod + def backward(ctx, grad_output): + #x_torch = ctx.saved_tensors + """ print("torch.cuda.memory_allocated: %fGB"%(torch.cuda.memory_allocated(0)/1024/1024/1024)) + print("torch.cuda.memory_reserved: %fGB"%(torch.cuda.memory_reserved(0)/1024/1024/1024)) + print("torch.cuda.max_memory_reserved: %fGB"%(torch.cuda.max_memory_reserved(0)/1024/1024/1024)) """ + ctx.obj.gradient(ctx.x_sirf) + + grad_input = -torch.tensor(ctx.obj.gradient(ctx.x_sirf).as_array(), device=ctx.device, dtype=ctx.dtype).view(ctx.shape)*grad_output + return grad_input, None, None + + +device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') +device = 'cpu' +class NetworkPreconditioner(torch.nn.Module): + def __init__(self, n_layers = 1, hidden_channels = 32, kernel_size = 5): + super(NetworkPreconditioner, self).__init__() + self.list_of_conv2 = torch.nn.ModuleList() + self.list_of_conv2.append(torch.nn.Conv2d(1, hidden_channels, kernel_size, padding='same', bias=False)) + for _ in range(n_layers): + self.list_of_conv2.append(torch.nn.Conv2d(hidden_channels, hidden_channels, kernel_size, padding='same', bias=False)) + self.list_of_conv2.append(torch.nn.Conv2d(hidden_channels, 1, kernel_size, padding='same', bias=False)) + self.activation = torch.nn.ReLU() + def forward(self, x): + for layer in self.list_of_conv2[:-1]: + x = layer(x) + x = self.activation(x) + x = self.list_of_conv2[-1](x) + return x + + +class DeepUnrolledPreconditioner(torch.nn.Module): + def __init__(self, unrolled_iterations = 10, n_layers = 1, hidden_channels = 32, kernel_size = 5, single_network = False): + super(DeepUnrolledPreconditioner, self).__init__() + self.nets = torch.nn.ModuleList() + self.unrolled_iterations = unrolled_iterations + self.single_network = single_network + if single_network: + self.nets.append(NetworkPreconditioner(n_layers, hidden_channels, kernel_size)) + else: + for _ in range(unrolled_iterations): + self.nets.append(NetworkPreconditioner(n_layers, hidden_channels, kernel_size)) + def forward(self, x, obj_funs, sirf_img, compute_upto = 1, plot=False, epoch = 0, update_filter = STIR.TruncateToCylinderProcessor()): + xs = [] + if compute_upto > self.unrolled_iterations: raise ValueError("Cannot compute more than unrolled_iterations") + for i in range(compute_upto): + if plot: + fig, axs = plt.subplots(1, 3, figsize=(30, 10)) + + tmp = obj_funs[i].gradient(sirf_img.fill(x.detach().cpu().squeeze().numpy())) + update_filter.apply(tmp) + grad = -torch.tensor(tmp.as_array(), device=device).unsqueeze(1) + grad_sens = grad * (x + 1e-3)/(torch.tensor(obj_funs[i].get_subset_sensitivity(0).as_array(), device=device).unsqueeze(1) + 1e-3) + if self.single_network: + precond = self.nets[0](grad_sens) + else: + precond = self.nets[i](grad_sens) + if plot: + fig.colorbar(axs[0].imshow(grad_sens.detach().cpu().numpy()[72, 0, :, :]), ax=axs[0]) + axs[0].set_title("Gradient") + fig.colorbar(axs[1].imshow(precond.detach().cpu().numpy()[72,0, :, :]), ax=axs[1]) + axs[1].set_title("Preconditioner") + x = x - precond + x.clamp_(0) + xs.append(x) + if plot: + fig.colorbar(axs[2].imshow(x.detach().cpu().numpy()[72,0, :, :]), ax=axs[2]) + axs[2].set_title("Updated Image") + plt.savefig(f"{dir_path}/image_e{epoch}_it{i}.png") + plt.close() + return xs + + +unrolled_iterations = num_subsets +precond = DeepUnrolledPreconditioner(unrolled_iterations=unrolled_iterations, n_layers=1, hidden_channels=16, kernel_size=5, single_network=False) +precond.to(device) +print("Preconditioner created and moved to device") + +optimizer = torch.optim.Adam(precond.parameters(), lr=1e-4) + +data.prior.set_penalisation_factor(data.prior.get_penalisation_factor() / len(obj_funs)) +data.prior.set_up(data.OSEM_image) +for f in obj_funs: # add prior evenly to every objective function + f.set_prior(data.prior) + +osem_input_torch = torch.tensor(data.OSEM_image.as_array(), device=device).unsqueeze(1) +x_sirf = data.OSEM_image.clone() +losses = [] +for i in range(unrolled_iterations*100): + optimizer.zero_grad() + shuffle(obj_funs) + compute_upto = (i//100)+1 + xs = precond(osem_input_torch, obj_funs, compute_upto = compute_upto, sirf_img = x_sirf, plot=True, epoch=i) + loss = 0 + for loss_i in range(compute_upto): + loss += _SIRF_objective_wrapper.apply(xs[0], x_sirf, full_obj_fun[0]) + loss = loss/compute_upto + print(f"Iteration: {i}, Loss: {loss.item()}") + loss.backward() + optimizer.step() + plt.imshow(xs[loss_i].detach().cpu().numpy()[72,0, :, :]) + # Make title loss value + plt.title(f"Loss: {loss.item()}") + plt.colorbar() + plt.savefig(f"{dir_path}/final_image_{i}.png") + plt.close() + plt.plot(losses) + plt.savefig(f"{dir_path}/losses.png") + plt.close() + + diff --git a/petric.py b/petric.py index 0e9ab01..9fc0c8d 100755 --- a/petric.py +++ b/petric.py @@ -1,283 +1,283 @@ -#!/usr/bin/env python -""" -ANY CHANGES TO THIS FILE ARE IGNORED BY THE ORGANISERS. -Only the `main.py` file may be modified by participants. - -This file is not intended for participants to use, except for -the `get_data` function (and possibly `QualityMetrics` class). -It is used by the organisers to run the submissions in a controlled way. -It is included here purely in the interest of transparency. - -Usage: - petric.py [options] - -Options: - --log LEVEL : Set logging level (DEBUG, [default: INFO], WARNING, ERROR, CRITICAL) -""" -import csv -import logging -import os -from dataclasses import dataclass -from pathlib import Path -from time import time -from traceback import print_exc - -import numpy as np -from skimage.metrics import mean_squared_error as mse -from tensorboardX import SummaryWriter - -import sirf.STIR as STIR -from cil.optimisation.algorithms import Algorithm -from cil.optimisation.utilities import callbacks as cil_callbacks -from img_quality_cil_stir import ImageQualityCallback - -log = logging.getLogger('petric') -TEAM = os.getenv("GITHUB_REPOSITORY", "SyneRBI/PETRIC-").split("/PETRIC-", 1)[-1] -VERSION = os.getenv("GITHUB_REF_NAME", "") -OUTDIR = Path(f"/o/logs/{TEAM}/{VERSION}" if TEAM and VERSION else "./output") -if not (SRCDIR := Path("/mnt/share/petric")).is_dir(): - SRCDIR = Path("./data") - - -class Callback(cil_callbacks.Callback): - """ - CIL Callback but with `self.skip_iteration` checking `min(self.interval, algo.update_objective_interval)`. - TODO: backport this class to CIL. - """ - def __init__(self, interval: int = 1 << 31, **kwargs): - super().__init__(**kwargs) - self.interval = interval - - def skip_iteration(self, algo: Algorithm) -> bool: - return algo.iteration % min(self.interval, - algo.update_objective_interval) != 0 and algo.iteration != algo.max_iteration - - -class SaveIters(Callback): - """Saves `algo.x` as "iter_{algo.iteration:04d}.hv" and `algo.loss` in `csv_file`""" - def __init__(self, outdir=OUTDIR, csv_file='objectives.csv', **kwargs): - super().__init__(**kwargs) - self.outdir = Path(outdir) - self.outdir.mkdir(parents=True, exist_ok=True) - self.csv = csv.writer((self.outdir / csv_file).open("w", buffering=1)) - self.csv.writerow(("iter", "objective")) - - def __call__(self, algo: Algorithm): - if not self.skip_iteration(algo): - log.debug("saving iter %d...", algo.iteration) - algo.x.write(str(self.outdir / f'iter_{algo.iteration:04d}.hv')) - self.csv.writerow((algo.iteration, algo.get_last_loss())) - log.debug("...saved") - if algo.iteration == algo.max_iteration: - algo.x.write(str(self.outdir / 'iter_final.hv')) - - -class StatsLog(Callback): - """Log image slices & objective value""" - def __init__(self, transverse_slice=None, coronal_slice=None, vmax=None, logdir=OUTDIR, **kwargs): - super().__init__(**kwargs) - self.transverse_slice = transverse_slice - self.coronal_slice = coronal_slice - self.vmax = vmax - self.x_prev = None - self.tb = logdir if isinstance(logdir, SummaryWriter) else SummaryWriter(logdir=str(logdir)) - - def __call__(self, algo: Algorithm): - if self.skip_iteration(algo): - return - t = getattr(self, '__time', None) or time() - log.debug("logging iter %d...", algo.iteration) - # initialise `None` values - self.transverse_slice = algo.x.dimensions()[0] // 2 if self.transverse_slice is None else self.transverse_slice - self.coronal_slice = algo.x.dimensions()[1] // 2 if self.coronal_slice is None else self.coronal_slice - self.vmax = algo.x.max() if self.vmax is None else self.vmax - - self.tb.add_scalar("objective", algo.get_last_loss(), algo.iteration, t) - if self.x_prev is not None: - normalised_change = (algo.x - self.x_prev).norm() / algo.x.norm() - self.tb.add_scalar("normalised_change", normalised_change, algo.iteration, t) - self.x_prev = algo.x.clone() - x_arr = algo.x.as_array() - self.tb.add_image("transverse", np.clip(x_arr[self.transverse_slice:self.transverse_slice + 1] / self.vmax, 0, - 1), algo.iteration, t) - self.tb.add_image("coronal", np.clip(x_arr[None, :, self.coronal_slice] / self.vmax, 0, 1), algo.iteration, t) - log.debug("...logged") - - -class QualityMetrics(ImageQualityCallback, Callback): - """From https://github.com/SyneRBI/PETRIC/wiki#metrics-and-thresholds""" - def __init__(self, reference_image, whole_object_mask, background_mask, interval: int = 1 << 31, **kwargs): - # TODO: drop multiple inheritance once `interval` included in CIL - Callback.__init__(self, interval=interval) - ImageQualityCallback.__init__(self, reference_image, **kwargs) - self.whole_object_indices = np.where(whole_object_mask.as_array()) - self.background_indices = np.where(background_mask.as_array()) - self.ref_im_arr = reference_image.as_array() - self.norm = self.ref_im_arr[self.background_indices].mean() - - def __call__(self, algo: Algorithm): - if self.skip_iteration(algo): - return - t = getattr(self, '__time', None) or time() - for tag, value in self.evaluate(algo.x).items(): - self.tb_summary_writer.add_scalar(tag, value, algo.iteration, t) - - def evaluate(self, test_im: STIR.ImageData) -> dict[str, float]: - assert not any(self.filter.values()), "Filtering not implemented" - test_im_arr = test_im.as_array() - whole = { - "RMSE_whole_object": np.sqrt( - mse(self.ref_im_arr[self.whole_object_indices], test_im_arr[self.whole_object_indices])) / self.norm, - "RMSE_background": np.sqrt( - mse(self.ref_im_arr[self.background_indices], test_im_arr[self.background_indices])) / self.norm} - local = { - f"AEM_VOI_{voi_name}": np.abs(test_im_arr[voi_indices].mean() - self.ref_im_arr[voi_indices].mean()) / - self.norm - for voi_name, voi_indices in sorted(self.voi_indices.items())} - return {**whole, **local} - - def keys(self): - return ["RMSE_whole_object", "RMSE_background"] + [f"AEM_VOI_{name}" for name in sorted(self.voi_indices)] - - -class MetricsWithTimeout(cil_callbacks.Callback): - """Stops the algorithm after `seconds`""" - def __init__(self, seconds=300, outdir=OUTDIR, transverse_slice=None, coronal_slice=None, **kwargs): - super().__init__(**kwargs) - self._seconds = seconds - self.callbacks = [ - cil_callbacks.ProgressCallback(), - SaveIters(outdir=outdir), - (tb_cbk := StatsLog(logdir=outdir, transverse_slice=transverse_slice, coronal_slice=coronal_slice))] - self.tb = tb_cbk.tb # convenient access to the underlying SummaryWriter - self.reset() - - def reset(self, seconds=None): - self.limit = time() + (self._seconds if seconds is None else seconds) - self.offset = 0 - - def __call__(self, algo: Algorithm): - if (now := time()) > self.limit + self.offset: - log.warning("Timeout reached. Stopping algorithm.") - raise StopIteration - for c in self.callbacks: - c.__time = now - self.offset # privately inject walltime-excluding-petric-callbacks - c(algo) - self.offset += time() - now - - @staticmethod - def mean_absolute_error(y, x): - return np.mean(np.abs(y, x)) - - -def construct_RDP(penalty_strength, initial_image, kappa, max_scaling=1e-3): - """ - Construct a smoothed Relative Difference Prior (RDP) - - initial_image: used to determine a smoothing factor (epsilon). - kappa: used to pass voxel-dependent weights. - """ - prior = getattr(STIR, 'CudaRelativeDifferencePrior', STIR.RelativeDifferencePrior)() - # need to make it differentiable - epsilon = initial_image.max() * max_scaling - prior.set_epsilon(epsilon) - prior.set_penalisation_factor(penalty_strength) - prior.set_kappa(kappa) - prior.set_up(initial_image) - return prior - - -@dataclass -class Dataset: - acquired_data: STIR.AcquisitionData - additive_term: STIR.AcquisitionData - mult_factors: STIR.AcquisitionData - OSEM_image: STIR.ImageData - prior: STIR.RelativeDifferencePrior - kappa: STIR.ImageData - reference_image: STIR.ImageData | None - whole_object_mask: STIR.ImageData | None - background_mask: STIR.ImageData | None - voi_masks: dict[str, STIR.ImageData] - - -def get_data(srcdir=".", outdir=OUTDIR, sirf_verbosity=0): - """ - Load data from `srcdir`, constructs prior and return as a `Dataset`. - Also redirects sirf.STIR log output to `outdir`. - """ - srcdir = Path(srcdir) - outdir = Path(outdir) - STIR.set_verbosity(sirf_verbosity) # set to higher value to diagnose problems - STIR.AcquisitionData.set_storage_scheme('memory') # needed for get_subsets() - - _ = STIR.MessageRedirector(str(outdir / 'info.txt'), str(outdir / 'warnings.txt'), str(outdir / 'errors.txt')) - acquired_data = STIR.AcquisitionData(str(srcdir / 'prompts.hs')) - additive_term = STIR.AcquisitionData(str(srcdir / 'additive_term.hs')) - mult_factors = STIR.AcquisitionData(str(srcdir / 'mult_factors.hs')) - OSEM_image = STIR.ImageData(str(srcdir / 'OSEM_image.hv')) - kappa = STIR.ImageData(str(srcdir / 'kappa.hv')) - if (penalty_strength_file := (srcdir / 'penalisation_factor.txt')).is_file(): - penalty_strength = float(np.loadtxt(penalty_strength_file)) - else: - penalty_strength = 1 / 700 # default choice - prior = construct_RDP(penalty_strength, OSEM_image, kappa) - - def get_image(fname): - if (source := srcdir / 'PETRIC' / fname).is_file(): - return STIR.ImageData(str(source)) - return None # explicit to suppress linter warnings - - reference_image = get_image('reference_image.hv') - whole_object_mask = get_image('VOI_whole_object.hv') - background_mask = get_image('VOI_background.hv') - voi_masks = { - voi.stem[4:]: STIR.ImageData(str(voi)) - for voi in (srcdir / 'PETRIC').glob("VOI_*.hv") if voi.stem[4:] not in ('background', 'whole_object')} - - return Dataset(acquired_data, additive_term, mult_factors, OSEM_image, prior, kappa, reference_image, - whole_object_mask, background_mask, voi_masks) - - -if SRCDIR.is_dir(): - # create list of existing data - # NB: `MetricsWithTimeout` initialises `SaveIters` which creates `outdir` - data_dirs_metrics = [(SRCDIR / "Siemens_mMR_NEMA_IQ", OUTDIR / "mMR_NEMA", - [MetricsWithTimeout(outdir=OUTDIR / "mMR_NEMA", transverse_slice=72, coronal_slice=109)]), - (SRCDIR / "NeuroLF_Hoffman_Dataset", OUTDIR / "NeuroLF_Hoffman", - [MetricsWithTimeout(outdir=OUTDIR / "NeuroLF_Hoffman", transverse_slice=72)]), - (SRCDIR / "Siemens_Vision600_thorax", OUTDIR / "Vision600_thorax", - [MetricsWithTimeout(outdir=OUTDIR / "Vision600_thorax")])] -else: - log.warning("Source directory does not exist: %s", SRCDIR) - data_dirs_metrics = [(None, None, [])] # type: ignore - -if __name__ != "__main__": - # load up first data-set for people to play with - srcdir, outdir, metrics = data_dirs_metrics[0] - if srcdir is None: - data = None - else: - data = get_data(srcdir=srcdir, outdir=outdir) - metrics[0].reset() -else: - from docopt import docopt - args = docopt(__doc__) - logging.basicConfig(level=getattr(logging, args["--log"].upper())) - from main import Submission, submission_callbacks - assert issubclass(Submission, Algorithm) - for srcdir, outdir, metrics in data_dirs_metrics: - data = get_data(srcdir=srcdir, outdir=outdir) - metrics_with_timeout = metrics[0] - if data.reference_image is not None: - metrics_with_timeout.callbacks.append( - QualityMetrics(data.reference_image, data.whole_object_mask, data.background_mask, - tb_summary_writer=metrics_with_timeout.tb, voi_mask_dict=data.voi_masks)) - metrics_with_timeout.reset() # timeout from now - algo = Submission(data) - try: - algo.run(np.inf, callbacks=metrics + submission_callbacks) - except Exception: - print_exc(limit=2) - finally: - del algo +#!/usr/bin/env python +""" +ANY CHANGES TO THIS FILE ARE IGNORED BY THE ORGANISERS. +Only the `main.py` file may be modified by participants. + +This file is not intended for participants to use, except for +the `get_data` function (and possibly `QualityMetrics` class). +It is used by the organisers to run the submissions in a controlled way. +It is included here purely in the interest of transparency. + +Usage: + petric.py [options] + +Options: + --log LEVEL : Set logging level (DEBUG, [default: INFO], WARNING, ERROR, CRITICAL) +""" +import csv +import logging +import os +from dataclasses import dataclass +from pathlib import Path +from time import time +from traceback import print_exc + +import numpy as np +from skimage.metrics import mean_squared_error as mse +from tensorboardX import SummaryWriter + +import sirf.STIR as STIR +from cil.optimisation.algorithms import Algorithm +from cil.optimisation.utilities import callbacks as cil_callbacks +from img_quality_cil_stir import ImageQualityCallback + +log = logging.getLogger('petric') +TEAM = os.getenv("GITHUB_REPOSITORY", "SyneRBI/PETRIC-").split("/PETRIC-", 1)[-1] +VERSION = os.getenv("GITHUB_REF_NAME", "") +OUTDIR = Path(f"/o/logs/{TEAM}/{VERSION}" if TEAM and VERSION else "./output") +if not (SRCDIR := Path("/mnt/share/petric")).is_dir(): + SRCDIR = Path("./data") + + +class Callback(cil_callbacks.Callback): + """ + CIL Callback but with `self.skip_iteration` checking `min(self.interval, algo.update_objective_interval)`. + TODO: backport this class to CIL. + """ + def __init__(self, interval: int = 1 << 31, **kwargs): + super().__init__(**kwargs) + self.interval = interval + + def skip_iteration(self, algo: Algorithm) -> bool: + return algo.iteration % min(self.interval, + algo.update_objective_interval) != 0 and algo.iteration != algo.max_iteration + + +class SaveIters(Callback): + """Saves `algo.x` as "iter_{algo.iteration:04d}.hv" and `algo.loss` in `csv_file`""" + def __init__(self, outdir=OUTDIR, csv_file='objectives.csv', **kwargs): + super().__init__(**kwargs) + self.outdir = Path(outdir) + self.outdir.mkdir(parents=True, exist_ok=True) + self.csv = csv.writer((self.outdir / csv_file).open("w", buffering=1)) + self.csv.writerow(("iter", "objective")) + + def __call__(self, algo: Algorithm): + if not self.skip_iteration(algo): + log.debug("saving iter %d...", algo.iteration) + algo.x.write(str(self.outdir / f'iter_{algo.iteration:04d}.hv')) + self.csv.writerow((algo.iteration, algo.get_last_loss())) + log.debug("...saved") + if algo.iteration == algo.max_iteration: + algo.x.write(str(self.outdir / 'iter_final.hv')) + + +class StatsLog(Callback): + """Log image slices & objective value""" + def __init__(self, transverse_slice=None, coronal_slice=None, vmax=None, logdir=OUTDIR, **kwargs): + super().__init__(**kwargs) + self.transverse_slice = transverse_slice + self.coronal_slice = coronal_slice + self.vmax = vmax + self.x_prev = None + self.tb = logdir if isinstance(logdir, SummaryWriter) else SummaryWriter(logdir=str(logdir)) + + def __call__(self, algo: Algorithm): + if self.skip_iteration(algo): + return + t = getattr(self, '__time', None) or time() + log.debug("logging iter %d...", algo.iteration) + # initialise `None` values + self.transverse_slice = algo.x.dimensions()[0] // 2 if self.transverse_slice is None else self.transverse_slice + self.coronal_slice = algo.x.dimensions()[1] // 2 if self.coronal_slice is None else self.coronal_slice + self.vmax = algo.x.max() if self.vmax is None else self.vmax + + self.tb.add_scalar("objective", algo.get_last_loss(), algo.iteration, t) + if self.x_prev is not None: + normalised_change = (algo.x - self.x_prev).norm() / algo.x.norm() + self.tb.add_scalar("normalised_change", normalised_change, algo.iteration, t) + self.x_prev = algo.x.clone() + x_arr = algo.x.as_array() + self.tb.add_image("transverse", np.clip(x_arr[self.transverse_slice:self.transverse_slice + 1] / self.vmax, 0, + 1), algo.iteration, t) + self.tb.add_image("coronal", np.clip(x_arr[None, :, self.coronal_slice] / self.vmax, 0, 1), algo.iteration, t) + log.debug("...logged") + + +class QualityMetrics(ImageQualityCallback, Callback): + """From https://github.com/SyneRBI/PETRIC/wiki#metrics-and-thresholds""" + def __init__(self, reference_image, whole_object_mask, background_mask, interval: int = 1 << 31, **kwargs): + # TODO: drop multiple inheritance once `interval` included in CIL + Callback.__init__(self, interval=interval) + ImageQualityCallback.__init__(self, reference_image, **kwargs) + self.whole_object_indices = np.where(whole_object_mask.as_array()) + self.background_indices = np.where(background_mask.as_array()) + self.ref_im_arr = reference_image.as_array() + self.norm = self.ref_im_arr[self.background_indices].mean() + + def __call__(self, algo: Algorithm): + if self.skip_iteration(algo): + return + t = getattr(self, '__time', None) or time() + for tag, value in self.evaluate(algo.x).items(): + self.tb_summary_writer.add_scalar(tag, value, algo.iteration, t) + + def evaluate(self, test_im: STIR.ImageData) -> dict[str, float]: + assert not any(self.filter.values()), "Filtering not implemented" + test_im_arr = test_im.as_array() + whole = { + "RMSE_whole_object": np.sqrt( + mse(self.ref_im_arr[self.whole_object_indices], test_im_arr[self.whole_object_indices])) / self.norm, + "RMSE_background": np.sqrt( + mse(self.ref_im_arr[self.background_indices], test_im_arr[self.background_indices])) / self.norm} + local = { + f"AEM_VOI_{voi_name}": np.abs(test_im_arr[voi_indices].mean() - self.ref_im_arr[voi_indices].mean()) / + self.norm + for voi_name, voi_indices in sorted(self.voi_indices.items())} + return {**whole, **local} + + def keys(self): + return ["RMSE_whole_object", "RMSE_background"] + [f"AEM_VOI_{name}" for name in sorted(self.voi_indices)] + + +class MetricsWithTimeout(cil_callbacks.Callback): + """Stops the algorithm after `seconds`""" + def __init__(self, seconds=600, outdir=OUTDIR, transverse_slice=None, coronal_slice=None, **kwargs): + super().__init__(**kwargs) + self._seconds = seconds + self.callbacks = [ + cil_callbacks.ProgressCallback(), + SaveIters(outdir=outdir), + (tb_cbk := StatsLog(logdir=outdir, transverse_slice=transverse_slice, coronal_slice=coronal_slice))] + self.tb = tb_cbk.tb # convenient access to the underlying SummaryWriter + self.reset() + + def reset(self, seconds=None): + self.limit = time() + (self._seconds if seconds is None else seconds) + self.offset = 0 + + def __call__(self, algo: Algorithm): + if (now := time()) > self.limit + self.offset: + log.warning("Timeout reached. Stopping algorithm.") + raise StopIteration + for c in self.callbacks: + c.__time = now - self.offset # privately inject walltime-excluding-petric-callbacks + c(algo) + self.offset += time() - now + + @staticmethod + def mean_absolute_error(y, x): + return np.mean(np.abs(y, x)) + + +def construct_RDP(penalty_strength, initial_image, kappa, max_scaling=1e-3): + """ + Construct a smoothed Relative Difference Prior (RDP) + + initial_image: used to determine a smoothing factor (epsilon). + kappa: used to pass voxel-dependent weights. + """ + prior = getattr(STIR, 'CudaRelativeDifferencePrior', STIR.RelativeDifferencePrior)() + # need to make it differentiable + epsilon = initial_image.max() * max_scaling + prior.set_epsilon(epsilon) + prior.set_penalisation_factor(penalty_strength) + prior.set_kappa(kappa) + prior.set_up(initial_image) + return prior + + +@dataclass +class Dataset: + acquired_data: STIR.AcquisitionData + additive_term: STIR.AcquisitionData + mult_factors: STIR.AcquisitionData + OSEM_image: STIR.ImageData + prior: STIR.RelativeDifferencePrior + kappa: STIR.ImageData + reference_image: STIR.ImageData | None + whole_object_mask: STIR.ImageData | None + background_mask: STIR.ImageData | None + voi_masks: dict[str, STIR.ImageData] + + +def get_data(srcdir=".", outdir=OUTDIR, sirf_verbosity=0): + """ + Load data from `srcdir`, constructs prior and return as a `Dataset`. + Also redirects sirf.STIR log output to `outdir`. + """ + srcdir = Path(srcdir) + outdir = Path(outdir) + STIR.set_verbosity(sirf_verbosity) # set to higher value to diagnose problems + STIR.AcquisitionData.set_storage_scheme('memory') # needed for get_subsets() + + _ = STIR.MessageRedirector(str(outdir / 'info.txt'), str(outdir / 'warnings.txt'), str(outdir / 'errors.txt')) + acquired_data = STIR.AcquisitionData(str(srcdir / 'prompts.hs')) + additive_term = STIR.AcquisitionData(str(srcdir / 'additive_term.hs')) + mult_factors = STIR.AcquisitionData(str(srcdir / 'mult_factors.hs')) + OSEM_image = STIR.ImageData(str(srcdir / 'OSEM_image.hv')) + kappa = STIR.ImageData(str(srcdir / 'kappa.hv')) + if (penalty_strength_file := (srcdir / 'penalisation_factor.txt')).is_file(): + penalty_strength = float(np.loadtxt(penalty_strength_file)) + else: + penalty_strength = 1 / 700 # default choice + prior = construct_RDP(penalty_strength, OSEM_image, kappa) + + def get_image(fname): + if (source := srcdir / 'PETRIC' / fname).is_file(): + return STIR.ImageData(str(source)) + return None # explicit to suppress linter warnings + + reference_image = get_image('reference_image.hv') + whole_object_mask = get_image('VOI_whole_object.hv') + background_mask = get_image('VOI_background.hv') + voi_masks = { + voi.stem[4:]: STIR.ImageData(str(voi)) + for voi in (srcdir / 'PETRIC').glob("VOI_*.hv") if voi.stem[4:] not in ('background', 'whole_object')} + + return Dataset(acquired_data, additive_term, mult_factors, OSEM_image, prior, kappa, reference_image, + whole_object_mask, background_mask, voi_masks) + + +if SRCDIR.is_dir(): + # create list of existing data + # NB: `MetricsWithTimeout` initialises `SaveIters` which creates `outdir` + data_dirs_metrics = [(SRCDIR / "Siemens_mMR_NEMA_IQ", OUTDIR / "mMR_NEMA", + [MetricsWithTimeout(outdir=OUTDIR / "mMR_NEMA", transverse_slice=72, coronal_slice=109)]), + (SRCDIR / "NeuroLF_Hoffman_Dataset", OUTDIR / "NeuroLF_Hoffman", + [MetricsWithTimeout(outdir=OUTDIR / "NeuroLF_Hoffman", transverse_slice=72)]), + (SRCDIR / "Siemens_Vision600_thorax", OUTDIR / "Vision600_thorax", + [MetricsWithTimeout(outdir=OUTDIR / "Vision600_thorax")])] +else: + log.warning("Source directory does not exist: %s", SRCDIR) + data_dirs_metrics = [(None, None, [])] # type: ignore + +if __name__ != "__main__": + # load up first data-set for people to play with + srcdir, outdir, metrics = data_dirs_metrics[0] + if srcdir is None: + data = None + else: + data = get_data(srcdir=srcdir, outdir=outdir) + metrics[0].reset() +else: + from docopt import docopt + args = docopt(__doc__) + logging.basicConfig(level=getattr(logging, args["--log"].upper())) + from main import Submission, submission_callbacks + assert issubclass(Submission, Algorithm) + for srcdir, outdir, metrics in data_dirs_metrics: + data = get_data(srcdir=srcdir, outdir=outdir) + metrics_with_timeout = metrics[0] + if data.reference_image is not None: + metrics_with_timeout.callbacks.append( + QualityMetrics(data.reference_image, data.whole_object_mask, data.background_mask, + tb_summary_writer=metrics_with_timeout.tb, voi_mask_dict=data.voi_masks)) + metrics_with_timeout.reset() # timeout from now + algo = Submission(data) + try: + algo.run(np.inf, callbacks=metrics + submission_callbacks) + except Exception: + print_exc(limit=2) + finally: + del algo \ No newline at end of file