From a362cfc13e33c0ac58bcb3a966cb028cd14a2984 Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Tue, 17 Feb 2026 15:01:34 -0500 Subject: [PATCH 1/5] add support --- examples/quantization_w8a8_fp8/qwen3_5_moe.py | 43 ++++++++++++ src/llmcompressor/modeling/__init__.py | 1 + src/llmcompressor/modeling/qwen3_5_vl_moe.py | 68 +++++++++++++++++++ src/llmcompressor/utils/dev.py | 22 +++++- 4 files changed, 133 insertions(+), 1 deletion(-) create mode 100644 examples/quantization_w8a8_fp8/qwen3_5_moe.py create mode 100644 src/llmcompressor/modeling/qwen3_5_vl_moe.py diff --git a/examples/quantization_w8a8_fp8/qwen3_5_moe.py b/examples/quantization_w8a8_fp8/qwen3_5_moe.py new file mode 100644 index 0000000000..61018a32f6 --- /dev/null +++ b/examples/quantization_w8a8_fp8/qwen3_5_moe.py @@ -0,0 +1,43 @@ +from transformers import AutoProcessor, Qwen3_5MoeForConditionalGeneration + +from llmcompressor import oneshot +from llmcompressor.modifiers.quantization import QuantizationModifier + +MODEL_ID = "/raid/engine/dsikka/models--Qwen--Qwen3.5-397B-A17B/snapshots/7cad2bae11cb49ca79f7d6a0954de2e2756f4e27" + +# Load model. +model = Qwen3_5MoeForConditionalGeneration.from_pretrained(MODEL_ID, dtype="auto") +processor = AutoProcessor.from_pretrained(MODEL_ID) + +# Configure the quantization algorithm and scheme. +# In this case, we: +# * quantize the weights to fp8 with channel-wise quantization +# * quantize the activations to fp8 with dynamic token activations +# NOTE: only datafree quantization is supported for Qwen3-VL-MoE currently +recipe = QuantizationModifier( + targets="Linear", + scheme="FP8_DYNAMIC", + ignore=[ + "re:.*lm_head", + "re:visual.*", + "re:model.visual.*", + "re:.*mlp.gate$", + "re:.*embed_tokens$", + "re:.*shared_expert_gate$", + "re:.*conv1d$", + "re:.*in_proj_a$", + "re:.*in_proj_b$", + "re:.*in_proj_ba$", + "re:.*norm$", + "re:.*pre_fc_norm_hidden$", + "re:.*pre_fc_norm_embedding$", + ], +) + +# Apply quantization. +oneshot(model=model, recipe=recipe) + +# Save to disk in compressed-tensors format. +SAVE_DIR = "/raid/engine/dsikka/" + "Qwen3.5-397B-A17B" + "-FP8-DYNAMIC" +model.save_pretrained(SAVE_DIR) +processor.save_pretrained(SAVE_DIR) diff --git a/src/llmcompressor/modeling/__init__.py b/src/llmcompressor/modeling/__init__.py index d7cd3f24ae..39be1c4579 100644 --- a/src/llmcompressor/modeling/__init__.py +++ b/src/llmcompressor/modeling/__init__.py @@ -16,6 +16,7 @@ from .qwen3_moe import CalibrationQwen3MoeSparseMoeBlock # noqa: F401 from .qwen3_vl_moe import CalibrateQwen3VLMoeTextSparseMoeBlock # noqa: F401 from .qwen3_next_moe import CalibrationQwen3NextSparseMoeBlock # noqa: F401 +from .qwen3_5_vl_moe import CalibrateQwen3_5MoeTextSparseMoeBlock # noqa: F401 # TODO: add granite4, Qwen3Next from .fuse import * diff --git a/src/llmcompressor/modeling/qwen3_5_vl_moe.py b/src/llmcompressor/modeling/qwen3_5_vl_moe.py new file mode 100644 index 0000000000..5c6e0f77ed --- /dev/null +++ b/src/llmcompressor/modeling/qwen3_5_vl_moe.py @@ -0,0 +1,68 @@ +import torch +from transformers import Qwen3_5MoeConfig, Qwen3_5MoeTextConfig +from transformers.models.qwen3_5_moe.modeling_qwen3_5_moe import ( + Qwen3_5MoeSparseMoeBlock, +) + +from llmcompressor.modeling.moe_context import MoECalibrationModule +from llmcompressor.utils.dev import skip_weights_initialize + + +@MoECalibrationModule.register("Qwen3_5MoeSparseMoeBlock") +class CalibrateQwen3_5MoeTextSparseMoeBlock(MoECalibrationModule): + """ + Calibration version of Qwen3_5MoeSparseMoeBlock that sends all tokens to all + experts. + """ + + is_permanent = True + + def __init__( + self, + original: "Qwen3_5MoeSparseMoeBlock", + config: "Qwen3_5MoeConfig", + calibrate_all_experts: bool, + ): + super().__init__() + text_config: "Qwen3_5MoeTextConfig" = config.get_text_config() + + self.num_experts = text_config.num_experts + + self.shared_expert = original.shared_expert + self.shared_expert_gate = original.shared_expert_gate + self.gate = original.gate + self.experts = SequentialQwen3VLMoeTextExperts(text_config, original.experts) + + def restore(self, original: torch.nn.Module) -> torch.nn.Module: + return original + + +class SequentialQwen3VLMoeTextExperts(torch.nn.ModuleList): + def __init__(self, config, original): + from transformers.models.qwen3_5_moe.modeling_qwen3_5_moe import ( + Qwen3_5MoeMLP, + ) + + self.num_experts = original.gate_up_proj.shape[0] + with skip_weights_initialize(): + super().__init__( + [ + Qwen3_5MoeMLP( + config, intermediate_size=config.shared_expert_intermediate_size + ) + for _ in range(self.num_experts) + ] + ) + + intermediate_size = original.down_proj.shape[1] + + for i in range(self.num_experts): + gate_up = original.gate_up_proj[i] + down = original.down_proj[i] + + gate_proj = gate_up[:, :intermediate_size] + up_proj = gate_up[:, intermediate_size:] + + self[i].gate_proj.weight.data = gate_proj.t().clone().contiguous() + self[i].up_proj.weight.data = up_proj.t().clone().contiguous() + self[i].down_proj.weight.data = down.t().clone().contiguous() diff --git a/src/llmcompressor/utils/dev.py b/src/llmcompressor/utils/dev.py index 244cd1489a..aaecddc016 100644 --- a/src/llmcompressor/utils/dev.py +++ b/src/llmcompressor/utils/dev.py @@ -12,7 +12,8 @@ from loguru import logger from safetensors.torch import save_file from transformers import AutoModelForCausalLM, PreTrainedModel -from transformers.modeling_utils import TORCH_INIT_FUNCTIONS + +# from transformers.modeling_utils import TORCH_INIT_FUNCTIONS from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, WEIGHTS_INDEX_NAME __all__ = [ @@ -22,6 +23,25 @@ "dispatch_for_generation", ] +from torch import nn + +TORCH_INIT_FUNCTIONS = { + "uniform_": nn.init.uniform_, + "normal_": nn.init.normal_, + "trunc_normal_": nn.init.trunc_normal_, + "constant_": nn.init.constant_, + "xavier_uniform_": nn.init.xavier_uniform_, + "xavier_normal_": nn.init.xavier_normal_, + "kaiming_uniform_": nn.init.kaiming_uniform_, + "kaiming_normal_": nn.init.kaiming_normal_, + "uniform": nn.init.uniform, + "normal": nn.init.normal, + "xavier_uniform": nn.init.xavier_uniform, + "xavier_normal": nn.init.xavier_normal, + "kaiming_uniform": nn.init.kaiming_uniform, + "kaiming_normal": nn.init.kaiming_normal, +} + @contextlib.contextmanager def skip_weights_download(model_class: Type[PreTrainedModel] = AutoModelForCausalLM): From 7474c7d9f3694006b86117f7829079399ec44a82 Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Wed, 18 Feb 2026 02:31:17 -0500 Subject: [PATCH 2/5] update --- examples/quantization_w8a8_fp8/qwen3_5_moe.py | 4 ++-- src/llmcompressor/modeling/qwen3_5_vl_moe.py | 12 ++++++------ 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/examples/quantization_w8a8_fp8/qwen3_5_moe.py b/examples/quantization_w8a8_fp8/qwen3_5_moe.py index 61018a32f6..2abad358d9 100644 --- a/examples/quantization_w8a8_fp8/qwen3_5_moe.py +++ b/examples/quantization_w8a8_fp8/qwen3_5_moe.py @@ -16,7 +16,7 @@ # NOTE: only datafree quantization is supported for Qwen3-VL-MoE currently recipe = QuantizationModifier( targets="Linear", - scheme="FP8_DYNAMIC", + scheme="FP8_BLOCK", ignore=[ "re:.*lm_head", "re:visual.*", @@ -38,6 +38,6 @@ oneshot(model=model, recipe=recipe) # Save to disk in compressed-tensors format. -SAVE_DIR = "/raid/engine/dsikka/" + "Qwen3.5-397B-A17B" + "-FP8-DYNAMIC" +SAVE_DIR = "/raid/engine/dsikka/" + "Qwen3.5-397B-A17B" + "-FP8-Block" model.save_pretrained(SAVE_DIR) processor.save_pretrained(SAVE_DIR) diff --git a/src/llmcompressor/modeling/qwen3_5_vl_moe.py b/src/llmcompressor/modeling/qwen3_5_vl_moe.py index 5c6e0f77ed..27a7142913 100644 --- a/src/llmcompressor/modeling/qwen3_5_vl_moe.py +++ b/src/llmcompressor/modeling/qwen3_5_vl_moe.py @@ -54,15 +54,15 @@ def __init__(self, config, original): ] ) - intermediate_size = original.down_proj.shape[1] + intermediate_size = original.down_proj.shape[-1] for i in range(self.num_experts): gate_up = original.gate_up_proj[i] down = original.down_proj[i] - gate_proj = gate_up[:, :intermediate_size] - up_proj = gate_up[:, intermediate_size:] + gate_proj = gate_up[:intermediate_size, :] + up_proj = gate_up[intermediate_size:, :] - self[i].gate_proj.weight.data = gate_proj.t().clone().contiguous() - self[i].up_proj.weight.data = up_proj.t().clone().contiguous() - self[i].down_proj.weight.data = down.t().clone().contiguous() + self[i].gate_proj.weight.data = gate_proj.clone().contiguous() + self[i].up_proj.weight.data = up_proj.clone().contiguous() + self[i].down_proj.weight.data = down.clone().contiguous() \ No newline at end of file From 55593b694b7bc7b0f5daf05e78b4ee86027c65b6 Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Sat, 21 Feb 2026 18:33:11 -0500 Subject: [PATCH 3/5] update --- examples/quantization_w4a4_fp4/qwen3_5_moe.py | 77 +++++++++++++++++++ examples/quantization_w8a8_fp8/qwen3_5_moe.py | 13 +--- src/llmcompressor/modeling/qwen3_5_vl_moe.py | 59 +++++++++++++- 3 files changed, 138 insertions(+), 11 deletions(-) create mode 100644 examples/quantization_w4a4_fp4/qwen3_5_moe.py diff --git a/examples/quantization_w4a4_fp4/qwen3_5_moe.py b/examples/quantization_w4a4_fp4/qwen3_5_moe.py new file mode 100644 index 0000000000..82a19ed9fb --- /dev/null +++ b/examples/quantization_w4a4_fp4/qwen3_5_moe.py @@ -0,0 +1,77 @@ +from transformers import AutoProcessor, Qwen3_5MoeForConditionalGeneration +from datasets import load_dataset +from llmcompressor import oneshot +from llmcompressor.modifiers.quantization import QuantizationModifier + +MODEL_ID = "/raid/engine/dsikka/models--Qwen--Qwen3.5-397B-A17B/snapshots/7cad2bae11cb49ca79f7d6a0954de2e2756f4e27" + +# Load model. +model = Qwen3_5MoeForConditionalGeneration.from_pretrained(MODEL_ID, dtype="auto") +processor = AutoProcessor.from_pretrained(MODEL_ID) + + +recipe = QuantizationModifier( + targets="Linear", + scheme="NVFP4", + ignore=[ + "re:.*lm_head", + "re:visual.*", + "re:model.visual.*", + "re:.*mlp.gate$", + "re:.*embed_tokens$", + "re:.*shared_expert_gate$", + "re:.*mlp\\.shared_expert$", + "re:.*linear_attn.*", + ], +) + +DATASET_ID = "HuggingFaceH4/ultrachat_200k" +DATASET_SPLIT = "train_sft" + +# Select number of samples +NUM_CALIBRATION_SAMPLES = 20 +MAX_SEQUENCE_LENGTH = 2048 + +# Load dataset and preprocess. +ds = load_dataset(DATASET_ID, split=f"{DATASET_SPLIT}[:{NUM_CALIBRATION_SAMPLES}]") +ds = ds.shuffle(seed=42) + + +def preprocess(example): + return { + "text": processor.apply_chat_template( + example["messages"], + tokenize=False, + ) + } + + +ds = ds.map(preprocess) + + +# Tokenize inputs. +def tokenize(sample): + return processor( + sample["text"], + padding=False, + max_length=MAX_SEQUENCE_LENGTH, + truncation=True, + add_special_tokens=False, + ) + + +ds = ds.map(tokenize, remove_columns=ds.column_names) + + +# Apply quantization. +oneshot(model=model, + recipe=recipe, + dataset=ds, + max_seq_length=MAX_SEQUENCE_LENGTH, + num_calibration_samples=NUM_CALIBRATION_SAMPLES, + moe_calibrate_all_experts=True) + +# Save to disk in compressed-tensors format. +SAVE_DIR = "/raid/engine/dsikka/" + "Qwen3.5-397B-A17B" + "-NVFP4" +model.save_pretrained(SAVE_DIR) +processor.save_pretrained(SAVE_DIR) diff --git a/examples/quantization_w8a8_fp8/qwen3_5_moe.py b/examples/quantization_w8a8_fp8/qwen3_5_moe.py index 2abad358d9..b4edad0e74 100644 --- a/examples/quantization_w8a8_fp8/qwen3_5_moe.py +++ b/examples/quantization_w8a8_fp8/qwen3_5_moe.py @@ -16,7 +16,7 @@ # NOTE: only datafree quantization is supported for Qwen3-VL-MoE currently recipe = QuantizationModifier( targets="Linear", - scheme="FP8_BLOCK", + scheme="FP8_DYNAMIC", ignore=[ "re:.*lm_head", "re:visual.*", @@ -24,13 +24,8 @@ "re:.*mlp.gate$", "re:.*embed_tokens$", "re:.*shared_expert_gate$", - "re:.*conv1d$", - "re:.*in_proj_a$", - "re:.*in_proj_b$", - "re:.*in_proj_ba$", - "re:.*norm$", - "re:.*pre_fc_norm_hidden$", - "re:.*pre_fc_norm_embedding$", + "re:.*mlp\\.shared_expert$", + "re:.*linear_attn.*", ], ) @@ -38,6 +33,6 @@ oneshot(model=model, recipe=recipe) # Save to disk in compressed-tensors format. -SAVE_DIR = "/raid/engine/dsikka/" + "Qwen3.5-397B-A17B" + "-FP8-Block" +SAVE_DIR = "/raid/engine/dsikka/" + "Qwen3.5-397B-A17B" + "-FP8-Dynamic-NoLinearAttn" model.save_pretrained(SAVE_DIR) processor.save_pretrained(SAVE_DIR) diff --git a/src/llmcompressor/modeling/qwen3_5_vl_moe.py b/src/llmcompressor/modeling/qwen3_5_vl_moe.py index 27a7142913..7f275fd696 100644 --- a/src/llmcompressor/modeling/qwen3_5_vl_moe.py +++ b/src/llmcompressor/modeling/qwen3_5_vl_moe.py @@ -6,6 +6,7 @@ from llmcompressor.modeling.moe_context import MoECalibrationModule from llmcompressor.utils.dev import skip_weights_initialize +import torch.nn.functional as F @MoECalibrationModule.register("Qwen3_5MoeSparseMoeBlock") @@ -32,6 +33,55 @@ def __init__( self.shared_expert_gate = original.shared_expert_gate self.gate = original.gate self.experts = SequentialQwen3VLMoeTextExperts(text_config, original.experts) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + batch_size, sequence_length, hidden_dim = hidden_states.shape + hidden_states_reshaped = hidden_states.view(-1, hidden_dim) + + # router: returns (router_logits, router_scores, router_indices) + _, routing_weights, selected_experts = self.gate(hidden_states_reshaped) + + # expert mask: (num_experts, top_k, num_tokens) + expert_mask = F.one_hot(selected_experts, num_classes=self.num_experts).permute( + 2, 1, 0 + ) + + final_hidden_states = torch.zeros( + (batch_size * sequence_length, hidden_dim), + dtype=hidden_states.dtype, + device=hidden_states.device, + ) + + for expert_idx, expert_layer in enumerate(self.experts): + idx, token_idx = torch.where(expert_mask[expert_idx]) + + if self.calibrate_all_experts: + expert_out = expert_layer(hidden_states_reshaped)[token_idx] + else: + expert_out = expert_layer(hidden_states_reshaped[token_idx]) + + if len(token_idx) > 0: + current_hidden_states = ( + expert_out * routing_weights[token_idx, idx, None] + ) + final_hidden_states.index_add_( + 0, + token_idx, + current_hidden_states.to(hidden_states.dtype), + ) + + # shared expert + shared_expert_output = self.shared_expert(hidden_states_reshaped) + shared_expert_output = ( + F.sigmoid(self.shared_expert_gate(hidden_states_reshaped)) + * shared_expert_output + ) + final_hidden_states = final_hidden_states + shared_expert_output + + final_hidden_states = final_hidden_states.reshape( + batch_size, sequence_length, hidden_dim + ) + return final_hidden_states def restore(self, original: torch.nn.Module) -> torch.nn.Module: return original @@ -42,6 +92,7 @@ def __init__(self, config, original): from transformers.models.qwen3_5_moe.modeling_qwen3_5_moe import ( Qwen3_5MoeMLP, ) + from compressed_tensors.offload import disable_onloading self.num_experts = original.gate_up_proj.shape[0] with skip_weights_initialize(): @@ -56,9 +107,13 @@ def __init__(self, config, original): intermediate_size = original.down_proj.shape[-1] + with disable_onloading(): + gate_up_data = original.gate_up_proj.data # [num_experts, 2*inter, hidden] + down_data = original.down_proj.data # [num_experts, hidden, inter] + for i in range(self.num_experts): - gate_up = original.gate_up_proj[i] - down = original.down_proj[i] + gate_up = gate_up_data[i] + down = down_data[i] gate_proj = gate_up[:intermediate_size, :] up_proj = gate_up[intermediate_size:, :] From a215aaeef63be8ee731cf5c770fc32e1487e1b91 Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Sun, 22 Feb 2026 00:26:38 -0500 Subject: [PATCH 4/5] update --- examples/quantization_w4a4_fp4/qwen3_5_moe.py | 60 +++++++++++-------- src/llmcompressor/modeling/qwen3_5_vl_moe.py | 3 +- src/llmcompressor/utils/pytorch/module.py | 2 +- 3 files changed, 38 insertions(+), 27 deletions(-) diff --git a/examples/quantization_w4a4_fp4/qwen3_5_moe.py b/examples/quantization_w4a4_fp4/qwen3_5_moe.py index 82a19ed9fb..06f3144a9a 100644 --- a/examples/quantization_w4a4_fp4/qwen3_5_moe.py +++ b/examples/quantization_w4a4_fp4/qwen3_5_moe.py @@ -2,6 +2,7 @@ from datasets import load_dataset from llmcompressor import oneshot from llmcompressor.modifiers.quantization import QuantizationModifier +import torch MODEL_ID = "/raid/engine/dsikka/models--Qwen--Qwen3.5-397B-A17B/snapshots/7cad2bae11cb49ca79f7d6a0954de2e2756f4e27" @@ -25,48 +26,57 @@ ], ) -DATASET_ID = "HuggingFaceH4/ultrachat_200k" -DATASET_SPLIT = "train_sft" - -# Select number of samples +DATASET_ID = "neuralmagic/calibration" NUM_CALIBRATION_SAMPLES = 20 -MAX_SEQUENCE_LENGTH = 2048 +MAX_SEQUENCE_LENGTH = 8192 -# Load dataset and preprocess. -ds = load_dataset(DATASET_ID, split=f"{DATASET_SPLIT}[:{NUM_CALIBRATION_SAMPLES}]") -ds = ds.shuffle(seed=42) +ds = load_dataset(DATASET_ID, name="LLM", split=f"train[:{NUM_CALIBRATION_SAMPLES}]") -def preprocess(example): - return { - "text": processor.apply_chat_template( - example["messages"], - tokenize=False, +def preprocess_function(example): + messgages = [] + for message in example["messages"]: + messgages.append( + { + "role": message["role"], + "content": [{"type": "text", "text": message["content"]}], + } ) - } - -ds = ds.map(preprocess) - - -# Tokenize inputs. -def tokenize(sample): - return processor( - sample["text"], + return processor.apply_chat_template( + messgages, + return_tensors="pt", padding=False, - max_length=MAX_SEQUENCE_LENGTH, truncation=True, + max_length=MAX_SEQUENCE_LENGTH, + tokenize=True, add_special_tokens=False, + return_dict=True, + add_generation_prompt=False, ) -ds = ds.map(tokenize, remove_columns=ds.column_names) +ds = ds.map(preprocess_function, batched=False, remove_columns=ds.column_names) + + +def data_collator(batch): + assert len(batch) == 1 + return { + key: ( + torch.tensor(value) + if key != "pixel_values" + else torch.tensor(value, dtype=torch.bfloat16).squeeze(0) + ) + for key, value in batch[0].items() + } + # Apply quantization. oneshot(model=model, recipe=recipe, - dataset=ds, + dataset=ds, + data_collator=data_collator, max_seq_length=MAX_SEQUENCE_LENGTH, num_calibration_samples=NUM_CALIBRATION_SAMPLES, moe_calibrate_all_experts=True) diff --git a/src/llmcompressor/modeling/qwen3_5_vl_moe.py b/src/llmcompressor/modeling/qwen3_5_vl_moe.py index 7f275fd696..3c870582fd 100644 --- a/src/llmcompressor/modeling/qwen3_5_vl_moe.py +++ b/src/llmcompressor/modeling/qwen3_5_vl_moe.py @@ -22,7 +22,7 @@ def __init__( self, original: "Qwen3_5MoeSparseMoeBlock", config: "Qwen3_5MoeConfig", - calibrate_all_experts: bool, + calibrate_all_experts: bool = True, ): super().__init__() text_config: "Qwen3_5MoeTextConfig" = config.get_text_config() @@ -33,6 +33,7 @@ def __init__( self.shared_expert_gate = original.shared_expert_gate self.gate = original.gate self.experts = SequentialQwen3VLMoeTextExperts(text_config, original.experts) + self.calibrate_all_experts = calibrate_all_experts def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: batch_size, sequence_length, hidden_dim = hidden_states.shape diff --git a/src/llmcompressor/utils/pytorch/module.py b/src/llmcompressor/utils/pytorch/module.py index 5a8098dd3b..6b22a7c6d9 100644 --- a/src/llmcompressor/utils/pytorch/module.py +++ b/src/llmcompressor/utils/pytorch/module.py @@ -136,7 +136,7 @@ def get_no_split_params(model: PreTrainedModel) -> Union[str, List[str]]: :return: list of class names that shouldn't be split """ - no_split_modules = model._get_no_split_modules("auto") + no_split_modules = model._no_split_modules if len(no_split_modules) <= 0: return ALL_TARGET From 40c6211d1a1c22e6331f9018b687c559827b74db Mon Sep 17 00:00:00 2001 From: Dipika Date: Tue, 3 Mar 2026 16:29:02 +0000 Subject: [PATCH 5/5] updatE --- examples/quantization_w4a4_fp4/qwen3_5_moe.py | 3 ++- examples/quantization_w8a8_fp8/qwen3_5_moe.py | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/examples/quantization_w4a4_fp4/qwen3_5_moe.py b/examples/quantization_w4a4_fp4/qwen3_5_moe.py index 06f3144a9a..b284a31322 100644 --- a/examples/quantization_w4a4_fp4/qwen3_5_moe.py +++ b/examples/quantization_w4a4_fp4/qwen3_5_moe.py @@ -5,6 +5,7 @@ import torch MODEL_ID = "/raid/engine/dsikka/models--Qwen--Qwen3.5-397B-A17B/snapshots/7cad2bae11cb49ca79f7d6a0954de2e2756f4e27" +MODEL_ID = "Qwen/Qwen3.5-122B-A10B" # Load model. model = Qwen3_5MoeForConditionalGeneration.from_pretrained(MODEL_ID, dtype="auto") @@ -82,6 +83,6 @@ def data_collator(batch): moe_calibrate_all_experts=True) # Save to disk in compressed-tensors format. -SAVE_DIR = "/raid/engine/dsikka/" + "Qwen3.5-397B-A17B" + "-NVFP4" +SAVE_DIR = "/mnt/nvme_stripe/playground/dsikka/" + "Qwen3.5-122B-A10B" + "-NVFP4" model.save_pretrained(SAVE_DIR) processor.save_pretrained(SAVE_DIR) diff --git a/examples/quantization_w8a8_fp8/qwen3_5_moe.py b/examples/quantization_w8a8_fp8/qwen3_5_moe.py index b4edad0e74..dd7c0a5d69 100644 --- a/examples/quantization_w8a8_fp8/qwen3_5_moe.py +++ b/examples/quantization_w8a8_fp8/qwen3_5_moe.py @@ -4,6 +4,7 @@ from llmcompressor.modifiers.quantization import QuantizationModifier MODEL_ID = "/raid/engine/dsikka/models--Qwen--Qwen3.5-397B-A17B/snapshots/7cad2bae11cb49ca79f7d6a0954de2e2756f4e27" +MODEL_ID = "Qwen/Qwen3.5-122B-A10B" # Load model. model = Qwen3_5MoeForConditionalGeneration.from_pretrained(MODEL_ID, dtype="auto") @@ -33,6 +34,6 @@ oneshot(model=model, recipe=recipe) # Save to disk in compressed-tensors format. -SAVE_DIR = "/raid/engine/dsikka/" + "Qwen3.5-397B-A17B" + "-FP8-Dynamic-NoLinearAttn" +SAVE_DIR = "/mnt/nvme_stripe/playground/dsikka/" + "Qwen3.5-122B-A10B" + "-FP8_DYNAMIC" model.save_pretrained(SAVE_DIR) processor.save_pretrained(SAVE_DIR)