|
10 | 10 | import tensorflow as tf |
11 | 11 | from tensorlayerx.nn.layers.utils import (get_variable_with_initializer, random_normal) |
12 | 12 |
|
13 | | -__all__ = ['Module', 'Sequential', 'ModuleList', 'ModuleDict', 'Parameter', 'ParameterList', 'ParameterDict'] |
| 13 | +__all__ = ['Module', 'Sequential', 'ModuleList', 'ModuleDict', 'Parameter', 'ParameterList', 'ParameterDict', 'ParameterTuple'] |
14 | 14 |
|
15 | 15 | _global_layer_name_dict = {} |
16 | 16 | _global_layer_node = [] |
@@ -55,6 +55,7 @@ class Module(object): |
55 | 55 | def __init__(self, name=None, act=None, *args, **kwargs): |
56 | 56 | self._params = OrderedDict() |
57 | 57 | self._layers = OrderedDict() |
| 58 | + self._params_tuple = OrderedDict() |
58 | 59 | self._params_status = OrderedDict() |
59 | 60 | self._parameter_layout_dict = {} |
60 | 61 | self._create_time = int(time.time() * 1e9) |
@@ -145,6 +146,9 @@ def __setattr__(self, name, value): |
145 | 146 | raise TypeError("Expected type is Module, but got Parameter.") |
146 | 147 | self.insert_param_to_layer(name, value) |
147 | 148 |
|
| 149 | + elif isinstance(value, ParameterTuple): |
| 150 | + self.set_attr_for_parameter_tuple(name, value) |
| 151 | + |
148 | 152 | elif isinstance(value, Module): |
149 | 153 | if layers is None: |
150 | 154 | raise AttributeError("Can not assign layers before Module.__init__() call.") |
@@ -293,6 +297,27 @@ def _compute_shape(tensors): |
293 | 297 | shape_mem = tlx.get_tensor_shape(tensors) |
294 | 298 | return shape_mem |
295 | 299 |
|
| 300 | + def set_attr_for_parameter_tuple(self, name, value): |
| 301 | + """Set attr for parameter in ParameterTuple.""" |
| 302 | + params = self.__dict__.get('_params') |
| 303 | + params_tuple = self.__dict__.get('_params_tuple') |
| 304 | + if params is None: |
| 305 | + raise AttributeError("For 'Module', can not assign params before Module.__init__() is called.") |
| 306 | + exist_names = set("") |
| 307 | + |
| 308 | + for item in value: |
| 309 | + self.insert_param_to_layer(item.name, item, check_name=False) |
| 310 | + if item.name in exist_names: |
| 311 | + raise ValueError("The value {} , its name '{}' already exists.". |
| 312 | + format(value, item.name)) |
| 313 | + exist_names.add(item.name) |
| 314 | + |
| 315 | + if name in self.__dict__: |
| 316 | + del self.__dict__[name] |
| 317 | + if name in params: |
| 318 | + del params[name] |
| 319 | + params_tuple[name] = value |
| 320 | + |
296 | 321 | def insert_param_to_layer(self, param_name, param, check_name=True): |
297 | 322 | """ |
298 | 323 | Adds a parameter to the current layer. |
@@ -344,6 +369,11 @@ def __getattr__(self, name): |
344 | 369 | params_status = self.__dict__['_params_status'] |
345 | 370 | if name in params_status: |
346 | 371 | return params_status[name] |
| 372 | + if '_params_tuple' in self.__dict__: |
| 373 | + params_tuple = self.__dict__['_params_tuple'] |
| 374 | + if name in params_tuple: |
| 375 | + para_list = params_tuple[name] |
| 376 | + return para_list |
347 | 377 | raise AttributeError("'{}' object has no attribute '{}'.".format(type(self).__name__, name)) |
348 | 378 |
|
349 | 379 | def __delattr__(self, name): |
@@ -1258,6 +1288,33 @@ def __call__(self, input): |
1258 | 1288 | raise RuntimeError('ParameterDict should not be called.') |
1259 | 1289 |
|
1260 | 1290 |
|
| 1291 | +class ParameterTuple(tuple): |
| 1292 | + """ |
| 1293 | + ParameterTuple for storing tuple of parameters. |
| 1294 | + """ |
| 1295 | + def __new__(cls, iterable): |
| 1296 | + data = tuple(iterable) |
| 1297 | + ids = set() |
| 1298 | + orders = {} |
| 1299 | + for x in data: |
| 1300 | + if not isinstance(x, tf.Variable): |
| 1301 | + raise TypeError(f"ParameterTuple input should be `Parameter` collection." |
| 1302 | + f"But got a {type(iterable)}, {iterable}") |
| 1303 | + if id(x) not in ids: |
| 1304 | + ids.add(id(x)) |
| 1305 | + if x.name not in orders.keys(): |
| 1306 | + orders[x.name] = [0, x] |
| 1307 | + else: |
| 1308 | + if isinstance(orders[x.name], list): |
| 1309 | + name = x.name |
| 1310 | + orders[name][1].name = name + "_" + str(0) |
| 1311 | + x.name = x.name + "_" + str(1) |
| 1312 | + orders[name] = 1 |
| 1313 | + else: |
| 1314 | + orders[x.name] += 1 |
| 1315 | + x.name = x.name + "_" + str(orders[x.name]) |
| 1316 | + return tuple.__new__(ParameterTuple, tuple(data)) |
| 1317 | + |
1261 | 1318 | def _valid_index(layer_num, index): |
1262 | 1319 | if not isinstance(index, int): |
1263 | 1320 | raise TypeError("Index {} is not int type") |
|
0 commit comments