-
Notifications
You must be signed in to change notification settings - Fork 453
[Examples] Add Qwen3.5-27B NVFP4A16 and MXFP4A16 quantization examples #2467
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
+94
−0
Merged
Changes from 2 commits
Commits
Show all changes
4 commits
Select commit
Hold shift + click to select a range
df3ad83
[Examples] Add Qwen3.5-27B NVFP4A16 and MXFP4A16 quantization examples
2imi9 fca0429
Merge branch 'main' into add-qwen3.5-fp4-examples
dsikka 71646ef
Merge branch 'main' into add-qwen3.5-fp4-examples
dsikka fa7eef0
fix examples to use proper classes
dsikka File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,45 @@ | ||
| from compressed_tensors.offload import dispatch_model | ||
| from transformers import AutoModelForCausalLM, AutoTokenizer | ||
|
|
||
| from llmcompressor import oneshot | ||
| from llmcompressor.modifiers.quantization import QuantizationModifier | ||
|
|
||
| # Load model. | ||
| MODEL_ID = "Qwen/Qwen3.5-27B" | ||
| model = AutoModelForCausalLM.from_pretrained( | ||
| MODEL_ID, dtype="auto", trust_remote_code=True | ||
| ) | ||
| tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True) | ||
|
|
||
| # Configure the quantization algorithm and scheme. | ||
| # In this case, we: | ||
| # * quantize the weights to fp4 with per group 32 via ptq | ||
| # * skip the visual encoder, lm_head, linear attention (Gated DeltaNet | ||
| # fused projections are incompatible with microscale formats), and MTP modules | ||
| recipe = QuantizationModifier( | ||
| targets="Linear", | ||
| scheme="MXFP4A16", | ||
| ignore=[ | ||
| "lm_head", | ||
| "re:.*visual.*", | ||
| "re:.*linear_attn.*", | ||
| "re:.*mtp.*", | ||
| ], | ||
| ) | ||
|
|
||
| # Apply quantization. | ||
| oneshot(model=model, recipe=recipe) | ||
|
|
||
| print("\n\n========== SAMPLE GENERATION ==============") | ||
| dispatch_model(model) | ||
| input_ids = tokenizer("Hello my name is", return_tensors="pt").input_ids.to( | ||
| model.device | ||
| ) | ||
| output = model.generate(input_ids, max_new_tokens=100) | ||
| print(tokenizer.decode(output[0], skip_special_tokens=True)) | ||
| print("==========================================\n\n") | ||
|
|
||
| # Save to disk in compressed-tensors format. | ||
| SAVE_DIR = MODEL_ID.rstrip("/").split("/")[-1] + "-MXFP4A16" | ||
| model.save_pretrained(SAVE_DIR, save_compressed=True) | ||
| tokenizer.save_pretrained(SAVE_DIR) | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,45 @@ | ||
| from compressed_tensors.offload import dispatch_model | ||
| from transformers import AutoModelForCausalLM, AutoTokenizer | ||
|
|
||
| from llmcompressor import oneshot | ||
| from llmcompressor.modifiers.quantization import QuantizationModifier | ||
|
|
||
| # Load model. | ||
| MODEL_ID = "Qwen/Qwen3.5-27B" | ||
| model = AutoModelForCausalLM.from_pretrained( | ||
| MODEL_ID, dtype="auto", trust_remote_code=True | ||
| ) | ||
| tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True) | ||
|
|
||
| # Configure the quantization algorithm and scheme. | ||
| # In this case, we: | ||
| # * quantize the weights to fp4 with per group 16 via ptq | ||
| # * skip the visual encoder, lm_head, linear attention (Gated DeltaNet | ||
| # fused projections are incompatible with NVFP4), and MTP modules | ||
| recipe = QuantizationModifier( | ||
| targets="Linear", | ||
| scheme="NVFP4A16", | ||
| ignore=[ | ||
| "lm_head", | ||
| "re:.*visual.*", | ||
| "re:.*linear_attn.*", | ||
| "re:.*mtp.*", | ||
| ], | ||
| ) | ||
|
|
||
| # Apply quantization. | ||
| oneshot(model=model, recipe=recipe) | ||
|
|
||
| print("\n\n========== SAMPLE GENERATION ==============") | ||
| dispatch_model(model) | ||
| input_ids = tokenizer("Hello my name is", return_tensors="pt").input_ids.to( | ||
| model.device | ||
| ) | ||
| output = model.generate(input_ids, max_new_tokens=100) | ||
| print(tokenizer.decode(output[0], skip_special_tokens=True)) | ||
| print("==========================================\n\n") | ||
|
|
||
| # Save to disk in compressed-tensors format. | ||
| SAVE_DIR = MODEL_ID.rstrip("/").split("/")[-1] + "-NVFP4A16" | ||
| model.save_pretrained(SAVE_DIR, save_compressed=True) | ||
| tokenizer.save_pretrained(SAVE_DIR) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.