Skip to content

Commit 23eb1ab

Browse files
committed
Added KL divergence for continuous random variables backed by distributions
1 parent c07f8a5 commit 23eb1ab

File tree

2 files changed

+37
-0
lines changed

2 files changed

+37
-0
lines changed

dynaml-core/src/main/scala-2.11/io/github/mandar2812/dynaml/probability/package.scala

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,13 +43,34 @@ package object probability {
4343
}
4444

4545

46+
/**
47+
* Calculate the entropy of a random variable
48+
* */
4649
def entropy[I, Distr <: ContinuousDistr[I]](rv: ContinuousRVWithDistr[I, Distr]): Double = {
4750
val logp_x: RandomVariable[Double] = MeasurableFunction[I, Double, ContinuousRVWithDistr[I, Distr]](
4851
rv, DataPipe((sample: I) => -1d*rv.underlyingDist.logPdf(sample)))
4952
E(logp_x)
5053
}
5154

5255

56+
/**
57+
* KL divergence:
58+
* @param p The base random variable
59+
* @param q The random variable used to approximate p
60+
* @return The Kulback Leibler divergence KL(P||Q)
61+
* */
62+
def KL[I, Distr <: ContinuousDistr[I]](
63+
p: ContinuousRVWithDistr[I, Distr])(
64+
q: ContinuousRVWithDistr[I, Distr]): Double = {
65+
66+
67+
val log_q_p: RandomVariable[Double] = MeasurableFunction[I, Double, ContinuousRVWithDistr[I, Distr]](
68+
p, DataPipe((sample: I) => p.underlyingDist.logPdf(sample)-q.underlyingDist.logPdf(sample)))
69+
70+
E(log_q_p)
71+
}
72+
73+
5374
/**
5475
* Calculate the (monte carlo approximation to) mean, median, mode, lower and upper confidence interval.
5576
*

dynaml-core/src/main/scala-2.11/io/github/mandar2812/dynaml/utils/package.scala

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -292,6 +292,22 @@ package object utils {
292292

293293
}
294294

295+
//TODO: Complete tail recursive implementation of b-spline
296+
/*
297+
def bspline(control: Seq[(Double, Double)])(x: Double): Double = {
298+
val (knots, control_points) = control.unzip
299+
300+
val knot_pairs: Map[Int, (Double, Double)] = knots
301+
.sliding(2).toSeq
302+
.zipWithIndex
303+
.map(p => (p._2, (p._1.head, p._1.last)))
304+
.toMap
305+
306+
def bsplineRec(i: Int, k: Int): Double = 0d
307+
308+
0.0
309+
}*/
310+
295311
/**
296312
* Calculates the Chebyshev polynomials of the first and second kind,
297313
* in a tail recursive manner, using their recurrence relations.

0 commit comments

Comments
 (0)