@@ -2,19 +2,22 @@ package io.github.mandar2812.dynaml.models.neuralnets
22
33import io .github .mandar2812 .dynaml .graph .NeuralGraph
44
5+ import scala .collection .GenTraversableLike
6+
57/**
68 * @author mandar2812 date 17/04/2017.
79 *
810 * Base class for Neural Computation Stack
911 * implementations.
1012 * */
11- abstract class GenericNeuralStack [
12- P , I , T <: Traversable [NeuralLayer [P , I , I ]]
13- ](elements : T ) extends NeuralGraph [T , I , I ] {
13+ class GenericNeuralStack [
14+ P , I , Layer <: NeuralLayer [P , I , I ],
15+ T [Layer ] <: Traversable [Layer ] with GenTraversableLike [Layer , T [Layer ]]
16+ ](elements : T [Layer ]) extends NeuralGraph [T [Layer ], I , I ] {
1417
1518 self =>
1619
17- override protected val g : T = elements
20+ override protected val g : T [ Layer ] = elements
1821
1922 /**
2023 * Do a forward pass through the network outputting only the output layer activations.
@@ -27,12 +30,13 @@ P, I, T <: Traversable[NeuralLayer[P, I, I]]
2730 * Do a forward pass through the network outputting all the intermediate.
2831 * layer activations.
2932 * */
30- def forwardPropagate (x : I ): Traversable [I ] = g.scanLeft(x)((h, layer) => layer.forward(h))
33+ def forwardPropagate (x : I ): T [I ] = g.scanLeft(x)((h, layer) => layer.forward(h)). asInstanceOf [ T [ I ]]
3134
3235 /**
3336 * Batch version of [[forwardPropagate() ]]
3437 * */
35- def forwardPropagateBatch [G <: Traversable [I ]](d : G ): Traversable [G ] = g.scanLeft(d)((h, layer) => layer.forward(h))
38+ def forwardPropagateBatch [G <: Traversable [I ]](d : G ): T [G ] =
39+ g.scanLeft(d)((h, layer) => layer.forward(h)).asInstanceOf [T [G ]]
3640
3741 /**
3842 * Batch version of [[forwardPass() ]]
@@ -42,17 +46,27 @@ P, I, T <: Traversable[NeuralLayer[P, I, I]]
4246 /**
4347 * Slice the stack according to a range.
4448 * */
45- def apply (r : Range ): GenericNeuralStack [P , I , T ]
49+ def apply (r : Range ): GenericNeuralStack [P , I , Layer , T ] =
50+ new GenericNeuralStack (self.g.slice(r.min,r.max+ 1 ).asInstanceOf [T [Layer ]])
4651
4752 /**
4853 * Append another computation stack to the end of the
4954 * current one.
5055 * */
51- def ++ [G <: Traversable [NeuralLayer [P , I , I ]]](otherStack : GenericNeuralStack [P , I , G ]): GenericNeuralStack [P , I , T ]
56+ def ++ [
57+ L <: NeuralLayer [P , I , I ],
58+ G [L ] <: Traversable [L ] with GenTraversableLike [L , G [L ]]](
59+ otherStack : GenericNeuralStack [P , I , L , G ])
60+ : GenericNeuralStack [P , I , NeuralLayer [P , I , I ], T ] = new GenericNeuralStack [P , I , NeuralLayer [P , I , I ], T ](
61+ (self.g.map((l : Layer ) => l.asInstanceOf [NeuralLayer [P , I , I ]]) ++
62+ otherStack._layers.map((l : L ) => l.asInstanceOf [NeuralLayer [P , I , I ]]))
63+ .asInstanceOf [T [NeuralLayer [P , I , I ]]])
5264
5365 /**
5466 * Append a single computation layer to the stack.
5567 * */
56- def :+ (computationLayer : NeuralLayer [P , I , I ]): GenericNeuralStack [P , I , T ]
68+ def :+ (computationLayer : NeuralLayer [P , I , I ])
69+ : GenericNeuralStack [P , I , NeuralLayer [P , I , I ], T ] = self ++
70+ new GenericNeuralStack [P , I , NeuralLayer [P , I , I ], Seq ](Seq (computationLayer))
5771
5872}
0 commit comments