Skip to content

Commit 4ecb07b

Browse files
committed
update mindspore rnn
1 parent 6a4f391 commit 4ecb07b

File tree

1 file changed

+33
-32
lines changed

1 file changed

+33
-32
lines changed

tensorlayerx/backend/ops/mindspore_nn.py

Lines changed: 33 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1919,6 +1919,19 @@ def construct(self, input, h):
19191919
return h, h
19201920

19211921

1922+
@constexpr
1923+
def _init_state(shape, dtype, is_lstm):
1924+
hx = ms.Tensor(np.zeros(shape), dtype)
1925+
cx = ms.Tensor(np.zeros(shape), dtype)
1926+
if is_lstm:
1927+
return (hx, cx)
1928+
return hx
1929+
1930+
@constexpr
1931+
def _check_input_dtype_same_and_valid(args_name, args_value, valid_values, cls_name):
1932+
args = {args_name[i]: args_value[i] for i in range(len(args_value))}
1933+
validator.check_types_same_and_valid(args, valid_values, cls_name)
1934+
19221935
class rnnbase(Cell):
19231936

19241937
def __init__(
@@ -1959,10 +1972,10 @@ def __init__(
19591972
self.bidirectional = bidirectional
19601973
self.batch_first = batch_first
19611974
self.train = is_train
1962-
self.w_ih = ParameterTuple(w_ih)
1963-
self.w_hh = ParameterTuple(w_hh)
1964-
self.b_ih = ParameterTuple(b_ih)
1965-
self.b_hh = ParameterTuple(b_hh)
1975+
self.w_ih_list = ParameterTuple(w_ih)
1976+
self.w_hh_list = ParameterTuple(w_hh)
1977+
self.b_ih_list = ParameterTuple(b_ih)
1978+
self.b_hh_list = ParameterTuple(b_hh)
19661979
self.rnn = _DynamicRNN(mode)
19671980
self.is_lstm = mode == "LSTM"
19681981

@@ -2060,43 +2073,31 @@ def _stacked_dynamic_rnn(self, x, h, seq_length):
20602073
h_n = P.Concat(0)(h_n)
20612074
return output, h_n.view(h.shape)
20622075

2063-
@constexpr
2064-
def _init_state(shape, dtype, is_lstm):
2065-
hx = ms.Tensor(np.zeros(shape), dtype)
2066-
cx = ms.Tensor(np.zeros(shape), dtype)
2067-
if is_lstm:
2068-
return (hx, cx)
2069-
return hx
2070-
2071-
@constexpr
2072-
def _check_input_dtype(input_dtype, param_name, allow_dtypes, cls_name):
2073-
validator.check_type_name(param_name, input_dtype, allow_dtypes, cls_name)
2074-
20752076
def construct(self, x, hx=None, seq_length=None):
20762077
'''Defines the RNN like operators performed'''
2077-
x_dtype = P.dtype(x)
2078-
hx_dtype = P.dtype(hx)
2079-
self._check_input_dtype(x_dtype, "x", [ms.float32], self.cls_name)
2080-
self._check_input_dtype(hx_dtype, "hx", [ms.float32], self.cls_name)
2081-
if seq_length is not None:
2082-
seq_length_dtype = P.dtype(seq_length)
2083-
self._check_input_dtype(seq_length_dtype, "seq_length", [ms.int32, ms.int64], self.cls_name)
2084-
20852078
max_batch_size = x.shape[0] if self.batch_first else x.shape[1]
20862079
num_directions = 2 if self.bidirectional else 1
2087-
if hx is None:
2088-
hx = self._init_state(
2089-
(self.num_layers * num_directions, max_batch_size, self.hidden_size), x.dtype, self.is_lstm
2090-
)
2080+
x_dtype = x.dtype
2081+
if hx is not None:
2082+
if not self.is_lstm:
2083+
_check_input_dtype_same_and_valid(['x', 'hx'], [x_dtype, hx.dtype], \
2084+
[ms.float32, ms.float16], self.cls_name)
2085+
else:
2086+
_check_input_dtype_same_and_valid(['x', 'hx[0]', 'hx[1]'], [x_dtype, hx[0].dtype, hx[1].dtype], \
2087+
[ms.float32, ms.float16], self.cls_name)
2088+
else:
2089+
hx = _init_state((self.num_layers * num_directions, max_batch_size, self.hidden_size), x_dtype, self.is_lstm)
20912090
if self.batch_first:
20922091
x = P.Transpose()(x, (1, 0, 2))
20932092
if self.bidirectional:
2094-
x, h = self._stacked_bi_dynamic_rnn(x, hx, seq_length)
2093+
x_n, hx_n = self._stacked_bi_dynamic_rnn(x, hx, seq_length)
20952094
else:
2096-
x, h = self._stacked_dynamic_rnn(x, hx, seq_length)
2095+
x_n, hx_n = self._stacked_dynamic_rnn(x, hx, seq_length)
20972096
if self.batch_first:
2098-
x = P.Transpose()(x, (1, 0, 2))
2099-
return x, h
2097+
x_n = P.Transpose()(x_n, (1, 0, 2))
2098+
if not self.is_lstm:
2099+
return x_n.astype(x_dtype), hx_n.astype(x_dtype)
2100+
return x_n.astype(x_dtype), (hx_n[0].astype(x_dtype), hx_n[1].astype(x_dtype))
21002101

21012102

21022103
class layernorm(Cell):

0 commit comments

Comments
 (0)