From 952eff71f4b2f49ca1870b4ef568cd8da97ee1e0 Mon Sep 17 00:00:00 2001 From: Kacper Korban Date: Mon, 18 Nov 2024 16:49:08 +0100 Subject: [PATCH] Cleanup context bounds for poly functions implementation after review --- .../src/dotty/tools/dotc/ast/Desugar.scala | 50 +++++++++++++------ 1 file changed, 36 insertions(+), 14 deletions(-) diff --git a/compiler/src/dotty/tools/dotc/ast/Desugar.scala b/compiler/src/dotty/tools/dotc/ast/Desugar.scala index 6e54dee51c89..56c153498f87 100644 --- a/compiler/src/dotty/tools/dotc/ast/Desugar.scala +++ b/compiler/src/dotty/tools/dotc/ast/Desugar.scala @@ -350,11 +350,13 @@ object desugar { val Function(vparams: List[untpd.ValDef] @unchecked, rhs) = fun: @unchecked val newParamss = paramssNoContextBounds(tparams :: vparams :: Nil) val params = evidenceParamBuf.toList - if params.isEmpty then return meth - val boundNames = getBoundNames(params, newParamss) - val recur = fitEvidenceParams(params, nme.apply, boundNames) - val (paramsFst, paramsSnd) = recur(newParamss) - functionsOf((paramsFst ++ paramsSnd).filter(_.nonEmpty), rhs) + if params.isEmpty then + meth + else + val boundNames = getBoundNames(params, newParamss) + val recur = fitEvidenceParams(params, nme.apply, boundNames) + val (paramsFst, paramsSnd) = recur(newParamss) + functionsOf((paramsFst ++ paramsSnd).filter(_.nonEmpty), rhs) end elimContextBounds def addDefaultGetters(meth: DefDef)(using Context): Tree = @@ -487,8 +489,27 @@ object desugar { case Ident(name: TermName) => names.contains(name) case _ => false - /** Fit evidence `params` into the `mparamss` parameter lists */ - private def fitEvidenceParams(params: List[ValDef], methName: Name, boundNames: Set[TermName])(mparamss: List[ParamClause])(using Context): (List[ParamClause], List[ParamClause]) = mparamss match + /** Fit evidence `params` into the `mparamss` parameter lists, making sure + * that all parameters referencing `params` are after them. + * - for methods the final parameter lists are := result._1 ++ result._2 + * - for poly functions, each element of the pair contains at most one term + * parameter list + * + * @param params the evidence parameters list that should fit into `mparamss` + * @param methName the name of the method that `mparamss` belongs to + * @param boundNames the names of the evidence parameters + * @param mparamss the original parameter lists of the method + * @return a pair of parameter lists containing all parameter lists in a + * reference-correct order; make sure that `params` is always at the + * intersection of the pair elements; this is relevant, for poly functions + * where `mparamss` is guaranteed to have exectly one term parameter list, + * then each pair element will have at most one term parameter list + */ + private def fitEvidenceParams( + params: List[ValDef], + methName: Name, + boundNames: Set[TermName] + )(mparamss: List[ParamClause])(using Context): (List[ParamClause], List[ParamClause]) = mparamss match case ValDefs(mparams) :: _ if mparams.exists(referencesName(_, boundNames)) => (params :: Nil) -> mparamss case ValDefs(mparams @ (mparam :: _)) :: Nil if mparam.mods.isOneOf(GivenOrImplicit) => @@ -547,19 +568,20 @@ object desugar { val boundNames = getBoundNames(params, meth.paramss) - val recur = fitEvidenceParams(params, meth.name, boundNames) + val fitParams = fitEvidenceParams(params, meth.name, boundNames) - if meth.hasAttachment(PolyFunctionApply) then - meth.removeAttachment(PolyFunctionApply) - // for PolyFunctions we are limited to a single term param list, so we reuse the recur logic to compute the new parameter lists - // and then we add the other parameter lists as function types to the return type - val (paramsFst, paramsSnd) = recur(meth.paramss) + if meth.removeAttachment(PolyFunctionApply).isDefined then + // for PolyFunctions we are limited to a single term param list, so we + // reuse the fitEvidenceParams logic to compute the new parameter lists + // and then we add the other parameter lists as function types to the + // return type + val (paramsFst, paramsSnd) = fitParams(meth.paramss) if ctx.mode.is(Mode.Type) then cpy.DefDef(meth)(paramss = paramsFst, tpt = functionsOf(paramsSnd, meth.tpt)) else cpy.DefDef(meth)(paramss = paramsFst, rhs = functionsOf(paramsSnd, meth.rhs)) else - val (paramsFst, paramsSnd) = recur(meth.paramss) + val (paramsFst, paramsSnd) = fitParams(meth.paramss) cpy.DefDef(meth)(paramss = paramsFst ++ paramsSnd) end addEvidenceParams