Skip to content

Commit 961404d

Browse files
HDCharlesyiliu30
authored andcommitted
[dist][moe] fix add moe_context for big models (vllm-project#2405)
Summary: large models like Qwen/Qwen3-VL-235B-A22B-Instruct, when they add moe calibration context, different threads can take different lengths of time, for larger models this difference can be longer than the nccl timeout. fix: add a sync point at each module since we're rate limited to the slowest thread as is. at some point this should be changed to add moe calibration context in parallel and broadcast the updated modules TEST PLAN: tested e2e <details> ``` ############################################################################### # This script quantizes Qwen3-VL-235B-MoE with GPTQ + INT4 using DDP. # run this with `torchrun --nproc_per_node=8 qwen3_vl_235b_moe_gptq_int4_ddp_example.py` # or change nproc_per_node to your desired configuration # NOTE: Currently uses data-free GPTQ as only data-free quantization is supported for Qwen3-VL-MoE ############################################################################### from compressed_tensors.offload import init_dist, load_offloaded_model from transformers import AutoProcessor, Qwen3VLMoeForConditionalGeneration from llmcompressor import oneshot from llmcompressor.modifiers.quantization import GPTQModifier MODEL_ID = "Qwen/Qwen3-VL-235B-A22B-Instruct" ###### DDP MODEL LOAD CHANGE ##### init_dist() with load_offloaded_model(): model = Qwen3VLMoeForConditionalGeneration.from_pretrained( MODEL_ID, dtype="auto", device_map="auto_offload" ) ################################## processor = AutoProcessor.from_pretrained(MODEL_ID) # Recipe: GPTQ + INT4 (data-free) # NOTE: only datafree quantization is supported for Qwen3-VL-MoE currently recipe = GPTQModifier( targets="Linear", scheme="W4A16", ignore=[ "re:.*lm_head", "re:visual.*", "re:model.visual.*", "re:.*mlp.gate$", ], ) # Apply quantization (no dataset needed for data-free GPTQ) oneshot(model=model, recipe=recipe) import torch # Save to disk in compressed-tensors format. SAVE_DIR = ( MODEL_ID.rstrip("/").split("/")[-1] + "-GPTQ-W4A16-G128-DDP" + str(torch.distributed.get_world_size()) ) model.save_pretrained(SAVE_DIR, save_compressed=True) processor.save_pretrained(SAVE_DIR) ``` <\details> Signed-off-by: HDCharles <charlesdavidhernandez@gmail.com> Signed-off-by: yiliu30 <yi4.liu@intel.com>
1 parent af31a16 commit 961404d

File tree

1 file changed

+4
-0
lines changed

1 file changed

+4
-0
lines changed

src/llmcompressor/modeling/moe_context.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
from abc import ABC
1616

1717
import torch
18+
import torch.distributed as dist
19+
from compressed_tensors.offload import is_distributed
1820
from compressed_tensors.registry import RegistryMixin, standardize_lookup_name
1921
from loguru import logger
2022
from tqdm import tqdm
@@ -111,6 +113,8 @@ def moe_calibration_context(
111113
)
112114
model.set_submodule(name, replacement)
113115
replaced[name] = (module, replacement)
116+
if is_distributed():
117+
dist.barrier()
114118

115119
# Log what was replaced
116120
if replaced:

0 commit comments

Comments
 (0)