Skip to content

Commit a2d0372

Browse files
authored
Merge pull request #3221 from stan-dev/fix/exprs-apply
Use Perfect Forwarding in all functions that use `apply` family of functors
2 parents 1187719 + acd3426 commit a2d0372

File tree

224 files changed

+1588
-1151
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

224 files changed

+1588
-1151
lines changed

stan/math/fwd/fun/inv_logit.hpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,10 @@ namespace math {
1616
* @param x argument
1717
* @return inverse logit of argument
1818
*/
19-
template <typename T>
20-
inline fvar<T> inv_logit(const fvar<T>& x) {
21-
return fvar<T>(inv_logit(x.val_),
22-
x.d_ * inv_logit(x.val_) * (1 - inv_logit(x.val_)));
19+
template <typename T, require_fvar_t<T>* = nullptr>
20+
inline auto inv_logit(T&& x) {
21+
return std::decay_t<T>(inv_logit(x.val_),
22+
x.d_ * inv_logit(x.val_) * (1 - inv_logit(x.val_)));
2323
}
2424

2525
} // namespace math

stan/math/fwd/fun/log_softmax.hpp

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
#include <stan/math/fwd/meta.hpp>
77
#include <stan/math/fwd/fun/softmax.hpp>
88
#include <stan/math/prim/fun/log_softmax.hpp>
9+
#include <stan/math/prim/fun/to_ref.hpp>
910

1011
namespace stan {
1112
namespace math {
@@ -19,24 +20,24 @@ namespace math {
1920
* @throw std::domain_error If the input vector is size 0.
2021
*/
2122
template <typename T, require_vector_st<is_fvar, T>* = nullptr>
22-
inline auto log_softmax(const T& x) {
23-
return apply_vector_unary<T>::apply(x, [&](const auto& alpha) {
23+
inline auto log_softmax(T&& x) {
24+
return apply_vector_unary<T>::apply(std::forward<T>(x), [](auto&& alpha) {
2425
using T_alpha = decltype(alpha);
2526
using T_fvar = value_type_t<T_alpha>;
2627
using T_fvar_inner = typename T_fvar::Scalar;
2728

28-
const Eigen::Ref<const plain_type_t<T_alpha>>& alpha_ref = alpha;
29+
auto&& alpha_ref = to_ref(std::forward<decltype(alpha)>(alpha));
2930
Eigen::Matrix<T_fvar_inner, -1, 1> alpha_t = alpha_ref.val();
3031
Eigen::Matrix<T_fvar_inner, -1, 1> softmax_alpha_t = softmax(alpha_t);
3132

32-
Eigen::Matrix<T_fvar, -1, 1> log_softmax_alpha(alpha.size());
33+
Eigen::Matrix<T_fvar, -1, 1> log_softmax_alpha(alpha_ref.size());
3334
log_softmax_alpha.val() = log_softmax(alpha_t);
3435
log_softmax_alpha.d().setZero();
3536

36-
for (int m = 0; m < alpha.size(); ++m) {
37+
for (int m = 0; m < alpha_ref.size(); ++m) {
3738
T_fvar_inner negative_alpha_m_d_times_softmax_alpha_t_m
3839
= -alpha_ref.coeff(m).d_ * softmax_alpha_t(m);
39-
for (int k = 0; k < alpha.size(); ++k) {
40+
for (int k = 0; k < alpha_ref.size(); ++k) {
4041
if (m == k) {
4142
log_softmax_alpha(k).d_
4243
+= alpha_ref.coeff(m).d_

stan/math/fwd/fun/log_sum_exp.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,9 +50,9 @@ inline fvar<T> log_sum_exp(const fvar<T>& x1, double x2) {
5050
* @return The log of the sum of the exponentiated vector values.
5151
*/
5252
template <typename T, require_container_st<is_fvar, T>* = nullptr>
53-
inline auto log_sum_exp(const T& x) {
53+
inline auto log_sum_exp(T&& x) {
5454
return apply_vector_unary<ref_type_t<T>>::reduce(
55-
to_ref(x), [&](const auto& v) {
55+
to_ref(std::forward<T>(x)), [](auto&& v) {
5656
using T_fvar_inner = typename value_type_t<decltype(v)>::Scalar;
5757
using mat_type = Eigen::Matrix<T_fvar_inner, -1, -1>;
5858
mat_type vals = v.val();

stan/math/fwd/fun/norm1.hpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,14 @@ namespace math {
1515
/**
1616
* Compute the L1 norm of the specified vector of values.
1717
*
18-
* @tparam T Type of input vector.
18+
* @tparam Container Type of input vector.
1919
* @param[in] x Vector of specified values.
2020
* @return L1 norm of x.
2121
*/
2222
template <typename Container, require_eigen_vt<is_fvar, Container>* = nullptr>
23-
inline auto norm1(const Container& x) {
23+
inline auto norm1(Container&& x) {
2424
return apply_vector_unary<ref_type_t<Container>>::reduce(
25-
to_ref(x), [&](const auto& v) {
25+
to_ref(std::forward<Container>(x)), [](auto&& v) {
2626
using T_fvar_inner = typename value_type_t<decltype(v)>::Scalar;
2727
return fvar<T_fvar_inner>(norm1(v.val()),
2828
v.d().cwiseProduct(sign(v.val())).sum());

stan/math/fwd/fun/norm2.hpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,14 @@ namespace math {
1414
/**
1515
* Compute the L2 norm of the specified vector of values.
1616
*
17-
* @tparam T Type of input vector.
17+
* @tparam Container Type of input vector.
1818
* @param[in] x Vector of specified values.
1919
* @return L2 norm of x.
2020
*/
2121
template <typename Container, require_eigen_vt<is_fvar, Container>* = nullptr>
22-
inline auto norm2(const Container& x) {
22+
inline auto norm2(Container&& x) {
2323
return apply_vector_unary<ref_type_t<Container>>::reduce(
24-
to_ref(x), [&](const auto& v) {
24+
to_ref(std::forward<Container>(x)), [](auto&& v) {
2525
using T_fvar_inner = typename value_type_t<decltype(v)>::Scalar;
2626
T_fvar_inner res = norm2(v.val());
2727
return fvar<T_fvar_inner>(res,

stan/math/fwd/fun/pow.hpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,9 +73,13 @@ inline auto pow(const T1& x1, const T2& x2) {
7373
template <typename T1, typename T2, require_any_container_t<T1, T2>* = nullptr,
7474
require_all_not_matrix_st<is_var, T1, T2>* = nullptr,
7575
require_any_fvar_t<base_type_t<T1>, base_type_t<T2>>* = nullptr>
76-
inline auto pow(const T1& a, const T2& b) {
76+
inline auto pow(T1&& a, T2&& b) {
7777
return apply_scalar_binary(
78-
[](const auto& c, const auto& d) { return stan::math::pow(c, d); }, a, b);
78+
[](auto&& c, auto&& d) {
79+
return stan::math::pow(std::forward<decltype(c)>(c),
80+
std::forward<decltype(d)>(d));
81+
},
82+
std::forward<T1>(a), std::forward<T2>(b));
7983
}
8084

8185
} // namespace math

stan/math/fwd/functor/apply_scalar_unary.hpp

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,20 +18,23 @@ namespace math {
1818
* autodiff variable.
1919
*/
2020
template <typename F, typename T>
21-
struct apply_scalar_unary<F, fvar<T> > {
21+
struct apply_scalar_unary<F, T, require_fvar_t<T>> {
2222
/**
2323
* Function return type, which is same as the argument type for
2424
* the function, <code>fvar&lt;T&gt;</code>.
2525
*/
26-
using return_t = fvar<T>;
26+
using return_t = std::decay_t<T>;
2727

2828
/**
2929
* Apply the function specified by F to the specified argument.
3030
*
3131
* @param x Argument variable.
3232
* @return Function applied to the variable.
3333
*/
34-
static inline return_t apply(const fvar<T>& x) { return F::fun(x); }
34+
template <typename T2>
35+
static inline auto apply(T2&& x) {
36+
return F::fun(std::forward<T2>(x));
37+
}
3538
};
3639

3740
} // namespace math

stan/math/fwd/functor/finite_diff.hpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,8 @@ inline constexpr double aggregate_tangent(const FuncTangent& tangent,
4444
template <typename FuncTangent, typename InputArg,
4545
require_st_fvar<InputArg>* = nullptr>
4646
inline auto aggregate_tangent(const FuncTangent& tangent, const InputArg& arg) {
47-
return sum(apply_scalar_binary(
48-
[](const auto& x, const auto& y) { return x * y.d_; }, tangent, arg));
47+
return sum(apply_scalar_binary([](auto&& x, auto&& y) { return x * y.d_; },
48+
tangent, arg));
4949
}
5050
} // namespace internal
5151

@@ -73,7 +73,7 @@ inline auto finite_diff(const F& func, const TArgs&... args) {
7373
std::vector<FvarInnerT> serialised_args
7474
= serialize<FvarInnerT>(value_of(args)...);
7575

76-
auto serial_functor = [&](const auto& v) {
76+
auto serial_functor = [&](auto&& v) {
7777
auto v_deserializer = to_deserializer(v);
7878
return func(v_deserializer.read(args)...);
7979
};

stan/math/opencl/prim/dirichlet_lpdf.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#define STAN_MATH_OPENCL_PRIM_DIRICHLET_LPDF_HPP
33
#ifdef STAN_OPENCL
44

5+
#include <stan/math/prim/fun/Eigen.hpp>
56
#include <stan/math/prim/meta.hpp>
67
#include <stan/math/prim/err.hpp>
78
#include <stan/math/prim/fun/constants.hpp>

stan/math/opencl/prim/log_softmax.hpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,7 @@ template <typename T,
2222
require_all_kernel_expressions_and_none_scalar_t<T>* = nullptr>
2323
inline matrix_cl<double> log_softmax(const T& a) {
2424
check_nonzero_size("log_softmax (OpenCL)", "x", a);
25-
return make_holder_cl([](const auto& x) { return x - log_sum_exp(x); },
26-
to_ref(a));
25+
return make_holder_cl([](auto&& x) { return x - log_sum_exp(x); }, to_ref(a));
2726
}
2827

2928
} // namespace math

0 commit comments

Comments
 (0)