@@ -70,18 +70,33 @@ class ConnectiveClassifier(Component):
7070 def __init__ (self , input_dim , used_context : int = 0 ):
7171 in_size = input_dim + 2 * used_context * input_dim
7272 self .model = get_conn_model (in_size , 1 , 1024 )
73+ self .input_dim = input_dim
7374 self .used_context = used_context
7475
76+ def get_config (self ):
77+ return {
78+ 'model_name' : self .model_name ,
79+ 'input_dim' : self .input_dim ,
80+ 'used_context' : self .used_context ,
81+ }
82+
83+ @staticmethod
84+ def from_config (config : dict ):
85+ clf = ConnectiveClassifier (config ['input_dim' ], config ['used_context' ])
86+ clf .sense_map = config ['sense_map' ]
87+ clf .classes = config ['classes' ]
88+ return clf
89+
7590 def load (self , path ):
76- if not os .path .exists (os .path .join (path , f'connective_nn_ { self .used_context } .model' )):
91+ if not os .path .exists (os .path .join (path , self .model_name )):
7792 raise FileNotFoundError ("Model not found." )
78- self .model = tf .keras .models .load_model (os .path .join (path , f'connective_nn_ { self .used_context } .model' ),
93+ self .model = tf .keras .models .load_model (os .path .join (path , self .model_name ),
7994 compile = False )
8095
8196 def save (self , path ):
8297 if not os .path .exists (path ):
8398 os .makedirs (path )
84- self .model .save (os .path .join (path , f'connective_nn_ { self .used_context } .model' ))
99+ self .model .save (os .path .join (path , self .model_name ))
85100
86101 def fit (self , docs_train : List [Document ], docs_val : List [Document ] = None ):
87102 if docs_val is None :
0 commit comments