Skip to content

[Bug]: Gemma3ForCausalLM(Medgemma-27-text-it) model produces garbage output #2365

@Jeevi10

Description

@Jeevi10

⚙️ Your current environment

import torch
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer

from llmcompressor import oneshot
from llmcompressor.modifiers.awq import AWQModifier, AWQMapping

# Select calibration dataset
DATASET_ID = "FreedomIntelligence/medical-o1-reasoning-SFT"
DATASET_SPLIT = "train"
language = "en"

# Increasing the number of calib samples to 256 or higher can improve accuracy.
NUM_CALIBRATION_SAMPLES = 256
MAX_SEQUENCE_LENGTH = 4096

# Load dataset and preprocess.
ds = load_dataset(DATASET_ID, language, split=f"{DATASET_SPLIT}[:{NUM_CALIBRATION_SAMPLES}]")
ds = ds.shuffle(seed=42)
print(f"Dataset size: {len(ds)} samples")


# Preprocess: apply chat template and tokenize properly
def preprocess(example):
    text = tokenizer.apply_chat_template(
        [{"role": "user", "content": example["Question"]}],
        tokenize=False,
        add_generation_prompt=True,
    )
    return tokenizer(
        text,
        padding=False,
        max_length=MAX_SEQUENCE_LENGTH,
        truncation=True,
    )


ds = ds.map(preprocess, remove_columns=ds.column_names)

# Gemma 3 specific mappings (GQA-aware: skip v_proj -> o_proj due to dimension mismatch)
gemma3_mappings = [
    # 1. Smooth Input Norm against Q, K, V
    AWQMapping(
        smooth_layer="re:.*input_layernorm$",
        balance_layers=["re:.*q_proj$", "re:.*k_proj$", "re:.*v_proj$"]
    ),
    # 2. Smooth the MLP using the pre-feedforward norm
    AWQMapping(
        smooth_layer="re:.*pre_feedforward_layernorm$",
        balance_layers=["re:.*gate_proj$", "re:.*up_proj$"]
    ),
    # 3. Smooth Up-proj against Down-proj
    AWQMapping(
        smooth_layer="re:.*up_proj$",
        balance_layers=["re:.*down_proj$"]
    ),
]

# Initialize the modifier with these mappings
recipe = [
    AWQModifier(
        targets=["Linear"],
        scheme="W4A16_ASYM",
        ignore=["lm_head"],
        mappings=gemma3_mappings,
    )
]

# Run quantization algorithm
oneshot(
    model=model,
    dataset=ds,
    recipe=recipe,
    max_seq_length=MAX_SEQUENCE_LENGTH,
    num_calibration_samples=NUM_CALIBRATION_SAMPLES,
)

🐛 Describe the bug

testing time model produce garbage results

🛠️ Steps to reproduce

No response

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions