Skip to content

Commit a4da1d4

Browse files
BoyuanFengpytorchmergebot
authored andcommitted
[Graph Partition] support standalone_compile (pytorch#154698)
For graph partition, `write_get_raw_stream_header_once` is done once so the autotune code may not have the header. This PR additionally calls `write_get_raw_stream_header` in `codegen_device_guard_enter` before `get_raw_stream` is used. Pull Request resolved: pytorch#154698 Approved by: https://github.com/oulgen
1 parent d91c85b commit a4da1d4

File tree

3 files changed

+26
-8
lines changed

3 files changed

+26
-8
lines changed

test/inductor/test_codecache.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1565,7 +1565,10 @@ def backend(gm_, args_, **kwargs_):
15651565
@parametrize("device", (GPU_TYPE, "cpu"))
15661566
@parametrize("format", ("binary", "unpacked"))
15671567
@parametrize("dynamic", (False, True))
1568-
def test_basic(self, device: str, format: str, dynamic: bool) -> None:
1568+
@parametrize("graph_partition", (False, True))
1569+
def test_basic(
1570+
self, device: str, format: str, dynamic: bool, graph_partition: bool
1571+
) -> None:
15691572
if device == GPU_TYPE and not HAS_GPU:
15701573
raise unittest.SkipTest(f"requires {GPU_TYPE}")
15711574

@@ -1580,7 +1583,9 @@ def f(x):
15801583

15811584
eager_out = f(x)
15821585

1583-
with tempfile.TemporaryDirectory() as temp_dir:
1586+
with tempfile.TemporaryDirectory() as temp_dir, config.patch(
1587+
graph_partition=graph_partition
1588+
):
15841589
path = (
15851590
temp_dir
15861591
if format == "unpacked"

torch/_inductor/codegen/wrapper.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
DelayReplaceLine,
4949
get_benchmark_name,
5050
IndentedBuffer,
51+
is_codegen_graph_partition_subgraph,
5152
LineContext,
5253
set_kernel_post_grad_provenance_tracing,
5354
sympy_product,
@@ -891,10 +892,7 @@ def __init__(self):
891892

892893
self.write_header()
893894

894-
if not (
895-
isinstance(self, SubgraphPythonWrapperCodegen)
896-
and self.partition_signatures is not None
897-
):
895+
if not is_codegen_graph_partition_subgraph(self):
898896
# See [Note: Removed Graph Partition Arguments]
899897
self.write_prefix()
900898

@@ -1057,8 +1055,7 @@ def write_triton_header_once(self) -> None:
10571055
V.graph.device_ops.import_get_raw_stream_as("get_raw_stream")
10581056
)
10591057

1060-
@cache_on_self
1061-
def write_get_raw_stream_header_once(self) -> None:
1058+
def write_get_raw_stream_header(self) -> None:
10621059
if config.triton.autotune_at_compile_time:
10631060
self.kernel_autotune_calls.writeline(
10641061
V.graph.device_ops.import_get_raw_stream_as("get_raw_stream")
@@ -1068,6 +1065,10 @@ def write_get_raw_stream_header_once(self) -> None:
10681065
V.graph.device_ops.import_get_raw_stream_as("get_raw_stream")
10691066
)
10701067

1068+
@cache_on_self
1069+
def write_get_raw_stream_header_once(self) -> None:
1070+
self.write_get_raw_stream_header()
1071+
10711072
def add_meta_once(self, meta: TritonMetaParams) -> str:
10721073
meta = repr(meta)
10731074
if meta not in self._metas:
@@ -1248,6 +1249,9 @@ def codegen_device_guard_enter(self, device_idx: int) -> None:
12481249
self.kernel_autotune_calls.writeline(
12491250
V.graph.device_ops.set_device(device_idx)
12501251
)
1252+
if is_codegen_graph_partition_subgraph(self):
1253+
# Need get_raw_stream for subgraph
1254+
self.write_get_raw_stream_header()
12511255
self.kernel_autotune_calls.writeline(
12521256
f"stream{device_idx} = get_raw_stream({device_idx})"
12531257
)

torch/_inductor/utils.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3101,3 +3101,12 @@ def get_ld_library_path() -> str:
31013101
path = os.pathsep.join([lib_path, path]) if path else lib_path
31023102

31033103
return path
3104+
3105+
3106+
def is_codegen_graph_partition_subgraph(wrapper: PythonWrapperCodegen) -> bool:
3107+
from torch._inductor.codegen.wrapper import SubgraphPythonWrapperCodegen
3108+
3109+
return (
3110+
isinstance(wrapper, SubgraphPythonWrapperCodegen)
3111+
and wrapper.partition_signatures is not None
3112+
)

0 commit comments

Comments
 (0)