Skip to content

Commit

Permalink
Generalize HOAS patterns to take type parameters (experimental feature)
Browse files Browse the repository at this point in the history
  • Loading branch information
zeptometer authored and nicolasstucki committed Oct 16, 2023
1 parent 8c1cdc2 commit 02b60bf
Show file tree
Hide file tree
Showing 35 changed files with 615 additions and 118 deletions.
2 changes: 1 addition & 1 deletion compiler/src/dotty/tools/dotc/ast/Desugar.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2028,7 +2028,7 @@ object desugar {
case Quote(body, _) =>
new UntypedTreeTraverser {
def traverse(tree: untpd.Tree)(using Context): Unit = tree match {
case SplicePattern(body, _) => collect(body)
case SplicePattern(body, _, _) => collect(body)
case _ => traverseChildren(tree)
}
}.traverse(body)
Expand Down
2 changes: 1 addition & 1 deletion compiler/src/dotty/tools/dotc/ast/TreeInfo.scala
Original file line number Diff line number Diff line change
Expand Up @@ -852,7 +852,7 @@ trait TypedTreeInfo extends TreeInfo[Type] { self: Trees.Instance[Type] =>
}
private object quotePatVars extends TreeAccumulator[List[Symbol]] {
def apply(syms: List[Symbol], tree: Tree)(using Context) = tree match {
case SplicePattern(pat, _) => outer.apply(syms, pat)
case SplicePattern(pat, _, _) => outer.apply(syms, pat)
case _ => foldOver(syms, tree)
}
}
Expand Down
17 changes: 9 additions & 8 deletions compiler/src/dotty/tools/dotc/ast/Trees.scala
Original file line number Diff line number Diff line change
Expand Up @@ -760,9 +760,10 @@ object Trees {
* `SplicePattern` can only be contained within a `QuotePattern`.
*
* @param body The tree that was spliced
* @param typeargs The type arguments of the splice (the HOAS arguments)
* @param args The arguments of the splice (the HOAS arguments)
*/
case class SplicePattern[+T <: Untyped] private[ast] (body: Tree[T], args: List[Tree[T]])(implicit @constructorOnly src: SourceFile)
case class SplicePattern[+T <: Untyped] private[ast] (body: Tree[T], typeargs: List[Tree[T]], args: List[Tree[T]])(implicit @constructorOnly src: SourceFile)
extends TermTree[T] {
type ThisTree[+T <: Untyped] = SplicePattern[T]
}
Expand Down Expand Up @@ -1367,9 +1368,9 @@ object Trees {
case tree: QuotePattern if (bindings eq tree.bindings) && (body eq tree.body) && (quotes eq tree.quotes) => tree
case _ => finalize(tree, untpd.QuotePattern(bindings, body, quotes)(sourceFile(tree)))
}
def SplicePattern(tree: Tree)(body: Tree, args: List[Tree])(using Context): SplicePattern = tree match {
case tree: SplicePattern if (body eq tree.body) && (args eq tree.args) => tree
case _ => finalize(tree, untpd.SplicePattern(body, args)(sourceFile(tree)))
def SplicePattern(tree: Tree)(body: Tree, typeargs: List[Tree], args: List[Tree])(using Context): SplicePattern = tree match {
case tree: SplicePattern if (body eq tree.body) && (typeargs eq tree.typeargs) & (args eq tree.args) => tree
case _ => finalize(tree, untpd.SplicePattern(body, typeargs, args)(sourceFile(tree)))
}
def SingletonTypeTree(tree: Tree)(ref: Tree)(using Context): SingletonTypeTree = tree match {
case tree: SingletonTypeTree if (ref eq tree.ref) => tree
Expand Down Expand Up @@ -1617,8 +1618,8 @@ object Trees {
cpy.Splice(tree)(transform(expr)(using spliceContext))
case tree @ QuotePattern(bindings, body, quotes) =>
cpy.QuotePattern(tree)(transform(bindings), transform(body)(using quoteContext), transform(quotes))
case tree @ SplicePattern(body, args) =>
cpy.SplicePattern(tree)(transform(body)(using spliceContext), transform(args))
case tree @ SplicePattern(body, targs, args) =>
cpy.SplicePattern(tree)(transform(body)(using spliceContext), transform(targs), transform(args))
case tree @ Hole(isTerm, idx, args, content) =>
cpy.Hole(tree)(isTerm, idx, transform(args), transform(content))
case _ =>
Expand Down Expand Up @@ -1766,8 +1767,8 @@ object Trees {
this(x, expr)(using spliceContext)
case QuotePattern(bindings, body, quotes) =>
this(this(this(x, bindings), body)(using quoteContext), quotes)
case SplicePattern(body, args) =>
this(this(x, body)(using spliceContext), args)
case SplicePattern(body, typeargs, args) =>
this(this(this(x, body)(using spliceContext), typeargs), args)
case Hole(_, _, args, content) =>
this(this(x, args), content)
case _ =>
Expand Down
2 changes: 1 addition & 1 deletion compiler/src/dotty/tools/dotc/ast/untpd.scala
Original file line number Diff line number Diff line change
Expand Up @@ -412,7 +412,7 @@ object untpd extends Trees.Instance[Untyped] with UntypedTreeInfo {
def Quote(body: Tree, tags: List[Tree])(implicit src: SourceFile): Quote = new Quote(body, tags)
def Splice(expr: Tree)(implicit src: SourceFile): Splice = new Splice(expr)
def QuotePattern(bindings: List[Tree], body: Tree, quotes: Tree)(implicit src: SourceFile): QuotePattern = new QuotePattern(bindings, body, quotes)
def SplicePattern(body: Tree, args: List[Tree])(implicit src: SourceFile): SplicePattern = new SplicePattern(body, args)
def SplicePattern(body: Tree, typeargs: List[Tree], args: List[Tree])(implicit src: SourceFile): SplicePattern = new SplicePattern(body, typeargs, args)
def TypeTree()(implicit src: SourceFile): TypeTree = new TypeTree()
def InferredTypeTree()(implicit src: SourceFile): TypeTree = new InferredTypeTree()
def SingletonTypeTree(ref: Tree)(implicit src: SourceFile): SingletonTypeTree = new SingletonTypeTree(ref)
Expand Down
4 changes: 4 additions & 0 deletions compiler/src/dotty/tools/dotc/config/Feature.scala
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ object Feature:
val pureFunctions = experimental("pureFunctions")
val captureChecking = experimental("captureChecking")
val into = experimental("into")
val quotedPatternsWithPolymorphicFunctions = experimental("quotedPatternsWithPolymorphicFunctions")

val globalOnlyImports: Set[TermName] = Set(pureFunctions, captureChecking)

Expand Down Expand Up @@ -83,6 +84,9 @@ object Feature:

def scala2ExperimentalMacroEnabled(using Context) = enabled(scala2macros)

def quotedPatternsWithPolymorphicFunctionsEnabled(using Context) =
enabled(quotedPatternsWithPolymorphicFunctions)

/** Is pureFunctions enabled for this compilation unit? */
def pureFunsEnabled(using Context) =
enabledBySetting(pureFunctions)
Expand Down
1 change: 1 addition & 0 deletions compiler/src/dotty/tools/dotc/core/Definitions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -884,6 +884,7 @@ class Definitions {
@tu lazy val QuotedRuntimePatterns: Symbol = requiredModule("scala.quoted.runtime.Patterns")
@tu lazy val QuotedRuntimePatterns_patternHole: Symbol = QuotedRuntimePatterns.requiredMethod("patternHole")
@tu lazy val QuotedRuntimePatterns_higherOrderHole: Symbol = QuotedRuntimePatterns.requiredMethod("higherOrderHole")
@tu lazy val QuotedRuntimePatterns_higherOrderHoleWithTypes: Symbol = QuotedRuntimePatterns.requiredMethod("higherOrderHoleWithTypes")
@tu lazy val QuotedRuntimePatterns_patternTypeAnnot: ClassSymbol = QuotedRuntimePatterns.requiredClass("patternType")
@tu lazy val QuotedRuntimePatterns_fromAboveAnnot: ClassSymbol = QuotedRuntimePatterns.requiredClass("fromAbove")

Expand Down
2 changes: 1 addition & 1 deletion compiler/src/dotty/tools/dotc/parsing/Parsers.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1802,7 +1802,7 @@ object Parsers {
syntaxError(em"$msg\n\nHint: $hint", Span(start, in.lastOffset))
Ident(nme.ERROR.toTypeName)
else if inPattern then
SplicePattern(expr, Nil)
SplicePattern(expr, Nil, Nil)
else
Splice(expr)
}
Expand Down
7 changes: 4 additions & 3 deletions compiler/src/dotty/tools/dotc/printing/RefinedPrinter.scala
Original file line number Diff line number Diff line change
Expand Up @@ -752,11 +752,12 @@ class RefinedPrinter(_ctx: Context) extends PlainPrinter(_ctx) {
val open = if (body.isTerm) keywordStr("{") else keywordStr("[")
val close = if (body.isTerm) keywordStr("}") else keywordStr("]")
keywordStr("'") ~ quotesText ~ open ~ bindingsText ~ toTextGlobal(body) ~ close
case SplicePattern(pattern, args) =>
case SplicePattern(pattern, typeargs, args) =>
val spliceTypeText = (keywordStr("[") ~ toTextGlobal(tree.typeOpt) ~ keywordStr("]")).provided(printDebug && tree.typeOpt.exists)
keywordStr("$") ~ spliceTypeText ~ {
if args.isEmpty then keywordStr("{") ~ inPattern(toText(pattern)) ~ keywordStr("}")
else toText(pattern.symbol.name) ~ "(" ~ toTextGlobal(args, ", ") ~ ")"
if typeargs.isEmpty && args.isEmpty then keywordStr("{") ~ inPattern(toText(pattern)) ~ keywordStr("}")
else if typeargs.isEmpty then toText(pattern.symbol.name) ~ "(" ~ toTextGlobal(args, ", ") ~ ")"
else toText(pattern.symbol.name) ~ "[" ~ toTextGlobal(typeargs, ", ")~ "]" ~ "(" ~ toTextGlobal(args, ", ") ~ ")"
}
case Hole(isTerm, idx, args, content) =>
val (prefix, postfix) = if isTerm then ("{{{", "}}}") else ("[[[", "]]]")
Expand Down
137 changes: 103 additions & 34 deletions compiler/src/dotty/tools/dotc/quoted/QuotePatterns.scala
Original file line number Diff line number Diff line change
Expand Up @@ -26,31 +26,90 @@ object QuotePatterns:
import tpd._

/** Check for restricted patterns */
def checkPattern(quotePattern: QuotePattern)(using Context): Unit = new tpd.TreeTraverser {
def traverse(tree: Tree)(using Context): Unit = tree match {
case _: SplicePattern =>
case tdef: TypeDef if tdef.symbol.isClass =>
val kind = if tdef.symbol.is(Module) then "objects" else "classes"
report.error(em"Implementation restriction: cannot match $kind", tree.srcPos)
case tree: NamedDefTree =>
if tree.name.is(NameKinds.WildcardParamName) then
report.warning(
"Use of `_` for lambda in quoted pattern. Use explicit lambda instead or use `$_` to match any term.",
tree.srcPos)
if tree.name.isTermName && !tree.nameSpan.isSynthetic && tree.name != nme.ANON_FUN && tree.name.startsWith("$") then
report.error("Names cannot start with $ quote pattern", tree.namePos)
traverseChildren(tree)
case _: Match =>
report.error("Implementation restriction: cannot match `match` expressions", tree.srcPos)
case _: Try =>
report.error("Implementation restriction: cannot match `try` expressions", tree.srcPos)
case _: Return =>
report.error("Implementation restriction: cannot match `return` statements", tree.srcPos)
case _ =>
traverseChildren(tree)
}
def checkPattern(quotePattern: QuotePattern)(using Context): Unit =
def validatePatternAndCollectTypeVars(): Set[Symbol] = new tpd.TreeAccumulator[Set[Symbol]] {
override def apply(typevars: Set[Symbol], tree: tpd.Tree)(using Context): Set[Symbol] =
// Collect type variables
val typevars1 = tree match
case tree @ DefDef(_, paramss, _, _) =>
typevars union paramss.flatMap{ params => params match
case TypeDefs(tdefs) => tdefs.map(_.symbol)
case _ => List.empty
}.toSet union typevars
case _ => typevars

// Validate pattern
tree match
case _: SplicePattern => typevars1
case tdef: TypeDef if tdef.symbol.isClass =>
val kind = if tdef.symbol.is(Module) then "objects" else "classes"
report.error(em"Implementation restriction: cannot match $kind", tree.srcPos)
typevars1
case tree: NamedDefTree =>
if tree.name.is(NameKinds.WildcardParamName) then
report.warning(
"Use of `_` for lambda in quoted pattern. Use explicit lambda instead or use `$_` to match any term.",
tree.srcPos)
if tree.name.isTermName && !tree.nameSpan.isSynthetic && tree.name != nme.ANON_FUN && tree.name.startsWith("$") then
report.error("Names cannot start with $ quote pattern", tree.namePos)
foldOver(typevars1, tree)
case _: Match =>
report.error("Implementation restriction: cannot match `match` expressions", tree.srcPos)
typevars1
case _: Try =>
report.error("Implementation restriction: cannot match `try` expressions", tree.srcPos)
typevars1
case _: Return =>
report.error("Implementation restriction: cannot match `return` statements", tree.srcPos)
typevars1
case _ =>
foldOver(typevars1, tree)
}.apply(Set.empty, quotePattern.body)

val boundTypeVars = validatePatternAndCollectTypeVars()

}.traverse(quotePattern.body)
/*
* This part checks well-formedness of arguments to hoas patterns.
* (1) Type arguments of a hoas patterns must be introduced in the quote pattern.ctxShow
* Examples
* well-formed: '{ [A] => (x : A) => $a[A](x) } // A is introduced in the quote pattern
* ill-formed: '{ (x : Int) => $a[Int](x) } // Int is defined outside of the quote pattern
* (2) If value arguments of a hoas pattern has a type with type variables that are introduced in
* the quote pattern, those type variables should be in type arguments to the hoas patternHole
* Examples
* well-formed: '{ [A] => (x : A) => $a[A](x) } // a : [A] => (x:A) => A
* ill-formed: '{ [A] => (x : A) => $a(x) } // a : (x:A) => A ...but A is undefined; hence ill-formed
*/
new tpd.TreeTraverser {
override def traverse(tree: tpd.Tree)(using Context): Unit = tree match {
case tree: SplicePattern =>
def uncapturedTypeVars(arg: tpd.Tree, capturedTypeVars: List[tpd.Tree]): Set[Type] =
/* Sometimes arg is untyped when a splice pattern is ill-formed.
* Return early in such case.
* Refer to QuoteAndSplices::typedSplicePattern
*/
if !arg.hasType then return Set.empty

val capturedTypeVarsSet = capturedTypeVars.map(_.symbol).toSet
new TypeAccumulator[Set[Type]] {
def apply(x: Set[Type], tp: Type): Set[Type] =
if boundTypeVars.contains(tp.typeSymbol) && !capturedTypeVarsSet.contains(tp.typeSymbol) then
foldOver(x + tp, tp)
else
foldOver(x, tp)
}.apply(Set.empty, arg.tpe)

for (typearg <- tree.typeargs) // case (1)
do
if !boundTypeVars.contains(typearg.symbol) then
report.error("Type arguments of a hoas pattern needs to be defined inside the quoted pattern", typearg.srcPos)
for (arg <- tree.args) // case (2)
do
if !uncapturedTypeVars(arg, tree.typeargs).isEmpty then
report.error("Type variables that this argument depends on are not captured in this hoas pattern", arg.srcPos)
case _ => traverseChildren(tree)
}
}.traverse(quotePattern.body)

/** Encode the quote pattern into an `unapply` that the pattern matcher can handle.
*
Expand All @@ -74,7 +133,7 @@ object QuotePatterns:
* .ExprMatch // or TypeMatch
* .unapply[
* KCons[t1 >: l1 <: b1, ...KCons[tn >: ln <: bn, KNil]...], // scala.quoted.runtime.{KCons, KNil}
* (T1, T2, (A1, ..., An) => T3, ...)
* (Expr[T1], Expr[T2], Expr[(A1, ..., An) => T3], ...)
* ](
* '{
* type t1' >: l1' <: b1'
Expand Down Expand Up @@ -197,16 +256,24 @@ object QuotePatterns:
val patBuf = new mutable.ListBuffer[Tree]
val shape = new tpd.TreeMap {
override def transform(tree: Tree)(using Context) = tree match {
case Typed(splice @ SplicePattern(pat, Nil), tpt) if !tpt.tpe.derivesFrom(defn.RepeatedParamClass) =>
case Typed(splice @ SplicePattern(pat, Nil, Nil), tpt) if !tpt.tpe.derivesFrom(defn.RepeatedParamClass) =>
transform(tpt) // Collect type bindings
transform(splice)
case SplicePattern(pat, args) =>
case SplicePattern(pat, typeargs, args) =>
val patType = pat.tpe.widen
val patType1 = patType.translateFromRepeated(toArray = false)
val pat1 = if (patType eq patType1) pat else pat.withType(patType1)
patBuf += pat1
if args.isEmpty then ref(defn.QuotedRuntimePatterns_patternHole.termRef).appliedToType(tree.tpe).withSpan(tree.span)
else ref(defn.QuotedRuntimePatterns_higherOrderHole.termRef).appliedToType(tree.tpe).appliedTo(SeqLiteral(args, TypeTree(defn.AnyType))).withSpan(tree.span)
if typeargs.isEmpty && args.isEmpty then ref(defn.QuotedRuntimePatterns_patternHole.termRef).appliedToType(tree.tpe).withSpan(tree.span)
else if typeargs.isEmpty then
ref(defn.QuotedRuntimePatterns_higherOrderHole.termRef)
.appliedToType(tree.tpe)
.appliedTo(SeqLiteral(args, TypeTree(defn.AnyType)))
.withSpan(tree.span)
else ref(defn.QuotedRuntimePatterns_higherOrderHoleWithTypes.termRef)
.appliedToTypeTrees(List(TypeTree(tree.tpe), tpd.hkNestedPairsTypeTree(typeargs)))
.appliedTo(SeqLiteral(args, TypeTree(defn.AnyType)))
.withSpan(tree.span)
case _ =>
super.transform(tree)
}
Expand All @@ -232,17 +299,19 @@ object QuotePatterns:
fun match
// <quotes>.asInstanceOf[QuoteMatching].{ExprMatch,TypeMatch}.unapply[<typeBindings>, <resTypes>]
case TypeApply(Select(Select(TypeApply(Select(quotes, _), _), _), _), typeBindings :: resTypes :: Nil) =>
val bindings = unrollBindings(typeBindings)
val bindings = unrollHkNestedPairsTypeTree(typeBindings)
val addPattenSplice = new TreeMap {
private val patternIterator = patterns.iterator.filter {
case pat: Bind => !pat.symbol.name.is(PatMatGivenVarName)
case _ => true
}
override def transform(tree: tpd.Tree)(using Context): tpd.Tree = tree match
case TypeApply(patternHole, _) if patternHole.symbol == defn.QuotedRuntimePatterns_patternHole =>
cpy.SplicePattern(tree)(patternIterator.next(), Nil)
cpy.SplicePattern(tree)(patternIterator.next(), Nil, Nil)
case Apply(patternHole, SeqLiteral(args, _) :: Nil) if patternHole.symbol == defn.QuotedRuntimePatterns_higherOrderHole =>
cpy.SplicePattern(tree)(patternIterator.next(), args)
cpy.SplicePattern(tree)(patternIterator.next(), Nil, args)
case Apply(TypeApply(patternHole, List(_, targsTpe)), SeqLiteral(args, _) :: Nil) if patternHole.symbol == defn.QuotedRuntimePatterns_higherOrderHoleWithTypes =>
cpy.SplicePattern(tree)(patternIterator.next(), unrollHkNestedPairsTypeTree(targsTpe), args)
case _ => super.transform(tree)
}
val body = addPattenSplice.transform(shape) match
Expand All @@ -260,7 +329,7 @@ object QuotePatterns:
case body => body
cpy.QuotePattern(tree)(bindings, body, quotes)

private def unrollBindings(tree: Tree)(using Context): List[Tree] = tree match
private def unrollHkNestedPairsTypeTree(tree: Tree)(using Context): List[Tree] = tree match
case AppliedTypeTree(tupleN, bindings) if defn.isTupleClass(tupleN.symbol) => bindings // TupleN, 1 <= N <= 22
case AppliedTypeTree(_, head :: tail :: Nil) => head :: unrollBindings(tail) // KCons or *:
case AppliedTypeTree(_, head :: tail :: Nil) => head :: unrollHkNestedPairsTypeTree(tail) // KCons or *:
case _ => Nil // KNil or EmptyTuple
9 changes: 9 additions & 0 deletions compiler/src/dotty/tools/dotc/typer/Applications.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1114,6 +1114,8 @@ trait Applications extends Compatibility {
}
else {
val app = tree.fun match
case untpd.TypeApply(_: untpd.SplicePattern, _) if Feature.quotedPatternsWithPolymorphicFunctionsEnabled =>
typedAppliedSpliceWithTypes(tree, pt)
case _: untpd.SplicePattern => typedAppliedSplice(tree, pt)
case _ => realApply
app match {
Expand Down Expand Up @@ -1164,9 +1166,16 @@ trait Applications extends Compatibility {
if (ctx.mode.is(Mode.Pattern))
return errorTree(tree, em"invalid pattern")

tree.fun match {
case _: untpd.SplicePattern if Feature.quotedPatternsWithPolymorphicFunctionsEnabled =>
return errorTree(tree, em"Implementation restriction: A higher-order pattern must carry value arguments")
case _ =>
}

val isNamed = hasNamedArg(tree.args)
val typedArgs = if (isNamed) typedNamedArgs(tree.args) else tree.args.mapconserve(typedType(_))
record("typedTypeApply")

typedExpr(tree.fun, PolyProto(typedArgs, pt)) match {
case fun: TypeApply if !ctx.isAfterTyper =>
val function = fun.fun
Expand Down
Loading

0 comments on commit 02b60bf

Please sign in to comment.