Skip to content

Commit b5b0a92

Browse files
committed
Implement lambdas to E-Node conversion
1 parent 9c76814 commit b5b0a92

File tree

4 files changed

+57
-33
lines changed

4 files changed

+57
-33
lines changed

compiler/src/dotty/tools/dotc/qualified_types/EGraph.scala

Lines changed: 31 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ import dotty.tools.dotc.qualified_types.ENode.Op
4747
import dotty.tools.dotc.reporting.trace
4848
import dotty.tools.dotc.transform.TreeExtractors.BinaryOp
4949
import dotty.tools.dotc.util.Spans.Span
50+
import scala.collection.mutable.ListBuffer
5051

5152
final class EGraph(rootCtx: Context):
5253

@@ -72,23 +73,23 @@ final class EGraph(rootCtx: Context):
7273
/** Map used for hash-consing nodes, keys and values are the same */
7374
private val index = mutable.Map.empty[ENode, ENode]
7475

75-
val trueNode: ENode.Atom = ENode.Atom(ConstantType(Constant(true))(using rootCtx))
76+
final val trueNode: ENode.Atom = ENode.Atom(ConstantType(Constant(true))(using rootCtx))
7677
index(trueNode) = trueNode
7778

78-
val falseNode: ENode.Atom = ENode.Atom(ConstantType(Constant(false))(using rootCtx))
79+
final val falseNode: ENode.Atom = ENode.Atom(ConstantType(Constant(false))(using rootCtx))
7980
index(falseNode) = falseNode
8081

81-
val minusOneIntNode: ENode.Atom = ENode.Atom(ConstantType(Constant(-1))(using rootCtx))
82+
final val minusOneIntNode: ENode.Atom = ENode.Atom(ConstantType(Constant(-1))(using rootCtx))
8283
index(minusOneIntNode) = minusOneIntNode
8384

84-
val zeroIntNode: ENode.Atom = ENode.Atom(ConstantType(Constant(0))(using rootCtx))
85+
final val zeroIntNode: ENode.Atom = ENode.Atom(ConstantType(Constant(0))(using rootCtx))
8586
index(zeroIntNode) = zeroIntNode
8687

87-
val oneIntNode: ENode.Atom = ENode.Atom(ConstantType(Constant(1))(using rootCtx))
88+
final val oneIntNode: ENode.Atom = ENode.Atom(ConstantType(Constant(1))(using rootCtx))
8889
index(oneIntNode) = oneIntNode
8990

90-
val d = defn(using rootCtx) // Need a stable path to match on `defn` members
91-
val builtinOps = Map(
91+
private val d = defn(using rootCtx) // Need a stable path to match on `defn` members
92+
private val builtinOps = Map(
9293
d.Int_== -> Op.Equal,
9394
d.Boolean_== -> Op.Equal,
9495
d.Any_== -> Op.Equal,
@@ -137,16 +138,16 @@ final class EGraph(rootCtx: Context):
137138
}
138139
).asInstanceOf[node.type]
139140

140-
def toNode(tree: Tree, paramSyms: List[Symbol] = Nil, paramNodes: List[ENode.ArgRefType] = Nil)(using
141+
def toNode(tree: Tree, paramSyms: List[Symbol] = Nil, paramTps: List[ENode.ArgRefType] = Nil)(using
141142
Context
142143
): Option[ENode] =
143144
trace(i"EGraph.toNode $tree", Printers.qualifiedTypes):
144-
computeToNode(tree, paramSyms, paramNodes).map(node => representent(unique(node)))
145+
computeToNode(tree, paramSyms, paramTps).map(node => representent(unique(node)))
145146

146147
private def computeToNode(
147148
tree: Tree,
148149
paramSyms: List[Symbol] = Nil,
149-
paramNodes: List[ENode.ArgRefType] = Nil
150+
paramTps: List[ENode.ArgRefType] = Nil
150151
)(using currentCtx: Context): Option[ENode] =
151152
trace(i"ENode.computeToNode $tree", Printers.qualifiedTypes):
152153
def normalizeType(tp: Type): Type =
@@ -159,48 +160,45 @@ final class EGraph(rootCtx: Context):
159160
case tp => tp
160161

161162
def mapType(tp: Type): Type =
162-
normalizeType(tp.subst(paramSyms, paramNodes))
163+
normalizeType(tp.subst(paramSyms, paramTps))
163164

164165
tree match
165166
case Literal(_) | Ident(_) | This(_) if tree.tpe.isInstanceOf[SingletonType] =>
166167
Some(ENode.Atom(mapType(tree.tpe).asInstanceOf[SingletonType]))
167168
case New(clazz) =>
168-
for clazzNode <- toNode(clazz, paramSyms, paramNodes) yield ENode.New(clazzNode)
169+
for clazzNode <- toNode(clazz, paramSyms, paramTps) yield ENode.New(clazzNode)
169170
case Select(qual, name) =>
170-
for qualNode <- toNode(qual, paramSyms, paramNodes) yield ENode.Select(qualNode, tree.symbol)
171+
for qualNode <- toNode(qual, paramSyms, paramTps) yield ENode.Select(qualNode, tree.symbol)
171172
case BinaryOp(lhs, op, rhs) if builtinOps.contains(op) =>
172173
for
173-
lhsNode <- toNode(lhs, paramSyms, paramNodes)
174-
rhsNode <- toNode(rhs, paramSyms, paramNodes)
174+
lhsNode <- toNode(lhs, paramSyms, paramTps)
175+
rhsNode <- toNode(rhs, paramSyms, paramTps)
175176
yield normalizeOp(builtinOps(op), List(lhsNode, rhsNode))
176177
case BinaryOp(lhs, d.Int_-, rhs) if lhs.tpe.isInstanceOf[ValueType] && rhs.tpe.isInstanceOf[ValueType] =>
177178
for
178-
lhsNode <- toNode(lhs, paramSyms, paramNodes)
179-
rhsNode <- toNode(rhs, paramSyms, paramNodes)
179+
lhsNode <- toNode(lhs, paramSyms, paramTps)
180+
rhsNode <- toNode(rhs, paramSyms, paramTps)
180181
yield normalizeOp(Op.IntSum, List(lhsNode, normalizeOp(Op.IntProduct, List(minusOneIntNode, rhsNode))))
181182
case Apply(fun, args) =>
182183
for
183-
funNode <- toNode(fun, paramSyms, paramNodes)
184-
argsNodes <- args.map(toNode(_, paramSyms, paramNodes)).sequence
184+
funNode <- toNode(fun, paramSyms, paramTps)
185+
argsNodes <- args.map(toNode(_, paramSyms, paramTps)).sequence
185186
yield ENode.Apply(funNode, argsNodes)
186187
case TypeApply(fun, args) =>
187-
for funNode <- toNode(fun, paramSyms, paramNodes)
188+
for funNode <- toNode(fun, paramSyms, paramTps)
188189
yield ENode.TypeApply(funNode, args.map(tp => mapType(tp.tpe)))
189190
case closureDef(defDef) =>
190191
defDef.symbol.info.dealias match
191192
case mt: MethodType =>
192193
assert(defDef.termParamss.size == 1, "closures have a single parameter list, right?")
193-
val params = defDef.termParamss.head
194-
val myParamSyms = params.map(_.symbol)
195-
196-
val myParamTps: ArrayBuffer[Type] = ArrayBuffer.empty
197-
???
198-
199-
val myRetTp = ???
200-
201-
val myParamNodes = myParamTps.zipWithIndex.map((tp, i) => ENode.ArgRefType(i, tp)).toList
202-
203-
for body <- toNode(defDef.rhs, myParamSyms ::: paramSyms, myParamNodes ::: paramNodes)
194+
val myParamSyms: List[Symbol] = defDef.termParamss.head.map(_.symbol)
195+
val myParamTps: ListBuffer[ENode.ArgRefType] = ListBuffer.empty
196+
val paramTpsSize = paramTps.size
197+
for myParamSym <- myParamSyms do
198+
val underlying = mapType(myParamSym.info.subst(myParamSyms.take(myParamTps.size), myParamTps.toList))
199+
myParamTps += ENode.ArgRefType(paramTpsSize + myParamTps.size, underlying)
200+
val myRetTp = mapType(defDef.tpt.tpe.subst(myParamSyms, myParamTps.toList))
201+
for body <- toNode(defDef.rhs, myParamSyms ::: paramSyms, myParamTps.toList ::: paramTps)
204202
yield ENode.Lambda(myParamTps.toList, myRetTp, body)
205203
case _ => None
206204
case _ =>
@@ -222,15 +220,15 @@ final class EGraph(rootCtx: Context):
222220
case ENode.TypeApply(fn, args) =>
223221
ENode.TypeApply(representent(fn), args)
224222
case ENode.Lambda(paramTps, retTp, body) =>
225-
226223
ENode.Lambda(paramTps, retTp, representent(body))
227224
))
228225

229226
private def normalizeOp(op: ENode.Op, args: List[ENode]): ENode =
230227
op match
231228
case Op.Equal =>
232229
assert(args.size == 2, s"Expected 2 arguments for equality, got $args")
233-
if args(0) eq args(1) then trueNode
230+
if args(0) eq args(1) then
231+
trueNode
234232
else ENode.OpApply(op, args.sortBy(_.hashCode()))
235233
case Op.And =>
236234
assert(args.size == 2, s"Expected 2 arguments for conjunction, got $args")
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
def toBool[T](x: T): Boolean = ???
2+
def tp[T](): Any = ???
3+
4+
def test: Unit =
5+
val x: {l: List[Int] with toBool((x: String, y: x.type) => x.length > 0)} = ??? // error: cannot turn method type into closure because it has internal parameter dependencies
6+
summon[{l: List[Int] with toBool((x: String, y: String) => tp[x.type]())} =:= {l: List[Int] with toBool((x: String, y: String) => tp[y.type]())}] // error
7+
8+
summon[{l: List[Int] with toBool((x: Double) => (y: Int) => x == y)} =:= {l: List[Int] with toBool((a: Double) => (b: Int) => a == a)}] // error
9+
summon[{l: List[Int] with toBool((x: Int) => (y: Int) => x == y)} =:= {l: List[Int] with toBool((a: Int) => (b: Int) => a == a)}] // error
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
def toBool[T](x: T): Boolean = ???
2+
def tp[T](): Any = ???
3+
4+
5+
def test: Unit =
6+
summon[{l: List[Int] with l.forall(x => x > 0)} =:= {l: List[Int] with l.forall(x => x > 0)}]
7+
summon[{l: List[Int] with l.forall(x => x > 0)} =:= {l: List[Int] with l.forall(y => y > 0)}]
8+
summon[{l: List[Int] with l.forall(x => x > 0)} =:= {l: List[Int] with l.forall(_ > 0)}]
9+
10+
summon[{l: List[Int] with toBool((x: String) => x.length > 0)} =:= {l: List[Int] with toBool((y: String) => y.length > 0)}]
11+
12+
summon[{l: List[Int] with toBool((x: String) => tp[x.type]())} =:= {l: List[Int] with toBool((y: String) => tp[y.type]())}]
13+
summon[{l: List[Int] with toBool((x: String, y: String) => tp[x.type]())} =:= {l: List[Int] with toBool((x: String, y: String) => tp[x.type]())}]
14+
summon[{l: List[Int] with toBool((x: String) => tp[x.type]())} =:= {l: List[Int] with toBool((y: String) => tp[y.type]())}]
15+
summon[{l: List[Int] with toBool((x: String, y: String) => tp[y.type]())} =:= {l: List[Int] with toBool((x: String, y: String) => tp[y.type]())}]
16+
17+
summon[{l: List[Int] with toBool((x: String) => (y: String) => x == y)} =:= {l: List[Int] with toBool((a: String) => (b: String) => a == b)}]

0 commit comments

Comments
 (0)