Skip to content

Commit d1780d1

Browse files
feat: Support for request id field in generate API (#7392)
1 parent ac0d4d6 commit d1780d1

File tree

4 files changed

+55
-4
lines changed

4 files changed

+55
-4
lines changed

docs/protocol/extension_generate.md

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
<!--
2-
# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# Copyright 2023-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
33
#
44
# Redistribution and use in source and binary forms, with or without
55
# modification, are permitted provided that the following conditions
@@ -87,10 +87,12 @@ return an error.
8787

8888
$generate_request =
8989
{
90+
"id" : $string, #optional
9091
"text_input" : $string,
9192
"parameters" : $parameters #optional
9293
}
9394

95+
* "id": An identifier for this request. Optional, but if specified this identifier must be returned in the response.
9496
* "text_input" : The text input that the model should generate output from.
9597
* "parameters" : An optional object containing zero or more parameters for this
9698
generate request expressed as key/value pairs. See
@@ -121,14 +123,15 @@ specification to set the parameters.
121123
Below is an example to send generate request with additional model parameters `stream` and `temperature`.
122124

123125
```
124-
$ curl -X POST localhost:8000/v2/models/mymodel/generate -d '{"text_input": "client input", "parameters": {"stream": false, "temperature": 0}}'
126+
$ curl -X POST localhost:8000/v2/models/mymodel/generate -d '{"id": "42", "text_input": "client input", "parameters": {"stream": false, "temperature": 0}}'
125127
126128
POST /v2/models/mymodel/generate HTTP/1.1
127129
Host: localhost:8000
128130
Content-Type: application/json
129131
Content-Length: <xx>
130132
{
131-
"text_input": "client input",
133+
"id" : "42",
134+
"text_input" : "client input",
132135
"parameters" :
133136
{
134137
"stream": false,
@@ -145,11 +148,13 @@ the HTTP body.
145148

146149
$generate_response =
147150
{
151+
"id" : $string
148152
"model_name" : $string,
149153
"model_version" : $string,
150154
"text_output" : $string
151155
}
152156

157+
* "id" : The "id" identifier given in the request, if any.
153158
* "model_name" : The name of the model used for inference.
154159
* "model_version" : The specific model version used for inference.
155160
* "text_output" : The output of the inference.
@@ -159,6 +164,7 @@ the HTTP body.
159164
```
160165
200
161166
{
167+
"id" : "42"
162168
"model_name" : "mymodel",
163169
"model_version" : "1",
164170
"text_output" : "model output"

qa/L0_http/generate_endpoint_test.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,49 @@ def test_generate(self):
142142
self.assertIn("TEXT", data)
143143
self.assertEqual(text, data["TEXT"])
144144

145+
def test_request_id(self):
146+
# Setup text based input
147+
text = "hello world"
148+
request_id = "42"
149+
150+
# Test when request id in request body
151+
inputs = {"PROMPT": text, "id": request_id, "STREAM": False}
152+
r = self.generate(self._model_name, inputs)
153+
r.raise_for_status()
154+
155+
self.assertIn("Content-Type", r.headers)
156+
self.assertEqual(r.headers["Content-Type"], "application/json")
157+
158+
data = r.json()
159+
self.assertIn("id", data)
160+
self.assertEqual(request_id, data["id"])
161+
self.assertIn("TEXT", data)
162+
self.assertEqual(text, data["TEXT"])
163+
164+
# Test when request id not in request body
165+
inputs = {"PROMPT": text, "STREAM": False}
166+
r = self.generate(self._model_name, inputs)
167+
r.raise_for_status()
168+
169+
self.assertIn("Content-Type", r.headers)
170+
self.assertEqual(r.headers["Content-Type"], "application/json")
171+
172+
data = r.json()
173+
self.assertNotIn("id", data)
174+
175+
# Test when request id is empty
176+
inputs = {"PROMPT": text, "id": "", "STREAM": False}
177+
r = self.generate(self._model_name, inputs)
178+
r.raise_for_status()
179+
180+
self.assertIn("Content-Type", r.headers)
181+
self.assertEqual(r.headers["Content-Type"], "application/json")
182+
183+
data = r.json()
184+
self.assertNotIn("id", data)
185+
self.assertIn("TEXT", data)
186+
self.assertEqual(text, data["TEXT"])
187+
145188
def test_generate_stream(self):
146189
# Setup text-based input
147190
text = "hello world"

qa/L0_http/test.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -662,7 +662,7 @@ fi
662662
## Python Unit Tests
663663
TEST_RESULT_FILE='test_results.txt'
664664
PYTHON_TEST=generate_endpoint_test.py
665-
EXPECTED_NUM_TESTS=15
665+
EXPECTED_NUM_TESTS=16
666666
set +e
667667
python $PYTHON_TEST >$CLIENT_LOG 2>&1
668668
if [ $? -ne 0 ]; then

src/http_server.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3327,6 +3327,8 @@ HTTPAPIServer::HandleGenerate(
33273327
// thus the string must live as long as the JSON message).
33283328
triton::common::TritonJson::Value request;
33293329
RETURN_AND_CALLBACK_IF_ERR(EVRequestToJson(req, &request), error_callback);
3330+
RETURN_AND_CALLBACK_IF_ERR(
3331+
ParseJsonTritonRequestID(request, irequest), error_callback);
33303332

33313333
RETURN_AND_CALLBACK_IF_ERR(
33323334
generate_request->ConvertGenerateRequest(

0 commit comments

Comments
 (0)