|
1 | 1 | #! /usr/bin/python |
2 | 2 | # -*- coding: utf-8 -*- |
3 | 3 |
|
4 | | -from .common import str2act, str2init |
| 4 | +from .common import str2act, str2init, tolist |
5 | 5 | from .common import _save_weights, _load_weights, _save_standard_weights_dict, _load_standard_weights_dict |
6 | 6 | from collections import OrderedDict |
7 | 7 | import time |
|
12 | 12 | __all__ = ['Module', 'Sequential', 'ModuleList'] |
13 | 13 |
|
14 | 14 | _global_layer_name_dict = {} |
| 15 | +_global_layer_node = [] |
15 | 16 |
|
16 | 17 |
|
17 | 18 | class Module(object): |
@@ -322,23 +323,6 @@ def insert_param_to_layer(self, param_name, param, check_name=True): |
322 | 323 | except: |
323 | 324 | pass |
324 | 325 |
|
325 | | - def _add_node(self, input_tensors, output_tensors): |
326 | | - """Add a LayerNode for this layer given input_tensors, output_tensors. |
327 | | -
|
328 | | - WARINING: This function should not be called from outside, it should only be called |
329 | | - in layer.__call__ when building static model. |
330 | | -
|
331 | | - Parameters |
332 | | - ---------- |
333 | | - input_tensors : Tensor or a list of tensors |
334 | | - Input tensors to this layer. |
335 | | - output_tensors : Tensor or a list of tensors |
336 | | - Output tensors to this layer. |
337 | | -
|
338 | | - """ |
339 | | - |
340 | | - raise NotImplementedError |
341 | | - |
342 | 326 | @property |
343 | 327 | def create_time(self): |
344 | 328 | return self._create_time |
@@ -597,6 +581,96 @@ def init_build(self, *inputs, **kwargs): |
597 | 581 | def str_to_init(self, initializer): |
598 | 582 | return str2init(initializer) |
599 | 583 |
|
| 584 | + def node_build(self, *inputs, **kwargs): |
| 585 | + self.forward(*inputs, **kwargs) |
| 586 | + return _global_layer_node |
| 587 | + |
| 588 | + def _add_node(self, input_tensors, output_tensors): |
| 589 | + """Add a ModuleNode for this layer given input_tensors, output_tensors. |
| 590 | +
|
| 591 | + This function should not be called from outside, it should only be called |
| 592 | + in __call__ when building static model. |
| 593 | +
|
| 594 | + Parameters |
| 595 | + ---------- |
| 596 | + input_tensors : Tensor or a list of tensors |
| 597 | + Input tensors to this layer. |
| 598 | + output_tensors : Tensor or a list of tensors |
| 599 | + Output tensors to this layer. |
| 600 | +
|
| 601 | + """ |
| 602 | + inputs_list = tolist(input_tensors) |
| 603 | + outputs_list = tolist(output_tensors) |
| 604 | + |
| 605 | + if self.__class__.__name__ in tlx.layers.inputs.__all__: |
| 606 | + # for InputLayer, there should be no in_nodes |
| 607 | + in_nodes = [] |
| 608 | + in_tensor_idxes = [0] |
| 609 | + else: |
| 610 | + in_nodes = [tensor for tensor in inputs_list] |
| 611 | + in_tensor_idxes = [idx for idx, tensor in enumerate(inputs_list)] |
| 612 | + # in_nodes = [tensor._info[0] for tensor in inputs_list] |
| 613 | + # in_tensor_idxes = [tensor._info[1] for tensor in inputs_list] |
| 614 | + node_index = len(_global_layer_node) |
| 615 | + |
| 616 | + new_node = ModuleNode(self, node_index, in_nodes, inputs_list, outputs_list, in_tensor_idxes) |
| 617 | + _global_layer_node.append(new_node) |
| 618 | + for idx, tensor in enumerate(outputs_list): |
| 619 | + tensor._info = (new_node, idx) |
| 620 | + |
| 621 | + |
| 622 | +class ModuleNode(object): |
| 623 | + """ |
| 624 | + The class :class:`ModuleNode` class represents a conceptional node for a layer. |
| 625 | +
|
| 626 | + ModuleNode is used for building static model and it is actually a light weighted |
| 627 | + wrapper over Layer. Specifically, it is used for building static computational graph |
| 628 | + (see _construct_graph() in tl.models.Model). In static model, each layer relates to |
| 629 | + one or more ModuleNode, and the connection relationship between layers is built upon |
| 630 | + ModuleNode. In addition, ModuleNode eases layer reuse and weights sharing. |
| 631 | +
|
| 632 | + Parameters |
| 633 | + ---------- |
| 634 | + layer : tl.layers.Layer |
| 635 | + A tl layer that wants to create a node. |
| 636 | + node_index : int |
| 637 | + Index of this node in layer._nodes. |
| 638 | + in_nodes :a list of ModuleNode |
| 639 | + Father nodes to this node. |
| 640 | + in_tensors : a list of tensors |
| 641 | + Input tensors to this node. |
| 642 | + out_tensors : a list of tensors |
| 643 | + Output tensors to this node. |
| 644 | + in_tensor_idxes : a list of int |
| 645 | + Indexes of each input tensor in its corresponding node's out_tensors. |
| 646 | +
|
| 647 | + Methods |
| 648 | + --------- |
| 649 | + __init__() |
| 650 | + Initializing the ModuleNode. |
| 651 | + __call__() |
| 652 | + (1) Forwarding through the layer. (2) Update its input/output tensors. |
| 653 | + """ |
| 654 | + |
| 655 | + def __init__(self, layer, node_index, in_nodes, in_tensors, out_tensors, in_tensor_idxes): |
| 656 | + self.layer = layer |
| 657 | + self.node_index = node_index |
| 658 | + self.in_nodes = in_nodes |
| 659 | + self.out_nodes = [] |
| 660 | + self.in_tensors = in_tensors |
| 661 | + self.out_tensors = out_tensors |
| 662 | + self.name = layer.name + "_node_{}".format(node_index) |
| 663 | + |
| 664 | + self.in_tensors_idxes = in_tensor_idxes |
| 665 | + self.visited = False |
| 666 | + |
| 667 | + def __call__(self, inputs, **kwargs): |
| 668 | + """(1) Forwarding through the layer. (2) Update its input/output tensors.""" |
| 669 | + outputs = self.layer.forward(inputs, **kwargs) |
| 670 | + self.in_tensors = tolist(inputs) |
| 671 | + self.out_tensors = tolist(outputs) |
| 672 | + return self.out_tensors |
| 673 | + |
600 | 674 |
|
601 | 675 | class Sequential(Module): |
602 | 676 | """ |
|
0 commit comments