Skip to content

Commit 8a27177

Browse files
committed
add celerity blockchain for task divergence checking
1 parent 0822c32 commit 8a27177

16 files changed

+1042
-15
lines changed

CMakeLists.txt

+1
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,7 @@ set(SOURCES
187187
src/command_graph.cc
188188
src/config.cc
189189
src/device_queue.cc
190+
src/divergence_block_chain.cc
190191
src/executor.cc
191192
src/distributed_graph_generator.cc
192193
src/graph_serializer.cc

include/communicator.h

+72
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
#pragma once
2+
3+
#include "types.h"
4+
#include <memory>
5+
#include <mpi.h>
6+
7+
namespace celerity::detail {
8+
class communicator {
9+
public:
10+
communicator() = default;
11+
communicator(const communicator&) = delete;
12+
communicator(communicator&&) noexcept = default;
13+
14+
communicator& operator=(const communicator&) = delete;
15+
communicator& operator=(communicator&&) noexcept = default;
16+
17+
virtual ~communicator() = default;
18+
19+
template <typename S>
20+
void allgather_inplace(S* sendrecvbuf, const int sendrecvcount) {
21+
allgather_inplace_impl(reinterpret_cast<std::byte*>(sendrecvbuf), sendrecvcount * sizeof(S));
22+
}
23+
24+
template <typename S, typename R>
25+
void allgather(const S* sendbuf, const int sendcount, R* recvbuf, const int recvcount) {
26+
allgather_impl(reinterpret_cast<const std::byte*>(sendbuf), sendcount * sizeof(S), reinterpret_cast<std::byte*>(recvbuf), recvcount * sizeof(R));
27+
}
28+
29+
void barrier() { barrier_impl(); }
30+
31+
size_t get_num_nodes() { return num_nodes_impl(); }
32+
33+
node_id get_local_nid() { return local_nid_impl(); }
34+
35+
protected:
36+
virtual void allgather_inplace_impl(std::byte* sendrecvbuf, const int sendrecvcount) = 0;
37+
virtual void allgather_impl(const std::byte* sendbuf, const int sendcount, std::byte* recvbuf, const int recvcount) = 0;
38+
virtual void barrier_impl() = 0;
39+
virtual size_t num_nodes_impl() = 0;
40+
virtual node_id local_nid_impl() = 0;
41+
};
42+
43+
class mpi_communicator : public communicator {
44+
public:
45+
mpi_communicator(MPI_Comm comm) : m_comm(comm) {}
46+
47+
private:
48+
void allgather_inplace_impl(std::byte* sendrecvbuf, const int sendrecvcount) override {
49+
MPI_Allgather(MPI_IN_PLACE, 0, MPI_DATATYPE_NULL, sendrecvbuf, sendrecvcount, MPI_BYTE, m_comm);
50+
};
51+
52+
void allgather_impl(const std::byte* sendbuf, const int sendcount, std::byte* recvbuf, const int recvcount) override {
53+
MPI_Allgather(sendbuf, sendcount, MPI_BYTE, recvbuf, recvcount, MPI_BYTE, m_comm);
54+
};
55+
56+
void barrier_impl() override { MPI_Barrier(m_comm); }
57+
58+
size_t num_nodes_impl() override {
59+
int size = -1;
60+
MPI_Comm_size(m_comm, &size);
61+
return static_cast<size_t>(size);
62+
}
63+
64+
node_id local_nid_impl() override {
65+
int rank = -1;
66+
MPI_Comm_rank(m_comm, &rank);
67+
return static_cast<node_id>(rank);
68+
}
69+
70+
MPI_Comm m_comm;
71+
};
72+
} // namespace celerity::detail

include/divergence_block_chain.h

+166
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,166 @@
1+
#pragma once
2+
3+
#include "communicator.h"
4+
#include "recorders.h"
5+
#include <mutex>
6+
#include <thread>
7+
#include <vector>
8+
9+
namespace celerity::detail {
10+
/**
11+
* @brief This class is a wrapper around a 1D vector that allows us to access it as a 2D array.
12+
*
13+
* It is used to send the task hashes to other nodes using MPI while keeping the code simple and readable.
14+
*/
15+
template <typename T>
16+
struct vector_2d {
17+
public:
18+
vector_2d(size_t width, size_t height) : m_data(width * height), m_width(width){};
19+
20+
const T& operator[](id<2> ij) const {
21+
assert(ij[0] * m_width + ij[1] < m_data.size());
22+
return m_data[ij[0] * m_width + ij[1]];
23+
}
24+
25+
T* data() { return m_data.data(); }
26+
27+
private:
28+
std::vector<T> m_data;
29+
const size_t m_width;
30+
};
31+
32+
/**
33+
* @brief This class gives a view into a const vector.
34+
*
35+
* It is used to give us the currently unhashed task records while keeping track of the offset and width.
36+
*/
37+
template <typename T>
38+
struct sliding_window {
39+
public:
40+
sliding_window(const std::vector<T>& value) : m_value(value) {}
41+
42+
const T& operator[](size_t i) const {
43+
assert(i >= 0 && i < m_width);
44+
return m_value[m_offset + i];
45+
}
46+
47+
size_t size() {
48+
m_width = m_value.size() - m_offset;
49+
return m_width;
50+
}
51+
52+
void slide(size_t i) {
53+
assert(i == 0 || (i >= 0 && i <= m_width));
54+
m_offset += i;
55+
m_width -= i;
56+
}
57+
58+
private:
59+
const std::vector<T>& m_value;
60+
size_t m_offset = 0;
61+
size_t m_width = 0;
62+
};
63+
64+
using task_hash = size_t;
65+
using task_hash_data = vector_2d<task_hash>;
66+
using divergence_map = std::unordered_map<task_hash, std::vector<node_id>>;
67+
68+
/**
69+
* @brief This class checks for divergences of tasks between nodes.
70+
*
71+
* It is responsible for collecting the task hashes from all nodes and checking for differences -> divergence.
72+
* When a divergence is found, the task record for the diverging task is printed and the program is terminated.
73+
* Additionally it also checks for deadlocks and prints a warning if one is detected.
74+
*/
75+
76+
class divergence_block_chain {
77+
friend struct divergence_block_chain_testspy;
78+
79+
public:
80+
divergence_block_chain(const std::vector<task_record>& task_recorder, std::unique_ptr<communicator> comm)
81+
: m_local_nid(comm->get_local_nid()), m_num_nodes(comm->get_num_nodes()), m_sizes(comm->get_num_nodes()), m_task_recorder_window(task_recorder),
82+
m_communicator(std::move(comm)) {}
83+
84+
divergence_block_chain(const divergence_block_chain&) = delete;
85+
divergence_block_chain(divergence_block_chain&&) = default;
86+
87+
~divergence_block_chain() = default;
88+
89+
divergence_block_chain& operator=(const divergence_block_chain&) = delete;
90+
divergence_block_chain& operator=(divergence_block_chain&&) = delete;
91+
92+
bool check_for_divergence();
93+
94+
private:
95+
node_id m_local_nid;
96+
size_t m_num_nodes;
97+
98+
std::vector<task_hash> m_hashes;
99+
std::vector<int> m_sizes;
100+
101+
sliding_window<task_record> m_task_recorder_window;
102+
103+
std::chrono::time_point<std::chrono::steady_clock> m_last_cleared = std::chrono::steady_clock::now();
104+
105+
std::unique_ptr<communicator> m_communicator;
106+
107+
void divergence_out(const divergence_map& check_map, const int task_num);
108+
109+
void add_new_hashes();
110+
void clear(const int min_progress);
111+
std::pair<int, int> collect_sizes();
112+
task_hash_data collect_hashes(const int max_size);
113+
divergence_map create_check_map(const task_hash_data& task_graphs, const int task_num) const;
114+
115+
void check_for_deadlock() const;
116+
117+
static void log_node_divergences(const divergence_map& check_map, const int task_num);
118+
119+
static void log_task_record(const divergence_map& check_map, const task_record& task, const task_hash hash);
120+
121+
void log_task_record_once(const divergence_map& check_map, const int task_num) const;
122+
};
123+
124+
class divergence_sceduler {
125+
friend struct runtime_testspy;
126+
127+
public:
128+
divergence_sceduler(const std::vector<task_record>& task_recorder, std::unique_ptr<communicator> comm, bool test_mode = false)
129+
: m_block_chain(task_recorder, std::move(comm)) {
130+
if(!test_mode) { start(); }
131+
}
132+
133+
divergence_sceduler(const divergence_sceduler&) = delete;
134+
divergence_sceduler(divergence_sceduler&&) = default;
135+
136+
divergence_sceduler& operator=(const divergence_sceduler&) = delete;
137+
divergence_sceduler& operator=(divergence_sceduler&&) = delete;
138+
139+
~divergence_sceduler() { stop(); }
140+
141+
private:
142+
void start() {
143+
stop();
144+
m_thread = std::thread(&divergence_sceduler::run, this);
145+
m_is_running = true;
146+
}
147+
148+
void stop() {
149+
m_is_running = false;
150+
if(m_thread.joinable()) { m_thread.join(); }
151+
}
152+
153+
void run() {
154+
bool is_finished = false;
155+
while(!is_finished || m_is_running) {
156+
is_finished = m_block_chain.check_for_divergence();
157+
158+
std::this_thread::sleep_for(std::chrono::milliseconds(100));
159+
}
160+
}
161+
162+
std::thread m_thread;
163+
bool m_is_running = false;
164+
divergence_block_chain m_block_chain;
165+
};
166+
}; // namespace celerity::detail

include/grid.h

+22
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
#include <gch/small_vector.hpp>
99

1010
#include "ranges.h"
11+
#include "utils.h"
1112
#include "workaround.h"
1213

1314
namespace celerity::detail {
@@ -257,6 +258,27 @@ class region {
257258

258259
} // namespace celerity::detail
259260

261+
template <int Dims>
262+
struct std::hash<celerity::detail::box<Dims>> {
263+
std::size_t operator()(const celerity::detail::box<Dims> r) {
264+
std::size_t seed = 0;
265+
celerity::detail::utils::hash_combine(seed, std::hash<celerity::id<Dims>>{}(r.get_min()), std::hash<celerity::id<Dims>>{}(r.get_max()));
266+
return seed;
267+
};
268+
};
269+
270+
template <int Dims>
271+
struct std::hash<celerity::detail::region<Dims>> {
272+
std::size_t operator()(const celerity::detail::region<Dims> r) {
273+
std::size_t seed = 0;
274+
for(auto box : r.get_boxes()) {
275+
celerity::detail::utils::hash_combine(seed, std::hash<celerity::detail::box<Dims>>{}(box));
276+
}
277+
return seed;
278+
};
279+
};
280+
281+
260282
namespace celerity::detail::grid_detail {
261283

262284
// forward-declaration for tests (explicitly instantiated)

include/ranges.h

+23
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#pragma once
22

33
#include "sycl_wrappers.h"
4+
#include "utils.h"
45
#include "workaround.h"
56

67
namespace celerity {
@@ -229,6 +230,17 @@ struct ones_t {
229230

230231
}; // namespace celerity::detail
231232

233+
template <typename Interface, int Dims>
234+
struct std::hash<celerity::detail::coordinate<Interface, Dims>> {
235+
std::size_t operator()(const celerity::detail::coordinate<Interface, Dims>& r) const noexcept {
236+
std::size_t seed = 0;
237+
for(int i = 0; i < Dims; ++i) {
238+
celerity::detail::utils::hash_combine(seed, std::hash<int>{}(r[i]));
239+
}
240+
return seed;
241+
};
242+
};
243+
232244
namespace celerity {
233245

234246
template <int Dims>
@@ -401,6 +413,17 @@ nd_range(range<3> global_range, range<3> local_range)->nd_range<3>;
401413

402414
} // namespace celerity
403415

416+
417+
template <int Dims>
418+
struct std::hash<celerity::range<Dims>> {
419+
std::size_t operator()(const celerity::range<Dims>& r) const noexcept { return std::hash<celerity::detail::coordinate<celerity::range<Dims>, Dims>>{}(r); };
420+
};
421+
422+
template <int Dims>
423+
struct std::hash<celerity::id<Dims>> {
424+
std::size_t operator()(const celerity::id<Dims>& r) const noexcept { return std::hash<celerity::detail::coordinate<celerity::id<Dims>, Dims>>{}(r); };
425+
};
426+
404427
namespace celerity {
405428
namespace detail {
406429

0 commit comments

Comments
 (0)