diff --git a/compiler.py b/compiler.py index fab688b5..0a9f61c2 100644 --- a/compiler.py +++ b/compiler.py @@ -30,6 +30,10 @@ type_of, IntType, StringType, + TyEmptyRow, + TyRow, + TyCon, + row_flatten, parse, # needed for /compilerepl tokenize, # needed for /compilerepl ) @@ -131,14 +135,35 @@ def _guard(self, cond: str, msg: Optional[str] = None) -> None: self._emit("abort();") self._emit("}") + def _is_int(self, exp: Object) -> bool: + return type_of(exp) == IntType + + def _is_list(self, exp: Object) -> bool: + ty = type_of(exp) + return isinstance(ty, TyCon) and ty.name == "list" + + def _is_hole(self, exp: Object) -> bool: + ty = type_of(exp) + return isinstance(ty, TyCon) and ty.name == "hole" + + def _is_record(self, exp: Object) -> bool: + return isinstance(type_of(exp), TyRow) or isinstance(type_of(exp), TyEmptyRow) + def _guard_int(self, exp: Object, c_name: str) -> None: - if type_of(exp) != IntType: + if not self._is_int(exp): self._guard(f"is_num({c_name})") def _guard_str(self, exp: Object, c_name: str) -> None: if type_of(exp) != StringType: self._guard(f"is_string({c_name})") + def _guaranteed_has_field(self, exp: Object, name: str) -> bool: + ty = type_of(exp) + if not isinstance(ty, TyRow): + return False + fields, _ = row_flatten(ty) + return name in fields + def _mktemp(self, exp: str) -> str: temp = self.gensym() return self._handle(temp, exp) @@ -193,10 +218,12 @@ def try_match(self, env: Env, arg: str, pattern: Object, fallthrough: str) -> En # TODO(max): Give `arg` an AST node so we can track its inferred type # and make use of that in pattern matching if isinstance(pattern, Int): - self._emit(f"if (!is_num_equal_word({arg}, {pattern.value})) {{ goto {fallthrough}; }}") + if not self._is_int(pattern): + self._emit(f"if (!is_num_equal_word({arg}, {pattern.value})) {{ goto {fallthrough}; }}") return {} if isinstance(pattern, Hole): - self._emit(f"if (!is_hole({arg})) {{ goto {fallthrough}; }}") + if not self._is_hole(pattern): + self._emit(f"if (!is_hole({arg})) {{ goto {fallthrough}; }}") return {} if isinstance(pattern, Variant): self.variant_tag(pattern.tag) # register it for the big enum @@ -205,6 +232,7 @@ def try_match(self, env: Env, arg: str, pattern: Object, fallthrough: str) -> En # necessary; the non-Hole case would work just fine. self._emit(f"if ({arg} != mk_immediate_variant(Tag_{pattern.tag})) {{ goto {fallthrough}; }}") return {} + # TODO(max): Check if it's a variant self._emit(f"if (!is_variant({arg})) {{ goto {fallthrough}; }}") self._emit(f"if (variant_tag({arg}) != Tag_{pattern.tag}) {{ goto {fallthrough}; }}") return self.try_match(env, self._mktemp(f"variant_value({arg})"), pattern.value, fallthrough) @@ -214,7 +242,8 @@ def try_match(self, env: Env, arg: str, pattern: Object, fallthrough: str) -> En if len(value) < 8: self._emit(f"if ({arg} != mksmallstring({json.dumps(value)}, {len(value)})) {{ goto {fallthrough}; }}") return {} - self._emit(f"if (!is_string({arg})) {{ goto {fallthrough}; }}") + if not self._is_string(pattern): + self._emit(f"if (!is_string({arg})) {{ goto {fallthrough}; }}") self._emit( f"if (!string_equal_cstr_len({arg}, {json.dumps(value)}, {len(value)})) {{ goto {fallthrough}; }}" ) @@ -222,7 +251,8 @@ def try_match(self, env: Env, arg: str, pattern: Object, fallthrough: str) -> En if isinstance(pattern, Var): return {pattern.name: arg} if isinstance(pattern, List): - self._emit(f"if (!is_list({arg})) {{ goto {fallthrough}; }}") + if not self._is_list(pattern): + self._emit(f"if (!is_list({arg})) {{ goto {fallthrough}; }}") updates = {} the_list = arg use_spread = False @@ -242,7 +272,8 @@ def try_match(self, env: Env, arg: str, pattern: Object, fallthrough: str) -> En self._emit(f"if (!is_empty_list({the_list})) {{ goto {fallthrough}; }}") return updates if isinstance(pattern, Record): - self._emit(f"if (!is_record({arg})) {{ goto {fallthrough}; }}") + if not self._is_record(pattern): + self._emit(f"if (!is_record({arg})) {{ goto {fallthrough}; }}") updates = {} use_spread = False for key, pattern_value in pattern.data.items(): @@ -253,9 +284,11 @@ def try_match(self, env: Env, arg: str, pattern: Object, fallthrough: str) -> En break key_idx = self.record_key(key) record_value = self._mktemp(f"record_get({arg}, {key_idx})") - # TODO(max): If the key is present in the type, don't emit this - # check - self._emit(f"if ({record_value} == NULL) {{ goto {fallthrough}; }}") + # TODO(max): Figure out another way to do this. It's a bit of a + # hack to check the pattern type *even though* it's supposed to + # be unified with the arg type + if not self._guaranteed_has_field(pattern, key): + self._emit(f"if ({record_value} == NULL) {{ goto {fallthrough}; }}") updates.update(self.try_match(env, record_value, pattern_value, fallthrough)) if not use_spread: self._emit(f"if (record_num_fields({arg}) != {len(pattern.data)}) {{ goto {fallthrough}; }}") @@ -439,9 +472,11 @@ def compile(self, env: Env, exp: Object) -> str: record = self.compile(env, exp.obj) key_idx = self.record_key(exp.at.name) # Check if the record is a record - self._guard(f"is_record({record})", "not a record") + if not self._is_record(exp.obj): + self._guard(f"is_record({record})", "not a record") value = self._mktemp(f"record_get({record}, {key_idx})") - self._guard(f"{value} != NULL", f"missing key {exp.at.name!s}") + if not self._guaranteed_has_field(exp.obj, exp.at.name): + self._guard(f"{value} != NULL", f"missing key {exp.at.name!s}") return value if isinstance(exp, Function): # Anonymous function