-
Notifications
You must be signed in to change notification settings - Fork 453
Closed
Closed
Copy link
Labels
bugSomething isn't workingSomething isn't working
Description
⚙️ Your current environment
import torch
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer
from llmcompressor import oneshot
from llmcompressor.modifiers.awq import AWQModifier, AWQMapping
# Select calibration dataset
DATASET_ID = "FreedomIntelligence/medical-o1-reasoning-SFT"
DATASET_SPLIT = "train"
language = "en"
# Increasing the number of calib samples to 256 or higher can improve accuracy.
NUM_CALIBRATION_SAMPLES = 256
MAX_SEQUENCE_LENGTH = 4096
# Load dataset and preprocess.
ds = load_dataset(DATASET_ID, language, split=f"{DATASET_SPLIT}[:{NUM_CALIBRATION_SAMPLES}]")
ds = ds.shuffle(seed=42)
print(f"Dataset size: {len(ds)} samples")
# Preprocess: apply chat template and tokenize properly
def preprocess(example):
text = tokenizer.apply_chat_template(
[{"role": "user", "content": example["Question"]}],
tokenize=False,
add_generation_prompt=True,
)
return tokenizer(
text,
padding=False,
max_length=MAX_SEQUENCE_LENGTH,
truncation=True,
)
ds = ds.map(preprocess, remove_columns=ds.column_names)
# Gemma 3 specific mappings (GQA-aware: skip v_proj -> o_proj due to dimension mismatch)
gemma3_mappings = [
# 1. Smooth Input Norm against Q, K, V
AWQMapping(
smooth_layer="re:.*input_layernorm$",
balance_layers=["re:.*q_proj$", "re:.*k_proj$", "re:.*v_proj$"]
),
# 2. Smooth the MLP using the pre-feedforward norm
AWQMapping(
smooth_layer="re:.*pre_feedforward_layernorm$",
balance_layers=["re:.*gate_proj$", "re:.*up_proj$"]
),
# 3. Smooth Up-proj against Down-proj
AWQMapping(
smooth_layer="re:.*up_proj$",
balance_layers=["re:.*down_proj$"]
),
]
# Initialize the modifier with these mappings
recipe = [
AWQModifier(
targets=["Linear"],
scheme="W4A16_ASYM",
ignore=["lm_head"],
mappings=gemma3_mappings,
)
]
# Run quantization algorithm
oneshot(
model=model,
dataset=ds,
recipe=recipe,
max_seq_length=MAX_SEQUENCE_LENGTH,
num_calibration_samples=NUM_CALIBRATION_SAMPLES,
)🐛 Describe the bug
testing time model produce garbage results
🛠️ Steps to reproduce
No response
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working