Skip to content

Commit 43846ca

Browse files
committed
实现 推理引擎的 shutdown方法
1 parent 836c66d commit 43846ca

File tree

6 files changed

+23
-2
lines changed

6 files changed

+23
-2
lines changed

gpt_server/model_backend/base.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,3 +6,6 @@ class ModelBackend(ABC):
66
@abstractmethod
77
def stream_chat(self, params: Dict[str, Any]):
88
pass
9+
10+
def shutdown(self):
11+
pass

gpt_server/model_backend/hf_backend.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,9 @@ def __init__(self, tokenizer: PreTrainedTokenizer, model: torch.nn.Module) -> No
5454
continue
5555
self.model.load_adapter(model_id=lora_path, adapter_name=lora_name)
5656

57+
def shutdown(self):
58+
pass
59+
5760
async def stream_chat(self, params: Dict[str, Any]):
5861
# params 已不需要传入 prompt
5962
messages = params["messages"]

gpt_server/model_backend/lmdeploy_backend.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,9 @@ def __init__(self, model_path, tokenizer: PreTrainedTokenizer) -> None:
102102
# 自定义日志
103103
self.async_engine.request_logger = CustomRequestLogger(max_log_len=None)
104104

105+
def shutdown(self):
106+
pass
107+
105108
async def stream_chat(self, params: Dict[str, Any]) -> AsyncGenerator:
106109
# params 已不需要传入 prompt
107110
messages = params["messages"]

gpt_server/model_backend/sglang_backend.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import sglang as sgl
99
from transformers import PreTrainedTokenizer
1010
from sglang.utils import convert_json_schema_to_str
11-
11+
from sglang.srt.entrypoints.engine import Engine
1212
from qwen_vl_utils import process_vision_info
1313
from sglang.srt.managers.io_struct import GenerateReqInput
1414
from gpt_server.settings import get_model_config
@@ -48,7 +48,7 @@ def __init__(self, model_path, tokenizer: PreTrainedTokenizer) -> None:
4848
logger.info(f"model_config: {model_config}")
4949
self.lora_requests = []
5050
# ---
51-
self.async_engine = sgl.Engine(
51+
self.async_engine: Engine = sgl.Engine(
5252
model_path=model_path,
5353
trust_remote_code=True,
5454
mem_fraction_static=model_config.gpu_memory_utilization,
@@ -60,6 +60,9 @@ def __init__(self, model_path, tokenizer: PreTrainedTokenizer) -> None:
6060
)
6161
self.tokenizer = tokenizer
6262

63+
def shutdown(self):
64+
self.async_engine.shutdown()
65+
6366
async def stream_chat(self, params: Dict[str, Any]) -> AsyncGenerator:
6467
# params 已不需要传入 prompt
6568
messages = params["messages"]

gpt_server/model_backend/vllm_backend.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,9 @@ def __init__(self, model_path, tokenizer: PreTrainedTokenizer) -> None:
6161
self.tokenizer = tokenizer
6262
self.reasoning_parser_cache = {}
6363

64+
def shutdown(self):
65+
self.engine.shutdown()
66+
6467
async def stream_chat(self, params: Dict[str, Any]) -> AsyncGenerator:
6568
# params 已不需要传入 prompt
6669
messages = params["messages"]

gpt_server/model_worker/base/model_worker_base.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -375,6 +375,12 @@ async def startup():
375375
limit_worker_concurrency=limit_worker_concurrency,
376376
)
377377

378+
@app.on_event("shutdown")
379+
async def shutdown():
380+
global worker
381+
# 优雅推出
382+
worker.backend.shutdown()
383+
378384
uvicorn.run(app, host=host, port=port)
379385

380386

0 commit comments

Comments
 (0)