Skip to content

Commit

Permalink
refactor(frontend): refactor extended query mode (risingwavelabs#8919)
Browse files Browse the repository at this point in the history
  • Loading branch information
ZENOTME authored Apr 3, 2023
1 parent 06025d5 commit 8278061
Show file tree
Hide file tree
Showing 16 changed files with 716 additions and 1,509 deletions.
8 changes: 5 additions & 3 deletions src/frontend/src/binder/bind_param.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@
// limitations under the License.

use bytes::Bytes;
use pgwire::types::Format;
use risingwave_common::error::{Result, RwError};
use pgwire::types::{Format, FormatIterator};
use risingwave_common::error::{ErrorCode, Result, RwError};
use risingwave_common::types::ScalarImpl;

use super::statement::RewriteExprsRecursive;
Expand Down Expand Up @@ -85,8 +85,10 @@ impl BoundStatement {
param_formats: Vec<Format>,
) -> Result<BoundStatement> {
let mut rewriter = ParamRewriter {
param_formats: FormatIterator::new(&param_formats, params.len())
.map_err(ErrorCode::BindError)?
.collect(),
params,
param_formats,
error: None,
};

Expand Down
2 changes: 1 addition & 1 deletion src/frontend/src/binder/delete.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ use crate::catalog::TableId;
use crate::expr::ExprImpl;
use crate::user::UserId;

#[derive(Debug)]
#[derive(Debug, Clone)]
pub struct BoundDelete {
/// Id of the table to perform deleting.
pub table_id: TableId,
Expand Down
2 changes: 1 addition & 1 deletion src/frontend/src/binder/insert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ use crate::catalog::TableId;
use crate::expr::{ExprImpl, InputRef};
use crate::user::UserId;

#[derive(Debug)]
#[derive(Debug, Clone)]
pub struct BoundInsert {
/// Id of the table to perform inserting.
pub table_id: TableId,
Expand Down
7 changes: 7 additions & 0 deletions src/frontend/src/binder/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,13 @@ impl Binder {
Self::new_inner(session, true, vec![])
}

pub fn new_for_stream_with_param_types(
session: &SessionImpl,
param_types: Vec<DataType>,
) -> Binder {
Self::new_inner(session, true, param_types)
}

/// Bind a [`Statement`].
pub fn bind(&mut self, stmt: Statement) -> Result<BoundStatement> {
self.bind_statement(stmt)
Expand Down
23 changes: 22 additions & 1 deletion src/frontend/src/binder/statement.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.

use risingwave_common::catalog::Field;
use risingwave_common::error::{ErrorCode, Result};
use risingwave_sqlparser::ast::Statement;

Expand All @@ -20,14 +21,34 @@ use super::update::BoundUpdate;
use crate::binder::{Binder, BoundInsert, BoundQuery};
use crate::expr::ExprRewriter;

#[derive(Debug)]
#[derive(Debug, Clone)]
pub enum BoundStatement {
Insert(Box<BoundInsert>),
Delete(Box<BoundDelete>),
Update(Box<BoundUpdate>),
Query(Box<BoundQuery>),
}

impl BoundStatement {
pub fn output_fields(&self) -> Vec<Field> {
match self {
BoundStatement::Insert(i) => i.returning_schema.as_ref().map_or(
vec![Field::unnamed(risingwave_common::types::DataType::Int64)],
|s| s.fields().into(),
),
BoundStatement::Delete(d) => d.returning_schema.as_ref().map_or(
vec![Field::unnamed(risingwave_common::types::DataType::Int64)],
|s| s.fields().into(),
),
BoundStatement::Update(u) => u.returning_schema.as_ref().map_or(
vec![Field::unnamed(risingwave_common::types::DataType::Int64)],
|s| s.fields().into(),
),
BoundStatement::Query(q) => q.schema().fields().into(),
}
}
}

impl Binder {
pub(super) fn bind_statement(&mut self, stmt: Statement) -> Result<BoundStatement> {
match stmt {
Expand Down
2 changes: 1 addition & 1 deletion src/frontend/src/binder/update.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ use crate::catalog::TableId;
use crate::expr::{Expr as _, ExprImpl};
use crate::user::UserId;

#[derive(Debug)]
#[derive(Debug, Clone)]
pub struct BoundUpdate {
/// Id of the table to perform updating.
pub table_id: TableId,
Expand Down
118 changes: 84 additions & 34 deletions src/frontend/src/handler/extended_handle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,19 +18,33 @@ use bytes::Bytes;
use pgwire::types::Format;
use risingwave_common::error::{ErrorCode, Result};
use risingwave_common::types::DataType;
use risingwave_sqlparser::ast::Statement;
use risingwave_sqlparser::ast::{Query, Statement};

use super::{query, HandlerArgs, RwPgResponse};
use super::{handle, query, HandlerArgs, RwPgResponse};
use crate::binder::BoundStatement;
use crate::session::SessionImpl;

pub struct PrepareStatement {
#[derive(Clone)]
pub enum PrepareStatement {
Prepared(PreparedResult),
PureStatement(Statement),
}

#[derive(Clone)]
pub struct PreparedResult {
pub statement: Statement,
pub bound_statement: BoundStatement,
pub param_types: Vec<DataType>,
}

pub struct Portal {
#[derive(Clone)]
pub enum Portal {
Portal(PortalResult),
PureStatement(Statement),
}

#[derive(Clone)]
pub struct PortalResult {
pub statement: Statement,
pub bound_statement: BoundStatement,
pub result_formats: Vec<Format>,
Expand All @@ -44,16 +58,38 @@ pub fn handle_parse(
session.clear_cancel_query_flag();
let str_sql = stmt.to_string();
let handler_args = HandlerArgs::new(session, &stmt, &str_sql)?;
match stmt {
match &stmt {
Statement::Query(_)
| Statement::Insert { .. }
| Statement::Delete { .. }
| Statement::Update { .. } => query::handle_parse(handler_args, stmt, specific_param_types),
_ => Err(ErrorCode::NotSupported(
format!("Can't support {} in extended query mode now", str_sql,),
"".to_string(),
)
.into()),
Statement::CreateView {
query,
..
} => {
if have_parameter_in_query(query) {
return Err(ErrorCode::NotImplemented(
"CREATE VIEW with parameters".to_string(),
None.into(),
)
.into());
}
Ok(PrepareStatement::PureStatement(stmt))
}
Statement::CreateTable {
query,
..
} => {
if let Some(query) = query && have_parameter_in_query(query) {
Err(ErrorCode::NotImplemented(
"CREATE TABLE AS SELECT with parameters".to_string(),
None.into(),
).into())
} else {
Ok(PrepareStatement::PureStatement(stmt))
}
}
_ => Ok(PrepareStatement::PureStatement(stmt)),
}
}

Expand All @@ -63,32 +99,46 @@ pub fn handle_bind(
param_formats: Vec<Format>,
result_formats: Vec<Format>,
) -> Result<Portal> {
let PrepareStatement {
statement,
bound_statement,
..
} = prepare_statement;
let bound_statement = bound_statement.bind_parameter(params, param_formats)?;
Ok(Portal {
statement,
bound_statement,
result_formats,
})
match prepare_statement {
PrepareStatement::Prepared(prepared_result) => {
let PreparedResult {
statement,
bound_statement,
..
} = prepared_result;
let bound_statement = bound_statement.bind_parameter(params, param_formats)?;
Ok(Portal::Portal(PortalResult {
statement,
bound_statement,
result_formats,
}))
}
PrepareStatement::PureStatement(stmt) => Ok(Portal::PureStatement(stmt)),
}
}

pub async fn handle_execute(session: Arc<SessionImpl>, portal: Portal) -> Result<RwPgResponse> {
session.clear_cancel_query_flag();
let str_sql = portal.statement.to_string();
let handler_args = HandlerArgs::new(session, &portal.statement, &str_sql)?;
match &portal.statement {
Statement::Query(_)
| Statement::Insert { .. }
| Statement::Delete { .. }
| Statement::Update { .. } => query::handle_execute(handler_args, portal).await,
_ => Err(ErrorCode::NotSupported(
format!("Can't support {} in extended query mode now", str_sql,),
"".to_string(),
)
.into()),
match portal {
Portal::Portal(portal) => {
session.clear_cancel_query_flag();
let str_sql = portal.statement.to_string();
let handler_args = HandlerArgs::new(session, &portal.statement, &str_sql)?;
match &portal.statement {
Statement::Query(_)
| Statement::Insert { .. }
| Statement::Delete { .. }
| Statement::Update { .. } => query::handle_execute(handler_args, portal).await,
_ => unreachable!(),
}
}
Portal::PureStatement(stmt) => {
let sql = stmt.to_string();
handle(session, stmt, &sql, vec![]).await
}
}
}

/// A quick way to check if a query contains parameters.
fn have_parameter_in_query(query: &Query) -> bool {
query.to_string().contains("$1")
}
91 changes: 53 additions & 38 deletions src/frontend/src/handler/query.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,9 @@ use risingwave_common::session_config::QueryMode;
use risingwave_common::types::DataType;
use risingwave_sqlparser::ast::{SetExpr, Statement};

use super::extended_handle::{Portal, PrepareStatement};
use super::extended_handle::{PortalResult, PrepareStatement, PreparedResult};
use super::{PgResponseStream, RwPgResponse};
use crate::binder::Binder;
use crate::binder::{Binder, BoundStatement};
use crate::catalog::TableId;
use crate::handler::flush::do_flush;
use crate::handler::privilege::resolve_privileges;
Expand Down Expand Up @@ -368,6 +368,8 @@ pub async fn local_execute(
Ok(execution.stream_rows())
}

// TODO: Following code have redundant code with `handle_query`, we may need to refactor them in
// future.
pub fn handle_parse(
handler_args: HandlerArgs,
statement: Statement,
Expand All @@ -382,15 +384,58 @@ pub fn handle_parse(

let param_types = binder.export_param_types()?;

Ok(PrepareStatement {
Ok(PrepareStatement::Prepared(PreparedResult {
statement,
bound_statement,
param_types,
})
}))
}

pub async fn handle_execute(handler_args: HandlerArgs, portal: Portal) -> Result<RwPgResponse> {
let Portal {
pub fn gen_batch_query_plan_for_bound(
session: &SessionImpl,
context: OptimizerContextRef,
stmt: Statement,
bound: BoundStatement,
) -> Result<(PlanRef, QueryMode, Schema)> {
let must_dist = must_run_in_distributed_mode(&stmt)?;

let mut planner = Planner::new(context);

let mut logical = planner.plan(bound)?;
let schema = logical.schema();
let batch_plan = logical.gen_batch_plan()?;

let must_local = must_run_in_local_mode(batch_plan.clone());

let query_mode = match (must_dist, must_local) {
(true, true) => {
return Err(ErrorCode::InternalError(
"the query is forced to both local and distributed mode by optimizer".to_owned(),
)
.into())
}
(true, false) => QueryMode::Distributed,
(false, true) => QueryMode::Local,
(false, false) => match session.config().get_query_mode() {
QueryMode::Auto => determine_query_mode(batch_plan.clone()),
QueryMode::Local => QueryMode::Local,
QueryMode::Distributed => QueryMode::Distributed,
},
};

let physical = match query_mode {
QueryMode::Auto => unreachable!(),
QueryMode::Local => logical.gen_batch_local_plan(batch_plan)?,
QueryMode::Distributed => logical.gen_batch_distributed_plan(batch_plan)?,
};
Ok((physical, query_mode, schema))
}

pub async fn handle_execute(
handler_args: HandlerArgs,
portal: PortalResult,
) -> Result<RwPgResponse> {
let PortalResult {
statement,
bound_statement,
result_formats,
Expand All @@ -407,38 +452,8 @@ pub async fn handle_execute(handler_args: HandlerArgs, portal: Portal) -> Result
let (plan_fragmenter, query_mode, output_schema) = {
let context = OptimizerContext::from_handler_args(handler_args);

let must_dist = must_run_in_distributed_mode(&statement)?;

let mut planner = Planner::new(context.into());

let mut logical = planner.plan(bound_statement)?;
let schema = logical.schema();
let batch_plan = logical.gen_batch_plan()?;

let must_local = must_run_in_local_mode(batch_plan.clone());

let query_mode = match (must_dist, must_local) {
(true, true) => {
return Err(ErrorCode::InternalError(
"the query is forced to both local and distributed mode by optimizer"
.to_owned(),
)
.into())
}
(true, false) => QueryMode::Distributed,
(false, true) => QueryMode::Local,
(false, false) => match session.config().get_query_mode() {
QueryMode::Auto => determine_query_mode(batch_plan.clone()),
QueryMode::Local => QueryMode::Local,
QueryMode::Distributed => QueryMode::Distributed,
},
};

let physical = match query_mode {
QueryMode::Auto => unreachable!(),
QueryMode::Local => logical.gen_batch_local_plan(batch_plan)?,
QueryMode::Distributed => logical.gen_batch_distributed_plan(batch_plan)?,
};
let (physical, query_mode, schema) =
gen_batch_query_plan_for_bound(&session, context.into(), statement, bound_statement)?;

let context = physical.plan_base().ctx.clone();
tracing::trace!(
Expand Down
Loading

0 comments on commit 8278061

Please sign in to comment.