@@ -520,13 +520,74 @@ def update_partial_tree_lstm(self, states, state_idxs, partial_tree_h0, partial_
520520 states = new_states
521521 state_idxs = new_state_idxs
522522
523+ def calculate_iteration_loss (self , states , gold_transitions , output_hx , left_deprels , right_deprels ):
524+ device = next (self .parameters ()).device
525+ total_loss = 0.0
526+ for state_idx , (state , gold_transition ) in enumerate (zip (states , gold_transitions )):
527+ # one hot vectors will be made in the following order:
528+ # Shift - Finalize - word_position left attachments - num_heads-1 right attachments
529+ if len (state .current_heads ) <= 1 :
530+ one_hot = torch .zeros (2 )
531+ else :
532+ # 2 for shift & finalize
533+ # state.num_heads - 1 for right attachments
534+ # state.word_position for the left attachments
535+ num_hots = len (state .current_heads ) + state .word_position + 1
536+ one_hot = torch .zeros (num_hots )
537+ #print(state_idx, gold_transition, len(state.current_heads), state.word_position)
538+ if isinstance (gold_transition , Shift ):
539+ one_hot [0 ] = 1
540+ elif isinstance (gold_transition , Finalize ):
541+ one_hot [1 ] = 1
542+ elif isinstance (gold_transition , ProjectiveLeft ) or isinstance (gold_transition , NonprojectiveLeft ):
543+ if isinstance (gold_transition , ProjectiveLeft ):
544+ head = state .current_heads [- 2 ]
545+ else :
546+ head = gold_transition .word_idx
547+ #print(" Left", head, num_hots)
548+ # words are indexed at 1
549+ # so there should never be a head for word 0
550+ # hence the first word needs to be at one_hot[2]
551+ one_hot [head + 1 ] = 1
552+
553+ # also, include a loss for the deprel
554+ deprel = gold_transition .deprel
555+ deprel_idx = self .relation_to_id [deprel ]
556+ deprel_one_hot = torch .zeros (len (self .relations ))
557+ deprel_one_hot [deprel_idx ] = 1
558+ deprel_one_hot = deprel_one_hot .to (device )
559+ # the word_embeddings list has an entry for root at 0,
560+ # whereas the words are indexed from 1
561+ # here, though, we only want to attach words to previous heads
562+ # so we do head-1
563+ deprel_output = left_deprels [state_idx ][head - 1 ]
564+ total_loss += self .deprel_loss_function (deprel_output , deprel_one_hot )
565+ elif isinstance (gold_transition , ProjectiveRight ) or isinstance (gold_transition , NonprojectiveRight ):
566+ if isinstance (gold_transition , ProjectiveRight ):
567+ head = len (state .current_heads ) - 2
568+ else :
569+ head = state .current_heads .index (gold_transition .word_idx )
570+ hot_idx = 2 + state .word_position + head
571+ #print(" Right", hot_idx, num_hots)
572+ one_hot [hot_idx ] = 1
573+
574+ # also, include a loss for the deprel
575+ deprel = gold_transition .deprel
576+ deprel_idx = self .relation_to_id [deprel ]
577+ deprel_one_hot = torch .zeros (len (self .relations ))
578+ deprel_one_hot [deprel_idx ] = 1
579+ deprel_one_hot = deprel_one_hot .to (device )
580+ deprel_output = right_deprels [state_idx ][head ]
581+ total_loss += self .deprel_loss_function (deprel_output , deprel_one_hot )
582+ one_hot = one_hot .to (device )
583+ total_loss += self .transition_loss_function (output_hx [state_idx ], one_hot )
584+ return total_loss
523585
524586 def loss (self , word , word_mask , wordchars , wordchars_mask , upos , xpos , ufeats , pretrained , lemma , head , deprel , word_orig_idx , sentlens , wordlens , text ):
525587 # lstm_outputs will be a list of tensors for each sentence
526588 # max(len) x args['hidden_dim']*2
527589 lstm_outputs = self .embed (word , word_mask , wordchars , wordchars_mask , upos , xpos , ufeats , pretrained , lemma , head , deprel , word_orig_idx , sentlens , wordlens , text )
528590 states = self .build_initial_states (head , deprel , text , lstm_outputs , sentlens )
529- device = next (self .parameters ()).device
530591
531592 total_loss = 0
532593
@@ -536,65 +597,10 @@ def loss(self, word, word_mask, wordchars, wordchars_mask, upos, xpos, ufeats, p
536597 #print("ITERATION %d" % iteration)
537598 output_hx , left_deprels , right_deprels , transition_h0 , transition_c0 , partial_tree_h0 , partial_tree_c0 = self .forward (states )
538599 gold_transitions = [state .gold_sequence [len (state .transitions )] for state in states ]
539- new_states = []
540- for state_idx , (state , gold_transition ) in enumerate (zip (states , gold_transitions )):
541- # one hot vectors will be made in the following order:
542- # Shift - Finalize - word_position left attachments - num_heads-1 right attachments
543- if len (state .current_heads ) <= 1 :
544- one_hot = torch .zeros (2 )
545- else :
546- # 2 for shift & finalize
547- # state.num_heads - 1 for right attachments
548- # state.word_position for the left attachments
549- num_hots = len (state .current_heads ) + state .word_position + 1
550- one_hot = torch .zeros (num_hots )
551- #print(state_idx, gold_transition, len(state.current_heads), state.word_position)
552- if isinstance (gold_transition , Shift ):
553- one_hot [0 ] = 1
554- elif isinstance (gold_transition , Finalize ):
555- one_hot [1 ] = 1
556- elif isinstance (gold_transition , ProjectiveLeft ) or isinstance (gold_transition , NonprojectiveLeft ):
557- if isinstance (gold_transition , ProjectiveLeft ):
558- head = state .current_heads [- 2 ]
559- else :
560- head = gold_transition .word_idx
561- #print(" Left", head, num_hots)
562- # words are indexed at 1
563- # so there should never be a head for word 0
564- # hence the first word needs to be at one_hot[2]
565- one_hot [head + 1 ] = 1
566-
567- # also, include a loss for the deprel
568- deprel = gold_transition .deprel
569- deprel_idx = self .relation_to_id [deprel ]
570- deprel_one_hot = torch .zeros (len (self .relations ))
571- deprel_one_hot [deprel_idx ] = 1
572- deprel_one_hot = deprel_one_hot .to (device )
573- # the word_embeddings list has an entry for root at 0,
574- # whereas the words are indexed from 1
575- # here, though, we only want to attach words to previous heads
576- # so we do head-1
577- deprel_output = left_deprels [state_idx ][head - 1 ]
578- total_loss += self .deprel_loss_function (deprel_output , deprel_one_hot )
579- elif isinstance (gold_transition , ProjectiveRight ) or isinstance (gold_transition , NonprojectiveRight ):
580- if isinstance (gold_transition , ProjectiveRight ):
581- head = len (state .current_heads ) - 2
582- else :
583- head = state .current_heads .index (gold_transition .word_idx )
584- hot_idx = 2 + state .word_position + head
585- #print(" Right", hot_idx, num_hots)
586- one_hot [hot_idx ] = 1
587-
588- # also, include a loss for the deprel
589- deprel = gold_transition .deprel
590- deprel_idx = self .relation_to_id [deprel ]
591- deprel_one_hot = torch .zeros (len (self .relations ))
592- deprel_one_hot [deprel_idx ] = 1
593- deprel_one_hot = deprel_one_hot .to (device )
594- deprel_output = right_deprels [state_idx ][head ]
595- total_loss += self .deprel_loss_function (deprel_output , deprel_one_hot )
596- one_hot = one_hot .to (device )
597- total_loss += self .transition_loss_function (output_hx [state_idx ], one_hot )
600+
601+ iteration_loss = self .calculate_iteration_loss (states , gold_transitions , output_hx , left_deprels , right_deprels )
602+ total_loss += iteration_loss
603+
598604 states = self .update_subtree_embeddings (states , gold_transitions )
599605 states = [gold_transition .apply (state ) for state , gold_transition in zip (states , gold_transitions )]
600606 for state_idx , state in enumerate (states ):
0 commit comments