Skip to content

Commit c18243e

Browse files
committed
DropForMap conversion, extension, inline, args
1 parent 207604b commit c18243e

File tree

6 files changed

+270
-95
lines changed

6 files changed

+270
-95
lines changed

compiler/src/dotty/tools/dotc/transform/localopt/DropForMap.scala

Lines changed: 105 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,54 +1,130 @@
11
package dotty.tools.dotc
22
package transform.localopt
33

4+
import dotty.tools.dotc.ast.desugar.TrailingForMap
45
import dotty.tools.dotc.ast.tpd.*
5-
import dotty.tools.dotc.core.Decorators.*
66
import dotty.tools.dotc.core.Contexts.*
7+
import dotty.tools.dotc.core.Decorators.*
8+
import dotty.tools.dotc.core.Flags.*
79
import dotty.tools.dotc.core.StdNames.*
810
import dotty.tools.dotc.core.Symbols.*
911
import dotty.tools.dotc.core.Types.*
1012
import dotty.tools.dotc.transform.MegaPhase.MiniPhase
11-
import dotty.tools.dotc.ast.desugar
1213

1314
/** Drop unused trailing map calls in for comprehensions.
14-
* We can drop the map call if:
15-
* - it won't change the type of the expression, and
16-
* - the function is an identity function or a const function to unit.
17-
*
18-
* The latter condition is checked in [[Desugar.scala#makeFor]]
19-
*/
15+
*
16+
* We can drop the map call if:
17+
* - it won't change the type of the expression, and
18+
* - the function is an identity function or a const function to unit.
19+
*
20+
* The latter condition is checked in [[Desugar.scala#makeFor]]
21+
*/
2022
class DropForMap extends MiniPhase:
21-
import DropForMap.*
2223

2324
override def phaseName: String = DropForMap.name
2425

2526
override def description: String = DropForMap.description
2627

27-
override def transformApply(tree: Apply)(using Context): Tree =
28-
if !tree.hasAttachment(desugar.TrailingForMap) then tree
29-
else tree match
30-
case aply @ Apply(MapCall(f), List(Lambda(List(param), body)))
31-
if f.tpe =:= aply.tpe => // make sure that the type of the expression won't change
28+
import DropForMap.{Converted, Unmapped}
29+
30+
/** r.map(x => x)(using y) --> r
31+
* ^ TrailingForMap
32+
*/
33+
override def transformApply(tree: Apply)(using Context): Tree = tree match
34+
case Unmapped(f0, sym, args) =>
35+
val f =
36+
if sym.is(Extension) then args.head
37+
else f0
38+
if f.tpe.widen =:= tree.tpe then // make sure that the type of the expression won't change
3239
f // drop the map call
33-
case _ =>
34-
tree.removeAttachment(desugar.TrailingForMap)
35-
tree
40+
else
41+
f match
42+
case Converted(r) if r.tpe =:= tree.tpe => r // drop the map call and the conversion
43+
case _ => tree
44+
case tree => tree
3645

37-
private object Lambda:
38-
def unapply(tree: Tree)(using Context): Option[(List[ValDef], Tree)] =
39-
tree match
40-
case Block(List(defdef: DefDef), Closure(Nil, ref, _))
41-
if ref.symbol == defdef.symbol && !defdef.paramss.exists(_.forall(_.isType)) =>
42-
Some((defdef.termParamss.flatten, defdef.rhs))
46+
/** If the map was inlined, fetch the binding for the receiver,
47+
* then find the tree in the expansion that refers to the binding.
48+
* That is the expansion of the result Inlined node.
49+
*/
50+
override def transformInlined(tree: Inlined)(using Context): Tree = tree match
51+
case Inlined(call, bindings, expansion) if call.hasAttachment(TrailingForMap) =>
52+
val expansion1 =
53+
call match
54+
case Unmapped(f0, sym, args) =>
55+
val f =
56+
if sym.is(Extension) then args.head
57+
else f0
58+
if f.tpe.widen =:= expansion.tpe then
59+
bindings.collectFirst:
60+
case vd: ValDef if f.sameTree(vd.rhs) =>
61+
expansion.find:
62+
case Inlined(Thicket(Nil), Nil, Ident(ident)) => ident == vd.name
63+
case _ => false
64+
.getOrElse(expansion)
65+
.getOrElse(expansion)
66+
else
67+
f match
68+
case Converted(r) if r.tpe =:= expansion.tpe => r // drop the map call and the conversion
69+
case _ => expansion
70+
case _ => expansion
71+
if expansion1 ne expansion then
72+
cpy.Inlined(tree)(call, bindings, expansion1)
73+
else tree
74+
case tree => tree
75+
76+
object DropForMap:
77+
val name: String = "dropForMap"
78+
val description: String = "Drop unused trailing map calls in for comprehensions"
79+
80+
// Extracts a fun from a possibly nested Apply with lambda and arbitrary implicit args.
81+
// Specifically, an application `r.map(x => x)` is destructured into (r, map, args).
82+
// If the receiver r was adapted, it is unwrapped.
83+
// If `map` is an extension method, the nominal receiver is `args.head`.
84+
private object Unmapped:
85+
private def loop(tree: Tree, args: List[Tree])(using Context): Option[(Tree, Symbol, List[Tree])] = tree match
86+
case Apply(fun, args @ Lambda(_ :: Nil, _) :: Nil) =>
87+
tree.removeAttachment(TrailingForMap) match
88+
case Some(_) =>
89+
fun match
90+
case MapCall(f, sym, args) => Some((f, sym, args))
91+
case _ => None
92+
case _ => None
93+
case Apply(fun, _) =>
94+
fun.tpe match
95+
case mt: MethodType if mt.isImplicitMethod => loop(fun, args)
4396
case _ => None
97+
case TypeApply(fun, _) => loop(fun, args)
98+
case _ => None
99+
end loop
100+
def unapply(tree: Apply)(using Context): Option[(Tree, Symbol, List[Tree])] =
101+
tree.tpe match
102+
case _: MethodOrPoly => None
103+
case _ => loop(tree, args = Nil)
104+
105+
private object Lambda:
106+
def unapply(tree: Tree)(using Context): Option[(List[ValDef], Tree)] = tree match
107+
case Block(List(defdef: DefDef), Closure(Nil, ref, _))
108+
if ref.symbol == defdef.symbol && !defdef.paramss.exists(_.forall(_.isType)) =>
109+
Some((defdef.termParamss.flatten, defdef.rhs))
110+
case _ => None
44111

45112
private object MapCall:
113+
def unapply(tree: Tree)(using Context): Option[(Tree, Symbol, List[Tree])] =
114+
def loop(tree: Tree, args: List[Tree]): Option[(Tree, Symbol, List[Tree])] =
115+
tree match
116+
case Ident(nme.map) if tree.symbol.is(Extension) => Some((EmptyTree, tree.symbol, args))
117+
case Select(f, nme.map) => Some((f, tree.symbol, args))
118+
case Apply(fn, args) => loop(fn, args)
119+
case TypeApply(fn, _) => loop(fn, args)
120+
case _ => None
121+
loop(tree, Nil)
122+
123+
private object Converted:
46124
def unapply(tree: Tree)(using Context): Option[Tree] = tree match
47-
case Select(f, nme.map) => Some(f)
48-
case Apply(fn, _) => unapply(fn)
125+
case Apply(fn @ Apply(_, _), _) => unapply(fn)
126+
case Apply(fn, r :: Nil)
127+
if fn.symbol.is(Implicit) || fn.symbol.name == nme.apply && fn.symbol.owner.derivesFrom(defn.ConversionClass)
128+
=> Some(r)
49129
case TypeApply(fn, _) => unapply(fn)
50130
case _ => None
51-
52-
object DropForMap:
53-
val name: String = "dropForMap"
54-
val description: String = "Drop unused trailing map calls in for comprehensions"

tests/run/better-fors-map-elim.check

Lines changed: 0 additions & 4 deletions
This file was deleted.

tests/run/better-fors-map-elim.scala

Lines changed: 44 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -1,62 +1,44 @@
1-
class myOptionModule(doOnMap: => Unit) {
2-
sealed trait MyOption[+A] {
3-
def map[B](f: A => B): MyOption[B] = this match {
4-
case MySome(x) => {
5-
doOnMap
6-
MySome(f(x))
7-
}
8-
case MyNone => MyNone
9-
}
10-
def flatMap[B](f: A => MyOption[B]): MyOption[B] = this match {
11-
case MySome(x) => f(x)
12-
case MyNone => MyNone
13-
}
14-
}
15-
case class MySome[A](x: A) extends MyOption[A]
16-
case object MyNone extends MyOption[Nothing]
17-
object MyOption {
18-
def apply[A](x: A): MyOption[A] = MySome(x)
19-
}
20-
}
21-
22-
object Test extends App {
23-
24-
val myOption = new myOptionModule(println("map called"))
25-
26-
import myOption.*
27-
28-
def portablePrintMyOption(opt: MyOption[Any]): Unit =
29-
if opt == MySome(()) then
30-
println("MySome(())")
31-
else
32-
println(opt)
33-
34-
val z = for {
35-
a <- MyOption(1)
36-
b <- MyOption(())
37-
} yield ()
38-
39-
portablePrintMyOption(z)
40-
41-
val z2 = for {
42-
a <- MyOption(1)
43-
b <- MyOption(2)
44-
} yield b
45-
46-
portablePrintMyOption(z2)
47-
48-
val z3 = for {
49-
a <- MyOption(1)
50-
(b, c) <- MyOption((2, 3))
51-
} yield (b, c)
52-
53-
portablePrintMyOption(z3)
54-
55-
val z4 = for {
56-
a <- MyOption(1)
57-
(b, (c, d)) <- MyOption((2, (3, 4)))
58-
} yield (b, (c, d))
59-
60-
portablePrintMyOption(z4)
61-
62-
}
1+
enum MyOption[+A]:
2+
case MySome(x: A)
3+
case MyNone
4+
5+
def map[B](f: A => B): MyOption[B] =
6+
this match
7+
case MySome(x) => ??? //MySome(f(x))
8+
case MyNone => ??? //MyNone
9+
def flatMap[B](f: A => MyOption[B]): MyOption[B] =
10+
this match
11+
case MySome(x) => f(x)
12+
case MyNone => MyNone
13+
object MyOption:
14+
def apply[A](x: A): MyOption[A] = MySome(x)
15+
16+
@main def Test =
17+
18+
val _ =
19+
for
20+
a <- MyOption(1)
21+
b <- MyOption(())
22+
yield ()
23+
24+
val _ =
25+
for
26+
a <- MyOption(1)
27+
b <- MyOption(2)
28+
yield b
29+
30+
val _ =
31+
for
32+
a <- MyOption(1)
33+
(b, c) <- MyOption((2, 3))
34+
yield (b, c)
35+
36+
val _ =
37+
for
38+
a <- MyOption(1)
39+
(b, (c, d)) <- MyOption((2, (3, 4)))
40+
yield (b, (c, d))
41+
42+
extension (i: Int) def map[A](f: Int => A): A = ???
43+
44+
val _ = for j <- 42 yield j
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
//> using options -preview
2+
3+
enum MyOption[+A]:
4+
case MySome(x: A)
5+
case MyNone
6+
7+
inline def map[B](f: A => B): MyOption[B] =
8+
this match
9+
case MySome(x) => ??? //MySome(f(x))
10+
case MyNone => ??? //MyNone
11+
def flatMap[B](f: A => MyOption[B]): MyOption[B] =
12+
this match
13+
case MySome(x) => f(x)
14+
case MyNone => MyNone
15+
object MyOption:
16+
def apply[A](x: A): MyOption[A] = MySome(x)
17+
18+
@main def Test =
19+
20+
val _ =
21+
for
22+
a <- MyOption(1)
23+
b <- MyOption(())
24+
yield ()
25+
26+
val _ =
27+
for
28+
a <- MyOption(1)
29+
b <- MyOption(2)
30+
yield b
31+
32+
val _ =
33+
for
34+
a <- MyOption(1)
35+
(b, c) <- MyOption((2, 3))
36+
yield (b, c)
37+
38+
val _ =
39+
for
40+
a <- MyOption(1)
41+
(b, (c, d)) <- MyOption((2, (3, 4)))
42+
yield (b, (c, d))

tests/run/i23409.scala

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
//> using options -preview
2+
3+
// dropForMap should be aware of conversions to receiver
4+
5+
import language.implicitConversions
6+
7+
trait Func[F[_]]:
8+
def map[A, B](fa: F[A])(f: A => B): F[B]
9+
10+
object Func:
11+
trait Ops[F[_], A]:
12+
type T <: Func[F]
13+
def t: T
14+
def fa: F[A]
15+
def map[B](f: A => B): F[B] = t.map[A, B](fa)(f)
16+
17+
object OldStyle:
18+
implicit def cv[F[_], A](fa0: F[A])(using Func[F]): Ops[F, A] { type T = Func[F] } =
19+
new Ops[F, A]:
20+
type T = Func[F]
21+
def t: T = summon[Func[F]]
22+
def fa = fa0
23+
24+
object NewStyle:
25+
given [F[_], A] => Func[F] => Conversion[F[A], Ops[F, A] { type T = Func[F] }]:
26+
def apply(fa0: F[A]): Ops[F, A] { type T = Func[F] } =
27+
new Ops[F, A]:
28+
type T = Func[F]
29+
def t: T = summon[Func[F]]
30+
def fa = fa0
31+
end Func
32+
33+
def works =
34+
for i <- List(42) yield i
35+
36+
class C[A]
37+
object C:
38+
given Func[C]:
39+
def map[A, B](fa: C[A])(f: A => B): C[B] = ??? // must be elided
40+
41+
def implicitlyConverted() = println:
42+
import Func.OldStyle.given
43+
//C().map(x => x) --> C()
44+
for x <- C() yield x
45+
46+
def usingConversion() = println:
47+
import Func.NewStyle.given
48+
//C().map(x => x) --> C()
49+
for x <- C() yield x
50+
51+
@main def Test =
52+
implicitlyConverted()
53+
usingConversion()

tests/run/i23409b.scala

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
//> using options -preview
2+
3+
final class Implicit()
4+
5+
final class Id[+A, -U](val value: A):
6+
def map[B](f: A => B)(using Implicit): Id[B, U] = ??? //Id(f(value))
7+
def flatMap[B, V <: U](f: A => Id[B, V]): Id[B, V] = f(value)
8+
def run: A = value
9+
10+
type Foo = Foo.type
11+
case object Foo:
12+
def get: Id[Int, Foo] = Id(42)
13+
14+
type Bar = Bar.type
15+
case object Bar:
16+
def inc(i: Int): Id[Int, Bar] = Id(i * 10)
17+
18+
def program(using Implicit) =
19+
for
20+
a <- Foo.get
21+
x <- Bar.inc(a)
22+
yield x
23+
24+
@main def Test = println:
25+
given Implicit = Implicit()
26+
program.run

0 commit comments

Comments
 (0)