Skip to content

Commit df26565

Browse files
committed
fix TLX nodes collection bug
1 parent f6f0a06 commit df26565

File tree

4 files changed

+78
-74
lines changed

4 files changed

+78
-74
lines changed

tensorlayerx/nn/core/core_mindspore.py

Lines changed: 24 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -241,11 +241,16 @@ def init_build(self, *inputs, **kwargs):
241241

242242
self.forward(*inputs, **kwargs)
243243

244-
def build_graph(self, *inputs, **kwargs):
245-
# Add nodes only when the composition is needed.
244+
def set_build_graph(self):
246245
for layer_name, layer in self._cells.items():
247246
if isinstance(layer, Module):
247+
if len(layer._cells) > 1:
248+
layer.set_build_graph()
248249
layer._build_graph = True
250+
251+
def build_graph(self, *inputs, **kwargs):
252+
# Add nodes only when the composition is needed.
253+
self.set_build_graph()
249254
self.set_eval()
250255

251256
outputs = self.forward(*inputs, **kwargs)
@@ -382,31 +387,26 @@ def build(self, inputs_shape):
382387

383388
def forward(self, input_data):
384389
for layer in self.layer_list:
385-
inputs = input_data
386390
input_data = layer(input_data)
387-
outputs = input_data
388-
if not self._nodes_fixed and self._build_graph:
389-
self._add_seq_node(inputs, outputs, layer)
390-
self._nodes_fixed = True
391391
return input_data
392392

393-
def _add_seq_node(self, input_tensors, output_tensors, layer):
394-
inputs_list = tolist(input_tensors)
395-
outputs_list = tolist(output_tensors)
396-
if layer.__class__.__name__ in tlx.layers.inputs.__all__:
397-
in_nodes = []
398-
in_tensor_idxes = [0]
399-
else:
400-
in_nodes = [tensor._info[0] for tensor in inputs_list]
401-
in_tensor_idxes = [tensor._info[1] for tensor in inputs_list]
402-
node_index = len(_global_layer_node)
403-
404-
new_node = ModuleNode(
405-
layer, node_index, in_nodes, inputs_list, outputs_list, in_tensor_idxes, select_attrs(layer)
406-
)
407-
_global_layer_node.append(new_node)
408-
for idx, tensor in enumerate(outputs_list):
409-
tensor._info = (new_node, idx)
393+
# def _add_seq_node(self, input_tensors, output_tensors, layer):
394+
# inputs_list = tolist(input_tensors)
395+
# outputs_list = tolist(output_tensors)
396+
# if layer.__class__.__name__ in tlx.layers.inputs.__all__:
397+
# in_nodes = []
398+
# in_tensor_idxes = [0]
399+
# else:
400+
# in_nodes = [tensor._info[0] for tensor in inputs_list]
401+
# in_tensor_idxes = [tensor._info[1] for tensor in inputs_list]
402+
# node_index = len(_global_layer_node)
403+
#
404+
# new_node = ModuleNode(
405+
# layer, node_index, in_nodes, inputs_list, outputs_list, in_tensor_idxes, select_attrs(layer)
406+
# )
407+
# _global_layer_node.append(new_node)
408+
# for idx, tensor in enumerate(outputs_list):
409+
# tensor._info = (new_node, idx)
410410

411411

412412
class ModuleList(Module):

tensorlayerx/nn/core/core_paddle.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -300,6 +300,10 @@ def insert_child_to_layer(self, child_name, child):
300300
raise TypeError("Child layer type is incorrect.")
301301
self._sub_layers[child_name] = child
302302

303+
def set_build_graph(self):
304+
305+
raise NotImplementedError
306+
303307
def build_graph(self, *inputs, **kwargs):
304308
# Add nodes only when the composition is needed.
305309
# for layer in self.sublayers():

tensorlayerx/nn/core/core_tensorflow.py

Lines changed: 25 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -639,11 +639,16 @@ def str_to_init(self, initializer):
639639
def check_param(self, param, dim='2d'):
640640
return check_parameter(param, dim)
641641

642-
def build_graph(self, *inputs, **kwargs):
643-
# Add nodes only when the composition is needed.
642+
def set_build_graph(self):
644643
for layer_name, layer in self._layers.items():
645644
if isinstance(layer, Module):
645+
if len(layer._layers) > 1:
646+
layer.set_build_graph()
646647
layer._build_graph = True
648+
649+
def build_graph(self, *inputs, **kwargs):
650+
# Add nodes only when the composition is needed.
651+
self.set_build_graph()
647652
self.set_eval()
648653

649654
outputs = self.forward(*inputs, **kwargs)
@@ -774,31 +779,26 @@ def build(self, inputs_shape):
774779

775780
def forward(self, input_data):
776781
for layer in self.layer_list:
777-
inputs = input_data
778782
input_data = layer(input_data)
779-
outputs = input_data
780-
if not self._nodes_fixed and self._build_graph:
781-
self._add_seq_node(inputs, outputs, layer)
782-
self._nodes_fixed = True
783783
return input_data
784-
785-
def _add_seq_node(self, input_tensors, output_tensors, layer):
786-
inputs_list = tolist(input_tensors)
787-
outputs_list = tolist(output_tensors)
788-
if layer.__class__.__name__ in tlx.layers.inputs.__all__:
789-
in_nodes = []
790-
in_tensor_idxes = [0]
791-
else:
792-
in_nodes = [tensor._info[0] for tensor in inputs_list]
793-
in_tensor_idxes = [tensor._info[1] for tensor in inputs_list]
794-
node_index = len(_global_layer_node)
795-
796-
new_node = ModuleNode(
797-
layer, node_index, in_nodes, inputs_list, outputs_list, in_tensor_idxes, select_attrs(layer)
798-
)
799-
_global_layer_node.append(new_node)
800-
for idx, tensor in enumerate(outputs_list):
801-
tensor._info = (new_node, idx)
784+
#
785+
# def _add_seq_node(self, input_tensors, output_tensors, layer):
786+
# inputs_list = tolist(input_tensors)
787+
# outputs_list = tolist(output_tensors)
788+
# if layer.__class__.__name__ in tlx.layers.inputs.__all__:
789+
# in_nodes = []
790+
# in_tensor_idxes = [0]
791+
# else:
792+
# in_nodes = [tensor._info[0] for tensor in inputs_list]
793+
# in_tensor_idxes = [tensor._info[1] for tensor in inputs_list]
794+
# node_index = len(_global_layer_node)
795+
#
796+
# new_node = ModuleNode(
797+
# layer, node_index, in_nodes, inputs_list, outputs_list, in_tensor_idxes, select_attrs(layer)
798+
# )
799+
# _global_layer_node.append(new_node)
800+
# for idx, tensor in enumerate(outputs_list):
801+
# tensor._info = (new_node, idx)
802802

803803

804804
class ModuleList(Module):

tensorlayerx/nn/core/core_torch.py

Lines changed: 25 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -207,11 +207,16 @@ def init_build(self, *inputs, **kwargs):
207207

208208
self.forward(*inputs, **kwargs)
209209

210-
def build_graph(self, *inputs, **kwargs):
211-
# Add nodes only when the composition is needed.
212-
for name, layer in self._modules.items():
210+
def set_build_graph(self):
211+
for layer_name, layer in self._modules.items():
213212
if isinstance(layer, Module):
213+
if len(layer._modules) > 1:
214+
layer.set_build_graph()
214215
layer._build_graph = True
216+
217+
def build_graph(self, *inputs, **kwargs):
218+
# Add nodes only when the composition is needed.
219+
self.set_build_graph()
215220
self.set_eval()
216221

217222
outputs = self.forward(*inputs, **kwargs)
@@ -360,31 +365,26 @@ def build(self, inputs_shape):
360365

361366
def forward(self, input_data):
362367
for layer in self.layer_list:
363-
inputs = input_data
364368
input_data = layer(input_data)
365-
outputs = input_data
366-
if not self._nodes_fixed and self._build_graph:
367-
self._add_seq_node(inputs, outputs, layer)
368-
self._nodes_fixed = True
369369
return input_data
370370

371-
def _add_seq_node(self, input_tensors, output_tensors, layer):
372-
inputs_list = tolist(input_tensors)
373-
outputs_list = tolist(output_tensors)
374-
if layer.__class__.__name__ in tlx.layers.inputs.__all__:
375-
in_nodes = []
376-
in_tensor_idxes = [0]
377-
else:
378-
in_nodes = [tensor._info[0] for tensor in inputs_list]
379-
in_tensor_idxes = [tensor._info[1] for tensor in inputs_list]
380-
node_index = len(_global_layer_node)
381-
382-
new_node = ModuleNode(
383-
layer, node_index, in_nodes, inputs_list, outputs_list, in_tensor_idxes, select_attrs(layer)
384-
)
385-
_global_layer_node.append(new_node)
386-
for idx, tensor in enumerate(outputs_list):
387-
tensor._info = (new_node, idx)
371+
# def _add_seq_node(self, input_tensors, output_tensors, layer):
372+
# inputs_list = tolist(input_tensors)
373+
# outputs_list = tolist(output_tensors)
374+
# if layer.__class__.__name__ in tlx.layers.inputs.__all__:
375+
# in_nodes = []
376+
# in_tensor_idxes = [0]
377+
# else:
378+
# in_nodes = [tensor._info[0] for tensor in inputs_list]
379+
# in_tensor_idxes = [tensor._info[1] for tensor in inputs_list]
380+
# node_index = len(_global_layer_node)
381+
#
382+
# new_node = ModuleNode(
383+
# layer, node_index, in_nodes, inputs_list, outputs_list, in_tensor_idxes, select_attrs(layer)
384+
# )
385+
# _global_layer_node.append(new_node)
386+
# for idx, tensor in enumerate(outputs_list):
387+
# tensor._info = (new_node, idx)
388388

389389

390390
class ModuleList(Module):

0 commit comments

Comments
 (0)