From f4a474307326d544d70927405b5338923bd286c6 Mon Sep 17 00:00:00 2001 From: Catharine Manning <2330+catharinejm@users.noreply.github.com> Date: Fri, 25 Oct 2024 16:09:27 -0400 Subject: [PATCH 1/7] wip adding CondExpr --- compiler.py | 77 +++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 77 insertions(+) diff --git a/compiler.py b/compiler.py index fab688b5..75ab6fc2 100644 --- a/compiler.py +++ b/compiler.py @@ -18,6 +18,7 @@ Hole, Int, List, + MatchCase, MatchFunction, Object, Record, @@ -60,6 +61,82 @@ def decl(self) -> str: return f"struct object* {self.name}({args})" +def group_cases(cases: typing.List[MatchCase], key: object) -> typing.List[typing.List[MatchCase]]: + sorted_by_key = sorted(cases, key) + return [group for _, group in groupby(sorted_by_key, key)] + + +class MatchKind: + pass + + +class IsNumber(MatchKind): + pass + + +class IsHole(MatchKind): + pass + + +class IsString(MatchKind): + pass + + +class IsVar(MatchKind): + pass + + +class IsList(MatchKind): + pass + + +class IsRecord(MatchKind): + pass + + +class NumberHasValue(MatchKind): + value: int + + +@dataclasses.dataclass(frozen=True) +class CondExpr(Object): + arg: Var + condition: MatchKind + body: Object + + +@dataclasses.dataclass(frozen=True) +class MatchExpr(Object): + arg: Object # Maybe not needed? + cases: typing.List[CondExpr] + + def compile_match_function(match_fn: MatchFunction) -> Function: + arg = Var("x") + cases = compile_ungrouped_match_cases(arg, match_fn.cases, type) + return Function(arg, MatchExpr(arg, cases)) + + def compile_ungrouped_match_cases(arg: Var, cases: typing.List[MatchCase], group_key) -> typing.List[CondExpr]: + patterns = [case.pattern for case in cases] + grouped = group_cases(patterns, group_key) + return [expand_group(group) for group in grouped] + + def compile_int_cases(arg: Var, group: typing.List[MatchCase]): + cases = [CondExpr(arg, NumberHasValue(case.pattern), case.body) for case in group] + return MatchExpr(arg, cases) + + def expand_group(arg: Var, group: typing.List[MatchCase]): + canonical_case = group[0].case + if isinstance(canonical_case.pattern, Int): + return CondExpr(arg, IsNumber, compile_int_cases(arg, group)) + # if isinstance(canonical_case.pattern, Hole): + # if isinstance(canonical_case.pattern, Variant): + # if isinstance(canonical_case.pattern, String): + # if isinstance(canonical_case.pattern, Var): + # if isinstance(canonical_case.pattern, List): + # if isinstance(canonical_case.pattern, Record): + raise NotImplementedError("expand_group", canonical_case.pattern) + + class Compiler: def __init__(self, main_fn: CompiledFunction) -> None: self.gensym_counter: int = 0 From c6a29fdbbbf5f11f21ba33c120893c801ac6d032 Mon Sep 17 00:00:00 2001 From: Catharine Manning <2330+catharinejm@users.noreply.github.com> Date: Fri, 25 Oct 2024 16:09:27 -0400 Subject: [PATCH 2/7] matching on numbers seems to work --- compiler.py | 92 ++++++++++++++++++++++++++++++++------------------ scrapscript.py | 4 ++- 2 files changed, 63 insertions(+), 33 deletions(-) diff --git a/compiler.py b/compiler.py index 75ab6fc2..5ca8f16a 100644 --- a/compiler.py +++ b/compiler.py @@ -62,20 +62,23 @@ def decl(self) -> str: def group_cases(cases: typing.List[MatchCase], key: object) -> typing.List[typing.List[MatchCase]]: - sorted_by_key = sorted(cases, key) - return [group for _, group in groupby(sorted_by_key, key)] + sorted_by_key = sorted(cases, key=key) + return [list(group) for _, group in itertools.groupby(sorted_by_key, key)] class MatchKind: - pass + def compile(self, arg: str) -> str: + raise NotImplementedError class IsNumber(MatchKind): - pass + def compile(self, arg: str) -> str: + return f"is_num({arg})" class IsHole(MatchKind): - pass + def compile(self, arg: str) -> str: + return f"is_hole({arg})" class IsString(MatchKind): @@ -94,13 +97,22 @@ class IsRecord(MatchKind): pass +@dataclasses.dataclass class NumberHasValue(MatchKind): value: int + def compile(self, arg: str) -> str: + return f"is_num_equal_word({arg}, {self.value})" + + +def coerce_int(object: Object) -> int: + assert isinstance(object, Int) + return object.value + @dataclasses.dataclass(frozen=True) class CondExpr(Object): - arg: Var + arg: Var # Actually, probably this one isn't needed?? condition: MatchKind body: Object @@ -110,31 +122,35 @@ class MatchExpr(Object): arg: Object # Maybe not needed? cases: typing.List[CondExpr] - def compile_match_function(match_fn: MatchFunction) -> Function: - arg = Var("x") - cases = compile_ungrouped_match_cases(arg, match_fn.cases, type) - return Function(arg, MatchExpr(arg, cases)) - - def compile_ungrouped_match_cases(arg: Var, cases: typing.List[MatchCase], group_key) -> typing.List[CondExpr]: - patterns = [case.pattern for case in cases] - grouped = group_cases(patterns, group_key) - return [expand_group(group) for group in grouped] - - def compile_int_cases(arg: Var, group: typing.List[MatchCase]): - cases = [CondExpr(arg, NumberHasValue(case.pattern), case.body) for case in group] - return MatchExpr(arg, cases) - - def expand_group(arg: Var, group: typing.List[MatchCase]): - canonical_case = group[0].case - if isinstance(canonical_case.pattern, Int): - return CondExpr(arg, IsNumber, compile_int_cases(arg, group)) - # if isinstance(canonical_case.pattern, Hole): - # if isinstance(canonical_case.pattern, Variant): - # if isinstance(canonical_case.pattern, String): - # if isinstance(canonical_case.pattern, Var): - # if isinstance(canonical_case.pattern, List): - # if isinstance(canonical_case.pattern, Record): - raise NotImplementedError("expand_group", canonical_case.pattern) + +def compile_match_function(match_fn: MatchFunction) -> Function: + arg = Var("x") + cases = compile_ungrouped_match_cases(arg, match_fn.cases, lambda x: type(x).__name__) + return Function(arg, MatchExpr(arg, cases)) + + +def compile_ungrouped_match_cases(arg: Var, cases: typing.List[MatchCase], group_key: object) -> typing.List[CondExpr]: + cases = [case for case in cases] + grouped = group_cases(cases, group_key) + return [expand_group(arg, group) for group in grouped] + + +def compile_int_cases(arg: Var, group: typing.List[MatchCase]): + cases = [CondExpr(arg, NumberHasValue(coerce_int(case.pattern)), case.body) for case in group] + return MatchExpr(arg, cases) + + +def expand_group(arg: Var, group: typing.List[MatchCase]): + canonical_case = group[0] + if isinstance(canonical_case.pattern, Int): + return CondExpr(arg, IsNumber(), compile_int_cases(arg, group)) + # if isinstance(canonical_case.pattern, Hole): + # if isinstance(canonical_case.pattern, Variant): + # if isinstance(canonical_case.pattern, String): + # if isinstance(canonical_case.pattern, Var): + # if isinstance(canonical_case.pattern, List): + # if isinstance(canonical_case.pattern, Record): + raise NotImplementedError("expand_group", canonical_case.pattern) class Compiler: @@ -525,9 +541,21 @@ def compile(self, env: Env, exp: Object) -> str: return self.compile_function(env, exp, name=None) if isinstance(exp, MatchFunction): # Anonymous match function - return self.compile_match_function(env, exp, name=None) + return self.compile_function(env, compile_match_function(exp), name=None) + if isinstance(exp, MatchExpr): + return self.compile_match_expr(env, exp) raise NotImplementedError(f"exp {type(exp)} {exp}") + def compile_match_expr(self, env: Env, match_expr: MatchExpr) -> str: + arg = self.compile(env, match_expr.arg) + for cond in match_expr.cases: + fallthrough = self.gensym("case") + c_cond = cond.condition.compile(arg) + self._emit(f"if (!{c_cond}) goto {fallthrough};") + case_result = self.compile(env, cond.body) + self._emit(f"return {case_result};") + self._emit(f"{fallthrough}:;") + def compile_to_string(program: Object, debug: bool) -> str: main_fn = CompiledFunction("scrap_main", params=[]) diff --git a/scrapscript.py b/scrapscript.py index e6418e50..48e527b7 100755 --- a/scrapscript.py +++ b/scrapscript.py @@ -1312,7 +1312,9 @@ def free_in(exp: Object) -> Set[str]: if isinstance(exp, Closure): # TODO(max): Should this remove the set of keys in the closure env? return free_in(exp.func) - raise NotImplementedError(("free_in", type(exp))) + # :'( + return set() + # raise NotImplementedError(("free_in", type(exp))) def improve_closure(closure: Closure) -> Closure: From 53d2f97e02f2d47cab4dc277b6609a87f51d4af8 Mon Sep 17 00:00:00 2001 From: Catharine Manning <2330+catharinejm@users.noreply.github.com> Date: Fri, 25 Oct 2024 16:09:27 -0400 Subject: [PATCH 3/7] it's alive! (somewhat!) --- compiler.py | 33 ++++++++++++++++++++++++--------- 1 file changed, 24 insertions(+), 9 deletions(-) diff --git a/compiler.py b/compiler.py index 5ca8f16a..7f52b0fd 100644 --- a/compiler.py +++ b/compiler.py @@ -61,11 +61,6 @@ def decl(self) -> str: return f"struct object* {self.name}({args})" -def group_cases(cases: typing.List[MatchCase], key: object) -> typing.List[typing.List[MatchCase]]: - sorted_by_key = sorted(cases, key=key) - return [list(group) for _, group in itertools.groupby(sorted_by_key, key)] - - class MatchKind: def compile(self, arg: str) -> str: raise NotImplementedError @@ -123,14 +118,24 @@ class MatchExpr(Object): cases: typing.List[CondExpr] +def group_cases(cases: typing.List[MatchCase], keyof: object) -> typing.List[typing.List[MatchCase]]: + groups = {} + for case in cases: + if keyof(case) in groups: + groups[keyof(case)].append(case) + else: + groups[keyof(case)] = [case] + + return list(groups.values()) + + def compile_match_function(match_fn: MatchFunction) -> Function: arg = Var("x") - cases = compile_ungrouped_match_cases(arg, match_fn.cases, lambda x: type(x).__name__) + cases = compile_ungrouped_match_cases(arg, match_fn.cases, lambda x: type(x.pattern).__name__) return Function(arg, MatchExpr(arg, cases)) def compile_ungrouped_match_cases(arg: Var, cases: typing.List[MatchCase], group_key: object) -> typing.List[CondExpr]: - cases = [case for case in cases] grouped = group_cases(cases, group_key) return [expand_group(arg, group) for group in grouped] @@ -144,7 +149,9 @@ def expand_group(arg: Var, group: typing.List[MatchCase]): canonical_case = group[0] if isinstance(canonical_case.pattern, Int): return CondExpr(arg, IsNumber(), compile_int_cases(arg, group)) - # if isinstance(canonical_case.pattern, Hole): + if isinstance(canonical_case.pattern, Hole): + # throwing away subsequent holes + return CondExpr(arg, IsHole(), canonical_case.body) # if isinstance(canonical_case.pattern, Variant): # if isinstance(canonical_case.pattern, String): # if isinstance(canonical_case.pattern, Var): @@ -548,13 +555,21 @@ def compile(self, env: Env, exp: Object) -> str: def compile_match_expr(self, env: Env, match_expr: MatchExpr) -> str: arg = self.compile(env, match_expr.arg) + result = self.gensym("result") + done = self.gensym("done") + self._emit(f"struct object* {result} = NULL;") for cond in match_expr.cases: fallthrough = self.gensym("case") c_cond = cond.condition.compile(arg) self._emit(f"if (!{c_cond}) goto {fallthrough};") case_result = self.compile(env, cond.body) - self._emit(f"return {case_result};") + self._emit(f"{result} = {case_result};") + self._emit(f"goto {done};") self._emit(f"{fallthrough}:;") + self._emit(r'fprintf(stderr, "no matching cases\n");') + self._emit("abort();") + self._emit(f"{done}:;") + return result def compile_to_string(program: Object, debug: bool) -> str: From 95367a6749e9b2eddc46633330210e195542b03c Mon Sep 17 00:00:00 2001 From: Catharine Manning <2330+catharinejm@users.noreply.github.com> Date: Fri, 25 Oct 2024 16:09:27 -0400 Subject: [PATCH 4/7] match arg --- compiler.py | 44 +++++++++++++++++--------------------------- 1 file changed, 17 insertions(+), 27 deletions(-) diff --git a/compiler.py b/compiler.py index 7f52b0fd..d98a29d4 100644 --- a/compiler.py +++ b/compiler.py @@ -66,6 +66,11 @@ def compile(self, arg: str) -> str: raise NotImplementedError +class AcceptAny(MatchKind): + def compile(self, arg: str) -> str: + return "true" + + class IsNumber(MatchKind): def compile(self, arg: str) -> str: return f"is_num({arg})" @@ -121,10 +126,16 @@ class MatchExpr(Object): def group_cases(cases: typing.List[MatchCase], keyof: object) -> typing.List[typing.List[MatchCase]]: groups = {} for case in cases: - if keyof(case) in groups: - groups[keyof(case)].append(case) + if isinstance(case, Var): + if not groups: + raise NotImplementedError + else: + groups[list(groups.keys())[-1]].append(case) else: - groups[keyof(case)] = [case] + if keyof(case) in groups: + groups[keyof(case)].append(case) + else: + groups[keyof(case)] = [case] return list(groups.values()) @@ -152,6 +163,8 @@ def expand_group(arg: Var, group: typing.List[MatchCase]): if isinstance(canonical_case.pattern, Hole): # throwing away subsequent holes return CondExpr(arg, IsHole(), canonical_case.body) + if isinstance(canonical_case.pattern, Var): + return CondExpr(arg, AcceptAny(), Where(canonical_case.body, Assign(canonical_case.pattern, arg))) # if isinstance(canonical_case.pattern, Variant): # if isinstance(canonical_case.pattern, String): # if isinstance(canonical_case.pattern, Var): @@ -252,7 +265,7 @@ def compile_assign(self, env: Env, exp: Assign) -> Env: return {**env, name: value} if isinstance(exp.value, MatchFunction): # Named match function - value = self.compile_match_function(env, exp.value, name) + value = self.compile_function(env, compile_match_function(exp.value), name) return {**env, name: value} value = self.compile(env, exp.value) return {**env, name: value} @@ -362,29 +375,6 @@ def try_match(self, env: Env, arg: str, pattern: Object, fallthrough: str) -> En return updates raise NotImplementedError("try_match", pattern) - def compile_match_function(self, env: Env, exp: MatchFunction, name: Optional[str]) -> str: - arg = self.gensym() - fn = self.make_compiled_function(arg, exp, name) - self.functions.append(fn) - cur = self.function - self.function = fn - funcenv = self.compile_function_env(fn, name) - for i, case in enumerate(exp.cases): - fallthrough = f"case_{i+1}" if i < len(exp.cases) - 1 else "no_match" - env_updates = self.try_match(funcenv, arg, case.pattern, fallthrough) - case_result = self.compile({**funcenv, **env_updates}, case.body) - self._emit(f"return {case_result};") - self._emit(f"{fallthrough}:;") - self._emit(r'fprintf(stderr, "no matching cases\n");') - self._emit("abort();") - # Pacify the C compiler - self._emit("return NULL;") - self.function = cur - if not fn.fields: - # TODO(max): Closure over freevars but only consts - return self._const_closure(fn) - return self.make_closure(env, fn) - def make_closure(self, env: Env, fn: CompiledFunction) -> str: name = self._mktemp(f"mkclosure(heap, {fn.name}, {len(fn.fields)})") for i, field in enumerate(fn.fields): From 00cb47b7493bd71645089e5928fd2d862c279e41 Mon Sep 17 00:00:00 2001 From: Catharine Manning <2330+catharinejm@users.noreply.github.com> Date: Fri, 25 Oct 2024 16:09:27 -0400 Subject: [PATCH 5/7] handle var matches as final fallthrough --- compiler.py | 90 ++++++++++++++++++++++++++++++++++++++++------------- 1 file changed, 68 insertions(+), 22 deletions(-) diff --git a/compiler.py b/compiler.py index d98a29d4..c38108fd 100644 --- a/compiler.py +++ b/compiler.py @@ -82,7 +82,8 @@ def compile(self, arg: str) -> str: class IsString(MatchKind): - pass + def compile(self, arg: str) -> str: + return f"is_string({arg})" class IsVar(MatchKind): @@ -110,6 +111,21 @@ def coerce_int(object: Object) -> int: return object.value +@dataclasses.dataclass +class StringHasValue(MatchKind): + value: str + + def compile(self, arg: str) -> str: + if len(self.value) < 8: + return f"({arg} == mksmallstring({json.dumps(self.value)}, {len(self.value)}))" + return f'string_equal_cstr_len({arg}, "{json.dumps(self.value)}", {len(self.value)})' + + +def coerce_string(object: Object) -> str: + assert isinstance(object, String) + return object.value + + @dataclasses.dataclass(frozen=True) class CondExpr(Object): arg: Var # Actually, probably this one isn't needed?? @@ -121,52 +137,77 @@ class CondExpr(Object): class MatchExpr(Object): arg: Object # Maybe not needed? cases: typing.List[CondExpr] + fallthrough_case: Where | None -def group_cases(cases: typing.List[MatchCase], keyof: object) -> typing.List[typing.List[MatchCase]]: +def group_cases( + cases: typing.List[MatchCase], keyof: object +) -> tuple[typing.List[typing.List[MatchCase]], MatchCase | None]: + print("ungrouped cases") + print(cases) groups = {} + fallthrough = None for case in cases: - if isinstance(case, Var): - if not groups: - raise NotImplementedError - else: - groups[list(groups.keys())[-1]].append(case) + if isinstance(case.pattern, Var): + fallthrough = case + # nothing can match after the var + break else: if keyof(case) in groups: groups[keyof(case)].append(case) else: groups[keyof(case)] = [case] - return list(groups.values()) + print("grouped cases") + print(groups) + return list(groups.values()), fallthrough def compile_match_function(match_fn: MatchFunction) -> Function: arg = Var("x") - cases = compile_ungrouped_match_cases(arg, match_fn.cases, lambda x: type(x.pattern).__name__) - return Function(arg, MatchExpr(arg, cases)) + cases, fallthrough_case = compile_ungrouped_match_cases(arg, match_fn.cases, lambda x: type(x.pattern).__name__) + return Function(arg, MatchExpr(arg, cases, fallthrough_case)) + +def compile_ungrouped_match_cases( + arg: Var, cases: typing.List[MatchCase], group_key: object +) -> tuple[typing.List[CondExpr], Where | None]: + grouped, fallthrough_case = group_cases(cases, group_key) + return [expand_group(arg, group, fallthrough_case) for group in grouped], compile_var_case(arg, fallthrough_case) -def compile_ungrouped_match_cases(arg: Var, cases: typing.List[MatchCase], group_key: object) -> typing.List[CondExpr]: - grouped = group_cases(cases, group_key) - return [expand_group(arg, group) for group in grouped] +def compile_var_case(arg: Var, case: MatchCase | None) -> Where | None: + if case: + assert isinstance(case.pattern, Var) + return Where(case.body, Assign(case.pattern, arg)) + return None -def compile_int_cases(arg: Var, group: typing.List[MatchCase]): + +def compile_int_cases(arg: Var, group: typing.List[MatchCase], fallthrough_case: MatchCase | None): cases = [CondExpr(arg, NumberHasValue(coerce_int(case.pattern)), case.body) for case in group] - return MatchExpr(arg, cases) + return MatchExpr(arg, cases, compile_var_case(arg, fallthrough_case)) + +def compile_string_cases(arg: Var, group: typing.List[MatchCase], fallthrough_case: MatchCase | None): + cases = [CondExpr(arg, StringHasValue(coerce_string(case.pattern)), case.body) for case in group] + return MatchExpr(arg, cases, compile_var_case(arg, fallthrough_case)) -def expand_group(arg: Var, group: typing.List[MatchCase]): + +def expand_group(arg: Var, group: typing.List[MatchCase], fallthrough_case: MatchCase | None): + if not group: + assert fallthrough_case + return compile_var_case(arg, fallthrough_case) canonical_case = group[0] if isinstance(canonical_case.pattern, Int): - return CondExpr(arg, IsNumber(), compile_int_cases(arg, group)) + return CondExpr(arg, IsNumber(), compile_int_cases(arg, group, fallthrough_case)) if isinstance(canonical_case.pattern, Hole): # throwing away subsequent holes return CondExpr(arg, IsHole(), canonical_case.body) if isinstance(canonical_case.pattern, Var): - return CondExpr(arg, AcceptAny(), Where(canonical_case.body, Assign(canonical_case.pattern, arg))) + raise Exception("saw a var") # if isinstance(canonical_case.pattern, Variant): - # if isinstance(canonical_case.pattern, String): + if isinstance(canonical_case.pattern, String): + return CondExpr(arg, IsString(), compile_string_cases(arg, group, fallthrough_case)) # if isinstance(canonical_case.pattern, Var): # if isinstance(canonical_case.pattern, List): # if isinstance(canonical_case.pattern, Record): @@ -541,7 +582,7 @@ def compile(self, env: Env, exp: Object) -> str: return self.compile_function(env, compile_match_function(exp), name=None) if isinstance(exp, MatchExpr): return self.compile_match_expr(env, exp) - raise NotImplementedError(f"exp {type(exp)} {exp}") + raise NotImplementedError(f"exp {type(exp)} {exp!r}") def compile_match_expr(self, env: Env, match_expr: MatchExpr) -> str: arg = self.compile(env, match_expr.arg) @@ -556,8 +597,13 @@ def compile_match_expr(self, env: Env, match_expr: MatchExpr) -> str: self._emit(f"{result} = {case_result};") self._emit(f"goto {done};") self._emit(f"{fallthrough}:;") - self._emit(r'fprintf(stderr, "no matching cases\n");') - self._emit("abort();") + if match_expr.fallthrough_case: + c_name = self.compile(env, match_expr.fallthrough_case) + self._emit(f"{result} = {c_name};") + self._emit(f"goto {done};") + else: + self._emit(r'fprintf(stderr, "no matching cases\n");') + self._emit("abort();") self._emit(f"{done}:;") return result From 830b2bc56ae8666ecc198c1980c243aef7811187 Mon Sep 17 00:00:00 2001 From: Catharine Manning <2330+catharinejm@users.noreply.github.com> Date: Fri, 25 Oct 2024 18:04:36 -0400 Subject: [PATCH 6/7] variants work! probably! --- compiler.py | 88 ++++++++++++++++++++++++++++++++++++++++++++--------- 1 file changed, 74 insertions(+), 14 deletions(-) diff --git a/compiler.py b/compiler.py index c38108fd..d62a052f 100644 --- a/compiler.py +++ b/compiler.py @@ -86,8 +86,9 @@ def compile(self, arg: str) -> str: return f"is_string({arg})" -class IsVar(MatchKind): - pass +class IsVariant(MatchKind): + def compile(self, arg: str) -> str: + return f"is_variant({arg})" class IsList(MatchKind): @@ -126,6 +127,14 @@ def coerce_string(object: Object) -> str: return object.value +@dataclasses.dataclass +class VariantHasTag(MatchKind): + tag: str + + def compile(self, arg: str) -> str: + return f"(variant_tag({arg}) == Tag_{self.tag})" + + @dataclasses.dataclass(frozen=True) class CondExpr(Object): arg: Var # Actually, probably this one isn't needed?? @@ -140,15 +149,20 @@ class MatchExpr(Object): fallthrough_case: Where | None +@dataclasses.dataclass(frozen=True) +class VariantValueExpr(Object): + variant: Object + + def group_cases( - cases: typing.List[MatchCase], keyof: object + cases: typing.List[MatchCase], keyof: object, is_fallthrough: object ) -> tuple[typing.List[typing.List[MatchCase]], MatchCase | None]: print("ungrouped cases") print(cases) groups = {} fallthrough = None for case in cases: - if isinstance(case.pattern, Var): + if is_fallthrough(case): fallthrough = case # nothing can match after the var break @@ -163,16 +177,29 @@ def group_cases( return list(groups.values()), fallthrough +def typename(case: MatchCase) -> str: + return type(case.pattern).__name__ + + +def pattern_is_var(case: MatchCase) -> bool: + return isinstance(case.pattern, Var) + + +def let(name: Var, value: Object, body: Object) -> Where: + return Where(body, Assign(name, value)) + + def compile_match_function(match_fn: MatchFunction) -> Function: - arg = Var("x") - cases, fallthrough_case = compile_ungrouped_match_cases(arg, match_fn.cases, lambda x: type(x.pattern).__name__) - return Function(arg, MatchExpr(arg, cases, fallthrough_case)) + fn_arg = Var(gensym("fn_arg")) + match_arg = Var(gensym("match")) + cases, fallthrough_case = compile_ungrouped_match_cases(match_arg, match_fn.cases, typename, pattern_is_var) + return Function(fn_arg, let(match_arg, fn_arg, MatchExpr(match_arg, cases, fallthrough_case))) def compile_ungrouped_match_cases( - arg: Var, cases: typing.List[MatchCase], group_key: object + arg: Var, cases: typing.List[MatchCase], group_key: object, is_fallthrough: object ) -> tuple[typing.List[CondExpr], Where | None]: - grouped, fallthrough_case = group_cases(cases, group_key) + grouped, fallthrough_case = group_cases(cases, group_key, is_fallthrough) return [expand_group(arg, group, fallthrough_case) for group in grouped], compile_var_case(arg, fallthrough_case) @@ -193,6 +220,26 @@ def compile_string_cases(arg: Var, group: typing.List[MatchCase], fallthrough_ca return MatchExpr(arg, cases, compile_var_case(arg, fallthrough_case)) +def compile_variant_cases(arg: Var, group: typing.List[MatchCase], fallthrough_case: MatchCase | None): + def case_tag(case: MatchCase): + assert isinstance(case.pattern, Variant) + return case.pattern.tag + + grouped_by_variant, _ = group_cases(group, case_tag, lambda x: False) + cond_exprs = [] + for group in grouped_by_variant: + lifted_matches = [MatchCase(case.pattern.value, case.body) for case in group] + print("lifted_matches", repr(lifted_matches)) + inner_arg = Var(gensym("variant_match")) + expanded_cases, inner_fallthrough_case = compile_ungrouped_match_cases( + inner_arg, lifted_matches, typename, pattern_is_var + ) + match_expr = let(inner_arg, VariantValueExpr(arg), MatchExpr(inner_arg, expanded_cases, inner_fallthrough_case)) + cond_exprs.append(CondExpr(arg, VariantHasTag(group[0].pattern.tag), match_expr)) + + return MatchExpr(arg, cond_exprs, compile_var_case(arg, fallthrough_case)) + + def expand_group(arg: Var, group: typing.List[MatchCase], fallthrough_case: MatchCase | None): if not group: assert fallthrough_case @@ -205,18 +252,27 @@ def expand_group(arg: Var, group: typing.List[MatchCase], fallthrough_case: Matc return CondExpr(arg, IsHole(), canonical_case.body) if isinstance(canonical_case.pattern, Var): raise Exception("saw a var") - # if isinstance(canonical_case.pattern, Variant): + if isinstance(canonical_case.pattern, Variant): + return CondExpr(arg, IsVariant(), compile_variant_cases(arg, group, fallthrough_case)) if isinstance(canonical_case.pattern, String): return CondExpr(arg, IsString(), compile_string_cases(arg, group, fallthrough_case)) - # if isinstance(canonical_case.pattern, Var): # if isinstance(canonical_case.pattern, List): # if isinstance(canonical_case.pattern, Record): raise NotImplementedError("expand_group", canonical_case.pattern) +gensym_counter = 0 + + +def gensym(stem: str = "tmp") -> str: + global gensym_counter + gensym_counter += 1 + return f"{stem}_{gensym_counter-1}" + + class Compiler: def __init__(self, main_fn: CompiledFunction) -> None: - self.gensym_counter: int = 0 + # self.gensym_counter: int = 0 self.functions: typing.List[CompiledFunction] = [main_fn] self.function: CompiledFunction = main_fn self.record_keys: Dict[str, int] = {} @@ -259,8 +315,7 @@ def variant_tag(self, key: str) -> int: return result def gensym(self, stem: str = "tmp") -> str: - self.gensym_counter += 1 - return f"{stem}_{self.gensym_counter-1}" + return gensym(stem) def _emit(self, line: str) -> None: self.function.code.append(line) @@ -582,6 +637,9 @@ def compile(self, env: Env, exp: Object) -> str: return self.compile_function(env, compile_match_function(exp), name=None) if isinstance(exp, MatchExpr): return self.compile_match_expr(env, exp) + if isinstance(exp, VariantValueExpr): + value = self.compile(env, exp.variant) + return self._mktemp(f"variant_value({value});") raise NotImplementedError(f"exp {type(exp)} {exp!r}") def compile_match_expr(self, env: Env, match_expr: MatchExpr) -> str: @@ -590,6 +648,8 @@ def compile_match_expr(self, env: Env, match_expr: MatchExpr) -> str: done = self.gensym("done") self._emit(f"struct object* {result} = NULL;") for cond in match_expr.cases: + if isinstance(cond.condition, VariantHasTag): + self.variant_tag(cond.condition.tag) fallthrough = self.gensym("case") c_cond = cond.condition.compile(arg) self._emit(f"if (!{c_cond}) goto {fallthrough};") From b55fc00db47166b35ff6c3a7ba23d68dfc6b633f Mon Sep 17 00:00:00 2001 From: Catharine Manning <2330+catharinejm@users.noreply.github.com> Date: Fri, 25 Oct 2024 18:19:55 -0400 Subject: [PATCH 7/7] remove extra quotes on long string match emitted code --- compiler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/compiler.py b/compiler.py index d62a052f..893283b6 100644 --- a/compiler.py +++ b/compiler.py @@ -119,7 +119,7 @@ class StringHasValue(MatchKind): def compile(self, arg: str) -> str: if len(self.value) < 8: return f"({arg} == mksmallstring({json.dumps(self.value)}, {len(self.value)}))" - return f'string_equal_cstr_len({arg}, "{json.dumps(self.value)}", {len(self.value)})' + return f"string_equal_cstr_len({arg}, {json.dumps(self.value)}, {len(self.value)})" def coerce_string(object: Object) -> str: