Skip to content

[Example] [VLM] Gemma3n #1696

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 10 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 90 additions & 0 deletions examples/multimodal_vision/gemma3n_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
import requests
import torch
from PIL import Image
from transformers import AutoProcessor, Gemma3nForConditionalGeneration

from llmcompressor import oneshot
from llmcompressor.modifiers.quantization import GPTQModifier
from llmcompressor.utils import dispatch_for_generation

# Load model.
model_id = "google/gemma-3n-E2B-it"
model = Gemma3nForConditionalGeneration.from_pretrained(model_id, torch_dtype="auto")
processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)

# Oneshot arguments
DATASET_ID = "flickr30k"
DATASET_SPLIT = {"calibration": "test[:512]"}
NUM_CALIBRATION_SAMPLES = 512
MAX_SEQUENCE_LENGTH = 2048


# Define a oneshot data collator for multimodal inputs.
def data_collator(batch):
assert len(batch) == 1
return {key: torch.tensor(value) for key, value in batch[0].items()}


# Recipe
recipe = [
GPTQModifier(
targets="Linear",
scheme="W4A16",
ignore=[
"re:.*embed_audio.*",
"re:.*embed_vision.*",
"re:.*audio_tower.*",
"re:.*vision_tower.*",
"re:.*altup.*",
"re:.*lm_head.*",
"re:.*laurel.*",
"re:model\.language_model\.layers\.\d+\.per_layer_input_gate",
"re:model\.language_model\.layers\.\d+\.per_layer_projection",
"model.language_model.per_layer_model_projection",
],
),
]

# Perform oneshot
oneshot(
model=model,
tokenizer=model_id,
dataset=DATASET_ID,
splits=DATASET_SPLIT,
recipe=recipe,
max_seq_length=MAX_SEQUENCE_LENGTH,
num_calibration_samples=NUM_CALIBRATION_SAMPLES,
trust_remote_code_model=True,
data_collator=data_collator,
# gemma3n has broken weight offloading which is required by the sequential pipeline
pipeline="basic",
# gemma3n does not support untying word embeddings
tie_word_embeddings=True,
)

# Confirm generations of the quantized model look sane.
print("========== SAMPLE GENERATION ==============")
dispatch_for_generation(model)
messages = [
{
"role": "user",
"content": [
{"type": "text", "text": "Please describe the animal in this image\n"},
{"type": "image"},
],
},
]
prompt = processor.apply_chat_template(messages, add_generation_prompt=True)
image_url = "http://images.cocodataset.org/train2017/000000231895.jpg"
raw_image = Image.open(requests.get(image_url, stream=True).raw)

# Note: compile is disabled: https://github.com/huggingface/transformers/issues/38333
inputs = processor(images=raw_image, text=prompt, return_tensors="pt").to("cuda")
output = model.generate(**inputs, max_new_tokens=100, disable_compile=True)
print(processor.decode(output[0], skip_special_tokens=True))
print("==========================================")

# # Save to disk compressed.
# SAVE_DIR = model_id.rstrip("/").split("/")[-1] + "-W4A16-G128"
# model.save_pretrained(SAVE_DIR, save_compressed=True)
# processor.save_pretrained(SAVE_DIR)
4 changes: 1 addition & 3 deletions examples/quantization_w8a8_fp8/fp8_block_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,7 @@
# In this case, we:
# * quantize the weights to fp8 with per channel via ptq
# * quantize the activations to fp8 with dynamic per token
recipe = QuantizationModifier(
targets="Linear", scheme="FP8_BLOCK", ignore=["lm_head"]
)
recipe = QuantizationModifier(targets="Linear", scheme="FP8_BLOCK", ignore=["lm_head"])

# Apply quantization.
oneshot(model=model, recipe=recipe)
Expand Down
1 change: 1 addition & 0 deletions src/llmcompressor/args/dataset_arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,7 @@ class DatasetArguments(CustomDatasetArguments):
"_prepare_4d_causal_attention_mask",
"_prepare_fsmt_decoder_inputs",
"_prepare_4d_causal_attention_mask_with_cache_position",
"project_per_layer_inputs",
],
metadata={
"help": "List of functions to ignore during tracing, either "
Expand Down
40 changes: 34 additions & 6 deletions src/llmcompressor/pipelines/sequential/ast_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,15 @@
import linecache
import sys
import textwrap
import traceback
from typing import List

import torch

from llmcompressor.pipelines.sequential.ast_utils.auto_wrapper import AutoWrapper
from llmcompressor.utils import patch_attr

__all__ = ["autowrap_forwards"]
__all__ = ["autowrap_forwards", "append_autowrap_source_on_fail"]


@contextlib.contextmanager
Expand Down Expand Up @@ -58,22 +59,49 @@ def autowrap_forward(module: torch.nn.Module, ignore: List[str]):
# autowrap untraceable code
auto_wrapper = AutoWrapper(namespace, ignore)
tree = auto_wrapper.auto_wrap(tree)
source = ast.unparse(tree)

# compile new forward function from autowrapped code
filename = f"{module.__class__.__name__}_{hash(module)}_autowrapped"
code = compile(tree, filename=filename, mode="exec")
filename = f"<Autowrapped {module.__class__.__name__} {id(module)}>"
code = compile(source, filename=filename, mode="exec")
exec(code, namespace) # ensure ns of functions is the same ns as torch.fx.wrap

# enable better tracebacks if autowrapped code fails
source_str = ast.unparse(tree)
linecache.cache[filename] = (
len(source_str),
len(source),
None,
[line + "\n" for line in source_str.splitlines()],
[line + "\n" for line in source.splitlines()],
filename,
)

# patch forward with autowrapped forward
new_forward = namespace["forward"].__get__(module)
with patch_attr(module, "forward", new_forward):
yield


@contextlib.contextmanager
def append_autowrap_source_on_fail():
try:
yield
except Exception as exception:
_exc_type, _exc_value, exc_tb = sys.exc_info()
tb_list = traceback.extract_tb(exc_tb)

for frame in reversed(tb_list):
if "Autowrapped" in frame.filename:
source_lines = linecache.getlines(frame.filename)
lineno = frame.lineno

# annotate failing line
source_lines = [
("> " if i + 1 == lineno else " ") + line
for i, line in enumerate(source_lines)
]

message = f"{exception}\n\n"
message += f"\n--- {frame.filename}:{lineno} ---\n"
message += "".join(source_lines)
raise RuntimeError(message) from exception

raise exception
16 changes: 13 additions & 3 deletions src/llmcompressor/pipelines/sequential/ast_utils/auto_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,13 @@ def visit_FunctionDef(self, node: ast.FunctionDef) -> ast.FunctionDef:
:param node: function definition whose decorators will be stripped
:return: function definition without decorators
"""
node.decorator_list = []
node.decorator_list = [
decorator_name
for decorator_name in node.decorator_list
if isinstance(decorator_name, ast.Name)
and decorator_name.id in ("can_return_tuple",) # modifies func signature
]

if node.name == "forward":
for arg in node.args.args:
self._local_names.add(arg.arg)
Expand Down Expand Up @@ -104,6 +110,11 @@ def visit_If(self, node: ast.If) -> Union[ast.If, ast.Assign]:
try:
value = bool(self._eval_expr(node.test))

# force a wrap if any assignments occur within the if statement
for expr in ast.walk(node):
if isinstance(expr, ast.NamedExpr):
raise Exception("If statement contains assignment")

except Exception:
return self._wrap_if_possible(node)

Expand Down Expand Up @@ -165,8 +176,7 @@ def _can_wrap(self, node: ast.AST) -> bool:
without its original context. In the future, we can add more checks for module
calls (see `visit_If`)
"""
analyzer = ControlFlowAnalyzer()
return analyzer.is_valid(node)
return ControlFlowAnalyzer().is_valid(node)

def _wrap_if_possible(self, node: ast.AST) -> Union[ast.AST, ast.Assign, ast.Call]:
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,13 @@ def visit_Assign(self, node: ast.Assign):
for target in node.targets:
self.visit(target)

def visit_NamedExpr(self, node: ast.NamedExpr):
# Visit the right side of the assignment first
self.visit(node.value)

# Now visit the left side of the assignment
self.visit(node.target)

def visit_If(self, node: ast.If):
self.visit(node.test)

Expand Down
49 changes: 22 additions & 27 deletions src/llmcompressor/pipelines/sequential/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import inspect
from collections import deque
from dataclasses import dataclass
from types import FunctionType, MethodType
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Set, Tuple

import torch
Expand All @@ -26,7 +27,7 @@
from llmcompressor.utils.helpers import calibration_forward_context, patch_attr
from llmcompressor.utils.pytorch.module import get_no_split_params

from .ast_helpers import autowrap_forwards
from .ast_helpers import append_autowrap_source_on_fail, autowrap_forwards

if TYPE_CHECKING:
from llmcompressor.args.dataset_arguments import DatasetArguments
Expand Down Expand Up @@ -69,15 +70,8 @@ def forward(self, *args, **kwargs) -> Dict[str, Any]:

forward_fn = self._code.globals.get("forward")

try:
outputs = forward_fn(*args, **kwargs)
except Exception as exception:
raise RuntimeError(
"Raised an exception during execution of the following code:\n"
f"```\n{add_line_numbers(self._code.src)}\n```"
) from exception

return outputs
with append_autowrap_source_on_fail():
return forward_fn(*args, **kwargs)


def trace_subgraphs(
Expand Down Expand Up @@ -118,19 +112,26 @@ def trace_subgraphs(

# autowrap forwards
stack.enter_context(autowrap_forwards(ancestors, ignore))
stack.enter_context(patch_attr(type(model), "forward", model.forward.__func__))

graph = GraphModule(
model,
tracer.trace(
# avoid bug where pytorch cannot handle wrapped root functions
unwrapped = inspect.unwrap(model.forward).__get__(model)
stack.enter_context(patch_attr(model, "forward", unwrapped))
stack.enter_context(patch_attr(type(model), "forward", unwrapped.__func__))
assert isinstance(model.forward, MethodType)
assert isinstance(type(model).forward, FunctionType)

with append_autowrap_source_on_fail():
graph = GraphModule(
model,
dummy_inputs=sample_input,
concrete_args=concrete_args,
complete_concrete_args_with_inputs_not_in_dummy_inputs=False,
# bug in trace throws an error for variadic
# args and kwargs in function signature
),
)
tracer.trace(
model,
dummy_inputs=sample_input,
concrete_args=concrete_args,
complete_concrete_args_with_inputs_not_in_dummy_inputs=False,
# bug in trace throws an error for variadic
# args and kwargs in function signature
),
)

# copy metadata
graph.config = model.config
Expand Down Expand Up @@ -483,12 +484,6 @@ def get_sequential_targets(
return sequential_targets


def add_line_numbers(text: str) -> str:
lines = text.splitlines()
numbered_lines = [f"{i + 1} {line}" for i, line in enumerate(lines)]
return "\n".join(numbered_lines)


def get_sequential_ancestors(model: Module, targets: Set[Module]) -> Set[Module]:
"""
Find modules which are call graph ancestors of the given sequential targets
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# flake8: noqa
import ast
import textwrap
from types import SimpleNamespace
Expand All @@ -21,13 +22,14 @@ def check_wrapping(

wrapped_lines = ast.unparse(wrapped).splitlines()
output_lines = textwrap.dedent(output).splitlines()[1:]
lines = ("\n".join(wrapped_lines), "\n".join(output_lines))

assert len(wrapped_lines) == len(output_lines)
assert len(wrapped_lines) == len(output_lines), lines
for wrapped_line, output_line in zip(wrapped_lines, output_lines):
if "# skip" in output:
continue

assert wrapped_line == output_line
assert wrapped_line == output_line, lines


def test_static_if():
Expand Down Expand Up @@ -189,3 +191,24 @@ def forward(a, *b, c=5, **d):
() = wrapped_0(a, b, c, d)
"""
check_wrapping(source, output)


def test_walrus():
"""Checks for handling variadic names created via function def"""

source = """
def forward():
if (x := (1 + 2)):
pass
"""
output = """
@torch.fx.wrap
def wrapped_0():
if (x := (1 + 2)):
pass
return (x,)

def forward():
(x,) = wrapped_0() # skip: some envs use "(x,)" -> "x,"
"""
check_wrapping(source, output)
3 changes: 3 additions & 0 deletions tests/llmcompressor/transformers/tracing/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from transformers import (
AutoModelForCausalLM,
Gemma3ForConditionalGeneration,
Gemma3nForConditionalGeneration,
Idefics3ForConditionalGeneration,
Llama4ForConditionalGeneration,
LlavaForConditionalGeneration,
Expand Down Expand Up @@ -49,6 +50,7 @@
"text",
[],
),
("google/gemma-3n-E2B-it", AutoModelForCausalLM, None, "text", []),
("unsloth/DeepSeek-R1-0528-BF16", AutoModelForCausalLM, None, "text", []),
# --- vision ---
(
Expand Down Expand Up @@ -122,6 +124,7 @@
"vision",
[],
),
("google/gemma-3n-E2B-it", Gemma3nForConditionalGeneration, None, "vision", []),
# --- audio ---
(
"openai/whisper-large-v3",
Expand Down