Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix setup of CapSet arguments. #21309

Merged
merged 3 commits into from
Aug 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion compiler/src/dotty/tools/dotc/cc/CaptureOps.scala
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,9 @@ extension (tp: Type)
def boxed(using Context): Type = tp.dealias match
case tp @ CapturingType(parent, refs) if !tp.isBoxed && !refs.isAlwaysEmpty =>
tp.annot match
case ann: CaptureAnnotation => AnnotatedType(parent, ann.boxedAnnot)
case ann: CaptureAnnotation =>
assert(!parent.derivesFrom(defn.Caps_CapSet))
AnnotatedType(parent, ann.boxedAnnot)
case ann => tp
case tp: RealTypeBounds =>
tp.derivedTypeBounds(tp.lo.boxed, tp.hi.boxed)
Expand Down
1 change: 1 addition & 0 deletions compiler/src/dotty/tools/dotc/cc/CapturingType.scala
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ object CapturingType:
* boxing status is the same or if A is boxed.
*/
def apply(parent: Type, refs: CaptureSet, boxed: Boolean = false)(using Context): Type =
assert(!boxed || !parent.derivesFrom(defn.Caps_CapSet))
if refs.isAlwaysEmpty && !refs.keepAlways then parent
else parent match
case parent @ CapturingType(parent1, refs1) if boxed || !parent.isBoxed =>
Expand Down
54 changes: 39 additions & 15 deletions compiler/src/dotty/tools/dotc/cc/CheckCaptures.scala
Original file line number Diff line number Diff line change
Expand Up @@ -388,23 +388,25 @@ class CheckCaptures extends Recheck, SymTransformer:
// should be included.
val included = cs.filter: c =>
c.stripReach match
case ref: TermRef =>
//if c.isReach then println(i"REACH $c in ${env.owner}")
//assert(!env.owner.isAnonymousFunction)
case ref: NamedType =>
val refSym = ref.symbol
val refOwner = refSym.owner
val isVisible = isVisibleFromEnv(refOwner)
if !isVisible && c.isReach && refSym.is(Param) && refOwner == env.owner then
if refSym.hasAnnotation(defn.UnboxAnnot) then
capt.println(i"exempt: $ref in $refOwner")
else
// Reach capabilities that go out of scope have to be approximated
// by their underlying capture set, which cannot be universal.
// Reach capabilities of @unboxed parameters are exempted.
val cs = CaptureSet.ofInfo(c)
cs.disallowRootCapability: () =>
report.error(em"Local reach capability $c leaks into capture scope of ${env.ownerString}", pos)
checkSubset(cs, env.captured, pos, provenance(env))
if !isVisible
&& (c.isReach || ref.isType)
&& refSym.is(Param)
&& refOwner == env.owner
then
if refSym.hasAnnotation(defn.UnboxAnnot) then
capt.println(i"exempt: $ref in $refOwner")
else
// Reach capabilities that go out of scope have to be approximated
// by their underlying capture set, which cannot be universal.
// Reach capabilities of @unboxed parameters are exempted.
val cs = CaptureSet.ofInfo(c)
cs.disallowRootCapability: () =>
report.error(em"Local reach capability $c leaks into capture scope of ${env.ownerString}", pos)
checkSubset(cs, env.captured, pos, provenance(env))
isVisible
case ref: ThisType => isVisibleFromEnv(ref.cls)
case _ => false
Expand Down Expand Up @@ -674,7 +676,29 @@ class CheckCaptures extends Recheck, SymTransformer:
i"Sealed type variable $pname", "be instantiated to",
i"This is often caused by a local capability$where\nleaking as part of its result.",
tree.srcPos)
handleCall(meth, tree, () => Existential.toCap(super.recheckTypeApply(tree, pt)))
val res = handleCall(meth, tree, () => Existential.toCap(super.recheckTypeApply(tree, pt)))
if meth == defn.Caps_containsImpl then checkContains(tree)
res
end recheckTypeApply

/** Faced with a tree of form `caps.contansImpl[CS, r.type]`, check that `R` is a tracked
* capability and assert that `{r} <:CS`.
*/
def checkContains(tree: TypeApply)(using Context): Unit =
tree.fun.knownType.widen match
case fntpe: PolyType =>
tree.args match
case csArg :: refArg :: Nil =>
val cs = csArg.knownType.captureSet
val ref = refArg.knownType
capt.println(i"check contains $cs , $ref")
ref match
case ref: CaptureRef if ref.isTracked =>
checkElem(ref, cs, tree.srcPos)
case _ =>
report.error(em"$refArg is not a tracked capability", refArg.srcPos)
case _ =>
case _ =>

override def recheckBlock(tree: Block, pt: Type)(using Context): Type =
inNestedLevel(super.recheckBlock(tree, pt))
Expand Down
11 changes: 7 additions & 4 deletions compiler/src/dotty/tools/dotc/cc/Setup.scala
Original file line number Diff line number Diff line change
Expand Up @@ -134,9 +134,10 @@ class Setup extends PreRecheck, SymTransformer, SetupAPI:
private def box(tp: Type)(using Context): Type =
def recur(tp: Type): Type = tp.dealiasKeepAnnotsAndOpaques match
case tp @ CapturingType(parent, refs) =>
if tp.isBoxed then tp else tp.boxed
if tp.isBoxed || parent.derivesFrom(defn.Caps_CapSet) then tp
else tp.boxed
case tp @ AnnotatedType(parent, ann) =>
if ann.symbol.isRetains
if ann.symbol.isRetains && !parent.derivesFrom(defn.Caps_CapSet)
then CapturingType(parent, ann.tree.toCaptureSet, boxed = true)
else tp.derivedAnnotatedType(box(parent), ann)
case tp1 @ AppliedType(tycon, args) if defn.isNonRefinedFunction(tp1) =>
Expand Down Expand Up @@ -605,8 +606,10 @@ class Setup extends PreRecheck, SymTransformer, SetupAPI:
!refs.isEmpty
case tp: (TypeRef | AppliedType) =>
val sym = tp.typeSymbol
if sym.isClass then !sym.isPureClass
else instanceCanBeImpure(tp.superType)
if sym.isClass
then !sym.isPureClass
else !tp.derivesFrom(defn.Caps_CapSet) // CapSet arguments don't get other capture set variables added
&& instanceCanBeImpure(tp.superType)
case tp: (RefinedOrRecType | MatchType) =>
instanceCanBeImpure(tp.underlying)
case tp: AndType =>
Expand Down
6 changes: 4 additions & 2 deletions compiler/src/dotty/tools/dotc/core/Definitions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -993,15 +993,17 @@ class Definitions {
@tu lazy val CapsModule: Symbol = requiredModule("scala.caps")
@tu lazy val captureRoot: TermSymbol = CapsModule.requiredValue("cap")
@tu lazy val Caps_Capability: TypeSymbol = CapsModule.requiredType("Capability")
@tu lazy val Caps_CapSet = requiredClass("scala.caps.CapSet")
@tu lazy val Caps_CapSet: ClassSymbol = requiredClass("scala.caps.CapSet")
@tu lazy val Caps_reachCapability: TermSymbol = CapsModule.requiredMethod("reachCapability")
@tu lazy val Caps_capsOf: TermSymbol = CapsModule.requiredMethod("capsOf")
@tu lazy val Caps_Exists = requiredClass("scala.caps.Exists")
@tu lazy val Caps_Exists: ClassSymbol = requiredClass("scala.caps.Exists")
@tu lazy val CapsUnsafeModule: Symbol = requiredModule("scala.caps.unsafe")
@tu lazy val Caps_unsafeAssumePure: Symbol = CapsUnsafeModule.requiredMethod("unsafeAssumePure")
@tu lazy val Caps_unsafeBox: Symbol = CapsUnsafeModule.requiredMethod("unsafeBox")
@tu lazy val Caps_unsafeUnbox: Symbol = CapsUnsafeModule.requiredMethod("unsafeUnbox")
@tu lazy val Caps_unsafeBoxFunArg: Symbol = CapsUnsafeModule.requiredMethod("unsafeBoxFunArg")
@tu lazy val Caps_ContainsTrait: TypeSymbol = CapsModule.requiredType("Capability")
@tu lazy val Caps_containsImpl: TermSymbol = CapsModule.requiredMethod("containsImpl")

@tu lazy val PureClass: Symbol = requiredClass("scala.Pure")

Expand Down
12 changes: 11 additions & 1 deletion library/src/scala/caps.scala
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
package scala

import annotation.{experimental, compileTimeOnly}
import annotation.{experimental, compileTimeOnly, retainsCap}

@experimental object caps:

Expand All @@ -19,6 +19,16 @@ import annotation.{experimental, compileTimeOnly}
/** Carrier trait for capture set type parameters */
trait CapSet extends Any

/** A type constraint expressing that the capture set `C` needs to contain
* the capability `R`
*/
sealed trait Contains[C <: CapSet @retainsCap, R <: Singleton]

/** The only implementation of `Contains`. The constraint that `{R} <: C` is
* added separately by the capture checker.
*/
given containsImpl[C <: CapSet @retainsCap, R <: Singleton]: Contains[C, R]()

@compileTimeOnly("Should be be used only internally by the Scala compiler")
def capsOf[CS]: Any = ???

Expand Down
11 changes: 11 additions & 0 deletions tests/neg-custom-args/captures/i21313.check
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
-- Error: tests/neg-custom-args/captures/i21313.scala:6:27 -------------------------------------------------------------
6 |def foo(x: Async) = x.await(???) // error
| ^
| (x : Async) is not a tracked capability
-- [E007] Type Mismatch Error: tests/neg-custom-args/captures/i21313.scala:15:12 ---------------------------------------
15 | ac1.await(src2) // error
| ^^^^
| Found: (src2 : Source[Int, caps.CapSet^{ac2}]^?)
| Required: Source[Int, caps.CapSet^{ac1}]^
|
| longer explanation available when compiling with `-explain`
15 changes: 15 additions & 0 deletions tests/neg-custom-args/captures/i21313.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import caps.CapSet

trait Async:
def await[T, Cap^](using caps.Contains[Cap, this.type])(src: Source[T, Cap]^): T

def foo(x: Async) = x.await(???) // error

trait Source[+T, Cap^]:
final def await(using ac: Async^{Cap^}) = ac.await[T, Cap](this) // Contains[Cap, ac] is assured because {ac} <: Cap.

def test(using ac1: Async^, ac2: Async^, x: String) =
val src1 = new Source[Int, CapSet^{ac1}] {}
ac1.await(src1) // ok
val src2 = new Source[Int, CapSet^{ac2}] {}
ac1.await(src2) // error
15 changes: 15 additions & 0 deletions tests/neg-custom-args/captures/i21347.check
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
-- Error: tests/neg-custom-args/captures/i21347.scala:4:15 -------------------------------------------------------------
4 | ops.foreach: op => // error
| ^
| Local reach capability C leaks into capture scope of method runOps
5 | op()
-- Error: tests/neg-custom-args/captures/i21347.scala:8:14 -------------------------------------------------------------
8 | () => runOps(f :: Nil) // error
| ^^^^^^^^^^^^^^^^
| reference (caps.cap : caps.Capability) is not included in the allowed capture set {}
| of an enclosing function literal with expected type () -> Unit
-- Error: tests/neg-custom-args/captures/i21347.scala:11:15 ------------------------------------------------------------
11 | ops.foreach: op => // error
| ^
| Local reach capability ops* leaks into capture scope of method runOpsAlt
12 | op()
12 changes: 12 additions & 0 deletions tests/neg-custom-args/captures/i21347.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
import language.experimental.captureChecking

def runOps[C^](ops: List[() ->{C^} Unit]): Unit =
ops.foreach: op => // error
op()

def boom(f: () => Unit): () -> Unit =
() => runOps(f :: Nil) // error

def runOpsAlt(ops: List[() => Unit]): Unit =
ops.foreach: op => // error
op()
11 changes: 11 additions & 0 deletions tests/pos-custom-args/captures/i21313.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
import caps.CapSet

trait Async:
def await[T, Cap^](using caps.Contains[Cap, this.type])(src: Source[T, Cap]^): T

trait Source[+T, Cap^]:
final def await(using ac: Async^{Cap^}) = ac.await[T, Cap](this) // Contains[Cap, ac] is assured because {ac} <: Cap.

Comment on lines +6 to +8
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This wouldn't work if Async was declared extends caps.Capability

def test(using ac1: Async^, ac2: Async^, x: String) =
val src1 = new Source[Int, CapSet^{ac1}] {}
ac1.await(src1)
14 changes: 14 additions & 0 deletions tests/pos/polycap.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import language.experimental.captureChecking

class Source[+T, Cap^]

def completed[T, Cap^](result: T): Source[T, Cap] =
//val fut = new Source[T, Cap]()
val fut2 = new Source[T, Cap]()
fut2: Source[T, Cap]






Loading