Skip to content

Commit bc6c815

Browse files
committed
Mixture model tuner now returns entire mixture configuration
1 parent 9dea653 commit bc6c815

File tree

2 files changed

+11
-2
lines changed

2 files changed

+11
-2
lines changed

dynaml-core/src/main/scala-2.11/io/github/mandar2812/dynaml/optimization/ProbGPMixtureMachine.scala

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package io.github.mandar2812.dynaml.optimization
22

33
import breeze.linalg.DenseVector
4+
import io.github.mandar2812.dynaml.models.StochasticProcessMixtureModel
45
import io.github.mandar2812.dynaml.models.gp.{AbstractGPRegressionModel, GaussianProcessMixture}
56
import io.github.mandar2812.dynaml.pipes.DataPipe
67

@@ -105,6 +106,14 @@ class ProbGPMixtureMachine[T, I: ClassTag](
105106

106107

107108

108-
(new GaussianProcessMixture[T, I](models, DenseVector(weights.toArray)), Map())
109+
(
110+
StochasticProcessMixtureModel[T, I](
111+
models, DenseVector(weights.toArray)
112+
),
113+
models.map(m => {
114+
val model_id = m.toString.split("\\.").last
115+
m._current_state.map(c => (model_id+"/"+c._1,c._2))
116+
}).reduceLeft((m1, m2) => m1++m2)
117+
)
109118
}
110119
}

scripts/stochasticPriors.sc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ val mixt_machine = new ProbGPMixtureMachine(gpModel1)
6969
.setLogScale(true)
7070
.setMaxIterations(16)
7171

72-
val (mix_model, _) = mixt_machine.optimize(gp_prior.covariance.state ++ gp_prior.noiseCovariance.state)
72+
val (mix_model, mixt_model_conf) = mixt_machine.optimize(gp_prior.covariance.state ++ gp_prior.noiseCovariance.state)
7373

7474

7575
val zs: MultGaussianPRV = gpModel.predictiveDistribution(xs)

0 commit comments

Comments
 (0)