Skip to content

Commit 5abd5c5

Browse files
committed
Improve closure typing
If the closure has an expected function type with a fully defined result type, take the internalized result type as the local return type of the closure. This has the effect that some conformance tests are now done with Fresh instead Result caps. This means a now can widen a local reference to a result cap, since the comparison is done between the local reference and the internalized FreshCap. Previously this failed since we compared a local cap with result cap, and result caps only subtype other result caps. It also propagates types more aggressively into closure bodies, which sometimes reduces the error span and improves the error message.
1 parent 5669497 commit 5abd5c5

21 files changed

+209
-115
lines changed

compiler/src/dotty/tools/dotc/cc/Capability.scala

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -948,6 +948,69 @@ object Capabilities:
948948
def freshToCap(tp: Type)(using Context): Type =
949949
CapToFresh(Origin.Unknown).inverse(tp)
950950

951+
/** The local dual of a result type of a closure type.
952+
* @param binder the method type of the anonymous function whose result is mapped
953+
* @pre the context's owner is the anonymous function
954+
*/
955+
class Internalize(binder: MethodType)(using Context) extends BiTypeMap:
956+
thisMap =>
957+
958+
val sym = ctx.owner
959+
assert(sym.isAnonymousFunction)
960+
val paramSyms = atPhase(ctx.phase.prev):
961+
// We need to ask one phase before since `sym` should not be completed as a side effect.
962+
// The result of Internalize is used to se the result type of an anonymous function, and
963+
// the new info of that function is built with the result.
964+
sym.paramSymss.head
965+
val resultToFresh = EqHashMap[ResultCap, FreshCap]()
966+
val freshToResult = EqHashMap[FreshCap, ResultCap]()
967+
968+
override def apply(t: Type) =
969+
if variance < 0 then t
970+
else t match
971+
case t: ParamRef =>
972+
if t.binder == this.binder then paramSyms(t.paramNum).termRef else t
973+
case _ => mapOver(t)
974+
975+
override def mapCapability(c: Capability, deep: Boolean): Capability = c match
976+
case r: ResultCap if r.binder == this.binder =>
977+
resultToFresh.get(r) match
978+
case Some(f) => f
979+
case None =>
980+
val f = FreshCap(Origin.LocalInstance(binder.resType))
981+
resultToFresh(r) = f
982+
freshToResult(f) = r
983+
f
984+
case _ =>
985+
super.mapCapability(c, deep)
986+
987+
class Inverse extends BiTypeMap:
988+
def apply(t: Type): Type =
989+
if variance < 0 then t
990+
else t match
991+
case t: TermRef if paramSyms.contains(t) =>
992+
binder.paramRefs(paramSyms.indexOf(t.symbol))
993+
case _ => mapOver(t)
994+
995+
override def mapCapability(c: Capability, deep: Boolean): Capability = c match
996+
case f: FreshCap if f.owner == sym =>
997+
freshToResult.get(f) match
998+
case Some(r) => r
999+
case None =>
1000+
val r = ResultCap(binder)
1001+
resultToFresh(r) = f
1002+
freshToResult(f) = r
1003+
r
1004+
case _ => super.mapCapability(c, deep)
1005+
1006+
def inverse = thisMap
1007+
override def toString = thisMap.toString + ".inverse"
1008+
end Inverse
1009+
1010+
override def toString = "InternalizeClosureResult"
1011+
def inverse = Inverse()
1012+
end Internalize
1013+
9511014
/** Map top-level free existential variables one-to-one to Fresh instances */
9521015
def resultToFresh(tp: Type, origin: Origin)(using Context): Type =
9531016
val subst = new TypeMap:

compiler/src/dotty/tools/dotc/cc/CheckCaptures.scala

Lines changed: 24 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@ import config.Printers.{capt, recheckr, noPrinter}
1010
import config.{Config, Feature}
1111
import ast.{tpd, untpd, Trees}
1212
import Trees.*
13+
import typer.ForceDegree
14+
import typer.Inferencing.isFullyDefined
1315
import typer.RefChecks.{checkAllOverrides, checkSelfAgainstParents, OverridingPairsChecker}
1416
import typer.Checking.{checkBounds, checkAppliedTypesIn}
1517
import typer.ErrorReporting.{Addenda, NothingToAdd, err}
@@ -25,7 +27,7 @@ import NameKinds.{DefaultGetterName, WildcardParamName, UniqueNameKind}
2527
import reporting.{trace, Message, OverrideError}
2628
import Annotations.Annotation
2729
import Capabilities.*
28-
import dotty.tools.dotc.util.common.alwaysTrue
30+
import util.common.alwaysTrue
2931

3032
/** The capture checker */
3133
object CheckCaptures:
@@ -907,6 +909,7 @@ class CheckCaptures extends Recheck, SymTransformer:
907909
* { def $anonfun(...) = ...; closure($anonfun, ...)}
908910
*/
909911
override def recheckClosureBlock(mdef: DefDef, expr: Closure, pt: Type)(using Context): Type =
912+
val anonfun = mdef.symbol
910913

911914
def matchParams(paramss: List[ParamClause], pt: Type): Unit =
912915
//println(i"match $mdef against $pt")
@@ -922,7 +925,19 @@ class CheckCaptures extends Recheck, SymTransformer:
922925
val paramType = freshToCap(paramTpt.nuType)
923926
checkConformsExpr(argType, paramType, param)
924927
.showing(i"compared expected closure formal $argType against $param with ${paramTpt.nuType}", capt)
925-
if !pt.isInstanceOf[RefinedType]
928+
if ccConfig.newScheme then
929+
if resType.isValueType && isFullyDefined(resType, ForceDegree.none) then
930+
val localResType = pt match
931+
case RefinedType(_, _, mt: MethodType) =>
932+
inContext(ctx.withOwner(anonfun)):
933+
Internalize(mt)(resType)
934+
case _ => resType
935+
mdef.tpt.updNuType(localResType)
936+
// Make sure we affect the info of the anonfun by the previous updNuType
937+
// unless the info is already defined in a previous phase and does not change.
938+
assert(!anonfun.isCompleted || anonfun.denot.validFor.firstPhaseId != thisPhase.id)
939+
//println(i"updating ${mdef.tpt} to $localResType/${mdef.tpt.nuType}")
940+
else if !pt.isInstanceOf[RefinedType]
926941
&& !(isEtaExpansion(mdef) && ccConfig.handleEtaExpansionsSpecially)
927942
then
928943
// If the closure is not an eta expansion and the expected type is a parametric
@@ -941,16 +956,16 @@ class CheckCaptures extends Recheck, SymTransformer:
941956
case _ =>
942957
case Nil =>
943958

944-
openClosures = (mdef.symbol, pt) :: openClosures
959+
openClosures = (anonfun, pt) :: openClosures
945960
// openClosures is needed for errors but currently makes no difference
946961
// TODO follow up on this
947962
try
948963
matchParams(mdef.paramss, pt)
949-
capt.println(i"recheck closure block $mdef: ${mdef.symbol.infoOrCompleter}")
950-
if !mdef.symbol.isCompleted then
951-
mdef.symbol.ensureCompleted() // this will recheck def
964+
capt.println(i"recheck closure block $mdef: ${anonfun.infoOrCompleter}")
965+
if !anonfun.isCompleted then
966+
anonfun.ensureCompleted() // this will recheck def
952967
else
953-
recheckDef(mdef, mdef.symbol)
968+
recheckDef(mdef, anonfun)
954969

955970
recheckClosure(expr, pt, forceDependent = true)
956971
finally
@@ -1454,7 +1469,8 @@ class CheckCaptures extends Recheck, SymTransformer:
14541469
case FunctionOrMethod(aargs, ares) =>
14551470
val saved = curEnv
14561471
curEnv = Env(
1457-
curEnv.owner, EnvKind.NestedInOwner,
1472+
curEnv.owner,
1473+
if boxed then EnvKind.Boxed else EnvKind.NestedInOwner,
14581474
CaptureSet.Var(curEnv.owner, level = ccState.currentLevel),
14591475
if boxed then null else curEnv)
14601476
try

compiler/src/dotty/tools/dotc/transform/Recheck.scala

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -165,11 +165,12 @@ abstract class Recheck extends Phase, SymTransformer:
165165
* from the current type.
166166
*/
167167
def setNuType(tpe: Type): Unit =
168-
if nuTypes.lookup(tree) == null then updNuType(tpe)
168+
if nuTypes.lookup(tree) == null && (tpe ne tree.tpe) then
169+
updNuType(tpe)
169170

170171
/** Set new type of the tree unconditionally. */
171172
def updNuType(tpe: Type): Unit =
172-
if tpe ne tree.tpe then nuTypes(tree) = tpe
173+
nuTypes(tree) = tpe
173174

174175
/** The new type of the tree, or if none was installed, the original type */
175176
def nuType(using Context): Type =

tests/neg-custom-args/captures/capt1.check

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
-- [E007] Type Mismatch Error: tests/neg-custom-args/captures/capt1.scala:5:2 ------------------------------------------
22
5 | () => if x == null then y else y // error
33
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
4-
| Found: () ->{x} C^?
4+
| Found: () ->{x} C
55
| Required: () -> C
66
| Note that capability (x : C^) is not included in capture set {}.
77
|
@@ -52,12 +52,27 @@
5252
-- [E007] Type Mismatch Error: tests/neg-custom-args/captures/capt1.scala:36:24 ----------------------------------------
5353
36 | val z2 = h[() -> Cap](() => x) // error // error
5454
| ^^^^^^^
55-
|Found: () ->{x} C^{x}
56-
|Required: () -> C^
55+
|Found: () ->? C^
56+
|Required: () -> C^²
57+
|
58+
|where: ^ refers to a root capability associated with the result type of (): C^
59+
| ^² refers to a fresh root capability created in value z2 when checking argument to parameter a of method h
60+
|
61+
|Note that capability <cap of (): C^> is not included in capture set {cap}
62+
|because <cap of (): C^> is not visible from cap in value z2.
63+
|
64+
| longer explanation available when compiling with `-explain`
65+
-- [E007] Type Mismatch Error: tests/neg-custom-args/captures/capt1.scala:37:5 -----------------------------------------
66+
37 | (() => C()) // error
67+
| ^^^^^^^^^
68+
|Found: () ->? C^
69+
|Required: () -> C^²
5770
|
58-
|where: ^ refers to a fresh root capability created in value z2 when checking argument to parameter a of method h
71+
|where: ^ refers to a root capability associated with the result type of (): C^
72+
| ^² refers to a fresh root capability created in value z2 when checking argument to parameter b of method h
5973
|
60-
|Note that capability (x : C^) is not included in capture set {}.
74+
|Note that capability <cap of (): C^> is not included in capture set {cap}
75+
|because <cap of (): C^> is not visible from cap in value z2.
6176
|
6277
| longer explanation available when compiling with `-explain`
6378
-- Error: tests/neg-custom-args/captures/capt1.scala:38:13 -------------------------------------------------------------

tests/neg-custom-args/captures/capt1.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def foo() =
3434
def h[X](a: X)(b: X) = a
3535

3636
val z2 = h[() -> Cap](() => x) // error // error
37-
(() => C())
37+
(() => C()) // error
3838
val z3 = h[(() -> Cap) @retains[x.type]](() => x)(() => C()) // error
3939

4040
val z1: () => Cap = f1(x)
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
-- [E007] Type Mismatch Error: tests/neg-custom-args/captures/closure-result-typing.scala:2:30 -------------------------
2+
2 | val x: () -> Object = () => c // error
3+
| ^
4+
| Found: (c : Object^)
5+
| Required: Object
6+
|
7+
| where: ^ refers to a fresh root capability in the type of parameter c
8+
|
9+
| Note that capability cap is not included in capture set {}.
10+
|
11+
| longer explanation available when compiling with `-explain`
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
def test(c: Object^): Unit =
2+
val x: () -> Object = () => c // error

tests/neg-custom-args/captures/eta.check

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,3 @@
66
| Note that capability (f : Proc^) is not included in capture set {}.
77
|
88
| longer explanation available when compiling with `-explain`
9-
-- [E007] Type Mismatch Error: tests/neg-custom-args/captures/eta.scala:6:14 -------------------------------------------
10-
6 | bar( () => f ) // error
11-
| ^^^^^^^
12-
| Found: () ->{f} () ->{f} Unit
13-
| Required: () -> () ->{f} Unit
14-
| Note that capability (f : Proc^) is not included in capture set {}.
15-
|
16-
| longer explanation available when compiling with `-explain`

tests/neg-custom-args/captures/eta.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,5 +3,5 @@
33
def bar[A <: Proc^{f}](g: () -> A): () -> Proc^{f} =
44
g // error
55
val stowaway: () -> Proc^{f} =
6-
bar( () => f ) // error
6+
bar( () => f ) // was error now OK
77
() => { stowaway.apply().apply() }

tests/neg-custom-args/captures/filevar.check

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,10 @@
11
-- [E007] Type Mismatch Error: tests/neg-custom-args/captures/filevar.scala:15:12 --------------------------------------
22
15 | withFile: f => // error with level checking, was OK under both schemes before
33
| ^
4-
|Found: (l: scala.caps.Capability^) ?->? File^? ->? Unit
5-
|Required: (l: scala.caps.Capability^) ?-> (f: File^{l}) => Unit
4+
|Found: (f: File^?) ->? Unit
5+
|Required: (f: File^{l}) => Unit
66
|
7-
|where: => refers to a root capability associated with the result type of (using l: scala.caps.Capability^): (f: File^{l}) => Unit
8-
| ^ refers to the universal root capability
7+
|where: => refers to a fresh root capability created in anonymous function of type (using l²: scala.caps.Capability): File^{l²} -> Unit when instantiating expected result type (f: File^{l}) ->{cap} Unit of function literal
98
|
109
|Note that capability l.type
1110
|cannot be included in outer capture set ? of parameter f.

0 commit comments

Comments
 (0)