diff --git a/examples/quantization_w4a4_fp4/qwen3_vl_moe_w4a4_fp4.py b/examples/quantization_w4a4_fp4/qwen3_vl_moe_w4a4_fp4.py new file mode 100644 index 000000000..2e8f86003 --- /dev/null +++ b/examples/quantization_w4a4_fp4/qwen3_vl_moe_w4a4_fp4.py @@ -0,0 +1,103 @@ +import torch +from datasets import load_dataset +from transformers import AutoProcessor, Qwen3VLMoeForConditionalGeneration + +from llmcompressor import oneshot +from llmcompressor.modeling import replace_modules_for_calibration +from llmcompressor.modifiers.quantization import QuantizationModifier +from llmcompressor.utils import dispatch_for_generation + +# NOTE: Qwen3-VL-MoE support is not in transformers<=4.56.2 +# you may need to install transformers from source + + +MODEL_ID = "Qwen/Qwen3-VL-235B-A22B-Instruct" + +# Load model. +model = Qwen3VLMoeForConditionalGeneration.from_pretrained(MODEL_ID, torch_dtype="auto") +processor = AutoProcessor.from_pretrained(MODEL_ID) +model = replace_modules_for_calibration(model) + +DATASET_ID = "neuralmagic/calibration" +NUM_CALIBRATION_SAMPLES = 20 +MAX_SEQUENCE_LENGTH = 8192 + +ds = load_dataset(DATASET_ID, name="LLM", split=f"train[:{NUM_CALIBRATION_SAMPLES}]") + + +def preprocess_function(example): + messgages = [] + for message in example["messages"]: + messgages.append( + { + "role": message["role"], + "content": [{"type": "text", "text": message["content"]}], + } + ) + + return processor.apply_chat_template( + messgages, + return_tensors="pt", + padding=False, + truncation=True, + max_length=MAX_SEQUENCE_LENGTH, + tokenize=True, + add_special_tokens=False, + return_dict=True, + add_generation_prompt=False, + ) + + +ds = ds.map(preprocess_function, batched=False, remove_columns=ds.column_names) + + +def data_collator(batch): + assert len(batch) == 1 + return { + key: ( + torch.tensor(value) + if key != "pixel_values" + else torch.tensor(value, dtype=torch.bfloat16).squeeze(0) + ) + for key, value in batch[0].items() + } + + +# Configure the quantization algorithm and scheme. +# In this case, we: +# * quantize the weights to fp8 with channel-wise quantization +# * quantize the activations to fp8 with dynamic token activations +# NOTE: only datafree quantization is supported for Qwen3-VL-MoE currently +recipe = QuantizationModifier( + targets="Linear", + scheme="NVFP4", + ignore=[ + "re:.*lm_head", + "re:visual.*", + "re:model.visual.*", + "re:.*mlp.gate$", + ], +) + +# Apply quantization. +oneshot( + model=model, + recipe=recipe, + max_seq_length=MAX_SEQUENCE_LENGTH, + num_calibration_samples=NUM_CALIBRATION_SAMPLES, + dataset=ds, + data_collator=data_collator, +) + +print("========== SAMPLE GENERATION ==============") +dispatch_for_generation(model) +input_ids = processor(text="Hello my name is", return_tensors="pt").input_ids.to("cuda") +output = model.generate(input_ids, max_new_tokens=20) +print(processor.decode(output[0])) +print("==========================================") + + +# Save to disk in compressed-tensors format. +SAVE_DIR = "/proving-grounds/engine/hub_cache/Qwen3-VL-235B-A22B-Instruct" + "-NVFP4" +model.save_pretrained(SAVE_DIR) +processor.save_pretrained(SAVE_DIR) diff --git a/src/llmcompressor/modeling/qwen3_vl_moe.py b/src/llmcompressor/modeling/qwen3_vl_moe.py index b36869344..1624e2be8 100644 --- a/src/llmcompressor/modeling/qwen3_vl_moe.py +++ b/src/llmcompressor/modeling/qwen3_vl_moe.py @@ -4,13 +4,61 @@ class LinearQwen3VLMoeTextSparseMoeBlock(torch.nn.Module): - def __init__(self, config, original): + def __init__(self, config, original, calibrate_all_experts): super().__init__() self.hidden_size = config.hidden_size self.num_experts = config.num_experts - self.gate = wrap_gate(original.gate) + self.top_k = original.top_k + # Note: gate was changed to be a Linear layer in transformers==4.57.0 + # https://github.com/JJJYmmm/transformers/commit/f5dea1c694af8c994c769170813a8702332119ee + self.gate = original.gate + self.calibrate_all_experts = calibrate_all_experts self.experts = SequentialQwen3VLMoeTextExperts(config, original.experts) + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + batch_size = hidden_states.shape[0] + hidden_states = hidden_states.reshape(-1, self.hidden_size) + + router_logits = self.gate(hidden_states) + routing_weights = torch.nn.functional.softmax( + router_logits, dim=-1, dtype=torch.float + ) + routing_weights, router_indices = torch.topk( + routing_weights, self.top_k, dim=-1 + ) + routing_weights = routing_weights / routing_weights.sum(dim=-1, keepdim=True) + routing_weights = routing_weights.to(hidden_states.dtype) + + router_weights = torch.zeros_like(router_logits).scatter_( + 1, router_indices, routing_weights + ) + + next_states = torch.zeros_like( + hidden_states, dtype=hidden_states.dtype, device=hidden_states.device + ) + expert_mask = torch.nn.functional.one_hot( + router_indices, num_classes=self.num_experts + ).permute(2, 1, 0) + for expert_idx, expert_layer in enumerate(self.experts): + with torch.no_grad(): + _, token_idx = torch.where(expert_mask[expert_idx].squeeze(0)) + + if self.calibrate_all_experts: + expert_out = expert_layer(hidden_states)[token_idx] + else: + expert_out = expert_layer(hidden_states[token_idx]) + + if len(token_idx) > 0: + weighted_output = ( + expert_out * router_weights[token_idx, expert_idx, None] + ) + next_states.index_add_( + 0, token_idx, weighted_output.to(hidden_states.dtype) + ) + + next_states = next_states.view(batch_size, -1, self.hidden_size) + return next_states + class SequentialQwen3VLMoeTextExperts(torch.nn.ModuleList): def __init__(self, config, original): @@ -38,19 +86,9 @@ def __init__(self, config, original): self[i].down_proj.weight.data = down.t().clone().contiguous() -def wrap_gate(gate): - # temporary workaround until ct supports ignores of Linear instances - linear_gate = torch.nn.Linear(gate.in_features, gate.out_features) - linear_gate.weight.data.copy_(gate.weight.data) - setattr(linear_gate, "hidden_size", gate.hidden_size) - setattr(linear_gate, "top_k", gate.top_k) - setattr(linear_gate, "forward", gate.forward) - del gate - return linear_gate - - -def replace(config, module, calibrate_all_experts=False): +def replace(config, module, calibrate_all_experts): return LinearQwen3VLMoeTextSparseMoeBlock( config=config.get_text_config(), original=module, + calibrate_all_experts=calibrate_all_experts, )