diff --git a/examples/quantization_w4a4_fp4/qwen3_example.py b/examples/quantization_w4a4_fp4/qwen3_example.py new file mode 100644 index 000000000..6f6f8c774 --- /dev/null +++ b/examples/quantization_w4a4_fp4/qwen3_example.py @@ -0,0 +1,68 @@ +from datasets import load_dataset +from transformers import AutoModelForCausalLM, AutoTokenizer + +from llmcompressor import oneshot +from llmcompressor.modifiers.quantization import QuantizationModifier + +# Load model. +MODEL_ID = "Qwen/Qwen3-32B" +model = AutoModelForCausalLM.from_pretrained( + MODEL_ID, torch_dtype="auto", trust_remote_code=True +) +tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True) + +DATASET_ID = "HuggingFaceH4/ultrachat_200k" +DATASET_SPLIT = "train_sft" + +# Select number of samples. 512 samples is a good place to start. +# Increasing the number of samples can improve accuracy. +NUM_CALIBRATION_SAMPLES = 256 +MAX_SEQUENCE_LENGTH = 2048 + +# Load dataset and preprocess. +ds = load_dataset(DATASET_ID, split=f"{DATASET_SPLIT}[:{NUM_CALIBRATION_SAMPLES}]") +ds = ds.shuffle(seed=42) + + +def preprocess(example): + formatted = tokenizer.apply_chat_template( + example["messages"], tokenize=False, add_generation_prompt=False + ) + return {"text": formatted[: MAX_SEQUENCE_LENGTH * 4]} + + +ds = ds.map(preprocess) + + +# Tokenize inputs. +def tokenize(sample): + return tokenizer( + sample["text"], + padding="max_length", + max_length=MAX_SEQUENCE_LENGTH, + truncation=True, + return_tensors="pt", + add_special_tokens=False, + ) + + +ds = ds.map(tokenize, remove_columns=ds.column_names, batched=True) + +# Configure the quantization algorithm and scheme. +# In this case, we: +# * quantize the weights to fp4 with per group 16 via ptq +# * calibrate a global_scale for activations, which will be used to +# quantize activations to fp4 on the fly +recipe = QuantizationModifier(targets="Linear", scheme="NVFP4", ignore=["lm_head"]) + + +# SAVE_DIR is the directory where the compressed model will be saved. +SAVE_DIR = MODEL_ID.rstrip("/").split("/")[-1] + "-NVFP4" +oneshot( + model=model, + dataset=ds, + recipe=recipe, + max_seq_length=MAX_SEQUENCE_LENGTH, + num_calibration_samples=NUM_CALIBRATION_SAMPLES, + output_dir=SAVE_DIR, +)