diff --git a/datafusion/core/tests/data/recursive_cte/closure.csv b/datafusion/core/tests/data/recursive_cte/closure.csv new file mode 100644 index 0000000000000..a31e2bfbf36b6 --- /dev/null +++ b/datafusion/core/tests/data/recursive_cte/closure.csv @@ -0,0 +1,6 @@ +start,end +1,2 +2,3 +2,4 +2,4 +4,1 \ No newline at end of file diff --git a/datafusion/expr/src/logical_plan/builder.rs b/datafusion/expr/src/logical_plan/builder.rs index b9afd894d77d3..b291717d1d8e8 100644 --- a/datafusion/expr/src/logical_plan/builder.rs +++ b/datafusion/expr/src/logical_plan/builder.rs @@ -55,7 +55,7 @@ use datafusion_common::display::ToStringifiedPlan; use datafusion_common::file_options::file_type::FileType; use datafusion_common::metadata::FieldMetadata; use datafusion_common::{ - exec_err, get_target_functional_dependencies, internal_datafusion_err, not_impl_err, + exec_err, get_target_functional_dependencies, internal_datafusion_err, plan_datafusion_err, plan_err, Column, Constraints, DFSchema, DFSchemaRef, NullEquality, Result, ScalarValue, TableReference, ToDFSchema, UnnestOptions, }; @@ -179,12 +179,6 @@ impl LogicalPlanBuilder { recursive_term: LogicalPlan, is_distinct: bool, ) -> Result { - // TODO: we need to do a bunch of validation here. Maybe more. - if is_distinct { - return not_impl_err!( - "Recursive queries with a distinct 'UNION' (in which the previous iteration's results will be de-duplicated) is not supported" - ); - } // Ensure that the static term and the recursive term have the same number of fields let static_fields_len = self.plan.schema().fields().len(); let recursive_fields_len = recursive_term.schema().fields().len(); diff --git a/datafusion/physical-plan/src/recursive_query.rs b/datafusion/physical-plan/src/recursive_query.rs index 7b966ed23dbde..e2df8f9578f97 100644 --- a/datafusion/physical-plan/src/recursive_query.rs +++ b/datafusion/physical-plan/src/recursive_query.rs @@ -22,13 +22,18 @@ use std::sync::Arc; use std::task::{Context, Poll}; use super::work_table::{ReservedBatches, WorkTable}; +use crate::aggregates::group_values::{new_group_values, GroupValues}; +use crate::aggregates::order::GroupOrdering; use crate::execution_plan::{Boundedness, EmissionType}; +use crate::metrics::{ + BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet, RecordOutput, +}; use crate::{ - metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet}, - PlanProperties, RecordBatchStream, SendableRecordBatchStream, Statistics, + DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties, RecordBatchStream, + SendableRecordBatchStream, Statistics, }; -use crate::{DisplayAs, DisplayFormatType, ExecutionPlan}; - +use arrow::array::{BooleanArray, BooleanBuilder}; +use arrow::compute::filter_record_batch; use arrow::datatypes::SchemaRef; use arrow::record_batch::RecordBatch; use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; @@ -195,8 +200,9 @@ impl ExecutionPlan for RecursiveQueryExec { Arc::clone(&self.work_table), Arc::clone(&self.recursive_term), static_stream, + self.is_distinct, baseline_metrics, - ))) + )?)) } fn metrics(&self) -> Option { @@ -267,8 +273,10 @@ struct RecursiveQueryStream { buffer: Vec, /// Tracks the memory used by the buffer reservation: MemoryReservation, - // /// Metrics. - _baseline_metrics: BaselineMetrics, + /// If the distinct flag is set, then we use this hash table to remove duplicates from result and work tables + distinct_deduplicator: Option, + /// Metrics. + baseline_metrics: BaselineMetrics, } impl RecursiveQueryStream { @@ -278,12 +286,16 @@ impl RecursiveQueryStream { work_table: Arc, recursive_term: Arc, static_stream: SendableRecordBatchStream, + is_distinct: bool, baseline_metrics: BaselineMetrics, - ) -> Self { + ) -> Result { let schema = static_stream.schema(); let reservation = MemoryConsumer::new("RecursiveQuery").register(task_context.memory_pool()); - Self { + let distinct_deduplicator = is_distinct + .then(|| DistinctDeduplicator::new(Arc::clone(&schema), &task_context)) + .transpose()?; + Ok(Self { task_context, work_table, recursive_term, @@ -292,21 +304,28 @@ impl RecursiveQueryStream { schema, buffer: vec![], reservation, - _baseline_metrics: baseline_metrics, - } + distinct_deduplicator, + baseline_metrics, + }) } /// Push a clone of the given batch to the in memory buffer, and then return /// a poll with it. fn push_batch( mut self: std::pin::Pin<&mut Self>, - batch: RecordBatch, + mut batch: RecordBatch, ) -> Poll>> { + let baseline_metrics = self.baseline_metrics.clone(); + if let Some(deduplicator) = &mut self.distinct_deduplicator { + let _timer_guard = baseline_metrics.elapsed_compute().timer(); + batch = deduplicator.deduplicate(&batch)?; + } + if let Err(e) = self.reservation.try_grow(batch.get_array_memory_size()) { return Poll::Ready(Some(Err(e))); } - self.buffer.push(batch.clone()); + (&batch).record_output(&baseline_metrics); Poll::Ready(Some(Ok(batch))) } @@ -391,7 +410,6 @@ impl Stream for RecursiveQueryStream { mut self: std::pin::Pin<&mut Self>, cx: &mut Context<'_>, ) -> Poll> { - // TODO: we should use this poll to record some metrics! if let Some(static_stream) = &mut self.static_stream { // While the static term's stream is available, we'll be forwarding the batches from it (also // saving them for the initial iteration of the recursive term). @@ -428,5 +446,61 @@ impl RecordBatchStream for RecursiveQueryStream { } } +/// Deduplicator based on a hash table. +struct DistinctDeduplicator { + /// Grouped rows used for distinct + group_values: Box, + reservation: MemoryReservation, + intern_output_buffer: Vec, +} + +impl DistinctDeduplicator { + fn new(schema: SchemaRef, task_context: &TaskContext) -> Result { + let group_values = new_group_values(schema, &GroupOrdering::None)?; + let reservation = MemoryConsumer::new("RecursiveQueryHashTable") + .register(task_context.memory_pool()); + Ok(Self { + group_values, + reservation, + intern_output_buffer: Vec::new(), + }) + } + + /// Remove duplicated rows from the given batch, keeping a state between batches. + /// + /// We use a hash table to allocate new group ids for the new rows. + /// [`GroupValues`] allocate increasing group ids. + /// Hence, if groups (i.e., rows) are new, then they have ids >= length before interning, we keep them. + /// We also detect duplicates by enforcing that group ids are increasing. + fn deduplicate(&mut self, batch: &RecordBatch) -> Result { + let size_before = self.group_values.len(); + self.intern_output_buffer.reserve(batch.num_rows()); + self.group_values + .intern(batch.columns(), &mut self.intern_output_buffer)?; + let mask = new_groups_mask(&self.intern_output_buffer, size_before); + self.intern_output_buffer.clear(); + // We update the reservation to reflect the new size of the hash table. + self.reservation.try_resize(self.group_values.size())?; + Ok(filter_record_batch(batch, &mask)?) + } +} + +/// Return a mask, each element being true if, and only if, the element is greater than all previous elements and greater or equal than the provided max_already_seen_group_id +fn new_groups_mask( + values: &[usize], + mut max_already_seen_group_id: usize, +) -> BooleanArray { + let mut output = BooleanBuilder::with_capacity(values.len()); + for value in values { + if *value >= max_already_seen_group_id { + output.append_value(true); + max_already_seen_group_id = *value + 1; // We want to be increasing + } else { + output.append_value(false); + } + } + output.finish() +} + #[cfg(test)] mod tests {} diff --git a/datafusion/sqllogictest/test_files/cte.slt b/datafusion/sqllogictest/test_files/cte.slt index 03900a608e6a8..fe9077b7f8dc9 100644 --- a/datafusion/sqllogictest/test_files/cte.slt +++ b/datafusion/sqllogictest/test_files/cte.slt @@ -58,18 +58,6 @@ WITH RECURSIVE nodes AS ( statement ok set datafusion.execution.enable_recursive_ctes = true; - -# DISTINCT UNION is not supported -query error DataFusion error: This feature is not implemented: Recursive queries with a distinct 'UNION' \(in which the previous iteration's results will be de\-duplicated\) is not supported -WITH RECURSIVE nodes AS ( - SELECT 1 as id - UNION - SELECT id + 1 as id - FROM nodes - WHERE id < 3 -) SELECT * FROM nodes - - # trivial recursive CTE works query I rowsort WITH RECURSIVE nodes AS ( @@ -121,6 +109,22 @@ physical_plan 07)--------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 08)----------WorkTableExec: name=nodes +# simple deduplicating recursive CTE works +query I +WITH RECURSIVE nodes AS ( + SELECT id from (VALUES (1), (2)) nodes(id) + UNION + SELECT id + 1 as id + FROM nodes + WHERE id < 4 +) +SELECT * FROM nodes +---- +1 +2 +3 +4 + # setup statement ok CREATE EXTERNAL TABLE balance STORED as CSV LOCATION '../core/tests/data/recursive_cte/balance.csv' OPTIONS ('format.has_header' 'true'); @@ -1044,6 +1048,66 @@ physical_plan 05)----SortExec: TopK(fetch=1), expr=[v@1 ASC NULLS LAST], preserve_partitioning=[false] 06)------WorkTableExec: name=r +# setup +statement ok +CREATE EXTERNAL TABLE closure STORED as CSV LOCATION '../core/tests/data/recursive_cte/closure.csv' OPTIONS ('format.has_header' 'true'); + +# transitive closure with loop +query II +WITH RECURSIVE trans AS ( + SELECT * FROM closure + UNION + SELECT l.start, r.end + FROM trans as l, closure AS r + WHERE l.end = r.start +) SELECT * FROM trans ORDER BY start, end +---- +1 1 +1 2 +1 3 +1 4 +2 1 +2 2 +2 3 +2 4 +4 1 +4 2 +4 3 +4 4 + +query TT +EXPLAIN WITH RECURSIVE trans AS ( + SELECT * FROM closure + UNION + SELECT l.start, r.end + FROM trans as l, closure AS r + WHERE l.end = r.start +) SELECT * FROM trans +---- +logical_plan +01)SubqueryAlias: trans +02)--RecursiveQuery: is_distinct=true +03)----Projection: closure.start, closure.end +04)------TableScan: closure +05)----Projection: l.start, r.end +06)------Inner Join: l.end = r.start +07)--------SubqueryAlias: l +08)----------TableScan: trans +09)--------SubqueryAlias: r +10)----------TableScan: closure +physical_plan +01)RecursiveQueryExec: name=trans, is_distinct=true +02)--DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/recursive_cte/closure.csv]]}, projection=[start, end], file_type=csv, has_header=true +03)--CoalescePartitionsExec +04)----CoalesceBatchesExec: target_batch_size=8182 +05)------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(end@1, start@0)], projection=[start@0, end@3] +06)--------CoalesceBatchesExec: target_batch_size=8182 +07)----------RepartitionExec: partitioning=Hash([end@1], 4), input_partitions=1 +08)------------WorkTableExec: name=trans +09)--------CoalesceBatchesExec: target_batch_size=8182 +10)----------RepartitionExec: partitioning=Hash([start@0], 4), input_partitions=1 +11)------------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/recursive_cte/closure.csv]]}, projection=[start, end], file_type=csv, has_header=true + statement count 0 set datafusion.execution.enable_recursive_ctes = false;