-
Notifications
You must be signed in to change notification settings - Fork 453
[Distributed] Extend QuantizationModifier to support distributed activation calibration #2391
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
Etelis
wants to merge
22
commits into
vllm-project:main
Choose a base branch
from
Etelis:feature/quantization-modifier-ddp
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from 7 commits
Commits
Show all changes
22 commits
Select commit
Hold shift + click to select a range
f60200a
[Distributed] Add distributed utilities for DDP calibration
EtelisIBM c4d630d
[Distributed] Add recompute_qparams_from_observer helper
EtelisIBM 89d1ade
[Distributed] Partition weight calibration across DDP ranks
EtelisIBM ac0cc2a
[Tests] Add unit tests for distributed utilities
EtelisIBM 76cf40f
[Tests] Add multi-GPU integration tests for DDP quantization
EtelisIBM 0f3e1f9
[Examples] Add distributed W8A8 quantization example
EtelisIBM 9975edc
[Distributed] Fix broadcast_module_parameter for CPU-resident models
EtelisIBM 3320812
[Distributed] Refactor DDP activation sync per review feedback
EtelisIBM 87f4b0d
Merge branch 'main' into feature/quantization-modifier-ddp
Etelis 766a70c
Merge remote-tracking branch 'upstream/main' into feature/quantizatio…
EtelisIBM d44c4ab
[Distributed] Address review feedback for DDP activation observer sync
EtelisIBM 5fa31b2
Merge branch 'main' into feature/quantization-modifier-ddp
HDCharles 0e3a843
[Distributed] Use as_broadcastable and simplify moving-average sync
EtelisIBM 82d808c
Merge branch 'feature/quantization-modifier-ddp' of https://github.co…
EtelisIBM d6b3575
Update src/llmcompressor/observers/moving_base.py
Etelis 9680d9e
Merge branch 'main' into feature/quantization-modifier-ddp
Etelis 688b309
Merge branch 'main' into feature/quantization-modifier-ddp
kylesayrs 7f31744
Merge branch 'main' into feature/quantization-modifier-ddp
Etelis 4f80617
Merge branch 'main' into feature/quantization-modifier-ddp
Etelis f959d4b
fix formatting and moving-average test mock path
EtelisIBM 7baf545
Merge branch 'feature/quantization-modifier-ddp' of https://github.co…
EtelisIBM 09e817b
Merge branch 'main' into feature/quantization-modifier-ddp
HDCharles File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Some comments aren't visible on the classic Files Changed page.
There are no files selected for viewing
98 changes: 98 additions & 0 deletions
98
examples/big_models_with_sequential_onloading/llama3_8b_w8a8_distributed.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,98 @@ | ||
| import torch | ||
| import torch.distributed as dist | ||
| from datasets import load_dataset | ||
| from transformers import AutoModelForCausalLM, AutoTokenizer | ||
|
|
||
| from llmcompressor import oneshot | ||
| from llmcompressor.modifiers.quantization import QuantizationModifier | ||
|
|
||
| # Select model and load it. | ||
| MODEL_ID = "meta-llama/Meta-Llama-3-8B-Instruct" | ||
|
|
||
| # Select calibration dataset. | ||
| DATASET_ID = "HuggingFaceH4/ultrachat_200k" | ||
| DATASET_SPLIT = "train_sft" | ||
|
|
||
| # Select number of samples. | ||
| # Increasing the number of samples can improve accuracy. | ||
| NUM_CALIBRATION_SAMPLES = 256 | ||
| MAX_SEQUENCE_LENGTH = 2048 | ||
|
|
||
| # Initialize distributed. | ||
| # Usage: torchrun --nproc_per_node=2 llama3_8b_w8a8_distributed.py | ||
| dist.init_process_group(backend="nccl") | ||
| rank = dist.get_rank() | ||
| world_size = dist.get_world_size() | ||
| torch.cuda.set_device(rank) | ||
kylesayrs marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| if rank == 0: | ||
| print(f"Running distributed quantization with {world_size} GPUs") | ||
|
|
||
| # Load model to CPU for sequential onloading. | ||
| model = AutoModelForCausalLM.from_pretrained( | ||
| MODEL_ID, | ||
| dtype="auto", | ||
| device_map=None, | ||
| ) | ||
| tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) | ||
|
|
||
| # Load and partition dataset across ranks. | ||
| # Each rank loads a disjoint slice of the calibration data. | ||
| samples_per_rank = NUM_CALIBRATION_SAMPLES // world_size | ||
| start = samples_per_rank * rank | ||
| end = start + samples_per_rank | ||
|
|
||
| ds = load_dataset(DATASET_ID, split=f"{DATASET_SPLIT}[{start}:{end}]") | ||
| ds = ds.shuffle(seed=42) | ||
|
|
||
|
|
||
| def preprocess(example): | ||
| return { | ||
| "text": tokenizer.apply_chat_template( | ||
| example["messages"], | ||
| tokenize=False, | ||
| ) | ||
| } | ||
|
|
||
|
|
||
| ds = ds.map(preprocess) | ||
|
|
||
|
|
||
| # Tokenize inputs. | ||
| def tokenize(sample): | ||
| return tokenizer( | ||
| sample["text"], | ||
| padding=False, | ||
| max_length=MAX_SEQUENCE_LENGTH, | ||
| truncation=True, | ||
| add_special_tokens=False, | ||
| ) | ||
|
|
||
|
|
||
| ds = ds.map(tokenize, remove_columns=ds.column_names) | ||
|
|
||
| # Configure the quantization algorithm to run. | ||
| # QuantizationModifier automatically detects torch.distributed and: | ||
| # * partitions weight calibration across ranks | ||
| # * all-reduces activation observer statistics at layer boundaries | ||
| recipe = [ | ||
| QuantizationModifier(targets="Linear", scheme="W8A8", ignore=["lm_head"]), | ||
| ] | ||
|
|
||
| # Apply algorithms. | ||
| oneshot( | ||
| model=model, | ||
| dataset=ds, | ||
| recipe=recipe, | ||
| max_seq_length=MAX_SEQUENCE_LENGTH, | ||
| num_calibration_samples=samples_per_rank, | ||
| ) | ||
|
|
||
| # Save to disk compressed (rank 0 only). | ||
| SAVE_DIR = MODEL_ID.rstrip("/").split("/")[-1] + "-W8A8-distributed" | ||
| if rank == 0: | ||
| model.save_pretrained(SAVE_DIR, save_compressed=True) | ||
| tokenizer.save_pretrained(SAVE_DIR) | ||
| print(f"Model saved to {SAVE_DIR}") | ||
|
|
||
| dist.destroy_process_group() | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -8,3 +8,4 @@ | |
| from .dev import * | ||
| from .helpers import * | ||
| from .dist import * | ||
| from .distributed import * | ||
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.