77from stanza .models .common .utils import build_nonlinearity , unsort
88from stanza .models .common .vocab import VOCAB_PREFIX_SIZE
99from stanza .models .depparse .model import BaseParser , EmbeddingParser
10- from stanza .models .depparse .transition .state import state_from_text , states_from_data_batch , TransitionLSTMEmbedding , SubtreeLSTMEmbedding
10+ from stanza .models .depparse .transition .state import state_from_text , states_from_data_batch , TransitionLSTMEmbedding , SubtreeLSTMEmbedding , ArcLSTMEmbedding
1111from stanza .models .depparse .transition .transitions import Shift , Finalize , ProjectiveLeft , ProjectiveRight , NonprojectiveLeft , NonprojectiveRight
1212
1313# A few notes on some experiments crossvalidating the hyperparameters for this model
@@ -115,16 +115,18 @@ def __init__(self, args, vocab, emb_matrix=None, foundation_cache=None, bert_mod
115115 self .transition_subtree_nonlinearity = build_nonlinearity (self .args .get ('transition_subtree_nonlinearity' ))
116116 self .drop = nn .Dropout (self .args ['dropout' ])
117117
118+ self .merge_words_output_dim = self .args ['transition_merge_words_output_dim' ]
119+ self .merge_hidden_dim = self .transition_hidden_dim + self .args ['hidden_dim' ] + self .merge_words_output_dim
120+
118121 # the bidirectional LSTM is x2, adding in the partial trees is another x1
119- self .word_hidden_dim = self .transition_hidden_dim + self .args ['hidden_dim' ] * 3
122+ # the arc embeddings take up merge_hidden_dim
123+ self .word_hidden_dim = self .transition_hidden_dim + self .args ['hidden_dim' ] * 3 + self .merge_hidden_dim
120124 self .word_output_layers = nn .Sequential (self .nonlinearity ,
121125 self .drop ,
122126 nn .Linear (self .word_hidden_dim , self .word_hidden_dim ),
123127 self .nonlinearity ,
124128 self .drop ,
125129 nn .Linear (self .word_hidden_dim , self .word_hidden_dim ))
126- self .merge_words_output_dim = self .args ['transition_merge_words_output_dim' ]
127- self .merge_hidden_dim = self .transition_hidden_dim + self .args ['hidden_dim' ] + self .merge_words_output_dim
128130 # Splitting this into a left and right version is close,
129131 # but seems to be somewhat more accurate than one layer
130132 # 5 model dev avg LAS baseline merge-two-sides
@@ -155,6 +157,11 @@ def __init__(self, args, vocab, emb_matrix=None, foundation_cache=None, bert_mod
155157 self .drop ,
156158 nn .Linear (self .merge_hidden_dim , self .merge_hidden_dim ))
157159
160+ self .arc_embedding_lstm = nn .LSTM (input_size = self .merge_hidden_dim , hidden_size = self .merge_hidden_dim , num_layers = self .args ['num_layers' ], dropout = self .args ['dropout' ])
161+ self .arc_embedding_start = nn .Parameter (torch .zeros (self .merge_hidden_dim ))
162+ self .arc_embedding_h0 = nn .Parameter (torch .zeros (self .args ['num_layers' ], self .merge_hidden_dim ))
163+ self .arc_embedding_c0 = nn .Parameter (torch .zeros (self .args ['num_layers' ], self .merge_hidden_dim ))
164+
158165 self .output_basic = nn .Linear (self .word_hidden_dim , 2 )
159166 self .output_left_transition = nn .Linear (self .merge_hidden_dim , 1 )
160167 self .output_right_transition = nn .Linear (self .merge_hidden_dim , 1 )
@@ -274,19 +281,26 @@ def forward(self, states):
274281 partial_tree_embeddings = partial_tree_embeddings .squeeze (0 )
275282 #print(torch.linalg.norm(partial_tree_embeddings))
276283
284+ arc_embeddings = [state .arc_lstm_embeddings [- 1 ].hx for state in states ]
285+ arc_embeddings = torch .stack (arc_embeddings , dim = 0 )
286+
277287 word_embeddings = [state .word_embeddings [state .word_position ] for state in states ]
278288 word_embeddings = torch .stack (word_embeddings )
279- output_hx = torch .cat ([transition_embeddings , partial_tree_embeddings , word_embeddings ], dim = 1 )
289+
290+ output_hx = torch .cat ([transition_embeddings , partial_tree_embeddings , arc_embeddings , word_embeddings ], dim = 1 )
280291 output_hx = self .word_output_layers (output_hx )
281292 # batch size x 2 - Shift or Finalize
282293 basic_output = self .output_basic (self .drop (self .nonlinearity (output_hx )))
283294 final_output = [[x ] for x in basic_output ]
284295 left_deprels = []
285296 right_deprels = []
286-
297+ left_arc_hxs = []
298+ right_arc_hxs = []
287299 for state_idx , state in enumerate (states ):
288300 left_deprel = None
289301 right_deprel = None
302+ left_arc_hx = None
303+ right_arc_hx = None
290304 if len (state .current_heads ) > 1 :
291305 # TODO: add a position embedding for the projective / non-projective attachments?
292306 attachment_embeddings = [torch .cat ([state .subtree_embeddings [x ], state .subtree_embeddings [state .current_heads [- 1 ]]])
@@ -316,8 +330,10 @@ def forward(self, states):
316330 final_output [state_idx ] = [final_output [state_idx ][0 ], left_output .squeeze (1 ), right_output .squeeze (1 )]
317331 left_deprels .append (left_deprel )
318332 right_deprels .append (right_deprel )
333+ left_arc_hxs .append (left_arc_hx )
334+ right_arc_hxs .append (right_arc_hx )
319335 final_output = [torch .cat (x ) for x in final_output ]
320- return final_output , left_deprels , right_deprels , transition_h0 , transition_c0 , partial_tree_h0 , partial_tree_c0
336+ return final_output , left_deprels , right_deprels , transition_h0 , transition_c0 , partial_tree_h0 , partial_tree_c0 , left_arc_hxs , right_arc_hxs
321337
322338 def update_subtree_embeddings (self , states , transitions ):
323339 embeddings = []
@@ -520,6 +536,41 @@ def update_partial_tree_lstm(self, states, state_idxs, partial_tree_h0, partial_
520536 states = new_states
521537 state_idxs = new_state_idxs
522538
539+ def extract_arc_embeddings (self , states , transitions , left_arc_hxs , right_arc_hxs ):
540+ arc_embeddings = []
541+ for state_idx , (state , transition ) in enumerate (zip (states , transitions )):
542+ if isinstance (transition , Shift ):
543+ arc_embeddings .append (None )
544+ elif isinstance (transition , Finalize ):
545+ arc_embeddings .append (None )
546+ elif isinstance (transition , ProjectiveLeft ) or isinstance (transition , NonprojectiveLeft ):
547+ if isinstance (transition , ProjectiveLeft ):
548+ head = state .current_heads [- 2 ]
549+ else :
550+ head = transition .word_idx
551+ arc_embeddings .append (left_arc_hxs [state_idx ][head - 1 ])
552+ elif isinstance (transition , ProjectiveRight ) or isinstance (transition , NonprojectiveRight ):
553+ if isinstance (transition , ProjectiveRight ):
554+ head = len (state .current_heads ) - 2
555+ else :
556+ head = state .current_heads .index (transition .word_idx )
557+ arc_embeddings .append (right_arc_hxs [state_idx ][head ])
558+ return arc_embeddings
559+
560+ def update_arc_embedding_lstm (self , states , arc_embeddings ):
561+ arc_states = [state for state_idx , state in enumerate (states ) if arc_embeddings [state_idx ] is not None ]
562+ arc_embeddings = [arc for arc in arc_embeddings if arc is not None ]
563+ if len (arc_embeddings ) == 0 :
564+ return
565+ arc_embedding_hx = torch .stack (arc_embeddings , dim = 0 ).unsqueeze (0 )
566+ arc_embedding_h0 = torch .stack ([state .arc_lstm_embeddings [- 1 ].h0 for state in arc_states ], dim = 1 )
567+ arc_embedding_c0 = torch .stack ([state .arc_lstm_embeddings [- 1 ].c0 for state in arc_states ], dim = 1 )
568+ arc_embedding_hx , (arc_embedding_h0 , arc_embedding_c0 ) = self .arc_embedding_lstm (arc_embedding_hx , (arc_embedding_h0 , arc_embedding_c0 ))
569+ for state_idx , state in enumerate (arc_states ):
570+ state .arc_embeddings .append (arc_embeddings [state_idx ])
571+ state .arc_lstm_embeddings .append (ArcLSTMEmbedding (arc_embedding_hx [0 , state_idx , :],
572+ arc_embedding_h0 [:, state_idx , :],
573+ arc_embedding_c0 [:, state_idx , :]))
523574
524575 def loss (self , word , word_mask , wordchars , wordchars_mask , upos , xpos , ufeats , pretrained , lemma , head , deprel , word_orig_idx , sentlens , wordlens , text ):
525576 # lstm_outputs will be a list of tensors for each sentence
@@ -534,7 +585,7 @@ def loss(self, word, word_mask, wordchars, wordchars_mask, upos, xpos, ufeats, p
534585 while len (states ) > 0 :
535586 iteration += 1
536587 #print("ITERATION %d" % iteration)
537- output_hx , left_deprels , right_deprels , transition_h0 , transition_c0 , partial_tree_h0 , partial_tree_c0 = self .forward (states )
588+ output_hx , left_deprels , right_deprels , transition_h0 , transition_c0 , partial_tree_h0 , partial_tree_c0 , left_arc_hxs , right_arc_hxs = self .forward (states )
538589 gold_transitions = [state .gold_sequence [len (state .transitions )] for state in states ]
539590 new_states = []
540591 for state_idx , (state , gold_transition ) in enumerate (zip (states , gold_transitions )):
@@ -595,12 +646,14 @@ def loss(self, word, word_mask, wordchars, wordchars_mask, upos, xpos, ufeats, p
595646 total_loss += self .deprel_loss_function (deprel_output , deprel_one_hot )
596647 one_hot = one_hot .to (device )
597648 total_loss += self .transition_loss_function (output_hx [state_idx ], one_hot )
649+ arc_embeddings = self .extract_arc_embeddings (states , gold_transitions , left_arc_hxs , right_arc_hxs )
598650 states = self .update_subtree_embeddings (states , gold_transitions )
599651 states = [gold_transition .apply (state ) for state , gold_transition in zip (states , gold_transitions )]
600652 for state_idx , state in enumerate (states ):
601653 # TODO: can this be moved into .apply()
602654 state .transition_lstm_embeddings .append (TransitionLSTMEmbedding (transition_h0 [:, state_idx , :], transition_c0 [:, state_idx , :]))
603655 self .update_partial_tree_lstm (states , range (len (states )), partial_tree_h0 , partial_tree_c0 )
656+ self .update_arc_embedding_lstm (states , arc_embeddings )
604657 states = [state for state in states if not isinstance (state .transitions [- 1 ], Finalize )]
605658
606659 return total_loss
@@ -658,8 +711,11 @@ def predict(self, word, word_mask, wordchars, wordchars_mask, upos, xpos, ufeats
658711 while len (states ) > 0 :
659712 iteration += 1
660713 #print("ITERATION %d" % iteration)
661- output_hx , left_deprels , right_deprels , transition_h0 , transition_c0 , partial_tree_h0 , partial_tree_c0 = self .forward (states )
714+ output_hx , left_deprels , right_deprels , transition_h0 , transition_c0 , partial_tree_h0 , partial_tree_c0 , left_arc_hxs , right_arc_hxs = self .forward (states )
662715 transitions = self .choose_transitions (self .relations , states , output_hx , left_deprels , right_deprels )
716+
717+ arc_embeddings = self .extract_arc_embeddings (states , transitions , left_arc_hxs , right_arc_hxs )
718+
663719 #print(transitions[0])
664720 #print(len(states[0].subtree_lstm_embeddings),
665721 # len(states[0].current_heads), states[0].current_heads)
@@ -675,6 +731,7 @@ def predict(self, word, word_mask, wordchars, wordchars_mask, upos, xpos, ufeats
675731 # TODO: can this be moved into .apply()
676732 state .transition_lstm_embeddings .append (TransitionLSTMEmbedding (transition_h0 [:, state_idx , :], transition_c0 [:, state_idx , :]))
677733 self .update_partial_tree_lstm (states , range (len (states )), partial_tree_h0 , partial_tree_c0 )
734+ self .update_arc_embedding_lstm (states , arc_embeddings )
678735 #print(len(states[0].subtree_lstm_embeddings),
679736 # len(states[0].current_heads), states[0].current_heads)
680737 #if len(states[0].subtree_lstm_embeddings) > 0:
@@ -713,7 +770,9 @@ def build_initial_states(self, head, deprel, text, lstm_outputs, sentlens):
713770 # the sentences are all prepended with root
714771 # which is fine, since we need an embedding for word 0
715772 state = state ._replace (word_embeddings = lstm_output ,
716- subtree_embeddings = {})
773+ subtree_embeddings = {},
774+ arc_embeddings = [],
775+ arc_lstm_embeddings = [ArcLSTMEmbedding (self .arc_embedding_start , self .arc_embedding_h0 , self .arc_embedding_c0 )])
717776 updated_states .append (state )
718777 return updated_states
719778
@@ -778,7 +837,9 @@ def forward(self, model_states):
778837 model_transition_c0 = [x [4 ] for x in model_forwards ]
779838 model_partial_tree_h0 = [x [5 ] for x in model_forwards ]
780839 model_partial_tree_c0 = [x [6 ] for x in model_forwards ]
781- return output_hx , left_deprels , right_deprels , model_transition_h0 , model_transition_c0 , model_partial_tree_h0 , model_partial_tree_c0
840+ model_left_arc_hx = [x [7 ] for x in model_forwards ]
841+ model_right_arc_hx = [x [8 ] for x in model_forwards ]
842+ return output_hx , left_deprels , right_deprels , model_transition_h0 , model_transition_c0 , model_partial_tree_h0 , model_partial_tree_c0 , model_left_arc_hx , model_right_arc_hx
782843
783844 def predict (self , word , word_mask , wordchars , wordchars_mask , upos , xpos , ufeats , pretrained , lemma , head , deprel , word_orig_idx , sentlens , wordlens , text ):
784845 device = self .get_device ()
@@ -798,7 +859,8 @@ def predict(self, word, word_mask, wordchars, wordchars_mask, upos, xpos, ufeats
798859
799860 # output_hx, left_deprels, and right_deprels are already collapsed into one summed value
800861 # the transition and partial_tree vectors are a list of M items long (M=#models)
801- output_hx , left_deprels , right_deprels , model_transition_h0 , model_transition_c0 , model_partial_tree_h0 , model_partial_tree_c0 = self .forward (model_states )
862+ output_hx , left_deprels , right_deprels , model_transition_h0 , model_transition_c0 , model_partial_tree_h0 , model_partial_tree_c0 , model_left_arc_hx , model_right_arc_hx = self .forward (model_states )
863+ # TODO: use the left & right arcs
802864
803865 transitions = TransitionParser .choose_transitions (self .models [0 ].relations , model_states [0 ], output_hx , left_deprels , right_deprels )
804866
0 commit comments