Skip to content

Commit e414c73

Browse files
committed
Gaussian Process Mixtures
- Added `ProbGPMixtureMachine` to create GP mixtures from a single GP model
1 parent ad3374b commit e414c73

File tree

6 files changed

+76
-17
lines changed

6 files changed

+76
-17
lines changed

dynaml-core/src/main/scala-2.11/io/github/mandar2812/dynaml/models/gp/AbstractGPRegressionModel.scala

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,8 +74,6 @@ abstract class AbstractGPRegressionModel[T, I: ClassTag](
7474

7575
override protected val g: T = data
7676

77-
def _data = g
78-
7977
val npoints = num
8078

8179
protected var blockSize = 1000

dynaml-core/src/main/scala-2.11/io/github/mandar2812/dynaml/optimization/AbstractGridSearch.scala

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,7 @@ under the License.
1818
* */
1919
package io.github.mandar2812.dynaml.optimization
2020

21-
import breeze.linalg.DenseVector
2221
import org.apache.log4j.Logger
23-
import io.github.mandar2812.dynaml.utils
2422

2523
/**
2624
* @author mandar2812 datum 01/12/15.

dynaml-core/src/main/scala-2.11/io/github/mandar2812/dynaml/optimization/CoupledSimulatedAnnealing.scala

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,6 @@ under the License.
1818
* */
1919
package io.github.mandar2812.dynaml.optimization
2020

21-
import breeze.linalg.DenseVector
22-
import breeze.stats.distributions.CauchyDistribution
23-
import io.github.mandar2812.dynaml.probability.RandomVariable
24-
import io.github.mandar2812.dynaml.utils
25-
26-
import scala.util.Random
27-
2821
/**
2922
* Implementation of the Coupled Simulated Annealing algorithm
3023
* for global optimization.

dynaml-core/src/main/scala-2.11/io/github/mandar2812/dynaml/optimization/GridSearch.scala

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@ under the License.
1818
* */
1919
package io.github.mandar2812.dynaml.optimization
2020

21-
import org.apache.log4j.Logger
2221

2322
/**
2423
* @author mandar2812 datum 24/6/15.

dynaml-core/src/main/scala-2.11/io/github/mandar2812/dynaml/optimization/ProbGPCommMachine.scala

Lines changed: 57 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -222,10 +222,10 @@ object ProbGPCommMachine {
222222

223223
}
224224

225-
/*
225+
226226
class ProbGPMixtureMachine[T, I: ClassTag](
227227
model: AbstractGPRegressionModel[T, I]) extends
228-
ModelTuner[AbstractGPRegressionModel[T, I], GaussianProcessMixture[I]] {
228+
AbstractCSA[AbstractGPRegressionModel[T, I], GaussianProcessMixture[I]](model) {
229229

230230
private var policy: String = "CSA"
231231

@@ -256,14 +256,67 @@ class ProbGPMixtureMachine[T, I: ClassTag](
256256
this
257257
}
258258

259-
override val system = model
259+
private def calculateEnergyLandscape(initialConfig: Map[String, Double], options: Map[String, String]) =
260+
if(policy == "CSA") performCSA(initialConfig, options)
261+
else getEnergyLandscape(initialConfig, options, meanFieldPrior)
262+
263+
private def modelProbabilities = DataPipe(ProbGPCommMachine.calculateModelWeightsSigmoid(baselinePolicy) _)
260264

261265
override def optimize(
262266
initialConfig: Map[String, Double],
263267
options: Map[String, String]) = {
264268

269+
//Find out the blocked hyper parameters and their values
270+
val blockedHypParams = system.covariance.blocked_hyper_parameters ++ system.noiseModel.blocked_hyper_parameters
271+
272+
val (kernelPipe, noisePipe) = (system.covariance.asPipe, system.noiseModel.asPipe)
273+
274+
val (kernelParams, noiseParams) = (
275+
system.covariance.hyper_parameters,
276+
system.noiseModel.hyper_parameters)
277+
278+
val blockedState = system._current_state.filterKeys(blockedHypParams.contains)
279+
280+
val energyLandscape = calculateEnergyLandscape(initialConfig, options)
281+
282+
val data = system.data
283+
284+
//Calculate the weights of each configuration
285+
val (weights, models) = modelProbabilities(energyLandscape).map(c => {
286+
287+
val model_state = c._2 ++ blockedState
288+
289+
implicit val transform = DataPipe(system.dataAsSeq _)
290+
291+
val model = AbstractGPRegressionModel(
292+
kernelPipe(model_state), noisePipe(model_state),
293+
system.mean)(
294+
data, system.npoints)
295+
296+
297+
(c._1, model)
298+
}).unzip
299+
300+
301+
val configsAndWeights = modelProbabilities(energyLandscape).map(c => (c._1, c._2 ++ blockedState))
302+
303+
logger.info("===============================================")
304+
logger.info("Constructing Gaussian Process Mixture")
305+
306+
logger.info("Number of model instances = "+weights.length)
307+
logger.info("--------------------------------------")
308+
logger.info(
309+
"Calculated model probabilities/weights are \n"+
310+
configsAndWeights.map(wc =>
311+
"\nConfiguration: \n"+
312+
GlobalOptimizer.prettyPrint(wc._2)+
313+
"\nProbability = "+wc._1+"\n"
314+
).reduceLeft((a, b) => a++b)
315+
)
316+
logger.info("--------------------------------------")
317+
265318

266319

320+
(new GaussianProcessMixture[I](models, DenseVector(weights.toArray)), Map())
267321
}
268322
}
269-
*/

scripts/stochasticPriors.sc

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import io.github.mandar2812.dynaml.models.bayes.{LinearTrendESGPrior, LinearTren
55
import io.github.mandar2812.dynaml.probability._
66
import com.quantifind.charts.Highcharts._
77
import io.github.mandar2812.dynaml.analysis.implicits._
8+
import io.github.mandar2812.dynaml.optimization.ProbGPMixtureMachine
89
import io.github.mandar2812.dynaml.pipes.Encoder
910

1011
val rbfc = new RBFCovFunc(1.5)
@@ -55,8 +56,16 @@ sgp_prior.globalOptConfig_(Map("gridStep" -> "0.15", "gridSize" -> "40"))
5556
val gpModel = gp_prior.posteriorModel(dataset)
5657
val sgpModel = sgp_prior.posteriorModel(dataset)
5758

59+
60+
gp_prior.globalOptConfig_(Map("gridStep" -> "0.0", "gridSize" -> "1", "globalOpt" -> "GS", "policy" -> "GS"))
61+
val gpModel1 = gp_prior.posteriorModel(dataset)
62+
val mixt_machine = new ProbGPMixtureMachine(gpModel1)
63+
val (mix_model, _) = mixt_machine.optimize(gp_prior.covariance.state ++ gp_prior.noiseCovariance.state)
64+
65+
5866
val zs: MultGaussianPRV = gpModel.predictiveDistribution(xs)
5967
val sgp_zs: BlockedMESNRV = sgpModel.predictiveDistribution(xs)
68+
val mix_zs = mix_model.predictiveDistribution(xs)
6069

6170
val MultGaussianPRV(m, c) = zs
6271
val eigD = eig(c.toBreezeMatrix)
@@ -75,14 +84,23 @@ if(eValuesPositive) {
7584
println("Predictive Covariance Ill-Posed!")
7685
}
7786

78-
val samplesSGPPost = sgp_zs.iid(8).sample().map(s => s.toBreezeVector.toArray.toSeq)
87+
val samplesSGPPost = sgp_zs.iid(8).sample().map(_.toBreezeVector.toArray.toSeq)
7988

8089
spline(xs, samplesSGPPost.head)
8190
hold()
8291
samplesSGPPost.tail.foreach((s: Seq[Double]) => spline(xs, s))
8392
unhold()
8493
title("Ext. Skew Gaussian Process posterior samples")
8594

95+
val samplesMixPost = mix_zs.iid(8).sample().map(_.toBreezeVector.toArray.toSeq)
96+
97+
spline(xs, samplesMixPost.head)
98+
hold()
99+
samplesMixPost.tail.foreach((s: Seq[Double]) => spline(xs, s))
100+
unhold()
101+
title("Gaussian Process Mixture posterior samples")
102+
103+
86104
val (dx, dy) = dataset.sorted.unzip
87105

88106
spline(dx, dy)

0 commit comments

Comments
 (0)