Skip to content

Commit

Permalink
refactor(common): remove unwrap from memcmp_encoding and improve …
Browse files Browse the repository at this point in the history
…`encode_row` (risingwavelabs#9087)

Signed-off-by: Richard Chien <[email protected]>
  • Loading branch information
stdrc authored Apr 10, 2023
1 parent 490161d commit 7088fd3
Show file tree
Hide file tree
Showing 6 changed files with 78 additions and 50 deletions.
2 changes: 1 addition & 1 deletion src/batch/src/executor/group_top_n.rs
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ impl<K: HashKey> GroupTopNExecutor<K> {
let chunk = Arc::new(chunk?.compact());
let keys = K::build(self.group_key.as_slice(), &chunk)?;

for (row_id, (encoded_row, key)) in encode_chunk(&chunk, &self.column_orders)
for (row_id, (encoded_row, key)) in encode_chunk(&chunk, &self.column_orders)?
.into_iter()
.zip_eq_fast(keys.into_iter())
.enumerate()
Expand Down
2 changes: 1 addition & 1 deletion src/batch/src/executor/order_by.rs
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ impl SortExecutor {
}

for chunk in &chunks {
let encoded_chunk = encode_chunk(chunk, &self.column_orders);
let encoded_chunk = encode_chunk(chunk, &self.column_orders)?;
encoded_rows.extend(
encoded_chunk
.into_iter()
Expand Down
2 changes: 1 addition & 1 deletion src/batch/src/executor/top_n.rs
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ impl TopNExecutor {
#[for_await]
for chunk in self.child.execute() {
let chunk = Arc::new(chunk?.compact());
for (row_id, encoded_row) in encode_chunk(&chunk, &self.column_orders)
for (row_id, encoded_row) in encode_chunk(&chunk, &self.column_orders)?
.into_iter()
.enumerate()
{
Expand Down
69 changes: 38 additions & 31 deletions src/common/src/util/memcmp_encoding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,9 @@ use bytes::{Buf, BufMut};
use itertools::Itertools;
use serde::{Deserialize, Serialize};

use super::iter_util::ZipEqFast;
use super::iter_util::{ZipEqDebug, ZipEqFast};
use crate::array::serial_array::Serial;
use crate::array::{ArrayImpl, DataChunk};
use crate::error::Result;
use crate::row::Row;
use crate::types::{DataType, Date, Datum, ScalarImpl, Time, Timestamp, ToDatumRef, F32, F64};
use crate::util::sort_util::{ColumnOrder, OrderType};
Expand Down Expand Up @@ -180,18 +179,22 @@ fn calculate_encoded_size_inner(
Ok(deserializer.position() - base_position)
}

pub fn encode_value(value: impl ToDatumRef, order: OrderType) -> Result<Vec<u8>> {
pub fn encode_value(value: impl ToDatumRef, order: OrderType) -> memcomparable::Result<Vec<u8>> {
let mut serializer = memcomparable::Serializer::new(vec![]);
serialize_datum(value, order, &mut serializer)?;
Ok(serializer.into_inner())
}

pub fn decode_value(ty: &DataType, encoded_value: &[u8], order: OrderType) -> Result<Datum> {
pub fn decode_value(
ty: &DataType,
encoded_value: &[u8],
order: OrderType,
) -> memcomparable::Result<Datum> {
let mut deserializer = memcomparable::Deserializer::new(encoded_value);
Ok(deserialize_datum(ty, order, &mut deserializer)?)
deserialize_datum(ty, order, &mut deserializer)
}

pub fn encode_array(array: &ArrayImpl, order: OrderType) -> Result<Vec<Vec<u8>>> {
pub fn encode_array(array: &ArrayImpl, order: OrderType) -> memcomparable::Result<Vec<Vec<u8>>> {
let mut data = Vec::with_capacity(array.len());
for datum in array.iter() {
data.push(encode_value(datum, order)?);
Expand All @@ -202,11 +205,14 @@ pub fn encode_array(array: &ArrayImpl, order: OrderType) -> Result<Vec<Vec<u8>>>
/// This function is used to accelerate the comparison of tuples. It takes datachunk and
/// user-defined order as input, yield encoded binary string with order preserved for each tuple in
/// the datachunk.
pub fn encode_chunk(chunk: &DataChunk, column_orders: &[ColumnOrder]) -> Vec<Vec<u8>> {
let encoded_columns = column_orders
pub fn encode_chunk(
chunk: &DataChunk,
column_orders: &[ColumnOrder],
) -> memcomparable::Result<Vec<Vec<u8>>> {
let encoded_columns: Vec<_> = column_orders
.iter()
.map(|o| encode_array(chunk.column_at(o.column_index).array_ref(), o.order_type).unwrap())
.collect_vec();
.map(|o| encode_array(chunk.column_at(o.column_index).array_ref(), o.order_type))
.try_collect()?;

let mut encoded_chunk = vec![vec![]; chunk.capacity()];
for encoded_column in encoded_columns {
Expand All @@ -215,16 +221,16 @@ pub fn encode_chunk(chunk: &DataChunk, column_orders: &[ColumnOrder]) -> Vec<Vec
}
}

encoded_chunk
Ok(encoded_chunk)
}

/// Encode a row into memcomparable format.
pub fn encode_row(row: impl Row, column_orders: &[ColumnOrder]) -> Vec<u8> {
let mut encoded_row = vec![];
column_orders.iter().for_each(|o| {
encoded_row.extend(encode_value(row.datum_at(o.column_index), o.order_type).unwrap());
});
encoded_row
pub fn encode_row(row: impl Row, order_types: &[OrderType]) -> memcomparable::Result<Vec<u8>> {
let mut serializer = memcomparable::Serializer::new(vec![]);
row.iter()
.zip_eq_debug(order_types)
.try_for_each(|(datum, order)| serialize_datum(datum, *order, &mut serializer))?;
Ok(serializer.into_inner())
}

#[cfg(test)]
Expand All @@ -236,7 +242,7 @@ mod tests {

use super::*;
use crate::array::{DataChunk, ListValue, StructValue};
use crate::row::OwnedRow;
use crate::row::{OwnedRow, RowExt};
use crate::types::{DataType, ScalarImpl, F32};
use crate::util::sort_util::{ColumnOrder, OrderType};

Expand Down Expand Up @@ -501,12 +507,10 @@ mod tests {

let row1 = OwnedRow::new(vec![v10, v11, v12]);
let row2 = OwnedRow::new(vec![v20, v21, v22]);
let column_orders = vec![
ColumnOrder::new(0, OrderType::ascending()),
ColumnOrder::new(1, OrderType::descending()),
];
let order_col_indices = vec![0, 1];
let order_types = vec![OrderType::ascending(), OrderType::descending()];

let encoded_row1 = encode_row(&row1, &column_orders);
let encoded_row1 = encode_row(row1.project(&order_col_indices), &order_types).unwrap();
let encoded_v10 = encode_value(
v10_cloned.as_ref().map(|x| x.as_scalar_ref_impl()),
OrderType::ascending(),
Expand All @@ -523,7 +527,7 @@ mod tests {
.collect_vec();
assert_eq!(encoded_row1, concated_encoded_row1);

let encoded_row2 = encode_row(&row2, &column_orders);
let encoded_row2 = encode_row(row2.project(&order_col_indices), &order_types).unwrap();
assert!(encoded_row1 < encoded_row2);
}

Expand All @@ -542,14 +546,17 @@ mod tests {
&[row1.clone(), row2.clone()],
&[DataType::Int32, DataType::Varchar, DataType::Float32],
);
let column_orders = vec![
ColumnOrder::new(0, OrderType::ascending()),
ColumnOrder::new(1, OrderType::descending()),
];
let order_col_indices = vec![0, 1];
let order_types = vec![OrderType::ascending(), OrderType::descending()];
let column_orders = order_col_indices
.iter()
.zip_eq_fast(&order_types)
.map(|(i, o)| ColumnOrder::new(*i, *o))
.collect_vec();

let encoded_row1 = encode_row(&row1, &column_orders);
let encoded_row2 = encode_row(&row2, &column_orders);
let encoded_chunk = encode_chunk(&chunk, &column_orders);
let encoded_row1 = encode_row(row1.project(&order_col_indices), &order_types).unwrap();
let encoded_row2 = encode_row(row2.project(&order_col_indices), &order_types).unwrap();
let encoded_chunk = encode_chunk(&chunk, &column_orders).unwrap();
assert_eq!(&encoded_chunk, &[encoded_row1, encoded_row2]);
}
}
26 changes: 18 additions & 8 deletions src/expr/src/vector_op/agg/array_agg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,16 @@
// See the License for the specific language governing permissions and
// limitations under the License.

use anyhow::anyhow;
use risingwave_common::array::{ArrayBuilder, ArrayBuilderImpl, DataChunk, ListValue, RowRef};
use risingwave_common::bail;
use risingwave_common::row::Row;
use risingwave_common::row::{Row, RowExt};
use risingwave_common::types::{DataType, Datum, Scalar, ToOwnedDatum};
use risingwave_common::util::memcmp_encoding;
use risingwave_common::util::sort_util::ColumnOrder;
use risingwave_common::util::sort_util::{ColumnOrder, OrderType};

use crate::vector_op::agg::aggregator::Aggregator;
use crate::Result;
use crate::{ExprError, Result};

#[derive(Clone)]
struct ArrayAggUnordered {
Expand Down Expand Up @@ -97,25 +98,34 @@ type OrderKey = Vec<u8>;
struct ArrayAggOrdered {
return_type: DataType,
agg_col_idx: usize,
column_orders: Vec<ColumnOrder>,
order_col_indices: Vec<usize>,
order_types: Vec<OrderType>,
unordered_values: Vec<(OrderKey, Datum)>,
}

impl ArrayAggOrdered {
fn new(return_type: DataType, agg_col_idx: usize, column_orders: Vec<ColumnOrder>) -> Self {
assert!(matches!(return_type, DataType::List { datatype: _ }));
let (order_col_indices, order_types) = column_orders
.into_iter()
.map(|c| (c.column_index, c.order_type))
.unzip();
ArrayAggOrdered {
return_type,
agg_col_idx,
column_orders,
order_col_indices,
order_types,
unordered_values: vec![],
}
}

fn push_row(&mut self, row: RowRef<'_>) {
let key = memcmp_encoding::encode_row(row, &self.column_orders);
fn push_row(&mut self, row: RowRef<'_>) -> Result<()> {
let key =
memcmp_encoding::encode_row(row.project(&self.order_col_indices), &self.order_types)
.map_err(|e| ExprError::Internal(anyhow!("failed to encode row, error: {}", e)))?;
let datum = row.datum_at(self.agg_col_idx).to_owned_datum();
self.unordered_values.push((key, datum));
Ok(())
}

fn get_result_and_reset(&mut self) -> ListValue {
Expand All @@ -134,7 +144,7 @@ impl Aggregator for ArrayAggOrdered {
async fn update_single(&mut self, input: &DataChunk, row_id: usize) -> Result<()> {
let (row, vis) = input.row_at(row_id);
assert!(vis);
self.push_row(row);
self.push_row(row)?;
Ok(())
}

Expand Down
27 changes: 19 additions & 8 deletions src/expr/src/vector_op/agg/string_agg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,19 @@
// See the License for the specific language governing permissions and
// limitations under the License.

use anyhow::anyhow;
use risingwave_common::array::{
Array, ArrayBuilder, ArrayBuilderImpl, ArrayImpl, DataChunk, RowRef,
};
use risingwave_common::bail;
use risingwave_common::row::RowExt;
use risingwave_common::types::DataType;
use risingwave_common::util::iter_util::ZipEqFast;
use risingwave_common::util::memcmp_encoding;
use risingwave_common::util::sort_util::ColumnOrder;
use risingwave_common::util::sort_util::{ColumnOrder, OrderType};

use crate::vector_op::agg::aggregator::Aggregator;
use crate::Result;
use crate::{ExprError, Result};

#[derive(Clone)]
struct StringAggUnordered {
Expand Down Expand Up @@ -124,29 +126,38 @@ struct StringAggData {
struct StringAggOrdered {
agg_col_idx: usize,
delim_col_idx: usize,
column_orders: Vec<ColumnOrder>,
order_col_indices: Vec<usize>,
order_types: Vec<OrderType>,
unordered_values: Vec<(OrderKey, StringAggData)>,
}

impl StringAggOrdered {
fn new(agg_col_idx: usize, delim_col_idx: usize, column_orders: Vec<ColumnOrder>) -> Self {
let (order_col_indices, order_types) = column_orders
.into_iter()
.map(|c| (c.column_index, c.order_type))
.unzip();
Self {
agg_col_idx,
delim_col_idx,
column_orders,
order_col_indices,
order_types,
unordered_values: vec![],
}
}

fn push_row(&mut self, value: &str, delim: &str, row: RowRef<'_>) {
let key = memcmp_encoding::encode_row(row, &self.column_orders);
fn push_row(&mut self, value: &str, delim: &str, row: RowRef<'_>) -> Result<()> {
let key =
memcmp_encoding::encode_row(row.project(&self.order_col_indices), &self.order_types)
.map_err(|e| ExprError::Internal(anyhow!("failed to encode row, error: {}", e)))?;
self.unordered_values.push((
key,
StringAggData {
value: value.to_string(),
delim: delim.to_string(),
},
));
Ok(())
}

fn get_result_and_reset(&mut self) -> Option<String> {
Expand Down Expand Up @@ -181,7 +192,7 @@ impl Aggregator for StringAggOrdered {
let delim = delim_col.value_at(row_id).unwrap_or("");
let (row_ref, vis) = input.row_at(row_id);
assert!(vis);
self.push_row(value, delim, row_ref);
self.push_row(value, delim, row_ref)?;
}
Ok(())
} else {
Expand Down Expand Up @@ -209,7 +220,7 @@ impl Aggregator for StringAggOrdered {
{
let (row_ref, vis) = input.row_at(row_id);
assert!(vis);
self.push_row(value.unwrap(), delim.unwrap_or(""), row_ref);
self.push_row(value.unwrap(), delim.unwrap_or(""), row_ref)?;
}
Ok(())
} else {
Expand Down

0 comments on commit 7088fd3

Please sign in to comment.