From a22408808b952a3d8294bba5629a036aad476781 Mon Sep 17 00:00:00 2001 From: Max Bernstein Date: Sat, 16 Nov 2024 11:43:33 -0500 Subject: [PATCH 1/7] Add another generalization test --- scrapscript.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/scrapscript.py b/scrapscript.py index 68466c4c..72c7e13e 100755 --- a/scrapscript.py +++ b/scrapscript.py @@ -4676,6 +4676,12 @@ def test_generalization(self) -> None: ty = self.infer(expr, {}) self.assertTyEqual(ty, func_type(TyVar("a"), TyVar("a"))) + def test_generalization2(self) -> None: + # From https://okmij.org/ftp/ML/generalization.html + expr = parse(tokenize("x -> (y . y = z -> x z)")) + ty = self.infer(expr, {}) + self.assertTyEqual(ty, func_type(func_type(TyVar("a"), TyVar("b")), func_type(TyVar("a"), TyVar("b")))) + def test_id(self) -> None: expr = Function(Var("x"), Var("x")) ty = self.infer(expr, {}) From b93b320a4edb762cc05c7aa7d5b8749faa6d4d70 Mon Sep 17 00:00:00 2001 From: Max Bernstein Date: Sat, 16 Nov 2024 12:52:34 -0500 Subject: [PATCH 2/7] Use eager level-based generalization --- scrapscript.py | 71 ++++++++++++++++++++++++++++++++++---------------- 1 file changed, 49 insertions(+), 22 deletions(-) diff --git a/scrapscript.py b/scrapscript.py index 72c7e13e..74fe7218 100755 --- a/scrapscript.py +++ b/scrapscript.py @@ -3984,6 +3984,7 @@ def find(self) -> MonoType: @dataclasses.dataclass class TyVar(MonoType): forwarded: MonoType | None = dataclasses.field(init=False, default=None) + level: int = dataclasses.field(init=False, default_factory=lambda: current_level) name: str def find(self) -> MonoType: @@ -4122,24 +4123,38 @@ def unify_fail(ty1: MonoType, ty2: MonoType) -> None: raise InferenceError(f"Unification failed for {ty1} and {ty2}") -def occurs_in(tyvar: TyVar, ty: MonoType) -> bool: +def occurs(tyvar: TyVar, ty: MonoType) -> None: if isinstance(ty, TyVar): - return tyvar == ty + if tyvar == ty: + raise InferenceError(f"Occurs check failed for {tyvar} and {ty}") + if ty.is_unbound(): + min_level = min(tyvar.level, ty.level) if tyvar.is_unbound() else tyvar.level + ty.level = min_level + return + occurs(tyvar, ty.forwarded) + return if isinstance(ty, TyCon): - return any(occurs_in(tyvar, arg) for arg in ty.args) + for arg in ty.args: + occurs(tyvar, arg) + return if isinstance(ty, TyEmptyRow): - return False + return if isinstance(ty, TyRow): - return any(occurs_in(tyvar, val) for val in ty.fields.values()) or occurs_in(tyvar, ty.rest) + for val in ty.fields.values(): + occurs(tyvar, val) + occurs(tyvar, ty.rest) + return raise InferenceError(f"Unknown type: {ty}") def unify_type(ty1: MonoType, ty2: MonoType) -> None: + if ty1 == ty2: + return ty1 = ty1.find() ty2 = ty2.find() if isinstance(ty1, TyVar): - if occurs_in(ty1, ty2): - raise InferenceError(f"Occurs check failed for {ty1} and {ty2}") + # ty1 is unbound if it's still a TyVar after .find() + occurs(ty1, ty2) ty1.make_equal_to(ty2) return if isinstance(ty2, TyVar): # Mirror @@ -4210,6 +4225,18 @@ def unify_type(ty1: MonoType, ty2: MonoType) -> None: fresh_var_counter = 0 +current_level = 0 + + +def enter_level() -> None: + global current_level + current_level += 1 + + +def leave_level() -> None: + global current_level + current_level -= 1 + def fresh_tyvar(prefix: str = "t") -> TyVar: global fresh_var_counter @@ -4248,28 +4275,26 @@ def instantiate(scheme: Forall) -> MonoType: return apply_ty(scheme.ty, fresh) -def ftv_ty(ty: MonoType) -> set[str]: - ty = ty.find() +def ftv_ty(ty: MonoType, min_level=-1) -> set[str]: if isinstance(ty, TyVar): - return {ty.name} + if ty.is_unbound(): + if ty.level > min_level: + return {ty.name} + else: + return ftv_ty(ty.forwarded, min_level) + return set() if isinstance(ty, TyCon): - return set().union(*map(ftv_ty, ty.args)) + return set().union(*(ftv_ty(arg, min_level) for arg in ty.args)) if isinstance(ty, TyEmptyRow): return set() if isinstance(ty, TyRow): - return set().union(*map(ftv_ty, ty.fields.values()), ftv_ty(ty.rest)) - raise InferenceError(f"Unknown type: {ty}") - - -def generalize(ty: MonoType, ctx: Context) -> Forall: - def ftv_scheme(ty: Forall) -> set[str]: - return ftv_ty(ty.ty) - set(tyvar.name for tyvar in ty.tyvars) + return set().union(ftv_ty(ty.rest, min_level), *(ftv_ty(val, min_level) for val in ty.fields.values())) + raise InferenceError(f"ftv_ty: Unknown type: {ty}") - def ftv_ctx(ctx: Context) -> set[str]: - return set().union(*(ftv_scheme(scheme) for scheme in ctx.values())) +def generalize(ty: MonoType) -> Forall: # TODO(max): Freshen? - tyvars = ftv_ty(ty) - ftv_ctx(ctx) + tyvars = ftv_ty(ty, current_level) return Forall([TyVar(name) for name in sorted(tyvars)], ty) @@ -4355,6 +4380,7 @@ def infer_type(expr: Object, ctx: Context) -> MonoType: if isinstance(expr, Where): assert isinstance(expr.binding, Assign) name, value, body = expr.binding.name.name, expr.binding.value, expr.body + enter_level() if isinstance(value, (Function, MatchFunction)): # Letrec func_ty: MonoType = fresh_tyvar() @@ -4362,7 +4388,8 @@ def infer_type(expr: Object, ctx: Context) -> MonoType: else: # Let value_ty = infer_type(value, ctx) - value_scheme = generalize(value_ty, ctx) + leave_level() + value_scheme = generalize(value_ty) body_ty = infer_type(body, {**ctx, name: value_scheme}) return set_type(expr, body_ty) if isinstance(expr, List): From 5ddbfe67071e5f7b46664ddf36b711d2394f774a Mon Sep 17 00:00:00 2001 From: Max Bernstein Date: Sat, 16 Nov 2024 12:54:54 -0500 Subject: [PATCH 3/7] . --- scrapscript.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/scrapscript.py b/scrapscript.py index 74fe7218..91759e08 100755 --- a/scrapscript.py +++ b/scrapscript.py @@ -4123,7 +4123,7 @@ def unify_fail(ty1: MonoType, ty2: MonoType) -> None: raise InferenceError(f"Unification failed for {ty1} and {ty2}") -def occurs(tyvar: TyVar, ty: MonoType) -> None: +def occurs_in(tyvar: TyVar, ty: MonoType) -> None: if isinstance(ty, TyVar): if tyvar == ty: raise InferenceError(f"Occurs check failed for {tyvar} and {ty}") @@ -4131,18 +4131,18 @@ def occurs(tyvar: TyVar, ty: MonoType) -> None: min_level = min(tyvar.level, ty.level) if tyvar.is_unbound() else tyvar.level ty.level = min_level return - occurs(tyvar, ty.forwarded) + occurs_in(tyvar, ty.forwarded) return if isinstance(ty, TyCon): for arg in ty.args: - occurs(tyvar, arg) + occurs_in(tyvar, arg) return if isinstance(ty, TyEmptyRow): return if isinstance(ty, TyRow): for val in ty.fields.values(): - occurs(tyvar, val) - occurs(tyvar, ty.rest) + occurs_in(tyvar, val) + occurs_in(tyvar, ty.rest) return raise InferenceError(f"Unknown type: {ty}") @@ -4154,7 +4154,7 @@ def unify_type(ty1: MonoType, ty2: MonoType) -> None: ty2 = ty2.find() if isinstance(ty1, TyVar): # ty1 is unbound if it's still a TyVar after .find() - occurs(ty1, ty2) + occurs_in(ty1, ty2) ty1.make_equal_to(ty2) return if isinstance(ty2, TyVar): # Mirror From c631a100b480d677cf21842317fb63c54ea0c55e Mon Sep 17 00:00:00 2001 From: Max Bernstein Date: Sat, 16 Nov 2024 12:57:37 -0500 Subject: [PATCH 4/7] Fix mypy --- scrapscript.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/scrapscript.py b/scrapscript.py index 91759e08..8f237e48 100755 --- a/scrapscript.py +++ b/scrapscript.py @@ -4131,7 +4131,7 @@ def occurs_in(tyvar: TyVar, ty: MonoType) -> None: min_level = min(tyvar.level, ty.level) if tyvar.is_unbound() else tyvar.level ty.level = min_level return - occurs_in(tyvar, ty.forwarded) + occurs_in(tyvar, ty.find()) return if isinstance(ty, TyCon): for arg in ty.args: @@ -4275,13 +4275,13 @@ def instantiate(scheme: Forall) -> MonoType: return apply_ty(scheme.ty, fresh) -def ftv_ty(ty: MonoType, min_level=-1) -> set[str]: +def ftv_ty(ty: MonoType, min_level: int = -1) -> set[str]: if isinstance(ty, TyVar): if ty.is_unbound(): if ty.level > min_level: return {ty.name} else: - return ftv_ty(ty.forwarded, min_level) + return ftv_ty(ty.find(), min_level) return set() if isinstance(ty, TyCon): return set().union(*(ftv_ty(arg, min_level) for arg in ty.args)) From 10413ed09f22c696048dfc8794880d4fc6b8b287 Mon Sep 17 00:00:00 2001 From: Max Bernstein Date: Sat, 16 Nov 2024 15:33:44 -0500 Subject: [PATCH 5/7] Minimize diff --- scrapscript.py | 18 ++++++------------ 1 file changed, 6 insertions(+), 12 deletions(-) diff --git a/scrapscript.py b/scrapscript.py index 8f237e48..1a023198 100755 --- a/scrapscript.py +++ b/scrapscript.py @@ -4123,27 +4123,21 @@ def unify_fail(ty1: MonoType, ty2: MonoType) -> None: raise InferenceError(f"Unification failed for {ty1} and {ty2}") -def occurs_in(tyvar: TyVar, ty: MonoType) -> None: +def occurs_in(tyvar: TyVar, ty: MonoType) -> bool: if isinstance(ty, TyVar): if tyvar == ty: raise InferenceError(f"Occurs check failed for {tyvar} and {ty}") if ty.is_unbound(): min_level = min(tyvar.level, ty.level) if tyvar.is_unbound() else tyvar.level ty.level = min_level - return - occurs_in(tyvar, ty.find()) - return + return False + return occurs_in(tyvar, ty.find()) if isinstance(ty, TyCon): - for arg in ty.args: - occurs_in(tyvar, arg) - return + return any(occurs_in(tyvar, arg) for arg in ty.args) if isinstance(ty, TyEmptyRow): - return + return False if isinstance(ty, TyRow): - for val in ty.fields.values(): - occurs_in(tyvar, val) - occurs_in(tyvar, ty.rest) - return + return any(occurs_in(tyvar, val) for val in ty.fields.values()) or occurs_in(tyvar, ty.rest) raise InferenceError(f"Unknown type: {ty}") From 64cee18b05ea7f1f27f307cc660edba9780cdb75 Mon Sep 17 00:00:00 2001 From: Max Bernstein Date: Sat, 16 Nov 2024 15:34:26 -0500 Subject: [PATCH 6/7] Minimize diff --- scrapscript.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/scrapscript.py b/scrapscript.py index 1a023198..61d6dc6f 100755 --- a/scrapscript.py +++ b/scrapscript.py @@ -4126,7 +4126,7 @@ def unify_fail(ty1: MonoType, ty2: MonoType) -> None: def occurs_in(tyvar: TyVar, ty: MonoType) -> bool: if isinstance(ty, TyVar): if tyvar == ty: - raise InferenceError(f"Occurs check failed for {tyvar} and {ty}") + return True if ty.is_unbound(): min_level = min(tyvar.level, ty.level) if tyvar.is_unbound() else tyvar.level ty.level = min_level @@ -4147,8 +4147,8 @@ def unify_type(ty1: MonoType, ty2: MonoType) -> None: ty1 = ty1.find() ty2 = ty2.find() if isinstance(ty1, TyVar): - # ty1 is unbound if it's still a TyVar after .find() - occurs_in(ty1, ty2) + if occurs_in(ty1, ty2): + raise InferenceError(f"Occurs check failed for {ty1} and {ty2}") ty1.make_equal_to(ty2) return if isinstance(ty2, TyVar): # Mirror From a9e14568b34d8cb10ed65aced85f5f62494be7b7 Mon Sep 17 00:00:00 2001 From: Max Bernstein Date: Sat, 16 Nov 2024 16:50:59 -0500 Subject: [PATCH 7/7] . --- scrapscript.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scrapscript.py b/scrapscript.py index 61d6dc6f..407ea947 100755 --- a/scrapscript.py +++ b/scrapscript.py @@ -4283,7 +4283,7 @@ def ftv_ty(ty: MonoType, min_level: int = -1) -> set[str]: return set() if isinstance(ty, TyRow): return set().union(ftv_ty(ty.rest, min_level), *(ftv_ty(val, min_level) for val in ty.fields.values())) - raise InferenceError(f"ftv_ty: Unknown type: {ty}") + raise InferenceError(f"Unknown type: {ty}") def generalize(ty: MonoType) -> Forall: