Skip to content

Commit 45ad465

Browse files
authored
Improve closure typing (#23700)
If the closure has an expected function type with a fully defined result type, take the internalized result type as the local return type of the closure. This has the effect that some conformance tests are now done with Fresh instead Result caps. This means one can now widen a local reference to a result cap, since the comparison is done between the local reference and the internalized FreshCap. Previously this failed since we compared a local reference with a result cap, and result caps only subtype other result caps. It also propagates types more aggressively into closure bodies, which sometimes reduces the error span and improves the error message.
2 parents 481e173 + 060cbd2 commit 45ad465

21 files changed

+275
-199
lines changed

compiler/src/dotty/tools/dotc/cc/Capability.scala

Lines changed: 129 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,7 @@ object Capabilities:
198198
i"a fresh root capability$classifierStr$originStr"
199199

200200
object FreshCap:
201-
def apply(origin: Origin)(using Context): FreshCap | GlobalCap.type =
201+
def apply(origin: Origin)(using Context): FreshCap =
202202
FreshCap(ctx.owner, origin)
203203

204204
/** A root capability associated with a function type. These are conceptually
@@ -837,6 +837,7 @@ object Capabilities:
837837
case Formal(pref: ParamRef, app: tpd.Apply)
838838
case ResultInstance(methType: Type, meth: Symbol)
839839
case UnapplyInstance(info: MethodType)
840+
case LocalInstance(restpe: Type)
840841
case NewMutable(tp: Type)
841842
case NewCapability(tp: Type)
842843
case LambdaExpected(respt: Type)
@@ -865,6 +866,8 @@ object Capabilities:
865866
i" when instantiating $methDescr$mt"
866867
case UnapplyInstance(info) =>
867868
i" when instantiating argument of unapply with type $info"
869+
case LocalInstance(restpe) =>
870+
i" when instantiating expected result type $restpe of function literal"
868871
case NewMutable(tp) =>
869872
i" when constructing mutable $tp"
870873
case NewCapability(tp) =>
@@ -948,6 +951,69 @@ object Capabilities:
948951
def freshToCap(param: Symbol, tp: Type)(using Context): Type =
949952
CapToFresh(Origin.Parameter(param)).inverse(tp)
950953

954+
/** The local dual of a result type of a closure type.
955+
* @param binder the method type of the anonymous function whose result is mapped
956+
* @pre the context's owner is the anonymous function
957+
*/
958+
class Internalize(binder: MethodType)(using Context) extends BiTypeMap:
959+
thisMap =>
960+
961+
val sym = ctx.owner
962+
assert(sym.isAnonymousFunction)
963+
val paramSyms = atPhase(ctx.phase.prev):
964+
// We need to ask one phase before since `sym` should not be completed as a side effect.
965+
// The result of Internalize is used to se the result type of an anonymous function, and
966+
// the new info of that function is built with the result.
967+
sym.paramSymss.head
968+
val resultToFresh = EqHashMap[ResultCap, FreshCap]()
969+
val freshToResult = EqHashMap[FreshCap, ResultCap]()
970+
971+
override def apply(t: Type) =
972+
if variance < 0 then t
973+
else t match
974+
case t: ParamRef =>
975+
if t.binder == this.binder then paramSyms(t.paramNum).termRef else t
976+
case _ => mapOver(t)
977+
978+
override def mapCapability(c: Capability, deep: Boolean): Capability = c match
979+
case r: ResultCap if r.binder == this.binder =>
980+
resultToFresh.get(r) match
981+
case Some(f) => f
982+
case None =>
983+
val f = FreshCap(Origin.LocalInstance(binder.resType))
984+
resultToFresh(r) = f
985+
freshToResult(f) = r
986+
f
987+
case _ =>
988+
super.mapCapability(c, deep)
989+
990+
class Inverse extends BiTypeMap:
991+
def apply(t: Type): Type =
992+
if variance < 0 then t
993+
else t match
994+
case t: TermRef if paramSyms.contains(t) =>
995+
binder.paramRefs(paramSyms.indexOf(t.symbol))
996+
case _ => mapOver(t)
997+
998+
override def mapCapability(c: Capability, deep: Boolean): Capability = c match
999+
case f: FreshCap if f.owner == sym =>
1000+
freshToResult.get(f) match
1001+
case Some(r) => r
1002+
case None =>
1003+
val r = ResultCap(binder)
1004+
resultToFresh(r) = f
1005+
freshToResult(f) = r
1006+
r
1007+
case _ => super.mapCapability(c, deep)
1008+
1009+
def inverse = thisMap
1010+
override def toString = thisMap.toString + ".inverse"
1011+
end Inverse
1012+
1013+
override def toString = "InternalizeClosureResult"
1014+
def inverse = Inverse()
1015+
end Internalize
1016+
9511017
/** Map top-level free existential variables one-to-one to Fresh instances */
9521018
def resultToFresh(tp: Type, origin: Origin)(using Context): Type =
9531019
val subst = new TypeMap:
@@ -977,78 +1043,76 @@ object Capabilities:
9771043
subst(tp)
9781044
end resultToFresh
9791045

980-
/** Replace all occurrences of `cap` (or fresh) in parts of this type by an existentially bound
981-
* variable bound by `mt`.
982-
* Stop at function or method types since these have been mapped before.
983-
*/
984-
def toResult(tp: Type, mt: MethodicType, fail: Message => Unit)(using Context): Type =
985-
986-
abstract class CapMap extends BiTypeMap:
987-
override def mapOver(t: Type): Type = t match
988-
case t @ FunctionOrMethod(args, res) if variance > 0 && !t.isAliasFun =>
989-
t // `t` should be mapped in this case by a different call to `toResult`. See [[toResultInResults]].
990-
case t: (LazyRef | TypeVar) =>
991-
mapConserveSuper(t)
992-
case _ =>
993-
super.mapOver(t)
1046+
abstract class CapMap(using Context) extends BiTypeMap:
1047+
override def mapOver(t: Type): Type = t match
1048+
case t @ FunctionOrMethod(args, res) if variance > 0 && !t.isAliasFun =>
1049+
t // `t` should be mapped in this case by a different call to `toResult`. See [[toResultInResults]].
1050+
case t: (LazyRef | TypeVar) =>
1051+
mapConserveSuper(t)
1052+
case _ =>
1053+
super.mapOver(t)
1054+
1055+
class ToResult(localResType: Type, mt: MethodicType, fail: Message => Unit)(using Context) extends CapMap:
1056+
1057+
def apply(t: Type) = t match
1058+
case defn.FunctionNOf(args, res, contextual) if t.typeSymbol.name.isImpureFunction =>
1059+
if variance > 0 then
1060+
super.mapOver:
1061+
defn.FunctionNOf(args, res, contextual)
1062+
.capturing(ResultCap(mt).singletonCaptureSet)
1063+
else mapOver(t)
1064+
case _ =>
1065+
mapOver(t)
1066+
1067+
override def mapCapability(c: Capability, deep: Boolean) = c match
1068+
case c: (FreshCap | GlobalCap.type) =>
1069+
if variance > 0 then
1070+
val res = ResultCap(mt)
1071+
c match
1072+
case c: FreshCap => res.setOrigin(c)
1073+
case _ =>
1074+
res
1075+
else
1076+
if variance == 0 then
1077+
fail(em"""$localResType captures the root capability `cap` in invariant position.
1078+
|This capability cannot be converted to an existential in the result type of a function.""")
1079+
// we accept variance < 0, and leave the cap as it is
1080+
c
1081+
case _ =>
1082+
super.mapCapability(c, deep)
9941083

995-
object toVar extends CapMap:
1084+
//.showing(i"mapcap $t = $result")
1085+
override def toString = "toVar"
9961086

997-
def apply(t: Type) = t match
998-
case defn.FunctionNOf(args, res, contextual) if t.typeSymbol.name.isImpureFunction =>
999-
if variance > 0 then
1000-
super.mapOver:
1001-
defn.FunctionNOf(args, res, contextual)
1002-
.capturing(ResultCap(mt).singletonCaptureSet)
1003-
else mapOver(t)
1004-
case _ =>
1005-
mapOver(t)
1087+
object inverse extends BiTypeMap:
1088+
def apply(t: Type) = mapOver(t)
10061089

10071090
override def mapCapability(c: Capability, deep: Boolean) = c match
1008-
case c: (FreshCap | GlobalCap.type) =>
1009-
if variance > 0 then
1010-
val res = ResultCap(mt)
1011-
c match
1012-
case c: FreshCap => res.setOrigin(c)
1013-
case _ =>
1014-
res
1015-
else
1016-
if variance == 0 then
1017-
fail(em"""$tp captures the root capability `cap` in invariant position.
1018-
|This capability cannot be converted to an existential in the result type of a function.""")
1019-
// we accept variance < 0, and leave the cap as it is
1020-
c
1091+
case c @ ResultCap(`mt`) =>
1092+
// do a reverse getOrElseUpdate on `seen` to produce the
1093+
// `Fresh` assosicated with `t`
1094+
val primary = c.primaryResultCap
1095+
primary.origin match
1096+
case GlobalCap =>
1097+
val fresh = FreshCap(Origin.LocalInstance(mt.resType))
1098+
primary.setOrigin(fresh)
1099+
fresh
1100+
case origin: FreshCap =>
1101+
origin
10211102
case _ =>
10221103
super.mapCapability(c, deep)
10231104

1024-
//.showing(i"mapcap $t = $result")
1025-
override def toString = "toVar"
1026-
1027-
object inverse extends BiTypeMap:
1028-
def apply(t: Type) = mapOver(t)
1029-
1030-
override def mapCapability(c: Capability, deep: Boolean) = c match
1031-
case c @ ResultCap(`mt`) =>
1032-
// do a reverse getOrElseUpdate on `seen` to produce the
1033-
// `Fresh` assosicated with `t`
1034-
val primary = c.primaryResultCap
1035-
primary.origin match
1036-
case GlobalCap =>
1037-
val fresh = FreshCap(Origin.Unknown)
1038-
primary.setOrigin(fresh)
1039-
fresh
1040-
case origin: FreshCap =>
1041-
origin
1042-
case _ =>
1043-
super.mapCapability(c, deep)
1044-
1045-
def inverse = toVar.this
1046-
override def toString = "toVar.inverse"
1047-
end inverse
1048-
end toVar
1105+
def inverse = ToResult.this
1106+
override def toString = "toVar.inverse"
1107+
end inverse
1108+
end ToResult
10491109

1050-
toVar(tp)
1051-
end toResult
1110+
/** Replace all occurrences of `cap` (or fresh) in parts of this type by an existentially bound
1111+
* variable bound by `mt`.
1112+
* Stop at function or method types since these have been mapped before.
1113+
*/
1114+
def toResult(tp: Type, mt: MethodicType, fail: Message => Unit)(using Context): Type =
1115+
ToResult(tp, mt, fail)(tp)
10521116

10531117
/** Map global roots in function results to result roots. Also,
10541118
* map roots in the types of def methods that are parameterless

compiler/src/dotty/tools/dotc/cc/CheckCaptures.scala

Lines changed: 25 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@ import config.Printers.{capt, recheckr, noPrinter}
1010
import config.{Config, Feature}
1111
import ast.{tpd, untpd, Trees}
1212
import Trees.*
13+
import typer.ForceDegree
14+
import typer.Inferencing.isFullyDefined
1315
import typer.RefChecks.{checkAllOverrides, checkSelfAgainstParents, OverridingPairsChecker}
1416
import typer.Checking.{checkBounds, checkAppliedTypesIn}
1517
import typer.ErrorReporting.{Addenda, NothingToAdd, err}
@@ -25,7 +27,7 @@ import NameKinds.{DefaultGetterName, WildcardParamName, UniqueNameKind}
2527
import reporting.{trace, Message, OverrideError}
2628
import Annotations.Annotation
2729
import Capabilities.*
28-
import dotty.tools.dotc.util.common.alwaysTrue
30+
import util.common.alwaysTrue
2931

3032
/** The capture checker */
3133
object CheckCaptures:
@@ -916,51 +918,45 @@ class CheckCaptures extends Recheck, SymTransformer:
916918
* { def $anonfun(...) = ...; closure($anonfun, ...)}
917919
*/
918920
override def recheckClosureBlock(mdef: DefDef, expr: Closure, pt: Type)(using Context): Type =
921+
val anonfun = mdef.symbol
919922

920-
def matchParams(paramss: List[ParamClause], pt: Type): Unit =
923+
def matchParamsAndResult(paramss: List[ParamClause], pt: Type): Unit =
921924
//println(i"match $mdef against $pt")
922925
paramss match
923926
case params :: paramss1 => pt match
924927
case defn.PolyFunctionOf(poly: PolyType) =>
925928
assert(params.hasSameLengthAs(poly.paramInfos))
926-
matchParams(paramss1, poly.instantiate(params.map(_.symbol.typeRef)))
929+
matchParamsAndResult(paramss1, poly.instantiate(params.map(_.symbol.typeRef)))
927930
case FunctionOrMethod(argTypes, resType) =>
928931
assert(params.hasSameLengthAs(argTypes), i"$mdef vs $pt, ${params}")
929932
for (argType, param) <- argTypes.lazyZip(params) do
930933
val paramTpt = param.asInstanceOf[ValDef].tpt
931934
val paramType = freshToCap(param.symbol, paramTpt.nuType)
932935
checkConformsExpr(argType, paramType, param)
933936
.showing(i"compared expected closure formal $argType against $param with ${paramTpt.nuType}", capt)
934-
if !pt.isInstanceOf[RefinedType]
935-
&& !(isEtaExpansion(mdef) && ccConfig.handleEtaExpansionsSpecially)
936-
then
937-
// If the closure is not an eta expansion and the expected type is a parametric
938-
// function type, check whether the closure's result conforms to the expected
939-
// result type. This constrains parameter types of the closure which can give better
940-
// error messages. It also prevents mapping fresh to result caps in the closure's
941-
// result type.
942-
// If the closure is an eta expanded method reference it's better to not constrain
943-
// its internals early since that would give error messages in generated code
944-
// which are less intelligible. An example is the line `a = x` in
945-
// neg-custom-args/captures/vars.scala. That's why this code is conditioned.
946-
// to apply only to closures that are not eta expansions.
947-
assert(paramss1.isEmpty)
948-
capt.println(i"pre-check closure $expr of type ${mdef.tpt.nuType} against $resType")
949-
checkConformsExpr(mdef.tpt.nuType, resType, expr)
937+
if resType.isValueType && isFullyDefined(resType, ForceDegree.none) then
938+
val localResType = pt match
939+
case RefinedType(_, _, mt: MethodType) =>
940+
inContext(ctx.withOwner(anonfun)):
941+
Internalize(mt)(resType)
942+
case _ => resType
943+
mdef.tpt.updNuType(localResType)
944+
// Make sure we affect the info of the anonfun by the previous updNuType
945+
// unless the info is already defined in a previous phase and does not change.
946+
assert(!anonfun.isCompleted || anonfun.denot.validFor.firstPhaseId != thisPhase.id)
947+
//println(i"updating ${mdef.tpt} to $localResType/${mdef.tpt.nuType}")
950948
case _ =>
951949
case Nil =>
952950

953-
openClosures = (mdef.symbol, pt) :: openClosures
951+
openClosures = (anonfun, pt) :: openClosures
954952
// openClosures is needed for errors but currently makes no difference
955953
// TODO follow up on this
956954
try
957-
matchParams(mdef.paramss, pt)
958-
capt.println(i"recheck closure block $mdef: ${mdef.symbol.infoOrCompleter}")
959-
if !mdef.symbol.isCompleted then
960-
mdef.symbol.ensureCompleted() // this will recheck def
961-
else
962-
recheckDef(mdef, mdef.symbol)
963-
955+
matchParamsAndResult(mdef.paramss, pt)
956+
capt.println(i"recheck closure block $mdef: ${anonfun.infoOrCompleter}")
957+
if !anonfun.isCompleted
958+
then anonfun.ensureCompleted() // this will recheck def
959+
else recheckDef(mdef, anonfun)
964960
recheckClosure(expr, pt, forceDependent = true)
965961
finally
966962
openClosures = openClosures.tail
@@ -1463,7 +1459,8 @@ class CheckCaptures extends Recheck, SymTransformer:
14631459
case FunctionOrMethod(aargs, ares) =>
14641460
val saved = curEnv
14651461
curEnv = Env(
1466-
curEnv.owner, EnvKind.NestedInOwner,
1462+
curEnv.owner,
1463+
if boxed then EnvKind.Boxed else EnvKind.NestedInOwner,
14671464
CaptureSet.Var(curEnv.owner, level = ccState.currentLevel),
14681465
if boxed then null else curEnv)
14691466
try

compiler/src/dotty/tools/dotc/transform/Recheck.scala

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -165,11 +165,12 @@ abstract class Recheck extends Phase, SymTransformer:
165165
* from the current type.
166166
*/
167167
def setNuType(tpe: Type): Unit =
168-
if nuTypes.lookup(tree) == null then updNuType(tpe)
168+
if nuTypes.lookup(tree) == null && (tpe ne tree.tpe) then
169+
updNuType(tpe)
169170

170171
/** Set new type of the tree unconditionally. */
171172
def updNuType(tpe: Type): Unit =
172-
if tpe ne tree.tpe then nuTypes(tree) = tpe
173+
nuTypes(tree) = tpe
173174

174175
/** The new type of the tree, or if none was installed, the original type */
175176
def nuType(using Context): Type =

tests/neg-custom-args/captures/capt1.check

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
-- [E007] Type Mismatch Error: tests/neg-custom-args/captures/capt1.scala:5:2 ------------------------------------------
22
5 | () => if x == null then y else y // error
33
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
4-
| Found: () ->{x} C^?
4+
| Found: () ->{x} C
55
| Required: () -> C
66
| Note that capability (x : C^) is not included in capture set {}.
77
|
@@ -52,12 +52,27 @@
5252
-- [E007] Type Mismatch Error: tests/neg-custom-args/captures/capt1.scala:36:24 ----------------------------------------
5353
36 | val z2 = h[() -> Cap](() => x) // error // error
5454
| ^^^^^^^
55-
|Found: () ->{x} C^{x}
56-
|Required: () -> C^
55+
|Found: () ->? C^
56+
|Required: () -> C^²
57+
|
58+
|where: ^ refers to a root capability associated with the result type of (): C^
59+
| ^² refers to a fresh root capability created in value z2 when checking argument to parameter a of method h
60+
|
61+
|Note that capability <cap of (): C^> is not included in capture set {cap}
62+
|because <cap of (): C^> is not visible from cap in value z2.
63+
|
64+
| longer explanation available when compiling with `-explain`
65+
-- [E007] Type Mismatch Error: tests/neg-custom-args/captures/capt1.scala:37:5 -----------------------------------------
66+
37 | (() => C()) // error
67+
| ^^^^^^^^^
68+
|Found: () ->? C^
69+
|Required: () -> C^²
5770
|
58-
|where: ^ refers to a fresh root capability created in value z2 when checking argument to parameter a of method h
71+
|where: ^ refers to a root capability associated with the result type of (): C^
72+
| ^² refers to a fresh root capability created in value z2 when checking argument to parameter b of method h
5973
|
60-
|Note that capability (x : C^) is not included in capture set {}.
74+
|Note that capability <cap of (): C^> is not included in capture set {cap}
75+
|because <cap of (): C^> is not visible from cap in value z2.
6176
|
6277
| longer explanation available when compiling with `-explain`
6378
-- Error: tests/neg-custom-args/captures/capt1.scala:38:13 -------------------------------------------------------------

0 commit comments

Comments
 (0)