Skip to content

Commit

Permalink
Add Container trait and its blanket implementations, remove `map_un…
Browse files Browse the repository at this point in the history
…til_stop_and_collect` macro, simplify apply and map logic with `Container`s where possible
  • Loading branch information
peter-toth committed Nov 18, 2024
1 parent 6b0570b commit 2036c09
Show file tree
Hide file tree
Showing 9 changed files with 676 additions and 594 deletions.
329 changes: 279 additions & 50 deletions datafusion/common/src/tree_node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,12 @@

//! [`TreeNode`] for visiting and rewriting expression and plan trees
use crate::Result;
use recursive::recursive;
use std::collections::HashMap;
use std::hash::Hash;
use std::sync::Arc;

use crate::Result;

/// These macros are used to determine continuation during transforming traversals.
macro_rules! handle_transform_recursion {
($F_DOWN:expr, $F_CHILD:expr, $F_UP:expr) => {{
Expand Down Expand Up @@ -769,6 +770,263 @@ impl<T> Transformed<T> {
}
}

/// [`Container`] contains elements that a function can be applied on or mapped. The
/// elements of the container are siblings so the continuation rules are similar to
/// [`TreeNodeRecursion::visit_sibling`] / [`Transformed::transform_sibling`].
pub trait Container<'a, T: 'a>: Sized {
fn apply_elements<F: FnMut(&'a T) -> Result<TreeNodeRecursion>>(
&'a self,
f: F,
) -> Result<TreeNodeRecursion>;

fn map_elements<F: FnMut(T) -> Result<Transformed<T>>>(
self,
f: F,
) -> Result<Transformed<Self>>;
}

impl<'a, T: 'a, C: Container<'a, T>> Container<'a, T> for Box<C> {
fn apply_elements<F: FnMut(&'a T) -> Result<TreeNodeRecursion>>(
&'a self,
f: F,
) -> Result<TreeNodeRecursion> {
self.as_ref().apply_elements(f)
}

fn map_elements<F: FnMut(T) -> Result<Transformed<T>>>(
self,
f: F,
) -> Result<Transformed<Self>> {
(*self).map_elements(f)?.map_data(|c| Ok(Self::new(c)))
}
}

impl<'a, T: 'a, C: Container<'a, T> + Clone> Container<'a, T> for Arc<C> {
fn apply_elements<F: FnMut(&'a T) -> Result<TreeNodeRecursion>>(
&'a self,
f: F,
) -> Result<TreeNodeRecursion> {
self.as_ref().apply_elements(f)
}

fn map_elements<F: FnMut(T) -> Result<Transformed<T>>>(
self,
f: F,
) -> Result<Transformed<Self>> {
Arc::unwrap_or_clone(self)
.map_elements(f)?
.map_data(|c| Ok(Arc::new(c)))
}
}

impl<'a, T: 'a, C: Container<'a, T>> Container<'a, T> for Option<C> {
fn apply_elements<F: FnMut(&'a T) -> Result<TreeNodeRecursion>>(
&'a self,
f: F,
) -> Result<TreeNodeRecursion> {
match self {
Some(t) => t.apply_elements(f),
None => Ok(TreeNodeRecursion::Continue),
}
}

fn map_elements<F: FnMut(T) -> Result<Transformed<T>>>(
self,
f: F,
) -> Result<Transformed<Self>> {
self.map_or(Ok(Transformed::no(None)), |c| {
c.map_elements(f)?.map_data(|c| Ok(Some(c)))
})
}
}

impl<'a, T: 'a, C: Container<'a, T>> Container<'a, T> for Vec<C> {
fn apply_elements<F: FnMut(&'a T) -> Result<TreeNodeRecursion>>(
&'a self,
mut f: F,
) -> Result<TreeNodeRecursion> {
let mut tnr = TreeNodeRecursion::Continue;
for c in self {
tnr = c.apply_elements(&mut f)?;
match tnr {
TreeNodeRecursion::Continue | TreeNodeRecursion::Jump => {}
TreeNodeRecursion::Stop => return Ok(TreeNodeRecursion::Stop),
}
}
Ok(tnr)
}

fn map_elements<F: FnMut(T) -> Result<Transformed<T>>>(
self,
mut f: F,
) -> Result<Transformed<Self>> {
let mut tnr = TreeNodeRecursion::Continue;
let mut transformed = false;
self.into_iter()
.map(|c| match tnr {
TreeNodeRecursion::Continue | TreeNodeRecursion::Jump => {
c.map_elements(&mut f).map(|result| {
tnr = result.tnr;
transformed |= result.transformed;
result.data
})
}
TreeNodeRecursion::Stop => Ok(c),
})
.collect::<Result<Vec<_>>>()
.map(|data| Transformed::new(data, transformed, tnr))
}
}

impl<'a, T: 'a, K: Eq + Hash, C: Container<'a, T>> Container<'a, T> for HashMap<K, C> {
fn apply_elements<F: FnMut(&'a T) -> Result<TreeNodeRecursion>>(
&'a self,
mut f: F,
) -> Result<TreeNodeRecursion> {
let mut tnr = TreeNodeRecursion::Continue;
for c in self.values() {
tnr = c.apply_elements(&mut f)?;
match tnr {
TreeNodeRecursion::Continue | TreeNodeRecursion::Jump => {}
TreeNodeRecursion::Stop => return Ok(TreeNodeRecursion::Stop),
}
}
Ok(tnr)
}

fn map_elements<F: FnMut(T) -> Result<Transformed<T>>>(
self,
mut f: F,
) -> Result<Transformed<Self>> {
let mut tnr = TreeNodeRecursion::Continue;
let mut transformed = false;
self.into_iter()
.map(|(k, c)| match tnr {
TreeNodeRecursion::Continue | TreeNodeRecursion::Jump => {
c.map_elements(&mut f).map(|result| {
tnr = result.tnr;
transformed |= result.transformed;
(k, result.data)
})
}
TreeNodeRecursion::Stop => Ok((k, c)),
})
.collect::<Result<HashMap<_, _>>>()
.map(|data| Transformed::new(data, transformed, tnr))
}
}

impl<'a, T: 'a, C0: Container<'a, T>, C1: Container<'a, T>> Container<'a, T>
for (C0, C1)
{
fn apply_elements<F: FnMut(&'a T) -> Result<TreeNodeRecursion>>(
&'a self,
mut f: F,
) -> Result<TreeNodeRecursion> {
self.0
.apply_elements(&mut f)?
.visit_sibling(|| self.1.apply_elements(&mut f))
}

fn map_elements<F: FnMut(T) -> Result<Transformed<T>>>(
self,
mut f: F,
) -> Result<Transformed<Self>> {
self.0
.map_elements(&mut f)?
.map_data(|new_c0| Ok((new_c0, self.1)))?
.transform_sibling(|(new_c0, c1)| {
c1.map_elements(&mut f)?
.map_data(|new_c1| Ok((new_c0, new_c1)))
})
}
}

impl<'a, T: 'a, C0: Container<'a, T>, C1: Container<'a, T>, C2: Container<'a, T>>
Container<'a, T> for (C0, C1, C2)
{
fn apply_elements<F: FnMut(&'a T) -> Result<TreeNodeRecursion>>(
&'a self,
mut f: F,
) -> Result<TreeNodeRecursion> {
self.0
.apply_elements(&mut f)?
.visit_sibling(|| self.1.apply_elements(&mut f))?
.visit_sibling(|| self.2.apply_elements(&mut f))
}

fn map_elements<F: FnMut(T) -> Result<Transformed<T>>>(
self,
mut f: F,
) -> Result<Transformed<Self>> {
self.0
.map_elements(&mut f)?
.map_data(|new_c0| Ok((new_c0, self.1, self.2)))?
.transform_sibling(|(new_c0, c1, c2)| {
c1.map_elements(&mut f)?
.map_data(|new_c1| Ok((new_c0, new_c1, c2)))
})?
.transform_sibling(|(new_c0, new_c1, c2)| {
c2.map_elements(&mut f)?
.map_data(|new_c2| Ok((new_c0, new_c1, new_c2)))
})
}
}

/// [`RefContainer`] contains references to elements that a function can be applied on.
/// The elements of the container are siblings so the continuation rules are similar to
/// [`TreeNodeRecursion::visit_sibling`].
pub trait RefContainer<'a, T: 'a>: Sized {
fn apply_ref_elements<F: FnMut(&'a T) -> Result<TreeNodeRecursion>>(
&self,
f: F,
) -> Result<TreeNodeRecursion>;
}

impl<'a, T: 'a, C: Container<'a, T>> RefContainer<'a, T> for Vec<&'a C> {
fn apply_ref_elements<F: FnMut(&'a T) -> Result<TreeNodeRecursion>>(
&self,
mut f: F,
) -> Result<TreeNodeRecursion> {
let mut tnr = TreeNodeRecursion::Continue;
for c in self {
tnr = c.apply_elements(&mut f)?;
match tnr {
TreeNodeRecursion::Continue | TreeNodeRecursion::Jump => {}
TreeNodeRecursion::Stop => return Ok(TreeNodeRecursion::Stop),
}
}
Ok(tnr)
}
}

impl<'a, T: 'a, C0: Container<'a, T>, C1: Container<'a, T>> RefContainer<'a, T>
for (&'a C0, &'a C1)
{
fn apply_ref_elements<F: FnMut(&'a T) -> Result<TreeNodeRecursion>>(
&self,
mut f: F,
) -> Result<TreeNodeRecursion> {
self.0
.apply_elements(&mut f)?
.visit_sibling(|| self.1.apply_elements(&mut f))
}
}

impl<'a, T: 'a, C0: Container<'a, T>, C1: Container<'a, T>, C2: Container<'a, T>>
RefContainer<'a, T> for (&'a C0, &'a C1, &'a C2)
{
fn apply_ref_elements<F: FnMut(&'a T) -> Result<TreeNodeRecursion>>(
&self,
mut f: F,
) -> Result<TreeNodeRecursion> {
self.0
.apply_elements(&mut f)?
.visit_sibling(|| self.1.apply_elements(&mut f))?
.visit_sibling(|| self.2.apply_elements(&mut f))
}
}

/// Transformation helper to process a sequence of iterable tree nodes that are siblings.
pub trait TreeNodeIterator: Iterator {
/// Apples `f` to each item in this iterator
Expand Down Expand Up @@ -843,50 +1101,6 @@ impl<I: Iterator> TreeNodeIterator for I {
}
}

/// Transformation helper to process a heterogeneous sequence of tree node containing
/// expressions.
///
/// This macro is very similar to [TreeNodeIterator::map_until_stop_and_collect] to
/// process nodes that are siblings, but it accepts an initial transformation (`F0`) and
/// a sequence of pairs. Each pair is made of an expression (`EXPR`) and its
/// transformation (`F`).
///
/// The macro builds up a tuple that contains `Transformed.data` result of `F0` as the
/// first element and further elements from the sequence of pairs. An element from a pair
/// is either the value of `EXPR` or the `Transformed.data` result of `F`, depending on
/// the `Transformed.tnr` result of previous `F`s (`F0` initially).
///
/// # Returns
/// Error if any of the transformations returns an error
///
/// Ok(Transformed<(data0, ..., dataN)>) such that:
/// 1. `transformed` is true if any of the transformations had transformed true
/// 2. `(data0, ..., dataN)`, where `data0` is the `Transformed.data` from `F0` and
/// `data1` ... `dataN` are from either `EXPR` or the `Transformed.data` of `F`
/// 3. `tnr` from `F0` or the last invocation of `F`
#[macro_export]
macro_rules! map_until_stop_and_collect {
($F0:expr, $($EXPR:expr, $F:expr),*) => {{
$F0.and_then(|Transformed { data: data0, mut transformed, mut tnr }| {
let all_datas = (
data0,
$(
if tnr == TreeNodeRecursion::Continue || tnr == TreeNodeRecursion::Jump {
$F.map(|result| {
tnr = result.tnr;
transformed |= result.transformed;
result.data
})?
} else {
$EXPR
},
)*
);
Ok(Transformed::new(all_datas, transformed, tnr))
})
}}
}

/// Transformation helper to access [`Transformed`] fields in a [`Result`] easily.
///
/// # Example
Expand Down Expand Up @@ -1021,7 +1235,7 @@ pub(crate) mod tests {
use std::fmt::Display;

use crate::tree_node::{
Transformed, TreeNode, TreeNodeIterator, TreeNodeRecursion, TreeNodeRewriter,
Container, Transformed, TreeNode, TreeNodeRecursion, TreeNodeRewriter,
TreeNodeVisitor,
};
use crate::Result;
Expand Down Expand Up @@ -1054,7 +1268,7 @@ pub(crate) mod tests {
&'n self,
f: F,
) -> Result<TreeNodeRecursion> {
self.children.iter().apply_until_stop(f)
self.children.apply_elements(f)
}

fn map_children<F: FnMut(Self) -> Result<Transformed<Self>>>(
Expand All @@ -1063,15 +1277,30 @@ pub(crate) mod tests {
) -> Result<Transformed<Self>> {
Ok(self
.children
.into_iter()
.map_until_stop_and_collect(f)?
.map_elements(f)?
.update_data(|new_children| Self {
children: new_children,
..self
}))
}
}

impl<'a, T: 'a> Container<'a, Self> for TestTreeNode<T> {
fn apply_elements<F: FnMut(&'a Self) -> Result<TreeNodeRecursion>>(
&'a self,
mut f: F,
) -> Result<TreeNodeRecursion> {
f(self)
}

fn map_elements<F: FnMut(Self) -> Result<Transformed<Self>>>(
self,
mut f: F,
) -> Result<Transformed<Self>> {
f(self)
}
}

// J
// |
// I
Expand Down
Loading

0 comments on commit 2036c09

Please sign in to comment.