Skip to content

Commit 6df0adf

Browse files
smartermilessabin
authored andcommitted
Add syntactic sugar for polymorphic function values
Desugar the value [T_1, ..., T_M] -> (x_1: P_1, ..., x_N: P_N) => body Into new scala.PolyFunction { def apply[T_1, ..., T_M](x_1: P_1, ..., x_N: P_N) = body }
1 parent e02b772 commit 6df0adf

File tree

3 files changed

+28
-12
lines changed

3 files changed

+28
-12
lines changed

compiler/src/dotty/tools/dotc/ast/Desugar.scala

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1430,22 +1430,35 @@ object desugar {
14301430
}
14311431

14321432
val desugared = tree match {
1433-
case PolyFunction(targs, body) if (ctx.mode.is(Mode.Type)) =>
1434-
// Desugar [T_1, ..., T_M] -> (P_1, ..., P_N) => R
1435-
// Into scala.PolyFunction { def apply[T_1, ..., T_M](x$1: P_1, ..., x$N: P_N): R }
1436-
val Function(vargs, resType) = body
1433+
case PolyFunction(targs, body) =>
1434+
val Function(vargs, res) = body
14371435
// TODO: Figure out if we need a `PolyFunctionWithMods` instead.
14381436
val mods = body match {
14391437
case body: FunctionWithMods => body.mods
14401438
case _ => untpd.EmptyModifiers
14411439
}
1440+
val polyFunctionTpt = ref(defn.PolyFunctionType)
14421441
val applyTParams = targs.asInstanceOf[List[TypeDef]]
1443-
val applyVParams = vargs.zipWithIndex.map { case (p, n) =>
1444-
makeSyntheticParameter(n + 1, p).withAddedFlags(mods.flags)
1442+
if (ctx.mode.is(Mode.Type)) {
1443+
// Desugar [T_1, ..., T_M] -> (P_1, ..., P_N) => R
1444+
// Into scala.PolyFunction { def apply[T_1, ..., T_M](x$1: P_1, ..., x$N: P_N): R }
1445+
1446+
val applyVParams = vargs.zipWithIndex.map { case (p, n) =>
1447+
makeSyntheticParameter(n + 1, p).withAddedFlags(mods.flags)
1448+
}
1449+
RefinedTypeTree(polyFunctionTpt, List(
1450+
DefDef(nme.apply, applyTParams, List(applyVParams), res, EmptyTree)
1451+
))
1452+
} else {
1453+
// Desugar [T_1, ..., T_M] -> (x_1: P_1, ..., x_N: P_N) => body
1454+
// Into new scala.PolyFunction { def apply[T_1, ..., T_M](x_1: P_1, ..., x_N: P_N) = body }
1455+
1456+
val applyVParams = vargs.asInstanceOf[List[ValDef]]
1457+
.map(varg => varg.withAddedFlags(mods.flags | Param))
1458+
New(Template(emptyConstructor, List(polyFunctionTpt), Nil, EmptyValDef,
1459+
List(DefDef(nme.apply, applyTParams, List(applyVParams), TypeTree(), res))
1460+
))
14451461
}
1446-
RefinedTypeTree(ref(defn.PolyFunctionType), List(
1447-
DefDef(nme.apply, applyTParams, List(applyVParams), resType, EmptyTree)
1448-
))
14491462
case SymbolLit(str) =>
14501463
Literal(Constant(scala.Symbol(str)))
14511464
case InterpolatedString(id, segments) =>

compiler/src/dotty/tools/dotc/parsing/Parsers.scala

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1325,6 +1325,11 @@ object Parsers {
13251325
atSpan(in.skipToken()) { Return(if (isExprIntro) expr() else EmptyTree, EmptyTree) }
13261326
case FOR =>
13271327
forExpr()
1328+
case LBRACKET =>
1329+
val start = in.offset
1330+
val tparams = typeParamClause(ParamOwner.TypeParam)
1331+
assert(isIdent && in.name.toString == "->", "Expected `->`")
1332+
atSpan(start, in.skipToken())(PolyFunction(tparams, expr()))
13281333
case _ =>
13291334
if (isIdent(nme.inline) && !in.inModifierPosition() && in.lookaheadIn(canStartExpressionTokens)) {
13301335
val start = in.skipToken()

tests/run/polymorphic-functions.scala

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,7 @@ object Test {
44
}
55

66
def main(args: Array[String]): Unit = {
7-
val fun = new PolyFunction {
8-
def apply[T <: AnyVal](x: List[T]): List[(T, T)] = x.map(e => (e, e))
9-
}
7+
val fun = [T <: AnyVal] -> (x: List[T]) => x.map(e => (e, e))
108

119
assert(test1(fun) == List((1, 1), (2, 2), (3, 3)))
1210
}

0 commit comments

Comments
 (0)