Skip to content

Commit e02b772

Browse files
smartermilessabin
authored andcommitted
Support polymorphic function values
A polymorphic function value can be written as: new PolyFunction { def apply[T_1, ..., T_M](x_1: P_1, ..., x_N: P_N): R = body } This is erased to: new FunctionN { def apply(x_1: Object, ..., x_N: Object): Object = body } Getting everything to erase correctly was tricky, the current implementation is a bit messy currently.
1 parent 0742175 commit e02b772

File tree

6 files changed

+95
-4
lines changed

6 files changed

+95
-4
lines changed

compiler/src/dotty/tools/dotc/Compiler.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,7 @@ class Compiler {
9090
List(new Erasure) :: // Rewrite types to JVM model, erasing all type parameters, abstract types and refinements.
9191
List(new ElimErasedValueType, // Expand erased value types to their underlying implmementation types
9292
new VCElideAllocations, // Peep-hole optimization to eliminate unnecessary value class allocations
93+
new ElimPolyFunction, // Rewrite PolyFunction subclasses to FunctionN subclasses
9394
new TailRec, // Rewrite tail recursion to loops
9495
new Mixin, // Expand trait fields and trait initializers
9596
new LazyVals, // Expand lazy vals

compiler/src/dotty/tools/dotc/core/TypeErasure.scala

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -196,10 +196,24 @@ object TypeErasure {
196196
MethodType(Nil, defn.BoxedUnitType)
197197
else if (sym.isAnonymousFunction && einfo.paramInfos.length > MaxImplementedFunctionArity)
198198
MethodType(nme.ALLARGS :: Nil, JavaArrayType(defn.ObjectType) :: Nil, einfo.resultType)
199+
else if (sym.name == nme.apply && sym.owner.derivesFrom(defn.PolyFunctionClass)) {
200+
// The erasure of `apply` in subclasses of PolyFunction has to match
201+
// the erasure of FunctionN#apply, since after `ElimPolyFunction` we replace
202+
// a `PolyFunction` parent by a `FunctionN` parent.
203+
einfo.derivedLambdaType(
204+
paramInfos = einfo.paramInfos.map(_ => defn.ObjectType),
205+
resType = defn.ObjectType
206+
)
207+
}
199208
else
200209
einfo
201210
case einfo =>
202-
einfo
211+
// Erase the parameters of `apply` in subclasses of PolyFunction
212+
if (sym.is(TermParam) && sym.owner.name == nme.apply
213+
&& sym.owner.owner.derivesFrom(defn.PolyFunctionClass))
214+
defn.ObjectType
215+
else
216+
einfo
203217
}
204218
}
205219

compiler/src/dotty/tools/dotc/transform/ElimErasedValueType.scala

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,10 @@ class ElimErasedValueType extends MiniPhase with InfoTransformer { thisPhase =>
8787
val site = root.thisType
8888
val info1 = site.memberInfo(sym1)
8989
val info2 = site.memberInfo(sym2)
90-
if (!info1.matchesLoosely(info2))
90+
if (!info1.matchesLoosely(info2) &&
91+
!(sym1.name == nme.apply &&
92+
(sym1.owner.derivesFrom(defn.PolyFunctionClass) ||
93+
sym2.owner.derivesFrom(defn.PolyFunctionClass))))
9194
ctx.error(DoubleDefinition(sym1, sym2, root), root.sourcePos)
9295
}
9396
val earlyCtx = ctx.withPhase(ctx.elimRepeatedPhase.next)
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
package dotty.tools.dotc
2+
package transform
3+
4+
import ast.{Trees, tpd}
5+
import core._, core.Decorators._
6+
import MegaPhase._, Phases.Phase
7+
import Types._, Contexts._, Constants._, Names._, NameOps._, Flags._, DenotTransformers._
8+
import SymDenotations._, Symbols._, StdNames._, Annotations._, Trees._, Scopes._, Denotations._
9+
import TypeErasure.ErasedValueType, ValueClasses._
10+
11+
/** This phase rewrite PolyFunction subclasses to FunctionN subclasses
12+
*
13+
* class Foo extends PolyFunction {
14+
* def apply(x_1: P_1, ..., x_N: P_N): R = rhs
15+
* }
16+
* becomes:
17+
* class Foo extends FunctionN {
18+
* def apply(x_1: P_1, ..., x_N: P_N): R = rhs
19+
* }
20+
*/
21+
class ElimPolyFunction extends MiniPhase with DenotTransformer {
22+
23+
import tpd._
24+
25+
override def phaseName: String = ElimPolyFunction.name
26+
27+
override def runsAfter = Set(Erasure.name)
28+
29+
override def changesParents: Boolean = true // Replaces PolyFunction by FunctionN
30+
31+
override def transform(ref: SingleDenotation)(implicit ctx: Context) = ref match {
32+
case ref: ClassDenotation if ref.symbol != defn.PolyFunctionClass && ref.derivesFrom(defn.PolyFunctionClass) =>
33+
val cinfo = ref.classInfo
34+
val newParent = functionTypeOfPoly(cinfo)
35+
val newParents = cinfo.classParents.map(parent =>
36+
if (parent.typeSymbol == defn.PolyFunctionClass)
37+
newParent
38+
else
39+
parent
40+
)
41+
ref.copySymDenotation(info = cinfo.derivedClassInfo(classParents = newParents))
42+
case _ =>
43+
ref
44+
}
45+
46+
def functionTypeOfPoly(cinfo: ClassInfo)(implicit ctx: Context): Type = {
47+
val applyMeth = cinfo.decls.lookup(nme.apply).info
48+
val arity = applyMeth.paramNamess.head.length
49+
defn.FunctionType(arity)
50+
}
51+
52+
override def transformTemplate(tree: Template)(implicit ctx: Context): Tree = {
53+
val newParents = tree.parents.mapconserve(parent =>
54+
if (parent.tpe.typeSymbol == defn.PolyFunctionClass) {
55+
val cinfo = tree.symbol.owner.asClass.classInfo
56+
tpd.TypeTree(functionTypeOfPoly(cinfo))
57+
}
58+
else
59+
parent
60+
)
61+
cpy.Template(tree)(parents = newParents)
62+
}
63+
}
64+
65+
object ElimPolyFunction {
66+
val name = "elimPolyFunction"
67+
}
68+

compiler/src/dotty/tools/dotc/typer/TypeAssigner.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,8 @@ trait TypeAssigner {
5454
required = EmptyFlagConjunction, excluded = Private)
5555
.suchThat(decl.matches(_))
5656
val inheritedInfo = inherited.info
57-
if (inheritedInfo.exists &&
57+
val isPolyFunctionApply = decl.name == nme.apply && (parent <:< defn.PolyFunctionType)
58+
if (isPolyFunctionApply || inheritedInfo.exists &&
5859
decl.info.widenExpr <:< inheritedInfo.widenExpr &&
5960
!(inheritedInfo.widenExpr <:< decl.info.widenExpr)) {
6061
val r = RefinedType(parent, decl.name, decl.info)

tests/run/polymorphic-functions.scala

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,10 @@ object Test {
44
}
55

66
def main(args: Array[String]): Unit = {
7-
//test1(...)
7+
val fun = new PolyFunction {
8+
def apply[T <: AnyVal](x: List[T]): List[(T, T)] = x.map(e => (e, e))
9+
}
10+
11+
assert(test1(fun) == List((1, 1), (2, 2), (3, 3)))
812
}
913
}

0 commit comments

Comments
 (0)