Skip to content

Commit 14baa29

Browse files
sam-scaleian-scale-2ian-scale
authored
Smartly check safetensors vs. bin (#296)
* Smartly check safetensors vs. bin * Fix formatting * Add unit test * Add unit test * heh hope this works. * refactoring * adding new utils file, removing test * adding in unit test, refactoring again * adding artifact gateway to use case * renaming gateway function * whoops * cleanup --------- Co-authored-by: Ian Macleod <[email protected]> Co-authored-by: Ian Macleod <[email protected]>
1 parent 287ab59 commit 14baa29

File tree

7 files changed

+95
-3
lines changed

7 files changed

+95
-3
lines changed

model-engine/model_engine_server/api/llms_v1.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@ async def create_model_endpoint(
9292
create_model_bundle_use_case=create_model_bundle_use_case,
9393
model_bundle_repository=external_interfaces.model_bundle_repository,
9494
model_endpoint_service=external_interfaces.model_endpoint_service,
95+
llm_artifact_gateway=external_interfaces.llm_artifact_gateway,
9596
)
9697
return await use_case.execute(user=auth, request=request)
9798
except ObjectAlreadyExistsException as exc:

model-engine/model_engine_server/core/aws/storage_client.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020

2121

2222
def sync_storage_client(**kwargs) -> BaseClient:
23-
return session(infra_config().profile_ml_worker).client("s3", **kwargs)
23+
return session(infra_config().profile_ml_worker).client("s3", **kwargs) # type: ignore
2424

2525

2626
def open(uri: str, mode: str = "rt", **kwargs) -> IO: # pylint: disable=redefined-builtin

model-engine/model_engine_server/domain/gateways/llm_artifact_gateway.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,13 @@ class LLMArtifactGateway(ABC):
77
Abstract Base Class for interacting with llm artifacts.
88
"""
99

10+
@abstractmethod
11+
def list_files(self, path: str, **kwargs) -> List[str]:
12+
"""
13+
Gets a list of files from a given path.
14+
"""
15+
pass
16+
1017
@abstractmethod
1118
def get_model_weights_urls(self, owner: str, model_name: str, **kwargs) -> List[str]:
1219
"""

model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,25 @@
134134
DOWNSTREAM_REQUEST_TIMEOUT_SECONDS = 5 * 60 # 5 minutes
135135

136136

137+
def _exclude_safetensors_or_bin(model_files: List[str]) -> Optional[str]:
138+
"""
139+
This function is used to determine whether to exclude "*.safetensors" or "*.bin" files
140+
based on which file type is present more often in the checkpoint folder. The less
141+
frequently present file type is excluded.
142+
If both files are equally present, no exclusion string is returned.
143+
"""
144+
exclude_str = None
145+
if len([f for f in model_files if f.endswith(".safetensors")]) > len(
146+
[f for f in model_files if f.endswith(".bin")]
147+
):
148+
exclude_str = "*.bin"
149+
elif len([f for f in model_files if f.endswith(".safetensors")]) < len(
150+
[f for f in model_files if f.endswith(".bin")]
151+
):
152+
exclude_str = "*.safetensors"
153+
return exclude_str
154+
155+
137156
def _model_endpoint_entity_to_get_llm_model_endpoint_response(
138157
model_endpoint: ModelEndpoint,
139158
) -> GetLLMModelEndpointV1Response:
@@ -182,11 +201,13 @@ def __init__(
182201
create_model_bundle_use_case: CreateModelBundleV2UseCase,
183202
model_bundle_repository: ModelBundleRepository,
184203
model_endpoint_service: ModelEndpointService,
204+
llm_artifact_gateway: LLMArtifactGateway,
185205
):
186206
self.authz_module = LiveAuthorizationModule()
187207
self.create_model_bundle_use_case = create_model_bundle_use_case
188208
self.model_bundle_repository = model_bundle_repository
189209
self.model_endpoint_service = model_endpoint_service
210+
self.llm_artifact_gateway = llm_artifact_gateway
190211

191212
async def create_model_bundle(
192213
self,
@@ -358,14 +379,21 @@ def load_model_weights_sub_commands(
358379
]
359380
)
360381
else:
361-
if framework == LLMInferenceFramework.TEXT_GENERATION_INFERENCE:
382+
# Let's check whether to exclude "*.safetensors" or "*.bin" files
383+
checkpoint_files = self.llm_artifact_gateway.list_files(checkpoint_path)
384+
model_files = [f for f in checkpoint_files if "model" in f]
385+
386+
exclude_str = _exclude_safetensors_or_bin(model_files)
387+
388+
if exclude_str is None:
362389
subcommands.append(
363390
f"{s5cmd} --numworkers 512 cp --concurrency 10 {os.path.join(checkpoint_path, '*')} {final_weights_folder}"
364391
)
365392
else:
366393
subcommands.append(
367-
f"{s5cmd} --numworkers 512 cp --concurrency 10 --exclude '*.safetensors' {os.path.join(checkpoint_path, '*')} {final_weights_folder}"
394+
f"{s5cmd} --numworkers 512 cp --concurrency 10 --exclude '{exclude_str}' {os.path.join(checkpoint_path, '*')} {final_weights_folder}"
368395
)
396+
369397
return subcommands
370398

371399
async def create_deepspeed_bundle(

model-engine/model_engine_server/infra/gateways/s3_llm_artifact_gateway.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
import boto3
55
from model_engine_server.common.config import get_model_cache_directory_name, hmi_config
6+
from model_engine_server.core.utils.url import parse_attachment_url
67
from model_engine_server.domain.gateways import LLMArtifactGateway
78

89

@@ -17,6 +18,18 @@ def _get_s3_resource(self, kwargs):
1718
resource = session.resource("s3")
1819
return resource
1920

21+
def list_files(self, path: str, **kwargs) -> List[str]:
22+
s3 = self._get_s3_resource(kwargs)
23+
parsed_remote = parse_attachment_url(path)
24+
bucket = parsed_remote.bucket
25+
key = parsed_remote.key
26+
# From here: https://dev.to/aws-builders/how-to-list-contents-of-s3-bucket-using-boto3-python-47mm
27+
files = [
28+
bucket_object["Key"]
29+
for bucket_object in s3.list_objects_v2(Bucket=bucket, Prefix=key)["Contents"]
30+
]
31+
return files
32+
2033
def get_model_weights_urls(self, owner: str, model_name: str, **kwargs) -> List[str]:
2134
s3 = self._get_s3_resource(kwargs)
2235
# parsing prefix to get S3 bucket name

model-engine/tests/unit/conftest.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -748,11 +748,16 @@ async def initialize_events(self, user_id: str, model_endpoint_name: str):
748748
class FakeLLMArtifactGateway(LLMArtifactGateway):
749749
def __init__(self):
750750
self.existing_models = []
751+
self.s3_bucket = {"fake-checkpoint": ["fake.bin, fake2.bin", "fake3.safetensors"]}
751752
self.urls = {"filename": "https://test-bucket.s3.amazonaws.com/llm/llm-1.0.0.tar.gz"}
752753

753754
def _add_model(self, owner: str, model_name: str):
754755
self.existing_models.append((owner, model_name))
755756

757+
def list_files(self, path: str, **kwargs) -> List[str]:
758+
if path in self.s3_bucket:
759+
return self.s3_bucket[path]
760+
756761
def get_model_weights_urls(self, owner: str, model_name: str):
757762
if (owner, model_name) in self.existing_models:
758763
return self.urls

model-engine/tests/unit/domain/test_llm_use_cases.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
DeleteLLMEndpointByNameUseCase,
3737
GetLLMModelEndpointByNameV1UseCase,
3838
ModelDownloadV1UseCase,
39+
_exclude_safetensors_or_bin,
3940
)
4041
from model_engine_server.domain.use_cases.model_bundle_use_cases import CreateModelBundleV2UseCase
4142

@@ -47,6 +48,7 @@ async def test_create_model_endpoint_use_case_success(
4748
fake_model_endpoint_service,
4849
fake_docker_repository_image_always_exists,
4950
fake_model_primitive_gateway,
51+
fake_llm_artifact_gateway,
5052
create_llm_model_endpoint_request_async: CreateLLMModelEndpointV1Request,
5153
create_llm_model_endpoint_request_sync: CreateLLMModelEndpointV1Request,
5254
create_llm_model_endpoint_request_streaming: CreateLLMModelEndpointV1Request,
@@ -62,6 +64,7 @@ async def test_create_model_endpoint_use_case_success(
6264
create_model_bundle_use_case=bundle_use_case,
6365
model_bundle_repository=fake_model_bundle_repository,
6466
model_endpoint_service=fake_model_endpoint_service,
67+
llm_artifact_gateway=fake_llm_artifact_gateway,
6568
)
6669
user = User(user_id=test_api_key, team_id=test_api_key, is_privileged_user=True)
6770
response_1 = await use_case.execute(user=user, request=create_llm_model_endpoint_request_async)
@@ -150,6 +153,7 @@ async def test_create_model_endpoint_text_generation_inference_use_case_success(
150153
fake_model_endpoint_service,
151154
fake_docker_repository_image_always_exists,
152155
fake_model_primitive_gateway,
156+
fake_llm_artifact_gateway,
153157
create_llm_model_endpoint_text_generation_inference_request_async: CreateLLMModelEndpointV1Request,
154158
create_llm_model_endpoint_text_generation_inference_request_streaming: CreateLLMModelEndpointV1Request,
155159
):
@@ -163,6 +167,7 @@ async def test_create_model_endpoint_text_generation_inference_use_case_success(
163167
create_model_bundle_use_case=bundle_use_case,
164168
model_bundle_repository=fake_model_bundle_repository,
165169
model_endpoint_service=fake_model_endpoint_service,
170+
llm_artifact_gateway=fake_llm_artifact_gateway,
166171
)
167172
user = User(user_id=test_api_key, team_id=test_api_key, is_privileged_user=True)
168173
response_1 = await use_case.execute(
@@ -202,6 +207,7 @@ async def test_create_llm_model_endpoint_use_case_raises_invalid_value_exception
202207
fake_model_endpoint_service,
203208
fake_docker_repository_image_always_exists,
204209
fake_model_primitive_gateway,
210+
fake_llm_artifact_gateway,
205211
create_llm_model_endpoint_request_invalid_model_name: CreateLLMModelEndpointV1Request,
206212
):
207213
fake_model_endpoint_service.model_bundle_repository = fake_model_bundle_repository
@@ -214,6 +220,7 @@ async def test_create_llm_model_endpoint_use_case_raises_invalid_value_exception
214220
create_model_bundle_use_case=bundle_use_case,
215221
model_bundle_repository=fake_model_bundle_repository,
216222
model_endpoint_service=fake_model_endpoint_service,
223+
llm_artifact_gateway=fake_llm_artifact_gateway,
217224
)
218225
user = User(user_id=test_api_key, team_id=test_api_key, is_privileged_user=True)
219226
with pytest.raises(ObjectHasInvalidValueException):
@@ -953,3 +960,34 @@ async def test_delete_public_inference_model_raises_not_authorized(
953960
await use_case.execute(
954961
user=user, model_endpoint_name=llm_model_endpoint_sync[0].record.name
955962
)
963+
964+
965+
@pytest.mark.asyncio
966+
async def test_exclude_safetensors_or_bin_majority_bin_returns_exclude_safetensors():
967+
fake_model_files = ["fake.bin", "fake2.bin", "fake3.safetensors", "model.json", "optimizer.pt"]
968+
assert _exclude_safetensors_or_bin(fake_model_files) == "*.safetensors"
969+
970+
971+
@pytest.mark.asyncio
972+
async def test_exclude_safetensors_or_bin_majority_safetensors_returns_exclude_bin():
973+
fake_model_files = [
974+
"fake.bin",
975+
"fake2.safetensors",
976+
"fake3.safetensors",
977+
"model.json",
978+
"optimizer.pt",
979+
]
980+
assert _exclude_safetensors_or_bin(fake_model_files) == "*.bin"
981+
982+
983+
@pytest.mark.asyncio
984+
async def test_exclude_safetensors_or_bin_equal_bins_and_safetensors_returns_none():
985+
fake_model_files = [
986+
"fake.bin",
987+
"fake2.safetensors",
988+
"fake3.safetensors",
989+
"fake4.bin",
990+
"model.json",
991+
"optimizer.pt",
992+
]
993+
assert _exclude_safetensors_or_bin(fake_model_files) is None

0 commit comments

Comments
 (0)