Skip to content

Commit e1233fb

Browse files
committed
stan changes
1 parent 192c775 commit e1233fb

File tree

16 files changed

+415
-151
lines changed

16 files changed

+415
-151
lines changed

stan/math/fwd/fun/accumulator.hpp

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,4 +9,92 @@
99
#include <vector>
1010
#include <type_traits>
1111

12+
namespace stan {
13+
namespace math {
14+
15+
/**
16+
* Class to accumulate values and eventually return their sum. If
17+
* no values are ever added, the return value is 0.
18+
*
19+
* This class is useful for speeding up autodiff of long sums
20+
* because it uses the <code>sum()</code> operation (either from
21+
* <code>stan::math</code> or one defined by argument-dependent lookup.
22+
*
23+
* @tparam T Type of scalar added
24+
*/
25+
template <typename T>
26+
class accumulator<T, require_fvar_t<T>> {
27+
private:
28+
std::vector<T> buf_;
29+
30+
public:
31+
/**
32+
* Add the specified arithmetic type value to the buffer after
33+
* static casting it to the class type <code>T</code>.
34+
*
35+
* <p>See the std library doc for <code>std::is_arithmetic</code>
36+
* for information on what counts as an arithmetic type.
37+
*
38+
* @tparam S Type of argument
39+
* @param x Value to add
40+
*/
41+
template <typename S, typename = require_stan_scalar_t<S>>
42+
inline void add(S x) {
43+
buf_.push_back(x);
44+
}
45+
46+
/**
47+
* Add each entry in the specified matrix, vector, or row vector
48+
* of values to the buffer.
49+
*
50+
* @tparam S type of the matrix
51+
* @param m Matrix of values to add
52+
*/
53+
template <typename S, require_matrix_t<S>* = nullptr>
54+
inline void add(const S& m) {
55+
buf_.push_back(stan::math::sum(m));
56+
}
57+
58+
/**
59+
* Recursively add each entry in the specified standard vector
60+
* to the buffer. This will allow vectors of primitives,
61+
* autodiff variables to be added; if the vector entries
62+
* are collections, their elements are recursively added.
63+
*
64+
* @tparam S Type of value to recursively add.
65+
* @param xs Vector of entries to add
66+
*/
67+
template <typename S>
68+
inline void add(const std::vector<S>& xs) {
69+
for (size_t i = 0; i < xs.size(); ++i) {
70+
this->add(xs[i]);
71+
}
72+
}
73+
74+
#ifdef STAN_OPENCL
75+
76+
/**
77+
* Sum each entry and then push to the buffer.
78+
* @tparam S A Type inheriting from `matrix_cl_base`
79+
* @param x An OpenCL matrix
80+
*/
81+
template <typename S,
82+
require_all_kernel_expressions_and_none_scalar_t<S>* = nullptr>
83+
inline void add(const S& xs) {
84+
buf_.push_back(stan::math::sum(xs));
85+
}
86+
87+
#endif
88+
89+
/**
90+
* Return the sum of the accumulated values.
91+
*
92+
* @return Sum of accumulated values.
93+
*/
94+
inline T sum() const { return stan::math::sum(buf_); }
95+
};
96+
97+
}
98+
}
99+
12100
#endif

stan/math/fwd/fun/mdivide_right.hpp

Lines changed: 61 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -12,117 +12,89 @@
1212

1313
namespace stan {
1414
namespace math {
15-
15+
/*
1616
template <typename EigMat1, typename EigMat2,
17-
require_all_eigen_vt<is_fvar, EigMat1, EigMat2>* = nullptr,
18-
require_vt_same<EigMat1, EigMat2>* = nullptr>
19-
inline Eigen::Matrix<value_type_t<EigMat1>, EigMat1::RowsAtCompileTime,
20-
EigMat2::ColsAtCompileTime>
21-
mdivide_right(const EigMat1& A, const EigMat2& b) {
22-
using T = typename value_type_t<EigMat1>::Scalar;
23-
constexpr int R1 = EigMat1::RowsAtCompileTime;
24-
constexpr int C1 = EigMat1::ColsAtCompileTime;
25-
constexpr int R2 = EigMat2::RowsAtCompileTime;
26-
constexpr int C2 = EigMat2::ColsAtCompileTime;
27-
28-
check_square("mdivide_right", "b", b);
29-
check_multiplicable("mdivide_right", "A", A, "b", b);
30-
if (b.size() == 0) {
31-
return {A.rows(), 0};
17+
require_all_eigen_vt<is_fvar, EigMat1, EigMat2>* = nullptr>
18+
inline auto
19+
mdivide_right(const EigMat1& b, const EigMat2& A) {
20+
std::cout << "\nUsing 1: " << "\n";
21+
using A_fvar_inner_type = typename value_type_t<EigMat2>::Scalar;
22+
using b_fvar_inner_type = typename value_type_t<EigMat1>::Scalar;
23+
using inner_ret_t = return_type_t<A_fvar_inner_type, b_fvar_inner_type>;
24+
constexpr auto R1 = EigMat1::RowsAtCompileTime;
25+
constexpr auto C1 = EigMat1::ColsAtCompileTime;
26+
constexpr auto R2 = EigMat2::RowsAtCompileTime;
27+
constexpr auto C2 = EigMat2::ColsAtCompileTime;
28+
29+
check_square("mdivide_right", "A", A);
30+
check_multiplicable("mdivide_right", "b", b, "A", A);
31+
if (A.size() == 0) {
32+
using ret_t = decltype(mdivide_right(b.val(), A.val()).eval());
33+
return promote_scalar_t<fvar<inner_ret_t>, ret_t>{b.rows(), 0};
3234
}
3335
34-
Eigen::Matrix<T, R1, C1> val_A(A.rows(), A.cols());
35-
Eigen::Matrix<T, R1, C1> deriv_A(A.rows(), A.cols());
36-
Eigen::Matrix<T, R2, C2> val_b(b.rows(), b.cols());
37-
Eigen::Matrix<T, R2, C2> deriv_b(b.rows(), b.cols());
36+
Eigen::Matrix<A_fvar_inner_type, R2, C2> val_A(A.rows(), A.cols());
37+
Eigen::Matrix<A_fvar_inner_type, R2, C2> deriv_A(A.rows(), A.cols());
3838
39-
const Eigen::Ref<const plain_type_t<EigMat1>>& A_ref = A;
39+
const auto& A_ref = to_ref(A);
4040
for (int j = 0; j < A.cols(); j++) {
4141
for (int i = 0; i < A.rows(); i++) {
4242
val_A.coeffRef(i, j) = A_ref.coeff(i, j).val_;
4343
deriv_A.coeffRef(i, j) = A_ref.coeff(i, j).d_;
4444
}
4545
}
4646
47-
const Eigen::Ref<const plain_type_t<EigMat2>>& b_ref = b;
48-
for (int j = 0; j < b.cols(); j++) {
49-
for (int i = 0; i < b.rows(); i++) {
47+
Eigen::Matrix<b_fvar_inner_type, R1, C1> val_b(b.rows(), b.cols());
48+
Eigen::Matrix<b_fvar_inner_type, R1, C1> deriv_b(b.rows(), b.cols());
49+
const auto& b_ref = to_ref(b);
50+
for (Eigen::Index j = 0; j < b.cols(); j++) {
51+
for (Eigen::Index i = 0; i < b.rows(); i++) {
5052
val_b.coeffRef(i, j) = b_ref.coeff(i, j).val_;
5153
deriv_b.coeffRef(i, j) = b_ref.coeff(i, j).d_;
5254
}
5355
}
54-
55-
Eigen::Matrix<T, R1, C2> A_mult_inv_b = mdivide_right(val_A, val_b);
56-
57-
return to_fvar(A_mult_inv_b,
58-
mdivide_right(deriv_A, val_b)
59-
- A_mult_inv_b * mdivide_right(deriv_b, val_b));
56+
auto A_mult_inv_b = mdivide_right(val_b, val_A).eval();
57+
promote_scalar_t<fvar<inner_ret_t>, decltype(A_mult_inv_b)>
58+
ret(A_mult_inv_b.rows(), A_mult_inv_b.cols()); ret.val() = A_mult_inv_b; ret.d()
59+
= mdivide_right(deriv_b, val_A)
60+
- multiply(A_mult_inv_b, mdivide_right(deriv_A, val_A));
61+
return ret;
6062
}
6163
6264
template <typename EigMat1, typename EigMat2,
6365
require_eigen_vt<is_fvar, EigMat1>* = nullptr,
64-
require_eigen_vt<std::is_arithmetic, EigMat2>* = nullptr>
65-
inline Eigen::Matrix<value_type_t<EigMat1>, EigMat1::RowsAtCompileTime,
66-
EigMat2::ColsAtCompileTime>
67-
mdivide_right(const EigMat1& A, const EigMat2& b) {
68-
using T = typename value_type_t<EigMat1>::Scalar;
69-
constexpr int R1 = EigMat1::RowsAtCompileTime;
70-
constexpr int C1 = EigMat1::ColsAtCompileTime;
71-
72-
check_square("mdivide_right", "b", b);
73-
check_multiplicable("mdivide_right", "A", A, "b", b);
74-
if (b.size() == 0) {
75-
return {A.rows(), 0};
76-
}
77-
78-
Eigen::Matrix<T, R1, C1> val_A(A.rows(), A.cols());
79-
Eigen::Matrix<T, R1, C1> deriv_A(A.rows(), A.cols());
80-
81-
const Eigen::Ref<const plain_type_t<EigMat1>>& A_ref = A;
82-
for (int j = 0; j < A.cols(); j++) {
83-
for (int i = 0; i < A.rows(); i++) {
84-
val_A.coeffRef(i, j) = A_ref.coeff(i, j).val_;
85-
deriv_A.coeffRef(i, j) = A_ref.coeff(i, j).d_;
86-
}
87-
}
88-
89-
return to_fvar(mdivide_right(val_A, b), mdivide_right(deriv_A, b));
90-
}
66+
require_eigen_vt<is_var_or_arithmetic, EigMat2>* = nullptr>
67+
inline auto
68+
mdivide_right(const EigMat1& b, const EigMat2& A) {
69+
using T_return = return_type_t<EigMat1, EigMat2>;
70+
check_square("mdivide_right", "A", A);
71+
check_multiplicable("mdivide_right", "b", b, "A", A);
72+
if (A.size() == 0) {
73+
using ret_type = decltype(A.transpose().template
74+
cast<T_return>().lu().solve(b.template
75+
cast<T_return>().transpose()).transpose().eval()); return ret_type{b.rows(), 0};
76+
}
77+
return A.transpose().template cast<T_return>().lu().solve(b.template
78+
cast<T_return>().transpose()).transpose().eval();
79+
}
9180
9281
template <typename EigMat1, typename EigMat2,
93-
require_eigen_vt<std::is_arithmetic, EigMat1>* = nullptr,
82+
require_eigen_vt<is_var_or_arithmetic, EigMat1>* = nullptr,
9483
require_eigen_vt<is_fvar, EigMat2>* = nullptr>
95-
inline Eigen::Matrix<value_type_t<EigMat2>, EigMat1::RowsAtCompileTime,
96-
EigMat2::ColsAtCompileTime>
97-
mdivide_right(const EigMat1& A, const EigMat2& b) {
98-
using T = typename value_type_t<EigMat2>::Scalar;
99-
constexpr int R1 = EigMat1::RowsAtCompileTime;
100-
constexpr int C1 = EigMat1::ColsAtCompileTime;
101-
constexpr int R2 = EigMat2::RowsAtCompileTime;
102-
constexpr int C2 = EigMat2::ColsAtCompileTime;
103-
104-
check_square("mdivide_right", "b", b);
105-
check_multiplicable("mdivide_right", "A", A, "b", b);
106-
if (b.size() == 0) {
107-
return {A.rows(), 0};
108-
}
109-
110-
Eigen::Matrix<T, R2, C2> val_b(b.rows(), b.cols());
111-
Eigen::Matrix<T, R2, C2> deriv_b(b.rows(), b.cols());
112-
113-
const Eigen::Ref<const plain_type_t<EigMat2>>& b_ref = b;
114-
for (int j = 0; j < b.cols(); j++) {
115-
for (int i = 0; i < b.rows(); i++) {
116-
val_b.coeffRef(i, j) = b_ref.coeff(i, j).val_;
117-
deriv_b.coeffRef(i, j) = b_ref.coeff(i, j).d_;
118-
}
119-
}
120-
121-
Eigen::Matrix<T, R1, C2> A_mult_inv_b = mdivide_right(A, val_b);
122-
123-
return to_fvar(A_mult_inv_b, -A_mult_inv_b * mdivide_right(deriv_b, val_b));
124-
}
125-
84+
inline auto
85+
mdivide_right(const EigMat1& b, const EigMat2& A) {
86+
using T_return = return_type_t<EigMat1, EigMat2>;
87+
check_square("mdivide_right", "A", A);
88+
check_multiplicable("mdivide_right", "b", b, "A", A);
89+
if (A.size() == 0) {
90+
using ret_type = decltype(A.transpose().template
91+
cast<T_return>().lu().solve(b.template
92+
cast<T_return>().transpose()).transpose().eval()); return ret_type{b.rows(), 0};
93+
}
94+
return A.transpose().template cast<T_return>().lu().solve(b.template
95+
cast<T_return>().transpose()).transpose().eval();
96+
}
97+
*/
12698
} // namespace math
12799
} // namespace stan
128100
#endif

stan/math/fwd/fun/sum.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ inline value_type_t<T> sum(const T& m) {
4545
if (m.size() == 0) {
4646
return 0.0;
4747
}
48-
const Eigen::Ref<const plain_type_t<T>>& m_ref = m;
48+
const auto& m_ref = to_ref(m);
4949
return {sum(m_ref.val()), sum(m_ref.d())};
5050
}
5151

stan/math/mix.hpp

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,12 @@
55
#include <stan/math/mix/fun.hpp>
66
#include <stan/math/mix/functor.hpp>
77

8+
9+
#include <stan/math/rev/core.hpp>
10+
#include <stan/math/rev/meta.hpp>
11+
#include <stan/math/rev/fun.hpp>
12+
#include <stan/math/rev/functor.hpp>
13+
814
#ifdef STAN_OPENCL
915
#include <stan/math/opencl/rev.hpp>
1016
#endif
@@ -14,11 +20,6 @@
1420
#include <stan/math/fwd/fun.hpp>
1521
#include <stan/math/fwd/functor.hpp>
1622

17-
#include <stan/math/rev/core.hpp>
18-
#include <stan/math/rev/meta.hpp>
19-
#include <stan/math/rev/fun.hpp>
20-
#include <stan/math/rev/functor.hpp>
21-
2223
#include <stan/math/prim.hpp>
2324

2425
#endif

stan/math/mix/functor/derivative.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
#ifndef STAN_MATH_MIX_FUNCTOR_DERIVATIVE_HPP
22
#define STAN_MATH_MIX_FUNCTOR_DERIVATIVE_HPP
33

4-
#include <stan/math/fwd/core.hpp>
54
#include <stan/math/prim/fun/Eigen.hpp>
5+
#include <stan/math/fwd/core.hpp>
66
#include <stan/math/rev/core.hpp>
77
#include <vector>
88

stan/math/mix/meta.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,12 @@
66
#include <stan/math/fwd/core.hpp>
77
#include <stan/math/fwd/meta/is_fvar.hpp>
88
#include <stan/math/fwd/meta/partials_type.hpp>
9+
#include <stan/math/fwd/meta.hpp>
910

1011
#include <stan/math/rev/core.hpp>
1112
#include <stan/math/rev/meta/is_var.hpp>
1213
#include <stan/math/rev/meta/partials_type.hpp>
1314

14-
#include <stan/math/fwd/meta.hpp>
1515
#include <stan/math/rev/meta.hpp>
1616
#include <stan/math/prim/meta.hpp>
1717

stan/math/prim/fun/eigenvalues.hpp

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,11 @@
77
namespace stan {
88
namespace math {
99

10-
template <typename T>
11-
Eigen::Matrix<std::complex<T>, -1, 1> eigenvalues(
12-
const Eigen::Matrix<T, -1, -1>& m) {
10+
template <typename Mat, require_eigen_t<Mat>* = nullptr>
11+
inline auto eigenvalues(const Mat& m) {
1312
check_nonzero_size("eigenvalues", "m", m);
1413
check_square("eigenvalues", "m", m);
15-
16-
Eigen::EigenSolver<Eigen::Matrix<T, -1, -1>> solver(m);
17-
return solver.eigenvalues();
14+
return m.eigenvalues().eval();
1815
}
1916

2017
} // namespace math

stan/math/prim/fun/mdivide_left_tri.hpp

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,20 +26,24 @@ namespace math {
2626
template <Eigen::UpLoType TriView, typename T1, typename T2,
2727
require_all_eigen_t<T1, T2> * = nullptr,
2828
require_all_not_eigen_vt<is_var, T1, T2> * = nullptr>
29-
inline Eigen::Matrix<return_type_t<T1, T2>, T1::RowsAtCompileTime,
30-
T2::ColsAtCompileTime>
31-
mdivide_left_tri(const T1 &A, const T2 &b) {
29+
inline auto mdivide_left_tri(const T1 &A, const T2 &b) {
3230
using T_return = return_type_t<T1, T2>;
3331
check_square("mdivide_left_tri", "A", A);
3432
check_multiplicable("mdivide_left_tri", "A", A, "b", b);
33+
using ret_type = decltype(A.template cast<T_return>()
34+
.eval()
35+
.template triangularView<TriView>()
36+
.solve(b.template cast<T_return>().eval())
37+
.eval());
3538
if (A.rows() == 0) {
36-
return {0, b.cols()};
39+
return ret_type(0, b.cols());
3740
}
3841

3942
return A.template cast<T_return>()
4043
.eval()
4144
.template triangularView<TriView>()
42-
.solve(b.template cast<T_return>().eval());
45+
.solve(b.template cast<T_return>().eval())
46+
.eval();
4347
}
4448

4549
/**

0 commit comments

Comments
 (0)