Skip to content

Commit 66d49ba

Browse files
committed
Fix Sequential mode ONNX node collection.
1 parent f5bc439 commit 66d49ba

File tree

3 files changed

+69
-0
lines changed

3 files changed

+69
-0
lines changed

tensorlayerx/nn/core/core_mindspore.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -380,9 +380,32 @@ def build(self, inputs_shape):
380380

381381
def forward(self, input_data):
382382
for layer in self.layer_list:
383+
inputs = input_data
383384
input_data = layer(input_data)
385+
outputs = input_data
386+
if not self._nodes_fixed and self._build_graph:
387+
self._add_seq_node(inputs, outputs, layer)
388+
self._nodes_fixed = True
384389
return input_data
385390

391+
def _add_seq_node(self, input_tensors, output_tensors, layer):
392+
inputs_list = tolist(input_tensors)
393+
outputs_list = tolist(output_tensors)
394+
if layer.__class__.__name__ in tlx.layers.inputs.__all__:
395+
in_nodes = []
396+
in_tensor_idxes = [0]
397+
else:
398+
in_nodes = [tensor._info[0] for tensor in inputs_list]
399+
in_tensor_idxes = [tensor._info[1] for tensor in inputs_list]
400+
node_index = len(_global_layer_node)
401+
402+
new_node = ModuleNode(
403+
layer, node_index, in_nodes, inputs_list, outputs_list, in_tensor_idxes, select_attrs(layer)
404+
)
405+
_global_layer_node.append(new_node)
406+
for idx, tensor in enumerate(outputs_list):
407+
tensor._info = (new_node, idx)
408+
386409

387410
class ModuleList(Module):
388411
"""

tensorlayerx/nn/core/core_tensorflow.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -775,9 +775,32 @@ def build(self, inputs_shape):
775775

776776
def forward(self, input_data):
777777
for layer in self.layer_list:
778+
inputs = input_data
778779
input_data = layer(input_data)
780+
outputs = input_data
781+
if not self._nodes_fixed and self._build_graph:
782+
self._add_seq_node(inputs, outputs, layer)
783+
self._nodes_fixed = True
779784
return input_data
780785

786+
def _add_seq_node(self, input_tensors, output_tensors, layer):
787+
inputs_list = tolist(input_tensors)
788+
outputs_list = tolist(output_tensors)
789+
if layer.__class__.__name__ in tlx.layers.inputs.__all__:
790+
in_nodes = []
791+
in_tensor_idxes = [0]
792+
else:
793+
in_nodes = [tensor._info[0] for tensor in inputs_list]
794+
in_tensor_idxes = [tensor._info[1] for tensor in inputs_list]
795+
node_index = len(_global_layer_node)
796+
797+
new_node = ModuleNode(
798+
layer, node_index, in_nodes, inputs_list, outputs_list, in_tensor_idxes, select_attrs(layer)
799+
)
800+
_global_layer_node.append(new_node)
801+
for idx, tensor in enumerate(outputs_list):
802+
tensor._info = (new_node, idx)
803+
781804

782805
class ModuleList(Module):
783806
"""

tensorlayerx/nn/core/core_torch.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -360,9 +360,32 @@ def build(self, inputs_shape):
360360

361361
def forward(self, input_data):
362362
for layer in self.layer_list:
363+
inputs = input_data
363364
input_data = layer(input_data)
365+
outputs = input_data
366+
if not self._nodes_fixed and self._build_graph:
367+
self._add_seq_node(inputs, outputs, layer)
368+
self._nodes_fixed = True
364369
return input_data
365370

371+
def _add_seq_node(self, input_tensors, output_tensors, layer):
372+
inputs_list = tolist(input_tensors)
373+
outputs_list = tolist(output_tensors)
374+
if layer.__class__.__name__ in tlx.layers.inputs.__all__:
375+
in_nodes = []
376+
in_tensor_idxes = [0]
377+
else:
378+
in_nodes = [tensor._info[0] for tensor in inputs_list]
379+
in_tensor_idxes = [tensor._info[1] for tensor in inputs_list]
380+
node_index = len(_global_layer_node)
381+
382+
new_node = ModuleNode(
383+
layer, node_index, in_nodes, inputs_list, outputs_list, in_tensor_idxes, select_attrs(layer)
384+
)
385+
_global_layer_node.append(new_node)
386+
for idx, tensor in enumerate(outputs_list):
387+
tensor._info = (new_node, idx)
388+
366389

367390
class ModuleList(Module):
368391
"""

0 commit comments

Comments
 (0)