@@ -18,15 +18,21 @@ class Layer_RNN_Test(CustomTestCase):
18
18
@classmethod
19
19
def setUpClass (cls ):
20
20
21
- cls .batch_size = 2
21
+ cls .batch_size = 10
22
22
23
23
cls .vocab_size = 20
24
24
cls .embedding_size = 4
25
25
26
26
cls .hidden_size = 8
27
27
cls .num_steps = 6
28
28
29
+ cls .data_n_steps = np .random .randint (low = cls .num_steps // 2 , high = cls .num_steps + 1 , size = cls .batch_size )
29
30
cls .data_x = np .random .random ([cls .batch_size , cls .num_steps , cls .embedding_size ]).astype (np .float32 )
31
+
32
+ for i in range (cls .batch_size ):
33
+ for j in range (cls .data_n_steps [i ], cls .num_steps ):
34
+ cls .data_x [i ][j ][:] = 0
35
+
30
36
cls .data_y = np .zeros ([cls .batch_size , 1 ]).astype (np .float32 )
31
37
cls .data_y2 = np .zeros ([cls .batch_size , cls .num_steps ]).astype (np .float32 )
32
38
@@ -865,6 +871,48 @@ def forward(self, x):
865
871
print (output .shape )
866
872
print (state )
867
873
874
+ def test_dynamic_rnn_with_fake_data (self ):
875
+
876
+ class CustomisedModel (tl .models .Model ):
877
+
878
+ def __init__ (self ):
879
+ super (CustomisedModel , self ).__init__ ()
880
+ self .rnnlayer = tl .layers .LSTMRNN (
881
+ units = 8 , dropout = 0.1 , in_channels = 4 ,
882
+ return_last_output = True ,
883
+ return_last_state = False
884
+ )
885
+ self .dense = tl .layers .Dense (in_channels = 8 , n_units = 1 )
886
+
887
+ def forward (self , x ):
888
+ z = self .rnnlayer (x , sequence_length = tl .layers .retrieve_seq_length_op3 (x ))
889
+ z = self .dense (z [:, :])
890
+ return z
891
+
892
+ rnn_model = CustomisedModel ()
893
+ print (rnn_model )
894
+ optimizer = tf .optimizers .Adam (learning_rate = 0.01 )
895
+ rnn_model .train ()
896
+
897
+ for epoch in range (50 ):
898
+ with tf .GradientTape () as tape :
899
+ pred_y = rnn_model (self .data_x )
900
+ loss = tl .cost .mean_squared_error (pred_y , self .data_y )
901
+
902
+ gradients = tape .gradient (loss , rnn_model .trainable_weights )
903
+ optimizer .apply_gradients (zip (gradients , rnn_model .trainable_weights ))
904
+
905
+ if (epoch + 1 ) % 10 == 0 :
906
+ print ("epoch %d, loss %f" % (epoch , loss ))
907
+
908
+ # Testing saving and restoring of RNN weights
909
+ rnn_model2 = CustomisedModel ()
910
+ rnn_model2 .eval ()
911
+ pred_y = rnn_model2 (self .data_x )
912
+ loss = tl .cost .mean_squared_error (pred_y , self .data_y )
913
+ print ("MODEL INIT loss %f" % (loss ))
914
+
915
+
868
916
869
917
if __name__ == '__main__' :
870
918
0 commit comments