Skip to content

Commit 0745316

Browse files
committed
[SYCL] Fix asynchronous exception behavior
This commit makes the following changes to the behavior of asynchronous exception handling: 1. The death of a queue should not consume asynchronous exceptions. 2. Calling wait_and_throw on an event after the associated queue has died should still consume exceptions that were originally associated with the queue. This should respect the async_handler priority to the best of its ability. 3. Calling wait_and_throw or throw_asynchronous on a queue without an async_handler should fall back to using the async_handler of the associated context, then the default async_handler if none were attached to the context. Additionally, this lays the ground work for intel#20266 by moving the tracking of unconsumed asynchronous exception to the devices. Signed-off-by: Larsen, Steffen <[email protected]>
1 parent 91653ad commit 0745316

File tree

10 files changed

+374
-54
lines changed

10 files changed

+374
-54
lines changed

sycl/include/sycl/exception_list.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ inline namespace _V1 {
2424
// Forward declaration
2525
namespace detail {
2626
class queue_impl;
27+
class device_impl;
2728
}
2829

2930
/// A list of asynchronous exceptions.
@@ -46,6 +47,7 @@ class __SYCL_EXPORT exception_list {
4647

4748
private:
4849
friend class detail::queue_impl;
50+
friend class detail::device_impl;
4951
void PushBack(const_reference Value);
5052
void PushBack(value_type &&Value);
5153
void Clear() noexcept;

sycl/source/detail/device_impl.hpp

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2262,6 +2262,31 @@ class device_impl : public std::enable_shared_from_this<device_impl> {
22622262
return {};
22632263
}
22642264

2265+
/// Puts exception to the list of asynchronous ecxeptions.
2266+
///
2267+
/// \param QueueWeakPtr is a weak pointer referring to the queue to report
2268+
/// the asynchronous exceptions for.
2269+
/// \param ExceptionPtr is a pointer to exception to be put.
2270+
void reportAsyncException(std::weak_ptr<queue_impl> QueueWeakPtr,
2271+
const std::exception_ptr &ExceptionPtr) {
2272+
std::lock_guard<std::mutex> Lock(MAsyncExceptionsMutex);
2273+
MAsyncExceptions[QueueWeakPtr].PushBack(ExceptionPtr);
2274+
}
2275+
2276+
/// Extracts all unconsumed asynchronous exceptions for a given queue.
2277+
///
2278+
/// \param QueueWeakPtr is a weak pointer referring to the queue to extract
2279+
/// unconsumed asynchronous exceptions for.
2280+
exception_list flushAsyncExceptions(std::weak_ptr<queue_impl> QueueWeakPtr) {
2281+
std::lock_guard<std::mutex> Lock(MAsyncExceptionsMutex);
2282+
auto ExceptionsEntryIt = MAsyncExceptions.find(QueueWeakPtr);
2283+
if (ExceptionsEntryIt == MAsyncExceptions.end())
2284+
return exception_list{};
2285+
exception_list Exceptions = std::move(ExceptionsEntryIt->second);
2286+
MAsyncExceptions.erase(ExceptionsEntryIt);
2287+
return Exceptions;
2288+
}
2289+
22652290
private:
22662291
ur_device_handle_t MDevice = 0;
22672292
// This is used for getAdapter so should be above other properties.
@@ -2272,6 +2297,13 @@ class device_impl : public std::enable_shared_from_this<device_impl> {
22722297

22732298
const ur_device_handle_t MRootDevice;
22742299

2300+
// Asynchronous exceptions are captured at device-level until flushed, either
2301+
// by queues, events or a synchronization on the device itself.
2302+
std::mutex MAsyncExceptionsMutex;
2303+
std::map<std::weak_ptr<queue_impl>, exception_list,
2304+
std::owner_less<std::weak_ptr<queue_impl>>>
2305+
MAsyncExceptions;
2306+
22752307
// Order of caches matters! UR must come before SYCL info descriptors (because
22762308
// get_info calls get_info_impl but the opposite never happens) and both
22772309
// should come before aspects.

sycl/source/detail/event_impl.cpp

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -211,8 +211,9 @@ void event_impl::initHostProfilingInfo() {
211211
MHostProfilingInfo->setDevice(&Device);
212212
}
213213

214-
void event_impl::setSubmittedQueue(std::weak_ptr<queue_impl> SubmittedQueue) {
215-
MSubmittedQueue = std::move(SubmittedQueue);
214+
void event_impl::setSubmittedQueue(queue_impl *SubmittedQueue) {
215+
MSubmittedQueue = SubmittedQueue->weak_from_this();
216+
MSubmittedDevice = &SubmittedQueue->getDeviceImpl();
216217
}
217218

218219
#ifdef XPTI_ENABLE_INSTRUMENTATION
@@ -308,8 +309,28 @@ void event_impl::wait(bool *Success) {
308309
void event_impl::wait_and_throw() {
309310
wait();
310311

311-
if (std::shared_ptr<queue_impl> SubmittedQueue = MSubmittedQueue.lock())
312+
if (std::shared_ptr<queue_impl> SubmittedQueue = MSubmittedQueue.lock()) {
312313
SubmittedQueue->throw_asynchronous();
314+
return;
315+
}
316+
317+
// If the queue has died, we rely on finding its exceptions through the
318+
// device.
319+
if (MSubmittedDevice == nullptr)
320+
return;
321+
322+
// If MSubmittedQueue has died, get flush any exceptions associated with it
323+
// still, then user either the context async_handler or the default
324+
// async_handler.
325+
exception_list Exceptions =
326+
MSubmittedDevice->flushAsyncExceptions(MSubmittedQueue);
327+
if (Exceptions.size() == 0)
328+
return;
329+
330+
if (MContext && MContext->get_async_handler())
331+
MContext->get_async_handler()(std::move(Exceptions));
332+
else
333+
defaultAsyncHandler(std::move(Exceptions));
313334
}
314335

315336
void event_impl::checkProfilingPreconditions() const {

sycl/source/detail/event_impl.hpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -264,10 +264,10 @@ class event_impl {
264264
MWorkerQueue = std::move(WorkerQueue);
265265
};
266266

267-
/// Sets original queue used for submission.
267+
/// Sets original queue and device used for submission.
268268
///
269269
/// @return
270-
void setSubmittedQueue(std::weak_ptr<queue_impl> SubmittedQueue);
270+
void setSubmittedQueue(queue_impl *SubmittedQueue);
271271

272272
/// Indicates if this event is not associated with any command and doesn't
273273
/// have native handle.
@@ -394,6 +394,7 @@ class event_impl {
394394

395395
std::weak_ptr<queue_impl> MWorkerQueue;
396396
std::weak_ptr<queue_impl> MSubmittedQueue;
397+
device_impl *MSubmittedDevice = nullptr;
397398

398399
/// Dependency events prepared for waiting by backend.
399400
std::vector<EventImplPtr> MPreparedDepsEvents;

sycl/source/detail/queue_impl.hpp

Lines changed: 10 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -253,7 +253,6 @@ class queue_impl : public std::enable_shared_from_this<queue_impl> {
253253
// notification and destroy the trace event for this queue.
254254
destructorNotification();
255255
#endif
256-
throw_asynchronous();
257256
auto status =
258257
getAdapter().call_nocheck<UrApiKind::urQueueRelease>(MQueue);
259258
// If loader is already closed, it'll return a not-initialized status
@@ -393,9 +392,6 @@ class queue_impl : public std::enable_shared_from_this<queue_impl> {
393392
/// @param Loc is the code location of the submit call (default argument)
394393
void wait(const detail::code_location &Loc = {});
395394

396-
/// \return list of asynchronous exceptions occurred during execution.
397-
exception_list getExceptionList() const { return MExceptions; }
398-
399395
/// @param Loc is the code location of the submit call (default argument)
400396
void wait_and_throw(const detail::code_location &Loc = {}) {
401397
wait(Loc);
@@ -408,21 +404,20 @@ class queue_impl : public std::enable_shared_from_this<queue_impl> {
408404
/// Synchronous errors will be reported through SYCL exceptions.
409405
/// Asynchronous errors will be passed to the async_handler passed to the
410406
/// queue on construction. If no async_handler was provided then
411-
/// asynchronous exceptions will be lost.
407+
/// asynchronous exceptions will be passed to the default async_handler.
412408
void throw_asynchronous() {
413-
if (!MAsyncHandler)
409+
exception_list Exceptions =
410+
getDeviceImpl().flushAsyncExceptions(weak_from_this());
411+
if (Exceptions.size() == 0)
414412
return;
415413

416-
exception_list Exceptions;
417-
{
418-
std::lock_guard<std::mutex> Lock(MMutex);
419-
std::swap(Exceptions, MExceptions);
420-
}
421-
// Unlock the mutex before calling user-provided handler to avoid
422-
// potential deadlock if the same queue is somehow referenced in the
423-
// handler.
424-
if (Exceptions.size())
414+
if (MAsyncHandler)
425415
MAsyncHandler(std::move(Exceptions));
416+
else if (const async_handler &CtxAsyncHandler =
417+
getContextImpl().get_async_handler())
418+
CtxAsyncHandler(std::move(Exceptions));
419+
else
420+
defaultAsyncHandler(std::move(Exceptions));
426421
}
427422

428423
/// Creates UR properties array.
@@ -570,14 +565,6 @@ class queue_impl : public std::enable_shared_from_this<queue_impl> {
570565
event mem_advise(const void *Ptr, size_t Length, ur_usm_advice_flags_t Advice,
571566
const std::vector<event> &DepEvents, bool CallerNeedsEvent);
572567

573-
/// Puts exception to the list of asynchronous ecxeptions.
574-
///
575-
/// \param ExceptionPtr is a pointer to exception to be put.
576-
void reportAsyncException(const std::exception_ptr &ExceptionPtr) {
577-
std::lock_guard<std::mutex> Lock(MMutex);
578-
MExceptions.PushBack(ExceptionPtr);
579-
}
580-
581568
static ThreadPool &getThreadPool() {
582569
return GlobalHandler::instance().getHostTaskThreadPool();
583570
}
@@ -979,10 +966,6 @@ class queue_impl : public std::enable_shared_from_this<queue_impl> {
979966
/// These events are tracked, but not owned, by the queue.
980967
std::vector<std::weak_ptr<event_impl>> MEventsWeak;
981968

982-
/// Events without data dependencies (such as USM) need an owner,
983-
/// additionally, USM operations are not added to the scheduler command graph,
984-
/// queue is the only owner on the runtime side.
985-
exception_list MExceptions;
986969
const async_handler MAsyncHandler;
987970
const property_list MPropList;
988971

sycl/source/detail/scheduler/commands.cpp

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -359,12 +359,14 @@ class DispatchHostTask {
359359
AdapterWithEvents.first->call<UrApiKind::urEventWait>(RawEvents.size(),
360360
RawEvents.data());
361361
} catch (const sycl::exception &) {
362-
MThisCmd->MEvent->getSubmittedQueue()->reportAsyncException(
363-
std::current_exception());
362+
auto QueuePtr = MThisCmd->MEvent->getSubmittedQueue();
363+
QueuePtr->getDeviceImpl().reportAsyncException(
364+
QueuePtr, std::current_exception());
364365
return false;
365366
} catch (...) {
366-
MThisCmd->MEvent->getSubmittedQueue()->reportAsyncException(
367-
std::current_exception());
367+
auto QueuePtr = MThisCmd->MEvent->getSubmittedQueue();
368+
QueuePtr->getDeviceImpl().reportAsyncException(
369+
QueuePtr, std::current_exception());
368370
return false;
369371
}
370372
}
@@ -407,7 +409,8 @@ class DispatchHostTask {
407409
make_error_code(errc::runtime),
408410
std::string("Couldn't wait for host-task's dependencies")));
409411

410-
MThisCmd->MEvent->getSubmittedQueue()->reportAsyncException(EPtr);
412+
auto QueuePtr = MThisCmd->MEvent->getSubmittedQueue();
413+
QueuePtr->getDeviceImpl().reportAsyncException(QueuePtr, EPtr);
411414
// reset host-task's lambda and quit
412415
HostTask.MHostTask.reset();
413416
Scheduler::getInstance().NotifyHostTaskCompletion(MThisCmd);
@@ -469,8 +472,9 @@ class DispatchHostTask {
469472
}
470473
}
471474
#endif
472-
MThisCmd->MEvent->getSubmittedQueue()->reportAsyncException(
473-
CurrentException);
475+
auto QueuePtr = MThisCmd->MEvent->getSubmittedQueue();
476+
QueuePtr->getDeviceImpl().reportAsyncException(QueuePtr,
477+
CurrentException);
474478
}
475479

476480
HostTask.MHostTask.reset();
@@ -487,8 +491,9 @@ class DispatchHostTask {
487491
Scheduler::getInstance().NotifyHostTaskCompletion(MThisCmd);
488492
} catch (...) {
489493
auto CurrentException = std::current_exception();
490-
MThisCmd->MEvent->getSubmittedQueue()->reportAsyncException(
491-
CurrentException);
494+
auto QueuePtr = MThisCmd->MEvent->getSubmittedQueue();
495+
QueuePtr->getDeviceImpl().reportAsyncException(QueuePtr,
496+
CurrentException);
492497
}
493498
}
494499
};
@@ -563,7 +568,8 @@ Command::Command(
563568
MCommandBuffer(CommandBuffer), MSyncPointDeps(SyncPoints) {
564569
MWorkerQueue = MQueue;
565570
MEvent->setWorkerQueue(MWorkerQueue);
566-
MEvent->setSubmittedQueue(MWorkerQueue);
571+
if (Queue)
572+
MEvent->setSubmittedQueue(Queue);
567573
MEvent->setCommand(this);
568574
if (MQueue)
569575
MEvent->setContextImpl(MQueue->getContextImpl());
@@ -1958,7 +1964,7 @@ ExecCGCommand::ExecCGCommand(
19581964
assert(SubmitQueue &&
19591965
"Host task command group must have a valid submit queue");
19601966

1961-
MEvent->setSubmittedQueue(SubmitQueue->weak_from_this());
1967+
MEvent->setSubmittedQueue(SubmitQueue);
19621968
// Initialize host profiling info if the queue has profiling enabled.
19631969
if (SubmitQueue->MIsProfilingEnabled)
19641970
MEvent->initHostProfilingInfo();

sycl/source/detail/scheduler/scheduler.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -260,7 +260,8 @@ EventImplPtr Scheduler::addCopyBack(Requirement *Req) {
260260
auto WorkerQueue = NewCmd->getEvent()->getWorkerQueue();
261261
assert(WorkerQueue &&
262262
"WorkerQueue for CopyBack command must be not null");
263-
WorkerQueue->reportAsyncException(std::current_exception());
263+
WorkerQueue->getDeviceImpl().reportAsyncException(
264+
WorkerQueue, std::current_exception());
264265
}
265266
}
266267
EventImplPtr NewEvent = NewCmd->getEvent();

sycl/source/handler.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -956,7 +956,7 @@ event handler::finalize() {
956956
// it to the graph to create a node, rather than submit it to the scheduler.
957957
if (auto GraphImpl = Queue->getCommandGraph(); GraphImpl) {
958958
auto EventImpl = detail::event_impl::create_completed_host_event();
959-
EventImpl->setSubmittedQueue(Queue->weak_from_this());
959+
EventImpl->setSubmittedQueue(Queue);
960960
ext::oneapi::experimental::detail::node_impl *NodeImpl = nullptr;
961961

962962
// GraphImpl is read and written in this scope so we lock this graph

0 commit comments

Comments
 (0)