Skip to content

Commit

Permalink
First implementation of capture polymorphism
Browse files Browse the repository at this point in the history
  • Loading branch information
odersky committed Jun 26, 2024
1 parent ea182f2 commit 60b0486
Show file tree
Hide file tree
Showing 14 changed files with 139 additions and 27 deletions.
11 changes: 8 additions & 3 deletions compiler/src/dotty/tools/dotc/ast/untpd.scala
Original file line number Diff line number Diff line change
Expand Up @@ -521,12 +521,17 @@ object untpd extends Trees.Instance[Untyped] with UntypedTreeInfo {
def captureRoot(using Context): Select =
Select(scalaDot(nme.caps), nme.CAPTURE_ROOT)

def captureRootIn(using Context): Select =
Select(scalaDot(nme.caps), nme.capIn)

def makeRetaining(parent: Tree, refs: List[Tree], annotName: TypeName)(using Context): Annotated =
Annotated(parent, New(scalaAnnotationDot(annotName), List(refs)))

def makeCapsOf(id: Ident)(using Context): Tree =
TypeApply(Select(scalaDot(nme.caps), nme.capsOf), id :: Nil)

def makeCapsBound()(using Context): Tree =
makeRetaining(
Select(scalaDot(nme.caps), tpnme.CapSet),
Nil, tpnme.retainsCap)

def makeConstructor(tparams: List[TypeDef], vparamss: List[List[ValDef]], rhs: Tree = EmptyTree)(using Context): DefDef =
DefDef(nme.CONSTRUCTOR, joinParams(tparams, vparamss), TypeTree(), rhs)

Expand Down
12 changes: 11 additions & 1 deletion compiler/src/dotty/tools/dotc/cc/CaptureOps.scala
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,8 @@ extension (tree: Tree)
def toCaptureRef(using Context): CaptureRef = tree match
case ReachCapabilityApply(arg) =>
arg.toCaptureRef.reach
case CapsOfApply(arg) =>
arg.toCaptureRef
case _ => tree.tpe match
case ref: CaptureRef if ref.isTrackableRef =>
ref
Expand All @@ -145,7 +147,7 @@ extension (tree: Tree)
case Some(refs) => refs
case None =>
val refs = CaptureSet(tree.retainedElems.map(_.toCaptureRef)*)
.showing(i"toCaptureSet $tree --> $result", capt)
//.showing(i"toCaptureSet $tree --> $result", capt)
tree.putAttachment(Captures, refs)
refs

Expand Down Expand Up @@ -526,6 +528,14 @@ object ReachCapabilityApply:
case Apply(reach, arg :: Nil) if reach.symbol == defn.Caps_reachCapability => Some(arg)
case _ => None

/** An extractor for `caps.capsOf[X]`, which is used to express a generic capture set
* as a tree in a @retains annotation.
*/
object CapsOfApply:
def unapply(tree: TypeApply)(using Context): Option[Tree] = tree match
case TypeApply(capsOf, arg :: Nil) if capsOf.symbol == defn.Caps_capsOf => Some(arg)
case _ => None

class AnnotatedCapability(annot: Context ?=> ClassSymbol):
def apply(tp: Type)(using Context) =
AnnotatedType(tp, Annotation(annot, util.Spans.NoSpan))
Expand Down
4 changes: 3 additions & 1 deletion compiler/src/dotty/tools/dotc/cc/CaptureSet.scala
Original file line number Diff line number Diff line change
Expand Up @@ -879,7 +879,9 @@ object CaptureSet:
val r1 = tm(r)
val upper = r1.captureSet
def isExact =
upper.isAlwaysEmpty || upper.isConst && upper.elems.size == 1 && upper.elems.contains(r1)
upper.isAlwaysEmpty
|| upper.isConst && upper.elems.size == 1 && upper.elems.contains(r1)
|| r.derivesFrom(defn.Caps_CapSet)
if variance > 0 || isExact then upper
else if variance < 0 then CaptureSet.empty
else upper.maybe
Expand Down
26 changes: 17 additions & 9 deletions compiler/src/dotty/tools/dotc/cc/CheckCaptures.scala
Original file line number Diff line number Diff line change
Expand Up @@ -123,16 +123,24 @@ object CheckCaptures:
case _: SingletonType =>
report.error(em"Singleton type $parent cannot have capture set", parent.srcPos)
case _ =>
def check(elem: Tree, pos: SrcPos): Unit = elem.tpe match
case ref: CaptureRef =>
if !ref.isTrackableRef then
report.error(em"$elem cannot be tracked since it is not a parameter or local value", pos)
case tpe =>
report.error(em"$elem: $tpe is not a legal element of a capture set", pos)
for elem <- ann.retainedElems do
val elem1 = elem match
case ReachCapabilityApply(arg) => arg
case _ => elem
elem1.tpe match
case ref: CaptureRef =>
if !ref.isTrackableRef then
report.error(em"$elem cannot be tracked since it is not a parameter or local value", elem.srcPos)
case tpe =>
report.error(em"$elem: $tpe is not a legal element of a capture set", elem.srcPos)
elem match
case CapsOfApply(arg) =>
def isLegalCapsOfArg =
arg.symbol.isAbstractOrParamType && arg.symbol.info.derivesFrom(defn.Caps_CapSet)
if !isLegalCapsOfArg then
report.error(
em"""$arg is not a legal prefix for `^` here,
|is must be a type parameter or abstract type with a caps.CapSet upper bound.""",
elem.srcPos)
case ReachCapabilityApply(arg) => check(arg, elem.srcPos)
case _ => check(elem, elem.srcPos)

/** Report an error if some part of `tp` contains the root capability in its capture set
* or if it refers to an unsealed type parameter that could possibly be instantiated with
Expand Down
2 changes: 1 addition & 1 deletion compiler/src/dotty/tools/dotc/cc/Setup.scala
Original file line number Diff line number Diff line change
Expand Up @@ -384,7 +384,7 @@ class Setup extends PreRecheck, SymTransformer, SetupAPI:
sym.updateInfo(thisPhase, info, newFlagsFor(sym))
toBeUpdated -= sym
sym.namedType match
case ref: CaptureRef => ref.invalidateCaches() // TODO: needed?
case ref: CaptureRef if ref.isTrackableRef => ref.invalidateCaches() // TODO: needed?
case _ =>

extension (sym: Symbol) def nextInfo(using Context): Type =
Expand Down
4 changes: 3 additions & 1 deletion compiler/src/dotty/tools/dotc/core/Definitions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -991,8 +991,10 @@ class Definitions {

@tu lazy val CapsModule: Symbol = requiredModule("scala.caps")
@tu lazy val captureRoot: TermSymbol = CapsModule.requiredValue("cap")
@tu lazy val Caps_Capability: ClassSymbol = requiredClass("scala.caps.Capability")
@tu lazy val Caps_Capability: TypeSymbol = CapsModule.requiredType("Capability")
@tu lazy val Caps_CapSet = requiredClass("scala.caps.CapSet")
@tu lazy val Caps_reachCapability: TermSymbol = CapsModule.requiredMethod("reachCapability")
@tu lazy val Caps_capsOf: TermSymbol = CapsModule.requiredMethod("capsOf")
@tu lazy val Caps_Exists = requiredClass("scala.caps.Exists")
@tu lazy val CapsUnsafeModule: Symbol = requiredModule("scala.caps.unsafe")
@tu lazy val Caps_unsafeAssumePure: Symbol = CapsUnsafeModule.requiredMethod("unsafeAssumePure")
Expand Down
3 changes: 2 additions & 1 deletion compiler/src/dotty/tools/dotc/core/StdNames.scala
Original file line number Diff line number Diff line change
Expand Up @@ -358,6 +358,7 @@ object StdNames {
val AppliedTypeTree: N = "AppliedTypeTree"
val ArrayAnnotArg: N = "ArrayAnnotArg"
val CAP: N = "CAP"
val CapSet: N = "CapSet"
val Constant: N = "Constant"
val ConstantType: N = "ConstantType"
val Eql: N = "Eql"
Expand Down Expand Up @@ -441,8 +442,8 @@ object StdNames {
val bytes: N = "bytes"
val canEqual_ : N = "canEqual"
val canEqualAny : N = "canEqualAny"
val capIn: N = "capIn"
val caps: N = "caps"
val capsOf: N = "capsOf"
val captureChecking: N = "captureChecking"
val checkInitialized: N = "checkInitialized"
val classOf: N = "classOf"
Expand Down
2 changes: 1 addition & 1 deletion compiler/src/dotty/tools/dotc/core/TypeComparer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2839,7 +2839,7 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling
private def existentialVarsConform(tp1: Type, tp2: Type) =
tp2 match
case tp2: TermParamRef => tp1 match
case tp1: CaptureRef => subsumesExistentially(tp2, tp1)
case tp1: CaptureRef if tp1.isTrackableRef => subsumesExistentially(tp2, tp1)
case _ => false
case _ => false

Expand Down
16 changes: 13 additions & 3 deletions compiler/src/dotty/tools/dotc/core/Types.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2313,7 +2313,11 @@ object Types extends TypeUtils {

override def captureSet(using Context): CaptureSet =
val cs = captureSetOfInfo
if isTrackableRef && !cs.isAlwaysEmpty then singletonCaptureSet else cs
if isTrackableRef then
if cs.isAlwaysEmpty then cs else singletonCaptureSet
else dealias match
case _: (TypeRef | TypeParamRef) => CaptureSet.empty
case _ => cs

end CaptureRef

Expand Down Expand Up @@ -3032,7 +3036,7 @@ object Types extends TypeUtils {

abstract case class TypeRef(override val prefix: Type,
private var myDesignator: Designator)
extends NamedType {
extends NamedType, CaptureRef {

type ThisType = TypeRef
type ThisName = TypeName
Expand Down Expand Up @@ -3081,6 +3085,9 @@ object Types extends TypeUtils {
/** Hook that can be called from creation methods in TermRef and TypeRef */
def validated(using Context): this.type =
this

override def isTrackableRef(using Context) =
symbol.isAbstractOrParamType && derivesFrom(defn.Caps_CapSet)
}

final class CachedTermRef(prefix: Type, designator: Designator, hc: Int) extends TermRef(prefix, designator) {
Expand Down Expand Up @@ -4841,7 +4848,8 @@ object Types extends TypeUtils {
/** Only created in `binder.paramRefs`. Use `binder.paramRefs(paramNum)` to
* refer to `TypeParamRef(binder, paramNum)`.
*/
abstract case class TypeParamRef(binder: TypeLambda, paramNum: Int) extends ParamRef {
abstract case class TypeParamRef(binder: TypeLambda, paramNum: Int)
extends ParamRef, CaptureRef {
type BT = TypeLambda
def kindString: String = "Type"
def copyBoundType(bt: BT): Type = bt.paramRefs(paramNum)
Expand All @@ -4861,6 +4869,8 @@ object Types extends TypeUtils {
case bound: OrType => occursIn(bound.tp1, fromBelow) || occursIn(bound.tp2, fromBelow)
case _ => false
}

override def isTrackableRef(using Context) = derivesFrom(defn.Caps_CapSet)
}

private final class TypeParamRefImpl(binder: TypeLambda, paramNum: Int) extends TypeParamRef(binder, paramNum)
Expand Down
17 changes: 13 additions & 4 deletions compiler/src/dotty/tools/dotc/parsing/Parsers.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1541,7 +1541,7 @@ object Parsers {
case _ => None
}

/** CaptureRef ::= ident | `this` | `cap` [`[` ident `]`]
/** CaptureRef ::= ident [`*` | `^`] | `this`
*/
def captureRef(): Tree =
if in.token == THIS then simpleRef()
Expand All @@ -1551,6 +1551,10 @@ object Parsers {
in.nextToken()
atSpan(startOffset(id)):
PostfixOp(id, Ident(nme.CC_REACH))
else if isIdent(nme.UPARROW) then
in.nextToken()
atSpan(startOffset(id)):
makeCapsOf(cpy.Ident(id)(id.name.toTypeName))
else id

/** CaptureSet ::= `{` CaptureRef {`,` CaptureRef} `}` -- under captureChecking
Expand Down Expand Up @@ -1968,7 +1972,7 @@ object Parsers {
}

/** SimpleType ::= SimpleLiteral
* | ‘?’ SubtypeBounds
* | ‘?’ TypeBounds
* | SimpleType1
* | SimpleType ‘(’ Singletons ‘)’ -- under language.experimental.dependent, checked in Typer
* Singletons ::= Singleton {‘,’ Singleton}
Expand Down Expand Up @@ -2188,9 +2192,15 @@ object Parsers {
inBraces(refineStatSeq())

/** TypeBounds ::= [`>:' Type] [`<:' Type]
* | `^` -- under captureChecking
*/
def typeBounds(): TypeBoundsTree =
atSpan(in.offset) { TypeBoundsTree(bound(SUPERTYPE), bound(SUBTYPE)) }
atSpan(in.offset):
if in.isIdent(nme.UPARROW) && Feature.ccEnabled then
in.nextToken()
TypeBoundsTree(EmptyTree, makeCapsBound())
else
TypeBoundsTree(bound(SUPERTYPE), bound(SUBTYPE))

private def bound(tok: Int): Tree =
if (in.token == tok) { in.nextToken(); toplevelTyp() }
Expand Down Expand Up @@ -3384,7 +3394,6 @@ object Parsers {
*
* DefTypeParamClause::= ‘[’ DefTypeParam {‘,’ DefTypeParam} ‘]’
* DefTypeParam ::= {Annotation}
* [`sealed`] -- under captureChecking
* id [HkTypeParamClause] TypeParamBounds
*
* TypTypeParamClause::= ‘[’ TypTypeParam {‘,’ TypTypeParam} ‘]’
Expand Down
2 changes: 1 addition & 1 deletion compiler/src/dotty/tools/dotc/typer/Typer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2328,7 +2328,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
val res = Throw(expr1).withSpan(tree.span)
if Feature.ccEnabled && !cap.isEmpty && !ctx.isAfterTyper then
// Record access to the CanThrow capabulity recovered in `cap` by wrapping
// the type of the `throw` (i.e. Nothing) in a `@requiresCapability` annotatoon.
// the type of the `throw` (i.e. Nothing) in a `@requiresCapability` annotation.
Typed(res,
TypeTree(
AnnotatedType(res.tpe,
Expand Down
8 changes: 7 additions & 1 deletion library/src/scala/caps.scala
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
package scala

import annotation.experimental
import annotation.{experimental, compileTimeOnly}

@experimental object caps:

Expand All @@ -16,6 +16,12 @@ import annotation.experimental
@deprecated("Use `Capability` instead")
type Cap = Capability

/** Carrier trait for capture set type parameters */
trait CapSet extends Any

@compileTimeOnly("Should be be used only internally by the Scala compiler")
def capsOf[CS]: Any = ???

/** Reach capabilities x* which appear as terms in @retains annotations are encoded
* as `caps.reachCapability(x)`. When converted to CaptureRef types in capture sets
* they are represented as `x.type @annotation.internal.reachCapability`.
Expand Down
23 changes: 23 additions & 0 deletions tests/pos/cc-poly-1.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import language.experimental.captureChecking
import annotation.experimental
import caps.{CapSet, Capability}

@experimental object Test:

class C extends Capability
class D

def f[X^](x: D^{X^}): D^{X^} = x
def g[X^](x: D^{X^}, y: D^{X^}): D^{X^} = x

def test(c1: C, c2: C) =
val d: D^{c1, c2} = D()
val x = f[CapSet^{c1, c2}](d)
val _: D^{c1, c2} = x
val d1: D^{c1} = D()
val d2: D^{c2} = D()
val y = g(d1, d2)
val _: D^{d1, d2} = y
val _: D^{c1, c2} = y


36 changes: 36 additions & 0 deletions tests/pos/cc-poly-source.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import language.experimental.captureChecking
import annotation.experimental
import caps.{CapSet, Capability}

@experimental object Test:

class Label //extends Capability

class Listener

class Source[X^]:
private var listeners: Set[Listener^{X^}] = Set.empty
def register(x: Listener^{X^}): Unit =
listeners += x

def allListeners: Set[Listener^{X^}] = listeners

def test1(lbl1: Label^, lbl2: Label^) =
val src = Source[CapSet^{lbl1, lbl2}]
def l1: Listener^{lbl1} = ???
val l2: Listener^{lbl2} = ???
src.register{l1}
src.register{l2}
val ls = src.allListeners
val _: Set[Listener^{lbl1, lbl2}] = ls

def test2(lbls: List[Label^]) =
def makeListener(lbl: Label^): Listener^{lbl} = ???
val listeners = lbls.map(makeListener)
val src = Source[CapSet^{lbls*}]
for l <- listeners do
src.register(l)
val ls = src.allListeners
val _: Set[Listener^{lbls*}] = ls


0 comments on commit 60b0486

Please sign in to comment.