Skip to content

Commit ad17b27

Browse files
committed
Refactor using the contextual lemmatizer so that the training script can evaluate that lemmatizer as well
1 parent 47946c7 commit ad17b27

File tree

4 files changed

+31
-11
lines changed

4 files changed

+31
-11
lines changed

stanza/models/lemma/trainer.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import torch.nn.init as init
1313

1414
import stanza.models.common.seq2seq_constant as constant
15+
from stanza.models.common.doc import TEXT, UPOS
1516
from stanza.models.common.foundation_cache import load_charlm
1617
from stanza.models.common.seq2seq_model import Seq2SeqModel
1718
from stanza.models.common.char_model import CharacterLanguageModelWordAdapter
@@ -171,6 +172,29 @@ def predict_contextual(self, sentence_words, sentence_tags, preds):
171172
preds[sent_id][word_id] = pred
172173
return preds
173174

175+
def update_contextual_preds(self, doc, preds):
176+
"""
177+
Update a flat list of preds with the output of the contextual lemmatizers
178+
179+
- First, it unflattens the preds based on the lengths of the sentences
180+
- Then it uses the contextual lemmatizers
181+
- Finally, it reflattens the preds into the format expected by the caller
182+
"""
183+
if len(self.contextual_lemmatizers) == 0:
184+
return preds
185+
186+
sentence_words = doc.get([TEXT], as_sentences=True)
187+
sentence_tags = doc.get([UPOS], as_sentences=True)
188+
sentence_preds = []
189+
start_index = 0
190+
for sent in sentence_words:
191+
end_index = start_index + len(sent)
192+
sentence_preds.append(preds[start_index:end_index])
193+
start_index += len(sent)
194+
preds = self.predict_contextual(sentence_words, sentence_tags, sentence_preds)
195+
preds = [lemma for sentence in preds for lemma in sentence]
196+
return preds
197+
174198
def update_lr(self, new_lr):
175199
utils.change_lr(self.optimizer, new_lr)
176200

stanza/models/lemmatizer.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -298,6 +298,9 @@ def evaluate(args):
298298
logger.info("[Ensembling dict with seq2seq lemmatizer...]")
299299
preds = trainer.ensemble(batch.doc.get([TEXT, UPOS]), preds)
300300

301+
if trainer.has_contextual_lemmatizers():
302+
preds = trainer.update_contextual_preds(batch.doc, preds)
303+
301304
# write to file and score
302305
batch.doc.set([LEMMA], preds)
303306
CoNLL.write_doc2conll(batch.doc, system_pred_file)

stanza/pipeline/lemma_processor.py

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -115,16 +115,7 @@ def process(self, document):
115115
preds = self.trainer.postprocess(batch.doc.get([doc.TEXT]), preds, edits=edits)
116116

117117
if self.trainer.has_contextual_lemmatizers():
118-
sentence_words = batch.doc.get([doc.TEXT], as_sentences=True)
119-
sentence_tags = batch.doc.get([doc.UPOS], as_sentences=True)
120-
sentence_preds = []
121-
start_index = 0
122-
for sent in sentence_words:
123-
end_index = start_index + len(sent)
124-
sentence_preds.append(preds[start_index:end_index])
125-
start_index += len(sent)
126-
preds = self.trainer.predict_contextual(sentence_words, sentence_tags, sentence_preds)
127-
preds = [lemma for sentence in preds for lemma in sentence]
118+
preds = self.trainer.update_contextual_preds(batch.doc, preds)
128119

129120
# map empty string lemmas to '_'
130121
preds = [max([(len(x), x), (0, '_')])[1] for x in preds]

stanza/utils/training/run_lemma.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,9 @@ def run_treebank(mode, paths, treebank, short_name,
167167
'--output', save_name,
168168
'--classifier', 'saved_models/lemma_classifier/%s_lemma_classifier.pt' % short_name]
169169
attach_lemma_classifier.main(attach_args)
170-
# TODO: rerun dev set / test set with the attached classifier?
170+
171+
# now we rerun the dev set - the HI in particular demonstrates some good improvement
172+
lemmatizer.main(dev_args)
171173

172174
def main():
173175
common.main(run_treebank, "lemma", "lemmatizer", add_lemma_args, sub_argparse=lemmatizer.build_argparse(), build_model_filename=build_model_filename, choose_charlm_method=choose_lemma_charlm)

0 commit comments

Comments
 (0)