Skip to content

Commit a3d2a3f

Browse files
dzhengAPDavid Zheng
andauthored
[Distributed][SmoothQuant] Add distributed activation scale reduction (#2180) (#2471)
## Summary Implements distributed support for `SmoothQuantModifier` as part of the weight-parallel optimization tracked in #2180, assigned to @dzhengAP. ## What this PR does In a distributed calibration run, each rank observes a disjoint partition of the calibration dataset. Activation statistics (per-channel min/max) are collected locally via forward hooks. Before smoothing scales are computed, `_reduce_activation_scales()` all-reduces those statistics across all ranks so every rank has the global activation profile, then each rank independently computes identical smoothing scales (cheap op, no weight broadcast needed). This follows the AWQ strategy described in #2180. Single-GPU behavior is completely unchanged — all new code is guarded by `is_distributed()`. ## Changes - `_reduce_activation_scales()`: all-reduces `min/max_channel_vals` across ranks using async MIN/MAX collectives batched with `wait_for_comms` - `_apply_smoothing()`: calls `_reduce_activation_scales()` as first step - Unit tests: 5 mock-based tests verifying call contract (no GPU needed) - DDP example: `examples/quantization_w8a8_int8/smoothquant_ddp_example.py` - Multi-GPU integration tests verifying weight equivalence vs single-GPU ## Test results - Unit tests: all 5 passed (`pytest -m unit`) - DDP example: ran successfully on 2x V100 32GB, both ranks completed in ~698s, peak GPU mem 1.66 GB per rank ## Distributed Speedup Benchmarks Model: Qwen/Qwen2-7B-Instruct, 512 calibration samples, 4x V100 32GB | GPUs | Total Time | Peak GPU Mem | Speedup | |------|-----------|-------------|---------| | 1 GPU | 94.1 min | 8.93 GB | 1.00x | | 2 GPU | 58.7 min | 7.06 GB | 1.60x | | 4 GPU | 28.7 min | 7.06 GB | 3.28x | Benchmark script: `examples/quantization_w8a8_int8/benchmark_smoothquant_ddp.py` - ruff: all checks passed Closes part of #2180 cc @kylesayrs --------- Signed-off-by: David Zheng <dzheng@apple.com> Signed-off-by: David Zheng <dqzheng1996@gmail.com> Co-authored-by: David Zheng <dzheng@apple.com>
1 parent eb49917 commit a3d2a3f

File tree

4 files changed

+628
-0
lines changed

4 files changed

+628
-0
lines changed
Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,152 @@
1+
"""
2+
Benchmark: Single-GPU vs Multi-GPU DDP SmoothQuant calibration time.
3+
4+
Usage:
5+
# 1 GPU
6+
python benchmark_smoothquant_ddp.py --num_gpus 1
7+
8+
# 2 GPU
9+
torchrun --standalone --nproc_per_node=2 benchmark_smoothquant_ddp.py --num_gpus 2
10+
11+
# 4 GPU
12+
torchrun --standalone --nproc_per_node=4 benchmark_smoothquant_ddp.py --num_gpus 4
13+
"""
14+
15+
import argparse
16+
import time
17+
18+
import torch
19+
import torch.distributed as dist
20+
from compressed_tensors.offload import init_dist, load_offloaded_model
21+
from datasets import load_dataset
22+
from loguru import logger
23+
from transformers import AutoModelForCausalLM, AutoTokenizer
24+
25+
from llmcompressor import oneshot
26+
from llmcompressor.modifiers.gptq import GPTQModifier
27+
from llmcompressor.modifiers.transform.smoothquant import SmoothQuantModifier
28+
29+
# ---------------------------------------------------------------------------
30+
# Config
31+
# ---------------------------------------------------------------------------
32+
MODEL_ID = "Qwen/Qwen2-7B-Instruct"
33+
DATASET_ID = "HuggingFaceH4/ultrachat_200k"
34+
DATASET_SPLIT = "train_sft"
35+
NUM_CALIBRATION_SAMPLES = 512
36+
MAX_SEQUENCE_LENGTH = 2048
37+
38+
39+
def get_rank():
40+
if dist.is_initialized():
41+
return dist.get_rank()
42+
return 0
43+
44+
45+
def get_world_size():
46+
if dist.is_initialized():
47+
return dist.get_world_size()
48+
return 1
49+
50+
51+
def main(num_gpus: int):
52+
is_distributed = num_gpus > 1
53+
54+
# ------------------------------------------------------------------
55+
# Init distributed if needed
56+
# ------------------------------------------------------------------
57+
if is_distributed:
58+
init_dist()
59+
with load_offloaded_model():
60+
model = AutoModelForCausalLM.from_pretrained(
61+
MODEL_ID,
62+
dtype="auto",
63+
device_map="auto_offload",
64+
)
65+
else:
66+
model = AutoModelForCausalLM.from_pretrained(
67+
MODEL_ID,
68+
dtype="auto",
69+
device_map="auto",
70+
)
71+
72+
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
73+
74+
# ------------------------------------------------------------------
75+
# Dataset — each rank gets a disjoint slice
76+
# ------------------------------------------------------------------
77+
rank = get_rank()
78+
world_size = get_world_size()
79+
80+
samples_per_rank = NUM_CALIBRATION_SAMPLES // world_size
81+
start = samples_per_rank * rank
82+
split = f"{DATASET_SPLIT}[{start}:{start + samples_per_rank}]"
83+
84+
ds = load_dataset(DATASET_ID, split=split)
85+
ds = ds.shuffle(seed=42)
86+
87+
def preprocess(example):
88+
return {
89+
"text": tokenizer.apply_chat_template(
90+
example["messages"],
91+
tokenize=False,
92+
)
93+
}
94+
95+
def tokenize(sample):
96+
return tokenizer(
97+
sample["text"],
98+
padding=False,
99+
max_length=MAX_SEQUENCE_LENGTH,
100+
truncation=True,
101+
add_special_tokens=False,
102+
)
103+
104+
ds = ds.map(preprocess)
105+
ds = ds.map(tokenize, remove_columns=ds.column_names)
106+
107+
# ------------------------------------------------------------------
108+
# Recipe
109+
# ------------------------------------------------------------------
110+
recipe = [
111+
SmoothQuantModifier(smoothing_strength=0.8),
112+
GPTQModifier(targets="Linear", scheme="W8A8", ignore=["lm_head"]),
113+
]
114+
115+
# ------------------------------------------------------------------
116+
# Benchmark
117+
# ------------------------------------------------------------------
118+
torch.cuda.reset_peak_memory_stats()
119+
start_time = time.time()
120+
121+
oneshot(
122+
model=model,
123+
dataset=ds,
124+
recipe=recipe,
125+
max_seq_length=MAX_SEQUENCE_LENGTH,
126+
num_calibration_samples=NUM_CALIBRATION_SAMPLES,
127+
)
128+
129+
elapsed = time.time() - start_time
130+
peak_mem_gb = torch.cuda.max_memory_allocated() / (1024**3)
131+
132+
if rank == 0:
133+
logger.info("=" * 60)
134+
logger.info(f"BENCHMARK RESULTS — {world_size} GPU(s)")
135+
logger.info("=" * 60)
136+
logger.info(f"Model: {MODEL_ID}")
137+
logger.info(f"Calibration: {NUM_CALIBRATION_SAMPLES} samples total")
138+
logger.info(f"Samples/rank: {samples_per_rank}")
139+
logger.info(f"World size: {world_size}")
140+
logger.info(f"Total time: {elapsed:.1f}s ({elapsed/60:.2f} min)")
141+
logger.info(f"Peak GPU mem: {peak_mem_gb:.2f} GB (rank 0)")
142+
logger.info("=" * 60)
143+
144+
if is_distributed:
145+
dist.destroy_process_group()
146+
147+
148+
if __name__ == "__main__":
149+
parser = argparse.ArgumentParser()
150+
parser.add_argument("--num_gpus", type=int, default=1)
151+
args = parser.parse_args()
152+
main(args.num_gpus)
Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,142 @@
1+
"""
2+
Distributed SmoothQuant + GPTQ W8A8 quantization using Data-Parallel Calibration.
3+
4+
Run with:
5+
torchrun --standalone --nproc_per_node=NUM_GPUS smoothquant_ddp_example.py
6+
7+
Each rank loads a disjoint partition of the calibration dataset.
8+
SmoothQuantModifier all-reduces per-channel activation statistics across ranks
9+
before computing smoothing scales (identical on every rank, no weight broadcast
10+
needed). GPTQModifier then applies distributed W8A8 quantization.
11+
12+
This script intentionally mirrors the structure of
13+
examples/quantization_w4a16/llama3_ddp_example.py so it is easy to diff.
14+
"""
15+
16+
import time
17+
18+
import torch
19+
import torch.distributed as dist
20+
from compressed_tensors.offload import dispatch_model, init_dist, load_offloaded_model
21+
from datasets import load_dataset
22+
from loguru import logger
23+
from transformers import AutoModelForCausalLM, AutoTokenizer
24+
25+
from llmcompressor import oneshot
26+
from llmcompressor.datasets.utils import get_rank_partition
27+
from llmcompressor.modifiers.gptq import GPTQModifier
28+
from llmcompressor.modifiers.transform.smoothquant import SmoothQuantModifier
29+
30+
# ---------------------------------------------------------------------------
31+
# Config
32+
# ---------------------------------------------------------------------------
33+
MODEL_ID = "Qwen/Qwen2-7B-Instruct"
34+
DATASET_ID = "HuggingFaceH4/ultrachat_200k"
35+
DATASET_SPLIT = "train_sft"
36+
NUM_CALIBRATION_SAMPLES = 512
37+
MAX_SEQUENCE_LENGTH = 2048
38+
39+
# ---------------------------------------------------------------------------
40+
# DDP init + model load
41+
# ---------------------------------------------------------------------------
42+
init_dist()
43+
44+
with load_offloaded_model():
45+
model = AutoModelForCausalLM.from_pretrained(
46+
MODEL_ID,
47+
dtype="auto",
48+
device_map="auto_offload",
49+
)
50+
51+
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
52+
53+
# ---------------------------------------------------------------------------
54+
# Dataset: each rank gets a disjoint slice
55+
# ---------------------------------------------------------------------------
56+
ds = load_dataset(
57+
DATASET_ID,
58+
split=get_rank_partition(DATASET_SPLIT, NUM_CALIBRATION_SAMPLES),
59+
)
60+
ds = ds.shuffle(seed=42)
61+
62+
63+
def preprocess(example):
64+
return {
65+
"text": tokenizer.apply_chat_template(
66+
example["messages"],
67+
tokenize=False,
68+
)
69+
}
70+
71+
72+
def tokenize(sample):
73+
return tokenizer(
74+
sample["text"],
75+
padding=False,
76+
max_length=MAX_SEQUENCE_LENGTH,
77+
truncation=True,
78+
add_special_tokens=False,
79+
)
80+
81+
82+
ds = ds.map(preprocess)
83+
ds = ds.map(tokenize, remove_columns=ds.column_names)
84+
85+
# ---------------------------------------------------------------------------
86+
# Recipe: SmoothQuant (distributed-aware) + GPTQ W8A8
87+
# ---------------------------------------------------------------------------
88+
recipe = [
89+
SmoothQuantModifier(smoothing_strength=0.8),
90+
GPTQModifier(targets="Linear", scheme="W8A8", ignore=["lm_head"]),
91+
]
92+
93+
# ---------------------------------------------------------------------------
94+
# Run oneshot
95+
# ---------------------------------------------------------------------------
96+
torch.cuda.reset_peak_memory_stats()
97+
start_time = time.time()
98+
99+
oneshot(
100+
model=model,
101+
dataset=ds,
102+
recipe=recipe,
103+
max_seq_length=MAX_SEQUENCE_LENGTH,
104+
num_calibration_samples=NUM_CALIBRATION_SAMPLES,
105+
)
106+
107+
elapsed = time.time() - start_time
108+
peak_mem_gb = torch.cuda.max_memory_allocated() / (1024**3)
109+
110+
rank = dist.get_rank()
111+
logger.info(
112+
f"[Rank {rank}] Done in {elapsed:.1f}s | Peak GPU mem: {peak_mem_gb:.2f} GB"
113+
)
114+
115+
# ---------------------------------------------------------------------------
116+
# Sample generation (rank 0 only)
117+
# ---------------------------------------------------------------------------
118+
# Sample generation (all ranks must participate)
119+
dist.barrier()
120+
dispatch_model(model)
121+
sample = tokenizer("Hello my name is", return_tensors="pt")
122+
sample = {k: v.to(model.device) for k, v in sample.items()}
123+
output = model.generate(**sample, max_new_tokens=50)
124+
if rank == 0:
125+
logger.info("\n========== SAMPLE GENERATION ==========")
126+
logger.info(tokenizer.decode(output[0]))
127+
logger.info("========================================\n")
128+
129+
# ---------------------------------------------------------------------------
130+
# Save (rank 0 only — save_pretrained handles dist internally)
131+
# ---------------------------------------------------------------------------
132+
if rank == 0:
133+
SAVE_DIR = (
134+
MODEL_ID.rstrip("/").split("/")[-1]
135+
+ "-W8A8-SmoothQuant-DDP"
136+
+ str(dist.get_world_size())
137+
)
138+
model.save_pretrained(SAVE_DIR, save_compressed=True)
139+
tokenizer.save_pretrained(SAVE_DIR)
140+
logger.info(f"[Rank {rank}] Saved to {SAVE_DIR}")
141+
142+
dist.destroy_process_group()

src/llmcompressor/modifiers/transform/smoothquant/base.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,9 @@
22
from typing import Callable
33

44
import torch
5+
import torch.distributed as dist
56
from compressed_tensors.offload import update_offload_parameter
7+
from compressed_tensors.offload.dist_utils import is_distributed
68
from compressed_tensors.utils import match_modules_set, match_named_modules
79
from loguru import logger
810
from pydantic import ConfigDict, Field
@@ -15,6 +17,7 @@
1517
get_layer_mappings_from_architecture,
1618
handle_mapping_resolution_errors,
1719
)
20+
from llmcompressor.utils.dist import wait_for_comms
1821
from llmcompressor.utils.pytorch.module import get_module_to_name_dict
1922

2023
MINIMUM_SMOOTHING_SCALE = 1e-5
@@ -290,6 +293,40 @@ def hook_fn(module, inp, out):
290293
layer = mapping.smooth_layer
291294
self.register_hook(layer, create_hook_fn(name), "forward")
292295

296+
def _reduce_activation_scales(self):
297+
"""
298+
In a distributed setting, all-reduce the per-channel min/max activation
299+
statistics collected by each rank during calibration.
300+
301+
Each rank observes a disjoint partition of the calibration dataset, so the
302+
global channel-wise min/max must be gathered across all ranks before smoothing
303+
scales are computed. We use ``dist.all_reduce`` with MIN/MAX ops so that
304+
every rank ends up with identical statistics and can independently compute the
305+
same smoothing scales without an extra broadcast of the final scale tensors.
306+
307+
This is a no-op when not running in a distributed context.
308+
"""
309+
if not is_distributed():
310+
return
311+
312+
pending_comms = []
313+
for layer_name, scale in self.scales_.items():
314+
pending_comms.append(
315+
dist.all_reduce(
316+
scale.min_channel_vals,
317+
op=dist.ReduceOp.MIN,
318+
async_op=True,
319+
)
320+
)
321+
pending_comms.append(
322+
dist.all_reduce(
323+
scale.max_channel_vals,
324+
op=dist.ReduceOp.MAX,
325+
async_op=True,
326+
)
327+
)
328+
wait_for_comms(pending_comms)
329+
293330
@torch.no_grad()
294331
def _apply_smoothing(self, model: Module):
295332
"""
@@ -299,8 +336,16 @@ def _apply_smoothing(self, model: Module):
299336
Y = (Xdiag(scales)^(-1) * diag(scales)W) where W is the to_balance weights and
300337
X is the to_smooth weights
301338
339+
In a distributed setting, activation statistics are first all-reduced across
340+
ranks so that every rank computes identical smoothing scales. The scale
341+
computation itself is duplicated across ranks (cheap), avoiding the need for
342+
a broadcast of the final weight tensors.
343+
302344
This modifies the weights of the model in-place.
303345
"""
346+
if is_distributed():
347+
self._reduce_activation_scales()
348+
304349
for mapping in self.resolved_mappings_:
305350
if mapping.smooth_name not in self.scales_:
306351
continue

0 commit comments

Comments
 (0)