@@ -80,11 +80,13 @@ def workspace_shapes(
80
80
topk : int ,
81
81
num_experts : int ,
82
82
) -> tuple [int , int , torch .dtype ]:
83
+
83
84
block_m = self .block_shape [0 ]
84
85
M_sum = (M * topk ) + num_experts * (block_m - 1 )
85
86
M_sum = round_up (M_sum , block_m )
86
87
workspace1 = M_sum * max (N * 2 , K )
87
- workspace2 = M_sum * N
88
+ workspace2 = M_sum * max (N , K )
89
+
88
90
return (workspace1 , workspace2 , a .dtype )
89
91
90
92
def apply (
@@ -135,26 +137,31 @@ def apply(
135
137
136
138
# Note: M_sum is different than the pre-permuted shape of a1q.
137
139
M_sum = a1q .size (0 )
138
- workspace1 = _resize_cache (workspace13 , (M_sum , N ))
139
- workspace2 = _resize_cache (workspace2 , (M_sum , N // 2 ))
140
- workspace3 = _resize_cache (workspace13 , (M_sum , K ))
140
+
141
+ mm1_out = _resize_cache (workspace13 , (M_sum , N ))
142
+ act_out = _resize_cache (workspace2 , (M_sum , N // 2 ))
143
+ quant_out = _resize_cache (workspace13 .view (dtype = torch .float8_e4m3fn ),
144
+ (M_sum , N // 2 ))
145
+ mm2_out = _resize_cache (workspace2 , (M_sum , K ))
146
+ out = _resize_cache (workspace13 , (inv_perm .size (0 ), K ))
141
147
142
148
dg .m_grouped_gemm_fp8_fp8_bf16_nt_contiguous (
143
- (a1q , a1q_scale ), (w1 , w1_scale ), workspace1 , expert_ids )
149
+ (a1q , a1q_scale ), (w1 , w1_scale ), mm1_out , expert_ids )
144
150
145
- self .activation (activation , workspace2 , workspace1 .view (- 1 , N ))
151
+ self .activation (activation , act_out , mm1_out .view (- 1 , N ))
146
152
147
153
a2q_scale : Optional [torch .Tensor ] = None
148
- a2q , a2q_scale = per_token_group_quant_fp8 (workspace2 ,
154
+ a2q , a2q_scale = per_token_group_quant_fp8 (act_out ,
149
155
self .block_shape [1 ],
150
- column_major_scales = True )
156
+ column_major_scales = True ,
157
+ out_q = quant_out )
151
158
152
159
dg .m_grouped_gemm_fp8_fp8_bf16_nt_contiguous (
153
- (a2q , a2q_scale ), (w2 , w2_scale ), workspace3 , expert_ids )
160
+ (a2q , a2q_scale ), (w2 , w2_scale ), mm2_out , expert_ids )
154
161
155
- workspace3 = workspace3 [ inv_perm , ...]
162
+ torch . index_select ( mm2_out , 0 , inv_perm , out = out )
156
163
157
- return workspace3
164
+ return out
158
165
159
166
160
167
def deep_gemm_moe_fp8 (
0 commit comments