Skip to content

Commit 8da2e1d

Browse files
committed
Thread should_continue through recursive solver
1 parent 73e3c7d commit 8da2e1d

File tree

4 files changed

+62
-26
lines changed

4 files changed

+62
-26
lines changed

chalk-recursive/src/fixed_point.rs

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ where
4343
context: &mut RecursiveContext<K, V>,
4444
goal: &K,
4545
minimums: &mut Minimums,
46+
should_continue: impl std::ops::Fn() -> bool,
4647
) -> V;
4748
fn reached_fixed_point(self, old_value: &V, new_value: &V) -> bool;
4849
fn error_value(self) -> V;
@@ -104,22 +105,24 @@ where
104105
&mut self,
105106
canonical_goal: &K,
106107
solver_stuff: impl SolverStuff<K, V>,
108+
should_continue: impl std::ops::Fn() -> bool,
107109
) -> V {
108110
debug!("solve_root_goal(canonical_goal={:?})", canonical_goal);
109111
assert!(self.stack.is_empty());
110112
let minimums = &mut Minimums::new();
111-
self.solve_goal(canonical_goal, minimums, solver_stuff)
113+
self.solve_goal(canonical_goal, minimums, solver_stuff, should_continue)
112114
}
113115

114116
/// Attempt to solve a goal that has been fully broken down into leaf form
115117
/// and canonicalized. This is where the action really happens, and is the
116118
/// place where we would perform caching in rustc (and may eventually do in Chalk).
117-
#[instrument(level = "info", skip(self, minimums, solver_stuff,))]
119+
#[instrument(level = "info", skip(self, minimums, solver_stuff, should_continue))]
118120
pub fn solve_goal(
119121
&mut self,
120122
goal: &K,
121123
minimums: &mut Minimums,
122124
solver_stuff: impl SolverStuff<K, V>,
125+
should_continue: impl std::ops::Fn() -> bool,
123126
) -> V {
124127
// First check the cache.
125128
if let Some(cache) = &self.cache {
@@ -159,7 +162,8 @@ where
159162
let depth = self.stack.push(coinductive_goal);
160163
let dfn = self.search_graph.insert(goal, depth, initial_solution);
161164

162-
let subgoal_minimums = self.solve_new_subgoal(goal, depth, dfn, solver_stuff);
165+
let subgoal_minimums =
166+
self.solve_new_subgoal(goal, depth, dfn, solver_stuff, should_continue);
163167

164168
self.search_graph[dfn].links = subgoal_minimums;
165169
self.search_graph[dfn].stack_depth = None;
@@ -190,13 +194,14 @@ where
190194
}
191195
}
192196

193-
#[instrument(level = "debug", skip(self, solver_stuff))]
197+
#[instrument(level = "debug", skip(self, solver_stuff, should_continue))]
194198
fn solve_new_subgoal(
195199
&mut self,
196200
canonical_goal: &K,
197201
depth: StackDepth,
198202
dfn: DepthFirstNumber,
199203
solver_stuff: impl SolverStuff<K, V>,
204+
should_continue: impl std::ops::Fn() -> bool,
200205
) -> Minimums {
201206
// We start with `answer = None` and try to solve the goal. At the end of the iteration,
202207
// `answer` will be updated with the result of the solving process. If we detect a cycle
@@ -209,7 +214,8 @@ where
209214
// so this function will eventually be constant and the loop terminates.
210215
loop {
211216
let minimums = &mut Minimums::new();
212-
let current_answer = solver_stuff.solve_iteration(self, canonical_goal, minimums);
217+
let current_answer =
218+
solver_stuff.solve_iteration(self, canonical_goal, minimums, &should_continue);
213219

214220
debug!(
215221
"solve_new_subgoal: loop iteration result = {:?} with minimums {:?}",

chalk-recursive/src/fulfill.rs

Lines changed: 28 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -342,24 +342,31 @@ impl<'s, I: Interner, Solver: SolveDatabase<I>> Fulfill<'s, I, Solver> {
342342
Ok(())
343343
}
344344

345-
#[instrument(level = "debug", skip(self, minimums))]
345+
#[instrument(level = "debug", skip(self, minimums, should_continue))]
346346
fn prove(
347347
&mut self,
348348
wc: InEnvironment<Goal<I>>,
349349
minimums: &mut Minimums,
350+
should_continue: impl std::ops::Fn() -> bool,
350351
) -> Fallible<PositiveSolution<I>> {
351352
let interner = self.solver.interner();
352353
let (quantified, free_vars) = canonicalize(&mut self.infer, interner, wc);
353354
let (quantified, universes) = u_canonicalize(&mut self.infer, interner, &quantified);
354-
let result = self.solver.solve_goal(quantified, minimums);
355+
let result = self
356+
.solver
357+
.solve_goal(quantified, minimums, should_continue);
355358
Ok(PositiveSolution {
356359
free_vars,
357360
universes,
358361
solution: result?,
359362
})
360363
}
361364

362-
fn refute(&mut self, goal: InEnvironment<Goal<I>>) -> Fallible<NegativeSolution> {
365+
fn refute(
366+
&mut self,
367+
goal: InEnvironment<Goal<I>>,
368+
should_continue: impl std::ops::Fn() -> bool,
369+
) -> Fallible<NegativeSolution> {
363370
let canonicalized = match self
364371
.infer
365372
.invert_then_canonicalize(self.solver.interner(), goal)
@@ -376,7 +383,10 @@ impl<'s, I: Interner, Solver: SolveDatabase<I>> Fulfill<'s, I, Solver> {
376383
let (quantified, _) =
377384
u_canonicalize(&mut self.infer, self.solver.interner(), &canonicalized);
378385
let mut minimums = Minimums::new(); // FIXME -- minimums here seems wrong
379-
if let Ok(solution) = self.solver.solve_goal(quantified, &mut minimums) {
386+
if let Ok(solution) = self
387+
.solver
388+
.solve_goal(quantified, &mut minimums, should_continue)
389+
{
380390
if solution.is_unique() {
381391
Err(NoSolution)
382392
} else {
@@ -431,7 +441,11 @@ impl<'s, I: Interner, Solver: SolveDatabase<I>> Fulfill<'s, I, Solver> {
431441
}
432442
}
433443

434-
fn fulfill(&mut self, minimums: &mut Minimums) -> Fallible<Outcome> {
444+
fn fulfill(
445+
&mut self,
446+
minimums: &mut Minimums,
447+
should_continue: impl std::ops::Fn() -> bool,
448+
) -> Fallible<Outcome> {
435449
debug_span!("fulfill", obligations=?self.obligations);
436450

437451
// Try to solve all the obligations. We do this via a fixed-point
@@ -460,7 +474,7 @@ impl<'s, I: Interner, Solver: SolveDatabase<I>> Fulfill<'s, I, Solver> {
460474
free_vars,
461475
universes,
462476
solution,
463-
} = self.prove(wc.clone(), minimums)?;
477+
} = self.prove(wc.clone(), minimums, &should_continue)?;
464478

465479
if let Some(constrained_subst) = solution.definite_subst(self.interner()) {
466480
// If the substitution is trivial, we won't actually make any progress by applying it!
@@ -484,7 +498,7 @@ impl<'s, I: Interner, Solver: SolveDatabase<I>> Fulfill<'s, I, Solver> {
484498
solution.is_ambig()
485499
}
486500
Obligation::Refute(goal) => {
487-
let answer = self.refute(goal.clone())?;
501+
let answer = self.refute(goal.clone(), &should_continue)?;
488502
answer == NegativeSolution::Ambiguous
489503
}
490504
};
@@ -514,8 +528,12 @@ impl<'s, I: Interner, Solver: SolveDatabase<I>> Fulfill<'s, I, Solver> {
514528
/// Try to fulfill all pending obligations and build the resulting
515529
/// solution. The returned solution will transform `subst` substitution with
516530
/// the outcome of type inference by updating the replacements it provides.
517-
pub(super) fn solve(mut self, minimums: &mut Minimums) -> Fallible<Solution<I>> {
518-
let outcome = match self.fulfill(minimums) {
531+
pub(super) fn solve(
532+
mut self,
533+
minimums: &mut Minimums,
534+
should_continue: impl std::ops::Fn() -> bool,
535+
) -> Fallible<Solution<I>> {
536+
let outcome = match self.fulfill(minimums, &should_continue) {
519537
Ok(o) => o,
520538
Err(e) => return Err(e),
521539
};
@@ -567,7 +585,7 @@ impl<'s, I: Interner, Solver: SolveDatabase<I>> Fulfill<'s, I, Solver> {
567585
free_vars,
568586
universes,
569587
solution,
570-
} = self.prove(goal, minimums).unwrap();
588+
} = self.prove(goal, minimums, &should_continue).unwrap();
571589
if let Some(constrained_subst) =
572590
solution.constrained_subst(self.solver.interner())
573591
{

chalk-recursive/src/recursive.rs

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -76,8 +76,9 @@ impl<I: Interner> SolverStuff<UCanonicalGoal<I>, Fallible<Solution<I>>> for &dyn
7676
context: &mut RecursiveContext<UCanonicalGoal<I>, Fallible<Solution<I>>>,
7777
goal: &UCanonicalGoal<I>,
7878
minimums: &mut Minimums,
79+
should_continue: impl std::ops::Fn() -> bool,
7980
) -> Fallible<Solution<I>> {
80-
Solver::new(context, self).solve_iteration(goal, minimums)
81+
Solver::new(context, self).solve_iteration(goal, minimums, should_continue)
8182
}
8283

8384
fn reached_fixed_point(
@@ -108,8 +109,10 @@ impl<'me, I: Interner> SolveDatabase<I> for Solver<'me, I> {
108109
&mut self,
109110
goal: UCanonicalGoal<I>,
110111
minimums: &mut Minimums,
112+
should_continue: impl std::ops::Fn() -> bool,
111113
) -> Fallible<Solution<I>> {
112-
self.context.solve_goal(&goal, minimums, self.program)
114+
self.context
115+
.solve_goal(&goal, minimums, self.program, should_continue)
113116
}
114117

115118
fn interner(&self) -> I {
@@ -131,17 +134,19 @@ impl<I: Interner> chalk_solve::Solver<I> for RecursiveSolver<I> {
131134
program: &dyn RustIrDatabase<I>,
132135
goal: &UCanonical<InEnvironment<Goal<I>>>,
133136
) -> Option<chalk_solve::Solution<I>> {
134-
self.ctx.solve_root_goal(goal, program).ok()
137+
self.ctx.solve_root_goal(goal, program, || true).ok()
135138
}
136139

137140
fn solve_limited(
138141
&mut self,
139142
program: &dyn RustIrDatabase<I>,
140143
goal: &UCanonical<InEnvironment<Goal<I>>>,
141-
_should_continue: &dyn std::ops::Fn() -> bool,
144+
should_continue: &dyn std::ops::Fn() -> bool,
142145
) -> Option<chalk_solve::Solution<I>> {
143146
// TODO support should_continue in recursive solver
144-
self.ctx.solve_root_goal(goal, program).ok()
147+
self.ctx
148+
.solve_root_goal(goal, program, should_continue)
149+
.ok()
145150
}
146151

147152
fn solve_multiple(

chalk-recursive/src/solve.rs

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ pub(super) trait SolveDatabase<I: Interner>: Sized {
2020
&mut self,
2121
goal: UCanonical<InEnvironment<Goal<I>>>,
2222
minimums: &mut Minimums,
23+
should_continue: impl std::ops::Fn() -> bool,
2324
) -> Fallible<Solution<I>>;
2425

2526
fn max_size(&self) -> usize;
@@ -35,11 +36,12 @@ pub(super) trait SolveIteration<I: Interner>: SolveDatabase<I> {
3536
/// Executes one iteration of the recursive solver, computing the current
3637
/// solution to the given canonical goal. This is used as part of a loop in
3738
/// the case of cyclic goals.
38-
#[instrument(level = "debug", skip(self))]
39+
#[instrument(level = "debug", skip(self, should_continue))]
3940
fn solve_iteration(
4041
&mut self,
4142
canonical_goal: &UCanonicalGoal<I>,
4243
minimums: &mut Minimums,
44+
should_continue: impl std::ops::Fn() -> bool,
4345
) -> Fallible<Solution<I>> {
4446
let UCanonical {
4547
universes,
@@ -72,7 +74,7 @@ pub(super) trait SolveIteration<I: Interner>: SolveDatabase<I> {
7274
let prog_solution = {
7375
debug_span!("prog_clauses");
7476

75-
self.solve_from_clauses(&canonical_goal, minimums)
77+
self.solve_from_clauses(&canonical_goal, minimums, should_continue)
7678
};
7779
debug!(?prog_solution);
7880

@@ -88,7 +90,7 @@ pub(super) trait SolveIteration<I: Interner>: SolveDatabase<I> {
8890
},
8991
};
9092

91-
self.solve_via_simplification(&canonical_goal, minimums)
93+
self.solve_via_simplification(&canonical_goal, minimums, should_continue)
9294
}
9395
}
9496
}
@@ -103,15 +105,16 @@ where
103105

104106
/// Helper methods for `solve_iteration`, private to this module.
105107
trait SolveIterationHelpers<I: Interner>: SolveDatabase<I> {
106-
#[instrument(level = "debug", skip(self, minimums))]
108+
#[instrument(level = "debug", skip(self, minimums, should_continue))]
107109
fn solve_via_simplification(
108110
&mut self,
109111
canonical_goal: &UCanonicalGoal<I>,
110112
minimums: &mut Minimums,
113+
should_continue: impl std::ops::Fn() -> bool,
111114
) -> Fallible<Solution<I>> {
112115
let (infer, subst, goal) = self.new_inference_table(canonical_goal);
113116
match Fulfill::new_with_simplification(self, infer, subst, goal) {
114-
Ok(fulfill) => fulfill.solve(minimums),
117+
Ok(fulfill) => fulfill.solve(minimums, should_continue),
115118
Err(e) => Err(e),
116119
}
117120
}
@@ -123,6 +126,7 @@ trait SolveIterationHelpers<I: Interner>: SolveDatabase<I> {
123126
&mut self,
124127
canonical_goal: &UCanonical<InEnvironment<DomainGoal<I>>>,
125128
minimums: &mut Minimums,
129+
should_continue: impl std::ops::Fn() -> bool,
126130
) -> Fallible<Solution<I>> {
127131
let mut clauses = vec![];
128132

@@ -159,7 +163,10 @@ trait SolveIterationHelpers<I: Interner>: SolveDatabase<I> {
159163
let subst = subst.clone();
160164
let goal = goal.clone();
161165
let res = match Fulfill::new_with_clause(self, infer, subst, goal, implication) {
162-
Ok(fulfill) => (fulfill.solve(minimums), implication.skip_binders().priority),
166+
Ok(fulfill) => (
167+
fulfill.solve(minimums, &should_continue),
168+
implication.skip_binders().priority,
169+
),
163170
Err(e) => (Err(e), ClausePriority::High),
164171
};
165172

0 commit comments

Comments
 (0)