Skip to content

Commit 81c67a0

Browse files
committed
refactor gamma_lccdf
1 parent 78fa2bd commit 81c67a0

File tree

1 file changed

+97
-79
lines changed

1 file changed

+97
-79
lines changed

stan/math/prim/prob/gamma_lccdf.hpp

Lines changed: 97 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -18,21 +18,95 @@
1818
#include <stan/math/prim/fun/size.hpp>
1919
#include <stan/math/prim/fun/size_zero.hpp>
2020
#include <stan/math/prim/fun/tgamma.hpp>
21-
#include <stan/math/prim/fun/value_of.hpp>
21+
#include <stan/math/prim/fun/value_of_rec.hpp>
2222
#include <stan/math/prim/fun/log_gamma_q_dgamma.hpp>
2323
#include <stan/math/prim/functor/partials_propagator.hpp>
2424
#include <cmath>
2525

2626
namespace stan {
2727
namespace math {
28+
namespace internal {
29+
template <typename T>
30+
struct Q_eval {
31+
T log_Q{0.0};
32+
T dlogQ_dalpha{0.0};
33+
bool ok{false};
34+
};
35+
36+
/**
37+
* Computes log q and d(log q) / d(alpha) using continued fraction.
38+
*/
39+
template <typename T, typename T_shape,
40+
bool any_fvar, bool partials_fvar>
41+
static inline Q_eval<T> eval_q_cf(const T& alpha,
42+
const T& beta_y) {
43+
Q_eval<T> out;
44+
if constexpr (!any_fvar && is_autodiff_v<T_shape>) {
45+
auto log_q_result = log_gamma_q_dgamma(value_of_rec(alpha), value_of_rec(beta_y));
46+
out.log_Q = log_q_result.log_q;
47+
out.dlogQ_dalpha = log_q_result.dlog_q_da;
48+
} else {
49+
out.log_Q = internal::log_q_gamma_cf(alpha, beta_y);
50+
if constexpr (is_autodiff_v<T_shape>) {
51+
if constexpr (!partials_fvar) {
52+
out.dlogQ_dalpha
53+
= grad_reg_inc_gamma(alpha, beta_y, tgamma(alpha),
54+
digamma(alpha)) / exp(out.log_Q);
55+
} else {
56+
T alpha_unit = alpha;
57+
alpha_unit.d_ = 1;
58+
T beta_y_unit = beta_y;
59+
beta_y_unit.d_ = 0;
60+
T log_Q_fvar = internal::log_q_gamma_cf(alpha_unit, beta_y_unit);
61+
out.dlogQ_dalpha = log_Q_fvar.d_;
62+
}
63+
}
64+
}
65+
66+
out.ok = std::isfinite(value_of_rec(out.log_Q));
67+
return out;
68+
}
69+
70+
/**
71+
* Computes log q and d(log q) / d(alpha) using log1m.
72+
*/
73+
template <typename T, typename T_shape,
74+
bool partials_fvar>
75+
static inline Q_eval<T> eval_q_log1m(const T& alpha,
76+
const T& beta_y) {
77+
Q_eval<T> out;
78+
out.log_Q = log1m(gamma_p(alpha, beta_y));
79+
80+
if (!std::isfinite(value_of_rec(out.log_Q))) {
81+
out.ok = false;
82+
return out;
83+
}
84+
85+
if constexpr (is_autodiff_v<T_shape>) {
86+
if constexpr (partials_fvar) {
87+
T alpha_unit = alpha;
88+
alpha_unit.d_ = 1;
89+
T beta_unit = beta_y;
90+
beta_unit.d_ = 0;
91+
T log_Q_fvar = log1m(gamma_p(alpha_unit, beta_unit));
92+
out.dlogQ_dalpha = log_Q_fvar.d_;
93+
} else {
94+
out.dlogQ_dalpha = -grad_reg_lower_inc_gamma(alpha, beta_y) / exp(out.log_Q);
95+
}
96+
}
97+
98+
out.ok = true;
99+
return out;
100+
}
101+
}
28102

29103
template <typename T_y, typename T_shape, typename T_inv_scale>
30-
return_type_t<T_y, T_shape, T_inv_scale> gamma_lccdf(const T_y& y,
31-
const T_shape& alpha,
32-
const T_inv_scale& beta) {
33-
using T_partials_return = partials_return_t<T_y, T_shape, T_inv_scale>;
104+
inline return_type_t<T_y, T_shape, T_inv_scale> gamma_lccdf(const T_y& y,
105+
const T_shape& alpha,
106+
const T_inv_scale& beta) {
34107
using std::exp;
35108
using std::log;
109+
using T_partials_return = partials_return_t<T_y, T_shape, T_inv_scale>;
36110
using T_y_ref = ref_type_t<T_y>;
37111
using T_alpha_ref = ref_type_t<T_shape>;
38112
using T_beta_ref = ref_type_t<T_inv_scale>;
@@ -58,7 +132,6 @@ return_type_t<T_y, T_shape, T_inv_scale> gamma_lccdf(const T_y& y,
58132
scalar_seq_view<T_beta_ref> beta_vec(beta_ref);
59133
const size_t N = max_size(y, alpha, beta);
60134

61-
constexpr bool need_y_beta_deriv = !is_constant_all<T_y, T_inv_scale>::value;
62135
constexpr bool any_fvar = is_fvar<scalar_type_t<T_y>>::value
63136
|| is_fvar<scalar_type_t<T_shape>>::value
64137
|| is_fvar<scalar_type_t<T_inv_scale>>::value;
@@ -83,90 +156,35 @@ return_type_t<T_y, T_shape, T_inv_scale> gamma_lccdf(const T_y& y,
83156
return ops_partials.build(negative_infinity());
84157
}
85158

86-
bool use_cf = beta_y > alpha_dbl + 1.0;
87-
T_partials_return log_Qn;
88-
[[maybe_unused]] T_partials_return dlogQ_dalpha = 0.0;
89-
90-
// Branch by autodiff type first, then handle use_cf logic inside each path
91-
if constexpr (!any_fvar && is_autodiff_v<T_shape>) {
92-
// var-only path: use log_gamma_q_dgamma which computes both log_q
93-
// and its gradient analytically with double inputs
94-
const double beta_y_dbl = value_of(value_of(beta_y));
95-
const double alpha_dbl_val = value_of(value_of(alpha_dbl));
96-
97-
if (use_cf) {
98-
auto log_q_result = log_gamma_q_dgamma(alpha_dbl_val, beta_y_dbl);
99-
log_Qn = log_q_result.log_q;
100-
dlogQ_dalpha = log_q_result.dlog_q_da;
101-
} else {
102-
const T_partials_return Pn = gamma_p(alpha_dbl, beta_y);
103-
log_Qn = log1m(Pn);
104-
const T_partials_return Qn = exp(log_Qn);
105-
106-
// Check if we need to fallback to continued fraction
107-
bool need_cf_fallback
108-
= !std::isfinite(value_of(value_of(log_Qn))) || Qn <= 0.0;
109-
if (need_cf_fallback && beta_y > 0.0) {
110-
auto log_q_result = log_gamma_q_dgamma(alpha_dbl_val, beta_y_dbl);
111-
log_Qn = log_q_result.log_q;
112-
dlogQ_dalpha = log_q_result.dlog_q_da;
113-
} else {
114-
dlogQ_dalpha = -grad_reg_lower_inc_gamma(alpha_dbl, beta_y) / Qn;
115-
}
116-
}
117-
} else if constexpr (partials_fvar && is_autodiff_v<T_shape>) {
118-
// fvar path: use unit derivative trick to compute gradients
119-
auto alpha_unit = alpha_dbl;
120-
alpha_unit.d_ = 1;
121-
auto beta_unit = beta_y;
122-
beta_unit.d_ = 0;
123-
124-
if (use_cf) {
125-
log_Qn = internal::log_q_gamma_cf(alpha_dbl, beta_y);
126-
auto log_Qn_fvar = internal::log_q_gamma_cf(alpha_unit, beta_unit);
127-
dlogQ_dalpha = log_Qn_fvar.d_;
128-
} else {
129-
const T_partials_return Pn = gamma_p(alpha_dbl, beta_y);
130-
log_Qn = log1m(Pn);
131-
132-
if (!std::isfinite(value_of(value_of(log_Qn))) && beta_y > 0.0) {
133-
// Fallback to continued fraction
134-
log_Qn = internal::log_q_gamma_cf(alpha_dbl, beta_y);
135-
auto log_Qn_fvar = internal::log_q_gamma_cf(alpha_unit, beta_unit);
136-
dlogQ_dalpha = log_Qn_fvar.d_;
137-
} else {
138-
auto log_Qn_fvar = log1m(gamma_p(alpha_unit, beta_unit));
139-
dlogQ_dalpha = log_Qn_fvar.d_;
140-
}
141-
}
159+
const bool use_continued_fraction = beta_y > alpha_dbl + 1.0;
160+
internal::Q_eval<T_partials_return> result;
161+
if (use_continued_fraction) {
162+
result = internal::eval_q_cf<T_partials_return, T_shape,
163+
any_fvar, partials_fvar>(alpha_dbl, beta_y);
142164
} else {
143-
// No alpha derivative needed (alpha is constant or double-only)
144-
if (use_cf) {
145-
log_Qn = internal::log_q_gamma_cf(alpha_dbl, beta_y);
146-
} else {
147-
const T_partials_return Pn = gamma_p(alpha_dbl, beta_y);
148-
log_Qn = log1m(Pn);
165+
result = internal::eval_q_log1m<T_partials_return, T_shape,
166+
partials_fvar>(alpha_dbl, beta_y);
149167

150-
if (!std::isfinite(value_of(value_of(log_Qn))) && beta_y > 0.0) {
151-
log_Qn = internal::log_q_gamma_cf(alpha_dbl, beta_y);
152-
}
168+
if (!result.ok && beta_y > 0.0) {
169+
// Fallback to continued fraction if log1m fails
170+
result = internal::eval_q_cf<T_partials_return, T_shape,
171+
any_fvar, partials_fvar>(alpha_dbl, beta_y);
153172
}
154173
}
155-
if (!std::isfinite(value_of(value_of(log_Qn)))) {
174+
if (!result.ok) {
156175
return ops_partials.build(negative_infinity());
157176
}
158-
P += log_Qn;
159177

160-
if constexpr (need_y_beta_deriv) {
178+
P += result.log_Q;
179+
180+
if constexpr (is_autodiff_v<T_y> || is_autodiff_v<T_inv_scale>) {
161181
const T_partials_return log_y = log(y_dbl);
162-
const T_partials_return log_beta = log(beta_dbl);
163-
const T_partials_return lgamma_alpha = lgamma(alpha_dbl);
164182
const T_partials_return alpha_minus_one = fma(alpha_dbl, log_y, -log_y);
165183

166184
const T_partials_return log_pdf
167-
= alpha_dbl * log_beta - lgamma_alpha + alpha_minus_one - beta_y;
185+
= alpha_dbl * log(beta_dbl) - lgamma(alpha_dbl) + alpha_minus_one - beta_y;
168186

169-
const T_partials_return hazard = exp(log_pdf - log_Qn); // f/Q
187+
const T_partials_return hazard = exp(log_pdf - result.log_Q); // f/Q
170188

171189
if constexpr (is_autodiff_v<T_y>) {
172190
partials<0>(ops_partials)[n] -= hazard;
@@ -176,7 +194,7 @@ return_type_t<T_y, T_shape, T_inv_scale> gamma_lccdf(const T_y& y,
176194
}
177195
}
178196
if constexpr (is_autodiff_v<T_shape>) {
179-
partials<1>(ops_partials)[n] += dlogQ_dalpha;
197+
partials<1>(ops_partials)[n] += result.dlogQ_dalpha;
180198
}
181199
}
182200
return ops_partials.build(P);

0 commit comments

Comments
 (0)