Skip to content

Commit 633dcf1

Browse files
trieuatpytorchmergebot
authored andcommitted
Constant folding for lifted graph (pytorch#135060)
Summary: Current implementation for lifted graph takes a dict of [constant name: constant value]. And the constant value is used to run_node and excute the constant graph to get the folded values and then create new getattr nodes for folded values. We don't have constant values for lifted graph during model compilation on MTIA. I think it is more general to allow the constant folding pass to just take the constant names only to produce the constant graph and represent the folded nodes as placeholders to make it consistent with lifted graph. Additionally, this mimic the real situation on Sigmoid, where Sigmoid executes the constant graph, get the folded values and set the folded values to the main graph. This diff is to update the pass to work with a list of constant names. Test Plan: ``` buck run mode/opt caffe2/test:test_export -- -r split_const_gm ``` Differential Revision: D62144791 Pull Request resolved: pytorch#135060 Approved by: https://github.com/SherlockNoMad Co-authored-by: Tuan Trieu <[email protected]>
1 parent a99e8ee commit 633dcf1

File tree

3 files changed

+97
-55
lines changed

3 files changed

+97
-55
lines changed

test/export/test_export.py

Lines changed: 29 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8089,7 +8089,9 @@ def forward(self, x):
80898089
w_transpose = torch.transpose(self.w_pre, 0, 1)
80908090
w_relu = torch.nn.functional.relu(w_transpose)
80918091
w = w_relu + self.b
8092-
return torch.matmul(x, w)
8092+
return (
8093+
torch.matmul(x, w) + self.b + torch.arange(4, dtype=torch.float16)
8094+
)
80938095

80948096
example_inputs = (torch.randn(4, 4),)
80958097
mod = Model()
@@ -8105,17 +8107,38 @@ def forward(self, x):
81058107
for n, spec in zip(placeholder_nodes, new_sig.input_specs)
81068108
if spec.target is not None
81078109
}
8108-
const_gm, _ = split_const_gm(new_gm, lifted_constants)
8110+
# [self.w_pre, self.b]
8111+
lifted_constant_names = list(lifted_constants)
8112+
lifted_constant_values = [lifted_constants[n] for n in lifted_constant_names]
8113+
const_gm, _ = split_const_gm(new_gm, False, lifted_constant_names)
81098114
counter = 0
81108115
for node in const_gm.graph.nodes:
81118116
if node.op == "call_function":
81128117
counter += 1
8113-
self.assertTrue(counter > 0)
8118+
self.assertTrue(counter == 4)
8119+
counter = 0
8120+
for n in new_gm.graph.nodes:
8121+
if n.op == "placeholder":
8122+
counter += 1
8123+
# expect 3 existing placeholders and 2 folded constant
8124+
self.assertTrue(counter == 5)
8125+
# return (self.b, folded_const, folded_const)
8126+
const_folded_value = const_gm(*lifted_constant_values)
8127+
81148128
test_input = torch.randn(4, 4)
8115-
expected = new_gm(None, None, test_input)[0]
8116-
actual = mod(test_input)
8129+
# new_gm(c_w_pre, b, x, folded_const, folded_const)
8130+
actual = new_gm(
8131+
lifted_constant_values[0],
8132+
const_folded_value[0],
8133+
test_input,
8134+
const_folded_value[1],
8135+
const_folded_value[2],
8136+
)[0]
8137+
expected = mod(test_input)
81178138
self.assertEqual(actual, expected)
8118-
const_gm, _ = split_const_gm(ep.graph_module, lifted_constants, lambda x: True)
8139+
const_gm, _ = split_const_gm(
8140+
ep.graph_module, False, lifted_constant_names, lambda x: True
8141+
)
81198142
counter = 0
81208143
for node in const_gm.graph.nodes:
81218144
if node.op == "call_function":

torch/_inductor/compile_fx.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -350,7 +350,8 @@ def _recursive_post_grad_passes(gm: GraphModule, is_inference: bool = False) ->
350350

351351
def split_const_gm(
352352
gm: GraphModule,
353-
lifted_constants: Optional[Dict[str, Any]] = None,
353+
skip_constructor: bool = True,
354+
lifted_constant_names: Optional[List[str]] = None,
354355
skip_folding_node_fn: Optional[Callable[[torch.fx.Node], bool]] = None,
355356
) -> Tuple[GraphModule, Dict[str, int]]:
356357
"""
@@ -377,9 +378,10 @@ def split_const_gm(
377378
run_and_get_constant_graph,
378379
)
379380

380-
const_gm, const_result = run_and_get_constant_graph(
381-
gm, lifted_constants, skip_folding_node_fn
381+
const_gm = run_and_get_constant_graph(
382+
gm, skip_constructor, lifted_constant_names, skip_folding_node_fn
382383
)
384+
const_result = const_gm() if lifted_constant_names is None else None
383385

384386
const_outputs = {
385387
x.name: idx for idx, x in enumerate(tuple(const_gm.graph.nodes)[-1].args[0])
@@ -399,7 +401,11 @@ def split_const_gm(
399401
replace_node_with_constant(
400402
gm,
401403
node,
402-
const_result[const_outputs[node.name]],
404+
(
405+
const_result[const_outputs[node.name]]
406+
if lifted_constant_names is None
407+
else None
408+
),
403409
new_const_name,
404410
)
405411
const_output_index[new_const_name] = const_outputs[node.name]

torch/_inductor/constant_folding.py

Lines changed: 58 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import collections
2-
from typing import Any, Callable, Dict, List, Optional, Tuple
2+
from typing import Any, Callable, Dict, List, Optional
33

44
import torch
55
import torch.utils._pytree as pytree
@@ -18,7 +18,7 @@
1818
def replace_node_with_constant(
1919
gm: torch.fx.GraphModule,
2020
node: torch.fx.Node,
21-
constant: torch.Tensor,
21+
constant: Optional[torch.Tensor] = None,
2222
name: Optional[str] = None,
2323
) -> None:
2424
g = gm.graph
@@ -39,32 +39,33 @@ def replace_node_with_constant(
3939
gm._frozen_param_count = i + 1
4040

4141
with g.inserting_before(node):
42-
new_input_node = g.create_node("get_attr", qualname, (), {})
42+
if constant is not None:
43+
new_input_node = g.create_node("get_attr", qualname, (), {})
44+
else:
45+
# this is the case for lifted constants
46+
new_input_node = g.create_node("placeholder", qualname, (), {})
4347
node.replace_all_uses_with(new_input_node)
4448
new_input_node.meta.update(node.meta)
4549
g.erase_node(node)
4650

47-
# needed to suppress `does not reference an nn.Module, nn.Parameter, or buffer` warning
48-
gm.register_buffer(qualname, constant)
49-
setattr(gm, qualname, constant)
51+
if constant is not None:
52+
# needed to suppress `does not reference an nn.Module, nn.Parameter, or buffer` warning
53+
gm.register_buffer(qualname, constant)
54+
setattr(gm, qualname, constant)
5055

5156

5257
def is_const_source(
53-
node: torch.fx.Node, lifted_constants: Optional[Dict[str, Any]]
58+
node: torch.fx.Node, lifted_constant_names: Optional[List[str]]
5459
) -> bool:
55-
return node.op == "get_attr" or (
56-
node.op == "placeholder"
57-
and lifted_constants is not None
58-
and node.name in lifted_constants
59-
)
60+
return node.op == "get_attr" or node.name in (lifted_constant_names or ())
6061

6162

6263
class ConstantFolder(torch.fx.Interpreter):
6364
def __init__(
6465
self,
6566
gm: torch.fx.GraphModule,
6667
skip_constructors: bool = False,
67-
lifted_constants: Optional[Dict[str, torch.Tensor]] = None,
68+
lifted_constant_names: Optional[List[str]] = None,
6869
skip_folding_node_fn: Optional[Callable[[torch.fx.Node], bool]] = None,
6970
) -> None:
7071
super().__init__(gm)
@@ -76,14 +77,27 @@ def __init__(
7677
# overwrite this to deallocate env values if their only remaining use
7778
# is the output
7879
self.user_to_last_uses = self.node_to_last_non_output_use()
79-
self.lifted_constants = lifted_constants
80+
self.lifted_constant_names = lifted_constant_names
81+
self.deferred_value = object()
8082

8183
def _support_dynamic_shape(self) -> bool:
8284
# ConstantFolder not support dynamic shape now
8385
return False
8486

8587
def _deduce_value(self, node: torch.fx.Node) -> Any:
86-
return super().run_node(node)
88+
if self.lifted_constant_names is None:
89+
return super().run_node(node)
90+
# if lifted_constant_names is passed in, no concrete value is available
91+
# so we just check if all inputs have values
92+
flattened_node_inps = pytree.arg_tree_leaves(*node.args, **node.kwargs)
93+
for inp in flattened_node_inps:
94+
if (
95+
isinstance(inp, torch.fx.Node)
96+
and inp.name not in (self.lifted_constant_names or ())
97+
and self.env[inp] != self.deferred_value
98+
):
99+
return self.unknown_value
100+
return self.deferred_value
87101

88102
def is_impure(self, node: torch.fx.node.Node) -> bool:
89103
def is_woq_int8_pattern(node: torch.fx.node.Node) -> bool:
@@ -103,7 +117,7 @@ def is_woq_int8_pattern(node: torch.fx.node.Node) -> bool:
103117
and is_woq_int8_pattern(next(iter(node.users)))
104118
)
105119
) and is_const_source(
106-
node.args[0], self.lifted_constants # type: ignore[arg-type]
120+
node.args[0], self.lifted_constant_names # type: ignore[arg-type]
107121
):
108122
# Case 1: int8_weight -> dq -> bf16_weight
109123
# Case 2: int8_weight -> permute -> dq -> bf16_weight
@@ -191,7 +205,7 @@ def set_env(arg: torch.fx.Node) -> None:
191205
# TODO - more complicated strategy
192206
if (
193207
self.skip_constructors
194-
and not is_const_source(node, self.lifted_constants)
208+
and not is_const_source(node, self.lifted_constant_names)
195209
and not any(isinstance(e, torch.Tensor) for e in flattened_inputs)
196210
):
197211
return self.unknown_value
@@ -207,10 +221,10 @@ def set_env(arg: torch.fx.Node) -> None:
207221
if out == self.unknown_value:
208222
return self.unknown_value
209223

210-
if not is_const_source(node, self.lifted_constants) and isinstance(
211-
out, torch.Tensor
224+
if not is_const_source(node, self.lifted_constant_names) and (
225+
isinstance(out, torch.Tensor) or out == self.deferred_value
212226
):
213-
if out.device.type == "meta":
227+
if out != self.deferred_value and out.device.type == "meta":
214228
return out
215229

216230
if not self.insertable_tensor_check(out):
@@ -248,10 +262,12 @@ def run(self) -> Any: # type: ignore[override]
248262

249263
def insert_placerholder_values(self, env: Dict[torch.fx.Node, Any]) -> None:
250264
for n in self.module.graph.find_nodes(op="placeholder"):
251-
if self.lifted_constants is not None and n.name in self.lifted_constants:
252-
env[n] = self.lifted_constants[n.name]
253-
else:
254-
env[n] = self.unknown_value # type: ignore[assignment]
265+
env[n] = self.unknown_value # type: ignore[assignment]
266+
if self.lifted_constant_names is None:
267+
return
268+
for n in self.module.graph.nodes:
269+
if n.name in (self.lifted_constant_names or ()):
270+
env[n] = self.deferred_value
255271

256272

257273
def constant_fold(
@@ -284,12 +300,15 @@ def constant_fold(
284300

285301
def constant_graph_tag(
286302
gm: torch.fx.GraphModule,
287-
lifted_constants: Optional[Dict[str, Any]],
288-
skip_folding_node_fn: Optional[Callable[[torch.fx.Node], bool]],
303+
skip_constructors: bool = True,
304+
lifted_constant_names: Optional[List[str]] = None,
305+
skip_folding_node_fn: Optional[Callable[[torch.fx.Node], bool]] = None,
289306
) -> None:
290307
with torch.utils._python_dispatch._disable_current_modes():
291308
cf = ConstantFolder(
292-
gm, skip_constructors=True, lifted_constants=lifted_constants
309+
gm,
310+
skip_constructors=skip_constructors,
311+
lifted_constant_names=lifted_constant_names,
293312
)
294313
cf.run()
295314

@@ -298,7 +317,7 @@ def constant_graph_tag(
298317
node.meta[META_TAG] = MODULE_TAG
299318
continue
300319
if (
301-
is_const_source(node, lifted_constants)
320+
is_const_source(node, lifted_constant_names)
302321
or node in cf.node_replacements
303322
or node in cf.replaced_uses
304323
):
@@ -309,15 +328,18 @@ def constant_graph_tag(
309328

310329
def run_and_get_constant_graph(
311330
gm: torch.fx.GraphModule,
312-
lifted_constants: Optional[Dict[str, Any]],
313-
skip_folding_node_fn: Optional[Callable[[torch.fx.Node], bool]],
314-
) -> Tuple[torch.fx.GraphModule, Tuple[torch.Tensor, ...]]:
331+
skip_constructors: bool = True,
332+
lifted_constant_names: Optional[List[str]] = None,
333+
skip_folding_node_fn: Optional[Callable[[torch.fx.Node], bool]] = None,
334+
) -> torch.fx.GraphModule:
315335
"""
316336
Construct a GraphModule which corresponds to the part which could be
317337
constant folded in provided gm.
318338
"""
319339

320-
constant_graph_tag(gm, lifted_constants, skip_folding_node_fn)
340+
constant_graph_tag(
341+
gm, skip_constructors, lifted_constant_names, skip_folding_node_fn
342+
)
321343

322344
def untag(node: torch.fx.Node) -> bool:
323345
used_to_fold = False
@@ -329,19 +351,11 @@ def untag(node: torch.fx.Node) -> bool:
329351
node.meta[META_TAG] = MODULE_TAG
330352
return used_to_fold
331353

332-
const_args = []
333-
if lifted_constants is not None:
334-
placeholders = list(gm.graph.find_nodes(op="placeholder"))
335-
for node in placeholders:
336-
if node.meta[META_TAG] == MODULE_TAG:
337-
continue
338-
if untag(node):
339-
const_args.append(lifted_constants[node.name])
340-
341354
# We rewrite the tags, if it's a constant being directly consumed, without
342355
# any folding opportunity, we keep it in main gm.
343-
for node in gm.graph.find_nodes(op="get_attr"):
344-
untag(node)
356+
for node in gm.graph.nodes:
357+
if node.op == "getattr" or (node.name in (lifted_constant_names or ())):
358+
untag(node)
345359

346360
new_graph = torch.fx.Graph()
347361

@@ -363,5 +377,4 @@ def untag(node: torch.fx.Node) -> bool:
363377
new_graph.lint()
364378
new_gm = torch.fx.GraphModule(gm, new_graph)
365379

366-
const_result = new_gm(*const_args)
367-
return new_gm, const_result
380+
return new_gm

0 commit comments

Comments
 (0)