1717#include " llvm/Support/Compiler.h"
1818#include " llvm/Support/Debug.h"
1919#include " llvm/Support/ExtensibleRTTI.h"
20+ #include " llvm/Support/ErrorHandling.h"
2021#include " llvm/Support/raw_ostream.h"
2122
2223#include < atomic>
2324#include < cassert>
2425#include < string>
26+ #include < type_traits>
2527
2628#if LLVM_ENABLE_THREADS
2729#include < condition_variable>
@@ -35,8 +37,6 @@ namespace orc {
3537
3638// / Forward declarations
3739class future_base ;
38- template <typename T> class future ;
39- template <typename T> class promise ;
4040class TaskDispatcher ;
4141
4242// / Represents an abstract task for ORC to run.
@@ -170,13 +170,23 @@ class LLVM_ABI DynamicThreadPoolTaskDispatcher : public TaskDispatcher {
170170
171171#endif // LLVM_ENABLE_THREADS
172172
173+ // / Status for future/promise state
174+ enum class FutureStatus : uint8_t {
175+ NotReady = 0 ,
176+ Ready = 1 ,
177+ NotValid = 2
178+ };
179+
173180// / Type-erased base class for futures
174181class future_base {
175182public:
176- virtual ~future_base () = default ;
177-
178183 bool is_ready () const {
179- return state_->status_ .load (std::memory_order_acquire) != 0 ;
184+ return state_->status_ .load (std::memory_order_acquire) != FutureStatus::NotReady;
185+ }
186+
187+ // / Check if the future is in a valid state (not moved-from and not consumed)
188+ bool valid () const {
189+ return state_ && state_->status_ .load (std::memory_order_acquire) != FutureStatus::NotValid;
180190 }
181191
182192 // / Wait for the future to be ready, helping with task dispatch
@@ -189,49 +199,84 @@ class future_base {
189199
190200protected:
191201 struct state_base {
192- std::atomic<uint8_t > status_{0 };
202+ std::atomic<FutureStatus > status_{FutureStatus::NotReady };
193203 };
194204
195- future_base (std::shared_ptr< state_base> state) : state_(std::move( state) ) {}
205+ future_base (state_base* state) : state_(state) {}
196206 future_base () = default ;
207+ ~future_base () {
208+ if (valid ())
209+ report_fatal_error (" get() must be called before future destruction" );
210+ delete state_;
211+ }
212+
213+ // Move constructor and assignment
214+ future_base (future_base&& other) noexcept : state_(other.state_) {
215+ other.state_ = nullptr ;
216+ }
217+ future_base& operator =(future_base&& other) noexcept {
218+ if (this != &other) {
219+ delete state_;
220+ state_ = other.state_ ;
221+ other.state_ = nullptr ;
222+ }
223+ return *this ;
224+ }
197225
198- std::shared_ptr< state_base> state_;
226+ state_base* state_;
199227};
200228
201229// / ORC-aware future class that can help with task dispatch while waiting
230+
231+ template <typename T> class future ;
232+ template <typename T> class promise ;
202233template <typename T>
203234class future : public future_base {
204235public:
205236 struct state : public future_base ::state_base {
206- T value_;
237+ template <typename U>
238+ struct value_storage {
239+ U value_;
240+ };
241+
242+ template <>
243+ struct value_storage <void > {
244+ // No value_ member for void
245+ };
246+
247+ value_storage<T> storage;
207248 };
208249
209- future () = default ;
250+ future () = delete ;
210251 future (const future&) = delete ;
211252 future& operator =(const future&) = delete ;
212253 future (future&&) = default ;
213254 future& operator =(future&&) = default ;
214255
215-
216256 // / Get the value, helping with task dispatch while waiting.
217257 // / This will destroy the underlying value, so this must only be called once.
218258 T get (TaskDispatcher& D) {
259+ if (!valid ())
260+ report_fatal_error (" get() must only be called once" );
219261 wait (D);
220- // optionally: state_->ready_.swap(0, std::memory_order_acquire);
221- return std::move (static_cast <typename future<T>::state*>(state_.get ())->value_ );
222- }
223-
224- // / Cast a future to a different type using static_pointer_cast
225- template <typename U>
226- static future<U> static_pointer_cast (future<T>&& f) {
227- std::shared_ptr<typename future<U>::state> casted_state = std::static_pointer_cast<typename future<U>::state>(std::move (f.state_ ));
228- return future<U>(casted_state);
262+ auto old_status = state_->status_ .exchange (FutureStatus::NotValid, std::memory_order_release);
263+ if (old_status != FutureStatus::Ready)
264+ report_fatal_error (" get() must only be called once" );
265+ return take_value ();
229266 }
230267
231268private:
232269 friend class promise <T>;
233270
234- explicit future (std::shared_ptr<state> state) : future_base(state) {}
271+ template <typename U = T>
272+ typename std::enable_if<!std::is_void<U>::value, U>::type take_value () {
273+ return std::move (static_cast <typename future<T>::state*>(state_)->storage .value_ );
274+ }
275+
276+ template <typename U = T>
277+ typename std::enable_if<std::is_void<U>::value, U>::type take_value () {}
278+
279+ explicit future (state* state) : future_base(state) {}
235280};
236281
237282// / ORC-aware promise class that works with ORC future
@@ -240,74 +285,95 @@ class promise {
240285 friend class future <T>;
241286
242287public:
243- promise () : state_(std::make_shared<typename future<T>::state>()) {}
288+ promise () : state_(new typename future<T>::state()), future_created_(false ) {}
289+
290+ ~promise () {
291+ // Delete state only if get_future() was never called
292+ if (!future_created_) {
293+ delete state_;
294+ }
295+ }
296+
244297 promise (const promise&) = delete ;
245298 promise& operator =(const promise&) = delete ;
246- promise (promise&&) = default ;
247- promise& operator =(promise&&) = default ;
299+
300+ promise (promise&& other) noexcept
301+ : state_(other.state_), future_created_(other.future_created_) {
302+ other.state_ = nullptr ;
303+ other.future_created_ = false ;
304+ }
305+
306+ promise& operator =(promise&& other) noexcept {
307+ if (this != &other) {
308+ if (!future_created_) {
309+ delete state_;
310+ }
311+ state_ = other.state_ ;
312+ future_created_ = other.future_created_ ;
313+ other.state_ = nullptr ;
314+ other.future_created_ = false ;
315+ }
316+ return *this ;
317+ }
248318
249- // / Get the associated future
319+ // / Get the associated future (must only be called once)
250320 future<T> get_future () {
321+ assert (!future_created_ && " get_future() can only be called once" );
322+ future_created_ = true ;
251323 return future<T>(state_);
252324 }
253325
254- // / Set the value
255- void set_value (const T& value) {
256- state_->value_ = value;
257- state_->status_ .store (1 , std::memory_order_release);
326+ // / Set the value (must only be called once)
327+ // In C++20, this std::conditional weirdness can probably be replaced just
328+ // with requires. It ensures that we don't try to define a method for `void&`,
329+ // but that if the user calls set_value(v) for any value v that they get a
330+ // member function error, instead of no member named 'value_'.
331+ template <typename U = T>
332+ void set_value (const typename std::conditional<std::is_void<T>::value, std::nullopt_t , T>::type& value) {
333+ assert (state_ && " Invalid promise state" );
334+ state_->storage .value_ = value;
335+ state_->status_ .store (FutureStatus::Ready, std::memory_order_release);
258336 }
259337
260- void set_value (T&& value) {
261- state_->value_ = std::move (value);
262- state_->status_ .store (1 , std::memory_order_release);
338+ template <typename U = T>
339+ void set_value (typename std::conditional<std::is_void<T>::value, std::nullopt_t , T>::type&& value) {
340+ assert (state_ && " Invalid promise state" );
341+ state_->storage .value_ = std::move (value);
342+ state_->status_ .store (FutureStatus::Ready, std::memory_order_release);
263343 }
264344
265- private:
266- std::shared_ptr<typename future<T>::state> state_;
267- };
268-
269- // / Specialization of future<void>
270- template <> class future <void > : public future_base {
271- public:
272- using state = future_base::state_base;
273-
274- future () = default ;
275- future (const future &) = delete ;
276- future &operator =(const future &) = delete ;
277- future (future &&) = default ;
278- future &operator =(future &&) = default ;
279-
280- // / Get the value (void), helping with task dispatch while waiting.
281- void get (TaskDispatcher &D) { wait (D); }
345+ template <typename U = T>
346+ typename std::enable_if<std::is_void<U>::value, void >::type set_value (const std::nullopt_t & value) = delete;
282347
283- private:
284- friend class promise <void >;
285-
286- explicit future (std::shared_ptr<state> state) : future_base(state) {}
287- };
288-
289- // / Specialization of promise<void>
290- template <> class promise <void > {
291- friend class future <void >;
292-
293- public:
294- promise () : state_(std::make_shared<future<void >::state>()) {}
295- promise (const promise &) = delete ;
296- promise &operator =(const promise &) = delete ;
297- promise (promise &&) = default ;
298- promise &operator =(promise &&) = default ;
348+ template <typename U = T>
349+ typename std::enable_if<std::is_void<U>::value, void >::type set_value (std::nullopt_t && value) = delete;
299350
300- // / Get the associated future
301- future<void > get_future () { return future<void >(state_); }
351+ template <typename U = T>
352+ typename std::enable_if<std::is_void<U>::value, void >::type set_value () {
353+ assert (state_ && " Invalid promise state" );
354+ state_->status_ .store (FutureStatus::Ready, std::memory_order_release);
355+ }
302356
303- // / Set the value (void)
304- void set_value () { state_->status_ .store (1 , std::memory_order_release); }
357+ // / Swap with another promise
358+ void swap (promise& other) noexcept {
359+ using std::swap;
360+ swap (state_, other.state_ );
361+ swap (future_created_, other.future_created_ );
362+ }
305363
306364private:
307- std::shared_ptr<future<void >::state> state_;
365+ typename future<T>::state* state_;
366+ bool future_created_;
308367};
309368
310369} // End namespace orc
311370} // End namespace llvm
312371
372+ namespace std {
373+ template <typename T>
374+ void swap (llvm::orc::promise<T>& lhs, llvm::orc::promise<T>& rhs) noexcept {
375+ lhs.swap (rhs);
376+ }
377+ } // End namespace std
378+
313379#endif // LLVM_EXECUTIONENGINE_ORC_TASKDISPATCH_H
0 commit comments