22
33import torch
44from torch import nn
5+ import torch .nn .functional as F
56from torch .nn .utils .rnn import pack_padded_sequence , pad_packed_sequence
67
8+ from stanza .models .common .biaffine import DeepBiaffineScorer
79from stanza .models .common .utils import build_nonlinearity , unsort
810from stanza .models .common .vocab import VOCAB_PREFIX_SIZE
911from stanza .models .depparse .model import BaseParser , EmbeddingParser
@@ -211,6 +213,10 @@ def __init__(self, args, vocab, emb_matrix=None, foundation_cache=None, bert_mod
211213 self .transition_loss_function = nn .CrossEntropyLoss (reduction = 'sum' )
212214 self .deprel_loss_function = nn .CrossEntropyLoss (reduction = 'sum' )
213215
216+ # self.args['distance_output_dim']
217+ self .distance = DeepBiaffineScorer (2 * self .args ['hidden_dim' ], 2 * self .args ['hidden_dim' ], self .args ['deep_biaff_hidden_dim' ], 1 , pairwise = True , dropout = self .args ['dropout' ])
218+ #self.distance_expansion = nn.Linear(self.args['distance_output_dim'], self.merge_hidden_dim)
219+
214220 def forward (self , states ):
215221 """
216222 Builds a list of logits for the different operations, including a separate one for each Left and Right merge
@@ -300,7 +306,8 @@ def forward(self, states):
300306 attachment_input_left = attachment_input .expand (state .word_position , attachment_input .shape [0 ])
301307 left_arc_hx = torch .cat ([attachment_input_left , attachment_embeddings_left ], axis = 1 )
302308 left_arc_hx = self .merge_output_left (left_arc_hx )
303- left_output = self .output_left_transition (self .drop (self .nonlinearity (left_arc_hx )))
309+ distance_left = state .distance [0 , 1 :state .word_position + 1 , state .current_heads [- 1 ]].unsqueeze (1 ).detach ()
310+ left_output = self .output_left_transition (self .drop (self .nonlinearity (left_arc_hx ))) + distance_left * self .args ['distance_factor' ]
304311 left_deprel = self .output_left_deprel (self .drop (self .nonlinearity (left_arc_hx )))
305312
306313 # truncate the outputs to only be the current heads,
@@ -311,7 +318,8 @@ def forward(self, states):
311318 attachment_input_right = attachment_input .unsqueeze (0 ).expand (current_heads .shape [0 ], attachment_input .shape [0 ])
312319 right_arc_hx = torch .cat ([attachment_input_right , attachment_embeddings_right ], axis = 1 )
313320 right_arc_hx = self .merge_output_right (right_arc_hx )
314- right_output = self .output_right_transition (self .drop (self .nonlinearity (right_arc_hx )))
321+ distance_right = state .distance [0 , state .current_heads [- 1 ], :][current_heads ].unsqueeze (1 ).detach ()
322+ right_output = self .output_right_transition (self .drop (self .nonlinearity (right_arc_hx ))) + distance_right * self .args ['distance_factor' ]
315323 right_deprel = self .output_right_deprel (self .drop (self .nonlinearity (right_arc_hx )))
316324 final_output [state_idx ] = [final_output [state_idx ][0 ], left_output .squeeze (1 ), right_output .squeeze (1 )]
317325 left_deprels .append (left_deprel )
@@ -598,6 +606,10 @@ def loss(self, word, word_mask, wordchars, wordchars_mask, upos, xpos, ufeats, p
598606 states = self .build_initial_states (head , deprel , text , lstm_outputs , sentlens )
599607
600608 total_loss = 0
609+ for state , sentence_head in zip (states , head ):
610+ dist_kld = torch .gather (state .distance [0 , 1 :, :], 1 , sentence_head [:state .num_words ].unsqueeze (1 ))
611+ # definitely not +=... that model is completely broken
612+ total_loss -= dist_kld .sum ()
601613
602614 iteration = 0
603615 while len (states ) > 0 :
@@ -722,12 +734,24 @@ def build_initial_states(self, head, deprel, text, lstm_outputs, sentlens):
722734 else :
723735 states = [state_from_text (sentence ) for sentence in text ]
724736 updated_states = []
737+
725738 # TODO: list comprehension?
726739 for state , lstm_output , sentlen in zip (states , lstm_outputs , sentlens ):
727740 # the sentences are all prepended with root
728741 # which is fine, since we need an embedding for word 0
742+ # for distance, the graph parser uses the extra space with the root
743+ # TODO: stack the distance operation
744+ head_offset = (torch .arange (sentlen , device = lstm_output .device ).view (1 , 1 , - 1 ) -
745+ torch .arange (sentlen , device = lstm_output .device ).view (1 , - 1 , 1 ))
746+ distance_scores = self .distance (self .drop (lstm_output [:sentlen ].unsqueeze (0 )),
747+ self .drop (lstm_output [:sentlen ].unsqueeze (0 ))).squeeze (3 ).squeeze (0 )
748+ distance_pred = 1 + F .softplus (distance_scores )
749+ distance_target = torch .abs (head_offset )
750+ distance_kld = - torch .log ((distance_target .float () - distance_pred )** 2 / 2 + 1 )
751+
729752 state = state ._replace (word_embeddings = lstm_output ,
730- subtree_embeddings = {})
753+ subtree_embeddings = {},
754+ distance = distance_kld )
731755 updated_states .append (state )
732756 return updated_states
733757
0 commit comments