You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
This PR adds support for effectful ops within invoke_subgraphs.
* Most of the logic is in `invoke_subgraph.py_functionalize_impl`.
* In the functionalization metadata collection phase, we note the tokens before going further down the dispatcher, and then note the tokens after coming back from the dispatcher. If there are nodes in the invoke_subgraph subgraph that contain effects, the number of effects should change, or the tokens used for an effect should.
* We will store this effect difference in the `InvokeSubgraphCache` where the key is the identifier and value is the effect. For now we only support one effect within a subgraph.
* During the tracing part of AOTAutograd, we will then wrap the subgraph to take in and output a token.
Before:
```
def forward(self, x):
repeated_subgraph0 = self.repeated_subgraph0
invoke_subgraph = torch.ops.higher_order.invoke_subgraph(repeated_subgraph0, 'subgraph_0', x)
return invoke_subgraph
def repeated_subgraph(self, x):
record_memory = torch.ops.mylib.record_memory.default("forward", "N")
add = torch.ops.aten.add(x, x)
return add
```
After:
```
def forward(self, token, x):
repeated_subgraph0 = self.repeated_subgraph0
invoke_subgraph = torch.ops.higher_order.invoke_subgraph(repeated_subgraph0, 'subgraph_0', token, x)
getitem = invoke_subgraph[0] # output token
getitem_1 = invoke_subgraph[1]
return (getitem, getitem_1)
def repeated_subgraph(self, token, x):
with_effects = torch.ops.higher_order.with_effects(token, torch.ops.mylib.record_memory.default, 'forward', 'N')
getitem = with_effects[0] # output token
add = torch.ops.aten.add(x, x)
return (getitem, add)
```
* Then there is a bunch of logic within `_remove_effect_tokens` to handle removing the effects from the invoke_subgraph subgraph
Differential Revision: [D87392741](https://our.internmc.facebook.com/intern/diff/D87392741)
Pull Request resolved: pytorch#167231
Approved by: https://github.com/anijain2305
0 commit comments