From 240df1c471df7aa18554ea48e39a3df021e34eed Mon Sep 17 00:00:00 2001 From: Som Snytt Date: Tue, 24 Jun 2025 09:02:53 -0700 Subject: [PATCH 1/3] Elide conversion of receiver in DropForMap --- .../dotc/transform/localopt/DropForMap.scala | 32 ++++++++--- tests/run/i23409.scala | 54 +++++++++++++++++++ 2 files changed, 78 insertions(+), 8 deletions(-) create mode 100644 tests/run/i23409.scala diff --git a/compiler/src/dotty/tools/dotc/transform/localopt/DropForMap.scala b/compiler/src/dotty/tools/dotc/transform/localopt/DropForMap.scala index f7594f041204..698cb84c137e 100644 --- a/compiler/src/dotty/tools/dotc/transform/localopt/DropForMap.scala +++ b/compiler/src/dotty/tools/dotc/transform/localopt/DropForMap.scala @@ -2,8 +2,9 @@ package dotty.tools.dotc package transform.localopt import dotty.tools.dotc.ast.tpd.* -import dotty.tools.dotc.core.Decorators.* import dotty.tools.dotc.core.Contexts.* +import dotty.tools.dotc.core.Decorators.* +import dotty.tools.dotc.core.Flags.* import dotty.tools.dotc.core.StdNames.* import dotty.tools.dotc.core.Symbols.* import dotty.tools.dotc.core.Types.* @@ -25,14 +26,20 @@ class DropForMap extends MiniPhase: override def description: String = DropForMap.description override def transformApply(tree: Apply)(using Context): Tree = - if !tree.hasAttachment(desugar.TrailingForMap) then tree - else tree match - case aply @ Apply(MapCall(f), List(Lambda(List(param), body))) - if f.tpe =:= aply.tpe => // make sure that the type of the expression won't change - f // drop the map call + tree.removeAttachment(desugar.TrailingForMap) match + case Some(_) => + tree match + case aply @ Apply(MapCall(f), List(Lambda(List(param), body))) => + if f.tpe =:= aply.tpe then // make sure that the type of the expression won't change + return f // drop the map call + else + f match + case Converted(r) if r.tpe =:= aply.tpe => + return r // drop the map call and the conversion + case _ => case _ => - tree.removeAttachment(desugar.TrailingForMap) - tree + case _ => + tree private object Lambda: def unapply(tree: Tree)(using Context): Option[(List[ValDef], Tree)] = @@ -49,6 +56,15 @@ class DropForMap extends MiniPhase: case TypeApply(fn, _) => unapply(fn) case _ => None + private object Converted: + def unapply(tree: Tree)(using Context): Option[Tree] = tree match + case Apply(fn @ Apply(_, _), _) => unapply(fn) + case Apply(fn, r :: Nil) + if fn.symbol.is(Implicit) || fn.symbol.name == nme.apply && fn.symbol.owner.derivesFrom(defn.ConversionClass) + => Some(r) + case TypeApply(fn, _) => unapply(fn) + case _ => None + object DropForMap: val name: String = "dropForMap" val description: String = "Drop unused trailing map calls in for comprehensions" diff --git a/tests/run/i23409.scala b/tests/run/i23409.scala new file mode 100644 index 000000000000..0357f538e1ef --- /dev/null +++ b/tests/run/i23409.scala @@ -0,0 +1,54 @@ + +//> using options -preview + +// dropForMap should be aware of conversions to receiver + +import language.implicitConversions + +trait Func[F[_]]: + def map[A, B](fa: F[A])(f: A => B): F[B] + +object Func: + trait Ops[F[_], A]: + type T <: Func[F] + def t: T + def fa: F[A] + def map[B](f: A => B): F[B] = t.map[A, B](fa)(f) + + object OldStyle: + implicit def cv[F[_], A](fa0: F[A])(using Func[F]): Ops[F, A] { type T = Func[F] } = + new Ops[F, A]: + type T = Func[F] + def t: T = summon[Func[F]] + def fa = fa0 + + object NewStyle: + given [F[_], A] => Func[F] => Conversion[F[A], Ops[F, A] { type T = Func[F] }]: + def apply(fa0: F[A]): Ops[F, A] { type T = Func[F] } = + new Ops[F, A]: + type T = Func[F] + def t: T = summon[Func[F]] + def fa = fa0 +end Func + +def works = + for i <- List(42) yield i + +class C[A] +object C: + given Func[C]: + def map[A, B](fa: C[A])(f: A => B): C[B] = ??? // must be elided + +def implicitlyConverted() = println: + import Func.OldStyle.given + //C().map(x => x) --> C() + for x <- C() yield x + +def usingConversion() = println: + import Func.NewStyle.given + //C().map(x => x) --> C() + for x <- C() yield x + +@main def Test = + implicitlyConverted() + usingConversion() From 14974ed6a45681b374a0e0acc23a686e0ee0a1b7 Mon Sep 17 00:00:00 2001 From: Som Snytt Date: Fri, 27 Jun 2025 06:44:34 -0700 Subject: [PATCH 2/3] Drop trailing implicit args to dropped map --- .../dotc/transform/localopt/DropForMap.scala | 68 ++++++++++++------- tests/run/i23409b.scala | 26 +++++++ 2 files changed, 70 insertions(+), 24 deletions(-) create mode 100644 tests/run/i23409b.scala diff --git a/compiler/src/dotty/tools/dotc/transform/localopt/DropForMap.scala b/compiler/src/dotty/tools/dotc/transform/localopt/DropForMap.scala index 698cb84c137e..c51a116673ff 100644 --- a/compiler/src/dotty/tools/dotc/transform/localopt/DropForMap.scala +++ b/compiler/src/dotty/tools/dotc/transform/localopt/DropForMap.scala @@ -12,12 +12,13 @@ import dotty.tools.dotc.transform.MegaPhase.MiniPhase import dotty.tools.dotc.ast.desugar /** Drop unused trailing map calls in for comprehensions. - * We can drop the map call if: - * - it won't change the type of the expression, and - * - the function is an identity function or a const function to unit. - * - * The latter condition is checked in [[Desugar.scala#makeFor]] - */ + * + * We can drop the map call if: + * - it won't change the type of the expression, and + * - the function is an identity function or a const function to unit. + * + * The latter condition is checked in [[Desugar.scala#makeFor]] + */ class DropForMap extends MiniPhase: import DropForMap.* @@ -25,29 +26,48 @@ class DropForMap extends MiniPhase: override def description: String = DropForMap.description - override def transformApply(tree: Apply)(using Context): Tree = - tree.removeAttachment(desugar.TrailingForMap) match - case Some(_) => + /** r.map(x => x)(using y) --> r + * ^ TrailingForMap + */ + override def transformApply(tree: Apply)(using Context): Tree = tree match + case Unmapped(f) => + if f.tpe =:= tree.tpe then // make sure that the type of the expression won't change + f // drop the map call + else + f match + case Converted(r) if r.tpe =:= tree.tpe => r // drop the map call and the conversion + case _ => tree + case tree => tree + + // Extracts a fun from a possibly nested Apply with lambda and arbitrary implicit args. + private object Unmapped: + private def loop(tree: Tree)(using Context): Option[Tree] = tree match - case aply @ Apply(MapCall(f), List(Lambda(List(param), body))) => - if f.tpe =:= aply.tpe then // make sure that the type of the expression won't change - return f // drop the map call - else - f match - case Converted(r) if r.tpe =:= aply.tpe => - return r // drop the map call and the conversion + case Apply(fun, Lambda(_ :: Nil, _) :: Nil) => + tree.removeAttachment(desugar.TrailingForMap) match + case Some(_) => + fun match + case MapCall(f) => return Some(f) case _ => + case _ => + case Apply(fun, _) => + fun.tpe match + case mt: MethodType if mt.isImplicitMethod => return loop(fun) + case _ => case _ => - case _ => - tree + None + end loop + def unapply(tree: Apply)(using Context): Option[Tree] = + tree.tpe match + case _: MethodOrPoly => None + case _ => loop(tree) private object Lambda: - def unapply(tree: Tree)(using Context): Option[(List[ValDef], Tree)] = - tree match - case Block(List(defdef: DefDef), Closure(Nil, ref, _)) - if ref.symbol == defdef.symbol && !defdef.paramss.exists(_.forall(_.isType)) => - Some((defdef.termParamss.flatten, defdef.rhs)) - case _ => None + def unapply(tree: Tree)(using Context): Option[(List[ValDef], Tree)] = tree match + case Block(List(defdef: DefDef), Closure(Nil, ref, _)) + if ref.symbol == defdef.symbol && !defdef.paramss.exists(_.forall(_.isType)) => + Some((defdef.termParamss.flatten, defdef.rhs)) + case _ => None private object MapCall: def unapply(tree: Tree)(using Context): Option[Tree] = tree match diff --git a/tests/run/i23409b.scala b/tests/run/i23409b.scala new file mode 100644 index 000000000000..4fd92cfe0bca --- /dev/null +++ b/tests/run/i23409b.scala @@ -0,0 +1,26 @@ +//> using options -preview + +final class Implicit() + +final class Id[+A, -U](val value: A): + def map[B](f: A => B)(using Implicit): Id[B, U] = ??? //Id(f(value)) + def flatMap[B, V <: U](f: A => Id[B, V]): Id[B, V] = f(value) + def run: A = value + +type Foo = Foo.type +case object Foo: + def get: Id[Int, Foo] = Id(42) + +type Bar = Bar.type +case object Bar: + def inc(i: Int): Id[Int, Bar] = Id(i * 10) + +def program(using Implicit) = + for + a <- Foo.get + x <- Bar.inc(a) + yield x + +@main def Test = println: + given Implicit = Implicit() + program.run From c121bd8e110ddb23ade0507d3460d1a8f1f50957 Mon Sep 17 00:00:00 2001 From: Som Snytt Date: Sat, 28 Jun 2025 15:35:13 -0700 Subject: [PATCH 3/3] Drop inlined map --- .../dotty/tools/dotc/inlines/Inlines.scala | 2 +- .../dotc/transform/localopt/DropForMap.scala | 17 +++++- tests/run/better-fors-map-inlined.check | 4 ++ tests/run/better-fors-map-inlined.scala | 58 +++++++++++++++++++ 4 files changed, 79 insertions(+), 2 deletions(-) create mode 100644 tests/run/better-fors-map-inlined.check create mode 100644 tests/run/better-fors-map-inlined.scala diff --git a/compiler/src/dotty/tools/dotc/inlines/Inlines.scala b/compiler/src/dotty/tools/dotc/inlines/Inlines.scala index a7269c83bccb..e6c8ffd89b05 100644 --- a/compiler/src/dotty/tools/dotc/inlines/Inlines.scala +++ b/compiler/src/dotty/tools/dotc/inlines/Inlines.scala @@ -571,7 +571,7 @@ object Inlines: // Take care that only argument bindings go into `bindings`, since positions are // different for bindings from arguments and bindings from body. - val inlined = tpd.Inlined(call, bindings, expansion) + val inlined = tpd.Inlined(call, bindings, expansion.withAttachmentsFrom(call)) if !hasOpaqueProxies then inlined else diff --git a/compiler/src/dotty/tools/dotc/transform/localopt/DropForMap.scala b/compiler/src/dotty/tools/dotc/transform/localopt/DropForMap.scala index c51a116673ff..3cc40448fe62 100644 --- a/compiler/src/dotty/tools/dotc/transform/localopt/DropForMap.scala +++ b/compiler/src/dotty/tools/dotc/transform/localopt/DropForMap.scala @@ -20,7 +20,6 @@ import dotty.tools.dotc.ast.desugar * The latter condition is checked in [[Desugar.scala#makeFor]] */ class DropForMap extends MiniPhase: - import DropForMap.* override def phaseName: String = DropForMap.name @@ -39,6 +38,22 @@ class DropForMap extends MiniPhase: case _ => tree case tree => tree + override def transformInlined(tree: Inlined)(using Context): Tree = tree match + case Inlined(call, bindings, expansion) if expansion.hasAttachment(desugar.TrailingForMap) => + call match + case Unmapped(f) => + bindings.collectFirst: + case vd: ValDef if f.sameTree(vd.rhs) => + expansion.find: + case Inlined(Thicket(Nil), Nil, Ident(ident)) => ident == vd.name + case _ => false + .match + case Some(ref) => cpy.Inlined(tree)(call, bindings, ref) + case _ => tree + .getOrElse(tree) + case _ => tree + case tree => tree + // Extracts a fun from a possibly nested Apply with lambda and arbitrary implicit args. private object Unmapped: private def loop(tree: Tree)(using Context): Option[Tree] = diff --git a/tests/run/better-fors-map-inlined.check b/tests/run/better-fors-map-inlined.check new file mode 100644 index 000000000000..0ef3447a47c4 --- /dev/null +++ b/tests/run/better-fors-map-inlined.check @@ -0,0 +1,4 @@ +MySome(()) +MySome(2) +MySome((2,3)) +MySome((2,(3,4))) diff --git a/tests/run/better-fors-map-inlined.scala b/tests/run/better-fors-map-inlined.scala new file mode 100644 index 000000000000..d43fed329c3b --- /dev/null +++ b/tests/run/better-fors-map-inlined.scala @@ -0,0 +1,58 @@ +//> using options -preview + +class myOptionModule(doOnMap: => Unit): + sealed trait MyOption[+A]: + inline def map[B](f: A => B): MyOption[B] = + this match + case MySome(x) => + doOnMap + MySome(f(x)) + case MyNone => MyNone + def flatMap[B](f: A => MyOption[B]): MyOption[B] = + this match + case MySome(x) => f(x) + case MyNone => MyNone + case class MySome[A](x: A) extends MyOption[A] + case object MyNone extends MyOption[Nothing] + object MyOption: + def apply[A](x: A): MyOption[A] = MySome(x) + +@main def Test = + + val myOption = myOptionModule(???) + + import myOption.* + + def portablePrintMyOption(opt: MyOption[Any]): Unit = println: + opt match + case MySome(()) => "MySome(())" + case opt => opt + + val z = for { + a <- MyOption(1) + b <- MyOption(()) + } yield () + + portablePrintMyOption(z) + + val z2 = for { + a <- MyOption(1) + b <- MyOption(2) + } yield b + + portablePrintMyOption(z2) + + val z3 = for { + a <- MyOption(1) + (b, c) <- MyOption((2, 3)) + } yield (b, c) + + portablePrintMyOption(z3) + + val z4 = for { + a <- MyOption(1) + (b, (c, d)) <- MyOption((2, (3, 4))) + } yield (b, (c, d)) + + portablePrintMyOption(z4) +end Test