Skip to content

Commit

Permalink
Replace quoted type variables in signature of HOAS pattern result (#1…
Browse files Browse the repository at this point in the history
…6951)

To be able to construct the lambda returned by the HOAS pattern we need:
first resolve the type variables and then use the result to construct
the
signature of the lambdas.

To simplify this transformation, `QuoteMatcher` returns a
`Seq[MatchResult]`
instead of an untyped `Tuple` containing `Expr[?]`. The tuple is created
once we have accumulated and processed all extracted values.

Fixes #15165
  • Loading branch information
nicolasstucki authored Mar 3, 2023
2 parents 8020c77 + 20174d7 commit 6ea3ea6
Show file tree
Hide file tree
Showing 9 changed files with 164 additions and 74 deletions.
19 changes: 19 additions & 0 deletions compiler/src/dotty/tools/dotc/util/optional.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
package dotty.tools.dotc.util

import scala.util.boundary

/** Return type that indicates that the method returns a T or aborts to the enclosing boundary with a `None` */
type optional[T] = boundary.Label[None.type] ?=> T

/** A prompt for `Option`, which establishes a boundary which `_.?` on `Option` can return */
object optional:
inline def apply[T](inline body: optional[T]): Option[T] =
boundary(Some(body))

extension [T](r: Option[T])
inline def ? (using label: boundary.Label[None.type]): T = r match
case Some(x) => x
case None => boundary.break(None)

inline def break()(using label: boundary.Label[None.type]): Nothing =
boundary.break(None)
140 changes: 79 additions & 61 deletions compiler/src/scala/quoted/runtime/impl/QuoteMatcher.scala
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
package scala.quoted
package runtime.impl


import dotty.tools.dotc.ast.tpd
import dotty.tools.dotc.core.Contexts.*
import dotty.tools.dotc.core.Flags.*
import dotty.tools.dotc.core.Names.*
import dotty.tools.dotc.core.Types.*
import dotty.tools.dotc.core.StdNames.nme
import dotty.tools.dotc.core.Symbols.*
import dotty.tools.dotc.util.optional

/** Matches a quoted tree against a quoted pattern tree.
* A quoted pattern tree may have type and term holes in addition to normal terms.
Expand Down Expand Up @@ -103,12 +103,13 @@ import dotty.tools.dotc.core.Symbols.*
object QuoteMatcher {
import tpd.*

// TODO improve performance

// TODO use flag from Context. Maybe -debug or add -debug-macros
private inline val debug = false

import Matching._
/** Sequence of matched expressions.
* These expressions are part of the scrutinee and will be bound to the quote pattern term splices.
*/
type MatchingExprs = Seq[MatchResult]

/** A map relating equivalent symbols from the scrutinee and the pattern
* For example in
Expand All @@ -121,32 +122,34 @@ object QuoteMatcher {

private def withEnv[T](env: Env)(body: Env ?=> T): T = body(using env)

def treeMatch(scrutineeTree: Tree, patternTree: Tree)(using Context): Option[Tuple] =
def treeMatch(scrutineeTree: Tree, patternTree: Tree)(using Context): Option[MatchingExprs] =
given Env = Map.empty
scrutineeTree =?= patternTree
optional:
scrutineeTree =?= patternTree

/** Check that all trees match with `mtch` and concatenate the results with &&& */
private def matchLists[T](l1: List[T], l2: List[T])(mtch: (T, T) => Matching): Matching = (l1, l2) match {
private def matchLists[T](l1: List[T], l2: List[T])(mtch: (T, T) => MatchingExprs): optional[MatchingExprs] = (l1, l2) match {
case (x :: xs, y :: ys) => mtch(x, y) &&& matchLists(xs, ys)(mtch)
case (Nil, Nil) => matched
case _ => notMatched
}

extension (scrutinees: List[Tree])
private def =?= (patterns: List[Tree])(using Env, Context): Matching =
private def =?= (patterns: List[Tree])(using Env, Context): optional[MatchingExprs] =
matchLists(scrutinees, patterns)(_ =?= _)

extension (scrutinee0: Tree)

/** Check that the trees match and return the contents from the pattern holes.
* Return None if the trees do not match otherwise return Some of a tuple containing all the contents in the holes.
* Return a sequence containing all the contents in the holes.
* If it does not match, continues to the `optional` with `None`.
*
* @param scrutinee The tree being matched
* @param pattern The pattern tree that the scrutinee should match. Contains `patternHole` holes.
* @param `summon[Env]` Set of tuples containing pairs of symbols (s, p) where s defines a symbol in `scrutinee` which corresponds to symbol p in `pattern`.
* @return `None` if it did not match or `Some(tup: Tuple)` if it matched where `tup` contains the contents of the holes.
* @return The sequence with the contents of the holes of the matched expression.
*/
private def =?= (pattern0: Tree)(using Env, Context): Matching =
private def =?= (pattern0: Tree)(using Env, Context): optional[MatchingExprs] =

/* Match block flattening */ // TODO move to cases
/** Normalize the tree */
Expand Down Expand Up @@ -203,31 +206,12 @@ object QuoteMatcher {
// Matches an open term and wraps it into a lambda that provides the free variables
case Apply(TypeApply(Ident(_), List(TypeTree())), SeqLiteral(args, _) :: Nil)
if pattern.symbol.eq(defn.QuotedRuntimePatterns_higherOrderHole) =>
def hoasClosure = {
val names: List[TermName] = args.map {
case Block(List(DefDef(nme.ANON_FUN, _, _, Apply(Ident(name), _))), _) => name.asTermName
case arg => arg.symbol.name.asTermName
}
val argTypes = args.map(x => x.tpe.widenTermRefExpr)
val methTpe = MethodType(names)(_ => argTypes, _ => pattern.tpe)
val meth = newAnonFun(ctx.owner, methTpe)
def bodyFn(lambdaArgss: List[List[Tree]]): Tree = {
val argsMap = args.map(_.symbol).zip(lambdaArgss.head).toMap
val body = new TreeMap {
override def transform(tree: Tree)(using Context): Tree =
tree match
case tree: Ident => summon[Env].get(tree.symbol).flatMap(argsMap.get).getOrElse(tree)
case tree => super.transform(tree)
}.transform(scrutinee)
TreeOps(body).changeNonLocalOwners(meth)
}
Closure(meth, bodyFn)
}
val env = summon[Env]
val capturedArgs = args.map(_.symbol)
val captureEnv = summon[Env].filter((k, v) => !capturedArgs.contains(v))
val captureEnv = env.filter((k, v) => !capturedArgs.contains(v))
withEnv(captureEnv) {
scrutinee match
case ClosedPatternTerm(scrutinee) => matched(hoasClosure)
case ClosedPatternTerm(scrutinee) => matchedOpen(scrutinee, pattern.tpe, args, env)
case _ => notMatched
}

Expand Down Expand Up @@ -431,7 +415,6 @@ object QuoteMatcher {
case _ => scrutinee
val pattern = patternTree.symbol


devirtualizedScrutinee == pattern
|| summon[Env].get(devirtualizedScrutinee).contains(pattern)
|| devirtualizedScrutinee.allOverriddenSymbols.contains(pattern)
Expand All @@ -452,32 +435,67 @@ object QuoteMatcher {
accumulator.apply(Set.empty, term)
}

/** Result of matching a part of an expression */
private type Matching = Option[Tuple]

private object Matching {

def notMatched: Matching = None

val matched: Matching = Some(Tuple())

def matched(tree: Tree)(using Context): Matching =
Some(Tuple1(new ExprImpl(tree, SpliceScope.getCurrent)))

extension (self: Matching)
def asOptionOfTuple: Option[Tuple] = self

/** Concatenates the contents of two successful matchings or return a `notMatched` */
def &&& (that: => Matching): Matching = self match {
case Some(x) =>
that match {
case Some(y) => Some(x ++ y)
case _ => None
}
case _ => None
}
end extension

}
enum MatchResult:
/** Closed pattern extracted value
* @param tree Scrutinee sub-tree that matched
*/
case ClosedTree(tree: Tree)
/** HOAS pattern extracted value
*
* @param tree Scrutinee sub-tree that matched
* @param patternTpe Type of the pattern hole (from the pattern)
* @param args HOAS arguments (from the pattern)
* @param env Mapping between scrutinee and pattern variables
*/
case OpenTree(tree: Tree, patternTpe: Type, args: List[Tree], env: Env)

/** Return the expression that was extracted from a hole.
*
* If it was a closed expression it returns that expression. Otherwise,
* if it is a HOAS pattern, the surrounding lambda is generated using
* `mapTypeHoles` to create the signature of the lambda.
*
* This expression is assumed to be a valid expression in the given splice scope.
*/
def toExpr(mapTypeHoles: TypeMap, spliceScope: Scope)(using Context): Expr[Any] = this match
case MatchResult.ClosedTree(tree) =>
new ExprImpl(tree, spliceScope)
case MatchResult.OpenTree(tree, patternTpe, args, env) =>
val names: List[TermName] = args.map {
case Block(List(DefDef(nme.ANON_FUN, _, _, Apply(Ident(name), _))), _) => name.asTermName
case arg => arg.symbol.name.asTermName
}
val paramTypes = args.map(x => mapTypeHoles(x.tpe.widenTermRefExpr))
val methTpe = MethodType(names)(_ => paramTypes, _ => mapTypeHoles(patternTpe))
val meth = newAnonFun(ctx.owner, methTpe)
def bodyFn(lambdaArgss: List[List[Tree]]): Tree = {
val argsMap = args.view.map(_.symbol).zip(lambdaArgss.head).toMap
val body = new TreeMap {
override def transform(tree: Tree)(using Context): Tree =
tree match
case tree: Ident => env.get(tree.symbol).flatMap(argsMap.get).getOrElse(tree)
case tree => super.transform(tree)
}.transform(tree)
TreeOps(body).changeNonLocalOwners(meth)
}
val hoasClosure = Closure(meth, bodyFn)
new ExprImpl(hoasClosure, spliceScope)

private inline def notMatched: optional[MatchingExprs] =
optional.break()

private inline def matched: MatchingExprs =
Seq.empty

private inline def matched(tree: Tree)(using Context): MatchingExprs =
Seq(MatchResult.ClosedTree(tree))

private def matchedOpen(tree: Tree, patternTpe: Type, args: List[Tree], env: Env)(using Context): MatchingExprs =
Seq(MatchResult.OpenTree(tree, patternTpe, args, env))

extension (self: MatchingExprs)
/** Concatenates the contents of two successful matchings */
def &&& (that: MatchingExprs): MatchingExprs = self ++ that
end extension

}
33 changes: 20 additions & 13 deletions compiler/src/scala/quoted/runtime/impl/QuotesImpl.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3137,20 +3137,27 @@ class QuotesImpl private (using val ctx: Context) extends Quotes, QuoteUnpickler
ctx1.gadtState.addToConstraint(typeHoles)
ctx1

val matchings = QuoteMatcher.treeMatch(scrutinee, pat1)(using ctx1)

if typeHoles.isEmpty then matchings
else {
// After matching and doing all subtype checks, we have to approximate all the type bindings
// that we have found, seal them in a quoted.Type and add them to the result
def typeHoleApproximation(sym: Symbol) =
val fromAboveAnnot = sym.hasAnnotation(dotc.core.Symbols.defn.QuotedRuntimePatterns_fromAboveAnnot)
val fullBounds = ctx1.gadt.fullBounds(sym)
val tp = if fromAboveAnnot then fullBounds.hi else fullBounds.lo
reflect.TypeReprMethods.asType(tp)
matchings.map { tup =>
Tuple.fromIArray(typeHoles.map(typeHoleApproximation).toArray.asInstanceOf[IArray[Object]]) ++ tup
// After matching and doing all subtype checks, we have to approximate all the type bindings
// that we have found, seal them in a quoted.Type and add them to the result
def typeHoleApproximation(sym: Symbol) =
val fromAboveAnnot = sym.hasAnnotation(dotc.core.Symbols.defn.QuotedRuntimePatterns_fromAboveAnnot)
val fullBounds = ctx1.gadt.fullBounds(sym)
if fromAboveAnnot then fullBounds.hi else fullBounds.lo

QuoteMatcher.treeMatch(scrutinee, pat1)(using ctx1).map { matchings =>
import QuoteMatcher.MatchResult.*
lazy val spliceScope = SpliceScope.getCurrent
val typeHoleApproximations = typeHoles.map(typeHoleApproximation)
val typeHoleMapping = Map(typeHoles.zip(typeHoleApproximations)*)
val typeHoleMap = new Types.TypeMap {
def apply(tp: Types.Type): Types.Type = tp match
case Types.TypeRef(Types.NoPrefix, _) => typeHoleMapping.getOrElse(tp.typeSymbol, tp)
case _ => mapOver(tp)
}
val matchedExprs = matchings.map(_.toExpr(typeHoleMap, spliceScope))
val matchedTypes = typeHoleApproximations.map(reflect.TypeReprMethods.asType)
val results = matchedTypes ++ matchedExprs
Tuple.fromIArray(IArray.unsafeFromArray(results.toArray))
}
}

Expand Down
9 changes: 9 additions & 0 deletions tests/pos-macros/i15165a/Macro_1.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
import scala.quoted.*

inline def valToFun[T](inline expr: T): T =
${ impl('expr) }

def impl[T: Type](expr: Expr[T])(using quotes: Quotes): Expr[T] =
expr match
case '{ { val ident = ($a: α); $rest(ident): T } } =>
'{ { (y: α) => $rest(y) }.apply(???) }
4 changes: 4 additions & 0 deletions tests/pos-macros/i15165a/Test_2.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
def test = valToFun {
val a: Int = 1
a + 1
}
16 changes: 16 additions & 0 deletions tests/pos-macros/i15165b/Macro_1.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import scala.quoted.*

inline def valToFun[T](inline expr: T): T =
${ impl('expr) }

def impl[T: Type](expr: Expr[T])(using quotes: Quotes): Expr[T] =
expr match
case '{ { val ident = ($a: α); $rest(ident): T } } =>
'{
{ (y: α) =>
${
val bound = '{ ${ rest }(y) }
Expr.betaReduce(bound)
}
}.apply($a)
}
4 changes: 4 additions & 0 deletions tests/pos-macros/i15165b/Test_2.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
def test = valToFun {
val a: Int = 1
a + 1
}
9 changes: 9 additions & 0 deletions tests/pos-macros/i15165c/Macro_1.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
import scala.quoted.*

inline def valToFun[T](inline expr: T): T =
${ impl('expr) }

def impl[T: Type](expr: Expr[T])(using quotes: Quotes): Expr[T] =
expr match
case '{ type α; { val ident = ($a: `α`); $rest(ident): `α` & T } } =>
'{ { (y: α) => $rest(y) }.apply(???) }
4 changes: 4 additions & 0 deletions tests/pos-macros/i15165c/Test_2.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
def test = valToFun {
val a: Int = 1
a + 1
}

0 comments on commit 6ea3ea6

Please sign in to comment.