Skip to content

Commit

Permalink
Fixes for Metals infer expected type feature
Browse files Browse the repository at this point in the history
Fixes to makeVarArg and assignType(SeqLiteral) fixes the propagation of
vararg type errors, which would trip up InferExpectedType's looking at
whether the tree's type are errors.

Force LazyRef in AvoidWildcardsMap to avoid unneeded new LazyRef's.
LazyRefs, which are created for recursive types, aren't cacheable, so if
you TypeMap an AppliedType with one, it will create a brand new
instance.  OrderingConstraint#init runs AvoidWildcardsMap on param
bounds.  So when instDirection compares the constraint bounds and the
original param bounds (to calculate the instantiate direction), because
they are new instances they won't shortcircuit, leading to a recursion
overflow.  By forcing, it will eq check and return true.

Finally, change interpolateTypeVars' instantiation decision, following
the logic that isFullyDefined had.  In particular, if the typevar has an
upper bound constraint, maximise rather than minimise, which fixes the
inference of map/flatMap's B type args.

Finally, drop needless tree.tpe.isInstanceOf[MethodOrPoly] and
tree.tpe.widen in interpolateTypeVars.
  • Loading branch information
dwijnand committed Aug 14, 2024
1 parent e896db2 commit 58cec81
Show file tree
Hide file tree
Showing 8 changed files with 87 additions and 87 deletions.
4 changes: 4 additions & 0 deletions compiler/src/dotty/tools/dotc/core/Types.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4693,6 +4693,7 @@ object Types extends TypeUtils {
type BT <: LambdaType
def paramNum: Int
def paramName: binder.ThisName = binder.paramNames(paramNum)
def paramInfo: binder.PInfo = binder.paramInfos(paramNum)

override def underlying(using Context): Type = {
// TODO: update paramInfos's type to nullable
Expand Down Expand Up @@ -6631,6 +6632,9 @@ object Types extends TypeUtils {
range(atVariance(-variance)(apply(bounds.lo)), apply(bounds.hi))
def apply(t: Type): Type = t match
case t: WildcardType => mapWild(t)
case tp: LazyRef => mapOver(tp) match
case tp1: LazyRef if tp.ref eq tp1.ref => tp
case tp1 => tp1
case _ => mapOver(t)

// ----- TypeAccumulators ----------------------------------------------------
Expand Down
3 changes: 2 additions & 1 deletion compiler/src/dotty/tools/dotc/typer/Applications.scala
Original file line number Diff line number Diff line change
Expand Up @@ -883,7 +883,8 @@ trait Applications extends Compatibility {
def makeVarArg(n: Int, elemFormal: Type): Unit = {
val args = typedArgBuf.takeRight(n).toList
typedArgBuf.dropRightInPlace(n)
val elemtpt = TypeTree(elemFormal, inferred = true)
val elemtp = if !args.exists(_.tpe.isError) then elemFormal else UnspecifiedErrorType
val elemtpt = TypeTree(elemtp, inferred = true)
typedArgBuf += seqToRepeated(SeqLiteral(args, elemtpt))
}

Expand Down
73 changes: 31 additions & 42 deletions compiler/src/dotty/tools/dotc/typer/Inferencing.scala
Original file line number Diff line number Diff line change
Expand Up @@ -240,25 +240,12 @@ object Inferencing {
&& {
var fail = false
var skip = false
val direction = instDirection(tvar.origin)
if minimizeSelected then
if direction <= 0 && tvar.hasLowerBound then
skip = instantiate(tvar, fromBelow = true)
else if direction >= 0 && tvar.hasUpperBound then
skip = instantiate(tvar, fromBelow = false)
// else hold off instantiating unbounded unconstrained variable
else if direction != 0 then
skip = instantiate(tvar, fromBelow = direction < 0)
else if variance >= 0 && tvar.hasLowerBound then
skip = instantiate(tvar, fromBelow = true)
else if (variance > 0 || variance == 0 && !tvar.hasUpperBound)
&& force.ifBottom == IfBottom.ok
then // if variance == 0, prefer upper bound if one is given
skip = instantiate(tvar, fromBelow = true)
else if variance >= 0 && force.ifBottom == IfBottom.fail then
fail = true
else
toMaximize = tvar :: toMaximize
instDecision(tvar.origin, variance, minimizeSelected, force.ifBottom) match
case Decision.Min => skip = instantiate(tvar, fromBelow = true)
case Decision.Max => skip = instantiate(tvar, fromBelow = false)
case Decision.Skip => // hold off instantiating unbounded unconstrained variable
case Decision.Fail => fail = true
case Decision.ToMax => toMaximize ::= tvar
!fail && (skip || foldOver(x, tvar))
}
case tp => foldOver(x, tp)
Expand Down Expand Up @@ -438,22 +425,20 @@ object Inferencing {
occurring(tree, boundVars(tree, Nil), Nil)
}

/** The instantiation direction for given poly param computed
* from the constraint:
* @return 1 (maximize) if constraint is uniformly from above,
* -1 (minimize) if constraint is uniformly from below,
* 0 if unconstrained, or constraint is from below and above.
*/
private def instDirection(param: TypeParamRef)(using Context): Int = {
val constrained = TypeComparer.fullBounds(param)
val original = param.binder.paramInfos(param.paramNum)
val cmp = TypeComparer
val approxBelow =
if (!cmp.isSubTypeWhenFrozen(constrained.lo, original.lo)) 1 else 0
val approxAbove =
if (!cmp.isSubTypeWhenFrozen(original.hi, constrained.hi)) 1 else 0
approxAbove - approxBelow
}
/** The instantiation decision for given poly param computed from the constraint. */
enum Decision { case Min; case Max; case ToMax; case Skip; case Fail }
private def instDecision(param: TypeParamRef, v: Int, min: Boolean, ifBottom: IfBottom)(using Context): Decision =
import Decision.*
val tb = param.paramInfo // type bounds
val cb = TypeComparer.fullBounds(param) // constrained bounds
val dir = (if cb.lo frozen_<:< tb.lo then 0 else -1) + (if tb.hi frozen_<:< cb.hi then 0 else 1)
if dir < 0 || (min || v >= 0) && !cb.lo.isExactlyNothing then Min
else if dir > 0 || (min || v == 0) && !cb.hi.isTopOfSomeKind then Max // prefer upper bound if one is given
else if min then Skip
else ifBottom match
case IfBottom.ok => if v >= 0 then Min else ToMax
case IfBottom.fail => if v >= 0 then Fail else ToMax
case ifBottom_flip => ToMax

/** Following type aliases and stripping refinements and annotations, if one arrives at a
* class type reference where the class has a companion module, a reference to
Expand Down Expand Up @@ -651,16 +636,17 @@ trait Inferencing { this: Typer =>

val ownedVars = state.ownedVars
if (ownedVars ne locked) && !ownedVars.isEmpty then
val qualifying = ownedVars -- locked
val qualifying = (ownedVars -- locked).toList
if (!qualifying.isEmpty) {
typr.println(i"interpolate $tree: ${tree.tpe.widen} in $state, pt = $pt, owned vars = ${state.ownedVars.toList}%, %, qualifying = ${qualifying.toList}%, %, previous = ${locked.toList}%, % / ${state.constraint}")
typr.println(i"interpolate $tree: ${tree.tpe.widen} in $state, pt = $pt, owned vars = ${ownedVars.toList}, qualifying = $qualifying.toList}, previous = ${locked.toList}%, % / ${state.constraint}")
val resultAlreadyConstrained =
tree.isInstanceOf[Apply] || tree.tpe.isInstanceOf[MethodOrPoly]
tree.isInstanceOf[Apply]
if (!resultAlreadyConstrained)
trace(i"constrainResult($tree ${tree.symbol}, ${tree.tpe}, $pt)"):
constrainResult(tree.symbol, tree.tpe, pt)
// This is needed because it could establish singleton type upper bounds. See i2998.scala.

val tp = tree.tpe.widen
val tp = tree.tpe
val vs = variances(tp, pt)

// Avoid interpolating variables occurring in tree's type if typerstate has unreported errors.
Expand All @@ -687,6 +673,8 @@ trait Inferencing { this: Typer =>

def constraint = state.constraint

trace(i"interpolateTypeVars($tree: ${tree.tpe}, $pt, $qualifying)", typr, (_: Any) => i"$qualifying $constraint") {

/** Values of this type report type variables to instantiate with variance indication:
* +1 variable appears covariantly, can be instantiated from lower bound
* -1 variable appears contravariantly, can be instantiated from upper bound
Expand Down Expand Up @@ -782,12 +770,10 @@ trait Inferencing { this: Typer =>
/** Try to instantiate `tvs`, return any suspended type variables */
def tryInstantiate(tvs: ToInstantiate): ToInstantiate = tvs match
case (hd @ (tvar, v)) :: tvs1 =>
val fromBelow = v == 1 || (v == 0 && tvar.hasLowerBound)
typr.println(
i"interpolate${if v == 0 then " non-occurring" else ""} $tvar in $state in $tree: $tp, fromBelow = $fromBelow, $constraint")
if tvar.isInstantiated then
tryInstantiate(tvs1)
else
val fromBelow = instDecision(tvar.origin, v, false, IfBottom.flip) == Decision.Min
val suspend = tvs1.exists{ (following, _) =>
if fromBelow
then constraint.isLess(following.origin, tvar.origin)
Expand All @@ -797,13 +783,16 @@ trait Inferencing { this: Typer =>
typr.println(i"suspended: $hd")
hd :: tryInstantiate(tvs1)
else
typr.println(
i"interpolate${if v == 0 then " non-occurring" else ""} $tvar in $state in $tree: $tp, fromBelow = $fromBelow, $constraint")
tvar.instantiate(fromBelow)
tryInstantiate(tvs1)
case Nil => Nil
if tvs.nonEmpty then doInstantiate(tryInstantiate(tvs))
end doInstantiate

doInstantiate(filterByDeps(toInstantiate))
}
}
end if
tree
Expand Down
2 changes: 1 addition & 1 deletion compiler/src/dotty/tools/dotc/typer/TypeAssigner.scala
Original file line number Diff line number Diff line change
Expand Up @@ -469,7 +469,7 @@ trait TypeAssigner {
else tree.withType(TypeComparer.lub(expr.tpe :: cases.tpes))

def assignType(tree: untpd.SeqLiteral, elems: List[Tree], elemtpt: Tree)(using Context): SeqLiteral =
tree.withType(seqLitType(tree, elemtpt.tpe))
tree.withType(if elemtpt.tpe.isError then elemtpt.tpe else seqLitType(tree, elemtpt.tpe))

def assignType(tree: untpd.SingletonTypeTree, ref: Tree)(using Context): SingletonTypeTree =
tree.withType(ref.tpe)
Expand Down
50 changes: 24 additions & 26 deletions compiler/test/dotty/tools/dotc/typer/InstantiateModel.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4,22 +4,16 @@ package typer

// Modelling the decision in IsFullyDefined
object InstantiateModel:
enum LB { case NN; case LL; case L1 }; import LB.*
enum UB { case AA; case UU; case U1 }; import UB.*
enum Var { case V; case NotV }; import Var.*
enum MSe { case M; case NotM }; import MSe.*
enum Bot { case Fail; case Ok; case Flip }; import Bot.*
enum Act { case Min; case Max; case ToMax; case Skip; case False }; import Act.*
enum LB { case NN; case LL; case L1 }; import LB.*
enum UB { case AA; case UU; case U1 }; import UB.*
import Inferencing.Decision.*

// NN/AA = Nothing/Any
// LL/UU = the original bounds, on the type parameter
// L1/U1 = the constrained bounds, on the type variable
// V = variance >= 0 ("non-contravariant")
// MSe = minimisedSelected
// Bot = IfBottom
// ToMax = delayed maximisation, via addition to toMaximize
// Skip = minimisedSelected "hold off instantiating"
// False = return false
// Fail = IfBottom.fail's bail option

// there are 9 combinations:
// # | LB | UB | d | // d = direction
Expand All @@ -34,24 +28,28 @@ object InstantiateModel:
// 8 | NN | UU | 0 | T <: UU
// 9 | NN | AA | 0 | T

def decide(lb: LB, ub: UB, v: Var, bot: Bot, m: MSe): Act = (lb, ub) match
def instDecision(lb: LB, ub: UB, v: Int, ifBottom: IfBottom, min: Boolean) = (lb, ub) match
case (L1, AA) => Min
case (L1, UU) => Min
case (LL, U1) => Max
case (NN, U1) => Max

case (L1, U1) => if m==M || v==V then Min else ToMax
case (LL, UU) => if m==M || v==V then Min else ToMax
case (LL, AA) => if m==M || v==V then Min else ToMax

case (NN, UU) => bot match
case _ if m==M => Max
//case Ok if v==V => Min // removed, i14218 fix
case Fail if v==V => False
case _ => ToMax

case (NN, AA) => bot match
case _ if m==M => Skip
case Ok if v==V => Min
case Fail if v==V => False
case _ => ToMax
case (L1, U1) => if min then Min else pickVar(v, Min, Min, ToMax)
case (LL, UU) => if min then Min else pickVar(v, Min, Min, ToMax)
case (LL, AA) => if min then Min else pickVar(v, Min, Min, ToMax)

case (NN, UU) => ifBottom match
case IfBottom.ok => pickVar(v, Min, ToMax, ToMax)
case IfBottom.fail => pickVar(v, Fail, Fail, ToMax)
case IfBottom.flip => if min then Max else ToMax

case (NN, AA) => ifBottom match
case IfBottom.ok => pickVar(v, Min, Min, ToMax)
case IfBottom.fail => pickVar(v, Fail, Fail, ToMax)
case IfBottom.flip => if min then Skip else ToMax

def interpolateTypeVars(lb: LB, ub: UB, v: Int) =
instDecision(lb, ub, v, IfBottom.flip, min = false)

def pickVar[A](v: Int, cov: A, inv: A, con: A) =
if v > 0 then cov else if v == 0 then inv else con
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ import dotty.tools.dotc.core.Types.*
import dotty.tools.dotc.core.Types.Type
import dotty.tools.dotc.interactive.Interactive
import dotty.tools.dotc.interactive.InteractiveDriver
import dotty.tools.dotc.typer.Applications.UnapplyArgs
import dotty.tools.dotc.util.NoSourcePosition
import dotty.tools.dotc.util.SourceFile
import dotty.tools.dotc.util.Spans.Span
import dotty.tools.pc.IndexedContext
Expand Down Expand Up @@ -86,9 +88,15 @@ object InterCompletionType:
// List(@@)
case SeqLiteral(_, tpe) :: _ if !tpe.tpe.isErroneous =>
Some(tpe.tpe)
case SeqLiteral(_, _) :: _typed :: rest =>
inferType(rest, span)
// val _: T = @@
// def _: T = @@
case (defn: ValOrDefDef) :: rest if !defn.tpt.tpe.isErroneous => Some(defn.tpt.tpe)
case UnApply(fun, _, pats) :: _ =>
val ind = pats.indexWhere(_.span.contains(span))
if ind < 0 then None
else Some(UnapplyArgs(fun.tpe.finalResultType, fun, pats, NoSourcePosition).argTypes(ind))
// f(@@)
case (app: Apply) :: rest =>
val param =
Expand All @@ -98,7 +106,7 @@ object InterCompletionType:
}
params <- app.symbol.paramSymss.find(!_.exists(_.isTypeParam))
param <- params.get(ind)
} yield param.info
} yield param.info.repeatedToSingle
param match
// def f[T](a: T): T = ???
// f[Int](@@)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ class InferExpectedTypeSuite extends BasePCSuite:
EmptyCancelToken
)
presentationCompiler.asInstanceOf[ScalaPresentationCompiler].inferExpectedType(offsetParams).get().asScala match {
case Some(value) => assertNoDiff(value, expectedType)
case Some(value) => assertNoDiff(expectedType, value)
case None => fail("Empty result.")
}

Expand Down Expand Up @@ -55,7 +55,6 @@ class InferExpectedTypeSuite extends BasePCSuite:
|""".stripMargin
)

@Ignore("Not handled correctly.")
@Test def list =
check(
"""|val i: List[Int] = List(@@)
Expand Down Expand Up @@ -193,7 +192,6 @@ class InferExpectedTypeSuite extends BasePCSuite:
|""".stripMargin
)

@Ignore("Unapply is not handled correctly.")
@Test def unapply =
check(
"""|val _ =
Expand Down Expand Up @@ -223,7 +221,6 @@ class InferExpectedTypeSuite extends BasePCSuite:
|""".stripMargin
)

@Ignore("Generic functions are not handled correctly.")
@Test def flatmap =
check(
"""|val _ : List[Int] = List().flatMap(_ => @@)
Expand All @@ -232,7 +229,14 @@ class InferExpectedTypeSuite extends BasePCSuite:
|""".stripMargin
)

@Ignore("Generic functions are not handled correctly.")
@Test def map =
check(
"""|val _ : List[Int] = List().map(_ => @@)
|""".stripMargin,
"""|Int
|""".stripMargin
)

@Test def `for-comprehension` =
check(
"""|val _ : List[Int] =
Expand All @@ -245,40 +249,36 @@ class InferExpectedTypeSuite extends BasePCSuite:
)

// bounds
@Ignore("Bounds are not handled correctly.")
@Test def any =
check(
"""|trait Foo
|def foo[T](a: T): Boolean = ???
|val _ = foo(@@)
|""".stripMargin,
"""|<: Any
"""|Any
|""".stripMargin
)

@Ignore("Bounds are not handled correctly.")
@Test def `bounds-1` =
check(
"""|trait Foo
|def foo[T <: Foo](a: Foo): Boolean = ???
|def foo[T <: Foo](a: T): Boolean = ???
|val _ = foo(@@)
|""".stripMargin,
"""|<: Foo
"""|Foo
|""".stripMargin
)

@Ignore("Bounds are not handled correctly.")
@Test def `bounds-2` =
check(
"""|trait Foo
|def foo[T :> Foo](a: Foo): Boolean = ???
|def foo[T >: Foo](a: T): Boolean = ???
|val _ = foo(@@)
|""".stripMargin,
"""|:> Foo
"""|Foo
|""".stripMargin
)

@Ignore("Bounds are not handled correctly.")
@Test def `bounds-3` =
check(
"""|trait A
Expand All @@ -287,6 +287,6 @@ class InferExpectedTypeSuite extends BasePCSuite:
|def roo[F >: C <: A](f: F) = ???
|val kjk = roo(@@)
|""".stripMargin,
"""|>: C <: A
"""|C
|""".stripMargin
)
2 changes: 1 addition & 1 deletion tests/neg/recursive-lower-constraint.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,5 @@ class Bar extends Foo[Bar]

class A {
def foo[T <: Foo[T], U >: Foo[T] <: T](x: T): T = x
foo(new Bar) // error // error
foo(new Bar) // error
}

0 comments on commit 58cec81

Please sign in to comment.