diff --git a/scrapscript.py b/scrapscript.py index 68466c4c..407ea947 100755 --- a/scrapscript.py +++ b/scrapscript.py @@ -3984,6 +3984,7 @@ def find(self) -> MonoType: @dataclasses.dataclass class TyVar(MonoType): forwarded: MonoType | None = dataclasses.field(init=False, default=None) + level: int = dataclasses.field(init=False, default_factory=lambda: current_level) name: str def find(self) -> MonoType: @@ -4124,7 +4125,13 @@ def unify_fail(ty1: MonoType, ty2: MonoType) -> None: def occurs_in(tyvar: TyVar, ty: MonoType) -> bool: if isinstance(ty, TyVar): - return tyvar == ty + if tyvar == ty: + return True + if ty.is_unbound(): + min_level = min(tyvar.level, ty.level) if tyvar.is_unbound() else tyvar.level + ty.level = min_level + return False + return occurs_in(tyvar, ty.find()) if isinstance(ty, TyCon): return any(occurs_in(tyvar, arg) for arg in ty.args) if isinstance(ty, TyEmptyRow): @@ -4135,6 +4142,8 @@ def occurs_in(tyvar: TyVar, ty: MonoType) -> bool: def unify_type(ty1: MonoType, ty2: MonoType) -> None: + if ty1 == ty2: + return ty1 = ty1.find() ty2 = ty2.find() if isinstance(ty1, TyVar): @@ -4210,6 +4219,18 @@ def unify_type(ty1: MonoType, ty2: MonoType) -> None: fresh_var_counter = 0 +current_level = 0 + + +def enter_level() -> None: + global current_level + current_level += 1 + + +def leave_level() -> None: + global current_level + current_level -= 1 + def fresh_tyvar(prefix: str = "t") -> TyVar: global fresh_var_counter @@ -4248,28 +4269,26 @@ def instantiate(scheme: Forall) -> MonoType: return apply_ty(scheme.ty, fresh) -def ftv_ty(ty: MonoType) -> set[str]: - ty = ty.find() +def ftv_ty(ty: MonoType, min_level: int = -1) -> set[str]: if isinstance(ty, TyVar): - return {ty.name} + if ty.is_unbound(): + if ty.level > min_level: + return {ty.name} + else: + return ftv_ty(ty.find(), min_level) + return set() if isinstance(ty, TyCon): - return set().union(*map(ftv_ty, ty.args)) + return set().union(*(ftv_ty(arg, min_level) for arg in ty.args)) if isinstance(ty, TyEmptyRow): return set() if isinstance(ty, TyRow): - return set().union(*map(ftv_ty, ty.fields.values()), ftv_ty(ty.rest)) + return set().union(ftv_ty(ty.rest, min_level), *(ftv_ty(val, min_level) for val in ty.fields.values())) raise InferenceError(f"Unknown type: {ty}") -def generalize(ty: MonoType, ctx: Context) -> Forall: - def ftv_scheme(ty: Forall) -> set[str]: - return ftv_ty(ty.ty) - set(tyvar.name for tyvar in ty.tyvars) - - def ftv_ctx(ctx: Context) -> set[str]: - return set().union(*(ftv_scheme(scheme) for scheme in ctx.values())) - +def generalize(ty: MonoType) -> Forall: # TODO(max): Freshen? - tyvars = ftv_ty(ty) - ftv_ctx(ctx) + tyvars = ftv_ty(ty, current_level) return Forall([TyVar(name) for name in sorted(tyvars)], ty) @@ -4355,6 +4374,7 @@ def infer_type(expr: Object, ctx: Context) -> MonoType: if isinstance(expr, Where): assert isinstance(expr.binding, Assign) name, value, body = expr.binding.name.name, expr.binding.value, expr.body + enter_level() if isinstance(value, (Function, MatchFunction)): # Letrec func_ty: MonoType = fresh_tyvar() @@ -4362,7 +4382,8 @@ def infer_type(expr: Object, ctx: Context) -> MonoType: else: # Let value_ty = infer_type(value, ctx) - value_scheme = generalize(value_ty, ctx) + leave_level() + value_scheme = generalize(value_ty) body_ty = infer_type(body, {**ctx, name: value_scheme}) return set_type(expr, body_ty) if isinstance(expr, List): @@ -4676,6 +4697,12 @@ def test_generalization(self) -> None: ty = self.infer(expr, {}) self.assertTyEqual(ty, func_type(TyVar("a"), TyVar("a"))) + def test_generalization2(self) -> None: + # From https://okmij.org/ftp/ML/generalization.html + expr = parse(tokenize("x -> (y . y = z -> x z)")) + ty = self.infer(expr, {}) + self.assertTyEqual(ty, func_type(func_type(TyVar("a"), TyVar("b")), func_type(TyVar("a"), TyVar("b")))) + def test_id(self) -> None: expr = Function(Var("x"), Var("x")) ty = self.infer(expr, {})