Skip to content

Commit 61ae566

Browse files
committed
initial generate support
1 parent fddba6d commit 61ae566

File tree

5 files changed

+156
-13
lines changed

5 files changed

+156
-13
lines changed

src/c++/perf_analyzer/client_backend/openai/openai_client.cc

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,14 @@ namespace openai {
6363
void
6464
ChatCompletionRequest::SendResponse(bool is_final, bool is_null)
6565
{
66+
// if final response has already been sent
67+
// due to detecting the [DONE]
68+
// ignore final response due to request completion
69+
if (final_response_sent_) {
70+
return;
71+
}
72+
73+
final_response_sent_ = is_final;
6674
response_callback_(new ChatCompletionResult(
6775
http_code_, std::move(response_buffer_), is_final, is_null, request_id_));
6876
}
@@ -107,13 +115,15 @@ ChatCompletionClient::ResponseHeaderHandler(
107115
hdr.find("text/event-stream") != std::string::npos) {
108116
request->is_stream_ = true;
109117
}
118+
110119
return byte_size;
111120
}
112121

113122
size_t
114123
ChatCompletionClient::ResponseHandler(
115124
void* contents, size_t size, size_t nmemb, void* userp)
116125
{
126+
117127
// [TODO TMA-1666] verify if the SSE responses received are complete, or the
118128
// response need to be stitched first. To verify, print out the received
119129
// responses from SendResponse() to make sure the OpenAI server doesn't chunk
@@ -151,7 +161,7 @@ ChatCompletionClient::ResponseHandler(
151161
// RECV_END so that we always have the time of the last.
152162
request->timer_.CaptureTimestamp(
153163
triton::client::RequestTimers::Kind::RECV_END);
154-
164+
155165
return result_bytes;
156166
}
157167

@@ -162,6 +172,8 @@ ChatCompletionClient::AsyncInfer(
162172
std::string& serialized_request_body, const std::string& request_id,
163173
const Headers& headers)
164174
{
175+
176+
165177
if (callback == nullptr) {
166178
return Error(
167179
"Callback function must be provided along with AsyncInfer() call.");
@@ -172,9 +184,14 @@ ChatCompletionClient::AsyncInfer(
172184
request->timer_.CaptureTimestamp(
173185
triton::client::RequestTimers::Kind::REQUEST_END);
174186
UpdateInferStat(request->timer_);
175-
if (!request->is_stream_) {
176-
request->SendResponse(true /* is_final */, false /* is_null */);
177-
}
187+
188+
// Updated to be ok to call multiple times
189+
// will only send the first final response
190+
//
191+
// if (!request->is_stream_) {
192+
//
193+
request->SendResponse(true /* is_final */, false /* is_null */);
194+
// }
178195
};
179196
std::unique_ptr<HttpRequest> request(new ChatCompletionRequest(
180197
std::move(completion_callback), std::move(callback), request_id,
@@ -185,7 +202,7 @@ ChatCompletionClient::AsyncInfer(
185202
request->AddInput(
186203
reinterpret_cast<uint8_t*>(serialized_request_body.data()),
187204
serialized_request_body.size());
188-
205+
189206
CURL* multi_easy_handle = curl_easy_init();
190207
Error err = PreRunProcessing(multi_easy_handle, raw_request, headers);
191208
if (!err.IsOk()) {
@@ -226,7 +243,7 @@ ChatCompletionClient::PreRunProcessing(
226243

227244
// response data handled by ResponseHandler()
228245
curl_easy_setopt(curl, CURLOPT_WRITEFUNCTION, ResponseHandler);
229-
curl_easy_setopt(curl, CURLOPT_WRITEDATA, request);
246+
curl_easy_setopt(curl, CURLOPT_WRITEDATA, request);
230247

231248
const curl_off_t post_byte_size = request->total_input_byte_size_;
232249
curl_easy_setopt(curl, CURLOPT_POSTFIELDSIZE_LARGE, post_byte_size);

src/c++/perf_analyzer/client_backend/openai/openai_client.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,7 @@ class ChatCompletionRequest : public HttpRequest {
127127
// The timers for infer request.
128128
triton::client::RequestTimers timer_;
129129
const std::string request_id_;
130+
bool final_response_sent_{false};
130131
};
131132

132133
class ChatCompletionClient : public HttpClient {
@@ -172,7 +173,7 @@ class ChatCompletionClient : public HttpClient {
172173
void* contents, size_t size, size_t nmemb, void* userp);
173174
static size_t ResponseHeaderHandler(
174175
void* contents, size_t size, size_t nmemb, void* userp);
175-
176+
176177
Error UpdateInferStat(const triton::client::RequestTimers& timer);
177178
InferStat infer_stat_;
178179
};

src/c++/perf_analyzer/genai-perf/genai_perf/llm_inputs/llm_inputs.py

Lines changed: 103 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ class PromptSource(Enum):
4141
class OutputFormat(Enum):
4242
OPENAI_CHAT_COMPLETIONS = auto()
4343
OPENAI_COMPLETIONS = auto()
44+
TRITON_GENERATE = auto()
4445
TENSORRTLLM = auto()
4546
VLLM = auto()
4647

@@ -364,7 +365,18 @@ def _convert_generic_json_to_output_format(
364365
model_name: list = [],
365366
model_selection_strategy: ModelSelectionStrategy = ModelSelectionStrategy.ROUND_ROBIN,
366367
) -> Dict:
367-
if output_format == OutputFormat.OPENAI_CHAT_COMPLETIONS:
368+
if output_format == OutputFormat.TRITON_GENERATE:
369+
output_json = cls._convert_generic_json_to_generate_format(
370+
generic_dataset,
371+
add_model_name,
372+
add_stream,
373+
extra_inputs,
374+
output_tokens_mean,
375+
output_tokens_stddev,
376+
output_tokens_deterministic,
377+
model_name,
378+
)
379+
elif output_format == OutputFormat.OPENAI_CHAT_COMPLETIONS:
368380
output_json = cls._convert_generic_json_to_openai_chat_completions_format(
369381
generic_dataset,
370382
add_model_name,
@@ -454,6 +466,43 @@ def _convert_generic_json_to_openai_chat_completions_format(
454466

455467
return pa_json
456468

469+
@classmethod
470+
def _convert_generic_json_to_generate_format(
471+
cls,
472+
dataset_json: Dict,
473+
add_model_name: bool,
474+
add_stream: bool,
475+
extra_inputs: Dict,
476+
output_tokens_mean: int,
477+
output_tokens_stddev: int,
478+
output_tokens_deterministic: bool,
479+
model_name: str = "",
480+
) -> Dict:
481+
482+
(
483+
system_role_headers,
484+
user_role_headers,
485+
text_input_headers,
486+
) = cls._determine_json_feature_roles(dataset_json)
487+
488+
489+
pa_json = cls._populate_triton_generate_output_json(
490+
dataset_json,
491+
system_role_headers,
492+
user_role_headers,
493+
text_input_headers,
494+
add_model_name,
495+
add_stream,
496+
extra_inputs,
497+
output_tokens_mean,
498+
output_tokens_stddev,
499+
output_tokens_deterministic,
500+
model_name,
501+
)
502+
503+
return pa_json
504+
505+
457506
@classmethod
458507
def _convert_generic_json_to_openai_completions_format(
459508
cls,
@@ -652,6 +701,59 @@ def _populate_openai_chat_completions_output_json(
652701
)
653702

654703
return pa_json
704+
705+
@classmethod
706+
def _populate_triton_generate_output_json(
707+
cls,
708+
dataset: Dict,
709+
system_role_headers: List[str],
710+
user_role_headers: List[str],
711+
text_input_headers: List[str],
712+
add_model_name: bool,
713+
add_stream: bool,
714+
extra_inputs: Dict,
715+
output_tokens_mean: int,
716+
output_tokens_stddev: int,
717+
output_tokens_deterministic: bool,
718+
model_name: str = "",
719+
) -> Dict:
720+
number_of_rows = len(dataset["rows"])
721+
pa_json = cls._create_empty_trtllm_pa_json()
722+
723+
default_max_tokens = (
724+
"max_tokens" not in extra_inputs
725+
or output_tokens_mean != cls.DEFAULT_OUTPUT_TOKENS_MEAN
726+
)
727+
728+
pa_json = {"data":[{"payload":[{}]} for _ in dataset["rows"]]}
729+
730+
for index, entry in enumerate(dataset["rows"]):
731+
732+
for header, content in entry.items():
733+
new_text_input = cls._create_new_text_input(
734+
header,
735+
system_role_headers,
736+
user_role_headers,
737+
text_input_headers,
738+
content,
739+
)
740+
pa_json["data"][index]["payload"][0]["text_input"] = new_text_input
741+
742+
pa_json = cls._add_optional_tags_to_openai_json(
743+
pa_json,
744+
index,
745+
False,
746+
add_stream,
747+
extra_inputs,
748+
output_tokens_mean,
749+
output_tokens_stddev,
750+
output_tokens_deterministic,
751+
model_name,
752+
)
753+
754+
return pa_json
755+
756+
655757

656758
@classmethod
657759
def _populate_openai_completions_output_json(

src/c++/perf_analyzer/genai-perf/genai_perf/llm_metrics.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ class ResponseFormat(Enum):
4646
OPENAI_CHAT_COMPLETIONS = auto()
4747
OPENAI_COMPLETIONS = auto()
4848
TRITON = auto()
49+
TRITON_GENERATE = auto()
4950

5051

5152
class Metrics:
@@ -446,6 +447,8 @@ def _get_profile_metadata(self, data: dict) -> None:
446447
self._response_format = ResponseFormat.OPENAI_CHAT_COMPLETIONS
447448
elif data["endpoint"] == "v1/completions":
448449
self._response_format = ResponseFormat.OPENAI_COMPLETIONS
450+
elif "generate" in data["endpoint"]:
451+
self._response_format = ResponseFormat.TRITON_GENERATE
449452
else:
450453
# TPA-66: add PA metadata to handle this case
451454
# When endpoint field is either empty or custom endpoint, fall
@@ -662,6 +665,8 @@ def _tokenize_openai_request_input(self, req_inputs: dict) -> List[int]:
662665
input_text = payload["messages"][0]["content"]
663666
elif self._response_format == ResponseFormat.OPENAI_COMPLETIONS:
664667
input_text = payload["prompt"]
668+
elif self._response_format == ResponseFormat.TRITON_GENERATE:
669+
input_text = payload["text_input"]
665670
else:
666671
raise ValueError(
667672
"Failed to parse OpenAI request input in profile export file."
@@ -689,7 +694,10 @@ def _tokenize_openai_response_output(self, res_outputs: dict) -> List[List[int]]
689694
"""Tokenize the OpenAI response output texts."""
690695
output_texts = []
691696
for output in res_outputs:
692-
text = self._extract_openai_text_output(output["response"])
697+
if self._response_format == ResponseFormat.TRITON_GENERATE:
698+
text = self._extract_generate_text_output(output["response"])
699+
else:
700+
text = self._extract_openai_text_output(output["response"])
693701
output_texts.append(text)
694702
return self._run_tokenizer(output_texts)
695703

@@ -702,6 +710,16 @@ def _run_tokenizer(self, output_texts: List[str]) -> List[List[int]]:
702710
encodings = self._tokenizer(output_texts)
703711
return [out[1:] for out in encodings.data["input_ids"]]
704712

713+
def _extract_generate_text_output(self, response: str) -> str:
714+
715+
response = remove_sse_prefix(response)
716+
717+
if response == "":
718+
return response
719+
720+
data = json.loads(response)
721+
return data["text_output"]
722+
705723
def _extract_openai_text_output(self, response: str) -> str:
706724
"""Extracts text/content of the OpenAI response object."""
707725
response = remove_sse_prefix(response)
@@ -731,7 +749,10 @@ def _extract_openai_text_output(self, response: str) -> str:
731749

732750
def _is_openai_empty_response(self, response: str) -> bool:
733751
"""Returns true if the response is an openai response with no content (or empty content)"""
734-
text = self._extract_openai_text_output(response)
752+
if self._response_format == ResponseFormat.TRITON_GENERATE:
753+
text = self._extract_generate_text_output(response)
754+
else:
755+
text = self._extract_openai_text_output(response)
735756
if text:
736757
return False
737758
return True

src/c++/perf_analyzer/genai-perf/genai_perf/parser.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@
5151

5252
logger = logging.getLogger(__name__)
5353

54-
_endpoint_type_map = {"chat": "v1/chat/completions", "completions": "v1/completions"}
54+
_endpoint_type_map = {"chat": "v1/chat/completions", "completions": "v1/completions", "generate":"v2/models/{MODEL_NAME}/generate"}
5555

5656

5757
def _check_model_args(
@@ -109,11 +109,13 @@ def _check_conditional_args(
109109
args.output_format = OutputFormat.OPENAI_CHAT_COMPLETIONS
110110
elif args.endpoint_type == "completions":
111111
args.output_format = OutputFormat.OPENAI_COMPLETIONS
112+
elif args.endpoint_type == "generate":
113+
args.output_format = OutputFormat.TRITON_GENERATE
112114

113115
if args.endpoint is not None:
114116
args.endpoint = args.endpoint.lstrip(" /")
115117
else:
116-
args.endpoint = _endpoint_type_map[args.endpoint_type]
118+
args.endpoint = _endpoint_type_map[args.endpoint_type].format(MODEL_NAME=args.model)
117119
elif args.endpoint_type is not None:
118120
parser.error(
119121
"The --endpoint-type option should only be used when using the 'openai' service-kind."
@@ -400,7 +402,7 @@ def _add_endpoint_args(parser):
400402
endpoint_group.add_argument(
401403
"--endpoint-type",
402404
type=str,
403-
choices=["chat", "completions"],
405+
choices=["chat", "completions", "generate"],
404406
required=False,
405407
help=f"The endpoint-type to send requests to on the "
406408
'server. This is only used with the "openai" service-kind.',

0 commit comments

Comments
 (0)