Skip to content

Commit affa73f

Browse files
committed
working, jank
Signed-off-by: Kyle Sayers <[email protected]>
1 parent 8ffa1f9 commit affa73f

File tree

4 files changed

+408
-31
lines changed

4 files changed

+408
-31
lines changed
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
from datasets import load_dataset
2+
from transformers import AutoModelForCausalLM, AutoTokenizer
3+
4+
from llmcompressor.modeling import replace_modules_for_calibration
5+
from llmcompressor.modifiers.quantization import GPTQModifier
6+
from llmcompressor import oneshot
7+
from llmcompressor.utils import dispatch_for_generation
8+
9+
# Select model and load it.
10+
model_id = "unsloth/gpt-oss-20b-BF16"
11+
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype="auto")
12+
tokenizer = AutoTokenizer.from_pretrained(model_id)
13+
#replace_modules_for_calibration(model) # linearize experts so they can be targeted
14+
15+
# Select calibration dataset.
16+
DATASET_ID = "HuggingFaceH4/ultrachat_200k"
17+
DATASET_SPLIT = "train_sft"
18+
19+
# Select number of samples. 512 samples is a good place to start.
20+
# Increasing the number of samples can improve accuracy.
21+
NUM_CALIBRATION_SAMPLES = 1#512
22+
MAX_SEQUENCE_LENGTH = 2048
23+
24+
# Load dataset and preprocess.
25+
ds = load_dataset(DATASET_ID, split=f"{DATASET_SPLIT}[:{NUM_CALIBRATION_SAMPLES}]")
26+
ds = ds.shuffle(seed=42)
27+
28+
29+
def preprocess(example):
30+
return {
31+
"text": tokenizer.apply_chat_template(
32+
example["messages"],
33+
tokenize=False,
34+
)
35+
}
36+
37+
38+
ds = ds.map(preprocess)
39+
40+
41+
# Tokenize inputs.
42+
def tokenize(sample):
43+
return tokenizer(
44+
sample["text"],
45+
padding=False,
46+
max_length=MAX_SEQUENCE_LENGTH,
47+
truncation=True,
48+
add_special_tokens=False,
49+
)
50+
51+
52+
ds = ds.map(tokenize, remove_columns=ds.column_names)
53+
54+
# Configure the quantization algorithm to run.
55+
# * quantize the weights to 4 bit with GPTQ with a group size 128
56+
recipe = GPTQModifier(targets="Linear", scheme="W4A16", ignore=["lm_head"])
57+
58+
# Apply algorithms.
59+
oneshot(
60+
model=model,
61+
dataset=ds,
62+
recipe=recipe,
63+
max_seq_length=MAX_SEQUENCE_LENGTH,
64+
num_calibration_samples=NUM_CALIBRATION_SAMPLES,
65+
calibrate_moe_context=True,
66+
pipeline="sequential",
67+
)
68+
69+
# # Confirm generations of the quantized model look sane.
70+
print("\n\n")
71+
print("========== SAMPLE GENERATION ==============")
72+
dispatch_for_generation(model)
73+
sample = tokenizer("Hello my name is", return_tensors="pt")
74+
sample = {key: value.to("cuda") for key, value in sample.items()}
75+
output = model.generate(**sample, max_new_tokens=100)
76+
print(tokenizer.decode(output[0]))
77+
print("==========================================\n\n")
78+
79+
# Save to disk compressed.
80+
SAVE_DIR = model_id.rstrip("/").split("/")[-1] + "-W4A16"
81+
model.save_pretrained(SAVE_DIR, save_compressed=True)
82+
tokenizer.save_pretrained(SAVE_DIR)

src/llmcompressor/modeling/gpt_oss.py

Lines changed: 181 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,181 @@
1+
from typing import List
2+
3+
import torch
4+
import contextlib
5+
6+
from transformers import GptOssForCausalLM
7+
from transformers.models.gpt_oss.modeling_gpt_oss import GptOssExperts
8+
from transformers.models.gpt_oss.configuration_gpt_oss import GptOssConfig
9+
from llmcompressor.utils.dev import skip_weights_initialize
10+
11+
from compressed_tensors.utils import update_offload_parameter, align_module_device
12+
13+
14+
class GptOssExpert(torch.nn.Module):
15+
gate_proj: torch.nn.Linear
16+
up_proj: torch.nn.Linear
17+
down_proj: torch.nn.Linear
18+
19+
def __init__(self, experts: GptOssExperts):
20+
super().__init__()
21+
22+
self.hidden_size = experts.hidden_size
23+
self.expert_dim = experts.expert_dim
24+
self.alpha = experts.alpha
25+
self.limit = experts.limit
26+
27+
assert experts.gate_up_proj.dtype == experts.gate_up_proj_bias.dtype
28+
assert experts.down_proj.dtype == experts.down_proj_bias.dtype
29+
30+
with skip_weights_initialize():
31+
self.gate_proj = torch.nn.Linear(self.hidden_size, self.expert_dim, bias=True, dtype=experts.gate_up_proj.dtype)
32+
self.up_proj = torch.nn.Linear(self.hidden_size, self.expert_dim, bias=True, dtype=experts.gate_up_proj.dtype)
33+
self.down_proj = torch.nn.Linear(self.expert_dim, self.hidden_size, bias=True, dtype=experts.down_proj.dtype)
34+
35+
def forward(self, hidden_states: torch.Tensor):
36+
gate = self.gate_proj(hidden_states)
37+
gate = gate.clamp(min=None, max=self.limit)
38+
39+
up = self.up_proj(hidden_states)
40+
up = up.clamp(min=-self.limit, max=self.limit)
41+
42+
glu = gate * torch.sigmoid(gate * self.alpha)
43+
return self.down_proj((up + 1) * glu)
44+
45+
46+
47+
class GptOssExpertsLinear(torch.nn.Module):
48+
experts: List[GptOssExpert]
49+
50+
def __init__(self, experts: GptOssExperts):
51+
super().__init__()
52+
53+
self.intermediate_size = experts.intermediate_size
54+
self.num_experts = experts.num_experts
55+
self.hidden_size = experts.hidden_size
56+
self.expert_dim = experts.expert_dim
57+
58+
with skip_weights_initialize():
59+
self.experts = torch.nn.ModuleList([GptOssExpert(experts) for _ in range(self.num_experts)])
60+
61+
self.load_weights(experts)
62+
63+
self.alpha = experts.alpha
64+
self.limit = experts.limit
65+
66+
def load_weights(self, experts: GptOssExperts):
67+
with align_module_device(experts):
68+
for expert_index, expert in enumerate(self.experts):
69+
update_offload_parameter(expert.gate_proj, "weight", experts.gate_up_proj[expert_index, ..., ::2].T)
70+
update_offload_parameter(expert.gate_proj, "bias", experts.gate_up_proj_bias[expert_index, ..., ::2])
71+
72+
update_offload_parameter(expert.up_proj, "weight", experts.gate_up_proj[expert_index, ..., 1::2].T)
73+
update_offload_parameter(expert.up_proj, "bias", experts.gate_up_proj_bias[expert_index, ..., 1::2])
74+
75+
update_offload_parameter(expert.down_proj, "weight", experts.down_proj[expert_index].T)
76+
update_offload_parameter(expert.down_proj, "bias", experts.down_proj_bias[expert_index])
77+
78+
def to_original(self) -> GptOssExperts:
79+
# TODO: this doesn't really handle offloading or correct device placement
80+
with skip_weights_initialize(use_zeros=True):
81+
fake_config = GptOssConfig(
82+
intermediate_size=self.intermediate_size,
83+
num_local_experts=self.num_experts,
84+
hidden_size=self.hidden_size,
85+
)
86+
experts = GptOssExperts(fake_config)
87+
experts.gate_up_proj = torch.nn.Parameter(experts.gate_up_proj.to(dtype=self.experts[0].gate_proj.weight.dtype), requires_grad=False)
88+
experts.gate_up_proj_bias = torch.nn.Parameter(experts.gate_up_proj_bias.to(dtype=self.experts[0].gate_proj.weight.dtype), requires_grad=False)
89+
experts.down_proj = torch.nn.Parameter(experts.down_proj.to(dtype=self.experts[0].down_proj.weight.dtype), requires_grad=False)
90+
experts.down_proj_bias = torch.nn.Parameter(experts.down_proj_bias.to(dtype=self.experts[0].down_proj.weight.dtype), requires_grad=False)
91+
92+
for expert_index, expert in enumerate(self.experts):
93+
with align_module_device(expert.gate_proj, "cpu"), align_module_device(expert.up_proj, "cpu"), align_module_device(expert.down_proj, "cpu"):
94+
experts.gate_up_proj[expert_index, ..., ::2].copy_(expert.gate_proj.weight.data.T)
95+
experts.gate_up_proj_bias[expert_index, ..., ::2].copy_(expert.gate_proj.bias.data)
96+
97+
experts.gate_up_proj[expert_index, ..., 1::2].copy_(expert.up_proj.weight.data.T)
98+
experts.gate_up_proj_bias[expert_index, ..., 1::2].copy_(expert.up_proj.bias.data)
99+
100+
experts.down_proj[expert_index].copy_(expert.down_proj.weight.data.T)
101+
experts.down_proj_bias[expert_index].copy_(expert.down_proj.bias.data)
102+
103+
print("converted, for some reason slows down over time")
104+
import time
105+
print(time.time())
106+
107+
experts.eval()
108+
return experts
109+
110+
111+
def forward(self, hidden_states: torch.Tensor, router_indices=None, routing_weights=None) -> torch.Tensor:
112+
"""
113+
When training is is more efficient to just loop over the experts and compute the output for each expert
114+
as otherwise the memory would explode.
115+
116+
For inference we can sacrifice some memory and compute the output for all experts at once. By repeating the inputs.
117+
118+
Args:
119+
hidden_states (torch.Tensor): (batch_size, seq_len, hidden_size)
120+
selected_experts (torch.Tensor): (batch_size * token_num, top_k)
121+
routing_weights (torch.Tensor): (batch_size * token_num, num_experts)
122+
Returns:
123+
torch.Tensor
124+
"""
125+
original_shape = hidden_states.shape
126+
hidden_states = hidden_states.reshape(-1, self.hidden_size) # (num_tokens, hidden_size)
127+
128+
next_states = torch.zeros_like(hidden_states, dtype=hidden_states.dtype, device=hidden_states.device)
129+
for expert_index, expert in enumerate(self.experts):
130+
next_states += expert(hidden_states) * routing_weights.T[expert_index].unsqueeze(-1)
131+
132+
next_states = next_states.reshape(original_shape)
133+
return next_states
134+
135+
def replace_gpt_oss(config: GptOssConfig, module: GptOssExpert):
136+
return GptOssExpertsLinear(module)
137+
138+
139+
def test_restore():
140+
config = GptOssConfig(hidden_size=7, num_local_experts=3, expert_dim=5)
141+
142+
original = GptOssExperts(config)
143+
linear = GptOssExpertsLinear(original)
144+
145+
restored = linear.to_original()
146+
for param_name, param in original.named_parameters(recurse=False):
147+
restored_param = getattr(restored, param_name)
148+
assert param.shape == restored_param.shape
149+
assert param.dtype == restored_param.dtype
150+
151+
assert torch.all(getattr(restored, param_name) == param)
152+
153+
154+
def test_correctness():
155+
batch_size, seq_len = 13, 12
156+
config = GptOssConfig(hidden_size=7, num_local_experts=3, expert_dim=5)
157+
158+
input = torch.rand((batch_size, seq_len, config.hidden_size))
159+
routing_weights = torch.rand((batch_size * seq_len, config.num_local_experts))
160+
161+
with torch.no_grad():
162+
original = GptOssExperts(config)
163+
for name in ["gate_up_proj", "gate_up_proj_bias", "down_proj", "down_proj_bias"]:
164+
setattr(original, name, getattr(original, name).normal_())
165+
166+
original.eval()
167+
assert original.training == False
168+
true_output = original(input, routing_weights=routing_weights)
169+
170+
linear = GptOssExpertsLinear(original)
171+
output = linear(input, routing_weights=routing_weights)
172+
173+
assert torch.allclose(output, true_output, atol=1e-3, rtol=0.0)
174+
175+
restored = linear.to_original()
176+
restored_output = restored(input, routing_weights=routing_weights)
177+
assert torch.allclose(restored_output, true_output, atol=1e-3, rtol=0.0)
178+
179+
180+
if __name__ == "__main__":
181+
test_restore()

0 commit comments

Comments
 (0)