Skip to content

Commit 2c76f31

Browse files
laithsakkapytorchmergebot
authored andcommitted
Compute contiguity symbolically to avoid dde, and introduce c++ sym_is_contiguous (pytorch#155590)
When we compute contiguity for a tensor with dynamic shapes we first: 1) Try to compute it without guarding. 2) If all shapes hinted, compute it with potentially adding guards. 3) if any input is not hinted, compute it symbolically. sym_is_contiguous return a SymBool that is then either evaluated or guard_or_false can be called on it to avoid data dependent errors. ex: bool is_contiguous = input.sym_is_contiguous().guard_or_false(__FILE__, __LINE__); is_contiguous_or_false is a helper function that does that. In this PR I only handle default contiguity, will follow up with changes for other formats like channel_last . We use this patter in this PR for several locations to avoid DDEs. Differential Revision: [D77183032](https://our.internmc.facebook.com/intern/diff/D77183032) Pull Request resolved: pytorch#155590 Approved by: https://github.com/ezyang
1 parent b754b1f commit 2c76f31

34 files changed

+390
-114
lines changed

.github/ci_commit_pins/xla.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
55a75404c9b75cd5fd62ab5d4deafc8c506b3af2
1+
926700d7832caa552ba2e1fc8302f6a2f4d2f6d8

aten/src/ATen/FunctionalTensorWrapper.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -499,8 +499,8 @@ int64_t FunctionalTensorWrapper::dim_custom() const {
499499
int64_t FunctionalTensorWrapper::numel_custom() const {
500500
return value_.unsafeGetTensorImpl()->numel();
501501
}
502-
bool FunctionalTensorWrapper::is_contiguous_custom(at::MemoryFormat memory_format) const {
503-
return value_.unsafeGetTensorImpl()->is_contiguous(memory_format);
502+
c10::SymBool FunctionalTensorWrapper::sym_is_contiguous_custom(at::MemoryFormat memory_format) const {
503+
return value_.unsafeGetTensorImpl()->sym_is_contiguous(memory_format);
504504
}
505505
c10::SymIntArrayRef FunctionalTensorWrapper::sym_sizes_custom() const {
506506
return value_.unsafeGetTensorImpl()->sym_sizes();

aten/src/ATen/FunctionalTensorWrapper.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -236,7 +236,8 @@ struct TORCH_API FunctionalTensorWrapper : public c10::TensorImpl {
236236
at::IntArrayRef strides_custom() const override;
237237
int64_t dim_custom() const override;
238238
int64_t numel_custom() const override;
239-
bool is_contiguous_custom(at::MemoryFormat memory_format) const override;
239+
c10::SymBool sym_is_contiguous_custom(
240+
at::MemoryFormat memory_format) const override;
240241
c10::SymIntArrayRef sym_sizes_custom() const override;
241242
c10::SymInt sym_size_custom(int64_t d) const override;
242243
c10::SymIntArrayRef sym_strides_custom() const override;

aten/src/ATen/FunctionalizeFallbackKernel.cpp

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -320,11 +320,9 @@ static at::Tensor _unsafe_view_functionalize(const at::Tensor & self, at::SymInt
320320
auto stride = at::detail::computeStride(self.sym_sizes(), self.sym_strides(), inferred_size);
321321

322322
if (!stride.has_value()) {
323-
// With unbacked symints, computeStride could fail even on contiguous
324-
// tensors. In this case, we can use the strides of an empty tensor of
325-
// inferred_size.
326-
TORCH_CHECK(
327-
self.is_contiguous(),
323+
324+
TORCH_SYM_CHECK(
325+
self.sym_is_contiguous(),
328326
"View is not valid from size:",
329327
self.sym_sizes(),
330328
" stride: ",
@@ -333,6 +331,9 @@ static at::Tensor _unsafe_view_functionalize(const at::Tensor & self, at::SymInt
333331
inferred_size,
334332
" in case of unbacked symbols consider adding torch.check to guide computing strides.");
335333

334+
// With unbacked symints, computeStride could fail even on contiguous
335+
// tensors. In this case, we can use the strides of an empty tensor of
336+
// inferred_size.
336337
stride = at::detail::empty_symint_meta(
337338
inferred_size,
338339
std::nullopt,

aten/src/ATen/LegacyBatchedTensorImpl.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ IntArrayRef BatchedTensorImpl::strides_custom() const {
8484

8585
// TODO: implement proper contiguity on batched tensor, then put
8686
// sizes_strides_policy back to Default
87-
bool BatchedTensorImpl::is_contiguous_custom(at::MemoryFormat memory_format) const {
87+
c10::SymBool BatchedTensorImpl::sym_is_contiguous_custom(at::MemoryFormat memory_format) const {
8888
TORCH_CHECK(memory_format == MemoryFormat::Contiguous,
8989
"NYI: querying is_contiguous inside of vmap for memory_format ",
9090
"other than torch.contiguous_format");

aten/src/ATen/LegacyBatchedTensorImpl.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,8 @@ struct TORCH_API BatchedTensorImpl : public c10::TensorImpl {
8282
IntArrayRef strides_custom() const override;
8383
// Override a bunch of methods inherited from TensorImpl to return error
8484
// messages.
85-
bool is_contiguous_custom(at::MemoryFormat memory_format) const override;
85+
c10::SymBool sym_is_contiguous_custom(
86+
at::MemoryFormat memory_format) const override;
8687
void set_size(int64_t dim, int64_t new_size) override;
8788
void set_stride(int64_t dim, int64_t new_stride) override;
8889
void set_storage_offset(int64_t storage_offset) override;

aten/src/ATen/MemoryOverlap.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ MemOverlap has_internal_overlap(TensorImpl* t) {
2424
}
2525
}
2626

27-
if (t->is_non_overlapping_and_dense()) {
27+
if (t->is_non_overlapping_and_dense_or_false()) {
2828
return MemOverlap::No;
2929
}
3030

@@ -63,7 +63,7 @@ MemOverlapStatus get_overlap_status(const TensorImpl* a, const TensorImpl* b) {
6363
if (a->numel() == 0 || b->numel() == 0) {
6464
return MemOverlapStatus::No;
6565
}
66-
if (!a->is_non_overlapping_and_dense() || !b->is_non_overlapping_and_dense()) {
66+
if (!a->is_non_overlapping_and_dense_or_false() || !b->is_non_overlapping_and_dense_or_false()) {
6767
return MemOverlapStatus::TooHard;
6868
}
6969
// Test for storage equality, rather than pointer equality.

aten/src/ATen/NestedTensorImpl.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -273,7 +273,7 @@ c10::SymInt NestedTensorImpl::sym_numel_custom() const {
273273
return NestedTensorImpl::numel_custom();
274274
}
275275

276-
bool NestedTensorImpl::is_contiguous_custom(MemoryFormat) const {
276+
c10::SymBool NestedTensorImpl::sym_is_contiguous_custom(MemoryFormat) const {
277277
return nested_tensor_impl_is_contiguous(this);
278278
}
279279
IntArrayRef NestedTensorImpl::sizes_custom() const {

aten/src/ATen/NestedTensorImpl.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ struct TORCH_API NestedTensorImpl : public c10::TensorImpl {
115115
// with real implementations
116116
int64_t numel_custom() const override;
117117
c10::SymInt sym_numel_custom() const override;
118-
bool is_contiguous_custom(MemoryFormat) const override;
118+
c10::SymBool sym_is_contiguous_custom(MemoryFormat) const override;
119119
int64_t size_custom(int64_t d) const override {
120120
return this->size(d);
121121
}

aten/src/ATen/SparseCsrTensorImpl.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -252,8 +252,7 @@ void SparseCsrTensorImpl::set_stride(int64_t dim, int64_t new_stride) {
252252
void SparseCsrTensorImpl::set_storage_offset(int64_t storage_offset) {
253253
TORCH_CHECK(false, "Sparse ", at::sparse_csr::layoutToString(layout_, /*upper=*/true), " tensors do not have set_storage_offset.");
254254
}
255-
bool SparseCsrTensorImpl::is_contiguous_custom(MemoryFormat) const {
255+
c10::SymBool SparseCsrTensorImpl::sym_is_contiguous_custom(MemoryFormat) const {
256256
TORCH_CHECK(false, "Sparse ", at::sparse_csr::layoutToString(layout_, /*upper=*/true), " tensors do not have is_contiguous");
257257
}
258-
259258
} // namespace at

0 commit comments

Comments
 (0)