Skip to content

Commit 6173d1a

Browse files
committed
test dynamic rnn with fake data
1 parent e932fe8 commit 6173d1a

File tree

1 file changed

+8
-0
lines changed

1 file changed

+8
-0
lines changed

tests/layers/test_layers_recurrent.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -905,13 +905,21 @@ def forward(self, x):
905905
if (epoch + 1) % 10 == 0:
906906
print("epoch %d, loss %f" % (epoch, loss))
907907

908+
filename = "dynamic_rnn.h5"
909+
rnn_model.save_weights(filename)
910+
908911
# Testing saving and restoring of RNN weights
909912
rnn_model2 = CustomisedModel()
910913
rnn_model2.eval()
911914
pred_y = rnn_model2(self.data_x)
912915
loss = tl.cost.mean_squared_error(pred_y, self.data_y)
913916
print("MODEL INIT loss %f" % (loss))
914917

918+
rnn_model2.load_weights(filename)
919+
pred_y = rnn_model2(self.data_x)
920+
loss = tl.cost.mean_squared_error(pred_y, self.data_y)
921+
print("MODEL RESTORE W loss %f" % (loss))
922+
915923

916924

917925
if __name__ == '__main__':

0 commit comments

Comments
 (0)