Skip to content

Commit 0742175

Browse files
smartermilessabin
authored andcommitted
Add syntactic sugar for polymorphic function types
Desugar the type [T_1, ..., T_M] -> (P_1, ..., P_N) => R Into scala.PolyFunction { def apply[T_1, ..., T_M](x$1: P_1, ..., x$N: P_N): R }
1 parent 951eaf4 commit 0742175

File tree

6 files changed

+39
-1
lines changed

6 files changed

+39
-1
lines changed

compiler/src/dotty/tools/dotc/ast/Desugar.scala

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1430,6 +1430,22 @@ object desugar {
14301430
}
14311431

14321432
val desugared = tree match {
1433+
case PolyFunction(targs, body) if (ctx.mode.is(Mode.Type)) =>
1434+
// Desugar [T_1, ..., T_M] -> (P_1, ..., P_N) => R
1435+
// Into scala.PolyFunction { def apply[T_1, ..., T_M](x$1: P_1, ..., x$N: P_N): R }
1436+
val Function(vargs, resType) = body
1437+
// TODO: Figure out if we need a `PolyFunctionWithMods` instead.
1438+
val mods = body match {
1439+
case body: FunctionWithMods => body.mods
1440+
case _ => untpd.EmptyModifiers
1441+
}
1442+
val applyTParams = targs.asInstanceOf[List[TypeDef]]
1443+
val applyVParams = vargs.zipWithIndex.map { case (p, n) =>
1444+
makeSyntheticParameter(n + 1, p).withAddedFlags(mods.flags)
1445+
}
1446+
RefinedTypeTree(ref(defn.PolyFunctionType), List(
1447+
DefDef(nme.apply, applyTParams, List(applyVParams), resType, EmptyTree)
1448+
))
14331449
case SymbolLit(str) =>
14341450
Literal(Constant(scala.Symbol(str)))
14351451
case InterpolatedString(id, segments) =>

compiler/src/dotty/tools/dotc/ast/Trees.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -331,6 +331,7 @@ object Trees {
331331
}
332332

333333
def withFlags(flags: FlagSet): ThisTree[Untyped] = withMods(untpd.Modifiers(flags))
334+
def withAddedFlags(flags: FlagSet): ThisTree[Untyped] = withMods(rawMods | flags)
334335

335336
def setComment(comment: Option[Comment]): this.type = {
336337
comment.map(putAttachment(DocComment, _))

compiler/src/dotty/tools/dotc/ast/untpd.scala

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,12 @@ object untpd extends Trees.Instance[Untyped] with UntypedTreeInfo {
7272
class FunctionWithMods(args: List[Tree], body: Tree, val mods: Modifiers)(implicit @constructorOnly src: SourceFile)
7373
extends Function(args, body)
7474

75+
/** A polymorphic function type */
76+
case class PolyFunction(targs: List[Tree], body: Tree)(implicit @constructorOnly src: SourceFile) extends Tree {
77+
override def isTerm = body.isTerm
78+
override def isType = body.isType
79+
}
80+
7581
/** A function created from a wildcard expression
7682
* @param placeholderParams a list of definitions of synthetic parameters.
7783
* @param body the function body where wildcards are replaced by
@@ -491,6 +497,10 @@ object untpd extends Trees.Instance[Untyped] with UntypedTreeInfo {
491497
case tree: Function if (args eq tree.args) && (body eq tree.body) => tree
492498
case _ => finalize(tree, untpd.Function(args, body)(tree.source))
493499
}
500+
def PolyFunction(tree: Tree)(targs: List[Tree], body: Tree)(implicit ctx: Context): Tree = tree match {
501+
case tree: PolyFunction if (targs eq tree.targs) && (body eq tree.body) => tree
502+
case _ => finalize(tree, untpd.PolyFunction(targs, body)(tree.source))
503+
}
494504
def InfixOp(tree: Tree)(left: Tree, op: Ident, right: Tree)(implicit ctx: Context): Tree = tree match {
495505
case tree: InfixOp if (left eq tree.left) && (op eq tree.op) && (right eq tree.right) => tree
496506
case _ => finalize(tree, untpd.InfixOp(left, op, right)(tree.source))
@@ -579,6 +589,8 @@ object untpd extends Trees.Instance[Untyped] with UntypedTreeInfo {
579589
cpy.InterpolatedString(tree)(id, segments.mapConserve(transform))
580590
case Function(args, body) =>
581591
cpy.Function(tree)(transform(args), transform(body))
592+
case PolyFunction(targs, body) =>
593+
cpy.PolyFunction(tree)(transform(targs), transform(body))
582594
case InfixOp(left, op, right) =>
583595
cpy.InfixOp(tree)(transform(left), op, transform(right))
584596
case PostfixOp(od, op) =>
@@ -634,6 +646,8 @@ object untpd extends Trees.Instance[Untyped] with UntypedTreeInfo {
634646
this(x, segments)
635647
case Function(args, body) =>
636648
this(this(x, args), body)
649+
case PolyFunction(targs, body) =>
650+
this(this(x, targs), body)
637651
case InfixOp(left, op, right) =>
638652
this(this(this(x, left), op), right)
639653
case PostfixOp(od, op) =>

compiler/src/dotty/tools/dotc/parsing/Parsers.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -924,6 +924,8 @@ object Parsers {
924924
val tparams = typeParamClause(ParamOwner.TypeParam)
925925
if (in.token == TLARROW)
926926
atSpan(start, in.skipToken())(LambdaTypeTree(tparams, toplevelTyp()))
927+
else if (isIdent && in.name.toString == "->")
928+
atSpan(start, in.skipToken())(PolyFunction(tparams, toplevelTyp()))
927929
else { accept(TLARROW); typ() }
928930
}
929931
else infixType()

compiler/src/dotty/tools/dotc/printing/RefinedPrinter.scala

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -558,6 +558,11 @@ class RefinedPrinter(_ctx: Context) extends PlainPrinter(_ctx) {
558558
(keywordText("erased ") provided isErased) ~
559559
argsText ~ " => " ~ toText(body)
560560
}
561+
case PolyFunction(targs, body) =>
562+
val targsText = "[" ~ Text(targs.map((arg: Tree) => toText(arg)), ", ") ~ "]"
563+
changePrec(GlobalPrec) {
564+
targsText ~ " -> " ~ toText(body)
565+
}
561566
case InfixOp(l, op, r) =>
562567
val opPrec = parsing.precedence(op.name)
563568
changePrec(opPrec) { toText(l) ~ " " ~ toText(op) ~ " " ~ toText(r) }

tests/run/polymorphic-functions.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
object Test {
2-
def test1(f: PolyFunction { def apply[T <: AnyVal](x: List[T]): List[(T, T)] }) = {
2+
def test1(f: [T <: AnyVal] -> List[T] => List[(T, T)]) = {
33
f(List(1, 2, 3))
44
}
55

0 commit comments

Comments
 (0)