-
Notifications
You must be signed in to change notification settings - Fork 1.8k
feat: Support recursive queries with a distinct 'UNION' #18254
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 4 commits
8539670
cd0701a
6cf4151
48e8e33
af0f140
0af5648
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,6 @@ | ||
| start,end | ||
| 1,2 | ||
| 2,3 | ||
| 2,4 | ||
| 2,4 | ||
| 4,1 |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -22,13 +22,18 @@ use std::sync::Arc; | |
| use std::task::{Context, Poll}; | ||
|
|
||
| use super::work_table::{ReservedBatches, WorkTable}; | ||
| use crate::aggregates::group_values::{new_group_values, GroupValues}; | ||
| use crate::aggregates::order::GroupOrdering; | ||
| use crate::execution_plan::{Boundedness, EmissionType}; | ||
| use crate::metrics::{ | ||
| BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet, RecordOutput, | ||
| }; | ||
| use crate::{ | ||
| metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet}, | ||
| PlanProperties, RecordBatchStream, SendableRecordBatchStream, Statistics, | ||
| DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties, RecordBatchStream, | ||
| SendableRecordBatchStream, Statistics, | ||
| }; | ||
| use crate::{DisplayAs, DisplayFormatType, ExecutionPlan}; | ||
|
|
||
| use arrow::array::{BooleanArray, BooleanBuilder}; | ||
| use arrow::compute::filter_record_batch; | ||
| use arrow::datatypes::SchemaRef; | ||
| use arrow::record_batch::RecordBatch; | ||
| use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; | ||
|
|
@@ -195,8 +200,9 @@ impl ExecutionPlan for RecursiveQueryExec { | |
| Arc::clone(&self.work_table), | ||
| Arc::clone(&self.recursive_term), | ||
| static_stream, | ||
| self.is_distinct, | ||
| baseline_metrics, | ||
| ))) | ||
| )?)) | ||
| } | ||
|
|
||
| fn metrics(&self) -> Option<MetricsSet> { | ||
|
|
@@ -267,8 +273,10 @@ struct RecursiveQueryStream { | |
| buffer: Vec<RecordBatch>, | ||
| /// Tracks the memory used by the buffer | ||
| reservation: MemoryReservation, | ||
| // /// Metrics. | ||
| _baseline_metrics: BaselineMetrics, | ||
| /// If the distinct flag is set, then we use this hash table to remove duplicates from result and work tables | ||
| distinct_deduplicator: Option<DistinctDeduplicator>, | ||
| /// Metrics. | ||
| baseline_metrics: BaselineMetrics, | ||
| } | ||
|
|
||
| impl RecursiveQueryStream { | ||
|
|
@@ -278,12 +286,16 @@ impl RecursiveQueryStream { | |
| work_table: Arc<WorkTable>, | ||
| recursive_term: Arc<dyn ExecutionPlan>, | ||
| static_stream: SendableRecordBatchStream, | ||
| is_distinct: bool, | ||
| baseline_metrics: BaselineMetrics, | ||
| ) -> Self { | ||
| ) -> Result<Self> { | ||
| let schema = static_stream.schema(); | ||
| let reservation = | ||
| MemoryConsumer::new("RecursiveQuery").register(task_context.memory_pool()); | ||
| Self { | ||
| let distinct_deduplicator = is_distinct | ||
| .then(|| DistinctDeduplicator::new(Arc::clone(&schema), &task_context)) | ||
| .transpose()?; | ||
| Ok(Self { | ||
| task_context, | ||
| work_table, | ||
| recursive_term, | ||
|
|
@@ -292,21 +304,28 @@ impl RecursiveQueryStream { | |
| schema, | ||
| buffer: vec![], | ||
| reservation, | ||
| _baseline_metrics: baseline_metrics, | ||
| } | ||
| distinct_deduplicator, | ||
| baseline_metrics, | ||
| }) | ||
| } | ||
|
|
||
| /// Push a clone of the given batch to the in memory buffer, and then return | ||
| /// a poll with it. | ||
| fn push_batch( | ||
| mut self: std::pin::Pin<&mut Self>, | ||
| batch: RecordBatch, | ||
| mut batch: RecordBatch, | ||
| ) -> Poll<Option<Result<RecordBatch>>> { | ||
| let baseline_metrics = self.baseline_metrics.clone(); | ||
| if let Some(deduplicator) = &mut self.distinct_deduplicator { | ||
| let _timer_guard = baseline_metrics.elapsed_compute().timer(); | ||
| batch = deduplicator.deduplicate(&batch)?; | ||
| } | ||
|
|
||
| if let Err(e) = self.reservation.try_grow(batch.get_array_memory_size()) { | ||
| return Poll::Ready(Some(Err(e))); | ||
| } | ||
|
|
||
| self.buffer.push(batch.clone()); | ||
| (&batch).record_output(&baseline_metrics); | ||
| Poll::Ready(Some(Ok(batch))) | ||
| } | ||
|
|
||
|
|
@@ -391,7 +410,6 @@ impl Stream for RecursiveQueryStream { | |
| mut self: std::pin::Pin<&mut Self>, | ||
| cx: &mut Context<'_>, | ||
| ) -> Poll<Option<Self::Item>> { | ||
| // TODO: we should use this poll to record some metrics! | ||
| if let Some(static_stream) = &mut self.static_stream { | ||
| // While the static term's stream is available, we'll be forwarding the batches from it (also | ||
| // saving them for the initial iteration of the recursive term). | ||
|
|
@@ -428,5 +446,58 @@ impl RecordBatchStream for RecursiveQueryStream { | |
| } | ||
| } | ||
|
|
||
| /// Deduplicator based on a hash table. | ||
| struct DistinctDeduplicator { | ||
| /// Grouped rows used for distinct | ||
| group_values: Box<dyn GroupValues>, | ||
| reservation: MemoryReservation, | ||
| intern_output_buffer: Vec<usize>, | ||
| } | ||
|
|
||
| impl DistinctDeduplicator { | ||
| fn new(schema: SchemaRef, task_context: &TaskContext) -> Result<Self> { | ||
| let group_values = new_group_values(schema, &GroupOrdering::None)?; | ||
| let reservation = MemoryConsumer::new("RecursiveQueryHashTable") | ||
| .register(task_context.memory_pool()); | ||
| Ok(Self { | ||
| group_values, | ||
| reservation, | ||
| intern_output_buffer: Vec::new(), | ||
| }) | ||
| } | ||
|
|
||
| /// Remove duplicated rows from the given batch, keeping a state between batches. | ||
| /// | ||
| /// We use a hash table to allocate new group ids for the new rows. | ||
| /// [`GroupValues`] allocate increasing group ids. | ||
| /// Hence, if groups (i.e., rows) are now, then they have ids >= length before interning, we keep them. | ||
Tpt marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| /// We also detect duplicates by enforcing that group ids are increasing. | ||
| fn deduplicate(&mut self, batch: &RecordBatch) -> Result<RecordBatch> { | ||
| let size_before = self.group_values.len(); | ||
| self.intern_output_buffer.reserve(batch.num_rows()); | ||
| self.group_values | ||
| .intern(batch.columns(), &mut self.intern_output_buffer)?; | ||
| let mask = are_increasing_mask(&self.intern_output_buffer, size_before); | ||
| self.intern_output_buffer.clear(); | ||
| // We update the reservation to reflect the new size of the hash table. | ||
| self.reservation.try_resize(self.group_values.size())?; | ||
| Ok(filter_record_batch(batch, &mask)?) | ||
| } | ||
| } | ||
|
|
||
| /// Return a mask, each element being true if, and only if, the element is greater than all previous elements and greater or equal than the provided min_value | ||
| fn are_increasing_mask(values: &[usize], mut min_value: usize) -> BooleanArray { | ||
|
||
| let mut output = BooleanBuilder::with_capacity(values.len()); | ||
| for value in values { | ||
| if *value >= min_value { | ||
| output.append_value(true); | ||
| min_value = *value + 1; // We want to be increasing | ||
| } else { | ||
| output.append_value(false); | ||
| } | ||
| } | ||
| output.finish() | ||
| } | ||
|
|
||
| #[cfg(test)] | ||
| mod tests {} | ||
Uh oh!
There was an error while loading. Please reload this page.