-
-
Notifications
You must be signed in to change notification settings - Fork 198
Add Wolfe line search to Laplace approximation #3250
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: develop
Are you sure you want to change the base?
Changes from all commits
43c3eef
74a92bc
2de7a4f
547288e
cb5e282
cdaf700
6b22c85
a1d0906
5b6ffff
8eff766
f542cc5
c845944
40f1243
6e528d2
d89eeb5
40d889f
59b7a2f
c73f5aa
773d417
b557dad
b18bf87
98df588
929dd47
2ebb01a
3bbcef3
66ffec9
ff5bee4
7a7415a
973144a
cc5d49a
d759fdd
7b4e3a1
dfba08b
7df0ed1
0c92732
d19ee8b
82e43da
22a2210
24e2e19
7720c7a
63e1700
a9f17d4
28c44dd
fddf54f
113e2b1
88a8950
c4fcba2
521145f
d648ee0
7778307
04b5b2e
4117a31
affabfa
a143355
95c21d5
5ba7426
475c632
307fb0c
5038198
7e9af37
863223e
04197f2
c8a1613
aeb1662
ea9ffe0
7f7bbb2
66e8470
6acdd09
9dc118e
b9a493a
7c886f5
5eb664b
1f7ff3c
65914d3
6b74be8
e9cad2e
db4f677
aa73b5d
cf151fa
0022b41
0e1b7f4
d26626f
a2a0a99
a4f0b7d
4c884dd
e6b2d74
27d2fc9
f845a43
200fd9a
d6859f1
d2de83d
7f1a995
d86fa69
6abd9da
b30823d
dd1e74e
b292cae
ed3ce3a
d9319ad
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,124 @@ | ||
| #ifndef STAN_MATH_MIX_FUNCTOR_BARZILAI_BORWEIN_STEP_SIZE_HPP | ||
| #define STAN_MATH_MIX_FUNCTOR_BARZILAI_BORWEIN_STEP_SIZE_HPP | ||
| #include <stan/math/prim/fun/Eigen.hpp> | ||
| #include <algorithm> | ||
| #include <numeric> | ||
| #include <cmath> | ||
|
|
||
| namespace stan::math::internal { | ||
| /** | ||
| * @brief Curvature-aware Barzilai–Borwein (BB) step length with robust | ||
| * safeguards. | ||
| * | ||
| * Given successive parameter displacements \f$s = x_k - x_{k-1}\f$ and | ||
| * gradients \f$g_k\f$, \f$g_{k-1}\f$, this routine forms | ||
| * \f$y = g_k - g_{k-1}\f$ and computes the two classical BB candidates | ||
| * | ||
| * \f{align*}{ | ||
| * \alpha_{\text{BB1}} &= \frac{\langle s,s\rangle}{\langle s,y\rangle},\\ | ||
| * \alpha_{\text{BB2}} &= \frac{\langle s,y\rangle}{\langle y,y\rangle}, | ||
| * \f} | ||
| * | ||
| * then chooses between them using the **spectral cosine** | ||
| * \f$r = \cos^2\!\angle(s,y) = \dfrac{\langle s,y\rangle^2} | ||
| * {\langle s,s\rangle\,\langle | ||
| * y,y\rangle}\in[0,1]\f$: | ||
| * | ||
| * - if \f$r > 0.9\f$ (well-aligned curvature) and the previous line search | ||
| * did **≤ 1** backtrack, prefer the “long” step \f$\alpha_{\text{BB1}}\f$; | ||
| * - if \f$0.1 \le r \le 0.9\f$, take the neutral geometric mean | ||
| * \f$\sqrt{\alpha_{\text{BB1}}\alpha_{\text{BB2}}}\f$; | ||
| * - otherwise default to the “short” step \f$\alpha_{\text{BB2}}\f$. | ||
| * | ||
| * All candidates are clamped into \f$[\text{min\_alpha},\,\text{max\_alpha}]\f$ | ||
| * and must be finite and positive. | ||
| * If the curvature scalars are ill-posed (non-finite or too small), | ||
| * \f$\langle s,y\rangle \le \varepsilon\f$, or if `last_backtracks == 99` | ||
| * (explicitly disabling BB for this iteration), the function falls back to a | ||
| * **safe** step: | ||
| * use `prev_step` when finite and positive, otherwise \f$1.0\f$, then clamp to | ||
| * \f$[\text{min\_alpha},\,\text{max\_alpha}]\f$. | ||
| * | ||
| * @param s Displacement between consecutive iterates | ||
| * (\f$s = x_k - x_{k-1}\f$). | ||
| * @param g_curr Gradient at the current iterate \f$g_k\f$. | ||
| * @param g_prev Gradient at the previous iterate \f$g_{k-1}\f$. | ||
| * @param prev_step Previously accepted step length (used by the fallback). | ||
| * @param last_backtracks | ||
| * Number of backtracking contractions performed by the most | ||
| * recent line search; set to 99 to force the safe fallback. | ||
| * @param min_alpha Lower bound for the returned step length. | ||
| * @param max_alpha Upper bound for the returned step length. | ||
| * | ||
| * @return A finite, positive BB-style step length \f$\alpha \in | ||
| * [\text{min\_alpha},\,\text{max\_alpha}]\f$ suitable for seeding a | ||
| * line search or as a spectral preconditioner scale. | ||
| * | ||
| * @note Uses \f$\varepsilon=10^{-16}\f$ to guard against division by very | ||
| * small curvature terms, and applies `std::abs` to BB ratios to avoid | ||
| * negative steps; descent is enforced by the line search. | ||
| * @warning The vectors must have identical size. Non-finite inputs yield the | ||
| * safe fallback. | ||
| */ | ||
| inline double barzilai_borwein_step_size(const Eigen::VectorXd& s, | ||
| const Eigen::VectorXd& g_curr, | ||
| const Eigen::VectorXd& g_prev, | ||
| double prev_step, int last_backtracks, | ||
| double min_alpha, double max_alpha) { | ||
| // Fallbacks | ||
| auto safe_fallback = [&]() -> double { | ||
| double a = std::clamp( | ||
| prev_step > 0.0 && std::isfinite(prev_step) ? prev_step : 1.0, | ||
| min_alpha, max_alpha); | ||
| return a; | ||
| }; | ||
|
|
||
| const Eigen::VectorXd y = g_curr - g_prev; | ||
| const double sty = s.dot(y); | ||
| const double sts = s.squaredNorm(); | ||
| const double yty = y.squaredNorm(); | ||
|
|
||
| // Basic validity checks | ||
| constexpr double eps = 1e-16; | ||
| if (!(std::isfinite(sty) && std::isfinite(sts) && std::isfinite(yty)) | ||
| || sts <= eps || yty <= eps || sty <= eps || last_backtracks == 99) { | ||
| return safe_fallback(); | ||
| } | ||
|
|
||
| // BB candidates | ||
| double alpha_bb1 = std::clamp(std::abs(sts / sty), min_alpha, max_alpha); | ||
| double alpha_bb2 = std::clamp(std::abs(sty / yty), min_alpha, max_alpha); | ||
|
|
||
| // Safeguard candidates | ||
| if (!std::isfinite(alpha_bb1) || !std::isfinite(alpha_bb2) || alpha_bb1 <= 0.0 | ||
| || alpha_bb2 <= 0.0) { | ||
| return safe_fallback(); | ||
| } | ||
|
|
||
| // Spectral cosine r = cos^2(angle(s, y)) in [0,1] | ||
| const double r = (sty * sty) / (sts * yty); | ||
|
|
||
| // Heuristic thresholds (robust defaults) | ||
| constexpr double kLoose = 0.9; // "nice" curvature | ||
| constexpr double kTight = 0.1; // "dodgy" curvature | ||
|
|
||
| double alpha0 = alpha_bb2; // default to short BB for robustness | ||
| if (r > kLoose && last_backtracks <= 1) { | ||
| // Spectrum looks friendly and line search was not harsh -> try long BB | ||
| alpha0 = alpha_bb1; | ||
| } else if (r >= kTight && r <= kLoose) { | ||
| // Neither clearly friendly nor clearly dodgy -> neutral middle | ||
| alpha0 = std::sqrt(alpha_bb1 * alpha_bb2); | ||
| } // else keep alpha_bb2 | ||
|
|
||
| // Clip to user bounds | ||
| alpha0 = std::clamp(alpha0, min_alpha, max_alpha); | ||
|
|
||
| if (!std::isfinite(alpha0) || alpha0 <= 0.0) { | ||
| return safe_fallback(); | ||
| } | ||
| return alpha0; | ||
| } | ||
|
|
||
| } // namespace stan::math::internal | ||
| #endif | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,106 @@ | ||
| #ifndef STAN_MATH_MIX_FUNCTOR_CONDITIONAL_COPY_AND_PROMOTE_HPP | ||
| #define STAN_MATH_MIX_FUNCTOR_CONDITIONAL_COPY_AND_PROMOTE_HPP | ||
|
|
||
| #include <stan/math/mix/functor/hessian_block_diag.hpp> | ||
| #include <stan/math/prim/functor.hpp> | ||
| #include <stan/math/prim/fun.hpp> | ||
|
|
||
| namespace stan::math::internal { | ||
|
|
||
| /** | ||
| * Decide if object should be deep or shallow copied when | ||
| * using @ref conditional_copy_and_promote . | ||
| */ | ||
| enum class COPY_TYPE : uint8_t { SHALLOW = 0, DEEP = 1 }; | ||
|
|
||
| /** | ||
| * Conditional copy and promote a type's scalar type to a `PromotedType`. | ||
| * @tparam Filter type trait with a static constexpr bool member `value` | ||
| * that is true if the type should be promoted. Otherwise, the type is | ||
| * left unchanged. | ||
| * @tparam PromotedType type to promote the scalar to. | ||
| * @tparam CopyType type of copy to perform. | ||
| * @tparam Args variadic arguments. | ||
| * @param args variadic arguments to conditionally copy and promote. | ||
| * @return a tuple where each element is either a reference to the original | ||
| * argument or a promoted copy of the argument. | ||
| */ | ||
| template <template <typename...> class Filter, | ||
| typename PromotedType = stan::math::var, | ||
| COPY_TYPE CopyType = COPY_TYPE::DEEP, typename... Args> | ||
| inline auto conditional_copy_and_promote(Args&&... args) { | ||
| return map_if<Filter>( | ||
| [](auto&& arg) { | ||
| if constexpr (is_tuple_v<decltype(arg)>) { | ||
| return stan::math::apply( | ||
| [](auto&&... inner_args) { | ||
| return make_holder_tuple( | ||
| conditional_copy_and_promote<Filter, PromotedType, | ||
| CopyType>( | ||
| std::forward<decltype(inner_args)>(inner_args))...); | ||
| }, | ||
| std::forward<decltype(arg)>(arg)); | ||
| } else if constexpr (is_std_vector_v<decltype(arg)>) { | ||
| std::vector<decltype(conditional_copy_and_promote< | ||
| Filter, PromotedType, CopyType>(arg[0]))> | ||
| ret; | ||
| for (std::size_t i = 0; i < arg.size(); ++i) { | ||
| ret.push_back( | ||
| conditional_copy_and_promote<Filter, PromotedType, CopyType>( | ||
| arg[i])); | ||
| } | ||
| return ret; | ||
| } else { | ||
| if constexpr (CopyType == COPY_TYPE::DEEP) { | ||
| return stan::math::eval(promote_scalar<PromotedType>( | ||
| value_of_rec(std::forward<decltype(arg)>(arg)))); | ||
| } else if (CopyType == COPY_TYPE::SHALLOW) { | ||
| if constexpr (std::is_same_v<PromotedType, | ||
| scalar_type_t<decltype(arg)>>) { | ||
| return std::forward<decltype(arg)>(arg); | ||
| } else { | ||
| return stan::math::eval(promote_scalar<PromotedType>( | ||
| std::forward<decltype(arg)>(arg))); | ||
| } | ||
| } | ||
| } | ||
| }, | ||
| std::forward<Args>(args)...); | ||
| } | ||
|
|
||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. q: should the
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's only used in other internal functions so I think it is better to have in internal |
||
| /** | ||
| * Conditional deep copy types with a `var` scalar type to `PromotedType`. | ||
| * @tparam PromotedType type to promote the scalar to. | ||
| * @tparam Args variadic arguments. | ||
| * @param args variadic arguments to conditionally copy and promote. | ||
| * @return a tuple where each element is either a reference to the original | ||
| * argument or a promoted copy of the argument. | ||
| */ | ||
| template <typename PromotedType, typename... Args> | ||
| inline auto deep_copy_vargs(Args&&... args) { | ||
| return conditional_copy_and_promote<is_any_var_scalar, PromotedType, | ||
| COPY_TYPE::DEEP>( | ||
| std::forward<Args>(args)...); | ||
| } | ||
|
|
||
| /** | ||
| * Conditional shallow copy types with a `var` scalar type to `PromotedType`. | ||
| * @note This function is useful whenever you are inside of nested autodiff | ||
| * and want to allow the input arguments from an outer autodiff to be used | ||
| * in an inner autodiff without making a hard copy of the input arguments. | ||
| * @tparam PromotedType type to promote the scalar to. | ||
| * @tparam Args variadic arguments. | ||
| * @param args variadic arguments to conditionally copy and promote. | ||
| * @return a tuple where each element is either a reference to the original | ||
| * argument or a promoted copy of the argument. | ||
| */ | ||
| template <typename PromotedType, typename... Args> | ||
| inline auto shallow_copy_vargs(Args&&... args) { | ||
| return conditional_copy_and_promote<is_any_var_scalar, PromotedType, | ||
| COPY_TYPE::SHALLOW>( | ||
| std::forward<Args>(args)...); | ||
| } | ||
|
|
||
| } // namespace stan::math::internal | ||
|
|
||
| #endif | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@charlesm93 the C++ looks fine to me for this function but I'd appreciate someone else putting eyes on the math