Commit 0350c7e
[BE] Introduce torch.AcceleratorError (pytorch#152023)
Which inherits from `RuntimeError` and contains `error_code`, which in case of CUDA should contain error returned by `cudaGetLastError`
`torch::detail::_new_accelerator_error_object(c10::AcceleratorError&)` follows the pattern of CPython's [`PyErr_SetString`](https://github.com/python/cpython/blob/cb8a72b301f47e76d93a7fe5b259e9a5758792e1/Python/errors.c#L282), namely
- Convert cstr into Python string with `PyUnicode_FromString`
- Create new exception object using `PyObject_CallOneArg` just like it's done in [`_PyErr_CreateException`](https://github.com/python/cpython/blob/cb8a72b301f47e76d93a7fe5b259e9a5758792e1/Python/errors.c#L32)
- Set `error_code` property using `PyObject_SetAttrString`
- decref all temporary references
Test that it works and captures CPP backtrace (in addition to CI) by running
```python
import os
os.environ['TORCH_SHOW_CPP_STACKTRACES'] = '1'
import torch
x = torch.rand(10, device="cuda")
y = torch.arange(20, device="cuda")
try:
x[y] = 2
print(x)
except torch.AcceleratorError as e:
print("Exception was raised", e.args[0])
print("Captured error code is ", e.error_code)
```
which produces following output
```
Exception was raised CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.
Exception raised from c10_cuda_check_implementation at /home/ubuntu/pytorch/c10/cuda/CUDAException.cpp:41 (most recent call first):
C++ CapturedTraceback:
#4 std::_Function_handler<std::shared_ptr<c10::LazyValue<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > > const> (), c10::SetStackTraceFetcher(std::function<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > ()>)::{lambda()#1}>::_M_invoke(std::_Any_data const&) from Logging.cpp:0
#5 c10::Error::Error(c10::SourceLocation, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >) from ??:0
pytorch#6 c10::cuda::c10_cuda_check_implementation(int, char const*, char const*, int, bool) [clone .cold] from CUDAException.cpp:0
pytorch#7 void at::native::gpu_kernel_impl<at::native::AbsFunctor<float> >(at::TensorIteratorBase&, at::native::AbsFunctor<float> const&) [clone .isra.0] from tmpxft_000191fc_00000000-6_AbsKernel.cudafe1.cpp:0
pytorch#8 at::native::abs_kernel_cuda(at::TensorIteratorBase&) from ??:0
pytorch#9 at::Tensor& at::native::unary_op_impl_with_complex_to_float_out<at::native::abs_stub_DECLARE_DISPATCH_type>(at::Tensor&, at::Tensor const&, at::native::abs_stub_DECLARE_DISPATCH_type&, bool) [clone .constprop.0] from UnaryOps.cpp:0
pytorch#10 at::(anonymous namespace)::(anonymous namespace)::wrapper_CUDA_out_abs_out(at::Tensor const&, at::Tensor&) from RegisterCUDA_0.cpp:0
pytorch#11 at::_ops::abs_out::call(at::Tensor const&, at::Tensor&) from ??:0
pytorch#12 at::native::abs(at::Tensor const&) from ??:0
pytorch#13 c10::impl::wrap_kernel_functor_unboxed_<c10::impl::detail::WrapFunctionIntoFunctor_<c10::CompileTimeFunctionPointer<at::Tensor (at::Tensor const&), &at::(anonymous namespace)::(anonymous namespace)::wrapper_CompositeExplicitAutograd__abs>, at::Tensor, c10::guts::typelist::typelist<at::Tensor const&> >, at::Tensor (at::Tensor const&)>::call(c10::OperatorKernel*, c10::DispatchKeySet, at::Tensor const&) from RegisterCompositeExplicitAutograd_0.cpp:0
pytorch#14 at::_ops::abs::redispatch(c10::DispatchKeySet, at::Tensor const&) from ??:0
pytorch#15 torch::autograd::VariableType::(anonymous namespace)::abs(c10::DispatchKeySet, at::Tensor const&) from VariableType_1.cpp:0
pytorch#16 c10::impl::wrap_kernel_functor_unboxed_<c10::impl::detail::WrapFunctionIntoFunctor_<c10::CompileTimeFunctionPointer<at::Tensor (c10::DispatchKeySet, at::Tensor const&), &torch::autograd::VariableType::(anonymous namespace)::abs>, at::Tensor, c10::guts::typelist::typelist<c10::DispatchKeySet, at::Tensor const&> >, at::Tensor (c10::DispatchKeySet, at::Tensor const&)>::call(c10::OperatorKernel*, c10::DispatchKeySet, at::Tensor const&) from VariableType_1.cpp:0
pytorch#17 at::_ops::abs::call(at::Tensor const&) from ??:0
pytorch#18 at::native::isfinite(at::Tensor const&) from ??:0
pytorch#19 c10::impl::wrap_kernel_functor_unboxed_<c10::impl::detail::WrapFunctionIntoFunctor_<c10::CompileTimeFunctionPointer<at::Tensor (at::Tensor const&), &at::(anonymous namespace)::(anonymous namespace)::wrapper_CompositeImplicitAutograd__isfinite>, at::Tensor, c10::guts::typelist::typelist<at::Tensor const&> >, at::Tensor (at::Tensor const&)>::call(c10::OperatorKernel*, c10::DispatchKeySet, at::Tensor const&) from RegisterCompositeImplicitAutograd_0.cpp:0
pytorch#20 at::_ops::isfinite::call(at::Tensor const&) from ??:0
pytorch#21 torch::autograd::THPVariable_isfinite(_object*, _object*, _object*) from python_torch_functions_2.cpp:0
pytorch#22 PyObject_CallFunctionObjArgs from ??:0
pytorch#23 _PyObject_MakeTpCall from ??:0
pytorch#24 _PyEval_EvalFrameDefault from ??:0
pytorch#25 _PyObject_FastCallDictTstate from ??:0
pytorch#26 _PyStack_AsDict from ??:0
pytorch#27 _PyObject_MakeTpCall from ??:0
pytorch#28 _PyEval_EvalFrameDefault from ??:0
pytorch#29 _PyFunction_Vectorcall from ??:0
pytorch#30 _PyEval_EvalFrameDefault from ??:0
pytorch#31 _PyFunction_Vectorcall from ??:0
pytorch#32 _PyEval_EvalFrameDefault from ??:0
pytorch#33 _PyFunction_Vectorcall from ??:0
pytorch#34 _PyEval_EvalFrameDefault from ??:0
pytorch#35 PyFrame_GetCode from ??:0
pytorch#36 PyNumber_Xor from ??:0
pytorch#37 PyObject_Str from ??:0
pytorch#38 PyFile_WriteObject from ??:0
pytorch#39 _PyWideStringList_AsList from ??:0
pytorch#40 _PyDict_NewPresized from ??:0
pytorch#41 _PyEval_EvalFrameDefault from ??:0
pytorch#42 PyEval_EvalCode from ??:0
pytorch#43 PyEval_EvalCode from ??:0
pytorch#44 PyUnicode_Tailmatch from ??:0
pytorch#45 PyInit__collections from ??:0
pytorch#46 PyUnicode_Tailmatch from ??:0
pytorch#47 _PyRun_SimpleFileObject from ??:0
pytorch#48 _PyRun_AnyFileObject from ??:0
pytorch#49 Py_RunMain from ??:0
pytorch#50 Py_BytesMain from ??:0
pytorch#51 __libc_init_first from ??:0
pytorch#52 __libc_start_main from ??:0
pytorch#53 _start from ??:0
Captured error code is 710
```
Pull Request resolved: pytorch#152023
Approved by: https://github.com/eqy, https://github.com/mradmila, https://github.com/ngimel
ghstack dependencies: pytorch#1544361 parent f7c09f8 commit 0350c7e
File tree
9 files changed
+59
-4
lines changed- c10
- cuda
- util
- docs/source
- test
- torch
- _C
- csrc
- cuda
9 files changed
+59
-4
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
38 | 38 | | |
39 | 39 | | |
40 | 40 | | |
41 | | - | |
42 | | - | |
| 41 | + | |
| 42 | + | |
43 | 43 | | |
44 | 44 | | |
45 | 45 | | |
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
295 | 295 | | |
296 | 296 | | |
297 | 297 | | |
| 298 | + | |
| 299 | + | |
| 300 | + | |
| 301 | + | |
| 302 | + | |
| 303 | + | |
| 304 | + | |
| 305 | + | |
| 306 | + | |
| 307 | + | |
| 308 | + | |
| 309 | + | |
| 310 | + | |
298 | 311 | | |
299 | 312 | | |
300 | 313 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
40 | 40 | | |
41 | 41 | | |
42 | 42 | | |
| 43 | + | |
43 | 44 | | |
44 | 45 | | |
45 | 46 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
1387 | 1387 | | |
1388 | 1388 | | |
1389 | 1389 | | |
| 1390 | + | |
| 1391 | + | |
1390 | 1392 | | |
1391 | 1393 | | |
1392 | 1394 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
59 | 59 | | |
60 | 60 | | |
61 | 61 | | |
| 62 | + | |
62 | 63 | | |
63 | 64 | | |
64 | 65 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
2628 | 2628 | | |
2629 | 2629 | | |
2630 | 2630 | | |
| 2631 | + | |
2631 | 2632 | | |
2632 | 2633 | | |
2633 | 2634 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
14 | 14 | | |
15 | 15 | | |
16 | 16 | | |
17 | | - | |
| 17 | + | |
| 18 | + | |
18 | 19 | | |
19 | 20 | | |
20 | 21 | | |
| |||
125 | 126 | | |
126 | 127 | | |
127 | 128 | | |
| 129 | + | |
| 130 | + | |
| 131 | + | |
| 132 | + | |
| 133 | + | |
| 134 | + | |
| 135 | + | |
| 136 | + | |
| 137 | + | |
| 138 | + | |
| 139 | + | |
| 140 | + | |
128 | 141 | | |
129 | 142 | | |
130 | 143 | | |
| |||
341 | 354 | | |
342 | 355 | | |
343 | 356 | | |
| 357 | + | |
| 358 | + | |
| 359 | + | |
| 360 | + | |
| 361 | + | |
| 362 | + | |
| 363 | + | |
| 364 | + | |
| 365 | + | |
| 366 | + | |
| 367 | + | |
| 368 | + | |
| 369 | + | |
| 370 | + | |
344 | 371 | | |
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
86 | 86 | | |
87 | 87 | | |
88 | 88 | | |
| 89 | + | |
| 90 | + | |
| 91 | + | |
| 92 | + | |
| 93 | + | |
| 94 | + | |
89 | 95 | | |
90 | 96 | | |
91 | 97 | | |
| |||
141 | 147 | | |
142 | 148 | | |
143 | 149 | | |
144 | | - | |
| 150 | + | |
| 151 | + | |
145 | 152 | | |
146 | 153 | | |
147 | 154 | | |
| |||
369 | 376 | | |
370 | 377 | | |
371 | 378 | | |
| 379 | + | |
| 380 | + | |
372 | 381 | | |
373 | 382 | | |
374 | 383 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
333 | 333 | | |
334 | 334 | | |
335 | 335 | | |
| 336 | + | |
336 | 337 | | |
337 | 338 | | |
338 | 339 | | |
| |||
0 commit comments