@@ -1919,6 +1919,19 @@ def construct(self, input, h):
19191919 return h , h
19201920
19211921
1922+ @constexpr
1923+ def _init_state (shape , dtype , is_lstm ):
1924+ hx = ms .Tensor (np .zeros (shape ), dtype )
1925+ cx = ms .Tensor (np .zeros (shape ), dtype )
1926+ if is_lstm :
1927+ return (hx , cx )
1928+ return hx
1929+
1930+ @constexpr
1931+ def _check_input_dtype_same_and_valid (args_name , args_value , valid_values , cls_name ):
1932+ args = {args_name [i ]: args_value [i ] for i in range (len (args_value ))}
1933+ validator .check_types_same_and_valid (args , valid_values , cls_name )
1934+
19221935class rnnbase (Cell ):
19231936
19241937 def __init__ (
@@ -1959,10 +1972,10 @@ def __init__(
19591972 self .bidirectional = bidirectional
19601973 self .batch_first = batch_first
19611974 self .train = is_train
1962- self .w_ih = ParameterTuple (w_ih )
1963- self .w_hh = ParameterTuple (w_hh )
1964- self .b_ih = ParameterTuple (b_ih )
1965- self .b_hh = ParameterTuple (b_hh )
1975+ self .w_ih_list = ParameterTuple (w_ih )
1976+ self .w_hh_list = ParameterTuple (w_hh )
1977+ self .b_ih_list = ParameterTuple (b_ih )
1978+ self .b_hh_list = ParameterTuple (b_hh )
19661979 self .rnn = _DynamicRNN (mode )
19671980 self .is_lstm = mode == "LSTM"
19681981
@@ -2060,43 +2073,31 @@ def _stacked_dynamic_rnn(self, x, h, seq_length):
20602073 h_n = P .Concat (0 )(h_n )
20612074 return output , h_n .view (h .shape )
20622075
2063- @constexpr
2064- def _init_state (shape , dtype , is_lstm ):
2065- hx = ms .Tensor (np .zeros (shape ), dtype )
2066- cx = ms .Tensor (np .zeros (shape ), dtype )
2067- if is_lstm :
2068- return (hx , cx )
2069- return hx
2070-
2071- @constexpr
2072- def _check_input_dtype (input_dtype , param_name , allow_dtypes , cls_name ):
2073- validator .check_type_name (param_name , input_dtype , allow_dtypes , cls_name )
2074-
20752076 def construct (self , x , hx = None , seq_length = None ):
20762077 '''Defines the RNN like operators performed'''
2077- x_dtype = P .dtype (x )
2078- hx_dtype = P .dtype (hx )
2079- self ._check_input_dtype (x_dtype , "x" , [ms .float32 ], self .cls_name )
2080- self ._check_input_dtype (hx_dtype , "hx" , [ms .float32 ], self .cls_name )
2081- if seq_length is not None :
2082- seq_length_dtype = P .dtype (seq_length )
2083- self ._check_input_dtype (seq_length_dtype , "seq_length" , [ms .int32 , ms .int64 ], self .cls_name )
2084-
20852078 max_batch_size = x .shape [0 ] if self .batch_first else x .shape [1 ]
20862079 num_directions = 2 if self .bidirectional else 1
2087- if hx is None :
2088- hx = self ._init_state (
2089- (self .num_layers * num_directions , max_batch_size , self .hidden_size ), x .dtype , self .is_lstm
2090- )
2080+ x_dtype = x .dtype
2081+ if hx is not None :
2082+ if not self .is_lstm :
2083+ _check_input_dtype_same_and_valid (['x' , 'hx' ], [x_dtype , hx .dtype ], \
2084+ [ms .float32 , ms .float16 ], self .cls_name )
2085+ else :
2086+ _check_input_dtype_same_and_valid (['x' , 'hx[0]' , 'hx[1]' ], [x_dtype , hx [0 ].dtype , hx [1 ].dtype ], \
2087+ [ms .float32 , ms .float16 ], self .cls_name )
2088+ else :
2089+ hx = _init_state ((self .num_layers * num_directions , max_batch_size , self .hidden_size ), x_dtype , self .is_lstm )
20912090 if self .batch_first :
20922091 x = P .Transpose ()(x , (1 , 0 , 2 ))
20932092 if self .bidirectional :
2094- x , h = self ._stacked_bi_dynamic_rnn (x , hx , seq_length )
2093+ x_n , hx_n = self ._stacked_bi_dynamic_rnn (x , hx , seq_length )
20952094 else :
2096- x , h = self ._stacked_dynamic_rnn (x , hx , seq_length )
2095+ x_n , hx_n = self ._stacked_dynamic_rnn (x , hx , seq_length )
20972096 if self .batch_first :
2098- x = P .Transpose ()(x , (1 , 0 , 2 ))
2099- return x , h
2097+ x_n = P .Transpose ()(x_n , (1 , 0 , 2 ))
2098+ if not self .is_lstm :
2099+ return x_n .astype (x_dtype ), hx_n .astype (x_dtype )
2100+ return x_n .astype (x_dtype ), (hx_n [0 ].astype (x_dtype ), hx_n [1 ].astype (x_dtype ))
21002101
21012102
21022103class layernorm (Cell ):
0 commit comments