Skip to content

Commit 51da241

Browse files
yushangdipytorchmergebot
authored andcommitted
[aoti] Fix cannot determine truth value of Relation error when propagating unbacked symint in lowering (pytorch#150570)
Summary: Fix cannot determine truth value of Relation error when propagating unbacked symint in lowering Test Plan: ``` buck run fbcode//mode/dev-nosan //caffe2/test/inductor:test_aot_inductor -- -r aoti_runtime_asserts ``` Differential Revision: D72331070 Pull Request resolved: pytorch#150570 Approved by: https://github.com/angelayi, https://github.com/henryoier
1 parent c1d5035 commit 51da241

File tree

2 files changed

+51
-2
lines changed

2 files changed

+51
-2
lines changed

test/inductor/test_aot_inductor.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3369,6 +3369,56 @@ def forward(self, q, k, v, attn_bias):
33693369
)
33703370
self.check_model(Model(), example_inputs)
33713371

3372+
def test_aoti_runtime_asserts(self):
3373+
from torch._dispatch.python import enable_python_dispatcher
3374+
from torch.export._draft_export import draft_export
3375+
3376+
with torch.library._scoped_library("mylib", "FRAGMENT") as lib:
3377+
torch.library.define(
3378+
"mylib::foo",
3379+
"(Tensor a, Tensor b) -> Tensor",
3380+
tags=torch.Tag.pt2_compliant_tag,
3381+
lib=lib,
3382+
)
3383+
3384+
@torch.library.impl("mylib::foo", "cpu", lib=lib)
3385+
def foo(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
3386+
return a[: b.item()]
3387+
3388+
@torch.library.impl_abstract("mylib::foo", lib=lib)
3389+
def foo_fake_impl(a, b):
3390+
ctx = torch.library.get_ctx()
3391+
u = ctx.new_dynamic_size()
3392+
return torch.empty(u)
3393+
3394+
class M(torch.nn.Module):
3395+
def forward(self, a, b):
3396+
res = torch.ops.mylib.foo(a, b)
3397+
s = res.shape[0]
3398+
torch._check(s > 3)
3399+
torch._check(s < a.shape[0])
3400+
return a[s - 3]
3401+
3402+
example_inputs = (torch.randn(100), torch.tensor(10))
3403+
ep = draft_export(M(), example_inputs)
3404+
m = ep.module()
3405+
from torch.fx.passes.fake_tensor_prop import FakeTensorProp
3406+
3407+
example_inputs = [
3408+
node.meta["val"] for node in m.graph.nodes if node.op == "placeholder"
3409+
]
3410+
fake_mode = example_inputs[0].fake_mode
3411+
with enable_python_dispatcher(), fake_mode:
3412+
FakeTensorProp(m, mode=fake_mode).propagate_dont_convert_inputs(
3413+
*example_inputs
3414+
)
3415+
3416+
# TODO: change to the tests below after MetadataMismatchError is fixed
3417+
# pt2_file = torch._inductor.aoti_compile_and_package(ep)
3418+
# optimized = torch._inductor.aoti_load_package(pt2_file)
3419+
3420+
# self.assertTrue(same(optimized(example_inputs), m(example_inputs)))
3421+
33723422
def test_index_put_with_none_index(self):
33733423
# index_put falls back in the deterministic mode
33743424
with DeterministicGuard(True):

torch/_subclasses/fake_tensor.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2308,10 +2308,9 @@ def maybe_to_real_tensor(
23082308
if (
23092309
self.propagate_real_tensors
23102310
and all(e.real_tensor is not None for e in flat_arg_fake_tensors)
2311-
# TODO: Handle SymFloat/SymBool
23122311
and not any(
23132312
(
2314-
isinstance(a, SymInt)
2313+
isinstance(a, py_sym_types)
23152314
and (syms := free_unbacked_symbols(a))
23162315
and self.shape_env is not None
23172316
and any(s not in self.shape_env.unbacked_var_to_val for s in syms)

0 commit comments

Comments
 (0)