@@ -535,8 +535,9 @@ def loss(self, word, word_mask, wordchars, wordchars_mask, upos, xpos, ufeats, p
535535 iteration += 1
536536 #print("ITERATION %d" % iteration)
537537 output_hx , left_deprels , right_deprels , transition_h0 , transition_c0 , partial_tree_h0 , partial_tree_c0 = self .forward (states )
538+ gold_transitions = [state .gold_sequence [len (state .transitions )] for state in states ]
538539 new_states = []
539- for state_idx , state in enumerate (states ):
540+ for state_idx , ( state , gold_transition ) in enumerate (zip ( states , gold_transitions ) ):
540541 # one hot vectors will be made in the following order:
541542 # Shift - Finalize - word_position left attachments - num_heads-1 right attachments
542543 if len (state .current_heads ) <= 1 :
@@ -547,8 +548,6 @@ def loss(self, word, word_mask, wordchars, wordchars_mask, upos, xpos, ufeats, p
547548 # state.word_position for the left attachments
548549 num_hots = len (state .current_heads ) + state .word_position + 1
549550 one_hot = torch .zeros (num_hots )
550- num_transitions = len (state .transitions )
551- gold_transition = state .gold_sequence [num_transitions ]
552551 #print(state_idx, gold_transition, len(state.current_heads), state.word_position)
553552 if isinstance (gold_transition , Shift ):
554553 one_hot [0 ] = 1
@@ -596,7 +595,6 @@ def loss(self, word, word_mask, wordchars, wordchars_mask, upos, xpos, ufeats, p
596595 total_loss += self .deprel_loss_function (deprel_output , deprel_one_hot )
597596 one_hot = one_hot .to (device )
598597 total_loss += self .transition_loss_function (output_hx [state_idx ], one_hot )
599- gold_transitions = [state .gold_sequence [len (state .transitions )] for state in states ]
600598 states = self .update_subtree_embeddings (states , gold_transitions )
601599 states = [gold_transition .apply (state ) for state , gold_transition in zip (states , gold_transitions )]
602600 for state_idx , state in enumerate (states ):
0 commit comments