diff --git a/compiler/src/dotty/tools/dotc/ast/Desugar.scala b/compiler/src/dotty/tools/dotc/ast/Desugar.scala index 6024eab29722..e70f029f65a7 100644 --- a/compiler/src/dotty/tools/dotc/ast/Desugar.scala +++ b/compiler/src/dotty/tools/dotc/ast/Desugar.scala @@ -2028,7 +2028,7 @@ object desugar { case Quote(body, _) => new UntypedTreeTraverser { def traverse(tree: untpd.Tree)(using Context): Unit = tree match { - case SplicePattern(body, _) => collect(body) + case SplicePattern(body, _, _) => collect(body) case _ => traverseChildren(tree) } }.traverse(body) diff --git a/compiler/src/dotty/tools/dotc/ast/TreeInfo.scala b/compiler/src/dotty/tools/dotc/ast/TreeInfo.scala index 90c8211b3b60..7e8aa870de3a 100644 --- a/compiler/src/dotty/tools/dotc/ast/TreeInfo.scala +++ b/compiler/src/dotty/tools/dotc/ast/TreeInfo.scala @@ -852,7 +852,7 @@ trait TypedTreeInfo extends TreeInfo[Type] { self: Trees.Instance[Type] => } private object quotePatVars extends TreeAccumulator[List[Symbol]] { def apply(syms: List[Symbol], tree: Tree)(using Context) = tree match { - case SplicePattern(pat, _) => outer.apply(syms, pat) + case SplicePattern(pat, _, _) => outer.apply(syms, pat) case _ => foldOver(syms, tree) } } diff --git a/compiler/src/dotty/tools/dotc/ast/Trees.scala b/compiler/src/dotty/tools/dotc/ast/Trees.scala index 2a66eda068c9..e4e997b4b055 100644 --- a/compiler/src/dotty/tools/dotc/ast/Trees.scala +++ b/compiler/src/dotty/tools/dotc/ast/Trees.scala @@ -760,9 +760,10 @@ object Trees { * `SplicePattern` can only be contained within a `QuotePattern`. * * @param body The tree that was spliced + * @param typeargs The type arguments of the splice (the HOAS arguments) * @param args The arguments of the splice (the HOAS arguments) */ - case class SplicePattern[+T <: Untyped] private[ast] (body: Tree[T], args: List[Tree[T]])(implicit @constructorOnly src: SourceFile) + case class SplicePattern[+T <: Untyped] private[ast] (body: Tree[T], typeargs: List[Tree[T]], args: List[Tree[T]])(implicit @constructorOnly src: SourceFile) extends TermTree[T] { type ThisTree[+T <: Untyped] = SplicePattern[T] } @@ -1367,9 +1368,9 @@ object Trees { case tree: QuotePattern if (bindings eq tree.bindings) && (body eq tree.body) && (quotes eq tree.quotes) => tree case _ => finalize(tree, untpd.QuotePattern(bindings, body, quotes)(sourceFile(tree))) } - def SplicePattern(tree: Tree)(body: Tree, args: List[Tree])(using Context): SplicePattern = tree match { - case tree: SplicePattern if (body eq tree.body) && (args eq tree.args) => tree - case _ => finalize(tree, untpd.SplicePattern(body, args)(sourceFile(tree))) + def SplicePattern(tree: Tree)(body: Tree, typeargs: List[Tree], args: List[Tree])(using Context): SplicePattern = tree match { + case tree: SplicePattern if (body eq tree.body) && (typeargs eq tree.typeargs) & (args eq tree.args) => tree + case _ => finalize(tree, untpd.SplicePattern(body, typeargs, args)(sourceFile(tree))) } def SingletonTypeTree(tree: Tree)(ref: Tree)(using Context): SingletonTypeTree = tree match { case tree: SingletonTypeTree if (ref eq tree.ref) => tree @@ -1617,8 +1618,8 @@ object Trees { cpy.Splice(tree)(transform(expr)(using spliceContext)) case tree @ QuotePattern(bindings, body, quotes) => cpy.QuotePattern(tree)(transform(bindings), transform(body)(using quoteContext), transform(quotes)) - case tree @ SplicePattern(body, args) => - cpy.SplicePattern(tree)(transform(body)(using spliceContext), transform(args)) + case tree @ SplicePattern(body, targs, args) => + cpy.SplicePattern(tree)(transform(body)(using spliceContext), transform(targs), transform(args)) case tree @ Hole(isTerm, idx, args, content) => cpy.Hole(tree)(isTerm, idx, transform(args), transform(content)) case _ => @@ -1766,8 +1767,8 @@ object Trees { this(x, expr)(using spliceContext) case QuotePattern(bindings, body, quotes) => this(this(this(x, bindings), body)(using quoteContext), quotes) - case SplicePattern(body, args) => - this(this(x, body)(using spliceContext), args) + case SplicePattern(body, typeargs, args) => + this(this(this(x, body)(using spliceContext), typeargs), args) case Hole(_, _, args, content) => this(this(x, args), content) case _ => diff --git a/compiler/src/dotty/tools/dotc/ast/untpd.scala b/compiler/src/dotty/tools/dotc/ast/untpd.scala index e7d38da854a4..e9f00cc098b5 100644 --- a/compiler/src/dotty/tools/dotc/ast/untpd.scala +++ b/compiler/src/dotty/tools/dotc/ast/untpd.scala @@ -412,7 +412,7 @@ object untpd extends Trees.Instance[Untyped] with UntypedTreeInfo { def Quote(body: Tree, tags: List[Tree])(implicit src: SourceFile): Quote = new Quote(body, tags) def Splice(expr: Tree)(implicit src: SourceFile): Splice = new Splice(expr) def QuotePattern(bindings: List[Tree], body: Tree, quotes: Tree)(implicit src: SourceFile): QuotePattern = new QuotePattern(bindings, body, quotes) - def SplicePattern(body: Tree, args: List[Tree])(implicit src: SourceFile): SplicePattern = new SplicePattern(body, args) + def SplicePattern(body: Tree, typeargs: List[Tree], args: List[Tree])(implicit src: SourceFile): SplicePattern = new SplicePattern(body, typeargs, args) def TypeTree()(implicit src: SourceFile): TypeTree = new TypeTree() def InferredTypeTree()(implicit src: SourceFile): TypeTree = new InferredTypeTree() def SingletonTypeTree(ref: Tree)(implicit src: SourceFile): SingletonTypeTree = new SingletonTypeTree(ref) diff --git a/compiler/src/dotty/tools/dotc/config/Feature.scala b/compiler/src/dotty/tools/dotc/config/Feature.scala index f1443ad56442..e8b2806af34b 100644 --- a/compiler/src/dotty/tools/dotc/config/Feature.scala +++ b/compiler/src/dotty/tools/dotc/config/Feature.scala @@ -32,6 +32,7 @@ object Feature: val pureFunctions = experimental("pureFunctions") val captureChecking = experimental("captureChecking") val into = experimental("into") + val quotedPatternsWithPolymorphicFunctions = experimental("quotedPatternsWithPolymorphicFunctions") val globalOnlyImports: Set[TermName] = Set(pureFunctions, captureChecking) @@ -83,6 +84,9 @@ object Feature: def scala2ExperimentalMacroEnabled(using Context) = enabled(scala2macros) + def quotedPatternsWithPolymorphicFunctionsEnabled(using Context) = + enabled(quotedPatternsWithPolymorphicFunctions) + /** Is pureFunctions enabled for this compilation unit? */ def pureFunsEnabled(using Context) = enabledBySetting(pureFunctions) diff --git a/compiler/src/dotty/tools/dotc/core/Definitions.scala b/compiler/src/dotty/tools/dotc/core/Definitions.scala index edd054375b05..ceea46c38dfd 100644 --- a/compiler/src/dotty/tools/dotc/core/Definitions.scala +++ b/compiler/src/dotty/tools/dotc/core/Definitions.scala @@ -884,6 +884,7 @@ class Definitions { @tu lazy val QuotedRuntimePatterns: Symbol = requiredModule("scala.quoted.runtime.Patterns") @tu lazy val QuotedRuntimePatterns_patternHole: Symbol = QuotedRuntimePatterns.requiredMethod("patternHole") @tu lazy val QuotedRuntimePatterns_higherOrderHole: Symbol = QuotedRuntimePatterns.requiredMethod("higherOrderHole") + @tu lazy val QuotedRuntimePatterns_higherOrderHoleWithTypes: Symbol = QuotedRuntimePatterns.requiredMethod("higherOrderHoleWithTypes") @tu lazy val QuotedRuntimePatterns_patternTypeAnnot: ClassSymbol = QuotedRuntimePatterns.requiredClass("patternType") @tu lazy val QuotedRuntimePatterns_fromAboveAnnot: ClassSymbol = QuotedRuntimePatterns.requiredClass("fromAbove") diff --git a/compiler/src/dotty/tools/dotc/parsing/Parsers.scala b/compiler/src/dotty/tools/dotc/parsing/Parsers.scala index 05e5c34b5a0f..b0aa8aa11922 100644 --- a/compiler/src/dotty/tools/dotc/parsing/Parsers.scala +++ b/compiler/src/dotty/tools/dotc/parsing/Parsers.scala @@ -1802,7 +1802,7 @@ object Parsers { syntaxError(em"$msg\n\nHint: $hint", Span(start, in.lastOffset)) Ident(nme.ERROR.toTypeName) else if inPattern then - SplicePattern(expr, Nil) + SplicePattern(expr, Nil, Nil) else Splice(expr) } diff --git a/compiler/src/dotty/tools/dotc/printing/RefinedPrinter.scala b/compiler/src/dotty/tools/dotc/printing/RefinedPrinter.scala index 7465b5c60aa3..530c62860c2e 100644 --- a/compiler/src/dotty/tools/dotc/printing/RefinedPrinter.scala +++ b/compiler/src/dotty/tools/dotc/printing/RefinedPrinter.scala @@ -752,11 +752,12 @@ class RefinedPrinter(_ctx: Context) extends PlainPrinter(_ctx) { val open = if (body.isTerm) keywordStr("{") else keywordStr("[") val close = if (body.isTerm) keywordStr("}") else keywordStr("]") keywordStr("'") ~ quotesText ~ open ~ bindingsText ~ toTextGlobal(body) ~ close - case SplicePattern(pattern, args) => + case SplicePattern(pattern, typeargs, args) => val spliceTypeText = (keywordStr("[") ~ toTextGlobal(tree.typeOpt) ~ keywordStr("]")).provided(printDebug && tree.typeOpt.exists) keywordStr("$") ~ spliceTypeText ~ { - if args.isEmpty then keywordStr("{") ~ inPattern(toText(pattern)) ~ keywordStr("}") - else toText(pattern.symbol.name) ~ "(" ~ toTextGlobal(args, ", ") ~ ")" + if typeargs.isEmpty && args.isEmpty then keywordStr("{") ~ inPattern(toText(pattern)) ~ keywordStr("}") + else if typeargs.isEmpty then toText(pattern.symbol.name) ~ "(" ~ toTextGlobal(args, ", ") ~ ")" + else toText(pattern.symbol.name) ~ "[" ~ toTextGlobal(typeargs, ", ")~ "]" ~ "(" ~ toTextGlobal(args, ", ") ~ ")" } case Hole(isTerm, idx, args, content) => val (prefix, postfix) = if isTerm then ("{{{", "}}}") else ("[[[", "]]]") diff --git a/compiler/src/dotty/tools/dotc/quoted/QuotePatterns.scala b/compiler/src/dotty/tools/dotc/quoted/QuotePatterns.scala index 48884f6b2d6e..1c054759846e 100644 --- a/compiler/src/dotty/tools/dotc/quoted/QuotePatterns.scala +++ b/compiler/src/dotty/tools/dotc/quoted/QuotePatterns.scala @@ -26,31 +26,90 @@ object QuotePatterns: import tpd._ /** Check for restricted patterns */ - def checkPattern(quotePattern: QuotePattern)(using Context): Unit = new tpd.TreeTraverser { - def traverse(tree: Tree)(using Context): Unit = tree match { - case _: SplicePattern => - case tdef: TypeDef if tdef.symbol.isClass => - val kind = if tdef.symbol.is(Module) then "objects" else "classes" - report.error(em"Implementation restriction: cannot match $kind", tree.srcPos) - case tree: NamedDefTree => - if tree.name.is(NameKinds.WildcardParamName) then - report.warning( - "Use of `_` for lambda in quoted pattern. Use explicit lambda instead or use `$_` to match any term.", - tree.srcPos) - if tree.name.isTermName && !tree.nameSpan.isSynthetic && tree.name != nme.ANON_FUN && tree.name.startsWith("$") then - report.error("Names cannot start with $ quote pattern", tree.namePos) - traverseChildren(tree) - case _: Match => - report.error("Implementation restriction: cannot match `match` expressions", tree.srcPos) - case _: Try => - report.error("Implementation restriction: cannot match `try` expressions", tree.srcPos) - case _: Return => - report.error("Implementation restriction: cannot match `return` statements", tree.srcPos) - case _ => - traverseChildren(tree) - } + def checkPattern(quotePattern: QuotePattern)(using Context): Unit = + def validatePatternAndCollectTypeVars(): Set[Symbol] = new tpd.TreeAccumulator[Set[Symbol]] { + override def apply(typevars: Set[Symbol], tree: tpd.Tree)(using Context): Set[Symbol] = + // Collect type variables + val typevars1 = tree match + case tree @ DefDef(_, paramss, _, _) => + typevars union paramss.flatMap{ params => params match + case TypeDefs(tdefs) => tdefs.map(_.symbol) + case _ => List.empty + }.toSet union typevars + case _ => typevars + + // Validate pattern + tree match + case _: SplicePattern => typevars1 + case tdef: TypeDef if tdef.symbol.isClass => + val kind = if tdef.symbol.is(Module) then "objects" else "classes" + report.error(em"Implementation restriction: cannot match $kind", tree.srcPos) + typevars1 + case tree: NamedDefTree => + if tree.name.is(NameKinds.WildcardParamName) then + report.warning( + "Use of `_` for lambda in quoted pattern. Use explicit lambda instead or use `$_` to match any term.", + tree.srcPos) + if tree.name.isTermName && !tree.nameSpan.isSynthetic && tree.name != nme.ANON_FUN && tree.name.startsWith("$") then + report.error("Names cannot start with $ quote pattern", tree.namePos) + foldOver(typevars1, tree) + case _: Match => + report.error("Implementation restriction: cannot match `match` expressions", tree.srcPos) + typevars1 + case _: Try => + report.error("Implementation restriction: cannot match `try` expressions", tree.srcPos) + typevars1 + case _: Return => + report.error("Implementation restriction: cannot match `return` statements", tree.srcPos) + typevars1 + case _ => + foldOver(typevars1, tree) + }.apply(Set.empty, quotePattern.body) + + val boundTypeVars = validatePatternAndCollectTypeVars() - }.traverse(quotePattern.body) + /* + * This part checks well-formedness of arguments to hoas patterns. + * (1) Type arguments of a hoas patterns must be introduced in the quote pattern.ctxShow + * Examples + * well-formed: '{ [A] => (x : A) => $a[A](x) } // A is introduced in the quote pattern + * ill-formed: '{ (x : Int) => $a[Int](x) } // Int is defined outside of the quote pattern + * (2) If value arguments of a hoas pattern has a type with type variables that are introduced in + * the quote pattern, those type variables should be in type arguments to the hoas patternHole + * Examples + * well-formed: '{ [A] => (x : A) => $a[A](x) } // a : [A] => (x:A) => A + * ill-formed: '{ [A] => (x : A) => $a(x) } // a : (x:A) => A ...but A is undefined; hence ill-formed + */ + new tpd.TreeTraverser { + override def traverse(tree: tpd.Tree)(using Context): Unit = tree match { + case tree: SplicePattern => + def uncapturedTypeVars(arg: tpd.Tree, capturedTypeVars: List[tpd.Tree]): Set[Type] = + /* Sometimes arg is untyped when a splice pattern is ill-formed. + * Return early in such case. + * Refer to QuoteAndSplices::typedSplicePattern + */ + if !arg.hasType then return Set.empty + + val capturedTypeVarsSet = capturedTypeVars.map(_.symbol).toSet + new TypeAccumulator[Set[Type]] { + def apply(x: Set[Type], tp: Type): Set[Type] = + if boundTypeVars.contains(tp.typeSymbol) && !capturedTypeVarsSet.contains(tp.typeSymbol) then + foldOver(x + tp, tp) + else + foldOver(x, tp) + }.apply(Set.empty, arg.tpe) + + for (typearg <- tree.typeargs) // case (1) + do + if !boundTypeVars.contains(typearg.symbol) then + report.error("Type arguments of a hoas pattern needs to be defined inside the quoted pattern", typearg.srcPos) + for (arg <- tree.args) // case (2) + do + if !uncapturedTypeVars(arg, tree.typeargs).isEmpty then + report.error("Type variables that this argument depends on are not captured in this hoas pattern", arg.srcPos) + case _ => traverseChildren(tree) + } + }.traverse(quotePattern.body) /** Encode the quote pattern into an `unapply` that the pattern matcher can handle. * @@ -74,7 +133,7 @@ object QuotePatterns: * .ExprMatch // or TypeMatch * .unapply[ * KCons[t1 >: l1 <: b1, ...KCons[tn >: ln <: bn, KNil]...], // scala.quoted.runtime.{KCons, KNil} - * (T1, T2, (A1, ..., An) => T3, ...) + * (Expr[T1], Expr[T2], Expr[(A1, ..., An) => T3], ...) * ]( * '{ * type t1' >: l1' <: b1' @@ -197,16 +256,24 @@ object QuotePatterns: val patBuf = new mutable.ListBuffer[Tree] val shape = new tpd.TreeMap { override def transform(tree: Tree)(using Context) = tree match { - case Typed(splice @ SplicePattern(pat, Nil), tpt) if !tpt.tpe.derivesFrom(defn.RepeatedParamClass) => + case Typed(splice @ SplicePattern(pat, Nil, Nil), tpt) if !tpt.tpe.derivesFrom(defn.RepeatedParamClass) => transform(tpt) // Collect type bindings transform(splice) - case SplicePattern(pat, args) => + case SplicePattern(pat, typeargs, args) => val patType = pat.tpe.widen val patType1 = patType.translateFromRepeated(toArray = false) val pat1 = if (patType eq patType1) pat else pat.withType(patType1) patBuf += pat1 - if args.isEmpty then ref(defn.QuotedRuntimePatterns_patternHole.termRef).appliedToType(tree.tpe).withSpan(tree.span) - else ref(defn.QuotedRuntimePatterns_higherOrderHole.termRef).appliedToType(tree.tpe).appliedTo(SeqLiteral(args, TypeTree(defn.AnyType))).withSpan(tree.span) + if typeargs.isEmpty && args.isEmpty then ref(defn.QuotedRuntimePatterns_patternHole.termRef).appliedToType(tree.tpe).withSpan(tree.span) + else if typeargs.isEmpty then + ref(defn.QuotedRuntimePatterns_higherOrderHole.termRef) + .appliedToType(tree.tpe) + .appliedTo(SeqLiteral(args, TypeTree(defn.AnyType))) + .withSpan(tree.span) + else ref(defn.QuotedRuntimePatterns_higherOrderHoleWithTypes.termRef) + .appliedToTypeTrees(List(TypeTree(tree.tpe), tpd.hkNestedPairsTypeTree(typeargs))) + .appliedTo(SeqLiteral(args, TypeTree(defn.AnyType))) + .withSpan(tree.span) case _ => super.transform(tree) } @@ -232,7 +299,7 @@ object QuotePatterns: fun match // .asInstanceOf[QuoteMatching].{ExprMatch,TypeMatch}.unapply[, ] case TypeApply(Select(Select(TypeApply(Select(quotes, _), _), _), _), typeBindings :: resTypes :: Nil) => - val bindings = unrollBindings(typeBindings) + val bindings = unrollHkNestedPairsTypeTree(typeBindings) val addPattenSplice = new TreeMap { private val patternIterator = patterns.iterator.filter { case pat: Bind => !pat.symbol.name.is(PatMatGivenVarName) @@ -240,9 +307,11 @@ object QuotePatterns: } override def transform(tree: tpd.Tree)(using Context): tpd.Tree = tree match case TypeApply(patternHole, _) if patternHole.symbol == defn.QuotedRuntimePatterns_patternHole => - cpy.SplicePattern(tree)(patternIterator.next(), Nil) + cpy.SplicePattern(tree)(patternIterator.next(), Nil, Nil) case Apply(patternHole, SeqLiteral(args, _) :: Nil) if patternHole.symbol == defn.QuotedRuntimePatterns_higherOrderHole => - cpy.SplicePattern(tree)(patternIterator.next(), args) + cpy.SplicePattern(tree)(patternIterator.next(), Nil, args) + case Apply(TypeApply(patternHole, List(_, targsTpe)), SeqLiteral(args, _) :: Nil) if patternHole.symbol == defn.QuotedRuntimePatterns_higherOrderHoleWithTypes => + cpy.SplicePattern(tree)(patternIterator.next(), unrollHkNestedPairsTypeTree(targsTpe), args) case _ => super.transform(tree) } val body = addPattenSplice.transform(shape) match @@ -260,7 +329,7 @@ object QuotePatterns: case body => body cpy.QuotePattern(tree)(bindings, body, quotes) - private def unrollBindings(tree: Tree)(using Context): List[Tree] = tree match + private def unrollHkNestedPairsTypeTree(tree: Tree)(using Context): List[Tree] = tree match case AppliedTypeTree(tupleN, bindings) if defn.isTupleClass(tupleN.symbol) => bindings // TupleN, 1 <= N <= 22 - case AppliedTypeTree(_, head :: tail :: Nil) => head :: unrollBindings(tail) // KCons or *: + case AppliedTypeTree(_, head :: tail :: Nil) => head :: unrollHkNestedPairsTypeTree(tail) // KCons or *: case _ => Nil // KNil or EmptyTuple diff --git a/compiler/src/dotty/tools/dotc/typer/Applications.scala b/compiler/src/dotty/tools/dotc/typer/Applications.scala index 9501e51aeb6f..1c15b8a26c0e 100644 --- a/compiler/src/dotty/tools/dotc/typer/Applications.scala +++ b/compiler/src/dotty/tools/dotc/typer/Applications.scala @@ -1114,6 +1114,8 @@ trait Applications extends Compatibility { } else { val app = tree.fun match + case untpd.TypeApply(_: untpd.SplicePattern, _) if Feature.quotedPatternsWithPolymorphicFunctionsEnabled => + typedAppliedSpliceWithTypes(tree, pt) case _: untpd.SplicePattern => typedAppliedSplice(tree, pt) case _ => realApply app match { @@ -1164,9 +1166,16 @@ trait Applications extends Compatibility { if (ctx.mode.is(Mode.Pattern)) return errorTree(tree, em"invalid pattern") + tree.fun match { + case _: untpd.SplicePattern if Feature.quotedPatternsWithPolymorphicFunctionsEnabled => + return errorTree(tree, em"Implementation restriction: A higher-order pattern must carry value arguments") + case _ => + } + val isNamed = hasNamedArg(tree.args) val typedArgs = if (isNamed) typedNamedArgs(tree.args) else tree.args.mapconserve(typedType(_)) record("typedTypeApply") + typedExpr(tree.fun, PolyProto(typedArgs, pt)) match { case fun: TypeApply if !ctx.isAfterTyper => val function = fun.fun diff --git a/compiler/src/dotty/tools/dotc/typer/QuotesAndSplices.scala b/compiler/src/dotty/tools/dotc/typer/QuotesAndSplices.scala index bda2c25c26b8..9bacfd76a00f 100644 --- a/compiler/src/dotty/tools/dotc/typer/QuotesAndSplices.scala +++ b/compiler/src/dotty/tools/dotc/typer/QuotesAndSplices.scala @@ -119,14 +119,31 @@ trait QuotesAndSplices { EmptyTree } } + val typedTypeargs = tree.typeargs.map { + case typearg: untpd.Ident => + val typedTypearg = typedType(typearg) + val bounds = ctx.gadt.fullBounds(typedTypearg.symbol) + if bounds != null && bounds != TypeBounds.empty then + report.error("Implementation restriction: Type arguments to Open pattern are expected to have no bounds", typearg.srcPos) + typedTypearg + case arg => + report.error("Open pattern expected an identifier", arg.srcPos) + EmptyTree + } for arg <- typedArgs if arg.symbol.is(Mutable) do // TODO support these patterns. Possibly using scala.quoted.util.Var report.error("References to `var`s cannot be used in higher-order pattern", arg.srcPos) val argTypes = typedArgs.map(_.tpe.widenTermRefExpr) - val patType = if tree.args.isEmpty then pt else defn.FunctionOf(argTypes, pt) + val patType = (tree.typeargs.isEmpty, tree.args.isEmpty) match + case (true, true) => pt + case (true, false) => + defn.FunctionOf(argTypes, pt) + case (false, _) => + PolyFunctionOf(typedTypeargs.tpes, argTypes, pt) + val pat = typedPattern(tree.body, defn.QuotedExprClass.typeRef.appliedTo(patType))(using quotePatternSpliceContext) val baseType = pat.tpe.baseType(defn.QuotedExprClass) val argType = if baseType.exists then baseType.argTypesHi.head else defn.NothingType - untpd.cpy.SplicePattern(tree)(pat, typedArgs).withType(pt) + untpd.cpy.SplicePattern(tree)(pat, typedTypeargs, typedArgs).withType(pt) else errorTree(tree, em"Type must be fully defined.\nConsider annotating the splice using a type ascription:\n ($tree: XYZ).", tree.body.srcPos) } @@ -153,7 +170,34 @@ trait QuotesAndSplices { else // $x(...) higher-order quasipattern if args.isEmpty then report.error("Missing arguments for open pattern", tree.srcPos) - typedSplicePattern(untpd.cpy.SplicePattern(tree)(splice.body, args), pt) + typedSplicePattern(untpd.cpy.SplicePattern(tree)(splice.body, Nil, args), pt) + } + + /** Types a splice applied to some type arguments and arguments + * `$f[targs1, ..., targsn](arg1, ..., argn)` in a quote pattern. + * + * Refer to: typedAppliedSplice + */ + def typedAppliedSpliceWithTypes(tree: untpd.Apply, pt: Type)(using Context): Tree = { + assert(ctx.mode.isQuotedPattern) + val untpd.Apply(typeApplyTree @ untpd.TypeApply(splice: untpd.SplicePattern, typeargs), args) = tree: @unchecked + def isInBraces: Boolean = splice.span.end != splice.body.span.end + if isInBraces then // ${x}[...](...) match an application + val typedTypeargs = typeargs.map(arg => typedType(arg)) + val typedArgs = args.map(arg => typedExpr(arg)) + val argTypes = typedArgs.map(_.tpe.widenTermRefExpr) + val splice1 = typedSplicePattern(splice, ProtoTypes.PolyProto(typedArgs, defn.FunctionOf(argTypes, pt))) + val typedTypeApply = untpd.cpy.TypeApply(typeApplyTree)(splice1.select(nme.apply), typedTypeargs) + untpd.cpy.Apply(tree)(typedTypeApply, typedArgs).withType(pt) + else // $x[...](...) higher-order quasipattern + // Empty args is allowed + if typeargs.isEmpty then + report.error("Missing type arguments for open pattern", tree.srcPos) + typedSplicePattern(untpd.cpy.SplicePattern(tree)(splice.body, typeargs, args), pt) + } + + def typedTypeAppliedSplice(tree: untpd.TypeApply, pt: Type)(using Context): Tree = { + typedAppliedSpliceWithTypes(untpd.Apply(tree, Nil), pt) } /** Type check a type binding reference in a quoted pattern. @@ -322,4 +366,22 @@ object QuotesAndSplices { case _ => super.transform(tree) end TreeMapWithVariance + + object PolyFunctionOf { + /** + * Return a poly-type + method type [$typeargs] => ($args) => ($resultType) + * where typeargs occur in args and resulttype + */ + def apply(typeargs: List[Type], args: List[Type], resultType: Type)(using Context): Type = + val typeargs1 = PolyType.syntheticParamNames(typeargs.length) + + val bounds = typeargs map (_ => TypeBounds.empty) + val resultTypeExp = (pt: PolyType) => { + val fromSymbols = typeargs map (_.typeSymbol) + val args1 = args map (_.subst(fromSymbols, pt.paramRefs)) + val resultType1 = resultType.subst(fromSymbols, pt.paramRefs) + MethodType(args1, resultType1) + } + defn.PolyFunctionOf(PolyType(typeargs1)(_ => bounds, resultTypeExp)) + } } diff --git a/compiler/src/dotty/tools/dotc/typer/ReTyper.scala b/compiler/src/dotty/tools/dotc/typer/ReTyper.scala index 20dfe07c3be5..0b1450dafe89 100644 --- a/compiler/src/dotty/tools/dotc/typer/ReTyper.scala +++ b/compiler/src/dotty/tools/dotc/typer/ReTyper.scala @@ -130,14 +130,15 @@ class ReTyper(nestingLevel: Int = 0) extends Typer(nestingLevel) with ReChecking override def typedSplicePattern(tree: untpd.SplicePattern, pt: Type)(using Context): Tree = assertTyped(tree) + val typeargs1 = tree.typeargs.mapconserve(typedType(_)) val args1 = tree.args.mapconserve(typedExpr(_)) val patternTpe = - if args1.isEmpty then tree.typeOpt + if !typeargs1.isEmpty then QuotesAndSplices.PolyFunctionOf(typeargs1.map(_.tpe), args1.map(_.tpe), tree.typeOpt) + else if args1.isEmpty then tree.typeOpt else defn.FunctionType(args1.size).appliedTo(args1.map(_.tpe) :+ tree.typeOpt) val bodyCtx = spliceContext.addMode(Mode.Pattern).retractMode(Mode.QuotedPatternBits) val body1 = typed(tree.body, defn.QuotedExprClass.typeRef.appliedTo(patternTpe))(using bodyCtx) - val args = tree.args.mapconserve(typedExpr(_)) - untpd.cpy.SplicePattern(tree)(body1, args1).withType(tree.typeOpt) + untpd.cpy.SplicePattern(tree)(body1, typeargs1, args1).withType(tree.typeOpt) override def typedHole(tree: untpd.Hole, pt: Type)(using Context): Tree = promote(tree) diff --git a/compiler/src/scala/quoted/runtime/impl/QuoteMatcher.scala b/compiler/src/scala/quoted/runtime/impl/QuoteMatcher.scala index 44886d59ac12..82a889278507 100644 --- a/compiler/src/scala/quoted/runtime/impl/QuoteMatcher.scala +++ b/compiler/src/scala/quoted/runtime/impl/QuoteMatcher.scala @@ -11,6 +11,7 @@ import dotty.tools.dotc.core.Types.* import dotty.tools.dotc.core.StdNames.nme import dotty.tools.dotc.core.Symbols.* import dotty.tools.dotc.util.optional +import dotty.tools.dotc.ast.TreeTypeMap /** Matches a quoted tree against a quoted pattern tree. * A quoted pattern tree may have type and term holes in addition to normal terms. @@ -112,16 +113,17 @@ class QuoteMatcher(debug: Boolean) { /** Sequence of matched expressions. * These expressions are part of the scrutinee and will be bound to the quote pattern term splices. */ - type MatchingExprs = Seq[MatchResult] + private type MatchingExprs = Seq[MatchResult] - /** A map relating equivalent symbols from the scrutinee and the pattern + /** TODO-18271: update + * A map relating equivalent symbols from the scrutinee and the pattern * For example in * ``` * '{val a = 4; a * a} match case '{ val x = 4; x * x } * ``` * when matching `a * a` with `x * x` the environment will contain `Map(a -> x)`. */ - private type Env = Map[Symbol, Symbol] + private case class Env(val termEnv: Map[Symbol, Symbol], val typeEnv: Map[Symbol, Symbol]) private def withEnv[T](env: Env)(body: Env ?=> T): T = body(using env) @@ -132,7 +134,7 @@ class QuoteMatcher(debug: Boolean) { val (pat1, typeHoles, ctx1) = instrumentTypeHoles(pattern) inContext(ctx1) { optional { - given Env = Map.empty + given Env = new Env(Map.empty, Map.empty) scrutinee =?= pat1 }.map { matchings => lazy val spliceScope = SpliceScope.getCurrent @@ -236,6 +238,26 @@ class QuoteMatcher(debug: Boolean) { case _ => None end TypeTreeTypeTest + /* Some of method symbols in arguments of higher-order term hole are eta-expanded. + * e.g. + * g: (Int) => Int + * => { + * def $anonfun(y: Int): Int = g(y) + * closure($anonfun) + * } + * + * f: (using Int) => Int + * => f(using x) + * This function restores the symbol of the original method from + * the eta-expanded function. + */ + def getCapturedIdent(arg: Tree)(using Context): Ident = + arg match + case id: Ident => id + case Apply(fun, _) => getCapturedIdent(fun) + case Block((ddef: DefDef) :: _, _: Closure) => getCapturedIdent(ddef.rhs) + case Typed(expr, _) => getCapturedIdent(expr) + def runMatch(): optional[MatchingExprs] = pattern match /* Term hole */ @@ -244,14 +266,14 @@ class QuoteMatcher(debug: Boolean) { if patternHole.symbol.eq(defn.QuotedRuntimePatterns_patternHole) && tpt2.tpe.derivesFrom(defn.RepeatedParamClass) => scrutinee match - case Typed(s, tpt1) if s.tpe <:< tpt.tpe => matched(scrutinee) + case Typed(s, tpt1) if isSubTypeUnderEnv(s, tpt) => matched(scrutinee) case _ => notMatched /* Term hole */ // Match a scala.internal.Quoted.patternHole and return the scrutinee tree case TypeApply(patternHole, tpt :: Nil) if patternHole.symbol.eq(defn.QuotedRuntimePatterns_patternHole) && - scrutinee.tpe <:< tpt.tpe => + isSubTypeUnderEnv(scrutinee, tpt) => scrutinee match case ClosedPatternTerm(scrutinee) => matched(scrutinee) case _ => notMatched @@ -262,33 +284,32 @@ class QuoteMatcher(debug: Boolean) { case Apply(TypeApply(Ident(_), List(TypeTree())), SeqLiteral(args, _) :: Nil) if pattern.symbol.eq(defn.QuotedRuntimePatterns_higherOrderHole) => - /* Some of method symbols in arguments of higher-order term hole are eta-expanded. - * e.g. - * g: (Int) => Int - * => { - * def $anonfun(y: Int): Int = g(y) - * closure($anonfun) - * } - * - * f: (using Int) => Int - * => f(using x) - * This function restores the symbol of the original method from - * the eta-expanded function. - */ - def getCapturedIdent(arg: Tree)(using Context): Ident = - arg match - case id: Ident => id - case Apply(fun, _) => getCapturedIdent(fun) - case Block((ddef: DefDef) :: _, _: Closure) => getCapturedIdent(ddef.rhs) - case Typed(expr, _) => getCapturedIdent(expr) - val env = summon[Env] val capturedIds = args.map(getCapturedIdent) val capturedSymbols = capturedIds.map(_.symbol) - val captureEnv = env.filter((k, v) => !capturedSymbols.contains(v)) + val captureEnv = Env( + termEnv = env.termEnv.filter((k, v) => !capturedIds.map(_.symbol).contains(v)), + typeEnv = env.typeEnv) withEnv(captureEnv) { scrutinee match - case ClosedPatternTerm(scrutinee) => matchedOpen(scrutinee, pattern.tpe, capturedIds, args.map(_.tpe), env) + case ClosedPatternTerm(scrutinee) => matchedOpen(scrutinee, pattern.tpe, capturedIds, args.map(_.tpe), Nil, env) + case _ => notMatched + } + + /* Higher order term hole */ + // Matches an open term and wraps it into a lambda that provides the free variables + case Apply(TypeApply(Ident(_), List(TypeTree(), targs)), SeqLiteral(args, _) :: Nil) + if pattern.symbol.eq(defn.QuotedRuntimePatterns_higherOrderHoleWithTypes) => + + val env = summon[Env] + val capturedIds = args.map(getCapturedIdent) + val capturedTargs = unrollHkNestedPairsTypeTree(targs) + val captureEnv = Env( + termEnv = env.termEnv.filter((k, v) => !capturedIds.map(_.symbol).contains(v)), + typeEnv = env.typeEnv.filter((k, v) => !capturedTargs.map(_.symbol).contains(v))) + withEnv(captureEnv) { + scrutinee match + case ClosedPatternTerm(scrutinee) => matchedOpen(scrutinee, pattern.tpe, capturedIds, args.map(_.tpe), capturedTargs.map(_.tpe), env) case _ => notMatched } @@ -324,7 +345,7 @@ class QuoteMatcher(debug: Boolean) { /* Match reference */ case _: Ident if symbolMatch(scrutinee, pattern) => matched /* Match type */ - case TypeTreeTypeTest(pattern) if scrutinee.tpe <:< pattern.tpe => matched + case TypeTreeTypeTest(pattern) if isSubTypeUnderEnv(scrutinee, pattern) => matched case _ => notMatched /* Match application */ @@ -346,8 +367,12 @@ class QuoteMatcher(debug: Boolean) { pattern match case Block(stat2 :: stats2, expr2) => val newEnv = (stat1, stat2) match { - case (stat1: MemberDef, stat2: MemberDef) => - summon[Env] + (stat1.symbol -> stat2.symbol) + case (stat1: ValOrDefDef, stat2: ValOrDefDef) => + val Env(termEnv, typeEnv) = summon[Env] + new Env(termEnv + (stat1.symbol -> stat2.symbol), typeEnv) + case (stat1: TypeDef, stat2: TypeDef) => + val Env(termEnv, typeEnv) = summon[Env] + new Env(termEnv, typeEnv + (stat1.symbol -> stat2.symbol)) case _ => summon[Env] } @@ -403,14 +428,16 @@ class QuoteMatcher(debug: Boolean) { // TODO remove this? case TypeTreeTypeTest(scrutinee) => pattern match - case TypeTreeTypeTest(pattern) if scrutinee.tpe <:< pattern.tpe => matched + case TypeTreeTypeTest(pattern) if isSubTypeUnderEnv(scrutinee, pattern) => matched case _ => notMatched /* Match val */ case scrutinee @ ValDef(_, tpt1, _) => pattern match case pattern @ ValDef(_, tpt2, _) if checkValFlags() => - def rhsEnv = summon[Env] + (scrutinee.symbol -> pattern.symbol) + def rhsEnv = + val Env(termEnv, typeEnv) = summon[Env] + new Env(termEnv + (scrutinee.symbol -> pattern.symbol), typeEnv) tpt1 =?= tpt2 &&& withEnv(rhsEnv)(scrutinee.rhs =?= pattern.rhs) case _ => notMatched @@ -427,11 +454,38 @@ class QuoteMatcher(debug: Boolean) { notMatched case _ => matched + /** + * Implementation restriction: The current implementation matches type parameters + * only when they have empty bounds (>: Nothing <: Any) + */ + def matchTypeDef(sctypedef: TypeDef, pttypedef: TypeDef): MatchingExprs = sctypedef match + case TypeDef(_, TypeBoundsTree(sclo, schi, EmptyTree)) + if sclo.tpe == defn.NothingType && schi.tpe == defn.AnyType => + pttypedef match + case TypeDef(_, TypeBoundsTree(ptlo, pthi, EmptyTree)) + if sclo.tpe == defn.NothingType && schi.tpe == defn.AnyType => + matched + case _ => notMatched + case _ => notMatched + def matchParamss(scparamss: List[ParamClause], ptparamss: List[ParamClause])(using Env): optional[(Env, MatchingExprs)] = (scparamss, ptparamss) match { - case (scparams :: screst, ptparams :: ptrest) => + case (ValDefs(scparams) :: screst, ValDefs(ptparams) :: ptrest) => val mr1 = matchLists(scparams, ptparams)(_ =?= _) - val newEnv = summon[Env] ++ scparams.map(_.symbol).zip(ptparams.map(_.symbol)) + val Env(termEnv, typeEnv) = summon[Env] + val newEnv = new Env( + termEnv = termEnv ++ scparams.map(_.symbol).zip(ptparams.map(_.symbol)), + typeEnv = typeEnv + ) + val (resEnv, mrrest) = withEnv(newEnv)(matchParamss(screst, ptrest)) + (resEnv, mr1 &&& mrrest) + case (TypeDefs(scparams) :: screst, TypeDefs(ptparams) :: ptrest) => + val mr1 = matchLists(scparams, ptparams)(matchTypeDef) + val Env(termEnv, typeEnv) = summon[Env] + val newEnv = new Env( + termEnv = termEnv, + typeEnv = typeEnv ++ scparams.map(_.symbol).zip(ptparams.map(_.symbol)), + ) val (resEnv, mrrest) = withEnv(newEnv)(matchParamss(screst, ptrest)) (resEnv, mr1 &&& mrrest) case (Nil, Nil) => (summon[Env], matched) @@ -439,8 +493,8 @@ class QuoteMatcher(debug: Boolean) { } val ematch = matchErasedParams(scrutinee.tpe.widenTermRefExpr, pattern.tpe.widenTermRefExpr) - val (pEnv, pmatch) = matchParamss(paramss1, paramss2) - val defEnv = pEnv + (scrutinee.symbol -> pattern.symbol) + val (Env(termEnv, typeEnv), pmatch) = matchParamss(paramss1, paramss2) + val defEnv = Env(termEnv + (scrutinee.symbol -> pattern.symbol), typeEnv) ematch &&& pmatch @@ -514,11 +568,19 @@ class QuoteMatcher(debug: Boolean) { else scrutinee case _ => scrutinee val pattern = patternTree.symbol + val Env(termEnv, typeEnv) = summon[Env] devirtualizedScrutinee == pattern - || summon[Env].get(devirtualizedScrutinee).contains(pattern) + || termEnv.get(devirtualizedScrutinee).contains(pattern) + || typeEnv.get(devirtualizedScrutinee).contains(pattern) || devirtualizedScrutinee.allOverriddenSymbols.contains(pattern) + private def isSubTypeUnderEnv(scrutinee: Tree, pattern: Tree)(using Env, Context): Boolean = + val env = summon[Env].typeEnv + val scType = if env.isEmpty then scrutinee.tpe + else scrutinee.subst(env.keys.toList, env.values.toList).tpe + scType <:< pattern.tpe + private object ClosedPatternTerm { /** Matches a term that does not contain free variables defined in the pattern (i.e. not defined in `Env`) */ def unapply(term: Tree)(using Env, Context): Option[term.type] = @@ -526,16 +588,24 @@ class QuoteMatcher(debug: Boolean) { /** Return all free variables of the term defined in the pattern (i.e. defined in `Env`) */ def freePatternVars(term: Tree)(using Env, Context): Set[Symbol] = - val accumulator = new TreeAccumulator[Set[Symbol]] { + val Env(termEnv, typeEnv) = summon[Env] + val typeAccumulator = new TypeAccumulator[Set[Symbol]] { + def apply(x: Set[Symbol], tp: Type): Set[Symbol] = tp match + case tp: TypeRef if typeEnv.contains(tp.typeSymbol) => foldOver(x + tp.typeSymbol, tp) + case tp: TermRef if termEnv.contains(tp.termSymbol) => foldOver(x + tp.termSymbol, tp) + case _ => foldOver(x, tp) + } + val treeAccumulator = new TreeAccumulator[Set[Symbol]] { def apply(x: Set[Symbol], tree: Tree)(using Context): Set[Symbol] = tree match - case tree: Ident if summon[Env].contains(tree.symbol) => foldOver(x + tree.symbol, tree) + case tree: Ident if termEnv.contains(tree.symbol) => foldOver(typeAccumulator(x, tree.tpe) + tree.symbol, tree) + case tree: TypeTree => typeAccumulator(x, tree.tpe) case _ => foldOver(x, tree) } - accumulator.apply(Set.empty, term) + treeAccumulator(Set.empty, term) } - enum MatchResult: + private enum MatchResult: /** Closed pattern extracted value * @param tree Scrutinee sub-tree that matched */ @@ -546,9 +616,10 @@ class QuoteMatcher(debug: Boolean) { * @param patternTpe Type of the pattern hole (from the pattern) * @param argIds Identifiers of HOAS arguments (from the pattern) * @param argTypes Eta-expanded types of HOAS arguments (from the pattern) + * @param typeArgs type arguments from the pattern * @param env Mapping between scrutinee and pattern variables */ - case OpenTree(tree: Tree, patternTpe: Type, argIds: List[Tree], argTypes: List[Type], env: Env) + case OpenTree(tree: Tree, patternTpe: Type, argIds: List[Tree], argTypes: List[Type], typeArgs: List[Type], env: Env) /** Return the expression that was extracted from a hole. * @@ -561,28 +632,61 @@ class QuoteMatcher(debug: Boolean) { def toExpr(mapTypeHoles: Type => Type, spliceScope: Scope)(using Context): Expr[Any] = this match case MatchResult.ClosedTree(tree) => new ExprImpl(tree, spliceScope) - case MatchResult.OpenTree(tree, patternTpe, argIds, argTypes, env) => + case MatchResult.OpenTree(tree, patternTpe, argIds, argTypes, typeArgs, Env(termEnv, typeEnv)) => val names: List[TermName] = argIds.map(_.symbol.name.asTermName) val paramTypes = argTypes.map(tpe => mapTypeHoles(tpe.widenTermRefExpr)) - val methTpe = MethodType(names)(_ => paramTypes, _ => mapTypeHoles(patternTpe)) + val ptTypeVarSymbols = typeArgs.map(_.typeSymbol) + val isNotPoly = typeArgs.isEmpty + + val methTpe = if isNotPoly then + MethodType(names)(_ => paramTypes, _ => mapTypeHoles(patternTpe)) + else + val typeArgs1 = PolyType.syntheticParamNames(typeArgs.length) + val bounds = typeArgs map (_ => TypeBounds.empty) + val resultTypeExp = (pt: PolyType) => { + val argTypes1 = paramTypes.map(_.subst(ptTypeVarSymbols, pt.paramRefs)) + val resultType1 = mapTypeHoles(patternTpe).subst(ptTypeVarSymbols, pt.paramRefs) + MethodType(argTypes1, resultType1) + } + PolyType(typeArgs1)(_ => bounds, resultTypeExp) + val meth = newAnonFun(ctx.owner, methTpe) + def bodyFn(lambdaArgss: List[List[Tree]]): Tree = { - val argsMap = argIds.view.map(_.symbol).zip(lambdaArgss.head).toMap - val body = new TreeMap { - override def transform(tree: Tree)(using Context): Tree = - tree match - /* - * When matching a method call `f(0)` against a HOAS pattern `p(g)` where - * f has a method type `(x: Int): Int` and `f` maps to `g`, `p` should hold - * `g.apply(0)` because the type of `g` is `Int => Int` due to eta expansion. - */ - case Apply(fun, args) if env.contains(tree.symbol) => transform(fun).select(nme.apply).appliedToArgs(args.map(transform)) - case tree: Ident => env.get(tree.symbol).flatMap(argsMap.get).getOrElse(tree) - case tree => super.transform(tree) - }.transform(tree) + val (typeParams, params) = if isNotPoly then + (List.empty, lambdaArgss.head) + else + (lambdaArgss.head.map(_.tpe), lambdaArgss.tail.head) + + val typeArgsMap = ptTypeVarSymbols.zip(typeParams).toMap + val argsMap = argIds.view.map(_.symbol).zip(params).toMap + + val body = new TreeTypeMap( + typeMap = if isNotPoly then IdentityTypeMap + else new TypeMap() { + override def apply(tp: Type): Type = tp match { + case tr: TypeRef if tr.prefix.eq(NoPrefix) => + typeEnv.get(tr.symbol).flatMap(typeArgsMap.get).getOrElse(tr) + case tp => mapOver(tp) + } + }, + treeMap = new TreeMap { + override def transform(tree: Tree)(using Context): Tree = + tree match + /* + * When matching a method call `f(0)` against a HOAS pattern `p(g)` where + * f has a method type `(x: Int): Int` and `f` maps to `g`, `p` should hold + * `g.apply(0)` because the type of `g` is `Int => Int` due to eta expansion. + */ + case Apply(fun, args) if termEnv.contains(tree.symbol) => transform(fun).select(nme.apply).appliedToArgs(args.map(transform)) + case tree: Ident => termEnv.get(tree.symbol).flatMap(argsMap.get).getOrElse(tree) + case tree => super.transform(tree) + }.transform + ).transform(tree) + TreeOps(body).changeNonLocalOwners(meth) } - val hoasClosure = Closure(meth, bodyFn) + val hoasClosure = Closure(meth, bodyFn).withSpan(tree.span) new ExprImpl(hoasClosure, spliceScope) private inline def notMatched[T]: optional[T] = @@ -594,12 +698,17 @@ class QuoteMatcher(debug: Boolean) { private inline def matched(tree: Tree)(using Context): MatchingExprs = Seq(MatchResult.ClosedTree(tree)) - private def matchedOpen(tree: Tree, patternTpe: Type, argIds: List[Tree], argTypes: List[Type], env: Env)(using Context): MatchingExprs = - Seq(MatchResult.OpenTree(tree, patternTpe, argIds, argTypes, env)) + private def matchedOpen(tree: Tree, patternTpe: Type, argIds: List[Tree], argTypes: List[Type], typeArgs: List[Type], env: Env)(using Context): MatchingExprs = + Seq(MatchResult.OpenTree(tree, patternTpe, argIds, argTypes, typeArgs, env)) extension (self: MatchingExprs) /** Concatenates the contents of two successful matchings */ - def &&& (that: MatchingExprs): MatchingExprs = self ++ that + private def &&& (that: MatchingExprs): MatchingExprs = self ++ that end extension + // TODO-18271: Duplicate with QuotePatterns.unrollHkNestedPairsTypeTree + private def unrollHkNestedPairsTypeTree(tree: Tree)(using Context): List[Tree] = tree match + case AppliedTypeTree(tupleN, bindings) if defn.isTupleClass(tupleN.symbol) => bindings // TupleN, 1 <= N <= 22 + case AppliedTypeTree(_, head :: tail :: Nil) => head :: unrollHkNestedPairsTypeTree(tail) // KCons or *: + case _ => Nil // KNil or EmptyTuple } diff --git a/docs/_docs/reference/experimental/quoted-patterns-with-polymorphic-functions.md b/docs/_docs/reference/experimental/quoted-patterns-with-polymorphic-functions.md new file mode 100644 index 000000000000..0c30a867b189 --- /dev/null +++ b/docs/_docs/reference/experimental/quoted-patterns-with-polymorphic-functions.md @@ -0,0 +1,54 @@ +--- +layout: doc-page +title: "Quoted Patterns with Polymorphic Functions" +nightlyOf: https://docs.scala-lang.org/scala3/reference/other-new-features/quoted-patterns-with-polymorphic-functions.html +--- + +This feature extends the capability of quoted patterns with regard to polymorphic functions. It is not yet part of the Scala language standard. To use this feature, turn on the language feature [`experimental.quotedPatternsWithPolymorphicFunctions`](https://scala-lang.org/api/3.x/scala/runtime/stdLibPatches/language$$experimental$$quotedPatternsWithPolymorphicFunctions$.html). This can be done with a language import +```scala +import scala.language.experimental.quotedPatternsWithPolymorphicFunctions +``` +or by setting the command line option `-language:experimental.quotedPatternsWithPolymorphicFunctions`. + +## Background +Quoted patterns allows us to use quoted code as a pattern. Using quoted patterns, we can check if an expression is equivalent to another, or decompose it. Especially, higher-order patterns are useful when extracting code fraguments inside function bodies. + +```scala +def decomposeFunc(x: Expr[Any])(using Quotes): Expr[Int] = + x match + case '{ (a: Int, b: Int) => $y(a, b) : Int } => + '{ $y(0, 0) } + case _ => Expr(0) +``` + +In the example above, the first case matches the case where `x` is a function and `y` is bound to the body of the function. The higher-order pattern `$y(a, b)` states that it matches any code with free occurence of variables `a` and `b`. If it is `$y(a)` instead, an expression like `(a: Int, b: Int) => a + b` will not match because `a + b` has an occurence of `b`, which is not included in the higher-order pattern. + +## Motivation +This experimental feature extends this higher-order pattern syntax to allow type variables. + +```scala +def decomposePoly(x: Expr[Any])(using Quotes): Expr[Int] = + x match + case '{ [A] => (x: List[A]) => $y[A](x) : Int } => + '{ $y[Int](List(1, 2, 3)) } + case _ => Expr(0) +``` + +Now we can use a higher-order pattern `$y[A](x)` with type variables. `y` is bound to the body of code with occurences of `A` and `x`, and has the type `[A] => (x: List[A]) => Int`. + +## Type Dependency +If a higher-order pattern carries a value parameter with a type that has type parameters defined in the quoted pattern, those type parameters should also be captured in the higher-order pattern. For example, the following pattern will not be typed. + +``` +case '{ [A] => (x: List[A]) => $y(x) : Int } => +``` + +In this case, `x` has the type `List[A]`, which includes a type variable `A` that is defined in the pattern. However, the higher-order pattern `$y(x)` does not have any type parameters. This should be ill-typed. One can always avoid this kind of type errors by adding type parameters, like `$y[A](x)` + +## Implementation Restriction +Current implementation only allows type parameters that do not have bounds, because sound typing rules for such pattern is not clear yet. + +```scala +case '{ [A] => (x: List[A]) => $y(x) : Int } => // Allowed +case '{ [A <: Int] => (x: List[A]) => $y(x) : Int } => // Disallowed +``` diff --git a/library/src/scala/quoted/runtime/Patterns.scala b/library/src/scala/quoted/runtime/Patterns.scala index 91ad23c62a98..f8e172d30f62 100644 --- a/library/src/scala/quoted/runtime/Patterns.scala +++ b/library/src/scala/quoted/runtime/Patterns.scala @@ -1,6 +1,7 @@ package scala.quoted.runtime import scala.annotation.{Annotation, compileTimeOnly} +import scala.annotation.experimental @compileTimeOnly("Illegal reference to `scala.quoted.runtime.Patterns`") object Patterns { @@ -26,6 +27,14 @@ object Patterns { @compileTimeOnly("Illegal reference to `scala.quoted.runtime.Patterns.higherOrderHole`") def higherOrderHole[U](args: Any*): U = ??? + /** A higher order splice in a quoted pattern is desugared by the compiler into a call to this method. + * + * Calling this method in source has undefined behavior at compile-time + */ + @experimental + @compileTimeOnly("Illegal reference to `scala.quoted.runtime.Patterns.higherOrderHoleWithTypes`") + def higherOrderHoleWithTypes[U, T](args: Any*): U = ??? + /** A splice of a name in a quoted pattern is that marks the definition of a type splice. * * Adding this annotation in source has undefined behavior at compile-time diff --git a/library/src/scala/runtime/stdLibPatches/language.scala b/library/src/scala/runtime/stdLibPatches/language.scala index fbab0c14c9fb..8b192be6ca40 100644 --- a/library/src/scala/runtime/stdLibPatches/language.scala +++ b/library/src/scala/runtime/stdLibPatches/language.scala @@ -98,6 +98,14 @@ object language: @compileTimeOnly("`relaxedExtensionImports` can only be used at compile time in import statements") @deprecated("The experimental.relaxedExtensionImports language import is no longer needed since the feature is now standard", since = "3.4") object relaxedExtensionImports + + /** Experimental support for quote pattern matching with polymorphic functions + * + * @see [[https://dotty.epfl.ch/docs/reference/experimental/quoted-patterns-with-polymorphic-functions]] + */ + @compileTimeOnly("`quotedPatternsWithPolymorphicFunctions` can only be used at compile time in import statements") + object quotedPatternsWithPolymorphicFunctions + end experimental /** The deprecated object contains features that are no longer officially suypported in Scala. diff --git a/tests/neg-macros/quoted-pattern-with-bounded-type-params-regression.check b/tests/neg-macros/quoted-pattern-with-bounded-type-params-regression.check new file mode 100644 index 000000000000..860482f2e552 --- /dev/null +++ b/tests/neg-macros/quoted-pattern-with-bounded-type-params-regression.check @@ -0,0 +1,6 @@ +-- Error: tests/neg-macros/quoted-pattern-with-bounded-type-params-regression.scala:11:48 ------------------------------ +11 | case '{ [A <: Int, B] => (x : A, y : A) => $b[A](x, y) : A } => ??? // error + | ^ + | Type must be fully defined. + | Consider annotating the splice using a type ascription: + | (${b}: XYZ). diff --git a/tests/neg-macros/quoted-pattern-with-bounded-type-params-regression.scala b/tests/neg-macros/quoted-pattern-with-bounded-type-params-regression.scala new file mode 100644 index 000000000000..6797ae926367 --- /dev/null +++ b/tests/neg-macros/quoted-pattern-with-bounded-type-params-regression.scala @@ -0,0 +1,12 @@ +/** + * Supporting hoas quote pattern with bounded type variable + * is future todo. + * Refer to: quoted-pattern-with-bounded-type-params.scala + */ + +import scala.quoted.* + +def test(body: Expr[Any])(using Quotes): Expr[String] = + body match + case '{ [A <: Int, B] => (x : A, y : A) => $b[A](x, y) : A } => ??? // error + case _ => Expr("not matched") diff --git a/tests/neg-macros/quoted-pattern-with-bounded-type-params.check b/tests/neg-macros/quoted-pattern-with-bounded-type-params.check new file mode 100644 index 000000000000..0e787377bfc5 --- /dev/null +++ b/tests/neg-macros/quoted-pattern-with-bounded-type-params.check @@ -0,0 +1,4 @@ +-- Error: tests/neg-macros/quoted-pattern-with-bounded-type-params.scala:11:50 ----------------------------------------- +11 | case '{ [A <: Int, B] => (x : A, y : A) => $b[A](x, y) : A } => ??? // error + | ^ + | Implementation restriction: Type arguments to Open pattern are expected to have no bounds diff --git a/tests/neg-macros/quoted-pattern-with-bounded-type-params.scala b/tests/neg-macros/quoted-pattern-with-bounded-type-params.scala new file mode 100644 index 000000000000..567efa9ee35d --- /dev/null +++ b/tests/neg-macros/quoted-pattern-with-bounded-type-params.scala @@ -0,0 +1,12 @@ +/* + * Supporting hoas quote pattern with bounded type variable + * is future todo. + */ + +import scala.quoted.* +import scala.language.experimental.quotedPatternsWithPolymorphicFunctions + +def test(body: Expr[Any])(using Quotes): Expr[String] = + body match + case '{ [A <: Int, B] => (x : A, y : A) => $b[A](x, y) : A } => ??? // error + case _ => Expr("not matched") diff --git a/tests/neg-macros/quoted-pattern-with-type-params-regression.check b/tests/neg-macros/quoted-pattern-with-type-params-regression.check new file mode 100644 index 000000000000..543c119b3d33 --- /dev/null +++ b/tests/neg-macros/quoted-pattern-with-type-params-regression.check @@ -0,0 +1,16 @@ +-- Error: tests/neg-macros/quoted-pattern-with-type-params-regression.scala:8:31 --------------------------------------- +8 | case '{ [A] => (x : A) => $b[A] : (A => A) } => ??? // error + | ^ + | Type must be fully defined. + | Consider annotating the splice using a type ascription: + | (${b}: XYZ). +-- Error: tests/neg-macros/quoted-pattern-with-type-params-regression.scala:9:33 --------------------------------------- +9 | case '{ [A] => (x : A) => $b(x) : (A => A) } => ??? // error + | ^ + | Type variables that this argument depends on are not captured in this hoas pattern +-- Error: tests/neg-macros/quoted-pattern-with-type-params-regression.scala:10:24 -------------------------------------- +10 | case '{ (a:Int) => $b[Int](a) : String } => ??? // error + | ^ + | Type must be fully defined. + | Consider annotating the splice using a type ascription: + | (${b}: XYZ). diff --git a/tests/neg-macros/quoted-pattern-with-type-params-regression.scala b/tests/neg-macros/quoted-pattern-with-type-params-regression.scala new file mode 100644 index 000000000000..aa2489bc440b --- /dev/null +++ b/tests/neg-macros/quoted-pattern-with-type-params-regression.scala @@ -0,0 +1,11 @@ +/** + * Refer to: quoted-pattern-with-type-params.scala + */ +import scala.quoted.* + +def test(body: Expr[Any])(using Quotes): Expr[String] = + body match + case '{ [A] => (x : A) => $b[A] : (A => A) } => ??? // error + case '{ [A] => (x : A) => $b(x) : (A => A) } => ??? // error + case '{ (a:Int) => $b[Int](a) : String } => ??? // error + case _ => Expr("not matched") diff --git a/tests/neg-macros/quoted-pattern-with-type-params.check b/tests/neg-macros/quoted-pattern-with-type-params.check new file mode 100644 index 000000000000..37e8f611d5a9 --- /dev/null +++ b/tests/neg-macros/quoted-pattern-with-type-params.check @@ -0,0 +1,12 @@ +-- Error: tests/neg-macros/quoted-pattern-with-type-params.scala:6:32 -------------------------------------------------- +6 | case '{ [A] => (x : A) => $b[A] : (A => A) } => ??? // error + | ^^^^^ + | Implementation restriction: A higher-order pattern must carry value arguments +-- Error: tests/neg-macros/quoted-pattern-with-type-params.scala:7:33 -------------------------------------------------- +7 | case '{ [A] => (x : A) => $b(x) : (A => A) } => ??? // error + | ^ + | Type variables that this argument depends on are not captured in this hoas pattern +-- Error: tests/neg-macros/quoted-pattern-with-type-params.scala:8:26 -------------------------------------------------- +8 | case '{ (a:Int) => $b[Int](a) : String } => ??? // error + | ^^^ + | Type arguments of a hoas pattern needs to be defined inside the quoted pattern diff --git a/tests/neg-macros/quoted-pattern-with-type-params.scala b/tests/neg-macros/quoted-pattern-with-type-params.scala new file mode 100644 index 000000000000..2e4a059ee23a --- /dev/null +++ b/tests/neg-macros/quoted-pattern-with-type-params.scala @@ -0,0 +1,9 @@ +import scala.quoted.* +import scala.language.experimental.quotedPatternsWithPolymorphicFunctions + +def test(body: Expr[Any])(using Quotes): Expr[String] = + body match + case '{ [A] => (x : A) => $b[A] : (A => A) } => ??? // error + case '{ [A] => (x : A) => $b(x) : (A => A) } => ??? // error + case '{ (a:Int) => $b[Int](a) : String } => ??? // error + case _ => Expr("not matched") diff --git a/tests/pos-macros/quoted-patten-with-type-params.scala b/tests/pos-macros/quoted-patten-with-type-params.scala new file mode 100644 index 000000000000..030e3415476e --- /dev/null +++ b/tests/pos-macros/quoted-patten-with-type-params.scala @@ -0,0 +1,14 @@ +import scala.quoted.* +import scala.language.experimental.quotedPatternsWithPolymorphicFunctions + +def test(body: Expr[Any])(using Quotes): Expr[String] = + body match + case '{ [A] => (x : A, y : A) => (x, y) } => ??? + // Bounded type parameters are allowed when they are not used in + // higher-order patterns + case '{ [A <: Iterable[Int]] => (x : A) => x } => ??? + case '{ [A] => (x : A, y : A) => $b[A](x, y) : A } => + '{ $b[String]("truthy", "falsy") } + case '{ [A, B] => (x : A, f : A => B) => $b[A, B](x, f) : B} => + '{ $b[Int, String](10, (x:Int)=>x.toHexString) } + case _ => Expr("not matched") diff --git a/tests/run-macros/quote-match-poly-function-1-regression.check b/tests/run-macros/quote-match-poly-function-1-regression.check new file mode 100644 index 000000000000..d871d3004550 --- /dev/null +++ b/tests/run-macros/quote-match-poly-function-1-regression.check @@ -0,0 +1,3 @@ +Case 1 matched +not matched +not matched diff --git a/tests/run-macros/quote-match-poly-function-1-regression/Macro_1.scala b/tests/run-macros/quote-match-poly-function-1-regression/Macro_1.scala new file mode 100644 index 000000000000..a148fdee3d27 --- /dev/null +++ b/tests/run-macros/quote-match-poly-function-1-regression/Macro_1.scala @@ -0,0 +1,8 @@ +import scala.quoted.* + +inline def testExpr(inline body: Any) = ${ testExprImpl1('body) } +def testExprImpl1(body: Expr[Any])(using Quotes): Expr[String] = + body match + case '{ [A] => (x : A, y : A) => (x, y) } => Expr("Case 1 matched") + case '{ [A <: Iterable[Int]] => (x : A) => x } => Expr("Case 2 matched") + case _ => Expr("not matched") diff --git a/tests/run-macros/quote-match-poly-function-1-regression/Test_2.scala b/tests/run-macros/quote-match-poly-function-1-regression/Test_2.scala new file mode 100644 index 000000000000..9c89b1aa9db0 --- /dev/null +++ b/tests/run-macros/quote-match-poly-function-1-regression/Test_2.scala @@ -0,0 +1,4 @@ +@main def Test: Unit = + println(testExpr([B] => (x : B, y : B) => (x, y))) + println(testExpr([B <: Iterable[Int]] => (x : B) => x)) + println(testExpr([B <: List[Int]] => (x : B) => x)) diff --git a/tests/run-macros/quote-match-poly-function-1.check b/tests/run-macros/quote-match-poly-function-1.check new file mode 100644 index 000000000000..d871d3004550 --- /dev/null +++ b/tests/run-macros/quote-match-poly-function-1.check @@ -0,0 +1,3 @@ +Case 1 matched +not matched +not matched diff --git a/tests/run-macros/quote-match-poly-function-1/Macro_1.scala b/tests/run-macros/quote-match-poly-function-1/Macro_1.scala new file mode 100644 index 000000000000..07fd18ccabb7 --- /dev/null +++ b/tests/run-macros/quote-match-poly-function-1/Macro_1.scala @@ -0,0 +1,9 @@ +import scala.quoted.* +import scala.language.experimental.quotedPatternsWithPolymorphicFunctions + +inline def testExpr(inline body: Any) = ${ testExprImpl1('body) } +def testExprImpl1(body: Expr[Any])(using Quotes): Expr[String] = + body match + case '{ [A] => (x : A, y : A) => (x, y) } => Expr("Case 1 matched") + case '{ [A <: Iterable[Int]] => (x : A) => x } => Expr("Case 2 matched") + case _ => Expr("not matched") diff --git a/tests/run-macros/quote-match-poly-function-1/Test_2.scala b/tests/run-macros/quote-match-poly-function-1/Test_2.scala new file mode 100644 index 000000000000..9c89b1aa9db0 --- /dev/null +++ b/tests/run-macros/quote-match-poly-function-1/Test_2.scala @@ -0,0 +1,4 @@ +@main def Test: Unit = + println(testExpr([B] => (x : B, y : B) => (x, y))) + println(testExpr([B <: Iterable[Int]] => (x : B) => x)) + println(testExpr([B <: List[Int]] => (x : B) => x)) diff --git a/tests/run-macros/quote-match-poly-function-2.check b/tests/run-macros/quote-match-poly-function-2.check new file mode 100644 index 000000000000..a9ad3170d8fb --- /dev/null +++ b/tests/run-macros/quote-match-poly-function-2.check @@ -0,0 +1,7 @@ +case 2 matched => 5 +case 3 matched => truthy +case 4 matched => truthy +case 5 matchd => 1 +case 7 matchd => 1 +case 8 matched => (1,str) +case 9 matched => zero diff --git a/tests/run-macros/quote-match-poly-function-2/Macro_1.scala b/tests/run-macros/quote-match-poly-function-2/Macro_1.scala new file mode 100644 index 000000000000..8b5d5a85942a --- /dev/null +++ b/tests/run-macros/quote-match-poly-function-2/Macro_1.scala @@ -0,0 +1,23 @@ +import scala.quoted.* +import scala.language.experimental.quotedPatternsWithPolymorphicFunctions + +inline def testExpr(inline body: Any) = ${ testExprImpl1('body) } +def testExprImpl1(body: Expr[Any])(using Quotes): Expr[String] = + body match + case '{ [A] => (x : Int, y : Int) => $b(x, y) : Int } => + '{ "case 2 matched => " + $b(2, 3) } + case '{ [A] => (x : A, y : A) => $b[A](x, y) : A } => + '{ "case 3 matched => " + $b[String]("truthy", "falsy") } + case '{ [A] => (x : A, y : A) => $b[A](x, y) : (A, A) } => + '{ "case 4 matched => " + $b[String]("truthy", "falsy")._2 } + case '{ [A, B] => (x : A, y : A => B) => $a[A, B](x, y) : B } => + '{ "case 5 matchd => " + $a[Int, Int](0, x => x + 1) } + case '{ [A] => (x : List[A], y : A) => $a[A](x) : Int } => + '{ "case 6 matchd => " + $a[Int](List(1, 2, 3)) } + case '{ [A] => (x : List[A], y : A) => $a[A](x, y) : Int } => + '{ "case 7 matchd => " + $a[Int](List(1, 2, 3), 2) } + case '{ [A] => (x : A) => [B] => (y : B) => $a[A, B](x, y) : (A, B) } => + '{ "case 8 matched => " + $a[Int, String](1, "str")} + case '{ [A, B] => (x : Map[A, B], y: A) => $a[A, B](x, y) : Option[B] } => + '{ "case 9 matched => " + $a[Int, String](Map(0 -> "zero", 1 -> "one"), 0).getOrElse("failed") } + case _ => Expr("not matched") diff --git a/tests/run-macros/quote-match-poly-function-2/Test_2.scala b/tests/run-macros/quote-match-poly-function-2/Test_2.scala new file mode 100644 index 000000000000..af249ab2eb57 --- /dev/null +++ b/tests/run-macros/quote-match-poly-function-2/Test_2.scala @@ -0,0 +1,8 @@ +@main def Test: Unit = + println(testExpr([B] => (x : Int, y : Int) => x + y)) // Should match case 2 + println(testExpr([B] => (x : B, y : B) => x)) // Should match case 3 + println(testExpr([B] => (x : B, y : B) => (y, x))) // Should match case 4 + println(testExpr([C, D] => (x : C, f : C => D) => f(x))) // Should match case 4 + println(testExpr([B] => (x : List[B], y : B) => x.indexOf(y))) // Should match case 7 + println(testExpr([B] => (x : B) => [C] => (y : C) => (x, y))) // Should match case 8 + println(testExpr([C, D] => (x : Map[C, D], y: C) => x.get(y))) diff --git a/tests/run-tasty-inspector/stdlibExperimentalDefinitions.scala b/tests/run-tasty-inspector/stdlibExperimentalDefinitions.scala index a01c71724b0e..df0220458fe2 100644 --- a/tests/run-tasty-inspector/stdlibExperimentalDefinitions.scala +++ b/tests/run-tasty-inspector/stdlibExperimentalDefinitions.scala @@ -97,7 +97,11 @@ val experimentalDefinitionInLibrary = Set( "scala.Tuple$.Helpers$", "scala.Tuple$.Helpers$.ReverseImpl", "scala.Tuple$.Reverse", - "scala.runtime.Tuples$.reverse" + "scala.runtime.Tuples$.reverse", + + // New feature: functions with erased parameters. + // Need quotedPatternsWithPolymorphicFunctions enabled. + "scala.quoted.runtime.Patterns$.higherOrderHoleWithTypes" )