diff --git a/cppcore/include/hamiltonian/HamiltonianModifiers.hpp b/cppcore/include/hamiltonian/HamiltonianModifiers.hpp index 5e0d8b0b..27504233 100644 --- a/cppcore/include/hamiltonian/HamiltonianModifiers.hpp +++ b/cppcore/include/hamiltonian/HamiltonianModifiers.hpp @@ -218,12 +218,12 @@ struct HoppingBuffer { hoppings(size * unit_hopping.size()), pos1(size), pos2(size) {} - /// Replicate each value from the `unit_hopping` matrix `size` times - void reset_hoppings() { + /// Replicate each value from the `unit_hopping` matrix `num` times + void reset_hoppings(idx_t num) { auto start = idx_t{0}; for (auto const& value : unit_hopping) { - hoppings.segment(start, size).setConstant(value); - start += size; + hoppings.segment(start, num).setConstant(value); + start += num; } } @@ -269,8 +269,6 @@ void HamiltonianModifiers::apply_to_hoppings_impl(System const& system, auto buffer = HoppingBuffer(hopping_family.energy, block.size()); for (auto const coo_slice : sliced(block.coordinates(), buffer.size)) { - buffer.reset_hoppings(); - auto size = idx_t{0}; for (auto const& coo : coo_slice) { buffer.pos1[size] = system.positions[coo.row]; @@ -278,6 +276,7 @@ void HamiltonianModifiers::apply_to_hoppings_impl(System const& system, ++size; } + buffer.reset_hoppings(size); for (auto const& modifier : hopping) { modifier.apply(buffer.hoppings_ref(size), buffer.pos1.head(size), buffer.pos2.head(size), hopping_name); diff --git a/cppcore/tests/test_detail.cpp b/cppcore/tests/test_detail.cpp index 9ba44b6c..a6589768 100644 --- a/cppcore/tests/test_detail.cpp +++ b/cppcore/tests/test_detail.cpp @@ -2,6 +2,7 @@ #include #include "Model.hpp" +#include "detail/algorithm.hpp" using namespace cpb; namespace static_test_typelist { @@ -34,3 +35,24 @@ TEST_CASE("Symmetry masks") { })); } } + +TEST_CASE("sliced") { + auto const v = []{ + auto result = std::vector(10); + std::iota(result.begin(), result.end(), 0); + return result; + }(); + REQUIRE_THAT(v, Catch::Equals(std::vector{0, 1, 2, 3, 4, 5, 6, 7, 8, 9})); + + auto vectors = std::vector>(); + for (auto const& slice : sliced(v, 3)) { + auto tmp = std::vector(); + std::copy(slice.begin(), slice.end(), std::back_inserter(tmp)); + vectors.push_back(tmp); + } + REQUIRE(vectors.size() == 4); + REQUIRE_THAT(vectors[0], Catch::Equals(std::vector{0, 1, 2})); + REQUIRE_THAT(vectors[1], Catch::Equals(std::vector{3, 4, 5})); + REQUIRE_THAT(vectors[2], Catch::Equals(std::vector{6, 7, 8})); + REQUIRE_THAT(vectors[3], Catch::Equals(std::vector{9})); +}