Skip to content

Commit 2a66f58

Browse files
HDCharlesyiliu30
andcommitted
Refine Autoround ddp example
Signed-off-by: HDCharles <charlesdavidhernandez@gmail.com> Co-authored-by: yiliu30 <yi4.liu@intel.com>
1 parent 47ec10e commit 2a66f58

File tree

1 file changed

+152
-0
lines changed

1 file changed

+152
-0
lines changed
Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,152 @@
1+
#############################################################################
2+
# This script is adapted to use DDP functionality with AutoRound.
3+
# run this with `torchrun --nproc_per_node=2 ddp_qwen3_example.py`
4+
# or change nproc_per_node to your desired configuration
5+
#
6+
# Example usage:
7+
# torchrun --nproc_per_node=2 ddp_qwen3_example.py \
8+
# --model Qwen/Qwen3-8B \
9+
# --nsamples 128 \
10+
# --iters 200 \
11+
# --disable_torch_compile \
12+
# --deterministic
13+
#############################################################################
14+
15+
import argparse
16+
import os
17+
import time
18+
19+
import torch
20+
import torch.distributed as dist
21+
from compressed_tensors.offload import dispatch_model, init_dist, load_offloaded_model
22+
from datasets import load_dataset
23+
from loguru import logger
24+
from transformers import AutoModelForCausalLM, AutoTokenizer
25+
import torch.distributed as dist
26+
from llmcompressor import oneshot
27+
from llmcompressor.datasets.utils import get_rank_partition
28+
from llmcompressor.modifiers.autoround import AutoRoundModifier
29+
30+
31+
def fix_everything(seed=42):
32+
import random
33+
import numpy as np
34+
35+
random.seed(seed)
36+
np.random.seed(seed)
37+
torch.manual_seed(seed)
38+
torch.cuda.manual_seed_all(seed)
39+
40+
41+
def config_deterministic():
42+
torch.use_deterministic_algorithms(True, warn_only=False)
43+
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
44+
fix_everything()
45+
46+
47+
if __name__ == "__main__":
48+
parser = argparse.ArgumentParser(
49+
description="AutoRound Quantization with DDP support"
50+
)
51+
parser.add_argument(
52+
"--model",
53+
type=str,
54+
default="Qwen/Qwen3-8B",
55+
help="Model name or path",
56+
)
57+
parser.add_argument(
58+
"--scheme",
59+
type=str,
60+
default="W4A16",
61+
help="Quantization scheme (W4A16, MXFP8, MXFP4, etc.)",
62+
)
63+
parser.add_argument("--iters", type=int, default=200, help="Number of iterations")
64+
parser.add_argument("--nsamples", type=int, default=128, help="Number of samples")
65+
parser.add_argument(
66+
"--disable_torch_compile",
67+
action="store_true",
68+
help="Disable torch.compile for model acceleration during quantization",
69+
)
70+
parser.add_argument(
71+
"--deterministic",
72+
action="store_true",
73+
help="Enable deterministic mode for reproducibility",
74+
)
75+
args = parser.parse_args()
76+
77+
if args.deterministic:
78+
config_deterministic()
79+
80+
model_id = args.model
81+
82+
###### DDP MODEL LOAD CHANGE #####
83+
init_dist()
84+
with load_offloaded_model():
85+
model = AutoModelForCausalLM.from_pretrained(
86+
model_id, dtype="auto", device_map="auto_offload"
87+
)
88+
##################################
89+
90+
tokenizer = AutoTokenizer.from_pretrained(model_name)
91+
92+
# Select calibration dataset.
93+
NUM_CALIBRATION_SAMPLES = args.nsamples
94+
MAX_SEQUENCE_LENGTH = 2048
95+
ITERS = args.iters
96+
# Get aligned calibration dataset.
97+
98+
ds = get_dataset(
99+
tokenizer=tokenizer,
100+
seqlen=MAX_SEQUENCE_LENGTH,
101+
nsamples=NUM_CALIBRATION_SAMPLES,
102+
)
103+
104+
# Configure the quantization algorithm to run.
105+
# * quantize the weights to 4 bit with AutoRound with a group size 128
106+
recipe = AutoRoundModifier(
107+
targets="Linear",
108+
scheme=args.scheme,
109+
ignore=[
110+
"lm_head",
111+
"re:.*mlp.gate$",
112+
],
113+
iters=ITERS,
114+
enable_torch_compile=not args.disable_torch_compile,
115+
)
116+
117+
# Apply algorithms.
118+
oneshot(
119+
model=model,
120+
dataset=ds,
121+
recipe=recipe,
122+
max_seq_length=MAX_SEQUENCE_LENGTH,
123+
num_calibration_samples=NUM_CALIBRATION_SAMPLES,
124+
shuffle_calibration_samples=False,
125+
)
126+
127+
rank = dist.get_rank()
128+
logger.info(f"[Rank {rank}] Quantization completed")
129+
# Confirm generations of the quantized model look sane.
130+
logger.info("\n\n")
131+
logger.info("========== SAMPLE GENERATION ==============")
132+
dispatch_model(model)
133+
sample = tokenizer("Hello my name is", return_tensors="pt")
134+
sample = {key: value.to(model.device) for key, value in sample.items()}
135+
output = model.generate(**sample, max_new_tokens=100)
136+
logger.info(tokenizer.decode(output[0]))
137+
logger.info("==========================================\n\n")
138+
139+
logger.info("Saving...")
140+
# Save to disk compressed.
141+
SAVE_DIR = (
142+
model_id.rstrip("/").split("/")[-1]
143+
+ f"-{args.scheme}-AutoRound"
144+
+ f"-iters{args.iters}-nsamples{args.nsamples}"
145+
+ "-DDP"
146+
+ str(dist.get_world_size())
147+
)
148+
model.save_pretrained(SAVE_DIR, save_compressed=True)
149+
tokenizer.save_pretrained(SAVE_DIR)
150+
logger.info(f"Saved to {SAVE_DIR}")
151+
152+
dist.destroy_process_group()

0 commit comments

Comments
 (0)