1
+ from typing import List
2
+
3
+ import torch
4
+ import contextlib
5
+
6
+ from transformers import GptOssForCausalLM
7
+ from transformers .models .gpt_oss .modeling_gpt_oss import GptOssExperts
8
+ from transformers .models .gpt_oss .configuration_gpt_oss import GptOssConfig
9
+ from llmcompressor .utils .dev import skip_weights_initialize
10
+
11
+ from compressed_tensors .utils import update_offload_parameter , align_module_device
12
+
13
+
14
+ class GptOssExpert (torch .nn .Module ):
15
+ gate_proj : torch .nn .Linear
16
+ up_proj : torch .nn .Linear
17
+ down_proj : torch .nn .Linear
18
+
19
+ def __init__ (self , experts : GptOssExperts ):
20
+ super ().__init__ ()
21
+
22
+ self .hidden_size = experts .hidden_size
23
+ self .expert_dim = experts .expert_dim
24
+ self .alpha = experts .alpha
25
+ self .limit = experts .limit
26
+
27
+ assert experts .gate_up_proj .dtype == experts .gate_up_proj_bias .dtype
28
+ assert experts .down_proj .dtype == experts .down_proj_bias .dtype
29
+
30
+ with skip_weights_initialize ():
31
+ self .gate_proj = torch .nn .Linear (self .hidden_size , self .expert_dim , bias = True , dtype = experts .gate_up_proj .dtype )
32
+ self .up_proj = torch .nn .Linear (self .hidden_size , self .expert_dim , bias = True , dtype = experts .gate_up_proj .dtype )
33
+ self .down_proj = torch .nn .Linear (self .expert_dim , self .hidden_size , bias = True , dtype = experts .down_proj .dtype )
34
+
35
+ def forward (self , hidden_states : torch .Tensor ):
36
+ gate = self .gate_proj (hidden_states )
37
+ gate = gate .clamp (min = None , max = self .limit )
38
+
39
+ up = self .up_proj (hidden_states )
40
+ up = up .clamp (min = - self .limit , max = self .limit )
41
+
42
+ glu = gate * torch .sigmoid (gate * self .alpha )
43
+ return self .down_proj ((up + 1 ) * glu )
44
+
45
+
46
+
47
+ class GptOssExpertsLinear (torch .nn .Module ):
48
+ experts : List [GptOssExpert ]
49
+
50
+ def __init__ (self , experts : GptOssExperts ):
51
+ super ().__init__ ()
52
+
53
+ self .intermediate_size = experts .intermediate_size
54
+ self .num_experts = experts .num_experts
55
+ self .hidden_size = experts .hidden_size
56
+ self .expert_dim = experts .expert_dim
57
+
58
+ with skip_weights_initialize ():
59
+ self .experts = torch .nn .ModuleList ([GptOssExpert (experts ) for _ in range (self .num_experts )])
60
+
61
+ self .load_weights (experts )
62
+
63
+ self .alpha = experts .alpha
64
+ self .limit = experts .limit
65
+
66
+ def load_weights (self , experts : GptOssExperts ):
67
+ with align_module_device (experts ):
68
+ for expert_index , expert in enumerate (self .experts ):
69
+ update_offload_parameter (expert .gate_proj , "weight" , experts .gate_up_proj [expert_index , ..., ::2 ].T )
70
+ update_offload_parameter (expert .gate_proj , "bias" , experts .gate_up_proj_bias [expert_index , ..., ::2 ])
71
+
72
+ update_offload_parameter (expert .up_proj , "weight" , experts .gate_up_proj [expert_index , ..., 1 ::2 ].T )
73
+ update_offload_parameter (expert .up_proj , "bias" , experts .gate_up_proj_bias [expert_index , ..., 1 ::2 ])
74
+
75
+ update_offload_parameter (expert .down_proj , "weight" , experts .down_proj [expert_index ].T )
76
+ update_offload_parameter (expert .down_proj , "bias" , experts .down_proj_bias [expert_index ])
77
+
78
+ def to_original (self ) -> GptOssExperts :
79
+ # TODO: this doesn't really handle offloading or correct device placement
80
+ with skip_weights_initialize (use_zeros = True ):
81
+ fake_config = GptOssConfig (
82
+ intermediate_size = self .intermediate_size ,
83
+ num_local_experts = self .num_experts ,
84
+ hidden_size = self .hidden_size ,
85
+ )
86
+ experts = GptOssExperts (fake_config )
87
+ experts .gate_up_proj = torch .nn .Parameter (experts .gate_up_proj .to (dtype = self .experts [0 ].gate_proj .weight .dtype ), requires_grad = False )
88
+ experts .gate_up_proj_bias = torch .nn .Parameter (experts .gate_up_proj_bias .to (dtype = self .experts [0 ].gate_proj .weight .dtype ), requires_grad = False )
89
+ experts .down_proj = torch .nn .Parameter (experts .down_proj .to (dtype = self .experts [0 ].down_proj .weight .dtype ), requires_grad = False )
90
+ experts .down_proj_bias = torch .nn .Parameter (experts .down_proj_bias .to (dtype = self .experts [0 ].down_proj .weight .dtype ), requires_grad = False )
91
+
92
+ for expert_index , expert in enumerate (self .experts ):
93
+ with align_module_device (expert .gate_proj , "cpu" ), align_module_device (expert .up_proj , "cpu" ), align_module_device (expert .down_proj , "cpu" ):
94
+ experts .gate_up_proj [expert_index , ..., ::2 ].copy_ (expert .gate_proj .weight .data .T )
95
+ experts .gate_up_proj_bias [expert_index , ..., ::2 ].copy_ (expert .gate_proj .bias .data )
96
+
97
+ experts .gate_up_proj [expert_index , ..., 1 ::2 ].copy_ (expert .up_proj .weight .data .T )
98
+ experts .gate_up_proj_bias [expert_index , ..., 1 ::2 ].copy_ (expert .up_proj .bias .data )
99
+
100
+ experts .down_proj [expert_index ].copy_ (expert .down_proj .weight .data .T )
101
+ experts .down_proj_bias [expert_index ].copy_ (expert .down_proj .bias .data )
102
+
103
+ print ("converted, for some reason slows down over time" )
104
+ import time
105
+ print (time .time ())
106
+
107
+ experts .eval ()
108
+ return experts
109
+
110
+
111
+ def forward (self , hidden_states : torch .Tensor , router_indices = None , routing_weights = None ) -> torch .Tensor :
112
+ """
113
+ When training is is more efficient to just loop over the experts and compute the output for each expert
114
+ as otherwise the memory would explode.
115
+
116
+ For inference we can sacrifice some memory and compute the output for all experts at once. By repeating the inputs.
117
+
118
+ Args:
119
+ hidden_states (torch.Tensor): (batch_size, seq_len, hidden_size)
120
+ selected_experts (torch.Tensor): (batch_size * token_num, top_k)
121
+ routing_weights (torch.Tensor): (batch_size * token_num, num_experts)
122
+ Returns:
123
+ torch.Tensor
124
+ """
125
+ original_shape = hidden_states .shape
126
+ hidden_states = hidden_states .reshape (- 1 , self .hidden_size ) # (num_tokens, hidden_size)
127
+
128
+ next_states = torch .zeros_like (hidden_states , dtype = hidden_states .dtype , device = hidden_states .device )
129
+ for expert_index , expert in enumerate (self .experts ):
130
+ next_states += expert (hidden_states ) * routing_weights .T [expert_index ].unsqueeze (- 1 )
131
+
132
+ next_states = next_states .reshape (original_shape )
133
+ return next_states
134
+
135
+ def replace_gpt_oss (config : GptOssConfig , module : GptOssExpert ):
136
+ return GptOssExpertsLinear (module )
137
+
138
+
139
+ def test_restore ():
140
+ config = GptOssConfig (hidden_size = 7 , num_local_experts = 3 , expert_dim = 5 )
141
+
142
+ original = GptOssExperts (config )
143
+ linear = GptOssExpertsLinear (original )
144
+
145
+ restored = linear .to_original ()
146
+ for param_name , param in original .named_parameters (recurse = False ):
147
+ restored_param = getattr (restored , param_name )
148
+ assert param .shape == restored_param .shape
149
+ assert param .dtype == restored_param .dtype
150
+
151
+ assert torch .all (getattr (restored , param_name ) == param )
152
+
153
+
154
+ def test_correctness ():
155
+ batch_size , seq_len = 13 , 12
156
+ config = GptOssConfig (hidden_size = 7 , num_local_experts = 3 , expert_dim = 5 )
157
+
158
+ input = torch .rand ((batch_size , seq_len , config .hidden_size ))
159
+ routing_weights = torch .rand ((batch_size * seq_len , config .num_local_experts ))
160
+
161
+ with torch .no_grad ():
162
+ original = GptOssExperts (config )
163
+ for name in ["gate_up_proj" , "gate_up_proj_bias" , "down_proj" , "down_proj_bias" ]:
164
+ setattr (original , name , getattr (original , name ).normal_ ())
165
+
166
+ original .eval ()
167
+ assert original .training == False
168
+ true_output = original (input , routing_weights = routing_weights )
169
+
170
+ linear = GptOssExpertsLinear (original )
171
+ output = linear (input , routing_weights = routing_weights )
172
+
173
+ assert torch .allclose (output , true_output , atol = 1e-3 , rtol = 0.0 )
174
+
175
+ restored = linear .to_original ()
176
+ restored_output = restored (input , routing_weights = routing_weights )
177
+ assert torch .allclose (restored_output , true_output , atol = 1e-3 , rtol = 0.0 )
178
+
179
+
180
+ if __name__ == "__main__" :
181
+ test_restore ()
0 commit comments