@@ -74,7 +74,7 @@ async def test_batch_inference(
7474 new_callable = mock_open ,
7575 read_data = "Mocked content" ,
7676)
77- async def test_batch_inference_failed_to_download_model (
77+ async def test_batch_inference_failed_to_download_model_but_proceed (
7878 mock_open_func ,
7979 mock_popen ,
8080 mock_get_s3_client ,
@@ -86,19 +86,36 @@ async def test_batch_inference_failed_to_download_model(
8686 create_vllm_request_outputs ,
8787 mock_s3_client ,
8888 mock_process ,
89+ mock_completion_output ,
8990):
9091 # Mock the necessary objects and data
91- mock_process .returncode = 1
92+ mock_process .returncode = 1 # Failed to download model
9293 mock_popen .return_value = mock_process
9394 mock_get_s3_client .return_value = mock_s3_client
9495 mock_create_batch_completions_request .parse_file .return_value = create_batch_completions_request
9596 mock_create_batch_completions_request_content .parse_raw .return_value = (
9697 create_batch_completions_request_content
9798 )
9899
100+ mock_results_generator = MagicMock ()
101+ mock_results_generator .__aiter__ .return_value = create_vllm_request_outputs
102+
103+ # Mock the generate_with_vllm function
104+ mock_generate_with_vllm .return_value = [mock_results_generator ]
105+
99106 # Call the function
100- with pytest .raises (IOError ):
101- await batch_inference ()
107+ await batch_inference ()
108+
109+ # Assertions
110+ mock_create_batch_completions_request .parse_file .assert_called_once ()
111+ mock_open_func .assert_has_calls (
112+ [
113+ call ("input_data_path" , "r" ),
114+ call ("output_data_path" , "w" ),
115+ call ().write (json .dumps ([mock_completion_output .dict ()])),
116+ ],
117+ any_order = True ,
118+ )
102119
103120
104121@pytest .mark .asyncio
0 commit comments