@@ -878,10 +878,10 @@ def test_ragged_paged_attention_wrapper_without_dynamo(
878
878
use_dynamo = False ,
879
879
)
880
880
881
+ # compute normalized Frobenius error.
881
882
def _compute_rel_error (self , x , q_x ):
882
- return torch .mean (torch .sqrt (torch .mean (torch .square (q_x - x ),
883
- axis = 1 ))) / torch .sqrt (
884
- torch .mean (torch .square (x )))
883
+ abs_error = torch .sqrt (torch .mean (torch .square (q_x - x ), axis = 1 ))
884
+ return torch .mean (abs_error ) / torch .sqrt (torch .mean (torch .square (x )))
885
885
886
886
def _test_quantized_matmul_int8 (
887
887
self ,
@@ -909,7 +909,9 @@ def _test_quantized_matmul_int8(
909
909
qscheme = torch .per_channel_symmetric )
910
910
w_int = torch .ops .quantized_decomposed .quantize_per_channel (
911
911
w , scalar , zero_point , 0 , int_min , int_max , torch .int8 )
912
- scalar = scalar .to (w .dtype )
912
+ # In the actual workload such as vLLM, the scalar is obtained
913
+ # offline and is usually in float32.
914
+ scalar = scalar .to (torch .float32 )
913
915
914
916
x_copy = x .clone ()
915
917
w_copy = w .clone ()
@@ -942,7 +944,7 @@ def quantized_matmul_int8_wrapper(x, w_int, scalar, quantize_activation):
942
944
rel_error = self ._compute_rel_error (expected , actual )
943
945
944
946
self .assertEqual (actual .shape , expected .shape )
945
- self .assertEqual (actual .dtype , expected .dtype )
947
+ self .assertEqual (actual .dtype , x .dtype )
946
948
self .assertTrue (rel_error < 3e-2 )
947
949
948
950
@parameterized .product (
@@ -1020,6 +1022,28 @@ def test_quantized_matmul_int8_wrapper_key_not_exists_in_table(
1020
1022
use_dynamo = use_dynamo ,
1021
1023
)
1022
1024
1025
+ @unittest .skipIf (xr .device_type () != 'TPU' , "This test only works on TPU." )
1026
+ @parameterized .product (
1027
+ dtype = [torch .bfloat16 , torch .float32 ],
1028
+ use_dynamo = [True , False ],
1029
+ )
1030
+ def test_quantized_matmul_int8_wrapper_fallback (self , dtype , use_dynamo ):
1031
+ x = torch .randn (10 , 20 , device = 'meta' , dtype = dtype )
1032
+ w = torch .randint (- 128 , 127 , (30 , 20 ), device = 'meta' , dtype = torch .int8 )
1033
+ scalar = torch .randn (30 , device = 'meta' , dtype = torch .float32 )
1034
+ if use_dynamo :
1035
+
1036
+ def quantized_matmul_int8_wrapper (x , w_int , scalar , quantize_activation ):
1037
+ return torch .ops .xla .quantized_matmul_int8 (
1038
+ x , w_int , scalar , quantize_activation = quantize_activation )
1039
+
1040
+ quantized_matmul_int8 = torch .compile (
1041
+ quantized_matmul_int8_wrapper , backend = "openxla" )
1042
+ else :
1043
+ quantized_matmul_int8 = torch .ops .xla .quantized_matmul_int8
1044
+ res = quantized_matmul_int8 (x , w , scalar , quantize_activation = True )
1045
+ self .assertEqual (res .dtype , x .dtype )
1046
+
1023
1047
@unittest .skipIf (xr .device_type () != 'TPU' or tpu .version () < 4 ,
1024
1048
"This test only works on TPUv4+." )
1025
1049
def test_paged_attention_multi_queries_wrapper (self ):
0 commit comments