-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
d9fbdc0
commit f14f4b2
Showing
3 changed files
with
206 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
"""Main file to modify for submissions. | ||
Once renamed or symlinked as `main.py`, it will be used by `petric.py` as follows: | ||
>>> from main import Submission, submission_callbacks | ||
>>> from petric import data, metrics | ||
>>> algorithm = Submission(data) | ||
>>> algorithm.run(np.inf, callbacks=metrics + submission_callbacks) | ||
""" | ||
from cil.optimisation.algorithms import Algorithm | ||
from cil.optimisation.utilities import callbacks | ||
from petric import Dataset | ||
|
||
from pnp import PnP1 | ||
|
||
from sirf.contrib.partitioner import partitioner | ||
|
||
|
||
assert issubclass(PnP1, Algorithm) | ||
|
||
|
||
class MaxIteration(callbacks.Callback): | ||
""" | ||
The organisers try to `Submission(data).run(inf)` i.e. for infinite iterations (until timeout). | ||
This callback forces stopping after `max_iteration` instead. | ||
""" | ||
def __init__(self, max_iteration: int, verbose: int = 1): | ||
super().__init__(verbose) | ||
self.max_iteration = max_iteration | ||
|
||
def __call__(self, algorithm: Algorithm): | ||
if algorithm.iteration >= self.max_iteration: | ||
raise StopIteration | ||
|
||
|
||
class Submission(PnP1): | ||
# note that `issubclass(PnP1, Algorithm) == True` | ||
def __init__(self, data: Dataset, num_subsets: int = 7, update_objective_interval: int = 10): | ||
""" | ||
Initialisation function, setting up data & (hyper)parameters. | ||
NB: in practice, `num_subsets` should likely be determined from the data. | ||
This is just an example. Try to modify and improve it! | ||
""" | ||
data_sub, acq_models, obj_funs = partitioner.data_partition(data.acquired_data, data.additive_term, | ||
data.mult_factors, num_subsets, | ||
initial_image=data.OSEM_image) | ||
# WARNING: modifies prior strength with 1/num_subsets (as currently needed for BSREM implementations) | ||
data.prior.set_penalisation_factor(data.prior.get_penalisation_factor() / len(obj_funs)) | ||
|
||
data.prior.set_up(data.OSEM_image) | ||
for f in obj_funs: # add prior evenly to every objective function | ||
f.set_prior(data.prior) | ||
|
||
initial_image = data.OSEM_image.clone() #/ (data.OSEM_image.clone() + 1e-4) | ||
|
||
super().__init__(data_sub, obj_funs, initial=initial_image, initial_step_size=.3, relaxation_eta=.01, | ||
update_objective_interval=update_objective_interval) | ||
|
||
|
||
submission_callbacks = [MaxIteration(2000)] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
put model weights for drunet here | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,144 @@ | ||
|
||
# | ||
# Classes implementing the BSREM+PnP algorithm in sirf.STIR | ||
# | ||
# BSREM from https://github.com/SyneRBI/SIRF-Contribs/blob/master/src/Python/sirf/contrib/BSREM/BSREM.py | ||
|
||
import numpy | ||
import sirf.STIR as STIR | ||
from sirf.Utilities import examples_data_path | ||
|
||
from cil.optimisation.algorithms import Algorithm | ||
|
||
import torch | ||
import numpy as np | ||
from deepinv.models import DRUNet | ||
|
||
import time | ||
|
||
class PnPSkeleton(Algorithm): | ||
''' Main implementation of a modified BSREM algorithm | ||
This essentially implements constrained preconditioned gradient ascent | ||
with an EM-type preconditioner. | ||
In each update step, the gradient of a subset is computed, multiplied by a step_size and a EM-type preconditioner. | ||
Before adding this to the previous iterate, an update_filter can be applied. | ||
Step-size uses relaxation: ``initial_step_size`` / (1 + ``relaxation_eta`` * ``epoch()``) | ||
''' | ||
def __init__(self, data, initial, initial_step_size, relaxation_eta, | ||
update_filter=STIR.TruncateToCylinderProcessor(), **kwargs): | ||
''' | ||
Arguments: | ||
``data``: list of items as returned by `partitioner` | ||
``initial``: initial estimate | ||
``initial_step_size``, ``relaxation_eta``: step-size constants | ||
``update_filter`` is applied on the (additive) update term, i.e. before adding to the previous iterate. | ||
Set the filter to `None` if you don't want any. | ||
''' | ||
super().__init__(**kwargs) | ||
self.x = initial.copy() | ||
self.data = data | ||
self.num_subsets = len(data) | ||
self.initial_step_size = initial_step_size | ||
self.relaxation_eta = relaxation_eta | ||
# compute small number to add to image in preconditioner | ||
# don't make it too small as otherwise the algorithm cannot recover from zeroes. | ||
self.eps = initial.max()/1e3 | ||
self.average_sensitivity = initial.get_uniform_copy(0) | ||
for s in range(len(data)): | ||
self.average_sensitivity += self.subset_sensitivity(s)/self.num_subsets | ||
# add a small number to avoid division by zero in the preconditioner | ||
self.average_sensitivity += self.average_sensitivity.max()/1e4 | ||
self.subset = 0 | ||
self.update_filter = update_filter | ||
self.configured = True | ||
|
||
self.max_pnp_iters = 100 | ||
|
||
self.model = DRUNet(in_channels=1, out_channels=1, pretrained="model_weights/drunet_gray.pth", device="cuda") | ||
self.denoising_strength = np.linspace(25./255, 5./255, self.max_pnp_iters) | ||
|
||
|
||
def subset_sensitivity(self, subset_num): | ||
raise NotImplementedError | ||
|
||
def subset_gradient(self, x, subset_num): | ||
raise NotImplementedError | ||
|
||
def epoch(self): | ||
return self.iteration // self.num_subsets | ||
|
||
def step_size(self): | ||
return self.initial_step_size / (1 + self.relaxation_eta * self.epoch()) | ||
|
||
def update(self): | ||
if self.iteration < self.max_pnp_iters: | ||
t1 = time.time() | ||
#print(self.iteration, self.max_pnp_iters) | ||
|
||
x = torch.from_numpy(self.x.as_array()).float().to("cuda").unsqueeze(1) | ||
x_max, _ = torch.max(x.view(x.shape[0], -1), dim=-1) | ||
|
||
#scale to 0-1 | ||
x = x/x_max[:,None,None,None] | ||
|
||
x = self.model(x, self.denoising_strength[self.iteration])*x_max[:,None,None,None] | ||
|
||
self.x.fill(x.squeeze().cpu().numpy()) | ||
self.x.maximum(0, out=self.x) | ||
t2 = time.time() | ||
print("Duration of denoising: ", t2 - t1, " s") | ||
|
||
t1 = time.time() | ||
g = self.subset_gradient(self.x, self.subset) | ||
self.x_update = (self.x + self.eps) * g / self.average_sensitivity * self.step_size() | ||
if self.update_filter is not None: | ||
self.update_filter.apply(self.x_update) | ||
self.x += self.x_update | ||
# threshold to non-negative | ||
self.x.maximum(0, out=self.x) | ||
self.subset = (self.subset + 1) % self.num_subsets | ||
t2 = time.time() | ||
print("Duration of BSREM step: ", t2 - t1, " s") | ||
|
||
def update_objective(self): | ||
# required for current CIL (needs to set self.loss) | ||
self.loss.append(self.objective_function(self.x)) | ||
|
||
def objective_function(self, x): | ||
''' value of objective function summed over all subsets ''' | ||
v = 0 | ||
for s in range(len(self.data)): | ||
v += self.subset_objective(x, s) | ||
return v | ||
|
||
def subset_objective(self, x, subset_num): | ||
''' value of objective function for one subset ''' | ||
raise NotImplementedError | ||
|
||
class PnP1(PnPSkeleton): | ||
''' BSREM implementation using sirf.STIR objective functions''' | ||
def __init__(self, data, obj_funs, initial, initial_step_size=1, relaxation_eta=0, **kwargs): | ||
''' | ||
construct Algorithm with lists of data and, objective functions, initial estimate, initial step size, | ||
step-size relaxation (per epoch) and optionally Algorithm parameters | ||
''' | ||
self.obj_funs = obj_funs | ||
super().__init__(data, initial, initial_step_size, relaxation_eta, **kwargs) | ||
|
||
def subset_sensitivity(self, subset_num): | ||
''' Compute sensitivity for a particular subset''' | ||
self.obj_funs[subset_num].set_up(self.x) | ||
# note: sirf.STIR Poisson likelihood uses `get_subset_sensitivity(0) for the whole | ||
# sensitivity if there are no subsets in that likelihood | ||
return self.obj_funs[subset_num].get_subset_sensitivity(0) | ||
|
||
def subset_gradient(self, x, subset_num): | ||
''' Compute gradient at x for a particular subset''' | ||
return self.obj_funs[subset_num].gradient(x) | ||
|
||
def subset_objective(self, x, subset_num): | ||
''' value of objective function for one subset ''' | ||
return self.obj_funs[subset_num](x) |