Skip to content

Commit dd4a0f9

Browse files
authored
Fix integration tests for streaming case (#548)
1 parent 0ff1824 commit dd4a0f9

File tree

2 files changed

+25
-3
lines changed

2 files changed

+25
-3
lines changed

integration_tests/rest_api_utils.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,7 @@ def my_model(**keyword_args):
168168

169169
CREATE_LLM_MODEL_ENDPOINT_REQUEST: Dict[str, Any] = {
170170
"name": format_name("llama-2-7b-test"),
171-
"model_name": "llama-2-7b",
171+
"model_name": "llama-2-7b-chat",
172172
"source": "hugging_face",
173173
"inference_framework": "vllm",
174174
"inference_framework_image_tag": "latest",
@@ -802,7 +802,7 @@ async def create_llm_streaming_task(
802802
timeout=LONG_NETWORK_TIMEOUT_SEC,
803803
) as response:
804804
assert response.status == 200, (await response.read()).decode()
805-
return await response.json()
805+
return (await response.read()).decode()
806806

807807

808808
async def create_sync_tasks(
@@ -987,6 +987,27 @@ def ensure_llm_task_response_is_correct(
987987
assert re.search(response_text_regex, response["output"]["text"])
988988

989989

990+
def ensure_llm_task_stream_response_is_correct(
991+
response: str,
992+
required_output_fields: Optional[List[str]],
993+
response_text_regex: Optional[str],
994+
):
995+
# parse response
996+
# data has format "data: <data>\n\ndata: <data>\n\n"
997+
# We want to get a list of dictionaries parsing out the 'data:' field
998+
parsed_response = [
999+
json.loads(r.split("data: ")[1]) for r in response.split("\n") if "data:" in r.strip()
1000+
]
1001+
1002+
# Join the text field of the response
1003+
response_text = "".join([r["output"]["text"] for r in parsed_response])
1004+
print("response text: ", response_text)
1005+
assert response_text is not None
1006+
1007+
if response_text_regex is not None:
1008+
assert re.search(response_text_regex, response_text)
1009+
1010+
9901011
# Wait up to 30 seconds for the tasks to be returned.
9911012
@retry(
9921013
stop=stop_after_attempt(10), wait=wait_fixed(1), retry=retry_if_exception_type(AssertionError)

integration_tests/test_completions.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
delete_llm_model_endpoint,
1414
ensure_launch_gateway_healthy,
1515
ensure_llm_task_response_is_correct,
16+
ensure_llm_task_stream_response_is_correct,
1617
ensure_n_ready_private_llm_endpoints_short,
1718
ensure_nonzero_available_llm_workers,
1819
)
@@ -86,7 +87,7 @@ def test_completions(capsys):
8687
)
8788
)
8889
for response in task_responses:
89-
ensure_llm_task_response_is_correct(
90+
ensure_llm_task_stream_response_is_correct(
9091
response, required_output_fields, response_text_regex
9192
)
9293
except Exception as e:

0 commit comments

Comments
 (0)