1+ # mypy: allow-untyped-defs
2+
13# Copyright (c) Meta Platforms, Inc. and affiliates.
24# All rights reserved.
35#
46# This source code is licensed under the BSD-style license found in the
57# LICENSE file in the root directory of this source tree.
68
7- # pyre-strict
8-
99from __future__ import annotations
1010
1111import torch
1212import torch .utils ._pytree as pytree
1313from torch ._ops import HigherOrderOperator
1414from torch ._subclasses .fake_tensor import FakeTensor , FakeTensorMode
15+ from torch .fx .experimental .proxy_tensor import (
16+ disable_proxy_modes_tracing ,
17+ ProxyTorchDispatchMode ,
18+ track_tensor_tree ,
19+ )
1520
1621
17- AOTI_LOWERED_MODULE = "AOTInductorEPModule"
22+ AOTI_LOWERED_MODULE = "AOTInductorEPModule/AOTInductorRunnerWrapper "
1823
1924
2025class AOTICallDelegate (HigherOrderOperator ):
2126 """aoti_call_delegate is a HOP for calling AOTInductor lowered submodule in ExportedProgram.
2227
2328 It has the following signature:
2429 aoti_call_delegate(
25- lowered_module: AOTInductorEPModule,
30+ lowered_module: Union[ AOTInductorEPModule, AOTInductorRunnerWrapper]
2631 original_gm:fx.GraphModule,
2732 weight_args: List[Tensor],
2833 input_args: List[Tensor],
2934 ) -> outputs: List[Tensor]
3035
3136 where,
3237 - lowered_module is the AOTInductor lowered submodule, backed by compiled .so file, supporting real tensor inputs
33- - original_gm is the original GraphModule before lowering, allowing FakeTensor propagation
38+ - original_gm is the stateless version of the original GraphModule before lowering, allowing FakeTensor propagation
3439 - weight_args is the list of weights in original GraphModule, including parameters and buffers
3540 - input_args is the list of flatten inputs
36-
37- NOTE: aoti_call_delegate doesn't support retracing yet, as original_gm is currently stateful with weight as get_attr nodes.
38- This will fail functionalization during retrace. When we move AOTI to accept stateless GraphModule, we can enable retracing.
39-
40- When serialization, we have special hanlding for aoti_call_delegate, as AOTInductorEPModule is not serializable
41- and stateful original_gm is failing the verifier.
4241 """
4342
4443 def __init__ (self ) -> None :
@@ -62,7 +61,6 @@ def __call__(
6261
6362
6463@aoti_call_delegate .py_impl (torch ._C .DispatchKey .CompositeExplicitAutograd )
65- # pyre-ignore
6664def call_delegate_cpu (
6765 lowered_module : AOTI_LOWERED_MODULE , # type: ignore[valid-type]
6866 original_gm : torch .fx .GraphModule ,
@@ -77,27 +75,60 @@ def call_delegate_cpu(
7775 new_args = pytree .tree_map_only (
7876 tuple (map_types .keys ()),
7977 lambda a : map_types [type (a )](a ),
80- input_args ,
78+ weight_args + input_args ,
8179 lambda a : isinstance (a , tuple (map_types .keys ())),
8280 )
83-
84- has_fake_input_args = any (isinstance (arg , FakeTensor ) for arg in new_args )
85- has_fake_params = any (
86- isinstance (param , FakeTensor ) for param in original_gm .parameters ()
87- )
88- has_fake_buffers = any (
89- isinstance (buffer , FakeTensor ) for buffer in original_gm .buffers ()
81+ has_fake_args = any (isinstance (arg , FakeTensor ) for arg in new_args )
82+ if has_fake_args :
83+ # use stateless original_gm for tracing with fake tensors
84+ fake_out = original_gm (* new_args )
85+ return fake_out
86+ else :
87+ # use AOTI Runner for real tensors
88+ new_input_args = new_args [len (weight_args ) :]
89+ if type (lowered_module ).__name__ == "AOTInductorRunnerWrapper" :
90+ return lowered_module (* new_input_args ) # type: ignore[misc]
91+ elif type (lowered_module ).__name__ == "AOTInductorEPModule" :
92+ return lowered_module (new_input_args ) # type: ignore[misc]
93+ else :
94+ raise RuntimeError (
95+ f"Unexpected lowered_module type: { type (lowered_module )} ."
96+ )
97+
98+
99+ def trace_aoti_call_delegate (
100+ proxy_mode , func_overload , lowered_module , original_gm , weight_args , input_args
101+ ):
102+ proxy_mode .tracer .root .register_module ("lowered_module" , lowered_module )
103+ proxy_mode .tracer .root .register_module ("original_gm" , original_gm )
104+
105+ node_args = (lowered_module , original_gm , weight_args , input_args )
106+ proxy_args = pytree .tree_map (proxy_mode .tracer .unwrap_proxy , node_args )
107+
108+ out_proxy = proxy_mode .tracer .create_proxy (
109+ "call_function" , func_overload , proxy_args , {}, name = "aoti_call_delegate"
90110 )
111+ with disable_proxy_modes_tracing ():
112+ out = call_delegate_cpu (lowered_module , original_gm , weight_args , input_args )
91113
92- if has_fake_input_args or has_fake_params or has_fake_buffers :
93- # aoti lowered module doesn't support fake tensor
94- return original_gm (* new_args )
95- else :
96- return lowered_module (new_args ) # type: ignore[misc]
114+ return track_tensor_tree (out , out_proxy , constant = None , tracer = proxy_mode .tracer )
115+
116+
117+ @aoti_call_delegate .py_impl (ProxyTorchDispatchMode )
118+ def call_delegate_proxy_torch_dispatch_mode (
119+ mode : ProxyTorchDispatchMode ,
120+ lowered_module : AOTI_LOWERED_MODULE , # type: ignore[valid-type]
121+ original_gm : torch .fx .GraphModule ,
122+ weight_args : list [torch .Tensor ],
123+ input_args : list [torch .Tensor ],
124+ ):
125+ res = trace_aoti_call_delegate (
126+ mode , aoti_call_delegate , lowered_module , original_gm , weight_args , input_args
127+ )
128+ return res
97129
98130
99131@aoti_call_delegate .py_impl (FakeTensorMode )
100- # pyre-ignore
101132def call_delegate_fake_tensor_mode (
102133 mode : FakeTensorMode ,
103134 lowered_module : AOTI_LOWERED_MODULE , # type: ignore[valid-type]
@@ -107,3 +138,24 @@ def call_delegate_fake_tensor_mode(
107138) -> list [torch .Tensor ]:
108139 with mode :
109140 return call_delegate_cpu (lowered_module , original_gm , weight_args , input_args )
141+
142+
143+ @aoti_call_delegate .py_functionalize_impl
144+ def call_delegate_functionalize (
145+ ctx ,
146+ lowered_module : AOTI_LOWERED_MODULE , # type: ignore[valid-type]
147+ original_gm : torch .fx .GraphModule ,
148+ weight_args : list [torch .Tensor ],
149+ input_args : list [torch .Tensor ],
150+ ):
151+ unwrapped_weight_args = tuple (
152+ ctx .unwrap_tensors (weight_arg ) for weight_arg in weight_args
153+ )
154+ unwrapped_input_args = tuple (
155+ ctx .unwrap_tensors (input_arg ) for input_arg in input_args
156+ )
157+ with ctx .redispatch_to_next ():
158+ res = aoti_call_delegate (
159+ lowered_module , original_gm , unwrapped_weight_args , unwrapped_input_args # type: ignore[arg-type]
160+ )
161+ return ctx .wrap_tensors (res )
0 commit comments