@@ -271,15 +271,15 @@ def relu6(x):
271271
272272class LeakyReLU (Cell ):
273273
274- def __init__ (self , alpha = 0.2 ):
274+ def __init__ (self , negative_slope = 0.01 ):
275275 super (LeakyReLU , self ).__init__ ()
276- self .leakyrelu = ms .nn .LeakyReLU (alpha = alpha )
276+ self .leakyrelu = ms .nn .LeakyReLU (alpha = negative_slope )
277277
278278 def construct (self , x ):
279279 return self .leakyrelu (x )
280280
281281
282- def leaky_relu (x , alpha = 0.2 ):
282+ def leaky_relu (x , negative_slope = 0.2 ):
283283 """
284284 Compute the Leaky ReLU activation function.
285285
@@ -294,9 +294,9 @@ def leaky_relu(x, alpha=0.2):
294294 The activation value.
295295 """
296296
297- leaky_relu = LeakyReLU (alpha = alpha )
297+ leaky_relu = ms . nn . LeakyReLU (alpha = negative_slope )
298298 output = leaky_relu (x )
299- return leaky_relu
299+ return output
300300
301301
302302class Softplus (Cell ):
@@ -348,15 +348,15 @@ def sigmoid(x):
348348
349349class Softmax (Cell ):
350350
351- def __init__ (self ):
351+ def __init__ (self , axis = - 1 ):
352352 super (Softmax , self ).__init__ ()
353- self .softmax = P .Softmax ()
353+ self .softmax = P .Softmax (axis )
354354
355355 def construct (self , x ):
356356 return self .softmax (x )
357357
358358
359- def softmax (logits , axis = None ):
359+ def softmax (logits , axis = - 1 ):
360360 """
361361 Computes softmax activations.
362362
@@ -2392,3 +2392,22 @@ def __init__(
23922392
23932393 def construct (self , inputs ):
23942394 raise NotImplementedError
2395+
2396+ class PReLU (Cell ):
2397+
2398+ def __init__ (self , data_format ):
2399+ super (PReLU , self ).__init__ ()
2400+ self .data_format = data_format
2401+
2402+ def __call__ (self , input , weight ):
2403+
2404+ prelu = P .PReLU ()
2405+ v = prelu (input , F .cast (weight , input .dtype ))
2406+ return v
2407+
2408+
2409+ def prelu (input , weight , data_format ):
2410+
2411+ prelu = P .PReLU ()
2412+ v = prelu (input , F .cast (weight , input .dtype ))
2413+ return v
0 commit comments