Skip to content

Commit 4101ea5

Browse files
authored
Optimize w8a8 quantized matmul kernel (#9412)
1 parent 9c8ae9f commit 4101ea5

File tree

4 files changed

+172
-111
lines changed

4 files changed

+172
-111
lines changed

test/test_pallas.py

Lines changed: 63 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import logging
22
import sys
33
import unittest
4+
from unittest.mock import patch
45
from absl.testing import parameterized
56

67
import torch
@@ -877,6 +878,11 @@ def test_ragged_paged_attention_wrapper_without_dynamo(
877878
use_dynamo=False,
878879
)
879880

881+
def _compute_rel_error(self, x, q_x):
882+
return torch.mean(torch.sqrt(torch.mean(torch.square(q_x - x),
883+
axis=1))) / torch.sqrt(
884+
torch.mean(torch.square(x)))
885+
880886
def _test_quantized_matmul_int8(
881887
self,
882888
dtype,
@@ -885,10 +891,6 @@ def _test_quantized_matmul_int8(
885891
n_output_features,
886892
quantize_activation,
887893
use_dynamo,
888-
batch_block_size=None,
889-
out_block_size=None,
890-
in_block_size=None,
891-
atol=1.5,
892894
n_bits=8,
893895
):
894896
x = torch.randn((bs, n_input_features), dtype=dtype)
@@ -918,17 +920,9 @@ def _test_quantized_matmul_int8(
918920
scalar_xla = scalar.to('xla')
919921
if use_dynamo:
920922

921-
def quantized_matmul_int8_wrapper(x, w_int, scalar, quantize_activation,
922-
batch_block_size, out_block_size,
923-
in_block_size):
923+
def quantized_matmul_int8_wrapper(x, w_int, scalar, quantize_activation):
924924
return torch.ops.xla.quantized_matmul_int8(
925-
x,
926-
w_int,
927-
scalar,
928-
quantize_activation=quantize_activation,
929-
batch_block_size=batch_block_size,
930-
out_block_size=out_block_size,
931-
in_block_size=in_block_size)
925+
x, w_int, scalar, quantize_activation=quantize_activation)
932926

933927
quantized_matmul_int8 = torch.compile(
934928
quantized_matmul_int8_wrapper, backend="openxla")
@@ -941,46 +935,90 @@ def quantized_matmul_int8_wrapper(x, w_int, scalar, quantize_activation,
941935
w_int_xla,
942936
scalar_xla,
943937
quantize_activation=quantize_activation,
944-
batch_block_size=batch_block_size,
945-
out_block_size=out_block_size,
946-
in_block_size=in_block_size).cpu()
938+
).cpu()
939+
940+
# print(f'Output max diff: {torch.max(torch.abs(expected - actual))}')
941+
# print(f'Output mean diff: {torch.mean(torch.abs(expected - actual))}')
942+
rel_error = self._compute_rel_error(expected, actual)
947943

948944
self.assertEqual(actual.shape, expected.shape)
949945
self.assertEqual(actual.dtype, expected.dtype)
950-
self.assertTrue(torch.allclose(actual, expected, atol=atol))
946+
self.assertTrue(rel_error < 3e-2)
947+
948+
@parameterized.product(
949+
dtype=[torch.bfloat16
950+
], # not testing float32 because we haven't tuned for float32 case.
951+
quantize_activation=[True],
952+
use_dynamo=[True, False],
953+
)
954+
@unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 5,
955+
"This test only works on TPUv5+.")
956+
@patch(
957+
'torch_xla.experimental.pallas_kernels.quantized_matmul_kernel.get_tpu_version'
958+
)
959+
def test_quantized_matmul_int8_wrapper_key_exists_in_table(
960+
self,
961+
get_tpu_version,
962+
dtype,
963+
quantize_activation,
964+
use_dynamo,
965+
):
966+
from torch_xla.experimental.pallas_kernels.quantized_matmul_kernel import TUNED_BLOCK_SIZES
967+
num_cases_to_test = 2
968+
if len(TUNED_BLOCK_SIZES) < num_cases_to_test:
969+
self.fail(
970+
"Not enough tuned block sizes for quantized matmul int8 test. But we should have {num_cases_to_test} block sizes to test."
971+
)
972+
input_shapes = []
973+
for key in TUNED_BLOCK_SIZES.keys():
974+
if len(input_shapes) >= num_cases_to_test:
975+
break
976+
_, batch_size, n_output_features, n_input_features, *_ = key
977+
input_shapes.append((batch_size, n_output_features, n_input_features))
978+
tpu_version_to_use = 6
979+
get_tpu_version.return_value = tpu_version_to_use
980+
for batch_size, n_output_features, n_input_features in input_shapes:
981+
self._test_quantized_matmul_int8(
982+
dtype,
983+
batch_size,
984+
n_input_features,
985+
n_output_features,
986+
quantize_activation,
987+
use_dynamo=use_dynamo,
988+
)
951989

952990
@parameterized.product(
953991
dtype=[torch.bfloat16, torch.float32],
954992
bs=[256, 512],
955993
n_input_features=[256, 512],
956994
n_output_features=[256, 512],
957995
quantize_activation=[True],
958-
kernel_block_sizes=[(None, None, None), (256, 256, 256)],
959996
use_dynamo=[True, False],
960997
)
961998
@unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 5,
962999
"This test only works on TPUv5+.")
963-
def test_quantized_matmul_int8_wrapper(
1000+
@patch(
1001+
'torch_xla.experimental.pallas_kernels.quantized_matmul_kernel.get_tuned_block_sizes'
1002+
)
1003+
def test_quantized_matmul_int8_wrapper_key_not_exists_in_table(
9641004
self,
1005+
get_tuned_block_sizes,
9651006
dtype,
9661007
bs,
9671008
n_input_features,
9681009
n_output_features,
9691010
quantize_activation,
970-
kernel_block_sizes,
9711011
use_dynamo,
9721012
):
973-
batch_block_size, out_block_size, in_block_size = kernel_block_sizes
1013+
get_tuned_block_sizes.return_value = (None, None, None)
9741014
self._test_quantized_matmul_int8(
9751015
dtype,
9761016
bs,
9771017
n_input_features,
9781018
n_output_features,
9791019
quantize_activation,
9801020
use_dynamo=use_dynamo,
981-
batch_block_size=batch_block_size,
982-
out_block_size=out_block_size,
983-
in_block_size=in_block_size)
1021+
)
9841022

9851023
@unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 4,
9861024
"This test only works on TPUv4+.")

test/test_quantized_matmul_pallas_kernel.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,8 @@ def test_quantized_matmul_retrieve_block_sizes(self, get_tpu_version):
128128
break
129129
expected_block_sizes = TUNED_BLOCK_SIZES[key0]
130130
_, bs, n_output_features, n_input_features, activation_dtype, quantize_activation = key0
131-
actual_block_sizes = get_tuned_block_sizes(bs, n_output_features,
131+
actual_block_sizes = get_tuned_block_sizes(TUNED_BLOCK_SIZES, bs,
132+
n_output_features,
132133
n_input_features,
133134
activation_dtype,
134135
quantize_activation)
@@ -145,12 +146,17 @@ def test_quantized_matmul_use_tuned_block_sizes(self, dtype, bs,
145146
n_input_features,
146147
n_output_features,
147148
quantize_activation):
148-
self._test_quantized_matmul(
149-
dtype,
150-
bs,
151-
n_input_features,
152-
n_output_features,
153-
quantize_activation=quantize_activation)
149+
with self.assertRaises(AssertionError):
150+
self._test_quantized_matmul(
151+
dtype,
152+
bs,
153+
n_input_features,
154+
n_output_features,
155+
quantize_activation=quantize_activation,
156+
batch_block_size=None,
157+
out_block_size=None,
158+
in_block_size=None,
159+
)
154160

155161

156162
if __name__ == "__main__":

torch_xla/experimental/custom_kernel.py

Lines changed: 26 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1075,17 +1075,32 @@ def quantized_matmul_int8(
10751075
in_block_size: int | None = None,
10761076
vmem_limit_bytes: int | None = 64 * 1024 * 1024,
10771077
) -> torch.Tensor:
1078-
from torch_xla.experimental.pallas_kernels.quantized_matmul_kernel import quantized_matmul_int8
1079-
return xb.call_jax(
1080-
quantized_matmul_int8, (x, w, scalar), {
1081-
"zero_point": zero_point,
1082-
"quant_block_size": quant_block_size,
1083-
"quantize_activation": quantize_activation,
1084-
"batch_block_size": batch_block_size,
1085-
"out_block_size": out_block_size,
1086-
"in_block_size": in_block_size,
1087-
"vmem_limit_bytes": vmem_limit_bytes
1088-
})
1078+
from torch_xla.experimental.pallas_kernels.quantized_matmul_kernel import (
1079+
quantized_matmul_int8,
1080+
get_tuned_block_sizes,
1081+
TUNED_BLOCK_SIZES,
1082+
)
1083+
bs, n_in_features = x.shape
1084+
n_out_features, _ = w.shape
1085+
jax_dtype = convert_torch_dtype_to_jax(x.dtype)
1086+
import jax.numpy as jnp
1087+
batch_block_size, out_block_size, in_block_size = get_tuned_block_sizes(
1088+
TUNED_BLOCK_SIZES, bs, n_out_features, n_in_features,
1089+
jnp.dtype(jax_dtype).name, quantize_activation)
1090+
if batch_block_size is not None and out_block_size is not None and in_block_size is not None:
1091+
return xb.call_jax(
1092+
quantized_matmul_int8, (x, w, scalar), {
1093+
"zero_point": zero_point,
1094+
"quant_block_size": quant_block_size,
1095+
"quantize_activation": quantize_activation,
1096+
"batch_block_size": batch_block_size,
1097+
"out_block_size": out_block_size,
1098+
"in_block_size": in_block_size,
1099+
"vmem_limit_bytes": vmem_limit_bytes
1100+
})
1101+
from torch_xla.experimental.xla_quantized_matmul import quantized_matmul_xla
1102+
return quantized_matmul_xla(
1103+
x, w, scalar, quantize_activation=quantize_activation)
10891104

10901105

10911106
def _multi_queries_paged_attention_nonkernel(

0 commit comments

Comments
 (0)