66#include < stan/math/fwd/meta.hpp>
77#include < stan/math/fwd/fun/softmax.hpp>
88#include < stan/math/prim/fun/log_softmax.hpp>
9+ #include < stan/math/prim/fun/to_ref.hpp>
910
1011namespace stan {
1112namespace math {
@@ -19,24 +20,24 @@ namespace math {
1920 * @throw std::domain_error If the input vector is size 0.
2021 */
2122template <typename T, require_vector_st<is_fvar, T>* = nullptr >
22- inline auto log_softmax (const T & x) {
23- return apply_vector_unary<T>::apply (x , [&]( const auto & alpha) {
23+ inline auto log_softmax (T& & x) {
24+ return apply_vector_unary<T>::apply (std::forward<T>(x) , []( auto & & alpha) {
2425 using T_alpha = decltype (alpha);
2526 using T_fvar = value_type_t <T_alpha>;
2627 using T_fvar_inner = typename T_fvar::Scalar;
2728
28- const Eigen::Ref< const plain_type_t <T_alpha>>& alpha_ref = alpha;
29+ auto && alpha_ref = to_ref (std::forward< decltype ( alpha)>(alpha)) ;
2930 Eigen::Matrix<T_fvar_inner, -1 , 1 > alpha_t = alpha_ref.val ();
3031 Eigen::Matrix<T_fvar_inner, -1 , 1 > softmax_alpha_t = softmax (alpha_t );
3132
32- Eigen::Matrix<T_fvar, -1 , 1 > log_softmax_alpha (alpha .size ());
33+ Eigen::Matrix<T_fvar, -1 , 1 > log_softmax_alpha (alpha_ref .size ());
3334 log_softmax_alpha.val () = log_softmax (alpha_t );
3435 log_softmax_alpha.d ().setZero ();
3536
36- for (int m = 0 ; m < alpha .size (); ++m) {
37+ for (int m = 0 ; m < alpha_ref .size (); ++m) {
3738 T_fvar_inner negative_alpha_m_d_times_softmax_alpha_t_m
3839 = -alpha_ref.coeff (m).d_ * softmax_alpha_t (m);
39- for (int k = 0 ; k < alpha .size (); ++k) {
40+ for (int k = 0 ; k < alpha_ref .size (); ++k) {
4041 if (m == k) {
4142 log_softmax_alpha (k).d_
4243 += alpha_ref.coeff (m).d_
0 commit comments