Skip to content

Commit 1c18b56

Browse files
committed
Added MixturePipe Class
1 parent f3774d9 commit 1c18b56

File tree

4 files changed

+69
-12
lines changed

4 files changed

+69
-12
lines changed
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
package io.github.mandar2812.dynaml.modelpipe
2+
3+
import breeze.linalg.{DenseMatrix, DenseVector}
4+
import breeze.stats.distributions.{ContinuousDistr, Moments}
5+
import io.github.mandar2812.dynaml.algebra.{PartitionedPSDMatrix, PartitionedVector}
6+
import io.github.mandar2812.dynaml.models.gp.AbstractGPRegressionModel
7+
import io.github.mandar2812.dynaml.models.stp.AbstractSTPRegressionModel
8+
import io.github.mandar2812.dynaml.models.{ContinuousProcessModel, GenContinuousMixtureModel, SecondOrderProcessModel, StochasticProcessMixtureModel}
9+
import io.github.mandar2812.dynaml.optimization.GloballyOptimizable
10+
import io.github.mandar2812.dynaml.pipes.DataPipe2
11+
import io.github.mandar2812.dynaml.probability.{ContinuousRVWithDistr, MultGaussianPRV, MultStudentsTPRV}
12+
import io.github.mandar2812.dynaml.probability.distributions.{BlockedMultiVariateGaussian, BlockedMultivariateStudentsT, HasErrorBars}
13+
14+
import scala.reflect.ClassTag
15+
16+
/**
17+
* Mixture Pipe takes a sequence of stochastic process models
18+
* and associated probability weights and returns a mixture model.
19+
* @author mandar2812 date 22/06/2017.
20+
* */
21+
abstract class MixturePipe[
22+
T, I: ClassTag, Y, YDomain, YDomainVar,
23+
BaseDistr <: ContinuousDistr[YDomain]
24+
with Moments[YDomain, YDomainVar]
25+
with HasErrorBars[YDomain],
26+
W1 <: ContinuousRVWithDistr[YDomain, BaseDistr],
27+
BaseProcess <: ContinuousProcessModel[T, I, Y, W1]
28+
with SecondOrderProcessModel[T, I, Y, Double, DenseMatrix[Double], W1]
29+
with GloballyOptimizable] extends
30+
DataPipe2[Seq[BaseProcess], DenseVector[Double],
31+
GenContinuousMixtureModel[
32+
T, I, Y, YDomain, YDomainVar,
33+
BaseDistr, W1, BaseProcess]]
34+
35+
36+
class GPMixturePipe[T, I: ClassTag] extends
37+
MixturePipe[T, I, Double, PartitionedVector, PartitionedPSDMatrix,
38+
BlockedMultiVariateGaussian, MultGaussianPRV,
39+
AbstractGPRegressionModel[T, I]] {
40+
41+
override def run(
42+
models: Seq[AbstractGPRegressionModel[T, I]],
43+
weights: DenseVector[Double]) =
44+
StochasticProcessMixtureModel(models, weights)
45+
}
46+
47+
class StudentTMixturePipe[T, I: ClassTag] extends
48+
MixturePipe[T, I, Double, PartitionedVector, PartitionedPSDMatrix,
49+
BlockedMultivariateStudentsT, MultStudentsTPRV,
50+
AbstractSTPRegressionModel[T, I]] {
51+
52+
override def run(
53+
models: Seq[AbstractSTPRegressionModel[T, I]],
54+
weights: DenseVector[Double]) =
55+
StochasticProcessMixtureModel(models, weights)
56+
}

dynaml-core/src/main/scala-2.11/io/github/mandar2812/dynaml/optimization/MixtureMachine.scala

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,8 @@ import io.github.mandar2812.dynaml.probability.distributions.HasErrorBars
1010
import scala.reflect.ClassTag
1111

1212
/**
13-
* Created by mandar on 21/06/2017.
14-
*/
13+
* @author mandar2812 date 21/06/2017.
14+
* */
1515
abstract class MixtureMachine[
1616
T, I: ClassTag, Y, YDomain, YDomainVar,
1717
BaseDistr <: ContinuousDistr[YDomain]
@@ -71,7 +71,8 @@ BaseProcess <: ContinuousProcessModel[T, I, Y, W1]
7171
if(policy == "CSA") performCSA(initialConfig, options)
7272
else getEnergyLandscape(initialConfig, options, meanFieldPrior)
7373

74-
protected def modelProbabilities: DataPipe[Seq[(Double, Map[String, Double])], Seq[(Double, Map[String, Double])]] =
74+
protected def modelProbabilities
75+
: DataPipe[Seq[(Double, Map[String, Double])], Seq[(Double, Map[String, Double])]] =
7576
DataPipe(ProbGPCommMachine.calculateModelWeightsSigmoid(baselinePolicy))
7677

7778
override def optimize(

dynaml-core/src/main/scala-2.11/io/github/mandar2812/dynaml/optimization/ProbGPMixtureMachine.scala

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,9 @@ package io.github.mandar2812.dynaml.optimization
22

33
import breeze.linalg.DenseVector
44
import io.github.mandar2812.dynaml.algebra.{PartitionedPSDMatrix, PartitionedVector}
5+
import io.github.mandar2812.dynaml.modelpipe._
56
import io.github.mandar2812.dynaml.models.StochasticProcessMixtureModel
6-
import io.github.mandar2812.dynaml.models.gp.{AbstractGPRegressionModel, GaussianProcessMixture}
7+
import io.github.mandar2812.dynaml.models.gp.AbstractGPRegressionModel
78
import io.github.mandar2812.dynaml.pipes.{DataPipe, DataPipe2}
89
import io.github.mandar2812.dynaml.probability.MultGaussianPRV
910
import io.github.mandar2812.dynaml.probability.distributions.BlockedMultiVariateGaussian
@@ -30,13 +31,12 @@ class ProbGPMixtureMachine[T, I: ClassTag](
3031

3132
implicit val transform: DataPipe[T, Seq[(I, Double)]] = DataPipe(system.dataAsSeq)
3233

33-
override val confToModel = DataPipe((model_state: Map[String, Double]) =>
34-
AbstractGPRegressionModel(
35-
kernelPipe(model_state), noisePipe(model_state),
36-
system.mean)(system.data, system.npoints))
34+
override val confToModel = DataPipe(
35+
(model_state: Map[String, Double]) =>
36+
AbstractGPRegressionModel(
37+
kernelPipe(model_state), noisePipe(model_state),
38+
system.mean)(system.data, system.npoints))
3739

38-
override val mixturePipe = DataPipe2(
39-
(models: Seq[AbstractGPRegressionModel[T, I]], weights: DenseVector[Double]) =>
40-
StochasticProcessMixtureModel[T, I](models, weights))
40+
override val mixturePipe = new GPMixturePipe[T, I]
4141

4242
}

scripts/stochasticPriors.sc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import breeze.linalg.eig
2-
import breeze.stats.distributions.{ContinuousDistr, Gamma, Gaussian}
2+
import breeze.stats.distributions.{ContinuousDistr, Gamma}
33
import io.github.mandar2812.dynaml.kernels._
44
import io.github.mandar2812.dynaml.models.bayes.{LinearTrendESGPrior, LinearTrendGaussianPrior}
55
import io.github.mandar2812.dynaml.probability._

0 commit comments

Comments
 (0)