Skip to content

Commit 1fdd4f4

Browse files
committed
切换vllm 后端为 V0,加快启动和推理性能
1 parent 3299041 commit 1fdd4f4

File tree

3 files changed

+4
-5
lines changed

3 files changed

+4
-5
lines changed

gpt_server/model_backend/vllm_backend.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,7 @@
1919
# 解决vllm中 ray集群在 TP>1时死的Bug
2020
import ray
2121

22-
ray.init(ignore_reinit_error=True, num_cpus=4)
23-
24-
os.environ["VLLM_USE_V1"] = "1"
22+
ray.init(ignore_reinit_error=True, num_cpus=8)
2523

2624

2725
class VllmBackend(ModelBackend):

gpt_server/model_worker/base/model_worker_base.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -259,6 +259,7 @@ def run(cls):
259259
logger.remove(0)
260260
log_level = os.getenv("log_level", "WARNING")
261261
logger.add(sys.stderr, level=log_level)
262+
os.environ["VLLM_USE_V1"] = "0"
262263

263264
host = args.host
264265
controller_address = args.controller_address

gpt_server/model_worker/spark_tts.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,6 @@
1010

1111
root_dir = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
1212

13-
14-
os.environ["VLLM_USE_V1"] = "1"
1513
import httpx
1614
from fastapi import HTTPException
1715
import base64
@@ -69,6 +67,7 @@ def __init__(
6967
model_type="tts",
7068
)
7169
backend = os.environ["backend"]
70+
gpu_memory_utilization = float(os.getenv("gpu_memory_utilization", 0.6))
7271
self.engine = AutoEngine(
7372
model_path=model_path,
7473
max_length=32768,
@@ -79,6 +78,7 @@ def __init__(
7978
wav2vec_attn_implementation="sdpa", # 使用flash attn加速wav2vec
8079
llm_gpu_memory_utilization=0.6,
8180
seed=0,
81+
llm_gpu_memory_utilization=gpu_memory_utilization,
8282
)
8383
loop = asyncio.get_running_loop()
8484
# ------------- 添加声音 -------------

0 commit comments

Comments
 (0)