Skip to content

Commit 98e1601

Browse files
Xia-Weiwenpytorchmergebot
authored andcommitted
[Quant][CPU] add a wrapper op for _weight_int4pack_mm_for_cpu with tensor args (pytorch#145245)
**Summary** It's part of the task to enable max-autotune with GEMM template for WoQ INT4 GEMM on CPU. This PR adds a wrapper op in `quantized` namespace for `torch.ops.aten_weight_int4pack_mm_for_cpu`, whose arguments are all tensors. It will be used in Inductor lowering with max-autotune where scalar arguments are difficult to handle. The new op is not registered to - `aten` because it will require changing `native_functions.yaml`, which is not recommended. - `quantized_decomposed` because it will only have a Python implementation, which cannot be used for cpp wrapper in Inductor. **Test plan** ``` python test/test_linalg.py -k test__int4_mm ``` Pull Request resolved: pytorch#145245 Approved by: https://github.com/leslie-fang-intel, https://github.com/jgong5, https://github.com/jerryzh168
1 parent ac0f206 commit 98e1601

File tree

6 files changed

+49
-1
lines changed

6 files changed

+49
-1
lines changed

aten/src/ATen/native/LinearAlgebra.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3489,6 +3489,8 @@ Tensor _weight_int4pack_mm_cpu(
34893489
TORCH_CHECK(qGroupSize == 32 || qGroupSize == 64 || qGroupSize == 128
34903490
|| qGroupSize == 256,
34913491
__func__, ": expect qGroupSize to be 32, 64, 128 or 256, got ", qGroupSize);
3492+
TORCH_CHECK(K % qGroupSize == 0,
3493+
__func__, ": expect K to be divisible by qGroupSize, got K:", K, ", qGroupSize:", qGroupSize);
34923494

34933495
TORCH_CHECK(qScaleAndZeros.dim() == 3 && qScaleAndZeros.size(1) == N
34943496
&& qScaleAndZeros.size(2) == 2,

aten/src/ATen/native/quantized/cpu/qlinear.cpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
#include <ATen/ops/quantize_per_channel_native.h> // for quantize_per_ch...
2626
#include <ATen/ops/quantize_per_tensor_native.h> // for quantize_per_te...
2727
#include <ATen/ops/zeros.h>
28+
#include <ATen/ops/_weight_int4pack_mm_for_cpu.h>
2829
#endif
2930

3031
#include <c10/util/irange.h>
@@ -1179,6 +1180,17 @@ namespace at::native {
11791180
TORCH_CHECK(false, "Unimplemented (int8 linear with packed weight and bias)");
11801181
}
11811182

1183+
Tensor _weight_int4pack_mm_cpu_tensor(
1184+
const Tensor& A,
1185+
const Tensor& B,
1186+
const Tensor& qGroupSize,
1187+
const Tensor& qScaleAndZeros) {
1188+
TORCH_CHECK(qGroupSize.numel() == 1, __func__, ": group size must be a scalar.");
1189+
TORCH_CHECK(qGroupSize.scalar_type() == c10::kLong, __func__, ": group size must be int64.");
1190+
int group_size = qGroupSize.item<int64_t>();
1191+
return at::_weight_int4pack_mm_for_cpu(A, B, group_size, qScaleAndZeros);
1192+
}
1193+
11821194

11831195
namespace {
11841196

@@ -1346,6 +1358,7 @@ TORCH_LIBRARY_IMPL(_quantized, QuantizedCPU, m) {
13461358
TORCH_LIBRARY_IMPL(quantized, CPU, m) {
13471359
m.impl(TORCH_SELECTIVE_NAME("quantized::linear_with_input_q_dq_qweight_dq_output_fp32"), TORCH_FN(QLinearInt8FusedQDQ<false>::run));
13481360
m.impl(TORCH_SELECTIVE_NAME("quantized::linear_with_input_q_dq_qweight_dq_relu_output_fp32"), TORCH_FN(QLinearInt8FusedQDQ<true>::run));
1361+
m.impl(TORCH_SELECTIVE_NAME("quantized::int4mm_packed_weight_cpu"), TORCH_FN(at::native::_weight_int4pack_mm_cpu_tensor));
13491362
}
13501363

13511364
TORCH_LIBRARY_IMPL(onednn, MkldnnCPU, m) {

aten/src/ATen/native/quantized/cpu/qlinear.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,4 +42,10 @@ C10_API static Tensor run_pointwise_binary_tensor(
4242
std::string_view unary_post_op_algorithm);
4343
};
4444

45+
C10_API Tensor _weight_int4pack_mm_cpu_tensor(
46+
const Tensor& A,
47+
const Tensor& B,
48+
const Tensor& qGroupSize,
49+
const Tensor& qScaleAndZeros);
50+
4551
} // namespace at::native

aten/src/ATen/native/quantized/library.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,7 @@ TORCH_LIBRARY(quantized, m) {
216216
m.def(TORCH_SELECTIVE_SCHEMA("quantized::prelu(Tensor qx, Tensor weight, float output_scale, int output_zero_point) -> Tensor"), {at::Tag::pt2_compliant_tag});
217217
m.def(TORCH_SELECTIVE_SCHEMA("quantized::sigmoid(Tensor qx, float output_scale, int output_zero_point) -> Tensor"), {at::Tag::pt2_compliant_tag});
218218
m.def(TORCH_SELECTIVE_SCHEMA("quantized::softmax(Tensor qx, int dim, float output_scale, int output_zero_point) -> Tensor"), {at::Tag::pt2_compliant_tag});
219+
m.def(TORCH_SELECTIVE_SCHEMA("quantized::int4mm_packed_weight_cpu(Tensor self, Tensor mat2, Tensor qGroupSize, Tensor qScaleAndZeros) -> Tensor"));
219220
}
220221

221222
// According to #33294: The "_" prefix registration will be

test/test_linalg.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6644,9 +6644,16 @@ def weight_int4pack_mm(a, b_int4pack, b_scales_and_zeros):
66446644
if self.device_type == 'cpu':
66456645
self.assertTrue(b_int4pack.dtype is torch.uint8)
66466646
self.assertTrue(b_int4pack.dim() == 2)
6647-
return torch._weight_int4pack_mm_for_cpu(
6647+
c = torch._weight_int4pack_mm_for_cpu(
66486648
a, b_int4pack, q_group, b_scales_and_zeros
66496649
)
6650+
# test wrapper
6651+
q_group_t = torch.tensor(q_group, dtype=torch.int64, device=device)
6652+
c_2 = torch.ops.quantized.int4mm_packed_weight_cpu(
6653+
a, b_int4pack, q_group_t, b_scales_and_zeros
6654+
)
6655+
assert torch.equal(c, c_2)
6656+
return c
66506657
else:
66516658
self.assertTrue(b_int4pack.dtype is torch.int32)
66526659
self.assertTrue(b_int4pack.dim() == 4)

torch/_meta_registrations.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2658,6 +2658,25 @@ def meta_quantized_max_pool2d(
26582658
memory_format=memory_format,
26592659
)
26602660

2661+
@register_meta(torch.ops.quantized.int4mm_packed_weight_cpu)
2662+
def meta_int4mm_packed_weight_cpu(x, w, q_group_size, q_scale_and_zeros):
2663+
torch._check(x.dim() == 2, f"x must be a 2D tensor, got {x.dim()}D")
2664+
torch._check(w.dim() == 2, f"w must be a 2D tensor, got {w.dim()}D")
2665+
torch._check(
2666+
x.dtype in [torch.float32, torch.float16, torch.bfloat16],
2667+
f"expected x to be f32/f16/bf16, got {x.dtype}",
2668+
)
2669+
torch._check(w.dtype == torch.uint8, f"expected w to be uint8, got {w.dtype}")
2670+
torch._check(
2671+
q_group_size.dtype == torch.int64,
2672+
f"q_group_size must be int64, got {q_group_size.dtype}",
2673+
)
2674+
torch._check(
2675+
q_scale_and_zeros.dtype == x.dtype,
2676+
f"q_scale_and_zeros must have the same dtype as x, got {q_scale_and_zeros.dtype}",
2677+
)
2678+
return x.new_empty(x.size(0), w.size(0), dtype=x.dtype)
2679+
26612680

26622681
# from check_dim_size() in aten/src/ATen/TensorUtils.cpp.
26632682
def check_dim_size(tensor, dim, dim_size, size):

0 commit comments

Comments
 (0)