Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion datafusion-examples/examples/custom_datasource.rs
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ impl ExecutionPlan for CustomExec {
&self.cache
}

fn children(&self) -> Vec<Arc<dyn ExecutionPlan>> {
fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
vec![]
}

Expand Down
135 changes: 111 additions & 24 deletions datafusion/common/src/tree_node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -123,8 +123,8 @@ pub trait TreeNode: Sized {
/// TreeNodeVisitor::f_up(ChildNode2)
/// TreeNodeVisitor::f_up(ParentNode)
/// ```
fn visit<V: TreeNodeVisitor<Node = Self>>(
&self,
fn visit<'n, V: TreeNodeVisitor<'n, Node = Self>>(
&'n self,
visitor: &mut V,
) -> Result<TreeNodeRecursion> {
visitor
Expand Down Expand Up @@ -190,12 +190,12 @@ pub trait TreeNode: Sized {
/// # See Also
/// * [`Self::transform_down`] for the equivalent transformation API.
/// * [`Self::visit`] for both top-down and bottom up traversal.
fn apply<F: FnMut(&Self) -> Result<TreeNodeRecursion>>(
&self,
fn apply<'n, F: FnMut(&'n Self) -> Result<TreeNodeRecursion>>(
&'n self,
mut f: F,
) -> Result<TreeNodeRecursion> {
fn apply_impl<N: TreeNode, F: FnMut(&N) -> Result<TreeNodeRecursion>>(
node: &N,
fn apply_impl<'n, N: TreeNode, F: FnMut(&'n N) -> Result<TreeNodeRecursion>>(
node: &'n N,
f: &mut F,
) -> Result<TreeNodeRecursion> {
f(node)?.visit_children(|| node.apply_children(|c| apply_impl(c, f)))
Expand Down Expand Up @@ -427,8 +427,8 @@ pub trait TreeNode: Sized {
///
/// Description: Apply `f` to inspect node's children (but not the node
/// itself).
fn apply_children<F: FnMut(&Self) -> Result<TreeNodeRecursion>>(
&self,
fn apply_children<'n, F: FnMut(&'n Self) -> Result<TreeNodeRecursion>>(
&'n self,
f: F,
) -> Result<TreeNodeRecursion>;

Expand Down Expand Up @@ -466,19 +466,19 @@ pub trait TreeNode: Sized {
///
/// # See Also:
/// * [`TreeNode::rewrite`] to rewrite owned `TreeNode`s
pub trait TreeNodeVisitor: Sized {
pub trait TreeNodeVisitor<'n>: Sized {
/// The node type which is visitable.
type Node: TreeNode;

/// Invoked while traversing down the tree, before any children are visited.
/// Default implementation continues the recursion.
fn f_down(&mut self, _node: &Self::Node) -> Result<TreeNodeRecursion> {
fn f_down(&mut self, _node: &'n Self::Node) -> Result<TreeNodeRecursion> {
Ok(TreeNodeRecursion::Continue)
}

/// Invoked while traversing up the tree after children are visited. Default
/// implementation continues the recursion.
fn f_up(&mut self, _node: &Self::Node) -> Result<TreeNodeRecursion> {
fn f_up(&mut self, _node: &'n Self::Node) -> Result<TreeNodeRecursion> {
Ok(TreeNodeRecursion::Continue)
}
}
Expand Down Expand Up @@ -855,7 +855,7 @@ impl<T> TransformedResult<T> for Result<Transformed<T>> {
/// its related `Arc<dyn T>` will automatically implement [`TreeNode`].
pub trait DynTreeNode {
/// Returns all children of the specified `TreeNode`.
fn arc_children(&self) -> Vec<Arc<Self>>;
fn arc_children(&self) -> Vec<&Arc<Self>>;

/// Constructs a new node with the specified children.
fn with_new_arc_children(
Expand All @@ -868,11 +868,11 @@ pub trait DynTreeNode {
/// Blanket implementation for any `Arc<T>` where `T` implements [`DynTreeNode`]
/// (such as [`Arc<dyn PhysicalExpr>`]).
impl<T: DynTreeNode + ?Sized> TreeNode for Arc<T> {
fn apply_children<F: FnMut(&Self) -> Result<TreeNodeRecursion>>(
&self,
fn apply_children<'n, F: FnMut(&'n Self) -> Result<TreeNodeRecursion>>(
&'n self,
f: F,
) -> Result<TreeNodeRecursion> {
self.arc_children().iter().apply_until_stop(f)
self.arc_children().into_iter().apply_until_stop(f)
}

fn map_children<F: FnMut(Self) -> Result<Transformed<Self>>>(
Expand All @@ -881,7 +881,10 @@ impl<T: DynTreeNode + ?Sized> TreeNode for Arc<T> {
) -> Result<Transformed<Self>> {
let children = self.arc_children();
if !children.is_empty() {
let new_children = children.into_iter().map_until_stop_and_collect(f)?;
let new_children = children
.into_iter()
.cloned()
.map_until_stop_and_collect(f)?;
// Propagate up `new_children.transformed` and `new_children.tnr`
// along with the node containing transformed children.
if new_children.transformed {
Expand Down Expand Up @@ -913,8 +916,8 @@ pub trait ConcreteTreeNode: Sized {
}

impl<T: ConcreteTreeNode> TreeNode for T {
fn apply_children<F: FnMut(&Self) -> Result<TreeNodeRecursion>>(
&self,
fn apply_children<'n, F: FnMut(&'n Self) -> Result<TreeNodeRecursion>>(
&'n self,
f: F,
) -> Result<TreeNodeRecursion> {
self.children().into_iter().apply_until_stop(f)
Expand All @@ -938,6 +941,7 @@ impl<T: ConcreteTreeNode> TreeNode for T {

#[cfg(test)]
mod tests {
use std::collections::HashMap;
use std::fmt::Display;

use crate::tree_node::{
Expand All @@ -946,7 +950,7 @@ mod tests {
};
use crate::Result;

#[derive(PartialEq, Debug)]
#[derive(Debug, Eq, Hash, PartialEq)]
struct TestTreeNode<T> {
children: Vec<TestTreeNode<T>>,
data: T,
Expand All @@ -959,8 +963,8 @@ mod tests {
}

impl<T> TreeNode for TestTreeNode<T> {
fn apply_children<F: FnMut(&Self) -> Result<TreeNodeRecursion>>(
&self,
fn apply_children<'n, F: FnMut(&'n Self) -> Result<TreeNodeRecursion>>(
&'n self,
f: F,
) -> Result<TreeNodeRecursion> {
self.children.iter().apply_until_stop(f)
Expand Down Expand Up @@ -1459,15 +1463,15 @@ mod tests {
}
}

impl<T: Display> TreeNodeVisitor for TestVisitor<T> {
impl<'n, T: Display> TreeNodeVisitor<'n> for TestVisitor<T> {
type Node = TestTreeNode<T>;

fn f_down(&mut self, node: &Self::Node) -> Result<TreeNodeRecursion> {
fn f_down(&mut self, node: &'n Self::Node) -> Result<TreeNodeRecursion> {
self.visits.push(format!("f_down({})", node.data));
(*self.f_down)(node)
}

fn f_up(&mut self, node: &Self::Node) -> Result<TreeNodeRecursion> {
fn f_up(&mut self, node: &'n Self::Node) -> Result<TreeNodeRecursion> {
self.visits.push(format!("f_up({})", node.data));
(*self.f_up)(node)
}
Expand Down Expand Up @@ -1912,4 +1916,87 @@ mod tests {
TreeNodeRecursion::Stop
)
);

// F
// / | \
// / | \
// E C A
// | / \
// C B D
// / \ |
// B D A
// |
// A
#[test]
fn test_apply_and_visit_references() -> Result<()> {
let node_a = TestTreeNode::new(vec![], "a".to_string());
let node_b = TestTreeNode::new(vec![], "b".to_string());
let node_d = TestTreeNode::new(vec![node_a], "d".to_string());
let node_c = TestTreeNode::new(vec![node_b, node_d], "c".to_string());
let node_e = TestTreeNode::new(vec![node_c], "e".to_string());
let node_a_2 = TestTreeNode::new(vec![], "a".to_string());
let node_b_2 = TestTreeNode::new(vec![], "b".to_string());
let node_d_2 = TestTreeNode::new(vec![node_a_2], "d".to_string());
let node_c_2 = TestTreeNode::new(vec![node_b_2, node_d_2], "c".to_string());
let node_a_3 = TestTreeNode::new(vec![], "a".to_string());
let tree = TestTreeNode::new(vec![node_e, node_c_2, node_a_3], "f".to_string());

let node_f_ref = &tree;
let node_e_ref = &node_f_ref.children[0];
let node_c_ref = &node_e_ref.children[0];
let node_b_ref = &node_c_ref.children[0];
let node_d_ref = &node_c_ref.children[1];
let node_a_ref = &node_d_ref.children[0];

let mut m: HashMap<&TestTreeNode<String>, usize> = HashMap::new();
tree.apply(|e| {
*m.entry(e).or_insert(0) += 1;
Ok(TreeNodeRecursion::Continue)
})?;

let expected = HashMap::from([
(node_f_ref, 1),
(node_e_ref, 1),
(node_c_ref, 2),
(node_d_ref, 2),
(node_b_ref, 2),
(node_a_ref, 3),
]);
assert_eq!(m, expected);

struct TestVisitor<'n> {
m: HashMap<&'n TestTreeNode<String>, (usize, usize)>,
}

impl<'n> TreeNodeVisitor<'n> for TestVisitor<'n> {
type Node = TestTreeNode<String>;

fn f_down(&mut self, node: &'n Self::Node) -> Result<TreeNodeRecursion> {
let (down_count, _) = self.m.entry(node).or_insert((0, 0));
*down_count += 1;
Ok(TreeNodeRecursion::Continue)
}

fn f_up(&mut self, node: &'n Self::Node) -> Result<TreeNodeRecursion> {
let (_, up_count) = self.m.entry(node).or_insert((0, 0));
*up_count += 1;
Ok(TreeNodeRecursion::Continue)
}
}

let mut visitor = TestVisitor { m: HashMap::new() };
tree.visit(&mut visitor)?;

let expected = HashMap::from([
(node_f_ref, (1, 1)),
(node_e_ref, (1, 1)),
(node_c_ref, (2, 2)),
(node_d_ref, (2, 2)),
(node_b_ref, (2, 2)),
(node_a_ref, (3, 3)),
]);
assert_eq!(visitor.m, expected);

Ok(())
}
}
2 changes: 1 addition & 1 deletion datafusion/core/src/datasource/physical_plan/arrow_file.rs
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ impl ExecutionPlan for ArrowExec {
&self.cache
}

fn children(&self) -> Vec<Arc<dyn ExecutionPlan>> {
fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
Vec::new()
}

Expand Down
2 changes: 1 addition & 1 deletion datafusion/core/src/datasource/physical_plan/avro.rs
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ impl ExecutionPlan for AvroExec {
&self.cache
}

fn children(&self) -> Vec<Arc<dyn ExecutionPlan>> {
fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
Vec::new()
}

Expand Down
2 changes: 1 addition & 1 deletion datafusion/core/src/datasource/physical_plan/csv.rs
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ impl ExecutionPlan for CsvExec {
&self.cache
}

fn children(&self) -> Vec<Arc<dyn ExecutionPlan>> {
fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
// this is a leaf node and has no children
vec![]
}
Expand Down
2 changes: 1 addition & 1 deletion datafusion/core/src/datasource/physical_plan/json.rs
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ impl ExecutionPlan for NdJsonExec {
&self.cache
}

fn children(&self) -> Vec<Arc<dyn ExecutionPlan>> {
fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
Vec::new()
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -363,7 +363,7 @@ impl ExecutionPlan for ParquetExec {
&self.cache
}

fn children(&self) -> Vec<Arc<dyn ExecutionPlan>> {
fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
// this is a leaf node and has no children
vec![]
}
Expand Down
4 changes: 2 additions & 2 deletions datafusion/core/src/execution/context/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2441,10 +2441,10 @@ impl<'a> BadPlanVisitor<'a> {
}
}

impl<'a> TreeNodeVisitor for BadPlanVisitor<'a> {
impl<'n, 'a> TreeNodeVisitor<'n> for BadPlanVisitor<'a> {
type Node = LogicalPlan;

fn f_down(&mut self, node: &Self::Node) -> Result<TreeNodeRecursion> {
fn f_down(&mut self, node: &'n Self::Node) -> Result<TreeNodeRecursion> {
match node {
LogicalPlan::Ddl(ddl) if !self.options.allow_ddl => {
plan_err!("DDL not supported: {}", ddl.name())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ fn take_optimizable(node: &dyn ExecutionPlan) -> Option<Arc<dyn ExecutionPlan>>
return Some(child);
}
}
if let [ref childrens_child] = child.children().as_slice() {
if let [childrens_child] = child.children().as_slice() {
child = Arc::clone(childrens_child);
} else {
break;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1375,8 +1375,8 @@ pub(crate) mod tests {
vec![false]
}

fn children(&self) -> Vec<Arc<dyn ExecutionPlan>> {
vec![self.input.clone()]
fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
vec![&self.input]
}

// model that it requires the output ordering of its input
Expand Down
2 changes: 1 addition & 1 deletion datafusion/core/src/physical_optimizer/enforce_sorting.rs
Original file line number Diff line number Diff line change
Expand Up @@ -567,7 +567,7 @@ fn remove_corresponding_sort_from_sub_plan(
// Replace with variants that do not preserve order.
if is_sort_preserving_merge(&node.plan) {
node.children = node.children.swap_remove(0).children;
node.plan = node.plan.children().swap_remove(0);
node.plan = node.plan.children().swap_remove(0).clone();
} else if let Some(repartition) =
node.plan.as_any().downcast_ref::<RepartitionExec>()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,15 +78,15 @@ impl LimitedDistinctAggregation {
let mut is_global_limit = false;
if let Some(local_limit) = plan.as_any().downcast_ref::<LocalLimitExec>() {
limit = local_limit.fetch();
children = local_limit.children();
children = local_limit.children().into_iter().cloned().collect();
} else if let Some(global_limit) = plan.as_any().downcast_ref::<GlobalLimitExec>()
{
global_fetch = global_limit.fetch();
global_fetch?;
global_skip = global_limit.skip();
// the aggregate must read at least fetch+skip number of rows
limit = global_fetch.unwrap() + global_skip;
children = global_limit.children();
children = global_limit.children().into_iter().cloned().collect();
is_global_limit = true
} else {
return None;
Expand Down
6 changes: 3 additions & 3 deletions datafusion/core/src/physical_optimizer/output_requirements.rs
Original file line number Diff line number Diff line change
Expand Up @@ -157,8 +157,8 @@ impl ExecutionPlan for OutputRequirementExec {
vec![true]
}

fn children(&self) -> Vec<Arc<dyn ExecutionPlan>> {
vec![self.input.clone()]
fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
vec![&self.input]
}

fn required_input_ordering(&self) -> Vec<Option<Vec<PhysicalSortRequirement>>> {
Expand Down Expand Up @@ -273,7 +273,7 @@ fn require_top_ordering_helper(
// When an operator requires an ordering, any `SortExec` below can not
// be responsible for (i.e. the originator of) the global ordering.
let (new_child, is_changed) =
require_top_ordering_helper(children.swap_remove(0))?;
require_top_ordering_helper(children.swap_remove(0).clone())?;
Ok((plan.with_new_children(vec![new_child])?, is_changed))
} else {
// Stop searching, there is no global ordering desired for the query.
Expand Down
Loading