Skip to content

Commit e433c11

Browse files
authored
Fix vllm:prompt_tokens_total metric calculation (#2869)
1 parent 86fd8bb commit e433c11

File tree

3 files changed

+41
-6
lines changed

3 files changed

+41
-6
lines changed

tests/conftest.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,10 @@
1313
_LONG_PROMPTS = [os.path.join(_TEST_DIR, "prompts", "summary.txt")]
1414

1515

16-
def _read_prompts(filename: str) -> str:
17-
prompts = []
16+
def _read_prompts(filename: str) -> List[str]:
1817
with open(filename, "r") as f:
19-
prompt = f.readline()
20-
prompts.append(prompt)
21-
return prompts
18+
prompts = f.readlines()
19+
return prompts
2220

2321

2422
@pytest.fixture
@@ -165,6 +163,7 @@ def __init__(
165163
model_name: str,
166164
tokenizer_name: Optional[str] = None,
167165
dtype: str = "half",
166+
disable_log_stats: bool = True,
168167
tensor_parallel_size: int = 1,
169168
) -> None:
170169
self.model = LLM(
@@ -173,6 +172,7 @@ def __init__(
173172
trust_remote_code=True,
174173
dtype=dtype,
175174
swap_space=0,
175+
disable_log_stats=disable_log_stats,
176176
tensor_parallel_size=tensor_parallel_size,
177177
)
178178

tests/metrics/test_metrics.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
import pytest
2+
import vllm.engine.metrics
3+
4+
MODELS = [
5+
"facebook/opt-125m",
6+
]
7+
8+
9+
@pytest.mark.parametrize("model", MODELS)
10+
@pytest.mark.parametrize("dtype", ["float"])
11+
@pytest.mark.parametrize("max_tokens", [128])
12+
def test_metrics(
13+
vllm_runner,
14+
example_prompts,
15+
model: str,
16+
dtype: str,
17+
max_tokens: int,
18+
) -> None:
19+
vllm_model = vllm_runner(model, dtype=dtype, disable_log_stats=False)
20+
tokenizer = vllm_model.model.get_tokenizer()
21+
prompt_token_counts = [len(tokenizer.encode(p)) for p in example_prompts]
22+
# This test needs at least 2 prompts in a batch of different lengths to verify their token count is correct despite padding.
23+
assert len(example_prompts) > 1, "at least 2 prompts are required"
24+
assert prompt_token_counts[0] != prompt_token_counts[1], (
25+
"prompts of different lengths are required")
26+
vllm_prompt_token_count = sum(prompt_token_counts)
27+
28+
_ = vllm_model.generate_greedy(example_prompts, max_tokens)
29+
metric_count = vllm.engine.metrics.counter_prompt_tokens.get_value({})
30+
31+
assert vllm_prompt_token_count == metric_count, (
32+
f"prompt token count: {vllm_prompt_token_count!r}\nmetric: {metric_count!r}"
33+
)

vllm/engine/llm_engine.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -867,7 +867,9 @@ def _get_stats(self,
867867

868868
# Number of Tokens.
869869
if prompt_run:
870-
num_prompt_tokens = scheduler_outputs.num_batched_tokens
870+
num_prompt_tokens = sum(
871+
len(seq_group.prompt_token_ids)
872+
for seq_group in scheduler_outputs.scheduled_seq_groups)
871873
else:
872874
num_generation_tokens = scheduler_outputs.num_batched_tokens
873875

0 commit comments

Comments
 (0)