Skip to content

Commit 8814f40

Browse files
committed
fix rnn lstm gru bug
1 parent 66d49ba commit 8814f40

File tree

3 files changed

+22
-5
lines changed

3 files changed

+22
-5
lines changed

examples/basic_tutorials/imdb_LSTM_simple.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,16 @@
44
# The same set of code can switch the backend with one line
55
import os
66
# os.environ['TL_BACKEND'] = 'tensorflow'
7-
os.environ['TL_BACKEND'] = 'mindspore'
8-
# os.environ['TL_BACKEND'] = 'paddle'
7+
# os.environ['TL_BACKEND'] = 'mindspore'
8+
os.environ['TL_BACKEND'] = 'paddle'
99
# os.environ['TL_BACKEND'] = 'torch'
1010
import tensorlayerx as tlx
1111
from tensorlayerx.nn import Module
12-
from tensorlayerx.nn import Linear, LSTM, Embedding
12+
from tensorlayerx.nn import Linear, LSTM, Embedding, RNN
1313
from tensorlayerx.dataflow import Dataset
1414
import numpy as np
15+
prev_h = np.random.random([1, 200, 64]).astype(np.float32)
16+
prev_h = tlx.convert_to_tensor(prev_h)
1517

1618
X_train, y_train, X_test, y_test = tlx.files.load_imdb_dataset('data', nb_words=20000, test_split=0.2)
1719

@@ -48,7 +50,7 @@ def __init__(self):
4850

4951
def forward(self, x):
5052
x = self.embedding(x)
51-
x, _ = self.lstm(x)
53+
x, _ = self.lstm(x, [prev_h, prev_h])
5254
x = tlx.reduce_mean(x, axis=1)
5355
x = self.linear1(x)
5456
x = self.linear2(x)

tensorlayerx/backend/ops/paddle_nn.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1638,10 +1638,16 @@ def _cudnn_impl(self, inputs, initial_states, sequence_length):
16381638
out = pd.tensor.transpose(out, [1, 0, 2]) if not self.time_major else out
16391639
return out, tuple(state) if len(state) > 1 else state[0]
16401640

1641+
def check_hidden(self, h, batch_size):
1642+
expected_hidden_size = [self.num_layers * self.bidirect, batch_size, self.hidden_size]
1643+
if h.shape != expected_hidden_size:
1644+
raise ValueError('Expected hidden size {}, got {}.'.format(expected_hidden_size, h.shape))
1645+
16411646
def forward(self, inputs, initial_states=None):
16421647
batch_index = 1 if self.time_major else 0
16431648
dtype = inputs.dtype
16441649
sequence_length = None
1650+
batch_size = inputs.shape[batch_index]
16451651
if initial_states is None:
16461652
state_shape = (self.num_layers * self.bidirect, -1, self.hidden_size)
16471653
if self.state_components == 1:
@@ -1655,6 +1661,15 @@ def forward(self, inputs, initial_states=None):
16551661
for _ in range(self.state_components)
16561662
]
16571663
)
1664+
else:
1665+
if self.mode == 'LSTM':
1666+
h, c = initial_states
1667+
self.check_hidden(h, batch_size)
1668+
self.check_hidden(c, batch_size)
1669+
else:
1670+
self.check_hidden(initial_states, batch_size)
1671+
if not isinstance(initial_states, (tuple, list)):
1672+
initial_states = [initial_states,]
16581673
if self.could_use_cudnn:
16591674
# Add CPU kernel and dispatch in backend later
16601675
return self._cudnn_impl(inputs, initial_states, sequence_length)

tensorlayerx/backend/ops/torch_nn.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1934,7 +1934,7 @@ def check_input(self, input_shape):
19341934
)
19351935

19361936
def check_hidden(self, h, batch_size):
1937-
expected_hidden_size = (self.num_layers * self.bidirect, batch_size, self.hidden_size)
1937+
expected_hidden_size = (self.num_layers * self.num_directions, batch_size, self.hidden_size)
19381938
if h.shape != expected_hidden_size:
19391939
raise ValueError('Expected hidden size {}, got {}.'.format(expected_hidden_size, h.shape))
19401940

0 commit comments

Comments
 (0)