Skip to content

Commit 0c0049d

Browse files
committed
支持beg-vl
1 parent 24846fb commit 0c0049d

File tree

8 files changed

+208
-56
lines changed

8 files changed

+208
-56
lines changed

README.md

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@
5252
<summary><b>2025</b></summary>
5353

5454
```plaintext
55+
2025-6-6 支持了 bge-vl 系列 (代码样例见gpt_server/tests/test_openai_embedding_vl.py)
5556
2025-6-6 支持了 ritrieve_zh_v1
5657
2025-4-29 支持了 Qwen3
5758
2025-4-24 支持了 Spark-TTS后端的 TTS
@@ -278,11 +279,11 @@ Chat UI界面:
278279

279280
**原则上支持所有的Embedding/Rerank/Classify模型**
280281

281-
**推理速度:** Infinity >> HF
282+
**推理速度:** embedding_infinity > embedding
282283

283284
以下模型经过测试可放心使用:
284285

285-
| Embedding/Rerank/Classify | HF | Infinity |
286+
| Models / model_type | embedding | embedding_infinity |
286287
| ----------------------------------------------------------------------------------- | --- | -------- |
287288
| bge-reranker |||
288289
| bce-reranker |||
@@ -296,6 +297,7 @@ Chat UI界面:
296297
| xiaobu-embedding |||
297298
| Conan-embedding-v1 |||
298299
| ritrieve_zh_v1 |||
300+
| bge-vl || × |
299301
| KoalaAI/Text-Moderation(文本审核/多分类,审核文本是否存在暴力、色情等) | × ||
300302
| protectai/deberta-v3-base-prompt-injection-v2(提示注入/2分类,审核文本为提示注入) | × ||
301303

gpt_server/model_worker/embedding.py

Lines changed: 67 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,13 @@
1+
import asyncio
12
import os
23
from typing import List
34

45
import sentence_transformers
6+
import torch
7+
from transformers import AutoConfig, AutoModel
58
from loguru import logger
69
from gpt_server.model_worker.base.model_worker_base import ModelWorkerBase
10+
from gpt_server.model_worker.utils import load_base64_or_url
711

812

913
class EmbeddingWorker(ModelWorkerBase):
@@ -33,23 +37,38 @@ def __init__(
3337
device = "cuda"
3438
logger.warning(f"使用{device}加载...")
3539
model_kwargs = {"device": device}
36-
self.encode_kwargs = {"normalize_embeddings": True, "batch_size": 64}
40+
# TODO
3741
self.mode = "embedding"
38-
# rerank
39-
for model_name in model_names:
40-
if "rerank" in model_name:
41-
self.mode = "rerank"
42-
break
43-
if self.mode == "rerank":
44-
self.client = sentence_transformers.CrossEncoder(
45-
model_name=model_path, **model_kwargs
46-
)
47-
logger.warning("正在使用 rerank 模型...")
48-
elif self.mode == "embedding":
49-
self.client = sentence_transformers.SentenceTransformer(
50-
model_path, **model_kwargs
51-
)
52-
logger.warning("正在使用 embedding 模型...")
42+
model_type = getattr(
43+
getattr(self.model_config, "text_config", {}), "model_type", None
44+
)
45+
logger.warning(f"model_type: {model_type}")
46+
if "clip_text_model" in model_type: # clip text 模型
47+
self.mode = "clip_text_model"
48+
self.client = AutoModel.from_pretrained(
49+
model_path, trust_remote_code=True
50+
) # You must set trust_remote_code=True
51+
self.client.set_processor(model_path)
52+
self.client.eval()
53+
else:
54+
self.encode_kwargs = {"normalize_embeddings": True, "batch_size": 64}
55+
56+
# rerank
57+
for model_name in model_names:
58+
if "rerank" in model_name:
59+
self.mode = "rerank"
60+
break
61+
if self.mode == "rerank":
62+
self.client = sentence_transformers.CrossEncoder(
63+
model_name=model_path, **model_kwargs
64+
)
65+
logger.warning("正在使用 rerank 模型...")
66+
elif self.mode == "embedding":
67+
self.client = sentence_transformers.SentenceTransformer(
68+
model_path, **model_kwargs
69+
)
70+
logger.warning("正在使用 embedding 模型...")
71+
logger.warning(f"模型:{model_names[0]}")
5372

5473
async def get_embeddings(self, params):
5574
self.call_ct += 1
@@ -69,6 +88,38 @@ async def get_embeddings(self, params):
6988
sentence_pairs = [[query, inp] for inp in texts]
7089
scores = self.client.predict(sentence_pairs)
7190
embedding = [[float(score)] for score in scores]
91+
elif self.mode == "clip_text_model":
92+
token_num = 0
93+
if isinstance(texts[0], dict):
94+
text = [i["text"] for i in texts]
95+
text = list(map(lambda x: x.replace("\n", " "), text))
96+
97+
images = [i["image"] for i in texts]
98+
coro_list = []
99+
for i in images:
100+
coro = load_base64_or_url(base64_or_url=i)
101+
coro_list.append(coro)
102+
result_images = await asyncio.gather(*coro_list)
103+
104+
embedding = self.client.encode(
105+
images=result_images,
106+
text=text,
107+
).tolist()
108+
elif isinstance(texts[0], str):
109+
if "http" in texts[0] or "data:image" in texts[0]: # 图片
110+
images = texts
111+
coro_list = []
112+
for i in images:
113+
coro = load_base64_or_url(base64_or_url=i)
114+
coro_list.append(coro)
115+
result_images = await asyncio.gather(*coro_list)
116+
embedding = self.client.encode(
117+
images=result_images,
118+
).tolist()
119+
else: # 文本
120+
embedding = self.client.encode(
121+
text=texts,
122+
).tolist()
72123
ret["embedding"] = embedding
73124
ret["token_num"] = token_num
74125
return ret

gpt_server/model_worker/spark_tts.py

Lines changed: 1 addition & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -4,46 +4,12 @@
44
from typing import List
55
from loguru import logger
66
from gpt_server.model_worker.base.model_worker_base import ModelWorkerBase
7-
7+
from gpt_server.model_worker.utils import load_base64_or_url
88
from flashtts.engine import AutoEngine
99
from flashtts.server.utils.audio_writer import StreamingAudioWriter
1010

1111
root_dir = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
1212

13-
import httpx
14-
from fastapi import HTTPException
15-
import base64
16-
import io
17-
18-
19-
async def get_audio_bytes_from_url(url: str) -> bytes:
20-
async with httpx.AsyncClient() as client:
21-
response = await client.get(url)
22-
if response.status_code != 200:
23-
raise HTTPException(status_code=400, detail="无法从指定 URL 下载参考音频")
24-
return response.content
25-
26-
27-
async def load_base64_or_url(audio):
28-
# 根据 reference_audio 内容判断读取方式
29-
if audio.startswith("http://") or audio.startswith("https://"):
30-
audio_bytes = await get_audio_bytes_from_url(audio)
31-
else:
32-
try:
33-
audio_bytes = base64.b64decode(audio)
34-
except Exception as e:
35-
logger.warning("无效的 base64 音频数据: " + str(e))
36-
raise HTTPException(
37-
status_code=400, detail="无效的 base64 音频数据: " + str(e)
38-
)
39-
# 利用 BytesIO 包装字节数据,然后使用 soundfile 读取为 numpy 数组
40-
try:
41-
bytes_io = io.BytesIO(audio_bytes)
42-
except Exception as e:
43-
logger.warning("读取参考音频失败: " + str(e))
44-
raise HTTPException(status_code=400, detail="读取参考音频失败: " + str(e))
45-
return bytes_io
46-
4713

4814
class SparkTTSWorker(ModelWorkerBase):
4915
def __init__(

gpt_server/model_worker/utils.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
import httpx
2+
from loguru import logger
3+
from fastapi import HTTPException
4+
import base64
5+
import io
6+
7+
8+
def extract_base64(data_url: str):
9+
"""从Data URL中提取纯Base64数据"""
10+
return data_url.split(",", 1)[-1] # 从第一个逗号后分割
11+
12+
13+
async def get_bytes_from_url(url: str) -> bytes:
14+
async with httpx.AsyncClient() as client:
15+
response = await client.get(url)
16+
if response.status_code != 200:
17+
raise HTTPException(status_code=400, detail="无法从指定 URL 下载数据")
18+
return response.content
19+
20+
21+
async def load_base64_or_url(base64_or_url):
22+
# 根据 reference_audio 内容判断读取方式
23+
if base64_or_url.startswith("http://") or base64_or_url.startswith("https://"):
24+
audio_bytes = await get_bytes_from_url(base64_or_url)
25+
else:
26+
try:
27+
if "data:" in base64_or_url:
28+
base64_or_url = extract_base64(data_url=base64_or_url)
29+
audio_bytes = base64.b64decode(base64_or_url)
30+
except Exception as e:
31+
logger.warning("无效的 base64 数据: " + str(e))
32+
raise HTTPException(status_code=400, detail="无效的 base64 数据: " + str(e))
33+
# 利用 BytesIO 包装字节数据,然后使用 soundfile 读取为 numpy 数组
34+
try:
35+
bytes_io = io.BytesIO(audio_bytes)
36+
except Exception as e:
37+
logger.warning("读取数据失败: " + str(e))
38+
raise HTTPException(status_code=400, detail="读取数据失败: " + str(e))
39+
return bytes_io
40+
41+
42+
if __name__ == "__main__":
43+
44+
# 示例用法
45+
data_url = "..."
46+
pure_base64 = extract_base64(data_url)
47+
print(pure_base64) # 输出: iVBORw0KGgoAAAANSUhEUg...

tests/test_needle_haystack.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
"""大海捞针评测"""
2+
3+
import os
4+
from evalscope import TaskConfig, run_task
5+
6+
task_cfg = TaskConfig(
7+
model="qwen",
8+
api_url="http://localhost:8082/v1",
9+
api_key="123",
10+
eval_type="service", # 使用API模型服务
11+
datasets=["needle_haystack"],
12+
eval_batch_size=20,
13+
dataset_args={
14+
"needle_haystack": {
15+
"subset_list": ["chinese", "english"][:1], # 可选,指定使用中文或英文子集
16+
# 支持配置的参数
17+
"extra_params": {
18+
# 问题
19+
"retrieval_question": "What is the best thing to do in San Francisco?",
20+
# 插入的文本(可以设置为多个)
21+
"needles": [
22+
"\nThe best thing to do in San Francisco is eat a sandwich and sit in Dolores Park on a sunny day.\n"
23+
],
24+
# 语料的最小长度
25+
"context_lengths_min": 1000,
26+
# 语料的最大长度
27+
"context_lengths_max": 64 * 1024, # 64K
28+
# 语料的区间数
29+
"context_lengths_num_intervals": 20,
30+
# 插入文本最小位置(百分数)
31+
"document_depth_percent_min": 0,
32+
# 插入文本最大位置(百分数)
33+
"document_depth_percent_max": 100,
34+
# 插入文本位置区间数
35+
"document_depth_percent_intervals": 10,
36+
# tokenizer的路径(可以指定modelscope的id)
37+
"tokenizer_path": "/home/dev/model/Qwen/Qwen2___5-32B-Instruct-AWQ/",
38+
"show_score": True, # 是否在heatmap上显示分数
39+
},
40+
}
41+
},
42+
generation_config={
43+
"max_tokens": 512, # 最大生成token数
44+
},
45+
judge_worker_num=5,
46+
judge_model_args={
47+
"model_id": "qwen",
48+
"api_url": "http://localhost:8082/v1",
49+
"api_key": "123",
50+
},
51+
)
52+
run_task(task_cfg=task_cfg)

tests/test_openai_embedding_vl.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
from openai import OpenAI
2+
from rich import print
3+
import base64
4+
5+
6+
## 测试只对 文本嵌入
7+
client = OpenAI(api_key="EMPTY", base_url="http://localhost:8082/v1")
8+
data = client.embeddings.create(model="bge-vl", input=["你是谁", "你是谁"])
9+
10+
print(data.data)
11+
## 测试只对 图片嵌入
12+
13+
14+
def image_to_base64(image_path):
15+
"""将图片转换为Base64字符串"""
16+
base64_prefix = "data:image/png;base64,"
17+
18+
with open(image_path, "rb") as image_file:
19+
base64_string = base64.b64encode(image_file.read()).decode("utf-8")
20+
return base64_prefix + base64_string
21+
22+
23+
image_path = "../assets/logo.png"
24+
# 使用本地的图片
25+
url = image_to_base64(image_path)
26+
data = client.embeddings.create(model="bge-vl", input=[url, url])
27+
28+
print(data.data)
29+
## 测试 图文一起嵌入
30+
data = client.embeddings.create(
31+
model="bge-vl", input=[{"text": "你好", "image": url}] * 2
32+
)
33+
34+
print(data.data)

tests/test_openai_vl_chat.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ def image_to_base64(image_path):
2222

2323
stream = True
2424
output = client.chat.completions.create(
25-
model="internvl2", # internlm chatglm3 qwen llama3 chatglm4
25+
model="minicpmv", # internlm chatglm3 qwen llama3 chatglm4
2626
messages=[
2727
{
2828
"role": "user",

tests/test_perf.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,9 @@
55
if __name__ == "__main__":
66
args = Arguments(
77
url="http://localhost:8082/v1/chat/completions", # 请求的URL地址
8-
parallel=20, # 并行请求的任务数量
8+
parallel=100, # 并行请求的任务数量
99
model="qwen", # 使用的模型名称
10-
number=20, # 请求数量
10+
number=100, # 请求数量
1111
api="openai", # 使用的API服务
1212
dataset="openqa", # 数据集名称
1313
stream=True, # 是否启用流式处理

0 commit comments

Comments
 (0)