Skip to content

Commit 70d0e77

Browse files
authored
integration tests for completions (#507)
* get latest inference framework tag from configmap * comments * fix for test * make namespace a config * fix s3 prefix bug * fix checkpoint path fn + tests * integration tests for completions * values change * quotes
1 parent a2bf698 commit 70d0e77

File tree

2 files changed

+337
-1
lines changed

2 files changed

+337
-1
lines changed

integration_tests/rest_api_utils.py

Lines changed: 241 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,9 @@
22
import inspect
33
import json
44
import os
5+
import re
56
import time
6-
from typing import Any, Dict, List, Sequence
7+
from typing import Any, Dict, List, Optional, Sequence
78

89
import aiohttp
910
import requests
@@ -14,6 +15,7 @@
1415
BASE_PATH = os.environ.get("BASE_PATH", _DEFAULT_BASE_PATH)
1516
print(f"Integration tests using gateway {BASE_PATH=}")
1617
DEFAULT_NETWORK_TIMEOUT_SEC = 10
18+
LONG_NETWORK_TIMEOUT_SEC = 30
1719

1820
# add suffix to avoid name collisions
1921
SERVICE_IDENTIFIER = os.environ.get("SERVICE_IDENTIFIER", "")
@@ -164,12 +166,87 @@ def my_model(**keyword_args):
164166
"url": None,
165167
}
166168

169+
CREATE_LLM_MODEL_ENDPOINT_REQUEST: Dict[str, Any] = {
170+
"name": format_name("llama-2-7b-test"),
171+
"model_name": "llama-2-7b",
172+
"source": "hugging_face",
173+
"inference_framework": "vllm",
174+
"inference_framework_image_tag": "latest",
175+
"endpoint_type": "streaming",
176+
"cpus": 20,
177+
"gpus": 1,
178+
"memory": "20Gi",
179+
"gpu_type": "nvidia-ampere-a10",
180+
"storage": "40Gi",
181+
"optimize_costs": False,
182+
"min_workers": 1,
183+
"max_workers": 1,
184+
"per_worker": 1,
185+
"labels": {"team": "infra", "product": "launch"},
186+
"metadata": {"key": "value"},
187+
"public_inference": False,
188+
}
189+
190+
167191
INFERENCE_PAYLOAD_RETURN_PICKLED_FALSE: Dict[str, Any] = INFERENCE_PAYLOAD.copy()
168192
INFERENCE_PAYLOAD_RETURN_PICKLED_FALSE["return_pickled"] = False
169193

170194
INFERENCE_PAYLOAD_RETURN_PICKLED_TRUE: Dict[str, Any] = INFERENCE_PAYLOAD.copy()
171195
INFERENCE_PAYLOAD_RETURN_PICKLED_TRUE["return_pickled"] = True
172196

197+
LLM_PAYLOAD: Dict[str, Any] = {
198+
"prompt": "Hello, my name is",
199+
"max_new_tokens": 10,
200+
"temperature": 0.2,
201+
}
202+
203+
LLM_PAYLOAD_WITH_STOP_SEQUENCE: Dict[str, Any] = LLM_PAYLOAD.copy()
204+
LLM_PAYLOAD_WITH_STOP_SEQUENCE["stop_sequences"] = ["\n"]
205+
206+
LLM_PAYLOAD_WITH_PRESENCE_PENALTY: Dict[str, Any] = LLM_PAYLOAD.copy()
207+
LLM_PAYLOAD_WITH_PRESENCE_PENALTY["presence_penalty"] = 0.5
208+
209+
LLM_PAYLOAD_WITH_FREQUENCY_PENALTY: Dict[str, Any] = LLM_PAYLOAD.copy()
210+
LLM_PAYLOAD_WITH_FREQUENCY_PENALTY["frequency_penalty"] = 0.5
211+
212+
LLM_PAYLOAD_WITH_TOP_K: Dict[str, Any] = LLM_PAYLOAD.copy()
213+
LLM_PAYLOAD_WITH_TOP_K["top_k"] = 10
214+
215+
LLM_PAYLOAD_WITH_TOP_P: Dict[str, Any] = LLM_PAYLOAD.copy()
216+
LLM_PAYLOAD_WITH_TOP_P["top_p"] = 0.5
217+
218+
LLM_PAYLOAD_WITH_INCLUDE_STOP_STR_IN_OUTPUT: Dict[str, Any] = LLM_PAYLOAD.copy()
219+
LLM_PAYLOAD_WITH_INCLUDE_STOP_STR_IN_OUTPUT["include_stop_str_in_output"] = True
220+
221+
LLM_PAYLOAD_WITH_GUIDED_JSON: Dict[str, Any] = LLM_PAYLOAD.copy()
222+
LLM_PAYLOAD_WITH_GUIDED_JSON["guided_json"] = {
223+
"properties": {"myString": {"type": "string"}},
224+
"required": ["myString"],
225+
}
226+
227+
LLM_PAYLOAD_WITH_GUIDED_REGEX: Dict[str, Any] = LLM_PAYLOAD.copy()
228+
LLM_PAYLOAD_WITH_GUIDED_REGEX["guided_regex"] = "Sean.*"
229+
230+
LLM_PAYLOAD_WITH_GUIDED_CHOICE: Dict[str, Any] = LLM_PAYLOAD.copy()
231+
LLM_PAYLOAD_WITH_GUIDED_CHOICE["guided_choice"] = ["dog", "cat"]
232+
233+
LLM_PAYLOAD_WITH_GUIDED_GRAMMAR: Dict[str, Any] = LLM_PAYLOAD.copy()
234+
LLM_PAYLOAD_WITH_GUIDED_GRAMMAR["guided_grammar"] = 'start: "John"'
235+
236+
LLM_PAYLOADS_WITH_EXPECTED_RESPONSES = [
237+
(LLM_PAYLOAD, None, None),
238+
(LLM_PAYLOAD_WITH_STOP_SEQUENCE, None, None),
239+
(LLM_PAYLOAD_WITH_PRESENCE_PENALTY, None, None),
240+
(LLM_PAYLOAD_WITH_FREQUENCY_PENALTY, None, None),
241+
(LLM_PAYLOAD_WITH_TOP_K, None, None),
242+
(LLM_PAYLOAD_WITH_TOP_P, None, None),
243+
(LLM_PAYLOAD_WITH_INCLUDE_STOP_STR_IN_OUTPUT, ["tokens"], None),
244+
(LLM_PAYLOAD_WITH_GUIDED_JSON, None, None),
245+
(LLM_PAYLOAD_WITH_GUIDED_REGEX, None, "Sean.*"),
246+
(LLM_PAYLOAD_WITH_GUIDED_CHOICE, None, "dog|cat"),
247+
(LLM_PAYLOAD_WITH_GUIDED_GRAMMAR, None, "John"),
248+
]
249+
173250
CREATE_BATCH_JOB_REQUEST: Dict[str, Any] = {
174251
"bundle_name": "model_bundle_simple",
175252
"input_path": "TBA",
@@ -524,6 +601,18 @@ def get_model_endpoint(name: str, user_id: str) -> Dict[str, Any]:
524601
return response.json()["model_endpoints"][0]
525602

526603

604+
@retry(stop=stop_after_attempt(6), wait=wait_fixed(1))
605+
def get_llm_model_endpoint(name: str, user_id: str) -> Dict[str, Any]:
606+
response = requests.get(
607+
f"{BASE_PATH}/v1/llm/model-endpoints/{name}",
608+
auth=(user_id, ""),
609+
timeout=DEFAULT_NETWORK_TIMEOUT_SEC,
610+
)
611+
if not response.ok:
612+
raise ValueError(response.content)
613+
return response.json()
614+
615+
527616
@retry(stop=stop_after_attempt(3), wait=wait_fixed(20))
528617
def update_model_endpoint(
529618
endpoint_name: str, update_model_endpoint_request: Dict[str, Any], user_id: str
@@ -556,6 +645,18 @@ def delete_model_endpoint(endpoint_name: str, user_id: str) -> Dict[str, Any]:
556645
return response.json()
557646

558647

648+
def delete_llm_model_endpoint(endpoint_name: str, user_id: str) -> Dict[str, Any]:
649+
response = requests.delete(
650+
f"{BASE_PATH}/v1/llm/model-endpoints/{endpoint_name}",
651+
headers={"Content-Type": "application/json"},
652+
auth=(user_id, ""),
653+
timeout=DEFAULT_NETWORK_TIMEOUT_SEC,
654+
)
655+
if not response.ok:
656+
raise ValueError(response.content)
657+
return response.json()
658+
659+
559660
@retry(stop=stop_after_attempt(3), wait=wait_fixed(1))
560661
def list_model_endpoints(user_id: str) -> List[Dict[str, Any]]:
561662
response = requests.get(
@@ -568,6 +669,44 @@ def list_model_endpoints(user_id: str) -> List[Dict[str, Any]]:
568669
return response.json()["model_endpoints"]
569670

570671

672+
@retry(stop=stop_after_attempt(3), wait=wait_fixed(1))
673+
def list_llm_model_endpoints(user_id: str) -> List[Dict[str, Any]]:
674+
response = requests.get(
675+
f"{BASE_PATH}/v1/llm/model-endpoints",
676+
auth=(user_id, ""),
677+
timeout=DEFAULT_NETWORK_TIMEOUT_SEC,
678+
)
679+
if not response.ok:
680+
raise ValueError(response.content)
681+
return response.json()["model_endpoints"]
682+
683+
684+
@retry(stop=stop_after_attempt(3), wait=wait_fixed(1))
685+
def create_llm_model_endpoint(
686+
create_llm_model_endpoint_request: Dict[str, Any],
687+
user_id: str,
688+
inference_framework: Optional[str],
689+
inference_framework_image_tag: Optional[str],
690+
) -> Dict[str, Any]:
691+
create_model_endpoint_request = create_llm_model_endpoint_request.copy()
692+
if inference_framework:
693+
create_model_endpoint_request["inference_framework"] = inference_framework
694+
if inference_framework_image_tag:
695+
create_model_endpoint_request[
696+
"inference_framework_image_tag"
697+
] = inference_framework_image_tag
698+
response = requests.post(
699+
f"{BASE_PATH}/v1/llm/model-endpoints",
700+
json=create_model_endpoint_request,
701+
headers={"Content-Type": "application/json"},
702+
auth=(user_id, ""),
703+
timeout=DEFAULT_NETWORK_TIMEOUT_SEC,
704+
)
705+
if not response.ok:
706+
raise ValueError(response.content)
707+
return response.json()
708+
709+
571710
async def create_async_task(
572711
model_endpoint_id: str,
573712
create_async_task_request: Dict[str, Any],
@@ -615,6 +754,23 @@ async def create_sync_task(
615754
return await response.json()
616755

617756

757+
async def create_llm_sync_task(
758+
model_endpoint_name: str,
759+
create_sync_task_request: Dict[str, Any],
760+
user_id: str,
761+
session: aiohttp.ClientSession,
762+
) -> str:
763+
async with session.post(
764+
f"{BASE_PATH}/v1/llm/completions-sync?model_endpoint_name={model_endpoint_name}",
765+
json=create_sync_task_request,
766+
headers={"Content-Type": "application/json"},
767+
auth=aiohttp.BasicAuth(user_id, ""),
768+
timeout=LONG_NETWORK_TIMEOUT_SEC,
769+
) as response:
770+
assert response.status == 200, (await response.read()).decode()
771+
return await response.json()
772+
773+
618774
async def create_streaming_task(
619775
model_endpoint_id: str,
620776
create_streaming_task_request: Dict[str, Any],
@@ -632,6 +788,23 @@ async def create_streaming_task(
632788
return (await response.read()).decode()
633789

634790

791+
async def create_llm_streaming_task(
792+
model_endpoint_name: str,
793+
create_streaming_task_request: Dict[str, Any],
794+
user_id: str,
795+
session: aiohttp.ClientSession,
796+
) -> str:
797+
async with session.post(
798+
f"{BASE_PATH}/v1/llm/completions-stream?model_endpoint_name={model_endpoint_name}",
799+
json=create_streaming_task_request,
800+
headers={"Content-Type": "application/json"},
801+
auth=aiohttp.BasicAuth(user_id, ""),
802+
timeout=LONG_NETWORK_TIMEOUT_SEC,
803+
) as response:
804+
assert response.status == 200, (await response.read()).decode()
805+
return await response.json()
806+
807+
635808
async def create_sync_tasks(
636809
endpoint_name: str, create_sync_task_requests: List[Dict[str, Any]], user_id: str
637810
) -> List[Any]:
@@ -646,6 +819,19 @@ async def create_sync_tasks(
646819
return result # type: ignore
647820

648821

822+
async def create_llm_sync_tasks(
823+
endpoint_name: str, create_sync_task_requests: List[Dict[str, Any]], user_id: str
824+
) -> List[Any]:
825+
async with aiohttp.ClientSession() as session:
826+
tasks = []
827+
for create_sync_task_request in create_sync_task_requests:
828+
task = create_llm_sync_task(endpoint_name, create_sync_task_request, user_id, session)
829+
tasks.append(asyncio.create_task(task))
830+
831+
result = await asyncio.gather(*tasks)
832+
return result # type: ignore
833+
834+
649835
async def create_streaming_tasks(
650836
endpoint_name: str, create_streaming_task_requests: List[Dict[str, Any]], user_id: str
651837
) -> List[Any]:
@@ -662,6 +848,21 @@ async def create_streaming_tasks(
662848
return result # type: ignore
663849

664850

851+
async def create_llm_streaming_tasks(
852+
endpoint_name: str, create_streaming_task_requests: List[Dict[str, Any]], user_id: str
853+
) -> List[Any]:
854+
async with aiohttp.ClientSession() as session:
855+
tasks = []
856+
for create_streaming_task_request in create_streaming_task_requests:
857+
task = create_llm_streaming_task(
858+
endpoint_name, create_streaming_task_request, user_id, session
859+
)
860+
tasks.append(asyncio.create_task(task))
861+
862+
result = await asyncio.gather(*tasks)
863+
return result # type: ignore
864+
865+
665866
async def get_async_task(
666867
task_id: str, user_id: str, session: aiohttp.ClientSession
667868
) -> Dict[str, Any]:
@@ -708,6 +909,22 @@ def ensure_n_ready_endpoints_short(n: int, user_id: str):
708909
assert len(ready_endpoints) >= n
709910

710911

912+
# Wait 2 minutes (120 seconds) for endpoints to build.
913+
@retry(stop=stop_after_attempt(12), wait=wait_fixed(10))
914+
def ensure_n_ready_private_llm_endpoints_short(n: int, user_id: str):
915+
endpoints = list_llm_model_endpoints(user_id)
916+
private_endpoints = [
917+
endpoint for endpoint in endpoints if not endpoint["spec"]["public_inference"]
918+
]
919+
ready_endpoints = [endpoint for endpoint in private_endpoints if endpoint["status"] == "READY"]
920+
print(
921+
f"User {user_id} Current num endpoints: {len(private_endpoints)}, num ready endpoints: {len(ready_endpoints)}"
922+
)
923+
assert (
924+
len(ready_endpoints) >= n
925+
), f"Expected {n} ready endpoints, got {len(ready_endpoints)}. Look through endpoint builder for errors."
926+
927+
711928
def delete_all_endpoints(user_id: str, delete_suffix_only: bool):
712929
endpoints = list_model_endpoints(user_id)
713930
for i, endpoint in enumerate(endpoints):
@@ -737,6 +954,13 @@ def ensure_nonzero_available_workers(endpoint_name: str, user_id: str):
737954
assert simple_endpoint.get("deployment_state", {}).get("available_workers", 0)
738955

739956

957+
# Wait up to 20 minutes (1200 seconds) for the pods to spin up.
958+
@retry(stop=stop_after_attempt(120), wait=wait_fixed(10))
959+
def ensure_nonzero_available_llm_workers(endpoint_name: str, user_id: str):
960+
simple_endpoint = get_llm_model_endpoint(endpoint_name, user_id)
961+
assert simple_endpoint["spec"].get("deployment_state", {}).get("available_workers", 0)
962+
963+
740964
def ensure_inference_task_response_is_correct(response: Dict[str, Any], return_pickled: bool):
741965
print(response)
742966
assert response["status"] == "SUCCESS"
@@ -747,6 +971,22 @@ def ensure_inference_task_response_is_correct(response: Dict[str, Any], return_p
747971
assert response["result"] == {"result": '{"y": 1}'}
748972

749973

974+
def ensure_llm_task_response_is_correct(
975+
response: Dict[str, Any],
976+
required_output_fields: Optional[List[str]],
977+
response_text_regex: Optional[str],
978+
):
979+
print(response)
980+
assert response["output"] is not None
981+
982+
if required_output_fields is not None:
983+
for field in required_output_fields:
984+
assert field in response["output"]
985+
986+
if response_text_regex is not None:
987+
assert re.search(response_text_regex, response["output"]["text"])
988+
989+
750990
# Wait up to 30 seconds for the tasks to be returned.
751991
@retry(
752992
stop=stop_after_attempt(10), wait=wait_fixed(1), retry=retry_if_exception_type(AssertionError)

0 commit comments

Comments
 (0)