-
Notifications
You must be signed in to change notification settings - Fork 453
[Docs] Add Qwen3.5 to Key Models #2502
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
+232
−1
Merged
Changes from all commits
Commits
Show all changes
6 commits
Select commit
Hold shift + click to select a range
6a02f09
Add qwen3.5 docs
dsikka fdd376c
Merge branch 'main' into add_qwen35_docs
dsikka bd386c4
fix nav
dsikka 18d2837
update
dsikka 0bf4324
Merge branch 'main' into add_qwen35_docs
kylesayrs c00fb15
Merge branch 'main' into add_qwen35_docs
dsikka File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,12 @@ | ||
| # Qwen3.5 | ||
|
|
||
| Quantization examples for the Qwen3.5 family of models, including dense vision-language and sparse MoE variants. | ||
|
|
||
| > **Note:** These examples require `transformers >= v5`, which can be installed with: | ||
| > ```bash | ||
| > uv pip install --upgrade transformers | ||
| > ``` | ||
| > With this, the examples can run end-to-end on `main`. You may also need to update the version of `transformers` in your vLLM environment in order for the tokenizer to be properly applied. | ||
|
|
||
| - [NVFP4A16 Vision-Language Example](nvfp4-vl-example.md) | ||
| - [NVFP4 MoE Example](nvfp4-moe-example.md) | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,123 @@ | ||
| ## Qwen3.5 NVFP4 MoE Example | ||
|
|
||
| This example quantizes the Qwen3.5-122B-A10B sparse MoE model to NVFP4 (weights and activations quantized to FP4) using calibration data. | ||
|
|
||
| NOTE: This example requires `transformers >= v5`. | ||
|
|
||
| ### Code Walkthrough | ||
|
|
||
| Let's walk through the main steps of the quantization process: | ||
| 1. Load model | ||
| 2. Load and preprocess calibration dataset | ||
| 3. Configure quantization algorithm and scheme | ||
| 4. Apply quantization | ||
| 5. Save to disk in compressed-tensors format | ||
|
|
||
| ### 1. Load Model | ||
|
|
||
| ```python | ||
| import torch | ||
| from compressed_tensors.utils import save_mtp_tensors_to_checkpoint | ||
| from datasets import load_dataset | ||
| from transformers import AutoProcessor, Qwen3_5MoeForConditionalGeneration | ||
|
|
||
| from llmcompressor import oneshot | ||
| from llmcompressor.modifiers.quantization import QuantizationModifier | ||
|
|
||
| MODEL_ID = "Qwen/Qwen3.5-122B-A10B" | ||
|
|
||
| # Load model. | ||
| model = Qwen3_5MoeForConditionalGeneration.from_pretrained(MODEL_ID, dtype="auto") | ||
| processor = AutoProcessor.from_pretrained(MODEL_ID) | ||
| ``` | ||
|
|
||
| ### 2. Load and Preprocess Calibration Dataset | ||
|
|
||
| ```python | ||
| NUM_CALIBRATION_SAMPLES = 256 | ||
| MAX_SEQUENCE_LENGTH = 4096 | ||
|
|
||
| ds = load_dataset( | ||
| "HuggingFaceH4/ultrachat_200k", | ||
| split=f"train_sft[:{NUM_CALIBRATION_SAMPLES}]", | ||
| ) | ||
| ds = ds.select_columns(["messages"]) | ||
| ds = ds.shuffle(seed=42) | ||
|
|
||
|
|
||
| def preprocess_function(example): | ||
| messages = [ | ||
| {"role": m["role"], "content": [{"type": "text", "text": m["content"]}]} | ||
| for m in example["messages"] | ||
| ] | ||
| return processor.apply_chat_template( | ||
| messages, | ||
| return_tensors="pt", | ||
| padding=False, | ||
| truncation=True, | ||
| max_length=MAX_SEQUENCE_LENGTH, | ||
| tokenize=True, | ||
| add_special_tokens=False, | ||
| return_dict=True, | ||
| add_generation_prompt=False, | ||
| ) | ||
|
|
||
|
|
||
| ds = ds.map(preprocess_function, batched=False, remove_columns=ds.column_names) | ||
|
|
||
|
|
||
| def data_collator(batch): | ||
| assert len(batch) == 1 | ||
| return {key: torch.tensor(value) for key, value in batch[0].items()} | ||
| ``` | ||
|
|
||
| ### 3. Configure Quantization Algorithm and Scheme | ||
|
|
||
| In this case, we are doing the following: | ||
| - Quantize the weights and activations to FP4 via calibration-based PTQ | ||
| - Skip `lm_head`, visual layers, MoE gate projections, embedding layers, shared expert gates, and linear attention layers | ||
| - MTP layers are not loaded through `Qwen3_5MoeForConditionalGeneration`, so there is no need to include them in the ignore list | ||
|
|
||
| ```python | ||
| recipe = QuantizationModifier( | ||
| targets="Linear", | ||
| scheme="NVFP4", | ||
| ignore=[ | ||
| "re:.*lm_head", | ||
| "re:visual.*", | ||
| "re:model.visual.*", | ||
kylesayrs marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| "re:.*mlp.gate$", | ||
| "re:.*embed_tokens$", | ||
| "re:.*shared_expert_gate$", | ||
| "re:.*linear_attn.*", | ||
| ], | ||
| ) | ||
| ``` | ||
|
|
||
| ### 4. Apply Quantization | ||
|
|
||
| `moe_calibrate_all_experts=True` ensures all MoE experts receive calibration data, which improves quantization quality for sparse MoE models. | ||
|
|
||
| ```python | ||
| oneshot( | ||
| model=model, | ||
| recipe=recipe, | ||
| dataset=ds, | ||
| max_seq_length=MAX_SEQUENCE_LENGTH, | ||
| num_calibration_samples=NUM_CALIBRATION_SAMPLES, | ||
| moe_calibrate_all_experts=True, | ||
| data_collator=data_collator, | ||
| ) | ||
| ``` | ||
|
|
||
| ### 5. Save to Disk in Compressed-Tensors Format | ||
|
|
||
kylesayrs marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| ```python | ||
| SAVE_DIR = MODEL_ID.rstrip("/").split("/")[-1] + "-NVFP4" | ||
| model.save_pretrained(SAVE_DIR) | ||
dsikka marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| processor.save_pretrained(SAVE_DIR) | ||
|
|
||
| # MTP layers are excluded from the model through Qwen3_5MoeForConditionalGeneration | ||
| # Save them as-is from the original checkpoint into the quantized output. | ||
| save_mtp_tensors_to_checkpoint(source_model=MODEL_ID, dest_dir=SAVE_DIR) | ||
| ``` | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,84 @@ | ||
| ## Qwen3.5 NVFP4A16 Vision-Language Example | ||
|
|
||
| This example quantizes the Qwen3.5-27B vision-language model to NVFP4A16 (weights quantized to FP4 with per-group-16 granularity, activations in FP16) using data-free PTQ. | ||
|
|
||
| ### Code Walkthrough | ||
|
|
||
| Let's walk through the main steps of the quantization process: | ||
| 1. Load model | ||
| 2. Configure quantization algorithm and scheme | ||
| 3. Apply quantization | ||
| 4. Run sample generation | ||
| 5. Save to disk in compressed-tensors format | ||
|
|
||
| ### 1. Load Model | ||
|
|
||
| ```python | ||
| from compressed_tensors.offload import dispatch_model | ||
| from compressed_tensors.utils import save_mtp_tensors_to_checkpoint | ||
| from transformers import AutoProcessor, Qwen3_5ForConditionalGeneration | ||
|
|
||
| from llmcompressor import oneshot | ||
| from llmcompressor.modifiers.quantization import QuantizationModifier | ||
|
|
||
| # Load model. | ||
| MODEL_ID = "Qwen/Qwen3.5-27B" | ||
| model = Qwen3_5ForConditionalGeneration.from_pretrained( | ||
| MODEL_ID, dtype="auto", trust_remote_code=True | ||
| ) | ||
| processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True) | ||
| ``` | ||
|
|
||
| ### 2. Configure Quantization Algorithm and Scheme | ||
|
|
||
| In this case, we are doing the following: | ||
| - Quantize the weights to FP4 with per-group-16 granularity via data-free PTQ | ||
| - Skip the visual encoder, `lm_head`, and linear attention layers (Gated DeltaNet fused projections are incompatible with NVFP4) | ||
| - MTP layers are not loaded through `Qwen3_5ForConditionalGeneration`, so there is no need to include them in the ignore list | ||
|
|
||
| ```python | ||
| # No need to include mtp layers as they are not loaded | ||
| # through Qwen3_5ForConditionalGeneration | ||
| recipe = QuantizationModifier( | ||
| targets="Linear", | ||
| scheme="NVFP4A16", | ||
| ignore=[ | ||
| "lm_head", | ||
| "re:.*visual.*", | ||
| "re:.*linear_attn.*", | ||
| ], | ||
| ) | ||
| ``` | ||
|
|
||
| ### 3. Apply Quantization | ||
|
|
||
| ```python | ||
| oneshot(model=model, recipe=recipe) | ||
| ``` | ||
|
|
||
| ### 4. Run Sample Generation | ||
|
|
||
| ```python | ||
| print("\n\n========== SAMPLE GENERATION ==============") | ||
| dispatch_model(model) | ||
| messages = [{"role": "user", "content": "Hello my name is"}] | ||
| prompt = processor.apply_chat_template( | ||
| messages, tokenize=False, add_generation_prompt=True | ||
| ) | ||
| inputs = processor(text=prompt, return_tensors="pt").to(model.device) | ||
| output = model.generate(**inputs, max_new_tokens=100) | ||
| print(processor.decode(output[0], skip_special_tokens=True)) | ||
| print("==========================================\n\n") | ||
| ``` | ||
|
|
||
| ### 5. Save to Disk in Compressed-Tensors Format | ||
|
|
||
| ```python | ||
| SAVE_DIR = MODEL_ID.rstrip("/").split("/")[-1] + "-NVFP4A16" | ||
| model.save_pretrained(SAVE_DIR) | ||
| processor.save_pretrained(SAVE_DIR) | ||
|
|
||
| # MTP layers are excluded from the model through Qwen3_5ForConditionalGeneration | ||
| # Save them as-is from the original checkpoint into the quantized output. | ||
| save_mtp_tensors_to_checkpoint(source_model=MODEL_ID, dest_dir=SAVE_DIR) | ||
| ``` |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.