Skip to content

Improve closure typing #23700

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Aug 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
194 changes: 129 additions & 65 deletions compiler/src/dotty/tools/dotc/cc/Capability.scala
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ object Capabilities:
i"a fresh root capability$classifierStr$originStr"

object FreshCap:
def apply(origin: Origin)(using Context): FreshCap | GlobalCap.type =
def apply(origin: Origin)(using Context): FreshCap =
FreshCap(ctx.owner, origin)

/** A root capability associated with a function type. These are conceptually
Expand Down Expand Up @@ -837,6 +837,7 @@ object Capabilities:
case Formal(pref: ParamRef, app: tpd.Apply)
case ResultInstance(methType: Type, meth: Symbol)
case UnapplyInstance(info: MethodType)
case LocalInstance(restpe: Type)
case NewMutable(tp: Type)
case NewCapability(tp: Type)
case LambdaExpected(respt: Type)
Expand Down Expand Up @@ -865,6 +866,8 @@ object Capabilities:
i" when instantiating $methDescr$mt"
case UnapplyInstance(info) =>
i" when instantiating argument of unapply with type $info"
case LocalInstance(restpe) =>
i" when instantiating expected result type $restpe of function literal"
case NewMutable(tp) =>
i" when constructing mutable $tp"
case NewCapability(tp) =>
Expand Down Expand Up @@ -948,6 +951,69 @@ object Capabilities:
def freshToCap(param: Symbol, tp: Type)(using Context): Type =
CapToFresh(Origin.Parameter(param)).inverse(tp)

/** The local dual of a result type of a closure type.
* @param binder the method type of the anonymous function whose result is mapped
* @pre the context's owner is the anonymous function
*/
class Internalize(binder: MethodType)(using Context) extends BiTypeMap:
thisMap =>

val sym = ctx.owner
assert(sym.isAnonymousFunction)
val paramSyms = atPhase(ctx.phase.prev):
// We need to ask one phase before since `sym` should not be completed as a side effect.
// The result of Internalize is used to se the result type of an anonymous function, and
// the new info of that function is built with the result.
sym.paramSymss.head
val resultToFresh = EqHashMap[ResultCap, FreshCap]()
val freshToResult = EqHashMap[FreshCap, ResultCap]()

override def apply(t: Type) =
if variance < 0 then t
else t match
case t: ParamRef =>
if t.binder == this.binder then paramSyms(t.paramNum).termRef else t
case _ => mapOver(t)

override def mapCapability(c: Capability, deep: Boolean): Capability = c match
case r: ResultCap if r.binder == this.binder =>
resultToFresh.get(r) match
case Some(f) => f
case None =>
val f = FreshCap(Origin.LocalInstance(binder.resType))
resultToFresh(r) = f
freshToResult(f) = r
f
case _ =>
super.mapCapability(c, deep)

class Inverse extends BiTypeMap:
def apply(t: Type): Type =
if variance < 0 then t
else t match
case t: TermRef if paramSyms.contains(t) =>
binder.paramRefs(paramSyms.indexOf(t.symbol))
case _ => mapOver(t)

override def mapCapability(c: Capability, deep: Boolean): Capability = c match
case f: FreshCap if f.owner == sym =>
freshToResult.get(f) match
case Some(r) => r
case None =>
val r = ResultCap(binder)
resultToFresh(r) = f
freshToResult(f) = r
r
case _ => super.mapCapability(c, deep)

def inverse = thisMap
override def toString = thisMap.toString + ".inverse"
end Inverse

override def toString = "InternalizeClosureResult"
def inverse = Inverse()
end Internalize

/** Map top-level free existential variables one-to-one to Fresh instances */
def resultToFresh(tp: Type, origin: Origin)(using Context): Type =
val subst = new TypeMap:
Expand Down Expand Up @@ -977,78 +1043,76 @@ object Capabilities:
subst(tp)
end resultToFresh

/** Replace all occurrences of `cap` (or fresh) in parts of this type by an existentially bound
* variable bound by `mt`.
* Stop at function or method types since these have been mapped before.
*/
def toResult(tp: Type, mt: MethodicType, fail: Message => Unit)(using Context): Type =

abstract class CapMap extends BiTypeMap:
override def mapOver(t: Type): Type = t match
case t @ FunctionOrMethod(args, res) if variance > 0 && !t.isAliasFun =>
t // `t` should be mapped in this case by a different call to `toResult`. See [[toResultInResults]].
case t: (LazyRef | TypeVar) =>
mapConserveSuper(t)
case _ =>
super.mapOver(t)
abstract class CapMap(using Context) extends BiTypeMap:
override def mapOver(t: Type): Type = t match
case t @ FunctionOrMethod(args, res) if variance > 0 && !t.isAliasFun =>
t // `t` should be mapped in this case by a different call to `toResult`. See [[toResultInResults]].
case t: (LazyRef | TypeVar) =>
mapConserveSuper(t)
case _ =>
super.mapOver(t)

class ToResult(localResType: Type, mt: MethodicType, fail: Message => Unit)(using Context) extends CapMap:

def apply(t: Type) = t match
case defn.FunctionNOf(args, res, contextual) if t.typeSymbol.name.isImpureFunction =>
if variance > 0 then
super.mapOver:
defn.FunctionNOf(args, res, contextual)
.capturing(ResultCap(mt).singletonCaptureSet)
else mapOver(t)
case _ =>
mapOver(t)

override def mapCapability(c: Capability, deep: Boolean) = c match
case c: (FreshCap | GlobalCap.type) =>
if variance > 0 then
val res = ResultCap(mt)
c match
case c: FreshCap => res.setOrigin(c)
case _ =>
res
else
if variance == 0 then
fail(em"""$localResType captures the root capability `cap` in invariant position.
|This capability cannot be converted to an existential in the result type of a function.""")
// we accept variance < 0, and leave the cap as it is
c
case _ =>
super.mapCapability(c, deep)

object toVar extends CapMap:
//.showing(i"mapcap $t = $result")
override def toString = "toVar"

def apply(t: Type) = t match
case defn.FunctionNOf(args, res, contextual) if t.typeSymbol.name.isImpureFunction =>
if variance > 0 then
super.mapOver:
defn.FunctionNOf(args, res, contextual)
.capturing(ResultCap(mt).singletonCaptureSet)
else mapOver(t)
case _ =>
mapOver(t)
object inverse extends BiTypeMap:
def apply(t: Type) = mapOver(t)

override def mapCapability(c: Capability, deep: Boolean) = c match
case c: (FreshCap | GlobalCap.type) =>
if variance > 0 then
val res = ResultCap(mt)
c match
case c: FreshCap => res.setOrigin(c)
case _ =>
res
else
if variance == 0 then
fail(em"""$tp captures the root capability `cap` in invariant position.
|This capability cannot be converted to an existential in the result type of a function.""")
// we accept variance < 0, and leave the cap as it is
c
case c @ ResultCap(`mt`) =>
// do a reverse getOrElseUpdate on `seen` to produce the
// `Fresh` assosicated with `t`
val primary = c.primaryResultCap
primary.origin match
case GlobalCap =>
val fresh = FreshCap(Origin.LocalInstance(mt.resType))
primary.setOrigin(fresh)
fresh
case origin: FreshCap =>
origin
case _ =>
super.mapCapability(c, deep)

//.showing(i"mapcap $t = $result")
override def toString = "toVar"

object inverse extends BiTypeMap:
def apply(t: Type) = mapOver(t)

override def mapCapability(c: Capability, deep: Boolean) = c match
case c @ ResultCap(`mt`) =>
// do a reverse getOrElseUpdate on `seen` to produce the
// `Fresh` assosicated with `t`
val primary = c.primaryResultCap
primary.origin match
case GlobalCap =>
val fresh = FreshCap(Origin.Unknown)
primary.setOrigin(fresh)
fresh
case origin: FreshCap =>
origin
case _ =>
super.mapCapability(c, deep)

def inverse = toVar.this
override def toString = "toVar.inverse"
end inverse
end toVar
def inverse = ToResult.this
override def toString = "toVar.inverse"
end inverse
end ToResult

toVar(tp)
end toResult
/** Replace all occurrences of `cap` (or fresh) in parts of this type by an existentially bound
* variable bound by `mt`.
* Stop at function or method types since these have been mapped before.
*/
def toResult(tp: Type, mt: MethodicType, fail: Message => Unit)(using Context): Type =
ToResult(tp, mt, fail)(tp)

/** Map global roots in function results to result roots. Also,
* map roots in the types of def methods that are parameterless
Expand Down
53 changes: 25 additions & 28 deletions compiler/src/dotty/tools/dotc/cc/CheckCaptures.scala
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ import config.Printers.{capt, recheckr, noPrinter}
import config.{Config, Feature}
import ast.{tpd, untpd, Trees}
import Trees.*
import typer.ForceDegree
import typer.Inferencing.isFullyDefined
import typer.RefChecks.{checkAllOverrides, checkSelfAgainstParents, OverridingPairsChecker}
import typer.Checking.{checkBounds, checkAppliedTypesIn}
import typer.ErrorReporting.{Addenda, NothingToAdd, err}
Expand All @@ -25,7 +27,7 @@ import NameKinds.{DefaultGetterName, WildcardParamName, UniqueNameKind}
import reporting.{trace, Message, OverrideError}
import Annotations.Annotation
import Capabilities.*
import dotty.tools.dotc.util.common.alwaysTrue
import util.common.alwaysTrue

/** The capture checker */
object CheckCaptures:
Expand Down Expand Up @@ -916,51 +918,45 @@ class CheckCaptures extends Recheck, SymTransformer:
* { def $anonfun(...) = ...; closure($anonfun, ...)}
*/
override def recheckClosureBlock(mdef: DefDef, expr: Closure, pt: Type)(using Context): Type =
val anonfun = mdef.symbol

def matchParams(paramss: List[ParamClause], pt: Type): Unit =
def matchParamsAndResult(paramss: List[ParamClause], pt: Type): Unit =
//println(i"match $mdef against $pt")
paramss match
case params :: paramss1 => pt match
case defn.PolyFunctionOf(poly: PolyType) =>
assert(params.hasSameLengthAs(poly.paramInfos))
matchParams(paramss1, poly.instantiate(params.map(_.symbol.typeRef)))
matchParamsAndResult(paramss1, poly.instantiate(params.map(_.symbol.typeRef)))
case FunctionOrMethod(argTypes, resType) =>
assert(params.hasSameLengthAs(argTypes), i"$mdef vs $pt, ${params}")
for (argType, param) <- argTypes.lazyZip(params) do
val paramTpt = param.asInstanceOf[ValDef].tpt
val paramType = freshToCap(param.symbol, paramTpt.nuType)
checkConformsExpr(argType, paramType, param)
.showing(i"compared expected closure formal $argType against $param with ${paramTpt.nuType}", capt)
if !pt.isInstanceOf[RefinedType]
&& !(isEtaExpansion(mdef) && ccConfig.handleEtaExpansionsSpecially)
then
// If the closure is not an eta expansion and the expected type is a parametric
// function type, check whether the closure's result conforms to the expected
// result type. This constrains parameter types of the closure which can give better
// error messages. It also prevents mapping fresh to result caps in the closure's
// result type.
// If the closure is an eta expanded method reference it's better to not constrain
// its internals early since that would give error messages in generated code
// which are less intelligible. An example is the line `a = x` in
// neg-custom-args/captures/vars.scala. That's why this code is conditioned.
// to apply only to closures that are not eta expansions.
assert(paramss1.isEmpty)
capt.println(i"pre-check closure $expr of type ${mdef.tpt.nuType} against $resType")
checkConformsExpr(mdef.tpt.nuType, resType, expr)
if resType.isValueType && isFullyDefined(resType, ForceDegree.none) then
val localResType = pt match
case RefinedType(_, _, mt: MethodType) =>
inContext(ctx.withOwner(anonfun)):
Internalize(mt)(resType)
case _ => resType
mdef.tpt.updNuType(localResType)
// Make sure we affect the info of the anonfun by the previous updNuType
// unless the info is already defined in a previous phase and does not change.
assert(!anonfun.isCompleted || anonfun.denot.validFor.firstPhaseId != thisPhase.id)
//println(i"updating ${mdef.tpt} to $localResType/${mdef.tpt.nuType}")
case _ =>
case Nil =>

openClosures = (mdef.symbol, pt) :: openClosures
openClosures = (anonfun, pt) :: openClosures
// openClosures is needed for errors but currently makes no difference
// TODO follow up on this
try
matchParams(mdef.paramss, pt)
capt.println(i"recheck closure block $mdef: ${mdef.symbol.infoOrCompleter}")
if !mdef.symbol.isCompleted then
mdef.symbol.ensureCompleted() // this will recheck def
else
recheckDef(mdef, mdef.symbol)

matchParamsAndResult(mdef.paramss, pt)
capt.println(i"recheck closure block $mdef: ${anonfun.infoOrCompleter}")
if !anonfun.isCompleted
then anonfun.ensureCompleted() // this will recheck def
else recheckDef(mdef, anonfun)
recheckClosure(expr, pt, forceDependent = true)
finally
openClosures = openClosures.tail
Expand Down Expand Up @@ -1463,7 +1459,8 @@ class CheckCaptures extends Recheck, SymTransformer:
case FunctionOrMethod(aargs, ares) =>
val saved = curEnv
curEnv = Env(
curEnv.owner, EnvKind.NestedInOwner,
curEnv.owner,
if boxed then EnvKind.Boxed else EnvKind.NestedInOwner,
CaptureSet.Var(curEnv.owner, level = ccState.currentLevel),
if boxed then null else curEnv)
try
Expand Down
5 changes: 3 additions & 2 deletions compiler/src/dotty/tools/dotc/transform/Recheck.scala
Original file line number Diff line number Diff line change
Expand Up @@ -165,11 +165,12 @@ abstract class Recheck extends Phase, SymTransformer:
* from the current type.
*/
def setNuType(tpe: Type): Unit =
if nuTypes.lookup(tree) == null then updNuType(tpe)
if nuTypes.lookup(tree) == null && (tpe ne tree.tpe) then
updNuType(tpe)

/** Set new type of the tree unconditionally. */
def updNuType(tpe: Type): Unit =
if tpe ne tree.tpe then nuTypes(tree) = tpe
nuTypes(tree) = tpe

/** The new type of the tree, or if none was installed, the original type */
def nuType(using Context): Type =
Expand Down
25 changes: 20 additions & 5 deletions tests/neg-custom-args/captures/capt1.check
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
-- [E007] Type Mismatch Error: tests/neg-custom-args/captures/capt1.scala:5:2 ------------------------------------------
5 | () => if x == null then y else y // error
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
| Found: () ->{x} C^?
| Found: () ->{x} C
| Required: () -> C
| Note that capability (x : C^) is not included in capture set {}.
|
Expand Down Expand Up @@ -52,12 +52,27 @@
-- [E007] Type Mismatch Error: tests/neg-custom-args/captures/capt1.scala:36:24 ----------------------------------------
36 | val z2 = h[() -> Cap](() => x) // error // error
| ^^^^^^^
|Found: () ->{x} C^{x}
|Required: () -> C^
|Found: () ->? C^
|Required: () -> C^²
|
|where: ^ refers to a root capability associated with the result type of (): C^
| ^² refers to a fresh root capability created in value z2 when checking argument to parameter a of method h
|
|Note that capability <cap of (): C^> is not included in capture set {cap}
|because <cap of (): C^> is not visible from cap in value z2.
|
| longer explanation available when compiling with `-explain`
-- [E007] Type Mismatch Error: tests/neg-custom-args/captures/capt1.scala:37:5 -----------------------------------------
37 | (() => C()) // error
| ^^^^^^^^^
|Found: () ->? C^
|Required: () -> C^²
|
|where: ^ refers to a fresh root capability created in value z2 when checking argument to parameter a of method h
|where: ^ refers to a root capability associated with the result type of (): C^
| ^² refers to a fresh root capability created in value z2 when checking argument to parameter b of method h
|
|Note that capability (x : C^) is not included in capture set {}.
|Note that capability <cap of (): C^> is not included in capture set {cap}
|because <cap of (): C^> is not visible from cap in value z2.
|
| longer explanation available when compiling with `-explain`
-- Error: tests/neg-custom-args/captures/capt1.scala:38:13 -------------------------------------------------------------
Expand Down
Loading
Loading