Skip to content

Commit 136825d

Browse files
authored
[Misc] Enhance code formatting in mxfp4.py (#22423)
Signed-off-by: Woosuk Kwon <[email protected]>
1 parent c2dba2d commit 136825d

File tree

1 file changed

+52
-33
lines changed
  • vllm/model_executor/layers/quantization

1 file changed

+52
-33
lines changed

vllm/model_executor/layers/quantization/mxfp4.py

Lines changed: 52 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -109,55 +109,74 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int,
109109
self.intermediate_size = intermediate_size_per_partition_after_pad
110110
self.hidden_size = hidden_size
111111
# Fused gate_up_proj (column parallel)
112-
w13_weight = torch.nn.Parameter(torch.zeros(
113-
num_experts,
114-
2 * intermediate_size_per_partition_after_pad,
115-
hidden_size // 2,
116-
dtype=weight_dtype),
117-
requires_grad=False)
112+
w13_weight = torch.nn.Parameter(
113+
torch.zeros(
114+
num_experts,
115+
2 * intermediate_size_per_partition_after_pad,
116+
hidden_size // 2,
117+
dtype=weight_dtype,
118+
),
119+
requires_grad=False,
120+
)
118121
layer.register_parameter("w13_weight", w13_weight)
119122
set_weight_attrs(w13_weight, extra_weight_attrs)
120123

121-
w13_weight_scale = torch.nn.Parameter(torch.zeros(
122-
num_experts,
123-
2 * intermediate_size_per_partition_after_pad,
124-
hidden_size // mxfp4_block,
125-
dtype=scale_dtype),
126-
requires_grad=False)
124+
w13_weight_scale = torch.nn.Parameter(
125+
torch.zeros(
126+
num_experts,
127+
2 * intermediate_size_per_partition_after_pad,
128+
hidden_size // mxfp4_block,
129+
dtype=scale_dtype,
130+
),
131+
requires_grad=False,
132+
)
127133
layer.register_parameter("w13_weight_scale", w13_weight_scale)
128134
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
129135

130-
w13_bias = torch.nn.Parameter(torch.zeros(
131-
num_experts,
132-
2 * intermediate_size_per_partition_after_pad,
133-
dtype=torch.bfloat16),
134-
requires_grad=False)
136+
w13_bias = torch.nn.Parameter(
137+
torch.zeros(
138+
num_experts,
139+
2 * intermediate_size_per_partition_after_pad,
140+
dtype=torch.bfloat16,
141+
),
142+
requires_grad=False,
143+
)
135144
layer.register_parameter("w13_bias", w13_bias)
136145
set_weight_attrs(w13_bias, extra_weight_attrs)
137146

138147
# down_proj (row parallel)
139-
w2_weight = torch.nn.Parameter(torch.zeros(
140-
num_experts,
141-
hidden_size,
142-
intermediate_size_per_partition_after_pad // 2,
143-
dtype=weight_dtype),
144-
requires_grad=False)
148+
w2_weight = torch.nn.Parameter(
149+
torch.zeros(
150+
num_experts,
151+
hidden_size,
152+
intermediate_size_per_partition_after_pad // 2,
153+
dtype=weight_dtype,
154+
),
155+
requires_grad=False,
156+
)
145157
layer.register_parameter("w2_weight", w2_weight)
146158
set_weight_attrs(w2_weight, extra_weight_attrs)
147159

148-
w2_weight_scale = torch.nn.Parameter(torch.zeros(
149-
num_experts,
150-
hidden_size,
151-
intermediate_size_per_partition_after_pad // mxfp4_block,
152-
dtype=scale_dtype),
153-
requires_grad=False)
160+
w2_weight_scale = torch.nn.Parameter(
161+
torch.zeros(
162+
num_experts,
163+
hidden_size,
164+
intermediate_size_per_partition_after_pad // mxfp4_block,
165+
dtype=scale_dtype,
166+
),
167+
requires_grad=False,
168+
)
154169
layer.register_parameter("w2_weight_scale", w2_weight_scale)
155170
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
156171

157-
w2_bias = torch.nn.Parameter(torch.zeros(num_experts,
158-
hidden_size,
159-
dtype=torch.bfloat16),
160-
requires_grad=False)
172+
w2_bias = torch.nn.Parameter(
173+
torch.zeros(
174+
num_experts,
175+
hidden_size,
176+
dtype=torch.bfloat16,
177+
),
178+
requires_grad=False,
179+
)
161180
layer.register_parameter("w2_bias", w2_bias)
162181
set_weight_attrs(w2_bias, extra_weight_attrs)
163182

0 commit comments

Comments
 (0)