@@ -1954,12 +1954,14 @@ def masked_store(ref, val, start, end, group=1):
1954
1954
# kv lens will be contracting dim, we should mask out the NaNs.
1955
1955
kv_mask = (
1956
1956
lax .broadcasted_iota (jnp .int32 , k .shape , 0 ) < kv_len - kv_len_start )
1957
- k = jnp .where (kv_mask , k .astype (jnp .float32 ), 0 ).astype (k .dtype )
1958
- v = jnp .where (kv_mask , v .astype (jnp .float32 ), 0 ).astype (v .dtype )
1959
-
1960
- qk = (
1961
- jnp .einsum ("nd,md->nm" , q , k , preferred_element_type = jnp .float32 ) *
1962
- sm_scale )
1957
+ k = jnp .where (kv_mask , k , 0 )
1958
+ v = jnp .where (kv_mask , v , 0 )
1959
+
1960
+ qk = jnp .einsum ("nd,md->nm" , q , k , preferred_element_type = jnp .float32 )
1961
+ if k_scale is not None :
1962
+ qk *= sm_scale * k_scale
1963
+ else :
1964
+ qk *= sm_scale
1963
1965
store_start = jnp .maximum (q_start - q_len_start , 0 )
1964
1966
store_end = jnp .minimum (q_end - q_len_start , num_q_per_blk )
1965
1967
@@ -2007,6 +2009,8 @@ def init_scratch_ref():
2007
2009
m_curr = jnp .max (qk , axis = 1 , keepdims = True )
2008
2010
s_curr = jnp .exp (qk - m_curr )
2009
2011
qkv = jnp .dot (s_curr , v , preferred_element_type = jnp .float32 )
2012
+ if v_scale is not None :
2013
+ qkv *= v_scale
2010
2014
lm_store_shape = head_m_ref .shape
2011
2015
m_curr = jnp .broadcast_to (m_curr , lm_store_shape )
2012
2016
l_curr = jnp .broadcast_to (
@@ -2088,14 +2092,6 @@ def prefetch_next_kv_blk():
2088
2092
for step_idx in range (kv_load_step ):
2089
2093
k = k_list [step_idx ]
2090
2094
v = v_list [step_idx ]
2091
- if k_scale is not None :
2092
- # NOTE: Conversion between arbitrary data types is not supported.
2093
- # That's why it is converted to float32 first.
2094
- k = k .astype (jnp .float32 ) * k_scale
2095
- k = k .astype (q_ref .dtype )
2096
- if v_scale is not None :
2097
- v = v .astype (jnp .float32 ) * v_scale
2098
- v = v .astype (q_ref .dtype )
2099
2095
kv_head_idx = kv_head_chunk_idx + step_idx
2100
2096
q_head_idx = kv_head_idx * num_q_heads_per_kv_head
2101
2097
# TODO(jevinjiang): extra handlig for packed type that can start at
0 commit comments