Skip to content

Commit 0bc916e

Browse files
authored
[AWQ][DDP] adding DDP functionality to AWQ (#2457)
This PR enables AWQ to have DDP functionality. similar to [GPTQ DDP](#2333) i noticed a situation involving compounding floating point errors. With GPTQ this issue made the non DDP evaluation performance better, however this time it made the DDP evaluation performance worse. After correcting the compounding error, It looks like both DDP and non-DDP evaluation performance is more aligned with one another and its also slightly better or equal (to 2 decimal points) compared to before. see results below: ``` Script Model Time (min) GPU (GB) Flex Strict Flex(before) Strict(before) --------------------------------------------------------------------------------------------------------------------------------------------------------------------------- examples/awq/llama_example_ddp.py ./Meta-Llama-3-8B-Instruct-awq-asym-DDP4 2.40 4.99 0.7142 0.7149 0.6983 0.6990 ++ examples/awq/llama_example.py ./Meta-Llama-3-8B-Instruct-awq-asym 7.02 10.20 0.7081 0.7074 0.7058 0.7058 ++ examples/awq/llama_example_with_masking_ddp.py ./Meta-Llama-3-8B-Instruct-awq-asym-masked-DDP4 2.67 4.98 0.7119 0.7119 examples/awq/llama_example_with_masking.py ./Meta-Llama-3-8B-Instruct-awq-asym-masked 8.13 10.14 0.7058 0.7074 examples/awq/qwen3_vl_30b_example_ddp.py ./Qwen3-VL-30B-A3B-Instruct-AWQ-W4A16-g32-DDP4 143.10 3.38 0.8764 0.8529 0.8696 0.8453 ++ examples/awq/qwen3-vl-30b-a3b-Instruct-example.py ./Qwen3-VL-30B-A3B-Instruct-AWQ-W4A16-mse-seq 446.68 3.93 0.8643 0.8491 0.8613 0.8499 +- examples/awq/qwen3_moe_example_ddp.py ./Qwen3-30B-A3B-awq-sym-DDP4 143.90 3.36 0.8802 0.8832 0.8848 0.8802 -+ examples/awq/qwen3_moe_example.py ./Qwen3-30B-A3B-awq-sym 459.65 4.13 0.8825 0.8863 0.8878 0.8840 -+ ``` ## changes: - Added distributed functionality - Accumulate activation sums instead of means to avoid floating point errors - Make everything broadcastable by changing to tensors - added helper for all_reducing with sum op Test Plan: see penultimate commit for test scripts and evaluation framework --------- Signed-off-by: HDCharles <charlesdavidhernandez@gmail.com>
1 parent 2ab0244 commit 0bc916e

File tree

3 files changed

+273
-31
lines changed

3 files changed

+273
-31
lines changed

examples/awq/llama_example_ddp.py

Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
#############################################################################
2+
# This script is adapted from ./llama_example.py and adds DDP functionality.
3+
# run this with `torchrun --nproc_per_node=2 llama_example_ddp.py`
4+
# or change nproc_per_node to your desired configuration
5+
# to adapt other examples to use DDP, see the 2 altered sections below
6+
#############################################################################
7+
8+
import time
9+
10+
import torch
11+
from compressed_tensors.offload import dispatch_model, init_dist, load_offloaded_model
12+
from datasets import load_dataset
13+
from transformers import AutoModelForCausalLM, AutoTokenizer
14+
15+
from llmcompressor import oneshot
16+
from llmcompressor.datasets.utils import get_rank_partition
17+
from llmcompressor.modifiers.awq import AWQModifier
18+
19+
# Select model and load it.
20+
MODEL_ID = "meta-llama/Meta-Llama-3-8B-Instruct"
21+
22+
init_dist()
23+
with load_offloaded_model():
24+
model = AutoModelForCausalLM.from_pretrained(
25+
MODEL_ID, dtype="auto", device_map="auto_offload"
26+
)
27+
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
28+
29+
# Select calibration dataset.
30+
DATASET_ID = "HuggingFaceH4/ultrachat_200k"
31+
DATASET_SPLIT = "train_sft"
32+
33+
# Select number of samples. 256 samples is a good place to start.
34+
# Increasing the number of samples can improve accuracy.
35+
NUM_CALIBRATION_SAMPLES = 256
36+
MAX_SEQUENCE_LENGTH = 512
37+
38+
# Load dataset and preprocess.
39+
ds = load_dataset(
40+
DATASET_ID, split=get_rank_partition(DATASET_SPLIT, NUM_CALIBRATION_SAMPLES)
41+
)
42+
ds = ds.shuffle(seed=42)
43+
44+
45+
def preprocess(example):
46+
return {
47+
"text": tokenizer.apply_chat_template(
48+
example["messages"],
49+
tokenize=False,
50+
)
51+
}
52+
53+
54+
ds = ds.map(preprocess)
55+
56+
57+
# Tokenize inputs.
58+
def tokenize(sample):
59+
return tokenizer(
60+
sample["text"],
61+
padding=False,
62+
max_length=MAX_SEQUENCE_LENGTH,
63+
truncation=True,
64+
add_special_tokens=False,
65+
)
66+
67+
68+
# Configure the quantization algorithm to run.
69+
recipe = [
70+
AWQModifier(
71+
ignore=["lm_head"], scheme="W4A16_ASYM", targets=["Linear"], duo_scaling="both"
72+
),
73+
]
74+
75+
torch.cuda.reset_peak_memory_stats()
76+
start_time = time.time()
77+
78+
# Apply algorithms.
79+
oneshot(
80+
model=model,
81+
dataset=ds,
82+
recipe=recipe,
83+
max_seq_length=MAX_SEQUENCE_LENGTH,
84+
num_calibration_samples=NUM_CALIBRATION_SAMPLES,
85+
)
86+
87+
elapsed_time = time.time() - start_time
88+
peak_memory_gb = torch.cuda.max_memory_allocated() / (1024**3)
89+
print("Quantization Complete")
90+
print(f"Time: {elapsed_time / 60:.2f} minutes ({elapsed_time:.2f} seconds)")
91+
print(f"Peak GPU Memory: {peak_memory_gb:.2f} GB")
92+
93+
# Confirm generations of the quantized model look sane.
94+
print("\n\n")
95+
print("========== SAMPLE GENERATION ==============")
96+
dispatch_model(model)
97+
input_ids = tokenizer("Hello my name is", return_tensors="pt").input_ids.to(
98+
model.device
99+
)
100+
output = model.generate(input_ids, max_new_tokens=100)
101+
print(tokenizer.decode(output[0]))
102+
print("==========================================\n\n")
103+
104+
# Save to disk compressed.
105+
SAVE_DIR = (
106+
MODEL_ID.rstrip("/").split("/")[-1]
107+
+ "-awq-asym-DDP"
108+
+ str(torch.distributed.get_world_size())
109+
)
110+
model.save_pretrained(SAVE_DIR, save_compressed=True)
111+
tokenizer.save_pretrained(SAVE_DIR)
112+
113+
torch.distributed.destroy_process_group()
Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
#############################################################################
2+
# This script is adapted from ./qwen3_moe_example.py and adds DDP functionality.
3+
# run this with `torchrun --nproc_per_node=2 qwen3_moe_example_ddp.py`
4+
# or change nproc_per_node to your desired configuration
5+
# to adapt other examples to use DDP, see the 2 altered sections below
6+
#############################################################################
7+
8+
import time
9+
10+
import torch
11+
from compressed_tensors.offload import dispatch_model, init_dist, load_offloaded_model
12+
from datasets import load_dataset
13+
from transformers import AutoModelForCausalLM, AutoTokenizer
14+
15+
from llmcompressor import oneshot
16+
from llmcompressor.datasets.utils import get_rank_partition
17+
from llmcompressor.modifiers.awq import AWQModifier
18+
19+
# Select model and load it.
20+
MODEL_ID = "Qwen/Qwen3-30B-A3B"
21+
22+
init_dist()
23+
with load_offloaded_model():
24+
model = AutoModelForCausalLM.from_pretrained(
25+
MODEL_ID, dtype="auto", device_map="auto_offload"
26+
)
27+
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
28+
29+
# Select calibration dataset.
30+
DATASET_ID = "HuggingFaceH4/ultrachat_200k"
31+
DATASET_SPLIT = "train_sft"
32+
33+
# Select number of samples. 256 samples is a good place to start.
34+
# Increasing the number of samples can improve accuracy.
35+
NUM_CALIBRATION_SAMPLES = 256
36+
MAX_SEQUENCE_LENGTH = 512
37+
38+
# Load dataset and preprocess.
39+
ds = load_dataset(
40+
DATASET_ID, split=get_rank_partition(DATASET_SPLIT, NUM_CALIBRATION_SAMPLES)
41+
)
42+
ds = ds.shuffle(seed=42)
43+
44+
45+
def preprocess(example):
46+
return {
47+
"text": tokenizer.apply_chat_template(
48+
example["messages"],
49+
tokenize=False,
50+
)
51+
}
52+
53+
54+
ds = ds.map(preprocess)
55+
56+
57+
# Tokenize inputs.
58+
def tokenize(sample):
59+
return tokenizer(
60+
sample["text"],
61+
padding=False,
62+
max_length=MAX_SEQUENCE_LENGTH,
63+
truncation=True,
64+
add_special_tokens=False,
65+
)
66+
67+
68+
# Configure the quantization algorithm to run.
69+
# NOTE: vllm currently does not support asym MoE, using symmetric here
70+
recipe = [
71+
AWQModifier(
72+
ignore=["lm_head", "re:.*mlp.gate$", "re:.*mlp.shared_expert_gate$"],
73+
scheme="W4A16",
74+
targets=["Linear"],
75+
),
76+
]
77+
78+
torch.cuda.reset_peak_memory_stats()
79+
start_time = time.time()
80+
81+
# Apply algorithms.
82+
oneshot(
83+
model=model,
84+
dataset=ds,
85+
recipe=recipe,
86+
max_seq_length=MAX_SEQUENCE_LENGTH,
87+
num_calibration_samples=NUM_CALIBRATION_SAMPLES,
88+
)
89+
90+
elapsed_time = time.time() - start_time
91+
peak_memory_gb = torch.cuda.max_memory_allocated() / (1024**3)
92+
print("Quantization Complete")
93+
print(f"Time: {elapsed_time / 60:.2f} minutes ({elapsed_time:.2f} seconds)")
94+
print(f"Peak GPU Memory: {peak_memory_gb:.2f} GB")
95+
96+
# Confirm generations of the quantized model look sane.
97+
print("\n\n")
98+
print("========== SAMPLE GENERATION ==============")
99+
dispatch_model(model)
100+
input_ids = tokenizer("Hello my name is", return_tensors="pt").input_ids.to(
101+
model.device
102+
)
103+
output = model.generate(input_ids, max_new_tokens=100)
104+
print(tokenizer.decode(output[0]))
105+
print("==========================================\n\n")
106+
107+
# Save to disk compressed.
108+
SAVE_DIR = (
109+
MODEL_ID.rstrip("/").split("/")[-1]
110+
+ "-awq-sym-DDP"
111+
+ str(torch.distributed.get_world_size())
112+
)
113+
model.save_pretrained(SAVE_DIR, save_compressed=True)
114+
tokenizer.save_pretrained(SAVE_DIR)
115+
116+
torch.distributed.destroy_process_group()

0 commit comments

Comments
 (0)