Skip to content

Commit 52569ec

Browse files
authored
Unify the return type of w8a8 matmul between fallback and the actual impl. (#9452)
1 parent 156d913 commit 52569ec

File tree

3 files changed

+42
-8
lines changed

3 files changed

+42
-8
lines changed

test/test_pallas.py

Lines changed: 29 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -878,10 +878,10 @@ def test_ragged_paged_attention_wrapper_without_dynamo(
878878
use_dynamo=False,
879879
)
880880

881+
# compute normalized Frobenius error.
881882
def _compute_rel_error(self, x, q_x):
882-
return torch.mean(torch.sqrt(torch.mean(torch.square(q_x - x),
883-
axis=1))) / torch.sqrt(
884-
torch.mean(torch.square(x)))
883+
abs_error = torch.sqrt(torch.mean(torch.square(q_x - x), axis=1))
884+
return torch.mean(abs_error) / torch.sqrt(torch.mean(torch.square(x)))
885885

886886
def _test_quantized_matmul_int8(
887887
self,
@@ -909,7 +909,9 @@ def _test_quantized_matmul_int8(
909909
qscheme=torch.per_channel_symmetric)
910910
w_int = torch.ops.quantized_decomposed.quantize_per_channel(
911911
w, scalar, zero_point, 0, int_min, int_max, torch.int8)
912-
scalar = scalar.to(w.dtype)
912+
# In the actual workload such as vLLM, the scalar is obtained
913+
# offline and is usually in float32.
914+
scalar = scalar.to(torch.float32)
913915

914916
x_copy = x.clone()
915917
w_copy = w.clone()
@@ -942,7 +944,7 @@ def quantized_matmul_int8_wrapper(x, w_int, scalar, quantize_activation):
942944
rel_error = self._compute_rel_error(expected, actual)
943945

944946
self.assertEqual(actual.shape, expected.shape)
945-
self.assertEqual(actual.dtype, expected.dtype)
947+
self.assertEqual(actual.dtype, x.dtype)
946948
self.assertTrue(rel_error < 3e-2)
947949

948950
@parameterized.product(
@@ -1020,6 +1022,28 @@ def test_quantized_matmul_int8_wrapper_key_not_exists_in_table(
10201022
use_dynamo=use_dynamo,
10211023
)
10221024

1025+
@unittest.skipIf(xr.device_type() != 'TPU', "This test only works on TPU.")
1026+
@parameterized.product(
1027+
dtype=[torch.bfloat16, torch.float32],
1028+
use_dynamo=[True, False],
1029+
)
1030+
def test_quantized_matmul_int8_wrapper_fallback(self, dtype, use_dynamo):
1031+
x = torch.randn(10, 20, device='meta', dtype=dtype)
1032+
w = torch.randint(-128, 127, (30, 20), device='meta', dtype=torch.int8)
1033+
scalar = torch.randn(30, device='meta', dtype=torch.float32)
1034+
if use_dynamo:
1035+
1036+
def quantized_matmul_int8_wrapper(x, w_int, scalar, quantize_activation):
1037+
return torch.ops.xla.quantized_matmul_int8(
1038+
x, w_int, scalar, quantize_activation=quantize_activation)
1039+
1040+
quantized_matmul_int8 = torch.compile(
1041+
quantized_matmul_int8_wrapper, backend="openxla")
1042+
else:
1043+
quantized_matmul_int8 = torch.ops.xla.quantized_matmul_int8
1044+
res = quantized_matmul_int8(x, w, scalar, quantize_activation=True)
1045+
self.assertEqual(res.dtype, x.dtype)
1046+
10231047
@unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 4,
10241048
"This test only works on TPUv4+.")
10251049
def test_paged_attention_multi_queries_wrapper(self):

test/test_quantized_matmul_pallas_kernel.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,13 @@ def quantize_array(x, n_bits: int = 8, dim: int = -1):
2626
return x_int, scale.astype(x.dtype)
2727

2828

29+
# compute normalized Frobenius error.
30+
@jax.jit
31+
def _compute_rel_error(x, q_x):
32+
abs_error = jnp.sqrt(jnp.mean(jnp.square(q_x - x), axis=1))
33+
return jnp.mean(abs_error) / jnp.sqrt(jnp.mean(jnp.square(x)))
34+
35+
2936
@jtu.with_config(jax_numpy_dtype_promotion="standard")
3037
class QuantizedMatmulKernelTest(jtu.JaxTestCase):
3138

@@ -69,7 +76,10 @@ def _test_quantized_matmul(self,
6976
expected = jax.lax.dot_general(
7077
x_copy, w_copy, dimension_numbers=(((1,), (1,)), ((), ())))
7178

72-
self.assertEqual(output.dtype, expected.dtype)
79+
rel_error = _compute_rel_error(expected, output)
80+
self.assertTrue(rel_error < 2e-2)
81+
82+
self.assertEqual(output.dtype, x.dtype)
7383
self.assertEqual(output.shape, expected.shape)
7484
self.assertAllClose(output, expected, atol=atol)
7585

torch_xla/experimental/custom_kernel.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1100,7 +1100,7 @@ def quantized_matmul_int8(
11001100
})
11011101
from torch_xla.experimental.xla_quantized_matmul import quantized_matmul_xla
11021102
return quantized_matmul_xla(
1103-
x, w, scalar, quantize_activation=quantize_activation)
1103+
x, w, scalar, quantize_activation=quantize_activation).to(x.dtype)
11041104

11051105

11061106
def _multi_queries_paged_attention_nonkernel(
@@ -1778,4 +1778,4 @@ def quantized_matmul_int8_non_xla(
17781778
warnings.warn(
17791779
f'XLA quantized_matmul_int8 should only be applied to tensors on XLA device'
17801780
)
1781-
return torch.empty(x.shape[0], w.shape[0], device=x.device)
1781+
return torch.empty(x.shape[0], w.shape[0], device=x.device, dtype=x.dtype)

0 commit comments

Comments
 (0)