33
44"""
55
6- from typing import Any , Dict , List , Optional
6+ from typing import Any , Dict , List , Literal , Optional , TypeAlias , Union
77
88from model_engine_server .common .dtos .core import HttpUrlStr
9+ from model_engine_server .common .dtos .llms .sglang import SGLangEndpointAdditionalArgs
910from model_engine_server .common .dtos .llms .vllm import VLLMEndpointAdditionalArgs
1011from model_engine_server .common .dtos .model_endpoints import (
1112 CpuSpecificationType ,
2526 ModelEndpointStatus ,
2627 Quantization ,
2728)
29+ from pydantic import Discriminator , Tag
30+ from typing_extensions import Annotated
2831
2932
30- class CreateLLMModelEndpointV1Request (VLLMEndpointAdditionalArgs , BaseModel ):
31- name : str
32-
33- # LLM specific fields
34- model_name : str
35- source : LLMSource = LLMSource .HUGGING_FACE
36- inference_framework : LLMInferenceFramework = LLMInferenceFramework .VLLM
37- inference_framework_image_tag : str = "latest"
38- num_shards : int = 1
39- """
40- Number of shards to distribute the model onto GPUs.
41- """
42-
33+ class LLMModelEndpointCommonArgs (BaseModel ):
4334 quantize : Optional [Quantization ] = None
4435 """
4536 Whether to quantize the model.
@@ -51,20 +42,14 @@ class CreateLLMModelEndpointV1Request(VLLMEndpointAdditionalArgs, BaseModel):
5142 """
5243
5344 # General endpoint fields
54- metadata : Dict [str , Any ] # TODO: JSON type
5545 post_inference_hooks : Optional [List [str ]] = None
56- endpoint_type : ModelEndpointType = ModelEndpointType .SYNC
5746 cpus : Optional [CpuSpecificationType ] = None
5847 gpus : Optional [int ] = None
5948 memory : Optional [StorageSpecificationType ] = None
6049 gpu_type : Optional [GpuType ] = None
6150 storage : Optional [StorageSpecificationType ] = None
6251 nodes_per_worker : Optional [int ] = None
6352 optimize_costs : Optional [bool ] = None
64- min_workers : int
65- max_workers : int
66- per_worker : int
67- labels : Dict [str , str ]
6853 prewarm : Optional [bool ] = None
6954 high_priority : Optional [bool ] = None
7055 billing_tags : Optional [Dict [str , Any ]] = None
@@ -77,6 +62,83 @@ class CreateLLMModelEndpointV1Request(VLLMEndpointAdditionalArgs, BaseModel):
7762 )
7863
7964
65+ class CreateLLMModelEndpointArgs (LLMModelEndpointCommonArgs ):
66+ name : str
67+ model_name : str
68+ """
69+ Number of shards to distribute the model onto GPUs.
70+ """
71+ metadata : Dict [str , Any ] # TODO: JSON type
72+ min_workers : int
73+ max_workers : int
74+ per_worker : int
75+ labels : Dict [str , str ]
76+ source : LLMSource = LLMSource .HUGGING_FACE
77+ inference_framework_image_tag : str = "latest"
78+ num_shards : int = 1
79+ endpoint_type : ModelEndpointType = ModelEndpointType .SYNC
80+
81+
82+ class CreateVLLMModelEndpointRequest (
83+ VLLMEndpointAdditionalArgs , CreateLLMModelEndpointArgs , BaseModel
84+ ):
85+ inference_framework : Literal [LLMInferenceFramework .VLLM ] = LLMInferenceFramework .VLLM
86+ pass
87+
88+
89+ class CreateSGLangModelEndpointRequest (
90+ SGLangEndpointAdditionalArgs , CreateLLMModelEndpointArgs , BaseModel
91+ ):
92+ inference_framework : Literal [LLMInferenceFramework .SGLANG ] = LLMInferenceFramework .SGLANG
93+ pass
94+
95+
96+ class CreateDeepSpeedModelEndpointRequest (CreateLLMModelEndpointArgs , BaseModel ):
97+ inference_framework : Literal [LLMInferenceFramework .DEEPSPEED ] = LLMInferenceFramework .DEEPSPEED
98+ pass
99+
100+
101+ class CreateTextGenerationInferenceModelEndpointRequest (CreateLLMModelEndpointArgs , BaseModel ):
102+ inference_framework : Literal [LLMInferenceFramework .TEXT_GENERATION_INFERENCE ] = (
103+ LLMInferenceFramework .TEXT_GENERATION_INFERENCE
104+ )
105+ pass
106+
107+
108+ class CreateLightLLMModelEndpointRequest (CreateLLMModelEndpointArgs , BaseModel ):
109+ inference_framework : Literal [LLMInferenceFramework .LIGHTLLM ] = LLMInferenceFramework .LIGHTLLM
110+ pass
111+
112+
113+ class CreateTensorRTLLMModelEndpointRequest (CreateLLMModelEndpointArgs , BaseModel ):
114+ inference_framework : Literal [LLMInferenceFramework .TENSORRT_LLM ] = (
115+ LLMInferenceFramework .TENSORRT_LLM
116+ )
117+ pass
118+
119+
120+ def get_inference_framework (v : Any ) -> str :
121+ if isinstance (v , dict ):
122+ return v .get ("inference_framework" , LLMInferenceFramework .VLLM )
123+ return getattr (v , "inference_framework" , LLMInferenceFramework .VLLM )
124+
125+
126+ CreateLLMModelEndpointV1Request : TypeAlias = Annotated [
127+ Union [
128+ Annotated [CreateVLLMModelEndpointRequest , Tag (LLMInferenceFramework .VLLM )],
129+ Annotated [CreateSGLangModelEndpointRequest , Tag (LLMInferenceFramework .SGLANG )],
130+ Annotated [CreateDeepSpeedModelEndpointRequest , Tag (LLMInferenceFramework .DEEPSPEED )],
131+ Annotated [
132+ CreateTextGenerationInferenceModelEndpointRequest ,
133+ Tag (LLMInferenceFramework .TEXT_GENERATION_INFERENCE ),
134+ ],
135+ Annotated [CreateLightLLMModelEndpointRequest , Tag (LLMInferenceFramework .LIGHTLLM )],
136+ Annotated [CreateTensorRTLLMModelEndpointRequest , Tag (LLMInferenceFramework .TENSORRT_LLM )],
137+ ],
138+ Discriminator (get_inference_framework ),
139+ ]
140+
141+
80142class CreateLLMModelEndpointV1Response (BaseModel ):
81143 endpoint_creation_task_id : str
82144
@@ -107,57 +169,73 @@ class ListLLMModelEndpointsV1Response(BaseModel):
107169 model_endpoints : List [GetLLMModelEndpointV1Response ]
108170
109171
110- class UpdateLLMModelEndpointV1Request (VLLMEndpointAdditionalArgs , BaseModel ):
111- # LLM specific fields
172+ class UpdateLLMModelEndpointArgs (LLMModelEndpointCommonArgs ):
112173 model_name : Optional [str ] = None
113174 source : Optional [LLMSource ] = None
175+ inference_framework : Optional [LLMInferenceFramework ] = None
114176 inference_framework_image_tag : Optional [str ] = None
115177 num_shards : Optional [int ] = None
116178 """
117179 Number of shards to distribute the model onto GPUs.
118180 """
119-
120- quantize : Optional [Quantization ] = None
121- """
122- Whether to quantize the model.
181+ metadata : Optional [Dict [str , Any ]] = None
182+ force_bundle_recreation : Optional [bool ] = False
123183 """
184+ Whether to force recreate the underlying bundle.
124185
125- checkpoint_path : Optional [str ] = None
126- """
127- Path to the checkpoint to load the model from.
186+ If True, the underlying bundle will be recreated. This is useful if there are underlying implementation changes with how bundles are created
187+ that we would like to pick up for existing endpoints
128188 """
129-
130- # General endpoint fields
131- metadata : Optional [Dict [str , Any ]] = None
132- post_inference_hooks : Optional [List [str ]] = None
133- cpus : Optional [CpuSpecificationType ] = None
134- gpus : Optional [int ] = None
135- memory : Optional [StorageSpecificationType ] = None
136- gpu_type : Optional [GpuType ] = None
137- storage : Optional [StorageSpecificationType ] = None
138- optimize_costs : Optional [bool ] = None
139189 min_workers : Optional [int ] = None
140190 max_workers : Optional [int ] = None
141191 per_worker : Optional [int ] = None
142192 labels : Optional [Dict [str , str ]] = None
143- prewarm : Optional [bool ] = None
144- high_priority : Optional [bool ] = None
145- billing_tags : Optional [Dict [str , Any ]] = None
146- default_callback_url : Optional [HttpUrlStr ] = None
147- default_callback_auth : Optional [CallbackAuth ] = None
148- public_inference : Optional [bool ] = None
149- chat_template_override : Optional [str ] = Field (
150- default = None ,
151- description = "A Jinja template to use for this endpoint. If not provided, will use the chat template from the checkpoint" ,
193+
194+
195+ class UpdateVLLMModelEndpointRequest (
196+ VLLMEndpointAdditionalArgs , UpdateLLMModelEndpointArgs , BaseModel
197+ ):
198+ inference_framework : Literal [LLMInferenceFramework .VLLM ] = LLMInferenceFramework .VLLM
199+
200+
201+ class UpdateSGLangModelEndpointRequest (
202+ SGLangEndpointAdditionalArgs , UpdateLLMModelEndpointArgs , BaseModel
203+ ):
204+ inference_framework : Literal [LLMInferenceFramework .SGLANG ] = LLMInferenceFramework .SGLANG
205+
206+
207+ class UpdateDeepSpeedModelEndpointRequest (UpdateLLMModelEndpointArgs , BaseModel ):
208+ inference_framework : Literal [LLMInferenceFramework .DEEPSPEED ] = LLMInferenceFramework .DEEPSPEED
209+
210+
211+ class UpdateTextGenerationInferenceModelEndpointRequest (UpdateLLMModelEndpointArgs , BaseModel ):
212+ inference_framework : Literal [LLMInferenceFramework .TEXT_GENERATION_INFERENCE ] = (
213+ LLMInferenceFramework .TEXT_GENERATION_INFERENCE
152214 )
153215
154- force_bundle_recreation : Optional [bool ] = False
155- """
156- Whether to force recreate the underlying bundle.
157216
158- If True, the underlying bundle will be recreated. This is useful if there are underlying implementation changes with how bundles are created
159- that we would like to pick up for existing endpoints
160- """
217+ class UpdateLightLLMModelEndpointRequest (UpdateLLMModelEndpointArgs , BaseModel ):
218+ inference_framework : Literal [LLMInferenceFramework .LIGHTLLM ] = LLMInferenceFramework .LIGHTLLM
219+
220+
221+ class UpdateTensorRTLLMModelEndpointRequest (UpdateLLMModelEndpointArgs , BaseModel ):
222+ inference_framework : Literal [LLMInferenceFramework .TENSORRT_LLM ] = (
223+ LLMInferenceFramework .TENSORRT_LLM
224+ )
225+
226+
227+ UpdateLLMModelEndpointV1Request : TypeAlias = Annotated [
228+ Union [
229+ Annotated [UpdateVLLMModelEndpointRequest , Tag (LLMInferenceFramework .VLLM )],
230+ Annotated [UpdateSGLangModelEndpointRequest , Tag (LLMInferenceFramework .SGLANG )],
231+ Annotated [UpdateDeepSpeedModelEndpointRequest , Tag (LLMInferenceFramework .DEEPSPEED )],
232+ Annotated [
233+ UpdateTextGenerationInferenceModelEndpointRequest ,
234+ Tag (LLMInferenceFramework .TEXT_GENERATION_INFERENCE ),
235+ ],
236+ ],
237+ Discriminator (get_inference_framework ),
238+ ]
161239
162240
163241class UpdateLLMModelEndpointV1Response (BaseModel ):
0 commit comments