Skip to content

Commit 4ef95b0

Browse files
authored
[Bugfix] use float32 precision in samplers/test_logprobs.py for comparing with HF (#6409)
Signed-off-by: Thomas Parnell <[email protected]>
1 parent eaec4b9 commit 4ef95b0

File tree

2 files changed

+8
-1
lines changed

2 files changed

+8
-1
lines changed

tests/samplers/test_logprobs.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@
1111

1212

1313
@pytest.mark.parametrize("model", MODELS)
14-
@pytest.mark.parametrize("dtype", ["half"])
14+
@pytest.mark.parametrize("dtype",
15+
["float"]) # needed for comparing logprobs with HF
1516
@pytest.mark.parametrize("chunked_prefill_token_size", [1, 4, 16, -1])
1617
@pytest.mark.parametrize("num_top_logprobs", [6]) # 32000 == vocab_size
1718
@pytest.mark.parametrize("detokenize", [True, False])

vllm/attention/ops/prefix_prefill.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -687,6 +687,12 @@ def context_attention_fwd(q,
687687

688688
cap = current_platform.get_device_capability()
689689
BLOCK = 128 if cap[0] >= 8 else 64
690+
691+
# need to reduce num. blocks when using fp32
692+
# due to increased use of GPU shared memory
693+
if q.dtype is torch.float32:
694+
BLOCK = BLOCK // 2
695+
690696
# shape constraints
691697
Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
692698
assert Lq == Lk and Lk == Lv

0 commit comments

Comments
 (0)