Skip to content

Commit 3adceaa

Browse files
committed
Revision: add celerity blockchain for task divergence checking
1 parent fac5661 commit 3adceaa

17 files changed

+170
-148
lines changed

CHANGELOG.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ Versioning](http://semver.org/spec/v2.0.0.html).
1313
- Introduce new experimental `for_each_item` utility to iterate over a celerity range (#199)
1414
- Add new environment variables `CELERITY_HORIZON_STEP` and `CELERITY_HORIZON_MAX_PARALLELISM` to control Horizon generation (#199)
1515
- Add new `experimental::constrain_split` API to limit how a kernel can be split (#?)
16-
- Add divergence check blockchain for automatic detection of diverging tasks in debug mode (#217)
16+
- Add automatic detection of diverging execution in debug mode (#217)
1717
- `distr_queue::fence` and `buffer_snapshot` are now stable, subsuming the `experimental::` APIs of the same name (#225)
1818
- Celerity now warns at runtime when a task declares reads from uninitialized buffers or writes with overlapping ranges between nodes (#224)
1919
- Introduce new `experimental::hint` API for providing the runtime with additional information on how to execute a task (#227)

CMakeLists.txt

+7-1
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,16 @@ endif()
2323

2424
option(CELERITY_ACCESS_PATTERN_DIAGNOSTICS "Diagnose uninitialized reads and overlapping writes" ${DEFAULT_ENABLE_DEBUG_CHECKS})
2525
option(CELERITY_ACCESSOR_BOUNDARY_CHECK "Enable accessor boundary check" ${DEFAULT_ENABLE_DEBUG_CHECKS})
26+
option(CELERITY_DIVERGENCE_CHECK "Enable divergence check" ${DEFAULT_ENABLE_DEBUG_CHECKS})
2627

2728
if(CELERITY_ACCESSOR_BOUNDARY_CHECK AND NOT (CMAKE_BUILD_TYPE STREQUAL "Debug"))
2829
message(STATUS "Accessor boundary check enabled - this will impact kernel performance")
2930
endif()
3031

32+
if(CELERITY_DIVERGENCE_CHECK AND NOT (CMAKE_BUILD_TYPE STREQUAL "Debug"))
33+
message(STATUS "Divergence checker enabled - this will impact the overall performance")
34+
endif()
35+
3136
set(CELERITY_CMAKE_DIR "${PROJECT_SOURCE_DIR}/cmake")
3237
set(CMAKE_MODULE_PATH "${CMAKE_MODULE_PATH}" "${CELERITY_CMAKE_DIR}")
3338
find_package(MPI 2.0 REQUIRED)
@@ -186,7 +191,7 @@ set(SOURCES
186191
src/command_graph.cc
187192
src/config.cc
188193
src/device_queue.cc
189-
src/divergence_block_chain.cc
194+
src/divergence_checker.cc
190195
src/executor.cc
191196
src/distributed_graph_generator.cc
192197
src/graph_serializer.cc
@@ -289,6 +294,7 @@ target_compile_definitions(celerity_runtime PUBLIC
289294
CELERITY_FEATURE_UNNAMED_KERNELS=$<BOOL:${CELERITY_FEATURE_UNNAMED_KERNELS}>
290295
CELERITY_DETAIL_HAS_NAMED_THREADS=$<BOOL:${CELERITY_DETAIL_HAS_NAMED_THREADS}>
291296
CELERITY_ACCESSOR_BOUNDARY_CHECK=$<BOOL:${CELERITY_ACCESSOR_BOUNDARY_CHECK}>
297+
CELERITY_DIVERGENCE_CHECK=$<BOOL:${CELERITY_DIVERGENCE_CHECK}>
292298
CELERITY_ACCESS_PATTERN_DIAGNOSTICS=$<BOOL:${CELERITY_ACCESS_PATTERN_DIAGNOSTICS}>
293299
)
294300

docs/pitfalls.md

+4
Original file line numberDiff line numberDiff line change
@@ -132,3 +132,7 @@ if(rand() > 1337) {
132132
celerity::buffer<float, 2> my_buffer(...);
133133
}
134134
```
135+
136+
> Diverging Host-Execution can be detected at runtime by enabling the
137+
> `CELERITY_DIVERGENCE_CHECK` CMake option at the cost of some runtime
138+
> overhead (enabled by default in debug builds).

include/divergence_block_chain.h renamed to include/divergence_checker.h

+13-14
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,6 @@
77
#include "communicator.h"
88
#include "recorders.h"
99

10-
namespace celerity::detail {
11-
struct runtime_testspy;
12-
}
13-
1410
namespace celerity::detail::divergence_checker_detail {
1511
using task_hash = size_t;
1612
using divergence_map = std::unordered_map<task_hash, std::vector<node_id>>;
@@ -72,29 +68,32 @@ class divergence_block_chain {
7268
std::mutex m_task_records_mutex;
7369

7470
std::chrono::time_point<std::chrono::steady_clock> m_last_cleared = std::chrono::steady_clock::now();
71+
std::chrono::seconds m_time_of_last_warning = std::chrono::seconds(0);
7572

7673
std::unique_ptr<communicator> m_communicator;
7774

78-
void divergence_out(const divergence_map& check_map, const int task_num);
75+
void reprot_divergence(const divergence_map& check_map, const int task_num);
7976

8077
void add_new_hashes();
8178
void clear(const int min_progress);
8279
std::pair<int, int> collect_hash_counts();
8380
per_node_task_hashes collect_hashes(const int min_hash_count) const;
84-
divergence_map create_check_map(const per_node_task_hashes& task_hashes, const int task_num) const;
81+
divergence_map create_divergence_map(const per_node_task_hashes& task_hashes, const int task_num) const;
8582

86-
void check_for_deadlock() const;
83+
void check_for_deadlock();
8784

88-
static void log_node_divergences(const divergence_map& check_map, const int task_num);
85+
static void log_node_divergences(const divergence_map& check_map, const int task_id);
8986
static void log_task_record(const divergence_map& check_map, const task_record& task, const task_hash hash);
9087
void log_task_record_once(const divergence_map& check_map, const int task_num);
9188

9289
void add_new_task(const task_record& task);
9390
task_record thread_save_get_task_record(const size_t task_num);
9491
};
92+
}; // namespace celerity::detail::divergence_checker_detail
9593

94+
namespace celerity::detail {
9695
class divergence_checker {
97-
friend struct ::celerity::detail::runtime_testspy;
96+
friend struct runtime_testspy;
9897

9998
public:
10099
divergence_checker(task_recorder& task_recorder, std::unique_ptr<communicator> comm, bool test_mode = false)
@@ -111,6 +110,10 @@ class divergence_checker {
111110
~divergence_checker() { stop(); }
112111

113112
private:
113+
std::thread m_thread;
114+
bool m_is_running = false;
115+
divergence_checker_detail::divergence_block_chain m_block_chain;
116+
114117
void start() {
115118
m_thread = std::thread(&divergence_checker::run, this);
116119
m_is_running = true;
@@ -129,9 +132,5 @@ class divergence_checker {
129132
std::this_thread::sleep_for(std::chrono::milliseconds(100));
130133
}
131134
}
132-
133-
std::thread m_thread;
134-
bool m_is_running = false;
135-
divergence_block_chain m_block_chain;
136135
};
137-
}; // namespace celerity::detail::divergence_checker_detail
136+
};

include/grid.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -271,7 +271,7 @@ template <int Dims>
271271
struct std::hash<celerity::detail::region<Dims>> {
272272
std::size_t operator()(const celerity::detail::region<Dims> r) {
273273
std::size_t seed = 0;
274-
for(auto box : r.get_boxes()) {
274+
for(auto& box : r.get_boxes()) {
275275
celerity::detail::utils::hash_combine(seed, std::hash<celerity::detail::box<Dims>>{}(box));
276276
}
277277
return seed;

include/mpi_communicator.h

+4-3
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
#pragma once
2+
23
#include <memory>
34

45
#include <mpi.h>
@@ -11,6 +12,8 @@ class mpi_communicator : public communicator {
1112
mpi_communicator(MPI_Comm comm) : m_comm(comm) {}
1213

1314
private:
15+
MPI_Comm m_comm;
16+
1417
void allgather_inplace_impl(std::byte* sendrecvbuf, const int sendrecvcount) override {
1518
MPI_Allgather(MPI_IN_PLACE, 0, MPI_DATATYPE_NULL, sendrecvbuf, sendrecvcount, MPI_BYTE, m_comm);
1619
};
@@ -32,7 +35,5 @@ class mpi_communicator : public communicator {
3235
MPI_Comm_rank(m_comm, &rank);
3336
return static_cast<node_id>(rank);
3437
}
35-
36-
MPI_Comm m_comm;
3738
};
38-
} // namespace celerity::detail
39+
} // namespace celerity::detail

include/ranges.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -235,7 +235,7 @@ struct std::hash<celerity::detail::coordinate<Interface, Dims>> {
235235
std::size_t operator()(const celerity::detail::coordinate<Interface, Dims>& r) const noexcept {
236236
std::size_t seed = 0;
237237
for(int i = 0; i < Dims; ++i) {
238-
celerity::detail::utils::hash_combine(seed, std::hash<int>{}(r[i]));
238+
celerity::detail::utils::hash_combine(seed, std::hash<size_t>{}(r[i]));
239239
}
240240
return seed;
241241
};

include/recorders.h

+5-4
Original file line numberDiff line numberDiff line change
@@ -61,14 +61,15 @@ class task_recorder {
6161
void record_task(const task& tsk);
6262

6363
void add_callback(task_callback callback);
64-
void invoke_callbacks(const task_record& tsk) const;
6564

6665
const task_records& get_tasks() const { return m_recorded_tasks; }
6766

6867
private:
6968
task_records m_recorded_tasks;
7069
std::vector<task_callback> m_callbacks{};
7170
const buffer_manager* m_buff_mngr;
71+
72+
void invoke_callbacks(const task_record& tsk) const;
7273
};
7374

7475
// Command recording
@@ -104,16 +105,16 @@ struct command_record {
104105

105106
class command_recorder {
106107
public:
107-
using command_record = std::vector<command_record>;
108+
using command_records = std::vector<command_record>;
108109

109110
command_recorder(const task_manager* task_mngr, const buffer_manager* buff_mngr = nullptr) : m_task_mngr(task_mngr), m_buff_mngr(buff_mngr) {}
110111

111112
void record_command(const abstract_command& com);
112113

113-
const command_record& get_commands() const { return m_recorded_commands; }
114+
const command_records& get_commands() const { return m_recorded_commands; }
114115

115116
private:
116-
command_record m_recorded_commands;
117+
command_records m_recorded_commands;
117118
const task_manager* m_task_mngr;
118119
const buffer_manager* m_buff_mngr;
119120
};

include/runtime.h

+2-2
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
#include "command.h"
88
#include "config.h"
99
#include "device_queue.h"
10-
#include "divergence_block_chain.h"
10+
#include "divergence_checker.h"
1111
#include "frame.h"
1212
#include "host_queue.h"
1313
#include "recorders.h"
@@ -102,7 +102,7 @@ namespace detail {
102102
size_t m_num_nodes;
103103
node_id m_local_nid;
104104

105-
std::unique_ptr<divergence_checker_detail::divergence_checker> m_divergence_check;
105+
std::unique_ptr<divergence_checker> m_divergence_check;
106106

107107
// These management classes are only constructed on the master node.
108108
std::unique_ptr<command_graph> m_cdag;

src/config.cc

+5
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,12 @@ namespace detail {
201201
const auto has_dry_run_nodes = parsed_and_validated_envs.get(env_dry_run_nodes);
202202
if(has_dry_run_nodes) { m_dry_run_nodes = *has_dry_run_nodes; }
203203

204+
#if CELERITY_DIVERGENCE_CHECK
205+
// divergence checker needs recording
206+
m_recording = true;
207+
#else
204208
m_recording = parsed_and_validated_envs.get_or(env_recording, false);
209+
#endif
205210
m_horizon_step = parsed_and_validated_envs.get(env_horizon_step);
206211
m_horizon_max_parallelism = parsed_and_validated_envs.get(env_horizon_max_para);
207212

src/divergence_block_chain.cc renamed to src/divergence_checker.cc

+61-61
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,46 @@
1-
#include "divergence_block_chain.h"
1+
#include "divergence_checker.h"
22

33
namespace celerity::detail::divergence_checker_detail {
4+
bool divergence_block_chain::check_for_divergence() {
5+
add_new_hashes();
6+
7+
const auto [min_hash_count, max_hash_count] = collect_hash_counts();
8+
9+
if(min_hash_count == 0) {
10+
if(max_hash_count != 0 && m_local_nid == 0) {
11+
check_for_deadlock();
12+
} else if(max_hash_count == 0) {
13+
return true;
14+
}
15+
return false;
16+
}
17+
18+
const per_node_task_hashes task_graphs = collect_hashes(min_hash_count);
19+
20+
for(int j = 0; j < min_hash_count; ++j) {
21+
const divergence_map check_map = create_divergence_map(task_graphs, j);
22+
23+
// If there is more than one hash for this task, we have a divergence!
24+
if(check_map.size() > 1) { reprot_divergence(check_map, j); }
25+
}
26+
27+
clear(min_hash_count);
28+
29+
return false;
30+
}
31+
32+
void divergence_block_chain::reprot_divergence(const divergence_map& check_map, const int task_num) {
33+
if(m_local_nid == 0) { log_node_divergences(check_map, task_num + static_cast<int>(m_tasks_checked) + 1); }
34+
35+
// sleep for local_nid * 100 ms such that we have a no lock synchronized output
36+
std::this_thread::sleep_for(std::chrono::milliseconds(m_local_nid * 100));
37+
38+
log_task_record_once(check_map, task_num);
39+
40+
m_communicator->barrier();
41+
42+
throw std::runtime_error("Divergence in task graph detected");
43+
}
444

545
void divergence_block_chain::add_new_hashes() {
646
std::lock_guard<std::mutex> lock(m_task_records_mutex);
@@ -38,36 +78,36 @@ per_node_task_hashes divergence_block_chain::collect_hashes(const int min_hash_c
3878
}
3979

4080

41-
divergence_map divergence_block_chain::create_check_map(const per_node_task_hashes& task_hashes, const int task_num) const {
81+
divergence_map divergence_block_chain::create_divergence_map(const per_node_task_hashes& task_hashes, const int task_num) const {
4282
divergence_map check_map;
43-
for(size_t i = 0; i < m_num_nodes; ++i) {
44-
check_map[task_hashes(i, task_num)].push_back(i);
83+
for(node_id nid = 0; nid < m_num_nodes; ++nid) {
84+
check_map[task_hashes(nid, task_num)].push_back(nid);
4585
}
4686
return check_map;
4787
}
4888

49-
void divergence_block_chain::check_for_deadlock() const {
89+
void divergence_block_chain::check_for_deadlock() {
5090
auto diff = std::chrono::duration_cast<std::chrono::seconds>(std::chrono::steady_clock::now() - m_last_cleared);
51-
static auto last = std::chrono::seconds(0);
5291

53-
if(diff >= std::chrono::seconds(10) && diff - last >= std::chrono::seconds(5)) {
54-
std::string warning = fmt::format("After {} seconds of waiting nodes", diff.count());
92+
if(diff >= std::chrono::seconds(10) && diff - m_time_of_last_warning >= std::chrono::seconds(5)) {
93+
std::string warning = fmt::format("After {} seconds of waiting, node(s)", diff.count());
5594

56-
for(size_t i = 0; i < m_num_nodes; ++i) {
57-
if(m_per_node_hash_counts[i] == 0) { warning += fmt::format(" {},", i); }
95+
std::vector<node_id> stuck_nodes;
96+
for(node_id nid = 0; nid < m_num_nodes; ++nid) {
97+
if(m_per_node_hash_counts[nid] == 0) stuck_nodes.push_back(nid);
5898
}
59-
60-
warning += " did not move to the next task. The runtime might be stuck.";
99+
warning += fmt::format(" {} ", fmt::join(stuck_nodes, ","));
100+
warning += "did not move to the next task. The runtime might be stuck.";
61101

62102
CELERITY_WARN("{}", warning);
63-
last = diff;
103+
m_time_of_last_warning = diff;
64104
}
65105
}
66106

67-
void divergence_block_chain::log_node_divergences(const divergence_map& check_map, const int task_num) {
68-
std::string error = fmt::format("Divergence detected in task graph at index {}:\n\n", task_num);
107+
void divergence_block_chain::log_node_divergences(const divergence_map& check_map, const int task_id) {
108+
std::string error = fmt::format("Divergence detected. Task Nr {} diverges on nodes:\n\n", task_id);
69109
for(auto& [hash, nodes] : check_map) {
70-
error += fmt::format("{:#x} on nodes ", hash);
110+
error += fmt::format("Following task-hash {:#x} resulted on {} ", hash, nodes.size() > 1 ? "nodes" : "node ");
71111
for(auto& node : nodes) {
72112
error += fmt::format("{} ", node);
73113
}
@@ -115,11 +155,6 @@ void divergence_block_chain::log_task_record(const divergence_map& check_map, co
115155
CELERITY_ERROR("{}", task_record_output);
116156
}
117157

118-
task_record divergence_block_chain::thread_save_get_task_record(const size_t task_num) {
119-
std::lock_guard<std::mutex> lock(m_task_records_mutex);
120-
return m_task_records[task_num];
121-
}
122-
123158
void divergence_block_chain::log_task_record_once(const divergence_map& check_map, const int task_num) {
124159
for(auto& [hash, nodes] : check_map) {
125160
if(nodes[0] == m_local_nid) {
@@ -129,49 +164,14 @@ void divergence_block_chain::log_task_record_once(const divergence_map& check_ma
129164
}
130165
}
131166

132-
bool divergence_block_chain::check_for_divergence() {
133-
add_new_hashes();
134-
135-
const auto [min_hash_count, max_hash_count] = collect_hash_counts();
136-
137-
if(min_hash_count == 0) {
138-
if(max_hash_count != 0 && m_local_nid == 0) {
139-
check_for_deadlock();
140-
} else if(max_hash_count == 0) {
141-
return true;
142-
}
143-
return false;
144-
}
145-
146-
const per_node_task_hashes task_graphs = collect_hashes(min_hash_count);
147-
148-
for(int j = 0; j < min_hash_count; ++j) {
149-
const divergence_map check_map = create_check_map(task_graphs, j);
150-
151-
if(check_map.size() > 1) { divergence_out(check_map, j); }
152-
}
153-
154-
clear(min_hash_count);
155-
156-
return false;
157-
}
158-
159-
void divergence_block_chain::divergence_out(const divergence_map& check_map, const int task_num) {
160-
if(m_local_nid == 0) { log_node_divergences(check_map, task_num); }
161-
162-
// sleep for local_nid * 100 ms such that we have a no lock synchronized output
163-
std::this_thread::sleep_for(std::chrono::milliseconds(m_local_nid * 100));
164-
165-
log_task_record_once(check_map, task_num);
166-
167-
m_communicator->barrier();
168-
169-
throw std::runtime_error("Divergence in task graph detected");
170-
}
171-
172167
void divergence_block_chain::add_new_task(const task_record& task) { //
173168
std::lock_guard<std::mutex> lock(m_task_records_mutex);
174169
// make copy of task record so that we can access it later
175170
m_task_records.emplace_back(task);
176171
}
172+
173+
task_record divergence_block_chain::thread_save_get_task_record(const size_t task_num) {
174+
std::lock_guard<std::mutex> lock(m_task_records_mutex);
175+
return m_task_records[task_num];
176+
}
177177
} // namespace celerity::detail::divergence_checker_detail

0 commit comments

Comments
 (0)