@@ -591,7 +591,9 @@ def forward(self, input, states=None):
591591 output , new_states = self .rnn (input , states )
592592
593593 if not self ._nodes_fixed and self ._build_graph :
594- self ._add_node (input , [output , new_states ])
594+ self .states = states
595+ self .new_states = new_states
596+ self ._add_node (input , output )
595597 self ._nodes_fixed = True
596598 return output , new_states
597599
@@ -671,7 +673,9 @@ def forward(self, input, states=None):
671673 output , new_states = self .rnn (input , states )
672674
673675 if not self ._nodes_fixed and self ._build_graph :
674- self ._add_node (input , [output , new_states ])
676+ self .states = states
677+ self .new_states = new_states
678+ self ._add_node (input , output )
675679 self ._nodes_fixed = True
676680 return output , new_states
677681
@@ -750,6 +754,8 @@ def forward(self, input, states=None):
750754 output , new_states = self .rnn (input , states )
751755
752756 if not self ._nodes_fixed and self ._build_graph :
753- self ._add_node (input , [output , new_states ])
757+ self .states = states
758+ self .new_states = new_states
759+ self ._add_node (input , output )
754760 self ._nodes_fixed = True
755761 return output , new_states
0 commit comments