Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor(expr): generate build-from-prost with procedural macros #8499

Merged
merged 55 commits into from
Mar 24, 2023
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
Show all changes
55 commits
Select commit Hold shift + click to select a range
c6ec937
add expr macro crate
wangrunji0408 Mar 9, 2023
19cb654
add function annotation to vector ops
wangrunji0408 Mar 10, 2023
f79710c
basic macros
wangrunji0408 Mar 12, 2023
ad12df1
parse writer and batch
wangrunji0408 Mar 12, 2023
1a5d300
generate build from prost
wangrunji0408 Mar 12, 2023
a5268b3
refactor code
wangrunji0408 Mar 12, 2023
850c8c0
remove build functions
wangrunji0408 Mar 13, 2023
257f262
add license header
wangrunji0408 Mar 13, 2023
b5ada1b
make array_to_string and array_distinct functions
wangrunji0408 Mar 13, 2023
fb161ee
move build function of some-all to its mod
wangrunji0408 Mar 13, 2023
4f8fb9b
add function to global registry
wangrunji0408 Mar 13, 2023
74b861f
fix build for expr crate
wangrunji0408 Mar 13, 2023
c525b4e
build and/or/is_null/is_not_null
wangrunji0408 Mar 13, 2023
30ba885
remove Result if possible
wangrunji0408 Mar 13, 2023
6eded32
remove some tests
wangrunji0408 Mar 14, 2023
99272ec
consider return type
wangrunji0408 Mar 14, 2023
f5b24b5
add "not"
wangrunji0408 Mar 14, 2023
a4b7ed9
fix build for test
wangrunji0408 Mar 14, 2023
3a9da1c
fix is_distinct_from
wangrunji0408 Mar 14, 2023
d4f2d6e
pass all unit tests
wangrunji0408 Mar 14, 2023
7da23ec
Merge remote-tracking branch 'origin/main' into wrj/expr-proce-macro
wangrunji0408 Mar 15, 2023
85b65d7
Merge remote-tracking branch 'origin/main' into wrj/expr-proce-macro
wangrunji0408 Mar 20, 2023
b8dc6c3
refactor for new_binary_expr
wangrunji0408 Mar 21, 2023
a4ab1f9
revert FuncSign
wangrunji0408 Mar 21, 2023
391899c
change build_from_prost to build function
wangrunji0408 Mar 21, 2023
3407e39
fix all build
wangrunji0408 Mar 21, 2023
70ed256
fix unit test
wangrunji0408 Mar 21, 2023
0080c3f
build expressions using `build`
wangrunji0408 Mar 21, 2023
f477c60
fix panic in unit test
wangrunji0408 Mar 21, 2023
4e65bf6
simplify unnest with the help of ChatGPT
wangrunji0408 Mar 21, 2023
0c19fbe
fix cast with list and struct
wangrunji0408 Mar 21, 2023
6932fa0
fix array_to_string
wangrunji0408 Mar 22, 2023
c7a1840
fix to_timestamp1
wangrunji0408 Mar 22, 2023
70cc436
remove ensure length in build functions
wangrunji0408 Mar 22, 2023
2094ec7
fix clippy
wangrunji0408 Mar 22, 2023
d7cc3a8
Merge remote-tracking branch 'origin/main' into wrj/expr-proce-macro
wangrunji0408 Mar 22, 2023
de31d64
add docs for function signature registry
wangrunji0408 Mar 22, 2023
b79710e
support `fn(T) -> Result<Option<T>>`
wangrunji0408 Mar 22, 2023
13095dd
debug
wangrunji0408 Mar 22, 2023
4d4ec94
fix char_length
wangrunji0408 Mar 22, 2023
759ce3f
fix jsonb_access
wangrunji0408 Mar 22, 2023
48c2ac6
move build function into sub-module
wangrunji0408 Mar 23, 2023
66165cf
use build_function for `now()`. fix no input function
wangrunji0408 Mar 23, 2023
4d306f3
avoid RUSTFLAGS override problem in simulation
wangrunji0408 Mar 23, 2023
ff7b732
fix now
wangrunji0408 Mar 23, 2023
5110437
minor change
wangrunji0408 Mar 23, 2023
65a83b6
fix planner test
wangrunji0408 Mar 23, 2023
9951f5b
fix clippy
wangrunji0408 Mar 23, 2023
1421553
Merge remote-tracking branch 'origin/main' into wrj/expr-proce-macro
wangrunji0408 Mar 23, 2023
b28e325
move unit test to corresponding modules
wangrunji0408 Mar 24, 2023
66866f7
merge all trim functions into one file
wangrunji0408 Mar 24, 2023
22a0dfe
fix expr bench
wangrunji0408 Mar 24, 2023
8375393
Merge remote-tracking branch 'origin/main' into wrj/expr-proce-macro
wangrunji0408 Mar 24, 2023
539e87b
simplify array_length
wangrunji0408 Mar 24, 2023
f691ca5
fix clippy
wangrunji0408 Mar 24, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ members = [
"src/connector",
"src/ctl",
"src/expr",
"src/expr/macro",
"src/frontend",
"src/frontend/planner_test",
"src/java_binding",
Expand Down
1 change: 1 addition & 0 deletions src/expr/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ parse-display = "0.6"
paste = "1"
regex = "1"
risingwave_common = { path = "../common" }
risingwave_expr_macro = { path = "macro" }
risingwave_pb = { path = "../prost" }
risingwave_udf = { path = "../udf" }
speedate = "0.7.0"
Expand Down
16 changes: 16 additions & 0 deletions src/expr/macro/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
[package]
name = "risingwave_expr_macro"
version = "0.1.0"
edition = "2021"

# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

[lib]
proc-macro = true

[dependencies]
proc-macro-error = "1"
proc-macro2 = "1"
quote = "1"
syn = "1"
itertools = "0.10"
232 changes: 232 additions & 0 deletions src/expr/macro/src/gen.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,232 @@
//! Generate code for the functions.
wangrunji0408 marked this conversation as resolved.
Show resolved Hide resolved

use itertools::Itertools;
use proc_macro2::Span;
use quote::{format_ident, quote};

use super::*;

impl FunctionAttr {
/// Generate descriptors of the function.
///
/// If the function arguments or return type contains wildcard, it will generate descriptors for
/// each of them.
pub fn generate_descriptors(&self) -> Result<TokenStream2> {
let args = self.args.iter().map(|ty| types::expand_type_wildcard(&ty));
let ret = types::expand_type_wildcard(&self.ret);
let mut tokens = TokenStream2::new();
for (args, mut ret) in args.multi_cartesian_product().cartesian_product(ret) {
if ret == "auto" {
ret = types::min_compatible_type(&args);
}
let attr = FunctionAttr {
name: self.name.clone(),
args: args.iter().map(|s| s.to_string()).collect(),
ret: ret.to_string(),
batch: self.batch.clone(),
user_fn: self.user_fn.clone(),
};
tokens.extend(attr.generate_descriptor_one()?);
}
Ok(tokens)
}

/// Generate a descriptor of the function.
///
/// The types of arguments and return value should not contain wildcard.
fn generate_descriptor_one(&self) -> Result<TokenStream2> {
let name = self.name.clone();

fn to_data_type_name(ty: &str) -> Result<TokenStream2> {
let variant = format_ident!(
"{}",
types::to_data_type_name(ty).ok_or_else(|| Error::new(
Span::call_site(),
format!("unknown type: {}", ty),
))?
);
Ok(quote! { risingwave_common::types::DataTypeName::#variant })
}
let mut args = Vec::with_capacity(self.args.len());
for ty in &self.args {
args.push(to_data_type_name(ty)?);
}
let ret = to_data_type_name(&self.ret)?;

let pb_type = format_ident!("{}", utils::to_camel_case(&name));
let descriptor_name = format_ident!("{}_{}_{}", self.name, self.args.join("_"), self.ret);
let descriptor_type = quote! { crate::sig::func::FunctionDescriptor };
let build_fn = self.generate_build_fn()?;
Ok(quote! {
static #descriptor_name: #descriptor_type = #descriptor_type {
name: #name,
ty: risingwave_pb::expr::expr_node::Type::#pb_type,
args: &[#(#args),*],
ret: #ret,
build_from_prost: #build_fn,
};
})
}

fn generate_build_fn(&self) -> Result<TokenStream2> {
let num_args = self.args.len();
let i = 0..self.args.len();
let fn_name = format_ident!("{}", self.user_fn.name);
let arg_arrays = self
.args
.iter()
.map(|t| format_ident!("{}", types::to_array_type(t)));
let ret_array = format_ident!("{}", types::to_array_type(&self.ret));
let arg_types = self
.args
.iter()
.map(|t| types::to_data_type(t).parse::<TokenStream2>().unwrap());
let ret_type = types::to_data_type(&self.ret)
.parse::<TokenStream2>()
.unwrap();

let prepare = quote! {
use risingwave_common::array::*;
use risingwave_common::types::*;
use risingwave_pb::expr::expr_node::RexNode;

let return_type = DataType::from(prost.get_return_type().unwrap());
let RexNode::FuncCall(func_call) = prost.get_rex_node().unwrap() else {
crate::bail!("Expected RexNode::FuncCall");
};
let children = func_call.get_children();
crate::ensure!(children.len() == #num_args);
let exprs = [#(crate::expr::build_from_prost(&children[#i])?),*];
};

let build_expr = if self.ret == "varchar" && self.user_fn.is_writer_style() {
let template_struct = match num_args {
1 => format_ident!("UnaryBytesExpression"),
2 => format_ident!("BinaryBytesExpression"),
3 => format_ident!("TernaryBytesExpression"),
4 => format_ident!("QuaternaryBytesExpression"),
_ => return Err(Error::new(Span::call_site(), "unsupported arguments")),
};
let i = 0..self.args.len();
quote! {
Ok(Box::new(crate::expr::template::#template_struct::<#(#arg_arrays),*, _>::new(
#(exprs[#i]),*,
return_type,
#fn_name,
)))
}
} else if self.args.iter().all(|t| t == "boolean")
&& self.ret == "boolean"
&& !self.user_fn.return_result
&& self.batch.is_some()
{
let template_struct = match num_args {
1 => format_ident!("BooleanUnaryExpression"),
2 => format_ident!("BooleanBinaryExpression"),
_ => return Err(Error::new(Span::call_site(), "unsupported arguments")),
};
let batch = format_ident!("{}", self.batch.as_ref().unwrap());
let i = 0..self.args.len();
let func = if self.user_fn.arg_option && self.user_fn.return_option {
quote! { #fn_name }
} else if self.user_fn.arg_option {
let args = (0..num_args).map(|i| format_ident!("x{i}"));
let args1 = args.clone();
quote! { |#(#args),*| Some(#fn_name(#(#args1),*)) }
} else {
let args = (0..num_args).map(|i| format_ident!("x{i}"));
let args1 = args.clone();
let args2 = args.clone();
let args3 = args.clone();
quote! {
|#(#args),*| match (#(#args1),*) {
(#(Some(#args2)),*) => Some(#fn_name(#(#args3),*)),
_ => None,
}
}
};
quote! {
Ok(Box::new(crate::expr::template_fast::#template_struct::new(
#(exprs[#i]),*, #batch, #func,
)))
}
} else if self.args.len() == 2 && self.ret == "boolean" && self.user_fn.is_pure() {
let compatible_type = types::to_data_type(types::min_compatible_type(&self.args))
.parse::<TokenStream2>()
.unwrap();
quote! {
Ok(Box::new(crate::expr::template_fast::CompareExpression::<_, #(#arg_arrays),*>::new(
exprs[0], exprs[1], #fn_name::<#(#arg_types),*, #compatible_type>,
)))
}
} else if self.args.iter().all(|t| types::is_primitive(t)) && self.user_fn.is_pure() {
let template_struct = match num_args {
1 => format_ident!("UnaryExpression"),
2 => format_ident!("BinaryExpression"),
_ => return Err(Error::new(Span::call_site(), "unsupported arguments")),
};
let i = 0..self.args.len();
quote! {
Ok(Box::new(crate::expr::template_fast::#template_struct::<_, #(#arg_types),*, #ret_type>::new(
#(exprs[#i]),*,
return_type,
#fn_name,
)))
}
} else if self.user_fn.arg_option {
let template_struct = match num_args {
1 => format_ident!("UnaryNullableExpression"),
2 => format_ident!("BinaryNullableExpression"),
3 => format_ident!("TernaryNullableExpression"),
_ => return Err(Error::new(Span::call_site(), "unsupported arguments")),
};
let i = 0..self.args.len();
let func = if self.user_fn.return_result {
quote! { #fn_name }
} else if self.user_fn.return_option {
let args = (0..num_args).map(|i| format_ident!("x{i}"));
let args1 = args.clone();
quote! { |#(#args),*| Ok(#fn_name(#(#args1),*)) }
} else {
let args = (0..num_args).map(|i| format_ident!("x{i}"));
let args1 = args.clone();
quote! { |#(#args),*| Ok(Some(#fn_name(#(#args1),*))) }
};
quote! {
Ok(Box::new(crate::expr::template::#template_struct::<#(#arg_arrays),*, #ret_array, _>::new(
#(exprs[#i]),*,
return_type,
#func,
)))
}
} else {
let template_struct = match num_args {
1 => format_ident!("UnaryExpression"),
2 => format_ident!("BinaryExpression"),
3 => format_ident!("TernaryExpression"),
_ => return Err(Error::new(Span::call_site(), "unsupported arguments")),
};
let i = 0..self.args.len();
let func = if self.user_fn.return_result {
quote! { #fn_name }
} else {
let args = (0..num_args).map(|i| format_ident!("x{i}"));
let args1 = args.clone();
quote! { |#(#args),*| Ok(#fn_name(#(#args1),*)) }
};
quote! {
Ok(Box::new(crate::expr::template::#template_struct::<#(#arg_arrays),*, #ret_array, _>::new(
#(exprs[#i]),*,
return_type,
#func,
)))
}
};
Ok(quote! {
|prost| {
#prepare
#build_expr
}
})
}
}
60 changes: 60 additions & 0 deletions src/expr/macro/src/lib.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
use proc_macro::TokenStream;
wangrunji0408 marked this conversation as resolved.
Show resolved Hide resolved
use proc_macro2::TokenStream as TokenStream2;
use quote::ToTokens;
use syn::{parse_macro_input, Error, Result};

mod gen;
mod parse;
mod types;
mod utils;

#[proc_macro_attribute]
pub fn function(attr: TokenStream, item: TokenStream) -> TokenStream {
let attr = parse_macro_input!(attr as syn::AttributeArgs);
let item = parse_macro_input!(item as syn::ItemFn);
match parse_function(attr, item) {
Ok(tokens) => tokens.into(),
Err(e) => e.to_compile_error().into(),
}
}

fn parse_function(attr: syn::AttributeArgs, item: syn::ItemFn) -> Result<TokenStream2> {
let fn_attr = FunctionAttr::parse(&attr, &item)?;

let mut tokens = item.into_token_stream();
tokens.extend(fn_attr.generate_descriptors()?);
Ok(tokens)
}

#[derive(Debug)]
struct FunctionAttr {
name: String,
args: Vec<String>,
ret: String,
batch: Option<String>,
user_fn: UserFunctionAttr,
}

#[derive(Debug, Clone)]
struct UserFunctionAttr {
/// Function name
name: String,
/// The last argument type is `&mut dyn Write`.
write: bool,
/// The argument type are `Option`s.
arg_option: bool,
/// The return type is `Option`.
return_option: bool,
/// The return type is `Result`.
return_result: bool,
}

impl UserFunctionAttr {
fn is_writer_style(&self) -> bool {
self.write && !self.arg_option && self.return_result
}

fn is_pure(&self) -> bool {
!self.write && !self.arg_option && !self.return_option && !self.return_result
}
}
Loading