11#! /usr/bin/python
22# -*- coding: utf-8 -*-
33
4- from .common import str2act , str2init , random_normal
4+ from .common import str2act , str2init , random_normal , tolist , construct_graph , ModuleNode
55from .common import _save_weights , _load_weights , _save_standard_weights_dict , _load_standard_weights_dict
66from mindspore .nn import Cell
77import tensorlayerx as tlx
8- from collections import OrderedDict
8+ import mindspore as ms
99from mindspore import log as logger
1010import inspect
1111from mindspore import context
1616__all__ = ['Module' , 'Sequential' , 'ModuleList' , 'ModuleDict' ]
1717
1818_global_layer_name_dict = {}
19+ _global_layer_node = []
1920
2021
2122class Module (Cell ):
@@ -138,41 +139,40 @@ def _compute_shape(tensors):
138139 shape_mem = tlx .get_tensor_shape (tensors )
139140 return shape_mem
140141
141- def __call__ (self , * inputs , ** kwargs ):
142+ def __call__ (self , * args , ** kwargs ):
142143 if self .__class__ .construct is Cell .construct :
143- logger .warning (
144- f"The '{ self .__class__ } ' does not override the method 'construct', "
145- f"will call the super class(Cell) 'construct'."
146- )
144+ logger .warning (f"The '{ self .__class__ } ' does not override the method 'construct', "
145+ f"will call the super class(Cell) 'construct'." )
147146 if kwargs :
148- bound_args = inspect .signature (self .construct ).bind (* inputs , ** kwargs )
149- inputs = bound_args .args
150- kwargs = bound_args .kwargs
147+ bound_arguments = inspect .signature (self .construct ).bind (* args , ** kwargs )
148+ bound_arguments .apply_defaults ()
149+ args = bound_arguments .args
150+ kwargs = bound_arguments .kwargs
151151
152152 # Run in Graph mode.
153- if context .get_context ("mode" ) == context .GRAPH_MODE :
154- raise NotImplemented ("GRAPH MODE is not supported, please select PYNATIVE MODE." )
153+ if context ._get_mode () == context .GRAPH_MODE :
154+ self ._check_construct_args (* args , ** kwargs )
155+ if self .enable_hook :
156+ raise ValueError ("For 'Cell', it's not support hook function in graph mode, please use "
157+ "context.set_context to set pynative mode." )
158+ out = self .compile_and_run (* args )
159+ return out
155160
156161 # Run in PyNative mode.
157162 if _pynative_executor .is_top_cell ():
158163 _pynative_executor .set_lazy_build (True )
159164 # There many Casts in parameter_broadcast. Enable lazy_build and build faster.
160165 self ._do_parameter_broadcast ()
161166
162- for item in inputs :
163- if isinstance (item , numpy .ndarray ):
164- raise TypeError ("The cell inputs should not be numpy arrays." )
167+ for item in args :
168+ if isinstance (item , ms .Tensor ) and item .has_init :
169+ item .init_data ()
170+ elif isinstance (item , numpy .ndarray ):
171+ raise TypeError ("For 'Cell', inputs should not be numpy array." )
165172 if self .requires_grad is True :
166173 _pynative_executor .set_grad_flag (True )
167- _pynative_executor .new_graph (self , * inputs , ** kwargs )
168- cast_inputs = list ()
169- if hasattr (self , "_mindspore_flags" ):
170- if self ._mindspore_flags .get ('fp16' ):
171- cast_inputs = self ._cast_mixed_precision_inputs (inputs , tlx .float16 )
172- if self ._mindspore_flags .get ('fp32' ):
173- cast_inputs = self ._cast_mixed_precision_inputs (inputs , tlx .float32 )
174- if not cast_inputs :
175- cast_inputs = inputs
174+ _pynative_executor .new_graph (self , * args , ** kwargs )
175+ cast_inputs = self .auto_cast_inputs (args )
176176
177177 with self .CellGuard ():
178178 try :
@@ -182,29 +182,13 @@ def __call__(self, *inputs, **kwargs):
182182 raise err
183183
184184 if _pynative_executor .is_top_cell ():
185- _pynative_executor .execute_all_task ()
185+ _pynative_executor .execute_lazy_task ()
186186
187187 if isinstance (output , Parameter ):
188188 output = output .data
189- _pynative_executor .end_graph (self , output , * inputs , ** kwargs )
189+ _pynative_executor .end_graph (self , output , * args , ** kwargs )
190190 return output
191191
192- def _add_node (self , input_tensors , output_tensors ):
193- """Add a LayerNode for this layer given input_tensors, output_tensors.
194-
195- WARINING: This function should not be called from outside, it should only be called
196- in layer.__call__ when building static model.
197-
198- Parameters
199- ----------
200- input_tensors : Tensor or a list of tensors
201- Input tensors to this layer.
202- output_tensors : Tensor or a list of tensors
203- Output tensors to this layer.
204-
205- """
206- raise NotImplementedError
207-
208192 def set_train (self ):
209193 """
210194 Sets the cell to training mode.
@@ -310,6 +294,49 @@ def init_build(self, *inputs, **kwargs):
310294
311295 self .forward (* inputs , ** kwargs )
312296
297+ def build_graph (self , * inputs , ** kwargs ):
298+ # Add nodes only when the composition is needed.
299+ layers = self .cells_and_names (name_prefix = '' )
300+ for layer_name , layer in layers :
301+ if isinstance (layer , Module ):
302+ layer ._build_graph = True
303+
304+ outputs = self .forward (* inputs , ** kwargs )
305+ self .inputs = inputs
306+ self .outputs = outputs
307+ self ._node_by_depth , self ._all_layers = construct_graph (self .inputs , self .outputs )
308+ return self ._node_by_depth , self ._all_layers
309+
310+ def _add_node (self , input_tensors , output_tensors ):
311+ """Add a ModuleNode for this layer given input_tensors, output_tensors.
312+
313+ This function should not be called from outside, it should only be called
314+ in __call__ when building static model.
315+
316+ Parameters
317+ ----------
318+ input_tensors : Tensor or a list of tensors
319+ Input tensors to this layer.
320+ output_tensors : Tensor or a list of tensors
321+ Output tensors to this layer.
322+
323+ """
324+ inputs_list = tolist (input_tensors )
325+ outputs_list = tolist (output_tensors )
326+ if self .__class__ .__name__ in tlx .layers .inputs .__all__ :
327+ # for InputLayer, there should be no in_nodes
328+ in_nodes = []
329+ in_tensor_idxes = [0 ]
330+ else :
331+ in_nodes = [tensor ._info [0 ] for tensor in inputs_list ]
332+ in_tensor_idxes = [tensor ._info [1 ] for tensor in inputs_list ]
333+ node_index = len (_global_layer_node )
334+
335+ new_node = ModuleNode (self , node_index , in_nodes , inputs_list , outputs_list , in_tensor_idxes )
336+ _global_layer_node .append (new_node )
337+ for idx , tensor in enumerate (outputs_list ):
338+ tensor ._info = (new_node , idx )
339+
313340
314341class Sequential (Module ):
315342 """
0 commit comments