Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 42 additions & 15 deletions scrapscript.py
Original file line number Diff line number Diff line change
Expand Up @@ -3984,6 +3984,7 @@ def find(self) -> MonoType:
@dataclasses.dataclass
class TyVar(MonoType):
forwarded: MonoType | None = dataclasses.field(init=False, default=None)
level: int = dataclasses.field(init=False, default_factory=lambda: current_level)
name: str

def find(self) -> MonoType:
Expand Down Expand Up @@ -4124,7 +4125,13 @@ def unify_fail(ty1: MonoType, ty2: MonoType) -> None:

def occurs_in(tyvar: TyVar, ty: MonoType) -> bool:
if isinstance(ty, TyVar):
return tyvar == ty
if tyvar == ty:
return True
if ty.is_unbound():
min_level = min(tyvar.level, ty.level) if tyvar.is_unbound() else tyvar.level
ty.level = min_level
return False
return occurs_in(tyvar, ty.find())
if isinstance(ty, TyCon):
return any(occurs_in(tyvar, arg) for arg in ty.args)
if isinstance(ty, TyEmptyRow):
Expand All @@ -4135,6 +4142,8 @@ def occurs_in(tyvar: TyVar, ty: MonoType) -> bool:


def unify_type(ty1: MonoType, ty2: MonoType) -> None:
if ty1 == ty2:
return
ty1 = ty1.find()
ty2 = ty2.find()
if isinstance(ty1, TyVar):
Expand Down Expand Up @@ -4210,6 +4219,18 @@ def unify_type(ty1: MonoType, ty2: MonoType) -> None:

fresh_var_counter = 0

current_level = 0


def enter_level() -> None:
global current_level
current_level += 1


def leave_level() -> None:
global current_level
current_level -= 1


def fresh_tyvar(prefix: str = "t") -> TyVar:
global fresh_var_counter
Expand Down Expand Up @@ -4248,28 +4269,26 @@ def instantiate(scheme: Forall) -> MonoType:
return apply_ty(scheme.ty, fresh)


def ftv_ty(ty: MonoType) -> set[str]:
ty = ty.find()
def ftv_ty(ty: MonoType, min_level: int = -1) -> set[str]:
if isinstance(ty, TyVar):
return {ty.name}
if ty.is_unbound():
if ty.level > min_level:
return {ty.name}
else:
return ftv_ty(ty.find(), min_level)
return set()
if isinstance(ty, TyCon):
return set().union(*map(ftv_ty, ty.args))
return set().union(*(ftv_ty(arg, min_level) for arg in ty.args))
if isinstance(ty, TyEmptyRow):
return set()
if isinstance(ty, TyRow):
return set().union(*map(ftv_ty, ty.fields.values()), ftv_ty(ty.rest))
return set().union(ftv_ty(ty.rest, min_level), *(ftv_ty(val, min_level) for val in ty.fields.values()))
raise InferenceError(f"Unknown type: {ty}")


def generalize(ty: MonoType, ctx: Context) -> Forall:
def ftv_scheme(ty: Forall) -> set[str]:
return ftv_ty(ty.ty) - set(tyvar.name for tyvar in ty.tyvars)

def ftv_ctx(ctx: Context) -> set[str]:
return set().union(*(ftv_scheme(scheme) for scheme in ctx.values()))

def generalize(ty: MonoType) -> Forall:
# TODO(max): Freshen?
tyvars = ftv_ty(ty) - ftv_ctx(ctx)
tyvars = ftv_ty(ty, current_level)
return Forall([TyVar(name) for name in sorted(tyvars)], ty)


Expand Down Expand Up @@ -4355,14 +4374,16 @@ def infer_type(expr: Object, ctx: Context) -> MonoType:
if isinstance(expr, Where):
assert isinstance(expr.binding, Assign)
name, value, body = expr.binding.name.name, expr.binding.value, expr.body
enter_level()
if isinstance(value, (Function, MatchFunction)):
# Letrec
func_ty: MonoType = fresh_tyvar()
value_ty = infer_type(value, {**ctx, name: Forall([], func_ty)})
else:
# Let
value_ty = infer_type(value, ctx)
value_scheme = generalize(value_ty, ctx)
leave_level()
value_scheme = generalize(value_ty)
body_ty = infer_type(body, {**ctx, name: value_scheme})
return set_type(expr, body_ty)
if isinstance(expr, List):
Expand Down Expand Up @@ -4676,6 +4697,12 @@ def test_generalization(self) -> None:
ty = self.infer(expr, {})
self.assertTyEqual(ty, func_type(TyVar("a"), TyVar("a")))

def test_generalization2(self) -> None:
# From https://okmij.org/ftp/ML/generalization.html
expr = parse(tokenize("x -> (y . y = z -> x z)"))
ty = self.infer(expr, {})
self.assertTyEqual(ty, func_type(func_type(TyVar("a"), TyVar("b")), func_type(TyVar("a"), TyVar("b"))))

def test_id(self) -> None:
expr = Function(Var("x"), Var("x"))
ty = self.infer(expr, {})
Expand Down
Loading