Skip to content

Commit

Permalink
[experimental feature] Support HOAS pattern with type variables for q…
Browse files Browse the repository at this point in the history
…uote pattern matching (#18271)

This PR extends higher-order patterns inside quote patterns to allow
type parameters. When this PR is merged, we'll be able to write quote
patterns like the following example with an experimental flag
`experimental.quotedPatternsWithPolymorphicFunctions`.

```scala
def decomposePoly(x: Expr[Any])(using Quotes): Expr[Int] =
  x match
    case '{ [A] => (x: List[A]) => $y[A](x) : Int } =>
      '{ $y[Int](List(1, 2, 3)) }
    case _ => Expr(0)
```

You can see that the higher-order pattern `$y[A](x)` carries an type
parameter `A`. It states that this pattern matches a code fragment with
occurrences of `A`, and `y` is assigned a polymorphic function type `[A]
=> List[A] => x`.

This PR mainly changes two parts: type checker and quote pattern
matcher. Those changes are based on the formalized type system defined
in [Nicolas Stucki's thesis](https://github.com/nicolasstucki#thesis),
and one can expect the soundness of the implementation.

## Type Dependency
If a higher-order pattern carries a value parameter with a type that has
type parameters defined in the quoted pattern, those type parameters
should also be captured in the higher-order pattern. For example, the
following pattern will not be typed.

```
case '{ [A] => (x: List[A]) => $y(x) : Int } =>
```

In this case, `x` has the type `List[A]`, which includes a type variable
`A` that is defined in the pattern. However, the higher-order pattern
`$y(x)` does not have any type parameters. This should be ill-typed. One
can always avoid this kind of type errors by adding type parameters,
like `$y[A](x)`

## Implementation Restriction
Current implementation only allows type parameters that do not have
bounds, because sound typing rules for such pattern is not clear yet.

```scala
case '{ [A] => (x: List[A]) => $y(x) : Int } => // Allowed
case '{ [A <: Int] => (x: List[A]) => $y(x) : Int } => // Disallowed
  • Loading branch information
jchyb authored Jul 22, 2024
2 parents af933c4 + 41e2d52 commit 4c9cf0a
Show file tree
Hide file tree
Showing 38 changed files with 624 additions and 125 deletions.
2 changes: 1 addition & 1 deletion compiler/src/dotty/tools/dotc/ast/Desugar.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2218,7 +2218,7 @@ object desugar {
case Quote(body, _) =>
new UntypedTreeTraverser {
def traverse(tree: untpd.Tree)(using Context): Unit = tree match {
case SplicePattern(body, _) => collect(body)
case SplicePattern(body, _, _) => collect(body)
case _ => traverseChildren(tree)
}
}.traverse(body)
Expand Down
2 changes: 1 addition & 1 deletion compiler/src/dotty/tools/dotc/ast/TreeInfo.scala
Original file line number Diff line number Diff line change
Expand Up @@ -871,7 +871,7 @@ trait TypedTreeInfo extends TreeInfo[Type] { self: Trees.Instance[Type] =>
}
private object quotePatVars extends TreeAccumulator[List[Symbol]] {
def apply(syms: List[Symbol], tree: Tree)(using Context) = tree match {
case SplicePattern(pat, _) => outer.apply(syms, pat)
case SplicePattern(pat, _, _) => outer.apply(syms, pat)
case _ => foldOver(syms, tree)
}
}
Expand Down
17 changes: 9 additions & 8 deletions compiler/src/dotty/tools/dotc/ast/Trees.scala
Original file line number Diff line number Diff line change
Expand Up @@ -763,9 +763,10 @@ object Trees {
* `SplicePattern` can only be contained within a `QuotePattern`.
*
* @param body The tree that was spliced
* @param typeargs The type arguments of the splice (the HOAS arguments)
* @param args The arguments of the splice (the HOAS arguments)
*/
case class SplicePattern[+T <: Untyped] private[ast] (body: Tree[T], args: List[Tree[T]])(implicit @constructorOnly src: SourceFile)
case class SplicePattern[+T <: Untyped] private[ast] (body: Tree[T], typeargs: List[Tree[T]], args: List[Tree[T]])(implicit @constructorOnly src: SourceFile)
extends TermTree[T] {
type ThisTree[+T <: Untyped] = SplicePattern[T]
}
Expand Down Expand Up @@ -1372,9 +1373,9 @@ object Trees {
case tree: QuotePattern if (bindings eq tree.bindings) && (body eq tree.body) && (quotes eq tree.quotes) => tree
case _ => finalize(tree, untpd.QuotePattern(bindings, body, quotes)(sourceFile(tree)))
}
def SplicePattern(tree: Tree)(body: Tree, args: List[Tree])(using Context): SplicePattern = tree match {
case tree: SplicePattern if (body eq tree.body) && (args eq tree.args) => tree
case _ => finalize(tree, untpd.SplicePattern(body, args)(sourceFile(tree)))
def SplicePattern(tree: Tree)(body: Tree, typeargs: List[Tree], args: List[Tree])(using Context): SplicePattern = tree match {
case tree: SplicePattern if (body eq tree.body) && (typeargs eq tree.typeargs) & (args eq tree.args) => tree
case _ => finalize(tree, untpd.SplicePattern(body, typeargs, args)(sourceFile(tree)))
}
def SingletonTypeTree(tree: Tree)(ref: Tree)(using Context): SingletonTypeTree = tree match {
case tree: SingletonTypeTree if (ref eq tree.ref) => tree
Expand Down Expand Up @@ -1622,8 +1623,8 @@ object Trees {
cpy.Splice(tree)(transform(expr)(using spliceContext))
case tree @ QuotePattern(bindings, body, quotes) =>
cpy.QuotePattern(tree)(transform(bindings), transform(body)(using quoteContext), transform(quotes))
case tree @ SplicePattern(body, args) =>
cpy.SplicePattern(tree)(transform(body)(using spliceContext), transform(args))
case tree @ SplicePattern(body, targs, args) =>
cpy.SplicePattern(tree)(transform(body)(using spliceContext), transform(targs), transform(args))
case tree @ Hole(isTerm, idx, args, content) =>
cpy.Hole(tree)(isTerm, idx, transform(args), transform(content))
case _ =>
Expand Down Expand Up @@ -1771,8 +1772,8 @@ object Trees {
this(x, expr)(using spliceContext)
case QuotePattern(bindings, body, quotes) =>
this(this(this(x, bindings), body)(using quoteContext), quotes)
case SplicePattern(body, args) =>
this(this(x, body)(using spliceContext), args)
case SplicePattern(body, typeargs, args) =>
this(this(this(x, body)(using spliceContext), typeargs), args)
case Hole(_, _, args, content) =>
this(this(x, args), content)
case _ =>
Expand Down
4 changes: 2 additions & 2 deletions compiler/src/dotty/tools/dotc/ast/tpd.scala
Original file line number Diff line number Diff line change
Expand Up @@ -181,8 +181,8 @@ object tpd extends Trees.Instance[Type] with TypedTreeInfo {
def Splice(expr: Tree)(using Context): Splice =
ta.assignType(untpd.Splice(expr), expr)

def SplicePattern(pat: Tree, args: List[Tree], tpe: Type)(using Context): SplicePattern =
untpd.SplicePattern(pat, args).withType(tpe)
def SplicePattern(pat: Tree, targs: List[Tree], args: List[Tree], tpe: Type)(using Context): SplicePattern =
untpd.SplicePattern(pat, targs, args).withType(tpe)

def Hole(isTerm: Boolean, idx: Int, args: List[Tree], content: Tree, tpe: Type)(using Context): Hole =
untpd.Hole(isTerm, idx, args, content).withType(tpe)
Expand Down
2 changes: 1 addition & 1 deletion compiler/src/dotty/tools/dotc/ast/untpd.scala
Original file line number Diff line number Diff line change
Expand Up @@ -415,7 +415,7 @@ object untpd extends Trees.Instance[Untyped] with UntypedTreeInfo {
def Quote(body: Tree, tags: List[Tree])(implicit src: SourceFile): Quote = new Quote(body, tags)
def Splice(expr: Tree)(implicit src: SourceFile): Splice = new Splice(expr)
def QuotePattern(bindings: List[Tree], body: Tree, quotes: Tree)(implicit src: SourceFile): QuotePattern = new QuotePattern(bindings, body, quotes)
def SplicePattern(body: Tree, args: List[Tree])(implicit src: SourceFile): SplicePattern = new SplicePattern(body, args)
def SplicePattern(body: Tree, typeargs: List[Tree], args: List[Tree])(implicit src: SourceFile): SplicePattern = new SplicePattern(body, typeargs, args)
def TypeTree()(implicit src: SourceFile): TypeTree = new TypeTree()
def InferredTypeTree()(implicit src: SourceFile): TypeTree = new InferredTypeTree()
def SingletonTypeTree(ref: Tree)(implicit src: SourceFile): SingletonTypeTree = new SingletonTypeTree(ref)
Expand Down
4 changes: 4 additions & 0 deletions compiler/src/dotty/tools/dotc/config/Feature.scala
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ object Feature:
val namedTuples = experimental("namedTuples")
val modularity = experimental("modularity")
val betterMatchTypeExtractors = experimental("betterMatchTypeExtractors")
val quotedPatternsWithPolymorphicFunctions = experimental("quotedPatternsWithPolymorphicFunctions")

def experimentalAutoEnableFeatures(using Context): List[TermName] =
defn.languageExperimentalFeatures
Expand Down Expand Up @@ -130,6 +131,9 @@ object Feature:

def betterMatchTypeExtractorsEnabled(using Context) = enabled(betterMatchTypeExtractors)

def quotedPatternsWithPolymorphicFunctionsEnabled(using Context) =
enabled(quotedPatternsWithPolymorphicFunctions)

/** Is pureFunctions enabled for this compilation unit? */
def pureFunsEnabled(using Context) =
enabledBySetting(pureFunctions)
Expand Down
1 change: 1 addition & 0 deletions compiler/src/dotty/tools/dotc/core/Definitions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -903,6 +903,7 @@ class Definitions {
@tu lazy val QuotedRuntimePatterns: Symbol = requiredModule("scala.quoted.runtime.Patterns")
@tu lazy val QuotedRuntimePatterns_patternHole: Symbol = QuotedRuntimePatterns.requiredMethod("patternHole")
@tu lazy val QuotedRuntimePatterns_higherOrderHole: Symbol = QuotedRuntimePatterns.requiredMethod("higherOrderHole")
@tu lazy val QuotedRuntimePatterns_higherOrderHoleWithTypes: Symbol = QuotedRuntimePatterns.requiredMethod("higherOrderHoleWithTypes")
@tu lazy val QuotedRuntimePatterns_patternTypeAnnot: ClassSymbol = QuotedRuntimePatterns.requiredClass("patternType")
@tu lazy val QuotedRuntimePatterns_fromAboveAnnot: ClassSymbol = QuotedRuntimePatterns.requiredClass("fromAbove")

Expand Down
3 changes: 1 addition & 2 deletions compiler/src/dotty/tools/dotc/core/tasty/TreePickler.scala
Original file line number Diff line number Diff line change
Expand Up @@ -776,8 +776,7 @@ class TreePickler(pickler: TastyPickler, attributes: Attributes) {
pickleType(tree.tpe)
bindings.foreach(pickleTree)
}
case SplicePattern(pat, args) =>
val targs = Nil // SplicePattern `targs` will be added with #18271
case SplicePattern(pat, targs, args) =>
writeByte(SPLICEPATTERN)
withLength {
pickleTree(pat)
Expand Down
3 changes: 1 addition & 2 deletions compiler/src/dotty/tools/dotc/core/tasty/TreeUnpickler.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1668,8 +1668,7 @@ class TreeUnpickler(reader: TastyReader,
val pat = readTree()
val patType = readType()
val (targs, args) = until(end)(readTree()).span(_.isType)
assert(targs.isEmpty, "unexpected type arguments in SPLICEPATTERN") // `targs` will be needed for #18271. Until this fearure is added they should be empty.
SplicePattern(pat, args, patType)
SplicePattern(pat, targs, args, patType)
case HOLE =>
readHole(end, isTerm = true)
case _ =>
Expand Down
2 changes: 1 addition & 1 deletion compiler/src/dotty/tools/dotc/parsing/Parsers.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1966,7 +1966,7 @@ object Parsers {
syntaxError(em"$msg\n\nHint: $hint", Span(start, in.lastOffset))
Ident(nme.ERROR.toTypeName)
else if inPattern then
SplicePattern(expr, Nil)
SplicePattern(expr, Nil, Nil)
else
Splice(expr)
}
Expand Down
7 changes: 4 additions & 3 deletions compiler/src/dotty/tools/dotc/printing/RefinedPrinter.scala
Original file line number Diff line number Diff line change
Expand Up @@ -793,11 +793,12 @@ class RefinedPrinter(_ctx: Context) extends PlainPrinter(_ctx) {
val open = if (body.isTerm) keywordStr("{") else keywordStr("[")
val close = if (body.isTerm) keywordStr("}") else keywordStr("]")
keywordStr("'") ~ quotesText ~ open ~ bindingsText ~ toTextGlobal(body) ~ close
case SplicePattern(pattern, args) =>
case SplicePattern(pattern, typeargs, args) =>
val spliceTypeText = (keywordStr("[") ~ toTextGlobal(tree.typeOpt) ~ keywordStr("]")).provided(printDebug && tree.typeOpt.exists)
keywordStr("$") ~ spliceTypeText ~ {
if args.isEmpty then keywordStr("{") ~ inPattern(toText(pattern)) ~ keywordStr("}")
else toText(pattern) ~ "(" ~ toTextGlobal(args, ", ") ~ ")"
if typeargs.isEmpty && args.isEmpty then keywordStr("{") ~ inPattern(toText(pattern)) ~ keywordStr("}")
else if typeargs.isEmpty then toText(pattern) ~ "(" ~ toTextGlobal(args, ", ") ~ ")"
else toText(pattern) ~ "[" ~ toTextGlobal(typeargs, ", ")~ "]" ~ "(" ~ toTextGlobal(args, ", ") ~ ")"
}
case Hole(isTerm, idx, args, content) =>
val (prefix, postfix) = if isTerm then ("{{{", "}}}") else ("[[[", "]]]")
Expand Down
Loading

0 comments on commit 4c9cf0a

Please sign in to comment.