Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor(plan_node): Further simplify join-related plan nodes #8905

Merged
merged 5 commits into from
Mar 31, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 35 additions & 57 deletions src/frontend/src/optimizer/plan_node/logical_join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -154,24 +154,18 @@ impl LogicalJoin {

/// Clone with new output indices
pub fn clone_with_output_indices(&self, output_indices: Vec<usize>) -> Self {
Self::with_output_indices(
self.left(),
self.right(),
self.join_type(),
self.on().clone(),
Self::with_core(generic::Join {
output_indices,
)
..self.core.clone()
})
}

/// Clone with new `on` condition
pub fn clone_with_cond(&self, cond: Condition) -> Self {
Self::with_output_indices(
self.left(),
self.right(),
self.join_type(),
cond,
self.output_indices().clone(),
)
pub fn clone_with_cond(&self, on: Condition) -> Self {
Self::with_core(generic::Join {
on,
..self.core.clone()
})
}

pub fn is_left_join(&self) -> bool {
Expand Down Expand Up @@ -287,9 +281,9 @@ impl LogicalJoin {
fn to_batch_lookup_join_with_index_selection(
&self,
predicate: EqJoinPredicate,
logical_join: LogicalJoin,
logical_join: generic::Join<PlanRef>,
) -> Option<BatchLookupJoin> {
match logical_join.join_type() {
match logical_join.join_type {
JoinType::Inner | JoinType::LeftOuter | JoinType::LeftSemi | JoinType::LeftAnti => {}
_ => return None,
};
Expand All @@ -312,10 +306,8 @@ impl LogicalJoin {
if let Some(index_scan) = logical_scan.to_index_scan_if_index_covered(index) {
let index_scan: PlanRef = index_scan.into();
let that = self.clone_with_left_right(self.left(), index_scan.clone());
let new_logical_join = logical_join.clone_with_left_right(
logical_join.left(),
index_scan.to_batch().expect("index scan failed to batch"),
);
let mut new_logical_join = logical_join.clone();
new_logical_join.right = index_scan.to_batch().expect("index scan failed to batch");

// Lookup covered index.
if let Some(lookup_join) =
Expand Down Expand Up @@ -343,9 +335,9 @@ impl LogicalJoin {
fn to_batch_lookup_join(
&self,
predicate: EqJoinPredicate,
logical_join: LogicalJoin,
logical_join: generic::Join<PlanRef>,
) -> Option<BatchLookupJoin> {
match logical_join.join_type() {
match logical_join.join_type {
JoinType::Inner | JoinType::LeftOuter | JoinType::LeftSemi | JoinType::LeftAnti => {}
_ => return None,
};
Expand All @@ -369,15 +361,15 @@ impl LogicalJoin {
max_pos,
order_key
.iter()
.position(|x| *x == d)
.position(|&x| x == d)
.expect("dist_key must in order_key"),
);
}
max_pos + 1
};

// Reorder the join equal predicate to match the order key.
let mut reorder_idx = vec![];
let mut reorder_idx = Vec::with_capacity(at_least_prefix_len);
for order_col_id in order_col_ids {
for (i, eq_idx) in predicate.right_eq_indexes().into_iter().enumerate() {
if order_col_id == output_column_ids[eq_idx] {
Expand All @@ -403,7 +395,7 @@ impl LogicalJoin {
} else {
(0..logical_scan.output_col_idx().len()).collect_vec()
};
let left_schema_len = logical_join.left().schema().len();
let left_schema_len = logical_join.left.schema().len();

let mut join_predicate_rewriter = LookupJoinPredicateRewriter {
offset: left_schema_len,
Expand Down Expand Up @@ -440,10 +432,9 @@ impl LogicalJoin {
// Rewrite the join output indices and all output indices referred to the old scan need to
// rewrite.
let new_join_output_indices = logical_join
.output_indices()
.clone()
.into_iter()
.map(|x| {
.output_indices
.iter()
.map(|&x| {
if x < left_schema_len {
x
} else {
Expand All @@ -456,10 +447,10 @@ impl LogicalJoin {

// Construct a new logical join, because we have change its RHS.
let new_logical_join = generic::Join::new(
logical_join.left(),
logical_join.left,
new_scan.into(),
new_join_on,
logical_join.join_type(),
logical_join.join_type,
new_join_output_indices,
);

Expand Down Expand Up @@ -488,13 +479,11 @@ impl PlanTreeNodeBinary for LogicalJoin {
}

fn clone_with_left_right(&self, left: PlanRef, right: PlanRef) -> Self {
Self::with_output_indices(
Self::with_core(generic::Join {
left,
right,
self.join_type(),
self.on().clone(),
self.output_indices().clone(),
)
..self.core.clone()
})
}

#[must_use]
Expand Down Expand Up @@ -1179,11 +1168,6 @@ impl LogicalJoin {
}
}

fn into_batch_hash_join(self, predicate: EqJoinPredicate) -> Result<PlanRef> {
assert!(predicate.has_eq());
Ok(BatchHashJoin::new(self.core, predicate).into())
}

pub fn index_lookup_join_to_batch_lookup_join(&self) -> Result<PlanRef> {
let predicate = EqJoinPredicate::create(
self.left().schema().len(),
Expand All @@ -1192,20 +1176,15 @@ impl LogicalJoin {
);
assert!(predicate.has_eq());

let left = self.left().to_batch()?;
let right = self.right().to_batch()?;
let logical_join = self.clone_with_left_right(left, right);
let mut logical_join = self.core.clone();
logical_join.left = logical_join.left.to_batch()?;
logical_join.right = logical_join.right.to_batch()?;

Ok(self
.to_batch_lookup_join(predicate, logical_join)
.expect("Fail to convert to lookup join")
.into())
}

fn into_batch_nested_loop_join(self, predicate: EqJoinPredicate) -> Result<PlanRef> {
assert!(!predicate.has_eq());
Ok(BatchNestedLoopJoin::new(self.core).into())
}
}

impl ToBatch for LogicalJoin {
Expand All @@ -1216,9 +1195,9 @@ impl ToBatch for LogicalJoin {
self.on().clone(),
);

let left = self.left().to_batch()?;
let right = self.right().to_batch()?;
let logical_join = self.clone_with_left_right(left, right);
let mut logical_join = self.core.clone();
logical_join.left = logical_join.left.to_batch()?;
logical_join.right = logical_join.right.to_batch()?;

let config = self.base.ctx.session_ctx().config();

Expand All @@ -1238,10 +1217,10 @@ impl ToBatch for LogicalJoin {
}
}

logical_join.into_batch_hash_join(predicate)
Ok(BatchHashJoin::new(logical_join, predicate).into())
} else {
// Convert to Nested-loop Join for non-equal joins
logical_join.into_batch_nested_loop_join(predicate)
Ok(BatchNestedLoopJoin::new(logical_join).into())
}
}
}
Expand Down Expand Up @@ -1315,9 +1294,8 @@ impl ToStream for LogicalJoin {
let mut right_to_add = right
.logical_pk()
.iter()
.cloned()
.filter(|i| r2o.try_map(*i).is_none())
.map(|i| i + left_len)
.filter(|&&i| r2o.try_map(i).is_none())
.map(|&i| i + left_len)
.collect_vec();

// NOTE(st1page): add join keys in the pk_indices a work around before we really have stream
Expand Down