Skip to content

Commit d8392cb

Browse files
authored
Merge pull request #1830 from martinmodrak/bugfix/1495-neg_binomial_2_log_stability
More stable implementation of neg_binomial_2_log_lpmf
2 parents 11742e2 + 766de89 commit d8392cb

File tree

3 files changed

+556
-443
lines changed

3 files changed

+556
-443
lines changed

stan/math/prim/prob/neg_binomial_2_log_lpmf.hpp

Lines changed: 32 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -3,17 +3,16 @@
33

44
#include <stan/math/prim/meta.hpp>
55
#include <stan/math/prim/err.hpp>
6+
#include <stan/math/prim/fun/binomial_coefficient_log.hpp>
67
#include <stan/math/prim/fun/digamma.hpp>
7-
#include <stan/math/prim/fun/exp.hpp>
8-
#include <stan/math/prim/fun/lgamma.hpp>
8+
#include <stan/math/prim/fun/inv.hpp>
99
#include <stan/math/prim/fun/log.hpp>
10+
#include <stan/math/prim/fun/log1p_exp.hpp>
1011
#include <stan/math/prim/fun/log_sum_exp.hpp>
1112
#include <stan/math/prim/fun/max_size.hpp>
12-
#include <stan/math/prim/fun/multiply_log.hpp>
1313
#include <stan/math/prim/fun/size.hpp>
1414
#include <stan/math/prim/fun/size_zero.hpp>
1515
#include <stan/math/prim/fun/value_of.hpp>
16-
#include <stan/math/prim/prob/poisson_log_lpmf.hpp>
1716
#include <cmath>
1817

1918
namespace stan {
@@ -52,13 +51,20 @@ return_type_t<T_log_location, T_precision> neg_binomial_2_log_lpmf(
5251
size_t size_phi = stan::math::size(phi);
5352
size_t size_eta_phi = max_size(eta, phi);
5453
size_t size_n_phi = max_size(n, phi);
55-
size_t max_size_seq_view = max_size(n, eta, phi);
54+
size_t size_all = max_size(n, eta, phi);
5655

5756
VectorBuilder<true, T_partials_return, T_log_location> eta_val(size_eta);
5857
for (size_t i = 0; i < size_eta; ++i) {
5958
eta_val[i] = value_of(eta_vec[i]);
6059
}
6160

61+
VectorBuilder<true, T_partials_return, T_precision> phi_val(size_phi);
62+
VectorBuilder<true, T_partials_return, T_precision> log_phi(size_phi);
63+
for (size_t i = 0; i < size_phi; ++i) {
64+
phi_val[i] = value_of(phi_vec[i]);
65+
log_phi[i] = log(phi_val[i]);
66+
}
67+
6268
VectorBuilder<!is_constant_all<T_log_location, T_precision>::value,
6369
T_partials_return, T_log_location>
6470
exp_eta(size_eta);
@@ -68,17 +74,19 @@ return_type_t<T_log_location, T_precision> neg_binomial_2_log_lpmf(
6874
}
6975
}
7076

71-
VectorBuilder<true, T_partials_return, T_precision> phi_val(size_phi);
72-
VectorBuilder<true, T_partials_return, T_precision> log_phi(size_phi);
73-
for (size_t i = 0; i < size_phi; ++i) {
74-
phi_val[i] = value_of(phi_vec[i]);
75-
log_phi[i] = log(phi_val[i]);
77+
VectorBuilder<!is_constant_all<T_log_location, T_precision>::value,
78+
T_partials_return, T_log_location, T_precision>
79+
exp_eta_over_exp_eta_phi(size_eta_phi);
80+
if (!is_constant_all<T_log_location, T_precision>::value) {
81+
for (size_t i = 0; i < size_eta_phi; ++i) {
82+
exp_eta_over_exp_eta_phi[i] = inv(phi_val[i] / exp_eta[i] + 1);
83+
}
7684
}
7785

7886
VectorBuilder<true, T_partials_return, T_log_location, T_precision>
79-
logsumexp_eta_logphi(size_eta_phi);
87+
log1p_exp_eta_m_logphi(size_eta_phi);
8088
for (size_t i = 0; i < size_eta_phi; ++i) {
81-
logsumexp_eta_logphi[i] = log_sum_exp(eta_val[i], log_phi[i]);
89+
log1p_exp_eta_m_logphi[i] = log1p_exp(eta_val[i] - log_phi[i]);
8290
}
8391

8492
VectorBuilder<true, T_partials_return, T_n, T_precision> n_plus_phi(
@@ -87,38 +95,25 @@ return_type_t<T_log_location, T_precision> neg_binomial_2_log_lpmf(
8795
n_plus_phi[i] = n_vec[i] + phi_val[i];
8896
}
8997

90-
for (size_t i = 0; i < max_size_seq_view; i++) {
91-
if (phi_val[i] > 1e5) {
92-
// TODO(martinmodrak) This is wrong (doesn't pass propto information),
93-
// and inaccurate for n = 0, but shouldn't break most models.
94-
// Also the 1e5 cutoff is way too low.
95-
// Will be addressed better once PR #1497 is merged
96-
logp += poisson_log_lpmf(n_vec[i], eta_val[i]);
97-
} else {
98-
if (include_summand<propto>::value) {
99-
logp -= lgamma(n_vec[i] + 1.0);
100-
}
101-
if (include_summand<propto, T_precision>::value) {
102-
logp += multiply_log(phi_val[i], phi_val[i]) - lgamma(phi_val[i]);
103-
}
104-
if (include_summand<propto, T_log_location>::value) {
105-
logp += n_vec[i] * eta_val[i];
106-
}
107-
if (include_summand<propto, T_precision>::value) {
108-
logp += lgamma(n_plus_phi[i]);
109-
}
110-
logp -= (n_plus_phi[i]) * logsumexp_eta_logphi[i];
98+
for (size_t i = 0; i < size_all; i++) {
99+
if (include_summand<propto, T_precision>::value) {
100+
logp += binomial_coefficient_log(n_plus_phi[i] - 1, n_vec[i]);
101+
}
102+
if (include_summand<propto, T_log_location>::value) {
103+
logp += n_vec[i] * eta_val[i];
111104
}
105+
logp += -phi_val[i] * log1p_exp_eta_m_logphi[i]
106+
- n_vec[i] * (log_phi[i] + log1p_exp_eta_m_logphi[i]);
112107

113108
if (!is_constant_all<T_log_location>::value) {
114109
ops_partials.edge1_.partials_[i]
115-
+= n_vec[i] - n_plus_phi[i] / (phi_val[i] / exp_eta[i] + 1);
110+
+= n_vec[i] - n_plus_phi[i] * exp_eta_over_exp_eta_phi[i];
116111
}
117112
if (!is_constant_all<T_precision>::value) {
118113
ops_partials.edge2_.partials_[i]
119-
+= 1.0 - n_plus_phi[i] / (exp_eta[i] + phi_val[i]) + log_phi[i]
120-
- logsumexp_eta_logphi[i] - digamma(phi_val[i])
121-
+ digamma(n_plus_phi[i]);
114+
+= exp_eta_over_exp_eta_phi[i] - n_vec[i] / (exp_eta[i] + phi_val[i])
115+
- log1p_exp_eta_m_logphi[i]
116+
- (digamma(phi_val[i]) - digamma(n_plus_phi[i]));
122117
}
123118
}
124119
return ops_partials.build(logp);

test/unit/math/prim/prob/neg_binomial_2_log_test.cpp

Lines changed: 5 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#include <stan/math/prim.hpp>
22
#include <test/unit/math/prim/prob/vector_rng_test_helper.hpp>
33
#include <test/unit/math/prim/prob/NegativeBinomial2LogTestRig.hpp>
4+
#include <test/unit/math/expect_near_rel.hpp>
45
#include <gtest/gtest.h>
56
#include <boost/random/mersenne_twister.hpp>
67
#include <boost/math/distributions.hpp>
@@ -212,29 +213,21 @@ TEST(ProbNegBinomial2, log_matches_lpmf) {
212213
TEST(ProbDistributionsNegBinomial2Log, neg_binomial_2_log_grid_test) {
213214
std::vector<double> mu_log_to_test
214215
= {-101, -27, -3, -1, -0.132, 0, 4, 10, 87};
215-
// TODO(martinmodrak) Reducing the span of the test, should be fixed
216-
// along with #1495
217-
// std::vector<double> phi_to_test = {2e-5, 0.36, 1, 10, 2.3e5, 1.8e10, 6e16};
218-
std::vector<double> phi_to_test = {0.36, 1, 10};
216+
std::vector<double> phi_to_test = {2e-5, 0.36, 1, 10, 2.3e5, 1.8e10, 6e16};
219217
std::vector<int> n_to_test = {0, 1, 10, 39, 101, 3048, 150054};
220218

221-
// TODO(martinmdorak) Only weak tolerance for this quick fix
222-
auto tolerance = [](double x) { return std::max(fabs(x * 1e-8), 1e-8); };
223-
224219
for (double mu_log : mu_log_to_test) {
225220
for (double phi : phi_to_test) {
226221
for (int n : n_to_test) {
227222
double val_log = stan::math::neg_binomial_2_log_lpmf(n, mu_log, phi);
228-
EXPECT_LE(val_log, 0)
229-
<< "neg_binomial_2_log_lpmf yields " << val_log
230-
<< " which si greater than 0 for n = " << n
231-
<< ", mu_log = " << mu_log << ", phi = " << phi << ".";
223+
std::stringstream msg;
232224
double val_orig
233225
= stan::math::neg_binomial_2_lpmf(n, std::exp(mu_log), phi);
234-
EXPECT_NEAR(val_log, val_orig, tolerance(val_orig))
226+
msg << std::setprecision(22)
235227
<< "neg_binomial_2_log_lpmf yields different result (" << val_log
236228
<< ") than neg_binomial_2_lpmf (" << val_orig << ") for n = " << n
237229
<< ", mu_log = " << mu_log << ", phi = " << phi << ".";
230+
stan::test::expect_near_rel(msg.str(), val_log, val_orig);
238231
}
239232
}
240233
}

0 commit comments

Comments
 (0)