Skip to content

Commit 367b8cd

Browse files
committed
Fix to energy calculation of CSA
- In case of a specified mean field hyper-prior, add its energy to the system energy
1 parent 0152efd commit 367b8cd

File tree

2 files changed

+19
-6
lines changed

2 files changed

+19
-6
lines changed

dynaml-core/src/main/scala-2.11/io/github/mandar2812/dynaml/optimization/CoupledSimulatedAnnealing.scala

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ package io.github.mandar2812.dynaml.optimization
2020

2121
import breeze.linalg.DenseVector
2222
import breeze.stats.distributions.CauchyDistribution
23+
import io.github.mandar2812.dynaml.probability.RandomVariable
2324
import io.github.mandar2812.dynaml.utils
2425

2526
import scala.util.Random
@@ -112,6 +113,10 @@ class CoupledSimulatedAnnealing[M <: GloballyOptimizable](model: M)
112113
computeAcceptanceProb(c._1, c._1, gamma_init, accTemp)
113114
})
114115

116+
val hyp = initialConfig.keys
117+
118+
val usePriorFlag: Boolean = hyp.forall(meanFieldPrior.contains)
119+
115120
def CSATRec(eLandscape: List[(Double, Map[String, Double])], it: Int): List[(Double, Map[String, Double])] =
116121
it match {
117122
case 0 => eLandscape
@@ -142,7 +147,15 @@ class CoupledSimulatedAnnealing[M <: GloballyOptimizable](model: M)
142147
val (newEnergyLandscape,probabilities) = eLandscape.map((config) => {
143148
//mutate this config
144149
val new_config = mutate(config._2, mutTemp)
145-
val new_energy = system.energy(new_config, options)
150+
151+
val priorEnergy =
152+
if(usePriorFlag)
153+
new_config.foldLeft(0.0)(
154+
(p_acc, keyValue) => p_acc - meanFieldPrior(keyValue._1).underlyingDist.logPdf(keyValue._2)
155+
)
156+
else 0.0
157+
158+
val new_energy = system.energy(new_config, options) + priorEnergy
146159

147160
logger.info("New Configuration: \n"+GlobalOptimizer.prettyPrint(new_config))
148161
logger.info("Energy = "+new_energy)

dynaml-core/src/main/scala-2.11/io/github/mandar2812/dynaml/optimization/GlobalOptimizer.scala

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,11 @@ import io.github.mandar2812.dynaml.utils
2424
import org.apache.log4j.Logger
2525

2626
/**
27-
* @author mandar2812 datum 24/6/15.
28-
*
29-
* High level interface defining the
30-
* core functions of a global optimizer
31-
*/
27+
* High level interface defining the
28+
* core functions of a global optimizer
29+
* @author mandar2812 datum 24/6/15.
30+
*
31+
* */
3232
trait GlobalOptimizer[T <: GloballyOptimizable] {
3333

3434
val system: T

0 commit comments

Comments
 (0)