Skip to content

Commit feb0c38

Browse files
committed
support gemma3n
Signed-off-by: Kyle Sayers <[email protected]>
1 parent 86f9557 commit feb0c38

File tree

6 files changed

+84
-32
lines changed

6 files changed

+84
-32
lines changed

src/llmcompressor/args/dataset_arguments.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,7 @@ class DatasetArguments(CustomDatasetArguments):
192192
"_prepare_4d_causal_attention_mask",
193193
"_prepare_fsmt_decoder_inputs",
194194
"_prepare_4d_causal_attention_mask_with_cache_position",
195+
"project_per_layer_inputs",
195196
],
196197
metadata={
197198
"help": "List of functions to ignore during tracing, either "

src/llmcompressor/pipelines/sequential/ast_helpers.py

Lines changed: 38 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,15 @@
44
import linecache
55
import sys
66
import textwrap
7+
import traceback
78
from typing import List
89

910
import torch
1011

1112
from llmcompressor.pipelines.sequential.ast_utils.auto_wrapper import AutoWrapper
1213
from llmcompressor.utils import patch_attr
1314

14-
__all__ = ["autowrap_forwards"]
15+
__all__ = ["autowrap_forwards", "append_autowrap_source_on_fail"]
1516

1617

1718
@contextlib.contextmanager
@@ -58,22 +59,53 @@ def autowrap_forward(module: torch.nn.Module, ignore: List[str]):
5859
# autowrap untraceable code
5960
auto_wrapper = AutoWrapper(namespace, ignore)
6061
tree = auto_wrapper.auto_wrap(tree)
62+
source = ast.unparse(tree)
6163

6264
# compile new forward function from autowrapped code
63-
filename = f"{module.__class__.__name__}_{hash(module)}_autowrapped"
64-
code = compile(tree, filename=filename, mode="exec")
65+
filename = f"<Autowrapped {module.__class__.__name__} {id(module)}>"
66+
code = compile(source, filename=filename, mode="exec")
6567
exec(code, namespace) # ensure ns of functions is the same ns as torch.fx.wrap
6668

6769
# enable better tracebacks if autowrapped code fails
68-
source_str = ast.unparse(tree)
6970
linecache.cache[filename] = (
70-
len(source_str),
71+
len(source),
7172
None,
72-
[line + "\n" for line in source_str.splitlines()],
73+
[line + "\n" for line in source.splitlines()],
7374
filename,
7475
)
7576

7677
# patch forward with autowrapped forward
7778
new_forward = namespace["forward"].__get__(module)
7879
with patch_attr(module, "forward", new_forward):
7980
yield
81+
82+
83+
@contextlib.contextmanager
84+
def append_autowrap_source_on_fail():
85+
try:
86+
yield
87+
except Exception as exception:
88+
_exc_type, _exc_value, exc_tb = sys.exc_info()
89+
tb_list = traceback.extract_tb(exc_tb)
90+
91+
collected_sources = []
92+
for frame in reversed(tb_list):
93+
if "Autowrapped" in frame.filename:
94+
source_lines = linecache.getlines(frame.filename)
95+
lineno = frame.lineno
96+
97+
# annotate failing line
98+
source_lines = [
99+
("> " if i + 1 == lineno else " ") + line
100+
for i, line in enumerate(source_lines)
101+
]
102+
103+
collected_sources.append(
104+
f"\n--- Autowrapped source for {frame.filename}:{lineno} ---\n"
105+
+ "".join(source_lines)
106+
)
107+
108+
new_message = f"{exception}\n\n" + "\n".join(collected_sources)
109+
raise RuntimeError(new_message) from exception
110+
111+
raise exception

src/llmcompressor/pipelines/sequential/ast_utils/auto_wrapper.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,13 @@ def visit_FunctionDef(self, node: ast.FunctionDef) -> ast.FunctionDef:
5353
:param node: function definition whose decorators will be stripped
5454
:return: function definition without decorators
5555
"""
56-
node.decorator_list = []
56+
node.decorator_list = [
57+
decorator_name
58+
for decorator_name in node.decorator_list
59+
if isinstance(decorator_name, ast.Name)
60+
and decorator_name.id in ("can_return_tuple",) # modifies func signature
61+
]
62+
5763
if node.name == "forward":
5864
for arg in node.args.args:
5965
self._local_names.add(arg.arg)
@@ -104,6 +110,11 @@ def visit_If(self, node: ast.If) -> Union[ast.If, ast.Assign]:
104110
try:
105111
value = bool(self._eval_expr(node.test))
106112

113+
# force a wrap if any assignments occur within the if statement
114+
for expr in ast.walk(node):
115+
if isinstance(expr, ast.NamedExpr):
116+
raise Exception("If statement contains assignment")
117+
107118
except Exception:
108119
return self._wrap_if_possible(node)
109120

@@ -165,8 +176,7 @@ def _can_wrap(self, node: ast.AST) -> bool:
165176
without its original context. In the future, we can add more checks for module
166177
calls (see `visit_If`)
167178
"""
168-
analyzer = ControlFlowAnalyzer()
169-
return analyzer.is_valid(node)
179+
return ControlFlowAnalyzer().is_valid(node)
170180

171181
def _wrap_if_possible(self, node: ast.AST) -> Union[ast.AST, ast.Assign, ast.Call]:
172182
"""

src/llmcompressor/pipelines/sequential/ast_utils/name_analyzer.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,13 @@ def visit_Assign(self, node: ast.Assign):
7474
for target in node.targets:
7575
self.visit(target)
7676

77+
def visit_NamedExpr(self, node: ast.NamedExpr):
78+
# Visit the right side of the assignment first
79+
self.visit(node.value)
80+
81+
# Now visit the left side of the assignment
82+
self.visit(node.target)
83+
7784
def visit_If(self, node: ast.If):
7885
self.visit(node.test)
7986

src/llmcompressor/pipelines/sequential/helpers.py

Lines changed: 22 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import inspect
33
from collections import deque
44
from dataclasses import dataclass
5+
from types import FunctionType, MethodType
56
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Set, Tuple
67

78
import torch
@@ -26,7 +27,7 @@
2627
from llmcompressor.utils.helpers import calibration_forward_context, patch_attr
2728
from llmcompressor.utils.pytorch.module import get_no_split_params
2829

29-
from .ast_helpers import autowrap_forwards
30+
from .ast_helpers import append_autowrap_source_on_fail, autowrap_forwards
3031

3132
if TYPE_CHECKING:
3233
from llmcompressor.args.dataset_arguments import DatasetArguments
@@ -69,17 +70,8 @@ def forward(self, *args, **kwargs) -> Dict[str, Any]:
6970

7071
forward_fn = self._code.globals.get("forward")
7172

72-
try:
73-
outputs = forward_fn(*args, **kwargs)
74-
except Exception as exception:
75-
raise RuntimeError(
76-
"Raised an exception during execution of the following code:\n"
77-
f"```\n{add_line_numbers(self._code.src)}\n```\n"
78-
"This is likely due to a violation of shape assumptions made when "
79-
"tracing"
80-
) from exception
81-
82-
return outputs
73+
with append_autowrap_source_on_fail():
74+
return forward_fn(*args, **kwargs)
8375

8476

8577
def trace_subgraphs(
@@ -120,19 +112,26 @@ def trace_subgraphs(
120112

121113
# autowrap forwards
122114
stack.enter_context(autowrap_forwards(ancestors, ignore))
123-
stack.enter_context(patch_attr(type(model), "forward", model.forward.__func__))
124115

125-
graph = GraphModule(
126-
model,
127-
tracer.trace(
116+
# avoid bug where pytorch cannot handle wrapped root functions
117+
unwrapped = inspect.unwrap(model.forward).__get__(model)
118+
stack.enter_context(patch_attr(model, "forward", unwrapped))
119+
stack.enter_context(patch_attr(type(model), "forward", unwrapped.__func__))
120+
assert isinstance(model.forward, MethodType)
121+
assert isinstance(type(model).forward, FunctionType)
122+
123+
with append_autowrap_source_on_fail():
124+
graph = GraphModule(
128125
model,
129-
dummy_inputs=sample_input,
130-
concrete_args=concrete_args,
131-
complete_concrete_args_with_inputs_not_in_dummy_inputs=False,
132-
# bug in trace throws an error for variadic
133-
# args and kwargs in function signature
134-
),
135-
)
126+
tracer.trace(
127+
model,
128+
dummy_inputs=sample_input,
129+
concrete_args=concrete_args,
130+
complete_concrete_args_with_inputs_not_in_dummy_inputs=False,
131+
# bug in trace throws an error for variadic
132+
# args and kwargs in function signature
133+
),
134+
)
136135

137136
# copy metadata
138137
graph.config = model.config

tests/llmcompressor/transformers/tracing/test_models.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from transformers import (
55
AutoModelForCausalLM,
66
Gemma3ForConditionalGeneration,
7+
Gemma3nForConditionalGeneration,
78
Idefics3ForConditionalGeneration,
89
Llama4ForConditionalGeneration,
910
LlavaForConditionalGeneration,
@@ -49,6 +50,7 @@
4950
"text",
5051
[],
5152
),
53+
("google/gemma-3n-E2B-it", AutoModelForCausalLM, None, "text", []),
5254
("unsloth/DeepSeek-R1-0528-BF16", AutoModelForCausalLM, None, "text", []),
5355
# --- vision ---
5456
(
@@ -122,6 +124,7 @@
122124
"vision",
123125
[],
124126
),
127+
("google/gemma-3n-E2B-it", Gemma3nForConditionalGeneration, None, "vision", []),
125128
# --- audio ---
126129
(
127130
"openai/whisper-large-v3",

0 commit comments

Comments
 (0)