Skip to content

Commit 02cd565

Browse files
committed
Allow reach capabilities from within a nested closure
1 parent 17a4a8c commit 02cd565

File tree

3 files changed

+123
-10
lines changed

3 files changed

+123
-10
lines changed

compiler/src/dotty/tools/dotc/cc/CheckCaptures.scala

Lines changed: 22 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,8 @@ object CheckCaptures:
5050
owner: Symbol,
5151
kind: EnvKind,
5252
captured: CaptureSet,
53-
outer0: Env | Null):
53+
outer0: Env | Null,
54+
nestedClosure: Symbol = NoSymbol):
5455

5556
def outer = outer0.nn
5657

@@ -400,16 +401,18 @@ class CheckCaptures extends Recheck, SymTransformer:
400401
else
401402
!sym.isContainedIn(env.owner)
402403

403-
def checkUseDeclared(c: CaptureRef, env: Env) =
404-
c.pathRoot match
404+
def checkUseDeclared(c: CaptureRef, env: Env, lastEnv: Env | Null) =
405+
if lastEnv != null && env.nestedClosure.exists && env.nestedClosure == lastEnv.owner then
406+
() // access is from a nested closure, so it's OK
407+
else c.pathRoot match
405408
case ref: NamedType if !ref.symbol.hasAnnotation(defn.UseAnnot) =>
406409
val what = if ref.isType then "Capture set parameter" else "Local reach capability"
407410
report.error(
408411
em"""$what $c leaks into capture scope of ${env.ownerString}.
409412
|To allow this, the ${ref.symbol} should be declared with a @use annotation""", pos)
410413
case _ =>
411414

412-
def recur(cs: CaptureSet, env: Env)(using Context): Unit =
415+
def recur(cs: CaptureSet, env: Env, lastEnv: Env | Null)(using Context): Unit =
413416
if env.isOpen && !env.owner.isStaticOwner && !cs.isAlwaysEmpty then
414417
// Only captured references that are visible from the environment
415418
// should be included.
@@ -423,7 +426,7 @@ class CheckCaptures extends Recheck, SymTransformer:
423426
c match
424427
case ReachCapability(c1) =>
425428
if c1.isParamPath then
426-
checkUseDeclared(c, env)
429+
checkUseDeclared(c, env, lastEnv)
427430
else
428431
// When a reach capabilty x* where `x` is not a parameter goes out
429432
// of scope, we need to continue with `x`'s underlying deep capture set.
@@ -438,16 +441,16 @@ class CheckCaptures extends Recheck, SymTransformer:
438441
capt.println(i"Widen reach $c to $underlying in ${env.owner}")
439442
underlying.disallowRootCapability: () =>
440443
report.error(em"Local reach capability $c leaks into capture scope of ${env.ownerString}", pos)
441-
recur(underlying, env)
444+
recur(underlying, env, lastEnv)
442445
case c: TypeRef if c.isParamPath =>
443-
checkUseDeclared(c, env)
446+
checkUseDeclared(c, env, lastEnv)
444447
case _ =>
445448
isVisible
446449
checkSubset(included, env.captured, pos, provenance(env))
447450
capt.println(i"Include call or box capture $included from $cs in ${env.owner} --> ${env.captured}")
448451
if !isOfNestedMethod(env) then
449-
recur(included, nextEnvToCharge(env, !_.owner.isStaticOwner))
450-
recur(cs, curEnv)
452+
recur(included, nextEnvToCharge(env, !_.owner.isStaticOwner), env)
453+
recur(cs, curEnv, null)
451454
end markFree
452455

453456
/** Include references captured by the called method in the current environment stack */
@@ -843,10 +846,19 @@ class CheckCaptures extends Recheck, SymTransformer:
843846
override def recheckDefDef(tree: DefDef, sym: Symbol)(using Context): Type =
844847
if Synthetics.isExcluded(sym) then sym.info
845848
else
849+
// If rhs ends in a closure or anonymous class, the corresponding symbol
850+
def nestedClosure(rhs: Tree)(using Context): Symbol = rhs match
851+
case Closure(_, meth, _) => meth.symbol
852+
case Apply(fn, _) if fn.symbol.isConstructor && fn.symbol.owner.isAnonymousClass => fn.symbol.owner
853+
case Block(_, expr) => nestedClosure(expr)
854+
case Inlined(_, _, expansion) => nestedClosure(expansion)
855+
case Typed(expr, _) => nestedClosure(expr)
856+
case _ => NoSymbol
857+
846858
val saved = curEnv
847859
val localSet = capturedVars(sym)
848860
if !localSet.isAlwaysEmpty then
849-
curEnv = Env(sym, EnvKind.Regular, localSet, curEnv)
861+
curEnv = Env(sym, EnvKind.Regular, localSet, curEnv, nestedClosure(tree.rhs))
850862

851863
// ctx with AssumedContains entries for each Contains parameter
852864
val bodyCtx =

docs/_docs/internals/cc/use-design.md

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
2+
Possible design:
3+
4+
1. Have @use annotation on type parameters and value parameters of regular methods
5+
(not anonymous functions).
6+
2. In markFree, keep track whether a capture set variable or reach capability
7+
is used directly in the method where it is defined, or in a nested context
8+
(either unbound nested closure or unbound anonymous class).
9+
3. Disallow charging a reach capability `xs*` to the environment of the method where
10+
`xs` is a parameter unless `xs` is declared `@use`.
11+
4. Analogously, disallow charging a capture set variable `C^` to the environment of the method where `C^` is a parameter unless `C^` is declared `@use`.
12+
5. When passing an argument to a `@use`d term parameter, charge the `dcs` of the argument type to the environments via markFree.
13+
6. When instantiating a `@use`d type parameter, charge the capture set of the argument
14+
to the environments via markFree.
15+
16+
It follows that we cannot refer to methods with @use term parameters as values. Indeed,
17+
their eta expansion would produce an anonymous function that includes a reach capability of
18+
its parameter in its use set, violating (3).
19+
20+
Example:
21+
22+
```scala
23+
def runOps(@use ops: List[() => Unit]): Unit = ops.foreach(_())
24+
```
25+
Then `runOps` expands to
26+
```scala
27+
(xs: List[() => Unit]) => runOps(xs)
28+
```
29+
Note that `xs` does not carry a `@use` since this is disallowed by (1) for anonymous functions. By (5), we charge the deep capture set of `xs`, which is `xs*` to the environment. By (3), this is actually disallowed.
30+
31+
Now, if we express this with explicit capture set parameters we get:
32+
```scala
33+
def runOpsPoly[@use C^](ops: List[() ->{C^} Unit]): Unit = ops.foreach[C^](_())
34+
```
35+
Then `runOpsPoly` expands to `runOpsPoly[cs]` for some inferred capture set `cs`. And this expands to:
36+
```scala
37+
(xs: List[() ->{cs} Unit]) => runOpsPoly[cs](xs)
38+
```
39+
Since `cs` is passed to the `@use` parameter of `runOpsPoly` it is charged
40+
to the environment of the function body, so the type of the previous expression is
41+
```scala
42+
List[() ->{cs} Unit]) ->{cs} Unit
43+
```
44+
45+
We can also use explicit capture set parameters to eta expand the first `runOps` manually:
46+
47+
```scala
48+
[C^] => (xs: List[() ->{C^} Unit]) => runOps(xs)
49+
: [C^] -> List[() ->{C^} Unit] ->[C^] Unit
50+
```
51+
Except that this currently runs afoul of the implementation restriction that polymorphic functions cannot wrap capturing functions. But that's a restriction we need to lift anyway.
52+
53+
## `@use` inference
54+
55+
- `@use` is implied for a term parameter `x` of a method if `x`'s type contains a boxed cap and `x` or `x*` is not referred to in the result type of the method.
56+
57+
- `@use` is implied for a capture set parameter `C` of a method if `C` is not referred to in the result type of the method.
58+
59+
If `@use` is implied, one can override to no use by giving an explicit use annotation
60+
`@use(false)` instead. Example:
61+
```scala
62+
def f(@use(false) xs: List[() => Unit]): Int = xs.length
63+
```
64+
65+
This works since `@use` is defined like this:
66+
```scala
67+
class use(cond: Boolean = true) extends StaticAnnotation
68+
```
69+
70+
71+
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
def test(xs: List[() => Unit]) =
2+
xs.head // error
3+
4+
def foo =
5+
xs.head // ok
6+
def bar() =
7+
xs.head // ok
8+
9+
class Foo:
10+
println(xs.head) // error, but could be OK
11+
12+
foo // error
13+
bar() // error
14+
Foo() // OK, but could be error
15+
16+
def test2(xs: List[() => Unit]) =
17+
def foo = xs.head // ok
18+
()
19+
20+
def test3(xs: List[() => Unit]): () ->{xs*} Unit = () =>
21+
println(xs.head) // ok
22+
23+
def test4(xs: List[() => Unit]) = () => xs.head // ok
24+
25+
def test5(xs: List[() => Unit]) = new:
26+
println(xs.head) // ok
27+
28+
def test6(xs: List[() => Unit]) =
29+
val x= new { println(xs.head) } // error
30+
x

0 commit comments

Comments
 (0)