Skip to content

Commit 0790054

Browse files
committed
refator recurrent layer
1 parent 7fc6650 commit 0790054

File tree

7 files changed

+516
-389
lines changed

7 files changed

+516
-389
lines changed

tensorlayerx/backend/ops/mindspore_nn.py

Lines changed: 167 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import mindspore as ms
77
import mindspore.ops as P
88
from mindspore import context
9+
from mindspore.ops.primitive import constexpr
910
from mindspore.nn.cell import Cell
1011
from mindspore._checkparam import Rel
1112
from mindspore.ops import functional as F
@@ -17,6 +18,8 @@
1718
from mindspore.communication.management import get_group_size, get_rank
1819
from mindspore.ops.operations import LayerNorm
1920
import mindspore.numpy as np
21+
from mindspore.common.parameter import ParameterTuple
22+
from mindspore.nn.layer.rnns import _DynamicRNN
2023
import warnings
2124
import math
2225

@@ -833,7 +836,9 @@ def __init__(self, ksize, strides, padding, data_format=None):
833836
self.data_format, self.padding = preprocess_2d_format(data_format=data_format, padding=padding)
834837
ms_ksize = ksize[1]
835838
ms_strides = strides[1]
836-
self.avgpool = P.AvgPool(kernel_size=ms_ksize, strides=ms_strides, pad_mode=padding, data_format=self.data_format)
839+
self.avgpool = P.AvgPool(
840+
kernel_size=ms_ksize, strides=ms_strides, pad_mode=padding, data_format=self.data_format
841+
)
837842

838843
def construct(self, inputs):
839844
outputs = self.avgpool(inputs)
@@ -930,7 +935,7 @@ def __init__(self, ksize, strides, padding, data_format='NCDHW'):
930935
if data_format == 'NCDHW':
931936
strides = (strides[2], strides[3], strides[4])
932937
print(ksize, strides, padding)
933-
self.avg_pool = P.AvgPool3D(kernel_size=ksize, strides = strides, pad_mode=padding, data_format=data_format)
938+
self.avg_pool = P.AvgPool3D(kernel_size=ksize, strides=strides, pad_mode=padding, data_format=data_format)
934939

935940
def __call__(self, inputs):
936941
return self.avg_pool(inputs)
@@ -1838,15 +1843,12 @@ def __init__(self, weight_ih, weight_hh, bias_ih, bias_hh, act):
18381843
self.bias_ih = bias_ih
18391844
self.bias_hh = bias_hh
18401845
self.act_fn = P.ReLU() if act == 'relu' else P.Tanh()
1841-
self.transpose = P.Transpose()
18421846

18431847
def construct(self, input, h):
1844-
self.weight_ih = self.transpose(self.weight_ih, (1, 0))
1845-
i2h = P.matmul(input, self.weight_ih)
1848+
i2h = P.MatMul(False, True)(input, self.weight_ih)
18461849
if self.bias_ih is not None:
18471850
i2h += self.bias_ih
1848-
self.weight_hh = self.transpose(self.weight_hh, (1, 0))
1849-
h2h = P.matmul(h, self.weight_hh)
1851+
h2h = P.MatMul(False, True)(h, self.weight_hh)
18501852
if self.bias_hh is not None:
18511853
h2h += self.bias_hh
18521854
h = self.act_fn(i2h + h2h)
@@ -1863,17 +1865,14 @@ def __init__(self, weight_ih, weight_hh, bias_ih, bias_hh):
18631865
self.bias_hh = bias_hh
18641866
self.gate_act_fn = P.Sigmoid()
18651867
self.act_fn = P.Tanh()
1866-
self.transpose = P.Transpose()
18671868
self.split = P.Split(axis=-1, output_num=4)
18681869

18691870
def construct(self, input, h, c):
18701871

1871-
self.weight_ih = self.transpose(self.weight_ih, (1, 0))
1872-
gates = P.matmul(input, self.weight_ih)
1872+
gates = P.MatMul(False, True)(input, self.weight_ih)
18731873
if self.bias_ih is not None:
18741874
gates += self.bias_ih
1875-
self.weight_hh = self.transpose(self.weight_hh, (1, 0))
1876-
gates += P.matmul(h, self.weight_hh)
1875+
gates += P.MatMul(False, True)(h, self.weight_hh)
18771876
if self.bias_hh is not None:
18781877
gates += self.bias_hh
18791878

@@ -1902,12 +1901,10 @@ def __init__(self, weight_ih, weight_hh, bias_ih, bias_hh):
19021901

19031902
def construct(self, input, h):
19041903

1905-
self.weight_ih = self.transpose(self.weight_ih, (1, 0))
1906-
x_gates = P.matmul(input, self.weight_ih)
1904+
x_gates = P.MatMul(False, True)(input, self.weight_ih)
19071905
if self.bias_ih is not None:
19081906
x_gates += self.bias_ih
1909-
self.weight_hh = self.transpose(self.weight_hh, (1, 0))
1910-
h_gates = P.matmul(h, self.weight_hh)
1907+
h_gates = P.MatMul(False, True)(h, self.weight_hh)
19111908
if self.bias_hh is not None:
19121909
h_gates += self.bias_hh
19131910

@@ -1935,47 +1932,171 @@ def __init__(
19351932
dropout,
19361933
bidirectional,
19371934
is_train,
1935+
w_ih,
1936+
w_hh,
1937+
b_ih,
1938+
b_hh,
19381939
):
19391940
super(rnnbase, self).__init__()
1941+
if not 0 <= dropout < 1:
1942+
raise ValueError("dropout should be a number in range [0, 1).")
1943+
if dropout > 0 and num_layers == 1:
1944+
raise ValueError(
1945+
"dropout option adds dropout after all but last "
1946+
"recurrent layer, so non-zero dropout expects "
1947+
"num_layers greater than 1, but got dropout={} and "
1948+
"num_layers={}".format(dropout, num_layers)
1949+
)
19401950
self.mode = mode
1951+
self.reverse = P.ReverseV2([0])
1952+
self.reverse_sequence = P.ReverseSequence(0, 1)
19411953
self.input_size = input_size
19421954
self.hidden_size = hidden_size
19431955
self.num_layers = num_layers
1944-
self.bidirect = 2 if bidirectional else 1
1956+
self.dropout = dropout
1957+
self.dropout_op = ms.nn.Dropout(float(1 - dropout))
1958+
self.has_bias = bias
1959+
self.bidirectional = bidirectional
19451960
self.batch_first = batch_first
1946-
if mode == 'LSTM':
1947-
self.lstm = ms.nn.LSTM(
1948-
input_size=input_size, hidden_size=hidden_size, num_layers=num_layers, has_bias=bias,
1949-
batch_first=batch_first, dropout=dropout, bidirectional=bidirectional
1950-
)
1951-
elif mode == 'GRU':
1952-
1953-
raise NotImplementedError
1954-
1955-
elif mode == 'RNN_TANH':
1956-
1957-
raise NotImplementedError
1958-
1959-
elif mode == 'RNN_RELU':
1960-
1961-
raise NotImplementedError
1961+
self.train = is_train
1962+
self.w_ih = ParameterTuple(w_ih)
1963+
self.w_hh = ParameterTuple(w_hh)
1964+
self.b_ih = ParameterTuple(b_ih)
1965+
self.b_hh = ParameterTuple(b_hh)
1966+
self.rnn = _DynamicRNN(mode)
1967+
self.is_lstm = mode == "LSTM"
19621968

19631969
self.zeros = P.Zeros()
19641970

1965-
def construct(self, input, states):
1966-
input_shape = input.shape
1967-
input_dtype = input.dtype
1968-
if self.mode == 'LSTM':
1969-
if self.batch_first:
1970-
batch_size = input_shape[0]
1971+
def _stacked_bi_dynamic_rnn(self, x, h, seq_length):
1972+
"""stacked bidirectional dynamic_rnn"""
1973+
pre_layer = x
1974+
h_n = ()
1975+
c_n = ()
1976+
output = 0
1977+
for i in range(self.num_layers):
1978+
offset = i * 2
1979+
if self.has_bias:
1980+
w_f_ih, w_f_hh, b_f_ih, b_f_hh = \
1981+
self.w_ih_list[offset], self.w_hh_list[offset], \
1982+
self.b_ih_list[offset], self.b_hh_list[offset]
1983+
w_b_ih, w_b_hh, b_b_ih, b_b_hh = \
1984+
self.w_ih_list[offset + 1], self.w_hh_list[offset + 1], \
1985+
self.b_ih_list[offset + 1], self.b_hh_list[offset + 1]
1986+
else:
1987+
w_f_ih, w_f_hh = self.w_ih_list[offset], self.w_hh_list[offset]
1988+
w_b_ih, w_b_hh = self.w_ih_list[offset + 1], self.w_hh_list[offset + 1]
1989+
b_f_ih, b_f_hh, b_b_ih, b_b_hh = None, None, None, None
1990+
if self.is_lstm:
1991+
h_f_i = (h[0][offset], h[1][offset])
1992+
h_b_i = (h[0][offset + 1], h[1][offset + 1])
1993+
else:
1994+
h_f_i = h[offset]
1995+
h_b_i = h[offset + 1]
1996+
if seq_length is None:
1997+
x_b = self.reverse(pre_layer)
19711998
else:
1972-
batch_size = input_shape[1]
1973-
if states is None:
1974-
h = self.zeros((self.bidirect * self.num_layers, batch_size, self.hidden_size), input_dtype)
1975-
c = self.zeros((self.bidirect * self.num_layers, batch_size, self.hidden_size), input_dtype)
1976-
states = (h, c)
1977-
output, (h, c) = self.lstm(input, states)
1978-
return output, (h, c)
1999+
x_b = self.reverse_sequence(pre_layer, seq_length)
2000+
output_f, h_t_f = self.rnn(pre_layer, h_f_i, seq_length, w_f_ih, w_f_hh, b_f_ih, b_f_hh)
2001+
output_b, h_t_b = self.rnn(x_b, h_b_i, seq_length, w_b_ih, w_b_hh, b_b_ih, b_b_hh)
2002+
if seq_length is None:
2003+
output_b = self.reverse(output_b)
2004+
else:
2005+
output_b = self.reverse_sequence(output_b, seq_length)
2006+
output = P.Concat(2)((output_f, output_b))
2007+
pre_layer = self.dropout_op(output) if (self.dropout != 0 and i < self.num_layers - 1) else output
2008+
if self.is_lstm:
2009+
h_n += (
2010+
h_t_f[0],
2011+
h_t_b[0],
2012+
)
2013+
c_n += (
2014+
h_t_f[1],
2015+
h_t_b[1],
2016+
)
2017+
else:
2018+
h_n += (
2019+
h_t_f,
2020+
h_t_b,
2021+
)
2022+
if self.is_lstm:
2023+
h_n = P.Concat(0)(h_n)
2024+
c_n = P.Concat(0)(c_n)
2025+
h_n = h_n.view(h[0].shape)
2026+
c_n = c_n.view(h[1].shape)
2027+
return output, (h_n.view(h[0].shape), c_n.view(h[1].shape))
2028+
h_n = P.Concat(0)(h_n)
2029+
return output, h_n.view(h.shape)
2030+
2031+
def _stacked_dynamic_rnn(self, x, h, seq_length):
2032+
"""stacked mutil_layer dynamic_rnn"""
2033+
pre_layer = x
2034+
h_n = ()
2035+
c_n = ()
2036+
output = 0
2037+
for i in range(self.num_layers):
2038+
if self.has_bias:
2039+
w_ih, w_hh, b_ih, b_hh = self.w_ih_list[i], self.w_hh_list[i], self.b_ih_list[i], self.b_hh_list[i]
2040+
else:
2041+
w_ih, w_hh = self.w_ih_list[i], self.w_hh_list[i]
2042+
b_ih, b_hh = None, None
2043+
if self.is_lstm:
2044+
h_i = (h[0][i], h[1][i])
2045+
else:
2046+
h_i = h[i]
2047+
output, h_t = self.rnn(pre_layer, h_i, seq_length, w_ih, w_hh, b_ih, b_hh)
2048+
pre_layer = self.dropout_op(output) if (self.dropout != 0 and i < self.num_layers - 1) else output
2049+
if self.is_lstm:
2050+
h_n += (h_t[0], )
2051+
c_n += (h_t[1], )
2052+
else:
2053+
h_n += (h_t, )
2054+
if self.is_lstm:
2055+
h_n = P.Concat(0)(h_n)
2056+
c_n = P.Concat(0)(c_n)
2057+
h_n = h_n.view(h[0].shape)
2058+
c_n = c_n.view(h[1].shape)
2059+
return output, (h_n.view(h[0].shape), c_n.view(h[1].shape))
2060+
h_n = P.Concat(0)(h_n)
2061+
return output, h_n.view(h.shape)
2062+
2063+
@constexpr
2064+
def _init_state(shape, dtype, is_lstm):
2065+
hx = ms.Tensor(np.zeros(shape), dtype)
2066+
cx = ms.Tensor(np.zeros(shape), dtype)
2067+
if is_lstm:
2068+
return (hx, cx)
2069+
return hx
2070+
2071+
@constexpr
2072+
def _check_input_dtype(input_dtype, param_name, allow_dtypes, cls_name):
2073+
validator.check_type_name(param_name, input_dtype, allow_dtypes, cls_name)
2074+
2075+
def construct(self, x, hx=None, seq_length=None):
2076+
'''Defines the RNN like operators performed'''
2077+
x_dtype = P.dtype(x)
2078+
hx_dtype = P.dtype(hx)
2079+
self._check_input_dtype(x_dtype, "x", [ms.float32], self.cls_name)
2080+
self._check_input_dtype(hx_dtype, "hx", [ms.float32], self.cls_name)
2081+
if seq_length is not None:
2082+
seq_length_dtype = P.dtype(seq_length)
2083+
self._check_input_dtype(seq_length_dtype, "seq_length", [ms.int32, ms.int64], self.cls_name)
2084+
2085+
max_batch_size = x.shape[0] if self.batch_first else x.shape[1]
2086+
num_directions = 2 if self.bidirectional else 1
2087+
if hx is None:
2088+
hx = self._init_state(
2089+
(self.num_layers * num_directions, max_batch_size, self.hidden_size), x.dtype, self.is_lstm
2090+
)
2091+
if self.batch_first:
2092+
x = P.Transpose()(x, (1, 0, 2))
2093+
if self.bidirectional:
2094+
x, h = self._stacked_bi_dynamic_rnn(x, hx, seq_length)
2095+
else:
2096+
x, h = self._stacked_dynamic_rnn(x, hx, seq_length)
2097+
if self.batch_first:
2098+
x = P.Transpose()(x, (1, 0, 2))
2099+
return x, h
19792100

19802101

19812102
class layernorm(Cell):

0 commit comments

Comments
 (0)