We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent e932fe8 commit 6173d1aCopy full SHA for 6173d1a
tests/layers/test_layers_recurrent.py
@@ -905,13 +905,21 @@ def forward(self, x):
905
if (epoch + 1) % 10 == 0:
906
print("epoch %d, loss %f" % (epoch, loss))
907
908
+ filename = "dynamic_rnn.h5"
909
+ rnn_model.save_weights(filename)
910
+
911
# Testing saving and restoring of RNN weights
912
rnn_model2 = CustomisedModel()
913
rnn_model2.eval()
914
pred_y = rnn_model2(self.data_x)
915
loss = tl.cost.mean_squared_error(pred_y, self.data_y)
916
print("MODEL INIT loss %f" % (loss))
917
918
+ rnn_model2.load_weights(filename)
919
+ pred_y = rnn_model2(self.data_x)
920
+ loss = tl.cost.mean_squared_error(pred_y, self.data_y)
921
+ print("MODEL RESTORE W loss %f" % (loss))
922
923
924
925
if __name__ == '__main__':
0 commit comments