Skip to content

Commit 6e6ba3b

Browse files
committedAug 18, 2023
lock-free async poggers
1 parent 4ad1e63 commit 6e6ba3b

File tree

1 file changed

+146
-100
lines changed

1 file changed

+146
-100
lines changed
 

‎include/dpp/coro/async.h

+146-100
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,82 @@
2828
#include <utility>
2929
#include <type_traits>
3030
#include <functional>
31+
#include <atomic>
32+
#include <cstddef>
3133

3234
namespace dpp {
3335

36+
namespace detail {
37+
38+
/**
39+
* @brief Empty struct used for overload resolution.
40+
*/
41+
struct empty_tag_t{};
42+
43+
/**
44+
* @brief Represents the step an std::async is at.
45+
*/
46+
enum class async_state_t {
47+
sent, /* Request was sent but not co_await-ed. handle is nullptr, result_storage is not constructed */
48+
waiting, /* Request was co_await-ed. handle is valid, result_storage is not constructed */
49+
done, /* Request was completed. handle is unknown, result_storage is valid */
50+
dangling /* Request was never co_await-ed. */
51+
};
52+
53+
/**
54+
* @brief State of the async and its callback.
55+
*
56+
* Defined outside of dpp::async because this seems to work better with Intellisense.
57+
*/
58+
template <typename R>
59+
struct async_callback_data {
60+
/**
61+
* @brief Number of references to this callback state.
62+
*/
63+
std::atomic<int> ref_count{1};
64+
65+
/**
66+
* @brief State of the awaitable and the API callback
67+
*/
68+
std::atomic<detail::async_state_t> state = detail::async_state_t::sent;
69+
70+
/**
71+
* @brief The stored result of the API call, stored as an array of bytes to directly construct with copy constructor
72+
*/
73+
alignas(R) std::array<std::byte, sizeof(R)> result_storage;
74+
75+
/**
76+
* @brief Handle to the coroutine co_await-ing on this API call
77+
*
78+
* @see <a href="https://en.cppreference.com/w/cpp/coroutine/coroutine_handle">std::coroutine_handle</a>
79+
*/
80+
std_coroutine::coroutine_handle<> coro_handle = nullptr;
81+
82+
/**
83+
* @brief Convenience function to construct the result in the storage and initialize its lifetime
84+
*
85+
* @warning This is only a convenience function, ONLY CALL THIS IN THE CALLBACK, before setting state to done.
86+
*/
87+
template <typename... Ts>
88+
void construct_result(Ts&&... ts) {
89+
// Standard-compliant type punning yay
90+
std::construct_at<R>(reinterpret_cast<R *>(result_storage.data()), std::forward<Ts>(ts)...);
91+
}
92+
93+
/**
94+
* @brief Destructor.
95+
*
96+
* Also destroys the result if present.
97+
*/
98+
~async_callback_data() {
99+
if (state.load() == detail::async_state_t::done) {
100+
std::destroy_at<R>(reinterpret_cast<R *>(result_storage.data()));
101+
}
102+
}
103+
};
104+
105+
}
106+
34107
struct confirmation_callback_t;
35108

36109
/**
@@ -41,7 +114,6 @@ struct confirmation_callback_t;
41114
* @remark - This object's methods, other than constructors and operators, should not be called directly. It is designed to be used with coroutine keywords such as co_await.
42115
* @remark - The coroutine may be resumed in another thread, do not rely on thread_local variables.
43116
* @warning - This feature is EXPERIMENTAL. The API may change at any time and there may be bugs. Please report any to <a href="https://github.com/brainboxdotcc/DPP/issues">GitHub issues</a> or to the <a href="https://discord.gg/dpp">D++ Discord server</a>.
44-
* @warning - Using co_await on this object more than once is undefined behavior.
45117
* @tparam R The return type of the API call. Defaults to confirmation_callback_t
46118
*/
47119
template <typename R>
@@ -50,97 +122,35 @@ class async {
50122
* @brief Ref-counted callback, contains the callback logic and manages the lifetime of the callback data over multiple threads.
51123
*/
52124
struct shared_callback {
53-
struct empty_tag_t{};
54-
55-
/**
56-
* @brief State of the async and its callback.
57-
*/
58-
struct callback_state {
59-
enum state_t {
60-
waiting,
61-
done,
62-
dangling
63-
};
64-
65-
/**
66-
* @brief Mutex to ensure the API result isn't set at the same time the coroutine is awaited and its value is checked, or the async is destroyed
67-
*/
68-
std::mutex mutex{};
69-
70-
/**
71-
* @brief Number of references to this callback state.
72-
*/
73-
int ref_count;
74-
75-
/**
76-
* @brief State of the awaitable and the API callback
77-
*/
78-
state_t state = waiting;
79-
80-
/**
81-
* @brief The stored result of the API call
82-
*/
83-
std::optional<R> result = std::nullopt;
84-
85-
/**
86-
* @brief Handle to the coroutine co_await-ing on this API call
87-
*
88-
* @see <a href="https://en.cppreference.com/w/cpp/coroutine/coroutine_handle">std::coroutine_handle</a>
89-
*/
90-
detail::std_coroutine::coroutine_handle<> coro_handle = nullptr;
91-
};
92-
93-
callback_state *state;
125+
detail::async_callback_data<R> *state = new detail::async_callback_data<R>;
94126

95127
/**
96128
* @brief Callback function.
97129
*
130+
* Constructs the callback data, and if the coroutine was awaiting, resume it
98131
* @param cback The result of the API call.
99132
*/
100133
void operator()(const R &cback) const {
101-
std::unique_lock lock{get_mutex()};
102-
103-
if (state->state == callback_state::dangling) // Async object is gone - likely an exception killed it or it was never co_await-ed
104-
return;
105-
state->result = cback;
106-
state->state = callback_state::done;
107-
if (state->coro_handle) {
108-
auto handle = state->coro_handle;
109-
state->coro_handle = nullptr;
110-
lock.unlock();
111-
handle.resume();
134+
state->construct_result(cback);
135+
if (state->state.exchange(detail::async_state_t::done) == detail::async_state_t::waiting) {
136+
state->coro_handle.resume();
112137
}
113138
}
114139

115140
/**
116141
* @brief Main constructor, allocates a new callback_state object.
117142
*/
118-
shared_callback() : state{new callback_state{.ref_count = 1}} {}
119-
120-
shared_callback(empty_tag_t) noexcept : state{nullptr} {}
143+
shared_callback() = default;
121144

122145
/**
123-
* @brief Destructor. Releases the held reference and destroys if no other references exist.
146+
* @brief Empty constructor, holds no state.
124147
*/
125-
~shared_callback() {
126-
if (!state) // Moved-from object
127-
return;
128-
129-
std::unique_lock lock{state->mutex};
130-
131-
if (state->ref_count) {
132-
--(state->ref_count);
133-
if (state->ref_count <= 0) {;
134-
lock.unlock();
135-
delete state;
136-
}
137-
}
138-
}
148+
explicit shared_callback(detail::empty_tag_t) noexcept : state{nullptr} {}
139149

140150
/**
141151
* @brief Copy constructor. Takes shared ownership of the callback state, increasing the reference count.
142152
*/
143-
shared_callback(const shared_callback &other) {
153+
shared_callback(const shared_callback &other) noexcept {
144154
this->operator=(other);
145155
}
146156

@@ -151,12 +161,23 @@ class async {
151161
this->operator=(std::move(other));
152162
}
153163

164+
/**
165+
* @brief Destructor. Releases the held reference and destroys if no other references exist.
166+
*/
167+
~shared_callback() {
168+
if (!state) // Moved-from object
169+
return;
170+
171+
auto count = state->ref_count.fetch_sub(1);
172+
if (count == 0) {
173+
delete state;
174+
}
175+
}
176+
154177
/**
155178
* @brief Copy assignment. Takes shared ownership of the callback state, increasing the reference count.
156179
*/
157180
shared_callback &operator=(const shared_callback &other) noexcept {
158-
std::lock_guard lock{other.get_mutex()};
159-
160181
state = other.state;
161182
++state->ref_count;
162183
return *this;
@@ -166,36 +187,46 @@ class async {
166187
* @brief Move assignment. Transfers ownership from another object, leaving intact the reference count. The other object releases the callback state.
167188
*/
168189
shared_callback &operator=(shared_callback &&other) noexcept {
169-
std::lock_guard lock{other.get_mutex()};
170-
171190
state = std::exchange(other.state, nullptr);
172191
return *this;
173192
}
174193

175194
/**
176195
* @brief Function called by the async when it is destroyed when it was never co_awaited, signals to the callback to abort.
177196
*/
178-
void set_dangling() {
197+
void set_dangling() noexcept {
179198
if (!state) // moved-from object
180199
return;
181-
std::lock_guard lock{get_mutex()};
200+
/*
201+
If the state is sent but not awaited, set it to dangling, in a relaxed memory order (we don't care if the callback thread actually sees it).
202+
"sent" is the only state we care about to set it to dangling, as if it's done it's not dangling, and if it's waiting... Something went seriously wrong and shouldn't be happening.
203+
*/
204+
auto expected = detail::async_state_t::sent;
205+
state->state.compare_exchange_strong(expected, detail::async_state_t::dangling, std::memory_order_seq_cst, std::memory_order_relaxed);
206+
}
182207

183-
if (state->state == callback_state::waiting)
184-
state->state = callback_state::dangling;
208+
bool done(std::memory_order order = std::memory_order_seq_cst) const noexcept {
209+
return (state->state.load(order) == detail::async_state_t::done);
185210
}
186211

187212
/**
188-
* @brief Convenience function to get the shared callback state's mutex.
213+
* @brief Convenience function to get the shared callback state's result.
214+
*
215+
* @warning It is UB to call this on a callback whose state is anything else but async_state_t::done.
189216
*/
190-
std::mutex &get_mutex() const {
191-
return (state->mutex);
217+
R &get_result() noexcept {
218+
assert(state && done());
219+
return (*reinterpret_cast<R *>(state->result_storage.data()));
192220
}
193221

194222
/**
195223
* @brief Convenience function to get the shared callback state's result.
224+
*
225+
* @warning It is UB to call this on a callback whose state is anything else but async_state_t::done.
196226
*/
197-
std::optional<R> &get_result() const {
198-
return (state->result);
227+
const R &get_result() const noexcept {
228+
assert(state && done());
229+
return (*reinterpret_cast<R *>(state->result_storage.data()));
199230
}
200231
};
201232

@@ -246,7 +277,7 @@ class async {
246277
/**
247278
* @brief Construct an empty async. Using `co_await` on an empty async is undefined behavior.
248279
*/
249-
async() noexcept : api_callback{typename shared_callback::empty_tag_t{}} {}
280+
async() noexcept : api_callback{detail::empty_tag_t{}} {}
250281

251282
/**
252283
* @brief Destructor. If any callback is pending it will be aborted.
@@ -294,9 +325,7 @@ class async {
294325
* @return bool Whether we already have the result of the API call or not
295326
*/
296327
bool await_ready() noexcept {
297-
std::lock_guard lock{api_callback.get_mutex()};
298-
299-
return api_callback.get_result().has_value();
328+
return api_callback.done();
300329
}
301330

302331
/**
@@ -307,25 +336,42 @@ class async {
307336
* @remark Do not call this manually, use the co_await keyword instead.
308337
* @param handle The handle to the coroutine co_await-ing and being suspended
309338
*/
310-
template <typename T>
311-
bool await_suspend(detail::std_coroutine::coroutine_handle<T> caller) {
312-
std::lock_guard lock{api_callback.get_mutex()};
313-
314-
if (api_callback.get_result().has_value())
315-
return false; // immediately resume the coroutine as we already have the result of the api call
339+
bool await_suspend(detail::std_coroutine::coroutine_handle<> caller) {
340+
auto sent = detail::async_state_t::sent;
316341
api_callback.state->coro_handle = caller;
317-
return true; // suspend the caller, the callback will resume it
342+
return api_callback.state->state.compare_exchange_strong(sent, detail::async_state_t::waiting); // true (suspend) if `sent` was replaced with `waiting` -- false (resume) if the value was not `sent` (`done` is the only other option)
343+
}
344+
345+
/**
346+
* @brief Function called by the standard library when the async is resumed. Its return value is what the whole co_await expression evaluates to
347+
*
348+
* @remark Do not call this manually, use the co_await keyword instead.
349+
* @return R& The result of the API call as an lvalue reference.
350+
*/
351+
R& await_resume() & noexcept {
352+
return api_callback.get_result();
353+
}
354+
355+
356+
/**
357+
* @brief Function called by the standard library when the async is resumed. Its return value is what the whole co_await expression evaluates to
358+
*
359+
* @remark Do not call this manually, use the co_await keyword instead.
360+
* @return const R& The result of the API call as a const lvalue reference.
361+
*/
362+
const R& await_resume() const& noexcept {
363+
return api_callback.get_result();
318364
}
365+
319366

320367
/**
321368
* @brief Function called by the standard library when the async is resumed. Its return value is what the whole co_await expression evaluates to
322369
*
323370
* @remark Do not call this manually, use the co_await keyword instead.
324-
* @return R The result of the API call.
371+
* @return R&& The result of the API call as an rvalue reference.
325372
*/
326-
R await_resume() {
327-
// no locking needed here as the callback has already executed
328-
return std::move(*api_callback.get_result());
373+
R&& await_resume() && noexcept {
374+
return std::move(api_callback.get_result());
329375
}
330376
};
331377

0 commit comments

Comments
 (0)