Skip to content

Commit d2de83d

Browse files
committed
Merge remote-tracking branch 'refs/remotes/origin/fix/wolfe-zoom1' into fix/wolfe-zoom1
2 parents d6859f1 + 200fd9a commit d2de83d

File tree

2 files changed

+65
-59
lines changed

2 files changed

+65
-59
lines changed

stan/math/mix/functor/laplace_marginal_density_estimator.hpp

Lines changed: 42 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -420,7 +420,8 @@ inline void validate_laplace_options(const char* frame_name,
420420
std::stringstream msg;
421421
msg << frame_name << ": The size of the initial theta ("
422422
<< options.theta_0.size()
423-
<< ") vector must match the rows and columns of the covariance matrix ("
423+
<< ") vector must match the rows and columns of the covariance "
424+
"matrix ("
424425
<< covariance.rows() << ", " << covariance.cols() << ").";
425426
throw std::domain_error(msg.str());
426427
}
@@ -524,7 +525,6 @@ struct NewtonState {
524525
*/
525526
const auto& prev() const& { return wolfe_info.prev_; }
526527
auto&& prev() && { return std::move(wolfe_info).prev(); }
527-
528528
};
529529

530530
/**
@@ -1220,8 +1220,8 @@ inline auto laplace_marginal_density_est(
12201220
eval_in.obj() = obj_fun(proposal.a(), proposal.theta());
12211221
eval_in.dir() = grad_fun(proposal).dot(p);
12221222
};
1223-
auto update_try = [&update_step](auto&& proposal, auto&& curr_ref, auto&& prev_ref, auto& eval,
1224-
auto&& p) {
1223+
auto update_try = [&update_step](auto&& proposal, auto&& curr_ref,
1224+
auto&& prev_ref, auto& eval, auto&& p) {
12251225
try {
12261226
update_step(proposal, curr_ref, prev_ref, eval, p);
12271227
return std::isfinite(eval.obj()) && std::isfinite(eval.dir());
@@ -1234,43 +1234,44 @@ inline auto laplace_marginal_density_est(
12341234
return eval.alpha() > options.line_search.min_alpha;
12351235
};
12361236
auto is_valid = [](const auto& /* eval */, bool ok) { return ok; };
1237-
auto update_line_search
1238-
= [&grad_fun, &update_step, &options, &msgs, &state, &backoff, &update_try, &is_valid](
1239-
auto&& wolfe_status, auto&& wolfe_info, auto&& curr, auto&& prev) {
1240-
wolfe_info.p_ = curr.a() - prev.a();
1241-
state.prev_g.noalias() = grad_fun(prev);
1242-
wolfe_info.init_dir_ = state.prev_g.dot(wolfe_info.p_);
1243-
// Flip direction if not ascending
1244-
if (wolfe_info.init_dir_ < 0) {
1245-
wolfe_info.p_ = -wolfe_info.p_;
1246-
wolfe_info.init_dir_ = -wolfe_info.init_dir_;
1247-
}
1248-
auto&& scratch = wolfe_info.scratch_;
1249-
scratch.alpha() = 1.0;
1250-
internal::retry_evaluate(update_try, scratch, curr, prev, scratch.eval_,
1251-
wolfe_info.p_, backoff, is_valid);
1252-
if (scratch.alpha() <= options.line_search.min_alpha) {
1253-
wolfe_status.accept_ = false;
1254-
return true;
1255-
}
1256-
if (options.line_search.max_iterations == 0) {
1257-
if (scratch.alpha() > options.line_search.min_alpha) {
1258-
curr.update(scratch);
1259-
wolfe_status.accept_ = true;
1260-
return false;
1261-
}
1262-
} else {
1263-
Eigen::VectorXd s = scratch.a() - prev.a();
1264-
curr.alpha() = barzilai_borwein_step_size(
1265-
s, grad_fun(scratch), state.prev_g, prev.alpha(),
1266-
wolfe_status.num_backtracks_, options.line_search.min_alpha,
1267-
options.line_search.max_alpha);
1268-
wolfe_status = internal::wolfe_line_search(
1269-
wolfe_info, update_step, options.line_search, msgs);
1270-
}
1271-
return std::abs(curr.obj() - prev.obj()) < options.tolerance
1272-
|| (!wolfe_status.accept_ && curr.obj() <= prev.obj());
1273-
};
1237+
auto update_line_search = [&grad_fun, &update_step, &options, &msgs, &state,
1238+
&backoff, &update_try,
1239+
&is_valid](auto&& wolfe_status, auto&& wolfe_info,
1240+
auto&& curr, auto&& prev) {
1241+
wolfe_info.p_ = curr.a() - prev.a();
1242+
state.prev_g.noalias() = grad_fun(prev);
1243+
wolfe_info.init_dir_ = state.prev_g.dot(wolfe_info.p_);
1244+
// Flip direction if not ascending
1245+
if (wolfe_info.init_dir_ < 0) {
1246+
wolfe_info.p_ = -wolfe_info.p_;
1247+
wolfe_info.init_dir_ = -wolfe_info.init_dir_;
1248+
}
1249+
auto&& scratch = wolfe_info.scratch_;
1250+
scratch.alpha() = 1.0;
1251+
internal::retry_evaluate(update_try, scratch, curr, prev, scratch.eval_,
1252+
wolfe_info.p_, backoff, is_valid);
1253+
if (scratch.alpha() <= options.line_search.min_alpha) {
1254+
wolfe_status.accept_ = false;
1255+
return true;
1256+
}
1257+
if (options.line_search.max_iterations == 0) {
1258+
if (scratch.alpha() > options.line_search.min_alpha) {
1259+
curr.update(scratch);
1260+
wolfe_status.accept_ = true;
1261+
return false;
1262+
}
1263+
} else {
1264+
Eigen::VectorXd s = scratch.a() - prev.a();
1265+
curr.alpha() = barzilai_borwein_step_size(
1266+
s, grad_fun(scratch), state.prev_g, prev.alpha(),
1267+
wolfe_status.num_backtracks_, options.line_search.min_alpha,
1268+
options.line_search.max_alpha);
1269+
wolfe_status = internal::wolfe_line_search(wolfe_info, update_step,
1270+
options.line_search, msgs);
1271+
}
1272+
return std::abs(curr.obj() - prev.obj()) < options.tolerance
1273+
|| (!wolfe_status.accept_ && curr.obj() <= prev.obj());
1274+
};
12741275
auto set_next_iter = [&options](auto&& curr, auto&& prev) {
12751276
prev.update(curr);
12761277
curr.alpha() = std::clamp(curr.alpha(), 0.0, options.line_search.max_alpha);

stan/math/mix/functor/wolfe_line_search.hpp

Lines changed: 23 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -542,14 +542,15 @@ struct WolfeInfo {
542542
/**
543543
* Retry evaluation of a step until it passes a validity check.
544544
*
545-
* The update callable is invoked with `(curr, prev, eval, p)` and is expected to
546-
* fill `eval` (at `eval.alpha()`) with the objective and directional
545+
* The update callable is invoked with `(curr, prev, eval, p)` and is expected
546+
* to fill `eval` (at `eval.alpha()`) with the objective and directional
547547
* derivative. If the evaluation is not valid, the backoff callable should
548548
* shrink `eval.alpha()` and return whether another retry should be attempted.
549549
* The validity check can inspect the evaluation and, for non-void updates, the
550550
* returned status.
551551
*
552-
* @tparam Update Callable that performs one evaluation step. Must accept 4 arguments.
552+
* @tparam Update Callable that performs one evaluation step. Must accept 4
553+
* arguments.
553554
* @tparam Curr Current state type passed to `update`.
554555
* @tparam Prev Previous state type passed to `update`.
555556
* @tparam Eval Evaluation record containing alpha/obj/dir.
@@ -568,12 +569,13 @@ struct WolfeInfo {
568569
* @return For void updates, returns void. Otherwise returns the value from the
569570
* first valid evaluation.
570571
*/
571-
template <typename Update, typename Proposal, typename Curr, typename Prev, typename Eval,
572-
typename P, typename Backoff, typename IsValid>
573-
inline auto retry_evaluate(Update&& update, Proposal&& proposal, Curr&& curr, Prev&& prev, Eval& eval,
574-
P&& p, Backoff&& backoff, IsValid&& is_valid) {
575-
if constexpr (std::is_void_v<
576-
std::invoke_result_t<Update&, Proposal, Curr, Prev, Eval&, P>>) {
572+
template <typename Update, typename Proposal, typename Curr, typename Prev,
573+
typename Eval, typename P, typename Backoff, typename IsValid>
574+
inline auto retry_evaluate(Update&& update, Proposal&& proposal, Curr&& curr,
575+
Prev&& prev, Eval& eval, P&& p, Backoff&& backoff,
576+
IsValid&& is_valid) {
577+
if constexpr (std::is_void_v<std::invoke_result_t<Update&, Proposal, Curr,
578+
Prev, Eval&, P>>) {
577579
while (true) {
578580
update(proposal, curr, prev, eval, p);
579581
if (is_valid(eval)) {
@@ -596,7 +598,6 @@ inline auto retry_evaluate(Update&& update, Proposal&& proposal, Curr&& curr, Pr
596598
}
597599
}
598600

599-
600601
/**
601602
* @brief Strong-Wolfe line search with expansion, bracketing, and
602603
* cubic/bisection zoom.
@@ -885,8 +886,10 @@ inline WolfeStatus wolfe_line_search(Info& wolfe_info, UpdateFun&& update_fun,
885886
&& state.theta().allFinite() && state.theta_grad().allFinite();
886887
};
887888
Eval best = low; // keep the best Armijo-OK in case strong-Wolfe fails
888-
auto update_with_tick
889-
= [&total_updates, &opt, &best, &update_fun, &assign_step, &wolfe_ok, &armijo_ok](auto&& proposal, auto&& curr, auto&& prev, Eval& e, auto&& p) {
889+
auto update_with_tick = [&total_updates, &opt, &best, &update_fun,
890+
&assign_step, &wolfe_ok,
891+
&armijo_ok](auto&& proposal, auto&& curr,
892+
auto&& prev, Eval& e, auto&& p) {
890893
const bool over_budget = total_updates > opt.max_iterations;
891894
if (over_budget) {
892895
// Soft budget: stop evaluating new trial points once exceeded.
@@ -918,10 +921,11 @@ inline WolfeStatus wolfe_line_search(Info& wolfe_info, UpdateFun&& update_fun,
918921
return eval.alpha() >= opt.min_alpha;
919922
};
920923
auto is_valid = [&](const Eval& eval, const WolfeStatus& status) {
921-
return status.stop_ != WolfeReturn::Continue || eval_finite(eval, scratch);
924+
return status.stop_ != WolfeReturn::Continue
925+
|| eval_finite(eval, scratch);
922926
};
923-
wolfe_check = retry_evaluate(update_with_tick, scratch, curr, prev, high, p, backoff,
924-
is_valid);
927+
wolfe_check = retry_evaluate(update_with_tick, scratch, curr, prev, high, p,
928+
backoff, is_valid);
925929
if (wolfe_check.stop_ != WolfeReturn::Continue) {
926930
return wolfe_check;
927931
}
@@ -1081,10 +1085,11 @@ inline WolfeStatus wolfe_line_search(Info& wolfe_info, UpdateFun&& update_fun,
10811085
return eval.alpha() > opt.min_alpha;
10821086
};
10831087
auto is_valid = [&](const Eval& eval, const WolfeStatus& status) {
1084-
return status.stop_ != WolfeReturn::Continue || eval_finite(eval, scratch);
1088+
return status.stop_ != WolfeReturn::Continue
1089+
|| eval_finite(eval, scratch);
10851090
};
1086-
auto wolfe_check = retry_evaluate(update_with_tick, scratch, curr, prev, mid, p,
1087-
backoff, is_valid);
1091+
auto wolfe_check = retry_evaluate(update_with_tick, scratch, curr, prev,
1092+
mid, p, backoff, is_valid);
10881093
if (wolfe_check.stop_ != WolfeReturn::Continue) {
10891094
return wolfe_check;
10901095
}

0 commit comments

Comments
 (0)