13
13
from vllm .model_executor .layers .quantization .utils .quant_utils import (
14
14
GroupShape )
15
15
from vllm .platforms import current_platform
16
+ from vllm .utils import direct_register_custom_op
16
17
17
18
# Input scaling factors are no longer optional in _scaled_mm starting
18
19
# from pytorch 2.5. Allocating a dummy tensor to pass as input_scale
@@ -156,13 +157,10 @@ def cutlass_w8a8_scaled_mm(*, qinput: torch.Tensor, weight: torch.Tensor,
156
157
return output .view (* output_shape )
157
158
158
159
159
- def rocm_per_tensor_w8a8_scaled_mm (* , qinput : torch .Tensor ,
160
- weight : torch .Tensor ,
161
- out_dtype : torch .dtype ,
162
- scale_a : torch .Tensor ,
163
- scale_b : torch .Tensor , bias : torch .Tensor ,
164
- input_2d : torch .Tensor ,
165
- output_shape : list ) -> torch .Tensor :
160
+ def rocm_per_tensor_w8a8_scaled_mm_impl (
161
+ qinput : torch .Tensor , weight : torch .Tensor , out_dtype : torch .dtype ,
162
+ scale_a : torch .Tensor , scale_b : torch .Tensor , bias : torch .Tensor ,
163
+ input_2d : torch .Tensor ) -> torch .Tensor :
166
164
from vllm .platforms .rocm import on_mi3xx
167
165
if envs .VLLM_ROCM_USE_SKINNY_GEMM and on_mi3xx (
168
166
) and qinput .shape [0 ] == 1 and qinput .shape [1 ] % 16 == 0 :
@@ -175,10 +173,38 @@ def rocm_per_tensor_w8a8_scaled_mm(*, qinput: torch.Tensor,
175
173
scale_a = scale_a ,
176
174
scale_b = scale_b ,
177
175
bias = bias )
176
+ return output
177
+
178
+
179
+ def rocm_per_tensor_w8a8_scaled_mm_fake (
180
+ qinput : torch .Tensor , weight : torch .Tensor , out_dtype : torch .dtype ,
181
+ scale_a : torch .Tensor , scale_b : torch .Tensor , bias : torch .Tensor ,
182
+ input_2d : torch .Tensor ) -> torch .Tensor :
183
+ return qinput .new_empty ((* qinput .shape [:- 1 ], weight .shape [1 ]),
184
+ dtype = out_dtype )
178
185
186
+
187
+ def rocm_per_tensor_w8a8_scaled_mm (* , qinput : torch .Tensor ,
188
+ weight : torch .Tensor ,
189
+ out_dtype : torch .dtype ,
190
+ scale_a : torch .Tensor ,
191
+ scale_b : torch .Tensor , bias : torch .Tensor ,
192
+ input_2d : torch .Tensor ,
193
+ output_shape : list ) -> torch .Tensor :
194
+ output = torch .ops .vllm .rocm_per_tensor_w8a8_scaled_mm_impl (
195
+ qinput , weight , out_dtype , scale_a , scale_b , bias , input_2d )
179
196
return torch .narrow (output , 0 , 0 , input_2d .shape [0 ]).view (* output_shape )
180
197
181
198
199
+ direct_register_custom_op (
200
+ op_name = "rocm_per_tensor_w8a8_scaled_mm_impl" ,
201
+ op_func = rocm_per_tensor_w8a8_scaled_mm_impl ,
202
+ mutates_args = [],
203
+ fake_impl = rocm_per_tensor_w8a8_scaled_mm_fake ,
204
+ dispatch_key = current_platform .dispatch_key ,
205
+ )
206
+
207
+
182
208
def torch_per_tensor_w8a8_scaled_mm (* , qinput : torch .Tensor ,
183
209
weight : torch .Tensor ,
184
210
out_dtype : torch .dtype ,
0 commit comments