Skip to content

Commit

Permalink
refactor(streaming): retrieve epoch from task local storage (#9488)
Browse files Browse the repository at this point in the history
Signed-off-by: Bugen Zhao <[email protected]>
  • Loading branch information
BugenZhao authored Apr 27, 2023
1 parent c82e175 commit a086a2e
Show file tree
Hide file tree
Showing 7 changed files with 139 additions and 86 deletions.
45 changes: 45 additions & 0 deletions src/common/src/util/epoch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,51 @@ impl EpochPair {
Self::new(curr, curr - 1)
}
}

/// Task-local storage for the epoch pair.
pub mod task_local {
use futures::Future;
use tokio::task_local;

use super::{Epoch, EpochPair};

task_local! {
static TASK_LOCAL_EPOCH_PAIR: EpochPair;
}

/// Retrieve the current epoch from the task local storage.
///
/// This value is updated after every yield of the barrier message. Returns `None` if the first
/// barrier message is not yielded.
pub fn curr_epoch() -> Option<Epoch> {
TASK_LOCAL_EPOCH_PAIR.try_with(|e| Epoch(e.curr)).ok()
}

/// Retrieve the previous epoch from the task local storage.
///
/// This value is updated after every yield of the barrier message. Returns `None` if the first
/// barrier message is not yielded.
pub fn prev_epoch() -> Option<Epoch> {
TASK_LOCAL_EPOCH_PAIR.try_with(|e| Epoch(e.prev)).ok()
}

/// Retrieve the epoch pair from the task local storage.
///
/// This value is updated after every yield of the barrier message. Returns `None` if the first
/// barrier message is not yielded.
pub fn epoch() -> Option<EpochPair> {
TASK_LOCAL_EPOCH_PAIR.try_with(|e| *e).ok()
}

/// Provides the given epoch pair in the task local storage for the scope of the given future.
pub async fn scope<F>(epoch: EpochPair, f: F) -> F::Output
where
F: Future,
{
TASK_LOCAL_EPOCH_PAIR.scope(epoch, f).await
}
}

#[cfg(test)]
mod tests {
use chrono::{Local, TimeZone, Utc};
Expand Down
45 changes: 23 additions & 22 deletions src/expr/src/expr/expr_proctime.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,11 @@
use risingwave_common::array::DataChunk;
use risingwave_common::row::OwnedRow;
use risingwave_common::types::{DataType, Datum, ScalarImpl};
use risingwave_common::util::epoch;
use risingwave_pb::expr::expr_node::{RexNode, Type};
use risingwave_pb::expr::ExprNode;

use super::{Expression, ValueImpl, CONTEXT};
use super::{Expression, ValueImpl};
use crate::{bail, ensure, ExprError, Result};

#[derive(Debug)]
Expand All @@ -45,57 +46,57 @@ impl<'a> TryFrom<&'a ExprNode> for ProcTimeExpression {
}
}

/// Get the processing time in microseconds from the task-local epoch.
fn proc_time_us_from_epoch() -> Result<ScalarImpl> {
let us = epoch::task_local::curr_epoch()
.ok_or(ExprError::Context)?
.as_unix_millis()
* 1000;
Ok(ScalarImpl::Int64(us as i64))
}

#[async_trait::async_trait]
impl Expression for ProcTimeExpression {
fn return_type(&self) -> DataType {
DataType::Timestamptz
}

async fn eval_v2(&self, input: &DataChunk) -> Result<ValueImpl> {
let proctime = CONTEXT
.try_with(|context| context.get_proctime())
.map_err(|_| ExprError::Context)?;
let datum = Some(ScalarImpl::Int64(proctime as i64));

Ok(ValueImpl::Scalar {
value: datum,
proc_time_us_from_epoch().map(|s| ValueImpl::Scalar {
value: Some(s),
capacity: input.capacity(),
})
}

async fn eval_row(&self, _input: &OwnedRow) -> Result<Datum> {
let proctime = CONTEXT
.try_with(|context| context.get_proctime())
.map_err(|_| ExprError::Context)?;
let datum = Some(ScalarImpl::Int64(proctime as i64));

Ok(datum)
proc_time_us_from_epoch().map(Some)
}
}

#[cfg(test)]
mod tests {
use risingwave_common::array::DataChunk;
use risingwave_common::types::ScalarRefImpl;
use risingwave_common::util::epoch::Epoch;
use risingwave_common::util::epoch::{Epoch, EpochPair};

use super::*;
use crate::expr::{ExprContext, CONTEXT};

#[tokio::test]
async fn test_expr_proctime() {
let proctime_expr = ProcTimeExpression::new();
let epoch = Epoch::now();
let time_us = epoch.as_unix_millis() * 1000;
let time_datum = Some(ScalarRefImpl::Int64(time_us as i64));
let context = ExprContext::new(epoch);
let curr_epoch = Epoch::now();
let epoch = EpochPair {
curr: curr_epoch.0,
prev: 0,
};
let chunk = DataChunk::new_dummy(3);

let array = CONTEXT
.scope(context, proctime_expr.eval(&chunk))
let array = epoch::task_local::scope(epoch, proctime_expr.eval(&chunk))
.await
.unwrap();

let time_us = curr_epoch.as_unix_millis() * 1000;
let time_datum = Some(ScalarRefImpl::Int64(time_us as i64));
for datum_ref in array.iter() {
assert_eq!(datum_ref, time_datum)
}
Expand Down
22 changes: 0 additions & 22 deletions src/expr/src/expr/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,6 @@ use futures_util::TryFutureExt;
use risingwave_common::array::{ArrayRef, DataChunk};
use risingwave_common::row::{OwnedRow, Row};
use risingwave_common::types::{DataType, Datum};
use risingwave_common::util::epoch::Epoch;
use static_assertions::const_assert;

pub use self::build::*;
Expand Down Expand Up @@ -193,24 +192,3 @@ pub type ExpressionRef = Arc<dyn Expression>;
/// See also <https://github.com/risingwavelabs/risingwave/issues/4625>.
#[allow(dead_code)]
const STRICT_MODE: bool = false;

/// The context used by expressions.
#[derive(Clone)]
pub struct ExprContext {
/// The epoch that an executor currently in.
curr_epoch: Epoch,
}

impl ExprContext {
pub fn new(curr_epoch: Epoch) -> Self {
Self { curr_epoch }
}

pub fn get_proctime(&self) -> u64 {
self.curr_epoch.as_unix_millis() * 1000
}
}

tokio::task_local! {
pub static CONTEXT: ExprContext;
}
2 changes: 2 additions & 0 deletions src/stream/src/executor/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,8 @@ pub type BoxedExecutor = Box<dyn Executor>;
pub type MessageStreamItem = StreamExecutorResult<Message>;
pub type BoxedMessageStream = BoxStream<'static, MessageStreamItem>;

pub use risingwave_common::util::epoch::task_local::{curr_epoch, epoch, prev_epoch};

pub trait MessageStream = futures::Stream<Item = MessageStreamItem> + Send;

/// Static information of an executor.
Expand Down
37 changes: 11 additions & 26 deletions src/stream/src/executor/project.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ use multimap::MultiMap;
use risingwave_common::array::column::Column;
use risingwave_common::array::StreamChunk;
use risingwave_common::catalog::{Field, Schema};
use risingwave_expr::expr::{BoxedExpression, ExprContext, CONTEXT};
use risingwave_expr::expr::BoxedExpression;

use super::*;

Expand Down Expand Up @@ -115,7 +115,6 @@ impl Inner {
async fn map_filter_chunk(
&self,
chunk: StreamChunk,
context: ExprContext,
) -> StreamExecutorResult<Option<StreamChunk>> {
let chunk = if chunk.selectivity() <= self.materialize_selectivity_threshold {
chunk.compact()
Expand All @@ -126,15 +125,11 @@ impl Inner {
let mut projected_columns = Vec::new();

for expr in &self.exprs {
let evaluated_expr = CONTEXT
.scope(
context.clone(),
expr.eval_infallible(&data_chunk, |err| {
self.ctx.on_compute_error(err, &self.info.identity)
}),
)
let evaluated_expr = expr
.eval_infallible(&data_chunk, |err| {
self.ctx.on_compute_error(err, &self.info.identity)
})
.await;

let new_column = Column::new(evaluated_expr);
projected_columns.push(new_column);
}
Expand Down Expand Up @@ -175,13 +170,8 @@ impl Inner {

#[try_stream(ok = Message, error = StreamExecutorError)]
async fn execute(self, input: BoxedExecutor) {
let mut input = input.execute();
let first_barrier = expect_first_barrier(&mut input).await?;
let mut context = ExprContext::new(first_barrier.get_curr_epoch());
yield Message::Barrier(first_barrier);

#[for_await]
for msg in input {
for msg in input.execute() {
let msg = msg?;
match msg {
Message::Watermark(w) => {
Expand All @@ -190,16 +180,11 @@ impl Inner {
yield Message::Watermark(watermark)
}
}
Message::Chunk(chunk) => {
match self.map_filter_chunk(chunk, context.clone()).await? {
Some(new_chunk) => yield Message::Chunk(new_chunk),
None => continue,
}
}
Message::Barrier(barrier) => {
context = ExprContext::new(barrier.get_curr_epoch());
yield Message::Barrier(barrier);
}
Message::Chunk(chunk) => match self.map_filter_chunk(chunk).await? {
Some(new_chunk) => yield Message::Chunk(new_chunk),
None => continue,
},
m => yield m,
}
}
}
Expand Down
32 changes: 16 additions & 16 deletions src/stream/src/executor/wrapper.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ use super::{
use crate::task::ActorId;

mod epoch_check;
mod epoch_provide;
mod schema_check;
mod trace;
mod update_check;
Expand Down Expand Up @@ -85,14 +86,7 @@ impl WrapperExecutor {
extra.metrics,
stream,
);
// Await tree
let stream =
trace::instrument_await_tree(info.clone(), extra.actor_id, extra.executor_id, stream);

// Schema check
let stream = schema_check::schema_check(info.clone(), stream);
// Epoch check
let stream = epoch_check::epoch_check(info.clone(), stream);
// Update check
let stream = update_check::update_check(info, stream);

Expand All @@ -102,7 +96,7 @@ impl WrapperExecutor {
#[allow(clippy::let_and_return)]
fn wrap_release(
enable_executor_row_count: bool,
info: Arc<ExecutorInfo>,
_info: Arc<ExecutorInfo>,
extra: ExtraInfo,
stream: impl MessageStream + 'static,
) -> impl MessageStream + 'static {
Expand All @@ -114,14 +108,6 @@ impl WrapperExecutor {
extra.metrics,
stream,
);
// Await tree
let stream =
trace::instrument_await_tree(info.clone(), extra.actor_id, extra.executor_id, stream);

// Schema check
let stream = schema_check::schema_check(info.clone(), stream);
// Epoch check
let stream = epoch_check::epoch_check(info, stream);

stream
}
Expand All @@ -132,6 +118,20 @@ impl WrapperExecutor {
extra: ExtraInfo,
stream: impl MessageStream + 'static,
) -> BoxedMessageStream {
// -- Shared wrappers --

// Await tree
let stream =
trace::instrument_await_tree(info.clone(), extra.actor_id, extra.executor_id, stream);

// Schema check
let stream = schema_check::schema_check(info.clone(), stream);
// Epoch check
let stream = epoch_check::epoch_check(info.clone(), stream);

// Epoch provide
let stream = epoch_provide::epoch_provide(stream);

if cfg!(debug_assertions) {
Self::wrap_debug(enable_executor_row_count, info, extra, stream).boxed()
} else {
Expand Down
42 changes: 42 additions & 0 deletions src/stream/src/executor/wrapper/epoch_provide.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
// Copyright 2023 RisingWave Labs
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

use futures::{pin_mut, TryStreamExt};
use futures_async_stream::try_stream;
use risingwave_common::util::epoch;

use crate::executor::error::StreamExecutorError;
use crate::executor::{Message, MessageStream};

/// Streams wrapped by `epoch_provide` is able to retrieve the current epoch pair from the functions
/// from [`epoch::task_local`].
#[try_stream(ok = Message, error = StreamExecutorError)]
pub async fn epoch_provide(input: impl MessageStream) {
pin_mut!(input);

let mut epoch = None;

while let Some(message) = if let Some(epoch) = epoch {
epoch::task_local::scope(epoch, input.try_next()).await?
} else {
input.try_next().await?
} {
// The inner executor has yielded a new barrier message. In next polls, we will provide the
// updated epoch pair.
if let Message::Barrier(b) = &message {
epoch = Some(b.epoch);
}
yield message;
}
}

0 comments on commit a086a2e

Please sign in to comment.