Skip to content

Commit 1ba10c4

Browse files
committed
DRY with enable_if
1 parent 145ce28 commit 1ba10c4

File tree

1 file changed

+136
-70
lines changed

1 file changed

+136
-70
lines changed

llvm/include/llvm/ExecutionEngine/Orc/TaskDispatch.h

Lines changed: 136 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,13 @@
1717
#include "llvm/Support/Compiler.h"
1818
#include "llvm/Support/Debug.h"
1919
#include "llvm/Support/ExtensibleRTTI.h"
20+
#include "llvm/Support/ErrorHandling.h"
2021
#include "llvm/Support/raw_ostream.h"
2122

2223
#include <atomic>
2324
#include <cassert>
2425
#include <string>
26+
#include <type_traits>
2527

2628
#if LLVM_ENABLE_THREADS
2729
#include <condition_variable>
@@ -35,8 +37,6 @@ namespace orc {
3537

3638
/// Forward declarations
3739
class future_base;
38-
template <typename T> class future;
39-
template <typename T> class promise;
4040
class TaskDispatcher;
4141

4242
/// Represents an abstract task for ORC to run.
@@ -170,13 +170,23 @@ class LLVM_ABI DynamicThreadPoolTaskDispatcher : public TaskDispatcher {
170170

171171
#endif // LLVM_ENABLE_THREADS
172172

173+
/// Status for future/promise state
174+
enum class FutureStatus : uint8_t {
175+
NotReady = 0,
176+
Ready = 1,
177+
NotValid = 2
178+
};
179+
173180
/// Type-erased base class for futures
174181
class future_base {
175182
public:
176-
virtual ~future_base() = default;
177-
178183
bool is_ready() const {
179-
return state_->status_.load(std::memory_order_acquire) != 0;
184+
return state_->status_.load(std::memory_order_acquire) != FutureStatus::NotReady;
185+
}
186+
187+
/// Check if the future is in a valid state (not moved-from and not consumed)
188+
bool valid() const {
189+
return state_ && state_->status_.load(std::memory_order_acquire) != FutureStatus::NotValid;
180190
}
181191

182192
/// Wait for the future to be ready, helping with task dispatch
@@ -189,49 +199,84 @@ class future_base {
189199

190200
protected:
191201
struct state_base {
192-
std::atomic<uint8_t> status_{0};
202+
std::atomic<FutureStatus> status_{FutureStatus::NotReady};
193203
};
194204

195-
future_base(std::shared_ptr<state_base> state) : state_(std::move(state)) {}
205+
future_base(state_base* state) : state_(state) {}
196206
future_base() = default;
207+
~future_base() {
208+
if (valid())
209+
report_fatal_error("get() must be called before future destruction");
210+
delete state_;
211+
}
212+
213+
// Move constructor and assignment
214+
future_base(future_base&& other) noexcept : state_(other.state_) {
215+
other.state_ = nullptr;
216+
}
217+
future_base& operator=(future_base&& other) noexcept {
218+
if (this != &other) {
219+
delete state_;
220+
state_ = other.state_;
221+
other.state_ = nullptr;
222+
}
223+
return *this;
224+
}
197225

198-
std::shared_ptr<state_base> state_;
226+
state_base* state_;
199227
};
200228

201229
/// ORC-aware future class that can help with task dispatch while waiting
230+
231+
template <typename T> class future;
232+
template <typename T> class promise;
202233
template <typename T>
203234
class future : public future_base {
204235
public:
205236
struct state : public future_base::state_base {
206-
T value_;
237+
template <typename U>
238+
struct value_storage {
239+
U value_;
240+
};
241+
242+
template <>
243+
struct value_storage<void> {
244+
// No value_ member for void
245+
};
246+
247+
value_storage<T> storage;
207248
};
208249

209-
future() = default;
250+
future() = delete;
210251
future(const future&) = delete;
211252
future& operator=(const future&) = delete;
212253
future(future&&) = default;
213254
future& operator=(future&&) = default;
214255

215-
216256
/// Get the value, helping with task dispatch while waiting.
217257
/// This will destroy the underlying value, so this must only be called once.
218258
T get(TaskDispatcher& D) {
259+
if (!valid())
260+
report_fatal_error("get() must only be called once");
219261
wait(D);
220-
// optionally: state_->ready_.swap(0, std::memory_order_acquire);
221-
return std::move(static_cast<typename future<T>::state*>(state_.get())->value_);
222-
}
223-
224-
/// Cast a future to a different type using static_pointer_cast
225-
template <typename U>
226-
static future<U> static_pointer_cast(future<T>&& f) {
227-
std::shared_ptr<typename future<U>::state> casted_state = std::static_pointer_cast<typename future<U>::state>(std::move(f.state_));
228-
return future<U>(casted_state);
262+
auto old_status = state_->status_.exchange(FutureStatus::NotValid, std::memory_order_release);
263+
if (old_status != FutureStatus::Ready)
264+
report_fatal_error("get() must only be called once");
265+
return take_value();
229266
}
230267

231268
private:
232269
friend class promise<T>;
233270

234-
explicit future(std::shared_ptr<state> state) : future_base(state) {}
271+
template <typename U = T>
272+
typename std::enable_if<!std::is_void<U>::value, U>::type take_value() {
273+
return std::move(static_cast<typename future<T>::state*>(state_)->storage.value_);
274+
}
275+
276+
template <typename U = T>
277+
typename std::enable_if<std::is_void<U>::value, U>::type take_value() {}
278+
279+
explicit future(state* state) : future_base(state) {}
235280
};
236281

237282
/// ORC-aware promise class that works with ORC future
@@ -240,74 +285,95 @@ class promise {
240285
friend class future<T>;
241286

242287
public:
243-
promise() : state_(std::make_shared<typename future<T>::state>()) {}
288+
promise() : state_(new typename future<T>::state()), future_created_(false) {}
289+
290+
~promise() {
291+
// Delete state only if get_future() was never called
292+
if (!future_created_) {
293+
delete state_;
294+
}
295+
}
296+
244297
promise(const promise&) = delete;
245298
promise& operator=(const promise&) = delete;
246-
promise(promise&&) = default;
247-
promise& operator=(promise&&) = default;
299+
300+
promise(promise&& other) noexcept
301+
: state_(other.state_), future_created_(other.future_created_) {
302+
other.state_ = nullptr;
303+
other.future_created_ = false;
304+
}
305+
306+
promise& operator=(promise&& other) noexcept {
307+
if (this != &other) {
308+
if (!future_created_) {
309+
delete state_;
310+
}
311+
state_ = other.state_;
312+
future_created_ = other.future_created_;
313+
other.state_ = nullptr;
314+
other.future_created_ = false;
315+
}
316+
return *this;
317+
}
248318

249-
/// Get the associated future
319+
/// Get the associated future (must only be called once)
250320
future<T> get_future() {
321+
assert(!future_created_ && "get_future() can only be called once");
322+
future_created_ = true;
251323
return future<T>(state_);
252324
}
253325

254-
/// Set the value
255-
void set_value(const T& value) {
256-
state_->value_ = value;
257-
state_->status_.store(1, std::memory_order_release);
326+
/// Set the value (must only be called once)
327+
// In C++20, this std::conditional weirdness can probably be replaced just
328+
// with requires. It ensures that we don't try to define a method for `void&`,
329+
// but that if the user calls set_value(v) for any value v that they get a
330+
// member function error, instead of no member named 'value_'.
331+
template <typename U = T>
332+
void set_value(const typename std::conditional<std::is_void<T>::value, std::nullopt_t, T>::type& value) {
333+
assert(state_ && "Invalid promise state");
334+
state_->storage.value_ = value;
335+
state_->status_.store(FutureStatus::Ready, std::memory_order_release);
258336
}
259337

260-
void set_value(T&& value) {
261-
state_->value_ = std::move(value);
262-
state_->status_.store(1, std::memory_order_release);
338+
template <typename U = T>
339+
void set_value(typename std::conditional<std::is_void<T>::value, std::nullopt_t, T>::type&& value) {
340+
assert(state_ && "Invalid promise state");
341+
state_->storage.value_ = std::move(value);
342+
state_->status_.store(FutureStatus::Ready, std::memory_order_release);
263343
}
264344

265-
private:
266-
std::shared_ptr<typename future<T>::state> state_;
267-
};
268-
269-
/// Specialization of future<void>
270-
template <> class future<void> : public future_base {
271-
public:
272-
using state = future_base::state_base;
273-
274-
future() = default;
275-
future(const future &) = delete;
276-
future &operator=(const future &) = delete;
277-
future(future &&) = default;
278-
future &operator=(future &&) = default;
279-
280-
/// Get the value (void), helping with task dispatch while waiting.
281-
void get(TaskDispatcher &D) { wait(D); }
345+
template <typename U = T>
346+
typename std::enable_if<std::is_void<U>::value, void>::type set_value(const std::nullopt_t& value) = delete;
282347

283-
private:
284-
friend class promise<void>;
285-
286-
explicit future(std::shared_ptr<state> state) : future_base(state) {}
287-
};
288-
289-
/// Specialization of promise<void>
290-
template <> class promise<void> {
291-
friend class future<void>;
292-
293-
public:
294-
promise() : state_(std::make_shared<future<void>::state>()) {}
295-
promise(const promise &) = delete;
296-
promise &operator=(const promise &) = delete;
297-
promise(promise &&) = default;
298-
promise &operator=(promise &&) = default;
348+
template <typename U = T>
349+
typename std::enable_if<std::is_void<U>::value, void>::type set_value(std::nullopt_t&& value) = delete;
299350

300-
/// Get the associated future
301-
future<void> get_future() { return future<void>(state_); }
351+
template <typename U = T>
352+
typename std::enable_if<std::is_void<U>::value, void>::type set_value() {
353+
assert(state_ && "Invalid promise state");
354+
state_->status_.store(FutureStatus::Ready, std::memory_order_release);
355+
}
302356

303-
/// Set the value (void)
304-
void set_value() { state_->status_.store(1, std::memory_order_release); }
357+
/// Swap with another promise
358+
void swap(promise& other) noexcept {
359+
using std::swap;
360+
swap(state_, other.state_);
361+
swap(future_created_, other.future_created_);
362+
}
305363

306364
private:
307-
std::shared_ptr<future<void>::state> state_;
365+
typename future<T>::state* state_;
366+
bool future_created_;
308367
};
309368

310369
} // End namespace orc
311370
} // End namespace llvm
312371

372+
namespace std {
373+
template <typename T>
374+
void swap(llvm::orc::promise<T>& lhs, llvm::orc::promise<T>& rhs) noexcept {
375+
lhs.swap(rhs);
376+
}
377+
} // End namespace std
378+
313379
#endif // LLVM_EXECUTIONENGINE_ORC_TASKDISPATCH_H

0 commit comments

Comments
 (0)