@@ -523,6 +523,8 @@ def update_partial_tree_lstm(self, states, state_idxs, partial_tree_h0, partial_
523523 def calculate_iteration_loss (self , states , gold_transitions , output_hx , left_deprels , right_deprels ):
524524 device = next (self .parameters ()).device
525525 total_loss = 0.0
526+ deprel_one_hots = []
527+ deprel_hx = []
526528 for state_idx , (state , gold_transition ) in enumerate (zip (states , gold_transitions )):
527529 # one hot vectors will be made in the following order:
528530 # Shift - Finalize - word_position left attachments - num_heads-1 right attachments
@@ -561,7 +563,8 @@ def calculate_iteration_loss(self, states, gold_transitions, output_hx, left_dep
561563 # here, though, we only want to attach words to previous heads
562564 # so we do head-1
563565 deprel_output = left_deprels [state_idx ][head - 1 ]
564- total_loss += self .deprel_loss_function (deprel_output , deprel_one_hot )
566+ deprel_hx .append (deprel_output )
567+ deprel_one_hots .append (deprel_one_hot )
565568 elif isinstance (gold_transition , ProjectiveRight ) or isinstance (gold_transition , NonprojectiveRight ):
566569 if isinstance (gold_transition , ProjectiveRight ):
567570 head = len (state .current_heads ) - 2
@@ -578,9 +581,14 @@ def calculate_iteration_loss(self, states, gold_transitions, output_hx, left_dep
578581 deprel_one_hot [deprel_idx ] = 1
579582 deprel_one_hot = deprel_one_hot .to (device )
580583 deprel_output = right_deprels [state_idx ][head ]
581- total_loss += self .deprel_loss_function (deprel_output , deprel_one_hot )
584+ deprel_hx .append (deprel_output )
585+ deprel_one_hots .append (deprel_one_hot )
582586 one_hot = one_hot .to (device )
583587 total_loss += self .transition_loss_function (output_hx [state_idx ], one_hot )
588+ if len (deprel_one_hots ) > 0 :
589+ deprel_one_hots = torch .stack (deprel_one_hots , dim = 0 )
590+ deprel_hx = torch .stack (deprel_hx , dim = 0 )
591+ total_loss += self .deprel_loss_function (deprel_hx , deprel_one_hots )
584592 return total_loss
585593
586594 def loss (self , word , word_mask , wordchars , wordchars_mask , upos , xpos , ufeats , pretrained , lemma , head , deprel , word_orig_idx , sentlens , wordlens , text ):
0 commit comments