Skip to content

Commit ac51c20

Browse files
authored
Merge pull request #414 from brendanzab/unify-branches
Unify constant matches
2 parents 0de5068 + 18470f0 commit ac51c20

File tree

7 files changed

+125
-47
lines changed

7 files changed

+125
-47
lines changed

fathom/src/core.rs

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -175,9 +175,8 @@ pub enum Term<'arena> {
175175

176176
/// Constant literals.
177177
ConstLit(Span, Const),
178-
/// Match on a constant.
179-
///
180-
/// (head_expr, pattern_branches, default_expr)
178+
/// Match on a constant. The pattern branches should be unique, and listed
179+
/// in lexicographic order.
181180
ConstMatch(
182181
Span,
183182
&'arena Term<'arena>,

fathom/src/core/semantics.rs

Lines changed: 52 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
//! The operational semantics of the core language, implemented using
2-
//! [normalisation by evaluation](https://en.wikipedia.org/wiki/Normalisation_by_evaluation).
1+
//! The semantics of the core language, implemented using [normalisation by
2+
//! evaluation](https://en.wikipedia.org/wiki/Normalisation_by_evaluation).
33
44
use scoped_arena::Scope;
55
use std::panic::panic_any;
@@ -1072,20 +1072,7 @@ impl<'arena, 'env> ConversionEnv<'arena, 'env> {
10721072
| (_, Value::Stuck(Head::Prim(Prim::ReportedError), _)) => true,
10731073

10741074
(Value::Stuck(head0, spine0), Value::Stuck(head1, spine1)) => {
1075-
use Elim::*;
1076-
1077-
head0 == head1
1078-
&& spine0.len() == spine1.len()
1079-
&& Iterator::zip(spine0.iter(), spine1.iter()).all(|(elim0, elim1)| {
1080-
match (elim0, elim1) {
1081-
(FunApp(expr0), FunApp(expr1)) => self.is_equal(expr0, expr1),
1082-
(RecordProj(label0), RecordProj(label1)) => label0 == label1,
1083-
(ConstMatch(branches0), ConstMatch(branches1)) => {
1084-
self.is_equal_branches(branches0, branches1)
1085-
}
1086-
(_, _) => false,
1087-
}
1088-
})
1075+
head0 == head1 && self.is_equal_spines(spine0, spine1)
10891076
}
10901077
(Value::Universe, Value::Universe) => true,
10911078

@@ -1142,6 +1129,21 @@ impl<'arena, 'env> ConversionEnv<'arena, 'env> {
11421129
}
11431130
}
11441131

1132+
/// Check that two elimination spines are equal.
1133+
pub fn is_equal_spines(&mut self, spine0: &[Elim<'_>], spine1: &[Elim<'_>]) -> bool {
1134+
spine0.len() == spine1.len()
1135+
&& Iterator::zip(spine0.iter(), spine1.iter()).all(|(elim0, elim1)| {
1136+
match (elim0, elim1) {
1137+
(Elim::FunApp(expr0), Elim::FunApp(expr1)) => self.is_equal(expr0, expr1),
1138+
(Elim::RecordProj(label0), Elim::RecordProj(label1)) => label0 == label1,
1139+
(Elim::ConstMatch(branches0), Elim::ConstMatch(branches1)) => {
1140+
self.is_equal_branches(branches0, branches1)
1141+
}
1142+
(_, _) => false,
1143+
}
1144+
})
1145+
}
1146+
11451147
/// Check that two [closures][Closure] are equal.
11461148
pub fn is_equal_closures(&mut self, closure0: &Closure<'_>, closure1: &Closure<'_>) -> bool {
11471149
let var = Spanned::empty(Arc::new(Value::local_var(self.local_exprs.next_level())));
@@ -1258,32 +1260,46 @@ impl<'arena, 'env> ConversionEnv<'arena, 'env> {
12581260
#[cfg(test)]
12591261
mod tests {
12601262
use super::*;
1261-
use crate::core::Const;
1262-
1263-
#[test]
1264-
fn value_has_unify_and_is_equal_impls() {
1265-
let value = Arc::new(Value::ConstLit(Const::Bool(false)));
12661263

1267-
// This test exists in order to cause a test failure when `Value` is changed. If this test
1268-
// has failed and you have added a new variant to Value it is a prompt to ensure that
1269-
// variant is handled in:
1264+
#[allow(dead_code)]
1265+
fn value_has_unify_and_is_equal_impls(value: Value<'_>) {
1266+
// The following match will fail to be exhaustive after new variants
1267+
// are added to `Value`. When this happens, it’s a prompt to make sure
1268+
// that the variants are handled in:
12701269
//
1271-
// - surface::elaboration::Env::unify
1270+
// - surface::elaboration::Context::unify
12721271
// - core::semantics::is_equal
12731272
//
12741273
// NOTE: Only update the match below when you've updated the above functions.
1275-
match value.as_ref() {
1276-
Value::Stuck(_, _) => {}
1274+
match value {
1275+
Value::Stuck(..) => {}
12771276
Value::Universe => {}
1278-
Value::FunType(_, _, _) => {}
1279-
Value::FunLit(_, _) => {}
1280-
Value::RecordType(_, _) => {}
1281-
Value::RecordLit(_, _) => {}
1282-
Value::ArrayLit(_) => {}
1283-
Value::FormatRecord(_, _) => {}
1284-
Value::FormatCond(_, _, _) => {}
1285-
Value::FormatOverlap(_, _) => {}
1286-
Value::ConstLit(_) => {}
1277+
Value::FunType(..) => {}
1278+
Value::FunLit(..) => {}
1279+
Value::RecordType(..) => {}
1280+
Value::RecordLit(..) => {}
1281+
Value::ArrayLit(..) => {}
1282+
Value::FormatRecord(..) => {}
1283+
Value::FormatCond(..) => {}
1284+
Value::FormatOverlap(..) => {}
1285+
Value::ConstLit(..) => {}
1286+
}
1287+
}
1288+
1289+
#[allow(dead_code)]
1290+
fn elim_has_unify_and_is_equal_impls(elim: Elim<'_>) {
1291+
// The following match will fail to be exhaustive after new variants
1292+
// are added to `Elim`. When this happens, it’s a prompt to make sure
1293+
// that the variants are handled in:
1294+
//
1295+
// - surface::elaboration::Context::unify
1296+
// - core::semantics::is_equal
1297+
//
1298+
// NOTE: Only update the match below when you've updated the above functions.
1299+
match elim {
1300+
Elim::FunApp(..) => {}
1301+
Elim::RecordProj(..) => {}
1302+
Elim::ConstMatch(..) => {}
12871303
}
12881304
}
12891305
}

fathom/src/surface/distillation.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -773,8 +773,8 @@ fn match_if_then_else<'arena>(
773773
default_expr: Option<&'arena core::Term<'arena>>,
774774
) -> Option<(&'arena core::Term<'arena>, &'arena core::Term<'arena>)> {
775775
match (branches, default_expr) {
776-
([(Const::Bool(true), then_expr), (Const::Bool(false), else_expr)], None)
777-
| ([(Const::Bool(false), else_expr), (Const::Bool(true), then_expr)], None)
776+
([(Const::Bool(false), else_expr), (Const::Bool(true), then_expr)], None)
777+
// TODO: Normalise boolean branches when elaborating patterns
778778
| ([(Const::Bool(true), then_expr)], Some(else_expr))
779779
| ([(Const::Bool(false), else_expr)], Some(then_expr)) => Some((then_expr, else_expr)),
780780
_ => None,

fathom/src/surface/elaboration.rs

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -933,16 +933,17 @@ impl<'interner, 'arena> Context<'interner, 'arena> {
933933
let cond_expr = self.check(cond_expr, &self.bool_type.clone());
934934
let then_expr = self.check(then_expr, &expected_type);
935935
let else_expr = self.check(else_expr, &expected_type);
936-
let match_expr = core::Term::ConstMatch(
936+
937+
core::Term::ConstMatch(
937938
range.into(),
938939
self.scope.to_scope(cond_expr),
940+
// NOTE: in lexicographic order: in Rust, `false < true`
939941
self.scope.to_scope_from_iter([
940-
(Const::Bool(true), then_expr),
941942
(Const::Bool(false), else_expr),
943+
(Const::Bool(true), then_expr),
942944
]),
943945
None,
944-
);
945-
match_expr
946+
)
946947
}
947948
(Term::Match(range, scrutinee_expr, equations), _) => {
948949
self.check_match(*range, scrutinee_expr, equations, &expected_type)
@@ -1276,15 +1277,18 @@ impl<'interner, 'arena> Context<'interner, 'arena> {
12761277
let cond_expr = self.check(cond_expr, &self.bool_type.clone());
12771278
let (then_expr, r#type) = self.synth(then_expr);
12781279
let else_expr = self.check(else_expr, &r#type);
1280+
12791281
let match_expr = core::Term::ConstMatch(
12801282
range.into(),
12811283
self.scope.to_scope(cond_expr),
1284+
// NOTE: in lexicographic order: in Rust, `false < true`
12821285
self.scope.to_scope_from_iter([
1283-
(Const::Bool(true), then_expr),
12841286
(Const::Bool(false), else_expr),
1287+
(Const::Bool(true), then_expr),
12851288
]),
12861289
None,
12871290
);
1291+
12881292
(match_expr, r#type)
12891293
}
12901294
Term::Match(range, scrutinee_expr, equations) => {
@@ -2078,6 +2082,10 @@ impl<'interner, 'arena> Context<'interner, 'arena> {
20782082
CheckedPattern::Binder(range, name) => {
20792083
self.check_match_reachable(is_reachable, range);
20802084

2085+
// TODO: If we know this is an exhaustive match, bind the
2086+
// scrutinee to a let binding with the elaborated body, and
2087+
// add it to the branches. This will simplify the
2088+
// distillation of if expressions.
20812089
(self.local_env).push_param(Some(name), match_info.scrutinee.r#type.clone());
20822090
default_expr = self.check(body_expr, &match_info.expected_type);
20832091
self.local_env.pop();

fathom/src/surface/elaboration/unification.rs

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ use std::sync::Arc;
2020

2121
use crate::alloc::SliceVec;
2222
use crate::core::semantics::{
23-
self, ArcValue, Closure, Elim, Head, SplitBranches, Telescope, Value,
23+
self, ArcValue, Branches, Closure, Elim, Head, SplitBranches, Telescope, Value,
2424
};
2525
use crate::core::{Prim, Term};
2626
use crate::env::{EnvLen, Index, Level, SharedEnv, SliceEnv, UniqueEnv};
@@ -316,6 +316,9 @@ impl<'arena, 'env> Context<'arena, 'env> {
316316
self.unify(arg_expr0, arg_expr1)?;
317317
}
318318
(Elim::RecordProj(label0), Elim::RecordProj(label1)) if label0 == label1 => {}
319+
(Elim::ConstMatch(branches0), Elim::ConstMatch(branches1)) => {
320+
self.unify_branches(branches0, branches1)?;
321+
}
319322
(_, _) => {
320323
return Err(Error::Mismatch);
321324
}
@@ -374,6 +377,41 @@ impl<'arena, 'env> Context<'arena, 'env> {
374377
Ok(())
375378
}
376379

380+
/// Unify two [constant branches][Branches].
381+
fn unify_branches<P: PartialEq + Copy>(
382+
&mut self,
383+
branches0: &Branches<'arena, P>,
384+
branches1: &Branches<'arena, P>,
385+
) -> Result<(), Error> {
386+
use SplitBranches::*;
387+
388+
let mut branches0 = branches0.clone();
389+
let mut branches1 = branches1.clone();
390+
391+
loop {
392+
match (
393+
self.elim_env().split_branches(branches0),
394+
self.elim_env().split_branches(branches1),
395+
) {
396+
(
397+
Branch((const0, body_expr0), next_branches0),
398+
Branch((const1, body_expr1), next_branches1),
399+
) if const0 == const1 => match self.unify(&body_expr0, &body_expr1) {
400+
Err(err) => return Err(err),
401+
Ok(()) => {
402+
branches0 = next_branches0;
403+
branches1 = next_branches1;
404+
}
405+
},
406+
(Default(default_expr0), Default(default_expr1)) => {
407+
return self.unify_closures(&default_expr0, &default_expr1);
408+
}
409+
(None, None) => return Ok(()),
410+
(_, _) => return Err(Error::Mismatch),
411+
}
412+
}
413+
}
414+
377415
/// Unify a function literal with a value, using eta-conversion.
378416
///
379417
/// ```fathom

tests/succeed/equality.fathom

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,4 +36,15 @@ let four_chars : Eq U32 "beng" 1650814567 = refl _ _;
3636
let three_chars : Eq U32 "BEN " 1111838240 = refl _ _;
3737

3838

39+
// Branches
40+
41+
let foo = fun (x : U32) =>
42+
match x {
43+
1 => 0 : U32,
44+
x => x
45+
};
46+
47+
let eq_foo : Eq _ foo foo =
48+
refl _ _;
49+
3950
Type

tests/succeed/equality.snap

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,12 @@ let four_chars : fun (P : U32 -> Type) -> P "beng" -> P 1650814567 =
4141
refl U32 "beng";
4242
let three_chars : fun (P : U32 -> Type) -> P "BEN " -> P 1111838240 =
4343
refl U32 "BEN ";
44+
let foo : U32 -> U32 = fun x => match x { 1 => 0, _ => _ };
45+
let eq_foo : fun (P : (U32 -> U32) -> Type) -> P (fun x => match x {
46+
1 => 0,
47+
_ => _,
48+
}) -> P (fun x => match x { 1 => 0, _ => _ }) = refl (U32 ->
49+
U32) (fun _ => match _ { 1 => 0, _ => _ });
4450
Type : Type
4551
'''
4652
stderr = ''

0 commit comments

Comments
 (0)