66import mindspore as ms
77import mindspore .ops as P
88from mindspore import context
9+ from mindspore .ops .primitive import constexpr
910from mindspore .nn .cell import Cell
1011from mindspore ._checkparam import Rel
1112from mindspore .ops import functional as F
1718from mindspore .communication .management import get_group_size , get_rank
1819from mindspore .ops .operations import LayerNorm
1920import mindspore .numpy as np
21+ from mindspore .common .parameter import ParameterTuple
22+ from mindspore .nn .layer .rnns import _DynamicRNN
2023import warnings
2124import math
2225
@@ -833,7 +836,9 @@ def __init__(self, ksize, strides, padding, data_format=None):
833836 self .data_format , self .padding = preprocess_2d_format (data_format = data_format , padding = padding )
834837 ms_ksize = ksize [1 ]
835838 ms_strides = strides [1 ]
836- self .avgpool = P .AvgPool (kernel_size = ms_ksize , strides = ms_strides , pad_mode = padding , data_format = self .data_format )
839+ self .avgpool = P .AvgPool (
840+ kernel_size = ms_ksize , strides = ms_strides , pad_mode = padding , data_format = self .data_format
841+ )
837842
838843 def construct (self , inputs ):
839844 outputs = self .avgpool (inputs )
@@ -930,7 +935,7 @@ def __init__(self, ksize, strides, padding, data_format='NCDHW'):
930935 if data_format == 'NCDHW' :
931936 strides = (strides [2 ], strides [3 ], strides [4 ])
932937 print (ksize , strides , padding )
933- self .avg_pool = P .AvgPool3D (kernel_size = ksize , strides = strides , pad_mode = padding , data_format = data_format )
938+ self .avg_pool = P .AvgPool3D (kernel_size = ksize , strides = strides , pad_mode = padding , data_format = data_format )
934939
935940 def __call__ (self , inputs ):
936941 return self .avg_pool (inputs )
@@ -1838,15 +1843,12 @@ def __init__(self, weight_ih, weight_hh, bias_ih, bias_hh, act):
18381843 self .bias_ih = bias_ih
18391844 self .bias_hh = bias_hh
18401845 self .act_fn = P .ReLU () if act == 'relu' else P .Tanh ()
1841- self .transpose = P .Transpose ()
18421846
18431847 def construct (self , input , h ):
1844- self .weight_ih = self .transpose (self .weight_ih , (1 , 0 ))
1845- i2h = P .matmul (input , self .weight_ih )
1848+ i2h = P .MatMul (False , True )(input , self .weight_ih )
18461849 if self .bias_ih is not None :
18471850 i2h += self .bias_ih
1848- self .weight_hh = self .transpose (self .weight_hh , (1 , 0 ))
1849- h2h = P .matmul (h , self .weight_hh )
1851+ h2h = P .MatMul (False , True )(h , self .weight_hh )
18501852 if self .bias_hh is not None :
18511853 h2h += self .bias_hh
18521854 h = self .act_fn (i2h + h2h )
@@ -1863,17 +1865,14 @@ def __init__(self, weight_ih, weight_hh, bias_ih, bias_hh):
18631865 self .bias_hh = bias_hh
18641866 self .gate_act_fn = P .Sigmoid ()
18651867 self .act_fn = P .Tanh ()
1866- self .transpose = P .Transpose ()
18671868 self .split = P .Split (axis = - 1 , output_num = 4 )
18681869
18691870 def construct (self , input , h , c ):
18701871
1871- self .weight_ih = self .transpose (self .weight_ih , (1 , 0 ))
1872- gates = P .matmul (input , self .weight_ih )
1872+ gates = P .MatMul (False , True )(input , self .weight_ih )
18731873 if self .bias_ih is not None :
18741874 gates += self .bias_ih
1875- self .weight_hh = self .transpose (self .weight_hh , (1 , 0 ))
1876- gates += P .matmul (h , self .weight_hh )
1875+ gates += P .MatMul (False , True )(h , self .weight_hh )
18771876 if self .bias_hh is not None :
18781877 gates += self .bias_hh
18791878
@@ -1902,12 +1901,10 @@ def __init__(self, weight_ih, weight_hh, bias_ih, bias_hh):
19021901
19031902 def construct (self , input , h ):
19041903
1905- self .weight_ih = self .transpose (self .weight_ih , (1 , 0 ))
1906- x_gates = P .matmul (input , self .weight_ih )
1904+ x_gates = P .MatMul (False , True )(input , self .weight_ih )
19071905 if self .bias_ih is not None :
19081906 x_gates += self .bias_ih
1909- self .weight_hh = self .transpose (self .weight_hh , (1 , 0 ))
1910- h_gates = P .matmul (h , self .weight_hh )
1907+ h_gates = P .MatMul (False , True )(h , self .weight_hh )
19111908 if self .bias_hh is not None :
19121909 h_gates += self .bias_hh
19131910
@@ -1935,47 +1932,171 @@ def __init__(
19351932 dropout ,
19361933 bidirectional ,
19371934 is_train ,
1935+ w_ih ,
1936+ w_hh ,
1937+ b_ih ,
1938+ b_hh ,
19381939 ):
19391940 super (rnnbase , self ).__init__ ()
1941+ if not 0 <= dropout < 1 :
1942+ raise ValueError ("dropout should be a number in range [0, 1)." )
1943+ if dropout > 0 and num_layers == 1 :
1944+ raise ValueError (
1945+ "dropout option adds dropout after all but last "
1946+ "recurrent layer, so non-zero dropout expects "
1947+ "num_layers greater than 1, but got dropout={} and "
1948+ "num_layers={}" .format (dropout , num_layers )
1949+ )
19401950 self .mode = mode
1951+ self .reverse = P .ReverseV2 ([0 ])
1952+ self .reverse_sequence = P .ReverseSequence (0 , 1 )
19411953 self .input_size = input_size
19421954 self .hidden_size = hidden_size
19431955 self .num_layers = num_layers
1944- self .bidirect = 2 if bidirectional else 1
1956+ self .dropout = dropout
1957+ self .dropout_op = ms .nn .Dropout (float (1 - dropout ))
1958+ self .has_bias = bias
1959+ self .bidirectional = bidirectional
19451960 self .batch_first = batch_first
1946- if mode == 'LSTM' :
1947- self .lstm = ms .nn .LSTM (
1948- input_size = input_size , hidden_size = hidden_size , num_layers = num_layers , has_bias = bias ,
1949- batch_first = batch_first , dropout = dropout , bidirectional = bidirectional
1950- )
1951- elif mode == 'GRU' :
1952-
1953- raise NotImplementedError
1954-
1955- elif mode == 'RNN_TANH' :
1956-
1957- raise NotImplementedError
1958-
1959- elif mode == 'RNN_RELU' :
1960-
1961- raise NotImplementedError
1961+ self .train = is_train
1962+ self .w_ih = ParameterTuple (w_ih )
1963+ self .w_hh = ParameterTuple (w_hh )
1964+ self .b_ih = ParameterTuple (b_ih )
1965+ self .b_hh = ParameterTuple (b_hh )
1966+ self .rnn = _DynamicRNN (mode )
1967+ self .is_lstm = mode == "LSTM"
19621968
19631969 self .zeros = P .Zeros ()
19641970
1965- def construct (self , input , states ):
1966- input_shape = input .shape
1967- input_dtype = input .dtype
1968- if self .mode == 'LSTM' :
1969- if self .batch_first :
1970- batch_size = input_shape [0 ]
1971+ def _stacked_bi_dynamic_rnn (self , x , h , seq_length ):
1972+ """stacked bidirectional dynamic_rnn"""
1973+ pre_layer = x
1974+ h_n = ()
1975+ c_n = ()
1976+ output = 0
1977+ for i in range (self .num_layers ):
1978+ offset = i * 2
1979+ if self .has_bias :
1980+ w_f_ih , w_f_hh , b_f_ih , b_f_hh = \
1981+ self .w_ih_list [offset ], self .w_hh_list [offset ], \
1982+ self .b_ih_list [offset ], self .b_hh_list [offset ]
1983+ w_b_ih , w_b_hh , b_b_ih , b_b_hh = \
1984+ self .w_ih_list [offset + 1 ], self .w_hh_list [offset + 1 ], \
1985+ self .b_ih_list [offset + 1 ], self .b_hh_list [offset + 1 ]
1986+ else :
1987+ w_f_ih , w_f_hh = self .w_ih_list [offset ], self .w_hh_list [offset ]
1988+ w_b_ih , w_b_hh = self .w_ih_list [offset + 1 ], self .w_hh_list [offset + 1 ]
1989+ b_f_ih , b_f_hh , b_b_ih , b_b_hh = None , None , None , None
1990+ if self .is_lstm :
1991+ h_f_i = (h [0 ][offset ], h [1 ][offset ])
1992+ h_b_i = (h [0 ][offset + 1 ], h [1 ][offset + 1 ])
1993+ else :
1994+ h_f_i = h [offset ]
1995+ h_b_i = h [offset + 1 ]
1996+ if seq_length is None :
1997+ x_b = self .reverse (pre_layer )
19711998 else :
1972- batch_size = input_shape [1 ]
1973- if states is None :
1974- h = self .zeros ((self .bidirect * self .num_layers , batch_size , self .hidden_size ), input_dtype )
1975- c = self .zeros ((self .bidirect * self .num_layers , batch_size , self .hidden_size ), input_dtype )
1976- states = (h , c )
1977- output , (h , c ) = self .lstm (input , states )
1978- return output , (h , c )
1999+ x_b = self .reverse_sequence (pre_layer , seq_length )
2000+ output_f , h_t_f = self .rnn (pre_layer , h_f_i , seq_length , w_f_ih , w_f_hh , b_f_ih , b_f_hh )
2001+ output_b , h_t_b = self .rnn (x_b , h_b_i , seq_length , w_b_ih , w_b_hh , b_b_ih , b_b_hh )
2002+ if seq_length is None :
2003+ output_b = self .reverse (output_b )
2004+ else :
2005+ output_b = self .reverse_sequence (output_b , seq_length )
2006+ output = P .Concat (2 )((output_f , output_b ))
2007+ pre_layer = self .dropout_op (output ) if (self .dropout != 0 and i < self .num_layers - 1 ) else output
2008+ if self .is_lstm :
2009+ h_n += (
2010+ h_t_f [0 ],
2011+ h_t_b [0 ],
2012+ )
2013+ c_n += (
2014+ h_t_f [1 ],
2015+ h_t_b [1 ],
2016+ )
2017+ else :
2018+ h_n += (
2019+ h_t_f ,
2020+ h_t_b ,
2021+ )
2022+ if self .is_lstm :
2023+ h_n = P .Concat (0 )(h_n )
2024+ c_n = P .Concat (0 )(c_n )
2025+ h_n = h_n .view (h [0 ].shape )
2026+ c_n = c_n .view (h [1 ].shape )
2027+ return output , (h_n .view (h [0 ].shape ), c_n .view (h [1 ].shape ))
2028+ h_n = P .Concat (0 )(h_n )
2029+ return output , h_n .view (h .shape )
2030+
2031+ def _stacked_dynamic_rnn (self , x , h , seq_length ):
2032+ """stacked mutil_layer dynamic_rnn"""
2033+ pre_layer = x
2034+ h_n = ()
2035+ c_n = ()
2036+ output = 0
2037+ for i in range (self .num_layers ):
2038+ if self .has_bias :
2039+ w_ih , w_hh , b_ih , b_hh = self .w_ih_list [i ], self .w_hh_list [i ], self .b_ih_list [i ], self .b_hh_list [i ]
2040+ else :
2041+ w_ih , w_hh = self .w_ih_list [i ], self .w_hh_list [i ]
2042+ b_ih , b_hh = None , None
2043+ if self .is_lstm :
2044+ h_i = (h [0 ][i ], h [1 ][i ])
2045+ else :
2046+ h_i = h [i ]
2047+ output , h_t = self .rnn (pre_layer , h_i , seq_length , w_ih , w_hh , b_ih , b_hh )
2048+ pre_layer = self .dropout_op (output ) if (self .dropout != 0 and i < self .num_layers - 1 ) else output
2049+ if self .is_lstm :
2050+ h_n += (h_t [0 ], )
2051+ c_n += (h_t [1 ], )
2052+ else :
2053+ h_n += (h_t , )
2054+ if self .is_lstm :
2055+ h_n = P .Concat (0 )(h_n )
2056+ c_n = P .Concat (0 )(c_n )
2057+ h_n = h_n .view (h [0 ].shape )
2058+ c_n = c_n .view (h [1 ].shape )
2059+ return output , (h_n .view (h [0 ].shape ), c_n .view (h [1 ].shape ))
2060+ h_n = P .Concat (0 )(h_n )
2061+ return output , h_n .view (h .shape )
2062+
2063+ @constexpr
2064+ def _init_state (shape , dtype , is_lstm ):
2065+ hx = ms .Tensor (np .zeros (shape ), dtype )
2066+ cx = ms .Tensor (np .zeros (shape ), dtype )
2067+ if is_lstm :
2068+ return (hx , cx )
2069+ return hx
2070+
2071+ @constexpr
2072+ def _check_input_dtype (input_dtype , param_name , allow_dtypes , cls_name ):
2073+ validator .check_type_name (param_name , input_dtype , allow_dtypes , cls_name )
2074+
2075+ def construct (self , x , hx = None , seq_length = None ):
2076+ '''Defines the RNN like operators performed'''
2077+ x_dtype = P .dtype (x )
2078+ hx_dtype = P .dtype (hx )
2079+ self ._check_input_dtype (x_dtype , "x" , [ms .float32 ], self .cls_name )
2080+ self ._check_input_dtype (hx_dtype , "hx" , [ms .float32 ], self .cls_name )
2081+ if seq_length is not None :
2082+ seq_length_dtype = P .dtype (seq_length )
2083+ self ._check_input_dtype (seq_length_dtype , "seq_length" , [ms .int32 , ms .int64 ], self .cls_name )
2084+
2085+ max_batch_size = x .shape [0 ] if self .batch_first else x .shape [1 ]
2086+ num_directions = 2 if self .bidirectional else 1
2087+ if hx is None :
2088+ hx = self ._init_state (
2089+ (self .num_layers * num_directions , max_batch_size , self .hidden_size ), x .dtype , self .is_lstm
2090+ )
2091+ if self .batch_first :
2092+ x = P .Transpose ()(x , (1 , 0 , 2 ))
2093+ if self .bidirectional :
2094+ x , h = self ._stacked_bi_dynamic_rnn (x , hx , seq_length )
2095+ else :
2096+ x , h = self ._stacked_dynamic_rnn (x , hx , seq_length )
2097+ if self .batch_first :
2098+ x = P .Transpose ()(x , (1 , 0 , 2 ))
2099+ return x , h
19792100
19802101
19812102class layernorm (Cell ):
0 commit comments