Skip to content

Commit

Permalink
WIP: Pattern match against hoas pattern wtih type vars
Browse files Browse the repository at this point in the history
  • Loading branch information
zeptometer committed Jul 31, 2023
1 parent 7bfa19d commit 282e4ad
Showing 1 changed file with 97 additions and 27 deletions.
124 changes: 97 additions & 27 deletions compiler/src/scala/quoted/runtime/impl/QuoteMatcher.scala
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,42 @@ class QuoteMatcher(debug: Boolean) {
val captureEnv = env.filter((k, v) => !capturedSymbols.contains(v))
withEnv(captureEnv) {
scrutinee match
case ClosedPatternTerm(scrutinee) => matchedOpen(scrutinee, pattern.tpe, capturedIds, args.map(_.tpe), env)
case ClosedPatternTerm(scrutinee) => matchedOpen(scrutinee, pattern.tpe, capturedIds, args.map(_.tpe), Nil, env)
case _ => notMatched
}

/* Higher order term hole */
// Matches an open term and wraps it into a lambda that provides the free variables
case Apply(TypeApply(Ident(_), List(TypeTree(), targs)), SeqLiteral(args, _) :: Nil)
if pattern.symbol.eq(defn.QuotedRuntimePatterns_higherOrderHoleWithTypes) =>

/* Some of method symbols in arguments of higher-order term hole are eta-expanded.
* e.g.
* g: (Int) => Int
* => {
* def $anonfun(y: Int): Int = g(y)
* closure($anonfun)
* }
*
* f: (using Int) => Int
* => f(using x)
* This function restores the symbol of the original method from
* the eta-expanded function.
*/
def getCapturedIdent(arg: Tree)(using Context): Ident =
arg match
case id: Ident => id
case Apply(fun, _) => getCapturedIdent(fun)
case Block((ddef: DefDef) :: _, _: Closure) => getCapturedIdent(ddef.rhs)
case Typed(expr, _) => getCapturedIdent(expr)

val env = summon[Env]
val capturedIds = args.map(getCapturedIdent)
val capturedSymbols = capturedIds.map(_.symbol)
val captureEnv = env.filter((k, v) => !capturedSymbols.contains(v))
withEnv(captureEnv) {
scrutinee match
case ClosedPatternTerm(scrutinee) => matchedOpen(scrutinee, pattern.tpe, capturedIds, args.map(_.tpe), targs, env)
case _ => notMatched
}

Expand Down Expand Up @@ -558,9 +593,10 @@ class QuoteMatcher(debug: Boolean) {
* @param patternTpe Type of the pattern hole (from the pattern)
* @param argIds Identifiers of HOAS arguments (from the pattern)
* @param argTypes Eta-expanded types of HOAS arguments (from the pattern)
* @param typeArgs type arguments from the pattern
* @param env Mapping between scrutinee and pattern variables
*/
case OpenTree(tree: Tree, patternTpe: Type, argIds: List[Tree], argTypes: List[Type], env: Env)
case OpenTree(tree: Tree, patternTpe: Type, argIds: List[Tree], argTypes: List[Type], typeArgs: List[Type], env: Env)

/** Return the expression that was extracted from a hole.
*
Expand All @@ -573,29 +609,63 @@ class QuoteMatcher(debug: Boolean) {
def toExpr(mapTypeHoles: Type => Type, spliceScope: Scope)(using Context): Expr[Any] = this match
case MatchResult.ClosedTree(tree) =>
new ExprImpl(tree, spliceScope)
case MatchResult.OpenTree(tree, patternTpe, argIds, argTypes, env) =>
val names: List[TermName] = argIds.map(_.symbol.name.asTermName)
val paramTypes = argTypes.map(tpe => mapTypeHoles(tpe.widenTermRefExpr))
val methTpe = MethodType(names)(_ => paramTypes, _ => mapTypeHoles(patternTpe))
val meth = newAnonFun(ctx.owner, methTpe)
def bodyFn(lambdaArgss: List[List[Tree]]): Tree = {
val argsMap = argIds.view.map(_.symbol).zip(lambdaArgss.head).toMap
val body = new TreeMap {
override def transform(tree: Tree)(using Context): Tree =
tree match
/*
* When matching a method call `f(0)` against a HOAS pattern `p(g)` where
* f has a method type `(x: Int): Int` and `f` maps to `g`, `p` should hold
* `g.apply(0)` because the type of `g` is `Int => Int` due to eta expansion.
*/
case Apply(fun, args) if env.contains(tree.symbol) => transform(fun).select(nme.apply).appliedToArgs(args.map(transform))
case tree: Ident => env.get(tree.symbol).flatMap(argsMap.get).getOrElse(tree)
case tree => super.transform(tree)
}.transform(tree)
TreeOps(body).changeNonLocalOwners(meth)
}
val hoasClosure = Closure(meth, bodyFn)
new ExprImpl(hoasClosure, spliceScope)
case MatchResult.OpenTree(tree, patternTpe, argIds, argTypes, typeArgs, env) =>
if typeArgs.isEmpty then
val names: List[TermName] = argIds.map(_.symbol.name.asTermName)
val paramTypes = argTypes.map(tpe => mapTypeHoles(tpe.widenTermRefExpr))
val methTpe = MethodType(names)(_ => paramTypes, _ => mapTypeHoles(patternTpe))
val meth = newAnonFun(ctx.owner, methTpe)
def bodyFn(lambdaArgss: List[List[Tree]]): Tree = {
val argsMap = argIds.view.map(_.symbol).zip(lambdaArgss.head).toMap
val body = new TreeMap {
override def transform(tree: Tree)(using Context): Tree =
tree match
/*
* When matching a method call `f(0)` against a HOAS pattern `p(g)` where
* f has a method type `(x: Int): Int` and `f` maps to `g`, `p` should hold
* `g.apply(0)` because the type of `g` is `Int => Int` due to eta expansion.
*/
case Apply(fun, args) if env.contains(tree.symbol) => transform(fun).select(nme.apply).appliedToArgs(args.map(transform))
case tree: Ident => env.get(tree.symbol).flatMap(argsMap.get).getOrElse(tree)
case tree => super.transform(tree)
}.transform(tree)
TreeOps(body).changeNonLocalOwners(meth)
}
val hoasClosure = Closure(meth, bodyFn)
new ExprImpl(hoasClosure, spliceScope)
else
val names: List[TermName] = argIds.map(_.symbol.name.asTermName)
val paramTypes = argTypes.map(tpe => mapTypeHoles(tpe.widenTermRefExpr))

val typeArgs1 = PolyType.syntheticParamNames(typeArgs.length)
val bounds = typeArgs map (_ => TypeBounds.empty)
val resultTypeExp = (pt: PolyType) => {
val fromSymbols = typeArgs.map(_.typeSymbol)
val argTypes1 = argTypes.map(_.subst(fromSymbols, pt.paramRefs))
val resultType1 = mapTypeHoles(patternTpe).subst(fromSymbols, pt.paramRefs)
MethodType(argTypes1, resultType1)
}
val methTpe = PolyType(typeArgs1)(_ => bounds, resultTypeExp)
val meth = newAnonFun(ctx.owner, methTpe)
// TODO-18271
def bodyFn(lambdaArgss: List[List[Tree]]): Tree = {
val argsMap = argIds.view.map(_.symbol).zip(lambdaArgss.head).toMap
val body = new TreeMap {
override def transform(tree: Tree)(using Context): Tree =
tree match
/*
* When matching a method call `f(0)` against a HOAS pattern `p(g)` where
* f has a method type `(x: Int): Int` and `f` maps to `g`, `p` should hold
* `g.apply(0)` because the type of `g` is `Int => Int` due to eta expansion.
*/
case Apply(fun, args) if env.contains(tree.symbol) => transform(fun).select(nme.apply).appliedToArgs(args.map(transform))
case tree: Ident => env.get(tree.symbol).flatMap(argsMap.get).getOrElse(tree)
case tree => super.transform(tree)
}.transform(tree)
TreeOps(body).changeNonLocalOwners(meth)
}
val hoasClosure = Closure(meth, bodyFn)
new ExprImpl(hoasClosure, spliceScope)

private inline def notMatched[T]: optional[T] =
optional.break()
Expand All @@ -606,8 +676,8 @@ class QuoteMatcher(debug: Boolean) {
private inline def matched(tree: Tree)(using Context): MatchingExprs =
Seq(MatchResult.ClosedTree(tree))

private def matchedOpen(tree: Tree, patternTpe: Type, argIds: List[Tree], argTypes: List[Type], env: Env)(using Context): MatchingExprs =
Seq(MatchResult.OpenTree(tree, patternTpe, argIds, argTypes, env))
private def matchedOpen(tree: Tree, patternTpe: Type, argIds: List[Tree], argTypes: List[Type], typeArgs: List[Type], env: Env)(using Context): MatchingExprs =
Seq(MatchResult.OpenTree(tree, patternTpe, argIds, argTypes, typeArgs, env))

extension (self: MatchingExprs)
/** Concatenates the contents of two successful matchings */
Expand Down

0 comments on commit 282e4ad

Please sign in to comment.