99
1010from mindspore .common .parameter import Parameter
1111from mindspore .common .initializer import (
12- initializer , Constant , Normal , TruncatedNormal , Initializer , _assignment , _calculate_in_and_out , One , Zero ,
13- _calculate_fan_in_and_fan_out
12+ initializer , Constant , Normal , TruncatedNormal , Initializer , _assignment , _calculate_gain , One , Zero ,
13+ _calculate_fan_in_and_fan_out , _calculate_correct_fan
1414)
1515from mindspore .common .tensor import Tensor
1616from mindspore .ops import operations as P
@@ -321,31 +321,7 @@ def truncated_normal(shape, mean=0.0, stddev=1.0, dtype=mstype.float32, seed=Non
321321 return Tensor (arr , dtype = dtype )
322322
323323
324- class HeNormal (Initializer ):
325- r"""
326- he_normal: It draws samples from a truncated normal distribution centered on 0 with
327- stddev = sqrt(2 / fan_in) where fan_in is the number of input units in the weight tensor.
328-
329- Args:
330- arr (Array): The array to be assigned.
331-
332- Returns:
333- Array, assigned array.
334- """
335-
336- def __init__ (self , seed = None ):
337- super (HeNormal , self ).__init__ (seed = seed )
338- self .seed = seed
339-
340- def _initialize (self , arr ):
341- n_in , _ = _calculate_in_and_out (arr )
342- boundary = np .sqrt (2.0 / n_in )
343- random .seed (self .seed )
344- data = np .random .normal (- boundary , boundary , arr .shape )
345- _assignment (arr , data )
346-
347-
348- def he_normal (shape , dtype , seed = None ):
324+ def he_normal (shape , a = 0 , mode = 'fan_in' , nonlinearity = 'leaky_relu' , dtype = mstype .float32 , seed = None ):
349325 """
350326 He normal initializer.
351327
@@ -362,54 +338,43 @@ def he_normal(shape, dtype, seed=None):
362338 -------
363339 A tensor of the specified shape filled with he normal values.
364340 """
365- # shape = shape[::-1]
366341 arr = np .ndarray (shape )
367- init_obj = HeNormal (seed )
368- init_obj (arr )
342+ fan = _calculate_correct_fan (shape , mode )
343+ gain = _calculate_gain (nonlinearity , a )
344+ std = gain / math .sqrt (fan )
345+ data = np .random .normal (0 , std , shape )
346+ _assignment (arr , data )
369347 return Tensor (arr , dtype = dtype )
370348
371-
372- class XavierUniform (Initializer ):
373-
374- def __init__ (self , seed = None ):
375- super (XavierUniform , self ).__init__ (seed = seed )
376- self .seed = seed
377-
378- def _initialize (self , arr ):
379- n_in , n_out = _calculate_fan_in_and_fan_out (arr .shape )
380- boundary = math .sqrt (6.0 / (n_in + n_out ))
381- random .seed (self .seed )
382- data = np .random .uniform (- boundary , boundary , arr .shape )
383- _assignment (arr , data )
384-
385-
386- def xavier_uniform (shape , dtype , seed = None ):
349+ def he_uniform (shape , a = 0 , mode = 'fan_in' , nonlinearity = 'leaky_relu' ,dtype = mstype .float32 , seed = None ):
387350
388351 arr = np .ndarray (shape )
389- init_obj = XavierUniform (seed )
390- init_obj (arr )
352+ fan = _calculate_correct_fan (shape , mode )
353+ gain = _calculate_gain (nonlinearity , a )
354+ std = gain / math .sqrt (fan )
355+ boundary = math .sqrt (3.0 ) * std
356+ data = np .random .uniform (- boundary , boundary , shape )
357+ _assignment (arr , data )
391358 return Tensor (arr , dtype = dtype )
392359
393360
394- class XavierNormal ( Initializer ):
361+ def xavier_uniform ( shape , gain = 1.0 , dtype = mstype . float32 , seed = None ):
395362
396- def __init__ (self , seed = None ):
397- super (XavierNormal , self ).__init__ (seed = seed )
398- self .seed = seed
399-
400- def _initialize (self , arr ):
401- n_in , n_out = _calculate_fan_in_and_fan_out (arr .shape )
402- boundary = math .sqrt (2.0 / (n_in + n_out ))
403- random .seed (self .seed )
404- data = np .random .normal (0 , boundary , arr .shape )
405- _assignment (arr , data )
363+ arr = np .ndarray (shape )
364+ fan_in , fan_out = _calculate_fan_in_and_fan_out (shape )
365+ bound = gain * math .sqrt (6.0 / (fan_in + fan_out ))
366+ data = np .random .uniform (- bound , bound , shape )
367+ _assignment (arr , data )
368+ return Tensor (arr , dtype = dtype )
406369
407370
408- def xavier_normal (shape , dtype , seed = None ):
371+ def xavier_normal (shape , gain = 1.0 , dtype = mstype . float32 , seed = None ):
409372
410373 arr = np .ndarray (shape )
411- init_obj = XavierNormal (seed )
412- init_obj (arr )
374+ fan_in , fan_out = _calculate_fan_in_and_fan_out (shape )
375+ std = gain * math .sqrt (2.0 / float (fan_in + fan_out ))
376+ data = np .random .normal (0 , std , shape )
377+ _assignment (arr , data )
413378 return Tensor (arr , dtype = dtype )
414379
415380
0 commit comments