Skip to content

Commit a3f9e04

Browse files
yiming0416pytorchmergebot
authored andcommitted
[export] Make aoti_call_delegate hop traceable (pytorch#148804)
Summary: The `aoti_call_delegate` hop now uses a stateless `original_gm` for tracing with fake tensors and the OSS AOTI Runner for running with real tensors Differential Revision: D70738393 Pull Request resolved: pytorch#148804 Approved by: https://github.com/SherlockNoMad
1 parent 51da241 commit a3f9e04

File tree

3 files changed

+82
-27
lines changed

3 files changed

+82
-27
lines changed

torch/_export/passes/lift_constants_pass.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,8 @@ def lift_constants_pass(
178178
continue
179179
if "LoweredBackendModule" in type(constant_val).__name__:
180180
continue
181+
if "AOTInductorRunnerWrapper" in type(constant_val).__name__:
182+
continue
181183
if isinstance(constant_val, torch.utils._pytree.TreeSpec):
182184
continue
183185

@@ -237,7 +239,6 @@ def lift_constants_pass(
237239
constant_name = f"lifted_tensor_{num_tensor_constants}"
238240
constant_fqn = get_constant_fqn(node, constant_name)
239241
num_tensor_constants += 1
240-
241242
else:
242243
raise SpecViolationError(
243244
f"getattr node {node} referencing unsupported type {type(constant_val)}"

torch/_export/verifier.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -271,6 +271,8 @@ def _is_type(name, ty):
271271
elif type(attr).__name__ == "AOTInductorEPModule":
272272
continue
273273

274+
elif type(attr).__name__ == "AOTInductorRunnerWrapper":
275+
continue
274276

275277
if not isinstance(attr, _allowed_getattr_types(is_toplevel_gm)):
276278
raise SpecViolationError(
Lines changed: 78 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,44 +1,43 @@
1+
# mypy: allow-untyped-defs
2+
13
# Copyright (c) Meta Platforms, Inc. and affiliates.
24
# All rights reserved.
35
#
46
# This source code is licensed under the BSD-style license found in the
57
# LICENSE file in the root directory of this source tree.
68

7-
# pyre-strict
8-
99
from __future__ import annotations
1010

1111
import torch
1212
import torch.utils._pytree as pytree
1313
from torch._ops import HigherOrderOperator
1414
from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode
15+
from torch.fx.experimental.proxy_tensor import (
16+
disable_proxy_modes_tracing,
17+
ProxyTorchDispatchMode,
18+
track_tensor_tree,
19+
)
1520

1621

17-
AOTI_LOWERED_MODULE = "AOTInductorEPModule"
22+
AOTI_LOWERED_MODULE = "AOTInductorEPModule/AOTInductorRunnerWrapper"
1823

1924

2025
class AOTICallDelegate(HigherOrderOperator):
2126
"""aoti_call_delegate is a HOP for calling AOTInductor lowered submodule in ExportedProgram.
2227
2328
It has the following signature:
2429
aoti_call_delegate(
25-
lowered_module: AOTInductorEPModule,
30+
lowered_module: Union[AOTInductorEPModule, AOTInductorRunnerWrapper]
2631
original_gm:fx.GraphModule,
2732
weight_args: List[Tensor],
2833
input_args: List[Tensor],
2934
) -> outputs: List[Tensor]
3035
3136
where,
3237
- lowered_module is the AOTInductor lowered submodule, backed by compiled .so file, supporting real tensor inputs
33-
- original_gm is the original GraphModule before lowering, allowing FakeTensor propagation
38+
- original_gm is the stateless version of the original GraphModule before lowering, allowing FakeTensor propagation
3439
- weight_args is the list of weights in original GraphModule, including parameters and buffers
3540
- input_args is the list of flatten inputs
36-
37-
NOTE: aoti_call_delegate doesn't support retracing yet, as original_gm is currently stateful with weight as get_attr nodes.
38-
This will fail functionalization during retrace. When we move AOTI to accept stateless GraphModule, we can enable retracing.
39-
40-
When serialization, we have special hanlding for aoti_call_delegate, as AOTInductorEPModule is not serializable
41-
and stateful original_gm is failing the verifier.
4241
"""
4342

4443
def __init__(self) -> None:
@@ -62,7 +61,6 @@ def __call__(
6261

6362

6463
@aoti_call_delegate.py_impl(torch._C.DispatchKey.CompositeExplicitAutograd)
65-
# pyre-ignore
6664
def call_delegate_cpu(
6765
lowered_module: AOTI_LOWERED_MODULE, # type: ignore[valid-type]
6866
original_gm: torch.fx.GraphModule,
@@ -77,27 +75,60 @@ def call_delegate_cpu(
7775
new_args = pytree.tree_map_only(
7876
tuple(map_types.keys()),
7977
lambda a: map_types[type(a)](a),
80-
input_args,
78+
weight_args + input_args,
8179
lambda a: isinstance(a, tuple(map_types.keys())),
8280
)
83-
84-
has_fake_input_args = any(isinstance(arg, FakeTensor) for arg in new_args)
85-
has_fake_params = any(
86-
isinstance(param, FakeTensor) for param in original_gm.parameters()
87-
)
88-
has_fake_buffers = any(
89-
isinstance(buffer, FakeTensor) for buffer in original_gm.buffers()
81+
has_fake_args = any(isinstance(arg, FakeTensor) for arg in new_args)
82+
if has_fake_args:
83+
# use stateless original_gm for tracing with fake tensors
84+
fake_out = original_gm(*new_args)
85+
return fake_out
86+
else:
87+
# use AOTI Runner for real tensors
88+
new_input_args = new_args[len(weight_args) :]
89+
if type(lowered_module).__name__ == "AOTInductorRunnerWrapper":
90+
return lowered_module(*new_input_args) # type: ignore[misc]
91+
elif type(lowered_module).__name__ == "AOTInductorEPModule":
92+
return lowered_module(new_input_args) # type: ignore[misc]
93+
else:
94+
raise RuntimeError(
95+
f"Unexpected lowered_module type: {type(lowered_module)}."
96+
)
97+
98+
99+
def trace_aoti_call_delegate(
100+
proxy_mode, func_overload, lowered_module, original_gm, weight_args, input_args
101+
):
102+
proxy_mode.tracer.root.register_module("lowered_module", lowered_module)
103+
proxy_mode.tracer.root.register_module("original_gm", original_gm)
104+
105+
node_args = (lowered_module, original_gm, weight_args, input_args)
106+
proxy_args = pytree.tree_map(proxy_mode.tracer.unwrap_proxy, node_args)
107+
108+
out_proxy = proxy_mode.tracer.create_proxy(
109+
"call_function", func_overload, proxy_args, {}, name="aoti_call_delegate"
90110
)
111+
with disable_proxy_modes_tracing():
112+
out = call_delegate_cpu(lowered_module, original_gm, weight_args, input_args)
91113

92-
if has_fake_input_args or has_fake_params or has_fake_buffers:
93-
# aoti lowered module doesn't support fake tensor
94-
return original_gm(*new_args)
95-
else:
96-
return lowered_module(new_args) # type: ignore[misc]
114+
return track_tensor_tree(out, out_proxy, constant=None, tracer=proxy_mode.tracer)
115+
116+
117+
@aoti_call_delegate.py_impl(ProxyTorchDispatchMode)
118+
def call_delegate_proxy_torch_dispatch_mode(
119+
mode: ProxyTorchDispatchMode,
120+
lowered_module: AOTI_LOWERED_MODULE, # type: ignore[valid-type]
121+
original_gm: torch.fx.GraphModule,
122+
weight_args: list[torch.Tensor],
123+
input_args: list[torch.Tensor],
124+
):
125+
res = trace_aoti_call_delegate(
126+
mode, aoti_call_delegate, lowered_module, original_gm, weight_args, input_args
127+
)
128+
return res
97129

98130

99131
@aoti_call_delegate.py_impl(FakeTensorMode)
100-
# pyre-ignore
101132
def call_delegate_fake_tensor_mode(
102133
mode: FakeTensorMode,
103134
lowered_module: AOTI_LOWERED_MODULE, # type: ignore[valid-type]
@@ -107,3 +138,24 @@ def call_delegate_fake_tensor_mode(
107138
) -> list[torch.Tensor]:
108139
with mode:
109140
return call_delegate_cpu(lowered_module, original_gm, weight_args, input_args)
141+
142+
143+
@aoti_call_delegate.py_functionalize_impl
144+
def call_delegate_functionalize(
145+
ctx,
146+
lowered_module: AOTI_LOWERED_MODULE, # type: ignore[valid-type]
147+
original_gm: torch.fx.GraphModule,
148+
weight_args: list[torch.Tensor],
149+
input_args: list[torch.Tensor],
150+
):
151+
unwrapped_weight_args = tuple(
152+
ctx.unwrap_tensors(weight_arg) for weight_arg in weight_args
153+
)
154+
unwrapped_input_args = tuple(
155+
ctx.unwrap_tensors(input_arg) for input_arg in input_args
156+
)
157+
with ctx.redispatch_to_next():
158+
res = aoti_call_delegate(
159+
lowered_module, original_gm, unwrapped_weight_args, unwrapped_input_args # type: ignore[arg-type]
160+
)
161+
return ctx.wrap_tensors(res)

0 commit comments

Comments
 (0)