Skip to content

Commit

Permalink
feat(frontend): support infer param in binder (risingwavelabs#8453)
Browse files Browse the repository at this point in the history
  • Loading branch information
ZENOTME authored Mar 14, 2023
1 parent 9d5ff78 commit 305c864
Show file tree
Hide file tree
Showing 10 changed files with 212 additions and 11 deletions.
7 changes: 6 additions & 1 deletion src/frontend/src/binder/expr/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ use risingwave_sqlparser::ast::{
};

use crate::binder::Binder;
use crate::expr::{Expr as _, ExprImpl, ExprType, FunctionCall, SubqueryKind};
use crate::expr::{Expr as _, ExprImpl, ExprType, FunctionCall, Parameter, SubqueryKind};

mod binary_op;
mod column;
Expand Down Expand Up @@ -123,6 +123,7 @@ impl Binder {
start,
count,
} => self.bind_overlay(*expr, *new_substring, *start, count),
Expr::Parameter { index } => self.bind_parameter(index),
_ => Err(ErrorCode::NotImplemented(
format!("unsupported expression {:?}", expr),
112.into(),
Expand Down Expand Up @@ -297,6 +298,10 @@ impl Binder {
FunctionCall::new(ExprType::Overlay, args).map(|f| f.into())
}

fn bind_parameter(&mut self, index: u64) -> Result<ExprImpl> {
Ok(Parameter::new(index, self.param_types.clone()).into())
}

/// Bind `expr (not) between low and high`
pub(super) fn bind_between(
&mut self,
Expand Down
100 changes: 96 additions & 4 deletions src/frontend/src/binder/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,13 @@
// limitations under the License.

use std::collections::HashMap;
use std::sync::Arc;
use std::sync::{Arc, RwLock};

use itertools::Itertools;
use risingwave_common::error::Result;
use risingwave_common::session_config::SearchPath;
use risingwave_common::types::DataType;
use risingwave_common::util::iter_util::ZipEqDebug;
use risingwave_sqlparser::ast::Statement;

mod bind_context;
Expand Down Expand Up @@ -90,10 +93,94 @@ pub struct Binder {

/// `ShareId`s identifying shared views.
shared_views: HashMap<ViewId, ShareId>,

param_types: ParameterTypes,
}

/// `ParameterTypes` is used to record the types of the parameters during binding. It works
/// following the rules:
/// 1. At the beginning, it contains the user specified parameters type.
/// 2. When the binder encounters a parameter, it will record it as unknown(call `record_new_param`)
/// if it didn't exist in `ParameterTypes`.
/// 3. When the binder encounters a cast on parameter, if it's a unknown type, the cast function
/// will record the target type as infer type for that parameter(call `record_infer_type`). If the
/// parameter has been inferred, the cast function will act as a normal cast.
/// 4. After bind finished:
/// (a) parameter not in `ParameterTypes` means that the user didn't specify it and it didn't
/// occur in the query. `export` will return error if there is a kind of
/// parameter. This rule is compatible with PostgreSQL
/// (b) parameter is None means that it's a unknown type. The user didn't specify it
/// and we can't infer it in the query. We will treat it as VARCHAR type finally. This rule is
/// compatible with PostgreSQL.
/// (c) parameter is Some means that it's a known type.
#[derive(Clone, Debug)]
pub struct ParameterTypes(Arc<RwLock<HashMap<u64, Option<DataType>>>>);

impl ParameterTypes {
pub fn new(specified_param_types: Vec<DataType>) -> Self {
let map = specified_param_types
.into_iter()
.enumerate()
.map(|(index, data_type)| ((index + 1) as u64, Some(data_type)))
.collect::<HashMap<u64, Option<DataType>>>();
Self(Arc::new(RwLock::new(map)))
}

pub fn has_infer(&self, index: u64) -> bool {
self.0.read().unwrap().get(&index).unwrap().is_some()
}

pub fn read_type(&self, index: u64) -> Option<DataType> {
self.0.read().unwrap().get(&index).unwrap().clone()
}

pub fn record_new_param(&mut self, index: u64) {
self.0.write().unwrap().entry(index).or_insert(None);
}

pub fn record_infer_type(&mut self, index: u64, data_type: DataType) {
assert!(
!self.has_infer(index),
"The parameter has been inferred, should not be inferred again."
);
self.0
.write()
.unwrap()
.get_mut(&index)
.unwrap()
.replace(data_type);
}

pub fn export(&self) -> Result<Vec<DataType>> {
let types = self
.0
.read()
.unwrap()
.clone()
.into_iter()
.sorted_by_key(|(index, _)| *index)
.collect::<Vec<_>>();

// Check if all the parameters have been inferred.
for ((index, _), expect_index) in types.iter().zip_eq_debug(1_u64..=types.len() as u64) {
if *index != expect_index {
return Err(ErrorCode::InvalidInputSyntax(format!(
"Cannot infer the type of the parameter {}.",
expect_index
))
.into());
}
}

Ok(types
.into_iter()
.map(|(_, data_type)| data_type.unwrap_or(DataType::Varchar))
.collect::<Vec<_>>())
}
}

impl Binder {
fn new_inner(session: &SessionImpl, in_create_mv: bool) -> Binder {
fn new_inner(session: &SessionImpl, in_create_mv: bool, param_types: Vec<DataType>) -> Binder {
let now_ms = session
.env()
.hummock_snapshot_manager()
Expand All @@ -114,22 +201,27 @@ impl Binder {
search_path: session.config().get_search_path(),
in_create_mv,
shared_views: HashMap::new(),
param_types: ParameterTypes::new(param_types),
}
}

pub fn new(session: &SessionImpl) -> Binder {
Self::new_inner(session, false)
Self::new_inner(session, false, vec![])
}

pub fn new_for_stream(session: &SessionImpl) -> Binder {
Self::new_inner(session, true)
Self::new_inner(session, true, vec![])
}

/// Bind a [`Statement`].
pub fn bind(&mut self, stmt: Statement) -> Result<BoundStatement> {
self.bind_statement(stmt)
}

pub fn export_param_types(&self) -> Result<Vec<DataType>> {
self.param_types.export()
}

fn push_context(&mut self) {
let new_context = std::mem::take(&mut self.context);
self.context.cte_to_relation = new_context.cte_to_relation.clone();
Expand Down
4 changes: 3 additions & 1 deletion src/frontend/src/expr/expr_mutator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
// limitations under the License.

use super::{
AggCall, CorrelatedInputRef, ExprImpl, FunctionCall, InputRef, Literal, Subquery,
AggCall, CorrelatedInputRef, ExprImpl, FunctionCall, InputRef, Literal, Parameter, Subquery,
TableFunction, UserDefinedFunction, WindowFunction,
};

Expand All @@ -30,6 +30,7 @@ pub trait ExprMutator {
ExprImpl::TableFunction(inner) => self.visit_table_function(inner),
ExprImpl::WindowFunction(inner) => self.visit_window_function(inner),
ExprImpl::UserDefinedFunction(inner) => self.visit_user_defined_function(inner),
ExprImpl::Parameter(inner) => self.visit_parameter(inner),
}
}
fn visit_function_call(&mut self, func_call: &mut FunctionCall) {
Expand All @@ -47,6 +48,7 @@ pub trait ExprMutator {
agg_call.filter_mut().visit_expr_mut(self);
}
fn visit_literal(&mut self, _: &mut Literal) {}
fn visit_parameter(&mut self, _: &mut Parameter) {}
fn visit_input_ref(&mut self, _: &mut InputRef) {}
fn visit_subquery(&mut self, _: &mut Subquery) {}
fn visit_correlated_input_ref(&mut self, _: &mut CorrelatedInputRef) {}
Expand Down
6 changes: 5 additions & 1 deletion src/frontend/src/expr/expr_rewriter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
// limitations under the License.

use super::{
AggCall, CorrelatedInputRef, ExprImpl, FunctionCall, InputRef, Literal, Subquery,
AggCall, CorrelatedInputRef, ExprImpl, FunctionCall, InputRef, Literal, Parameter, Subquery,
TableFunction, UserDefinedFunction, WindowFunction,
};

Expand All @@ -32,6 +32,7 @@ pub trait ExprRewriter {
ExprImpl::TableFunction(inner) => self.rewrite_table_function(*inner),
ExprImpl::WindowFunction(inner) => self.rewrite_window_function(*inner),
ExprImpl::UserDefinedFunction(inner) => self.rewrite_user_defined_function(*inner),
ExprImpl::Parameter(inner) => self.rewrite_parameter(*inner),
}
}
fn rewrite_function_call(&mut self, func_call: FunctionCall) -> ExprImpl {
Expand All @@ -54,6 +55,9 @@ pub trait ExprRewriter {
.unwrap()
.into()
}
fn rewrite_parameter(&mut self, parameter: Parameter) -> ExprImpl {
parameter.into()
}
fn rewrite_literal(&mut self, literal: Literal) -> ExprImpl {
literal.into()
}
Expand Down
6 changes: 5 additions & 1 deletion src/frontend/src/expr/expr_visitor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
// limitations under the License.

use super::{
AggCall, CorrelatedInputRef, ExprImpl, FunctionCall, InputRef, Literal, Subquery,
AggCall, CorrelatedInputRef, ExprImpl, FunctionCall, InputRef, Literal, Parameter, Subquery,
TableFunction, UserDefinedFunction, WindowFunction,
};

Expand Down Expand Up @@ -42,6 +42,7 @@ pub trait ExprVisitor<R: Default> {
ExprImpl::TableFunction(inner) => self.visit_table_function(inner),
ExprImpl::WindowFunction(inner) => self.visit_window_function(inner),
ExprImpl::UserDefinedFunction(inner) => self.visit_user_defined_function(inner),
ExprImpl::Parameter(inner) => self.visit_parameter(inner),
}
}
fn visit_function_call(&mut self, func_call: &FunctionCall) -> R {
Expand All @@ -63,6 +64,9 @@ pub trait ExprVisitor<R: Default> {
r = Self::merge(r, agg_call.filter().visit_expr(self));
r
}
fn visit_parameter(&mut self, _: &Parameter) -> R {
R::default()
}
fn visit_literal(&mut self, _: &Literal) -> R {
R::default()
}
Expand Down
6 changes: 5 additions & 1 deletion src/frontend/src/expr/function_call.rs
Original file line number Diff line number Diff line change
Expand Up @@ -111,10 +111,14 @@ impl FunctionCall {

/// Create a cast expr over `child` to `target` type in `allows` context.
pub fn new_cast(
child: ExprImpl,
mut child: ExprImpl,
target: DataType,
allows: CastContext,
) -> Result<ExprImpl, CastError> {
if let ExprImpl::Parameter(expr) = &mut child && !expr.has_infer() {
expr.cast_infer_type(target);
return Ok(child);
}
if is_row_function(&child) {
// Row function will have empty fields in Datatype::Struct at this point. Therefore,
// we will need to take some special care to generate the cast types. For normal struct
Expand Down
11 changes: 10 additions & 1 deletion src/frontend/src/expr/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ mod correlated_input_ref;
mod function_call;
mod input_ref;
mod literal;
mod parameter;
mod subquery;
mod table_function;
mod user_defined_function;
Expand All @@ -50,6 +51,7 @@ pub use expr_visitor::ExprVisitor;
pub use function_call::{is_row_function, FunctionCall, FunctionCallDisplay};
pub use input_ref::{input_ref_to_column_indices, InputRef, InputRefDisplay};
pub use literal::Literal;
pub use parameter::Parameter;
pub use risingwave_pb::expr::expr_node::Type as ExprType;
pub use session_timezone::SessionTimezone;
pub use subquery::{Subquery, SubqueryKind};
Expand Down Expand Up @@ -96,7 +98,8 @@ impl_expr_impl!(
Subquery,
TableFunction,
WindowFunction,
UserDefinedFunction
UserDefinedFunction,
Parameter
);

impl ExprImpl {
Expand Down Expand Up @@ -174,6 +177,7 @@ impl ExprImpl {
/// Check whether self is a literal NULL or literal string.
pub fn is_unknown(&self) -> bool {
matches!(self, ExprImpl::Literal(literal) if literal.return_type() == DataType::Varchar)
|| matches!(self, ExprImpl::Parameter(parameter) if !parameter.has_infer())
}

/// Shorthand to create cast expr to `target` type in implicit context.
Expand Down Expand Up @@ -761,6 +765,7 @@ impl Expr for ExprImpl {
ExprImpl::TableFunction(expr) => expr.return_type(),
ExprImpl::WindowFunction(expr) => expr.return_type(),
ExprImpl::UserDefinedFunction(expr) => expr.return_type(),
ExprImpl::Parameter(expr) => expr.return_type(),
}
}

Expand All @@ -779,6 +784,7 @@ impl Expr for ExprImpl {
unreachable!("Window function should not be converted to ExprNode")
}
ExprImpl::UserDefinedFunction(e) => e.to_expr_proto(),
ExprImpl::Parameter(e) => e.to_expr_proto(),
}
}
}
Expand Down Expand Up @@ -813,6 +819,7 @@ impl std::fmt::Debug for ExprImpl {
Self::UserDefinedFunction(arg0) => {
f.debug_tuple("UserDefinedFunction").field(arg0).finish()
}
Self::Parameter(arg0) => f.debug_tuple("Parameter").field(arg0).finish(),
};
}
match self {
Expand All @@ -825,6 +832,7 @@ impl std::fmt::Debug for ExprImpl {
Self::TableFunction(x) => write!(f, "{:?}", x),
Self::WindowFunction(x) => write!(f, "{:?}", x),
Self::UserDefinedFunction(x) => write!(f, "{:?}", x),
Self::Parameter(x) => write!(f, "{:?}", x),
}
}
}
Expand Down Expand Up @@ -867,6 +875,7 @@ impl std::fmt::Debug for ExprDisplay<'_> {
write!(f, "{:?}", x)
}
ExprImpl::UserDefinedFunction(x) => write!(f, "{:?}", x),
ExprImpl::Parameter(x) => write!(f, "{:?}", x),
}
}
}
Expand Down
Loading

0 comments on commit 305c864

Please sign in to comment.