@@ -1049,15 +1049,35 @@ class Typer extends Namer
1049
1049
*/
1050
1050
var paramIndex = Map [Name , Int ]()
1051
1051
1052
- /** If function is of the form
1052
+ /** Infer parameter type from the body of the function
1053
+ *
1054
+ * 1. If function is of the form
1055
+ *
1053
1056
* (x1, ..., xN) => f(... x1, ..., XN, ...)
1057
+ *
1054
1058
* where each `xi` occurs exactly once in the argument list of `f` (in
1055
1059
* any order), the type of `f`, otherwise NoType.
1060
+ *
1061
+ * 2. If the function is of the form
1062
+ *
1063
+ * (using x1, ..., xN) => f
1064
+ *
1065
+ * where `f` is a contextual function type of the form `(T1, ..., TN) ?=> T`,
1066
+ * then `xi` takes the type `Ti`.
1067
+ *
1056
1068
* Updates `fnBody` and `paramIndex` as a side effect.
1057
1069
* @post: If result exists, `paramIndex` is defined for the name of
1058
1070
* every parameter in `params`.
1059
1071
*/
1060
- lazy val calleeType : Type = fnBody match {
1072
+ lazy val calleeType : Type = untpd.stripAnnotated(fnBody) match {
1073
+ case ident : untpd.Ident if isContextual =>
1074
+ val ident1 = typedIdent(ident, WildcardType )
1075
+ val tp = ident1.tpe.widen
1076
+ if defn.isContextFunctionType(tp) && params.size == defn.functionArity(tp) then
1077
+ paramIndex = params.map(_.name).zipWithIndex.toMap
1078
+ fnBody = untpd.TypedSplice (ident1)
1079
+ tp.select(nme.apply)
1080
+ else NoType
1061
1081
case app @ Apply (expr, args) =>
1062
1082
paramIndex = {
1063
1083
for (param <- params; idx <- paramIndices(param, args))
@@ -2450,7 +2470,34 @@ class Typer extends Namer
2450
2470
2451
2471
protected def makeContextualFunction (tree : untpd.Tree , pt : Type )(using Context ): Tree = {
2452
2472
val defn .FunctionOf (formals, _, true , _) = pt.dropDependentRefinement
2453
- val ifun = desugar.makeContextualFunction(formals, tree, defn.isErasedFunctionType(pt))
2473
+
2474
+ // The getter of default parameters may reach here.
2475
+ // Given the code below
2476
+ //
2477
+ // class Foo[A](run: A ?=> Int) {
2478
+ // def foo[T](f: T ?=> Int = run) = ()
2479
+ // }
2480
+ //
2481
+ // it desugars to
2482
+ //
2483
+ // class Foo[A](run: A ?=> Int) {
2484
+ // def foo$default$1[T] = run
2485
+ // def foo[T](f: T ?=> Int = run) = ()
2486
+ // }
2487
+ //
2488
+ // The expected type for checking `run` in `foo$default$1` is
2489
+ //
2490
+ // <?> ?=> Int
2491
+ //
2492
+ // see tests/pos/i7778b.scala
2493
+
2494
+ val paramTypes = {
2495
+ val hasWildcard = formals.exists(_.isInstanceOf [WildcardType ])
2496
+ if hasWildcard then formals.map(_ => untpd.TypeTree ())
2497
+ else formals.map(untpd.TypeTree )
2498
+ }
2499
+
2500
+ val ifun = desugar.makeContextualFunction(paramTypes, tree, defn.isErasedFunctionType(pt))
2454
2501
typr.println(i " make contextual function $tree / $pt ---> $ifun" )
2455
2502
typed(ifun, pt)
2456
2503
}
0 commit comments