Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions src/llmcompressor/modeling/llama4.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from typing import Tuple

import torch
from transformers.models import Llama4Config
from transformers.models.llama4.configuration_llama4 import Llama4TextConfig
from transformers.models.llama4.configuration_llama4 import (
Llama4Config,
Llama4TextConfig,
)
from transformers.models.llama4.modeling_llama4 import (
Llama4TextExperts,
Llama4TextMLP,
Expand Down
8 changes: 5 additions & 3 deletions src/llmcompressor/modifiers/awq/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -465,11 +465,13 @@ def _apply_smoothing(self, model: Module) -> None:
# Calculates the relative magnitude of the weights within
# each of the quantization groups, and rescales each group
# individually so that each group has weights on a 0-1 scale.
w_scale = weight.abs() / (weight.abs().amax(dim=1, keepdim=True) + 1e-6)
weight.abs_()
weight.div_(weight.amax(dim=1, keepdim=True) + 1e-6)
# Resizes the rescaled weight matrix back up to its original dimensions
w_scale = w_scale.view(org_shape)
weight = weight.view(org_shape)
# Gets the average rescaled magnitude for each output channel
w_mean = w_scale.mean(0)
w_mean = weight.mean(0)
del weight

with calibration_forward_context(model), HooksMixin.disable_hooks():
# [STEP 3]: Compute output of module
Expand Down
17 changes: 17 additions & 0 deletions src/llmcompressor/modifiers/awq/mappings.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,9 +116,26 @@ class AWQMapping:
),
]

# DeepseekV3
_deepseek_mappings = [
AWQMapping(
"re:.*input_layernorm$",
# Some models use q_proj instead of q_a_proj
["re:.*(q|q_a)_proj$", "re:.*kv_a_proj_with_mqa$"],
),
AWQMapping("re:.*q_a_layernorm$", ["re:.*q_b_proj$"]),
AWQMapping("re:.*kv_a_layernorm$", ["re:.*kv_b_proj$"]),
AWQMapping(
"re:.*post_attention_layernorm$",
["re:.*gate_proj$", "re:.*up_proj$"],
),
AWQMapping("re:.*up_proj$", ["re:.*down_proj$"]),
]

AWQ_MAPPING_REGISTRY: Dict[str, list[AWQMapping]] = {
"CohereForCausalLM": _cohere_mappings,
"Cohere2ForCausalLM": _cohere_mappings,
"DeepseekV3ForCausalLM": _deepseek_mappings,
"Gemma2ForCausalLM": _gemma_mappings,
"Gemma3ForCausalLM": _gemma_mappings,
"Gemma3ForConditionalGeneration": _gemma_mappings,
Expand Down