Skip to content

Commit

Permalink
add pnp example
Browse files Browse the repository at this point in the history
  • Loading branch information
alexdenker committed Jul 25, 2024
1 parent d9fbdc0 commit f14f4b2
Show file tree
Hide file tree
Showing 3 changed files with 206 additions and 0 deletions.
60 changes: 60 additions & 0 deletions main_PnP.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
"""Main file to modify for submissions.
Once renamed or symlinked as `main.py`, it will be used by `petric.py` as follows:
>>> from main import Submission, submission_callbacks
>>> from petric import data, metrics
>>> algorithm = Submission(data)
>>> algorithm.run(np.inf, callbacks=metrics + submission_callbacks)
"""
from cil.optimisation.algorithms import Algorithm
from cil.optimisation.utilities import callbacks
from petric import Dataset

from pnp import PnP1

from sirf.contrib.partitioner import partitioner


assert issubclass(PnP1, Algorithm)


class MaxIteration(callbacks.Callback):
"""
The organisers try to `Submission(data).run(inf)` i.e. for infinite iterations (until timeout).
This callback forces stopping after `max_iteration` instead.
"""
def __init__(self, max_iteration: int, verbose: int = 1):
super().__init__(verbose)
self.max_iteration = max_iteration

def __call__(self, algorithm: Algorithm):
if algorithm.iteration >= self.max_iteration:
raise StopIteration


class Submission(PnP1):
# note that `issubclass(PnP1, Algorithm) == True`
def __init__(self, data: Dataset, num_subsets: int = 7, update_objective_interval: int = 10):
"""
Initialisation function, setting up data & (hyper)parameters.
NB: in practice, `num_subsets` should likely be determined from the data.
This is just an example. Try to modify and improve it!
"""
data_sub, acq_models, obj_funs = partitioner.data_partition(data.acquired_data, data.additive_term,
data.mult_factors, num_subsets,
initial_image=data.OSEM_image)
# WARNING: modifies prior strength with 1/num_subsets (as currently needed for BSREM implementations)
data.prior.set_penalisation_factor(data.prior.get_penalisation_factor() / len(obj_funs))

data.prior.set_up(data.OSEM_image)
for f in obj_funs: # add prior evenly to every objective function
f.set_prior(data.prior)

initial_image = data.OSEM_image.clone() #/ (data.OSEM_image.clone() + 1e-4)

super().__init__(data_sub, obj_funs, initial=initial_image, initial_step_size=.3, relaxation_eta=.01,
update_objective_interval=update_objective_interval)


submission_callbacks = [MaxIteration(2000)]
2 changes: 2 additions & 0 deletions model_weights/readme.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
put model weights for drunet here

144 changes: 144 additions & 0 deletions pnp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@

#
# Classes implementing the BSREM+PnP algorithm in sirf.STIR
#
# BSREM from https://github.com/SyneRBI/SIRF-Contribs/blob/master/src/Python/sirf/contrib/BSREM/BSREM.py

import numpy
import sirf.STIR as STIR
from sirf.Utilities import examples_data_path

from cil.optimisation.algorithms import Algorithm

import torch
import numpy as np
from deepinv.models import DRUNet

import time

class PnPSkeleton(Algorithm):
''' Main implementation of a modified BSREM algorithm
This essentially implements constrained preconditioned gradient ascent
with an EM-type preconditioner.
In each update step, the gradient of a subset is computed, multiplied by a step_size and a EM-type preconditioner.
Before adding this to the previous iterate, an update_filter can be applied.
Step-size uses relaxation: ``initial_step_size`` / (1 + ``relaxation_eta`` * ``epoch()``)
'''
def __init__(self, data, initial, initial_step_size, relaxation_eta,
update_filter=STIR.TruncateToCylinderProcessor(), **kwargs):
'''
Arguments:
``data``: list of items as returned by `partitioner`
``initial``: initial estimate
``initial_step_size``, ``relaxation_eta``: step-size constants
``update_filter`` is applied on the (additive) update term, i.e. before adding to the previous iterate.
Set the filter to `None` if you don't want any.
'''
super().__init__(**kwargs)
self.x = initial.copy()
self.data = data
self.num_subsets = len(data)
self.initial_step_size = initial_step_size
self.relaxation_eta = relaxation_eta
# compute small number to add to image in preconditioner
# don't make it too small as otherwise the algorithm cannot recover from zeroes.
self.eps = initial.max()/1e3
self.average_sensitivity = initial.get_uniform_copy(0)
for s in range(len(data)):
self.average_sensitivity += self.subset_sensitivity(s)/self.num_subsets
# add a small number to avoid division by zero in the preconditioner
self.average_sensitivity += self.average_sensitivity.max()/1e4
self.subset = 0
self.update_filter = update_filter
self.configured = True

self.max_pnp_iters = 100

self.model = DRUNet(in_channels=1, out_channels=1, pretrained="model_weights/drunet_gray.pth", device="cuda")
self.denoising_strength = np.linspace(25./255, 5./255, self.max_pnp_iters)


def subset_sensitivity(self, subset_num):
raise NotImplementedError

def subset_gradient(self, x, subset_num):
raise NotImplementedError

def epoch(self):
return self.iteration // self.num_subsets

def step_size(self):
return self.initial_step_size / (1 + self.relaxation_eta * self.epoch())

def update(self):
if self.iteration < self.max_pnp_iters:
t1 = time.time()
#print(self.iteration, self.max_pnp_iters)

x = torch.from_numpy(self.x.as_array()).float().to("cuda").unsqueeze(1)
x_max, _ = torch.max(x.view(x.shape[0], -1), dim=-1)

#scale to 0-1
x = x/x_max[:,None,None,None]

x = self.model(x, self.denoising_strength[self.iteration])*x_max[:,None,None,None]

self.x.fill(x.squeeze().cpu().numpy())
self.x.maximum(0, out=self.x)
t2 = time.time()
print("Duration of denoising: ", t2 - t1, " s")

t1 = time.time()
g = self.subset_gradient(self.x, self.subset)
self.x_update = (self.x + self.eps) * g / self.average_sensitivity * self.step_size()
if self.update_filter is not None:
self.update_filter.apply(self.x_update)
self.x += self.x_update
# threshold to non-negative
self.x.maximum(0, out=self.x)
self.subset = (self.subset + 1) % self.num_subsets
t2 = time.time()
print("Duration of BSREM step: ", t2 - t1, " s")

def update_objective(self):
# required for current CIL (needs to set self.loss)
self.loss.append(self.objective_function(self.x))

def objective_function(self, x):
''' value of objective function summed over all subsets '''
v = 0
for s in range(len(self.data)):
v += self.subset_objective(x, s)
return v

def subset_objective(self, x, subset_num):
''' value of objective function for one subset '''
raise NotImplementedError

class PnP1(PnPSkeleton):
''' BSREM implementation using sirf.STIR objective functions'''
def __init__(self, data, obj_funs, initial, initial_step_size=1, relaxation_eta=0, **kwargs):
'''
construct Algorithm with lists of data and, objective functions, initial estimate, initial step size,
step-size relaxation (per epoch) and optionally Algorithm parameters
'''
self.obj_funs = obj_funs
super().__init__(data, initial, initial_step_size, relaxation_eta, **kwargs)

def subset_sensitivity(self, subset_num):
''' Compute sensitivity for a particular subset'''
self.obj_funs[subset_num].set_up(self.x)
# note: sirf.STIR Poisson likelihood uses `get_subset_sensitivity(0) for the whole
# sensitivity if there are no subsets in that likelihood
return self.obj_funs[subset_num].get_subset_sensitivity(0)

def subset_gradient(self, x, subset_num):
''' Compute gradient at x for a particular subset'''
return self.obj_funs[subset_num].gradient(x)

def subset_objective(self, x, subset_num):
''' value of objective function for one subset '''
return self.obj_funs[subset_num](x)

0 comments on commit f14f4b2

Please # to comment.