File tree Expand file tree Collapse file tree 2 files changed +17
-1
lines changed Expand file tree Collapse file tree 2 files changed +17
-1
lines changed Original file line number Diff line number Diff line change @@ -1009,7 +1009,7 @@ def ragged_paged_attention(
10091009 from torch_xla .experimental .pallas_kernels .ragged_paged_attention_v2 import ragged_paged_attention as ragged_attention
10101010
10111011 if vmem_limit_bytes is None :
1012- vmem_limit_bytes = 64 * 1024 * 1024
1012+ vmem_limit_bytes = 120 * 1024 * 1024
10131013
10141014 payload , _ = trace_pallas (
10151015 ragged_attention ,
Original file line number Diff line number Diff line change @@ -453,6 +453,22 @@ def get_tuned_block_sizes(
453453
454454 # Default block sizes.
455455 bkv , bq = (128 , 32 )
456+
457+ def compute_actual_vmem_bytes (num_kv_pages_per_blk ):
458+ q_block = bq * num_q_heads_per_blk * head_dim * q_dtype .itemsize
459+ in_block = q_block
460+ out_block = in_block
461+ kv_block = 2 * num_kv_pages_per_blk * page_size * 2 * num_kv_heads_per_blk * head_dim * kv_dtype .itemsize
462+ l_ref = num_kv_heads_per_blk * bq * 2 * num_kv_heads_per_blk * 128 * jnp .float32 .dtype .itemsize
463+ m_ref = l_ref
464+ acc_ref = bq * num_q_heads_per_blk * head_dim * jnp .float32 .dtype .itemsize
465+ return 2 * (in_block + out_block ) + kv_block + l_ref + m_ref + acc_ref
466+
467+ # If the matrices are larger than 64MB, decrease num_kv_pages_per_blk by half.
468+ while compute_actual_vmem_bytes (bkv ) >= 64 * 1024 * 1024 :
469+ bkv //= 2
470+ bkv = max (bkv , 1 )
471+
456472 if tpu_version == 4 :
457473 # This default block size is not tuned, only make sure there's no
458474 # OOM in vmem
You can’t perform that action at this time.
0 commit comments