-
-
Notifications
You must be signed in to change notification settings - Fork 198
Improved numerical stability of binomial_coefficient_log #1614
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 33 commits
8a8ea67
7991f40
77d1e85
8bef133
118adb2
3578ff6
e03205d
e44bd90
035212d
e7194bf
7922dd5
3fd0f3f
6d6ac35
3a320df
37760cd
6ebef3e
3df81a2
d0f7f45
c01d2be
74d2824
dd59b06
5bda4f5
cd3381a
d014464
9a5ee11
51ef491
215cad0
71cbd40
cfa5b81
4fa9c52
eab19c3
75a6f73
47c7962
cf05d40
5b93278
c84a259
6aed316
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,10 +1,16 @@ | ||
| #ifndef STAN_MATH_PRIM_FUN_BINOMIAL_COEFFICIENT_LOG_HPP | ||
| #define STAN_MATH_PRIM_FUN_BINOMIAL_COEFFICIENT_LOG_HPP | ||
|
|
||
| #include <boost/math/constants/constants.hpp> | ||
| #include <stan/math/prim/meta.hpp> | ||
| #include <stan/math/prim/fun/inv.hpp> | ||
| #include <stan/math/prim/err.hpp> | ||
| #include <stan/math/prim/fun/constants.hpp> | ||
| #include <stan/math/prim/fun/digamma.hpp> | ||
| #include <stan/math/prim/fun/is_any_nan.hpp> | ||
| #include <stan/math/prim/fun/log1p.hpp> | ||
| #include <stan/math/prim/fun/lbeta.hpp> | ||
| #include <stan/math/prim/fun/lgamma.hpp> | ||
| #include <stan/math/prim/fun/multiply_log.hpp> | ||
| #include <stan/math/prim/fun/value_of.hpp> | ||
|
|
||
| namespace stan { | ||
| namespace math { | ||
|
|
@@ -13,22 +19,24 @@ namespace math { | |
| * Return the log of the binomial coefficient for the specified | ||
| * arguments. | ||
| * | ||
| * The binomial coefficient, \f${N \choose n}\f$, read "N choose n", is | ||
| * defined for \f$0 \leq n \leq N\f$ by | ||
| * The binomial coefficient, \f${n \choose k}\f$, read "n choose k", is | ||
| * defined for \f$0 \leq k \leq n\f$ by | ||
| * | ||
| * \f${N \choose n} = \frac{N!}{n! (N-n)!}\f$. | ||
| * \f${n \choose k} = \frac{n!}{k! (n-k)!}\f$. | ||
| * | ||
| * This function uses Gamma functions to define the log | ||
| * and generalize the arguments to continuous N and n. | ||
| * and generalize the arguments to continuous n and k. | ||
| * | ||
| * \f$ \log {n \choose k} | ||
| * = \log \ \Gamma(n+1) - \log \Gamma(k+1) - \log \Gamma(n-k+1)\f$. | ||
| * | ||
| * \f$ \log {N \choose n} | ||
| * = \log \ \Gamma(N+1) - \log \Gamma(n+1) - \log \Gamma(N-n+1)\f$. | ||
| * | ||
| \f[ | ||
| \mbox{binomial\_coefficient\_log}(x, y) = | ||
| \begin{cases} | ||
| \textrm{error} & \mbox{if } y > x \textrm{ or } y < 0\\ | ||
| \ln\Gamma(x+1) & \mbox{if } 0\leq y \leq x \\ | ||
| \textrm{error} & \mbox{if } y > x + 1 \textrm{ or } y < -1 \textrm{ or } x | ||
| < -1\\ | ||
| \ln\Gamma(x+1) & \mbox{if } -1 < y < x + 1 \\ | ||
| \quad -\ln\Gamma(y+1)& \\ | ||
| \quad -\ln\Gamma(x-y+1)& \\[6pt] | ||
| \textrm{NaN} & \mbox{if } x = \textrm{NaN or } y = \textrm{NaN} | ||
|
|
@@ -38,7 +46,8 @@ namespace math { | |
| \f[ | ||
| \frac{\partial\, \mbox{binomial\_coefficient\_log}(x, y)}{\partial x} = | ||
| \begin{cases} | ||
| \textrm{error} & \mbox{if } y > x \textrm{ or } y < 0\\ | ||
| \textrm{error} & \mbox{if } y > x + 1 \textrm{ or } y < -1 \textrm{ or } x | ||
| < -1\\ | ||
| \Psi(x+1) & \mbox{if } 0\leq y \leq x \\ | ||
| \quad -\Psi(x-y+1)& \\[6pt] | ||
| \textrm{NaN} & \mbox{if } x = \textrm{NaN or } y = \textrm{NaN} | ||
|
|
@@ -48,32 +57,91 @@ namespace math { | |
| \f[ | ||
| \frac{\partial\, \mbox{binomial\_coefficient\_log}(x, y)}{\partial y} = | ||
| \begin{cases} | ||
| \textrm{error} & \mbox{if } y > x \textrm{ or } y < 0\\ | ||
| \textrm{error} & \mbox{if } y > x + 1 \textrm{ or } y < -1 \textrm{ or } x | ||
| < -1\\ | ||
| -\Psi(y+1) & \mbox{if } 0\leq y \leq x \\ | ||
| \quad +\Psi(x-y+1)& \\[6pt] | ||
| \textrm{NaN} & \mbox{if } x = \textrm{NaN or } y = \textrm{NaN} | ||
| \end{cases} | ||
| \f] | ||
| * | ||
| * @tparam T_N type of the first argument | ||
| * @tparam T_n type of the second argument | ||
| * @param N total number of objects. | ||
| * @param n number of objects chosen. | ||
| * @return log (N choose n). | ||
| * This function is numerically more stable than naive evaluation via lgamma. | ||
| * | ||
| * @tparam T_n type of the first argument | ||
| * @tparam T_k type of the second argument | ||
| * | ||
| * @param n total number of objects. | ||
| * @param k number of objects chosen. | ||
| * @return log (n choose k). | ||
| */ | ||
| template <typename T_N, typename T_n> | ||
| inline return_type_t<T_N, T_n> binomial_coefficient_log(const T_N N, | ||
| const T_n n) { | ||
| const double CUTOFF = 1000; | ||
| if (N - n < CUTOFF) { | ||
| const T_N N_plus_1 = N + 1; | ||
| return lgamma(N_plus_1) - lgamma(n + 1) - lgamma(N_plus_1 - n); | ||
|
|
||
| template <typename T_n, typename T_k> | ||
| inline return_type_t<T_n, T_k> binomial_coefficient_log(const T_n n, | ||
| const T_k k) { | ||
| if (is_any_nan(n, k)) { | ||
| return stan::math::NOT_A_NUMBER; | ||
martinmodrak marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| } | ||
|
|
||
| // Choosing the more stable of the symmetric branches | ||
| if (n > 0 && k > value_of_rec(n) / 2.0 + 1e-8) { | ||
| return binomial_coefficient_log(n, n - k); | ||
| } | ||
|
|
||
| using T_partials_return = partials_return_t<T_n, T_k>; | ||
martinmodrak marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| const T_partials_return n_ = value_of(n); | ||
| const T_partials_return k_ = value_of(k); | ||
martinmodrak marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| const T_partials_return n_plus_1 = n_ + 1; | ||
| const T_partials_return n_plus_1_mk = n_plus_1 - k_; | ||
|
|
||
| static const char* function = "binomial_coefficient_log"; | ||
| check_greater_or_equal(function, "first argument", n, -1); | ||
| check_greater_or_equal(function, "second argument", k, -1); | ||
| check_greater_or_equal(function, "(first argument - second argument + 1)", | ||
| n_plus_1_mk, 0.0); | ||
|
|
||
| operands_and_partials<T_n, T_k> ops_partials(n, k); | ||
|
|
||
| T_partials_return value; | ||
| if (k_ == 0) { | ||
| value = 0; | ||
| } else if (n_plus_1 < lgamma_stirling_diff_useful) { | ||
| value = lgamma(n_plus_1) - lgamma(k_ + 1) - lgamma(n_plus_1_mk); | ||
| } else { | ||
| return_type_t<T_N, T_n> N_minus_n = N - n; | ||
| const double one_twelfth = inv(12); | ||
| return multiply_log(n, N_minus_n) + multiply_log((N + 0.5), N / N_minus_n) | ||
| + one_twelfth / N - n - one_twelfth / N_minus_n - lgamma(n + 1); | ||
| value = -lbeta(n_plus_1_mk, k_ + 1) - log1p(n_); | ||
| } | ||
|
|
||
| if (!is_constant_all<T_n, T_k>::value) { | ||
| // Branching on all the edge cases. | ||
| // In direct computation many of those would be NaN | ||
| // But one-sided limits from within the domain exist. | ||
| T_partials_return digamma_n_plus_1_mk = digamma(n_plus_1_mk); | ||
|
|
||
| if (!is_constant_all<T_n>::value) { | ||
| if (n_ == -1.0) { | ||
| if (k_ == 0) { | ||
| ops_partials.edge1_.partials_[0] = 0; | ||
| } else { | ||
| ops_partials.edge1_.partials_[0] = stan::math::NEGATIVE_INFTY; | ||
|
||
| } | ||
| } else { | ||
| ops_partials.edge1_.partials_[0] | ||
| = (digamma(n_plus_1) - digamma_n_plus_1_mk); | ||
| } | ||
| } | ||
| if (!is_constant_all<T_k>::value) { | ||
| if (k_ == 0 && n_ == -1.0) { | ||
|
||
| ops_partials.edge2_.partials_[0] = stan::math::NEGATIVE_INFTY; | ||
| } else if (k_ == -1) { | ||
| ops_partials.edge2_.partials_[0] = stan::math::INFTY; | ||
| } else { | ||
| ops_partials.edge2_.partials_[0] | ||
| = (digamma_n_plus_1_mk - digamma(k_ + 1)); | ||
| } | ||
| } | ||
| } | ||
|
|
||
| return ops_partials.build(value); | ||
| } | ||
|
|
||
| } // namespace math | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.