diff --git a/src/frontend/src/handler/extended_handle.rs b/src/frontend/src/handler/extended_handle.rs new file mode 100644 index 0000000000000..359c8e4ebe574 --- /dev/null +++ b/src/frontend/src/handler/extended_handle.rs @@ -0,0 +1,94 @@ +// Copyright 2023 RisingWave Labs +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::sync::Arc; + +use bytes::Bytes; +use pgwire::types::Format; +use risingwave_common::error::{ErrorCode, Result}; +use risingwave_common::types::DataType; +use risingwave_sqlparser::ast::Statement; + +use super::{query, HandlerArgs, RwPgResponse}; +use crate::binder::BoundStatement; +use crate::session::SessionImpl; + +pub struct PrepareStatement { + pub statement: Statement, + pub bound_statement: BoundStatement, + pub param_types: Vec, +} + +pub struct Portal { + pub statement: Statement, + pub bound_statement: BoundStatement, + pub result_formats: Vec, +} + +pub fn handle_parse( + session: Arc, + stmt: Statement, + specific_param_types: Vec, +) -> Result { + session.clear_cancel_query_flag(); + let str_sql = stmt.to_string(); + let handler_args = HandlerArgs::new(session, &stmt, &str_sql)?; + match stmt { + Statement::Query(_) + | Statement::Insert { .. } + | Statement::Delete { .. } + | Statement::Update { .. } => query::handle_parse(handler_args, stmt, specific_param_types), + _ => Err(ErrorCode::NotSupported( + format!("Can't support {} in extended query mode now", str_sql,), + "".to_string(), + ) + .into()), + } +} + +pub fn handle_bind( + prepare_statement: PrepareStatement, + params: Vec, + param_formats: Vec, + result_formats: Vec, +) -> Result { + let PrepareStatement { + statement, + bound_statement, + .. + } = prepare_statement; + let bound_statement = bound_statement.bind_parameter(params, param_formats)?; + Ok(Portal { + statement, + bound_statement, + result_formats, + }) +} + +pub async fn handle_execute(session: Arc, portal: Portal) -> Result { + session.clear_cancel_query_flag(); + let str_sql = portal.statement.to_string(); + let handler_args = HandlerArgs::new(session, &portal.statement, &str_sql)?; + match &portal.statement { + Statement::Query(_) + | Statement::Insert { .. } + | Statement::Delete { .. } + | Statement::Update { .. } => query::handle_execute(handler_args, portal).await, + _ => Err(ErrorCode::NotSupported( + format!("Can't support {} in extended query mode now", str_sql,), + "".to_string(), + ) + .into()), + } +} diff --git a/src/frontend/src/handler/mod.rs b/src/frontend/src/handler/mod.rs index 550985a591672..01cd757f21c39 100644 --- a/src/frontend/src/handler/mod.rs +++ b/src/frontend/src/handler/mod.rs @@ -59,6 +59,7 @@ pub mod drop_table; pub mod drop_user; mod drop_view; pub mod explain; +pub mod extended_handle; mod flush; pub mod handle_privilege; pub mod privilege; diff --git a/src/frontend/src/handler/query.rs b/src/frontend/src/handler/query.rs index 4d437b304515b..7b5a4ae708511 100644 --- a/src/frontend/src/handler/query.rs +++ b/src/frontend/src/handler/query.rs @@ -24,8 +24,10 @@ use postgres_types::FromSql; use risingwave_common::catalog::Schema; use risingwave_common::error::{ErrorCode, Result, RwError}; use risingwave_common::session_config::QueryMode; +use risingwave_common::types::DataType; use risingwave_sqlparser::ast::{SetExpr, Statement}; +use super::extended_handle::{Portal, PrepareStatement}; use super::{PgResponseStream, RwPgResponse}; use crate::binder::{Binder, BoundSetExpr, BoundStatement}; use crate::handler::flush::do_flush; @@ -74,6 +76,20 @@ fn must_run_in_distributed_mode(stmt: &Statement) -> Result { ) | is_insert_using_select(stmt)) } +fn must_run_in_local_mode(bound: &BoundStatement) -> bool { + let mut must_local = false; + + if let BoundStatement::Query(query) = &bound { + if let BoundSetExpr::Select(select) = &query.body + && let Some(relation) = &select.from + && relation.contains_sys_table() { + must_local = true; + } + } + + must_local +} + pub fn gen_batch_query_plan( session: &SessionImpl, context: OptimizerContextRef, @@ -89,16 +105,9 @@ pub fn gen_batch_query_plan( let check_items = resolve_privileges(&bound); session.check_privileges(&check_items)?; - let mut planner = Planner::new(context); + let must_local = must_run_in_local_mode(&bound); - let mut must_local = false; - if let BoundStatement::Query(query) = &bound { - if let BoundSetExpr::Select(select) = &query.body - && let Some(relation) = &select.from - && relation.contains_sys_table() { - must_local = true; - } - } + let mut planner = Planner::new(context); let mut logical = planner.plan(bound)?; let schema = logical.schema(); @@ -339,3 +348,227 @@ pub async fn local_execute( Ok(execution.stream_rows()) } + +pub fn handle_parse( + handler_args: HandlerArgs, + statement: Statement, + specific_param_types: Vec, +) -> Result { + let session = handler_args.session; + let mut binder = Binder::new_with_param_types(&session, specific_param_types); + let bound_statement = binder.bind(statement.clone())?; + + let check_items = resolve_privileges(&bound_statement); + session.check_privileges(&check_items)?; + + let param_types = binder.export_param_types()?; + + Ok(PrepareStatement { + statement, + bound_statement, + param_types, + }) +} + +pub async fn handle_execute(handler_args: HandlerArgs, portal: Portal) -> Result { + let Portal { + statement, + bound_statement, + result_formats, + } = portal; + + let stmt_type = StatementType::infer_from_statement(&statement) + .map_err(|err| RwError::from(ErrorCode::InvalidInputSyntax(err)))?; + let session = handler_args.session.clone(); + let query_start_time = Instant::now(); + let only_checkpoint_visible = handler_args.session.config().only_checkpoint_visible(); + let mut notice = String::new(); + + // Subblock to make sure PlanRef (an Rc) is dropped before `await` below. + let (plan_fragmenter, query_mode, output_schema) = { + let context = OptimizerContext::from_handler_args(handler_args); + + let must_dist = must_run_in_distributed_mode(&statement)?; + let must_local = must_run_in_local_mode(&bound_statement); + + let mut planner = Planner::new(context.into()); + + let mut logical = planner.plan(bound_statement)?; + let schema = logical.schema(); + let batch_plan = logical.gen_batch_plan()?; + + let query_mode = match (must_dist, must_local) { + (true, true) => { + return Err(ErrorCode::InternalError( + "the query is forced to both local and distributed mode by optimizer" + .to_owned(), + ) + .into()) + } + (true, false) => QueryMode::Distributed, + (false, true) => QueryMode::Local, + (false, false) => match session.config().get_query_mode() { + QueryMode::Auto => determine_query_mode(batch_plan.clone()), + QueryMode::Local => QueryMode::Local, + QueryMode::Distributed => QueryMode::Distributed, + }, + }; + + let physical = match query_mode { + QueryMode::Auto => unreachable!(), + QueryMode::Local => logical.gen_batch_local_plan(batch_plan)?, + QueryMode::Distributed => logical.gen_batch_distributed_plan(batch_plan)?, + }; + + let context = physical.plan_base().ctx.clone(); + tracing::trace!( + "Generated query plan: {:?}, query_mode:{:?}", + physical.explain_to_string()?, + query_mode + ); + let plan_fragmenter = BatchPlanFragmenter::new( + session.env().worker_node_manager_ref(), + session.env().catalog_reader().clone(), + session.config().get_batch_parallelism(), + physical, + )?; + context.append_notice(&mut notice); + (plan_fragmenter, query_mode, schema) + }; + let query = plan_fragmenter.generate_complete_query().await?; + tracing::trace!("Generated query after plan fragmenter: {:?}", &query); + + let pg_descs = output_schema + .fields() + .iter() + .map(to_pg_field) + .collect::>(); + let column_types = output_schema + .fields() + .iter() + .map(|f| f.data_type()) + .collect_vec(); + + // Used in counting row count. + let first_field_format = result_formats.first().copied().unwrap_or(Format::Text); + + let mut row_stream = { + let query_epoch = session.config().get_query_epoch(); + let query_snapshot = if let Some(query_epoch) = query_epoch { + PinnedHummockSnapshot::Other(query_epoch) + } else { + // Acquire hummock snapshot for execution. + // TODO: if there's no table scan, we don't need to acquire snapshot. + let hummock_snapshot_manager = session.env().hummock_snapshot_manager(); + let query_id = query.query_id().clone(); + let pinned_snapshot = hummock_snapshot_manager.acquire(&query_id).await?; + PinnedHummockSnapshot::FrontendPinned(pinned_snapshot, only_checkpoint_visible) + }; + match query_mode { + QueryMode::Auto => unreachable!(), + QueryMode::Local => PgResponseStream::LocalQuery(DataChunkToRowSetAdapter::new( + local_execute(session.clone(), query, query_snapshot).await?, + column_types, + result_formats, + session.clone(), + )), + // Local mode do not support cancel tasks. + QueryMode::Distributed => { + PgResponseStream::DistributedQuery(DataChunkToRowSetAdapter::new( + distribute_execute(session.clone(), query, query_snapshot).await?, + column_types, + result_formats, + session.clone(), + )) + } + } + }; + + let rows_count: Option = match stmt_type { + StatementType::SELECT + | StatementType::INSERT_RETURNING + | StatementType::DELETE_RETURNING + | StatementType::UPDATE_RETURNING => None, + + StatementType::INSERT | StatementType::DELETE | StatementType::UPDATE => { + let first_row_set = row_stream.next().await; + let first_row_set = match first_row_set { + None => { + return Err(RwError::from(ErrorCode::InternalError( + "no affected rows in output".to_string(), + ))) + } + Some(row) => { + row.map_err(|err| RwError::from(ErrorCode::InternalError(format!("{}", err))))? + } + }; + let affected_rows_str = first_row_set[0].values()[0] + .as_ref() + .expect("compute node should return affected rows in output"); + if let Format::Binary = first_field_format { + Some( + i64::from_sql(&postgres_types::Type::INT8, affected_rows_str) + .unwrap() + .try_into() + .expect("affected rows count large than i32"), + ) + } else { + Some( + String::from_utf8(affected_rows_str.to_vec()) + .unwrap() + .parse() + .unwrap_or_default(), + ) + } + } + _ => unreachable!(), + }; + + // We need to do some post work after the query is finished and before the `Complete` response + // it sent. This is achieved by the `callback` in `PgResponse`. + let callback = async move { + // Implicitly flush the writes. + if session.config().get_implicit_flush() && stmt_type.is_dml() { + do_flush(&session).await?; + } + + // update some metrics + match query_mode { + QueryMode::Auto => unreachable!(), + QueryMode::Local => { + session + .env() + .frontend_metrics + .latency_local_execution + .observe(query_start_time.elapsed().as_secs_f64()); + + session + .env() + .frontend_metrics + .query_counter_local_execution + .inc(); + } + QueryMode::Distributed => { + session + .env() + .query_manager() + .query_metrics + .query_latency + .observe(query_start_time.elapsed().as_secs_f64()); + + session + .env() + .query_manager() + .query_metrics + .completed_query_counter + .inc(); + } + } + + Ok(()) + }; + + Ok(PgResponse::new_for_stream_extra( + stmt_type, rows_count, row_stream, pg_descs, notice, callback, + )) +}