Skip to content

Commit f49833d

Browse files
angelayipytorchmergebot
authored andcommitted
[hoo] Invoke subgraph + effect (pytorch#167231)
This PR adds support for effectful ops within invoke_subgraphs. * Most of the logic is in `invoke_subgraph.py_functionalize_impl`. * In the functionalization metadata collection phase, we note the tokens before going further down the dispatcher, and then note the tokens after coming back from the dispatcher. If there are nodes in the invoke_subgraph subgraph that contain effects, the number of effects should change, or the tokens used for an effect should. * We will store this effect difference in the `InvokeSubgraphCache` where the key is the identifier and value is the effect. For now we only support one effect within a subgraph. * During the tracing part of AOTAutograd, we will then wrap the subgraph to take in and output a token. Before: ``` def forward(self, x): repeated_subgraph0 = self.repeated_subgraph0 invoke_subgraph = torch.ops.higher_order.invoke_subgraph(repeated_subgraph0, 'subgraph_0', x) return invoke_subgraph def repeated_subgraph(self, x): record_memory = torch.ops.mylib.record_memory.default("forward", "N") add = torch.ops.aten.add(x, x) return add ``` After: ``` def forward(self, token, x): repeated_subgraph0 = self.repeated_subgraph0 invoke_subgraph = torch.ops.higher_order.invoke_subgraph(repeated_subgraph0, 'subgraph_0', token, x) getitem = invoke_subgraph[0] # output token getitem_1 = invoke_subgraph[1] return (getitem, getitem_1) def repeated_subgraph(self, token, x): with_effects = torch.ops.higher_order.with_effects(token, torch.ops.mylib.record_memory.default, 'forward', 'N') getitem = with_effects[0] # output token add = torch.ops.aten.add(x, x) return (getitem, add) ``` * Then there is a bunch of logic within `_remove_effect_tokens` to handle removing the effects from the invoke_subgraph subgraph Differential Revision: [D87392741](https://our.internmc.facebook.com/intern/diff/D87392741) Pull Request resolved: pytorch#167231 Approved by: https://github.com/anijain2305
1 parent 28c7602 commit f49833d

File tree

10 files changed

+370
-135
lines changed

10 files changed

+370
-135
lines changed

test/export/test_converter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1405,7 +1405,7 @@ def func3(x): # noqa: F841
14051405
)
14061406
# qnnpack not supported on s390x
14071407
@xfailIfS390X
1408-
def test_ts2ep_convert_quantized_model(self):
1408+
def test_ts2ep_convert_quantized_model1(self):
14091409
class Standalone(torch.nn.Module):
14101410
def __init__(self):
14111411
super().__init__()

test/export/test_passes.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -640,16 +640,13 @@ def forward(self, x):
640640
self.assertExpectedInline(
641641
without_token_ep.graph_module.code.strip(),
642642
"""\
643-
def forward(self, token, obj_attr, x):
644-
with_effects = torch.ops.higher_order.with_effects(token, torch.ops._TorchScriptTesting.takes_foo_tuple_return.default, foo = obj_attr, x = x); token = x = None
645-
getitem = with_effects[0]
646-
getitem_1 = with_effects[1]
647-
getitem_2 = with_effects[2]; with_effects = None
643+
def forward(self, obj_attr, x):
644+
takes_foo_tuple_return_default = torch.ops._TorchScriptTesting.takes_foo_tuple_return.default(foo = obj_attr, x = x); x = None
645+
getitem_1 = takes_foo_tuple_return_default[0]
646+
getitem_2 = takes_foo_tuple_return_default[1]; takes_foo_tuple_return_default = None
648647
add = torch.ops.aten.add.Tensor(getitem_1, getitem_2); getitem_1 = getitem_2 = None
649-
with_effects_1 = torch.ops.higher_order.with_effects(getitem, torch.ops._TorchScriptTesting.takes_foo.default, foo = obj_attr, x = add); getitem = obj_attr = add = None
650-
getitem_3 = with_effects_1[0]
651-
getitem_4 = with_effects_1[1]; with_effects_1 = None
652-
return (getitem_3, getitem_4)""", # noqa: B950
648+
takes_foo_default = torch.ops._TorchScriptTesting.takes_foo.default(foo = obj_attr, x = add); obj_attr = add = None
649+
return (takes_foo_default,)""", # noqa: B950
653650
)
654651

655652
def test_fakify_script_objects(self):

test/export/test_torchbind.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -461,9 +461,9 @@ def forward(self, x):
461461
x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
462462
attr = self.attr
463463
_guards_fn = self._guards_fn(x); _guards_fn = None
464-
takes_foo_default_1 = torch.ops._TorchScriptTesting.takes_foo.default(attr, x)
465-
takes_foo_default = torch.ops._TorchScriptTesting.takes_foo.default(attr, takes_foo_default_1); attr = takes_foo_default_1 = None
466-
add = torch.ops.aten.add.Tensor(x, takes_foo_default); x = takes_foo_default = None
464+
takes_foo_default = torch.ops._TorchScriptTesting.takes_foo.default(attr, x)
465+
takes_foo_default_1 = torch.ops._TorchScriptTesting.takes_foo.default(attr, takes_foo_default); attr = takes_foo_default = None
466+
add = torch.ops.aten.add.Tensor(x, takes_foo_default_1); x = takes_foo_default_1 = None
467467
return pytree.tree_unflatten((add,), self._out_spec)""", # noqa: B950
468468
)
469469
self.assertExpectedInline(
@@ -1087,10 +1087,12 @@ def forward(self, token, tq, x):
10871087
str(ep.graph_module.graph).strip(),
10881088
"""\
10891089
graph():
1090+
%token : [num_users=1] = placeholder[target=token]
10901091
%tq : [num_users=2] = placeholder[target=tq]
10911092
%x : [num_users=1] = placeholder[target=x]
1092-
%queue_push_default : [num_users=0] = call_function[target=torch.ops._TorchScriptTesting.queue_push.default](args = (%tq, %x), kwargs = {})
1093-
return (tq,)""", # noqa: B950
1093+
%with_effects : [num_users=1] = call_function[target=torch.ops.higher_order.with_effects](args = (%token, _TorchScriptTesting.queue_push.default, %tq, %x), kwargs = {})
1094+
%getitem : [num_users=1] = call_function[target=operator.getitem](args = (%with_effects, 0), kwargs = {})
1095+
return (getitem, tq)""", # noqa: B950
10941096
)
10951097

10961098
def test_deepcopy(self):

test/higher_order_ops/test_with_effects.py

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -870,6 +870,104 @@ def forward(self, primals_2, getitem_1, tangents_1, tangents_token):
870870
finally:
871871
handle.destroy()
872872

873+
@unittest.skipIf(not TEST_CUDA, "triton")
874+
def test_export_invoke_subgraph(self):
875+
with torch.library._scoped_library("mylib", "FRAGMENT") as lib:
876+
recorded_list = []
877+
878+
@torch.library.custom_op("mylib::record_memory", mutates_args=())
879+
def record_memory(prefix: str, module_name: str) -> None:
880+
torch.cuda.synchronize()
881+
mem_alloc = torch.cuda.memory_allocated() / 1024**2
882+
mem_reserved = torch.cuda.memory_reserved() / 1024**2
883+
memory_str = f"[{prefix}] {module_name}: allocated={mem_alloc:.2f} MB, reserved={mem_reserved:.2f} MB"
884+
recorded_list.append(memory_str)
885+
886+
@record_memory.register_fake
887+
def record_memory_fake(prefix, module_name):
888+
return
889+
890+
record_memory.register_effect(_EffectType.ORDERED)
891+
892+
class N(torch.nn.Module):
893+
def __init__(self):
894+
super().__init__()
895+
self.linear1 = torch.nn.Linear(1024, 1024)
896+
self.relu = torch.nn.ReLU()
897+
self.linear2 = torch.nn.Linear(1024, 1024)
898+
899+
@torch.compiler.nested_compile_region
900+
def forward(self, x):
901+
torch.ops.mylib.record_memory("forward", "N")
902+
x = self.linear1(x)
903+
x = self.relu(x)
904+
x = self.linear2(x)
905+
return x
906+
907+
class M(torch.nn.Module):
908+
def __init__(self):
909+
super().__init__()
910+
self.mod_list = torch.nn.ModuleList(N() for _ in range(3))
911+
912+
def forward(self, x):
913+
for m in self.mod_list:
914+
x = m(x)
915+
torch.ops.mylib.record_memory("forward", "N")
916+
return (x,)
917+
918+
model = M().to("cuda")
919+
torch.cuda.reset_peak_memory_stats()
920+
921+
x = torch.randn(32, 1024, requires_grad=True, device="cuda")
922+
923+
ep = torch.export.export(model, (x,))
924+
ep = ep.run_decompositions()
925+
self.assertEqual(len(list(ep.graph_module.named_modules())), 2)
926+
927+
self.assertExpectedInline(
928+
ep.graph_module.code.strip(),
929+
"""\
930+
def forward(self, token, p_mod_list_0_linear1_weight, p_mod_list_0_linear1_bias, p_mod_list_0_linear2_weight, p_mod_list_0_linear2_bias, p_mod_list_1_linear1_weight, p_mod_list_1_linear1_bias, p_mod_list_1_linear2_weight, p_mod_list_1_linear2_bias, p_mod_list_2_linear1_weight, p_mod_list_2_linear1_bias, p_mod_list_2_linear2_weight, p_mod_list_2_linear2_bias, x):
931+
repeated_subgraph0 = self.repeated_subgraph0
932+
invoke_subgraph = torch.ops.higher_order.invoke_subgraph(repeated_subgraph0, 'subgraph_0', token, x, p_mod_list_0_linear1_weight, p_mod_list_0_linear1_bias, p_mod_list_0_linear2_weight, p_mod_list_0_linear2_bias); repeated_subgraph0 = token = x = p_mod_list_0_linear1_weight = p_mod_list_0_linear1_bias = p_mod_list_0_linear2_weight = p_mod_list_0_linear2_bias = None
933+
getitem = invoke_subgraph[0]
934+
getitem_1 = invoke_subgraph[1]; invoke_subgraph = None
935+
repeated_subgraph0_1 = self.repeated_subgraph0
936+
invoke_subgraph_1 = torch.ops.higher_order.invoke_subgraph(repeated_subgraph0_1, 'subgraph_0', getitem, getitem_1, p_mod_list_1_linear1_weight, p_mod_list_1_linear1_bias, p_mod_list_1_linear2_weight, p_mod_list_1_linear2_bias); repeated_subgraph0_1 = getitem = getitem_1 = p_mod_list_1_linear1_weight = p_mod_list_1_linear1_bias = p_mod_list_1_linear2_weight = p_mod_list_1_linear2_bias = None
937+
getitem_2 = invoke_subgraph_1[0]
938+
getitem_3 = invoke_subgraph_1[1]; invoke_subgraph_1 = None
939+
repeated_subgraph0_2 = self.repeated_subgraph0
940+
invoke_subgraph_2 = torch.ops.higher_order.invoke_subgraph(repeated_subgraph0_2, 'subgraph_0', getitem_2, getitem_3, p_mod_list_2_linear1_weight, p_mod_list_2_linear1_bias, p_mod_list_2_linear2_weight, p_mod_list_2_linear2_bias); repeated_subgraph0_2 = getitem_2 = getitem_3 = p_mod_list_2_linear1_weight = p_mod_list_2_linear1_bias = p_mod_list_2_linear2_weight = p_mod_list_2_linear2_bias = None
941+
getitem_4 = invoke_subgraph_2[0]
942+
getitem_5 = invoke_subgraph_2[1]; invoke_subgraph_2 = None
943+
with_effects = torch.ops.higher_order.with_effects(getitem_4, torch.ops.mylib.record_memory.default, 'forward', 'N'); getitem_4 = None
944+
getitem_6 = with_effects[0]; with_effects = None
945+
return (getitem_6, getitem_5)""",
946+
)
947+
948+
self.assertExpectedInline(
949+
ep.graph_module.repeated_subgraph0.code.strip(),
950+
"""\
951+
def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1):
952+
with_effects = torch.ops.higher_order.with_effects(arg0_1, torch.ops.mylib.record_memory.default, 'forward', 'N'); arg0_1 = None
953+
getitem = with_effects[0]; with_effects = None
954+
permute = torch.ops.aten.permute.default(arg2_1, [1, 0]); arg2_1 = None
955+
addmm = torch.ops.aten.addmm.default(arg3_1, arg1_1, permute); arg3_1 = arg1_1 = permute = None
956+
relu = torch.ops.aten.relu.default(addmm); addmm = None
957+
permute_1 = torch.ops.aten.permute.default(arg4_1, [1, 0]); arg4_1 = None
958+
addmm_1 = torch.ops.aten.addmm.default(arg5_1, relu, permute_1); arg5_1 = relu = permute_1 = None
959+
return (getitem, addmm_1)""",
960+
)
961+
962+
recorded_list.clear()
963+
# TODO: seems like invoke_subgraph's py_autograd impl calls the subgraph
964+
# eagerly twice. Once for get_output_metadata and then once for
965+
# InvokeSubgraphAutogradOp. This causes record_memory to be called twice.
966+
with torch.no_grad():
967+
out2 = ep.module()(x)
968+
self.assertEqual(len(recorded_list), 4)
969+
self.assertTrue(torch.allclose(model(x)[0], out2[0]))
970+
873971

874972
if __name__ == "__main__":
875973
run_tests()

torch/_guards.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -713,6 +713,9 @@ def __init__(self) -> None:
713713
self.lazy_bwd_cache: dict[
714714
str, dict[tuple[object], tuple[torch.fx.GraphModule, int]]
715715
] = defaultdict(dict)
716+
self.effects_cache: dict[
717+
str, set
718+
] = {} # Maps identifier -> set of effect types
716719

717720
def add_dynamo_installed_submodule(self, fn_id: int, identifier: str) -> None:
718721
self.dynamo_installed_submodules[fn_id].append(identifier)
@@ -751,6 +754,21 @@ def get_lazy_bwd_entry(
751754

752755
return self.lazy_bwd_cache[identifier].get(tangent_metadata, (None, None))
753756

757+
def add_effects(self, identifier: str, effects: set) -> None:
758+
"""Store the effect types for a given invoke_subgraph identifier."""
759+
if prev_effects := self.effects_cache.get(identifier, None):
760+
assert effects == prev_effects, (
761+
"Different number of effects were found for invoke_subgraph "
762+
f"call with identifier {identifier}. \n"
763+
f"Previously we had the following effects: {prev_effects}.\n"
764+
f"But now we have: {effects}."
765+
)
766+
self.effects_cache[identifier] = effects
767+
768+
def get_effects(self, identifier: str) -> Optional[set]:
769+
"""Retrieve the effect types for a given invoke_subgraph identifier."""
770+
return self.effects_cache.get(identifier, None)
771+
754772

755773
class HopDispatchSetCache:
756774
def __init__(self) -> None:

torch/_higher_order_ops/invoke_subgraph.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ def __call__(
8080
assert all(
8181
isinstance(o, (torch.Tensor, int, torch.SymInt, torch.Generator))
8282
for o in operands
83+
if o is not None
8384
), (
8485
f"invoke_subgraph operands must be a list of tensors/ints/SymInts/Generator {operands}"
8586
)
@@ -562,7 +563,34 @@ def _(ctx, subgraph, identifier, *operands):
562563
do_auto_functionalize_v2,
563564
)
564565

566+
# (in the functionalization metadata phase) Capture tokens before
567+
tokens_before = dict(ctx.mode._tokens)
568+
569+
# Check if this subgraph has effects stored in the cache
570+
invoke_subgraph_cache = get_invoke_subgraph_cache()
571+
effects = None
572+
if invoke_subgraph_cache:
573+
effects = invoke_subgraph_cache.get_effects(identifier)
574+
575+
if effects:
576+
assert len(effects) == 1, "Multiple effects within a subgraph NYI"
577+
tokens = ctx.mode._tokens
578+
effects = next(iter(effects))
579+
token_input = tokens[effects]
580+
581+
operands = (token_input, *operands)
582+
583+
def wrap_subgraph(subgraph):
584+
def wrapped_subgraph(token, *args):
585+
res = subgraph(*args)
586+
return ctx.unwrap_tensors(ctx.mode._tokens[effects]), *res
587+
588+
return wrapped_subgraph
589+
590+
subgraph = wrap_subgraph(subgraph)
591+
565592
unwrapped_operands = ctx.unwrap_tensors(operands)
593+
566594
hop_instance = HopInstance.create(invoke_subgraph, subgraph, identifier, *operands)
567595
if can_auto_functionalize(hop_instance):
568596
# NOTE: [auto_functionalize x invoke_subgraph caching]
@@ -587,6 +615,28 @@ def _(ctx, subgraph, identifier, *operands):
587615
# of invoke_subgraph ops if input aliasing/mutation is detected.
588616
functionalized_subgraph = FunctionalizeCtxWrapper(ctx, subgraph)
589617
out = invoke_subgraph(functionalized_subgraph, identifier, *unwrapped_operands)
618+
619+
if effects:
620+
(new_token, *out) = out
621+
ctx.mode._tokens[effects] = new_token
622+
623+
# (in the functionalization metadata phase) Capture tokens after and see if
624+
# there are any differences (there are new effects or the token value for an
625+
# effect type has changed)
626+
tokens_after = dict(ctx.mode._tokens)
627+
discovered_effects = set()
628+
for effect_type, token in tokens_after.items():
629+
if effect_type not in tokens_before or tokens_before[effect_type] is not token:
630+
discovered_effects.add(effect_type)
631+
632+
if discovered_effects:
633+
assert ctx.mode._allow_token_discovery, (
634+
f"Number of tokens changed by {len(discovered_effects)} when tracing subgraph {subgraph}."
635+
)
636+
# Store discovered effects in the cache by identifier
637+
if invoke_subgraph_cache:
638+
invoke_subgraph_cache.add_effects(identifier, discovered_effects)
639+
590640
return ctx.wrap_tensors(out)
591641

592642

torch/_library/effects.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,13 +35,28 @@ def _set_default_effect(self) -> None:
3535
if namespace == "higher_order":
3636
return
3737

38+
# These classes do not have side effects as they just store quantization
39+
# params, so we dont need to mark them as ordered
40+
skip_classes = (
41+
"__torch__.torch.classes.quantized.Conv2dPackedParamsBase",
42+
"__torch__.torch.classes.quantized.Conv3dPackedParamsBase",
43+
"__torch__.torch.classes.quantized.EmbeddingPackedParamsBase",
44+
"__torch__.torch.classes.quantized.LinearPackedParamsBase",
45+
"__torch__.torch.classes.xnnpack.Conv2dOpContext",
46+
"__torch__.torch.classes.xnnpack.LinearOpContext",
47+
"__torch__.torch.classes.xnnpack.TransposeConv2dOpContext",
48+
)
49+
3850
opname = f"{namespace}::{opname}"
3951
if torch._C._get_operation_overload(opname, overload) is not None:
4052
# Since we call this when destroying the library, sometimes the
4153
# schema will be gone already at that time.
4254
schema = torch._C._get_schema(opname, overload)
4355
for arg in schema.arguments:
4456
if isinstance(arg.type, torch.ClassType):
57+
type_str = arg.type.str() # pyrefly: ignore[missing-attribute]
58+
if type_str in skip_classes:
59+
continue
4560
self._effect = EffectType.ORDERED
4661
return
4762

0 commit comments

Comments
 (0)