Skip to content

Commit c08d3fe

Browse files
committed
fix rnn construct graph
1 parent 4f2f06f commit c08d3fe

File tree

2 files changed

+10
-5
lines changed

2 files changed

+10
-5
lines changed

tensorlayerx/nn/core/common.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -540,8 +540,7 @@ def construct_graph(inputs, outputs):
540540
if out_node.node_name not in indegrees.keys():
541541
indegrees[out_node.node_name] = len(out_node.in_nodes)
542542
indegrees[out_node.node_name] -= 1
543-
if indegrees[out_node.node_name] == 0 or \
544-
isinstance(out_node.layer, (tlx.nn.RNN, tlx.nn.LSTM, tlx.nn.GRU)):
543+
if indegrees[out_node.node_name] == 0:
545544
next_depth.append(out_node)
546545
cur_depth = next_depth
547546
next_depth = []

tensorlayerx/nn/layers/recurrent.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -591,7 +591,9 @@ def forward(self, input, states=None):
591591
output, new_states = self.rnn(input, states)
592592

593593
if not self._nodes_fixed and self._build_graph:
594-
self._add_node(input, [output, new_states])
594+
self.states = states
595+
self.new_states = new_states
596+
self._add_node(input, output)
595597
self._nodes_fixed = True
596598
return output, new_states
597599

@@ -671,7 +673,9 @@ def forward(self, input, states=None):
671673
output, new_states = self.rnn(input, states)
672674

673675
if not self._nodes_fixed and self._build_graph:
674-
self._add_node(input, [output, new_states])
676+
self.states = states
677+
self.new_states = new_states
678+
self._add_node(input, output)
675679
self._nodes_fixed = True
676680
return output, new_states
677681

@@ -750,6 +754,8 @@ def forward(self, input, states=None):
750754
output, new_states = self.rnn(input, states)
751755

752756
if not self._nodes_fixed and self._build_graph:
753-
self._add_node(input, [output, new_states])
757+
self.states = states
758+
self.new_states = new_states
759+
self._add_node(input, output)
754760
self._nodes_fixed = True
755761
return output, new_states

0 commit comments

Comments
 (0)