Skip to content

Commit 0b8e9f8

Browse files
committed
improve eq
also improve tracing and debuggability
1 parent 426e5d4 commit 0b8e9f8

File tree

8 files changed

+141
-80
lines changed

8 files changed

+141
-80
lines changed

crates/formality-prove/src/prove/constraints.rs

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
use formality_types::{
2-
cast::Upcast,
2+
cast::{Downcast, Upcast},
33
cast_impl,
4-
derive_links::UpcastFrom,
4+
derive_links::{DowncastTo, UpcastFrom},
55
fold::Fold,
6-
grammar::{Parameter, Substitution, Variable},
6+
grammar::{InferenceVar, Parameter, Substitution, Variable},
77
term::Term,
88
visit::Visit,
99
};
@@ -191,6 +191,7 @@ impl<R: Term> Visit for Constraints<R> {
191191
let domain = substitution.domain();
192192

193193
for &v in &domain {
194+
assert!(v.downcast::<InferenceVar>().is_some());
194195
assert!(v.is_free());
195196
assert!(is_valid_binding(v, &substitution[v]));
196197
}

crates/formality-prove/src/prove/prove_after.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,16 +10,16 @@ use super::constraints::Constraints;
1010
judgment_fn! {
1111
pub fn prove_after(
1212
program: Program,
13-
env1: Env,
14-
c1: Constraints,
13+
env: Env,
14+
constraints: Constraints,
1515
assumptions: Wcs,
1616
goal: Wcs,
1717
) => (Env, Constraints) {
1818
(
1919
(let (assumptions, goal) = c1.substitution().apply(&(assumptions, goal)))
20-
(prove(program, env1, assumptions, goal) => (env2, c2))
20+
(prove(program, env, assumptions, goal) => (env, c2))
2121
--- ("prove_after")
22-
(prove_after(program, env1, c1, assumptions, goal) => (env2, c1.seq(c2)))
22+
(prove_after(program, env, c1, assumptions, goal) => (env, c1.seq(c2)))
2323
)
2424
}
2525
}

crates/formality-prove/src/prove/prove_eq.rs

Lines changed: 42 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,20 @@
11
use formality_types::{
2-
cast::{Downcast, Upcast, Upcasted},
3-
collections::Deduplicate,
2+
cast::{Downcast, Downcasted, Upcast, Upcasted},
3+
collections::{Deduplicate, Set},
44
grammar::{
5-
AliasTy, AtomicRelation, InferenceVar, Parameter, RigidTy, Substitution, Ty, TyData,
6-
Variable, Wcs,
5+
AliasTy, AtomicRelation, InferenceVar, Parameter, PlaceholderVar, RigidTy, Substitution,
6+
Ty, TyData, Variable, Wcs,
77
},
8-
judgment_fn,
8+
judgment_fn, set,
99
visit::Visit,
1010
};
1111

1212
use crate::{
1313
program::Program,
1414
prove::{
15-
constraints::no_constraints, prove, prove_after::prove_after,
15+
constraints::{no_constraints, occurs_in},
16+
prove,
17+
prove_after::prove_after,
1618
prove_normalize::prove_normalize,
1719
},
1820
};
@@ -75,17 +77,17 @@ judgment_fn! {
7577
)
7678

7779
(
78-
(if let None = t.downcast::<InferenceVar>())
79-
(if let Some(env_c) = equate_variable(env, v, t))
80+
(if let None = t.downcast::<Variable>())
81+
(equate_variable(program, env, assumptions, v, t) => (env, c))
8082
----------------------------- ("existential-l")
81-
(prove_ty_eq(_program, env, _assumptions, Variable::InferenceVar(v), t) => env_c)
83+
(prove_ty_eq(program, env, assumptions, Variable::InferenceVar(v), t) => (env, c))
8284
)
8385

8486
(
85-
(if let None = t.downcast::<InferenceVar>())
86-
(if let Some(env_c) = equate_variable(env, v, t))
87+
(if let None = t.downcast::<Variable>())
88+
(equate_variable(program, env, assumptions, v, t) => (env, c))
8789
----------------------------- ("existential-r")
88-
(prove_ty_eq(_program, env, _assumptions, t, Variable::InferenceVar(v)) => env_c)
90+
(prove_ty_eq(program, env, assumptions, t, Variable::InferenceVar(v)) => (env, c))
8991
)
9092

9193
(
@@ -95,6 +97,12 @@ judgment_fn! {
9597
(prove_ty_eq(_program, env, _assumptions, Variable::InferenceVar(l), Variable::InferenceVar(r)) => (env, (b, a)))
9698
)
9799

100+
(
101+
(if env.universe(p) < env.universe(v))
102+
----------------------------- ("existential-vs-placeholder")
103+
(prove_ty_eq(_program, env, _assumptions, Variable::InferenceVar(v), Variable::PlaceholderVar(p)) => (env, (v, p)))
104+
)
105+
98106
(
99107
(prove_normalize(&program, env, &assumptions, &x) => (env1, y, c1))
100108
(prove_after(&program, env1, c1, &assumptions, eq(y, &z)) => (env2, c2))
@@ -105,20 +113,24 @@ judgment_fn! {
105113
}
106114

107115
fn equate_variable(
116+
program: Program,
108117
mut env: Env,
118+
assumptions: Wcs,
109119
x: InferenceVar,
110120
p: impl Upcast<Parameter>,
111-
) -> Option<(Env, Constraints)> {
121+
) -> Set<(Env, Constraints)> {
112122
let p: Parameter = p.upcast();
113123

114124
let span = tracing::debug_span!("equate_variable", ?x, ?p, ?env);
115125
let _guard = span.enter();
116126

117127
let fvs = p.free_variables().deduplicate();
118128

129+
env.assert_encloses((x, &fvs));
130+
119131
// Ensure that `x` passes the occurs check for the free variables in `p`.
120-
if !passes_occurs_check(&env, x, &fvs) {
121-
return None;
132+
if occurs_in(x, &fvs) {
133+
return set![];
122134
}
123135

124136
// Map each free variable `fv` in `p` that is of higher universe than `x`
@@ -130,7 +142,7 @@ fn equate_variable(
130142
let universe_x = env.universe(x);
131143
let universe_subst: Substitution = fvs
132144
.iter()
133-
.flat_map(|&fv| {
145+
.flat_map(|fv| {
134146
if universe_x < env.universe(fv) {
135147
let y = env.insert_fresh_before(fv.kind(), universe_x);
136148
Some((fv, y))
@@ -146,41 +158,23 @@ fn equate_variable(
146158
// * `x = universe_subst(p)` (e.g., `Vec<Z>` in our example above)
147159
let constraints: Constraints = universe_subst
148160
.iter()
161+
.filter(|(v, _)| v.is_a::<InferenceVar>())
149162
.chain(Some((x, universe_subst.apply(&p)).upcast()))
150163
.collect();
151164

152-
tracing::debug!("success: env={:?}, constraints={:?}", env, constraints);
153-
Some((env, constraints))
154-
}
155-
156-
/// An existential variable `x` *passes the occurs check* with respect to
157-
/// a set of free variables `fvs` if
158-
///
159-
/// * `x` is not a member of `fvs`
160-
/// * all placeholders in `fvs` are in a lower universe than `v`
161-
fn passes_occurs_check(env: &Env, x: InferenceVar, fvs: &[Variable]) -> bool {
162-
env.assert_encloses((x, fvs));
165+
let goals: Wcs = universe_subst
166+
.iter()
167+
.filter(|(v, _)| v.is_a::<PlaceholderVar>())
168+
.map(|(v, p)| eq(v, p))
169+
.upcasted()
170+
.collect();
163171

164-
let universe_x = env.universe(x);
165-
for fv in fvs {
166-
match fv {
167-
Variable::PlaceholderVar(pv) => {
168-
if universe_x < env.universe(pv) {
169-
return false;
170-
} else {
171-
}
172-
}
173-
Variable::InferenceVar(iv) => {
174-
if *iv == x {
175-
return false;
176-
} else {
177-
}
178-
}
179-
Variable::BoundVar(_) => {
180-
panic!("unexpected bound variable");
181-
}
182-
}
183-
}
172+
tracing::debug!(
173+
"equated: env={:?}, constraints={:?}, goals={:?}",
174+
env,
175+
constraints,
176+
goals
177+
);
184178

185-
true
179+
prove_after(program, env, constraints, assumptions, goals)
186180
}

crates/formality-prove/src/test/eq_assumptions.rs

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,24 @@ fn test_b() {
3636
term("<ty A> ({}, {for<ty T, ty U> if {T = u32, U = Vec<T>} A = U})"),
3737
);
3838
expect![[r#"
39-
{}
40-
"#]] // FIXME
39+
{
40+
(
41+
Env {
42+
variables: [
43+
?ty_4_U(0),
44+
?ty_1_U(0),
45+
],
46+
},
47+
Constraints {
48+
result: (),
49+
known_true: true,
50+
substitution: {
51+
?ty_1_U(0) => (rigid (adt Vec) (rigid (scalar u32))),
52+
?ty_4_U(0) => (rigid (scalar u32)),
53+
},
54+
},
55+
),
56+
}
57+
"#]]
4158
.assert_debug_eq(&constraints);
4259
}

crates/formality-types/src/cast.rs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,13 @@ pub trait Downcast: Sized {
4444
fn downcast<T>(&self) -> Option<T>
4545
where
4646
T: DowncastFrom<Self>;
47+
48+
fn is_a<T>(&self) -> bool
49+
where
50+
T: DowncastFrom<Self>,
51+
{
52+
self.downcast::<T>().is_some()
53+
}
4754
}
4855

4956
impl<U> Downcast for U {

crates/formality-types/src/fixed_point.rs

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ mod stack;
77
pub use stack::FixedPointStack;
88

99
pub fn fixed_point<Input, Output>(
10+
tracing_span: impl Fn(&Input) -> tracing::Span,
1011
storage: &'static LocalKey<RefCell<FixedPointStack<Input, Output>>>,
1112
args: Input,
1213
default_value: impl Fn(&Input) -> Output,
@@ -18,6 +19,7 @@ where
1819
{
1920
stacker::maybe_grow(32 * 1024, 1024 * 1024, || {
2021
FixedPoint {
22+
tracing_span,
2123
storage,
2224
default_value,
2325
next_value,
@@ -26,7 +28,12 @@ where
2628
})
2729
}
2830

29-
struct FixedPoint<Input: Value, Output: Value, DefaultValue, NextValue> {
31+
struct FixedPoint<Input, Output, DefaultValue, NextValue, TracingSpan>
32+
where
33+
Input: Value,
34+
Output: Value,
35+
{
36+
tracing_span: TracingSpan,
3037
storage: &'static LocalKey<RefCell<FixedPointStack<Input, Output>>>,
3138
default_value: DefaultValue,
3239
next_value: NextValue,
@@ -35,15 +42,18 @@ struct FixedPoint<Input: Value, Output: Value, DefaultValue, NextValue> {
3542
pub trait Value: Clone + Eq + Debug + Hash + 'static {}
3643
impl<T: Clone + Eq + Debug + Hash + 'static> Value for T {}
3744

38-
impl<Input, Output, DefaultValue, NextValue> FixedPoint<Input, Output, DefaultValue, NextValue>
45+
impl<Input, Output, DefaultValue, NextValue, TracingSpan>
46+
FixedPoint<Input, Output, DefaultValue, NextValue, TracingSpan>
3947
where
4048
Input: Value,
4149
Output: Value,
4250
DefaultValue: Fn(&Input) -> Output,
4351
NextValue: Fn(Input) -> Output,
52+
TracingSpan: Fn(&Input) -> tracing::Span,
4453
{
4554
fn apply(&self, input: Input) -> Output {
4655
if let Some(r) = self.with_stack(|stack| stack.search(&input)) {
56+
tracing::debug!("recursive call to {:?}, yielding {:?}", input, r);
4757
return r;
4858
}
4959

@@ -53,12 +63,14 @@ where
5363
});
5464

5565
loop {
56-
let span = tracing::debug_span!("fixed-point iteration", ?input);
66+
let span = (self.tracing_span)(&input);
5767
let _guard = span.enter();
5868
let output = (self.next_value)(input.clone());
5969
tracing::debug!(?output);
6070
if !self.with_stack(|stack| stack.update_output(&input, output)) {
6171
break;
72+
} else {
73+
tracing::debug!("output is different from previous iteration, re-executing until fixed point is reached");
6274
}
6375
}
6476

0 commit comments

Comments
 (0)