Skip to content

Commit d736b5c

Browse files
committed
extract SolverStuff trait
1 parent 48d9ec5 commit d736b5c

File tree

1 file changed

+83
-62
lines changed

1 file changed

+83
-62
lines changed

chalk-recursive/src/recursive.rs

Lines changed: 83 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -83,27 +83,12 @@ where
8383
/// Attempt to solve a goal that has been fully broken down into leaf form
8484
/// and canonicalized. This is where the action really happens, and is the
8585
/// place where we would perform caching in rustc (and may eventually do in Chalk).
86-
#[instrument(
87-
level = "info",
88-
skip(
89-
self,
90-
minimums,
91-
is_coinductive_goal,
92-
initial_value,
93-
solve_iteration,
94-
reached_fixed_point,
95-
error_value
96-
)
97-
)]
86+
#[instrument(level = "info", skip(self, minimums, solver_stuff,))]
9887
fn solve_goal(
9988
&mut self,
10089
goal: &K,
10190
minimums: &mut Minimums,
102-
is_coinductive_goal: impl Fn(&K) -> bool,
103-
initial_value: impl Fn(&K, bool) -> V,
104-
solve_iteration: impl FnMut(&mut Self, &K, &mut Minimums) -> V,
105-
reached_fixed_point: impl Fn(&V, &V) -> bool,
106-
error_value: impl Fn() -> V,
91+
solver_stuff: impl SolverStuff<K, V>,
10792
) -> V {
10893
// First check the cache.
10994
if let Some(cache) = &self.cache {
@@ -122,7 +107,7 @@ where
122107
// see the corresponding section in the coinduction chapter:
123108
// https://rust-lang.github.io/chalk/book/recursive/coinduction.html#mixed-co-inductive-and-inductive-cycles
124109
if self.stack.mixed_inductive_coinductive_cycle_from(depth) {
125-
return error_value();
110+
return solver_stuff.error_value();
126111
}
127112
}
128113

@@ -138,13 +123,12 @@ where
138123
} else {
139124
// Otherwise, push the goal onto the stack and create a table.
140125
// The initial result for this table depends on whether the goal is coinductive.
141-
let coinductive_goal = is_coinductive_goal(goal);
142-
let initial_solution = initial_value(goal, coinductive_goal);
126+
let coinductive_goal = solver_stuff.is_coinductive_goal(goal);
127+
let initial_solution = solver_stuff.initial_value(goal, coinductive_goal);
143128
let depth = self.stack.push(coinductive_goal);
144129
let dfn = self.search_graph.insert(&goal, depth, initial_solution);
145130

146-
let subgoal_minimums =
147-
self.solve_new_subgoal(&goal, depth, dfn, solve_iteration, reached_fixed_point);
131+
let subgoal_minimums = self.solve_new_subgoal(&goal, depth, dfn, solver_stuff);
148132

149133
self.search_graph[dfn].links = subgoal_minimums;
150134
self.search_graph[dfn].stack_depth = None;
@@ -175,14 +159,13 @@ where
175159
}
176160
}
177161

178-
#[instrument(level = "debug", skip(self, solve_iteration, reached_fixed_point))]
162+
#[instrument(level = "debug", skip(self, solver_stuff))]
179163
fn solve_new_subgoal(
180164
&mut self,
181165
canonical_goal: &K,
182166
depth: StackDepth,
183167
dfn: DepthFirstNumber,
184-
mut solve_iteration: impl FnMut(&mut Self, &K, &mut Minimums) -> V,
185-
reached_fixed_point: impl Fn(&V, &V) -> bool,
168+
solver_stuff: impl SolverStuff<K, V>,
186169
) -> Minimums {
187170
// We start with `answer = None` and try to solve the goal. At the end of the iteration,
188171
// `answer` will be updated with the result of the solving process. If we detect a cycle
@@ -195,7 +178,7 @@ where
195178
// so this function will eventually be constant and the loop terminates.
196179
loop {
197180
let minimums = &mut Minimums::new();
198-
let current_answer = solve_iteration(self, &canonical_goal, minimums);
181+
let current_answer = solver_stuff.solve_iteration(self, &canonical_goal, minimums);
199182

200183
debug!(
201184
"solve_new_subgoal: loop iteration result = {:?} with minimums {:?}",
@@ -212,7 +195,7 @@ where
212195
let old_answer =
213196
std::mem::replace(&mut self.search_graph[dfn].solution, current_answer);
214197

215-
if reached_fixed_point(&old_answer, &self.search_graph[dfn].solution) {
198+
if solver_stuff.reached_fixed_point(&old_answer, &self.search_graph[dfn].solution) {
216199
return *minimums;
217200
}
218201

@@ -256,47 +239,85 @@ impl<'me, I: Interner> Solver<'me, I> {
256239
}
257240
}
258241

242+
trait SolverStuff<K, V>: Copy
243+
where
244+
K: Hash + Eq + Debug + Clone,
245+
V: Debug + Clone,
246+
{
247+
fn is_coinductive_goal(self, goal: &K) -> bool;
248+
fn initial_value(self, goal: &K, coinductive_goal: bool) -> V;
249+
fn solve_iteration(
250+
self,
251+
context: &mut RecursiveContext<K, V>,
252+
goal: &K,
253+
minimums: &mut Minimums,
254+
) -> V;
255+
fn reached_fixed_point(self, old_value: &V, new_value: &V) -> bool;
256+
fn error_value(self) -> V;
257+
}
258+
259+
impl<I: Interner> SolverStuff<UCanonicalGoal<I>, Fallible<Solution<I>>> for &dyn RustIrDatabase<I> {
260+
fn is_coinductive_goal(self, goal: &UCanonicalGoal<I>) -> bool {
261+
goal.is_coinductive(self)
262+
}
263+
264+
fn initial_value(
265+
self,
266+
goal: &UCanonicalGoal<I>,
267+
coinductive_goal: bool,
268+
) -> Fallible<Solution<I>> {
269+
if coinductive_goal {
270+
Ok(Solution::Unique(Canonical {
271+
value: ConstrainedSubst {
272+
subst: goal.trivial_substitution(self.interner()),
273+
constraints: Constraints::empty(self.interner()),
274+
},
275+
binders: goal.canonical.binders.clone(),
276+
}))
277+
} else {
278+
Err(NoSolution)
279+
}
280+
}
281+
282+
fn solve_iteration(
283+
self,
284+
context: &mut RecursiveContext<UCanonicalGoal<I>, Fallible<Solution<I>>>,
285+
goal: &UCanonicalGoal<I>,
286+
minimums: &mut Minimums,
287+
) -> Fallible<Solution<I>> {
288+
Solver::new(context, self).solve_iteration(goal, minimums)
289+
}
290+
291+
fn reached_fixed_point(
292+
self,
293+
old_answer: &Fallible<Solution<I>>,
294+
current_answer: &Fallible<Solution<I>>,
295+
) -> bool {
296+
// Some of our subgoals depended on us. We need to re-run
297+
// with the current answer.
298+
old_answer == current_answer || {
299+
// Subtle: if our current answer is ambiguous, we can just stop, and
300+
// in fact we *must* -- otherwise, we sometimes fail to reach a
301+
// fixed point. See `multiple_ambiguous_cycles` for more.
302+
match &current_answer {
303+
Ok(s) => s.is_ambig(),
304+
Err(_) => false,
305+
}
306+
}
307+
}
308+
309+
fn error_value(self) -> Fallible<Solution<I>> {
310+
Err(NoSolution)
311+
}
312+
}
313+
259314
impl<'me, I: Interner> SolveDatabase<I> for Solver<'me, I> {
260315
fn solve_goal(
261316
&mut self,
262317
goal: UCanonicalGoal<I>,
263318
minimums: &mut Minimums,
264319
) -> Fallible<Solution<I>> {
265-
let program = self.program;
266-
let interner = program.interner();
267-
self.context.solve_goal(
268-
&goal,
269-
minimums,
270-
|goal| goal.is_coinductive(program),
271-
|goal, coinductive_goal| {
272-
if coinductive_goal {
273-
Ok(Solution::Unique(Canonical {
274-
value: ConstrainedSubst {
275-
subst: goal.trivial_substitution(interner),
276-
constraints: Constraints::empty(interner),
277-
},
278-
binders: goal.canonical.binders.clone(),
279-
}))
280-
} else {
281-
Err(NoSolution)
282-
}
283-
},
284-
|context, goal, minimums| Solver::new(context, program).solve_iteration(goal, minimums),
285-
|old_answer, current_answer| {
286-
// Some of our subgoals depended on us. We need to re-run
287-
// with the current answer.
288-
old_answer == current_answer || {
289-
// Subtle: if our current answer is ambiguous, we can just stop, and
290-
// in fact we *must* -- otherwise, we sometimes fail to reach a
291-
// fixed point. See `multiple_ambiguous_cycles` for more.
292-
match &current_answer {
293-
Ok(s) => s.is_ambig(),
294-
Err(_) => false,
295-
}
296-
}
297-
},
298-
|| Err(NoSolution),
299-
)
320+
self.context.solve_goal(&goal, minimums, self.program)
300321
}
301322

302323
fn interner(&self) -> &I {

0 commit comments

Comments
 (0)