22# -*- coding: utf-8 -*-
33
44import paddle as pd
5+ import paddle .nn
56from paddle import framework
67import paddle .nn .functional as F
78import numpy as np
1314from paddle .nn .layer .rnn import RNNCellBase
1415import warnings
1516import math
16-
17+ from paddle import _C_ops
1718
1819def padding_format (padding ):
1920 """
@@ -1503,7 +1504,6 @@ def concat_states(states, bidirectional=False, state_components=1):
15031504 componnets .append (states [i ::state_components ])
15041505 return tuple ([pd .stack (item ) for item in componnets ])
15051506
1506-
15071507class rnnbase (LayerList ):
15081508
15091509 def __init__ (
@@ -1539,7 +1539,6 @@ def __init__(
15391539 self .bias = bias
15401540 RNN = pd .nn .RNN
15411541 BiRNN = pd .nn .BiRNN
1542-
15431542 kwargs = {"weight_ih_attr" : None , "weight_hh_attr" : None , "bias_ih_attr" : self .bias , "bias_hh_attr" : self .bias }
15441543 act = None
15451544 rnn_cls = None
@@ -1618,16 +1617,11 @@ def __init__(
16181617
16191618 def flatten_parameters (self ):
16201619 if self .could_use_cudnn :
1621- params = self .parameters (include_sublayers = False )
1622- shape = [np .prod (param .shape ) for param in params ]
1623- self ._all_weights = [None ] * len (params )
1624- for i , param in enumerate (params ):
1625- offset = 0 if i % 4 < 2 else (2 * self .num_layers * self .bidirect )
1626- layer_idx = i // 4
1627- self ._all_weights [offset + layer_idx * 2 + i % 2 ] = param
1620+ self ._all_weights = self .parameters (include_sublayers = False )
1621+ shape = [np .prod (param .shape ) for param in self ._all_weights ]
16281622 self ._flat_weight = [
16291623 self .create_parameter (
1630- shape = [np .sum (shape )], dtype = params [0 ].dtype , default_initializer = I .Constant (0.0 )
1624+ shape = [np .sum (shape )], dtype = self . _all_weights [0 ].dtype , default_initializer = I .Constant (0.0 )
16311625 )
16321626 ]
16331627 self ._dropout_state = self .create_variable (dtype = fluid .core .VarDesc .VarType .UINT8 )
@@ -1640,42 +1634,20 @@ def flatten_parameters(self):
16401634 }, attrs = {
16411635 "copy_data" : True ,
16421636 "use_align" : False ,
1643- "dtype" : params [0 ].dtype
1637+ "dtype" : self . _all_weights [0 ].dtype
16441638 }
16451639 )
16461640
16471641 def _cudnn_impl (self , inputs , initial_states , sequence_length ):
16481642 if not self .time_major :
16491643 inputs = pd .tensor .transpose (inputs , [1 , 0 , 2 ])
1650- out = self ._helper .create_variable_for_type_inference (inputs .dtype )
1651- state = [self ._helper .create_variable_for_type_inference (inputs .dtype ) for i in range (self .state_components )]
1652- reserve = self ._helper .create_variable_for_type_inference (
1653- dtype = fluid .core .VarDesc .VarType .UINT8 , stop_gradient = True
1654- )
1655-
1656- inputs = {
1657- 'Input' : inputs ,
1658- 'WeightList' : self ._all_weights ,
1659- 'PreState' : initial_states ,
1660- 'SequenceLength' : sequence_length
1661- }
1662- attrs = {
1663- 'dropout_prob' : self .dropout ,
1664- 'is_bidirec' : self .bidirect == 2 ,
1665- 'input_size' : self .input_size ,
1666- 'hidden_size' : self .hidden_size ,
1667- 'num_layers' : self .num_layers ,
1668- 'mode' : self .mode ,
1669- 'is_test' : not self .training
1670- }
1671-
1672- outputs = {
1673- 'Out' : out ,
1674- 'State' : state ,
1675- 'Reserve' : reserve ,
1676- 'DropoutState' : self ._dropout_state ,
1677- }
1678- self ._helper .append_op (type = "rnn" , inputs = inputs , outputs = outputs , attrs = attrs )
1644+ _ , _ , out , state = _C_ops .rnn (
1645+ inputs , initial_states , self ._all_weights , sequence_length ,
1646+ self ._dropout_state , self .state_components , 'dropout_prob' ,
1647+ self .dropout , 'is_bidirec' , self .bidirect == 2 ,
1648+ 'input_size' , self .input_size , 'hidden_size' , self .hidden_size ,
1649+ 'num_layers' , self .num_layers , 'mode' , self .mode , 'is_test' ,
1650+ not self .training )
16791651 out = pd .tensor .transpose (out , [1 , 0 , 2 ]) if not self .time_major else out
16801652 return out , tuple (state ) if len (state ) > 1 else state [0 ]
16811653
0 commit comments