Skip to content

Commit e2dd093

Browse files
Fix streaming endpoint failure handling (#314)
* Fix streaming endpoint failure handling * Fix streaming endpoint failure handling * remove print * comments * client side changes * client side changes * fix * strong typing
1 parent 2f5dd72 commit e2dd093

File tree

9 files changed

+168
-49
lines changed

9 files changed

+168
-49
lines changed

.pre-commit-config.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ repos:
5555
hooks:
5656
- id: mypy
5757
name: mypy-clients-python
58+
files: clients/python/.*
5859
entry: mypy --config-file clients/python/mypy.ini
5960
language: system
6061
- repo: https://github.com/pre-commit/mirrors-mypy

clients/python/llmengine/data_types.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -355,6 +355,24 @@ class CompletionStreamOutput(BaseModel):
355355
"""Detailed token information."""
356356

357357

358+
class StreamErrorContent(BaseModel):
359+
error: str
360+
"""Error message."""
361+
timestamp: str
362+
"""Timestamp of the error."""
363+
364+
365+
class StreamError(BaseModel):
366+
"""
367+
Error object for a stream prompt completion task.
368+
"""
369+
370+
status_code: int
371+
"""The HTTP status code of the error."""
372+
content: StreamErrorContent
373+
"""The error content."""
374+
375+
358376
class CompletionStreamResponse(BaseModel):
359377
"""
360378
Response object for a stream prompt completion task.
@@ -372,6 +390,9 @@ class CompletionStreamResponse(BaseModel):
372390
output: Optional[CompletionStreamOutput] = None
373391
"""Completion output."""
374392

393+
error: Optional[StreamError] = None
394+
"""Error of the response (if any)."""
395+
375396

376397
class CreateFineTuneRequest(BaseModel):
377398
"""

docs/getting_started.md

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -81,11 +81,10 @@ stream = Completion.create(
8181
)
8282

8383
for response in stream:
84-
try:
85-
if response.output:
86-
print(response.output.text, end="")
87-
sys.stdout.flush()
88-
except: # an error occurred
89-
print(stream.text) # print the error message out
84+
if response.output:
85+
print(response.output.text, end="")
86+
sys.stdout.flush()
87+
else: # an error occurred
88+
print(response.error) # print the error message out
9089
break
9190
```

docs/guides/completions.md

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -87,12 +87,11 @@ stream = Completion.create(
8787
)
8888

8989
for response in stream:
90-
try:
91-
if response.output:
92-
print(response.output.text, end="")
93-
sys.stdout.flush()
94-
except: # an error occurred
95-
print(stream.text) # print the error message out
90+
if response.output:
91+
print(response.output.text, end="")
92+
sys.stdout.flush()
93+
else: # an error occurred
94+
print(response.error) # print the error message out
9695
break
9796
```
9897

model-engine/model_engine_server/api/llms_v1.py

Lines changed: 57 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
"""LLM Model Endpoint routes for the hosted model inference service.
22
"""
3+
import traceback
4+
from datetime import datetime
35
from typing import Optional
46

7+
import pytz
58
from fastapi import APIRouter, Depends, HTTPException, Query
69
from model_engine_server.api.dependencies import (
710
ExternalInterfaces,
@@ -28,6 +31,8 @@
2831
ListLLMModelEndpointsV1Response,
2932
ModelDownloadRequest,
3033
ModelDownloadResponse,
34+
StreamError,
35+
StreamErrorContent,
3136
)
3237
from model_engine_server.common.dtos.model_endpoints import ModelEndpointOrderBy
3338
from model_engine_server.core.auth.authentication_repository import User
@@ -71,6 +76,34 @@
7176
logger = make_logger(filename_wo_ext(__name__))
7277

7378

79+
def handle_streaming_exception(
80+
e: Exception,
81+
code: int,
82+
message: str,
83+
):
84+
tb_str = traceback.format_exception(etype=type(e), value=e, tb=e.__traceback__)
85+
request_id = get_request_id()
86+
timestamp = datetime.now(pytz.timezone("US/Pacific")).strftime("%Y-%m-%d %H:%M:%S %Z")
87+
structured_log = {
88+
"error": message,
89+
"request_id": str(request_id),
90+
"traceback": "".join(tb_str),
91+
}
92+
logger.error("Exception: %s", structured_log)
93+
return {
94+
"data": CompletionStreamV1Response(
95+
request_id=str(request_id),
96+
error=StreamError(
97+
status_code=code,
98+
content=StreamErrorContent(
99+
error=message,
100+
timestamp=timestamp,
101+
),
102+
),
103+
).json()
104+
}
105+
106+
74107
@llm_router_v1.post("/model-endpoints", response_model=CreateLLMModelEndpointV1Response)
75108
async def create_model_endpoint(
76109
request: CreateLLMModelEndpointV1Request,
@@ -226,42 +259,30 @@ async def create_completion_stream_task(
226259
logger.info(
227260
f"POST /completion_stream with {request} to endpoint {model_endpoint_name} for {auth}"
228261
)
229-
try:
230-
use_case = CompletionStreamV1UseCase(
231-
model_endpoint_service=external_interfaces.model_endpoint_service,
232-
llm_model_endpoint_service=external_interfaces.llm_model_endpoint_service,
233-
)
234-
response = use_case.execute(
235-
user=auth, model_endpoint_name=model_endpoint_name, request=request
236-
)
262+
use_case = CompletionStreamV1UseCase(
263+
model_endpoint_service=external_interfaces.model_endpoint_service,
264+
llm_model_endpoint_service=external_interfaces.llm_model_endpoint_service,
265+
)
266+
response = use_case.execute(user=auth, model_endpoint_name=model_endpoint_name, request=request)
237267

238-
async def event_generator():
239-
try:
240-
async for message in response:
241-
yield {"data": message.json()}
242-
except InvalidRequestException as exc:
243-
yield {"data": {"error": {"status_code": 400, "detail": str(exc)}}}
244-
return
268+
async def event_generator():
269+
try:
270+
async for message in response:
271+
yield {"data": message.json()}
272+
except (InvalidRequestException, ObjectHasInvalidValueException) as exc:
273+
yield handle_streaming_exception(exc, 400, str(exc))
274+
except (
275+
ObjectNotFoundException,
276+
ObjectNotAuthorizedException,
277+
EndpointUnsupportedInferenceTypeException,
278+
) as exc:
279+
yield handle_streaming_exception(exc, 404, str(exc))
280+
except Exception as exc:
281+
yield handle_streaming_exception(
282+
exc, 500, "Internal error occurred. Our team has been notified."
283+
)
245284

246-
return EventSourceResponse(event_generator())
247-
except UpstreamServiceError:
248-
request_id = get_request_id()
249-
logger.exception(f"Upstream service error for request {request_id}")
250-
return EventSourceResponse(
251-
iter((CompletionStreamV1Response(request_id=request_id).json(),)) # type: ignore
252-
)
253-
except (ObjectNotFoundException, ObjectNotAuthorizedException) as exc:
254-
raise HTTPException(
255-
status_code=404,
256-
detail="The specified endpoint could not be found.",
257-
) from exc
258-
except ObjectHasInvalidValueException as exc:
259-
raise HTTPException(status_code=400, detail=str(exc))
260-
except EndpointUnsupportedInferenceTypeException as exc:
261-
raise HTTPException(
262-
status_code=400,
263-
detail=f"Unsupported inference type: {str(exc)}",
264-
) from exc
285+
return EventSourceResponse(event_generator())
265286

266287

267288
@llm_router_v1.post("/fine-tunes", response_model=CreateFineTuneResponse)
@@ -405,12 +426,12 @@ async def delete_llm_model_endpoint(
405426
model_endpoint_service=external_interfaces.model_endpoint_service,
406427
)
407428
return await use_case.execute(user=auth, model_endpoint_name=model_endpoint_name)
408-
except (ObjectNotFoundException) as exc:
429+
except ObjectNotFoundException as exc:
409430
raise HTTPException(
410431
status_code=404,
411432
detail="The requested model endpoint could not be found.",
412433
) from exc
413-
except (ObjectNotAuthorizedException) as exc:
434+
except ObjectNotAuthorizedException as exc:
414435
raise HTTPException(
415436
status_code=403,
416437
detail="You don't have permission to delete the requested model endpoint.",

model-engine/model_engine_server/common/dtos/llms.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -202,13 +202,33 @@ class CompletionStreamOutput(BaseModel):
202202
token: Optional[TokenOutput] = None
203203

204204

205+
class StreamErrorContent(BaseModel):
206+
error: str
207+
"""Error message."""
208+
timestamp: str
209+
"""Timestamp of the error."""
210+
211+
212+
class StreamError(BaseModel):
213+
"""
214+
Error object for a stream prompt completion task.
215+
"""
216+
217+
status_code: int
218+
"""The HTTP status code of the error."""
219+
content: StreamErrorContent
220+
"""The error content."""
221+
222+
205223
class CompletionStreamV1Response(BaseModel):
206224
"""
207225
Response object for a stream prompt completion task.
208226
"""
209227

210228
request_id: str
211229
output: Optional[CompletionStreamOutput] = None
230+
error: Optional[StreamError] = None
231+
"""Error of the response (if any)."""
212232

213233

214234
class CreateFineTuneRequest(BaseModel):

model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1308,7 +1308,7 @@ async def execute(
13081308
)
13091309

13101310
if len(model_endpoints) == 0:
1311-
raise ObjectNotFoundException
1311+
raise ObjectNotFoundException(f"Model endpoint {model_endpoint_name} not found.")
13121312

13131313
if len(model_endpoints) > 1:
13141314
raise ObjectHasInvalidValueException(

model-engine/tests/unit/api/test_llms.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,32 @@ def test_completion_sync_success(
113113
assert response_1.json().keys() == {"output", "request_id"}
114114

115115

116+
def test_completion_sync_endpoint_not_found_returns_404(
117+
llm_model_endpoint_sync: Tuple[ModelEndpoint, Any],
118+
completion_sync_request: Dict[str, Any],
119+
get_test_client_wrapper,
120+
):
121+
client = get_test_client_wrapper(
122+
fake_docker_repository_image_always_exists=True,
123+
fake_model_bundle_repository_contents={},
124+
fake_model_endpoint_record_repository_contents={},
125+
fake_model_endpoint_infra_gateway_contents={
126+
llm_model_endpoint_sync[0]
127+
.infra_state.deployment_name: llm_model_endpoint_sync[0]
128+
.infra_state,
129+
},
130+
fake_batch_job_record_repository_contents={},
131+
fake_batch_job_progress_gateway_contents={},
132+
fake_docker_image_batch_job_bundle_repository_contents={},
133+
)
134+
response_1 = client.post(
135+
f"/v1/llm/completions-sync?model_endpoint_name={llm_model_endpoint_sync[0].record.name}",
136+
auth=("no_user", ""),
137+
json=completion_sync_request,
138+
)
139+
assert response_1.status_code == 404
140+
141+
116142
@pytest.mark.skip(reason="Need to figure out FastAPI test client asyncio funkiness")
117143
def test_completion_stream_success(
118144
llm_model_endpoint_streaming: ModelEndpoint,
@@ -136,6 +162,7 @@ def test_completion_stream_success(
136162
f"/v1/llm/completions-stream?model_endpoint_name={llm_model_endpoint_streaming.record.name}",
137163
auth=("no_user", ""),
138164
json=completion_stream_request,
165+
stream=True,
139166
)
140167
assert response_1.status_code == 200
141168
count = 0
@@ -146,3 +173,33 @@ def test_completion_stream_success(
146173
)
147174
count += 1
148175
assert count == 1
176+
177+
178+
@pytest.mark.skip(reason="Need to figure out FastAPI test client asyncio funkiness")
179+
def test_completion_stream_endpoint_not_found_returns_404(
180+
llm_model_endpoint_streaming: ModelEndpoint,
181+
completion_stream_request: Dict[str, Any],
182+
get_test_client_wrapper,
183+
):
184+
client = get_test_client_wrapper(
185+
fake_docker_repository_image_always_exists=True,
186+
fake_model_bundle_repository_contents={},
187+
fake_model_endpoint_record_repository_contents={},
188+
fake_model_endpoint_infra_gateway_contents={
189+
llm_model_endpoint_streaming.infra_state.deployment_name: llm_model_endpoint_streaming.infra_state,
190+
},
191+
fake_batch_job_record_repository_contents={},
192+
fake_batch_job_progress_gateway_contents={},
193+
fake_docker_image_batch_job_bundle_repository_contents={},
194+
)
195+
response_1 = client.post(
196+
f"/v1/llm/completions-stream?model_endpoint_name={llm_model_endpoint_streaming.record.name}",
197+
auth=("no_user", ""),
198+
json=completion_stream_request,
199+
stream=True,
200+
)
201+
202+
assert response_1.status_code == 200
203+
204+
for message in response_1:
205+
assert "404" in message.decode("utf-8")

model-engine/tests/unit/api/test_tasks.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -364,6 +364,7 @@ def test_create_streaming_task_success(
364364
f"/v1/streaming-tasks?model_endpoint_id={model_endpoint_streaming.record.id}",
365365
auth=(test_api_key, ""),
366366
json=endpoint_predict_request_1[1],
367+
stream=True,
367368
)
368369
assert response.status_code == 200
369370
count = 0

0 commit comments

Comments
 (0)