From 03cc2aebbc8dddaac135bdcb77b2b4395dd2a8af Mon Sep 17 00:00:00 2001 From: Runji Wang Date: Wed, 15 Mar 2023 16:52:47 +0800 Subject: [PATCH] refactor(expr): make evaluation async (#8229) Signed-off-by: Runji Wang --- Cargo.lock | 3 + src/batch/src/executor/filter.rs | 2 +- src/batch/src/executor/hash_agg.rs | 2 +- src/batch/src/executor/hop_window.rs | 4 +- src/batch/src/executor/join/hash_join.rs | 73 ++++++--- .../src/executor/join/nested_loop_join.rs | 22 ++- src/batch/src/executor/project.rs | 10 +- src/batch/src/executor/project_set.rs | 11 +- src/batch/src/executor/sort_agg.rs | 33 ++-- src/batch/src/executor/table_function.rs | 2 +- src/batch/src/executor/test_utils.rs | 2 +- src/batch/src/executor/update.rs | 10 +- src/batch/src/executor/values.rs | 2 +- src/common/src/types/mod.rs | 2 + src/expr/Cargo.toml | 3 + src/expr/benches/expr.rs | 27 +++- src/expr/src/expr/build_expr_from_prost.rs | 6 +- src/expr/src/expr/expr_array_concat.rs | 18 ++- src/expr/src/expr/expr_array_distinct.rs | 14 +- src/expr/src/expr/expr_array_to_string.rs | 17 +- src/expr/src/expr/expr_binary_bytes.rs | 20 +-- src/expr/src/expr/expr_binary_nonnull.rs | 68 ++++---- src/expr/src/expr/expr_binary_nullable.rs | 43 ++--- src/expr/src/expr/expr_case.rs | 50 +++--- src/expr/src/expr/expr_coalesce.rs | 21 +-- src/expr/src/expr/expr_concat_ws.rs | 39 +++-- src/expr/src/expr/expr_field.rs | 21 +-- src/expr/src/expr/expr_in.rs | 26 +-- src/expr/src/expr/expr_input_ref.rs | 11 +- src/expr/src/expr/expr_is_null.rs | 40 +++-- src/expr/src/expr/expr_jsonb_access.rs | 13 +- src/expr/src/expr/expr_literal.rs | 17 +- src/expr/src/expr/expr_nested_construct.rs | 39 ++--- src/expr/src/expr/expr_quaternary_bytes.rs | 12 +- src/expr/src/expr/expr_regexp.rs | 9 +- src/expr/src/expr/expr_some_all.rs | 31 ++-- src/expr/src/expr/expr_ternary_bytes.rs | 24 +-- src/expr/src/expr/expr_to_char_const_tmpl.rs | 9 +- .../src/expr/expr_to_timestamp_const_tmpl.rs | 9 +- src/expr/src/expr/expr_udf.rs | 29 ++-- src/expr/src/expr/expr_unary.rs | 44 ++--- src/expr/src/expr/expr_vnode.rs | 17 +- src/expr/src/expr/mod.rs | 58 ++++++- src/expr/src/expr/template.rs | 29 +++- src/expr/src/expr/template_fast.rs | 70 ++++---- .../src/table_function/generate_series.rs | 74 ++++----- src/expr/src/table_function/mod.rs | 20 +-- src/expr/src/table_function/regexp_matches.rs | 5 +- src/expr/src/table_function/unnest.rs | 5 +- src/expr/src/table_function/user_defined.rs | 17 +- src/expr/src/vector_op/agg/aggregator.rs | 5 +- .../vector_op/agg/approx_count_distinct.rs | 16 +- src/expr/src/vector_op/agg/array_agg.rs | 32 ++-- src/expr/src/vector_op/agg/count_star.rs | 5 +- src/expr/src/vector_op/agg/filter.rs | 81 +++++----- src/expr/src/vector_op/agg/general_agg.rs | 70 ++++---- .../src/vector_op/agg/general_distinct_agg.rs | 63 ++++---- src/expr/src/vector_op/agg/string_agg.rs | 28 ++-- src/frontend/src/expr/mod.rs | 7 +- src/frontend/src/scheduler/local.rs | 11 +- src/stream/src/common/infallible_expr.rs | 56 ------- src/stream/src/common/mod.rs | 2 - src/stream/src/executor/aggregation/mod.rs | 4 +- src/stream/src/executor/dynamic_filter.rs | 21 ++- src/stream/src/executor/filter.rs | 85 +++++----- src/stream/src/executor/global_simple_agg.rs | 27 ++-- src/stream/src/executor/hash_agg.rs | 27 ++-- src/stream/src/executor/hash_join.rs | 66 +++++--- src/stream/src/executor/hop_window.rs | 17 +- src/stream/src/executor/local_simple_agg.rs | 31 ++-- src/stream/src/executor/mod.rs | 8 +- src/stream/src/executor/project.rs | 153 ++++++++++-------- src/stream/src/executor/project_set.rs | 11 +- src/stream/src/executor/simple.rs | 95 ----------- src/stream/src/executor/temporal_join.rs | 37 ++--- src/stream/src/executor/watermark_filter.rs | 11 +- src/stream/src/lib.rs | 10 -- 77 files changed, 1099 insertions(+), 1013 deletions(-) delete mode 100644 src/stream/src/common/infallible_expr.rs delete mode 100644 src/stream/src/executor/simple.rs diff --git a/Cargo.lock b/Cargo.lock index 5d8578a62fd43..57aa88f6b9be5 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -6094,11 +6094,13 @@ dependencies = [ "anyhow", "arrow-array", "arrow-schema", + "async-trait", "chrono", "chrono-tz", "criterion", "dyn-clone", "either", + "futures-util", "itertools", "madsim-tokio", "md5", @@ -6111,6 +6113,7 @@ dependencies = [ "risingwave_pb", "risingwave_udf", "speedate", + "static_assertions", "thiserror", "workspace-hack", ] diff --git a/src/batch/src/executor/filter.rs b/src/batch/src/executor/filter.rs index 70c65d7e74a23..0a4ee63e306f7 100644 --- a/src/batch/src/executor/filter.rs +++ b/src/batch/src/executor/filter.rs @@ -58,7 +58,7 @@ impl FilterExecutor { #[for_await] for data_chunk in self.child.execute() { let data_chunk = data_chunk?.compact(); - let vis_array = self.expr.eval(&data_chunk)?; + let vis_array = self.expr.eval(&data_chunk).await?; if let Bool(vis) = vis_array.as_ref() { // TODO: should we yield masked data chunk directly? diff --git a/src/batch/src/executor/hash_agg.rs b/src/batch/src/executor/hash_agg.rs index 6232e78b470ba..42178ca5c7acb 100644 --- a/src/batch/src/executor/hash_agg.rs +++ b/src/batch/src/executor/hash_agg.rs @@ -213,7 +213,7 @@ impl HashAggExecutor { // TODO: currently not a vectorized implementation for state in states { - state.update_single(&chunk, row_id)? + state.update_single(&chunk, row_id).await? } } } diff --git a/src/batch/src/executor/hop_window.rs b/src/batch/src/executor/hop_window.rs index 3fcbb6ccf45ac..6f84a692f0fd6 100644 --- a/src/batch/src/executor/hop_window.rs +++ b/src/batch/src/executor/hop_window.rs @@ -178,12 +178,12 @@ impl HopWindowExecutor { let len = data_chunk.cardinality(); for i in 0..units { let window_start_col = if output_indices.contains(&window_start_col_index) { - Some(self.window_start_exprs[i].eval(&data_chunk)?) + Some(self.window_start_exprs[i].eval(&data_chunk).await?) } else { None }; let window_end_col = if output_indices.contains(&window_end_col_index) { - Some(self.window_end_exprs[i].eval(&data_chunk)?) + Some(self.window_end_exprs[i].eval(&data_chunk).await?) } else { None }; diff --git a/src/batch/src/executor/join/hash_join.rs b/src/batch/src/executor/join/hash_join.rs index c33f10baa6d76..e5eac41512a4f 100644 --- a/src/batch/src/executor/join/hash_join.rs +++ b/src/batch/src/executor/join/hash_join.rs @@ -366,7 +366,7 @@ impl HashJoinExecutor { #[for_await] for chunk in Self::do_inner_join(params) { let mut chunk = chunk?; - chunk.set_visibility(cond.eval(&chunk)?.as_bool().iter().collect()); + chunk.set_visibility(cond.eval(&chunk).await?.as_bool().iter().collect()); yield chunk } } @@ -473,7 +473,8 @@ impl HashJoinExecutor { spilled, cond.as_ref(), &mut non_equi_state, - )? + ) + .await? } } } else { @@ -494,7 +495,8 @@ impl HashJoinExecutor { spilled, cond.as_ref(), &mut non_equi_state, - )? + ) + .await? } } @@ -593,7 +595,8 @@ impl HashJoinExecutor { spilled, cond.as_ref(), &mut non_equi_state, - )? + ) + .await? } } } @@ -606,7 +609,8 @@ impl HashJoinExecutor { spilled, cond.as_ref(), &mut non_equi_state, - )? + ) + .await? } } @@ -657,7 +661,8 @@ impl HashJoinExecutor { spilled, cond.as_ref(), &mut non_equi_state, - )? + ) + .await? } } } else if let Some(spilled) = Self::append_one_probe_row( @@ -675,7 +680,8 @@ impl HashJoinExecutor { spilled, cond.as_ref(), &mut non_equi_state, - )? + ) + .await? } if let Some(spilled) = remaining_chunk_builder.consume_all() { yield spilled @@ -777,7 +783,8 @@ impl HashJoinExecutor { spilled, cond.as_ref(), &mut non_equi_state, - )? + ) + .await? } } } @@ -787,7 +794,8 @@ impl HashJoinExecutor { spilled, cond.as_ref(), &mut non_equi_state, - )? + ) + .await? } #[for_await] for spilled in Self::handle_remaining_build_rows_for_right_outer_join( @@ -884,7 +892,8 @@ impl HashJoinExecutor { spilled, cond.as_ref(), &mut non_equi_state, - )? + ) + .await? } } } @@ -894,7 +903,8 @@ impl HashJoinExecutor { spilled, cond.as_ref(), &mut non_equi_state, - )? + ) + .await? } #[for_await] for spilled in Self::handle_remaining_build_rows_for_right_semi_anti_join::( @@ -1028,7 +1038,8 @@ impl HashJoinExecutor { cond.as_ref(), &mut left_non_equi_state, &mut right_non_equi_state, - )? + ) + .await? } } } else { @@ -1050,7 +1061,8 @@ impl HashJoinExecutor { cond.as_ref(), &mut left_non_equi_state, &mut right_non_equi_state, - )? + ) + .await? } #[for_await] for spilled in Self::handle_remaining_build_rows_for_right_outer_join( @@ -1199,7 +1211,7 @@ impl HashJoinExecutor { /// /// For more information about how `process_*_join_non_equi_condition` work, see their unit /// tests. - fn process_left_outer_join_non_equi_condition( + async fn process_left_outer_join_non_equi_condition( chunk: DataChunk, cond: &dyn Expression, LeftNonEquiJoinState { @@ -1209,7 +1221,7 @@ impl HashJoinExecutor { found_matched, }: &mut LeftNonEquiJoinState, ) -> Result { - let filter = cond.eval(&chunk)?.as_bool().iter().collect(); + let filter = cond.eval(&chunk).await?.as_bool().iter().collect(); Ok(DataChunkMutator(chunk) .nullify_build_side_for_non_equi_condition(&filter, *probe_column_count) .remove_duplicate_rows_for_left_outer_join( @@ -1223,7 +1235,7 @@ impl HashJoinExecutor { /// Filters for candidate rows which satisfy `non_equi` predicate. /// Removes duplicate rows. - fn process_left_semi_anti_join_non_equi_condition( + async fn process_left_semi_anti_join_non_equi_condition( chunk: DataChunk, cond: &dyn Expression, LeftNonEquiJoinState { @@ -1233,7 +1245,7 @@ impl HashJoinExecutor { .. }: &mut LeftNonEquiJoinState, ) -> Result { - let filter = cond.eval(&chunk)?.as_bool().iter().collect(); + let filter = cond.eval(&chunk).await?.as_bool().iter().collect(); Ok(DataChunkMutator(chunk) .remove_duplicate_rows_for_left_semi_anti_join::( &filter, @@ -1244,7 +1256,7 @@ impl HashJoinExecutor { .take()) } - fn process_right_outer_join_non_equi_condition( + async fn process_right_outer_join_non_equi_condition( chunk: DataChunk, cond: &dyn Expression, RightNonEquiJoinState { @@ -1252,13 +1264,13 @@ impl HashJoinExecutor { build_row_matched, }: &mut RightNonEquiJoinState, ) -> Result { - let filter = cond.eval(&chunk)?.as_bool().iter().collect(); + let filter = cond.eval(&chunk).await?.as_bool().iter().collect(); Ok(DataChunkMutator(chunk) .remove_duplicate_rows_for_right_outer_join(&filter, build_row_ids, build_row_matched) .take()) } - fn process_right_semi_anti_join_non_equi_condition( + async fn process_right_semi_anti_join_non_equi_condition( chunk: DataChunk, cond: &dyn Expression, RightNonEquiJoinState { @@ -1266,7 +1278,7 @@ impl HashJoinExecutor { build_row_matched, }: &mut RightNonEquiJoinState, ) -> Result<()> { - let filter = cond.eval(&chunk)?.as_bool().iter().collect(); + let filter = cond.eval(&chunk).await?.as_bool().iter().collect(); DataChunkMutator(chunk).remove_duplicate_rows_for_right_semi_anti_join( &filter, build_row_ids, @@ -1275,13 +1287,13 @@ impl HashJoinExecutor { Ok(()) } - fn process_full_outer_join_non_equi_condition( + async fn process_full_outer_join_non_equi_condition( chunk: DataChunk, cond: &dyn Expression, left_non_equi_state: &mut LeftNonEquiJoinState, right_non_equi_state: &mut RightNonEquiJoinState, ) -> Result { - let filter = cond.eval(&chunk)?.as_bool().iter().collect(); + let filter = cond.eval(&chunk).await?.as_bool().iter().collect(); Ok(DataChunkMutator(chunk) .nullify_build_side_for_non_equi_condition( &filter, @@ -2609,6 +2621,7 @@ mod tests { cond.as_ref(), &mut state ) + .await .unwrap() .compact(), &expect @@ -2638,6 +2651,7 @@ mod tests { cond.as_ref(), &mut state ) + .await .unwrap() .compact(), &expect @@ -2667,6 +2681,7 @@ mod tests { cond.as_ref(), &mut state ) + .await .unwrap() .compact(), &expect @@ -2706,6 +2721,7 @@ mod tests { cond.as_ref(), &mut state ) + .await .unwrap() .compact(), &expect @@ -2732,6 +2748,7 @@ mod tests { cond.as_ref(), &mut state ) + .await .unwrap() .compact(), &expect @@ -2758,6 +2775,7 @@ mod tests { cond.as_ref(), &mut state ) + .await .unwrap() .compact(), &expect @@ -2799,6 +2817,7 @@ mod tests { cond.as_ref(), &mut state ) + .await .unwrap() .compact(), &expect @@ -2827,6 +2846,7 @@ mod tests { cond.as_ref(), &mut state ) + .await .unwrap() .compact(), &expect @@ -2855,6 +2875,7 @@ mod tests { cond.as_ref(), &mut state ) + .await .unwrap() .compact(), &expect @@ -2918,6 +2939,7 @@ mod tests { cond.as_ref(), &mut state ) + .await .unwrap() .compact(), &expect @@ -2958,6 +2980,7 @@ mod tests { cond.as_ref(), &mut state ) + .await .unwrap() .compact(), &expect @@ -3010,6 +3033,7 @@ mod tests { cond.as_ref(), &mut state ) + .await .is_ok() ); assert_eq!(state.build_row_ids, Vec::new()); @@ -3044,6 +3068,7 @@ mod tests { cond.as_ref(), &mut state ) + .await .is_ok() ); assert_eq!(state.build_row_ids, Vec::new()); @@ -3105,6 +3130,7 @@ mod tests { &mut left_state, &mut right_state, ) + .await .unwrap() .compact(), &expect @@ -3152,6 +3178,7 @@ mod tests { &mut left_state, &mut right_state, ) + .await .unwrap() .compact(), &expect diff --git a/src/batch/src/executor/join/nested_loop_join.rs b/src/batch/src/executor/join/nested_loop_join.rs index c0a46d6865d9b..5c229fa77ee5c 100644 --- a/src/batch/src/executor/join/nested_loop_join.rs +++ b/src/batch/src/executor/join/nested_loop_join.rs @@ -121,7 +121,7 @@ impl NestedLoopJoinExecutor { impl NestedLoopJoinExecutor { /// Create a chunk by concatenating a row with a chunk and set its visibility according to the /// evaluation result of the expression. - fn concatenate_and_eval( + async fn concatenate_and_eval( expr: &dyn Expression, left_row_types: &[DataType], left_row: RowRef<'_>, @@ -129,7 +129,7 @@ impl NestedLoopJoinExecutor { ) -> Result { let left_chunk = convert_row_to_chunk(&left_row, right_chunk.capacity(), left_row_types)?; let mut chunk = concatenate(&left_chunk, right_chunk)?; - chunk.set_visibility(expr.eval(&chunk)?.as_bool().iter().collect()); + chunk.set_visibility(expr.eval(&chunk).await?.as_bool().iter().collect()); Ok(chunk) } } @@ -232,7 +232,8 @@ impl NestedLoopJoinExecutor { &left_data_types, left_row, &right_chunk, - )?; + ) + .await?; // 4. Yield the concatenated chunk. if chunk.cardinality() > 0 { for spilled in chunk_builder.append_chunk(chunk) { @@ -264,7 +265,8 @@ impl NestedLoopJoinExecutor { &left_data_types, left_row, &right_chunk, - )?; + ) + .await?; if chunk.cardinality() > 0 { matched.set(left_row_idx, true); for spilled in chunk_builder.append_chunk(chunk) { @@ -308,7 +310,8 @@ impl NestedLoopJoinExecutor { &left_data_types, left_row, &right_chunk, - )?; + ) + .await?; if chunk.cardinality() > 0 { matched.set(left_row_idx, true) } @@ -345,7 +348,8 @@ impl NestedLoopJoinExecutor { &left_data_types, left_row, &right_chunk, - )?; + ) + .await?; if chunk.cardinality() > 0 { // chunk.visibility() must be Some(_) matched = &matched | chunk.visibility().unwrap(); @@ -385,7 +389,8 @@ impl NestedLoopJoinExecutor { &left_data_types, left_row, &right_chunk, - )?; + ) + .await?; if chunk.cardinality() > 0 { // chunk.visibility() must be Some(_) matched = &matched | chunk.visibility().unwrap(); @@ -424,7 +429,8 @@ impl NestedLoopJoinExecutor { &left_data_types, left_row, &right_chunk, - )?; + ) + .await?; if chunk.cardinality() > 0 { left_matched.set(left_row_idx, true); right_matched = &right_matched | chunk.visibility().unwrap(); diff --git a/src/batch/src/executor/project.rs b/src/batch/src/executor/project.rs index ed8f951e3bbb9..024eac9455d92 100644 --- a/src/batch/src/executor/project.rs +++ b/src/batch/src/executor/project.rs @@ -54,11 +54,11 @@ impl ProjectExecutor { for data_chunk in self.child.execute() { let data_chunk = data_chunk?; // let data_chunk = data_chunk.compact(); - let arrays: Vec = self - .expr - .iter_mut() - .map(|expr| expr.eval(&data_chunk).map(Column::new)) - .try_collect()?; + let mut arrays = Vec::with_capacity(self.expr.len()); + for expr in &mut self.expr { + let column = Column::new(expr.eval(&data_chunk).await?); + arrays.push(column); + } let (_, vis) = data_chunk.into_parts(); let ret = DataChunk::new(arrays, vis); yield ret diff --git a/src/batch/src/executor/project_set.rs b/src/batch/src/executor/project_set.rs index 249728c1e15d6..d41b904d105fe 100644 --- a/src/batch/src/executor/project_set.rs +++ b/src/batch/src/executor/project_set.rs @@ -73,11 +73,12 @@ impl ProjectSetExecutor { .map(|ty| ty.create_array_builder(self.chunk_size)) .collect_vec(); - let results: Vec<_> = self - .select_list - .iter() - .map(|select_item| select_item.eval(&data_chunk)) - .try_collect()?; + let mut results = Vec::with_capacity(self.select_list.len()); + + for select_item in &self.select_list { + let result = select_item.eval(&data_chunk).await?; + results.push(result); + } let mut lens = results .iter() diff --git a/src/batch/src/executor/sort_agg.rs b/src/batch/src/executor/sort_agg.rs index e8fcb784e7b49..b967ffc82a508 100644 --- a/src/batch/src/executor/sort_agg.rs +++ b/src/batch/src/executor/sort_agg.rs @@ -123,11 +123,11 @@ impl SortAggExecutor { if no_input_data && child_chunk.cardinality() > 0 { no_input_data = false; } - let group_columns: Vec<_> = self - .group_key - .iter_mut() - .map(|expr| expr.eval(&child_chunk)) - .try_collect()?; + let mut group_columns = Vec::with_capacity(self.group_key.len()); + for expr in &mut self.group_key { + let result = expr.eval(&child_chunk).await?; + group_columns.push(result); + } let groups: Vec<_> = self .sorted_groupers @@ -153,7 +153,8 @@ impl SortAggExecutor { &child_chunk, start_row_idx, end_row_idx, - )?; + ) + .await?; } Self::output_sorted_groupers(&mut self.sorted_groupers, &mut group_builders)?; Self::output_agg_states(&mut self.agg_states, &mut agg_builders)?; @@ -186,12 +187,8 @@ impl SortAggExecutor { start_row_idx, row_cnt, )?; - Self::update_agg_states( - &mut self.agg_states, - &child_chunk, - start_row_idx, - row_cnt, - )?; + Self::update_agg_states(&mut self.agg_states, &child_chunk, start_row_idx, row_cnt) + .await?; } } @@ -228,16 +225,18 @@ impl SortAggExecutor { .map_err(Into::into) } - fn update_agg_states( + async fn update_agg_states( agg_states: &mut [BoxedAggState], child_chunk: &DataChunk, start_row_idx: usize, end_row_idx: usize, ) -> Result<()> { - agg_states - .iter_mut() - .try_for_each(|state| state.update_multi(child_chunk, start_row_idx, end_row_idx)) - .map_err(Into::into) + for state in agg_states.iter_mut() { + state + .update_multi(child_chunk, start_row_idx, end_row_idx) + .await?; + } + Ok(()) } fn output_sorted_groupers( diff --git a/src/batch/src/executor/table_function.rs b/src/batch/src/executor/table_function.rs index 82ca4521b28cc..3ce3374253957 100644 --- a/src/batch/src/executor/table_function.rs +++ b/src/batch/src/executor/table_function.rs @@ -54,7 +54,7 @@ impl TableFunctionExecutor { .return_type() .create_array_builder(self.chunk_size); let mut len = 0; - for array in self.table_function.eval(&dummy_chunk)? { + for array in self.table_function.eval(&dummy_chunk).await? { len += array.len(); builder.append_array(&array); } diff --git a/src/batch/src/executor/test_utils.rs b/src/batch/src/executor/test_utils.rs index 2870e700fe6ae..cad53e7d70fab 100644 --- a/src/batch/src/executor/test_utils.rs +++ b/src/batch/src/executor/test_utils.rs @@ -113,7 +113,7 @@ pub fn gen_projected_data( let chunk = DataChunk::new(vec![array_builder.finish().into()], batch_size); - let array = expr.eval(&chunk).unwrap(); + let array = futures::executor::block_on(expr.eval(&chunk)).unwrap(); let chunk = DataChunk::new(vec![Column::new(array)], batch_size); ret.push(chunk); } diff --git a/src/batch/src/executor/update.rs b/src/batch/src/executor/update.rs index f820d126fee55..14f3c1eac25ca 100644 --- a/src/batch/src/executor/update.rs +++ b/src/batch/src/executor/update.rs @@ -132,11 +132,11 @@ impl UpdateExecutor { let data_chunk = data_chunk?; let updated_data_chunk = { - let columns: Vec<_> = self - .exprs - .iter_mut() - .map(|expr| expr.eval(&data_chunk).map(Column::new)) - .try_collect()?; + let mut columns = Vec::with_capacity(self.exprs.len()); + for expr in &mut self.exprs { + let column = Column::new(expr.eval(&data_chunk).await?); + columns.push(column); + } DataChunk::new(columns, data_chunk.vis().clone()) }; diff --git a/src/batch/src/executor/values.rs b/src/batch/src/executor/values.rs index c6cc8bf83c611..106af69b57a62 100644 --- a/src/batch/src/executor/values.rs +++ b/src/batch/src/executor/values.rs @@ -83,7 +83,7 @@ impl ValuesExecutor { let mut array_builders = self.schema.create_array_builders(chunk_size); for row in self.rows.by_ref().take(chunk_size) { for (expr, builder) in row.into_iter().zip_eq_fast(&mut array_builders) { - let out = expr.eval(&one_row_chunk)?; + let out = expr.eval(&one_row_chunk).await?; builder.append_array(&out); } } diff --git a/src/common/src/types/mod.rs b/src/common/src/types/mod.rs index 0d8c904d45eb9..55f324ba5ee6c 100644 --- a/src/common/src/types/mod.rs +++ b/src/common/src/types/mod.rs @@ -444,6 +444,8 @@ pub fn option_as_scalar_ref(scalar: &Option) -> Option: Copy + std::fmt::Debug + + Send + + Sync + 'a + TryFrom, Error = ArrayError> + Into> diff --git a/src/expr/Cargo.toml b/src/expr/Cargo.toml index 3f80b443c1233..c17dde73e7100 100644 --- a/src/expr/Cargo.toml +++ b/src/expr/Cargo.toml @@ -19,10 +19,12 @@ aho-corasick = "0.7" anyhow = "1" arrow-array = "34" arrow-schema = "34" +async-trait = "0.1" chrono = { version = "0.4", default-features = false, features = ["clock", "std"] } chrono-tz = { version = "0.7", features = ["case-insensitive"] } dyn-clone = "1" either = "1" +futures-util = "0.3" itertools = "0.10" md5 = "0.7.0" num-traits = "0.2" @@ -34,6 +36,7 @@ risingwave_common = { path = "../common" } risingwave_pb = { path = "../prost" } risingwave_udf = { path = "../udf" } speedate = "0.7.0" +static_assertions = "1" thiserror = "1" tokio = { version = "0.2", package = "madsim-tokio", features = ["rt", "rt-multi-thread", "sync", "macros", "time", "signal"] } diff --git a/src/expr/benches/expr.rs b/src/expr/benches/expr.rs index 2f43a37ecf681..eacbb54436616 100644 --- a/src/expr/benches/expr.rs +++ b/src/expr/benches/expr.rs @@ -19,6 +19,9 @@ // `zip_eq` is a source of poor performance. #![allow(clippy::disallowed_methods)] +use std::cell::RefCell; + +use criterion::async_executor::FuturesExecutor; use criterion::{criterion_group, criterion_main, Criterion}; use risingwave_common::array::*; use risingwave_common::types::test_utils::IntervalUnitTestExt; @@ -195,11 +198,15 @@ fn bench_expr(c: &mut Criterion) { c.bench_function("inputref", |bencher| { let inputref = inputrefs[0].clone().boxed(); - bencher.iter(|| inputref.eval(&input).unwrap()) + bencher + .to_async(FuturesExecutor) + .iter(|| inputref.eval(&input)) }); c.bench_function("constant", |bencher| { let constant = LiteralExpression::new(DataType::Int32, Some(1_i32.into())); - bencher.iter(|| constant.eval(&input).unwrap()) + bencher + .to_async(FuturesExecutor) + .iter(|| constant.eval(&input)) }); let sigs = func_sigs(); @@ -251,7 +258,7 @@ fn bench_expr(c: &mut Criterion) { } }; c.bench_function(&sig.to_string_no_return(), |bencher| { - bencher.iter(|| expr.eval(&input).unwrap()) + bencher.to_async(FuturesExecutor).iter(|| expr.eval(&input)) }); } @@ -260,7 +267,7 @@ fn bench_expr(c: &mut Criterion) { println!("todo: {}", sig.to_string_no_return()); continue; } - let mut agg = match create_agg_state_unary( + let agg = match create_agg_state_unary( sig.inputs_type[0].into(), inputref_for_type(sig.inputs_type[0].into()).index(), sig.func, @@ -273,8 +280,16 @@ fn bench_expr(c: &mut Criterion) { continue; } }; + // to workaround the lifetime issue + let agg = RefCell::new(agg); c.bench_function(&sig.to_string_no_return(), |bencher| { - bencher.iter(|| agg.update_multi(&input, 0, CHUNK_SIZE).unwrap()) + #[allow(clippy::await_holding_refcell_ref)] + bencher.to_async(FuturesExecutor).iter(|| async { + agg.borrow_mut() + .update_multi(&input, 0, CHUNK_SIZE) + .await + .unwrap() + }) }); } @@ -310,7 +325,7 @@ fn bench_expr(c: &mut Criterion) { } }; c.bench_function(&sig.to_string_no_return(), |bencher| { - bencher.iter(|| expr.eval(&input).unwrap()) + bencher.to_async(FuturesExecutor).iter(|| expr.eval(&input)) }); } diff --git a/src/expr/src/expr/build_expr_from_prost.rs b/src/expr/src/expr/build_expr_from_prost.rs index 05507af5342bf..7db4ea1e46a89 100644 --- a/src/expr/src/expr/build_expr_from_prost.rs +++ b/src/expr/src/expr/build_expr_from_prost.rs @@ -443,8 +443,8 @@ mod tests { use super::*; - #[test] - fn test_array_access_expr() { + #[tokio::test] + async fn test_array_access_expr() { let values = FunctionCall { children: vec![ ExprNode { @@ -506,7 +506,7 @@ mod tests { let expr = build_nullable_binary_expr_prost(&access); assert!(expr.is_ok()); - let res = expr.unwrap().eval(&DataChunk::new_dummy(1)).unwrap(); + let res = expr.unwrap().eval(&DataChunk::new_dummy(1)).await.unwrap(); assert_eq!(*res, ArrayImpl::Utf8(Utf8Array::from_iter(["foo"]))); } diff --git a/src/expr/src/expr/expr_array_concat.rs b/src/expr/src/expr/expr_array_concat.rs index 7e0723cf82e33..f904afc7ebb4e 100644 --- a/src/expr/src/expr/expr_array_concat.rs +++ b/src/expr/src/expr/expr_array_concat.rs @@ -313,14 +313,15 @@ impl ArrayConcatExpression { } } +#[async_trait::async_trait] impl Expression for ArrayConcatExpression { fn return_type(&self) -> DataType { self.return_type.clone() } - fn eval(&self, input: &DataChunk) -> Result { - let left_array = self.left.eval_checked(input)?; - let right_array = self.right.eval_checked(input)?; + async fn eval(&self, input: &DataChunk) -> Result { + let left_array = self.left.eval_checked(input).await?; + let right_array = self.right.eval_checked(input).await?; let mut builder = self .return_type .create_array_builder(left_array.len() + right_array.len()); @@ -338,9 +339,9 @@ impl Expression for ArrayConcatExpression { Ok(Arc::new(builder.finish())) } - fn eval_row(&self, input: &OwnedRow) -> Result { - let left_data = self.left.eval_row(input)?; - let right_data = self.right.eval_row(input)?; + async fn eval_row(&self, input: &OwnedRow) -> Result { + let left_data = self.left.eval_row(input).await?; + let right_data = self.right.eval_row(input).await?; Ok(self.evaluate(left_data.to_datum_ref(), right_data.to_datum_ref())) } } @@ -555,8 +556,8 @@ mod tests { .boxed() } - #[test] - fn test_array_concat_array_of_primitives() { + #[tokio::test] + async fn test_array_concat_array_of_primitives() { let left = make_i64_array_expr(vec![42]); let right = make_i64_array_expr(vec![43, 44]); let expr = ArrayConcatExpression::new( @@ -583,6 +584,7 @@ mod tests { ]; let actual = expr .eval(&chunk) + .await .unwrap() .iter() .map(|v| v.map(|s| s.into_scalar_impl())) diff --git a/src/expr/src/expr/expr_array_distinct.rs b/src/expr/src/expr/expr_array_distinct.rs index 120ad52dd8e5f..010f184023e87 100644 --- a/src/expr/src/expr/expr_array_distinct.rs +++ b/src/expr/src/expr/expr_array_distinct.rs @@ -80,13 +80,14 @@ impl<'a> TryFrom<&'a ExprNode> for ArrayDistinctExpression { } } +#[async_trait::async_trait] impl Expression for ArrayDistinctExpression { fn return_type(&self) -> DataType { self.return_type.clone() } - fn eval(&self, input: &DataChunk) -> Result { - let array = self.array.eval_checked(input)?; + async fn eval(&self, input: &DataChunk) -> Result { + let array = self.array.eval_checked(input).await?; let mut builder = self.return_type.create_array_builder(array.len()); for (vis, arr) in input.vis().iter().zip_eq_fast(array.iter()) { if !vis { @@ -98,8 +99,8 @@ impl Expression for ArrayDistinctExpression { Ok(Arc::new(builder.finish())) } - fn eval_row(&self, input: &OwnedRow) -> Result { - let array_data = self.array.eval_row(input)?; + async fn eval_row(&self, input: &OwnedRow) -> Result { + let array_data = self.array.eval_row(input).await?; Ok(self.evaluate(array_data.to_datum_ref())) } } @@ -226,8 +227,8 @@ mod tests { .boxed() } - #[test] - fn test_array_distinct_array_of_primitives() { + #[tokio::test] + async fn test_array_distinct_array_of_primitives() { let array = make_i64_array_expr(vec![42, 43, 42]); let expr = ArrayDistinctExpression { return_type: DataType::List { @@ -250,6 +251,7 @@ mod tests { ]; let actual = expr .eval(&chunk) + .await .unwrap() .iter() .map(|v| v.map(|s| s.into_scalar_impl())) diff --git a/src/expr/src/expr/expr_array_to_string.rs b/src/expr/src/expr/expr_array_to_string.rs index d5064c64fd585..0e991b0b44a97 100644 --- a/src/expr/src/expr/expr_array_to_string.rs +++ b/src/expr/src/expr/expr_array_to_string.rs @@ -121,20 +121,21 @@ impl<'a> TryFrom<&'a ExprNode> for ArrayToStringExpression { } } +#[async_trait::async_trait] impl Expression for ArrayToStringExpression { fn return_type(&self) -> DataType { DataType::Varchar } - fn eval(&self, input: &DataChunk) -> Result { - let list_array = self.array.eval_checked(input)?; + async fn eval(&self, input: &DataChunk) -> Result { + let list_array = self.array.eval_checked(input).await?; let list_array = list_array.as_list(); - let delim_array = self.delimiter.eval_checked(input)?; + let delim_array = self.delimiter.eval_checked(input).await?; let delim_array = delim_array.as_utf8(); let null_string_array = if let Some(expr) = &self.null_string { - let null_string_array = expr.eval_checked(input)?; + let null_string_array = expr.eval_checked(input).await?; Some(null_string_array) } else { None @@ -171,13 +172,13 @@ impl Expression for ArrayToStringExpression { Ok(Arc::new(output.finish().into())) } - fn eval_row(&self, input: &OwnedRow) -> Result { - let array = self.array.eval_row(input)?; - let delimiter = self.delimiter.eval_row(input)?; + async fn eval_row(&self, input: &OwnedRow) -> Result { + let array = self.array.eval_row(input).await?; + let delimiter = self.delimiter.eval_row(input).await?; let result = if let Some(array) = array && let Some(delimiter) = delimiter { let null_string = if let Some(e) = &self.null_string { - e.eval_row(input)? + e.eval_row(input).await? } else { None }; diff --git a/src/expr/src/expr/expr_binary_bytes.rs b/src/expr/src/expr/expr_binary_bytes.rs index 3ab6b822b35de..01a85163be2d3 100644 --- a/src/expr/src/expr/expr_binary_bytes.rs +++ b/src/expr/src/expr/expr_binary_bytes.rs @@ -131,16 +131,16 @@ mod tests { ) } - fn test_evals_dummy(expr: &BoxedExpression, expected: Datum) { - let res = expr.eval(&DataChunk::new_dummy(1)).unwrap(); + async fn test_evals_dummy(expr: &BoxedExpression, expected: Datum) { + let res = expr.eval(&DataChunk::new_dummy(1)).await.unwrap(); assert_eq!(res.to_datum(), expected); - let res = expr.eval_row(&OwnedRow::new(vec![])).unwrap(); + let res = expr.eval_row(&OwnedRow::new(vec![])).await.unwrap(); assert_eq!(res, expected); } - #[test] - fn test_substr() { + #[tokio::test] + async fn test_substr() { let text = "quick brown"; let start_pos = 3; let for_pos = 4; @@ -155,14 +155,15 @@ mod tests { Some(ScalarImpl::from(String::from( &text[start_pos as usize - 1..], ))), - ); + ) + .await; let substr_start_i32_none = create_str_i32_binary_expr( new_substr_start, Some(ScalarImpl::from(String::from(text))), None, ); - test_evals_dummy(&substr_start_i32_none, None); + test_evals_dummy(&substr_start_i32_none, None).await; let substr_for_normal = create_str_i32_binary_expr( new_substr_for, @@ -172,10 +173,11 @@ mod tests { test_evals_dummy( &substr_for_normal, Some(ScalarImpl::from(String::from(&text[..for_pos as usize]))), - ); + ) + .await; let substr_for_str_none = create_str_i32_binary_expr(new_substr_for, None, Some(ScalarImpl::Int32(for_pos))); - test_evals_dummy(&substr_for_str_none, None); + test_evals_dummy(&substr_for_str_none, None).await; } } diff --git a/src/expr/src/expr/expr_binary_nonnull.rs b/src/expr/src/expr/expr_binary_nonnull.rs index 4271136a288d1..613e5c8fe8f3d 100644 --- a/src/expr/src/expr/expr_binary_nonnull.rs +++ b/src/expr/src/expr/expr_binary_nonnull.rs @@ -853,45 +853,47 @@ mod tests { use crate::expr::test_utils::make_expression; use crate::vector_op::arithmetic_op::{date_interval_add, date_interval_sub}; - #[test] - fn test_binary() { - test_binary_i32::(|x, y| x + y, Type::Add); - test_binary_i32::(|x, y| x - y, Type::Subtract); - test_binary_i32::(|x, y| x * y, Type::Multiply); - test_binary_i32::(|x, y| x / y, Type::Divide); - test_binary_i32::(|x, y| x == y, Type::Equal); - test_binary_i32::(|x, y| x != y, Type::NotEqual); - test_binary_i32::(|x, y| x > y, Type::GreaterThan); - test_binary_i32::(|x, y| x >= y, Type::GreaterThanOrEqual); - test_binary_i32::(|x, y| x < y, Type::LessThan); - test_binary_i32::(|x, y| x <= y, Type::LessThanOrEqual); - test_binary_decimal::(|x, y| x + y, Type::Add); - test_binary_decimal::(|x, y| x - y, Type::Subtract); - test_binary_decimal::(|x, y| x * y, Type::Multiply); - test_binary_decimal::(|x, y| x / y, Type::Divide); - test_binary_decimal::(|x, y| x == y, Type::Equal); - test_binary_decimal::(|x, y| x != y, Type::NotEqual); - test_binary_decimal::(|x, y| x > y, Type::GreaterThan); - test_binary_decimal::(|x, y| x >= y, Type::GreaterThanOrEqual); - test_binary_decimal::(|x, y| x < y, Type::LessThan); - test_binary_decimal::(|x, y| x <= y, Type::LessThanOrEqual); + #[tokio::test] + async fn test_binary() { + test_binary_i32::(|x, y| x + y, Type::Add).await; + test_binary_i32::(|x, y| x - y, Type::Subtract).await; + test_binary_i32::(|x, y| x * y, Type::Multiply).await; + test_binary_i32::(|x, y| x / y, Type::Divide).await; + test_binary_i32::(|x, y| x == y, Type::Equal).await; + test_binary_i32::(|x, y| x != y, Type::NotEqual).await; + test_binary_i32::(|x, y| x > y, Type::GreaterThan).await; + test_binary_i32::(|x, y| x >= y, Type::GreaterThanOrEqual).await; + test_binary_i32::(|x, y| x < y, Type::LessThan).await; + test_binary_i32::(|x, y| x <= y, Type::LessThanOrEqual).await; + test_binary_decimal::(|x, y| x + y, Type::Add).await; + test_binary_decimal::(|x, y| x - y, Type::Subtract).await; + test_binary_decimal::(|x, y| x * y, Type::Multiply).await; + test_binary_decimal::(|x, y| x / y, Type::Divide).await; + test_binary_decimal::(|x, y| x == y, Type::Equal).await; + test_binary_decimal::(|x, y| x != y, Type::NotEqual).await; + test_binary_decimal::(|x, y| x > y, Type::GreaterThan).await; + test_binary_decimal::(|x, y| x >= y, Type::GreaterThanOrEqual).await; + test_binary_decimal::(|x, y| x < y, Type::LessThan).await; + test_binary_decimal::(|x, y| x <= y, Type::LessThanOrEqual).await; test_binary_interval::( |x, y| { date_interval_add::(x, y) .unwrap() }, Type::Add, - ); + ) + .await; test_binary_interval::( |x, y| { date_interval_sub::(x, y) .unwrap() }, Type::Subtract, - ); + ) + .await; } - fn test_binary_i32(f: F, kind: Type) + async fn test_binary_i32(f: F, kind: Type) where A: Array, for<'a> &'a A: std::convert::From<&'a ArrayImpl>, @@ -926,7 +928,7 @@ mod tests { let data_chunk = DataChunk::new(vec![col1, col2], 100); let expr = make_expression(kind, &[TypeName::Int32, TypeName::Int32], &[0, 1]); let vec_executor = build_from_prost(&expr).unwrap(); - let res = vec_executor.eval(&data_chunk).unwrap(); + let res = vec_executor.eval(&data_chunk).await.unwrap(); let arr: &A = res.as_ref().into(); for (idx, item) in arr.iter().enumerate() { let x = target[idx].as_ref().map(|x| x.as_scalar_ref()); @@ -938,13 +940,13 @@ mod tests { lhs[i].map(|int| int.to_scalar_value()), rhs[i].map(|int| int.to_scalar_value()), ]); - let result = vec_executor.eval_row(&row).unwrap(); + let result = vec_executor.eval_row(&row).await.unwrap(); let expected = target[i].as_ref().cloned().map(|x| x.to_scalar_value()); assert_eq!(result, expected); } } - fn test_binary_interval(f: F, kind: Type) + async fn test_binary_interval(f: F, kind: Type) where A: Array, for<'a> &'a A: std::convert::From<&'a ArrayImpl>, @@ -974,7 +976,7 @@ mod tests { let data_chunk = DataChunk::new(vec![col1, col2], 100); let expr = make_expression(kind, &[TypeName::Date, TypeName::Interval], &[0, 1]); let vec_executor = build_from_prost(&expr).unwrap(); - let res = vec_executor.eval(&data_chunk).unwrap(); + let res = vec_executor.eval(&data_chunk).await.unwrap(); let arr: &A = res.as_ref().into(); for (idx, item) in arr.iter().enumerate() { let x = target[idx].as_ref().map(|x| x.as_scalar_ref()); @@ -986,13 +988,13 @@ mod tests { lhs[i].map(|date| date.to_scalar_value()), rhs[i].map(|date| date.to_scalar_value()), ]); - let result = vec_executor.eval_row(&row).unwrap(); + let result = vec_executor.eval_row(&row).await.unwrap(); let expected = target[i].as_ref().cloned().map(|x| x.to_scalar_value()); assert_eq!(result, expected); } } - fn test_binary_decimal(f: F, kind: Type) + async fn test_binary_decimal(f: F, kind: Type) where A: Array, for<'a> &'a A: std::convert::From<&'a ArrayImpl>, @@ -1027,7 +1029,7 @@ mod tests { let data_chunk = DataChunk::new(vec![col1, col2], 100); let expr = make_expression(kind, &[TypeName::Decimal, TypeName::Decimal], &[0, 1]); let vec_executor = build_from_prost(&expr).unwrap(); - let res = vec_executor.eval(&data_chunk).unwrap(); + let res = vec_executor.eval(&data_chunk).await.unwrap(); let arr: &A = res.as_ref().into(); for (idx, item) in arr.iter().enumerate() { let x = target[idx].as_ref().map(|x| x.as_scalar_ref()); @@ -1039,7 +1041,7 @@ mod tests { lhs[i].map(|dec| dec.to_scalar_value()), rhs[i].map(|dec| dec.to_scalar_value()), ]); - let result = vec_executor.eval_row(&row).unwrap(); + let result = vec_executor.eval_row(&row).await.unwrap(); let expected = target[i].as_ref().cloned().map(|x| x.to_scalar_value()); assert_eq!(result, expected); } diff --git a/src/expr/src/expr/expr_binary_nullable.rs b/src/expr/src/expr/expr_binary_nullable.rs index 9a17cbf4d7736..2d26232fb1291 100644 --- a/src/expr/src/expr/expr_binary_nullable.rs +++ b/src/expr/src/expr/expr_binary_nullable.rs @@ -78,13 +78,14 @@ impl std::fmt::Debug for BinaryShortCircuitExpression { } } +#[async_trait::async_trait] impl Expression for BinaryShortCircuitExpression { fn return_type(&self) -> DataType { DataType::Boolean } - fn eval(&self, input: &DataChunk) -> Result { - let left = self.expr_ia1.eval_checked(input)?; + async fn eval(&self, input: &DataChunk) -> Result { + let left = self.expr_ia1.eval_checked(input).await?; let left = left.as_bool(); let res_vis: Vis = match self.expr_type { @@ -100,7 +101,7 @@ impl Expression for BinaryShortCircuitExpression { let mut input1 = input.clone(); input1.set_vis(new_vis); - let right = self.expr_ia2.eval_checked(&input1)?; + let right = self.expr_ia2.eval_checked(&input1).await?; let right = right.as_bool(); assert_eq!(left.len(), right.len()); @@ -128,14 +129,14 @@ impl Expression for BinaryShortCircuitExpression { Ok(Arc::new(c.into())) } - fn eval_row(&self, input: &OwnedRow) -> Result { - let ret_ia1 = self.expr_ia1.eval_row(input)?.map(|x| x.into_bool()); + async fn eval_row(&self, input: &OwnedRow) -> Result { + let ret_ia1 = self.expr_ia1.eval_row(input).await?.map(|x| x.into_bool()); match self.expr_type { Type::Or if ret_ia1 == Some(true) => return Ok(Some(true.to_scalar_value())), Type::And if ret_ia1 == Some(false) => return Ok(Some(false.to_scalar_value())), _ => {} } - let ret_ia2 = self.expr_ia2.eval_row(input)?.map(|x| x.into_bool()); + let ret_ia2 = self.expr_ia2.eval_row(input).await?.map(|x| x.into_bool()); match self.expr_type { Type::Or => Ok(or(ret_ia1, ret_ia2)?.map(|x| x.to_scalar_value())), Type::And => Ok(and(ret_ia1, ret_ia2)?.map(|x| x.to_scalar_value())), @@ -312,8 +313,8 @@ mod tests { use crate::expr::build_from_prost; use crate::expr::test_utils::make_expression; - #[test] - fn test_and() { + #[tokio::test] + async fn test_and() { let lhs = vec![ Some(true), Some(true), @@ -356,14 +357,14 @@ mod tests { lhs[i].map(|x| x.to_scalar_value()), rhs[i].map(|x| x.to_scalar_value()), ]); - let res = vec_executor.eval_row(&row).unwrap(); + let res = vec_executor.eval_row(&row).await.unwrap(); let expected = target[i].map(|x| x.to_scalar_value()); assert_eq!(res, expected); } } - #[test] - fn test_or() { + #[tokio::test] + async fn test_or() { let lhs = vec![ Some(true), Some(true), @@ -406,14 +407,14 @@ mod tests { lhs[i].map(|x| x.to_scalar_value()), rhs[i].map(|x| x.to_scalar_value()), ]); - let res = vec_executor.eval_row(&row).unwrap(); + let res = vec_executor.eval_row(&row).await.unwrap(); let expected = target[i].map(|x| x.to_scalar_value()); assert_eq!(res, expected); } } - #[test] - fn test_is_distinct_from() { + #[tokio::test] + async fn test_is_distinct_from() { let lhs = vec![None, None, Some(1), Some(2), Some(3)]; let rhs = vec![None, Some(1), None, Some(2), Some(4)]; let target = vec![Some(false), Some(true), Some(true), Some(false), Some(true)]; @@ -430,14 +431,14 @@ mod tests { lhs[i].map(|x| x.to_scalar_value()), rhs[i].map(|x| x.to_scalar_value()), ]); - let res = vec_executor.eval_row(&row).unwrap(); + let res = vec_executor.eval_row(&row).await.unwrap(); let expected = target[i].map(|x| x.to_scalar_value()); assert_eq!(res, expected); } } - #[test] - fn test_is_not_distinct_from() { + #[tokio::test] + async fn test_is_not_distinct_from() { let lhs = vec![None, None, Some(1), Some(2), Some(3)]; let rhs = vec![None, Some(1), None, Some(2), Some(4)]; let target = vec![ @@ -460,14 +461,14 @@ mod tests { lhs[i].map(|x| x.to_scalar_value()), rhs[i].map(|x| x.to_scalar_value()), ]); - let res = vec_executor.eval_row(&row).unwrap(); + let res = vec_executor.eval_row(&row).await.unwrap(); let expected = target[i].map(|x| x.to_scalar_value()); assert_eq!(res, expected); } } - #[test] - fn test_format_type() { + #[tokio::test] + async fn test_format_type() { let l = vec![Some(16), Some(21), Some(9527), None]; let r = vec![Some(0), None, Some(0), Some(0)]; let target: Vec> = vec![ @@ -488,7 +489,7 @@ mod tests { l[i].map(|x| x.to_scalar_value()), r[i].map(|x| x.to_scalar_value()), ]); - let res = vec_executor.eval_row(&row).unwrap(); + let res = vec_executor.eval_row(&row).await.unwrap(); let expected = target[i].as_ref().map(|x| x.into()); assert_eq!(res, expected); } diff --git a/src/expr/src/expr/expr_case.rs b/src/expr/src/expr/expr_case.rs index e1d6dfc2814b7..8754db3d71f17 100644 --- a/src/expr/src/expr/expr_case.rs +++ b/src/expr/src/expr/expr_case.rs @@ -57,22 +57,28 @@ impl CaseExpression { } } +#[async_trait::async_trait] impl Expression for CaseExpression { fn return_type(&self) -> DataType { self.return_type.clone() } - fn eval(&self, input: &DataChunk) -> Result { + async fn eval(&self, input: &DataChunk) -> Result { let mut input = input.clone(); let input_len = input.capacity(); let mut selection = vec![None; input_len]; let when_len = self.when_clauses.len(); let mut result_array = Vec::with_capacity(when_len + 1); for (when_idx, WhenClause { when, then }) in self.when_clauses.iter().enumerate() { - let calc_then_vis: Vis = when.eval_checked(&input)?.as_bool().to_bitmap().into(); + let calc_then_vis: Vis = when + .eval_checked(&input) + .await? + .as_bool() + .to_bitmap() + .into(); let input_vis = input.vis().clone(); input.set_vis(calc_then_vis.clone()); - let then_res = then.eval_checked(&input)?; + let then_res = then.eval_checked(&input).await?; calc_then_vis .iter_ones() .for_each(|pos| selection[pos] = Some(when_idx)); @@ -80,7 +86,7 @@ impl Expression for CaseExpression { result_array.push(then_res); } if let Some(ref else_expr) = self.else_clause { - let else_res = else_expr.eval_checked(&input)?; + let else_res = else_expr.eval_checked(&input).await?; input .vis() .iter_ones() @@ -98,14 +104,14 @@ impl Expression for CaseExpression { Ok(Arc::new(builder.finish())) } - fn eval_row(&self, input: &OwnedRow) -> Result { + async fn eval_row(&self, input: &OwnedRow) -> Result { for WhenClause { when, then } in &self.when_clauses { - if when.eval_row(input)?.map_or(false, |w| w.into_bool()) { - return then.eval_row(input); + if when.eval_row(input).await?.map_or(false, |w| w.into_bool()) { + return then.eval_row(input).await; } } if let Some(ref else_expr) = self.else_clause { - else_expr.eval_row(input) + else_expr.eval_row(input).await } else { Ok(None) } @@ -199,17 +205,17 @@ mod tests { assert!(CaseExpression::try_from(&p).is_ok()); } - fn test_eval_row(expr: CaseExpression, row_inputs: Vec, expected: Vec>) { + async fn test_eval_row(expr: CaseExpression, row_inputs: Vec, expected: Vec>) { for (i, row_input) in row_inputs.iter().enumerate() { let row = OwnedRow::new(vec![Some(row_input.to_scalar_value())]); - let datum = expr.eval_row(&row).unwrap(); + let datum = expr.eval_row(&row).await.unwrap(); let expected = expected[i].map(|f| f.into()); assert_eq!(datum, expected) } } - #[test] - fn test_eval_searched_case() { + #[tokio::test] + async fn test_eval_searched_case() { let ret_type = DataType::Float32; // when x <= 2 then 3.1 let when_clauses = vec![WhenClause::new( @@ -239,7 +245,7 @@ mod tests { 4 5", ); - let output = searched_case_expr.eval(&input).unwrap(); + let output = searched_case_expr.eval(&input).await.unwrap(); assert_eq!(output.datum_at(0), Some(3.1f32.into())); assert_eq!(output.datum_at(1), Some(3.1f32.into())); assert_eq!(output.datum_at(2), Some(4.1f32.into())); @@ -247,8 +253,8 @@ mod tests { assert_eq!(output.datum_at(4), Some(4.1f32.into())); } - #[test] - fn test_eval_without_else() { + #[tokio::test] + async fn test_eval_without_else() { let ret_type = DataType::Float32; // when x <= 3 then 3.1 let when_clauses = vec![WhenClause::new( @@ -272,15 +278,15 @@ mod tests { 3 4", ); - let output = searched_case_expr.eval(&input).unwrap(); + let output = searched_case_expr.eval(&input).await.unwrap(); assert_eq!(output.datum_at(0), Some(3.1f32.into())); assert_eq!(output.datum_at(1), None); assert_eq!(output.datum_at(2), Some(3.1f32.into())); assert_eq!(output.datum_at(3), None); } - #[test] - fn test_eval_row_searched_case() { + #[tokio::test] + async fn test_eval_row_searched_case() { let ret_type = DataType::Float32; // when x <= 2 then 3.1 let when_clauses = vec![WhenClause::new( @@ -312,11 +318,11 @@ mod tests { Some(4.1f32), ]; - test_eval_row(searched_case_expr, row_inputs, expected); + test_eval_row(searched_case_expr, row_inputs, expected).await; } - #[test] - fn test_eval_row_without_else() { + #[tokio::test] + async fn test_eval_row_without_else() { let ret_type = DataType::Float32; // when x <= 3 then 3.1 let when_clauses = vec![WhenClause::new( @@ -337,6 +343,6 @@ mod tests { let row_inputs = vec![2, 3, 4, 5]; let expected = vec![Some(3.1f32), Some(3.1f32), None, None]; - test_eval_row(searched_case_expr, row_inputs, expected); + test_eval_row(searched_case_expr, row_inputs, expected).await; } } diff --git a/src/expr/src/expr/expr_coalesce.rs b/src/expr/src/expr/expr_coalesce.rs index d2e51e7404222..4eb3355daacb9 100644 --- a/src/expr/src/expr/expr_coalesce.rs +++ b/src/expr/src/expr/expr_coalesce.rs @@ -31,19 +31,20 @@ pub struct CoalesceExpression { children: Vec, } +#[async_trait::async_trait] impl Expression for CoalesceExpression { fn return_type(&self) -> DataType { self.return_type.clone() } - fn eval(&self, input: &DataChunk) -> Result { + async fn eval(&self, input: &DataChunk) -> Result { let init_vis = input.vis(); let mut input = input.clone(); let len = input.capacity(); let mut selection: Vec> = vec![None; len]; let mut children_array = Vec::with_capacity(self.children.len()); for (child_idx, child) in self.children.iter().enumerate() { - let res = child.eval_checked(&input)?; + let res = child.eval_checked(&input).await?; let res_bitmap = res.null_bitmap(); let orig_vis = input.vis(); let res_bitmap_ref: VisRef<'_> = res_bitmap.into(); @@ -70,9 +71,9 @@ impl Expression for CoalesceExpression { Ok(Arc::new(builder.finish())) } - fn eval_row(&self, input: &OwnedRow) -> Result { + async fn eval_row(&self, input: &OwnedRow) -> Result { for child in &self.children { - let datum = child.eval_row(input)?; + let datum = child.eval_row(input).await?; if datum.is_some() { return Ok(datum); } @@ -138,8 +139,8 @@ mod tests { } } - #[test] - fn test_coalesce_expr() { + #[tokio::test] + async fn test_coalesce_expr() { let input_node1 = make_input_ref(0, TypeName::Int32); let input_node2 = make_input_ref(1, TypeName::Int32); let input_node3 = make_input_ref(2, TypeName::Int32); @@ -157,15 +158,15 @@ mod tests { TypeName::Int32, )) .unwrap(); - let res = nullif_expr.eval(&data_chunk).unwrap(); + let res = nullif_expr.eval(&data_chunk).await.unwrap(); assert_eq!(res.datum_at(0), Some(ScalarImpl::Int32(1))); assert_eq!(res.datum_at(1), Some(ScalarImpl::Int32(2))); assert_eq!(res.datum_at(2), Some(ScalarImpl::Int32(3))); assert_eq!(res.datum_at(3), None); } - #[test] - fn test_eval_row_coalesce_expr() { + #[tokio::test] + async fn test_eval_row_coalesce_expr() { let input_node1 = make_input_ref(0, TypeName::Int32); let input_node2 = make_input_ref(1, TypeName::Int32); let input_node3 = make_input_ref(2, TypeName::Int32); @@ -197,7 +198,7 @@ mod tests { .collect(); let row = OwnedRow::new(datum_vec); - let result = nullif_expr.eval_row(&row).unwrap(); + let result = nullif_expr.eval_row(&row).await.unwrap(); assert_eq!(result, expected[i]); } } diff --git a/src/expr/src/expr/expr_concat_ws.rs b/src/expr/src/expr/expr_concat_ws.rs index 6c5da2acafc4a..0a5780bc3db19 100644 --- a/src/expr/src/expr/expr_concat_ws.rs +++ b/src/expr/src/expr/expr_concat_ws.rs @@ -34,20 +34,20 @@ pub struct ConcatWsExpression { string_exprs: Vec, } +#[async_trait::async_trait] impl Expression for ConcatWsExpression { fn return_type(&self) -> DataType { self.return_type.clone() } - fn eval(&self, input: &DataChunk) -> Result { - let sep_column = self.sep_expr.eval_checked(input)?; + async fn eval(&self, input: &DataChunk) -> Result { + let sep_column = self.sep_expr.eval_checked(input).await?; let sep_column = sep_column.as_utf8(); - let string_columns = self - .string_exprs - .iter() - .map(|c| c.eval_checked(input)) - .collect::>>()?; + let mut string_columns = Vec::with_capacity(self.string_exprs.len()); + for expr in &self.string_exprs { + string_columns.push(expr.eval_checked(input).await?); + } let string_columns_ref = string_columns .iter() .map(|c| c.as_utf8()) @@ -92,18 +92,17 @@ impl Expression for ConcatWsExpression { Ok(Arc::new(ArrayImpl::from(builder.finish()))) } - fn eval_row(&self, input: &OwnedRow) -> Result { - let sep = self.sep_expr.eval_row(input)?; + async fn eval_row(&self, input: &OwnedRow) -> Result { + let sep = self.sep_expr.eval_row(input).await?; let sep = match sep { Some(sep) => sep, None => return Ok(None), }; - let strings = self - .string_exprs - .iter() - .map(|c| c.eval_row(input)) - .collect::>>()?; + let mut strings = Vec::with_capacity(self.string_exprs.len()); + for expr in &self.string_exprs { + strings.push(expr.eval_row(input).await?); + } let mut final_string = String::new(); let mut strings_iter = strings.iter(); @@ -183,8 +182,8 @@ mod tests { } } - #[test] - fn test_eval_concat_ws_expr() { + #[tokio::test] + async fn test_eval_concat_ws_expr() { let input_node1 = make_input_ref(0, TypeName::Varchar); let input_node2 = make_input_ref(1, TypeName::Varchar); let input_node3 = make_input_ref(2, TypeName::Varchar); @@ -205,7 +204,7 @@ mod tests { . . . .", ); - let actual = concat_ws_expr.eval(&chunk).unwrap(); + let actual = concat_ws_expr.eval(&chunk).await.unwrap(); let actual = actual .iter() .map(|r| r.map(|s| s.into_utf8())) @@ -216,8 +215,8 @@ mod tests { assert_eq!(actual, expected); } - #[test] - fn test_eval_row_concat_ws_expr() { + #[tokio::test] + async fn test_eval_row_concat_ws_expr() { let input_node1 = make_input_ref(0, TypeName::Varchar); let input_node2 = make_input_ref(1, TypeName::Varchar); let input_node3 = make_input_ref(2, TypeName::Varchar); @@ -242,7 +241,7 @@ mod tests { let datum_vec: Vec = row_input.iter().map(|e| e.map(|s| s.into())).collect(); let row = OwnedRow::new(datum_vec); - let result = concat_ws_expr.eval_row(&row).unwrap(); + let result = concat_ws_expr.eval_row(&row).await.unwrap(); let expected = expected[i].map(|s| s.into()); assert_eq!(result, expected); diff --git a/src/expr/src/expr/expr_field.rs b/src/expr/src/expr/expr_field.rs index 06e91130d1c2f..6ea49d272ffc1 100644 --- a/src/expr/src/expr/expr_field.rs +++ b/src/expr/src/expr/expr_field.rs @@ -33,13 +33,14 @@ pub struct FieldExpression { index: usize, } +#[async_trait::async_trait] impl Expression for FieldExpression { fn return_type(&self) -> DataType { self.return_type.clone() } - fn eval(&self, input: &DataChunk) -> Result { - let array = self.input.eval_checked(input)?; + async fn eval(&self, input: &DataChunk) -> Result { + let array = self.input.eval_checked(input).await?; if let ArrayImpl::Struct(struct_array) = array.as_ref() { Ok(struct_array.field_at(self.index)) } else { @@ -47,8 +48,8 @@ impl Expression for FieldExpression { } } - fn eval_row(&self, input: &OwnedRow) -> Result { - let struct_datum = self.input.eval_row(input)?; + async fn eval_row(&self, input: &OwnedRow) -> Result { + let struct_datum = self.input.eval_row(input).await?; struct_datum .map(|s| match s { ScalarImpl::Struct(v) => Ok(v.fields()[self.index].clone()), @@ -110,8 +111,8 @@ mod tests { use crate::expr::test_utils::{make_field_function, make_i32_literal, make_input_ref}; use crate::expr::Expression; - #[test] - fn test_field_expr() { + #[tokio::test] + async fn test_field_expr() { let input_node = make_input_ref(0, TypeName::Struct); let literal_node = make_i32_literal(0); let field_expr = FieldExpression::try_from(&make_field_function( @@ -129,7 +130,7 @@ mod tests { ); let data_chunk = DataChunk::new(vec![array.into()], 1); - let res = field_expr.eval(&data_chunk).unwrap(); + let res = field_expr.eval(&data_chunk).await.unwrap(); assert_eq!(res.datum_at(0), Some(ScalarImpl::Int32(1))); assert_eq!(res.datum_at(1), Some(ScalarImpl::Int32(2))); assert_eq!(res.datum_at(2), Some(ScalarImpl::Int32(3))); @@ -137,8 +138,8 @@ mod tests { assert_eq!(res.datum_at(4), Some(ScalarImpl::Int32(5))); } - #[test] - fn test_nested_field_expr() { + #[tokio::test] + async fn test_nested_field_expr() { let field_node = make_field_function( vec![make_input_ref(0, TypeName::Struct), make_i32_literal(0)], TypeName::Int32, @@ -167,7 +168,7 @@ mod tests { ); let data_chunk = DataChunk::new(vec![array.into()], 1); - let res = field_expr.eval(&data_chunk).unwrap(); + let res = field_expr.eval(&data_chunk).await.unwrap(); assert_eq!(res.datum_at(0), Some(ScalarImpl::Float32(1.0.into()))); assert_eq!(res.datum_at(1), Some(ScalarImpl::Float32(2.0.into()))); assert_eq!(res.datum_at(2), Some(ScalarImpl::Float32(3.0.into()))); diff --git a/src/expr/src/expr/expr_in.rs b/src/expr/src/expr/expr_in.rs index 82399247d23fc..6f8abb1f4cc1a 100644 --- a/src/expr/src/expr/expr_in.rs +++ b/src/expr/src/expr/expr_in.rs @@ -16,6 +16,7 @@ use std::collections::HashSet; use std::fmt::Debug; use std::sync::Arc; +use futures_util::future::FutureExt; use risingwave_common::array::{ArrayBuilder, ArrayRef, BoolArrayBuilder, DataChunk}; use risingwave_common::row::OwnedRow; use risingwave_common::types::{DataType, Datum, Scalar, ToOwnedDatum}; @@ -66,13 +67,14 @@ impl InExpression { } } +#[async_trait::async_trait] impl Expression for InExpression { fn return_type(&self) -> DataType { self.return_type.clone() } - fn eval(&self, input: &DataChunk) -> Result { - let input_array = self.left.eval_checked(input)?; + async fn eval(&self, input: &DataChunk) -> Result { + let input_array = self.left.eval_checked(input).await?; let mut output_array = BoolArrayBuilder::new(input_array.len()); for (data, vis) in input_array.iter().zip_eq_fast(input.vis().iter()) { if vis { @@ -85,8 +87,8 @@ impl Expression for InExpression { Ok(Arc::new(output_array.finish().into())) } - fn eval_row(&self, input: &OwnedRow) -> Result { - let data = self.left.eval_row(input)?; + async fn eval_row(&self, input: &OwnedRow) -> Result { + let data = self.left.eval_row(input).await?; let ret = self.exists(&data); Ok(ret.map(|b| b.to_scalar_value())) } @@ -111,7 +113,10 @@ impl<'a> TryFrom<&'a ExprNode> for InExpression { let data_chunk = DataChunk::new_dummy(1); for child in &children[1..] { let const_expr = build_from_prost(child)?; - let array = const_expr.eval(&data_chunk)?; + let array = const_expr + .eval(&data_chunk) + .now_or_never() + .expect("constant expression should not be async")?; let datum = array.value_at(0).to_owned_datum(); data.push(datum); } @@ -182,8 +187,8 @@ mod tests { assert!(InExpression::try_from(&p).is_ok()); } - #[test] - fn test_eval_search_expr() { + #[tokio::test] + async fn test_eval_search_expr() { let input_refs = [ Box::new(InputRefExpression::new(DataType::Varchar, 0)), Box::new(InputRefExpression::new(DataType::Varchar, 0)), @@ -226,6 +231,7 @@ mod tests { let vis = data_chunks[i].visibility(); let res = search_expr .eval(&data_chunks[i]) + .await .unwrap() .compact(vis.unwrap(), expected[i].len()); @@ -235,8 +241,8 @@ mod tests { } } - #[test] - fn test_eval_row_search_expr() { + #[tokio::test] + async fn test_eval_row_search_expr() { let input_refs = [ Box::new(InputRefExpression::new(DataType::Varchar, 0)), Box::new(InputRefExpression::new(DataType::Varchar, 0)), @@ -267,7 +273,7 @@ mod tests { for (j, row_input) in row_inputs[i].iter().enumerate() { let row_input = vec![row_input.map(|s| s.into())]; let row = OwnedRow::new(row_input); - let result = search_expr.eval_row(&row).unwrap(); + let result = search_expr.eval_row(&row).await.unwrap(); assert_eq!(result, expected[i][j].map(ScalarImpl::Bool)); } } diff --git a/src/expr/src/expr/expr_input_ref.rs b/src/expr/src/expr/expr_input_ref.rs index c6388af2b29af..7e4d61475d2fa 100644 --- a/src/expr/src/expr/expr_input_ref.rs +++ b/src/expr/src/expr/expr_input_ref.rs @@ -31,16 +31,17 @@ pub struct InputRefExpression { idx: usize, } +#[async_trait::async_trait] impl Expression for InputRefExpression { fn return_type(&self) -> DataType { self.return_type.clone() } - fn eval(&self, input: &DataChunk) -> Result { + async fn eval(&self, input: &DataChunk) -> Result { Ok(input.column_at(self.idx).array()) } - fn eval_row(&self, input: &OwnedRow) -> Result { + async fn eval_row(&self, input: &OwnedRow) -> Result { let cell = input.index(self.idx).as_ref().cloned(); Ok(cell) } @@ -85,14 +86,14 @@ mod tests { use crate::expr::{Expression, InputRefExpression}; - #[test] - fn test_eval_row_input_ref() { + #[tokio::test] + async fn test_eval_row_input_ref() { let datums: Vec = vec![Some(1.into()), Some(2.into()), None]; let input_row = OwnedRow::new(datums.clone()); for (i, expected) in datums.iter().enumerate() { let expr = InputRefExpression::new(DataType::Int32, i); - let result = expr.eval_row(&input_row).unwrap(); + let result = expr.eval_row(&input_row).await.unwrap(); assert_eq!(*expected, result); } } diff --git a/src/expr/src/expr/expr_is_null.rs b/src/expr/src/expr/expr_is_null.rs index 35022becdc0a2..95c91920f9f74 100644 --- a/src/expr/src/expr/expr_is_null.rs +++ b/src/expr/src/expr/expr_is_null.rs @@ -44,32 +44,34 @@ impl IsNotNullExpression { } } +#[async_trait::async_trait] impl Expression for IsNullExpression { fn return_type(&self) -> DataType { DataType::Boolean } - fn eval(&self, input: &DataChunk) -> Result { - let child_arr = self.child.eval_checked(input)?; + async fn eval(&self, input: &DataChunk) -> Result { + let child_arr = self.child.eval_checked(input).await?; let arr = BoolArray::new(!child_arr.null_bitmap(), Bitmap::ones(input.capacity())); Ok(Arc::new(ArrayImpl::Bool(arr))) } - fn eval_row(&self, input: &OwnedRow) -> Result { - let result = self.child.eval_row(input)?; + async fn eval_row(&self, input: &OwnedRow) -> Result { + let result = self.child.eval_row(input).await?; let is_null = result.is_none(); Ok(Some(is_null.to_scalar_value())) } } +#[async_trait::async_trait] impl Expression for IsNotNullExpression { fn return_type(&self) -> DataType { DataType::Boolean } - fn eval(&self, input: &DataChunk) -> Result { - let child_arr = self.child.eval_checked(input)?; + async fn eval(&self, input: &DataChunk) -> Result { + let child_arr = self.child.eval_checked(input).await?; let null_bitmap = match Arc::try_unwrap(child_arr) { Ok(child_arr) => child_arr.into_null_bitmap(), Err(child_arr) => child_arr.null_bitmap().clone(), @@ -79,8 +81,8 @@ impl Expression for IsNotNullExpression { Ok(Arc::new(ArrayImpl::Bool(arr))) } - fn eval_row(&self, input: &OwnedRow) -> Result { - let result = self.child.eval_row(input)?; + async fn eval_row(&self, input: &OwnedRow) -> Result { + let result = self.child.eval_row(input).await?; let is_not_null = result.is_some(); Ok(Some(is_not_null.to_scalar_value())) } @@ -98,7 +100,7 @@ mod tests { use crate::expr::{BoxedExpression, InputRefExpression}; use crate::Result; - fn do_test( + async fn do_test( expr: BoxedExpression, expected_eval_result: Vec, expected_eval_row_result: Vec, @@ -112,7 +114,7 @@ mod tests { }; let input_chunk = DataChunk::new(vec![input_array.into()], 3); - let result_array = expr.eval(&input_chunk).unwrap(); + let result_array = expr.eval(&input_chunk).await.unwrap(); assert_eq!(3, result_array.len()); for (i, v) in expected_eval_result.iter().enumerate() { assert_eq!( @@ -127,25 +129,29 @@ mod tests { ]; for (i, row) in rows.iter().enumerate() { - let result = expr.eval_row(row).unwrap().unwrap(); + let result = expr.eval_row(row).await.unwrap().unwrap(); assert_eq!(expected_eval_row_result[i], result.into_bool()); } Ok(()) } - #[test] - fn test_is_null() -> Result<()> { + #[tokio::test] + async fn test_is_null() -> Result<()> { let expr = IsNullExpression::new(Box::new(InputRefExpression::new(DataType::Decimal, 0))); - do_test(Box::new(expr), vec![false, false, true], vec![false, true]).unwrap(); + do_test(Box::new(expr), vec![false, false, true], vec![false, true]) + .await + .unwrap(); Ok(()) } - #[test] - fn test_is_not_null() -> Result<()> { + #[tokio::test] + async fn test_is_not_null() -> Result<()> { let expr = IsNotNullExpression::new(Box::new(InputRefExpression::new(DataType::Decimal, 0))); - do_test(Box::new(expr), vec![true, true, false], vec![true, false]).unwrap(); + do_test(Box::new(expr), vec![true, true, false], vec![true, false]) + .await + .unwrap(); Ok(()) } } diff --git a/src/expr/src/expr/expr_jsonb_access.rs b/src/expr/src/expr/expr_jsonb_access.rs index 7c26779a39109..8445976c74591 100644 --- a/src/expr/src/expr/expr_jsonb_access.rs +++ b/src/expr/src/expr/expr_jsonb_access.rs @@ -78,6 +78,7 @@ where } } +#[async_trait::async_trait] impl Expression for JsonbAccessExpression where A: Array, @@ -89,14 +90,14 @@ where O::return_type() } - fn eval(&self, input: &DataChunk) -> crate::Result { + async fn eval(&self, input: &DataChunk) -> crate::Result { let Either::Left(path_expr) = &self.path else { unreachable!("optimization for const path not implemented yet"); }; - let path_array = path_expr.eval_checked(input)?; + let path_array = path_expr.eval_checked(input).await?; let path_array: &A = path_array.as_ref().into(); - let input_array = self.input.eval_checked(input)?; + let input_array = self.input.eval_checked(input).await?; let input_array: &JsonbArray = input_array.as_ref().into(); let mut builder = O::new(input.capacity()); @@ -123,16 +124,16 @@ where Ok(std::sync::Arc::new(builder.finish().into())) } - fn eval_row(&self, input: &OwnedRow) -> crate::Result { + async fn eval_row(&self, input: &OwnedRow) -> crate::Result { let Either::Left(path_expr) = &self.path else { unreachable!("optimization for const path not implemented yet"); }; - let p = path_expr.eval_row(input)?; + let p = path_expr.eval_row(input).await?; let p = p .as_ref() .map(|p| p.as_scalar_ref_impl().try_into().unwrap()); - let v = self.input.eval_row(input)?; + let v = self.input.eval_row(input).await?; let v = v .as_ref() .map(|v| v.as_scalar_ref_impl().try_into().unwrap()); diff --git a/src/expr/src/expr/expr_literal.rs b/src/expr/src/expr/expr_literal.rs index da4097e9b2f62..376ed2dda42ed 100644 --- a/src/expr/src/expr/expr_literal.rs +++ b/src/expr/src/expr/expr_literal.rs @@ -33,12 +33,13 @@ pub struct LiteralExpression { literal: Datum, } +#[async_trait::async_trait] impl Expression for LiteralExpression { fn return_type(&self) -> DataType { self.return_type.clone() } - fn eval(&self, input: &DataChunk) -> Result { + async fn eval(&self, input: &DataChunk) -> Result { let mut array_builder = self.return_type.create_array_builder(input.capacity()); let capacity = input.capacity(); let builder = &mut array_builder; @@ -67,7 +68,7 @@ impl Expression for LiteralExpression { Ok(Arc::new(array_builder.finish())) } - fn eval_row(&self, _input: &OwnedRow) -> Result { + async fn eval_row(&self, _input: &OwnedRow) -> Result { Ok(self.literal.as_ref().cloned()) } } @@ -244,17 +245,17 @@ mod tests { } } - #[test] - fn test_literal_eval_dummy_chunk() { + #[tokio::test] + async fn test_literal_eval_dummy_chunk() { let literal = LiteralExpression::new(DataType::Int32, Some(1.into())); - let result = literal.eval(&DataChunk::new_dummy(1)).unwrap(); + let result = literal.eval(&DataChunk::new_dummy(1)).await.unwrap(); assert_eq!(*result, array_nonnull!(I32Array, [1]).into()); } - #[test] - fn test_literal_eval_row_dummy_chunk() { + #[tokio::test] + async fn test_literal_eval_row_dummy_chunk() { let literal = LiteralExpression::new(DataType::Int32, Some(1.into())); - let result = literal.eval_row(&OwnedRow::new(vec![])).unwrap(); + let result = literal.eval_row(&OwnedRow::new(vec![])).await.unwrap(); assert_eq!(result, Some(1.into())) } } diff --git a/src/expr/src/expr/expr_nested_construct.rs b/src/expr/src/expr/expr_nested_construct.rs index 038907ed51039..18cbd26ce184e 100644 --- a/src/expr/src/expr/expr_nested_construct.rs +++ b/src/expr/src/expr/expr_nested_construct.rs @@ -34,17 +34,17 @@ pub struct NestedConstructExpression { elements: Vec, } +#[async_trait::async_trait] impl Expression for NestedConstructExpression { fn return_type(&self) -> DataType { self.data_type.clone() } - fn eval(&self, input: &DataChunk) -> Result { - let columns = self - .elements - .iter() - .map(|e| e.eval_checked(input)) - .collect::>>()?; + async fn eval(&self, input: &DataChunk) -> Result { + let mut columns = Vec::with_capacity(self.elements.len()); + for e in &self.elements { + columns.push(e.eval_checked(input).await?); + } if let DataType::Struct(t) = &self.data_type { let mut builder = StructArrayBuilder::with_meta( @@ -80,12 +80,11 @@ impl Expression for NestedConstructExpression { } } - fn eval_row(&self, input: &OwnedRow) -> Result { - let datums = self - .elements - .iter() - .map(|e| e.eval_row(input)) - .collect::>>()?; + async fn eval_row(&self, input: &OwnedRow) -> Result { + let mut datums = Vec::with_capacity(self.elements.len()); + for e in &self.elements { + datums.push(e.eval_row(input).await?); + } if let DataType::Struct { .. } = &self.data_type { Ok(Some(StructValue::new(datums).to_scalar_value())) } else if let DataType::List { datatype: _ } = &self.data_type { @@ -135,8 +134,8 @@ mod tests { use super::NestedConstructExpression; use crate::expr::{BoxedExpression, Expression, LiteralExpression}; - #[test] - fn test_eval_array_expr() { + #[tokio::test] + async fn test_eval_array_expr() { let expr = NestedConstructExpression { data_type: DataType::List { datatype: DataType::Int32.into(), @@ -144,12 +143,12 @@ mod tests { elements: vec![i32_expr(1.into()), i32_expr(2.into())], }; - let arr = expr.eval(&DataChunk::new_dummy(2)).unwrap(); + let arr = expr.eval(&DataChunk::new_dummy(2)).await.unwrap(); assert_eq!(arr.len(), 2); } - #[test] - fn test_eval_row_array_expr() { + #[tokio::test] + async fn test_eval_row_array_expr() { let expr = NestedConstructExpression { data_type: DataType::List { datatype: DataType::Int32.into(), @@ -157,7 +156,11 @@ mod tests { elements: vec![i32_expr(1.into()), i32_expr(2.into())], }; - let scalar_impl = expr.eval_row(&OwnedRow::new(vec![])).unwrap().unwrap(); + let scalar_impl = expr + .eval_row(&OwnedRow::new(vec![])) + .await + .unwrap() + .unwrap(); let expected = ListValue::new(vec![Some(1.into()), Some(2.into())]).to_scalar_value(); assert_eq!(expected, scalar_impl); } diff --git a/src/expr/src/expr/expr_quaternary_bytes.rs b/src/expr/src/expr/expr_quaternary_bytes.rs index 74f9e6904c518..b1a2c097a54f2 100644 --- a/src/expr/src/expr/expr_quaternary_bytes.rs +++ b/src/expr/src/expr/expr_quaternary_bytes.rs @@ -48,15 +48,15 @@ mod tests { use super::*; use crate::expr::LiteralExpression; - fn test_evals_dummy(expr: BoxedExpression, expected: Datum, is_negative_len: bool) { - let res = expr.eval(&DataChunk::new_dummy(1)); + async fn test_evals_dummy(expr: BoxedExpression, expected: Datum, is_negative_len: bool) { + let res = expr.eval(&DataChunk::new_dummy(1)).await; if is_negative_len { assert!(res.is_err()); } else { assert_eq!(res.unwrap().to_datum(), expected); } - let res = expr.eval_row(&OwnedRow::new(vec![])); + let res = expr.eval_row(&OwnedRow::new(vec![])).await; if is_negative_len { assert!(res.is_err()); } else { @@ -64,8 +64,8 @@ mod tests { } } - #[test] - fn test_overlay() { + #[tokio::test] + async fn test_overlay() { let cases = vec![ ("aaa", "XY", 1, 0, "XYaaa"), ("aaa_aaa", "XYZ", 4, 1, "aaaXYZaaa"), @@ -96,7 +96,7 @@ mod tests { DataType::Varchar, ); - test_evals_dummy(expr, Some(ScalarImpl::from(String::from(expected))), false); + test_evals_dummy(expr, Some(ScalarImpl::from(String::from(expected))), false).await; } } } diff --git a/src/expr/src/expr/expr_regexp.rs b/src/expr/src/expr/expr_regexp.rs index 2f3348f6cea4d..ec6e01df4284e 100644 --- a/src/expr/src/expr/expr_regexp.rs +++ b/src/expr/src/expr/expr_regexp.rs @@ -189,6 +189,7 @@ impl RegexpMatchExpression { } } +#[async_trait::async_trait] impl Expression for RegexpMatchExpression { fn return_type(&self) -> DataType { DataType::List { @@ -196,8 +197,8 @@ impl Expression for RegexpMatchExpression { } } - fn eval(&self, input: &DataChunk) -> Result { - let text_arr = self.child.eval_checked(input)?; + async fn eval(&self, input: &DataChunk) -> Result { + let text_arr = self.child.eval_checked(input).await?; let text_arr: &Utf8Array = text_arr.as_ref().into(); let mut output = ListArrayBuilder::with_meta( input.capacity(), @@ -220,8 +221,8 @@ impl Expression for RegexpMatchExpression { Ok(Arc::new(output.finish().into())) } - fn eval_row(&self, input: &OwnedRow) -> Result { - let text = self.child.eval_row(input)?; + async fn eval_row(&self, input: &OwnedRow) -> Result { + let text = self.child.eval_row(input).await?; Ok(if let Some(ScalarImpl::Utf8(text)) = text { self.match_one(Some(&text)).map(Into::into) } else { diff --git a/src/expr/src/expr/expr_some_all.rs b/src/expr/src/expr/expr_some_all.rs index 58ce0094463f6..749ce3951f386 100644 --- a/src/expr/src/expr/expr_some_all.rs +++ b/src/expr/src/expr/expr_some_all.rs @@ -72,14 +72,15 @@ impl SomeAllExpression { } } +#[async_trait::async_trait] impl Expression for SomeAllExpression { fn return_type(&self) -> DataType { DataType::Boolean } - fn eval(&self, data_chunk: &DataChunk) -> Result { - let arr_left = self.left_expr.eval_checked(data_chunk)?; - let arr_right = self.right_expr.eval_checked(data_chunk)?; + async fn eval(&self, data_chunk: &DataChunk) -> Result { + let arr_left = self.left_expr.eval_checked(data_chunk).await?; + let arr_right = self.right_expr.eval_checked(data_chunk).await?; let bitmap = data_chunk.visibility(); let mut num_array = Vec::with_capacity(data_chunk.capacity()); @@ -148,7 +149,7 @@ impl Expression for SomeAllExpression { capacity, ); - let func_results = self.func.eval(&data_chunk)?; + let func_results = self.func.eval(&data_chunk).await?; let mut func_results_iter = func_results.as_bool().iter(); Ok(Arc::new( num_array @@ -164,20 +165,20 @@ impl Expression for SomeAllExpression { )) } - fn eval_row(&self, row: &OwnedRow) -> Result { - let datum_left = self.left_expr.eval_row(row)?; - let datum_right = self.right_expr.eval_row(row)?; + async fn eval_row(&self, row: &OwnedRow) -> Result { + let datum_left = self.left_expr.eval_row(row).await?; + let datum_right = self.right_expr.eval_row(row).await?; if let Some(array) = datum_right { match array { ScalarImpl::List(array) => { - let scalar_vec = array - .values() - .iter() - .map(|d| { - self.func - .eval_row(&OwnedRow::new(vec![datum_left.clone(), d.clone()])) - }) - .collect::>>()?; + let mut scalar_vec = Vec::with_capacity(array.values().len()); + for d in array.values() { + let e = self + .func + .eval_row(&OwnedRow::new(vec![datum_left.clone(), d.clone()])) + .await?; + scalar_vec.push(e); + } let boolean_vec = scalar_vec .into_iter() .map(|scalar_ref| scalar_ref.map(|s| s.into_bool())) diff --git a/src/expr/src/expr/expr_ternary_bytes.rs b/src/expr/src/expr/expr_ternary_bytes.rs index e12ae34dd10a5..5e0a63cceef89 100644 --- a/src/expr/src/expr/expr_ternary_bytes.rs +++ b/src/expr/src/expr/expr_ternary_bytes.rs @@ -119,15 +119,15 @@ mod tests { use super::*; use crate::expr::LiteralExpression; - fn test_evals_dummy(expr: BoxedExpression, expected: Datum, is_negative_len: bool) { - let res = expr.eval(&DataChunk::new_dummy(1)); + async fn test_evals_dummy(expr: BoxedExpression, expected: Datum, is_negative_len: bool) { + let res = expr.eval(&DataChunk::new_dummy(1)).await; if is_negative_len { assert!(res.is_err()); } else { assert_eq!(res.unwrap().to_datum(), expected); } - let res = expr.eval_row(&OwnedRow::new(vec![])); + let res = expr.eval_row(&OwnedRow::new(vec![])).await; if is_negative_len { assert!(res.is_err()); } else { @@ -135,8 +135,8 @@ mod tests { } } - #[test] - fn test_substr_start_end() { + #[tokio::test] + async fn test_substr_start_end() { let text = "quick brown"; let cases = [ ( @@ -186,12 +186,12 @@ mod tests { DataType::Varchar, ); - test_evals_dummy(expr, expected, is_negative_len); + test_evals_dummy(expr, expected, is_negative_len).await; } } - #[test] - fn test_replace() { + #[tokio::test] + async fn test_replace() { let cases = [ ("hello, word", "我的", "world", "hello, word"), ("hello, word", "", "world", "hello, word"), @@ -224,12 +224,12 @@ mod tests { DataType::Varchar, ); - test_evals_dummy(expr, Some(ScalarImpl::from(String::from(expected))), false); + test_evals_dummy(expr, Some(ScalarImpl::from(String::from(expected))), false).await; } } - #[test] - fn test_overlay() { + #[tokio::test] + async fn test_overlay() { let cases = vec![ ("aaa__aaa", "XY", 4, "aaaXYaaa"), ("aaa", "XY", 3, "aaXY"), @@ -255,7 +255,7 @@ mod tests { DataType::Varchar, ); - test_evals_dummy(expr, Some(ScalarImpl::from(String::from(expected))), false); + test_evals_dummy(expr, Some(ScalarImpl::from(String::from(expected))), false).await; } } } diff --git a/src/expr/src/expr/expr_to_char_const_tmpl.rs b/src/expr/src/expr/expr_to_char_const_tmpl.rs index aa3fe82623f71..4544ee9947e00 100644 --- a/src/expr/src/expr/expr_to_char_const_tmpl.rs +++ b/src/expr/src/expr/expr_to_char_const_tmpl.rs @@ -34,16 +34,17 @@ pub(crate) struct ExprToCharConstTmpl { pub(crate) ctx: ExprToCharConstTmplContext, } +#[async_trait::async_trait] impl Expression for ExprToCharConstTmpl { fn return_type(&self) -> DataType { DataType::Varchar } - fn eval( + async fn eval( &self, input: &risingwave_common::array::DataChunk, ) -> crate::Result { - let data_arr = self.child.eval_checked(input)?; + let data_arr = self.child.eval_checked(input).await?; let data_arr: &NaiveDateTimeArray = data_arr.as_ref().into(); let mut output = Utf8ArrayBuilder::new(input.capacity()); for (data, vis) in data_arr.iter().zip_eq_fast(input.vis().iter()) { @@ -64,8 +65,8 @@ impl Expression for ExprToCharConstTmpl { Ok(Arc::new(output.finish().into())) } - fn eval_row(&self, input: &OwnedRow) -> crate::Result { - let data = self.child.eval_row(input)?; + async fn eval_row(&self, input: &OwnedRow) -> crate::Result { + let data = self.child.eval_row(input).await?; Ok(if let Some(ScalarImpl::NaiveDateTime(data)) = data { Some( data.0 diff --git a/src/expr/src/expr/expr_to_timestamp_const_tmpl.rs b/src/expr/src/expr/expr_to_timestamp_const_tmpl.rs index 94003542b0af5..c5223437904ad 100644 --- a/src/expr/src/expr/expr_to_timestamp_const_tmpl.rs +++ b/src/expr/src/expr/expr_to_timestamp_const_tmpl.rs @@ -34,16 +34,17 @@ pub(crate) struct ExprToTimestampConstTmpl { pub(crate) ctx: ExprToTimestampConstTmplContext, } +#[async_trait::async_trait] impl Expression for ExprToTimestampConstTmpl { fn return_type(&self) -> DataType { DataType::Varchar } - fn eval( + async fn eval( &self, input: &risingwave_common::array::DataChunk, ) -> crate::Result { - let data_arr = self.child.eval_checked(input)?; + let data_arr = self.child.eval_checked(input).await?; let data_arr: &Utf8Array = data_arr.as_ref().into(); let mut output = NaiveDateTimeArrayBuilder::new(input.capacity()); for (data, vis) in data_arr.iter().zip_eq_fast(input.vis().iter()) { @@ -60,8 +61,8 @@ impl Expression for ExprToTimestampConstTmpl { Ok(Arc::new(output.finish().into())) } - fn eval_row(&self, input: &OwnedRow) -> crate::Result { - let data = self.child.eval_row(input)?; + async fn eval_row(&self, input: &OwnedRow) -> crate::Result { + let data = self.child.eval_row(input).await?; Ok(if let Some(ScalarImpl::Utf8(data)) = data { let res = to_timestamp_const_tmpl(&data, &self.ctx.chrono_pattern)?; Some(res.into()) diff --git a/src/expr/src/expr/expr_udf.rs b/src/expr/src/expr/expr_udf.rs index 3cadb64555824..b6c5ff7f8e23e 100644 --- a/src/expr/src/expr/expr_udf.rs +++ b/src/expr/src/expr/expr_udf.rs @@ -37,30 +37,26 @@ pub struct UdfExpression { identifier: String, } -// TODO: make evaluation functions async -// At present, we use `block_in_place` + `block_on` as a workaround to run -// async functions in sync context. #[cfg(not(madsim))] +#[async_trait::async_trait] impl Expression for UdfExpression { fn return_type(&self) -> DataType { self.return_type.clone() } - fn eval(&self, input: &DataChunk) -> Result { + async fn eval(&self, input: &DataChunk) -> Result { let vis = input.vis().to_bitmap(); - let columns: Vec<_> = self - .children - .iter() - .map(|c| c.eval_checked(input).map(|a| a.as_ref().into())) - .try_collect()?; + let mut columns = Vec::with_capacity(self.children.len()); + for child in &self.children { + let array = child.eval_checked(input).await?; + columns.push(array.as_ref().into()); + } let opts = arrow_array::RecordBatchOptions::default().with_row_count(Some(input.cardinality())); let input = arrow_array::RecordBatch::try_new_with_options(self.arg_schema.clone(), columns, &opts) .expect("failed to build record batch"); - let output = tokio::task::block_in_place(|| { - tokio::runtime::Handle::current().block_on(self.client.call(&self.identifier, input)) - })?; + let output = self.client.call(&self.identifier, input).await?; let arrow_array = output .columns() .get(0) @@ -70,9 +66,9 @@ impl Expression for UdfExpression { Ok(Arc::new(array)) } - fn eval_row(&self, input: &OwnedRow) -> Result { + async fn eval_row(&self, input: &OwnedRow) -> Result { let chunk = DataChunk::from_rows(std::slice::from_ref(input), &self.arg_types); - let output_array = self.eval(&chunk)?; + let output_array = self.eval(&chunk).await?; Ok(output_array.to_datum()) } } @@ -111,16 +107,17 @@ impl<'a> TryFrom<&'a ExprNode> for UdfExpression { } #[cfg(madsim)] +#[async_trait::async_trait] impl Expression for UdfExpression { fn return_type(&self) -> DataType { self.return_type.clone() } - fn eval(&self, input: &DataChunk) -> Result { + async fn eval(&self, input: &DataChunk) -> Result { panic!("UDF is not supported in simulation yet"); } - fn eval_row(&self, input: &OwnedRow) -> Result { + async fn eval_row(&self, input: &OwnedRow) -> Result { panic!("UDF is not supported in simulation yet"); } } diff --git a/src/expr/src/expr/expr_unary.rs b/src/expr/src/expr/expr_unary.rs index b957fb74d9628..b8c3bea2bac2d 100644 --- a/src/expr/src/expr/expr_unary.rs +++ b/src/expr/src/expr/expr_unary.rs @@ -377,15 +377,15 @@ mod tests { use crate::expr::test_utils::{make_expression, make_input_ref}; use crate::vector_op::cast::{str_parse, try_cast}; - #[test] - fn test_unary() { - test_unary_bool::(|x| !x, Type::Not); - test_unary_date::(|x| try_cast(x).unwrap(), Type::Cast); - test_str_to_int16::(|x| str_parse(x).unwrap()); + #[tokio::test] + async fn test_unary() { + test_unary_bool::(|x| !x, Type::Not).await; + test_unary_date::(|x| try_cast(x).unwrap(), Type::Cast).await; + test_str_to_int16::(|x| str_parse(x).unwrap()).await; } - #[test] - fn test_i16_to_i32() { + #[tokio::test] + async fn test_i16_to_i32() { let mut input = Vec::>::new(); let mut target = Vec::>::new(); for i in 0..100i16 { @@ -412,7 +412,7 @@ mod tests { })), }; let vec_executor = build_from_prost(&expr).unwrap(); - let res = vec_executor.eval(&data_chunk).unwrap(); + let res = vec_executor.eval(&data_chunk).await.unwrap(); let arr: &I32Array = res.as_ref().into(); for (idx, item) in arr.iter().enumerate() { let x = target[idx].as_ref().map(|x| x.as_scalar_ref()); @@ -421,14 +421,14 @@ mod tests { for i in 0..input.len() { let row = OwnedRow::new(vec![input[i].map(|int| int.to_scalar_value())]); - let result = vec_executor.eval_row(&row).unwrap(); + let result = vec_executor.eval_row(&row).await.unwrap(); let expected = target[i].map(|int| int.to_scalar_value()); assert_eq!(result, expected); } } - #[test] - fn test_neg() { + #[tokio::test] + async fn test_neg() { let mut input = Vec::>::new(); let mut target = Vec::>::new(); @@ -455,7 +455,7 @@ mod tests { })), }; let vec_executor = build_from_prost(&expr).unwrap(); - let res = vec_executor.eval(&data_chunk).unwrap(); + let res = vec_executor.eval(&data_chunk).await.unwrap(); let arr: &I32Array = res.as_ref().into(); for (idx, item) in arr.iter().enumerate() { let x = target[idx].as_ref().map(|x| x.as_scalar_ref()); @@ -464,13 +464,13 @@ mod tests { for i in 0..input.len() { let row = OwnedRow::new(vec![input[i].map(|int| int.to_scalar_value())]); - let result = vec_executor.eval_row(&row).unwrap(); + let result = vec_executor.eval_row(&row).await.unwrap(); let expected = target[i].map(|int| int.to_scalar_value()); assert_eq!(result, expected); } } - fn test_str_to_int16(f: F) + async fn test_str_to_int16(f: F) where A: Array, for<'a> &'a A: std::convert::From<&'a ArrayImpl>, @@ -505,7 +505,7 @@ mod tests { })), }; let vec_executor = build_from_prost(&expr).unwrap(); - let res = vec_executor.eval(&data_chunk).unwrap(); + let res = vec_executor.eval(&data_chunk).await.unwrap(); let arr: &A = res.as_ref().into(); for (idx, item) in arr.iter().enumerate() { let x = target[idx].as_ref().map(|x| x.as_scalar_ref()); @@ -517,13 +517,13 @@ mod tests { .as_ref() .cloned() .map(|str| str.to_scalar_value())]); - let result = vec_executor.eval_row(&row).unwrap(); + let result = vec_executor.eval_row(&row).await.unwrap(); let expected = target[i].as_ref().cloned().map(|x| x.to_scalar_value()); assert_eq!(result, expected); } } - fn test_unary_bool(f: F, kind: Type) + async fn test_unary_bool(f: F, kind: Type) where A: Array, for<'a> &'a A: std::convert::From<&'a ArrayImpl>, @@ -549,7 +549,7 @@ mod tests { let data_chunk = DataChunk::new(vec![col1], 100); let expr = make_expression(kind, &[TypeName::Boolean], &[0]); let vec_executor = build_from_prost(&expr).unwrap(); - let res = vec_executor.eval(&data_chunk).unwrap(); + let res = vec_executor.eval(&data_chunk).await.unwrap(); let arr: &A = res.as_ref().into(); for (idx, item) in arr.iter().enumerate() { let x = target[idx].as_ref().map(|x| x.as_scalar_ref()); @@ -558,13 +558,13 @@ mod tests { for i in 0..input.len() { let row = OwnedRow::new(vec![input[i].map(|b| b.to_scalar_value())]); - let result = vec_executor.eval_row(&row).unwrap(); + let result = vec_executor.eval_row(&row).await.unwrap(); let expected = target[i].as_ref().cloned().map(|x| x.to_scalar_value()); assert_eq!(result, expected); } } - fn test_unary_date(f: F, kind: Type) + async fn test_unary_date(f: F, kind: Type) where A: Array, for<'a> &'a A: std::convert::From<&'a ArrayImpl>, @@ -588,7 +588,7 @@ mod tests { let data_chunk = DataChunk::new(vec![col1], 100); let expr = make_expression(kind, &[TypeName::Date], &[0]); let vec_executor = build_from_prost(&expr).unwrap(); - let res = vec_executor.eval(&data_chunk).unwrap(); + let res = vec_executor.eval(&data_chunk).await.unwrap(); let arr: &A = res.as_ref().into(); for (idx, item) in arr.iter().enumerate() { let x = target[idx].as_ref().map(|x| x.as_scalar_ref()); @@ -597,7 +597,7 @@ mod tests { for i in 0..input.len() { let row = OwnedRow::new(vec![input[i].map(|d| d.to_scalar_value())]); - let result = vec_executor.eval_row(&row).unwrap(); + let result = vec_executor.eval_row(&row).await.unwrap(); let expected = target[i].as_ref().cloned().map(|x| x.to_scalar_value()); assert_eq!(result, expected); } diff --git a/src/expr/src/expr/expr_vnode.rs b/src/expr/src/expr/expr_vnode.rs index 7f951e12accbb..c31aee110077c 100644 --- a/src/expr/src/expr/expr_vnode.rs +++ b/src/expr/src/expr/expr_vnode.rs @@ -62,12 +62,13 @@ impl<'a> TryFrom<&'a ExprNode> for VnodeExpression { } } +#[async_trait::async_trait] impl Expression for VnodeExpression { fn return_type(&self) -> DataType { DataType::Int16 } - fn eval(&self, input: &DataChunk) -> Result { + async fn eval(&self, input: &DataChunk) -> Result { let hash_values = input.get_hash_values(&self.dist_key_indices, Crc32FastBuilder); let mut builder = I16ArrayBuilder::new(input.capacity()); hash_values @@ -76,7 +77,7 @@ impl Expression for VnodeExpression { Ok(Arc::new(ArrayImpl::from(builder.finish()))) } - fn eval_row(&self, input: &OwnedRow) -> Result { + async fn eval_row(&self, input: &OwnedRow) -> Result { let vnode = input .project(&self.dist_key_indices) .hash(Crc32FastBuilder) @@ -112,8 +113,8 @@ mod tests { } } - #[test] - fn test_vnode_expr_eval() { + #[tokio::test] + async fn test_vnode_expr_eval() { let input_node1 = make_input_ref(0, TypeName::Int32); let input_node2 = make_input_ref(0, TypeName::Int64); let input_node3 = make_input_ref(0, TypeName::Varchar); @@ -129,7 +130,7 @@ mod tests { 2 32 def 3 88 ghi", ); - let actual = vnode_expr.eval(&chunk).unwrap(); + let actual = vnode_expr.eval(&chunk).await.unwrap(); actual.iter().for_each(|vnode| { let vnode = vnode.unwrap().into_int16(); assert!(vnode >= 0); @@ -137,8 +138,8 @@ mod tests { }); } - #[test] - fn test_vnode_expr_eval_row() { + #[tokio::test] + async fn test_vnode_expr_eval_row() { let input_node1 = make_input_ref(0, TypeName::Int32); let input_node2 = make_input_ref(0, TypeName::Int64); let input_node3 = make_input_ref(0, TypeName::Varchar); @@ -156,7 +157,7 @@ mod tests { ); let rows: Vec<_> = chunk.rows().map(|row| row.into_owned_row()).collect(); for row in rows { - let actual = vnode_expr.eval_row(&row).unwrap(); + let actual = vnode_expr.eval_row(&row).await.unwrap(); let vnode = actual.unwrap().into_int16(); assert!(vnode >= 0); assert!((vnode as usize) < VirtualNode::COUNT); diff --git a/src/expr/src/expr/mod.rs b/src/expr/src/expr/mod.rs index 29c3ec1c87a35..3f9ea94e36ec9 100644 --- a/src/expr/src/expr/mod.rs +++ b/src/expr/src/expr/mod.rs @@ -68,8 +68,9 @@ pub mod test_utils; use std::sync::Arc; use risingwave_common::array::{ArrayRef, DataChunk}; -use risingwave_common::row::OwnedRow; +use risingwave_common::row::{OwnedRow, Row}; use risingwave_common::types::{DataType, Datum}; +use static_assertions::const_assert; pub use self::agg::AggKind; pub use self::build_expr_from_prost::build_from_prost; @@ -77,16 +78,17 @@ pub use self::expr_binary_nonnull::new_binary_expr; pub use self::expr_input_ref::InputRefExpression; pub use self::expr_literal::LiteralExpression; pub use self::expr_unary::new_unary_expr; -use super::Result; +use super::{ExprError, Result}; /// Instance of an expression +#[async_trait::async_trait] pub trait Expression: std::fmt::Debug + Sync + Send { /// Get the return data type. fn return_type(&self) -> DataType; /// Eval the result with extra checks. - fn eval_checked(&self, input: &DataChunk) -> Result { - let res = self.eval(input)?; + async fn eval_checked(&self, input: &DataChunk) -> Result { + let res = self.eval(input).await?; // TODO: Decide to use assert or debug_assert by benchmarks. assert_eq!(res.len(), input.capacity()); @@ -99,10 +101,10 @@ pub trait Expression: std::fmt::Debug + Sync + Send { /// # Arguments /// /// * `input` - input data of the Project Executor - fn eval(&self, input: &DataChunk) -> Result; + async fn eval(&self, input: &DataChunk) -> Result; /// Evaluate the expression in row-based execution. - fn eval_row(&self, input: &OwnedRow) -> Result; + async fn eval_row(&self, input: &OwnedRow) -> Result; /// Wrap the expression in a Box. fn boxed(self) -> BoxedExpression @@ -113,8 +115,52 @@ pub trait Expression: std::fmt::Debug + Sync + Send { } } +impl dyn Expression { + pub async fn eval_infallible(&self, input: &DataChunk, on_err: impl Fn(ExprError)) -> ArrayRef { + const_assert!(!STRICT_MODE); + + if let Ok(array) = self.eval(input).await { + return array; + } + + // When eval failed, recompute in row-based execution + // and pad with NULL for each failed row. + let mut array_builder = self.return_type().create_array_builder(input.cardinality()); + for row in input.rows_with_holes() { + if let Some(row) = row { + let datum = self + .eval_row_infallible(&row.into_owned_row(), &on_err) + .await; + array_builder.append_datum(&datum); + } else { + array_builder.append_null(); + } + } + Arc::new(array_builder.finish()) + } + + pub async fn eval_row_infallible(&self, input: &OwnedRow, on_err: impl Fn(ExprError)) -> Datum { + const_assert!(!STRICT_MODE); + + self.eval_row(input).await.unwrap_or_else(|err| { + on_err(err); + None + }) + } +} + /// An owned dynamically typed [`Expression`]. pub type BoxedExpression = Box; /// A reference to a dynamically typed [`Expression`]. pub type ExpressionRef = Arc; + +/// Controls the behavior when a compute error happens. +/// +/// - If set to `false`, `NULL` will be inserted. +/// - TODO: If set to `true`, The MV will be suspended and removed from further checkpoints. It can +/// still be used to serve outdated data without corruption. +/// +/// See also . +#[allow(dead_code)] +const STRICT_MODE: bool = false; diff --git a/src/expr/src/expr/template.rs b/src/expr/src/expr/template.rs index 9b3206a6c0c2b..c10ca3cf4aef3 100644 --- a/src/expr/src/expr/template.rs +++ b/src/expr/src/expr/template.rs @@ -15,6 +15,8 @@ //! Template macro to generate code for unary/binary/ternary expression. use std::fmt; +use std::future::Future; +use std::pin::Pin; use std::sync::Arc; use itertools::multizip; @@ -28,10 +30,15 @@ use crate::expr::{BoxedExpression, Expression}; macro_rules! gen_eval { { ($macro:ident, $macro_row:ident), $ty_name:ident, $OA:ty, $($arg:ident,)* } => { - fn eval(&self, data_chunk: &DataChunk) -> $crate::Result { - paste! { + fn eval<'a, 'b, 'async_trait>(&'a self, data_chunk: &'b DataChunk) + -> Pin> + Send + 'async_trait>> + where + 'a: 'async_trait, + 'b: 'async_trait, + { + Box::pin(async move { paste! { $( - let [] = self.[].eval_checked(data_chunk)?; + let [] = self.[].eval_checked(data_chunk).await?; let []: &$arg = [].as_ref().into(); )* @@ -55,22 +62,27 @@ macro_rules! gen_eval { output_array.finish().into() } })) - } + }}) } /// `eval_row()` first calls `eval_row()` on the inner expressions to get the resulting datums, /// then directly calls `$macro_row` to evaluate the current expression. - fn eval_row(&self, row: &OwnedRow) -> $crate::Result { - paste! { + fn eval_row<'a, 'b, 'async_trait>(&'a self, row: &'b OwnedRow) + -> Pin> + Send + 'async_trait>> + where + 'a: 'async_trait, + 'b: 'async_trait, + { + Box::pin(async move { paste! { $( - let [] = self.[].eval_row(row)?; + let [] = self.[].eval_row(row).await?; let [] = [].as_ref().map(|s| s.as_scalar_ref_impl().try_into().unwrap()); )* let output_scalar = $macro_row!(self, $([],)*); let output_datum = output_scalar.map(|s| s.to_scalar_value()); Ok(output_datum) - } + }}) } } } @@ -285,6 +297,7 @@ macro_rules! gen_expr_nullable { } } + #[async_trait::async_trait] impl<$($arg: Array, )* OA: Array, F: Fn($(Option<$arg::RefItem<'_>>, )*) -> $crate::Result> + Sync + Send, diff --git a/src/expr/src/expr/template_fast.rs b/src/expr/src/expr/template_fast.rs index 70eeca15bf2c6..9091f351c4d74 100644 --- a/src/expr/src/expr/template_fast.rs +++ b/src/expr/src/expr/template_fast.rs @@ -71,6 +71,7 @@ where } } +#[async_trait::async_trait] impl Expression for BooleanUnaryExpression where FA: Fn(&BoolArray) -> BoolArray + Send + Sync, @@ -80,15 +81,15 @@ where DataType::Boolean } - fn eval(&self, data_chunk: &DataChunk) -> crate::Result { - let child = self.child.eval_checked(data_chunk)?; + async fn eval(&self, data_chunk: &DataChunk) -> crate::Result { + let child = self.child.eval_checked(data_chunk).await?; let a = child.as_bool(); let c = (self.f_array)(a); Ok(Arc::new(c.into())) } - fn eval_row(&self, row: &OwnedRow) -> crate::Result { - let datum = self.child.eval_row(row)?; + async fn eval_row(&self, row: &OwnedRow) -> crate::Result { + let datum = self.child.eval_row(row).await?; let scalar = datum.map(|s| *s.as_bool()); let output_scalar = (self.f_value)(scalar); let output_datum = output_scalar.map(|s| s.to_scalar_value()); @@ -127,6 +128,7 @@ where } } +#[async_trait::async_trait] impl Expression for BooleanBinaryExpression where FA: Fn(&BoolArray, &BoolArray) -> BoolArray + Send + Sync, @@ -136,18 +138,18 @@ where DataType::Boolean } - fn eval(&self, data_chunk: &DataChunk) -> crate::Result { - let left = self.left.eval_checked(data_chunk)?; - let right = self.right.eval_checked(data_chunk)?; + async fn eval(&self, data_chunk: &DataChunk) -> crate::Result { + let left = self.left.eval_checked(data_chunk).await?; + let right = self.right.eval_checked(data_chunk).await?; let a = left.as_bool(); let b = right.as_bool(); let c = (self.f_array)(a, b); Ok(Arc::new(c.into())) } - fn eval_row(&self, row: &OwnedRow) -> crate::Result { - let left = self.left.eval_row(row)?.map(|s| *s.as_bool()); - let right = self.right.eval_row(row)?.map(|s| *s.as_bool()); + async fn eval_row(&self, row: &OwnedRow) -> crate::Result { + let left = self.left.eval_row(row).await?.map(|s| *s.as_bool()); + let right = self.right.eval_row(row).await?.map(|s| *s.as_bool()); let output_scalar = (self.f_value)(left, right); let output_datum = output_scalar.map(|s| s.to_scalar_value()); Ok(output_datum) @@ -186,6 +188,7 @@ where } } +#[async_trait::async_trait] impl Expression for UnaryExpression where F: Fn(A) -> T + Send + Sync, @@ -197,8 +200,8 @@ where self.return_type.clone() } - fn eval(&self, data_chunk: &DataChunk) -> crate::Result { - let child = self.child.eval_checked(data_chunk)?; + async fn eval(&self, data_chunk: &DataChunk) -> crate::Result { + let child = self.child.eval_checked(data_chunk).await?; let bitmap = match data_chunk.visibility() { Some(vis) => vis & child.null_bitmap(), @@ -209,8 +212,8 @@ where Ok(Arc::new(c.into())) } - fn eval_row(&self, row: &OwnedRow) -> crate::Result { - let datum = self.child.eval_row(row)?; + async fn eval_row(&self, row: &OwnedRow) -> crate::Result { + let datum = self.child.eval_row(row).await?; let scalar = datum .as_ref() .map(|s| s.as_scalar_ref_impl().try_into().unwrap()); @@ -263,6 +266,7 @@ where } } +#[async_trait::async_trait] impl Expression for BinaryExpression where F: Fn(A, B) -> T + Send + Sync, @@ -276,9 +280,9 @@ where self.return_type.clone() } - fn eval(&self, data_chunk: &DataChunk) -> crate::Result { - let left = self.left.eval_checked(data_chunk)?; - let right = self.right.eval_checked(data_chunk)?; + async fn eval(&self, data_chunk: &DataChunk) -> crate::Result { + let left = self.left.eval_checked(data_chunk).await?; + let right = self.right.eval_checked(data_chunk).await?; assert_eq!(left.len(), right.len()); let mut bitmap = match data_chunk.visibility() { @@ -298,9 +302,9 @@ where Ok(Arc::new(c.into())) } - fn eval_row(&self, row: &OwnedRow) -> crate::Result { - let datum1 = self.left.eval_row(row)?; - let datum2 = self.right.eval_row(row)?; + async fn eval_row(&self, row: &OwnedRow) -> crate::Result { + let datum1 = self.left.eval_row(row).await?; + let datum2 = self.right.eval_row(row).await?; let scalar1 = datum1 .as_ref() .map(|s| s.as_scalar_ref_impl().try_into().unwrap()); @@ -352,6 +356,7 @@ where } } +#[async_trait::async_trait] impl Expression for CompareExpression where F: Fn(A, B) -> bool + Send + Sync, @@ -364,9 +369,9 @@ where DataType::Boolean } - fn eval(&self, data_chunk: &DataChunk) -> crate::Result { - let left = self.left.eval_checked(data_chunk)?; - let right = self.right.eval_checked(data_chunk)?; + async fn eval(&self, data_chunk: &DataChunk) -> crate::Result { + let left = self.left.eval_checked(data_chunk).await?; + let right = self.right.eval_checked(data_chunk).await?; assert_eq!(left.len(), right.len()); let mut bitmap = match data_chunk.visibility() { @@ -387,9 +392,9 @@ where Ok(Arc::new(c.into())) } - fn eval_row(&self, row: &OwnedRow) -> crate::Result { - let datum1 = self.left.eval_row(row)?; - let datum2 = self.right.eval_row(row)?; + async fn eval_row(&self, row: &OwnedRow) -> crate::Result { + let datum1 = self.left.eval_row(row).await?; + let datum2 = self.right.eval_row(row).await?; let scalar1 = datum1 .as_ref() .map(|s| s.as_scalar_ref_impl().try_into().unwrap()); @@ -442,6 +447,7 @@ where } } +#[async_trait::async_trait] impl Expression for IsDistinctFromExpression where F: Fn(A, B) -> bool + Send + Sync, @@ -454,9 +460,9 @@ where DataType::Boolean } - fn eval(&self, data_chunk: &DataChunk) -> crate::Result { - let left = self.left.eval_checked(data_chunk)?; - let right = self.right.eval_checked(data_chunk)?; + async fn eval(&self, data_chunk: &DataChunk) -> crate::Result { + let left = self.left.eval_checked(data_chunk).await?; + let right = self.right.eval_checked(data_chunk).await?; assert_eq!(left.len(), right.len()); let a: &PrimitiveArray = (&*left).into(); @@ -477,9 +483,9 @@ where Ok(Arc::new(c.into())) } - fn eval_row(&self, row: &OwnedRow) -> crate::Result { - let datum1 = self.left.eval_row(row)?; - let datum2 = self.right.eval_row(row)?; + async fn eval_row(&self, row: &OwnedRow) -> crate::Result { + let datum1 = self.left.eval_row(row).await?; + let datum2 = self.right.eval_row(row).await?; let scalar1 = datum1 .as_ref() .map(|s| s.as_scalar_ref_impl().try_into().unwrap()); diff --git a/src/expr/src/table_function/generate_series.rs b/src/expr/src/table_function/generate_series.rs index 5c7303a6b2c0e..96e67ffe7c57f 100644 --- a/src/expr/src/table_function/generate_series.rs +++ b/src/expr/src/table_function/generate_series.rs @@ -93,6 +93,7 @@ where } } +#[async_trait::async_trait] impl TableFunction for GenerateSeries where @@ -106,13 +107,13 @@ where self.start.return_type() } - fn eval(&self, input: &DataChunk) -> Result> { - let ret_start = self.start.eval_checked(input)?; + async fn eval(&self, input: &DataChunk) -> Result> { + let ret_start = self.start.eval_checked(input).await?; let arr_start: &T = ret_start.as_ref().into(); - let ret_stop = self.stop.eval_checked(input)?; + let ret_stop = self.stop.eval_checked(input).await?; let arr_stop: &T = ret_stop.as_ref().into(); - let ret_step = self.step.eval_checked(input)?; + let ret_step = self.step.eval_checked(input).await?; let arr_step: &S = ret_step.as_ref().into(); let bitmap = input.visibility(); @@ -188,15 +189,15 @@ mod tests { const CHUNK_SIZE: usize = 1024; - #[test] - fn test_generate_i32_series() { - generate_series_test_case(2, 4, 1); - generate_series_test_case(4, 2, -1); - generate_series_test_case(0, 9, 2); - generate_series_test_case(0, (CHUNK_SIZE * 2 + 3) as i32, 1); + #[tokio::test] + async fn test_generate_i32_series() { + generate_series_test_case(2, 4, 1).await; + generate_series_test_case(4, 2, -1).await; + generate_series_test_case(0, 9, 2).await; + generate_series_test_case(0, (CHUNK_SIZE * 2 + 3) as i32, 1).await; } - fn generate_series_test_case(start: i32, stop: i32, step: i32) { + async fn generate_series_test_case(start: i32, stop: i32, step: i32) { fn to_lit_expr(v: i32) -> BoxedExpression { LiteralExpression::new(DataType::Int32, Some(v.into())).boxed() } @@ -211,26 +212,27 @@ mod tests { let expect_cnt = ((stop - start) / step + 1) as usize; let dummy_chunk = DataChunk::new_dummy(1); - let arrays = function.eval(&dummy_chunk).unwrap(); + let arrays = function.eval(&dummy_chunk).await.unwrap(); let cnt: usize = arrays.iter().map(|a| a.len()).sum(); assert_eq!(cnt, expect_cnt); } - #[test] - fn test_generate_time_series() { + #[tokio::test] + async fn test_generate_time_series() { let start_time = str_to_timestamp("2008-03-01 00:00:00").unwrap(); let stop_time = str_to_timestamp("2008-03-09 00:00:00").unwrap(); let one_minute_step = IntervalUnit::from_minutes(1); let one_hour_step = IntervalUnit::from_minutes(60); let one_day_step = IntervalUnit::from_days(1); - generate_time_series_test_case(start_time, stop_time, one_minute_step, 60 * 24 * 8 + 1); - generate_time_series_test_case(start_time, stop_time, one_hour_step, 24 * 8 + 1); - generate_time_series_test_case(start_time, stop_time, one_day_step, 8 + 1); - generate_time_series_test_case(stop_time, start_time, -one_day_step, 8 + 1); + generate_time_series_test_case(start_time, stop_time, one_minute_step, 60 * 24 * 8 + 1) + .await; + generate_time_series_test_case(start_time, stop_time, one_hour_step, 24 * 8 + 1).await; + generate_time_series_test_case(start_time, stop_time, one_day_step, 8 + 1).await; + generate_time_series_test_case(stop_time, start_time, -one_day_step, 8 + 1).await; } - fn generate_time_series_test_case( + async fn generate_time_series_test_case( start: NaiveDateTimeWrapper, stop: NaiveDateTimeWrapper, step: IntervalUnit, @@ -248,21 +250,21 @@ mod tests { ); let dummy_chunk = DataChunk::new_dummy(1); - let arrays = function.eval(&dummy_chunk).unwrap(); + let arrays = function.eval(&dummy_chunk).await.unwrap(); let cnt: usize = arrays.iter().map(|a| a.len()).sum(); assert_eq!(cnt, expect_cnt); } - #[test] - fn test_i32_range() { - range_test_case(2, 4, 1); - range_test_case(4, 2, -1); - range_test_case(0, 9, 2); - range_test_case(0, (CHUNK_SIZE * 2 + 3) as i32, 1); + #[tokio::test] + async fn test_i32_range() { + range_test_case(2, 4, 1).await; + range_test_case(4, 2, -1).await; + range_test_case(0, 9, 2).await; + range_test_case(0, (CHUNK_SIZE * 2 + 3) as i32, 1).await; } - fn range_test_case(start: i32, stop: i32, step: i32) { + async fn range_test_case(start: i32, stop: i32, step: i32) { fn to_lit_expr(v: i32) -> BoxedExpression { LiteralExpression::new(DataType::Int32, Some(v.into())).boxed() } @@ -277,26 +279,26 @@ mod tests { let expect_cnt = ((stop - start - step.signum()) / step + 1) as usize; let dummy_chunk = DataChunk::new_dummy(1); - let arrays = function.eval(&dummy_chunk).unwrap(); + let arrays = function.eval(&dummy_chunk).await.unwrap(); let cnt: usize = arrays.iter().map(|a| a.len()).sum(); assert_eq!(cnt, expect_cnt); } - #[test] - fn test_time_range() { + #[tokio::test] + async fn test_time_range() { let start_time = str_to_timestamp("2008-03-01 00:00:00").unwrap(); let stop_time = str_to_timestamp("2008-03-09 00:00:00").unwrap(); let one_minute_step = IntervalUnit::from_minutes(1); let one_hour_step = IntervalUnit::from_minutes(60); let one_day_step = IntervalUnit::from_days(1); - time_range_test_case(start_time, stop_time, one_minute_step, 60 * 24 * 8); - time_range_test_case(start_time, stop_time, one_hour_step, 24 * 8); - time_range_test_case(start_time, stop_time, one_day_step, 8); - time_range_test_case(stop_time, start_time, -one_day_step, 8); + time_range_test_case(start_time, stop_time, one_minute_step, 60 * 24 * 8).await; + time_range_test_case(start_time, stop_time, one_hour_step, 24 * 8).await; + time_range_test_case(start_time, stop_time, one_day_step, 8).await; + time_range_test_case(stop_time, start_time, -one_day_step, 8).await; } - fn time_range_test_case( + async fn time_range_test_case( start: NaiveDateTimeWrapper, stop: NaiveDateTimeWrapper, step: IntervalUnit, @@ -314,7 +316,7 @@ mod tests { ); let dummy_chunk = DataChunk::new_dummy(1); - let arrays = function.eval(&dummy_chunk).unwrap(); + let arrays = function.eval(&dummy_chunk).await.unwrap(); let cnt: usize = arrays.iter().map(|a| a.len()).sum(); assert_eq!(cnt, expect_cnt); diff --git a/src/expr/src/table_function/mod.rs b/src/expr/src/table_function/mod.rs index 169bcdd74b61c..fa947f80dc2fc 100644 --- a/src/expr/src/table_function/mod.rs +++ b/src/expr/src/table_function/mod.rs @@ -18,7 +18,7 @@ use either::Either; use itertools::Itertools; use risingwave_common::array::{ArrayRef, DataChunk}; use risingwave_common::types::DataType; -use risingwave_pb::expr::project_set_select_item::SelectItem::*; +use risingwave_pb::expr::project_set_select_item::SelectItem; use risingwave_pb::expr::{ ProjectSetSelectItem as SelectItemProst, TableFunction as TableFunctionProst, }; @@ -40,10 +40,11 @@ use self::user_defined::*; /// /// A table function takes a row as input and returns a table. It is also known as Set-Returning /// Function. +#[async_trait::async_trait] pub trait TableFunction: std::fmt::Debug + Sync + Send { fn return_type(&self) -> DataType; - fn eval(&self, input: &DataChunk) -> Result>; + async fn eval(&self, input: &DataChunk) -> Result>; fn boxed(self) -> BoxedTableFunction where @@ -84,13 +85,14 @@ pub fn repeat_tf(expr: BoxedExpression, n: usize) -> BoxedTableFunction { n: usize, } + #[async_trait::async_trait] impl TableFunction for Mock { fn return_type(&self) -> DataType { self.expr.return_type() } - fn eval(&self, input: &DataChunk) -> Result> { - let array = self.expr.eval(input)?; + async fn eval(&self, input: &DataChunk) -> Result> { + let array = self.expr.eval(input).await?; let mut res = vec![]; for datum_ref in array.iter() { @@ -130,8 +132,8 @@ impl From for ProjectSetSelectItem { impl ProjectSetSelectItem { pub fn from_prost(prost: &SelectItemProst, chunk_size: usize) -> Result { match prost.select_item.as_ref().unwrap() { - Expr(expr) => expr_build_from_prost(expr).map(Into::into), - TableFunction(tf) => build_from_prost(tf, chunk_size).map(Into::into), + SelectItem::Expr(expr) => expr_build_from_prost(expr).map(Into::into), + SelectItem::TableFunction(tf) => build_from_prost(tf, chunk_size).map(Into::into), } } @@ -142,10 +144,10 @@ impl ProjectSetSelectItem { } } - pub fn eval(&self, input: &DataChunk) -> Result, ArrayRef>> { + pub async fn eval(&self, input: &DataChunk) -> Result, ArrayRef>> { match self { - ProjectSetSelectItem::TableFunction(tf) => tf.eval(input).map(Either::Left), - ProjectSetSelectItem::Expr(expr) => expr.eval(input).map(Either::Right), + ProjectSetSelectItem::TableFunction(tf) => tf.eval(input).await.map(Either::Left), + ProjectSetSelectItem::Expr(expr) => expr.eval(input).await.map(Either::Right), } } } diff --git a/src/expr/src/table_function/regexp_matches.rs b/src/expr/src/table_function/regexp_matches.rs index aa459de433d51..5347e2aa532d8 100644 --- a/src/expr/src/table_function/regexp_matches.rs +++ b/src/expr/src/table_function/regexp_matches.rs @@ -58,6 +58,7 @@ impl RegexpMatches { } } +#[async_trait::async_trait] impl TableFunction for RegexpMatches { fn return_type(&self) -> DataType { DataType::List { @@ -65,8 +66,8 @@ impl TableFunction for RegexpMatches { } } - fn eval(&self, input: &DataChunk) -> Result> { - let text_arr = self.text.eval_checked(input)?; + async fn eval(&self, input: &DataChunk) -> Result> { + let text_arr = self.text.eval_checked(input).await?; let text_arr: &Utf8Array = text_arr.as_ref().into(); let bitmap = input.visibility(); diff --git a/src/expr/src/table_function/unnest.rs b/src/expr/src/table_function/unnest.rs index c43463c250ece..ef3fe02f168a6 100644 --- a/src/expr/src/table_function/unnest.rs +++ b/src/expr/src/table_function/unnest.rs @@ -36,13 +36,14 @@ impl Unnest { } } +#[async_trait::async_trait] impl TableFunction for Unnest { fn return_type(&self) -> DataType { self.return_type.clone() } - fn eval(&self, input: &DataChunk) -> Result> { - let ret_list = self.list.eval_checked(input)?; + async fn eval(&self, input: &DataChunk) -> Result> { + let ret_list = self.list.eval_checked(input).await?; let arr_list: &ListArray = ret_list.as_ref().into(); let bitmap = input.visibility(); diff --git a/src/expr/src/table_function/user_defined.rs b/src/expr/src/table_function/user_defined.rs index 26b0c090f19bd..38a5f7be2d16b 100644 --- a/src/expr/src/table_function/user_defined.rs +++ b/src/expr/src/table_function/user_defined.rs @@ -33,17 +33,19 @@ pub struct UserDefinedTableFunction { } #[cfg(not(madsim))] +#[async_trait::async_trait] impl TableFunction for UserDefinedTableFunction { fn return_type(&self) -> DataType { self.return_type.clone() } - fn eval(&self, input: &DataChunk) -> Result> { - let columns: Vec<_> = self - .children - .iter() - .map(|c| c.eval_checked(input).map(|a| a.as_ref().into())) - .try_collect()?; + async fn eval(&self, input: &DataChunk) -> Result> { + let mut columns = Vec::with_capacity(self.children.len()); + for c in &self.children { + let val = c.eval_checked(input).await?.as_ref().into(); + columns.push(val); + } + let opts = arrow_array::RecordBatchOptions::default().with_row_count(Some(input.cardinality())); let input = @@ -93,12 +95,13 @@ pub fn new_user_defined( } #[cfg(madsim)] +#[async_trait::async_trait] impl TableFunction for UserDefinedTableFunction { fn return_type(&self) -> DataType { panic!("UDF is not supported in simulation yet"); } - fn eval(&self, _input: &DataChunk) -> Result> { + async fn eval(&self, _input: &DataChunk) -> Result> { panic!("UDF is not supported in simulation yet"); } } diff --git a/src/expr/src/vector_op/agg/aggregator.rs b/src/expr/src/vector_op/agg/aggregator.rs index 73b589b67831b..aff68f1e930ce 100644 --- a/src/expr/src/vector_op/agg/aggregator.rs +++ b/src/expr/src/vector_op/agg/aggregator.rs @@ -33,14 +33,15 @@ use crate::vector_op::agg::string_agg::create_string_agg_state; use crate::Result; /// An `Aggregator` supports `update` data and `output` result. +#[async_trait::async_trait] pub trait Aggregator: Send + DynClone + 'static { fn return_type(&self) -> DataType; /// `update_single` update the aggregator with a single row with type checked at runtime. - fn update_single(&mut self, input: &DataChunk, row_id: usize) -> Result<()>; + async fn update_single(&mut self, input: &DataChunk, row_id: usize) -> Result<()>; /// `update_multi` update the aggregator with multiple rows with type checked at runtime. - fn update_multi( + async fn update_multi( &mut self, input: &DataChunk, start_row_id: usize, diff --git a/src/expr/src/vector_op/agg/approx_count_distinct.rs b/src/expr/src/vector_op/agg/approx_count_distinct.rs index fde8e04c545a2..62e64b22439d6 100644 --- a/src/expr/src/vector_op/agg/approx_count_distinct.rs +++ b/src/expr/src/vector_op/agg/approx_count_distinct.rs @@ -118,18 +118,19 @@ impl ApproxCountDistinct { } } +#[async_trait::async_trait] impl Aggregator for ApproxCountDistinct { fn return_type(&self) -> DataType { self.return_type.clone() } - fn update_single(&mut self, input: &DataChunk, row_id: usize) -> Result<()> { + async fn update_single(&mut self, input: &DataChunk, row_id: usize) -> Result<()> { let array = input.column_at(self.input_col_idx).array_ref(); self.add_datum(array.value_at(row_id)); Ok(()) } - fn update_multi( + async fn update_multi( &mut self, input: &DataChunk, start_row_id: usize, @@ -176,8 +177,8 @@ mod tests { DataChunk::new(vec![col1], size) } - #[test] - fn test_update_single() { + #[tokio::test] + async fn test_update_single() { let inputs_size: [usize; 3] = [20000, 10000, 5000]; let inputs_start: [i32; 3] = [0, 20000, 30000]; @@ -187,7 +188,7 @@ mod tests { for i in 0..3 { let data_chunk = generate_data_chunk(inputs_size[i], inputs_start[i]); for row_id in 0..data_chunk.cardinality() { - agg.update_single(&data_chunk, row_id).unwrap(); + agg.update_single(&data_chunk, row_id).await.unwrap(); } agg.output(&mut builder).unwrap(); } @@ -196,8 +197,8 @@ mod tests { assert_eq!(array.len(), 3); } - #[test] - fn test_update_multi() { + #[tokio::test] + async fn test_update_multi() { let inputs_size: [usize; 3] = [20000, 10000, 5000]; let inputs_start: [i32; 3] = [0, 20000, 30000]; @@ -207,6 +208,7 @@ mod tests { for i in 0..3 { let data_chunk = generate_data_chunk(inputs_size[i], inputs_start[i]); agg.update_multi(&data_chunk, 0, data_chunk.cardinality()) + .await .unwrap(); agg.output(&mut builder).unwrap(); } diff --git a/src/expr/src/vector_op/agg/array_agg.rs b/src/expr/src/vector_op/agg/array_agg.rs index 4910e009362ec..f810149b74ab9 100644 --- a/src/expr/src/vector_op/agg/array_agg.rs +++ b/src/expr/src/vector_op/agg/array_agg.rs @@ -52,18 +52,19 @@ impl ArrayAggUnordered { } } +#[async_trait::async_trait] impl Aggregator for ArrayAggUnordered { fn return_type(&self) -> DataType { self.return_type.clone() } - fn update_single(&mut self, input: &DataChunk, row_id: usize) -> Result<()> { + async fn update_single(&mut self, input: &DataChunk, row_id: usize) -> Result<()> { let array = input.column_at(self.agg_col_idx).array_ref(); self.push(array.datum_at(row_id)); Ok(()) } - fn update_multi( + async fn update_multi( &mut self, input: &DataChunk, start_row_id: usize, @@ -71,7 +72,7 @@ impl Aggregator for ArrayAggUnordered { ) -> Result<()> { self.values.reserve(end_row_id - start_row_id); for row_id in start_row_id..end_row_id { - self.update_single(input, row_id)?; + self.update_single(input, row_id).await?; } Ok(()) } @@ -131,19 +132,20 @@ impl ArrayAggOrdered { } } +#[async_trait::async_trait] impl Aggregator for ArrayAggOrdered { fn return_type(&self) -> DataType { self.return_type.clone() } - fn update_single(&mut self, input: &DataChunk, row_id: usize) -> Result<()> { + async fn update_single(&mut self, input: &DataChunk, row_id: usize) -> Result<()> { let (row, vis) = input.row_at(row_id); assert!(vis); self.push_row(row); Ok(()) } - fn update_multi( + async fn update_multi( &mut self, input: &DataChunk, start_row_id: usize, @@ -151,7 +153,7 @@ impl Aggregator for ArrayAggOrdered { ) -> Result<()> { self.unordered_values.reserve(end_row_id - start_row_id); for row_id in start_row_id..end_row_id { - self.update_single(input, row_id)?; + self.update_single(input, row_id).await?; } Ok(()) } @@ -191,8 +193,8 @@ mod tests { use super::*; - #[test] - fn test_array_agg_basic() -> Result<()> { + #[tokio::test] + async fn test_array_agg_basic() -> Result<()> { let chunk = DataChunk::from_pretty( "i 123 @@ -204,7 +206,7 @@ mod tests { }; let mut agg = create_array_agg_state(return_type.clone(), 0, vec![])?; let mut builder = return_type.create_array_builder(0); - agg.update_multi(&chunk, 0, chunk.cardinality())?; + agg.update_multi(&chunk, 0, chunk.cardinality()).await?; agg.output(&mut builder)?; let output = builder.finish(); let actual = output.into_list(); @@ -223,8 +225,8 @@ mod tests { Ok(()) } - #[test] - fn test_array_agg_empty() -> Result<()> { + #[tokio::test] + async fn test_array_agg_empty() -> Result<()> { let return_type = DataType::List { datatype: Box::new(DataType::Int32), }; @@ -245,7 +247,7 @@ mod tests { .", ); let mut builder = return_type.create_array_builder(0); - agg.update_multi(&chunk, 0, chunk.cardinality())?; + agg.update_multi(&chunk, 0, chunk.cardinality()).await?; agg.output(&mut builder)?; let output = builder.finish(); let actual = output.into_list(); @@ -258,8 +260,8 @@ mod tests { Ok(()) } - #[test] - fn test_array_agg_with_order() -> Result<()> { + #[tokio::test] + async fn test_array_agg_with_order() -> Result<()> { let chunk = DataChunk::from_pretty( "i i 123 3 @@ -279,7 +281,7 @@ mod tests { ], )?; let mut builder = return_type.create_array_builder(0); - agg.update_multi(&chunk, 0, chunk.cardinality())?; + agg.update_multi(&chunk, 0, chunk.cardinality()).await?; agg.output(&mut builder)?; let output = builder.finish(); let actual = output.into_list(); diff --git a/src/expr/src/vector_op/agg/count_star.rs b/src/expr/src/vector_op/agg/count_star.rs index 12021cd69c548..7c67f6ef06508 100644 --- a/src/expr/src/vector_op/agg/count_star.rs +++ b/src/expr/src/vector_op/agg/count_star.rs @@ -34,19 +34,20 @@ impl CountStar { } } +#[async_trait::async_trait] impl Aggregator for CountStar { fn return_type(&self) -> DataType { self.return_type.clone() } - fn update_single(&mut self, input: &DataChunk, row_id: usize) -> Result<()> { + async fn update_single(&mut self, input: &DataChunk, row_id: usize) -> Result<()> { if let (_, true) = input.row_at(row_id) { self.result += 1; } Ok(()) } - fn update_multi( + async fn update_multi( &mut self, input: &DataChunk, start_row_id: usize, diff --git a/src/expr/src/vector_op/agg/filter.rs b/src/expr/src/vector_op/agg/filter.rs index 018bd4bdb9062..4295d9d3a6c73 100644 --- a/src/expr/src/vector_op/agg/filter.rs +++ b/src/expr/src/vector_op/agg/filter.rs @@ -13,7 +13,7 @@ // limitations under the License. use risingwave_common::array::{ArrayBuilderImpl, DataChunk}; -use risingwave_common::buffer::Bitmap; +use risingwave_common::buffer::BitmapBuilder; use risingwave_common::row::Row; use risingwave_common::types::{DataType, ScalarImpl}; @@ -37,26 +37,28 @@ impl Filter { } } +#[async_trait::async_trait] impl Aggregator for Filter { fn return_type(&self) -> DataType { self.inner.return_type() } - fn update_single(&mut self, input: &DataChunk, row_id: usize) -> Result<()> { + async fn update_single(&mut self, input: &DataChunk, row_id: usize) -> Result<()> { let (row_ref, vis) = input.row_at(row_id); assert!(vis); // cuz the input chunk is supposed to be compacted if self .condition - .eval_row(&row_ref.into_owned_row())? + .eval_row(&row_ref.into_owned_row()) + .await? .map(ScalarImpl::into_bool) .unwrap_or(false) { - self.inner.update_single(input, row_id)?; + self.inner.update_single(input, row_id).await?; } Ok(()) } - fn update_multi( + async fn update_multi( &mut self, input: &DataChunk, start_row_id: usize, @@ -64,25 +66,29 @@ impl Aggregator for Filter { ) -> Result<()> { let bitmap = if start_row_id == 0 && end_row_id == input.capacity() { // if the input if the whole chunk, use `eval` to speed up - self.condition.eval(input)?.as_bool().to_bitmap() + self.condition.eval(input).await?.as_bool().to_bitmap() } else { + let mut bitmap_builder = BitmapBuilder::default(); // otherwise, run `eval_row` on each row - (start_row_id..end_row_id) - .map(|row_id| -> Result { - let (row_ref, vis) = input.row_at(row_id); - assert!(vis); // cuz the input chunk is supposed to be compacted - Ok(self - .condition - .eval_row(&row_ref.into_owned_row())? - .map(ScalarImpl::into_bool) - .unwrap_or(false)) - }) - .try_collect::()? + for row_id in start_row_id..end_row_id { + let (row_ref, vis) = input.row_at(row_id); + assert!(vis); // cuz the input chunk is supposed to be compacted + let b = self + .condition + .eval_row(&row_ref.into_owned_row()) + .await? + .map(ScalarImpl::into_bool) + .unwrap_or(false); + bitmap_builder.append(b); + } + bitmap_builder.finish() }; if bitmap.all() { // if the bitmap is all set, meaning all rows satisfy the filter, // call `update_multi` for potential optimization - self.inner.update_multi(input, start_row_id, end_row_id) + self.inner + .update_multi(input, start_row_id, end_row_id) + .await } else { // TODO(yuchao): we might want to pass visibility bitmap to the // inner aggregator, or re-compact the input chunk after filtering. @@ -91,7 +97,7 @@ impl Aggregator for Filter { .enumerate() .filter(|(i, _)| bitmap.is_set(*i)) { - self.inner.update_single(input, row_id)?; + self.inner.update_single(input, row_id).await?; } Ok(()) } @@ -118,17 +124,18 @@ mod tests { count: Arc, } + #[async_trait::async_trait] impl Aggregator for MockAgg { fn return_type(&self) -> DataType { DataType::Int64 } - fn update_single(&mut self, _input: &DataChunk, _row_id: usize) -> Result<()> { + async fn update_single(&mut self, _input: &DataChunk, _row_id: usize) -> Result<()> { self.count.fetch_add(1, Ordering::Relaxed); Ok(()) } - fn update_multi( + async fn update_multi( &mut self, _input: &DataChunk, start_row_id: usize, @@ -144,8 +151,8 @@ mod tests { } } - #[test] - fn test_selective_agg_always_true() -> Result<()> { + #[tokio::test] + async fn test_selective_agg_always_true() -> Result<()> { let condition = Arc::from(LiteralExpression::new(DataType::Boolean, Some(true.into())).boxed()); let agg_count = Arc::new(AtomicUsize::new(0)); @@ -164,20 +171,20 @@ mod tests { 1", ); - agg.update_single(&chunk, 0)?; + agg.update_single(&chunk, 0).await?; assert_eq!(agg_count.load(Ordering::Relaxed), 1); - agg.update_multi(&chunk, 2, 4)?; + agg.update_multi(&chunk, 2, 4).await?; assert_eq!(agg_count.load(Ordering::Relaxed), 3); - agg.update_multi(&chunk, 0, chunk.capacity())?; + agg.update_multi(&chunk, 0, chunk.capacity()).await?; assert_eq!(agg_count.load(Ordering::Relaxed), 7); Ok(()) } - #[test] - fn test_selective_agg() -> Result<()> { + #[tokio::test] + async fn test_selective_agg() -> Result<()> { // filter (where $1 > 5) let condition = Arc::from( new_binary_expr( @@ -204,23 +211,23 @@ mod tests { 1", ); - agg.update_single(&chunk, 0)?; + agg.update_single(&chunk, 0).await?; assert_eq!(agg_count.load(Ordering::Relaxed), 1); - agg.update_single(&chunk, 1)?; // should be filtered out + agg.update_single(&chunk, 1).await?; // should be filtered out assert_eq!(agg_count.load(Ordering::Relaxed), 1); - agg.update_multi(&chunk, 2, 4)?; // only 6 should be applied + agg.update_multi(&chunk, 2, 4).await?; // only 6 should be applied assert_eq!(agg_count.load(Ordering::Relaxed), 2); - agg.update_multi(&chunk, 0, chunk.capacity())?; + agg.update_multi(&chunk, 0, chunk.capacity()).await?; assert_eq!(agg_count.load(Ordering::Relaxed), 4); Ok(()) } - #[test] - fn test_selective_agg_null_condition() -> Result<()> { + #[tokio::test] + async fn test_selective_agg_null_condition() -> Result<()> { let condition = Arc::from( new_binary_expr( ProstType::Equal, @@ -246,13 +253,13 @@ mod tests { 1", ); - agg.update_single(&chunk, 0)?; + agg.update_single(&chunk, 0).await?; assert_eq!(agg_count.load(Ordering::Relaxed), 0); - agg.update_multi(&chunk, 2, 4)?; + agg.update_multi(&chunk, 2, 4).await?; assert_eq!(agg_count.load(Ordering::Relaxed), 0); - agg.update_multi(&chunk, 0, chunk.capacity())?; + agg.update_multi(&chunk, 0, chunk.capacity()).await?; assert_eq!(agg_count.load(Ordering::Relaxed), 0); Ok(()) diff --git a/src/expr/src/vector_op/agg/general_agg.rs b/src/expr/src/vector_op/agg/general_agg.rs index c9e81644e6412..d68f8f97ac9ff 100644 --- a/src/expr/src/vector_op/agg/general_agg.rs +++ b/src/expr/src/vector_op/agg/general_agg.rs @@ -94,6 +94,7 @@ where macro_rules! impl_aggregator { ($input:ty, $input_variant:ident, $result:ty, $result_variant:ident) => { + #[async_trait::async_trait] impl Aggregator for GeneralAgg<$input, F, $result> where F: for<'a> RTFn<'a, $input, $result>, @@ -102,7 +103,7 @@ macro_rules! impl_aggregator { self.return_type.clone() } - fn update_single(&mut self, input: &DataChunk, row_id: usize) -> Result<()> { + async fn update_single(&mut self, input: &DataChunk, row_id: usize) -> Result<()> { if let ArrayImpl::$input_variant(i) = input.column_at(self.input_col_idx).array_ref() { @@ -112,7 +113,7 @@ macro_rules! impl_aggregator { } } - fn update_multi( + async fn update_multi( &mut self, input: &DataChunk, start_row_id: usize, @@ -187,7 +188,7 @@ mod tests { use crate::expr::AggKind; use crate::vector_op::agg::aggregator::create_agg_state_unary; - fn eval_agg( + async fn eval_agg( input_type: DataType, input: ArrayRef, agg_kind: AggKind, @@ -197,13 +198,15 @@ mod tests { let len = input.len(); let input_chunk = DataChunk::new(vec![Column::new(input)], len); let mut agg_state = create_agg_state_unary(input_type, 0, agg_kind, return_type, false)?; - agg_state.update_multi(&input_chunk, 0, input_chunk.cardinality())?; + agg_state + .update_multi(&input_chunk, 0, input_chunk.cardinality()) + .await?; agg_state.output(&mut builder)?; Ok(builder.finish()) } - #[test] - fn vec_sum_int32() -> Result<()> { + #[tokio::test] + async fn vec_sum_int32() -> Result<()> { let input = I32Array::from_iter([1, 2, 3]); let agg_kind = AggKind::Sum; let input_type = DataType::Int32; @@ -214,15 +217,16 @@ mod tests { agg_kind, return_type, ArrayBuilderImpl::Int64(I64ArrayBuilder::new(0)), - )?; + ) + .await?; let actual = actual.as_int64(); let actual = actual.iter().collect::>(); assert_eq!(actual, &[Some(6)]); Ok(()) } - #[test] - fn vec_sum_int64() -> Result<()> { + #[tokio::test] + async fn vec_sum_int64() -> Result<()> { let input = I64Array::from_iter([1, 2, 3]); let agg_kind = AggKind::Sum; let input_type = DataType::Int64; @@ -233,15 +237,16 @@ mod tests { agg_kind, return_type, DecimalArrayBuilder::new(0).into(), - )?; + ) + .await?; let actual: DecimalArray = actual.into(); let actual = actual.iter().collect::>>(); assert_eq!(actual, vec![Some(Decimal::from(6))]); Ok(()) } - #[test] - fn vec_min_float32() -> Result<()> { + #[tokio::test] + async fn vec_min_float32() -> Result<()> { let input = F32Array::from_iter([Some(1.0.into()), Some(2.0.into()), Some(3.0.into())]); let agg_kind = AggKind::Min; let input_type = DataType::Float32; @@ -252,15 +257,16 @@ mod tests { agg_kind, return_type, ArrayBuilderImpl::Float32(F32ArrayBuilder::new(0)), - )?; + ) + .await?; let actual = actual.as_float32(); let actual = actual.iter().collect::>(); assert_eq!(actual, &[Some(1.0.into())]); Ok(()) } - #[test] - fn vec_min_char() -> Result<()> { + #[tokio::test] + async fn vec_min_char() -> Result<()> { let input = Utf8Array::from_iter(["b", "aa"]); let agg_kind = AggKind::Min; let input_type = DataType::Varchar; @@ -271,15 +277,16 @@ mod tests { agg_kind, return_type, ArrayBuilderImpl::Utf8(Utf8ArrayBuilder::new(0)), - )?; + ) + .await?; let actual = actual.as_utf8(); let actual = actual.iter().collect::>(); assert_eq!(actual, vec![Some("aa")]); Ok(()) } - #[test] - fn vec_min_list() -> Result<()> { + #[tokio::test] + async fn vec_min_list() -> Result<()> { use risingwave_common::array; let input = ListArray::from_iter( [ @@ -307,7 +314,8 @@ mod tests { datatype: Box::new(DataType::Int32), }, )), - )?; + ) + .await?; let actual = actual.as_list(); let actual = actual.iter().collect::>(); assert_eq!( @@ -319,8 +327,8 @@ mod tests { Ok(()) } - #[test] - fn vec_max_char() -> Result<()> { + #[tokio::test] + async fn vec_max_char() -> Result<()> { let input = Utf8Array::from_iter(["b", "aa"]); let agg_kind = AggKind::Max; let input_type = DataType::Varchar; @@ -331,16 +339,17 @@ mod tests { agg_kind, return_type, ArrayBuilderImpl::Utf8(Utf8ArrayBuilder::new(0)), - )?; + ) + .await?; let actual = actual.as_utf8(); let actual = actual.iter().collect::>(); assert_eq!(actual, vec![Some("b")]); Ok(()) } - #[test] - fn vec_count_int32() -> Result<()> { - let test_case = |input: ArrayImpl, expected: &[Option]| -> Result<()> { + #[tokio::test] + async fn vec_count_int32() -> Result<()> { + async fn test_case(input: ArrayImpl, expected: &[Option]) -> Result<()> { let agg_kind = AggKind::Count; let input_type = DataType::Int32; let return_type = DataType::Int64; @@ -350,21 +359,22 @@ mod tests { agg_kind, return_type, ArrayBuilderImpl::Int64(I64ArrayBuilder::new(0)), - )?; + ) + .await?; let actual = actual.as_int64(); let actual = actual.iter().collect::>(); assert_eq!(actual, expected); Ok(()) - }; + } let input = I32Array::from_iter([1, 2, 3]); let expected = &[Some(3)]; - test_case(input.into(), expected)?; + test_case(input.into(), expected).await?; #[allow(clippy::needless_borrow)] let input = I32Array::from_iter(&[]); let expected = &[Some(0)]; - test_case(input.into(), expected)?; + test_case(input.into(), expected).await?; let input = I32Array::from_iter([None]); let expected = &[Some(0)]; - test_case(input.into(), expected) + test_case(input.into(), expected).await } } diff --git a/src/expr/src/vector_op/agg/general_distinct_agg.rs b/src/expr/src/vector_op/agg/general_distinct_agg.rs index 97d68565f726c..298b5d9d06817 100644 --- a/src/expr/src/vector_op/agg/general_distinct_agg.rs +++ b/src/expr/src/vector_op/agg/general_distinct_agg.rs @@ -109,6 +109,7 @@ where macro_rules! impl_aggregator { ($input:ty, $input_variant:ident, $result:ty, $result_variant:ident) => { + #[async_trait::async_trait] impl Aggregator for GeneralDistinctAgg<$input, F, $result> where F: for<'a> RTFn<'a, $input, $result>, @@ -117,7 +118,7 @@ macro_rules! impl_aggregator { self.return_type.clone() } - fn update_single(&mut self, input: &DataChunk, row_id: usize) -> Result<()> { + async fn update_single(&mut self, input: &DataChunk, row_id: usize) -> Result<()> { if let ArrayImpl::$input_variant(i) = input.column_at(self.input_col_idx).array_ref() { @@ -127,7 +128,7 @@ macro_rules! impl_aggregator { } } - fn update_multi( + async fn update_multi( &mut self, input: &DataChunk, start_row_id: usize, @@ -198,7 +199,7 @@ mod tests { use crate::expr::AggKind; use crate::vector_op::agg::aggregator::create_agg_state_unary; - fn eval_agg( + async fn eval_agg( input_type: DataType, input: ArrayRef, agg_kind: AggKind, @@ -208,13 +209,15 @@ mod tests { let len = input.len(); let input_chunk = DataChunk::new(vec![Column::new(input)], len); let mut agg_state = create_agg_state_unary(input_type, 0, agg_kind, return_type, true)?; - agg_state.update_multi(&input_chunk, 0, input_chunk.cardinality())?; + agg_state + .update_multi(&input_chunk, 0, input_chunk.cardinality()) + .await?; agg_state.output(&mut builder)?; Ok(builder.finish()) } - #[test] - fn vec_distinct_sum_int32() -> Result<()> { + #[tokio::test] + async fn vec_distinct_sum_int32() -> Result<()> { let input = I32Array::from_iter([1, 1, 3]); let agg_kind = AggKind::Sum; let input_type = DataType::Int32; @@ -225,15 +228,16 @@ mod tests { agg_kind, return_type, ArrayBuilderImpl::Int64(I64ArrayBuilder::new(0)), - )?; + ) + .await?; let actual = actual.as_int64(); let actual = actual.iter().collect::>(); assert_eq!(actual, &[Some(4)]); Ok(()) } - #[test] - fn vec_distinct_sum_int64() -> Result<()> { + #[tokio::test] + async fn vec_distinct_sum_int64() -> Result<()> { let input = I64Array::from_iter([1, 1, 3]); let agg_kind = AggKind::Sum; let input_type = DataType::Int64; @@ -244,15 +248,16 @@ mod tests { agg_kind, return_type, DecimalArrayBuilder::new(0).into(), - )?; + ) + .await?; let actual: &DecimalArray = (&actual).into(); let actual = actual.iter().collect::>>(); assert_eq!(actual, vec![Some(Decimal::from(4))]); Ok(()) } - #[test] - fn vec_distinct_min_float32() -> Result<()> { + #[tokio::test] + async fn vec_distinct_min_float32() -> Result<()> { let input = F32Array::from_iter([Some(1.0.into()), Some(2.0.into()), Some(3.0.into())]); let agg_kind = AggKind::Min; let input_type = DataType::Float32; @@ -263,15 +268,16 @@ mod tests { agg_kind, return_type, ArrayBuilderImpl::Float32(F32ArrayBuilder::new(0)), - )?; + ) + .await?; let actual = actual.as_float32(); let actual = actual.iter().collect::>(); assert_eq!(actual, &[Some(1.0.into())]); Ok(()) } - #[test] - fn vec_distinct_min_char() -> Result<()> { + #[tokio::test] + async fn vec_distinct_min_char() -> Result<()> { let input = Utf8Array::from_iter(["b", "aa"]); let agg_kind = AggKind::Min; let input_type = DataType::Varchar; @@ -282,15 +288,16 @@ mod tests { agg_kind, return_type, ArrayBuilderImpl::Utf8(Utf8ArrayBuilder::new(0)), - )?; + ) + .await?; let actual = actual.as_utf8(); let actual = actual.iter().collect::>(); assert_eq!(actual, vec![Some("aa")]); Ok(()) } - #[test] - fn vec_distinct_max_char() -> Result<()> { + #[tokio::test] + async fn vec_distinct_max_char() -> Result<()> { let input = Utf8Array::from_iter(["b", "aa"]); let agg_kind = AggKind::Max; let input_type = DataType::Varchar; @@ -301,16 +308,17 @@ mod tests { agg_kind, return_type, ArrayBuilderImpl::Utf8(Utf8ArrayBuilder::new(0)), - )?; + ) + .await?; let actual = actual.as_utf8(); let actual = actual.iter().collect::>(); assert_eq!(actual, vec![Some("b")]); Ok(()) } - #[test] - fn vec_distinct_count_int32() -> Result<()> { - let test_case = |input: ArrayImpl, expected: &[Option]| -> Result<()> { + #[tokio::test] + async fn vec_distinct_count_int32() -> Result<()> { + async fn test_case(input: ArrayImpl, expected: &[Option]) -> Result<()> { let agg_kind = AggKind::Count; let input_type = DataType::Int32; let return_type = DataType::Int64; @@ -320,21 +328,22 @@ mod tests { agg_kind, return_type, ArrayBuilderImpl::Int64(I64ArrayBuilder::new(0)), - )?; + ) + .await?; let actual = actual.as_int64(); let actual = actual.iter().collect::>(); assert_eq!(actual, expected); Ok(()) - }; + } let input = I32Array::from_iter([1, 1, 3]); let expected = &[Some(2)]; - test_case(input.into(), expected)?; + test_case(input.into(), expected).await?; #[allow(clippy::needless_borrow)] let input = I32Array::from_iter(&[]); let expected = &[None]; - test_case(input.into(), expected)?; + test_case(input.into(), expected).await?; let input = I32Array::from_iter([None]); let expected = &[Some(0)]; - test_case(input.into(), expected) + test_case(input.into(), expected).await } } diff --git a/src/expr/src/vector_op/agg/string_agg.rs b/src/expr/src/vector_op/agg/string_agg.rs index 046d3288e86a8..bf829f34dfeab 100644 --- a/src/expr/src/vector_op/agg/string_agg.rs +++ b/src/expr/src/vector_op/agg/string_agg.rs @@ -55,12 +55,13 @@ impl StringAggUnordered { } } +#[async_trait::async_trait] impl Aggregator for StringAggUnordered { fn return_type(&self) -> DataType { DataType::Varchar } - fn update_single(&mut self, input: &DataChunk, row_id: usize) -> Result<()> { + async fn update_single(&mut self, input: &DataChunk, row_id: usize) -> Result<()> { if let (ArrayImpl::Utf8(agg_col), ArrayImpl::Utf8(delim_col)) = ( input.column_at(self.agg_col_idx).array_ref(), input.column_at(self.delim_col_idx).array_ref(), @@ -76,7 +77,7 @@ impl Aggregator for StringAggUnordered { } } - fn update_multi( + async fn update_multi( &mut self, input: &DataChunk, start_row_id: usize, @@ -172,12 +173,13 @@ impl StringAggOrdered { } } +#[async_trait::async_trait] impl Aggregator for StringAggOrdered { fn return_type(&self) -> DataType { DataType::Varchar } - fn update_single(&mut self, input: &DataChunk, row_id: usize) -> Result<()> { + async fn update_single(&mut self, input: &DataChunk, row_id: usize) -> Result<()> { if let (ArrayImpl::Utf8(agg_col), ArrayImpl::Utf8(delim_col)) = ( input.column_at(self.agg_col_idx).array_ref(), input.column_at(self.delim_col_idx).array_ref(), @@ -195,7 +197,7 @@ impl Aggregator for StringAggOrdered { } } - fn update_multi( + async fn update_multi( &mut self, input: &DataChunk, start_row_id: usize, @@ -260,8 +262,8 @@ mod tests { use super::*; - #[test] - fn test_string_agg_basic() -> Result<()> { + #[tokio::test] + async fn test_string_agg_basic() -> Result<()> { let chunk = DataChunk::from_pretty( "T T aaa , @@ -271,7 +273,7 @@ mod tests { ); let mut agg = create_string_agg_state(0, 1, vec![])?; let mut builder = ArrayBuilderImpl::Utf8(Utf8ArrayBuilder::new(0)); - agg.update_multi(&chunk, 0, chunk.cardinality())?; + agg.update_multi(&chunk, 0, chunk.cardinality()).await?; agg.output(&mut builder)?; let output = builder.finish(); let actual = output.as_utf8(); @@ -281,8 +283,8 @@ mod tests { Ok(()) } - #[test] - fn test_string_agg_complex() -> Result<()> { + #[tokio::test] + async fn test_string_agg_complex() -> Result<()> { let chunk = DataChunk::from_pretty( "T T aaa , @@ -292,7 +294,7 @@ mod tests { ); let mut agg = create_string_agg_state(0, 1, vec![])?; let mut builder = ArrayBuilderImpl::Utf8(Utf8ArrayBuilder::new(0)); - agg.update_multi(&chunk, 0, chunk.cardinality())?; + agg.update_multi(&chunk, 0, chunk.cardinality()).await?; agg.output(&mut builder)?; let output = builder.finish(); let actual = output.as_utf8(); @@ -302,8 +304,8 @@ mod tests { Ok(()) } - #[test] - fn test_string_agg_with_order() -> Result<()> { + #[tokio::test] + async fn test_string_agg_with_order() -> Result<()> { let chunk = DataChunk::from_pretty( "T T i i _ aaa 1 3 @@ -321,7 +323,7 @@ mod tests { ], )?; let mut builder = ArrayBuilderImpl::Utf8(Utf8ArrayBuilder::new(0)); - agg.update_multi(&chunk, 0, chunk.cardinality())?; + agg.update_multi(&chunk, 0, chunk.cardinality()).await?; agg.output(&mut builder)?; let output = builder.finish(); let actual = output.as_utf8(); diff --git a/src/frontend/src/expr/mod.rs b/src/frontend/src/expr/mod.rs index fde0a471b4699..cc80dc1aad5cd 100644 --- a/src/frontend/src/expr/mod.rs +++ b/src/frontend/src/expr/mod.rs @@ -14,6 +14,7 @@ use enum_as_inner::EnumAsInner; use fixedbitset::FixedBitSet; +use futures::FutureExt; use paste::paste; use risingwave_common::array::ListValue; use risingwave_common::error::Result as RwResult; @@ -242,15 +243,17 @@ impl ExprImpl { /// /// TODO: This is a naive implementation. We should avoid proto ser/de. /// Tracking issue: - fn eval_row(&self, input: &OwnedRow) -> RwResult { + async fn eval_row(&self, input: &OwnedRow) -> RwResult { let backend_expr = build_from_prost(&self.to_expr_proto())?; - backend_expr.eval_row(input).map_err(Into::into) + Ok(backend_expr.eval_row(input).await?) } /// Evaluate a constant expression. pub fn eval_row_const(&self) -> RwResult { assert!(self.is_const()); self.eval_row(&OwnedRow::empty()) + .now_or_never() + .expect("constant expression should not be async") } } diff --git a/src/frontend/src/scheduler/local.rs b/src/frontend/src/scheduler/local.rs index 377ecd2bdae9c..2e165e2c047c0 100644 --- a/src/frontend/src/scheduler/local.rs +++ b/src/frontend/src/scheduler/local.rs @@ -17,7 +17,6 @@ use std::collections::HashMap; use std::sync::Arc; use anyhow::Context; -use futures::executor::block_on; use futures::StreamExt; use futures_async_stream::try_stream; use itertools::Itertools; @@ -41,7 +40,6 @@ use risingwave_pb::batch_plan::{ }; use risingwave_pb::common::WorkerNode; use tokio::sync::mpsc; -use tokio::task::spawn_blocking; use tokio_stream::wrappers::ReceiverStream; use tracing::debug; use uuid::Uuid; @@ -138,11 +136,10 @@ impl LocalQueryExecution { } }; - if cfg!(madsim) { - tokio::spawn(future); - } else { - spawn_blocking(move || block_on(future)); - } + #[cfg(madsim)] + tokio::spawn(future); + #[cfg(not(madsim))] + tokio::task::spawn_blocking(move || futures::executor::block_on(future)); ReceiverStream::new(receiver) } diff --git a/src/stream/src/common/infallible_expr.rs b/src/stream/src/common/infallible_expr.rs deleted file mode 100644 index 2742c5960cb76..0000000000000 --- a/src/stream/src/common/infallible_expr.rs +++ /dev/null @@ -1,56 +0,0 @@ -// Copyright 2023 RisingWave Labs -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -use std::sync::Arc; - -use risingwave_common::array::{ArrayRef, DataChunk}; -use risingwave_common::row::{OwnedRow, Row}; -use risingwave_common::types::Datum; -use risingwave_expr::expr::Expression; -use risingwave_expr::ExprError; -use static_assertions::const_assert; - -pub trait InfallibleExpression: Expression { - fn eval_infallible(&self, input: &DataChunk, on_err: impl Fn(ExprError)) -> ArrayRef { - const_assert!(!crate::STRICT_MODE); - - #[expect(clippy::disallowed_methods)] - self.eval(input).unwrap_or_else(|_err| { - // When eval failed, recompute in row-based execution - // and pad with NULL for each failed row. - let mut array_builder = self.return_type().create_array_builder(input.cardinality()); - for row in input.rows_with_holes() { - if let Some(row) = row { - let datum = self.eval_row_infallible(&row.into_owned_row(), &on_err); - array_builder.append_datum(&datum); - } else { - array_builder.append_null(); - } - } - Arc::new(array_builder.finish()) - }) - } - - fn eval_row_infallible(&self, input: &OwnedRow, on_err: impl Fn(ExprError)) -> Datum { - const_assert!(!crate::STRICT_MODE); - - #[expect(clippy::disallowed_methods)] - self.eval_row(input).unwrap_or_else(|err| { - on_err(err); - None - }) - } -} - -impl InfallibleExpression for E {} diff --git a/src/stream/src/common/mod.rs b/src/stream/src/common/mod.rs index 026af37353995..ebbadb96a2f57 100644 --- a/src/stream/src/common/mod.rs +++ b/src/stream/src/common/mod.rs @@ -14,9 +14,7 @@ pub use builder::*; pub use column_mapping::*; -pub use infallible_expr::*; mod builder; mod column_mapping; -mod infallible_expr; pub mod table; diff --git a/src/stream/src/executor/aggregation/mod.rs b/src/stream/src/executor/aggregation/mod.rs index 5958ccdb9b685..bbe9c4febf2db 100644 --- a/src/stream/src/executor/aggregation/mod.rs +++ b/src/stream/src/executor/aggregation/mod.rs @@ -27,7 +27,6 @@ use risingwave_storage::StateStore; use super::ActorContextRef; use crate::common::table::state_table::StateTable; -use crate::common::InfallibleExpression; use crate::executor::error::StreamExecutorResult; use crate::executor::Executor; @@ -66,7 +65,7 @@ pub fn generate_agg_schema( Schema { fields } } -pub fn agg_call_filter_res( +pub async fn agg_call_filter_res( ctx: &ActorContextRef, identity: &str, agg_call: &AggCall, @@ -90,6 +89,7 @@ pub fn agg_call_filter_res( let data_chunk = DataChunk::new(columns.to_vec(), capacity); if let Bool(filter_res) = filter .eval_infallible(&data_chunk, |err| ctx.on_compute_error(err, identity)) + .await .as_ref() { Some(filter_res.to_bitmap()) diff --git a/src/stream/src/executor/dynamic_filter.rs b/src/stream/src/executor/dynamic_filter.rs index 177f931ef8c33..f4da5696ebf8d 100644 --- a/src/stream/src/executor/dynamic_filter.rs +++ b/src/stream/src/executor/dynamic_filter.rs @@ -42,7 +42,7 @@ use super::{ ActorContextRef, BoxedExecutor, BoxedMessageStream, Executor, Message, PkIndices, PkIndicesRef, }; use crate::common::table::state_table::StateTable; -use crate::common::{InfallibleExpression, StreamChunkBuilder}; +use crate::common::StreamChunkBuilder; use crate::executor::expect_first_barrier_from_aligned_stream; pub struct DynamicFilterExecutor { @@ -93,7 +93,7 @@ impl DynamicFilterExecutor { } } - fn apply_batch( + async fn apply_batch( &mut self, data_chunk: &DataChunk, ops: Vec, @@ -104,11 +104,16 @@ impl DynamicFilterExecutor { let mut new_visibility = BitmapBuilder::with_capacity(ops.len()); let mut last_res = false; - let eval_results = condition.map(|cond| { - cond.eval_infallible(data_chunk, |err| { - self.ctx.on_compute_error(err, self.identity()) - }) - }); + let eval_results = if let Some(cond) = condition { + Some( + cond.eval_infallible(data_chunk, |err| { + self.ctx.on_compute_error(err, &self.identity) + }) + .await, + ) + } else { + None + }; for (idx, (row, op)) in data_chunk.rows().zip_eq_debug(ops.iter()).enumerate() { let left_val = row.datum_at(self.key_l).to_owned_datum(); @@ -323,7 +328,7 @@ impl DynamicFilterExecutor { let condition = dynamic_cond(right_val).transpose()?; let (new_ops, new_visibility) = - self.apply_batch(&data_chunk, ops, condition)?; + self.apply_batch(&data_chunk, ops, condition).await?; let (columns, _) = data_chunk.into_parts(); diff --git a/src/stream/src/executor/filter.rs b/src/stream/src/executor/filter.rs index 80354bd48f691..61b9495d0f8c6 100644 --- a/src/stream/src/executor/filter.rs +++ b/src/stream/src/executor/filter.rs @@ -21,52 +21,33 @@ use risingwave_common::catalog::Schema; use risingwave_common::util::iter_util::ZipEqFast; use risingwave_expr::expr::BoxedExpression; -use super::{ - ActorContextRef, Executor, ExecutorInfo, PkIndicesRef, SimpleExecutor, SimpleExecutorWrapper, - StreamExecutorResult, Watermark, -}; -use crate::common::InfallibleExpression; - -pub type FilterExecutor = SimpleExecutorWrapper; - -impl FilterExecutor { - pub fn new( - ctx: ActorContextRef, - input: Box, - expr: BoxedExpression, - executor_id: u64, - ) -> Self { - let info = input.info(); - - SimpleExecutorWrapper { - input, - inner: SimpleFilterExecutor::new(ctx, info, expr, executor_id), - } - } -} +use super::*; /// `FilterExecutor` filters data with the `expr`. The `expr` takes a chunk of data, /// and returns a boolean array on whether each item should be retained. And then, /// `FilterExecutor` will insert, delete or update element into next executor according /// to the result of the expression. -pub struct SimpleFilterExecutor { +pub struct FilterExecutor { ctx: ActorContextRef, info: ExecutorInfo, + input: BoxedExecutor, /// Expression of the current filter, note that the filter must always have the same output for /// the same input. expr: BoxedExpression, } -impl SimpleFilterExecutor { +impl FilterExecutor { pub fn new( ctx: ActorContextRef, - input_info: ExecutorInfo, + input: Box, expr: BoxedExpression, executor_id: u64, ) -> Self { + let input_info = input.info(); Self { ctx, + input, info: ExecutorInfo { schema: input_info.schema, pk_indices: input_info.pk_indices, @@ -155,7 +136,7 @@ impl SimpleFilterExecutor { } } -impl Debug for SimpleFilterExecutor { +impl Debug for FilterExecutor { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { f.debug_struct("FilterExecutor") .field("expr", &self.expr) @@ -163,21 +144,7 @@ impl Debug for SimpleFilterExecutor { } } -impl SimpleExecutor for SimpleFilterExecutor { - fn map_filter_chunk(&self, chunk: StreamChunk) -> StreamExecutorResult> { - let chunk = chunk.compact(); - - let pred_output = self.expr.eval_infallible(chunk.data_chunk(), |err| { - self.ctx.on_compute_error(err, self.identity()) - }); - - Self::filter(chunk, pred_output) - } - - fn handle_watermark(&self, watermark: Watermark) -> StreamExecutorResult> { - Ok(vec![watermark]) - } - +impl Executor for FilterExecutor { fn schema(&self) -> &Schema { &self.info.schema } @@ -189,6 +156,40 @@ impl SimpleExecutor for SimpleFilterExecutor { fn identity(&self) -> &str { &self.info.identity } + + fn execute(self: Box) -> BoxedMessageStream { + self.execute_inner().boxed() + } +} + +impl FilterExecutor { + #[try_stream(ok = Message, error = StreamExecutorError)] + async fn execute_inner(self) { + let input = self.input.execute(); + #[for_await] + for msg in input { + let msg = msg?; + match msg { + Message::Watermark(w) => yield Message::Watermark(w), + Message::Chunk(chunk) => { + let chunk = chunk.compact(); + + let pred_output = self + .expr + .eval_infallible(chunk.data_chunk(), |err| { + self.ctx.on_compute_error(err, &self.info.identity) + }) + .await; + + match Self::filter(chunk, pred_output)? { + Some(new_chunk) => yield Message::Chunk(new_chunk), + None => continue, + } + } + m => yield m, + } + } + } } #[cfg(test)] diff --git a/src/stream/src/executor/global_simple_agg.rs b/src/stream/src/executor/global_simple_agg.rs index 827aedc385971..85d8b8abdc2fe 100644 --- a/src/stream/src/executor/global_simple_agg.rs +++ b/src/stream/src/executor/global_simple_agg.rs @@ -168,20 +168,19 @@ impl GlobalSimpleAggExecutor { let (ops, columns, visibility) = chunk.into_inner(); // Calculate the row visibility for every agg call. - let visibilities: Vec<_> = this - .agg_calls - .iter() - .map(|agg_call| { - agg_call_filter_res( - &this.actor_ctx, - &this.info.identity, - agg_call, - &columns, - visibility.as_ref(), - capacity, - ) - }) - .try_collect()?; + let mut visibilities = Vec::with_capacity(this.agg_calls.len()); + for agg_call in &this.agg_calls { + let result = agg_call_filter_res( + &this.actor_ctx, + &this.info.identity, + agg_call, + &columns, + visibility.as_ref(), + capacity, + ) + .await?; + visibilities.push(result); + } // Materialize input chunk if needed. this.storages diff --git a/src/stream/src/executor/hash_agg.rs b/src/stream/src/executor/hash_agg.rs index b6c41e5846a3b..245921beb4060 100644 --- a/src/stream/src/executor/hash_agg.rs +++ b/src/stream/src/executor/hash_agg.rs @@ -310,20 +310,19 @@ impl HashAggExecutor { let (ops, columns, visibility) = chunk.into_inner(); // Calculate the row visibility for every agg call. - let call_visibilities: Vec<_> = this - .agg_calls - .iter() - .map(|agg_call| { - agg_call_filter_res( - &this.actor_ctx, - &this.info.identity, - agg_call, - &columns, - visibility.as_ref(), - capacity, - ) - }) - .try_collect()?; + let mut call_visibilities = Vec::with_capacity(this.agg_calls.len()); + for agg_call in &this.agg_calls { + let agg_call_filter_res = agg_call_filter_res( + &this.actor_ctx, + &this.info.identity, + agg_call, + &columns, + visibility.as_ref(), + capacity, + ) + .await?; + call_visibilities.push(agg_call_filter_res); + } // Materialize input chunk if needed. this.storages diff --git a/src/stream/src/executor/hash_join.rs b/src/stream/src/executor/hash_join.rs index 5c56743217333..3c1348dca74f7 100644 --- a/src/stream/src/executor/hash_join.rs +++ b/src/stream/src/executor/hash_join.rs @@ -43,7 +43,7 @@ use super::{ Watermark, }; use crate::common::table::state_table::StateTable; -use crate::common::{InfallibleExpression, StreamChunkBuilder}; +use crate::common::StreamChunkBuilder; use crate::executor::expect_first_barrier_from_aligned_stream; use crate::executor::JoinType::LeftAnti; use crate::task::AtomicU64Ref; @@ -864,26 +864,6 @@ impl HashJoinExecutor, row_matched: &OwnedRow| -> bool { - // TODO(yuhao-su): We should find a better way to eval the expression without concat - // two rows. - // if there are non-equi expressions - if let Some(ref mut cond) = cond { - let new_row = Self::row_concat( - row_update, - side_update.start_pos, - row_matched, - side_match.start_pos, - ); - - cond.eval_row_infallible(&new_row, |err| ctx.on_compute_error(err, identity)) - .map(|s| *s.as_bool()) - .unwrap_or(false) - } else { - true - } - }; - let keys = K::build(&side_update.join_key_indices, chunk.data_chunk())?; for ((op, row), key) in chunk.rows().zip_eq_debug(keys.iter()) { let matched_rows: Option = @@ -897,7 +877,27 @@ impl HashJoinExecutor HashJoinExecutor StreamExecutorResult<()> { let capacity = chunk.capacity(); let (ops, columns, visibility) = chunk.into_inner(); - let visibilities: Vec<_> = agg_calls - .iter() - .map(|agg_call| { - agg_call_filter_res( - ctx, - identity, - agg_call, - &columns, - visibility.as_ref(), - capacity, - ) - }) - .try_collect()?; + let mut visibilities = Vec::with_capacity(agg_calls.len()); + for agg_call in agg_calls { + let result = agg_call_filter_res( + ctx, + identity, + agg_call, + &columns, + visibility.as_ref(), + capacity, + ) + .await?; + visibilities.push(result) + } agg_calls .iter() .zip_eq_fast(visibilities) @@ -118,7 +118,8 @@ impl LocalSimpleAggExecutor { match msg { Message::Watermark(_) => {} Message::Chunk(chunk) => { - Self::apply_chunk(&ctx, &info.identity, &agg_calls, &mut aggregators, chunk)?; + Self::apply_chunk(&ctx, &info.identity, &agg_calls, &mut aggregators, chunk) + .await?; is_dirty = true; } m @ Message::Barrier(_) => { diff --git a/src/stream/src/executor/mod.rs b/src/stream/src/executor/mod.rs index 57858819808e9..27e8d7583b371 100644 --- a/src/stream/src/executor/mod.rs +++ b/src/stream/src/executor/mod.rs @@ -20,6 +20,7 @@ use await_tree::InstrumentAwait; use enum_as_inner::EnumAsInner; use futures::stream::BoxStream; use futures::{Stream, StreamExt}; +use futures_async_stream::try_stream; use itertools::Itertools; use minitrace::prelude::*; use risingwave_common::array::column::Column; @@ -46,7 +47,6 @@ use risingwave_pb::stream_plan::{ }; use smallvec::SmallVec; -use crate::common::InfallibleExpression; use crate::error::StreamResult; use crate::task::{ActorId, FragmentId}; @@ -81,7 +81,6 @@ mod project_set; mod rearranged_chain; mod receiver; pub mod row_id_gen; -mod simple; mod sink; mod sort; mod sort_buffer; @@ -126,7 +125,6 @@ pub use project_set::*; pub use rearranged_chain::RearrangedChainExecutor; pub use receiver::ReceiverExecutor; use risingwave_pb::source::{ConnectorSplit, ConnectorSplits}; -use simple::{SimpleExecutor, SimpleExecutorWrapper}; pub use sink::SinkExecutor; pub use sort::SortExecutor; pub use source::*; @@ -591,7 +589,7 @@ impl Watermark { } } - pub fn transform_with_expr( + pub async fn transform_with_expr( self, expr: &BoxedExpression, new_col_idx: usize, @@ -607,7 +605,7 @@ impl Watermark { row[col_idx] = Some(val); OwnedRow::new(row) }; - let val = expr.eval_row_infallible(&row, on_err)?; + let val = expr.eval_row_infallible(&row, on_err).await?; Some(Self { col_idx: new_col_idx, data_type, diff --git a/src/stream/src/executor/project.rs b/src/stream/src/executor/project.rs index b691ac4e7b71f..86ecb80a5feda 100644 --- a/src/stream/src/executor/project.rs +++ b/src/stream/src/executor/project.rs @@ -21,39 +21,17 @@ use risingwave_common::array::StreamChunk; use risingwave_common::catalog::{Field, Schema}; use risingwave_expr::expr::BoxedExpression; -use super::{ - ActorContextRef, Executor, ExecutorInfo, PkIndices, PkIndicesRef, SimpleExecutor, - SimpleExecutorWrapper, StreamExecutorResult, Watermark, -}; -use crate::common::InfallibleExpression; - -pub type ProjectExecutor = SimpleExecutorWrapper; - -impl ProjectExecutor { - pub fn new( - ctx: ActorContextRef, - input: Box, - pk_indices: PkIndices, - exprs: Vec, - execuotr_id: u64, - watermark_derivations: MultiMap, - ) -> Self { - let info = ExecutorInfo { - schema: input.schema().to_owned(), - pk_indices, - identity: "Project".to_owned(), - }; - SimpleExecutorWrapper { - input, - inner: SimpleProjectExecutor::new(ctx, info, exprs, execuotr_id, watermark_derivations), - } - } -} +use super::*; /// `ProjectExecutor` project data with the `expr`. The `expr` takes a chunk of data, /// and returns a new data chunk. And then, `ProjectExecutor` will insert, delete /// or update element into next operator according to the result of the expression. -pub struct SimpleProjectExecutor { +pub struct ProjectExecutor { + input: BoxedExecutor, + inner: Inner, +} + +struct Inner { ctx: ActorContextRef, info: ExecutorInfo, @@ -64,14 +42,21 @@ pub struct SimpleProjectExecutor { watermark_derivations: MultiMap, } -impl SimpleProjectExecutor { +impl ProjectExecutor { pub fn new( ctx: ActorContextRef, - input_info: ExecutorInfo, + input: Box, + pk_indices: PkIndices, exprs: Vec, executor_id: u64, watermark_derivations: MultiMap, ) -> Self { + let info = ExecutorInfo { + schema: input.schema().to_owned(), + pk_indices, + identity: "Project".to_owned(), + }; + let schema = Schema { fields: exprs .iter() @@ -79,47 +64,73 @@ impl SimpleProjectExecutor { .collect_vec(), }; Self { - ctx, - info: ExecutorInfo { - schema, - pk_indices: input_info.pk_indices, - identity: format!("ProjectExecutor {:X}", executor_id), + input, + inner: Inner { + ctx, + info: ExecutorInfo { + schema, + pk_indices: info.pk_indices, + identity: format!("ProjectExecutor {:X}", executor_id), + }, + exprs, + watermark_derivations, }, - exprs, - watermark_derivations, } } } -impl Debug for SimpleProjectExecutor { +impl Debug for ProjectExecutor { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { f.debug_struct("ProjectExecutor") - .field("exprs", &self.exprs) + .field("exprs", &self.inner.exprs) .finish() } } -impl SimpleExecutor for SimpleProjectExecutor { - fn map_filter_chunk(&self, chunk: StreamChunk) -> StreamExecutorResult> { +impl Executor for ProjectExecutor { + fn schema(&self) -> &Schema { + &self.inner.info.schema + } + + fn pk_indices(&self) -> PkIndicesRef<'_> { + &self.inner.info.pk_indices + } + + fn identity(&self) -> &str { + &self.inner.info.identity + } + + fn execute(self: Box) -> BoxedMessageStream { + self.inner.execute(self.input).boxed() + } +} + +impl Inner { + async fn map_filter_chunk( + &self, + chunk: StreamChunk, + ) -> StreamExecutorResult> { let chunk = chunk.compact(); let (data_chunk, ops) = chunk.into_parts(); - let projected_columns = self - .exprs - .iter() - .map(|expr| { - Column::new(expr.eval_infallible(&data_chunk, |err| { + let mut projected_columns = Vec::new(); + + for expr in &self.exprs { + let evaluated_expr = expr + .eval_infallible(&data_chunk, |err| { self.ctx.on_compute_error(err, &self.info.identity) - })) - }) - .collect(); + }) + .await; + let new_column = Column::new(evaluated_expr); + projected_columns.push(new_column); + } let new_chunk = StreamChunk::new(ops, projected_columns, None); Ok(Some(new_chunk)) } - fn handle_watermark(&self, watermark: Watermark) -> StreamExecutorResult> { + async fn handle_watermark(&self, watermark: Watermark) -> StreamExecutorResult> { let out_col_indices = match self.watermark_derivations.get_vec(&watermark.col_idx) { Some(v) => v, None => return Ok(vec![]), @@ -127,16 +138,15 @@ impl SimpleExecutor for SimpleProjectExecutor { let mut ret = vec![]; for out_col_idx in out_col_indices { let out_col_idx = *out_col_idx; - let derived_watermark = watermark.clone().transform_with_expr( - &self.exprs[out_col_idx], - out_col_idx, - |err| { + let derived_watermark = watermark + .clone() + .transform_with_expr(&self.exprs[out_col_idx], out_col_idx, |err| { self.ctx.on_compute_error( err, &(self.info.identity.to_string() + "(when computing watermark)"), ) - }, - ); + }) + .await; if let Some(derived_watermark) = derived_watermark { ret.push(derived_watermark); } else { @@ -149,16 +159,25 @@ impl SimpleExecutor for SimpleProjectExecutor { Ok(ret) } - fn schema(&self) -> &Schema { - &self.info.schema - } - - fn pk_indices(&self) -> PkIndicesRef<'_> { - &self.info.pk_indices - } - - fn identity(&self) -> &str { - &self.info.identity + #[try_stream(ok = Message, error = StreamExecutorError)] + async fn execute(self, input: BoxedExecutor) { + #[for_await] + for msg in input.execute() { + let msg = msg?; + match msg { + Message::Watermark(w) => { + let watermarks = self.handle_watermark(w).await?; + for watermark in watermarks { + yield Message::Watermark(watermark) + } + } + Message::Chunk(chunk) => match self.map_filter_chunk(chunk).await? { + Some(new_chunk) => yield Message::Chunk(new_chunk), + None => continue, + }, + m => yield m, + } + } } } diff --git a/src/stream/src/executor/project_set.rs b/src/stream/src/executor/project_set.rs index 2d5c1643b9303..3d908d9e55df4 100644 --- a/src/stream/src/executor/project_set.rs +++ b/src/stream/src/executor/project_set.rs @@ -126,11 +126,12 @@ impl ProjectSetExecutor { .collect_vec(); let mut ret_ops = vec![]; - let results: Vec<_> = self - .select_list - .iter() - .map(|select_item| select_item.eval(&data_chunk)) - .try_collect()?; + let mut results = Vec::with_capacity(self.select_list.len()); + for select_item in &self.select_list { + let result = select_item.eval(&data_chunk).await?; + results.push(result); + } + assert!( results .iter() diff --git a/src/stream/src/executor/simple.rs b/src/stream/src/executor/simple.rs deleted file mode 100644 index ccbdbabecba2f..0000000000000 --- a/src/stream/src/executor/simple.rs +++ /dev/null @@ -1,95 +0,0 @@ -// Copyright 2023 RisingWave Labs -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -use futures::StreamExt; -use futures_async_stream::try_stream; -use risingwave_common::catalog::Schema; - -use super::error::{StreamExecutorError, StreamExecutorResult}; -use super::{ - BoxedExecutor, BoxedMessageStream, Executor, Message, PkIndicesRef, StreamChunk, Watermark, -}; - -/// Executor which can handle [`StreamChunk`]s one by one. -pub trait SimpleExecutor: Send + Sync + 'static { - /// convert a single chunk to zero or one chunks. - fn map_filter_chunk(&self, chunk: StreamChunk) -> StreamExecutorResult>; - - /// convert a single chunk to zero or one chunks. - fn handle_watermark(&self, watermark: Watermark) -> StreamExecutorResult>; - - /// See [`super::Executor::schema`]. - fn schema(&self) -> &Schema; - - /// See [`super::Executor::pk_indices`]. - fn pk_indices(&self) -> PkIndicesRef<'_>; - - /// See [`super::Executor::identity`]. - fn identity(&self) -> &str; -} - -/// The struct wraps a [`SimpleExecutor`], and implements the interface of [`Executor`]. -pub struct SimpleExecutorWrapper { - pub(super) input: BoxedExecutor, - pub(super) inner: E, -} - -impl Executor for SimpleExecutorWrapper -where - E: SimpleExecutor, -{ - fn schema(&self) -> &Schema { - self.inner.schema() - } - - fn pk_indices(&self) -> PkIndicesRef<'_> { - self.inner.pk_indices() - } - - fn identity(&self) -> &str { - self.inner.identity() - } - - fn execute(self: Box) -> BoxedMessageStream { - self.execute_inner().boxed() - } -} - -impl SimpleExecutorWrapper -where - E: SimpleExecutor, -{ - #[try_stream(ok = Message, error = StreamExecutorError)] - async fn execute_inner(self) { - let input = self.input.execute(); - let inner = self.inner; - #[for_await] - for msg in input { - let msg = msg?; - match msg { - Message::Watermark(w) => { - let watermarks = inner.handle_watermark(w)?; - for watermark in watermarks { - yield Message::Watermark(watermark) - } - } - Message::Chunk(chunk) => match inner.map_filter_chunk(chunk)? { - Some(new_chunk) => yield Message::Chunk(new_chunk), - None => continue, - }, - m => yield m, - } - } - } -} diff --git a/src/stream/src/executor/temporal_join.rs b/src/stream/src/executor/temporal_join.rs index 4dfa7fd7ebbd7..6b119047447cb 100644 --- a/src/stream/src/executor/temporal_join.rs +++ b/src/stream/src/executor/temporal_join.rs @@ -21,7 +21,7 @@ use futures::{StreamExt, TryStreamExt}; use futures_async_stream::try_stream; use local_stats_alloc::{SharedStatsAlloc, StatsAlloc}; use lru::DefaultHasher; -use risingwave_common::array::{Op, RowRef, StreamChunk}; +use risingwave_common::array::{Op, StreamChunk}; use risingwave_common::catalog::Schema; use risingwave_common::row::{OwnedRow, Row, RowExt}; use risingwave_common::util::iter_util::ZipEqFast; @@ -32,7 +32,7 @@ use risingwave_storage::StateStore; use super::{Barrier, Executor, Message, MessageStream, StreamExecutorError, StreamExecutorResult}; use crate::cache::{new_with_hasher_in, ManagedLruCache}; -use crate::common::{InfallibleExpression, StreamChunkBuilder}; +use crate::common::StreamChunkBuilder; use crate::executor::monitor::StreamingMetrics; use crate::executor::{ActorContextRef, BoxedExecutor, JoinType, JoinTypePrimitive, PkIndices}; use crate::task::AtomicU64Ref; @@ -227,19 +227,6 @@ impl TemporalJoinExecutor { self.right.schema().len(), ); - let mut check_join_condition = |left_row: &RowRef<'_>, right_row: &OwnedRow| -> bool { - if let Some(ref mut cond) = self.condition { - let concat_row = left_row.chain(right_row).into_owned_row(); - cond.eval_row_infallible(&concat_row, |err| { - self.ctx.on_compute_error(err, self.identity.as_str()) - }) - .map(|s| *s.as_bool()) - .unwrap_or(false) - } else { - true - } - }; - let mut prev_epoch = None; #[for_await] for msg in align_input(self.left, self.right) { @@ -261,9 +248,23 @@ impl TemporalJoinExecutor { { continue; } - if let Some(right_row) = self.right_table.lookup(key, epoch).await? && check_join_condition(&left_row, &right_row) { - if let Some(chunk) = builder.append_row(op, left_row, &right_row) { - yield Message::Chunk(chunk); + if let Some(right_row) = self.right_table.lookup(key, epoch).await? { + // check join condition + let ok = if let Some(ref mut cond) = self.condition { + let concat_row = left_row.chain(&right_row).into_owned_row(); + cond.eval_row_infallible(&concat_row, |err| { + self.ctx.on_compute_error(err, self.identity.as_str()) + }) + .await + .map(|s| *s.as_bool()) + .unwrap_or(false) + } else { + true + }; + if ok { + if let Some(chunk) = builder.append_row(op, left_row, &right_row) { + yield Message::Chunk(chunk); + } } } else if T == JoinType::LeftOuter { if let Some(chunk) = builder.append_row_update(op, left_row) { diff --git a/src/stream/src/executor/watermark_filter.rs b/src/stream/src/executor/watermark_filter.rs index 98e9fc82d4c89..f43b17b75c7c3 100644 --- a/src/stream/src/executor/watermark_filter.rs +++ b/src/stream/src/executor/watermark_filter.rs @@ -30,12 +30,11 @@ use risingwave_pb::expr::expr_node::Type; use risingwave_storage::StateStore; use super::error::StreamExecutorError; -use super::filter::SimpleFilterExecutor; +use super::filter::FilterExecutor; use super::{ ActorContextRef, BoxedExecutor, Executor, ExecutorInfo, Message, StreamExecutorResult, }; use crate::common::table::state_table::StateTable; -use crate::common::InfallibleExpression; use crate::executor::{expect_first_barrier, Watermark}; /// The executor will generate a `Watermark` after each chunk. @@ -146,7 +145,8 @@ impl WatermarkFilterExecutor { let watermark_array = watermark_expr .eval_infallible(chunk.data_chunk(), |err| { ctx.on_compute_error(err, &info.identity) - }); + }) + .await; // Build the expression to calculate watermark filter. let watermark_filter_expr = Self::build_watermark_filter_expr( @@ -167,9 +167,10 @@ impl WatermarkFilterExecutor { let pred_output = watermark_filter_expr .eval_infallible(chunk.data_chunk(), |err| { ctx.on_compute_error(err, &info.identity) - }); + }) + .await; - if let Some(output_chunk) = SimpleFilterExecutor::filter(chunk, pred_output)? { + if let Some(output_chunk) = FilterExecutor::filter(chunk, pred_output)? { yield Message::Chunk(output_chunk); }; diff --git a/src/stream/src/lib.rs b/src/stream/src/lib.rs index 20885d8311975..aea2f732e72d3 100644 --- a/src/stream/src/lib.rs +++ b/src/stream/src/lib.rs @@ -49,13 +49,3 @@ pub mod error; pub mod executor; mod from_proto; pub mod task; - -/// Controls the behavior when a compute error happens. -/// -/// - If set to `false`, `NULL` will be inserted. -/// - TODO: If set to `true`, The MV will be suspended and removed from further checkpoints. It can -/// still be used to serve outdated data without corruption. -/// -/// See also . -#[expect(dead_code)] -const STRICT_MODE: bool = false;