Skip to content

Commit eb54005

Browse files
committed
Use arc_embeddings as an input when deciding the next transition
Add a place for storing the arc embeddings in the State Saves the arc embeddings used for each dependency chosen, runs an LSTM over them, then includes them in the merged output layers for the next iteration
1 parent 533c1fa commit eb54005

File tree

2 files changed

+82
-17
lines changed

2 files changed

+82
-17
lines changed

stanza/models/depparse/transition/model.py

Lines changed: 75 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from stanza.models.common.utils import build_nonlinearity, unsort
88
from stanza.models.common.vocab import VOCAB_PREFIX_SIZE
99
from stanza.models.depparse.model import BaseParser, EmbeddingParser
10-
from stanza.models.depparse.transition.state import state_from_text, states_from_data_batch, TransitionLSTMEmbedding, SubtreeLSTMEmbedding
10+
from stanza.models.depparse.transition.state import state_from_text, states_from_data_batch, TransitionLSTMEmbedding, SubtreeLSTMEmbedding, ArcLSTMEmbedding
1111
from stanza.models.depparse.transition.transitions import Shift, Finalize, ProjectiveLeft, ProjectiveRight, NonprojectiveLeft, NonprojectiveRight
1212

1313
# A few notes on some experiments crossvalidating the hyperparameters for this model
@@ -115,16 +115,18 @@ def __init__(self, args, vocab, emb_matrix=None, foundation_cache=None, bert_mod
115115
self.transition_subtree_nonlinearity = build_nonlinearity(self.args.get('transition_subtree_nonlinearity'))
116116
self.drop = nn.Dropout(self.args['dropout'])
117117

118+
self.merge_words_output_dim = self.args['transition_merge_words_output_dim']
119+
self.merge_hidden_dim = self.transition_hidden_dim + self.args['hidden_dim'] + self.merge_words_output_dim
120+
118121
# the bidirectional LSTM is x2, adding in the partial trees is another x1
119-
self.word_hidden_dim = self.transition_hidden_dim + self.args['hidden_dim'] * 3
122+
# the arc embeddings take up merge_hidden_dim
123+
self.word_hidden_dim = self.transition_hidden_dim + self.args['hidden_dim'] * 3 + self.merge_hidden_dim
120124
self.word_output_layers = nn.Sequential(self.nonlinearity,
121125
self.drop,
122126
nn.Linear(self.word_hidden_dim, self.word_hidden_dim),
123127
self.nonlinearity,
124128
self.drop,
125129
nn.Linear(self.word_hidden_dim, self.word_hidden_dim))
126-
self.merge_words_output_dim = self.args['transition_merge_words_output_dim']
127-
self.merge_hidden_dim = self.transition_hidden_dim + self.args['hidden_dim'] + self.merge_words_output_dim
128130
# Splitting this into a left and right version is close,
129131
# but seems to be somewhat more accurate than one layer
130132
# 5 model dev avg LAS baseline merge-two-sides
@@ -155,6 +157,11 @@ def __init__(self, args, vocab, emb_matrix=None, foundation_cache=None, bert_mod
155157
self.drop,
156158
nn.Linear(self.merge_hidden_dim, self.merge_hidden_dim))
157159

160+
self.arc_embedding_lstm = nn.LSTM(input_size=self.merge_hidden_dim, hidden_size=self.merge_hidden_dim, num_layers=self.args['num_layers'], dropout=self.args['dropout'])
161+
self.arc_embedding_start = nn.Parameter(torch.zeros(self.merge_hidden_dim))
162+
self.arc_embedding_h0 = nn.Parameter(torch.zeros(self.args['num_layers'], self.merge_hidden_dim))
163+
self.arc_embedding_c0 = nn.Parameter(torch.zeros(self.args['num_layers'], self.merge_hidden_dim))
164+
158165
self.output_basic = nn.Linear(self.word_hidden_dim, 2)
159166
self.output_left_transition = nn.Linear(self.merge_hidden_dim, 1)
160167
self.output_right_transition = nn.Linear(self.merge_hidden_dim, 1)
@@ -274,19 +281,26 @@ def forward(self, states):
274281
partial_tree_embeddings = partial_tree_embeddings.squeeze(0)
275282
#print(torch.linalg.norm(partial_tree_embeddings))
276283

284+
arc_embeddings = [state.arc_lstm_embeddings[-1].hx for state in states]
285+
arc_embeddings = torch.stack(arc_embeddings, dim=0)
286+
277287
word_embeddings = [state.word_embeddings[state.word_position] for state in states]
278288
word_embeddings = torch.stack(word_embeddings)
279-
output_hx = torch.cat([transition_embeddings, partial_tree_embeddings, word_embeddings], dim=1)
289+
290+
output_hx = torch.cat([transition_embeddings, partial_tree_embeddings, arc_embeddings, word_embeddings], dim=1)
280291
output_hx = self.word_output_layers(output_hx)
281292
# batch size x 2 - Shift or Finalize
282293
basic_output = self.output_basic(self.drop(self.nonlinearity(output_hx)))
283294
final_output = [[x] for x in basic_output]
284295
left_deprels = []
285296
right_deprels = []
286-
297+
left_arc_hxs = []
298+
right_arc_hxs = []
287299
for state_idx, state in enumerate(states):
288300
left_deprel = None
289301
right_deprel = None
302+
left_arc_hx = None
303+
right_arc_hx = None
290304
if len(state.current_heads) > 1:
291305
# TODO: add a position embedding for the projective / non-projective attachments?
292306
attachment_embeddings = [torch.cat([state.subtree_embeddings[x], state.subtree_embeddings[state.current_heads[-1]]])
@@ -316,8 +330,10 @@ def forward(self, states):
316330
final_output[state_idx] = [final_output[state_idx][0], left_output.squeeze(1), right_output.squeeze(1)]
317331
left_deprels.append(left_deprel)
318332
right_deprels.append(right_deprel)
333+
left_arc_hxs.append(left_arc_hx)
334+
right_arc_hxs.append(right_arc_hx)
319335
final_output = [torch.cat(x) for x in final_output]
320-
return final_output, left_deprels, right_deprels, transition_h0, transition_c0, partial_tree_h0, partial_tree_c0
336+
return final_output, left_deprels, right_deprels, transition_h0, transition_c0, partial_tree_h0, partial_tree_c0, left_arc_hxs, right_arc_hxs
321337

322338
def update_subtree_embeddings(self, states, transitions):
323339
embeddings = []
@@ -591,6 +607,42 @@ def calculate_iteration_loss(self, states, gold_transitions, output_hx, left_dep
591607
total_loss += self.deprel_loss_function(deprel_hx, deprel_one_hots)
592608
return total_loss
593609

610+
def extract_arc_embeddings(self, states, transitions, left_arc_hxs, right_arc_hxs):
611+
arc_embeddings = []
612+
for state_idx, (state, transition) in enumerate(zip(states, transitions)):
613+
if isinstance(transition, Shift):
614+
arc_embeddings.append(None)
615+
elif isinstance(transition, Finalize):
616+
arc_embeddings.append(None)
617+
elif isinstance(transition, ProjectiveLeft) or isinstance(transition, NonprojectiveLeft):
618+
if isinstance(transition, ProjectiveLeft):
619+
head = state.current_heads[-2]
620+
else:
621+
head = transition.word_idx
622+
arc_embeddings.append(left_arc_hxs[state_idx][head-1])
623+
elif isinstance(transition, ProjectiveRight) or isinstance(transition, NonprojectiveRight):
624+
if isinstance(transition, ProjectiveRight):
625+
head = len(state.current_heads) - 2
626+
else:
627+
head = state.current_heads.index(transition.word_idx)
628+
arc_embeddings.append(right_arc_hxs[state_idx][head])
629+
return arc_embeddings
630+
631+
def update_arc_embedding_lstm(self, states, arc_embeddings):
632+
arc_states = [state for state_idx, state in enumerate(states) if arc_embeddings[state_idx] is not None]
633+
arc_embeddings = [arc for arc in arc_embeddings if arc is not None]
634+
if len(arc_embeddings) == 0:
635+
return
636+
arc_embedding_hx = torch.stack(arc_embeddings, dim=0).unsqueeze(0)
637+
arc_embedding_h0 = torch.stack([state.arc_lstm_embeddings[-1].h0 for state in arc_states], dim=1)
638+
arc_embedding_c0 = torch.stack([state.arc_lstm_embeddings[-1].c0 for state in arc_states], dim=1)
639+
arc_embedding_hx, (arc_embedding_h0, arc_embedding_c0) = self.arc_embedding_lstm(arc_embedding_hx, (arc_embedding_h0, arc_embedding_c0))
640+
for state_idx, state in enumerate(arc_states):
641+
state.arc_embeddings.append(arc_embeddings[state_idx])
642+
state.arc_lstm_embeddings.append(ArcLSTMEmbedding(arc_embedding_hx[0, state_idx, :],
643+
arc_embedding_h0[:, state_idx, :],
644+
arc_embedding_c0[:, state_idx, :]))
645+
594646
def loss(self, word, word_mask, wordchars, wordchars_mask, upos, xpos, ufeats, pretrained, lemma, head, deprel, word_orig_idx, sentlens, wordlens, text):
595647
# lstm_outputs will be a list of tensors for each sentence
596648
# max(len) x args['hidden_dim']*2
@@ -603,18 +655,20 @@ def loss(self, word, word_mask, wordchars, wordchars_mask, upos, xpos, ufeats, p
603655
while len(states) > 0:
604656
iteration += 1
605657
#print("ITERATION %d" % iteration)
606-
output_hx, left_deprels, right_deprels, transition_h0, transition_c0, partial_tree_h0, partial_tree_c0 = self.forward(states)
658+
output_hx, left_deprels, right_deprels, transition_h0, transition_c0, partial_tree_h0, partial_tree_c0, left_arc_hxs, right_arc_hxs = self.forward(states)
607659
gold_transitions = [state.gold_sequence[len(state.transitions)] for state in states]
608660

609661
iteration_loss = self.calculate_iteration_loss(states, gold_transitions, output_hx, left_deprels, right_deprels)
610662
total_loss += iteration_loss
611663

664+
arc_embeddings = self.extract_arc_embeddings(states, gold_transitions, left_arc_hxs, right_arc_hxs)
612665
states = self.update_subtree_embeddings(states, gold_transitions)
613666
states = [gold_transition.apply(state) for state, gold_transition in zip(states, gold_transitions)]
614667
for state_idx, state in enumerate(states):
615668
# TODO: can this be moved into .apply()
616669
state.transition_lstm_embeddings.append(TransitionLSTMEmbedding(transition_h0[:, state_idx, :], transition_c0[:, state_idx, :]))
617670
self.update_partial_tree_lstm(states, range(len(states)), partial_tree_h0, partial_tree_c0)
671+
self.update_arc_embedding_lstm(states, arc_embeddings)
618672
states = [state for state in states if not isinstance(state.transitions[-1], Finalize)]
619673

620674
return total_loss
@@ -672,8 +726,11 @@ def predict(self, word, word_mask, wordchars, wordchars_mask, upos, xpos, ufeats
672726
while len(states) > 0:
673727
iteration += 1
674728
#print("ITERATION %d" % iteration)
675-
output_hx, left_deprels, right_deprels, transition_h0, transition_c0, partial_tree_h0, partial_tree_c0 = self.forward(states)
729+
output_hx, left_deprels, right_deprels, transition_h0, transition_c0, partial_tree_h0, partial_tree_c0, left_arc_hxs, right_arc_hxs = self.forward(states)
676730
transitions = self.choose_transitions(self.relations, states, output_hx, left_deprels, right_deprels)
731+
732+
arc_embeddings = self.extract_arc_embeddings(states, transitions, left_arc_hxs, right_arc_hxs)
733+
677734
#print(transitions[0])
678735
#print(len(states[0].subtree_lstm_embeddings),
679736
# len(states[0].current_heads), states[0].current_heads)
@@ -689,6 +746,7 @@ def predict(self, word, word_mask, wordchars, wordchars_mask, upos, xpos, ufeats
689746
# TODO: can this be moved into .apply()
690747
state.transition_lstm_embeddings.append(TransitionLSTMEmbedding(transition_h0[:, state_idx, :], transition_c0[:, state_idx, :]))
691748
self.update_partial_tree_lstm(states, range(len(states)), partial_tree_h0, partial_tree_c0)
749+
self.update_arc_embedding_lstm(states, arc_embeddings)
692750
#print(len(states[0].subtree_lstm_embeddings),
693751
# len(states[0].current_heads), states[0].current_heads)
694752
#if len(states[0].subtree_lstm_embeddings) > 0:
@@ -727,7 +785,9 @@ def build_initial_states(self, head, deprel, text, lstm_outputs, sentlens):
727785
# the sentences are all prepended with root
728786
# which is fine, since we need an embedding for word 0
729787
state = state._replace(word_embeddings=lstm_output,
730-
subtree_embeddings={})
788+
subtree_embeddings={},
789+
arc_embeddings=[],
790+
arc_lstm_embeddings=[ArcLSTMEmbedding(self.arc_embedding_start, self.arc_embedding_h0, self.arc_embedding_c0)])
731791
updated_states.append(state)
732792
return updated_states
733793

@@ -792,7 +852,9 @@ def forward(self, model_states):
792852
model_transition_c0 = [x[4] for x in model_forwards]
793853
model_partial_tree_h0 = [x[5] for x in model_forwards]
794854
model_partial_tree_c0 = [x[6] for x in model_forwards]
795-
return output_hx, left_deprels, right_deprels, model_transition_h0, model_transition_c0, model_partial_tree_h0, model_partial_tree_c0
855+
model_left_arc_hx = [x[7] for x in model_forwards]
856+
model_right_arc_hx = [x[8] for x in model_forwards]
857+
return output_hx, left_deprels, right_deprels, model_transition_h0, model_transition_c0, model_partial_tree_h0, model_partial_tree_c0, model_left_arc_hx, model_right_arc_hx
796858

797859
def predict(self, word, word_mask, wordchars, wordchars_mask, upos, xpos, ufeats, pretrained, lemma, head, deprel, word_orig_idx, sentlens, wordlens, text):
798860
device = self.get_device()
@@ -812,7 +874,8 @@ def predict(self, word, word_mask, wordchars, wordchars_mask, upos, xpos, ufeats
812874

813875
# output_hx, left_deprels, and right_deprels are already collapsed into one summed value
814876
# the transition and partial_tree vectors are a list of M items long (M=#models)
815-
output_hx, left_deprels, right_deprels, model_transition_h0, model_transition_c0, model_partial_tree_h0, model_partial_tree_c0 = self.forward(model_states)
877+
output_hx, left_deprels, right_deprels, model_transition_h0, model_transition_c0, model_partial_tree_h0, model_partial_tree_c0, model_left_arc_hx, model_right_arc_hx = self.forward(model_states)
878+
# TODO: use the left & right arcs
816879

817880
transitions = TransitionParser.choose_transitions(self.models[0].relations, model_states[0], output_hx, left_deprels, right_deprels)
818881

stanza/models/depparse/transition/state.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,16 @@
1414

1515
TransitionLSTMEmbedding = namedtuple('TransitionLSTMEmbedding', 'h0 c0')
1616
SubtreeLSTMEmbedding = namedtuple('SubtreeLSTMEmbedding', 'h0 c0')
17+
# TODO: is it possible to make the others have the hx along with the h0 and c0 in TransitionLSTMEmbedding and SubtreeLSTMEmbedding
18+
ArcLSTMEmbedding = namedtuple('ArcLSTMEmbedding', 'hx h0 c0')
1719

1820
# transitions and parsed_graph represent the current state of a parse
1921
# gold_graph and gold_sequence are gold, if that information exists
2022
# current_heads is a list of the word IDs for the heads of the subtrees
2123
# transition_lstm_embeddings is a list of the above TransitionLSTMEmbedding namedtuple - one per transition
2224
State = namedtuple('State', ['transitions', 'parsed_graph', 'word_position', 'num_words', 'current_heads',
23-
'gold_graph', 'gold_sequence', 'word_embeddings', 'subtree_embeddings',
24-
'transition_lstm_embeddings', 'subtree_lstm_embeddings'])
25+
'gold_graph', 'gold_sequence', 'word_embeddings', 'subtree_embeddings', 'arc_embeddings',
26+
'transition_lstm_embeddings', 'subtree_lstm_embeddings', 'arc_lstm_embeddings'])
2527

2628
def is_nonproj(gold_graph, node, pred):
2729
for middle in range(node+1, pred):
@@ -33,7 +35,7 @@ def is_nonproj(gold_graph, node, pred):
3335

3436
def build_gold_sequence(gold_graph):
3537
num_words = len(gold_graph.nodes()) - 1
36-
state = State([], nx.DiGraph(), 0, num_words, [], None, None, None, None, [], [])
38+
state = State([], nx.DiGraph(), 0, num_words, [], None, None, None, None, None, [], [], [])
3739

3840
# determine which arcs are non-projective
3941
# key is the head, value is a set of the children which are non-proj
@@ -123,7 +125,7 @@ def state_from_graph(gold_graph):
123125

124126
gold_sequence = build_gold_sequence(gold_graph)
125127
num_words = len(gold_graph.nodes()) - 1
126-
return State(transitions, empty_graph, 0, num_words, [], gold_graph, gold_sequence, None, None, [], [])
128+
return State(transitions, empty_graph, 0, num_words, [], gold_graph, gold_sequence, None, None, None, [], [], [])
127129

128130
def from_gold(sentence):
129131
gold_graph = nx.DiGraph()
@@ -160,4 +162,4 @@ def state_from_text(text):
160162
transitions = []
161163
num_words = len(text)
162164
empty_graph = nx.DiGraph()
163-
return State(transitions, empty_graph, 0, num_words, [], None, None, None, None, [], [])
165+
return State(transitions, empty_graph, 0, num_words, [], None, None, None, None, None, [], [], [])

0 commit comments

Comments
 (0)