Skip to content

Commit d9e9623

Browse files
authored
Modify v1 completions_stream logic to raise most exceptions before async streaming inference response (#534)
* consolidate streaming response logic into separate inline function. call execute() synchronously and call inline function async * iterate * refactor: pull inference result status/empty check outside of framework conditionals to dedupe code. put logic for unsuccessful/empty results before other handling logic for readability. add some commenting and other small edits. * formatting fixes * improve commenting * fix and reenable 404 unit test * fix stream success unit test, add async test client fixture * move _response_chunk_generator() from an inline def in execute() to a separate private method for the usecase * fix issue with streaming tests interacting by defining a per-session event loop fixture and reconfiguring test_create_streaming_task_success as an async test * one more unit test * update llm-engine Completions docs with details on streaming error handling
1 parent e46cbd4 commit d9e9623

File tree

7 files changed

+280
-113
lines changed

7 files changed

+280
-113
lines changed

docs/guides/completions.md

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,11 @@ applications. When streaming, tokens will be sent as data-only
6767

6868
To enable token streaming, pass `stream=True` to either [Completion.create](../../api/python_client/#llmengine.completion.Completion.create) or [Completion.acreate](../../api/python_client/#llmengine.completion.Completion.acreate).
6969

70-
Note that errors from streaming calls are returned back to the user as plain-text messages and currently need to be handled by the client.
70+
### Streaming Error Handling
71+
72+
Note: Error handling semantics are mixed for streaming calls:
73+
- Errors that arise *before* streaming begins are returned back to the user as `HTTP` errors with the appropriate status code.
74+
- Errors that arise *after* streaming begins within a `HTTP 200` response are returned back to the user as plain-text messages and currently need to be handled by the client.
7175

7276
An example of token streaming using the synchronous Completions API looks as follows:
7377

@@ -78,6 +82,7 @@ import sys
7882

7983
from llmengine import Completion
8084

85+
# errors occurring before streaming begins will be thrown here
8186
stream = Completion.create(
8287
model="llama-2-7b",
8388
prompt="Give me a 200 word summary on the current economic events in the US.",
@@ -90,7 +95,7 @@ for response in stream:
9095
if response.output:
9196
print(response.output.text, end="")
9297
sys.stdout.flush()
93-
else: # an error occurred
98+
else: # an error occurred after streaming began
9499
print(response.error) # print the error message out
95100
break
96101
```

model-engine/model_engine_server/api/llms_v1.py

Lines changed: 35 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -405,7 +405,29 @@ async def create_completion_stream_task(
405405
llm_model_endpoint_service=external_interfaces.llm_model_endpoint_service,
406406
tokenizer_repository=external_interfaces.tokenizer_repository,
407407
)
408-
response = use_case.execute(user=auth, model_endpoint_name=model_endpoint_name, request=request)
408+
409+
try:
410+
# Call execute() with await, since it needs to handle exceptions before we begin streaming the response below.
411+
# execute() will create a response chunk generator and return a reference to it.
412+
response = await use_case.execute(
413+
user=auth, model_endpoint_name=model_endpoint_name, request=request
414+
)
415+
except (ObjectNotFoundException, ObjectNotAuthorizedException) as exc:
416+
raise HTTPException(
417+
status_code=404,
418+
detail=str(exc),
419+
) from exc
420+
except EndpointUnsupportedInferenceTypeException as exc:
421+
raise HTTPException(
422+
status_code=400,
423+
detail=str(exc),
424+
) from exc
425+
except ObjectHasInvalidValueException as exc:
426+
raise HTTPException(status_code=400, detail=str(exc)) from exc
427+
except Exception as exc:
428+
raise HTTPException(
429+
status_code=500, detail="Internal error occurred. Our team has been notified."
430+
) from exc
409431

410432
async def event_generator():
411433
try:
@@ -427,14 +449,19 @@ async def event_generator():
427449
),
428450
metric_metadata,
429451
)
430-
except (InvalidRequestException, ObjectHasInvalidValueException) as exc:
452+
# The following two exceptions are only raised after streaming begins, so we wrap the exception within a Response object
453+
except InvalidRequestException as exc:
431454
yield handle_streaming_exception(exc, 400, str(exc))
432-
except (
433-
ObjectNotFoundException,
434-
ObjectNotAuthorizedException,
435-
EndpointUnsupportedInferenceTypeException,
436-
) as exc:
437-
yield handle_streaming_exception(exc, 404, str(exc))
455+
except UpstreamServiceError as exc:
456+
request_id = LoggerTagManager.get(LoggerTagKey.REQUEST_ID)
457+
logger.exception(
458+
f"Upstream service error for request {request_id}. Error detail: {str(exc.content)}"
459+
)
460+
yield handle_streaming_exception(
461+
exc,
462+
500,
463+
f"Upstream service error for request_id {request_id}",
464+
)
438465
except Exception as exc:
439466
yield handle_streaming_exception(
440467
exc, 500, "Internal error occurred. Our team has been notified."

model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py

Lines changed: 71 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,10 @@
7878
ObjectNotFoundException,
7979
UpstreamServiceError,
8080
)
81-
from model_engine_server.domain.gateways import DockerImageBatchJobGateway
81+
from model_engine_server.domain.gateways import (
82+
DockerImageBatchJobGateway,
83+
StreamingModelEndpointInferenceGateway,
84+
)
8285
from model_engine_server.domain.gateways.llm_artifact_gateway import LLMArtifactGateway
8386
from model_engine_server.domain.repositories import (
8487
DockerImageBatchJobBundleRepository,
@@ -1845,18 +1848,27 @@ async def execute(
18451848
) -> AsyncIterable[CompletionStreamV1Response]:
18461849
"""
18471850
Runs the use case to create a stream inference task.
1851+
NOTE: Must be called with await(), since the function is not a generator itself, but rather creates one and
1852+
returns a reference to it. This structure allows exceptions that occur before response streaming begins
1853+
to propagate to the client as HTTP exceptions with the appropriate code.
18481854
18491855
Args:
18501856
user: The user who is creating the stream inference task.
18511857
model_endpoint_name: The name of the model endpoint for the task.
18521858
request: The body of the request to forward to the endpoint.
18531859
18541860
Returns:
1855-
A response object that contains the status and result of the task.
1861+
An asynchronous response chunk generator, containing response objects to be iterated through with 'async for'.
1862+
Each response object contains the status and result of the task.
18561863
18571864
Raises:
18581865
ObjectNotFoundException: If a model endpoint with the given name could not be found.
1866+
ObjectHasInvalidValueException: If there are multiple model endpoints with the given name.
18591867
ObjectNotAuthorizedException: If the owner does not own the model endpoint.
1868+
EndpointUnsupportedInferenceTypeException: If the model endpoint does not support streaming or uses
1869+
an unsupported inference framework.
1870+
UpstreamServiceError: If an error occurs upstream in the streaming inference API call.
1871+
InvalidRequestException: If request validation fails during inference.
18601872
"""
18611873

18621874
request_id = LoggerTagManager.get(LoggerTagKey.REQUEST_ID)
@@ -2020,7 +2032,6 @@ async def execute(
20202032
model_content.model_name,
20212033
self.tokenizer_repository,
20222034
)
2023-
20242035
else:
20252036
raise EndpointUnsupportedInferenceTypeException(
20262037
f"Unsupported inference framework {model_content.inference_framework}"
@@ -2031,15 +2042,55 @@ async def execute(
20312042
num_retries=NUM_DOWNSTREAM_REQUEST_RETRIES,
20322043
timeout_seconds=DOWNSTREAM_REQUEST_TIMEOUT_SECONDS,
20332044
)
2045+
2046+
return self._response_chunk_generator(
2047+
request=request,
2048+
request_id=request_id,
2049+
model_endpoint=model_endpoint,
2050+
model_content=model_content,
2051+
inference_gateway=inference_gateway,
2052+
inference_request=inference_request,
2053+
num_prompt_tokens=num_prompt_tokens,
2054+
)
2055+
2056+
async def _response_chunk_generator(
2057+
self,
2058+
request: CompletionStreamV1Request,
2059+
request_id: Optional[str],
2060+
model_endpoint: ModelEndpoint,
2061+
model_content: GetLLMModelEndpointV1Response,
2062+
inference_gateway: StreamingModelEndpointInferenceGateway,
2063+
inference_request: SyncEndpointPredictV1Request,
2064+
num_prompt_tokens: Optional[int],
2065+
) -> AsyncIterable[CompletionStreamV1Response]:
2066+
"""
2067+
Async generator yielding tokens to stream for the completions response. Should only be called when
2068+
returned directly by execute().
2069+
"""
20342070
predict_result = inference_gateway.streaming_predict(
20352071
topic=model_endpoint.record.destination, predict_request=inference_request
20362072
)
20372073

20382074
num_completion_tokens = 0
20392075
async for res in predict_result:
2040-
result = res.result
2041-
if model_content.inference_framework == LLMInferenceFramework.DEEPSPEED:
2042-
if res.status == TaskStatus.SUCCESS and result is not None:
2076+
if not res.status == TaskStatus.SUCCESS or res.result is None:
2077+
# Raise an UpstreamServiceError if the task has failed
2078+
if res.status == TaskStatus.FAILURE:
2079+
raise UpstreamServiceError(
2080+
status_code=500,
2081+
content=(
2082+
res.traceback.encode("utf-8") if res.traceback is not None else b""
2083+
),
2084+
)
2085+
# Otherwise, yield empty response chunk for unsuccessful or empty results
2086+
yield CompletionStreamV1Response(
2087+
request_id=request_id,
2088+
output=None,
2089+
)
2090+
else:
2091+
result = res.result
2092+
# DEEPSPEED
2093+
if model_content.inference_framework == LLMInferenceFramework.DEEPSPEED:
20432094
if "token" in result["result"]:
20442095
yield CompletionStreamV1Response(
20452096
request_id=request_id,
@@ -2063,15 +2114,11 @@ async def execute(
20632114
num_completion_tokens=completion_token_count,
20642115
),
20652116
)
2066-
else:
2067-
yield CompletionStreamV1Response(
2068-
request_id=request_id,
2069-
output=None,
2070-
)
2071-
elif (
2072-
model_content.inference_framework == LLMInferenceFramework.TEXT_GENERATION_INFERENCE
2073-
):
2074-
if res.status == TaskStatus.SUCCESS and result is not None:
2117+
# TEXT_GENERATION_INTERFACE
2118+
elif (
2119+
model_content.inference_framework
2120+
== LLMInferenceFramework.TEXT_GENERATION_INFERENCE
2121+
):
20752122
if result["result"].get("generated_text") is not None:
20762123
finished = True
20772124
else:
@@ -2108,14 +2155,8 @@ async def execute(
21082155
raise UpstreamServiceError(
21092156
status_code=500, content=result.get("error")
21102157
) # also change llms_v1.py that will return a 500 HTTPException so user can retry
2111-
2112-
else:
2113-
yield CompletionStreamV1Response(
2114-
request_id=request_id,
2115-
output=None,
2116-
)
2117-
elif model_content.inference_framework == LLMInferenceFramework.VLLM:
2118-
if res.status == TaskStatus.SUCCESS and result is not None:
2158+
# VLLM
2159+
elif model_content.inference_framework == LLMInferenceFramework.VLLM:
21192160
token = None
21202161
if request.return_token_log_probs:
21212162
token = TokenOutput(
@@ -2134,13 +2175,8 @@ async def execute(
21342175
token=token,
21352176
),
21362177
)
2137-
else:
2138-
yield CompletionStreamV1Response(
2139-
request_id=request_id,
2140-
output=None,
2141-
)
2142-
elif model_content.inference_framework == LLMInferenceFramework.LIGHTLLM:
2143-
if res.status == TaskStatus.SUCCESS and result is not None:
2178+
# LIGHTLLM
2179+
elif model_content.inference_framework == LLMInferenceFramework.LIGHTLLM:
21442180
token = None
21452181
num_completion_tokens += 1
21462182
if request.return_token_log_probs:
@@ -2159,13 +2195,8 @@ async def execute(
21592195
token=token,
21602196
),
21612197
)
2162-
else:
2163-
yield CompletionStreamV1Response(
2164-
request_id=request_id,
2165-
output=None,
2166-
)
2167-
elif model_content.inference_framework == LLMInferenceFramework.TENSORRT_LLM:
2168-
if res.status == TaskStatus.SUCCESS and result is not None:
2198+
# TENSORRT_LLM
2199+
elif model_content.inference_framework == LLMInferenceFramework.TENSORRT_LLM:
21692200
num_completion_tokens += 1
21702201
yield CompletionStreamV1Response(
21712202
request_id=request_id,
@@ -2176,15 +2207,9 @@ async def execute(
21762207
num_completion_tokens=num_completion_tokens,
21772208
),
21782209
)
2179-
else:
2180-
yield CompletionStreamV1Response(
2181-
request_id=request_id,
2182-
output=None,
2183-
)
2184-
else:
2185-
raise EndpointUnsupportedInferenceTypeException(
2186-
f"Unsupported inference framework {model_content.inference_framework}"
2187-
)
2210+
# No else clause needed for an unsupported inference framework, since we check
2211+
# model_content.inference_framework in execute() prior to calling _response_chunk_generator,
2212+
# raising an exception if it is not one of the frameworks handled above.
21882213

21892214

21902215
class ModelDownloadV1UseCase:

model-engine/tests/unit/api/conftest.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
1+
import asyncio
12
import datetime
23
from typing import Any, Dict, Iterator, Tuple
34

45
import pytest
6+
import pytest_asyncio
57
from fastapi import Depends, HTTPException
68
from fastapi.security import HTTPBasicCredentials
79
from fastapi.testclient import TestClient
10+
from httpx import AsyncClient
811
from model_engine_server.api.app import app
912
from model_engine_server.api.dependencies import (
1013
AUTH,
@@ -90,6 +93,14 @@ def fake_auth():
9093
app.dependency_overrides[verify_authentication] = {}
9194

9295

96+
@pytest_asyncio.fixture(scope="session", autouse=True)
97+
def event_loop(request):
98+
"""Create an instance of the default event loop for each test case."""
99+
loop = asyncio.get_event_loop_policy().new_event_loop()
100+
yield loop
101+
loop.close()
102+
103+
93104
@pytest.fixture
94105
def get_test_client_wrapper(get_repositories_generator_wrapper):
95106
def get_test_client(
@@ -159,6 +170,75 @@ def get_test_client(
159170
return get_test_client
160171

161172

173+
@pytest.fixture
174+
def get_async_test_client_wrapper(get_repositories_generator_wrapper):
175+
def get_async_test_client(
176+
fake_docker_repository_image_always_exists=True,
177+
fake_model_bundle_repository_contents=None,
178+
fake_model_endpoint_record_repository_contents=None,
179+
fake_model_endpoint_infra_gateway_contents=None,
180+
fake_batch_job_record_repository_contents=None,
181+
fake_batch_job_progress_gateway_contents=None,
182+
fake_docker_image_batch_job_bundle_repository_contents=None,
183+
fake_docker_image_batch_job_gateway_contents=None,
184+
fake_llm_fine_tuning_service_contents=None,
185+
fake_file_storage_gateway_contents=None,
186+
fake_file_system_gateway_contents=None,
187+
fake_trigger_repository_contents=None,
188+
fake_cron_job_gateway_contents=None,
189+
fake_sync_inference_content=None,
190+
) -> AsyncClient:
191+
if fake_docker_image_batch_job_gateway_contents is None:
192+
fake_docker_image_batch_job_gateway_contents = {}
193+
if fake_docker_image_batch_job_bundle_repository_contents is None:
194+
fake_docker_image_batch_job_bundle_repository_contents = {}
195+
if fake_batch_job_progress_gateway_contents is None:
196+
fake_batch_job_progress_gateway_contents = {}
197+
if fake_batch_job_record_repository_contents is None:
198+
fake_batch_job_record_repository_contents = {}
199+
if fake_model_endpoint_infra_gateway_contents is None:
200+
fake_model_endpoint_infra_gateway_contents = {}
201+
if fake_model_endpoint_record_repository_contents is None:
202+
fake_model_endpoint_record_repository_contents = {}
203+
if fake_model_bundle_repository_contents is None:
204+
fake_model_bundle_repository_contents = {}
205+
if fake_llm_fine_tuning_service_contents is None:
206+
fake_llm_fine_tuning_service_contents = {}
207+
if fake_file_storage_gateway_contents is None:
208+
fake_file_storage_gateway_contents = {}
209+
if fake_file_system_gateway_contents is None:
210+
fake_file_system_gateway_contents = {}
211+
if fake_trigger_repository_contents is None:
212+
fake_trigger_repository_contents = {}
213+
if fake_cron_job_gateway_contents is None:
214+
fake_cron_job_gateway_contents = {}
215+
if fake_sync_inference_content is None:
216+
fake_sync_inference_content = {}
217+
app.dependency_overrides[get_external_interfaces] = get_repositories_generator_wrapper(
218+
fake_docker_repository_image_always_exists=fake_docker_repository_image_always_exists,
219+
fake_model_bundle_repository_contents=fake_model_bundle_repository_contents,
220+
fake_model_endpoint_record_repository_contents=fake_model_endpoint_record_repository_contents,
221+
fake_model_endpoint_infra_gateway_contents=fake_model_endpoint_infra_gateway_contents,
222+
fake_batch_job_record_repository_contents=fake_batch_job_record_repository_contents,
223+
fake_batch_job_progress_gateway_contents=fake_batch_job_progress_gateway_contents,
224+
fake_docker_image_batch_job_bundle_repository_contents=fake_docker_image_batch_job_bundle_repository_contents,
225+
fake_docker_image_batch_job_gateway_contents=fake_docker_image_batch_job_gateway_contents,
226+
fake_llm_fine_tuning_service_contents=fake_llm_fine_tuning_service_contents,
227+
fake_file_storage_gateway_contents=fake_file_storage_gateway_contents,
228+
fake_file_system_gateway_contents=fake_file_system_gateway_contents,
229+
fake_trigger_repository_contents=fake_trigger_repository_contents,
230+
fake_cron_job_gateway_contents=fake_cron_job_gateway_contents,
231+
fake_sync_inference_content=fake_sync_inference_content,
232+
)
233+
app.dependency_overrides[get_external_interfaces_read_only] = app.dependency_overrides[
234+
get_external_interfaces
235+
]
236+
client = AsyncClient(app=app, base_url="http://test")
237+
return client
238+
239+
return get_async_test_client
240+
241+
162242
@pytest.fixture
163243
def simple_client(get_test_client_wrapper) -> TestClient:
164244
"""Returns a Client with no initial contents and a Docker repository that always returns True"""

0 commit comments

Comments
 (0)