11import collections
2- from typing import Any , Callable , Dict , List , Optional , Tuple
2+ from typing import Any , Callable , Dict , List , Optional
33
44import torch
55import torch .utils ._pytree as pytree
1818def replace_node_with_constant (
1919 gm : torch .fx .GraphModule ,
2020 node : torch .fx .Node ,
21- constant : torch .Tensor ,
21+ constant : Optional [ torch .Tensor ] = None ,
2222 name : Optional [str ] = None ,
2323) -> None :
2424 g = gm .graph
@@ -39,32 +39,33 @@ def replace_node_with_constant(
3939 gm ._frozen_param_count = i + 1
4040
4141 with g .inserting_before (node ):
42- new_input_node = g .create_node ("get_attr" , qualname , (), {})
42+ if constant is not None :
43+ new_input_node = g .create_node ("get_attr" , qualname , (), {})
44+ else :
45+ # this is the case for lifted constants
46+ new_input_node = g .create_node ("placeholder" , qualname , (), {})
4347 node .replace_all_uses_with (new_input_node )
4448 new_input_node .meta .update (node .meta )
4549 g .erase_node (node )
4650
47- # needed to suppress `does not reference an nn.Module, nn.Parameter, or buffer` warning
48- gm .register_buffer (qualname , constant )
49- setattr (gm , qualname , constant )
51+ if constant is not None :
52+ # needed to suppress `does not reference an nn.Module, nn.Parameter, or buffer` warning
53+ gm .register_buffer (qualname , constant )
54+ setattr (gm , qualname , constant )
5055
5156
5257def is_const_source (
53- node : torch .fx .Node , lifted_constants : Optional [Dict [str , Any ]]
58+ node : torch .fx .Node , lifted_constant_names : Optional [List [str ]]
5459) -> bool :
55- return node .op == "get_attr" or (
56- node .op == "placeholder"
57- and lifted_constants is not None
58- and node .name in lifted_constants
59- )
60+ return node .op == "get_attr" or node .name in (lifted_constant_names or ())
6061
6162
6263class ConstantFolder (torch .fx .Interpreter ):
6364 def __init__ (
6465 self ,
6566 gm : torch .fx .GraphModule ,
6667 skip_constructors : bool = False ,
67- lifted_constants : Optional [Dict [str , torch . Tensor ]] = None ,
68+ lifted_constant_names : Optional [List [str ]] = None ,
6869 skip_folding_node_fn : Optional [Callable [[torch .fx .Node ], bool ]] = None ,
6970 ) -> None :
7071 super ().__init__ (gm )
@@ -76,14 +77,27 @@ def __init__(
7677 # overwrite this to deallocate env values if their only remaining use
7778 # is the output
7879 self .user_to_last_uses = self .node_to_last_non_output_use ()
79- self .lifted_constants = lifted_constants
80+ self .lifted_constant_names = lifted_constant_names
81+ self .deferred_value = object ()
8082
8183 def _support_dynamic_shape (self ) -> bool :
8284 # ConstantFolder not support dynamic shape now
8385 return False
8486
8587 def _deduce_value (self , node : torch .fx .Node ) -> Any :
86- return super ().run_node (node )
88+ if self .lifted_constant_names is None :
89+ return super ().run_node (node )
90+ # if lifted_constant_names is passed in, no concrete value is available
91+ # so we just check if all inputs have values
92+ flattened_node_inps = pytree .arg_tree_leaves (* node .args , ** node .kwargs )
93+ for inp in flattened_node_inps :
94+ if (
95+ isinstance (inp , torch .fx .Node )
96+ and inp .name not in (self .lifted_constant_names or ())
97+ and self .env [inp ] != self .deferred_value
98+ ):
99+ return self .unknown_value
100+ return self .deferred_value
87101
88102 def is_impure (self , node : torch .fx .node .Node ) -> bool :
89103 def is_woq_int8_pattern (node : torch .fx .node .Node ) -> bool :
@@ -103,7 +117,7 @@ def is_woq_int8_pattern(node: torch.fx.node.Node) -> bool:
103117 and is_woq_int8_pattern (next (iter (node .users )))
104118 )
105119 ) and is_const_source (
106- node .args [0 ], self .lifted_constants # type: ignore[arg-type]
120+ node .args [0 ], self .lifted_constant_names # type: ignore[arg-type]
107121 ):
108122 # Case 1: int8_weight -> dq -> bf16_weight
109123 # Case 2: int8_weight -> permute -> dq -> bf16_weight
@@ -191,7 +205,7 @@ def set_env(arg: torch.fx.Node) -> None:
191205 # TODO - more complicated strategy
192206 if (
193207 self .skip_constructors
194- and not is_const_source (node , self .lifted_constants )
208+ and not is_const_source (node , self .lifted_constant_names )
195209 and not any (isinstance (e , torch .Tensor ) for e in flattened_inputs )
196210 ):
197211 return self .unknown_value
@@ -207,10 +221,10 @@ def set_env(arg: torch.fx.Node) -> None:
207221 if out == self .unknown_value :
208222 return self .unknown_value
209223
210- if not is_const_source (node , self .lifted_constants ) and isinstance (
211- out , torch .Tensor
224+ if not is_const_source (node , self .lifted_constant_names ) and (
225+ isinstance ( out , torch .Tensor ) or out == self . deferred_value
212226 ):
213- if out .device .type == "meta" :
227+ if out != self . deferred_value and out .device .type == "meta" :
214228 return out
215229
216230 if not self .insertable_tensor_check (out ):
@@ -248,10 +262,12 @@ def run(self) -> Any: # type: ignore[override]
248262
249263 def insert_placerholder_values (self , env : Dict [torch .fx .Node , Any ]) -> None :
250264 for n in self .module .graph .find_nodes (op = "placeholder" ):
251- if self .lifted_constants is not None and n .name in self .lifted_constants :
252- env [n ] = self .lifted_constants [n .name ]
253- else :
254- env [n ] = self .unknown_value # type: ignore[assignment]
265+ env [n ] = self .unknown_value # type: ignore[assignment]
266+ if self .lifted_constant_names is None :
267+ return
268+ for n in self .module .graph .nodes :
269+ if n .name in (self .lifted_constant_names or ()):
270+ env [n ] = self .deferred_value
255271
256272
257273def constant_fold (
@@ -284,12 +300,15 @@ def constant_fold(
284300
285301def constant_graph_tag (
286302 gm : torch .fx .GraphModule ,
287- lifted_constants : Optional [Dict [str , Any ]],
288- skip_folding_node_fn : Optional [Callable [[torch .fx .Node ], bool ]],
303+ skip_constructors : bool = True ,
304+ lifted_constant_names : Optional [List [str ]] = None ,
305+ skip_folding_node_fn : Optional [Callable [[torch .fx .Node ], bool ]] = None ,
289306) -> None :
290307 with torch .utils ._python_dispatch ._disable_current_modes ():
291308 cf = ConstantFolder (
292- gm , skip_constructors = True , lifted_constants = lifted_constants
309+ gm ,
310+ skip_constructors = skip_constructors ,
311+ lifted_constant_names = lifted_constant_names ,
293312 )
294313 cf .run ()
295314
@@ -298,7 +317,7 @@ def constant_graph_tag(
298317 node .meta [META_TAG ] = MODULE_TAG
299318 continue
300319 if (
301- is_const_source (node , lifted_constants )
320+ is_const_source (node , lifted_constant_names )
302321 or node in cf .node_replacements
303322 or node in cf .replaced_uses
304323 ):
@@ -309,15 +328,18 @@ def constant_graph_tag(
309328
310329def run_and_get_constant_graph (
311330 gm : torch .fx .GraphModule ,
312- lifted_constants : Optional [Dict [str , Any ]],
313- skip_folding_node_fn : Optional [Callable [[torch .fx .Node ], bool ]],
314- ) -> Tuple [torch .fx .GraphModule , Tuple [torch .Tensor , ...]]:
331+ skip_constructors : bool = True ,
332+ lifted_constant_names : Optional [List [str ]] = None ,
333+ skip_folding_node_fn : Optional [Callable [[torch .fx .Node ], bool ]] = None ,
334+ ) -> torch .fx .GraphModule :
315335 """
316336 Construct a GraphModule which corresponds to the part which could be
317337 constant folded in provided gm.
318338 """
319339
320- constant_graph_tag (gm , lifted_constants , skip_folding_node_fn )
340+ constant_graph_tag (
341+ gm , skip_constructors , lifted_constant_names , skip_folding_node_fn
342+ )
321343
322344 def untag (node : torch .fx .Node ) -> bool :
323345 used_to_fold = False
@@ -329,19 +351,11 @@ def untag(node: torch.fx.Node) -> bool:
329351 node .meta [META_TAG ] = MODULE_TAG
330352 return used_to_fold
331353
332- const_args = []
333- if lifted_constants is not None :
334- placeholders = list (gm .graph .find_nodes (op = "placeholder" ))
335- for node in placeholders :
336- if node .meta [META_TAG ] == MODULE_TAG :
337- continue
338- if untag (node ):
339- const_args .append (lifted_constants [node .name ])
340-
341354 # We rewrite the tags, if it's a constant being directly consumed, without
342355 # any folding opportunity, we keep it in main gm.
343- for node in gm .graph .find_nodes (op = "get_attr" ):
344- untag (node )
356+ for node in gm .graph .nodes :
357+ if node .op == "getattr" or (node .name in (lifted_constant_names or ())):
358+ untag (node )
345359
346360 new_graph = torch .fx .Graph ()
347361
@@ -363,5 +377,4 @@ def untag(node: torch.fx.Node) -> bool:
363377 new_graph .lint ()
364378 new_gm = torch .fx .GraphModule (gm , new_graph )
365379
366- const_result = new_gm (* const_args )
367- return new_gm , const_result
380+ return new_gm
0 commit comments