|
1 | | -############################################################################# |
2 | | -# This script is adapted to use DDP functionality with AutoRound. |
3 | | -# run this with `torchrun --nproc_per_node=2 ddp_qwen3_example.py` |
4 | | -# or change nproc_per_node to your desired configuration |
5 | | -# |
6 | | -# Example usage: |
7 | | -# torchrun --nproc_per_node=2 ddp_qwen3_example.py \ |
8 | | -# --model Qwen/Qwen3-8B \ |
9 | | -# --nsamples 128 \ |
10 | | -# --iters 100 \ |
11 | | -# --disable_torch_compile \ |
12 | | -# --deterministic |
13 | | -############################################################################# |
14 | | - |
| 1 | +""" |
| 2 | +This script is adapted to use DDP functionality with AutoRound. |
| 3 | +run this with `torchrun --nproc_per_node=2 ddp_qwen3_example.py` |
| 4 | +or change nproc_per_node to your desired configuration |
| 5 | +
|
| 6 | +Example usage: |
| 7 | +torchrun --nproc_per_node=2 ddp_qwen3_example.py \ |
| 8 | + --model Qwen/Qwen3-8B \ |
| 9 | + --nsamples 128 \ |
| 10 | + --iters 100 \ |
| 11 | + --disable_torch_compile \ |
| 12 | + --deterministic |
| 13 | +""" |
15 | 14 | import argparse |
16 | 15 | import os |
17 | 16 |
|
@@ -93,12 +92,12 @@ def config_deterministic(): |
93 | 92 |
|
94 | 93 |
|
95 | 94 | # Get aligned calibration dataset. |
96 | | -from auto_round.calib_dataset import get_dataset |
| 95 | +from auto_round.calib_dataset import get_dataset # noqa: E402 |
97 | 96 |
|
98 | 97 | # Note: Make sure model are loaded before importing auto-round related code. |
99 | 98 | # This requirement will be lifted once switching to new release of auto-round which |
100 | 99 | # includes below fix: |
101 | | -from llmcompressor.modifiers.autoround import AutoRoundModifier |
| 100 | +from llmcompressor.modifiers.autoround import AutoRoundModifier # noqa: E402 |
102 | 101 |
|
103 | 102 | ds = get_dataset( |
104 | 103 | tokenizer=tokenizer, |
@@ -131,27 +130,28 @@ def config_deterministic(): |
131 | 130 |
|
132 | 131 | rank = dist.get_rank() |
133 | 132 | logger.info(f"[Rank {rank}] Quantization completed") |
134 | | -# Confirm generations of the quantized model look sane. |
135 | | -logger.info("\n\n") |
136 | | -logger.info("========== SAMPLE GENERATION ==============") |
137 | | -dispatch_model(model) |
138 | | -sample = tokenizer("Hello my name is", return_tensors="pt") |
139 | | -sample = {key: value.to(model.device) for key, value in sample.items()} |
140 | | -output = model.generate(**sample, max_new_tokens=100) |
141 | | -logger.info(tokenizer.decode(output[0])) |
142 | | -logger.info("==========================================\n\n") |
143 | | - |
144 | | -logger.info("Saving...") |
145 | | -# Save to disk compressed. |
146 | | -SAVE_DIR = ( |
147 | | - model_id.rstrip("/").split("/")[-1] |
148 | | - + f"-{args.scheme}-AutoRound" |
149 | | - + f"-iters{args.iters}-nsamples{args.nsamples}" |
150 | | - + "-DDP" |
151 | | - + str(dist.get_world_size()) |
152 | | -) |
153 | | -model.save_pretrained(SAVE_DIR, save_compressed=True) |
154 | | -tokenizer.save_pretrained(SAVE_DIR) |
155 | | -logger.info(f"Saved to {SAVE_DIR}") |
| 133 | +if rank == 0: |
| 134 | + # Confirm generations of the quantized model look sane. |
| 135 | + logger.info("\n\n") |
| 136 | + logger.info("========== SAMPLE GENERATION ==============") |
| 137 | + dispatch_model(model) |
| 138 | + sample = tokenizer("Hello my name is", return_tensors="pt") |
| 139 | + sample = {key: value.to(model.device) for key, value in sample.items()} |
| 140 | + output = model.generate(**sample, max_new_tokens=100) |
| 141 | + logger.info(tokenizer.decode(output[0])) |
| 142 | + logger.info("==========================================\n\n") |
| 143 | + |
| 144 | + logger.info("Saving...") |
| 145 | + # Save to disk compressed. |
| 146 | + SAVE_DIR = ( |
| 147 | + model_id.rstrip("/").split("/")[-1] |
| 148 | + + f"-{args.scheme}-AutoRound" |
| 149 | + + f"-iters{args.iters}-nsamples{args.nsamples}" |
| 150 | + + "-DDP" |
| 151 | + + str(dist.get_world_size()) |
| 152 | + ) |
| 153 | + model.save_pretrained(SAVE_DIR, save_compressed=True) |
| 154 | + tokenizer.save_pretrained(SAVE_DIR) |
| 155 | + logger.info(f"Saved to {SAVE_DIR}") |
156 | 156 |
|
157 | 157 | dist.destroy_process_group() |
0 commit comments