3
3
import torch .nn as nn
4
4
from typing import Optional , Union
5
5
from pathlib import Path
6
- from .utils import MODEL_PATH
6
+ from .lib .inject import inject_script
7
+ from .utils import MODEL_PATH , model as model_utils
7
8
8
9
9
10
class Predicter :
@@ -47,54 +48,38 @@ def __new__(
47
48
# 否则返回父类的实例
48
49
return super ().__new__ (cls )
49
50
51
+ @classmethod
52
+ def get_root (cls ) -> Path :
53
+ return MODEL_PATH
54
+
50
55
@classmethod
51
56
def get_save_directory (cls , name : str ) -> Path :
52
- save_directory = MODEL_PATH / name
57
+ save_directory = cls . get_root () / name
53
58
return save_directory
54
59
60
+ @classmethod
61
+ def get_fork_directory (cls , fork : str ) -> Optional [Path ]:
62
+ return model_utils .get_fork_directory (cls .get_root (), fork )
63
+
55
64
@classmethod
56
65
def get_checkpoint (cls , name : str , checkpoint : Optional [str ] = None ) -> Path :
57
66
save_directory = cls .get_save_directory (name )
58
-
59
- # 寻找文件夹下的最新的 checkpoint 的 name
60
- if checkpoint :
61
- # check if the checkpoint exists
62
- if not (save_directory / f"{ checkpoint } .pth" ).exists ():
63
- raise FileNotFoundError (f"checkpoint { checkpoint } not found" )
64
- return save_directory / f"{ checkpoint } .pth"
65
- try :
66
- checkpoint_path = max (
67
- save_directory .glob ("*.pth" ), key = lambda x : x .stat ().st_ctime
68
- )
69
- # 去掉后缀
70
- return checkpoint_path
71
- except ValueError :
72
- raise FileNotFoundError (f"checkpoint not found in { save_directory } " )
67
+ return model_utils .get_checkpoint (save_directory , checkpoint )
73
68
74
69
@classmethod
75
70
def get_model_config_json (cls , name : str ) -> dict :
76
71
save_directory = cls .get_save_directory (name )
77
- with open (save_directory / "config.json" , "r" ) as f :
78
- config_dict = json .load (f )
79
- return config_dict
72
+ return model_utils .get_model_config_json (save_directory )
80
73
81
74
@classmethod
82
75
def get_external_config_json (cls , name : str ) -> Optional [dict ]:
83
76
save_directory = cls .get_save_directory (name )
84
- external_config_path = save_directory / "external_config.json"
85
- if external_config_path .exists ():
86
- with open (external_config_path , "r" ) as f :
87
- config_dict = json .load (f )
88
- return config_dict
89
- else :
90
- return None
77
+ return model_utils .get_external_config_json (save_directory )
91
78
92
79
@classmethod
93
80
def get_trainer_config_json (cls , name : str ) -> dict :
94
81
save_directory = cls .get_save_directory (name )
95
- with open (save_directory / "trainer_config.json" , "r" ) as f :
96
- config_dict = json .load (f )
97
- return config_dict
82
+ return model_utils .get_trainer_config_json (save_directory )
98
83
99
84
@classmethod
100
85
def get_model (cls , name : str , checkpoint : Optional [str ] = None ) -> nn .Module :
@@ -114,6 +99,16 @@ def get_model(cls, name: str, checkpoint: Optional[str] = None) -> nn.Module:
114
99
print (f"Loading model { sub_class_name } from { checkpoint_path } " )
115
100
return sub_class .get_model (name , checkpoint )
116
101
102
+ @classmethod
103
+ def inject_script (cls , model , name : str ):
104
+ external_config = cls .get_external_config_json (name )
105
+ if external_config :
106
+ fork = external_config .get ("fork" , None )
107
+ fork_directory = cls .get_fork_directory (fork )
108
+ if fork_directory is not None :
109
+ model = inject_script (model , fork_directory )
110
+ return model
111
+
117
112
@classmethod
118
113
def from_pretrained (
119
114
cls ,
@@ -124,6 +119,7 @@ def from_pretrained(
124
119
trainer_config = cls .get_trainer_config_json (name )
125
120
device = device if device else trainer_config .get ("device" , "cuda" )
126
121
model = cls .get_model (name , checkpoint ).to (device )
122
+ model = cls .inject_script (model , name )
127
123
return cls (name , model , device = device )
128
124
129
125
def _predict (self , ctx ):
0 commit comments