Skip to content
7 changes: 5 additions & 2 deletions compiler/src/dotty/tools/dotc/Run.scala
Original file line number Diff line number Diff line change
Expand Up @@ -365,8 +365,11 @@ class Run(comp: Compiler, ictx: Context) extends ImplicitRunInfo with Constraint
.setTyper(new Typer)
.addMode(Mode.ImplicitsEnabled)
.setTyperState(ctx.typerState.fresh(ctx.reporter))
if ctx.settings.YexplicitNulls.value && !Feature.enabledBySetting(nme.unsafeNulls) then
start = start.addMode(Mode.SafeNulls)
if ctx.settings.YexplicitNulls.value then
if !Feature.enabledBySetting(nme.unsafeNulls) then
start = start.addMode(Mode.SafeNulls)
if Feature.enabledBySetting(Feature.unsafeJavaReturn) then
start = start.addMode(Mode.UnsafeJavaReturn)
ctx.initialize()(using start) // re-initialize the base context with start

// `this` must be unchecked for safe initialization because by being passed to setRun during
Expand Down
1 change: 1 addition & 0 deletions compiler/src/dotty/tools/dotc/config/Feature.scala
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ object Feature:
val symbolLiterals = deprecated("symbolLiterals")
val fewerBraces = experimental("fewerBraces")
val saferExceptions = experimental("saferExceptions")
val unsafeJavaReturn = experimental("unsafeJavaReturn")

/** Is `feature` enabled by by a command-line setting? The enabling setting is
*
Expand Down
20 changes: 14 additions & 6 deletions compiler/src/dotty/tools/dotc/core/Contexts.scala
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import scala.io.Codec
import collection.mutable
import printing._
import config.{JavaPlatform, SJSPlatform, Platform, ScalaSettings}
import config.Feature
import classfile.ReusableDataReader
import StdNames.nme

Expand Down Expand Up @@ -642,12 +643,19 @@ object Contexts {
def setProfiler(profiler: Profiler): this.type = updateStore(profilerLoc, profiler)
def setNotNullInfos(notNullInfos: List[NotNullInfo]): this.type = updateStore(notNullInfosLoc, notNullInfos)
def setImportInfo(importInfo: ImportInfo): this.type =
importInfo.mentionsFeature(nme.unsafeNulls) match
case Some(true) =>
setMode(this.mode &~ Mode.SafeNulls)
case Some(false) if ctx.settings.YexplicitNulls.value =>
setMode(this.mode | Mode.SafeNulls)
case _ =>
if ctx.settings.YexplicitNulls.value then
importInfo.mentionsFeature(nme.unsafeNulls) match
case Some(true) =>
setMode(this.mode &~ Mode.SafeNulls)
case Some(false) =>
setMode(this.mode | Mode.SafeNulls)
case _ =>
importInfo.mentionsFeature(Feature.unsafeJavaReturn) match
case Some(true) =>
setMode(this.mode | Mode.UnsafeJavaReturn)
case Some(false) =>
setMode(this.mode &~ Mode.UnsafeJavaReturn)
case _ =>
updateStore(importInfoLoc, importInfo)
def setTypeAssigner(typeAssigner: TypeAssigner): this.type = updateStore(typeAssignerLoc, typeAssigner)

Expand Down
1 change: 1 addition & 0 deletions compiler/src/dotty/tools/dotc/core/Definitions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -989,6 +989,7 @@ class Definitions {
@tu lazy val FunctionalInterfaceAnnot: ClassSymbol = requiredClass("java.lang.FunctionalInterface")
@tu lazy val TargetNameAnnot: ClassSymbol = requiredClass("scala.annotation.targetName")
@tu lazy val VarargsAnnot: ClassSymbol = requiredClass("scala.annotation.varargs")
@tu lazy val CanEqualNullAnnot: ClassSymbol = requiredClass("scala.annotation.CanEqualNull")

@tu lazy val JavaRepeatableAnnot: ClassSymbol = requiredClass("java.lang.annotation.Repeatable")

Expand Down
2 changes: 2 additions & 0 deletions compiler/src/dotty/tools/dotc/core/Mode.scala
Original file line number Diff line number Diff line change
Expand Up @@ -129,4 +129,6 @@ object Mode {
* Type `Null` becomes a subtype of non-primitive value types in TypeComparer.
*/
val RelaxedOverriding: Mode = newMode(30, "RelaxedOverriding")

val UnsafeJavaReturn: Mode = newMode(31, "UnsafeJavaReturn")
}
61 changes: 56 additions & 5 deletions compiler/src/dotty/tools/dotc/core/NullOpsDecorator.scala
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
package dotty.tools.dotc
package core

import Annotations._
import Contexts._
import Flags._
import Symbols._
import Types._
import transform.SymUtils._

/** Defines operations on nullable types and tree. */
object NullOpsDecorator:
Expand Down Expand Up @@ -42,6 +46,24 @@ object NullOpsDecorator:
if ctx.explicitNulls then strip(self) else self
}

/** Strips `|Null` from the return type of a Java method,
* replacing it with a `@CanEqualNull` annotation
*/
def replaceOrNull(using Context): Type =
// Since this method should only be called on types from Java,
// handling these cases is enough.
def recur(tp: Type): Type = tp match
case tp @ OrType(lhs, rhs) if rhs.isNullType =>
AnnotatedType(recur(lhs), Annotation(defn.CanEqualNullAnnot))
case tp: AndOrType =>
tp.derivedAndOrType(recur(tp.tp1), recur(tp.tp2))
case tp @ AppliedType(tycon, targs) =>
tp.derivedAppliedType(tycon, targs.map(recur))
case mptp: MethodOrPoly =>
mptp.derivedLambdaType(resType = recur(mptp.resType))
case _ => tp
if ctx.explicitNulls then recur(self) else self

/** Is self (after widening and dealiasing) a type of the form `T | Null`? */
def isNullableUnion(using Context): Boolean = {
val stripped = self.stripNull
Expand All @@ -51,10 +73,39 @@ object NullOpsDecorator:

import ast.tpd._

extension (self: Tree)
extension (tree: Tree)

// cast the type of the tree to a non-nullable type
def castToNonNullable(using Context): Tree = self.typeOpt match {
case OrNull(tp) => self.cast(tp)
case _ => self
}
def castToNonNullable(using Context): Tree = tree.typeOpt match
case OrNull(tp) => tree.cast(tp)
case _ => tree

def tryToCastToCanEqualNull(using Context): Tree =
// return the tree directly if not at Typer phase
if !(ctx.explicitNulls && ctx.phase.isTyper) then return tree

val sym = tree.symbol
val tp = tree.tpe

if !ctx.mode.is(Mode.UnsafeJavaReturn)
|| !sym.is(JavaDefined)
|| sym.isNoValue
|| !sym.isTerm
|| tp.isError then
return tree

tree match
case _: Apply if sym.is(Method) =>
val tp2 = tp.replaceOrNull
if tp ne tp2 then
tree.cast(tp2)
else tree
case _: Select | _: Ident if !sym.is(Method) =>
val tpw = tp.widen
val tp2 = tpw.replaceOrNull
if tpw ne tp2 then
tree.cast(tp2)
else tree
case _ => tree

end NullOpsDecorator
8 changes: 6 additions & 2 deletions compiler/src/dotty/tools/dotc/core/TypeComparer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -756,8 +756,12 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling
}
compareTypeBounds
case tp2: AnnotatedType if tp2.isRefining =>
(tp1.derivesAnnotWith(tp2.annot.sameAnnotation) || tp1.isBottomType) &&
recur(tp1, tp2.parent)
// `CanEqualNull` is a special refining annotation.
// An annotated type is equivalent to the original type.
(tp1.derivesAnnotWith(tp2.annot.sameAnnotation)
|| tp2.annot.matches(defn.CanEqualNullAnnot)
|| tp1.isBottomType)
&& recur(tp1, tp2.parent)
case ClassInfo(pre2, cls2, _, _, _) =>
def compareClassInfo = tp1 match {
case ClassInfo(pre1, cls1, _, _, _) =>
Expand Down
3 changes: 2 additions & 1 deletion compiler/src/dotty/tools/dotc/typer/Applications.scala
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import reporting._
import transform.TypeUtils._
import transform.SymUtils._
import Nullables._
import NullOpsDecorator._
import config.Feature

import collection.mutable
Expand Down Expand Up @@ -908,7 +909,7 @@ trait Applications extends Compatibility {
def simpleApply(fun1: Tree, proto: FunProto)(using Context): Tree =
methPart(fun1).tpe match {
case funRef: TermRef =>
val app = ApplyTo(tree, fun1, funRef, proto, pt)
val app = ApplyTo(tree, fun1, funRef, proto, pt).tryToCastToCanEqualNull
convertNewGenericArray(
widenEnumCase(
postProcessByNameArgs(funRef, app).computeNullable(),
Expand Down
19 changes: 16 additions & 3 deletions compiler/src/dotty/tools/dotc/typer/Synthesizer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ class Synthesizer(typer: Typer)(using @constructorOnly c: Context):
}

/** Is an `CanEqual[cls1, cls2]` instance assumed for predefined classes `cls1`, cls2`? */
def canComparePredefinedClasses(cls1: ClassSymbol, cls2: ClassSymbol): Boolean =
def canComparePredefinedClasses(cls1: ClassSymbol, cls2: ClassSymbol)(using Context): Boolean =

def cmpWithBoxed(cls1: ClassSymbol, cls2: ClassSymbol) =
cls2 == defn.NothingClass
Expand All @@ -164,15 +164,17 @@ class Synthesizer(typer: Typer)(using @constructorOnly c: Context):
cmpWithBoxed(cls2, cls1)
else if ctx.mode.is(Mode.SafeNulls) then
// If explicit nulls is enabled, and unsafeNulls is not enabled,
// and the types don't have `@CanEqualNull` annotation,
// we want to disallow comparison between Object and Null.
// If we have to check whether a variable with a non-nullable type has null value
// (for example, a NotNull java method returns null for some reasons),
// we can still cast it to a nullable type then compare its value.
// we can still use `eq/ne null` or cast it to a nullable type then compare its value.
//
// Example:
// val x: String = null.asInstanceOf[String]
// if (x == null) {} // error: x is non-nullable
// if (x.asInstanceOf[String|Null] == null) {} // ok
// if (x eq null) {} // ok
cls1 == defn.NullClass && cls1 == cls2
else if cls1 == defn.NullClass then
cls1 == cls2 || cls2.derivesFrom(defn.ObjectClass)
Expand All @@ -187,9 +189,20 @@ class Synthesizer(typer: Typer)(using @constructorOnly c: Context):
* interpret.
*/
def canComparePredefined(tp1: Type, tp2: Type) =
// In explicit nulls, when one of type has `@CanEqualNull` annotation,
// we use unsafe nulls semantic to check, which allows reference types
// to be compared with `Null`.
// Example:
// val s1: String = ???
// s1 == null // error
// val s2: String @CanEqualNull = ???
// s2 == null // ok
val checkCtx = if ctx.explicitNulls
&& (tp1.hasAnnotation(defn.CanEqualNullAnnot) || tp2.hasAnnotation(defn.CanEqualNullAnnot))
then ctx.retractMode(Mode.SafeNulls) else ctx
tp1.classSymbols.exists(cls1 =>
tp2.classSymbols.exists(cls2 =>
canComparePredefinedClasses(cls1, cls2)))
canComparePredefinedClasses(cls1, cls2)(using checkCtx)))

formal.argTypes match
case args @ (arg1 :: arg2 :: Nil) =>
Expand Down
8 changes: 5 additions & 3 deletions compiler/src/dotty/tools/dotc/typer/Typer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -548,7 +548,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
ref(ownType).withSpan(tree.span)
case _ =>
tree.withType(ownType)
val tree2 = toNotNullTermRef(tree1, pt)
val tree2 = toNotNullTermRef(tree1, pt).tryToCastToCanEqualNull
checkLegalValue(tree2, pt)
tree2

Expand Down Expand Up @@ -646,7 +646,8 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer

def typeSelectOnTerm(using Context): Tree =
val qual = typedExpr(tree.qualifier, shallowSelectionProto(tree.name, pt, this))
typedSelect(tree, pt, qual).withSpan(tree.span).computeNullable()
val sel = typedSelect(tree, pt, qual).withSpan(tree.span).computeNullable()
if pt != AssignProto then sel.tryToCastToCanEqualNull else sel

def javaSelectOnType(qual: Tree)(using Context) =
// semantic name conversion for `O$` in java code
Expand Down Expand Up @@ -3679,7 +3680,8 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
}
simplify(typed(etaExpand(tree, wtp, arity), pt), pt, locked)
else if (wtp.paramInfos.isEmpty && isAutoApplied(tree.symbol))
readaptSimplified(tpd.Apply(tree, Nil))
val app = tpd.Apply(tree, Nil).tryToCastToCanEqualNull
readaptSimplified(app)
else if (wtp.isImplicitMethod)
err.typeMismatch(tree, pt)
else
Expand Down
1 change: 1 addition & 0 deletions compiler/test/dotty/tools/dotc/CompilationTests.scala
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,7 @@ class CompilationTests {
compileFilesInDir("tests/explicit-nulls/pos-separate", explicitNullsOptions),
compileFilesInDir("tests/explicit-nulls/pos-patmat", explicitNullsOptions and "-Xfatal-warnings"),
compileFilesInDir("tests/explicit-nulls/unsafe-common", explicitNullsOptions and "-language:unsafeNulls"),
compileFilesInDir("tests/explicit-nulls/unsafe-java", explicitNullsOptions),
compileFile("tests/explicit-nulls/pos-special/i14682.scala", explicitNullsOptions and "-Ysafe-init"),
compileFile("tests/explicit-nulls/pos-special/i14947.scala", explicitNullsOptions and "-Ytest-pickler" and "-Xprint-types"),
)
Expand Down
22 changes: 22 additions & 0 deletions library/src/scala/annotation/CanEqualNull.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
package scala.annotation

/** An annotation makes reference types comparable to `null` in explicit nulls.
* `CanEqualNull` is a special refining annotation. An annotated type is equivalent to the original type.
*
* For example:
* ```scala
* val s1: String = ???
* s1 == null // error
* val s2: String @CanEqualNull = ???
* s2 == null // ok
*
* // String =:= String @CanEqualNull
* val s3: String = s2
* val s4: String @CanEqualNull = s1
*
* val ss: Array[String @CanEqualNull] = ???
* ss.map(_ == null)
* ```
*/
@experimental
final class CanEqualNull extends RefiningAnnotation
5 changes: 5 additions & 0 deletions library/src/scala/runtime/stdLibPatches/language.scala
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,11 @@ object language:
@compileTimeOnly("`saferExceptions` can only be used at compile time in import statements")
object saferExceptions

/** Experimental support for unsafe Java return in explicit nulls
*/
@compileTimeOnly("`unsafeJavaReturn` can only be used at compile time in import statements")
object unsafeJavaReturn

end experimental

/** The deprecated object contains features that are no longer officially suypported in Scala.
Expand Down
4 changes: 2 additions & 2 deletions project/MiMaFilters.scala
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,6 @@ object MiMaFilters {
ProblemFilters.exclude[MissingClassProblem]("scala.runtime.TupleMirror"),
ProblemFilters.exclude[MissingTypesProblem]("scala.Tuple$package$EmptyTuple$"), // we made the empty tuple a case object
ProblemFilters.exclude[DirectMissingMethodProblem]("scala.runtime.Scala3RunTime.nnFail"),
ProblemFilters.exclude[DirectMissingMethodProblem]("scala.runtime.Scala3RunTime.nnFail"),
ProblemFilters.exclude[DirectMissingMethodProblem]("scala.runtime.LazyVals.getOffsetStatic"), // Added for #14780
ProblemFilters.exclude[DirectMissingMethodProblem]("scala.runtime.LazyVals.getOffsetStatic"), // Added for #14780
ProblemFilters.exclude[MissingFieldProblem]("scala.runtime.stdLibPatches.language.3.2-migration"),
ProblemFilters.exclude[MissingFieldProblem]("scala.runtime.stdLibPatches.language.3.2"),
Expand All @@ -28,6 +26,8 @@ object MiMaFilters {
ProblemFilters.exclude[ReversedMissingMethodProblem]("scala.quoted.Quotes#reflectModule#SymbolMethods.typeRef"),
ProblemFilters.exclude[ReversedMissingMethodProblem]("scala.quoted.Quotes#reflectModule#SymbolMethods.termRef"),
ProblemFilters.exclude[ReversedMissingMethodProblem]("scala.quoted.Quotes#reflectModule#TypeTreeModule.ref"),
ProblemFilters.exclude[MissingFieldProblem]("scala.runtime.stdLibPatches.language#experimental.unsafeJavaReturn"),
ProblemFilters.exclude[MissingClassProblem]("scala.runtime.stdLibPatches.language$experimental$unsafeJavaReturn$"),

ProblemFilters.exclude[MissingClassProblem]("scala.annotation.since"),
)
Expand Down
14 changes: 14 additions & 0 deletions tests/explicit-nulls/unsafe-java/JavaStatic.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import language.experimental.unsafeJavaReturn

import java.math.MathContext, MathContext._

val x: MathContext = DECIMAL32
val y: MathContext = MathContext.DECIMAL32

import java.io.File

val s: String = File.separator
import java.time.ZoneId

val zids: java.util.Set[String] = ZoneId.getAvailableZoneIds
val zarr: Array[String] = ZoneId.getAvailableZoneIds.toArray(Array.empty[String | Null])
11 changes: 11 additions & 0 deletions tests/explicit-nulls/unsafe-java/UnaryCall.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
import scala.language.experimental.unsafeJavaReturn

import java.lang.reflect.Method

def getMethods(f: String): List[Method] =
val clazz = Class.forName(f)
val methods = clazz.getMethods
if methods == null then List()
else methods.toList

def getClass(o: AnyRef): Class[?] = o.getClass
7 changes: 7 additions & 0 deletions tests/explicit-nulls/unsafe-java/java-chain/J.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
class J1 {
J2 getJ2() { return new J2(); }
}

class J2 {
J1 getJ1() { return new J1(); }
}
6 changes: 6 additions & 0 deletions tests/explicit-nulls/unsafe-java/java-chain/S.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
import scala.language.experimental.unsafeJavaReturn

def f = {
val j: J2 = new J2()
j.getJ1().getJ2().getJ1().getJ2().getJ1().getJ2()
}
Loading