|
1 | 1 | #! /usr/bin/python |
2 | 2 | # -*- coding: utf-8 -*- |
3 | 3 |
|
4 | | -from .common import str2act, str2init, random_normal, tolist, construct_graph, ModuleNode |
| 4 | +from .common import check_parameter, str2act, str2init, random_normal, tolist, construct_graph, ModuleNode, select_attrs |
5 | 5 | from .common import _save_weights, _load_weights, _save_standard_weights_dict, _load_standard_weights_dict |
6 | 6 | from mindspore.nn import Cell |
7 | 7 | import tensorlayerx as tlx |
@@ -105,7 +105,7 @@ def _get_weights(self, var_name, shape, init=random_normal(), trainable=True, tr |
105 | 105 | if len(shape) == 3: |
106 | 106 | shape = shape[::-1] |
107 | 107 | if len(shape) == 4: |
108 | | - if not transposed and self.data_format == 'NHWC': |
| 108 | + if not transposed and self.data_format in ['NHWC', 'channels_last']: |
109 | 109 | shape = (shape[3], shape[0], shape[1], shape[2]) |
110 | 110 | else: |
111 | 111 | shape = (shape[3], shape[2], shape[0], shape[1]) |
@@ -265,6 +265,9 @@ def all_weights(self): |
265 | 265 | def str_to_init(self, initializer): |
266 | 266 | return str2init(initializer) |
267 | 267 |
|
| 268 | + def check_param(self, param, dim='2d'): |
| 269 | + return check_parameter(param, dim) |
| 270 | + |
268 | 271 | def insert_child_to_layer(self, child_name, child): |
269 | 272 | """ |
270 | 273 | Adds a child layer to the current layer. |
@@ -333,7 +336,7 @@ def _add_node(self, input_tensors, output_tensors): |
333 | 336 | in_tensor_idxes = [tensor._info[1] for tensor in inputs_list] |
334 | 337 | node_index = len(_global_layer_node) |
335 | 338 |
|
336 | | - new_node = ModuleNode(self, node_index, in_nodes, inputs_list, outputs_list, in_tensor_idxes) |
| 339 | + new_node = ModuleNode(self, node_index, in_nodes, inputs_list, outputs_list, in_tensor_idxes, select_attrs(self)) |
337 | 340 | _global_layer_node.append(new_node) |
338 | 341 | for idx, tensor in enumerate(outputs_list): |
339 | 342 | tensor._info = (new_node, idx) |
|
0 commit comments