1818#include < stan/math/prim/fun/size.hpp>
1919#include < stan/math/prim/fun/size_zero.hpp>
2020#include < stan/math/prim/fun/tgamma.hpp>
21- #include < stan/math/prim/fun/value_of .hpp>
21+ #include < stan/math/prim/fun/value_of_rec .hpp>
2222#include < stan/math/prim/fun/log_gamma_q_dgamma.hpp>
2323#include < stan/math/prim/functor/partials_propagator.hpp>
2424#include < cmath>
2525
2626namespace stan {
2727namespace math {
28+ namespace internal {
29+ template <typename T>
30+ struct Q_eval {
31+ T log_Q{0.0 };
32+ T dlogQ_dalpha{0.0 };
33+ bool ok{false };
34+ };
35+
36+ /* *
37+ * Computes log q and d(log q) / d(alpha) using continued fraction.
38+ */
39+ template <typename T, typename T_shape,
40+ bool any_fvar, bool partials_fvar>
41+ static inline Q_eval<T> eval_q_cf (const T& alpha,
42+ const T& beta_y) {
43+ Q_eval<T> out;
44+ if constexpr (!any_fvar && is_autodiff_v<T_shape>) {
45+ auto log_q_result = log_gamma_q_dgamma (value_of_rec (alpha), value_of_rec (beta_y));
46+ out.log_Q = log_q_result.log_q ;
47+ out.dlogQ_dalpha = log_q_result.dlog_q_da ;
48+ } else {
49+ out.log_Q = internal::log_q_gamma_cf (alpha, beta_y);
50+ if constexpr (is_autodiff_v<T_shape>) {
51+ if constexpr (!partials_fvar) {
52+ out.dlogQ_dalpha
53+ = grad_reg_inc_gamma (alpha, beta_y, tgamma (alpha),
54+ digamma (alpha)) / exp (out.log_Q );
55+ } else {
56+ T alpha_unit = alpha;
57+ alpha_unit.d_ = 1 ;
58+ T beta_y_unit = beta_y;
59+ beta_y_unit.d_ = 0 ;
60+ T log_Q_fvar = internal::log_q_gamma_cf (alpha_unit, beta_y_unit);
61+ out.dlogQ_dalpha = log_Q_fvar.d_ ;
62+ }
63+ }
64+ }
65+
66+ out.ok = std::isfinite (value_of_rec (out.log_Q ));
67+ return out;
68+ }
69+
70+ /* *
71+ * Computes log q and d(log q) / d(alpha) using log1m.
72+ */
73+ template <typename T, typename T_shape,
74+ bool partials_fvar>
75+ static inline Q_eval<T> eval_q_log1m (const T& alpha,
76+ const T& beta_y) {
77+ Q_eval<T> out;
78+ out.log_Q = log1m (gamma_p (alpha, beta_y));
79+
80+ if (!std::isfinite (value_of_rec (out.log_Q ))) {
81+ out.ok = false ;
82+ return out;
83+ }
84+
85+ if constexpr (is_autodiff_v<T_shape>) {
86+ if constexpr (partials_fvar) {
87+ T alpha_unit = alpha;
88+ alpha_unit.d_ = 1 ;
89+ T beta_unit = beta_y;
90+ beta_unit.d_ = 0 ;
91+ T log_Q_fvar = log1m (gamma_p (alpha_unit, beta_unit));
92+ out.dlogQ_dalpha = log_Q_fvar.d_ ;
93+ } else {
94+ out.dlogQ_dalpha = -grad_reg_lower_inc_gamma (alpha, beta_y) / exp (out.log_Q );
95+ }
96+ }
97+
98+ out.ok = true ;
99+ return out;
100+ }
101+ }
28102
29103template <typename T_y, typename T_shape, typename T_inv_scale>
30- return_type_t <T_y, T_shape, T_inv_scale> gamma_lccdf (const T_y& y,
31- const T_shape& alpha,
32- const T_inv_scale& beta) {
33- using T_partials_return = partials_return_t <T_y, T_shape, T_inv_scale>;
104+ inline return_type_t <T_y, T_shape, T_inv_scale> gamma_lccdf (const T_y& y,
105+ const T_shape& alpha,
106+ const T_inv_scale& beta) {
34107 using std::exp;
35108 using std::log;
109+ using T_partials_return = partials_return_t <T_y, T_shape, T_inv_scale>;
36110 using T_y_ref = ref_type_t <T_y>;
37111 using T_alpha_ref = ref_type_t <T_shape>;
38112 using T_beta_ref = ref_type_t <T_inv_scale>;
@@ -58,7 +132,6 @@ return_type_t<T_y, T_shape, T_inv_scale> gamma_lccdf(const T_y& y,
58132 scalar_seq_view<T_beta_ref> beta_vec (beta_ref);
59133 const size_t N = max_size (y, alpha, beta);
60134
61- constexpr bool need_y_beta_deriv = !is_constant_all<T_y, T_inv_scale>::value;
62135 constexpr bool any_fvar = is_fvar<scalar_type_t <T_y>>::value
63136 || is_fvar<scalar_type_t <T_shape>>::value
64137 || is_fvar<scalar_type_t <T_inv_scale>>::value;
@@ -83,90 +156,35 @@ return_type_t<T_y, T_shape, T_inv_scale> gamma_lccdf(const T_y& y,
83156 return ops_partials.build (negative_infinity ());
84157 }
85158
86- bool use_cf = beta_y > alpha_dbl + 1.0 ;
87- T_partials_return log_Qn;
88- [[maybe_unused]] T_partials_return dlogQ_dalpha = 0.0 ;
89-
90- // Branch by autodiff type first, then handle use_cf logic inside each path
91- if constexpr (!any_fvar && is_autodiff_v<T_shape>) {
92- // var-only path: use log_gamma_q_dgamma which computes both log_q
93- // and its gradient analytically with double inputs
94- const double beta_y_dbl = value_of (value_of (beta_y));
95- const double alpha_dbl_val = value_of (value_of (alpha_dbl));
96-
97- if (use_cf) {
98- auto log_q_result = log_gamma_q_dgamma (alpha_dbl_val, beta_y_dbl);
99- log_Qn = log_q_result.log_q ;
100- dlogQ_dalpha = log_q_result.dlog_q_da ;
101- } else {
102- const T_partials_return Pn = gamma_p (alpha_dbl, beta_y);
103- log_Qn = log1m (Pn);
104- const T_partials_return Qn = exp (log_Qn);
105-
106- // Check if we need to fallback to continued fraction
107- bool need_cf_fallback
108- = !std::isfinite (value_of (value_of (log_Qn))) || Qn <= 0.0 ;
109- if (need_cf_fallback && beta_y > 0.0 ) {
110- auto log_q_result = log_gamma_q_dgamma (alpha_dbl_val, beta_y_dbl);
111- log_Qn = log_q_result.log_q ;
112- dlogQ_dalpha = log_q_result.dlog_q_da ;
113- } else {
114- dlogQ_dalpha = -grad_reg_lower_inc_gamma (alpha_dbl, beta_y) / Qn;
115- }
116- }
117- } else if constexpr (partials_fvar && is_autodiff_v<T_shape>) {
118- // fvar path: use unit derivative trick to compute gradients
119- auto alpha_unit = alpha_dbl;
120- alpha_unit.d_ = 1 ;
121- auto beta_unit = beta_y;
122- beta_unit.d_ = 0 ;
123-
124- if (use_cf) {
125- log_Qn = internal::log_q_gamma_cf (alpha_dbl, beta_y);
126- auto log_Qn_fvar = internal::log_q_gamma_cf (alpha_unit, beta_unit);
127- dlogQ_dalpha = log_Qn_fvar.d_ ;
128- } else {
129- const T_partials_return Pn = gamma_p (alpha_dbl, beta_y);
130- log_Qn = log1m (Pn);
131-
132- if (!std::isfinite (value_of (value_of (log_Qn))) && beta_y > 0.0 ) {
133- // Fallback to continued fraction
134- log_Qn = internal::log_q_gamma_cf (alpha_dbl, beta_y);
135- auto log_Qn_fvar = internal::log_q_gamma_cf (alpha_unit, beta_unit);
136- dlogQ_dalpha = log_Qn_fvar.d_ ;
137- } else {
138- auto log_Qn_fvar = log1m (gamma_p (alpha_unit, beta_unit));
139- dlogQ_dalpha = log_Qn_fvar.d_ ;
140- }
141- }
159+ const bool use_continued_fraction = beta_y > alpha_dbl + 1.0 ;
160+ internal::Q_eval<T_partials_return> result;
161+ if (use_continued_fraction) {
162+ result = internal::eval_q_cf<T_partials_return, T_shape,
163+ any_fvar, partials_fvar>(alpha_dbl, beta_y);
142164 } else {
143- // No alpha derivative needed (alpha is constant or double-only)
144- if (use_cf) {
145- log_Qn = internal::log_q_gamma_cf (alpha_dbl, beta_y);
146- } else {
147- const T_partials_return Pn = gamma_p (alpha_dbl, beta_y);
148- log_Qn = log1m (Pn);
165+ result = internal::eval_q_log1m<T_partials_return, T_shape,
166+ partials_fvar>(alpha_dbl, beta_y);
149167
150- if (!std::isfinite (value_of (value_of (log_Qn))) && beta_y > 0.0 ) {
151- log_Qn = internal::log_q_gamma_cf (alpha_dbl, beta_y);
152- }
168+ if (!result.ok && beta_y > 0.0 ) {
169+ // Fallback to continued fraction if log1m fails
170+ result = internal::eval_q_cf<T_partials_return, T_shape,
171+ any_fvar, partials_fvar>(alpha_dbl, beta_y);
153172 }
154173 }
155- if (!std::isfinite ( value_of ( value_of (log_Qn))) ) {
174+ if (!result. ok ) {
156175 return ops_partials.build (negative_infinity ());
157176 }
158- P += log_Qn;
159177
160- if constexpr (need_y_beta_deriv) {
178+ P += result.log_Q ;
179+
180+ if constexpr (is_autodiff_v<T_y> || is_autodiff_v<T_inv_scale>) {
161181 const T_partials_return log_y = log (y_dbl);
162- const T_partials_return log_beta = log (beta_dbl);
163- const T_partials_return lgamma_alpha = lgamma (alpha_dbl);
164182 const T_partials_return alpha_minus_one = fma (alpha_dbl, log_y, -log_y);
165183
166184 const T_partials_return log_pdf
167- = alpha_dbl * log_beta - lgamma_alpha + alpha_minus_one - beta_y;
185+ = alpha_dbl * log (beta_dbl) - lgamma (alpha_dbl) + alpha_minus_one - beta_y;
168186
169- const T_partials_return hazard = exp (log_pdf - log_Qn ); // f/Q
187+ const T_partials_return hazard = exp (log_pdf - result. log_Q ); // f/Q
170188
171189 if constexpr (is_autodiff_v<T_y>) {
172190 partials<0 >(ops_partials)[n] -= hazard;
@@ -176,7 +194,7 @@ return_type_t<T_y, T_shape, T_inv_scale> gamma_lccdf(const T_y& y,
176194 }
177195 }
178196 if constexpr (is_autodiff_v<T_shape>) {
179- partials<1 >(ops_partials)[n] += dlogQ_dalpha;
197+ partials<1 >(ops_partials)[n] += result. dlogQ_dalpha ;
180198 }
181199 }
182200 return ops_partials.build (P);
0 commit comments