From be35a267c29a7cd86a2094450ad96a2114ca84d1 Mon Sep 17 00:00:00 2001 From: Stefan Dragnev Date: Mon, 17 May 2021 12:26:32 +0200 Subject: [PATCH] add fixed_shape::operator== --- include/xtensor/xstorage.hpp | 14 ++++++++++++++ test/test_xbuilder.cpp | 8 -------- test/test_xstorage.cpp | 16 +++++++++++++--- 3 files changed, 27 insertions(+), 11 deletions(-) diff --git a/include/xtensor/xstorage.hpp b/include/xtensor/xstorage.hpp index f43e75049..234012395 100644 --- a/include/xtensor/xstorage.hpp +++ b/include/xtensor/xstorage.hpp @@ -1692,6 +1692,19 @@ namespace xt return sizeof...(X) == 0; } + template + XTENSOR_FIXED_SHAPE_CONSTEXPR bool operator==(const fixed_shape& b) const { + if (size() != b.size()) { + return false; + } + for (std::size_t i = 0; i < size(); ++i) { + if ((*this)[i] != b[i]) { + return false; + } + } + return true; + } + private: XTENSOR_CONSTEXPR_ENHANCED_STATIC cast_type m_array = cast_type({X...}); @@ -1702,6 +1715,7 @@ namespace xt constexpr typename fixed_shape::cast_type fixed_shape::m_array; #endif + #undef XTENSOR_FIXED_SHAPE_CONSTEXPR template diff --git a/test/test_xbuilder.cpp b/test/test_xbuilder.cpp index 0aa86cf16..a03cc131e 100644 --- a/test/test_xbuilder.cpp +++ b/test/test_xbuilder.cpp @@ -380,14 +380,6 @@ namespace xt XT_EXPECT_ANY_THROW(xt::concatenate(xt::xtuple(fa, ta))); } - template - bool operator==(fixed_shape, fixed_shape) - { - std::array ix = {I...}; - std::array jx = {J...}; - return sizeof...(J) == sizeof...(I) && std::equal(ix.begin(), ix.end(), jx.begin()); - } - #ifndef VS_SKIP_CONCATENATE_FIXED // This test mimics the relevant parts of `TEST(xbuilder, concatenate)` TEST(xbuilder, concatenate_fixed) diff --git a/test/test_xstorage.cpp b/test/test_xstorage.cpp index 54ac95e2a..fd9c1d5b7 100644 --- a/test/test_xstorage.cpp +++ b/test/test_xstorage.cpp @@ -207,7 +207,7 @@ namespace xt svector_type e(src); EXPECT_EQ(size_t(10), d.size()); EXPECT_EQ(size_t(1), d[2]); - + svector_type f = { 1, 2, 3, 4 }; EXPECT_EQ(size_t(4), f.size()); EXPECT_EQ(size_t(3), f[2]); @@ -221,7 +221,7 @@ namespace xt TEST(svector, assign) { svector_type a = { 1, 2, 3, 4 }; - + svector_type src1(10, 2); a = src1; EXPECT_EQ(size_t(10), a.size()); @@ -332,7 +332,7 @@ namespace xt EXPECT_EQ(i, a[i]); } } - + TEST(fixed_shape, fixed_shape) { fixed_shape<3, 4, 5> af; @@ -344,4 +344,14 @@ namespace xt EXPECT_EQ(a.front(), size_t(3)); EXPECT_EQ(a.size(), size_t(3)); } + + TEST(fixed_shape, operator_eq) + { + fixed_shape<2, 3> a, b; + fixed_shape<5> c; + fixed_shape<2, 4> d; + EXPECT_TRUE(a == b); + EXPECT_FALSE(a == c); + EXPECT_FALSE(a == d); + } }