Skip to content

Commit a961c69

Browse files
committed
Improvements to GenericNeuralStack class
1 parent c731896 commit a961c69

File tree

3 files changed

+46
-20
lines changed

3 files changed

+46
-20
lines changed

dynaml-core/src/main/scala-2.11/io/github/mandar2812/dynaml/models/neuralnets/GenericNeuralStack.scala

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,19 +2,22 @@ package io.github.mandar2812.dynaml.models.neuralnets
22

33
import io.github.mandar2812.dynaml.graph.NeuralGraph
44

5+
import scala.collection.GenTraversableLike
6+
57
/**
68
* @author mandar2812 date 17/04/2017.
79
*
810
* Base class for Neural Computation Stack
911
* implementations.
1012
* */
11-
abstract class GenericNeuralStack[
12-
P, I, T <: Traversable[NeuralLayer[P, I, I]]
13-
](elements: T) extends NeuralGraph[T, I, I] {
13+
class GenericNeuralStack[
14+
P, I, Layer <: NeuralLayer[P, I, I],
15+
T[Layer] <: Traversable[Layer] with GenTraversableLike[Layer, T[Layer]]
16+
](elements: T[Layer]) extends NeuralGraph[T[Layer], I, I] {
1417

1518
self =>
1619

17-
override protected val g: T = elements
20+
override protected val g: T[Layer] = elements
1821

1922
/**
2023
* Do a forward pass through the network outputting only the output layer activations.
@@ -27,12 +30,13 @@ P, I, T <: Traversable[NeuralLayer[P, I, I]]
2730
* Do a forward pass through the network outputting all the intermediate.
2831
* layer activations.
2932
* */
30-
def forwardPropagate(x: I): Traversable[I] = g.scanLeft(x)((h, layer) => layer.forward(h))
33+
def forwardPropagate(x: I): T[I] = g.scanLeft(x)((h, layer) => layer.forward(h)).asInstanceOf[T[I]]
3134

3235
/**
3336
* Batch version of [[forwardPropagate()]]
3437
* */
35-
def forwardPropagateBatch[G <: Traversable[I]](d: G): Traversable[G] = g.scanLeft(d)((h, layer) => layer.forward(h))
38+
def forwardPropagateBatch[G <: Traversable[I]](d: G): T[G] =
39+
g.scanLeft(d)((h, layer) => layer.forward(h)).asInstanceOf[T[G]]
3640

3741
/**
3842
* Batch version of [[forwardPass()]]
@@ -42,17 +46,27 @@ P, I, T <: Traversable[NeuralLayer[P, I, I]]
4246
/**
4347
* Slice the stack according to a range.
4448
* */
45-
def apply(r: Range): GenericNeuralStack[P, I, T]
49+
def apply(r: Range): GenericNeuralStack[P, I, Layer, T] =
50+
new GenericNeuralStack(self.g.slice(r.min,r.max+1).asInstanceOf[T[Layer]])
4651

4752
/**
4853
* Append another computation stack to the end of the
4954
* current one.
5055
* */
51-
def ++[G <: Traversable[NeuralLayer[P, I, I]]](otherStack: GenericNeuralStack[P, I, G]): GenericNeuralStack[P, I, T]
56+
def ++[
57+
L <: NeuralLayer[P, I, I],
58+
G[L] <: Traversable[L] with GenTraversableLike[L, G[L]]](
59+
otherStack: GenericNeuralStack[P, I, L, G])
60+
: GenericNeuralStack[P, I, NeuralLayer[P, I, I], T] = new GenericNeuralStack[P, I, NeuralLayer[P, I, I], T](
61+
(self.g.map((l: Layer) => l.asInstanceOf[NeuralLayer[P, I, I]]) ++
62+
otherStack._layers.map((l: L) => l.asInstanceOf[NeuralLayer[P, I, I]]))
63+
.asInstanceOf[T[NeuralLayer[P, I, I]]])
5264

5365
/**
5466
* Append a single computation layer to the stack.
5567
* */
56-
def :+(computationLayer: NeuralLayer[P, I, I]): GenericNeuralStack[P, I, T]
68+
def :+(computationLayer: NeuralLayer[P, I, I])
69+
: GenericNeuralStack[P, I, NeuralLayer[P, I, I], T] = self ++
70+
new GenericNeuralStack[P, I, NeuralLayer[P, I, I], Seq](Seq(computationLayer))
5771

5872
}

dynaml-core/src/main/scala-2.11/io/github/mandar2812/dynaml/models/neuralnets/LazyNeuralStack.scala

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,32 +1,39 @@
11
package io.github.mandar2812.dynaml.models.neuralnets
22

3+
import scala.collection.GenTraversableLike
4+
35

46
/**
57
* @author mandar2812 date: 17/04/2017.
68
*
79
* A computation stack whose layers are lazily computed
810
* */
911
class LazyNeuralStack[P, I](elements: Stream[NeuralLayer[P, I, I]]) extends
10-
GenericNeuralStack[P, I, Stream[NeuralLayer[P, I, I]]](elements) {
12+
GenericNeuralStack[P, I, NeuralLayer[P, I, I], Stream](elements) {
1113

1214
self =>
1315

1416
/**
1517
* Slice the stack according to a range.
1618
**/
17-
override def apply(r: Range) = new LazyNeuralStack[P, I](g.slice(r.min, r.max + 1))
19+
override def apply(r: Range): LazyNeuralStack[P, I] = new LazyNeuralStack[P, I](g.slice(r.min, r.max + 1))
20+
1821

1922
/**
2023
* Append another computation stack to the end of the
2124
* current one.
2225
**/
23-
override def ++[G <: Traversable[NeuralLayer[P, I, I]]](
24-
otherStack: GenericNeuralStack[P, I, G]) = new LazyNeuralStack[P, I](self.g ++ otherStack._layers)
26+
override def ++[
27+
L <: NeuralLayer[P, I, I],
28+
G[L] <: Traversable[L] with GenTraversableLike[L, G[L]]](
29+
otherStack: GenericNeuralStack[P, I, L, G]) =
30+
new LazyNeuralStack[P, I](self.g ++ otherStack._layers.asInstanceOf[Stream[NeuralLayer[P, I, I]]])
2531

2632
/**
2733
* Append a single computation layer to the stack.
2834
**/
29-
override def :+(computationLayer: NeuralLayer[P, I, I]) = new LazyNeuralStack[P, I](self.g :+ computationLayer)
35+
override def :+(computationLayer: NeuralLayer[P, I, I]): LazyNeuralStack[P, I] =
36+
new LazyNeuralStack[P, I](self.g :+ computationLayer)
3037
}
3138

3239
object LazyNeuralStack {

dynaml-core/src/main/scala-2.11/io/github/mandar2812/dynaml/models/neuralnets/NeuralStack.scala

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,15 @@
11
package io.github.mandar2812.dynaml.models.neuralnets
22

3+
import scala.collection.GenTraversableLike
34
import breeze.linalg.{DenseMatrix, DenseVector}
4-
import io.github.mandar2812.dynaml.graph.NeuralGraph
55
import io.github.mandar2812.dynaml.pipes.DataPipe
66

7+
78
/**
89
* A network, represented as a stack of [[NeuralLayer]] objects.
910
* */
1011
class NeuralStack[P, I](elements: Seq[NeuralLayer[P, I, I]])
11-
extends GenericNeuralStack[P, I, Seq[NeuralLayer[P, I, I]]](elements) {
12+
extends GenericNeuralStack[P, I, NeuralLayer[P, I, I], Seq](elements) {
1213

1314
self =>
1415

@@ -42,14 +43,18 @@ class NeuralStack[P, I](elements: Seq[NeuralLayer[P, I, I]])
4243
/**
4344
* Slice the stack according to a range.
4445
* */
45-
def apply(r: Range): NeuralStack[P, I] = NeuralStack(g.slice(r.min, r.max + 1):_*)
46+
override def apply(r: Range): NeuralStack[P, I] = NeuralStack(g.slice(r.min, r.max + 1):_*)
47+
4648

4749
/**
4850
* Append another computation stack to the end of the
4951
* current one.
50-
* */
51-
override def ++[T <: Traversable[NeuralLayer[P, I, I]]](otherStack: GenericNeuralStack[P, I, T]) =
52-
new NeuralStack(self.g ++ otherStack._layers)
52+
**/
53+
override def ++[
54+
L <: NeuralLayer[P, I, I],
55+
G[L] <: Traversable[L] with GenTraversableLike[L, G[L]]](
56+
otherStack: GenericNeuralStack[P, I, L, G]) =
57+
new NeuralStack(self.g ++ otherStack._layers.asInstanceOf[Seq[NeuralLayer[P, I, I]]])
5358

5459
/**
5560
* Append a single computation layer to the stack.

0 commit comments

Comments
 (0)