Skip to content

Commit 3ba7350

Browse files
Tptalamb
andauthored
feat: Support recursive queries with a distinct 'UNION' (#18254)
Rely on aggregate GroupValues abstraction to build a hash table of the emitted rows that is used to deduplicate We might make things a bit more efficient by rewriting a hash table wrapper just for deduplication, but this implementation should give a fair baseline ## Which issue does this PR close? - Closes #18140. ## Rationale for this change Implements deduplicating recursive CTE (i.e. `UNION` inside of `WITH RECURSIVE`) using a hash table. I reuse the one from aggregates to avoid rebuilding a full wrapper and specialization for types. Each time a batch is returned by the static or the recursive terms of the CTE, the hash table is used to remove already seen rows before emitting the rows and keeping them in memory for the next recursion step. ## What changes are included in this PR? Reusing `GroupValues` trait implementations inside of `RecursiveQueryExec` to get deduplication working. ## Are these changes tested? Yes, some sqllogictests have been added, including ones that would lead to infinite recursion is deduplication where disabled. ## Are there any user-facing changes? No --------- Co-authored-by: Andrew Lamb <[email protected]>
1 parent 10706ae commit 3ba7350

File tree

4 files changed

+171
-33
lines changed

4 files changed

+171
-33
lines changed
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
start,end
2+
1,2
3+
2,3
4+
2,4
5+
2,4
6+
4,1

datafusion/expr/src/logical_plan/builder.rs

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ use datafusion_common::display::ToStringifiedPlan;
5555
use datafusion_common::file_options::file_type::FileType;
5656
use datafusion_common::metadata::FieldMetadata;
5757
use datafusion_common::{
58-
exec_err, get_target_functional_dependencies, internal_datafusion_err, not_impl_err,
58+
exec_err, get_target_functional_dependencies, internal_datafusion_err,
5959
plan_datafusion_err, plan_err, Column, Constraints, DFSchema, DFSchemaRef,
6060
NullEquality, Result, ScalarValue, TableReference, ToDFSchema, UnnestOptions,
6161
};
@@ -179,12 +179,6 @@ impl LogicalPlanBuilder {
179179
recursive_term: LogicalPlan,
180180
is_distinct: bool,
181181
) -> Result<Self> {
182-
// TODO: we need to do a bunch of validation here. Maybe more.
183-
if is_distinct {
184-
return not_impl_err!(
185-
"Recursive queries with a distinct 'UNION' (in which the previous iteration's results will be de-duplicated) is not supported"
186-
);
187-
}
188182
// Ensure that the static term and the recursive term have the same number of fields
189183
let static_fields_len = self.plan.schema().fields().len();
190184
let recursive_fields_len = recursive_term.schema().fields().len();

datafusion/physical-plan/src/recursive_query.rs

Lines changed: 88 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,18 @@ use std::sync::Arc;
2222
use std::task::{Context, Poll};
2323

2424
use super::work_table::{ReservedBatches, WorkTable};
25+
use crate::aggregates::group_values::{new_group_values, GroupValues};
26+
use crate::aggregates::order::GroupOrdering;
2527
use crate::execution_plan::{Boundedness, EmissionType};
28+
use crate::metrics::{
29+
BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet, RecordOutput,
30+
};
2631
use crate::{
27-
metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet},
28-
PlanProperties, RecordBatchStream, SendableRecordBatchStream, Statistics,
32+
DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties, RecordBatchStream,
33+
SendableRecordBatchStream, Statistics,
2934
};
30-
use crate::{DisplayAs, DisplayFormatType, ExecutionPlan};
31-
35+
use arrow::array::{BooleanArray, BooleanBuilder};
36+
use arrow::compute::filter_record_batch;
3237
use arrow::datatypes::SchemaRef;
3338
use arrow::record_batch::RecordBatch;
3439
use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode};
@@ -195,8 +200,9 @@ impl ExecutionPlan for RecursiveQueryExec {
195200
Arc::clone(&self.work_table),
196201
Arc::clone(&self.recursive_term),
197202
static_stream,
203+
self.is_distinct,
198204
baseline_metrics,
199-
)))
205+
)?))
200206
}
201207

202208
fn metrics(&self) -> Option<MetricsSet> {
@@ -267,8 +273,10 @@ struct RecursiveQueryStream {
267273
buffer: Vec<RecordBatch>,
268274
/// Tracks the memory used by the buffer
269275
reservation: MemoryReservation,
270-
// /// Metrics.
271-
_baseline_metrics: BaselineMetrics,
276+
/// If the distinct flag is set, then we use this hash table to remove duplicates from result and work tables
277+
distinct_deduplicator: Option<DistinctDeduplicator>,
278+
/// Metrics.
279+
baseline_metrics: BaselineMetrics,
272280
}
273281

274282
impl RecursiveQueryStream {
@@ -278,12 +286,16 @@ impl RecursiveQueryStream {
278286
work_table: Arc<WorkTable>,
279287
recursive_term: Arc<dyn ExecutionPlan>,
280288
static_stream: SendableRecordBatchStream,
289+
is_distinct: bool,
281290
baseline_metrics: BaselineMetrics,
282-
) -> Self {
291+
) -> Result<Self> {
283292
let schema = static_stream.schema();
284293
let reservation =
285294
MemoryConsumer::new("RecursiveQuery").register(task_context.memory_pool());
286-
Self {
295+
let distinct_deduplicator = is_distinct
296+
.then(|| DistinctDeduplicator::new(Arc::clone(&schema), &task_context))
297+
.transpose()?;
298+
Ok(Self {
287299
task_context,
288300
work_table,
289301
recursive_term,
@@ -292,21 +304,28 @@ impl RecursiveQueryStream {
292304
schema,
293305
buffer: vec![],
294306
reservation,
295-
_baseline_metrics: baseline_metrics,
296-
}
307+
distinct_deduplicator,
308+
baseline_metrics,
309+
})
297310
}
298311

299312
/// Push a clone of the given batch to the in memory buffer, and then return
300313
/// a poll with it.
301314
fn push_batch(
302315
mut self: std::pin::Pin<&mut Self>,
303-
batch: RecordBatch,
316+
mut batch: RecordBatch,
304317
) -> Poll<Option<Result<RecordBatch>>> {
318+
let baseline_metrics = self.baseline_metrics.clone();
319+
if let Some(deduplicator) = &mut self.distinct_deduplicator {
320+
let _timer_guard = baseline_metrics.elapsed_compute().timer();
321+
batch = deduplicator.deduplicate(&batch)?;
322+
}
323+
305324
if let Err(e) = self.reservation.try_grow(batch.get_array_memory_size()) {
306325
return Poll::Ready(Some(Err(e)));
307326
}
308-
309327
self.buffer.push(batch.clone());
328+
(&batch).record_output(&baseline_metrics);
310329
Poll::Ready(Some(Ok(batch)))
311330
}
312331

@@ -391,7 +410,6 @@ impl Stream for RecursiveQueryStream {
391410
mut self: std::pin::Pin<&mut Self>,
392411
cx: &mut Context<'_>,
393412
) -> Poll<Option<Self::Item>> {
394-
// TODO: we should use this poll to record some metrics!
395413
if let Some(static_stream) = &mut self.static_stream {
396414
// While the static term's stream is available, we'll be forwarding the batches from it (also
397415
// saving them for the initial iteration of the recursive term).
@@ -428,5 +446,61 @@ impl RecordBatchStream for RecursiveQueryStream {
428446
}
429447
}
430448

449+
/// Deduplicator based on a hash table.
450+
struct DistinctDeduplicator {
451+
/// Grouped rows used for distinct
452+
group_values: Box<dyn GroupValues>,
453+
reservation: MemoryReservation,
454+
intern_output_buffer: Vec<usize>,
455+
}
456+
457+
impl DistinctDeduplicator {
458+
fn new(schema: SchemaRef, task_context: &TaskContext) -> Result<Self> {
459+
let group_values = new_group_values(schema, &GroupOrdering::None)?;
460+
let reservation = MemoryConsumer::new("RecursiveQueryHashTable")
461+
.register(task_context.memory_pool());
462+
Ok(Self {
463+
group_values,
464+
reservation,
465+
intern_output_buffer: Vec::new(),
466+
})
467+
}
468+
469+
/// Remove duplicated rows from the given batch, keeping a state between batches.
470+
///
471+
/// We use a hash table to allocate new group ids for the new rows.
472+
/// [`GroupValues`] allocate increasing group ids.
473+
/// Hence, if groups (i.e., rows) are new, then they have ids >= length before interning, we keep them.
474+
/// We also detect duplicates by enforcing that group ids are increasing.
475+
fn deduplicate(&mut self, batch: &RecordBatch) -> Result<RecordBatch> {
476+
let size_before = self.group_values.len();
477+
self.intern_output_buffer.reserve(batch.num_rows());
478+
self.group_values
479+
.intern(batch.columns(), &mut self.intern_output_buffer)?;
480+
let mask = new_groups_mask(&self.intern_output_buffer, size_before);
481+
self.intern_output_buffer.clear();
482+
// We update the reservation to reflect the new size of the hash table.
483+
self.reservation.try_resize(self.group_values.size())?;
484+
Ok(filter_record_batch(batch, &mask)?)
485+
}
486+
}
487+
488+
/// Return a mask, each element being true if, and only if, the element is greater than all previous elements and greater or equal than the provided max_already_seen_group_id
489+
fn new_groups_mask(
490+
values: &[usize],
491+
mut max_already_seen_group_id: usize,
492+
) -> BooleanArray {
493+
let mut output = BooleanBuilder::with_capacity(values.len());
494+
for value in values {
495+
if *value >= max_already_seen_group_id {
496+
output.append_value(true);
497+
max_already_seen_group_id = *value + 1; // We want to be increasing
498+
} else {
499+
output.append_value(false);
500+
}
501+
}
502+
output.finish()
503+
}
504+
431505
#[cfg(test)]
432506
mod tests {}

datafusion/sqllogictest/test_files/cte.slt

Lines changed: 76 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -58,18 +58,6 @@ WITH RECURSIVE nodes AS (
5858
statement ok
5959
set datafusion.execution.enable_recursive_ctes = true;
6060

61-
62-
# DISTINCT UNION is not supported
63-
query error DataFusion error: This feature is not implemented: Recursive queries with a distinct 'UNION' \(in which the previous iteration's results will be de\-duplicated\) is not supported
64-
WITH RECURSIVE nodes AS (
65-
SELECT 1 as id
66-
UNION
67-
SELECT id + 1 as id
68-
FROM nodes
69-
WHERE id < 3
70-
) SELECT * FROM nodes
71-
72-
7361
# trivial recursive CTE works
7462
query I rowsort
7563
WITH RECURSIVE nodes AS (
@@ -121,6 +109,22 @@ physical_plan
121109
07)--------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1
122110
08)----------WorkTableExec: name=nodes
123111

112+
# simple deduplicating recursive CTE works
113+
query I
114+
WITH RECURSIVE nodes AS (
115+
SELECT id from (VALUES (1), (2)) nodes(id)
116+
UNION
117+
SELECT id + 1 as id
118+
FROM nodes
119+
WHERE id < 4
120+
)
121+
SELECT * FROM nodes
122+
----
123+
1
124+
2
125+
3
126+
4
127+
124128
# setup
125129
statement ok
126130
CREATE EXTERNAL TABLE balance STORED as CSV LOCATION '../core/tests/data/recursive_cte/balance.csv' OPTIONS ('format.has_header' 'true');
@@ -1044,6 +1048,66 @@ physical_plan
10441048
05)----SortExec: TopK(fetch=1), expr=[v@1 ASC NULLS LAST], preserve_partitioning=[false]
10451049
06)------WorkTableExec: name=r
10461050

1051+
# setup
1052+
statement ok
1053+
CREATE EXTERNAL TABLE closure STORED as CSV LOCATION '../core/tests/data/recursive_cte/closure.csv' OPTIONS ('format.has_header' 'true');
1054+
1055+
# transitive closure with loop
1056+
query II
1057+
WITH RECURSIVE trans AS (
1058+
SELECT * FROM closure
1059+
UNION
1060+
SELECT l.start, r.end
1061+
FROM trans as l, closure AS r
1062+
WHERE l.end = r.start
1063+
) SELECT * FROM trans ORDER BY start, end
1064+
----
1065+
1 1
1066+
1 2
1067+
1 3
1068+
1 4
1069+
2 1
1070+
2 2
1071+
2 3
1072+
2 4
1073+
4 1
1074+
4 2
1075+
4 3
1076+
4 4
1077+
1078+
query TT
1079+
EXPLAIN WITH RECURSIVE trans AS (
1080+
SELECT * FROM closure
1081+
UNION
1082+
SELECT l.start, r.end
1083+
FROM trans as l, closure AS r
1084+
WHERE l.end = r.start
1085+
) SELECT * FROM trans
1086+
----
1087+
logical_plan
1088+
01)SubqueryAlias: trans
1089+
02)--RecursiveQuery: is_distinct=true
1090+
03)----Projection: closure.start, closure.end
1091+
04)------TableScan: closure
1092+
05)----Projection: l.start, r.end
1093+
06)------Inner Join: l.end = r.start
1094+
07)--------SubqueryAlias: l
1095+
08)----------TableScan: trans
1096+
09)--------SubqueryAlias: r
1097+
10)----------TableScan: closure
1098+
physical_plan
1099+
01)RecursiveQueryExec: name=trans, is_distinct=true
1100+
02)--DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/recursive_cte/closure.csv]]}, projection=[start, end], file_type=csv, has_header=true
1101+
03)--CoalescePartitionsExec
1102+
04)----CoalesceBatchesExec: target_batch_size=8182
1103+
05)------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(end@1, start@0)], projection=[start@0, end@3]
1104+
06)--------CoalesceBatchesExec: target_batch_size=8182
1105+
07)----------RepartitionExec: partitioning=Hash([end@1], 4), input_partitions=1
1106+
08)------------WorkTableExec: name=trans
1107+
09)--------CoalesceBatchesExec: target_batch_size=8182
1108+
10)----------RepartitionExec: partitioning=Hash([start@0], 4), input_partitions=1
1109+
11)------------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/recursive_cte/closure.csv]]}, projection=[start, end], file_type=csv, has_header=true
1110+
10471111
statement count 0
10481112
set datafusion.execution.enable_recursive_ctes = false;
10491113

0 commit comments

Comments
 (0)