Skip to content

Commit 02fa81e

Browse files
committed
DropForMap conversion, extension, inline, args
1 parent d7f3a80 commit 02fa81e

File tree

6 files changed

+270
-95
lines changed

6 files changed

+270
-95
lines changed

compiler/src/dotty/tools/dotc/transform/localopt/DropForMap.scala

Lines changed: 105 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,54 +1,130 @@
11
package dotty.tools.dotc
22
package transform.localopt
33

4+
import dotty.tools.dotc.ast.desugar.TrailingForMap
45
import dotty.tools.dotc.ast.tpd.*
5-
import dotty.tools.dotc.core.Decorators.*
66
import dotty.tools.dotc.core.Contexts.*
7+
import dotty.tools.dotc.core.Decorators.*
8+
import dotty.tools.dotc.core.Flags.*
79
import dotty.tools.dotc.core.StdNames.*
810
import dotty.tools.dotc.core.Symbols.*
911
import dotty.tools.dotc.core.Types.*
1012
import dotty.tools.dotc.transform.MegaPhase.MiniPhase
11-
import dotty.tools.dotc.ast.desugar
1213

1314
/** Drop unused trailing map calls in for comprehensions.
14-
* We can drop the map call if:
15-
* - it won't change the type of the expression, and
16-
* - the function is an identity function or a const function to unit.
17-
*
18-
* The latter condition is checked in [[Desugar.scala#makeFor]]
19-
*/
15+
*
16+
* We can drop the map call if:
17+
* - it won't change the type of the expression, and
18+
* - the function is an identity function or a const function to unit.
19+
*
20+
* The latter condition is checked in [[Desugar.scala#makeFor]]
21+
*/
2022
class DropForMap extends MiniPhase:
21-
import DropForMap.*
2223

2324
override def phaseName: String = DropForMap.name
2425

2526
override def description: String = DropForMap.description
2627

27-
override def transformApply(tree: Apply)(using Context): Tree =
28-
if !tree.hasAttachment(desugar.TrailingForMap) then tree
29-
else tree match
30-
case aply @ Apply(MapCall(f), List(Lambda(List(param), body)))
31-
if f.tpe =:= aply.tpe => // make sure that the type of the expression won't change
28+
import DropForMap.{Converted, Unmapped}
29+
30+
/** r.map(x => x)(using y) --> r
31+
* ^ TrailingForMap
32+
*/
33+
override def transformApply(tree: Apply)(using Context): Tree = tree match
34+
case Unmapped(f0, sym, args) =>
35+
val f =
36+
if sym.is(Extension) then args.head
37+
else f0
38+
if f.tpe.widen =:= tree.tpe then // make sure that the type of the expression won't change
3239
f // drop the map call
33-
case _ =>
34-
tree.removeAttachment(desugar.TrailingForMap)
35-
tree
40+
else
41+
f match
42+
case Converted(r) if r.tpe =:= tree.tpe => r // drop the map call and the conversion
43+
case _ => tree
44+
case tree => tree
3645

37-
private object Lambda:
38-
def unapply(tree: Tree)(using Context): Option[(List[ValDef], Tree)] =
39-
tree match
40-
case Block(List(defdef: DefDef), Closure(Nil, ref, _))
41-
if ref.symbol == defdef.symbol && !defdef.paramss.exists(_.forall(_.isType)) =>
42-
Some((defdef.termParamss.flatten, defdef.rhs))
46+
/** If the map was inlined, fetch the binding for the receiver,
47+
* then find the tree in the expansion that refers to the binding.
48+
* That is the expansion of the result Inlined node.
49+
*/
50+
override def transformInlined(tree: Inlined)(using Context): Tree = tree match
51+
case Inlined(call, bindings, expansion) if call.hasAttachment(TrailingForMap) =>
52+
val expansion1 =
53+
call match
54+
case Unmapped(f0, sym, args) =>
55+
val f =
56+
if sym.is(Extension) then args.head
57+
else f0
58+
if f.tpe.widen =:= expansion.tpe then
59+
bindings.collectFirst:
60+
case vd: ValDef if f.sameTree(vd.rhs) =>
61+
expansion.find:
62+
case Inlined(Thicket(Nil), Nil, Ident(ident)) => ident == vd.name
63+
case _ => false
64+
.getOrElse(expansion)
65+
.getOrElse(expansion)
66+
else
67+
f match
68+
case Converted(r) if r.tpe =:= expansion.tpe => r // drop the map call and the conversion
69+
case _ => expansion
70+
case _ => expansion
71+
if expansion1 ne expansion then
72+
cpy.Inlined(tree)(call, bindings, expansion1)
73+
else tree
74+
case tree => tree
75+
76+
object DropForMap:
77+
val name: String = "dropForMap"
78+
val description: String = "Drop unused trailing map calls in for comprehensions"
79+
80+
// Extracts a fun from a possibly nested Apply with lambda and arbitrary implicit args.
81+
// Specifically, an application `r.map(x => x)` is destructured into (r, map, args).
82+
// If the receiver r was adapted, it is unwrapped.
83+
// If `map` is an extension method, the nominal receiver is `args.head`.
84+
private object Unmapped:
85+
private def loop(tree: Tree, args: List[Tree])(using Context): Option[(Tree, Symbol, List[Tree])] = tree match
86+
case Apply(fun, args @ Lambda(_ :: Nil, _) :: Nil) =>
87+
tree.removeAttachment(TrailingForMap) match
88+
case Some(_) =>
89+
fun match
90+
case MapCall(f, sym, args) => Some((f, sym, args))
91+
case _ => None
92+
case _ => None
93+
case Apply(fun, _) =>
94+
fun.tpe match
95+
case mt: MethodType if mt.isImplicitMethod => loop(fun, args)
4396
case _ => None
97+
case TypeApply(fun, _) => loop(fun, args)
98+
case _ => None
99+
end loop
100+
def unapply(tree: Apply)(using Context): Option[(Tree, Symbol, List[Tree])] =
101+
tree.tpe match
102+
case _: MethodOrPoly => None
103+
case _ => loop(tree, args = Nil)
104+
105+
private object Lambda:
106+
def unapply(tree: Tree)(using Context): Option[(List[ValDef], Tree)] = tree match
107+
case Block(List(defdef: DefDef), Closure(Nil, ref, _))
108+
if ref.symbol == defdef.symbol && !defdef.paramss.exists(_.forall(_.isType)) =>
109+
Some((defdef.termParamss.flatten, defdef.rhs))
110+
case _ => None
44111

45112
private object MapCall:
113+
def unapply(tree: Tree)(using Context): Option[(Tree, Symbol, List[Tree])] =
114+
def loop(tree: Tree, args: List[Tree]): Option[(Tree, Symbol, List[Tree])] =
115+
tree match
116+
case Ident(nme.map) if tree.symbol.is(Extension) => Some((EmptyTree, tree.symbol, args))
117+
case Select(f, nme.map) => Some((f, tree.symbol, args))
118+
case Apply(fn, args) => loop(fn, args)
119+
case TypeApply(fn, _) => loop(fn, args)
120+
case _ => None
121+
loop(tree, Nil)
122+
123+
private object Converted:
46124
def unapply(tree: Tree)(using Context): Option[Tree] = tree match
47-
case Select(f, nme.map) => Some(f)
48-
case Apply(fn, _) => unapply(fn)
125+
case Apply(fn @ Apply(_, _), _) => unapply(fn)
126+
case Apply(fn, r :: Nil)
127+
if fn.symbol.is(Implicit) || fn.symbol.name == nme.apply && fn.symbol.owner.derivesFrom(defn.ConversionClass)
128+
=> Some(r)
49129
case TypeApply(fn, _) => unapply(fn)
50130
case _ => None
51-
52-
object DropForMap:
53-
val name: String = "dropForMap"
54-
val description: String = "Drop unused trailing map calls in for comprehensions"

tests/run/better-fors-map-elim.check

Lines changed: 0 additions & 4 deletions
This file was deleted.

tests/run/better-fors-map-elim.scala

Lines changed: 44 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -1,65 +1,47 @@
11
//> using options -preview
22
// import scala.language.experimental.betterFors
33

4-
class myOptionModule(doOnMap: => Unit) {
5-
sealed trait MyOption[+A] {
6-
def map[B](f: A => B): MyOption[B] = this match {
7-
case MySome(x) => {
8-
doOnMap
9-
MySome(f(x))
10-
}
11-
case MyNone => MyNone
12-
}
13-
def flatMap[B](f: A => MyOption[B]): MyOption[B] = this match {
14-
case MySome(x) => f(x)
15-
case MyNone => MyNone
16-
}
17-
}
18-
case class MySome[A](x: A) extends MyOption[A]
19-
case object MyNone extends MyOption[Nothing]
20-
object MyOption {
21-
def apply[A](x: A): MyOption[A] = MySome(x)
22-
}
23-
}
24-
25-
object Test extends App {
26-
27-
val myOption = new myOptionModule(println("map called"))
28-
29-
import myOption.*
30-
31-
def portablePrintMyOption(opt: MyOption[Any]): Unit =
32-
if opt == MySome(()) then
33-
println("MySome(())")
34-
else
35-
println(opt)
36-
37-
val z = for {
38-
a <- MyOption(1)
39-
b <- MyOption(())
40-
} yield ()
41-
42-
portablePrintMyOption(z)
43-
44-
val z2 = for {
45-
a <- MyOption(1)
46-
b <- MyOption(2)
47-
} yield b
48-
49-
portablePrintMyOption(z2)
50-
51-
val z3 = for {
52-
a <- MyOption(1)
53-
(b, c) <- MyOption((2, 3))
54-
} yield (b, c)
55-
56-
portablePrintMyOption(z3)
57-
58-
val z4 = for {
59-
a <- MyOption(1)
60-
(b, (c, d)) <- MyOption((2, (3, 4)))
61-
} yield (b, (c, d))
62-
63-
portablePrintMyOption(z4)
64-
65-
}
4+
enum MyOption[+A]:
5+
case MySome(x: A)
6+
case MyNone
7+
8+
def map[B](f: A => B): MyOption[B] =
9+
this match
10+
case MySome(x) => ??? //MySome(f(x))
11+
case MyNone => ??? //MyNone
12+
def flatMap[B](f: A => MyOption[B]): MyOption[B] =
13+
this match
14+
case MySome(x) => f(x)
15+
case MyNone => MyNone
16+
object MyOption:
17+
def apply[A](x: A): MyOption[A] = MySome(x)
18+
19+
@main def Test =
20+
21+
val _ =
22+
for
23+
a <- MyOption(1)
24+
b <- MyOption(())
25+
yield ()
26+
27+
val _ =
28+
for
29+
a <- MyOption(1)
30+
b <- MyOption(2)
31+
yield b
32+
33+
val _ =
34+
for
35+
a <- MyOption(1)
36+
(b, c) <- MyOption((2, 3))
37+
yield (b, c)
38+
39+
val _ =
40+
for
41+
a <- MyOption(1)
42+
(b, (c, d)) <- MyOption((2, (3, 4)))
43+
yield (b, (c, d))
44+
45+
extension (i: Int) def map[A](f: Int => A): A = ???
46+
47+
val _ = for j <- 42 yield j
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
//> using options -preview
2+
3+
enum MyOption[+A]:
4+
case MySome(x: A)
5+
case MyNone
6+
7+
inline def map[B](f: A => B): MyOption[B] =
8+
this match
9+
case MySome(x) => ??? //MySome(f(x))
10+
case MyNone => ??? //MyNone
11+
def flatMap[B](f: A => MyOption[B]): MyOption[B] =
12+
this match
13+
case MySome(x) => f(x)
14+
case MyNone => MyNone
15+
object MyOption:
16+
def apply[A](x: A): MyOption[A] = MySome(x)
17+
18+
@main def Test =
19+
20+
val _ =
21+
for
22+
a <- MyOption(1)
23+
b <- MyOption(())
24+
yield ()
25+
26+
val _ =
27+
for
28+
a <- MyOption(1)
29+
b <- MyOption(2)
30+
yield b
31+
32+
val _ =
33+
for
34+
a <- MyOption(1)
35+
(b, c) <- MyOption((2, 3))
36+
yield (b, c)
37+
38+
val _ =
39+
for
40+
a <- MyOption(1)
41+
(b, (c, d)) <- MyOption((2, (3, 4)))
42+
yield (b, (c, d))

tests/run/i23409.scala

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
//> using options -preview
2+
3+
// dropForMap should be aware of conversions to receiver
4+
5+
import language.implicitConversions
6+
7+
trait Func[F[_]]:
8+
def map[A, B](fa: F[A])(f: A => B): F[B]
9+
10+
object Func:
11+
trait Ops[F[_], A]:
12+
type T <: Func[F]
13+
def t: T
14+
def fa: F[A]
15+
def map[B](f: A => B): F[B] = t.map[A, B](fa)(f)
16+
17+
object OldStyle:
18+
implicit def cv[F[_], A](fa0: F[A])(using Func[F]): Ops[F, A] { type T = Func[F] } =
19+
new Ops[F, A]:
20+
type T = Func[F]
21+
def t: T = summon[Func[F]]
22+
def fa = fa0
23+
24+
object NewStyle:
25+
given [F[_], A] => Func[F] => Conversion[F[A], Ops[F, A] { type T = Func[F] }]:
26+
def apply(fa0: F[A]): Ops[F, A] { type T = Func[F] } =
27+
new Ops[F, A]:
28+
type T = Func[F]
29+
def t: T = summon[Func[F]]
30+
def fa = fa0
31+
end Func
32+
33+
def works =
34+
for i <- List(42) yield i
35+
36+
class C[A]
37+
object C:
38+
given Func[C]:
39+
def map[A, B](fa: C[A])(f: A => B): C[B] = ??? // must be elided
40+
41+
def implicitlyConverted() = println:
42+
import Func.OldStyle.given
43+
//C().map(x => x) --> C()
44+
for x <- C() yield x
45+
46+
def usingConversion() = println:
47+
import Func.NewStyle.given
48+
//C().map(x => x) --> C()
49+
for x <- C() yield x
50+
51+
@main def Test =
52+
implicitlyConverted()
53+
usingConversion()

tests/run/i23409b.scala

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
//> using options -preview
2+
3+
final class Implicit()
4+
5+
final class Id[+A, -U](val value: A):
6+
def map[B](f: A => B)(using Implicit): Id[B, U] = ??? //Id(f(value))
7+
def flatMap[B, V <: U](f: A => Id[B, V]): Id[B, V] = f(value)
8+
def run: A = value
9+
10+
type Foo = Foo.type
11+
case object Foo:
12+
def get: Id[Int, Foo] = Id(42)
13+
14+
type Bar = Bar.type
15+
case object Bar:
16+
def inc(i: Int): Id[Int, Bar] = Id(i * 10)
17+
18+
def program(using Implicit) =
19+
for
20+
a <- Foo.get
21+
x <- Bar.inc(a)
22+
yield x
23+
24+
@main def Test = println:
25+
given Implicit = Implicit()
26+
program.run

0 commit comments

Comments
 (0)