77from stanza .models .common .utils import build_nonlinearity , unsort
88from stanza .models .common .vocab import VOCAB_PREFIX_SIZE
99from stanza .models .depparse .model import BaseParser , EmbeddingParser
10- from stanza .models .depparse .transition .state import state_from_text , states_from_data_batch , TransitionLSTMEmbedding , SubtreeLSTMEmbedding
10+ from stanza .models .depparse .transition .state import state_from_text , states_from_data_batch , TransitionLSTMEmbedding , SubtreeLSTMEmbedding , ArcLSTMEmbedding
1111from stanza .models .depparse .transition .transitions import Shift , Finalize , ProjectiveLeft , ProjectiveRight , NonprojectiveLeft , NonprojectiveRight
1212
1313# A few notes on some experiments crossvalidating the hyperparameters for this model
@@ -115,16 +115,18 @@ def __init__(self, args, vocab, emb_matrix=None, foundation_cache=None, bert_mod
115115 self .transition_subtree_nonlinearity = build_nonlinearity (self .args .get ('transition_subtree_nonlinearity' ))
116116 self .drop = nn .Dropout (self .args ['dropout' ])
117117
118+ self .merge_words_output_dim = self .args ['transition_merge_words_output_dim' ]
119+ self .merge_hidden_dim = self .transition_hidden_dim + self .args ['hidden_dim' ] + self .merge_words_output_dim
120+
118121 # the bidirectional LSTM is x2, adding in the partial trees is another x1
119- self .word_hidden_dim = self .transition_hidden_dim + self .args ['hidden_dim' ] * 3
122+ # the arc embeddings take up merge_hidden_dim
123+ self .word_hidden_dim = self .transition_hidden_dim + self .args ['hidden_dim' ] * 3 + self .merge_hidden_dim
120124 self .word_output_layers = nn .Sequential (self .nonlinearity ,
121125 self .drop ,
122126 nn .Linear (self .word_hidden_dim , self .word_hidden_dim ),
123127 self .nonlinearity ,
124128 self .drop ,
125129 nn .Linear (self .word_hidden_dim , self .word_hidden_dim ))
126- self .merge_words_output_dim = self .args ['transition_merge_words_output_dim' ]
127- self .merge_hidden_dim = self .transition_hidden_dim + self .args ['hidden_dim' ] + self .merge_words_output_dim
128130 # Splitting this into a left and right version is close,
129131 # but seems to be somewhat more accurate than one layer
130132 # 5 model dev avg LAS baseline merge-two-sides
@@ -155,6 +157,11 @@ def __init__(self, args, vocab, emb_matrix=None, foundation_cache=None, bert_mod
155157 self .drop ,
156158 nn .Linear (self .merge_hidden_dim , self .merge_hidden_dim ))
157159
160+ self .arc_embedding_lstm = nn .LSTM (input_size = self .merge_hidden_dim , hidden_size = self .merge_hidden_dim , num_layers = self .args ['num_layers' ], dropout = self .args ['dropout' ])
161+ self .arc_embedding_start = nn .Parameter (torch .zeros (self .merge_hidden_dim ))
162+ self .arc_embedding_h0 = nn .Parameter (torch .zeros (self .args ['num_layers' ], self .merge_hidden_dim ))
163+ self .arc_embedding_c0 = nn .Parameter (torch .zeros (self .args ['num_layers' ], self .merge_hidden_dim ))
164+
158165 self .output_basic = nn .Linear (self .word_hidden_dim , 2 )
159166 self .output_left_transition = nn .Linear (self .merge_hidden_dim , 1 )
160167 self .output_right_transition = nn .Linear (self .merge_hidden_dim , 1 )
@@ -274,19 +281,26 @@ def forward(self, states):
274281 partial_tree_embeddings = partial_tree_embeddings .squeeze (0 )
275282 #print(torch.linalg.norm(partial_tree_embeddings))
276283
284+ arc_embeddings = [state .arc_lstm_embeddings [- 1 ].hx for state in states ]
285+ arc_embeddings = torch .stack (arc_embeddings , dim = 0 )
286+
277287 word_embeddings = [state .word_embeddings [state .word_position ] for state in states ]
278288 word_embeddings = torch .stack (word_embeddings )
279- output_hx = torch .cat ([transition_embeddings , partial_tree_embeddings , word_embeddings ], dim = 1 )
289+
290+ output_hx = torch .cat ([transition_embeddings , partial_tree_embeddings , arc_embeddings , word_embeddings ], dim = 1 )
280291 output_hx = self .word_output_layers (output_hx )
281292 # batch size x 2 - Shift or Finalize
282293 basic_output = self .output_basic (self .drop (self .nonlinearity (output_hx )))
283294 final_output = [[x ] for x in basic_output ]
284295 left_deprels = []
285296 right_deprels = []
286-
297+ left_arc_hxs = []
298+ right_arc_hxs = []
287299 for state_idx , state in enumerate (states ):
288300 left_deprel = None
289301 right_deprel = None
302+ left_arc_hx = None
303+ right_arc_hx = None
290304 if len (state .current_heads ) > 1 :
291305 # TODO: add a position embedding for the projective / non-projective attachments?
292306 attachment_embeddings = [torch .cat ([state .subtree_embeddings [x ], state .subtree_embeddings [state .current_heads [- 1 ]]])
@@ -316,8 +330,10 @@ def forward(self, states):
316330 final_output [state_idx ] = [final_output [state_idx ][0 ], left_output .squeeze (1 ), right_output .squeeze (1 )]
317331 left_deprels .append (left_deprel )
318332 right_deprels .append (right_deprel )
333+ left_arc_hxs .append (left_arc_hx )
334+ right_arc_hxs .append (right_arc_hx )
319335 final_output = [torch .cat (x ) for x in final_output ]
320- return final_output , left_deprels , right_deprels , transition_h0 , transition_c0 , partial_tree_h0 , partial_tree_c0
336+ return final_output , left_deprels , right_deprels , transition_h0 , transition_c0 , partial_tree_h0 , partial_tree_c0 , left_arc_hxs , right_arc_hxs
321337
322338 def update_subtree_embeddings (self , states , transitions ):
323339 embeddings = []
@@ -591,6 +607,42 @@ def calculate_iteration_loss(self, states, gold_transitions, output_hx, left_dep
591607 total_loss += self .deprel_loss_function (deprel_hx , deprel_one_hots )
592608 return total_loss
593609
610+ def extract_arc_embeddings (self , states , transitions , left_arc_hxs , right_arc_hxs ):
611+ arc_embeddings = []
612+ for state_idx , (state , transition ) in enumerate (zip (states , transitions )):
613+ if isinstance (transition , Shift ):
614+ arc_embeddings .append (None )
615+ elif isinstance (transition , Finalize ):
616+ arc_embeddings .append (None )
617+ elif isinstance (transition , ProjectiveLeft ) or isinstance (transition , NonprojectiveLeft ):
618+ if isinstance (transition , ProjectiveLeft ):
619+ head = state .current_heads [- 2 ]
620+ else :
621+ head = transition .word_idx
622+ arc_embeddings .append (left_arc_hxs [state_idx ][head - 1 ])
623+ elif isinstance (transition , ProjectiveRight ) or isinstance (transition , NonprojectiveRight ):
624+ if isinstance (transition , ProjectiveRight ):
625+ head = len (state .current_heads ) - 2
626+ else :
627+ head = state .current_heads .index (transition .word_idx )
628+ arc_embeddings .append (right_arc_hxs [state_idx ][head ])
629+ return arc_embeddings
630+
631+ def update_arc_embedding_lstm (self , states , arc_embeddings ):
632+ arc_states = [state for state_idx , state in enumerate (states ) if arc_embeddings [state_idx ] is not None ]
633+ arc_embeddings = [arc for arc in arc_embeddings if arc is not None ]
634+ if len (arc_embeddings ) == 0 :
635+ return
636+ arc_embedding_hx = torch .stack (arc_embeddings , dim = 0 ).unsqueeze (0 )
637+ arc_embedding_h0 = torch .stack ([state .arc_lstm_embeddings [- 1 ].h0 for state in arc_states ], dim = 1 )
638+ arc_embedding_c0 = torch .stack ([state .arc_lstm_embeddings [- 1 ].c0 for state in arc_states ], dim = 1 )
639+ arc_embedding_hx , (arc_embedding_h0 , arc_embedding_c0 ) = self .arc_embedding_lstm (arc_embedding_hx , (arc_embedding_h0 , arc_embedding_c0 ))
640+ for state_idx , state in enumerate (arc_states ):
641+ state .arc_embeddings .append (arc_embeddings [state_idx ])
642+ state .arc_lstm_embeddings .append (ArcLSTMEmbedding (arc_embedding_hx [0 , state_idx , :],
643+ arc_embedding_h0 [:, state_idx , :],
644+ arc_embedding_c0 [:, state_idx , :]))
645+
594646 def loss (self , word , word_mask , wordchars , wordchars_mask , upos , xpos , ufeats , pretrained , lemma , head , deprel , word_orig_idx , sentlens , wordlens , text ):
595647 # lstm_outputs will be a list of tensors for each sentence
596648 # max(len) x args['hidden_dim']*2
@@ -603,18 +655,20 @@ def loss(self, word, word_mask, wordchars, wordchars_mask, upos, xpos, ufeats, p
603655 while len (states ) > 0 :
604656 iteration += 1
605657 #print("ITERATION %d" % iteration)
606- output_hx , left_deprels , right_deprels , transition_h0 , transition_c0 , partial_tree_h0 , partial_tree_c0 = self .forward (states )
658+ output_hx , left_deprels , right_deprels , transition_h0 , transition_c0 , partial_tree_h0 , partial_tree_c0 , left_arc_hxs , right_arc_hxs = self .forward (states )
607659 gold_transitions = [state .gold_sequence [len (state .transitions )] for state in states ]
608660
609661 iteration_loss = self .calculate_iteration_loss (states , gold_transitions , output_hx , left_deprels , right_deprels )
610662 total_loss += iteration_loss
611663
664+ arc_embeddings = self .extract_arc_embeddings (states , gold_transitions , left_arc_hxs , right_arc_hxs )
612665 states = self .update_subtree_embeddings (states , gold_transitions )
613666 states = [gold_transition .apply (state ) for state , gold_transition in zip (states , gold_transitions )]
614667 for state_idx , state in enumerate (states ):
615668 # TODO: can this be moved into .apply()
616669 state .transition_lstm_embeddings .append (TransitionLSTMEmbedding (transition_h0 [:, state_idx , :], transition_c0 [:, state_idx , :]))
617670 self .update_partial_tree_lstm (states , range (len (states )), partial_tree_h0 , partial_tree_c0 )
671+ self .update_arc_embedding_lstm (states , arc_embeddings )
618672 states = [state for state in states if not isinstance (state .transitions [- 1 ], Finalize )]
619673
620674 return total_loss
@@ -672,8 +726,11 @@ def predict(self, word, word_mask, wordchars, wordchars_mask, upos, xpos, ufeats
672726 while len (states ) > 0 :
673727 iteration += 1
674728 #print("ITERATION %d" % iteration)
675- output_hx , left_deprels , right_deprels , transition_h0 , transition_c0 , partial_tree_h0 , partial_tree_c0 = self .forward (states )
729+ output_hx , left_deprels , right_deprels , transition_h0 , transition_c0 , partial_tree_h0 , partial_tree_c0 , left_arc_hxs , right_arc_hxs = self .forward (states )
676730 transitions = self .choose_transitions (self .relations , states , output_hx , left_deprels , right_deprels )
731+
732+ arc_embeddings = self .extract_arc_embeddings (states , transitions , left_arc_hxs , right_arc_hxs )
733+
677734 #print(transitions[0])
678735 #print(len(states[0].subtree_lstm_embeddings),
679736 # len(states[0].current_heads), states[0].current_heads)
@@ -689,6 +746,7 @@ def predict(self, word, word_mask, wordchars, wordchars_mask, upos, xpos, ufeats
689746 # TODO: can this be moved into .apply()
690747 state .transition_lstm_embeddings .append (TransitionLSTMEmbedding (transition_h0 [:, state_idx , :], transition_c0 [:, state_idx , :]))
691748 self .update_partial_tree_lstm (states , range (len (states )), partial_tree_h0 , partial_tree_c0 )
749+ self .update_arc_embedding_lstm (states , arc_embeddings )
692750 #print(len(states[0].subtree_lstm_embeddings),
693751 # len(states[0].current_heads), states[0].current_heads)
694752 #if len(states[0].subtree_lstm_embeddings) > 0:
@@ -727,7 +785,9 @@ def build_initial_states(self, head, deprel, text, lstm_outputs, sentlens):
727785 # the sentences are all prepended with root
728786 # which is fine, since we need an embedding for word 0
729787 state = state ._replace (word_embeddings = lstm_output ,
730- subtree_embeddings = {})
788+ subtree_embeddings = {},
789+ arc_embeddings = [],
790+ arc_lstm_embeddings = [ArcLSTMEmbedding (self .arc_embedding_start , self .arc_embedding_h0 , self .arc_embedding_c0 )])
731791 updated_states .append (state )
732792 return updated_states
733793
@@ -792,7 +852,9 @@ def forward(self, model_states):
792852 model_transition_c0 = [x [4 ] for x in model_forwards ]
793853 model_partial_tree_h0 = [x [5 ] for x in model_forwards ]
794854 model_partial_tree_c0 = [x [6 ] for x in model_forwards ]
795- return output_hx , left_deprels , right_deprels , model_transition_h0 , model_transition_c0 , model_partial_tree_h0 , model_partial_tree_c0
855+ model_left_arc_hx = [x [7 ] for x in model_forwards ]
856+ model_right_arc_hx = [x [8 ] for x in model_forwards ]
857+ return output_hx , left_deprels , right_deprels , model_transition_h0 , model_transition_c0 , model_partial_tree_h0 , model_partial_tree_c0 , model_left_arc_hx , model_right_arc_hx
796858
797859 def predict (self , word , word_mask , wordchars , wordchars_mask , upos , xpos , ufeats , pretrained , lemma , head , deprel , word_orig_idx , sentlens , wordlens , text ):
798860 device = self .get_device ()
@@ -812,7 +874,8 @@ def predict(self, word, word_mask, wordchars, wordchars_mask, upos, xpos, ufeats
812874
813875 # output_hx, left_deprels, and right_deprels are already collapsed into one summed value
814876 # the transition and partial_tree vectors are a list of M items long (M=#models)
815- output_hx , left_deprels , right_deprels , model_transition_h0 , model_transition_c0 , model_partial_tree_h0 , model_partial_tree_c0 = self .forward (model_states )
877+ output_hx , left_deprels , right_deprels , model_transition_h0 , model_transition_c0 , model_partial_tree_h0 , model_partial_tree_c0 , model_left_arc_hx , model_right_arc_hx = self .forward (model_states )
878+ # TODO: use the left & right arcs
816879
817880 transitions = TransitionParser .choose_transitions (self .models [0 ].relations , model_states [0 ], output_hx , left_deprels , right_deprels )
818881
0 commit comments