Skip to content

Commit a2433a9

Browse files
yiliu30HDCharlesdsikka
authored
[AutoRound] Add DDP Support and Example (#2411)
SUMMARY: Add DDP support for Autoround and use Qwen as example. Depends on #2410 TEST PLAN: "please outline how the changes were tested" cc @hshen14 @thuang6 --------- Signed-off-by: HDCharles <charlesdavidhernandez@gmail.com> Signed-off-by: yiliu30 <yi4.liu@intel.com> Co-authored-by: HDCharles <charlesdavidhernandez@gmail.com> Co-authored-by: Dipika Sikka <dipikasikka1@gmail.com>
1 parent a88ebbd commit a2433a9

File tree

2 files changed

+165
-0
lines changed

2 files changed

+165
-0
lines changed
Lines changed: 157 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,157 @@
1+
"""
2+
This script is adapted to use DDP functionality with AutoRound.
3+
run this with `torchrun --nproc_per_node=2 ddp_qwen3_example.py`
4+
or change nproc_per_node to your desired configuration
5+
6+
Example usage:
7+
torchrun --nproc_per_node=2 ddp_qwen3_example.py \
8+
--model Qwen/Qwen3-8B \
9+
--nsamples 128 \
10+
--iters 100 \
11+
--disable_torch_compile \
12+
--deterministic
13+
"""
14+
15+
import argparse
16+
import os
17+
18+
import torch
19+
import torch.distributed as dist
20+
from compressed_tensors.offload import dispatch_model, init_dist, load_offloaded_model
21+
from loguru import logger
22+
from transformers import AutoModelForCausalLM, AutoTokenizer
23+
24+
from llmcompressor import oneshot
25+
26+
27+
def fix_everything(seed=42):
28+
import random
29+
30+
import numpy as np
31+
32+
random.seed(seed)
33+
np.random.seed(seed)
34+
torch.manual_seed(seed)
35+
torch.cuda.manual_seed_all(seed)
36+
37+
38+
def config_deterministic():
39+
torch.use_deterministic_algorithms(True, warn_only=False)
40+
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
41+
fix_everything()
42+
43+
44+
if __name__ == "__main__":
45+
parser = argparse.ArgumentParser(
46+
description="AutoRound Quantization with DDP support"
47+
)
48+
parser.add_argument(
49+
"--model",
50+
type=str,
51+
default="Qwen/Qwen3-8B",
52+
help="Model name or path",
53+
)
54+
parser.add_argument(
55+
"--scheme",
56+
type=str,
57+
default="W4A16",
58+
help="Quantization scheme (W4A16, MXFP8, MXFP4, etc.)",
59+
)
60+
parser.add_argument("--iters", type=int, default=200, help="Number of iterations")
61+
parser.add_argument("--nsamples", type=int, default=128, help="Number of samples")
62+
parser.add_argument(
63+
"--disable_torch_compile",
64+
action="store_true",
65+
help="Disable torch.compile for model acceleration during quantization",
66+
)
67+
parser.add_argument(
68+
"--deterministic",
69+
action="store_true",
70+
help="Enable deterministic mode for reproducibility",
71+
)
72+
args = parser.parse_args()
73+
74+
if args.deterministic:
75+
config_deterministic()
76+
77+
model_id = args.model
78+
79+
###### DDP MODEL LOAD CHANGE #####
80+
init_dist()
81+
with load_offloaded_model():
82+
model = AutoModelForCausalLM.from_pretrained(
83+
model_id, dtype="auto", device_map="auto_offload"
84+
)
85+
##################################
86+
87+
tokenizer = AutoTokenizer.from_pretrained(model_id)
88+
89+
# Select calibration dataset.
90+
NUM_CALIBRATION_SAMPLES = args.nsamples
91+
MAX_SEQUENCE_LENGTH = 2048
92+
ITERS = args.iters
93+
94+
95+
# Get aligned calibration dataset.
96+
from auto_round.calib_dataset import get_dataset # noqa: E402
97+
98+
# Note: Make sure model are loaded before importing auto-round related code.
99+
# This requirement will be lifted once switching to new release of auto-round which
100+
# includes below fix:
101+
from llmcompressor.modifiers.autoround import AutoRoundModifier # noqa: E402
102+
103+
ds = get_dataset(
104+
tokenizer=tokenizer,
105+
seqlen=MAX_SEQUENCE_LENGTH,
106+
nsamples=NUM_CALIBRATION_SAMPLES,
107+
)
108+
109+
# Configure the quantization algorithm to run.
110+
# * quantize the weights to 4 bit with AutoRound with a group size 128
111+
recipe = AutoRoundModifier(
112+
targets="Linear",
113+
scheme=args.scheme,
114+
ignore=[
115+
"lm_head",
116+
"re:.*mlp.gate$",
117+
],
118+
iters=ITERS,
119+
enable_torch_compile=not args.disable_torch_compile,
120+
)
121+
122+
# Apply algorithms.
123+
oneshot(
124+
model=model,
125+
dataset=ds,
126+
recipe=recipe,
127+
max_seq_length=MAX_SEQUENCE_LENGTH,
128+
num_calibration_samples=NUM_CALIBRATION_SAMPLES,
129+
shuffle_calibration_samples=False,
130+
)
131+
132+
rank = dist.get_rank()
133+
logger.info(f"[Rank {rank}] Quantization completed")
134+
# Confirm generations of the quantized model look sane.
135+
logger.info("\n\n")
136+
logger.info("========== SAMPLE GENERATION ==============")
137+
dispatch_model(model)
138+
sample = tokenizer("Hello my name is", return_tensors="pt")
139+
sample = {key: value.to(model.device) for key, value in sample.items()}
140+
output = model.generate(**sample, max_new_tokens=100)
141+
logger.info(tokenizer.decode(output[0]))
142+
logger.info("==========================================\n\n")
143+
144+
logger.info("Saving...")
145+
# Save to disk compressed.
146+
SAVE_DIR = (
147+
model_id.rstrip("/").split("/")[-1]
148+
+ f"-{args.scheme}-AutoRound"
149+
+ f"-iters{args.iters}-nsamples{args.nsamples}"
150+
+ "-DDP"
151+
+ str(dist.get_world_size())
152+
)
153+
model.save_pretrained(SAVE_DIR, save_compressed=True)
154+
tokenizer.save_pretrained(SAVE_DIR)
155+
logger.info(f"Saved to {SAVE_DIR}")
156+
157+
dist.destroy_process_group()

src/llmcompressor/modifiers/autoround/base.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -280,6 +280,7 @@ def apply_autoround(self, state, subgraph):
280280
align_module_device(decoding_layer),
281281
suspend_offloading(wrapped_model),
282282
):
283+
self._update_device_map_for_dp(kwargs)
283284
ar = AutoRound(
284285
model=wrapped_model,
285286
**kwargs,
@@ -349,6 +350,13 @@ def get_unquantized_layer_names(self, wrapped_model: torch.nn.Module) -> List[st
349350
unquantized_layers.append(name)
350351
return unquantized_layers
351352

353+
def _update_device_map_for_dp(self, ar_kwargs):
354+
if torch.distributed.is_initialized():
355+
rank = torch.distributed.get_rank()
356+
ar_kwargs["device_map"] = (
357+
f"cuda:{rank}" if torch.cuda.is_available() else "cpu"
358+
)
359+
352360
def _unwrapper_quantized_layer(self, model: torch.nn.Module):
353361
# auto-round will return WrapperWALayer if activation is quantized
354362
for name, module in model.named_modules():

0 commit comments

Comments
 (0)