Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Bushy tree join ordering #8316

Merged
merged 14 commits into from
Mar 18, 2023
21 changes: 20 additions & 1 deletion src/common/src/session_config/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ use crate::util::epoch::Epoch;

// This is a hack, &'static str is not allowed as a const generics argument.
// TODO: refine this using the adt_const_params feature.
const CONFIG_KEYS: [&str; 21] = [
const CONFIG_KEYS: [&str; 22] = [
"RW_IMPLICIT_FLUSH",
"CREATE_COMPACTION_GROUP_FOR_MV",
"QUERY_MODE",
Expand All @@ -56,6 +56,7 @@ const CONFIG_KEYS: [&str; 21] = [
"RW_ENABLE_SHARE_PLAN",
"INTERVALSTYLE",
"BATCH_PARALLELISM",
"RW_STREAMING_ENABLE_BUSHY_JOIN",
];

// MUST HAVE 1v1 relationship to CONFIG_KEYS. e.g. CONFIG_KEYS[IMPLICIT_FLUSH] =
Expand All @@ -81,6 +82,7 @@ const FORCE_TWO_PHASE_AGG: usize = 17;
const RW_ENABLE_SHARE_PLAN: usize = 18;
const INTERVAL_STYLE: usize = 19;
const BATCH_PARALLELISM: usize = 20;
const STREAMING_ENABLE_BUSHY_JOIN: usize = 21;

trait ConfigEntry: Default + for<'a> TryFrom<&'a [&'a str], Error = RwError> {
fn entry_name() -> &'static str;
Expand Down Expand Up @@ -277,6 +279,7 @@ type QueryEpoch = ConfigU64<QUERY_EPOCH, 0>;
type Timezone = ConfigString<TIMEZONE>;
type StreamingParallelism = ConfigU64<STREAMING_PARALLELISM, 0>;
type StreamingEnableDeltaJoin = ConfigBool<STREAMING_ENABLE_DELTA_JOIN, false>;
type StreamingEnableBushyJoin = ConfigBool<STREAMING_ENABLE_BUSHY_JOIN, false>;
type EnableTwoPhaseAgg = ConfigBool<ENABLE_TWO_PHASE_AGG, true>;
type ForceTwoPhaseAgg = ConfigBool<FORCE_TWO_PHASE_AGG, false>;
type EnableSharePlan = ConfigBool<RW_ENABLE_SHARE_PLAN, true>;
Expand Down Expand Up @@ -342,6 +345,9 @@ pub struct ConfigMap {
/// Enable delta join in streaming query. Defaults to false.
streaming_enable_delta_join: StreamingEnableDeltaJoin,

/// Enable bushy join in the streaming query. Defaults to false.
streaming_enable_bushy_join: StreamingEnableBushyJoin,
KveinAxel marked this conversation as resolved.
Show resolved Hide resolved

/// Enable two phase agg optimization. Defaults to true.
/// Setting this to true will always set `FORCE_TWO_PHASE_AGG` to false.
enable_two_phase_agg: EnableTwoPhaseAgg,
Expand Down Expand Up @@ -402,6 +408,8 @@ impl ConfigMap {
self.streaming_parallelism = val.as_slice().try_into()?;
} else if key.eq_ignore_ascii_case(StreamingEnableDeltaJoin::entry_name()) {
self.streaming_enable_delta_join = val.as_slice().try_into()?;
} else if key.eq_ignore_ascii_case(StreamingEnableBushyJoin::entry_name()) {
self.streaming_enable_bushy_join = val.as_slice().try_into()?;
} else if key.eq_ignore_ascii_case(EnableTwoPhaseAgg::entry_name()) {
self.enable_two_phase_agg = val.as_slice().try_into()?;
if !*self.enable_two_phase_agg {
Expand Down Expand Up @@ -458,6 +466,8 @@ impl ConfigMap {
Ok(self.streaming_parallelism.to_string())
} else if key.eq_ignore_ascii_case(StreamingEnableDeltaJoin::entry_name()) {
Ok(self.streaming_enable_delta_join.to_string())
} else if key.eq_ignore_ascii_case(StreamingEnableBushyJoin::entry_name()) {
Ok(self.streaming_enable_bushy_join.to_string())
} else if key.eq_ignore_ascii_case(EnableTwoPhaseAgg::entry_name()) {
Ok(self.enable_two_phase_agg.to_string())
} else if key.eq_ignore_ascii_case(ForceTwoPhaseAgg::entry_name()) {
Expand Down Expand Up @@ -550,6 +560,11 @@ impl ConfigMap {
setting : self.streaming_enable_delta_join.to_string(),
description: String::from("Enable delta join in streaming query.")
},
VariableInfo{
name : StreamingEnableBushyJoin::entry_name().to_lowercase(),
setting : self.streaming_enable_bushy_join.to_string(),
description: String::from("Enable bushy join in streaming query.")
},
VariableInfo{
name : EnableTwoPhaseAgg::entry_name().to_lowercase(),
setting : self.enable_two_phase_agg.to_string(),
Expand Down Expand Up @@ -648,6 +663,10 @@ impl ConfigMap {
*self.streaming_enable_delta_join
}

pub fn get_streaming_enable_bushy_join(&self) -> bool {
*self.streaming_enable_bushy_join
}

pub fn get_enable_two_phase_agg(&self) -> bool {
*self.enable_two_phase_agg
}
Expand Down
10 changes: 8 additions & 2 deletions src/frontend/src/optimizer/logical_optimization.rs
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,12 @@ lazy_static! {
ApplyOrder::TopDown,
);

static ref JOIN_REORDER_STREAM: OptimizationStage = OptimizationStage::new(
"Join Reorder Stream".to_string(),
vec![ReorderMultiJoinRuleStreaming::create()],
ApplyOrder::TopDown,
);
KveinAxel marked this conversation as resolved.
Show resolved Hide resolved

static ref FILTER_WITH_NOW_TO_JOIN: OptimizationStage = OptimizationStage::new(
"Push down filter with now into a left semijoin",
vec![FilterWithNowToJoinRule::create()],
Expand Down Expand Up @@ -365,8 +371,8 @@ impl LogicalOptimizer {
// their relevant joins.
plan = plan.optimize_by_rules(&TO_MULTI_JOIN);

// Reorder multijoin into left-deep join tree.
plan = plan.optimize_by_rules(&JOIN_REORDER);
// Reorder multijoin into join tree.
plan = plan.optimize_by_rules(&JOIN_REORDER_STREAM);

// Predicate Push-down: apply filter pushdown rules again since we pullup all join
// conditions into a filter above the multijoin.
Expand Down
238 changes: 238 additions & 0 deletions src/frontend/src/optimizer/plan_node/logical_multi_join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.

use std::cmp::Ordering;
use std::collections::{BTreeMap, BTreeSet, VecDeque};
use std::fmt;

use itertools::Itertools;
Expand Down Expand Up @@ -483,6 +485,242 @@ impl LogicalMultiJoin {
Ok(join_ordering)
}

pub fn as_bushy_tree_join(&self) -> Result<PlanRef> {
// Join tree internal representation
#[derive(Clone, Default, Debug)]
struct JoinTreeNode {
idx: Option<usize>,
left: Option<Box<JoinTreeNode>>,
right: Option<Box<JoinTreeNode>>,
height: usize,
}

// join graph internal representation
#[derive(Clone, Debug)]
struct GraphNode {
id: usize,
join_tree: JoinTreeNode,
// use BTreeSet for deterministic
relations: BTreeSet<usize>,
}

let mut nodes: BTreeMap<_, _> = (0..self.inputs.len())
.map(|idx| GraphNode {
id: idx,
relations: BTreeSet::new(),
join_tree: JoinTreeNode {
idx: Some(idx),
left: None,
right: None,
height: 0,
},
})
.enumerate()
.collect();
let (eq_join_conditions, _) = self
.on
.clone()
.split_by_input_col_nums(&self.input_col_nums(), true);

for ((src, dst), _) in eq_join_conditions {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Only using the existing join conditions is not enough, we need to do some derivation to gain more equal conditions. For example: a == b && a == c => b == c

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added the derived conditions manually and in this way, the algorithm can generate the optimal bushy tree ordering.

nodes.get_mut(&src).unwrap().relations.insert(dst);
nodes.get_mut(&dst).unwrap().relations.insert(src);
}

// isolated nodes can be joined at any where.
let iso_nodes = nodes
.iter()
.filter_map(|n| {
if n.1.relations.is_empty() {
Some(*n.0)
} else {
None
}
})
.collect_vec();

for n in iso_nodes {
for adj in 0..nodes.len() {
if adj != n {
nodes.get_mut(&n).unwrap().relations.insert(adj);
}
}
}

let mut optimized_bushy_tree = None;
let mut que = VecDeque::from([nodes]);
let mut isolated = BTreeSet::new();

while let Some(mut nodes) = que.pop_front() {
if nodes.len() == 1 {
let node = nodes.into_values().next().unwrap();
optimized_bushy_tree = Some(optimized_bushy_tree.map_or(
node.clone(),
|old_tree: GraphNode| {
if node.join_tree.height < old_tree.join_tree.height {
node
} else {
old_tree
}
},
));
continue;
}

let (idx, _) = nodes
.iter()
.min_by(
|(_, x), (_, y)| match x.relations.len().cmp(&y.relations.len()) {
Ordering::Less => Ordering::Less,
Ordering::Greater => Ordering::Greater,
Ordering::Equal => x.join_tree.height.cmp(&y.join_tree.height),
},
)
.unwrap();
let n = nodes.remove(&idx.clone()).unwrap();

if n.relations.is_empty() {
isolated.insert(n.id);
que.push_back(nodes);
continue;
}

for merge_node in &n.relations {
KveinAxel marked this conversation as resolved.
Show resolved Hide resolved
let mut nodes = nodes.clone();
for adjacent_node in &n.relations {
if *adjacent_node != *merge_node {
nodes
.get_mut(adjacent_node)
.unwrap()
.relations
.remove(&n.id);
nodes
.get_mut(adjacent_node)
.unwrap()
.relations
.insert(*merge_node);
nodes
.get_mut(merge_node)
.unwrap()
.relations
.insert(*adjacent_node);
}
}
let mut merge_graph_node = nodes.get_mut(merge_node).unwrap();
merge_graph_node.relations.remove(&n.id);
let l_tree = n.join_tree.clone();
let r_tree = std::mem::take(&mut merge_graph_node.join_tree);
let new_height = usize::max(l_tree.height, r_tree.height) + 1;

if let Some(min_height) = optimized_bushy_tree.as_ref().map(|t| t.join_tree.height) && min_height < new_height {
continue;
}

merge_graph_node.join_tree = JoinTreeNode {
idx: None,
left: Some(Box::new(l_tree)),
right: Some(Box::new(r_tree)),
height: new_height,
};
que.push_back(nodes);
}
}

fn create_logical_join(
s: &LogicalMultiJoin,
mut join_tree: JoinTreeNode,
join_ordering: &mut Vec<usize>,
) -> Result<PlanRef> {
Ok(match (join_tree.left.take(), join_tree.right.take()) {
(Some(l), Some(r)) => LogicalJoin::new(
create_logical_join(s, *l, join_ordering)?,
create_logical_join(s, *r, join_ordering)?,
JoinType::Inner,
Condition::true_cond(),
)
.into(),
(None, None) => {
if let Some(idx) = join_tree.idx {
join_ordering.push(idx);
s.inputs[idx].clone()
} else {
return Err(RwError::from(ErrorCode::InternalError(
"id of the leaf node not found in the join tree".into(),
)));
}
}
(_, _) => {
return Err(RwError::from(ErrorCode::InternalError(
"only leaf node can have None subtree".into(),
)))
}
})
}

let isolated = isolated.into_iter().collect_vec();
let mut join_ordering = vec![];
let mut output = if let Some(optimized_bushy_tree) = optimized_bushy_tree {
let mut output =
create_logical_join(self, optimized_bushy_tree.join_tree, &mut join_ordering)?;

output = isolated.into_iter().fold(output, |chain, n| {
join_ordering.push(n);
LogicalJoin::new(
chain,
self.inputs[n].clone(),
JoinType::Inner,
Condition::true_cond(),
)
.into()
});
output
} else if !isolated.is_empty() {
let base = isolated[0];
join_ordering.push(isolated[0]);
isolated[1..]
.iter()
.fold(self.inputs[base].clone(), |chain, n| {
join_ordering.push(*n);
LogicalJoin::new(
chain,
self.inputs[*n].clone(),
JoinType::Inner,
Condition::true_cond(),
)
.into()
})
} else {
return Err(RwError::from(ErrorCode::InternalError(
"no plan remain".into(),
)));
};
let total_col_num = self.inner2output.source_size();
let reorder_mapping = {
let mut reorder_mapping = vec![None; total_col_num];

join_ordering
.iter()
.cloned()
.flat_map(|input_idx| {
(0..self.inputs[input_idx].schema().len())
.map(move |col_idx| self.inner_i2o_mappings[input_idx].map(col_idx))
})
.enumerate()
.for_each(|(tar, src)| reorder_mapping[src] = Some(tar));
reorder_mapping
};
output =
LogicalProject::with_out_col_idx(output, reorder_mapping.iter().map(|i| i.unwrap()))
.into();

// We will later push down all of the filters back to the individual joins via the
// `FilterJoinRule`.
output = LogicalFilter::create(output, self.on.clone());
output =
LogicalProject::with_out_col_idx(output, self.output_indices.iter().cloned()).into();
Ok(output)
}

pub(crate) fn input_col_nums(&self) -> Vec<usize> {
self.inputs.iter().map(|i| i.schema().len()).collect()
}
Expand Down
2 changes: 2 additions & 0 deletions src/frontend/src/optimizer/rule/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ mod top_n_on_index_rule;
pub use top_n_on_index_rule::*;
mod stream;
pub use stream::filter_with_now_to_join_rule::*;
pub use stream::reorder_multijoin_rule_streaming::*;
mod trivial_project_to_values_rule;
pub use trivial_project_to_values_rule::*;
mod union_input_values_merge_rule;
Expand Down Expand Up @@ -133,6 +134,7 @@ macro_rules! for_all_rules {
, { RewriteLikeExprRule }
, { AvoidExchangeShareRule }
, { MinMaxOnIndexRule }
, { ReorderMultiJoinRuleStreaming }
}
};
}
Expand Down
1 change: 1 addition & 0 deletions src/frontend/src/optimizer/rule/stream/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,4 @@
// limitations under the License.

pub(crate) mod filter_with_now_to_join_rule;
pub(crate) mod reorder_multijoin_rule_streaming;
Loading