Skip to content

Commit bea86ee

Browse files
authored
Update ragged paged attention kernel to prevent vmem oom (#9346)
Signed-off-by: Chenyaaang <[email protected]>
1 parent 3a1ed62 commit bea86ee

File tree

2 files changed

+17
-1
lines changed

2 files changed

+17
-1
lines changed

torch_xla/experimental/custom_kernel.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1009,7 +1009,7 @@ def ragged_paged_attention(
10091009
from torch_xla.experimental.pallas_kernels.ragged_paged_attention_v2 import ragged_paged_attention as ragged_attention
10101010

10111011
if vmem_limit_bytes is None:
1012-
vmem_limit_bytes = 64 * 1024 * 1024
1012+
vmem_limit_bytes = 120 * 1024 * 1024
10131013

10141014
payload, _ = trace_pallas(
10151015
ragged_attention,

torch_xla/experimental/pallas_kernels/ragged_paged_attention_v2.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -453,6 +453,22 @@ def get_tuned_block_sizes(
453453

454454
# Default block sizes.
455455
bkv, bq = (128, 32)
456+
457+
def compute_actual_vmem_bytes(num_kv_pages_per_blk):
458+
q_block = bq * num_q_heads_per_blk * head_dim * q_dtype.itemsize
459+
in_block = q_block
460+
out_block = in_block
461+
kv_block = 2 * num_kv_pages_per_blk * page_size * 2 * num_kv_heads_per_blk * head_dim * kv_dtype.itemsize
462+
l_ref = num_kv_heads_per_blk * bq * 2 * num_kv_heads_per_blk * 128 * jnp.float32.dtype.itemsize
463+
m_ref = l_ref
464+
acc_ref = bq * num_q_heads_per_blk * head_dim * jnp.float32.dtype.itemsize
465+
return 2 * (in_block + out_block) + kv_block + l_ref + m_ref + acc_ref
466+
467+
# If the matrices are larger than 64MB, decrease num_kv_pages_per_blk by half.
468+
while compute_actual_vmem_bytes(bkv) >= 64 * 1024 * 1024:
469+
bkv //= 2
470+
bkv = max(bkv, 1)
471+
456472
if tpu_version == 4:
457473
# This default block size is not tuned, only make sure there's no
458474
# OOM in vmem

0 commit comments

Comments
 (0)