Skip to content

Commit b929d0b

Browse files
committed
Refactor the loss into its own function. Will try to stack it
1 parent 734c2a6 commit b929d0b

File tree

1 file changed

+66
-60
lines changed
  • stanza/models/depparse/transition

1 file changed

+66
-60
lines changed

stanza/models/depparse/transition/model.py

Lines changed: 66 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -520,13 +520,74 @@ def update_partial_tree_lstm(self, states, state_idxs, partial_tree_h0, partial_
520520
states = new_states
521521
state_idxs = new_state_idxs
522522

523+
def calculate_iteration_loss(self, states, gold_transitions, output_hx, left_deprels, right_deprels):
524+
device = next(self.parameters()).device
525+
total_loss = 0.0
526+
for state_idx, (state, gold_transition) in enumerate(zip(states, gold_transitions)):
527+
# one hot vectors will be made in the following order:
528+
# Shift - Finalize - word_position left attachments - num_heads-1 right attachments
529+
if len(state.current_heads) <= 1:
530+
one_hot = torch.zeros(2)
531+
else:
532+
# 2 for shift & finalize
533+
# state.num_heads - 1 for right attachments
534+
# state.word_position for the left attachments
535+
num_hots = len(state.current_heads) + state.word_position + 1
536+
one_hot = torch.zeros(num_hots)
537+
#print(state_idx, gold_transition, len(state.current_heads), state.word_position)
538+
if isinstance(gold_transition, Shift):
539+
one_hot[0] = 1
540+
elif isinstance(gold_transition, Finalize):
541+
one_hot[1] = 1
542+
elif isinstance(gold_transition, ProjectiveLeft) or isinstance(gold_transition, NonprojectiveLeft):
543+
if isinstance(gold_transition, ProjectiveLeft):
544+
head = state.current_heads[-2]
545+
else:
546+
head = gold_transition.word_idx
547+
#print(" Left", head, num_hots)
548+
# words are indexed at 1
549+
# so there should never be a head for word 0
550+
# hence the first word needs to be at one_hot[2]
551+
one_hot[head+1] = 1
552+
553+
# also, include a loss for the deprel
554+
deprel = gold_transition.deprel
555+
deprel_idx = self.relation_to_id[deprel]
556+
deprel_one_hot = torch.zeros(len(self.relations))
557+
deprel_one_hot[deprel_idx] = 1
558+
deprel_one_hot = deprel_one_hot.to(device)
559+
# the word_embeddings list has an entry for root at 0,
560+
# whereas the words are indexed from 1
561+
# here, though, we only want to attach words to previous heads
562+
# so we do head-1
563+
deprel_output = left_deprels[state_idx][head-1]
564+
total_loss += self.deprel_loss_function(deprel_output, deprel_one_hot)
565+
elif isinstance(gold_transition, ProjectiveRight) or isinstance(gold_transition, NonprojectiveRight):
566+
if isinstance(gold_transition, ProjectiveRight):
567+
head = len(state.current_heads) - 2
568+
else:
569+
head = state.current_heads.index(gold_transition.word_idx)
570+
hot_idx = 2+state.word_position+head
571+
#print(" Right", hot_idx, num_hots)
572+
one_hot[hot_idx] = 1
573+
574+
# also, include a loss for the deprel
575+
deprel = gold_transition.deprel
576+
deprel_idx = self.relation_to_id[deprel]
577+
deprel_one_hot = torch.zeros(len(self.relations))
578+
deprel_one_hot[deprel_idx] = 1
579+
deprel_one_hot = deprel_one_hot.to(device)
580+
deprel_output = right_deprels[state_idx][head]
581+
total_loss += self.deprel_loss_function(deprel_output, deprel_one_hot)
582+
one_hot = one_hot.to(device)
583+
total_loss += self.transition_loss_function(output_hx[state_idx], one_hot)
584+
return total_loss
523585

524586
def loss(self, word, word_mask, wordchars, wordchars_mask, upos, xpos, ufeats, pretrained, lemma, head, deprel, word_orig_idx, sentlens, wordlens, text):
525587
# lstm_outputs will be a list of tensors for each sentence
526588
# max(len) x args['hidden_dim']*2
527589
lstm_outputs = self.embed(word, word_mask, wordchars, wordchars_mask, upos, xpos, ufeats, pretrained, lemma, head, deprel, word_orig_idx, sentlens, wordlens, text)
528590
states = self.build_initial_states(head, deprel, text, lstm_outputs, sentlens)
529-
device = next(self.parameters()).device
530591

531592
total_loss = 0
532593

@@ -536,65 +597,10 @@ def loss(self, word, word_mask, wordchars, wordchars_mask, upos, xpos, ufeats, p
536597
#print("ITERATION %d" % iteration)
537598
output_hx, left_deprels, right_deprels, transition_h0, transition_c0, partial_tree_h0, partial_tree_c0 = self.forward(states)
538599
gold_transitions = [state.gold_sequence[len(state.transitions)] for state in states]
539-
new_states = []
540-
for state_idx, (state, gold_transition) in enumerate(zip(states, gold_transitions)):
541-
# one hot vectors will be made in the following order:
542-
# Shift - Finalize - word_position left attachments - num_heads-1 right attachments
543-
if len(state.current_heads) <= 1:
544-
one_hot = torch.zeros(2)
545-
else:
546-
# 2 for shift & finalize
547-
# state.num_heads - 1 for right attachments
548-
# state.word_position for the left attachments
549-
num_hots = len(state.current_heads) + state.word_position + 1
550-
one_hot = torch.zeros(num_hots)
551-
#print(state_idx, gold_transition, len(state.current_heads), state.word_position)
552-
if isinstance(gold_transition, Shift):
553-
one_hot[0] = 1
554-
elif isinstance(gold_transition, Finalize):
555-
one_hot[1] = 1
556-
elif isinstance(gold_transition, ProjectiveLeft) or isinstance(gold_transition, NonprojectiveLeft):
557-
if isinstance(gold_transition, ProjectiveLeft):
558-
head = state.current_heads[-2]
559-
else:
560-
head = gold_transition.word_idx
561-
#print(" Left", head, num_hots)
562-
# words are indexed at 1
563-
# so there should never be a head for word 0
564-
# hence the first word needs to be at one_hot[2]
565-
one_hot[head+1] = 1
566-
567-
# also, include a loss for the deprel
568-
deprel = gold_transition.deprel
569-
deprel_idx = self.relation_to_id[deprel]
570-
deprel_one_hot = torch.zeros(len(self.relations))
571-
deprel_one_hot[deprel_idx] = 1
572-
deprel_one_hot = deprel_one_hot.to(device)
573-
# the word_embeddings list has an entry for root at 0,
574-
# whereas the words are indexed from 1
575-
# here, though, we only want to attach words to previous heads
576-
# so we do head-1
577-
deprel_output = left_deprels[state_idx][head-1]
578-
total_loss += self.deprel_loss_function(deprel_output, deprel_one_hot)
579-
elif isinstance(gold_transition, ProjectiveRight) or isinstance(gold_transition, NonprojectiveRight):
580-
if isinstance(gold_transition, ProjectiveRight):
581-
head = len(state.current_heads) - 2
582-
else:
583-
head = state.current_heads.index(gold_transition.word_idx)
584-
hot_idx = 2+state.word_position+head
585-
#print(" Right", hot_idx, num_hots)
586-
one_hot[hot_idx] = 1
587-
588-
# also, include a loss for the deprel
589-
deprel = gold_transition.deprel
590-
deprel_idx = self.relation_to_id[deprel]
591-
deprel_one_hot = torch.zeros(len(self.relations))
592-
deprel_one_hot[deprel_idx] = 1
593-
deprel_one_hot = deprel_one_hot.to(device)
594-
deprel_output = right_deprels[state_idx][head]
595-
total_loss += self.deprel_loss_function(deprel_output, deprel_one_hot)
596-
one_hot = one_hot.to(device)
597-
total_loss += self.transition_loss_function(output_hx[state_idx], one_hot)
600+
601+
iteration_loss = self.calculate_iteration_loss(states, gold_transitions, output_hx, left_deprels, right_deprels)
602+
total_loss += iteration_loss
603+
598604
states = self.update_subtree_embeddings(states, gold_transitions)
599605
states = [gold_transition.apply(state) for state, gold_transition in zip(states, gold_transitions)]
600606
for state_idx, state in enumerate(states):

0 commit comments

Comments
 (0)