Skip to content

Commit

Permalink
Code overlaps in computing gradients of objective functions removed. (#…
Browse files Browse the repository at this point in the history
…1284)

* got rid of code overlaps in computing gradients of objective functions

* slimmed C interface for computing gradients of objective functions
  • Loading branch information
evgueni-ovtchinnikov authored Aug 13, 2024
1 parent a4d8cae commit 1e2c5c1
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 43 deletions.
42 changes: 7 additions & 35 deletions src/xSTIR/cSTIR/cstir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1198,26 +1198,11 @@ void*
cSTIR_objectiveFunctionGradient(void* ptr_f, void* ptr_i, int subset)
{
try {
ObjectiveFunction3DF& fun = objectFromHandle< ObjectiveFunction3DF>(ptr_f);
STIRImageData& id = objectFromHandle<STIRImageData>(ptr_i);
Image3DF& image = id.data();
STIRImageData* ptr_id = new STIRImageData(image);
shared_ptr<STIRImageData> sptr(ptr_id);
Image3DF& grad = sptr->data();
if (subset >= 0)
fun.compute_sub_gradient(grad, image, subset);
else {
int nsub = fun.get_num_subsets();
grad.fill(0.0);
STIRImageData* ptr_id = new STIRImageData(image);
shared_ptr<STIRImageData> sptr_sub(ptr_id);
Image3DF& subgrad = sptr_sub->data();
for (int sub = 0; sub < nsub; sub++) {
fun.compute_sub_gradient(subgrad, image, sub);
grad += subgrad;
}
}
return newObjectHandle(sptr);
auto& fun = objectFromHandle<xSTIR_ObjFun3DF>(ptr_f);
auto& id = objectFromHandle<STIRImageData>(ptr_i);
auto sptr_gd = std::make_shared<STIRImageData>(id);
fun.compute_gradient(id, subset, *sptr_gd);
return newObjectHandle(sptr_gd);
}
CATCH;
}
Expand All @@ -1227,23 +1212,10 @@ void*
cSTIR_computeObjectiveFunctionGradient(void* ptr_f, void* ptr_i, int subset, void* ptr_g)
{
try {
ObjectiveFunction3DF& fun = objectFromHandle< ObjectiveFunction3DF>(ptr_f);
xSTIR_ObjFun3DF& fun = objectFromHandle<xSTIR_ObjFun3DF>(ptr_f);
STIRImageData& id = objectFromHandle<STIRImageData>(ptr_i);
STIRImageData& gd = objectFromHandle<STIRImageData>(ptr_g);
Image3DF& image = id.data();
Image3DF& grad = gd.data();
if (subset >= 0)
fun.compute_sub_gradient(grad, image, subset);
else {
int nsub = fun.get_num_subsets();
grad.fill(0.0);
shared_ptr<STIRImageData> sptr_sub(new STIRImageData(image));
Image3DF& subgrad = sptr_sub->data();
for (int sub = 0; sub < nsub; sub++) {
fun.compute_sub_gradient(subgrad, image, sub);
grad += subgrad;
}
}
fun.compute_gradient(id, subset, gd);
return (void*) new DataHandle;
}
CATCH;
Expand Down
34 changes: 26 additions & 8 deletions src/xSTIR/cSTIR/include/sirf/STIR/stir_x.h
Original file line number Diff line number Diff line change
Expand Up @@ -1105,11 +1105,33 @@ The actual algorithm is described in
}
};

class xSTIR_GeneralisedObjectiveFunction3DF :
public stir::GeneralisedObjectiveFunction < Image3DF > {
class xSTIR_GeneralisedObjectiveFunction3DF : public ObjectiveFunction3DF {
public:
//! computes the gradientof an objective function
/*! if the subset number is non-negative, computes the gradient of
this objective function for that subset, otherwise computes
the sum of gradients for all subsets
*/
void compute_gradient(const STIRImageData& id, int subset, STIRImageData& gd)
{
const Image3DF& image = id.data();
Image3DF& grad = gd.data();
if (subset >= 0)
compute_sub_gradient(grad, image, subset);
else {
int nsub = get_num_subsets();
grad.fill(0.0);
shared_ptr<STIRImageData> sptr_sub(new STIRImageData(image));
Image3DF& subgrad = sptr_sub->data();
for (int sub = 0; sub < nsub; sub++) {
compute_sub_gradient(subgrad, image, sub);
grad += subgrad;
}
}
}

void multiply_with_Hessian(Image3DF& output, const Image3DF& curr_image_est,
const Image3DF& input, const int subset) const
const Image3DF& input, const int subset) const
{
output.fill(0.0);
if (subset >= 0)
Expand All @@ -1120,13 +1142,9 @@ The actual algorithm is described in
}
}
}

// bool post_process() {
// return post_processing();
// }
};

//typedef xSTIR_GeneralisedObjectiveFunction3DF ObjectiveFunction3DF;
typedef xSTIR_GeneralisedObjectiveFunction3DF xSTIR_ObjFun3DF;

class xSTIR_PoissonLogLikelihoodWithLinearModelForMeanAndProjData3DF :
public stir::PoissonLogLikelihoodWithLinearModelForMeanAndProjData < Image3DF > {
Expand Down

0 comments on commit 1e2c5c1

Please # to comment.