Skip to content

Commit 1d3e473

Browse files
committed
Add primitive compiletime operations on singleton types
1 parent 4237152 commit 1d3e473

File tree

9 files changed

+263
-14
lines changed

9 files changed

+263
-14
lines changed

compiler/src/dotty/tools/dotc/core/Definitions.scala

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,8 @@ class Definitions {
233233
@tu lazy val CompiletimeTesting_ErrorKind: Symbol = ctx.requiredModule("scala.compiletime.testing.ErrorKind")
234234
@tu lazy val CompiletimeTesting_ErrorKind_Parser: Symbol = CompiletimeTesting_ErrorKind.requiredMethod("Parser")
235235
@tu lazy val CompiletimeTesting_ErrorKind_Typer: Symbol = CompiletimeTesting_ErrorKind.requiredMethod("Typer")
236+
@tu lazy val CompiletimeIntPackageObject: Symbol = ctx.requiredModule("scala.compiletime.int.package")
237+
@tu lazy val CompiletimeBooleanPackageObject: Symbol = ctx.requiredModule("scala.compiletime.boolean.package")
236238

237239
/** The `scalaShadowing` package is used to safely modify classes and
238240
* objects in scala so that they can be used from dotty. They will
@@ -898,6 +900,28 @@ class Definitions {
898900
final def isCompiletime_S(sym: Symbol)(implicit ctx: Context): Boolean =
899901
sym.name == tpnme.S && sym.owner == CompiletimePackageObject.moduleClass
900902

903+
final def isCompiletimeAppliedType(sym: Symbol)(implicit ctx: Context): Boolean = {
904+
def isPackageObjectAppliedType: Boolean =
905+
sym.owner == CompiletimePackageObject.moduleClass && Set(
906+
tpnme.S, tpnme.Equals, tpnme.NotEquals
907+
).contains(sym.name)
908+
909+
def isIntAppliedType: Boolean =
910+
sym.owner == CompiletimeIntPackageObject.moduleClass && Set(
911+
tpnme.Plus, tpnme.Minus, tpnme.Times, tpnme.Div, tpnme.Mod,
912+
tpnme.Lt, tpnme.Gt, tpnme.Ge, tpnme.Le,
913+
tpnme.Abs, tpnme.Negate, tpnme.Min, tpnme.Max
914+
).contains(sym.name)
915+
916+
def isBooleanAppliedType: Boolean =
917+
sym.owner == CompiletimeBooleanPackageObject.moduleClass && Set(
918+
tpnme.Not, tpnme.Xor, tpnme.And, tpnme.Or
919+
).contains(sym.name)
920+
921+
isPackageObjectAppliedType || isIntAppliedType || isBooleanAppliedType
922+
}
923+
924+
901925
// ----- Symbol sets ---------------------------------------------------
902926

903927
@tu lazy val AbstractFunctionType: Array[TypeRef] = mkArityArray("scala.runtime.AbstractFunction", MaxImplementedFunctionArity, 0)

compiler/src/dotty/tools/dotc/core/StdNames.scala

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -201,13 +201,33 @@ object StdNames {
201201
final val Product: N = "Product"
202202
final val PartialFunction: N = "PartialFunction"
203203
final val PrefixType: N = "PrefixType"
204-
final val S: N = "S"
205204
final val Serializable: N = "Serializable"
206205
final val Singleton: N = "Singleton"
207206
final val Throwable: N = "Throwable"
208207
final val IOOBException: N = "IndexOutOfBoundsException"
209208
final val FunctionXXL: N = "FunctionXXL"
210209

210+
final val Abs: N = "Abs"
211+
final val And: N = "&&"
212+
final val Div: N = "/"
213+
final val Equals: N = "=="
214+
final val Ge: N = ">="
215+
final val Gt: N = ">"
216+
final val Le: N = "<="
217+
final val Lt: N = "<"
218+
final val Max: N = "Max"
219+
final val Min: N = "Min"
220+
final val Minus: N = "-"
221+
final val Mod: N = "%"
222+
final val Negate: N = "Negate"
223+
final val Not: N = "!"
224+
final val NotEquals: N = "!="
225+
final val Or: N = "||"
226+
final val Plus: N = "+"
227+
final val S: N = "S"
228+
final val Times: N = "*"
229+
final val Xor: N = "^"
230+
211231
final val ClassfileAnnotation: N = "ClassfileAnnotation"
212232
final val ClassManifest: N = "ClassManifest"
213233
final val Enum: N = "Enum"

compiler/src/dotty/tools/dotc/core/TypeApplications.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -376,7 +376,7 @@ class TypeApplications(val self: Type) extends AnyVal {
376376
}
377377
}
378378
if ((dealiased eq stripped) || followAlias)
379-
try dealiased.instantiate(args)
379+
try dealiased.instantiate(args).normalized
380380
catch { case ex: IndexOutOfBoundsException => AppliedType(self, args) }
381381
else AppliedType(self, args)
382382
}

compiler/src/dotty/tools/dotc/core/TypeComparer.scala

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -965,7 +965,7 @@ class TypeComparer(initctx: Context) extends ConstraintHandling[AbsentContext] w
965965
compareLower(bounds(param2), tyconIsTypeRef = false)
966966
case tycon2: TypeRef =>
967967
isMatchingApply(tp1) ||
968-
defn.isCompiletime_S(tycon2.symbol) && compareS(tp2, tp1, fromBelow = true) || {
968+
defn.isCompiletimeAppliedType(tycon2.symbol) && compareCompiletimeAppliedType(tp2, tp1, fromBelow = true) || {
969969
tycon2.info match {
970970
case info2: TypeBounds =>
971971
compareLower(info2, tyconIsTypeRef = true)
@@ -1005,7 +1005,7 @@ class TypeComparer(initctx: Context) extends ConstraintHandling[AbsentContext] w
10051005
case tycon1: TypeRef =>
10061006
val sym = tycon1.symbol
10071007
!sym.isClass && {
1008-
defn.isCompiletime_S(sym) && compareS(tp1, tp2, fromBelow = false) ||
1008+
defn.isCompiletimeAppliedType(sym) && compareCompiletimeAppliedType(tp1, tp2, fromBelow = false) ||
10091009
recur(tp1.superType, tp2) ||
10101010
tryLiftedToThis1
10111011
}
@@ -1037,6 +1037,11 @@ class TypeComparer(initctx: Context) extends ConstraintHandling[AbsentContext] w
10371037
case _ => false
10381038
}
10391039

1040+
def compareCompiletimeAppliedType(tp: AppliedType, other: Type, fromBelow: Boolean): Boolean = {
1041+
if (defn.isCompiletime_S(tp.tycon.typeSymbol)) compareS(tp, other, fromBelow)
1042+
else tp.tryCompiletimeConstantFold.exists(folded => recur(folded, other))
1043+
}
1044+
10401045
/** Like tp1 <:< tp2, but returns false immediately if we know that
10411046
* the case was covered previously during subtyping.
10421047
*/

compiler/src/dotty/tools/dotc/core/Types.scala

Lines changed: 63 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3595,19 +3595,72 @@ object Types {
35953595
case _ =>
35963596
NoType
35973597
}
3598-
if (defn.isCompiletime_S(tycon.symbol) && args.length == 1)
3599-
trace(i"normalize S $this", typr, show = true) {
3600-
args.head.normalized match {
3601-
case ConstantType(Constant(n: Int)) if n >= 0 && n < Int.MaxValue =>
3602-
ConstantType(Constant(n + 1))
3603-
case none => tryMatchAlias
3604-
}
3605-
}
3606-
else tryMatchAlias
3598+
3599+
tryCompiletimeConstantFold.getOrElse(tryMatchAlias)
3600+
36073601
case _ =>
36083602
NoType
36093603
}
36103604

3605+
def tryCompiletimeConstantFold(implicit ctx: Context): Option[Type] = tycon match {
3606+
case tycon: TypeRef if defn.isCompiletimeAppliedType(tycon.symbol) =>
3607+
def constValue(tp: Type): Option[Any] = tp match {
3608+
case ConstantType(Constant(n)) => Some(n)
3609+
case _ => None
3610+
}
3611+
3612+
def boolValue(tp: Type): Option[Boolean] = tp match {
3613+
case ConstantType(Constant(n: Boolean)) => Some(n)
3614+
case _ => None
3615+
}
3616+
3617+
def intValue(tp: Type): Option[Int] = tp match {
3618+
case ConstantType(Constant(n: Int)) => Some(n)
3619+
case _ => None
3620+
}
3621+
3622+
def natValue(tp: Type): Option[Int] = intValue(tp).filter(n => n >= 0 && n < Int.MaxValue)
3623+
3624+
def constantFold1[T](extractor: Type => Option[T], op: T => Any): Option[Type] =
3625+
extractor(args.head.normalized).map(a => ConstantType(Constant(op(a))))
3626+
3627+
def constantFold2[T](extractor: Type => Option[T], op: (T, T) => Any): Option[Type] =
3628+
for {
3629+
a <- extractor(args.head.normalized)
3630+
b <- extractor(args.tail.head.normalized)
3631+
} yield ConstantType(Constant(op(a, b)))
3632+
3633+
trace(i"compiletime constant fold $this", typr, show = true) {
3634+
if (args.length == 1) tycon.symbol.name match {
3635+
case tpnme.S => constantFold1(natValue, _ + 1)
3636+
case tpnme.Abs => constantFold1(intValue, _.abs)
3637+
case tpnme.Negate => constantFold1(intValue, x => -x)
3638+
case tpnme.Not => constantFold1(boolValue, x => !x)
3639+
case _ => None
3640+
} else if (args.length == 2) tycon.symbol.name match {
3641+
case tpnme.Equals => constantFold2(constValue, _ == _)
3642+
case tpnme.NotEquals => constantFold2(constValue, _ != _)
3643+
case tpnme.Plus => constantFold2(intValue, _ + _)
3644+
case tpnme.Minus => constantFold2(intValue, _ - _)
3645+
case tpnme.Times => constantFold2(intValue, _ * _)
3646+
case tpnme.Div => constantFold2(intValue, _ / _)
3647+
case tpnme.Mod => constantFold2(intValue, _ % _)
3648+
case tpnme.Lt => constantFold2(intValue, _ < _)
3649+
case tpnme.Gt => constantFold2(intValue, _ > _)
3650+
case tpnme.Ge => constantFold2(intValue, _ >= _)
3651+
case tpnme.Le => constantFold2(intValue, _ <= _)
3652+
case tpnme.Min => constantFold2(intValue, _ min _)
3653+
case tpnme.Max => constantFold2(intValue, _ max _)
3654+
case tpnme.And => constantFold2(boolValue, _ && _)
3655+
case tpnme.Or => constantFold2(boolValue, _ || _)
3656+
case tpnme.Xor => constantFold2(boolValue, _ ^ _)
3657+
case _ => None
3658+
} else None
3659+
}
3660+
3661+
case _ => None
3662+
}
3663+
36113664
def lowerBound(implicit ctx: Context): Type = tycon.stripTypeVar match {
36123665
case tycon: TypeRef =>
36133666
tycon.info match {
@@ -3974,7 +4027,7 @@ object Types {
39744027
myReduced =
39754028
trace(i"reduce match type $this $hashCode", typr, show = true) {
39764029
try
3977-
typeComparer.matchCases(scrutinee, cases)(trackingCtx)
4030+
typeComparer.matchCases(scrutinee.normalized, cases)(trackingCtx)
39784031
catch {
39794032
case ex: Throwable =>
39804033
handleRecursive("reduce type ", i"$scrutinee match ...", ex)
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
package scala.compiletime
2+
3+
package object boolean {
4+
type ![X <: Boolean] <: Boolean
5+
type ^[X <: Boolean, Y <: Boolean] <: Boolean
6+
type &&[X <: Boolean, Y <: Boolean] <: Boolean
7+
type ||[X <: Boolean, Y <: Boolean] <: Boolean
8+
}
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
package scala.compiletime
2+
3+
package object int {
4+
type +[X <: Int, Y <: Int] <: Int
5+
type -[X <: Int, Y <: Int] <: Int
6+
type *[X <: Int, Y <: Int] <: Int
7+
type /[X <: Int, Y <: Int] <: Int
8+
type %[X <: Int, Y <: Int] <: Int
9+
10+
type <[X <: Int, Y <: Int] <: Boolean
11+
type >[X <: Int, Y <: Int] <: Boolean
12+
type >=[X <: Int, Y <: Int] <: Boolean
13+
type <=[X <: Int, Y <: Int] <: Boolean
14+
15+
type Abs[X <: Int] <: Int
16+
type Negate[X <: Int] <: Int
17+
type Min[X <: Int, Y <: Int] <: Int
18+
type Max[X <: Int, Y <: Int] <: Int
19+
}

library/src/scala/compiletime/package.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,4 +63,7 @@ package object compiletime {
6363
* }
6464
*/
6565
type S[N <: Int] <: Int
66+
67+
type ==[X <: AnyVal, Y <: AnyVal] <: Boolean
68+
type !=[X <: AnyVal, Y <: AnyVal] <: Boolean
6669
}
Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
import scala.compiletime._
2+
import scala.compiletime.int._
3+
import scala.compiletime.boolean._
4+
5+
object Test {
6+
val t0: 2 + 3 = 5
7+
val t1: 2 + 2 = 5 // error
8+
val t2: -1 + 1 = 0
9+
val t3: -5 + -5 = -11 // error
10+
11+
val t4: 10 * 20 = 200
12+
val t5: 30 * 10 = 400 // error
13+
val t6: -10 * 2 = -20
14+
val t7: -2 * -2 = 4
15+
16+
val t8: 10 / 2 = 5
17+
val t9: 11 / 2 = 5 // Integer division
18+
val t10: 2 / 4 = 2 // error
19+
val t11: -1 / -1 = 1
20+
21+
val t12: 10 % 3 = 1
22+
val t13: 12 % 2 = 1 // error
23+
val t14: 1 % -3 = 1
24+
val t15: -3 % -2 = 0 // error
25+
26+
val t16: 1 < 0 = false
27+
val t17: 0 < 1 = true
28+
val t18: 10 < 5 = true // error
29+
val t19: 5 < 10 = false // error
30+
31+
val t20: 1 <= 0 = false
32+
val t21: 1 <= 1 = true
33+
val t22: 10 <= 5 = true // error
34+
val t23: 5 <= 10 = false // error
35+
36+
val t24: 1 > 0 = true
37+
val t25: 0 > 1 = false
38+
val t26: 10 > 5 = false // error
39+
val t27: 5 > 10 = true // error
40+
41+
val t28: 1 >= 1 = true
42+
val t29: 0 >= 1 = false
43+
val t30: 10 >= 5 = false // error
44+
val t31: 5 >= 10 = true // error
45+
46+
val t32: 1 == 1 = true
47+
val t33: 0 == 1 = false
48+
val t34: 10 == 5 = true // error
49+
val t35: 10 == 10 = false // error
50+
51+
val t36: 1 != 1 = false
52+
val t37: 0 != 1 = true
53+
val t38: 10 != 5 = false // error
54+
val t39: 10 != 10 = true // error
55+
56+
val t40: Abs[0] = 0
57+
val t41: Abs[-1] = 1
58+
val t42: Abs[-1] = -1 // error
59+
val t43: Abs[1] = -1 // error
60+
61+
val t44: Negate[-10] = 10
62+
val t45: Negate[10] = -10
63+
val t46: Negate[1] = 1 // error
64+
val t47: Negate[-1] = -1 // error
65+
66+
val t48: Max[-1, 10] = 10
67+
val t49: Max[4, 2] = 4
68+
val t50: Max[2, 2] = 1 // error
69+
val t51: Max[-1, -1] = 0 // error
70+
71+
val t52: Min[-1, 10] = -1
72+
val t53: Min[4, 2] = 2
73+
val t54: Min[2, 2] = 1 // error
74+
val t55: Min[-1, -1] = 0 // error
75+
76+
val t56: true && true = true
77+
val t57: true && false = false
78+
val t58: false && true = true // error
79+
val t59: false && false = true // error
80+
81+
val t60: true || true = true
82+
val t61: true || false = true
83+
val t62: false || true = false // error
84+
val t63: false || false = true // error
85+
86+
val t64: ![true] = false
87+
val t65: ![false] = true
88+
val t66: ![true] = true // error
89+
val t67: ![false] = false // error
90+
91+
// Test singleton ops in type alias:
92+
type Xor[A <: Boolean, B <: Boolean] = (A && ![B]) || (![A] && B)
93+
val t68: Xor[true, true] = false
94+
val t69: Xor[false, true] = true
95+
val t70: Xor[true, false] = false // error
96+
val t71: Xor[false, false] = true // error
97+
98+
// Test singleton ops in recursive match types:
99+
type GCD[A <: Int, B <: Int] <: Int = B match {
100+
case 0 => A
101+
case _ => GCD[B, A % B]
102+
}
103+
val t72: GCD[10, 0] = 10
104+
val t73: GCD[252, 105] = 21
105+
val t74: GCD[105, 147] = 10 // error
106+
val t75: GCD[1, 1] = -1 // error
107+
108+
// Test singleton ops in match type scrutinee:
109+
type Max2[A <: Int, B <: Int] <: Int = (A < B) match {
110+
case true => B
111+
case false => A
112+
}
113+
val t76: Max[-1, 10] = 10
114+
val t77: Max[4, 2] = 4
115+
val t78: Max[2, 2] = 1 // error
116+
val t79: Max[-1, -1] = 0 // error
117+
}

0 commit comments

Comments
 (0)