From cff822e7d5bb3a96b226f8e8e52986c7a7583842 Mon Sep 17 00:00:00 2001 From: kamille Date: Sat, 7 Dec 2024 19:30:19 +0800 Subject: [PATCH 01/21] draft of `MedianGroupAccumulator`. --- datafusion/functions-aggregate/src/median.rs | 118 ++++++++++++++++++- 1 file changed, 117 insertions(+), 1 deletion(-) diff --git a/datafusion/functions-aggregate/src/median.rs b/datafusion/functions-aggregate/src/median.rs index 70f192c32ae1..7afed3fa57d4 100644 --- a/datafusion/functions-aggregate/src/median.rs +++ b/datafusion/functions-aggregate/src/median.rs @@ -20,7 +20,10 @@ use std::fmt::{Debug, Formatter}; use std::mem::{size_of, size_of_val}; use std::sync::Arc; -use arrow::array::{downcast_integer, ArrowNumericType}; +use arrow::array::{ + downcast_integer, ArrowNumericType, BooleanArray, GenericListBuilder, + GenericListViewArray, +}; use arrow::{ array::{ArrayRef, AsArray}, datatypes::{ @@ -39,8 +42,10 @@ use datafusion_expr::{ function::AccumulatorArgs, utils::format_state_name, Accumulator, AggregateUDFImpl, Documentation, Signature, Volatility, }; +use datafusion_expr::{EmitTo, GroupsAccumulator}; use datafusion_functions_aggregate_common::utils::Hashable; use datafusion_macros::user_doc; +use datafusion_physical_expr::NullState; make_udaf_expr_and_func!( Median, @@ -230,6 +235,117 @@ impl Accumulator for MedianAccumulator { } } +/// The median accumulator accumulates the raw input values +/// as `ScalarValue`s +/// +/// The intermediate state is represented as a List of scalar values updated by +/// `merge_batch` and a `Vec` of `ArrayRef` that are converted to scalar values +/// in the final evaluation step so that we avoid expensive conversions and +/// allocations during `update_batch`. +#[derive(Debug)] +struct MedianGroupAccumulator { + data_type: DataType, + group_values: Vec>, + null_state: NullState, +} + +impl MedianGroupAccumulator { + pub fn new(data_type: DataType) -> Self { + Self { + data_type, + group_values: Vec::new(), + null_state: NullState::new(), + } + } +} + +impl GroupsAccumulator for MedianGroupAccumulator { + fn update_batch( + &mut self, + values: &[ArrayRef], + group_indices: &[usize], + opt_filter: Option<&BooleanArray>, + total_num_groups: usize, + ) -> Result<()> { + assert_eq!(values.len(), 1, "single argument to update_batch"); + let values = values[0].as_primitive::(); + + // increment counts, update sums + self.group_values.resize(total_num_groups, Vec::new()); + self.null_state.accumulate( + group_indices, + values, + opt_filter, + total_num_groups, + |group_index, new_value| { + self.group_values[group_index].push(new_value); + }, + ); + + Ok(()) + } + + fn merge_batch( + &mut self, + values: &[ArrayRef], + group_indices: &[usize], + // Since aggregate filter should be applied in partial stage, in final stage there should be no filter + _opt_filter: Option<&BooleanArray>, + total_num_groups: usize, + ) -> Result<()> { + assert_eq!(values.len(), 1, "one argument to merge_batch"); + + // The merged values should be organized like as a `ListArray` like: + // + // ```text + // group 0: [1, 2, 3] + // group 1: [4, 5] + // group 2: [6, 7, 8] + // ... + // group n: [...] + // ``` + // + let input_group_values = values[0].as_list::(); + + // Adds the counts with the partial counts + group_indices + .iter() + .zip(input_group_values.iter()) + .for_each(|(&group_index, values_opt)| { + if let Some(values) = values_opt { + let values = values.as_primitive::(); + self.group_values[group_index].extend(values.values().iter()); + } + }); + + Ok(()) + } + + fn evaluate(&mut self, emit_to: EmitTo) -> Result { + todo!() + } + + fn state(&mut self, emit_to: EmitTo) -> Result> { + todo!() + } + + fn convert_to_state( + &self, + values: &[ArrayRef], + opt_filter: Option<&BooleanArray>, + ) -> Result> { + todo!() + } + + fn supports_convert_to_state(&self) -> bool { + todo!() + } + + fn size(&self) -> usize { + todo!() + } +} + /// The distinct median accumulator accumulates the raw input values /// as `ScalarValue`s /// From 7f10006434a4bb038c3b17f401f2460bb75005e5 Mon Sep 17 00:00:00 2001 From: kamille Date: Sun, 8 Dec 2024 01:53:39 +0800 Subject: [PATCH 02/21] impl `state`. --- datafusion/functions-aggregate/src/median.rs | 70 ++++++++++++++------ 1 file changed, 51 insertions(+), 19 deletions(-) diff --git a/datafusion/functions-aggregate/src/median.rs b/datafusion/functions-aggregate/src/median.rs index 7afed3fa57d4..4b7ff565d1ae 100644 --- a/datafusion/functions-aggregate/src/median.rs +++ b/datafusion/functions-aggregate/src/median.rs @@ -22,8 +22,9 @@ use std::sync::Arc; use arrow::array::{ downcast_integer, ArrowNumericType, BooleanArray, GenericListBuilder, - GenericListViewArray, + GenericListViewArray, ListArray, ListBuilder, PrimitiveArray, PrimitiveBuilder, }; +use arrow::buffer::{OffsetBuffer, ScalarBuffer}; use arrow::{ array::{ArrayRef, AsArray}, datatypes::{ @@ -235,21 +236,23 @@ impl Accumulator for MedianAccumulator { } } -/// The median accumulator accumulates the raw input values -/// as `ScalarValue`s +/// The median groups accumulator accumulates the raw input values +/// +/// For calculating the accurate medians of groups, we need to store all values +/// of groups before final evaluation. +/// And values in each group will be stored in a `Vec`, so the total group values +/// will be actually organized as a `Vec>`. +/// +/// In partial aggregation stage, the `values` /// -/// The intermediate state is represented as a List of scalar values updated by -/// `merge_batch` and a `Vec` of `ArrayRef` that are converted to scalar values -/// in the final evaluation step so that we avoid expensive conversions and -/// allocations during `update_batch`. #[derive(Debug)] -struct MedianGroupAccumulator { +struct MedianGroupsAccumulator { data_type: DataType, group_values: Vec>, null_state: NullState, } -impl MedianGroupAccumulator { +impl MedianGroupsAccumulator { pub fn new(data_type: DataType) -> Self { Self { data_type, @@ -259,7 +262,7 @@ impl MedianGroupAccumulator { } } -impl GroupsAccumulator for MedianGroupAccumulator { +impl GroupsAccumulator for MedianGroupsAccumulator { fn update_batch( &mut self, values: &[ArrayRef], @@ -295,7 +298,7 @@ impl GroupsAccumulator for MedianGroupAccumulator ) -> Result<()> { assert_eq!(values.len(), 1, "one argument to merge_batch"); - // The merged values should be organized like as a `ListArray` like: + // The merged values should be organized like as a `non-nullable ListArray` like: // // ```text // group 0: [1, 2, 3] @@ -306,26 +309,55 @@ impl GroupsAccumulator for MedianGroupAccumulator // ``` // let input_group_values = values[0].as_list::(); + assert!(input_group_values.null_count() == 0); + + // Ensure group values big enough + self.group_values.resize(total_num_groups, Vec::new()); - // Adds the counts with the partial counts + // Extend values to related groups group_indices .iter() .zip(input_group_values.iter()) .for_each(|(&group_index, values_opt)| { - if let Some(values) = values_opt { - let values = values.as_primitive::(); - self.group_values[group_index].extend(values.values().iter()); - } + let values = values_opt.unwrap(); + let values = values.as_primitive::(); + self.group_values[group_index].extend(values.values().iter()); }); Ok(()) } - fn evaluate(&mut self, emit_to: EmitTo) -> Result { - todo!() + fn state(&mut self, emit_to: EmitTo) -> Result> { + let emit_group_values = emit_to.take_needed(&mut self.group_values); + + // Build offsets + let mut offsets = Vec::with_capacity(self.group_values.len() + 1); + offsets.push(0); + let mut cur_len = 0; + for group_value in &emit_group_values { + cur_len += group_value.len() as i32; + offsets.push(cur_len); + } + let offsets = OffsetBuffer::new(ScalarBuffer::from(offsets)); + + // Build inner array + let flatten_group_values = + emit_group_values.into_iter().flatten().collect::>(); + let group_values_array = + PrimitiveArray::::new(ScalarBuffer::from(flatten_group_values), None); + + // Build the result list array + let result_list_array = ListArray::new( + Arc::new(Field::new_list_field(self.data_type.clone(), false)), + offsets, + Arc::new(group_values_array), + None, + ); + + Ok(vec![Arc::new(result_list_array)]) } - fn state(&mut self, emit_to: EmitTo) -> Result> { + fn evaluate(&mut self, emit_to: EmitTo) -> Result { todo!() } From 6f172ef5c9fb0193cae92619b26a3c1b78d57ad3 Mon Sep 17 00:00:00 2001 From: kamille Date: Sun, 8 Dec 2024 23:40:57 +0800 Subject: [PATCH 03/21] impl rest methods of `MedianGroupsAccumulator`. --- datafusion/functions-aggregate/src/median.rs | 64 ++++++++++++++++---- 1 file changed, 53 insertions(+), 11 deletions(-) diff --git a/datafusion/functions-aggregate/src/median.rs b/datafusion/functions-aggregate/src/median.rs index 4b7ff565d1ae..e5f203f17e7b 100644 --- a/datafusion/functions-aggregate/src/median.rs +++ b/datafusion/functions-aggregate/src/median.rs @@ -21,8 +21,8 @@ use std::mem::{size_of, size_of_val}; use std::sync::Arc; use arrow::array::{ - downcast_integer, ArrowNumericType, BooleanArray, GenericListBuilder, - GenericListViewArray, ListArray, ListBuilder, PrimitiveArray, PrimitiveBuilder, + downcast_integer, ArrowNumericType, BooleanArray, ListArray, PrimitiveArray, + PrimitiveBuilder, }; use arrow::buffer::{OffsetBuffer, ScalarBuffer}; use arrow::{ @@ -44,9 +44,10 @@ use datafusion_expr::{ Documentation, Signature, Volatility, }; use datafusion_expr::{EmitTo, GroupsAccumulator}; +use datafusion_functions_aggregate_common::aggregate::groups_accumulator::accumulate::accumulate; +use datafusion_functions_aggregate_common::aggregate::groups_accumulator::nulls::filtered_null_mask; use datafusion_functions_aggregate_common::utils::Hashable; use datafusion_macros::user_doc; -use datafusion_physical_expr::NullState; make_udaf_expr_and_func!( Median, @@ -249,7 +250,6 @@ impl Accumulator for MedianAccumulator { struct MedianGroupsAccumulator { data_type: DataType, group_values: Vec>, - null_state: NullState, } impl MedianGroupsAccumulator { @@ -257,7 +257,6 @@ impl MedianGroupsAccumulator { Self { data_type, group_values: Vec::new(), - null_state: NullState::new(), } } } @@ -275,11 +274,10 @@ impl GroupsAccumulator for MedianGroupsAccumulator GroupsAccumulator for MedianGroupsAccumulator Result> { + // Emit values let emit_group_values = emit_to.take_needed(&mut self.group_values); // Build offsets @@ -358,7 +357,17 @@ impl GroupsAccumulator for MedianGroupsAccumulator Result { - todo!() + // Emit values + let emit_group_values = emit_to.take_needed(&mut self.group_values); + + // Calculate median for each group + let mut evaluate_result_builder = PrimitiveBuilder::::new(); + for values in emit_group_values { + let median = calculate_median::(values); + evaluate_result_builder.append_option(median); + } + + Ok(Arc::new(evaluate_result_builder.finish())) } fn convert_to_state( @@ -366,15 +375,48 @@ impl GroupsAccumulator for MedianGroupsAccumulator, ) -> Result> { - todo!() + assert_eq!(values.len(), 1, "one argument to merge_batch"); + + let input_array = values[0].as_primitive::(); + + // Directly convert the input array to states, each row will be + // seen as a respective group. + // For detail, the `input_array` will be converted to a `ListArray`. + // And if row is `not null + not filtered`, it will be converted to a list + // with only one element; otherwise, this row in `ListArray` will be set + // to null. + + // Reuse values buffer in `input_array` to build `values` in `ListArray` + let values = PrimitiveArray::::new(input_array.values().clone(), None); + + // `offsets` in `ListArray`, each row as a list element + let offsets = (0..=input_array.len() as i32) + .into_iter() + .collect::>(); + let offsets = OffsetBuffer::new(ScalarBuffer::from(offsets)); + + // `nulls` for converted `ListArray` + let nulls = filtered_null_mask(opt_filter, input_array); + + let converted_list_array = Arc::new(ListArray::new( + Arc::new(Field::new_list_field(self.data_type.clone(), false)), + offsets, + Arc::new(values), + nulls, + )); + + Ok(vec![converted_list_array]) } fn supports_convert_to_state(&self) -> bool { - todo!() + true } fn size(&self) -> usize { - todo!() + self.group_values + .iter() + .map(|values| values.capacity() * size_of::()) + .sum::() } } From cacc693fb63a0fbd4165e0e1e93d67ed73a5f398 Mon Sep 17 00:00:00 2001 From: kamille Date: Mon, 9 Dec 2024 00:34:56 +0800 Subject: [PATCH 04/21] improve comments. --- datafusion/functions-aggregate/src/median.rs | 31 +++++++++++++------- 1 file changed, 21 insertions(+), 10 deletions(-) diff --git a/datafusion/functions-aggregate/src/median.rs b/datafusion/functions-aggregate/src/median.rs index e5f203f17e7b..c91e0c6786d8 100644 --- a/datafusion/functions-aggregate/src/median.rs +++ b/datafusion/functions-aggregate/src/median.rs @@ -241,11 +241,9 @@ impl Accumulator for MedianAccumulator { /// /// For calculating the accurate medians of groups, we need to store all values /// of groups before final evaluation. -/// And values in each group will be stored in a `Vec`, so the total group values +/// So values in each group will be stored in a `Vec`, so the total group values /// will be actually organized as a `Vec>`. /// -/// In partial aggregation stage, the `values` -/// #[derive(Debug)] struct MedianGroupsAccumulator { data_type: DataType, @@ -272,7 +270,7 @@ impl GroupsAccumulator for MedianGroupsAccumulator(); - // increment counts, update sums + // Push the `not nulls + not filtered` row into its group self.group_values.resize(total_num_groups, Vec::new()); accumulate( group_indices, @@ -296,30 +294,43 @@ impl GroupsAccumulator for MedianGroupsAccumulator Result<()> { assert_eq!(values.len(), 1, "one argument to merge_batch"); - // The merged values should be organized like as a `non-nullable ListArray` like: + // The merged values should be organized like as a `ListArray` which is nullable, + // but `values` in it is `non-nullable`(`values` with nulls usually generated + // from `convert_to_state`). + // + // Following is the possible and impossible input `values`: // + // # Possible values // ```text // group 0: [1, 2, 3] - // group 1: [4, 5] + // group 1: null (list array is nullable) // group 2: [6, 7, 8] // ... // group n: [...] // ``` // + // # Impossible values + // ```text + // group x: [1, 2, null] (values in list array is non-nullable) + // ``` + // let input_group_values = values[0].as_list::(); - assert!(input_group_values.null_count() == 0); // Ensure group values big enough self.group_values.resize(total_num_groups, Vec::new()); // Extend values to related groups + // TODO: avoid using iterator of the `ListArray`, this will lead to + // many calls of `slice` of its `values` array, and `slice` is not + // so efficient. group_indices .iter() .zip(input_group_values.iter()) .for_each(|(&group_index, values_opt)| { - let values = values_opt.unwrap(); - let values = values.as_primitive::(); - self.group_values[group_index].extend(values.values().iter()); + if let Some(values) = values_opt { + let values = values.as_primitive::(); + self.group_values[group_index].extend(values.values().iter()); + } }); Ok(()) From 11e675309656bd2dddd951147725ade1fc4b1b52 Mon Sep 17 00:00:00 2001 From: kamille Date: Mon, 9 Dec 2024 01:25:06 +0800 Subject: [PATCH 05/21] use `MedianGroupsAccumulator`. --- datafusion/functions-aggregate/src/median.rs | 42 +++++++++++++++++++- 1 file changed, 41 insertions(+), 1 deletion(-) diff --git a/datafusion/functions-aggregate/src/median.rs b/datafusion/functions-aggregate/src/median.rs index c91e0c6786d8..29166a07b2ad 100644 --- a/datafusion/functions-aggregate/src/median.rs +++ b/datafusion/functions-aggregate/src/median.rs @@ -37,7 +37,8 @@ use arrow::array::Array; use arrow::array::ArrowNativeTypeOp; use arrow::datatypes::{ArrowNativeType, ArrowPrimitiveType}; -use datafusion_common::{DataFusionError, HashSet, Result, ScalarValue}; +use datafusion_common::{internal_err, DataFusionError, HashSet, Result, ScalarValue}; +use datafusion_doc::DocSection; use datafusion_expr::function::StateFieldsArgs; use datafusion_expr::{ function::AccumulatorArgs, utils::format_state_name, Accumulator, AggregateUDFImpl, @@ -172,6 +173,45 @@ impl AggregateUDFImpl for Median { } } + fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool { + !args.is_distinct + } + + fn create_groups_accumulator( + &self, + args: AccumulatorArgs, + ) -> Result> { + let num_args = args.exprs.len(); + if num_args != 1 { + return internal_err!( + "median should only have 1 arg, but found num args:{}", + args.exprs.len() + ); + } + + let dt = args.exprs[0].data_type(args.schema)?; + + macro_rules! helper { + ($t:ty, $dt:expr) => { + Ok(Box::new(MedianGroupsAccumulator::<$t>::new($dt))) + }; + } + + downcast_integer! { + dt => (helper, dt), + DataType::Float16 => helper!(Float16Type, dt), + DataType::Float32 => helper!(Float32Type, dt), + DataType::Float64 => helper!(Float64Type, dt), + DataType::Decimal128(_, _) => helper!(Decimal128Type, dt), + DataType::Decimal256(_, _) => helper!(Decimal256Type, dt), + _ => Err(DataFusionError::NotImplemented(format!( + "MedianGroupsAccumulator not supported for {} with {}", + args.name, + dt, + ))), + } + } + fn aliases(&self) -> &[String] { &[] } From 955036f499e7e37563a4eabb90f5378ac4cc2e44 Mon Sep 17 00:00:00 2001 From: kamille Date: Mon, 9 Dec 2024 01:28:17 +0800 Subject: [PATCH 06/21] remove unused import. --- datafusion/functions-aggregate/src/median.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/datafusion/functions-aggregate/src/median.rs b/datafusion/functions-aggregate/src/median.rs index 29166a07b2ad..7189e00c4554 100644 --- a/datafusion/functions-aggregate/src/median.rs +++ b/datafusion/functions-aggregate/src/median.rs @@ -38,7 +38,6 @@ use arrow::array::ArrowNativeTypeOp; use arrow::datatypes::{ArrowNativeType, ArrowPrimitiveType}; use datafusion_common::{internal_err, DataFusionError, HashSet, Result, ScalarValue}; -use datafusion_doc::DocSection; use datafusion_expr::function::StateFieldsArgs; use datafusion_expr::{ function::AccumulatorArgs, utils::format_state_name, Accumulator, AggregateUDFImpl, From 17bd90bb4fed7ad2ca66d6da602f83cd5ab40067 Mon Sep 17 00:00:00 2001 From: kamille Date: Sat, 21 Dec 2024 13:32:53 +0800 Subject: [PATCH 07/21] add `group_median_table` to test group median. --- datafusion/functions-aggregate/src/median.rs | 6 ++--- .../sqllogictest/test_files/aggregate.slt | 25 +++++++++++++++++++ 2 files changed, 28 insertions(+), 3 deletions(-) diff --git a/datafusion/functions-aggregate/src/median.rs b/datafusion/functions-aggregate/src/median.rs index 7189e00c4554..2010b5f6c41f 100644 --- a/datafusion/functions-aggregate/src/median.rs +++ b/datafusion/functions-aggregate/src/median.rs @@ -448,14 +448,14 @@ impl GroupsAccumulator for MedianGroupsAccumulator bool { diff --git a/datafusion/sqllogictest/test_files/aggregate.slt b/datafusion/sqllogictest/test_files/aggregate.slt index bd3b40089519..d86a1a11a842 100644 --- a/datafusion/sqllogictest/test_files/aggregate.slt +++ b/datafusion/sqllogictest/test_files/aggregate.slt @@ -67,6 +67,31 @@ statement ok CREATE TABLE test (c1 BIGINT,c2 BIGINT) as values (0,null), (1,1), (null,1), (3,2), (3,2) +statement ok +CREATE TABLE group_median_table ( + col_group STRING, + col_i8 TINYINT, + col_i16 SMALLINT, + col_i32 INT, + col_i64 BIGINT, + col_u8 TINYINT UNSIGNED, + col_u16 SMALLINT UNSIGNED, + col_u32 INT UNSIGNED, + col_u64 BIGINT UNSIGNED, + col_f32 FLOAT, + col_f64 DOUBLE, + col_f64_nan DOUBLE +) as VALUES +( "group0", -128, -32768, -2147483648, arrow_cast(-9223372036854775808,'Int64'), 0, 0, 0, arrow_cast(0,'UInt64'), 1.1, 1.1, 1.1 ), +( "group0", -128, -32768, -2147483648, arrow_cast(-9223372036854775808,'Int64'), 0, 0, 0, arrow_cast(0,'UInt64'), 4.4, 4.4, arrow_cast('NAN','Float64') ), +( "group0", 100, 100, 100, arrow_cast(100,'Int64'), 100, 100, 100, arrow_cast(100,'UInt64'), 3.3, 3.3, arrow_cast('NAN','Float64') ), +( "group0", 127, 32767, 2147483647, arrow_cast(9223372036854775807,'Int64'), 255, 65535, 4294967295, 18446744073709551615, 2.2, 2.2, arrow_cast('NAN','Float64') ), +( "group1", -128, -32768, -2147483648, arrow_cast(-9223372036854775808,'Int64'), 0, 0, 0, arrow_cast(0,'UInt64'), 1.1, 1.1, 1.1 ), +( "group1", -128, -32768, -2147483648, arrow_cast(-9223372036854775808,'Int64'), 0, 0, 0, arrow_cast(0,'UInt64'), 4.4, 4.4, arrow_cast('NAN','Float64') ), +( "group1", 100, 100, 100, arrow_cast(100,'Int64'), 101, 100, 100, arrow_cast(100,'UInt64'), 3.3, 3.3, arrow_cast('NAN','Float64') ), +( "group1", 125, 32766, 2147483646, arrow_cast(9223372036854775806,'Int64'), 100, 101, 4294967294, arrow_cast(100,'UInt64'), 3.2, 5.5, arrow_cast('NAN','Float64') ), +( "group1", 127, 32767, 2147483647, arrow_cast(9223372036854775807,'Int64'), 255, 65535, 4294967295, 18446744073709551615, 2.2, 2.2, arrow_cast('NAN','Float64') ) + ####### # Error tests ####### From 28d871636a03014f6d3ff35ef3e263389a29540f Mon Sep 17 00:00:00 2001 From: kamille Date: Wed, 22 Jan 2025 17:24:19 +0800 Subject: [PATCH 08/21] complete group median test cases in aggregate slt. --- .../sqllogictest/test_files/aggregate.slt | 99 +++++++++++++++++-- 1 file changed, 90 insertions(+), 9 deletions(-) diff --git a/datafusion/sqllogictest/test_files/aggregate.slt b/datafusion/sqllogictest/test_files/aggregate.slt index d86a1a11a842..2274851e8939 100644 --- a/datafusion/sqllogictest/test_files/aggregate.slt +++ b/datafusion/sqllogictest/test_files/aggregate.slt @@ -82,15 +82,15 @@ CREATE TABLE group_median_table ( col_f64 DOUBLE, col_f64_nan DOUBLE ) as VALUES -( "group0", -128, -32768, -2147483648, arrow_cast(-9223372036854775808,'Int64'), 0, 0, 0, arrow_cast(0,'UInt64'), 1.1, 1.1, 1.1 ), -( "group0", -128, -32768, -2147483648, arrow_cast(-9223372036854775808,'Int64'), 0, 0, 0, arrow_cast(0,'UInt64'), 4.4, 4.4, arrow_cast('NAN','Float64') ), -( "group0", 100, 100, 100, arrow_cast(100,'Int64'), 100, 100, 100, arrow_cast(100,'UInt64'), 3.3, 3.3, arrow_cast('NAN','Float64') ), -( "group0", 127, 32767, 2147483647, arrow_cast(9223372036854775807,'Int64'), 255, 65535, 4294967295, 18446744073709551615, 2.2, 2.2, arrow_cast('NAN','Float64') ), -( "group1", -128, -32768, -2147483648, arrow_cast(-9223372036854775808,'Int64'), 0, 0, 0, arrow_cast(0,'UInt64'), 1.1, 1.1, 1.1 ), -( "group1", -128, -32768, -2147483648, arrow_cast(-9223372036854775808,'Int64'), 0, 0, 0, arrow_cast(0,'UInt64'), 4.4, 4.4, arrow_cast('NAN','Float64') ), -( "group1", 100, 100, 100, arrow_cast(100,'Int64'), 101, 100, 100, arrow_cast(100,'UInt64'), 3.3, 3.3, arrow_cast('NAN','Float64') ), -( "group1", 125, 32766, 2147483646, arrow_cast(9223372036854775806,'Int64'), 100, 101, 4294967294, arrow_cast(100,'UInt64'), 3.2, 5.5, arrow_cast('NAN','Float64') ), -( "group1", 127, 32767, 2147483647, arrow_cast(9223372036854775807,'Int64'), 255, 65535, 4294967295, 18446744073709551615, 2.2, 2.2, arrow_cast('NAN','Float64') ) +( 'group0', -128, -32768, -2147483648, arrow_cast(-9223372036854775808,'Int64'), 0, 0, 0, arrow_cast(0,'UInt64'), 1.1, 1.1, 1.1 ), +( 'group0', -128, -32768, -2147483648, arrow_cast(-9223372036854775808,'Int64'), 0, 0, 0, arrow_cast(0,'UInt64'), 4.4, 4.4, 1.1 ), +( 'group0', 100, 100, 100, arrow_cast(100,'Int64'), 100, 100, 100, arrow_cast(100,'UInt64'), 3.3, 3.3, arrow_cast('NAN','Float64') ), +( 'group0', 127, 32767, 2147483647, arrow_cast(9223372036854775807,'Int64'), 255, 65535, 4294967295, 18446744073709551615, 2.2, 2.2, arrow_cast('NAN','Float64') ), +( 'group1', -128, -32768, -2147483648, arrow_cast(-9223372036854775808,'Int64'), 0, 0, 0, arrow_cast(0,'UInt64'), 1.1, 1.1, 1.1 ), +( 'group1', -128, -32768, -2147483648, arrow_cast(-9223372036854775808,'Int64'), 0, 0, 0, arrow_cast(0,'UInt64'), 4.4, 4.4, arrow_cast('NAN','Float64') ), +( 'group1', 100, 100, 100, arrow_cast(100,'Int64'), 101, 100, 100, arrow_cast(100,'UInt64'), 3.3, 3.3, arrow_cast('NAN','Float64') ), +( 'group1', 125, 32766, 2147483646, arrow_cast(9223372036854775806,'Int64'), 100, 101, 4294967294, arrow_cast(100,'UInt64'), 3.2, 5.5, arrow_cast('NAN','Float64') ), +( 'group1', 127, 32767, 2147483647, arrow_cast(9223372036854775807,'Int64'), 255, 65535, 4294967295, 18446744073709551615, 2.2, 2.2, arrow_cast('NAN','Float64') ) ####### # Error tests @@ -6228,3 +6228,84 @@ physical_plan 14)--------------PlaceholderRowExec 15)------------ProjectionExec: expr=[1 as id, 2 as foo] 16)--------------PlaceholderRowExec + +####### +# Group median test +####### + +# group median i8 +query TI rowsort +SELECT col_group, median(col_i8) FROM group_median_table GROUP BY col_group +---- +group0 -14 +group1 100 + +# group median i16 +query TI +SELECT col_group, median(col_i16) FROM group_median_table GROUP BY col_group +---- +group0 -16334 +group1 100 + +# group median i32 +query TI +SELECT col_group, median(col_i32) FROM group_median_table GROUP BY col_group +---- +group0 -1073741774 +group1 100 + +# group median i64 +query TI +SELECT col_group, median(col_i64) FROM group_median_table GROUP BY col_group +---- +group0 -4611686018427387854 +group1 100 + +# group median u8 +query TI rowsort +SELECT col_group, median(col_u8) FROM group_median_table GROUP BY col_group +---- +group0 50 +group1 100 + +# group median u16 +query TI +SELECT col_group, median(col_u16) FROM group_median_table GROUP BY col_group +---- +group0 50 +group1 100 + +# group median u32 +query TI +SELECT col_group, median(col_u32) FROM group_median_table GROUP BY col_group +---- +group0 50 +group1 100 + +# group median u64 +query TI +SELECT col_group, median(col_u64) FROM group_median_table GROUP BY col_group +---- +group0 50 +group1 100 + +# group median f32 +query TR +SELECT col_group, median(col_f32) FROM group_median_table GROUP BY col_group +---- +group0 2.75 +group1 3.2 + +# group median f64 +query TR +SELECT col_group, median(col_f64) FROM group_median_table GROUP BY col_group +---- +group0 2.75 +group1 3.3 + +# group median f64_nan +query TR +SELECT col_group, median(col_f64_nan) FROM group_median_table GROUP BY col_group +---- +group0 NaN +group1 NaN From 1244df43d84a1283af49abb989634ea319035cd2 Mon Sep 17 00:00:00 2001 From: kamille Date: Wed, 22 Jan 2025 17:24:50 +0800 Subject: [PATCH 09/21] fix type of state. --- datafusion/functions-aggregate/src/median.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/datafusion/functions-aggregate/src/median.rs b/datafusion/functions-aggregate/src/median.rs index 2010b5f6c41f..56f2ee3f0e33 100644 --- a/datafusion/functions-aggregate/src/median.rs +++ b/datafusion/functions-aggregate/src/median.rs @@ -397,7 +397,7 @@ impl GroupsAccumulator for MedianGroupsAccumulator GroupsAccumulator for MedianGroupsAccumulator Date: Thu, 23 Jan 2025 13:08:40 +0100 Subject: [PATCH 10/21] Clippy --- datafusion/functions-aggregate/src/median.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/datafusion/functions-aggregate/src/median.rs b/datafusion/functions-aggregate/src/median.rs index 56f2ee3f0e33..810c3388eb87 100644 --- a/datafusion/functions-aggregate/src/median.rs +++ b/datafusion/functions-aggregate/src/median.rs @@ -441,7 +441,6 @@ impl GroupsAccumulator for MedianGroupsAccumulator>(); let offsets = OffsetBuffer::new(ScalarBuffer::from(offsets)); From fdc9b3358ac7666f1a8f4db3803bd6712ef34b15 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20Heres?= Date: Thu, 23 Jan 2025 13:12:41 +0100 Subject: [PATCH 11/21] Fmt --- datafusion/functions-aggregate/src/median.rs | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/datafusion/functions-aggregate/src/median.rs b/datafusion/functions-aggregate/src/median.rs index 810c3388eb87..38bf29487c2b 100644 --- a/datafusion/functions-aggregate/src/median.rs +++ b/datafusion/functions-aggregate/src/median.rs @@ -440,8 +440,7 @@ impl GroupsAccumulator for MedianGroupsAccumulator::new(input_array.values().clone(), None); // `offsets` in `ListArray`, each row as a list element - let offsets = (0..=input_array.len() as i32) - .collect::>(); + let offsets = (0..=input_array.len() as i32).collect::>(); let offsets = OffsetBuffer::new(ScalarBuffer::from(offsets)); // `nulls` for converted `ListArray` From e2f384fe56f07ae676947b598212c8403c2654a5 Mon Sep 17 00:00:00 2001 From: kamille Date: Fri, 24 Jan 2025 13:49:06 +0800 Subject: [PATCH 12/21] add fuzzy tests for median. --- .../core/tests/fuzz_cases/aggregate_fuzz.rs | 23 ++++++++++++++++++- 1 file changed, 22 insertions(+), 1 deletion(-) diff --git a/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs b/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs index 09d0c8d5ca2e..1ccc673e2f27 100644 --- a/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs @@ -19,10 +19,11 @@ use std::sync::Arc; use arrow::array::{Array, ArrayRef, AsArray, Int64Array}; use arrow::compute::{concat_batches, SortOptions}; -use arrow::datatypes::DataType; +use arrow::datatypes::{DataType, Decimal128Type}; use arrow::record_batch::RecordBatch; use arrow::util::pretty::pretty_format_batches; use arrow_array::types::Int64Type; +use arrow_ipc::Decimal; use arrow_schema::{ IntervalUnit, TimeUnit, DECIMAL128_MAX_PRECISION, DECIMAL128_MAX_SCALE, DECIMAL256_MAX_PRECISION, DECIMAL256_MAX_SCALE, @@ -148,6 +149,26 @@ async fn test_count() { .await; } +#[tokio::test(flavor = "multi_thread")] +async fn test_median() { + let data_gen_config = baseline_config(); + + // Queries like SELECT median(a), median(distinct) FROM fuzz_table GROUP BY b + let query_builder = QueryBuilder::new() + .with_table_name("fuzz_table") + .with_aggregate_function("median") + .with_distinct_aggregate_function("median") + // median only works on numeric columns + .with_aggregate_arguments(data_gen_config.numeric_columns()) + .set_group_by_columns(data_gen_config.all_columns()); + + AggregationFuzzerBuilder::from(data_gen_config) + .add_query_builder(query_builder) + .build() + .run() + .await; +} + /// Return a standard set of columns for testing data generation /// /// Includes numeric and string types From 4b8a4adc21e4a73ab09b7e85d250d6fc3364676b Mon Sep 17 00:00:00 2001 From: kamille Date: Fri, 24 Jan 2025 13:49:19 +0800 Subject: [PATCH 13/21] fix decimal. --- datafusion/functions-aggregate/src/median.rs | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/datafusion/functions-aggregate/src/median.rs b/datafusion/functions-aggregate/src/median.rs index 38bf29487c2b..2f74fd2e2f49 100644 --- a/datafusion/functions-aggregate/src/median.rs +++ b/datafusion/functions-aggregate/src/median.rs @@ -393,7 +393,8 @@ impl GroupsAccumulator for MedianGroupsAccumulator>(); let group_values_array = - PrimitiveArray::::new(ScalarBuffer::from(flatten_group_values), None); + PrimitiveArray::::new(ScalarBuffer::from(flatten_group_values), None) + .with_data_type(self.data_type.clone()); // Build the result list array let result_list_array = ListArray::new( @@ -411,7 +412,8 @@ impl GroupsAccumulator for MedianGroupsAccumulator::new(); + let mut evaluate_result_builder = + PrimitiveBuilder::::new().with_data_type(self.data_type.clone()); for values in emit_group_values { let median = calculate_median::(values); evaluate_result_builder.append_option(median); @@ -437,7 +439,8 @@ impl GroupsAccumulator for MedianGroupsAccumulator::new(input_array.values().clone(), None); + let values = PrimitiveArray::::new(input_array.values().clone(), None) + .with_data_type(self.data_type.clone()); // `offsets` in `ListArray`, each row as a list element let offsets = (0..=input_array.len() as i32).collect::>(); From 5603bc0e5189516afa0c0dfbf8ddc0244e421c5f Mon Sep 17 00:00:00 2001 From: kamille Date: Fri, 24 Jan 2025 17:02:48 +0800 Subject: [PATCH 14/21] fix clippy. --- datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs b/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs index 1ccc673e2f27..bcd88bae739a 100644 --- a/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs @@ -19,11 +19,10 @@ use std::sync::Arc; use arrow::array::{Array, ArrayRef, AsArray, Int64Array}; use arrow::compute::{concat_batches, SortOptions}; -use arrow::datatypes::{DataType, Decimal128Type}; +use arrow::datatypes::DataType; use arrow::record_batch::RecordBatch; use arrow::util::pretty::pretty_format_batches; use arrow_array::types::Int64Type; -use arrow_ipc::Decimal; use arrow_schema::{ IntervalUnit, TimeUnit, DECIMAL128_MAX_PRECISION, DECIMAL128_MAX_SCALE, DECIMAL256_MAX_PRECISION, DECIMAL256_MAX_SCALE, From 7e6a73aff7da92ecdafd933a7f7cc8b04772b773 Mon Sep 17 00:00:00 2001 From: kamille Date: Fri, 24 Jan 2025 17:10:49 +0800 Subject: [PATCH 15/21] improve comments. --- datafusion/functions-aggregate/src/median.rs | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/datafusion/functions-aggregate/src/median.rs b/datafusion/functions-aggregate/src/median.rs index 2f74fd2e2f49..f3dc8ca47f53 100644 --- a/datafusion/functions-aggregate/src/median.rs +++ b/datafusion/functions-aggregate/src/median.rs @@ -280,7 +280,7 @@ impl Accumulator for MedianAccumulator { /// /// For calculating the accurate medians of groups, we need to store all values /// of groups before final evaluation. -/// So values in each group will be stored in a `Vec`, so the total group values +/// So values in each group will be stored in a `Vec`, and the total group values /// will be actually organized as a `Vec>`. /// #[derive(Debug)] @@ -333,9 +333,9 @@ impl GroupsAccumulator for MedianGroupsAccumulator Result<()> { assert_eq!(values.len(), 1, "one argument to merge_batch"); - // The merged values should be organized like as a `ListArray` which is nullable, - // but `values` in it is `non-nullable`(`values` with nulls usually generated - // from `convert_to_state`). + // The merged values should be organized like as a `ListArray` which is nullable + // (input with nulls usually generated from `convert_to_state`), but `inner array` of + // `ListArray` is `non-nullable`. // // Following is the possible and impossible input `values`: // @@ -360,8 +360,8 @@ impl GroupsAccumulator for MedianGroupsAccumulator Date: Sun, 26 Jan 2025 23:24:20 +0800 Subject: [PATCH 16/21] add median cases with nulls. --- .../sqllogictest/test_files/aggregate.slt | 223 +++++++++++++++--- 1 file changed, 189 insertions(+), 34 deletions(-) diff --git a/datafusion/sqllogictest/test_files/aggregate.slt b/datafusion/sqllogictest/test_files/aggregate.slt index 2274851e8939..4838911649bd 100644 --- a/datafusion/sqllogictest/test_files/aggregate.slt +++ b/datafusion/sqllogictest/test_files/aggregate.slt @@ -68,8 +68,35 @@ CREATE TABLE test (c1 BIGINT,c2 BIGINT) as values (0,null), (1,1), (null,1), (3,2), (3,2) statement ok -CREATE TABLE group_median_table ( - col_group STRING, +CREATE TABLE group_median_table_non_nullable ( + col_group STRING NOT NULL, + col_i8 TINYINT NOT NULL, + col_i16 SMALLINT NOT NULL, + col_i32 INT NOT NULL, + col_i64 BIGINT NOT NULL, + col_u8 TINYINT UNSIGNED NOT NULL, + col_u16 SMALLINT UNSIGNED NOT NULL, + col_u32 INT UNSIGNED NOT NULL, + col_u64 BIGINT UNSIGNED NOT NULL, + col_f32 FLOAT NOT NULL, + col_f64 DOUBLE NOT NULL, + col_f64_nan DOUBLE NOT NULL, + col_decimal128 DECIMAL(10, 4) NOT NULL, + col_decimal256 NUMERIC(10, 4) NOT NULL +) as VALUES +( 'group0', -128, -32768, -2147483648, arrow_cast(-9223372036854775808,'Int64'), 0, 0, 0, arrow_cast(0,'UInt64'), 1.1, 1.1, 1.1, 0.0001, 0.0001 ), +( 'group0', -128, -32768, -2147483648, arrow_cast(-9223372036854775808,'Int64'), 0, 0, 0, arrow_cast(0,'UInt64'), 4.4, 4.4, 1.1, 0.0002, 0.0002 ), +( 'group0', 100, 100, 100, arrow_cast(100,'Int64'), 100, 100, 100, arrow_cast(100,'UInt64'), 3.3, 3.3, arrow_cast('NAN','Float64'), 0.0003, 0.0003 ), +( 'group0', 127, 32767, 2147483647, arrow_cast(9223372036854775807,'Int64'), 255, 65535, 4294967295, 18446744073709551615, 2.2, 2.2, arrow_cast('NAN','Float64'), 0.0004, 0.0004 ), +( 'group1', -128, -32768, -2147483648, arrow_cast(-9223372036854775808,'Int64'), 0, 0, 0, arrow_cast(0,'UInt64'), 1.1, 1.1, 1.1, 0.0001, 0.0001 ), +( 'group1', -128, -32768, -2147483648, arrow_cast(-9223372036854775808,'Int64'), 0, 0, 0, arrow_cast(0,'UInt64'), 4.4, 4.4, arrow_cast('NAN','Float64'), 0.0002, 0.0002 ), +( 'group1', 100, 100, 100, arrow_cast(100,'Int64'), 101, 100, 100, arrow_cast(100,'UInt64'), 3.3, 3.3, arrow_cast('NAN','Float64'), 0.0003, 0.0003 ), +( 'group1', 125, 32766, 2147483646, arrow_cast(9223372036854775806,'Int64'), 100, 101, 4294967294, arrow_cast(100,'UInt64'), 3.2, 5.5, arrow_cast('NAN','Float64'), 0.0004, 0.0004 ), +( 'group1', 127, 32767, 2147483647, arrow_cast(9223372036854775807,'Int64'), 255, 65535, 4294967295, 18446744073709551615, 2.2, 2.2, arrow_cast('NAN','Float64'), 0.0005, 0.0005 ) + +statement ok +CREATE TABLE group_median_table_nullable ( + col_group STRING NOT NULL, col_i8 TINYINT, col_i16 SMALLINT, col_i32 INT, @@ -80,17 +107,21 @@ CREATE TABLE group_median_table ( col_u64 BIGINT UNSIGNED, col_f32 FLOAT, col_f64 DOUBLE, - col_f64_nan DOUBLE + col_f64_nan DOUBLE, + col_decimal128 DECIMAL(10, 4), + col_decimal256 NUMERIC(10, 4) ) as VALUES -( 'group0', -128, -32768, -2147483648, arrow_cast(-9223372036854775808,'Int64'), 0, 0, 0, arrow_cast(0,'UInt64'), 1.1, 1.1, 1.1 ), -( 'group0', -128, -32768, -2147483648, arrow_cast(-9223372036854775808,'Int64'), 0, 0, 0, arrow_cast(0,'UInt64'), 4.4, 4.4, 1.1 ), -( 'group0', 100, 100, 100, arrow_cast(100,'Int64'), 100, 100, 100, arrow_cast(100,'UInt64'), 3.3, 3.3, arrow_cast('NAN','Float64') ), -( 'group0', 127, 32767, 2147483647, arrow_cast(9223372036854775807,'Int64'), 255, 65535, 4294967295, 18446744073709551615, 2.2, 2.2, arrow_cast('NAN','Float64') ), -( 'group1', -128, -32768, -2147483648, arrow_cast(-9223372036854775808,'Int64'), 0, 0, 0, arrow_cast(0,'UInt64'), 1.1, 1.1, 1.1 ), -( 'group1', -128, -32768, -2147483648, arrow_cast(-9223372036854775808,'Int64'), 0, 0, 0, arrow_cast(0,'UInt64'), 4.4, 4.4, arrow_cast('NAN','Float64') ), -( 'group1', 100, 100, 100, arrow_cast(100,'Int64'), 101, 100, 100, arrow_cast(100,'UInt64'), 3.3, 3.3, arrow_cast('NAN','Float64') ), -( 'group1', 125, 32766, 2147483646, arrow_cast(9223372036854775806,'Int64'), 100, 101, 4294967294, arrow_cast(100,'UInt64'), 3.2, 5.5, arrow_cast('NAN','Float64') ), -( 'group1', 127, 32767, 2147483647, arrow_cast(9223372036854775807,'Int64'), 255, 65535, 4294967295, 18446744073709551615, 2.2, 2.2, arrow_cast('NAN','Float64') ) +( 'group0', NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL ), +( 'group0', -128, -32768, -2147483648, arrow_cast(-9223372036854775808,'Int64'), 0, 0, 0, arrow_cast(0,'UInt64'), 1.1, 1.1, 1.1, 0.0001, 0.0001 ), +( 'group0', -128, -32768, -2147483648, arrow_cast(-9223372036854775808,'Int64'), 0, 0, 0, arrow_cast(0,'UInt64'), 4.4, 4.4, 1.1, 0.0002, 0.0002 ), +( 'group0', 100, 100, 100, arrow_cast(100,'Int64'), 100, 100, 100, arrow_cast(100,'UInt64'), 3.3, 3.3, arrow_cast('NAN','Float64'), 0.0003, 0.0003 ), +( 'group0', 127, 32767, 2147483647, arrow_cast(9223372036854775807,'Int64'), 255, 65535, 4294967295, 18446744073709551615, 2.2, 2.2, arrow_cast('NAN','Float64'), 0.0004, 0.0004 ), +( 'group1', NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL ), +( 'group1', -128, -32768, -2147483648, arrow_cast(-9223372036854775808,'Int64'), 0, 0, 0, arrow_cast(0,'UInt64'), 1.1, 1.1, 1.1, 0.0001, 0.0001 ), +( 'group1', -128, -32768, -2147483648, arrow_cast(-9223372036854775808,'Int64'), 0, 0, 0, arrow_cast(0,'UInt64'), 4.4, 4.4, arrow_cast('NAN','Float64'), 0.0002, 0.0002 ), +( 'group1', 100, 100, 100, arrow_cast(100,'Int64'), 101, 100, 100, arrow_cast(100,'UInt64'), 3.3, 3.3, arrow_cast('NAN','Float64'), 0.0003, 0.0003 ), +( 'group1', 125, 32766, 2147483646, arrow_cast(9223372036854775806,'Int64'), 100, 101, 4294967294, arrow_cast(100,'UInt64'), 3.2, 5.5, arrow_cast('NAN','Float64'), 0.0004, 0.0004 ), +( 'group1', 127, 32767, 2147483647, arrow_cast(9223372036854775807,'Int64'), 255, 65535, 4294967295, 18446744073709551615, 2.2, 2.2, arrow_cast('NAN','Float64'), 0.0005, 0.0005 ) ####### # Error tests @@ -6233,79 +6264,203 @@ physical_plan # Group median test ####### -# group median i8 +# group median i8 non-nullable query TI rowsort -SELECT col_group, median(col_i8) FROM group_median_table GROUP BY col_group +SELECT col_group, median(col_i8) FROM group_median_table_non_nullable GROUP BY col_group ---- group0 -14 group1 100 -# group median i16 +# group median i16 non-nullable query TI -SELECT col_group, median(col_i16) FROM group_median_table GROUP BY col_group +SELECT col_group, median(col_i16) FROM group_median_table_non_nullable GROUP BY col_group ---- group0 -16334 group1 100 -# group median i32 +# group median i32 non-nullable query TI -SELECT col_group, median(col_i32) FROM group_median_table GROUP BY col_group +SELECT col_group, median(col_i32) FROM group_median_table_non_nullable GROUP BY col_group ---- group0 -1073741774 group1 100 -# group median i64 +# group median i64 non-nullable query TI -SELECT col_group, median(col_i64) FROM group_median_table GROUP BY col_group +SELECT col_group, median(col_i64) FROM group_median_table_non_nullable GROUP BY col_group ---- group0 -4611686018427387854 group1 100 -# group median u8 +# group median u8 non-nullable query TI rowsort -SELECT col_group, median(col_u8) FROM group_median_table GROUP BY col_group +SELECT col_group, median(col_u8) FROM group_median_table_non_nullable GROUP BY col_group ---- group0 50 group1 100 -# group median u16 +# group median u16 non-nullable query TI -SELECT col_group, median(col_u16) FROM group_median_table GROUP BY col_group +SELECT col_group, median(col_u16) FROM group_median_table_non_nullable GROUP BY col_group ---- group0 50 group1 100 -# group median u32 +# group median u32 non-nullable query TI -SELECT col_group, median(col_u32) FROM group_median_table GROUP BY col_group +SELECT col_group, median(col_u32) FROM group_median_table_non_nullable GROUP BY col_group ---- group0 50 group1 100 -# group median u64 +# group median u64 non-nullable query TI -SELECT col_group, median(col_u64) FROM group_median_table GROUP BY col_group +SELECT col_group, median(col_u64) FROM group_median_table_non_nullable GROUP BY col_group ---- group0 50 group1 100 -# group median f32 +# group median f32 non-nullable query TR -SELECT col_group, median(col_f32) FROM group_median_table GROUP BY col_group +SELECT col_group, median(col_f32) FROM group_median_table_non_nullable GROUP BY col_group ---- group0 2.75 group1 3.2 -# group median f64 +# group median f64 non-nullable query TR -SELECT col_group, median(col_f64) FROM group_median_table GROUP BY col_group +SELECT col_group, median(col_f64) FROM group_median_table_non_nullable GROUP BY col_group ---- group0 2.75 group1 3.3 -# group median f64_nan +# group median f64_nan non-nullable +query TR +SELECT col_group, median(col_f64_nan) FROM group_median_table_non_nullable GROUP BY col_group +---- +group0 NaN +group1 NaN + +# group median decimal128 non-nullable +query TR +SELECT col_group, median(col_decimal128) FROM group_median_table_non_nullable GROUP BY col_group +---- +group0 0.0002 +group1 0.0003 + +# group median decimal256 non-nullable query TR -SELECT col_group, median(col_f64_nan) FROM group_median_table GROUP BY col_group +SELECT col_group, median(col_decimal256) FROM group_median_table_non_nullable GROUP BY col_group +---- +group0 0.0002 +group1 0.0003 + +# group median i8 nullable +query TI rowsort +SELECT col_group, median(col_i8) FROM group_median_table_nullable GROUP BY col_group +---- +group0 -14 +group1 100 + +# group median i16 nullable +query TI rowsort +SELECT col_group, median(col_i16) FROM group_median_table_nullable GROUP BY col_group +---- +group0 -16334 +group1 100 + +# group median i32 nullable +query TI rowsort +SELECT col_group, median(col_i32) FROM group_median_table_nullable GROUP BY col_group +---- +group0 -1073741774 +group1 100 + +# group median i64 nullable +query TI rowsort +SELECT col_group, median(col_i64) FROM group_median_table_nullable GROUP BY col_group +---- +group0 -4611686018427387854 +group1 100 + +# group median u8 nullable +query TI rowsort +SELECT col_group, median(col_u8) FROM group_median_table_nullable GROUP BY col_group +---- +group0 50 +group1 100 + +# group median u16 nullable +query TI rowsort +SELECT col_group, median(col_u16) FROM group_median_table_nullable GROUP BY col_group +---- +group0 50 +group1 100 + +# group median u32 nullable +query TI rowsort +SELECT col_group, median(col_u32) FROM group_median_table_nullable GROUP BY col_group +---- +group0 50 +group1 100 + +# group median u64 nullable +query TI rowsort +SELECT col_group, median(col_u64) FROM group_median_table_nullable GROUP BY col_group +---- +group0 50 +group1 100 + +# group median f32 nullable +query TR rowsort +SELECT col_group, median(col_f32) FROM group_median_table_nullable GROUP BY col_group +---- +group0 2.75 +group1 3.2 + +# group median f64 nullable +query TR rowsort +SELECT col_group, median(col_f64) FROM group_median_table_nullable GROUP BY col_group +---- +group0 2.75 +group1 3.3 + +# group median f64_nan nullable +query TR rowsort +SELECT col_group, median(col_f64_nan) FROM group_median_table_nullable GROUP BY col_group ---- group0 NaN group1 NaN + +# group median decimal128 nullable +query TR rowsort +SELECT col_group, median(col_decimal128) FROM group_median_table_nullable GROUP BY col_group +---- +group0 0.0002 +group1 0.0003 + +# group median decimal256 nullable +query TR rowsort +SELECT col_group, median(col_decimal256) FROM group_median_table_nullable GROUP BY col_group +---- +group0 0.0002 +group1 0.0003 + +# median with all nulls +statement ok +create table group_median_all_nulls( + a STRING NOT NULL, + b INT +) AS VALUES +( 'group0', NULL), +( 'group0', NULL), +( 'group0', NULL), +( 'group1', NULL), +( 'group1', NULL), +( 'group1', NULL) + +query TIT rowsort +SELECT a, median(b), arrow_typeof(median(b)) FROM group_median_all_nulls GROUP BY a +---- +group0 NULL Int32 +group1 NULL Int32 From 5eb7711b570c8eee141875567f2307b04eae4519 Mon Sep 17 00:00:00 2001 From: kamille <3144148605@qq.com> Date: Tue, 28 Jan 2025 01:45:39 +0800 Subject: [PATCH 17/21] Update datafusion/functions-aggregate/src/median.rs Co-authored-by: Andrew Lamb --- datafusion/functions-aggregate/src/median.rs | 2 ++ 1 file changed, 2 insertions(+) diff --git a/datafusion/functions-aggregate/src/median.rs b/datafusion/functions-aggregate/src/median.rs index f3dc8ca47f53..0ceb83fcfbd1 100644 --- a/datafusion/functions-aggregate/src/median.rs +++ b/datafusion/functions-aggregate/src/median.rs @@ -468,6 +468,8 @@ impl GroupsAccumulator for MedianGroupsAccumulator()) .sum::() + // account for size of self.grou_values too + + self.group_values.capacity() * size_of::>() } } From 5a52e7c439523e5b748a5517e77e74655e4a1e57 Mon Sep 17 00:00:00 2001 From: kamille Date: Tue, 28 Jan 2025 05:33:33 +0800 Subject: [PATCH 18/21] use `OffsetBuffer::new_unchecked` in `convert_to_state`. --- datafusion/functions-aggregate/src/median.rs | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/datafusion/functions-aggregate/src/median.rs b/datafusion/functions-aggregate/src/median.rs index 0ceb83fcfbd1..11962651f386 100644 --- a/datafusion/functions-aggregate/src/median.rs +++ b/datafusion/functions-aggregate/src/median.rs @@ -382,7 +382,7 @@ impl GroupsAccumulator for MedianGroupsAccumulator GroupsAccumulator for MedianGroupsAccumulator>(); - let offsets = OffsetBuffer::new(ScalarBuffer::from(offsets)); + // Safety: all checks in `OffsetBuffer::new` are ensured to pass + let offsets = unsafe { OffsetBuffer::new_unchecked(ScalarBuffer::from(offsets)) }; // `nulls` for converted `ListArray` let nulls = filtered_null_mask(opt_filter, input_array); From 6f56a6353bfc15b20582c6683698e72c8617a14b Mon Sep 17 00:00:00 2001 From: kamille Date: Tue, 28 Jan 2025 05:50:11 +0800 Subject: [PATCH 19/21] add todo. --- datafusion/functions-aggregate/src/median.rs | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/datafusion/functions-aggregate/src/median.rs b/datafusion/functions-aggregate/src/median.rs index 11962651f386..5a5f84b0d937 100644 --- a/datafusion/functions-aggregate/src/median.rs +++ b/datafusion/functions-aggregate/src/median.rs @@ -387,6 +387,13 @@ impl GroupsAccumulator for MedianGroupsAccumulator Date: Wed, 29 Jan 2025 13:53:20 +0800 Subject: [PATCH 20/21] remove assert and switch to i32 try from. --- datafusion/functions-aggregate/src/median.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/datafusion/functions-aggregate/src/median.rs b/datafusion/functions-aggregate/src/median.rs index 5a5f84b0d937..5cb0f39ba7b3 100644 --- a/datafusion/functions-aggregate/src/median.rs +++ b/datafusion/functions-aggregate/src/median.rs @@ -450,8 +450,8 @@ impl GroupsAccumulator for MedianGroupsAccumulator>(); + let offset_end = i32::try_from(input_array.len()).unwrap(); + let offsets = (0..=offset_end).collect::>(); // Safety: all checks in `OffsetBuffer::new` are ensured to pass let offsets = unsafe { OffsetBuffer::new_unchecked(ScalarBuffer::from(offsets)) }; From 5fd9d8e09d927170163382404b3a7ec44bd3c33a Mon Sep 17 00:00:00 2001 From: kamille Date: Wed, 29 Jan 2025 16:22:01 +0800 Subject: [PATCH 21/21] return error when try from failed. --- datafusion/functions-aggregate/src/median.rs | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/datafusion/functions-aggregate/src/median.rs b/datafusion/functions-aggregate/src/median.rs index 5cb0f39ba7b3..defbbe737a9d 100644 --- a/datafusion/functions-aggregate/src/median.rs +++ b/datafusion/functions-aggregate/src/median.rs @@ -37,7 +37,9 @@ use arrow::array::Array; use arrow::array::ArrowNativeTypeOp; use arrow::datatypes::{ArrowNativeType, ArrowPrimitiveType}; -use datafusion_common::{internal_err, DataFusionError, HashSet, Result, ScalarValue}; +use datafusion_common::{ + internal_datafusion_err, internal_err, DataFusionError, HashSet, Result, ScalarValue, +}; use datafusion_expr::function::StateFieldsArgs; use datafusion_expr::{ function::AccumulatorArgs, utils::format_state_name, Accumulator, AggregateUDFImpl, @@ -450,7 +452,11 @@ impl GroupsAccumulator for MedianGroupsAccumulator>(); // Safety: all checks in `OffsetBuffer::new` are ensured to pass let offsets = unsafe { OffsetBuffer::new_unchecked(ScalarBuffer::from(offsets)) };