Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve Contains handling #21361

Merged
merged 2 commits into from
Aug 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions compiler/src/dotty/tools/dotc/cc/CaptureOps.scala
Original file line number Diff line number Diff line change
Expand Up @@ -713,3 +713,21 @@ extension (self: Type)
case _ =>
self

/** An extractor for a contains argument */
object ContainsImpl:
def unapply(tree: TypeApply)(using Context): Option[(Tree, Tree)] =
tree.fun.tpe.widen match
case fntpe: PolyType if tree.fun.symbol == defn.Caps_containsImpl =>
tree.args match
case csArg :: refArg :: Nil => Some((csArg, refArg))
case _ => None
case _ => None

/** An extractor for a contains parameter */
object ContainsParam:
def unapply(sym: Symbol)(using Context): Option[(TypeRef, CaptureRef)] =
sym.info.dealias match
case AppliedType(tycon, (cs: TypeRef) :: (ref: CaptureRef) :: Nil)
if tycon.typeSymbol == defn.Caps_ContainsTrait
&& cs.typeSymbol.isAbstractOrParamType => Some((cs, ref))
case _ => None
4 changes: 4 additions & 0 deletions compiler/src/dotty/tools/dotc/cc/CaptureRef.scala
Original file line number Diff line number Diff line change
Expand Up @@ -116,8 +116,12 @@ trait CaptureRef extends TypeProxy, ValueType:
case x1: SingletonCaptureRef => x1.subsumes(y)
case _ => false
case x: TermParamRef => subsumesExistentially(x, y)
case x: TypeRef => assumedContainsOf(x).contains(y)
case _ => false

def assumedContainsOf(x: TypeRef)(using Context): SimpleIdentitySet[CaptureRef] =
CaptureSet.assumedContains.getOrElse(x, SimpleIdentitySet.empty)

end CaptureRef

trait SingletonCaptureRef extends SingletonType, CaptureRef
Expand Down
8 changes: 7 additions & 1 deletion compiler/src/dotty/tools/dotc/cc/CaptureSet.scala
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ import util.{SimpleIdentitySet, Property}
import typer.ErrorReporting.Addenda
import TypeComparer.subsumesExistentially
import util.common.alwaysTrue
import scala.collection.mutable
import scala.collection.{mutable, immutable}
import CCState.*

/** A class for capture sets. Capture sets can be constants or variables.
Expand Down Expand Up @@ -1125,6 +1125,12 @@ object CaptureSet:
foldOver(cs, t)
collect(CaptureSet.empty, tp)

type AssumedContains = immutable.Map[TypeRef, SimpleIdentitySet[CaptureRef]]
val AssumedContains: Property.Key[AssumedContains] = Property.Key()

def assumedContains(using Context): AssumedContains =
ctx.property(AssumedContains).getOrElse(immutable.Map.empty)

private val ShownVars: Property.Key[mutable.Set[Var]] = Property.Key()

/** Perform `op`. Under -Ycc-debug, collect and print info about all variables reachable
Expand Down
46 changes: 26 additions & 20 deletions compiler/src/dotty/tools/dotc/cc/CheckCaptures.scala
Original file line number Diff line number Diff line change
Expand Up @@ -676,29 +676,24 @@ class CheckCaptures extends Recheck, SymTransformer:
i"Sealed type variable $pname", "be instantiated to",
i"This is often caused by a local capability$where\nleaking as part of its result.",
tree.srcPos)
val res = handleCall(meth, tree, () => Existential.toCap(super.recheckTypeApply(tree, pt)))
if meth == defn.Caps_containsImpl then checkContains(tree)
res
try handleCall(meth, tree, () => Existential.toCap(super.recheckTypeApply(tree, pt)))
finally checkContains(tree)
end recheckTypeApply

/** Faced with a tree of form `caps.contansImpl[CS, r.type]`, check that `R` is a tracked
* capability and assert that `{r} <:CS`.
*/
def checkContains(tree: TypeApply)(using Context): Unit =
tree.fun.knownType.widen match
case fntpe: PolyType =>
tree.args match
case csArg :: refArg :: Nil =>
val cs = csArg.knownType.captureSet
val ref = refArg.knownType
capt.println(i"check contains $cs , $ref")
ref match
case ref: CaptureRef if ref.isTracked =>
checkElem(ref, cs, tree.srcPos)
case _ =>
report.error(em"$refArg is not a tracked capability", refArg.srcPos)
case _ =>
case _ =>
def checkContains(tree: TypeApply)(using Context): Unit = tree match
case ContainsImpl(csArg, refArg) =>
val cs = csArg.knownType.captureSet
val ref = refArg.knownType
capt.println(i"check contains $cs , $ref")
ref match
case ref: CaptureRef if ref.isTracked =>
checkElem(ref, cs, tree.srcPos)
case _ =>
report.error(em"$refArg is not a tracked capability", refArg.srcPos)
case _ =>

override def recheckBlock(tree: Block, pt: Type)(using Context): Type =
inNestedLevel(super.recheckBlock(tree, pt))
Expand Down Expand Up @@ -814,15 +809,26 @@ class CheckCaptures extends Recheck, SymTransformer:
val localSet = capturedVars(sym)
if !localSet.isAlwaysEmpty then
curEnv = Env(sym, EnvKind.Regular, localSet, curEnv)

// ctx with AssumedContains entries for each Contains parameter
val bodyCtx =
var ac = CaptureSet.assumedContains
for paramSyms <- sym.paramSymss do
for case ContainsParam(cs, ref) <- paramSyms do
ac = ac.updated(cs, ac.getOrElse(cs, SimpleIdentitySet.empty) + ref)
if ac.isEmpty then ctx
else ctx.withProperty(CaptureSet.AssumedContains, Some(ac))

inNestedLevel: // TODO: needed here?
try checkInferredResult(super.recheckDefDef(tree, sym), tree)
try checkInferredResult(super.recheckDefDef(tree, sym)(using bodyCtx), tree)
finally
if !sym.isAnonymousFunction then
// Anonymous functions propagate their type to the enclosing environment
// so it is not in general sound to interpolate their types.
interpolateVarsIn(tree.tpt)
curEnv = saved

end recheckDefDef

/** If val or def definition with inferred (result) type is visible
* in other compilation units, check that the actual inferred type
* conforms to the expected type where all inferred capture sets are dropped.
Expand Down
2 changes: 1 addition & 1 deletion compiler/src/dotty/tools/dotc/core/Definitions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1002,7 +1002,7 @@ class Definitions {
@tu lazy val Caps_unsafeBox: Symbol = CapsUnsafeModule.requiredMethod("unsafeBox")
@tu lazy val Caps_unsafeUnbox: Symbol = CapsUnsafeModule.requiredMethod("unsafeUnbox")
@tu lazy val Caps_unsafeBoxFunArg: Symbol = CapsUnsafeModule.requiredMethod("unsafeBoxFunArg")
@tu lazy val Caps_ContainsTrait: TypeSymbol = CapsModule.requiredType("Capability")
@tu lazy val Caps_ContainsTrait: TypeSymbol = CapsModule.requiredType("Contains")
@tu lazy val Caps_containsImpl: TermSymbol = CapsModule.requiredMethod("containsImpl")

@tu lazy val PureClass: Symbol = requiredClass("scala.Pure")
Expand Down
11 changes: 10 additions & 1 deletion tests/pos-custom-args/captures/i21313.scala
Original file line number Diff line number Diff line change
@@ -1,7 +1,16 @@
import caps.CapSet

trait Async:
def await[T, Cap^](using caps.Contains[Cap, this.type])(src: Source[T, Cap]^): T
def await[T, Cap^](using caps.Contains[Cap, this.type])(src: Source[T, Cap]^): T =
val x: Async^{this} = ???
val y: Async^{Cap^} = x
val ac: Async^ = ???
def f(using caps.Contains[Cap, ac.type]) =
val x2: Async^{this} = ???
val y2: Async^{Cap^} = x2
val x3: Async^{ac} = ???
val y3: Async^{Cap^} = x3
???

trait Source[+T, Cap^]:
final def await(using ac: Async^{Cap^}) = ac.await[T, Cap](this) // Contains[Cap, ac] is assured because {ac} <: Cap.
Expand Down
8 changes: 8 additions & 0 deletions tests/run/Providers.check
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,11 @@ Executing query: insert into subscribers(name, email) values Daniel daniel@Rockt
You've just been subscribed to RockTheJVM. Welcome, Martin
Acquired connection
Executing query: insert into subscribers(name, email) values Martin [email protected]

Injected2
You've just been subscribed to RockTheJVM. Welcome, Daniel
Acquired connection
Executing query: insert into subscribers(name, email) values Daniel [email protected]
You've just been subscribed to RockTheJVM. Welcome, Martin
Acquired connection
Executing query: insert into subscribers(name, email) values Martin [email protected]
natsukagami marked this conversation as resolved.
Show resolved Hide resolved
52 changes: 52 additions & 0 deletions tests/run/Providers.scala
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@ end Providers
Explicit().test()
println(s"\nInjected")
Injected().test()
println(s"\nInjected2")
Injected2().test()

/** Demonstrator for explicit dependency construction */
class Explicit:
Expand Down Expand Up @@ -173,5 +175,55 @@ class Injected:
end explicit
end Injected

/** Injected with builders in companion objects */
class Injected2:
import Providers.*

case class User(name: String, email: String)

class UserSubscription(emailService: EmailService, db: UserDatabase):
def subscribe(user: User) =
emailService.email(user)
db.insert(user)
object UserSubscription:
def apply()(using Provider[(EmailService, UserDatabase)]): UserSubscription =
new UserSubscription(provided[EmailService], provided[UserDatabase])

class EmailService:
def email(user: User) =
println(s"You've just been subscribed to RockTheJVM. Welcome, ${user.name}")

class UserDatabase(pool: ConnectionPool):
def insert(user: User) =
pool.get().runQuery(s"insert into subscribers(name, email) values ${user.name} ${user.email}")
object UserDatabase:
def apply()(using Provider[(ConnectionPool)]): UserDatabase =
new UserDatabase(provided[ConnectionPool])

class ConnectionPool(n: Int):
def get(): Connection =
println(s"Acquired connection")
Connection()

class Connection():
def runQuery(query: String): Unit =
println(s"Executing query: $query")

def test() =
given Provider[EmailService] = provide(EmailService())
given Provider[ConnectionPool] = provide(ConnectionPool(10))
given Provider[UserDatabase] = provide(UserDatabase())
given Provider[UserSubscription] = provide(UserSubscription())

def subscribe(user: User)(using Provider[UserSubscription]) =
val sub = UserSubscription()
sub.subscribe(user)

subscribe(User("Daniel", "[email protected]"))
subscribe(User("Martin", "[email protected]"))
end test
end Injected2




Loading