Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
119 changes: 86 additions & 33 deletions compiler/src/dotty/tools/dotc/cc/CheckCaptures.scala
Original file line number Diff line number Diff line change
Expand Up @@ -407,25 +407,22 @@ class CheckCaptures extends Recheck, SymTransformer:
else i"references $cs1$cs1description are not all",
cs1, cs2, pos, provenance)

/** If `sym` is a class or method nested inside a term, a capture set variable representing
* the captured variables of the environment associated with `sym`.
/** If `sym` is a method or a non-static inner class, a capture set variable
* representing the captured variables of the environment associated with `sym`.
*/
def capturedVars(sym: Symbol)(using Context): CaptureSet =
myCapturedVars.getOrElseUpdate(sym,
if sym.ownersIterator.exists(_.isTerm)
if sym.isTerm || !sym.owner.isStaticOwner
then CaptureSet.Var(sym.owner, level = ccState.symLevel(sym))
else CaptureSet.empty)

// ---- Record Uses with MarkFree ----------------------------------------------------

/** The next environment enclosing `env` that needs to be charged
* with free references.
* @param included Whether an environment is included in the range of
* environments to charge. Once `included` is false, no
* more environments need to be charged.
*/
def nextEnvToCharge(env: Env, included: Env => Boolean)(using Context): Env =
if env.owner.isConstructor && included(env.outer) then env.outer.outer
def nextEnvToCharge(env: Env)(using Context): Env | Null =
if env.owner.isConstructor then env.outer.outer0
else env.outer

/** A description where this environment comes from */
Expand Down Expand Up @@ -458,21 +455,27 @@ class CheckCaptures extends Recheck, SymTransformer:
markFree(sym, sym.termRef, tree)

def markFree(sym: Symbol, ref: Capability, tree: Tree)(using Context): Unit =
if sym.exists && ref.isTracked then markFree(ref.singletonCaptureSet, tree)
if sym.exists then markFree(ref, tree)

def markFree(ref: Capability, tree: Tree)(using Context): Unit =
if ref.isTracked then markFree(ref.singletonCaptureSet, tree)

/** Make sure the (projected) `cs` is a subset of the capture sets of all enclosing
* environments. At each stage, only include references from `cs` that are outside
* the environment's owner
*/
def markFree(cs: CaptureSet, tree: Tree)(using Context): Unit =
def markFree(cs: CaptureSet, tree: Tree, addUseInfo: Boolean = true)(using Context): Unit =
// A captured reference with the symbol `sym` is visible from the environment
// if `sym` is not defined inside the owner of the environment.
inline def isVisibleFromEnv(sym: Symbol, env: Env) =
sym.exists && {
val effectiveOwner =
if env.owner.isConstructor then env.owner.owner
else env.owner
if env.kind == EnvKind.NestedInOwner then
!sym.isProperlyContainedIn(env.owner)
!sym.isProperlyContainedIn(effectiveOwner)
else
!sym.isContainedIn(env.owner)
!sym.isContainedIn(effectiveOwner)
}

/** Avoid locally defined capability by charging the underlying type
Expand Down Expand Up @@ -535,13 +538,15 @@ class CheckCaptures extends Recheck, SymTransformer:
checkSubset(included, env.captured, tree.srcPos, provenance(env))
capt.println(i"Include call or box capture $included from $cs in ${env.owner} --> ${env.captured}")
if !isOfNestedMethod(env) then
recur(included, nextEnvToCharge(env, !_.owner.isStaticOwner), env)
val nextEnv = nextEnvToCharge(env)
if nextEnv != null && !nextEnv.owner.isStaticOwner then
recur(included, nextEnv, env)
// Under deferredReaches, don't propagate out of methods inside terms.
// The use set of these methods will be charged when that method is called.

if !cs.isAlwaysEmpty then
recur(cs, curEnv, null)
useInfos += ((tree, cs, curEnv))
if addUseInfo then useInfos += ((tree, cs, curEnv))
end markFree

/** If capability `c` refers to a parameter that is not implicitly or explicitly
Expand Down Expand Up @@ -626,25 +631,33 @@ class CheckCaptures extends Recheck, SymTransformer:
// If ident refers to a parameterless method, charge its cv to the environment
includeCallCaptures(sym, sym.info, tree)
else if !sym.isStatic then
// Otherwise charge its symbol, but add all selections and also any `.rd`
// modifier implied by the expected type `pt`.
// Example: If we have `x` and the expected type says we select that with `.a.b`
// where `b` is a read-only method, we charge `x.a.b.rd` instead of `x`.
def addSelects(ref: TermRef, pt: Type): Capability = pt match
case pt: PathSelectionProto if ref.isTracked =>
if pt.sym.isReadOnlyMethod then
ref.readOnly
else
// if `ref` is not tracked then the selection could not give anything new
// class SerializationProxy in stdlib-cc/../LazyListIterable.scala has an example where this matters.
addSelects(ref.select(pt.sym).asInstanceOf[TermRef], pt.pt)
case _ => ref
var pathRef: Capability = addSelects(sym.termRef, pt)
if pathRef.derivesFromMutable && pt.isValueType && !pt.isMutableType then
pathRef = pathRef.readOnly
markFree(sym, pathRef, tree)
markFree(sym, pathRef(sym.termRef, pt), tree)
mapResultRoots(super.recheckIdent(tree, pt), tree.symbol)

override def recheckThis(tree: This, pt: Type)(using Context): Type =
markFree(pathRef(tree.tpe.asInstanceOf[ThisType], pt), tree)
super.recheckThis(tree, pt)

/** Add all selections and also any `.rd modifier implied by the expected
* type `pt` to `base`. Example:
* If we have `x` and the expected type says we select that with `.a.b`
* where `b` is a read-only method, we charge `x.a.b.rd` instead of `x`.
*/
private def pathRef(base: TermRef | ThisType, pt: Type)(using Context): Capability =
def addSelects(ref: TermRef | ThisType, pt: Type): Capability = pt match
case pt: PathSelectionProto if ref.isTracked =>
if pt.sym.isReadOnlyMethod then
ref.readOnly
else
// if `ref` is not tracked then the selection could not give anything new
// class SerializationProxy in stdlib-cc/../LazyListIterable.scala has an example where this matters.
addSelects(ref.select(pt.sym).asInstanceOf[TermRef], pt.pt)
case _ => ref
val ref: Capability = addSelects(base, pt)
if ref.derivesFromMutable && pt.isValueType && !pt.isMutableType
then ref.readOnly
else ref

/** The expected type for the qualifier of a selection. If the selection
* could be part of a capability path or is a a read-only method, we return
* a PathSelectionProto.
Expand Down Expand Up @@ -866,7 +879,7 @@ class CheckCaptures extends Recheck, SymTransformer:
val (refined, cs) = addParamArgRefinements(core, initCs)
refined.capturing(cs)

augmentConstructorType(resType, capturedVars(cls) ++ capturedVars(constr))
augmentConstructorType(resType, capturedVars(cls))
.showing(i"constr type $mt with $argTypes%, % in $constr = $result", capt)
end refineConstructorInstance

Expand Down Expand Up @@ -975,6 +988,8 @@ class CheckCaptures extends Recheck, SymTransformer:
* - Interpolate contravariant capture set variables in result type.
*/
override def recheckValDef(tree: ValDef, sym: Symbol)(using Context): Type =
val savedEnv = curEnv
val runInConstructor = !sym.isOneOf(Param | ParamAccessor | Lazy | NonMember)
try
if sym.is(Module) then sym.info // Modules are checked by checking the module class
else
Expand All @@ -993,6 +1008,8 @@ class CheckCaptures extends Recheck, SymTransformer:
""
disallowBadRootsIn(
tree.tpt.nuType, NoSymbol, i"Mutable $sym", "have type", addendum, sym.srcPos)
if runInConstructor then
pushConstructorEnv()
checkInferredResult(super.recheckValDef(tree, sym), tree)
finally
if !sym.is(Param) then
Expand All @@ -1002,6 +1019,22 @@ class CheckCaptures extends Recheck, SymTransformer:
// function is compiled since we do not propagate expected types into blocks.
interpolateIfInferred(tree.tpt, sym)

def declaredCaptures = tree.tpt.nuType.captureSet
if runInConstructor && savedEnv.owner.isClass then
curEnv = savedEnv
markFree(declaredCaptures, tree, addUseInfo = false)

if sym.owner.isStaticOwner && !declaredCaptures.elems.isEmpty && sym != defn.captureRoot then
def where =
if sym.effectiveOwner.is(Package) then "top-level definition"
else i"member of static ${sym.owner}"
report.warning(
em"""$sym has a non-empty capture set but will not be added as
|a capability to computed capture sets since it is globally accessible
|as a $where. Global values cannot be capabilities.""",
tree.namePos)
end recheckValDef

/** Recheck method definitions:
* - check body in a nested environment that tracks uses, in a nested level,
* and in a nested context that knows abaout Contains parameters so that we
Expand Down Expand Up @@ -1228,6 +1261,24 @@ class CheckCaptures extends Recheck, SymTransformer:
recheckFinish(result, arg, pt)
*/

/** If environment is owned by a class, run in a new environment owned by
* its primary constructor instead.
*/
def pushConstructorEnv()(using Context): Unit =
if curEnv.owner.isClass then
val constr = curEnv.owner.primaryConstructor
if constr.exists then
val constrSet = capturedVars(constr)
if capturedVars(constr) ne CaptureSet.empty then
curEnv = Env(constr, EnvKind.Regular, constrSet, curEnv)

override def recheckStat(stat: Tree)(using Context): Unit =
val saved = curEnv
if !stat.isInstanceOf[MemberDef] then
pushConstructorEnv()
try recheck(stat)
finally curEnv = saved

/** The main recheck method does some box adapation for all nodes:
* - If expected type `pt` is boxed and the tree is a lambda or a reference,
* don't propagate free variables.
Expand Down Expand Up @@ -2021,7 +2072,9 @@ class CheckCaptures extends Recheck, SymTransformer:
if env.kind == EnvKind.Boxed then env.owner
else if isOfNestedMethod(env) then env.owner.owner
else if env.owner.isStaticOwner then NoSymbol
else boxedOwner(nextEnvToCharge(env, alwaysTrue))
else
val nextEnv = nextEnvToCharge(env)
if nextEnv == null then NoSymbol else boxedOwner(nextEnv)

def checkUseUnlessBoxed(c: Capability, croot: NamedType) =
if !boxedOwner(env).isContainedIn(croot.symbol.owner) then
Expand Down
15 changes: 13 additions & 2 deletions compiler/src/dotty/tools/dotc/transform/Recheck.scala
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,12 @@ abstract class Recheck extends Phase, SymTransformer:
def recheckSelection(tree: Select, qualType: Type, name: Name, pt: Type)(using Context): Type =
recheckSelection(tree, qualType, name, sharpen = identity[Denotation])

def recheckThis(tree: This, pt: Type)(using Context): Type =
tree.tpe

def recheckSuper(tree: Super, pt: Type)(using Context): Type =
tree.tpe

def recheckBind(tree: Bind, pt: Type)(using Context): Type = tree match
case Bind(name, body) =>
recheck(body, pt)
Expand Down Expand Up @@ -487,12 +493,15 @@ abstract class Recheck extends Phase, SymTransformer:
recheckStats(tree.stats)
NoType

def recheckStat(stat: Tree)(using Context): Unit =
recheck(stat)

def recheckStats(stats: List[Tree])(using Context): Unit =
@tailrec def traverse(stats: List[Tree])(using Context): Unit = stats match
case (imp: Import) :: rest =>
traverse(rest)(using ctx.importContext(imp, imp.symbol))
case stat :: rest =>
recheck(stat)
recheckStat(stat)
traverse(rest)
case _ =>
traverse(stats)
Expand Down Expand Up @@ -540,7 +549,9 @@ abstract class Recheck extends Phase, SymTransformer:
def recheckUnnamed(tree: Tree, pt: Type): Type = tree match
case tree: Apply => recheckApply(tree, pt)
case tree: TypeApply => recheckTypeApply(tree, pt)
case _: New | _: This | _: Super | _: Literal => tree.tpe
case tree: This => recheckThis(tree, pt)
case tree: Super => recheckSuper(tree, pt)
case _: New | _: Literal => tree.tpe
case tree: Typed => recheckTyped(tree)
case tree: Assign => recheckAssign(tree)
case tree: Block => recheckBlock(tree, pt)
Expand Down
12 changes: 6 additions & 6 deletions library/src/scala/collection/Iterator.scala
Original file line number Diff line number Diff line change
Expand Up @@ -418,7 +418,7 @@ trait Iterator[+A] extends IterableOnce[A] with IterableOnceOps[A, Iterator, Ite

@deprecated("Call scanRight on an Iterable instead.", "2.13.0")
def scanRight[B](z: B)(op: (A, B) => B): Iterator[B]^{this, op} = ArrayBuffer.from(this).scanRight(z)(op).iterator

/** Finds index of the first element satisfying some predicate after or at some start index.
*
* $mayNotTerminateInf
Expand Down Expand Up @@ -494,9 +494,9 @@ trait Iterator[+A] extends IterableOnce[A] with IterableOnceOps[A, Iterator, Ite
while (p(hd) == isFlipped) {
if (!self.hasNext) return false
hd = self.next()
}
}
hdDefined = true
true
true
}

def next() =
Expand Down Expand Up @@ -874,7 +874,7 @@ trait Iterator[+A] extends IterableOnce[A] with IterableOnceOps[A, Iterator, Ite
*/
def duplicate: (Iterator[A]^{this}, Iterator[A]^{this}) = {
val gap = new scala.collection.mutable.Queue[A]
var ahead: Iterator[A] = null
var ahead: Iterator[A]^ = null
class Partner extends AbstractIterator[A] {
override def knownSize: Int = self.synchronized {
val thisSize = self.knownSize
Expand All @@ -890,7 +890,7 @@ trait Iterator[+A] extends IterableOnce[A] with IterableOnceOps[A, Iterator, Ite
if (gap.isEmpty) ahead = this
if (this eq ahead) {
val e = self.next()
gap enqueue e
gap.enqueue(e)
e
} else gap.dequeue()
}
Expand Down Expand Up @@ -918,7 +918,7 @@ trait Iterator[+A] extends IterableOnce[A] with IterableOnceOps[A, Iterator, Ite
*/
def patch[B >: A](from: Int, patchElems: Iterator[B]^, replaced: Int): Iterator[B]^{this, patchElems} =
new AbstractIterator[B] {
private[this] var origElems = self
private[this] var origElems: Iterator[B]^ = self
// > 0 => that many more elems from `origElems` before switching to `patchElems`
// 0 => need to drop elems from `origElems` and start using `patchElems`
// -1 => have dropped elems from `origElems`, will be using `patchElems` until it's empty
Expand Down
2 changes: 1 addition & 1 deletion library/src/scala/collection/LazyZipOps.scala
Original file line number Diff line number Diff line change
Expand Up @@ -389,7 +389,7 @@ final class LazyZip4[+El1, +El2, +El3, +El4, C1] private[collection](src: C1,
}

private def toIterable: View[(El1, El2, El3, El4)]^{this} = new AbstractView[(El1, El2, El3, El4)] {
def iterator: AbstractIterator[(El1, El2, El3, El4)] = new AbstractIterator[(El1, El2, El3, El4)] {
def iterator: AbstractIterator[(El1, El2, El3, El4)]^{this} = new AbstractIterator[(El1, El2, El3, El4)] {
private[this] val elems1 = coll1.iterator
private[this] val elems2 = coll2.iterator
private[this] val elems3 = coll3.iterator
Expand Down
4 changes: 2 additions & 2 deletions library/src/scala/collection/SeqView.scala
Original file line number Diff line number Diff line change
Expand Up @@ -202,10 +202,10 @@ object SeqView {
override def knownSize: Int = len
override def isEmpty: Boolean = len == 0
override def to[C1](factory: Factory[A, C1]): C1 = _sorted.to(factory)
override def reverse: SeqView[A] = new ReverseSorted
override def reverse: SeqView[A]^{this} = new ReverseSorted
// we know `_sorted` is either tiny or has efficient random access,
// so this is acceptable for `reversed`
override protected def reversed: Iterable[A] = new ReverseSorted
override protected def reversed: Iterable[A]^{this} = new ReverseSorted

override def sorted[B1 >: A](implicit ord1: Ordering[B1]): SeqView[A]^{this} =
if (ord1 == this.ord) this
Expand Down
6 changes: 3 additions & 3 deletions library/src/scala/collection/Stepper.scala
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,7 @@ trait IntStepper extends Stepper[Int] {

def spliterator[B >: Int]: Spliterator.OfInt^{this} = new IntStepper.IntStepperSpliterator(this)

def javaIterator[B >: Int]: PrimitiveIterator.OfInt = new PrimitiveIterator.OfInt {
def javaIterator[B >: Int]: PrimitiveIterator.OfInt^{this} = new PrimitiveIterator.OfInt {
def hasNext: Boolean = hasStep
def nextInt(): Int = nextStep()
}
Expand Down Expand Up @@ -298,7 +298,7 @@ trait DoubleStepper extends Stepper[Double] {

def spliterator[B >: Double]: Spliterator.OfDouble^{this} = new DoubleStepper.DoubleStepperSpliterator(this)

def javaIterator[B >: Double]: PrimitiveIterator.OfDouble = new PrimitiveIterator.OfDouble {
def javaIterator[B >: Double]: PrimitiveIterator.OfDouble^{this} = new PrimitiveIterator.OfDouble {
def hasNext: Boolean = hasStep
def nextDouble(): Double = nextStep()
}
Expand Down Expand Up @@ -337,7 +337,7 @@ trait LongStepper extends Stepper[Long] {

def spliterator[B >: Long]: Spliterator.OfLong^{this} = new LongStepper.LongStepperSpliterator(this)

def javaIterator[B >: Long]: PrimitiveIterator.OfLong = new PrimitiveIterator.OfLong {
def javaIterator[B >: Long]: PrimitiveIterator.OfLong^{this} = new PrimitiveIterator.OfLong {
def hasNext: Boolean = hasStep
def nextLong(): Long = nextStep()
}
Expand Down
4 changes: 2 additions & 2 deletions library/src/scala/collection/View.scala
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ object View extends IterableFactory[View] {

@SerialVersionUID(3L)
class LeftPartitionMapped[A, A1, A2](underlying: SomeIterableOps[A]^, f: A => Either[A1, A2]) extends AbstractView[A1] {
def iterator: AbstractIterator[A1] = new AbstractIterator[A1] {
def iterator: AbstractIterator[A1]^{this} = new AbstractIterator[A1] {
private[this] val self = underlying.iterator
private[this] var hd: A1 = _
private[this] var hdDefined: Boolean = false
Expand All @@ -197,7 +197,7 @@ object View extends IterableFactory[View] {

@SerialVersionUID(3L)
class RightPartitionMapped[A, A1, A2](underlying: SomeIterableOps[A]^, f: A => Either[A1, A2]) extends AbstractView[A2] {
def iterator: AbstractIterator[A2] = new AbstractIterator[A2] {
def iterator: AbstractIterator[A2]^{this} = new AbstractIterator[A2] {
private[this] val self = underlying.iterator
private[this] var hd: A2 = _
private[this] var hdDefined: Boolean = false
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import scala.jdk._
* [[scala.jdk.javaapi.StreamConverters]].
*/
trait StreamExtensions {
this: StreamExtensions =>
// collections

implicit class IterableHasSeqStream[A](cc: IterableOnce[A]) {
Expand Down
2 changes: 1 addition & 1 deletion library/src/scala/collection/mutable/HashTable.scala
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ private[collection] trait HashTable[A, B, Entry >: Null <: HashEntry[A, Entry]]

/** An iterator returning all entries.
*/
def entriesIterator: Iterator[Entry] = new AbstractIterator[Entry] {
def entriesIterator: Iterator[Entry]^{this} = new AbstractIterator[Entry] {
val iterTable = table
var idx = lastPopulatedIndex
var es = iterTable(idx)
Expand Down
Loading
Loading