@@ -168,7 +168,7 @@ def my_model(**keyword_args):
168168
169169CREATE_LLM_MODEL_ENDPOINT_REQUEST : Dict [str , Any ] = {
170170 "name" : format_name ("llama-2-7b-test" ),
171- "model_name" : "llama-2-7b" ,
171+ "model_name" : "llama-2-7b-chat " ,
172172 "source" : "hugging_face" ,
173173 "inference_framework" : "vllm" ,
174174 "inference_framework_image_tag" : "latest" ,
@@ -802,7 +802,7 @@ async def create_llm_streaming_task(
802802 timeout = LONG_NETWORK_TIMEOUT_SEC ,
803803 ) as response :
804804 assert response .status == 200 , (await response .read ()).decode ()
805- return await response .json ()
805+ return ( await response .read ()). decode ()
806806
807807
808808async def create_sync_tasks (
@@ -987,6 +987,27 @@ def ensure_llm_task_response_is_correct(
987987 assert re .search (response_text_regex , response ["output" ]["text" ])
988988
989989
990+ def ensure_llm_task_stream_response_is_correct (
991+ response : str ,
992+ required_output_fields : Optional [List [str ]],
993+ response_text_regex : Optional [str ],
994+ ):
995+ # parse response
996+ # data has format "data: <data>\n\ndata: <data>\n\n"
997+ # We want to get a list of dictionaries parsing out the 'data:' field
998+ parsed_response = [
999+ json .loads (r .split ("data: " )[1 ]) for r in response .split ("\n " ) if "data:" in r .strip ()
1000+ ]
1001+
1002+ # Join the text field of the response
1003+ response_text = "" .join ([r ["output" ]["text" ] for r in parsed_response ])
1004+ print ("response text: " , response_text )
1005+ assert response_text is not None
1006+
1007+ if response_text_regex is not None :
1008+ assert re .search (response_text_regex , response_text )
1009+
1010+
9901011# Wait up to 30 seconds for the tasks to be returned.
9911012@retry (
9921013 stop = stop_after_attempt (10 ), wait = wait_fixed (1 ), retry = retry_if_exception_type (AssertionError )
0 commit comments