diff --git a/src/buildblock/ProjDataInMemory.cxx b/src/buildblock/ProjDataInMemory.cxx index dbd53ebcc..b70267c0f 100644 --- a/src/buildblock/ProjDataInMemory.cxx +++ b/src/buildblock/ProjDataInMemory.cxx @@ -30,6 +30,7 @@ #include "stir/SegmentByView.h" #include "stir/Bin.h" #include "stir/is_null_ptr.h" +#include "stir/numerics/norm.h" #include #include #include @@ -361,6 +362,148 @@ ProjDataInMemory::set_bin_value(const Bin& bin) buffer[this->get_index(bin)] = bin.get_bin_value(); } +float +ProjDataInMemory::sum() const +{ + return buffer.sum(); +} + +float +ProjDataInMemory::find_max() const +{ + return buffer.find_max(); +} + +float +ProjDataInMemory::find_min() const +{ + return buffer.find_min(); +} + +double +ProjDataInMemory::norm() const +{ + return stir::norm(this->buffer); +} + +double +ProjDataInMemory::norm_squared() const +{ + return stir::norm_squared(this->buffer); +} + +ProjDataInMemory& +ProjDataInMemory::operator+=(const ProjDataInMemory& v) +{ + this->buffer += v.buffer; + return *this; +} + +ProjDataInMemory& +ProjDataInMemory::operator-=(const ProjDataInMemory& v) +{ + this->buffer -= v.buffer; + return *this; +} + +ProjDataInMemory& +ProjDataInMemory::operator*=(const ProjDataInMemory& v) +{ + this->buffer *= v.buffer; + return *this; +} + +ProjDataInMemory& +ProjDataInMemory::operator/=(const ProjDataInMemory& v) +{ + this->buffer /= v.buffer; + return *this; +} + +ProjDataInMemory& +ProjDataInMemory::operator+=(const float v) +{ + this->buffer += v; + return *this; +} + +ProjDataInMemory& +ProjDataInMemory::operator-=(const float v) +{ + this->buffer -= v; + return *this; +} + +ProjDataInMemory& +ProjDataInMemory::operator*=(const float v) +{ + this->buffer *= v; + return *this; +} + +ProjDataInMemory& +ProjDataInMemory::operator/=(const float v) +{ + this->buffer /= v; + return *this; +} + +ProjDataInMemory +ProjDataInMemory::operator+(const ProjDataInMemory& iv) const +{ + ProjDataInMemory c(*this); + return c += iv; +} + +ProjDataInMemory +ProjDataInMemory::operator-(const ProjDataInMemory& iv) const +{ + ProjDataInMemory c(*this); + return c -= iv; +} + +ProjDataInMemory +ProjDataInMemory::operator*(const ProjDataInMemory& iv) const +{ + ProjDataInMemory c(*this); + return c *= iv; +} + +ProjDataInMemory +ProjDataInMemory::operator/(const ProjDataInMemory& iv) const +{ + ProjDataInMemory c(*this); + return c /= iv; +} + +ProjDataInMemory +ProjDataInMemory::operator+(const float a) const +{ + ProjDataInMemory c(*this); + return c += a; +} + +ProjDataInMemory +ProjDataInMemory::operator-(const float a) const +{ + ProjDataInMemory c(*this); + return c -= a; +} + +ProjDataInMemory +ProjDataInMemory::operator*(const float a) const +{ + ProjDataInMemory c(*this); + return c *= a; +} + +ProjDataInMemory +ProjDataInMemory::operator/(const float a) const +{ + ProjDataInMemory c(*this); + return c /= a; +} + void ProjDataInMemory::axpby(const float a, const ProjData& x, const float b, const ProjData& y) { diff --git a/src/include/stir/ProjDataInMemory.h b/src/include/stir/ProjDataInMemory.h index 9bfad53f7..a9c067697 100644 --- a/src/include/stir/ProjDataInMemory.h +++ b/src/include/stir/ProjDataInMemory.h @@ -37,6 +37,9 @@ class Succeeded; */ class ProjDataInMemory : public ProjData { + + typedef ProjDataInMemory self_type; + public: //! constructor with only info, but no data /*! @@ -57,39 +60,21 @@ class ProjDataInMemory : public ProjData Viewgram get_viewgram(const int view_num, const int segment_num, - const bool make_num_tangential_poss_odd = false -#ifdef STIR_TOF - , - const int timing_pos = 0 -#endif - ) const override; + const bool make_num_tangential_poss_odd = false, + const int timing_pos = 0) const override; Succeeded set_viewgram(const Viewgram& v) override; Sinogram get_sinogram(const int ax_pos_num, const int segment_num, - const bool make_num_tangential_poss_odd = false -#ifdef STIR_TOF - , - const int timing_pos = 0 -#endif - ) const override; + const bool make_num_tangential_poss_odd = false, + const int timing_pos = 0) const override; Succeeded set_sinogram(const Sinogram& s) override; //! Get all sinograms for the given segment - SegmentBySinogram get_segment_by_sinogram(const int segment_num -#ifdef STIR_TOF - , - const int timing_pos = 0 -#endif - ) const override; + SegmentBySinogram get_segment_by_sinogram(const int segment_num, const int timing_pos = 0) const override; //! Get all viewgrams for the given segment - SegmentByView get_segment_by_view(const int segment_num -#ifdef STIR_TOF - , - const int timing_pos = 0 -#endif - ) const override; + SegmentByView get_segment_by_view(const int segment_num, const int timing_pos = 0) const override; //! Set all sinograms for the given segment Succeeded set_segment(const SegmentBySinogram&) override; @@ -115,6 +100,73 @@ class ProjDataInMemory : public ProjData void set_bin_value(const Bin& bin); + //! @name arithmetic operations + ///@{ + //! return sum of all elements + float sum() const; + + //! return maximum value of all elements + float find_max() const; + + //! return minimum value of all elements + float find_min() const; + + //! return L2-norm (sqrt of sum of squares) + double norm() const; + + //! return L2-norm squared (sum of squares) + double norm_squared() const; + + //! elem by elem addition + self_type operator+(const self_type& iv) const; + + //! elem by elem subtraction + self_type operator-(const self_type& iv) const; + + //! elem by elem multiplication + self_type operator*(const self_type& iv) const; + + //! elem by elem division + self_type operator/(const self_type& iv) const; + + //! addition with a 'float' + self_type operator+(const float a) const; + + //! subtraction with a 'float' + self_type operator-(const float a) const; + + //! multiplication with a 'float' + self_type operator*(const float a) const; + + //! division with a 'float' + self_type operator/(const float a) const; + + // corresponding assignment operators + + //! adding elements of \c v to the current vector + self_type& operator+=(const self_type& v); + + //! subtracting elements of \c v from the current vector + self_type& operator-=(const self_type& v); + + //! multiplying elements of the current vector with elements of \c v + self_type& operator*=(const self_type& v); + + //! dividing all elements of the current vector by elements of \c v + self_type& operator/=(const self_type& v); + + //! adding an \c float to the elements of the current vector + self_type& operator+=(const float v); + + //! subtracting an \c float from the elements of the current vector + self_type& operator-=(const float v); + + //! multiplying the elements of the current vector with an \c float + self_type& operator*=(const float v); + + //! dividing the elements of the current vector by an \c float + self_type& operator/=(const float v); + //! \deprecated a*x+b*y (use xapyb) STIR_DEPRECATED void axpby(const float a, const ProjData& x, const float b, const ProjData& y) override; @@ -137,6 +189,7 @@ class ProjDataInMemory : public ProjData /// This implementation requires that a, b and y are ProjDataInMemory /// (else falls back on general method) void sapyb(const ProjData& a, const ProjData& y, const ProjData& b) override; + ///@} /** @name iterator typedefs * iterator typedefs @@ -149,71 +202,35 @@ class ProjDataInMemory : public ProjData ///@} //! start value for iterating through all elements in the array, see iterator - inline iterator begin() - { - return buffer.begin(); - } + iterator begin() { return buffer.begin(); } //! start value for iterating through all elements in the (const) array, see iterator - inline const_iterator begin() const - { - return buffer.begin(); - } + const_iterator begin() const { return buffer.begin(); } //! end value for iterating through all elements in the array, see iterator - inline iterator end() - { - return buffer.end(); - } + iterator end() { return buffer.end(); } //! end value for iterating through all elements in the (const) array, see iterator - inline const_iterator end() const - { - return buffer.end(); - } + const_iterator end() const { return buffer.end(); } //! start value for iterating through all elements in the array, see iterator - inline iterator begin_all() - { - return buffer.begin_all(); - } + iterator begin_all() { return buffer.begin_all(); } //! start value for iterating through all elements in the (const) array, see iterator - inline const_iterator begin_all() const - { - return buffer.begin_all(); - } + const_iterator begin_all() const { return buffer.begin_all(); } //! end value for iterating through all elements in the array, see iterator - inline iterator end_all() - { - return buffer.end_all(); - } + iterator end_all() { return buffer.end_all(); } //! end value for iterating through all elements in the (const) array, see iterator - inline const_iterator end_all() const - { - return buffer.end_all(); - } + const_iterator end_all() const { return buffer.end_all(); } //! \name access to the data via a pointer //@{ //! member function for access to the data via a float* - inline float* get_data_ptr() - { - return buffer.get_data_ptr(); - } + float* get_data_ptr() { return buffer.get_data_ptr(); } //! member function for access to the data via a const float* - inline const float* get_const_data_ptr() const - { - return buffer.get_const_data_ptr(); - } + const float* get_const_data_ptr() const { return buffer.get_const_data_ptr(); } //! signal end of access to float* - inline void release_data_ptr() - { - buffer.release_data_ptr(); - } + void release_data_ptr() { buffer.release_data_ptr(); } //! signal end of access to const float* - inline void release_const_data_ptr() const - { - buffer.release_const_data_ptr(); - } + void release_const_data_ptr() const { buffer.release_const_data_ptr(); } //@} private: @@ -236,6 +253,18 @@ class ProjDataInMemory : public ProjData std::streamoff get_index(const Bin&) const; }; +inline double +norm(const ProjDataInMemory& p) +{ + return p.norm(); +} + +inline double +norm_squared(const ProjDataInMemory& p) +{ + return p.norm_squared(); +} + END_NAMESPACE_STIR #endif diff --git a/src/include/stir/RunTests.h b/src/include/stir/RunTests.h index 8d6707bca..b40696260 100644 --- a/src/include/stir/RunTests.h +++ b/src/include/stir/RunTests.h @@ -2,7 +2,7 @@ Copyright (C) 2000 PARAPET partners Copyright (C) 2000-2005, Hammersmith Imanet Ltd Copyright (C) 2013, Kris Thielemans - Copyright (C) 2013, 2020, 2022, 2023 University College London + Copyright (C) 2013, 2020, 2022, 2023, 2024 University College London Copyright (C) 2018, University of Hull This file is part of STIR. @@ -27,6 +27,7 @@ #include "stir/stream.h" #include "stir/Bin.h" #include "stir/DetectionPosition.h" +#include "stir/ProjDataInMemory.h" #include #include #include @@ -179,6 +180,8 @@ class RunTests return all_equal; } + bool check_if_equal(const ProjDataInMemory& t1, const ProjDataInMemory& t2, const std::string& str = ""); + // VC 6.0 needs definition of template members in the class def unfortunately. template bool check_if_equal(const IndexRange& t1, const IndexRange& t2, const std::string& str = "") @@ -433,6 +436,25 @@ RunTests::check_if_equal(const unsigned long long a, const unsigned long long b, } #endif +bool +RunTests::check_if_equal(const ProjDataInMemory& t1, const ProjDataInMemory& t2, const std::string& str) +{ + if (*t1.get_proj_data_info_sptr() != *t2.get_proj_data_info_sptr()) + { + std::cerr << "Error: unequal proj_data_info. " << str << std::endl; + return everything_ok = false; + } + + for (auto i1 = t1.begin(), i2 = t2.begin(); i1 != t1.end(); ++i1, ++i2) + { + if (!check_if_equal(*i1, *i2, str)) + { + return everything_ok = false; + } + } + return true; +} + bool RunTests::check_if_zero(const short a, const std::string& str) { diff --git a/src/test/test_proj_data_maths.cxx b/src/test/test_proj_data_maths.cxx index cb55aca69..083176615 100644 --- a/src/test/test_proj_data_maths.cxx +++ b/src/test/test_proj_data_maths.cxx @@ -11,7 +11,7 @@ */ /* - Copyright (C) 2020 University College London + Copyright (C) 2020, 2024 University College London This file is part of STIR. SPDX-License-Identifier: Apache-2.0 @@ -29,6 +29,7 @@ #include "stir/Scanner.h" #include "stir/copy_fill.h" #include "stir/error.h" +#include START_NAMESPACE_STIR /*! @@ -109,6 +110,143 @@ ProjDataInMemoryTests::run_tests(shared_ptr exam_info_sptr, shar *pd_iter++ = a * (*x_iter++) + b * (*y_iter++); check_proj_data_are_equal_and_non_zero(pd1, pd3); + + // numeric operations + { + { + auto res1 = pd1 + pd2; + { + ProjDataInMemory res2(pd1); + res2.sapyb(1.F, pd2, 1.F); + check_if_equal(res1, res2, "+ vs sapyb"); + } + { + ProjDataInMemory res2(pd1); + res2 += pd2; + check_if_equal(res1, res2, "+ vs +="); + } + { + check_if_equal(norm(pd1 + pd1), 2 * norm(pd1), "norm of x+x"); + } + } + { + auto res1 = pd1 - pd2; + { + ProjDataInMemory res2(pd1); + res2.sapyb(1.F, pd2, -1.F); + check_if_equal(res1, res2, "- vs sapyb"); + } + { + ProjDataInMemory res2(pd1); + res2 -= pd2; + check_if_equal(res1, res2, "- vs -="); + } + { + res1 += pd2; + check_if_equal(pd1, res1, "- vs +="); + } + { + check_if_zero(norm(pd1 - pd1), "norm of x-x"); + } + } + { + auto res1 = pd1 * pd2; + { + ProjDataInMemory res2(pd1); + for (auto i1 = res2.begin(), i2 = pd2.begin(); i1 != res2.end(); ++i1, ++i2) + *i1 *= *i2; + check_if_equal(res1, res2, "* vs loop"); + } + { + ProjDataInMemory res2(pd1); + res2 *= pd2; + check_if_equal(res1, res2, "* vs *="); + } + { + res1 /= pd2; + check_if_equal(pd1, res1, "* vs /="); + } + } + { + auto res1 = pd1 / pd2; + { + ProjDataInMemory res2(pd1); + for (auto i1 = res2.begin(), i2 = pd2.begin(); i1 != res2.end(); ++i1, ++i2) + *i1 /= *i2; + check_if_equal(res1, res2, "/ vs loop"); + } + { + ProjDataInMemory res2(pd1); + res2 /= pd2; + check_if_equal(res1, res2, "/ vs /="); + } + { + res1 *= pd2; + check_if_equal(pd1, res1, "/ vs *="); + } + { + // assumes that all elements are !=0 + check_if_equal(norm_squared(pd1 / pd1), static_cast(pd1.size_all()), "norm of x/x"); + } + } + // now with floats + { + auto res1 = pd1 + 5.6F; + check_if_equal(res1.find_max(), pd1.find_max() + 5.6F, "max(x + 5.6F)"); + { + ProjDataInMemory res2(pd1); + res2 += 5.6F; + check_if_equal(res1, res2, "+ vs += float"); + } + { + res1 -= pd1; + res1 /= 5.6F; + check_if_equal(norm_squared(res1), static_cast(pd1.size_all()), "norm of x + 5.6"); + } + } + { + auto res1 = pd1 - 5.6F; + check_if_equal(res1.find_min(), pd1.find_min() - 5.6F, "min(x - 5.6F)"); + { + ProjDataInMemory res2(pd1); + res2 -= 5.6F; + check_if_equal(res1, res2, "- vs -= float"); + } + { + res1 += 5.6F; + check_if_equal(res1, pd1, "- vs += float"); + } + } + { + auto res1 = pd1 * 5.6F; + check_if_equal(norm(res1), norm(pd1) * 5.6F, "norm of x*5.6"); + check_if_equal(res1.sum(), pd1.sum() * 5.6F, "sum(x * 5.6F)"); + + { + ProjDataInMemory res2(pd1); + res2 *= 5.6F; + check_if_equal(res1, res2, "* vs *= float"); + } + { + res1 /= 5.6F; + check_if_equal(res1, pd1, "* vs /= float"); + } + } + { + auto res1 = pd1 / 5.6F; + check_if_equal(norm(res1), norm(pd1) / 5.6F, "norm of x/float"); + + { + ProjDataInMemory res2(pd1); + res2 /= 5.6F; + check_if_equal(res1, res2, "/ vs /= float"); + } + { + res1 /= 1 / 5.6F; + check_if_equal(res1, pd1, "/ vs /= float 2"); + } + } + } } void