Skip to content

Commit ce57ca6

Browse files
committed
fix
1 parent 5b4f31e commit ce57ca6

File tree

1 file changed

+5
-7
lines changed

1 file changed

+5
-7
lines changed

gpt_server/model_worker/embedding.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -37,16 +37,15 @@ def __init__(
3737
device = "cuda"
3838
logger.warning(f"使用{device}加载...")
3939
model_kwargs = {"device": device}
40+
if device == "cuda":
41+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
4042
# TODO
4143
self.mode = get_embedding_mode(model_path=model_path)
4244
self.encode_kwargs = {"normalize_embeddings": True, "batch_size": 64}
4345
if "clip_text_model" in self.mode: # clip text 模型
4446
self.client = AutoModel.from_pretrained(model_path, trust_remote_code=True)
45-
if device == "cuda":
46-
self.client.to(
47-
torch.device("cuda" if torch.cuda.is_available() else "cpu")
48-
)
49-
logger.info(f"device: {self.client.device}")
47+
self.client.to(device)
48+
logger.info(f"device: {self.client.device}")
5049
self.client.set_processor(model_path)
5150
self.client.eval()
5251
elif "vl_rerank" == self.mode:
@@ -56,8 +55,7 @@ def __init__(
5655
trust_remote_code=True,
5756
# attn_implementation="flash_attention_2",
5857
)
59-
60-
self.client.to("cuda") # or 'cpu' if no GPU is available
58+
self.client.to(device)
6159
self.client.eval()
6260
elif "rerank" == self.mode:
6361
self.client = sentence_transformers.CrossEncoder(

0 commit comments

Comments
 (0)