Skip to content

Commit a5df0e7

Browse files
committed
Implement @main functions
1 parent 0b52037 commit a5df0e7

File tree

11 files changed

+268
-7
lines changed

11 files changed

+268
-7
lines changed
Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
package dotty.tools.dotc
2+
package ast
3+
4+
import core._
5+
import Symbols._, Types._, Contexts._, Decorators._, util.Spans._, Flags._, Constants._
6+
import StdNames.nme
7+
import ast.Trees._
8+
9+
/** Generate proxy classes for @main functions.
10+
* A function like
11+
*
12+
* @main f(x: S, ys: T*) = ...
13+
*
14+
* would be translated to something like
15+
*
16+
* import CommandLineParser._
17+
* class f {
18+
* @static def main(args: Array[String]): Unit =
19+
* try
20+
* f(
21+
* parseArgument[S](args, 0),
22+
* parseRemainingArguments[T](args, 1): _*
23+
* )
24+
* catch case err: ParseError => showError(err)
25+
* }
26+
*/
27+
object MainProxies {
28+
29+
def mainProxies(stats: List[tpd.Tree]) given Context: List[untpd.Tree] = {
30+
import tpd._
31+
def mainMethods(stats: List[Tree]): List[Symbol] = stats.flatMap {
32+
case stat: DefDef if stat.symbol.hasAnnotation(defn.MainAnnot) =>
33+
stat.symbol :: Nil
34+
case stat @ TypeDef(name, impl: Template) if stat.symbol.is(Module) =>
35+
mainMethods(impl.body)
36+
case _ =>
37+
Nil
38+
}
39+
mainMethods(stats).flatMap(mainProxy)
40+
}
41+
42+
import untpd._
43+
def mainProxy(mainFun: Symbol) given (ctx: Context): List[TypeDef] = {
44+
val mainAnnotSpan = mainFun.getAnnotation(defn.MainAnnot).get.tree.span
45+
46+
val argsRef = Ident(nme.args)
47+
48+
def addArgs(call: untpd.Tree, formals: List[Type], restpe: Type, idx: Int): untpd.Tree = {
49+
val args = formals.zipWithIndex map {
50+
(formal, n) =>
51+
val (parserSym, formalElem) =
52+
if (formal.isRepeatedParam) (defn.CLP_parseRemainingArguments, formal.argTypes.head)
53+
else (defn.CLP_parseArgument, formal)
54+
val arg = Apply(
55+
TypeApply(ref(parserSym.termRef), TypeTree(formalElem) :: Nil),
56+
argsRef :: Literal(Constant(idx + n)) :: Nil)
57+
if (formal.isRepeatedParam) repeated(arg) else arg
58+
}
59+
val call1 = Apply(call, args)
60+
restpe match {
61+
case restpe: MethodType if !restpe.isImplicitMethod =>
62+
if (formals.lastOption.getOrElse(NoType).isRepeatedParam)
63+
ctx.error(s"varargs parameter of @main method must come last", mainFun.sourcePos)
64+
addArgs(call1, restpe.paramInfos, restpe.resType, idx + args.length)
65+
case _ =>
66+
call1
67+
}
68+
}
69+
70+
var result: List[TypeDef] = Nil
71+
if (!mainFun.owner.isStaticOwner)
72+
ctx.error(s"@main method is not statically accessible", mainFun.sourcePos)
73+
else {
74+
var call = ref(mainFun.termRef)
75+
mainFun.info match {
76+
case _: ExprType =>
77+
case mt: MethodType =>
78+
if (!mt.isImplicitMethod) call = addArgs(call, mt.paramInfos, mt.resultType, 0)
79+
case _: PolyType =>
80+
ctx.error(s"@main method cannot have type parameters", mainFun.sourcePos)
81+
case _ =>
82+
ctx.error(s"@main can only annotate a method", mainFun.sourcePos)
83+
}
84+
val errVar = Ident(nme.error)
85+
val handler = CaseDef(
86+
Typed(errVar, TypeTree(defn.CLP_ParseError.typeRef)),
87+
EmptyTree,
88+
Apply(ref(defn.CLP_showError.termRef), errVar :: Nil))
89+
val body = Try(call, handler :: Nil, EmptyTree)
90+
val mainArg = ValDef(nme.args, TypeTree(defn.ArrayType.appliedTo(defn.StringType)), EmptyTree)
91+
.withFlags(Param)
92+
val mainMeth = DefDef(nme.main, Nil, (mainArg :: Nil) :: Nil, TypeTree(defn.UnitType), body)
93+
.withFlags(JavaStatic)
94+
val mainTempl = Template(emptyConstructor, Nil, Nil, EmptyValDef, mainMeth :: Nil)
95+
val mainCls = TypeDef(mainFun.name.toTypeName, mainTempl)
96+
if (!ctx.reporter.hasErrors) result = mainCls.withSpan(mainAnnotSpan) :: Nil
97+
}
98+
result
99+
}
100+
}

compiler/src/dotty/tools/dotc/core/Definitions.scala

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -695,10 +695,16 @@ class Definitions {
695695

696696
@threadUnsafe lazy val ValueOfClass: ClassSymbolPerRun = perRunClass(ctx.requiredClassRef("scala.ValueOf"))
697697
@threadUnsafe lazy val StatsModule: SymbolPerRun = perRunSym(ctx.requiredModuleRef("dotty.tools.dotc.util.Stats"))
698-
@threadUnsafe lazy val Stats_doRecord: SymbolPerRun = perRunSym(StatsModule.requiredMethodRef("doRecord"))
698+
@threadUnsafe lazy val Stats_doRecord: SymbolPerRun = perRunSym(StatsModule.requiredMethodRef("doRecord"))
699699

700700
@threadUnsafe lazy val XMLTopScopeModule: SymbolPerRun = perRunSym(ctx.requiredModuleRef("scala.xml.TopScope"))
701701

702+
@threadUnsafe lazy val CommandLineParserModule: SymbolPerRun = perRunSym(ctx.requiredModuleRef("scala.util.CommandLineParser"))
703+
@threadUnsafe lazy val CLP_ParseError: ClassSymbolPerRun = perRunClass(CommandLineParserModule.requiredClass("ParseError").typeRef)
704+
@threadUnsafe lazy val CLP_parseArgument: SymbolPerRun = perRunSym(CommandLineParserModule.requiredMethodRef("parseArgument"))
705+
@threadUnsafe lazy val CLP_parseRemainingArguments: SymbolPerRun = perRunSym(CommandLineParserModule.requiredMethodRef("parseRemainingArguments"))
706+
@threadUnsafe lazy val CLP_showError: SymbolPerRun = perRunSym(CommandLineParserModule.requiredMethodRef("showError"))
707+
702708
@threadUnsafe lazy val TupleTypeRef: TypeRef = ctx.requiredClassRef("scala.Tuple")
703709
def TupleClass(implicit ctx: Context): ClassSymbol = TupleTypeRef.symbol.asClass
704710
@threadUnsafe lazy val Tuple_cons: SymbolPerRun = perRunSym(TupleClass.requiredMethodRef("*:"))
@@ -712,8 +718,8 @@ class Definitions {
712718

713719
def TupleXXL_fromIterator(implicit ctx: Context): Symbol = TupleXXLModule.requiredMethod("fromIterator")
714720

715-
lazy val DynamicTupleModule: Symbol = ctx.requiredModule("scala.runtime.DynamicTuple")
716-
lazy val DynamicTupleModuleClass: Symbol = DynamicTupleModule.moduleClass
721+
@threadUnsafe lazy val DynamicTupleModule: Symbol = ctx.requiredModule("scala.runtime.DynamicTuple")
722+
@threadUnsafe lazy val DynamicTupleModuleClass: Symbol = DynamicTupleModule.moduleClass
717723
lazy val DynamicTuple_consIterator: Symbol = DynamicTupleModule.requiredMethod("consIterator")
718724
lazy val DynamicTuple_concatIterator: Symbol = DynamicTupleModule.requiredMethod("concatIterator")
719725
lazy val DynamicTuple_dynamicApply: Symbol = DynamicTupleModule.requiredMethod("dynamicApply")
@@ -724,10 +730,10 @@ class Definitions {
724730
lazy val DynamicTuple_dynamicToArray: Symbol = DynamicTupleModule.requiredMethod("dynamicToArray")
725731
lazy val DynamicTuple_productToArray: Symbol = DynamicTupleModule.requiredMethod("productToArray")
726732

727-
lazy val TupledFunctionTypeRef: TypeRef = ctx.requiredClassRef("scala.TupledFunction")
733+
@threadUnsafe lazy val TupledFunctionTypeRef: TypeRef = ctx.requiredClassRef("scala.TupledFunction")
728734
def TupledFunctionClass(implicit ctx: Context): ClassSymbol = TupledFunctionTypeRef.symbol.asClass
729735

730-
lazy val InternalTupledFunctionTypeRef: TypeRef = ctx.requiredClassRef("scala.internal.TupledFunction")
736+
@threadUnsafe lazy val InternalTupledFunctionTypeRef: TypeRef = ctx.requiredClassRef("scala.internal.TupledFunction")
731737
def InternalTupleFunctionClass(implicit ctx: Context): ClassSymbol = InternalTupledFunctionTypeRef.symbol.asClass
732738
def InternalTupleFunctionModule(implicit ctx: Context): Symbol = ctx.requiredModule("scala.internal.TupledFunction")
733739

@@ -751,6 +757,7 @@ class Definitions {
751757
@threadUnsafe lazy val ForceInlineAnnot: ClassSymbolPerRun = perRunClass(ctx.requiredClassRef("scala.forceInline"))
752758
@threadUnsafe lazy val InlineParamAnnot: ClassSymbolPerRun = perRunClass(ctx.requiredClassRef("scala.annotation.internal.InlineParam"))
753759
@threadUnsafe lazy val InvariantBetweenAnnot: ClassSymbolPerRun = perRunClass(ctx.requiredClassRef("scala.annotation.internal.InvariantBetween"))
760+
@threadUnsafe lazy val MainAnnot: ClassSymbolPerRun = perRunClass(ctx.requiredClassRef("scala.main"))
754761
@threadUnsafe lazy val MigrationAnnot: ClassSymbolPerRun = perRunClass(ctx.requiredClassRef("scala.annotation.migration"))
755762
@threadUnsafe lazy val NativeAnnot: ClassSymbolPerRun = perRunClass(ctx.requiredClassRef("scala.native"))
756763
@threadUnsafe lazy val RepeatedAnnot: ClassSymbolPerRun = perRunClass(ctx.requiredClassRef("scala.annotation.internal.Repeated"))

compiler/src/dotty/tools/dotc/typer/Typer.scala

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ package dotc
33
package typer
44

55
import core._
6-
import ast.{tpd, _}
6+
import ast._
77
import Trees._
88
import Constants._
99
import StdNames._
@@ -1793,7 +1793,9 @@ class Typer extends Namer
17931793
case pid1: RefTree if pkg.exists =>
17941794
if (!pkg.is(Package)) ctx.error(PackageNameAlreadyDefined(pkg), tree.sourcePos)
17951795
val packageCtx = ctx.packageContext(tree, pkg)
1796-
val stats1 = typedStats(tree.stats, pkg.moduleClass)(packageCtx)
1796+
var stats1 = typedStats(tree.stats, pkg.moduleClass)(packageCtx)
1797+
if (!ctx.isAfterTyper)
1798+
stats1 = stats1 ++ typedBlockStats(MainProxies.mainProxies(stats1))(packageCtx)._2
17971799
cpy.PackageDef(tree)(pid1, stats1).withType(pkg.termRef)
17981800
case _ =>
17991801
// Package will not exist if a duplicate type has already been entered, see `tests/neg/1708.scala`

library/src/scala/main.scala

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
/* __ *\
2+
** ________ ___ / / ___ Scala API **
3+
** / __/ __// _ | / / / _ | (c) 2002-2013, LAMP/EPFL **
4+
** __\ \/ /__/ __ |/ /__/ __ | http://scala-lang.org/ **
5+
** /____/\___/_/ |_/____/_/ | | **
6+
** |/ **
7+
\* */
8+
9+
package scala
10+
11+
/** An annotation that designates a main function
12+
*/
13+
class main extends scala.annotation.Annotation {}
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
package scala.util
2+
3+
object CommandLineParser {
4+
5+
/** An exception raised for an illegal command line
6+
* @param idx The index of the argument that's faulty (starting from 0)
7+
* @param msg The error message
8+
*/
9+
class ParseError(val idx: Int, val msg: String) extends Exception
10+
11+
/** Parse command line argument `s`, which has index `n`, as a value of type `T` */
12+
def parseString[T](str: String, n: Int) given (fs: FromString[T]): T = {
13+
try fs.fromString(str)
14+
catch {
15+
case ex: IllegalArgumentException => throw ParseError(n, ex.toString)
16+
}
17+
}
18+
19+
/** Parse `n`'th argument in `args` (counting from 0) as a value of type `T` */
20+
def parseArgument[T](args: Array[String], n: Int) given (fs: FromString[T]): T =
21+
if n < args.length then parseString(args(n), n)
22+
else throw ParseError(n, "more arguments expected")
23+
24+
/** Parse all arguments from `n`'th one (counting from 0) as a list of values of type `T` */
25+
def parseRemainingArguments[T](args: Array[String], n: Int) given (fs: FromString[T]): List[T] =
26+
if n < args.length then parseString(args(n), n) :: parseRemainingArguments(args, n + 1)
27+
else Nil
28+
29+
/** Print error message explaining given ParserError */
30+
def showError(err: ParseError): Unit = {
31+
val where =
32+
if err.idx == 0 then ""
33+
else if err.idx == 1 then " after first argument"
34+
else s" after ${err.idx} arguments"
35+
println(s"Illegal command line$where: ${err.msg}")
36+
}
37+
}
38+
39+
/* A function like
40+
41+
@main f(x: S, ys: T*) = ...
42+
43+
would be translated to something like
44+
45+
import CommandLineParser._
46+
class f {
47+
@static def main(args: Array[String]): Unit =
48+
try
49+
f(
50+
parseArgument[S](args, 0),
51+
parseRemainingArguments[T](args, 1): _*
52+
)
53+
catch case err: ParseError => showError(err)
54+
}
55+
*/
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
package scala.util
2+
3+
trait FromString[T] {
4+
/** Can throw java.lang.IllegalArgumentException */
5+
def fromString(s: String): T
6+
7+
def fromStringOption(s: String): Option[T] =
8+
try Some(fromString(s))
9+
catch {
10+
case ex: IllegalArgumentException => None
11+
}
12+
}
13+
14+
object FromString {
15+
16+
delegate for FromString[String] {
17+
def fromString(s: String) = s
18+
}
19+
20+
delegate for FromString[Boolean] {
21+
def fromString(s: String) = s.toBoolean
22+
}
23+
24+
delegate for FromString[Byte] {
25+
def fromString(s: String) = s.toByte
26+
}
27+
28+
delegate for FromString[Short] {
29+
def fromString(s: String) = s.toShort
30+
}
31+
32+
delegate for FromString[Int] {
33+
def fromString(s: String) = s.toInt
34+
}
35+
36+
delegate for FromString[Long] {
37+
def fromString(s: String) = s.toLong
38+
}
39+
40+
delegate for FromString[Float] {
41+
def fromString(s: String) = s.toFloat
42+
}
43+
44+
delegate for FromString[Double] {
45+
def fromString(s: String) = s.toDouble
46+
}
47+
}
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
2+
object foo {
3+
@main def foo(x: Int) = () // error: class foo differs only in case from object foo
4+
}
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
2+
class bar
3+
object bar {
4+
@main def bar(x: Int) = () // error: class bar has already been compiled once during this run
5+
}
6+
7+
object baz {
8+
@main def bam(x: Int): Unit = ()
9+
@main def bam(x: String): Unit = () // error: class bam has already been compiled once during this run
10+
}

tests/neg/main-functions.scala

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
object Test1 {
2+
@main def f(x: Foo) = () // error: no implicit argument of type util.FromString[Foo] was found
3+
}
4+
5+
object Test2 {
6+
@main val x = 2 // does nothing, should this be made an error?
7+
}
8+
9+
class Foo {
10+
@main def f = () // does nothing, should this be made an error?
11+
}
12+
13+
@main def g(x: Int*)(y: Int*) = () // error: varargs parameter of @main method must come last
14+
15+
@main def h[T: util.FromString](x: T) = () // error: @main method cannot have type parameters

tests/run/main-functions.check

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Hello, world!

0 commit comments

Comments
 (0)