Skip to content

Commit 3e8decd

Browse files
committed
add config methods to bert connective disambiguation model
1 parent a4053e0 commit 3e8decd

File tree

1 file changed

+18
-3
lines changed
  • discopy/components/connective

1 file changed

+18
-3
lines changed

discopy/components/connective/bert.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -70,18 +70,33 @@ class ConnectiveClassifier(Component):
7070
def __init__(self, input_dim, used_context: int = 0):
7171
in_size = input_dim + 2 * used_context * input_dim
7272
self.model = get_conn_model(in_size, 1, 1024)
73+
self.input_dim = input_dim
7374
self.used_context = used_context
7475

76+
def get_config(self):
77+
return {
78+
'model_name': self.model_name,
79+
'input_dim': self.input_dim,
80+
'used_context': self.used_context,
81+
}
82+
83+
@staticmethod
84+
def from_config(config: dict):
85+
clf = ConnectiveClassifier(config['input_dim'], config['used_context'])
86+
clf.sense_map = config['sense_map']
87+
clf.classes = config['classes']
88+
return clf
89+
7590
def load(self, path):
76-
if not os.path.exists(os.path.join(path, f'connective_nn_{self.used_context}.model')):
91+
if not os.path.exists(os.path.join(path, self.model_name)):
7792
raise FileNotFoundError("Model not found.")
78-
self.model = tf.keras.models.load_model(os.path.join(path, f'connective_nn_{self.used_context}.model'),
93+
self.model = tf.keras.models.load_model(os.path.join(path, self.model_name),
7994
compile=False)
8095

8196
def save(self, path):
8297
if not os.path.exists(path):
8398
os.makedirs(path)
84-
self.model.save(os.path.join(path, f'connective_nn_{self.used_context}.model'))
99+
self.model.save(os.path.join(path, self.model_name))
85100

86101
def fit(self, docs_train: List[Document], docs_val: List[Document] = None):
87102
if docs_val is None:

0 commit comments

Comments
 (0)