Skip to content

Commit 7e3c77c

Browse files
committed
fix RNN LSTM GRU
1 parent 1914746 commit 7e3c77c

File tree

6 files changed

+14
-44
lines changed

6 files changed

+14
-44
lines changed

examples/basic_tutorials/imdb_LSTM_simple.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@
55
import os
66
# os.environ['TL_BACKEND'] = 'tensorflow'
77
# os.environ['TL_BACKEND'] = 'mindspore'
8-
os.environ['TL_BACKEND'] = 'paddle'
9-
# os.environ['TL_BACKEND'] = 'torch'
8+
# os.environ['TL_BACKEND'] = 'paddle'
9+
os.environ['TL_BACKEND'] = 'torch'
1010
import tensorlayerx as tlx
1111
from tensorlayerx.nn import Module
1212
from tensorlayerx.nn import Linear, LSTM, Embedding
@@ -42,7 +42,7 @@ class ImdbNet(Module):
4242
def __init__(self):
4343
super(ImdbNet, self).__init__()
4444
self.embedding = Embedding(num_embeddings=vocab_size, embedding_dim=64)
45-
self.lstm = LSTM(input_size=64, hidden_size=64)
45+
self.lstm = LSTM(input_size=64, hidden_size=64, num_layers=2)
4646
self.linear1 = Linear(in_features=64, out_features=64, act=tlx.nn.ReLU)
4747
self.linear2 = Linear(in_features=64, out_features=2)
4848

tensorlayerx/nn/core/core_mindspore.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,9 @@
1212
import numpy
1313
from mindspore.common.api import _pynative_executor
1414
from collections import OrderedDict, abc as container_abcs
15-
from mindspore.nn.cell import ParameterTuple
1615

17-
__all__ = ['Module', 'Sequential', 'ModuleList', 'ModuleDict', 'Parameter', 'ParameterList', 'ParameterDict', 'ParameterTuple']
16+
17+
__all__ = ['Module', 'Sequential', 'ModuleList', 'ModuleDict', 'Parameter', 'ParameterList', 'ParameterDict']
1818

1919
_global_layer_name_dict = {}
2020
_global_layer_node = []

tensorlayerx/nn/core/core_paddle.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,11 @@
1414
from paddle.fluid.dygraph import parallel_helper
1515
import paddle as pd
1616
from collections import OrderedDict, abc as container_abcs
17-
from paddle.nn import ParameterList as ParameterTuple
1817

1918
_global_layer_name_dict = {}
2019
_global_layer_node = []
21-
# TODO Need to implement ParameterTuple
22-
__all__ = ['Module', 'Sequential', 'ModuleList', 'ModuleDict', 'Parameter', 'ParameterList', 'ParameterDict', 'ParameterTuple']
20+
21+
__all__ = ['Module', 'Sequential', 'ModuleList', 'ModuleDict', 'Parameter', 'ParameterList', 'ParameterDict']
2322

2423

2524
class Module(Layer):

tensorlayerx/nn/core/core_tensorflow.py

Lines changed: 1 addition & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import tensorflow as tf
1111
from tensorlayerx.nn.layers.utils import (get_variable_with_initializer, random_normal)
1212

13-
__all__ = ['Module', 'Sequential', 'ModuleList', 'ModuleDict', 'Parameter', 'ParameterList', 'ParameterDict', 'ParameterTuple']
13+
__all__ = ['Module', 'Sequential', 'ModuleList', 'ModuleDict', 'Parameter', 'ParameterList', 'ParameterDict']
1414

1515
_global_layer_name_dict = {}
1616
_global_layer_node = []
@@ -1287,34 +1287,6 @@ def update(self, parameters):
12871287
def __call__(self, input):
12881288
raise RuntimeError('ParameterDict should not be called.')
12891289

1290-
1291-
class ParameterTuple(tuple):
1292-
"""
1293-
ParameterTuple for storing tuple of parameters.
1294-
"""
1295-
def __new__(cls, iterable):
1296-
data = tuple(iterable)
1297-
ids = set()
1298-
orders = {}
1299-
for x in data:
1300-
if not isinstance(x, tf.Variable):
1301-
raise TypeError(f"ParameterTuple input should be `Parameter` collection."
1302-
f"But got a {type(iterable)}, {iterable}")
1303-
if id(x) not in ids:
1304-
ids.add(id(x))
1305-
if x.name not in orders.keys():
1306-
orders[x.name] = [0, x]
1307-
else:
1308-
if isinstance(orders[x.name], list):
1309-
name = x.name
1310-
orders[name][1].name = name + "_" + str(0)
1311-
x.name = x.name + "_" + str(1)
1312-
orders[name] = 1
1313-
else:
1314-
orders[x.name] += 1
1315-
x.name = x.name + "_" + str(orders[x.name])
1316-
return tuple.__new__(ParameterTuple, tuple(data))
1317-
13181290
def _valid_index(layer_num, index):
13191291
if not isinstance(index, int):
13201292
raise TypeError("Index {} is not int type")

tensorlayerx/nn/core/core_torch.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,11 @@
1212
from collections import OrderedDict, abc as container_abcs
1313
import warnings
1414
import tensorlayerx as tlx
15-
from torch.nn import ParameterList as ParameterTuple
1615

1716
_global_layer_name_dict = {}
1817
_global_layer_node = []
1918

20-
__all__ = ['Module', 'Sequential', 'ModuleList', 'ModuleDict', 'Parameter', 'ParameterList', 'ParameterDict', 'ParameterTuple']
19+
__all__ = ['Module', 'Sequential', 'ModuleList', 'ModuleDict', 'Parameter', 'ParameterList', 'ParameterDict']
2120

2221

2322
class Module(T_Module):

tensorlayerx/nn/layers/recurrent.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import numpy as np
55
import tensorlayerx as tlx
66
from tensorlayerx import logging
7-
from tensorlayerx.nn.core import Module, ParameterTuple
7+
from tensorlayerx.nn.core import Module, ParameterList
88

99
__all__ = [
1010
'RNN',
@@ -488,10 +488,10 @@ def build(self, inputs_shape):
488488
var_name='bias_hh_l{}{}'.format(layer, suffix), shape=(gate_size, ), init=_init
489489
)
490490
)
491-
# self.weight_ih = ParameterTuple(self.w_ih)
492-
# self.weight_hh = ParameterTuple(self.w_hh)
493-
# self.bias_ih = ParameterTuple(self.b_ih)
494-
# self.bias_hh = ParameterTuple(self.b_hh)
491+
self.weight_ih = ParameterList(self.w_ih)
492+
self.weight_hh = ParameterList(self.w_hh)
493+
self.bias_ih = ParameterList(self.b_ih)
494+
self.bias_hh =ParameterList(self.b_hh)
495495
self.rnn = tlx.ops.rnnbase(
496496
mode=self.mode, input_size=self.input_size, hidden_size=self.hidden_size, num_layers=self.num_layers,
497497
bias=self.bias, batch_first=self.batch_first, dropout=self.dropout, bidirectional=self.bidirectional,

0 commit comments

Comments
 (0)