22import inspect
33import json
44import os
5+ import re
56import time
6- from typing import Any , Dict , List , Sequence
7+ from typing import Any , Dict , List , Optional , Sequence
78
89import aiohttp
910import requests
1415BASE_PATH = os .environ .get ("BASE_PATH" , _DEFAULT_BASE_PATH )
1516print (f"Integration tests using gateway { BASE_PATH = } " )
1617DEFAULT_NETWORK_TIMEOUT_SEC = 10
18+ LONG_NETWORK_TIMEOUT_SEC = 30
1719
1820# add suffix to avoid name collisions
1921SERVICE_IDENTIFIER = os .environ .get ("SERVICE_IDENTIFIER" , "" )
@@ -164,12 +166,87 @@ def my_model(**keyword_args):
164166 "url" : None ,
165167}
166168
169+ CREATE_LLM_MODEL_ENDPOINT_REQUEST : Dict [str , Any ] = {
170+ "name" : format_name ("llama-2-7b-test" ),
171+ "model_name" : "llama-2-7b" ,
172+ "source" : "hugging_face" ,
173+ "inference_framework" : "vllm" ,
174+ "inference_framework_image_tag" : "latest" ,
175+ "endpoint_type" : "streaming" ,
176+ "cpus" : 20 ,
177+ "gpus" : 1 ,
178+ "memory" : "20Gi" ,
179+ "gpu_type" : "nvidia-ampere-a10" ,
180+ "storage" : "40Gi" ,
181+ "optimize_costs" : False ,
182+ "min_workers" : 1 ,
183+ "max_workers" : 1 ,
184+ "per_worker" : 1 ,
185+ "labels" : {"team" : "infra" , "product" : "launch" },
186+ "metadata" : {"key" : "value" },
187+ "public_inference" : False ,
188+ }
189+
190+
167191INFERENCE_PAYLOAD_RETURN_PICKLED_FALSE : Dict [str , Any ] = INFERENCE_PAYLOAD .copy ()
168192INFERENCE_PAYLOAD_RETURN_PICKLED_FALSE ["return_pickled" ] = False
169193
170194INFERENCE_PAYLOAD_RETURN_PICKLED_TRUE : Dict [str , Any ] = INFERENCE_PAYLOAD .copy ()
171195INFERENCE_PAYLOAD_RETURN_PICKLED_TRUE ["return_pickled" ] = True
172196
197+ LLM_PAYLOAD : Dict [str , Any ] = {
198+ "prompt" : "Hello, my name is" ,
199+ "max_new_tokens" : 10 ,
200+ "temperature" : 0.2 ,
201+ }
202+
203+ LLM_PAYLOAD_WITH_STOP_SEQUENCE : Dict [str , Any ] = LLM_PAYLOAD .copy ()
204+ LLM_PAYLOAD_WITH_STOP_SEQUENCE ["stop_sequences" ] = ["\n " ]
205+
206+ LLM_PAYLOAD_WITH_PRESENCE_PENALTY : Dict [str , Any ] = LLM_PAYLOAD .copy ()
207+ LLM_PAYLOAD_WITH_PRESENCE_PENALTY ["presence_penalty" ] = 0.5
208+
209+ LLM_PAYLOAD_WITH_FREQUENCY_PENALTY : Dict [str , Any ] = LLM_PAYLOAD .copy ()
210+ LLM_PAYLOAD_WITH_FREQUENCY_PENALTY ["frequency_penalty" ] = 0.5
211+
212+ LLM_PAYLOAD_WITH_TOP_K : Dict [str , Any ] = LLM_PAYLOAD .copy ()
213+ LLM_PAYLOAD_WITH_TOP_K ["top_k" ] = 10
214+
215+ LLM_PAYLOAD_WITH_TOP_P : Dict [str , Any ] = LLM_PAYLOAD .copy ()
216+ LLM_PAYLOAD_WITH_TOP_P ["top_p" ] = 0.5
217+
218+ LLM_PAYLOAD_WITH_INCLUDE_STOP_STR_IN_OUTPUT : Dict [str , Any ] = LLM_PAYLOAD .copy ()
219+ LLM_PAYLOAD_WITH_INCLUDE_STOP_STR_IN_OUTPUT ["include_stop_str_in_output" ] = True
220+
221+ LLM_PAYLOAD_WITH_GUIDED_JSON : Dict [str , Any ] = LLM_PAYLOAD .copy ()
222+ LLM_PAYLOAD_WITH_GUIDED_JSON ["guided_json" ] = {
223+ "properties" : {"myString" : {"type" : "string" }},
224+ "required" : ["myString" ],
225+ }
226+
227+ LLM_PAYLOAD_WITH_GUIDED_REGEX : Dict [str , Any ] = LLM_PAYLOAD .copy ()
228+ LLM_PAYLOAD_WITH_GUIDED_REGEX ["guided_regex" ] = "Sean.*"
229+
230+ LLM_PAYLOAD_WITH_GUIDED_CHOICE : Dict [str , Any ] = LLM_PAYLOAD .copy ()
231+ LLM_PAYLOAD_WITH_GUIDED_CHOICE ["guided_choice" ] = ["dog" , "cat" ]
232+
233+ LLM_PAYLOAD_WITH_GUIDED_GRAMMAR : Dict [str , Any ] = LLM_PAYLOAD .copy ()
234+ LLM_PAYLOAD_WITH_GUIDED_GRAMMAR ["guided_grammar" ] = 'start: "John"'
235+
236+ LLM_PAYLOADS_WITH_EXPECTED_RESPONSES = [
237+ (LLM_PAYLOAD , None , None ),
238+ (LLM_PAYLOAD_WITH_STOP_SEQUENCE , None , None ),
239+ (LLM_PAYLOAD_WITH_PRESENCE_PENALTY , None , None ),
240+ (LLM_PAYLOAD_WITH_FREQUENCY_PENALTY , None , None ),
241+ (LLM_PAYLOAD_WITH_TOP_K , None , None ),
242+ (LLM_PAYLOAD_WITH_TOP_P , None , None ),
243+ (LLM_PAYLOAD_WITH_INCLUDE_STOP_STR_IN_OUTPUT , ["tokens" ], None ),
244+ (LLM_PAYLOAD_WITH_GUIDED_JSON , None , None ),
245+ (LLM_PAYLOAD_WITH_GUIDED_REGEX , None , "Sean.*" ),
246+ (LLM_PAYLOAD_WITH_GUIDED_CHOICE , None , "dog|cat" ),
247+ (LLM_PAYLOAD_WITH_GUIDED_GRAMMAR , None , "John" ),
248+ ]
249+
173250CREATE_BATCH_JOB_REQUEST : Dict [str , Any ] = {
174251 "bundle_name" : "model_bundle_simple" ,
175252 "input_path" : "TBA" ,
@@ -524,6 +601,18 @@ def get_model_endpoint(name: str, user_id: str) -> Dict[str, Any]:
524601 return response .json ()["model_endpoints" ][0 ]
525602
526603
604+ @retry (stop = stop_after_attempt (6 ), wait = wait_fixed (1 ))
605+ def get_llm_model_endpoint (name : str , user_id : str ) -> Dict [str , Any ]:
606+ response = requests .get (
607+ f"{ BASE_PATH } /v1/llm/model-endpoints/{ name } " ,
608+ auth = (user_id , "" ),
609+ timeout = DEFAULT_NETWORK_TIMEOUT_SEC ,
610+ )
611+ if not response .ok :
612+ raise ValueError (response .content )
613+ return response .json ()
614+
615+
527616@retry (stop = stop_after_attempt (3 ), wait = wait_fixed (20 ))
528617def update_model_endpoint (
529618 endpoint_name : str , update_model_endpoint_request : Dict [str , Any ], user_id : str
@@ -556,6 +645,18 @@ def delete_model_endpoint(endpoint_name: str, user_id: str) -> Dict[str, Any]:
556645 return response .json ()
557646
558647
648+ def delete_llm_model_endpoint (endpoint_name : str , user_id : str ) -> Dict [str , Any ]:
649+ response = requests .delete (
650+ f"{ BASE_PATH } /v1/llm/model-endpoints/{ endpoint_name } " ,
651+ headers = {"Content-Type" : "application/json" },
652+ auth = (user_id , "" ),
653+ timeout = DEFAULT_NETWORK_TIMEOUT_SEC ,
654+ )
655+ if not response .ok :
656+ raise ValueError (response .content )
657+ return response .json ()
658+
659+
559660@retry (stop = stop_after_attempt (3 ), wait = wait_fixed (1 ))
560661def list_model_endpoints (user_id : str ) -> List [Dict [str , Any ]]:
561662 response = requests .get (
@@ -568,6 +669,44 @@ def list_model_endpoints(user_id: str) -> List[Dict[str, Any]]:
568669 return response .json ()["model_endpoints" ]
569670
570671
672+ @retry (stop = stop_after_attempt (3 ), wait = wait_fixed (1 ))
673+ def list_llm_model_endpoints (user_id : str ) -> List [Dict [str , Any ]]:
674+ response = requests .get (
675+ f"{ BASE_PATH } /v1/llm/model-endpoints" ,
676+ auth = (user_id , "" ),
677+ timeout = DEFAULT_NETWORK_TIMEOUT_SEC ,
678+ )
679+ if not response .ok :
680+ raise ValueError (response .content )
681+ return response .json ()["model_endpoints" ]
682+
683+
684+ @retry (stop = stop_after_attempt (3 ), wait = wait_fixed (1 ))
685+ def create_llm_model_endpoint (
686+ create_llm_model_endpoint_request : Dict [str , Any ],
687+ user_id : str ,
688+ inference_framework : Optional [str ],
689+ inference_framework_image_tag : Optional [str ],
690+ ) -> Dict [str , Any ]:
691+ create_model_endpoint_request = create_llm_model_endpoint_request .copy ()
692+ if inference_framework :
693+ create_model_endpoint_request ["inference_framework" ] = inference_framework
694+ if inference_framework_image_tag :
695+ create_model_endpoint_request [
696+ "inference_framework_image_tag"
697+ ] = inference_framework_image_tag
698+ response = requests .post (
699+ f"{ BASE_PATH } /v1/llm/model-endpoints" ,
700+ json = create_model_endpoint_request ,
701+ headers = {"Content-Type" : "application/json" },
702+ auth = (user_id , "" ),
703+ timeout = DEFAULT_NETWORK_TIMEOUT_SEC ,
704+ )
705+ if not response .ok :
706+ raise ValueError (response .content )
707+ return response .json ()
708+
709+
571710async def create_async_task (
572711 model_endpoint_id : str ,
573712 create_async_task_request : Dict [str , Any ],
@@ -615,6 +754,23 @@ async def create_sync_task(
615754 return await response .json ()
616755
617756
757+ async def create_llm_sync_task (
758+ model_endpoint_name : str ,
759+ create_sync_task_request : Dict [str , Any ],
760+ user_id : str ,
761+ session : aiohttp .ClientSession ,
762+ ) -> str :
763+ async with session .post (
764+ f"{ BASE_PATH } /v1/llm/completions-sync?model_endpoint_name={ model_endpoint_name } " ,
765+ json = create_sync_task_request ,
766+ headers = {"Content-Type" : "application/json" },
767+ auth = aiohttp .BasicAuth (user_id , "" ),
768+ timeout = LONG_NETWORK_TIMEOUT_SEC ,
769+ ) as response :
770+ assert response .status == 200 , (await response .read ()).decode ()
771+ return await response .json ()
772+
773+
618774async def create_streaming_task (
619775 model_endpoint_id : str ,
620776 create_streaming_task_request : Dict [str , Any ],
@@ -632,6 +788,23 @@ async def create_streaming_task(
632788 return (await response .read ()).decode ()
633789
634790
791+ async def create_llm_streaming_task (
792+ model_endpoint_name : str ,
793+ create_streaming_task_request : Dict [str , Any ],
794+ user_id : str ,
795+ session : aiohttp .ClientSession ,
796+ ) -> str :
797+ async with session .post (
798+ f"{ BASE_PATH } /v1/llm/completions-stream?model_endpoint_name={ model_endpoint_name } " ,
799+ json = create_streaming_task_request ,
800+ headers = {"Content-Type" : "application/json" },
801+ auth = aiohttp .BasicAuth (user_id , "" ),
802+ timeout = LONG_NETWORK_TIMEOUT_SEC ,
803+ ) as response :
804+ assert response .status == 200 , (await response .read ()).decode ()
805+ return await response .json ()
806+
807+
635808async def create_sync_tasks (
636809 endpoint_name : str , create_sync_task_requests : List [Dict [str , Any ]], user_id : str
637810) -> List [Any ]:
@@ -646,6 +819,19 @@ async def create_sync_tasks(
646819 return result # type: ignore
647820
648821
822+ async def create_llm_sync_tasks (
823+ endpoint_name : str , create_sync_task_requests : List [Dict [str , Any ]], user_id : str
824+ ) -> List [Any ]:
825+ async with aiohttp .ClientSession () as session :
826+ tasks = []
827+ for create_sync_task_request in create_sync_task_requests :
828+ task = create_llm_sync_task (endpoint_name , create_sync_task_request , user_id , session )
829+ tasks .append (asyncio .create_task (task ))
830+
831+ result = await asyncio .gather (* tasks )
832+ return result # type: ignore
833+
834+
649835async def create_streaming_tasks (
650836 endpoint_name : str , create_streaming_task_requests : List [Dict [str , Any ]], user_id : str
651837) -> List [Any ]:
@@ -662,6 +848,21 @@ async def create_streaming_tasks(
662848 return result # type: ignore
663849
664850
851+ async def create_llm_streaming_tasks (
852+ endpoint_name : str , create_streaming_task_requests : List [Dict [str , Any ]], user_id : str
853+ ) -> List [Any ]:
854+ async with aiohttp .ClientSession () as session :
855+ tasks = []
856+ for create_streaming_task_request in create_streaming_task_requests :
857+ task = create_llm_streaming_task (
858+ endpoint_name , create_streaming_task_request , user_id , session
859+ )
860+ tasks .append (asyncio .create_task (task ))
861+
862+ result = await asyncio .gather (* tasks )
863+ return result # type: ignore
864+
865+
665866async def get_async_task (
666867 task_id : str , user_id : str , session : aiohttp .ClientSession
667868) -> Dict [str , Any ]:
@@ -708,6 +909,22 @@ def ensure_n_ready_endpoints_short(n: int, user_id: str):
708909 assert len (ready_endpoints ) >= n
709910
710911
912+ # Wait 2 minutes (120 seconds) for endpoints to build.
913+ @retry (stop = stop_after_attempt (12 ), wait = wait_fixed (10 ))
914+ def ensure_n_ready_private_llm_endpoints_short (n : int , user_id : str ):
915+ endpoints = list_llm_model_endpoints (user_id )
916+ private_endpoints = [
917+ endpoint for endpoint in endpoints if not endpoint ["spec" ]["public_inference" ]
918+ ]
919+ ready_endpoints = [endpoint for endpoint in private_endpoints if endpoint ["status" ] == "READY" ]
920+ print (
921+ f"User { user_id } Current num endpoints: { len (private_endpoints )} , num ready endpoints: { len (ready_endpoints )} "
922+ )
923+ assert (
924+ len (ready_endpoints ) >= n
925+ ), f"Expected { n } ready endpoints, got { len (ready_endpoints )} . Look through endpoint builder for errors."
926+
927+
711928def delete_all_endpoints (user_id : str , delete_suffix_only : bool ):
712929 endpoints = list_model_endpoints (user_id )
713930 for i , endpoint in enumerate (endpoints ):
@@ -737,6 +954,13 @@ def ensure_nonzero_available_workers(endpoint_name: str, user_id: str):
737954 assert simple_endpoint .get ("deployment_state" , {}).get ("available_workers" , 0 )
738955
739956
957+ # Wait up to 20 minutes (1200 seconds) for the pods to spin up.
958+ @retry (stop = stop_after_attempt (120 ), wait = wait_fixed (10 ))
959+ def ensure_nonzero_available_llm_workers (endpoint_name : str , user_id : str ):
960+ simple_endpoint = get_llm_model_endpoint (endpoint_name , user_id )
961+ assert simple_endpoint ["spec" ].get ("deployment_state" , {}).get ("available_workers" , 0 )
962+
963+
740964def ensure_inference_task_response_is_correct (response : Dict [str , Any ], return_pickled : bool ):
741965 print (response )
742966 assert response ["status" ] == "SUCCESS"
@@ -747,6 +971,22 @@ def ensure_inference_task_response_is_correct(response: Dict[str, Any], return_p
747971 assert response ["result" ] == {"result" : '{"y": 1}' }
748972
749973
974+ def ensure_llm_task_response_is_correct (
975+ response : Dict [str , Any ],
976+ required_output_fields : Optional [List [str ]],
977+ response_text_regex : Optional [str ],
978+ ):
979+ print (response )
980+ assert response ["output" ] is not None
981+
982+ if required_output_fields is not None :
983+ for field in required_output_fields :
984+ assert field in response ["output" ]
985+
986+ if response_text_regex is not None :
987+ assert re .search (response_text_regex , response ["output" ]["text" ])
988+
989+
750990# Wait up to 30 seconds for the tasks to be returned.
751991@retry (
752992 stop = stop_after_attempt (10 ), wait = wait_fixed (1 ), retry = retry_if_exception_type (AssertionError )
0 commit comments