Skip to content

Commit

Permalink
refactor(optimizer): move some methods into core struct && refactor t…
Browse files Browse the repository at this point in the history
…he join's predicate push down (risingwavelabs#8455)
  • Loading branch information
st1page authored Mar 10, 2023
1 parent 64d80d2 commit 79b499c
Show file tree
Hide file tree
Showing 6 changed files with 302 additions and 209 deletions.
141 changes: 140 additions & 1 deletion src/frontend/src/optimizer/plan_node/generic/hop_window.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,18 @@
// limitations under the License.

use std::fmt;
use std::num::NonZeroUsize;

use itertools::Itertools;
use risingwave_common::catalog::{Field, Schema};
use risingwave_common::error::Result;
use risingwave_common::types::{DataType, IntervalUnit, IntervalUnitDisplay};
use risingwave_common::util::column_index_mapping::ColIndexMapping;
use risingwave_expr::ExprError;

use super::super::utils::IndicesDisplay;
use super::{GenericPlanNode, GenericPlanRef};
use crate::expr::{InputRef, InputRefDisplay};
use crate::expr::{ExprImpl, ExprType, FunctionCall, InputRef, InputRefDisplay, Literal};
use crate::optimizer::optimizer_context::OptimizerContextRef;

/// [`HopWindow`] implements Hop Table Function.
Expand Down Expand Up @@ -104,6 +108,141 @@ impl<PlanRef: GenericPlanRef> HopWindow<PlanRef> {
)
}

pub fn internal_window_start_col_idx(&self) -> usize {
self.input.schema().len()
}

pub fn internal_window_end_col_idx(&self) -> usize {
self.internal_window_start_col_idx() + 1
}

pub fn o2i_col_mapping(&self) -> ColIndexMapping {
self.output2internal_col_mapping()
.composite(&self.internal2input_col_mapping())
}

pub fn i2o_col_mapping(&self) -> ColIndexMapping {
self.input2internal_col_mapping()
.composite(&self.internal2output_col_mapping())
}

pub fn internal_column_num(&self) -> usize {
self.internal_window_start_col_idx() + 2
}

pub fn output2internal_col_mapping(&self) -> ColIndexMapping {
self.internal2output_col_mapping().inverse()
}

pub fn internal2output_col_mapping(&self) -> ColIndexMapping {
ColIndexMapping::with_remaining_columns(&self.output_indices, self.internal_column_num())
}

pub fn input2internal_col_mapping(&self) -> ColIndexMapping {
ColIndexMapping::identity_or_none(
self.internal_window_start_col_idx(),
self.internal_column_num(),
)
}

pub fn internal2input_col_mapping(&self) -> ColIndexMapping {
ColIndexMapping::identity_or_none(
self.internal_column_num(),
self.internal_window_start_col_idx(),
)
}

pub fn derive_window_start_and_end_exprs(&self) -> Result<(Vec<ExprImpl>, Vec<ExprImpl>)> {
let Self {
window_size,
window_slide,
time_col,
..
} = &self;
let units = window_size
.exact_div(window_slide)
.and_then(|x| NonZeroUsize::new(usize::try_from(x).ok()?))
.ok_or_else(|| ExprError::InvalidParam {
name: "window",
reason: format!(
"window_size {} cannot be divided by window_slide {}",
window_size, window_slide
),
})?
.get();
let window_size_expr = Literal::new(Some((*window_size).into()), DataType::Interval).into();
let window_slide_expr: ExprImpl =
Literal::new(Some((*window_slide).into()), DataType::Interval).into();
let window_size_sub_slide = FunctionCall::new(
ExprType::Subtract,
vec![window_size_expr, window_slide_expr.clone()],
)?
.into();

let time_col_shifted = FunctionCall::new(
ExprType::Subtract,
vec![
ExprImpl::InputRef(Box::new(time_col.clone())),
window_size_sub_slide,
],
)?
.into();

let hop_start: ExprImpl = FunctionCall::new(
ExprType::TumbleStart,
vec![time_col_shifted, window_slide_expr],
)?
.into();

let mut window_start_exprs = Vec::with_capacity(units);
let mut window_end_exprs = Vec::with_capacity(units);
for i in 0..units {
{
let window_start_offset =
window_slide
.checked_mul_int(i)
.ok_or_else(|| ExprError::InvalidParam {
name: "window",
reason: format!(
"window_slide {} cannot be multiplied by {}",
window_slide, i
),
})?;
let window_start_offset_expr =
Literal::new(Some(window_start_offset.into()), DataType::Interval).into();
let window_start_expr = FunctionCall::new(
ExprType::Add,
vec![hop_start.clone(), window_start_offset_expr],
)?
.into();
window_start_exprs.push(window_start_expr);
}
{
let window_end_offset =
window_slide.checked_mul_int(i + units).ok_or_else(|| {
ExprError::InvalidParam {
name: "window",
reason: format!(
"window_slide {} cannot be multiplied by {}",
window_slide,
i + units
),
}
})?;
let window_end_offset_expr =
Literal::new(Some(window_end_offset.into()), DataType::Interval).into();
let window_end_expr = FunctionCall::new(
ExprType::Add,
vec![hop_start.clone(), window_end_offset_expr],
)?
.into();
window_end_exprs.push(window_end_expr);
}
}
assert_eq!(window_start_exprs.len(), window_end_exprs.len());
Ok((window_start_exprs, window_end_exprs))
}

pub fn fmt_fields_with_builder(&self, builder: &mut fmt::DebugStruct<'_, '_>) {
let output_type = DataType::window_of(&self.time_col.data_type).unwrap();
builder.field(
Expand Down
129 changes: 129 additions & 0 deletions src/frontend/src/optimizer/plan_node/generic/join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -252,3 +252,132 @@ impl<PlanRef: GenericPlanRef> Join<PlanRef> {
}
}
}

/// Try to split and pushdown `predicate` into a into a join condition and into the inputs of the
/// join. Returns the pushed predicates. The pushed part will be removed from the original
/// predicate.
///
/// `InputRef`s in the right pushed condition are indexed by the right child's output schema.
pub fn push_down_into_join(
predicate: &mut Condition,
left_col_num: usize,
right_col_num: usize,
ty: JoinType,
) -> (Condition, Condition, Condition) {
let (left, right) = push_down_to_inputs(
predicate,
left_col_num,
right_col_num,
can_push_left_from_filter(ty),
can_push_right_from_filter(ty),
);

let on = if can_push_on_from_filter(ty) {
let mut conjunctions = std::mem::take(&mut predicate.conjunctions);

// Do not push now on to the on, it will be pulled up into a filter instead.
let on = Condition {
conjunctions: conjunctions
.drain_filter(|expr| expr.count_nows() == 0)
.collect(),
};
predicate.conjunctions = conjunctions;
on
} else {
Condition::true_cond()
};
(left, right, on)
}

/// Try to pushes parts of the join condition to its inputs. Returns the pushed predicates. The
/// pushed part will be removed from the original join predicate.
///
/// `InputRef`s in the right pushed condition are indexed by the right child's output schema.
pub fn push_down_join_condition(
on_condition: &mut Condition,
left_col_num: usize,
right_col_num: usize,
ty: JoinType,
) -> (Condition, Condition) {
push_down_to_inputs(
on_condition,
left_col_num,
right_col_num,
can_push_left_from_on(ty),
can_push_right_from_on(ty),
)
}

/// Try to split and pushdown `predicate` into a join's left/right child.
/// Returns the pushed predicates. The pushed part will be removed from the original predicate.
///
/// `InputRef`s in the right `Condition` are shifted by `-left_col_num`.
fn push_down_to_inputs(
predicate: &mut Condition,
left_col_num: usize,
right_col_num: usize,
push_left: bool,
push_right: bool,
) -> (Condition, Condition) {
let conjunctions = std::mem::take(&mut predicate.conjunctions);

let (mut left, right, mut others) =
Condition { conjunctions }.split(left_col_num, right_col_num);

if !push_left {
others.conjunctions.extend(left);
left = Condition::true_cond();
};

let right = if push_right {
let mut mapping = ColIndexMapping::with_shift_offset(
left_col_num + right_col_num,
-(left_col_num as isize),
);
right.rewrite_expr(&mut mapping)
} else {
others.conjunctions.extend(right);
Condition::true_cond()
};

predicate.conjunctions = others.conjunctions;

(left, right)
}

pub fn can_push_left_from_filter(ty: JoinType) -> bool {
matches!(
ty,
JoinType::Inner | JoinType::LeftOuter | JoinType::LeftSemi | JoinType::LeftAnti
)
}

pub fn can_push_right_from_filter(ty: JoinType) -> bool {
matches!(
ty,
JoinType::Inner | JoinType::RightOuter | JoinType::RightSemi | JoinType::RightAnti
)
}

pub fn can_push_on_from_filter(ty: JoinType) -> bool {
matches!(
ty,
JoinType::Inner | JoinType::LeftSemi | JoinType::RightSemi
)
}

pub fn can_push_left_from_on(ty: JoinType) -> bool {
matches!(
ty,
JoinType::Inner | JoinType::RightOuter | JoinType::LeftSemi
)
}

pub fn can_push_right_from_on(ty: JoinType) -> bool {
matches!(
ty,
JoinType::Inner | JoinType::LeftOuter | JoinType::RightSemi
)
}
26 changes: 5 additions & 21 deletions src/frontend/src/optimizer/plan_node/logical_apply.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ use risingwave_common::catalog::Schema;
use risingwave_common::error::{ErrorCode, Result, RwError};
use risingwave_pb::plan_common::JoinType;

use super::generic::{self, GenericPlanNode};
use super::generic::{self, push_down_into_join, push_down_join_condition, GenericPlanNode};
use super::{
ColPrunable, LogicalJoin, LogicalProject, PlanBase, PlanRef, PlanTreeNodeBinary,
PredicatePushdown, ToBatch, ToStream,
Expand Down Expand Up @@ -318,28 +318,12 @@ impl PredicatePushdown for LogicalApply {
let right_col_num = self.right().schema().len();
let join_type = self.join_type();

let (left_from_filter, right_from_filter, on) = LogicalJoin::push_down(
&mut predicate,
left_col_num,
right_col_num,
LogicalJoin::can_push_left_from_filter(join_type),
LogicalJoin::can_push_right_from_filter(join_type),
LogicalJoin::can_push_on_from_filter(join_type),
);
let (left_from_filter, right_from_filter, on) =
push_down_into_join(&mut predicate, left_col_num, right_col_num, join_type);

let mut new_on = self.on.clone().and(on);
let (left_from_on, right_from_on, on) = LogicalJoin::push_down(
&mut new_on,
left_col_num,
right_col_num,
LogicalJoin::can_push_left_from_on(join_type),
LogicalJoin::can_push_right_from_on(join_type),
false,
);
assert!(
on.always_true(),
"On-clause should not be pushed to on-clause."
);
let (left_from_on, right_from_on) =
push_down_join_condition(&mut new_on, left_col_num, right_col_num, join_type);

let left_predicate = left_from_filter.and(left_from_on);
let right_predicate = right_from_filter.and(right_from_on);
Expand Down
Loading

0 comments on commit 79b499c

Please sign in to comment.