diff --git a/compiler/src/dotty/tools/dotc/core/Annotations.scala b/compiler/src/dotty/tools/dotc/core/Annotations.scala index 1615679a036e..48eb6e0d82f8 100644 --- a/compiler/src/dotty/tools/dotc/core/Annotations.scala +++ b/compiler/src/dotty/tools/dotc/core/Annotations.scala @@ -315,4 +315,15 @@ object Annotations { case Some(Constant(msg: String)) => Some(msg) case _ => Some("") } + + object JavaRecordFieldsAnnotation { + def unapply(a: Annotation)(using Context): Option[List[String]] = + if a.symbol ne defn.JavaRecordFieldsAnnot then None + else + a.tree match + case Apply(_, List(Typed(SeqLiteral(args, _), _))) => + val fields = args.collect { case Literal(Constant(s: String)) => s } + Some(fields) + case _ => None + } } diff --git a/compiler/src/dotty/tools/dotc/core/Definitions.scala b/compiler/src/dotty/tools/dotc/core/Definitions.scala index dbe1602e2d82..5ef3dbbb9846 100644 --- a/compiler/src/dotty/tools/dotc/core/Definitions.scala +++ b/compiler/src/dotty/tools/dotc/core/Definitions.scala @@ -1112,6 +1112,7 @@ class Definitions { @tu lazy val PublicInBinaryAnnot: ClassSymbol = requiredClass("scala.annotation.publicInBinary") @tu lazy val WitnessNamesAnnot: ClassSymbol = requiredClass("scala.annotation.internal.WitnessNames") @tu lazy val StableNullAnnot: ClassSymbol = requiredClass("scala.annotation.stableNull") + @tu lazy val JavaRecordFieldsAnnot: ClassSymbol = requiredClass("scala.annotation.internal.JavaRecordFields") @tu lazy val JavaRepeatableAnnot: ClassSymbol = requiredClass("java.lang.annotation.Repeatable") diff --git a/compiler/src/dotty/tools/dotc/core/StdNames.scala b/compiler/src/dotty/tools/dotc/core/StdNames.scala index 323c59a5711d..71eb3486450a 100644 --- a/compiler/src/dotty/tools/dotc/core/StdNames.scala +++ b/compiler/src/dotty/tools/dotc/core/StdNames.scala @@ -271,6 +271,7 @@ object StdNames { final val MethodParametersATTR: N = "MethodParameters" final val LineNumberTableATTR: N = "LineNumberTable" final val LocalVariableTableATTR: N = "LocalVariableTable" + final val RecordATTR: N = "Record" final val RuntimeVisibleAnnotationATTR: N = "RuntimeVisibleAnnotations" // RetentionPolicy.RUNTIME final val RuntimeInvisibleAnnotationATTR: N = "RuntimeInvisibleAnnotations" // RetentionPolicy.CLASS final val RuntimeParamAnnotationATTR: N = "RuntimeVisibleParameterAnnotations" // RetentionPolicy.RUNTIME (annotations on parameters) diff --git a/compiler/src/dotty/tools/dotc/core/SymUtils.scala b/compiler/src/dotty/tools/dotc/core/SymUtils.scala index c7733acbfdec..a6694366f516 100644 --- a/compiler/src/dotty/tools/dotc/core/SymUtils.scala +++ b/compiler/src/dotty/tools/dotc/core/SymUtils.scala @@ -403,6 +403,8 @@ class SymUtils: || isDefaultArgumentOfCheckedMethod || (!self.is(Package) && checkOwner(self.owner)) + def isJavaRecord(using Context) = self.is(JavaDefined) && self.derivesFrom(defn.JavaRecordClass) + /** The declared self type of this class, as seen from `site`, stripping * all refinements for opaque types. */ diff --git a/compiler/src/dotty/tools/dotc/core/classfile/ClassfileParser.scala b/compiler/src/dotty/tools/dotc/core/classfile/ClassfileParser.scala index ea8a74d18192..af1204ffd2ff 100644 --- a/compiler/src/dotty/tools/dotc/core/classfile/ClassfileParser.scala +++ b/compiler/src/dotty/tools/dotc/core/classfile/ClassfileParser.scala @@ -535,6 +535,35 @@ class ClassfileParser( } } + class RecordUnapplyCompleter() extends LazyType { + override def complete(denot: SymDenotation)(using Context): Unit = + def methType(t: Type) = MethodType(List(nme.x_0), List(t), t) + + val unapplyMethodType = + val recType = classRoot.typeRef + val tparams = classRoot.typeParams + if tparams.length > 0 then + PolyType(tparams.map(_.name))( + pt => tparams.map(_.info.subst(tparams, pt.paramRefs).bounds), + pt => methType(AppliedType(recType, pt.paramRefs)) + ) + else methType(recType) + + // synthetic unapply generated here won't be invalidated by `invalidateIfClashingSynthetic`, so we handle that immediately + val clashes = denot.owner.unforcedDecls.lookupAll(nme.unapply) + if clashes.exists(c => c != denot.symbol && c.info.matches(unapplyMethodType)) then + denot.info = NoType + else + denot.info = unapplyMethodType + val sym = denot.symbol + val ddef = DefDef(sym.asTerm, _.last.last) + .withAddedFlags(Flags.JavaDefined | Flags.SyntheticMethod | Flags.Inline).withSpan(Span(0)) + sym.defTree = ddef + + // typed trees generated here are not a subject to typer's inline logic, so we do that manually + inlines.PrepareInlineable.registerInlineInfo(sym, ddef.rhs) + } + def constantTagToType(tag: Int)(using Context): Type = (tag: @switch) match { case BYTE_TAG => defn.ByteType @@ -990,6 +1019,32 @@ class ClassfileParser( report.log(s"$sym in ${sym.owner} is a java 8+ default method.") } + case tpnme.RecordATTR => + val components = List.fill(in.nextChar): + val name = pool.getName(in.nextChar).value + val _ = in.nextChar + skipAttributes() + name + + classRoot.addAnnotation( + Annotation( + defn.JavaRecordFieldsAnnot, + Typed( + SeqLiteral(components.map(field => Literal(Constant(field))), TypeTree(defn.StringType)), + TypeTree(defn.RepeatedParamType.appliedTo(defn.StringType)) + ), + NoSpan + ) + ) + val completer = RecordUnapplyCompleter() + val member = newSymbol( + moduleRoot.symbol, + nme.unapply, + Flags.JavaDefined | Flags.SyntheticMethod | Flags.Inline, + completer, + ) + staticScope.enter(member) + case _ => } in.bp = end diff --git a/compiler/src/dotty/tools/dotc/parsing/JavaParsers.scala b/compiler/src/dotty/tools/dotc/parsing/JavaParsers.scala index 721c7a36acda..5f2150b99c64 100644 --- a/compiler/src/dotty/tools/dotc/parsing/JavaParsers.scala +++ b/compiler/src/dotty/tools/dotc/parsing/JavaParsers.scala @@ -900,8 +900,31 @@ object JavaParsers { needsDummyConstr = true ) ).withMods(mods.withFlags(Flags.JavaDefined | Flags.Final)) + .withAddedAnnotation( + New( + ref(defn.JavaRecordFieldsAnnot), + header.map(field => Literal(Constant(field.name.toString))) :: Nil, + ).withSpan(Span(start)) + ) } - addCompanionObject(statics, recordTypeDef) + + val unapplyDef = { + val tparams2 = tparams.map(td => TypeDef(td.name, td.rhs).withMods(Modifiers(Flags.Param))) + + val selfTpt = if tparams2.isEmpty then Ident(name) else + AppliedTypeTree(Ident(name), tparams2.map(tp => Ident(tp.name))) + val param = ValDef(nme.x_0, selfTpt, EmptyTree) + .withMods(Modifiers(Flags.JavaDefined | Flags.SyntheticParam)) + + DefDef( + nme.unapply, + joinParams(tparams2, List(List(param))), + selfTpt, + Ident(nme.x_0) + ).withMods(Modifiers(Flags.JavaDefined | Flags.SyntheticMethod | Flags.Inline)) + } + + addCompanionObject(unapplyDef :: statics, recordTypeDef) end recordDecl def interfaceDecl(start: Offset, mods: Modifiers): List[Tree] = { diff --git a/compiler/src/dotty/tools/dotc/transform/PatternMatcher.scala b/compiler/src/dotty/tools/dotc/transform/PatternMatcher.scala index 273879f3c3cb..5a5cbb3c7fd3 100644 --- a/compiler/src/dotty/tools/dotc/transform/PatternMatcher.scala +++ b/compiler/src/dotty/tools/dotc/transform/PatternMatcher.scala @@ -422,8 +422,7 @@ object PatternMatcher { (0 until unappResult.denot.info.tupleElementTypes.getOrElse(Nil).length) .toList.map(tupleApp(_, ref(unappResult))) matchArgsPlan(components, args, onSuccess) - else { - assert(isGetMatch(unappType)) + else if (isGetMatch(unappType)) { val argsPlan = { val get = getOfGetMatch(ref(unappResult)) val arity = productArity(get.tpe.stripNamedTuple, unapp.srcPos) @@ -450,6 +449,11 @@ object PatternMatcher { } } TestPlan(NonEmptyTest, unappResult, unapp.span, argsPlan) + } else { + assert(unappType.classSymbol.isJavaRecord) + val selectors = javaRecordFields(unappType).map: field => + ref(unappResult).select(field, _.paramSymss == List(Nil)).appliedToArgs(Nil) + matchArgsPlan(selectors, args, onSuccess) } } } diff --git a/compiler/src/dotty/tools/dotc/typer/Applications.scala b/compiler/src/dotty/tools/dotc/typer/Applications.scala index 5ecf1601db8a..80b989e64c79 100644 --- a/compiler/src/dotty/tools/dotc/typer/Applications.scala +++ b/compiler/src/dotty/tools/dotc/typer/Applications.scala @@ -30,7 +30,7 @@ import util.chaining.tap import collection.mutable import config.Printers.{overload, typr, unapp} import TypeApplications.* -import Annotations.Annotation +import Annotations.{Annotation, JavaRecordFieldsAnnotation} import Constants.{Constant, IntTag} import Denotations.SingleDenotation @@ -185,6 +185,11 @@ object Applications { (0 until argsNum).map(i => if (i < arity - 1) selectorTypes(i) else elemTp).toList end seqSelectors + def javaRecordFields(tp: Type)(using Context): List[Name] = + tp.typeSymbol.getAnnotation(defn.JavaRecordFieldsAnnot) match + case Some(JavaRecordFieldsAnnotation(fields)) => fields.map(termName) + case _ => assert(false) + /** A utility class that matches results of unapplys with patterns. Two queriable members: * val argTypes: List[Type] * def typedPatterns(qual: untpd.Tree, typer: Typer): List[Tree] @@ -263,6 +268,10 @@ object Applications { case _ => None case _ => None + private def javaRecordTypes(tp: Type): List[Type] = + javaRecordFields(tp).map: name => + tp.member(name).suchThat(_.paramSymss == List(Nil)).info.resultType + /** The computed argument types which will be the scutinees of the sub-patterns. */ val argTypes: List[Type] = if unapplyName == nme.unapplySeq then @@ -282,6 +291,8 @@ object Applications { productSelectorTypes(unapplyResult, pos) // this will cause a "wrong number of arguments in pattern" error later on, // which is better than the message in `fail`. + else if unapplyResult.classSymbol.isJavaRecord then + javaRecordTypes(unapplyResult) else fail /** The typed pattens of this unapply */ diff --git a/compiler/src/dotty/tools/dotc/typer/Namer.scala b/compiler/src/dotty/tools/dotc/typer/Namer.scala index 65c9fa2f5ddc..0dd636b61d7c 100644 --- a/compiler/src/dotty/tools/dotc/typer/Namer.scala +++ b/compiler/src/dotty/tools/dotc/typer/Namer.scala @@ -928,9 +928,6 @@ class Namer { typer: Typer => */ private def invalidateIfClashingSynthetic(denot: SymDenotation): Unit = - def isJavaRecord(owner: Symbol) = - owner.is(JavaDefined) && owner.derivesFrom(defn.JavaRecordClass) - def isCaseClassOrCompanion(owner: Symbol) = owner.isClass && { if (owner.is(Module)) owner.linkedClass.is(CaseClass) @@ -954,8 +951,8 @@ class Namer { typer: Typer => && (definesMember || inheritsConcreteMember) ) || - // remove synthetic constructor or method of a java Record if it clashes with a non-synthetic constructor - (isJavaRecord(denot.owner) + // remove synthetic constructor, method or companion's unapply of a java Record if it clashes with a non-synthetic one + ((denot.owner.isJavaRecord || (denot.owner.companionClass.isJavaRecord && denot.name == nme.unapply)) && denot.is(Method) && denot.owner.unforcedDecls.lookupAll(denot.name).exists(c => c != denot.symbol && c.info.matches(denot.info)) ) diff --git a/compiler/src/dotty/tools/dotc/typer/Typer.scala b/compiler/src/dotty/tools/dotc/typer/Typer.scala index 28b05ecd5001..0b558ae3456a 100644 --- a/compiler/src/dotty/tools/dotc/typer/Typer.scala +++ b/compiler/src/dotty/tools/dotc/typer/Typer.scala @@ -2996,7 +2996,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer val canBeInvalidated: Boolean = sym.is(Synthetic) && (desugar.isRetractableCaseClassMethodName(sym.name) || - (sym.owner.is(JavaDefined) && sym.owner.derivesFrom(defn.JavaRecordClass) && sym.is(Method))) + (sym.owner.isJavaRecord && sym.is(Method))) assert(canBeInvalidated) sym.owner.info.decls.openForMutations.unlink(sym) EmptyTree diff --git a/compiler/test/dotty/tools/dotc/CompilationTests.scala b/compiler/test/dotty/tools/dotc/CompilationTests.scala index baf1b4d66306..1046d8dd67ff 100644 --- a/compiler/test/dotty/tools/dotc/CompilationTests.scala +++ b/compiler/test/dotty/tools/dotc/CompilationTests.scala @@ -169,13 +169,23 @@ class CompilationTests { @Test def runAll: Unit = { implicit val testGroup: TestGroup = TestGroup("runAll") - aggregateTests( + var tests = List( compileFilesInDir("tests/run", defaultOptions.and("-Wsafe-init")), compileFilesInDir("tests/run-deep-subtype", allowDeepSubtypes), compileFilesInDir("tests/run-custom-args/captures", allowDeepSubtypes.and("-language:experimental.captureChecking", "-language:experimental.separationChecking", "-source", "3.8")), // Run tests for legacy lazy vals. compileFilesInDir("tests/run", defaultOptions.and("-Wsafe-init", "-Ylegacy-lazy-vals", "-Ycheck-constraint-deps"), FileFilter.include(TestSources.runLazyValsAllowlist)), - ).checkRuns() + ) + + if scala.util.Properties.isJavaAtLeast("16") then + tests ++= List( + // for separate compilation + compileFilesInDir("tests/run-java16+", defaultOptions), + // for joint compilation + compileDir("tests/run-java16+/java-records-match", defaultOptions), + ) + + aggregateTests(tests*).checkRuns() } // Generic java signatures tests --------------------------------------------- diff --git a/library/src/scala/annotation/internal/JavaRecordFields.scala b/library/src/scala/annotation/internal/JavaRecordFields.scala new file mode 100644 index 000000000000..93a126db31de --- /dev/null +++ b/library/src/scala/annotation/internal/JavaRecordFields.scala @@ -0,0 +1,8 @@ +package scala.annotation.internal + +import scala.annotation.StaticAnnotation + +/** An annotation attached by JavaParsers/ClassfileParser to Java record class + * with a list of that record's fields. Used in pattern matching on records. + */ +final class JavaRecordFields(args: String*) extends StaticAnnotation diff --git a/project/Build.scala b/project/Build.scala index 9c4fbafbb330..f8c27e70379d 100644 --- a/project/Build.scala +++ b/project/Build.scala @@ -1315,6 +1315,7 @@ object Build { file(s"${baseDirectory.value}/src/scala/quoted/runtime/StopMacroExpansion.scala"), file(s"${baseDirectory.value}/src/scala/compiletime/Erased.scala"), file(s"${baseDirectory.value}/src/scala/annotation/internal/onlyCapability.scala"), + file(s"${baseDirectory.value}/src/scala/annotation/internal/JavaRecordFields.scala"), file(s"${baseDirectory.value}/src/scala/runtime/VarArgsBuilder.scala"), ) ) diff --git a/project/MiMaFilters.scala b/project/MiMaFilters.scala index e2ca983081d2..225d505a962c 100644 --- a/project/MiMaFilters.scala +++ b/project/MiMaFilters.scala @@ -19,6 +19,7 @@ object MiMaFilters { ProblemFilters.exclude[MissingClassProblem]("scala.Conversion$"), ProblemFilters.exclude[MissingClassProblem]("scala.annotation.internal.RuntimeChecked"), ProblemFilters.exclude[MissingClassProblem]("scala.annotation.stableNull"), + ProblemFilters.exclude[MissingClassProblem]("scala.annotation.internal.JavaRecordFields"), ProblemFilters.exclude[DirectMissingMethodProblem]("scala.NamedTuple.namedTupleOrdering"), ProblemFilters.exclude[MissingClassProblem]("scala.NamedTuple$namedTupleOrdering"), diff --git a/tests/run-java16+/java-records-match.check b/tests/run-java16+/java-records-match.check new file mode 100644 index 000000000000..b25641dfadf4 --- /dev/null +++ b/tests/run-java16+/java-records-match.check @@ -0,0 +1,9 @@ +empty +hello +hahaha +hehehe +21 +hihihi +hohoho +unapply +hejhejhejhej diff --git a/tests/run-java16+/java-records-match/Rec0_1.java b/tests/run-java16+/java-records-match/Rec0_1.java new file mode 100644 index 000000000000..d5f7f71ba3a2 --- /dev/null +++ b/tests/run-java16+/java-records-match/Rec0_1.java @@ -0,0 +1 @@ +public record Rec0_1() {} diff --git a/tests/run-java16+/java-records-match/Rec1_1.java b/tests/run-java16+/java-records-match/Rec1_1.java new file mode 100644 index 000000000000..bde38258e4b0 --- /dev/null +++ b/tests/run-java16+/java-records-match/Rec1_1.java @@ -0,0 +1 @@ +public record Rec1_1(String s) {} diff --git a/tests/run-java16+/java-records-match/Rec2_1.java b/tests/run-java16+/java-records-match/Rec2_1.java new file mode 100644 index 000000000000..e47d48b23070 --- /dev/null +++ b/tests/run-java16+/java-records-match/Rec2_1.java @@ -0,0 +1 @@ +public record Rec2_1(int x, String y) {} diff --git a/tests/run-java16+/java-records-match/Rec3_1.java b/tests/run-java16+/java-records-match/Rec3_1.java new file mode 100644 index 000000000000..331b1c831562 --- /dev/null +++ b/tests/run-java16+/java-records-match/Rec3_1.java @@ -0,0 +1 @@ +public record Rec3_1(int x, T y) {} diff --git a/tests/run-java16+/java-records-match/Rec4_1.java b/tests/run-java16+/java-records-match/Rec4_1.java new file mode 100644 index 000000000000..d1a875781430 --- /dev/null +++ b/tests/run-java16+/java-records-match/Rec4_1.java @@ -0,0 +1 @@ +public record Rec4_1>(int x, T y) {} diff --git a/tests/run-java16+/java-records-match/Rec5_1.java b/tests/run-java16+/java-records-match/Rec5_1.java new file mode 100644 index 000000000000..fa07b8cda3b3 --- /dev/null +++ b/tests/run-java16+/java-records-match/Rec5_1.java @@ -0,0 +1 @@ +public record Rec5_1, W extends Comparable>(T t, U u, W w) {} diff --git a/tests/run-java16+/java-records-match/RecUnapply_1.java b/tests/run-java16+/java-records-match/RecUnapply_1.java new file mode 100644 index 000000000000..1be9bd3ab7df --- /dev/null +++ b/tests/run-java16+/java-records-match/RecUnapply_1.java @@ -0,0 +1,3 @@ +public record RecUnapply_1(int i, String s) { + public static boolean unapply(RecUnapply_1 r) { return true; } +} diff --git a/tests/run-java16+/java-records-match/Test_2.scala b/tests/run-java16+/java-records-match/Test_2.scala new file mode 100644 index 000000000000..a628e3939af6 --- /dev/null +++ b/tests/run-java16+/java-records-match/Test_2.scala @@ -0,0 +1,42 @@ +case class Foo(val value: String) extends Comparable[Integer]: + override def compareTo(other: Integer) = 0 + +case class Bar(val value: String) extends Comparable[Bar]: + override def compareTo(other: Bar) = 0 + +case class Baz(val s: String, val i: Int) + +object Baz: + def unapply(b: Baz): Rec2_1 = Rec2_1(b.i + 1, b.s + "j") + +@main def Test = + val r0 = Rec0_1() + r0 match { case Rec0_1() => println("empty") } + + val r1 = Rec1_1("hello") + r1 match { case Rec1_1(s) => println(s) } + + val r2 = Rec2_1(3, "ha") + r2 match { case Rec2_1(i, s) => println(s * i) } + + // type param (no bounds) + val r3a = Rec3_1(3, "he") + r3a match { case Rec3_1(i, s) => println(s * i) } + val r3b = Rec3_1(3, 7) + r3b match { case Rec3_1(i, j) => println(i * j) } + + // type param with simple bounds + val r4 = Rec4_1(3, Foo("hi")) + r4 match { case Rec4_1(i, f) => println(f.value * i) } + + // type params with recursion / mutual reference + val r5 = Rec5_1(3 : Integer, Foo("h"), Bar("o")) + r5 match { case Rec5_1(i, f, b) => println((f.value + b.value) * i) } + + // custom unapply + val r6 = RecUnapply_1(3, "x") + r6 match { case RecUnapply_1() => println("unapply") } + + // scala class returning record from unapply + val r7 = Baz("he", 3) + r7 match { case Baz(i, s) => println(s * i) }