1
1
import logging
2
2
import sys
3
3
import unittest
4
+ from unittest .mock import patch
4
5
from absl .testing import parameterized
5
6
6
7
import torch
@@ -877,6 +878,11 @@ def test_ragged_paged_attention_wrapper_without_dynamo(
877
878
use_dynamo = False ,
878
879
)
879
880
881
+ def _compute_rel_error (self , x , q_x ):
882
+ return torch .mean (torch .sqrt (torch .mean (torch .square (q_x - x ),
883
+ axis = 1 ))) / torch .sqrt (
884
+ torch .mean (torch .square (x )))
885
+
880
886
def _test_quantized_matmul_int8 (
881
887
self ,
882
888
dtype ,
@@ -885,10 +891,6 @@ def _test_quantized_matmul_int8(
885
891
n_output_features ,
886
892
quantize_activation ,
887
893
use_dynamo ,
888
- batch_block_size = None ,
889
- out_block_size = None ,
890
- in_block_size = None ,
891
- atol = 1.5 ,
892
894
n_bits = 8 ,
893
895
):
894
896
x = torch .randn ((bs , n_input_features ), dtype = dtype )
@@ -918,17 +920,9 @@ def _test_quantized_matmul_int8(
918
920
scalar_xla = scalar .to ('xla' )
919
921
if use_dynamo :
920
922
921
- def quantized_matmul_int8_wrapper (x , w_int , scalar , quantize_activation ,
922
- batch_block_size , out_block_size ,
923
- in_block_size ):
923
+ def quantized_matmul_int8_wrapper (x , w_int , scalar , quantize_activation ):
924
924
return torch .ops .xla .quantized_matmul_int8 (
925
- x ,
926
- w_int ,
927
- scalar ,
928
- quantize_activation = quantize_activation ,
929
- batch_block_size = batch_block_size ,
930
- out_block_size = out_block_size ,
931
- in_block_size = in_block_size )
925
+ x , w_int , scalar , quantize_activation = quantize_activation )
932
926
933
927
quantized_matmul_int8 = torch .compile (
934
928
quantized_matmul_int8_wrapper , backend = "openxla" )
@@ -941,46 +935,90 @@ def quantized_matmul_int8_wrapper(x, w_int, scalar, quantize_activation,
941
935
w_int_xla ,
942
936
scalar_xla ,
943
937
quantize_activation = quantize_activation ,
944
- batch_block_size = batch_block_size ,
945
- out_block_size = out_block_size ,
946
- in_block_size = in_block_size ).cpu ()
938
+ ).cpu ()
939
+
940
+ # print(f'Output max diff: {torch.max(torch.abs(expected - actual))}')
941
+ # print(f'Output mean diff: {torch.mean(torch.abs(expected - actual))}')
942
+ rel_error = self ._compute_rel_error (expected , actual )
947
943
948
944
self .assertEqual (actual .shape , expected .shape )
949
945
self .assertEqual (actual .dtype , expected .dtype )
950
- self .assertTrue (torch .allclose (actual , expected , atol = atol ))
946
+ self .assertTrue (rel_error < 3e-2 )
947
+
948
+ @parameterized .product (
949
+ dtype = [torch .bfloat16
950
+ ], # not testing float32 because we haven't tuned for float32 case.
951
+ quantize_activation = [True ],
952
+ use_dynamo = [True , False ],
953
+ )
954
+ @unittest .skipIf (xr .device_type () != 'TPU' or tpu .version () < 5 ,
955
+ "This test only works on TPUv5+." )
956
+ @patch (
957
+ 'torch_xla.experimental.pallas_kernels.quantized_matmul_kernel.get_tpu_version'
958
+ )
959
+ def test_quantized_matmul_int8_wrapper_key_exists_in_table (
960
+ self ,
961
+ get_tpu_version ,
962
+ dtype ,
963
+ quantize_activation ,
964
+ use_dynamo ,
965
+ ):
966
+ from torch_xla .experimental .pallas_kernels .quantized_matmul_kernel import TUNED_BLOCK_SIZES
967
+ num_cases_to_test = 2
968
+ if len (TUNED_BLOCK_SIZES ) < num_cases_to_test :
969
+ self .fail (
970
+ "Not enough tuned block sizes for quantized matmul int8 test. But we should have {num_cases_to_test} block sizes to test."
971
+ )
972
+ input_shapes = []
973
+ for key in TUNED_BLOCK_SIZES .keys ():
974
+ if len (input_shapes ) >= num_cases_to_test :
975
+ break
976
+ _ , batch_size , n_output_features , n_input_features , * _ = key
977
+ input_shapes .append ((batch_size , n_output_features , n_input_features ))
978
+ tpu_version_to_use = 6
979
+ get_tpu_version .return_value = tpu_version_to_use
980
+ for batch_size , n_output_features , n_input_features in input_shapes :
981
+ self ._test_quantized_matmul_int8 (
982
+ dtype ,
983
+ batch_size ,
984
+ n_input_features ,
985
+ n_output_features ,
986
+ quantize_activation ,
987
+ use_dynamo = use_dynamo ,
988
+ )
951
989
952
990
@parameterized .product (
953
991
dtype = [torch .bfloat16 , torch .float32 ],
954
992
bs = [256 , 512 ],
955
993
n_input_features = [256 , 512 ],
956
994
n_output_features = [256 , 512 ],
957
995
quantize_activation = [True ],
958
- kernel_block_sizes = [(None , None , None ), (256 , 256 , 256 )],
959
996
use_dynamo = [True , False ],
960
997
)
961
998
@unittest .skipIf (xr .device_type () != 'TPU' or tpu .version () < 5 ,
962
999
"This test only works on TPUv5+." )
963
- def test_quantized_matmul_int8_wrapper (
1000
+ @patch (
1001
+ 'torch_xla.experimental.pallas_kernels.quantized_matmul_kernel.get_tuned_block_sizes'
1002
+ )
1003
+ def test_quantized_matmul_int8_wrapper_key_not_exists_in_table (
964
1004
self ,
1005
+ get_tuned_block_sizes ,
965
1006
dtype ,
966
1007
bs ,
967
1008
n_input_features ,
968
1009
n_output_features ,
969
1010
quantize_activation ,
970
- kernel_block_sizes ,
971
1011
use_dynamo ,
972
1012
):
973
- batch_block_size , out_block_size , in_block_size = kernel_block_sizes
1013
+ get_tuned_block_sizes . return_value = ( None , None , None )
974
1014
self ._test_quantized_matmul_int8 (
975
1015
dtype ,
976
1016
bs ,
977
1017
n_input_features ,
978
1018
n_output_features ,
979
1019
quantize_activation ,
980
1020
use_dynamo = use_dynamo ,
981
- batch_block_size = batch_block_size ,
982
- out_block_size = out_block_size ,
983
- in_block_size = in_block_size )
1021
+ )
984
1022
985
1023
@unittest .skipIf (xr .device_type () != 'TPU' or tpu .version () < 4 ,
986
1024
"This test only works on TPUv4+." )
0 commit comments