24
24
# TPU requires the head size to be a multiple of 128.
25
25
TPU_HEAD_SIZE_ALIGNMENT = 128
26
26
27
+ # Note: TPU can fp8 as storage dtype but doesn't support converting from uint8
28
+ # from to fp32 directly. That's why it has a dtype mapping different from GPU
29
+ TPU_STR_DTYPE_TO_TORCH_DTYPE = {
30
+ "half" : torch .half ,
31
+ "bfloat16" : torch .bfloat16 ,
32
+ "float" : torch .float ,
33
+ "fp8" : torch .float8_e4m3fn ,
34
+ "fp8_e4m3" : torch .float8_e4m3fn ,
35
+ "fp8_e5m2" : torch .float8_e5m2 ,
36
+ "int8" : torch .int8 ,
37
+ "uint8" : torch .uint8 ,
38
+ }
39
+
27
40
28
41
class PallasAttentionBackend (AttentionBackend ):
29
42
@@ -152,15 +165,18 @@ def __init__(
152
165
self .num_queries_per_kv = self .num_heads // self .num_kv_heads
153
166
if alibi_slopes is not None :
154
167
raise NotImplementedError ("Alibi slopes is not supported." )
155
- if kv_cache_dtype != "auto" :
156
- raise NotImplementedError ("FP8 KV cache dtype is not supported." )
157
168
158
169
if attn_type != AttentionType .DECODER :
159
170
raise NotImplementedError ("Encoder self-attention and "
160
171
"encoder/decoder cross-attention "
161
172
"are not implemented for "
162
173
"PallasAttentionBackendImpl" )
163
174
175
+ self .kv_cache_quantized_dtype = None
176
+ if kv_cache_dtype != "auto" :
177
+ self .kv_cache_quantized_dtype = TPU_STR_DTYPE_TO_TORCH_DTYPE .get (
178
+ kv_cache_dtype .lower ().strip ())
179
+
164
180
def forward (
165
181
self ,
166
182
layer : AttentionLayer ,
@@ -194,7 +210,6 @@ def forward(
194
210
output = torch .ones_like (query )
195
211
return output
196
212
197
- assert layer ._k_scale_float == 1.0 and layer ._v_scale_float == 1.0
198
213
num_tokens , hidden_size = query .shape
199
214
query = query .view (num_tokens , self .num_heads , self .head_size )
200
215
key = key .view (- 1 , self .num_kv_heads , self .head_size )
@@ -215,10 +230,21 @@ def forward(
215
230
# Skip this if sharing KV cache with an earlier attention layer.
216
231
slot_mapping = attn_metadata .slot_mapping
217
232
write_to_kv_cache (
218
- key , value , kv_cache , slot_mapping ,
233
+ key ,
234
+ value ,
235
+ kv_cache ,
236
+ slot_mapping ,
219
237
attn_metadata .num_slices_per_kv_cache_update_block ,
220
- attn_metadata .num_kv_update_slices )
221
-
238
+ attn_metadata .num_kv_update_slices ,
239
+ self .kv_cache_quantized_dtype ,
240
+ layer ._k_scale_float ,
241
+ layer ._v_scale_float ,
242
+ )
243
+
244
+ if self .kv_cache_quantized_dtype is not None and (
245
+ layer ._k_scale_float == 0.0 or layer ._v_scale_float == 0.0 ):
246
+ raise ValueError (
247
+ "k_scale_float and v_scale_float must be non-zero" )
222
248
output = torch .ops .xla .ragged_paged_attention (
223
249
query ,
224
250
kv_cache ,
@@ -236,6 +262,8 @@ def forward(
236
262
sm_scale = self .scale ,
237
263
sliding_window = self .sliding_window ,
238
264
soft_cap = self .logits_soft_cap ,
265
+ k_scale = layer ._k_scale_float ,
266
+ v_scale = layer ._v_scale_float ,
239
267
)
240
268
241
269
if self .head_size % TPU_HEAD_SIZE_ALIGNMENT != 0 :
@@ -251,18 +279,32 @@ def write_to_kv_cache(
251
279
slot_mapping : torch .Tensor ,
252
280
num_slices_per_kv_cache_update_block : int ,
253
281
num_kv_update_slices : torch .Tensor ,
282
+ kv_cache_quantized_dtype : Optional [torch .dtype ] = None ,
283
+ k_scale : float = 1.0 ,
284
+ v_scale : float = 1.0 ,
254
285
) -> None :
255
286
""" Write the key and values to the KV cache.
256
287
257
288
Args:
258
- key: shape = [num_tokens, num_kv_heads * head_size]
259
- value: shape = [num_tokens, num_kv_heads * head_size]
289
+ key: shape = [num_tokens, num_kv_heads, head_size]
290
+ value: shape = [num_tokens, num_kv_heads, head_size]
260
291
kv_cache = [num_blocks, block_size, num_kv_heads * 2, head_size]
261
292
num_slices_per_kv_cache_update_block: int
262
293
"""
263
294
_ , page_size , num_combined_kv_heads , head_size = kv_cache .shape
264
295
head_size = cdiv (head_size ,
265
296
TPU_HEAD_SIZE_ALIGNMENT ) * TPU_HEAD_SIZE_ALIGNMENT
297
+
298
+ if kv_cache_quantized_dtype is not None :
299
+ dtype_info = torch .finfo (kv_cache_quantized_dtype )
300
+ key = key .to (torch .float32 ) / k_scale
301
+ # NOTE: clamp is added here to avoid out of range of quantized dtype
302
+ key = torch .clamp (key , dtype_info .min , dtype_info .max )
303
+ key = key .to (kv_cache_quantized_dtype )
304
+ value = value .to (torch .float32 ) / v_scale
305
+ value = torch .clamp (value , dtype_info .min , dtype_info .max )
306
+ value = value .to (kv_cache_quantized_dtype )
307
+
266
308
kv = torch .cat ([key , value ], axis = - 1 ).reshape (- 1 , num_combined_kv_heads ,
267
309
head_size )
268
310
0 commit comments