From 22a41896eaff4cd3ee3a0d6830402fad3afbd918 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Wed, 30 Jul 2025 16:20:34 -0400 Subject: [PATCH 1/9] support gemma3n Signed-off-by: Kyle Sayers --- src/llmcompressor/args/dataset_arguments.py | 1 + .../pipelines/sequential/ast_helpers.py | 44 +++++++++++++++--- .../sequential/ast_utils/auto_wrapper.py | 16 +++++-- .../sequential/ast_utils/name_analyzer.py | 7 +++ .../pipelines/sequential/helpers.py | 45 +++++++++---------- .../transformers/tracing/test_models.py | 3 ++ 6 files changed, 84 insertions(+), 32 deletions(-) diff --git a/src/llmcompressor/args/dataset_arguments.py b/src/llmcompressor/args/dataset_arguments.py index e19850c80..6f90974df 100644 --- a/src/llmcompressor/args/dataset_arguments.py +++ b/src/llmcompressor/args/dataset_arguments.py @@ -192,6 +192,7 @@ class DatasetArguments(CustomDatasetArguments): "_prepare_4d_causal_attention_mask", "_prepare_fsmt_decoder_inputs", "_prepare_4d_causal_attention_mask_with_cache_position", + "project_per_layer_inputs", ], metadata={ "help": "List of functions to ignore during tracing, either " diff --git a/src/llmcompressor/pipelines/sequential/ast_helpers.py b/src/llmcompressor/pipelines/sequential/ast_helpers.py index 49a0ee774..af0334af5 100644 --- a/src/llmcompressor/pipelines/sequential/ast_helpers.py +++ b/src/llmcompressor/pipelines/sequential/ast_helpers.py @@ -4,6 +4,7 @@ import linecache import sys import textwrap +import traceback from typing import List import torch @@ -11,7 +12,7 @@ from llmcompressor.pipelines.sequential.ast_utils.auto_wrapper import AutoWrapper from llmcompressor.utils import patch_attr -__all__ = ["autowrap_forwards"] +__all__ = ["autowrap_forwards", "append_autowrap_source_on_fail"] @contextlib.contextmanager @@ -58,18 +59,18 @@ def autowrap_forward(module: torch.nn.Module, ignore: List[str]): # autowrap untraceable code auto_wrapper = AutoWrapper(namespace, ignore) tree = auto_wrapper.auto_wrap(tree) + source = ast.unparse(tree) # compile new forward function from autowrapped code - filename = f"{module.__class__.__name__}_{hash(module)}_autowrapped" - code = compile(tree, filename=filename, mode="exec") + filename = f"" + code = compile(source, filename=filename, mode="exec") exec(code, namespace) # ensure ns of functions is the same ns as torch.fx.wrap # enable better tracebacks if autowrapped code fails - source_str = ast.unparse(tree) linecache.cache[filename] = ( - len(source_str), + len(source), None, - [line + "\n" for line in source_str.splitlines()], + [line + "\n" for line in source.splitlines()], filename, ) @@ -77,3 +78,34 @@ def autowrap_forward(module: torch.nn.Module, ignore: List[str]): new_forward = namespace["forward"].__get__(module) with patch_attr(module, "forward", new_forward): yield + + +@contextlib.contextmanager +def append_autowrap_source_on_fail(): + try: + yield + except Exception as exception: + _exc_type, _exc_value, exc_tb = sys.exc_info() + tb_list = traceback.extract_tb(exc_tb) + + collected_sources = [] + for frame in reversed(tb_list): + if "Autowrapped" in frame.filename: + source_lines = linecache.getlines(frame.filename) + lineno = frame.lineno + + # annotate failing line + source_lines = [ + ("> " if i + 1 == lineno else " ") + line + for i, line in enumerate(source_lines) + ] + + collected_sources.append( + f"\n--- Autowrapped source for {frame.filename}:{lineno} ---\n" + + "".join(source_lines) + ) + + new_message = f"{exception}\n\n" + "\n".join(collected_sources) + raise RuntimeError(new_message) from exception + + raise exception diff --git a/src/llmcompressor/pipelines/sequential/ast_utils/auto_wrapper.py b/src/llmcompressor/pipelines/sequential/ast_utils/auto_wrapper.py index 2e78994e4..aa516bdfe 100644 --- a/src/llmcompressor/pipelines/sequential/ast_utils/auto_wrapper.py +++ b/src/llmcompressor/pipelines/sequential/ast_utils/auto_wrapper.py @@ -53,7 +53,13 @@ def visit_FunctionDef(self, node: ast.FunctionDef) -> ast.FunctionDef: :param node: function definition whose decorators will be stripped :return: function definition without decorators """ - node.decorator_list = [] + node.decorator_list = [ + decorator_name + for decorator_name in node.decorator_list + if isinstance(decorator_name, ast.Name) + and decorator_name.id in ("can_return_tuple",) # modifies func signature + ] + if node.name == "forward": for arg in node.args.args: self._local_names.add(arg.arg) @@ -104,6 +110,11 @@ def visit_If(self, node: ast.If) -> Union[ast.If, ast.Assign]: try: value = bool(self._eval_expr(node.test)) + # force a wrap if any assignments occur within the if statement + for expr in ast.walk(node): + if isinstance(expr, ast.NamedExpr): + raise Exception("If statement contains assignment") + except Exception: return self._wrap_if_possible(node) @@ -165,8 +176,7 @@ def _can_wrap(self, node: ast.AST) -> bool: without its original context. In the future, we can add more checks for module calls (see `visit_If`) """ - analyzer = ControlFlowAnalyzer() - return analyzer.is_valid(node) + return ControlFlowAnalyzer().is_valid(node) def _wrap_if_possible(self, node: ast.AST) -> Union[ast.AST, ast.Assign, ast.Call]: """ diff --git a/src/llmcompressor/pipelines/sequential/ast_utils/name_analyzer.py b/src/llmcompressor/pipelines/sequential/ast_utils/name_analyzer.py index 2fc75cad8..a9636a96a 100644 --- a/src/llmcompressor/pipelines/sequential/ast_utils/name_analyzer.py +++ b/src/llmcompressor/pipelines/sequential/ast_utils/name_analyzer.py @@ -74,6 +74,13 @@ def visit_Assign(self, node: ast.Assign): for target in node.targets: self.visit(target) + def visit_NamedExpr(self, node: ast.NamedExpr): + # Visit the right side of the assignment first + self.visit(node.value) + + # Now visit the left side of the assignment + self.visit(node.target) + def visit_If(self, node: ast.If): self.visit(node.test) diff --git a/src/llmcompressor/pipelines/sequential/helpers.py b/src/llmcompressor/pipelines/sequential/helpers.py index f80cb65a7..0252ef41a 100644 --- a/src/llmcompressor/pipelines/sequential/helpers.py +++ b/src/llmcompressor/pipelines/sequential/helpers.py @@ -2,6 +2,7 @@ import inspect from collections import deque from dataclasses import dataclass +from types import FunctionType, MethodType from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Set, Tuple import torch @@ -26,7 +27,7 @@ from llmcompressor.utils.helpers import calibration_forward_context, patch_attr from llmcompressor.utils.pytorch.module import get_no_split_params -from .ast_helpers import autowrap_forwards +from .ast_helpers import append_autowrap_source_on_fail, autowrap_forwards if TYPE_CHECKING: from llmcompressor.args.dataset_arguments import DatasetArguments @@ -69,17 +70,8 @@ def forward(self, *args, **kwargs) -> Dict[str, Any]: forward_fn = self._code.globals.get("forward") - try: - outputs = forward_fn(*args, **kwargs) - except Exception as exception: - raise RuntimeError( - "Raised an exception during execution of the following code:\n" - f"```\n{add_line_numbers(self._code.src)}\n```\n" - "This is likely due to a violation of shape assumptions made when " - "tracing" - ) from exception - - return outputs + with append_autowrap_source_on_fail(): + return forward_fn(*args, **kwargs) def trace_subgraphs( @@ -120,19 +112,26 @@ def trace_subgraphs( # autowrap forwards stack.enter_context(autowrap_forwards(ancestors, ignore)) - stack.enter_context(patch_attr(type(model), "forward", model.forward.__func__)) - graph = GraphModule( - model, - tracer.trace( + # avoid bug where pytorch cannot handle wrapped root functions + unwrapped = inspect.unwrap(model.forward).__get__(model) + stack.enter_context(patch_attr(model, "forward", unwrapped)) + stack.enter_context(patch_attr(type(model), "forward", unwrapped.__func__)) + assert isinstance(model.forward, MethodType) + assert isinstance(type(model).forward, FunctionType) + + with append_autowrap_source_on_fail(): + graph = GraphModule( model, - dummy_inputs=sample_input, - concrete_args=concrete_args, - complete_concrete_args_with_inputs_not_in_dummy_inputs=False, - # bug in trace throws an error for variadic - # args and kwargs in function signature - ), - ) + tracer.trace( + model, + dummy_inputs=sample_input, + concrete_args=concrete_args, + complete_concrete_args_with_inputs_not_in_dummy_inputs=False, + # bug in trace throws an error for variadic + # args and kwargs in function signature + ), + ) # copy metadata graph.config = model.config diff --git a/tests/llmcompressor/transformers/tracing/test_models.py b/tests/llmcompressor/transformers/tracing/test_models.py index ded1dffda..094d63def 100644 --- a/tests/llmcompressor/transformers/tracing/test_models.py +++ b/tests/llmcompressor/transformers/tracing/test_models.py @@ -4,6 +4,7 @@ from transformers import ( AutoModelForCausalLM, Gemma3ForConditionalGeneration, + Gemma3nForConditionalGeneration, Idefics3ForConditionalGeneration, Llama4ForConditionalGeneration, LlavaForConditionalGeneration, @@ -49,6 +50,7 @@ "text", [], ), + ("google/gemma-3n-E2B-it", AutoModelForCausalLM, None, "text", []), ("unsloth/DeepSeek-R1-0528-BF16", AutoModelForCausalLM, None, "text", []), # --- vision --- ( @@ -122,6 +124,7 @@ "vision", [], ), + ("google/gemma-3n-E2B-it", Gemma3nForConditionalGeneration, None, "vision", []), # --- audio --- ( "openai/whisper-large-v3", From 677614d7000a55cd9f41720b95fcdb2a845240e6 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Wed, 30 Jul 2025 16:22:35 -0400 Subject: [PATCH 2/9] style Signed-off-by: Kyle Sayers --- examples/quantization_w8a8_fp8/fp8_block_example.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/examples/quantization_w8a8_fp8/fp8_block_example.py b/examples/quantization_w8a8_fp8/fp8_block_example.py index b5d6ca1f9..68f13cf93 100644 --- a/examples/quantization_w8a8_fp8/fp8_block_example.py +++ b/examples/quantization_w8a8_fp8/fp8_block_example.py @@ -15,9 +15,7 @@ # In this case, we: # * quantize the weights to fp8 with per channel via ptq # * quantize the activations to fp8 with dynamic per token -recipe = QuantizationModifier( - targets="Linear", scheme="FP8_BLOCK", ignore=["lm_head"] -) +recipe = QuantizationModifier(targets="Linear", scheme="FP8_BLOCK", ignore=["lm_head"]) # Apply quantization. oneshot(model=model, recipe=recipe) From 17decaed7c85bacaa1991c6c7cca20b1218098cf Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Wed, 30 Jul 2025 16:42:02 -0400 Subject: [PATCH 3/9] cleanup, add test Signed-off-by: Kyle Sayers --- .../pipelines/sequential/ast_helpers.py | 12 +++------ .../ast_utils.py/test_auto_wrapper.py | 26 +++++++++++++++++-- 2 files changed, 28 insertions(+), 10 deletions(-) diff --git a/src/llmcompressor/pipelines/sequential/ast_helpers.py b/src/llmcompressor/pipelines/sequential/ast_helpers.py index af0334af5..c038529a1 100644 --- a/src/llmcompressor/pipelines/sequential/ast_helpers.py +++ b/src/llmcompressor/pipelines/sequential/ast_helpers.py @@ -88,7 +88,6 @@ def append_autowrap_source_on_fail(): _exc_type, _exc_value, exc_tb = sys.exc_info() tb_list = traceback.extract_tb(exc_tb) - collected_sources = [] for frame in reversed(tb_list): if "Autowrapped" in frame.filename: source_lines = linecache.getlines(frame.filename) @@ -100,12 +99,9 @@ def append_autowrap_source_on_fail(): for i, line in enumerate(source_lines) ] - collected_sources.append( - f"\n--- Autowrapped source for {frame.filename}:{lineno} ---\n" - + "".join(source_lines) - ) - - new_message = f"{exception}\n\n" + "\n".join(collected_sources) - raise RuntimeError(new_message) from exception + message = f"{exception}\n\n" + message += f"\n--- {frame.filename}:{lineno} ---\n" + message += "".join(source_lines) + raise RuntimeError(message) from exception raise exception diff --git a/tests/llmcompressor/pipelines/sequential/ast_utils.py/test_auto_wrapper.py b/tests/llmcompressor/pipelines/sequential/ast_utils.py/test_auto_wrapper.py index ab2c38161..8ad38a554 100644 --- a/tests/llmcompressor/pipelines/sequential/ast_utils.py/test_auto_wrapper.py +++ b/tests/llmcompressor/pipelines/sequential/ast_utils.py/test_auto_wrapper.py @@ -21,13 +21,14 @@ def check_wrapping( wrapped_lines = ast.unparse(wrapped).splitlines() output_lines = textwrap.dedent(output).splitlines()[1:] + lines = ("\n".join(wrapped_lines), "\n".join(output_lines)) - assert len(wrapped_lines) == len(output_lines) + assert len(wrapped_lines) == len(output_lines), lines for wrapped_line, output_line in zip(wrapped_lines, output_lines): if "# skip" in output: continue - assert wrapped_line == output_line + assert wrapped_line == output_line, lines def test_static_if(): @@ -189,3 +190,24 @@ def forward(a, *b, c=5, **d): () = wrapped_0(a, b, c, d) """ check_wrapping(source, output) + + +def test_walrus(): + """Checks for handling variadic names created via function def""" + + source = """ + def forward(): + if (asdf := (1 + 2)): + pass + """ + output = """ + @torch.fx.wrap + def wrapped_0(): + if (asdf := (1 + 2)): + pass + return (asdf,) + + def forward(): + (asdf,) = wrapped_0() + """ + check_wrapping(source, output) From d69da65e98a1616d7e42c5e2dffe358fbdce0c8a Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Wed, 30 Jul 2025 16:54:40 -0400 Subject: [PATCH 4/9] add example Signed-off-by: Kyle Sayers --- examples/multimodal_vision/gemma3n_example.py | 83 +++++++++++++++++++ .../ast_utils.py/test_auto_wrapper.py | 1 + 2 files changed, 84 insertions(+) create mode 100644 examples/multimodal_vision/gemma3n_example.py diff --git a/examples/multimodal_vision/gemma3n_example.py b/examples/multimodal_vision/gemma3n_example.py new file mode 100644 index 000000000..c0d6aefae --- /dev/null +++ b/examples/multimodal_vision/gemma3n_example.py @@ -0,0 +1,83 @@ +import requests +import torch +from PIL import Image +from transformers import AutoProcessor, Gemma3nForConditionalGeneration + +from llmcompressor import oneshot +from llmcompressor.modifiers.quantization import GPTQModifier +from llmcompressor.utils import dispatch_for_generation + +# Load model. +model_id = "google/gemma-3n-E2B-it" +model = Gemma3nForConditionalGeneration.from_pretrained(model_id, torch_dtype="auto") +processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True) + +# Oneshot arguments +DATASET_ID = "flickr30k" +DATASET_SPLIT = {"calibration": "test[:512]"} +NUM_CALIBRATION_SAMPLES = 512 +MAX_SEQUENCE_LENGTH = 2048 + + +# Define a oneshot data collator for multimodal inputs. +def data_collator(batch): + assert len(batch) == 1 + return {key: torch.tensor(value) for key, value in batch[0].items()} + + +# Recipe +recipe = [ + GPTQModifier( + targets="Linear", + scheme="W4A16", + ignore=[ + "re:.*embed_audio.*", + "re:.*embed_vision.*", + "re:.*audio_tower.*", + "re:.*vision_tower.*", + "re:.*altup.*", + "re:.*lm_head.*", + "re:.*laurel.*", + ], + ), +] + +# Perform oneshot +oneshot( + model=model, + tokenizer=model_id, + dataset=DATASET_ID, + splits=DATASET_SPLIT, + recipe=recipe, + max_seq_length=MAX_SEQUENCE_LENGTH, + num_calibration_samples=NUM_CALIBRATION_SAMPLES, + trust_remote_code_model=True, + data_collator=data_collator, +) + +# Confirm generations of the quantized model look sane. +print("========== SAMPLE GENERATION ==============") +dispatch_for_generation(model) +messages = [ + { + "role": "user", + "content": [ + {"type": "text", "text": "Please describe the animal in this image\n"}, + {"type": "image"}, + ], + }, +] +prompt = processor.apply_chat_template(messages, add_generation_prompt=True) +image_url = "http://images.cocodataset.org/train2017/000000231895.jpg" +raw_image = Image.open(requests.get(image_url, stream=True).raw) + +# Note: compile is disabled: https://github.com/huggingface/transformers/issues/38333 +inputs = processor(images=raw_image, text=prompt, return_tensors="pt").to("cuda") +output = model.generate(**inputs, max_new_tokens=100, disable_compile=True) +print(processor.decode(output[0], skip_special_tokens=True)) +print("==========================================") + +# Save to disk compressed. +SAVE_DIR = model_id.rstrip("/").split("/")[-1] + "-W4A16-G128" +model.save_pretrained(SAVE_DIR, save_compressed=True) +processor.save_pretrained(SAVE_DIR) diff --git a/tests/llmcompressor/pipelines/sequential/ast_utils.py/test_auto_wrapper.py b/tests/llmcompressor/pipelines/sequential/ast_utils.py/test_auto_wrapper.py index 8ad38a554..70be21d68 100644 --- a/tests/llmcompressor/pipelines/sequential/ast_utils.py/test_auto_wrapper.py +++ b/tests/llmcompressor/pipelines/sequential/ast_utils.py/test_auto_wrapper.py @@ -1,3 +1,4 @@ +# flake8: noqa import ast import textwrap from types import SimpleNamespace From 09e6cfeb5f98334360b7b738e9040d63f6487299 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Wed, 30 Jul 2025 17:19:13 -0400 Subject: [PATCH 5/9] add gemma3n vision example Signed-off-by: Kyle Sayers --- examples/multimodal_vision/gemma3n_example.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/examples/multimodal_vision/gemma3n_example.py b/examples/multimodal_vision/gemma3n_example.py index c0d6aefae..59f565067 100644 --- a/examples/multimodal_vision/gemma3n_example.py +++ b/examples/multimodal_vision/gemma3n_example.py @@ -53,6 +53,8 @@ def data_collator(batch): num_calibration_samples=NUM_CALIBRATION_SAMPLES, trust_remote_code_model=True, data_collator=data_collator, + # gemma3n has broken weight offloading which is required by the sequential pipeline + pipeline="basic", ) # Confirm generations of the quantized model look sane. From ab6f51043d248eb7a97ae10c9c7d777a8fd36731 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Wed, 30 Jul 2025 17:22:18 -0400 Subject: [PATCH 6/9] account for other environments Signed-off-by: Kyle Sayers --- .../sequential/ast_utils.py/test_auto_wrapper.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/llmcompressor/pipelines/sequential/ast_utils.py/test_auto_wrapper.py b/tests/llmcompressor/pipelines/sequential/ast_utils.py/test_auto_wrapper.py index 70be21d68..8bcc0ae17 100644 --- a/tests/llmcompressor/pipelines/sequential/ast_utils.py/test_auto_wrapper.py +++ b/tests/llmcompressor/pipelines/sequential/ast_utils.py/test_auto_wrapper.py @@ -198,17 +198,17 @@ def test_walrus(): source = """ def forward(): - if (asdf := (1 + 2)): + if (x := (1 + 2)): pass """ output = """ @torch.fx.wrap def wrapped_0(): - if (asdf := (1 + 2)): + if (x := (1 + 2)): pass - return (asdf,) + return (x,) def forward(): - (asdf,) = wrapped_0() + (x,) = wrapped_0() # skip: some envs use "(x,)" -> "x," """ check_wrapping(source, output) From 8b8140a43ad26422b1affb68321a7ff2269551ff Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Wed, 30 Jul 2025 17:45:51 -0400 Subject: [PATCH 7/9] do not untie Signed-off-by: Kyle Sayers --- examples/multimodal_vision/gemma3n_example.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/examples/multimodal_vision/gemma3n_example.py b/examples/multimodal_vision/gemma3n_example.py index 59f565067..12428cc2e 100644 --- a/examples/multimodal_vision/gemma3n_example.py +++ b/examples/multimodal_vision/gemma3n_example.py @@ -55,6 +55,8 @@ def data_collator(batch): data_collator=data_collator, # gemma3n has broken weight offloading which is required by the sequential pipeline pipeline="basic", + # gemma3n does not support untying word embeddings + tie_word_embeddings=True, ) # Confirm generations of the quantized model look sane. From e5e8c603d214901d722caa15398a2930c0532ef1 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Wed, 30 Jul 2025 18:33:35 -0400 Subject: [PATCH 8/9] add more to ignore list Signed-off-by: Kyle Sayers --- examples/multimodal_vision/gemma3n_example.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/examples/multimodal_vision/gemma3n_example.py b/examples/multimodal_vision/gemma3n_example.py index 12428cc2e..0a835532d 100644 --- a/examples/multimodal_vision/gemma3n_example.py +++ b/examples/multimodal_vision/gemma3n_example.py @@ -38,9 +38,13 @@ def data_collator(batch): "re:.*altup.*", "re:.*lm_head.*", "re:.*laurel.*", + "re:model\.language_model\.layers\.\d+\.per_layer_input_gate", + "re:model\.language_model\.layers\.\d+\.per_layer_projection", + "model.language_model.per_layer_model_projection", ], ), ] +breakpoint() # Perform oneshot oneshot( @@ -81,7 +85,7 @@ def data_collator(batch): print(processor.decode(output[0], skip_special_tokens=True)) print("==========================================") -# Save to disk compressed. -SAVE_DIR = model_id.rstrip("/").split("/")[-1] + "-W4A16-G128" -model.save_pretrained(SAVE_DIR, save_compressed=True) -processor.save_pretrained(SAVE_DIR) +# # Save to disk compressed. +# SAVE_DIR = model_id.rstrip("/").split("/")[-1] + "-W4A16-G128" +# model.save_pretrained(SAVE_DIR, save_compressed=True) +# processor.save_pretrained(SAVE_DIR) From 85d163457c76c52bf0c0846cc7b9ca8411dc67b0 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Wed, 30 Jul 2025 18:35:04 -0400 Subject: [PATCH 9/9] remove breakpoint Signed-off-by: Kyle Sayers --- examples/multimodal_vision/gemma3n_example.py | 1 - 1 file changed, 1 deletion(-) diff --git a/examples/multimodal_vision/gemma3n_example.py b/examples/multimodal_vision/gemma3n_example.py index 0a835532d..47145b3c5 100644 --- a/examples/multimodal_vision/gemma3n_example.py +++ b/examples/multimodal_vision/gemma3n_example.py @@ -44,7 +44,6 @@ def data_collator(batch): ], ), ] -breakpoint() # Perform oneshot oneshot(