Skip to content

Commit e932fe8

Browse files
committed
test dynamic rnn with fake data
1 parent aeb0fb0 commit e932fe8

File tree

1 file changed

+49
-1
lines changed

1 file changed

+49
-1
lines changed

tests/layers/test_layers_recurrent.py

Lines changed: 49 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,15 +18,21 @@ class Layer_RNN_Test(CustomTestCase):
1818
@classmethod
1919
def setUpClass(cls):
2020

21-
cls.batch_size = 2
21+
cls.batch_size = 10
2222

2323
cls.vocab_size = 20
2424
cls.embedding_size = 4
2525

2626
cls.hidden_size = 8
2727
cls.num_steps = 6
2828

29+
cls.data_n_steps = np.random.randint(low=cls.num_steps // 2, high=cls.num_steps + 1, size=cls.batch_size)
2930
cls.data_x = np.random.random([cls.batch_size, cls.num_steps, cls.embedding_size]).astype(np.float32)
31+
32+
for i in range(cls.batch_size):
33+
for j in range(cls.data_n_steps[i], cls.num_steps):
34+
cls.data_x[i][j][:] = 0
35+
3036
cls.data_y = np.zeros([cls.batch_size, 1]).astype(np.float32)
3137
cls.data_y2 = np.zeros([cls.batch_size, cls.num_steps]).astype(np.float32)
3238

@@ -865,6 +871,48 @@ def forward(self, x):
865871
print(output.shape)
866872
print(state)
867873

874+
def test_dynamic_rnn_with_fake_data(self):
875+
876+
class CustomisedModel(tl.models.Model):
877+
878+
def __init__(self):
879+
super(CustomisedModel, self).__init__()
880+
self.rnnlayer = tl.layers.LSTMRNN(
881+
units=8, dropout=0.1, in_channels=4,
882+
return_last_output=True,
883+
return_last_state=False
884+
)
885+
self.dense = tl.layers.Dense(in_channels=8, n_units=1)
886+
887+
def forward(self, x):
888+
z = self.rnnlayer(x, sequence_length=tl.layers.retrieve_seq_length_op3(x))
889+
z = self.dense(z[:, :])
890+
return z
891+
892+
rnn_model = CustomisedModel()
893+
print(rnn_model)
894+
optimizer = tf.optimizers.Adam(learning_rate=0.01)
895+
rnn_model.train()
896+
897+
for epoch in range(50):
898+
with tf.GradientTape() as tape:
899+
pred_y = rnn_model(self.data_x)
900+
loss = tl.cost.mean_squared_error(pred_y, self.data_y)
901+
902+
gradients = tape.gradient(loss, rnn_model.trainable_weights)
903+
optimizer.apply_gradients(zip(gradients, rnn_model.trainable_weights))
904+
905+
if (epoch + 1) % 10 == 0:
906+
print("epoch %d, loss %f" % (epoch, loss))
907+
908+
# Testing saving and restoring of RNN weights
909+
rnn_model2 = CustomisedModel()
910+
rnn_model2.eval()
911+
pred_y = rnn_model2(self.data_x)
912+
loss = tl.cost.mean_squared_error(pred_y, self.data_y)
913+
print("MODEL INIT loss %f" % (loss))
914+
915+
868916

869917
if __name__ == '__main__':
870918

0 commit comments

Comments
 (0)