Skip to content

Commit 1b1d24b

Browse files
authored
Make TRT-LLM max_tokens behavior explicit in GenAi-Perf (#557)
1 parent 28e9746 commit 1b1d24b

File tree

2 files changed

+42
-13
lines changed

2 files changed

+42
-13
lines changed

src/c++/perf_analyzer/genai-perf/genai_perf/llm_inputs/llm_inputs.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -608,6 +608,7 @@ def _populate_trtllm_output_json(
608608
extra_inputs: Dict = {},
609609
) -> Dict:
610610
pa_json = LlmInputs._create_empty_trtllm_pa_json()
611+
include_max_tokens = "max_tokens" not in extra_inputs
611612

612613
for index, entry in enumerate(dataset_json["rows"]):
613614
pa_json["data"].append({"text_input": [""]})
@@ -625,7 +626,9 @@ def _populate_trtllm_output_json(
625626
pa_json, index, new_text_input
626627
)
627628

628-
pa_json = LlmInputs._add_required_tags_to_trtllm_json(pa_json, index)
629+
pa_json = LlmInputs._add_required_tags_to_trtllm_json(
630+
pa_json, index, include_max_tokens
631+
)
629632
pa_json = LlmInputs._add_optional_tags_to_trtllm_json(
630633
pa_json, index, add_model_name, add_stream, model_name, extra_inputs
631634
)
@@ -816,8 +819,10 @@ def _add_required_tags_to_trtllm_json(
816819
cls,
817820
pa_json: Dict,
818821
index: int,
822+
include_max_tokens: bool,
819823
) -> Dict:
820-
pa_json["data"][index]["max_tokens"] = [LlmInputs.DEFAULT_TRTLLM_MAX_TOKENS]
824+
if include_max_tokens:
825+
pa_json["data"][index]["max_tokens"] = [LlmInputs.DEFAULT_TRTLLM_MAX_TOKENS]
821826

822827
return pa_json
823828

src/c++/perf_analyzer/genai-perf/tests/test_llm_inputs.py

Lines changed: 35 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -368,8 +368,10 @@ def test_synthetic_to_openai_completions(self, default_tokenizer):
368368
(OutputFormat.VLLM),
369369
],
370370
)
371-
def test_llm_inputs_extra_inputs(self, default_tokenizer, output_format) -> None:
372-
request_inputs = {"additional_key": "additional_value"}
371+
def test_extra_inputs(self, default_tokenizer, output_format) -> None:
372+
input_name = "max_tokens"
373+
input_value = 5
374+
request_inputs = {input_name: input_value}
373375

374376
pa_json = LlmInputs.create_llm_inputs(
375377
input_type=PromptSource.SYNTHETIC,
@@ -392,18 +394,40 @@ def test_llm_inputs_extra_inputs(self, default_tokenizer, output_format) -> None
392394
payload = entry["payload"]
393395
for item in payload:
394396
assert (
395-
"additional_key" in item
396-
), "The additional_key is not present in the request"
397+
input_name in item
398+
), f"The input name {input_name} is not present in the request"
397399
assert (
398-
item["additional_key"] == "additional_value"
399-
), "The value of additional_key is incorrect"
400+
item[input_name] == input_value
401+
), f"The value of {input_name} is incorrect"
400402
elif output_format == OutputFormat.TRTLLM or output_format == OutputFormat.VLLM:
401403
for entry in pa_json["data"]:
402404
assert (
403-
"additional_key" in entry
404-
), "The additional_key is not present in the request"
405-
assert entry["additional_key"] == [
406-
"additional_value"
407-
], "The value of additional_key is incorrect"
405+
input_name in entry
406+
), f"The {input_name} is not present in the request"
407+
assert entry[input_name] == [
408+
input_value
409+
], f"The value of {input_name} is incorrect"
408410
else:
409411
assert False, f"Unsupported output format: {output_format}"
412+
413+
def test_trtllm_default_max_tokens(self, default_tokenizer) -> None:
414+
input_name = "max_tokens"
415+
input_value = 256
416+
417+
pa_json = LlmInputs.create_llm_inputs(
418+
input_type=PromptSource.SYNTHETIC,
419+
output_format=OutputFormat.TRTLLM,
420+
num_of_output_prompts=5,
421+
add_model_name=False,
422+
add_stream=True,
423+
tokenizer=default_tokenizer,
424+
)
425+
426+
assert len(pa_json["data"]) == 5
427+
for entry in pa_json["data"]:
428+
assert (
429+
input_name in entry
430+
), f"The {input_name} is not present in the request"
431+
assert entry[input_name] == [
432+
input_value
433+
], f"The value of {input_name} is incorrect"

0 commit comments

Comments
 (0)