diff --git a/datafusion/common/src/config.rs b/datafusion/common/src/config.rs index 1e1c5d5424b08..3d157249d4c7c 100644 --- a/datafusion/common/src/config.rs +++ b/datafusion/common/src/config.rs @@ -325,11 +325,7 @@ config_namespace! { /// Aggregation ratio (number of distinct groups / number of input rows) /// threshold for skipping partial aggregation. If the value is greater /// then partial aggregation will skip aggregation for further input - pub skip_partial_aggregation_probe_ratio_threshold: f64, default = 0.8 - - /// Number of input rows partial aggregation partition should process, before - /// aggregation ratio check and trying to switch to skipping aggregation mode - pub skip_partial_aggregation_probe_rows_threshold: usize, default = 100_000 + pub skip_partial_aggregation_probe_ratio_threshold: f64, default = 0.1 /// Should DataFusion use row number estimates at the input to decide /// whether increasing parallelism is beneficial or not. By default, diff --git a/datafusion/core/tests/data/aggregate_mixed_type.csv b/datafusion/core/tests/data/aggregate_mixed_type.csv new file mode 100644 index 0000000000000..5481158b93fce --- /dev/null +++ b/datafusion/core/tests/data/aggregate_mixed_type.csv @@ -0,0 +1,17 @@ +c1,c2 +1,'a' +2,'b' +3,'c' +4,'d' +1,'a' +2,'b' +3,'c' +4,'d' +4,'d' +3,'c' +3,'c' +5,'e' +6,'f' +7,'g' +8,'a' +9,'b' diff --git a/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs b/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs index 62e9be63983cb..5e793d898d7df 100644 --- a/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs @@ -44,6 +44,8 @@ use rand::rngs::StdRng; use rand::{Rng, SeedableRng}; use tokio::task::JoinSet; +const BATCH_SIZE: usize = 50; + /// Tests that streaming aggregate and batch (non streaming) aggregate produce /// same results #[tokio::test(flavor = "multi_thread")] @@ -60,13 +62,14 @@ async fn streaming_aggregate_test() { ]; let n = 300; let distincts = vec![10, 20]; + let len = 1000; for distinct in distincts { let mut join_set = JoinSet::new(); for i in 0..n { let test_idx = i % test_cases.len(); let group_by_columns = test_cases[test_idx].clone(); join_set.spawn(run_aggregate_test( - make_staggered_batches::(1000, distinct, i as u64), + make_staggered_batches::(len, distinct, i as u64), group_by_columns, )); } @@ -77,13 +80,19 @@ async fn streaming_aggregate_test() { } } +fn new_ctx() -> SessionContext { + let session_config = SessionConfig::new() + .with_batch_size(BATCH_SIZE) + // Ensure most of the fuzzing test cases doesn't skip the partial aggregation + .with_skip_partial_aggregation_probe_ratio_threshold(1.0); + SessionContext::new_with_config(session_config) +} + /// Perform batch and streaming aggregation with same input /// and verify outputs of `AggregateExec` with pipeline breaking stream `GroupedHashAggregateStream` /// and non-pipeline breaking stream `BoundedAggregateStream` produces same result. async fn run_aggregate_test(input1: Vec, group_by_columns: Vec<&str>) { let schema = input1[0].schema(); - let session_config = SessionConfig::new().with_batch_size(50); - let ctx = SessionContext::new_with_config(session_config); let mut sort_keys = vec![]; for ordering_col in ["a", "b", "c"] { sort_keys.push(PhysicalSortExpr { @@ -141,7 +150,7 @@ async fn run_aggregate_test(input1: Vec, group_by_columns: Vec<&str .unwrap(), ) as Arc; - let task_ctx = ctx.task_ctx(); + let task_ctx = new_ctx().task_ctx(); let collected_usual = collect(aggregate_exec_usual.clone(), task_ctx.clone()) .await .unwrap(); @@ -149,6 +158,7 @@ async fn run_aggregate_test(input1: Vec, group_by_columns: Vec<&str let collected_running = collect(aggregate_exec_running.clone(), task_ctx.clone()) .await .unwrap(); + assert!(collected_running.len() > 2); // Running should produce more chunk than the usual AggregateExec. // Otherwise it means that we cannot generate result in running mode. @@ -232,7 +242,7 @@ pub(crate) fn make_staggered_batches( let mut batches = vec![]; if STREAM { while remainder.num_rows() > 0 { - let batch_size = rng.gen_range(0..50); + let batch_size = rng.gen_range(0..BATCH_SIZE); if remainder.num_rows() < batch_size { break; } @@ -287,8 +297,7 @@ async fn group_by_string_test( let expected = compute_counts(&input, column_name); let schema = input[0].schema(); - let session_config = SessionConfig::new().with_batch_size(50); - let ctx = SessionContext::new_with_config(session_config); + let ctx = new_ctx(); let provider = MemTable::try_new(schema.clone(), vec![input]).unwrap(); let provider = if sorted { diff --git a/datafusion/execution/src/config.rs b/datafusion/execution/src/config.rs index cede75d21ca47..d1c7eecb217fe 100644 --- a/datafusion/execution/src/config.rs +++ b/datafusion/execution/src/config.rs @@ -388,6 +388,17 @@ impl SessionConfig { self } + /// Set the threshold for skip partial aggregation ratio + pub fn with_skip_partial_aggregation_probe_ratio_threshold( + mut self, + threshold: f64, + ) -> Self { + self.options + .execution + .skip_partial_aggregation_probe_ratio_threshold = threshold; + self + } + /// Returns true if record batches will be examined between each operator /// and small batches will be coalesced into larger batches. pub fn coalesce_batches(&self) -> bool { diff --git a/datafusion/physical-plan/src/aggregates/group_values/bytes.rs b/datafusion/physical-plan/src/aggregates/group_values/bytes.rs index f789af8b8a024..46461871e72c0 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/bytes.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/bytes.rs @@ -45,6 +45,7 @@ impl GroupValues for GroupValuesByes { &mut self, cols: &[ArrayRef], groups: &mut Vec, + _batch_hashes: &[u64], ) -> datafusion_common::Result<()> { assert_eq!(cols.len(), 1); @@ -108,7 +109,7 @@ impl GroupValues for GroupValuesByes { self.num_groups = 0; let mut group_indexes = vec![]; - self.intern(&[remaining_group_values], &mut group_indexes)?; + self.intern(&[remaining_group_values], &mut group_indexes, &[])?; // Verify that the group indexes were assigned in the correct order assert_eq!(0, group_indexes[0]); diff --git a/datafusion/physical-plan/src/aggregates/group_values/bytes_view.rs b/datafusion/physical-plan/src/aggregates/group_values/bytes_view.rs index 1a0cb90a16d47..4f9d0abf79927 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/bytes_view.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/bytes_view.rs @@ -46,6 +46,7 @@ impl GroupValues for GroupValuesBytesView { &mut self, cols: &[ArrayRef], groups: &mut Vec, + _batch_hashes: &[u64], ) -> datafusion_common::Result<()> { assert_eq!(cols.len(), 1); @@ -109,7 +110,7 @@ impl GroupValues for GroupValuesBytesView { self.num_groups = 0; let mut group_indexes = vec![]; - self.intern(&[remaining_group_values], &mut group_indexes)?; + self.intern(&[remaining_group_values], &mut group_indexes, &[])?; // Verify that the group indexes were assigned in the correct order assert_eq!(0, group_indexes[0]); diff --git a/datafusion/physical-plan/src/aggregates/group_values/column.rs b/datafusion/physical-plan/src/aggregates/group_values/column.rs index 28f35b2bded2e..23fc1fdf70141 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/column.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/column.rs @@ -19,7 +19,6 @@ use crate::aggregates::group_values::group_column::{ ByteGroupValueBuilder, GroupColumn, PrimitiveGroupValueBuilder, }; use crate::aggregates::group_values::GroupValues; -use ahash::RandomState; use arrow::compute::cast; use arrow::datatypes::{ Date32Type, Date64Type, Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, @@ -28,7 +27,6 @@ use arrow::datatypes::{ use arrow::record_batch::RecordBatch; use arrow_array::{Array, ArrayRef}; use arrow_schema::{DataType, Schema, SchemaRef}; -use datafusion_common::hash_utils::create_hashes; use datafusion_common::{not_impl_err, DataFusionError, Result}; use datafusion_execution::memory_pool::proxy::{RawTableAllocExt, VecAllocExt}; use datafusion_expr::EmitTo; @@ -68,9 +66,6 @@ pub struct GroupValuesColumn { /// reused buffer to store hashes hashes_buffer: Vec, - - /// Random state for creating hashes - random_state: RandomState, } impl GroupValuesColumn { @@ -83,7 +78,6 @@ impl GroupValuesColumn { map_size: 0, group_values: vec![], hashes_buffer: Default::default(), - random_state: Default::default(), }) } @@ -143,9 +137,12 @@ macro_rules! instantiate_primitive { } impl GroupValues for GroupValuesColumn { - fn intern(&mut self, cols: &[ArrayRef], groups: &mut Vec) -> Result<()> { - let n_rows = cols[0].len(); - + fn intern( + &mut self, + cols: &[ArrayRef], + groups: &mut Vec, + batch_hashes: &[u64], + ) -> Result<()> { if self.group_values.is_empty() { let mut v = Vec::with_capacity(cols.len()); @@ -195,12 +192,6 @@ impl GroupValues for GroupValuesColumn { // tracks to which group each of the input rows belongs groups.clear(); - // 1.1 Calculate the group keys for the group values - let batch_hashes = &mut self.hashes_buffer; - batch_hashes.clear(); - batch_hashes.resize(n_rows, 0); - create_hashes(cols, &self.random_state, batch_hashes)?; - for (row, &target_hash) in batch_hashes.iter().enumerate() { let entry = self.map.get_mut(target_hash, |(exist_hash, group_idx)| { // Somewhat surprisingly, this closure can be called even if the diff --git a/datafusion/physical-plan/src/aggregates/group_values/mod.rs b/datafusion/physical-plan/src/aggregates/group_values/mod.rs index fb7b667750924..e5758c5eb97d6 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/mod.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/mod.rs @@ -86,7 +86,12 @@ pub trait GroupValues: Send { /// If a row has the same value as a previous row, the same group id is /// assigned. If a row has a new value, the next available group id is /// assigned. - fn intern(&mut self, cols: &[ArrayRef], groups: &mut Vec) -> Result<()>; + fn intern( + &mut self, + cols: &[ArrayRef], + groups: &mut Vec, + batch_hashes: &[u64], + ) -> Result<()>; /// Returns the number of bytes of memory used by this [`GroupValues`] fn size(&self) -> usize; diff --git a/datafusion/physical-plan/src/aggregates/group_values/primitive.rs b/datafusion/physical-plan/src/aggregates/group_values/primitive.rs index d5b7f1b11ac55..8fef27a49ff85 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/primitive.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/primitive.rs @@ -111,7 +111,12 @@ impl GroupValues for GroupValuesPrimitive where T::Native: HashValue, { - fn intern(&mut self, cols: &[ArrayRef], groups: &mut Vec) -> Result<()> { + fn intern( + &mut self, + cols: &[ArrayRef], + groups: &mut Vec, + _batch_hashes: &[u64], + ) -> Result<()> { assert_eq!(cols.len(), 1); groups.clear(); diff --git a/datafusion/physical-plan/src/aggregates/group_values/row.rs b/datafusion/physical-plan/src/aggregates/group_values/row.rs index 8ca88257bf1a7..05fa1ffeca069 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/row.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/row.rs @@ -108,7 +108,12 @@ impl GroupValuesRows { } impl GroupValues for GroupValuesRows { - fn intern(&mut self, cols: &[ArrayRef], groups: &mut Vec) -> Result<()> { + fn intern( + &mut self, + cols: &[ArrayRef], + groups: &mut Vec, + _batch_hashes: &[u64], + ) -> Result<()> { // Convert the group keys into the row format let group_rows = &mut self.rows_buffer; group_rows.clear(); diff --git a/datafusion/physical-plan/src/aggregates/mod.rs b/datafusion/physical-plan/src/aggregates/mod.rs index 9466ff6dd4591..f73e750c7d923 100644 --- a/datafusion/physical-plan/src/aggregates/mod.rs +++ b/datafusion/physical-plan/src/aggregates/mod.rs @@ -1330,8 +1330,18 @@ mod tests { ) } + fn new_ctx() -> Arc { + // Ensure skip the partial aggregation is not triggered + let session_config = + SessionConfig::new().with_skip_partial_aggregation_probe_ratio_threshold(1.0); + let task_ctx = TaskContext::default().with_session_config(session_config); + Arc::new(task_ctx) + } + fn new_spill_ctx(batch_size: usize, max_memory: usize) -> Arc { - let session_config = SessionConfig::new().with_batch_size(batch_size); + let session_config = SessionConfig::new() + .with_batch_size(batch_size) + .with_skip_partial_aggregation_probe_ratio_threshold(1.0); let runtime = RuntimeEnvBuilder::default() .with_memory_pool(Arc::new(FairSpillPool::new(max_memory))) .build_arc() @@ -1373,7 +1383,7 @@ mod tests { // adjust the max memory size to have the partial aggregate result for spill mode. new_spill_ctx(4, 500) } else { - Arc::new(TaskContext::default()) + new_ctx() }; let partial_aggregate = Arc::new(AggregateExec::try_new( @@ -1521,7 +1531,7 @@ mod tests { // set to an appropriate value to trigger spill new_spill_ctx(2, 1600) } else { - Arc::new(TaskContext::default()) + new_ctx() }; let partial_aggregate = Arc::new(AggregateExec::try_new( @@ -1821,7 +1831,12 @@ mod tests { let runtime = RuntimeEnvBuilder::default() .with_memory_limit(1, 1.0) .build_arc()?; - let task_ctx = TaskContext::default().with_runtime(runtime); + // Ensure skip the partial aggregation is not triggered + let session_config = SessionConfig::default() + .with_skip_partial_aggregation_probe_ratio_threshold(1.0); + let task_ctx = TaskContext::default() + .with_runtime(runtime) + .with_session_config(session_config); let task_ctx = Arc::new(task_ctx); let groups_none = PhysicalGroupBy::default(); @@ -2503,16 +2518,8 @@ mod tests { schema, )?); - let mut session_config = SessionConfig::default(); - session_config = session_config.set( - "datafusion.execution.skip_partial_aggregation_probe_rows_threshold", - &ScalarValue::Int64(Some(2)), - ); - session_config = session_config.set( - "datafusion.execution.skip_partial_aggregation_probe_ratio_threshold", - &ScalarValue::Float64(Some(0.1)), - ); - + let session_config = SessionConfig::default() + .with_skip_partial_aggregation_probe_ratio_threshold(0.1); let ctx = TaskContext::default().with_session_config(session_config); let output = collect(aggregate_exec.execute(0, Arc::new(ctx))?).await?; @@ -2533,96 +2540,6 @@ mod tests { Ok(()) } - #[tokio::test] - async fn test_skip_aggregation_after_threshold() -> Result<()> { - let schema = Arc::new(Schema::new(vec![ - Field::new("key", DataType::Int32, true), - Field::new("val", DataType::Int32, true), - ])); - - let group_by = - PhysicalGroupBy::new_single(vec![(col("key", &schema)?, "key".to_string())]); - - let aggr_expr = - vec![ - AggregateExprBuilder::new(count_udaf(), vec![col("val", &schema)?]) - .schema(Arc::clone(&schema)) - .alias(String::from("COUNT(val)")) - .build()?, - ]; - - let input_data = vec![ - RecordBatch::try_new( - Arc::clone(&schema), - vec![ - Arc::new(Int32Array::from(vec![1, 2, 3])), - Arc::new(Int32Array::from(vec![0, 0, 0])), - ], - ) - .unwrap(), - RecordBatch::try_new( - Arc::clone(&schema), - vec![ - Arc::new(Int32Array::from(vec![2, 3, 4])), - Arc::new(Int32Array::from(vec![0, 0, 0])), - ], - ) - .unwrap(), - RecordBatch::try_new( - Arc::clone(&schema), - vec![ - Arc::new(Int32Array::from(vec![2, 3, 4])), - Arc::new(Int32Array::from(vec![0, 0, 0])), - ], - ) - .unwrap(), - ]; - - let input = Arc::new(MemoryExec::try_new( - &[input_data], - Arc::clone(&schema), - None, - )?); - let aggregate_exec = Arc::new(AggregateExec::try_new( - AggregateMode::Partial, - group_by, - aggr_expr, - vec![None], - Arc::clone(&input) as Arc, - schema, - )?); - - let mut session_config = SessionConfig::default(); - session_config = session_config.set( - "datafusion.execution.skip_partial_aggregation_probe_rows_threshold", - &ScalarValue::Int64(Some(5)), - ); - session_config = session_config.set( - "datafusion.execution.skip_partial_aggregation_probe_ratio_threshold", - &ScalarValue::Float64(Some(0.1)), - ); - - let ctx = TaskContext::default().with_session_config(session_config); - let output = collect(aggregate_exec.execute(0, Arc::new(ctx))?).await?; - - let expected = [ - "+-----+-------------------+", - "| key | COUNT(val)[count] |", - "+-----+-------------------+", - "| 1 | 1 |", - "| 2 | 2 |", - "| 3 | 2 |", - "| 4 | 1 |", - "| 2 | 1 |", - "| 3 | 1 |", - "| 4 | 1 |", - "+-----+-------------------+", - ]; - assert_batches_eq!(expected, &output); - - Ok(()) - } - #[test] fn group_exprs_nullable() -> Result<()> { let input_schema = Arc::new(Schema::new(vec![ diff --git a/datafusion/physical-plan/src/aggregates/row_hash.rs b/datafusion/physical-plan/src/aggregates/row_hash.rs index 998f6184f3213..793dcc7fb234e 100644 --- a/datafusion/physical-plan/src/aggregates/row_hash.rs +++ b/datafusion/physical-plan/src/aggregates/row_hash.rs @@ -27,17 +27,19 @@ use crate::aggregates::{ evaluate_group_by, evaluate_many, evaluate_optional, group_schema, AggregateMode, PhysicalGroupBy, }; -use crate::metrics::{BaselineMetrics, MetricBuilder, RecordOutput}; +use crate::metrics::{BaselineMetrics, RecordOutput}; use crate::sorts::sort::sort_batch; use crate::sorts::streaming_merge; use crate::spill::{read_spill_as_stream, spill_record_batch_by_size}; use crate::stream::RecordBatchStreamAdapter; -use crate::{aggregates, metrics, ExecutionPlan, PhysicalExpr}; +use crate::{aggregates, ExecutionPlan, PhysicalExpr}; use crate::{RecordBatchStream, SendableRecordBatchStream}; +use ahash::RandomState; use arrow::array::*; use arrow::datatypes::SchemaRef; use arrow_schema::SortOptions; +use datafusion_common::hash_utils::create_hashes; use datafusion_common::{internal_err, DataFusionError, Result}; use datafusion_execution::disk_manager::RefCountedTempFile; use datafusion_execution::memory_pool::proxy::VecAllocExt; @@ -51,6 +53,7 @@ use datafusion_physical_expr::{GroupsAccumulatorAdapter, PhysicalSortExpr}; use datafusion_physical_expr::aggregate::AggregateFunctionExpr; use futures::ready; use futures::stream::{Stream, StreamExt}; +use hashbrown::HashSet; use log::debug; use super::order::GroupOrdering; @@ -107,15 +110,13 @@ struct SpillState { /// Tracks if the aggregate should skip partial aggregations /// /// See "partial aggregation" discussion on [`GroupedHashAggregateStream`] +#[derive(Debug)] struct SkipAggregationProbe { // ======================================================================== // PROPERTIES: // These fields are initialized at the start and remain constant throughout // the execution. // ======================================================================== - /// Aggregation ratio check performed when the number of input rows exceeds - /// this threshold (from `SessionConfig`) - probe_rows_threshold: usize, /// Maximum ratio of `num_groups` to `input_rows` for continuing aggregation /// (from `SessionConfig`). If the ratio exceeds this value, aggregation /// is skipped and input rows are directly converted to output @@ -128,68 +129,32 @@ struct SkipAggregationProbe { // ======================================================================== /// Number of processed input rows (updated during probing) input_rows: usize, - /// Number of total group values for `input_rows` (updated during probing) - num_groups: usize, - /// Flag indicating further data aggregation may be skipped (decision made - /// when probing complete) - should_skip: bool, - /// Flag indicating further updates of `SkipAggregationProbe` state won't - /// make any effect (set either while probing or on probing completion) - is_locked: bool, - - /// Number of rows where state was output without aggregation. - /// - /// * If 0, all input rows were aggregated (should_skip was always false) - /// - /// * if greater than zero, the number of rows which were output directly - /// without aggregation - skipped_aggregation_rows: metrics::Count, + /// Number of unique hash, which represents the cardinality of the group values + unique_hashes_count: HashSet, } impl SkipAggregationProbe { - fn new( - probe_rows_threshold: usize, - probe_ratio_threshold: f64, - skipped_aggregation_rows: metrics::Count, - ) -> Self { + fn new(probe_ratio_threshold: f64) -> Self { Self { input_rows: 0, - num_groups: 0, - probe_rows_threshold, probe_ratio_threshold, - should_skip: false, - is_locked: false, - skipped_aggregation_rows, + unique_hashes_count: Default::default(), } } /// Updates `SkipAggregationProbe` state: - /// - increments the number of input rows - /// - replaces the number of groups with the new value - /// - on `probe_rows_threshold` exceeded calculates - /// aggregation ratio and sets `should_skip` flag - /// - if `should_skip` is set, locks further state updates - fn update_state(&mut self, input_rows: usize, num_groups: usize) { - if self.is_locked { - return; + /// Insert hashes to the HashSet, if the number of + /// unique hashes to the total rows exceed the + /// threshold, return true to indicates `should skip aggregation` + fn update_state(&mut self, batch_hashes: &[u64]) -> bool { + for target_hash in batch_hashes.iter() { + self.unique_hashes_count.insert(*target_hash); } - self.input_rows += input_rows; - self.num_groups = num_groups; - if self.input_rows >= self.probe_rows_threshold { - self.should_skip = self.num_groups as f64 / self.input_rows as f64 - >= self.probe_ratio_threshold; - self.is_locked = true; - } - } + self.input_rows += batch_hashes.len(); - fn should_skip(&self) -> bool { - self.should_skip - } - - /// Record the number of rows that were output directly without aggregation - fn record_skipped(&mut self, batch: &RecordBatch) { - self.skipped_aggregation_rows.add(batch.num_rows()); + self.unique_hashes_count.len() as f64 / self.input_rows as f64 + > self.probe_ratio_threshold } } @@ -410,6 +375,15 @@ pub(crate) struct GroupedHashAggregateStream { /// current stream. skip_aggregation_probe: Option, + /// Indicates whether we skip the partial aggregation + skip_partial_aggregation: bool, + + /// Random state for creating hashes + random_state: RandomState, + + /// Reuse buffer to avoid reallocation + hashes_buffer: Vec>, + // ======================================================================== // EXECUTION RESOURCES: // Fields related to managing execution resources and monitoring performance. @@ -527,17 +501,9 @@ impl GroupedHashAggregateStream { && agg_group_by.is_single() { let options = &context.session_config().options().execution; - let probe_rows_threshold = - options.skip_partial_aggregation_probe_rows_threshold; let probe_ratio_threshold = options.skip_partial_aggregation_probe_ratio_threshold; - let skipped_aggregation_rows = MetricBuilder::new(&agg.metrics) - .counter("skipped_aggregation_rows", partition); - Some(SkipAggregationProbe::new( - probe_rows_threshold, - probe_ratio_threshold, - skipped_aggregation_rows, - )) + Some(SkipAggregationProbe::new(probe_ratio_threshold)) } else { None }; @@ -562,6 +528,10 @@ impl GroupedHashAggregateStream { spill_state, group_values_soft_limit: agg.limit, skip_aggregation_probe, + random_state: Default::default(), + skip_partial_aggregation: false, + // Create with one vec, so we don't require length check for non skip partial aggregation case + hashes_buffer: vec![Vec::new()], }) } } @@ -609,15 +579,56 @@ impl Stream for GroupedHashAggregateStream { match &self.exec_state { ExecutionState::ReadingInput => 'reading_input: { match ready!(self.input.poll_next_unpin(cx)) { + // New batch to aggregate in partial aggregation operator with skip aggregation probe enabled + Some(Ok(batch)) + if self.mode == AggregateMode::Partial + && self.skip_aggregation_probe.is_some() => + { + let timer = elapsed_compute.timer(); + + // Do the grouping + extract_ok!( + self.group_aggregate_batch_with_skipping_partial(&batch) + ); + if self.skip_partial_aggregation { + let states = self.transform_to_states(batch)?; + self.exec_state = ExecutionState::ProducingOutput(states); + // make sure the exec_state just set is not overwritten below + break 'reading_input; + } + + // If we can begin emitting rows, do so, + // otherwise keep consuming input + assert!(!self.input_done); + + // If the number of group values equals or exceeds the soft limit, + // emit all groups and switch to producing output + if self.hit_soft_group_limit() { + timer.done(); + extract_ok!(self.set_input_done_and_produce_output()); + // make sure the exec_state just set is not overwritten below + break 'reading_input; + } + + if let Some(to_emit) = self.group_ordering.emit_to() { + let batch = extract_ok!(self.emit(to_emit, false)); + self.exec_state = ExecutionState::ProducingOutput(batch); + timer.done(); + // make sure the exec_state just set is not overwritten below + break 'reading_input; + } + + extract_ok!(self.emit_early_if_necessary()); + + timer.done(); + } + // New batch to aggregate in partial aggregation operator Some(Ok(batch)) if self.mode == AggregateMode::Partial => { let timer = elapsed_compute.timer(); - let input_rows = batch.num_rows(); // Do the grouping - extract_ok!(self.group_aggregate_batch(batch)); - - self.update_skip_aggregation_probe(input_rows); + extract_ok!(self.group_aggregate_batch(&batch)); // If we can begin emitting rows, do so, // otherwise keep consuming input @@ -642,8 +653,6 @@ impl Stream for GroupedHashAggregateStream { extract_ok!(self.emit_early_if_necessary()); - extract_ok!(self.switch_to_skip_aggregation()); - timer.done(); } @@ -656,7 +665,7 @@ impl Stream for GroupedHashAggregateStream { extract_ok!(self.spill_previous_if_necessary(&batch)); // Do the grouping - extract_ok!(self.group_aggregate_batch(batch)); + extract_ok!(self.group_aggregate_batch(&batch,)); // If we can begin emitting rows, do so, // otherwise keep consuming input @@ -700,9 +709,6 @@ impl Stream for GroupedHashAggregateStream { match ready!(self.input.poll_next_unpin(cx)) { Some(Ok(batch)) => { let _timer = elapsed_compute.timer(); - if let Some(probe) = self.skip_aggregation_probe.as_mut() { - probe.record_skipped(&batch); - } let states = self.transform_to_states(batch)?; return Poll::Ready(Some(Ok( states.record_output(&self.baseline_metrics) @@ -730,9 +736,7 @@ impl Stream for GroupedHashAggregateStream { } // In Partial aggregation, we also need to check // if we should trigger partial skipping - else if self.mode == AggregateMode::Partial - && self.should_skip_aggregation() - { + else if self.skip_partial_aggregation { ExecutionState::SkippingAggregation } else { ExecutionState::ReadingInput @@ -771,35 +775,158 @@ impl RecordBatchStream for GroupedHashAggregateStream { } impl GroupedHashAggregateStream { - /// Perform group-by aggregation for the given [`RecordBatch`]. - fn group_aggregate_batch(&mut self, batch: RecordBatch) -> Result<()> { - // Evaluate the grouping expressions + /// Group aggregation with skip partial logic + fn group_aggregate_batch_with_skipping_partial( + &mut self, + batch: &RecordBatch, + ) -> Result<()> { let group_by_values = if self.spill_state.is_stream_merging { - evaluate_group_by(&self.spill_state.merging_group_by, &batch)? + evaluate_group_by(&self.spill_state.merging_group_by, batch) + } else { + evaluate_group_by(&self.group_by, batch) + }?; + + self.hashes_buffer.resize(group_by_values.len(), Vec::new()); + for (index, group_values) in group_by_values.iter().enumerate() { + let n_rows = group_values[0].len(); + let batch_hashes = &mut self.hashes_buffer[index]; + batch_hashes.clear(); + batch_hashes.resize(n_rows, 0); + create_hashes(group_values, &self.random_state, batch_hashes)?; + + // This function should be called if skip aggregation is supported + let probe = self.skip_aggregation_probe.as_mut().unwrap(); + self.skip_partial_aggregation = probe.update_state(batch_hashes); + if self.skip_partial_aggregation { + return Ok(()); + } + } + + // Evaluate the aggregation expressions. + let input_values = if self.spill_state.is_stream_merging { + evaluate_many(&self.spill_state.merging_aggregate_arguments, batch)? + } else { + evaluate_many(&self.aggregate_arguments, batch)? + }; + + // Evaluate the filter expressions, if any, against the inputs + let filter_values = if self.spill_state.is_stream_merging { + let filter_expressions = vec![None; self.accumulators.len()]; + evaluate_optional(&filter_expressions, batch)? } else { - evaluate_group_by(&self.group_by, &batch)? + evaluate_optional(&self.filter_expressions, batch)? }; + for (index, group_values) in group_by_values.iter().enumerate() { + // calculate the group indices for each input row + let starting_num_groups = self.group_values.len(); + self.group_values.intern( + group_values, + &mut self.current_group_indices, + &self.hashes_buffer[index], + )?; + + let group_indices = &self.current_group_indices; + + // Update ordering information if necessary + let total_num_groups = self.group_values.len(); + if total_num_groups > starting_num_groups { + self.group_ordering.new_groups( + group_values, + group_indices, + total_num_groups, + )?; + } + + // Gather the inputs to call the actual accumulator + let t = self + .accumulators + .iter_mut() + .zip(input_values.iter()) + .zip(filter_values.iter()); + + for ((acc, values), opt_filter) in t { + let opt_filter = opt_filter.as_ref().map(|filter| filter.as_boolean()); + + // Call the appropriate method on each aggregator with + // the entire input row and the relevant group indexes + match self.mode { + AggregateMode::Partial + | AggregateMode::Single + | AggregateMode::SinglePartitioned + if !self.spill_state.is_stream_merging => + { + acc.update_batch( + values, + group_indices, + opt_filter, + total_num_groups, + )?; + } + _ => { + // if aggregation is over intermediate states, + // use merge + acc.merge_batch( + values, + group_indices, + opt_filter, + total_num_groups, + )?; + } + } + } + } + + match self.update_memory_reservation() { + // Here we can ignore `insufficient_capacity_err` because we will spill later, + // but at least one batch should fit in the memory + Err(DataFusionError::ResourcesExhausted(_)) + if self.group_values.len() >= self.batch_size => + { + Ok(()) + } + other => other, + } + } + + /// Perform group-by aggregation for the given [`RecordBatch`]. + fn group_aggregate_batch(&mut self, batch: &RecordBatch) -> Result<()> { + let group_by_values = if self.spill_state.is_stream_merging { + evaluate_group_by(&self.spill_state.merging_group_by, batch) + } else { + evaluate_group_by(&self.group_by, batch) + }?; + // Evaluate the aggregation expressions. let input_values = if self.spill_state.is_stream_merging { - evaluate_many(&self.spill_state.merging_aggregate_arguments, &batch)? + evaluate_many(&self.spill_state.merging_aggregate_arguments, batch)? } else { - evaluate_many(&self.aggregate_arguments, &batch)? + evaluate_many(&self.aggregate_arguments, batch)? }; // Evaluate the filter expressions, if any, against the inputs let filter_values = if self.spill_state.is_stream_merging { let filter_expressions = vec![None; self.accumulators.len()]; - evaluate_optional(&filter_expressions, &batch)? + evaluate_optional(&filter_expressions, batch)? } else { - evaluate_optional(&self.filter_expressions, &batch)? + evaluate_optional(&self.filter_expressions, batch)? }; - for group_values in &group_by_values { + for group_values in group_by_values.iter() { + let n_rows = group_values[0].len(); + let batch_hashes = &mut self.hashes_buffer[0]; + batch_hashes.clear(); + batch_hashes.resize(n_rows, 0); + create_hashes(group_values, &self.random_state, batch_hashes)?; + // calculate the group indices for each input row let starting_num_groups = self.group_values.len(); - self.group_values - .intern(group_values, &mut self.current_group_indices)?; + self.group_values.intern( + group_values, + &mut self.current_group_indices, + batch_hashes, + )?; + let group_indices = &self.current_group_indices; // Update ordering information if necessary @@ -1042,43 +1169,6 @@ impl GroupedHashAggregateStream { Ok(()) } - /// Updates skip aggregation probe state. - /// - /// Notice: It should only be called in Partial aggregation - fn update_skip_aggregation_probe(&mut self, input_rows: usize) { - if let Some(probe) = self.skip_aggregation_probe.as_mut() { - // Skip aggregation probe is not supported if stream has any spills, - // currently spilling is not supported for Partial aggregation - assert!(self.spill_state.spills.is_empty()); - probe.update_state(input_rows, self.group_values.len()); - }; - } - - /// In case the probe indicates that aggregation may be - /// skipped, forces stream to produce currently accumulated output. - /// - /// Notice: It should only be called in Partial aggregation - fn switch_to_skip_aggregation(&mut self) -> Result<()> { - if let Some(probe) = self.skip_aggregation_probe.as_mut() { - if probe.should_skip() { - let batch = self.emit(EmitTo::All, false)?; - self.exec_state = ExecutionState::ProducingOutput(batch); - } - } - - Ok(()) - } - - /// Returns true if the aggregation probe indicates that aggregation - /// should be skipped. - /// - /// Notice: It should only be called in Partial aggregation - fn should_skip_aggregation(&self) -> bool { - self.skip_aggregation_probe - .as_ref() - .is_some_and(|probe| probe.should_skip()) - } - /// Transforms input batch to intermediate aggregate state, without grouping it fn transform_to_states(&self, batch: RecordBatch) -> Result { let mut group_values = evaluate_group_by(&self.group_by, &batch)?; diff --git a/datafusion/sqllogictest/test_files/aggregate_skip_partial.slt b/datafusion/sqllogictest/test_files/aggregate_skip_partial.slt index a2e51cffacf7e..7b26a7a7579e1 100644 --- a/datafusion/sqllogictest/test_files/aggregate_skip_partial.slt +++ b/datafusion/sqllogictest/test_files/aggregate_skip_partial.slt @@ -57,9 +57,6 @@ SELECT FROM aggregate_test_100; # Prepare settings to skip partial aggregation from the beginning -statement ok -set datafusion.execution.skip_partial_aggregation_probe_rows_threshold = 0; - statement ok set datafusion.execution.skip_partial_aggregation_probe_ratio_threshold = 0.0; @@ -206,9 +203,6 @@ d NULL false NULL e true false NULL # Prepare settings to always skip aggregation after couple of batches -statement ok -set datafusion.execution.skip_partial_aggregation_probe_rows_threshold = 10; - statement ok set datafusion.execution.skip_partial_aggregation_probe_ratio_threshold = 0.0; @@ -668,9 +662,6 @@ statement ok DROP TABLE decimal_table; # Extra tests for 'bool_*()' edge cases -statement ok -set datafusion.execution.skip_partial_aggregation_probe_rows_threshold = 0; - statement ok set datafusion.execution.skip_partial_aggregation_probe_ratio_threshold = 0.0; @@ -701,9 +692,6 @@ SELECT bool_and(c1), bool_and(c2), bool_and(c3), bool_and(c4), bool_and(c5), boo ---- true false false false false true false NULL -statement ok -set datafusion.execution.skip_partial_aggregation_probe_rows_threshold = 2; - query BBBBBBBB SELECT bool_and(c1), bool_and(c2), bool_and(c3), bool_and(c4), bool_and(c5), bool_and(c6), bool_and(c7), bool_and(c8) FROM bool_aggregate_functions ---- diff --git a/datafusion/sqllogictest/test_files/group_by.slt b/datafusion/sqllogictest/test_files/group_by.slt index f561fa9e9ac8d..cc46297c38eef 100644 --- a/datafusion/sqllogictest/test_files/group_by.slt +++ b/datafusion/sqllogictest/test_files/group_by.slt @@ -5207,3 +5207,33 @@ NULL a 2 statement ok drop table t; + +# Make sure the query run in multiple partitions +statement ok +set datafusion.execution.batch_size = 8; + +statement ok +CREATE EXTERNAL TABLE agg_order ( +c1 INT NOT NULL, +c2 VARCHAR NOT NULL, +) +STORED AS CSV +LOCATION '../core/tests/data/aggregate_mixed_type.csv' +OPTIONS ('format.has_header' 'true'); + +query ITI rowsort +select c1, c2, count(*) from agg_order group by c1, c2; +---- +1 'a' 2 +2 'b' 2 +3 'c' 4 +4 'd' 3 +5 'e' 1 +6 'f' 1 +7 'g' 1 +8 'a' 1 +9 'b' 1 + +statement ok +drop table agg_order; + diff --git a/datafusion/sqllogictest/test_files/information_schema.slt b/datafusion/sqllogictest/test_files/information_schema.slt index 7acdf25b65967..ab282bda6f5a8 100644 --- a/datafusion/sqllogictest/test_files/information_schema.slt +++ b/datafusion/sqllogictest/test_files/information_schema.slt @@ -206,8 +206,7 @@ datafusion.execution.parquet.statistics_enabled page datafusion.execution.parquet.write_batch_size 1024 datafusion.execution.parquet.writer_version 1.0 datafusion.execution.planning_concurrency 13 -datafusion.execution.skip_partial_aggregation_probe_ratio_threshold 0.8 -datafusion.execution.skip_partial_aggregation_probe_rows_threshold 100000 +datafusion.execution.skip_partial_aggregation_probe_ratio_threshold 0.1 datafusion.execution.soft_max_rows_per_output_file 50000000 datafusion.execution.sort_in_place_threshold_bytes 1048576 datafusion.execution.sort_spill_reservation_bytes 10485760 @@ -296,8 +295,7 @@ datafusion.execution.parquet.statistics_enabled page (writing) Sets if statistic datafusion.execution.parquet.write_batch_size 1024 (writing) Sets write_batch_size in bytes datafusion.execution.parquet.writer_version 1.0 (writing) Sets parquet writer version valid values are "1.0" and "2.0" datafusion.execution.planning_concurrency 13 Fan-out during initial physical planning. This is mostly use to plan `UNION` children in parallel. Defaults to the number of CPU cores on the system -datafusion.execution.skip_partial_aggregation_probe_ratio_threshold 0.8 Aggregation ratio (number of distinct groups / number of input rows) threshold for skipping partial aggregation. If the value is greater then partial aggregation will skip aggregation for further input -datafusion.execution.skip_partial_aggregation_probe_rows_threshold 100000 Number of input rows partial aggregation partition should process, before aggregation ratio check and trying to switch to skipping aggregation mode +datafusion.execution.skip_partial_aggregation_probe_ratio_threshold 0.1 Aggregation ratio (number of distinct groups / number of input rows) threshold for skipping partial aggregation. If the value is greater then partial aggregation will skip aggregation for further input datafusion.execution.soft_max_rows_per_output_file 50000000 Target number of rows in output files when writing multiple. This is a soft max, so it can be exceeded slightly. There also will be one file smaller than the limit if the total number of rows written is not roughly divisible by the soft max datafusion.execution.sort_in_place_threshold_bytes 1048576 When sorting, below what size should data be concatenated and sorted in a single RecordBatch rather than sorted in batches and merged. datafusion.execution.sort_spill_reservation_bytes 10485760 Specifies the reserved memory for each spillable sort operation to facilitate an in-memory merge. When a sort operation spills to disk, the in-memory data must be sorted and merged before being written to a file. This setting reserves a specific amount of memory for that in-memory sort/merge process. Note: This setting is irrelevant if the sort operation cannot spill (i.e., if there's no `DiskManager` configured). diff --git a/docs/source/user-guide/configs.md b/docs/source/user-guide/configs.md index f34d148f092f3..df878189ad03d 100644 --- a/docs/source/user-guide/configs.md +++ b/docs/source/user-guide/configs.md @@ -88,8 +88,7 @@ Environment variables are read during `SessionConfig` initialisation so they mus | datafusion.execution.enable_recursive_ctes | true | Should DataFusion support recursive CTEs | | datafusion.execution.split_file_groups_by_statistics | false | Attempt to eliminate sorts by packing & sorting files with non-overlapping statistics into the same file groups. Currently experimental | | datafusion.execution.keep_partition_by_columns | false | Should DataFusion keep the columns used for partition_by in the output RecordBatches | -| datafusion.execution.skip_partial_aggregation_probe_ratio_threshold | 0.8 | Aggregation ratio (number of distinct groups / number of input rows) threshold for skipping partial aggregation. If the value is greater then partial aggregation will skip aggregation for further input | -| datafusion.execution.skip_partial_aggregation_probe_rows_threshold | 100000 | Number of input rows partial aggregation partition should process, before aggregation ratio check and trying to switch to skipping aggregation mode | +| datafusion.execution.skip_partial_aggregation_probe_ratio_threshold | 0.1 | Aggregation ratio (number of distinct groups / number of input rows) threshold for skipping partial aggregation. If the value is greater then partial aggregation will skip aggregation for further input | | datafusion.execution.use_row_number_estimates_to_optimize_partitioning | false | Should DataFusion use row number estimates at the input to decide whether increasing parallelism is beneficial or not. By default, only exact row numbers (not estimates) are used for this decision. Setting this flag to `true` will likely produce better plans. if the source of statistics is accurate. We plan to make this the default in the future. | | datafusion.optimizer.enable_distinct_aggregation_soft_limit | true | When set to true, the optimizer will push a limit operation into grouped aggregations which have no aggregate expressions, as a soft limit, emitting groups once the limit is reached, before all rows in the group are read. | | datafusion.optimizer.enable_round_robin_repartition | true | When set to true, the physical plan optimizer will try to add round robin repartitioning to increase parallelism to leverage more CPU cores |