55#include < stan/math/prim/err.hpp>
66#include < stan/math/prim/fun/as_column_vector_or_scalar.hpp>
77#include < stan/math/prim/fun/as_array_or_scalar.hpp>
8- #include < stan/math/prim/fun/as_value_column_array_or_scalar.hpp>
98#include < stan/math/prim/fun/binomial_coefficient_log.hpp>
109#include < stan/math/prim/fun/inc_beta.hpp>
1110#include < stan/math/prim/fun/inv_logit.hpp>
@@ -34,9 +33,7 @@ namespace math {
3433 * @throw std::domain_error if N is negative or probability parameter is invalid
3534 * @throw std::invalid_argument if vector sizes do not match
3635 */
37- template <bool propto, typename T_n, typename T_N, typename T_prob,
38- require_all_not_nonscalar_prim_or_rev_kernel_expression_t <
39- T_n, T_N, T_prob>* = nullptr >
36+ template <bool propto, typename T_n, typename T_N, typename T_prob>
4037return_type_t <T_prob> binomial_logit_lpmf (const T_n& n, const T_N& N,
4138 const T_prob& alpha) {
4239 using T_partials_return = partials_return_t <T_n, T_N, T_prob>;
@@ -52,11 +49,19 @@ return_type_t<T_prob> binomial_logit_lpmf(const T_n& n, const T_N& N,
5249 T_N_ref N_ref = N;
5350 T_alpha_ref alpha_ref = alpha;
5451
55- decltype ( auto ) n_val = to_ref ( as_value_column_array_or_scalar ( n_ref) );
56- decltype ( auto ) N_val = to_ref ( as_value_column_array_or_scalar ( N_ref) );
57- decltype ( auto ) alpha_val = to_ref ( as_value_column_array_or_scalar ( alpha_ref) );
52+ const auto & n_col = as_column_vector_or_scalar ( n_ref);
53+ const auto & N_col = as_column_vector_or_scalar ( N_ref);
54+ const auto & alpha_col = as_column_vector_or_scalar ( alpha_ref);
5855
59- check_bounded (function, " Successes variable" , value_of (n_val), 0 , N_val);
56+ const auto & n_arr = as_array_or_scalar (n_col);
57+ const auto & N_arr = as_array_or_scalar (N_col);
58+ const auto & alpha_arr = as_array_or_scalar (alpha_col);
59+
60+ ref_type_t <decltype (value_of (n_arr))> n_val = value_of (n_arr);
61+ ref_type_t <decltype (value_of (N_arr))> N_val = value_of (N_arr);
62+ ref_type_t <decltype (value_of (alpha_arr))> alpha_val = value_of (alpha_arr);
63+
64+ check_bounded (function, " Successes variable" , n_val, 0 , N_val);
6065 check_nonnegative (function, " Population size parameter" , N_val);
6166 check_finite (function, " Probability parameter" , alpha_val);
6267
@@ -90,8 +95,7 @@ return_type_t<T_prob> binomial_logit_lpmf(const T_n& n, const T_N& N,
9095 T_partials_return sum_n = sum (n_val) * maximum_size / math::size (n);
9196 ops_partials.edge1_ .partials_ [0 ] = forward_as<T_partials_return>(
9297 sum_n * inv_logit_neg_alpha
93- - (sum (N_val) * maximum_size / math::size (N) - sum_n)
94- * inv_logit_alpha);
98+ - (sum (N_val) * maximum_size / math::size (N) - sum_n) * inv_logit_alpha);
9599 }
96100 }
97101
0 commit comments