Skip to content

Commit 28e814a

Browse files
committed
Use arc_embeddings as an input when deciding the next transition
Add a place for storing the arc embeddings in the State Saves the arc embeddings used for each dependency chosen, runs an LSTM over them, then includes them in the merged output layers for the next iteration
1 parent 1f06efa commit 28e814a

File tree

2 files changed

+81
-17
lines changed

2 files changed

+81
-17
lines changed

stanza/models/depparse/transition/model.py

Lines changed: 74 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from stanza.models.common.utils import build_nonlinearity, unsort
88
from stanza.models.common.vocab import VOCAB_PREFIX_SIZE
99
from stanza.models.depparse.model import BaseParser, EmbeddingParser
10-
from stanza.models.depparse.transition.state import state_from_text, states_from_data_batch, TransitionLSTMEmbedding, SubtreeLSTMEmbedding
10+
from stanza.models.depparse.transition.state import state_from_text, states_from_data_batch, TransitionLSTMEmbedding, SubtreeLSTMEmbedding, ArcLSTMEmbedding
1111
from stanza.models.depparse.transition.transitions import Shift, Finalize, ProjectiveLeft, ProjectiveRight, NonprojectiveLeft, NonprojectiveRight
1212

1313
# A few notes on some experiments crossvalidating the hyperparameters for this model
@@ -115,16 +115,18 @@ def __init__(self, args, vocab, emb_matrix=None, foundation_cache=None, bert_mod
115115
self.transition_subtree_nonlinearity = build_nonlinearity(self.args.get('transition_subtree_nonlinearity'))
116116
self.drop = nn.Dropout(self.args['dropout'])
117117

118+
self.merge_words_output_dim = self.args['transition_merge_words_output_dim']
119+
self.merge_hidden_dim = self.transition_hidden_dim + self.args['hidden_dim'] + self.merge_words_output_dim
120+
118121
# the bidirectional LSTM is x2, adding in the partial trees is another x1
119-
self.word_hidden_dim = self.transition_hidden_dim + self.args['hidden_dim'] * 3
122+
# the arc embeddings take up merge_hidden_dim
123+
self.word_hidden_dim = self.transition_hidden_dim + self.args['hidden_dim'] * 3 + self.merge_hidden_dim
120124
self.word_output_layers = nn.Sequential(self.nonlinearity,
121125
self.drop,
122126
nn.Linear(self.word_hidden_dim, self.word_hidden_dim),
123127
self.nonlinearity,
124128
self.drop,
125129
nn.Linear(self.word_hidden_dim, self.word_hidden_dim))
126-
self.merge_words_output_dim = self.args['transition_merge_words_output_dim']
127-
self.merge_hidden_dim = self.transition_hidden_dim + self.args['hidden_dim'] + self.merge_words_output_dim
128130
# Splitting this into a left and right version is close,
129131
# but seems to be somewhat more accurate than one layer
130132
# 5 model dev avg LAS baseline merge-two-sides
@@ -155,6 +157,11 @@ def __init__(self, args, vocab, emb_matrix=None, foundation_cache=None, bert_mod
155157
self.drop,
156158
nn.Linear(self.merge_hidden_dim, self.merge_hidden_dim))
157159

160+
self.arc_embedding_lstm = nn.LSTM(input_size=self.merge_hidden_dim, hidden_size=self.merge_hidden_dim, num_layers=self.args['num_layers'], dropout=self.args['dropout'])
161+
self.arc_embedding_start = nn.Parameter(torch.zeros(self.merge_hidden_dim))
162+
self.arc_embedding_h0 = nn.Parameter(torch.zeros(self.args['num_layers'], self.merge_hidden_dim))
163+
self.arc_embedding_c0 = nn.Parameter(torch.zeros(self.args['num_layers'], self.merge_hidden_dim))
164+
158165
self.output_basic = nn.Linear(self.word_hidden_dim, 2)
159166
self.output_left_transition = nn.Linear(self.merge_hidden_dim, 1)
160167
self.output_right_transition = nn.Linear(self.merge_hidden_dim, 1)
@@ -274,19 +281,26 @@ def forward(self, states):
274281
partial_tree_embeddings = partial_tree_embeddings.squeeze(0)
275282
#print(torch.linalg.norm(partial_tree_embeddings))
276283

284+
arc_embeddings = [state.arc_lstm_embeddings[-1].hx for state in states]
285+
arc_embeddings = torch.stack(arc_embeddings, dim=0)
286+
277287
word_embeddings = [state.word_embeddings[state.word_position] for state in states]
278288
word_embeddings = torch.stack(word_embeddings)
279-
output_hx = torch.cat([transition_embeddings, partial_tree_embeddings, word_embeddings], dim=1)
289+
290+
output_hx = torch.cat([transition_embeddings, partial_tree_embeddings, arc_embeddings, word_embeddings], dim=1)
280291
output_hx = self.word_output_layers(output_hx)
281292
# batch size x 2 - Shift or Finalize
282293
basic_output = self.output_basic(self.drop(self.nonlinearity(output_hx)))
283294
final_output = [[x] for x in basic_output]
284295
left_deprels = []
285296
right_deprels = []
286-
297+
left_arc_hxs = []
298+
right_arc_hxs = []
287299
for state_idx, state in enumerate(states):
288300
left_deprel = None
289301
right_deprel = None
302+
left_arc_hx = None
303+
right_arc_hx = None
290304
if len(state.current_heads) > 1:
291305
# TODO: add a position embedding for the projective / non-projective attachments?
292306
attachment_embeddings = [torch.cat([state.subtree_embeddings[x], state.subtree_embeddings[state.current_heads[-1]]])
@@ -316,8 +330,10 @@ def forward(self, states):
316330
final_output[state_idx] = [final_output[state_idx][0], left_output.squeeze(1), right_output.squeeze(1)]
317331
left_deprels.append(left_deprel)
318332
right_deprels.append(right_deprel)
333+
left_arc_hxs.append(left_arc_hx)
334+
right_arc_hxs.append(right_arc_hx)
319335
final_output = [torch.cat(x) for x in final_output]
320-
return final_output, left_deprels, right_deprels, transition_h0, transition_c0, partial_tree_h0, partial_tree_c0
336+
return final_output, left_deprels, right_deprels, transition_h0, transition_c0, partial_tree_h0, partial_tree_c0, left_arc_hxs, right_arc_hxs
321337

322338
def update_subtree_embeddings(self, states, transitions):
323339
embeddings = []
@@ -520,6 +536,41 @@ def update_partial_tree_lstm(self, states, state_idxs, partial_tree_h0, partial_
520536
states = new_states
521537
state_idxs = new_state_idxs
522538

539+
def extract_arc_embeddings(self, states, transitions, left_arc_hxs, right_arc_hxs):
540+
arc_embeddings = []
541+
for state_idx, (state, transition) in enumerate(zip(states, transitions)):
542+
if isinstance(transition, Shift):
543+
arc_embeddings.append(None)
544+
elif isinstance(transition, Finalize):
545+
arc_embeddings.append(None)
546+
elif isinstance(transition, ProjectiveLeft) or isinstance(transition, NonprojectiveLeft):
547+
if isinstance(transition, ProjectiveLeft):
548+
head = state.current_heads[-2]
549+
else:
550+
head = transition.word_idx
551+
arc_embeddings.append(left_arc_hxs[state_idx][head-1])
552+
elif isinstance(transition, ProjectiveRight) or isinstance(transition, NonprojectiveRight):
553+
if isinstance(transition, ProjectiveRight):
554+
head = len(state.current_heads) - 2
555+
else:
556+
head = state.current_heads.index(transition.word_idx)
557+
arc_embeddings.append(right_arc_hxs[state_idx][head])
558+
return arc_embeddings
559+
560+
def update_arc_embedding_lstm(self, states, arc_embeddings):
561+
arc_states = [state for state_idx, state in enumerate(states) if arc_embeddings[state_idx] is not None]
562+
arc_embeddings = [arc for arc in arc_embeddings if arc is not None]
563+
if len(arc_embeddings) == 0:
564+
return
565+
arc_embedding_hx = torch.stack(arc_embeddings, dim=0).unsqueeze(0)
566+
arc_embedding_h0 = torch.stack([state.arc_lstm_embeddings[-1].h0 for state in arc_states], dim=1)
567+
arc_embedding_c0 = torch.stack([state.arc_lstm_embeddings[-1].c0 for state in arc_states], dim=1)
568+
arc_embedding_hx, (arc_embedding_h0, arc_embedding_c0) = self.arc_embedding_lstm(arc_embedding_hx, (arc_embedding_h0, arc_embedding_c0))
569+
for state_idx, state in enumerate(arc_states):
570+
state.arc_embeddings.append(arc_embeddings[state_idx])
571+
state.arc_lstm_embeddings.append(ArcLSTMEmbedding(arc_embedding_hx[0, state_idx, :],
572+
arc_embedding_h0[:, state_idx, :],
573+
arc_embedding_c0[:, state_idx, :]))
523574

524575
def loss(self, word, word_mask, wordchars, wordchars_mask, upos, xpos, ufeats, pretrained, lemma, head, deprel, word_orig_idx, sentlens, wordlens, text):
525576
# lstm_outputs will be a list of tensors for each sentence
@@ -534,7 +585,7 @@ def loss(self, word, word_mask, wordchars, wordchars_mask, upos, xpos, ufeats, p
534585
while len(states) > 0:
535586
iteration += 1
536587
#print("ITERATION %d" % iteration)
537-
output_hx, left_deprels, right_deprels, transition_h0, transition_c0, partial_tree_h0, partial_tree_c0 = self.forward(states)
588+
output_hx, left_deprels, right_deprels, transition_h0, transition_c0, partial_tree_h0, partial_tree_c0, left_arc_hxs, right_arc_hxs = self.forward(states)
538589
gold_transitions = [state.gold_sequence[len(state.transitions)] for state in states]
539590
new_states = []
540591
for state_idx, (state, gold_transition) in enumerate(zip(states, gold_transitions)):
@@ -595,12 +646,14 @@ def loss(self, word, word_mask, wordchars, wordchars_mask, upos, xpos, ufeats, p
595646
total_loss += self.deprel_loss_function(deprel_output, deprel_one_hot)
596647
one_hot = one_hot.to(device)
597648
total_loss += self.transition_loss_function(output_hx[state_idx], one_hot)
649+
arc_embeddings = self.extract_arc_embeddings(states, gold_transitions, left_arc_hxs, right_arc_hxs)
598650
states = self.update_subtree_embeddings(states, gold_transitions)
599651
states = [gold_transition.apply(state) for state, gold_transition in zip(states, gold_transitions)]
600652
for state_idx, state in enumerate(states):
601653
# TODO: can this be moved into .apply()
602654
state.transition_lstm_embeddings.append(TransitionLSTMEmbedding(transition_h0[:, state_idx, :], transition_c0[:, state_idx, :]))
603655
self.update_partial_tree_lstm(states, range(len(states)), partial_tree_h0, partial_tree_c0)
656+
self.update_arc_embedding_lstm(states, arc_embeddings)
604657
states = [state for state in states if not isinstance(state.transitions[-1], Finalize)]
605658

606659
return total_loss
@@ -658,8 +711,11 @@ def predict(self, word, word_mask, wordchars, wordchars_mask, upos, xpos, ufeats
658711
while len(states) > 0:
659712
iteration += 1
660713
#print("ITERATION %d" % iteration)
661-
output_hx, left_deprels, right_deprels, transition_h0, transition_c0, partial_tree_h0, partial_tree_c0 = self.forward(states)
714+
output_hx, left_deprels, right_deprels, transition_h0, transition_c0, partial_tree_h0, partial_tree_c0, left_arc_hxs, right_arc_hxs = self.forward(states)
662715
transitions = self.choose_transitions(self.relations, states, output_hx, left_deprels, right_deprels)
716+
717+
arc_embeddings = self.extract_arc_embeddings(states, transitions, left_arc_hxs, right_arc_hxs)
718+
663719
#print(transitions[0])
664720
#print(len(states[0].subtree_lstm_embeddings),
665721
# len(states[0].current_heads), states[0].current_heads)
@@ -675,6 +731,7 @@ def predict(self, word, word_mask, wordchars, wordchars_mask, upos, xpos, ufeats
675731
# TODO: can this be moved into .apply()
676732
state.transition_lstm_embeddings.append(TransitionLSTMEmbedding(transition_h0[:, state_idx, :], transition_c0[:, state_idx, :]))
677733
self.update_partial_tree_lstm(states, range(len(states)), partial_tree_h0, partial_tree_c0)
734+
self.update_arc_embedding_lstm(states, arc_embeddings)
678735
#print(len(states[0].subtree_lstm_embeddings),
679736
# len(states[0].current_heads), states[0].current_heads)
680737
#if len(states[0].subtree_lstm_embeddings) > 0:
@@ -713,7 +770,9 @@ def build_initial_states(self, head, deprel, text, lstm_outputs, sentlens):
713770
# the sentences are all prepended with root
714771
# which is fine, since we need an embedding for word 0
715772
state = state._replace(word_embeddings=lstm_output,
716-
subtree_embeddings={})
773+
subtree_embeddings={},
774+
arc_embeddings=[],
775+
arc_lstm_embeddings=[ArcLSTMEmbedding(self.arc_embedding_start, self.arc_embedding_h0, self.arc_embedding_c0)])
717776
updated_states.append(state)
718777
return updated_states
719778

@@ -778,7 +837,9 @@ def forward(self, model_states):
778837
model_transition_c0 = [x[4] for x in model_forwards]
779838
model_partial_tree_h0 = [x[5] for x in model_forwards]
780839
model_partial_tree_c0 = [x[6] for x in model_forwards]
781-
return output_hx, left_deprels, right_deprels, model_transition_h0, model_transition_c0, model_partial_tree_h0, model_partial_tree_c0
840+
model_left_arc_hx = [x[7] for x in model_forwards]
841+
model_right_arc_hx = [x[8] for x in model_forwards]
842+
return output_hx, left_deprels, right_deprels, model_transition_h0, model_transition_c0, model_partial_tree_h0, model_partial_tree_c0, model_left_arc_hx, model_right_arc_hx
782843

783844
def predict(self, word, word_mask, wordchars, wordchars_mask, upos, xpos, ufeats, pretrained, lemma, head, deprel, word_orig_idx, sentlens, wordlens, text):
784845
device = self.get_device()
@@ -798,7 +859,8 @@ def predict(self, word, word_mask, wordchars, wordchars_mask, upos, xpos, ufeats
798859

799860
# output_hx, left_deprels, and right_deprels are already collapsed into one summed value
800861
# the transition and partial_tree vectors are a list of M items long (M=#models)
801-
output_hx, left_deprels, right_deprels, model_transition_h0, model_transition_c0, model_partial_tree_h0, model_partial_tree_c0 = self.forward(model_states)
862+
output_hx, left_deprels, right_deprels, model_transition_h0, model_transition_c0, model_partial_tree_h0, model_partial_tree_c0, model_left_arc_hx, model_right_arc_hx = self.forward(model_states)
863+
# TODO: use the left & right arcs
802864

803865
transitions = TransitionParser.choose_transitions(self.models[0].relations, model_states[0], output_hx, left_deprels, right_deprels)
804866

stanza/models/depparse/transition/state.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,16 @@
1414

1515
TransitionLSTMEmbedding = namedtuple('TransitionLSTMEmbedding', 'h0 c0')
1616
SubtreeLSTMEmbedding = namedtuple('SubtreeLSTMEmbedding', 'h0 c0')
17+
# TODO: is it possible to make the others have the hx along with the h0 and c0 in TransitionLSTMEmbedding and SubtreeLSTMEmbedding
18+
ArcLSTMEmbedding = namedtuple('ArcLSTMEmbedding', 'hx h0 c0')
1719

1820
# transitions and parsed_graph represent the current state of a parse
1921
# gold_graph and gold_sequence are gold, if that information exists
2022
# current_heads is a list of the word IDs for the heads of the subtrees
2123
# transition_lstm_embeddings is a list of the above TransitionLSTMEmbedding namedtuple - one per transition
2224
State = namedtuple('State', ['transitions', 'parsed_graph', 'word_position', 'num_words', 'current_heads',
23-
'gold_graph', 'gold_sequence', 'word_embeddings', 'subtree_embeddings',
24-
'transition_lstm_embeddings', 'subtree_lstm_embeddings'])
25+
'gold_graph', 'gold_sequence', 'word_embeddings', 'subtree_embeddings', 'arc_embeddings',
26+
'transition_lstm_embeddings', 'subtree_lstm_embeddings', 'arc_lstm_embeddings'])
2527

2628
def is_nonproj(gold_graph, node, pred):
2729
for middle in range(node+1, pred):
@@ -33,7 +35,7 @@ def is_nonproj(gold_graph, node, pred):
3335

3436
def build_gold_sequence(gold_graph):
3537
num_words = len(gold_graph.nodes()) - 1
36-
state = State([], nx.DiGraph(), 0, num_words, [], None, None, None, None, [], [])
38+
state = State([], nx.DiGraph(), 0, num_words, [], None, None, None, None, None, [], [], [])
3739

3840
# determine which arcs are non-projective
3941
# key is the head, value is a set of the children which are non-proj
@@ -123,7 +125,7 @@ def state_from_graph(gold_graph):
123125

124126
gold_sequence = build_gold_sequence(gold_graph)
125127
num_words = len(gold_graph.nodes()) - 1
126-
return State(transitions, empty_graph, 0, num_words, [], gold_graph, gold_sequence, None, None, [], [])
128+
return State(transitions, empty_graph, 0, num_words, [], gold_graph, gold_sequence, None, None, None, [], [], [])
127129

128130
def from_gold(sentence):
129131
gold_graph = nx.DiGraph()
@@ -160,4 +162,4 @@ def state_from_text(text):
160162
transitions = []
161163
num_words = len(text)
162164
empty_graph = nx.DiGraph()
163-
return State(transitions, empty_graph, 0, num_words, [], None, None, None, None, [], [])
165+
return State(transitions, empty_graph, 0, num_words, [], None, None, None, None, None, [], [], [])

0 commit comments

Comments
 (0)