11import json
2- from typing import Any , Dict , List , Tuple
2+ from typing import Any , List , Tuple
33from unittest import mock
44
55import pytest
5454 validate_and_update_completion_params ,
5555)
5656from model_engine_server .domain .use_cases .model_bundle_use_cases import CreateModelBundleV2UseCase
57- from model_engine_server .infra .repositories import live_tokenizer_repository
58- from model_engine_server .infra .repositories .live_tokenizer_repository import ModelInfo
5957
6058
61- def good_models_info () -> Dict [ str , ModelInfo ] :
62- return {
63- k : ModelInfo ( v . hf_repo , "s3://test-s3.tar" )
64- for k , v in live_tokenizer_repository . SUPPORTED_MODELS_INFO . items ()
65- }
59+ def mocked__get_latest_tag () :
60+ async def async_mock ( * args , ** kwargs ): # noqa
61+ return "fake_docker_repository_latest_image_tag"
62+
63+ return mock . AsyncMock ( side_effect = async_mock )
6664
6765
6866@pytest .mark .asyncio
6967@mock .patch (
70- "model_engine_server.domain.use_cases.llm_model_endpoint_use_cases.SUPPORTED_MODELS_INFO " ,
71- good_models_info (),
68+ "model_engine_server.domain.use_cases.llm_model_endpoint_use_cases._get_latest_tag " ,
69+ mocked__get_latest_tag (),
7270)
7371async def test_create_model_endpoint_use_case_success (
7472 test_api_key : str ,
@@ -183,40 +181,33 @@ async def test_create_model_endpoint_use_case_success(
183181 assert "--max-total-tokens" in bundle .flavor .command [- 1 ] and "4096" in bundle .flavor .command [- 1 ]
184182
185183
186- def bad_models_info () -> Dict [str , ModelInfo ]:
187- info = {
188- k : ModelInfo (v .hf_repo , v .s3_repo )
189- for k , v in live_tokenizer_repository .SUPPORTED_MODELS_INFO .items ()
190- }
191- info .update (
192- {
193- "mpt-7b" : ModelInfo ("mosaicml/mpt-7b" , None ),
194- "mpt-7b-instruct" : ModelInfo ("mosaicml/mpt-7b-instruct" , "gibberish" ),
195- }
196- )
197- return info
198-
199-
200184@pytest .mark .asyncio
201185@pytest .mark .parametrize (
202- "inference_framework, model_name, expected_error" ,
186+ "inference_framework, model_name, checkpoint_path, expected_error" ,
203187 [
204- (LLMInferenceFramework .TEXT_GENERATION_INFERENCE , "mpt-7b" , InvalidRequestException ),
188+ (LLMInferenceFramework .TEXT_GENERATION_INFERENCE , "mpt-7b" , None , InvalidRequestException ),
205189 (
206190 LLMInferenceFramework .TEXT_GENERATION_INFERENCE ,
207191 "mpt-7b-instruct" ,
192+ "gibberish" ,
193+ ObjectHasInvalidValueException ,
194+ ),
195+ (LLMInferenceFramework .LIGHTLLM , "mpt-7b" , None , InvalidRequestException ),
196+ (
197+ LLMInferenceFramework .LIGHTLLM ,
198+ "mpt-7b-instruct" ,
199+ "gibberish" ,
200+ ObjectHasInvalidValueException ,
201+ ),
202+ (LLMInferenceFramework .VLLM , "mpt-7b" , None , InvalidRequestException ),
203+ (
204+ LLMInferenceFramework .VLLM ,
205+ "mpt-7b-instruct" ,
206+ "gibberish" ,
208207 ObjectHasInvalidValueException ,
209208 ),
210- (LLMInferenceFramework .LIGHTLLM , "mpt-7b" , InvalidRequestException ),
211- (LLMInferenceFramework .LIGHTLLM , "mpt-7b-instruct" , ObjectHasInvalidValueException ),
212- (LLMInferenceFramework .VLLM , "mpt-7b" , InvalidRequestException ),
213- (LLMInferenceFramework .VLLM , "mpt-7b-instruct" , ObjectHasInvalidValueException ),
214209 ],
215210)
216- @mock .patch (
217- "model_engine_server.domain.use_cases.llm_model_endpoint_use_cases.SUPPORTED_MODELS_INFO" ,
218- bad_models_info (),
219- )
220211async def test_create_model_bundle_fails_if_no_checkpoint (
221212 test_api_key : str ,
222213 fake_model_bundle_repository ,
@@ -227,6 +218,7 @@ async def test_create_model_bundle_fails_if_no_checkpoint(
227218 create_llm_model_endpoint_text_generation_inference_request_streaming : CreateLLMModelEndpointV1Request ,
228219 inference_framework ,
229220 model_name ,
221+ checkpoint_path ,
230222 expected_error ,
231223):
232224 fake_model_endpoint_service .model_bundle_repository = fake_model_bundle_repository
@@ -255,7 +247,7 @@ async def test_create_model_bundle_fails_if_no_checkpoint(
255247 endpoint_type = request .endpoint_type ,
256248 num_shards = request .num_shards ,
257249 quantize = request .quantize ,
258- checkpoint_path = None ,
250+ checkpoint_path = checkpoint_path ,
259251 )
260252
261253
@@ -269,10 +261,6 @@ async def test_create_model_bundle_fails_if_no_checkpoint(
269261 (True , LLMInferenceFramework .VLLM , "0.1.3.6" ),
270262 ],
271263)
272- @mock .patch (
273- "model_engine_server.domain.use_cases.llm_model_endpoint_use_cases.SUPPORTED_MODELS_INFO" ,
274- good_models_info (),
275- )
276264async def test_create_model_bundle_inference_framework_image_tag_validation (
277265 test_api_key : str ,
278266 fake_model_bundle_repository ,
@@ -307,6 +295,7 @@ async def test_create_model_bundle_inference_framework_image_tag_validation(
307295 request = create_llm_model_endpoint_text_generation_inference_request_streaming .copy ()
308296 request .inference_framework = inference_framework
309297 request .inference_framework_image_tag = inference_framework_image_tag
298+ request .checkpoint_path = "s3://test-s3.tar"
310299 user = User (user_id = test_api_key , team_id = test_api_key , is_privileged_user = True )
311300 if valid :
312301 await use_case .execute (user = user , request = request )
@@ -592,6 +581,10 @@ async def test_get_llm_model_endpoint_use_case_raises_not_authorized(
592581
593582
594583@pytest .mark .asyncio
584+ @mock .patch (
585+ "model_engine_server.domain.use_cases.llm_model_endpoint_use_cases._get_latest_tag" ,
586+ mocked__get_latest_tag (),
587+ )
595588async def test_update_model_endpoint_use_case_success (
596589 test_api_key : str ,
597590 fake_model_bundle_repository ,
0 commit comments