diff --git a/Cargo.lock b/Cargo.lock index 3dff9e714bdad..05cf79234a2ec 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -6144,6 +6144,7 @@ dependencies = [ "chrono", "chrono-tz", "criterion", + "ctor", "dyn-clone", "either", "futures-util", @@ -6156,14 +6157,28 @@ dependencies = [ "paste", "regex", "risingwave_common", + "risingwave_expr_macro", "risingwave_pb", "risingwave_udf", + "serde_json", "speedate", "static_assertions", "thiserror", + "tracing", "workspace-hack", ] +[[package]] +name = "risingwave_expr_macro" +version = "0.1.0" +dependencies = [ + "itertools", + "proc-macro-error", + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "risingwave_frontend" version = "0.2.0-alpha" diff --git a/Cargo.toml b/Cargo.toml index 042c61a240113..640214a213d91 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -11,6 +11,7 @@ members = [ "src/connector", "src/ctl", "src/expr", + "src/expr/macro", "src/frontend", "src/frontend/planner_test", "src/java_binding", diff --git a/Makefile.toml b/Makefile.toml index 11ad8dd45f5c5..7bbdac7b13358 100644 --- a/Makefile.toml +++ b/Makefile.toml @@ -622,12 +622,13 @@ tar xf ${TARGET_PATH} -C "${PREFIX_BIN}/connector-node" category = "RiseDev - Build in simulation mode" description = "Build in simulation mode" dependencies = ["warn-on-missing-tools"] -env = { RUSTFLAGS = "-Ctarget-cpu=native --cfg tokio_unstable --cfg madsim", RUSTDOCFLAGS = "--cfg madsim", CARGO_TARGET_DIR = "target/sim" } +env = { CARGO_TARGET_DIR = "target/sim" } script = """ #!/usr/bin/env bash set -e cargo build \ + --config "target.'cfg(all())'.rustflags = ['--cfg=madsim']" \ -p risingwave_batch \ -p risingwave_common \ -p risingwave_compute \ @@ -646,12 +647,13 @@ cargo build \ category = "RiseDev - Deterministic Simulation Test" description = "Run unit tests in deterministic simulation mode" dependencies = ["warn-on-missing-tools"] -env = { RUSTFLAGS = "-Ctarget-cpu=native --cfg tokio_unstable --cfg madsim", RUSTDOCFLAGS = "--cfg madsim", CARGO_TARGET_DIR = "target/sim" } +env = { CARGO_TARGET_DIR = "target/sim" } script = """ #!/usr/bin/env bash set -e cargo nextest run \ + --config "target.'cfg(all())'.rustflags = ['--cfg=madsim']" \ -p risingwave_batch \ -p risingwave_common \ -p risingwave_compute \ @@ -670,12 +672,13 @@ cargo nextest run \ category = "RiseDev - Simulation scaling tests" description = "Run integration scaling tests in deterministic simulation mode" dependencies = ["warn-on-missing-tools"] -env = { RUSTFLAGS = "-Ctarget-cpu=native --cfg tokio_unstable --cfg madsim", RUSTDOCFLAGS = "--cfg madsim", CARGO_TARGET_DIR = "target/sim" } +env = { CARGO_TARGET_DIR = "target/sim" } script = """ #!/usr/bin/env bash set -e cargo nextest run \ + --config "target.'cfg(all())'.rustflags = ['--cfg=madsim']" \ -p risingwave_simulation \ "$@" """ @@ -684,12 +687,13 @@ cargo nextest run \ category = "RiseDev - Simulation scaling tests" description = "Archive integration scaling tests in deterministic simulation mode" dependencies = ["warn-on-missing-tools"] -env = { RUSTFLAGS = "-Ctarget-cpu=native --cfg tokio_unstable --cfg madsim", RUSTDOCFLAGS = "--cfg madsim", CARGO_TARGET_DIR = "target/sim" } +env = { CARGO_TARGET_DIR = "target/sim" } script = """ #!/usr/bin/env bash set -e cargo nextest archive \ + --config "target.'cfg(all())'.rustflags = ['--cfg=madsim']" \ -p risingwave_simulation \ --archive-file scale-test.tar.zst \ "$@" @@ -699,48 +703,58 @@ cargo nextest archive \ category = "RiseDev - Deterministic Simulation End-to-end Test" description = "Run cargo check in deterministic simulation mode" dependencies = ["warn-on-missing-tools"] -env = { RUSTFLAGS = "-Ctarget-cpu=native --cfg tokio_unstable --cfg madsim", RUSTDOCFLAGS = "--cfg madsim", CARGO_TARGET_DIR = "target/sim" } +env = { CARGO_TARGET_DIR = "target/sim" } script = """ #!/usr/bin/env bash set -e -cargo check -p risingwave_simulation --all-targets "$@" +cargo check \ + --config "target.'cfg(all())'.rustflags = ['--cfg=madsim']" \ + -p risingwave_simulation --all-targets "$@" """ [tasks.sslt] category = "RiseDev - Deterministic Simulation End-to-end Test" description = "Run e2e tests in deterministic simulation mode" dependencies = ["warn-on-missing-tools"] -env = { RUSTFLAGS = "-Ctarget-cpu=native --cfg tokio_unstable --cfg madsim", RUSTDOCFLAGS = "--cfg madsim", CARGO_TARGET_DIR = "target/sim" } +env = { CARGO_TARGET_DIR = "target/sim" } script = """ #!/usr/bin/env bash set -e -cargo run -p risingwave_simulation "$@" +cargo run \ + --config "target.'cfg(all())'.rustflags = ['--cfg=madsim']" \ + -p risingwave_simulation "$@" """ [tasks.sslt-build-all] category = "RiseDev - Deterministic Simulation End-to-end Test" description = "Build deterministic simulation runner and tests" dependencies = ["warn-on-missing-tools"] -env = { RUSTFLAGS = "-Ctarget-cpu=native --cfg tokio_unstable --cfg madsim", RUSTDOCFLAGS = "--cfg madsim", CARGO_TARGET_DIR = "target/sim" } +env = { CARGO_TARGET_DIR = "target/sim" } script = """ #!/usr/bin/env bash set -e -cargo build -p risingwave_simulation --tests "$@" +cargo build \ + --config "target.'cfg(all())'.rustflags = ['--cfg=madsim']" \ + -p risingwave_simulation \ + --tests "$@" """ [tasks.sslt-cov] category = "RiseDev - Deterministic Simulation End-to-end Test" description = "Run e2e tests in deterministic simulation mode and report code coverage" dependencies = ["warn-on-missing-tools"] -env = { RUSTFLAGS = "-Ctarget-cpu=native --cfg tokio_unstable --cfg madsim", RUSTDOCFLAGS = "--cfg madsim", CARGO_TARGET_DIR = "target/sim-cov" } +env = { CARGO_TARGET_DIR = "target/sim-cov" } script = """ #!/usr/bin/env bash set -e -cargo llvm-cov run -p risingwave_simulation --html "$@" +cargo llvm-cov run \ + --config "target.'cfg(all())'.rustflags = ['--cfg=madsim']" \ + -p risingwave_simulation \ + --html "$@" """ [tasks.check-java] diff --git a/src/batch/src/executor/join/hash_join.rs b/src/batch/src/executor/join/hash_join.rs index e5eac41512a4f..36959b7253977 100644 --- a/src/batch/src/executor/join/hash_join.rs +++ b/src/batch/src/executor/join/hash_join.rs @@ -1799,8 +1799,8 @@ mod tests { use risingwave_common::test_prelude::DataChunkTestExt; use risingwave_common::types::DataType; use risingwave_common::util::iter_util::ZipEqDebug; - use risingwave_expr::expr::{new_binary_expr, BoxedExpression, InputRefExpression}; - use risingwave_pb::expr::expr_node::Type; + use risingwave_expr::expr::{build, BoxedExpression, Expression, InputRefExpression}; + use risingwave_pb::expr::expr_node::PbType; use super::{ ChunkedData, HashJoinExecutor, JoinType, LeftNonEquiJoinState, RightNonEquiJoinState, RowId, @@ -1985,13 +1985,13 @@ mod tests { } fn create_cond() -> BoxedExpression { - let left_expr = InputRefExpression::new(DataType::Float32, 1); - let right_expr = InputRefExpression::new(DataType::Float64, 3); - new_binary_expr( - Type::LessThan, + build( + PbType::LessThan, DataType::Boolean, - Box::new(left_expr), - Box::new(right_expr), + vec![ + InputRefExpression::new(DataType::Float32, 1).boxed(), + InputRefExpression::new(DataType::Float64, 3).boxed(), + ], ) .unwrap() } diff --git a/src/batch/src/executor/join/local_lookup_join.rs b/src/batch/src/executor/join/local_lookup_join.rs index cbcd1dea651fa..2976c26340a4a 100644 --- a/src/batch/src/executor/join/local_lookup_join.rs +++ b/src/batch/src/executor/join/local_lookup_join.rs @@ -462,9 +462,9 @@ mod tests { use risingwave_common::util::chunk_coalesce::DataChunkBuilder; use risingwave_common::util::sort_util::{ColumnOrder, OrderType}; use risingwave_expr::expr::{ - new_binary_expr, BoxedExpression, InputRefExpression, LiteralExpression, + build, BoxedExpression, Expression, InputRefExpression, LiteralExpression, }; - use risingwave_pb::expr::expr_node::Type; + use risingwave_pb::expr::expr_node::PbType; use super::LocalLookupJoinExecutorArgs; use crate::executor::join::JoinType; @@ -676,14 +676,13 @@ mod tests { ); let condition = Some( - new_binary_expr( - Type::LessThan, + build( + PbType::LessThan, DataType::Boolean, - Box::new(LiteralExpression::new( - DataType::Int32, - Some(ScalarImpl::Int32(5)), - )), - Box::new(InputRefExpression::new(DataType::Float32, 3)), + vec![ + LiteralExpression::new(DataType::Int32, Some(ScalarImpl::Int32(5))).boxed(), + InputRefExpression::new(DataType::Float32, 3).boxed(), + ], ) .unwrap(), ); @@ -705,14 +704,13 @@ mod tests { ); let condition = Some( - new_binary_expr( - Type::LessThan, + build( + PbType::LessThan, DataType::Boolean, - Box::new(LiteralExpression::new( - DataType::Int32, - Some(ScalarImpl::Int32(5)), - )), - Box::new(InputRefExpression::new(DataType::Float32, 3)), + vec![ + LiteralExpression::new(DataType::Int32, Some(ScalarImpl::Int32(5))).boxed(), + InputRefExpression::new(DataType::Float32, 3).boxed(), + ], ) .unwrap(), ); @@ -730,14 +728,13 @@ mod tests { ); let condition = Some( - new_binary_expr( - Type::LessThan, + build( + PbType::LessThan, DataType::Boolean, - Box::new(LiteralExpression::new( - DataType::Int32, - Some(ScalarImpl::Int32(5)), - )), - Box::new(InputRefExpression::new(DataType::Float32, 3)), + vec![ + LiteralExpression::new(DataType::Int32, Some(ScalarImpl::Int32(5))).boxed(), + InputRefExpression::new(DataType::Float32, 3).boxed(), + ], ) .unwrap(), ); @@ -756,14 +753,13 @@ mod tests { ); let condition = Some( - new_binary_expr( - Type::LessThan, + build( + PbType::LessThan, DataType::Boolean, - Box::new(LiteralExpression::new( - DataType::Int32, - Some(ScalarImpl::Int32(5)), - )), - Box::new(InputRefExpression::new(DataType::Float32, 3)), + vec![ + LiteralExpression::new(DataType::Int32, Some(ScalarImpl::Int32(5))).boxed(), + InputRefExpression::new(DataType::Float32, 3).boxed(), + ], ) .unwrap(), ); diff --git a/src/batch/src/executor/join/nested_loop_join.rs b/src/batch/src/executor/join/nested_loop_join.rs index 5c229fa77ee5c..8a9f8d94aed9a 100644 --- a/src/batch/src/executor/join/nested_loop_join.rs +++ b/src/batch/src/executor/join/nested_loop_join.rs @@ -470,8 +470,8 @@ mod tests { use risingwave_common::array::*; use risingwave_common::catalog::{Field, Schema}; use risingwave_common::types::DataType; - use risingwave_expr::expr::{new_binary_expr, InputRefExpression}; - use risingwave_pb::expr::expr_node::Type; + use risingwave_expr::expr::{build, InputRefExpression}; + use risingwave_pb::expr::expr_node::PbType; use crate::executor::join::nested_loop_join::NestedLoopJoinExecutor; use crate::executor::join::JoinType; @@ -587,11 +587,13 @@ mod tests { }; Box::new(NestedLoopJoinExecutor::new( - new_binary_expr( - Type::Equal, + build( + PbType::Equal, DataType::Boolean, - Box::new(InputRefExpression::new(DataType::Int32, 0)), - Box::new(InputRefExpression::new(DataType::Int32, 2)), + vec![ + Box::new(InputRefExpression::new(DataType::Int32, 0)), + Box::new(InputRefExpression::new(DataType::Int32, 2)), + ], ) .unwrap(), join_type, diff --git a/src/common/src/array/jsonb_array.rs b/src/common/src/array/jsonb_array.rs index e882c49bc3cc2..f58608c9b92d7 100644 --- a/src/common/src/array/jsonb_array.rs +++ b/src/common/src/array/jsonb_array.rs @@ -274,6 +274,26 @@ impl JsonbRef<'_> { } } +impl FromIterator> for JsonbArray { + fn from_iter>>(iter: I) -> Self { + let iter = iter.into_iter(); + let mut builder = ::Builder::new(iter.size_hint().0); + for i in iter { + match i { + Some(x) => builder.append(Some(x.as_scalar_ref())), + None => builder.append(None), + } + } + builder.finish() + } +} + +impl FromIterator for JsonbArray { + fn from_iter>(iter: I) -> Self { + iter.into_iter().map(Some).collect() + } +} + #[derive(Debug)] pub struct JsonbArrayBuilder { bitmap: BitmapBuilder, diff --git a/src/common/src/array/list_array.rs b/src/common/src/array/list_array.rs index a90d3fb750673..5170937555359 100644 --- a/src/common/src/array/list_array.rs +++ b/src/common/src/array/list_array.rs @@ -156,7 +156,7 @@ pub struct ListArray { bitmap: Bitmap, pub(super) offsets: Vec, pub(super) value: Box, - value_type: DataType, + pub(super) value_type: DataType, } impl Array for ListArray { diff --git a/src/common/src/array/mod.rs b/src/common/src/array/mod.rs index e75c2095a7b78..dcabfac81a674 100644 --- a/src/common/src/array/mod.rs +++ b/src/common/src/array/mod.rs @@ -59,13 +59,13 @@ pub use list_array::{ListArray, ListArrayBuilder, ListRef, ListValue}; use paste::paste; pub use primitive_array::{PrimitiveArray, PrimitiveArrayBuilder, PrimitiveArrayItemType}; use risingwave_pb::data::{PbArray, PbArrayType}; +pub use serial_array::{Serial, SerialArray, SerialArrayBuilder}; pub use stream_chunk::{Op, StreamChunk, StreamChunkTestExt}; pub use struct_array::{StructArray, StructArrayBuilder, StructRef, StructValue}; pub use utf8_array::*; pub use vis::{Vis, VisRef}; pub use self::error::ArrayError; -use crate::array::serial_array::{Serial, SerialArray, SerialArrayBuilder}; use crate::buffer::Bitmap; use crate::types::*; use crate::util::iter_util::ZipEqFast; diff --git a/src/common/src/types/mod.rs b/src/common/src/types/mod.rs index 7ab0108dd5cb0..6868d98de7ef8 100644 --- a/src/common/src/types/mod.rs +++ b/src/common/src/types/mod.rs @@ -20,6 +20,7 @@ use bytes::{Buf, BufMut, Bytes}; use num_traits::Float; use parse_display::{Display, FromStr}; use postgres_types::FromSql; +use risingwave_pb::data::data_type::PbTypeName; use risingwave_pb::data::PbDataType; use serde::{Deserialize, Serialize}; @@ -35,7 +36,6 @@ use std::fmt::Debug; use std::str::{FromStr, Utf8Error}; pub use native_type::*; -use risingwave_pb::data::data_type::TypeName; pub use scalar_impl::*; pub use successor::*; pub mod chrono_wrapper; @@ -210,32 +210,57 @@ pub fn unnested_list_type(datatype: DataType) -> DataType { impl From<&PbDataType> for DataType { fn from(proto: &PbDataType) -> DataType { match proto.get_type_name().expect("missing type field") { - TypeName::Int16 => DataType::Int16, - TypeName::Int32 => DataType::Int32, - TypeName::Int64 => DataType::Int64, - TypeName::Serial => DataType::Serial, - TypeName::Float => DataType::Float32, - TypeName::Double => DataType::Float64, - TypeName::Boolean => DataType::Boolean, - TypeName::Varchar => DataType::Varchar, - TypeName::Date => DataType::Date, - TypeName::Time => DataType::Time, - TypeName::Timestamp => DataType::Timestamp, - TypeName::Timestamptz => DataType::Timestamptz, - TypeName::Decimal => DataType::Decimal, - TypeName::Interval => DataType::Interval, - TypeName::Bytea => DataType::Bytea, - TypeName::Jsonb => DataType::Jsonb, - TypeName::Struct => { + PbTypeName::Int16 => DataType::Int16, + PbTypeName::Int32 => DataType::Int32, + PbTypeName::Int64 => DataType::Int64, + PbTypeName::Serial => DataType::Serial, + PbTypeName::Float => DataType::Float32, + PbTypeName::Double => DataType::Float64, + PbTypeName::Boolean => DataType::Boolean, + PbTypeName::Varchar => DataType::Varchar, + PbTypeName::Date => DataType::Date, + PbTypeName::Time => DataType::Time, + PbTypeName::Timestamp => DataType::Timestamp, + PbTypeName::Timestamptz => DataType::Timestamptz, + PbTypeName::Decimal => DataType::Decimal, + PbTypeName::Interval => DataType::Interval, + PbTypeName::Bytea => DataType::Bytea, + PbTypeName::Jsonb => DataType::Jsonb, + PbTypeName::Struct => { let fields: Vec = proto.field_type.iter().map(|f| f.into()).collect_vec(); let field_names: Vec = proto.field_names.iter().cloned().collect_vec(); DataType::new_struct(fields, field_names) } - TypeName::List => DataType::List { + PbTypeName::List => DataType::List { // The first (and only) item is the list element type. datatype: Box::new((&proto.field_type[0]).into()), }, - TypeName::TypeUnspecified => unreachable!(), + PbTypeName::TypeUnspecified => unreachable!(), + } + } +} + +impl From for PbTypeName { + fn from(type_name: DataTypeName) -> Self { + match type_name { + DataTypeName::Boolean => PbTypeName::Boolean, + DataTypeName::Int16 => PbTypeName::Int16, + DataTypeName::Int32 => PbTypeName::Int32, + DataTypeName::Int64 => PbTypeName::Int64, + DataTypeName::Serial => PbTypeName::Serial, + DataTypeName::Float32 => PbTypeName::Float, + DataTypeName::Float64 => PbTypeName::Double, + DataTypeName::Varchar => PbTypeName::Varchar, + DataTypeName::Date => PbTypeName::Date, + DataTypeName::Timestamp => PbTypeName::Timestamp, + DataTypeName::Timestamptz => PbTypeName::Timestamptz, + DataTypeName::Time => PbTypeName::Time, + DataTypeName::Interval => PbTypeName::Interval, + DataTypeName::Decimal => PbTypeName::Decimal, + DataTypeName::Bytea => PbTypeName::Bytea, + DataTypeName::Jsonb => PbTypeName::Jsonb, + DataTypeName::Struct => PbTypeName::Struct, + DataTypeName::List => PbTypeName::List, } } } @@ -273,26 +298,26 @@ impl DataType { } } - pub fn prost_type_name(&self) -> TypeName { + pub fn prost_type_name(&self) -> PbTypeName { match self { - DataType::Int16 => TypeName::Int16, - DataType::Int32 => TypeName::Int32, - DataType::Int64 => TypeName::Int64, - DataType::Serial => TypeName::Serial, - DataType::Float32 => TypeName::Float, - DataType::Float64 => TypeName::Double, - DataType::Boolean => TypeName::Boolean, - DataType::Varchar => TypeName::Varchar, - DataType::Date => TypeName::Date, - DataType::Time => TypeName::Time, - DataType::Timestamp => TypeName::Timestamp, - DataType::Timestamptz => TypeName::Timestamptz, - DataType::Decimal => TypeName::Decimal, - DataType::Interval => TypeName::Interval, - DataType::Jsonb => TypeName::Jsonb, - DataType::Struct { .. } => TypeName::Struct, - DataType::List { .. } => TypeName::List, - DataType::Bytea => TypeName::Bytea, + DataType::Int16 => PbTypeName::Int16, + DataType::Int32 => PbTypeName::Int32, + DataType::Int64 => PbTypeName::Int64, + DataType::Serial => PbTypeName::Serial, + DataType::Float32 => PbTypeName::Float, + DataType::Float64 => PbTypeName::Double, + DataType::Boolean => PbTypeName::Boolean, + DataType::Varchar => PbTypeName::Varchar, + DataType::Date => PbTypeName::Date, + DataType::Time => PbTypeName::Time, + DataType::Timestamp => PbTypeName::Timestamp, + DataType::Timestamptz => PbTypeName::Timestamptz, + DataType::Decimal => PbTypeName::Decimal, + DataType::Interval => PbTypeName::Interval, + DataType::Jsonb => PbTypeName::Jsonb, + DataType::Struct { .. } => PbTypeName::Struct, + DataType::List { .. } => PbTypeName::List, + DataType::Bytea => PbTypeName::Bytea, } } diff --git a/src/expr/Cargo.toml b/src/expr/Cargo.toml index c17dde73e7100..c1e8c66846ecf 100644 --- a/src/expr/Cargo.toml +++ b/src/expr/Cargo.toml @@ -22,6 +22,7 @@ arrow-schema = "34" async-trait = "0.1" chrono = { version = "0.4", default-features = false, features = ["clock", "std"] } chrono-tz = { version = "0.7", features = ["case-insensitive"] } +ctor = "0.1" dyn-clone = "1" either = "1" futures-util = "0.3" @@ -33,18 +34,21 @@ parse-display = "0.6" paste = "1" regex = "1" risingwave_common = { path = "../common" } +risingwave_expr_macro = { path = "macro" } risingwave_pb = { path = "../prost" } risingwave_udf = { path = "../udf" } speedate = "0.7.0" static_assertions = "1" thiserror = "1" tokio = { version = "0.2", package = "madsim-tokio", features = ["rt", "rt-multi-thread", "sync", "macros", "time", "signal"] } +tracing = "0.1" [target.'cfg(not(madsim))'.dependencies] workspace-hack = { path = "../workspace-hack" } [dev-dependencies] criterion = "0.4" +serde_json = "1" [[bench]] name = "expr" diff --git a/src/expr/benches/expr.rs b/src/expr/benches/expr.rs index eacbb54436616..33d54ce13e438 100644 --- a/src/expr/benches/expr.rs +++ b/src/expr/benches/expr.rs @@ -29,14 +29,12 @@ use risingwave_common::types::{ DataType, DataTypeName, Decimal, IntervalUnit, NaiveDateTimeWrapper, NaiveDateWrapper, NaiveTimeWrapper, OrderedF32, OrderedF64, }; -use risingwave_expr::expr::test_utils::{make_expression, make_string_literal}; use risingwave_expr::expr::*; use risingwave_expr::sig::agg::agg_func_sigs; -use risingwave_expr::sig::cast::cast_sigs; use risingwave_expr::sig::func::func_sigs; use risingwave_expr::vector_op::agg::create_agg_state_unary; use risingwave_expr::ExprError; -use risingwave_pb::expr::expr_node::{RexNode, Type as ExprType}; +use risingwave_pb::expr::expr_node::PbType; criterion_group!(benches, bench_expr, bench_raw); criterion_main!(benches); @@ -157,6 +155,13 @@ fn bench_expr(c: &mut Criterion) { .take(CHUNK_SIZE), ) .into(), + // 25: serial array + SerialArray::from_iter((1..=CHUNK_SIZE).map(|i| Serial::from(i as i64))).into(), + // 26: jsonb array + JsonbArray::from_iter( + (1..=CHUNK_SIZE).map(|i| JsonbVal::from_serde(serde_json::Value::Number(i.into()))), + ) + .into(), ], CHUNK_SIZE, ); @@ -165,6 +170,7 @@ fn bench_expr(c: &mut Criterion) { InputRefExpression::new(DataType::Int16, 1), InputRefExpression::new(DataType::Int32, 2), InputRefExpression::new(DataType::Int64, 3), + InputRefExpression::new(DataType::Serial, 25), InputRefExpression::new(DataType::Float32, 4), InputRefExpression::new(DataType::Float64, 5), InputRefExpression::new(DataType::Decimal, 6), @@ -175,12 +181,14 @@ fn bench_expr(c: &mut Criterion) { InputRefExpression::new(DataType::Interval, 11), InputRefExpression::new(DataType::Varchar, 12), InputRefExpression::new(DataType::Bytea, 13), + InputRefExpression::new(DataType::Jsonb, 26), ]; - let inputref_for_type = |ty: DataType| { + let input_index_for_type = |ty: DataType| { inputrefs .iter() .find(|r| r.return_type() == ty) - .expect("expression not found") + .unwrap_or_else(|| panic!("expression not found for {ty:?}")) + .index() }; const TIMEZONE: usize = 14; const TIME_FIELD: usize = 15; @@ -210,66 +218,74 @@ fn bench_expr(c: &mut Criterion) { }); let sigs = func_sigs(); - let sigs = sigs.sorted_by_cached_key(|sig| sig.to_string_no_return()); - for sig in sigs { + let sigs = sigs.sorted_by_cached_key(|sig| format!("{sig:?}")); + 'sig: for sig in sigs { if sig .inputs_type .iter() .any(|t| matches!(t, DataTypeName::Struct | DataTypeName::List)) { // TODO: support struct and list - println!("todo: {}", sig.to_string_no_return()); + println!("todo: {sig:?}"); continue; } - let mut prost = make_expression( - sig.func, - &sig.inputs_type - .iter() - .map(|t| DataType::from(*t).prost_type_name()) - .collect_vec(), - &sig.inputs_type - .iter() - .enumerate() - .map(|(idx, t)| match (sig.func, idx) { - (ExprType::AtTimeZone, 1) => TIMEZONE, - (ExprType::DateTrunc, 0) => TIME_FIELD, - (ExprType::DateTrunc, 2) => TIMEZONE, - (ExprType::Extract, 0) => match sig.inputs_type[1] { - DataTypeName::Date => EXTRACT_FIELD_DATE, - DataTypeName::Time => EXTRACT_FIELD_TIME, - DataTypeName::Timestamp => EXTRACT_FIELD_TIMESTAMP, - DataTypeName::Timestamptz => EXTRACT_FIELD_TIMESTAMPTZ, - t => panic!("unexpected type: {t:?}"), - }, - _ => inputref_for_type((*t).into()).index(), - }) - .collect_vec(), - ); - if sig.func == ExprType::ToChar { - let RexNode::FuncCall(f) = prost.rex_node.as_mut().unwrap() else { unreachable!() }; - f.children[1] = make_string_literal("YYYY/MM/DD HH:MM:SS"); + let mut children = vec![]; + for (i, t) in sig.inputs_type.iter().enumerate() { + use DataTypeName::*; + let idx = match (sig.func, i) { + (PbType::ToChar, 1) => { + children.push( + LiteralExpression::new( + DataType::Varchar, + Some("YYYY/MM/DD HH:MM:SS".into()), + ) + .boxed(), + ); + continue; + } + (PbType::Cast, 0) if *t == DataTypeName::Varchar => match sig.ret_type { + Boolean => BOOL_STRING, + Int16 | Int32 | Int64 | Float32 | Float64 | Decimal => NUMBER_STRING, + Date => DATE_STRING, + Time => TIME_STRING, + Timestamp => TIMESTAMP_STRING, + Timestamptz => TIMESTAMPTZ_STRING, + Interval => INTERVAL_STRING, + Bytea => NUMBER_STRING, // any + _ => { + println!("todo: {sig:?}"); + continue 'sig; + } + }, + (PbType::AtTimeZone, 1) => TIMEZONE, + (PbType::DateTrunc, 0) => TIME_FIELD, + (PbType::DateTrunc, 2) => TIMEZONE, + (PbType::Extract, 0) => match sig.inputs_type[1] { + Date => EXTRACT_FIELD_DATE, + Time => EXTRACT_FIELD_TIME, + Timestamp => EXTRACT_FIELD_TIMESTAMP, + Timestamptz => EXTRACT_FIELD_TIMESTAMPTZ, + t => panic!("unexpected type: {t:?}"), + }, + _ => input_index_for_type((*t).into()), + }; + children.push(InputRefExpression::new(DataType::from(*t), idx).boxed()); } - let expr = match build_from_prost(&prost) { - Ok(expr) => expr, - Err(e) => { - println!("error: {e}"); - continue; - } - }; - c.bench_function(&sig.to_string_no_return(), |bencher| { + let expr = build(sig.func, sig.ret_type.into(), children).unwrap(); + c.bench_function(&format!("{sig:?}"), |bencher| { bencher.to_async(FuturesExecutor).iter(|| expr.eval(&input)) }); } for sig in agg_func_sigs() { if sig.inputs_type.len() != 1 { - println!("todo: {}", sig.to_string_no_return()); + println!("todo: {sig:?}"); continue; } let agg = match create_agg_state_unary( sig.inputs_type[0].into(), - inputref_for_type(sig.inputs_type[0].into()).index(), + input_index_for_type(sig.inputs_type[0].into()), sig.func, sig.ret_type.into(), false, @@ -292,65 +308,6 @@ fn bench_expr(c: &mut Criterion) { }) }); } - - for sig in cast_sigs() { - let expr = match new_unary_expr( - ExprType::Cast, - sig.to_type.into(), - if matches!(sig.from_type, DataTypeName::Varchar) { - use DataTypeName::*; - let idx = match sig.to_type { - Boolean => BOOL_STRING, - Int16 | Int32 | Int64 | Float32 | Float64 | Decimal => NUMBER_STRING, - Date => DATE_STRING, - Time => TIME_STRING, - Timestamp => TIMESTAMP_STRING, - Timestamptz => TIMESTAMPTZ_STRING, - Interval => INTERVAL_STRING, - Bytea => NUMBER_STRING, // any - _ => { - println!("todo: {}", sig.to_string_no_return()); - continue; - } - }; - InputRefExpression::new(DataType::Varchar, idx).boxed() - } else { - inputref_for_type(sig.from_type.into()).clone().boxed() - }, - ) { - Ok(expr) => expr, - Err(e) => { - println!("error: {e}"); - continue; - } - }; - c.bench_function(&sig.to_string_no_return(), |bencher| { - bencher.to_async(FuturesExecutor).iter(|| expr.eval(&input)) - }); - } - - // ~360ns - // This should be the optimization goal for our add expression. - c.bench_function("TBD/add(int32,int32)", |bencher| { - bencher.iter(|| { - let a = input.column_at(2).array_ref().as_int32(); - let b = input.column_at(2).array_ref().as_int32(); - assert_eq!(a.len(), b.len()); - let mut c = (a.raw_iter()) - .zip(b.raw_iter()) - .map(|(a, b)| a + b) - .collect::(); - let mut overflow = false; - for ((a, b), c) in a.raw_iter().zip(b.raw_iter()).zip(c.raw_iter()) { - overflow |= (c ^ a) & (c ^ b) < 0; - } - if overflow { - return Err(ExprError::NumericOverflow); - } - c.set_bitmap(a.null_bitmap() & b.null_bitmap()); - Ok(c) - }) - }); } /// Evaluate on raw Rust array. @@ -457,4 +414,27 @@ fn bench_raw(c: &mut Criterion) { }) }, ); + + // ~360ns + // This should be the optimization goal for our add expression. + c.bench_function("TBD/add(int32,int32)", |bencher| { + bencher.iter(|| { + let a = (0..CHUNK_SIZE as i32).collect::(); + let b = (0..CHUNK_SIZE as i32).collect::(); + assert_eq!(a.len(), b.len()); + let mut c = (a.raw_iter()) + .zip(b.raw_iter()) + .map(|(a, b)| a + b) + .collect::(); + let mut overflow = false; + for ((a, b), c) in a.raw_iter().zip(b.raw_iter()).zip(c.raw_iter()) { + overflow |= (c ^ a) & (c ^ b) < 0; + } + if overflow { + return Err(ExprError::NumericOverflow); + } + c.set_bitmap(a.null_bitmap() & b.null_bitmap()); + Ok(c) + }) + }); } diff --git a/src/expr/macro/Cargo.toml b/src/expr/macro/Cargo.toml new file mode 100644 index 0000000000000..524f144053930 --- /dev/null +++ b/src/expr/macro/Cargo.toml @@ -0,0 +1,15 @@ +[package] +name = "risingwave_expr_macro" +version = "0.1.0" +edition = "2021" +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[lib] +proc-macro = true + +[dependencies] +itertools = "0.10" +proc-macro-error = "1" +proc-macro2 = "1" +quote = "1" +syn = "1" diff --git a/src/expr/macro/src/gen.rs b/src/expr/macro/src/gen.rs new file mode 100644 index 0000000000000..09c4d3ba753c8 --- /dev/null +++ b/src/expr/macro/src/gen.rs @@ -0,0 +1,286 @@ +// Copyright 2023 RisingWave Labs +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//! Generate code for the functions. + +use itertools::Itertools; +use proc_macro2::Span; +use quote::{format_ident, quote}; + +use super::*; + +impl FunctionAttr { + /// Generate descriptors of the function. + /// + /// If the function arguments or return type contains wildcard, it will generate descriptors for + /// each of them. + pub fn generate_descriptors(&self, build_fn: bool) -> Result { + let args = self.args.iter().map(|ty| types::expand_type_wildcard(ty)); + let ret = types::expand_type_wildcard(&self.ret); + let mut tokens = TokenStream2::new(); + // multi_cartesian_product should emit an empty set if the input is empty. + let args_cartesian_product = + args.multi_cartesian_product() + .chain(match self.args.is_empty() { + true => vec![vec![]], + false => vec![], + }); + for (args, mut ret) in args_cartesian_product.cartesian_product(ret) { + if ret == "auto" { + ret = types::min_compatible_type(&args); + } + let attr = FunctionAttr { + name: self.name.clone(), + args: args.iter().map(|s| s.to_string()).collect(), + ret: ret.to_string(), + batch: self.batch.clone(), + user_fn: self.user_fn.clone(), + }; + tokens.extend(attr.generate_descriptor_one(build_fn)?); + } + Ok(tokens) + } + + /// Generate a descriptor of the function. + /// + /// The types of arguments and return value should not contain wildcard. + fn generate_descriptor_one(&self, build_fn: bool) -> Result { + let name = self.name.clone(); + + fn to_data_type_name(ty: &str) -> Result { + let variant = format_ident!( + "{}", + types::to_data_type_name(ty).ok_or_else(|| Error::new( + Span::call_site(), + format!("unknown type: {}", ty), + ))? + ); + Ok(quote! { risingwave_common::types::DataTypeName::#variant }) + } + let mut args = Vec::with_capacity(self.args.len()); + for ty in &self.args { + args.push(to_data_type_name(ty)?); + } + let ret = to_data_type_name(&self.ret)?; + + let pb_type = format_ident!("{}", utils::to_camel_case(&name)); + let ctor_name = format_ident!("{}_{}_{}", self.name, self.args.join("_"), self.ret); + let descriptor_type = quote! { crate::sig::func::FuncSign }; + let build_fn = if build_fn { + let name = format_ident!("{}", self.user_fn.name); + quote! { #name } + } else { + self.generate_build_fn()? + }; + Ok(quote! { + #[ctor::ctor] + fn #ctor_name() { + unsafe { crate::sig::func::_register(#descriptor_type { + name: #name, + func: risingwave_pb::expr::expr_node::Type::#pb_type, + inputs_type: &[#(#args),*], + ret_type: #ret, + build: #build_fn, + }) }; + } + }) + } + + fn generate_build_fn(&self) -> Result { + let num_args = self.args.len(); + let fn_name = format_ident!("{}", self.user_fn.name); + let arg_arrays = self + .args + .iter() + .map(|t| format_ident!("{}", types::to_array_type(t))); + let ret_array = format_ident!("{}", types::to_array_type(&self.ret)); + let arg_types = self + .args + .iter() + .map(|t| types::to_data_type(t).parse::().unwrap()); + let ret_type = types::to_data_type(&self.ret) + .parse::() + .unwrap(); + let exprs = (0..num_args) + .map(|i| format_ident!("e{i}")) + .collect::>(); + let exprs0 = exprs.clone(); + + let build_expr = if self.ret == "varchar" && self.user_fn.is_writer_style() { + let template_struct = match num_args { + 1 => format_ident!("UnaryBytesExpression"), + 2 => format_ident!("BinaryBytesExpression"), + 3 => format_ident!("TernaryBytesExpression"), + 4 => format_ident!("QuaternaryBytesExpression"), + _ => return Err(Error::new(Span::call_site(), "unsupported arguments")), + }; + let args = (0..=num_args).map(|i| format_ident!("x{i}")); + let args1 = args.clone(); + let func = match self.user_fn.return_type { + ReturnType::T => quote! { Ok(#fn_name(#(#args1),*)) }, + ReturnType::Result => quote! { #fn_name(#(#args1),*) }, + _ => todo!("returning Option is not supported yet"), + }; + quote! { + Ok(Box::new(crate::expr::template::#template_struct::<#(#arg_arrays),*, _>::new( + #(#exprs),*, + return_type, + |#(#args),*| #func, + ))) + } + } else if self.args.iter().all(|t| t == "boolean") + && self.ret == "boolean" + && !self.user_fn.return_type.contains_result() + && self.batch.is_some() + { + let template_struct = match num_args { + 1 => format_ident!("BooleanUnaryExpression"), + 2 => format_ident!("BooleanBinaryExpression"), + _ => return Err(Error::new(Span::call_site(), "unsupported arguments")), + }; + let batch = format_ident!("{}", self.batch.as_ref().unwrap()); + let args = (0..num_args).map(|i| format_ident!("x{i}")); + let args1 = args.clone(); + let func = if self.user_fn.arg_option && self.user_fn.return_type == ReturnType::Option + { + quote! { #fn_name(#(#args1),*) } + } else if self.user_fn.arg_option { + quote! { Some(#fn_name(#(#args1),*)) } + } else { + let args2 = args.clone(); + let args3 = args.clone(); + quote! { + match (#(#args1),*) { + (#(Some(#args2)),*) => Some(#fn_name(#(#args3),*)), + _ => None, + } + } + }; + quote! { + Ok(Box::new(crate::expr::template_fast::#template_struct::new( + #(#exprs),*, + #batch, + |#(#args),*| #func, + ))) + } + } else if self.args.len() == 2 && self.ret == "boolean" && self.user_fn.is_pure() { + let compatible_type = types::to_data_type(types::min_compatible_type(&self.args)) + .parse::() + .unwrap(); + let args = (0..num_args).map(|i| format_ident!("x{i}")); + let args1 = args.clone(); + let generic = if self.user_fn.generic == 3 { + // XXX: for generic compare functions, we need to specify the compatible type + quote! { ::<_, _, #compatible_type> } + } else { + quote! {} + }; + quote! { + Ok(Box::new(crate::expr::template_fast::CompareExpression::<_, #(#arg_arrays),*>::new( + #(#exprs),*, + |#(#args),*| #fn_name #generic(#(#args1),*), + ))) + } + } else if self.args.iter().all(|t| types::is_primitive(t)) && self.user_fn.is_pure() { + let template_struct = match num_args { + 1 => format_ident!("UnaryExpression"), + 2 => format_ident!("BinaryExpression"), + _ => return Err(Error::new(Span::call_site(), "unsupported arguments")), + }; + quote! { + Ok(Box::new(crate::expr::template_fast::#template_struct::<_, #(#arg_types),*, #ret_type>::new( + #(#exprs),*, + return_type, + #fn_name, + ))) + } + } else if self.user_fn.arg_option || self.user_fn.return_type.contains_option() { + let template_struct = match num_args { + 1 => format_ident!("UnaryNullableExpression"), + 2 => format_ident!("BinaryNullableExpression"), + 3 => format_ident!("TernaryNullableExpression"), + _ => return Err(Error::new(Span::call_site(), "unsupported arguments")), + }; + let args = (0..num_args).map(|i| format_ident!("x{i}")); + let args1 = args.clone(); + let generic = if self.user_fn.generic == 3 { + // XXX: for generic compare functions, we need to specify the compatible type + let compatible_type = types::to_data_type(types::min_compatible_type(&self.args)) + .parse::() + .unwrap(); + quote! { ::<_, _, #compatible_type> } + } else { + quote! {} + }; + let mut func = quote! { #fn_name #generic(#(#args1),*) }; + func = match self.user_fn.return_type { + ReturnType::T => quote! { Ok(Some(#func)) }, + ReturnType::Option => quote! { Ok(#func) }, + ReturnType::Result => quote! { #func.map(Some) }, + ReturnType::ResultOption => quote! { #func }, + }; + if !self.user_fn.arg_option { + let args2 = args.clone(); + let args3 = args.clone(); + func = quote! { + match (#(#args2),*) { + (#(Some(#args3)),*) => #func, + _ => Ok(None), + } + }; + }; + quote! { + Ok(Box::new(crate::expr::template::#template_struct::<#(#arg_arrays),*, #ret_array, _>::new( + #(#exprs),*, + return_type, + |#(#args),*| #func, + ))) + } + } else { + let template_struct = match num_args { + 1 => format_ident!("UnaryExpression"), + 2 => format_ident!("BinaryExpression"), + 3 => format_ident!("TernaryExpression"), + _ => return Err(Error::new(Span::call_site(), "unsupported arguments")), + }; + let args = (0..num_args).map(|i| format_ident!("x{i}")); + let args1 = args.clone(); + let func = match self.user_fn.return_type { + ReturnType::T => quote! { Ok(#fn_name(#(#args1),*)) }, + ReturnType::Result => quote! { #fn_name(#(#args1),*) }, + _ => panic!("return type should not contain Option"), + }; + quote! { + Ok(Box::new(crate::expr::template::#template_struct::<#(#arg_arrays),*, #ret_array, _>::new( + #(#exprs),*, + return_type, + |#(#args),*| #func, + ))) + } + }; + Ok(quote! { + |return_type, children| { + use risingwave_common::array::*; + use risingwave_common::types::*; + use risingwave_pb::expr::expr_node::RexNode; + + crate::ensure!(children.len() == #num_args); + let mut iter = children.into_iter(); + #(let #exprs0 = iter.next().unwrap();)* + + #build_expr + } + }) + } +} diff --git a/src/expr/macro/src/lib.rs b/src/expr/macro/src/lib.rs new file mode 100644 index 0000000000000..9e00e6edb6505 --- /dev/null +++ b/src/expr/macro/src/lib.rs @@ -0,0 +1,114 @@ +// Copyright 2023 RisingWave Labs +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use proc_macro::TokenStream; +use proc_macro2::TokenStream as TokenStream2; +use quote::ToTokens; +use syn::{parse_macro_input, Error, Result}; + +mod gen; +mod parse; +mod types; +mod utils; + +#[proc_macro_attribute] +pub fn function(attr: TokenStream, item: TokenStream) -> TokenStream { + let attr = parse_macro_input!(attr as syn::AttributeArgs); + let item = parse_macro_input!(item as syn::ItemFn); + + fn inner(attr: syn::AttributeArgs, mut item: syn::ItemFn) -> Result { + let fn_attr = FunctionAttr::parse(&attr, &mut item)?; + + let mut tokens = item.into_token_stream(); + tokens.extend(fn_attr.generate_descriptors(false)?); + Ok(tokens) + } + match inner(attr, item) { + Ok(tokens) => tokens.into(), + Err(e) => e.to_compile_error().into(), + } +} + +#[proc_macro_attribute] +pub fn build_function(attr: TokenStream, item: TokenStream) -> TokenStream { + let attr = parse_macro_input!(attr as syn::AttributeArgs); + let item = parse_macro_input!(item as syn::ItemFn); + + fn inner(attr: syn::AttributeArgs, mut item: syn::ItemFn) -> Result { + let fn_attr = FunctionAttr::parse(&attr, &mut item)?; + + let mut tokens = item.into_token_stream(); + tokens.extend(fn_attr.generate_descriptors(true)?); + Ok(tokens) + } + match inner(attr, item) { + Ok(tokens) => tokens.into(), + Err(e) => e.to_compile_error().into(), + } +} + +#[derive(Debug)] +struct FunctionAttr { + name: String, + args: Vec, + ret: String, + batch: Option, + user_fn: UserFunctionAttr, +} + +#[derive(Debug, Clone)] +struct UserFunctionAttr { + /// Function name + name: String, + /// The last argument type is `&mut dyn Write`. + write: bool, + /// The argument type are `Option`s. + arg_option: bool, + /// The return type. + return_type: ReturnType, + /// The number of generic types. + generic: usize, + // /// `#[list(0)]` in arguments. + // list: Vec<(usize, usize)>, + // /// `#[struct(0)]` in arguments. + // struct_: Vec<(usize, usize)>, +} + +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)] +enum ReturnType { + T, + Option, + Result, + ResultOption, +} + +impl ReturnType { + fn contains_result(&self) -> bool { + matches!(self, ReturnType::Result | ReturnType::ResultOption) + } + + fn contains_option(&self) -> bool { + matches!(self, ReturnType::Option | ReturnType::ResultOption) + } +} + +impl UserFunctionAttr { + fn is_writer_style(&self) -> bool { + self.write && !self.arg_option + } + + fn is_pure(&self) -> bool { + !self.write && !self.arg_option && self.return_type == ReturnType::T + } +} diff --git a/src/expr/macro/src/parse.rs b/src/expr/macro/src/parse.rs new file mode 100644 index 0000000000000..50a577328e27e --- /dev/null +++ b/src/expr/macro/src/parse.rs @@ -0,0 +1,160 @@ +// Copyright 2023 RisingWave Labs +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//! Parse the tokens of the macro. + +use proc_macro2::Span; + +use super::*; + +impl FunctionAttr { + /// Parse the attribute of the function macro. + pub fn parse(attr: &syn::AttributeArgs, item: &mut syn::ItemFn) -> Result { + let sig = attr.get(0).ok_or_else(|| { + Error::new( + Span::call_site(), + "expected #[function(\"name(arg1, arg2) -> ret\")]", + ) + })?; + + let sig_str = match sig { + syn::NestedMeta::Lit(syn::Lit::Str(lit_str)) => lit_str.value(), + _ => return Err(Error::new_spanned(sig, "expected string literal")), + }; + + let (name_args, ret) = sig_str + .split_once("->") + .ok_or_else(|| Error::new_spanned(sig, "expected '->'"))?; + let (name, args) = name_args + .split_once('(') + .ok_or_else(|| Error::new_spanned(sig, "expected '('"))?; + let args = args.trim_start().trim_end_matches([')', ' ']); + + let batch = attr.iter().find_map(|n| { + let syn::NestedMeta::Meta(syn::Meta::NameValue(nv)) = n else { return None }; + if !nv.path.is_ident("batch") { + return None; + }; + let syn::Lit::Str(ref lit_str) = nv.lit else { return None }; + Some(lit_str.value()) + }); + + let user_fn = UserFunctionAttr::parse(item)?; + + Ok(FunctionAttr { + name: name.trim().to_string(), + args: if args.is_empty() { + vec![] + } else { + args.split(',').map(|s| s.trim().to_string()).collect() + }, + ret: ret.trim().to_string(), + batch, + user_fn, + }) + } +} + +impl UserFunctionAttr { + fn parse(item: &mut syn::ItemFn) -> Result { + Ok(UserFunctionAttr { + name: item.sig.ident.to_string(), + write: last_arg_is_write(item), + arg_option: args_are_all_option(item), + return_type: return_type(item), + generic: item.sig.generics.params.len(), + // prebuild: extract_prebuild_arg(item), + }) + } +} + +/// Check if the last argument is `&mut dyn Write`. +fn last_arg_is_write(item: &syn::ItemFn) -> bool { + let Some(syn::FnArg::Typed(arg)) = item.sig.inputs.last() else { return false }; + let syn::Type::Reference(syn::TypeReference { elem, .. }) = arg.ty.as_ref() else { return false }; + let syn::Type::TraitObject(syn::TypeTraitObject { bounds, .. }) = elem.as_ref() else { return false }; + let Some(syn::TypeParamBound::Trait(syn::TraitBound { path, .. })) = bounds.first() else { return false }; + path.segments.last().map_or(false, |s| s.ident == "Write") +} + +/// Check if all arguments are `Option`s. +fn args_are_all_option(item: &syn::ItemFn) -> bool { + if item.sig.inputs.is_empty() { + return false; + } + for arg in &item.sig.inputs { + let syn::FnArg::Typed(arg) = arg else { return false }; + let syn::Type::Path(path) = arg.ty.as_ref() else { return false }; + let Some(seg) = path.path.segments.last() else { return false }; + if seg.ident != "Option" { + return false; + } + } + true +} + +/// Check the return type. +fn return_type(item: &syn::ItemFn) -> ReturnType { + if return_value_is_result_option(item) { + ReturnType::ResultOption + } else if return_value_is(item, "Result") { + ReturnType::Result + } else if return_value_is(item, "Option") { + ReturnType::Option + } else { + ReturnType::T + } +} + +/// Check if the return value is `type_`. +fn return_value_is(item: &syn::ItemFn, type_: &str) -> bool { + let syn::ReturnType::Type(_, ty) = &item.sig.output else { return false }; + let syn::Type::Path(path) = ty.as_ref() else { return false }; + let Some(seg) = path.path.segments.last() else { return false }; + seg.ident == type_ +} + +/// Check if the return value is `Result>`. +fn return_value_is_result_option(item: &syn::ItemFn) -> bool { + let syn::ReturnType::Type(_, ty) = &item.sig.output else { return false }; + let syn::Type::Path(path) = ty.as_ref() else { return false }; + let Some(seg) = path.path.segments.last() else { return false }; + if seg.ident != "Result" { + return false; + } + let syn::PathArguments::AngleBracketed(args) = &seg.arguments else { return false }; + let Some(syn::GenericArgument::Type(ty)) = args.args.first() else { return false }; + let syn::Type::Path(path) = ty else { return false }; + let Some(seg) = path.path.segments.last() else { return false }; + seg.ident == "Option" +} + +/// Extract `#[prebuild("function_name")]` from arguments. +fn _extract_prebuild_arg(item: &mut syn::ItemFn) -> Option<(usize, String)> { + for (i, arg) in item.sig.inputs.iter_mut().enumerate() { + let syn::FnArg::Typed(arg) = arg else { continue }; + if let Some(idx) = arg + .attrs + .iter_mut() + .position(|att| att.path.is_ident("prebuild")) + { + let attr = arg.attrs.remove(idx); + // XXX: this is a hack to parse a string literal from token stream + let s = attr.tokens.to_string(); + let s = s.trim_start_matches("(\"").trim_end_matches("\")"); + return Some((i, s.to_string())); + } + } + None +} diff --git a/src/expr/macro/src/types.rs b/src/expr/macro/src/types.rs new file mode 100644 index 0000000000000..9c011496c959e --- /dev/null +++ b/src/expr/macro/src/types.rs @@ -0,0 +1,189 @@ +// Copyright 2023 RisingWave Labs +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//! This module provides utility functions for SQL data type conversion and manipulation. + +/// Expands a type wildcard string into a list of concrete types. +pub fn expand_type_wildcard(ty: &str) -> Vec<&str> { + match ty { + "*" => vec![ + "boolean", + "int16", + "int32", + "int64", + "float32", + "float64", + "decimal", + "serial", + "date", + "time", + "timestamp", + "timestamptz", + "interval", + "varchar", + "bytea", + "jsonb", + "struct", + "list", + ], + "*int" => vec!["int16", "int32", "int64"], + "*number" => vec!["int16", "int32", "int64", "float32", "float64", "decimal"], + _ => vec![ty], + } +} + +/// Maps a data type to its corresponding type name. +pub fn to_data_type_name(ty: &str) -> Option<&str> { + Some(match ty { + "boolean" => "Boolean", + "int16" => "Int16", + "int32" => "Int32", + "int64" => "Int64", + "float32" => "Float32", + "float64" => "Float64", + "decimal" => "Decimal", + "serial" => "Serial", + "date" => "Date", + "time" => "Time", + "timestamp" => "Timestamp", + "timestamptz" => "Timestamptz", + "interval" => "Interval", + "varchar" => "Varchar", + "bytea" => "Bytea", + "jsonb" => "Jsonb", + "struct" => "Struct", + "list" => "List", + _ => return None, + }) +} + +/// Computes the minimal compatible type between a pair of data types. +pub fn min_compatible_type(types: &[impl AsRef]) -> &str { + if types.len() == 1 { + return types[0].as_ref(); + } + assert_eq!(types.len(), 2); + match (types[0].as_ref(), types[1].as_ref()) { + (a, b) if a == b => a, + + ("int16", "int16") => "int16", + ("int16", "int32") => "int32", + ("int16", "int64") => "int64", + ("int16", "float32") => "float64", + ("int16", "float64") => "float64", + ("int16", "decimal") => "decimal", + + ("int32", "int16") => "int32", + ("int32", "int32") => "int32", + ("int32", "int64") => "int64", + ("int32", "float32") => "float64", + ("int32", "float64") => "float64", + ("int32", "decimal") => "decimal", + + ("int64", "int16") => "int64", + ("int64", "int32") => "int64", + ("int64", "int64") => "int64", + ("int64", "float32") => "float64", + ("int64", "float64") => "float64", + ("int64", "decimal") => "decimal", + + ("float32", "int16") => "float64", + ("float32", "int32") => "float64", + ("float32", "int64") => "float64", + ("float32", "float32") => "float32", + ("float32", "float64") => "float64", + ("float32", "decimal") => "float64", + + ("float64", "int16") => "float64", + ("float64", "int32") => "float64", + ("float64", "int64") => "float64", + ("float64", "float32") => "float64", + ("float64", "float64") => "float64", + ("float64", "decimal") => "float64", + + ("decimal", "int16") => "decimal", + ("decimal", "int32") => "decimal", + ("decimal", "int64") => "decimal", + ("decimal", "float32") => "float64", + ("decimal", "float64") => "float64", + ("decimal", "decimal") => "decimal", + + ("date", "timestamp") => "timestamp", + ("timestamp", "date") => "timestamp", + ("time", "interval") => "interval", + ("interval", "time") => "interval", + + (a, b) => panic!("unknown minimal compatible type for {a:?} and {b:?}"), + } +} + +/// Maps a data type to its corresponding array type name. +pub fn to_array_type(ty: &str) -> &str { + match ty { + "boolean" => "BoolArray", + "int16" => "I16Array", + "int32" => "I32Array", + "int64" => "I64Array", + "float32" => "F32Array", + "float64" => "F64Array", + "decimal" => "DecimalArray", + "serial" => "SerialArray", + "date" => "NaiveDateArray", + "time" => "NaiveTimeArray", + "timestamp" => "NaiveDateTimeArray", + "timestamptz" => "I64Array", + "interval" => "IntervalArray", + "varchar" => "Utf8Array", + "bytea" => "BytesArray", + "jsonb" => "JsonbArray", + "struct" => "StructArray", + "list" => "ListArray", + _ => panic!("unknown type: {ty:?}"), + } +} + +/// Maps a data type to its corresponding `ScalarRef` type name. +pub fn to_data_type(ty: &str) -> &str { + match ty { + "boolean" => "bool", + "int16" => "i16", + "int32" => "i32", + "int64" => "i64", + "float32" => "OrderedF32", + "float64" => "OrderedF64", + "decimal" => "Decimal", + "serial" => "Serial", + "date" => "NaiveDateWrapper", + "time" => "NaiveTimeWrapper", + "timestamp" => "NaiveDateTimeWrapper", + "timestamptz" => "i64", + "interval" => "IntervalUnit", + "varchar" => "&str", + "bytea" => "&[u8]", + "jsonb" => "JsonbRef<'_>", + "struct" => "StructRef<'_>", + "list" => "ListRef<'_>", + _ => panic!("unknown type: {ty:?}"), + } +} + +/// Checks if a data type is primitive. +pub fn is_primitive(ty: &str) -> bool { + match ty { + "int16" | "int32" | "int64" | "float32" | "float64" | "decimal" | "date" | "time" + | "timestamp" | "timestamptz" | "interval" | "serial" => true, + "boolean" | "varchar" | "bytea" | "jsonb" | "struct" | "list" => false, + _ => panic!("unknown type: {ty:?}"), + } +} diff --git a/src/expr/macro/src/utils.rs b/src/expr/macro/src/utils.rs new file mode 100644 index 0000000000000..788d09857cc93 --- /dev/null +++ b/src/expr/macro/src/utils.rs @@ -0,0 +1,29 @@ +// Copyright 2023 RisingWave Labs +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +/// Convert a string from `snake_case` to `CamelCase`. +pub fn to_camel_case(input: &str) -> String { + input + .split('_') + .map(|word| { + let mut chars = word.chars(); + match chars.next() { + None => String::new(), + Some(first_char) => { + format!("{}{}", first_char.to_uppercase(), chars.as_str()) + } + } + }) + .collect() +} diff --git a/src/expr/src/error.rs b/src/expr/src/error.rs index 1d03962c66772..fe48467fe18a9 100644 --- a/src/expr/src/error.rs +++ b/src/expr/src/error.rs @@ -66,6 +66,9 @@ pub enum ExprError { #[error("UDF error: {0}")] Udf(#[from] risingwave_udf::Error), + + #[error("not a constant")] + NotConstant, } impl From for RwError { diff --git a/src/expr/src/expr/build_expr_from_prost.rs b/src/expr/src/expr/build_expr_from_prost.rs index 7c86957ed1dd9..ff0abcc5b605d 100644 --- a/src/expr/src/expr/build_expr_from_prost.rs +++ b/src/expr/src/expr/build_expr_from_prost.rs @@ -12,116 +12,73 @@ // See the License for the specific language governing permissions and // limitations under the License. -use risingwave_common::try_match_expand; +use itertools::Itertools; use risingwave_common::types::DataType; -use risingwave_common::util::value_encoding::deserialize_datum; -use risingwave_pb::expr::expr_node::{RexNode, Type}; -use risingwave_pb::expr::{ExprNode, FunctionCall}; +use risingwave_pb::expr::expr_node::{PbType, RexNode}; +use risingwave_pb::expr::ExprNode; use super::expr_array_concat::ArrayConcatExpression; -use super::expr_binary_bytes::{ - new_ltrim_characters, new_repeat, new_rtrim_characters, new_substr_start, new_to_char, - new_trim_characters, -}; -use super::expr_binary_nonnull::{ - new_binary_expr, new_date_trunc_expr, new_like_default, new_to_timestamp, -}; -use super::expr_binary_nullable::new_nullable_binary_expr; use super::expr_case::CaseExpression; use super::expr_coalesce::CoalesceExpression; use super::expr_concat_ws::ConcatWsExpression; use super::expr_field::FieldExpression; use super::expr_in::InExpression; use super::expr_nested_construct::NestedConstructExpression; -use super::expr_quaternary_bytes::new_overlay_for_exp; use super::expr_regexp::RegexpMatchExpression; use super::expr_some_all::SomeAllExpression; -use super::expr_ternary_bytes::{ - new_overlay_exp, new_replace_expr, new_split_part_expr, new_substr_start_end, - new_translate_expr, -}; -use super::expr_to_char_const_tmpl::{ExprToCharConstTmpl, ExprToCharConstTmplContext}; -use super::expr_to_timestamp_const_tmpl::{ - ExprToTimestampConstTmpl, ExprToTimestampConstTmplContext, -}; use super::expr_udf::UdfExpression; -use super::expr_unary::{ - new_length_default, new_ltrim_expr, new_rtrim_expr, new_trim_expr, new_unary_expr, -}; use super::expr_vnode::VnodeExpression; -use crate::expr::expr_array_distinct::ArrayDistinctExpression; -use crate::expr::expr_array_length::ArrayLengthExpression; -use crate::expr::expr_array_to_string::ArrayToStringExpression; -use crate::expr::expr_binary_nonnull::new_tumble_start; -use crate::expr::expr_ternary::new_tumble_start_offset; -use crate::expr::{ - build_from_prost as expr_build_from_prost, BoxedExpression, Expression, InputRefExpression, - LiteralExpression, -}; -use crate::vector_op::to_char::compile_pattern_to_chrono; -use crate::{bail, ensure, ExprError, Result}; +use crate::expr::{BoxedExpression, Expression, InputRefExpression, LiteralExpression}; +use crate::sig::func::FUNC_SIG_MAP; +use crate::{bail, ExprError, Result}; +/// Build an expression from protobuf. pub fn build_from_prost(prost: &ExprNode) -> Result { - use risingwave_pb::expr::expr_node::Type::*; - - match prost.get_expr_type().unwrap() { - // Fixed number of arguments and based on `Unary/Binary/Ternary/...Expression` - Cast | Upper | Lower | Md5 | Not | IsTrue | IsNotTrue | IsFalse | IsNotFalse | IsNull - | IsNotNull | Neg | Ascii | Abs | Ceil | Floor | Round | Exp | BitwiseNot | CharLength - | BoolOut | OctetLength | BitLength | ToTimestamp | JsonbTypeof | JsonbArrayLength => { - build_unary_expr_prost(prost) - } - Equal | NotEqual | LessThan | LessThanOrEqual | GreaterThan | GreaterThanOrEqual | Add - | Subtract | Multiply | Divide | Modulus | Extract | RoundDigit | Pow | Position - | BitwiseShiftLeft | BitwiseShiftRight | BitwiseAnd | BitwiseOr | BitwiseXor | ConcatOp - | AtTimeZone | CastWithTimeZone | JsonbAccessInner | JsonbAccessStr => { - build_binary_expr_prost(prost) - } - And | Or | IsDistinctFrom | IsNotDistinctFrom | ArrayAccess | FormatType => { - build_nullable_binary_expr_prost(prost) + use PbType as E; + + if let Some(RexNode::FuncCall(call)) = &prost.rex_node { + let args = call + .children + .iter() + .map(|c| DataType::from(c.get_return_type().unwrap()).into()) + .collect_vec(); + let return_type = DataType::from(prost.get_return_type().unwrap()); + + if let Some(desc) = FUNC_SIG_MAP.get(prost.expr_type(), &args, (&return_type).into()) { + let RexNode::FuncCall(func_call) = prost.get_rex_node().unwrap() else { + bail!("Expected RexNode::FuncCall"); + }; + + let children = func_call + .get_children() + .iter() + .map(build_from_prost) + .try_collect()?; + return (desc.build)(return_type, children); } - ToChar => build_to_char_expr(prost), - ToTimestamp1 => build_to_timestamp_expr(prost), - Length => build_length_expr(prost), - Replace => build_replace_expr(prost), - Like => build_like_expr(prost), - Repeat => build_repeat_expr(prost), - SplitPart => build_split_part_expr(prost), - Translate => build_translate_expr(prost), - - // Variable number of arguments and based on `Unary/Binary/Ternary/...Expression` - TumbleStart => build_tumble_start_expr(prost), - Substr => build_substr_expr(prost), - Overlay => build_overlay_expr(prost), - Trim => build_trim_expr(prost), - Ltrim => build_ltrim_expr(prost), - Rtrim => build_rtrim_expr(prost), - DateTrunc => build_date_trunc_expr(prost), + } + match prost.expr_type() { // Dedicated types - All | Some => build_some_all_expr_prost(prost), - In => InExpression::try_from(prost).map(Expression::boxed), - Case => CaseExpression::try_from(prost).map(Expression::boxed), - Coalesce => CoalesceExpression::try_from(prost).map(Expression::boxed), - ConcatWs => ConcatWsExpression::try_from(prost).map(Expression::boxed), - ConstantValue => LiteralExpression::try_from(prost).map(Expression::boxed), - InputRef => InputRefExpression::try_from(prost).map(Expression::boxed), - Field => FieldExpression::try_from(prost).map(Expression::boxed), - Array => NestedConstructExpression::try_from(prost).map(Expression::boxed), - Row => NestedConstructExpression::try_from(prost).map(Expression::boxed), - RegexpMatch => RegexpMatchExpression::try_from(prost).map(Expression::boxed), - ArrayCat | ArrayAppend | ArrayPrepend => { + E::All | E::Some => SomeAllExpression::try_from(prost).map(Expression::boxed), + E::In => InExpression::try_from(prost).map(Expression::boxed), + E::Case => CaseExpression::try_from(prost).map(Expression::boxed), + E::Coalesce => CoalesceExpression::try_from(prost).map(Expression::boxed), + E::ConcatWs => ConcatWsExpression::try_from(prost).map(Expression::boxed), + E::ConstantValue => LiteralExpression::try_from(prost).map(Expression::boxed), + E::InputRef => InputRefExpression::try_from(prost).map(Expression::boxed), + E::Field => FieldExpression::try_from(prost).map(Expression::boxed), + E::Array => NestedConstructExpression::try_from(prost).map(Expression::boxed), + E::Row => NestedConstructExpression::try_from(prost).map(Expression::boxed), + E::RegexpMatch => RegexpMatchExpression::try_from(prost).map(Expression::boxed), + E::ArrayCat | E::ArrayAppend | E::ArrayPrepend => { // Now we implement these three functions as a single expression for the // sake of simplicity. If performance matters at some time, we can split // the implementation to improve performance. ArrayConcatExpression::try_from(prost).map(Expression::boxed) } - ArrayToString => ArrayToStringExpression::try_from(prost).map(Expression::boxed), - ArrayDistinct => ArrayDistinctExpression::try_from(prost).map(Expression::boxed), - ArrayLength => ArrayLengthExpression::try_from(prost).map(Expression::boxed), - Vnode => VnodeExpression::try_from(prost).map(Expression::boxed), - Now => build_now_expr(prost), - Udf => UdfExpression::try_from(prost).map(Expression::boxed), + E::Vnode => VnodeExpression::try_from(prost).map(Expression::boxed), + E::Udf => UdfExpression::try_from(prost).map(Expression::boxed), _ => Err(ExprError::UnsupportedFunction(format!( "{:?}", prost.get_expr_type() @@ -129,458 +86,34 @@ pub fn build_from_prost(prost: &ExprNode) -> Result { } } -fn get_children_and_return_type(prost: &ExprNode) -> Result<(Vec, DataType)> { +/// Build an expression. +pub fn build( + func: PbType, + ret_type: DataType, + children: Vec, +) -> Result { + let args = children + .iter() + .map(|c| c.return_type().into()) + .collect_vec(); + let desc = FUNC_SIG_MAP + .get(func, &args, (&ret_type).into()) + .ok_or_else(|| { + ExprError::UnsupportedFunction(format!( + "{:?}({}) -> {:?}", + func, + args.iter().map(|t| format!("{:?}", t)).join(", "), + ret_type + )) + })?; + (desc.build)(ret_type, children) +} + +pub(super) fn get_children_and_return_type(prost: &ExprNode) -> Result<(&[ExprNode], DataType)> { let ret_type = DataType::from(prost.get_return_type().unwrap()); if let RexNode::FuncCall(func_call) = prost.get_rex_node().unwrap() { - Ok((func_call.get_children().to_vec(), ret_type)) + Ok((func_call.get_children(), ret_type)) } else { bail!("Expected RexNode::FuncCall"); } } - -fn build_unary_expr_prost(prost: &ExprNode) -> Result { - let (children, ret_type) = get_children_and_return_type(prost)?; - let [child]: [_; 1] = children.try_into().unwrap(); - let child_expr = expr_build_from_prost(&child)?; - new_unary_expr(prost.get_expr_type().unwrap(), ret_type, child_expr) -} - -fn build_binary_expr_prost(prost: &ExprNode) -> Result { - let (children, ret_type) = get_children_and_return_type(prost)?; - let [left_child, right_child]: [_; 2] = children.try_into().unwrap(); - let left_expr = expr_build_from_prost(&left_child)?; - let right_expr = expr_build_from_prost(&right_child)?; - new_binary_expr( - prost.get_expr_type().unwrap(), - ret_type, - left_expr, - right_expr, - ) -} - -fn build_nullable_binary_expr_prost(prost: &ExprNode) -> Result { - let (children, ret_type) = get_children_and_return_type(prost)?; - let [left_child, right_child]: [_; 2] = children.try_into().unwrap(); - let left_expr = expr_build_from_prost(&left_child)?; - let right_expr = expr_build_from_prost(&right_child)?; - new_nullable_binary_expr( - prost.get_expr_type().unwrap(), - ret_type, - left_expr, - right_expr, - ) -} - -fn build_overlay_expr(prost: &ExprNode) -> Result { - let (children, ret_type) = get_children_and_return_type(prost)?; - ensure!(children.len() == 3 || children.len() == 4); - - let s = expr_build_from_prost(&children[0])?; - let new_sub_str = expr_build_from_prost(&children[1])?; - let start = expr_build_from_prost(&children[2])?; - - if children.len() == 3 { - Ok(new_overlay_exp(s, new_sub_str, start, ret_type)) - } else if children.len() == 4 { - let count = expr_build_from_prost(&children[3])?; - Ok(new_overlay_for_exp(s, new_sub_str, start, count, ret_type)) - } else { - unreachable!() - } -} - -fn build_repeat_expr(prost: &ExprNode) -> Result { - let (children, ret_type) = get_children_and_return_type(prost)?; - let [left_child, right_child]: [_; 2] = children.try_into().unwrap(); - let left_expr = expr_build_from_prost(&left_child)?; - let right_expr = expr_build_from_prost(&right_child)?; - Ok(new_repeat(left_expr, right_expr, ret_type)) -} - -fn build_substr_expr(prost: &ExprNode) -> Result { - let (children, ret_type) = get_children_and_return_type(prost)?; - let child = expr_build_from_prost(&children[0])?; - ensure!(children.len() == 2 || children.len() == 3); - if children.len() == 2 { - let off = expr_build_from_prost(&children[1])?; - Ok(new_substr_start(child, off, ret_type)) - } else if children.len() == 3 { - let off = expr_build_from_prost(&children[1])?; - let len = expr_build_from_prost(&children[2])?; - Ok(new_substr_start_end(child, off, len, ret_type)) - } else { - unreachable!() - } -} - -fn build_trim_expr(prost: &ExprNode) -> Result { - let (children, ret_type) = get_children_and_return_type(prost)?; - ensure!(!children.is_empty() && children.len() <= 2); - let original = expr_build_from_prost(&children[0])?; - match children.len() { - 1 => Ok(new_trim_expr(original, ret_type)), - 2 => { - let characters = expr_build_from_prost(&children[1])?; - Ok(new_trim_characters(original, characters, ret_type)) - } - _ => unreachable!(), - } -} - -fn build_ltrim_expr(prost: &ExprNode) -> Result { - let (children, ret_type) = get_children_and_return_type(prost)?; - ensure!(!children.is_empty() && children.len() <= 2); - let original = expr_build_from_prost(&children[0])?; - match children.len() { - 1 => Ok(new_ltrim_expr(original, ret_type)), - 2 => { - let characters = expr_build_from_prost(&children[1])?; - Ok(new_ltrim_characters(original, characters, ret_type)) - } - _ => unreachable!(), - } -} - -fn build_rtrim_expr(prost: &ExprNode) -> Result { - let (children, ret_type) = get_children_and_return_type(prost)?; - ensure!(!children.is_empty() && children.len() <= 2); - let original = expr_build_from_prost(&children[0])?; - match children.len() { - 1 => Ok(new_rtrim_expr(original, ret_type)), - 2 => { - let characters = expr_build_from_prost(&children[1])?; - Ok(new_rtrim_characters(original, characters, ret_type)) - } - _ => unreachable!(), - } -} - -fn build_replace_expr(prost: &ExprNode) -> Result { - let (children, ret_type) = get_children_and_return_type(prost)?; - ensure!(children.len() == 3); - let s = expr_build_from_prost(&children[0])?; - let from_str = expr_build_from_prost(&children[1])?; - let to_str = expr_build_from_prost(&children[2])?; - Ok(new_replace_expr(s, from_str, to_str, ret_type)) -} - -fn build_date_trunc_expr(prost: &ExprNode) -> Result { - let (children, ret_type) = get_children_and_return_type(prost)?; - ensure!(children.len() == 2 || children.len() == 3); - let field = expr_build_from_prost(&children[0])?; - let source = expr_build_from_prost(&children[1])?; - let time_zone = if let Some(child) = children.get(2) { - Some((expr_build_from_prost(child)?, expr_build_from_prost(child)?)) - } else { - None - }; - Ok(new_date_trunc_expr(ret_type, field, source, time_zone)) -} - -fn build_tumble_start_expr(prost: &ExprNode) -> Result { - let (children, ret_type) = get_children_and_return_type(prost)?; - ensure!(children.len() == 2 || children.len() == 3); - let time = expr_build_from_prost(&children[0])?; - let window_size = expr_build_from_prost(&children[1])?; - if children.len() == 2 { - new_tumble_start(time, window_size, ret_type) - } else if children.len() == 3 { - let offset = expr_build_from_prost(&children[2])?; - new_tumble_start_offset(time, window_size, offset, ret_type) - } else { - unreachable!() - } -} - -fn build_length_expr(prost: &ExprNode) -> Result { - let (children, ret_type) = get_children_and_return_type(prost)?; - // TODO: add encoding length expr - let [child]: [_; 1] = children.try_into().unwrap(); - let child = expr_build_from_prost(&child)?; - Ok(new_length_default(child, ret_type)) -} - -fn build_like_expr(prost: &ExprNode) -> Result { - let (children, ret_type) = get_children_and_return_type(prost)?; - ensure!(children.len() == 2); - let expr_ia1 = expr_build_from_prost(&children[0])?; - let expr_ia2 = expr_build_from_prost(&children[1])?; - Ok(new_like_default(expr_ia1, expr_ia2, ret_type)) -} - -fn build_translate_expr(prost: &ExprNode) -> Result { - let (children, ret_type) = get_children_and_return_type(prost)?; - ensure!(children.len() == 3); - let s = expr_build_from_prost(&children[0])?; - let match_str = expr_build_from_prost(&children[1])?; - let replace_str = expr_build_from_prost(&children[2])?; - Ok(new_translate_expr(s, match_str, replace_str, ret_type)) -} - -fn build_split_part_expr(prost: &ExprNode) -> Result { - let (children, ret_type) = get_children_and_return_type(prost)?; - ensure!(children.len() == 3); - let string_expr = expr_build_from_prost(&children[0])?; - let delimiter_expr = expr_build_from_prost(&children[1])?; - let nth_expr = expr_build_from_prost(&children[2])?; - Ok(new_split_part_expr( - string_expr, - delimiter_expr, - nth_expr, - ret_type, - )) -} - -fn build_to_char_expr(prost: &ExprNode) -> Result { - let (children, ret_type) = get_children_and_return_type(prost)?; - ensure!(children.len() == 2); - let data_expr = expr_build_from_prost(&children[0])?; - let tmpl_node = &children[1]; - if let RexNode::Constant(tmpl_value) = tmpl_node.get_rex_node().unwrap() - && let Ok(Some(tmpl)) = deserialize_datum(tmpl_value.get_body().as_slice(), &DataType::from(tmpl_node.get_return_type().unwrap())) - { - let tmpl = tmpl.as_utf8(); - let pattern = compile_pattern_to_chrono(tmpl); - - Ok(ExprToCharConstTmpl { - ctx: ExprToCharConstTmplContext { - chrono_pattern: pattern, - }, - child: data_expr, - }.boxed()) - } else { - let tmpl_expr = expr_build_from_prost(&children[1])?; - Ok(new_to_char(data_expr, tmpl_expr, ret_type)) - } -} - -pub fn build_now_expr(prost: &ExprNode) -> Result { - let rex_node = try_match_expand!(prost.get_rex_node(), Ok)?; - let RexNode::FuncCall(func_call_node) = rex_node else { - bail!("Expected RexNode::FuncCall in Now"); - }; - let Some(bind_timestamp) = func_call_node.children.first() else { - bail!("Expected epoch timestamp bound into Now"); - }; - LiteralExpression::try_from(bind_timestamp).map(Expression::boxed) -} - -pub fn build_to_timestamp_expr(prost: &ExprNode) -> Result { - let (children, ret_type) = get_children_and_return_type(prost)?; - ensure!(children.len() == 2); - let data_expr = expr_build_from_prost(&children[0])?; - let tmpl_node = &children[1]; - if let RexNode::Constant(tmpl_value) = tmpl_node.get_rex_node().unwrap() - && let Ok(Some(tmpl)) = deserialize_datum(tmpl_value.get_body().as_slice(), &DataType::from(tmpl_node.get_return_type().unwrap())) - { - let tmpl = tmpl.as_utf8(); - let pattern = compile_pattern_to_chrono(tmpl); - - Ok(ExprToTimestampConstTmpl { - ctx: ExprToTimestampConstTmplContext { - chrono_pattern: pattern, - }, - child: data_expr, - }.boxed()) - } else { - let tmpl_expr = expr_build_from_prost(&children[1])?; - Ok(new_to_timestamp(data_expr, tmpl_expr, ret_type)) - } -} - -pub fn build_some_all_expr_prost(prost: &ExprNode) -> Result { - let outer_expr_type = prost.get_expr_type().unwrap(); - let (outer_children, outer_return_type) = get_children_and_return_type(prost)?; - ensure!(matches!(outer_return_type, DataType::Boolean)); - - let mut inner_expr_type = outer_children[0].get_expr_type().unwrap(); - let (mut inner_children, mut inner_return_type) = - get_children_and_return_type(&outer_children[0])?; - let mut stack = vec![]; - while inner_children.len() != 2 { - stack.push((inner_expr_type, inner_return_type)); - inner_expr_type = inner_children[0].get_expr_type().unwrap(); - (inner_children, inner_return_type) = get_children_and_return_type(&inner_children[0])?; - } - - let [left_child, right_child]: [_; 2] = inner_children.try_into().unwrap(); - let left_expr = expr_build_from_prost(&left_child)?; - let right_expr = expr_build_from_prost(&right_child)?; - - let DataType::List { datatype: right_expr_return_type } = right_expr.return_type() else { - bail!("Expect Array Type"); - }; - - let eval_func = { - let left_expr_input_ref = ExprNode { - expr_type: Type::InputRef as i32, - return_type: Some(left_expr.return_type().to_protobuf()), - rex_node: Some(RexNode::InputRef(0)), - }; - let right_expr_input_ref = ExprNode { - expr_type: Type::InputRef as i32, - return_type: Some(right_expr_return_type.to_protobuf()), - rex_node: Some(RexNode::InputRef(1)), - }; - let mut root_expr_node = ExprNode { - expr_type: inner_expr_type as i32, - return_type: Some(inner_return_type.to_protobuf()), - rex_node: Some(RexNode::FuncCall(FunctionCall { - children: vec![left_expr_input_ref, right_expr_input_ref], - })), - }; - while let Some((expr_type, return_type)) = stack.pop() { - root_expr_node = ExprNode { - expr_type: expr_type as i32, - return_type: Some(return_type.to_protobuf()), - rex_node: Some(RexNode::FuncCall(FunctionCall { - children: vec![root_expr_node], - })), - } - } - expr_build_from_prost(&root_expr_node)? - }; - - Ok(Box::new(SomeAllExpression::new( - left_expr, - right_expr, - outer_expr_type, - eval_func, - ))) -} - -#[cfg(test)] -mod tests { - use std::vec; - - use risingwave_common::array::{ArrayImpl, DataChunk, Utf8Array}; - use risingwave_common::types::Scalar; - use risingwave_common::util::value_encoding::serialize_datum; - use risingwave_pb::data::data_type::TypeName; - use risingwave_pb::data::{PbDataType, PbDatum}; - use risingwave_pb::expr::expr_node::{RexNode, Type}; - use risingwave_pb::expr::{ExprNode, FunctionCall}; - - use super::*; - - #[tokio::test] - async fn test_array_access_expr() { - let values = FunctionCall { - children: vec![ - ExprNode { - expr_type: Type::ConstantValue as i32, - return_type: Some(PbDataType { - type_name: TypeName::Varchar as i32, - ..Default::default() - }), - rex_node: Some(RexNode::Constant(PbDatum { - body: serialize_datum(Some("foo".into()).as_ref()), - })), - }, - ExprNode { - expr_type: Type::ConstantValue as i32, - return_type: Some(PbDataType { - type_name: TypeName::Varchar as i32, - ..Default::default() - }), - rex_node: Some(RexNode::Constant(PbDatum { - body: serialize_datum(Some("bar".into()).as_ref()), - })), - }, - ], - }; - let array_index = FunctionCall { - children: vec![ - ExprNode { - expr_type: Type::Array as i32, - return_type: Some(PbDataType { - type_name: TypeName::List as i32, - field_type: vec![PbDataType { - type_name: TypeName::Varchar as i32, - ..Default::default() - }], - ..Default::default() - }), - rex_node: Some(RexNode::FuncCall(values)), - }, - ExprNode { - expr_type: Type::ConstantValue as i32, - return_type: Some(PbDataType { - type_name: TypeName::Int32 as i32, - ..Default::default() - }), - rex_node: Some(RexNode::Constant(PbDatum { - body: serialize_datum(Some(1_i32.to_scalar_value()).as_ref()), - })), - }, - ], - }; - let access = ExprNode { - expr_type: Type::ArrayAccess as i32, - return_type: Some(PbDataType { - type_name: TypeName::Varchar as i32, - ..Default::default() - }), - rex_node: Some(RexNode::FuncCall(array_index)), - }; - let expr = build_nullable_binary_expr_prost(&access); - assert!(expr.is_ok()); - - let res = expr.unwrap().eval(&DataChunk::new_dummy(1)).await.unwrap(); - assert_eq!(*res, ArrayImpl::Utf8(Utf8Array::from_iter(["foo"]))); - } - - #[test] - fn test_build_extract_expr() { - let left = ExprNode { - expr_type: Type::ConstantValue as i32, - return_type: Some(PbDataType { - type_name: TypeName::Varchar as i32, - precision: 11, - ..Default::default() - }), - rex_node: Some(RexNode::Constant(PbDatum { - body: serialize_datum(Some("DAY".into()).as_ref()), - })), - }; - let right_date = ExprNode { - expr_type: Type::ConstantValue as i32, - return_type: Some(PbDataType { - type_name: TypeName::Date as i32, - ..Default::default() - }), - rex_node: None, - }; - let right_time = ExprNode { - expr_type: Type::ConstantValue as i32, - return_type: Some(PbDataType { - type_name: TypeName::Timestamp as i32, - ..Default::default() - }), - rex_node: None, - }; - - let expr = ExprNode { - expr_type: Type::Extract as i32, - return_type: Some(PbDataType { - type_name: TypeName::Int64 as i32, - ..Default::default() - }), - rex_node: Some(RexNode::FuncCall(FunctionCall { - children: vec![left.clone(), right_date], - })), - }; - assert!(build_binary_expr_prost(&expr).is_ok()); - let expr = ExprNode { - expr_type: Type::Extract as i32, - return_type: Some(PbDataType { - type_name: TypeName::Int64 as i32, - ..Default::default() - }), - rex_node: Some(RexNode::FuncCall(FunctionCall { - children: vec![left, right_time], - })), - }; - assert!(build_binary_expr_prost(&expr).is_ok()); - } -} diff --git a/src/expr/src/expr/data_types.rs b/src/expr/src/expr/data_types.rs index 489dfa45ed676..eebb7b63655a5 100644 --- a/src/expr/src/expr/data_types.rs +++ b/src/expr/src/expr/data_types.rs @@ -12,6 +12,9 @@ // See the License for the specific language governing permissions and // limitations under the License. +// We may remove the entire file once procedual macros are ready for aggregations. +#![allow(unused_imports)] + //! Macros containing all necessary information for a logical type. //! //! Each type macro will call the `$macro` with multiple parameters: diff --git a/src/expr/src/expr/expr_array_distinct.rs b/src/expr/src/expr/expr_array_distinct.rs index d9f0c3ad3427a..e0fc29bfa5988 100644 --- a/src/expr/src/expr/expr_array_distinct.rs +++ b/src/expr/src/expr/expr_array_distinct.rs @@ -12,18 +12,10 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::sync::Arc; - use itertools::Itertools; use risingwave_common::array::*; -use risingwave_common::row::OwnedRow; -use risingwave_common::types::{DataType, Datum, DatumRef, ScalarRefImpl, ToDatumRef}; -use risingwave_common::util::iter_util::ZipEqFast; -use risingwave_pb::expr::expr_node::{RexNode, Type}; -use risingwave_pb::expr::ExprNode; - -use crate::expr::{build_from_prost, BoxedExpression, Expression}; -use crate::{bail, ensure, ExprError, Result}; +use risingwave_common::types::ScalarRefImpl; +use risingwave_expr_macro::function; /// Returns a new array removing all the duplicates from the input array /// @@ -58,204 +50,28 @@ use crate::{bail, ensure, ExprError, Result}; /// select array_distinct(null); /// ``` -#[derive(Debug)] -pub struct ArrayDistinctExpression { - array: BoxedExpression, - return_type: DataType, -} - -impl<'a> TryFrom<&'a ExprNode> for ArrayDistinctExpression { - type Error = ExprError; - - fn try_from(prost: &'a ExprNode) -> Result { - ensure!(prost.get_expr_type().unwrap() == Type::ArrayDistinct); - let RexNode::FuncCall(func_call_node) = prost.get_rex_node().unwrap() else { - bail!("Expected RexNode::FuncCall"); - }; - let children = func_call_node.get_children(); - ensure!(children.len() == 1); - let array = build_from_prost(&children[0])?; - let return_type = array.return_type(); - Ok(Self { array, return_type }) - } -} - -#[async_trait::async_trait] -impl Expression for ArrayDistinctExpression { - fn return_type(&self) -> DataType { - self.return_type.clone() - } - - async fn eval(&self, input: &DataChunk) -> Result { - let array = self.array.eval_checked(input).await?; - let mut builder = self.return_type.create_array_builder(array.len()); - for (vis, arr) in input.vis().iter().zip_eq_fast(array.iter()) { - if !vis { - builder.append_null(); - } else { - builder.append_datum(&self.evaluate(arr)); - } - } - Ok(Arc::new(builder.finish())) - } - - async fn eval_row(&self, input: &OwnedRow) -> Result { - let array_data = self.array.eval_row(input).await?; - Ok(self.evaluate(array_data.to_datum_ref())) - } -} - -impl ArrayDistinctExpression { - fn evaluate(&self, array: DatumRef<'_>) -> Datum { - match array { - Some(ScalarRefImpl::List(array)) => Some( - ListValue::new( - array - .values_ref() - .into_iter() - .map(|x| x.map(ScalarRefImpl::into_scalar_impl)) - .unique() - .collect(), - ) - .into(), - ), - None => None, - Some(_) => unreachable!("the operand must be a list type"), - } - } +#[function("array_distinct(list) -> list")] +pub fn array_distinct(list: ListRef<'_>) -> ListValue { + ListValue::new( + list.values_ref() + .into_iter() + .map(|x| x.map(ScalarRefImpl::into_scalar_impl)) + .unique() + .collect(), + ) } #[cfg(test)] mod tests { - - use itertools::Itertools; - use risingwave_common::array::DataChunk; - use risingwave_common::types::ScalarImpl; - use risingwave_pb::data::PbDatum; - use risingwave_pb::expr::expr_node::{PbType, RexNode}; - use risingwave_pb::expr::{ExprNode, FunctionCall}; + use risingwave_common::types::Scalar; use super::*; - use crate::expr::{Expression, LiteralExpression}; - - fn make_i64_expr_node(value: i64) -> ExprNode { - ExprNode { - expr_type: PbType::ConstantValue as i32, - return_type: Some(DataType::Int64.to_protobuf()), - rex_node: Some(RexNode::Constant(PbDatum { - body: value.to_be_bytes().to_vec(), - })), - } - } - - fn make_i64_array_expr_node(values: Vec) -> ExprNode { - ExprNode { - expr_type: PbType::Array as i32, - return_type: Some( - DataType::List { - datatype: Box::new(DataType::Int64), - } - .to_protobuf(), - ), - rex_node: Some(RexNode::FuncCall(FunctionCall { - children: values.into_iter().map(make_i64_expr_node).collect(), - })), - } - } - - fn make_i64_array_array_expr_node(values: Vec>) -> ExprNode { - ExprNode { - expr_type: PbType::Array as i32, - return_type: Some( - DataType::List { - datatype: Box::new(DataType::List { - datatype: Box::new(DataType::Int64), - }), - } - .to_protobuf(), - ), - rex_node: Some(RexNode::FuncCall(FunctionCall { - children: values.into_iter().map(make_i64_array_expr_node).collect(), - })), - } - } #[test] - fn test_array_distinct_try_from() { - { - let array = make_i64_array_expr_node(vec![12]); - let expr = ExprNode { - expr_type: PbType::ArrayDistinct as i32, - return_type: Some( - DataType::List { - datatype: Box::new(DataType::Int64), - } - .to_protobuf(), - ), - rex_node: Some(RexNode::FuncCall(FunctionCall { - children: vec![array], - })), - }; - assert!(ArrayDistinctExpression::try_from(&expr).is_ok()); - } - - { - let array = make_i64_array_array_expr_node(vec![vec![42], vec![42]]); - let expr = ExprNode { - expr_type: PbType::ArrayDistinct as i32, - return_type: Some( - DataType::List { - datatype: Box::new(DataType::Int64), - } - .to_protobuf(), - ), - rex_node: Some(RexNode::FuncCall(FunctionCall { - children: vec![array], - })), - }; - assert!(ArrayDistinctExpression::try_from(&expr).is_ok()); - } - } - - fn make_i64_array_expr(values: Vec) -> BoxedExpression { - LiteralExpression::new( - DataType::List { - datatype: Box::new(DataType::Int64), - }, - Some(ListValue::new(values.into_iter().map(|x| Some(x.into())).collect()).into()), - ) - .boxed() - } - - #[tokio::test] - async fn test_array_distinct_array_of_primitives() { - let array = make_i64_array_expr(vec![42, 43, 42]); - let expr = ArrayDistinctExpression { - return_type: DataType::List { - datatype: Box::new(DataType::Int64), - }, - array, - }; - - let chunk = DataChunk::new_dummy(4) - .with_visibility([true, false, true, true].into_iter().collect()); - let expected_array = Some(ScalarImpl::List(ListValue::new(vec![ - Some(42i64.into()), - Some(43i64.into()), - ]))); - let expected = vec![ - expected_array.clone(), - None, - expected_array.clone(), - expected_array, - ]; - let actual = expr - .eval(&chunk) - .await - .unwrap() - .iter() - .map(|v| v.map(|s| s.into_scalar_impl())) - .collect_vec(); + fn test_array_distinct_array_of_primitives() { + let array = ListValue::new([42, 43, 42].into_iter().map(|x| Some(x.into())).collect()); + let expected = ListValue::new([42, 43].into_iter().map(|x| Some(x.into())).collect()); + let actual = array_distinct(array.as_scalar_ref()); assert_eq!(actual, expected); } diff --git a/src/expr/src/expr/expr_array_length.rs b/src/expr/src/expr/expr_array_length.rs index 94faf1ae91242..76d062d7be134 100644 --- a/src/expr/src/expr/expr_array_length.rs +++ b/src/expr/src/expr/expr_array_length.rs @@ -12,17 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::sync::Arc; - -use risingwave_common::array::{ArrayRef, DataChunk}; -use risingwave_common::row::OwnedRow; -use risingwave_common::types::{DataType, Datum, DatumRef, ScalarImpl, ScalarRefImpl, ToDatumRef}; -use risingwave_common::util::iter_util::ZipEqFast; -use risingwave_pb::expr::expr_node::{RexNode, Type}; -use risingwave_pb::expr::ExprNode; - -use crate::expr::{build_from_prost, BoxedExpression, Expression}; -use crate::{bail, ensure, ExprError, Result}; +use risingwave_common::array::ListRef; +use risingwave_expr_macro::function; /// Returns the length of an array. /// @@ -66,213 +57,7 @@ use crate::{bail, ensure, ExprError, Result}; /// query error unknown type /// select array_length(null); /// ``` - -#[derive(Debug)] -pub struct ArrayLengthExpression { - array: BoxedExpression, - return_type: DataType, -} - -impl<'a> TryFrom<&'a ExprNode> for ArrayLengthExpression { - type Error = ExprError; - - fn try_from(prost: &'a ExprNode) -> Result { - ensure!(prost.get_expr_type().unwrap() == Type::ArrayLength); - let RexNode::FuncCall(func_call_node) = prost.get_rex_node().unwrap() else { - bail!("Expected RexNode:FunctionCall") - }; - let children = func_call_node.get_children(); - ensure!(children.len() == 1); - let array = build_from_prost(&children[0])?; - let return_type = DataType::Int64; - Ok(Self { array, return_type }) - } -} - -#[async_trait::async_trait] -impl Expression for ArrayLengthExpression { - fn return_type(&self) -> DataType { - self.return_type.clone() - } - - async fn eval(&self, input: &DataChunk) -> Result { - let array = self.array.eval_checked(input).await?; - let mut builder = self.return_type.create_array_builder(array.len()); - - for (vis, input_array) in input.vis().iter().zip_eq_fast(array.iter()) { - if vis { - builder.append_datum(self.evaluate(input_array)); - } else { - builder.append_null(); - } - } - - Ok(Arc::new(builder.finish())) - } - - async fn eval_row(&self, input: &OwnedRow) -> Result { - let array_data = self.array.eval_row(input).await?; - Ok(self.evaluate(array_data.to_datum_ref())) - } -} - -impl ArrayLengthExpression { - fn evaluate(&self, array: DatumRef<'_>) -> Datum { - match array { - Some(ScalarRefImpl::List(array)) => Some(ScalarImpl::Int64( - array.values_ref().len().try_into().unwrap(), - )), - None => None, - _ => { - panic!("The array should be a valid array"); - } - } - } -} - -#[cfg(test)] -mod tests { - use itertools::Itertools; - use risingwave_common::array::{DataChunk, ListValue}; - use risingwave_common::types::{DataType, ScalarImpl}; - use risingwave_pb::data::Datum as ProstDatum; - use risingwave_pb::expr::expr_node::{RexNode, Type as ProstType}; - use risingwave_pb::expr::{ExprNode, FunctionCall}; - - use crate::expr::expr_array_length::ArrayLengthExpression; - use crate::expr::{BoxedExpression, Expression, LiteralExpression}; - - fn make_i64_expr_node(value: i64) -> ExprNode { - ExprNode { - expr_type: ProstType::ConstantValue as i32, - return_type: Some(DataType::Int64.to_protobuf()), - rex_node: Some(RexNode::Constant(ProstDatum { - body: value.to_be_bytes().to_vec(), - })), - } - } - - fn make_i64_array_expr_node(values: Vec) -> ExprNode { - ExprNode { - expr_type: ProstType::Array as i32, - return_type: Some( - DataType::List { - datatype: Box::new(DataType::Int64), - } - .to_protobuf(), - ), - rex_node: Some(RexNode::FuncCall(FunctionCall { - children: values.into_iter().map(make_i64_expr_node).collect(), - })), - } - } - - fn make_i64_array_array_expr_node(values: Vec>) -> ExprNode { - ExprNode { - expr_type: ProstType::Array as i32, - return_type: Some( - DataType::List { - datatype: Box::new(DataType::List { - datatype: Box::new(DataType::Int64), - }), - } - .to_protobuf(), - ), - rex_node: Some(RexNode::FuncCall(FunctionCall { - children: values.into_iter().map(make_i64_array_expr_node).collect(), - })), - } - } - - #[test] - fn test_array_length_try_from() { - { - let array = make_i64_expr_node(1); - let expr = ExprNode { - expr_type: ProstType::ArrayLength as i32, - return_type: Some( - DataType::List { - datatype: Box::new(DataType::Int64), - } - .to_protobuf(), - ), - rex_node: Some(RexNode::FuncCall(FunctionCall { - children: vec![array], - })), - }; - - assert!(ArrayLengthExpression::try_from(&expr).is_ok()); - } - - { - let array = make_i64_array_expr_node(vec![1, 2, 3]); - let expr = ExprNode { - expr_type: ProstType::ArrayLength as i32, - return_type: Some( - DataType::List { - datatype: Box::new(DataType::Int64), - } - .to_protobuf(), - ), - rex_node: Some(RexNode::FuncCall(FunctionCall { - children: vec![array], - })), - }; - - assert!(ArrayLengthExpression::try_from(&expr).is_ok()); - } - - { - let array = make_i64_array_array_expr_node(vec![vec![1, 2, 3]]); - let expr = ExprNode { - expr_type: ProstType::ArrayLength as i32, - return_type: Some( - DataType::List { - datatype: Box::new(DataType::Int64), - } - .to_protobuf(), - ), - rex_node: Some(RexNode::FuncCall(FunctionCall { - children: vec![array], - })), - }; - - assert!(ArrayLengthExpression::try_from(&expr).is_ok()); - } - } - - fn make_i64_array_expr(values: Vec) -> BoxedExpression { - LiteralExpression::new( - DataType::List { - datatype: Box::new(DataType::Int64), - }, - Some(ListValue::new(values.into_iter().map(|x| Some(x.into())).collect()).into()), - ) - .boxed() - } - - #[tokio::test] - async fn test_array_length_of_primitives() { - let array = make_i64_array_expr(vec![1, 2, 3]); - let expr = ArrayLengthExpression { - array, - return_type: DataType::Int64, - }; - - let chunk = - DataChunk::new_dummy(3).with_visibility(([false, true, true]).into_iter().collect()); - let expected_length = Some(ScalarImpl::Int64(3)); - - let expected = vec![None, expected_length.clone(), expected_length]; - - let actual = expr - .eval(&chunk) - .await - .unwrap() - .iter() - .map(|v| v.map(|s| s.into_scalar_impl())) - .collect_vec(); - - assert_eq!(actual, expected); - } +#[function("array_length(list) -> int64")] +fn array_length(array: ListRef<'_>) -> i64 { + array.values_ref().len() as _ } diff --git a/src/expr/src/expr/expr_array_to_string.rs b/src/expr/src/expr/expr_array_to_string.rs index 0e991b0b44a97..7ebc5a30e6251 100644 --- a/src/expr/src/expr/expr_array_to_string.rs +++ b/src/expr/src/expr/expr_array_to_string.rs @@ -12,18 +12,38 @@ // See the License for the specific language governing permissions and // limitations under the License. +#![allow(clippy::unit_arg)] + use std::fmt::Write; -use std::sync::Arc; use risingwave_common::array::*; -use risingwave_common::row::OwnedRow; use risingwave_common::types::to_text::ToText; -use risingwave_common::types::{DataType, Datum}; -use risingwave_pb::expr::expr_node::{RexNode, Type}; -use risingwave_pb::expr::ExprNode; - -use crate::expr::{build_from_prost, Expression}; -use crate::{bail, ensure, ExprError, Result}; +use risingwave_common::types::DataType; +use risingwave_expr_macro::build_function; + +use super::template::{BinaryBytesExpression, TernaryBytesExpression}; +use super::{BoxedExpression, Result}; + +#[build_function("array_to_string(list, varchar) -> varchar")] +fn build_array_to_string( + return_type: DataType, + children: Vec, +) -> Result { + let mut iter = children.into_iter(); + let list = iter.next().unwrap(); + let delimiter = iter.next().unwrap(); + let elem_type = match list.return_type() { + DataType::List { datatype } => *datatype, + _ => panic!("expected list type"), + }; + let expr = BinaryBytesExpression::::new( + list, + delimiter, + return_type, + move |a, d, writer| Ok(array_to_string(a, &elem_type, d, writer)), + ); + Ok(Box::new(expr)) +} /// Converts each array element to its text representation, and concatenates those /// separated by the delimiter string. If `null_string` is given and is not NULL, @@ -74,163 +94,65 @@ use crate::{bail, ensure, ExprError, Result}; /// query error polymorphic type /// select array_to_string(null, ','); /// ``` -#[derive(Debug)] -pub struct ArrayToStringExpression { - array: Box, - element_data_type: DataType, - delimiter: Box, - null_string: Option>, -} - -impl<'a> TryFrom<&'a ExprNode> for ArrayToStringExpression { - type Error = ExprError; - - fn try_from(prost: &'a ExprNode) -> Result { - ensure!(prost.get_expr_type().unwrap() == Type::ArrayToString); - let RexNode::FuncCall(func_call_node) = prost.get_rex_node().unwrap() else { - bail!("Expected RexNode::FuncCall"); - }; - let mut children = func_call_node.children.iter(); - let Some(array_node) = children.next() else { - bail!("Expected argument 'array'"); - }; - let array = build_from_prost(array_node)?; - - let element_data_type = match array.return_type() { - DataType::List { datatype } => *datatype, - _ => bail!("Expected argument 'array' to be of type List"), - }; - - let Some(delim_node) = children.next() else { - bail!("Expected argument 'delimiter'"); - }; - let delimiter = build_from_prost(delim_node)?; - - let null_string = if let Some(null_string_node) = children.next() { - Some(build_from_prost(null_string_node)?) - } else { - None - }; - - Ok(Self { - array, - element_data_type, - delimiter, - null_string, - }) - } -} - -#[async_trait::async_trait] -impl Expression for ArrayToStringExpression { - fn return_type(&self) -> DataType { - DataType::Varchar - } - - async fn eval(&self, input: &DataChunk) -> Result { - let list_array = self.array.eval_checked(input).await?; - let list_array = list_array.as_list(); - - let delim_array = self.delimiter.eval_checked(input).await?; - let delim_array = delim_array.as_utf8(); - - let null_string_array = if let Some(expr) = &self.null_string { - let null_string_array = expr.eval_checked(input).await?; - Some(null_string_array) +fn array_to_string( + array: ListRef<'_>, + element_data_type: &DataType, + delimiter: &str, + mut writer: &mut dyn Write, +) { + let mut first = true; + for element in array.values_ref().iter().flat_map(|f| f.iter()) { + if !first { + write!(writer, "{}", delimiter).unwrap(); } else { - None - }; - let null_string_array = null_string_array.as_ref().map(|a| a.as_utf8()); - - let mut output = Utf8ArrayBuilder::with_meta(input.capacity(), ArrayMeta::Simple); - - for (i, vis) in input.vis().iter().enumerate() { - if !vis { - output.append_null(); - continue; - } - let array = list_array.value_at(i); - let delim = delim_array.value_at(i); - let null_string = if let Some(a) = null_string_array { - a.value_at(i) - } else { - None - }; - - if let Some(array) = array && let Some(delim) = delim { - let mut writer = output.writer().begin(); - if let Some(null_string) = null_string { - self.evaluate_with_nulls(array, delim, null_string, &mut writer); - } else { - self.evaluate(array, delim, &mut writer); - } - writer.finish(); - } else { - output.append_null(); - } + first = false; } - Ok(Arc::new(output.finish().into())) + element + .write_with_type(element_data_type, &mut writer) + .unwrap(); } +} - async fn eval_row(&self, input: &OwnedRow) -> Result { - let array = self.array.eval_row(input).await?; - let delimiter = self.delimiter.eval_row(input).await?; - - let result = if let Some(array) = array && let Some(delimiter) = delimiter { - let null_string = if let Some(e) = &self.null_string { - e.eval_row(input).await? - } else { - None - }; - let mut writer = String::new(); - if let Some(null_string) = null_string { - self.evaluate_with_nulls(array.as_scalar_ref_impl().into_list(), delimiter.as_utf8(), null_string.as_utf8(), &mut writer); - } else { - self.evaluate(array.as_scalar_ref_impl().into_list(), delimiter.as_utf8(), &mut writer); - } - Some(writer) - } else { - None - }; - Ok(result.map(|r| r.into())) - } +#[build_function("array_to_string(list, varchar, varchar) -> varchar")] +fn build_array_to_string_with_null( + return_type: DataType, + children: Vec, +) -> Result { + let mut iter = children.into_iter(); + let list = iter.next().unwrap(); + let delimiter = iter.next().unwrap(); + let null_string = iter.next().unwrap(); + let elem_type = match list.return_type() { + DataType::List { datatype } => *datatype, + _ => panic!("expected list type"), + }; + let expr = TernaryBytesExpression::::new( + list, + delimiter, + null_string, + return_type, + move |a, d, n, writer| Ok(array_to_string_with_null(a, &elem_type, d, n, writer)), + ); + Ok(Box::new(expr)) } -impl ArrayToStringExpression { - fn evaluate(&self, array: ListRef<'_>, delimiter: &str, mut writer: &mut dyn Write) { - let mut first = true; - for element in array.values_ref().iter().flat_map(|f| f.iter()) { - if !first { - write!(writer, "{}", delimiter).unwrap(); - } else { - first = false; - } - element - .write_with_type(&self.element_data_type, &mut writer) - .unwrap(); +fn array_to_string_with_null( + array: ListRef<'_>, + element_data_type: &DataType, + delimiter: &str, + null_string: &str, + mut writer: &mut dyn Write, +) { + let mut first = true; + for element in array.values_ref() { + if !first { + write!(writer, "{}", delimiter).unwrap(); + } else { + first = false; } - } - - fn evaluate_with_nulls( - &self, - array: ListRef<'_>, - delimiter: &str, - null_string: &str, - mut writer: &mut dyn Write, - ) { - let mut first = true; - for element in array.values_ref() { - if !first { - write!(writer, "{}", delimiter).unwrap(); - } else { - first = false; - } - match element { - Some(s) => s - .write_with_type(&self.element_data_type, &mut writer) - .unwrap(), - None => write!(writer, "{}", null_string).unwrap(), - } + match element { + Some(s) => s.write_with_type(element_data_type, &mut writer).unwrap(), + None => write!(writer, "{}", null_string).unwrap(), } } } diff --git a/src/expr/src/expr/expr_binary_bytes.rs b/src/expr/src/expr/expr_binary_bytes.rs deleted file mode 100644 index 01a85163be2d3..0000000000000 --- a/src/expr/src/expr/expr_binary_bytes.rs +++ /dev/null @@ -1,183 +0,0 @@ -// Copyright 2023 RisingWave Labs -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -//! For expression that only accept two arguments + 1 bytes writer as input. - -use risingwave_common::array::{I32Array, NaiveDateTimeArray, Utf8Array}; -use risingwave_common::types::DataType; - -use super::Expression; -use crate::expr::template::BinaryBytesExpression; -use crate::expr::BoxedExpression; -use crate::vector_op::concat_op::concat_op; -use crate::vector_op::repeat::repeat; -use crate::vector_op::substr::*; -use crate::vector_op::to_char::to_char_timestamp; -use crate::vector_op::trim_characters::{ltrim_characters, rtrim_characters, trim_characters}; - -pub fn new_substr_start( - expr_ia1: BoxedExpression, - expr_ia2: BoxedExpression, - return_type: DataType, -) -> BoxedExpression { - BinaryBytesExpression::::new( - expr_ia1, - expr_ia2, - return_type, - substr_start, - ) - .boxed() -} - -#[cfg_attr(not(test), expect(dead_code))] -pub fn new_substr_for( - expr_ia1: BoxedExpression, - expr_ia2: BoxedExpression, - return_type: DataType, -) -> BoxedExpression { - BinaryBytesExpression::::new( - expr_ia1, - expr_ia2, - return_type, - substr_for, - ) - .boxed() -} - -// TODO: Support more `to_char` types. -pub fn new_to_char( - expr_ia1: BoxedExpression, - expr_ia2: BoxedExpression, - return_type: DataType, -) -> BoxedExpression { - BinaryBytesExpression::::new( - expr_ia1, - expr_ia2, - return_type, - to_char_timestamp, - ) - .boxed() -} - -pub fn new_repeat( - expr_ia1: BoxedExpression, - expr_ia2: BoxedExpression, - return_type: DataType, -) -> BoxedExpression { - BinaryBytesExpression::::new(expr_ia1, expr_ia2, return_type, repeat) - .boxed() -} - -macro_rules! impl_utf8_utf8 { - ($({ $func_name:ident, $method:ident }),*) => { - $(pub fn $func_name( - expr_ia1: BoxedExpression, - expr_ia2: BoxedExpression, - return_type: DataType, - ) -> BoxedExpression { - BinaryBytesExpression::::new( - expr_ia1, - expr_ia2, - return_type, - $method, - ) - .boxed() - })* - }; -} - -macro_rules! for_all_utf8_utf8_op { - ($macro:ident) => { - $macro! { - { new_trim_characters, trim_characters }, - { new_ltrim_characters, ltrim_characters }, - { new_rtrim_characters, rtrim_characters }, - { new_concat_op, concat_op } - } - }; -} - -for_all_utf8_utf8_op! { impl_utf8_utf8 } - -#[cfg(test)] -mod tests { - use risingwave_common::array::DataChunk; - use risingwave_common::row::OwnedRow; - use risingwave_common::types::{Datum, ScalarImpl}; - - use super::*; - use crate::expr::LiteralExpression; - - fn create_str_i32_binary_expr( - f: fn(BoxedExpression, BoxedExpression, DataType) -> BoxedExpression, - str_arg: Datum, - i32_arg: Datum, - ) -> BoxedExpression { - f( - Box::new(LiteralExpression::new(DataType::Varchar, str_arg)), - Box::new(LiteralExpression::new(DataType::Int32, i32_arg)), - DataType::Varchar, - ) - } - - async fn test_evals_dummy(expr: &BoxedExpression, expected: Datum) { - let res = expr.eval(&DataChunk::new_dummy(1)).await.unwrap(); - assert_eq!(res.to_datum(), expected); - - let res = expr.eval_row(&OwnedRow::new(vec![])).await.unwrap(); - assert_eq!(res, expected); - } - - #[tokio::test] - async fn test_substr() { - let text = "quick brown"; - let start_pos = 3; - let for_pos = 4; - - let substr_start_normal = create_str_i32_binary_expr( - new_substr_start, - Some(ScalarImpl::from(String::from(text))), - Some(ScalarImpl::Int32(start_pos)), - ); - test_evals_dummy( - &substr_start_normal, - Some(ScalarImpl::from(String::from( - &text[start_pos as usize - 1..], - ))), - ) - .await; - - let substr_start_i32_none = create_str_i32_binary_expr( - new_substr_start, - Some(ScalarImpl::from(String::from(text))), - None, - ); - test_evals_dummy(&substr_start_i32_none, None).await; - - let substr_for_normal = create_str_i32_binary_expr( - new_substr_for, - Some(ScalarImpl::from(String::from(text))), - Some(ScalarImpl::Int32(for_pos)), - ); - test_evals_dummy( - &substr_for_normal, - Some(ScalarImpl::from(String::from(&text[..for_pos as usize]))), - ) - .await; - - let substr_for_str_none = - create_str_i32_binary_expr(new_substr_for, None, Some(ScalarImpl::Int32(for_pos))); - test_evals_dummy(&substr_for_str_none, None).await; - } -} diff --git a/src/expr/src/expr/expr_binary_nonnull.rs b/src/expr/src/expr/expr_binary_nonnull.rs index 2e4a273cf05b7..26808bd3a56c8 100644 --- a/src/expr/src/expr/expr_binary_nonnull.rs +++ b/src/expr/src/expr/expr_binary_nonnull.rs @@ -12,843 +12,15 @@ // See the License for the specific language governing permissions and // limitations under the License. -use risingwave_common::array::{ - Array, BoolArray, DecimalArray, F64Array, I32Array, I64Array, IntervalArray, JsonbArrayBuilder, - ListArray, NaiveDateArray, NaiveDateTimeArray, NaiveTimeArray, StructArray, Utf8Array, - Utf8ArrayBuilder, -}; -use risingwave_common::types::*; -use risingwave_pb::expr::expr_node::Type; - -use super::Expression; -use crate::expr::expr_binary_bytes::new_concat_op; -use crate::expr::expr_jsonb_access::{ - jsonb_array_element, jsonb_object_field, JsonbAccessExpression, -}; -use crate::expr::template::{BinaryBytesExpression, BinaryExpression}; -use crate::expr::{template_fast, BoxedExpression}; -use crate::vector_op::arithmetic_op::*; -use crate::vector_op::bitwise_op::*; -use crate::vector_op::cmp::*; -use crate::vector_op::date_trunc::{date_trunc_interval, date_trunc_timestamp}; -use crate::vector_op::extract::{ - extract_from_date, extract_from_time, extract_from_timestamp, extract_from_timestamptz, -}; -use crate::vector_op::like::like_default; -use crate::vector_op::position::position; -use crate::vector_op::round::round_digits; -use crate::vector_op::timestamptz::{ - str_to_timestamptz, timestamp_at_time_zone, timestamptz_at_time_zone, timestamptz_to_string, -}; -use crate::vector_op::to_timestamp::to_timestamp; -use crate::vector_op::tumble::{ - tumble_start_date, tumble_start_date_time, tumble_start_timestamptz, -}; -use crate::{for_all_cmp_variants, ExprError, Result}; - -/// This macro helps create arithmetic expression. -/// It receive all the combinations of `gen_binary_expr` and generate corresponding match cases -/// In [], the parameters are for constructing new expression -/// * $l: left expression -/// * $r: right expression -/// * $ret: return array type -/// In ()*, the parameters are for generating match cases -/// * $i1: left array type -/// * $i2: right array type -/// * $rt: The return type in that the operation will calculate -/// * $The scalar function for expression, it's a generic function and specialized by the type of -/// `$i1, $i2, $rt` -macro_rules! gen_atm_impl { - ([$l:expr, $r:expr, $ret:expr], $( { $i1:ident, $i2:ident, $rt:ident, $func:ident },)*) => { - match ($l.return_type(), $r.return_type()) { - $( - ($i1! { type_match_pattern }, $i2! { type_match_pattern }) => { - Box::new( - BinaryExpression::< - $i1! { type_array }, - $i2! { type_array }, - $rt! { type_array }, - _ - >::new( - $l, - $r, - $ret, - $func::< <$i1! { type_array } as Array>::OwnedItem, <$i2! { type_array } as Array>::OwnedItem, <$rt! { type_array } as Array>::OwnedItem>, - ) - ) as BoxedExpression - }, - )* - _ => { - return Err(ExprError::UnsupportedFunction(format!( - "{:?} atm {:?}", - $l.return_type(), $r.return_type() - ))); - } - } - }; -} - -macro_rules! gen_atm_impl_fast { - ([$l:expr, $r:expr, $ret:expr], $( { $i1:ident, $i2:ident, $rt:ident, $func:ident },)*) => { - match ($l.return_type(), $r.return_type()) { - $( - ($i1! { type_match_pattern }, $i2! { type_match_pattern }) => { - template_fast::BinaryExpression::new( - $l, $r, $ret, - $func::< - <$i1! { type_array } as Array>::OwnedItem, - <$i2! { type_array } as Array>::OwnedItem, - <$rt! { type_array } as Array>::OwnedItem - >, - ).boxed() - }, - )* - _ => { - return Err(ExprError::UnsupportedFunction(format!( - "{:?} atm {:?}", - $l.return_type(), $r.return_type() - ))); - } - } - }; -} - -/// This macro helps create comparison expression. Its output array is a bool array -/// Similar to `gen_atm_impl`. -macro_rules! gen_cmp_impl { - ([$l:expr, $r:expr, $ret:expr], $( { $i1:ident, $i2:ident, $cast:ident, $func:ident} ),* $(,)?) => { - match ($l.return_type(), $r.return_type()) { - $( - ($i1! { type_match_pattern }, $i2! { type_match_pattern }) => { - template_fast::CompareExpression::new( - $l, - $r, - $func::< - <$i1! { type_array } as Array>::OwnedItem, - <$i2! { type_array } as Array>::OwnedItem, - <$cast! { type_array } as Array>::OwnedItem - >, - ).boxed() - } - ),* - _ => { - return Err(ExprError::UnsupportedFunction(format!( - "{:?} cmp {:?}", - $l.return_type(), $r.return_type() - ))); - } - } - }; -} - -/// This macro helps create bitwise shift expression. The Output type is same as LHS of the -/// expression and the RHS of the expression is being match into u32. Similar to `gen_atm_impl`. -macro_rules! gen_shift_impl { - ([$l:expr, $r:expr, $ret:expr], $( { $i1:ident, $i2:ident, $func:ident },)*) => { - match ($l.return_type(), $r.return_type()) { - $( - ($i1! { type_match_pattern }, $i2! { type_match_pattern }) => { - Box::new( - BinaryExpression::< - $i1! { type_array }, - $i2! { type_array }, - $i1! { type_array }, - _ - >::new( - $l, - $r, - $ret, - $func::< - <$i1! { type_array } as Array>::OwnedItem, - <$i2! { type_array } as Array>::OwnedItem>, - - ) - ) as BoxedExpression - }, - )* - _ => { - return Err(ExprError::UnsupportedFunction(format!( - "{:?} shift {:?}", - $l.return_type(), $r.return_type() - ))); - } - } - }; -} - -/// Based on the data type of `$l`, `$r`, `$ret`, return corresponding expression struct with scalar -/// function inside. -/// * `$l`: left expression -/// * `$r`: right expression -/// * `$ret`: returned expression -/// * `macro`: a macro helps create expression -/// * `general_f`: generic cmp function (require a common ``TryInto`` type for two input). -/// * `boolean_f`: boolean cmp function -/// * `str_f`: cmp function between str -macro_rules! gen_binary_expr_cmp { - ($macro:ident, $general_f:ident, $boolean_f:ident, $op:ident, $l:expr, $r:expr, $ret:expr) => { - match ($l.return_type(), $r.return_type()) { - (DataType::Boolean, DataType::Boolean) => { - template_fast::BooleanBinaryExpression::new($l, $r, $boolean_f, |l, r| { - match (l, r) { - (Some(l), Some(r)) => Some($general_f::(l, r)), - _ => None, - } - }) - .boxed() - } - (DataType::Varchar, DataType::Varchar) => { - Box::new(BinaryExpression::::new( - $l, - $r, - $ret, - gen_str_cmp($op), - )) as BoxedExpression - } - (DataType::Struct { .. }, DataType::Struct { .. }) => { - Box::new( - BinaryExpression::::new( - $l, - $r, - $ret, - gen_struct_cmp($op), - ), - ) - } - (DataType::List { .. }, DataType::List { .. }) => { - Box::new(BinaryExpression::::new( - $l, - $r, - $ret, - gen_list_cmp($op), - )) - } - _ => { - for_all_cmp_variants! {$macro, $l, $r, $ret, $general_f} - } - } - }; -} - -/// `gen_binary_expr_atm` is similar to `gen_binary_expr_cmp`. -/// `atm` means arithmetic here. -/// They are differentiate cuz one type may not support atm and cmp at the same time. For example, -/// Varchar can support compare but not arithmetic. -/// * `$general_f`: generic atm function (require a common ``TryInto`` type for two input) -/// * `$i1`, `$i2`, `$rt`, `$func`: extra list passed to `$macro` directly -macro_rules! gen_binary_expr_atm { - ( - $macro:ident, - $l:expr, - $r:expr, - $ret:expr, - $general_f:ident, - { - $( { $i1:ident, $i2:ident, $rt:ident, $func:ident }, )* - } $(,)? - ) => { - $macro! { - [$l, $r, $ret], - { int16, int16, int16, $general_f }, - { int16, int32, int32, $general_f }, - { int16, int64, int64, $general_f }, - { int16, float32, float64, $general_f }, - { int16, float64, float64, $general_f }, - { int32, int16, int32, $general_f }, - { int32, int32, int32, $general_f }, - { int32, int64, int64, $general_f }, - { int32, float32, float64, $general_f }, - { int32, float64, float64, $general_f }, - { int64, int16,int64, $general_f }, - { int64, int32,int64, $general_f }, - { int64, int64, int64, $general_f }, - { int64, float32, float64 , $general_f}, - { int64, float64, float64, $general_f }, - { float32, int16, float64, $general_f }, - { float32, int32, float64, $general_f }, - { float32, int64, float64 , $general_f}, - { float32, float32, float32, $general_f }, - { float32, float64, float64, $general_f }, - { float64, int16, float64, $general_f }, - { float64, int32, float64, $general_f }, - { float64, int64, float64, $general_f }, - { float64, float32, float64, $general_f }, - { float64, float64, float64, $general_f }, - { decimal, int16, decimal, $general_f }, - { decimal, int32, decimal, $general_f }, - { decimal, int64, decimal, $general_f }, - { decimal, float32, float64, $general_f }, - { decimal, float64, float64, $general_f }, - { int16, decimal, decimal, $general_f }, - { int32, decimal, decimal, $general_f }, - { int64, decimal, decimal, $general_f }, - { decimal, decimal, decimal, $general_f }, - { float32, decimal, float64, $general_f }, - { float64, decimal, float64, $general_f }, - $( - { $i1, $i2, $rt, $func }, - )* - } - }; -} - -/// `gen_binary_expr_bitwise` is similar to `gen_binary_expr_atm`. -/// They are differentiate because bitwise operation only supports integral datatype. -/// * `$general_f`: generic atm function (require a common ``TryInto`` type for two input) -/// * `$i1`, `$i2`, `$rt`, `$func`: extra list passed to `$macro` directly -macro_rules! gen_binary_expr_bitwise { - ( - $macro:ident, - $l:expr, - $r:expr, - $ret:expr, - $general_f:ident, - { - $( { $i1:ident, $i2:ident, $rt:ident, $func:ident }, )* - } $(,)? - ) => { - $macro! { - [$l, $r, $ret], - { int16, int16, int16, $general_f }, - { int16, int32, int32, $general_f }, - { int16, int64, int64, $general_f }, - { int32, int16, int32, $general_f }, - { int32, int32, int32, $general_f }, - { int32, int64, int64, $general_f }, - { int64, int16, int64, $general_f }, - { int64, int32, int64, $general_f }, - { int64, int64, int64, $general_f }, - $( - { $i1, $i2, $rt, $func }, - )* - } - }; -} - -/// `gen_binary_expr_shift` is similar to `gen_binary_expr_bitwise`. -/// They are differentiate because shift operation have different typing rules. -/// * `$general_f`: generic atm function -/// `$rt` is not required because Type of the output is same as the Type of LHS of expression. -/// * `$i1`, `$i2`, `$func`: extra list passed to `$macro` directly -macro_rules! gen_binary_expr_shift { - ( - $macro:ident, - $l:expr, - $r:expr, - $ret:expr, - $general_f:ident, - { - $( { $i1:ident, $i2:ident, $func:ident }, )* - } $(,)? - ) => { - $macro! { - [$l, $r, $ret], - { int16, int16, $general_f }, - { int32, int16, $general_f }, - { int16, int32, $general_f }, - { int32, int32, $general_f }, - { int64, int16, $general_f }, - { int64, int32, $general_f }, - $( - { $i1, $i2, $func }, - )* - } - }; -} - -fn build_extract_expr( - ret: DataType, - l: BoxedExpression, - r: BoxedExpression, -) -> Result { - let expr: BoxedExpression = - match r.return_type() { - DataType::Date => Box::new(BinaryExpression::< - Utf8Array, - NaiveDateArray, - DecimalArray, - _, - >::new(l, r, ret, extract_from_date)), - DataType::Timestamp => Box::new(BinaryExpression::< - Utf8Array, - NaiveDateTimeArray, - DecimalArray, - _, - >::new(l, r, ret, extract_from_timestamp)), - DataType::Timestamptz => Box::new(BinaryExpression::< - Utf8Array, - I64Array, - DecimalArray, - _, - >::new( - l, r, ret, extract_from_timestamptz - )), - DataType::Time => Box::new(BinaryExpression::< - Utf8Array, - NaiveTimeArray, - DecimalArray, - _, - >::new(l, r, ret, extract_from_time)), - _ => { - return Err(ExprError::UnsupportedFunction(format!( - "Extract ( {:?} ) is not supported yet!", - r.return_type() - ))) - } - }; - Ok(expr) -} - -fn build_at_time_zone_expr( - ret: DataType, - l: BoxedExpression, - r: BoxedExpression, -) -> Result { - let expr: BoxedExpression = match l.return_type() { - DataType::Timestamp => Box::new(BinaryExpression::< - NaiveDateTimeArray, - Utf8Array, - I64Array, - _, - >::new(l, r, ret, timestamp_at_time_zone)), - DataType::Timestamptz => Box::new(BinaryExpression::< - I64Array, - Utf8Array, - NaiveDateTimeArray, - _, - >::new(l, r, ret, timestamptz_at_time_zone)), - _ => { - return Err(ExprError::UnsupportedFunction(format!( - "{:?} AT TIME ZONE is not supported yet!", - l.return_type() - ))) - } - }; - Ok(expr) -} - -fn build_cast_with_time_zone_expr( - ret: DataType, - l: BoxedExpression, - r: BoxedExpression, -) -> Result { - let expr: BoxedExpression = match (ret.clone(), l.return_type()) { - (DataType::Varchar, DataType::Timestamptz) => Box::new(BinaryBytesExpression::< - I64Array, - Utf8Array, - _, - >::new( - l, r, ret, timestamptz_to_string - )), - (DataType::Timestamptz, DataType::Varchar) => { - Box::new(BinaryExpression::::new( - l, - r, - ret, - str_to_timestamptz, - )) - } - _ => { - return Err(ExprError::UnsupportedFunction(format!( - "cannot cast at time zone (input type: {:?}, output type: {:?}", - l.return_type(), - ret, - ))) - } - }; - Ok(expr) -} - -pub fn new_date_trunc_expr( - ret: DataType, - field: BoxedExpression, - source: BoxedExpression, - timezone: Option<(BoxedExpression, BoxedExpression)>, -) -> BoxedExpression { - match source.return_type() { - DataType::Timestamp => BinaryExpression::< - Utf8Array, - NaiveDateTimeArray, - NaiveDateTimeArray, - _, - >::new(field, source, ret, date_trunc_timestamp).boxed(), - DataType::Timestamptz => { - // timestamptz AT TIME ZONE zone -> timestamp - // truncate(field, timestamp) -> timestamp - // timestamp AT TIME ZONE zone -> timestamptz - let (timezone1, timezone2) = timezone - .expect("A time zone must be specified when processing timestamp with time zone"); - let timestamp = BinaryExpression::::new( - source, - timezone1, - DataType::Timestamp, - timestamptz_at_time_zone, - ).boxed(); - let truncated = BinaryExpression::< - Utf8Array, - NaiveDateTimeArray, - NaiveDateTimeArray, - _, - >::new( - field, - timestamp, - DataType::Timestamp, - date_trunc_timestamp, - ).boxed(); - BinaryExpression::::new( - truncated, - timezone2, - DataType::Timestamptz, - timestamp_at_time_zone, - ).boxed() - } - DataType::Interval => BinaryExpression::< - Utf8Array, - IntervalArray, - IntervalArray, - _, - >::new(field, source, ret, date_trunc_interval).boxed(), - _ => panic!("source must be a value expression of type timestamp, timestamp with time zone, or interval."), - } -} - -/// Create a new binary expression. -pub fn new_binary_expr( - expr_type: Type, - ret: DataType, - l: BoxedExpression, - r: BoxedExpression, -) -> Result { - use crate::expr::data_types::*; - let expr = match expr_type { - Type::Equal => { - gen_binary_expr_cmp! {gen_cmp_impl, general_eq, boolean_eq, EQ, l, r, ret} - } - Type::NotEqual => { - gen_binary_expr_cmp! {gen_cmp_impl, general_ne, boolean_ne, NE, l, r, ret} - } - Type::LessThan => { - gen_binary_expr_cmp! {gen_cmp_impl, general_lt, boolean_lt, LT, l, r, ret} - } - Type::GreaterThan => { - gen_binary_expr_cmp! {gen_cmp_impl, general_gt, boolean_gt, GT, l, r, ret} - } - Type::GreaterThanOrEqual => { - gen_binary_expr_cmp! {gen_cmp_impl, general_ge, boolean_ge, GE, l, r, ret} - } - Type::LessThanOrEqual => { - gen_binary_expr_cmp! {gen_cmp_impl, general_le, boolean_le, LE, l, r, ret} - } - Type::Add => { - gen_binary_expr_atm! { - gen_atm_impl, - l, r, ret, - general_add, - { - { timestamptz, interval, timestamptz, timestamptz_interval_add }, - { interval, timestamptz, timestamptz, interval_timestamptz_add }, - { timestamp, interval, timestamp, timestamp_interval_add }, - { interval, timestamp, timestamp, interval_timestamp_add }, - { interval, date, timestamp, interval_date_add }, - { interval, time, time, interval_time_add }, - { date, interval, timestamp, date_interval_add }, - { date, int32, date, date_int_add }, - { int32, date, date, int_date_add }, - { date, time, timestamp, date_time_add }, - { time, date, timestamp, time_date_add }, - { interval, interval, interval, general_add }, - { time, interval, time, time_interval_add }, - }, - } - } - Type::Subtract => { - gen_binary_expr_atm! { - gen_atm_impl, - l, r, ret, - general_sub, - { - { timestamptz, interval, timestamptz, timestamptz_interval_sub }, - { timestamp, timestamp, interval, timestamp_timestamp_sub }, - { timestamp, interval, timestamp, timestamp_interval_sub }, - { date, date, int32, date_date_sub }, - { date, interval, timestamp, date_interval_sub }, - { time, time, interval, time_time_sub }, - { time, interval, time, time_interval_sub }, - { interval, interval, interval, general_sub }, - { date, int32, date, date_int_sub }, - }, - } - } - Type::Multiply => { - gen_binary_expr_atm! { - gen_atm_impl, - l, r, ret, - general_mul, - { - { interval, int16, interval, interval_int_mul }, - { interval, int32, interval, interval_int_mul }, - { interval, int64, interval, interval_int_mul }, - { interval, float32, interval, interval_float_mul }, - { interval, float64, interval, interval_float_mul }, - { interval, decimal, interval, interval_float_mul }, - - { int16, interval, interval, int_interval_mul }, - { int32, interval, interval, int_interval_mul }, - { int64, interval, interval, int_interval_mul }, - { float32, interval, interval, float_interval_mul }, - { float64, interval, interval, float_interval_mul }, - { decimal, interval, interval, float_interval_mul }, - }, - } - } - Type::Divide => { - gen_binary_expr_atm! { - gen_atm_impl, - l, r, ret, - general_div, - { - { interval, int16, interval, interval_float_div }, - { interval, int32, interval, interval_float_div }, - { interval, int64, interval, interval_float_div }, - { interval, float32, interval, interval_float_div }, - { interval, float64, interval, interval_float_div }, - { interval, decimal, interval, interval_float_div }, - }, - } - } - Type::Modulus => { - gen_binary_expr_atm! { - gen_atm_impl, - l, r, ret, - general_mod, - { - }, - } - } - // BitWise Operation - Type::BitwiseShiftLeft => { - gen_binary_expr_shift! { - gen_shift_impl, - l, r, ret, - general_shl, - { - - }, - } - } - Type::BitwiseShiftRight => { - gen_binary_expr_shift! { - gen_shift_impl, - l, r, ret, - general_shr, - { - - }, - } - } - Type::BitwiseAnd => { - gen_binary_expr_bitwise! { - gen_atm_impl_fast, - l, r, ret, - general_bitand, - { - }, - } - } - Type::BitwiseOr => { - gen_binary_expr_bitwise! { - gen_atm_impl_fast, - l, r, ret, - general_bitor, - { - }, - } - } - Type::BitwiseXor => { - gen_binary_expr_bitwise! { - gen_atm_impl_fast, - l, r, ret, - general_bitxor, - { - }, - } - } - Type::Pow => Box::new(BinaryExpression::::new( - l, r, ret, pow_f64, - )), - Type::Extract => build_extract_expr(ret, l, r)?, - Type::AtTimeZone => build_at_time_zone_expr(ret, l, r)?, - Type::CastWithTimeZone => build_cast_with_time_zone_expr(ret, l, r)?, - Type::RoundDigit => Box::new(template_fast::BinaryExpression::new( - l, - r, - ret, - round_digits::, - )), - Type::Position => Box::new(BinaryExpression::::new( - l, r, ret, position, - )), - Type::ConcatOp => new_concat_op(l, r, ret), - Type::JsonbAccessInner => match r.return_type() { - DataType::Varchar => { - JsonbAccessExpression::::new_expr( - l, - r, - jsonb_object_field, - ) - .boxed() - } - DataType::Int32 => JsonbAccessExpression::::new_expr( - l, - r, - jsonb_array_element, - ) - .boxed(), - t => return Err(ExprError::UnsupportedFunction(format!("jsonb -> {t}"))), - }, - Type::JsonbAccessStr => match r.return_type() { - DataType::Varchar => JsonbAccessExpression::::new_expr( - l, - r, - jsonb_object_field, - ) - .boxed(), - DataType::Int32 => JsonbAccessExpression::::new_expr( - l, - r, - jsonb_array_element, - ) - .boxed(), - t => return Err(ExprError::UnsupportedFunction(format!("jsonb ->> {t}"))), - }, - tp => { - return Err(ExprError::UnsupportedFunction(format!( - "{:?}({:?}, {:?})", - tp, - l.return_type(), - r.return_type(), - ))); - } - }; - Ok(expr) -} - -pub fn new_tumble_start( - expr_ia1: BoxedExpression, - expr_ia2: BoxedExpression, - return_type: DataType, -) -> Result { - let expr: BoxedExpression = match expr_ia1.return_type() { - DataType::Date => Box::new(BinaryExpression::< - NaiveDateArray, - IntervalArray, - NaiveDateTimeArray, - _, - >::new( - expr_ia1, expr_ia2, return_type, tumble_start_date - )), - DataType::Timestamp => Box::new(BinaryExpression::< - NaiveDateTimeArray, - IntervalArray, - NaiveDateTimeArray, - _, - >::new( - expr_ia1, expr_ia2, return_type, tumble_start_date_time - )), - DataType::Timestamptz => Box::new( - BinaryExpression::::new( - expr_ia1, - expr_ia2, - return_type, - tumble_start_timestamptz, - ), - ), - _ => { - return Err(ExprError::UnsupportedFunction(format!( - "tumble_start is not supported for {:?}", - expr_ia1.return_type() - ))) - } - }; - Ok(expr) -} - -pub fn new_like_default( - expr_ia1: BoxedExpression, - expr_ia2: BoxedExpression, - return_type: DataType, -) -> BoxedExpression { - Box::new(BinaryExpression::::new( - expr_ia1, - expr_ia2, - return_type, - like_default, - )) -} - -pub fn new_to_timestamp( - expr_ia1: BoxedExpression, - expr_ia2: BoxedExpression, - return_type: DataType, -) -> BoxedExpression { - BinaryExpression::::new( - expr_ia1, - expr_ia2, - return_type, - to_timestamp, - ) - .boxed() -} - -fn boolean_eq(l: &BoolArray, r: &BoolArray) -> BoolArray { - let data = !(l.data() ^ r.data()); - let bitmap = l.null_bitmap() & r.null_bitmap(); - BoolArray::new(data, bitmap) -} - -fn boolean_ne(l: &BoolArray, r: &BoolArray) -> BoolArray { - let data = l.data() ^ r.data(); - let bitmap = l.null_bitmap() & r.null_bitmap(); - BoolArray::new(data, bitmap) -} - -fn boolean_gt(l: &BoolArray, r: &BoolArray) -> BoolArray { - let data = l.data() & !r.data(); - let bitmap = l.null_bitmap() & r.null_bitmap(); - BoolArray::new(data, bitmap) -} - -fn boolean_lt(l: &BoolArray, r: &BoolArray) -> BoolArray { - let data = !l.data() & r.data(); - let bitmap = l.null_bitmap() & r.null_bitmap(); - BoolArray::new(data, bitmap) -} - -fn boolean_ge(l: &BoolArray, r: &BoolArray) -> BoolArray { - let data = l.data() | !r.data(); - let bitmap = l.null_bitmap() & r.null_bitmap(); - BoolArray::new(data, bitmap) -} - -fn boolean_le(l: &BoolArray, r: &BoolArray) -> BoolArray { - let data = !l.data() | r.data(); - let bitmap = l.null_bitmap() & r.null_bitmap(); - BoolArray::new(data, bitmap) -} - #[cfg(test)] mod tests { use risingwave_common::array::interval_array::IntervalArray; use risingwave_common::array::*; use risingwave_common::types::test_utils::IntervalUnitTestExt; - use risingwave_common::types::{ - Decimal, IntervalUnit, NaiveDateTimeWrapper, NaiveDateWrapper, Scalar, - }; - use risingwave_pb::data::data_type::TypeName; + use risingwave_common::types::{Decimal, IntervalUnit, NaiveDateWrapper, Scalar}; use risingwave_pb::expr::expr_node::Type; use super::super::*; - use crate::expr::test_utils::make_expression; use crate::vector_op::arithmetic_op::{date_interval_add, date_interval_sub}; #[tokio::test] @@ -874,18 +46,12 @@ mod tests { test_binary_decimal::(|x, y| x < y, Type::LessThan).await; test_binary_decimal::(|x, y| x <= y, Type::LessThanOrEqual).await; test_binary_interval::( - |x, y| { - date_interval_add::(x, y) - .unwrap() - }, + |x, y| date_interval_add(x, y).unwrap(), Type::Add, ) .await; test_binary_interval::( - |x, y| { - date_interval_sub::(x, y) - .unwrap() - }, + |x, y| date_interval_sub(x, y).unwrap(), Type::Subtract, ) .await; @@ -924,9 +90,19 @@ mod tests { let col1 = I32Array::from_iter(&lhs).into(); let col2 = I32Array::from_iter(&rhs).into(); let data_chunk = DataChunk::new(vec![col1, col2], 100); - let expr = make_expression(kind, &[TypeName::Int32, TypeName::Int32], &[0, 1]); - let vec_executor = build_from_prost(&expr).unwrap(); - let res = vec_executor.eval(&data_chunk).await.unwrap(); + let expr = build( + kind, + match kind { + Type::Add | Type::Subtract | Type::Multiply | Type::Divide => DataType::Int32, + _ => DataType::Boolean, + }, + vec![ + InputRefExpression::new(DataType::Int32, 0).boxed(), + InputRefExpression::new(DataType::Int32, 1).boxed(), + ], + ) + .unwrap(); + let res = expr.eval(&data_chunk).await.unwrap(); let arr: &A = res.as_ref().into(); for (idx, item) in arr.iter().enumerate() { let x = target[idx].as_ref().map(|x| x.as_scalar_ref()); @@ -938,7 +114,7 @@ mod tests { lhs[i].map(|int| int.to_scalar_value()), rhs[i].map(|int| int.to_scalar_value()), ]); - let result = vec_executor.eval_row(&row).await.unwrap(); + let result = expr.eval_row(&row).await.unwrap(); let expected = target[i].as_ref().cloned().map(|x| x.to_scalar_value()); assert_eq!(result, expected); } @@ -972,9 +148,16 @@ mod tests { let col1 = NaiveDateArray::from_iter(&lhs).into(); let col2 = IntervalArray::from_iter(&rhs).into(); let data_chunk = DataChunk::new(vec![col1, col2], 100); - let expr = make_expression(kind, &[TypeName::Date, TypeName::Interval], &[0, 1]); - let vec_executor = build_from_prost(&expr).unwrap(); - let res = vec_executor.eval(&data_chunk).await.unwrap(); + let expr = build( + kind, + DataType::Timestamp, + vec![ + InputRefExpression::new(DataType::Date, 0).boxed(), + InputRefExpression::new(DataType::Interval, 1).boxed(), + ], + ) + .unwrap(); + let res = expr.eval(&data_chunk).await.unwrap(); let arr: &A = res.as_ref().into(); for (idx, item) in arr.iter().enumerate() { let x = target[idx].as_ref().map(|x| x.as_scalar_ref()); @@ -986,7 +169,7 @@ mod tests { lhs[i].map(|date| date.to_scalar_value()), rhs[i].map(|date| date.to_scalar_value()), ]); - let result = vec_executor.eval_row(&row).await.unwrap(); + let result = expr.eval_row(&row).await.unwrap(); let expected = target[i].as_ref().cloned().map(|x| x.to_scalar_value()); assert_eq!(result, expected); } @@ -1025,9 +208,19 @@ mod tests { let col1 = DecimalArray::from_iter(&lhs).into(); let col2 = DecimalArray::from_iter(&rhs).into(); let data_chunk = DataChunk::new(vec![col1, col2], 100); - let expr = make_expression(kind, &[TypeName::Decimal, TypeName::Decimal], &[0, 1]); - let vec_executor = build_from_prost(&expr).unwrap(); - let res = vec_executor.eval(&data_chunk).await.unwrap(); + let expr = build( + kind, + match kind { + Type::Add | Type::Subtract | Type::Multiply | Type::Divide => DataType::Decimal, + _ => DataType::Boolean, + }, + vec![ + InputRefExpression::new(DataType::Decimal, 0).boxed(), + InputRefExpression::new(DataType::Decimal, 1).boxed(), + ], + ) + .unwrap(); + let res = expr.eval(&data_chunk).await.unwrap(); let arr: &A = res.as_ref().into(); for (idx, item) in arr.iter().enumerate() { let x = target[idx].as_ref().map(|x| x.as_scalar_ref()); @@ -1039,7 +232,7 @@ mod tests { lhs[i].map(|dec| dec.to_scalar_value()), rhs[i].map(|dec| dec.to_scalar_value()), ]); - let result = vec_executor.eval_row(&row).await.unwrap(); + let result = expr.eval_row(&row).await.unwrap(); let expected = target[i].as_ref().cloned().map(|x| x.to_scalar_value()); assert_eq!(result, expected); } diff --git a/src/expr/src/expr/expr_binary_nullable.rs b/src/expr/src/expr/expr_binary_nullable.rs index 2d26232fb1291..35732741fcffb 100644 --- a/src/expr/src/expr/expr_binary_nullable.rs +++ b/src/expr/src/expr/expr_binary_nullable.rs @@ -16,51 +16,16 @@ use std::sync::Arc; -use risingwave_common::array::serial_array::SerialArray; use risingwave_common::array::*; use risingwave_common::buffer::Bitmap; use risingwave_common::row::OwnedRow; use risingwave_common::types::{DataType, Datum, Scalar}; +use risingwave_expr_macro::build_function; use risingwave_pb::expr::expr_node::Type; use super::{BoxedExpression, Expression}; -use crate::expr::template::BinaryNullableExpression; -use crate::expr::template_fast; -use crate::vector_op::array_access::array_access; -use crate::vector_op::cmp::{ - general_is_distinct_from, general_is_not_distinct_from, general_ne, str_is_distinct_from, - str_is_not_distinct_from, -}; use crate::vector_op::conjunction::{and, or}; -use crate::vector_op::format_type::format_type; -use crate::{for_all_cmp_variants, ExprError, Result}; - -macro_rules! gen_is_distinct_from_impl { - ([$l:expr, $r:expr, $ret:expr], $( { $i1:ident, $i2:ident, $cast:ident, $func:ident} ),* $(,)?) => { - match ($l.return_type(), $r.return_type()) { - $( - ($i1! { type_match_pattern }, $i2! { type_match_pattern }) => { - template_fast::IsDistinctFromExpression::new( - $l, - $r, - general_ne::< - <$i1! { type_array } as Array>::OwnedItem, - <$i2! { type_array } as Array>::OwnedItem, - <$cast! { type_array } as Array>::OwnedItem - >, - $func, - ).boxed() - } - ),* - _ => { - return Err(ExprError::UnsupportedFunction(format!( - "{:?} cmp {:?}", - $l.return_type(), $r.return_type() - ))); - } - } - }; -} +use crate::Result; pub struct BinaryShortCircuitExpression { expr_ia1: BoxedExpression, @@ -138,180 +103,40 @@ impl Expression for BinaryShortCircuitExpression { } let ret_ia2 = self.expr_ia2.eval_row(input).await?.map(|x| x.into_bool()); match self.expr_type { - Type::Or => Ok(or(ret_ia1, ret_ia2)?.map(|x| x.to_scalar_value())), - Type::And => Ok(and(ret_ia1, ret_ia2)?.map(|x| x.to_scalar_value())), + Type::Or => Ok(or(ret_ia1, ret_ia2).map(|x| x.to_scalar_value())), + Type::And => Ok(and(ret_ia1, ret_ia2).map(|x| x.to_scalar_value())), _ => unimplemented!(), } } } -impl BinaryShortCircuitExpression { - pub fn new(expr_ia1: BoxedExpression, expr_ia2: BoxedExpression, expr_type: Type) -> Self { - Self { - expr_ia1, - expr_ia2, - expr_type, - } - } +#[build_function("and(boolean, boolean) -> boolean")] +fn build_and_expr(_: DataType, children: Vec) -> Result { + let mut iter = children.into_iter(); + Ok(Box::new(BinaryShortCircuitExpression { + expr_ia1: iter.next().unwrap(), + expr_ia2: iter.next().unwrap(), + expr_type: Type::And, + })) } -pub fn new_nullable_binary_expr( - expr_type: Type, - ret: DataType, - l: BoxedExpression, - r: BoxedExpression, -) -> Result { - let expr = match expr_type { - Type::ArrayAccess => build_array_access_expr(ret, l, r), - Type::And => Box::new(BinaryShortCircuitExpression::new(l, r, expr_type)), - Type::Or => Box::new(BinaryShortCircuitExpression::new(l, r, expr_type)), - Type::IsDistinctFrom => new_distinct_from_expr(l, r, ret)?, - Type::IsNotDistinctFrom => new_not_distinct_from_expr(l, r, ret)?, - Type::FormatType => new_format_type_expr(l, r, ret), - tp => { - return Err(ExprError::UnsupportedFunction(format!( - "{:?}({:?}, {:?})", - tp, - l.return_type(), - r.return_type(), - ))); - } - }; - Ok(expr) -} - -fn build_array_access_expr( - ret: DataType, - l: BoxedExpression, - r: BoxedExpression, -) -> BoxedExpression { - macro_rules! array_access_expression { - ($array:ty) => { - Box::new( - BinaryNullableExpression::::new( - l, - r, - ret, - array_access, - ), - ) - }; - } - - match ret { - DataType::Boolean => array_access_expression!(BoolArray), - DataType::Int16 => array_access_expression!(I16Array), - DataType::Int32 => array_access_expression!(I32Array), - DataType::Int64 => array_access_expression!(I64Array), - DataType::Serial => array_access_expression!(SerialArray), - DataType::Float32 => array_access_expression!(F32Array), - DataType::Float64 => array_access_expression!(F64Array), - DataType::Decimal => array_access_expression!(DecimalArray), - DataType::Date => array_access_expression!(NaiveDateArray), - DataType::Varchar => array_access_expression!(Utf8Array), - DataType::Bytea => array_access_expression!(BytesArray), - DataType::Time => array_access_expression!(NaiveTimeArray), - DataType::Timestamp => array_access_expression!(NaiveDateTimeArray), - DataType::Timestamptz => array_access_expression!(PrimitiveArray::), - DataType::Interval => array_access_expression!(IntervalArray), - DataType::Jsonb => array_access_expression!(JsonbArray), - DataType::Struct { .. } => array_access_expression!(StructArray), - DataType::List { .. } => array_access_expression!(ListArray), - } -} - -pub fn new_distinct_from_expr( - l: BoxedExpression, - r: BoxedExpression, - ret: DataType, -) -> Result { - use crate::expr::data_types::*; - - let expr: BoxedExpression = match (l.return_type(), r.return_type()) { - (DataType::Boolean, DataType::Boolean) => template_fast::BooleanBinaryExpression::new( - l, - r, - |l, r| { - let data = ((l.data() ^ r.data()) & (l.null_bitmap() & r.null_bitmap())) - | (l.null_bitmap() ^ r.null_bitmap()); - BoolArray::new(data, Bitmap::ones(l.len())) - }, - |l, r| Some(general_is_distinct_from::(l, r)), - ) - .boxed(), - (DataType::Varchar, DataType::Varchar) => Box::new(BinaryNullableExpression::< - Utf8Array, - Utf8Array, - BoolArray, - _, - >::new( - l, r, ret, str_is_distinct_from - )), - _ => { - for_all_cmp_variants! { gen_is_distinct_from_impl, l, r, ret, false } - } - }; - Ok(expr) -} - -pub fn new_not_distinct_from_expr( - l: BoxedExpression, - r: BoxedExpression, - ret: DataType, -) -> Result { - use crate::expr::data_types::*; - - let expr: BoxedExpression = match (l.return_type(), r.return_type()) { - (DataType::Boolean, DataType::Boolean) => template_fast::BooleanBinaryExpression::new( - l, - r, - |l, r| { - let data = !(((l.data() ^ r.data()) & (l.null_bitmap() & r.null_bitmap())) - | (l.null_bitmap() ^ r.null_bitmap())); - BoolArray::new(data, Bitmap::ones(l.len())) - }, - |l, r| Some(general_is_not_distinct_from::(l, r)), - ) - .boxed(), - (DataType::Varchar, DataType::Varchar) => Box::new(BinaryNullableExpression::< - Utf8Array, - Utf8Array, - BoolArray, - _, - >::new( - l, r, ret, str_is_not_distinct_from - )), - _ => { - for_all_cmp_variants! { gen_is_distinct_from_impl, l, r, ret, true } - } - }; - Ok(expr) -} - -pub fn new_format_type_expr( - expr_ia1: BoxedExpression, - expr_ia2: BoxedExpression, - return_type: DataType, -) -> BoxedExpression { - Box::new( - BinaryNullableExpression::::new( - expr_ia1, - expr_ia2, - return_type, - format_type, - ), - ) +#[build_function("or(boolean, boolean) -> boolean")] +fn build_or_expr(_: DataType, children: Vec) -> Result { + let mut iter = children.into_iter(); + Ok(Box::new(BinaryShortCircuitExpression { + expr_ia1: iter.next().unwrap(), + expr_ia2: iter.next().unwrap(), + expr_type: Type::Or, + })) } #[cfg(test)] mod tests { use risingwave_common::row::OwnedRow; - use risingwave_common::types::Scalar; - use risingwave_pb::data::data_type::TypeName; + use risingwave_common::types::{DataType, Scalar}; use risingwave_pb::expr::expr_node::Type; - use crate::expr::build_from_prost; - use crate::expr::test_utils::make_expression; + use crate::expr::{build, Expression, InputRefExpression}; #[tokio::test] async fn test_and() { @@ -349,15 +174,22 @@ mod tests { None, ]; - let expr = make_expression(Type::And, &[TypeName::Boolean, TypeName::Boolean], &[0, 1]); - let vec_executor = build_from_prost(&expr).unwrap(); + let expr = build( + Type::And, + DataType::Boolean, + vec![ + InputRefExpression::new(DataType::Boolean, 0).boxed(), + InputRefExpression::new(DataType::Boolean, 1).boxed(), + ], + ) + .unwrap(); for i in 0..lhs.len() { let row = OwnedRow::new(vec![ lhs[i].map(|x| x.to_scalar_value()), rhs[i].map(|x| x.to_scalar_value()), ]); - let res = vec_executor.eval_row(&row).await.unwrap(); + let res = expr.eval_row(&row).await.unwrap(); let expected = target[i].map(|x| x.to_scalar_value()); assert_eq!(res, expected); } @@ -399,15 +231,22 @@ mod tests { None, ]; - let expr = make_expression(Type::Or, &[TypeName::Boolean, TypeName::Boolean], &[0, 1]); - let vec_executor = build_from_prost(&expr).unwrap(); + let expr = build( + Type::Or, + DataType::Boolean, + vec![ + InputRefExpression::new(DataType::Boolean, 0).boxed(), + InputRefExpression::new(DataType::Boolean, 1).boxed(), + ], + ) + .unwrap(); for i in 0..lhs.len() { let row = OwnedRow::new(vec![ lhs[i].map(|x| x.to_scalar_value()), rhs[i].map(|x| x.to_scalar_value()), ]); - let res = vec_executor.eval_row(&row).await.unwrap(); + let res = expr.eval_row(&row).await.unwrap(); let expected = target[i].map(|x| x.to_scalar_value()); assert_eq!(res, expected); } @@ -419,19 +258,22 @@ mod tests { let rhs = vec![None, Some(1), None, Some(2), Some(4)]; let target = vec![Some(false), Some(true), Some(true), Some(false), Some(true)]; - let expr = make_expression( + let expr = build( Type::IsDistinctFrom, - &[TypeName::Int32, TypeName::Int32], - &[0, 1], - ); - let vec_executor = build_from_prost(&expr).unwrap(); + DataType::Boolean, + vec![ + InputRefExpression::new(DataType::Int32, 0).boxed(), + InputRefExpression::new(DataType::Int32, 1).boxed(), + ], + ) + .unwrap(); for i in 0..lhs.len() { let row = OwnedRow::new(vec![ lhs[i].map(|x| x.to_scalar_value()), rhs[i].map(|x| x.to_scalar_value()), ]); - let res = vec_executor.eval_row(&row).await.unwrap(); + let res = expr.eval_row(&row).await.unwrap(); let expected = target[i].map(|x| x.to_scalar_value()); assert_eq!(res, expected); } @@ -449,19 +291,22 @@ mod tests { Some(false), ]; - let expr = make_expression( + let expr = build( Type::IsNotDistinctFrom, - &[TypeName::Int32, TypeName::Int32], - &[0, 1], - ); - let vec_executor = build_from_prost(&expr).unwrap(); + DataType::Boolean, + vec![ + InputRefExpression::new(DataType::Int32, 0).boxed(), + InputRefExpression::new(DataType::Int32, 1).boxed(), + ], + ) + .unwrap(); for i in 0..lhs.len() { let row = OwnedRow::new(vec![ lhs[i].map(|x| x.to_scalar_value()), rhs[i].map(|x| x.to_scalar_value()), ]); - let res = vec_executor.eval_row(&row).await.unwrap(); + let res = expr.eval_row(&row).await.unwrap(); let expected = target[i].map(|x| x.to_scalar_value()); assert_eq!(res, expected); } @@ -477,19 +322,22 @@ mod tests { Some("???".into()), None, ]; - let expr = make_expression( + let expr = build( Type::FormatType, - &[TypeName::Int32, TypeName::Int32], - &[0, 1], - ); - let vec_executor = build_from_prost(&expr).unwrap(); + DataType::Varchar, + vec![ + InputRefExpression::new(DataType::Int32, 0).boxed(), + InputRefExpression::new(DataType::Int32, 1).boxed(), + ], + ) + .unwrap(); for i in 0..l.len() { let row = OwnedRow::new(vec![ l[i].map(|x| x.to_scalar_value()), r[i].map(|x| x.to_scalar_value()), ]); - let res = vec_executor.eval_row(&row).await.unwrap(); + let res = expr.eval_row(&row).await.unwrap(); let expected = target[i].as_ref().map(|x| x.into()); assert_eq!(res, expected); } diff --git a/src/expr/src/expr/expr_case.rs b/src/expr/src/expr/expr_case.rs index 96876b33b67bd..5680e2c648e29 100644 --- a/src/expr/src/expr/expr_case.rs +++ b/src/expr/src/expr/expr_case.rs @@ -18,7 +18,7 @@ use risingwave_common::array::{ArrayRef, DataChunk, Vis}; use risingwave_common::row::OwnedRow; use risingwave_common::types::{DataType, Datum}; use risingwave_common::{bail, ensure}; -use risingwave_pb::expr::expr_node::{RexNode, Type}; +use risingwave_pb::expr::expr_node::{PbType, RexNode}; use risingwave_pb::expr::ExprNode; use crate::expr::{build_from_prost, BoxedExpression, Expression}; @@ -26,14 +26,8 @@ use crate::{ExprError, Result}; #[derive(Debug)] pub struct WhenClause { - pub when: BoxedExpression, - pub then: BoxedExpression, -} - -impl WhenClause { - pub fn new(when: BoxedExpression, then: BoxedExpression) -> Self { - WhenClause { when, then } - } + when: BoxedExpression, + then: BoxedExpression, } #[derive(Debug)] @@ -122,7 +116,7 @@ impl<'a> TryFrom<&'a ExprNode> for CaseExpression { type Error = ExprError; fn try_from(prost: &'a ExprNode) -> Result { - ensure!(prost.get_expr_type().unwrap() == Type::Case); + ensure!(prost.get_expr_type().unwrap() == PbType::Case); let ret_type = DataType::from(prost.get_return_type().unwrap()); let RexNode::FuncCall(func_call_node) = prost.get_rex_node().unwrap() else { @@ -152,7 +146,10 @@ impl<'a> TryFrom<&'a ExprNode> for CaseExpression { if then_expr.return_type() != ret_type { bail!("Type mismatched between then clause and case"); } - let when_clause = WhenClause::new(when_expr, then_expr); + let when_clause = WhenClause { + when: when_expr, + then: then_expr, + }; when_clauses.push(when_clause); } Ok(CaseExpression::new(ret_type, when_clauses, else_clause)) @@ -163,47 +160,10 @@ impl<'a> TryFrom<&'a ExprNode> for CaseExpression { mod tests { use risingwave_common::test_prelude::DataChunkTestExt; use risingwave_common::types::Scalar; - use risingwave_pb::data::data_type::TypeName; - use risingwave_pb::data::PbDataType; - use risingwave_pb::expr::expr_node::Type; - use risingwave_pb::expr::FunctionCall; + use risingwave_pb::expr::expr_node::PbType; use super::*; - use crate::expr::expr_binary_nonnull::new_binary_expr; - use crate::expr::{InputRefExpression, LiteralExpression}; - - #[test] - fn test_case_expr() { - let call = FunctionCall { - children: vec![ - ExprNode { - expr_type: Type::ConstantValue as i32, - return_type: Some(PbDataType { - type_name: TypeName::Boolean as i32, - ..Default::default() - }), - rex_node: None, - }, - ExprNode { - expr_type: Type::ConstantValue as i32, - return_type: Some(PbDataType { - type_name: TypeName::Int32 as i32, - ..Default::default() - }), - rex_node: None, - }, - ], - }; - let p = ExprNode { - expr_type: Type::Case as i32, - return_type: Some(PbDataType { - type_name: TypeName::Int32 as i32, - ..Default::default() - }), - rex_node: Some(RexNode::FuncCall(call)), - }; - assert!(CaseExpression::try_from(&p).is_ok()); - } + use crate::expr::{build, InputRefExpression, LiteralExpression}; async fn test_eval_row(expr: CaseExpression, row_inputs: Vec, expected: Vec>) { for (i, row_input) in row_inputs.iter().enumerate() { @@ -218,19 +178,21 @@ mod tests { async fn test_eval_searched_case() { let ret_type = DataType::Float32; // when x <= 2 then 3.1 - let when_clauses = vec![WhenClause::new( - new_binary_expr( - Type::LessThanOrEqual, + let when_clauses = vec![WhenClause { + when: build( + PbType::LessThanOrEqual, DataType::Boolean, - Box::new(InputRefExpression::new(DataType::Int32, 0)), - Box::new(LiteralExpression::new(DataType::Float32, Some(2f32.into()))), + vec![ + Box::new(InputRefExpression::new(DataType::Int32, 0)), + Box::new(LiteralExpression::new(DataType::Float32, Some(2f32.into()))), + ], ) .unwrap(), - Box::new(LiteralExpression::new( + then: Box::new(LiteralExpression::new( DataType::Float32, Some(3.1f32.into()), )), - )]; + }]; // else 4.1 let els = Box::new(LiteralExpression::new( DataType::Float32, @@ -257,19 +219,21 @@ mod tests { async fn test_eval_without_else() { let ret_type = DataType::Float32; // when x <= 3 then 3.1 - let when_clauses = vec![WhenClause::new( - new_binary_expr( - Type::LessThanOrEqual, + let when_clauses = vec![WhenClause { + when: build( + PbType::LessThanOrEqual, DataType::Boolean, - Box::new(InputRefExpression::new(DataType::Int32, 0)), - Box::new(LiteralExpression::new(DataType::Float32, Some(3f32.into()))), + vec![ + Box::new(InputRefExpression::new(DataType::Int32, 0)), + Box::new(LiteralExpression::new(DataType::Float32, Some(3f32.into()))), + ], ) .unwrap(), - Box::new(LiteralExpression::new( + then: Box::new(LiteralExpression::new( DataType::Float32, Some(3.1f32.into()), )), - )]; + }]; let searched_case_expr = CaseExpression::new(ret_type, when_clauses, None); let input = DataChunk::from_pretty( "i @@ -289,19 +253,21 @@ mod tests { async fn test_eval_row_searched_case() { let ret_type = DataType::Float32; // when x <= 2 then 3.1 - let when_clauses = vec![WhenClause::new( - new_binary_expr( - Type::LessThanOrEqual, + let when_clauses = vec![WhenClause { + when: build( + PbType::LessThanOrEqual, DataType::Boolean, - Box::new(InputRefExpression::new(DataType::Int32, 0)), - Box::new(LiteralExpression::new(DataType::Float32, Some(2f32.into()))), + vec![ + Box::new(InputRefExpression::new(DataType::Int32, 0)), + Box::new(LiteralExpression::new(DataType::Float32, Some(2f32.into()))), + ], ) .unwrap(), - Box::new(LiteralExpression::new( + then: Box::new(LiteralExpression::new( DataType::Float32, Some(3.1f32.into()), )), - )]; + }]; // else 4.1 let els = Box::new(LiteralExpression::new( DataType::Float32, @@ -325,19 +291,21 @@ mod tests { async fn test_eval_row_without_else() { let ret_type = DataType::Float32; // when x <= 3 then 3.1 - let when_clauses = vec![WhenClause::new( - new_binary_expr( - Type::LessThanOrEqual, + let when_clauses = vec![WhenClause { + when: build( + PbType::LessThanOrEqual, DataType::Boolean, - Box::new(InputRefExpression::new(DataType::Int32, 0)), - Box::new(LiteralExpression::new(DataType::Float32, Some(3f32.into()))), + vec![ + Box::new(InputRefExpression::new(DataType::Int32, 0)), + Box::new(LiteralExpression::new(DataType::Float32, Some(3f32.into()))), + ], ) .unwrap(), - Box::new(LiteralExpression::new( + then: Box::new(LiteralExpression::new( DataType::Float32, Some(3.1f32.into()), )), - )]; + }]; let searched_case_expr = CaseExpression::new(ret_type, when_clauses, None); let row_inputs = vec![2, 3, 4, 5]; diff --git a/src/expr/src/expr/expr_is_null.rs b/src/expr/src/expr/expr_is_null.rs index 95c91920f9f74..3d6bf5081f4db 100644 --- a/src/expr/src/expr/expr_is_null.rs +++ b/src/expr/src/expr/expr_is_null.rs @@ -18,6 +18,7 @@ use risingwave_common::array::{ArrayImpl, ArrayRef, BoolArray, DataChunk}; use risingwave_common::buffer::Bitmap; use risingwave_common::row::OwnedRow; use risingwave_common::types::{DataType, Datum, Scalar}; +use risingwave_expr_macro::build_function; use crate::expr::{BoxedExpression, Expression}; use crate::Result; @@ -33,13 +34,13 @@ pub struct IsNotNullExpression { } impl IsNullExpression { - pub(crate) fn new(child: BoxedExpression) -> Self { + fn new(child: BoxedExpression) -> Self { Self { child } } } impl IsNotNullExpression { - pub(crate) fn new(child: BoxedExpression) -> Self { + fn new(child: BoxedExpression) -> Self { Self { child } } } @@ -88,6 +89,20 @@ impl Expression for IsNotNullExpression { } } +#[build_function("is_null(*) -> boolean")] +fn build_is_null_expr(_: DataType, children: Vec) -> Result { + Ok(Box::new(IsNullExpression::new( + children.into_iter().next().unwrap(), + ))) +} + +#[build_function("is_not_null(*) -> boolean")] +fn build_is_not_null_expr(_: DataType, children: Vec) -> Result { + Ok(Box::new(IsNotNullExpression::new( + children.into_iter().next().unwrap(), + ))) +} + #[cfg(test)] mod tests { use std::str::FromStr; diff --git a/src/expr/src/expr/expr_jsonb_access.rs b/src/expr/src/expr/expr_jsonb_access.rs index 8445976c74591..e8f159ad0070a 100644 --- a/src/expr/src/expr/expr_jsonb_access.rs +++ b/src/expr/src/expr/expr_jsonb_access.rs @@ -14,14 +14,16 @@ use either::Either; use risingwave_common::array::{ - Array, ArrayBuilder, ArrayImpl, ArrayRef, DataChunk, JsonbArray, JsonbArrayBuilder, JsonbRef, - Utf8ArrayBuilder, + Array, ArrayBuilder, ArrayImpl, ArrayRef, DataChunk, I32Array, JsonbArray, JsonbArrayBuilder, + JsonbRef, Utf8Array, Utf8ArrayBuilder, }; use risingwave_common::row::OwnedRow; use risingwave_common::types::{DataType, Datum, Scalar, ScalarRef}; use risingwave_common::util::iter_util::ZipEqFast; +use risingwave_expr_macro::build_function; use super::{BoxedExpression, Expression}; +use crate::Result; /// This is forked from [`BinaryExpression`] for the following reasons: /// * Optimize for the case when rhs path is const. (not implemented yet) @@ -221,3 +223,149 @@ impl AccessOutput for Utf8ArrayBuilder { } } } + +#[build_function("jsonb_access_inner(jsonb, varchar) -> jsonb")] +fn build_jsonb_access_object_field( + _return_type: DataType, + children: Vec, +) -> Result { + let mut iter = children.into_iter(); + let l = iter.next().unwrap(); + let r = iter.next().unwrap(); + Ok( + JsonbAccessExpression::::new_expr( + l, + r, + jsonb_object_field, + ) + .boxed(), + ) +} + +#[build_function("jsonb_access_inner(jsonb, int32) -> jsonb")] +fn build_jsonb_access_array_element( + _return_type: DataType, + children: Vec, +) -> Result { + let mut iter = children.into_iter(); + let l = iter.next().unwrap(); + let r = iter.next().unwrap(); + Ok( + JsonbAccessExpression::::new_expr( + l, + r, + jsonb_array_element, + ) + .boxed(), + ) +} + +#[build_function("jsonb_access_str(jsonb, varchar) -> varchar")] +fn build_jsonb_access_object_field_str( + _return_type: DataType, + children: Vec, +) -> Result { + let mut iter = children.into_iter(); + let l = iter.next().unwrap(); + let r = iter.next().unwrap(); + Ok( + JsonbAccessExpression::::new_expr(l, r, jsonb_object_field) + .boxed(), + ) +} + +#[build_function("jsonb_access_str(jsonb, int32) -> varchar")] +fn build_jsonb_access_array_element_str( + _return_type: DataType, + children: Vec, +) -> Result { + let mut iter = children.into_iter(); + let l = iter.next().unwrap(); + let r = iter.next().unwrap(); + Ok( + JsonbAccessExpression::::new_expr(l, r, jsonb_array_element) + .boxed(), + ) +} + +#[cfg(test)] +mod tests { + use std::vec; + + use risingwave_common::array::{ArrayImpl, DataChunk, Utf8Array}; + use risingwave_common::types::Scalar; + use risingwave_common::util::value_encoding::serialize_datum; + use risingwave_pb::data::data_type::TypeName; + use risingwave_pb::data::{DataType as ProstDataType, Datum as ProstDatum}; + use risingwave_pb::expr::expr_node::{RexNode, Type}; + use risingwave_pb::expr::{ExprNode, FunctionCall}; + + use crate::expr::build_from_prost; + + #[tokio::test] + async fn test_array_access_expr() { + let values = FunctionCall { + children: vec![ + ExprNode { + expr_type: Type::ConstantValue as i32, + return_type: Some(ProstDataType { + type_name: TypeName::Varchar as i32, + ..Default::default() + }), + rex_node: Some(RexNode::Constant(ProstDatum { + body: serialize_datum(Some("foo".into()).as_ref()), + })), + }, + ExprNode { + expr_type: Type::ConstantValue as i32, + return_type: Some(ProstDataType { + type_name: TypeName::Varchar as i32, + ..Default::default() + }), + rex_node: Some(RexNode::Constant(ProstDatum { + body: serialize_datum(Some("bar".into()).as_ref()), + })), + }, + ], + }; + let array_index = FunctionCall { + children: vec![ + ExprNode { + expr_type: Type::Array as i32, + return_type: Some(ProstDataType { + type_name: TypeName::List as i32, + field_type: vec![ProstDataType { + type_name: TypeName::Varchar as i32, + ..Default::default() + }], + ..Default::default() + }), + rex_node: Some(RexNode::FuncCall(values)), + }, + ExprNode { + expr_type: Type::ConstantValue as i32, + return_type: Some(ProstDataType { + type_name: TypeName::Int32 as i32, + ..Default::default() + }), + rex_node: Some(RexNode::Constant(ProstDatum { + body: serialize_datum(Some(1_i32.to_scalar_value()).as_ref()), + })), + }, + ], + }; + let access = ExprNode { + expr_type: Type::ArrayAccess as i32, + return_type: Some(ProstDataType { + type_name: TypeName::Varchar as i32, + ..Default::default() + }), + rex_node: Some(RexNode::FuncCall(array_index)), + }; + let expr = build_from_prost(&access); + assert!(expr.is_ok()); + + let res = expr.unwrap().eval(&DataChunk::new_dummy(1)).await.unwrap(); + assert_eq!(*res, ArrayImpl::Utf8(Utf8Array::from_iter(["foo"]))); + } +} diff --git a/src/expr/src/expr/expr_literal.rs b/src/expr/src/expr/expr_literal.rs index 55da473a73098..c1b3bb1beebb4 100644 --- a/src/expr/src/expr/expr_literal.rs +++ b/src/expr/src/expr/expr_literal.rs @@ -71,6 +71,10 @@ impl Expression for LiteralExpression { async fn eval_row(&self, _input: &OwnedRow) -> Result { Ok(self.literal.as_ref().cloned()) } + + fn eval_const(&self) -> Result { + Ok(self.literal.clone()) + } } impl LiteralExpression { diff --git a/src/expr/src/expr/expr_now.rs b/src/expr/src/expr/expr_now.rs new file mode 100644 index 0000000000000..5791083787a72 --- /dev/null +++ b/src/expr/src/expr/expr_now.rs @@ -0,0 +1,24 @@ +// Copyright 2023 RisingWave Labs +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use risingwave_common::types::DataType; +use risingwave_expr_macro::build_function; + +use super::{BoxedExpression, Result}; + +#[build_function("now(timestamptz) -> timestamptz")] +fn build_now_expr(_: DataType, children: Vec) -> Result { + // there should be exact 1 child containing a timestamp literal + Ok(children.into_iter().next().unwrap()) +} diff --git a/src/expr/src/expr/expr_quaternary_bytes.rs b/src/expr/src/expr/expr_quaternary_bytes.rs deleted file mode 100644 index b1a2c097a54f2..0000000000000 --- a/src/expr/src/expr/expr_quaternary_bytes.rs +++ /dev/null @@ -1,102 +0,0 @@ -// Copyright 2023 RisingWave Labs -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -//! For expression that accept 4 arguments + 1 bytes writer as input. - -use risingwave_common::array::{I32Array, Utf8Array}; -use risingwave_common::types::DataType; - -use crate::expr::template::QuaternaryBytesExpression; -use crate::expr::BoxedExpression; -use crate::vector_op::overlay::overlay_for; - -pub fn new_overlay_for_exp( - s: BoxedExpression, - new_sub_str: BoxedExpression, - start: BoxedExpression, - count: BoxedExpression, - return_type: DataType, -) -> BoxedExpression { - Box::new(QuaternaryBytesExpression::< - Utf8Array, - Utf8Array, - I32Array, - I32Array, - _, - >::new( - s, new_sub_str, start, count, return_type, overlay_for - )) -} - -#[cfg(test)] -mod tests { - use risingwave_common::array::DataChunk; - use risingwave_common::row::OwnedRow; - use risingwave_common::types::{Datum, ScalarImpl}; - - use super::*; - use crate::expr::LiteralExpression; - - async fn test_evals_dummy(expr: BoxedExpression, expected: Datum, is_negative_len: bool) { - let res = expr.eval(&DataChunk::new_dummy(1)).await; - if is_negative_len { - assert!(res.is_err()); - } else { - assert_eq!(res.unwrap().to_datum(), expected); - } - - let res = expr.eval_row(&OwnedRow::new(vec![])).await; - if is_negative_len { - assert!(res.is_err()); - } else { - assert_eq!(res.unwrap(), expected); - } - } - - #[tokio::test] - async fn test_overlay() { - let cases = vec![ - ("aaa", "XY", 1, 0, "XYaaa"), - ("aaa_aaa", "XYZ", 4, 1, "aaaXYZaaa"), - ("aaaaaa", "XYZ", 4, 0, "aaaXYZaaa"), - ("aaa___aaa", "X", 4, 3, "aaaXaaa"), - ("aaa", "X", 4, -123, "aaaX"), - ("aaa_", "X", 4, 123, "aaaX"), - ]; - - for (s, new_sub_str, start, count, expected) in cases { - let expr = new_overlay_for_exp( - Box::new(LiteralExpression::new( - DataType::Varchar, - Some(ScalarImpl::from(String::from(s))), - )), - Box::new(LiteralExpression::new( - DataType::Varchar, - Some(ScalarImpl::from(String::from(new_sub_str))), - )), - Box::new(LiteralExpression::new( - DataType::Int32, - Some(ScalarImpl::from(start)), - )), - Box::new(LiteralExpression::new( - DataType::Int32, - Some(ScalarImpl::from(count)), - )), - DataType::Varchar, - ); - - test_evals_dummy(expr, Some(ScalarImpl::from(String::from(expected))), false).await; - } - } -} diff --git a/src/expr/src/expr/expr_some_all.rs b/src/expr/src/expr/expr_some_all.rs index 749ce3951f386..b36d16f1a881b 100644 --- a/src/expr/src/expr/expr_some_all.rs +++ b/src/expr/src/expr/expr_some_all.rs @@ -19,10 +19,13 @@ use risingwave_common::array::{Array, ArrayMeta, ArrayRef, BoolArray, DataChunk} use risingwave_common::row::OwnedRow; use risingwave_common::types::{DataType, Datum, Scalar, ScalarImpl, ScalarRefImpl}; use risingwave_common::util::iter_util::ZipEqFast; -use risingwave_pb::expr::expr_node::Type; +use risingwave_common::{bail, ensure}; +use risingwave_pb::expr::expr_node::{RexNode, Type}; +use risingwave_pb::expr::{ExprNode, FunctionCall}; -use super::{BoxedExpression, Expression}; -use crate::Result; +use super::build_expr_from_prost::get_children_and_return_type; +use super::{build_from_prost, BoxedExpression, Expression}; +use crate::{ExprError, Result}; #[derive(Debug)] pub struct SomeAllExpression { @@ -194,3 +197,67 @@ impl Expression for SomeAllExpression { } } } + +impl<'a> TryFrom<&'a ExprNode> for SomeAllExpression { + type Error = ExprError; + + fn try_from(prost: &'a ExprNode) -> Result { + let outer_expr_type = prost.get_expr_type().unwrap(); + let (outer_children, outer_return_type) = get_children_and_return_type(prost)?; + ensure!(matches!(outer_return_type, DataType::Boolean)); + + let mut inner_expr_type = outer_children[0].get_expr_type().unwrap(); + let (mut inner_children, mut inner_return_type) = + get_children_and_return_type(&outer_children[0])?; + let mut stack = vec![]; + while inner_children.len() != 2 { + stack.push((inner_expr_type, inner_return_type)); + inner_expr_type = inner_children[0].get_expr_type().unwrap(); + (inner_children, inner_return_type) = get_children_and_return_type(&inner_children[0])?; + } + + let left_expr = build_from_prost(&inner_children[0])?; + let right_expr = build_from_prost(&inner_children[1])?; + + let DataType::List { datatype: right_expr_return_type } = right_expr.return_type() else { + bail!("Expect Array Type"); + }; + + let eval_func = { + let left_expr_input_ref = ExprNode { + expr_type: Type::InputRef as i32, + return_type: Some(left_expr.return_type().to_protobuf()), + rex_node: Some(RexNode::InputRef(0)), + }; + let right_expr_input_ref = ExprNode { + expr_type: Type::InputRef as i32, + return_type: Some(right_expr_return_type.to_protobuf()), + rex_node: Some(RexNode::InputRef(1)), + }; + let mut root_expr_node = ExprNode { + expr_type: inner_expr_type as i32, + return_type: Some(inner_return_type.to_protobuf()), + rex_node: Some(RexNode::FuncCall(FunctionCall { + children: vec![left_expr_input_ref, right_expr_input_ref], + })), + }; + while let Some((expr_type, return_type)) = stack.pop() { + root_expr_node = ExprNode { + expr_type: expr_type as i32, + return_type: Some(return_type.to_protobuf()), + rex_node: Some(RexNode::FuncCall(FunctionCall { + children: vec![root_expr_node], + })), + } + } + build_from_prost(&root_expr_node)? + }; + + Ok(SomeAllExpression::new( + left_expr, + right_expr, + outer_expr_type, + eval_func, + )) + } +} diff --git a/src/expr/src/expr/expr_ternary.rs b/src/expr/src/expr/expr_ternary.rs deleted file mode 100644 index e249a5ee73c1b..0000000000000 --- a/src/expr/src/expr/expr_ternary.rs +++ /dev/null @@ -1,80 +0,0 @@ -// Copyright 2023 RisingWave Labs -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -use risingwave_common::array::{I64Array, IntervalArray, NaiveDateArray, NaiveDateTimeArray}; -use risingwave_common::types::DataType; - -use super::template::TernaryExpression; -use super::BoxedExpression; -use crate::vector_op::tumble::{ - tumble_start_offset_date, tumble_start_offset_date_time, tumble_start_offset_timestamptz, -}; -use crate::{ExprError, Result}; - -pub(crate) fn new_tumble_start_offset( - time: BoxedExpression, - window_size: BoxedExpression, - offset: BoxedExpression, - return_type: DataType, -) -> Result { - let expr: BoxedExpression = match time.return_type() { - DataType::Date => Box::new(TernaryExpression::< - NaiveDateArray, - IntervalArray, - IntervalArray, - NaiveDateTimeArray, - _, - >::new( - time, - window_size, - offset, - return_type, - tumble_start_offset_date, - )), - DataType::Timestamp => Box::new(TernaryExpression::< - NaiveDateTimeArray, - IntervalArray, - IntervalArray, - NaiveDateTimeArray, - _, - >::new( - time, - window_size, - offset, - return_type, - tumble_start_offset_date_time, - )), - DataType::Timestamptz => Box::new(TernaryExpression::< - I64Array, - IntervalArray, - IntervalArray, - I64Array, - _, - >::new( - time, - window_size, - offset, - return_type, - tumble_start_offset_timestamptz, - )), - _ => { - return Err(ExprError::UnsupportedFunction(format!( - "tumble_start_offset is not supported for {:?}", - time.return_type() - ))) - } - }; - - Ok(expr) -} diff --git a/src/expr/src/expr/expr_ternary_bytes.rs b/src/expr/src/expr/expr_ternary_bytes.rs deleted file mode 100644 index 5e0a63cceef89..0000000000000 --- a/src/expr/src/expr/expr_ternary_bytes.rs +++ /dev/null @@ -1,261 +0,0 @@ -// Copyright 2023 RisingWave Labs -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -//! For expression that accept 3 arguments + 1 bytes writer as input. - -use risingwave_common::array::{I32Array, Utf8Array}; -use risingwave_common::types::DataType; - -use crate::expr::template::TernaryBytesExpression; -use crate::expr::BoxedExpression; -use crate::vector_op::overlay::overlay; -use crate::vector_op::replace::replace; -use crate::vector_op::split_part::split_part; -use crate::vector_op::substr::substr_start_for; -use crate::vector_op::translate::translate; - -pub fn new_substr_start_end( - items: BoxedExpression, - off: BoxedExpression, - len: BoxedExpression, - return_type: DataType, -) -> BoxedExpression { - Box::new( - TernaryBytesExpression::::new( - items, - off, - len, - return_type, - substr_start_for, - ), - ) -} - -pub fn new_replace_expr( - s: BoxedExpression, - from_str: BoxedExpression, - to_str: BoxedExpression, - return_type: DataType, -) -> BoxedExpression { - Box::new( - TernaryBytesExpression::::new( - s, - from_str, - to_str, - return_type, - replace, - ), - ) -} - -pub fn new_translate_expr( - s: BoxedExpression, - match_str: BoxedExpression, - replace_str: BoxedExpression, - return_type: DataType, -) -> BoxedExpression { - Box::new( - TernaryBytesExpression::::new( - s, - match_str, - replace_str, - return_type, - translate, - ), - ) -} - -pub fn new_split_part_expr( - string_expr: BoxedExpression, - delimiter_expr: BoxedExpression, - nth_expr: BoxedExpression, - return_type: DataType, -) -> BoxedExpression { - Box::new( - TernaryBytesExpression::::new( - string_expr, - delimiter_expr, - nth_expr, - return_type, - split_part, - ), - ) -} - -pub fn new_overlay_exp( - s: BoxedExpression, - new_sub_str: BoxedExpression, - start: BoxedExpression, - return_type: DataType, -) -> BoxedExpression { - Box::new( - TernaryBytesExpression::::new( - s, - new_sub_str, - start, - return_type, - overlay, - ), - ) -} - -#[cfg(test)] -mod tests { - use risingwave_common::array::DataChunk; - use risingwave_common::row::OwnedRow; - use risingwave_common::types::{Datum, ScalarImpl}; - - use super::*; - use crate::expr::LiteralExpression; - - async fn test_evals_dummy(expr: BoxedExpression, expected: Datum, is_negative_len: bool) { - let res = expr.eval(&DataChunk::new_dummy(1)).await; - if is_negative_len { - assert!(res.is_err()); - } else { - assert_eq!(res.unwrap().to_datum(), expected); - } - - let res = expr.eval_row(&OwnedRow::new(vec![])).await; - if is_negative_len { - assert!(res.is_err()); - } else { - assert_eq!(res.unwrap(), expected); - } - } - - #[tokio::test] - async fn test_substr_start_end() { - let text = "quick brown"; - let cases = [ - ( - Some(ScalarImpl::Int32(4)), - Some(ScalarImpl::Int32(2)), - Some(ScalarImpl::from(String::from("ck"))), - ), - ( - Some(ScalarImpl::Int32(-1)), - Some(ScalarImpl::Int32(5)), - Some(ScalarImpl::from(String::from("qui"))), - ), - ( - Some(ScalarImpl::Int32(0)), - Some(ScalarImpl::Int32(20)), - Some(ScalarImpl::from(String::from("quick brown"))), - ), - ( - Some(ScalarImpl::Int32(12)), - Some(ScalarImpl::Int32(20)), - Some(ScalarImpl::from(String::from(""))), - ), - ( - Some(ScalarImpl::Int32(5)), - Some(ScalarImpl::Int32(0)), - Some(ScalarImpl::from(String::from(""))), - ), - ( - Some(ScalarImpl::Int32(5)), - Some(ScalarImpl::Int32(-1)), - Some(ScalarImpl::from(String::from(""))), - ), - (Some(ScalarImpl::Int32(12)), None, None), - (None, Some(ScalarImpl::Int32(20)), None), - (None, None, None), - ]; - - for (start, len, expected) in cases { - let is_negative_len = matches!(len, Some(ScalarImpl::Int32(len_i32)) if len_i32 < 0); - let expr = new_substr_start_end( - Box::new(LiteralExpression::new( - DataType::Varchar, - Some(ScalarImpl::from(String::from(text))), - )), - Box::new(LiteralExpression::new(DataType::Int32, start)), - Box::new(LiteralExpression::new(DataType::Int32, len)), - DataType::Varchar, - ); - - test_evals_dummy(expr, expected, is_negative_len).await; - } - } - - #[tokio::test] - async fn test_replace() { - let cases = [ - ("hello, word", "我的", "world", "hello, word"), - ("hello, word", "", "world", "hello, word"), - ("hello, word", "word", "world", "hello, world"), - ("hello, world", "world", "", "hello, "), - ("你是❤️,是暖,是希望", "是", "非", "你非❤️,非暖,非希望"), - ("👴笑了", "👴", "爷爷", "爷爷笑了"), - ( - "НОЧЬ НА ОЧКРАИНЕ МОСКВЫ", - "ОЧ", - "Ы", - "НЫЬ НА ЫКРАИНЕ МОСКВЫ", - ), - ]; - - for (text, pattern, replacement, expected) in cases { - let expr = new_replace_expr( - Box::new(LiteralExpression::new( - DataType::Varchar, - Some(ScalarImpl::from(String::from(text))), - )), - Box::new(LiteralExpression::new( - DataType::Varchar, - Some(ScalarImpl::from(String::from(pattern))), - )), - Box::new(LiteralExpression::new( - DataType::Varchar, - Some(ScalarImpl::from(String::from(replacement))), - )), - DataType::Varchar, - ); - - test_evals_dummy(expr, Some(ScalarImpl::from(String::from(expected))), false).await; - } - } - - #[tokio::test] - async fn test_overlay() { - let cases = vec![ - ("aaa__aaa", "XY", 4, "aaaXYaaa"), - ("aaa", "XY", 3, "aaXY"), - ("aaa", "XY", 4, "aaaXY"), - ("aaa", "XY", -123, "XYa"), - ("aaa", "XY", 123, "aaaXY"), - ]; - - for (s, new_sub_str, start, expected) in cases { - let expr = new_overlay_exp( - Box::new(LiteralExpression::new( - DataType::Varchar, - Some(ScalarImpl::from(String::from(s))), - )), - Box::new(LiteralExpression::new( - DataType::Varchar, - Some(ScalarImpl::from(String::from(new_sub_str))), - )), - Box::new(LiteralExpression::new( - DataType::Int32, - Some(ScalarImpl::from(start)), - )), - DataType::Varchar, - ); - - test_evals_dummy(expr, Some(ScalarImpl::from(String::from(expected))), false).await; - } - } -} diff --git a/src/expr/src/expr/expr_to_char_const_tmpl.rs b/src/expr/src/expr/expr_to_char_const_tmpl.rs index 4544ee9947e00..3e77fa1d784b3 100644 --- a/src/expr/src/expr/expr_to_char_const_tmpl.rs +++ b/src/expr/src/expr/expr_to_char_const_tmpl.rs @@ -19,19 +19,21 @@ use risingwave_common::array::{Array, ArrayBuilder, NaiveDateTimeArray, Utf8Arra use risingwave_common::row::OwnedRow; use risingwave_common::types::{DataType, Datum, ScalarImpl}; use risingwave_common::util::iter_util::ZipEqFast; +use risingwave_expr_macro::build_function; -use super::Expression; -use crate::vector_op::to_char::ChronoPattern; +use super::{BoxedExpression, Expression, Result}; +use crate::expr::template::BinaryBytesExpression; +use crate::vector_op::to_char::{compile_pattern_to_chrono, to_char_timestamp, ChronoPattern}; #[derive(Debug)] -pub(crate) struct ExprToCharConstTmplContext { - pub(crate) chrono_pattern: ChronoPattern, +struct ExprToCharConstTmplContext { + chrono_pattern: ChronoPattern, } #[derive(Debug)] -pub(crate) struct ExprToCharConstTmpl { - pub(crate) child: Box, - pub(crate) ctx: ExprToCharConstTmplContext, +struct ExprToCharConstTmpl { + child: Box, + ctx: ExprToCharConstTmplContext, } #[async_trait::async_trait] @@ -79,3 +81,34 @@ impl Expression for ExprToCharConstTmpl { }) } } + +#[build_function("to_char(timestamp, varchar) -> varchar")] +fn build_to_char_expr( + return_type: DataType, + children: Vec, +) -> Result { + use risingwave_common::array::*; + + let mut iter = children.into_iter(); + let data_expr = iter.next().unwrap(); + let tmpl_expr = iter.next().unwrap(); + + Ok(if let Ok(Some(tmpl)) = tmpl_expr.eval_const() { + ExprToCharConstTmpl { + ctx: ExprToCharConstTmplContext { + chrono_pattern: compile_pattern_to_chrono(tmpl.as_utf8()), + }, + child: data_expr, + } + .boxed() + } else { + BinaryBytesExpression::::new( + data_expr, + tmpl_expr, + return_type, + #[allow(clippy::unit_arg)] + |a, b, w| Ok(to_char_timestamp(a, b, w)), + ) + .boxed() + }) +} diff --git a/src/expr/src/expr/expr_to_timestamp_const_tmpl.rs b/src/expr/src/expr/expr_to_timestamp_const_tmpl.rs index c5223437904ad..3b7f9f4f82a1c 100644 --- a/src/expr/src/expr/expr_to_timestamp_const_tmpl.rs +++ b/src/expr/src/expr/expr_to_timestamp_const_tmpl.rs @@ -18,20 +18,22 @@ use risingwave_common::array::{Array, ArrayBuilder, NaiveDateTimeArrayBuilder, U use risingwave_common::row::OwnedRow; use risingwave_common::types::{DataType, Datum, ScalarImpl}; use risingwave_common::util::iter_util::ZipEqFast; +use risingwave_expr_macro::build_function; -use super::Expression; -use crate::vector_op::to_char::ChronoPattern; -use crate::vector_op::to_timestamp::to_timestamp_const_tmpl; +use super::{BoxedExpression, Expression, Result}; +use crate::expr::template::BinaryExpression; +use crate::vector_op::to_char::{compile_pattern_to_chrono, ChronoPattern}; +use crate::vector_op::to_timestamp::{to_timestamp, to_timestamp_const_tmpl}; #[derive(Debug)] -pub(crate) struct ExprToTimestampConstTmplContext { - pub(crate) chrono_pattern: ChronoPattern, +struct ExprToTimestampConstTmplContext { + chrono_pattern: ChronoPattern, } #[derive(Debug)] -pub(crate) struct ExprToTimestampConstTmpl { - pub(crate) child: Box, - pub(crate) ctx: ExprToTimestampConstTmplContext, +struct ExprToTimestampConstTmpl { + child: Box, + ctx: ExprToTimestampConstTmplContext, } #[async_trait::async_trait] @@ -71,3 +73,33 @@ impl Expression for ExprToTimestampConstTmpl { }) } } + +#[build_function("to_timestamp1(varchar, varchar) -> timestamp")] +fn build_to_timestamp_expr( + return_type: DataType, + children: Vec, +) -> Result { + use risingwave_common::array::*; + + let mut iter = children.into_iter(); + let data_expr = iter.next().unwrap(); + let tmpl_expr = iter.next().unwrap(); + + Ok(if let Ok(Some(tmpl)) = tmpl_expr.eval_const() { + ExprToTimestampConstTmpl { + ctx: ExprToTimestampConstTmplContext { + chrono_pattern: compile_pattern_to_chrono(tmpl.as_utf8()), + }, + child: data_expr, + } + .boxed() + } else { + BinaryExpression::::new( + data_expr, + tmpl_expr, + return_type, + to_timestamp, + ) + .boxed() + }) +} diff --git a/src/expr/src/expr/expr_unary.rs b/src/expr/src/expr/expr_unary.rs index bacf418b5fb40..f5d77c083751e 100644 --- a/src/expr/src/expr/expr_unary.rs +++ b/src/expr/src/expr/expr_unary.rs @@ -14,373 +14,20 @@ //! For expression that only accept one value as input (e.g. CAST) -use risingwave_common::array::*; -use risingwave_common::buffer::Bitmap; -use risingwave_common::types::*; -use risingwave_pb::expr::expr_node::PbType; - -use super::expr_is_null::{IsNotNullExpression, IsNullExpression}; -use super::template::{UnaryBytesExpression, UnaryExpression}; -use super::template_fast::BooleanUnaryExpression; -use super::{template_fast, BoxedExpression, Expression}; -use crate::vector_op::arithmetic_op::{decimal_abs, general_abs, general_neg}; -use crate::vector_op::ascii::ascii; -use crate::vector_op::bitwise_op::general_bitnot; -use crate::vector_op::cast::*; -use crate::vector_op::cmp::{is_false, is_not_false, is_not_true, is_true}; -use crate::vector_op::conjunction; -use crate::vector_op::exp::exp_f64; -use crate::vector_op::jsonb_info::{jsonb_array_length, jsonb_typeof}; -use crate::vector_op::length::{bit_length, length_default, octet_length}; -use crate::vector_op::lower::lower; -use crate::vector_op::ltrim::ltrim; -use crate::vector_op::md5::md5; -use crate::vector_op::round::*; -use crate::vector_op::rtrim::rtrim; -use crate::vector_op::timestamptz::f64_sec_to_timestamptz; -use crate::vector_op::trim::trim; -use crate::vector_op::upper::upper; -use crate::{for_all_cast_variants, ExprError, Result}; - -/// This macro helps to create unary expression. -/// In [], the parameters are for constructing new expression -/// * $`expr_name`: expression name, used for print error message -/// * $child: child expression -/// * $ret: return array type -/// In ()*, the parameters are for generating match cases -/// * $input: child array type -/// * $rt: The return type in that the operation will calculate -/// * $func: The scalar function for expression -macro_rules! gen_unary_impl { - ([$expr_name: literal, $child:expr, $ret:expr], $( { $input:ident, $rt: ident, $func:ident },)*) => { - match ($child.return_type()) { - $( - $input! { type_match_pattern } => Box::new( - UnaryExpression::<$input! { type_array}, $rt! {type_array}, _>::new( - $child, - $ret.clone(), - $func, - ) - ), - )* - _ => { - return Err(ExprError::UnsupportedFunction(format!("{}({:?}) -> {:?}", $expr_name, $child.return_type(), $ret))); - } - } - }; -} - -macro_rules! gen_unary_impl_fast { - ([$expr_name: literal, $child:expr, $ret:expr], $( { $input:ident, $rt: ident, $func:expr },)*) => { - match ($child.return_type()) { - $( - $input! { type_match_pattern } => template_fast::UnaryExpression::new($child, $ret, $func).boxed(), - )* - _ => { - return Err(ExprError::UnsupportedFunction(format!("{}({:?}) -> {:?}", $expr_name, $child.return_type(), $ret))); - } - } - }; -} - -macro_rules! gen_unary_atm_expr { - ( - $expr_name: literal, - $child:expr, - $ret:expr, - $general_func:ident, - { - $( { $input:ident, $rt:ident, $func:ident }, )* - } $(,)? - ) => { - gen_unary_impl! { - [$expr_name, $child, $ret], - { int16, int16, $general_func }, - { int32, int32, $general_func }, - { int64, int64, $general_func }, - { float32, float32, $general_func }, - { float64, float64, $general_func }, - $( - { $input, $rt, $func }, - )* - } - }; -} - -macro_rules! gen_round_expr { - ( - $expr_name:literal, - $child:expr, - $ret:expr, - $float64_round_func:ident, - $decimal_round_func:ident - ) => { - gen_unary_impl_fast! { - [$expr_name, $child, $ret], - { float64, float64, $float64_round_func }, - { decimal, decimal, $decimal_round_func }, - } - }; -} - -/// Create a new unary expression. -pub fn new_unary_expr( - expr_type: PbType, - return_type: DataType, - child_expr: BoxedExpression, -) -> Result { - use crate::expr::data_types::*; - - let expr: BoxedExpression = match (expr_type, return_type.clone(), child_expr.return_type()) { - ( - PbType::Cast, - DataType::List { - datatype: target_elem_type, - }, - DataType::Varchar, - ) => Box::new(UnaryExpression::::new( - child_expr, - return_type, - move |input| str_to_list(input, &target_elem_type), - )), - (PbType::Cast, DataType::Struct(rty), DataType::Struct(lty)) => { - Box::new(UnaryExpression::::new( - child_expr, - return_type, - move |input| struct_cast(input, <y, &rty), - )) - } - ( - PbType::Cast, - DataType::List { - datatype: target_elem_type, - }, - DataType::List { - datatype: source_elem_type, - }, - ) => Box::new(UnaryExpression::::new( - child_expr, - return_type, - move |input| list_cast(input, &source_elem_type, &target_elem_type), - )), - (PbType::Cast, _, _) => { - macro_rules! gen_cast_impl { - ($( { $input:ident, $cast:ident, $func:expr, $infallible:ident } ),*) => { - match (child_expr.return_type(), return_type.clone()) { - $( - ($input! { type_match_pattern }, $cast! { type_match_pattern }) => gen_cast_impl!(arm: $input, $cast, $func, $infallible), - )* - _ => { - return Err(ExprError::UnsupportedCast(child_expr.return_type(), return_type)); - } - } - }; - (arm: $input:ident, varchar, $func:expr, false) => { - UnaryBytesExpression::< $input! { type_array }, _>::new( - child_expr, - return_type.clone(), - $func - ).boxed() - }; - (arm: $input:ident, $cast:ident, $func:expr, false) => { - UnaryExpression::< $input! { type_array }, $cast! { type_array }, _>::new( - child_expr, - return_type.clone(), - $func - ).boxed() - }; - (arm: $input:ident, $cast:ident, $func:expr, true) => { - template_fast::UnaryExpression::new( - child_expr, - return_type.clone(), - $func - ).boxed() - }; - } - - for_all_cast_variants! { gen_cast_impl } - } - (PbType::BoolOut, _, DataType::Boolean) => Box::new( - UnaryBytesExpression::::new(child_expr, return_type, bool_out), - ), - (PbType::Not, _, _) => Box::new(BooleanUnaryExpression::new( - child_expr, - |a| BoolArray::new(!a.data() & a.null_bitmap(), a.null_bitmap().clone()), - conjunction::not, - )), - (PbType::IsTrue, _, _) => Box::new(BooleanUnaryExpression::new( - child_expr, - |a| BoolArray::new(a.to_bitmap(), Bitmap::ones(a.len())), - is_true, - )), - (PbType::IsNotTrue, _, _) => Box::new(BooleanUnaryExpression::new( - child_expr, - |a| BoolArray::new(!a.to_bitmap(), Bitmap::ones(a.len())), - is_not_true, - )), - (PbType::IsFalse, _, _) => Box::new(BooleanUnaryExpression::new( - child_expr, - |a| BoolArray::new(!a.data() & a.null_bitmap(), Bitmap::ones(a.len())), - is_false, - )), - (PbType::IsNotFalse, _, _) => Box::new(BooleanUnaryExpression::new( - child_expr, - |a| BoolArray::new(a.data() | !a.null_bitmap(), Bitmap::ones(a.len())), - is_not_false, - )), - (PbType::IsNull, _, _) => Box::new(IsNullExpression::new(child_expr)), - (PbType::IsNotNull, _, _) => Box::new(IsNotNullExpression::new(child_expr)), - (PbType::Upper, _, _) => Box::new(UnaryBytesExpression::::new( - child_expr, - return_type, - upper, - )), - (PbType::Lower, _, _) => Box::new(UnaryBytesExpression::::new( - child_expr, - return_type, - lower, - )), - (PbType::Md5, _, _) => Box::new(UnaryBytesExpression::::new( - child_expr, - return_type, - md5, - )), - (PbType::Ascii, _, _) => Box::new(UnaryExpression::::new( - child_expr, - return_type, - ascii, - )), - (PbType::CharLength, _, _) => Box::new(UnaryExpression::::new( - child_expr, - return_type, - length_default, - )), - (PbType::OctetLength, _, _) => Box::new(UnaryExpression::::new( - child_expr, - return_type, - octet_length, - )), - (PbType::BitLength, _, _) => Box::new(UnaryExpression::::new( - child_expr, - return_type, - bit_length, - )), - (PbType::Neg, _, _) => { - gen_unary_atm_expr! { "Neg", child_expr, return_type, general_neg, - { - { decimal, decimal, general_neg }, - } - } - } - (PbType::Abs, _, _) => { - gen_unary_atm_expr! { "Abs", child_expr, return_type, general_abs, - { - {decimal, decimal, decimal_abs}, - } - } - } - (PbType::BitwiseNot, _, _) => { - gen_unary_impl_fast! { - [ "BitwiseNot", child_expr, return_type], - { int16, int16, general_bitnot:: }, - { int32, int32, general_bitnot:: }, - { int64, int64, general_bitnot:: }, - } - } - (PbType::Ceil, _, _) => { - gen_round_expr! {"Ceil", child_expr, return_type, ceil_f64, ceil_decimal} - } - (PbType::Floor, DataType::Float64, DataType::Float64) => { - gen_round_expr! {"Floor", child_expr, return_type, floor_f64, floor_decimal} - } - (PbType::Round, _, _) => { - gen_round_expr! {"Ceil", child_expr, return_type, round_f64, round_decimal} - } - (PbType::Exp, _, _) => Box::new(UnaryExpression::::new( - child_expr, - return_type, - exp_f64, - )), - (PbType::ToTimestamp, DataType::Timestamptz, DataType::Float64) => { - Box::new(UnaryExpression::::new( - child_expr, - return_type, - f64_sec_to_timestamptz, - )) - } - (PbType::JsonbTypeof, DataType::Varchar, DataType::Jsonb) => { - UnaryBytesExpression::::new(child_expr, return_type, jsonb_typeof) - .boxed() - } - (PbType::JsonbArrayLength, DataType::Int32, DataType::Jsonb) => { - UnaryExpression::::new( - child_expr, - return_type, - jsonb_array_length, - ) - .boxed() - } - (expr, ret, child) => { - return Err(ExprError::UnsupportedFunction(format!( - "{:?}({:?}) -> {:?}", - expr, child, ret - ))); - } - }; - - Ok(expr) -} - -pub fn new_length_default(expr_ia1: BoxedExpression, return_type: DataType) -> BoxedExpression { - Box::new(UnaryExpression::::new( - expr_ia1, - return_type, - length_default, - )) -} - -pub fn new_trim_expr(expr_ia1: BoxedExpression, return_type: DataType) -> BoxedExpression { - Box::new(UnaryBytesExpression::::new( - expr_ia1, - return_type, - trim, - )) -} - -pub fn new_ltrim_expr(expr_ia1: BoxedExpression, return_type: DataType) -> BoxedExpression { - Box::new(UnaryBytesExpression::::new( - expr_ia1, - return_type, - ltrim, - )) -} - -pub fn new_rtrim_expr(expr_ia1: BoxedExpression, return_type: DataType) -> BoxedExpression { - Box::new(UnaryBytesExpression::::new( - expr_ia1, - return_type, - rtrim, - )) -} - #[cfg(test)] mod tests { use itertools::Itertools; use risingwave_common::array::*; use risingwave_common::types::{NaiveDateWrapper, Scalar}; - use risingwave_pb::data::data_type::TypeName; - use risingwave_pb::data::DataType; - use risingwave_pb::expr::expr_node::{RexNode, Type}; - use risingwave_pb::expr::{ExprNode, FunctionCall}; + use risingwave_pb::expr::expr_node::PbType; use super::super::*; - use crate::expr::test_utils::{make_expression, make_input_ref}; use crate::vector_op::cast::{str_parse, try_cast}; #[tokio::test] async fn test_unary() { - test_unary_bool::(|x| !x, Type::Not).await; - test_unary_date::(|x| try_cast(x).unwrap(), Type::Cast).await; + test_unary_bool::(|x| !x, PbType::Not).await; + test_unary_date::(|x| try_cast(x).unwrap(), PbType::Cast).await; test_str_to_int16::(|x| str_parse(x).unwrap()).await; } @@ -399,20 +46,13 @@ mod tests { } let col1 = I16Array::from_iter(&input).into(); let data_chunk = DataChunk::new(vec![col1], 100); - let return_type = DataType { - type_name: TypeName::Int32 as i32, - is_nullable: false, - ..Default::default() - }; - let expr = ExprNode { - expr_type: Type::Cast as i32, - return_type: Some(return_type), - rex_node: Some(RexNode::FuncCall(FunctionCall { - children: vec![make_input_ref(0, TypeName::Int16)], - })), - }; - let vec_executor = build_from_prost(&expr).unwrap(); - let res = vec_executor.eval(&data_chunk).await.unwrap(); + let expr = build( + PbType::Cast, + DataType::Int32, + vec![InputRefExpression::new(DataType::Int16, 0).boxed()], + ) + .unwrap(); + let res = expr.eval(&data_chunk).await.unwrap(); let arr: &I32Array = res.as_ref().into(); for (idx, item) in arr.iter().enumerate() { let x = target[idx].as_ref().map(|x| x.as_scalar_ref()); @@ -421,7 +61,7 @@ mod tests { for i in 0..input.len() { let row = OwnedRow::new(vec![input[i].map(|int| int.to_scalar_value())]); - let result = vec_executor.eval_row(&row).await.unwrap(); + let result = expr.eval_row(&row).await.unwrap(); let expected = target[i].map(|int| int.to_scalar_value()); assert_eq!(result, expected); } @@ -442,20 +82,13 @@ mod tests { let col1 = I32Array::from_iter(&input).into(); let data_chunk = DataChunk::new(vec![col1], 3); - let return_type = DataType { - type_name: TypeName::Int32 as i32, - is_nullable: false, - ..Default::default() - }; - let expr = ExprNode { - expr_type: Type::Neg as i32, - return_type: Some(return_type), - rex_node: Some(RexNode::FuncCall(FunctionCall { - children: vec![make_input_ref(0, TypeName::Int32)], - })), - }; - let vec_executor = build_from_prost(&expr).unwrap(); - let res = vec_executor.eval(&data_chunk).await.unwrap(); + let expr = build( + PbType::Neg, + DataType::Int32, + vec![InputRefExpression::new(DataType::Int32, 0).boxed()], + ) + .unwrap(); + let res = expr.eval(&data_chunk).await.unwrap(); let arr: &I32Array = res.as_ref().into(); for (idx, item) in arr.iter().enumerate() { let x = target[idx].as_ref().map(|x| x.as_scalar_ref()); @@ -464,7 +97,7 @@ mod tests { for i in 0..input.len() { let row = OwnedRow::new(vec![input[i].map(|int| int.to_scalar_value())]); - let result = vec_executor.eval_row(&row).await.unwrap(); + let result = expr.eval_row(&row).await.unwrap(); let expected = target[i].map(|int| int.to_scalar_value()); assert_eq!(result, expected); } @@ -492,20 +125,13 @@ mod tests { let col1_data = &input.iter().map(|x| x.as_ref().map(|x| &**x)).collect_vec(); let col1 = Utf8Array::from_iter(col1_data).into(); let data_chunk = DataChunk::new(vec![col1], 1); - let return_type = DataType { - type_name: TypeName::Int16 as i32, - is_nullable: false, - ..Default::default() - }; - let expr = ExprNode { - expr_type: Type::Cast as i32, - return_type: Some(return_type), - rex_node: Some(RexNode::FuncCall(FunctionCall { - children: vec![make_input_ref(0, TypeName::Varchar)], - })), - }; - let vec_executor = build_from_prost(&expr).unwrap(); - let res = vec_executor.eval(&data_chunk).await.unwrap(); + let expr = build( + PbType::Cast, + DataType::Int16, + vec![InputRefExpression::new(DataType::Varchar, 0).boxed()], + ) + .unwrap(); + let res = expr.eval(&data_chunk).await.unwrap(); let arr: &A = res.as_ref().into(); for (idx, item) in arr.iter().enumerate() { let x = target[idx].as_ref().map(|x| x.as_scalar_ref()); @@ -517,13 +143,13 @@ mod tests { .as_ref() .cloned() .map(|str| str.to_scalar_value())]); - let result = vec_executor.eval_row(&row).await.unwrap(); + let result = expr.eval_row(&row).await.unwrap(); let expected = target[i].as_ref().cloned().map(|x| x.to_scalar_value()); assert_eq!(result, expected); } } - async fn test_unary_bool(f: F, kind: Type) + async fn test_unary_bool(f: F, kind: PbType) where A: Array, for<'a> &'a A: std::convert::From<&'a ArrayImpl>, @@ -547,9 +173,13 @@ mod tests { let col1 = BoolArray::from_iter(&input).into(); let data_chunk = DataChunk::new(vec![col1], 100); - let expr = make_expression(kind, &[TypeName::Boolean], &[0]); - let vec_executor = build_from_prost(&expr).unwrap(); - let res = vec_executor.eval(&data_chunk).await.unwrap(); + let expr = build( + kind, + DataType::Boolean, + vec![InputRefExpression::new(DataType::Boolean, 0).boxed()], + ) + .unwrap(); + let res = expr.eval(&data_chunk).await.unwrap(); let arr: &A = res.as_ref().into(); for (idx, item) in arr.iter().enumerate() { let x = target[idx].as_ref().map(|x| x.as_scalar_ref()); @@ -558,13 +188,13 @@ mod tests { for i in 0..input.len() { let row = OwnedRow::new(vec![input[i].map(|b| b.to_scalar_value())]); - let result = vec_executor.eval_row(&row).await.unwrap(); + let result = expr.eval_row(&row).await.unwrap(); let expected = target[i].as_ref().cloned().map(|x| x.to_scalar_value()); assert_eq!(result, expected); } } - async fn test_unary_date(f: F, kind: Type) + async fn test_unary_date(f: F, kind: PbType) where A: Array, for<'a> &'a A: std::convert::From<&'a ArrayImpl>, @@ -586,9 +216,13 @@ mod tests { let col1 = NaiveDateArray::from_iter(&input).into(); let data_chunk = DataChunk::new(vec![col1], 100); - let expr = make_expression(kind, &[TypeName::Date], &[0]); - let vec_executor = build_from_prost(&expr).unwrap(); - let res = vec_executor.eval(&data_chunk).await.unwrap(); + let expr = build( + kind, + DataType::Timestamp, + vec![InputRefExpression::new(DataType::Date, 0).boxed()], + ) + .unwrap(); + let res = expr.eval(&data_chunk).await.unwrap(); let arr: &A = res.as_ref().into(); for (idx, item) in arr.iter().enumerate() { let x = target[idx].as_ref().map(|x| x.as_scalar_ref()); @@ -597,7 +231,7 @@ mod tests { for i in 0..input.len() { let row = OwnedRow::new(vec![input[i].map(|d| d.to_scalar_value())]); - let result = vec_executor.eval_row(&row).await.unwrap(); + let result = expr.eval_row(&row).await.unwrap(); let expected = target[i].as_ref().cloned().map(|x| x.to_scalar_value()); assert_eq!(result, expected); } diff --git a/src/expr/src/expr/mod.rs b/src/expr/src/expr/mod.rs index 7a63da4ffe105..d88008f07c6b2 100644 --- a/src/expr/src/expr/mod.rs +++ b/src/expr/src/expr/mod.rs @@ -36,7 +36,6 @@ mod expr_array_concat; mod expr_array_distinct; mod expr_array_length; mod expr_array_to_string; -mod expr_binary_bytes; mod expr_binary_nonnull; mod expr_binary_nullable; mod expr_case; @@ -49,11 +48,9 @@ mod expr_is_null; mod expr_jsonb_access; mod expr_literal; mod expr_nested_construct; -mod expr_quaternary_bytes; +mod expr_now; pub mod expr_regexp; mod expr_some_all; -mod expr_ternary; -mod expr_ternary_bytes; mod expr_to_char_const_tmpl; mod expr_to_timestamp_const_tmpl; mod expr_udf; @@ -63,8 +60,8 @@ mod expr_vnode; mod agg; mod build_expr_from_prost; pub(crate) mod data_types; -mod template; -mod template_fast; +pub(crate) mod template; +pub(crate) mod template_fast; pub mod test_utils; use std::sync::Arc; @@ -75,11 +72,9 @@ use risingwave_common::types::{DataType, Datum}; use static_assertions::const_assert; pub use self::agg::AggKind; -pub use self::build_expr_from_prost::build_from_prost; -pub use self::expr_binary_nonnull::new_binary_expr; +pub use self::build_expr_from_prost::{build, build_from_prost}; pub use self::expr_input_ref::InputRefExpression; pub use self::expr_literal::LiteralExpression; -pub use self::expr_unary::new_unary_expr; use super::{ExprError, Result}; /// Instance of an expression @@ -108,6 +103,11 @@ pub trait Expression: std::fmt::Debug + Sync + Send { /// Evaluate the expression in row-based execution. async fn eval_row(&self, input: &OwnedRow) -> Result; + /// Evaluate if the expression is constant. + fn eval_const(&self) -> Result { + Err(ExprError::NotConstant) + } + /// Wrap the expression in a Box. fn boxed(self) -> BoxedExpression where diff --git a/src/expr/src/expr/template.rs b/src/expr/src/expr/template.rs index 68d3a80027e47..3fe8c885f234c 100644 --- a/src/expr/src/expr/template.rs +++ b/src/expr/src/expr/template.rs @@ -319,6 +319,7 @@ macro_rules! gen_expr_nullable { > $ty_name<$($arg, )* OA, F> { // Compile failed due to some GAT lifetime issues so make this field private. // Check issues #742. + #[allow(dead_code)] pub fn new( $([]: BoxedExpression, )* return_type: DataType, @@ -345,61 +346,5 @@ gen_expr_bytes!(BinaryBytesExpression, { IA1, IA2 }); gen_expr_bytes!(TernaryBytesExpression, { IA1, IA2, IA3 }); gen_expr_bytes!(QuaternaryBytesExpression, { IA1, IA2, IA3, IA4 }); +gen_expr_nullable!(UnaryNullableExpression, { IA1 }); gen_expr_nullable!(BinaryNullableExpression, { IA1, IA2 }); - -/// `for_all_cmp_types` helps in matching and casting types when building comparison expressions -/// such as <= or IS DISTINCT FROM. -#[macro_export] -macro_rules! for_all_cmp_variants { - ($macro:ident, $l:expr, $r:expr, $ret:expr, $general_f:ident) => { - $macro! { - [$l, $r, $ret], - { int16, int16, int16, $general_f }, - { int16, int32, int32, $general_f }, - { int16, int64, int64, $general_f }, - { int16, float32, float64, $general_f }, - { int16, float64, float64, $general_f }, - { int32, int16, int32, $general_f }, - { int32, int32, int32, $general_f }, - { int32, int64, int64, $general_f }, - { int32, float32, float64, $general_f }, - { int32, float64, float64, $general_f }, - { int64, int16,int64, $general_f }, - { int64, int32,int64, $general_f }, - { int64, int64, int64, $general_f }, - { int64, float32, float64 , $general_f}, - { int64, float64, float64, $general_f }, - { float32, int16, float64, $general_f }, - { float32, int32, float64, $general_f }, - { float32, int64, float64 , $general_f}, - { float32, float32, float32, $general_f }, - { float32, float64, float64, $general_f }, - { float64, int16, float64, $general_f }, - { float64, int32, float64, $general_f }, - { float64, int64, float64, $general_f }, - { float64, float32, float64, $general_f }, - { float64, float64, float64, $general_f }, - { decimal, int16, decimal, $general_f }, - { decimal, int32, decimal, $general_f }, - { decimal, int64, decimal, $general_f }, - { decimal, float32, float64, $general_f }, - { decimal, float64, float64, $general_f }, - { int16, decimal, decimal, $general_f }, - { int32, decimal, decimal, $general_f }, - { int64, decimal, decimal, $general_f }, - { decimal, decimal, decimal, $general_f }, - { float32, decimal, float64, $general_f }, - { float64, decimal, float64, $general_f }, - { timestamptz, timestamptz, timestamptz, $general_f }, - { timestamp, timestamp, timestamp, $general_f }, - { interval, interval, interval, $general_f }, - { time, time, time, $general_f }, - { date, date, date, $general_f }, - { timestamp, date, timestamp, $general_f }, - { date, timestamp, timestamp, $general_f }, - { interval, time, interval, $general_f }, - { time, interval, interval, $general_f }, - { serial, serial, serial, $general_f }, - } - }; -} diff --git a/src/expr/src/expr/template_fast.rs b/src/expr/src/expr/template_fast.rs index 9091f351c4d74..7601f21920ca3 100644 --- a/src/expr/src/expr/template_fast.rs +++ b/src/expr/src/expr/template_fast.rs @@ -340,11 +340,11 @@ impl fmt::Debug for CompareExpression { impl CompareExpression where - F: Fn(A, B) -> bool + Send + Sync, - A: PrimitiveArrayItemType, - B: PrimitiveArrayItemType, - for<'a> &'a PrimitiveArray: From<&'a ArrayImpl>, - for<'a> &'a PrimitiveArray: From<&'a ArrayImpl>, + F: Fn(A::RefItem<'_>, B::RefItem<'_>) -> bool + Send + Sync, + A: Array, + B: Array, + for<'a> &'a A: std::convert::From<&'a ArrayImpl>, + for<'a> &'a B: std::convert::From<&'a ArrayImpl>, { pub fn new(left: BoxedExpression, right: BoxedExpression, func: F) -> Self { CompareExpression { @@ -359,11 +359,11 @@ where #[async_trait::async_trait] impl Expression for CompareExpression where - F: Fn(A, B) -> bool + Send + Sync, - A: PrimitiveArrayItemType, - B: PrimitiveArrayItemType, - for<'a> &'a PrimitiveArray: From<&'a ArrayImpl>, - for<'a> &'a PrimitiveArray: From<&'a ArrayImpl>, + F: Fn(A::RefItem<'_>, B::RefItem<'_>) -> bool + Send + Sync, + A: Array, + B: Array, + for<'a> &'a A: std::convert::From<&'a ArrayImpl>, + for<'a> &'a B: std::convert::From<&'a ArrayImpl>, { fn return_type(&self) -> DataType { DataType::Boolean @@ -380,8 +380,8 @@ where }; bitmap &= left.null_bitmap(); bitmap &= right.null_bitmap(); - let a: &PrimitiveArray = (&*left).into(); - let b: &PrimitiveArray = (&*right).into(); + let a: &A = (&*left).into(); + let b: &B = (&*right).into(); let c = BoolArray::new( a.raw_iter() .zip(b.raw_iter()) @@ -430,12 +430,13 @@ impl fmt::Debug for IsDistinctFromExpression { impl IsDistinctFromExpression where - F: Fn(A, B) -> bool + Send + Sync, - A: PrimitiveArrayItemType, - B: PrimitiveArrayItemType, - for<'a> &'a PrimitiveArray: From<&'a ArrayImpl>, - for<'a> &'a PrimitiveArray: From<&'a ArrayImpl>, + F: Fn(A::RefItem<'_>, B::RefItem<'_>) -> bool + Send + Sync, + A: Array, + B: Array, + for<'a> &'a A: std::convert::From<&'a ArrayImpl>, + for<'a> &'a B: std::convert::From<&'a ArrayImpl>, { + #[allow(dead_code)] pub fn new(left: BoxedExpression, right: BoxedExpression, ne: F, not: bool) -> Self { IsDistinctFromExpression { left, @@ -450,11 +451,11 @@ where #[async_trait::async_trait] impl Expression for IsDistinctFromExpression where - F: Fn(A, B) -> bool + Send + Sync, - A: PrimitiveArrayItemType, - B: PrimitiveArrayItemType, - for<'a> &'a PrimitiveArray: From<&'a ArrayImpl>, - for<'a> &'a PrimitiveArray: From<&'a ArrayImpl>, + F: Fn(A::RefItem<'_>, B::RefItem<'_>) -> bool + Send + Sync, + A: Array, + B: Array, + for<'a> &'a A: std::convert::From<&'a ArrayImpl>, + for<'a> &'a B: std::convert::From<&'a ArrayImpl>, { fn return_type(&self) -> DataType { DataType::Boolean @@ -465,8 +466,8 @@ where let right = self.right.eval_checked(data_chunk).await?; assert_eq!(left.len(), right.len()); - let a: &PrimitiveArray = (&*left).into(); - let b: &PrimitiveArray = (&*right).into(); + let a: &A = (&*left).into(); + let b: &B = (&*right).into(); let mut data: Bitmap = a .raw_iter() diff --git a/src/expr/src/expr/test_utils.rs b/src/expr/src/expr/test_utils.rs index 5b80387c359ac..767a86961b9d1 100644 --- a/src/expr/src/expr/test_utils.rs +++ b/src/expr/src/expr/test_utils.rs @@ -18,7 +18,6 @@ use std::num::NonZeroUsize; use num_traits::CheckedSub; use risingwave_common::types::{DataType, IntervalUnit, ScalarImpl}; -use risingwave_common::util::iter_util::ZipEqFast; use risingwave_common::util::value_encoding::serialize_datum; use risingwave_pb::data::data_type::TypeName; use risingwave_pb::data::{PbDataType, PbDatum}; @@ -26,26 +25,17 @@ use risingwave_pb::expr::expr_node::Type::{Field, InputRef}; use risingwave_pb::expr::expr_node::{self, RexNode, Type}; use risingwave_pb::expr::{ExprNode, FunctionCall}; -use super::expr_ternary::new_tumble_start_offset; -use super::{ - new_binary_expr, BoxedExpression, Expression, InputRefExpression, LiteralExpression, Result, -}; +use super::{build_from_prost, BoxedExpression, Result}; use crate::ExprError; -pub fn make_expression(kind: Type, rets: &[TypeName], indices: &[usize]) -> ExprNode { - let mut exprs = Vec::new(); - for (idx, ret) in indices.iter().zip_eq_fast(rets.iter()) { - exprs.push(make_input_ref(*idx, *ret)); - } - let function_call = FunctionCall { children: exprs }; - let return_type = PbDataType { - type_name: TypeName::Timestamp as i32, - ..Default::default() - }; +pub fn make_expression(kind: Type, ret: TypeName, children: Vec) -> ExprNode { ExprNode { expr_type: kind as i32, - return_type: Some(return_type), - rex_node: Some(RexNode::FuncCall(function_call)), + return_type: Some(PbDataType { + type_name: ret as i32, + ..Default::default() + }), + rex_node: Some(RexNode::FuncCall(FunctionCall { children })), } } @@ -73,15 +63,15 @@ pub fn make_i32_literal(data: i32) -> ExprNode { } } -pub fn make_string_literal(data: &str) -> ExprNode { +fn make_interval_literal(data: IntervalUnit) -> ExprNode { ExprNode { expr_type: Type::ConstantValue as i32, return_type: Some(PbDataType { - type_name: TypeName::Varchar as i32, + type_name: TypeName::Interval as i32, ..Default::default() }), rex_node: Some(RexNode::Constant(PbDatum { - body: serialize_datum(Some(ScalarImpl::Utf8(data.into())).as_ref()), + body: serialize_datum(Some(ScalarImpl::Interval(data)).as_ref()), })), } } @@ -116,52 +106,40 @@ pub fn make_hop_window_expression( })? .get(); - let output_type = DataType::window_of(&time_col_data_type).unwrap(); - let get_hop_window_start = || -> Result { - let time_col_ref = InputRefExpression::new(time_col_data_type, time_col_idx).boxed(); + let output_type = DataType::window_of(&time_col_data_type) + .unwrap() + .to_protobuf() + .type_name(); - let window_slide_expr = - LiteralExpression::new(DataType::Interval, Some(ScalarImpl::Interval(window_slide))) - .boxed(); - let window_offset_expr = LiteralExpression::new( - DataType::Interval, - Some(ScalarImpl::Interval(window_offset)), - ) - .boxed(); + let time_col_ref = make_input_ref(time_col_idx, time_col_data_type.to_protobuf().type_name()); - // The first window_start of hop window should be: - // tumble_start(`time_col` - (`window_size` - `window_slide`), `window_slide`, - // `window_offset`). Let's pre calculate (`window_size` - `window_slide`). - let window_size_sub_slide = - window_size - .checked_sub(&window_slide) - .ok_or_else(|| ExprError::InvalidParam { - name: "window", - reason: format!( - "window_size {} cannot be subtracted by window_slide {}", - window_size, window_slide - ), - })?; - let window_size_sub_slide_expr = LiteralExpression::new( - DataType::Interval, - Some(ScalarImpl::Interval(window_size_sub_slide)), - ) - .boxed(); - - let hop_start = new_tumble_start_offset( - new_binary_expr( + // The first window_start of hop window should be: + // tumble_start(`time_col` - (`window_size` - `window_slide`), `window_slide`, `window_offset`). + // Let's pre calculate (`window_size` - `window_slide`). + let window_size_sub_slide = window_size + .checked_sub(&window_slide) + .ok_or_else(|| ExprError::InvalidParam { + name: "window", + reason: format!( + "window_size {} cannot be subtracted by window_slide {}", + window_size, window_slide + ), + }) + .unwrap(); + + let hop_window_start = make_expression( + expr_node::Type::TumbleStart, + output_type, + vec![ + make_expression( expr_node::Type::Subtract, - output_type.clone(), - time_col_ref, - window_size_sub_slide_expr, - )?, - window_slide_expr, - window_offset_expr, - output_type.clone(), - )?; - - Ok(hop_start) - }; + output_type, + vec![time_col_ref, make_interval_literal(window_size_sub_slide)], + ), + make_interval_literal(window_slide), + make_interval_literal(window_offset), + ], + ); let mut window_start_exprs = Vec::with_capacity(units); let mut window_end_exprs = Vec::with_capacity(units); @@ -176,11 +154,6 @@ pub fn make_hop_window_expression( window_slide, i ), })?; - let window_start_offset_expr = LiteralExpression::new( - DataType::Interval, - Some(ScalarImpl::Interval(window_start_offset)), - ) - .boxed(); let window_end_offset = window_slide .checked_mul_int(i + units) @@ -191,25 +164,24 @@ pub fn make_hop_window_expression( window_slide, i ), })?; - let window_end_offset_expr = LiteralExpression::new( - DataType::Interval, - Some(ScalarImpl::Interval(window_end_offset)), - ) - .boxed(); - let window_start_expr = new_binary_expr( + let window_start_expr = make_expression( expr_node::Type::Add, - output_type.clone(), - get_hop_window_start.clone()()?, - window_start_offset_expr, - )?; - window_start_exprs.push(window_start_expr); - let window_end_expr = new_binary_expr( + output_type, + vec![ + hop_window_start.clone(), + make_interval_literal(window_start_offset), + ], + ); + window_start_exprs.push(build_from_prost(&window_start_expr).unwrap()); + let window_end_expr = make_expression( expr_node::Type::Add, - output_type.clone(), - get_hop_window_start.clone()()?, - window_end_offset_expr, - )?; - window_end_exprs.push(window_end_expr); + output_type, + vec![ + hop_window_start.clone(), + make_interval_literal(window_end_offset), + ], + ); + window_end_exprs.push(build_from_prost(&window_end_expr).unwrap()); } Ok((window_start_exprs, window_end_exprs)) } diff --git a/src/expr/src/sig/cast.rs b/src/expr/src/sig/cast.rs index e5f85fd7d0fb4..705e27d58d8b2 100644 --- a/src/expr/src/sig/cast.rs +++ b/src/expr/src/sig/cast.rs @@ -36,13 +36,6 @@ pub enum CastContext { Explicit, } -impl CastSig { - /// Returns a string describing the cast. - pub fn to_string_no_return(&self) -> String { - format!("cast({:?}->{:?})", self.from_type, self.to_type).to_lowercase() - } -} - pub type CastMap = BTreeMap<(DataTypeName, DataTypeName), CastContext>; pub fn cast_sigs() -> impl Iterator { diff --git a/src/expr/src/sig/func.rs b/src/expr/src/sig/func.rs index 23cace6b0dd2e..498f97c5d42c4 100644 --- a/src/expr/src/sig/func.rs +++ b/src/expr/src/sig/func.rs @@ -15,358 +15,95 @@ //! Function signatures. use std::collections::HashMap; +use std::fmt; use std::ops::Deref; use std::sync::LazyLock; -use itertools::iproduct; -use risingwave_common::types::DataTypeName; -use risingwave_pb::expr::expr_node::Type as ExprType; +use itertools::Itertools; +use risingwave_common::types::{DataType, DataTypeName}; +use risingwave_pb::expr::expr_node::PbType; -pub static FUNC_SIG_MAP: LazyLock = LazyLock::new(build_type_derive_map); +use crate::error::Result; +use crate::expr::BoxedExpression; + +pub static FUNC_SIG_MAP: LazyLock = LazyLock::new(|| unsafe { + let mut map = FuncSigMap::default(); + tracing::info!("{} function signatures loaded.", FUNC_SIG_MAP_INIT.len()); + for desc in FUNC_SIG_MAP_INIT.drain(..) { + map.insert(desc); + } + map +}); /// The table of function signatures. pub fn func_sigs() -> impl Iterator { FUNC_SIG_MAP.0.values().flatten() } -/// A function signature. -#[derive(PartialEq, Eq, Hash, Clone, Debug)] -pub struct FuncSign { - pub func: ExprType, - pub inputs_type: Vec, - pub ret_type: DataTypeName, -} - -impl FuncSign { - /// Returns a string describing the function without return type. - pub fn to_string_no_return(&self) -> String { - format!( - "{}({})", - self.func.as_str_name(), - self.inputs_type - .iter() - .map(|t| format!("{t:?}")) - .collect::>() - .join(",") - ) - .to_lowercase() - } -} - -#[derive(Default)] -pub struct FuncSigMap(HashMap<(ExprType, usize), Vec>); - -impl Deref for FuncSigMap { - type Target = HashMap<(ExprType, usize), Vec>; - - fn deref(&self) -> &Self::Target { - &self.0 - } -} +#[derive(Default, Clone, Debug)] +pub struct FuncSigMap(HashMap<(PbType, usize), Vec>); impl FuncSigMap { - pub fn insert( - &mut self, - func: ExprType, - param_types: Vec, - ret_type: DataTypeName, - ) { - let arity = param_types.len(); - let inputs_type = param_types.into_iter().map(Into::into).collect(); - let sig = FuncSign { - func, - inputs_type, - ret_type, - }; - self.0.entry((func, arity)).or_default().push(sig) - } -} - -/// This function builds type derived map for all built-in functions that take a fixed number -/// of arguments. They can be determined to have one or more type signatures since some are -/// compatible with more than one type. -/// Type signatures and arities of variadic functions are checked -/// [elsewhere](crate::expr::FunctionCall::new). -fn build_type_derive_map() -> FuncSigMap { - use {DataTypeName as T, ExprType as E}; - let mut map = FuncSigMap::default(); - let all_types = [ - T::Boolean, - T::Int16, - T::Int32, - T::Int64, - T::Serial, - T::Decimal, - T::Float32, - T::Float64, - T::Varchar, - T::Date, - T::Timestamp, - T::Timestamptz, - T::Time, - T::Interval, - T::Jsonb, - ]; - let num_types = [ - T::Int16, - T::Int32, - T::Int64, - T::Serial, - T::Decimal, - T::Float32, - T::Float64, - ]; - - // logical expressions - for e in [E::Not, E::IsTrue, E::IsNotTrue, E::IsFalse, E::IsNotFalse] { - map.insert(e, vec![T::Boolean], T::Boolean); - } - for e in [E::And, E::Or] { - map.insert(e, vec![T::Boolean, T::Boolean], T::Boolean); + /// Inserts a function signature. + pub fn insert(&mut self, desc: FuncSign) { + self.0 + .entry((desc.func, desc.inputs_type.len())) + .or_default() + .push(desc) } - map.insert(E::BoolOut, vec![T::Boolean], T::Varchar); - // comparison expressions - for e in [E::IsNull, E::IsNotNull] { - for t in all_types { - map.insert(e, vec![t], T::Boolean); - } + /// Returns a function signature with the same type, argument types and return type. + pub fn get(&self, ty: PbType, args: &[DataTypeName], ret: DataTypeName) -> Option<&FuncSign> { + let v = self.0.get(&(ty, args.len()))?; + v.iter() + .find(|d| d.inputs_type == args && d.ret_type == ret) } - let cmp_exprs = &[ - E::Equal, - E::NotEqual, - E::LessThan, - E::LessThanOrEqual, - E::GreaterThan, - E::GreaterThanOrEqual, - E::IsDistinctFrom, - E::IsNotDistinctFrom, - ]; - build_binary_cmp_funcs(&mut map, cmp_exprs, &num_types); - build_binary_cmp_funcs(&mut map, cmp_exprs, &[T::Struct]); - build_binary_cmp_funcs( - &mut map, - cmp_exprs, - &[T::Date, T::Timestamp, T::Timestamptz], - ); - build_binary_cmp_funcs(&mut map, cmp_exprs, &[T::Time, T::Interval]); - for e in cmp_exprs { - for t in [T::Boolean, T::Varchar] { - map.insert(*e, vec![t, t], T::Boolean); - } - } - - let unary_atm_exprs = &[E::Abs, E::Neg]; - - build_unary_atm_funcs(&mut map, unary_atm_exprs, &num_types); - build_binary_atm_funcs( - &mut map, - &[E::Add, E::Subtract, E::Multiply, E::Divide], - &[T::Int16, T::Int32, T::Int64, T::Decimal], - ); - build_binary_atm_funcs( - &mut map, - &[E::Add, E::Subtract, E::Multiply, E::Divide], - &[T::Float32, T::Float64], - ); - build_binary_atm_funcs( - &mut map, - &[E::Modulus], - &[T::Int16, T::Int32, T::Int64, T::Decimal], - ); - map.insert(E::RoundDigit, vec![T::Decimal, T::Int32], T::Decimal); - map.insert(E::Pow, vec![T::Float64, T::Float64], T::Float64); - map.insert(E::Exp, vec![T::Float64], T::Float64); - - // build bitwise operator - // bitwise operator - let integral_types = [T::Int16, T::Int32, T::Int64]; // reusable for and/or/xor/not - - build_binary_atm_funcs( - &mut map, - &[E::BitwiseAnd, E::BitwiseOr, E::BitwiseXor], - &integral_types, - ); - // Shift Operator is not using `build_binary_atm_funcs` because - // allowed rhs is different from allowed lhs - // return type is lhs rather than larger of the two - for (e, lt, rt) in iproduct!( - &[E::BitwiseShiftLeft, E::BitwiseShiftRight], - &integral_types, - &[T::Int16, T::Int32] - ) { - map.insert(*e, vec![*lt, *rt], *lt); - } - - build_unary_atm_funcs(&mut map, &[E::BitwiseNot], &[T::Int16, T::Int32, T::Int64]); - - build_round_funcs(&mut map, E::Round); - build_round_funcs(&mut map, E::Ceil); - build_round_funcs(&mut map, E::Floor); - - // temporal expressions - for (base, delta) in [ - (T::Date, T::Int32), - (T::Timestamp, T::Interval), - (T::Timestamptz, T::Interval), - (T::Time, T::Interval), - ] { - build_commutative_funcs(&mut map, E::Add, base, delta, base); - map.insert(E::Subtract, vec![base, delta], base); - map.insert(E::Subtract, vec![base, base], delta); - } - map.insert(E::Add, vec![T::Interval, T::Interval], T::Interval); - map.insert(E::Subtract, vec![T::Interval, T::Interval], T::Interval); - - // date + interval = timestamp, date - interval = timestamp - build_commutative_funcs(&mut map, E::Add, T::Date, T::Interval, T::Timestamp); - map.insert(E::Subtract, vec![T::Date, T::Interval], T::Timestamp); - // date + time = timestamp - build_commutative_funcs(&mut map, E::Add, T::Date, T::Time, T::Timestamp); - // interval * float8 = interval, interval / float8 = interval - for t in num_types { - build_commutative_funcs(&mut map, E::Multiply, T::Interval, t, T::Interval); - map.insert(E::Divide, vec![T::Interval, t], T::Interval); - } - - for t in [T::Timestamptz, T::Timestamp, T::Time, T::Date] { - map.insert(E::Extract, vec![T::Varchar, t], T::Decimal); - } - for t in [T::Timestamp, T::Date] { - map.insert(E::TumbleStart, vec![t, T::Interval], T::Timestamp); - map.insert( - E::TumbleStart, - vec![t, T::Interval, T::Interval], - T::Timestamp, - ); - } - map.insert( - E::TumbleStart, - vec![T::Timestamptz, T::Interval], - T::Timestamptz, - ); - map.insert( - E::TumbleStart, - vec![T::Timestamptz, T::Interval, T::Interval], - T::Timestamptz, - ); - map.insert(E::ToTimestamp, vec![T::Float64], T::Timestamptz); - map.insert(E::ToTimestamp1, vec![T::Varchar, T::Varchar], T::Timestamp); - map.insert( - E::AtTimeZone, - vec![T::Timestamp, T::Varchar], - T::Timestamptz, - ); - map.insert( - E::AtTimeZone, - vec![T::Timestamptz, T::Varchar], - T::Timestamp, - ); - map.insert(E::DateTrunc, vec![T::Varchar, T::Timestamp], T::Timestamp); - map.insert( - E::DateTrunc, - vec![T::Varchar, T::Timestamptz, T::Varchar], - T::Timestamptz, - ); - map.insert(E::DateTrunc, vec![T::Varchar, T::Interval], T::Interval); - - // string expressions - for e in [E::Trim, E::Ltrim, E::Rtrim, E::Lower, E::Upper, E::Md5] { - map.insert(e, vec![T::Varchar], T::Varchar); - } - for e in [E::Trim, E::Ltrim, E::Rtrim] { - map.insert(e, vec![T::Varchar, T::Varchar], T::Varchar); - } - for e in [E::Repeat, E::Substr] { - map.insert(e, vec![T::Varchar, T::Int32], T::Varchar); - } - map.insert(E::Substr, vec![T::Varchar, T::Int32, T::Int32], T::Varchar); - for e in [E::Replace, E::Translate] { - map.insert(e, vec![T::Varchar, T::Varchar, T::Varchar], T::Varchar); - } - map.insert(E::FormatType, vec![T::Int32, T::Int32], T::Varchar); - map.insert( - E::Overlay, - vec![T::Varchar, T::Varchar, T::Int32], - T::Varchar, - ); - map.insert( - E::Overlay, - vec![T::Varchar, T::Varchar, T::Int32, T::Int32], - T::Varchar, - ); - for e in [ - E::Length, - E::Ascii, - E::CharLength, - E::OctetLength, - E::BitLength, - ] { - map.insert(e, vec![T::Varchar], T::Int32); - } - map.insert(E::Position, vec![T::Varchar, T::Varchar], T::Int32); - map.insert(E::Like, vec![T::Varchar, T::Varchar], T::Boolean); - map.insert( - E::SplitPart, - vec![T::Varchar, T::Varchar, T::Int32], - T::Varchar, - ); - // TODO: Support more `to_char` types. - map.insert(E::ToChar, vec![T::Timestamp, T::Varchar], T::Varchar); - // array_to_string - map.insert(E::ArrayToString, vec![T::List, T::Varchar], T::Varchar); - map.insert( - E::ArrayToString, - vec![T::List, T::Varchar, T::Varchar], - T::Varchar, - ); - - map.insert(E::JsonbAccessInner, vec![T::Jsonb, T::Int32], T::Jsonb); - map.insert(E::JsonbAccessInner, vec![T::Jsonb, T::Varchar], T::Jsonb); - map.insert(E::JsonbAccessStr, vec![T::Jsonb, T::Int32], T::Varchar); - map.insert(E::JsonbAccessStr, vec![T::Jsonb, T::Varchar], T::Varchar); - map.insert(E::JsonbTypeof, vec![T::Jsonb], T::Varchar); - map.insert(E::JsonbArrayLength, vec![T::Jsonb], T::Int32); - - map -} - -fn build_binary_cmp_funcs(map: &mut FuncSigMap, exprs: &[ExprType], args: &[DataTypeName]) { - for (e, lt, rt) in iproduct!(exprs, args, args) { - map.insert(*e, vec![*lt, *rt], DataTypeName::Boolean); + /// Returns all function signatures with the same type and number of arguments. + pub fn get_with_arg_nums(&self, ty: PbType, nargs: usize) -> &[FuncSign] { + self.0.get(&(ty, nargs)).map_or(&[], Deref::deref) } } -fn build_binary_atm_funcs(map: &mut FuncSigMap, exprs: &[ExprType], args: &[DataTypeName]) { - for e in exprs { - for (li, lt) in args.iter().enumerate() { - for (ri, rt) in args.iter().enumerate() { - let ret = if li <= ri { rt } else { lt }; - map.insert(*e, vec![*lt, *rt], *ret); - } - } - } +/// A function signature. +#[derive(Clone)] +pub struct FuncSign { + pub name: &'static str, + pub func: PbType, + pub inputs_type: &'static [DataTypeName], + pub ret_type: DataTypeName, + pub build: fn(return_type: DataType, children: Vec) -> Result, } -fn build_unary_atm_funcs(map: &mut FuncSigMap, exprs: &[ExprType], args: &[DataTypeName]) { - for (e, arg) in iproduct!(exprs, args) { - map.insert(*e, vec![*arg], *arg); +impl fmt::Debug for FuncSign { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let s = format!( + "{}({})->{:?}", + self.func.as_str_name(), + self.inputs_type.iter().map(|t| format!("{t:?}")).join(","), + self.ret_type + ) + .to_lowercase(); + f.write_str(&s) } } -fn build_commutative_funcs( - map: &mut FuncSigMap, - expr: ExprType, - arg0: DataTypeName, - arg1: DataTypeName, - ret: DataTypeName, -) { - map.insert(expr, vec![arg0, arg1], ret); - map.insert(expr, vec![arg1, arg0], ret); +/// Register a function into global registry. +/// +/// # Safety +/// +/// This function must be called sequentially. +/// +/// It is designed to be used by `#[function]` macro. +/// Users SHOULD NOT call this function. +#[doc(hidden)] +pub unsafe fn _register(desc: FuncSign) { + FUNC_SIG_MAP_INIT.push(desc) } -fn build_round_funcs(map: &mut FuncSigMap, expr: ExprType) { - map.insert(expr, vec![DataTypeName::Float64], DataTypeName::Float64); - map.insert(expr, vec![DataTypeName::Decimal], DataTypeName::Decimal); -} +/// The global registry of function signatures on initialization. +/// +/// `#[function]` macro will generate a `#[ctor]` function to register the signature into this +/// vector. The calls are guaranteed to be sequential. The vector will be drained and moved into +/// `FUNC_SIG_MAP` on the first access of `FUNC_SIG_MAP`. +static mut FUNC_SIG_MAP_INIT: Vec = Vec::new(); diff --git a/src/expr/src/vector_op/agg/filter.rs b/src/expr/src/vector_op/agg/filter.rs index 71aa96dffdd75..7eab746f2777c 100644 --- a/src/expr/src/vector_op/agg/filter.rs +++ b/src/expr/src/vector_op/agg/filter.rs @@ -117,7 +117,7 @@ mod tests { use risingwave_pb::expr::expr_node::PbType; use super::*; - use crate::expr::{new_binary_expr, Expression, InputRefExpression, LiteralExpression}; + use crate::expr::{build, Expression, InputRefExpression, LiteralExpression}; #[derive(Clone)] struct MockAgg { @@ -186,15 +186,16 @@ mod tests { #[tokio::test] async fn test_selective_agg() -> Result<()> { // filter (where $1 > 5) - let condition = Arc::from( - new_binary_expr( - PbType::GreaterThan, - DataType::Boolean, + let expr = build( + PbType::GreaterThan, + DataType::Boolean, + vec![ InputRefExpression::new(DataType::Int64, 0).boxed(), - LiteralExpression::new(DataType::Int64, Some((5_i64).into())).boxed(), - ) - .unwrap(), - ); + LiteralExpression::new(DataType::Int64, Some(ScalarImpl::Int64(5))).boxed(), + ], + ) + .unwrap(); + let condition = Arc::from(expr); let agg_count = Arc::new(AtomicUsize::new(0)); let mut agg = Filter::new( condition, @@ -228,18 +229,18 @@ mod tests { #[tokio::test] async fn test_selective_agg_null_condition() -> Result<()> { - let condition = Arc::from( - new_binary_expr( - PbType::Equal, - DataType::Boolean, + let expr = build( + PbType::Equal, + DataType::Boolean, + vec![ InputRefExpression::new(DataType::Int64, 0).boxed(), LiteralExpression::new(DataType::Int64, None).boxed(), - ) - .unwrap(), - ); + ], + ) + .unwrap(); let agg_count = Arc::new(AtomicUsize::new(0)); let mut agg = Filter::new( - condition, + Arc::from(expr), Box::new(MockAgg { count: agg_count.clone(), }), diff --git a/src/expr/src/vector_op/arithmetic_op.rs b/src/expr/src/vector_op/arithmetic_op.rs index 13be501bbea5d..added7cbda7e3 100644 --- a/src/expr/src/vector_op/arithmetic_op.rs +++ b/src/expr/src/vector_op/arithmetic_op.rs @@ -12,8 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -#![expect(clippy::extra_unused_type_parameters, reason = "used by macro")] - use std::convert::TryInto; use std::fmt::Debug; @@ -24,10 +22,12 @@ use risingwave_common::types::{ CheckedAdd, Decimal, IntervalUnit, NaiveDateTimeWrapper, NaiveDateWrapper, NaiveTimeWrapper, OrderedF64, }; +use risingwave_expr_macro::function; use crate::{ExprError, Result}; -#[inline(always)] +#[function("add(*number, *number) -> auto")] +#[function("add(interval, interval) -> interval")] pub fn general_add(l: T1, r: T2) -> Result where T1: Into + Debug, @@ -39,7 +39,8 @@ where }) } -#[inline(always)] +#[function("subtract(*number, *number) -> auto")] +#[function("subtract(interval, interval) -> interval")] pub fn general_sub(l: T1, r: T2) -> Result where T1: Into + Debug, @@ -51,7 +52,7 @@ where }) } -#[inline(always)] +#[function("multiply(*number, *number) -> auto")] pub fn general_mul(l: T1, r: T2) -> Result where T1: Into + Debug, @@ -63,7 +64,7 @@ where }) } -#[inline(always)] +#[function("divide(*number, *number) -> auto")] pub fn general_div(l: T1, r: T2) -> Result where T1: Into + Debug, @@ -81,7 +82,7 @@ where }) } -#[inline(always)] +#[function("modulus(*number, *number) -> auto")] pub fn general_mod(l: T1, r: T2) -> Result where T1: Into + Debug, @@ -93,12 +94,21 @@ where }) } -#[inline(always)] +#[function("neg(int16) -> int16")] +#[function("neg(int32) -> int32")] +#[function("neg(int64) -> int64")] +#[function("neg(float32) -> float32")] +#[function("neg(float64) -> float64")] +#[function("neg(decimal) -> decimal")] pub fn general_neg(expr: T1) -> Result { expr.checked_neg().ok_or(ExprError::NumericOutOfRange) } -#[inline(always)] +#[function("abs(int16) -> int16")] +#[function("abs(int32) -> int32")] +#[function("abs(int64) -> int64")] +#[function("abs(float32) -> float32")] +#[function("abs(float64) -> float64")] pub fn general_abs(expr: T1) -> Result { if expr.is_negative() { general_neg(expr) @@ -107,10 +117,12 @@ pub fn general_abs(expr: T1) -> Result { } } +#[function("abs(decimal) -> decimal")] pub fn decimal_abs(decimal: Decimal) -> Result { Ok(Decimal::abs(&decimal)) } +#[function("pow(float64, float64) -> float64")] pub fn pow_f64(l: OrderedF64, r: OrderedF64) -> Result { let res = l.powf(r); if res.is_infinite() { @@ -121,7 +133,7 @@ pub fn pow_f64(l: OrderedF64, r: OrderedF64) -> Result { } #[inline(always)] -pub fn general_atm(l: T1, r: T2, atm: F) -> Result +fn general_atm(l: T1, r: T2, atm: F) -> Result where T1: Into + Debug, T2: Into + Debug, @@ -130,8 +142,8 @@ where atm(l.into(), r.into()) } -#[inline(always)] -pub fn timestamp_timestamp_sub( +#[function("subtract(timestamp, timestamp) -> interval")] +pub fn timestamp_timestamp_sub( l: NaiveDateTimeWrapper, r: NaiveDateTimeWrapper, ) -> Result { @@ -143,55 +155,43 @@ pub fn timestamp_timestamp_sub( Ok(IntervalUnit::from_month_day_usec(0, days as i32, usecs)) } -#[inline(always)] -pub fn date_date_sub(l: NaiveDateWrapper, r: NaiveDateWrapper) -> Result { +#[function("subtract(date, date) -> int32")] +pub fn date_date_sub(l: NaiveDateWrapper, r: NaiveDateWrapper) -> Result { Ok((l.0 - r.0).num_days() as i32) // this does not overflow or underflow } -#[inline(always)] -pub fn interval_timestamp_add( +#[function("add(interval, timestamp) -> timestamp")] +pub fn interval_timestamp_add( l: IntervalUnit, r: NaiveDateTimeWrapper, ) -> Result { r.checked_add(l).ok_or(ExprError::NumericOutOfRange) } -#[inline(always)] -pub fn interval_date_add( - l: IntervalUnit, - r: NaiveDateWrapper, -) -> Result { - interval_timestamp_add::(l, r.into()) +#[function("add(interval, date) -> timestamp")] +pub fn interval_date_add(l: IntervalUnit, r: NaiveDateWrapper) -> Result { + interval_timestamp_add(l, r.into()) } -#[inline(always)] -pub fn interval_time_add( - l: IntervalUnit, - r: NaiveTimeWrapper, -) -> Result { - time_interval_add::(r, l) +#[function("add(interval, time) -> time")] +pub fn interval_time_add(l: IntervalUnit, r: NaiveTimeWrapper) -> Result { + time_interval_add(r, l) } -#[inline(always)] -pub fn date_interval_add( - l: NaiveDateWrapper, - r: IntervalUnit, -) -> Result { - interval_date_add::(r, l) +#[function("add(date, interval) -> timestamp")] +pub fn date_interval_add(l: NaiveDateWrapper, r: IntervalUnit) -> Result { + interval_date_add(r, l) } -#[inline(always)] -pub fn date_interval_sub( - l: NaiveDateWrapper, - r: IntervalUnit, -) -> Result { +#[function("subtract(date, interval) -> timestamp")] +pub fn date_interval_sub(l: NaiveDateWrapper, r: IntervalUnit) -> Result { // TODO: implement `checked_sub` for `NaiveDateTimeWrapper` to handle the edge case of negation // overflowing. - interval_date_add::(r.checked_neg().ok_or(ExprError::NumericOutOfRange)?, l) + interval_date_add(r.checked_neg().ok_or(ExprError::NumericOutOfRange)?, l) } -#[inline(always)] -pub fn date_int_add(l: NaiveDateWrapper, r: i32) -> Result { +#[function("add(date, int32) -> date")] +pub fn date_int_add(l: NaiveDateWrapper, r: i32) -> Result { let date = l.0; let date_wrapper = date .checked_add_signed(chrono::Duration::days(r as i64)) @@ -200,13 +200,13 @@ pub fn date_int_add(l: NaiveDateWrapper, r: i32) -> Result(l: i32, r: NaiveDateWrapper) -> Result { - date_int_add::(r, l) +#[function("add(int32, date) -> date")] +pub fn int_date_add(l: i32, r: NaiveDateWrapper) -> Result { + date_int_add(r, l) } -#[inline(always)] -pub fn date_int_sub(l: NaiveDateWrapper, r: i32) -> Result { +#[function("subtract(date, int32) -> date")] +pub fn date_int_sub(l: NaiveDateWrapper, r: i32) -> Result { let date = l.0; let date_wrapper = date .checked_sub_signed(chrono::Duration::days(r as i64)) @@ -215,35 +215,35 @@ pub fn date_int_sub(l: NaiveDateWrapper, r: i32) -> Result( +#[function("add(timestamp, interval) -> timestamp")] +pub fn timestamp_interval_add( l: NaiveDateTimeWrapper, r: IntervalUnit, ) -> Result { - interval_timestamp_add::(r, l) + interval_timestamp_add(r, l) } -#[inline(always)] -pub fn timestamp_interval_sub( +#[function("subtract(timestamp, interval) -> timestamp")] +pub fn timestamp_interval_sub( l: NaiveDateTimeWrapper, r: IntervalUnit, ) -> Result { - interval_timestamp_add::(r.checked_neg().ok_or(ExprError::NumericOutOfRange)?, l) + interval_timestamp_add(r.checked_neg().ok_or(ExprError::NumericOutOfRange)?, l) } -#[inline(always)] -pub fn timestamptz_interval_add(l: i64, r: IntervalUnit) -> Result { +#[function("add(timestamptz, interval) -> timestamptz")] +pub fn timestamptz_interval_add(l: i64, r: IntervalUnit) -> Result { timestamptz_interval_inner(l, r, i64::checked_add) } -#[inline(always)] -pub fn timestamptz_interval_sub(l: i64, r: IntervalUnit) -> Result { +#[function("subtract(timestamptz, interval) -> timestamptz")] +pub fn timestamptz_interval_sub(l: i64, r: IntervalUnit) -> Result { timestamptz_interval_inner(l, r, i64::checked_sub) } -#[inline(always)] -pub fn interval_timestamptz_add(l: IntervalUnit, r: i64) -> Result { - timestamptz_interval_add::(r, l) +#[function("add(interval, timestamptz) -> timestamptz")] +pub fn interval_timestamptz_add(l: IntervalUnit, r: i64) -> Result { + timestamptz_interval_add(r, l) } #[inline(always)] @@ -267,41 +267,29 @@ fn timestamptz_interval_inner( result.ok_or(ExprError::NumericOutOfRange) } -#[inline(always)] -pub fn interval_int_mul(l: IntervalUnit, r: T2) -> Result -where - T2: TryInto + Debug, -{ +#[function("multiply(interval, *int) -> interval")] +pub fn interval_int_mul(l: IntervalUnit, r: impl TryInto + Debug) -> Result { l.checked_mul_int(r).ok_or(ExprError::NumericOutOfRange) } -#[inline(always)] -pub fn int_interval_mul(l: T1, r: IntervalUnit) -> Result -where - T1: TryInto + Debug, -{ - interval_int_mul::(r, l) +#[function("multiply(*int, interval) -> interval")] +pub fn int_interval_mul(l: impl TryInto + Debug, r: IntervalUnit) -> Result { + interval_int_mul(r, l) } -#[inline(always)] -pub fn date_time_add( - l: NaiveDateWrapper, - r: NaiveTimeWrapper, -) -> Result { +#[function("add(date, time) -> timestamp")] +pub fn date_time_add(l: NaiveDateWrapper, r: NaiveTimeWrapper) -> Result { let date_time = NaiveDateTime::new(l.0, r.0); Ok(NaiveDateTimeWrapper::new(date_time)) } -#[inline(always)] -pub fn time_date_add( - l: NaiveTimeWrapper, - r: NaiveDateWrapper, -) -> Result { - date_time_add::(r, l) +#[function("add(time, date) -> timestamp")] +pub fn time_date_add(l: NaiveTimeWrapper, r: NaiveDateWrapper) -> Result { + date_time_add(r, l) } -#[inline(always)] -pub fn time_time_sub(l: NaiveTimeWrapper, r: NaiveTimeWrapper) -> Result { +#[function("subtract(time, time) -> interval")] +pub fn time_time_sub(l: NaiveTimeWrapper, r: NaiveTimeWrapper) -> Result { let tmp = l.0 - r.0; // this does not overflow or underflow let usecs = tmp .num_microseconds() @@ -309,11 +297,8 @@ pub fn time_time_sub(l: NaiveTimeWrapper, r: NaiveTimeWrapper) -> Re Ok(IntervalUnit::from_month_day_usec(0, 0, usecs)) } -#[inline(always)] -pub fn time_interval_sub( - l: NaiveTimeWrapper, - r: IntervalUnit, -) -> Result { +#[function("subtract(time, interval) -> time")] +pub fn time_interval_sub(l: NaiveTimeWrapper, r: IntervalUnit) -> Result { let time = l.0; let (new_time, ignored) = time.overflowing_sub_signed(Duration::microseconds(r.get_usecs())); if ignored == 0 { @@ -323,11 +308,8 @@ pub fn time_interval_sub( } } -#[inline(always)] -pub fn time_interval_add( - l: NaiveTimeWrapper, - r: IntervalUnit, -) -> Result { +#[function("add(time, interval) -> time")] +pub fn time_interval_add(l: NaiveTimeWrapper, r: IntervalUnit) -> Result { let time = l.0; let (new_time, ignored) = time.overflowing_add_signed(Duration::microseconds(r.get_usecs())); if ignored == 0 { @@ -337,24 +319,28 @@ pub fn time_interval_add( } } -#[inline(always)] -pub fn interval_float_div(l: IntervalUnit, r: T2) -> Result +#[function("divide(interval, *number) -> interval")] +pub fn interval_float_div(l: IntervalUnit, r: T2) -> Result where T2: TryInto + Debug, { l.div_float(r).ok_or(ExprError::NumericOutOfRange) } -#[inline(always)] -pub fn interval_float_mul(l: IntervalUnit, r: T2) -> Result +#[function("multiply(interval, float32) -> interval")] +#[function("multiply(interval, float64) -> interval")] +#[function("multiply(interval, decimal) -> interval")] +pub fn interval_float_mul(l: IntervalUnit, r: T2) -> Result where T2: TryInto + Debug, { l.mul_float(r).ok_or(ExprError::NumericOutOfRange) } -#[inline(always)] -pub fn float_interval_mul(l: T1, r: IntervalUnit) -> Result +#[function("multiply(float32, interval) -> interval")] +#[function("multiply(float64, interval) -> interval")] +#[function("multiply(decimal, interval) -> interval")] +pub fn float_interval_mul(l: T1, r: IntervalUnit) -> Result where T1: TryInto + Debug, { @@ -365,9 +351,12 @@ where mod tests { use std::str::FromStr; - use risingwave_common::types::Decimal; + use risingwave_common::types::test_utils::IntervalUnitTestExt; + use risingwave_common::types::{ + Decimal, IntervalUnit, NaiveDateTimeWrapper, NaiveDateWrapper, OrderedF32, OrderedF64, + }; - use crate::vector_op::arithmetic_op::general_add; + use super::*; #[test] fn test() { @@ -376,4 +365,114 @@ mod tests { Decimal::from_str("2").unwrap() ); } + + #[test] + fn test_arithmetic() { + assert_eq!( + general_add::(dec("1.0"), 1).unwrap(), + dec("2.0") + ); + assert_eq!( + general_sub::(dec("1.0"), 2).unwrap(), + dec("-1.0") + ); + assert_eq!( + general_mul::(dec("1.0"), 2).unwrap(), + dec("2.0") + ); + assert_eq!( + general_div::(dec("2.0"), 2).unwrap(), + dec("1.0") + ); + assert_eq!( + general_mod::(dec("2.0"), 2).unwrap(), + dec("0") + ); + assert_eq!(general_neg::(dec("1.0")).unwrap(), dec("-1.0")); + assert_eq!(general_add::(1i16, 1i32).unwrap(), 2i32); + assert_eq!(general_sub::(1i16, 1i32).unwrap(), 0i32); + assert_eq!(general_mul::(1i16, 1i32).unwrap(), 1i32); + assert_eq!(general_div::(1i16, 1i32).unwrap(), 1i32); + assert_eq!(general_mod::(1i16, 1i32).unwrap(), 0i32); + assert_eq!(general_neg::(1i16).unwrap(), -1i16); + + assert_eq!( + general_add::(dec("1.0"), -1f32).unwrap(), + dec("0.0") + ); + assert_eq!( + general_sub::(dec("1.0"), 1f32).unwrap(), + dec("0.0") + ); + assert_eq!( + general_div::(dec("0.0"), 1f32).unwrap(), + dec("0.0") + ); + assert_eq!( + general_mul::(dec("0.0"), 1f32).unwrap(), + dec("0.0") + ); + assert_eq!( + general_mod::(dec("0.0"), 1f32).unwrap(), + dec("0.0") + ); + assert!( + general_add::(-1i32, 1f32.into()) + .unwrap() + .is_zero() + ); + assert!( + general_sub::(1i32, 1f32.into()) + .unwrap() + .is_zero() + ); + assert!( + general_mul::(0i32, 1f32.into()) + .unwrap() + .is_zero() + ); + assert!( + general_div::(0i32, 1f32.into()) + .unwrap() + .is_zero() + ); + assert_eq!( + general_neg::(1f32.into()).unwrap(), + OrderedF32::from(-1f32) + ); + assert_eq!( + date_interval_add( + NaiveDateWrapper::from_ymd_uncheck(1994, 1, 1), + IntervalUnit::from_month(12) + ) + .unwrap(), + NaiveDateTimeWrapper::new( + NaiveDateTime::parse_from_str("1995-1-1 0:0:0", "%Y-%m-%d %H:%M:%S").unwrap() + ) + ); + assert_eq!( + interval_date_add( + IntervalUnit::from_month(12), + NaiveDateWrapper::from_ymd_uncheck(1994, 1, 1) + ) + .unwrap(), + NaiveDateTimeWrapper::new( + NaiveDateTime::parse_from_str("1995-1-1 0:0:0", "%Y-%m-%d %H:%M:%S").unwrap() + ) + ); + assert_eq!( + date_interval_sub( + NaiveDateWrapper::from_ymd_uncheck(1994, 1, 1), + IntervalUnit::from_month(12) + ) + .unwrap(), + NaiveDateTimeWrapper::new( + NaiveDateTime::parse_from_str("1993-1-1 0:0:0", "%Y-%m-%d %H:%M:%S").unwrap() + ) + ); + } + + fn dec(s: &str) -> Decimal { + Decimal::from_str(s).unwrap() + } } diff --git a/src/expr/src/vector_op/array_access.rs b/src/expr/src/vector_op/array_access.rs index 1e460acfddb9d..41101c6d13229 100644 --- a/src/expr/src/vector_op/array_access.rs +++ b/src/expr/src/vector_op/array_access.rs @@ -14,22 +14,21 @@ use risingwave_common::array::ListRef; use risingwave_common::types::{Scalar, ToOwnedDatum}; +use risingwave_expr_macro::function; use crate::Result; -#[inline(always)] -pub fn array_access(l: Option>, r: Option) -> Result> { - match (l, r) { - // index must be greater than 0 following a one-based numbering convention for arrays - (Some(list), Some(index)) if index > 0 => { - let datumref = list.value_at(index as usize)?; - if let Some(scalar) = datumref.to_owned_datum() { - Ok(Some(scalar.try_into()?)) - } else { - Ok(None) - } - } - _ => Ok(None), +#[function("array_access(list, int32) -> *")] +pub fn array_access(list: ListRef<'_>, index: i32) -> Result> { + // index must be greater than 0 following a one-based numbering convention for arrays + if index < 1 { + return Ok(None); + } + let datumref = list.value_at(index as usize)?; + if let Some(scalar) = datumref.to_owned_datum() { + Ok(Some(scalar.try_into()?)) + } else { + Ok(None) } } @@ -50,10 +49,10 @@ mod tests { ]); let l1 = ListRef::ValueRef { val: &v1 }; - assert_eq!(array_access::(Some(l1), Some(1)).unwrap(), Some(1)); - assert_eq!(array_access::(Some(l1), Some(-1)).unwrap(), None); - assert_eq!(array_access::(Some(l1), Some(0)).unwrap(), None); - assert_eq!(array_access::(Some(l1), Some(4)).unwrap(), None); + assert_eq!(array_access::(l1, 1).unwrap(), Some(1)); + assert_eq!(array_access::(l1, -1).unwrap(), None); + assert_eq!(array_access::(l1, 0).unwrap(), None); + assert_eq!(array_access::(l1, 4).unwrap(), None); } #[test] @@ -75,15 +74,15 @@ mod tests { let l3 = ListRef::ValueRef { val: &v3 }; assert_eq!( - array_access::>(Some(l1), Some(1)).unwrap(), + array_access::>(l1, 1).unwrap(), Some("来自".into()) ); assert_eq!( - array_access::>(Some(l2), Some(2)).unwrap(), + array_access::>(l2, 2).unwrap(), Some("荷兰".into()) ); assert_eq!( - array_access::>(Some(l3), Some(3)).unwrap(), + array_access::>(l3, 3).unwrap(), Some("的爱".into()) ); } @@ -102,7 +101,7 @@ mod tests { ]); let l = ListRef::ValueRef { val: &v }; assert_eq!( - array_access::(Some(l), Some(1)).unwrap(), + array_access::(l, 1).unwrap(), Some(ListValue::new(vec![ Some(ScalarImpl::Utf8("foo".into())), Some(ScalarImpl::Utf8("bar".into())), diff --git a/src/expr/src/vector_op/ascii.rs b/src/expr/src/vector_op/ascii.rs index e0421d39e0952..0f902f740fd2d 100644 --- a/src/expr/src/vector_op/ascii.rs +++ b/src/expr/src/vector_op/ascii.rs @@ -12,11 +12,11 @@ // See the License for the specific language governing permissions and // limitations under the License. -use crate::Result; +use risingwave_expr_macro::function; -#[inline(always)] -pub fn ascii(s: &str) -> Result { - Ok(s.as_bytes().first().map(|x| *x as i32).unwrap_or(0)) +#[function("ascii(varchar) -> int32")] +pub fn ascii(s: &str) -> i32 { + s.as_bytes().first().map(|x| *x as i32).unwrap_or(0) } #[cfg(test)] @@ -27,7 +27,7 @@ mod tests { fn test_ascii() { let cases = [("hello", 104), ("你好", 228), ("", 0)]; for (s, expected) in cases { - assert_eq!(ascii(s).unwrap(), expected) + assert_eq!(ascii(s), expected) } } } diff --git a/src/expr/src/vector_op/bitwise_op.rs b/src/expr/src/vector_op/bitwise_op.rs index 9e45cc8e3fa21..b89a0d1bae0b9 100644 --- a/src/expr/src/vector_op/bitwise_op.rs +++ b/src/expr/src/vector_op/bitwise_op.rs @@ -17,6 +17,7 @@ use std::fmt::Debug; use std::ops::{BitAnd, BitOr, BitXor, Not}; use num_traits::{CheckedShl, CheckedShr}; +use risingwave_expr_macro::function; use crate::{ExprError, Result}; @@ -25,7 +26,12 @@ use crate::{ExprError, Result}; // undefined behaviour. If the RHS is negative, instead of having an unexpected answer, we return an // error. If PG had clearly defined behavior rather than relying on UB of C, we would follow it even // when it is different from rust std. -#[inline(always)] +#[function("bitwise_shift_left(int16, int16) -> int16")] +#[function("bitwise_shift_left(int16, int32) -> int16")] +#[function("bitwise_shift_left(int32, int16) -> int32")] +#[function("bitwise_shift_left(int32, int32) -> int32")] +#[function("bitwise_shift_left(int64, int16) -> int64")] +#[function("bitwise_shift_left(int64, int32) -> int64")] pub fn general_shl(l: T1, r: T2) -> Result where T1: CheckedShl + Debug, @@ -36,7 +42,12 @@ where }) } -#[inline(always)] +#[function("bitwise_shift_right(int16, int16) -> int16")] +#[function("bitwise_shift_right(int16, int32) -> int16")] +#[function("bitwise_shift_right(int32, int16) -> int32")] +#[function("bitwise_shift_right(int32, int32) -> int32")] +#[function("bitwise_shift_right(int64, int16) -> int64")] +#[function("bitwise_shift_right(int64, int32) -> int64")] pub fn general_shr(l: T1, r: T2) -> Result where T1: CheckedShr + Debug, @@ -48,7 +59,7 @@ where } #[inline(always)] -pub fn general_shift(l: T1, r: T2, atm: F) -> Result +fn general_shift(l: T1, r: T2, atm: F) -> Result where T1: Debug, T2: TryInto + Debug, @@ -61,7 +72,7 @@ where atm(l, r) } -#[inline(always)] +#[function("bitwise_and(*int, *int) -> auto")] pub fn general_bitand(l: T1, r: T2) -> T3 where T1: Into + Debug, @@ -71,7 +82,7 @@ where l.into() & r.into() } -#[inline(always)] +#[function("bitwise_or(*int, *int) -> auto")] pub fn general_bitor(l: T1, r: T2) -> T3 where T1: Into + Debug, @@ -81,7 +92,7 @@ where l.into() | r.into() } -#[inline(always)] +#[function("bitwise_xor(*int, *int) -> auto")] pub fn general_bitxor(l: T1, r: T2) -> T3 where T1: Into + Debug, @@ -91,7 +102,44 @@ where l.into() ^ r.into() } -#[inline(always)] +#[function("bitwise_not(*int) -> auto")] pub fn general_bitnot>(expr: T1) -> T1 { !expr } + +#[cfg(test)] +mod tests { + use std::assert_matches::assert_matches; + + use super::*; + + #[test] + fn test_bitwise() { + // check the boundary + assert_eq!(general_shl::(1i32, 0i32).unwrap(), 1i32); + assert_eq!(general_shl::(1i64, 31i32).unwrap(), 2147483648i64); + assert_matches!( + general_shl::(1i32, 32i32).unwrap_err(), + ExprError::NumericOutOfRange, + ); + assert_eq!( + general_shr::(-2147483648i64, 31i32).unwrap(), + -1i64 + ); + assert_eq!(general_shr::(1i64, 0i32).unwrap(), 1i64); + // truth table + assert_eq!( + general_bitand::(0b0011u32, 0b0101u32), + 0b1u64 + ); + assert_eq!( + general_bitor::(0b0011u32, 0b0101u32), + 0b0111u64 + ); + assert_eq!( + general_bitxor::(0b0011u32, 0b0101u32), + 0b0110u64 + ); + assert_eq!(general_bitnot::(0b01i32), -2i32); + } +} diff --git a/src/expr/src/vector_op/cast.rs b/src/expr/src/vector_op/cast.rs index 4681de4136f11..2a77e50a66974 100644 --- a/src/expr/src/vector_op/cast.rs +++ b/src/expr/src/vector_op/cast.rs @@ -13,22 +13,30 @@ // limitations under the License. use std::any::type_name; -use std::fmt::Write; +use std::fmt::{Debug, Write}; use std::str::FromStr; use chrono::{DateTime, NaiveDate, NaiveDateTime, NaiveTime, Utc}; +use futures_util::FutureExt; use itertools::Itertools; use num_traits::ToPrimitive; -use risingwave_common::array::{Array, JsonbRef, ListRef, ListValue, StructRef, StructValue}; +use risingwave_common::array::{ + JsonbRef, ListArray, ListRef, ListValue, StructArray, StructRef, StructValue, Utf8Array, +}; +use risingwave_common::row::OwnedRow; use risingwave_common::types::struct_type::StructType; use risingwave_common::types::to_text::ToText; use risingwave_common::types::{ DataType, Decimal, IntervalUnit, NaiveDateTimeWrapper, NaiveDateWrapper, NaiveTimeWrapper, - OrderedF32, OrderedF64, Scalar, ScalarImpl, ScalarRefImpl, + OrderedF32, OrderedF64, ScalarImpl, }; use risingwave_common::util::iter_util::ZipEqFast; +use risingwave_expr_macro::{build_function, function}; +use risingwave_pb::expr::expr_node::PbType; use speedate::{Date as SpeedDate, DateTime as SpeedDateTime, Time as SpeedTime}; +use crate::expr::template::UnaryExpression; +use crate::expr::{build, BoxedExpression, Expression, InputRefExpression}; use crate::{ExprError, Result}; /// String literals for bool type. @@ -49,17 +57,17 @@ const PARSE_ERROR_STR_TO_TIME: &str = const PARSE_ERROR_STR_TO_DATE: &str = "Can't cast string to date (expected format is YYYY-MM-DD)"; const PARSE_ERROR_STR_TO_BYTEA: &str = "Invalid Bytea syntax"; -#[inline(always)] +#[function("cast(varchar) -> date")] pub fn str_to_date(elem: &str) -> Result { Ok(NaiveDateWrapper::new(parse_naive_date(elem)?)) } -#[inline(always)] +#[function("cast(varchar) -> time")] pub fn str_to_time(elem: &str) -> Result { Ok(NaiveTimeWrapper::new(parse_naive_time(elem)?)) } -#[inline(always)] +#[function("cast(varchar) -> timestamp")] pub fn str_to_timestamp(elem: &str) -> Result { Ok(NaiveDateTimeWrapper::new(parse_naive_datetime(elem)?)) } @@ -184,7 +192,7 @@ pub fn i64_to_timestamptz(t: i64) -> Result { } } -#[inline(always)] +#[function("cast(varchar) -> bytea")] pub fn str_to_bytea(elem: &str) -> Result> { // Padded with whitespace str is not allowed. if elem.starts_with(' ') && elem.trim().starts_with("\\x") { @@ -250,7 +258,9 @@ pub fn parse_bytes_traditional(s: &str) -> Result> { Ok(out) } -#[inline(always)] +#[function("cast(varchar) -> *number")] +#[function("cast(varchar) -> interval")] +#[function("cast(varchar) -> jsonb")] pub fn str_parse(elem: &str) -> Result where T: FromStr, @@ -261,67 +271,74 @@ where .map_err(|_| ExprError::Parse(type_name::().into())) } -/// Define the cast function to primitive types. -/// -/// Due to the orphan rule, some data can't implement `TryFrom` trait for basic type. -/// We can only use [`ToPrimitive`] trait. -/// -/// Note: this might be lossy according to the docs from [`ToPrimitive`]: -/// > On the other hand, conversions with possible precision loss or truncation -/// are admitted, like an `f32` with a decimal part to an integer type, or -/// even a large `f64` saturating to `f32` infinity. -macro_rules! define_cast_to_primitive { - ($ty:ty) => { - define_cast_to_primitive! { $ty, $ty } - }; - ($ty:ty, $wrapper_ty:ty) => { - paste::paste! { - #[inline(always)] - pub fn [](elem: T) -> Result<$wrapper_ty> - where - T: ToPrimitive + std::fmt::Debug, - { - elem.[]() - .ok_or_else(|| { - ExprError::CastOutOfRange( - std::any::type_name::<$ty>() - ) - }) - .map(Into::into) - } - } - }; +// Define the cast function to primitive types. +// +// Due to the orphan rule, some data can't implement `TryFrom` trait for basic type. +// We can only use [`ToPrimitive`] trait. +// +// Note: this might be lossy according to the docs from [`ToPrimitive`]: +// > On the other hand, conversions with possible precision loss or truncation +// are admitted, like an `f32` with a decimal part to an integer type, or +// even a large `f64` saturating to `f32` infinity. + +#[function("cast(float32) -> int16")] +#[function("cast(float64) -> int16")] +pub fn to_i16(elem: T) -> Result { + elem.to_i16().ok_or(ExprError::CastOutOfRange("i16")) } -define_cast_to_primitive! { i16 } -define_cast_to_primitive! { i32 } -define_cast_to_primitive! { i64 } -define_cast_to_primitive! { f32, OrderedF32 } -define_cast_to_primitive! { f64, OrderedF64 } +#[function("cast(float32) -> int32")] +#[function("cast(float64) -> int32")] +pub fn to_i32(elem: T) -> Result { + elem.to_i32().ok_or(ExprError::CastOutOfRange("i32")) +} + +#[function("cast(float32) -> int64")] +#[function("cast(float64) -> int64")] +pub fn to_i64(elem: T) -> Result { + elem.to_i64().ok_or(ExprError::CastOutOfRange("i64")) +} + +#[function("cast(int32) -> float32")] +#[function("cast(int64) -> float32")] +#[function("cast(float64) -> float32")] +#[function("cast(decimal) -> float32")] +pub fn to_f32(elem: T) -> Result { + elem.to_f32() + .map(Into::into) + .ok_or(ExprError::CastOutOfRange("f32")) +} + +#[function("cast(decimal) -> float64")] +pub fn to_f64(elem: T) -> Result { + elem.to_f64() + .map(Into::into) + .ok_or(ExprError::CastOutOfRange("f64")) +} // In postgresSql, the behavior of casting decimal to integer is rounding. // We should write them separately -#[inline(always)] +#[function("cast(decimal) -> int16")] pub fn dec_to_i16(elem: Decimal) -> Result { to_i16(elem.round_dp(0)) } -#[inline(always)] +#[function("cast(decimal) -> int32")] pub fn dec_to_i32(elem: Decimal) -> Result { to_i32(elem.round_dp(0)) } -#[inline(always)] +#[function("cast(decimal) -> int64")] pub fn dec_to_i64(elem: Decimal) -> Result { to_i64(elem.round_dp(0)) } -#[inline(always)] +#[function("cast(jsonb) -> boolean")] pub fn jsonb_to_bool(v: JsonbRef<'_>) -> Result { v.as_bool().map_err(|e| ExprError::Parse(e.into())) } -#[inline(always)] +#[function("cast(jsonb) -> decimal")] pub fn jsonb_to_dec(v: JsonbRef<'_>) -> Result { v.as_number() .map_err(|e| ExprError::Parse(e.into())) @@ -335,38 +352,38 @@ pub fn jsonb_to_dec(v: JsonbRef<'_>) -> Result { /// Note that PostgreSQL casts JSON numbers from arbitrary precision `numeric` but we use `f64`. /// This is less powerful but still meets RFC 8259 interoperability. macro_rules! define_jsonb_to_number { - ($ty:ty) => { - define_jsonb_to_number! { $ty, $ty } + ($ty:ty, $sig:literal) => { + define_jsonb_to_number! { $ty, $ty, $sig } }; - ($ty:ty, $wrapper_ty:ty) => { + ($ty:ty, $wrapper_ty:ty, $sig:literal) => { paste::paste! { - #[inline(always)] + #[function($sig)] pub fn [](v: JsonbRef<'_>) -> Result<$wrapper_ty> { v.as_number().map_err(|e| ExprError::Parse(e.into())).and_then([]) } } }; } -define_jsonb_to_number! { i16 } -define_jsonb_to_number! { i32 } -define_jsonb_to_number! { i64 } -define_jsonb_to_number! { f32, OrderedF32 } -define_jsonb_to_number! { f64, OrderedF64 } +define_jsonb_to_number! { i16, "cast(jsonb) -> int16" } +define_jsonb_to_number! { i32, "cast(jsonb) -> int32" } +define_jsonb_to_number! { i64, "cast(jsonb) -> int64" } +define_jsonb_to_number! { f32, OrderedF32, "cast(jsonb) -> float32" } +define_jsonb_to_number! { f64, OrderedF64, "cast(jsonb) -> float64" } /// In `PostgreSQL`, casting from timestamp to date discards the time part. -#[inline(always)] +#[function("cast(timestamp) -> date")] pub fn timestamp_to_date(elem: NaiveDateTimeWrapper) -> NaiveDateWrapper { NaiveDateWrapper(elem.0.date()) } /// In `PostgreSQL`, casting from timestamp to time discards the date part. -#[inline(always)] +#[function("cast(timestamp) -> time")] pub fn timestamp_to_time(elem: NaiveDateTimeWrapper) -> NaiveTimeWrapper { NaiveTimeWrapper(elem.0.time()) } /// In `PostgreSQL`, casting from interval to time discards the days part. -#[inline(always)] +#[function("cast(interval) -> time")] pub fn interval_to_time(elem: IntervalUnit) -> NaiveTimeWrapper { let usecs = elem.get_usecs_of_day(); let secs = (usecs / 1_000_000) as u32; @@ -374,7 +391,11 @@ pub fn interval_to_time(elem: IntervalUnit) -> NaiveTimeWrapper { NaiveTimeWrapper::from_num_seconds_from_midnight_uncheck(secs, nano) } -#[inline(always)] +#[function("cast(boolean) -> int32")] +#[function("cast(int32) -> int16")] +#[function("cast(int64) -> int16")] +#[function("cast(int64) -> int32")] +#[function("cast(int64) -> float64")] pub fn try_cast(elem: T1) -> Result where T1: TryInto + std::fmt::Debug + Copy, @@ -384,7 +405,21 @@ where .map_err(|_| ExprError::CastOutOfRange(std::any::type_name::())) } -#[inline(always)] +#[function("cast(int16) -> int32")] +#[function("cast(int16) -> int64")] +#[function("cast(int16) -> float32")] +#[function("cast(int16) -> float64")] +#[function("cast(int16) -> decimal")] +#[function("cast(int32) -> int64")] +#[function("cast(int32) -> float64")] +#[function("cast(int32) -> decimal")] +#[function("cast(int64) -> decimal")] +#[function("cast(float32) -> float64")] +#[function("cast(float32) -> decimal")] +#[function("cast(float64) -> decimal")] +#[function("cast(date) -> timestamp")] +#[function("cast(time) -> interval")] +#[function("cast(varchar) -> varchar")] pub fn cast(elem: T1) -> T2 where T1: Into, @@ -392,7 +427,7 @@ where elem.into() } -#[inline(always)] +#[function("cast(varchar) -> boolean")] pub fn str_to_bool(input: &str) -> Result { let trimmed_input = input.trim(); if TRUE_BOOL_LITERALS @@ -410,17 +445,26 @@ pub fn str_to_bool(input: &str) -> Result { } } +#[function("cast(int32) -> boolean")] pub fn int32_to_bool(input: i32) -> Result { Ok(input != 0) } // For most of the types, cast them to varchar is similar to return their text format. // So we use this function to cast type to varchar. +#[function("cast(*number) -> varchar")] +#[function("cast(time) -> varchar")] +#[function("cast(date) -> varchar")] +#[function("cast(interval) -> varchar")] +#[function("cast(timestamp) -> varchar")] +#[function("cast(jsonb) -> varchar")] +#[function("cast(list) -> varchar")] pub fn general_to_text(elem: impl ToText, mut writer: &mut dyn Write) -> Result<()> { elem.write(&mut writer).unwrap(); Ok(()) } +#[function("cast(boolean) -> varchar")] pub fn bool_to_varchar(input: bool, writer: &mut dyn Write) -> Result<()> { writer .write_str(if input { "true" } else { "false" }) @@ -430,6 +474,7 @@ pub fn bool_to_varchar(input: bool, writer: &mut dyn Write) -> Result<()> { /// `bool_out` is different from `general_to_string` to produce a single char. `PostgreSQL` /// uses different variants of bool-to-string in different situations. +#[function("bool_out(boolean) -> varchar")] pub fn bool_out(input: bool, writer: &mut dyn Write) -> Result<()> { writer.write_str(if input { "t" } else { "f" }).unwrap(); Ok(()) @@ -471,292 +516,184 @@ pub fn literal_parsing( Ok(scalar) } -/// It accepts a macro whose input is `{ $input:ident, $cast:ident, $func:expr }` tuples -/// -/// * `$input`: input type -/// * `$cast`: The cast type in that the operation will calculate -/// * `$func`: The scalar function for expression, it's a generic function and specialized by the -/// type of `$input, $cast` -/// * `$infallible`: Whether the cast is infallible -#[macro_export] -macro_rules! for_all_cast_variants { - ($macro:ident) => { - $macro! { - { varchar, date, str_to_date, false }, - { varchar, time, str_to_time, false }, - { varchar, interval, str_parse, false }, - { varchar, timestamp, str_to_timestamp, false }, - { varchar, int16, str_parse, false }, - { varchar, int32, str_parse, false }, - { varchar, int64, str_parse, false }, - { varchar, float32, str_parse, false }, - { varchar, float64, str_parse, false }, - { varchar, decimal, str_parse, false }, - { varchar, boolean, str_to_bool, false }, - { varchar, bytea, str_to_bytea, false }, - { varchar, jsonb, str_parse, false }, - // `str_to_list` requires `target_elem_type` and is handled elsewhere - - { boolean, varchar, bool_to_varchar, false }, - { int16, varchar, general_to_text, false }, - { int32, varchar, general_to_text, false }, - { int64, varchar, general_to_text, false }, - { float32, varchar, general_to_text, false }, - { float64, varchar, general_to_text, false }, - { decimal, varchar, general_to_text, false }, - { time, varchar, general_to_text, false }, - { interval, varchar, general_to_text, false }, - { date, varchar, general_to_text, false }, - { timestamp, varchar, general_to_text, false }, - { jsonb, varchar, |x, w| general_to_text(x, w), false }, - { list, varchar, |x, w| general_to_text(x, w), false }, - - { jsonb, boolean, jsonb_to_bool, false }, - { jsonb, int16, jsonb_to_i16, false }, - { jsonb, int32, jsonb_to_i32, false }, - { jsonb, int64, jsonb_to_i64, false }, - { jsonb, decimal, jsonb_to_dec, false }, - { jsonb, float32, jsonb_to_f32, false }, - { jsonb, float64, jsonb_to_f64, false }, - - { boolean, int32, try_cast, false }, - { int32, boolean, int32_to_bool, false }, - - { int16, int32, cast::, true }, - { int16, int64, cast::, true }, - { int16, float32, cast::, true }, - { int16, float64, cast::, true }, - { int16, decimal, cast::, true }, - { int32, int16, try_cast, false }, - { int32, int64, cast::, true }, - { int32, float32, to_f32, false }, // lossy - { int32, float64, cast::, true }, - { int32, decimal, cast::, true }, - { int64, int16, try_cast, false }, - { int64, int32, try_cast, false }, - { int64, float32, to_f32, false }, // lossy - { int64, float64, to_f64, false }, // lossy - { int64, decimal, cast::, true }, - - { float32, float64, cast::, true }, - { float32, decimal, cast::, true }, - { float32, int16, to_i16, false }, - { float32, int32, to_i32, false }, - { float32, int64, to_i64, false }, - { float64, decimal, cast::, true }, - { float64, int16, to_i16, false }, - { float64, int32, to_i32, false }, - { float64, int64, to_i64, false }, - { float64, float32, to_f32, false }, // lossy - - { decimal, int16, dec_to_i16, false }, - { decimal, int32, dec_to_i32, false }, - { decimal, int64, dec_to_i64, false }, - { decimal, float32, to_f32, false }, - { decimal, float64, to_f64, false }, - - { date, timestamp, cast::, true }, - { time, interval, cast::, true }, - { timestamp, date, timestamp_to_date, true }, - { timestamp, time, timestamp_to_time, true }, - { interval, time, interval_to_time, true } - } - }; -} - // TODO(nanderstabel): optimize for multidimensional List. Depth can be given as a parameter to this // function. -fn unnest(input: &str) -> Result> { - // Trim input +/// Takes a string input in the form of a comma-separated list enclosed in braces, and returns a +/// vector of strings containing the list items. +/// +/// # Examples +/// - "{1, 2, 3}" => ["1", "2", "3"] +/// - "{1, {2, 3}}" => ["1", "{2, 3}"] +fn unnest(input: &str) -> Result> { let trimmed = input.trim(); - - let mut chars = trimmed.chars(); - if chars.next() != Some('{') || chars.next_back() != Some('}') { + if !trimmed.starts_with('{') || !trimmed.ends_with('}') { return Err(ExprError::Parse("Input must be braced".into())); } + let trimmed = &trimmed[1..trimmed.len() - 1]; let mut items = Vec::new(); - while let Some(c) = chars.next() { + let mut depth = 0; + let mut start = 0; + for (i, c) in trimmed.chars().enumerate() { match c { - '{' => { - let mut string = String::from(c); - let mut depth = 1; - while depth != 0 { - let c = match chars.next() { - Some(c) => { - if c == '{' { - depth += 1; - } else if c == '}' { - depth -= 1; - } - c - } - None => { - return Err(ExprError::Parse( - "Missing closing brace '}}' character".into(), - )) - } - }; - string.push(c); - } - items.push(string); - } - '}' => { - return Err(ExprError::Parse( - "Unexpected closing brace '}}' character".into(), - )) + '{' => depth += 1, + '}' => depth -= 1, + ',' if depth == 0 => { + let item = trimmed[start..i].trim(); + items.push(item); + start = i + 1; } - ',' => {} - c if c.is_whitespace() => {} - c => items.push(format!( - "{}{}", - c, - chars.take_while_ref(|&c| c != ',').collect::() - )), + _ => {} } } + if depth != 0 { + return Err(ExprError::Parse("Unbalanced braces".into())); + } + let last = trimmed[start..].trim(); + if !last.is_empty() { + items.push(last); + } Ok(items) } -#[inline(always)] -pub fn str_to_list(input: &str, target_elem_type: &DataType) -> Result { - // Return a new ListValue. - // For each &str in the comma separated input a ScalarRefImpl is initialized which in turn - // is cast into the target DataType. If the target DataType is of type Varchar, then - // no casting is needed. - Ok(ListValue::new( - unnest(input)? - .iter() - .map(|s| { - Some(ScalarRefImpl::Utf8(s.trim())) - .map(|scalar_ref| match target_elem_type { - DataType::Varchar => Ok(scalar_ref.into_scalar_impl()), - _ => scalar_cast(scalar_ref, &DataType::Varchar, target_elem_type), - }) - .transpose() - }) - .try_collect()?, - )) +#[build_function("cast(varchar) -> list")] +fn build_cast_str_to_list( + return_type: DataType, + children: Vec, +) -> Result { + let elem_type = match &return_type { + DataType::List { datatype } => (**datatype).clone(), + _ => panic!("expected list type"), + }; + let child = children.into_iter().next().unwrap(); + Ok(Box::new(UnaryExpression::::new( + child, + return_type, + move |x| str_to_list(x, &elem_type), + ))) +} + +fn str_to_list(input: &str, target_elem_type: &DataType) -> Result { + let cast = build( + PbType::Cast, + target_elem_type.clone(), + vec![InputRefExpression::new(DataType::Varchar, 0).boxed()], + ) + .unwrap(); + let mut values = vec![]; + for item in unnest(input)? { + let v = cast + .eval_row(&OwnedRow::new(vec![Some(item.to_string().into())])) // TODO: optimize + .now_or_never() + .unwrap()?; + values.push(v); + } + Ok(ListValue::new(values)) +} + +#[build_function("cast(list) -> list")] +fn build_cast_list_to_list( + return_type: DataType, + children: Vec, +) -> Result { + let child = children.into_iter().next().unwrap(); + let source_elem_type = match child.return_type() { + DataType::List { datatype } => (*datatype).clone(), + _ => panic!("expected list type"), + }; + let target_elem_type = match &return_type { + DataType::List { datatype } => (**datatype).clone(), + _ => panic!("expected list type"), + }; + Ok(Box::new(UnaryExpression::::new( + child, + return_type, + move |x| list_cast(x, &source_elem_type, &target_elem_type), + ))) } /// Cast array with `source_elem_type` into array with `target_elem_type` by casting each element. -/// -/// TODO: `.map(scalar_cast)` is not a preferred pattern and we should avoid it if possible. -pub fn list_cast( +fn list_cast( input: ListRef<'_>, source_elem_type: &DataType, target_elem_type: &DataType, ) -> Result { - Ok(ListValue::new( - input - .values_ref() - .into_iter() - .map(|datum_ref| { - datum_ref - .map(|scalar_ref| scalar_cast(scalar_ref, source_elem_type, target_elem_type)) - .transpose() - }) - .try_collect()?, + let cast = build( + PbType::Cast, + target_elem_type.clone(), + vec![InputRefExpression::new(source_elem_type.clone(), 0).boxed()], + ) + .unwrap(); + let elements = input.values_ref(); + let mut values = Vec::with_capacity(elements.len()); + for item in elements { + let v = cast + .eval_row(&OwnedRow::new(vec![item.map(|s| s.into_scalar_impl())])) // TODO: optimize + .now_or_never() + .unwrap()?; + values.push(v); + } + Ok(ListValue::new(values)) +} + +#[build_function("cast(struct) -> struct")] +fn build_cast_struct_to_struct( + return_type: DataType, + children: Vec, +) -> Result { + let child = children.into_iter().next().unwrap(); + let source_elem_type = match child.return_type() { + DataType::Struct(s) => (*s).clone(), + _ => panic!("expected struct type"), + }; + let target_elem_type = match &return_type { + DataType::Struct(s) => (**s).clone(), + _ => panic!("expected struct type"), + }; + Ok(Box::new( + UnaryExpression::::new(child, return_type, move |x| { + struct_cast(x, &source_elem_type, &target_elem_type) + }), )) } /// Cast struct of `source_elem_type` to `target_elem_type` by casting each element. -pub fn struct_cast( +fn struct_cast( input: StructRef<'_>, source_elem_type: &StructType, target_elem_type: &StructType, ) -> Result { - Ok(StructValue::new( - input - .fields_ref() - .into_iter() - .zip_eq_fast(source_elem_type.fields.iter()) - .zip_eq_fast(target_elem_type.fields.iter()) - .map(|((datum_ref, source_elem_type), target_elem_type)| { - if source_elem_type == target_elem_type { - return Ok(datum_ref.map(|scalar_ref| scalar_ref.into_scalar_impl())); - } - datum_ref - .map(|scalar_ref| scalar_cast(scalar_ref, source_elem_type, target_elem_type)) - .transpose() - }) - .try_collect()?, - )) -} - -/// Cast scalar ref with `source_type` into owned scalar with `target_type`. This function forms a -/// mutual recursion with `list_cast` so that we can cast nested lists (e.g., varchar[][] to -/// int[][]). -fn scalar_cast( - source: ScalarRefImpl<'_>, - source_type: &DataType, - target_type: &DataType, -) -> Result { - use crate::expr::data_types::*; - - match (source_type, target_type) { - (DataType::Struct(source_type), DataType::Struct(target_type)) => { - Ok(struct_cast(source.try_into()?, source_type, target_type)?.to_scalar_value()) - } - ( - DataType::List { - datatype: source_elem_type, - }, - DataType::List { - datatype: target_elem_type, - }, - ) => list_cast(source.try_into()?, source_elem_type, target_elem_type) - .map(Scalar::to_scalar_value), - ( - DataType::Varchar, - DataType::List { - datatype: target_elem_type, - }, - ) => str_to_list(source.try_into()?, target_elem_type).map(Scalar::to_scalar_value), - (source_type, target_type) => { - macro_rules! gen_cast_impl { - ($( { $input:ident, $cast:ident, $func:expr, $infallible:ident } ),*) => { - match (source_type, target_type) { - $( - ($input! { type_match_pattern }, $cast! { type_match_pattern }) => gen_cast_impl!(arm: $input, $cast, $func, $infallible), - )* - _ => { - return Err(ExprError::UnsupportedCast(source_type.clone(), target_type.clone())); - } - } - }; - (arm: $input:ident, varchar, $func:expr, false) => { - { - let source: <$input! { type_array } as Array>::RefItem<'_> = source.try_into()?; - let mut writer = String::new(); - let target: Result<()> = $func(source, &mut writer); - target.map(|_| Scalar::to_scalar_value(writer.into_boxed_str())) - } - }; - (arm: $input:ident, $cast:ident, $func:expr, false) => { - { - let source: <$input! { type_array } as Array>::RefItem<'_> = source.try_into()?; - let target: Result<<$cast! { type_array } as Array>::OwnedItem> = $func(source); - target.map(Scalar::to_scalar_value) - } - }; - (arm: $input:ident, $cast:ident, $func:expr, true) => { - { - let source: <$input! { type_array } as Array>::RefItem<'_> = source.try_into()?; - let target: Result<<$cast! { type_array } as Array>::OwnedItem> = Ok($func(source)); - target.map(Scalar::to_scalar_value) - } - }; + let fields = (input.fields_ref().into_iter()) + .zip_eq_fast(source_elem_type.fields.iter()) + .zip_eq_fast(target_elem_type.fields.iter()) + .map(|((datum_ref, source_field_type), target_field_type)| { + if source_field_type == target_field_type { + return Ok(datum_ref.map(|scalar_ref| scalar_ref.into_scalar_impl())); } - for_all_cast_variants!(gen_cast_impl) - } - } + let cast = build( + PbType::Cast, + target_field_type.clone(), + vec![InputRefExpression::new(source_field_type.clone(), 0).boxed()], + ) + .unwrap(); + let value = match datum_ref { + Some(scalar_ref) => cast + .eval_row(&OwnedRow::new(vec![Some(scalar_ref.into_scalar_impl())])) + .now_or_never() + .unwrap()?, + None => None, + }; + Ok(value) as Result<_> + }) + .try_collect()?; + Ok(StructValue::new(fields)) } #[cfg(test)] mod tests { use num_traits::FromPrimitive; + use risingwave_common::types::Scalar; use super::*; @@ -866,6 +803,7 @@ mod tests { #[test] fn test_unnest() { + assert_eq!(unnest("{ }").unwrap(), vec![] as Vec); assert_eq!( unnest("{1, 2, 3}").unwrap(), vec!["1".to_string(), "2".to_string(), "3".to_string()] @@ -1076,4 +1014,15 @@ mod tests { let timestamp2 = str_to_timestamp(str2).unwrap(); assert_eq!(timestamp2.0.timestamp_micros(), -1); } + + #[test] + fn test_timestamp() { + assert_eq!( + try_cast::<_, NaiveDateTimeWrapper>(NaiveDateWrapper::from_ymd_uncheck(1994, 1, 1)) + .unwrap(), + NaiveDateTimeWrapper::new( + NaiveDateTime::parse_from_str("1994-1-1 0:0:0", "%Y-%m-%d %H:%M:%S").unwrap() + ) + ) + } } diff --git a/src/expr/src/vector_op/cmp.rs b/src/expr/src/vector_op/cmp.rs index f329595b4d8e3..dde6808adeb7a 100644 --- a/src/expr/src/vector_op/cmp.rs +++ b/src/expr/src/vector_op/cmp.rs @@ -14,11 +14,24 @@ use std::fmt::Debug; -use risingwave_common::array::{ListRef, StructRef}; +use risingwave_common::array::{Array, BoolArray}; +use risingwave_common::buffer::Bitmap; +use risingwave_expr_macro::function; -use crate::Result; - -#[inline(always)] +#[function("equal(*number, *number) -> boolean")] +#[function("equal(serial, serial) -> boolean")] +#[function("equal(date, date) -> boolean")] +#[function("equal(time, time) -> boolean")] +#[function("equal(interval, interval) -> boolean")] +#[function("equal(timestamp, timestamp) -> boolean")] +#[function("equal(timestamptz, timestamptz) -> boolean")] +#[function("equal(date, timestamp) -> boolean")] +#[function("equal(timestamp, date) -> boolean")] +#[function("equal(time, interval) -> boolean")] +#[function("equal(interval, time) -> boolean")] +#[function("equal(varchar, varchar) -> boolean")] +#[function("equal(list, list) -> boolean")] +#[function("equal(struct, struct) -> boolean")] pub fn general_eq(l: T1, r: T2) -> bool where T1: Into + Debug, @@ -28,7 +41,20 @@ where l.into() == r.into() } -#[inline(always)] +#[function("not_equal(*number, *number) -> boolean")] +#[function("not_equal(serial, serial) -> boolean")] +#[function("not_equal(date, date) -> boolean")] +#[function("not_equal(time, time) -> boolean")] +#[function("not_equal(interval, interval) -> boolean")] +#[function("not_equal(timestamp, timestamp) -> boolean")] +#[function("not_equal(timestamptz, timestamptz) -> boolean")] +#[function("not_equal(date, timestamp) -> boolean")] +#[function("not_equal(timestamp, date) -> boolean")] +#[function("not_equal(time, interval) -> boolean")] +#[function("not_equal(interval, time) -> boolean")] +#[function("not_equal(varchar, varchar) -> boolean")] +#[function("not_equal(list, list) -> boolean")] +#[function("not_equal(struct, struct) -> boolean")] pub fn general_ne(l: T1, r: T2) -> bool where T1: Into + Debug, @@ -38,7 +64,20 @@ where l.into() != r.into() } -#[inline(always)] +#[function("greater_than_or_equal(*number, *number) -> boolean")] +#[function("greater_than_or_equal(serial, serial) -> boolean")] +#[function("greater_than_or_equal(date, date) -> boolean")] +#[function("greater_than_or_equal(time, time) -> boolean")] +#[function("greater_than_or_equal(interval, interval) -> boolean")] +#[function("greater_than_or_equal(timestamp, timestamp) -> boolean")] +#[function("greater_than_or_equal(timestamptz, timestamptz) -> boolean")] +#[function("greater_than_or_equal(date, timestamp) -> boolean")] +#[function("greater_than_or_equal(timestamp, date) -> boolean")] +#[function("greater_than_or_equal(time, interval) -> boolean")] +#[function("greater_than_or_equal(interval, time) -> boolean")] +#[function("greater_than_or_equal(varchar, varchar) -> boolean")] +#[function("greater_than_or_equal(list, list) -> boolean")] +#[function("greater_than_or_equal(struct, struct) -> boolean")] pub fn general_ge(l: T1, r: T2) -> bool where T1: Into + Debug, @@ -48,7 +87,20 @@ where l.into() >= r.into() } -#[inline(always)] +#[function("greater_than(*number, *number) -> boolean")] +#[function("greater_than(serial, serial) -> boolean")] +#[function("greater_than(date, date) -> boolean")] +#[function("greater_than(time, time) -> boolean")] +#[function("greater_than(interval, interval) -> boolean")] +#[function("greater_than(timestamp, timestamp) -> boolean")] +#[function("greater_than(timestamptz, timestamptz) -> boolean")] +#[function("greater_than(date, timestamp) -> boolean")] +#[function("greater_than(timestamp, date) -> boolean")] +#[function("greater_than(time, interval) -> boolean")] +#[function("greater_than(interval, time) -> boolean")] +#[function("greater_than(varchar, varchar) -> boolean")] +#[function("greater_than(list, list) -> boolean")] +#[function("greater_than(struct, struct) -> boolean")] pub fn general_gt(l: T1, r: T2) -> bool where T1: Into + Debug, @@ -58,7 +110,20 @@ where l.into() > r.into() } -#[inline(always)] +#[function("less_than_or_equal(*number, *number) -> boolean")] +#[function("less_than_or_equal(serial, serial) -> boolean")] +#[function("less_than_or_equal(date, date) -> boolean")] +#[function("less_than_or_equal(time, time) -> boolean")] +#[function("less_than_or_equal(interval, interval) -> boolean")] +#[function("less_than_or_equal(timestamp, timestamp) -> boolean")] +#[function("less_than_or_equal(timestamptz, timestamptz) -> boolean")] +#[function("less_than_or_equal(date, timestamp) -> boolean")] +#[function("less_than_or_equal(timestamp, date) -> boolean")] +#[function("less_than_or_equal(time, interval) -> boolean")] +#[function("less_than_or_equal(interval, time) -> boolean")] +#[function("less_than_or_equal(varchar, varchar) -> boolean")] +#[function("less_than_or_equal(list, list) -> boolean")] +#[function("less_than_or_equal(struct, struct) -> boolean")] pub fn general_le(l: T1, r: T2) -> bool where T1: Into + Debug, @@ -68,7 +133,20 @@ where l.into() <= r.into() } -#[inline(always)] +#[function("less_than(*number, *number) -> boolean")] +#[function("less_than(serial, serial) -> boolean")] +#[function("less_than(date, date) -> boolean")] +#[function("less_than(time, time) -> boolean")] +#[function("less_than(interval, interval) -> boolean")] +#[function("less_than(timestamp, timestamp) -> boolean")] +#[function("less_than(timestamptz, timestamptz) -> boolean")] +#[function("less_than(date, timestamp) -> boolean")] +#[function("less_than(timestamp, date) -> boolean")] +#[function("less_than(time, interval) -> boolean")] +#[function("less_than(interval, time) -> boolean")] +#[function("less_than(varchar, varchar) -> boolean")] +#[function("less_than(list, list) -> boolean")] +#[function("less_than(struct, struct) -> boolean")] pub fn general_lt(l: T1, r: T2) -> bool where T1: Into + Debug, @@ -78,136 +156,197 @@ where l.into() < r.into() } +#[function("is_distinct_from(*number, *number) -> boolean")] +#[function("is_distinct_from(serial, serial) -> boolean")] +#[function("is_distinct_from(date, date) -> boolean")] +#[function("is_distinct_from(time, time) -> boolean")] +#[function("is_distinct_from(interval, interval) -> boolean")] +#[function("is_distinct_from(timestamp, timestamp) -> boolean")] +#[function("is_distinct_from(timestamptz, timestamptz) -> boolean")] +#[function("is_distinct_from(date, timestamp) -> boolean")] +#[function("is_distinct_from(timestamp, date) -> boolean")] +#[function("is_distinct_from(time, interval) -> boolean")] +#[function("is_distinct_from(interval, time) -> boolean")] +#[function("is_distinct_from(varchar, varchar) -> boolean")] +#[function("is_distinct_from(list, list) -> boolean")] +#[function("is_distinct_from(struct, struct) -> boolean")] pub fn general_is_distinct_from(l: Option, r: Option) -> bool where T1: Into + Debug, T2: Into + Debug, T3: Ord, { - match (l, r) { - (Some(lv), Some(rv)) => general_ne::(lv, rv), - (Some(_), None) => true, - (None, Some(_)) => true, - (None, None) => false, - } + l.map(Into::into) != r.map(Into::into) } +#[function("is_not_distinct_from(*number, *number) -> boolean")] +#[function("is_not_distinct_from(serial, serial) -> boolean")] +#[function("is_not_distinct_from(date, date) -> boolean")] +#[function("is_not_distinct_from(time, time) -> boolean")] +#[function("is_not_distinct_from(interval, interval) -> boolean")] +#[function("is_not_distinct_from(timestamp, timestamp) -> boolean")] +#[function("is_not_distinct_from(timestamptz, timestamptz) -> boolean")] +#[function("is_not_distinct_from(date, timestamp) -> boolean")] +#[function("is_not_distinct_from(timestamp, date) -> boolean")] +#[function("is_not_distinct_from(time, interval) -> boolean")] +#[function("is_not_distinct_from(interval, time) -> boolean")] +#[function("is_not_distinct_from(varchar, varchar) -> boolean")] +#[function("is_not_distinct_from(list, list) -> boolean")] +#[function("is_not_distinct_from(struct, struct) -> boolean")] pub fn general_is_not_distinct_from(l: Option, r: Option) -> bool where T1: Into + Debug, T2: Into + Debug, T3: Ord, { - match (l, r) { - (Some(lv), Some(rv)) => general_eq::(lv, rv), - (Some(_), None) => false, - (None, Some(_)) => false, - (None, None) => true, - } + l.map(Into::into) == r.map(Into::into) } -#[derive(Clone, Copy, Debug)] -pub enum Comparison { - Eq, - Ne, - Lt, - Gt, - Le, - Ge, -} - -pub(crate) static EQ: Comparison = Comparison::Eq; -pub(crate) static NE: Comparison = Comparison::Ne; -pub(crate) static LT: Comparison = Comparison::Lt; -pub(crate) static GT: Comparison = Comparison::Gt; -pub(crate) static LE: Comparison = Comparison::Le; -pub(crate) static GE: Comparison = Comparison::Ge; - -#[inline(always)] -pub fn gen_struct_cmp(op: Comparison) -> fn(StructRef<'_>, StructRef<'_>) -> Result { - use crate::gen_cmp; - gen_cmp!(op) -} - -#[inline(always)] -pub fn gen_list_cmp(op: Comparison) -> fn(ListRef<'_>, ListRef<'_>) -> Result { - use crate::gen_cmp; - gen_cmp!(op) -} - -#[inline(always)] -pub fn gen_str_cmp(op: Comparison) -> fn(&str, &str) -> Result { - use crate::gen_cmp; - gen_cmp!(op) -} - -#[macro_export] -macro_rules! gen_cmp { - ($op:expr) => { - match $op { - Comparison::Eq => |l, r| Ok(l == r), - Comparison::Ne => |l, r| Ok(l != r), - Comparison::Lt => |l, r| Ok(l < r), - Comparison::Gt => |l, r| Ok(l > r), - Comparison::Le => |l, r| Ok(l <= r), - Comparison::Ge => |l, r| Ok(l >= r), - } - }; -} - -pub fn str_is_distinct_from(l: Option<&str>, r: Option<&str>) -> Result> { - match (l, r) { - (Some(lv), Some(rv)) => Ok(Some(lv != rv)), - (Some(_), None) => Ok(Some(true)), - (None, Some(_)) => Ok(Some(true)), - (None, None) => Ok(Some(false)), - } +#[function("equal(boolean, boolean) -> boolean", batch = "boolarray_eq")] +pub fn boolean_eq(l: bool, r: bool) -> bool { + l == r } -pub fn str_is_not_distinct_from(l: Option<&str>, r: Option<&str>) -> Result> { - match (l, r) { - (Some(lv), Some(rv)) => Ok(Some(lv == rv)), - (Some(_), None) => Ok(Some(false)), - (None, Some(_)) => Ok(Some(false)), - (None, None) => Ok(Some(true)), - } +#[function("not_equal(boolean, boolean) -> boolean", batch = "boolarray_ne")] +pub fn boolean_ne(l: bool, r: bool) -> bool { + l != r +} + +#[function( + "greater_than_or_equal(boolean, boolean) -> boolean", + batch = "boolarray_ge" +)] +pub fn boolean_ge(l: bool, r: bool) -> bool { + l >= r +} + +#[allow(clippy::bool_comparison)] +#[function("greater_than(boolean, boolean) -> boolean", batch = "boolarray_gt")] +pub fn boolean_gt(l: bool, r: bool) -> bool { + l > r +} + +#[function( + "less_than_or_equal(boolean, boolean) -> boolean", + batch = "boolarray_le" +)] +pub fn boolean_le(l: bool, r: bool) -> bool { + l <= r } -#[inline(always)] -pub fn is_true(v: Option) -> Option { - Some(v == Some(true)) +#[allow(clippy::bool_comparison)] +#[function("less_than(boolean, boolean) -> boolean", batch = "boolarray_lt")] +pub fn boolean_lt(l: bool, r: bool) -> bool { + l < r } -#[inline(always)] -pub fn is_not_true(v: Option) -> Option { - Some(v != Some(true)) +#[function( + "is_distinct_from(boolean, boolean) -> boolean", + batch = "boolarray_is_distinct_from" +)] +pub fn boolean_is_distinct_from(l: Option, r: Option) -> bool { + l != r } -#[inline(always)] -pub fn is_false(v: Option) -> Option { - Some(v == Some(false)) +#[function( + "is_not_distinct_from(boolean, boolean) -> boolean", + batch = "boolarray_is_not_distinct_from" +)] +pub fn boolean_is_not_distinct_from(l: Option, r: Option) -> bool { + l == r +} + +#[function("is_true(boolean) -> boolean", batch = "boolarray_is_true")] +pub fn is_true(v: Option) -> bool { + v == Some(true) +} + +#[function("is_not_true(boolean) -> boolean", batch = "boolarray_is_not_true")] +pub fn is_not_true(v: Option) -> bool { + v != Some(true) +} + +#[function("is_false(boolean) -> boolean", batch = "boolarray_is_false")] +pub fn is_false(v: Option) -> bool { + v == Some(false) +} + +#[function("is_not_false(boolean) -> boolean", batch = "boolarray_is_not_false")] +pub fn is_not_false(v: Option) -> bool { + v != Some(false) +} + +// optimized functions for bool arrays + +fn boolarray_eq(l: &BoolArray, r: &BoolArray) -> BoolArray { + let data = !(l.data() ^ r.data()); + let bitmap = l.null_bitmap() & r.null_bitmap(); + BoolArray::new(data, bitmap) } -#[inline(always)] -pub fn is_not_false(v: Option) -> Option { - Some(v != Some(false)) +fn boolarray_ne(l: &BoolArray, r: &BoolArray) -> BoolArray { + let data = l.data() ^ r.data(); + let bitmap = l.null_bitmap() & r.null_bitmap(); + BoolArray::new(data, bitmap) } -#[inline(always)] -pub fn is_unknown(v: Option) -> Option { - Some(v.is_none()) +fn boolarray_gt(l: &BoolArray, r: &BoolArray) -> BoolArray { + let data = l.data() & !r.data(); + let bitmap = l.null_bitmap() & r.null_bitmap(); + BoolArray::new(data, bitmap) } -#[inline(always)] -pub fn is_not_unknown(v: Option) -> Option { - Some(v.is_some()) +fn boolarray_lt(l: &BoolArray, r: &BoolArray) -> BoolArray { + let data = !l.data() & r.data(); + let bitmap = l.null_bitmap() & r.null_bitmap(); + BoolArray::new(data, bitmap) +} + +fn boolarray_ge(l: &BoolArray, r: &BoolArray) -> BoolArray { + let data = l.data() | !r.data(); + let bitmap = l.null_bitmap() & r.null_bitmap(); + BoolArray::new(data, bitmap) +} + +fn boolarray_le(l: &BoolArray, r: &BoolArray) -> BoolArray { + let data = !l.data() | r.data(); + let bitmap = l.null_bitmap() & r.null_bitmap(); + BoolArray::new(data, bitmap) +} + +fn boolarray_is_distinct_from(l: &BoolArray, r: &BoolArray) -> BoolArray { + let data = ((l.data() ^ r.data()) & (l.null_bitmap() & r.null_bitmap())) + | (l.null_bitmap() ^ r.null_bitmap()); + BoolArray::new(data, Bitmap::ones(l.len())) +} + +fn boolarray_is_not_distinct_from(l: &BoolArray, r: &BoolArray) -> BoolArray { + let data = !(((l.data() ^ r.data()) & (l.null_bitmap() & r.null_bitmap())) + | (l.null_bitmap() ^ r.null_bitmap())); + BoolArray::new(data, Bitmap::ones(l.len())) +} + +fn boolarray_is_true(a: &BoolArray) -> BoolArray { + BoolArray::new(a.to_bitmap(), Bitmap::ones(a.len())) +} + +fn boolarray_is_not_true(a: &BoolArray) -> BoolArray { + BoolArray::new(!a.to_bitmap(), Bitmap::ones(a.len())) +} + +fn boolarray_is_false(a: &BoolArray) -> BoolArray { + BoolArray::new(!a.data() & a.null_bitmap(), Bitmap::ones(a.len())) +} + +fn boolarray_is_not_false(a: &BoolArray) -> BoolArray { + BoolArray::new(a.data() | !a.null_bitmap(), Bitmap::ones(a.len())) } #[cfg(test)] mod tests { use std::str::FromStr; - use risingwave_common::types::Decimal; + use risingwave_common::types::{Decimal, OrderedF32, OrderedF64}; use super::*; @@ -218,4 +357,71 @@ mod tests { 1.1f32 )) } + + #[test] + fn test_comparison() { + assert!(general_eq::(dec("1.0"), 1)); + assert!(general_eq::(dec("1.0"), 1.0)); + assert!(!general_ne::(dec("1.0"), 1)); + assert!(!general_ne::(dec("1.0"), 1.0)); + assert!(!general_gt::(dec("1.0"), 2)); + assert!(!general_gt::(dec("1.0"), 2.0)); + assert!(general_le::(dec("1.0"), 2)); + assert!(general_le::(dec("1.0"), 2.1)); + assert!(!general_ge::(dec("1.0"), 2)); + assert!(!general_ge::(dec("1.0"), 2.1)); + assert!(general_lt::(dec("1.0"), 2)); + assert!(general_lt::(dec("1.0"), 2.1)); + assert!(general_is_distinct_from::( + Some(dec("1.0")), + Some(2) + )); + assert!(general_is_distinct_from::( + Some(dec("1.0")), + Some(2.0) + )); + assert!(general_is_distinct_from::( + Some(dec("1.0")), + None + )); + assert!(general_is_distinct_from::( + None, + Some(1) + )); + assert!(!general_is_distinct_from::( + Some(dec("1.0")), + Some(1) + )); + assert!(!general_is_distinct_from::( + Some(dec("1.0")), + Some(1.0) + )); + assert!(!general_is_distinct_from::( + None, None + )); + assert!(general_eq::(1.0.into(), 1)); + assert!(!general_ne::(1.0.into(), 1)); + assert!(!general_lt::(1.0.into(), 1)); + assert!(general_le::(1.0.into(), 1)); + assert!(!general_gt::(1.0.into(), 1)); + assert!(general_ge::(1.0.into(), 1)); + assert!(!general_is_distinct_from::( + Some(1.0.into()), + Some(1) + )); + assert!(general_eq::(1i64, 1)); + assert!(!general_ne::(1i64, 1)); + assert!(!general_lt::(1i64, 1)); + assert!(general_le::(1i64, 1)); + assert!(!general_gt::(1i64, 1)); + assert!(general_ge::(1i64, 1)); + assert!(!general_is_distinct_from::( + Some(1i64), + Some(1) + )); + } + + fn dec(s: &str) -> Decimal { + Decimal::from_str(s).unwrap() + } } diff --git a/src/expr/src/vector_op/concat_op.rs b/src/expr/src/vector_op/concat_op.rs index 0dc6da2e3f138..b6d362538a51f 100644 --- a/src/expr/src/vector_op/concat_op.rs +++ b/src/expr/src/vector_op/concat_op.rs @@ -14,13 +14,12 @@ use std::fmt::Write; -use crate::Result; +use risingwave_expr_macro::function; -#[inline(always)] -pub fn concat_op(left: &str, right: &str, writer: &mut dyn Write) -> Result<()> { +#[function("concat_op(varchar, varchar) -> varchar")] +pub fn concat_op(left: &str, right: &str, writer: &mut dyn Write) { writer.write_str(left).unwrap(); writer.write_str(right).unwrap(); - Ok(()) } #[cfg(test)] @@ -30,7 +29,7 @@ mod tests { #[test] fn test_concat_op() { let mut s = String::new(); - concat_op("114", "514", &mut s).unwrap(); + concat_op("114", "514", &mut s); assert_eq!(s, "114514") } } diff --git a/src/expr/src/vector_op/conjunction.rs b/src/expr/src/vector_op/conjunction.rs index 2cf0f6c5eda33..c644b4bfba0fa 100644 --- a/src/expr/src/vector_op/conjunction.rs +++ b/src/expr/src/vector_op/conjunction.rs @@ -12,58 +12,66 @@ // See the License for the specific language governing permissions and // limitations under the License. -use crate::Result; +use risingwave_expr_macro::function; -#[inline(always)] -pub fn and(l: Option, r: Option) -> Result> { +// see BinaryShortCircuitExpression +// #[function("and(boolean, boolean) -> boolean")] +pub fn and(l: Option, r: Option) -> Option { match (l, r) { - (Some(lb), Some(lr)) => Ok(Some(lb & lr)), - (Some(true), None) => Ok(None), - (None, Some(true)) => Ok(None), - (Some(false), None) => Ok(Some(false)), - (None, Some(false)) => Ok(Some(false)), - (None, None) => Ok(None), + (Some(lb), Some(lr)) => Some(lb & lr), + (Some(true), None) => None, + (None, Some(true)) => None, + (Some(false), None) => Some(false), + (None, Some(false)) => Some(false), + (None, None) => None, } } -#[inline(always)] -pub fn or(l: Option, r: Option) -> Result> { +// see BinaryShortCircuitExpression +// #[function("or(boolean, boolean) -> boolean")] +pub fn or(l: Option, r: Option) -> Option { match (l, r) { - (Some(lb), Some(lr)) => Ok(Some(lb | lr)), - (Some(true), None) => Ok(Some(true)), - (None, Some(true)) => Ok(Some(true)), - (Some(false), None) => Ok(None), - (None, Some(false)) => Ok(None), - (None, None) => Ok(None), + (Some(lb), Some(lr)) => Some(lb | lr), + (Some(true), None) => Some(true), + (None, Some(true)) => Some(true), + (Some(false), None) => None, + (None, Some(false)) => None, + (None, None) => None, } } -#[inline(always)] -pub fn not(l: Option) -> Option { - l.map(|v| !v) +#[function("not(boolean) -> boolean")] +pub fn not(v: bool) -> bool { + !v } #[cfg(test)] mod tests { - use crate::vector_op::conjunction::{and, or}; + use super::*; #[test] fn test_and() { - assert_eq!(Some(true), and(Some(true), Some(true)).unwrap()); - assert_eq!(Some(false), and(Some(true), Some(false)).unwrap()); - assert_eq!(Some(false), and(Some(false), Some(false)).unwrap()); - assert_eq!(None, and(Some(true), None).unwrap()); - assert_eq!(Some(false), and(Some(false), None).unwrap()); - assert_eq!(None, and(None, None).unwrap()); + assert_eq!(Some(true), and(Some(true), Some(true))); + assert_eq!(Some(false), and(Some(true), Some(false))); + assert_eq!(Some(false), and(Some(false), Some(false))); + assert_eq!(None, and(Some(true), None)); + assert_eq!(Some(false), and(Some(false), None)); + assert_eq!(None, and(None, None)); } #[test] fn test_or() { - assert_eq!(Some(true), or(Some(true), Some(true)).unwrap()); - assert_eq!(Some(true), or(Some(true), Some(false)).unwrap()); - assert_eq!(Some(false), or(Some(false), Some(false)).unwrap()); - assert_eq!(Some(true), or(Some(true), None).unwrap()); - assert_eq!(None, or(Some(false), None).unwrap()); - assert_eq!(None, or(None, None).unwrap()); + assert_eq!(Some(true), or(Some(true), Some(true))); + assert_eq!(Some(true), or(Some(true), Some(false))); + assert_eq!(Some(false), or(Some(false), Some(false))); + assert_eq!(Some(true), or(Some(true), None)); + assert_eq!(None, or(Some(false), None)); + assert_eq!(None, or(None, None)); + } + + #[test] + fn test_not() { + assert!(!not(true)); + assert!(not(false)); } } diff --git a/src/expr/src/vector_op/date_trunc.rs b/src/expr/src/vector_op/date_trunc.rs index 0acb8f14a7453..76676c7aa4fa0 100644 --- a/src/expr/src/vector_op/date_trunc.rs +++ b/src/expr/src/vector_op/date_trunc.rs @@ -13,10 +13,12 @@ // limitations under the License. use risingwave_common::types::{IntervalUnit, NaiveDateTimeWrapper}; +use risingwave_expr_macro::function; +use super::timestamptz::{timestamp_at_time_zone, timestamptz_at_time_zone}; use crate::{ExprError, Result}; -#[inline] +#[function("date_trunc(varchar, timestamp) -> timestamp")] pub fn date_trunc_timestamp(field: &str, ts: NaiveDateTimeWrapper) -> Result { Ok(match field.to_ascii_lowercase().as_str() { "microseconds" => ts.truncate_micros(), @@ -36,7 +38,19 @@ pub fn date_trunc_timestamp(field: &str, ts: NaiveDateTimeWrapper) -> Result timestamptz")] +pub fn date_trunc_timestamptz(_field: &str, _ts: i64) -> Result { + todo!("date_trunc_timestamptz") +} + +#[function("date_trunc(varchar, timestamptz, varchar) -> timestamptz")] +pub fn date_trunc_timestamptz_at_timezone(field: &str, ts: i64, timezone: &str) -> Result { + let timestamp = timestamptz_at_time_zone(ts, timezone)?; + let truncated = date_trunc_timestamp(field, timestamp)?; + timestamp_at_time_zone(truncated, timezone) +} + +#[function("date_trunc(varchar, interval) -> interval")] pub fn date_trunc_interval(field: &str, interval: IntervalUnit) -> Result { Ok(match field.to_ascii_lowercase().as_str() { "microseconds" => interval, diff --git a/src/expr/src/vector_op/exp.rs b/src/expr/src/vector_op/exp.rs index b9ae2fb8fc4da..b4b6377570043 100644 --- a/src/expr/src/vector_op/exp.rs +++ b/src/expr/src/vector_op/exp.rs @@ -14,9 +14,11 @@ use num_traits::{Float, Zero}; use risingwave_common::types::OrderedF64; +use risingwave_expr_macro::function; use crate::{ExprError, Result}; +#[function("exp(float64) -> float64")] pub fn exp_f64(input: OrderedF64) -> Result { // The cases where the exponent value is Inf or NaN can be handled explicitly and without // evaluating the `exp` operation. diff --git a/src/expr/src/vector_op/extract.rs b/src/expr/src/vector_op/extract.rs index a8605b36ba0ef..c0882e5052135 100644 --- a/src/expr/src/vector_op/extract.rs +++ b/src/expr/src/vector_op/extract.rs @@ -14,6 +14,7 @@ use chrono::{Datelike, Timelike}; use risingwave_common::types::{Decimal, NaiveDateTimeWrapper, NaiveDateWrapper, NaiveTimeWrapper}; +use risingwave_expr_macro::function; use crate::{ExprError, Result}; @@ -51,10 +52,12 @@ fn invalid_unit(name: &'static str, unit: &str) -> ExprError { } } +#[function("extract(varchar, date) -> decimal")] pub fn extract_from_date(unit: &str, date: NaiveDateWrapper) -> Result { extract_date(date.0, unit).ok_or_else(|| invalid_unit("date unit", unit)) } +#[function("extract(varchar, timestamp) -> decimal")] pub fn extract_from_timestamp(unit: &str, timestamp: NaiveDateTimeWrapper) -> Result { let time = timestamp.0; @@ -63,6 +66,7 @@ pub fn extract_from_timestamp(unit: &str, timestamp: NaiveDateTimeWrapper) -> Re .ok_or_else(|| invalid_unit("timestamp unit", unit)) } +#[function("extract(varchar, timestamptz) -> decimal")] pub fn extract_from_timestamptz(unit: &str, usecs: i64) -> Result { match unit { "EPOCH" => Ok(Decimal::from(usecs) / 1_000_000.into()), @@ -71,6 +75,7 @@ pub fn extract_from_timestamptz(unit: &str, usecs: i64) -> Result { } } +#[function("extract(varchar, time) -> decimal")] pub fn extract_from_time(unit: &str, time: NaiveTimeWrapper) -> Result { extract_time(time.0, unit).ok_or_else(|| invalid_unit("time unit", unit)) } diff --git a/src/expr/src/vector_op/format_type.rs b/src/expr/src/vector_op/format_type.rs index b18b734989cf6..a8e497e6aaacf 100644 --- a/src/expr/src/vector_op/format_type.rs +++ b/src/expr/src/vector_op/format_type.rs @@ -13,15 +13,14 @@ // limitations under the License. use risingwave_common::types::DataType; +use risingwave_expr_macro::function; -use crate::Result; - -#[inline(always)] -pub fn format_type(oid: Option, _typemod: Option) -> Result>> { +#[function("format_type(int32, int32) -> varchar")] +pub fn format_type(oid: Option, _typemod: Option) -> Option> { // since we don't support type modifier, ignore it. - Ok(oid.map(|i| { + oid.map(|i| { DataType::from_oid(i) .map(|dt| format!("{}", dt).into_boxed_str()) .unwrap_or("???".into()) - })) + }) } diff --git a/src/expr/src/vector_op/jsonb_info.rs b/src/expr/src/vector_op/jsonb_info.rs index 064e54959b3bd..68c96de378ef0 100644 --- a/src/expr/src/vector_op/jsonb_info.rs +++ b/src/expr/src/vector_op/jsonb_info.rs @@ -15,17 +15,16 @@ use std::fmt::Write; use risingwave_common::array::JsonbRef; +use risingwave_expr_macro::function; use crate::{ExprError, Result}; -#[inline(always)] -pub fn jsonb_typeof(v: JsonbRef<'_>, writer: &mut dyn Write) -> Result<()> { - writer - .write_str(v.type_name()) - .map_err(|e| ExprError::Internal(e.into())) +#[function("jsonb_typeof(jsonb) -> varchar")] +pub fn jsonb_typeof(v: JsonbRef<'_>, writer: &mut dyn Write) { + writer.write_str(v.type_name()).unwrap() } -#[inline(always)] +#[function("jsonb_array_length(jsonb) -> int32")] pub fn jsonb_array_length(v: JsonbRef<'_>) -> Result { v.array_len() .map(|n| n as i32) diff --git a/src/expr/src/vector_op/length.rs b/src/expr/src/vector_op/length.rs index dc071b807afb0..e0ea6ff69d8da 100644 --- a/src/expr/src/vector_op/length.rs +++ b/src/expr/src/vector_op/length.rs @@ -12,21 +12,22 @@ // See the License for the specific language governing permissions and // limitations under the License. -use crate::Result; +use risingwave_expr_macro::function; -#[inline(always)] -pub fn length_default(s: &str) -> Result { - Ok(s.chars().count() as i32) +#[function("length(varchar) -> int32")] +#[function("char_length(varchar) -> int32")] +pub fn length(s: &str) -> i32 { + s.chars().count() as i32 } -#[inline(always)] -pub fn octet_length(s: &str) -> Result { - Ok(s.as_bytes().len() as i32) +#[function("octet_length(varchar) -> int32")] +pub fn octet_length(s: &str) -> i32 { + s.as_bytes().len() as i32 } -#[inline(always)] -pub fn bit_length(s: &str) -> Result { - octet_length(s).map(|n| n * 8) +#[function("bit_length(varchar) -> int32")] +pub fn bit_length(s: &str) -> i32 { + octet_length(s) * 8 } #[cfg(test)] @@ -39,7 +40,7 @@ mod tests { let cases = [("hello world", 11), ("hello rust", 10)]; for (s, expected) in cases { - assert_eq!(length_default(s).unwrap(), expected) + assert_eq!(length(s), expected); } } @@ -48,7 +49,7 @@ mod tests { let cases = [("hello world", 11), ("你好", 6), ("😇哈哈hhh", 13)]; for (s, expected) in cases { - assert_eq!(octet_length(s).unwrap(), expected) + assert_eq!(octet_length(s), expected); } } @@ -61,7 +62,7 @@ mod tests { ]; for (s, expected) in cases { - assert_eq!(bit_length(s).unwrap(), expected) + assert_eq!(bit_length(s), expected); } } } diff --git a/src/expr/src/vector_op/like.rs b/src/expr/src/vector_op/like.rs index c6778ab3e1016..3b94990c68279 100644 --- a/src/expr/src/vector_op/like.rs +++ b/src/expr/src/vector_op/like.rs @@ -12,10 +12,10 @@ // See the License for the specific language governing permissions and // limitations under the License. -use crate::Result; +use risingwave_expr_macro::function; -#[inline(always)] -pub fn like_default(s: &str, p: &str) -> Result { +#[function("like(varchar, varchar) -> boolean")] +pub fn like_default(s: &str, p: &str) -> bool { let (mut px, mut sx) = (0, 0); let (mut next_px, mut next_sx) = (0, 0); let (pbytes, sbytes) = (p.as_bytes(), s.as_bytes()); @@ -50,9 +50,9 @@ pub fn like_default(s: &str, p: &str) -> Result { sx = next_sx; continue; } - return Ok(false); + return false; } - Ok(true) + true } #[cfg(test)] @@ -84,7 +84,7 @@ mod tests { #[test] fn test_like() { for (target, pattern, expected) in CASES { - let output = like_default(target, pattern).unwrap(); + let output = like_default(target, pattern); assert_eq!( output, expected.unwrap(), diff --git a/src/expr/src/vector_op/lower.rs b/src/expr/src/vector_op/lower.rs index c895c10c97f11..c76de5b129c27 100644 --- a/src/expr/src/vector_op/lower.rs +++ b/src/expr/src/vector_op/lower.rs @@ -14,14 +14,13 @@ use std::fmt::Write; -use crate::Result; +use risingwave_expr_macro::function; -#[inline(always)] -pub fn lower(s: &str, writer: &mut dyn Write) -> Result<()> { +#[function("lower(varchar) -> varchar")] +pub fn lower(s: &str, writer: &mut dyn Write) { for c in s.chars() { writer.write_char(c.to_ascii_lowercase()).unwrap(); } - Ok(()) } #[cfg(test)] @@ -29,7 +28,7 @@ mod tests { use super::*; #[test] - fn test_lower() -> Result<()> { + fn test_lower() { let cases = [ ("HELLO WORLD", "hello world"), ("hello RUST", "hello rust"), @@ -38,9 +37,8 @@ mod tests { for (s, expected) in cases { let mut writer = String::new(); - lower(s, &mut writer)?; + lower(s, &mut writer); assert_eq!(writer, expected); } - Ok(()) } } diff --git a/src/expr/src/vector_op/ltrim.rs b/src/expr/src/vector_op/ltrim.rs deleted file mode 100644 index c542d9039ffff..0000000000000 --- a/src/expr/src/vector_op/ltrim.rs +++ /dev/null @@ -1,47 +0,0 @@ -// Copyright 2023 RisingWave Labs -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -use std::fmt::Write; - -use crate::Result; - -/// Note: the behavior of `ltrim` in `PostgreSQL` and `trim_start` (or `trim_left`) in Rust -/// are actually different when the string is in right-to-left languages like Arabic or Hebrew. -/// Since we would like to simplify the implementation, currently we omit this case. -#[inline(always)] -pub fn ltrim(s: &str, writer: &mut dyn Write) -> Result<()> { - writer.write_str(s.trim_start()).unwrap(); - Ok(()) -} - -#[cfg(test)] -mod tests { - - use super::*; - - #[test] - fn test_ltrim() -> Result<()> { - let cases = [ - (" \tHello\tworld\t", "Hello\tworld\t"), - (" \t空I ❤️ databases空 ", "空I ❤️ databases空 "), - ]; - - for (s, expected) in cases { - let mut writer = String::new(); - ltrim(s, &mut writer)?; - assert_eq!(writer, expected); - } - Ok(()) - } -} diff --git a/src/expr/src/vector_op/md5.rs b/src/expr/src/vector_op/md5.rs index 780526d893c84..44c6b4ea1a7eb 100644 --- a/src/expr/src/vector_op/md5.rs +++ b/src/expr/src/vector_op/md5.rs @@ -14,12 +14,11 @@ use std::fmt::Write; -use crate::Result; +use risingwave_expr_macro::function; -#[inline(always)] -pub fn md5(s: &str, writer: &mut dyn Write) -> Result<()> { +#[function("md5(varchar) -> varchar")] +pub fn md5(s: &str, writer: &mut dyn Write) { write!(writer, "{:x}", ::md5::compute(s)).unwrap(); - Ok(()) } #[cfg(test)] @@ -27,7 +26,7 @@ mod tests { use super::*; #[test] - fn test_md5() -> Result<()> { + fn test_md5() { let cases = [ ("hello world", "5eb63bbbe01eeed093cb22bb8f5acdc3"), ("hello RUST", "917b821a0a5f23ab0cfdb36056d2eb9d"), @@ -39,9 +38,8 @@ mod tests { for (s, expected) in cases { let mut writer = String::new(); - md5(s, &mut writer)?; + md5(s, &mut writer); assert_eq!(writer, expected); } - Ok(()) } } diff --git a/src/expr/src/vector_op/mod.rs b/src/expr/src/vector_op/mod.rs index 164e433ad9dde..389703a12bf27 100644 --- a/src/expr/src/vector_op/mod.rs +++ b/src/expr/src/vector_op/mod.rs @@ -29,14 +29,12 @@ pub mod jsonb_info; pub mod length; pub mod like; pub mod lower; -pub mod ltrim; pub mod md5; pub mod overlay; pub mod position; pub mod repeat; pub mod replace; pub mod round; -pub mod rtrim; pub mod split_part; pub mod substr; pub mod timestamptz; @@ -44,9 +42,5 @@ pub mod to_char; pub mod to_timestamp; pub mod translate; pub mod trim; -pub mod trim_characters; pub mod tumble; pub mod upper; - -#[cfg(test)] -mod tests; diff --git a/src/expr/src/vector_op/overlay.rs b/src/expr/src/vector_op/overlay.rs index e83251f7ff362..928d573c47485 100644 --- a/src/expr/src/vector_op/overlay.rs +++ b/src/expr/src/vector_op/overlay.rs @@ -14,15 +14,17 @@ use std::fmt::Write; +use risingwave_expr_macro::function; + use crate::{ExprError, Result}; -#[inline(always)] +#[function("overlay(varchar, varchar, int32) -> varchar")] pub fn overlay(s: &str, new_sub_str: &str, start: i32, writer: &mut dyn Write) -> Result<()> { // If count is omitted, it defaults to the length of new_sub_str. overlay_for(s, new_sub_str, start, new_sub_str.len() as i32, writer) } -#[inline(always)] +#[function("overlay(varchar, varchar, int32, int32) -> varchar")] pub fn overlay_for( s: &str, new_sub_str: &str, diff --git a/src/expr/src/vector_op/position.rs b/src/expr/src/vector_op/position.rs index 5e2721687f24e..9049e743a4f70 100644 --- a/src/expr/src/vector_op/position.rs +++ b/src/expr/src/vector_op/position.rs @@ -12,12 +12,14 @@ // See the License for the specific language governing permissions and // limitations under the License. +use risingwave_expr_macro::function; + use crate::Result; -#[inline(always)] /// Location of specified substring /// /// Note: According to pgsql, position will return 0 rather -1 when substr is not in the target str +#[function("position(varchar, varchar) -> int32")] pub fn position(str: &str, sub_str: &str) -> Result { match str.find(sub_str) { Some(byte_idx) => Ok((str[..byte_idx].chars().count() + 1) as i32), diff --git a/src/expr/src/vector_op/repeat.rs b/src/expr/src/vector_op/repeat.rs index bd5f490de6008..b46076b8eeefd 100644 --- a/src/expr/src/vector_op/repeat.rs +++ b/src/expr/src/vector_op/repeat.rs @@ -14,14 +14,13 @@ use std::fmt::Write; -use crate::Result; +use risingwave_expr_macro::function; -#[inline(always)] -pub fn repeat(s: &str, count: i32, writer: &mut dyn Write) -> Result<()> { +#[function("repeat(varchar, int32) -> varchar")] +pub fn repeat(s: &str, count: i32, writer: &mut dyn Write) { for _ in 0..count { writer.write_str(s).unwrap(); } - Ok(()) } #[cfg(test)] @@ -29,7 +28,7 @@ mod tests { use super::*; #[test] - fn test_repeat() -> Result<()> { + fn test_repeat() { let cases = vec![ ("hello, world", 1, "hello, world"), ("114514", 3, "114514114514114514"), @@ -39,9 +38,8 @@ mod tests { for (s, count, expected) in cases { let mut writer = String::new(); - repeat(s, count, &mut writer)?; + repeat(s, count, &mut writer); assert_eq!(writer, expected); } - Ok(()) } } diff --git a/src/expr/src/vector_op/replace.rs b/src/expr/src/vector_op/replace.rs index ef530d4ba02bb..02eeefdc8490e 100644 --- a/src/expr/src/vector_op/replace.rs +++ b/src/expr/src/vector_op/replace.rs @@ -14,13 +14,13 @@ use std::fmt::Write; -use crate::Result; +use risingwave_expr_macro::function; -#[inline(always)] -pub fn replace(s: &str, from_str: &str, to_str: &str, writer: &mut dyn Write) -> Result<()> { +#[function("replace(varchar, varchar, varchar) -> varchar")] +pub fn replace(s: &str, from_str: &str, to_str: &str, writer: &mut dyn Write) { if from_str.is_empty() { writer.write_str(s).unwrap(); - return Ok(()); + return; } let mut last = 0; while let Some(mut start) = s[last..].find(from_str) { @@ -30,7 +30,6 @@ pub fn replace(s: &str, from_str: &str, to_str: &str, writer: &mut dyn Write) -> last = start + from_str.len(); } writer.write_str(&s[last..]).unwrap(); - Ok(()) } #[cfg(test)] @@ -38,7 +37,7 @@ mod tests { use super::*; #[test] - fn test_replace() -> Result<()> { + fn test_replace() { let cases = vec![ ("hello, word", "我的", "world", "hello, word"), ("hello, word", "", "world", "hello, word"), @@ -50,9 +49,8 @@ mod tests { for (s, from_str, to_str, expected) in cases { let mut writer = String::new(); - replace(s, from_str, to_str, &mut writer)?; + replace(s, from_str, to_str, &mut writer); assert_eq!(writer, expected); } - Ok(()) } } diff --git a/src/expr/src/vector_op/round.rs b/src/expr/src/vector_op/round.rs index 307a47583816d..1c0b9b8fa1b73 100644 --- a/src/expr/src/vector_op/round.rs +++ b/src/expr/src/vector_op/round.rs @@ -13,8 +13,9 @@ // limitations under the License. use risingwave_common::types::{Decimal, OrderedF64}; +use risingwave_expr_macro::function; -#[inline(always)] +#[function("round_digit(decimal, int32) -> decimal")] pub fn round_digits>(input: Decimal, digits: D) -> Decimal { let digits = digits.into(); if digits < 0 { @@ -25,37 +26,38 @@ pub fn round_digits>(input: Decimal, digits: D) -> Decimal { } } -#[inline(always)] +#[function("ceil(float64) -> float64")] pub fn ceil_f64(input: OrderedF64) -> OrderedF64 { f64::ceil(input.0).into() } -#[inline(always)] +#[function("ceil(decimal) -> decimal")] pub fn ceil_decimal(input: Decimal) -> Decimal { input.ceil() } -#[inline(always)] +#[function("floor(float64) -> float64")] pub fn floor_f64(input: OrderedF64) -> OrderedF64 { f64::floor(input.0).into() } -#[inline(always)] +#[function("floor(decimal) -> decimal")] pub fn floor_decimal(input: Decimal) -> Decimal { input.floor() } // Ties are broken by rounding away from zero -#[inline(always)] +#[function("round(float64) -> float64")] pub fn round_f64(input: OrderedF64) -> OrderedF64 { f64::round(input.0).into() } // Ties are broken by rounding away from zero -#[inline(always)] +#[function("round(decimal) -> decimal")] pub fn round_decimal(input: Decimal) -> Decimal { input.round_dp(0) } + #[cfg(test)] mod tests { use std::str::FromStr; diff --git a/src/expr/src/vector_op/rtrim.rs b/src/expr/src/vector_op/rtrim.rs deleted file mode 100644 index 3257044edeb5d..0000000000000 --- a/src/expr/src/vector_op/rtrim.rs +++ /dev/null @@ -1,46 +0,0 @@ -// Copyright 2023 RisingWave Labs -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -use std::fmt::Write; - -use crate::Result; - -/// Note: the behavior of `rtrim` in `PostgreSQL` and `trim_end` (or `trim_right`) in Rust -/// are actually different when the string is in right-to-left languages like Arabic or Hebrew. -/// Since we would like to simplify the implementation, currently we omit this case. -#[inline(always)] -pub fn rtrim(s: &str, writer: &mut dyn Write) -> Result<()> { - writer.write_str(s.trim_end()).unwrap(); - Ok(()) -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_rtrim() -> Result<()> { - let cases = [ - (" \tHello\tworld\t ", " \tHello\tworld"), - (" \t空I ❤️ databases空\t ", " \t空I ❤️ databases空"), - ]; - - for (s, expected) in cases { - let mut writer = String::new(); - rtrim(s, &mut writer)?; - assert_eq!(writer, expected); - } - Ok(()) - } -} diff --git a/src/expr/src/vector_op/split_part.rs b/src/expr/src/vector_op/split_part.rs index a3d4a33b6c6a3..d64108bef9f97 100644 --- a/src/expr/src/vector_op/split_part.rs +++ b/src/expr/src/vector_op/split_part.rs @@ -14,9 +14,11 @@ use std::fmt::Write; +use risingwave_expr_macro::function; + use crate::{ExprError, Result}; -#[inline(always)] +#[function("split_part(varchar, varchar, int32) -> varchar")] pub fn split_part( string_expr: &str, delimiter_expr: &str, diff --git a/src/expr/src/vector_op/substr.rs b/src/expr/src/vector_op/substr.rs index c3f81cae2c063..5175be1eca452 100644 --- a/src/expr/src/vector_op/substr.rs +++ b/src/expr/src/vector_op/substr.rs @@ -15,23 +15,25 @@ use std::cmp::{max, min}; use std::fmt::Write; +use risingwave_expr_macro::function; + use crate::{bail, Result}; -#[inline(always)] +#[function("substr(varchar, int32) -> varchar")] pub fn substr_start(s: &str, start: i32, writer: &mut dyn Write) -> Result<()> { let start = (start.saturating_sub(1).max(0) as usize).min(s.len()); writer.write_str(&s[start..]).unwrap(); Ok(()) } -#[inline(always)] +// #[function("substr(varchar, 0, int32) -> varchar")] pub fn substr_for(s: &str, count: i32, writer: &mut dyn Write) -> Result<()> { let end = min(count as usize, s.len()); writer.write_str(&s[..end]).unwrap(); Ok(()) } -#[inline(always)] +#[function("substr(varchar, int32, int32) -> varchar")] pub fn substr_start_for(s: &str, start: i32, count: i32, writer: &mut dyn Write) -> Result<()> { if count < 0 { bail!("length in substr should be non-negative: {}", count); @@ -56,19 +58,19 @@ mod tests { let s = "cxscgccdd"; let cases = [ - (s.to_owned(), Some(4), None, "cgccdd"), - (s.to_owned(), None, Some(3), "cxs"), - (s.to_owned(), Some(4), Some(-2), "[unused result]"), - (s.to_owned(), Some(4), Some(2), "cg"), - (s.to_owned(), Some(-1), Some(-5), "[unused result]"), - (s.to_owned(), Some(-1), Some(5), "cxs"), + (s, Some(4), None, "cgccdd"), + (s, None, Some(3), "cxs"), + (s, Some(4), Some(-2), "[unused result]"), + (s, Some(4), Some(2), "cg"), + (s, Some(-1), Some(-5), "[unused result]"), + (s, Some(-1), Some(5), "cxs"), ]; for (s, off, len, expected) in cases { let mut writer = String::new(); match (off, len) { (Some(off), Some(len)) => { - let result = substr_start_for(&s, off, len, &mut writer); + let result = substr_start_for(s, off, len, &mut writer); if len < 0 { assert!(result.is_err()); continue; @@ -76,8 +78,8 @@ mod tests { result? } } - (Some(off), None) => substr_start(&s, off, &mut writer)?, - (None, Some(len)) => substr_for(&s, len, &mut writer)?, + (Some(off), None) => substr_start(s, off, &mut writer)?, + (None, Some(len)) => substr_for(s, len, &mut writer)?, _ => unreachable!(), } assert_eq!(writer, expected); diff --git a/src/expr/src/vector_op/tests.rs b/src/expr/src/vector_op/tests.rs deleted file mode 100644 index 35d146b189ee1..0000000000000 --- a/src/expr/src/vector_op/tests.rs +++ /dev/null @@ -1,253 +0,0 @@ -// Copyright 2023 RisingWave Labs -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -use std::assert_matches::assert_matches; -use std::str::FromStr; - -use chrono::NaiveDateTime; -use risingwave_common::types::test_utils::IntervalUnitTestExt; -use risingwave_common::types::{ - Decimal, IntervalUnit, NaiveDateTimeWrapper, NaiveDateWrapper, OrderedF32, OrderedF64, -}; - -use crate::vector_op::arithmetic_op::*; -use crate::vector_op::bitwise_op::*; -use crate::vector_op::cast::try_cast; -use crate::vector_op::cmp::*; -use crate::vector_op::conjunction::*; -use crate::ExprError; - -#[test] -fn test_arithmetic() { - assert_eq!( - general_add::(dec("1.0"), 1).unwrap(), - dec("2.0") - ); - assert_eq!( - general_sub::(dec("1.0"), 2).unwrap(), - dec("-1.0") - ); - assert_eq!( - general_mul::(dec("1.0"), 2).unwrap(), - dec("2.0") - ); - assert_eq!( - general_div::(dec("2.0"), 2).unwrap(), - dec("1.0") - ); - assert_eq!( - general_mod::(dec("2.0"), 2).unwrap(), - dec("0") - ); - assert_eq!(general_neg::(dec("1.0")).unwrap(), dec("-1.0")); - assert_eq!(general_add::(1i16, 1i32).unwrap(), 2i32); - assert_eq!(general_sub::(1i16, 1i32).unwrap(), 0i32); - assert_eq!(general_mul::(1i16, 1i32).unwrap(), 1i32); - assert_eq!(general_div::(1i16, 1i32).unwrap(), 1i32); - assert_eq!(general_mod::(1i16, 1i32).unwrap(), 0i32); - assert_eq!(general_neg::(1i16).unwrap(), -1i16); - - assert_eq!( - general_add::(dec("1.0"), -1f32).unwrap(), - dec("0.0") - ); - assert_eq!( - general_sub::(dec("1.0"), 1f32).unwrap(), - dec("0.0") - ); - assert_eq!( - general_div::(dec("0.0"), 1f32).unwrap(), - dec("0.0") - ); - assert_eq!( - general_mul::(dec("0.0"), 1f32).unwrap(), - dec("0.0") - ); - assert_eq!( - general_mod::(dec("0.0"), 1f32).unwrap(), - dec("0.0") - ); - assert!( - general_add::(-1i32, 1f32.into()) - .unwrap() - .abs() - < f64::EPSILON - ); - assert!( - general_sub::(1i32, 1f32.into()) - .unwrap() - .abs() - < f64::EPSILON - ); - assert!( - general_mul::(0i32, 1f32.into()) - .unwrap() - .abs() - < f64::EPSILON - ); - assert!( - general_div::(0i32, 1f32.into()) - .unwrap() - .abs() - < f64::EPSILON - ); - assert_eq!( - general_neg::(1f32.into()).unwrap(), - OrderedF32::from(-1f32) - ); - assert_eq!( - date_interval_add::( - NaiveDateWrapper::from_ymd_uncheck(1994, 1, 1), - IntervalUnit::from_month(12) - ) - .unwrap(), - NaiveDateTimeWrapper::new( - NaiveDateTime::parse_from_str("1995-1-1 0:0:0", "%Y-%m-%d %H:%M:%S").unwrap() - ) - ); - assert_eq!( - interval_date_add::( - IntervalUnit::from_month(12), - NaiveDateWrapper::from_ymd_uncheck(1994, 1, 1) - ) - .unwrap(), - NaiveDateTimeWrapper::new( - NaiveDateTime::parse_from_str("1995-1-1 0:0:0", "%Y-%m-%d %H:%M:%S").unwrap() - ) - ); - assert_eq!( - date_interval_sub::( - NaiveDateWrapper::from_ymd_uncheck(1994, 1, 1), - IntervalUnit::from_month(12) - ) - .unwrap(), - NaiveDateTimeWrapper::new( - NaiveDateTime::parse_from_str("1993-1-1 0:0:0", "%Y-%m-%d %H:%M:%S").unwrap() - ) - ); -} - -#[test] -fn test_bitwise() { - // check the boundary - assert_eq!(general_shl::(1i32, 0i32).unwrap(), 1i32); - assert_eq!(general_shl::(1i64, 31i32).unwrap(), 2147483648i64); - assert_matches!( - general_shl::(1i32, 32i32).unwrap_err(), - ExprError::NumericOutOfRange, - ); - assert_eq!( - general_shr::(-2147483648i64, 31i32).unwrap(), - -1i64 - ); - assert_eq!(general_shr::(1i64, 0i32).unwrap(), 1i64); - // truth table - assert_eq!( - general_bitand::(0b0011u32, 0b0101u32), - 0b1u64 - ); - assert_eq!( - general_bitor::(0b0011u32, 0b0101u32), - 0b0111u64 - ); - assert_eq!( - general_bitxor::(0b0011u32, 0b0101u32), - 0b0110u64 - ); - assert_eq!(general_bitnot::(0b01i32), -2i32); -} - -#[test] -fn test_comparison() { - assert!(general_eq::(dec("1.0"), 1)); - assert!(general_eq::(dec("1.0"), 1.0)); - assert!(!general_ne::(dec("1.0"), 1)); - assert!(!general_ne::(dec("1.0"), 1.0)); - assert!(!general_gt::(dec("1.0"), 2)); - assert!(!general_gt::(dec("1.0"), 2.0)); - assert!(general_le::(dec("1.0"), 2)); - assert!(general_le::(dec("1.0"), 2.1)); - assert!(!general_ge::(dec("1.0"), 2)); - assert!(!general_ge::(dec("1.0"), 2.1)); - assert!(general_lt::(dec("1.0"), 2)); - assert!(general_lt::(dec("1.0"), 2.1)); - assert!(general_is_distinct_from::( - Some(dec("1.0")), - Some(2) - )); - assert!(general_is_distinct_from::( - Some(dec("1.0")), - Some(2.0) - )); - assert!(general_is_distinct_from::( - Some(dec("1.0")), - None - )); - assert!(general_is_distinct_from::( - None, - Some(1) - )); - assert!(!general_is_distinct_from::( - Some(dec("1.0")), - Some(1) - )); - assert!(!general_is_distinct_from::( - Some(dec("1.0")), - Some(1.0) - )); - assert!(!general_is_distinct_from::( - None, None - )); - assert!(general_eq::(1.0.into(), 1)); - assert!(!general_ne::(1.0.into(), 1)); - assert!(!general_lt::(1.0.into(), 1)); - assert!(general_le::(1.0.into(), 1)); - assert!(!general_gt::(1.0.into(), 1)); - assert!(general_ge::(1.0.into(), 1)); - assert!(!general_is_distinct_from::( - Some(1.0.into()), - Some(1) - )); - assert!(general_eq::(1i64, 1)); - assert!(!general_ne::(1i64, 1)); - assert!(!general_lt::(1i64, 1)); - assert!(general_le::(1i64, 1)); - assert!(!general_gt::(1i64, 1)); - assert!(general_ge::(1i64, 1)); - assert!(!general_is_distinct_from::( - Some(1i64), - Some(1) - )); -} - -#[test] -fn test_conjunction() { - assert!(not(Some(false)).unwrap()); - assert!(!and(Some(true), Some(false)).unwrap().unwrap()); - assert!(or(Some(true), Some(false)).unwrap().unwrap()); -} -#[test] -fn test_cast() { - assert_eq!( - try_cast::<_, NaiveDateTimeWrapper>(NaiveDateWrapper::from_ymd_uncheck(1994, 1, 1)) - .unwrap(), - NaiveDateTimeWrapper::new( - NaiveDateTime::parse_from_str("1994-1-1 0:0:0", "%Y-%m-%d %H:%M:%S").unwrap() - ) - ) -} - -fn dec(s: &str) -> Decimal { - Decimal::from_str(s).unwrap() -} diff --git a/src/expr/src/vector_op/timestamptz.rs b/src/expr/src/vector_op/timestamptz.rs index 12a7ce3d74657..83b85e9560e43 100644 --- a/src/expr/src/vector_op/timestamptz.rs +++ b/src/expr/src/vector_op/timestamptz.rs @@ -18,6 +18,7 @@ use chrono::{TimeZone, Utc}; use chrono_tz::Tz; use num_traits::ToPrimitive; use risingwave_common::types::{NaiveDateTimeWrapper, OrderedF64}; +use risingwave_expr_macro::function; use crate::vector_op::cast::{str_to_timestamp, str_with_time_zone_to_timestamptz}; use crate::{ExprError, Result}; @@ -31,7 +32,7 @@ fn lookup_time_zone(time_zone: &str) -> Result { }) } -#[inline(always)] +#[function("to_timestamp(float64) -> timestamptz")] pub fn f64_sec_to_timestamptz(elem: OrderedF64) -> Result { // TODO(#4515): handle +/- infinity (elem * 1e6) @@ -40,7 +41,7 @@ pub fn f64_sec_to_timestamptz(elem: OrderedF64) -> Result { .ok_or(ExprError::NumericOutOfRange) } -#[inline(always)] +#[function("at_time_zone(timestamp, varchar) -> timestamptz")] pub fn timestamp_at_time_zone(input: NaiveDateTimeWrapper, time_zone: &str) -> Result { let time_zone = lookup_time_zone(time_zone)?; // https://www.postgresql.org/docs/current/datetime-invalid-input.html @@ -65,6 +66,7 @@ pub fn timestamp_at_time_zone(input: NaiveDateTimeWrapper, time_zone: &str) -> R Ok(usec) } +#[function("cast_with_time_zone(timestamptz, varchar) -> varchar")] pub fn timestamptz_to_string(elem: i64, time_zone: &str, writer: &mut dyn Write) -> Result<()> { let time_zone = lookup_time_zone(time_zone)?; let secs = elem.div_euclid(1_000_000); @@ -82,12 +84,13 @@ pub fn timestamptz_to_string(elem: i64, time_zone: &str, writer: &mut dyn Write) // Tries to interpret the string with a timezone, and if failing, tries to interpret the string as a // timestamp and then adjusts it with the session timezone. +#[function("cast_with_time_zone(varchar, varchar) -> timestamptz")] pub fn str_to_timestamptz(elem: &str, time_zone: &str) -> Result { str_with_time_zone_to_timestamptz(elem) .or_else(|_| timestamp_at_time_zone(str_to_timestamp(elem)?, time_zone)) } -#[inline(always)] +#[function("at_time_zone(timestamptz, varchar) -> timestamp")] pub fn timestamptz_at_time_zone(input: i64, time_zone: &str) -> Result { let time_zone = lookup_time_zone(time_zone)?; let secs = input.div_euclid(1_000_000); diff --git a/src/expr/src/vector_op/to_char.rs b/src/expr/src/vector_op/to_char.rs index 5e2d02c186ab7..c5be0913b5a4f 100644 --- a/src/expr/src/vector_op/to_char.rs +++ b/src/expr/src/vector_op/to_char.rs @@ -20,8 +20,6 @@ use chrono::format::StrftimeItems; use ouroboros::self_referencing; use risingwave_common::types::NaiveDateTimeWrapper; -use crate::Result; - #[self_referencing] pub struct ChronoPattern { pub(crate) tmpl: String, @@ -70,18 +68,9 @@ pub fn compile_pattern_to_chrono(tmpl: &str) -> ChronoPattern { .build() } -#[inline(always)] -pub fn to_char_timestamp( - data: NaiveDateTimeWrapper, - tmpl: &str, - writer: &mut dyn Write, -) -> Result<()> { +// #[function("to_char(timestamp, varchar) -> varchar")] +pub fn to_char_timestamp(data: NaiveDateTimeWrapper, tmpl: &str, writer: &mut dyn Write) { let pattern = compile_pattern_to_chrono(tmpl); - write!( - writer, - "{}", - data.0.format_with_items(pattern.borrow_items().iter()) - ) - .unwrap(); - Ok(()) + let format = data.0.format_with_items(pattern.borrow_items().iter()); + write!(writer, "{}", format).unwrap(); } diff --git a/src/expr/src/vector_op/to_timestamp.rs b/src/expr/src/vector_op/to_timestamp.rs index 2ccfc66df8921..cad1e95f5b3ad 100644 --- a/src/expr/src/vector_op/to_timestamp.rs +++ b/src/expr/src/vector_op/to_timestamp.rs @@ -15,6 +15,7 @@ use chrono::format::Parsed; use risingwave_common::types::NaiveDateTimeWrapper; +// use risingwave_expr_macro::function; use super::to_char::{compile_pattern_to_chrono, ChronoPattern}; use crate::Result; @@ -65,7 +66,7 @@ pub fn to_timestamp_const_tmpl(s: &str, tmpl: &ChronoPattern) -> Result timestamp")] pub fn to_timestamp(s: &str, tmpl: &str) -> Result { let pattern = compile_pattern_to_chrono(tmpl); to_timestamp_const_tmpl(s, &pattern) diff --git a/src/expr/src/vector_op/translate.rs b/src/expr/src/vector_op/translate.rs index b2c44aa43d9f9..8dcec75307946 100644 --- a/src/expr/src/vector_op/translate.rs +++ b/src/expr/src/vector_op/translate.rs @@ -15,15 +15,10 @@ use std::collections::HashMap; use std::fmt::Write; -use crate::Result; +use risingwave_expr_macro::function; -#[inline(always)] -pub fn translate( - s: &str, - match_str: &str, - replace_str: &str, - writer: &mut dyn Write, -) -> Result<()> { +#[function("translate(varchar, varchar, varchar) -> varchar")] +pub fn translate(s: &str, match_str: &str, replace_str: &str, writer: &mut dyn Write) { let mut char_map = HashMap::new(); let mut match_chars = match_str.chars(); let mut replace_chars = replace_str.chars(); @@ -46,7 +41,6 @@ pub fn translate( for c in iter { writer.write_char(c).unwrap(); } - Ok(()) } #[cfg(test)] @@ -54,7 +48,7 @@ mod tests { use super::*; #[test] - fn test_translate() -> Result<()> { + fn test_translate() { let cases = [ ("hello world", "lo", "12", "he112 w2r1d"), ( @@ -73,9 +67,8 @@ mod tests { for (s, match_str, replace_str, expected) in cases { let mut writer = String::new(); - translate(s, match_str, replace_str, &mut writer)?; + translate(s, match_str, replace_str, &mut writer); assert_eq!(writer, expected); } - Ok(()) } } diff --git a/src/expr/src/vector_op/trim.rs b/src/expr/src/vector_op/trim.rs index 165991fa04618..754d90650f69c 100644 --- a/src/expr/src/vector_op/trim.rs +++ b/src/expr/src/vector_op/trim.rs @@ -14,12 +14,48 @@ use std::fmt::Write; -use crate::Result; +use risingwave_expr_macro::function; -#[inline(always)] -pub fn trim(s: &str, writer: &mut dyn Write) -> Result<()> { +#[function("trim(varchar) -> varchar")] +pub fn trim(s: &str, writer: &mut dyn Write) { writer.write_str(s.trim()).unwrap(); - Ok(()) +} + +/// Note: the behavior of `ltrim` in `PostgreSQL` and `trim_start` (or `trim_left`) in Rust +/// are actually different when the string is in right-to-left languages like Arabic or Hebrew. +/// Since we would like to simplify the implementation, currently we omit this case. +#[function("ltrim(varchar) -> varchar")] +pub fn ltrim(s: &str, writer: &mut dyn Write) { + writer.write_str(s.trim_start()).unwrap(); +} + +/// Note: the behavior of `rtrim` in `PostgreSQL` and `trim_end` (or `trim_right`) in Rust +/// are actually different when the string is in right-to-left languages like Arabic or Hebrew. +/// Since we would like to simplify the implementation, currently we omit this case. +#[function("rtrim(varchar) -> varchar")] +pub fn rtrim(s: &str, writer: &mut dyn Write) { + writer.write_str(s.trim_end()).unwrap(); +} + +#[function("trim(varchar, varchar) -> varchar")] +pub fn trim_characters(s: &str, characters: &str, writer: &mut dyn Write) { + let pattern = |c| characters.chars().any(|ch| ch == c); + // We remark that feeding a &str and a slice of chars into trim_left/right_matches + // means different, one is matching with the entire string and the other one is matching + // with any char in the slice. + writer.write_str(s.trim_matches(pattern)).unwrap(); +} + +#[function("ltrim(varchar, varchar) -> varchar")] +pub fn ltrim_characters(s: &str, characters: &str, writer: &mut dyn Write) { + let pattern = |c| characters.chars().any(|ch| ch == c); + writer.write_str(s.trim_start_matches(pattern)).unwrap(); +} + +#[function("rtrim(varchar, varchar) -> varchar")] +pub fn rtrim_characters(s: &str, characters: &str, writer: &mut dyn Write) { + let pattern = |c| characters.chars().any(|ch| ch == c); + writer.write_str(s.trim_end_matches(pattern)).unwrap(); } #[cfg(test)] @@ -27,7 +63,7 @@ mod tests { use super::*; #[test] - fn test_trim() -> Result<()> { + fn test_trim() { let cases = [ (" Hello\tworld\t", "Hello\tworld"), (" 空I ❤️ databases空 ", "空I ❤️ databases空"), @@ -35,9 +71,81 @@ mod tests { for (s, expected) in cases { let mut writer = String::new(); - trim(s, &mut writer)?; + trim(s, &mut writer); + assert_eq!(writer, expected); + } + } + + #[test] + fn test_ltrim() { + let cases = [ + (" \tHello\tworld\t", "Hello\tworld\t"), + (" \t空I ❤️ databases空 ", "空I ❤️ databases空 "), + ]; + + for (s, expected) in cases { + let mut writer = String::new(); + ltrim(s, &mut writer); + assert_eq!(writer, expected); + } + } + + #[test] + fn test_rtrim() { + let cases = [ + (" \tHello\tworld\t ", " \tHello\tworld"), + (" \t空I ❤️ databases空\t ", " \t空I ❤️ databases空"), + ]; + + for (s, expected) in cases { + let mut writer = String::new(); + rtrim(s, &mut writer); + assert_eq!(writer, expected); + } + } + + #[test] + fn test_trim_characters() { + let cases = [ + ("Hello world", "Hdl", "ello wor"), + ("abcde", "aae", "bcd"), + ("zxy", "yxz", ""), + ]; + + for (s, characters, expected) in cases { + let mut writer = String::new(); + trim_characters(s, characters, &mut writer); + assert_eq!(writer, expected); + } + } + + #[test] + fn test_ltrim_characters() { + let cases = [ + ("Hello world", "Hdl", "ello world"), + ("abcde", "aae", "bcde"), + ("zxy", "yxz", ""), + ]; + + for (s, characters, expected) in cases { + let mut writer = String::new(); + ltrim_characters(s, characters, &mut writer); + assert_eq!(writer, expected); + } + } + + #[test] + fn test_rtrim_characters() { + let cases = [ + ("Hello world", "Hdl", "Hello wor"), + ("abcde", "aae", "abcd"), + ("zxy", "yxz", ""), + ]; + + for (s, characters, expected) in cases { + let mut writer = String::new(); + rtrim_characters(s, characters, &mut writer); assert_eq!(writer, expected); } - Ok(()) } } diff --git a/src/expr/src/vector_op/trim_characters.rs b/src/expr/src/vector_op/trim_characters.rs deleted file mode 100644 index e61f72600b645..0000000000000 --- a/src/expr/src/vector_op/trim_characters.rs +++ /dev/null @@ -1,89 +0,0 @@ -// Copyright 2023 RisingWave Labs -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -use std::fmt::Write; - -use crate::Result; - -macro_rules! gen_trim { - ($( { $func_name:ident, $method:ident }),*) => { - $(#[inline(always)] - pub fn $func_name(s: &str, characters: &str, writer: &mut dyn Write) -> Result<()> { - let pattern = |c| characters.chars().any(|ch| ch == c); - // We remark that feeding a &str and a slice of chars into trim_left/right_matches - // means different, one is matching with the entire string and the other one is matching - // with any char in the slice. - Ok(writer.write_str(s.$method(pattern)).unwrap()) - })* - } -} - -gen_trim! { - { trim_characters, trim_matches }, - { ltrim_characters, trim_start_matches }, - { rtrim_characters, trim_end_matches } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_trim_characters() -> Result<()> { - let cases = [ - ("Hello world", "Hdl", "ello wor"), - ("abcde", "aae", "bcd"), - ("zxy", "yxz", ""), - ]; - - for (s, characters, expected) in cases { - let mut writer = String::new(); - trim_characters(s, characters, &mut writer)?; - assert_eq!(writer, expected); - } - Ok(()) - } - - #[test] - fn test_ltrim_characters() -> Result<()> { - let cases = [ - ("Hello world", "Hdl", "ello world"), - ("abcde", "aae", "bcde"), - ("zxy", "yxz", ""), - ]; - - for (s, characters, expected) in cases { - let mut writer = String::new(); - ltrim_characters(s, characters, &mut writer)?; - assert_eq!(writer, expected); - } - Ok(()) - } - - #[test] - fn test_rtrim_characters() -> Result<()> { - let cases = [ - ("Hello world", "Hdl", "Hello wor"), - ("abcde", "aae", "abcd"), - ("zxy", "yxz", ""), - ]; - - for (s, characters, expected) in cases { - let mut writer = String::new(); - rtrim_characters(s, characters, &mut writer)?; - assert_eq!(writer, expected); - } - Ok(()) - } -} diff --git a/src/expr/src/vector_op/tumble.rs b/src/expr/src/vector_op/tumble.rs index f6e0c29120d45..092194cd92002 100644 --- a/src/expr/src/vector_op/tumble.rs +++ b/src/expr/src/vector_op/tumble.rs @@ -16,6 +16,7 @@ use num_traits::Zero; use risingwave_common::types::{ IntervalUnit, NaiveDateTimeWrapper, NaiveDateWrapper, USECS_PER_DAY, USECS_PER_MONTH, }; +use risingwave_expr_macro::function; use crate::Result; @@ -24,7 +25,7 @@ fn interval_unit_to_micro_second(t: IntervalUnit) -> i64 { t.get_months() as i64 * USECS_PER_MONTH + t.get_days() as i64 * USECS_PER_DAY + t.get_usecs() } -#[inline(always)] +#[function("tumble_start(date, interval) -> timestamp")] pub fn tumble_start_date( timestamp: NaiveDateWrapper, window_size: IntervalUnit, @@ -32,7 +33,7 @@ pub fn tumble_start_date( tumble_start_date_time(timestamp.into(), window_size) } -#[inline(always)] +#[function("tumble_start(timestamp, interval) -> timestamp")] pub fn tumble_start_date_time( timestamp: NaiveDateTimeWrapper, window_size: IntervalUnit, @@ -45,7 +46,7 @@ pub fn tumble_start_date_time( )) } -#[inline(always)] +#[function("tumble_start(timestamptz, interval) -> timestamptz")] pub fn tumble_start_timestamptz( timestamp_micro_second: i64, window_size: IntervalUnit, @@ -61,7 +62,7 @@ fn get_window_start(timestamp_micro_second: i64, window_size: IntervalUnit) -> R get_window_start_with_offset(timestamp_micro_second, window_size, IntervalUnit::zero()) } -#[inline(always)] +#[function("tumble_start(date, interval, interval) -> timestamp")] pub fn tumble_start_offset_date( timestamp_date: NaiveDateWrapper, window_size: IntervalUnit, @@ -70,7 +71,7 @@ pub fn tumble_start_offset_date( tumble_start_offset_date_time(timestamp_date.into(), window_size, offset) } -#[inline(always)] +#[function("tumble_start(timestamp, interval, interval) -> timestamp")] pub fn tumble_start_offset_date_time( time: NaiveDateTimeWrapper, window_size: IntervalUnit, @@ -104,7 +105,7 @@ fn get_window_start_with_offset( } } -#[inline(always)] +#[function("tumble_start(timestamptz, interval, interval) -> timestamptz")] pub fn tumble_start_offset_timestamptz( timestamp_micro_second: i64, window_size: IntervalUnit, diff --git a/src/expr/src/vector_op/upper.rs b/src/expr/src/vector_op/upper.rs index 9bd54ba13cb4d..45cf51ce9e327 100644 --- a/src/expr/src/vector_op/upper.rs +++ b/src/expr/src/vector_op/upper.rs @@ -14,14 +14,13 @@ use std::fmt::Write; -use crate::Result; +use risingwave_expr_macro::function; -#[inline(always)] -pub fn upper(s: &str, writer: &mut dyn Write) -> Result<()> { +#[function("upper(varchar) -> varchar")] +pub fn upper(s: &str, writer: &mut dyn Write) { for c in s.chars() { writer.write_char(c.to_ascii_uppercase()).unwrap(); } - Ok(()) } #[cfg(test)] @@ -29,7 +28,7 @@ mod tests { use super::*; #[test] - fn test_upper() -> Result<()> { + fn test_upper() { let cases = [ ("hello world", "HELLO WORLD"), ("hello RUST", "HELLO RUST"), @@ -38,9 +37,8 @@ mod tests { for (s, expected) in cases { let mut writer = String::new(); - upper(s, &mut writer)?; + upper(s, &mut writer); assert_eq!(writer, expected); } - Ok(()) } } diff --git a/src/frontend/planner_test/tests/testdata/expr.yaml b/src/frontend/planner_test/tests/testdata/expr.yaml index ab14b016a8831..1674071a270d0 100644 --- a/src/frontend/planner_test/tests/testdata/expr.yaml +++ b/src/frontend/planner_test/tests/testdata/expr.yaml @@ -377,14 +377,12 @@ sql: | select array[1] < SOME(null); binder_error: |- - Feature is not yet implemented: LessThan[List, unknown] - Tracking issue: https://github.com/risingwavelabs/risingwave/issues/112 + Bind error: array/struct on left are not supported yet - name: array of array/struct on right not supported yet 5808 sql: | select null < SOME(array[array[1]]); binder_error: |- - Feature is not yet implemented: LessThan[unknown, List] - Tracking issue: https://github.com/risingwavelabs/risingwave/issues/112 + Bind error: array of array/struct on right are not supported yet - sql: | select 1 < SOME(array[null]::integer[]); logical_plan: | diff --git a/src/frontend/planner_test/tests/testdata/time_window.yaml b/src/frontend/planner_test/tests/testdata/time_window.yaml index db258b5eca7a6..d49329aae87b5 100644 --- a/src/frontend/planner_test/tests/testdata/time_window.yaml +++ b/src/frontend/planner_test/tests/testdata/time_window.yaml @@ -135,7 +135,7 @@ create table t (v1 varchar, v2 timestamp, v3 float); select v1, window_end, avg(v3) as avg from hop( t, v2, interval '1' minute, interval '10' minute) group by v1, window_end; logical_plan: | - LogicalProject { exprs: [t.v1, window_end, (sum(t.v3) / count(t.v3)::Float64) as $expr1] } + LogicalProject { exprs: [t.v1, window_end, (sum(t.v3) / count(t.v3)) as $expr1] } └─LogicalAgg { group_key: [t.v1, window_end], aggs: [sum(t.v3), count(t.v3)] } └─LogicalProject { exprs: [t.v1, window_end, t.v3] } └─LogicalHopWindow { time_col: t.v2, slide: 00:01:00, size: 00:10:00, output: all } @@ -143,7 +143,7 @@ └─LogicalScan { table: t, columns: [t.v1, t.v2, t.v3, t._row_id] } batch_plan: | BatchExchange { order: [], dist: Single } - └─BatchProject { exprs: [t.v1, window_end, (sum(t.v3) / count(t.v3)::Float64) as $expr1] } + └─BatchProject { exprs: [t.v1, window_end, (sum(t.v3) / count(t.v3)) as $expr1] } └─BatchHashAgg { group_key: [t.v1, window_end], aggs: [sum(t.v3), count(t.v3)] } └─BatchHopWindow { time_col: t.v2, slide: 00:01:00, size: 00:10:00, output: [t.v1, t.v3, window_end] } └─BatchExchange { order: [], dist: HashShard(t.v1) } @@ -151,7 +151,7 @@ └─BatchScan { table: t, columns: [t.v1, t.v2, t.v3], distribution: SomeShard } stream_plan: | StreamMaterialize { columns: [v1, window_end, avg], pk_columns: [v1, window_end], pk_conflict: "no check" } - └─StreamProject { exprs: [t.v1, window_end, (sum(t.v3) / count(t.v3)::Float64) as $expr1] } + └─StreamProject { exprs: [t.v1, window_end, (sum(t.v3) / count(t.v3)) as $expr1] } └─StreamHashAgg { group_key: [t.v1, window_end], aggs: [sum(t.v3), count(t.v3), count] } └─StreamExchange { dist: HashShard(t.v1, window_end) } └─StreamHopWindow { time_col: t.v2, slide: 00:01:00, size: 00:10:00, output: [t.v1, t.v3, window_end, t._row_id] } diff --git a/src/frontend/src/expr/function_call.rs b/src/frontend/src/expr/function_call.rs index 1ec6593d2c79e..4c3fffb0b1a36 100644 --- a/src/frontend/src/expr/function_call.rs +++ b/src/frontend/src/expr/function_call.rs @@ -223,18 +223,15 @@ impl FunctionCall { let expr_type = func_types.remove(0); match expr_type { ExprType::Some | ExprType::All => { - let ensure_return_boolean = |return_type: &DataType| { - if &DataType::Boolean == return_type { - Ok(()) - } else { - Err(ErrorCode::BindError( - "op ANY/ALL (array) requires operator to yield boolean".to_string(), - )) - } - }; - let return_type = infer_some_all(func_types, &mut inputs)?; - ensure_return_boolean(&return_type)?; + + if return_type != DataType::Boolean { + return Err(ErrorCode::BindError(format!( + "op ANY/ALL (array) requires operator to yield boolean, but got {:?}", + return_type + )) + .into()); + } Ok(FunctionCall::new_unchecked(expr_type, inputs, return_type).into()) } diff --git a/src/frontend/src/expr/type_inference/func.rs b/src/frontend/src/expr/type_inference/func.rs index 2a7875a49a1f6..c72bf9cdbe235 100644 --- a/src/frontend/src/expr/type_inference/func.rs +++ b/src/frontend/src/expr/type_inference/func.rs @@ -44,7 +44,7 @@ pub fn infer_type(func_type: ExprType, inputs: &mut Vec) -> Result( func_type: ExprType, inputs: &[Option], ) -> Result<&'a FuncSign> { - let candidates = sig_map - .get(&(func_type, inputs.len())) - .map(std::ops::Deref::deref) - .unwrap_or_default(); + let candidates = sig_map.get_with_arg_nums(func_type, inputs.len()); // Binary operators have a special `unknown` handling rule for exact match. We do not // distinguish operators from functions as of now. @@ -1048,108 +1062,114 @@ mod tests { let testcases = [ ( "Binary special rule prefers arguments of same type.", - vec![ - vec![T::Int32, T::Int32], - vec![T::Int32, T::Varchar], - vec![T::Int32, T::Float64], - ], - &[Some(T::Int32), None] as &[_], - Ok(&[T::Int32, T::Int32] as &[_]), + &[ + &[T::Int32, T::Int32][..], + &[T::Int32, T::Varchar], + &[T::Int32, T::Float64], + ][..], + &[Some(T::Int32), None][..], + Ok(&[T::Int32, T::Int32][..]), ), ( "Without binary special rule, Rule 4e selects varchar.", - vec![ - vec![T::Int32, T::Int32, T::Int32], - vec![T::Int32, T::Int32, T::Varchar], - vec![T::Int32, T::Int32, T::Float64], + &[ + &[T::Int32, T::Int32, T::Int32], + &[T::Int32, T::Int32, T::Varchar], + &[T::Int32, T::Int32, T::Float64], ], - &[Some(T::Int32), Some(T::Int32), None] as &[_], - Ok(&[T::Int32, T::Int32, T::Varchar] as &[_]), + &[Some(T::Int32), Some(T::Int32), None], + Ok(&[T::Int32, T::Int32, T::Varchar]), ), ( "Without binary special rule, Rule 4e selects preferred type.", - vec![ - vec![T::Int32, T::Int32, T::Int32], - vec![T::Int32, T::Int32, T::Float64], + &[ + &[T::Int32, T::Int32, T::Int32], + &[T::Int32, T::Int32, T::Float64], ], - &[Some(T::Int32), Some(T::Int32), None] as &[_], - Ok(&[T::Int32, T::Int32, T::Float64] as &[_]), + &[Some(T::Int32), Some(T::Int32), None], + Ok(&[T::Int32, T::Int32, T::Float64]), ), ( "Without binary special rule, Rule 4f treats exact-match and cast-match equally.", - vec![ - vec![T::Int32, T::Int32, T::Int32], - vec![T::Int32, T::Int32, T::Float32], + &[ + &[T::Int32, T::Int32, T::Int32], + &[T::Int32, T::Int32, T::Float32], ], - &[Some(T::Int32), Some(T::Int32), None] as &[_], + &[Some(T::Int32), Some(T::Int32), None], Err("not unique"), ), ( "`top_matches` ranks by exact count then preferred count", - vec![ - vec![T::Float64, T::Float64, T::Float64, T::Timestamptz], /* 0 exact 3 preferred */ - vec![T::Float64, T::Int32, T::Float32, T::Timestamp], // 1 exact 1 preferred - vec![T::Float32, T::Float32, T::Int32, T::Timestamptz], // 1 exact 0 preferred - vec![T::Int32, T::Float64, T::Float32, T::Timestamptz], // 1 exact 1 preferred - vec![T::Int32, T::Int16, T::Int32, T::Timestamptz], // 2 exact 1 non-castable - vec![T::Int32, T::Float64, T::Float32, T::Date], // 1 exact 1 preferred + &[ + &[T::Float64, T::Float64, T::Float64, T::Timestamptz], /* 0 exact 3 preferred */ + &[T::Float64, T::Int32, T::Float32, T::Timestamp], // 1 exact 1 preferred + &[T::Float32, T::Float32, T::Int32, T::Timestamptz], // 1 exact 0 preferred + &[T::Int32, T::Float64, T::Float32, T::Timestamptz], // 1 exact 1 preferred + &[T::Int32, T::Int16, T::Int32, T::Timestamptz], // 2 exact 1 non-castable + &[T::Int32, T::Float64, T::Float32, T::Date], // 1 exact 1 preferred ], - &[Some(T::Int32), Some(T::Int32), Some(T::Int32), None] as &[_], - Ok(&[T::Int32, T::Float64, T::Float32, T::Timestamptz] as &[_]), + &[Some(T::Int32), Some(T::Int32), Some(T::Int32), None], + Ok(&[T::Int32, T::Float64, T::Float32, T::Timestamptz]), ), ( "Rule 4e fails and Rule 4f unique.", - vec![ - vec![T::Int32, T::Int32, T::Time], - vec![T::Int32, T::Int32, T::Int32], + &[ + &[T::Int32, T::Int32, T::Time], + &[T::Int32, T::Int32, T::Int32], ], - &[None, Some(T::Int32), None] as &[_], - Ok(&[T::Int32, T::Int32, T::Int32] as &[_]), + &[None, Some(T::Int32), None], + Ok(&[T::Int32, T::Int32, T::Int32]), ), ( "Rule 4e empty and Rule 4f unique.", - vec![ - vec![T::Int32, T::Int32, T::Varchar], - vec![T::Int32, T::Int32, T::Int32], - vec![T::Varchar, T::Int32, T::Int32], + &[ + &[T::Int32, T::Int32, T::Varchar], + &[T::Int32, T::Int32, T::Int32], + &[T::Varchar, T::Int32, T::Int32], ], - &[None, Some(T::Int32), None] as &[_], - Ok(&[T::Int32, T::Int32, T::Int32] as &[_]), + &[None, Some(T::Int32), None], + Ok(&[T::Int32, T::Int32, T::Int32]), ), ( "Rule 4e varchar resolves prior category conflict.", - vec![ - vec![T::Int32, T::Int32, T::Float32], - vec![T::Time, T::Int32, T::Int32], - vec![T::Varchar, T::Int32, T::Int32], + &[ + &[T::Int32, T::Int32, T::Float32], + &[T::Time, T::Int32, T::Int32], + &[T::Varchar, T::Int32, T::Int32], ], - &[None, Some(T::Int32), None] as &[_], - Ok(&[T::Varchar, T::Int32, T::Int32] as &[_]), + &[None, Some(T::Int32), None], + Ok(&[T::Varchar, T::Int32, T::Int32]), ), ( "Rule 4f fails.", - vec![ - vec![T::Float32, T::Float32, T::Float32, T::Float32], - vec![T::Decimal, T::Decimal, T::Int64, T::Decimal], + &[ + &[T::Float32, T::Float32, T::Float32, T::Float32], + &[T::Decimal, T::Decimal, T::Int64, T::Decimal], ], - &[Some(T::Int16), Some(T::Int32), None, Some(T::Int64)] as &[_], + &[Some(T::Int16), Some(T::Int32), None, Some(T::Int64)], Err("not unique"), ), ( "Rule 4f all unknown.", - vec![ - vec![T::Float32, T::Float32, T::Float32, T::Float32], - vec![T::Decimal, T::Decimal, T::Int64, T::Decimal], + &[ + &[T::Float32, T::Float32, T::Float32, T::Float32], + &[T::Decimal, T::Decimal, T::Int64, T::Decimal], ], - &[None, None, None, None] as &[_], + &[None, None, None, None], Err("not unique"), ), ]; for (desc, candidates, inputs, expected) in testcases { let mut sig_map = FuncSigMap::default(); - candidates - .into_iter() - .for_each(|formals| sig_map.insert(DUMMY_FUNC, formals, DUMMY_RET)); + for formals in candidates { + sig_map.insert(FuncSign { + name: "add", + func: DUMMY_FUNC, + inputs_type: formals, + ret_type: DUMMY_RET, + build: |_, _| unreachable!(), + }); + } let result = infer_type_name(&sig_map, DUMMY_FUNC, inputs); match (expected, result) { (Ok(expected), Ok(found)) => { diff --git a/src/meta/src/rpc/election_client.rs b/src/meta/src/rpc/election_client.rs index 1d9ea210bce3b..0195d211cea10 100644 --- a/src/meta/src/rpc/election_client.rs +++ b/src/meta/src/rpc/election_client.rs @@ -352,7 +352,7 @@ mod tests { let handle = tokio::spawn(async move { let addr = "0.0.0.0:2388".parse().unwrap(); let mut builder = etcd_client::SimServer::builder(); - builder.serve(addr).await; + builder.serve(addr).await.unwrap(); }); let mut clients: Vec<(watch::Sender<()>, Arc)> = vec![]; diff --git a/src/stream/src/executor/dynamic_filter.rs b/src/stream/src/executor/dynamic_filter.rs index 18a7437d70e1b..c8f7a1f1bf612 100644 --- a/src/stream/src/executor/dynamic_filter.rs +++ b/src/stream/src/executor/dynamic_filter.rs @@ -25,9 +25,7 @@ use risingwave_common::hash::VnodeBitmapExt; use risingwave_common::row::{once, OwnedRow as RowData, Row}; use risingwave_common::types::{DataType, Datum, ScalarImpl, ToDatumRef, ToOwnedDatum}; use risingwave_common::util::iter_util::ZipEqDebug; -use risingwave_expr::expr::{ - new_binary_expr, BoxedExpression, InputRefExpression, LiteralExpression, -}; +use risingwave_expr::expr::{build, BoxedExpression, InputRefExpression, LiteralExpression}; use risingwave_pb::expr::expr_node::Type as ExprNodeType; use risingwave_pb::expr::expr_node::Type::{ GreaterThan, GreaterThanOrEqual, LessThan, LessThanOrEqual, @@ -267,11 +265,13 @@ impl DynamicFilterExecutor { assert_eq!(l_data_type, r_data_type); let dynamic_cond = move |literal: Datum| { literal.map(|scalar| { - new_binary_expr( + build( self.comparator, DataType::Boolean, - Box::new(InputRefExpression::new(l_data_type.clone(), self.key_l)), - Box::new(LiteralExpression::new(r_data_type.clone(), Some(scalar))), + vec![ + Box::new(InputRefExpression::new(l_data_type.clone(), self.key_l)), + Box::new(LiteralExpression::new(r_data_type.clone(), Some(scalar))), + ], ) }) }; diff --git a/src/stream/src/executor/filter.rs b/src/stream/src/executor/filter.rs index 61b9495d0f8c6..cdbc7fd466e50 100644 --- a/src/stream/src/executor/filter.rs +++ b/src/stream/src/executor/filter.rs @@ -199,8 +199,8 @@ mod tests { use risingwave_common::array::StreamChunk; use risingwave_common::catalog::{Field, Schema}; use risingwave_common::types::DataType; - use risingwave_expr::expr::{new_binary_expr, InputRefExpression}; - use risingwave_pb::expr::expr_node::Type; + use risingwave_expr::expr::{build, InputRefExpression}; + use risingwave_pb::expr::expr_node::PbType; use super::super::test_utils::MockSource; use super::super::*; @@ -234,15 +234,16 @@ mod tests { }; let source = MockSource::with_chunks(schema, PkIndices::new(), vec![chunk1, chunk2]); - let left_expr = InputRefExpression::new(DataType::Int64, 0); - let right_expr = InputRefExpression::new(DataType::Int64, 1); - let test_expr = new_binary_expr( - Type::GreaterThan, + let test_expr = build( + PbType::GreaterThan, DataType::Boolean, - Box::new(left_expr), - Box::new(right_expr), + vec![ + Box::new(InputRefExpression::new(DataType::Int64, 0)), + Box::new(InputRefExpression::new(DataType::Int64, 1)), + ], ) .unwrap(); + let filter = Box::new(FilterExecutor::new( ActorContext::create(123), Box::new(source), diff --git a/src/stream/src/executor/hash_join.rs b/src/stream/src/executor/hash_join.rs index 4bbca3ac74148..4aff917c8cc9c 100644 --- a/src/stream/src/executor/hash_join.rs +++ b/src/stream/src/executor/hash_join.rs @@ -1057,8 +1057,8 @@ mod tests { use risingwave_common::hash::{Key128, Key64}; use risingwave_common::types::ScalarImpl; use risingwave_common::util::sort_util::OrderType; - use risingwave_expr::expr::{new_binary_expr, InputRefExpression}; - use risingwave_pb::expr::expr_node::Type; + use risingwave_expr::expr::{build, InputRefExpression}; + use risingwave_pb::expr::expr_node::PbType; use risingwave_storage::memory::MemoryStateStore; use super::*; @@ -1111,13 +1111,13 @@ mod tests { } fn create_cond() -> BoxedExpression { - let left_expr = InputRefExpression::new(DataType::Int64, 1); - let right_expr = InputRefExpression::new(DataType::Int64, 3); - new_binary_expr( - Type::LessThan, + build( + PbType::LessThan, DataType::Boolean, - Box::new(left_expr), - Box::new(right_expr), + vec![ + Box::new(InputRefExpression::new(DataType::Int64, 1)), + Box::new(InputRefExpression::new(DataType::Int64, 3)), + ], ) .unwrap() } diff --git a/src/stream/src/executor/project.rs b/src/stream/src/executor/project.rs index 86ecb80a5feda..0dec2ee6af94a 100644 --- a/src/stream/src/executor/project.rs +++ b/src/stream/src/executor/project.rs @@ -188,8 +188,8 @@ mod tests { use risingwave_common::array::StreamChunk; use risingwave_common::catalog::{Field, Schema}; use risingwave_common::types::DataType; - use risingwave_expr::expr::{new_binary_expr, InputRefExpression, LiteralExpression}; - use risingwave_pb::expr::expr_node::Type; + use risingwave_expr::expr::{build, Expression, InputRefExpression, LiteralExpression}; + use risingwave_pb::expr::expr_node::PbType; use super::super::test_utils::MockSource; use super::super::*; @@ -216,13 +216,13 @@ mod tests { }; let source = MockSource::with_chunks(schema, PkIndices::new(), vec![chunk1, chunk2]); - let left_expr = InputRefExpression::new(DataType::Int64, 0); - let right_expr = InputRefExpression::new(DataType::Int64, 1); - let test_expr = new_binary_expr( - Type::Add, + let test_expr = build( + PbType::Add, DataType::Int64, - Box::new(left_expr), - Box::new(right_expr), + vec![ + InputRefExpression::new(DataType::Int64, 0).boxed(), + InputRefExpression::new(DataType::Int64, 1).boxed(), + ], ) .unwrap(); @@ -269,23 +269,26 @@ mod tests { }; let (mut tx, source) = MockSource::channel(schema, PkIndices::new()); - let a_left_expr = InputRefExpression::new(DataType::Int64, 0); - let a_right_expr = LiteralExpression::new(DataType::Int64, Some(ScalarImpl::Int64(1))); - let a_expr = new_binary_expr( - Type::Add, + let a_expr = build( + PbType::Add, DataType::Int64, - Box::new(a_left_expr), - Box::new(a_right_expr), + vec![ + InputRefExpression::new(DataType::Int64, 0).boxed(), + LiteralExpression::new(DataType::Int64, Some(ScalarImpl::Int64(1))).boxed(), + ], ) .unwrap(); - let b_left_expr = InputRefExpression::new(DataType::Int64, 0); - let b_right_expr = LiteralExpression::new(DataType::Int64, Some(ScalarImpl::Int64(1))); - let b_expr = new_binary_expr( - Type::Subtract, + let b_expr = build( + PbType::Subtract, DataType::Int64, - Box::new(b_left_expr), - Box::new(b_right_expr), + vec![ + Box::new(InputRefExpression::new(DataType::Int64, 0)), + Box::new(LiteralExpression::new( + DataType::Int64, + Some(ScalarImpl::Int64(1)), + )), + ], ) .unwrap(); diff --git a/src/stream/src/executor/project_set.rs b/src/stream/src/executor/project_set.rs index 3d908d9e55df4..046db5003c56e 100644 --- a/src/stream/src/executor/project_set.rs +++ b/src/stream/src/executor/project_set.rs @@ -210,11 +210,9 @@ mod tests { use risingwave_common::array::StreamChunk; use risingwave_common::catalog::{Field, Schema}; use risingwave_common::types::DataType; - use risingwave_expr::expr::{ - new_binary_expr, Expression, InputRefExpression, LiteralExpression, - }; + use risingwave_expr::expr::{build, Expression, InputRefExpression, LiteralExpression}; use risingwave_expr::table_function::repeat_tf; - use risingwave_pb::expr::expr_node::Type; + use risingwave_pb::expr::expr_node::PbType; use super::super::test_utils::MockSource; use super::super::*; @@ -243,15 +241,16 @@ mod tests { }; let source = MockSource::with_chunks(schema, PkIndices::new(), vec![chunk1, chunk2]); - let left_expr = InputRefExpression::new(DataType::Int64, 0); - let right_expr = InputRefExpression::new(DataType::Int64, 1); - let test_expr = new_binary_expr( - Type::Add, + let test_expr = build( + PbType::Add, DataType::Int64, - Box::new(left_expr), - Box::new(right_expr), + vec![ + Box::new(InputRefExpression::new(DataType::Int64, 0)), + Box::new(InputRefExpression::new(DataType::Int64, 1)), + ], ) .unwrap(); + let tf1 = repeat_tf( LiteralExpression::new(DataType::Int32, Some(1_i32.into())).boxed(), 1, diff --git a/src/stream/src/executor/watermark_filter.rs b/src/stream/src/executor/watermark_filter.rs index 90496b14011f6..515d3acb97754 100644 --- a/src/stream/src/executor/watermark_filter.rs +++ b/src/stream/src/executor/watermark_filter.rs @@ -23,7 +23,7 @@ use risingwave_common::row::{OwnedRow, Row}; use risingwave_common::types::{DataType, ScalarImpl}; use risingwave_common::{bail, row}; use risingwave_expr::expr::{ - new_binary_expr, BoxedExpression, Expression, InputRefExpression, LiteralExpression, + build, BoxedExpression, Expression, InputRefExpression, LiteralExpression, }; use risingwave_expr::Result as ExprResult; use risingwave_pb::expr::expr_node::Type; @@ -236,11 +236,13 @@ impl WatermarkFilterExecutor { event_time_col_idx: usize, watermark: ScalarImpl, ) -> ExprResult { - new_binary_expr( + build( Type::GreaterThanOrEqual, DataType::Boolean, - InputRefExpression::new(watermark_type.clone(), event_time_col_idx).boxed(), - LiteralExpression::new(watermark_type, Some(watermark)).boxed(), + vec![ + InputRefExpression::new(watermark_type.clone(), event_time_col_idx).boxed(), + LiteralExpression::new(watermark_type, Some(watermark)).boxed(), + ], ) } @@ -333,17 +335,19 @@ mod tests { ], }; - let watermark_expr = new_binary_expr( + let watermark_expr = build( Type::Subtract, WATERMARK_TYPE.clone(), - InputRefExpression::new(WATERMARK_TYPE.clone(), 1).boxed(), - LiteralExpression::new( - interval_type, - Some(ScalarImpl::Interval(IntervalUnit::from_month_day_usec( - 0, 1, 0, - ))), - ) - .boxed(), + vec![ + InputRefExpression::new(WATERMARK_TYPE.clone(), 1).boxed(), + LiteralExpression::new( + interval_type, + Some(ScalarImpl::Interval(IntervalUnit::from_month_day_usec( + 0, 1, 0, + ))), + ) + .boxed(), + ], ) .unwrap();