Skip to content

Commit 27d9cda

Browse files
committed
fix: scope AWQ mapping resolution to sequential targets for multimodal models
1 parent 0c0ead3 commit 27d9cda

File tree

2 files changed

+165
-3
lines changed

2 files changed

+165
-3
lines changed
Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
import requests
2+
from compressed_tensors.offload import dispatch_model
3+
from PIL import Image
4+
from transformers import AutoProcessor, Gemma3ForConditionalGeneration
5+
6+
from llmcompressor import oneshot
7+
from llmcompressor.modifiers.awq import AWQModifier
8+
9+
# Load model.
10+
model_id = "google/gemma-3-4b-it"
11+
model = Gemma3ForConditionalGeneration.from_pretrained(model_id, dtype="auto")
12+
processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
13+
14+
# Oneshot arguments
15+
NUM_CALIBRATION_SAMPLES = 512
16+
MAX_SEQUENCE_LENGTH = 2048
17+
DATASET_ID = "flickr30k"
18+
DATASET_SPLIT = {"calibration": f"test[:{NUM_CALIBRATION_SAMPLES}]"}
19+
20+
# Recipe — AWQ with vision encoder excluded.
21+
# The vision tower (SigLIP) and multi-modal projector must be ignored because
22+
# their layer names (layer_norm1/2, out_proj, fc1/fc2) don't match the AWQ
23+
# gemma mappings (input_layernorm, o_proj, gate_proj/up_proj/down_proj), and
24+
# attempting to quantize them causes shape mismatches and tracing failures.
25+
recipe = AWQModifier(
26+
scheme="W4A16",
27+
ignore=[
28+
"lm_head",
29+
r"re:model\.vision_tower.*",
30+
r"re:model\.multi_modal_projector.*",
31+
],
32+
duo_scaling=False,
33+
)
34+
35+
# Perform oneshot.
36+
# sequential_targets must be set to the text decoder layer only, because the
37+
# default _no_split_modules includes SiglipEncoderLayer and other vision
38+
# components, which would cause the sequential pipeline to crash.
39+
oneshot(
40+
model=model,
41+
processor=processor,
42+
dataset=DATASET_ID,
43+
splits=DATASET_SPLIT,
44+
recipe=recipe,
45+
shuffle_calibration_samples=False,
46+
max_seq_length=MAX_SEQUENCE_LENGTH,
47+
num_calibration_samples=NUM_CALIBRATION_SAMPLES,
48+
sequential_targets=["Gemma3DecoderLayer"],
49+
)
50+
51+
# Confirm generations of the quantized model look sane.
52+
print("========== SAMPLE GENERATION ==============")
53+
dispatch_model(model)
54+
messages = [
55+
{
56+
"role": "user",
57+
"content": [
58+
{"type": "text", "text": "Please describe the animal in this image\n"},
59+
{"type": "image"},
60+
],
61+
},
62+
]
63+
prompt = processor.apply_chat_template(messages, add_generation_prompt=True)
64+
image_url = "http://images.cocodataset.org/train2017/000000231895.jpg"
65+
raw_image = Image.open(requests.get(image_url, stream=True).raw)
66+
67+
# Note: compile is disabled: https://github.com/huggingface/transformers/issues/38333
68+
inputs = processor(images=raw_image, text=prompt, return_tensors="pt").to(model.device)
69+
output = model.generate(**inputs, max_new_tokens=100, disable_compile=True)
70+
print(processor.decode(output[0], skip_special_tokens=True))
71+
print("==========================================")
72+
73+
# Save to disk.
74+
# Note: save_compressed=True currently fails on multimodal models due to a
75+
# known issue in compressed-tensors with non-quantized vision tower weights.
76+
SAVE_DIR = model_id.rstrip("/").split("/")[-1] + "-AWQ-W4A16"
77+
model.save_pretrained(SAVE_DIR, save_compressed=False)
78+
processor.save_pretrained(SAVE_DIR)

src/llmcompressor/modifiers/awq/base.py

Lines changed: 87 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -225,7 +225,10 @@ def on_initialize(self, state: State, **kwargs) -> bool:
225225
# (no offloading by default)
226226
self.offload_device = None
227227

228-
self._set_resolved_mappings(state.model)
228+
# Resolve sequential_targets: prefer oneshot() kwarg, fall back to
229+
# modifier field, then auto-detect from model config.
230+
seq_targets = kwargs.get("sequential_targets") or self.sequential_targets
231+
self._set_resolved_mappings(state.model, seq_targets)
229232

230233
return True
231234

@@ -320,7 +323,11 @@ def on_finalize(self, state: State, **kwargs) -> bool:
320323

321324
return True
322325

323-
def _set_resolved_mappings(self, model: Module) -> None:
326+
def _set_resolved_mappings(
327+
self,
328+
model: Module,
329+
sequential_targets: str | list[str] | None = None,
330+
) -> None:
324331
"""
325332
Transforms the list of activations to smooth and their corresponding weights
326333
into ResolvedMapping objects, resolving regular expressions.
@@ -331,9 +338,22 @@ def _set_resolved_mappings(self, model: Module) -> None:
331338
weight is ".*re:.*q_proj" and the activation is "re:.*self_attn_layer_norm" we
332339
would match model.layer.0.p_proj to model.layer.0.self_attn_layer_norm and
333340
repeat for model.layer.1 and so on
341+
342+
:param model: model to resolve mappings for
343+
:param sequential_targets: optional list of module class names that define
344+
the scope for mapping resolution. When provided, only modules that are
345+
children of these targets participate in matching. This prevents vision
346+
encoder modules from polluting the mapping resolution in multimodal models.
334347
"""
335348
resolved_mappings: list[ResolvedMapping] = []
336349
module_to_name = get_module_to_name_dict(model)
350+
351+
# Build a scoped model view when sequential_targets are available.
352+
# This restricts match_modules_set to only consider modules that live
353+
# under a sequential target (e.g. decoder layers), preventing vision
354+
# encoder modules from breaking the parent-context grouping.
355+
match_model = _build_scoped_model(model, sequential_targets)
356+
337357
# Get names of modules targeted for quantization (excludes ignored)
338358
targeted_names = set(
339359
name
@@ -346,7 +366,7 @@ def _set_resolved_mappings(self, model: Module) -> None:
346366
# so that we can handle layers that need smoothing but not quantization
347367
# we only skip if no layers in mapping are targeted for quantization.
348368
for smooth_layers, *nested_balance_layers in match_modules_set(
349-
model, (mapping.smooth_layer, *mapping.balance_layers)
369+
match_model, (mapping.smooth_layer, *mapping.balance_layers)
350370
):
351371
if len(smooth_layers) > 1:
352372
raise ValueError(
@@ -1040,3 +1060,67 @@ def _accumulate_mean(
10401060
new_count = prev_count + num_added
10411061

10421062
return (prev_sum + sum_added) / new_count, new_count
1063+
1064+
1065+
class _ScopedModuleView(Module):
1066+
"""
1067+
A lightweight wrapper that restricts ``named_modules()`` to only yield
1068+
modules whose names fall under given scope prefixes. Everything else
1069+
(attribute access, ``get_submodule``, etc.) is forwarded to the wrapped
1070+
model so that ``match_modules_set`` can resolve module identities normally.
1071+
"""
1072+
1073+
def __init__(self, model: Module, scope_prefixes: set[str]):
1074+
# bypass Module.__init__ to avoid registering parameters/buffers
1075+
object.__setattr__(self, "_model", model)
1076+
object.__setattr__(self, "_scope_prefixes", scope_prefixes)
1077+
1078+
def named_modules(self, *args, **kwargs):
1079+
for name, mod in self._model.named_modules(*args, **kwargs):
1080+
if not name: # root module — always include
1081+
yield name, mod
1082+
elif any(
1083+
name == p or name.startswith(p + ".") for p in self._scope_prefixes
1084+
):
1085+
yield name, mod
1086+
1087+
def __getattr__(self, name: str):
1088+
return getattr(self._model, name)
1089+
1090+
1091+
def _build_scoped_model(
1092+
model: Module,
1093+
sequential_targets: str | list[str] | None,
1094+
) -> Module:
1095+
"""
1096+
If *sequential_targets* is provided, return a :class:`_ScopedModuleView`
1097+
that only exposes modules living under instances of those target classes.
1098+
Otherwise return *model* unchanged (no-op for text-only models).
1099+
"""
1100+
if not sequential_targets:
1101+
return model
1102+
1103+
if isinstance(sequential_targets, str):
1104+
sequential_targets = [sequential_targets]
1105+
1106+
target_classes = set(sequential_targets)
1107+
1108+
scope_prefixes: set[str] = set()
1109+
for name, mod in model.named_modules():
1110+
if type(mod).__name__ in target_classes:
1111+
scope_prefixes.add(name)
1112+
1113+
if not scope_prefixes:
1114+
logger.warning(
1115+
"sequential_targets %s did not match any modules, "
1116+
"falling back to unscoped mapping resolution",
1117+
sequential_targets,
1118+
)
1119+
return model
1120+
1121+
logger.info(
1122+
"Scoping AWQ mapping resolution to %d sequential targets (%s)",
1123+
len(scope_prefixes),
1124+
sequential_targets,
1125+
)
1126+
return _ScopedModuleView(model, scope_prefixes)

0 commit comments

Comments
 (0)