Skip to content

Commit 7668120

Browse files
committed
Restrict allowed trees in annotations
1 parent fa43ab8 commit 7668120

31 files changed

+267
-124
lines changed

compiler/src/dotty/tools/dotc/ast/TreeInfo.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,7 @@ trait TreeInfo[T <: Untyped] { self: Trees.Instance[T] =>
144144
def allTermArguments(tree: Tree): List[Tree] = unsplice(tree) match {
145145
case Apply(fn, args) => allTermArguments(fn) ::: args
146146
case TypeApply(fn, args) => allTermArguments(fn)
147+
// TOOD(mbovel): is it really safe to skip all blocks here and in `allArguments`?
147148
case Block(_, expr) => allTermArguments(expr)
148149
case _ => Nil
149150
}

compiler/src/dotty/tools/dotc/core/Definitions.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -499,6 +499,9 @@ class Definitions {
499499

500500
@tu lazy val DummyImplicitClass: ClassSymbol = requiredClass("scala.DummyImplicit")
501501

502+
@tu lazy val SymbolModule: Symbol = requiredModule("scala.Symbol")
503+
@tu lazy val JSSymbolModule: Symbol = requiredModule("scala.scalajs.js.Symbol")
504+
502505
@tu lazy val ScalaRuntimeModule: Symbol = requiredModule("scala.runtime.ScalaRunTime")
503506
def runtimeMethodRef(name: PreName): TermRef = ScalaRuntimeModule.requiredMethodRef(name)
504507
def ScalaRuntime_drop: Symbol = runtimeMethodRef(nme.drop).symbol

compiler/src/dotty/tools/dotc/transform/TreeChecker.scala

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -827,6 +827,17 @@ object TreeChecker {
827827
|${mismatch.message}${mismatch.explanation}
828828
|tree = $tree ${tree.className}""".stripMargin
829829
})
830+
checkWellFormedType(tp1)
831+
checkWellFormedType(tp2)
832+
833+
/** Check that the type `tp` is well-formed. Currently this only means
834+
* checking that annotated types have valid annotation arguments.
835+
*/
836+
private def checkWellFormedType(tp: Type)(using Context): Unit =
837+
tp.foreachPart:
838+
case AnnotatedType(underlying, annot) => checkAnnot(annot.tree)
839+
case _ => ()
840+
830841
}
831842

832843
/** Tree checker that can be applied to a local tree. */

compiler/src/dotty/tools/dotc/typer/Checking.scala

Lines changed: 66 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -914,7 +914,6 @@ object Checking {
914914
annot
915915
case _ => annot
916916
end checkNamedArgumentForJavaAnnotation
917-
918917
}
919918

920919
trait Checking {
@@ -1385,12 +1384,21 @@ trait Checking {
13851384
if !Inlines.inInlineMethod && !ctx.isInlineContext then
13861385
report.error(em"$what can only be used in an inline method", pos)
13871386

1387+
def checkAnnot(tree: Tree)(using Context): Tree =
1388+
tree match
1389+
case Ident(tpnme.BOUNDTYPE_ANNOT) =>
1390+
// `FirstTransform.toTypeTree` creates `Annotated` nodes whose `annot` are
1391+
// `Ident`s, not annotation instances. See `tests/pos/annot-boundtype.scala`.
1392+
tree
1393+
case _ =>
1394+
checkAnnotArgs(checkAnnotClass(tree))
1395+
13881396
/** Check that the class corresponding to this tree is either a Scala or Java annotation.
13891397
*
13901398
* @return The original tree or an error tree in case `tree` isn't a valid
13911399
* annotation or already an error tree.
13921400
*/
1393-
def checkAnnotClass(tree: Tree)(using Context): Tree =
1401+
private def checkAnnotClass(tree: Tree)(using Context): Tree =
13941402
if tree.tpe.isError then
13951403
return tree
13961404
val cls = Annotations.annotClass(tree)
@@ -1402,8 +1410,8 @@ trait Checking {
14021410
errorTree(tree, em"$cls is not a valid Scala annotation: it does not extend `scala.annotation.Annotation`")
14031411
else tree
14041412

1405-
/** Check arguments of compiler-defined annotations */
1406-
def checkAnnotArgs(tree: Tree)(using Context): tree.type =
1413+
/** Check arguments of annotations */
1414+
private def checkAnnotArgs(tree: Tree)(using Context): Tree =
14071415
val cls = Annotations.annotClass(tree)
14081416
tree match
14091417
case Apply(tycon, arg :: Nil) if cls == defn.TargetNameAnnot =>
@@ -1414,8 +1422,61 @@ trait Checking {
14141422
case _ =>
14151423
report.error(em"@${cls.name} needs a string literal as argument", arg.srcPos)
14161424
case _ =>
1425+
if cls.isRetainsLike then () // Do not check @retain annotations
1426+
else if cls == defn.ThrowsAnnot then
1427+
// Do not check @throws annotations.
1428+
// TODO(mbovel): in tests/run/t6380.scala, an annotation tree is
1429+
// `new throws[Exception](throws.<init>[Exception])`. What is this?
1430+
()
1431+
else
1432+
tpd.allTermArguments(tree).foreach(checkAnnotArg)
14171433
tree
14181434

1435+
private def checkAnnotArg(tree: Tree)(using Context): Unit =
1436+
def isTupleModule(sym: Symbol): Boolean =
1437+
ctx.definitions.isTupleClass(sym.companionClass)
1438+
1439+
def isFunctionAllowed(t: Tree): Boolean =
1440+
t match
1441+
case Select(qual, nme.apply) =>
1442+
qual.symbol == defn.ArrayModule
1443+
|| qual.symbol == defn.ClassTagModule // class tags are used as arguments to Array.apply
1444+
|| qual.symbol == defn.SymbolModule // used in Akka
1445+
|| qual.symbol == defn.JSSymbolModule // used in Scala.js
1446+
|| isTupleModule(qual.symbol)
1447+
case Select(New(clazz), nme.CONSTRUCTOR) => clazz.symbol.isAnnotation
1448+
case Apply(fun, _) => isFunctionAllowed(fun)
1449+
case TypeApply(fun, _) => isFunctionAllowed(fun)
1450+
case _ => false
1451+
1452+
def valid(t: Tree): Boolean =
1453+
t.tpe.isEffectivelySingleton
1454+
|| (
1455+
t match
1456+
case Literal(_) => true
1457+
// `_` is used as placeholder for unspecified arguments of Java
1458+
// annotations. Example: tests/run/java-ann-super-class
1459+
case Ident(nme.WILDCARD) => true
1460+
case Apply(fun, args) => isFunctionAllowed(fun) && args.forall(valid)
1461+
case TypeApply(fun, args) => isFunctionAllowed(fun)
1462+
// Support for `x.isInstanceOf[T]`. Probably not needed.
1463+
//case TypeApply(meth @ Select(arg, _), _) if meth.symbol == defn.Any_asInstanceOf => valid(arg)
1464+
case SeqLiteral(elems, _) => elems.forall(valid)
1465+
case Typed(expr, _) => valid(expr)
1466+
case NamedArg(_, arg) => valid(arg)
1467+
case Splice(_) => true
1468+
case Hole(_, _, _, _) => true
1469+
case _ => false
1470+
)
1471+
1472+
if !valid(tree) then
1473+
report.error(
1474+
i"""Implementation restriction: not a valid annotation argument.
1475+
| Argument: $tree
1476+
| Type: ${tree.tpe}""",
1477+
tree.srcPos
1478+
)
1479+
14191480
/** 1. Check that all case classes that extend `scala.reflect.Enum` are `enum` cases
14201481
* 2. Check that parameterised `enum` cases do not extend java.lang.Enum.
14211482
* 3. Check that only a static `enum` base class can extend java.lang.Enum.
@@ -1663,7 +1724,7 @@ trait NoChecking extends ReChecking {
16631724
override def checkImplicitConversionDefOK(sym: Symbol)(using Context): Unit = ()
16641725
override def checkImplicitConversionUseOK(tree: Tree, expected: Type)(using Context): Unit = ()
16651726
override def checkFeasibleParent(tp: Type, pos: SrcPos, where: => String = "")(using Context): Type = tp
1666-
override def checkAnnotArgs(tree: Tree)(using Context): tree.type = tree
1727+
override def checkAnnot(tree: Tree)(using Context): tree.type = tree
16671728
override def checkNoTargetNameConflict(stats: List[Tree])(using Context): Unit = ()
16681729
override def checkParentCall(call: Tree, caller: ClassSymbol)(using Context): Unit = ()
16691730
override def checkSimpleKinded(tpt: Tree)(using Context): Tree = tpt

compiler/src/dotty/tools/dotc/typer/Typer.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2780,7 +2780,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
27802780
}
27812781

27822782
def typedAnnotation(annot: untpd.Tree)(using Context): Tree =
2783-
checkAnnotClass(checkAnnotArgs(typed(annot)))
2783+
checkAnnot(typed(annot))
27842784

27852785
def registerNowarn(tree: Tree, mdef: untpd.Tree)(using Context): Unit =
27862786
val annot = Annotations.Annotation(tree)
@@ -3307,7 +3307,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
33073307
end typedPackageDef
33083308

33093309
def typedAnnotated(tree: untpd.Annotated, pt: Type)(using Context): Tree = {
3310-
val annot1 = checkAnnotClass(typedExpr(tree.annot))
3310+
val annot1 = checkAnnot(typedExpr(tree.annot))
33113311
val annotCls = Annotations.annotClass(annot1)
33123312
if annotCls == defn.NowarnAnnot then
33133313
registerNowarn(annot1, tree)

tests/bench/inductive-implicits.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ package shapeless {
6161
import shapeless.*
6262

6363
object Test extends App {
64+
import Selector.given
6465
val sel = Selector[L, Boolean]
6566

6667
type L =

tests/neg/annot-invalid.check

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
-- Error: tests/neg/annot-invalid.scala:7:21 ---------------------------------------------------------------------------
2+
7 | val x1: Int @annot(n + 1) = 0 // error
3+
| ^^^^^
4+
| Implementation restriction: not a valid annotation argument.
5+
| Argument: n.+(1)
6+
| Type: Int
7+
-- Error: tests/neg/annot-invalid.scala:8:22 ---------------------------------------------------------------------------
8+
8 | val x2: Int @annot(f(2)) = 0 // error
9+
| ^^^^
10+
| Implementation restriction: not a valid annotation argument.
11+
| Argument: f(2)
12+
| Type: Unit
13+
-- Error: tests/neg/annot-invalid.scala:9:21 ---------------------------------------------------------------------------
14+
9 | val x3: Int @annot(throw new Error()) = 0 // error
15+
| ^^^^^^^^^^^^^^^^^
16+
| Implementation restriction: not a valid annotation argument.
17+
| Argument: throw new Error()
18+
| Type: Nothing
19+
-- Error: tests/neg/annot-invalid.scala:10:21 --------------------------------------------------------------------------
20+
10 | val x4: Int @annot((x: Int) => x) = 0 // error
21+
| ^^^^^^^^^^^^^
22+
| Implementation restriction: not a valid annotation argument.
23+
| Argument: {
24+
| def $anonfun(x: Int): Int = x
25+
| closure($anonfun)
26+
| }
27+
| Type: Int => Int
28+
-- Error: tests/neg/annot-invalid.scala:12:9 ---------------------------------------------------------------------------
29+
12 | @annot(n + 1) val y1: Int = 0 // error
30+
| ^^^^^
31+
| Implementation restriction: not a valid annotation argument.
32+
| Argument: n.+(1)
33+
| Type: Int
34+
-- Error: tests/neg/annot-invalid.scala:13:10 --------------------------------------------------------------------------
35+
13 | @annot(f(2)) val y2: Int = 0 // error
36+
| ^^^^
37+
| Implementation restriction: not a valid annotation argument.
38+
| Argument: f(2)
39+
| Type: Unit
40+
-- Error: tests/neg/annot-invalid.scala:14:9 ---------------------------------------------------------------------------
41+
14 | @annot(throw new Error()) val y3: Int = 0 // error
42+
| ^^^^^^^^^^^^^^^^^
43+
| Implementation restriction: not a valid annotation argument.
44+
| Argument: throw new Error()
45+
| Type: Nothing
46+
-- Error: tests/neg/annot-invalid.scala:15:9 ---------------------------------------------------------------------------
47+
15 | @annot((x: Int) => x) val y4: Int = 0 // error
48+
| ^^^^^^^^^^^^^
49+
| Implementation restriction: not a valid annotation argument.
50+
| Argument: {
51+
| def $anonfun(x: Int): Int = x
52+
| closure($anonfun)
53+
| }
54+
| Type: Int => Int

tests/neg/annot-invalid.scala

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
class annot[T](arg: T) extends scala.annotation.Annotation
2+
3+
def main =
4+
val n: Int = 0
5+
def f(x: Any): Unit = ()
6+
7+
val x1: Int @annot(n + 1) = 0 // error
8+
val x2: Int @annot(f(2)) = 0 // error
9+
val x3: Int @annot(throw new Error()) = 0 // error
10+
val x4: Int @annot((x: Int) => x) = 0 // error
11+
12+
@annot(n + 1) val y1: Int = 0 // error
13+
@annot(f(2)) val y2: Int = 0 // error
14+
@annot(throw new Error()) val y3: Int = 0 // error
15+
@annot((x: Int) => x) val y4: Int = 0 // error
16+
17+
()

tests/neg/i15054.scala

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
import scala.annotation.Annotation
2+
3+
class AnAnnotation(function: Int => String) extends Annotation
4+
5+
@AnAnnotation(_.toString) // error: not a valid annotation
6+
val a = 1
7+
@AnAnnotation(_.toString.length.toString) // error: not a valid annotation
8+
val b = 2
9+
10+
def test =
11+
@AnAnnotation(_.toString) // error: not a valid annotation
12+
val a = 1
13+
@AnAnnotation(_.toString.length.toString) // error: not a valid annotation
14+
val b = 2
15+
a + b

tests/neg/i7740a.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
class A(a: Any) extends annotation.StaticAnnotation
2+
@A({val x = 0}) trait B // error: not a valid annotation

0 commit comments

Comments
 (0)