Skip to content

Commit e694e92

Browse files
committed
添加模型名称重名校验
1 parent 37f98ae commit e694e92

File tree

1 file changed

+13
-3
lines changed

1 file changed

+13
-3
lines changed

gpt_server/utils.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ def start_model_worker(config: dict):
137137
error_msg = f"请参照 https://github.com/shell-nlp/gpt_server/blob/main/gpt_server/script/config.yaml 设置正确的 model_worker_args"
138138
logger.error(error_msg)
139139
raise KeyError(error_msg)
140-
140+
exist_model_names = [] # 记录已经存在的model_name
141141
for model_config_ in config["models"]:
142142
for model_name, model_config in model_config_.items():
143143
# 启用的模型
@@ -202,7 +202,15 @@ def start_model_worker(config: dict):
202202
if lora: # 如果使用lora,将lora的name添加到 model_names 中
203203
lora_names = list(lora.keys())
204204
model_names += "," + ",".join(lora_names)
205-
205+
intersection = list(
206+
set(exist_model_names) & set(model_names.split(","))
207+
) # 获取交集
208+
if intersection: # 如果有交集 则返回True
209+
logger.error(
210+
f"存在重名的模型名称或别名:{intersection} ,请检查 config.yaml 文件"
211+
)
212+
sys.exit()
213+
exist_model_names.extend(model_names.split(","))
206214
# 获取 worker 数目 并获取每个 worker 的资源
207215
workers = model_config["workers"]
208216

@@ -252,8 +260,10 @@ def start_model_worker(config: dict):
252260
if punc_model:
253261
cmd += f" --vad_model '{punc_model}'"
254262
p = Process(target=run_cmd, args=(cmd,))
255-
p.start()
263+
# p.start()
256264
process.append(p)
265+
for p in process:
266+
p.start()
257267
for p in process:
258268
p.join()
259269

0 commit comments

Comments
 (0)