Commit 899d3d3
Don't call
Don't call `sum()` on a tensor that is default constructed.
Previously we could call `sum()` on a tensor that was default-contructed. That would lead to an error like this:
```
Traceback (most recent call last):
File "/home/ahmads/.conda/envs/pt3/lib/python3.12/unittest/case.py", line 58, in testPartExecutor
yield
File "/home/ahmads/.conda/envs/pt3/lib/python3.12/unittest/case.py", line 634, in run
self._callTestMethod(testMethod)
File "/home/ahmads/.conda/envs/pt3/lib/python3.12/unittest/case.py", line 589, in _callTestMethod
if method() is not None:
^^^^^^^^
File "/home/ahmads/personal/pytorch/torch/testing/_internal/common_utils.py", line 3191, in wrapper
method(*args, **kwargs)
File "/home/ahmads/personal/pytorch/test/test_nn.py", line 7235, in test_layer_norm_backwards_eps
ln_out_cuda.backward(grad_output_cuda)
File "/home/ahmads/personal/pytorch/torch/_tensor.py", line 647, in backward
torch.autograd.backward(
File "/home/ahmads/personal/pytorch/torch/autograd/__init__.py", line 354, in backward
_engine_run_backward(
File "/home/ahmads/personal/pytorch/torch/autograd/graph.py", line 829, in _engine_run_backward
return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: tensor does not have a device
Exception raised from device_default at /home/ahmads/personal/pytorch/c10/core/TensorImpl.h:1265 (most recent call first):
C++ CapturedTraceback:
#4 std::_Function_handler<std::shared_ptr<c10::LazyValue<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > > const> (), c10::SetStackTraceFetcher(std::function<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > ()>)::{lambda()#1}>::_M_invoke(std::_Any_data const&) from Logging.cpp:0
#5 c10::Error::Error(c10::SourceLocation, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >) from ??:0
pytorch#6 c10::detail::torchCheckFail(char const*, char const*, unsigned int, char const*) from ??:0
pytorch#7 at::TensorBase::options() const from :0
pytorch#8 at::meta::resize_reduction(at::impl::MetaBase&, at::Tensor const&, c10::OptionalArrayRef<long>, bool, c10::ScalarType, bool) from :0
pytorch#9 at::meta::structured_sum_dim_IntList::meta(at::Tensor const&, c10::OptionalArrayRef<long>, bool, std::optional<c10::ScalarType>) from ??:0
pytorch#10 at::(anonymous namespace)::wrapper_CompositeExplicitAutogradNonFunctional_sum_dim_IntList(at::Tensor const&, c10::OptionalArrayRef<long>, bool, std::optional<c10::ScalarType>) from RegisterCompositeExplicitAutogradNonFunctional_0.cpp:0
pytorch#11 c10::impl::wrap_kernel_functor_unboxed_<c10::impl::detail::WrapFunctionIntoFunctor_<c10::CompileTimeFunctionPointer<at::Tensor (at::Tensor const&, c10::OptionalArrayRef<long>, bool, std::optional<c10::ScalarType>), &at::(anonymous namespace)::wrapper_CompositeExplicitAutogradNonFunctional_sum_dim_IntList>, at::Tensor, c10::guts::typelist::typelist<at::Tensor const&, c10::OptionalArrayRef<long>, bool, std::optional<c10::ScalarType> > >, at::Tensor (at::Tensor const&, c10::OptionalArrayRef<long>, bool, std::optional<c10::ScalarType>)>::call(c10::OperatorKernel*, c10::DispatchKeySet, at::Tensor const&, c10::OptionalArrayRef<long>, bool, std::optional<c10::ScalarType>) from RegisterCompositeExplicitAutogradNonFunctional_0.cpp:0
pytorch#12 at::_ops::sum_dim_IntList::call(at::Tensor const&, c10::OptionalArrayRef<long>, bool, std::optional<c10::ScalarType>) from ??:0
pytorch#13 void at::native::(anonymous namespace)::LaunchGammaBetaBackwardCUDAKernel<float, float>(float const*, float const*, float const*, float const*, long, long, at::Tensor*, at::Tensor*, CUstream_st*) from ??:0
pytorch#14 void at::native::(anonymous namespace)::LayerNormBackwardKernelImplInternal<float>(at::Tensor const&, at::Tensor const&, at::Tensor const&, at::Tensor const&, at::Tensor const&, long, long, at::Tensor*, at::Tensor*, at::Tensor*) from ??:0
pytorch#15 at::native::(anonymous namespace)::LayerNormBackwardKernelImpl(at::Tensor const&, at::Tensor const&, at::Tensor const&, at::Tensor const&, at::Tensor const&, long, long, at::Tensor*, at::Tensor*, at::Tensor*) from ??:0
pytorch#16 at::native::layer_norm_backward_cuda(at::Tensor const&, at::Tensor const&, c10::ArrayRef<long>, at::Tensor const&, at::Tensor const&, std::optional<at::Tensor> const&, std::optional<at::Tensor> const&, std::array<bool, 3ul>) from ??:0
pytorch#17 at::(anonymous namespace)::(anonymous namespace)::wrapper_CUDA__native_layer_norm_backward(at::Tensor const&, at::Tensor const&, c10::ArrayRef<c10::SymInt>, at::Tensor const&, at::Tensor const&, std::optional<at::Tensor> const&, std::optional<at::Tensor> const&, std::array<bool, 3ul>) from RegisterCUDA_0.cpp:0
```
Now we only call `sum(0)` on tensors that are defined and properly guard the `sum(0)` and assignment.
Pull Request resolved: pytorch#156600
Approved by: https://github.com/eqy, https://github.com/ngimelsum() on a tensor that is not summable in layer_norm (pytorch#156600)1 parent 17eb649 commit 899d3d3
File tree
2 files changed
+27
-16
lines changed- aten/src/ATen/native/cuda
- test
2 files changed
+27
-16
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
884 | 884 | | |
885 | 885 | | |
886 | 886 | | |
887 | | - | |
888 | | - | |
| 887 | + | |
| 888 | + | |
| 889 | + | |
| 890 | + | |
| 891 | + | |
| 892 | + | |
889 | 893 | | |
890 | 894 | | |
891 | 895 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
7212 | 7212 | | |
7213 | 7213 | | |
7214 | 7214 | | |
| 7215 | + | |
7215 | 7216 | | |
7216 | 7217 | | |
7217 | 7218 | | |
7218 | 7219 | | |
7219 | 7220 | | |
7220 | | - | |
7221 | | - | |
7222 | | - | |
7223 | | - | |
7224 | | - | |
7225 | | - | |
7226 | | - | |
7227 | | - | |
7228 | | - | |
7229 | | - | |
7230 | | - | |
7231 | | - | |
7232 | | - | |
7233 | | - | |
| 7221 | + | |
| 7222 | + | |
| 7223 | + | |
| 7224 | + | |
| 7225 | + | |
| 7226 | + | |
| 7227 | + | |
| 7228 | + | |
| 7229 | + | |
| 7230 | + | |
| 7231 | + | |
| 7232 | + | |
| 7233 | + | |
| 7234 | + | |
| 7235 | + | |
| 7236 | + | |
| 7237 | + | |
| 7238 | + | |
| 7239 | + | |
| 7240 | + | |
7234 | 7241 | | |
7235 | 7242 | | |
7236 | 7243 | | |
| |||
0 commit comments