Skip to content

Commit 329252c

Browse files
authored
[Autowrapper] Fix local names, increase reproducability (#1672)
## Purpose ## * Fix bug where `args` and `kwargs` were not being recognized as local names in the namespace when autowrapping * This was a problem for the Kimi K2 model which includes this line ```python3 def forward(..., kwargs): if "padding_mask" in kwargs: # kwargs was not being recognized as a local name ... ``` * Make autowrapper easier to debug by making wrapped function names and arguments consistent ## Changes ## * Add positional arguments and kwargs to set of local names * Add an incrementing counter which is used to name wrapped functions, rather than naming them with a hash * Refactored autowrapper tests to be easier to follow ## Testing ## * Can trace K2 model with higher granularity * Added test `test_function_variadic` --------- Signed-off-by: Kyle Sayers <[email protected]>
1 parent 3b4a0da commit 329252c

File tree

2 files changed

+107
-27
lines changed

2 files changed

+107
-27
lines changed

src/llmcompressor/pipelines/sequential/ast_utils/auto_wrapper.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ def __init__(self, namespace: Dict[str, Any], ignore: List[str]):
2626
self.ignore = ignore
2727
self._wrapper_fn_defs: List[ast.FunctionDef] = list()
2828
self._local_names = set()
29+
self._wrapped_counter = 0
2930

3031
def auto_wrap(self, tree: ast.Module) -> ast.Module:
3132
"""
@@ -56,6 +57,14 @@ def visit_FunctionDef(self, node: ast.FunctionDef) -> ast.FunctionDef:
5657
if node.name == "forward":
5758
for arg in node.args.args:
5859
self._local_names.add(arg.arg)
60+
for arg in node.args.posonlyargs:
61+
self._local_names.add(arg.arg)
62+
for arg in node.args.kwonlyargs:
63+
self._local_names.add(arg.arg)
64+
if node.args.vararg:
65+
self._local_names.add(node.args.vararg.arg)
66+
if node.args.kwarg:
67+
self._local_names.add(node.args.kwarg.arg)
5968
return super().generic_visit(node)
6069

6170
def visit_Name(self, node: ast.Name):
@@ -203,6 +212,11 @@ def _wrap_stmt(self, node: ast.stmt) -> ast.Assign:
203212
returns = assigned | conditionally_assigned
204213
assert "self" not in args, "Cannot trace self, this should be in the namespace"
205214

215+
# sort arguments for reproducability
216+
args = sorted(args)
217+
kwargs = sorted(kwargs)
218+
returns = sorted(returns)
219+
206220
# build function arguments
207221
args_obj = ast.arguments(
208222
args=[ast.arg(arg=name) for name in args],
@@ -217,21 +231,22 @@ def _wrap_stmt(self, node: ast.stmt) -> ast.Assign:
217231
# build body and return statement
218232
return_stmt = ast.Return(
219233
value=ast.Tuple(
220-
elts=[ast.Name(id=name, ctx=ast.Load()) for name in sorted(returns)],
234+
elts=[ast.Name(id=name, ctx=ast.Load()) for name in returns],
221235
ctx=ast.Load(),
222236
)
223237
)
224238
body = [node, return_stmt]
225239

226240
# build function definition, store in `_wrapper_fn_defs`
227-
fn_name = f"wrapped_{hash(node)}"
241+
fn_name = f"wrapped_{self._wrapped_counter}"
228242
fn_def = ast.FunctionDef(
229243
name=fn_name,
230244
args=args_obj,
231245
body=body,
232246
decorator_list=[ast.Name(id="torch.fx.wrap", ctx=ast.Load())],
233247
)
234248
self._wrapper_fn_defs.append(fn_def)
249+
self._wrapped_counter += 1
235250

236251
# build call and assignment
237252
fn_call = ast.Call(
@@ -240,13 +255,13 @@ def _wrap_stmt(self, node: ast.stmt) -> ast.Assign:
240255
keywords=list(),
241256
)
242257
return_tuple = ast.Tuple(
243-
elts=[ast.Name(id=name, ctx=ast.Store()) for name in sorted(returns)],
258+
elts=[ast.Name(id=name, ctx=ast.Store()) for name in returns],
244259
ctx=ast.Store(),
245260
)
246261
assign_call = ast.Assign(targets=[return_tuple], value=fn_call)
247262

248263
# update local names with newly returned values
249-
self._local_names |= returns
264+
self._local_names |= set(returns)
250265

251266
# log newly created function definition
252267
logger.debug("---- Autowrapper ----")

tests/llmcompressor/pipelines/sequential/ast_utils.py/test_auto_wrapper.py

Lines changed: 88 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,7 @@
88

99
def check_wrapping(
1010
source: str,
11-
output: Optional[str] = None,
12-
num_wrapped: int = 0,
11+
output: str,
1312
namespace: Optional[Dict[str, Any]] = None,
1413
ignore: Optional[List[str]] = None,
1514
):
@@ -20,15 +19,20 @@ def check_wrapping(
2019
wrapper = AutoWrapper(namespace, ignore)
2120
wrapped = wrapper.auto_wrap(tree)
2221

23-
if output is not None:
24-
wrapped_lines = ast.unparse(wrapped).splitlines()
25-
output_lines = textwrap.dedent(output).splitlines()[1:]
26-
assert wrapped_lines == output_lines
22+
wrapped_lines = ast.unparse(wrapped).splitlines()
23+
output_lines = textwrap.dedent(output).splitlines()[1:]
2724

28-
assert len(wrapper._wrapper_fn_defs) == num_wrapped
25+
assert len(wrapped_lines) == len(output_lines)
26+
for wrapped_line, output_line in zip(wrapped_lines, output_lines):
27+
if "# skip" in output:
28+
continue
29+
30+
assert wrapped_line == output_line
2931

3032

3133
def test_static_if():
34+
"""Checks that resolvable if statements are replaced"""
35+
3236
source = """
3337
def forward():
3438
if 1 + 1 == 2:
@@ -39,10 +43,12 @@ def forward():
3943
if True:
4044
pass
4145
"""
42-
check_wrapping(source, output, 0)
46+
check_wrapping(source, output)
4347

4448

4549
def test_static_if_global_vars():
50+
"""Checks that resolvable if statements are replaced"""
51+
4652
source = """
4753
def forward():
4854
if config.is_false:
@@ -54,20 +60,35 @@ def forward():
5460
pass
5561
"""
5662
config = SimpleNamespace(is_false=False)
57-
check_wrapping(source, output, 0, namespace={"config": config})
63+
check_wrapping(source, output, namespace={"config": config})
5864

5965

6066
def test_dynamic_if():
67+
"""Checks that non-resolvable if statements are ignored"""
68+
6169
source = """
6270
def forward():
6371
test = ...
6472
if test:
6573
pass
6674
"""
67-
check_wrapping(source, None, 1)
75+
output = """
76+
@torch.fx.wrap
77+
def wrapped_0(test):
78+
if test:
79+
pass
80+
return ()
81+
82+
def forward():
83+
test = ...
84+
() = wrapped_0(test)
85+
"""
86+
check_wrapping(source, output)
6887

6988

7089
def test_ignore_functions():
90+
"""Checks that ignored functions are wrapped"""
91+
7192
def func_one():
7293
pass
7394

@@ -79,11 +100,23 @@ def forward():
79100
func_one()
80101
func_two()
81102
"""
103+
output = """
104+
@torch.fx.wrap
105+
def wrapped_0():
106+
return func_one()
107+
return ()
108+
109+
def forward():
110+
wrapped_0()
111+
func_two()
112+
"""
82113
namespace = {"func_one": func_one, "func_two": func_two}
83-
check_wrapping(source, None, 1, namespace=namespace, ignore=["func_one"])
114+
check_wrapping(source, output, namespace=namespace, ignore=["func_one"])
84115

85116

86117
def test_ignore_methods():
118+
"""Checks that ignored class methods are wrapped"""
119+
87120
class Model:
88121
def meth_one(self):
89122
pass
@@ -96,11 +129,23 @@ def forward(self):
96129
self.meth_one()
97130
self.meth_two()
98131
"""
132+
output = """
133+
@torch.fx.wrap
134+
def wrapped_0():
135+
return self.meth_one()
136+
return ()
137+
138+
def forward(self):
139+
wrapped_0()
140+
self.meth_two()
141+
"""
99142
namespace = {"self": Model()}
100-
check_wrapping(source, None, 1, namespace=namespace, ignore=["meth_one"])
143+
check_wrapping(source, output, namespace=namespace, ignore=["meth_one"])
101144

102145

103146
def test_branch_with_self_assignment():
147+
"""Checks that names referenced in self assignment are included in fn args"""
148+
104149
source = """
105150
def forward(x, y):
106151
if y > 0:
@@ -109,18 +154,38 @@ def forward(x, y):
109154
x = x - 1
110155
return x
111156
"""
157+
output = """
158+
@torch.fx.wrap
159+
def wrapped_0(x, y):
160+
if y > 0:
161+
x = x + 1
162+
else:
163+
x = x - 1
164+
return (x,)
112165
113-
tree = ast.parse(textwrap.dedent(source))
114-
wrapper = AutoWrapper(namespace={}, ignore=[])
115-
wrapper.auto_wrap(tree)
166+
def forward(x, y):
167+
(x,) = wrapped_0(x, y) # skip: some envs use "(x,)" -> "x,"
168+
return x
169+
"""
170+
check_wrapping(source, output)
116171

117-
assert len(wrapper._wrapper_fn_defs) == 1
118172

119-
# Check if both x, y are included in args
120-
wrapped_fn = wrapper._wrapper_fn_defs[0]
121-
arg_names = {arg.arg for arg in wrapped_fn.args.args}
173+
def test_function_variadic():
174+
"""Checks for handling variadic names created via function def"""
175+
176+
source = """
177+
def forward(a, *b, c=5, **d):
178+
if a == b and c == d:
179+
pass
180+
"""
181+
output = """
182+
@torch.fx.wrap
183+
def wrapped_0(a, b, c, d):
184+
if a == b and c == d:
185+
pass
186+
return ()
122187
123-
assert arg_names == {
124-
"x",
125-
"y",
126-
}, f"Expected arguments {{'x', 'y'}}, but got {arg_names}"
188+
def forward(a, *b, c=5, **d):
189+
() = wrapped_0(a, b, c, d)
190+
"""
191+
check_wrapping(source, output)

0 commit comments

Comments
 (0)