Skip to content

Commit 0f3e1f9

Browse files
committed
[Examples] Add distributed W8A8 quantization example
Signed-off-by: Itay Etelis <itay.etelis@ibm.com>
1 parent 76cf40f commit 0f3e1f9

File tree

1 file changed

+98
-0
lines changed

1 file changed

+98
-0
lines changed
Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
import torch
2+
import torch.distributed as dist
3+
from datasets import load_dataset
4+
from transformers import AutoModelForCausalLM, AutoTokenizer
5+
6+
from llmcompressor import oneshot
7+
from llmcompressor.modifiers.quantization import QuantizationModifier
8+
9+
# Select model and load it.
10+
MODEL_ID = "meta-llama/Meta-Llama-3-8B-Instruct"
11+
12+
# Select calibration dataset.
13+
DATASET_ID = "HuggingFaceH4/ultrachat_200k"
14+
DATASET_SPLIT = "train_sft"
15+
16+
# Select number of samples.
17+
# Increasing the number of samples can improve accuracy.
18+
NUM_CALIBRATION_SAMPLES = 256
19+
MAX_SEQUENCE_LENGTH = 2048
20+
21+
# Initialize distributed.
22+
# Usage: torchrun --nproc_per_node=2 llama3_8b_w8a8_distributed.py
23+
dist.init_process_group(backend="nccl")
24+
rank = dist.get_rank()
25+
world_size = dist.get_world_size()
26+
torch.cuda.set_device(rank)
27+
28+
if rank == 0:
29+
print(f"Running distributed quantization with {world_size} GPUs")
30+
31+
# Load model to CPU for sequential onloading.
32+
model = AutoModelForCausalLM.from_pretrained(
33+
MODEL_ID,
34+
dtype="auto",
35+
device_map=None,
36+
)
37+
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
38+
39+
# Load and partition dataset across ranks.
40+
# Each rank loads a disjoint slice of the calibration data.
41+
samples_per_rank = NUM_CALIBRATION_SAMPLES // world_size
42+
start = samples_per_rank * rank
43+
end = start + samples_per_rank
44+
45+
ds = load_dataset(DATASET_ID, split=f"{DATASET_SPLIT}[{start}:{end}]")
46+
ds = ds.shuffle(seed=42)
47+
48+
49+
def preprocess(example):
50+
return {
51+
"text": tokenizer.apply_chat_template(
52+
example["messages"],
53+
tokenize=False,
54+
)
55+
}
56+
57+
58+
ds = ds.map(preprocess)
59+
60+
61+
# Tokenize inputs.
62+
def tokenize(sample):
63+
return tokenizer(
64+
sample["text"],
65+
padding=False,
66+
max_length=MAX_SEQUENCE_LENGTH,
67+
truncation=True,
68+
add_special_tokens=False,
69+
)
70+
71+
72+
ds = ds.map(tokenize, remove_columns=ds.column_names)
73+
74+
# Configure the quantization algorithm to run.
75+
# QuantizationModifier automatically detects torch.distributed and:
76+
# * partitions weight calibration across ranks
77+
# * all-reduces activation observer statistics at layer boundaries
78+
recipe = [
79+
QuantizationModifier(targets="Linear", scheme="W8A8", ignore=["lm_head"]),
80+
]
81+
82+
# Apply algorithms.
83+
oneshot(
84+
model=model,
85+
dataset=ds,
86+
recipe=recipe,
87+
max_seq_length=MAX_SEQUENCE_LENGTH,
88+
num_calibration_samples=samples_per_rank,
89+
)
90+
91+
# Save to disk compressed (rank 0 only).
92+
SAVE_DIR = MODEL_ID.rstrip("/").split("/")[-1] + "-W8A8-distributed"
93+
if rank == 0:
94+
model.save_pretrained(SAVE_DIR, save_compressed=True)
95+
tokenizer.save_pretrained(SAVE_DIR)
96+
print(f"Model saved to {SAVE_DIR}")
97+
98+
dist.destroy_process_group()

0 commit comments

Comments
 (0)