Skip to content

Commit

Permalink
resolve merge
Browse files Browse the repository at this point in the history
  • Loading branch information
alexdenker committed Jul 30, 2024
1 parent d24816d commit 4842cf8
Show file tree
Hide file tree
Showing 7 changed files with 330 additions and 342 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
data/
output/
output_ews/

tmp*/
err*.txt
info.txt
Expand Down
18 changes: 7 additions & 11 deletions adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,16 +60,9 @@ def __init__(self, data, initial, initial_step_size, relaxation_eta,
self.eps_adam = 1e-8

self.m = initial.get_uniform_copy(0)
# self.m.fill(np.zeros_like(initial.as_array()))

self.m_hat = initial.copy()
self.m_hat.fill(np.zeros_like(initial.as_array()))

self.v = initial.copy()
self.v.fill(np.zeros_like(initial.as_array()))

self.v_hat = initial.copy()
self.v_hat.fill(np.zeros_like(initial.as_array()))
self.m_hat = initial.get_uniform_copy(0)
self.v = initial.get_uniform_copy(0)
self.v_hat = initial.get_uniform_copy(0)


def subset_sensitivity(self, subset_num):
Expand All @@ -81,6 +74,9 @@ def subset_gradient(self, x, subset_num):
def epoch(self):
return self.iteration // self.num_subsets

def step_size(self):
return self.initial_step_size / (1 + self.relaxation_eta * self.epoch())

def update(self):
g = self.subset_gradient(self.x, self.subset)
#g = (self.x + self.eps) * g / self.average_sensitivity
Expand All @@ -92,7 +88,7 @@ def update(self):
self.v_hat = self.v.clone() / (1 - self.beta2 ** (self.iteration+1))
self.v_hat.sqrt(out=self.v_hat)

self.x_update = self.alpha * self.m_hat / (self.v_hat + self.eps_adam)
self.x_update = self.step_size() * self.m_hat / (self.v_hat + self.eps_adam)
if self.update_filter is not None:
self.update_filter.apply(self.x_update)

Expand Down
17 changes: 13 additions & 4 deletions main_ADAM.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,13 @@ def __call__(self, algorithm: Algorithm):

class Submission(Adam1):
# note that `issubclass(BSREM1, Algorithm) == True`
def __init__(self, data: Dataset, num_subsets: int = 7, update_objective_interval: int = 10):
def __init__(self, data: Dataset,
num_subsets: int = 7,
update_objective_interval: int = 10,
**kwargs):
"""
Initialisation function, setting up data & (hyper)parameters.
NB: in practice, `num_subsets` should likely be determined from the data.
This is just an example. Try to modify and improve it!
"""
data_sub, acq_models, obj_funs = partitioner.data_partition(data.acquired_data, data.additive_term,
data.mult_factors, num_subsets,
Expand All @@ -47,8 +49,15 @@ def __init__(self, data: Dataset, num_subsets: int = 7, update_objective_interva
for f in obj_funs: # add prior evenly to every objective function
f.set_prior(data.prior)

super().__init__(data_sub, obj_funs, initial=data.OSEM_image, initial_step_size=.3, relaxation_eta=.01,
initial_step_size = kwargs.get("initial_step_size", 0.1)
relaxation_eta = kwargs.get("relaxation_eta", 0.01)

super().__init__(data_sub,
obj_funs,
initial=data.OSEM_image,
initial_step_size=initial_step_size,
relaxation_eta=relaxation_eta,
update_objective_interval=update_objective_interval)


submission_callbacks = [MaxIteration(660)]
submission_callbacks = [] #[MaxIteration(660)]
23 changes: 18 additions & 5 deletions main_BSREM.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from cil.optimisation.algorithms import Algorithm
from cil.optimisation.utilities import callbacks
from petric import Dataset
from sirf.contrib.BSREM.BSREM import BSREM1
from BSREM import BSREM1
from sirf.contrib.partitioner import partitioner

assert issubclass(BSREM1, Algorithm)
Expand All @@ -32,23 +32,36 @@ def __call__(self, algorithm: Algorithm):

class Submission(BSREM1):
# note that `issubclass(BSREM1, Algorithm) == True`
def __init__(self, data: Dataset, num_subsets: int = 7, update_objective_interval: int = 10):
def __init__(self, data: Dataset,
num_subsets: int = 7,
update_objective_interval: int = 10,
**kwargs):
"""
Initialisation function, setting up data & (hyper)parameters.
NB: in practice, `num_subsets` should likely be determined from the data.
This is just an example. Try to modify and improve it!
"""
mode = kwargs.get("mode", "sequential")

data_sub, acq_models, obj_funs = partitioner.data_partition(data.acquired_data, data.additive_term,
data.mult_factors, num_subsets,
initial_image=data.OSEM_image)
initial_image=data.OSEM_image,
mode = mode)
# WARNING: modifies prior strength with 1/num_subsets (as currently needed for BSREM implementations)
data.prior.set_penalisation_factor(data.prior.get_penalisation_factor() / len(obj_funs))
data.prior.set_up(data.OSEM_image)
for f in obj_funs: # add prior evenly to every objective function
f.set_prior(data.prior)

super().__init__(data_sub, obj_funs, initial=data.OSEM_image, initial_step_size=.3, relaxation_eta=.01,
initial_step_size = kwargs.get("initial_step_size", 0.3)
relaxation_eta = kwargs.get("relaxation_eta", 0.01)

super().__init__(data_sub,
obj_funs,
initial=data.OSEM_image,
initial_step_size=initial_step_size,
relaxation_eta=relaxation_eta,
update_objective_interval=update_objective_interval)


submission_callbacks = [MaxIteration(660)]
submission_callbacks = [] #[MaxIteration(660)]
24 changes: 6 additions & 18 deletions main_EWS.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,26 +13,15 @@
from sirf.contrib.partitioner import partitioner

from ews import EWS

assert issubclass(EWS, Algorithm)


class MaxIteration(callbacks.Callback):
"""
The organisers try to `Submission(data).run(inf)` i.e. for infinite iterations (until timeout).
This callback forces stopping after `max_iteration` instead.
"""
def __init__(self, max_iteration: int, verbose: int = 1):
super().__init__(verbose)
self.max_iteration = max_iteration

def __call__(self, algorithm: Algorithm):
if algorithm.iteration >= self.max_iteration:
raise StopIteration

class Submission(EWS):
# note that `issubclass(BSREM1, Algorithm) == True`
def __init__(self, data: Dataset, num_subsets: int = 7, update_objective_interval: int = 10):
def __init__(self,
data: Dataset,
num_subsets: int = 7,
update_objective_interval: int = 10,
**kwargs):
"""
Initialisation function, setting up data & (hyper)parameters.
NB: in practice, `num_subsets` should likely be determined from the data.
Expand All @@ -50,5 +39,4 @@ def __init__(self, data: Dataset, num_subsets: int = 7, update_objective_interva
super().__init__(data_sub, obj_funs, initial=data.OSEM_image, initial_step_size=.3, relaxation_eta=.01,
update_objective_interval=update_objective_interval)


submission_callbacks = [MaxIteration(600)]
submission_callbacks = []
Loading

0 comments on commit 4842cf8

Please # to comment.