Skip to content

Commit 9995e97

Browse files
authored
Optimize KV cache dequantization performance (#9528)
1 parent d3d91a8 commit 9995e97

File tree

1 file changed

+10
-14
lines changed

1 file changed

+10
-14
lines changed

torch_xla/experimental/pallas_kernels/ragged_paged_attention_v2.py

Lines changed: 10 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1954,12 +1954,14 @@ def masked_store(ref, val, start, end, group=1):
19541954
# kv lens will be contracting dim, we should mask out the NaNs.
19551955
kv_mask = (
19561956
lax.broadcasted_iota(jnp.int32, k.shape, 0) < kv_len - kv_len_start)
1957-
k = jnp.where(kv_mask, k.astype(jnp.float32), 0).astype(k.dtype)
1958-
v = jnp.where(kv_mask, v.astype(jnp.float32), 0).astype(v.dtype)
1959-
1960-
qk = (
1961-
jnp.einsum("nd,md->nm", q, k, preferred_element_type=jnp.float32) *
1962-
sm_scale)
1957+
k = jnp.where(kv_mask, k, 0)
1958+
v = jnp.where(kv_mask, v, 0)
1959+
1960+
qk = jnp.einsum("nd,md->nm", q, k, preferred_element_type=jnp.float32)
1961+
if k_scale is not None:
1962+
qk *= sm_scale * k_scale
1963+
else:
1964+
qk *= sm_scale
19631965
store_start = jnp.maximum(q_start - q_len_start, 0)
19641966
store_end = jnp.minimum(q_end - q_len_start, num_q_per_blk)
19651967

@@ -2007,6 +2009,8 @@ def init_scratch_ref():
20072009
m_curr = jnp.max(qk, axis=1, keepdims=True)
20082010
s_curr = jnp.exp(qk - m_curr)
20092011
qkv = jnp.dot(s_curr, v, preferred_element_type=jnp.float32)
2012+
if v_scale is not None:
2013+
qkv *= v_scale
20102014
lm_store_shape = head_m_ref.shape
20112015
m_curr = jnp.broadcast_to(m_curr, lm_store_shape)
20122016
l_curr = jnp.broadcast_to(
@@ -2088,14 +2092,6 @@ def prefetch_next_kv_blk():
20882092
for step_idx in range(kv_load_step):
20892093
k = k_list[step_idx]
20902094
v = v_list[step_idx]
2091-
if k_scale is not None:
2092-
# NOTE: Conversion between arbitrary data types is not supported.
2093-
# That's why it is converted to float32 first.
2094-
k = k.astype(jnp.float32) * k_scale
2095-
k = k.astype(q_ref.dtype)
2096-
if v_scale is not None:
2097-
v = v.astype(jnp.float32) * v_scale
2098-
v = v.astype(q_ref.dtype)
20992095
kv_head_idx = kv_head_chunk_idx + step_idx
21002096
q_head_idx = kv_head_idx * num_q_heads_per_kv_head
21012097
# TODO(jevinjiang): extra handlig for packed type that can start at

0 commit comments

Comments
 (0)