Skip to content

Commit 4096ffd

Browse files
committed
WIP
Signed-off-by: Kyle Sayers <[email protected]>
1 parent 3e18dd9 commit 4096ffd

File tree

2 files changed

+122
-0
lines changed

2 files changed

+122
-0
lines changed

src/llmcompressor/modeling/gpt_oss.py

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
from typing import List
2+
3+
import torch
4+
import contextlib
5+
6+
from transformers.models.gpt_oss.modeling_gpt_oss import GptOssExperts
7+
from transformers.models.gpt_oss.configuration_gpt_oss import GptOssConfig
8+
from llmcompressor.utils.dev import skip_weights_initialize
9+
10+
11+
class GptOssExpert(torch.nn.Module):
12+
def __init__(self, hidden_size: int, expert_dim: int, alpha: float, limit: float):
13+
super().__init__()
14+
15+
self.hidden_size = hidden_size
16+
self.expert_dim = expert_dim
17+
self.alpha = alpha
18+
self.limit = limit
19+
20+
with skip_weights_initialize():
21+
self.gate_proj = torch.nn.Linear(self.hidden_size, self.expert_dim, bias=True)
22+
self.up_proj = torch.nn.Linear(self.hidden_size, self.expert_dim, bias=True)
23+
self.down_proj = torch.nn.Linear(self.expert_dim, self.hidden_size, bias=True)
24+
25+
26+
def forward(self, hidden_states: torch.Tensor):
27+
gate = self.gate_proj(hidden_states)
28+
gate = gate.clamp(min=None, max=self.limit)
29+
30+
up = self.up_proj(hidden_states)
31+
up = up.clamp(min=-self.limit, max=self.limit)
32+
33+
glu = gate * torch.sigmoid(gate * self.alpha)
34+
return self.down_proj((up + 1) * glu)
35+
36+
37+
38+
class GptOssExpertsLinear(torch.nn.Module):
39+
experts: List[GptOssExpert]
40+
41+
def __init__(self, experts: GptOssExpert):
42+
super().__init__()
43+
44+
self.intermediate_size = experts.intermediate_size
45+
self.num_experts = experts.num_experts
46+
self.hidden_size = experts.hidden_size
47+
self.expert_dim = experts.expert_dim
48+
49+
with skip_weights_initialize():
50+
self.experts = [GptOssExpert(self.hidden_size, self.expert_dim, experts.alpha, experts.limit) for _ in range(self.num_experts)]
51+
52+
self.load_weights(experts)
53+
54+
self.alpha = experts.alpha
55+
self.limit = experts.limit
56+
57+
def load_weights(self, experts: GptOssExperts):
58+
for expert_index, expert in enumerate(self.experts):
59+
expert.gate_proj.weight.data = experts.gate_up_proj[expert_index, ..., ::2].data.T
60+
expert.gate_proj.bias.data = experts.gate_up_proj_bias[expert_index, ..., ::2].data
61+
62+
expert.up_proj.weight.data = experts.gate_up_proj[expert_index, ..., 1::2].data.T
63+
expert.up_proj.bias.data = experts.gate_up_proj_bias[expert_index, ..., 1::2].data
64+
65+
expert.down_proj.weight.data = experts.down_proj[expert_index].T
66+
expert.down_proj.bias.data = experts.down_proj_bias[expert_index]
67+
68+
69+
def to_original(self) -> GptOssExperts:
70+
pass
71+
72+
73+
def forward(self, hidden_states: torch.Tensor, router_indices=None, routing_weights=None) -> torch.Tensor:
74+
"""
75+
When training is is more efficient to just loop over the experts and compute the output for each expert
76+
as otherwise the memory would explode.
77+
78+
For inference we can sacrifice some memory and compute the output for all experts at once. By repeating the inputs.
79+
80+
Args:
81+
hidden_states (torch.Tensor): (batch_size, seq_len, hidden_size)
82+
selected_experts (torch.Tensor): (batch_size * token_num, top_k)
83+
routing_weights (torch.Tensor): (batch_size * token_num, num_experts)
84+
Returns:
85+
torch.Tensor
86+
"""
87+
original_shape = hidden_states.shape
88+
hidden_states = hidden_states.reshape(-1, self.hidden_size) # (num_tokens, hidden_size)
89+
90+
next_states = torch.zeros_like(hidden_states, dtype=hidden_states.dtype, device=hidden_states.device)
91+
for expert_index, expert in enumerate(self.experts):
92+
next_states += expert(hidden_states) * routing_weights.T[expert_index].unsqueeze(-1)
93+
94+
next_states = next_states.reshape(original_shape)
95+
return next_states
96+
97+
98+
if __name__ == "__main__":
99+
batch_size, seq_len = 13, 12
100+
config = GptOssConfig(hidden_size=7, num_local_experts=3, expert_dim=5)
101+
102+
input = torch.rand((batch_size, seq_len, config.hidden_size))
103+
routing_weights = torch.rand((batch_size * seq_len, config.num_local_experts))
104+
105+
with torch.no_grad():
106+
original = GptOssExperts(config)
107+
for name in ["gate_up_proj", "gate_up_proj_bias", "down_proj", "down_proj_bias"]:
108+
setattr(original, name, getattr(original, name).normal_())
109+
110+
original.eval()
111+
true_output = original(input, routing_weights=routing_weights)
112+
113+
linear = GptOssExpertsLinear(original)
114+
output = linear(input, routing_weights=routing_weights)
115+
116+
breakpoint()
117+
assert torch.allclose(output, true_output, atol=1e-3, rtol=0.0)

src/llmcompressor/modeling/prepare.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,11 @@ def update_qwen3_moe(model, stack):
4242
)
4343

4444

45+
def update_gpt_oss_moe(model, stack):
46+
47+
48+
49+
4550
moe_context = {
4651
"Qwen3MoeForCausalLM": update_qwen3_moe,
4752
}

0 commit comments

Comments
 (0)