Skip to content

Commit 40a36cc

Browse files
authored
[ROCm][Bugfix] Use platform specific FP8 dtype (#15717)
Signed-off-by: Gregory Shtrasberg <[email protected]>
1 parent ef608c3 commit 40a36cc

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

vllm/attention/ops/prefix_prefill.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -753,7 +753,7 @@ def context_attention_fwd(q,
753753
assert (v_cache.dtype == torch.uint8)
754754

755755
if kv_cache_dtype in ("fp8", "fp8_e4m3"):
756-
target_dtype = torch.float8_e4m3fn
756+
target_dtype = current_platform.fp8_dtype()
757757
elif kv_cache_dtype == "fp8_e5m2":
758758
target_dtype = torch.float8_e5m2
759759
else:

0 commit comments

Comments
 (0)