From 7088fd32636fec6314570ae4d1e1b7a6eec1425d Mon Sep 17 00:00:00 2001 From: Richard Chien Date: Mon, 10 Apr 2023 20:20:53 +0800 Subject: [PATCH] refactor(common): remove `unwrap` from `memcmp_encoding` and improve `encode_row` (#9087) Signed-off-by: Richard Chien --- src/batch/src/executor/group_top_n.rs | 2 +- src/batch/src/executor/order_by.rs | 2 +- src/batch/src/executor/top_n.rs | 2 +- src/common/src/util/memcmp_encoding.rs | 69 +++++++++++++----------- src/expr/src/vector_op/agg/array_agg.rs | 26 ++++++--- src/expr/src/vector_op/agg/string_agg.rs | 27 +++++++--- 6 files changed, 78 insertions(+), 50 deletions(-) diff --git a/src/batch/src/executor/group_top_n.rs b/src/batch/src/executor/group_top_n.rs index 6a8c0b58bdd42..198a441cb1d59 100644 --- a/src/batch/src/executor/group_top_n.rs +++ b/src/batch/src/executor/group_top_n.rs @@ -186,7 +186,7 @@ impl GroupTopNExecutor { let chunk = Arc::new(chunk?.compact()); let keys = K::build(self.group_key.as_slice(), &chunk)?; - for (row_id, (encoded_row, key)) in encode_chunk(&chunk, &self.column_orders) + for (row_id, (encoded_row, key)) in encode_chunk(&chunk, &self.column_orders)? .into_iter() .zip_eq_fast(keys.into_iter()) .enumerate() diff --git a/src/batch/src/executor/order_by.rs b/src/batch/src/executor/order_by.rs index e32b997ee6f7d..719c1e0209832 100644 --- a/src/batch/src/executor/order_by.rs +++ b/src/batch/src/executor/order_by.rs @@ -91,7 +91,7 @@ impl SortExecutor { } for chunk in &chunks { - let encoded_chunk = encode_chunk(chunk, &self.column_orders); + let encoded_chunk = encode_chunk(chunk, &self.column_orders)?; encoded_rows.extend( encoded_chunk .into_iter() diff --git a/src/batch/src/executor/top_n.rs b/src/batch/src/executor/top_n.rs index 67c31c1f34ddd..d2ef40f0a847e 100644 --- a/src/batch/src/executor/top_n.rs +++ b/src/batch/src/executor/top_n.rs @@ -218,7 +218,7 @@ impl TopNExecutor { #[for_await] for chunk in self.child.execute() { let chunk = Arc::new(chunk?.compact()); - for (row_id, encoded_row) in encode_chunk(&chunk, &self.column_orders) + for (row_id, encoded_row) in encode_chunk(&chunk, &self.column_orders)? .into_iter() .enumerate() { diff --git a/src/common/src/util/memcmp_encoding.rs b/src/common/src/util/memcmp_encoding.rs index d48779d643c2e..43dd7e11afe74 100644 --- a/src/common/src/util/memcmp_encoding.rs +++ b/src/common/src/util/memcmp_encoding.rs @@ -16,10 +16,9 @@ use bytes::{Buf, BufMut}; use itertools::Itertools; use serde::{Deserialize, Serialize}; -use super::iter_util::ZipEqFast; +use super::iter_util::{ZipEqDebug, ZipEqFast}; use crate::array::serial_array::Serial; use crate::array::{ArrayImpl, DataChunk}; -use crate::error::Result; use crate::row::Row; use crate::types::{DataType, Date, Datum, ScalarImpl, Time, Timestamp, ToDatumRef, F32, F64}; use crate::util::sort_util::{ColumnOrder, OrderType}; @@ -180,18 +179,22 @@ fn calculate_encoded_size_inner( Ok(deserializer.position() - base_position) } -pub fn encode_value(value: impl ToDatumRef, order: OrderType) -> Result> { +pub fn encode_value(value: impl ToDatumRef, order: OrderType) -> memcomparable::Result> { let mut serializer = memcomparable::Serializer::new(vec![]); serialize_datum(value, order, &mut serializer)?; Ok(serializer.into_inner()) } -pub fn decode_value(ty: &DataType, encoded_value: &[u8], order: OrderType) -> Result { +pub fn decode_value( + ty: &DataType, + encoded_value: &[u8], + order: OrderType, +) -> memcomparable::Result { let mut deserializer = memcomparable::Deserializer::new(encoded_value); - Ok(deserialize_datum(ty, order, &mut deserializer)?) + deserialize_datum(ty, order, &mut deserializer) } -pub fn encode_array(array: &ArrayImpl, order: OrderType) -> Result>> { +pub fn encode_array(array: &ArrayImpl, order: OrderType) -> memcomparable::Result>> { let mut data = Vec::with_capacity(array.len()); for datum in array.iter() { data.push(encode_value(datum, order)?); @@ -202,11 +205,14 @@ pub fn encode_array(array: &ArrayImpl, order: OrderType) -> Result>> /// This function is used to accelerate the comparison of tuples. It takes datachunk and /// user-defined order as input, yield encoded binary string with order preserved for each tuple in /// the datachunk. -pub fn encode_chunk(chunk: &DataChunk, column_orders: &[ColumnOrder]) -> Vec> { - let encoded_columns = column_orders +pub fn encode_chunk( + chunk: &DataChunk, + column_orders: &[ColumnOrder], +) -> memcomparable::Result>> { + let encoded_columns: Vec<_> = column_orders .iter() - .map(|o| encode_array(chunk.column_at(o.column_index).array_ref(), o.order_type).unwrap()) - .collect_vec(); + .map(|o| encode_array(chunk.column_at(o.column_index).array_ref(), o.order_type)) + .try_collect()?; let mut encoded_chunk = vec![vec![]; chunk.capacity()]; for encoded_column in encoded_columns { @@ -215,16 +221,16 @@ pub fn encode_chunk(chunk: &DataChunk, column_orders: &[ColumnOrder]) -> Vec Vec { - let mut encoded_row = vec![]; - column_orders.iter().for_each(|o| { - encoded_row.extend(encode_value(row.datum_at(o.column_index), o.order_type).unwrap()); - }); - encoded_row +pub fn encode_row(row: impl Row, order_types: &[OrderType]) -> memcomparable::Result> { + let mut serializer = memcomparable::Serializer::new(vec![]); + row.iter() + .zip_eq_debug(order_types) + .try_for_each(|(datum, order)| serialize_datum(datum, *order, &mut serializer))?; + Ok(serializer.into_inner()) } #[cfg(test)] @@ -236,7 +242,7 @@ mod tests { use super::*; use crate::array::{DataChunk, ListValue, StructValue}; - use crate::row::OwnedRow; + use crate::row::{OwnedRow, RowExt}; use crate::types::{DataType, ScalarImpl, F32}; use crate::util::sort_util::{ColumnOrder, OrderType}; @@ -501,12 +507,10 @@ mod tests { let row1 = OwnedRow::new(vec![v10, v11, v12]); let row2 = OwnedRow::new(vec![v20, v21, v22]); - let column_orders = vec![ - ColumnOrder::new(0, OrderType::ascending()), - ColumnOrder::new(1, OrderType::descending()), - ]; + let order_col_indices = vec![0, 1]; + let order_types = vec![OrderType::ascending(), OrderType::descending()]; - let encoded_row1 = encode_row(&row1, &column_orders); + let encoded_row1 = encode_row(row1.project(&order_col_indices), &order_types).unwrap(); let encoded_v10 = encode_value( v10_cloned.as_ref().map(|x| x.as_scalar_ref_impl()), OrderType::ascending(), @@ -523,7 +527,7 @@ mod tests { .collect_vec(); assert_eq!(encoded_row1, concated_encoded_row1); - let encoded_row2 = encode_row(&row2, &column_orders); + let encoded_row2 = encode_row(row2.project(&order_col_indices), &order_types).unwrap(); assert!(encoded_row1 < encoded_row2); } @@ -542,14 +546,17 @@ mod tests { &[row1.clone(), row2.clone()], &[DataType::Int32, DataType::Varchar, DataType::Float32], ); - let column_orders = vec![ - ColumnOrder::new(0, OrderType::ascending()), - ColumnOrder::new(1, OrderType::descending()), - ]; + let order_col_indices = vec![0, 1]; + let order_types = vec![OrderType::ascending(), OrderType::descending()]; + let column_orders = order_col_indices + .iter() + .zip_eq_fast(&order_types) + .map(|(i, o)| ColumnOrder::new(*i, *o)) + .collect_vec(); - let encoded_row1 = encode_row(&row1, &column_orders); - let encoded_row2 = encode_row(&row2, &column_orders); - let encoded_chunk = encode_chunk(&chunk, &column_orders); + let encoded_row1 = encode_row(row1.project(&order_col_indices), &order_types).unwrap(); + let encoded_row2 = encode_row(row2.project(&order_col_indices), &order_types).unwrap(); + let encoded_chunk = encode_chunk(&chunk, &column_orders).unwrap(); assert_eq!(&encoded_chunk, &[encoded_row1, encoded_row2]); } } diff --git a/src/expr/src/vector_op/agg/array_agg.rs b/src/expr/src/vector_op/agg/array_agg.rs index b3d475b0602c5..732044558b800 100644 --- a/src/expr/src/vector_op/agg/array_agg.rs +++ b/src/expr/src/vector_op/agg/array_agg.rs @@ -12,15 +12,16 @@ // See the License for the specific language governing permissions and // limitations under the License. +use anyhow::anyhow; use risingwave_common::array::{ArrayBuilder, ArrayBuilderImpl, DataChunk, ListValue, RowRef}; use risingwave_common::bail; -use risingwave_common::row::Row; +use risingwave_common::row::{Row, RowExt}; use risingwave_common::types::{DataType, Datum, Scalar, ToOwnedDatum}; use risingwave_common::util::memcmp_encoding; -use risingwave_common::util::sort_util::ColumnOrder; +use risingwave_common::util::sort_util::{ColumnOrder, OrderType}; use crate::vector_op::agg::aggregator::Aggregator; -use crate::Result; +use crate::{ExprError, Result}; #[derive(Clone)] struct ArrayAggUnordered { @@ -97,25 +98,34 @@ type OrderKey = Vec; struct ArrayAggOrdered { return_type: DataType, agg_col_idx: usize, - column_orders: Vec, + order_col_indices: Vec, + order_types: Vec, unordered_values: Vec<(OrderKey, Datum)>, } impl ArrayAggOrdered { fn new(return_type: DataType, agg_col_idx: usize, column_orders: Vec) -> Self { assert!(matches!(return_type, DataType::List { datatype: _ })); + let (order_col_indices, order_types) = column_orders + .into_iter() + .map(|c| (c.column_index, c.order_type)) + .unzip(); ArrayAggOrdered { return_type, agg_col_idx, - column_orders, + order_col_indices, + order_types, unordered_values: vec![], } } - fn push_row(&mut self, row: RowRef<'_>) { - let key = memcmp_encoding::encode_row(row, &self.column_orders); + fn push_row(&mut self, row: RowRef<'_>) -> Result<()> { + let key = + memcmp_encoding::encode_row(row.project(&self.order_col_indices), &self.order_types) + .map_err(|e| ExprError::Internal(anyhow!("failed to encode row, error: {}", e)))?; let datum = row.datum_at(self.agg_col_idx).to_owned_datum(); self.unordered_values.push((key, datum)); + Ok(()) } fn get_result_and_reset(&mut self) -> ListValue { @@ -134,7 +144,7 @@ impl Aggregator for ArrayAggOrdered { async fn update_single(&mut self, input: &DataChunk, row_id: usize) -> Result<()> { let (row, vis) = input.row_at(row_id); assert!(vis); - self.push_row(row); + self.push_row(row)?; Ok(()) } diff --git a/src/expr/src/vector_op/agg/string_agg.rs b/src/expr/src/vector_op/agg/string_agg.rs index 32ef9330d1d70..2af70fe58d823 100644 --- a/src/expr/src/vector_op/agg/string_agg.rs +++ b/src/expr/src/vector_op/agg/string_agg.rs @@ -12,17 +12,19 @@ // See the License for the specific language governing permissions and // limitations under the License. +use anyhow::anyhow; use risingwave_common::array::{ Array, ArrayBuilder, ArrayBuilderImpl, ArrayImpl, DataChunk, RowRef, }; use risingwave_common::bail; +use risingwave_common::row::RowExt; use risingwave_common::types::DataType; use risingwave_common::util::iter_util::ZipEqFast; use risingwave_common::util::memcmp_encoding; -use risingwave_common::util::sort_util::ColumnOrder; +use risingwave_common::util::sort_util::{ColumnOrder, OrderType}; use crate::vector_op::agg::aggregator::Aggregator; -use crate::Result; +use crate::{ExprError, Result}; #[derive(Clone)] struct StringAggUnordered { @@ -124,22 +126,30 @@ struct StringAggData { struct StringAggOrdered { agg_col_idx: usize, delim_col_idx: usize, - column_orders: Vec, + order_col_indices: Vec, + order_types: Vec, unordered_values: Vec<(OrderKey, StringAggData)>, } impl StringAggOrdered { fn new(agg_col_idx: usize, delim_col_idx: usize, column_orders: Vec) -> Self { + let (order_col_indices, order_types) = column_orders + .into_iter() + .map(|c| (c.column_index, c.order_type)) + .unzip(); Self { agg_col_idx, delim_col_idx, - column_orders, + order_col_indices, + order_types, unordered_values: vec![], } } - fn push_row(&mut self, value: &str, delim: &str, row: RowRef<'_>) { - let key = memcmp_encoding::encode_row(row, &self.column_orders); + fn push_row(&mut self, value: &str, delim: &str, row: RowRef<'_>) -> Result<()> { + let key = + memcmp_encoding::encode_row(row.project(&self.order_col_indices), &self.order_types) + .map_err(|e| ExprError::Internal(anyhow!("failed to encode row, error: {}", e)))?; self.unordered_values.push(( key, StringAggData { @@ -147,6 +157,7 @@ impl StringAggOrdered { delim: delim.to_string(), }, )); + Ok(()) } fn get_result_and_reset(&mut self) -> Option { @@ -181,7 +192,7 @@ impl Aggregator for StringAggOrdered { let delim = delim_col.value_at(row_id).unwrap_or(""); let (row_ref, vis) = input.row_at(row_id); assert!(vis); - self.push_row(value, delim, row_ref); + self.push_row(value, delim, row_ref)?; } Ok(()) } else { @@ -209,7 +220,7 @@ impl Aggregator for StringAggOrdered { { let (row_ref, vis) = input.row_at(row_id); assert!(vis); - self.push_row(value.unwrap(), delim.unwrap_or(""), row_ref); + self.push_row(value.unwrap(), delim.unwrap_or(""), row_ref)?; } Ok(()) } else {