Skip to content

Commit fc89752

Browse files
committed
Added the SELU activation function
1 parent 157d9f3 commit fc89752

File tree

1 file changed

+28
-0
lines changed

1 file changed

+28
-0
lines changed

dynaml-core/src/main/scala-2.11/io/github/mandar2812/dynaml/models/neuralnets/TransferFunctions.scala

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,14 @@ object TransferFunctions {
8080
* */
8181
val DrecLin = (x: Double) => if(x < 0 ) 0.0 else 1.0
8282

83+
84+
val selu = (l: Double, a: Double) => (x: Double) => if(x <= 0d) l*(a*math.exp(x) - a) else x
85+
86+
87+
88+
val Dselu = (l: Double, a: Double) => (x: Double) => if(x <= 0d) l*a*math.exp(x) else 1.0
89+
90+
8391
/**
8492
* Function which returns
8593
* the appropriate activation
@@ -159,6 +167,26 @@ object VectorRecLin extends Activation[DenseVector[Double]] {
159167
override def run(data: DenseVector[Double]) = data.map(TransferFunctions.recLin)
160168
}
161169

170+
/**
171+
* Implementation of the SELU activation function
172+
* proposed by Hochreiter et. al
173+
* */
174+
case class VectorSELU(lambda: Double, alpha: Double) extends Activation[DenseVector[Double]] {
175+
176+
val SELU = TransferFunctions.selu(lambda, alpha)
177+
178+
val DSELU = TransferFunctions.Dselu(lambda, alpha)
179+
180+
override val grad = Scaler((x: DenseVector[Double]) => x.map(DSELU))
181+
182+
override def run(data: DenseVector[Double]) = data.map(SELU)
183+
}
184+
185+
/**
186+
* 'Magic' SELU activation with specific values of lambda and alpha
187+
* */
188+
object MagicSELU extends VectorSELU(1.050700987355, 1.67326324235)
189+
162190
@Experimental
163191
class VectorWavelet(motherWavelet: (Double) => Double, motherWaveletGr: (Double) => Double)
164192
extends Activation[DenseVector[Double]] {

0 commit comments

Comments
 (0)