Skip to content

Commit d0319e2

Browse files
authored
Ianmacleod/completion sync error throws 4xx (#234)
* changing 5xx error to 4xx error * . * . * adding completion stream changes * parsing error dictionary * . * . * fixing error handling for 400 * . * hacky way of fixing completion stream w error message * cleanup * cleanup, add docs * . * fixing indentation on docs
1 parent f3b0306 commit d0319e2

File tree

4 files changed

+58
-21
lines changed

4 files changed

+58
-21
lines changed

docs/getting_started.md

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,11 @@ stream = Completion.create(
7474
)
7575

7676
for response in stream:
77-
if response.output:
78-
print(response.output.text, end="")
79-
sys.stdout.flush()
77+
try:
78+
if response.output:
79+
print(response.output.text, end="")
80+
sys.stdout.flush()
81+
except: # an error occurred
82+
print(stream.text) # print the error message out
83+
break
8084
```

docs/guides/completions.md

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,8 @@ applications. When streaming, tokens will be sent as data-only
6767

6868
To enable token streaming, pass `stream=True` to either [Completion.create](../../api/python_client/#llmengine.completion.Completion.create) or [Completion.acreate](../../api/python_client/#llmengine.completion.Completion.acreate).
6969

70+
Note that errors from streaming calls are returned back to the user as plain-text messages and currently need to be handled by the client.
71+
7072
An example of token streaming using the synchronous Completions API looks as follows:
7173

7274
=== "Token streaming with synchronous API in python"
@@ -85,9 +87,13 @@ stream = Completion.create(
8587
)
8688

8789
for response in stream:
88-
if response.output:
89-
print(response.output.text, end="")
90-
sys.stdout.flush()
90+
try:
91+
if response.output:
92+
print(response.output.text, end="")
93+
sys.stdout.flush()
94+
except: # an error occurred
95+
print(stream.text) # print the error message out
96+
break
9197
```
9298

9399
## Async requests

model-engine/model_engine_server/api/llms_v1.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,8 @@ async def create_completion_sync_task(
199199
) from exc
200200
except ObjectHasInvalidValueException as exc:
201201
raise HTTPException(status_code=400, detail=str(exc))
202+
except InvalidRequestException as exc:
203+
raise HTTPException(status_code=400, detail=str(exc))
202204
except EndpointUnsupportedInferenceTypeException as exc:
203205
raise HTTPException(
204206
status_code=400,
@@ -230,8 +232,12 @@ async def create_completion_stream_task(
230232
)
231233

232234
async def event_generator():
233-
async for message in response:
234-
yield {"data": message.json()}
235+
try:
236+
async for message in response:
237+
yield {"data": message.json()}
238+
except InvalidRequestException as exc:
239+
yield {"data": {"error": {"status_code": 400, "detail": str(exc)}}}
240+
return
235241

236242
return EventSourceResponse(event_generator())
237243
except UpstreamServiceError:

model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py

Lines changed: 34 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,8 @@
5353
from model_engine_server.domain.exceptions import (
5454
EndpointLabelsException,
5555
EndpointUnsupportedInferenceTypeException,
56+
InvalidRequestException,
57+
UpstreamServiceError,
5658
)
5759
from model_engine_server.domain.gateways.llm_artifact_gateway import LLMArtifactGateway
5860
from model_engine_server.domain.repositories import ModelBundleRepository
@@ -741,9 +743,15 @@ def model_output_to_completion_output(
741743
num_completion_tokens=model_output["details"]["generated_tokens"],
742744
tokens=tokens,
743745
)
744-
except Exception as e:
745-
logger.exception(f"Error parsing text-generation-inference output {model_output}")
746-
raise e
746+
except Exception:
747+
logger.exception(f"Error parsing text-generation-inference output {model_output}.")
748+
if model_output.get("error_type") == "validation":
749+
raise InvalidRequestException(model_output.get("error")) # trigger a 400
750+
else:
751+
raise UpstreamServiceError(
752+
status_code=500, content=bytes(model_output["error"])
753+
)
754+
747755
elif model_content.inference_framework == LLMInferenceFramework.VLLM:
748756
tokens = None
749757
if with_token_probs:
@@ -924,7 +932,6 @@ async def execute(
924932
)
925933

926934
output = json.loads(predict_result.result["result"])
927-
928935
return CompletionSyncV1Response(
929936
request_id=request_id,
930937
output=self.model_output_to_completion_output(
@@ -1106,15 +1113,29 @@ async def execute(
11061113
token=result["result"]["token"]["text"],
11071114
log_prob=result["result"]["token"]["logprob"],
11081115
)
1109-
yield CompletionStreamV1Response(
1110-
request_id=request_id,
1111-
output=CompletionStreamOutput(
1112-
text=result["result"]["token"]["text"],
1113-
finished=finished,
1114-
num_completion_tokens=num_completion_tokens,
1115-
token=token,
1116-
),
1117-
)
1116+
try:
1117+
yield CompletionStreamV1Response(
1118+
request_id=request_id,
1119+
output=CompletionStreamOutput(
1120+
text=result["result"]["token"]["text"],
1121+
finished=finished,
1122+
num_completion_tokens=num_completion_tokens,
1123+
token=token,
1124+
),
1125+
)
1126+
except Exception:
1127+
logger.exception(
1128+
f"Error parsing text-generation-inference output. Result: {result['result']}"
1129+
)
1130+
if result["result"].get("error_type") == "validation":
1131+
raise InvalidRequestException(
1132+
result["result"].get("error")
1133+
) # trigger a 400
1134+
else:
1135+
raise UpstreamServiceError(
1136+
status_code=500, content=result.get("error")
1137+
) # also change llms_v1.py that will return a 500 HTTPException so user can retry
1138+
11181139
else:
11191140
yield CompletionStreamV1Response(
11201141
request_id=request_id,

0 commit comments

Comments
 (0)