33
44#include < stan/math/prim/meta.hpp>
55#include < stan/math/prim/err.hpp>
6+ #include < stan/math/prim/fun/binomial_coefficient_log.hpp>
67#include < stan/math/prim/fun/digamma.hpp>
7- #include < stan/math/prim/fun/exp.hpp>
8- #include < stan/math/prim/fun/lgamma.hpp>
8+ #include < stan/math/prim/fun/inv.hpp>
99#include < stan/math/prim/fun/log.hpp>
10+ #include < stan/math/prim/fun/log1p_exp.hpp>
1011#include < stan/math/prim/fun/log_sum_exp.hpp>
1112#include < stan/math/prim/fun/max_size.hpp>
12- #include < stan/math/prim/fun/multiply_log.hpp>
1313#include < stan/math/prim/fun/size.hpp>
1414#include < stan/math/prim/fun/size_zero.hpp>
1515#include < stan/math/prim/fun/value_of.hpp>
16- #include < stan/math/prim/prob/poisson_log_lpmf.hpp>
1716#include < cmath>
1817
1918namespace stan {
@@ -52,13 +51,20 @@ return_type_t<T_log_location, T_precision> neg_binomial_2_log_lpmf(
5251 size_t size_phi = stan::math::size (phi);
5352 size_t size_eta_phi = max_size (eta, phi);
5453 size_t size_n_phi = max_size (n, phi);
55- size_t max_size_seq_view = max_size (n, eta, phi);
54+ size_t size_all = max_size (n, eta, phi);
5655
5756 VectorBuilder<true , T_partials_return, T_log_location> eta_val (size_eta);
5857 for (size_t i = 0 ; i < size_eta; ++i) {
5958 eta_val[i] = value_of (eta_vec[i]);
6059 }
6160
61+ VectorBuilder<true , T_partials_return, T_precision> phi_val (size_phi);
62+ VectorBuilder<true , T_partials_return, T_precision> log_phi (size_phi);
63+ for (size_t i = 0 ; i < size_phi; ++i) {
64+ phi_val[i] = value_of (phi_vec[i]);
65+ log_phi[i] = log (phi_val[i]);
66+ }
67+
6268 VectorBuilder<!is_constant_all<T_log_location, T_precision>::value,
6369 T_partials_return, T_log_location>
6470 exp_eta (size_eta);
@@ -68,17 +74,19 @@ return_type_t<T_log_location, T_precision> neg_binomial_2_log_lpmf(
6874 }
6975 }
7076
71- VectorBuilder<true , T_partials_return, T_precision> phi_val (size_phi);
72- VectorBuilder<true , T_partials_return, T_precision> log_phi (size_phi);
73- for (size_t i = 0 ; i < size_phi; ++i) {
74- phi_val[i] = value_of (phi_vec[i]);
75- log_phi[i] = log (phi_val[i]);
77+ VectorBuilder<!is_constant_all<T_log_location, T_precision>::value,
78+ T_partials_return, T_log_location, T_precision>
79+ exp_eta_over_exp_eta_phi (size_eta_phi);
80+ if (!is_constant_all<T_log_location, T_precision>::value) {
81+ for (size_t i = 0 ; i < size_eta_phi; ++i) {
82+ exp_eta_over_exp_eta_phi[i] = inv (phi_val[i] / exp_eta[i] + 1 );
83+ }
7684 }
7785
7886 VectorBuilder<true , T_partials_return, T_log_location, T_precision>
79- logsumexp_eta_logphi (size_eta_phi);
87+ log1p_exp_eta_m_logphi (size_eta_phi);
8088 for (size_t i = 0 ; i < size_eta_phi; ++i) {
81- logsumexp_eta_logphi [i] = log_sum_exp (eta_val[i], log_phi[i]);
89+ log1p_exp_eta_m_logphi [i] = log1p_exp (eta_val[i] - log_phi[i]);
8290 }
8391
8492 VectorBuilder<true , T_partials_return, T_n, T_precision> n_plus_phi (
@@ -87,38 +95,25 @@ return_type_t<T_log_location, T_precision> neg_binomial_2_log_lpmf(
8795 n_plus_phi[i] = n_vec[i] + phi_val[i];
8896 }
8997
90- for (size_t i = 0 ; i < max_size_seq_view; i++) {
91- if (phi_val[i] > 1e5 ) {
92- // TODO(martinmodrak) This is wrong (doesn't pass propto information),
93- // and inaccurate for n = 0, but shouldn't break most models.
94- // Also the 1e5 cutoff is way too low.
95- // Will be addressed better once PR #1497 is merged
96- logp += poisson_log_lpmf (n_vec[i], eta_val[i]);
97- } else {
98- if (include_summand<propto>::value) {
99- logp -= lgamma (n_vec[i] + 1.0 );
100- }
101- if (include_summand<propto, T_precision>::value) {
102- logp += multiply_log (phi_val[i], phi_val[i]) - lgamma (phi_val[i]);
103- }
104- if (include_summand<propto, T_log_location>::value) {
105- logp += n_vec[i] * eta_val[i];
106- }
107- if (include_summand<propto, T_precision>::value) {
108- logp += lgamma (n_plus_phi[i]);
109- }
110- logp -= (n_plus_phi[i]) * logsumexp_eta_logphi[i];
98+ for (size_t i = 0 ; i < size_all; i++) {
99+ if (include_summand<propto, T_precision>::value) {
100+ logp += binomial_coefficient_log (n_plus_phi[i] - 1 , n_vec[i]);
101+ }
102+ if (include_summand<propto, T_log_location>::value) {
103+ logp += n_vec[i] * eta_val[i];
111104 }
105+ logp += -phi_val[i] * log1p_exp_eta_m_logphi[i]
106+ - n_vec[i] * (log_phi[i] + log1p_exp_eta_m_logphi[i]);
112107
113108 if (!is_constant_all<T_log_location>::value) {
114109 ops_partials.edge1_ .partials_ [i]
115- += n_vec[i] - n_plus_phi[i] / (phi_val [i] / exp_eta[i] + 1 ) ;
110+ += n_vec[i] - n_plus_phi[i] * exp_eta_over_exp_eta_phi [i];
116111 }
117112 if (!is_constant_all<T_precision>::value) {
118113 ops_partials.edge2_ .partials_ [i]
119- += 1.0 - n_plus_phi [i] / (exp_eta[i] + phi_val[i]) + log_phi[i]
120- - logsumexp_eta_logphi [i] - digamma (phi_val[i])
121- + digamma (n_plus_phi[i]);
114+ += exp_eta_over_exp_eta_phi[i] - n_vec [i] / (exp_eta[i] + phi_val[i])
115+ - log1p_exp_eta_m_logphi [i]
116+ - ( digamma (phi_val[i]) - digamma ( n_plus_phi[i]) );
122117 }
123118 }
124119 return ops_partials.build (logp);
0 commit comments