|
1 | 1 | """LLM Model Endpoint routes for the hosted model inference service. |
2 | 2 | """ |
| 3 | +import traceback |
| 4 | +from datetime import datetime |
3 | 5 | from typing import Optional |
4 | 6 |
|
| 7 | +import pytz |
5 | 8 | from fastapi import APIRouter, Depends, HTTPException, Query |
6 | 9 | from model_engine_server.api.dependencies import ( |
7 | 10 | ExternalInterfaces, |
|
28 | 31 | ListLLMModelEndpointsV1Response, |
29 | 32 | ModelDownloadRequest, |
30 | 33 | ModelDownloadResponse, |
| 34 | + StreamError, |
| 35 | + StreamErrorContent, |
31 | 36 | ) |
32 | 37 | from model_engine_server.common.dtos.model_endpoints import ModelEndpointOrderBy |
33 | 38 | from model_engine_server.core.auth.authentication_repository import User |
|
71 | 76 | logger = make_logger(filename_wo_ext(__name__)) |
72 | 77 |
|
73 | 78 |
|
| 79 | +def handle_streaming_exception( |
| 80 | + e: Exception, |
| 81 | + code: int, |
| 82 | + message: str, |
| 83 | +): |
| 84 | + tb_str = traceback.format_exception(etype=type(e), value=e, tb=e.__traceback__) |
| 85 | + request_id = get_request_id() |
| 86 | + timestamp = datetime.now(pytz.timezone("US/Pacific")).strftime("%Y-%m-%d %H:%M:%S %Z") |
| 87 | + structured_log = { |
| 88 | + "error": message, |
| 89 | + "request_id": str(request_id), |
| 90 | + "traceback": "".join(tb_str), |
| 91 | + } |
| 92 | + logger.error("Exception: %s", structured_log) |
| 93 | + return { |
| 94 | + "data": CompletionStreamV1Response( |
| 95 | + request_id=str(request_id), |
| 96 | + error=StreamError( |
| 97 | + status_code=code, |
| 98 | + content=StreamErrorContent( |
| 99 | + error=message, |
| 100 | + timestamp=timestamp, |
| 101 | + ), |
| 102 | + ), |
| 103 | + ).json() |
| 104 | + } |
| 105 | + |
| 106 | + |
74 | 107 | @llm_router_v1.post("/model-endpoints", response_model=CreateLLMModelEndpointV1Response) |
75 | 108 | async def create_model_endpoint( |
76 | 109 | request: CreateLLMModelEndpointV1Request, |
@@ -226,42 +259,30 @@ async def create_completion_stream_task( |
226 | 259 | logger.info( |
227 | 260 | f"POST /completion_stream with {request} to endpoint {model_endpoint_name} for {auth}" |
228 | 261 | ) |
229 | | - try: |
230 | | - use_case = CompletionStreamV1UseCase( |
231 | | - model_endpoint_service=external_interfaces.model_endpoint_service, |
232 | | - llm_model_endpoint_service=external_interfaces.llm_model_endpoint_service, |
233 | | - ) |
234 | | - response = use_case.execute( |
235 | | - user=auth, model_endpoint_name=model_endpoint_name, request=request |
236 | | - ) |
| 262 | + use_case = CompletionStreamV1UseCase( |
| 263 | + model_endpoint_service=external_interfaces.model_endpoint_service, |
| 264 | + llm_model_endpoint_service=external_interfaces.llm_model_endpoint_service, |
| 265 | + ) |
| 266 | + response = use_case.execute(user=auth, model_endpoint_name=model_endpoint_name, request=request) |
237 | 267 |
|
238 | | - async def event_generator(): |
239 | | - try: |
240 | | - async for message in response: |
241 | | - yield {"data": message.json()} |
242 | | - except InvalidRequestException as exc: |
243 | | - yield {"data": {"error": {"status_code": 400, "detail": str(exc)}}} |
244 | | - return |
| 268 | + async def event_generator(): |
| 269 | + try: |
| 270 | + async for message in response: |
| 271 | + yield {"data": message.json()} |
| 272 | + except (InvalidRequestException, ObjectHasInvalidValueException) as exc: |
| 273 | + yield handle_streaming_exception(exc, 400, str(exc)) |
| 274 | + except ( |
| 275 | + ObjectNotFoundException, |
| 276 | + ObjectNotAuthorizedException, |
| 277 | + EndpointUnsupportedInferenceTypeException, |
| 278 | + ) as exc: |
| 279 | + yield handle_streaming_exception(exc, 404, str(exc)) |
| 280 | + except Exception as exc: |
| 281 | + yield handle_streaming_exception( |
| 282 | + exc, 500, "Internal error occurred. Our team has been notified." |
| 283 | + ) |
245 | 284 |
|
246 | | - return EventSourceResponse(event_generator()) |
247 | | - except UpstreamServiceError: |
248 | | - request_id = get_request_id() |
249 | | - logger.exception(f"Upstream service error for request {request_id}") |
250 | | - return EventSourceResponse( |
251 | | - iter((CompletionStreamV1Response(request_id=request_id).json(),)) # type: ignore |
252 | | - ) |
253 | | - except (ObjectNotFoundException, ObjectNotAuthorizedException) as exc: |
254 | | - raise HTTPException( |
255 | | - status_code=404, |
256 | | - detail="The specified endpoint could not be found.", |
257 | | - ) from exc |
258 | | - except ObjectHasInvalidValueException as exc: |
259 | | - raise HTTPException(status_code=400, detail=str(exc)) |
260 | | - except EndpointUnsupportedInferenceTypeException as exc: |
261 | | - raise HTTPException( |
262 | | - status_code=400, |
263 | | - detail=f"Unsupported inference type: {str(exc)}", |
264 | | - ) from exc |
| 285 | + return EventSourceResponse(event_generator()) |
265 | 286 |
|
266 | 287 |
|
267 | 288 | @llm_router_v1.post("/fine-tunes", response_model=CreateFineTuneResponse) |
@@ -405,12 +426,12 @@ async def delete_llm_model_endpoint( |
405 | 426 | model_endpoint_service=external_interfaces.model_endpoint_service, |
406 | 427 | ) |
407 | 428 | return await use_case.execute(user=auth, model_endpoint_name=model_endpoint_name) |
408 | | - except (ObjectNotFoundException) as exc: |
| 429 | + except ObjectNotFoundException as exc: |
409 | 430 | raise HTTPException( |
410 | 431 | status_code=404, |
411 | 432 | detail="The requested model endpoint could not be found.", |
412 | 433 | ) from exc |
413 | | - except (ObjectNotAuthorizedException) as exc: |
| 434 | + except ObjectNotAuthorizedException as exc: |
414 | 435 | raise HTTPException( |
415 | 436 | status_code=403, |
416 | 437 | detail="You don't have permission to delete the requested model endpoint.", |
|
0 commit comments