Skip to content

Commit 36bfee0

Browse files
committed
Use create_async_then_emit in no grouping aggregation
1 parent f2db9a5 commit 36bfee0

File tree

3 files changed

+86
-127
lines changed

3 files changed

+86
-127
lines changed

datafusion/physical-plan/src/aggregates/mod.rs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -608,7 +608,9 @@ impl AggregateExec {
608608
}
609609

610610
// grouping by something else and we need to just materialize all results
611-
Ok(StreamType::GroupedHash(row_hash::aggregate_stream(self, context, partition)?))
611+
Ok(StreamType::GroupedHash(row_hash::aggregate_stream(
612+
self, context, partition,
613+
)?))
612614
}
613615

614616
/// Finds the DataType and SortDirection for this Aggregate, if there is one
@@ -1260,7 +1262,7 @@ pub fn create_accumulators(
12601262
/// final value (mode = Final, FinalPartitioned and Single) or states (mode = Partial)
12611263
pub fn finalize_aggregation(
12621264
accumulators: &mut [AccumulatorItem],
1263-
mode: &AggregateMode,
1265+
mode: AggregateMode,
12641266
) -> Result<Vec<ArrayRef>> {
12651267
match mode {
12661268
AggregateMode::Partial => {

datafusion/physical-plan/src/aggregates/no_grouping.rs

Lines changed: 74 additions & 121 deletions
Original file line numberDiff line numberDiff line change
@@ -17,49 +17,47 @@
1717

1818
//! Aggregate without grouping columns
1919
20+
use super::AggregateExec;
2021
use crate::aggregates::{
2122
aggregate_expressions, create_accumulators, finalize_aggregation, AccumulatorItem,
2223
AggregateMode,
2324
};
25+
use crate::filter::batch_filter;
2426
use crate::metrics::{BaselineMetrics, RecordOutput};
25-
use crate::{RecordBatchStream, SendableRecordBatchStream};
27+
use crate::stream::RecordBatchStreamAdapter;
28+
use crate::SendableRecordBatchStream;
2629
use arrow::datatypes::SchemaRef;
2730
use arrow::record_batch::RecordBatch;
2831
use datafusion_common::Result;
32+
use datafusion_execution::memory_pool::{MemoryConsumer, MemoryReservation};
2933
use datafusion_execution::TaskContext;
3034
use datafusion_physical_expr::PhysicalExpr;
31-
use futures::stream::BoxStream;
35+
use futures::stream::StreamExt;
36+
use futures::{stream, TryStreamExt};
3237
use std::borrow::Cow;
38+
use std::future::Future;
39+
use std::pin::Pin;
3340
use std::sync::Arc;
34-
use std::task::{Context, Poll};
35-
36-
use crate::filter::batch_filter;
37-
use datafusion_execution::memory_pool::{MemoryConsumer, MemoryReservation};
38-
use futures::stream::{Stream, StreamExt};
39-
use super::AggregateExec;
41+
use std::task::{ready, Context, Poll};
4042

4143
pub fn aggregate_stream(
4244
agg: &AggregateExec,
4345
context: Arc<TaskContext>,
4446
partition: usize,
4547
) -> Result<SendableRecordBatchStream> {
46-
Ok(Box::pin(AggregateStream::new(agg, context, partition)?))
47-
}
48-
49-
/// stream struct for aggregation without grouping columns
50-
struct AggregateStream {
51-
stream: BoxStream<'static, Result<RecordBatch>>,
52-
schema: SchemaRef,
48+
let aggregate = Aggregate::new(agg, context, partition)?;
49+
50+
// Spawn a task the first time the stream is polled for the sort phase.
51+
// This ensures the consumer of the aggregate does not poll unnecessarily
52+
// while the aggregation is ongoing
53+
Ok(crate::stream::create_async_then_emit(
54+
Arc::clone(&agg.schema),
55+
aggregate,
56+
))
5357
}
5458

55-
/// Actual implementation of [`AggregateStream`].
56-
///
57-
/// This is wrapped into yet another struct because we need to interact with the async memory management subsystem
58-
/// during poll. To have as little code "weirdness" as possible, we chose to just use [`BoxStream`] together with
59-
/// [`futures::stream::unfold`].
60-
///
61-
/// The latter requires a state object, which is [`AggregateStreamInner`].
62-
struct AggregateStreamInner {
59+
/// The state of the aggregation.
60+
struct Aggregate {
6361
schema: SchemaRef,
6462
mode: AggregateMode,
6563
input: SendableRecordBatchStream,
@@ -68,17 +66,14 @@ struct AggregateStreamInner {
6866
filter_expressions: Vec<Option<Arc<dyn PhysicalExpr>>>,
6967
accumulators: Vec<AccumulatorItem>,
7068
reservation: MemoryReservation,
71-
finished: bool,
7269
}
7370

74-
impl AggregateStream {
75-
/// Create a new AggregateStream
76-
pub fn new(
71+
impl Aggregate {
72+
fn new(
7773
agg: &AggregateExec,
7874
context: Arc<TaskContext>,
7975
partition: usize,
8076
) -> Result<Self> {
81-
let agg_schema = Arc::clone(&agg.schema);
8277
let agg_filter_expr = agg.filter_expr.clone();
8378

8479
let baseline_metrics = BaselineMetrics::new(&agg.metrics, partition);
@@ -98,7 +93,7 @@ impl AggregateStream {
9893
let reservation = MemoryConsumer::new(format!("AggregateStream[{partition}]"))
9994
.register(context.memory_pool());
10095

101-
let inner = AggregateStreamInner {
96+
Ok(Self {
10297
schema: Arc::clone(&agg.schema),
10398
mode: agg.mode,
10499
input,
@@ -107,91 +102,55 @@ impl AggregateStream {
107102
filter_expressions,
108103
accumulators,
109104
reservation,
110-
finished: false,
111-
};
112-
let stream = futures::stream::unfold(inner, |mut this| async move {
113-
if this.finished {
114-
return None;
115-
}
116-
117-
let elapsed_compute = this.baseline_metrics.elapsed_compute();
118-
119-
loop {
120-
let result = match this.input.next().await {
121-
Some(Ok(batch)) => {
122-
let timer = elapsed_compute.timer();
123-
let result = aggregate_batch(
124-
&this.mode,
125-
batch,
126-
&mut this.accumulators,
127-
&this.aggregate_expressions,
128-
&this.filter_expressions,
129-
);
130-
131-
timer.done();
132-
133-
// allocate memory
134-
// This happens AFTER we actually used the memory, but simplifies the whole accounting and we are OK with
135-
// overshooting a bit. Also this means we either store the whole record batch or not.
136-
match result
137-
.and_then(|allocated| this.reservation.try_grow(allocated))
138-
{
139-
Ok(_) => continue,
140-
Err(e) => Err(e),
141-
}
142-
}
143-
Some(Err(e)) => Err(e),
144-
None => {
145-
this.finished = true;
146-
let timer = this.baseline_metrics.elapsed_compute().timer();
147-
let result =
148-
finalize_aggregation(&mut this.accumulators, &this.mode)
149-
.and_then(|columns| {
150-
RecordBatch::try_new(
151-
Arc::clone(&this.schema),
152-
columns,
153-
)
154-
.map_err(Into::into)
155-
})
156-
.record_output(&this.baseline_metrics);
157-
158-
timer.done();
159-
160-
result
161-
}
162-
};
163-
164-
this.finished = true;
165-
return Some((result, this));
166-
}
167-
});
168-
169-
// seems like some consumers call this stream even after it returned `None`, so let's fuse the stream.
170-
let stream = stream.fuse();
171-
let stream = Box::pin(stream);
172-
173-
Ok(Self {
174-
schema: agg_schema,
175-
stream,
176105
})
177106
}
178107
}
179108

180-
impl Stream for AggregateStream {
181-
type Item = Result<RecordBatch>;
109+
impl Future for Aggregate {
110+
type Output = Result<SendableRecordBatchStream>;
182111

183-
fn poll_next(
184-
mut self: std::pin::Pin<&mut Self>,
185-
cx: &mut Context<'_>,
186-
) -> Poll<Option<Self::Item>> {
187-
let this = &mut *self;
188-
this.stream.poll_next_unpin(cx)
189-
}
190-
}
112+
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
113+
let elapsed_compute = self.baseline_metrics.elapsed_compute().clone();
114+
115+
loop {
116+
match ready!(self.input.poll_next_unpin(cx)) {
117+
Some(Ok(batch)) => {
118+
let timer = elapsed_compute.timer();
119+
120+
let result = aggregate_batch(&mut self, &batch);
121+
122+
timer.done();
191123

192-
impl RecordBatchStream for AggregateStream {
193-
fn schema(&self) -> SchemaRef {
194-
Arc::clone(&self.schema)
124+
// allocate memory
125+
// This happens AFTER we actually used the memory, but simplifies the whole accounting and we are OK with
126+
// overshooting a bit. Also this means we either store the whole record batch or not.
127+
match result
128+
.and_then(|allocated| self.reservation.try_grow(allocated))
129+
{
130+
Ok(_) => continue,
131+
Err(e) => return Poll::Ready(Err(e)),
132+
}
133+
}
134+
Some(Err(e)) => return Poll::Ready(Err(e)),
135+
None => {
136+
let timer = elapsed_compute.timer();
137+
let mode = self.mode;
138+
let result = finalize_aggregation(&mut self.accumulators, mode)
139+
.and_then(|columns| {
140+
RecordBatch::try_new(Arc::clone(&self.schema), columns)
141+
.map_err(Into::into)
142+
})
143+
.record_output(&self.baseline_metrics);
144+
145+
timer.done();
146+
147+
return Poll::Ready(Ok(Box::pin(RecordBatchStreamAdapter::new(
148+
Arc::clone(&self.schema),
149+
stream::iter(vec![result]),
150+
))));
151+
}
152+
};
153+
}
195154
}
196155
}
197156

@@ -200,13 +159,7 @@ impl RecordBatchStream for AggregateStream {
200159
/// If successful, this returns the additional number of bytes that were allocated during this process.
201160
///
202161
/// TODO: Make this a member function
203-
fn aggregate_batch(
204-
mode: &AggregateMode,
205-
batch: RecordBatch,
206-
accumulators: &mut [AccumulatorItem],
207-
expressions: &[Vec<Arc<dyn PhysicalExpr>>],
208-
filters: &[Option<Arc<dyn PhysicalExpr>>],
209-
) -> Result<usize> {
162+
fn aggregate_batch(agg: &mut Aggregate, batch: &RecordBatch) -> Result<usize> {
210163
let mut allocated = 0usize;
211164

212165
// 1.1 iterate accumulators and respective expressions together
@@ -215,15 +168,15 @@ fn aggregate_batch(
215168
// 1.4 update / merge accumulators with the expressions' values
216169

217170
// 1.1
218-
accumulators
171+
agg.accumulators
219172
.iter_mut()
220-
.zip(expressions)
221-
.zip(filters)
173+
.zip(&agg.aggregate_expressions)
174+
.zip(&agg.filter_expressions)
222175
.try_for_each(|((accum, expr), filter)| {
223176
// 1.2
224177
let batch = match filter {
225-
Some(filter) => Cow::Owned(batch_filter(&batch, filter)?),
226-
None => Cow::Borrowed(&batch),
178+
Some(filter) => Cow::Owned(batch_filter(batch, filter)?),
179+
None => Cow::Borrowed(batch),
227180
};
228181

229182
let n_rows = batch.num_rows();
@@ -236,7 +189,7 @@ fn aggregate_batch(
236189

237190
// 1.4
238191
let size_pre = accum.size();
239-
let res = match mode {
192+
let res = match agg.mode {
240193
AggregateMode::Partial
241194
| AggregateMode::Single
242195
| AggregateMode::SinglePartitioned => accum.update_batch(&values),

datafusion/physical-plan/src/aggregates/row_hash.rs

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -54,10 +54,14 @@ use futures::ready;
5454
use futures::stream::{Stream, StreamExt};
5555
use log::debug;
5656

57-
pub fn aggregate_stream(agg: &AggregateExec,
58-
context: Arc<TaskContext>,
59-
partition: usize,) -> Result<SendableRecordBatchStream> {
60-
Ok(Box::pin(GroupedHashAggregateStream::new(agg, context, partition)?))
57+
pub fn aggregate_stream(
58+
agg: &AggregateExec,
59+
context: Arc<TaskContext>,
60+
partition: usize,
61+
) -> Result<SendableRecordBatchStream> {
62+
Ok(Box::pin(GroupedHashAggregateStream::new(
63+
agg, context, partition,
64+
)?))
6165
}
6266

6367
#[derive(Debug, Clone)]

0 commit comments

Comments
 (0)