diff --git a/bsrem.py b/bsrem.py deleted file mode 100644 index 313947f..0000000 --- a/bsrem.py +++ /dev/null @@ -1,163 +0,0 @@ -# -# SPDX-License-Identifier: Apache-2.0 -# -# Classes implementing the BSREM algorithm in sirf.STIR -# -# Authors: Kris Thielemans -# -# Copyright 2024 University College London - -import numpy -import numpy as np -import sirf.STIR as STIR -from sirf.Utilities import examples_data_path - - -from cil.optimisation.algorithms import Algorithm -from utils.herman_meyer import herman_meyer_order -import time - - -class BSREMSkeleton(Algorithm): - ''' Main implementation of a modified BSREM algorithm - - This essentially implements constrained preconditioned gradient ascent - with an EM-type preconditioner. - - In each update step, the gradient of a subset is computed, multiplied by a step_size and a EM-type preconditioner. - Before adding this to the previous iterate, an update_filter can be applied. - - ''' - def __init__(self, data, initial, - update_filter=STIR.TruncateToCylinderProcessor(), - **kwargs): - ''' - Arguments: - ``data``: list of items as returned by `partitioner` - ``initial``: initial estimate - ``initial_step_size``, ``relaxation_eta``: step-size constants - ``update_filter`` is applied on the (additive) update term, i.e. before adding to the previous iterate. - Set the filter to `None` if you don't want any. - ''' - super().__init__(**kwargs) - self.x = initial.copy() - self.data = data - self.num_subsets = len(data) - - # compute small number to add to image in preconditioner - # don't make it too small as otherwise the algorithm cannot recover from zeroes. - self.eps = initial.max()/1e3 - self.average_sensitivity = initial.get_uniform_copy(0) - for s in range(len(data)): - self.average_sensitivity += self.subset_sensitivity(s)/self.num_subsets - # add a small number to avoid division by zero in the preconditioner - self.average_sensitivity += self.average_sensitivity.max()/1e4 - - self.subset = 0 - self.update_filter = update_filter - self.configured = True - - self.subset_order = herman_meyer_order(self.num_subsets) - - self.x_prev = None - self.x_update_prev = None - - self.x_update = initial.get_uniform_copy(0) - self.new_x = initial.get_uniform_copy(0) - self.last_x = initial.get_uniform_copy(0) - - def subset_sensitivity(self, subset_num): - raise NotImplementedError - - def subset_gradient(self, x, subset_num): - raise NotImplementedError - - def epoch(self): - return (self.iteration + 1) // self.num_subsets - - def step_size(self): - return self.initial_step_size / (1 + self.relaxation_eta * self.epoch()) - - def update(self): - - g = self.subset_gradient(self.x, self.subset_order[self.subset]) - - g.multiply(self.x + self.eps, out=self.x_update) - self.x_update.divide(self.average_sensitivity, out=self.x_update) - - if self.iteration == 0: - step_size = min(max(1/(self.x_update.norm() + 1e-3), 0.005), 3.5) - else: - delta_x = self.x - self.x_prev - delta_g = self.x_update_prev - self.x_update - - dot_product = delta_g.dot(delta_x) # (deltag * deltax).sum() - alpha_long = 1.25*delta_x.norm()**2 / np.abs(dot_product) - #dot_product = delta_x.dot(delta_g) - #alpha_short = np.abs((dot_product).sum()) / delta_g.norm()**2 - #print("short / long: ", alpha_short, alpha_long) - - step_size = max(alpha_long, 0.01) #np.sqrt(alpha_long*alpha_short) - #print("step size: ", step_size) - #print("step size: ", step_size) - - self.x_prev = self.x.copy() - self.x_update_prev = self.x_update.copy() - - if self.update_filter is not None: - self.update_filter.apply(self.x_update) - - momentum = 0.4 - self.new_x.fill(self.x + step_size * self.x_update + momentum * (self.x - self.last_x)) - self.last_x = self.x.copy() - - self.x.fill(self.new_x) - #self.x.sapyb(1.0, self.x_update, step_size, out=self.x) - #self.x += beta * (self.x - self.last_x) - #self.x += self.x_update * step_size - - # threshold to non-negative - self.x.maximum(0, out=self.x) - self.subset = (self.subset + 1) % self.num_subsets - - - def update_objective(self): - # required for current CIL (needs to set self.loss) - self.loss.append(self.objective_function(self.x)) - - def objective_function(self, x): - ''' value of objective function summed over all subsets ''' - v = 0 - #for s in range(len(self.data)): - # v += self.subset_objective(x, s) - return v - - def subset_objective(self, x, subset_num): - ''' value of objective function for one subset ''' - raise NotImplementedError - - -class BSREM(BSREMSkeleton): - ''' BSREM implementation using sirf.STIR objective functions''' - def __init__(self, data, obj_funs, initial, **kwargs): - ''' - construct Algorithm with lists of data and, objective functions, initial estimate, initial step size, - step-size relaxation (per epoch) and optionally Algorithm parameters - ''' - self.obj_funs = obj_funs - super().__init__(data, initial, **kwargs) - - def subset_sensitivity(self, subset_num): - ''' Compute sensitiSvity for a particular subset''' - self.obj_funs[subset_num].set_up(self.x) - # note: sirf.STIR Poisson likelihood uses `get_subset_sensitivity(0) for the whole - # sensitivity if there are no subsets in that likelihood - return self.obj_funs[subset_num].get_subset_sensitivity(0) - - def subset_gradient(self, x, subset_num): - ''' Compute gradient at x for a particular subset''' - return self.obj_funs[subset_num].gradient(x) - - def subset_objective(self, x, subset_num): - ''' value of objective function for one subset ''' - return self.obj_funs[subset_num](x) diff --git a/bsrem_bb.py b/bsrem_bb.py deleted file mode 100644 index b004ef6..0000000 --- a/bsrem_bb.py +++ /dev/null @@ -1,159 +0,0 @@ -# -# SPDX-License-Identifier: Apache-2.0 -# -# Classes implementing the BSREM algorithm in sirf.STIR -# -# Authors: Kris Thielemans -# -# Copyright 2024 University College London - -import numpy -import numpy as np -import sirf.STIR as STIR -from sirf.Utilities import examples_data_path - - -from cil.optimisation.algorithms import Algorithm -from utils.herman_meyer import herman_meyer_order -import time - - -class BSREMSkeleton(Algorithm): - ''' Main implementation of a modified BSREM algorithm - - This essentially implements constrained preconditioned gradient ascent - with an EM-type preconditioner. - - In each update step, the gradient of a subset is computed, multiplied by a step_size and a EM-type preconditioner. - Before adding this to the previous iterate, an update_filter can be applied. - - ''' - def __init__(self, data, initial, - update_filter=STIR.TruncateToCylinderProcessor(), - **kwargs): - ''' - Arguments: - ``data``: list of items as returned by `partitioner` - ``initial``: initial estimate - ``initial_step_size``, ``relaxation_eta``: step-size constants - ``update_filter`` is applied on the (additive) update term, i.e. before adding to the previous iterate. - Set the filter to `None` if you don't want any. - ''' - super().__init__(**kwargs) - self.x = initial.copy() - self.data = data - self.num_subsets = len(data) - - # compute small number to add to image in preconditioner - # don't make it too small as otherwise the algorithm cannot recover from zeroes. - self.eps = initial.max()/1e3 - self.average_sensitivity = initial.get_uniform_copy(0) - for s in range(len(data)): - self.average_sensitivity += self.subset_sensitivity(s)/self.num_subsets - # add a small number to avoid division by zero in the preconditioner - self.average_sensitivity += self.average_sensitivity.max()/1e4 - - - - self.precond = initial.get_uniform_copy(0) - - self.subset = 0 - self.update_filter = update_filter - self.configured = True - - self.subset_order = herman_meyer_order(self.num_subsets) - - self.x_prev = None - self.x_update_prev = None - - self.x_update = initial.get_uniform_copy(0) - - def subset_sensitivity(self, subset_num): - raise NotImplementedError - - def subset_gradient(self, x, subset_num): - raise NotImplementedError - - def epoch(self): - return (self.iteration + 1) // self.num_subsets - - def step_size(self): - return self.initial_step_size / (1 + self.relaxation_eta * self.epoch()) - - def update(self): - - g = self.subset_gradient(self.x, self.subset_order[self.subset]) - - g.multiply(self.x + self.eps, out=self.x_update) - - self.x_update.divide(self.average_sensitivity, out=self.x_update) - - if self.iteration == 0: - step_size = min(max(1/(self.x_update.norm() + 1e-3), 0.005), 3.0) - else: - delta_x = self.x - self.x_prev - delta_g = self.x_update_prev - self.x_update - - dot_product = delta_g.dot(delta_x) # (deltag * deltax).sum() - alpha_long = delta_x.norm()**2 / np.abs(dot_product) - #dot_product = delta_x.dot(delta_g) - #alpha_short = np.abs((dot_product).sum()) / delta_g.norm()**2 - #print("short / long: ", alpha_short, alpha_long) - - step_size = max(alpha_long, 0.01) #np.sqrt(alpha_long*alpha_short) - #print("step size: ", step_size) - #print("step size: ", step_size) - - self.x_prev = self.x.copy() - self.x_update_prev = self.x_update.copy() - - if self.update_filter is not None: - self.update_filter.apply(self.x_update) - - self.x.sapyb(1.0, self.x_update, step_size, out=self.x) - - # threshold to non-negative - self.x.maximum(0, out=self.x) - self.subset = (self.subset + 1) % self.num_subsets - - - def update_objective(self): - # required for current CIL (needs to set self.loss) - self.loss.append(self.objective_function(self.x)) - - def objective_function(self, x): - ''' value of objective function summed over all subsets ''' - v = 0 - #for s in range(len(self.data)): - # v += self.subset_objective(x, s) - return v - - def subset_objective(self, x, subset_num): - ''' value of objective function for one subset ''' - raise NotImplementedError - - -class BSREM(BSREMSkeleton): - ''' BSREM implementation using sirf.STIR objective functions''' - def __init__(self, data, obj_funs, initial, **kwargs): - ''' - construct Algorithm with lists of data and, objective functions, initial estimate, initial step size, - step-size relaxation (per epoch) and optionally Algorithm parameters - ''' - self.obj_funs = obj_funs - super().__init__(data, initial, **kwargs) - - def subset_sensitivity(self, subset_num): - ''' Compute sensitiSvity for a particular subset''' - self.obj_funs[subset_num].set_up(self.x) - # note: sirf.STIR Poisson likelihood uses `get_subset_sensitivity(0) for the whole - # sensitivity if there are no subsets in that likelihood - return self.obj_funs[subset_num].get_subset_sensitivity(0) - - def subset_gradient(self, x, subset_num): - ''' Compute gradient at x for a particular subset''' - return self.obj_funs[subset_num].gradient(x) - - def subset_objective(self, x, subset_num): - ''' value of objective function for one subset ''' - return self.obj_funs[subset_num](x) diff --git a/bsrem_bb_rdp.py b/bsrem_bb_rdp.py deleted file mode 100644 index a15e5de..0000000 --- a/bsrem_bb_rdp.py +++ /dev/null @@ -1,211 +0,0 @@ -# -# SPDX-License-Identifier: Apache-2.0 -# -# Classes implementing the BSREM algorithm in sirf.STIR -# -# Authors: Kris Thielemans -# -# Copyright 2024 University College London - -import numpy -import numpy as np -import sirf.STIR as STIR -from sirf.Utilities import examples_data_path -import torch -torch.cuda.set_per_process_memory_fraction(0.7) - -from cil.optimisation.algorithms import Algorithm -from utils.herman_meyer import herman_meyer_order -import time - -class RDPDiagHessTorch: - def __init__(self, rdp_diag_hess, prior): - self.epsilon = prior.get_epsilon() - self.gamma = prior.get_gamma() - self.penalty_strength = prior.get_penalisation_factor() - self.rdp_diag_hess = rdp_diag_hess - self.weights = torch.zeros([3,3,3]).cuda() - self.kappa = torch.tensor(prior.get_kappa().as_array()).cuda() - self.kappa_padded = torch.nn.functional.pad(self.kappa[None], pad=(1, 1, 1, 1, 1, 1), mode='replicate')[0] - voxel_sizes = rdp_diag_hess.voxel_sizes() - z_dim, y_dim, x_dim = rdp_diag_hess.shape - for i in range(3): - for j in range(3): - for k in range(3): - self.weights[i,j,k] = voxel_sizes[2]/np.sqrt(((i-1)*voxel_sizes[0])**2 + ((j-1)*voxel_sizes[1])**2 + ((k-1)*voxel_sizes[2])**2) - self.weights[1,1,1] = 0 - self.z_dim = z_dim - self.y_dim = y_dim - self.x_dim = x_dim - - - def compute(self, x): - x = torch.tensor(x.as_array(), dtype=torch.float32).cuda() - x_padded = torch.nn.functional.pad(x[None], pad=(1, 1, 1, 1, 1, 1), mode='replicate')[0] - x_rdp_diag_hess = torch.zeros_like(x) - for dz in range(3): - for dy in range(3): - for dx in range(3): - x_neighbour = x_padded[dz:dz+self.z_dim, dy:dy+self.y_dim, dx:dx+self.x_dim] - kappa_neighbour = self.kappa_padded[dz:dz+self.z_dim, dy:dy+self.y_dim, dx:dx+self.x_dim] - kappa_val = self.kappa * kappa_neighbour - numerator = 4 * (2 * x_neighbour + self.epsilon) ** 2 - denominator = (x + x_neighbour + self.gamma * torch.abs(x - x_neighbour) + self.epsilon) ** 3 - x_rdp_diag_hess += self.weights[dz, dy, dx] * self.penalty_strength * kappa_val * numerator / denominator - return self.rdp_diag_hess.fill(x_rdp_diag_hess.cpu().numpy()) - - -class BSREMSkeleton(Algorithm): - ''' Main implementation of a modified BSREM algorithm - - This essentially implements constrained preconditioned gradient ascent - with an EM-type preconditioner. - - In each update step, the gradient of a subset is computed, multiplied by a step_size and a EM-type preconditioner. - Before adding this to the previous iterate, an update_filter can be applied. - - ''' - def __init__(self, data, initial, - update_filter=STIR.TruncateToCylinderProcessor(), - **kwargs): - ''' - Arguments: - ``data``: list of items as returned by `partitioner` - ``initial``: initial estimate - ``initial_step_size``, ``relaxation_eta``: step-size constants - ``update_filter`` is applied on the (additive) update term, i.e. before adding to the previous iterate. - Set the filter to `None` if you don't want any. - ''' - super().__init__(**kwargs) - self.x = initial.copy() - self.data = data - self.num_subsets = len(data) - - # compute small number to add to image in preconditioner - # don't make it too small as otherwise the algorithm cannot recover from zeroes. - self.eps = initial.max()/1e3 - self.average_sensitivity = initial.get_uniform_copy(0) - for s in range(len(data)): - self.average_sensitivity += self.subset_sensitivity(s)/self.num_subsets - # add a small number to avoid division by zero in the preconditioner - self.average_sensitivity += self.average_sensitivity.max()/1e4 - - self.precond = initial.get_uniform_copy(0) - - self.subset = 0 - self.update_filter = update_filter - self.configured = True - - self.subset_order = herman_meyer_order(self.num_subsets) - - self.x_prev = None - self.x_update_prev = None - - self.x_update = initial.get_uniform_copy(0) - - def subset_sensitivity(self, subset_num): - raise NotImplementedError - - def subset_gradient(self, x, subset_num): - raise NotImplementedError - - def epoch(self): - return (self.iteration + 1) // self.num_subsets - - def step_size(self): - return self.initial_step_size / (1 + self.relaxation_eta * self.epoch()) - - def update(self): - - g = self.subset_gradient(self.x, self.subset_order[self.subset]) - if self.iteration == 0: - prior_grad = self.dataset.prior.gradient(self.x) - lhkd_grad = g - prior_grad - if prior_grad.norm()/g.norm() > 0.5: - self.rdp_diag_hess_obj = RDPDiagHessTorch(self.dataset.OSEM_image.copy(), self.dataset.prior) - self.lkhd_precond = self.dataset.kappa.power(2) - self.compute_rdp_diag_hess = True - self.eps = self.lkhd_precond.max()/1e4 - else: - self.compute_rdp_diag_hess = False - self.eps = self.dataset.OSEM_image.max()/1e3 - #x_norm = self.x.norm() - #print("prior: ", prior_grad.norm(), " lhkd: ", lhkd_grad.norm(), " x: ", x_norm, " g: ", g.norm(), " prior/x: ", prior_grad.norm()/x_norm, " lhkd/x: ", lhkd_grad.norm()/x_norm, " g/x: ", g.norm()/x_norm) - #print("prior/lhkd: ", prior_grad.norm()/lhkd_grad.norm(), " prior/g: ", prior_grad.norm()/g.norm(), " lhkd/g: ", lhkd_grad.norm()/g.norm()) - - #g.multiply(self.x + self.eps, out=self.x_update) - #self.x_update.divide(self.average_sensitivity, out=self.x_update) - if self.compute_rdp_diag_hess: - g.divide(self.lkhd_precond + self.rdp_diag_hess_obj.compute(self.x) + self.eps, out=self.x_update) - else: - g.multiply(self.x + self.eps, out=self.x_update) - self.x_update.divide(self.average_sensitivity, out=self.x_update) - if self.iteration == 0: - step_size = min(max(1/(self.x_update.norm() + 1e-3), 0.005), 3.0) - else: - delta_x = self.x - self.x_prev - delta_g = self.x_update_prev - self.x_update - - dot_product = delta_g.dot(delta_x) # (deltag * deltax).sum() - alpha_long = delta_x.norm()**2 / np.abs(dot_product) - #dot_product = delta_x.dot(delta_g) - #alpha_short = np.abs((dot_product).sum()) / delta_g.norm()**2 - #print("short / long: ", alpha_short, alpha_long) - - step_size = min(alpha_long, 0.01) #np.sqrt(alpha_long*alpha_short) - #print("step size: ", step_size) - #print("step size: ", step_size) - - self.x_prev = self.x.copy() - self.x_update_prev = self.x_update.copy() - - if self.update_filter is not None: - self.update_filter.apply(self.x_update) - - self.x.sapyb(1.0, self.x_update, step_size, out=self.x) - - # threshold to non-negative - self.x.maximum(0, out=self.x) - self.subset = (self.subset + 1) % self.num_subsets - - - def update_objective(self): - # required for current CIL (needs to set self.loss) - self.loss.append(self.objective_function(self.x)) - - def objective_function(self, x): - ''' value of objective function summed over all subsets ''' - v = 0 - #for s in range(len(self.data)): - # v += self.subset_objective(x, s) - return v - - def subset_objective(self, x, subset_num): - ''' value of objective function for one subset ''' - raise NotImplementedError - - -class BSREM(BSREMSkeleton): - ''' BSREM implementation using sirf.STIR objective functions''' - def __init__(self, data, obj_funs, initial, **kwargs): - ''' - construct Algorithm with lists of data and, objective functions, initial estimate, initial step size, - step-size relaxation (per epoch) and optionally Algorithm parameters - ''' - self.obj_funs = obj_funs - super().__init__(data, initial, **kwargs) - - def subset_sensitivity(self, subset_num): - ''' Compute sensitiSvity for a particular subset''' - self.obj_funs[subset_num].set_up(self.x) - # note: sirf.STIR Poisson likelihood uses `get_subset_sensitivity(0) for the whole - # sensitivity if there are no subsets in that likelihood - return self.obj_funs[subset_num].get_subset_sensitivity(0) - - def subset_gradient(self, x, subset_num): - ''' Compute gradient at x for a particular subset''' - return self.obj_funs[subset_num].gradient(x) - - def subset_objective(self, x, subset_num): - ''' value of objective function for one subset ''' - return self.obj_funs[subset_num](x) diff --git a/bsrem_bb_saga.py b/bsrem_bb_saga.py deleted file mode 100644 index 41f21d8..0000000 --- a/bsrem_bb_saga.py +++ /dev/null @@ -1,265 +0,0 @@ -# -# SPDX-License-Identifier: Apache-2.0 -# -# Classes implementing the BSREM algorithm in sirf.STIR -# -# Authors: Kris Thielemans -# -# Copyright 2024 University College London - -import numpy -import numpy as np -import sirf.STIR as STIR -from sirf.Utilities import examples_data_path -import torch -torch.cuda.set_per_process_memory_fraction(0.7) - -from cil.optimisation.algorithms import Algorithm -from utils.herman_meyer import herman_meyer_order -import time - -class RDPDiagHessTorch: - def __init__(self, rdp_diag_hess, prior): - self.epsilon = prior.get_epsilon() - self.gamma = prior.get_gamma() - self.penalty_strength = prior.get_penalisation_factor() - self.weights = torch.zeros([3,3,3]).cuda() - self.kappa = torch.tensor(prior.get_kappa().as_array()).cuda() - self.kappa_padded = torch.nn.functional.pad(self.kappa[None], pad=(1, 1, 1, 1, 1, 1), mode='replicate')[0] - voxel_sizes = rdp_diag_hess.voxel_sizes() - z_dim, y_dim, x_dim = rdp_diag_hess.shape - for i in range(3): - for j in range(3): - for k in range(3): - self.weights[i,j,k] = voxel_sizes[2]/np.sqrt(((i-1)*voxel_sizes[0])**2 + ((j-1)*voxel_sizes[1])**2 + ((k-1)*voxel_sizes[2])**2) - self.weights[1,1,1] = 0 - self.z_dim = z_dim - self.y_dim = y_dim - self.x_dim = x_dim - - - def compute(self, x, rdp_diag_hess): - - x = torch.tensor(x.as_array(), dtype=torch.float32).cuda() - x_padded = torch.nn.functional.pad(x[None], pad=(1, 1, 1, 1, 1, 1), mode='replicate')[0] - x_rdp_diag_hess = torch.zeros_like(x) - for dz in range(3): - for dy in range(3): - for dx in range(3): - x_neighbour = x_padded[dz:dz+self.z_dim, dy:dy+self.y_dim, dx:dx+self.x_dim] - kappa_neighbour = self.kappa_padded[dz:dz+self.z_dim, dy:dy+self.y_dim, dx:dx+self.x_dim] - kappa_val = self.kappa * kappa_neighbour - numerator = 4 * (2 * x_neighbour + self.epsilon) ** 2 - denominator = (x + x_neighbour + self.gamma * torch.abs(x - x_neighbour) + self.epsilon) ** 3 - x_rdp_diag_hess += self.weights[dz, dy, dx] * self.penalty_strength * kappa_val * numerator / denominator - - rdp_diag_hess.fill(x_rdp_diag_hess.cpu().numpy()) - - -class BSREMSkeleton(Algorithm): - ''' Main implementation of a modified BSREM algorithm - - This essentially implements constrained preconditioned gradient ascent - with an EM-type preconditioner. - - In each update step, the gradient of a subset is computed, multiplied by a step_size and a EM-type preconditioner. - Before adding this to the previous iterate, an update_filter can be applied. - - ''' - def __init__(self, data, initial, - update_filter=STIR.TruncateToCylinderProcessor(), - **kwargs): - ''' - Arguments: - ``data``: list of items as returned by `partitioner` - ``initial``: initial estimate - ``initial_step_size``, ``relaxation_eta``: step-size constants - ``update_filter`` is applied on the (additive) update term, i.e. before adding to the previous iterate. - Set the filter to `None` if you don't want any. - ''' - super().__init__(**kwargs) - self.x = initial.copy() - self.initial = initial.copy() - self.data = data - self.num_subsets = len(data) - self.num_subsets_initial = len(data) - - # compute small number to add to image in preconditioner - # don't make it too small as otherwise the algorithm cannot recover from zeroes. - self.eps = initial.max()/1e3 - self.average_sensitivity = initial.get_uniform_copy(0) - for s in range(len(data)): - self.average_sensitivity += self.subset_sensitivity(s)/self.num_subsets - # add a small number to avoid division by zero in the preconditioner - self.average_sensitivity += self.average_sensitivity.max()/1e4 - - self.precond = initial.get_uniform_copy(0) - - self.subset = 0 - self.update_filter = update_filter - self.configured = True - - self.subset_order = herman_meyer_order(self.num_subsets) - - self.x_update = initial.get_uniform_copy(0) - self.c = 1 - # DOG parameters - self.max_distance = 0 - self.sum_gradient = 0 - - self.precond = initial.get_uniform_copy(0) - - self.gm = [self.x.get_uniform_copy(0) for _ in range(self.num_subsets)] - - self.sum_gm = self.x.get_uniform_copy(0) - - def subset_sensitivity(self, subset_num): - raise NotImplementedError - - def subset_gradient(self, x, subset_num): - raise NotImplementedError - - def epoch(self): - return (self.iteration) // self.num_subsets - - def step_size(self): - return self.initial_step_size / (1 + self.relaxation_eta * self.epoch()) - - def get_number_of_subsets_to_accumulate_gradient(self): - for index, boundary in enumerate(self.accumulate_gradient_iter): - if self.iteration < boundary*self.num_subsets_initial: - return self.accumulate_gradient_num[index] - return self.num_subsets - - - def update(self): - if self.iteration == 0: - prior_grad = self.prior.gradient(self.x) - g = self.x.get_uniform_copy(0) - for i in range(self.num_subsets): - gm = self.subset_gradient_pll(self.x, self.subset_order[i]) - self.gm[self.subset_order[i]] = gm - prior_grad - g.add(gm, out=g) - - gradient = g / self.num_subsets - prior_grad - self.sum_gm = g - prior_grad * self.num_subsets - - if prior_grad.norm()/gradient.norm() > 0.5: - print("Choose RDP and kappa precond") - self.rdp_diag_hess_obj = RDPDiagHessTorch(self.dataset.OSEM_image.copy(), self.dataset.prior) - self.lkhd_precond = self.dataset.kappa.power(2) - self.compute_rdp_diag_hess = True - self.eps = self.lkhd_precond.max()/1e4 - print("Compute inital precond") - self.rdp_diag_hess_obj.compute(self.x, self.precond) - print(self.precond.min(), self.precond.max(), self.precond.norm()) - print("Finished computing inital precond") - else: - print("Choose EM precond") - self.compute_rdp_diag_hess = False - self.eps = self.dataset.OSEM_image.max()/1e3 - - - if self.iteration > 0: - subset_choice = self.subset_order[self.subset] - g = self.subset_gradient(self.x, subset_choice) - gradient = (g - self.gm[subset_choice]) + self.sum_gm / self.num_subsets - - if self.compute_rdp_diag_hess: - print(self.iteration % self.num_subsets) - if self.iteration % self.num_subsets == 0 and self.iteration > 0: - print("Compute Precond Again!") - self.rdp_diag_hess_obj.compute(self.x, self.precond) - gradient.divide(self.lkhd_precond + self.precond + self.eps, out=self.x_update) - else: - gradient.multiply(self.x + self.eps, out=self.x_update) - self.x_update.divide(self.average_sensitivity, out=self.x_update) - - - if self.update_filter is not None: - self.update_filter.apply(self.x_update) - - #distance = (self.x - self.initial).norm() - #if distance > self.max_distance: - # self.max_distance = distance - - #self.sum_gradient += self.x_update.norm()**2 - - if self.iteration == 0: - self.alpha = min(max(1/(self.x_update.norm() + 1e-3), 0.005), 1.0) - else: - nominator = self.last_g.dot(self.last_x_update) - - deltax = self.x - self.last_x - deltag = gradient - self.last_g - - denominator = deltax.dot(deltag) - - step_size = self.alpha**2 * np.abs(nominator) / np.abs(denominator) - k = self.iteration + 2 - phik = (k + 1) # /self.num_subsets_initial - self.c = self.c ** ((k-2)/(k-1)) * (step_size*phik) ** (1/(k-1)) - self.alpha = self.c / phik - #step_size = self.max_distance / np.sqrt(self.sum_gradient) - - self.last_x = self.x.copy() - - print("Step size: ", self.alpha) - self.x.sapyb(1.0, self.x_update, self.alpha, out=self.x) - - # threshold to non-negative - self.x.maximum(0, out=self.x) - self.subset = (self.subset + 1) % self.num_subsets - - self.last_g = gradient.copy() - self.last_x_update = self.x_update.copy() - - if self.iteration > 0: - self.sum_gm = self.sum_gm - self.gm[subset_choice] + g - self.gm[subset_choice] = g - - def update_objective(self): - # required for current CIL (needs to set self.loss) - self.loss.append(self.objective_function(self.x)) - - def objective_function(self, x): - ''' value of objective function summed over all subsets ''' - v = 0 - #for s in range(len(self.data)): - # v += self.subset_objective(x, s) - return v - - def subset_objective(self, x, subset_num): - ''' value of objective function for one subset ''' - raise NotImplementedError - - -class BSREM(BSREMSkeleton): - ''' BSREM implementation using sirf.STIR objective functions''' - def __init__(self, data, obj_funs, prior, initial, **kwargs): - ''' - construct Algorithm with lists of data and, objective functions, initial estimate, initial step size, - step-size relaxation (per epoch) and optionally Algorithm parameters - ''' - self.obj_funs = obj_funs - self.prior = prior - super().__init__(data, initial, **kwargs) - - def subset_sensitivity(self, subset_num): - ''' Compute sensitiSvity for a particular subset''' - self.obj_funs[subset_num].set_up(self.x) - # note: sirf.STIR Poisson likelihood uses `get_subset_sensitivity(0) for the whole - # sensitivity if there are no subsets in that likelihood - return self.obj_funs[subset_num].get_subset_sensitivity(0) - - def subset_gradient_pll(self, x, subset_num): - ''' Compute gradient at x for a particular subset''' - return self.obj_funs[subset_num].gradient(x) - - def subset_gradient(self, x, subset_num): - ''' Compute gradient at x for a particular subset''' - return self.obj_funs[subset_num].gradient(x) - self.prior.gradient(x) - - def subset_objective(self, x, subset_num): - ''' value of objective function for one subset ''' - return self.obj_funs[subset_num](x) diff --git a/bsrem_bb_subset.py b/bsrem_bb_subset.py deleted file mode 100644 index c5e4bf1..0000000 --- a/bsrem_bb_subset.py +++ /dev/null @@ -1,271 +0,0 @@ -# -# SPDX-License-Identifier: Apache-2.0 -# -# Classes implementing the BSREM algorithm in sirf.STIR -# -# Authors: Kris Thielemans -# -# Copyright 2024 University College London - -import numpy -import numpy as np -import sirf.STIR as STIR -from sirf.Utilities import examples_data_path -import torch -torch.cuda.set_per_process_memory_fraction(0.7) - -from cil.optimisation.algorithms import Algorithm -from utils.herman_meyer import herman_meyer_order -import time - -class RDPDiagHessTorch: - def __init__(self, rdp_diag_hess, prior): - self.epsilon = prior.get_epsilon() - self.gamma = prior.get_gamma() - self.penalty_strength = prior.get_penalisation_factor() - self.weights = torch.zeros([3,3,3]).cuda() - self.kappa = torch.tensor(prior.get_kappa().as_array()).cuda() - self.kappa_padded = torch.nn.functional.pad(self.kappa[None], pad=(1, 1, 1, 1, 1, 1), mode='replicate')[0] - voxel_sizes = rdp_diag_hess.voxel_sizes() - z_dim, y_dim, x_dim = rdp_diag_hess.shape - for i in range(3): - for j in range(3): - for k in range(3): - self.weights[i,j,k] = voxel_sizes[2]/np.sqrt(((i-1)*voxel_sizes[0])**2 + ((j-1)*voxel_sizes[1])**2 + ((k-1)*voxel_sizes[2])**2) - self.weights[1,1,1] = 0 - self.z_dim = z_dim - self.y_dim = y_dim - self.x_dim = x_dim - - - def compute(self, x, rdp_diag_hess): - - x = torch.tensor(x.as_array(), dtype=torch.float32).cuda() - x_padded = torch.nn.functional.pad(x[None], pad=(1, 1, 1, 1, 1, 1), mode='replicate')[0] - x_rdp_diag_hess = torch.zeros_like(x) - for dz in range(3): - for dy in range(3): - for dx in range(3): - x_neighbour = x_padded[dz:dz+self.z_dim, dy:dy+self.y_dim, dx:dx+self.x_dim] - kappa_neighbour = self.kappa_padded[dz:dz+self.z_dim, dy:dy+self.y_dim, dx:dx+self.x_dim] - kappa_val = self.kappa * kappa_neighbour - numerator = 4 * (2 * x_neighbour + self.epsilon) ** 2 - denominator = (x + x_neighbour + self.gamma * torch.abs(x - x_neighbour) + self.epsilon) ** 3 - x_rdp_diag_hess += self.weights[dz, dy, dx] * self.penalty_strength * kappa_val * numerator / denominator - - rdp_diag_hess.fill(x_rdp_diag_hess.cpu().numpy()) - - -class BSREMSkeleton(Algorithm): - ''' Main implementation of a modified BSREM algorithm - - This essentially implements constrained preconditioned gradient ascent - with an EM-type preconditioner. - - In each update step, the gradient of a subset is computed, multiplied by a step_size and a EM-type preconditioner. - Before adding this to the previous iterate, an update_filter can be applied. - - ''' - def __init__(self, data, initial, - update_filter=STIR.TruncateToCylinderProcessor(), - **kwargs): - ''' - Arguments: - ``data``: list of items as returned by `partitioner` - ``initial``: initial estimate - ``initial_step_size``, ``relaxation_eta``: step-size constants - ``update_filter`` is applied on the (additive) update term, i.e. before adding to the previous iterate. - Set the filter to `None` if you don't want any. - ''' - super().__init__(**kwargs) - self.x = initial.copy() - self.initial = initial.copy() - self.data = data - self.num_subsets = len(data) - self.num_subsets_initial = len(data) - - # compute small number to add to image in preconditioner - # don't make it too small as otherwise the algorithm cannot recover from zeroes. - self.eps = initial.max()/1e3 - self.average_sensitivity = initial.get_uniform_copy(0) - for s in range(len(data)): - self.average_sensitivity += self.subset_sensitivity(s)/self.num_subsets - # add a small number to avoid division by zero in the preconditioner - self.average_sensitivity += self.average_sensitivity.max()/1e4 - - self.precond = initial.get_uniform_copy(0) - - self.subset = 0 - self.update_filter = update_filter - self.configured = True - - self.subset_order = herman_meyer_order(self.num_subsets) - - self.x_update = initial.get_uniform_copy(0) - - self.accumulate_gradient_iter = [4, 6, 10, 12, 20] - self.accumulate_gradient_num = [1, 2, 4, 8, 10, 20] - self.rdp_update_interval = 20 - - # DOG parameters - self.max_distance = 0 - self.sum_gradient = 0 - - self.precond = initial.get_uniform_copy(0) - - - def subset_sensitivity(self, subset_num): - raise NotImplementedError - - def subset_gradient(self, x, subset_num): - raise NotImplementedError - - def epoch(self): - return (self.iteration) // self.num_subsets - - def step_size(self): - return self.initial_step_size / (1 + self.relaxation_eta * self.epoch()) - - def get_number_of_subsets_to_accumulate_gradient(self): - for index, boundary in enumerate(self.accumulate_gradient_iter): - if self.iteration < boundary*self.num_subsets_initial: - return self.accumulate_gradient_num[index] - return self.num_subsets - - - def update(self): - if self.iteration == 0: - num_to_accumulate = self.num_subsets - else: - num_to_accumulate = self.get_number_of_subsets_to_accumulate_gradient() - print("Num to accumulate: ", num_to_accumulate) - # use at most all subsets - if num_to_accumulate > self.num_subsets_initial: - num_to_accumulate = self.num_subsets_initial - - #print(f"Use {num_to_accumulate} subsets at iteration {self.iteration}") - for i in range(num_to_accumulate): - if i == 0: - g = self.subset_gradient_pll(self.x, self.subset_order[self.subset]) - else: - g += self.subset_gradient_pll(self.x, self.subset_order[self.subset]) - self.subset = (self.subset + 1) % self.num_subsets - #print(f"\n Added subset {i+1} (i.e. {self.subset}) of {num_to_accumulate}\n") - - g /= num_to_accumulate - - prior_grad = self.prior.gradient(self.x) - g -= prior_grad - - - if self.iteration == 0: - #lhkd_grad = g - prior_grad - if prior_grad.norm()/g.norm() > 0.5: - print("Choose RDP and kappa precond") - self.rdp_diag_hess_obj = RDPDiagHessTorch(self.dataset.OSEM_image.copy(), self.dataset.prior) - self.lkhd_precond = self.dataset.kappa.power(2) - self.compute_rdp_diag_hess = True - self.eps = self.lkhd_precond.max()/1e4 - print("Compute inital precond") - self.rdp_diag_hess_obj.compute(self.x, self.precond) - print(self.precond.min(), self.precond.max(), self.precond.norm()) - print("Finished computing inital precond") - else: - print("Choose EM precond") - self.compute_rdp_diag_hess = False - self.eps = self.dataset.OSEM_image.max()/1e3 - - - if self.compute_rdp_diag_hess: - print(self.iteration % self.rdp_update_interval) - if self.iteration % self.rdp_update_interval == 0 and self.iteration > 0: - print("Compute Precond Again!") - self.rdp_diag_hess_obj.compute(self.x, self.precond) - g.divide(self.lkhd_precond + self.precond + self.eps, out=self.x_update) - else: - g.multiply(self.x + self.eps, out=self.x_update) - self.x_update.divide(self.average_sensitivity, out=self.x_update) - - - if self.update_filter is not None: - self.update_filter.apply(self.x_update) - - #distance = (self.x - self.initial).norm() - #if distance > self.max_distance: - # self.max_distance = distance - - self.sum_gradient += self.x_update.norm()**2 - - if self.iteration == 0: - self.alpha = min(max(1/(self.x_update.norm() + 1e-3), 0.005), 1.0) - else: - nominator = self.last_g.dot(self.last_x_update) - - deltax = self.x - self.last_x - deltag = g - self.last_g - - denominator = deltax.dot(deltag) - - step_size = self.alpha**2 * np.abs(nominator) / np.abs(denominator) - - self.alpha = step_size - - #step_size = self.max_distance / np.sqrt(self.sum_gradient) - - self.last_x = self.x.copy() - - print("Step size: ", self.alpha) - self.x.sapyb(1.0, self.x_update, self.alpha, out=self.x) - - # threshold to non-negative - self.x.maximum(0, out=self.x) - self.subset = (self.subset + 1) % self.num_subsets - - self.last_g = g.copy() - self.last_x_update = self.x_update.copy() - - def update_objective(self): - # required for current CIL (needs to set self.loss) - self.loss.append(self.objective_function(self.x)) - - def objective_function(self, x): - ''' value of objective function summed over all subsets ''' - v = 0 - #for s in range(len(self.data)): - # v += self.subset_objective(x, s) - return v - - def subset_objective(self, x, subset_num): - ''' value of objective function for one subset ''' - raise NotImplementedError - - -class BSREM(BSREMSkeleton): - ''' BSREM implementation using sirf.STIR objective functions''' - def __init__(self, data, obj_funs, prior, initial, **kwargs): - ''' - construct Algorithm with lists of data and, objective functions, initial estimate, initial step size, - step-size relaxation (per epoch) and optionally Algorithm parameters - ''' - self.obj_funs = obj_funs - self.prior = prior - super().__init__(data, initial, **kwargs) - - def subset_sensitivity(self, subset_num): - ''' Compute sensitiSvity for a particular subset''' - self.obj_funs[subset_num].set_up(self.x) - # note: sirf.STIR Poisson likelihood uses `get_subset_sensitivity(0) for the whole - # sensitivity if there are no subsets in that likelihood - return self.obj_funs[subset_num].get_subset_sensitivity(0) - - def subset_gradient_pll(self, x, subset_num): - ''' Compute gradient at x for a particular subset''' - return self.obj_funs[subset_num].gradient(x) - - def subset_gradient(self, x, subset_num): - ''' Compute gradient at x for a particular subset''' - return self.obj_funs[subset_num].gradient(x) - self.prior.gradient(x) - - def subset_objective(self, x, subset_num): - ''' value of objective function for one subset ''' - return self.obj_funs[subset_num](x) diff --git a/bsrem_dowg.py b/bsrem_dowg.py index bcd45c4..2ca5dcf 100644 --- a/bsrem_dowg.py +++ b/bsrem_dowg.py @@ -79,8 +79,26 @@ def step_size(self): return self.initial_step_size / (1 + self.relaxation_eta * self.epoch()) def update(self): + if self.epoch() < 10: + g = self.subset_gradient(self.x, self.subset_order[self.subset]) + elif self.epoch() >= 10 and self.epoch() < 20: + for i in range(2): + if i == 0: + g = self.subset_gradient(self.x, self.subset_order[self.subset]) + else: + g += self.subset_gradient(self.x, self.subset_order[self.subset]) + self.subset = (self.subset + 1) % self.num_subsets - g = self.subset_gradient(self.x, self.subset_order[self.subset]) + g /= 2 + else: + for i in range(4): + if i == 0: + g = self.subset_gradient(self.x, self.subset_order[self.subset]) + else: + g += self.subset_gradient(self.x, self.subset_order[self.subset]) + self.subset = (self.subset + 1) % self.num_subsets + + g /= 4 g.multiply(self.x + self.eps, out=self.x_update) self.x_update.divide(self.average_sensitivity, out=self.x_update) diff --git a/bsrem_saga.py b/bsrem_saga.py deleted file mode 100644 index a9f34af..0000000 --- a/bsrem_saga.py +++ /dev/null @@ -1,253 +0,0 @@ -# -# -# Classes implementing the SAGA algorithm in sirf.STIR -# -# A. Defazio, F. Bach, and S. Lacoste-Julien, “SAGA: A Fast -# Incremental Gradient Method With Support for Non-Strongly -# Convex Composite Objectives,” in Advances in Neural Infor- -# mation Processing Systems, vol. 27, Curran Associates, Inc., 2014 -# -# Twyman, R., Arridge, S., Kereta, Z., Jin, B., Brusaferri, L., -# Ahn, S., ... & Thielemans, K. (2022). An investigation of stochastic variance -# reduction algorithms for relative difference penalized 3D PET image reconstruction. -# IEEE Transactions on Medical Imaging, 42(1), 29-41. - -import numpy -import numpy as np -import sirf.STIR as STIR - -from cil.optimisation.algorithms import Algorithm -from utils.herman_meyer import herman_meyer_order - -import torch - -class SAGASkeleton(Algorithm): - ''' Main implementation of a modified BSREM algorithm - - This essentially implements constrained preconditioned gradient ascent - with an EM-type preconditioner. - - In each update step, the gradient of a subset is computed, multiplied by a step_size and a EM-type preconditioner. - Before adding this to the previous iterate, an update_filter can be applied. - - ''' - def __init__(self, data, initial, average_sensitivity, - update_filter=STIR.TruncateToCylinderProcessor(), **kwargs): - ''' - Arguments: - ``data``: list of items as returned by `partitioner` - ``initial``: initial estimate - ``update_filter`` is applied on the (additive) update term, i.e. before adding to the previous iterate. - Set the filter to `None` if you don't want any. - ''' - super().__init__(**kwargs) - - self.x = initial - self.initial = initial.copy() - self.data = data - self.num_subsets = len(data) - self.average_sensitivity = average_sensitivity - self.eps = self.dataset.OSEM_image.max()/1e3 - - self.subset = 0 - self.update_filter = update_filter - self.configured = True - - # DOG parameters - self.max_distance = 0 - self.sum_gradient = 0 - - self.alpha = None - self.last_alpha = None - self.subset_order = herman_meyer_order(self.num_subsets) - - self.gm = [self.x.get_uniform_copy(0) for _ in range(self.num_subsets)] - - self.sum_gm = self.x.get_uniform_copy(0) - self.x_update = self.x.get_uniform_copy(0) - - self.last_objective_function = self.objective_function_inter(self.x) - self.gamma = 1.0 # scaling for learning rate - - def subset_sensitivity(self, subset_num): - raise NotImplementedError - - def subset_gradient(self, x, subset_num): - raise NotImplementedError - - def subset_gradient_likelihood(self, x, subset_num): - raise NotImplementedError - - def subset_gradient_prior(self, x, subset_num): - raise NotImplementedError - - def epoch(self): - return self.iteration // self.num_subsets - - def update(self): - if self.epoch() % 4 == 0 and self.iteration % self.num_subsets == 0 and self.epoch() > 0: - loss = self.objective_function_inter(self.x) - #print("Objective at ", self.epoch(), " is = ", loss) - - if loss < self.last_objective_function: - #print("Reduce learning rate!") - self.gamma = self.gamma * 0.75 - - self.last_objective_function = loss - # for the first epochs just do SGD - if self.epoch() < 2: - # construct gradient of subset - subset_choice = self.subset_order[self.subset] - g = self.subset_gradient(self.x, subset_choice) - #print("Gradient norm: ", g.norm()) - g.multiply(self.x + self.eps, out=self.x_update) - self.x_update.divide(self.average_sensitivity, out=self.x_update) - #self.x_update = (self.x + self.eps) * g / self.average_sensitivity - - # SGD for two epochs - if self.iteration == 0: - step_size_estimate = min(max(1/(self.x_update.norm() + 1e-3), 0.05), 3.0) - self.alpha = step_size_estimate - - distance = (self.x - self.initial).norm() - if distance > self.max_distance: - self.max_distance = distance - - self.sum_gradient += self.x_update.norm()**2 - - if self.iteration > 0: - self.alpha = self.max_distance / np.sqrt(self.sum_gradient) - - if self.update_filter is not None: - self.update_filter.apply(self.x_update) - - #print(self.alpha, self.sum_gradient) - self.x.sapyb(1.0, self.x_update, self.alpha, out=self.x) - #self.x += self.alpha * self.x_update - self.x.maximum(0, out=self.x) - - # do SAGA - else: - # do one step of full gradient descent to set up subset gradients - - if (self.epoch() in [2]) and self.iteration % self.num_subsets == 0: - # construct gradient of subset - #print("One full gradient step to intialise SAGA") - g = self.x.get_uniform_copy(0) - for i in range(self.num_subsets): - gm = self.subset_gradient(self.x, self.subset_order[i]) - self.gm[self.subset_order[i]] = gm - g.add(gm, out=g) - #g += gm - - g /= self.num_subsets - #print("Gradient norm: ", g.norm()) - distance = (self.x - self.initial).norm() - if distance > self.max_distance: - self.max_distance = distance - - g.multiply(self.x + self.eps, out=self.x_update) - self.x_update.divide(self.average_sensitivity, out=self.x_update) - #self.x_update = (self.x + self.eps) * g / self.average_sensitivity - self.sum_gradient += self.x_update.norm()**2 - self.alpha = self.max_distance / np.sqrt(self.sum_gradient) - - if self.update_filter is not None: - self.update_filter.apply(self.x_update) - - self.x.sapyb(1.0, self.x_update, self.gamma*self.alpha, out=self.x) - #self.x += self.alpha * self.x_update - - # threshold to non-negative - self.x.maximum(0, out=self.x) - - self.sum_gm = self.x.get_uniform_copy(0) - for gm in self.gm: - self.sum_gm += gm - - - subset_choice = self.subset_order[self.subset] - g = self.subset_gradient(self.x, subset_choice) - - #gradient = self.num_subsets * (g - self.gm[subset_choice]) + self.sum_gm #/ self.num_subsets - gradient = (g - self.gm[subset_choice]) + self.sum_gm / self.num_subsets - - distance = (self.x - self.initial).norm() - if distance > self.max_distance: - self.max_distance = distance - - gradient.multiply(self.x + self.eps, out=self.x_update) - self.x_update.divide(self.average_sensitivity, out=self.x_update) - #self.x_update = (self.x + self.eps) * gradient / self.average_sensitivity - - self.sum_gradient += self.x_update.norm()**2 - - if self.update_filter is not None: - self.update_filter.apply(self.x_update) - - # DOG lr - self.alpha = self.max_distance / np.sqrt(self.sum_gradient) - - self.x.sapyb(1.0, self.x_update, self.gamma*self.alpha, out=self.x) - #self.x += self.alpha * self.x_update - - # threshold to non-negative - self.x.maximum(0, out=self.x) - - self.sum_gm = self.sum_gm - self.gm[subset_choice] + g - self.gm[subset_choice] = g - - - self.subset = (self.subset + 1) % self.num_subsets - self.last_alpha = self.alpha - - def update_objective(self): - # required for current CIL (needs to set self.loss) - self.loss.append(self.objective_function(self.x)) - - def objective_function(self, x): - ''' value of objective function summed over all subsets ''' - v = 0 - #for s in range(len(self.data)): - # v += self.subset_objective(x, s) - return v - - def objective_function_inter(self, x): - ''' value of objective function summed over all subsets ''' - v = 0 - for s in range(len(self.data)): - v += self.subset_objective(x, s) - return v - - - def subset_objective(self, x, subset_num): - ''' value of objective function for one subset ''' - raise NotImplementedError - - -class SAGA(SAGASkeleton): - ''' SAGA implementation using sirf.STIR objective functions''' - def __init__(self, data, obj_funs, initial, average_sensitivity, **kwargs): - ''' - construct Algorithm with lists of data and, objective functions, initial estimate - and optionally Algorithm parameters - ''' - self.obj_funs = obj_funs - super().__init__(data, initial,average_sensitivity, **kwargs) - - def subset_sensitivity(self, subset_num): - ''' Compute sensitivity for a particular subset''' - self.obj_funs[subset_num].set_up(self.x) - # note: sirf.STIR Poisson likelihood uses `get_subset_sensitivity(0) for the whole - # sensitivity if there are no subsets in that likelihood - return self.obj_funs[subset_num].get_subset_sensitivity(0) - - def subset_gradient(self, x, subset_num): - ''' Compute gradient at x for a particular subset''' - return self.obj_funs[subset_num].gradient(x) - - def subset_objective(self, x, subset_num): - ''' value of objective function for one subset ''' - return self.obj_funs[subset_num](x) - - diff --git a/main.py b/main.py deleted file mode 100644 index 02a6abe..0000000 --- a/main.py +++ /dev/null @@ -1,76 +0,0 @@ - -from cil.optimisation.algorithms import Algorithm -from cil.optimisation.utilities import callbacks - -""" -EWS: Postprocessing + some iterative method - -""" - - -from bsrem_bb import BSREM -from utils.number_of_subsets import compute_number_of_subsets - -from sirf.contrib.partitioner import partitioner -#from utils.partioner_function import data_partition -#from utils.partioner_function_no_obj import data_partition - -assert issubclass(BSREM, Algorithm) - - -import torch -torch.cuda.set_per_process_memory_fraction(0.8) - -import setup_postprocessing - - -class MaxIteration(callbacks.Callback): - def __init__(self, max_iteration: int, verbose: int = 1): - super().__init__(verbose) - self.max_iteration = max_iteration - - def __call__(self, algorithm: Algorithm): - if algorithm.iteration >= self.max_iteration: - raise StopIteration - - -class Submission(BSREM): - def __init__(self, data, - update_objective_interval: int = 2, - **kwargs): - - num_subsets = 1 - - data_sub, _, obj_funs = partitioner.data_partition(data.acquired_data, data.additive_term, - data.mult_factors, num_subsets, - initial_image=data.OSEM_image, - mode = "staggered") - - self.dataset = data - - # WARNING: modifies prior strength with 1/num_subsets - data.prior.set_penalisation_factor(data.prior.get_penalisation_factor() / len(data_sub)) - data.prior.set_up(data.OSEM_image) - - DEVICE = "cuda" - - initial_images = torch.from_numpy(data.OSEM_image.as_array()).float().to(DEVICE).unsqueeze(0).unsqueeze(0) - with torch.no_grad(): - x_pred = setup_postprocessing.postprocessing_model(initial_images) - x_pred[x_pred < 0] = 0 - - #del setup_model.network_precond - del initial_images - - initial = data.OSEM_image.clone() - initial.fill(x_pred.detach().cpu().numpy().squeeze()) - - for f in obj_funs: # add prior evenly to every objective function - f.set_prior(data.prior) - - super().__init__(data=data_sub, - obj_funs=obj_funs, - initial=initial, - update_objective_interval=update_objective_interval) - -submission_callbacks = [] \ No newline at end of file diff --git a/main_BSREM.py b/main_BSREM.py deleted file mode 100644 index 3a2021d..0000000 --- a/main_BSREM.py +++ /dev/null @@ -1,67 +0,0 @@ -"""Main file to modify for submissions. - -Once renamed or symlinked as `main.py`, it will be used by `petric.py` as follows: - ->>> from main import Submission, submission_callbacks ->>> from petric import data, metrics ->>> algorithm = Submission(data) ->>> algorithm.run(np.inf, callbacks=metrics + submission_callbacks) -""" -from cil.optimisation.algorithms import Algorithm -from cil.optimisation.utilities import callbacks - -from sirf.contrib.partitioner import partitioner -from copy import deepcopy -from utils.number_of_subsets import compute_number_of_subsets -from bsrem_bb_saga import BSREM - -assert issubclass(BSREM, Algorithm) - - -class MaxIteration(callbacks.Callback): - """ - The organisers try to `Submission(data).run(inf)` i.e. for infinite iterations (until timeout). - This callback forces stopping after `max_iteration` instead. - """ - def __init__(self, max_iteration: int, verbose: int = 1): - super().__init__(verbose) - self.max_iteration = max_iteration - - def __call__(self, algorithm: Algorithm): - if algorithm.iteration >= self.max_iteration: - raise StopIteration - -class Submission(BSREM): - def __init__(self, data, - update_objective_interval: int = 10, - **kwargs): - - tof = (data.acquired_data.shape[0] > 1) - views = data.acquired_data.shape[2] - num_subsets = compute_number_of_subsets(views, tof) - print("Number of views: ", views, " use ", num_subsets, " subsets") - data_sub, _, obj_funs = partitioner.data_partition(data.acquired_data, data.additive_term, - data.mult_factors, num_subsets, - initial_image=data.OSEM_image) - - # WARNING: modifies prior strength with 1/num_subsets (as currently needed for BSREM implementations - self.dataset = data - data.prior.set_penalisation_factor(data.prior.get_penalisation_factor() / len(obj_funs)) - data.prior.set_up(data.OSEM_image) - - #print("prior: ", data.prior(data.OSEM_image)) - #data.prior = data.prior.set_penalisation_factor(data.prior.get_penalisation_factor()) - #data.prior.set_up(data.OSEM_image) - #print(data.prior.get_penalisation_factor()) - - - #print("prior: ", data.prior(data.OSEM_image)) - - super().__init__(data_sub, - obj_funs, - prior=data.prior, - initial=data.OSEM_image, - update_objective_interval=update_objective_interval) - - -submission_callbacks = [] #[MaxIteration(660)] diff --git a/main_EWS_SAGA.py b/main_EWS_SAGA.py deleted file mode 100644 index 6e1b74b..0000000 --- a/main_EWS_SAGA.py +++ /dev/null @@ -1,102 +0,0 @@ - -from cil.optimisation.algorithms import Algorithm -from cil.optimisation.utilities import callbacks - - -from bsrem_saga import SAGA -from utils.number_of_subsets import compute_number_of_subsets - -from sirf.contrib.partitioner import partitioner -#from utils.partioner_function import data_partition -#from utils.partioner_function_no_obj import data_partition - -assert issubclass(SAGA, Algorithm) - - -import torch -torch.cuda.set_per_process_memory_fraction(0.8) - -import setup_model - - -class MaxIteration(callbacks.Callback): - def __init__(self, max_iteration: int, verbose: int = 1): - super().__init__(verbose) - self.max_iteration = max_iteration - - def __call__(self, algorithm: Algorithm): - if algorithm.iteration >= self.max_iteration: - raise StopIteration - - -class Submission(SAGA): - def __init__(self, data, - update_objective_interval: int = 10, - **kwargs): - - tof = (data.acquired_data.shape[0] > 1) - views = data.acquired_data.shape[2] - num_subsets = compute_number_of_subsets(views, tof) - - - data_sub, _, obj_funs = partitioner.data_partition(data.acquired_data, data.additive_term, - data.mult_factors, num_subsets, - initial_image=data.OSEM_image, - mode = "staggered") - - self.dataset = data - - # WARNING: modifies prior strength with 1/num_subsets - data.prior.set_penalisation_factor(data.prior.get_penalisation_factor() / len(data_sub)) - data.prior.set_up(data.OSEM_image) - - sensitivity = data.OSEM_image.get_uniform_copy(0) - for s in range(len(data_sub)): - obj_funs[s].set_up(data.OSEM_image) - sensitivity.add(obj_funs[s].get_subset_sensitivity(0), out=sensitivity) - - pll_grad = data.OSEM_image.get_uniform_copy(0) - for s in range(len(data_sub)): - pll_grad.add(obj_funs[s].gradient(data.OSEM_image), out=pll_grad) - - average_sensitivity = sensitivity.clone() / num_subsets - average_sensitivity += average_sensitivity.max()/1e4 - - sensitivity += sensitivity.max()/1e4 - eps = data.OSEM_image.max()/1e3 - - prior_grad = data.prior.gradient(data.OSEM_image) * num_subsets - - grad = (data.OSEM_image + eps) * pll_grad / sensitivity - prior_grad = (data.OSEM_image + eps) * prior_grad / sensitivity - - DEVICE = "cuda" - - initial_images = torch.from_numpy(data.OSEM_image.as_array()).float().to(DEVICE).unsqueeze(0) - prior_grads = torch.from_numpy(prior_grad.as_array()).float().to(DEVICE).unsqueeze(0) - pll_grads = torch.from_numpy(grad.as_array()).float().to(DEVICE).unsqueeze(0) - - model_inp = torch.cat([initial_images, pll_grads, prior_grads], dim=0).unsqueeze(0) - with torch.no_grad(): - x_pred = setup_model.network_precond(model_inp) - x_pred[x_pred < 0] = 0 - - #del setup_model.network_precond - del initial_images - del prior_grads - del pll_grads - del model_inp - - initial = data.OSEM_image.clone() - initial.fill(x_pred.detach().cpu().numpy().squeeze()) - - for f in obj_funs: # add prior evenly to every objective function - f.set_prior(data.prior) - - super().__init__(data=data_sub, - obj_funs=obj_funs, - initial=initial, - average_sensitivity=average_sensitivity, - update_objective_interval=update_objective_interval) - -submission_callbacks = [] \ No newline at end of file diff --git a/main_Full_Gradient.py b/main_Full_Gradient.py deleted file mode 100644 index 007732e..0000000 --- a/main_Full_Gradient.py +++ /dev/null @@ -1,54 +0,0 @@ -"""Main file to modify for submissions. - -Once renamed or symlinked as `main.py`, it will be used by `petric.py` as follows: - ->>> from main import Submission, submission_callbacks ->>> from petric import data, metrics ->>> algorithm = Submission(data) ->>> algorithm.run(np.inf, callbacks=metrics + submission_callbacks) -""" -from cil.optimisation.algorithms import Algorithm -from cil.optimisation.utilities import callbacks - -from sirf.contrib.partitioner import partitioner - -from bsrem_bb import BSREM - -assert issubclass(BSREM, Algorithm) - - -class MaxIteration(callbacks.Callback): - """ - The organisers try to `Submission(data).run(inf)` i.e. for infinite iterations (until timeout). - This callback forces stopping after `max_iteration` instead. - """ - def __init__(self, max_iteration: int, verbose: int = 1): - super().__init__(verbose) - self.max_iteration = max_iteration - - def __call__(self, algorithm: Algorithm): - if algorithm.iteration >= self.max_iteration: - raise StopIteration - -class Submission(BSREM): - def __init__(self, data, - update_objective_interval: int = 1, - **kwargs): - - data_sub, _, obj_funs = partitioner.data_partition(data.acquired_data, data.additive_term, - data.mult_factors, 1, - initial_image=data.OSEM_image) - # WARNING: modifies prior strength with 1/num_subsets (as currently needed for BSREM implementations - data.prior.set_penalisation_factor(data.prior.get_penalisation_factor() / len(obj_funs)) - data.prior.set_up(data.OSEM_image) - for f in obj_funs: # add prior evenly to every objective function - f.set_prior(data.prior) - self.dataset = data - - super().__init__(data_sub, - obj_funs, - initial=data.OSEM_image, - update_objective_interval=update_objective_interval) - - -submission_callbacks = [] #[MaxIteration(660)] diff --git a/main_SAGA.py b/main_SAGA.py deleted file mode 100644 index 4b8fda1..0000000 --- a/main_SAGA.py +++ /dev/null @@ -1,57 +0,0 @@ - -from cil.optimisation.algorithms import Algorithm -from cil.optimisation.utilities import callbacks - - -from bsrem_saga import SAGA -from utils.number_of_subsets import compute_number_of_subsets - -#from sirf.contrib.partitioner import partitioner -from utils.partioner_function import data_partition - -assert issubclass(SAGA, Algorithm) - - -#import torch -#torch.cuda.set_per_process_memory_fraction(0.8) - -#from one_step_model import NetworkPreconditioner - - -class MaxIteration(callbacks.Callback): - def __init__(self, max_iteration: int, verbose: int = 1): - super().__init__(verbose) - self.max_iteration = max_iteration - - def __call__(self, algorithm: Algorithm): - if algorithm.iteration >= self.max_iteration: - raise StopIteration - - -class Submission(SAGA): - def __init__(self, data, - update_objective_interval: int = 10, - **kwargs): - - tof = (data.acquired_data.shape[0] > 1) - views = data.acquired_data.shape[2] - num_subsets = compute_number_of_subsets(views, tof) - print("Number of views: ", views, " use ", num_subsets, " subsets") - data_sub, _, obj_funs = data_partition(data.acquired_data, data.additive_term, - data.mult_factors, num_subsets, - initial_image=data.OSEM_image, - mode = "staggered") - self.dataset = data - - # WARNING: modifies prior strength with 1/num_subsets (as currently needed for BSREM implementations) - data.prior.set_penalisation_factor(data.prior.get_penalisation_factor() / len(obj_funs)) - data.prior.set_up(data.OSEM_image) - for f in obj_funs: # add prior evenly to every objective function - f.set_prior(data.prior) - - super().__init__(data_sub, - obj_funs, - data.OSEM_image, - update_objective_interval=1) - -submission_callbacks = [] \ No newline at end of file diff --git a/modified_petric.py b/modified_petric.py deleted file mode 100644 index 800dcd9..0000000 --- a/modified_petric.py +++ /dev/null @@ -1,299 +0,0 @@ -#!/usr/bin/env python -""" -ANY CHANGES TO THIS FILE ARE IGNORED BY THE ORGANISERS. -Only the `main.py` file may be modified by participants. - -This file is not intended for participants to use, except for -the `get_data` function (and possibly `QualityMetrics` class). -It is used by the organisers to run the submissions in a controlled way. -It is included here purely in the interest of transparency. - -Usage: - petric.py [options] - -Options: - --log LEVEL : Set logging level (DEBUG, [default: INFO], WARNING, ERROR, CRITICAL) -""" - -import csv -import logging -import os -from dataclasses import dataclass -from pathlib import Path -from time import time -from traceback import print_exc - -import torch -torch.cuda.set_per_process_memory_fraction(0.6) - -import numpy as np -from skimage.metrics import mean_squared_error as mse -from tensorboardX import SummaryWriter - -import sirf.STIR as STIR -from cil.optimisation.algorithms import Algorithm -from cil.optimisation.utilities import callbacks as cil_callbacks -from img_quality_cil_stir import ImageQualityCallback - -from datetime import datetime - - - -log = logging.getLogger('petric') -TEAM = os.getenv("GITHUB_REPOSITORY", "SyneRBI/PETRIC-").split("/PETRIC-", 1)[-1] -VERSION = os.getenv("GITHUB_REF_NAME", "") -OUTDIR = Path(f"/o/logs/{TEAM}/{VERSION}" if TEAM and VERSION else "./output/" + "full_gradient_rdp_diag_or_em" + datetime.now().strftime('%Y-%m-%d_%H-%M-%S')) -if not (SRCDIR := Path("/mnt/share/petric")).is_dir(): - SRCDIR = Path("./data") - - -class Callback(cil_callbacks.Callback): - """ - CIL Callback but with `self.skip_iteration` checking `min(self.interval, algo.update_objective_interval)`. - TODO: backport this class to CIL. - """ - def __init__(self, interval: int = 1 << 31, **kwargs): - super().__init__(**kwargs) - self.interval = interval - - def skip_iteration(self, algo: Algorithm) -> bool: - return algo.iteration % min(self.interval, - algo.update_objective_interval) != 0 and algo.iteration != algo.max_iteration - - -class SaveIters(Callback): - """Saves `algo.x` as "iter_{algo.iteration:04d}.hv" and `algo.loss` in `csv_file`""" - def __init__(self, outdir=OUTDIR, csv_file='objectives.csv', **kwargs): - super().__init__(**kwargs) - self.outdir = Path(outdir) - self.outdir.mkdir(parents=True, exist_ok=True) - self.csv = csv.writer((self.outdir / csv_file).open("w", buffering=1)) - self.csv.writerow(("iter", "objective")) - - def __call__(self, algo: Algorithm): - if not self.skip_iteration(algo): - log.debug("saving iter %d...", algo.iteration) - algo.x.write(str(self.outdir / f'iter_{algo.iteration:04d}.hv')) - self.csv.writerow((algo.iteration, algo.get_last_loss())) - log.debug("...saved") - if algo.iteration == algo.max_iteration: - algo.x.write(str(self.outdir / 'iter_final.hv')) - - -class StatsLog(Callback): - """Log image slices & objective value""" - def __init__(self, transverse_slice=None, coronal_slice=None, vmax=None, logdir=OUTDIR, **kwargs): - super().__init__(**kwargs) - self.transverse_slice = transverse_slice - self.coronal_slice = coronal_slice - self.vmax = vmax - self.x_prev = None - self.tb = logdir if isinstance(logdir, SummaryWriter) else SummaryWriter(logdir=str(logdir)) - - def __call__(self, algo: Algorithm): - if self.skip_iteration(algo): - return - t = getattr(self, '__time', None) or time() - log.debug("logging iter %d...", algo.iteration) - # initialise `None` values - self.transverse_slice = algo.x.dimensions()[0] // 2 if self.transverse_slice is None else self.transverse_slice - self.coronal_slice = algo.x.dimensions()[1] // 2 if self.coronal_slice is None else self.coronal_slice - self.vmax = algo.x.max() if self.vmax is None else self.vmax - - self.tb.add_scalar("objective", algo.get_last_loss(), algo.iteration, t) - if self.x_prev is not None: - normalised_change = (algo.x - self.x_prev).norm() / algo.x.norm() - self.tb.add_scalar("normalised_change", normalised_change, algo.iteration, t) - self.x_prev = algo.x.clone() - x_arr = algo.x.as_array() - self.tb.add_image("transverse", np.clip(x_arr[self.transverse_slice:self.transverse_slice + 1] / self.vmax, 0, - 1), algo.iteration, t) - self.tb.add_image("coronal", np.clip(x_arr[None, :, self.coronal_slice] / self.vmax, 0, 1), algo.iteration, t) - log.debug("...logged") - - -class QualityMetrics(ImageQualityCallback, Callback): - """From https://github.com/SyneRBI/PETRIC/wiki#metrics-and-thresholds""" - def __init__(self, reference_image, whole_object_mask, background_mask, interval: int = 1 << 31, **kwargs): - # TODO: drop multiple inheritance once `interval` included in CIL - Callback.__init__(self, interval=interval) - ImageQualityCallback.__init__(self, reference_image, **kwargs) - self.whole_object_indices = np.where(whole_object_mask.as_array()) - self.background_indices = np.where(background_mask.as_array()) - self.ref_im_arr = reference_image.as_array() - self.norm = self.ref_im_arr[self.background_indices].mean() - - def __call__(self, algo: Algorithm): - if self.skip_iteration(algo): - return - t = getattr(self, '__time', None) or time() - for tag, value in self.evaluate(algo.x).items(): - self.tb_summary_writer.add_scalar(tag, value, algo.iteration, t) - - def evaluate(self, test_im: STIR.ImageData) -> dict[str, float]: - assert not any(self.filter.values()), "Filtering not implemented" - test_im_arr = test_im.as_array() - whole = { - "RMSE_whole_object": np.sqrt( - mse(self.ref_im_arr[self.whole_object_indices], test_im_arr[self.whole_object_indices])) / self.norm, - "RMSE_background": np.sqrt( - mse(self.ref_im_arr[self.background_indices], test_im_arr[self.background_indices])) / self.norm} - local = { - f"AEM_VOI_{voi_name}": np.abs(test_im_arr[voi_indices].mean() - self.ref_im_arr[voi_indices].mean()) / - self.norm - for voi_name, voi_indices in sorted(self.voi_indices.items())} - return {**whole, **local} - - def keys(self): - return ["RMSE_whole_object", "RMSE_background"] + [f"AEM_VOI_{name}" for name in sorted(self.voi_indices)] - - -class MetricsWithTimeout(cil_callbacks.Callback): - """Stops the algorithm after `seconds`""" - def __init__(self, seconds=300, outdir=OUTDIR, transverse_slice=None, coronal_slice=None, **kwargs): - super().__init__(**kwargs) - print("outdir: ", outdir) - self._seconds = seconds - self.callbacks = [ - cil_callbacks.ProgressCallback(), - SaveIters(outdir=outdir), - (tb_cbk := StatsLog(logdir=outdir, transverse_slice=transverse_slice, coronal_slice=coronal_slice))] - self.tb = tb_cbk.tb # convenient access to the underlying SummaryWriter - self.reset() - - def reset(self, seconds=None): - self.limit = time() + (self._seconds if seconds is None else seconds) - self.offset = 0 - - def __call__(self, algo: Algorithm): - if (now := time()) > self.limit + self.offset: - log.warning("Timeout reached. Stopping algorithm.") - raise StopIteration - for c in self.callbacks: - c.__time = now - self.offset # privately inject walltime-excluding-petric-callbacks - c(algo) - self.offset += time() - now - - @staticmethod - def mean_absolute_error(y, x): - return np.mean(np.abs(y, x)) - - -def construct_RDP(penalty_strength, initial_image, kappa, max_scaling=1e-3): - """ - Construct a smoothed Relative Difference Prior (RDP) - - initial_image: used to determine a smoothing factor (epsilon). - kappa: used to pass voxel-dependent weights. - """ - prior = getattr(STIR, 'CudaRelativeDifferencePrior', STIR.RelativeDifferencePrior)() - # need to make it differentiable - epsilon = initial_image.max() * max_scaling - prior.set_epsilon(epsilon) - prior.set_penalisation_factor(penalty_strength) - prior.set_kappa(kappa) - prior.set_up(initial_image) - return prior - - -@dataclass -class Dataset: - acquired_data: STIR.AcquisitionData - additive_term: STIR.AcquisitionData - mult_factors: STIR.AcquisitionData - OSEM_image: STIR.ImageData - prior: STIR.RelativeDifferencePrior - kappa: STIR.ImageData - reference_image: STIR.ImageData | None - whole_object_mask: STIR.ImageData | None - background_mask: STIR.ImageData | None - voi_masks: dict[str, STIR.ImageData] - - -def get_data(srcdir=".", outdir=OUTDIR, sirf_verbosity=0): - """ - Load data from `srcdir`, constructs prior and return as a `Dataset`. - Also redirects sirf.STIR log output to `outdir`. - """ - srcdir = Path(srcdir) - outdir = Path(outdir) - STIR.set_verbosity(sirf_verbosity) # set to higher value to diagnose problems - STIR.AcquisitionData.set_storage_scheme('memory') # needed for get_subsets() - - _ = STIR.MessageRedirector(str(outdir / 'info.txt'), str(outdir / 'warnings.txt'), str(outdir / 'errors.txt')) - acquired_data = STIR.AcquisitionData(str(srcdir / 'prompts.hs')) - additive_term = STIR.AcquisitionData(str(srcdir / 'additive_term.hs')) - mult_factors = STIR.AcquisitionData(str(srcdir / 'mult_factors.hs')) - OSEM_image = STIR.ImageData(str(srcdir / 'OSEM_image.hv')) - kappa = STIR.ImageData(str(srcdir / 'kappa.hv')) - if (penalty_strength_file := (srcdir / 'penalisation_factor.txt')).is_file(): - penalty_strength = float(np.loadtxt(penalty_strength_file)) - else: - penalty_strength = 1 / 700 # default choice - prior = construct_RDP(penalty_strength, OSEM_image, kappa) - - def get_image(fname): - if (source := srcdir / 'PETRIC' / fname).is_file(): - return STIR.ImageData(str(source)) - return None # explicit to suppress linter warnings - - reference_image = get_image('reference_image.hv') - whole_object_mask = get_image('VOI_whole_object.hv') - background_mask = get_image('VOI_background.hv') - voi_masks = { - voi.stem[4:]: STIR.ImageData(str(voi)) - for voi in (srcdir / 'PETRIC').glob("VOI_*.hv") if voi.stem[4:] not in ('background', 'whole_object')} - return Dataset(acquired_data, additive_term, mult_factors, OSEM_image, prior, kappa, reference_image, - whole_object_mask, background_mask, voi_masks) - - -if SRCDIR.is_dir(): - # create list of existing data - # NB: `MetricsWithTimeout` initialises `SaveIters` which creates `outdir` - data_dirs_metrics = [(SRCDIR / "Mediso_NEMA_IQ", - OUTDIR / "Mediso_NEMA", - [MetricsWithTimeout(seconds=300, outdir=OUTDIR / "Mediso_NEMA")]), - (SRCDIR / "NeuroLF_Hoffman_Dataset", - OUTDIR / "NeuroLF_Hoffman", - [MetricsWithTimeout(seconds=300, outdir=OUTDIR / "NeuroLF_Hoffman")]), - (SRCDIR / "Siemens_mMR_ACR", - OUTDIR / "Siemens_ACR", - [MetricsWithTimeout(seconds=300, outdir=OUTDIR / "Siemens_ACR")]), - (SRCDIR / "Siemens_mMR_NEMA_IQ", - OUTDIR / "Siemens_NEMA", - [MetricsWithTimeout(seconds=300, outdir=OUTDIR / "Siemens_NEMA")]), - (SRCDIR / "Siemens_mMR_NEMA_IQ_lowcounts", - OUTDIR / "Siemens_NEMA_lowcounts", - [MetricsWithTimeout(seconds=300, outdir=OUTDIR / "Siemens_NEMA_lowcounts")]), - (SRCDIR / "Siemens_Vision600_thorax", - OUTDIR / "Vision600_thorax", - [MetricsWithTimeout(seconds=300, outdir=OUTDIR / "Vision600_thorax")]), - ] - - - -else: - log.warning("Source directory does not exist: %s", SRCDIR) - data_dirs_metrics = [(None, None, [])] # type: ignore - -if __name__ == "__main__": - from docopt import docopt - args = docopt(__doc__) - logging.basicConfig(level=getattr(logging, args["--log"].upper())) - from main_Full_Gradient import Submission, submission_callbacks - assert issubclass(Submission, Algorithm) - for srcdir, outdir, metrics in data_dirs_metrics: - data = get_data(srcdir=srcdir, outdir=outdir) - metrics_with_timeout = metrics[0] - if data.reference_image is not None: - metrics_with_timeout.callbacks.append( - QualityMetrics(data.reference_image, data.whole_object_mask, data.background_mask, - tb_summary_writer=metrics_with_timeout.tb, voi_mask_dict=data.voi_masks)) - metrics_with_timeout.reset() # timeout from now - algo = Submission(data) - try: - algo.run(np.inf, callbacks=metrics + submission_callbacks) - except Exception: - print_exc(limit=2) - finally: - del algo diff --git a/one_step_model.py b/one_step_model.py deleted file mode 100644 index f499553..0000000 --- a/one_step_model.py +++ /dev/null @@ -1,44 +0,0 @@ - -import torch - - -class NetworkPreconditioner(torch.nn.Module): - def __init__(self, n_layers = 1, hidden_channels = 8, kernel_size = 3): - super(NetworkPreconditioner, self).__init__() - - self.conv1 = torch.nn.Conv3d(3, 3*hidden_channels, kernel_size, groups=3, padding='same', bias=False) - self.conv2 = torch.nn.Conv3d(3*hidden_channels, 3*hidden_channels, kernel_size, groups=3, padding='same', bias=False) - self.conv3 = torch.nn.Conv3d(3*hidden_channels, hidden_channels, kernel_size, padding='same', bias=False) - - self.max_pool = torch.nn.MaxPool3d(kernel_size=2) - - self.conv4 = torch.nn.Conv3d(hidden_channels, hidden_channels, kernel_size, padding='same', bias=False) - self.conv5 = torch.nn.Conv3d(hidden_channels, hidden_channels, kernel_size, padding='same', bias=False) - - # interpolate - - self.conv6 = torch.nn.Conv3d(hidden_channels, hidden_channels, kernel_size, padding='same', bias=False) - self.conv7 = torch.nn.Conv3d(hidden_channels, 1, kernel_size, padding='same', bias=False) - - self.activation = torch.nn.ReLU() - - #self.list_of_conv3[-1].weight.data.fill_(0.0) - #self.list_of_conv3[-1].bias.data.fill_(0.0) - - def forward(self, x): - shape = x.shape - z = self.activation(self.conv1(x)) - z = self.activation(self.conv2(z)) - z1 = self.activation(self.conv3(z)) - - z2 = self.max_pool(z1) - z2 = self.activation(self.conv4(z2)) - z2 = self.activation(self.conv5(z2)) - - z3 = torch.nn.functional.interpolate(z2, size=shape[-3:], mode = "trilinear", align_corners=True) - z3 = z3 + z1 - - z4 = self.activation(self.conv6(z3)) - z_out = self.activation(self.conv7(z4)) - - return z_out diff --git a/plot_results.py b/plot_results.py deleted file mode 100644 index 6fbef52..0000000 --- a/plot_results.py +++ /dev/null @@ -1,110 +0,0 @@ - -import os -import pandas as pd -import numpy as np -import matplotlib.pyplot as plt -import yaml - -def cfg_to_name(cfg, file_id): - try: - method = cfg["method"] - except KeyError: - method = "osem" - if method == "bsrem": - name = f"{method}, subsets={cfg["num_subsets"]}, gamma={cfg["gamma"]} [{file_id}]" # {cfg["num_subsets"]} - return name - if method == "ews": - name = f"{method}, (ews={cfg["ews_method"]}) [{file_id}]" # {cfg["num_subsets"]} - return name - - if method == "bsrem_bb": - try: - name = f"{method}, {cfg["num_subsets"]} ({cfg["mode"]}), beta={cfg["beta"]}, bb init {cfg["bb_init_mode"]} [{file_id}]" - except KeyError: - name = f"{method}, ({cfg["mode"]}), beta={cfg["beta"]}, bb init {cfg["bb_init_mode"]} [{file_id}]" - - return name - if method == "adam": - name = f"{method}, {cfg["num_subsets"]} ({cfg["mode"]}), init lr={cfg["initial_step_size"]} [{file_id}]" - return name - if method == "adadelta": - name = f"{method}, {cfg["num_subsets"]} ({cfg["mode"]}), init lr={cfg["initial_step_size"]} [{file_id}]" - return name - return f"{method} [{file_id}]" - - -methods = ["Mediso_NEMA", "mMR_NEMA", "NeuroLF_Hoffman", "Siemens_mMR_ACR", "Vision600_thorax"] - -for method in methods: - print(method) - df = pd.read_csv(f"/home/alexdenker/pet/PETRIC-UCL-EWS/logs/osem/2024-09-16_09-06-06/{method}/results.csv") - metric_keys = list(df.keys())[2:] - - print(metric_keys) - - base_path = "/home/alexdenker/pet/PETRIC-UCL-EWS/logs/osem" - files = [os.path.join(base_path, f, method) for f in os.listdir(base_path)] - - base_path_bsrem = "/home/alexdenker/pet/PETRIC-UCL-EWS/logs/bsrem" - files = files + [os.path.join(base_path_bsrem, f, method) for f in os.listdir(base_path_bsrem)] - - fig, axes = plt.subplots(3, 4, figsize=(16,8)) - ax_ = list(axes.ravel()) - - for f_idx, f in enumerate(files): - try: - df = pd.read_csv(os.path.join(f, "results.csv")) - with open(os.path.join(f, "config.yaml")) as file_: - cfgdict = yaml.safe_load(file_) - - - for idx in range(len(metric_keys)): - - ax = ax_[idx] - ax.set_title(metric_keys[idx]) - - label_name = cfg_to_name(cfgdict, file_id=f.split("/")[-2]) - #ax.plot(df["time"], df[metric_keys[idx]], label=label_name) - try: - if "step" in metric_keys[idx] or "loss" in metric_keys[idx]: - ax.plot(df["time"], df[metric_keys[idx]], "-", label=label_name) - #ax.plot(df[metric_keys[idx]], label=label_name) - - else: - ax.semilogy(df["time"], df[metric_keys[idx]], "-", label=label_name) - - #ax.semilogy(df[metric_keys[idx]], label=label_name) - - if "AEM" in metric_keys[idx]: - ax.hlines(0.005, 0, df["time"].max(), color="k") - ax.set_ylim(0.0001, 0.6) - if "RMSE_whole_object" in metric_keys[idx]: - ax.hlines(0.01, 0, df["time"].max(), color="k") - ax.set_ylim(0.001, 0.9) - if "RMSE_background" in metric_keys[idx]: - ax.hlines(0.01, 0, df["time"].max(), color="k") - ax.set_ylim(0.001, 0.9) - if "change" in metric_keys[idx]: - ax.set_ylim(0.0001, 10.0) - - except KeyError: - if "step" in metric_keys[idx] or "loss" in metric_keys[idx]: - ax.plot(np.nan) - else: - ax.semilogy(np.nan) - - ax.legend(fontsize=7) - except FileNotFoundError: - pass - - ax_ = list(axes.ravel()) - ax_[-1].axis("off") - #ax_[-2].axis("off") - - - #for ax in axes: - # ax.set_xlim([0, 500]) - fig.suptitle(method) - fig.tight_layout() - plt.show() - diff --git a/reduced_petric.py b/reduced_petric.py deleted file mode 100755 index b3986fc..0000000 --- a/reduced_petric.py +++ /dev/null @@ -1,337 +0,0 @@ -#!/usr/bin/env python -""" -ANY CHANGES TO THIS FILE ARE IGNORED BY THE ORGANISERS. -Only the `main.py` file may be modified by participants. - -This file is not intended for participants to use, except for -the `get_data` function (and possibly `QualityMetrics` class). -It is used by the organisers to run the submissions in a controlled way. -It is included here purely in the interest of transparency. - -Usage: - petric.py [options] - -Options: - --log LEVEL : Set logging level (DEBUG, [default: INFO], WARNING, ERROR, CRITICAL) -""" -import csv -import logging -import os -from dataclasses import dataclass -from pathlib import Path -import time -from traceback import print_exc -from datetime import datetime -import yaml - -import numpy as np -from skimage.metrics import mean_squared_error as mse -import matplotlib.pyplot as plt - -import sirf.STIR as STIR -from cil.optimisation.algorithms import Algorithm -from cil.optimisation.utilities import callbacks as cil_callbacks -#from img_quality_cil_stir import ImageQualityCallback - -#import torch -#torch.cuda.set_per_process_memory_fraction(0.8) - -method = "bsrem" - -if method == "ews": - from main_EWS import Submission, submission_callbacks - submission_args = { - "method": "ews", - "model_name" : None, - "weights_path": None, - "mode": "staggered", - "initial_step_size": 0.3, - "relaxation_eta": 0.01, - "num_subsets": 10, - } -elif method == "adam": - from main_ADAM import Submission, submission_callbacks - submission_args = { - "method": "adam", - "initial_step_size": 1e-3, - "relaxation_eta": 0.005, - "mode": "staggered" - } -elif method == "osem": - from main_OSEM import Submission, submission_callbacks - submission_args = { - #"method": "osem", - #"mode": "staggered" - } -elif method == "bsrem": - from main import Submission, submission_callbacks - submission_args = { - "method": "bsrem", - "accumulate_gradient_iter": [6, 10, 14, 18, 32], - "accumulate_gradient_num": [1, 2, 4, 8, 16], - "gamma": 0.9, #0.9, - } -else: - raise NotImplementedError - -OUTDIR = Path("logs/" + method + "/" + datetime.now().strftime('%Y-%m-%d_%H-%M-%S')) - -assert issubclass(Submission, Algorithm) - - -log = logging.getLogger('petric') -if not (SRCDIR := Path("/mnt/share/petric")).is_dir(): - SRCDIR = Path("./data") - -class Callback(cil_callbacks.Callback): - """ - CIL Callback but with `self.skip_iteration` checking `min(self.interval, algo.update_objective_interval)`. - TODO: backport this class to CIL. - """ - def __init__(self, interval: int = 1 << 31, **kwargs): - super().__init__(**kwargs) - self.interval = interval - - def skip_iteration(self, algo: Algorithm) -> bool: - return algo.iteration % min(self.interval, - algo.update_objective_interval) != 0 and algo.iteration != algo.max_iteration - -class QualityMetrics(Callback): - """From https://github.com/SyneRBI/PETRIC/wiki#metrics-and-thresholds""" - def __init__(self, - reference_image, - whole_object_mask, - background_mask, - output_dir, - interval: int = 1 << 31, - **kwargs): - # TODO: drop multiple inheritance once `interval` included in CIL - Callback.__init__(self, interval=interval) - #ImageQualityCallback.__init__(self, reference_image, **kwargs) - self.whole_object_indices = np.where(whole_object_mask.as_array()) - self.background_indices = np.where(background_mask.as_array()) - self.ref_im_arr = reference_image.as_array() - self.norm = self.ref_im_arr[self.background_indices].mean() - - voi_mask_dict = kwargs.get("voi_mask_dict", None) - self.voi_indices = {} - for key, value in (voi_mask_dict or {}).items(): - self.voi_indices[key] = np.where(value.as_array()) - - self.filter = None - self.x_prev = None - self.output_dir = output_dir - headers = ["iteration", "time"] + self.keys() + ["normalised_change"] + ["step_size"] + ["loss"] - with open(os.path.join(self.output_dir, "results.csv"), 'w', newline='') as file: - writer = csv.writer(file) - writer.writerow(headers) - - def __call__(self, algo: Algorithm): - if self.skip_iteration(algo): - print("Skip iteration, dont log") - return - t = getattr(self, '_time', None) or time.time() - row = [algo.iteration, t] - for tag, value in self.evaluate(algo.x).items(): - row.append(value) - - if self.x_prev is not None: - normalised_change = (algo.x - self.x_prev).norm() / algo.x.norm() - row.append(normalised_change) - else: - row.append(np.nan) - self.x_prev = algo.x.clone() - - row.append(algo.alpha) - - loss = algo.get_last_loss() - row.append(loss) - with open(os.path.join(self.output_dir, "results.csv"), 'a', newline='') as file: - writer = csv.writer(file) - writer.writerow(row) - - #fig, (ax1, ax2) = plt.subplots(1,2) - #print("mean of pred / mean of gt: ", np.mean(algo.x.as_array()[72,:,:]), np.mean(self.ref_im_arr[72,:,:])) - #im = ax1.imshow(algo.x.as_array()[56,:,:], cmap="gray") - #fig.colorbar(im, ax=ax1) - - #im = ax2.imshow(self.ref_im_arr[56,:,:], cmap="gray") - #fig.colorbar(im, ax=ax2) - #plt.savefig(os.path.join(self.output_dir, "imgs", f"reco_at_{algo.iteration}.png")) - #plt.close() - - def evaluate(self, test_im: STIR.ImageData) -> dict[str, float]: - #assert not any(self.filter.values()), "Filtering not implemented" - test_im_arr = test_im.as_array() - whole = { - "RMSE_whole_object": np.sqrt( - mse(self.ref_im_arr[self.whole_object_indices], test_im_arr[self.whole_object_indices])) / self.norm, - "RMSE_background": np.sqrt( - mse(self.ref_im_arr[self.background_indices], test_im_arr[self.background_indices])) / self.norm} - local = { - f"AEM_VOI_{voi_name}": np.abs(test_im_arr[voi_indices].mean() - self.ref_im_arr[voi_indices].mean()) / - self.norm - for voi_name, voi_indices in sorted(self.voi_indices.items())} - return {**whole, **local} - - def keys(self): - return ["RMSE_whole_object", "RMSE_background"] + [f"AEM_VOI_{name}" for name in sorted(self.voi_indices)] - - -class MetricsWithTimeout(cil_callbacks.Callback): - """Stops the algorithm after `seconds`""" - def __init__(self, seconds=300, **kwargs): - super().__init__(**kwargs) - self._seconds = seconds - self.callbacks = [ - cil_callbacks.ProgressCallback()] - - self.reset() - - def reset(self, seconds=None): - self.limit = time.time() + (self._seconds if seconds is None else seconds) - self.start_time = time.time() #0 - - def __call__(self, algo: Algorithm): - if (now := time.time()) > self.limit: - log.warning("Timeout reached. Stopping algorithm.") - raise StopIteration - for c in self.callbacks: - c._time = now - self.start_time # privatel - c(algo) - - @staticmethod - def mean_absolute_error(y, x): - return np.mean(np.abs(y, x)) - - -def construct_RDP(penalty_strength, initial_image, kappa, max_scaling=1e-3): - """ - Construct a smoothed Relative Difference Prior (RDP) - - initial_image: used to determine a smoothing factor (epsilon). - kappa: used to pass voxel-dependent weights. - """ - prior = getattr(STIR, 'CudaRelativeDifferencePrior', STIR.RelativeDifferencePrior)() # CudaRelativeDifferencePrior - # need to make it differentiable - epsilon = initial_image.max() * max_scaling - prior.set_epsilon(epsilon) - prior.set_penalisation_factor(penalty_strength) - prior.set_kappa(kappa) - prior.set_up(initial_image) - return prior - - -@dataclass -class Dataset: - acquired_data: STIR.AcquisitionData - additive_term: STIR.AcquisitionData - mult_factors: STIR.AcquisitionData - OSEM_image: STIR.ImageData - prior: STIR.RelativeDifferencePrior - kappa: STIR.ImageData - reference_image: STIR.ImageData | None - whole_object_mask: STIR.ImageData | None - background_mask: STIR.ImageData | None - voi_masks: dict[str, STIR.ImageData] - - -def get_data(srcdir=".", outdir=OUTDIR, sirf_verbosity=0): - """ - Load data from `srcdir`, constructs prior and return as a `Dataset`. - Also redirects sirf.STIR log output to `outdir`. - """ - srcdir = Path(srcdir) - outdir = Path(outdir) - STIR.set_verbosity(sirf_verbosity) # set to higher value to diagnose problems - STIR.AcquisitionData.set_storage_scheme('memory') # needed for get_subsets() - - _ = STIR.MessageRedirector(str(outdir / 'info.txt'), str(outdir / 'warnings.txt'), str(outdir / 'errors.txt')) - acquired_data = STIR.AcquisitionData(str(srcdir / 'prompts.hs')) - additive_term = STIR.AcquisitionData(str(srcdir / 'additive_term.hs')) - mult_factors = STIR.AcquisitionData(str(srcdir / 'mult_factors.hs')) - print("SHAPE: ", acquired_data.shape, additive_term.shape, mult_factors.shape) - #if acquired_data.shape[0] > 1: - # tof_bins = acquired_data.shape[0] - - # acquired_data = acquired_data.rebin(num_tof_bins_to_combine=int(tof_bins), do_normalisation=True, num_segments_to_combine=1) - # additive_term = additive_term.rebin(num_tof_bins_to_combine=int(tof_bins), do_normalisation=True, num_segments_to_combine=1) - # mult_factors = mult_factors.rebin(num_tof_bins_to_combine=int(tof_bins), do_normalisation=True, num_segments_to_combine=1) - # print("NEW SHAPE: ", acquired_data.shape, additive_term.shape, mult_factors.shape) - - OSEM_image = STIR.ImageData(str(srcdir / 'OSEM_image.hv')) - kappa = STIR.ImageData(str(srcdir / 'kappa.hv')) - if (penalty_strength_file := (srcdir / 'penalisation_factor.txt')).is_file(): - penalty_strength = float(np.loadtxt(penalty_strength_file)) - else: - penalty_strength = 1 / 700 # default choice - prior = construct_RDP(penalty_strength, OSEM_image, kappa) - - def get_image(fname): - if (source := srcdir / 'PETRIC' / fname).is_file(): - return STIR.ImageData(str(source)) - return None # explicit to suppress linter warnings - - reference_image = get_image('reference_image.hv') - whole_object_mask = get_image('VOI_whole_object.hv') - background_mask = get_image('VOI_background.hv') - voi_masks = { - voi.stem[4:]: STIR.ImageData(str(voi)) - for voi in (srcdir / 'PETRIC').glob("VOI_*.hv") if voi.stem[4:] not in ('background', 'whole_object')} - - return Dataset(acquired_data, additive_term, mult_factors, OSEM_image, prior, kappa, reference_image, - whole_object_mask, background_mask, voi_masks) - -# create list of existing data -# NB: `MetricsWithTimeout` initialises `SaveIters` which creates `outdir` -data_dirs_metrics = [ (SRCDIR / "Siemens_mMR_NEMA_IQ", - OUTDIR / "mMR_NEMA", - [MetricsWithTimeout(seconds=700)]), - (SRCDIR / "Mediso_NEMA_IQ", - OUTDIR / "Mediso_NEMA", - [MetricsWithTimeout(seconds=700)]), - (SRCDIR / "Siemens_Vision600_thorax", - OUTDIR / "Vision600_thorax", - [MetricsWithTimeout(seconds=700)]), - (SRCDIR / "Siemens_mMR_ACR", - OUTDIR / "Siemens_mMR_ACR", - [MetricsWithTimeout(seconds=700)]), - (SRCDIR / "NeuroLF_Hoffman_Dataset", - OUTDIR / "NeuroLF_Hoffman", - [MetricsWithTimeout(seconds=700)]) - ] - - -from docopt import docopt -args = docopt(__doc__) -logging.basicConfig(level=getattr(logging, args["--log"].upper())) - -for srcdir, outdir, metrics in data_dirs_metrics: - print("OUTPUT dir: ", outdir) - os.makedirs(outdir) - - - os.makedirs(os.path.join(outdir, "imgs")) - - data = get_data(srcdir=srcdir, outdir=outdir) - metrics_with_timeout = metrics[0] - if data.reference_image is not None: - metrics_with_timeout.callbacks.append( - QualityMetrics(data.reference_image, - data.whole_object_mask, - data.background_mask, - voi_mask_dict=data.voi_masks, - output_dir=outdir, - interval=1)) - metrics_with_timeout.reset() # timeout from now - algo = Submission(data, update_objective_interval=100000, **submission_args) - submission_args["num_subsets"] = algo.num_subsets - - with open(os.path.join(outdir, "config.yaml"), "w") as file: - yaml.dump(submission_args, file) - try: - algo.run(np.inf, callbacks=metrics + submission_callbacks) - except Exception: - print_exc(limit=2) - finally: - del algo diff --git a/setup_model.py b/setup_model.py deleted file mode 100644 index f110ed5..0000000 --- a/setup_model.py +++ /dev/null @@ -1,10 +0,0 @@ -import torch -torch.cuda.set_per_process_memory_fraction(0.7) - -from one_step_model import NetworkPreconditioner - - -DEVICE = "cuda" -network_precond = NetworkPreconditioner() -network_precond = network_precond.to(DEVICE) -network_precond.load_state_dict(torch.load("checkpoint/model.pt", weights_only=True)) \ No newline at end of file diff --git a/test_single_iteration.py b/test_single_iteration.py deleted file mode 100644 index 3873b76..0000000 --- a/test_single_iteration.py +++ /dev/null @@ -1,285 +0,0 @@ -# %% - -from pathlib import Path -import numpy as np -import logging -import os -from dataclasses import dataclass -from matplotlib import pyplot as plt -from skimage.metrics import mean_squared_error as mse -from tqdm import tqdm - -import torch -torch.cuda.set_per_process_memory_fraction(0.6) - -log = logging.getLogger('petric') -TEAM = os.getenv("GITHUB_REPOSITORY", "SyneRBI/PETRIC-").split("/PETRIC-", 1)[-1] -VERSION = os.getenv("GITHUB_REF_NAME", "") -OUTDIR = Path(f"/o/logs/{TEAM}/{VERSION}" if TEAM and VERSION else "./output") -if not (SRCDIR := Path("/mnt/share/petric")).is_dir(): - SRCDIR = Path("./data") - -import sirf.STIR as STIR -from sirf.contrib.partitioner import partitioner - -@dataclass -class Dataset: - acquired_data: STIR.AcquisitionData - additive_term: STIR.AcquisitionData - mult_factors: STIR.AcquisitionData - OSEM_image: STIR.ImageData - prior: STIR.RelativeDifferencePrior - kappa: STIR.ImageData - reference_image: STIR.ImageData | None - whole_object_mask: STIR.ImageData | None - background_mask: STIR.ImageData | None - voi_masks: dict[str, STIR.ImageData] - - -def evaluate_quality_metrics(reference, prediction, whole_object_mask, background_mask, voi_indices): - whole_object_indices = np.where(whole_object_mask) - background_indices = np.where(background_mask) - norm = reference[background_indices].mean() - - - whole = { - "RMSE_whole_object": np.sqrt( - mse(reference[whole_object_indices], prediction[whole_object_indices])) / norm, - "RMSE_background": np.sqrt( - mse(reference[background_indices], prediction[background_indices])) / norm} - local = { - f"AEM_VOI_{voi_name}": np.abs(prediction[voi_indices].mean() - reference[voi_indices].mean()) / - norm for voi_name, voi_indices in sorted(voi_indices.items())} - return {**whole, **local} - -def construct_RDP(penalty_strength, initial_image, kappa, max_scaling=1e-3): - """ - Construct a smoothed Relative Difference Prior (RDP) - - initial_image: used to determine a smoothing factor (epsilon). - kappa: used to pass voxel-dependent weights. - """ - prior = getattr(STIR, 'CudaRelativeDifferencePrior', STIR.RelativeDifferencePrior)() - # need to make it differentiable - epsilon = initial_image.max() * max_scaling - prior.set_epsilon(epsilon) - prior.set_penalisation_factor(penalty_strength) - prior.set_kappa(kappa) - prior.set_up(initial_image) - return prior - - -def get_data(srcdir=".", outdir=OUTDIR, sirf_verbosity=0): - """ - Load data from `srcdir`, constructs prior and return as a `Dataset`. - Also redirects sirf.STIR log output to `outdir`. - """ - srcdir = Path(srcdir) - outdir = Path(outdir) - STIR.set_verbosity(sirf_verbosity) # set to higher value to diagnose problems - STIR.AcquisitionData.set_storage_scheme('memory') # needed for get_subsets() - - _ = STIR.MessageRedirector(str(outdir / 'info.txt'), str(outdir / 'warnings.txt'), str(outdir / 'errors.txt')) - acquired_data = STIR.AcquisitionData(str(srcdir / 'prompts.hs')) - additive_term = STIR.AcquisitionData(str(srcdir / 'additive_term.hs')) - mult_factors = STIR.AcquisitionData(str(srcdir / 'mult_factors.hs')) - OSEM_image = STIR.ImageData(str(srcdir / 'OSEM_image.hv')) - kappa = STIR.ImageData(str(srcdir / 'kappa.hv')) - if (penalty_strength_file := (srcdir / 'penalisation_factor.txt')).is_file(): - penalty_strength = float(np.loadtxt(penalty_strength_file)) - else: - penalty_strength = 1 / 700 # default choice - prior = construct_RDP(penalty_strength, OSEM_image, kappa) - - def get_image(fname): - if (source := srcdir / 'PETRIC' / fname).is_file(): - return STIR.ImageData(str(source)) - return None # explicit to suppress linter warnings - - reference_image = get_image('reference_image.hv') - whole_object_mask = get_image('VOI_whole_object.hv') - background_mask = get_image('VOI_background.hv') - voi_masks = { - voi.stem[4:]: STIR.ImageData(str(voi)) - for voi in (srcdir / 'PETRIC').glob("VOI_*.hv") if voi.stem[4:] not in ('background', 'whole_object')} - - return Dataset(acquired_data, additive_term, mult_factors, OSEM_image, prior, kappa, reference_image, - whole_object_mask, background_mask, voi_masks) - - -if SRCDIR.is_dir(): - data_dirs_metrics = [ (SRCDIR / "Siemens_mMR_NEMA_IQ", - OUTDIR / "mMR_NEMA"), - (SRCDIR / "NeuroLF_Hoffman_Dataset", - OUTDIR / "NeuroLF_Hoffman"), - (SRCDIR / "Mediso_NEMA_IQ", - OUTDIR / "Mediso_NEMA"), - (SRCDIR / "Siemens_Vision600_thorax", - OUTDIR / "Vision600_thorax"), - (SRCDIR / "Siemens_mMR_ACR", - OUTDIR / "Siemens_mMR_ACR"), - ] - - - -device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') -print("device: ", device) - -class NetworkPreconditioner(torch.nn.Module): - def __init__(self, n_layers = 1, hidden_channels = 8, kernel_size = 3): - super(NetworkPreconditioner, self).__init__() - - self.conv1 = torch.nn.Conv3d(3, 3*hidden_channels, kernel_size, groups=3, padding='same', bias=False) - self.conv2 = torch.nn.Conv3d(3*hidden_channels, 3*hidden_channels, kernel_size, groups=3, padding='same', bias=False) - self.conv3 = torch.nn.Conv3d(3*hidden_channels, hidden_channels, kernel_size, padding='same', bias=False) - - self.max_pool = torch.nn.MaxPool3d(kernel_size=2) - - self.conv4 = torch.nn.Conv3d(hidden_channels, hidden_channels, kernel_size, padding='same', bias=False) - self.conv5 = torch.nn.Conv3d(hidden_channels, hidden_channels, kernel_size, padding='same', bias=False) - - # interpolate - - self.conv6 = torch.nn.Conv3d(hidden_channels, hidden_channels, kernel_size, padding='same', bias=False) - self.conv7 = torch.nn.Conv3d(hidden_channels, 1, kernel_size, padding='same', bias=False) - - self.activation = torch.nn.ReLU() - - def forward(self, x): - shape = x.shape - z = self.activation(self.conv1(x)) - z = self.activation(self.conv2(z)) - z1 = self.activation(self.conv3(z)) - - z2 = self.max_pool(z1) - z2 = self.activation(self.conv4(z2)) - z2 = self.activation(self.conv5(z2)) - - z3 = torch.nn.functional.interpolate(z2, size=shape[-3:], mode = "trilinear", align_corners=True) - z3 = z3 + z1 - - z4 = self.activation(self.conv6(z3)) - z_out = self.activation(self.conv7(z4)) - - return z_out - - -precond = NetworkPreconditioner(n_layers=4) -precond = precond.to(device) -precond.load_state_dict(torch.load("checkpoint/model.pt", weights_only=True)) - - -from utils.number_of_subsets import compute_number_of_subsets -for data_name in data_dirs_metrics: - - data = get_data(srcdir=data_name[0], outdir=data_name[1]) - print(data_name[0]) - if data.acquired_data.shape[0] == 1: - views = data.acquired_data.shape[2] - num_subsets = compute_number_of_subsets(views, tof=False) - else: - num_subsets = 25 - - name = str(data_name[0]).split("/")[-1] - data_sub, _, obj_funs = partitioner.data_partition(data.acquired_data, data.additive_term, - data.mult_factors, num_subsets, - initial_image=data.OSEM_image, - mode="staggered") - - _, _, full_obj_fun = partitioner.data_partition(data.acquired_data, data.additive_term, - data.mult_factors, 1, - initial_image=data.OSEM_image, - mode="staggered") - - full_obj_fun[0].set_up(data.OSEM_image) - sensitiviy = full_obj_fun[0].get_subset_sensitivity(0) - sensitiviy += sensitiviy.max()/1e4 - - """ - full_obj_fun[0].set_up(data.OSEM_image) - sensitiviy = full_obj_fun[0].get_subset_sensitivity(0) - sensitiviy += sensitiviy.max()/1e4 - - eps = data.OSEM_image.max()/1e3 - - my_prior = data.prior - my_prior.set_penalisation_factor(data.prior.get_penalisation_factor()) - my_prior.set_up(data.OSEM_image) - - grad = full_obj_fun[0].gradient(data.OSEM_image) - prior_grad = my_prior.gradient(data.OSEM_image) - - grad = (data.OSEM_image + eps) * grad / sensitiviy - prior_grad = (data.OSEM_image + eps) * prior_grad / sensitiviy - """ - - pll_grad = data.OSEM_image.get_uniform_copy(0) - for i in range(len(obj_funs)): - obj_funs[i].set_up(data.OSEM_image) - pll_grad += obj_funs[i].gradient(data.OSEM_image) - - average_sensitivity = data.OSEM_image.get_uniform_copy(0) - for s in range(len(data_sub)): - subset_sens = obj_funs[s].get_subset_sensitivity(0) - average_sensitivity += subset_sens - - # add a small number to avoid division by zero in the preconditioner - average_sensitivity += average_sensitivity.max()/1e4 - - print("sens: ", (sensitiviy - average_sensitivity).norm()) - - - fig, ax = plt.subplots(1,2, figsize=(16,8)) - im = ax[0].imshow(sensitiviy.as_array()[56, :, :]) - ax[0].set_title("full sens") - fig.colorbar(im, ax=ax[0]) - im = ax[1].imshow(average_sensitivity.as_array()[56, :, :]) - ax[1].set_title("avg sens") - fig.colorbar(im, ax=ax[1]) - - plt.savefig(f"input_{name}.png") - plt.close() - - eps = data.OSEM_image.max()/1e3 - - my_prior = data.prior - my_prior.set_penalisation_factor(data.prior.get_penalisation_factor()) - my_prior.set_up(data.OSEM_image) - - prior_grad = my_prior.gradient(data.OSEM_image) - - grad = (data.OSEM_image + eps) * pll_grad / average_sensitivity - prior_grad = (data.OSEM_image + eps) * prior_grad / average_sensitivity - - whole_object_mask = data.whole_object_mask.as_array() - background_mask = data.background_mask.as_array() - - voi_indices = {} - for key, value in data.voi_masks.items(): - voi_indices[key] = np.where(value.as_array()) - voi_indices = voi_indices - - osem_input_torch = torch.from_numpy(data.OSEM_image.as_array()).float() - osem_input_torch = osem_input_torch.to(device).unsqueeze(0) - - x_reference = torch.from_numpy(data.reference_image.as_array()).float() - x_reference = x_reference.to(device).unsqueeze(0).unsqueeze(0) - - grad = torch.from_numpy(grad.as_array()).float() - grad = grad.to(device).unsqueeze(0) - - prior_grad = torch.from_numpy(prior_grad.as_array()).float() - prior_grad = prior_grad.to(device).unsqueeze(0) - print(torch.sum(osem_input_torch**2), torch.sum(prior_grad**2), torch.sum(grad**2)) - model_inp = torch.cat([osem_input_torch, grad, prior_grad], dim=0).unsqueeze(0) - - x_pred = precond(model_inp) #+ osem_input_torch.unsqueeze(0) - - print(evaluate_quality_metrics(x_reference.detach().cpu().squeeze().numpy(), - x_pred.detach().cpu().squeeze().numpy(), - whole_object_mask, - background_mask, - voi_indices)) - - - diff --git a/train_single_iteration.py b/train_single_iteration.py deleted file mode 100644 index 4e6214d..0000000 --- a/train_single_iteration.py +++ /dev/null @@ -1,339 +0,0 @@ -# %% - -from pathlib import Path -import numpy as np -import logging -import os -from dataclasses import dataclass -from matplotlib import pyplot as plt -from skimage.metrics import mean_squared_error as mse -from tqdm import tqdm - -import torch -torch.cuda.set_per_process_memory_fraction(0.7) - -log = logging.getLogger('petric') -TEAM = os.getenv("GITHUB_REPOSITORY", "SyneRBI/PETRIC-").split("/PETRIC-", 1)[-1] -VERSION = os.getenv("GITHUB_REF_NAME", "") -OUTDIR = Path(f"/o/logs/{TEAM}/{VERSION}" if TEAM and VERSION else "./output") -if not (SRCDIR := Path("/mnt/share/petric")).is_dir(): - SRCDIR = Path("./data") - -def evaluate_quality_metrics(reference, prediction, whole_object_mask, background_mask, voi_indices): - whole_object_indices = np.where(whole_object_mask) - background_indices = np.where(background_mask) - norm = reference[background_indices].mean() - - - whole = { - "RMSE_whole_object": np.sqrt( - mse(reference[whole_object_indices], prediction[whole_object_indices])) / norm, - "RMSE_background": np.sqrt( - mse(reference[background_indices], prediction[background_indices])) / norm} - local = { - f"AEM_VOI_{voi_name}": np.abs(prediction[voi_indices].mean() - reference[voi_indices].mean()) / - norm for voi_name, voi_indices in sorted(voi_indices.items())} - return {**whole, **local} - -def construct_RDP(penalty_strength, initial_image, kappa, max_scaling=1e-3): - """ - Construct a smoothed Relative Difference Prior (RDP) - - initial_image: used to determine a smoothing factor (epsilon). - kappa: used to pass voxel-dependent weights. - """ - prior = getattr(STIR, 'CudaRelativeDifferencePrior', STIR.RelativeDifferencePrior)() - # need to make it differentiable - epsilon = initial_image.max() * max_scaling - prior.set_epsilon(epsilon) - prior.set_penalisation_factor(penalty_strength) - prior.set_kappa(kappa) - prior.set_up(initial_image) - return prior - - -def get_data(srcdir=".", outdir=OUTDIR, sirf_verbosity=0): - """ - Load data from `srcdir`, constructs prior and return as a `Dataset`. - Also redirects sirf.STIR log output to `outdir`. - """ - srcdir = Path(srcdir) - outdir = Path(outdir) - STIR.set_verbosity(sirf_verbosity) # set to higher value to diagnose problems - STIR.AcquisitionData.set_storage_scheme('memory') # needed for get_subsets() - - _ = STIR.MessageRedirector(str(outdir / 'info.txt'), str(outdir / 'warnings.txt'), str(outdir / 'errors.txt')) - acquired_data = STIR.AcquisitionData(str(srcdir / 'prompts.hs')) - additive_term = STIR.AcquisitionData(str(srcdir / 'additive_term.hs')) - mult_factors = STIR.AcquisitionData(str(srcdir / 'mult_factors.hs')) - OSEM_image = STIR.ImageData(str(srcdir / 'OSEM_image.hv')) - kappa = STIR.ImageData(str(srcdir / 'kappa.hv')) - if (penalty_strength_file := (srcdir / 'penalisation_factor.txt')).is_file(): - penalty_strength = float(np.loadtxt(penalty_strength_file)) - else: - penalty_strength = 1 / 700 # default choice - prior = construct_RDP(penalty_strength, OSEM_image, kappa) - - def get_image(fname): - if (source := srcdir / 'PETRIC' / fname).is_file(): - return STIR.ImageData(str(source)) - return None # explicit to suppress linter warnings - - reference_image = get_image('reference_image.hv') - whole_object_mask = get_image('VOI_whole_object.hv') - background_mask = get_image('VOI_background.hv') - voi_masks = { - voi.stem[4:]: STIR.ImageData(str(voi)) - for voi in (srcdir / 'PETRIC').glob("VOI_*.hv") if voi.stem[4:] not in ('background', 'whole_object')} - - return Dataset(acquired_data, additive_term, mult_factors, OSEM_image, prior, kappa, reference_image, - whole_object_mask, background_mask, voi_masks) - - -if SRCDIR.is_dir(): - data_dirs_metrics = [ (SRCDIR / "Siemens_mMR_NEMA_IQ", OUTDIR / "mMR_NEMA"), - (SRCDIR / "Mediso_NEMA_IQ", OUTDIR / "Mediso_NEMA"), - (SRCDIR / "Siemens_Vision600_thorax", OUTDIR / "Vision600_thorax"), - (SRCDIR / "Siemens_mMR_ACR", OUTDIR / "Siemens_mMR_ACR"), - (SRCDIR / "NeuroLF_Hoffman_Dataset", OUTDIR / "NeuroLF_Hoffman"), - (SRCDIR / "Siemens_mMR_NEMA_IQ_lowcounts", OUTDIR / "mMR_NEMA_lowcounts"), - ] - -load_data = True - -initial_images = [] -prior_grads = [] -pll_grads = [] -reference_images = [] -whole_object_mask_ = [] -background_mask_ = [] -voi_indices_ = [] -for data_name in data_dirs_metrics: - - name = str(data_name[0]).split("/")[-1] - print(name) - - if load_data: - - initial_images.append(torch.load(f"training_data/{name}_initial_image.pt")) - prior_grads.append(torch.load(f"training_data/{name}_prior_grads.pt")) - pll_grads.append(torch.load(f"training_data/{name}_pll_grads.pt")) - reference_images.append(torch.load(f"training_data/{name}_reference_images.pt")) - - print("Norm of initial: ", torch.sum(initial_images[-1]**2)) - print("Norm of reference: ", torch.sum(reference_images[-1]**2)) - - print(reference_images[-1].shape) - - whole_object_mask_.append(np.load(f"training_data/{name}_whole_object_mask.npy")) - background_mask_.append(np.load(f"training_data/{name}_background_mask.npy")) - - voi_indices_.append(np.load(f"training_data/{name}_voi_masks.npy", allow_pickle=True).tolist()) - print(evaluate_quality_metrics(reference_images[-1].detach().cpu().squeeze().numpy(), - initial_images[-1].detach().cpu().squeeze().numpy(), - whole_object_mask_[-1], - background_mask_[-1], - voi_indices_[-1])) - else: - import sirf.STIR as STIR - from sirf.contrib.partitioner import partitioner - - @dataclass - class Dataset: - acquired_data: STIR.AcquisitionData - additive_term: STIR.AcquisitionData - mult_factors: STIR.AcquisitionData - OSEM_image: STIR.ImageData - prior: STIR.RelativeDifferencePrior - kappa: STIR.ImageData - reference_image: STIR.ImageData | None - whole_object_mask: STIR.ImageData | None - background_mask: STIR.ImageData | None - voi_masks: dict[str, STIR.ImageData] - - - data = get_data(srcdir=data_name[0], outdir=data_name[1]) - name = str(data_name[0]).split("/")[-1] - data_sub, _, full_obj_fun = partitioner.data_partition(data.acquired_data, data.additive_term, - data.mult_factors, 1, - initial_image=data.OSEM_image, - mode="staggered") - - full_obj_fun[0].set_up(data.OSEM_image) - sensitiviy = full_obj_fun[0].get_subset_sensitivity(0) - sensitiviy += sensitiviy.max()/1e4 - - eps = data.OSEM_image.max()/1e3 - - my_prior = data.prior - my_prior.set_penalisation_factor(data.prior.get_penalisation_factor()) - my_prior.set_up(data.OSEM_image) - - grad = full_obj_fun[0].gradient(data.OSEM_image) - prior_grad = my_prior.gradient(data.OSEM_image) - - grad = (data.OSEM_image + eps) * grad / sensitiviy - prior_grad = (data.OSEM_image + eps) * prior_grad / sensitiviy - - initial_images.append(torch.from_numpy(data.OSEM_image.as_array()).float()) - prior_grads.append(torch.from_numpy(prior_grad.as_array()).float()) - pll_grads.append(torch.from_numpy(grad.as_array()).float()) - reference_images.append(torch.from_numpy(data.reference_image.as_array()).float()) - - whole_object_mask_.append(data.whole_object_mask.as_array()) - background_mask_.append(data.background_mask.as_array()) - - voi_indices = {} - for key, value in data.voi_masks.items(): - voi_indices[key] = np.where(value.as_array()) - voi_indices_.append(voi_indices) - - torch.save(initial_images[-1], f"training_data/{name}_initial_image.pt") - torch.save(prior_grads[-1], f"training_data/{name}_prior_grads.pt") - torch.save(pll_grads[-1], f"training_data/{name}_pll_grads.pt") - torch.save(reference_images[-1], f"training_data/{name}_reference_images.pt") - - np.save(f"training_data/{name}_whole_object_mask.npy", whole_object_mask_[-1]) - np.save(f"training_data/{name}_background_mask.npy", background_mask_[-1]) - np.save(f"training_data/{name}_voi_masks.npy", voi_indices_[-1]) - - - fig, ax = plt.subplots(1,2, figsize=(16,8)) - im = ax[0].imshow(grad.as_array()[56, :, :]) - ax[0].set_title("pll grad") - fig.colorbar(im, ax=ax[0]) - im = ax[1].imshow(prior_grad.as_array()[56, :, :]) - ax[1].set_title("rdp grad") - fig.colorbar(im, ax=ax[1]) - - plt.savefig(f"input_{name}.png") - plt.close() - -print("Data loaded") - - - -device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') -print("device: ", device) - -class NetworkPreconditioner(torch.nn.Module): - def __init__(self, hidden_channels = 8, kernel_size = 3): - super(NetworkPreconditioner, self).__init__() - - self.conv1 = torch.nn.Conv3d(3, 3*hidden_channels, kernel_size, groups=3, padding='same', bias=False) - self.conv2 = torch.nn.Conv3d(3*hidden_channels, 3*hidden_channels, kernel_size, groups=3, padding='same', bias=False) - self.conv3 = torch.nn.Conv3d(3*hidden_channels, hidden_channels, kernel_size, padding='same', bias=False) - - self.max_pool = torch.nn.MaxPool3d(kernel_size=2) - - self.conv4 = torch.nn.Conv3d(hidden_channels, hidden_channels, kernel_size, padding='same', bias=False) - self.conv5 = torch.nn.Conv3d(hidden_channels, hidden_channels, kernel_size, padding='same', bias=False) - - # interpolate - - self.conv6 = torch.nn.Conv3d(hidden_channels, hidden_channels, kernel_size, padding='same', bias=False) - self.conv7 = torch.nn.Conv3d(hidden_channels, 1, kernel_size, padding='same', bias=False) - - self.activation = torch.nn.ReLU() - - #self.list_of_conv3[-1].weight.data.fill_(0.0) - #self.list_of_conv3[-1].bias.data.fill_(0.0) - - def forward(self, x): - shape = x.shape - z = self.activation(self.conv1(x)) - z = self.activation(self.conv2(z)) - z1 = self.activation(self.conv3(z)) - - z2 = self.max_pool(z1) - z2 = self.activation(self.conv4(z2)) - z2 = self.activation(self.conv5(z2)) - - z3 = torch.nn.functional.interpolate(z2, size=shape[-3:], mode = "trilinear", align_corners=True) - z3 = z3 + z1 - - z4 = self.activation(self.conv6(z3)) - z_out = self.activation(self.conv7(z4)) - - return z_out - - -precond = NetworkPreconditioner(n_layers=4) -precond = precond.to(device) - -optimizer = torch.optim.Adam(precond.parameters(), lr=3e-4) -lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 10, gamma=0.95) -print("Number of parameters: ", sum([p.numel() for p in precond.parameters()])) - -idx_to_plot = [100, 80,110, 140, 80] -for i in tqdm(range(2000)): - optimizer.zero_grad() - - full_loss = 0 - for j in range(len(initial_images)): - - osem_input_torch = initial_images[j] - osem_input_torch = osem_input_torch.to(device).unsqueeze(0) - - x_reference = reference_images[j] - x_reference = x_reference.to(device).unsqueeze(0).unsqueeze(0) - - grad = pll_grads[j] - grad = grad.to(device).unsqueeze(0) - - prior_grad = prior_grads[j] - prior_grad = prior_grad.to(device).unsqueeze(0) - - model_inp = torch.cat([osem_input_torch, grad, prior_grad], dim=0).unsqueeze(0) - - x_pred = precond(model_inp) #+ osem_input_torch.unsqueeze(0) - - #print(x_pred.shape, osem_input_torch.shape, x_reference.shape, model_inp.shape) - loss = torch.mean((x_pred - x_reference)**2) / torch.mean(x_reference**2) - full_loss += loss.item() - loss.backward() - - if i % 250 == 0: - print(evaluate_quality_metrics(x_reference.detach().cpu().squeeze().numpy(), - x_pred.detach().cpu().squeeze().numpy(), - whole_object_mask_[j], - background_mask_[j], - voi_indices_[j])) - - - idx = idx_to_plot[j] - - fig, ax = plt.subplots(2,3, figsize=(16,8)) - im = ax[0,0].imshow(osem_input_torch.detach().cpu().numpy()[0, :, idx, :], cmap="gray") - ax[0,0].set_title("osem_input_torch") - fig.colorbar(im, ax=ax[0,0]) - - im = ax[0,1].imshow(x_pred.detach().cpu().numpy()[0, 0, :, idx, :], cmap="gray") - ax[0,1].set_title("x_pred") - fig.colorbar(im, ax=ax[0,1]) - - im = ax[0,2].imshow(x_reference.detach().cpu().numpy()[0, 0, :, idx, :], cmap="gray") - ax[0,2].set_title("x_reference") - fig.colorbar(im, ax=ax[0,2]) - - im = ax[1,0].imshow(np.abs(osem_input_torch.detach().cpu().numpy()[0, :, idx, :] - x_reference.detach().cpu().numpy()[0, 0, :, idx, :]), cmap="gray") - ax[1,0].set_title("osem_input_torch - x_reference") - fig.colorbar(im, ax=ax[1,0]) - - im = ax[1,1].imshow(np.abs(x_pred.detach().cpu().numpy()[0, 0, :, idx, :] - x_reference.detach().cpu().numpy()[0, 0, :, idx, :]), cmap="gray") - ax[1,1].set_title("x_pred - x_reference") - fig.colorbar(im, ax=ax[1,1]) - - ax[1,2].axis("off") - - plt.savefig(f"tmp_imgs/{i}_{j}.png") - plt.close() - - print(f"Iter {i}, Loss = {full_loss}, || lr = {lr_scheduler.get_last_lr()[0]}") - optimizer.step() - lr_scheduler.step() - - torch.save(precond.state_dict(), "checkpoint/bigger_model.pt") - -torch.save(precond.state_dict(), "checkpoint/bigger_model.pt") -