File tree Expand file tree Collapse file tree 1 file changed +5
-7
lines changed
Expand file tree Collapse file tree 1 file changed +5
-7
lines changed Original file line number Diff line number Diff line change @@ -37,16 +37,15 @@ def __init__(
3737 device = "cuda"
3838 logger .warning (f"使用{ device } 加载..." )
3939 model_kwargs = {"device" : device }
40+ if device == "cuda" :
41+ device = torch .device ("cuda" if torch .cuda .is_available () else "cpu" )
4042 # TODO
4143 self .mode = get_embedding_mode (model_path = model_path )
4244 self .encode_kwargs = {"normalize_embeddings" : True , "batch_size" : 64 }
4345 if "clip_text_model" in self .mode : # clip text 模型
4446 self .client = AutoModel .from_pretrained (model_path , trust_remote_code = True )
45- if device == "cuda" :
46- self .client .to (
47- torch .device ("cuda" if torch .cuda .is_available () else "cpu" )
48- )
49- logger .info (f"device: { self .client .device } " )
47+ self .client .to (device )
48+ logger .info (f"device: { self .client .device } " )
5049 self .client .set_processor (model_path )
5150 self .client .eval ()
5251 elif "vl_rerank" == self .mode :
@@ -56,8 +55,7 @@ def __init__(
5655 trust_remote_code = True ,
5756 # attn_implementation="flash_attention_2",
5857 )
59-
60- self .client .to ("cuda" ) # or 'cpu' if no GPU is available
58+ self .client .to (device )
6159 self .client .eval ()
6260 elif "rerank" == self .mode :
6361 self .client = sentence_transformers .CrossEncoder (
You can’t perform that action at this time.
0 commit comments