1- """
2- # AutoRound: https://github.com/intel/auto-round/tree/main
3- CUDA_VISIBLE_DEVICES=0,1 python ddp_qwen3_example.py \
4- --model Qwen/Qwen3-8B \
5- --ddp \
6- --nsamples 128 \
7- --iters 200 \
8- --disable_torch_compile \
9- --deterministic \
10-
11- """
12- from loguru import logger
13- from auto_round .calib_dataset import get_dataset
14- from transformers import AutoModelForCausalLM , AutoTokenizer
15-
16- from llmcompressor import oneshot
17- from llmcompressor .modifiers .autoround import AutoRoundModifier
18- from llmcompressor .utils import dispatch_for_generation
19-
20- # Select model and load it.
21- model_id = "Qwen/Qwen3-8B"
1+ #############################################################################
2+ # This script is adapted to use DDP functionality with AutoRound.
3+ # run this with `torchrun --nproc_per_node=2 ddp_qwen3_example.py`
4+ # or change nproc_per_node to your desired configuration
5+ #
6+ # Example usage:
7+ # torchrun --nproc_per_node=2 ddp_qwen3_example.py \
8+ # --model Qwen/Qwen3-8B \
9+ # --nsamples 128 \
10+ # --iters 200 \
11+ # --disable_torch_compile \
12+ # --deterministic
13+ #############################################################################
2214
2315import argparse
2416import os
17+ import time
2518
2619import torch
2720import torch .distributed as dist
28- import torch .multiprocessing as mp
21+ from compressed_tensors .offload import dispatch_model , init_dist , load_offloaded_model
22+ from datasets import load_dataset
23+ from loguru import logger
24+ from transformers import AutoModelForCausalLM , AutoTokenizer
25+ import torch .distributed as dist
26+ from llmcompressor import oneshot
27+ from llmcompressor .datasets .utils import get_rank_partition
28+ from llmcompressor .modifiers .autoround import AutoRoundModifier
2929
3030
31- def fix_everything (seed = 42 ):
31+ def fix_everything (seed = 42 ):
3232 import random
3333 import numpy as np
3434
@@ -42,144 +42,16 @@ def config_deterministic():
4242 torch .use_deterministic_algorithms (True , warn_only = False )
4343 os .environ ["CUBLAS_WORKSPACE_CONFIG" ] = ":4096:8"
4444 fix_everything ()
45-
46-
47-
48- def setup_ddp (rank , world_size ):
49- """Initialize the distributed environment."""
50- os .environ ["MASTER_ADDR" ] = os .environ .get ("MASTER_ADDR" , "localhost" )
51- os .environ ["MASTER_PORT" ] = os .environ .get ("MASTER_PORT" , "12356" )
52-
53- # Initialize process group
54- backend = "nccl" if torch .cuda .is_available () else "gloo"
55- dist .init_process_group (backend , rank = rank , world_size = world_size )
56- torch .cuda .set_device (rank )
57-
58-
59- def cleanup_ddp ():
60- """Clean up the distributed environment."""
61- if dist .is_initialized ():
62- dist .destroy_process_group ()
63-
64-
65- def quantize_model (rank , world_size , args ):
66- """
67- Quantize model on a specific GPU rank.
68-
69- Args:
70- rank: GPU rank for this process
71- world_size: Total number of GPUs
72- args: Command line arguments
73- """
74- if args .deterministic :
75- config_deterministic ()
76- logger .info (f"[Rank { rank } /{ world_size } ] Starting quantization" )
77-
78- # Setup DDP if using multiple GPUs
79- if world_size > 1 :
80- setup_ddp (rank , world_size )
81-
82- # Set device for this process
83- model_name = args .model_name
84- # device_map = "meta"
85- model = AutoModelForCausalLM .from_pretrained (
86- model_name , torch_dtype = "auto" ,
87- # device_map=device_map
88- )
89- tokenizer = AutoTokenizer .from_pretrained (model_name )
90-
91- # Select calibration dataset.
92- NUM_CALIBRATION_SAMPLES = args .nsamples
93- MAX_SEQUENCE_LENGTH = 2048
94- ITERS = args .iters
95- # Get aligned calibration dataset.
96-
97- ds = get_dataset (
98- tokenizer = tokenizer ,
99- seqlen = MAX_SEQUENCE_LENGTH ,
100- nsamples = NUM_CALIBRATION_SAMPLES ,
101- )
102-
103- # Configure the quantization algorithm to run.
104- # * quantize the weights to 4 bit with AutoRound with a group size 128
105- recipe = AutoRoundModifier (
106- targets = "Linear" ,
107- scheme = args .scheme ,
108- ignore = [
109- "lm_head" ,
110- "re:.*mlp.gate$" ,
111- ],
112- iters = ITERS ,
113- enable_torch_compile = not args .disable_torch_compile ,
114- )
115-
116- # Apply algorithms.
117- oneshot (
118- model = model ,
119- dataset = ds ,
120- recipe = recipe ,
121- max_seq_length = MAX_SEQUENCE_LENGTH ,
122- num_calibration_samples = NUM_CALIBRATION_SAMPLES ,
123- shuffle_calibration_samples = False ,
124- )
125-
126- # Synchronize all processes
127- if world_size > 1 :
128- dist .barrier ()
129-
130- logger .info (f"[Rank { rank } ] Quantization completed" )
131- if rank == 0 :
132- # Confirm generations of the quantized model look sane.
133- logger .info ("\n \n " )
134- logger .info ("========== SAMPLE GENERATION ==============" )
135- dispatch_for_generation (model )
136- sample = tokenizer ("Hello my name is" , return_tensors = "pt" )
137- sample = {key : value .to (model .device ) for key , value in sample .items ()}
138- output = model .generate (** sample , max_new_tokens = 100 )
139- logger .info (tokenizer .decode (output [0 ]))
140- logger .info ("==========================================\n \n " )
141-
142- # Save to disk compressed.
143- SAVE_DIR = (
144- model_name .rstrip ("/" ).split ("/" )[- 1 ]
145- + f"-{ args .scheme } -AutoRound"
146- + f"-iters{ args .iters } -nsamples{ args .nsamples } "
147- )
148- logger .info (f"save to { SAVE_DIR } " )
149- model .save_pretrained (SAVE_DIR , save_compressed = True )
150- tokenizer .save_pretrained (SAVE_DIR )
151- else :
152- # Other ranks just run quantization without saving
153- logger .info (f"[Rank { rank } ] Running quantization (not saving)" )
154-
155- if world_size > 1 :
156- cleanup_ddp ()
157-
158-
159- def main_spawn (args ):
160- """Main function using mp.spawn for multi-GPU quantization."""
161- world_size = torch .cuda .device_count () if torch .cuda .is_available () else 1
162-
163- logger .info (f"Starting DDP quantization with { world_size } GPUs" )
164-
165- mp .spawn (
166- quantize_model ,
167- args = (world_size , args ),
168- nprocs = world_size ,
169- join = True ,
170- )
171-
172- logger .info ("Quantization completed!" )
17345
17446
17547if __name__ == "__main__" :
17648 parser = argparse .ArgumentParser (
17749 description = "AutoRound Quantization with DDP support"
17850 )
17951 parser .add_argument (
180- "--model_name " ,
52+ "--model " ,
18153 type = str ,
182- default = model_id ,
54+ default = "Qwen/Qwen3-8B" ,
18355 help = "Model name or path" ,
18456 )
18557 parser .add_argument (
@@ -188,9 +60,8 @@ def main_spawn(args):
18860 default = "W4A16" ,
18961 help = "Quantization scheme (W4A16, MXFP8, MXFP4, etc.)" ,
19062 )
191- parser .add_argument ("--iters" , type = int , default = 100 , help = "Number of iterations" )
192- parser .add_argument ("--nsamples" , type = int , default = 256 , help = "Number of samples" )
193- parser .add_argument ("--ddp" , action = "store_true" , help = "Enable DDP multi-GPU mode" )
63+ parser .add_argument ("--iters" , type = int , default = 200 , help = "Number of iterations" )
64+ parser .add_argument ("--nsamples" , type = int , default = 128 , help = "Number of samples" )
19465 parser .add_argument (
19566 "--disable_torch_compile" ,
19667 action = "store_true" ,
@@ -203,22 +74,79 @@ def main_spawn(args):
20374 )
20475 args = parser .parse_args ()
20576
206- # For backward compatibility with existing hardcoded values
207- model_name = args .model_name
208-
209- # Parse scheme from string if needed
210- from auto_round import schemes as ar_schemes
211-
212- scheme_map = {
213- "FP8_STATIC" : ar_schemes .FP8_STATIC ,
214- "MXFP8" : ar_schemes .MXFP8 ,
215- "MXFP4" : ar_schemes .MXFP4 ,
216- }
217- # scheme = scheme_map.get(args.scheme, args.scheme)
218-
219- if args .ddp :
220- logger .info ("Using mp.spawn mode for multi-GPU quantization" )
221- main_spawn (args )
222- else :
223- logger .info ("Using single-process quantization" )
224- quantize_model (rank = 0 , world_size = 1 , args = args )
77+ if args .deterministic :
78+ config_deterministic ()
79+
80+ model_id = args .model
81+
82+ ###### DDP MODEL LOAD CHANGE #####
83+ init_dist ()
84+ with load_offloaded_model ():
85+ model = AutoModelForCausalLM .from_pretrained (
86+ model_id , dtype = "auto" , device_map = "auto_offload"
87+ )
88+ ##################################
89+
90+ tokenizer = AutoTokenizer .from_pretrained (model_name )
91+
92+ # Select calibration dataset.
93+ NUM_CALIBRATION_SAMPLES = args .nsamples
94+ MAX_SEQUENCE_LENGTH = 2048
95+ ITERS = args .iters
96+ # Get aligned calibration dataset.
97+
98+ ds = get_dataset (
99+ tokenizer = tokenizer ,
100+ seqlen = MAX_SEQUENCE_LENGTH ,
101+ nsamples = NUM_CALIBRATION_SAMPLES ,
102+ )
103+
104+ # Configure the quantization algorithm to run.
105+ # * quantize the weights to 4 bit with AutoRound with a group size 128
106+ recipe = AutoRoundModifier (
107+ targets = "Linear" ,
108+ scheme = args .scheme ,
109+ ignore = [
110+ "lm_head" ,
111+ "re:.*mlp.gate$" ,
112+ ],
113+ iters = ITERS ,
114+ enable_torch_compile = not args .disable_torch_compile ,
115+ )
116+
117+ # Apply algorithms.
118+ oneshot (
119+ model = model ,
120+ dataset = ds ,
121+ recipe = recipe ,
122+ max_seq_length = MAX_SEQUENCE_LENGTH ,
123+ num_calibration_samples = NUM_CALIBRATION_SAMPLES ,
124+ shuffle_calibration_samples = False ,
125+ )
126+
127+ rank = dist .get_rank ()
128+ logger .info (f"[Rank { rank } ] Quantization completed" )
129+ # Confirm generations of the quantized model look sane.
130+ logger .info ("\n \n " )
131+ logger .info ("========== SAMPLE GENERATION ==============" )
132+ dispatch_model (model )
133+ sample = tokenizer ("Hello my name is" , return_tensors = "pt" )
134+ sample = {key : value .to (model .device ) for key , value in sample .items ()}
135+ output = model .generate (** sample , max_new_tokens = 100 )
136+ logger .info (tokenizer .decode (output [0 ]))
137+ logger .info ("==========================================\n \n " )
138+
139+ logger .info ("Saving..." )
140+ # Save to disk compressed.
141+ SAVE_DIR = (
142+ model_id .rstrip ("/" ).split ("/" )[- 1 ]
143+ + f"-{ args .scheme } -AutoRound"
144+ + f"-iters{ args .iters } -nsamples{ args .nsamples } "
145+ + "-DDP"
146+ + str (dist .get_world_size ())
147+ )
148+ model .save_pretrained (SAVE_DIR , save_compressed = True )
149+ tokenizer .save_pretrained (SAVE_DIR )
150+ logger .info (f"Saved to { SAVE_DIR } " )
151+
152+ dist .destroy_process_group ()
0 commit comments