-
Notifications
You must be signed in to change notification settings - Fork 453
Ddp v3 #2389
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Closed
Ddp v3 #2389
Changes from all commits
Commits
Show all changes
26 commits
Select commit
Hold shift + click to select a range
0b27aee
poc ddp
yiliu30 bba50ec
format
yiliu30 39d3fe1
add dpp example
yiliu30 80251aa
update example
yiliu30 191e340
fix disptach device
yiliu30 d740be6
fix CR
yiliu30 44867d2
fix
yiliu30 1f10bca
add doc
yiliu30 0471a68
fix
yiliu30 2dbe8c9
update
yiliu30 41c6a8b
fix
yiliu30 4abfaa3
clean code
yiliu30 01254cb
refine code
yiliu30 cdc8ab0
clean code
yiliu30 f0ae0ae
clean code
yiliu30 f4d2ed7
fix
yiliu30 0da4e97
clean
yiliu30 38bc7f0
update
yiliu30 c333ba0
update
yiliu30 7c23d0f
update
yiliu30 90ec240
fix
yiliu30 80669f5
fix
yiliu30 203cda7
clean
yiliu30 f8b5cf7
clean
yiliu30 7516998
add autoround
yiliu30 fa82d04
remove unneeded and test updates
HDCharles File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,152 @@ | ||
| ############################################################################# | ||
| # This script is adapted to use DDP functionality with AutoRound. | ||
| # run this with `torchrun --nproc_per_node=2 ddp_qwen3_example.py` | ||
| # or change nproc_per_node to your desired configuration | ||
| # | ||
| # Example usage: | ||
| # torchrun --nproc_per_node=2 ddp_qwen3_example.py \ | ||
| # --model Qwen/Qwen3-8B \ | ||
| # --nsamples 128 \ | ||
| # --iters 200 \ | ||
| # --disable_torch_compile \ | ||
| # --deterministic | ||
| ############################################################################# | ||
|
|
||
| import argparse | ||
| import os | ||
| import time | ||
|
|
||
| import torch | ||
| import torch.distributed as dist | ||
| from compressed_tensors.offload import dispatch_model, init_dist, load_offloaded_model | ||
| from datasets import load_dataset | ||
| from loguru import logger | ||
| from transformers import AutoModelForCausalLM, AutoTokenizer | ||
| import torch.distributed as dist | ||
| from llmcompressor import oneshot | ||
| from llmcompressor.datasets.utils import get_rank_partition | ||
| from llmcompressor.modifiers.autoround import AutoRoundModifier | ||
|
|
||
|
|
||
| def fix_everything(seed=42): | ||
| import random | ||
| import numpy as np | ||
|
|
||
| random.seed(seed) | ||
| np.random.seed(seed) | ||
| torch.manual_seed(seed) | ||
| torch.cuda.manual_seed_all(seed) | ||
|
|
||
|
|
||
| def config_deterministic(): | ||
| torch.use_deterministic_algorithms(True, warn_only=False) | ||
| os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" | ||
| fix_everything() | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| parser = argparse.ArgumentParser( | ||
| description="AutoRound Quantization with DDP support" | ||
| ) | ||
| parser.add_argument( | ||
| "--model", | ||
| type=str, | ||
| default="Qwen/Qwen3-8B", | ||
| help="Model name or path", | ||
| ) | ||
| parser.add_argument( | ||
| "--scheme", | ||
| type=str, | ||
| default="W4A16", | ||
| help="Quantization scheme (W4A16, MXFP8, MXFP4, etc.)", | ||
| ) | ||
| parser.add_argument("--iters", type=int, default=200, help="Number of iterations") | ||
| parser.add_argument("--nsamples", type=int, default=128, help="Number of samples") | ||
| parser.add_argument( | ||
| "--disable_torch_compile", | ||
| action="store_true", | ||
| help="Disable torch.compile for model acceleration during quantization", | ||
| ) | ||
| parser.add_argument( | ||
| "--deterministic", | ||
| action="store_true", | ||
| help="Enable deterministic mode for reproducibility", | ||
| ) | ||
| args = parser.parse_args() | ||
|
|
||
| if args.deterministic: | ||
| config_deterministic() | ||
|
|
||
| model_id = args.model | ||
|
|
||
| ###### DDP MODEL LOAD CHANGE ##### | ||
| init_dist() | ||
| with load_offloaded_model(): | ||
| model = AutoModelForCausalLM.from_pretrained( | ||
| model_id, dtype="auto", device_map="auto_offload" | ||
| ) | ||
| ################################## | ||
|
|
||
| tokenizer = AutoTokenizer.from_pretrained(model_name) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
|
|
||
| # Select calibration dataset. | ||
| NUM_CALIBRATION_SAMPLES = args.nsamples | ||
| MAX_SEQUENCE_LENGTH = 2048 | ||
| ITERS = args.iters | ||
| # Get aligned calibration dataset. | ||
|
|
||
| ds = get_dataset( | ||
| tokenizer=tokenizer, | ||
| seqlen=MAX_SEQUENCE_LENGTH, | ||
| nsamples=NUM_CALIBRATION_SAMPLES, | ||
| ) | ||
|
|
||
| # Configure the quantization algorithm to run. | ||
| # * quantize the weights to 4 bit with AutoRound with a group size 128 | ||
| recipe = AutoRoundModifier( | ||
| targets="Linear", | ||
| scheme=args.scheme, | ||
| ignore=[ | ||
| "lm_head", | ||
| "re:.*mlp.gate$", | ||
| ], | ||
| iters=ITERS, | ||
| enable_torch_compile=not args.disable_torch_compile, | ||
| ) | ||
|
|
||
| # Apply algorithms. | ||
| oneshot( | ||
| model=model, | ||
| dataset=ds, | ||
| recipe=recipe, | ||
| max_seq_length=MAX_SEQUENCE_LENGTH, | ||
| num_calibration_samples=NUM_CALIBRATION_SAMPLES, | ||
| shuffle_calibration_samples=False, | ||
| ) | ||
|
|
||
| rank = dist.get_rank() | ||
| logger.info(f"[Rank {rank}] Quantization completed") | ||
| # Confirm generations of the quantized model look sane. | ||
| logger.info("\n\n") | ||
| logger.info("========== SAMPLE GENERATION ==============") | ||
| dispatch_model(model) | ||
| sample = tokenizer("Hello my name is", return_tensors="pt") | ||
| sample = {key: value.to(model.device) for key, value in sample.items()} | ||
| output = model.generate(**sample, max_new_tokens=100) | ||
| logger.info(tokenizer.decode(output[0])) | ||
| logger.info("==========================================\n\n") | ||
|
|
||
| logger.info("Saving...") | ||
| # Save to disk compressed. | ||
| SAVE_DIR = ( | ||
| model_id.rstrip("/").split("/")[-1] | ||
| + f"-{args.scheme}-AutoRound" | ||
| + f"-iters{args.iters}-nsamples{args.nsamples}" | ||
| + "-DDP" | ||
| + str(dist.get_world_size()) | ||
| ) | ||
| model.save_pretrained(SAVE_DIR, save_compressed=True) | ||
| tokenizer.save_pretrained(SAVE_DIR) | ||
| logger.info(f"Saved to {SAVE_DIR}") | ||
|
|
||
| dist.destroy_process_group() | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The
torch.distributed as distimport is duplicated. It's already imported on line 20.