11import json
2- import re
32from typing import Any , Dict , Tuple
3+ from unittest import mock
44
55import pytest
66from model_engine_server .common .dtos .llms import GetLLMModelEndpointV1Response
@@ -156,12 +156,14 @@ def test_completion_sync_endpoint_not_found_returns_404(
156156 assert response_1 .status_code == 404
157157
158158
159+ # When enabling this test, other tests fail with "RunTumeError got Future <Future pending> attached to a different loop"
160+ # https://github.com/encode/starlette/issues/1315#issuecomment-980784457
159161@pytest .mark .skip (reason = "Need to figure out FastAPI test client asyncio funkiness" )
160162def test_completion_stream_success (
161163 llm_model_endpoint_streaming : ModelEndpoint ,
162164 completion_stream_request : Dict [str , Any ],
163165 get_test_client_wrapper ,
164- ):
166+ ): # pragma: no cover
165167 client = get_test_client_wrapper (
166168 fake_docker_repository_image_always_exists = True ,
167169 fake_model_bundle_repository_contents = {},
@@ -175,19 +177,28 @@ def test_completion_stream_success(
175177 fake_batch_job_progress_gateway_contents = {},
176178 fake_docker_image_batch_job_bundle_repository_contents = {},
177179 )
178- response_1 = client .post (
179- f"/v1/llm/completions-stream?model_endpoint_name={ llm_model_endpoint_streaming .record .name } " ,
180- auth = ("no_user" , "" ),
181- json = completion_stream_request ,
182- stream = True ,
183- )
180+ with mock .patch (
181+ "model_engine_server.domain.use_cases.llm_model_endpoint_use_cases.count_tokens" ,
182+ return_value = 5 ,
183+ ):
184+ response_1 = client .post (
185+ f"/v1/llm/completions-stream?model_endpoint_name={ llm_model_endpoint_streaming .record .name } " ,
186+ auth = ("no_user" , "" ),
187+ json = completion_stream_request ,
188+ stream = True ,
189+ )
184190 assert response_1 .status_code == 200
185191 count = 0
186192 for message in response_1 :
187- assert re .fullmatch (
188- 'data: {"request_id"}: ".*", "output": null}\r \n \r \n ' ,
189- message .decode ("utf-8" ),
190- )
193+ decoded_message = message .decode ("utf-8" )
194+ assert decoded_message .startswith ("data: " ), "SSE does not start with 'data: '"
195+
196+ # strip 'data: ' prefix from Server-sent events format
197+ json_str = decoded_message [len ("data: " ) :]
198+ parsed_data = json .loads (json_str .strip ())
199+ assert parsed_data ["request_id" ] is not None
200+ assert parsed_data ["output" ] is None
201+ assert parsed_data ["error" ] is None
191202 count += 1
192203 assert count == 1
193204
0 commit comments