@@ -109,55 +109,74 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int,
109
109
self .intermediate_size = intermediate_size_per_partition_after_pad
110
110
self .hidden_size = hidden_size
111
111
# Fused gate_up_proj (column parallel)
112
- w13_weight = torch .nn .Parameter (torch .zeros (
113
- num_experts ,
114
- 2 * intermediate_size_per_partition_after_pad ,
115
- hidden_size // 2 ,
116
- dtype = weight_dtype ),
117
- requires_grad = False )
112
+ w13_weight = torch .nn .Parameter (
113
+ torch .zeros (
114
+ num_experts ,
115
+ 2 * intermediate_size_per_partition_after_pad ,
116
+ hidden_size // 2 ,
117
+ dtype = weight_dtype ,
118
+ ),
119
+ requires_grad = False ,
120
+ )
118
121
layer .register_parameter ("w13_weight" , w13_weight )
119
122
set_weight_attrs (w13_weight , extra_weight_attrs )
120
123
121
- w13_weight_scale = torch .nn .Parameter (torch .zeros (
122
- num_experts ,
123
- 2 * intermediate_size_per_partition_after_pad ,
124
- hidden_size // mxfp4_block ,
125
- dtype = scale_dtype ),
126
- requires_grad = False )
124
+ w13_weight_scale = torch .nn .Parameter (
125
+ torch .zeros (
126
+ num_experts ,
127
+ 2 * intermediate_size_per_partition_after_pad ,
128
+ hidden_size // mxfp4_block ,
129
+ dtype = scale_dtype ,
130
+ ),
131
+ requires_grad = False ,
132
+ )
127
133
layer .register_parameter ("w13_weight_scale" , w13_weight_scale )
128
134
set_weight_attrs (w13_weight_scale , extra_weight_attrs )
129
135
130
- w13_bias = torch .nn .Parameter (torch .zeros (
131
- num_experts ,
132
- 2 * intermediate_size_per_partition_after_pad ,
133
- dtype = torch .bfloat16 ),
134
- requires_grad = False )
136
+ w13_bias = torch .nn .Parameter (
137
+ torch .zeros (
138
+ num_experts ,
139
+ 2 * intermediate_size_per_partition_after_pad ,
140
+ dtype = torch .bfloat16 ,
141
+ ),
142
+ requires_grad = False ,
143
+ )
135
144
layer .register_parameter ("w13_bias" , w13_bias )
136
145
set_weight_attrs (w13_bias , extra_weight_attrs )
137
146
138
147
# down_proj (row parallel)
139
- w2_weight = torch .nn .Parameter (torch .zeros (
140
- num_experts ,
141
- hidden_size ,
142
- intermediate_size_per_partition_after_pad // 2 ,
143
- dtype = weight_dtype ),
144
- requires_grad = False )
148
+ w2_weight = torch .nn .Parameter (
149
+ torch .zeros (
150
+ num_experts ,
151
+ hidden_size ,
152
+ intermediate_size_per_partition_after_pad // 2 ,
153
+ dtype = weight_dtype ,
154
+ ),
155
+ requires_grad = False ,
156
+ )
145
157
layer .register_parameter ("w2_weight" , w2_weight )
146
158
set_weight_attrs (w2_weight , extra_weight_attrs )
147
159
148
- w2_weight_scale = torch .nn .Parameter (torch .zeros (
149
- num_experts ,
150
- hidden_size ,
151
- intermediate_size_per_partition_after_pad // mxfp4_block ,
152
- dtype = scale_dtype ),
153
- requires_grad = False )
160
+ w2_weight_scale = torch .nn .Parameter (
161
+ torch .zeros (
162
+ num_experts ,
163
+ hidden_size ,
164
+ intermediate_size_per_partition_after_pad // mxfp4_block ,
165
+ dtype = scale_dtype ,
166
+ ),
167
+ requires_grad = False ,
168
+ )
154
169
layer .register_parameter ("w2_weight_scale" , w2_weight_scale )
155
170
set_weight_attrs (w2_weight_scale , extra_weight_attrs )
156
171
157
- w2_bias = torch .nn .Parameter (torch .zeros (num_experts ,
158
- hidden_size ,
159
- dtype = torch .bfloat16 ),
160
- requires_grad = False )
172
+ w2_bias = torch .nn .Parameter (
173
+ torch .zeros (
174
+ num_experts ,
175
+ hidden_size ,
176
+ dtype = torch .bfloat16 ,
177
+ ),
178
+ requires_grad = False ,
179
+ )
161
180
layer .register_parameter ("w2_bias" , w2_bias )
162
181
set_weight_attrs (w2_bias , extra_weight_attrs )
163
182
0 commit comments