@@ -1073,7 +1073,6 @@ def quantized_matmul_int8(
1073
1073
batch_block_size : int | None = None ,
1074
1074
out_block_size : int | None = None ,
1075
1075
in_block_size : int | None = None ,
1076
- vmem_limit_bytes : int | None = 64 * 1024 * 1024 ,
1077
1076
) -> torch .Tensor :
1078
1077
from torch_xla .experimental .pallas_kernels .quantized_matmul_kernel import (
1079
1078
quantized_matmul_int8 ,
@@ -1084,6 +1083,7 @@ def quantized_matmul_int8(
1084
1083
n_out_features , _ = w .shape
1085
1084
jax_dtype = convert_torch_dtype_to_jax (x .dtype )
1086
1085
import jax .numpy as jnp
1086
+ # We fetch the tuned block sizes here instead of in the kernel because if we cannot find the block sizes (meaning we haven't tuned the kernel for that case), then we fall back to the XLA quantized matmul kernel, which has better perf than using kernel with a default but crappy block size.
1087
1087
batch_block_size , out_block_size , in_block_size = get_tuned_block_sizes (
1088
1088
TUNED_BLOCK_SIZES , bs , n_out_features , n_in_features ,
1089
1089
jnp .dtype (jax_dtype ).name , quantize_activation )
@@ -1096,7 +1096,6 @@ def quantized_matmul_int8(
1096
1096
"batch_block_size" : batch_block_size ,
1097
1097
"out_block_size" : out_block_size ,
1098
1098
"in_block_size" : in_block_size ,
1099
- "vmem_limit_bytes" : vmem_limit_bytes
1100
1099
})
1101
1100
from torch_xla .experimental .xla_quantized_matmul import quantized_matmul_xla
1102
1101
return quantized_matmul_xla (
@@ -1737,7 +1736,7 @@ def gmm_non_xla(lhs: torch.Tensor,
1737
1736
1738
1737
1739
1738
XLA_LIB .define (
1740
- "quantized_matmul_int8(Tensor x, Tensor w, Tensor scalar, Tensor? zero_point=None, Tensor? quant_block_size=None, bool quantize_activation=False, int? batch_block_size=None, int? out_block_size=None, int? in_block_size=None, int? vmem_limit_bytes=None ) -> Tensor" ,
1739
+ "quantized_matmul_int8(Tensor x, Tensor w, Tensor scalar, Tensor? zero_point=None, Tensor? quant_block_size=None, bool quantize_activation=False, int? batch_block_size=None, int? out_block_size=None, int? in_block_size=None) -> Tensor" ,
1741
1740
)
1742
1741
1743
1742
@@ -1752,11 +1751,10 @@ def quantized_matmul_int8_xla(
1752
1751
batch_block_size : int | None = None ,
1753
1752
out_block_size : int | None = None ,
1754
1753
in_block_size : int | None = None ,
1755
- vmem_limit_bytes : int | None = 64 * 1024 * 1024 ,
1756
1754
) -> torch .Tensor :
1757
1755
return quantized_matmul_int8 (x , w , scalar , zero_point , quant_block_size ,
1758
1756
quantize_activation , batch_block_size ,
1759
- out_block_size , in_block_size , vmem_limit_bytes )
1757
+ out_block_size , in_block_size )
1760
1758
1761
1759
1762
1760
@impl (XLA_LIB , "quantized_matmul_int8" , "CompositeExplicitAutograd" )
@@ -1770,7 +1768,6 @@ def quantized_matmul_int8_non_xla(
1770
1768
batch_block_size : int | None = None ,
1771
1769
out_block_size : int | None = None ,
1772
1770
in_block_size : int | None = None ,
1773
- vmem_limit_bytes : int | None = 64 * 1024 * 1024 ,
1774
1771
) -> torch .Tensor :
1775
1772
# This will be called when dynamo use fake tensor to construct the fake output.
1776
1773
# We need to make sure output tensor's shape is correct.
0 commit comments