Skip to content

Commit 1f06efa

Browse files
committed
Should be the same - refactor the gold_transitions selection outside the loss block
1 parent 3705363 commit 1f06efa

File tree

1 file changed

+2
-4
lines changed
  • stanza/models/depparse/transition

1 file changed

+2
-4
lines changed

stanza/models/depparse/transition/model.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -535,8 +535,9 @@ def loss(self, word, word_mask, wordchars, wordchars_mask, upos, xpos, ufeats, p
535535
iteration += 1
536536
#print("ITERATION %d" % iteration)
537537
output_hx, left_deprels, right_deprels, transition_h0, transition_c0, partial_tree_h0, partial_tree_c0 = self.forward(states)
538+
gold_transitions = [state.gold_sequence[len(state.transitions)] for state in states]
538539
new_states = []
539-
for state_idx, state in enumerate(states):
540+
for state_idx, (state, gold_transition) in enumerate(zip(states, gold_transitions)):
540541
# one hot vectors will be made in the following order:
541542
# Shift - Finalize - word_position left attachments - num_heads-1 right attachments
542543
if len(state.current_heads) <= 1:
@@ -547,8 +548,6 @@ def loss(self, word, word_mask, wordchars, wordchars_mask, upos, xpos, ufeats, p
547548
# state.word_position for the left attachments
548549
num_hots = len(state.current_heads) + state.word_position + 1
549550
one_hot = torch.zeros(num_hots)
550-
num_transitions = len(state.transitions)
551-
gold_transition = state.gold_sequence[num_transitions]
552551
#print(state_idx, gold_transition, len(state.current_heads), state.word_position)
553552
if isinstance(gold_transition, Shift):
554553
one_hot[0] = 1
@@ -596,7 +595,6 @@ def loss(self, word, word_mask, wordchars, wordchars_mask, upos, xpos, ufeats, p
596595
total_loss += self.deprel_loss_function(deprel_output, deprel_one_hot)
597596
one_hot = one_hot.to(device)
598597
total_loss += self.transition_loss_function(output_hx[state_idx], one_hot)
599-
gold_transitions = [state.gold_sequence[len(state.transitions)] for state in states]
600598
states = self.update_subtree_embeddings(states, gold_transitions)
601599
states = [gold_transition.apply(state) for state, gold_transition in zip(states, gold_transitions)]
602600
for state_idx, state in enumerate(states):

0 commit comments

Comments
 (0)