Skip to content

Commit 009c2a4

Browse files
committed
Experiment: restrict allowed trees in type annotations
1 parent aa9115d commit 009c2a4

File tree

7 files changed

+67
-51
lines changed

7 files changed

+67
-51
lines changed

compiler/src/dotty/tools/dotc/core/Annotations.scala

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@ package dotc
33
package core
44

55
import Symbols.*, Types.*, Contexts.*, Constants.*, Phases.*
6+
import Decorators.i
7+
import StdNames.nme
68
import ast.tpd, tpd.*
79
import util.Spans.Span
810
import printing.{Showable, Printer}
@@ -106,6 +108,54 @@ object Annotations {
106108
go(metaSyms) || orNoneOf.nonEmpty && !go(orNoneOf)
107109
}
108110

111+
/** True if this annotation can be used as a type annotation, false otherwise.
112+
*
113+
* An annotation is a valid type annotation if its tree is one a `Literal`.
114+
*
115+
* Can be overridden.
116+
*/
117+
def checkValidTypeAnnotation()(using Context): Unit =
118+
def isTupleModule(sym: Symbol): Boolean =
119+
ctx.definitions.isTupleClass(sym.companionClass)
120+
121+
def isFunctionAllowed(t: Tree): Boolean =
122+
t match
123+
case Select(qual, nme.apply) => qual.symbol == defn.ArrayModule || isTupleModule(qual.symbol)
124+
case TypeApply(fun, _) => isFunctionAllowed(fun)
125+
case _ => false
126+
127+
def check(t: Tree): Boolean =
128+
t match
129+
case Literal(_) => true
130+
case Typed(expr, _) => check(expr)
131+
case SeqLiteral(elems, _) => elems.forall(check)
132+
case Apply(fun, args) => isFunctionAllowed(fun) && args.forall(check)
133+
case NamedArg(_, arg) => check(arg)
134+
case _ =>
135+
t.tpe.stripped match
136+
case _: SingletonType => true
137+
// We need to handle type refs for these test cases:
138+
// - tests/pos/dependent-annot.scala
139+
// - tests/pos/i16208.scala
140+
// - tests/run/java-ann-super-class.scala
141+
// - tests/run/java-ann-super-class-separate.scala
142+
// - tests/neg/i19470.scala (@retains)
143+
// Why do we get type refs in these cases?
144+
case _: TypeRef => true
145+
case _: TypeParamRef => true
146+
case tp => false
147+
148+
val uncheckedAnnots = Set[Symbol](defn.RetainsAnnot, defn.RetainsByNameAnnot)
149+
if uncheckedAnnots(symbol) then return
150+
151+
for arg <- arguments if !check(arg) do
152+
report.error(
153+
s"""Implementation restriction: not a valid type annotation argument.
154+
| Argument: $arg
155+
| Type: ${arg.tpe}""".stripMargin, arg.srcPos)
156+
157+
()
158+
109159
/** Operations for hash-consing, can be overridden */
110160
def hash: Int = System.identityHashCode(this)
111161
def eql(that: Annotation) = this eq that

compiler/src/dotty/tools/dotc/core/Types.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5754,6 +5754,7 @@ object Types extends TypeUtils {
57545754
def make(underlying: Type, annots: List[Annotation])(using Context): Type =
57555755
annots.foldLeft(underlying)(apply(_, _))
57565756
def apply(parent: Type, annot: Annotation)(using Context): AnnotatedType =
5757+
annot.checkValidTypeAnnotation()
57575758
unique(CachedAnnotatedType(parent, annot))
57585759
end AnnotatedType
57595760

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
class annot[T](arg: T) extends scala.annotation.Annotation
2+
3+
def m1(x: Any): Unit = ()
4+
val x: Int = ???
5+
6+
def m2(): String @annot(m1(2)) = ??? // error
7+
def m3(): String @annot(throw new Error()) = ??? // error
8+
def m4(): String @annot((x: Int) => x) = ??? // error
9+
def m5(): String @annot(x + 1) = ??? // error
10+
11+
def main =
12+
@annot(m1(2)) val x1: String = ??? // ok
13+
@annot(throw new Error()) val x2: String = ??? // ok
14+
@annot((x: Int) => x) val x3: String = ??? // ok
15+
@annot(x + 1) val x4: String = ??? // ok
16+
()

tests/pos/annot-17939b.scala

Lines changed: 0 additions & 10 deletions
This file was deleted.

tests/pos/annotDepMethType.scala

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,5 +3,4 @@ case class pc(calls: Any*) extends annotation.TypeConstraint
33
object Main {
44
class C0 { def baz: String = "" }
55
class C1 { def bar(c0: C0): String @pc(c0.baz) = c0.baz }
6-
def trans(c1: C1): String @pc(c1.bar(throw new Error())) = c1.bar(new C0)
76
}

tests/printing/annot-19846b.check

Lines changed: 0 additions & 33 deletions
This file was deleted.

tests/printing/annot-19846b.scala

Lines changed: 0 additions & 7 deletions
This file was deleted.

0 commit comments

Comments
 (0)