-
Notifications
You must be signed in to change notification settings - Fork 453
[AutoRound] Add DDP Support and Example #2411
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
6 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Some comments aren't visible on the classic Files Changed page.
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,157 @@ | ||
| """ | ||
| This script is adapted to use DDP functionality with AutoRound. | ||
| run this with `torchrun --nproc_per_node=2 ddp_qwen3_example.py` | ||
| or change nproc_per_node to your desired configuration | ||
|
|
||
| Example usage: | ||
| torchrun --nproc_per_node=2 ddp_qwen3_example.py \ | ||
| --model Qwen/Qwen3-8B \ | ||
| --nsamples 128 \ | ||
| --iters 100 \ | ||
| --disable_torch_compile \ | ||
| --deterministic | ||
| """ | ||
|
|
||
| import argparse | ||
| import os | ||
|
|
||
| import torch | ||
| import torch.distributed as dist | ||
| from compressed_tensors.offload import dispatch_model, init_dist, load_offloaded_model | ||
| from loguru import logger | ||
| from transformers import AutoModelForCausalLM, AutoTokenizer | ||
|
|
||
| from llmcompressor import oneshot | ||
|
|
||
|
|
||
| def fix_everything(seed=42): | ||
| import random | ||
|
|
||
| import numpy as np | ||
|
|
||
| random.seed(seed) | ||
| np.random.seed(seed) | ||
| torch.manual_seed(seed) | ||
| torch.cuda.manual_seed_all(seed) | ||
|
|
||
|
|
||
| def config_deterministic(): | ||
| torch.use_deterministic_algorithms(True, warn_only=False) | ||
| os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" | ||
| fix_everything() | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| parser = argparse.ArgumentParser( | ||
| description="AutoRound Quantization with DDP support" | ||
| ) | ||
| parser.add_argument( | ||
| "--model", | ||
| type=str, | ||
| default="Qwen/Qwen3-8B", | ||
| help="Model name or path", | ||
| ) | ||
| parser.add_argument( | ||
| "--scheme", | ||
| type=str, | ||
| default="W4A16", | ||
| help="Quantization scheme (W4A16, MXFP8, MXFP4, etc.)", | ||
| ) | ||
| parser.add_argument("--iters", type=int, default=200, help="Number of iterations") | ||
| parser.add_argument("--nsamples", type=int, default=128, help="Number of samples") | ||
| parser.add_argument( | ||
| "--disable_torch_compile", | ||
| action="store_true", | ||
| help="Disable torch.compile for model acceleration during quantization", | ||
| ) | ||
| parser.add_argument( | ||
| "--deterministic", | ||
| action="store_true", | ||
| help="Enable deterministic mode for reproducibility", | ||
| ) | ||
| args = parser.parse_args() | ||
|
|
||
| if args.deterministic: | ||
| config_deterministic() | ||
|
|
||
| model_id = args.model | ||
|
|
||
| ###### DDP MODEL LOAD CHANGE ##### | ||
| init_dist() | ||
| with load_offloaded_model(): | ||
| model = AutoModelForCausalLM.from_pretrained( | ||
| model_id, dtype="auto", device_map="auto_offload" | ||
| ) | ||
| ################################## | ||
|
|
||
| tokenizer = AutoTokenizer.from_pretrained(model_id) | ||
|
|
||
| # Select calibration dataset. | ||
| NUM_CALIBRATION_SAMPLES = args.nsamples | ||
| MAX_SEQUENCE_LENGTH = 2048 | ||
| ITERS = args.iters | ||
|
|
||
|
|
||
| # Get aligned calibration dataset. | ||
| from auto_round.calib_dataset import get_dataset # noqa: E402 | ||
|
|
||
| # Note: Make sure model are loaded before importing auto-round related code. | ||
| # This requirement will be lifted once switching to new release of auto-round which | ||
| # includes below fix: | ||
| from llmcompressor.modifiers.autoround import AutoRoundModifier # noqa: E402 | ||
|
|
||
| ds = get_dataset( | ||
HDCharles marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| tokenizer=tokenizer, | ||
| seqlen=MAX_SEQUENCE_LENGTH, | ||
| nsamples=NUM_CALIBRATION_SAMPLES, | ||
| ) | ||
|
|
||
| # Configure the quantization algorithm to run. | ||
| # * quantize the weights to 4 bit with AutoRound with a group size 128 | ||
| recipe = AutoRoundModifier( | ||
| targets="Linear", | ||
| scheme=args.scheme, | ||
| ignore=[ | ||
| "lm_head", | ||
| "re:.*mlp.gate$", | ||
| ], | ||
| iters=ITERS, | ||
| enable_torch_compile=not args.disable_torch_compile, | ||
| ) | ||
|
|
||
| # Apply algorithms. | ||
| oneshot( | ||
| model=model, | ||
| dataset=ds, | ||
| recipe=recipe, | ||
| max_seq_length=MAX_SEQUENCE_LENGTH, | ||
| num_calibration_samples=NUM_CALIBRATION_SAMPLES, | ||
| shuffle_calibration_samples=False, | ||
| ) | ||
|
|
||
| rank = dist.get_rank() | ||
| logger.info(f"[Rank {rank}] Quantization completed") | ||
| # Confirm generations of the quantized model look sane. | ||
| logger.info("\n\n") | ||
| logger.info("========== SAMPLE GENERATION ==============") | ||
| dispatch_model(model) | ||
| sample = tokenizer("Hello my name is", return_tensors="pt") | ||
| sample = {key: value.to(model.device) for key, value in sample.items()} | ||
| output = model.generate(**sample, max_new_tokens=100) | ||
| logger.info(tokenizer.decode(output[0])) | ||
| logger.info("==========================================\n\n") | ||
|
|
||
| logger.info("Saving...") | ||
| # Save to disk compressed. | ||
| SAVE_DIR = ( | ||
| model_id.rstrip("/").split("/")[-1] | ||
| + f"-{args.scheme}-AutoRound" | ||
| + f"-iters{args.iters}-nsamples{args.nsamples}" | ||
| + "-DDP" | ||
| + str(dist.get_world_size()) | ||
| ) | ||
| model.save_pretrained(SAVE_DIR, save_compressed=True) | ||
| tokenizer.save_pretrained(SAVE_DIR) | ||
| logger.info(f"Saved to {SAVE_DIR}") | ||
|
|
||
| dist.destroy_process_group() | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.