@@ -368,8 +368,10 @@ def test_synthetic_to_openai_completions(self, default_tokenizer):
368368 (OutputFormat .VLLM ),
369369 ],
370370 )
371- def test_llm_inputs_extra_inputs (self , default_tokenizer , output_format ) -> None :
372- request_inputs = {"additional_key" : "additional_value" }
371+ def test_extra_inputs (self , default_tokenizer , output_format ) -> None :
372+ input_name = "max_tokens"
373+ input_value = 5
374+ request_inputs = {input_name : input_value }
373375
374376 pa_json = LlmInputs .create_llm_inputs (
375377 input_type = PromptSource .SYNTHETIC ,
@@ -392,18 +394,40 @@ def test_llm_inputs_extra_inputs(self, default_tokenizer, output_format) -> None
392394 payload = entry ["payload" ]
393395 for item in payload :
394396 assert (
395- "additional_key" in item
396- ), "The additional_key is not present in the request"
397+ input_name in item
398+ ), f "The input name { input_name } is not present in the request"
397399 assert (
398- item ["additional_key" ] == "additional_value"
399- ), "The value of additional_key is incorrect"
400+ item [input_name ] == input_value
401+ ), f "The value of { input_name } is incorrect"
400402 elif output_format == OutputFormat .TRTLLM or output_format == OutputFormat .VLLM :
401403 for entry in pa_json ["data" ]:
402404 assert (
403- "additional_key" in entry
404- ), "The additional_key is not present in the request"
405- assert entry ["additional_key" ] == [
406- "additional_value"
407- ], "The value of additional_key is incorrect"
405+ input_name in entry
406+ ), f "The { input_name } is not present in the request"
407+ assert entry [input_name ] == [
408+ input_value
409+ ], f "The value of { input_name } is incorrect"
408410 else :
409411 assert False , f"Unsupported output format: { output_format } "
412+
413+ def test_trtllm_default_max_tokens (self , default_tokenizer ) -> None :
414+ input_name = "max_tokens"
415+ input_value = 256
416+
417+ pa_json = LlmInputs .create_llm_inputs (
418+ input_type = PromptSource .SYNTHETIC ,
419+ output_format = OutputFormat .TRTLLM ,
420+ num_of_output_prompts = 5 ,
421+ add_model_name = False ,
422+ add_stream = True ,
423+ tokenizer = default_tokenizer ,
424+ )
425+
426+ assert len (pa_json ["data" ]) == 5
427+ for entry in pa_json ["data" ]:
428+ assert (
429+ input_name in entry
430+ ), f"The { input_name } is not present in the request"
431+ assert entry [input_name ] == [
432+ input_value
433+ ], f"The value of { input_name } is incorrect"
0 commit comments