-
Notifications
You must be signed in to change notification settings - Fork 3
Expand file tree
/
Copy pathimdb_model.py
More file actions
133 lines (104 loc) · 7.11 KB
/
imdb_model.py
File metadata and controls
133 lines (104 loc) · 7.11 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
###############################################################################
# Modifier Author: Md Rizwan Parvez
# BCN Author: Wasi Ahmad
# Project: Biattentive Classification Network for Sentence Classification
# Date Created: 01/06/2018
#
# File Description: This script contains code related to the sequence-to-sequence
# network.
###############################################################################
import torch, helper
import torch.nn as nn
import torch.nn.functional as f
from collections import OrderedDict
from nn_layer import EmbeddingLayer, Encoder, MaxoutNetwork, Selector
# details of BCN can be found in the paper, "Learned in Translation: Contextualized Word Vectors"
class BCN(nn.Module):
"""Biattentive classification network architecture for sentence classification."""
def __init__(self, dictionary, embedding_index, args):
""""Constructor of the class."""
super(BCN, self).__init__()
self.config = args
self.num_directions = 2 if self.config.bidirection else 1
self.dictionary = dictionary
self.embedding = EmbeddingLayer(len(self.dictionary), self.config.emsize, self.config.emtraining, self.config)
self.embedding.init_embedding_weights(self.dictionary, embedding_index, self.config.emsize)
self.selector = Selector(self.config.emsize, self.config.dropout)
self.relu_network = nn.Sequential(OrderedDict([
('dense1', nn.Linear(self.config.emsize, self.config.nhid)),
('nonlinearity', nn.ReLU())
]))
self.encoder = Encoder(self.config.nhid, self.config.nhid, self.config.bidirection, self.config.nlayers,
self.config)
self.biatt_encoder1 = Encoder(self.config.nhid * self.num_directions * 3, self.config.nhid, self.config.bidirection, 1,
self.config)
self.biatt_encoder2 = Encoder(self.config.nhid * self.num_directions * 3, self.config.nhid, self.config.bidirection, 1,
self.config)
self.ffnn = nn.Linear(self.config.nhid * self.num_directions, 1)
self.maxout_network = MaxoutNetwork(self.config.nhid * self.num_directions * 4 * 2, self.config.num_class,
num_units=self.config.num_units)
def forward(self, sentence1, sentence1_len_old, sentence2, sentence2_len_old):
"""
Forward computation of the biattentive classification network.
Returns classification scores for a batch of sentence pairs.
:param sentence1: 2d tensor [batch_size x max_length]
:param sentence1_len: 1d numpy array [batch_size]
:param sentence2: 2d tensor [batch_size x max_length]
:param sentence2_len: 1d numpy array [batch_size]
:return: classification scores over batch [batch_size x num_classes]
"""
# step1: embed the words into vectors [batch_size x max_length x emsize]
embedded_x1 = self.embedding(sentence1)
embedded_y1 = self.embedding(sentence2)
###################################### selection ######################################
selection_x = self.selector(embedded_x1)
selection_y = self.selector(embedded_y1)
assert selection_x.size() == sentence1.size()
assert selection_y.size() == sentence2.size()
result_x = sentence1.mul(selection_x) #word ids that are selected contains zeros where it's not selected (ony selected can be found by selected_x[selected_x!=0])
result_y = sentence2.mul(selection_y)
selected_x, sentence1_len = helper.get_selected_tensor(result_x, self.config.cuda) #sentence1_len is a numpy array
selected_y, sentence2_len = helper.get_selected_tensor(result_y, self.config.cuda) #sentence2_len is a numpy array
embedded_x = self.embedding(selected_x)
embedded_y = self.embedding(selected_y)
# batch
# zsum = ##### same as sentence1_len #####T.sum(z, axis=0, dtype=theano.config.floatX)
zdiff1 = (selection_x[:,1:]-selection_x[:,:-1]).abs().sum(1) ####T.sum(T.abs_(z[1:]-z[:-1]), axis=0, dtype=theano.config.floatX)
zdiff2 = (selection_y[:,1:]-selection_y[:,:-1]).abs().sum(1) ####T.sum(T.abs_(z[1:]-z[:-1]), axis=0, dtype=theano.config.floatX)
assert zdiff1.size()[0]==len(sentence1_len)
###################################### selection ######################################
# step2: pass the embedded words through the ReLU network [batch_size x max_length x hidden_size]
embedded_x = self.relu_network(embedded_x)
embedded_y = self.relu_network(embedded_y)
# step3: pass the word vectors through the encoder [batch_size x max_length x hidden_size * num_directions]
encoded_x = self.encoder(embedded_x, sentence1_len)
# For the second sentences in batch
encoded_y = self.encoder(embedded_y, sentence2_len)
# step4: compute affinity matrix [batch_size x sent1_max_length x sent2_max_length]
affinity_mat = torch.bmm(encoded_x, encoded_y.transpose(1, 2))
# step5: compute conditioned representations [batch_size x max_length x hidden_size * num_directions]
conditioned_x = torch.bmm(f.softmax(affinity_mat, 2).transpose(1, 2), encoded_x)
conditioned_y = torch.bmm(f.softmax(affinity_mat.transpose(1, 2), 2).transpose(1, 2), encoded_y)
# step6: generate input of the biattentive encoders [batch_size x max_length x hidden_size * num_directions * 3]
biatt_input_x = torch.cat(
(encoded_x, torch.abs(encoded_x - conditioned_y), torch.mul(encoded_x, conditioned_y)), 2)
biatt_input_y = torch.cat(
(encoded_y, torch.abs(encoded_y - conditioned_x), torch.mul(encoded_y, conditioned_x)), 2)
# step7: pass the conditioned information through the biattentive encoders
# [batch_size x max_length x hidden_size * num_directions]
biatt_x = self.biatt_encoder1(biatt_input_x, sentence1_len)
biatt_y = self.biatt_encoder2(biatt_input_y, sentence2_len)
# step8: compute self-attentive pooling features
att_weights_x = self.ffnn(biatt_x.view(-1, biatt_x.size(2))).squeeze(1)
att_weights_x = f.softmax(att_weights_x.view(*biatt_x.size()[:-1]), 1)
att_weights_y = self.ffnn(biatt_y.view(-1, biatt_y.size(2))).squeeze(1)
att_weights_y = f.softmax(att_weights_y.view(*biatt_y.size()[:-1]), 1)
self_att_x = torch.bmm(biatt_x.transpose(1, 2), att_weights_x.unsqueeze(2)).squeeze(2)
self_att_y = torch.bmm(biatt_y.transpose(1, 2), att_weights_y.unsqueeze(2)).squeeze(2)
# step9: compute the joint representations [batch_size x hidden_size * num_directions * 4]
# print (' self_att_x size: ', self_att_x.size())
pooled_x = torch.cat((biatt_x.max(1)[0], biatt_x.mean(1), biatt_x.min(1)[0], self_att_x), 1)
pooled_y = torch.cat((biatt_y.max(1)[0], biatt_y.mean(1), biatt_y.min(1)[0], self_att_y), 1)
# step10: pass the pooled representations through the maxout network
score = self.maxout_network(torch.cat((pooled_x, pooled_y), 1))
return score, sentence1_len, sentence2_len, zdiff1, zdiff2