File tree Expand file tree Collapse file tree 3 files changed +41
-6
lines changed Expand file tree Collapse file tree 3 files changed +41
-6
lines changed Original file line number Diff line number Diff line change 13
13
_LONG_PROMPTS = [os .path .join (_TEST_DIR , "prompts" , "summary.txt" )]
14
14
15
15
16
- def _read_prompts (filename : str ) -> str :
17
- prompts = []
16
+ def _read_prompts (filename : str ) -> List [str ]:
18
17
with open (filename , "r" ) as f :
19
- prompt = f .readline ()
20
- prompts .append (prompt )
21
- return prompts
18
+ prompts = f .readlines ()
19
+ return prompts
22
20
23
21
24
22
@pytest .fixture
@@ -165,6 +163,7 @@ def __init__(
165
163
model_name : str ,
166
164
tokenizer_name : Optional [str ] = None ,
167
165
dtype : str = "half" ,
166
+ disable_log_stats : bool = True ,
168
167
tensor_parallel_size : int = 1 ,
169
168
) -> None :
170
169
self .model = LLM (
@@ -173,6 +172,7 @@ def __init__(
173
172
trust_remote_code = True ,
174
173
dtype = dtype ,
175
174
swap_space = 0 ,
175
+ disable_log_stats = disable_log_stats ,
176
176
tensor_parallel_size = tensor_parallel_size ,
177
177
)
178
178
Original file line number Diff line number Diff line change
1
+ import pytest
2
+ import vllm .engine .metrics
3
+
4
+ MODELS = [
5
+ "facebook/opt-125m" ,
6
+ ]
7
+
8
+
9
+ @pytest .mark .parametrize ("model" , MODELS )
10
+ @pytest .mark .parametrize ("dtype" , ["float" ])
11
+ @pytest .mark .parametrize ("max_tokens" , [128 ])
12
+ def test_metrics (
13
+ vllm_runner ,
14
+ example_prompts ,
15
+ model : str ,
16
+ dtype : str ,
17
+ max_tokens : int ,
18
+ ) -> None :
19
+ vllm_model = vllm_runner (model , dtype = dtype , disable_log_stats = False )
20
+ tokenizer = vllm_model .model .get_tokenizer ()
21
+ prompt_token_counts = [len (tokenizer .encode (p )) for p in example_prompts ]
22
+ # This test needs at least 2 prompts in a batch of different lengths to verify their token count is correct despite padding.
23
+ assert len (example_prompts ) > 1 , "at least 2 prompts are required"
24
+ assert prompt_token_counts [0 ] != prompt_token_counts [1 ], (
25
+ "prompts of different lengths are required" )
26
+ vllm_prompt_token_count = sum (prompt_token_counts )
27
+
28
+ _ = vllm_model .generate_greedy (example_prompts , max_tokens )
29
+ metric_count = vllm .engine .metrics .counter_prompt_tokens .get_value ({})
30
+
31
+ assert vllm_prompt_token_count == metric_count , (
32
+ f"prompt token count: { vllm_prompt_token_count !r} \n metric: { metric_count !r} "
33
+ )
Original file line number Diff line number Diff line change @@ -867,7 +867,9 @@ def _get_stats(self,
867
867
868
868
# Number of Tokens.
869
869
if prompt_run :
870
- num_prompt_tokens = scheduler_outputs .num_batched_tokens
870
+ num_prompt_tokens = sum (
871
+ len (seq_group .prompt_token_ids )
872
+ for seq_group in scheduler_outputs .scheduled_seq_groups )
871
873
else :
872
874
num_generation_tokens = scheduler_outputs .num_batched_tokens
873
875
You can’t perform that action at this time.
0 commit comments