|
25 | 25 | #include <ATen/ops/quantize_per_channel_native.h> // for quantize_per_ch... |
26 | 26 | #include <ATen/ops/quantize_per_tensor_native.h> // for quantize_per_te... |
27 | 27 | #include <ATen/ops/zeros.h> |
| 28 | +#include <ATen/ops/_weight_int4pack_mm_for_cpu.h> |
28 | 29 | #endif |
29 | 30 |
|
30 | 31 | #include <c10/util/irange.h> |
@@ -1179,6 +1180,17 @@ namespace at::native { |
1179 | 1180 | TORCH_CHECK(false, "Unimplemented (int8 linear with packed weight and bias)"); |
1180 | 1181 | } |
1181 | 1182 |
|
| 1183 | + Tensor _weight_int4pack_mm_cpu_tensor( |
| 1184 | + const Tensor& A, |
| 1185 | + const Tensor& B, |
| 1186 | + const Tensor& qGroupSize, |
| 1187 | + const Tensor& qScaleAndZeros) { |
| 1188 | + TORCH_CHECK(qGroupSize.numel() == 1, __func__, ": group size must be a scalar."); |
| 1189 | + TORCH_CHECK(qGroupSize.scalar_type() == c10::kLong, __func__, ": group size must be int64."); |
| 1190 | + int group_size = qGroupSize.item<int64_t>(); |
| 1191 | + return at::_weight_int4pack_mm_for_cpu(A, B, group_size, qScaleAndZeros); |
| 1192 | + } |
| 1193 | + |
1182 | 1194 |
|
1183 | 1195 | namespace { |
1184 | 1196 |
|
@@ -1346,6 +1358,7 @@ TORCH_LIBRARY_IMPL(_quantized, QuantizedCPU, m) { |
1346 | 1358 | TORCH_LIBRARY_IMPL(quantized, CPU, m) { |
1347 | 1359 | m.impl(TORCH_SELECTIVE_NAME("quantized::linear_with_input_q_dq_qweight_dq_output_fp32"), TORCH_FN(QLinearInt8FusedQDQ<false>::run)); |
1348 | 1360 | m.impl(TORCH_SELECTIVE_NAME("quantized::linear_with_input_q_dq_qweight_dq_relu_output_fp32"), TORCH_FN(QLinearInt8FusedQDQ<true>::run)); |
| 1361 | + m.impl(TORCH_SELECTIVE_NAME("quantized::int4mm_packed_weight_cpu"), TORCH_FN(at::native::_weight_int4pack_mm_cpu_tensor)); |
1349 | 1362 | } |
1350 | 1363 |
|
1351 | 1364 | TORCH_LIBRARY_IMPL(onednn, MkldnnCPU, m) { |
|
0 commit comments