Skip to content

Commit 0deabc5

Browse files
committed
Added abstract class to create mixture models
1 parent bc6c815 commit 0deabc5

File tree

5 files changed

+163
-106
lines changed

5 files changed

+163
-106
lines changed

dynaml-core/src/main/scala-2.11/io/github/mandar2812/dynaml/models/sgp/ESGPModel.scala

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,12 +39,13 @@ import org.apache.log4j.Logger
3939
import scala.reflect.ClassTag
4040

4141
/**
42-
* @author mandar2812 date: 28/02/2017.
43-
*
4442
* Implementation of Extended Skew-Gaussian Process regression model.
4543
* This is represented with a finite dimensional [[BlockedMESNRV]]
4644
* distribution of Adcock and Schutes.
47-
*/
45+
*
46+
* @author mandar2812 date 28/02/2017.
47+
*
48+
* */
4849
abstract class ESGPModel[T, I: ClassTag](
4950
cov: LocalScalarKernel[I], n: LocalScalarKernel[I],
5051
data: T, num: Int, lambda: Double, tau: Double,

dynaml-core/src/main/scala-2.11/io/github/mandar2812/dynaml/models/stp/AbstractSTPRegressionModel.scala

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,10 @@ import org.apache.log4j.Logger
3636
import scala.reflect.ClassTag
3737

3838
/**
39-
* @author mandar2812 date 26/08/16.
4039
* Implementation of a Students' T Regression model.
41-
*/
40+
* @author mandar2812 date 26/08/16.
41+
*
42+
* */
4243
abstract class AbstractSTPRegressionModel[T, I](
4344
mu: Double, cov: LocalScalarKernel[I],
4445
n: LocalScalarKernel[I],

dynaml-core/src/main/scala-2.11/io/github/mandar2812/dynaml/optimization/GloballyOptimizable.scala

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,13 @@ under the License.
1919
package io.github.mandar2812.dynaml.optimization
2020

2121
/**
22-
* @author mandar2812, datum: 23/6/15.
23-
*
24-
* We define a common binding
25-
* characteristic between all "globally optimizable"
26-
* models i.e. models where hyper-parameters can
27-
* be optimized/tuned.
28-
*/
22+
* A common binding characteristic between all "globally optimizable"
23+
* models i.e. models where hyper-parameters can
24+
* be optimized/tuned.
25+
*
26+
* @author mandar2812, date 23/6/15.
27+
*
28+
* */
2929
trait GloballyOptimizable {
3030

3131
/**
Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
package io.github.mandar2812.dynaml.optimization
2+
3+
import breeze.linalg.{DenseMatrix, DenseVector}
4+
import breeze.stats.distributions.{ContinuousDistr, Moments}
5+
import io.github.mandar2812.dynaml.models.{ContinuousProcessModel, GenContinuousMixtureModel, SecondOrderProcessModel}
6+
import io.github.mandar2812.dynaml.pipes.{DataPipe, DataPipe2}
7+
import io.github.mandar2812.dynaml.probability.ContinuousRVWithDistr
8+
import io.github.mandar2812.dynaml.probability.distributions.HasErrorBars
9+
10+
import scala.reflect.ClassTag
11+
12+
/**
13+
* Created by mandar on 21/06/2017.
14+
*/
15+
abstract class MixtureMachine[
16+
T, I: ClassTag, Y, YDomain, YDomainVar,
17+
BaseDistr <: ContinuousDistr[YDomain]
18+
with Moments[YDomain, YDomainVar]
19+
with HasErrorBars[YDomain],
20+
W1 <: ContinuousRVWithDistr[YDomain, BaseDistr],
21+
BaseProcess <: ContinuousProcessModel[T, I, Y, W1]
22+
with SecondOrderProcessModel[T, I, Y, Double, DenseMatrix[Double], W1]
23+
with GloballyOptimizable](model: BaseProcess) extends
24+
AbstractCSA[BaseProcess, GenContinuousMixtureModel[
25+
T, I, Y, YDomain, YDomainVar,
26+
BaseDistr, W1, BaseProcess]](model) {
27+
28+
29+
val confToModel: DataPipe[Map[String, Double], BaseProcess]
30+
31+
val mixturePipe: DataPipe2[
32+
Seq[BaseProcess],
33+
DenseVector[Double],
34+
GenContinuousMixtureModel[
35+
T, I, Y, YDomain, YDomainVar,
36+
BaseDistr, W1, BaseProcess]]
37+
38+
39+
protected var policy: String = "CSA"
40+
41+
protected var baselinePolicy: String = "max"
42+
43+
def _policy = policy
44+
45+
def setPolicy(p: String): this.type = {
46+
if(p == "CSA" || p == "Coupled Simulated Annealing")
47+
policy = "CSA"
48+
else
49+
policy = "GS"
50+
51+
this
52+
}
53+
54+
def setBaseLinePolicy(p: String): this.type = {
55+
56+
if(p == "avg" || p == "mean" || p == "average")
57+
baselinePolicy = "mean"
58+
else if(p == "min")
59+
baselinePolicy = "min"
60+
else if(p == "max")
61+
baselinePolicy = "max"
62+
else
63+
baselinePolicy = "mean"
64+
65+
this
66+
}
67+
68+
protected def calculateEnergyLandscape(initialConfig: Map[String, Double], options: Map[String, String]) =
69+
if(policy == "CSA") performCSA(initialConfig, options)
70+
else getEnergyLandscape(initialConfig, options, meanFieldPrior)
71+
72+
protected def modelProbabilities = DataPipe(ProbGPCommMachine.calculateModelWeightsSigmoid(baselinePolicy) _)
73+
74+
override def optimize(
75+
initialConfig: Map[String, Double],
76+
options: Map[String, String]) = {
77+
78+
//Find out the blocked hyper parameters and their values
79+
val blockedHypParams = system._hyper_parameters.filterNot(initialConfig.contains)
80+
81+
val blockedState = system._current_state.filterKeys(blockedHypParams.contains)
82+
83+
val energyLandscape = calculateEnergyLandscape(initialConfig, options)
84+
85+
86+
//Calculate the weights of each configuration
87+
val (weights, models) = modelProbabilities(energyLandscape).map(c => {
88+
89+
val model_state = c._2 ++ blockedState
90+
91+
val model = confToModel(model_state)
92+
93+
//Persist the model inference primitives to memory.
94+
model.persist(model_state)
95+
96+
(c._1, model)
97+
}).unzip
98+
99+
100+
val configsAndWeights = modelProbabilities(energyLandscape).map(c => (c._1, c._2 ++ blockedState))
101+
102+
logger.info("===============================================")
103+
logger.info("Constructing Gaussian Process Mixture")
104+
105+
logger.info("Number of model instances = "+weights.length)
106+
logger.info("--------------------------------------")
107+
logger.info(
108+
"Calculated model probabilities/weights are \n"+
109+
configsAndWeights.map(wc =>
110+
"\nConfiguration: \n"+
111+
GlobalOptimizer.prettyPrint(wc._2)+
112+
"\nProbability = "+wc._1+"\n"
113+
).reduceLeft((a, b) => a++b)
114+
)
115+
logger.info("--------------------------------------")
116+
117+
118+
119+
(
120+
mixturePipe(
121+
models, DenseVector(weights.toArray)
122+
),
123+
models.map(m => {
124+
val model_id = m.toString.split("\\.").last
125+
m._current_state.map(c => (model_id+"/"+c._1,c._2))
126+
}).reduceLeft((m1, m2) => m1++m2)
127+
)
128+
}
129+
130+
131+
}
Lines changed: 18 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
11
package io.github.mandar2812.dynaml.optimization
22

33
import breeze.linalg.DenseVector
4+
import io.github.mandar2812.dynaml.algebra.{PartitionedPSDMatrix, PartitionedVector}
45
import io.github.mandar2812.dynaml.models.StochasticProcessMixtureModel
56
import io.github.mandar2812.dynaml.models.gp.{AbstractGPRegressionModel, GaussianProcessMixture}
6-
import io.github.mandar2812.dynaml.pipes.DataPipe
7+
import io.github.mandar2812.dynaml.pipes.{DataPipe, DataPipe2}
8+
import io.github.mandar2812.dynaml.probability.MultGaussianPRV
9+
import io.github.mandar2812.dynaml.probability.distributions.BlockedMultiVariateGaussian
710

811
import scala.reflect.ClassTag
912

@@ -16,104 +19,25 @@ import scala.reflect.ClassTag
1619
* */
1720
class ProbGPMixtureMachine[T, I: ClassTag](
1821
model: AbstractGPRegressionModel[T, I]) extends
19-
AbstractCSA[AbstractGPRegressionModel[T, I], GaussianProcessMixture[T, I]](model) {
22+
MixtureMachine[T, I, Double, PartitionedVector, PartitionedPSDMatrix, BlockedMultiVariateGaussian,
23+
MultGaussianPRV, AbstractGPRegressionModel[T, I]](model) {
2024

21-
private var policy: String = "CSA"
2225

23-
private var baselinePolicy: String = "max"
26+
val (kernelPipe, noisePipe) = (system.covariance.asPipe, system.noiseModel.asPipe)
2427

25-
def _policy = policy
28+
def blockedHypParams = system.covariance.blocked_hyper_parameters ++ system.noiseModel.blocked_hyper_parameters
2629

27-
def setPolicy(p: String): this.type = {
28-
if(p == "CSA" || p == "Coupled Simulated Annealing")
29-
policy = "CSA"
30-
else
31-
policy = "GS"
30+
def blockedState = system._current_state.filterKeys(blockedHypParams.contains)
3231

33-
this
34-
}
32+
implicit val transform: DataPipe[T, Seq[(I, Double)]] = DataPipe(system.dataAsSeq _)
3533

36-
def setBaseLinePolicy(p: String): this.type = {
34+
override val confToModel = DataPipe((model_state: Map[String, Double]) =>
35+
AbstractGPRegressionModel(
36+
kernelPipe(model_state), noisePipe(model_state),
37+
system.mean)(system.data, system.npoints))
3738

38-
if(p == "avg" || p == "mean" || p == "average")
39-
baselinePolicy = "mean"
40-
else if(p == "min")
41-
baselinePolicy = "min"
42-
else if(p == "max")
43-
baselinePolicy = "max"
44-
else
45-
baselinePolicy = "mean"
46-
47-
this
48-
}
49-
50-
private def calculateEnergyLandscape(initialConfig: Map[String, Double], options: Map[String, String]) =
51-
if(policy == "CSA") performCSA(initialConfig, options)
52-
else getEnergyLandscape(initialConfig, options, meanFieldPrior)
53-
54-
private def modelProbabilities = DataPipe(ProbGPCommMachine.calculateModelWeightsSigmoid(baselinePolicy) _)
55-
56-
override def optimize(
57-
initialConfig: Map[String, Double],
58-
options: Map[String, String]) = {
59-
60-
//Find out the blocked hyper parameters and their values
61-
val blockedHypParams = system.covariance.blocked_hyper_parameters ++ system.noiseModel.blocked_hyper_parameters
62-
63-
val (kernelPipe, noisePipe) = (system.covariance.asPipe, system.noiseModel.asPipe)
64-
65-
val blockedState = system._current_state.filterKeys(blockedHypParams.contains)
66-
67-
val energyLandscape = calculateEnergyLandscape(initialConfig, options)
68-
69-
val data = system.data
70-
71-
//Calculate the weights of each configuration
72-
val (weights, models) = modelProbabilities(energyLandscape).map(c => {
73-
74-
val model_state = c._2 ++ blockedState
75-
76-
implicit val transform = DataPipe(system.dataAsSeq _)
77-
78-
val model = AbstractGPRegressionModel(
79-
kernelPipe(model_state), noisePipe(model_state),
80-
system.mean)(
81-
data, system.npoints)
82-
83-
//Persist the model inference primitives to memory.
84-
model.persist(model_state)
85-
86-
(c._1, model)
87-
}).unzip
88-
89-
90-
val configsAndWeights = modelProbabilities(energyLandscape).map(c => (c._1, c._2 ++ blockedState))
91-
92-
logger.info("===============================================")
93-
logger.info("Constructing Gaussian Process Mixture")
94-
95-
logger.info("Number of model instances = "+weights.length)
96-
logger.info("--------------------------------------")
97-
logger.info(
98-
"Calculated model probabilities/weights are \n"+
99-
configsAndWeights.map(wc =>
100-
"\nConfiguration: \n"+
101-
GlobalOptimizer.prettyPrint(wc._2)+
102-
"\nProbability = "+wc._1+"\n"
103-
).reduceLeft((a, b) => a++b)
104-
)
105-
logger.info("--------------------------------------")
106-
107-
108-
109-
(
110-
StochasticProcessMixtureModel[T, I](
111-
models, DenseVector(weights.toArray)
112-
),
113-
models.map(m => {
114-
val model_id = m.toString.split("\\.").last
115-
m._current_state.map(c => (model_id+"/"+c._1,c._2))
116-
}).reduceLeft((m1, m2) => m1++m2)
117-
)
118-
}
39+
override val mixturePipe = DataPipe2(
40+
(models: Seq[AbstractGPRegressionModel[T, I]], weights: DenseVector[Double]) =>
41+
StochasticProcessMixtureModel[T, I](models, weights))
42+
11943
}

0 commit comments

Comments
 (0)