7
7
from transformers .models .gpt_oss .configuration_gpt_oss import GptOssConfig
8
8
from llmcompressor .utils .dev import skip_weights_initialize
9
9
10
+ from compressed_tensors import update_offload_parameter
11
+
10
12
11
13
class GptOssExpert (torch .nn .Module ):
12
14
def __init__ (self , hidden_size : int , expert_dim : int , alpha : float , limit : float ):
@@ -56,18 +58,42 @@ def __init__(self, experts: GptOssExpert):
56
58
57
59
def load_weights (self , experts : GptOssExperts ):
58
60
for expert_index , expert in enumerate (self .experts ):
59
- expert .gate_proj .weight .data = experts .gate_up_proj [expert_index , ..., ::2 ].data .T
60
- expert .gate_proj .bias .data = experts .gate_up_proj_bias [expert_index , ..., ::2 ].data
61
-
62
- expert .up_proj .weight .data = experts .gate_up_proj [expert_index , ..., 1 ::2 ].data .T
63
- expert .up_proj .bias .data = experts .gate_up_proj_bias [expert_index , ..., 1 ::2 ].data
61
+ update_offload_parameter (expert .gate_proj , "weight" , experts .gate_up_proj [expert_index , ..., ::2 ].T )
62
+ update_offload_parameter (expert .gate_proj , "bias" , experts .gate_up_proj_bias [expert_index , ..., ::2 ])
64
63
65
- expert .down_proj . weight . data = experts .down_proj [expert_index ].T
66
- expert .down_proj . bias . data = experts .down_proj_bias [expert_index ]
64
+ update_offload_parameter ( expert .up_proj , " weight" , experts .gate_up_proj [expert_index , ..., 1 :: 2 ].T )
65
+ update_offload_parameter ( expert .up_proj , " bias" , experts .gate_up_proj_bias [expert_index , ..., 1 :: 2 ])
67
66
67
+ update_offload_parameter (expert .down_proj , "weight" , experts .down_proj [expert_index ].T )
68
+ update_offload_parameter (expert .down_proj , "bias" , experts .down_proj_bias [expert_index ])
68
69
69
70
def to_original (self ) -> GptOssExperts :
70
- pass
71
+ with skip_weights_initialize ():
72
+ fake_config = GptOssConfig (
73
+ intermediate_size = self .intermediate_size ,
74
+ num_local_experts = self .num_experts ,
75
+ hidden_size = self .hidden_size ,
76
+
77
+ )
78
+ experts = GptOssExperts (fake_config )
79
+
80
+ for expert_index , expert in enumerate (self .experts ):
81
+ experts .gate_up_proj [expert_index , ..., ::2 ].data = expert .gate_proj .weight .data .T
82
+ experts .gate_up_proj_bias [expert_index , ..., ::2 ].data = expert .gate_proj .bias .data
83
+
84
+ experts .gate_up_proj [expert_index , ..., 1 ::2 ].data = expert .up_proj .weight .data .T
85
+ experts .gate_up_proj_bias [expert_index , ..., 1 ::2 ].data = expert .up_proj .bias .data
86
+
87
+ experts .down_proj [expert_index ].data = expert .down_proj .weight .data .T
88
+ experts .down_proj_bias [expert_index ] = expert .down_proj .bias .data
89
+
90
+ # update offloaded state dict
91
+ update_offload_parameter (experts , "gate_up_proj" , experts .gate_up_proj )
92
+ update_offload_parameter (experts , "gate_up_proj_bias" , experts .gate_up_proj_bias )
93
+ update_offload_parameter (experts , "down_proj" , experts .down_proj )
94
+ update_offload_parameter (experts , "down_proj_bias" , experts .down_proj_bias )
95
+
96
+ return experts
71
97
72
98
73
99
def forward (self , hidden_states : torch .Tensor , router_indices = None , routing_weights = None ) -> torch .Tensor :
@@ -113,5 +139,8 @@ def forward(self, hidden_states: torch.Tensor, router_indices=None, routing_weig
113
139
linear = GptOssExpertsLinear (original )
114
140
output = linear (input , routing_weights = routing_weights )
115
141
116
- breakpoint ()
117
- assert torch .allclose (output , true_output , atol = 1e-3 , rtol = 0.0 )
142
+ assert torch .allclose (output , true_output , atol = 1e-3 , rtol = 0.0 )
143
+
144
+ restored = linear .to_original ()
145
+ restored_output = linear (input , routing_weights = routing_weights )
146
+ assert torch .allclose (restored_output , true_output , atol = 1e-3 , rtol = 0.0 )
0 commit comments