From 9e8e013b52bfc1a9f95ad201fdf0477a8862276e Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Thu, 28 Aug 2025 21:39:47 +0000 Subject: [PATCH] add mxfp4 support --- .../llama3_example_mxfp4.py | 35 +++++++++++++++++++ src/llmcompressor/observers/base.py | 5 ++- 2 files changed, 39 insertions(+), 1 deletion(-) create mode 100644 examples/quantization_w4a16_fp4/llama3_example_mxfp4.py diff --git a/examples/quantization_w4a16_fp4/llama3_example_mxfp4.py b/examples/quantization_w4a16_fp4/llama3_example_mxfp4.py new file mode 100644 index 000000000..c6de83d2d --- /dev/null +++ b/examples/quantization_w4a16_fp4/llama3_example_mxfp4.py @@ -0,0 +1,35 @@ +from transformers import AutoModelForCausalLM, AutoTokenizer + +from llmcompressor import oneshot +from llmcompressor.modifiers.quantization import QuantizationModifier +from llmcompressor.utils import dispatch_for_generation + +MODEL_ID = "meta-llama/Llama-3.1-8B-Instruct" + +# Load model. +model = AutoModelForCausalLM.from_pretrained(MODEL_ID, torch_dtype="auto") +tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) + +# Configure the quantization algorithm and scheme. +# In this case, we: +# * quantize the weights to fp4 with per group 16 via ptq +recipe = QuantizationModifier(targets="Linear", scheme="MXFP4A16", ignore=["lm_head"]) + +# Apply quantization. +oneshot(model=model, recipe=recipe) + +print("\n\n") +print("========== SAMPLE GENERATION ==============") +dispatch_for_generation(model) +input_ids = tokenizer("Hello my name is", return_tensors="pt").input_ids.to( + model.device +) +output = model.generate(input_ids, max_new_tokens=100) +print(tokenizer.decode(output[0])) +print("==========================================\n\n") + + +# Save to disk in compressed-tensors format. +SAVE_DIR = MODEL_ID.rstrip("/").split("/")[-1] + "-MXFP4A16" +model.save_pretrained(SAVE_DIR, save_compressed=True) +tokenizer.save_pretrained(SAVE_DIR) diff --git a/src/llmcompressor/observers/base.py b/src/llmcompressor/observers/base.py index aa9e1caab..e396474bd 100644 --- a/src/llmcompressor/observers/base.py +++ b/src/llmcompressor/observers/base.py @@ -151,7 +151,10 @@ def get_qparams( (rows, num_groups), dtype=observed.dtype, device=observed.device ) if is_fp4(quantization_args=self.quantization_args): - zp_dtype = FP8_E4M3_DATA.dtype + if self.quantization_args.group_size == 16: + zp_dtype = FP8_E4M3_DATA.dtype + else: + zp_dtype = torch.uint8 else: zp_dtype = self.quantization_args.pytorch_dtype()