Skip to content

Commit 5585d8e

Browse files
Fix permissions (#431)
* Fix s5cmd env vars * more fixes for s5cmd * dont error * add back aws_profile * flush * fix test
1 parent 8a35e38 commit 5585d8e

File tree

2 files changed

+30
-8
lines changed

2 files changed

+30
-8
lines changed

model-engine/model_engine_server/inference/batch_inference/vllm_batch.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,12 +27,17 @@ def get_s3_client():
2727

2828

2929
def download_model(checkpoint_path, final_weights_folder):
30-
s5cmd = f"./s5cmd --numworkers 512 sync --concurrency 10 {os.path.join(checkpoint_path, '*')} {final_weights_folder}"
30+
s5cmd = f"./s5cmd --numworkers 512 cp --concurrency 10 --include '*.model' --include '*.json' --include '*.bin' --include '*.safetensors' --exclude 'optimizer*' --exclude 'train*' {os.path.join(checkpoint_path, '*')} {final_weights_folder}"
31+
env = os.environ.copy()
32+
env["AWS_PROFILE"] = os.getenv("S3_WRITE_AWS_PROFILE", "default")
33+
# Need to override these env vars so s5cmd uses AWS_PROFILE
34+
env["AWS_ROLE_ARN"] = ""
35+
env["AWS_WEB_IDENTITY_TOKEN_FILE"] = ""
3136
process = subprocess.Popen(
32-
s5cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True
37+
s5cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, env=env
3338
)
3439
for line in process.stdout:
35-
print(line)
40+
print(line, flush=True)
3641

3742
process.wait()
3843

@@ -41,7 +46,7 @@ def download_model(checkpoint_path, final_weights_folder):
4146
for line in iter(process.stderr.readline, ""):
4247
stderr_lines.append(line.strip())
4348

44-
raise IOError(f"Error downloading model weights: {stderr_lines}")
49+
print(f"Error downloading model weights: {stderr_lines}", flush=True)
4550

4651

4752
def file_exists(path):

model-engine/tests/unit/inference/test_vllm_batch.py

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ async def test_batch_inference(
7474
new_callable=mock_open,
7575
read_data="Mocked content",
7676
)
77-
async def test_batch_inference_failed_to_download_model(
77+
async def test_batch_inference_failed_to_download_model_but_proceed(
7878
mock_open_func,
7979
mock_popen,
8080
mock_get_s3_client,
@@ -86,19 +86,36 @@ async def test_batch_inference_failed_to_download_model(
8686
create_vllm_request_outputs,
8787
mock_s3_client,
8888
mock_process,
89+
mock_completion_output,
8990
):
9091
# Mock the necessary objects and data
91-
mock_process.returncode = 1
92+
mock_process.returncode = 1 # Failed to download model
9293
mock_popen.return_value = mock_process
9394
mock_get_s3_client.return_value = mock_s3_client
9495
mock_create_batch_completions_request.parse_file.return_value = create_batch_completions_request
9596
mock_create_batch_completions_request_content.parse_raw.return_value = (
9697
create_batch_completions_request_content
9798
)
9899

100+
mock_results_generator = MagicMock()
101+
mock_results_generator.__aiter__.return_value = create_vllm_request_outputs
102+
103+
# Mock the generate_with_vllm function
104+
mock_generate_with_vllm.return_value = [mock_results_generator]
105+
99106
# Call the function
100-
with pytest.raises(IOError):
101-
await batch_inference()
107+
await batch_inference()
108+
109+
# Assertions
110+
mock_create_batch_completions_request.parse_file.assert_called_once()
111+
mock_open_func.assert_has_calls(
112+
[
113+
call("input_data_path", "r"),
114+
call("output_data_path", "w"),
115+
call().write(json.dumps([mock_completion_output.dict()])),
116+
],
117+
any_order=True,
118+
)
102119

103120

104121
@pytest.mark.asyncio

0 commit comments

Comments
 (0)