Skip to content

Commit c18e2ce

Browse files
eellisonpytorchmergebot
authored andcommitted
Ignore meta ops in inductor (pytorch#150137)
Fix for pytorch#144607 Pull Request resolved: pytorch#150137 Approved by: https://github.com/BoyuanFeng
1 parent ddb1e97 commit c18e2ce

File tree

3 files changed

+18
-1
lines changed

3 files changed

+18
-1
lines changed

test/inductor/test_torchinductor_dynamic_shapes.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -608,6 +608,16 @@ def f(x):
608608

609609
f(torch.tensor([3], device=device))
610610

611+
def test_meta_dynamic_shapes(self):
612+
def foobar(x, y):
613+
return x * 2, y * 3
614+
615+
foo_c = torch.compile(dynamic=True)(foobar)
616+
t = torch.empty((1, 16, 128, 128), device="meta")
617+
y = torch.rand([64])
618+
619+
self.assertEqual(foo_c(t, y), foobar(t, y))
620+
611621
def test_floor(self):
612622
def fn(x):
613623
n = x.size(-1)

torch/_inductor/lowering.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1927,6 +1927,9 @@ def unsupported_input_tensor(t: torch.Tensor, parent=None, node=None):
19271927
_warn_complex_not_supported()
19281928
return True
19291929

1930+
if t.is_meta:
1931+
return True
1932+
19301933
if t.dtype == torch.float8_e8m0fnu:
19311934
if not node:
19321935
return True

torch/_inductor/scheduler.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4394,7 +4394,11 @@ def _codegen(self, nodes: list[BaseSchedulerNode]) -> None:
43944394

43954395
if not isinstance(node, NopKernelSchedulerNode):
43964396
device = node.get_device()
4397-
if device is not None and self.get_backend(device).ready_to_flush():
4397+
if (
4398+
device is not None
4399+
and device.type != "meta"
4400+
and self.get_backend(device).ready_to_flush()
4401+
):
43984402
self.flush()
43994403

44004404
if self.current_device and device_need_guard(self.current_device.type):

0 commit comments

Comments
 (0)