Skip to content

Commit 7f867f8

Browse files
committed
优化 tts
1 parent 4cc45ee commit 7f867f8

File tree

2 files changed

+7
-3
lines changed

2 files changed

+7
-3
lines changed

gpt_server/model_worker/base/model_worker_base.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def __init__(
5454
multimodal: bool = False,
5555
):
5656
is_vision = False
57-
if model_type != "asr":
57+
if model_type != "asr" and model_type != "tts":
5858
try:
5959
self.model_config = AutoConfig.from_pretrained(
6060
model_path, trust_remote_code=True
@@ -116,7 +116,11 @@ def get_model_class(self):
116116

117117
def load_model_tokenizer(self, model_path):
118118
"""加载 模型 和 分词器 直接对 self.model 和 self.tokenizer 进行赋值"""
119-
if self.model_type == "embedding" or self.model_type == "asr":
119+
if (
120+
self.model_type == "embedding"
121+
or self.model_type == "asr"
122+
or self.model_type == "tts"
123+
):
120124
return 1
121125
self.tokenizer = AutoTokenizer.from_pretrained(
122126
model_path,

gpt_server/model_worker/spark_tts.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def __init__(
6060
model_names,
6161
limit_worker_concurrency,
6262
conv_template,
63-
model_type="asr",
63+
model_type="tts",
6464
)
6565

6666
self.engine = AutoEngine(

0 commit comments

Comments
 (0)