Skip to content

Commit 6bb1ce6

Browse files
authored
Merge pull request #34 from saulshanabrook/simplify-representation
Unified Representation and NumPy Overloading
2 parents 9a27e40 + 02f32b3 commit 6bb1ce6

17 files changed

+351
-872
lines changed

example_rewrite.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,21 +5,21 @@
55
llvmlite.opaque_pointers_enabled = True
66

77
from mlir_egglog import kernel
8-
from mlir_egglog.term_ir import Term, Sin, Cos, Mul, Add
8+
from mlir_egglog.term_ir import Term, sin, cos
99
from egglog import rewrite, ruleset
1010
from mlir_egglog.optimization_rules import basic_math
1111

1212

1313
# A rewrite rule
1414
@ruleset
1515
def trig_double_angle(a: Term):
16-
sin_a = Sin(a)
17-
cos_a = Cos(a)
18-
mul1 = Mul(sin_a, cos_a)
19-
mul2 = Mul(cos_a, sin_a)
16+
sin_a = sin(a)
17+
cos_a = cos(a)
18+
mul1 = sin_a * cos_a
19+
mul2 = cos_a * sin_a
2020

2121
# sin(a)*cos(a) + cos(a)*sin(a) -> 2 * sin(a)*cos(a)
22-
yield rewrite(Add(mul1, mul2)).to(Mul(Term.lit_f32(2.0), mul1))
22+
yield rewrite(mul1 + mul2).to(Term.lit_f32(2.0) * mul1)
2323

2424

2525
# Apply the rewrites

pyproject.toml

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ name = "mlir-egglog"
33
version = "0.1.0"
44
requires-python = ">=3.12"
55
dependencies = [
6-
"egglog>=8.0.1",
6+
"egglog>=11.0.0",
77
"llvmlite>=0.44.0",
88
"numpy>=2.2.5",
99
"pyyaml>=6.0.2",
@@ -12,7 +12,7 @@ dependencies = [
1212
[dependency-groups]
1313
dev = [
1414
"black>=25.1.0",
15-
"mypy>=1.15.0",
15+
"mypy>=1.17.1",
1616
"ruff>=0.9.10",
1717
"pytest>=8.3.5",
1818
]
@@ -34,3 +34,7 @@ disallow_subclassing_any = "false"
3434
disallow_untyped_decorators = "false"
3535
disallow_any_generics = "true"
3636
follow_imports = "silent"
37+
38+
[tool.uv.sources]
39+
# until https://github.com/python/mypy/pull/19600 is merged
40+
mypy = { git = "https://github.com/saulshanabrook/mypy", branch = "issue-19599" }

src/mlir_egglog/__init__.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,7 @@
99
# Version of the mlir-egglog package
1010
__version__ = "0.1.0"
1111

12-
from mlir_egglog.expr_model import ( # noqa: E402
13-
BinaryOp,
14-
FloatLiteral,
15-
IntLiteral,
16-
Symbol,
17-
)
18-
from mlir_egglog.term_ir import Term, as_egraph # noqa: E402
12+
from mlir_egglog.term_ir import Term # noqa: E402
1913
from mlir_egglog.dispatcher import kernel # noqa: E402
2014

2115
__all__ = [
@@ -24,6 +18,5 @@
2418
"IntLiteral",
2519
"Symbol",
2620
"Term",
27-
"as_egraph",
2821
"kernel",
2922
]

src/mlir_egglog/builtin_functions.py

Lines changed: 0 additions & 63 deletions
This file was deleted.

src/mlir_egglog/dispatcher.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ def __init__(
2424
py_func: types.FunctionType,
2525
rewrites: tuple[RewriteOrRule | Ruleset, ...] | None = None,
2626
):
27-
self.py_func = py_func
27+
self.py_func = py_func # type: ignore[assignment]
2828
self._compiled_func = None
2929
self._compiler = None
3030
self.rewrites = rewrites

src/mlir_egglog/egglog_optimizer.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,21 +4,19 @@
44
from egglog import EGraph, RewriteOrRule, Ruleset
55
from egglog.egraph import UnstableCombinedRuleset
66

7-
from mlir_egglog.term_ir import Term, as_egraph
7+
from mlir_egglog.term_ir import Term
88
from mlir_egglog.python_to_ir import interpret
9-
from mlir_egglog import builtin_functions as ns
10-
from mlir_egglog.expr_model import Expr
11-
from mlir_egglog.ir_to_mlir import convert_term_to_mlir
9+
from mlir_egglog.mlir_gen import MLIRGen
1210

1311
# Rewrite rules
1412
from mlir_egglog.optimization_rules import basic_math, trig_simplify
1513

1614
OPTS: tuple[Ruleset | RewriteOrRule, ...] = (basic_math, trig_simplify)
1715

1816

19-
def extract(ast: Expr, rules: tuple[RewriteOrRule | Ruleset, ...], debug=False) -> Term:
20-
root = as_egraph(ast)
21-
17+
def extract(
18+
root: Term, rules: tuple[RewriteOrRule | Ruleset, ...], debug=False
19+
) -> Term:
2220
egraph = EGraph()
2321
egraph.let("root", root)
2422

@@ -47,10 +45,21 @@ def compile(
4745
fn: FunctionType, rewrites: tuple[RewriteOrRule | Ruleset, ...] = OPTS, debug=True
4846
) -> str:
4947
# Convert np functions according to the namespace map
50-
exprtree = interpret(fn, {"np": ns})
48+
exprtree = interpret(fn)
5149
extracted = extract(exprtree, rewrites, debug)
5250

5351
# Get the argument spec
5452
argspec = inspect.signature(fn)
5553
params = ",".join(map(str, argspec.parameters))
5654
return convert_term_to_mlir(extracted, params)
55+
56+
57+
def convert_term_to_mlir(tree: Term, argspec: str) -> str:
58+
"""
59+
Convert a term to MLIR.
60+
"""
61+
62+
argnames = map(lambda x: x.strip(), argspec.split(","))
63+
argmap = {k: f"%arg_{k}" for k in argnames}
64+
source = MLIRGen(tree, argmap).generate()
65+
return source

0 commit comments

Comments
 (0)