-
Notifications
You must be signed in to change notification settings - Fork 22
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- add `rename_dimesion` to the collection of the SID transformations; - add the regression test that demonstrate the whole k-axis access to a field.
- Loading branch information
Showing
5 changed files
with
178 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
/* | ||
* GridTools | ||
* | ||
* Copyright (c) 2014-2019, ETH Zurich | ||
* All rights reserved. | ||
* | ||
* Please, refer to the LICENSE file in the root directory. | ||
* SPDX-License-Identifier: BSD-3-Clause | ||
*/ | ||
|
||
#pragma once | ||
|
||
#include <utility> | ||
|
||
#include "../common/hymap.hpp" | ||
#include "../meta.hpp" | ||
#include "concept.hpp" | ||
#include "delegate.hpp" | ||
|
||
namespace gridtools { | ||
namespace sid { | ||
namespace rename_dimension_impl_ { | ||
template <class OldKey, class NewKey, class Map> | ||
auto remap(Map map) { | ||
return hymap::convert_to<hymap::keys, meta::replace<get_keys<Map>, OldKey, NewKey>>(std::move(map)); | ||
} | ||
|
||
template <class OldKey, class NewKey, class Sid> | ||
struct renamed_sid : delegate<Sid> { | ||
template <class Map> | ||
using remapped_t = decltype(remap<OldKey, NewKey>(std::declval<Map>())); | ||
|
||
template <class T> | ||
renamed_sid(T &&obj) : delegate<Sid>(std::forward<T>(obj)) {} | ||
|
||
friend remapped_t<strides_type<Sid>> sid_get_strides(renamed_sid const &obj) { | ||
return remap<OldKey, NewKey>(sid_get_strides(obj.impl())); | ||
} | ||
friend remapped_t<lower_bounds_type<Sid>> sid_get_lower_bounds(renamed_sid const &obj) { | ||
return remap<OldKey, NewKey>(sid_get_lower_bounds(obj.impl())); | ||
} | ||
friend remapped_t<upper_bounds_type<Sid>> sid_get_upper_bounds(renamed_sid const &obj) { | ||
return remap<OldKey, NewKey>(sid_get_upper_bounds(obj.impl())); | ||
} | ||
}; | ||
|
||
template <class...> | ||
struct stride_kind_wrapper {}; | ||
|
||
template <class OldKey, class NewKey, class Sid> | ||
stride_kind_wrapper<OldKey, NewKey, strides_kind<Sid>> sid_get_strides_kind( | ||
renamed_sid<OldKey, NewKey, Sid> const &); | ||
|
||
template <class OldKey, class NewKey, class Sid> | ||
renamed_sid<OldKey, NewKey, Sid> rename_dimension(Sid &&sid) { | ||
return renamed_sid<OldKey, NewKey, Sid>{std::forward<Sid>(sid)}; | ||
} | ||
} // namespace rename_dimension_impl_ | ||
using rename_dimension_impl_::rename_dimension; | ||
} // namespace sid | ||
} // namespace gridtools |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
/* | ||
* GridTools | ||
* | ||
* Copyright (c) 2014-2019, ETH Zurich | ||
* All rights reserved. | ||
* | ||
* Please, refer to the LICENSE file in the root directory. | ||
* SPDX-License-Identifier: BSD-3-Clause | ||
*/ | ||
#include <type_traits> | ||
|
||
#include <gridtools/common/integral_constant.hpp> | ||
#include <gridtools/sid/rename_dimension.hpp> | ||
#include <gridtools/stencil/cartesian.hpp> | ||
#include <gridtools/stencil/positional.hpp> | ||
|
||
#include <stencil_select.hpp> | ||
#include <test_environment.hpp> | ||
|
||
namespace { | ||
using namespace gridtools; | ||
using namespace stencil; | ||
using namespace cartesian; | ||
|
||
struct functor { | ||
using out = inout_accessor<0>; | ||
using in = in_accessor<1, extent<>, 4>; | ||
using k_pos = in_accessor<2>; | ||
using param_list = make_param_list<out, in, k_pos>; | ||
|
||
template <class Eval> | ||
GT_FUNCTION static void apply(Eval &&eval) { | ||
auto k = eval(k_pos()); | ||
std::decay_t<decltype(eval(out()))> res = 0; | ||
for (int kk = 0; kk < k; ++kk) | ||
res += eval(in(0, 0, 0, kk)); | ||
eval(out()) = res; | ||
} | ||
}; | ||
|
||
GT_REGRESSION_TEST(whole_axis_access, test_environment<>, stencil_backend_t) { | ||
auto in = [](int i, int j, int k) { return i + j + k; }; | ||
auto out = TypeParam::make_storage(); | ||
run_single_stage(functor(), | ||
stencil_backend_t(), | ||
TypeParam::make_grid(), | ||
out, | ||
sid::rename_dimension<dim::k, integral_constant<int_t, 3>>(TypeParam::make_storage(in)), | ||
positional<dim::k>()); | ||
TypeParam::verify( | ||
[in](int i, int j, int k) { | ||
int res = 0; | ||
for (int kk = 0; kk < k; ++kk) | ||
res += in(i, j, kk); | ||
return res; | ||
}, | ||
out); | ||
} | ||
} // namespace |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
/* | ||
* GridTools | ||
* | ||
* Copyright (c) 2014-2019, ETH Zurich | ||
* All rights reserved. | ||
* | ||
* Please, refer to the LICENSE file in the root directory. | ||
* SPDX-License-Identifier: BSD-3-Clause | ||
*/ | ||
|
||
#include <gridtools/sid/rename_dimension.hpp> | ||
|
||
#include <gtest/gtest.h> | ||
|
||
#include <gridtools/common/hymap.hpp> | ||
#include <gridtools/common/integral_constant.hpp> | ||
#include <gridtools/common/tuple_util.hpp> | ||
#include <gridtools/sid/simple_ptr_holder.hpp> | ||
#include <gridtools/sid/synthetic.hpp> | ||
|
||
namespace gridtools { | ||
namespace { | ||
using sid::property; | ||
using namespace literals; | ||
namespace tu = tuple_util; | ||
|
||
struct a {}; | ||
struct b {}; | ||
struct c {}; | ||
struct d {}; | ||
|
||
TEST(rename_dimensions, smoke) { | ||
double data[3][5][7]; | ||
|
||
auto src = sid::synthetic() | ||
.set<property::origin>(sid::make_simple_ptr_holder(&data[0][0][0])) | ||
.set<property::strides>(tu::make<hymap::keys<a, b, c>::values>(5_c * 7_c, 7_c, 1_c)) | ||
.set<property::upper_bounds>(tu::make<hymap::keys<a, b>::values>(3, 5)); | ||
|
||
auto testee = sid::rename_dimension<b, d>(src); | ||
using testee_t = decltype(testee); | ||
|
||
auto strides = sid::get_strides(testee); | ||
EXPECT_EQ(35, sid::get_stride<a>(strides)); | ||
EXPECT_EQ(0, sid::get_stride<b>(strides)); | ||
EXPECT_EQ(1, sid::get_stride<c>(strides)); | ||
EXPECT_EQ(7, sid::get_stride<d>(strides)); | ||
|
||
static_assert(meta::is_empty<get_keys<sid::lower_bounds_type<testee_t>>>(), ""); | ||
|
||
auto u_bound = sid::get_upper_bounds(testee); | ||
EXPECT_EQ(3, at_key<a>(u_bound)); | ||
EXPECT_EQ(5, at_key<d>(u_bound)); | ||
} | ||
} // namespace | ||
} // namespace gridtools |