1
+ from typing import List
2
+
3
+ import torch
4
+ import contextlib
5
+
6
+ from transformers .models .gpt_oss .modeling_gpt_oss import GptOssExperts
7
+ from transformers .models .gpt_oss .configuration_gpt_oss import GptOssConfig
8
+ from llmcompressor .utils .dev import skip_weights_initialize
9
+
10
+
11
+ class GptOssExpert (torch .nn .Module ):
12
+ def __init__ (self , hidden_size : int , expert_dim : int , alpha : float , limit : float ):
13
+ super ().__init__ ()
14
+
15
+ self .hidden_size = hidden_size
16
+ self .expert_dim = expert_dim
17
+ self .alpha = alpha
18
+ self .limit = limit
19
+
20
+ with skip_weights_initialize ():
21
+ self .gate_proj = torch .nn .Linear (self .hidden_size , self .expert_dim , bias = True )
22
+ self .up_proj = torch .nn .Linear (self .hidden_size , self .expert_dim , bias = True )
23
+ self .down_proj = torch .nn .Linear (self .expert_dim , self .hidden_size , bias = True )
24
+
25
+
26
+ def forward (self , hidden_states : torch .Tensor ):
27
+ gate = self .gate_proj (hidden_states )
28
+ gate = gate .clamp (min = None , max = self .limit )
29
+
30
+ up = self .up_proj (hidden_states )
31
+ up = up .clamp (min = - self .limit , max = self .limit )
32
+
33
+ glu = gate * torch .sigmoid (gate * self .alpha )
34
+ return self .down_proj ((up + 1 ) * glu )
35
+
36
+
37
+
38
+ class GptOssExpertsLinear (torch .nn .Module ):
39
+ experts : List [GptOssExpert ]
40
+
41
+ def __init__ (self , experts : GptOssExpert ):
42
+ super ().__init__ ()
43
+
44
+ self .intermediate_size = experts .intermediate_size
45
+ self .num_experts = experts .num_experts
46
+ self .hidden_size = experts .hidden_size
47
+ self .expert_dim = experts .expert_dim
48
+
49
+ with skip_weights_initialize ():
50
+ self .experts = [GptOssExpert (self .hidden_size , self .expert_dim , experts .alpha , experts .limit ) for _ in range (self .num_experts )]
51
+
52
+ self .load_weights (experts )
53
+
54
+ self .alpha = experts .alpha
55
+ self .limit = experts .limit
56
+
57
+ def load_weights (self , experts : GptOssExperts ):
58
+ for expert_index , expert in enumerate (self .experts ):
59
+ expert .gate_proj .weight .data = experts .gate_up_proj [expert_index , ..., ::2 ].data .T
60
+ expert .gate_proj .bias .data = experts .gate_up_proj_bias [expert_index , ..., ::2 ].data
61
+
62
+ expert .up_proj .weight .data = experts .gate_up_proj [expert_index , ..., 1 ::2 ].data .T
63
+ expert .up_proj .bias .data = experts .gate_up_proj_bias [expert_index , ..., 1 ::2 ].data
64
+
65
+ expert .down_proj .weight .data = experts .down_proj [expert_index ].T
66
+ expert .down_proj .bias .data = experts .down_proj_bias [expert_index ]
67
+
68
+
69
+ def to_original (self ) -> GptOssExperts :
70
+ pass
71
+
72
+
73
+ def forward (self , hidden_states : torch .Tensor , router_indices = None , routing_weights = None ) -> torch .Tensor :
74
+ """
75
+ When training is is more efficient to just loop over the experts and compute the output for each expert
76
+ as otherwise the memory would explode.
77
+
78
+ For inference we can sacrifice some memory and compute the output for all experts at once. By repeating the inputs.
79
+
80
+ Args:
81
+ hidden_states (torch.Tensor): (batch_size, seq_len, hidden_size)
82
+ selected_experts (torch.Tensor): (batch_size * token_num, top_k)
83
+ routing_weights (torch.Tensor): (batch_size * token_num, num_experts)
84
+ Returns:
85
+ torch.Tensor
86
+ """
87
+ original_shape = hidden_states .shape
88
+ hidden_states = hidden_states .reshape (- 1 , self .hidden_size ) # (num_tokens, hidden_size)
89
+
90
+ next_states = torch .zeros_like (hidden_states , dtype = hidden_states .dtype , device = hidden_states .device )
91
+ for expert_index , expert in enumerate (self .experts ):
92
+ next_states += expert (hidden_states ) * routing_weights .T [expert_index ].unsqueeze (- 1 )
93
+
94
+ next_states = next_states .reshape (original_shape )
95
+ return next_states
96
+
97
+
98
+ if __name__ == "__main__" :
99
+ batch_size , seq_len = 13 , 12
100
+ config = GptOssConfig (hidden_size = 7 , num_local_experts = 3 , expert_dim = 5 )
101
+
102
+ input = torch .rand ((batch_size , seq_len , config .hidden_size ))
103
+ routing_weights = torch .rand ((batch_size * seq_len , config .num_local_experts ))
104
+
105
+ with torch .no_grad ():
106
+ original = GptOssExperts (config )
107
+ for name in ["gate_up_proj" , "gate_up_proj_bias" , "down_proj" , "down_proj_bias" ]:
108
+ setattr (original , name , getattr (original , name ).normal_ ())
109
+
110
+ original .eval ()
111
+ true_output = original (input , routing_weights = routing_weights )
112
+
113
+ linear = GptOssExpertsLinear (original )
114
+ output = linear (input , routing_weights = routing_weights )
115
+
116
+ breakpoint ()
117
+ assert torch .allclose (output , true_output , atol = 1e-3 , rtol = 0.0 )
0 commit comments