-
Notifications
You must be signed in to change notification settings - Fork 474
[Qwen3VLMoe] Add linearized definition and FP8 Quantization Example #1874
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
11 commits
Select commit
Hold shift + click to select a range
2f00d59
add linearized definition
dsikka 7966cdf
update
dsikka ecb7d93
fix ignore list
dsikka 27e6f07
add note
dsikka 15da5c6
updatE
dsikka fb9e3ce
Merge branch 'main' into qwen3VLMoE_lineared
dsikka 4995535
format, fix import
dsikka 2cb141e
Merge branch 'main' into qwen3VLMoE_lineared
dsikka aca3244
fix format
dsikka 0d502f6
move around imports
dsikka e8f0872
Merge branch 'main' into qwen3VLMoE_lineared
dsikka File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
40 changes: 40 additions & 0 deletions
40
examples/quantization_w8a8_fp8/qwen3_vl_moe_fp8_example.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,40 @@ | ||
| from transformers import AutoProcessor, Qwen3VLMoeForConditionalGeneration | ||
|
|
||
| from llmcompressor import oneshot | ||
| from llmcompressor.modeling import replace_modules_for_calibration | ||
| from llmcompressor.modifiers.quantization import QuantizationModifier | ||
|
|
||
| # NOTE: Qwen3-VL-MoE support is not in transformers<=4.56.2 | ||
| # you may need to install transformers from source | ||
|
|
||
|
|
||
| MODEL_ID = "Qwen/Qwen3-VL-235B-A22B-Instruct" | ||
|
|
||
| # Load model. | ||
| model = Qwen3VLMoeForConditionalGeneration.from_pretrained(MODEL_ID, torch_dtype="auto") | ||
| processor = AutoProcessor.from_pretrained(MODEL_ID) | ||
| model = replace_modules_for_calibration(model) | ||
|
|
||
| # Configure the quantization algorithm and scheme. | ||
| # In this case, we: | ||
| # * quantize the weights to fp8 with channel-wise quantization | ||
| # * quantize the activations to fp8 with dynamic token activations | ||
| # NOTE: only datafree quantization is supported for Qwen3-VL-MoE currently | ||
| recipe = QuantizationModifier( | ||
| targets="Linear", | ||
| scheme="FP8_DYNAMIC", | ||
| ignore=[ | ||
| "re:.*lm_head", | ||
| "re:visual.*", | ||
| "re:model.visual.*", | ||
| "re:.*mlp.gate$", | ||
| ], | ||
| ) | ||
|
|
||
| # Apply quantization. | ||
| oneshot(model=model, recipe=recipe) | ||
|
|
||
| # Save to disk in compressed-tensors format. | ||
| SAVE_DIR = MODEL_ID.rstrip("/").split("/")[-1] + "-FP8-DYNAMIC" | ||
| model.save_pretrained(SAVE_DIR) | ||
| processor.save_pretrained(SAVE_DIR) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,56 @@ | ||
| import torch | ||
|
|
||
| from llmcompressor.utils.dev import skip_weights_initialize | ||
|
|
||
|
|
||
| class LinearQwen3VLMoeTextSparseMoeBlock(torch.nn.Module): | ||
| def __init__(self, config, original): | ||
| super().__init__() | ||
| self.hidden_size = config.hidden_size | ||
| self.num_experts = config.num_experts | ||
| self.gate = wrap_gate(original.gate) | ||
| self.experts = SequentialQwen3VLMoeTextExperts(config, original.experts) | ||
|
|
||
|
|
||
| class SequentialQwen3VLMoeTextExperts(torch.nn.ModuleList): | ||
| def __init__(self, config, original): | ||
| from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import ( | ||
| Qwen3VLMoeTextMLP, | ||
| ) | ||
|
|
||
| self.num_experts = original.gate_up_proj.shape[0] | ||
| with skip_weights_initialize(): | ||
| super().__init__( | ||
| [Qwen3VLMoeTextMLP(config) for _ in range(self.num_experts)] | ||
| ) | ||
|
|
||
| intermediate_size = original.down_proj.shape[1] | ||
|
|
||
| for i in range(self.num_experts): | ||
| gate_up = original.gate_up_proj[i] | ||
| down = original.down_proj[i] | ||
|
|
||
| gate_proj = gate_up[:, :intermediate_size] | ||
| up_proj = gate_up[:, intermediate_size:] | ||
|
|
||
| self[i].gate_proj.weight.data = gate_proj.t().clone().contiguous() | ||
| self[i].up_proj.weight.data = up_proj.t().clone().contiguous() | ||
| self[i].down_proj.weight.data = down.t().clone().contiguous() | ||
|
|
||
|
|
||
| def wrap_gate(gate): | ||
| # temporary workaround until ct supports ignores of Linear instances | ||
| linear_gate = torch.nn.Linear(gate.in_features, gate.out_features) | ||
| linear_gate.weight.data.copy_(gate.weight.data) | ||
| setattr(linear_gate, "hidden_size", gate.hidden_size) | ||
| setattr(linear_gate, "top_k", gate.top_k) | ||
| setattr(linear_gate, "forward", gate.forward) | ||
| del gate | ||
| return linear_gate | ||
|
|
||
|
|
||
| def replace(config, module, calibrate_all_experts=False): | ||
| return LinearQwen3VLMoeTextSparseMoeBlock( | ||
| config=config.get_text_config(), | ||
| original=module, | ||
| ) | ||
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.