Skip to content

Commit ad67532

Browse files
committed
WIP: incorrect offloading
Signed-off-by: Kyle Sayers <[email protected]>
1 parent 337e067 commit ad67532

File tree

3 files changed

+110
-17
lines changed

3 files changed

+110
-17
lines changed
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
from datasets import load_dataset
2+
from transformers import AutoModelForCausalLM, AutoTokenizer
3+
4+
from llmcompressor.modifiers.quantization import GPTQModifier
5+
from llmcompressor.transformers import oneshot
6+
from llmcompressor.utils import dispatch_for_generation
7+
8+
# Select model and load it.
9+
model_id = "unsloth/gpt-oss-20b-BF16"
10+
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype="auto")
11+
tokenizer = AutoTokenizer.from_pretrained(model_id)
12+
13+
# Select calibration dataset.
14+
DATASET_ID = "HuggingFaceH4/ultrachat_200k"
15+
DATASET_SPLIT = "train_sft"
16+
17+
# Select number of samples. 512 samples is a good place to start.
18+
# Increasing the number of samples can improve accuracy.
19+
NUM_CALIBRATION_SAMPLES = 512
20+
MAX_SEQUENCE_LENGTH = 2048
21+
22+
# Load dataset and preprocess.
23+
ds = load_dataset(DATASET_ID, split=f"{DATASET_SPLIT}[:{NUM_CALIBRATION_SAMPLES}]")
24+
ds = ds.shuffle(seed=42)
25+
26+
27+
def preprocess(example):
28+
return {
29+
"text": tokenizer.apply_chat_template(
30+
example["messages"],
31+
tokenize=False,
32+
)
33+
}
34+
35+
36+
ds = ds.map(preprocess)
37+
38+
39+
# Tokenize inputs.
40+
def tokenize(sample):
41+
return tokenizer(
42+
sample["text"],
43+
padding=False,
44+
max_length=MAX_SEQUENCE_LENGTH,
45+
truncation=True,
46+
add_special_tokens=False,
47+
)
48+
49+
50+
ds = ds.map(tokenize, remove_columns=ds.column_names)
51+
52+
# Configure the quantization algorithm to run.
53+
# * quantize the weights to 4 bit with GPTQ with a group size 128
54+
recipe = GPTQModifier(targets="Linear", scheme="W4A16", ignore=["lm_head"])
55+
56+
# Apply algorithms.
57+
oneshot(
58+
model=model,
59+
dataset=ds,
60+
recipe=recipe,
61+
max_seq_length=MAX_SEQUENCE_LENGTH,
62+
num_calibration_samples=NUM_CALIBRATION_SAMPLES,
63+
calibrate_moe_context=True,
64+
)
65+
66+
# Confirm generations of the quantized model look sane.
67+
print("\n\n")
68+
print("========== SAMPLE GENERATION ==============")
69+
dispatch_for_generation(model)
70+
sample = tokenizer("Hello my name is", return_tensors="pt")
71+
sample = {key: value.to("cuda") for key, value in sample.items()}
72+
output = model.generate(**sample, max_new_tokens=100)
73+
print(tokenizer.decode(output[0]))
74+
print("==========================================\n\n")
75+
76+
# Save to disk compressed.
77+
SAVE_DIR = model_id.rstrip("/").split("/")[-1] + "-W4A16-G128"
78+
model.save_pretrained(SAVE_DIR, save_compressed=True)
79+
tokenizer.save_pretrained(SAVE_DIR)

src/llmcompressor/modeling/gpt_oss.py

Lines changed: 26 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,14 @@
77
from transformers.models.gpt_oss.configuration_gpt_oss import GptOssConfig
88
from llmcompressor.utils.dev import skip_weights_initialize
99

10-
from compressed_tensors import update_offload_parameter
10+
from compressed_tensors.utils import update_offload_parameter, align_module_device
1111

1212

1313
class GptOssExpert(torch.nn.Module):
14+
gate_proj: torch.nn.Linear
15+
up_proj: torch.nn.Linear
16+
down_proj: torch.nn.Linear
17+
1418
def __init__(self, hidden_size: int, expert_dim: int, alpha: float, limit: float):
1519
super().__init__()
1620

@@ -57,17 +61,21 @@ def __init__(self, experts: GptOssExpert):
5761
self.limit = experts.limit
5862

5963
def load_weights(self, experts: GptOssExperts):
60-
for expert_index, expert in enumerate(self.experts):
61-
update_offload_parameter(expert.gate_proj, "weight", experts.gate_up_proj[expert_index, ..., ::2].T)
62-
update_offload_parameter(expert.gate_proj, "bias", experts.gate_up_proj_bias[expert_index, ..., ::2])
64+
# TODO: this code is inefficient. If there was a "get_offloaded_data" util,
65+
# we could avoid having to move from cpu -> gpu -> cpu
66+
with align_module_device(experts):
67+
for expert_index, expert in enumerate(self.experts):
68+
update_offload_parameter(expert.gate_proj, "weight", experts.gate_up_proj[expert_index, ..., ::2].T)
69+
update_offload_parameter(expert.gate_proj, "bias", experts.gate_up_proj_bias[expert_index, ..., ::2])
6370

64-
update_offload_parameter(expert.up_proj, "weight", experts.gate_up_proj[expert_index, ..., 1::2].T)
65-
update_offload_parameter(expert.up_proj, "bias", experts.gate_up_proj_bias[expert_index, ..., 1::2])
71+
update_offload_parameter(expert.up_proj, "weight", experts.gate_up_proj[expert_index, ..., 1::2].T)
72+
update_offload_parameter(expert.up_proj, "bias", experts.gate_up_proj_bias[expert_index, ..., 1::2])
6673

67-
update_offload_parameter(expert.down_proj, "weight", experts.down_proj[expert_index].T)
68-
update_offload_parameter(expert.down_proj, "bias", experts.down_proj_bias[expert_index])
74+
update_offload_parameter(expert.down_proj, "weight", experts.down_proj[expert_index].T)
75+
update_offload_parameter(expert.down_proj, "bias", experts.down_proj_bias[expert_index])
6976

7077
def to_original(self) -> GptOssExperts:
78+
# TODO: this doesn't really handle offloading or correct device placement
7179
with skip_weights_initialize():
7280
fake_config = GptOssConfig(
7381
intermediate_size=self.intermediate_size,
@@ -78,14 +86,17 @@ def to_original(self) -> GptOssExperts:
7886
experts = GptOssExperts(fake_config)
7987

8088
for expert_index, expert in enumerate(self.experts):
81-
experts.gate_up_proj[expert_index, ..., ::2].data = expert.gate_proj.weight.data.T
82-
experts.gate_up_proj_bias[expert_index, ..., ::2].data = expert.gate_proj.bias.data
89+
# TODO: this code is inefficient. If there was a "get_offloaded_data" util,
90+
# we could avoid having to move from cpu -> gpu -> cpu
91+
with align_module_device(expert):
92+
experts.gate_up_proj[expert_index, ..., ::2].data = expert.gate_proj.weight.data.T
93+
experts.gate_up_proj_bias[expert_index, ..., ::2].data = expert.gate_proj.bias.data
8394

84-
experts.gate_up_proj[expert_index, ..., 1::2].data = expert.up_proj.weight.data.T
85-
experts.gate_up_proj_bias[expert_index, ..., 1::2].data = expert.up_proj.bias.data
95+
experts.gate_up_proj[expert_index, ..., 1::2].data = expert.up_proj.weight.data.T
96+
experts.gate_up_proj_bias[expert_index, ..., 1::2].data = expert.up_proj.bias.data
8697

87-
experts.down_proj[expert_index].data = expert.down_proj.weight.data.T
88-
experts.down_proj_bias[expert_index] = expert.down_proj.bias.data
98+
experts.down_proj[expert_index].data = expert.down_proj.weight.data.T
99+
experts.down_proj_bias[expert_index] = expert.down_proj.bias.data
89100

90101
# update offloaded state dict
91102
update_offload_parameter(experts, "gate_up_proj", experts.gate_up_proj)
@@ -134,6 +145,7 @@ def forward(self, hidden_states: torch.Tensor, router_indices=None, routing_weig
134145
setattr(original, name, getattr(original, name).normal_())
135146

136147
original.eval()
148+
assert original.training == False
137149
true_output = original(input, routing_weights=routing_weights)
138150

139151
linear = GptOssExpertsLinear(original)

src/llmcompressor/modeling/prepare.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1+
import tqdm
12
import contextlib
2-
from compressed_tensors.utils import replace_module
3+
from compressed_tensors.utils import replace_module, match_named_modules
34
from transformers import PreTrainedModel
45

56
from llmcompressor.modeling.deepseek_v3 import replace as replace_deepseekv3
@@ -57,9 +58,10 @@ def replace_context(model, name, module):
5758
replace_module(model, name, restored)
5859

5960
# TODO: need to think about duplicates
60-
for name, module in model.named_modules():
61+
modules = list(model.named_modules())
62+
for name, module in tqdm.tqdm(modules, desc="Checking modules for replacements"):
6163
cls_name = module.__class__.__name__
62-
if cls_name == "GptOssExpert":
64+
if cls_name == "GptOssExperts":
6365
stack.enter_context(replace_context(model, name, module))
6466

6567

0 commit comments

Comments
 (0)