From 7822f6152c29b54f9c09c02efe8ae5dfec0ffc85 Mon Sep 17 00:00:00 2001 From: Johnny Graettinger Date: Wed, 2 Oct 2024 10:36:26 -0500 Subject: [PATCH] job queue WIP --- Cargo.lock | 17 + crates/automations/Cargo.toml | 25 ++ crates/automations/src/executors.rs | 308 ++++++++++++++++++ crates/automations/src/lib.rs | 137 ++++++++ crates/automations/src/server.rs | 194 +++++++++++ crates/automations/tests/fibonacci.rs | 112 +++++++ crates/automations/tests/test_fibonacci.rs | 57 ++++ .../migrations/20241001042256_job_queue.sql | 97 ++++++ 8 files changed, 947 insertions(+) create mode 100644 crates/automations/Cargo.toml create mode 100644 crates/automations/src/executors.rs create mode 100644 crates/automations/src/lib.rs create mode 100644 crates/automations/src/server.rs create mode 100644 crates/automations/tests/fibonacci.rs create mode 100644 crates/automations/tests/test_fibonacci.rs create mode 100644 supabase/migrations/20241001042256_job_queue.sql diff --git a/Cargo.lock b/Cargo.lock index 7054a345ef..0606dc99b4 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -633,6 +633,23 @@ version = "1.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0c4b4d0bd25bd0b74681c0ad21497610ce1b7c91b1022cd21c80c6fbdd9476b0" +[[package]] +name = "automations" +version = "0.0.0" +dependencies = [ + "anyhow", + "coroutines", + "futures", + "insta", + "models", + "serde", + "serde_json", + "sqlx", + "tokio", + "tracing", + "tracing-subscriber", +] + [[package]] name = "avro" version = "0.0.0" diff --git a/crates/automations/Cargo.toml b/crates/automations/Cargo.toml new file mode 100644 index 0000000000..59d427ffda --- /dev/null +++ b/crates/automations/Cargo.toml @@ -0,0 +1,25 @@ +[package] +name = "automations" +version.workspace = true +rust-version.workspace = true +edition.workspace = true +authors.workspace = true +homepage.workspace = true +repository.workspace = true +license.workspace = true + +[dependencies] +coroutines = { path = "../coroutines" } +models = { path = "../models", features = ["sqlx-support"] } + +anyhow = { workspace = true } +futures = { workspace = true } +serde = { workspace = true } +serde_json = { workspace = true } +sqlx = { workspace = true } +tokio = { workspace = true } +tracing = { workspace = true } +tracing-subscriber = { workspace = true } + +[dev-dependencies] +insta = { workspace = true } diff --git a/crates/automations/src/executors.rs b/crates/automations/src/executors.rs new file mode 100644 index 0000000000..044264d158 --- /dev/null +++ b/crates/automations/src/executors.rs @@ -0,0 +1,308 @@ +use super::{server, BoxedRaw, Executor, PollOutcome, TaskType}; +use anyhow::Context; +use futures::future::{BoxFuture, FutureExt}; + +/// ObjSafe is an object-safe and type-erased trait which is implemented for all Executors. +pub trait ObjSafe: Send + Sync + 'static { + fn task_type(&self) -> TaskType; + + fn poll<'s>( + &'s self, + task_id: models::Id, + parent_id: Option, + state: &'s mut Option>, + inbox: &'s mut Option)>>>, + ) -> BoxFuture<'s, anyhow::Result>>; +} + +impl ObjSafe for E { + fn task_type(&self) -> TaskType { + E::TASK_TYPE + } + + fn poll<'s>( + &'s self, + task_id: models::Id, + parent_id: Option, + state: &'s mut Option>, + inbox: &'s mut Option)>>>, + ) -> BoxFuture<'s, anyhow::Result>> { + async move { + let mut state_parsed: E::State = if let Some(state) = state { + serde_json::from_str(state.get()).context("failed to decode task state")? + } else { + E::State::default() + }; + + let mut inbox_parsed: std::collections::VecDeque<(models::Id, Option)> = + inbox + .as_ref() + .into_iter() + .flatten() + .map(|sqlx::types::Json((task_id, rx))| { + if let Some(rx) = rx { + anyhow::Result::Ok((*task_id, Some(serde_json::from_str(rx.get())?))) + } else { + anyhow::Result::Ok((*task_id, None)) + } + }) + .collect::>() + .context("failed to decode received message")?; + + let outcome = E::poll( + self, + task_id, + parent_id, + &mut state_parsed, + &mut inbox_parsed, + ) + .await?; + + // Re-encode state for persistence. + // If we're Done, the state outcome is NULL which is the implicit Default. + if matches!(outcome, PollOutcome::Done) { + *state = None + } else { + *state = Some(sqlx::types::Json( + serde_json::value::to_raw_value(&state_parsed) + .context("failed to encode inner state")?, + )); + } + + // Re-encode the unconsumed portion of the inbox. + if inbox_parsed.is_empty() { + *inbox = None + } else { + *inbox = Some( + inbox_parsed + .into_iter() + .map(|(task_id, msg)| { + Ok(sqlx::types::Json(( + task_id, + match msg { + Some(msg) => Some( + serde_json::value::to_raw_value(&msg) + .context("failed to encode unconsumed inbox message")?, + ), + None => None, + }, + ))) + }) + .collect::>>()?, + ); + } + + Ok(match outcome { + PollOutcome::Done => PollOutcome::Done, + PollOutcome::Send(task_id, msg) => PollOutcome::Send(task_id, msg), + PollOutcome::Sleep(interval) => PollOutcome::Sleep(interval), + PollOutcome::Spawn(task_id, task_type, msg) => { + PollOutcome::Spawn(task_id, task_type, msg) + } + PollOutcome::Suspend => PollOutcome::Suspend, + PollOutcome::Yield(msg) => PollOutcome::Yield( + serde_json::value::to_raw_value(&msg) + .context("failed to encode yielded message")?, + ), + }) + } + .boxed() + } +} + +#[tracing::instrument(skip_all, fields(task_id, parent_id))] +pub async fn poll_task( + server::ReadyTask { + executor, + permit: _guard, + pool, + task: + server::DequeuedTask { + id: task_id, + type_: _, + parent_id, + mut inbox, + mut state, + mut last_heartbeat, + }, + }: server::ReadyTask, + heartbeat_timeout: std::time::Duration, +) -> anyhow::Result<()> { + let mut heartbeat_ticks = tokio::time::interval(heartbeat_timeout / 2); + let _instant = heartbeat_ticks.tick().await; // Discard immediate first tick. + + // Build a Future which forever maintains our heartbeat or errors. + let update_heartbeats = async { + loop { + let _instant = heartbeat_ticks.tick().await; + + last_heartbeat = + match update_heartbeat(&pool, task_id, heartbeat_timeout, last_heartbeat).await { + Ok(last_heartbeat) => last_heartbeat, + Err(err) => return err, + } + } + }; + tokio::pin!(update_heartbeats); + + // Poll the executor and `update_heartbeats` in tandem, so that a failure + // to update our heartbeat also cancels the executor poll. + let outcome = tokio::select! { + outcome = executor.poll(task_id, parent_id, &mut state, &mut inbox) => { outcome? }, + err = &mut update_heartbeats => return Err(err), + }; + + // The possibly long-lived polling operation is now complete. + // Build a Future that commits a (hopefully) brief transaction of `outcome`. + let persist_outcome = async { + let mut txn = pool.begin().await?; + () = persist_outcome(outcome, &mut *txn, task_id, parent_id, state, inbox).await?; + Ok(txn.commit().await?) + }; + + // Poll `persist_outcome` while continuing to poll `update_heartbeats`, + // to guarantee we cannot commit an outcome after our lease is lost. + tokio::select! { + result = persist_outcome => result, + err = update_heartbeats => Err(err), + } +} + +async fn update_heartbeat( + pool: &sqlx::PgPool, + task_id: models::Id, + heartbeat_timeout: std::time::Duration, + expect_heartbeat: String, +) -> anyhow::Result { + let update = sqlx::query!( + r#" + UPDATE internal.tasks + SET heartbeat = NOW() + WHERE task_id = $1 AND heartbeat::TEXT = $2 + RETURNING heartbeat::TEXT AS "heartbeat!"; + "#, + task_id as models::Id, + expect_heartbeat, + ) + .fetch_optional(pool); + + // We must guard against both explicit errors and also timeouts when updating + // the heartbeat, to ensure we bubble up an error that cancels our corresponding + // polling future prior to a complete `heartbeat_timeout` elapsing. + let updated = match tokio::time::timeout(heartbeat_timeout / 4, update).await { + Ok(Ok(Some(updated))) => updated, + Ok(Ok(None)) => anyhow::bail!("task heartbeat was unexpectedly updated externally"), + Ok(Err(err)) => return Err(anyhow::anyhow!(err).context("failed to update task heartbeat")), + Err(err) => return Err(anyhow::anyhow!(err).context("timed out updating task heartbeat")), + }; + + tracing::info!( + last = expect_heartbeat, + next = updated.heartbeat, + "updated heartbeat" + ); + + Ok(updated.heartbeat) +} + +async fn persist_outcome( + outcome: PollOutcome, + txn: &mut sqlx::PgConnection, + task_id: models::Id, + parent_id: Option, + state: Option>, + inbox: Option)>>>, +) -> anyhow::Result<()> { + use std::time::Duration; + + if let PollOutcome::Spawn(spawn_id, spawn_type, _msg) = &outcome { + sqlx::query!( + "SELECT internal.create_task($1, $2, $3)", + *spawn_id as models::Id, + *spawn_type as TaskType, + task_id as models::Id, + ) + .execute(&mut *txn) + .await + .context("failed to spawn new task row")?; + } + + if let Some((send_id, msg)) = match &outcome { + // When a task is spawned, send its first message. + PollOutcome::Spawn(spawn_id, _spawn_type, msg) => Some((*spawn_id, Some(msg))), + // If we're Done but have a parent, send it an EOF. + PollOutcome::Done => parent_id.map(|parent_id| (parent_id, None)), + // Send an arbitrary message to an identified task. + PollOutcome::Send(task_id, msg) => Some((*task_id, msg.as_ref())), + // Yield is sugar for sending to our parent. + PollOutcome::Yield(msg) => { + let Some(parent_id) = parent_id else { + anyhow::bail!("task yielded illegally, because it does not have a parent"); + }; + Some((parent_id, Some(msg))) + } + _ => None, + } { + sqlx::query!( + "SELECT internal.send_to_task($1, $2, $3::JSON);", + send_id as models::Id, + task_id as models::Id, + sqlx::types::Json(msg) as sqlx::types::Json<_>, + ) + .execute(&mut *txn) + .await + .with_context(|| format!("failed to send message to {send_id:?}"))?; + } + + let wake_at_interval = if inbox.is_some() { + Some(Duration::ZERO) // Always poll immediately if inbox items remain. + } else { + match &outcome { + PollOutcome::Sleep(interval) => Some(*interval), + // These outcomes do not suspend the task, and it should wake as soon as possible. + PollOutcome::Spawn(..) | PollOutcome::Send(..) | PollOutcome::Yield(..) => { + Some(Duration::ZERO) + } + // Suspend indefinitely (note that NOW() + NULL::INTERVAL is NULL). + PollOutcome::Done | PollOutcome::Suspend => None, + } + }; + + let updated = sqlx::query!( + r#" + UPDATE internal.tasks SET + heartbeat = '0001-01-01T00:00:00Z', + inbox = $3::JSON[] || inbox_next, + inbox_next = NULL, + inner_state = $2::JSON, + wake_at = + CASE WHEN inbox_next IS NOT NULL + THEN NOW() + ELSE NOW() + $4::INTERVAL + END + WHERE task_id = $1 + RETURNING wake_at IS NULL AS "suspended!" + "#, + task_id as models::Id, + state as Option>, + inbox as Option)>>>, + wake_at_interval as Option, + ) + .fetch_one(&mut *txn) + .await + .context("failed to update task row")?; + + // If we're Done and also successfully suspended, then delete ourselves. + // (Otherwise, the task has been left in a like-new state). + if matches!(&outcome, PollOutcome::Done if updated.suspended) { + sqlx::query!( + "DELETE FROM internal.tasks WHERE task_id = $1;", + task_id as models::Id, + ) + .execute(&mut *txn) + .await + .context("failed to delete task row")?; + } + + Ok(()) +} diff --git a/crates/automations/src/lib.rs b/crates/automations/src/lib.rs new file mode 100644 index 0000000000..b9077d27b7 --- /dev/null +++ b/crates/automations/src/lib.rs @@ -0,0 +1,137 @@ +use anyhow::Context; +use std::sync::Arc; + +pub mod executors; +pub mod server; + +/// BoxedRaw is a type-erased raw JSON message. +type BoxedRaw = Box; + +/// TaskType is the type of a task, and maps it to an Executor. +#[derive( + Debug, + serde::Deserialize, + serde::Serialize, + sqlx::Type, + PartialOrd, + PartialEq, + Ord, + Eq, + Clone, + Copy, +)] +#[sqlx(transparent)] +pub struct TaskType(pub i16); + +/// PollOutcome is the outcome of an `Executor::poll()` for a given task. +#[derive(Debug)] +pub enum PollOutcome { + /// Spawn a new TaskId with the given TaskType and send a first message. + /// The TaskId must not exist. + Spawn(models::Id, TaskType, BoxedRaw), + /// Send a message (Some) or EOF (None) to another TaskId, which must exist. + Send(models::Id, Option), + /// Yield to send a message to this task's parent. + Yield(Yield), + /// Sleep for at-most the indicated Duration, then poll again. + /// The task may be woken earlier if it receives a message. + Sleep(std::time::Duration), + /// Suspend the task until it receives a message. + Suspend, + /// Done completes and removes the task. + /// If this task has a parent, that parent is sent an EOF. + Done, +} + +/// Executor is the core trait implemented by executors of various task task types. +pub trait Executor: Send + Sync + 'static { + const TASK_TYPE: TaskType; + + type Receive: serde::de::DeserializeOwned + serde::Serialize + Send; + type State: Default + serde::de::DeserializeOwned + serde::Serialize + Send; + type Yield: serde::Serialize; + + fn poll<'s>( + &'s self, + task_id: models::Id, + parent_id: Option, + state: &'s mut Self::State, + inbox: &'s mut std::collections::VecDeque<(models::Id, Option)>, + ) -> impl std::future::Future>> + Send + 's; +} + +/// Executors holds registered implementations of Executor, +/// and serves them. +pub struct Executors(Vec>); + +impl Executors { + pub const fn new() -> Self { + Self(Vec::new()) + } + + pub fn register(mut self, executor: E) -> Self { + let index = match self + .0 + .binary_search_by_key(&E::TASK_TYPE, |entry| entry.task_type()) + { + Ok(_index) => panic!("an Executor for {:?} is already registered", E::TASK_TYPE), + Err(index) => index, + }; + + self.0.insert(index, Arc::new(executor)); + self + } + + pub async fn serve( + self, + permits: u32, + pool: sqlx::PgPool, + poll_interval: std::time::Duration, + heartbeat_timeout: std::time::Duration, + shutdown: impl std::future::Future, + ) { + server::serve( + self, + permits, + pool, + poll_interval, + heartbeat_timeout, + shutdown, + ) + .await + } +} + +impl PollOutcome { + pub fn spawn( + spawn_id: models::Id, + task_type: TaskType, + msg: M, + ) -> anyhow::Result { + Ok(Self::Spawn( + spawn_id, + task_type, + serde_json::value::to_raw_value(&msg).context("failed to encode task spawn message")?, + )) + } + + pub fn send(task_id: models::Id, msg: Option) -> anyhow::Result { + Ok(Self::Send( + task_id, + match msg { + Some(msg) => Some( + serde_json::value::to_raw_value(&msg) + .context("failed to encode sent message")?, + ), + None => None, + }, + )) + } +} + +pub fn next_task_id() -> models::Id { + static ID_GENERATOR: std::sync::LazyLock> = + std::sync::LazyLock::new(|| std::sync::Mutex::new(models::IdGenerator::new(1))); + + ID_GENERATOR.lock().unwrap().next() +} diff --git a/crates/automations/src/server.rs b/crates/automations/src/server.rs new file mode 100644 index 0000000000..db054e2ef5 --- /dev/null +++ b/crates/automations/src/server.rs @@ -0,0 +1,194 @@ +use super::{executors, BoxedRaw, Executors, TaskType}; +use futures::stream::StreamExt; +use std::sync::Arc; + +pub struct ReadyTask { + pub executor: Arc, + pub permit: tokio::sync::OwnedSemaphorePermit, + pub pool: sqlx::PgPool, + pub task: DequeuedTask, +} + +pub struct DequeuedTask { + pub id: models::Id, + pub type_: TaskType, + pub parent_id: Option, + pub inbox: Option)>>>, + pub state: Option>, + pub last_heartbeat: String, +} + +pub async fn serve( + executors: Executors, + permits: u32, + pool: sqlx::PgPool, + poll_interval: std::time::Duration, + heartbeat_timeout: std::time::Duration, + shutdown: impl std::future::Future, +) { + let semaphore = Arc::new(tokio::sync::Semaphore::new(permits as usize)); + + // Use Box::pin to ensure we can fullly drop `ready_tasks` later, + // as it may hold `semaphore` permits. + let mut ready_tasks = Box::pin(ready_tasks( + executors, + pool.clone(), + poll_interval, + heartbeat_timeout, + semaphore.clone(), + )); + tokio::pin!(shutdown); + + // Poll for ready tasks and start them until `shutdown` is signaled. + while let Some(ready_tasks) = tokio::select! { + ready = ready_tasks.next() => ready, + () = &mut shutdown => None, + } { + let ready_tasks: Vec = match ready_tasks { + Ok(tasks) => tasks, + Err(err) => { + tracing::error!(?err, "failed to poll for tasks (will retry)"); + Vec::new() + } + }; + + for ready in ready_tasks { + tokio::spawn(async move { + let (task_id, task_type, parent_id) = + (ready.task.id, ready.task.type_, ready.task.parent_id); + + if let Err(err) = executors::poll_task(ready, heartbeat_timeout).await { + tracing::error!( + ?task_id, + ?task_type, + ?parent_id, + ?err, + "failed to poll task (will retry)" + ); + // The task will be retried once it's heartbeat times out. + } + }); + } + } + tracing::info!("task polling loop signaled to stop and is awaiting running tasks"); + std::mem::drop(ready_tasks); + + // Acquire all permits, when only happens after all running tasks have finished. + let _ = semaphore.acquire_many_owned(permits).await.unwrap(); +} + +pub fn ready_tasks( + executors: Executors, + pool: sqlx::PgPool, + poll_interval: std::time::Duration, + heartbeat_timeout: std::time::Duration, + semaphore: Arc, +) -> impl futures::stream::Stream>> { + let task_types: Vec<_> = executors.0.iter().map(|e| e.task_type().0).collect(); + + coroutines::coroutine(move |mut co| async move { + loop { + () = ready_tasks_iter( + &mut co, + &executors, + heartbeat_timeout, + poll_interval, + &pool, + &semaphore, + &task_types, + ) + .await; + } + }) +} + +async fn ready_tasks_iter( + co: &mut coroutines::Suspend>, ()>, + executors: &Executors, + heartbeat_timeout: std::time::Duration, + poll_interval: std::time::Duration, + pool: &sqlx::PgPool, + semaphore: &Arc, + task_types: &[i16], +) { + // Block until at least one permit is available. + if semaphore.available_permits() == 0 { + let _ = semaphore.clone().acquire_owned().await.unwrap(); + } + + // Acquire all available permits, and then poll for up to that many tasks. + let mut permits = semaphore + .clone() + .acquire_many_owned(semaphore.available_permits() as u32) + .await + .unwrap(); + + let dequeued = sqlx::query_as!( + DequeuedTask, + r#" + WITH picked AS ( + SELECT task_id + FROM internal.tasks + WHERE + task_type = ANY($1) AND + wake_at < NOW() AND + heartbeat < NOW() - $2::INTERVAL + ORDER BY wake_at DESC + LIMIT $3 + FOR UPDATE SKIP LOCKED + ) + UPDATE internal.tasks + SET heartbeat = NOW() + WHERE task_id in (SELECT task_id FROM picked) + RETURNING + task_id as "id: models::Id", + task_type as "type_: TaskType", + parent_id as "parent_id: models::Id", + inbox as "inbox: Vec)>>", + inner_state as "state: sqlx::types::Json", + heartbeat::TEXT as "last_heartbeat!"; + "#, + &task_types as &[i16], + heartbeat_timeout as std::time::Duration, + permits.num_permits() as i64, + ) + .fetch_all(pool) + .await; + + let dequeued = match dequeued { + Ok(dequeued) => { + tracing::debug!(dequeued = dequeued.len(), "completed poll"); + dequeued + } + Err(err) => { + () = co.yield_(Err(err)).await; + Vec::new() // We'll sleep as if it were idle, then retry. + } + }; + + let ready = dequeued + .into_iter() + .map(|task| { + let Ok(index) = task_types.binary_search(&task.type_.0) else { + panic!("polled {:?} with unexpected {:?}", task.id, task.type_); + }; + ReadyTask { + task, + executor: executors.0[index].clone(), + permit: permits.split(1).unwrap(), + pool: pool.clone(), + } + }) + .collect(); + + () = co.yield_(Ok(ready)).await; + + // If permits remain, there were not enough tasks to dequeue. + // Sleep for up-to `poll_interval`, cancelling early if a task completes. + if permits.num_permits() != 0 { + tokio::select! { + () = tokio::time::sleep(poll_interval) => (), + _ = semaphore.clone().acquire_owned() => (), // Cancel sleep. + } + } +} diff --git a/crates/automations/tests/fibonacci.rs b/crates/automations/tests/fibonacci.rs new file mode 100644 index 0000000000..66ef84d36a --- /dev/null +++ b/crates/automations/tests/fibonacci.rs @@ -0,0 +1,112 @@ +use automations::PollOutcome; +use std::collections::VecDeque; + +#[derive(Debug, serde::Deserialize, serde::Serialize)] +pub struct FibMessage { + value: i64, +} + +#[derive(Debug, serde::Deserialize, serde::Serialize)] +pub enum FibState { + Init, + SpawnOne(i64), + SpawnTwo, + Waiting { partial: i64, pending: usize }, + Finished, +} + +impl Default for FibState { + fn default() -> Self { + Self::Init + } +} + +pub struct Fibonacci {} + +impl automations::Executor for Fibonacci { + const TASK_TYPE: automations::TaskType = automations::TaskType(1); + + type Receive = FibMessage; + type Yield = FibMessage; + type State = FibState; + + #[tracing::instrument( + ret, + err(level = tracing::Level::ERROR), + skip_all, + fields(?task_id, ?parent_id, ?state, ?inbox), + )] + async fn poll<'s>( + &'s self, + task_id: models::Id, + parent_id: Option, + state: &'s mut Self::State, + inbox: &'s mut VecDeque<(models::Id, Option)>, + ) -> anyhow::Result> { + if let FibState::SpawnOne(value) = state { + let spawn = PollOutcome::spawn( + automations::next_task_id(), + Self::TASK_TYPE, + FibMessage { value: *value - 2 }, + ); + *state = FibState::Waiting { + partial: 0, + pending: 2, + }; + + return spawn; + } + + match (std::mem::take(state), inbox.pop_front()) { + // Base case: + (FibState::Init, Some((_parent_id, Some(FibMessage { value })))) if value <= 2 => { + *state = FibState::Finished; + Ok(PollOutcome::Yield(FibMessage { value: 1 })) + } + + // Recursive case: + (FibState::Init, Some((_parent_id, Some(FibMessage { value })))) => { + *state = FibState::SpawnOne(value); + + PollOutcome::spawn( + automations::next_task_id(), + Self::TASK_TYPE, + FibMessage { value: value - 1 }, + ) + } + + (FibState::Waiting { partial, pending }, None) => { + *state = FibState::Waiting { partial, pending }; + Ok(PollOutcome::Suspend) + } + + ( + FibState::Waiting { partial, pending }, + Some((_child_id, Some(FibMessage { value }))), + ) => { + *state = FibState::Waiting { + partial: partial + value, + pending, + }; + Ok(PollOutcome::Suspend) + } + + (FibState::Waiting { partial, pending }, Some((_child_id, None))) => { + if pending != 1 || parent_id.is_none() { + *state = FibState::Waiting { + partial, + pending: pending - 1, + }; + Ok(PollOutcome::Suspend) + } else { + *state = FibState::Finished; + Ok(PollOutcome::Yield(FibMessage { value: partial })) + } + } + + (FibState::Finished, None) => Ok(PollOutcome::Done), + + state => anyhow::bail!("unexpected poll with state {state:?} and inbox {inbox:?}"), + } + } +} diff --git a/crates/automations/tests/test_fibonacci.rs b/crates/automations/tests/test_fibonacci.rs new file mode 100644 index 0000000000..03dc800098 --- /dev/null +++ b/crates/automations/tests/test_fibonacci.rs @@ -0,0 +1,57 @@ +use std::time::Duration; + +mod fibonacci; + +#[tokio::test] +async fn test_fibonacci_bench() { + let subscriber = tracing_subscriber::FmtSubscriber::builder() + .with_env_filter( + tracing_subscriber::EnvFilter::builder() + .with_default_directive(tracing::level_filters::LevelFilter::INFO.into()) + .from_env_lossy(), + ) + .finish(); + let _ = tracing::subscriber::set_global_default(subscriber); + + let pool = sqlx::postgres::PgPool::connect(&FIXED_DATABASE_URL) + .await + .expect("connect"); + + sqlx::query!("delete from internal.tasks;") + .execute(&pool) + .await + .unwrap(); + + let root_id = automations::next_task_id(); + + sqlx::query!( + "SELECT internal.create_task($1, 1::SMALLINT, NULL::public.flowid)", + root_id as models::Id + ) + .execute(&pool) + .await + .unwrap(); + + sqlx::query!( + r#"SELECT internal.send_to_task($1, '00:00:00:00:00:00:00:00'::flowid, '{"value":20}')"#, + root_id as models::Id + ) + .execute(&pool) + .await + .unwrap(); + + let stop = tokio::time::sleep(Duration::from_secs(60)); + + () = automations::Executors::new() + .register(fibonacci::Fibonacci {}) + .serve( + 30, + pool, + Duration::from_secs(5), + Duration::from_secs(20), + stop, + ) + .await; +} + +const FIXED_DATABASE_URL: &str = "postgresql://postgres:postgres@localhost:5432/postgres"; diff --git a/supabase/migrations/20241001042256_job_queue.sql b/supabase/migrations/20241001042256_job_queue.sql new file mode 100644 index 0000000000..aaa2272e51 --- /dev/null +++ b/supabase/migrations/20241001042256_job_queue.sql @@ -0,0 +1,97 @@ +BEGIN; + +CREATE TABLE internal.tasks ( + task_id public.flowid PRIMARY KEY NOT NULL, + task_type SMALLINT NOT NULL, + parent_id public.flowid, + + inner_state JSON, -- NULL is equivalent to Default. + + wake_at TIMESTAMPTZ, + inbox JSON[], + inbox_next JSON[], + + heartbeat TIMESTAMPTZ NOT NULL DEFAULT '0001-01-01T00:00:00Z' +); + +CREATE FUNCTION internal.create_task( + p_task_id public.flowid, + p_task_type SMALLINT, + p_parent_id public.flowid +) +RETURNS VOID +SET search_path = '' +AS $$ +BEGIN + + INSERT INTO internal.tasks (task_id, task_type, parent_id) + VALUES (p_task_id, p_task_type, p_parent_id); + +END; +$$ LANGUAGE plpgsql; + + +CREATE FUNCTION internal.send_to_task( + p_task_id public.flowid, + p_from_id public.flowid, + p_message JSON +) +RETURNS VOID +SET search_path = '' +AS $$ +BEGIN + + UPDATE internal.tasks SET + wake_at = LEAST(wake_at, NOW()), + inbox = + CASE WHEN heartbeat = '0001-01-01T00:00:00Z' + THEN ARRAY_APPEND(inbox, JSON_BUILD_ARRAY(p_from_id, p_message)) + ELSE inbox + END, + inbox_next = + CASE WHEN heartbeat = '0001-01-01T00:00:00Z' + THEN inbox_next + ELSE ARRAY_APPEND(inbox_next, JSON_BUILD_ARRAY(p_from_id, p_message)) + END + WHERE task_id = p_task_id; + +END; +$$ LANGUAGE plpgsql; + + +CREATE INDEX idx_tasks_ready_at ON internal.tasks + USING btree (wake_at) INCLUDE (task_type); + +COMMENT ON TABLE internal.tasks IS ' +The tasks table supports a distributed and asynchronous task execution system. + +Tasks are poll-able futures which are identified by (task_type, task_key). +They may be short-lived and polled just once, or very long-lived and polled +many times over their life-cycle. + +Tasks are polled by executors which dequeue from the tasks table and run +bespoke handlers parameterized by the task type, key, and context. A polling +routine may take an arbitrarily long amount of time to finish, and the executor +is required to periodically update the task heartbeat as it runs. + +A task is polled by at-most one executor at a time. Executor failures are +detected through a failure to update the task heartbeat within a threshold amount +of time, which makes the task re-eligible for dequeue by another executor. + +A task may be schedule to run many times prior to its actual dequeue by a runner. +If the task has yet to be dequeued, multiple scheduled polls of a task are reduced +by minimizing over `wake_at` and through JSON Merge-Patch of the polling `context`. + +If the task is currently being polled, additional scheduled polls are reduced +by minimizing over `next_wake_at` and through JSON Merge-Patch of `next_context`. +A running executor may schedule future polls its current task, to implement a +recursive or periodic task lifecycle. + +When an executor completes polling a task, the task is updated to be eligible +for a future dequeue in accordance with its `next_wake_at`. Or, if `next_wake_at` +remains NULL then the task is considered completed and its row is removed. +'; + +COMMENT ON COLUMN internal.tasks.task_id IS 'Generated unique ID for the task'; + +COMMIT; \ No newline at end of file