diff --git a/compiler/src/dotty/tools/dotc/transform/init/Cache.scala b/compiler/src/dotty/tools/dotc/transform/init/Cache.scala index c0391a05262d..8c4ee0112219 100644 --- a/compiler/src/dotty/tools/dotc/transform/init/Cache.scala +++ b/compiler/src/dotty/tools/dotc/transform/init/Cache.scala @@ -93,8 +93,6 @@ class Cache[Config, Res]: * * The algorithmic skeleton is as follows: * - * if don't cache result then - * return eval(expr) * if this.current.contains(config, expr) then * return cached value * else @@ -107,32 +105,28 @@ class Cache[Config, Res]: * this.current(config, expr) = actual * */ - def cachedEval(config: Config, expr: Tree, cacheResult: Boolean, default: Res)(eval: Tree => Res): Res = - if !cacheResult then - eval(expr) - else - this.get(config, expr) match - case Some(value) => value - case None => - val assumeValue: Res = - this.last.get(config, expr) match - case Some(value) => value - case None => - this.last = this.last.updatedNested(config, expr, default) - default - - this.current = this.current.updatedNested(config, expr, assumeValue) - - val actual = eval(expr) - if actual != assumeValue then - // println("Changed! from = " + assumeValue + ", to = " + actual) - this.changed = true - this.current = this.current.updatedNested(config, expr, actual) - // this.current = this.current.removed(config, expr) - end if - - actual - end if + def cachedEval(config: Config, expr: Tree, default: Res)(doEval: => Res): Res = + this.get(config, expr) match + case Some(value) => value + case None => + val assumeValue: Res = + this.last.get(config, expr) match + case Some(value) => value + case None => + this.last = this.last.updatedNested(config, expr, default) + default + + this.current = this.current.updatedNested(config, expr, assumeValue) + + val actual = doEval + if actual != assumeValue then + // println("Changed! from = " + assumeValue + ", to = " + actual) + this.changed = true + this.current = this.current.updatedNested(config, expr, actual) + // this.current = this.current.removed(config, expr) + end if + + actual end cachedEval def hasChanged = changed diff --git a/compiler/src/dotty/tools/dotc/transform/init/Objects.scala b/compiler/src/dotty/tools/dotc/transform/init/Objects.scala index 52760cf8b6c7..e0c158855528 100644 --- a/compiler/src/dotty/tools/dotc/transform/init/Objects.scala +++ b/compiler/src/dotty/tools/dotc/transform/init/Objects.scala @@ -179,26 +179,49 @@ class Objects(using Context @constructorOnly): * * Note that the 2nd parameter block does not take part in the definition of equality. */ - case class OfClass private ( - klass: ClassSymbol, outer: Value, ctor: Symbol, args: List[Value], env: Env.Data)( - valsMap: mutable.Map[Symbol, Value], varsMap: mutable.Map[Symbol, Heap.Addr], outersMap: mutable.Map[ClassSymbol, Value]) + class OfClass private( + val klass: ClassSymbol, + val env: Env.Data, + valsMap: mutable.Map[Symbol, Value], + varsMap: mutable.Map[Symbol, Heap.Addr], + outersMap: mutable.Map[ClassSymbol, Value]) extends Ref(valsMap, varsMap, outersMap): - def widenedCopy(outer: Value, args: List[Value], env: Env.Data): OfClass = - new OfClass(klass, outer, ctor, args, env)(this.valsMap, this.varsMap, this.outersMap) + override def equals(that: Any): Boolean = + that match + case ref: OfClass => + this.klass == ref.klass + && this.vals == ref.vals + && this.vars == ref.vars + && this.outers == ref.outers + && this.env == ref.env + + case _ => false + + override def hashCode(): Int = + this.klass.hashCode + + this.vals.hashCode + + this.vars.hashCode + + this.outers.hashCode + + this.env.hashCode + + def widen(height: Int): OfClass = + val vals2 = vals.map { (k, v) => k -> v.widen(height) } + val outers2 = outers.map { (k, v) => k -> v.widen(height) } + val env2 = env.widen(height) + new OfClass(klass, env2, vals2, this.varsMap, outers2) def show(using Context) = val valFields = vals.map(_.show + " -> " + _.show) - "OfClass(" + klass.show + ", outer = " + outer + ", args = " + args.map(_.show) + ", vals = " + valFields + ")" + "OfClass(" + klass.show + ", vals = " + vals + ", vars = " + vars + ", outers = " + outers + ")" object OfClass: - def apply( - klass: ClassSymbol, outer: Value, ctor: Symbol, args: List[Value], env: Env.Data)( - using Context - ): OfClass = - val instance = new OfClass(klass, outer, ctor, args, env)( - valsMap = mutable.Map.empty, varsMap = mutable.Map.empty, outersMap = mutable.Map.empty + def apply(klass: ClassSymbol, env: Env.Data)(using Context): OfClass = + val instance = new OfClass( + klass, env, + valsMap = mutable.Map.empty, + varsMap = mutable.Map.empty, + outersMap = mutable.Map.empty ) - instance.initOuter(klass, outer) instance /** @@ -452,7 +475,7 @@ class Objects(using Context @constructorOnly): case NoEnv => thisV match case ref: OfClass => - ref.outer match + ref.outerValue(ref.klass) match case outer : ThisValue => resolveEnv(meth, outer, ref.env) case _ => @@ -527,21 +550,54 @@ class Objects(using Context @constructorOnly): /** Cache used to terminate the check */ object Cache: - case class Config(thisV: Value, env: Env.Data, heap: Heap.Data) + enum Config: + val heap: Heap.Data + + /** The cache key for instantiation does not contain the value for `this` */ + case Instantiation(klass: ClassSymbol, outer: Value, ctor: Symbol, args: List[Value], env: Env.Data, heap: Heap.Data) + + /** The cache key for method calls contain the value for `this` */ + case Call(thisV: Value, env: Env.Data, heap: Heap.Data) + case class Res(value: Value, heap: Heap.Data) class Data extends Cache[Config, Res]: - def get(thisV: Value, expr: Tree)(using Heap.MutableData, Env.Data): Option[Value] = - val config = Config(thisV, summon[Env.Data], Heap.getHeapData()) - super.get(config, expr).map(_.value) - - def cachedEval(thisV: ThisValue, expr: Tree, cacheResult: Boolean)(fun: Tree => Value)(using Heap.MutableData, Env.Data): Value = - val config = Config(thisV, summon[Env.Data], Heap.getHeapData()) - val result = super.cachedEval(config, expr, cacheResult, default = Res(Bottom, Heap.getHeapData())) { expr => - Res(fun(expr), Heap.getHeapData()) + def cachedInstantiate(klass: ClassSymbol, outer: Value, ctor: Symbol, argInfos: List[ArgInfo], envOuter: Env.Data): Contextual[Value] = + val args = argInfos.map(_.value) + val instance = OfClass(klass, envOuter) + val config = Config.Instantiation(klass, outer, ctor, args, envOuter, Heap.getHeapData()) + val result = super.cachedEval(config, klass.defTree, default = Res(instance, config.heap)) { + given Env.Data = Env.NoEnv + instance.initOuter(klass, outer) + callConstructor(instance, ctor, argInfos) + Res(instance, Heap.getHeapData()) } Heap.setHeap(result.heap) result.value + + def cachedEval(thisV: ThisValue, expr: Tree, klass: ClassSymbol, ctx: EvalContext): Contextual[Value] = + ctx match + case EvalContext.Other => + cases(expr, thisV, klass) + + case _ => + val env = summon[Env.Data] + val config = Config.Call(thisV, env, Heap.getHeapData()) + val result = super.cachedEval(config, expr, default = Res(Bottom, config.heap)) { + ctx match + case EvalContext.Method(meth) => + Returns.installHandler(meth) + val res = cases(expr, thisV, klass) + val returns = Returns.popHandler(meth) + Res(res.join(returns), Heap.getHeapData()) + + case _ => + val resValue = cases(expr, thisV, klass) + Res(resValue, Heap.getHeapData()) + } + Heap.setHeap(result.heap) + result.value + end Cache /** @@ -603,24 +659,23 @@ class Objects(using Context @constructorOnly): case (a : ValueElement, b : ValueElement) => ValueSet(ListSet(a, b)) def widen(height: Int)(using Context): Value = - if height == 0 then Cold - else - a match - case Bottom => Bottom + a match + case Bottom => Bottom - case ValueSet(values) => - values.map(ref => ref.widen(height)).join + case ValueSet(values) => + values.map(ref => ref.widen(height)).join - case Fun(code, thisV, klass, env) => - Fun(code, thisV.widenRefOrCold(height), klass, env.widen(height - 1)) + case Fun(code, thisV, klass, env) => + if height == 0 then Cold + else Fun(code, thisV.widenRefOrCold(height), klass, env.widen(height - 1)) - case ref @ OfClass(klass, outer, _, args, env) => - val outer2 = outer.widen(height - 1) - val args2 = args.map(_.widen(height - 1)) - val env2 = env.widen(height - 1) - ref.widenedCopy(outer2, args2, env2) + case ref: OfClass => + if height == 0 then + Cold + else + ref.widen(height - 1) - case _ => a + case _ => a def filterType(tpe: Type)(using Context): Value = tpe match @@ -722,12 +777,7 @@ class Objects(using Context @constructorOnly): val env2 = Env.of(ddef, args.map(_.value), outerEnv) extendTrace(ddef) { given Env.Data = env2 - cache.cachedEval(ref, ddef.rhs, cacheResult = true) { expr => - Returns.installHandler(meth) - val res = cases(expr, thisV, cls) - val returns = Returns.popHandler(meth) - res.join(returns) - } + eval(ddef.rhs, thisV, cls, EvalContext.Method(meth)) } else Bottom @@ -751,7 +801,7 @@ class Objects(using Context @constructorOnly): case ddef: DefDef => if meth.name == nme.apply then given Env.Data = Env.of(ddef, args.map(_.value), env) - extendTrace(code) { eval(ddef.rhs, thisV, klass, cacheResult = true) } + extendTrace(code) { eval(ddef.rhs, thisV, klass, EvalContext.Function) } else // The methods defined in `Any` and `AnyRef` are trivial and don't affect initialization. if meth.owner == defn.AnyClass || meth.owner == defn.ObjectClass then @@ -766,7 +816,7 @@ class Objects(using Context @constructorOnly): case _ => // by-name closure given Env.Data = env - extendTrace(code) { eval(code, thisV, klass, cacheResult = true) } + extendTrace(code) { eval(code, thisV, klass, EvalContext.Function) } case ValueSet(vs) => vs.map(v => call(v, meth, args, receiver, superType)).join @@ -778,7 +828,7 @@ class Objects(using Context @constructorOnly): * @param ctor The symbol of the target method. * @param args Arguments of the constructor call (all parameter blocks flatten to a list). */ - def callConstructor(value: Value, ctor: Symbol, args: List[ArgInfo]): Contextual[Value] = log("call " + ctor.show + ", args = " + args.map(_.value.show), printer, (_: Value).show) { + def callConstructor(value: Value, ctor: Symbol, args: List[ArgInfo]): Contextual[Unit] = log("call " + ctor.show + ", args = " + args.map(_.value.show), printer, (_: Value).show) { value match case ref: Ref => if ctor.hasSource then @@ -789,11 +839,13 @@ class Objects(using Context @constructorOnly): given Env.Data = Env.of(ddef, argValues, Env.NoEnv) if ctor.isPrimaryConstructor then val tpl = cls.defTree.asInstanceOf[TypeDef].rhs.asInstanceOf[Template] - extendTrace(cls.defTree) { eval(tpl, ref, cls, cacheResult = true) } + extendTrace(cls.defTree): + init(tpl, ref, cls) else - extendTrace(ddef) { // The return values for secondary constructors can be ignored + extendTrace(ddef) { + // `return` is possible in secondary constructors Returns.installHandler(ctor) - eval(ddef.rhs, ref, cls, cacheResult = true) + eval(ddef.rhs, ref, cls, EvalContext.Other) Returns.popHandler(ctor) } else @@ -824,7 +876,7 @@ class Objects(using Context @constructorOnly): given Env.Data = Env.emptyEnv(target.owner.asInstanceOf[ClassSymbol].primaryConstructor) if target.hasSource then val rhs = target.defTree.asInstanceOf[ValDef].rhs - eval(rhs, ref, target.owner.asClass, cacheResult = true) + eval(rhs, ref, target.owner.asClass, EvalContext.LazyVal) else Bottom else if target.exists then @@ -916,8 +968,6 @@ class Objects(using Context @constructorOnly): /** * Handle new expression `new p.C(args)`. - * The actual instance might be cached without running the constructor. - * See tests/init-global/pos/cache-constructor.scala * * @param outer The value for `p`. * @param klass The symbol of the class `C`. @@ -957,8 +1007,7 @@ class Objects(using Context @constructorOnly): // klass.enclosingMethod returns its primary constructor Env.resolveEnv(klass.owner.enclosingMethod, thisV, summon[Env.Data]).getOrElse(Cold -> Env.NoEnv) - val instance = OfClass(klass, outerWidened, ctor, args.map(_.value), envWidened) - callConstructor(instance, ctor, args) + cache.cachedInstantiate(klass, outerWidened, ctor, args, envWidened) case ValueSet(values) => values.map(ref => instantiate(ref, klass, ctor, args)).join @@ -1006,7 +1055,7 @@ class Objects(using Context @constructorOnly): given Env.Data = env if sym.is(Flags.Lazy) then val rhs = sym.defTree.asInstanceOf[ValDef].rhs - eval(rhs, thisV, sym.enclosingClass.asClass, cacheResult = true) + eval(rhs, thisV, sym.enclosingClass.asClass, EvalContext.LazyVal) else // Assume forward reference check is doing a good job val value = Env.valValue(sym) @@ -1079,6 +1128,12 @@ class Objects(using Context @constructorOnly): do accessObject(classSym) + enum EvalContext: + case Method(sym: Symbol) + case Function + case LazyVal + case Other + /** Evaluate an expression with the given value for `this` in a given class `klass` * * Note that `klass` might be a super class of the object referred by `thisV`. @@ -1097,13 +1152,12 @@ class Objects(using Context @constructorOnly): * @param expr The expression to be evaluated. * @param thisV The value for `C.this` where `C` is represented by the parameter `klass`. * @param klass The enclosing class where the expression is located. - * @param cacheResult It is used to reduce the size of the cache. + * @param ctx The context where `eval` is called. */ - def eval(expr: Tree, thisV: ThisValue, klass: ClassSymbol, cacheResult: Boolean = false): Contextual[Value] = log("evaluating " + expr.show + ", this = " + thisV.show + ", regions = " + Regions.show + " in " + klass.show, printer, (_: Value).show) { - cache.cachedEval(thisV, expr, cacheResult) { expr => cases(expr, thisV, klass) } + def eval(expr: Tree, thisV: ThisValue, klass: ClassSymbol, ctx: EvalContext = EvalContext.Other): Contextual[Value] = log("evaluating " + expr.show + ", this = " + thisV.show + ", regions = " + Regions.show + " in " + klass.show, printer, (_: Value).show) { + cache.cachedEval(thisV, expr, klass, ctx) } - /** Evaluate a list of expressions */ def evalExprs(exprs: List[Tree], thisV: ThisValue, klass: ClassSymbol): Contextual[List[Value]] = exprs.map { expr => eval(expr, thisV, klass) } @@ -1173,6 +1227,7 @@ class Objects(using Context @constructorOnly): val receiver = eval(qual, thisV, klass) if ref.symbol.isConstructor then withTrace(trace2) { callConstructor(receiver, ref.symbol, args) } + Bottom else withTrace(trace2) { call(receiver, ref.symbol, args, receiver = qual.tpe, superType = NoType) } @@ -1188,6 +1243,7 @@ class Objects(using Context @constructorOnly): val receiver = withTrace(trace2) { evalType(prefix, thisV, klass) } if id.symbol.isConstructor then withTrace(trace2) { callConstructor(receiver, id.symbol, args) } + Bottom else withTrace(trace2) { call(receiver, id.symbol, args, receiver = prefix, superType = NoType) } @@ -1310,9 +1366,6 @@ class Objects(using Context @constructorOnly): case _: Import | _: Export => Bottom - case tpl: Template => - init(tpl, thisV.asInstanceOf[Ref], klass) - case _ => report.warning("[Internal error] unexpected tree: " + expr + "\n" + Trace.show, expr) Bottom @@ -1583,17 +1636,17 @@ class Objects(using Context @constructorOnly): def widenEscapedValue(value: Value, annotatedTree: Tree): Contextual[Value] = def parseAnnotation: Option[Int] = annotatedTree.tpe.getAnnotation(defn.InitWidenAnnot).flatMap: annot => - annot.argument(0).get match - case arg @ Literal(c: Constants.Constant) => - val height = c.intValue - if height < 0 then - report.warning("The argument should be positive", arg) - None - else - Some(height) - case arg => - report.warning("The argument should be a constant integer value", arg) + annot.argument(0).get match + case arg @ Literal(c: Constants.Constant) => + val height = c.intValue + if height < 0 then + report.warning("The argument should be positive", arg) None + else + Some(height) + case arg => + report.warning("The argument should be a constant integer value", arg) + None end parseAnnotation parseAnnotation match @@ -1601,9 +1654,7 @@ class Objects(using Context @constructorOnly): value.widen(i) case None => - if value.isInstanceOf[Fun] - then value.widen(2) - else value.widen(1) + value.widen(2) /** Evaluate arguments of methods and constructors */ def evalArgs(args: List[Arg], thisV: ThisValue, klass: ClassSymbol): Contextual[List[ArgInfo]] = @@ -1657,7 +1708,6 @@ class Objects(using Context @constructorOnly): tasks.append { () => printer.println("init super class " + cls.show) callConstructor(thisV, ctor, args) - () } // parents diff --git a/compiler/src/dotty/tools/dotc/transform/init/Semantic.scala b/compiler/src/dotty/tools/dotc/transform/init/Semantic.scala index 85b2764ff0f3..2f266243cb80 100644 --- a/compiler/src/dotty/tools/dotc/transform/init/Semantic.scala +++ b/compiler/src/dotty/tools/dotc/transform/init/Semantic.scala @@ -1181,7 +1181,10 @@ object Semantic: * @param cacheResult It is used to reduce the size of the cache. */ def eval(expr: Tree, thisV: Ref, klass: ClassSymbol, cacheResult: Boolean = false): Contextual[Value] = log("evaluating " + expr.show + ", this = " + thisV.show + " in " + klass.show, printer, (_: Value).show) { - cache.cachedEval(thisV, expr, cacheResult, default = Hot) { expr => cases(expr, thisV, klass) } + if cacheResult then + cache.cachedEval(thisV, expr, default = Hot) { cases(expr, thisV, klass) } + else + cases(expr, thisV, klass) } /** Evaluate a list of expressions */ diff --git a/compiler/src/dotty/tools/dotc/transform/init/Util.scala b/compiler/src/dotty/tools/dotc/transform/init/Util.scala index e11d0e1e21a5..d1529f9497c0 100644 --- a/compiler/src/dotty/tools/dotc/transform/init/Util.scala +++ b/compiler/src/dotty/tools/dotc/transform/init/Util.scala @@ -109,8 +109,3 @@ object Util: // A concrete class may not be instantiated if the self type is not satisfied instantiable && cls.enclosingPackageClass != defn.StdLibPatchesPackage.moduleClass - - /** Whether the class or its super class/trait contains any mutable fields? */ - def isMutable(cls: ClassSymbol)(using Context): Boolean = - cls.classInfo.decls.exists(_.is(Flags.Mutable)) || - cls.parentSyms.exists(parentCls => isMutable(parentCls.asClass)) diff --git a/compiler/test/dotc/neg-init-global-scala2-library-tasty.blacklist b/compiler/test/dotc/neg-init-global-scala2-library-tasty.blacklist index 48fe29ebc6bc..03b020db64d9 100644 --- a/compiler/test/dotc/neg-init-global-scala2-library-tasty.blacklist +++ b/compiler/test/dotc/neg-init-global-scala2-library-tasty.blacklist @@ -18,3 +18,4 @@ global-list.scala t5366.scala mutable-read7.scala t9115.scala +Color.scala \ No newline at end of file diff --git a/tests/init-global/warn/of-class-unsound.scala b/tests/init-global/warn/of-class-unsound.scala new file mode 100644 index 000000000000..108663187735 --- /dev/null +++ b/tests/init-global/warn/of-class-unsound.scala @@ -0,0 +1,25 @@ +class Box[T](val value: T) + +abstract class Base[T]: + def update(n: T): Unit + +class A[T](var a: T) extends Base[T]: + def update(n: T) = + a = n + +class B[T](var b: T) extends Base[T]: + def update(n: T) = + O.x // warn + b = n + +object O: + val m: Int = 10 + f(if m > 5 then Box(A(3)) else Box(B(4))) + + val x: Int = 10 + + def f(a: Box[Base[Int]]): Unit = + h(a.value) + + def h(a: Base[Int]): Unit = + a.update(10)