Skip to content

Commit 388c399

Browse files
committed
add celerity blockchain for task divergence checking
1 parent 0822c32 commit 388c399

19 files changed

+1058
-19
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ Versioning](http://semver.org/spec/v2.0.0.html).
1313
- Introduce new experimental `for_each_item` utility to iterate over a celerity range (#199)
1414
- Add new environment variables `CELERITY_HORIZON_STEP` and `CELERITY_HORIZON_MAX_PARALLELISM` to control Horizon generation (#199)
1515
- Add new `experimental::constrain_split` API to limit how a kernel can be split (#?)
16+
- Add divergence check blockchain for automatic detection of diverging tasks in debug mode (#217)
1617

1718
## Changed
1819

CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,7 @@ set(SOURCES
187187
src/command_graph.cc
188188
src/config.cc
189189
src/device_queue.cc
190+
src/divergence_block_chain.cc
190191
src/executor.cc
191192
src/distributed_graph_generator.cc
192193
src/graph_serializer.cc

include/communicator.h

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
#pragma once
2+
3+
#include "types.h"
4+
5+
namespace celerity::detail {
6+
7+
/*
8+
* @brief Defines an interface for a communicator that can be used to communicate between nodes.
9+
*
10+
* This interface is used to abstract away the communication between nodes. This allows us to use different communication backends during testing and
11+
* runtime. For example, we can use MPI for the runtime and a custom implementation for testing.
12+
*/
13+
class communicator {
14+
public:
15+
communicator() = default;
16+
communicator(const communicator&) = delete;
17+
communicator(communicator&&) noexcept = default;
18+
19+
communicator& operator=(const communicator&) = delete;
20+
communicator& operator=(communicator&&) noexcept = default;
21+
22+
virtual ~communicator() = default;
23+
24+
template <typename S>
25+
void allgather_inplace(S* sendrecvbuf, const int sendrecvcount) {
26+
allgather_inplace_impl(reinterpret_cast<std::byte*>(sendrecvbuf), sendrecvcount * sizeof(S));
27+
}
28+
29+
template <typename S, typename R>
30+
void allgather(const S* sendbuf, const int sendcount, R* recvbuf, const int recvcount) {
31+
allgather_impl(reinterpret_cast<const std::byte*>(sendbuf), sendcount * sizeof(S), reinterpret_cast<std::byte*>(recvbuf), recvcount * sizeof(R));
32+
}
33+
34+
void barrier() { barrier_impl(); }
35+
36+
size_t get_num_nodes() { return num_nodes_impl(); }
37+
38+
node_id get_local_nid() { return local_nid_impl(); }
39+
40+
protected:
41+
virtual void allgather_inplace_impl(std::byte* sendrecvbuf, const int sendrecvcount) = 0;
42+
virtual void allgather_impl(const std::byte* sendbuf, const int sendcount, std::byte* recvbuf, const int recvcount) = 0;
43+
virtual void barrier_impl() = 0;
44+
virtual size_t num_nodes_impl() = 0;
45+
virtual node_id local_nid_impl() = 0;
46+
};
47+
} // namespace celerity::detail

include/divergence_block_chain.h

Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
#pragma once
2+
3+
#include <mutex>
4+
#include <thread>
5+
#include <vector>
6+
7+
#include "communicator.h"
8+
#include "recorders.h"
9+
10+
namespace celerity::detail {
11+
struct runtime_testspy;
12+
}
13+
14+
namespace celerity::detail::divergence_checker_detail {
15+
using task_hash = size_t;
16+
using divergence_map = std::unordered_map<task_hash, std::vector<node_id>>;
17+
18+
/**
19+
* @brief Stores the hashes of tasks for each node.
20+
*
21+
* The data is stored densely so it can easily be exchanged through MPI collective operations.
22+
*/
23+
struct per_node_task_hashes {
24+
public:
25+
per_node_task_hashes(const size_t max_hash_count, const size_t num_nodes) : m_data(max_hash_count * num_nodes), m_max_hash_count(max_hash_count){};
26+
const task_hash& operator()(const node_id nid, const size_t i) const { return m_data.at(nid * m_max_hash_count + i); }
27+
task_hash* data() { return m_data.data(); }
28+
29+
private:
30+
std::vector<task_hash> m_data;
31+
size_t m_max_hash_count;
32+
};
33+
34+
/**
35+
* @brief This class checks for divergences of tasks between nodes.
36+
*
37+
* It is responsible for collecting the task hashes from all nodes and checking for differences -> divergence.
38+
* When a divergence is found, the task record for the diverging task is printed and the program is terminated.
39+
* Additionally it will also print a warning when a deadlock is suspected.
40+
*/
41+
42+
class divergence_block_chain {
43+
friend struct divergence_block_chain_testspy;
44+
45+
public:
46+
divergence_block_chain(task_recorder& task_recorder, std::unique_ptr<communicator> comm)
47+
: m_local_nid(comm->get_local_nid()), m_num_nodes(comm->get_num_nodes()), m_per_node_hash_counts(comm->get_num_nodes()),
48+
m_communicator(std::move(comm)) {
49+
task_recorder.add_callback([this](const task_record& task) { add_new_task(task); });
50+
}
51+
52+
divergence_block_chain(const divergence_block_chain&) = delete;
53+
divergence_block_chain(divergence_block_chain&&) = delete;
54+
55+
~divergence_block_chain() = default;
56+
57+
divergence_block_chain& operator=(const divergence_block_chain&) = delete;
58+
divergence_block_chain& operator=(divergence_block_chain&&) = delete;
59+
60+
bool check_for_divergence();
61+
62+
private:
63+
node_id m_local_nid;
64+
size_t m_num_nodes;
65+
66+
std::vector<task_hash> m_local_hashes;
67+
std::vector<task_record> m_task_records;
68+
size_t m_tasks_checked = 0;
69+
size_t m_hashes_added = 0;
70+
71+
std::vector<int> m_per_node_hash_counts;
72+
std::mutex m_task_records_mutex;
73+
74+
std::chrono::time_point<std::chrono::steady_clock> m_last_cleared = std::chrono::steady_clock::now();
75+
76+
std::unique_ptr<communicator> m_communicator;
77+
78+
void divergence_out(const divergence_map& check_map, const int task_num);
79+
80+
void add_new_hashes();
81+
void clear(const int min_progress);
82+
std::pair<int, int> collect_hash_counts();
83+
per_node_task_hashes collect_hashes(const int min_hash_count) const;
84+
divergence_map create_check_map(const per_node_task_hashes& task_hashes, const int task_num) const;
85+
86+
void check_for_deadlock() const;
87+
88+
static void log_node_divergences(const divergence_map& check_map, const int task_num);
89+
static void log_task_record(const divergence_map& check_map, const task_record& task, const task_hash hash);
90+
void log_task_record_once(const divergence_map& check_map, const int task_num);
91+
92+
void add_new_task(const task_record& task);
93+
task_record thread_save_get_task_record(const size_t task_num);
94+
};
95+
96+
class divergence_checker {
97+
friend struct ::celerity::detail::runtime_testspy;
98+
99+
public:
100+
divergence_checker(task_recorder& task_recorder, std::unique_ptr<communicator> comm, bool test_mode = false)
101+
: m_block_chain(task_recorder, std::move(comm)) {
102+
if(!test_mode) { start(); }
103+
}
104+
105+
divergence_checker(const divergence_checker&) = delete;
106+
divergence_checker(const divergence_checker&&) = delete;
107+
108+
divergence_checker& operator=(const divergence_checker&) = delete;
109+
divergence_checker& operator=(divergence_checker&&) = delete;
110+
111+
~divergence_checker() { stop(); }
112+
113+
private:
114+
void start() {
115+
m_thread = std::thread(&divergence_checker::run, this);
116+
m_is_running = true;
117+
}
118+
119+
void stop() {
120+
m_is_running = false;
121+
if(m_thread.joinable()) { m_thread.join(); }
122+
}
123+
124+
void run() {
125+
bool is_finished = false;
126+
while(!is_finished || m_is_running) {
127+
is_finished = m_block_chain.check_for_divergence();
128+
129+
std::this_thread::sleep_for(std::chrono::milliseconds(100));
130+
}
131+
}
132+
133+
std::thread m_thread;
134+
bool m_is_running = false;
135+
divergence_block_chain m_block_chain;
136+
};
137+
}; // namespace celerity::detail::divergence_checker_detail

include/grid.h

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
#include <gch/small_vector.hpp>
99

1010
#include "ranges.h"
11+
#include "utils.h"
1112
#include "workaround.h"
1213

1314
namespace celerity::detail {
@@ -257,6 +258,27 @@ class region {
257258

258259
} // namespace celerity::detail
259260

261+
template <int Dims>
262+
struct std::hash<celerity::detail::box<Dims>> {
263+
std::size_t operator()(const celerity::detail::box<Dims> r) {
264+
std::size_t seed = 0;
265+
celerity::detail::utils::hash_combine(seed, std::hash<celerity::id<Dims>>{}(r.get_min()), std::hash<celerity::id<Dims>>{}(r.get_max()));
266+
return seed;
267+
};
268+
};
269+
270+
template <int Dims>
271+
struct std::hash<celerity::detail::region<Dims>> {
272+
std::size_t operator()(const celerity::detail::region<Dims> r) {
273+
std::size_t seed = 0;
274+
for(auto box : r.get_boxes()) {
275+
celerity::detail::utils::hash_combine(seed, std::hash<celerity::detail::box<Dims>>{}(box));
276+
}
277+
return seed;
278+
};
279+
};
280+
281+
260282
namespace celerity::detail::grid_detail {
261283

262284
// forward-declaration for tests (explicitly instantiated)

include/mpi_communicator.h

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
#pragma once
2+
#include <memory>
3+
4+
#include <mpi.h>
5+
6+
#include "communicator.h"
7+
8+
namespace celerity::detail {
9+
class mpi_communicator : public communicator {
10+
public:
11+
mpi_communicator(MPI_Comm comm) : m_comm(comm) {}
12+
13+
private:
14+
void allgather_inplace_impl(std::byte* sendrecvbuf, const int sendrecvcount) override {
15+
MPI_Allgather(MPI_IN_PLACE, 0, MPI_DATATYPE_NULL, sendrecvbuf, sendrecvcount, MPI_BYTE, m_comm);
16+
};
17+
18+
void allgather_impl(const std::byte* sendbuf, const int sendcount, std::byte* recvbuf, const int recvcount) override {
19+
MPI_Allgather(sendbuf, sendcount, MPI_BYTE, recvbuf, recvcount, MPI_BYTE, m_comm);
20+
};
21+
22+
void barrier_impl() override { MPI_Barrier(m_comm); }
23+
24+
size_t num_nodes_impl() override {
25+
int size = -1;
26+
MPI_Comm_size(m_comm, &size);
27+
return static_cast<size_t>(size);
28+
}
29+
30+
node_id local_nid_impl() override {
31+
int rank = -1;
32+
MPI_Comm_rank(m_comm, &rank);
33+
return static_cast<node_id>(rank);
34+
}
35+
36+
MPI_Comm m_comm;
37+
};
38+
} // namespace celerity::detail

include/ranges.h

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#pragma once
22

33
#include "sycl_wrappers.h"
4+
#include "utils.h"
45
#include "workaround.h"
56

67
namespace celerity {
@@ -229,6 +230,17 @@ struct ones_t {
229230

230231
}; // namespace celerity::detail
231232

233+
template <typename Interface, int Dims>
234+
struct std::hash<celerity::detail::coordinate<Interface, Dims>> {
235+
std::size_t operator()(const celerity::detail::coordinate<Interface, Dims>& r) const noexcept {
236+
std::size_t seed = 0;
237+
for(int i = 0; i < Dims; ++i) {
238+
celerity::detail::utils::hash_combine(seed, std::hash<int>{}(r[i]));
239+
}
240+
return seed;
241+
};
242+
};
243+
232244
namespace celerity {
233245

234246
template <int Dims>
@@ -401,6 +413,17 @@ nd_range(range<3> global_range, range<3> local_range)->nd_range<3>;
401413

402414
} // namespace celerity
403415

416+
417+
template <int Dims>
418+
struct std::hash<celerity::range<Dims>> {
419+
std::size_t operator()(const celerity::range<Dims>& r) const noexcept { return std::hash<celerity::detail::coordinate<celerity::range<Dims>, Dims>>{}(r); };
420+
};
421+
422+
template <int Dims>
423+
struct std::hash<celerity::id<Dims>> {
424+
std::size_t operator()(const celerity::id<Dims>& r) const noexcept { return std::hash<celerity::detail::coordinate<celerity::id<Dims>, Dims>>{}(r); };
425+
};
426+
404427
namespace celerity {
405428
namespace detail {
406429

0 commit comments

Comments
 (0)