7
7
from transformers .models .gpt_oss .configuration_gpt_oss import GptOssConfig
8
8
from llmcompressor .utils .dev import skip_weights_initialize
9
9
10
- from compressed_tensors import update_offload_parameter
10
+ from compressed_tensors . utils import update_offload_parameter , align_module_device
11
11
12
12
13
13
class GptOssExpert (torch .nn .Module ):
14
+ gate_proj : torch .nn .Linear
15
+ up_proj : torch .nn .Linear
16
+ down_proj : torch .nn .Linear
17
+
14
18
def __init__ (self , hidden_size : int , expert_dim : int , alpha : float , limit : float ):
15
19
super ().__init__ ()
16
20
@@ -57,17 +61,21 @@ def __init__(self, experts: GptOssExpert):
57
61
self .limit = experts .limit
58
62
59
63
def load_weights (self , experts : GptOssExperts ):
60
- for expert_index , expert in enumerate (self .experts ):
61
- update_offload_parameter (expert .gate_proj , "weight" , experts .gate_up_proj [expert_index , ..., ::2 ].T )
62
- update_offload_parameter (expert .gate_proj , "bias" , experts .gate_up_proj_bias [expert_index , ..., ::2 ])
64
+ # TODO: this code is inefficient. If there was a "get_offloaded_data" util,
65
+ # we could avoid having to move from cpu -> gpu -> cpu
66
+ with align_module_device (experts ):
67
+ for expert_index , expert in enumerate (self .experts ):
68
+ update_offload_parameter (expert .gate_proj , "weight" , experts .gate_up_proj [expert_index , ..., ::2 ].T )
69
+ update_offload_parameter (expert .gate_proj , "bias" , experts .gate_up_proj_bias [expert_index , ..., ::2 ])
63
70
64
- update_offload_parameter (expert .up_proj , "weight" , experts .gate_up_proj [expert_index , ..., 1 ::2 ].T )
65
- update_offload_parameter (expert .up_proj , "bias" , experts .gate_up_proj_bias [expert_index , ..., 1 ::2 ])
71
+ update_offload_parameter (expert .up_proj , "weight" , experts .gate_up_proj [expert_index , ..., 1 ::2 ].T )
72
+ update_offload_parameter (expert .up_proj , "bias" , experts .gate_up_proj_bias [expert_index , ..., 1 ::2 ])
66
73
67
- update_offload_parameter (expert .down_proj , "weight" , experts .down_proj [expert_index ].T )
68
- update_offload_parameter (expert .down_proj , "bias" , experts .down_proj_bias [expert_index ])
74
+ update_offload_parameter (expert .down_proj , "weight" , experts .down_proj [expert_index ].T )
75
+ update_offload_parameter (expert .down_proj , "bias" , experts .down_proj_bias [expert_index ])
69
76
70
77
def to_original (self ) -> GptOssExperts :
78
+ # TODO: this doesn't really handle offloading or correct device placement
71
79
with skip_weights_initialize ():
72
80
fake_config = GptOssConfig (
73
81
intermediate_size = self .intermediate_size ,
@@ -78,14 +86,17 @@ def to_original(self) -> GptOssExperts:
78
86
experts = GptOssExperts (fake_config )
79
87
80
88
for expert_index , expert in enumerate (self .experts ):
81
- experts .gate_up_proj [expert_index , ..., ::2 ].data = expert .gate_proj .weight .data .T
82
- experts .gate_up_proj_bias [expert_index , ..., ::2 ].data = expert .gate_proj .bias .data
89
+ # TODO: this code is inefficient. If there was a "get_offloaded_data" util,
90
+ # we could avoid having to move from cpu -> gpu -> cpu
91
+ with align_module_device (expert ):
92
+ experts .gate_up_proj [expert_index , ..., ::2 ].data = expert .gate_proj .weight .data .T
93
+ experts .gate_up_proj_bias [expert_index , ..., ::2 ].data = expert .gate_proj .bias .data
83
94
84
- experts .gate_up_proj [expert_index , ..., 1 ::2 ].data = expert .up_proj .weight .data .T
85
- experts .gate_up_proj_bias [expert_index , ..., 1 ::2 ].data = expert .up_proj .bias .data
95
+ experts .gate_up_proj [expert_index , ..., 1 ::2 ].data = expert .up_proj .weight .data .T
96
+ experts .gate_up_proj_bias [expert_index , ..., 1 ::2 ].data = expert .up_proj .bias .data
86
97
87
- experts .down_proj [expert_index ].data = expert .down_proj .weight .data .T
88
- experts .down_proj_bias [expert_index ] = expert .down_proj .bias .data
98
+ experts .down_proj [expert_index ].data = expert .down_proj .weight .data .T
99
+ experts .down_proj_bias [expert_index ] = expert .down_proj .bias .data
89
100
90
101
# update offloaded state dict
91
102
update_offload_parameter (experts , "gate_up_proj" , experts .gate_up_proj )
@@ -134,6 +145,7 @@ def forward(self, hidden_states: torch.Tensor, router_indices=None, routing_weig
134
145
setattr (original , name , getattr (original , name ).normal_ ())
135
146
136
147
original .eval ()
148
+ assert original .training == False
137
149
true_output = original (input , routing_weights = routing_weights )
138
150
139
151
linear = GptOssExpertsLinear (original )
0 commit comments