|
7 | 7 | from transformers import AutoConfig, AutoModel
|
8 | 8 | from loguru import logger
|
9 | 9 | from gpt_server.model_worker.base.model_worker_base import ModelWorkerBase
|
10 |
| -from gpt_server.model_worker.utils import load_base64_or_url |
| 10 | +from gpt_server.model_worker.utils import load_base64_or_url, get_embedding_mode |
11 | 11 |
|
12 | 12 |
|
13 | 13 | class EmbeddingWorker(ModelWorkerBase):
|
@@ -38,41 +38,27 @@ def __init__(
|
38 | 38 | logger.warning(f"使用{device}加载...")
|
39 | 39 | model_kwargs = {"device": device}
|
40 | 40 | # TODO
|
41 |
| - self.mode = "embedding" |
42 |
| - model_type = getattr( |
43 |
| - getattr(self.model_config, "text_config", {}), "model_type", None |
44 |
| - ) |
45 |
| - logger.warning(f"model_type: {model_type}") |
46 |
| - if "clip_text_model" in model_type: # clip text 模型 |
47 |
| - self.mode = "clip_text_model" |
48 |
| - self.client = AutoModel.from_pretrained( |
49 |
| - model_path, trust_remote_code=True |
50 |
| - ) # You must set trust_remote_code=True |
| 41 | + self.mode = get_embedding_mode(model_path=model_path) |
| 42 | + self.encode_kwargs = {"normalize_embeddings": True, "batch_size": 64} |
| 43 | + if "clip_text_model" in self.mode: # clip text 模型 |
| 44 | + self.client = AutoModel.from_pretrained(model_path, trust_remote_code=True) |
51 | 45 | if device == "cuda":
|
52 | 46 | self.client.to(
|
53 | 47 | torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
54 | 48 | )
|
55 | 49 | logger.info(f"device: {self.client.device}")
|
56 | 50 | self.client.set_processor(model_path)
|
57 | 51 | self.client.eval()
|
58 |
| - else: |
59 |
| - self.encode_kwargs = {"normalize_embeddings": True, "batch_size": 64} |
60 |
| - |
61 |
| - # rerank |
62 |
| - for model_name in model_names: |
63 |
| - if "rerank" in model_name: |
64 |
| - self.mode = "rerank" |
65 |
| - break |
66 |
| - if self.mode == "rerank": |
67 |
| - self.client = sentence_transformers.CrossEncoder( |
68 |
| - model_name=model_path, **model_kwargs |
69 |
| - ) |
70 |
| - logger.warning("正在使用 rerank 模型...") |
71 |
| - elif self.mode == "embedding": |
72 |
| - self.client = sentence_transformers.SentenceTransformer( |
73 |
| - model_path, **model_kwargs |
74 |
| - ) |
75 |
| - logger.warning("正在使用 embedding 模型...") |
| 52 | + elif "rerank" in self.mode: |
| 53 | + self.client = sentence_transformers.CrossEncoder( |
| 54 | + model_name=model_path, **model_kwargs |
| 55 | + ) |
| 56 | + logger.warning("正在使用 rerank 模型...") |
| 57 | + elif "embedding" in self.mode: |
| 58 | + self.client = sentence_transformers.SentenceTransformer( |
| 59 | + model_path, **model_kwargs |
| 60 | + ) |
| 61 | + logger.warning("正在使用 embedding 模型...") |
76 | 62 | logger.warning(f"模型:{model_names[0]}")
|
77 | 63 |
|
78 | 64 | async def get_embeddings(self, params):
|
|
0 commit comments