1717
1818//! Aggregate without grouping columns
1919
20+ use super :: AggregateExec ;
2021use crate :: aggregates:: {
2122 aggregate_expressions, create_accumulators, finalize_aggregation, AccumulatorItem ,
2223 AggregateMode ,
2324} ;
25+ use crate :: filter:: batch_filter;
2426use crate :: metrics:: { BaselineMetrics , RecordOutput } ;
25- use crate :: { RecordBatchStream , SendableRecordBatchStream } ;
27+ use crate :: stream:: RecordBatchStreamAdapter ;
28+ use crate :: SendableRecordBatchStream ;
2629use arrow:: datatypes:: SchemaRef ;
2730use arrow:: record_batch:: RecordBatch ;
2831use datafusion_common:: Result ;
32+ use datafusion_execution:: memory_pool:: { MemoryConsumer , MemoryReservation } ;
2933use datafusion_execution:: TaskContext ;
3034use datafusion_physical_expr:: PhysicalExpr ;
31- use futures:: stream:: BoxStream ;
35+ use futures:: stream:: StreamExt ;
36+ use futures:: { stream, TryStreamExt } ;
3237use std:: borrow:: Cow ;
38+ use std:: future:: Future ;
39+ use std:: pin:: Pin ;
3340use std:: sync:: Arc ;
34- use std:: task:: { Context , Poll } ;
35-
36- use crate :: filter:: batch_filter;
37- use datafusion_execution:: memory_pool:: { MemoryConsumer , MemoryReservation } ;
38- use futures:: stream:: { Stream , StreamExt } ;
39- use super :: AggregateExec ;
41+ use std:: task:: { ready, Context , Poll } ;
4042
4143pub fn aggregate_stream (
4244 agg : & AggregateExec ,
4345 context : Arc < TaskContext > ,
4446 partition : usize ,
4547) -> Result < SendableRecordBatchStream > {
46- Ok ( Box :: pin ( AggregateStream :: new ( agg, context, partition) ?) )
47- }
48-
49- /// stream struct for aggregation without grouping columns
50- struct AggregateStream {
51- stream : BoxStream < ' static , Result < RecordBatch > > ,
52- schema : SchemaRef ,
48+ let aggregate = Aggregate :: new ( agg, context, partition) ?;
49+
50+ // Spawn a task the first time the stream is polled for the sort phase.
51+ // This ensures the consumer of the aggregate does not poll unnecessarily
52+ // while the aggregation is ongoing
53+ Ok ( crate :: stream:: create_async_then_emit (
54+ Arc :: clone ( & agg. schema ) ,
55+ aggregate,
56+ ) )
5357}
5458
55- /// Actual implementation of [`AggregateStream`].
56- ///
57- /// This is wrapped into yet another struct because we need to interact with the async memory management subsystem
58- /// during poll. To have as little code "weirdness" as possible, we chose to just use [`BoxStream`] together with
59- /// [`futures::stream::unfold`].
60- ///
61- /// The latter requires a state object, which is [`AggregateStreamInner`].
62- struct AggregateStreamInner {
59+ /// The state of the aggregation.
60+ struct Aggregate {
6361 schema : SchemaRef ,
6462 mode : AggregateMode ,
6563 input : SendableRecordBatchStream ,
@@ -68,17 +66,14 @@ struct AggregateStreamInner {
6866 filter_expressions : Vec < Option < Arc < dyn PhysicalExpr > > > ,
6967 accumulators : Vec < AccumulatorItem > ,
7068 reservation : MemoryReservation ,
71- finished : bool ,
7269}
7370
74- impl AggregateStream {
75- /// Create a new AggregateStream
76- pub fn new (
71+ impl Aggregate {
72+ fn new (
7773 agg : & AggregateExec ,
7874 context : Arc < TaskContext > ,
7975 partition : usize ,
8076 ) -> Result < Self > {
81- let agg_schema = Arc :: clone ( & agg. schema ) ;
8277 let agg_filter_expr = agg. filter_expr . clone ( ) ;
8378
8479 let baseline_metrics = BaselineMetrics :: new ( & agg. metrics , partition) ;
@@ -98,7 +93,7 @@ impl AggregateStream {
9893 let reservation = MemoryConsumer :: new ( format ! ( "AggregateStream[{partition}]" ) )
9994 . register ( context. memory_pool ( ) ) ;
10095
101- let inner = AggregateStreamInner {
96+ Ok ( Self {
10297 schema : Arc :: clone ( & agg. schema ) ,
10398 mode : agg. mode ,
10499 input,
@@ -107,91 +102,55 @@ impl AggregateStream {
107102 filter_expressions,
108103 accumulators,
109104 reservation,
110- finished : false ,
111- } ;
112- let stream = futures:: stream:: unfold ( inner, |mut this| async move {
113- if this. finished {
114- return None ;
115- }
116-
117- let elapsed_compute = this. baseline_metrics . elapsed_compute ( ) ;
118-
119- loop {
120- let result = match this. input . next ( ) . await {
121- Some ( Ok ( batch) ) => {
122- let timer = elapsed_compute. timer ( ) ;
123- let result = aggregate_batch (
124- & this. mode ,
125- batch,
126- & mut this. accumulators ,
127- & this. aggregate_expressions ,
128- & this. filter_expressions ,
129- ) ;
130-
131- timer. done ( ) ;
132-
133- // allocate memory
134- // This happens AFTER we actually used the memory, but simplifies the whole accounting and we are OK with
135- // overshooting a bit. Also this means we either store the whole record batch or not.
136- match result
137- . and_then ( |allocated| this. reservation . try_grow ( allocated) )
138- {
139- Ok ( _) => continue ,
140- Err ( e) => Err ( e) ,
141- }
142- }
143- Some ( Err ( e) ) => Err ( e) ,
144- None => {
145- this. finished = true ;
146- let timer = this. baseline_metrics . elapsed_compute ( ) . timer ( ) ;
147- let result =
148- finalize_aggregation ( & mut this. accumulators , & this. mode )
149- . and_then ( |columns| {
150- RecordBatch :: try_new (
151- Arc :: clone ( & this. schema ) ,
152- columns,
153- )
154- . map_err ( Into :: into)
155- } )
156- . record_output ( & this. baseline_metrics ) ;
157-
158- timer. done ( ) ;
159-
160- result
161- }
162- } ;
163-
164- this. finished = true ;
165- return Some ( ( result, this) ) ;
166- }
167- } ) ;
168-
169- // seems like some consumers call this stream even after it returned `None`, so let's fuse the stream.
170- let stream = stream. fuse ( ) ;
171- let stream = Box :: pin ( stream) ;
172-
173- Ok ( Self {
174- schema : agg_schema,
175- stream,
176105 } )
177106 }
178107}
179108
180- impl Stream for AggregateStream {
181- type Item = Result < RecordBatch > ;
109+ impl Future for Aggregate {
110+ type Output = Result < SendableRecordBatchStream > ;
182111
183- fn poll_next (
184- mut self : std:: pin:: Pin < & mut Self > ,
185- cx : & mut Context < ' _ > ,
186- ) -> Poll < Option < Self :: Item > > {
187- let this = & mut * self ;
188- this. stream . poll_next_unpin ( cx)
189- }
190- }
112+ fn poll ( mut self : Pin < & mut Self > , cx : & mut Context < ' _ > ) -> Poll < Self :: Output > {
113+ let elapsed_compute = self . baseline_metrics . elapsed_compute ( ) . clone ( ) ;
114+
115+ loop {
116+ match ready ! ( self . input. poll_next_unpin( cx) ) {
117+ Some ( Ok ( batch) ) => {
118+ let timer = elapsed_compute. timer ( ) ;
119+
120+ let result = aggregate_batch ( & mut self , & batch) ;
121+
122+ timer. done ( ) ;
191123
192- impl RecordBatchStream for AggregateStream {
193- fn schema ( & self ) -> SchemaRef {
194- Arc :: clone ( & self . schema )
124+ // allocate memory
125+ // This happens AFTER we actually used the memory, but simplifies the whole accounting and we are OK with
126+ // overshooting a bit. Also this means we either store the whole record batch or not.
127+ match result
128+ . and_then ( |allocated| self . reservation . try_grow ( allocated) )
129+ {
130+ Ok ( _) => continue ,
131+ Err ( e) => return Poll :: Ready ( Err ( e) ) ,
132+ }
133+ }
134+ Some ( Err ( e) ) => return Poll :: Ready ( Err ( e) ) ,
135+ None => {
136+ let timer = elapsed_compute. timer ( ) ;
137+ let mode = self . mode ;
138+ let result = finalize_aggregation ( & mut self . accumulators , mode)
139+ . and_then ( |columns| {
140+ RecordBatch :: try_new ( Arc :: clone ( & self . schema ) , columns)
141+ . map_err ( Into :: into)
142+ } )
143+ . record_output ( & self . baseline_metrics ) ;
144+
145+ timer. done ( ) ;
146+
147+ return Poll :: Ready ( Ok ( Box :: pin ( RecordBatchStreamAdapter :: new (
148+ Arc :: clone ( & self . schema ) ,
149+ stream:: iter ( vec ! [ result] ) ,
150+ ) ) ) ) ;
151+ }
152+ } ;
153+ }
195154 }
196155}
197156
@@ -200,13 +159,7 @@ impl RecordBatchStream for AggregateStream {
200159/// If successful, this returns the additional number of bytes that were allocated during this process.
201160///
202161/// TODO: Make this a member function
203- fn aggregate_batch (
204- mode : & AggregateMode ,
205- batch : RecordBatch ,
206- accumulators : & mut [ AccumulatorItem ] ,
207- expressions : & [ Vec < Arc < dyn PhysicalExpr > > ] ,
208- filters : & [ Option < Arc < dyn PhysicalExpr > > ] ,
209- ) -> Result < usize > {
162+ fn aggregate_batch ( agg : & mut Aggregate , batch : & RecordBatch ) -> Result < usize > {
210163 let mut allocated = 0usize ;
211164
212165 // 1.1 iterate accumulators and respective expressions together
@@ -215,15 +168,15 @@ fn aggregate_batch(
215168 // 1.4 update / merge accumulators with the expressions' values
216169
217170 // 1.1
218- accumulators
171+ agg . accumulators
219172 . iter_mut ( )
220- . zip ( expressions )
221- . zip ( filters )
173+ . zip ( & agg . aggregate_expressions )
174+ . zip ( & agg . filter_expressions )
222175 . try_for_each ( |( ( accum, expr) , filter) | {
223176 // 1.2
224177 let batch = match filter {
225- Some ( filter) => Cow :: Owned ( batch_filter ( & batch, filter) ?) ,
226- None => Cow :: Borrowed ( & batch) ,
178+ Some ( filter) => Cow :: Owned ( batch_filter ( batch, filter) ?) ,
179+ None => Cow :: Borrowed ( batch) ,
227180 } ;
228181
229182 let n_rows = batch. num_rows ( ) ;
@@ -236,7 +189,7 @@ fn aggregate_batch(
236189
237190 // 1.4
238191 let size_pre = accum. size ( ) ;
239- let res = match mode {
192+ let res = match agg . mode {
240193 AggregateMode :: Partial
241194 | AggregateMode :: Single
242195 | AggregateMode :: SinglePartitioned => accum. update_batch ( & values) ,
0 commit comments