Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions compiler/src/dotty/tools/dotc/core/Annotations.scala
Original file line number Diff line number Diff line change
Expand Up @@ -315,4 +315,15 @@ object Annotations {
case Some(Constant(msg: String)) => Some(msg)
case _ => Some("")
}

object JavaRecordFieldsAnnotation {
def unapply(a: Annotation)(using Context): Option[List[String]] =
if a.symbol ne defn.JavaRecordFieldsAnnot then None
else
a.tree match
case Apply(_, List(Typed(SeqLiteral(args, _), _))) =>
val fields = args.collect { case Literal(Constant(s: String)) => s }
Some(fields)
case _ => None
}
}
1 change: 1 addition & 0 deletions compiler/src/dotty/tools/dotc/core/Definitions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1112,6 +1112,7 @@ class Definitions {
@tu lazy val PublicInBinaryAnnot: ClassSymbol = requiredClass("scala.annotation.publicInBinary")
@tu lazy val WitnessNamesAnnot: ClassSymbol = requiredClass("scala.annotation.internal.WitnessNames")
@tu lazy val StableNullAnnot: ClassSymbol = requiredClass("scala.annotation.stableNull")
@tu lazy val JavaRecordFieldsAnnot: ClassSymbol = requiredClass("scala.annotation.internal.JavaRecordFields")

@tu lazy val JavaRepeatableAnnot: ClassSymbol = requiredClass("java.lang.annotation.Repeatable")

Expand Down
1 change: 1 addition & 0 deletions compiler/src/dotty/tools/dotc/core/StdNames.scala
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,7 @@ object StdNames {
final val MethodParametersATTR: N = "MethodParameters"
final val LineNumberTableATTR: N = "LineNumberTable"
final val LocalVariableTableATTR: N = "LocalVariableTable"
final val RecordATTR: N = "Record"
final val RuntimeVisibleAnnotationATTR: N = "RuntimeVisibleAnnotations" // RetentionPolicy.RUNTIME
final val RuntimeInvisibleAnnotationATTR: N = "RuntimeInvisibleAnnotations" // RetentionPolicy.CLASS
final val RuntimeParamAnnotationATTR: N = "RuntimeVisibleParameterAnnotations" // RetentionPolicy.RUNTIME (annotations on parameters)
Expand Down
2 changes: 2 additions & 0 deletions compiler/src/dotty/tools/dotc/core/SymUtils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -403,6 +403,8 @@ class SymUtils:
|| isDefaultArgumentOfCheckedMethod
|| (!self.is(Package) && checkOwner(self.owner))

def isJavaRecord(using Context) = self.is(JavaDefined) && self.derivesFrom(defn.JavaRecordClass)

/** The declared self type of this class, as seen from `site`, stripping
* all refinements for opaque types.
*/
Expand Down
55 changes: 55 additions & 0 deletions compiler/src/dotty/tools/dotc/core/classfile/ClassfileParser.scala
Original file line number Diff line number Diff line change
Expand Up @@ -535,6 +535,35 @@ class ClassfileParser(
}
}

class RecordUnapplyCompleter() extends LazyType {
override def complete(denot: SymDenotation)(using Context): Unit =
def methType(t: Type) = MethodType(List(nme.x_0), List(t), t)

val unapplyMethodType =
val recType = classRoot.typeRef
val tparams = classRoot.typeParams
if tparams.length > 0 then
PolyType(tparams.map(_.name))(
pt => tparams.map(_.info.subst(tparams, pt.paramRefs).bounds),
pt => methType(AppliedType(recType, pt.paramRefs))
)
else methType(recType)

// synthetic unapply generated here won't be invalidated by `invalidateIfClashingSynthetic`, so we handle that immediately
val clashes = denot.owner.unforcedDecls.lookupAll(nme.unapply)
if clashes.exists(c => c != denot.symbol && c.info.matches(unapplyMethodType)) then
denot.info = NoType
else
denot.info = unapplyMethodType
val sym = denot.symbol
val ddef = DefDef(sym.asTerm, _.last.last)
.withAddedFlags(Flags.JavaDefined | Flags.SyntheticMethod | Flags.Inline).withSpan(Span(0))
sym.defTree = ddef

// typed trees generated here are not a subject to typer's inline logic, so we do that manually
inlines.PrepareInlineable.registerInlineInfo(sym, ddef.rhs)
}

def constantTagToType(tag: Int)(using Context): Type =
(tag: @switch) match {
case BYTE_TAG => defn.ByteType
Expand Down Expand Up @@ -990,6 +1019,32 @@ class ClassfileParser(
report.log(s"$sym in ${sym.owner} is a java 8+ default method.")
}

case tpnme.RecordATTR =>
val components = List.fill(in.nextChar):
val name = pool.getName(in.nextChar).value
val _ = in.nextChar
skipAttributes()
name

classRoot.addAnnotation(
Annotation(
defn.JavaRecordFieldsAnnot,
Typed(
SeqLiteral(components.map(field => Literal(Constant(field))), TypeTree(defn.StringType)),
TypeTree(defn.RepeatedParamType.appliedTo(defn.StringType))
),
NoSpan
)
)
val completer = RecordUnapplyCompleter()
val member = newSymbol(
moduleRoot.symbol,
nme.unapply,
Flags.JavaDefined | Flags.SyntheticMethod | Flags.Inline,
completer,
)
staticScope.enter(member)

case _ =>
}
in.bp = end
Expand Down
25 changes: 24 additions & 1 deletion compiler/src/dotty/tools/dotc/parsing/JavaParsers.scala
Original file line number Diff line number Diff line change
Expand Up @@ -900,8 +900,31 @@ object JavaParsers {
needsDummyConstr = true
)
).withMods(mods.withFlags(Flags.JavaDefined | Flags.Final))
.withAddedAnnotation(
New(
ref(defn.JavaRecordFieldsAnnot),
header.map(field => Literal(Constant(field.name.toString))) :: Nil,
).withSpan(Span(start))
)
}
addCompanionObject(statics, recordTypeDef)

val unapplyDef = {
val tparams2 = tparams.map(td => TypeDef(td.name, td.rhs).withMods(Modifiers(Flags.Param)))

val selfTpt = if tparams2.isEmpty then Ident(name) else
AppliedTypeTree(Ident(name), tparams2.map(tp => Ident(tp.name)))
val param = ValDef(nme.x_0, selfTpt, EmptyTree)
.withMods(Modifiers(Flags.JavaDefined | Flags.SyntheticParam))

DefDef(
nme.unapply,
joinParams(tparams2, List(List(param))),
selfTpt,
Ident(nme.x_0)
).withMods(Modifiers(Flags.JavaDefined | Flags.SyntheticMethod | Flags.Inline))
}

addCompanionObject(unapplyDef :: statics, recordTypeDef)
end recordDecl

def interfaceDecl(start: Offset, mods: Modifiers): List[Tree] = {
Expand Down
8 changes: 6 additions & 2 deletions compiler/src/dotty/tools/dotc/transform/PatternMatcher.scala
Original file line number Diff line number Diff line change
Expand Up @@ -422,8 +422,7 @@ object PatternMatcher {
(0 until unappResult.denot.info.tupleElementTypes.getOrElse(Nil).length)
.toList.map(tupleApp(_, ref(unappResult)))
matchArgsPlan(components, args, onSuccess)
else {
assert(isGetMatch(unappType))
else if (isGetMatch(unappType)) {
val argsPlan = {
val get = getOfGetMatch(ref(unappResult))
val arity = productArity(get.tpe.stripNamedTuple, unapp.srcPos)
Expand All @@ -450,6 +449,11 @@ object PatternMatcher {
}
}
TestPlan(NonEmptyTest, unappResult, unapp.span, argsPlan)
} else {
assert(unappType.classSymbol.isJavaRecord)
val selectors = javaRecordFields(unappType).map: field =>
ref(unappResult).select(field, _.paramSymss == List(Nil)).appliedToArgs(Nil)
matchArgsPlan(selectors, args, onSuccess)
}
}
}
Expand Down
13 changes: 12 additions & 1 deletion compiler/src/dotty/tools/dotc/typer/Applications.scala
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ import util.chaining.tap
import collection.mutable
import config.Printers.{overload, typr, unapp}
import TypeApplications.*
import Annotations.Annotation
import Annotations.{Annotation, JavaRecordFieldsAnnotation}

import Constants.{Constant, IntTag}
import Denotations.SingleDenotation
Expand Down Expand Up @@ -185,6 +185,11 @@ object Applications {
(0 until argsNum).map(i => if (i < arity - 1) selectorTypes(i) else elemTp).toList
end seqSelectors

def javaRecordFields(tp: Type)(using Context): List[Name] =
tp.typeSymbol.getAnnotation(defn.JavaRecordFieldsAnnot) match
case Some(JavaRecordFieldsAnnotation(fields)) => fields.map(termName)
case _ => assert(false)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doesn't matter, but a little cleaner to just return Nil here in my opinion


/** A utility class that matches results of unapplys with patterns. Two queriable members:
* val argTypes: List[Type]
* def typedPatterns(qual: untpd.Tree, typer: Typer): List[Tree]
Expand Down Expand Up @@ -263,6 +268,10 @@ object Applications {
case _ => None
case _ => None

private def javaRecordTypes(tp: Type): List[Type] =
javaRecordFields(tp).map: name =>
tp.member(name).suchThat(_.paramSymss == List(Nil)).info.resultType

/** The computed argument types which will be the scutinees of the sub-patterns. */
val argTypes: List[Type] =
if unapplyName == nme.unapplySeq then
Expand All @@ -282,6 +291,8 @@ object Applications {
productSelectorTypes(unapplyResult, pos)
// this will cause a "wrong number of arguments in pattern" error later on,
// which is better than the message in `fail`.
else if unapplyResult.classSymbol.isJavaRecord then
javaRecordTypes(unapplyResult)
else fail

/** The typed pattens of this unapply */
Expand Down
7 changes: 2 additions & 5 deletions compiler/src/dotty/tools/dotc/typer/Namer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -928,9 +928,6 @@ class Namer { typer: Typer =>
*/
private def invalidateIfClashingSynthetic(denot: SymDenotation): Unit =

def isJavaRecord(owner: Symbol) =
owner.is(JavaDefined) && owner.derivesFrom(defn.JavaRecordClass)

def isCaseClassOrCompanion(owner: Symbol) =
owner.isClass && {
if (owner.is(Module)) owner.linkedClass.is(CaseClass)
Expand All @@ -954,8 +951,8 @@ class Namer { typer: Typer =>
&& (definesMember || inheritsConcreteMember)
)
||
// remove synthetic constructor or method of a java Record if it clashes with a non-synthetic constructor
(isJavaRecord(denot.owner)
// remove synthetic constructor, method or companion's unapply of a java Record if it clashes with a non-synthetic one
((denot.owner.isJavaRecord || denot.owner.companionClass.isJavaRecord && denot.name == nme.unapply)
&& denot.is(Method)
&& denot.owner.unforcedDecls.lookupAll(denot.name).exists(c => c != denot.symbol && c.info.matches(denot.info))
)
Expand Down
2 changes: 1 addition & 1 deletion compiler/src/dotty/tools/dotc/typer/Typer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2996,7 +2996,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
val canBeInvalidated: Boolean =
sym.is(Synthetic)
&& (desugar.isRetractableCaseClassMethodName(sym.name) ||
(sym.owner.is(JavaDefined) && sym.owner.derivesFrom(defn.JavaRecordClass) && sym.is(Method)))
(sym.owner.isJavaRecord && sym.is(Method)))
assert(canBeInvalidated)
sym.owner.info.decls.openForMutations.unlink(sym)
EmptyTree
Expand Down
12 changes: 10 additions & 2 deletions compiler/test/dotty/tools/dotc/CompilationTests.scala
Original file line number Diff line number Diff line change
Expand Up @@ -169,13 +169,21 @@ class CompilationTests {

@Test def runAll: Unit = {
implicit val testGroup: TestGroup = TestGroup("runAll")
aggregateTests(
var tests = List(
compileFilesInDir("tests/run", defaultOptions.and("-Wsafe-init")),
compileFilesInDir("tests/run-deep-subtype", allowDeepSubtypes),
compileFilesInDir("tests/run-custom-args/captures", allowDeepSubtypes.and("-language:experimental.captureChecking", "-language:experimental.separationChecking", "-source", "3.8")),
// Run tests for legacy lazy vals.
compileFilesInDir("tests/run", defaultOptions.and("-Wsafe-init", "-Ylegacy-lazy-vals", "-Ycheck-constraint-deps"), FileFilter.include(TestSources.runLazyValsAllowlist)),
).checkRuns()
)

if scala.util.Properties.isJavaAtLeast("16") then
tests ++= List(
compileFilesInDir("tests/run-java16+", defaultOptions),
compileDir("tests/run-java16+/java-records-match", defaultOptions),
)

aggregateTests(tests*).checkRuns()
}

// Generic java signatures tests ---------------------------------------------
Expand Down
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ideally we would not need this, since we don't need that information pickled and it's only relevant for the current compilation run. We could add a property key for a tree, but that would only work for the situations where we use the JavaOutlineParser, when we read the classfiles we would not have access to a tree on which a property could be used. I see we can really only get this information out when parsing the tree, and we don't save it anywhere else, which I guess justifies the annotation here.

Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
package scala.annotation.internal

import scala.annotation.StaticAnnotation

/** An annotation attached by JavaParsers/ClassfileParser to Java record class
* with a list of that record's fields. Used in pattern matching on records.
*/
final class JavaRecordFields(args: String*) extends StaticAnnotation
1 change: 1 addition & 0 deletions project/Build.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1315,6 +1315,7 @@ object Build {
file(s"${baseDirectory.value}/src/scala/quoted/runtime/StopMacroExpansion.scala"),
file(s"${baseDirectory.value}/src/scala/compiletime/Erased.scala"),
file(s"${baseDirectory.value}/src/scala/annotation/internal/onlyCapability.scala"),
file(s"${baseDirectory.value}/src/scala/annotation/internal/JavaRecordFields.scala"),
file(s"${baseDirectory.value}/src/scala/runtime/VarArgsBuilder.scala"),
)
)
Expand Down
1 change: 1 addition & 0 deletions project/MiMaFilters.scala
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ object MiMaFilters {
ProblemFilters.exclude[MissingClassProblem]("scala.annotation.internal.CaptureChecked"),
ProblemFilters.exclude[MissingClassProblem]("scala.annotation.internal.reachCapability"),
ProblemFilters.exclude[MissingClassProblem]("scala.annotation.internal.preview"),
ProblemFilters.exclude[MissingClassProblem]("scala.annotation.internal.JavaRecordFields"),
ProblemFilters.exclude[MissingClassProblem]("scala.annotation.unchecked.uncheckedCaptures"),
ProblemFilters.exclude[MissingClassProblem]("scala.quoted.Quotes$reflectModule$ValOrDefDefMethods"),
ProblemFilters.exclude[MissingClassProblem]("scala.runtime.stdLibPatches.language$3$u002E4$"),
Expand Down
9 changes: 9 additions & 0 deletions tests/run-java16+/java-records-match.check
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
empty
hello
hahaha
hehehe
21
hihihi
hohoho
unapply
hejhejhejhej
1 change: 1 addition & 0 deletions tests/run-java16+/java-records-match/Rec0_1.java
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
public record Rec0_1() {}
1 change: 1 addition & 0 deletions tests/run-java16+/java-records-match/Rec1_1.java
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
public record Rec1_1(String s) {}
1 change: 1 addition & 0 deletions tests/run-java16+/java-records-match/Rec2_1.java
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
public record Rec2_1(int x, String y) {}
1 change: 1 addition & 0 deletions tests/run-java16+/java-records-match/Rec3_1.java
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
public record Rec3_1<T>(int x, T y) {}
1 change: 1 addition & 0 deletions tests/run-java16+/java-records-match/Rec4_1.java
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
public record Rec4_1<T extends Comparable<Integer>>(int x, T y) {}
1 change: 1 addition & 0 deletions tests/run-java16+/java-records-match/Rec5_1.java
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
public record Rec5_1<T, U extends Comparable<T>, W extends Comparable<W>>(T t, U u, W w) {}
3 changes: 3 additions & 0 deletions tests/run-java16+/java-records-match/RecUnapply_1.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
public record RecUnapply_1(int i, String s) {
public static boolean unapply(RecUnapply_1 r) { return true; }
}
42 changes: 42 additions & 0 deletions tests/run-java16+/java-records-match/Test_2.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
case class Foo(val value: String) extends Comparable[Integer]:
override def compareTo(other: Integer) = 0

case class Bar(val value: String) extends Comparable[Bar]:
override def compareTo(other: Bar) = 0

case class Baz(val s: String, val i: Int)

object Baz:
def unapply(b: Baz): Rec2_1 = Rec2_1(b.i + 1, b.s + "j")

@main def Test =
val r0 = Rec0_1()
r0 match { case Rec0_1() => println("empty") }

val r1 = Rec1_1("hello")
r1 match { case Rec1_1(s) => println(s) }

val r2 = Rec2_1(3, "ha")
r2 match { case Rec2_1(i, s) => println(s * i) }

// type param (no bounds)
val r3a = Rec3_1(3, "he")
r3a match { case Rec3_1(i, s) => println(s * i) }
val r3b = Rec3_1(3, 7)
r3b match { case Rec3_1(i, j) => println(i * j) }

// type param with simple bounds
val r4 = Rec4_1(3, Foo("hi"))
r4 match { case Rec4_1(i, f) => println(f.value * i) }

// type params with recursion / mutual reference
val r5 = Rec5_1(3 : Integer, Foo("h"), Bar("o"))
r5 match { case Rec5_1(i, f, b) => println((f.value + b.value) * i) }

// custom unapply
val r6 = RecUnapply_1(3, "x")
r6 match { case RecUnapply_1() => println("unapply") }

// scala class returning record from unapply
val r7 = Baz("he", 3)
r7 match { case Baz(i, s) => println(s * i) }