-
Notifications
You must be signed in to change notification settings - Fork 453
[Distributed][SmoothQuant] Add distributed activation scale reduction (#2180) #2471
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
kylesayrs
merged 6 commits into
vllm-project:main
from
dzhengAP:feature/smoothquant-distributed
Mar 18, 2026
+628
−0
Merged
Changes from all commits
Commits
Show all changes
6 commits
Select commit
Hold shift + click to select a range
33e3c07
feat(smoothquant): add distributed activation scale reduction
a6e80c8
fix(smoothquant): save model on rank 0 only in DDP example
6646930
fix(smoothquant): fix NCCL timeout in DDP example sample generation
12110f6
refactor(smoothquant): address HDCharles review comments
500a1e2
style: apply ruff formatting to test_smoothquant_distributed
dzhengAP 145c073
Merge branch 'main' into feature/smoothquant-distributed
dzhengAP File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
152 changes: 152 additions & 0 deletions
152
examples/quantization_w8a8_int8/benchmark_smoothquant_ddp.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,152 @@ | ||
| """ | ||
| Benchmark: Single-GPU vs Multi-GPU DDP SmoothQuant calibration time. | ||
|
|
||
| Usage: | ||
| # 1 GPU | ||
| python benchmark_smoothquant_ddp.py --num_gpus 1 | ||
|
|
||
| # 2 GPU | ||
| torchrun --standalone --nproc_per_node=2 benchmark_smoothquant_ddp.py --num_gpus 2 | ||
|
|
||
| # 4 GPU | ||
| torchrun --standalone --nproc_per_node=4 benchmark_smoothquant_ddp.py --num_gpus 4 | ||
| """ | ||
|
|
||
| import argparse | ||
| import time | ||
|
|
||
| import torch | ||
| import torch.distributed as dist | ||
| from compressed_tensors.offload import init_dist, load_offloaded_model | ||
| from datasets import load_dataset | ||
| from loguru import logger | ||
| from transformers import AutoModelForCausalLM, AutoTokenizer | ||
|
|
||
| from llmcompressor import oneshot | ||
| from llmcompressor.modifiers.gptq import GPTQModifier | ||
| from llmcompressor.modifiers.transform.smoothquant import SmoothQuantModifier | ||
|
|
||
| # --------------------------------------------------------------------------- | ||
| # Config | ||
| # --------------------------------------------------------------------------- | ||
| MODEL_ID = "Qwen/Qwen2-7B-Instruct" | ||
| DATASET_ID = "HuggingFaceH4/ultrachat_200k" | ||
| DATASET_SPLIT = "train_sft" | ||
| NUM_CALIBRATION_SAMPLES = 512 | ||
| MAX_SEQUENCE_LENGTH = 2048 | ||
|
|
||
|
|
||
| def get_rank(): | ||
dzhengAP marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| if dist.is_initialized(): | ||
| return dist.get_rank() | ||
| return 0 | ||
|
|
||
|
|
||
| def get_world_size(): | ||
| if dist.is_initialized(): | ||
| return dist.get_world_size() | ||
| return 1 | ||
|
|
||
|
|
||
| def main(num_gpus: int): | ||
| is_distributed = num_gpus > 1 | ||
|
|
||
| # ------------------------------------------------------------------ | ||
| # Init distributed if needed | ||
| # ------------------------------------------------------------------ | ||
| if is_distributed: | ||
| init_dist() | ||
| with load_offloaded_model(): | ||
| model = AutoModelForCausalLM.from_pretrained( | ||
| MODEL_ID, | ||
| dtype="auto", | ||
| device_map="auto_offload", | ||
| ) | ||
| else: | ||
| model = AutoModelForCausalLM.from_pretrained( | ||
| MODEL_ID, | ||
| dtype="auto", | ||
| device_map="auto", | ||
| ) | ||
|
|
||
| tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) | ||
|
|
||
| # ------------------------------------------------------------------ | ||
| # Dataset — each rank gets a disjoint slice | ||
| # ------------------------------------------------------------------ | ||
| rank = get_rank() | ||
| world_size = get_world_size() | ||
|
|
||
| samples_per_rank = NUM_CALIBRATION_SAMPLES // world_size | ||
| start = samples_per_rank * rank | ||
| split = f"{DATASET_SPLIT}[{start}:{start + samples_per_rank}]" | ||
|
|
||
| ds = load_dataset(DATASET_ID, split=split) | ||
dzhengAP marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| ds = ds.shuffle(seed=42) | ||
|
|
||
| def preprocess(example): | ||
| return { | ||
| "text": tokenizer.apply_chat_template( | ||
| example["messages"], | ||
| tokenize=False, | ||
| ) | ||
| } | ||
|
|
||
| def tokenize(sample): | ||
| return tokenizer( | ||
| sample["text"], | ||
| padding=False, | ||
| max_length=MAX_SEQUENCE_LENGTH, | ||
| truncation=True, | ||
| add_special_tokens=False, | ||
| ) | ||
|
|
||
| ds = ds.map(preprocess) | ||
| ds = ds.map(tokenize, remove_columns=ds.column_names) | ||
|
|
||
| # ------------------------------------------------------------------ | ||
| # Recipe | ||
| # ------------------------------------------------------------------ | ||
| recipe = [ | ||
| SmoothQuantModifier(smoothing_strength=0.8), | ||
| GPTQModifier(targets="Linear", scheme="W8A8", ignore=["lm_head"]), | ||
| ] | ||
|
|
||
| # ------------------------------------------------------------------ | ||
| # Benchmark | ||
| # ------------------------------------------------------------------ | ||
| torch.cuda.reset_peak_memory_stats() | ||
| start_time = time.time() | ||
|
|
||
| oneshot( | ||
| model=model, | ||
| dataset=ds, | ||
| recipe=recipe, | ||
| max_seq_length=MAX_SEQUENCE_LENGTH, | ||
| num_calibration_samples=NUM_CALIBRATION_SAMPLES, | ||
| ) | ||
|
|
||
| elapsed = time.time() - start_time | ||
| peak_mem_gb = torch.cuda.max_memory_allocated() / (1024**3) | ||
|
|
||
| if rank == 0: | ||
| logger.info("=" * 60) | ||
| logger.info(f"BENCHMARK RESULTS — {world_size} GPU(s)") | ||
| logger.info("=" * 60) | ||
| logger.info(f"Model: {MODEL_ID}") | ||
| logger.info(f"Calibration: {NUM_CALIBRATION_SAMPLES} samples total") | ||
| logger.info(f"Samples/rank: {samples_per_rank}") | ||
| logger.info(f"World size: {world_size}") | ||
| logger.info(f"Total time: {elapsed:.1f}s ({elapsed/60:.2f} min)") | ||
| logger.info(f"Peak GPU mem: {peak_mem_gb:.2f} GB (rank 0)") | ||
| logger.info("=" * 60) | ||
|
|
||
| if is_distributed: | ||
| dist.destroy_process_group() | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| parser = argparse.ArgumentParser() | ||
| parser.add_argument("--num_gpus", type=int, default=1) | ||
| args = parser.parse_args() | ||
| main(args.num_gpus) | ||
142 changes: 142 additions & 0 deletions
142
examples/quantization_w8a8_int8/smoothquant_ddp_example.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,142 @@ | ||
| """ | ||
| Distributed SmoothQuant + GPTQ W8A8 quantization using Data-Parallel Calibration. | ||
|
|
||
| Run with: | ||
| torchrun --standalone --nproc_per_node=NUM_GPUS smoothquant_ddp_example.py | ||
|
|
||
| Each rank loads a disjoint partition of the calibration dataset. | ||
| SmoothQuantModifier all-reduces per-channel activation statistics across ranks | ||
| before computing smoothing scales (identical on every rank, no weight broadcast | ||
| needed). GPTQModifier then applies distributed W8A8 quantization. | ||
|
|
||
| This script intentionally mirrors the structure of | ||
| examples/quantization_w4a16/llama3_ddp_example.py so it is easy to diff. | ||
| """ | ||
|
|
||
| import time | ||
|
|
||
| import torch | ||
| import torch.distributed as dist | ||
| from compressed_tensors.offload import dispatch_model, init_dist, load_offloaded_model | ||
| from datasets import load_dataset | ||
| from loguru import logger | ||
| from transformers import AutoModelForCausalLM, AutoTokenizer | ||
|
|
||
| from llmcompressor import oneshot | ||
| from llmcompressor.datasets.utils import get_rank_partition | ||
| from llmcompressor.modifiers.gptq import GPTQModifier | ||
| from llmcompressor.modifiers.transform.smoothquant import SmoothQuantModifier | ||
|
|
||
| # --------------------------------------------------------------------------- | ||
| # Config | ||
| # --------------------------------------------------------------------------- | ||
| MODEL_ID = "Qwen/Qwen2-7B-Instruct" | ||
| DATASET_ID = "HuggingFaceH4/ultrachat_200k" | ||
| DATASET_SPLIT = "train_sft" | ||
| NUM_CALIBRATION_SAMPLES = 512 | ||
| MAX_SEQUENCE_LENGTH = 2048 | ||
|
|
||
| # --------------------------------------------------------------------------- | ||
| # DDP init + model load | ||
| # --------------------------------------------------------------------------- | ||
| init_dist() | ||
|
|
||
| with load_offloaded_model(): | ||
| model = AutoModelForCausalLM.from_pretrained( | ||
| MODEL_ID, | ||
| dtype="auto", | ||
| device_map="auto_offload", | ||
| ) | ||
|
|
||
| tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) | ||
|
|
||
| # --------------------------------------------------------------------------- | ||
| # Dataset: each rank gets a disjoint slice | ||
| # --------------------------------------------------------------------------- | ||
| ds = load_dataset( | ||
| DATASET_ID, | ||
| split=get_rank_partition(DATASET_SPLIT, NUM_CALIBRATION_SAMPLES), | ||
| ) | ||
| ds = ds.shuffle(seed=42) | ||
|
|
||
|
|
||
| def preprocess(example): | ||
| return { | ||
| "text": tokenizer.apply_chat_template( | ||
| example["messages"], | ||
| tokenize=False, | ||
| ) | ||
| } | ||
|
|
||
|
|
||
| def tokenize(sample): | ||
| return tokenizer( | ||
| sample["text"], | ||
| padding=False, | ||
| max_length=MAX_SEQUENCE_LENGTH, | ||
| truncation=True, | ||
| add_special_tokens=False, | ||
| ) | ||
|
|
||
|
|
||
| ds = ds.map(preprocess) | ||
| ds = ds.map(tokenize, remove_columns=ds.column_names) | ||
|
|
||
| # --------------------------------------------------------------------------- | ||
| # Recipe: SmoothQuant (distributed-aware) + GPTQ W8A8 | ||
| # --------------------------------------------------------------------------- | ||
| recipe = [ | ||
| SmoothQuantModifier(smoothing_strength=0.8), | ||
| GPTQModifier(targets="Linear", scheme="W8A8", ignore=["lm_head"]), | ||
| ] | ||
|
|
||
| # --------------------------------------------------------------------------- | ||
| # Run oneshot | ||
| # --------------------------------------------------------------------------- | ||
| torch.cuda.reset_peak_memory_stats() | ||
| start_time = time.time() | ||
|
|
||
| oneshot( | ||
| model=model, | ||
| dataset=ds, | ||
| recipe=recipe, | ||
| max_seq_length=MAX_SEQUENCE_LENGTH, | ||
| num_calibration_samples=NUM_CALIBRATION_SAMPLES, | ||
| ) | ||
|
|
||
| elapsed = time.time() - start_time | ||
| peak_mem_gb = torch.cuda.max_memory_allocated() / (1024**3) | ||
|
|
||
| rank = dist.get_rank() | ||
| logger.info( | ||
| f"[Rank {rank}] Done in {elapsed:.1f}s | Peak GPU mem: {peak_mem_gb:.2f} GB" | ||
| ) | ||
|
|
||
| # --------------------------------------------------------------------------- | ||
| # Sample generation (rank 0 only) | ||
| # --------------------------------------------------------------------------- | ||
| # Sample generation (all ranks must participate) | ||
| dist.barrier() | ||
| dispatch_model(model) | ||
| sample = tokenizer("Hello my name is", return_tensors="pt") | ||
| sample = {k: v.to(model.device) for k, v in sample.items()} | ||
| output = model.generate(**sample, max_new_tokens=50) | ||
| if rank == 0: | ||
| logger.info("\n========== SAMPLE GENERATION ==========") | ||
| logger.info(tokenizer.decode(output[0])) | ||
| logger.info("========================================\n") | ||
|
|
||
| # --------------------------------------------------------------------------- | ||
| # Save (rank 0 only — save_pretrained handles dist internally) | ||
| # --------------------------------------------------------------------------- | ||
| if rank == 0: | ||
| SAVE_DIR = ( | ||
| MODEL_ID.rstrip("/").split("/")[-1] | ||
| + "-W8A8-SmoothQuant-DDP" | ||
| + str(dist.get_world_size()) | ||
| ) | ||
| model.save_pretrained(SAVE_DIR, save_compressed=True) | ||
| tokenizer.save_pretrained(SAVE_DIR) | ||
| logger.info(f"[Rank {rank}] Saved to {SAVE_DIR}") | ||
|
|
||
| dist.destroy_process_group() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.