|
119 | 119 | "falcon-40b": "tiiuae/falcon-40b", |
120 | 120 | "falcon-40b-instruct": "tiiuae/falcon-40b-instruct", |
121 | 121 | }, |
| 122 | + LLMInferenceFramework.LIGHTLLM: { |
| 123 | + "llama-7b": "decapoda-research/llama-7b-hf", |
| 124 | + "llama-2-7b": "meta-llama/Llama-2-7b-hf", |
| 125 | + "llama-2-7b-chat": "meta-llama/Llama-2-7b-chat-hf", |
| 126 | + "llama-2-13b": "meta-llama/Llama-2-13b-hf", |
| 127 | + "llama-2-13b-chat": "meta-llama/Llama-2-13b-chat-hf", |
| 128 | + "llama-2-70b": "meta-llama/Llama-2-70b-hf", |
| 129 | + "llama-2-70b-chat": "meta-llama/Llama-2-70b-chat-hf", |
| 130 | + }, |
122 | 131 | } |
123 | 132 |
|
124 | 133 |
|
@@ -221,6 +230,15 @@ async def create_model_bundle( |
221 | 230 | num_shards, |
222 | 231 | checkpoint_path, |
223 | 232 | ) |
| 233 | + elif framework == LLMInferenceFramework.LIGHTLLM: |
| 234 | + bundle_id = await self.create_lightllm_bundle( |
| 235 | + user, |
| 236 | + model_name, |
| 237 | + framework_image_tag, |
| 238 | + endpoint_name, |
| 239 | + num_shards, |
| 240 | + checkpoint_path, |
| 241 | + ) |
224 | 242 | else: |
225 | 243 | raise ObjectHasInvalidValueException( |
226 | 244 | f"Framework {framework} is not supported for source {source}." |
@@ -499,6 +517,86 @@ async def create_vllm_bundle( |
499 | 517 | ) |
500 | 518 | ).model_bundle_id |
501 | 519 |
|
| 520 | + async def create_lightllm_bundle( |
| 521 | + self, |
| 522 | + user: User, |
| 523 | + model_name: str, |
| 524 | + framework_image_tag: str, |
| 525 | + endpoint_unique_name: str, |
| 526 | + num_shards: int, |
| 527 | + checkpoint_path: Optional[str], |
| 528 | + ): |
| 529 | + command = [] |
| 530 | + |
| 531 | + # TODO: incorporate auto calculate max_total_token_num from https://github.com/ModelTC/lightllm/pull/81 |
| 532 | + max_total_token_num = 6000 # LightLLM default |
| 533 | + if num_shards == 1: |
| 534 | + max_total_token_num = 15000 # Default for Llama 2 7B on 1 x A10 |
| 535 | + elif num_shards == 2: |
| 536 | + max_total_token_num = 21000 # Default for Llama 2 13B on 2 x A10 |
| 537 | + elif num_shards == 4: |
| 538 | + max_total_token_num = 70000 # Default for Llama 2 13B on 4 x A10 |
| 539 | + max_req_input_len = 2047 |
| 540 | + max_req_total_len = 2048 |
| 541 | + if "llama-2" in model_name: |
| 542 | + max_req_input_len = 4095 |
| 543 | + max_req_total_len = 4096 |
| 544 | + |
| 545 | + subcommands = [] |
| 546 | + if checkpoint_path is not None: |
| 547 | + if checkpoint_path.startswith("s3://"): |
| 548 | + final_weights_folder = "model_files" |
| 549 | + subcommands += self.load_model_weights_sub_commands( |
| 550 | + LLMInferenceFramework.LIGHTLLM, |
| 551 | + framework_image_tag, |
| 552 | + checkpoint_path, |
| 553 | + final_weights_folder, |
| 554 | + ) |
| 555 | + else: |
| 556 | + raise ObjectHasInvalidValueException( |
| 557 | + f"Not able to load checkpoint path {checkpoint_path}." |
| 558 | + ) |
| 559 | + else: |
| 560 | + final_weights_folder = _SUPPORTED_MODEL_NAMES[LLMInferenceFramework.VLLM][model_name] |
| 561 | + |
| 562 | + subcommands.append( |
| 563 | + f"python -m lightllm.server.api_server --model_dir {final_weights_folder} --port 5005 --tp {num_shards} --max_total_token_num {max_total_token_num} --max_req_input_len {max_req_input_len} --max_req_total_len {max_req_total_len} --tokenizer_mode auto" |
| 564 | + ) |
| 565 | + |
| 566 | + command = [ |
| 567 | + "/bin/bash", |
| 568 | + "-c", |
| 569 | + ";".join(subcommands), |
| 570 | + ] |
| 571 | + |
| 572 | + return ( |
| 573 | + await self.create_model_bundle_use_case.execute( |
| 574 | + user, |
| 575 | + CreateModelBundleV2Request( |
| 576 | + name=endpoint_unique_name, |
| 577 | + schema_location="TBA", |
| 578 | + flavor=StreamingEnhancedRunnableImageFlavor( |
| 579 | + flavor=ModelBundleFlavorType.STREAMING_ENHANCED_RUNNABLE_IMAGE, |
| 580 | + repository=hmi_config.lightllm_repository, |
| 581 | + tag=framework_image_tag, |
| 582 | + command=command, |
| 583 | + streaming_command=command, |
| 584 | + protocol="http", |
| 585 | + readiness_initial_delay_seconds=10, |
| 586 | + healthcheck_route="/health", |
| 587 | + predict_route="/generate", |
| 588 | + streaming_predict_route="/generate_stream", |
| 589 | + env={}, |
| 590 | + ), |
| 591 | + metadata={}, |
| 592 | + ), |
| 593 | + do_auth_check=False, |
| 594 | + # Skip auth check because llm create endpoint is called as the user itself, |
| 595 | + # but the user isn't directly making the action. It should come from the fine tune |
| 596 | + # job. |
| 597 | + ) |
| 598 | + ).model_bundle_id |
| 599 | + |
502 | 600 | async def execute( |
503 | 601 | self, user: User, request: CreateLLMModelEndpointV1Request |
504 | 602 | ) -> CreateLLMModelEndpointV1Response: |
@@ -764,6 +862,19 @@ def model_output_to_completion_output( |
764 | 862 | num_completion_tokens=model_output["count_output_tokens"], |
765 | 863 | tokens=tokens, |
766 | 864 | ) |
| 865 | + elif model_content.inference_framework == LLMInferenceFramework.LIGHTLLM: |
| 866 | + print(model_output) |
| 867 | + tokens = None |
| 868 | + if with_token_probs: |
| 869 | + tokens = [ |
| 870 | + TokenOutput(token=t["text"], log_prob=t["logprob"]) |
| 871 | + for t in model_output["tokens"] |
| 872 | + ] |
| 873 | + return CompletionOutput( |
| 874 | + text=model_output["generated_text"][0], |
| 875 | + num_completion_tokens=model_output["count_output_tokens"], |
| 876 | + tokens=tokens, |
| 877 | + ) |
767 | 878 | else: |
768 | 879 | raise EndpointUnsupportedInferenceTypeException( |
769 | 880 | f"Unsupported inference framework {model_content.inference_framework}" |
@@ -925,6 +1036,44 @@ async def execute( |
925 | 1036 | topic=model_endpoint.record.destination, predict_request=inference_request |
926 | 1037 | ) |
927 | 1038 |
|
| 1039 | + if predict_result.status != TaskStatus.SUCCESS or predict_result.result is None: |
| 1040 | + return CompletionSyncV1Response( |
| 1041 | + request_id=request_id, |
| 1042 | + output=None, |
| 1043 | + ) |
| 1044 | + |
| 1045 | + output = json.loads(predict_result.result["result"]) |
| 1046 | + return CompletionSyncV1Response( |
| 1047 | + request_id=request_id, |
| 1048 | + output=self.model_output_to_completion_output( |
| 1049 | + output, model_endpoint, request.return_token_log_probs |
| 1050 | + ), |
| 1051 | + ) |
| 1052 | + elif endpoint_content.inference_framework == LLMInferenceFramework.LIGHTLLM: |
| 1053 | + lightllm_args: Any = { |
| 1054 | + "inputs": request.prompt, |
| 1055 | + "parameters": { |
| 1056 | + "max_new_tokens": request.max_new_tokens, |
| 1057 | + }, |
| 1058 | + } |
| 1059 | + # TODO: implement stop sequences |
| 1060 | + if request.temperature > 0: |
| 1061 | + lightllm_args["parameters"]["temperature"] = request.temperature |
| 1062 | + lightllm_args["parameters"]["do_sample"] = True |
| 1063 | + else: |
| 1064 | + lightllm_args["parameters"]["do_sample"] = False |
| 1065 | + if request.return_token_log_probs: |
| 1066 | + lightllm_args["parameters"]["return_details"] = True |
| 1067 | + |
| 1068 | + inference_request = SyncEndpointPredictV1Request( |
| 1069 | + args=lightllm_args, |
| 1070 | + num_retries=NUM_DOWNSTREAM_REQUEST_RETRIES, |
| 1071 | + timeout_seconds=DOWNSTREAM_REQUEST_TIMEOUT_SECONDS, |
| 1072 | + ) |
| 1073 | + predict_result = await inference_gateway.predict( |
| 1074 | + topic=model_endpoint.record.destination, predict_request=inference_request |
| 1075 | + ) |
| 1076 | + |
928 | 1077 | if predict_result.status != TaskStatus.SUCCESS or predict_result.result is None: |
929 | 1078 | return CompletionSyncV1Response( |
930 | 1079 | request_id=request_id, |
@@ -1055,6 +1204,25 @@ async def execute( |
1055 | 1204 | if request.return_token_log_probs: |
1056 | 1205 | args["logprobs"] = 1 |
1057 | 1206 | args["stream"] = True |
| 1207 | + elif model_content.inference_framework == LLMInferenceFramework.LIGHTLLM: |
| 1208 | + args = { |
| 1209 | + "inputs": request.prompt, |
| 1210 | + "parameters": { |
| 1211 | + "max_new_tokens": request.max_new_tokens, |
| 1212 | + }, |
| 1213 | + } |
| 1214 | + # TODO: stop sequences |
| 1215 | + if request.temperature > 0: |
| 1216 | + args["parameters"]["temperature"] = request.temperature |
| 1217 | + args["parameters"]["do_sample"] = True |
| 1218 | + else: |
| 1219 | + args["parameters"]["do_sample"] = False |
| 1220 | + if request.return_token_log_probs: |
| 1221 | + args["parameters"]["return_details"] = True |
| 1222 | + else: |
| 1223 | + raise EndpointUnsupportedInferenceTypeException( |
| 1224 | + f"Unsupported inference framework {model_content.inference_framework}" |
| 1225 | + ) |
1058 | 1226 |
|
1059 | 1227 | inference_request = SyncEndpointPredictV1Request( |
1060 | 1228 | args=args, |
@@ -1163,6 +1331,30 @@ async def execute( |
1163 | 1331 | request_id=request_id, |
1164 | 1332 | output=None, |
1165 | 1333 | ) |
| 1334 | + elif model_content.inference_framework == LLMInferenceFramework.LIGHTLLM: |
| 1335 | + if res.status == TaskStatus.SUCCESS and result is not None: |
| 1336 | + print(result) |
| 1337 | + token = None |
| 1338 | + num_completion_tokens += 1 |
| 1339 | + if request.return_token_log_probs: |
| 1340 | + token = TokenOutput( |
| 1341 | + token=result["result"]["token"]["text"], |
| 1342 | + log_prob=result["result"]["token"]["logprob"], |
| 1343 | + ) |
| 1344 | + yield CompletionStreamV1Response( |
| 1345 | + request_id=request_id, |
| 1346 | + output=CompletionStreamOutput( |
| 1347 | + text=result["result"]["token"]["text"], |
| 1348 | + finished=result["result"]["finished"], |
| 1349 | + num_completion_tokens=num_completion_tokens, |
| 1350 | + token=token, |
| 1351 | + ), |
| 1352 | + ) |
| 1353 | + else: |
| 1354 | + yield CompletionStreamV1Response( |
| 1355 | + request_id=request_id, |
| 1356 | + output=None, |
| 1357 | + ) |
1166 | 1358 | else: |
1167 | 1359 | raise EndpointUnsupportedInferenceTypeException( |
1168 | 1360 | f"Unsupported inference framework {model_content.inference_framework}" |
|
0 commit comments