Skip to content

Commit

Permalink
Fix mapping and pickling of annotated types
Browse files Browse the repository at this point in the history
`Annotation.mapWith` maps an `Annotation` with a type map `tm`. Before actually applying `tm` to the annotation’s `tree`, it first checks if `tm` would result in any change by applying it to the types of the annotation’s arguments, and checking if the mapped types are different. This optimization had two problems: it didn’t include type parameters, and used `frozen_=:=`  to compare types, which failed to detected some changes. This commit changes `Annotation.arguments` to also include type parameters, and, and changes `Annotation.MapWith` to use `==` to compare types instead of `frozen_=:=`.

Furthermore, in case of changes, the symbol in the annotation's tree should be copied to make sure that the same symbol is not used for different trees. This commit achieves this by using a custom `TreeTypeMap` with an overridden `withMappedSyms` method where `Symbols.mapSymbols` is called with the argument `mapAlways = true`.

Finally, positons of trees that appear inside `AnnotatedType` only were not pickled. This commit also fixes this.
  • Loading branch information
mbovel committed Apr 18, 2024
1 parent d148973 commit 878cc3e
Show file tree
Hide file tree
Showing 15 changed files with 133 additions and 20 deletions.
4 changes: 2 additions & 2 deletions compiler/src/dotty/tools/dotc/ast/TreeInfo.scala
Original file line number Diff line number Diff line change
Expand Up @@ -134,10 +134,10 @@ trait TreeInfo[T <: Untyped] { self: Trees.Instance[T] =>
case _ => argss
loop(tree, Nil)

/** All term arguments of an application in a single flattened list */
/** All type and term arguments of an application in a single flattened list */
def allArguments(tree: Tree): List[Tree] = unsplice(tree) match {
case Apply(fn, args) => allArguments(fn) ::: args
case TypeApply(fn, _) => allArguments(fn)
case TypeApply(fn, args) => allArguments(fn) ::: args
case Block(_, expr) => allArguments(expr)
case _ => Nil
}
Expand Down
29 changes: 13 additions & 16 deletions compiler/src/dotty/tools/dotc/core/Annotations.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@ package dotc
package core

import Symbols.*, Types.*, Contexts.*, Constants.*, Phases.*
import ast.tpd, tpd.*
import util.Spans.Span
import ast.{tpd, untpd, TreeTypeMap}
import tpd.*
import util.Spans.{Span, NoSpan}
import printing.{Showable, Printer}
import printing.Texts.Text

Expand All @@ -30,7 +31,7 @@ object Annotations {
def derivedAnnotation(tree: Tree)(using Context): Annotation =
if (tree eq this.tree) this else Annotation(tree)

/** All arguments to this annotation in a single flat list */
/** All type and term arguments to this annotation in a single flat list */
def arguments(using Context): List[Tree] = tpd.allArguments(tree)

def argument(i: Int)(using Context): Option[Tree] = {
Expand All @@ -54,19 +55,15 @@ object Annotations {
* type, since ranges cannot be types of trees.
*/
def mapWith(tm: TypeMap)(using Context) =
val args = arguments
if args.isEmpty then this
else
val findDiff = new TreeAccumulator[Type]:
def apply(x: Type, tree: Tree)(using Context): Type =
if tm.isRange(x) then x
else
val tp1 = tm(tree.tpe)
foldOver(if tp1 frozen_=:= tree.tpe then x else tp1, tree)
val diff = findDiff(NoType, args)
if tm.isRange(diff) then EmptyAnnotation
else if diff.exists then derivedAnnotation(tm.mapOver(tree))
else this
val originalTree = tree
val mappedTree = tm.mapOver(originalTree)
if mappedTree neq originalTree then
val ttm =
new TreeTypeMap(typeMap = tm):
final override def withMappedSyms(syms: List[Symbol]): TreeTypeMap =
withMappedSyms(syms, mapSymbols(syms, this, mapAlways = true))
derivedAnnotation(ttm.transform(tree))
derivedAnnotation(tm.mapOver(tree))

/** Does this annotation refer to a parameter of `tl`? */
def refersToParamOf(tl: TermLambda)(using Context): Boolean =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ object PositionPickler:
pickler: TastyPickler,
addrOfTree: TreeToAddr,
treeAnnots: untpd.MemberDef => List[tpd.Tree],
typeAnnots: List[tpd.Tree],
relativePathReference: String,
source: SourceFile,
roots: List[Tree],
Expand Down Expand Up @@ -136,6 +137,9 @@ object PositionPickler:
}
for (root <- roots)
traverse(root, NoSource)

for annotTree <- typeAnnots do
traverse(annotTree, NoSource)
end picklePositions
end PositionPickler

7 changes: 7 additions & 0 deletions compiler/src/dotty/tools/dotc/core/tasty/TreePickler.scala
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,10 @@ class TreePickler(pickler: TastyPickler, attributes: Attributes) {
*/
private val annotTrees = util.EqHashMap[untpd.MemberDef, mutable.ListBuffer[Tree]]()

/** A set of annotation trees appearing in annotated types.
*/
private val annotatedTypeTrees = mutable.ListBuffer[Tree]()

/** A map from member definitions to their doc comments, so that later
* parallel comment pickling does not need to access symbols of trees (which
* would involve accessing symbols of named types and possibly changing phases
Expand All @@ -56,6 +60,8 @@ class TreePickler(pickler: TastyPickler, attributes: Attributes) {
val ts = annotTrees.lookup(tree)
if ts == null then Nil else ts.toList

def typeAnnots: List[Tree] = annotatedTypeTrees.toList

def docString(tree: untpd.MemberDef): Option[Comment] =
Option(docStrings.lookup(tree))

Expand Down Expand Up @@ -266,6 +272,7 @@ class TreePickler(pickler: TastyPickler, attributes: Attributes) {
case tpe: AnnotatedType =>
writeByte(ANNOTATEDtype)
withLength { pickleType(tpe.parent, richTypes); pickleTree(tpe.annot.tree) }
annotatedTypeTrees += tpe.annot.tree
case tpe: AndType =>
writeByte(ANDtype)
withLength { pickleType(tpe.tp1, richTypes); pickleType(tpe.tp2, richTypes) }
Expand Down
2 changes: 1 addition & 1 deletion compiler/src/dotty/tools/dotc/quoted/PickledQuotes.scala
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ object PickledQuotes {
if tree.span.exists then
val positionWarnings = new mutable.ListBuffer[Message]()
val reference = ctx.settings.sourceroot.value
PositionPickler.picklePositions(pickler, treePkl.buf.addrOfTree, treePkl.treeAnnots, reference,
PositionPickler.picklePositions(pickler, treePkl.buf.addrOfTree, treePkl.treeAnnots, treePkl.typeAnnots, reference,
ctx.compilationUnit.source, tree :: Nil, positionWarnings)
positionWarnings.foreach(report.warning(_))

Expand Down
2 changes: 1 addition & 1 deletion compiler/src/dotty/tools/dotc/transform/Pickler.scala
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ class Pickler extends Phase {
if tree.span.exists then
val reference = ctx.settings.sourceroot.value
PositionPickler.picklePositions(
pickler, treePkl.buf.addrOfTree, treePkl.treeAnnots, reference,
pickler, treePkl.buf.addrOfTree, treePkl.treeAnnots, treePkl.typeAnnots, reference,
unit.source, tree :: Nil, positionWarnings,
scratch.positionBuffer, scratch.pickledIndices)

Expand Down
7 changes: 7 additions & 0 deletions tests/pos/annot-17939.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
class qualified[T](f: T => Boolean) extends annotation.StaticAnnotation

class Box[T](val x: T)
class Box2(val x: Int)

class A(a: String @qualified((x: Int) => Box(3).x == 3)) // crash
class A2(a2: String @qualified((x: Int) => Box2(3).x == 3)) // works
10 changes: 10 additions & 0 deletions tests/pos/annot-17939b.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
import scala.annotation.Annotation
class myRefined(f: ? => Boolean) extends Annotation

def test(axes: Int) = true

trait Tensor:
def mean(axes: Int): Int @myRefined(_ => test(axes))

class TensorImpl() extends Tensor:
def mean(axes: Int) = ???
8 changes: 8 additions & 0 deletions tests/pos/annot-19846.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
class qualified[T](predicate: T => Boolean) extends annotation.StaticAnnotation

class EqualPair(val x: Int, val y: Int @qualified[Int](it => it == x))

@main def main =
val p = EqualPair(42, 42)
val y = p.y
println(42)
7 changes: 7 additions & 0 deletions tests/pos/annot-19846b.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
class qualified[T](predicate: T => Boolean) extends annotation.StaticAnnotation

def f(x: Int): Int @qualified[Int](it => it == x) = ???

@main def main =
val z = f(42)
()
10 changes: 10 additions & 0 deletions tests/pos/annot-5789.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
class Annot[T] extends scala.annotation.Annotation

class D[T](val f: Int@Annot[T])

object A{
def main(a:Array[String]) = {
val c = new D[Int](1)
c.f
}
}
16 changes: 16 additions & 0 deletions tests/printing/annot-18064.check
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
[[syntax trees at end of typer]] // tests/printing/annot-18064.scala
package <empty> {
class myAnnot[T >: Nothing <: Any]() extends annotation.Annotation() {
T
}
trait Tensor[T >: Nothing <: Any]() extends Object {
T
def add: Tensor[Tensor.this.T] @myAnnot[T]
}
class TensorImpl[A >: Nothing <: Any]() extends Object(), Tensor[
TensorImpl.this.A] {
A
def add: Tensor[A] @myAnnot[A] = this
}
}

7 changes: 7 additions & 0 deletions tests/printing/annot-18064.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
class myAnnot[T]() extends annotation.Annotation

trait Tensor[T]:
def add: Tensor[T] @myAnnot[T]()

class TensorImpl[A]() extends Tensor[A]:
def add /* : Tensor[A] @myAnnot[A] */ = this
33 changes: 33 additions & 0 deletions tests/printing/annot-19846b.check
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
[[syntax trees at end of typer]] // tests/printing/annot-19846b.scala
package <empty> {
class lambdaAnnot(g: () => Int) extends scala.annotation.Annotation(),
annotation.StaticAnnotation {
private[this] val g: () => Int
}
final lazy module val Test: Test = new Test()
final module class Test() extends Object() { this: Test.type =>
val y: Int = ???
val z:
Int @lambdaAnnot(
{
def $anonfun(): Int = Test.y
closure($anonfun)
}
)
= f(Test.y)
}
final lazy module val annot-19846b$package: annot-19846b$package =
new annot-19846b$package()
final module class annot-19846b$package() extends Object() {
this: annot-19846b$package.type =>
def f(x: Int):
Int @lambdaAnnot(
{
def $anonfun(): Int = x
closure($anonfun)
}
)
= x
}
}

7 changes: 7 additions & 0 deletions tests/printing/annot-19846b.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
class lambdaAnnot(g: () => Int) extends annotation.StaticAnnotation

def f(x: Int): Int @lambdaAnnot(() => x) = x

object Test:
val y: Int = ???
val z /* : Int @lambdaAnnot(() => y) */ = f(y)

0 comments on commit 878cc3e

Please sign in to comment.