Skip to content

Commit 6b35428

Browse files
Integrate LightLLM (#273)
* Integrate LightLLM * wip
1 parent 46ce8e5 commit 6b35428

File tree

5 files changed

+196
-0
lines changed

5 files changed

+196
-0
lines changed

charts/model-engine/values_circleci.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,7 @@ config:
141141
istio_enabled: true
142142
tgi_repository: "text-generation-inference"
143143
vllm_repository: "vllm"
144+
lightllm_repository: "lightllm"
144145
hf_user_fine_tuned_weights_prefix: "s3://$CIRCLECI_AWS_S3_BUCKET"
145146

146147
# Service Account

model-engine/model_engine_server/common/config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ class HostedModelInferenceServiceConfig:
5656
datadog_trace_enabled: bool
5757
tgi_repository: str
5858
vllm_repository: str
59+
lightllm_repository: str
5960

6061
@classmethod
6162
def from_yaml(cls, yaml_path):

model-engine/model_engine_server/domain/entities/llm_entity.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ class LLMInferenceFramework(str, Enum):
1111
DEEPSPEED = "deepspeed"
1212
TEXT_GENERATION_INFERENCE = "text_generation_inference"
1313
VLLM = "vllm"
14+
LIGHTLLM = "lightllm"
1415

1516

1617
class Quantization(str, Enum):

model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py

Lines changed: 192 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,15 @@
119119
"falcon-40b": "tiiuae/falcon-40b",
120120
"falcon-40b-instruct": "tiiuae/falcon-40b-instruct",
121121
},
122+
LLMInferenceFramework.LIGHTLLM: {
123+
"llama-7b": "decapoda-research/llama-7b-hf",
124+
"llama-2-7b": "meta-llama/Llama-2-7b-hf",
125+
"llama-2-7b-chat": "meta-llama/Llama-2-7b-chat-hf",
126+
"llama-2-13b": "meta-llama/Llama-2-13b-hf",
127+
"llama-2-13b-chat": "meta-llama/Llama-2-13b-chat-hf",
128+
"llama-2-70b": "meta-llama/Llama-2-70b-hf",
129+
"llama-2-70b-chat": "meta-llama/Llama-2-70b-chat-hf",
130+
},
122131
}
123132

124133

@@ -221,6 +230,15 @@ async def create_model_bundle(
221230
num_shards,
222231
checkpoint_path,
223232
)
233+
elif framework == LLMInferenceFramework.LIGHTLLM:
234+
bundle_id = await self.create_lightllm_bundle(
235+
user,
236+
model_name,
237+
framework_image_tag,
238+
endpoint_name,
239+
num_shards,
240+
checkpoint_path,
241+
)
224242
else:
225243
raise ObjectHasInvalidValueException(
226244
f"Framework {framework} is not supported for source {source}."
@@ -499,6 +517,86 @@ async def create_vllm_bundle(
499517
)
500518
).model_bundle_id
501519

520+
async def create_lightllm_bundle(
521+
self,
522+
user: User,
523+
model_name: str,
524+
framework_image_tag: str,
525+
endpoint_unique_name: str,
526+
num_shards: int,
527+
checkpoint_path: Optional[str],
528+
):
529+
command = []
530+
531+
# TODO: incorporate auto calculate max_total_token_num from https://github.com/ModelTC/lightllm/pull/81
532+
max_total_token_num = 6000 # LightLLM default
533+
if num_shards == 1:
534+
max_total_token_num = 15000 # Default for Llama 2 7B on 1 x A10
535+
elif num_shards == 2:
536+
max_total_token_num = 21000 # Default for Llama 2 13B on 2 x A10
537+
elif num_shards == 4:
538+
max_total_token_num = 70000 # Default for Llama 2 13B on 4 x A10
539+
max_req_input_len = 2047
540+
max_req_total_len = 2048
541+
if "llama-2" in model_name:
542+
max_req_input_len = 4095
543+
max_req_total_len = 4096
544+
545+
subcommands = []
546+
if checkpoint_path is not None:
547+
if checkpoint_path.startswith("s3://"):
548+
final_weights_folder = "model_files"
549+
subcommands += self.load_model_weights_sub_commands(
550+
LLMInferenceFramework.LIGHTLLM,
551+
framework_image_tag,
552+
checkpoint_path,
553+
final_weights_folder,
554+
)
555+
else:
556+
raise ObjectHasInvalidValueException(
557+
f"Not able to load checkpoint path {checkpoint_path}."
558+
)
559+
else:
560+
final_weights_folder = _SUPPORTED_MODEL_NAMES[LLMInferenceFramework.VLLM][model_name]
561+
562+
subcommands.append(
563+
f"python -m lightllm.server.api_server --model_dir {final_weights_folder} --port 5005 --tp {num_shards} --max_total_token_num {max_total_token_num} --max_req_input_len {max_req_input_len} --max_req_total_len {max_req_total_len} --tokenizer_mode auto"
564+
)
565+
566+
command = [
567+
"/bin/bash",
568+
"-c",
569+
";".join(subcommands),
570+
]
571+
572+
return (
573+
await self.create_model_bundle_use_case.execute(
574+
user,
575+
CreateModelBundleV2Request(
576+
name=endpoint_unique_name,
577+
schema_location="TBA",
578+
flavor=StreamingEnhancedRunnableImageFlavor(
579+
flavor=ModelBundleFlavorType.STREAMING_ENHANCED_RUNNABLE_IMAGE,
580+
repository=hmi_config.lightllm_repository,
581+
tag=framework_image_tag,
582+
command=command,
583+
streaming_command=command,
584+
protocol="http",
585+
readiness_initial_delay_seconds=10,
586+
healthcheck_route="/health",
587+
predict_route="/generate",
588+
streaming_predict_route="/generate_stream",
589+
env={},
590+
),
591+
metadata={},
592+
),
593+
do_auth_check=False,
594+
# Skip auth check because llm create endpoint is called as the user itself,
595+
# but the user isn't directly making the action. It should come from the fine tune
596+
# job.
597+
)
598+
).model_bundle_id
599+
502600
async def execute(
503601
self, user: User, request: CreateLLMModelEndpointV1Request
504602
) -> CreateLLMModelEndpointV1Response:
@@ -764,6 +862,19 @@ def model_output_to_completion_output(
764862
num_completion_tokens=model_output["count_output_tokens"],
765863
tokens=tokens,
766864
)
865+
elif model_content.inference_framework == LLMInferenceFramework.LIGHTLLM:
866+
print(model_output)
867+
tokens = None
868+
if with_token_probs:
869+
tokens = [
870+
TokenOutput(token=t["text"], log_prob=t["logprob"])
871+
for t in model_output["tokens"]
872+
]
873+
return CompletionOutput(
874+
text=model_output["generated_text"][0],
875+
num_completion_tokens=model_output["count_output_tokens"],
876+
tokens=tokens,
877+
)
767878
else:
768879
raise EndpointUnsupportedInferenceTypeException(
769880
f"Unsupported inference framework {model_content.inference_framework}"
@@ -925,6 +1036,44 @@ async def execute(
9251036
topic=model_endpoint.record.destination, predict_request=inference_request
9261037
)
9271038

1039+
if predict_result.status != TaskStatus.SUCCESS or predict_result.result is None:
1040+
return CompletionSyncV1Response(
1041+
request_id=request_id,
1042+
output=None,
1043+
)
1044+
1045+
output = json.loads(predict_result.result["result"])
1046+
return CompletionSyncV1Response(
1047+
request_id=request_id,
1048+
output=self.model_output_to_completion_output(
1049+
output, model_endpoint, request.return_token_log_probs
1050+
),
1051+
)
1052+
elif endpoint_content.inference_framework == LLMInferenceFramework.LIGHTLLM:
1053+
lightllm_args: Any = {
1054+
"inputs": request.prompt,
1055+
"parameters": {
1056+
"max_new_tokens": request.max_new_tokens,
1057+
},
1058+
}
1059+
# TODO: implement stop sequences
1060+
if request.temperature > 0:
1061+
lightllm_args["parameters"]["temperature"] = request.temperature
1062+
lightllm_args["parameters"]["do_sample"] = True
1063+
else:
1064+
lightllm_args["parameters"]["do_sample"] = False
1065+
if request.return_token_log_probs:
1066+
lightllm_args["parameters"]["return_details"] = True
1067+
1068+
inference_request = SyncEndpointPredictV1Request(
1069+
args=lightllm_args,
1070+
num_retries=NUM_DOWNSTREAM_REQUEST_RETRIES,
1071+
timeout_seconds=DOWNSTREAM_REQUEST_TIMEOUT_SECONDS,
1072+
)
1073+
predict_result = await inference_gateway.predict(
1074+
topic=model_endpoint.record.destination, predict_request=inference_request
1075+
)
1076+
9281077
if predict_result.status != TaskStatus.SUCCESS or predict_result.result is None:
9291078
return CompletionSyncV1Response(
9301079
request_id=request_id,
@@ -1055,6 +1204,25 @@ async def execute(
10551204
if request.return_token_log_probs:
10561205
args["logprobs"] = 1
10571206
args["stream"] = True
1207+
elif model_content.inference_framework == LLMInferenceFramework.LIGHTLLM:
1208+
args = {
1209+
"inputs": request.prompt,
1210+
"parameters": {
1211+
"max_new_tokens": request.max_new_tokens,
1212+
},
1213+
}
1214+
# TODO: stop sequences
1215+
if request.temperature > 0:
1216+
args["parameters"]["temperature"] = request.temperature
1217+
args["parameters"]["do_sample"] = True
1218+
else:
1219+
args["parameters"]["do_sample"] = False
1220+
if request.return_token_log_probs:
1221+
args["parameters"]["return_details"] = True
1222+
else:
1223+
raise EndpointUnsupportedInferenceTypeException(
1224+
f"Unsupported inference framework {model_content.inference_framework}"
1225+
)
10581226

10591227
inference_request = SyncEndpointPredictV1Request(
10601228
args=args,
@@ -1163,6 +1331,30 @@ async def execute(
11631331
request_id=request_id,
11641332
output=None,
11651333
)
1334+
elif model_content.inference_framework == LLMInferenceFramework.LIGHTLLM:
1335+
if res.status == TaskStatus.SUCCESS and result is not None:
1336+
print(result)
1337+
token = None
1338+
num_completion_tokens += 1
1339+
if request.return_token_log_probs:
1340+
token = TokenOutput(
1341+
token=result["result"]["token"]["text"],
1342+
log_prob=result["result"]["token"]["logprob"],
1343+
)
1344+
yield CompletionStreamV1Response(
1345+
request_id=request_id,
1346+
output=CompletionStreamOutput(
1347+
text=result["result"]["token"]["text"],
1348+
finished=result["result"]["finished"],
1349+
num_completion_tokens=num_completion_tokens,
1350+
token=token,
1351+
),
1352+
)
1353+
else:
1354+
yield CompletionStreamV1Response(
1355+
request_id=request_id,
1356+
output=None,
1357+
)
11661358
else:
11671359
raise EndpointUnsupportedInferenceTypeException(
11681360
f"Unsupported inference framework {model_content.inference_framework}"

model-engine/service_configs/service_config_circleci.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ datadog_trace_enabled: false
5656
istio_enabled: true
5757
tgi_repository: "text-generation-inference"
5858
vllm_repository: "vllm"
59+
lightllm_repository: "lightllm"
5960

6061
# S3 access
6162
hf_user_fine_tuned_weights_prefix: "s3://test-bucket"

0 commit comments

Comments
 (0)