From 4df0b578e7127c2042ff482f97fc5bf77e4f5a0c Mon Sep 17 00:00:00 2001 From: Nicolas Stucki Date: Tue, 22 Aug 2023 15:22:54 +0200 Subject: [PATCH] Handle dependent context functions Add `FunctionTypeOfMethod` extractor that matches any kind of function and return its method type. We use this extractor instead of `ContextFunctionType` to all of * `ContextFunctionN[...]` * `ContextFunctionN[...] { def apply(using ...): R }` where `R` might be dependent on the parameters. * `PolyFunction { def apply(using ...): R }` where `R` might be dependent on the parameters. Currently this one would have at least one erased parameter. --- compiler/src/dotty/tools/dotc/ast/TreeInfo.scala | 2 +- .../src/dotty/tools/dotc/core/Definitions.scala | 16 ++++++++++++++++ .../dotc/transform/ContextFunctionResults.scala | 15 ++++++++------- 3 files changed, 25 insertions(+), 8 deletions(-) diff --git a/compiler/src/dotty/tools/dotc/ast/TreeInfo.scala b/compiler/src/dotty/tools/dotc/ast/TreeInfo.scala index 6659818b333e..18a1b73984e1 100644 --- a/compiler/src/dotty/tools/dotc/ast/TreeInfo.scala +++ b/compiler/src/dotty/tools/dotc/ast/TreeInfo.scala @@ -954,7 +954,7 @@ trait TypedTreeInfo extends TreeInfo[Type] { self: Trees.Instance[Type] => def isStructuralTermSelectOrApply(tree: Tree)(using Context): Boolean = { def isStructuralTermSelect(tree: Select) = def hasRefinement(qualtpe: Type): Boolean = qualtpe.dealias match - case defn.PolyFunctionOf(_) => + case defn.FunctionTypeOfMethod(_) => false case RefinedType(parent, rname, rinfo) => rname == tree.name || hasRefinement(parent) diff --git a/compiler/src/dotty/tools/dotc/core/Definitions.scala b/compiler/src/dotty/tools/dotc/core/Definitions.scala index b4df6bcd4ca5..3e50239e90f3 100644 --- a/compiler/src/dotty/tools/dotc/core/Definitions.scala +++ b/compiler/src/dotty/tools/dotc/core/Definitions.scala @@ -1108,6 +1108,22 @@ class Definitions { // - .linkedClass: the ClassSymbol of the enumeration (class E) sym.owner.linkedClass.typeRef + object FunctionTypeOfMethod { + /** Matches a `FunctionN[...]`/`ContextFunctionN[...]` or refined `PolyFunction`/`FunctionN[...]`/`ContextFunctionN[...]`. + * Extracts the method type type and apply info. + */ + def unapply(ft: Type)(using Context): Option[MethodOrPoly] = { + ft match + case RefinedType(parent, nme.apply, mt: MethodOrPoly) + if parent.derivesFrom(defn.PolyFunctionClass) || isFunctionNType(parent) => + Some(mt) + case FunctionOf(argTypes, resultType, isContextual) => + val methodType = if isContextual then ContextualMethodType else MethodType + Some(methodType(argTypes, resultType)) + case _ => None + } + } + object FunctionOf { def apply(args: List[Type], resultType: Type, isContextual: Boolean = false)(using Context): Type = val mt = MethodType.companion(isContextual, false)(args, resultType) diff --git a/compiler/src/dotty/tools/dotc/transform/ContextFunctionResults.scala b/compiler/src/dotty/tools/dotc/transform/ContextFunctionResults.scala index b4eb71c541d3..3e6ecc892667 100644 --- a/compiler/src/dotty/tools/dotc/transform/ContextFunctionResults.scala +++ b/compiler/src/dotty/tools/dotc/transform/ContextFunctionResults.scala @@ -58,7 +58,8 @@ object ContextFunctionResults: */ def contextResultsAreErased(sym: Symbol)(using Context): Boolean = def allErased(tp: Type): Boolean = tp.dealias match - case defn.ContextFunctionType(_, resTpe, erasedParams) => !erasedParams.contains(false) && allErased(resTpe) + case ft @ defn.FunctionTypeOfMethod(mt: MethodType) if mt.isContextualMethod => + !mt.erasedParams.contains(false) && allErased(mt.resType) case _ => true contextResultCount(sym) > 0 && allErased(sym.info.finalResultType) @@ -67,13 +68,13 @@ object ContextFunctionResults: */ def integrateContextResults(tp: Type, crCount: Int)(using Context): Type = if crCount == 0 then tp - else tp match + else tp.dealias match case ExprType(rt) => integrateContextResults(rt, crCount) case tp: MethodOrPoly => tp.derivedLambdaType(resType = integrateContextResults(tp.resType, crCount)) - case defn.ContextFunctionType(argTypes, resType, erasedParams) => - MethodType(argTypes, integrateContextResults(resType, crCount - 1)) + case defn.FunctionTypeOfMethod(mt) if mt.isContextualMethod => + mt.derivedLambdaType(resType = integrateContextResults(mt.resType, crCount - 1)) /** The total number of parameters of method `sym`, not counting * erased parameters, but including context result parameters. @@ -103,7 +104,7 @@ object ContextFunctionResults: def recur(tp: Type, n: Int): Type = if n == 0 then tp else tp match - case defn.ContextFunctionType(_, resTpe, _) => recur(resTpe, n - 1) + case defn.FunctionTypeOfMethod(mt) => recur(mt.resType, n - 1) recur(meth.info.finalResultType, depth) /** Should selection `tree` be eliminated since it refers to an `apply` @@ -117,8 +118,8 @@ object ContextFunctionResults: else tree match case Select(qual, name) => if name == nme.apply then - qual.tpe match - case defn.ContextFunctionType(_, _, _) => + qual.tpe.nn.dealias match + case defn.FunctionTypeOfMethod(mt) if mt.isContextualMethod => integrateSelect(qual, n + 1) case _ if defn.isContextFunctionClass(tree.symbol.maybeOwner) => // for TermRefs integrateSelect(qual, n + 1)