Skip to content

Commit

Permalink
Backport "Three fixes to SAM type handling" to LTS (#22132)
Browse files Browse the repository at this point in the history
Backports #21596 to the 3.3.5.

PR submitted by the release tooling.
[skip ci]
  • Loading branch information
WojciechMazur authored Dec 4, 2024
2 parents 8a459ab + 81ff618 commit 0c5915c
Show file tree
Hide file tree
Showing 7 changed files with 124 additions and 38 deletions.
63 changes: 44 additions & 19 deletions compiler/src/dotty/tools/dotc/ast/tpd.scala
Original file line number Diff line number Diff line change
Expand Up @@ -305,24 +305,41 @@ object tpd extends Trees.Instance[Type] with TypedTreeInfo {
def TypeDef(sym: TypeSymbol)(using Context): TypeDef =
ta.assignType(untpd.TypeDef(sym.name, TypeTree(sym.info)), sym)

def ClassDef(cls: ClassSymbol, constr: DefDef, body: List[Tree], superArgs: List[Tree] = Nil)(using Context): TypeDef = {
/** Create a class definition
* @param cls the class symbol of the created class
* @param constr its primary constructor
* @param body the statements in its template
* @param superArgs the arguments to pass to the superclass constructor
* @param adaptVarargs if true, allow matching a vararg superclass constructor
* with a missing argument in superArgs, and synthesize an
* empty repeated parameter in the supercall in this case
*/
def ClassDef(cls: ClassSymbol, constr: DefDef, body: List[Tree],
superArgs: List[Tree] = Nil, adaptVarargs: Boolean = false)(using Context): TypeDef =
val firstParent :: otherParents = cls.info.parents: @unchecked

def adaptedSuperArgs(ctpe: Type): List[Tree] = ctpe match
case ctpe: PolyType =>
adaptedSuperArgs(ctpe.instantiate(firstParent.argTypes))
case ctpe: MethodType
if ctpe.paramInfos.length == superArgs.length + 1 =>
// last argument must be a vararg, otherwise isApplicable would have failed
superArgs :+
repeated(Nil, TypeTree(ctpe.paramInfos.last.argInfos.head, inferred = true))
case _ =>
superArgs

val superRef =
if (cls.is(Trait)) TypeTree(firstParent)
else {
def isApplicable(ctpe: Type): Boolean = ctpe match {
case ctpe: PolyType =>
isApplicable(ctpe.instantiate(firstParent.argTypes))
case ctpe: MethodType =>
(superArgs corresponds ctpe.paramInfos)(_.tpe <:< _)
case _ =>
false
}
val constr = firstParent.decl(nme.CONSTRUCTOR).suchThat(constr => isApplicable(constr.info))
New(firstParent, constr.symbol.asTerm, superArgs)
}
if cls.is(Trait) then TypeTree(firstParent)
else
val parentConstr = firstParent.applicableConstructors(superArgs.tpes, adaptVarargs) match
case Nil => assert(false, i"no applicable parent constructor of $firstParent for supercall arguments $superArgs")
case constr :: Nil => constr
case _ => assert(false, i"multiple applicable parent constructors of $firstParent for supercall arguments $superArgs")
New(firstParent, parentConstr.asTerm, adaptedSuperArgs(parentConstr.info))

ClassDefWithParents(cls, constr, superRef :: otherParents.map(TypeTree(_)), body)
}
end ClassDef

def ClassDefWithParents(cls: ClassSymbol, constr: DefDef, parents: List[Tree], body: List[Tree])(using Context): TypeDef = {
val selfType =
Expand All @@ -349,13 +366,18 @@ object tpd extends Trees.Instance[Type] with TypedTreeInfo {
* @param parents a non-empty list of class types
* @param termForwarders a non-empty list of forwarding definitions specified by their name and the definition they forward to.
* @param typeMembers a possibly-empty list of type members specified by their name and their right hand side.
* @param adaptVarargs if true, allow matching a vararg superclass constructor
* with a missing argument in superArgs, and synthesize an
* empty repeated parameter in the supercall in this case
*
* The class has the same owner as the first function in `termForwarders`.
* Its position is the union of all symbols in `termForwarders`.
*/
def AnonClass(parents: List[Type], termForwarders: List[(TermName, TermSymbol)],
typeMembers: List[(TypeName, TypeBounds)] = Nil)(using Context): Block = {
AnonClass(termForwarders.head._2.owner, parents, termForwarders.map(_._2.span).reduceLeft(_ union _)) { cls =>
def AnonClass(parents: List[Type],
termForwarders: List[(TermName, TermSymbol)],
typeMembers: List[(TypeName, TypeBounds)],
adaptVarargs: Boolean)(using Context): Block = {
AnonClass(termForwarders.head._2.owner, parents, termForwarders.map(_._2.span).reduceLeft(_ union _), adaptVarargs) { cls =>
def forwarder(name: TermName, fn: TermSymbol) = {
val fwdMeth = fn.copy(cls, name, Synthetic | Method | Final).entered.asTerm
for overridden <- fwdMeth.allOverriddenSymbols do
Expand All @@ -375,6 +397,9 @@ object tpd extends Trees.Instance[Type] with TypedTreeInfo {
* with the specified owner and position.
*/
def AnonClass(owner: Symbol, parents: List[Type], coord: Coord)(body: ClassSymbol => List[Tree])(using Context): Block =
AnonClass(owner, parents, coord, adaptVarargs = false)(body)

private def AnonClass(owner: Symbol, parents: List[Type], coord: Coord, adaptVarargs: Boolean)(body: ClassSymbol => List[Tree])(using Context): Block =
val parents1 =
if (parents.head.classSymbol.is(Trait)) {
val head = parents.head.parents.head
Expand All @@ -383,7 +408,7 @@ object tpd extends Trees.Instance[Type] with TypedTreeInfo {
else parents
val cls = newNormalizedClassSymbol(owner, tpnme.ANON_CLASS, Synthetic | Final, parents1, coord = coord)
val constr = newConstructor(cls, Synthetic, Nil, Nil).entered
val cdef = ClassDef(cls, DefDef(constr), body(cls))
val cdef = ClassDef(cls, DefDef(constr), body(cls), Nil, adaptVarargs)
Block(cdef :: Nil, New(cls.typeRef, Nil))

def Import(expr: Tree, selectors: List[untpd.ImportSelector])(using Context): Import =
Expand Down
2 changes: 1 addition & 1 deletion compiler/src/dotty/tools/dotc/core/Phases.scala
Original file line number Diff line number Diff line change
Expand Up @@ -486,7 +486,7 @@ object Phases {
def sbtExtractDependenciesPhase(using Context): Phase = ctx.base.sbtExtractDependenciesPhase
def picklerPhase(using Context): Phase = ctx.base.picklerPhase
def inliningPhase(using Context): Phase = ctx.base.inliningPhase
def stagingPhase(using Context): Phase = ctx.base.stagingPhase
def stagingPhase(using Context): Phase = ctx.base.stagingPhase
def splicingPhase(using Context): Phase = ctx.base.splicingPhase
def firstTransformPhase(using Context): Phase = ctx.base.firstTransformPhase
def refchecksPhase(using Context): Phase = ctx.base.refchecksPhase
Expand Down
34 changes: 30 additions & 4 deletions compiler/src/dotty/tools/dotc/core/TypeUtils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,13 @@ package core
import TypeErasure.ErasedValueType
import Types.*, Contexts.*, Symbols.*, Flags.*, Decorators.*
import Names.Name
import StdNames.nme

class TypeUtils {
class TypeUtils:
/** A decorator that provides methods on types
* that are needed in the transformer pipeline.
*/
extension (self: Type) {
extension (self: Type)

def isErasedValueType(using Context): Boolean =
self.isInstanceOf[ErasedValueType]
Expand Down Expand Up @@ -125,5 +126,30 @@ class TypeUtils {
def takesImplicitParams(using Context): Boolean = self.stripPoly match
case mt: MethodType => mt.isImplicitMethod || mt.resType.takesImplicitParams
case _ => false
}
}

/** The constructors of this type that are applicable to `argTypes`, without needing
* an implicit conversion. Curried constructors are always excluded.
* @param adaptVarargs if true, allow a constructor with just a varargs argument to
* match an empty argument list.
*/
def applicableConstructors(argTypes: List[Type], adaptVarargs: Boolean)(using Context): List[Symbol] =
def isApplicable(constr: Symbol): Boolean =
def recur(ctpe: Type): Boolean = ctpe match
case ctpe: PolyType =>
if argTypes.isEmpty then recur(ctpe.resultType) // no need to know instances
else recur(ctpe.instantiate(self.argTypes))
case ctpe: MethodType =>
var paramInfos = ctpe.paramInfos
if adaptVarargs && paramInfos.length == argTypes.length + 1
&& atPhaseNoLater(Phases.elimRepeatedPhase)(constr.info.isVarArgsMethod)
then // accept missing argument for varargs parameter
paramInfos = paramInfos.init
argTypes.corresponds(paramInfos)(_ <:< _) && !ctpe.resultType.isInstanceOf[MethodType]
case _ =>
false
recur(constr.info)

self.decl(nme.CONSTRUCTOR).altsWith(isApplicable).map(_.symbol)

end TypeUtils

21 changes: 11 additions & 10 deletions compiler/src/dotty/tools/dotc/core/Types.scala
Original file line number Diff line number Diff line change
Expand Up @@ -5612,17 +5612,18 @@ object Types extends TypeUtils {

def samClass(tp: Type)(using Context): Symbol = tp match
case tp: ClassInfo =>
def zeroParams(tp: Type): Boolean = tp.stripPoly match
case mt: MethodType => mt.paramInfos.isEmpty && !mt.resultType.isInstanceOf[MethodType]
case et: ExprType => true
case _ => false
val cls = tp.cls
val validCtor =
val ctor = cls.primaryConstructor
// `ContextFunctionN` does not have constructors
!ctor.exists || zeroParams(ctor.info)
val isInstantiable = !cls.isOneOf(FinalOrSealed) && (tp.appliedRef <:< tp.selfType)
if validCtor && isInstantiable then tp.cls
def takesNoArgs(tp: Type) =
!tp.classSymbol.primaryConstructor.exists
// e.g. `ContextFunctionN` does not have constructors
|| tp.applicableConstructors(Nil, adaptVarargs = true).lengthCompare(1) == 0
// we require a unique constructor so that SAM expansion is deterministic
val noArgsNeeded: Boolean =
takesNoArgs(tp)
&& (!tp.cls.is(Trait) || takesNoArgs(tp.parents.head))
def isInstantiable =
!tp.cls.isOneOf(FinalOrSealed) && (tp.appliedRef <:< tp.selfType)
if noArgsNeeded && isInstantiable then tp.cls
else NoSymbol
case tp: AppliedType =>
samClass(tp.superType)
Expand Down
10 changes: 6 additions & 4 deletions compiler/src/dotty/tools/dotc/transform/ExpandSAMs.scala
Original file line number Diff line number Diff line change
Expand Up @@ -69,10 +69,12 @@ class ExpandSAMs extends MiniPhase:
val tpe1 = collectAndStripRefinements(tpe)
val Seq(samDenot) = tpe1.possibleSamMethods
cpy.Block(tree)(stats,
AnonClass(List(tpe1),
List(samDenot.symbol.asTerm.name -> fn.symbol.asTerm),
refinements.toList
)
transformFollowingDeep:
AnonClass(List(tpe1),
List(samDenot.symbol.asTerm.name -> fn.symbol.asTerm),
refinements.toList,
adaptVarargs = true
)
)
}
case _ =>
Expand Down
15 changes: 15 additions & 0 deletions tests/neg/i15855.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
class MyFunction(args: String)

trait MyFunction0[+R] extends MyFunction {
def apply(): R
}

def fromFunction0[R](f: Function0[R]): MyFunction0[R] = () => f() // error

class MyFunctionWithImplicit(implicit args: String)

trait MyFunction0WithImplicit[+R] extends MyFunctionWithImplicit {
def apply(): R
}

def fromFunction1[R](f: Function0[R]): MyFunction0WithImplicit[R] = () => f() // error
17 changes: 17 additions & 0 deletions tests/run/i15855.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
class MyFunction(args: String*)

trait MyFunction0[+R] extends MyFunction {
def apply(): R
}

abstract class MyFunction1[R](args: R*):
def apply(): R

def fromFunction0[R](f: Function0[R]): MyFunction0[R] = () => f()
def fromFunction1[R](f: Function0[R]): MyFunction1[R] = () => f()

@main def Test =
val m0: MyFunction0[Int] = fromFunction0(() => 1)
val m1: MyFunction1[Int] = fromFunction1(() => 2)
assert(m0() == 1)
assert(m1() == 2)

0 comments on commit 0c5915c

Please sign in to comment.