-
Notifications
You must be signed in to change notification settings - Fork 464
[Qwen3.5 MoE Support] #2377
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Qwen3.5 MoE Support] #2377
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,88 @@ | ||
| from transformers import AutoProcessor, Qwen3_5MoeForConditionalGeneration | ||
| from datasets import load_dataset | ||
| from llmcompressor import oneshot | ||
| from llmcompressor.modifiers.quantization import QuantizationModifier | ||
| import torch | ||
|
|
||
| MODEL_ID = "/raid/engine/dsikka/models--Qwen--Qwen3.5-397B-A17B/snapshots/7cad2bae11cb49ca79f7d6a0954de2e2756f4e27" | ||
| MODEL_ID = "Qwen/Qwen3.5-122B-A10B" | ||
|
|
||
| # Load model. | ||
| model = Qwen3_5MoeForConditionalGeneration.from_pretrained(MODEL_ID, dtype="auto") | ||
| processor = AutoProcessor.from_pretrained(MODEL_ID) | ||
|
|
||
|
|
||
| recipe = QuantizationModifier( | ||
| targets="Linear", | ||
| scheme="NVFP4", | ||
| ignore=[ | ||
| "re:.*lm_head", | ||
| "re:visual.*", | ||
| "re:model.visual.*", | ||
| "re:.*mlp.gate$", | ||
| "re:.*embed_tokens$", | ||
| "re:.*shared_expert_gate$", | ||
| "re:.*mlp\\.shared_expert$", | ||
| "re:.*linear_attn.*", | ||
| ], | ||
| ) | ||
|
|
||
| DATASET_ID = "neuralmagic/calibration" | ||
| NUM_CALIBRATION_SAMPLES = 20 | ||
| MAX_SEQUENCE_LENGTH = 8192 | ||
|
|
||
| ds = load_dataset(DATASET_ID, name="LLM", split=f"train[:{NUM_CALIBRATION_SAMPLES}]") | ||
|
|
||
|
|
||
| def preprocess_function(example): | ||
| messgages = [] | ||
| for message in example["messages"]: | ||
| messgages.append( | ||
| { | ||
| "role": message["role"], | ||
| "content": [{"type": "text", "text": message["content"]}], | ||
| } | ||
| ) | ||
|
|
||
| return processor.apply_chat_template( | ||
| messgages, | ||
| return_tensors="pt", | ||
| padding=False, | ||
| truncation=True, | ||
| max_length=MAX_SEQUENCE_LENGTH, | ||
| tokenize=True, | ||
| add_special_tokens=False, | ||
| return_dict=True, | ||
| add_generation_prompt=False, | ||
| ) | ||
|
|
||
|
|
||
| ds = ds.map(preprocess_function, batched=False, remove_columns=ds.column_names) | ||
|
|
||
|
|
||
| def data_collator(batch): | ||
| assert len(batch) == 1 | ||
| return { | ||
| key: ( | ||
| torch.tensor(value) | ||
| if key != "pixel_values" | ||
| else torch.tensor(value, dtype=torch.bfloat16).squeeze(0) | ||
| ) | ||
| for key, value in batch[0].items() | ||
| } | ||
|
|
||
|
|
||
|
|
||
| # Apply quantization. | ||
| oneshot(model=model, | ||
| recipe=recipe, | ||
| dataset=ds, | ||
| data_collator=data_collator, | ||
| max_seq_length=MAX_SEQUENCE_LENGTH, | ||
| num_calibration_samples=NUM_CALIBRATION_SAMPLES, | ||
| moe_calibrate_all_experts=True) | ||
|
|
||
| # Save to disk in compressed-tensors format. | ||
| SAVE_DIR = "/mnt/nvme_stripe/playground/dsikka/" + "Qwen3.5-122B-A10B" + "-NVFP4" | ||
| model.save_pretrained(SAVE_DIR) | ||
| processor.save_pretrained(SAVE_DIR) |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,39 @@ | ||
| from transformers import AutoProcessor, Qwen3_5MoeForConditionalGeneration | ||
|
|
||
| from llmcompressor import oneshot | ||
| from llmcompressor.modifiers.quantization import QuantizationModifier | ||
|
|
||
| MODEL_ID = "/raid/engine/dsikka/models--Qwen--Qwen3.5-397B-A17B/snapshots/7cad2bae11cb49ca79f7d6a0954de2e2756f4e27" | ||
| MODEL_ID = "Qwen/Qwen3.5-122B-A10B" | ||
|
|
||
| # Load model. | ||
| model = Qwen3_5MoeForConditionalGeneration.from_pretrained(MODEL_ID, dtype="auto") | ||
| processor = AutoProcessor.from_pretrained(MODEL_ID) | ||
|
|
||
| # Configure the quantization algorithm and scheme. | ||
| # In this case, we: | ||
| # * quantize the weights to fp8 with channel-wise quantization | ||
| # * quantize the activations to fp8 with dynamic token activations | ||
| # NOTE: only datafree quantization is supported for Qwen3-VL-MoE currently | ||
| recipe = QuantizationModifier( | ||
| targets="Linear", | ||
| scheme="FP8_DYNAMIC", | ||
| ignore=[ | ||
| "re:.*lm_head", | ||
| "re:visual.*", | ||
| "re:model.visual.*", | ||
| "re:.*mlp.gate$", | ||
| "re:.*embed_tokens$", | ||
| "re:.*shared_expert_gate$", | ||
| "re:.*mlp\\.shared_expert$", | ||
| "re:.*linear_attn.*", | ||
| ], | ||
| ) | ||
|
|
||
| # Apply quantization. | ||
| oneshot(model=model, recipe=recipe) | ||
|
|
||
| # Save to disk in compressed-tensors format. | ||
| SAVE_DIR = "/mnt/nvme_stripe/playground/dsikka/" + "Qwen3.5-122B-A10B" + "-FP8_DYNAMIC" | ||
| model.save_pretrained(SAVE_DIR) | ||
| processor.save_pretrained(SAVE_DIR) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,124 @@ | ||
| import torch | ||
| from transformers import Qwen3_5MoeConfig, Qwen3_5MoeTextConfig | ||
| from transformers.models.qwen3_5_moe.modeling_qwen3_5_moe import ( | ||
| Qwen3_5MoeSparseMoeBlock, | ||
| ) | ||
|
|
||
| from llmcompressor.modeling.moe_context import MoECalibrationModule | ||
| from llmcompressor.utils.dev import skip_weights_initialize | ||
| import torch.nn.functional as F | ||
|
|
||
|
|
||
| @MoECalibrationModule.register("Qwen3_5MoeSparseMoeBlock") | ||
| class CalibrateQwen3_5MoeTextSparseMoeBlock(MoECalibrationModule): | ||
| """ | ||
| Calibration version of Qwen3_5MoeSparseMoeBlock that sends all tokens to all | ||
| experts. | ||
| """ | ||
|
|
||
| is_permanent = True | ||
|
|
||
| def __init__( | ||
| self, | ||
| original: "Qwen3_5MoeSparseMoeBlock", | ||
| config: "Qwen3_5MoeConfig", | ||
| calibrate_all_experts: bool = True, | ||
| ): | ||
| super().__init__() | ||
| text_config: "Qwen3_5MoeTextConfig" = config.get_text_config() | ||
|
|
||
| self.num_experts = text_config.num_experts | ||
|
|
||
| self.shared_expert = original.shared_expert | ||
| self.shared_expert_gate = original.shared_expert_gate | ||
| self.gate = original.gate | ||
| self.experts = SequentialQwen3VLMoeTextExperts(text_config, original.experts) | ||
| self.calibrate_all_experts = calibrate_all_experts | ||
|
|
||
| def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: | ||
| batch_size, sequence_length, hidden_dim = hidden_states.shape | ||
| hidden_states_reshaped = hidden_states.view(-1, hidden_dim) | ||
|
|
||
| # router: returns (router_logits, router_scores, router_indices) | ||
| _, routing_weights, selected_experts = self.gate(hidden_states_reshaped) | ||
|
|
||
| # expert mask: (num_experts, top_k, num_tokens) | ||
| expert_mask = F.one_hot(selected_experts, num_classes=self.num_experts).permute( | ||
| 2, 1, 0 | ||
| ) | ||
|
|
||
| final_hidden_states = torch.zeros( | ||
| (batch_size * sequence_length, hidden_dim), | ||
| dtype=hidden_states.dtype, | ||
| device=hidden_states.device, | ||
| ) | ||
|
|
||
| for expert_idx, expert_layer in enumerate(self.experts): | ||
| idx, token_idx = torch.where(expert_mask[expert_idx]) | ||
|
|
||
| if self.calibrate_all_experts: | ||
| expert_out = expert_layer(hidden_states_reshaped)[token_idx] | ||
| else: | ||
| expert_out = expert_layer(hidden_states_reshaped[token_idx]) | ||
|
|
||
| if len(token_idx) > 0: | ||
| current_hidden_states = ( | ||
| expert_out * routing_weights[token_idx, idx, None] | ||
| ) | ||
| final_hidden_states.index_add_( | ||
| 0, | ||
| token_idx, | ||
| current_hidden_states.to(hidden_states.dtype), | ||
| ) | ||
|
|
||
| # shared expert | ||
| shared_expert_output = self.shared_expert(hidden_states_reshaped) | ||
| shared_expert_output = ( | ||
| F.sigmoid(self.shared_expert_gate(hidden_states_reshaped)) | ||
| * shared_expert_output | ||
| ) | ||
| final_hidden_states = final_hidden_states + shared_expert_output | ||
|
|
||
| final_hidden_states = final_hidden_states.reshape( | ||
| batch_size, sequence_length, hidden_dim | ||
| ) | ||
| return final_hidden_states | ||
|
|
||
| def restore(self, original: torch.nn.Module) -> torch.nn.Module: | ||
| return original | ||
|
|
||
|
|
||
| class SequentialQwen3VLMoeTextExperts(torch.nn.ModuleList): | ||
| def __init__(self, config, original): | ||
| from transformers.models.qwen3_5_moe.modeling_qwen3_5_moe import ( | ||
| Qwen3_5MoeMLP, | ||
| ) | ||
| from compressed_tensors.offload import disable_onloading | ||
|
|
||
| self.num_experts = original.gate_up_proj.shape[0] | ||
| with skip_weights_initialize(): | ||
| super().__init__( | ||
| [ | ||
| Qwen3_5MoeMLP( | ||
| config, intermediate_size=config.shared_expert_intermediate_size | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should be config.moe_intermediate_size , this will create incorrectly sized linear layers. |
||
| ) | ||
| for _ in range(self.num_experts) | ||
| ] | ||
| ) | ||
|
|
||
| intermediate_size = original.down_proj.shape[-1] | ||
|
|
||
| with disable_onloading(): | ||
| gate_up_data = original.gate_up_proj.data # [num_experts, 2*inter, hidden] | ||
| down_data = original.down_proj.data # [num_experts, hidden, inter] | ||
|
|
||
| for i in range(self.num_experts): | ||
| gate_up = gate_up_data[i] | ||
| down = down_data[i] | ||
|
|
||
| gate_proj = gate_up[:intermediate_size, :] | ||
| up_proj = gate_up[intermediate_size:, :] | ||
|
|
||
| self[i].gate_proj.weight.data = gate_proj.clone().contiguous() | ||
| self[i].up_proj.weight.data = up_proj.clone().contiguous() | ||
| self[i].down_proj.weight.data = down.clone().contiguous() | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -12,7 +12,8 @@ | |
| from loguru import logger | ||
| from safetensors.torch import save_file | ||
| from transformers import AutoModelForCausalLM, PreTrainedModel | ||
| from transformers.modeling_utils import TORCH_INIT_FUNCTIONS | ||
|
|
||
| # from transformers.modeling_utils import TORCH_INIT_FUNCTIONS | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Unrelated to Qwen 3.5 as this is a transformers compatibility workaround |
||
| from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, WEIGHTS_INDEX_NAME | ||
|
|
||
| __all__ = [ | ||
|
|
@@ -22,6 +23,25 @@ | |
| "dispatch_for_generation", | ||
| ] | ||
|
|
||
| from torch import nn | ||
|
|
||
| TORCH_INIT_FUNCTIONS = { | ||
| "uniform_": nn.init.uniform_, | ||
| "normal_": nn.init.normal_, | ||
| "trunc_normal_": nn.init.trunc_normal_, | ||
| "constant_": nn.init.constant_, | ||
| "xavier_uniform_": nn.init.xavier_uniform_, | ||
| "xavier_normal_": nn.init.xavier_normal_, | ||
| "kaiming_uniform_": nn.init.kaiming_uniform_, | ||
| "kaiming_normal_": nn.init.kaiming_normal_, | ||
| "uniform": nn.init.uniform, | ||
| "normal": nn.init.normal, | ||
| "xavier_uniform": nn.init.xavier_uniform, | ||
| "xavier_normal": nn.init.xavier_normal, | ||
| "kaiming_uniform": nn.init.kaiming_uniform, | ||
| "kaiming_normal": nn.init.kaiming_normal, | ||
| } | ||
|
Comment on lines
+26
to
+43
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The |
||
|
|
||
|
|
||
| @contextlib.contextmanager | ||
| def skip_weights_download(model_class: Type[PreTrainedModel] = AutoModelForCausalLM): | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The
MODEL_IDis currently hardcoded to a specific local path. This reduces the portability and reusability of the example script. Consider making this path configurable, perhaps through command-line arguments or environment variables, to allow users to easily specify their model location.