22# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33
44
5- import pytest
6-
75from vllm .logprobs import (
86 FlatLogprobs ,
97 Logprob ,
1412)
1513
1614
17- def test_create_logprobs_non_flat (monkeypatch : pytest .MonkeyPatch ) -> None :
18- monkeypatch .setenv ("VLLM_FLAT_LOGPROBS" , "0" )
19-
20- prompt_logprobs = create_prompt_logprobs ()
15+ def test_create_logprobs_non_flat () -> None :
16+ prompt_logprobs = create_prompt_logprobs (flat_logprobs = False )
2117 assert isinstance (prompt_logprobs , list )
2218 # Ensure first prompt position logprobs is None
2319 assert len (prompt_logprobs ) == 1
2420 assert prompt_logprobs [0 ] is None
2521
26- sample_logprobs = create_sample_logprobs ()
22+ sample_logprobs = create_sample_logprobs (flat_logprobs = False )
2723 assert isinstance (sample_logprobs , list )
2824 assert len (sample_logprobs ) == 0
2925
3026
31- def test_create_logprobs_flat (monkeypatch : pytest .MonkeyPatch ) -> None :
32- monkeypatch .setenv ("VLLM_FLAT_LOGPROBS" , "1" )
33-
34- prompt_logprobs = create_prompt_logprobs ()
27+ def test_create_logprobs_flat () -> None :
28+ prompt_logprobs = create_prompt_logprobs (flat_logprobs = True )
3529 assert isinstance (prompt_logprobs , FlatLogprobs )
3630 assert prompt_logprobs .start_indices == [0 ]
3731 assert prompt_logprobs .end_indices == [0 ]
@@ -43,7 +37,7 @@ def test_create_logprobs_flat(monkeypatch: pytest.MonkeyPatch) -> None:
4337 assert len (prompt_logprobs ) == 1
4438 assert prompt_logprobs [0 ] == dict ()
4539
46- sample_logprobs = create_sample_logprobs ()
40+ sample_logprobs = create_sample_logprobs (flat_logprobs = True )
4741 assert isinstance (sample_logprobs , FlatLogprobs )
4842 assert len (sample_logprobs .start_indices ) == 0
4943 assert len (sample_logprobs .end_indices ) == 0
@@ -54,11 +48,8 @@ def test_create_logprobs_flat(monkeypatch: pytest.MonkeyPatch) -> None:
5448 assert len (sample_logprobs ) == 0
5549
5650
57- def test_append_logprobs_for_next_position_none_flat (
58- monkeypatch : pytest .MonkeyPatch ,
59- ) -> None :
60- monkeypatch .setenv ("VLLM_FLAT_LOGPROBS" , "0" )
61- logprobs = create_sample_logprobs ()
51+ def test_append_logprobs_for_next_position_none_flat () -> None :
52+ logprobs = create_sample_logprobs (flat_logprobs = False )
6253 append_logprobs_for_next_position (
6354 logprobs ,
6455 token_ids = [1 ],
@@ -85,11 +76,8 @@ def test_append_logprobs_for_next_position_none_flat(
8576 ]
8677
8778
88- def test_append_logprobs_for_next_position_flat (
89- monkeypatch : pytest .MonkeyPatch ,
90- ) -> None :
91- monkeypatch .setenv ("VLLM_FLAT_LOGPROBS" , "1" )
92- logprobs = create_sample_logprobs ()
79+ def test_append_logprobs_for_next_position_flat () -> None :
80+ logprobs = create_sample_logprobs (flat_logprobs = True )
9381 append_logprobs_for_next_position (
9482 logprobs ,
9583 token_ids = [1 ],
0 commit comments