Skip to content

Commit

Permalink
Task Metadata instead of Droppable Future (#3)
Browse files Browse the repository at this point in the history
- Added TaskMetadata used to monitor when a Task is completed/cancelled
- Removed DroppableFuture (unnecessary indirection)
  • Loading branch information
coder137 authored Jun 28, 2024
1 parent f7f2bbd commit 07eda0e
Show file tree
Hide file tree
Showing 4 changed files with 90 additions and 93 deletions.
1 change: 0 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ edition = "2021"

[dependencies]
async-task = "4.7"
pin-project = "1"

[dev-dependencies]
tokio = { version = "1", features = ["full"] }
51 changes: 0 additions & 51 deletions src/droppable_future.rs

This file was deleted.

3 changes: 0 additions & 3 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@
mod droppable_future;
use droppable_future::*;

mod task_identifier;
pub use task_identifier::*;

Expand Down
128 changes: 90 additions & 38 deletions src/ticked_async_executor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use std::{
},
};

use crate::{DroppableFuture, TaskIdentifier};
use crate::TaskIdentifier;

#[derive(Debug)]
pub enum TaskState {
Expand All @@ -16,11 +16,37 @@ pub enum TaskState {
Drop(TaskIdentifier),
}

pub type Task<T> = async_task::Task<T>;
type Payload = (TaskIdentifier, async_task::Runnable);
pub type Task<T, O> = async_task::Task<T, TaskMetadata<O>>;
type TaskRunnable<O> = async_task::Runnable<TaskMetadata<O>>;
type Payload<O> = (TaskIdentifier, TaskRunnable<O>);

pub struct TickedAsyncExecutor<O> {
channel: (mpsc::Sender<Payload>, mpsc::Receiver<Payload>),
/// Task Metadata associated with TickedAsyncExecutor
///
/// Primarily used to track when the Task is completed/cancelled
pub struct TaskMetadata<O>
where
O: Fn(TaskState) + Send + Sync + 'static,
{
num_spawned_tasks: Arc<AtomicUsize>,
identifier: TaskIdentifier,
observer: O,
}

impl<O> Drop for TaskMetadata<O>
where
O: Fn(TaskState) + Send + Sync + 'static,
{
fn drop(&mut self) {
self.num_spawned_tasks.fetch_sub(1, Ordering::Relaxed);
(self.observer)(TaskState::Drop(self.identifier.clone()));
}
}

pub struct TickedAsyncExecutor<O>
where
O: Fn(TaskState) + Send + Sync + 'static,
{
channel: (mpsc::Sender<Payload<O>>, mpsc::Receiver<Payload<O>>),
num_woken_tasks: Arc<AtomicUsize>,
num_spawned_tasks: Arc<AtomicUsize>,

Expand Down Expand Up @@ -53,14 +79,22 @@ where
&self,
identifier: impl Into<TaskIdentifier>,
future: impl Future<Output = T> + Send + 'static,
) -> Task<T>
) -> Task<T, O>
where
T: Send + 'static,
{
let identifier = identifier.into();
let future = self.droppable_future(identifier.clone(), future);
let schedule = self.runnable_schedule_cb(identifier);
let (runnable, task) = async_task::spawn(future, schedule);
self.num_spawned_tasks.fetch_add(1, Ordering::Relaxed);
(self.observer)(TaskState::Spawn(identifier.clone()));

let schedule = self.runnable_schedule_cb(identifier.clone());
let (runnable, task) = async_task::Builder::new()
.metadata(TaskMetadata {
num_spawned_tasks: self.num_spawned_tasks.clone(),
identifier,
observer: self.observer.clone(),
})
.spawn(|_m| future, schedule);
runnable.schedule();
task
}
Expand All @@ -69,14 +103,22 @@ where
&self,
identifier: impl Into<TaskIdentifier>,
future: impl Future<Output = T> + 'static,
) -> Task<T>
) -> Task<T, O>
where
T: 'static,
{
let identifier = identifier.into();
let future = self.droppable_future(identifier.clone(), future);
let schedule = self.runnable_schedule_cb(identifier);
let (runnable, task) = async_task::spawn_local(future, schedule);
self.num_spawned_tasks.fetch_add(1, Ordering::Relaxed);
(self.observer)(TaskState::Spawn(identifier.clone()));

let schedule = self.runnable_schedule_cb(identifier.clone());
let (runnable, task) = async_task::Builder::new()
.metadata(TaskMetadata {
num_spawned_tasks: self.num_spawned_tasks.clone(),
identifier,
observer: self.observer.clone(),
})
.spawn_local(move |_m| future, schedule);
runnable.schedule();
task
}
Expand Down Expand Up @@ -104,29 +146,7 @@ where
.fetch_sub(num_woken_tasks, Ordering::Relaxed);
}

fn droppable_future<F>(
&self,
identifier: TaskIdentifier,
future: F,
) -> DroppableFuture<F, impl Fn()>
where
F: Future,
{
let observer = self.observer.clone();

// Spawn Task
self.num_spawned_tasks.fetch_add(1, Ordering::Relaxed);
observer(TaskState::Spawn(identifier.clone()));

// Droppable Future registering on_drop callback
let num_spawned_tasks = self.num_spawned_tasks.clone();
DroppableFuture::new(future, move || {
num_spawned_tasks.fetch_sub(1, Ordering::Relaxed);
observer(TaskState::Drop(identifier.clone()));
})
}

fn runnable_schedule_cb(&self, identifier: TaskIdentifier) -> impl Fn(async_task::Runnable) {
fn runnable_schedule_cb(&self, identifier: TaskIdentifier) -> impl Fn(TaskRunnable<O>) {
let sender = self.channel.0.clone();
let num_woken_tasks = self.num_woken_tasks.clone();
let observer = self.observer.clone();
Expand All @@ -145,7 +165,7 @@ mod tests {
use super::*;

#[test]
fn test_multiple_tasks() {
fn test_multiple_local_tasks() {
let executor = TickedAsyncExecutor::default();
executor
.spawn_local("A", async move {
Expand All @@ -167,7 +187,7 @@ mod tests {
}

#[test]
fn test_task_cancellation() {
fn test_local_tasks_cancellation() {
let executor = TickedAsyncExecutor::new(|_state| println!("{_state:?}"));
let task1 = executor.spawn_local("A", async move {
loop {
Expand Down Expand Up @@ -197,4 +217,36 @@ mod tests {
executor.tick();
}
}

#[test]
fn test_tasks_cancellation() {
let executor = TickedAsyncExecutor::new(|_state| println!("{_state:?}"));
let task1 = executor.spawn("A", async move {
loop {
tokio::task::yield_now().await;
}
});

let task2 = executor.spawn(format!("B"), async move {
loop {
tokio::task::yield_now().await;
}
});
assert_eq!(executor.num_tasks(), 2);
executor.tick();

executor
.spawn_local("CancelTasks", async move {
let (t1, t2) = join!(task1.cancel(), task2.cancel());
assert_eq!(t1, None);
assert_eq!(t2, None);
})
.detach();
assert_eq!(executor.num_tasks(), 3);

// Since we have cancelled the tasks above, the loops should eventually end
while executor.num_tasks() != 0 {
executor.tick();
}
}
}

0 comments on commit 07eda0e

Please sign in to comment.