Skip to content

Commit 49216ae

Browse files
xieusLiguang Xie
andauthored
[Bug] Refactor downloader and artifact_service to be non-blocking (#1895)
* Downloader minor refactor and gcp_test add timeout setting * Refactor downloaders and artifact_service to be non-blocking * Add unit tests for downloader * Remove unused libs * Move test_downloaders.py * Fix minor typos in downloader tests * Recover KV Cache status func --------- Signed-off-by: Liguang Xie <[email protected]> Co-authored-by: Liguang Xie <[email protected]>
1 parent 12cbb16 commit 49216ae

File tree

7 files changed

+75
-9
lines changed

7 files changed

+75
-9
lines changed

deployment/terraform/tests/gcp_test.go

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ import (
2222
"testing"
2323
"net"
2424
"net/http"
25+
"time"
2526

2627
"github.com/gruntwork-io/terratest/modules/terraform"
2728
"github.com/stretchr/testify/assert"
@@ -75,8 +76,11 @@ func TestAIBrixGCPDeployment(t *testing.T) {
7576
option.WithMaxRetries(0),
7677
)
7778

79+
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
80+
defer cancel()
81+
7882
// Run a chat completion against the model endpoint
79-
chatCompletion, err := client.Chat.Completions.New(context.TODO(), openai.ChatCompletionNewParams{
83+
chatCompletion, err := client.Chat.Completions.New(ctx, openai.ChatCompletionNewParams{
8084
Messages: openai.F([]openai.ChatCompletionMessageParamUnion{
8185
openai.UserMessage("What can you tell me about San Francisco?"),
8286
}),

python/aibrix/aibrix/runtime/artifact_service.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
"""Artifact delegation service for LoRA adapters."""
1616

17+
import asyncio
1718
import os
1819
import shutil
1920
from pathlib import Path
@@ -313,7 +314,8 @@ async def unload_adapter(
313314
local_path = self._get_local_path_for_adapter(lora_name)
314315
if os.path.exists(local_path):
315316
try:
316-
shutil.rmtree(local_path)
317+
loop = asyncio.get_running_loop()
318+
await loop.run_in_executor(None, shutil.rmtree, local_path)
317319
logger.info(f"Cleaned up local artifacts for {lora_name}")
318320
except Exception as e:
319321
logger.warning(
@@ -337,7 +339,8 @@ async def cleanup_artifact(self, lora_name: str) -> None:
337339

338340
if os.path.exists(local_path):
339341
try:
340-
shutil.rmtree(local_path)
342+
loop = asyncio.get_running_loop()
343+
await loop.run_in_executor(None, shutil.rmtree, local_path)
341344
logger.info(f"Cleaned up artifacts for {lora_name}")
342345
except Exception as e:
343346
logger.warning(f"Failed to clean up artifacts for {lora_name}: {e}")

python/aibrix/aibrix/runtime/downloaders.py

Lines changed: 40 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414

1515
"""Artifact downloaders for different storage backends."""
1616

17+
import asyncio
18+
import functools
1719
import os
1820
from abc import ABC, abstractmethod
1921
from pathlib import Path
@@ -58,6 +60,18 @@ class S3ArtifactDownloader(ArtifactDownloader):
5860

5961
async def download(
6062
self, source_url: str, local_path: str, credentials: Optional[Dict] = None
63+
) -> str:
64+
"""
65+
Download from S3 (Async wrapper).
66+
"""
67+
loop = asyncio.get_running_loop()
68+
return await loop.run_in_executor(
69+
None,
70+
functools.partial(self._download_sync, source_url, local_path, credentials),
71+
)
72+
73+
def _download_sync(
74+
self, source_url: str, local_path: str, credentials: Optional[Dict] = None
6175
) -> str:
6276
"""
6377
Download from S3.
@@ -154,7 +168,7 @@ def _download_s3_directory(
154168
local_file_path = os.path.join(local_path, relative_path)
155169

156170
# Ensure directory exists
157-
os.makedirs(os.path.dirname(local_file_path), exist_ok=True)
171+
self._ensure_directory(os.path.dirname(local_file_path))
158172

159173
# Download file
160174
s3_client.download_file(bucket_name, key, local_file_path)
@@ -168,6 +182,18 @@ class GCSArtifactDownloader(ArtifactDownloader):
168182

169183
async def download(
170184
self, source_url: str, local_path: str, credentials: Optional[Dict] = None
185+
) -> str:
186+
"""
187+
Download from GCS (Async wrapper).
188+
"""
189+
loop = asyncio.get_running_loop()
190+
return await loop.run_in_executor(
191+
None,
192+
functools.partial(self._download_sync, source_url, local_path, credentials),
193+
)
194+
195+
def _download_sync(
196+
self, source_url: str, local_path: str, credentials: Optional[Dict] = None
171197
) -> str:
172198
"""
173199
Download from GCS.
@@ -227,7 +253,7 @@ async def download(
227253
local_file_path = os.path.join(local_path, relative_path)
228254

229255
# Ensure directory exists
230-
os.makedirs(os.path.dirname(local_file_path), exist_ok=True)
256+
self._ensure_directory(os.path.dirname(local_file_path))
231257

232258
# Download file
233259
blob.download_to_filename(local_file_path)
@@ -257,6 +283,18 @@ class HuggingFaceArtifactDownloader(ArtifactDownloader):
257283

258284
async def download(
259285
self, source_url: str, local_path: str, credentials: Optional[Dict] = None
286+
) -> str:
287+
"""
288+
Download from HuggingFace Hub (Async wrapper).
289+
"""
290+
loop = asyncio.get_running_loop()
291+
return await loop.run_in_executor(
292+
None,
293+
functools.partial(self._download_sync, source_url, local_path, credentials),
294+
)
295+
296+
def _download_sync(
297+
self, source_url: str, local_path: str, credentials: Optional[Dict] = None
260298
) -> str:
261299
"""
262300
Download from HuggingFace Hub.

python/aibrix/tests/downloader/test_downloader_hf.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def test_get_downloader_hf():
2929

3030
def test_get_downloader_hf_not_exist():
3131
with pytest.raises(ModelNotFoundError) as exception:
32-
get_downloader("not_exsit_path/model")
32+
get_downloader("not_exist_path/model")
3333
assert "Model not found" in str(exception.value)
3434

3535

@@ -39,5 +39,5 @@ def test_get_downloader_hf_invalid_uri():
3939
assert "not in the expected format: repo/name" in str(exception.value)
4040

4141
with pytest.raises(ArgNotFormatError) as exception:
42-
get_downloader("multi/filed/repo")
42+
get_downloader("multi/field/repo")
4343
assert "not in the expected format: repo/name" in str(exception.value)

python/aibrix/tests/downloader/test_downloader_tos.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ def test_get_downloader_tos_path_not_exist(mock_boto3):
6363
mock_not_exsit_tos(mock_boto3)
6464

6565
with pytest.raises(ModelNotFoundError) as exception:
66-
get_downloader("tos://bucket/not_exsit_path")
66+
get_downloader("tos://bucket/not_exist_path")
6767
assert "Model not found" in str(exception.value)
6868

6969

python/aibrix/tests/downloader/test_downloader_tos_v1.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ def test_get_downloader_tos_path_not_exist(mock_tos):
6161
mock_not_exsit_tos(mock_tos)
6262

6363
with pytest.raises(ModelNotFoundError) as exception:
64-
get_downloader("tos://bucket/not_exsit_path")
64+
get_downloader("tos://bucket/not_exist_path")
6565
assert "Model not found" in str(exception.value)
6666

6767

python/aibrix_kvcache/aibrix_kvcache/status.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,27 @@ def __init__( # type: ignore[misc]
7474
def is_ok(self) -> bool:
7575
return self.error_code == StatusCodes.OK
7676

77+
def is_error(self) -> bool:
78+
return self.error_code == StatusCodes.ERROR
79+
80+
def is_not_found(self) -> bool:
81+
return self.error_code == StatusCodes.NOT_FOUND
82+
83+
def is_invalid(self) -> bool:
84+
return self.error_code == StatusCodes.INVALID
85+
86+
def is_out_of_memory(self) -> bool:
87+
return self.error_code == StatusCodes.OUT_OF_MEMORY
88+
89+
def is_timeout(self) -> bool:
90+
return self.error_code == StatusCodes.TIMEOUT
91+
92+
def is_denied(self) -> bool:
93+
return self.error_code == StatusCodes.DENIED
94+
95+
def is_cancelled(self) -> bool:
96+
return self.error_code == StatusCodes.CANCELLED
97+
7798
def get(self, default=None) -> T:
7899
"""Returns the value if successful, otherwise returns default."""
79100
if not self.is_ok():

0 commit comments

Comments
 (0)