Skip to content

Commit be16f21

Browse files
BoyuanFengpytorchmergebot
authored andcommitted
[Graph Partition] add symints to get_graph_inputs (pytorch#154679)
During `codegen_inputs`, we check whether there are undefined symbols: https://github.com/pytorch/pytorch/blob/65b1aedd09e98fcafcdd893ca4924f4fa598fd18/torch/_inductor/codegen/wrapper.py#L1668-L1674 Previously, for graph partition inputs, we do not explicitly add symints. https://github.com/pytorch/pytorch/blob/65b1aedd09e98fcafcdd893ca4924f4fa598fd18/torch/_inductor/codegen/wrapper.py#L3265-L3272 We relied on sizes/strides of TensorBox for codegen symint inputs. For example, a tensor with shape `[s0, 2]` will implicitly codegen `s0` as an input here. This works fine in most cases since backed symint has to come from some tensor shapes. https://github.com/pytorch/pytorch/blob/65b1aedd09e98fcafcdd893ca4924f4fa598fd18/torch/_inductor/codegen/wrapper.py#L1624-L1632 In rare cases, this does not work. One example is saved tensors for backward where a tensor may have shape `[2*s0, 2]`. Since `2*s0` is an expression but not a symbol, `codegen_input_symbol_assignment` would not handle `s0` and later there would be an error when `_verify_input_symbol_assignment`. The fix is add symints to `get_graph_inputs`. An alternative way is to update `codegen_input_symbol_assignment` but I want to minimize the change to graph partition only. Pull Request resolved: pytorch#154679 Approved by: https://github.com/eellison
1 parent d3c8f36 commit be16f21

File tree

4 files changed

+51
-8
lines changed

4 files changed

+51
-8
lines changed

test/distributed/tensor/test_dtensor_compile.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -801,12 +801,7 @@ def fn(x):
801801
out_dt = torch.matmul(tmp_dt, y_dt)
802802
out_dt.sum().backward()
803803

804-
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
805-
@skip_if_lt_x_gpu(1)
806-
# TODO: somehow inductor bg compile threads are causing hangs at exit with distributed work dtor
807-
@patch.object(torch._inductor.config, "compile_threads", 1)
808-
@patch.object(torch._inductor.config, "reorder_for_compute_comm_overlap", True)
809-
def test_tp_compile_comm_reordering(self):
804+
def _test_tp_compile_comm_reordering(self):
810805
class FakeAttention(nn.Module):
811806
def __init__(self) -> None:
812807
super().__init__()
@@ -876,6 +871,23 @@ def forward(self, input):
876871
code
877872
)
878873

874+
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
875+
@skip_if_lt_x_gpu(1)
876+
# TODO: somehow inductor bg compile threads are causing hangs at exit with distributed work dtor
877+
@patch.object(torch._inductor.config, "compile_threads", 1)
878+
@patch.object(torch._inductor.config, "reorder_for_compute_comm_overlap", True)
879+
def test_tp_compile_comm_reordering(self):
880+
self._test_tp_compile_comm_reordering()
881+
882+
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
883+
@skip_if_lt_x_gpu(1)
884+
# TODO: somehow inductor bg compile threads are causing hangs at exit with distributed work dtor
885+
@patch.object(torch._inductor.config, "compile_threads", 1)
886+
@patch.object(torch._inductor.config, "reorder_for_compute_comm_overlap", True)
887+
@torch._inductor.config.patch("graph_partition", True)
888+
def test_tp_compile_comm_reordering_graph_partition(self):
889+
self._test_tp_compile_comm_reordering()
890+
879891

880892
@instantiate_parametrized_tests
881893
class TestDTensorCompileE2E(DTensorTestBase):

test/inductor/test_torchinductor.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14968,6 +14968,27 @@ def f(x, y):
1496814968
compiled_out = f_compiled(x, y)
1496914969
self.assertEqual(compiled_out, f(x, y))
1497014970

14971+
@torch._inductor.config.patch("graph_partition", True)
14972+
def test_graph_partition_symint_cat_backward(self):
14973+
def f(x, w):
14974+
y = torch.cat((x, x), dim=0)
14975+
z = y @ w
14976+
return z @ z.T
14977+
14978+
compiled_f = torch.compile(f)
14979+
14980+
for shape in (2, 3):
14981+
torch.manual_seed(42)
14982+
eager_x = torch.randn(shape, 2, device=self.device)
14983+
eager_w = torch.randn(2, 2, device=self.device, requires_grad=True)
14984+
torch.manual_seed(42)
14985+
compiled_x = torch.randn(shape, 2, device=self.device)
14986+
compiled_w = torch.randn(2, 2, device=self.device, requires_grad=True)
14987+
14988+
f(eager_x, eager_w).sum().backward()
14989+
compiled_f(compiled_x, compiled_w).sum().backward()
14990+
self.assertEqual(eager_w.grad, compiled_w.grad)
14991+
1497114992
@dynamo_config.patch("capture_dynamic_output_shape_ops", True)
1497214993
@config.patch(implicit_fallbacks=True)
1497314994
@torch._inductor.config.patch("graph_partition", True)

torch/_inductor/codegen/wrapper.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3282,7 +3282,9 @@ def get_graph_inputs(
32823282
self,
32833283
) -> dict[str, Union[ir.TensorBox, ir.TorchBindObject, sympy.Expr]]:
32843284
if signature := self.partition_signatures:
3285-
inputs = signature.input_nodes
3285+
inputs = signature.input_nodes | {
3286+
str(s): s for s in signature.symbol_inputs
3287+
}
32863288
else:
32873289
inputs = V.graph.graph_inputs
32883290
return inputs

torch/_inductor/scheduler.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4251,7 +4251,15 @@ def filter_symbols(
42514251
*(get_input_node_symbols(node) for _, node in input_nodes.items())
42524252
)
42534253

4254-
return filter_symbols(candidate_symbols)
4254+
candidate_symbols = filter_symbols(candidate_symbols)
4255+
4256+
res: OrderedSet[sympy.Symbol] = OrderedSet()
4257+
for s in candidate_symbols:
4258+
symplified_s = V.graph.sizevars.simplify(s)
4259+
# use free_symbols only when s is simplified to an Integer or expr
4260+
res.update(symplified_s.free_symbols)
4261+
4262+
return OrderedSet(sorted(res, key=operator.attrgetter("name")))
42554263

42564264
def get_graph_partition_signature(
42574265
self, partitions: list[PartitionType], skip_cudagraphs: list[bool]

0 commit comments

Comments
 (0)