Skip to content

Commit c1055f4

Browse files
laithsakkapytorchmergebot
authored andcommitted
Data dependent free reshape. (pytorch#153198)
#### change 1: if compute_strides stride fail for reshape just clone. Lets consider the most general case, if torch compile is asked to reshape [u0, u1][u3, u4] -> [u5, u6] what shall it do? The shape is general enough to represent both contiguous and non contiguous tensors, tensors where a clone free reshape can happen and other where a clone free cant happen. The current algorithm will fail due to data dependent errors. The general idea is if its impossible to tell if the reshape can happen in place, (because for some concrete inputs it will and other not) then its ok to take the general path and clone, instead of failing or asking the user to give hints. **Because the user want a single graph (single compilations)** and this is the only way it can be done. Had this been a view? then the user is explicitly asking for a copy-free reshape, we would fail asking for more information (hints in torch.checks form). with this change reshape works as the following: 1. if we know the input is contiguous we will convert the reshape to view. 2. if compute_strides succeed we will use view. (compute_strides was changed to not fail when when unbacked presented instead it will just return nullptr if it cant compute the strides meaning we shall use a clone). 3. if neither 1, 2 works clone and use a view. Side note: having a view does not mean that inductor will not clone, for inductor there is a pass that converts all views back to reshapes and inductor has its logic dealing with those. #### change 2 : skip _reshape_view_helper and fall back to simpler logic if it fail. We trace _reshape_view_helper when doing fake tensor tracing , but not during proxy tracing. hence such tracing wont effect the graph (only compute output shapes of several operations). We should not fail there, because it should always be possible for us to pass it in case of reshape. i.e. when reshape_symint was called we would have either cloned, or compute_strides succeeded so the view should pass. What I did is the following: we run _reshape_view_helper, if we fail due to unbacked we call _view_simple which will succeed always for reshapes, (might fail for views when its impossible to do the view, in such case we throw the dde that was thrown by the original algorithm). Ideally I would want to register _view_simple as the meta for view and avoid calling _reshape_view_helper completely but I am running some issues with the dispatcher with subclasses and I do not have time to debug it. Namely one test would end up calling some c++ view function that does not support symints during meta dispatch when i register a python meta decompositions ```python test/dynamo/test_subclasses.py SubclassTests.test_subclass_views_dynamic_True ``` pytorch#153303 will follow up with that change in a separate PR. cc @H-Huang @awgu @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @chenyang78 @kadeng @muchulee8 @amjames @chauhang @aakhundov @bdhirsh Two other alternatives for registering _view_simple as meta and the try catch approach in this PR is: 1. call _view_simple if any input is dynamic see pytorch#153521 2. if we make is_compiling works for framework code tracing (does not work rn) we can call _view_simple is if is_compiling. #### Note: Reshape can still fail when is_contiguous is called, Next PR will handle that by calling is_known_contiguous. Pull Request resolved: pytorch#153198 Approved by: https://github.com/etaf, https://github.com/bobrenjc93
1 parent f74842d commit c1055f4

File tree

7 files changed

+453
-76
lines changed

7 files changed

+453
-76
lines changed

aten/src/ATen/FunctionalizeFallbackKernel.cpp

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
#include <torch/library.h>
88
#include <c10/util/irange.h>
99
#include <c10/util/strides.h>
10+
#include <ATen/EmptyTensor.h>
1011

1112
#ifndef AT_PER_OPERATOR_HEADERS
1213
#include <ATen/ATen.h>
@@ -315,8 +316,33 @@ static at::Tensor _unsafe_view_functionalize(const at::Tensor & self, at::SymInt
315316
// See Note [Propagating strides in the functionalization pass]
316317
// (for _unsafe_view, I'm just manually doing the shape inference rule here instead of calling the meta function for unsafe_view)
317318
auto inferred_size = at::infer_size_dv(size, self.sym_numel());
319+
318320
auto stride = at::detail::computeStride(self.sym_sizes(), self.sym_strides(), inferred_size);
319-
TORCH_INTERNAL_ASSERT(stride.has_value());
321+
322+
if (!stride.has_value()) {
323+
// With unbacked symints, computeStride could fail even on contiguous
324+
// tensors. In this case, we can use the strides of an empty tensor of
325+
// inferred_size.
326+
TORCH_CHECK(
327+
self.is_contiguous(),
328+
"View is not valid from size:",
329+
self.sym_sizes(),
330+
" stride: ",
331+
self.sym_strides(),
332+
" to shape: ",
333+
inferred_size,
334+
" in case of unbacked symbols consider adding torch.check to guide computing strides.");
335+
336+
stride = at::detail::empty_symint_meta(
337+
inferred_size,
338+
std::nullopt,
339+
std::nullopt,
340+
std::nullopt,
341+
std::nullopt,
342+
std::nullopt)
343+
.sym_strides();
344+
}
345+
320346
out.unsafeGetTensorImpl()->set_sizes_and_strides(inferred_size, stride.value());
321347
return out;
322348
}

aten/src/ATen/InferSize.h

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -25,17 +25,23 @@ inline void infer_size_impl(
2525
// N.B. this is an index, not a sym dim!
2626
std::optional<int64_t> infer_dim;
2727
for (int64_t dim = 0, ndim = shape.size(); dim != ndim; dim++) {
28-
// We can avoid failing on unbacked shape[dim] and assert that it is >=0
29-
// following python behaviour.
30-
if (shape[dim] == -1) {
28+
if (TORCH_GUARD_OR_FALSE(sym_eq(shape[dim], -1))) {
3129
if (infer_dim) {
3230
throw std::runtime_error("only one dimension can be inferred");
3331
}
3432
infer_dim = dim;
35-
} else if (shape[dim] >= 0) {
36-
newsize *= shape[dim];
3733
} else {
38-
TORCH_CHECK(false, "invalid shape dimension ", shape[dim]);
34+
// in case of unbacked shape[dim] we assume it's not -1 and add a runtime
35+
// assertion.
36+
TORCH_MAYBE_SYM_CHECK(
37+
sym_gt(shape[dim], -1),
38+
"invalid shape dimension ",
39+
shape[dim],
40+
" at index ",
41+
dim,
42+
" of shape ",
43+
shape);
44+
newsize *= shape[dim];
3945
}
4046
}
4147

aten/src/ATen/TensorUtils.cpp

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -367,19 +367,33 @@ inline static std::optional<ResultVec> computeStride_impl(
367367
// numel in current chunk
368368
Numel tensor_numel = 1;
369369
Numel view_numel = 1;
370+
371+
// The usages of TORCH_GUARD_OR_TRUE/TORCH_GUARD_OR_FALSE below could result in returning
372+
// std::nullopt which has an effect of falling back to a clone when unbacked symints are present.
373+
// But it will not result in returning different or wrong results.
370374
for (int64_t tensor_d = oldshape.size() - 1; tensor_d >= 0; tensor_d--) {
371375
tensor_numel *= oldshape[tensor_d];
372376
// if end of tensor size chunk, check view
373377
if ((tensor_d == 0) ||
374-
(TORCH_GUARD_SIZE_OBLIVIOUS(sym_ne(oldshape[tensor_d - 1], 1)) &&
375-
TORCH_GUARD_SIZE_OBLIVIOUS(sym_ne(oldstride[tensor_d - 1], tensor_numel * chunk_base_stride)))) {
378+
(TORCH_GUARD_OR_TRUE(sym_ne(oldshape[tensor_d - 1], 1)) &&
379+
TORCH_GUARD_OR_TRUE(sym_ne(oldstride[tensor_d - 1], tensor_numel * chunk_base_stride)))) {
380+
// We want to accumulate stuff in view_numel until view_numel == tensor_numel, if we do not
381+
// know if that is satisfied we keep accumalating. For example if view_numel = 1 and tensor_numel = u1,
382+
// we want to take that path, view_numel will become u0. Next iteration if u0==u1 we want to stop.
383+
// Thats why we use TORCH_GUARD_OR_TRUE below.
384+
385+
// we use TORCH_GUARD_OR_FALSE and not TORCH_GUARD_OR_TRUE when comparing newshape[view_d] ==1 because
386+
// if we know view_numel < tensor_numel is false, we want to stop. Unless we know for sure newshape[view_d]==1
387+
// in that case we would stop in the next iteration anyway. For example, if view_numel = u0 and tensor_numel = u1,
388+
// and u0==u1, then want to stop unless newshape[view_d]==1. taking one more iteration will keep [view_numel = u0
389+
// and tensor_numel = u1].
376390
while (view_d >= 0 &&
377-
(TORCH_GUARD_SIZE_OBLIVIOUS(sym_lt(view_numel, tensor_numel)) || TORCH_GUARD_SIZE_OBLIVIOUS(sym_eq(newshape[view_d], 1)))) {
391+
(TORCH_GUARD_OR_TRUE(sym_lt(view_numel, tensor_numel)) || TORCH_GUARD_OR_FALSE(sym_eq(newshape[view_d], 1)))) {
378392
newstride[view_d] = view_numel * chunk_base_stride;
379393
view_numel *= newshape[view_d];
380394
view_d--;
381395
}
382-
if (TORCH_GUARD_SIZE_OBLIVIOUS(sym_ne(view_numel, tensor_numel))) {
396+
if (TORCH_GUARD_OR_TRUE(sym_ne(view_numel, tensor_numel))) {
383397
return std::nullopt;
384398
}
385399
if (tensor_d > 0) {

test/export/test_export.py

Lines changed: 0 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -4570,32 +4570,9 @@ class M_v0(torch.nn.Module):
45704570
def forward(self, t):
45714571
items = [t[i].item() for i in range(t.numel())]
45724572
r = torch.randn([items[0], items[1]])
4573-
# Could not guard on data-dependent expression Ne(Mod(u1, u2), 0)
45744573
return r.view(items[0], items[2])
45754574

45764575
M = M_v0
4577-
with self.assertRaisesRegex(
4578-
error_type,
4579-
"The following call raised this error(.*\n)+"
4580-
f".*{re.escape('return r.view(items[0], items[2])')}(.*\n)+"
4581-
"To fix the error, insert one of the following checks before this call.*:\n"
4582-
f".*{re.escape('torch._check((items[1] % items[2]) == 0)')}.*\n"
4583-
f".*{re.escape('torch._check((items[1] % items[2]) != 0)')}(.*\n)+"
4584-
f".*{re.escape('(These suggested fixes were derived by replacing `u1` with items[1]')}"
4585-
f".*{re.escape('or r.shape[1], `u2` with items[2] in Eq(Mod(u1, u2), 0) and its negation.')}",
4586-
):
4587-
export(N(), (t,), strict=strict)
4588-
4589-
class M_v1(torch.nn.Module):
4590-
def forward(self, t):
4591-
items = [t[i].item() for i in range(t.numel())]
4592-
r = torch.randn([items[0], items[1]])
4593-
# TODO(pianpwk): this isn't the suggested fixes.
4594-
# fix issue with % being interpreted as PythonMod instead of Mod
4595-
torch._check(items[1] == items[2])
4596-
return r.view(items[0], items[2])
4597-
4598-
M = M_v1
45994576
export(N(), (t,), strict=strict)
46004577

46014578
def test_suggested_fixes_for_data_dependent_errors_puzzlers(self):

0 commit comments

Comments
 (0)