@@ -18,15 +18,21 @@ class Layer_RNN_Test(CustomTestCase):
1818 @classmethod
1919 def setUpClass (cls ):
2020
21- cls .batch_size = 2
21+ cls .batch_size = 10
2222
2323 cls .vocab_size = 20
2424 cls .embedding_size = 4
2525
2626 cls .hidden_size = 8
2727 cls .num_steps = 6
2828
29+ cls .data_n_steps = np .random .randint (low = cls .num_steps // 2 , high = cls .num_steps + 1 , size = cls .batch_size )
2930 cls .data_x = np .random .random ([cls .batch_size , cls .num_steps , cls .embedding_size ]).astype (np .float32 )
31+
32+ for i in range (cls .batch_size ):
33+ for j in range (cls .data_n_steps [i ], cls .num_steps ):
34+ cls .data_x [i ][j ][:] = 0
35+
3036 cls .data_y = np .zeros ([cls .batch_size , 1 ]).astype (np .float32 )
3137 cls .data_y2 = np .zeros ([cls .batch_size , cls .num_steps ]).astype (np .float32 )
3238
@@ -865,6 +871,48 @@ def forward(self, x):
865871 print (output .shape )
866872 print (state )
867873
874+ def test_dynamic_rnn_with_fake_data (self ):
875+
876+ class CustomisedModel (tl .models .Model ):
877+
878+ def __init__ (self ):
879+ super (CustomisedModel , self ).__init__ ()
880+ self .rnnlayer = tl .layers .LSTMRNN (
881+ units = 8 , dropout = 0.1 , in_channels = 4 ,
882+ return_last_output = True ,
883+ return_last_state = False
884+ )
885+ self .dense = tl .layers .Dense (in_channels = 8 , n_units = 1 )
886+
887+ def forward (self , x ):
888+ z = self .rnnlayer (x , sequence_length = tl .layers .retrieve_seq_length_op3 (x ))
889+ z = self .dense (z [:, :])
890+ return z
891+
892+ rnn_model = CustomisedModel ()
893+ print (rnn_model )
894+ optimizer = tf .optimizers .Adam (learning_rate = 0.01 )
895+ rnn_model .train ()
896+
897+ for epoch in range (50 ):
898+ with tf .GradientTape () as tape :
899+ pred_y = rnn_model (self .data_x )
900+ loss = tl .cost .mean_squared_error (pred_y , self .data_y )
901+
902+ gradients = tape .gradient (loss , rnn_model .trainable_weights )
903+ optimizer .apply_gradients (zip (gradients , rnn_model .trainable_weights ))
904+
905+ if (epoch + 1 ) % 10 == 0 :
906+ print ("epoch %d, loss %f" % (epoch , loss ))
907+
908+ # Testing saving and restoring of RNN weights
909+ rnn_model2 = CustomisedModel ()
910+ rnn_model2 .eval ()
911+ pred_y = rnn_model2 (self .data_x )
912+ loss = tl .cost .mean_squared_error (pred_y , self .data_y )
913+ print ("MODEL INIT loss %f" % (loss ))
914+
915+
868916
869917if __name__ == '__main__' :
870918
0 commit comments