Skip to content

Commit

Permalink
New capture escape checking based on levels (scala#18463)
Browse files Browse the repository at this point in the history
A new scope restriction scheme for capture checking based on levels.

The idea is to have a stack of capture roots where inner capture roots
are super-captures of outer roots.

Refines and supersedes scala#18348
  • Loading branch information
odersky authored Sep 9, 2023
2 parents 08f2faf + 4a45939 commit 64c3138
Show file tree
Hide file tree
Showing 102 changed files with 2,126 additions and 968 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -99,3 +99,6 @@ docs/_spec/.jekyll-metadata
# scaladoc related
scaladoc/output/

#coverage
coverage/

5 changes: 3 additions & 2 deletions compiler/src/dotty/tools/dotc/Run.scala
Original file line number Diff line number Diff line change
Expand Up @@ -247,8 +247,9 @@ class Run(comp: Compiler, ictx: Context) extends ImplicitRunInfo with Constraint
profiler.afterPhase(phase, profileBefore)
if (ctx.settings.Xprint.value.containsPhase(phase))
for (unit <- units)
lastPrintedTree =
printTree(lastPrintedTree)(using ctx.fresh.setPhase(phase.next).setCompilationUnit(unit))
def printCtx(unit: CompilationUnit) = phase.printingContext(
ctx.fresh.setPhase(phase.next).setCompilationUnit(unit))
lastPrintedTree = printTree(lastPrintedTree)(using printCtx(unit))
report.informTime(s"$phase ", start)
Stats.record(s"total trees at end of $phase", ast.Trees.ntrees)
for (unit <- units)
Expand Down
13 changes: 8 additions & 5 deletions compiler/src/dotty/tools/dotc/ast/Desugar.scala
Original file line number Diff line number Diff line change
Expand Up @@ -205,12 +205,12 @@ object desugar {

def makeImplicitParameters(
tpts: List[Tree], implicitFlag: FlagSet,
mkParamName: () => TermName,
mkParamName: Int => TermName,
forPrimaryConstructor: Boolean = false
)(using Context): List[ValDef] =
for (tpt, i) <- tpts.zipWithIndex yield {
val paramFlags: FlagSet = if (forPrimaryConstructor) LocalParamAccessor else Param
val epname = mkParamName()
val epname = mkParamName(i)
ValDef(epname, tpt, EmptyTree).withFlags(paramFlags | implicitFlag)
}

Expand Down Expand Up @@ -254,7 +254,7 @@ object desugar {
// using clauses, we only need names that are unique among the
// parameters of the method since shadowing does not affect
// implicit resolution in Scala 3.
mkParamName = () =>
mkParamName = i =>
val index = seenContextBounds + 1 // Start at 1 like FreshNameCreator.
val ret = ContextBoundParamName(EmptyTermName, index)
seenContextBounds += 1
Expand Down Expand Up @@ -1602,9 +1602,12 @@ object desugar {
case vd: ValDef => vd
}

def makeContextualFunction(formals: List[Tree], body: Tree, erasedParams: List[Boolean])(using Context): Function = {
def makeContextualFunction(formals: List[Tree], paramNamesOrNil: List[TermName], body: Tree, erasedParams: List[Boolean])(using Context): Function = {
val mods = Given
val params = makeImplicitParameters(formals, mods, mkParamName = () => ContextFunctionParamName.fresh())
val params = makeImplicitParameters(formals, mods,
mkParamName = i =>
if paramNamesOrNil.isEmpty then ContextFunctionParamName.fresh()
else paramNamesOrNil(i))
FunctionWithMods(params, body, Modifiers(mods), erasedParams)
}

Expand Down
36 changes: 36 additions & 0 deletions compiler/src/dotty/tools/dotc/ast/TreeInfo.scala
Original file line number Diff line number Diff line change
Expand Up @@ -376,6 +376,17 @@ trait TreeInfo[T <: Untyped] { self: Trees.Instance[T] =>
case _ =>
tree.tpe.isInstanceOf[ThisType]
}

/** Under capture checking, an extractor for qualified roots `cap[Q]`.
*/
object QualifiedRoot:

def unapply(tree: Apply)(using Context): Option[String] = tree match
case Apply(fn, Literal(lit) :: Nil) if fn.symbol == defn.Caps_capIn =>
Some(lit.value.asInstanceOf[String])
case _ =>
None
end QualifiedRoot
}

trait UntypedTreeInfo extends TreeInfo[Untyped] { self: Trees.Instance[Untyped] =>
Expand Down Expand Up @@ -799,12 +810,37 @@ trait TypedTreeInfo extends TreeInfo[Type] { self: Trees.Instance[Type] =>
}
}

/** An extractor for def of a closure contained the block of the closure,
* possibly with type ascriptions.
*/
object possiblyTypedClosureDef:
def unapply(tree: Tree)(using Context): Option[DefDef] = tree match
case Typed(expr, _) => unapply(expr)
case _ => closureDef.unapply(tree)

/** If tree is a closure, its body, otherwise tree itself */
def closureBody(tree: Tree)(using Context): Tree = tree match {
case closureDef(meth) => meth.rhs
case _ => tree
}

/** Is `mdef` an eta-expansion of a method reference? To recognize this, we use
* the following criterion: A method definition is an eta expansion, if
* it contains at least one term paramter, the parameter has a zero extent span,
* and the right hand side is either an application or a closure with'
* an anonymous method that's itself characterized as an eta expansion.
*/
def isEtaExpansion(mdef: DefDef)(using Context): Boolean =
!rhsOfEtaExpansion(mdef).isEmpty

def rhsOfEtaExpansion(mdef: DefDef)(using Context): Tree = mdef.paramss match
case (param :: _) :: _ if param.asInstanceOf[Tree].span.isZeroExtent =>
mdef.rhs match
case rhs: Apply => rhs
case closureDef(mdef1) => rhsOfEtaExpansion(mdef1)
case _ => EmptyTree
case _ => EmptyTree

/** The variables defined by a pattern, in reverse order of their appearance. */
def patVars(tree: Tree)(using Context): List[Symbol] = {
val acc = new TreeAccumulator[List[Symbol]] { outer =>
Expand Down
8 changes: 7 additions & 1 deletion compiler/src/dotty/tools/dotc/ast/untpd.scala
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,10 @@ object untpd extends Trees.Instance[Untyped] with UntypedTreeInfo {
case Floating
}

/** {x1, ..., xN} T (only relevant under captureChecking) */
/** {x1, ..., xN} T (only relevant under captureChecking)
* Created when parsing function types so that capture set and result type
* is combined in a single node.
*/
case class CapturesAndResult(refs: List[Tree], parent: Tree)(implicit @constructorOnly src: SourceFile) extends TypTree

/** A type tree appearing somewhere in the untyped DefDef of a lambda, it will be typed using `tpFun`.
Expand Down Expand Up @@ -512,6 +515,9 @@ object untpd extends Trees.Instance[Untyped] with UntypedTreeInfo {
def captureRoot(using Context): Select =
Select(scalaDot(nme.caps), nme.CAPTURE_ROOT)

def captureRootIn(using Context): Select =
Select(scalaDot(nme.caps), nme.capIn)

def makeRetaining(parent: Tree, refs: List[Tree], annotName: TypeName)(using Context): Annotated =
Annotated(parent, New(scalaAnnotationDot(annotName), List(refs)))

Expand Down
170 changes: 166 additions & 4 deletions compiler/src/dotty/tools/dotc/cc/CaptureOps.scala
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@ import config.SourceVersion
import config.Printers.capt
import util.Property.Key
import tpd.*
import StdNames.nme
import config.Feature
import collection.mutable

private val Captures: Key[CaptureSet] = Key()
private val BoxedType: Key[BoxedTypeCache] = Key()
Expand All @@ -21,6 +23,11 @@ private val BoxedType: Key[BoxedTypeCache] = Key()
*/
private val adaptUnpickledFunctionTypes = false

/** Switch whether we constrain a root var that includes the source of a
* root map to be an alias of that source (so that it can be mapped)
*/
private val constrainRootsWhenMapping = true

/** The arguments of a @retains or @retainsByName annotation */
private[cc] def retainedElems(tree: Tree)(using Context): List[Tree] = tree match
case Apply(_, Typed(SeqLiteral(elems, _), _) :: Nil) => elems
Expand All @@ -32,12 +39,82 @@ def allowUniversalInBoxed(using Context) =
/** An exception thrown if a @retains argument is not syntactically a CaptureRef */
class IllegalCaptureRef(tpe: Type) extends Exception

/** Capture checking state, which is stored in a context property */
class CCState:

val rhsClosure: mutable.HashSet[Symbol] = new mutable.HashSet

val levelOwners: mutable.HashSet[Symbol] = new mutable.HashSet

/** Associates certain symbols (the nesting level owners) with their ccNestingLevel */
val nestingLevels: mutable.HashMap[Symbol, Int] = new mutable.HashMap

/** Associates nesting level owners with the local roots valid in their scopes. */
val localRoots: mutable.HashMap[Symbol, Symbol] = new mutable.HashMap

/** The last pair of capture reference and capture set where
* the reference could not be added to the set due to a level conflict.
*/
var levelError: Option[(CaptureRef, CaptureSet)] = None

/** Under saferExceptions: The <try block> symbol generated for a try.
* Installed by Setup, removed by CheckCaptures.
*/
val tryBlockOwner: mutable.HashMap[Try, Symbol] = new mutable.HashMap
end CCState

/** Property key for capture checking state */
val ccStateKey: Key[CCState] = Key()

/** The currently valid CCState */
def ccState(using Context) = ctx.property(ccStateKey).get

trait FollowAliases extends TypeMap:
def mapOverFollowingAliases(t: Type): Type = t match
case t: LazyRef =>
val t1 = this(t.ref)
if t1 ne t.ref then t1 else t
case _ =>
val t1 = t.dealiasKeepAnnots
if t1 ne t then
val t2 = this(t1)
if t2 ne t1 then return t2
mapOver(t)

class mapRoots(from: CaptureRoot, to: CaptureRoot)(using Context) extends BiTypeMap, FollowAliases:
thisMap =>

def apply(t: Type): Type =
if t eq from then to
else t match
case t: CaptureRoot.Var =>
val ta = t.followAlias
if ta ne t then apply(ta)
else from match
case from: TermRef
if t.upperLevel >= from.symbol.ccNestingLevel
&& constrainRootsWhenMapping // next two lines do the constraining
&& CaptureRoot.isEnclosingRoot(from, t)
&& CaptureRoot.isEnclosingRoot(t, from) => to
case from: CaptureRoot.Var if from.followAlias eq t => to
case _ => t
case _ =>
mapOverFollowingAliases(t)

def inverse = mapRoots(to, from)
end mapRoots

extension (tree: Tree)

/** Map tree with CaptureRef type to its type, throw IllegalCaptureRef otherwise */
def toCaptureRef(using Context): CaptureRef = tree.tpe match
case ref: CaptureRef => ref
case tpe => throw IllegalCaptureRef(tpe)
def toCaptureRef(using Context): CaptureRef = tree match
case QualifiedRoot(outer) =>
ctx.owner.levelOwnerNamed(outer)
.orElse(defn.captureRoot) // non-existing outer roots are reported in Setup's checkQualifiedRoots
.localRoot.termRef
case _ => tree.tpe match
case ref: CaptureRef => ref
case tpe => throw IllegalCaptureRef(tpe) // if this was compiled from cc syntax, problem should have been reported at Typer

/** Convert a @retains or @retainsByName annotation tree to the capture set it represents.
* For efficience, the result is cached as an Attachment on the tree.
Expand Down Expand Up @@ -164,7 +241,7 @@ extension (tp: Type)
* a by name parameter type, turning the latter into an impure by name parameter type.
*/
def adaptByNameArgUnderPureFuns(using Context): Type =
if Feature.pureFunsEnabledSomewhere then
if adaptUnpickledFunctionTypes && Feature.pureFunsEnabledSomewhere then
AnnotatedType(tp,
CaptureAnnotation(CaptureSet.universal, boxed = false)(defn.RetainsByNameAnnot))
else
Expand Down Expand Up @@ -253,6 +330,91 @@ extension (sym: Symbol)
&& sym != defn.Caps_unsafeBox
&& sym != defn.Caps_unsafeUnbox

def isLevelOwner(using Context): Boolean = ccState.levelOwners.contains(sym)

/** The owner of the current level. Qualifying owners are
* - methods other than constructors and anonymous functions
* - anonymous functions, provided they either define a local
* root of type caps.Cap, or they are the rhs of a val definition.
* - classes, if they are not staticOwners
* - _root_
*/
def levelOwner(using Context): Symbol =
if !sym.exists || sym.isRoot || sym.isStaticOwner then defn.RootClass
else if sym.isLevelOwner then sym
else sym.owner.levelOwner

/** The nesting level of `sym` for the purposes of `cc`,
* -1 for NoSymbol
*/
def ccNestingLevel(using Context): Int =
if sym.exists then
val lowner = sym.levelOwner
ccState.nestingLevels.getOrElseUpdate(lowner,
if lowner.isRoot then 0 else lowner.owner.ccNestingLevel + 1)
else -1

/** Optionally, the nesting level of `sym` for the purposes of `cc`, provided
* a capture checker is running.
*/
def ccNestingLevelOpt(using Context): Option[Int] =
if ctx.property(ccStateKey).isDefined then Some(ccNestingLevel) else None

/** The parameter with type caps.Cap in the leading term parameter section,
* or NoSymbol, if none exists.
*/
def definedLocalRoot(using Context): Symbol =
sym.paramSymss.dropWhile(psyms => psyms.nonEmpty && psyms.head.isType) match
case psyms :: _ => psyms.find(_.info.typeSymbol == defn.Caps_Cap).getOrElse(NoSymbol)
case _ => NoSymbol

/** The local root corresponding to sym's level owner */
def localRoot(using Context): Symbol =
val owner = sym.levelOwner
assert(owner.exists)
def newRoot = newSymbol(if owner.isClass then newLocalDummy(owner) else owner,
nme.LOCAL_CAPTURE_ROOT, Synthetic, defn.Caps_Cap.typeRef, nestingLevel = owner.ccNestingLevel)
def lclRoot =
if owner.isTerm then owner.definedLocalRoot.orElse(newRoot)
else newRoot
ccState.localRoots.getOrElseUpdate(owner, lclRoot)

/** The level owner enclosing `sym` which has the given name, or NoSymbol if none exists.
* If name refers to a val that has a closure as rhs, we return the closure as level
* owner.
*/
def levelOwnerNamed(name: String)(using Context): Symbol =
def recur(owner: Symbol, prev: Symbol): Symbol =
if owner.name.toString == name then
if owner.isLevelOwner then owner
else if owner.isTerm && !owner.isOneOf(Method | Module) && prev.exists then prev
else NoSymbol
else if owner == defn.RootClass then
NoSymbol
else
val prev1 = if owner.isAnonymousFunction && owner.isLevelOwner then owner else NoSymbol
recur(owner.owner, prev1)
recur(sym, NoSymbol)
.showing(i"find outer $sym [ $name ] = $result", capt)

def maxNested(other: Symbol)(using Context): Symbol =
if sym.ccNestingLevel < other.ccNestingLevel then other else sym
/* does not work yet, we do mix sets with different levels, for instance in cc-this.scala.
else if sym.ccNestingLevel > other.ccNestingLevel then sym
else
assert(sym == other, i"conflicting symbols at same nesting level: $sym, $other")
sym
*/

def minNested(other: Symbol)(using Context): Symbol =
if sym.ccNestingLevel > other.ccNestingLevel then other else sym

extension (tp: TermRef | ThisType)
/** The nesting level of this reference as defined by capture checking */
def ccNestingLevel(using Context): Int = tp match
case tp: TermRef => tp.symbol.ccNestingLevel
case tp: ThisType => tp.cls.ccNestingLevel

extension (tp: AnnotatedType)
/** Is this a boxed capturing type? */
def isBoxed(using Context): Boolean = tp.annot match
Expand Down
Loading

0 comments on commit 64c3138

Please sign in to comment.