Skip to content

Commit 5ace29b

Browse files
committed
Maybe a tiny bit faster? At least the results are the same
1 parent b929d0b commit 5ace29b

File tree

1 file changed

+10
-2
lines changed
  • stanza/models/depparse/transition

1 file changed

+10
-2
lines changed

stanza/models/depparse/transition/model.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -523,6 +523,8 @@ def update_partial_tree_lstm(self, states, state_idxs, partial_tree_h0, partial_
523523
def calculate_iteration_loss(self, states, gold_transitions, output_hx, left_deprels, right_deprels):
524524
device = next(self.parameters()).device
525525
total_loss = 0.0
526+
deprel_one_hots = []
527+
deprel_hx = []
526528
for state_idx, (state, gold_transition) in enumerate(zip(states, gold_transitions)):
527529
# one hot vectors will be made in the following order:
528530
# Shift - Finalize - word_position left attachments - num_heads-1 right attachments
@@ -561,7 +563,8 @@ def calculate_iteration_loss(self, states, gold_transitions, output_hx, left_dep
561563
# here, though, we only want to attach words to previous heads
562564
# so we do head-1
563565
deprel_output = left_deprels[state_idx][head-1]
564-
total_loss += self.deprel_loss_function(deprel_output, deprel_one_hot)
566+
deprel_hx.append(deprel_output)
567+
deprel_one_hots.append(deprel_one_hot)
565568
elif isinstance(gold_transition, ProjectiveRight) or isinstance(gold_transition, NonprojectiveRight):
566569
if isinstance(gold_transition, ProjectiveRight):
567570
head = len(state.current_heads) - 2
@@ -578,9 +581,14 @@ def calculate_iteration_loss(self, states, gold_transitions, output_hx, left_dep
578581
deprel_one_hot[deprel_idx] = 1
579582
deprel_one_hot = deprel_one_hot.to(device)
580583
deprel_output = right_deprels[state_idx][head]
581-
total_loss += self.deprel_loss_function(deprel_output, deprel_one_hot)
584+
deprel_hx.append(deprel_output)
585+
deprel_one_hots.append(deprel_one_hot)
582586
one_hot = one_hot.to(device)
583587
total_loss += self.transition_loss_function(output_hx[state_idx], one_hot)
588+
if len(deprel_one_hots) > 0:
589+
deprel_one_hots = torch.stack(deprel_one_hots, dim=0)
590+
deprel_hx = torch.stack(deprel_hx, dim=0)
591+
total_loss += self.deprel_loss_function(deprel_hx, deprel_one_hots)
584592
return total_loss
585593

586594
def loss(self, word, word_mask, wordchars, wordchars_mask, upos, xpos, ufeats, pretrained, lemma, head, deprel, word_orig_idx, sentlens, wordlens, text):

0 commit comments

Comments
 (0)