@@ -23,13 +23,13 @@ export interface ActivationFunction {
2323}
2424
2525export class TanHFunc implements ActivationFunction {
26- output ( math : NDArrayMath , x : NDArray ) {
26+ output < T extends NDArray > ( math : NDArrayMath , x : T ) {
2727 return math . scope ( ( ) => {
2828 return math . tanh ( x ) ;
2929 } ) ;
3030 }
3131
32- der ( math : NDArrayMath , x : NDArray , y : NDArray ) {
32+ der < T extends NDArray > ( math : NDArrayMath , x : T , y : T ) {
3333 return math . scope ( ( ) => {
3434 const ySquared = math . elementWiseMul ( y , y ) ;
3535 // 1 - y^2.
@@ -39,27 +39,27 @@ export class TanHFunc implements ActivationFunction {
3939}
4040
4141export class ReLUFunc implements ActivationFunction {
42- output ( math : NDArrayMath , x : NDArray ) {
42+ output < T extends NDArray > ( math : NDArrayMath , x : T ) {
4343 return math . scope ( ( ) => {
4444 return math . relu ( x ) ;
4545 } ) ;
4646 }
4747
48- der ( math : NDArrayMath , x : NDArray , y : NDArray ) {
48+ der < T extends NDArray > ( math : NDArrayMath , x : T , y : T ) {
4949 return math . scope ( ( ) => {
5050 return math . step ( x ) ;
5151 } ) ;
5252 }
5353}
5454
5555export class SigmoidFunc implements ActivationFunction {
56- output ( math : NDArrayMath , x : NDArray ) {
56+ output < T extends NDArray > ( math : NDArrayMath , x : T ) {
5757 return math . scope ( ( ) => {
5858 return math . sigmoid ( x ) ;
5959 } ) ;
6060 }
6161
62- der ( math : NDArrayMath , x : NDArray , y : NDArray ) {
62+ der < T extends NDArray > ( math : NDArrayMath , x : T , y : T ) {
6363 return math . scope ( ( ) => {
6464 // y * (1 - y) = y - y^2
6565 const ySquared = math . elementWiseMul ( y , y ) ;
@@ -69,13 +69,13 @@ export class SigmoidFunc implements ActivationFunction {
6969}
7070
7171export class SquareFunc implements ActivationFunction {
72- output ( math : NDArrayMath , x : NDArray ) {
72+ output < T extends NDArray > ( math : NDArrayMath , x : T ) {
7373 return math . scope ( ( ) => {
7474 return math . elementWiseMul ( x , x ) ;
7575 } ) ;
7676 }
7777
78- der ( math : NDArrayMath , x : NDArray , y : NDArray ) {
78+ der < T extends NDArray > ( math : NDArrayMath , x : T , y : T ) {
7979 return math . scope ( ( ) => {
8080 // dy/dx = 2*x.
8181 return math . scalarTimesArray ( Scalar . TWO , x ) ;
0 commit comments