Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 78 additions & 0 deletions examples/multimodal_vision/gemma3_awq_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
import requests
from compressed_tensors.offload import dispatch_model
from PIL import Image
from transformers import AutoProcessor, Gemma3ForConditionalGeneration

from llmcompressor import oneshot
from llmcompressor.modifiers.awq import AWQModifier

# Load model.
model_id = "google/gemma-3-4b-it"
model = Gemma3ForConditionalGeneration.from_pretrained(model_id, dtype="auto")
processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)

# Oneshot arguments
NUM_CALIBRATION_SAMPLES = 512
MAX_SEQUENCE_LENGTH = 2048
DATASET_ID = "flickr30k"
DATASET_SPLIT = {"calibration": f"test[:{NUM_CALIBRATION_SAMPLES}]"}

# Recipe — AWQ with vision encoder excluded.
# The vision tower (SigLIP) and multi-modal projector must be ignored because
# their layer names (layer_norm1/2, out_proj, fc1/fc2) don't match the AWQ
# gemma mappings (input_layernorm, o_proj, gate_proj/up_proj/down_proj), and
# attempting to quantize them causes shape mismatches and tracing failures.
recipe = AWQModifier(
scheme="W4A16",
ignore=[
"lm_head",
r"re:model\.vision_tower.*",
r"re:model\.multi_modal_projector.*",
],
duo_scaling=False,
)

# Perform oneshot.
# sequential_targets must be set to the text decoder layer only, because the
# default _no_split_modules includes SiglipEncoderLayer and other vision
# components, which would cause the sequential pipeline to crash.
oneshot(
model=model,
processor=processor,
dataset=DATASET_ID,
splits=DATASET_SPLIT,
recipe=recipe,
shuffle_calibration_samples=False,
max_seq_length=MAX_SEQUENCE_LENGTH,
num_calibration_samples=NUM_CALIBRATION_SAMPLES,
sequential_targets=["Gemma3DecoderLayer"],
)

# Confirm generations of the quantized model look sane.
print("========== SAMPLE GENERATION ==============")
dispatch_model(model)
messages = [
{
"role": "user",
"content": [
{"type": "text", "text": "Please describe the animal in this image\n"},
{"type": "image"},
],
},
]
prompt = processor.apply_chat_template(messages, add_generation_prompt=True)
image_url = "http://images.cocodataset.org/train2017/000000231895.jpg"
raw_image = Image.open(requests.get(image_url, stream=True).raw)

# Note: compile is disabled: https://github.com/huggingface/transformers/issues/38333
inputs = processor(images=raw_image, text=prompt, return_tensors="pt").to(model.device)
output = model.generate(**inputs, max_new_tokens=100, disable_compile=True)
print(processor.decode(output[0], skip_special_tokens=True))
print("==========================================")

# Save to disk.
# Note: save_compressed=True currently fails on multimodal models due to a
# known issue in compressed-tensors with non-quantized vision tower weights.
SAVE_DIR = model_id.rstrip("/").split("/")[-1] + "-AWQ-W4A16"
model.save_pretrained(SAVE_DIR, save_compressed=False)
processor.save_pretrained(SAVE_DIR)
90 changes: 87 additions & 3 deletions src/llmcompressor/modifiers/awq/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,10 @@ def on_initialize(self, state: State, **kwargs) -> bool:
# (no offloading by default)
self.offload_device = None

self._set_resolved_mappings(state.model)
# Resolve sequential_targets: prefer oneshot() kwarg, fall back to
# modifier field.
seq_targets = kwargs.get("sequential_targets") or self.sequential_targets
self._set_resolved_mappings(state.model, seq_targets)

return True

Expand Down Expand Up @@ -320,7 +323,11 @@ def on_finalize(self, state: State, **kwargs) -> bool:

return True

def _set_resolved_mappings(self, model: Module) -> None:
def _set_resolved_mappings(
self,
model: Module,
sequential_targets: str | list[str] | None = None,
) -> None:
"""
Transforms the list of activations to smooth and their corresponding weights
into ResolvedMapping objects, resolving regular expressions.
Expand All @@ -331,9 +338,22 @@ def _set_resolved_mappings(self, model: Module) -> None:
weight is ".*re:.*q_proj" and the activation is "re:.*self_attn_layer_norm" we
would match model.layer.0.p_proj to model.layer.0.self_attn_layer_norm and
repeat for model.layer.1 and so on

:param model: model to resolve mappings for
:param sequential_targets: optional list of module class names that define
the scope for mapping resolution. When provided, only modules that are
children of these targets participate in matching. This prevents vision
encoder modules from polluting the mapping resolution in multimodal models.
"""
resolved_mappings: list[ResolvedMapping] = []
module_to_name = get_module_to_name_dict(model)

# Build a scoped model view when sequential_targets are available.
# This restricts match_modules_set to only consider modules that live
# under a sequential target (e.g. decoder layers), preventing vision
# encoder modules from breaking the parent-context grouping.
match_model = _build_scoped_model(model, sequential_targets)

# Get names of modules targeted for quantization (excludes ignored)
targeted_names = set(
name
Expand All @@ -346,7 +366,7 @@ def _set_resolved_mappings(self, model: Module) -> None:
# so that we can handle layers that need smoothing but not quantization
# we only skip if no layers in mapping are targeted for quantization.
for smooth_layers, *nested_balance_layers in match_modules_set(
model, (mapping.smooth_layer, *mapping.balance_layers)
match_model, (mapping.smooth_layer, *mapping.balance_layers)
):
if len(smooth_layers) > 1:
raise ValueError(
Expand Down Expand Up @@ -1040,3 +1060,67 @@ def _accumulate_mean(
new_count = prev_count + num_added

return (prev_sum + sum_added) / new_count, new_count


class _ScopedModuleView(Module):
"""
A lightweight wrapper that restricts ``named_modules()`` to only yield
modules whose names fall under given scope prefixes. Everything else
(attribute access, ``get_submodule``, etc.) is forwarded to the wrapped
model so that ``match_modules_set`` can resolve module identities normally.
"""

def __init__(self, model: Module, scope_prefixes: set[str]):
# bypass Module.__init__ to avoid registering parameters/buffers
object.__setattr__(self, "_model", model)
object.__setattr__(self, "_scope_prefixes", scope_prefixes)

def named_modules(self, *args, **kwargs):
for name, mod in self._model.named_modules(*args, **kwargs):
if not name: # root module — always include
yield name, mod
elif any(
name == p or name.startswith(p + ".") for p in self._scope_prefixes
):
yield name, mod

def __getattr__(self, name: str):
return getattr(self._model, name)


def _build_scoped_model(
model: Module,
sequential_targets: str | list[str] | None,
) -> Module:
"""
If *sequential_targets* is provided, return a :class:`_ScopedModuleView`
that only exposes modules living under instances of those target classes.
Otherwise return *model* unchanged (no-op for text-only models).
"""
if not sequential_targets:
return model

if isinstance(sequential_targets, str):
sequential_targets = [sequential_targets]

target_classes = set(sequential_targets)

scope_prefixes: set[str] = set()
for name, mod in model.named_modules():
if type(mod).__name__ in target_classes:
scope_prefixes.add(name)

if not scope_prefixes:
logger.warning(
"sequential_targets %s did not match any modules, "
"falling back to unscoped mapping resolution",
sequential_targets,
)
return model

logger.info(
"Scoping AWQ mapping resolution to %d sequential targets (%s)",
len(scope_prefixes),
sequential_targets,
)
return _ScopedModuleView(model, scope_prefixes)
Loading