@@ -47,6 +47,7 @@ import dotty.tools.dotc.qualified_types.ENode.Op
47
47
import dotty .tools .dotc .reporting .trace
48
48
import dotty .tools .dotc .transform .TreeExtractors .BinaryOp
49
49
import dotty .tools .dotc .util .Spans .Span
50
+ import scala .collection .mutable .ListBuffer
50
51
51
52
final class EGraph (rootCtx : Context ):
52
53
@@ -72,23 +73,23 @@ final class EGraph(rootCtx: Context):
72
73
/** Map used for hash-consing nodes, keys and values are the same */
73
74
private val index = mutable.Map .empty[ENode , ENode ]
74
75
75
- val trueNode : ENode .Atom = ENode .Atom (ConstantType (Constant (true ))(using rootCtx))
76
+ final val trueNode : ENode .Atom = ENode .Atom (ConstantType (Constant (true ))(using rootCtx))
76
77
index(trueNode) = trueNode
77
78
78
- val falseNode : ENode .Atom = ENode .Atom (ConstantType (Constant (false ))(using rootCtx))
79
+ final val falseNode : ENode .Atom = ENode .Atom (ConstantType (Constant (false ))(using rootCtx))
79
80
index(falseNode) = falseNode
80
81
81
- val minusOneIntNode : ENode .Atom = ENode .Atom (ConstantType (Constant (- 1 ))(using rootCtx))
82
+ final val minusOneIntNode : ENode .Atom = ENode .Atom (ConstantType (Constant (- 1 ))(using rootCtx))
82
83
index(minusOneIntNode) = minusOneIntNode
83
84
84
- val zeroIntNode : ENode .Atom = ENode .Atom (ConstantType (Constant (0 ))(using rootCtx))
85
+ final val zeroIntNode : ENode .Atom = ENode .Atom (ConstantType (Constant (0 ))(using rootCtx))
85
86
index(zeroIntNode) = zeroIntNode
86
87
87
- val oneIntNode : ENode .Atom = ENode .Atom (ConstantType (Constant (1 ))(using rootCtx))
88
+ final val oneIntNode : ENode .Atom = ENode .Atom (ConstantType (Constant (1 ))(using rootCtx))
88
89
index(oneIntNode) = oneIntNode
89
90
90
- val d = defn(using rootCtx) // Need a stable path to match on `defn` members
91
- val builtinOps = Map (
91
+ private val d = defn(using rootCtx) // Need a stable path to match on `defn` members
92
+ private val builtinOps = Map (
92
93
d.Int_== -> Op .Equal ,
93
94
d.Boolean_== -> Op .Equal ,
94
95
d.Any_== -> Op .Equal ,
@@ -137,16 +138,16 @@ final class EGraph(rootCtx: Context):
137
138
}
138
139
).asInstanceOf [node.type ]
139
140
140
- def toNode (tree : Tree , paramSyms : List [Symbol ] = Nil , paramNodes : List [ENode .ArgRefType ] = Nil )(using
141
+ def toNode (tree : Tree , paramSyms : List [Symbol ] = Nil , paramTps : List [ENode .ArgRefType ] = Nil )(using
141
142
Context
142
143
): Option [ENode ] =
143
144
trace(i " EGraph.toNode $tree" , Printers .qualifiedTypes):
144
- computeToNode(tree, paramSyms, paramNodes ).map(node => representent(unique(node)))
145
+ computeToNode(tree, paramSyms, paramTps ).map(node => representent(unique(node)))
145
146
146
147
private def computeToNode (
147
148
tree : Tree ,
148
149
paramSyms : List [Symbol ] = Nil ,
149
- paramNodes : List [ENode .ArgRefType ] = Nil
150
+ paramTps : List [ENode .ArgRefType ] = Nil
150
151
)(using currentCtx : Context ): Option [ENode ] =
151
152
trace(i " ENode.computeToNode $tree" , Printers .qualifiedTypes):
152
153
def normalizeType (tp : Type ): Type =
@@ -159,48 +160,45 @@ final class EGraph(rootCtx: Context):
159
160
case tp => tp
160
161
161
162
def mapType (tp : Type ): Type =
162
- normalizeType(tp.subst(paramSyms, paramNodes ))
163
+ normalizeType(tp.subst(paramSyms, paramTps ))
163
164
164
165
tree match
165
166
case Literal (_) | Ident (_) | This (_) if tree.tpe.isInstanceOf [SingletonType ] =>
166
167
Some (ENode .Atom (mapType(tree.tpe).asInstanceOf [SingletonType ]))
167
168
case New (clazz) =>
168
- for clazzNode <- toNode(clazz, paramSyms, paramNodes ) yield ENode .New (clazzNode)
169
+ for clazzNode <- toNode(clazz, paramSyms, paramTps ) yield ENode .New (clazzNode)
169
170
case Select (qual, name) =>
170
- for qualNode <- toNode(qual, paramSyms, paramNodes ) yield ENode .Select (qualNode, tree.symbol)
171
+ for qualNode <- toNode(qual, paramSyms, paramTps ) yield ENode .Select (qualNode, tree.symbol)
171
172
case BinaryOp (lhs, op, rhs) if builtinOps.contains(op) =>
172
173
for
173
- lhsNode <- toNode(lhs, paramSyms, paramNodes )
174
- rhsNode <- toNode(rhs, paramSyms, paramNodes )
174
+ lhsNode <- toNode(lhs, paramSyms, paramTps )
175
+ rhsNode <- toNode(rhs, paramSyms, paramTps )
175
176
yield normalizeOp(builtinOps(op), List (lhsNode, rhsNode))
176
177
case BinaryOp (lhs, d.Int_- , rhs) if lhs.tpe.isInstanceOf [ValueType ] && rhs.tpe.isInstanceOf [ValueType ] =>
177
178
for
178
- lhsNode <- toNode(lhs, paramSyms, paramNodes )
179
- rhsNode <- toNode(rhs, paramSyms, paramNodes )
179
+ lhsNode <- toNode(lhs, paramSyms, paramTps )
180
+ rhsNode <- toNode(rhs, paramSyms, paramTps )
180
181
yield normalizeOp(Op .IntSum , List (lhsNode, normalizeOp(Op .IntProduct , List (minusOneIntNode, rhsNode))))
181
182
case Apply (fun, args) =>
182
183
for
183
- funNode <- toNode(fun, paramSyms, paramNodes )
184
- argsNodes <- args.map(toNode(_, paramSyms, paramNodes )).sequence
184
+ funNode <- toNode(fun, paramSyms, paramTps )
185
+ argsNodes <- args.map(toNode(_, paramSyms, paramTps )).sequence
185
186
yield ENode .Apply (funNode, argsNodes)
186
187
case TypeApply (fun, args) =>
187
- for funNode <- toNode(fun, paramSyms, paramNodes )
188
+ for funNode <- toNode(fun, paramSyms, paramTps )
188
189
yield ENode .TypeApply (funNode, args.map(tp => mapType(tp.tpe)))
189
190
case closureDef(defDef) =>
190
191
defDef.symbol.info.dealias match
191
192
case mt : MethodType =>
192
193
assert(defDef.termParamss.size == 1 , " closures have a single parameter list, right?" )
193
- val params = defDef.termParamss.head
194
- val myParamSyms = params.map(_.symbol)
195
-
196
- val myParamTps : ArrayBuffer [Type ] = ArrayBuffer .empty
197
- ???
198
-
199
- val myRetTp = ???
200
-
201
- val myParamNodes = myParamTps.zipWithIndex.map((tp, i) => ENode .ArgRefType (i, tp)).toList
202
-
203
- for body <- toNode(defDef.rhs, myParamSyms ::: paramSyms, myParamNodes ::: paramNodes)
194
+ val myParamSyms : List [Symbol ] = defDef.termParamss.head.map(_.symbol)
195
+ val myParamTps : ListBuffer [ENode .ArgRefType ] = ListBuffer .empty
196
+ val paramTpsSize = paramTps.size
197
+ for myParamSym <- myParamSyms do
198
+ val underlying = mapType(myParamSym.info.subst(myParamSyms.take(myParamTps.size), myParamTps.toList))
199
+ myParamTps += ENode .ArgRefType (paramTpsSize + myParamTps.size, underlying)
200
+ val myRetTp = mapType(defDef.tpt.tpe.subst(myParamSyms, myParamTps.toList))
201
+ for body <- toNode(defDef.rhs, myParamSyms ::: paramSyms, myParamTps.toList ::: paramTps)
204
202
yield ENode .Lambda (myParamTps.toList, myRetTp, body)
205
203
case _ => None
206
204
case _ =>
@@ -222,15 +220,15 @@ final class EGraph(rootCtx: Context):
222
220
case ENode .TypeApply (fn, args) =>
223
221
ENode .TypeApply (representent(fn), args)
224
222
case ENode .Lambda (paramTps, retTp, body) =>
225
-
226
223
ENode .Lambda (paramTps, retTp, representent(body))
227
224
))
228
225
229
226
private def normalizeOp (op : ENode .Op , args : List [ENode ]): ENode =
230
227
op match
231
228
case Op .Equal =>
232
229
assert(args.size == 2 , s " Expected 2 arguments for equality, got $args" )
233
- if args(0 ) eq args(1 ) then trueNode
230
+ if args(0 ) eq args(1 ) then
231
+ trueNode
234
232
else ENode .OpApply (op, args.sortBy(_.hashCode()))
235
233
case Op .And =>
236
234
assert(args.size == 2 , s " Expected 2 arguments for conjunction, got $args" )
0 commit comments