Skip to content

Commit 2f5dd72

Browse files
Return TGI errors (#313)
* Return TGI errors * remove prints * fix lint
1 parent 86003eb commit 2f5dd72

File tree

3 files changed

+199
-13
lines changed

3 files changed

+199
-13
lines changed

model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -909,7 +909,10 @@ def validate_and_update_completion_params(
909909
if inference_framework == LLMInferenceFramework.TEXT_GENERATION_INFERENCE:
910910
request.top_k = None if request.top_k == -1 else request.top_k
911911
request.top_p = None if request.top_p == 1.0 else request.top_p
912-
if inference_framework in [LLMInferenceFramework.VLLM, LLMInferenceFramework.LIGHTLLM]:
912+
if inference_framework in [
913+
LLMInferenceFramework.VLLM,
914+
LLMInferenceFramework.LIGHTLLM,
915+
]:
913916
request.top_k = -1 if request.top_k is None else request.top_k
914917
request.top_p = 1.0 if request.top_p is None else request.top_p
915918
else:
@@ -919,7 +922,10 @@ def validate_and_update_completion_params(
919922
)
920923

921924
# presence_penalty, frequency_penalty
922-
if inference_framework in [LLMInferenceFramework.VLLM, LLMInferenceFramework.LIGHTLLM]:
925+
if inference_framework in [
926+
LLMInferenceFramework.VLLM,
927+
LLMInferenceFramework.LIGHTLLM,
928+
]:
923929
request.presence_penalty = (
924930
0.0 if request.presence_penalty is None else request.presence_penalty
925931
)
@@ -987,14 +993,17 @@ def model_output_to_completion_output(
987993
raise InvalidRequestException(model_output.get("error")) # trigger a 400
988994
else:
989995
raise UpstreamServiceError(
990-
status_code=500, content=bytes(model_output["error"])
996+
status_code=500, content=bytes(model_output["error"], "utf-8")
991997
)
992998

993999
elif model_content.inference_framework == LLMInferenceFramework.VLLM:
9941000
tokens = None
9951001
if with_token_probs:
9961002
tokens = [
997-
TokenOutput(token=model_output["tokens"][index], log_prob=list(t.values())[0])
1003+
TokenOutput(
1004+
token=model_output["tokens"][index],
1005+
log_prob=list(t.values())[0],
1006+
)
9981007
for index, t in enumerate(model_output["log_probs"])
9991008
]
10001009
return CompletionOutput(
@@ -1003,7 +1012,6 @@ def model_output_to_completion_output(
10031012
tokens=tokens,
10041013
)
10051014
elif model_content.inference_framework == LLMInferenceFramework.LIGHTLLM:
1006-
print(model_output)
10071015
tokens = None
10081016
if with_token_probs:
10091017
tokens = [
@@ -1109,7 +1117,8 @@ async def execute(
11091117
timeout_seconds=DOWNSTREAM_REQUEST_TIMEOUT_SECONDS,
11101118
)
11111119
predict_result = await inference_gateway.predict(
1112-
topic=model_endpoint.record.destination, predict_request=inference_request
1120+
topic=model_endpoint.record.destination,
1121+
predict_request=inference_request,
11131122
)
11141123

11151124
if predict_result.status == TaskStatus.SUCCESS and predict_result.result is not None:
@@ -1152,7 +1161,8 @@ async def execute(
11521161
timeout_seconds=DOWNSTREAM_REQUEST_TIMEOUT_SECONDS,
11531162
)
11541163
predict_result = await inference_gateway.predict(
1155-
topic=model_endpoint.record.destination, predict_request=inference_request
1164+
topic=model_endpoint.record.destination,
1165+
predict_request=inference_request,
11561166
)
11571167

11581168
if predict_result.status != TaskStatus.SUCCESS or predict_result.result is None:
@@ -1191,7 +1201,8 @@ async def execute(
11911201
timeout_seconds=DOWNSTREAM_REQUEST_TIMEOUT_SECONDS,
11921202
)
11931203
predict_result = await inference_gateway.predict(
1194-
topic=model_endpoint.record.destination, predict_request=inference_request
1204+
topic=model_endpoint.record.destination,
1205+
predict_request=inference_request,
11951206
)
11961207

11971208
if predict_result.status != TaskStatus.SUCCESS or predict_result.result is None:
@@ -1233,7 +1244,8 @@ async def execute(
12331244
timeout_seconds=DOWNSTREAM_REQUEST_TIMEOUT_SECONDS,
12341245
)
12351246
predict_result = await inference_gateway.predict(
1236-
topic=model_endpoint.record.destination, predict_request=inference_request
1247+
topic=model_endpoint.record.destination,
1248+
predict_request=inference_request,
12371249
)
12381250

12391251
if predict_result.status != TaskStatus.SUCCESS or predict_result.result is None:
@@ -1517,7 +1529,6 @@ async def execute(
15171529
)
15181530
elif model_content.inference_framework == LLMInferenceFramework.LIGHTLLM:
15191531
if res.status == TaskStatus.SUCCESS and result is not None:
1520-
print(result)
15211532
token = None
15221533
num_completion_tokens += 1
15231534
if request.return_token_log_probs:

model-engine/tests/unit/conftest.py

Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3672,6 +3672,138 @@ def llm_model_endpoint_sync(
36723672
return model_endpoint, model_endpoint_json
36733673

36743674

3675+
@pytest.fixture
3676+
def llm_model_endpoint_sync_tgi(
3677+
test_api_key: str, model_bundle_1: ModelBundle
3678+
) -> Tuple[ModelEndpoint, Any]:
3679+
model_endpoint = ModelEndpoint(
3680+
record=ModelEndpointRecord(
3681+
id="test_llm_model_endpoint_id_2",
3682+
name="test_llm_model_endpoint_name_1",
3683+
created_by=test_api_key,
3684+
created_at=datetime(2022, 1, 3),
3685+
last_updated_at=datetime(2022, 1, 3),
3686+
metadata={
3687+
"_llm": {
3688+
"model_name": "llama-7b",
3689+
"source": "hugging_face",
3690+
"inference_framework": "text_generation_inference",
3691+
"inference_framework_image_tag": "123",
3692+
"num_shards": 4,
3693+
}
3694+
},
3695+
creation_task_id="test_creation_task_id",
3696+
endpoint_type=ModelEndpointType.SYNC,
3697+
destination="test_destination",
3698+
status=ModelEndpointStatus.READY,
3699+
current_model_bundle=model_bundle_1,
3700+
owner=test_api_key,
3701+
public_inference=True,
3702+
),
3703+
infra_state=ModelEndpointInfraState(
3704+
deployment_name=f"{test_api_key}-test_llm_model_endpoint_name_1",
3705+
aws_role="test_aws_role",
3706+
results_s3_bucket="test_s3_bucket",
3707+
child_fn_info=None,
3708+
labels={},
3709+
prewarm=True,
3710+
high_priority=False,
3711+
deployment_state=ModelEndpointDeploymentState(
3712+
min_workers=1,
3713+
max_workers=3,
3714+
per_worker=2,
3715+
available_workers=1,
3716+
unavailable_workers=1,
3717+
),
3718+
resource_state=ModelEndpointResourceState(
3719+
cpus=1,
3720+
gpus=1,
3721+
memory="1G",
3722+
gpu_type=GpuType.NVIDIA_TESLA_T4,
3723+
storage="10G",
3724+
optimize_costs=True,
3725+
),
3726+
user_config_state=ModelEndpointUserConfigState(
3727+
app_config=model_bundle_1.app_config,
3728+
endpoint_config=ModelEndpointConfig(
3729+
bundle_name=model_bundle_1.name,
3730+
endpoint_name="test_llm_model_endpoint_name_1",
3731+
post_inference_hooks=["callback"],
3732+
default_callback_url="http://www.example.com",
3733+
default_callback_auth=CallbackAuth(
3734+
__root__=CallbackBasicAuth(
3735+
kind="basic",
3736+
username="test_username",
3737+
password="test_password",
3738+
),
3739+
),
3740+
),
3741+
),
3742+
num_queued_items=1,
3743+
image="test_image",
3744+
),
3745+
)
3746+
model_endpoint_json: Dict[str, Any] = {
3747+
"id": "test_llm_model_endpoint_id_2",
3748+
"name": "test_llm_model_endpoint_name_1",
3749+
"model_name": "llama-7b",
3750+
"source": "hugging_face",
3751+
"status": "READY",
3752+
"inference_framework": "text_generation_inference",
3753+
"inference_framework_image_tag": "123",
3754+
"num_shards": 4,
3755+
"spec": {
3756+
"id": "test_llm_model_endpoint_id_2",
3757+
"name": "test_llm_model_endpoint_name_1",
3758+
"endpoint_type": "sync",
3759+
"destination": "test_destination",
3760+
"deployment_name": f"{test_api_key}-test_llm_model_endpoint_name_1",
3761+
"metadata": {
3762+
"_llm": {
3763+
"model_name": "llama-7b",
3764+
"source": "hugging_face",
3765+
"inference_framework": "text_generation_inference",
3766+
"inference_framework_image_tag": "123",
3767+
"num_shards": 4,
3768+
}
3769+
},
3770+
"bundle_name": "test_model_bundle_name_1",
3771+
"status": "READY",
3772+
"post_inference_hooks": ["callback"],
3773+
"default_callback_url": "http://www.example.com",
3774+
"default_callback_auth": {
3775+
"kind": "basic",
3776+
"username": "test_username",
3777+
"password": "test_password",
3778+
},
3779+
"labels": {},
3780+
"aws_role": "test_aws_role",
3781+
"results_s3_bucket": "test_s3_bucket",
3782+
"created_by": test_api_key,
3783+
"created_at": "2022-01-03T00:00:00",
3784+
"last_updated_at": "2022-01-03T00:00:00",
3785+
"deployment_state": {
3786+
"min_workers": 1,
3787+
"max_workers": 3,
3788+
"per_worker": 2,
3789+
"available_workers": 1,
3790+
"unavailable_workers": 1,
3791+
},
3792+
"resource_state": {
3793+
"cpus": "1",
3794+
"gpus": 1,
3795+
"memory": "1G",
3796+
"gpu_type": "nvidia-tesla-t4",
3797+
"storage": "10G",
3798+
"optimize_costs": True,
3799+
},
3800+
"num_queued_items": 1,
3801+
"public_inference": True,
3802+
},
3803+
}
3804+
return model_endpoint, model_endpoint_json
3805+
3806+
36753807
@pytest.fixture
36763808
def llm_model_endpoint_streaming(test_api_key: str, model_bundle_5: ModelBundle) -> ModelEndpoint:
36773809
# model_bundle_5 is a runnable bundle

model-engine/tests/unit/domain/test_llm_use_cases.py

Lines changed: 46 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
ObjectHasInvalidValueException,
2323
ObjectNotAuthorizedException,
2424
ObjectNotFoundException,
25+
UpstreamServiceError,
2526
)
2627
from model_engine_server.domain.use_cases.llm_fine_tuning_use_cases import (
2728
MAX_LLM_ENDPOINTS_PER_INTERNAL_USER,
@@ -171,7 +172,8 @@ async def test_create_model_endpoint_text_generation_inference_use_case_success(
171172
)
172173
user = User(user_id=test_api_key, team_id=test_api_key, is_privileged_user=True)
173174
response_1 = await use_case.execute(
174-
user=user, request=create_llm_model_endpoint_text_generation_inference_request_streaming
175+
user=user,
176+
request=create_llm_model_endpoint_text_generation_inference_request_streaming,
175177
)
176178
assert response_1.endpoint_creation_task_id
177179
assert isinstance(response_1, CreateLLMModelEndpointV1Response)
@@ -196,7 +198,8 @@ async def test_create_model_endpoint_text_generation_inference_use_case_success(
196198

197199
with pytest.raises(ObjectHasInvalidValueException):
198200
await use_case.execute(
199-
user=user, request=create_llm_model_endpoint_text_generation_inference_request_async
201+
user=user,
202+
request=create_llm_model_endpoint_text_generation_inference_request_async,
200203
)
201204

202205

@@ -483,6 +486,40 @@ async def test_completion_sync_use_case_predict_failed(
483486
assert response_1.output is None
484487

485488

489+
@pytest.mark.asyncio
490+
async def test_completion_sync_use_case_predict_failed_with_errors(
491+
test_api_key: str,
492+
fake_model_endpoint_service,
493+
fake_llm_model_endpoint_service,
494+
llm_model_endpoint_sync_tgi: Tuple[ModelEndpoint, Any],
495+
completion_sync_request: CompletionSyncV1Request,
496+
):
497+
fake_llm_model_endpoint_service.add_model_endpoint(llm_model_endpoint_sync_tgi[0])
498+
fake_model_endpoint_service.sync_model_endpoint_inference_gateway.response = SyncEndpointPredictV1Response(
499+
status=TaskStatus.SUCCESS,
500+
result={
501+
"result": """
502+
{
503+
"error": "Request failed during generation: Server error: transport error",
504+
"error_type": "generation"
505+
}
506+
"""
507+
},
508+
traceback="failed to predict",
509+
)
510+
use_case = CompletionSyncV1UseCase(
511+
model_endpoint_service=fake_model_endpoint_service,
512+
llm_model_endpoint_service=fake_llm_model_endpoint_service,
513+
)
514+
user = User(user_id=test_api_key, team_id=test_api_key, is_privileged_user=True)
515+
with pytest.raises(UpstreamServiceError):
516+
await use_case.execute(
517+
user=user,
518+
model_endpoint_name=llm_model_endpoint_sync_tgi[0].record.name,
519+
request=completion_sync_request,
520+
)
521+
522+
486523
@pytest.mark.asyncio
487524
async def test_completion_sync_use_case_not_sync_endpoint_raises(
488525
test_api_key: str,
@@ -964,7 +1001,13 @@ async def test_delete_public_inference_model_raises_not_authorized(
9641001

9651002
@pytest.mark.asyncio
9661003
async def test_exclude_safetensors_or_bin_majority_bin_returns_exclude_safetensors():
967-
fake_model_files = ["fake.bin", "fake2.bin", "fake3.safetensors", "model.json", "optimizer.pt"]
1004+
fake_model_files = [
1005+
"fake.bin",
1006+
"fake2.bin",
1007+
"fake3.safetensors",
1008+
"model.json",
1009+
"optimizer.pt",
1010+
]
9681011
assert _exclude_safetensors_or_bin(fake_model_files) == "*.safetensors"
9691012

9701013

0 commit comments

Comments
 (0)