Skip to content

Commit e31aa08

Browse files
committed
Handle dependent function types in Typer
1 parent 66b6280 commit e31aa08

File tree

6 files changed

+205
-140
lines changed

6 files changed

+205
-140
lines changed

compiler/src/dotty/tools/dotc/core/Symbols.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,9 @@ trait Symbols { this: Context =>
130130
newClassSymbol(owner, name, flags, completer, privateWithin, coord, assocFile)
131131
}
132132

133+
def newRefinedClassSymbol = newCompleteClassSymbol(
134+
ctx.owner, tpnme.REFINE_CLASS, NonMember, parents = Nil)
135+
133136
/** Create a module symbol with associated module class
134137
* from its non-info fields and a function producing the info
135138
* of the module class (this info may be lazy).

compiler/src/dotty/tools/dotc/core/Types.scala

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2723,6 +2723,20 @@ object Types {
27232723
def isParamDependent(implicit ctx: Context): Boolean = paramDependencyStatus == TrueDeps
27242724

27252725
def newParamRef(n: Int) = new TermParamRef(this, n) {}
2726+
2727+
/** The least supertype of `resultType` that does not contain parameter dependencies */
2728+
def nonDependentResultApprox(implicit ctx: Context): Type =
2729+
if (isDependent) {
2730+
val dropDependencies = new ApproximatingTypeMap {
2731+
def apply(tp: Type) = tp match {
2732+
case tp @ TermParamRef(thisLambdaType, _) =>
2733+
range(tp.bottomType, atVariance(1)(apply(tp.underlying)))
2734+
case _ => mapOver(tp)
2735+
}
2736+
}
2737+
dropDependencies(resultType)
2738+
}
2739+
else resultType
27262740
}
27272741

27282742
abstract case class MethodType(paramNames: List[TermName])(

compiler/src/dotty/tools/dotc/core/tasty/TreeUnpickler.scala

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -999,8 +999,7 @@ class TreeUnpickler(reader: TastyReader, nameAtRef: NameRef => TermName, posUnpi
999999
val argPats = until(end)(readTerm())
10001000
UnApply(fn, implicitArgs, argPats, patType)
10011001
case REFINEDtpt =>
1002-
val refineCls = ctx.newCompleteClassSymbol(
1003-
ctx.owner, tpnme.REFINE_CLASS, NonMember, parents = Nil)
1002+
val refineCls = ctx.newRefinedClassSymbol
10041003
typeAtAddr(start) = refineCls.typeRef
10051004
val parent = readTpt()
10061005
val refinements = readStats(refineCls, end)(localContext(refineCls))

compiler/src/dotty/tools/dotc/typer/Typer.scala

Lines changed: 168 additions & 138 deletions
Original file line numberDiff line numberDiff line change
@@ -697,160 +697,190 @@ class Typer extends Namer with TypeAssigner with Applications with Implicits wit
697697
}
698698

699699
def typedFunction(tree: untpd.Function, pt: Type)(implicit ctx: Context) = track("typedFunction") {
700+
if (ctx.mode is Mode.Type) typedFunctionType(tree, pt)
701+
else typedFunctionValue(tree, pt)
702+
}
703+
704+
def typedFunctionType(tree: untpd.Function, pt: Type)(implicit ctx: Context) = {
700705
val untpd.Function(args, body) = tree
701-
if (ctx.mode is Mode.Type) {
702-
val isImplicit = tree match {
703-
case _: untpd.ImplicitFunction =>
704-
if (args.length == 0) {
705-
ctx.error(ImplicitFunctionTypeNeedsNonEmptyParameterList(), tree.pos)
706-
false
707-
}
708-
else true
709-
case _ => false
710-
}
711-
val funCls = defn.FunctionClass(args.length, isImplicit)
712-
typed(cpy.AppliedTypeTree(tree)(
713-
untpd.TypeTree(funCls.typeRef), args :+ body), pt)
706+
val isImplicit = tree match {
707+
case _: untpd.ImplicitFunction =>
708+
if (args.length == 0) {
709+
ctx.error(ImplicitFunctionTypeNeedsNonEmptyParameterList(), tree.pos)
710+
false
711+
}
712+
else true
713+
case _ => false
714714
}
715-
else {
716-
val params = args.asInstanceOf[List[untpd.ValDef]]
717-
718-
pt match {
719-
case pt: TypeVar if untpd.isFunctionWithUnknownParamType(tree) =>
720-
// try to instantiate `pt` if this is possible. If it does not
721-
// work the error will be reported later in `inferredParam`,
722-
// when we try to infer the parameter type.
723-
isFullyDefined(pt, ForceDegree.noBottom)
724-
case _ =>
725-
}
715+
val funCls = defn.FunctionClass(args.length, isImplicit)
716+
717+
def typedDependent(params: List[ValDef])(implicit ctx: Context) = {
718+
completeParams(params)
719+
val params1 = params.map(typedExpr(_).asInstanceOf[ValDef])
720+
val resultTpt = typed(body)
721+
val companion = if (isImplicit) ImplicitMethodType else MethodType
722+
val mt = companion.fromSymbols(params1.map(_.symbol), resultTpt.tpe)
723+
if (mt.isParamDependent)
724+
ctx.error(i"$mt is an illegal function type because it has inter-parameter dependencies")
725+
val resTpt = TypeTree(mt.nonDependentResultApprox).withPos(body.pos)
726+
val typeArgs = params1.map(_.tpt) :+ resTpt
727+
val tycon = TypeTree(funCls.typeRef)
728+
val core = assignType(cpy.AppliedTypeTree(tree)(tycon, typeArgs), tycon, typeArgs)
729+
val appMeth = ctx.newSymbol(ctx.owner, nme.apply, Synthetic | Deferred, mt)
730+
val appDef = assignType(
731+
untpd.DefDef(appMeth.name, Nil, List(params1), resultTpt, EmptyTree),
732+
appMeth)
733+
RefinedTypeTree(core, List(appDef), ctx.owner.asClass)
734+
}
735+
736+
args match {
737+
case ValDef(_, _, _) :: _ =>
738+
typedDependent(args.asInstanceOf[List[ValDef]])(
739+
ctx.fresh.setOwner(ctx.newRefinedClassSymbol).setNewScope)
740+
case _ =>
741+
typed(cpy.AppliedTypeTree(tree)(untpd.TypeTree(funCls.typeRef), args :+ body), pt)
742+
}
743+
}
726744

727-
val (protoFormals, protoResult) = decomposeProtoFunction(pt, params.length)
745+
def typedFunctionValue(tree: untpd.Function, pt: Type)(implicit ctx: Context) = {
746+
val untpd.Function(args, body) = tree
747+
val params = args.asInstanceOf[List[untpd.ValDef]]
728748

729-
def refersTo(arg: untpd.Tree, param: untpd.ValDef): Boolean = arg match {
730-
case Ident(name) => name == param.name
731-
case _ => false
732-
}
749+
pt match {
750+
case pt: TypeVar if untpd.isFunctionWithUnknownParamType(tree) =>
751+
// try to instantiate `pt` if this is possible. If it does not
752+
// work the error will be reported later in `inferredParam`,
753+
// when we try to infer the parameter type.
754+
isFullyDefined(pt, ForceDegree.noBottom)
755+
case _ =>
756+
}
733757

734-
/** The function body to be returned in the closure. Can become a TypedSplice
735-
* of a typed expression if this is necessary to infer a parameter type.
736-
*/
737-
var fnBody = tree.body
758+
val (protoFormals, protoResult) = decomposeProtoFunction(pt, params.length)
738759

739-
/** A map from parameter names to unique positions where the parameter
740-
* appears in the argument list of an application.
741-
*/
742-
var paramIndex = Map[Name, Int]()
760+
def refersTo(arg: untpd.Tree, param: untpd.ValDef): Boolean = arg match {
761+
case Ident(name) => name == param.name
762+
case _ => false
763+
}
743764

744-
/** If parameter `param` appears exactly once as an argument in `args`,
745-
* the singleton list consisting of its position in `args`, otherwise `Nil`.
746-
*/
747-
def paramIndices(param: untpd.ValDef, args: List[untpd.Tree]): List[Int] = {
748-
def loop(args: List[untpd.Tree], start: Int): List[Int] = args match {
749-
case arg :: args1 =>
750-
val others = loop(args1, start + 1)
751-
if (refersTo(arg, param)) start :: others else others
752-
case _ => Nil
753-
}
754-
val allIndices = loop(args, 0)
755-
if (allIndices.length == 1) allIndices else Nil
765+
/** The function body to be returned in the closure. Can become a TypedSplice
766+
* of a typed expression if this is necessary to infer a parameter type.
767+
*/
768+
var fnBody = tree.body
769+
770+
/** A map from parameter names to unique positions where the parameter
771+
* appears in the argument list of an application.
772+
*/
773+
var paramIndex = Map[Name, Int]()
774+
775+
/** If parameter `param` appears exactly once as an argument in `args`,
776+
* the singleton list consisting of its position in `args`, otherwise `Nil`.
777+
*/
778+
def paramIndices(param: untpd.ValDef, args: List[untpd.Tree]): List[Int] = {
779+
def loop(args: List[untpd.Tree], start: Int): List[Int] = args match {
780+
case arg :: args1 =>
781+
val others = loop(args1, start + 1)
782+
if (refersTo(arg, param)) start :: others else others
783+
case _ => Nil
756784
}
785+
val allIndices = loop(args, 0)
786+
if (allIndices.length == 1) allIndices else Nil
787+
}
757788

758-
/** If function is of the form
759-
* (x1, ..., xN) => f(... x1, ..., XN, ...)
760-
* where each `xi` occurs exactly once in the argument list of `f` (in
761-
* any order), the type of `f`, otherwise NoType.
762-
* Updates `fnBody` and `paramIndex` as a side effect.
763-
* @post: If result exists, `paramIndex` is defined for the name of
764-
* every parameter in `params`.
765-
*/
766-
def calleeType: Type = fnBody match {
767-
case Apply(expr, args) =>
768-
paramIndex = {
769-
for (param <- params; idx <- paramIndices(param, args))
770-
yield param.name -> idx
771-
}.toMap
772-
if (paramIndex.size == params.length)
773-
expr match {
774-
case untpd.TypedSplice(expr1) =>
775-
expr1.tpe
776-
case _ =>
777-
val protoArgs = args map (_ withType WildcardType)
778-
val callProto = FunProto(protoArgs, WildcardType, this)
779-
val expr1 = typedExpr(expr, callProto)
780-
fnBody = cpy.Apply(fnBody)(untpd.TypedSplice(expr1), args)
781-
expr1.tpe
782-
}
783-
else NoType
789+
/** If function is of the form
790+
* (x1, ..., xN) => f(... x1, ..., XN, ...)
791+
* where each `xi` occurs exactly once in the argument list of `f` (in
792+
* any order), the type of `f`, otherwise NoType.
793+
* Updates `fnBody` and `paramIndex` as a side effect.
794+
* @post: If result exists, `paramIndex` is defined for the name of
795+
* every parameter in `params`.
796+
*/
797+
def calleeType: Type = fnBody match {
798+
case Apply(expr, args) =>
799+
paramIndex = {
800+
for (param <- params; idx <- paramIndices(param, args))
801+
yield param.name -> idx
802+
}.toMap
803+
if (paramIndex.size == params.length)
804+
expr match {
805+
case untpd.TypedSplice(expr1) =>
806+
expr1.tpe
807+
case _ =>
808+
val protoArgs = args map (_ withType WildcardType)
809+
val callProto = FunProto(protoArgs, WildcardType, this)
810+
val expr1 = typedExpr(expr, callProto)
811+
fnBody = cpy.Apply(fnBody)(untpd.TypedSplice(expr1), args)
812+
expr1.tpe
813+
}
814+
else NoType
815+
case _ =>
816+
NoType
817+
}
818+
819+
/** Two attempts: First, if expected type is fully defined pick this one.
820+
* Second, if function is of the form
821+
* (x1, ..., xN) => f(... x1, ..., XN, ...)
822+
* where each `xi` occurs exactly once in the argument list of `f` (in
823+
* any order), and f has a method type MT, pick the corresponding parameter
824+
* type in MT, if this one is fully defined.
825+
* If both attempts fail, issue a "missing parameter type" error.
826+
*/
827+
def inferredParamType(param: untpd.ValDef, formal: Type): Type = {
828+
if (isFullyDefined(formal, ForceDegree.noBottom)) return formal
829+
calleeType.widen match {
830+
case mtpe: MethodType =>
831+
val pos = paramIndex(param.name)
832+
if (pos < mtpe.paramInfos.length) {
833+
val ptype = mtpe.paramInfos(pos)
834+
if (isFullyDefined(ptype, ForceDegree.noBottom) && !ptype.isRepeatedParam)
835+
return ptype
836+
}
784837
case _ =>
785-
NoType
786838
}
839+
errorType(AnonymousFunctionMissingParamType(param, args, tree, pt), param.pos)
840+
}
787841

788-
/** Two attempts: First, if expected type is fully defined pick this one.
789-
* Second, if function is of the form
790-
* (x1, ..., xN) => f(... x1, ..., XN, ...)
791-
* where each `xi` occurs exactly once in the argument list of `f` (in
792-
* any order), and f has a method type MT, pick the corresponding parameter
793-
* type in MT, if this one is fully defined.
794-
* If both attempts fail, issue a "missing parameter type" error.
795-
*/
796-
def inferredParamType(param: untpd.ValDef, formal: Type): Type = {
797-
if (isFullyDefined(formal, ForceDegree.noBottom)) return formal
798-
calleeType.widen match {
799-
case mtpe: MethodType =>
800-
val pos = paramIndex(param.name)
801-
if (pos < mtpe.paramInfos.length) {
802-
val ptype = mtpe.paramInfos(pos)
803-
if (isFullyDefined(ptype, ForceDegree.noBottom) && !ptype.isRepeatedParam)
804-
return ptype
805-
}
806-
case _ =>
807-
}
808-
errorType(AnonymousFunctionMissingParamType(param, args, tree, pt), param.pos)
809-
}
842+
def protoFormal(i: Int): Type =
843+
if (protoFormals.length == params.length) protoFormals(i)
844+
else errorType(WrongNumberOfParameters(protoFormals.length), tree.pos)
810845

811-
def protoFormal(i: Int): Type =
812-
if (protoFormals.length == params.length) protoFormals(i)
813-
else errorType(WrongNumberOfParameters(protoFormals.length), tree.pos)
814-
815-
/** Is `formal` a product type which is elementwise compatible with `params`? */
816-
def ptIsCorrectProduct(formal: Type) = {
817-
isFullyDefined(formal, ForceDegree.noBottom) &&
818-
defn.isProductSubType(formal) &&
819-
Applications.productSelectorTypes(formal).corresponds(params) {
820-
(argType, param) =>
821-
param.tpt.isEmpty || argType <:< typedAheadType(param.tpt).tpe
822-
}
846+
/** Is `formal` a product type which is elementwise compatible with `params`? */
847+
def ptIsCorrectProduct(formal: Type) = {
848+
isFullyDefined(formal, ForceDegree.noBottom) &&
849+
defn.isProductSubType(formal) &&
850+
Applications.productSelectorTypes(formal).corresponds(params) {
851+
(argType, param) =>
852+
param.tpt.isEmpty || argType <:< typedAheadType(param.tpt).tpe
823853
}
854+
}
824855

825-
val desugared =
826-
if (protoFormals.length == 1 && params.length != 1 && ptIsCorrectProduct(protoFormals.head)) {
827-
desugar.makeTupledFunction(params, fnBody)
828-
}
829-
else {
830-
val inferredParams: List[untpd.ValDef] =
831-
for ((param, i) <- params.zipWithIndex) yield
832-
if (!param.tpt.isEmpty) param
833-
else cpy.ValDef(param)(
834-
tpt = untpd.TypeTree(
835-
inferredParamType(param, protoFormal(i)).underlyingIfRepeated(isJava = false)))
836-
837-
// Define result type of closure as the expected type, thereby pushing
838-
// down any implicit searches. We do this even if the expected type is not fully
839-
// defined, which is a bit of a hack. But it's needed to make the following work
840-
// (see typers.scala and printers/PlainPrinter.scala for examples).
841-
//
842-
// def double(x: Char): String = s"$x$x"
843-
// "abc" flatMap double
844-
//
845-
val resultTpt = protoResult match {
846-
case WildcardType(_) => untpd.TypeTree()
847-
case _ => untpd.TypeTree(protoResult)
848-
}
849-
val inlineable = pt.hasAnnotation(defn.InlineParamAnnot)
850-
desugar.makeClosure(inferredParams, fnBody, resultTpt, inlineable)
856+
val desugared =
857+
if (protoFormals.length == 1 && params.length != 1 && ptIsCorrectProduct(protoFormals.head)) {
858+
desugar.makeTupledFunction(params, fnBody)
859+
}
860+
else {
861+
val inferredParams: List[untpd.ValDef] =
862+
for ((param, i) <- params.zipWithIndex) yield
863+
if (!param.tpt.isEmpty) param
864+
else cpy.ValDef(param)(
865+
tpt = untpd.TypeTree(
866+
inferredParamType(param, protoFormal(i)).underlyingIfRepeated(isJava = false)))
867+
868+
// Define result type of closure as the expected type, thereby pushing
869+
// down any implicit searches. We do this even if the expected type is not fully
870+
// defined, which is a bit of a hack. But it's needed to make the following work
871+
// (see typers.scala and printers/PlainPrinter.scala for examples).
872+
//
873+
// def double(x: Char): String = s"$x$x"
874+
// "abc" flatMap double
875+
//
876+
val resultTpt = protoResult match {
877+
case WildcardType(_) => untpd.TypeTree()
878+
case _ => untpd.TypeTree(protoResult)
851879
}
852-
typed(desugared, pt)
853-
}
880+
val inlineable = pt.hasAnnotation(defn.InlineParamAnnot)
881+
desugar.makeClosure(inferredParams, fnBody, resultTpt, inlineable)
882+
}
883+
typed(desugared, pt)
854884
}
855885

856886
def typedClosure(tree: untpd.Closure, pt: Type)(implicit ctx: Context): Tree = track("typedClosure") {

tests/neg/depfuns.scala

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
object Test {
2+
3+
type T = (x: Int) // error: `=>' expected
4+
5+
}

tests/pos/depfuntype.scala

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
object Test {
2+
3+
trait C { type M; val m: M }
4+
5+
type DF = (x: C) => x.M
6+
val depfun: DF = ??? // (x: C) => x.m
7+
val c = new C { type M = Int; val m = 0 }
8+
val y = depfun(c)
9+
val y1: Int = y
10+
11+
val d: C = c
12+
val z = depfun(d)
13+
val z1: d.M = z
14+
}

0 commit comments

Comments
 (0)