Skip to content

Commit

Permalink
Merge betterFors desugaring with the default implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
KacperFKorban committed Jul 22, 2024
1 parent 9efe193 commit 4b67eb0
Show file tree
Hide file tree
Showing 3 changed files with 98 additions and 160 deletions.
255 changes: 97 additions & 158 deletions compiler/src/dotty/tools/dotc/ast/Desugar.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1806,113 +1806,79 @@ object desugar {
* corresponding to whether this is a for-do or a for-yield.
* If betterFors are enabled, the creation performs the following rewrite rules:
*
* 1.
* 1. if betterFors is enabled:
*
* for (P <- G) do E ==> G.foreach (P => E)
* for () do E ==> E
* or
* for () yield E ==> E
*
* Here and in the following (P => E) is interpreted as the function (P => E)
* if P is a variable pattern and as the partial function { case P => E } otherwise.
* (Where empty for-comprehensions are excluded by the parser)
*
* 2.
*
* for (P <- G) yield P ==> G
*
* If P is a variable or a tuple of variables and G is not a withFilter.
* for (P <- G) do E ==> G.foreach (P => E)
*
* for (P <- G) yield E ==> G.map (P => E)
*
* Otherwise
* Here and in the following (P => E) is interpreted as the function (P => E)
* if P is a variable pattern and as the partial function { case P => E } otherwise.
*
* 3.
*
* for (P_1 <- G_1; P_2 <- G_2; ...) ...
* ==>
* G_1.flatMap (P_1 => for (P_2 <- G_2; ...) ...)
*
* 4.
*
* for (P <- G; if E; ...) ...
* ==>
* for (P <- G.withFilter (P => E); ...) ...
*
* 5. For any N:
*
* for (P <- G; P_1 = E_1; ... P_N = E_N; rest)
* ==>
* G.flatMap (P => for (P_1 = E_1; ... P_N = E_N; ...)) if rest contains (<-)
* G.map (P => for (P_1 = E_1; ... P_N = E_N; ...)) otherwise
* for (P <- G) yield P ==> G
*
* 6. For any N:
* If betterFors is enabled, P is a variable or a tuple of variables and G is not a withFilter.
*
* for (P <- G; P_1 = E_1; ... P_N = E_N; if E; ...)
* ==>
* for (TupleN(P, P_1, ... P_N) <-
* for (x @ P <- G) yield {
* val x_1 @ P_1 = E_2
* ...
* val x_N @ P_N = E_N
* TupleN(x, x_1, ..., x_N)
* }; if E; ...)
*
* If any of the P_i are variable patterns, the corresponding `x_i @ P_i` is not generated
* and the variable constituting P_i is used instead of x_i
*
* 7. For any N:
*
* for (P_1 = E_1; ... P_N = E_N; ...)
* ==>
* {
* val x_N @ P_N = E_N
* for (...)
* }
*
* 8.
* for () yield E ==> E
*
* (Where empty for-comprehensions are excluded by the parser)
* for (P <- G) yield E ==> G.map (P => E)
*
* If the aliases are not followed by a guard, otherwise an error.
*
* With betterFors disabled, the translation is as follows:
*
* 1.
* Otherwise
*
* for (P <- G) E ==> G.foreach (P => E)
* 4.
*
* Here and in the following (P => E) is interpreted as the function (P => E)
* if P is a variable pattern and as the partial function { case P => E } otherwise.
* for (P_1 <- G_1; P_2 <- G_2; ...) ...
* ==>
* G_1.flatMap (P_1 => for (P_2 <- G_2; ...) ...)
*
* 2.
* 5.
*
* for (P <- G) yield E ==> G.map (P => E)
* for (P <- G; if E; ...) ...
* ==>
* for (P <- G.withFilter (P => E); ...) ...
*
* 3.
* 6. For any N, if betterFors is enabled:
*
* for (P_1 <- G_1; P_2 <- G_2; ...) ...
* for (P <- G; P_1 = E_1; ... P_N = E_N; P1 <- G1; ...) ...
* ==>
* G_1.flatMap (P_1 => for (P_2 <- G_2; ...) ...)
* G.flatMap (P => for (P_1 = E_1; ... P_N = E_N; ...))
*
* 4.
* 7. For any N, if betterFors is enabled:
*
* for (P <- G; E; ...) ...
* =>
* for (P <- G.filter (P => E); ...) ...
* for (P <- G; P_1 = E_1; ... P_N = E_N) ...
* ==>
* G.map (P => for (P_1 = E_1; ... P_N = E_N) ...)
*
* 5. For any N:
* 8. For any N:
*
* for (P_1 <- G; P_2 = E_2; val P_N = E_N; ...)
* for (P <- G; P_1 = E_1; ... P_N = E_N; ...)
* ==>
* for (TupleN(P_1, P_2, ... P_N) <-
* for (x_1 @ P_1 <- G) yield {
* val x_2 @ P_2 = E_2
* for (TupleN(P, P_1, ... P_N) <-
* for (x @ P <- G) yield {
* val x_1 @ P_1 = E_2
* ...
* val x_N & P_N = E_N
* TupleN(x_1, ..., x_N)
* } ...)
* val x_N @ P_N = E_N
* TupleN(x, x_1, ..., x_N)
* }; if E; ...)
*
* If any of the P_i are variable patterns, the corresponding `x_i @ P_i` is not generated
* and the variable constituting P_i is used instead of x_i
*
* 9. For any N, if betterFors is enabled:
*
* for (P_1 = E_1; ... P_N = E_N; ...)
* ==>
* {
* val x_N @ P_N = E_N
* for (...)
* }
*
* @param mapName The name to be used for maps (either map or foreach)
* @param flatMapName The name to be used for flatMaps (either flatMap or foreach)
* @param enums The enumerators in the for expression
Expand Down Expand Up @@ -2037,86 +2003,59 @@ object desugar {
case (Tuple(ts1), Tuple(ts2)) => ts1.corresponds(ts2)(deepEquals)
case _ => false

if betterForsEnabled then
enums match {
case Nil => body
case (gen: GenFrom) :: Nil =>
if gen.checkMode != GenCheckMode.Filtered // results of withFilter have the wrong type
&& deepEquals(gen.pat, body)
then gen.expr // avoid a redundant map with identity
else Apply(rhsSelect(gen, mapName), makeLambda(gen, body))
case (gen: GenFrom) :: rest
if rest.dropWhile(_.isInstanceOf[GenAlias]).headOption.forall(e => e.isInstanceOf[GenFrom]) =>
val cont = makeFor(mapName, flatMapName, rest, body)
val selectName =
if rest.exists(_.isInstanceOf[GenFrom]) then flatMapName
else mapName
Apply(rhsSelect(gen, selectName), makeLambda(gen, cont))
case (gen: GenFrom) :: (rest @ GenAlias(_, _) :: _) =>
val (valeqs, rest1) = rest.span(_.isInstanceOf[GenAlias])
val pats = valeqs map { case GenAlias(pat, _) => pat }
val rhss = valeqs map { case GenAlias(_, rhs) => rhs }
val (defpat0, id0) = makeIdPat(gen.pat)
val (defpats, ids) = (pats map makeIdPat).unzip
val pdefs = valeqs.lazyZip(defpats).lazyZip(rhss).map { (valeq, defpat, rhs) =>
val mods = defpat match
case defTree: DefTree => defTree.mods
case _ => Modifiers()
makePatDef(valeq, mods, defpat, rhs)
}
val rhs1 = makeFor(nme.map, nme.flatMap, GenFrom(defpat0, gen.expr, gen.checkMode) :: Nil, Block(pdefs, makeTuple(id0 :: ids)))
val allpats = gen.pat :: pats
val vfrom1 = GenFrom(makeTuple(allpats), rhs1, GenCheckMode.Ignore)
makeFor(mapName, flatMapName, vfrom1 :: rest1, body)
case (gen: GenFrom) :: test :: rest =>
val filtered = Apply(rhsSelect(gen, nme.withFilter), makeLambda(gen, test))
val genFrom = GenFrom(gen.pat, filtered, GenCheckMode.Filtered)
makeFor(mapName, flatMapName, genFrom :: rest, body)
case GenAlias(_, _) :: _ =>
val (valeqs, rest) = enums.span(_.isInstanceOf[GenAlias])
val pats = valeqs.map { case GenAlias(pat, _) => pat }
val rhss = valeqs.map { case GenAlias(_, rhs) => rhs }
val (defpats, ids) = pats.map(makeIdPat).unzip
val pdefs = valeqs.lazyZip(defpats).lazyZip(rhss).map { (valeq, defpat, rhs) =>
val mods = defpat match
case defTree: DefTree => defTree.mods
case _ => Modifiers()
makePatDef(valeq, mods, defpat, rhs)
}
Block(pdefs, makeFor(mapName, flatMapName, rest, body))
case _ =>
EmptyTree //may happen for erroneous input
}
else {
enums match {
case (gen: GenFrom) :: Nil =>
Apply(rhsSelect(gen, mapName), makeLambda(gen, body))
case (gen: GenFrom) :: (rest @ (GenFrom(_, _, _) :: _)) =>
val cont = makeFor(mapName, flatMapName, rest, body)
Apply(rhsSelect(gen, flatMapName), makeLambda(gen, cont))
case (gen: GenFrom) :: (rest @ GenAlias(_, _) :: _) =>
val (valeqs, rest1) = rest.span(_.isInstanceOf[GenAlias])
val pats = valeqs map { case GenAlias(pat, _) => pat }
val rhss = valeqs map { case GenAlias(_, rhs) => rhs }
val (defpat0, id0) = makeIdPat(gen.pat)
val (defpats, ids) = (pats map makeIdPat).unzip
val pdefs = valeqs.lazyZip(defpats).lazyZip(rhss).map { (valeq, defpat, rhs) =>
val mods = defpat match
case defTree: DefTree => defTree.mods
case _ => Modifiers()
makePatDef(valeq, mods, defpat, rhs)
}
val rhs1 = makeFor(nme.map, nme.flatMap, GenFrom(defpat0, gen.expr, gen.checkMode) :: Nil, Block(pdefs, makeTuple(id0 :: ids)))
val allpats = gen.pat :: pats
val vfrom1 = GenFrom(makeTuple(allpats), rhs1, GenCheckMode.Ignore)
makeFor(mapName, flatMapName, vfrom1 :: rest1, body)
case (gen: GenFrom) :: test :: rest =>
val filtered = Apply(rhsSelect(gen, nme.withFilter), makeLambda(gen, test))
val genFrom = GenFrom(gen.pat, filtered, GenCheckMode.Ignore)
makeFor(mapName, flatMapName, genFrom :: rest, body)
case _ =>
EmptyTree //may happen for erroneous input
}
enums match {
case Nil if betterForsEnabled => body
case (gen: GenFrom) :: Nil =>
if betterForsEnabled
&& gen.checkMode != GenCheckMode.Filtered // results of withFilter have the wrong type
&& deepEquals(gen.pat, body)
then gen.expr // avoid a redundant map with identity
else Apply(rhsSelect(gen, mapName), makeLambda(gen, body))
case (gen: GenFrom) :: (rest @ (GenFrom(_, _, _) :: _)) =>
val cont = makeFor(mapName, flatMapName, rest, body)
Apply(rhsSelect(gen, flatMapName), makeLambda(gen, cont))
case (gen: GenFrom) :: rest
if betterForsEnabled
&& rest.dropWhile(_.isInstanceOf[GenAlias]).headOption.forall(e => e.isInstanceOf[GenFrom]) => // possible aliases followed by a generator or end of for
val cont = makeFor(mapName, flatMapName, rest, body)
val selectName =
if rest.exists(_.isInstanceOf[GenFrom]) then flatMapName
else mapName
Apply(rhsSelect(gen, selectName), makeLambda(gen, cont))
case (gen: GenFrom) :: (rest @ GenAlias(_, _) :: _) =>
val (valeqs, rest1) = rest.span(_.isInstanceOf[GenAlias])
val pats = valeqs map { case GenAlias(pat, _) => pat }
val rhss = valeqs map { case GenAlias(_, rhs) => rhs }
val (defpat0, id0) = makeIdPat(gen.pat)
val (defpats, ids) = (pats map makeIdPat).unzip
val pdefs = valeqs.lazyZip(defpats).lazyZip(rhss).map { (valeq, defpat, rhs) =>
val mods = defpat match
case defTree: DefTree => defTree.mods
case _ => Modifiers()
makePatDef(valeq, mods, defpat, rhs)
}
val rhs1 = makeFor(nme.map, nme.flatMap, GenFrom(defpat0, gen.expr, gen.checkMode) :: Nil, Block(pdefs, makeTuple(id0 :: ids)))
val allpats = gen.pat :: pats
val vfrom1 = GenFrom(makeTuple(allpats), rhs1, GenCheckMode.Ignore)
makeFor(mapName, flatMapName, vfrom1 :: rest1, body)
case (gen: GenFrom) :: test :: rest =>
val filtered = Apply(rhsSelect(gen, nme.withFilter), makeLambda(gen, test))
val genFrom = GenFrom(gen.pat, filtered, if betterForsEnabled then GenCheckMode.Filtered else GenCheckMode.Ignore)
makeFor(mapName, flatMapName, genFrom :: rest, body)
case GenAlias(_, _) :: _ if betterForsEnabled =>
val (valeqs, rest) = enums.span(_.isInstanceOf[GenAlias])
val pats = valeqs.map { case GenAlias(pat, _) => pat }
val rhss = valeqs.map { case GenAlias(_, rhs) => rhs }
val (defpats, ids) = pats.map(makeIdPat).unzip
val pdefs = valeqs.lazyZip(defpats).lazyZip(rhss).map { (valeq, defpat, rhs) =>
val mods = defpat match
case defTree: DefTree => defTree.mods
case _ => Modifiers()
makePatDef(valeq, mods, defpat, rhs)
}
Block(pdefs, makeFor(mapName, flatMapName, rest, body))
case _ =>
EmptyTree //may happen for erroneous input
}
}

Expand Down
2 changes: 1 addition & 1 deletion compiler/src/dotty/tools/dotc/ast/untpd.scala
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ object untpd extends Trees.Instance[Untyped] with UntypedTreeInfo {

/** An enum to control checking or filtering of patterns in GenFrom trees */
enum GenCheckMode {
case Ignore // neither filter since pattern is trivially irrefutable
case Ignore // neither filter nor check since pattern is trivially irrefutable
case Filtered // neither filter nor check since filtering was done before
case Check // check that pattern is irrefutable
case CheckAndFilter // both check and filter (transitional period starting with 3.2)
Expand Down
1 change: 0 additions & 1 deletion compiler/src/dotty/tools/dotc/core/StdNames.scala
Original file line number Diff line number Diff line change
Expand Up @@ -435,7 +435,6 @@ object StdNames {
val asInstanceOfPM: N = "$asInstanceOf$"
val assert_ : N = "assert"
val assume_ : N = "assume"
val betterFors: N = "betterFors"
val box: N = "box"
val break: N = "break"
val build : N = "build"
Expand Down

0 comments on commit 4b67eb0

Please sign in to comment.