-
Notifications
You must be signed in to change notification settings - Fork 246
feature: triton generate support #675
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 15 commits
c822917
fddba6d
61ae566
f1b4646
56e662c
5503c85
bfc1c7a
d576d25
1f50b6e
8dcc53d
7f740a9
73ffa08
056c099
e7b3e53
0c7f5a1
1026361
20a3c07
69480bd
9f153f8
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -63,6 +63,14 @@ namespace openai { | |
| void | ||
| ChatCompletionRequest::SendResponse(bool is_final, bool is_null) | ||
| { | ||
| // if final response has already been sent | ||
|
||
| // due to detecting the [DONE] | ||
| // ignore final response due to request completion | ||
| if (final_response_sent_) { | ||
| return; | ||
| } | ||
|
|
||
| final_response_sent_ = is_final; | ||
| response_callback_(new ChatCompletionResult( | ||
| http_code_, std::move(response_buffer_), is_final, is_null, request_id_)); | ||
| } | ||
|
|
@@ -172,9 +180,11 @@ ChatCompletionClient::AsyncInfer( | |
| request->timer_.CaptureTimestamp( | ||
| triton::client::RequestTimers::Kind::REQUEST_END); | ||
| UpdateInferStat(request->timer_); | ||
| if (!request->is_stream_) { | ||
| request->SendResponse(true /* is_final */, false /* is_null */); | ||
| } | ||
|
|
||
| // Send Response checks if a final | ||
| // response has already been sent | ||
| // (in the case of seeing [DONE] in streaming case) | ||
| request->SendResponse(true /* is_final */, false /* is_null */); | ||
| }; | ||
| std::unique_ptr<HttpRequest> request(new ChatCompletionRequest( | ||
| std::move(completion_callback), std::move(callback), request_id, | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -46,6 +46,7 @@ class ResponseFormat(Enum): | |
| OPENAI_CHAT_COMPLETIONS = auto() | ||
| OPENAI_COMPLETIONS = auto() | ||
| TRITON = auto() | ||
| TRITON_GENERATE = auto() | ||
|
|
||
|
|
||
| class Metrics: | ||
|
|
@@ -446,6 +447,8 @@ def _get_profile_metadata(self, data: dict) -> None: | |
| self._response_format = ResponseFormat.OPENAI_CHAT_COMPLETIONS | ||
| elif data["endpoint"] == "v1/completions": | ||
| self._response_format = ResponseFormat.OPENAI_COMPLETIONS | ||
| elif "generate" in data["endpoint"]: | ||
| self._response_format = ResponseFormat.TRITON_GENERATE | ||
| else: | ||
| # TPA-66: add PA metadata to handle this case | ||
| # When endpoint field is either empty or custom endpoint, fall | ||
|
|
@@ -662,6 +665,8 @@ def _tokenize_openai_request_input(self, req_inputs: dict) -> List[int]: | |
| input_text = payload["messages"][0]["content"] | ||
| elif self._response_format == ResponseFormat.OPENAI_COMPLETIONS: | ||
| input_text = payload["prompt"] | ||
| elif self._response_format == ResponseFormat.TRITON_GENERATE: | ||
| input_text = payload["text_input"] | ||
| else: | ||
| raise ValueError( | ||
| "Failed to parse OpenAI request input in profile export file." | ||
|
|
@@ -689,7 +694,10 @@ def _tokenize_openai_response_output(self, res_outputs: dict) -> List[List[int]] | |
| """Tokenize the OpenAI response output texts.""" | ||
| output_texts = [] | ||
| for output in res_outputs: | ||
| text = self._extract_openai_text_output(output["response"]) | ||
| if self._response_format == ResponseFormat.TRITON_GENERATE: | ||
| text = self._extract_generate_text_output(output["response"]) | ||
| else: | ||
| text = self._extract_openai_text_output(output["response"]) | ||
| output_texts.append(text) | ||
| return self._run_tokenizer(output_texts) | ||
|
|
||
|
|
@@ -702,6 +710,15 @@ def _run_tokenizer(self, output_texts: List[str]) -> List[List[int]]: | |
| encodings = self._tokenizer(output_texts) | ||
| return [out[1:] for out in encodings.data["input_ids"]] | ||
|
|
||
| def _extract_generate_text_output(self, response: str) -> str: | ||
| response = remove_sse_prefix(response) | ||
|
|
||
| if response == "": | ||
| return response | ||
|
|
||
| data = json.loads(response) | ||
| return data["text_output"] | ||
|
|
||
| def _extract_openai_text_output(self, response: str) -> str: | ||
| """Extracts text/content of the OpenAI response object.""" | ||
| response = remove_sse_prefix(response) | ||
|
|
@@ -731,7 +748,10 @@ def _extract_openai_text_output(self, response: str) -> str: | |
|
|
||
| def _is_openai_empty_response(self, response: str) -> bool: | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should change the name of the function since it's no longer just openai |
||
| """Returns true if the response is an openai response with no content (or empty content)""" | ||
| text = self._extract_openai_text_output(response) | ||
| if self._response_format == ResponseFormat.TRITON_GENERATE: | ||
| text = self._extract_generate_text_output(response) | ||
| else: | ||
| text = self._extract_openai_text_output(response) | ||
| if text: | ||
| return False | ||
| return True | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -51,7 +51,12 @@ | |
|
|
||
| logger = logging.getLogger(__name__) | ||
|
|
||
| _endpoint_type_map = {"chat": "v1/chat/completions", "completions": "v1/completions"} | ||
| _endpoint_type_map = { | ||
| "chat": "v1/chat/completions", | ||
| "completions": "v1/completions", | ||
| "generate": "v2/models/{MODEL_NAME}/generate", | ||
| "kserve": "v2/models/{MODEL_NAME}/infer", | ||
| } | ||
|
|
||
|
|
||
| def _check_model_args( | ||
|
|
@@ -98,31 +103,31 @@ def _check_conditional_args( | |
| Check for conditional args and raise an error if they are not set. | ||
| """ | ||
|
|
||
| # Endpoint and output format checks | ||
| if args.service_kind == "openai": | ||
| if args.endpoint_type is None: | ||
| parser.error( | ||
| "The --endpoint-type option is required when using the 'openai' service-kind." | ||
| ) | ||
| else: | ||
| if args.endpoint_type == "chat": | ||
| args.output_format = OutputFormat.OPENAI_CHAT_COMPLETIONS | ||
| elif args.endpoint_type == "completions": | ||
| args.output_format = OutputFormat.OPENAI_COMPLETIONS | ||
|
|
||
| if args.endpoint is not None: | ||
| args.endpoint = args.endpoint.lstrip(" /") | ||
| else: | ||
| args.endpoint = _endpoint_type_map[args.endpoint_type] | ||
| elif args.endpoint_type is not None: | ||
| parser.error( | ||
| "The --endpoint-type option should only be used when using the 'openai' service-kind." | ||
| ) | ||
|
|
||
| if args.service_kind == "triton": | ||
| if args.endpoint_type == "chat": | ||
| args.output_format = OutputFormat.OPENAI_CHAT_COMPLETIONS | ||
| args.service_kind = "openai" | ||
| elif args.endpoint_type == "completions": | ||
| args.output_format = OutputFormat.OPENAI_COMPLETIONS | ||
| args.service_kind = "openai" | ||
| elif args.endpoint_type == "generate": | ||
| args.output_format = OutputFormat.TRITON_GENERATE | ||
| args.service_kind = "openai" | ||
| elif args.endpoint_type == "kserve": | ||
| args.service_kind = "triton" | ||
| args = _convert_str_to_enum_entry(args, "backend", OutputFormat) | ||
| args.output_format = args.backend | ||
|
|
||
| if args.endpoint is not None: | ||
| args.endpoint = args.endpoint.lstrip(" /") | ||
| else: | ||
| if args.model: | ||
| model_name = args.model[0] | ||
| else: | ||
| model_name = "" | ||
| args.endpoint = _endpoint_type_map[args.endpoint_type].format( | ||
| MODEL_NAME=model_name | ||
| ) | ||
|
|
||
| # Output token distribution checks | ||
| if args.output_tokens_mean == LlmInputs.DEFAULT_OUTPUT_TOKENS_MEAN: | ||
| if args.output_tokens_stddev != LlmInputs.DEFAULT_OUTPUT_TOKENS_STDDEV: | ||
|
|
@@ -137,7 +142,7 @@ def _check_conditional_args( | |
| if args.service_kind != "triton": | ||
| if args.output_tokens_mean_deterministic: | ||
| parser.error( | ||
| "The --output-tokens-mean-deterministic option is only supported with the Triton service-kind." | ||
| "The --output-tokens-mean-deterministic option is only supported with the kserve endpoint type." | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should the code be changed to check endpoint_type != kserve? I know that with the current code it is the same result, but it introduces an assumption (endpoint kserve -> service_kind triton) that could trip up a future developer. |
||
| ) | ||
|
|
||
| return args | ||
|
|
@@ -272,7 +277,7 @@ def _add_input_args(parser): | |
| help=f"When using --output-tokens-mean, this flag can be set to " | ||
| "improve precision by setting the minimum number of tokens " | ||
| "equal to the requested number of tokens. This is currently " | ||
| "supported with the Triton service-kind. " | ||
| "supported with the kserve endpoint type. " | ||
| "Note that there is still some variability in the requested number " | ||
| "of output tokens, but GenAi-Perf attempts its best effort with your " | ||
| "model to get the right number of output tokens. ", | ||
|
|
@@ -380,10 +385,10 @@ def _add_endpoint_args(parser): | |
| endpoint_group.add_argument( | ||
| "--backend", | ||
| type=str, | ||
| choices=utils.get_enum_names(OutputFormat)[2:], | ||
| choices=["tensorrtllm", "vllm"], | ||
debermudez marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| default="tensorrtllm", | ||
| required=False, | ||
| help=f'When using the "triton" service-kind, ' | ||
| help=f'When using the "kserve" endpoint type, ' | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can generate endpoint not use trtllm vs vllm? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It can - I haven't added any different behavior for the different backends. Actually - it has only been tested against vllm at the moment. So this is fair point ... Let me move this back to draft - plan to test trt-llm in the next week or so |
||
| "this is the backend of the model. " | ||
| "For the TENSORRT-LLM backend, you currently must set " | ||
| "'exclude_input_in_output' to true in the model config to " | ||
|
|
@@ -400,21 +405,10 @@ def _add_endpoint_args(parser): | |
| endpoint_group.add_argument( | ||
| "--endpoint-type", | ||
| type=str, | ||
| choices=["chat", "completions"], | ||
| required=False, | ||
| help=f"The endpoint-type to send requests to on the " | ||
| 'server. This is only used with the "openai" service-kind.', | ||
| ) | ||
|
|
||
| endpoint_group.add_argument( | ||
| "--service-kind", | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Finally one less CLI option 😄 Can we also update the README to reflect the changes in CLI options? |
||
| type=str, | ||
| choices=["triton", "openai"], | ||
tgerdesnv marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| default="triton", | ||
| choices=["chat", "completions", "generate", "kserve"], | ||
| default="kserve", | ||
| required=False, | ||
| help="The kind of service perf_analyzer will " | ||
| 'generate load for. In order to use "openai", ' | ||
| "you must specify an api via --endpoint-type.", | ||
| help=f"The endpoint-type for requests. Inputs will be formatted according to endpoint-type.", | ||
| ) | ||
|
|
||
| endpoint_group.add_argument( | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.