Skip to content

Commit 7ba55d2

Browse files
committed
Begin attempting to port the distance calculation to the transition parser
Connect the distance to the outputs using a flag for the scaling factor
1 parent 421992d commit 7ba55d2

File tree

3 files changed

+34
-7
lines changed

3 files changed

+34
-7
lines changed

stanza/models/depparse/transition/model.py

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,10 @@
22

33
import torch
44
from torch import nn
5+
import torch.nn.functional as F
56
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
67

8+
from stanza.models.common.biaffine import DeepBiaffineScorer
79
from stanza.models.common.utils import build_nonlinearity, unsort
810
from stanza.models.common.vocab import VOCAB_PREFIX_SIZE
911
from stanza.models.depparse.model import BaseParser, EmbeddingParser
@@ -211,6 +213,10 @@ def __init__(self, args, vocab, emb_matrix=None, foundation_cache=None, bert_mod
211213
self.transition_loss_function = nn.CrossEntropyLoss(reduction='sum')
212214
self.deprel_loss_function = nn.CrossEntropyLoss(reduction='sum')
213215

216+
# self.args['distance_output_dim']
217+
self.distance = DeepBiaffineScorer(2 * self.args['hidden_dim'], 2 * self.args['hidden_dim'], self.args['deep_biaff_hidden_dim'], 1, pairwise=True, dropout=self.args['dropout'])
218+
#self.distance_expansion = nn.Linear(self.args['distance_output_dim'], self.merge_hidden_dim)
219+
214220
def forward(self, states):
215221
"""
216222
Builds a list of logits for the different operations, including a separate one for each Left and Right merge
@@ -300,7 +306,8 @@ def forward(self, states):
300306
attachment_input_left = attachment_input.expand(state.word_position, attachment_input.shape[0])
301307
left_arc_hx = torch.cat([attachment_input_left, attachment_embeddings_left], axis=1)
302308
left_arc_hx = self.merge_output_left(left_arc_hx)
303-
left_output = self.output_left_transition(self.drop(self.nonlinearity(left_arc_hx)))
309+
distance_left = state.distance[0, 1:state.word_position+1, state.current_heads[-1]].unsqueeze(1).detach()
310+
left_output = self.output_left_transition(self.drop(self.nonlinearity(left_arc_hx))) + distance_left * self.args['distance_factor']
304311
left_deprel = self.output_left_deprel(self.drop(self.nonlinearity(left_arc_hx)))
305312

306313
# truncate the outputs to only be the current heads,
@@ -311,7 +318,8 @@ def forward(self, states):
311318
attachment_input_right = attachment_input.unsqueeze(0).expand(current_heads.shape[0], attachment_input.shape[0])
312319
right_arc_hx = torch.cat([attachment_input_right, attachment_embeddings_right], axis=1)
313320
right_arc_hx = self.merge_output_right(right_arc_hx)
314-
right_output = self.output_right_transition(self.drop(self.nonlinearity(right_arc_hx)))
321+
distance_right = state.distance[0, state.current_heads[-1], :][current_heads].unsqueeze(1).detach()
322+
right_output = self.output_right_transition(self.drop(self.nonlinearity(right_arc_hx))) + distance_right * self.args['distance_factor']
315323
right_deprel = self.output_right_deprel(self.drop(self.nonlinearity(right_arc_hx)))
316324
final_output[state_idx] = [final_output[state_idx][0], left_output.squeeze(1), right_output.squeeze(1)]
317325
left_deprels.append(left_deprel)
@@ -598,6 +606,10 @@ def loss(self, word, word_mask, wordchars, wordchars_mask, upos, xpos, ufeats, p
598606
states = self.build_initial_states(head, deprel, text, lstm_outputs, sentlens)
599607

600608
total_loss = 0
609+
for state, sentence_head in zip(states, head):
610+
dist_kld = torch.gather(state.distance[0, 1:, :], 1, sentence_head[:state.num_words].unsqueeze(1))
611+
# definitely not +=... that model is completely broken
612+
total_loss -= dist_kld.sum()
601613

602614
iteration = 0
603615
while len(states) > 0:
@@ -722,12 +734,24 @@ def build_initial_states(self, head, deprel, text, lstm_outputs, sentlens):
722734
else:
723735
states = [state_from_text(sentence) for sentence in text]
724736
updated_states = []
737+
725738
# TODO: list comprehension?
726739
for state, lstm_output, sentlen in zip(states, lstm_outputs, sentlens):
727740
# the sentences are all prepended with root
728741
# which is fine, since we need an embedding for word 0
742+
# for distance, the graph parser uses the extra space with the root
743+
# TODO: stack the distance operation
744+
head_offset = (torch.arange(sentlen, device=lstm_output.device).view(1, 1, -1) -
745+
torch.arange(sentlen, device=lstm_output.device).view(1, -1, 1))
746+
distance_scores = self.distance(self.drop(lstm_output[:sentlen].unsqueeze(0)),
747+
self.drop(lstm_output[:sentlen].unsqueeze(0))).squeeze(3).squeeze(0)
748+
distance_pred = 1 + F.softplus(distance_scores)
749+
distance_target = torch.abs(head_offset)
750+
distance_kld = -torch.log((distance_target.float() - distance_pred)**2/2 + 1)
751+
729752
state = state._replace(word_embeddings=lstm_output,
730-
subtree_embeddings={})
753+
subtree_embeddings={},
754+
distance=distance_kld)
731755
updated_states.append(state)
732756
return updated_states
733757

stanza/models/depparse/transition/state.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,8 @@
2121
# transition_lstm_embeddings is a list of the above TransitionLSTMEmbedding namedtuple - one per transition
2222
State = namedtuple('State', ['transitions', 'parsed_graph', 'word_position', 'num_words', 'current_heads',
2323
'gold_graph', 'gold_sequence', 'word_embeddings', 'subtree_embeddings',
24-
'transition_lstm_embeddings', 'subtree_lstm_embeddings'])
24+
'transition_lstm_embeddings', 'subtree_lstm_embeddings',
25+
'distance'])
2526

2627
def is_nonproj(gold_graph, node, pred):
2728
for middle in range(node+1, pred):
@@ -33,7 +34,7 @@ def is_nonproj(gold_graph, node, pred):
3334

3435
def build_gold_sequence(gold_graph):
3536
num_words = len(gold_graph.nodes()) - 1
36-
state = State([], nx.DiGraph(), 0, num_words, [], None, None, None, None, [], [])
37+
state = State([], nx.DiGraph(), 0, num_words, [], None, None, None, None, [], [], None)
3738

3839
# determine which arcs are non-projective
3940
# key is the head, value is a set of the children which are non-proj
@@ -123,7 +124,7 @@ def state_from_graph(gold_graph):
123124

124125
gold_sequence = build_gold_sequence(gold_graph)
125126
num_words = len(gold_graph.nodes()) - 1
126-
return State(transitions, empty_graph, 0, num_words, [], gold_graph, gold_sequence, None, None, [], [])
127+
return State(transitions, empty_graph, 0, num_words, [], gold_graph, gold_sequence, None, None, [], [], None)
127128

128129
def from_gold(sentence):
129130
gold_graph = nx.DiGraph()
@@ -160,4 +161,4 @@ def state_from_text(text):
160161
transitions = []
161162
num_words = len(text)
162163
empty_graph = nx.DiGraph()
163-
return State(transitions, empty_graph, 0, num_words, [], None, None, None, None, [], [])
164+
return State(transitions, empty_graph, 0, num_words, [], None, None, None, None, [], [], None)

stanza/models/parser.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,8 @@ def build_argparse():
8989
parser.add_argument('--char_hidden_dim', type=int, default=400)
9090
parser.add_argument('--deep_biaff_hidden_dim', type=int, default=400)
9191
parser.add_argument('--deep_biaff_output_dim', type=int, default=160)
92+
parser.add_argument('--distance_output_dim', type=int, default=1)
93+
parser.add_argument('--distance_factor', type=float, default=0.1, help="How much weight to put on the distance")
9294
# As an additional option, we implement arc embeddings
9395
# described in https://arxiv.org/pdf/2501.09451
9496
# Scaling Graph-Based Dependency Parsing with Arc Vectorization and Attention-Based Refinement

0 commit comments

Comments
 (0)