1
- #include " divergence_block_chain .h"
1
+ #include " divergence_checker .h"
2
2
3
3
namespace celerity ::detail::divergence_checker_detail {
4
+ bool divergence_block_chain::check_for_divergence () {
5
+ add_new_hashes ();
6
+
7
+ const auto [min_hash_count, max_hash_count] = collect_hash_counts ();
8
+
9
+ if (min_hash_count == 0 ) {
10
+ if (max_hash_count != 0 && m_local_nid == 0 ) {
11
+ check_for_deadlock ();
12
+ } else if (max_hash_count == 0 ) {
13
+ return true ;
14
+ }
15
+ return false ;
16
+ }
17
+
18
+ const per_node_task_hashes task_graphs = collect_hashes (min_hash_count);
19
+
20
+ for (int j = 0 ; j < min_hash_count; ++j) {
21
+ const divergence_map check_map = create_divergence_map (task_graphs, j);
22
+
23
+ // If there is more than one hash for this task, we have a divergence!
24
+ if (check_map.size () > 1 ) { reprot_divergence (check_map, j); }
25
+ }
26
+
27
+ clear (min_hash_count);
28
+
29
+ return false ;
30
+ }
31
+
32
+ void divergence_block_chain::reprot_divergence (const divergence_map& check_map, const int task_num) {
33
+ if (m_local_nid == 0 ) { log_node_divergences (check_map, task_num + static_cast <int >(m_tasks_checked) + 1 ); }
34
+
35
+ // sleep for local_nid * 100 ms such that we have a no lock synchronized output
36
+ std::this_thread::sleep_for (std::chrono::milliseconds (m_local_nid * 100 ));
37
+
38
+ log_task_record_once (check_map, task_num);
39
+
40
+ m_communicator->barrier ();
41
+
42
+ throw std::runtime_error (" Divergence in task graph detected" );
43
+ }
4
44
5
45
void divergence_block_chain::add_new_hashes () {
6
46
std::lock_guard<std::mutex> lock (m_task_records_mutex);
@@ -38,36 +78,36 @@ per_node_task_hashes divergence_block_chain::collect_hashes(const int min_hash_c
38
78
}
39
79
40
80
41
- divergence_map divergence_block_chain::create_check_map (const per_node_task_hashes& task_hashes, const int task_num) const {
81
+ divergence_map divergence_block_chain::create_divergence_map (const per_node_task_hashes& task_hashes, const int task_num) const {
42
82
divergence_map check_map;
43
- for (size_t i = 0 ; i < m_num_nodes; ++i ) {
44
- check_map[task_hashes (i , task_num)].push_back (i );
83
+ for (node_id nid = 0 ; nid < m_num_nodes; ++nid ) {
84
+ check_map[task_hashes (nid , task_num)].push_back (nid );
45
85
}
46
86
return check_map;
47
87
}
48
88
49
- void divergence_block_chain::check_for_deadlock () const {
89
+ void divergence_block_chain::check_for_deadlock () {
50
90
auto diff = std::chrono::duration_cast<std::chrono::seconds>(std::chrono::steady_clock::now () - m_last_cleared);
51
- static auto last = std::chrono::seconds (0 );
52
91
53
- if (diff >= std::chrono::seconds (10 ) && diff - last >= std::chrono::seconds (5 )) {
54
- std::string warning = fmt::format (" After {} seconds of waiting nodes " , diff.count ());
92
+ if (diff >= std::chrono::seconds (10 ) && diff - m_time_of_last_warning >= std::chrono::seconds (5 )) {
93
+ std::string warning = fmt::format (" After {} seconds of waiting, node(s) " , diff.count ());
55
94
56
- for (size_t i = 0 ; i < m_num_nodes; ++i) {
57
- if (m_per_node_hash_counts[i] == 0 ) { warning += fmt::format (" {}," , i); }
95
+ std::vector<node_id> stuck_nodes;
96
+ for (node_id nid = 0 ; nid < m_num_nodes; ++nid) {
97
+ if (m_per_node_hash_counts[nid] == 0 ) stuck_nodes.push_back (nid);
58
98
}
59
-
60
- warning += " did not move to the next task. The runtime might be stuck." ;
99
+ warning += fmt::format ( " {} " , fmt::join (stuck_nodes, " , " ));
100
+ warning += " did not move to the next task. The runtime might be stuck." ;
61
101
62
102
CELERITY_WARN (" {}" , warning);
63
- last = diff;
103
+ m_time_of_last_warning = diff;
64
104
}
65
105
}
66
106
67
- void divergence_block_chain::log_node_divergences (const divergence_map& check_map, const int task_num ) {
68
- std::string error = fmt::format (" Divergence detected in task graph at index {} :\n\n " , task_num );
107
+ void divergence_block_chain::log_node_divergences (const divergence_map& check_map, const int task_id ) {
108
+ std::string error = fmt::format (" Divergence detected. Task Nr {} diverges on nodes :\n\n " , task_id );
69
109
for (auto & [hash, nodes] : check_map) {
70
- error += fmt::format (" {:#x} on nodes " , hash);
110
+ error += fmt::format (" Following task-hash {:#x} resulted on {} " , hash, nodes. size () > 1 ? " nodes " : " node " );
71
111
for (auto & node : nodes) {
72
112
error += fmt::format (" {} " , node);
73
113
}
@@ -115,11 +155,6 @@ void divergence_block_chain::log_task_record(const divergence_map& check_map, co
115
155
CELERITY_ERROR (" {}" , task_record_output);
116
156
}
117
157
118
- task_record divergence_block_chain::thread_save_get_task_record (const size_t task_num) {
119
- std::lock_guard<std::mutex> lock (m_task_records_mutex);
120
- return m_task_records[task_num];
121
- }
122
-
123
158
void divergence_block_chain::log_task_record_once (const divergence_map& check_map, const int task_num) {
124
159
for (auto & [hash, nodes] : check_map) {
125
160
if (nodes[0 ] == m_local_nid) {
@@ -129,49 +164,14 @@ void divergence_block_chain::log_task_record_once(const divergence_map& check_ma
129
164
}
130
165
}
131
166
132
- bool divergence_block_chain::check_for_divergence () {
133
- add_new_hashes ();
134
-
135
- const auto [min_hash_count, max_hash_count] = collect_hash_counts ();
136
-
137
- if (min_hash_count == 0 ) {
138
- if (max_hash_count != 0 && m_local_nid == 0 ) {
139
- check_for_deadlock ();
140
- } else if (max_hash_count == 0 ) {
141
- return true ;
142
- }
143
- return false ;
144
- }
145
-
146
- const per_node_task_hashes task_graphs = collect_hashes (min_hash_count);
147
-
148
- for (int j = 0 ; j < min_hash_count; ++j) {
149
- const divergence_map check_map = create_check_map (task_graphs, j);
150
-
151
- if (check_map.size () > 1 ) { divergence_out (check_map, j); }
152
- }
153
-
154
- clear (min_hash_count);
155
-
156
- return false ;
157
- }
158
-
159
- void divergence_block_chain::divergence_out (const divergence_map& check_map, const int task_num) {
160
- if (m_local_nid == 0 ) { log_node_divergences (check_map, task_num); }
161
-
162
- // sleep for local_nid * 100 ms such that we have a no lock synchronized output
163
- std::this_thread::sleep_for (std::chrono::milliseconds (m_local_nid * 100 ));
164
-
165
- log_task_record_once (check_map, task_num);
166
-
167
- m_communicator->barrier ();
168
-
169
- throw std::runtime_error (" Divergence in task graph detected" );
170
- }
171
-
172
167
void divergence_block_chain::add_new_task (const task_record& task) { //
173
168
std::lock_guard<std::mutex> lock (m_task_records_mutex);
174
169
// make copy of task record so that we can access it later
175
170
m_task_records.emplace_back (task);
176
171
}
172
+
173
+ task_record divergence_block_chain::thread_save_get_task_record (const size_t task_num) {
174
+ std::lock_guard<std::mutex> lock (m_task_records_mutex);
175
+ return m_task_records[task_num];
176
+ }
177
177
} // namespace celerity::detail::divergence_checker_detail
0 commit comments