Skip to content

Commit 17decae

Browse files
committed
cleanup, add test
Signed-off-by: Kyle Sayers <[email protected]>
1 parent 677614d commit 17decae

File tree

2 files changed

+28
-10
lines changed

2 files changed

+28
-10
lines changed

src/llmcompressor/pipelines/sequential/ast_helpers.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,6 @@ def append_autowrap_source_on_fail():
8888
_exc_type, _exc_value, exc_tb = sys.exc_info()
8989
tb_list = traceback.extract_tb(exc_tb)
9090

91-
collected_sources = []
9291
for frame in reversed(tb_list):
9392
if "Autowrapped" in frame.filename:
9493
source_lines = linecache.getlines(frame.filename)
@@ -100,12 +99,9 @@ def append_autowrap_source_on_fail():
10099
for i, line in enumerate(source_lines)
101100
]
102101

103-
collected_sources.append(
104-
f"\n--- Autowrapped source for {frame.filename}:{lineno} ---\n"
105-
+ "".join(source_lines)
106-
)
107-
108-
new_message = f"{exception}\n\n" + "\n".join(collected_sources)
109-
raise RuntimeError(new_message) from exception
102+
message = f"{exception}\n\n"
103+
message += f"\n--- {frame.filename}:{lineno} ---\n"
104+
message += "".join(source_lines)
105+
raise RuntimeError(message) from exception
110106

111107
raise exception

tests/llmcompressor/pipelines/sequential/ast_utils.py/test_auto_wrapper.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,13 +21,14 @@ def check_wrapping(
2121

2222
wrapped_lines = ast.unparse(wrapped).splitlines()
2323
output_lines = textwrap.dedent(output).splitlines()[1:]
24+
lines = ("\n".join(wrapped_lines), "\n".join(output_lines))
2425

25-
assert len(wrapped_lines) == len(output_lines)
26+
assert len(wrapped_lines) == len(output_lines), lines
2627
for wrapped_line, output_line in zip(wrapped_lines, output_lines):
2728
if "# skip" in output:
2829
continue
2930

30-
assert wrapped_line == output_line
31+
assert wrapped_line == output_line, lines
3132

3233

3334
def test_static_if():
@@ -189,3 +190,24 @@ def forward(a, *b, c=5, **d):
189190
() = wrapped_0(a, b, c, d)
190191
"""
191192
check_wrapping(source, output)
193+
194+
195+
def test_walrus():
196+
"""Checks for handling variadic names created via function def"""
197+
198+
source = """
199+
def forward():
200+
if (asdf := (1 + 2)):
201+
pass
202+
"""
203+
output = """
204+
@torch.fx.wrap
205+
def wrapped_0():
206+
if (asdf := (1 + 2)):
207+
pass
208+
return (asdf,)
209+
210+
def forward():
211+
(asdf,) = wrapped_0()
212+
"""
213+
check_wrapping(source, output)

0 commit comments

Comments
 (0)