Skip to content

Commit

Permalink
Fix pkg obj prefix of opaque tp ext meth (#21527)
Browse files Browse the repository at this point in the history
  • Loading branch information
dwijnand authored Nov 12, 2024
2 parents 33b3d60 + 7db83c5 commit 896965c
Show file tree
Hide file tree
Showing 12 changed files with 125 additions and 36 deletions.
1 change: 0 additions & 1 deletion compiler/src/dotty/tools/dotc/core/TypeErasure.scala
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ package core
import Symbols.*, Types.*, Contexts.*, Flags.*, Names.*, StdNames.*, Phases.*
import Flags.JavaDefined
import Uniques.unique
import TypeOps.makePackageObjPrefixExplicit
import backend.sjs.JSDefinitions
import transform.ExplicitOuter.*
import transform.ValueClasses.*
Expand Down
30 changes: 0 additions & 30 deletions compiler/src/dotty/tools/dotc/core/TypeOps.scala
Original file line number Diff line number Diff line change
Expand Up @@ -560,36 +560,6 @@ object TypeOps:
widenMap(tp)
}

/** If `tpe` is of the form `p.x` where `p` refers to a package
* but `x` is not owned by a package, expand it to
*
* p.package.x
*/
def makePackageObjPrefixExplicit(tpe: NamedType)(using Context): Type = {
def tryInsert(pkgClass: SymDenotation): Type = pkgClass match {
case pkg: PackageClassDenotation =>
var sym = tpe.symbol
if !sym.exists && tpe.denot.isOverloaded then
// we know that all alternatives must come from the same package object, since
// otherwise we would get "is already defined" errors. So we can take the first
// symbol we see.
sym = tpe.denot.alternatives.head.symbol
val pobj = pkg.packageObjFor(sym)
if (pobj.exists) tpe.derivedSelect(pobj.termRef)
else tpe
case _ =>
tpe
}
if (tpe.symbol.isRoot)
tpe
else
tpe.prefix match {
case pre: ThisType if pre.cls.is(Package) => tryInsert(pre.cls)
case pre: TermRef if pre.symbol.is(Package) => tryInsert(pre.symbol.moduleClass)
case _ => tpe
}
}

/** An argument bounds violation is a triple consisting of
* - the argument tree
* - a string "upper" or "lower" indicating which bound is violated
Expand Down
32 changes: 31 additions & 1 deletion compiler/src/dotty/tools/dotc/core/TypeUtils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ package dotc
package core

import TypeErasure.ErasedValueType
import Types.*, Contexts.*, Symbols.*, Flags.*, Decorators.*
import Types.*, Contexts.*, Symbols.*, Flags.*, Decorators.*, SymDenotations.*
import Names.{Name, TermName}
import Constants.Constant

Expand Down Expand Up @@ -186,6 +186,36 @@ class TypeUtils:
case self: Types.ThisType => self.cls == cls
case _ => false

/** If `self` is of the form `p.x` where `p` refers to a package
* but `x` is not owned by a package, expand it to
*
* p.package.x
*/
def makePackageObjPrefixExplicit(using Context): Type =
def tryInsert(tpe: NamedType, pkgClass: SymDenotation): Type = pkgClass match
case pkg: PackageClassDenotation =>
var sym = tpe.symbol
if !sym.exists && tpe.denot.isOverloaded then
// we know that all alternatives must come from the same package object, since
// otherwise we would get "is already defined" errors. So we can take the first
// symbol we see.
sym = tpe.denot.alternatives.head.symbol
val pobj = pkg.packageObjFor(sym)
if pobj.exists then tpe.derivedSelect(pobj.termRef)
else tpe
case _ =>
tpe
self match
case tpe: NamedType =>
if tpe.symbol.isRoot then
tpe
else
tpe.prefix match
case pre: ThisType if pre.cls.is(Package) => tryInsert(tpe, pre.cls)
case pre: TermRef if pre.symbol.is(Package) => tryInsert(tpe, pre.symbol.moduleClass)
case _ => tpe
case tpe => tpe

/** Strip all outer refinements off this type */
def stripRefinement: Type = self match
case self: RefinedOrRecType => self.parent.stripRefinement
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1273,7 +1273,7 @@ class TreeUnpickler(reader: TastyReader,
val tpe0 = name match
case name: TypeName => TypeRef(qualType, name, denot)
case name: TermName => TermRef(qualType, name, denot)
val tpe = TypeOps.makePackageObjPrefixExplicit(tpe0)
val tpe = tpe0.makePackageObjPrefixExplicit
ConstFold.Select(untpd.Select(qual, name).withType(tpe))

def completeSelect(name: Name, sig: Signature, target: Name): Select =
Expand Down
2 changes: 1 addition & 1 deletion compiler/src/dotty/tools/dotc/typer/TypeAssigner.scala
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ trait TypeAssigner {
defn.FromJavaObjectType
else tpe match
case tpe: NamedType =>
val tpe1 = TypeOps.makePackageObjPrefixExplicit(tpe)
val tpe1 = tpe.makePackageObjPrefixExplicit
if tpe1 ne tpe then
accessibleType(tpe1, superAccess)
else
Expand Down
4 changes: 2 additions & 2 deletions compiler/src/dotty/tools/dotc/typer/Typer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
// so we ignore that import.
if reallyExists(denot) && !isScalaJsPseudoUnion then
if unimported.isEmpty || !unimported.contains(pre.termSymbol) then
return pre.select(name, denot)
return pre.select(name, denot).makePackageObjPrefixExplicit
case _ =>
if imp.importSym.isCompleting then
report.warning(i"cyclic ${imp.importSym}, ignored", pos)
Expand Down Expand Up @@ -504,7 +504,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
defDenot.symbol.owner
else
curOwner
effectiveOwner.thisType.select(name, defDenot)
effectiveOwner.thisType.select(name, defDenot).makePackageObjPrefixExplicit
}
if !curOwner.is(Package) || isDefinedInCurrentUnit(defDenot) then
result = checkNewOrShadowed(found, Definition) // no need to go further out, we found highest prec entry
Expand Down
22 changes: 22 additions & 0 deletions tests/pos/i18097.1.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
opaque type Pos = Double

object Pos:
extension (x: Pos)
def mult1(y: Pos): Pos = x * y

extension (x: Pos)
def mult2(y: Pos): Pos = x * y

class Test:
def test(key: String, a: Pos, b: Pos): Unit =
val tup1 = new Tuple1(Pos.mult1(a)(b))
val res1: Pos = tup1._1

val tup2 = new Tuple1(a.mult1(b))
val res2: Pos = tup2._1

val tup3 = new Tuple1(mult2(a)(b))
val res3: Pos = tup3._1

val tup4 = new Tuple1(a.mult2(b))
val res4: Pos = tup4._1 // was error: Found: (tup4._4 : Double) Required: Pos
13 changes: 13 additions & 0 deletions tests/pos/i18097.2.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
opaque type Namespace = List[String]

object Namespace:
def apply(head: String): Namespace = List(head)

extension (ns: Namespace)
def appended(segment: String): Namespace = ns.appended(segment)

object Main:
def main(args: Array[String]): Unit =
val a: Namespace = Namespace("a")
.appended("B")
.appended("c") // was error: Found: List[String] Required: Namespace
13 changes: 13 additions & 0 deletions tests/pos/i18097.2.works.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
object Main:
opaque type Namespace = List[String]

object Namespace:
def apply(head: String): Namespace = List(head)

extension (ns: Namespace)
def appended(segment: String): Namespace = ns.appended(segment)

def main(args: Array[String]): Unit =
val a: Namespace = Namespace("a")
.appended("B")
.appended("c")
9 changes: 9 additions & 0 deletions tests/pos/i18097.3/Opaque.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
package test

type Foo = Unit
val bar: Foo = ()

opaque type Opaque = Unit

extension (foo: Foo)
def go: Option[Opaque] = ???
13 changes: 13 additions & 0 deletions tests/pos/i18097.3/Test.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
package test

final case class Test(value: Opaque)

def test: Test =
bar.go match
case Some(value) => Test(value) // was error: Found: (value : Unit) Required: test.Opaque
case _ => ???

def test2: Test =
go(bar) match
case Some(value) => Test(value)
case _ => ???
20 changes: 20 additions & 0 deletions tests/pos/i18097.orig.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
opaque type PositiveNumber = Double

object PositiveNumber:
extension (x: PositiveNumber)
def mult1(other: PositiveNumber): PositiveNumber =
x * other

extension (x: PositiveNumber)
def mult2(other: PositiveNumber): PositiveNumber =
x * other

object Test:
def multMap1[A](x: Map[A, PositiveNumber], num: PositiveNumber): Map[A, PositiveNumber] = x.map((key, value) => key -> value.mult1(num)).toMap

def multMap2[A](x: Map[A, PositiveNumber], num: PositiveNumber): Map[A, PositiveNumber] = x.map((key, value) => key -> value.mult2(num)).toMap // was error
// ^
// Cannot prove that (A, Double) <:< (A, V2).
//
// where: V2 is a type variable with constraint <: PositiveNumber
def multMap2_2[A](x: Map[A, PositiveNumber], num: PositiveNumber): Map[A, PositiveNumber] = x.map((key, value) => key -> mult2(value)(num)).toMap

0 comments on commit 896965c

Please sign in to comment.