Skip to content

Commit 83dc9da

Browse files
authored
Calculate vmem limit dynamically in the quantized matmul kernel. (#9470)
1 parent 1340308 commit 83dc9da

File tree

2 files changed

+12
-8
lines changed

2 files changed

+12
-8
lines changed

torch_xla/experimental/custom_kernel.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1073,7 +1073,6 @@ def quantized_matmul_int8(
10731073
batch_block_size: int | None = None,
10741074
out_block_size: int | None = None,
10751075
in_block_size: int | None = None,
1076-
vmem_limit_bytes: int | None = 64 * 1024 * 1024,
10771076
) -> torch.Tensor:
10781077
from torch_xla.experimental.pallas_kernels.quantized_matmul_kernel import (
10791078
quantized_matmul_int8,
@@ -1084,6 +1083,7 @@ def quantized_matmul_int8(
10841083
n_out_features, _ = w.shape
10851084
jax_dtype = convert_torch_dtype_to_jax(x.dtype)
10861085
import jax.numpy as jnp
1086+
# We fetch the tuned block sizes here instead of in the kernel because if we cannot find the block sizes (meaning we haven't tuned the kernel for that case), then we fall back to the XLA quantized matmul kernel, which has better perf than using kernel with a default but crappy block size.
10871087
batch_block_size, out_block_size, in_block_size = get_tuned_block_sizes(
10881088
TUNED_BLOCK_SIZES, bs, n_out_features, n_in_features,
10891089
jnp.dtype(jax_dtype).name, quantize_activation)
@@ -1096,7 +1096,6 @@ def quantized_matmul_int8(
10961096
"batch_block_size": batch_block_size,
10971097
"out_block_size": out_block_size,
10981098
"in_block_size": in_block_size,
1099-
"vmem_limit_bytes": vmem_limit_bytes
11001099
})
11011100
from torch_xla.experimental.xla_quantized_matmul import quantized_matmul_xla
11021101
return quantized_matmul_xla(
@@ -1737,7 +1736,7 @@ def gmm_non_xla(lhs: torch.Tensor,
17371736

17381737

17391738
XLA_LIB.define(
1740-
"quantized_matmul_int8(Tensor x, Tensor w, Tensor scalar, Tensor? zero_point=None, Tensor? quant_block_size=None, bool quantize_activation=False, int? batch_block_size=None, int? out_block_size=None, int? in_block_size=None, int? vmem_limit_bytes=None) -> Tensor",
1739+
"quantized_matmul_int8(Tensor x, Tensor w, Tensor scalar, Tensor? zero_point=None, Tensor? quant_block_size=None, bool quantize_activation=False, int? batch_block_size=None, int? out_block_size=None, int? in_block_size=None) -> Tensor",
17411740
)
17421741

17431742

@@ -1752,11 +1751,10 @@ def quantized_matmul_int8_xla(
17521751
batch_block_size: int | None = None,
17531752
out_block_size: int | None = None,
17541753
in_block_size: int | None = None,
1755-
vmem_limit_bytes: int | None = 64 * 1024 * 1024,
17561754
) -> torch.Tensor:
17571755
return quantized_matmul_int8(x, w, scalar, zero_point, quant_block_size,
17581756
quantize_activation, batch_block_size,
1759-
out_block_size, in_block_size, vmem_limit_bytes)
1757+
out_block_size, in_block_size)
17601758

17611759

17621760
@impl(XLA_LIB, "quantized_matmul_int8", "CompositeExplicitAutograd")
@@ -1770,7 +1768,6 @@ def quantized_matmul_int8_non_xla(
17701768
batch_block_size: int | None = None,
17711769
out_block_size: int | None = None,
17721770
in_block_size: int | None = None,
1773-
vmem_limit_bytes: int | None = 64 * 1024 * 1024,
17741771
) -> torch.Tensor:
17751772
# This will be called when dynamo use fake tensor to construct the fake output.
17761773
# We need to make sure output tensor's shape is correct.

torch_xla/experimental/pallas_kernels/quantized_matmul_kernel.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,6 @@ def _next_multiple(x, multiple):
9090
'batch_block_size',
9191
'out_block_size',
9292
'in_block_size',
93-
'vmem_limit_bytes',
9493
])
9594
def quantized_matmul_int8(
9695
x: jax.Array, # [bs, n_input_features]
@@ -104,7 +103,6 @@ def quantized_matmul_int8(
104103
batch_block_size: int | None = None,
105104
out_block_size: int | None = None,
106105
in_block_size: int | None = None,
107-
vmem_limit_bytes: int | None = 64 * 1024 * 1024,
108106
):
109107
assert zero_point is None, "Not implemented: zero_point is not supported."
110108
assert quant_block_size is None, "Not implemented: quant_block_size is not supported."
@@ -152,6 +150,15 @@ def quantized_matmul_int8(
152150
1] % in_block_size == 0, f"x.shape[1] ({x.shape[1]}) must be a multiple of block size ({in_block_size})"
153151

154152
acc_dtype = jnp.int32 if quantize_activation else x.dtype
153+
vmem_to_be_transferred = 2 * (
154+
batch_block_size * in_block_size * x.dtype.itemsize +
155+
out_block_size * in_block_size * w.dtype.itemsize + out_block_size *
156+
scalar.dtype.itemsize + batch_block_size * x_abs_max_val.dtype.itemsize +
157+
batch_block_size * out_block_size * x.dtype.itemsize
158+
) + batch_block_size * out_block_size * jnp.dtype(acc_dtype).itemsize
159+
# Within the kernel, it will use some extra VMEM for computation or vreg spills.
160+
vmem_used = vmem_to_be_transferred * 2
161+
vmem_limit_bytes = min(vmem_used * 2, 96 * 1024 * 1024)
155162
kernel = pl.pallas_call(
156163
functools.partial(
157164
matmul_kernel,

0 commit comments

Comments
 (0)