File tree Expand file tree Collapse file tree 1 file changed +5
-7
lines changed Expand file tree Collapse file tree 1 file changed +5
-7
lines changed Original file line number Diff line number Diff line change @@ -37,16 +37,15 @@ def __init__(
37
37
device = "cuda"
38
38
logger .warning (f"使用{ device } 加载..." )
39
39
model_kwargs = {"device" : device }
40
+ if device == "cuda" :
41
+ device = torch .device ("cuda" if torch .cuda .is_available () else "cpu" )
40
42
# TODO
41
43
self .mode = get_embedding_mode (model_path = model_path )
42
44
self .encode_kwargs = {"normalize_embeddings" : True , "batch_size" : 64 }
43
45
if "clip_text_model" in self .mode : # clip text 模型
44
46
self .client = AutoModel .from_pretrained (model_path , trust_remote_code = True )
45
- if device == "cuda" :
46
- self .client .to (
47
- torch .device ("cuda" if torch .cuda .is_available () else "cpu" )
48
- )
49
- logger .info (f"device: { self .client .device } " )
47
+ self .client .to (device )
48
+ logger .info (f"device: { self .client .device } " )
50
49
self .client .set_processor (model_path )
51
50
self .client .eval ()
52
51
elif "vl_rerank" == self .mode :
@@ -56,8 +55,7 @@ def __init__(
56
55
trust_remote_code = True ,
57
56
# attn_implementation="flash_attention_2",
58
57
)
59
-
60
- self .client .to ("cuda" ) # or 'cpu' if no GPU is available
58
+ self .client .to (device )
61
59
self .client .eval ()
62
60
elif "rerank" == self .mode :
63
61
self .client = sentence_transformers .CrossEncoder (
You can’t perform that action at this time.
0 commit comments