diff --git a/compiler/src/dotty/tools/dotc/transform/localopt/DropForMap.scala b/compiler/src/dotty/tools/dotc/transform/localopt/DropForMap.scala index f7594f041204..39b9b7a65bd8 100644 --- a/compiler/src/dotty/tools/dotc/transform/localopt/DropForMap.scala +++ b/compiler/src/dotty/tools/dotc/transform/localopt/DropForMap.scala @@ -1,54 +1,130 @@ package dotty.tools.dotc package transform.localopt +import dotty.tools.dotc.ast.desugar.TrailingForMap import dotty.tools.dotc.ast.tpd.* -import dotty.tools.dotc.core.Decorators.* import dotty.tools.dotc.core.Contexts.* +import dotty.tools.dotc.core.Decorators.* +import dotty.tools.dotc.core.Flags.* import dotty.tools.dotc.core.StdNames.* import dotty.tools.dotc.core.Symbols.* import dotty.tools.dotc.core.Types.* import dotty.tools.dotc.transform.MegaPhase.MiniPhase -import dotty.tools.dotc.ast.desugar /** Drop unused trailing map calls in for comprehensions. - * We can drop the map call if: - * - it won't change the type of the expression, and - * - the function is an identity function or a const function to unit. - * - * The latter condition is checked in [[Desugar.scala#makeFor]] - */ + * + * We can drop the map call if: + * - it won't change the type of the expression, and + * - the function is an identity function or a const function to unit. + * + * The latter condition is checked in [[Desugar.scala#makeFor]] + */ class DropForMap extends MiniPhase: - import DropForMap.* override def phaseName: String = DropForMap.name override def description: String = DropForMap.description - override def transformApply(tree: Apply)(using Context): Tree = - if !tree.hasAttachment(desugar.TrailingForMap) then tree - else tree match - case aply @ Apply(MapCall(f), List(Lambda(List(param), body))) - if f.tpe =:= aply.tpe => // make sure that the type of the expression won't change + import DropForMap.{Converted, Unmapped} + + /** r.map(x => x)(using y) --> r + * ^ TrailingForMap + */ + override def transformApply(tree: Apply)(using Context): Tree = tree match + case Unmapped(f0, sym, args) => + val f = + if sym.is(Extension) then args.head + else f0 + if f.tpe.widen =:= tree.tpe then // make sure that the type of the expression won't change f // drop the map call - case _ => - tree.removeAttachment(desugar.TrailingForMap) - tree + else + f match + case Converted(r) if r.tpe =:= tree.tpe => r // drop the map call and the conversion + case _ => tree + case tree => tree - private object Lambda: - def unapply(tree: Tree)(using Context): Option[(List[ValDef], Tree)] = - tree match - case Block(List(defdef: DefDef), Closure(Nil, ref, _)) - if ref.symbol == defdef.symbol && !defdef.paramss.exists(_.forall(_.isType)) => - Some((defdef.termParamss.flatten, defdef.rhs)) + /** If the map was inlined, fetch the binding for the receiver, + * then find the tree in the expansion that refers to the binding. + * That is the expansion of the result Inlined node. + */ + override def transformInlined(tree: Inlined)(using Context): Tree = tree match + case Inlined(call, bindings, expansion) if call.hasAttachment(TrailingForMap) => + val expansion1 = + call match + case Unmapped(f0, sym, args) => + val f = + if sym.is(Extension) then args.head + else f0 + if f.tpe.widen =:= expansion.tpe then + bindings.collectFirst: + case vd: ValDef if f.sameTree(vd.rhs) => + expansion.find: + case Inlined(Thicket(Nil), Nil, Ident(ident)) => ident == vd.name + case _ => false + .getOrElse(expansion) + .getOrElse(expansion) + else + f match + case Converted(r) if r.tpe =:= expansion.tpe => r // drop the map call and the conversion + case _ => expansion + case _ => expansion + if expansion1 ne expansion then + cpy.Inlined(tree)(call, bindings, expansion1) + else tree + case tree => tree + +object DropForMap: + val name: String = "dropForMap" + val description: String = "Drop unused trailing map calls in for comprehensions" + + // Extracts a fun from a possibly nested Apply with lambda and arbitrary implicit args. + // Specifically, an application `r.map(x => x)` is destructured into (r, map, args). + // If the receiver r was adapted, it is unwrapped. + // If `map` is an extension method, the nominal receiver is `args.head`. + private object Unmapped: + private def loop(tree: Tree, args: List[Tree])(using Context): Option[(Tree, Symbol, List[Tree])] = tree match + case Apply(fun, args @ Lambda(_ :: Nil, _) :: Nil) => + tree.removeAttachment(TrailingForMap) match + case Some(_) => + fun match + case MapCall(f, sym, args) => Some((f, sym, args)) + case _ => None + case _ => None + case Apply(fun, _) => + fun.tpe match + case mt: MethodType if mt.isImplicitMethod => loop(fun, args) case _ => None + case TypeApply(fun, _) => loop(fun, args) + case _ => None + end loop + def unapply(tree: Apply)(using Context): Option[(Tree, Symbol, List[Tree])] = + tree.tpe match + case _: MethodOrPoly => None + case _ => loop(tree, args = Nil) + + private object Lambda: + def unapply(tree: Tree)(using Context): Option[(List[ValDef], Tree)] = tree match + case Block(List(defdef: DefDef), Closure(Nil, ref, _)) + if ref.symbol == defdef.symbol && !defdef.paramss.exists(_.forall(_.isType)) => + Some((defdef.termParamss.flatten, defdef.rhs)) + case _ => None private object MapCall: + def unapply(tree: Tree)(using Context): Option[(Tree, Symbol, List[Tree])] = + def loop(tree: Tree, args: List[Tree]): Option[(Tree, Symbol, List[Tree])] = + tree match + case Ident(nme.map) if tree.symbol.is(Extension) => Some((EmptyTree, tree.symbol, args)) + case Select(f, nme.map) => Some((f, tree.symbol, args)) + case Apply(fn, args) => loop(fn, args) + case TypeApply(fn, _) => loop(fn, args) + case _ => None + loop(tree, Nil) + + private object Converted: def unapply(tree: Tree)(using Context): Option[Tree] = tree match - case Select(f, nme.map) => Some(f) - case Apply(fn, _) => unapply(fn) + case Apply(fn @ Apply(_, _), _) => unapply(fn) + case Apply(fn, r :: Nil) + if fn.symbol.is(Implicit) || fn.symbol.name == nme.apply && fn.symbol.owner.derivesFrom(defn.ConversionClass) + => Some(r) case TypeApply(fn, _) => unapply(fn) case _ => None - -object DropForMap: - val name: String = "dropForMap" - val description: String = "Drop unused trailing map calls in for comprehensions" diff --git a/tests/debug/eval-in-for-comprehension.check b/tests/debug/eval-in-for-comprehension.check index 6e91c891ebdb..fea2a1c261b2 100644 --- a/tests/debug/eval-in-for-comprehension.check +++ b/tests/debug/eval-in-for-comprehension.check @@ -16,12 +16,7 @@ break Test$ 11 // in main$$anonfun$2 eval x result 1 -break Test$ 13 // in main -eval list(0) -result 1 -break Test$ 13 // in main$$anonfun$4 - break Test$ 14 // in main eval list(0) result 1 -break Test$ 14 // in main$$anonfun$5 +break Test$ 14 // in main$$anonfun$4 diff --git a/tests/run/better-fors-map-elim.check b/tests/run/better-fors-map-elim.check deleted file mode 100644 index 0ef3447a47c4..000000000000 --- a/tests/run/better-fors-map-elim.check +++ /dev/null @@ -1,4 +0,0 @@ -MySome(()) -MySome(2) -MySome((2,3)) -MySome((2,(3,4))) diff --git a/tests/run/better-fors-map-elim.scala b/tests/run/better-fors-map-elim.scala index 390ad8ce5b50..bdeb087258bd 100644 --- a/tests/run/better-fors-map-elim.scala +++ b/tests/run/better-fors-map-elim.scala @@ -1,62 +1,44 @@ -class myOptionModule(doOnMap: => Unit) { - sealed trait MyOption[+A] { - def map[B](f: A => B): MyOption[B] = this match { - case MySome(x) => { - doOnMap - MySome(f(x)) - } - case MyNone => MyNone - } - def flatMap[B](f: A => MyOption[B]): MyOption[B] = this match { - case MySome(x) => f(x) - case MyNone => MyNone - } - } - case class MySome[A](x: A) extends MyOption[A] - case object MyNone extends MyOption[Nothing] - object MyOption { - def apply[A](x: A): MyOption[A] = MySome(x) - } -} - -object Test extends App { - - val myOption = new myOptionModule(println("map called")) - - import myOption.* - - def portablePrintMyOption(opt: MyOption[Any]): Unit = - if opt == MySome(()) then - println("MySome(())") - else - println(opt) - - val z = for { - a <- MyOption(1) - b <- MyOption(()) - } yield () - - portablePrintMyOption(z) - - val z2 = for { - a <- MyOption(1) - b <- MyOption(2) - } yield b - - portablePrintMyOption(z2) - - val z3 = for { - a <- MyOption(1) - (b, c) <- MyOption((2, 3)) - } yield (b, c) - - portablePrintMyOption(z3) - - val z4 = for { - a <- MyOption(1) - (b, (c, d)) <- MyOption((2, (3, 4))) - } yield (b, (c, d)) - - portablePrintMyOption(z4) - -} +enum MyOption[+A]: + case MySome(x: A) + case MyNone + + def map[B](f: A => B): MyOption[B] = + this match + case MySome(x) => ??? //MySome(f(x)) + case MyNone => ??? //MyNone + def flatMap[B](f: A => MyOption[B]): MyOption[B] = + this match + case MySome(x) => f(x) + case MyNone => MyNone +object MyOption: + def apply[A](x: A): MyOption[A] = MySome(x) + +@main def Test = + + val _ = + for + a <- MyOption(1) + b <- MyOption(()) + yield () + + val _ = + for + a <- MyOption(1) + b <- MyOption(2) + yield b + + val _ = + for + a <- MyOption(1) + (b, c) <- MyOption((2, 3)) + yield (b, c) + + val _ = + for + a <- MyOption(1) + (b, (c, d)) <- MyOption((2, (3, 4))) + yield (b, (c, d)) + + extension (i: Int) def map[A](f: Int => A): A = ??? + + val _ = for j <- 42 yield j diff --git a/tests/run/better-fors-map-inlined.scala b/tests/run/better-fors-map-inlined.scala new file mode 100644 index 000000000000..2fec1f446ae9 --- /dev/null +++ b/tests/run/better-fors-map-inlined.scala @@ -0,0 +1,40 @@ +enum MyOption[+A]: + case MySome(x: A) + case MyNone + + inline def map[B](f: A => B): MyOption[B] = + this match + case MySome(x) => ??? //MySome(f(x)) + case MyNone => ??? //MyNone + def flatMap[B](f: A => MyOption[B]): MyOption[B] = + this match + case MySome(x) => f(x) + case MyNone => MyNone +object MyOption: + def apply[A](x: A): MyOption[A] = MySome(x) + +@main def Test = + + val _ = + for + a <- MyOption(1) + b <- MyOption(()) + yield () + + val _ = + for + a <- MyOption(1) + b <- MyOption(2) + yield b + + val _ = + for + a <- MyOption(1) + (b, c) <- MyOption((2, 3)) + yield (b, c) + + val _ = + for + a <- MyOption(1) + (b, (c, d)) <- MyOption((2, (3, 4))) + yield (b, (c, d)) diff --git a/tests/run/i23409.scala b/tests/run/i23409.scala new file mode 100644 index 000000000000..ed86216582d4 --- /dev/null +++ b/tests/run/i23409.scala @@ -0,0 +1,52 @@ + +// dropForMap should be aware of conversions to receiver + +import language.implicitConversions + +trait Func[F[_]]: + def map[A, B](fa: F[A])(f: A => B): F[B] + +object Func: + trait Ops[F[_], A]: + type T <: Func[F] + def t: T + def fa: F[A] + def map[B](f: A => B): F[B] = t.map[A, B](fa)(f) + + object OldStyle: + implicit def cv[F[_], A](fa0: F[A])(using Func[F]): Ops[F, A] { type T = Func[F] } = + new Ops[F, A]: + type T = Func[F] + def t: T = summon[Func[F]] + def fa = fa0 + + object NewStyle: + given [F[_], A] => Func[F] => Conversion[F[A], Ops[F, A] { type T = Func[F] }]: + def apply(fa0: F[A]): Ops[F, A] { type T = Func[F] } = + new Ops[F, A]: + type T = Func[F] + def t: T = summon[Func[F]] + def fa = fa0 +end Func + +def works = + for i <- List(42) yield i + +class C[A] +object C: + given Func[C]: + def map[A, B](fa: C[A])(f: A => B): C[B] = ??? // must be elided + +def implicitlyConverted() = println: + import Func.OldStyle.given + //C().map(x => x) --> C() + for x <- C() yield x + +def usingConversion() = println: + import Func.NewStyle.given + //C().map(x => x) --> C() + for x <- C() yield x + +@main def Test = + implicitlyConverted() + usingConversion() diff --git a/tests/run/i23409b.scala b/tests/run/i23409b.scala new file mode 100644 index 000000000000..e064b2ab6efa --- /dev/null +++ b/tests/run/i23409b.scala @@ -0,0 +1,25 @@ + +final class Implicit() + +final class Id[+A, -U](val value: A): + def map[B](f: A => B)(using Implicit): Id[B, U] = ??? //Id(f(value)) + def flatMap[B, V <: U](f: A => Id[B, V]): Id[B, V] = f(value) + def run: A = value + +type Foo = Foo.type +case object Foo: + def get: Id[Int, Foo] = Id(42) + +type Bar = Bar.type +case object Bar: + def inc(i: Int): Id[Int, Bar] = Id(i * 10) + +def program(using Implicit) = + for + a <- Foo.get + x <- Bar.inc(a) + yield x + +@main def Test = println: + given Implicit = Implicit() + program.run