Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 35 additions & 0 deletions examples/quantization_w4a16_fp4/llama3_example_mxfp4.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
from transformers import AutoModelForCausalLM, AutoTokenizer

from llmcompressor import oneshot
from llmcompressor.modifiers.quantization import QuantizationModifier
from llmcompressor.utils import dispatch_for_generation

MODEL_ID = "meta-llama/Llama-3.1-8B-Instruct"

# Load model.
model = AutoModelForCausalLM.from_pretrained(MODEL_ID, torch_dtype="auto")
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)

# Configure the quantization algorithm and scheme.
# In this case, we:
# * quantize the weights to fp4 with per group 16 via ptq
recipe = QuantizationModifier(targets="Linear", scheme="MXFP4A16", ignore=["lm_head"])

# Apply quantization.
oneshot(model=model, recipe=recipe)

print("\n\n")
print("========== SAMPLE GENERATION ==============")
dispatch_for_generation(model)
input_ids = tokenizer("Hello my name is", return_tensors="pt").input_ids.to(
model.device
)
output = model.generate(input_ids, max_new_tokens=100)
print(tokenizer.decode(output[0]))
print("==========================================\n\n")


# Save to disk in compressed-tensors format.
SAVE_DIR = MODEL_ID.rstrip("/").split("/")[-1] + "-MXFP4A16"
model.save_pretrained(SAVE_DIR, save_compressed=True)
tokenizer.save_pretrained(SAVE_DIR)
5 changes: 4 additions & 1 deletion src/llmcompressor/observers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,10 @@ def get_qparams(
(rows, num_groups), dtype=observed.dtype, device=observed.device
)
if is_fp4(quantization_args=self.quantization_args):
zp_dtype = FP8_E4M3_DATA.dtype
if self.quantization_args.group_size == 16:
zp_dtype = FP8_E4M3_DATA.dtype
else:
zp_dtype = torch.uint8
else:
zp_dtype = self.quantization_args.pytorch_dtype()

Expand Down