diff --git a/examples/multimodal_vision/gemma3n_example.py b/examples/multimodal_vision/gemma3n_example.py new file mode 100644 index 000000000..47145b3c5 --- /dev/null +++ b/examples/multimodal_vision/gemma3n_example.py @@ -0,0 +1,90 @@ +import requests +import torch +from PIL import Image +from transformers import AutoProcessor, Gemma3nForConditionalGeneration + +from llmcompressor import oneshot +from llmcompressor.modifiers.quantization import GPTQModifier +from llmcompressor.utils import dispatch_for_generation + +# Load model. +model_id = "google/gemma-3n-E2B-it" +model = Gemma3nForConditionalGeneration.from_pretrained(model_id, torch_dtype="auto") +processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True) + +# Oneshot arguments +DATASET_ID = "flickr30k" +DATASET_SPLIT = {"calibration": "test[:512]"} +NUM_CALIBRATION_SAMPLES = 512 +MAX_SEQUENCE_LENGTH = 2048 + + +# Define a oneshot data collator for multimodal inputs. +def data_collator(batch): + assert len(batch) == 1 + return {key: torch.tensor(value) for key, value in batch[0].items()} + + +# Recipe +recipe = [ + GPTQModifier( + targets="Linear", + scheme="W4A16", + ignore=[ + "re:.*embed_audio.*", + "re:.*embed_vision.*", + "re:.*audio_tower.*", + "re:.*vision_tower.*", + "re:.*altup.*", + "re:.*lm_head.*", + "re:.*laurel.*", + "re:model\.language_model\.layers\.\d+\.per_layer_input_gate", + "re:model\.language_model\.layers\.\d+\.per_layer_projection", + "model.language_model.per_layer_model_projection", + ], + ), +] + +# Perform oneshot +oneshot( + model=model, + tokenizer=model_id, + dataset=DATASET_ID, + splits=DATASET_SPLIT, + recipe=recipe, + max_seq_length=MAX_SEQUENCE_LENGTH, + num_calibration_samples=NUM_CALIBRATION_SAMPLES, + trust_remote_code_model=True, + data_collator=data_collator, + # gemma3n has broken weight offloading which is required by the sequential pipeline + pipeline="basic", + # gemma3n does not support untying word embeddings + tie_word_embeddings=True, +) + +# Confirm generations of the quantized model look sane. +print("========== SAMPLE GENERATION ==============") +dispatch_for_generation(model) +messages = [ + { + "role": "user", + "content": [ + {"type": "text", "text": "Please describe the animal in this image\n"}, + {"type": "image"}, + ], + }, +] +prompt = processor.apply_chat_template(messages, add_generation_prompt=True) +image_url = "http://images.cocodataset.org/train2017/000000231895.jpg" +raw_image = Image.open(requests.get(image_url, stream=True).raw) + +# Note: compile is disabled: https://github.com/huggingface/transformers/issues/38333 +inputs = processor(images=raw_image, text=prompt, return_tensors="pt").to("cuda") +output = model.generate(**inputs, max_new_tokens=100, disable_compile=True) +print(processor.decode(output[0], skip_special_tokens=True)) +print("==========================================") + +# # Save to disk compressed. +# SAVE_DIR = model_id.rstrip("/").split("/")[-1] + "-W4A16-G128" +# model.save_pretrained(SAVE_DIR, save_compressed=True) +# processor.save_pretrained(SAVE_DIR) diff --git a/examples/quantization_w8a8_fp8/fp8_block_example.py b/examples/quantization_w8a8_fp8/fp8_block_example.py index b5d6ca1f9..68f13cf93 100644 --- a/examples/quantization_w8a8_fp8/fp8_block_example.py +++ b/examples/quantization_w8a8_fp8/fp8_block_example.py @@ -15,9 +15,7 @@ # In this case, we: # * quantize the weights to fp8 with per channel via ptq # * quantize the activations to fp8 with dynamic per token -recipe = QuantizationModifier( - targets="Linear", scheme="FP8_BLOCK", ignore=["lm_head"] -) +recipe = QuantizationModifier(targets="Linear", scheme="FP8_BLOCK", ignore=["lm_head"]) # Apply quantization. oneshot(model=model, recipe=recipe) diff --git a/src/llmcompressor/args/dataset_arguments.py b/src/llmcompressor/args/dataset_arguments.py index e19850c80..6f90974df 100644 --- a/src/llmcompressor/args/dataset_arguments.py +++ b/src/llmcompressor/args/dataset_arguments.py @@ -192,6 +192,7 @@ class DatasetArguments(CustomDatasetArguments): "_prepare_4d_causal_attention_mask", "_prepare_fsmt_decoder_inputs", "_prepare_4d_causal_attention_mask_with_cache_position", + "project_per_layer_inputs", ], metadata={ "help": "List of functions to ignore during tracing, either " diff --git a/src/llmcompressor/pipelines/sequential/ast_helpers.py b/src/llmcompressor/pipelines/sequential/ast_helpers.py index 49a0ee774..c038529a1 100644 --- a/src/llmcompressor/pipelines/sequential/ast_helpers.py +++ b/src/llmcompressor/pipelines/sequential/ast_helpers.py @@ -4,6 +4,7 @@ import linecache import sys import textwrap +import traceback from typing import List import torch @@ -11,7 +12,7 @@ from llmcompressor.pipelines.sequential.ast_utils.auto_wrapper import AutoWrapper from llmcompressor.utils import patch_attr -__all__ = ["autowrap_forwards"] +__all__ = ["autowrap_forwards", "append_autowrap_source_on_fail"] @contextlib.contextmanager @@ -58,18 +59,18 @@ def autowrap_forward(module: torch.nn.Module, ignore: List[str]): # autowrap untraceable code auto_wrapper = AutoWrapper(namespace, ignore) tree = auto_wrapper.auto_wrap(tree) + source = ast.unparse(tree) # compile new forward function from autowrapped code - filename = f"{module.__class__.__name__}_{hash(module)}_autowrapped" - code = compile(tree, filename=filename, mode="exec") + filename = f"" + code = compile(source, filename=filename, mode="exec") exec(code, namespace) # ensure ns of functions is the same ns as torch.fx.wrap # enable better tracebacks if autowrapped code fails - source_str = ast.unparse(tree) linecache.cache[filename] = ( - len(source_str), + len(source), None, - [line + "\n" for line in source_str.splitlines()], + [line + "\n" for line in source.splitlines()], filename, ) @@ -77,3 +78,30 @@ def autowrap_forward(module: torch.nn.Module, ignore: List[str]): new_forward = namespace["forward"].__get__(module) with patch_attr(module, "forward", new_forward): yield + + +@contextlib.contextmanager +def append_autowrap_source_on_fail(): + try: + yield + except Exception as exception: + _exc_type, _exc_value, exc_tb = sys.exc_info() + tb_list = traceback.extract_tb(exc_tb) + + for frame in reversed(tb_list): + if "Autowrapped" in frame.filename: + source_lines = linecache.getlines(frame.filename) + lineno = frame.lineno + + # annotate failing line + source_lines = [ + ("> " if i + 1 == lineno else " ") + line + for i, line in enumerate(source_lines) + ] + + message = f"{exception}\n\n" + message += f"\n--- {frame.filename}:{lineno} ---\n" + message += "".join(source_lines) + raise RuntimeError(message) from exception + + raise exception diff --git a/src/llmcompressor/pipelines/sequential/ast_utils/auto_wrapper.py b/src/llmcompressor/pipelines/sequential/ast_utils/auto_wrapper.py index 2e78994e4..aa516bdfe 100644 --- a/src/llmcompressor/pipelines/sequential/ast_utils/auto_wrapper.py +++ b/src/llmcompressor/pipelines/sequential/ast_utils/auto_wrapper.py @@ -53,7 +53,13 @@ def visit_FunctionDef(self, node: ast.FunctionDef) -> ast.FunctionDef: :param node: function definition whose decorators will be stripped :return: function definition without decorators """ - node.decorator_list = [] + node.decorator_list = [ + decorator_name + for decorator_name in node.decorator_list + if isinstance(decorator_name, ast.Name) + and decorator_name.id in ("can_return_tuple",) # modifies func signature + ] + if node.name == "forward": for arg in node.args.args: self._local_names.add(arg.arg) @@ -104,6 +110,11 @@ def visit_If(self, node: ast.If) -> Union[ast.If, ast.Assign]: try: value = bool(self._eval_expr(node.test)) + # force a wrap if any assignments occur within the if statement + for expr in ast.walk(node): + if isinstance(expr, ast.NamedExpr): + raise Exception("If statement contains assignment") + except Exception: return self._wrap_if_possible(node) @@ -165,8 +176,7 @@ def _can_wrap(self, node: ast.AST) -> bool: without its original context. In the future, we can add more checks for module calls (see `visit_If`) """ - analyzer = ControlFlowAnalyzer() - return analyzer.is_valid(node) + return ControlFlowAnalyzer().is_valid(node) def _wrap_if_possible(self, node: ast.AST) -> Union[ast.AST, ast.Assign, ast.Call]: """ diff --git a/src/llmcompressor/pipelines/sequential/ast_utils/name_analyzer.py b/src/llmcompressor/pipelines/sequential/ast_utils/name_analyzer.py index 2fc75cad8..a9636a96a 100644 --- a/src/llmcompressor/pipelines/sequential/ast_utils/name_analyzer.py +++ b/src/llmcompressor/pipelines/sequential/ast_utils/name_analyzer.py @@ -74,6 +74,13 @@ def visit_Assign(self, node: ast.Assign): for target in node.targets: self.visit(target) + def visit_NamedExpr(self, node: ast.NamedExpr): + # Visit the right side of the assignment first + self.visit(node.value) + + # Now visit the left side of the assignment + self.visit(node.target) + def visit_If(self, node: ast.If): self.visit(node.test) diff --git a/src/llmcompressor/pipelines/sequential/helpers.py b/src/llmcompressor/pipelines/sequential/helpers.py index 51e7c3a74..92dddcbaf 100644 --- a/src/llmcompressor/pipelines/sequential/helpers.py +++ b/src/llmcompressor/pipelines/sequential/helpers.py @@ -2,6 +2,7 @@ import inspect from collections import deque from dataclasses import dataclass +from types import FunctionType, MethodType from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Set, Tuple import torch @@ -26,7 +27,7 @@ from llmcompressor.utils.helpers import calibration_forward_context, patch_attr from llmcompressor.utils.pytorch.module import get_no_split_params -from .ast_helpers import autowrap_forwards +from .ast_helpers import append_autowrap_source_on_fail, autowrap_forwards if TYPE_CHECKING: from llmcompressor.args.dataset_arguments import DatasetArguments @@ -69,15 +70,8 @@ def forward(self, *args, **kwargs) -> Dict[str, Any]: forward_fn = self._code.globals.get("forward") - try: - outputs = forward_fn(*args, **kwargs) - except Exception as exception: - raise RuntimeError( - "Raised an exception during execution of the following code:\n" - f"```\n{add_line_numbers(self._code.src)}\n```" - ) from exception - - return outputs + with append_autowrap_source_on_fail(): + return forward_fn(*args, **kwargs) def trace_subgraphs( @@ -118,19 +112,26 @@ def trace_subgraphs( # autowrap forwards stack.enter_context(autowrap_forwards(ancestors, ignore)) - stack.enter_context(patch_attr(type(model), "forward", model.forward.__func__)) - graph = GraphModule( - model, - tracer.trace( + # avoid bug where pytorch cannot handle wrapped root functions + unwrapped = inspect.unwrap(model.forward).__get__(model) + stack.enter_context(patch_attr(model, "forward", unwrapped)) + stack.enter_context(patch_attr(type(model), "forward", unwrapped.__func__)) + assert isinstance(model.forward, MethodType) + assert isinstance(type(model).forward, FunctionType) + + with append_autowrap_source_on_fail(): + graph = GraphModule( model, - dummy_inputs=sample_input, - concrete_args=concrete_args, - complete_concrete_args_with_inputs_not_in_dummy_inputs=False, - # bug in trace throws an error for variadic - # args and kwargs in function signature - ), - ) + tracer.trace( + model, + dummy_inputs=sample_input, + concrete_args=concrete_args, + complete_concrete_args_with_inputs_not_in_dummy_inputs=False, + # bug in trace throws an error for variadic + # args and kwargs in function signature + ), + ) # copy metadata graph.config = model.config @@ -483,12 +484,6 @@ def get_sequential_targets( return sequential_targets -def add_line_numbers(text: str) -> str: - lines = text.splitlines() - numbered_lines = [f"{i + 1} {line}" for i, line in enumerate(lines)] - return "\n".join(numbered_lines) - - def get_sequential_ancestors(model: Module, targets: Set[Module]) -> Set[Module]: """ Find modules which are call graph ancestors of the given sequential targets diff --git a/tests/llmcompressor/pipelines/sequential/ast_utils.py/test_auto_wrapper.py b/tests/llmcompressor/pipelines/sequential/ast_utils.py/test_auto_wrapper.py index ab2c38161..8bcc0ae17 100644 --- a/tests/llmcompressor/pipelines/sequential/ast_utils.py/test_auto_wrapper.py +++ b/tests/llmcompressor/pipelines/sequential/ast_utils.py/test_auto_wrapper.py @@ -1,3 +1,4 @@ +# flake8: noqa import ast import textwrap from types import SimpleNamespace @@ -21,13 +22,14 @@ def check_wrapping( wrapped_lines = ast.unparse(wrapped).splitlines() output_lines = textwrap.dedent(output).splitlines()[1:] + lines = ("\n".join(wrapped_lines), "\n".join(output_lines)) - assert len(wrapped_lines) == len(output_lines) + assert len(wrapped_lines) == len(output_lines), lines for wrapped_line, output_line in zip(wrapped_lines, output_lines): if "# skip" in output: continue - assert wrapped_line == output_line + assert wrapped_line == output_line, lines def test_static_if(): @@ -189,3 +191,24 @@ def forward(a, *b, c=5, **d): () = wrapped_0(a, b, c, d) """ check_wrapping(source, output) + + +def test_walrus(): + """Checks for handling variadic names created via function def""" + + source = """ + def forward(): + if (x := (1 + 2)): + pass + """ + output = """ + @torch.fx.wrap + def wrapped_0(): + if (x := (1 + 2)): + pass + return (x,) + + def forward(): + (x,) = wrapped_0() # skip: some envs use "(x,)" -> "x," + """ + check_wrapping(source, output) diff --git a/tests/llmcompressor/transformers/tracing/test_models.py b/tests/llmcompressor/transformers/tracing/test_models.py index ded1dffda..094d63def 100644 --- a/tests/llmcompressor/transformers/tracing/test_models.py +++ b/tests/llmcompressor/transformers/tracing/test_models.py @@ -4,6 +4,7 @@ from transformers import ( AutoModelForCausalLM, Gemma3ForConditionalGeneration, + Gemma3nForConditionalGeneration, Idefics3ForConditionalGeneration, Llama4ForConditionalGeneration, LlavaForConditionalGeneration, @@ -49,6 +50,7 @@ "text", [], ), + ("google/gemma-3n-E2B-it", AutoModelForCausalLM, None, "text", []), ("unsloth/DeepSeek-R1-0528-BF16", AutoModelForCausalLM, None, "text", []), # --- vision --- ( @@ -122,6 +124,7 @@ "vision", [], ), + ("google/gemma-3n-E2B-it", Gemma3nForConditionalGeneration, None, "vision", []), # --- audio --- ( "openai/whisper-large-v3",