diff --git a/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator.rs b/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator.rs index 1c97d22ec79c1..b5eb36c3fac7c 100644 --- a/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator.rs +++ b/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator.rs @@ -207,8 +207,11 @@ impl GroupsAccumulatorAdapter { let state = &mut self.states[group_idx]; sizes_pre += state.size(); - let values_to_accumulate = - slice_and_maybe_filter(&values, opt_filter.as_ref(), offsets)?; + let values_to_accumulate = slice_and_maybe_filter( + &values, + opt_filter.as_ref().map(|f| f.as_boolean()), + offsets, + )?; (f)(state.accumulator.as_mut(), &values_to_accumulate)?; // clear out the state so they are empty for next @@ -290,6 +293,7 @@ impl GroupsAccumulator for GroupsAccumulatorAdapter { result } + // filtered_null_mask(opt_filter, &values); fn state(&mut self, emit_to: EmitTo) -> Result> { let vec_size_pre = self.states.allocated_size(); let states = emit_to.take_needed(&mut self.states); @@ -348,6 +352,46 @@ impl GroupsAccumulator for GroupsAccumulatorAdapter { fn size(&self) -> usize { self.allocation_bytes } + + fn convert_to_state( + &self, + values: &[ArrayRef], + opt_filter: Option<&BooleanArray>, + ) -> Result> { + let num_rows = values[0].len(); + + // Each row has its respective group + let mut results = vec![]; + for row_idx in 0..num_rows { + // Create the empty accumulator for converting + let mut converted_accumulator = (self.factory)()?; + + // Convert row to states + let values_to_accumulate = + slice_and_maybe_filter(values, opt_filter, &[row_idx, row_idx + 1])?; + converted_accumulator.update_batch(&values_to_accumulate)?; + let states = converted_accumulator.state()?; + + // Resize results to have enough columns according to the converted states + results.resize_with(states.len(), || Vec::with_capacity(num_rows)); + + // Add the states to results + for (idx, state_val) in states.into_iter().enumerate() { + results[idx].push(state_val); + } + } + + let arrays = results + .into_iter() + .map(ScalarValue::iter_to_array) + .collect::>>()?; + + Ok(arrays) + } + + fn supports_convert_to_state(&self) -> bool { + true + } } /// Extension trait for [`Vec`] to account for allocations. @@ -384,7 +428,7 @@ fn get_filter_at_indices( // Copied from physical-plan pub(crate) fn slice_and_maybe_filter( aggr_array: &[ArrayRef], - filter_opt: Option<&ArrayRef>, + filter_opt: Option<&BooleanArray>, offsets: &[usize], ) -> Result> { let (offset, length) = (offsets[0], offsets[1] - offsets[0]); @@ -394,13 +438,12 @@ pub(crate) fn slice_and_maybe_filter( .collect(); if let Some(f) = filter_opt { - let filter_array = f.slice(offset, length); - let filter_array = filter_array.as_boolean(); + let filter = f.slice(offset, length); sliced_arrays .iter() .map(|array| { - compute::filter(array, filter_array).map_err(|e| arrow_datafusion_err!(e)) + compute::filter(&array, &filter).map_err(|e| arrow_datafusion_err!(e)) }) .collect() } else { diff --git a/datafusion/sqllogictest/test_files/aggregate_skip_partial.slt b/datafusion/sqllogictest/test_files/aggregate_skip_partial.slt index ab1c7e78f1ffc..a2e51cffacf7e 100644 --- a/datafusion/sqllogictest/test_files/aggregate_skip_partial.slt +++ b/datafusion/sqllogictest/test_files/aggregate_skip_partial.slt @@ -133,6 +133,51 @@ GROUP BY 1, 2 ORDER BY 1 LIMIT 5; -2117946883 d -2117946883 NULL NULL NULL -2098805236 c -2098805236 NULL NULL NULL +query ITIIII +SELECT c5, c1, + MEDIAN(c5), + MEDIAN(CASE WHEN c1 = 'a' THEN c5 ELSE NULL END), + MEDIAN(c5) FILTER (WHERE c1 = 'b'), + MEDIAN(CASE WHEN c1 = 'a' THEN c5 ELSE NULL END) FILTER (WHERE c1 = 'b') +FROM aggregate_test_100 +GROUP BY 1, 2 ORDER BY 1 LIMIT 5; +---- +-2141999138 c -2141999138 NULL NULL NULL +-2141451704 a -2141451704 -2141451704 NULL NULL +-2138770630 b -2138770630 NULL -2138770630 NULL +-2117946883 d -2117946883 NULL NULL NULL +-2098805236 c -2098805236 NULL NULL NULL + +query ITIIII +SELECT c5, c1, + APPROX_MEDIAN(c5), + APPROX_MEDIAN(CASE WHEN c1 = 'a' THEN c5 ELSE NULL END), + APPROX_MEDIAN(c5) FILTER (WHERE c1 = 'b'), + APPROX_MEDIAN(CASE WHEN c1 = 'a' THEN c5 ELSE NULL END) FILTER (WHERE c1 = 'b') +FROM aggregate_test_100 +GROUP BY 1, 2 ORDER BY 1 LIMIT 5; +---- +-2141999138 c -2141999138 NULL NULL NULL +-2141451704 a -2141451704 -2141451704 NULL NULL +-2138770630 b -2138770630 NULL -2138770630 NULL +-2117946883 d -2117946883 NULL NULL NULL +-2098805236 c -2098805236 NULL NULL NULL + +query ITIIII +SELECT c5, c1, + APPROX_DISTINCT(c5), + APPROX_DISTINCT(CASE WHEN c1 = 'a' THEN c5 ELSE NULL END), + APPROX_DISTINCT(c5) FILTER (WHERE c1 = 'b'), + APPROX_DISTINCT(CASE WHEN c1 = 'a' THEN c5 ELSE NULL END) FILTER (WHERE c1 = 'b') +FROM aggregate_test_100 +GROUP BY 1, 2 ORDER BY 1 LIMIT 5; +---- +-2141999138 c 1 0 0 0 +-2141451704 a 1 1 0 0 +-2138770630 b 1 0 1 0 +-2117946883 d 1 0 0 0 +-2098805236 c 1 0 0 0 + # FIXME: add bool_and(v3) column when issue fixed # ISSUE https://github.com/apache/datafusion/issues/11846 query TBBB rowsort @@ -222,6 +267,36 @@ SELECT c2, sum(c5), sum(c11) FROM aggregate_test_100 GROUP BY c2 ORDER BY c2; 4 16155718643 9.531112968922 5 6449337880 7.074412226677 +# Test median for int / float +query IIR +SELECT c2, median(c5), median(c11) FROM aggregate_test_100 GROUP BY c2 ORDER BY c2; +---- +1 23971150 0.5922606 +2 -562486880 0.43422085 +3 240273900 0.40199697 +4 762932956 0.48515016 +5 604973998 0.49842384 + +# Test approx_median for int / float +query IIR +SELECT c2, approx_median(c5), approx_median(c11) FROM aggregate_test_100 GROUP BY c2 ORDER BY c2; +---- +1 191655437 0.59926736 +2 -587831330 0.43230486 +3 240273900 0.40199697 +4 762932956 0.48515016 +5 593204320 0.5156586 + +# Test approx_distinct for varchar / int +query III +SELECT c2, approx_distinct(c1), approx_distinct(c5) FROM aggregate_test_100 GROUP BY c2 ORDER BY c2; +---- +1 5 22 +2 5 22 +3 5 19 +4 5 23 +5 5 14 + # Test count with nullable fields query III SELECT c2, count(c3), count(c11) FROM aggregate_test_100_null GROUP BY c2 ORDER BY c2; @@ -252,6 +327,36 @@ SELECT c2, sum(c3), sum(c11) FROM aggregate_test_100 GROUP BY c2 ORDER BY c2; 4 29 9.531112968922 5 -194 7.074412226677 +# Test median with nullable fields +query IIR +SELECT c2, median(c3), median(c11) FROM aggregate_test_100_null GROUP BY c2 ORDER BY c2; +---- +1 12 0.6067944 +2 1 0.46076488 +3 14 0.40154034 +4 -17 0.48515016 +5 -35 0.5536642 + +# Test approx_median with nullable fields +query IIR +SELECT c2, approx_median(c3), approx_median(c11) FROM aggregate_test_100_null GROUP BY c2 ORDER BY c2; +---- +1 12 0.6067944 +2 1 0.46076488 +3 14 0.40154034 +4 -7 0.48515016 +5 -39 0.5536642 + +# Test approx_distinct with nullable fields +query II +SELECT c2, approx_distinct(c3) FROM aggregate_test_100_null GROUP BY c2 ORDER BY c2; +---- +1 19 +2 16 +3 13 +4 16 +5 12 + # Test avg for tinyint / float query TRR SELECT @@ -338,6 +443,48 @@ FROM aggregate_test_100 GROUP BY c2 ORDER BY c2; 4 417 5 284 +# Test approx_distinct with filter +query III +SELECT + c2, + approx_distinct(c3) FILTER (WHERE c3 > 0), + approx_distinct(c3) FILTER (WHERE c11 > 10) +FROM aggregate_test_100 GROUP BY c2 ORDER BY c2; +---- +1 13 0 +2 12 0 +3 11 0 +4 13 0 +5 5 0 + +# Test median with filter +query III +SELECT + c2, + median(c3) FILTER (WHERE c3 > 0), + median(c3) FILTER (WHERE c3 < 0) +FROM aggregate_test_100 GROUP BY c2 ORDER BY c2; +---- +1 57 -56 +2 52 -60 +3 71 -74 +4 65 -69 +5 64 -59 + +# Test approx_median with filter +query III +SELECT + c2, + approx_median(c3) FILTER (WHERE c3 > 0), + approx_median(c3) FILTER (WHERE c3 < 0) +FROM aggregate_test_100 GROUP BY c2 ORDER BY c2; +---- +1 57 -56 +2 52 -60 +3 71 -76 +4 65 -64 +5 64 -59 + # Test count with nullable fields and filter query III SELECT c2, @@ -421,6 +568,79 @@ FROM aggregate_test_100_null GROUP BY c2 ORDER BY c2; 4 -171 56 2.10740506649 1.939846396446 5 -86 -76 1.8741710186 1.600569307804 +# Test approx_distinct with nullable fields and filter +query II +SELECT c2, + approx_distinct(c3) FILTER (WHERE c5 > 0) +FROM aggregate_test_100_null GROUP BY c2 ORDER BY c2; +---- +1 11 +2 6 +3 6 +4 11 +5 8 + +# Test approx_distinct with nullable fields and nullable filter +query II +SELECT c2, + approx_distinct(c3) FILTER (WHERE c11 > 0.5) +FROM aggregate_test_100_null GROUP BY c2 ORDER BY c2; +---- +1 10 +2 6 +3 3 +4 3 +5 6 + +# Test median with nullable fields and filter +query IIR +SELECT c2, + median(c3) FILTER (WHERE c5 > 0), + median(c11) FILTER (WHERE c5 < 0) +FROM aggregate_test_100_null GROUP BY c2 ORDER BY c2; +---- +1 -5 0.6623719 +2 15 0.52930677 +3 13 0.32792538 +4 -38 0.49774808 +5 -18 0.49842384 + +# Test min / max with nullable fields and nullable filter +query II +SELECT c2, + median(c3) FILTER (WHERE c11 > 0.5) +FROM aggregate_test_100_null GROUP BY c2 ORDER BY c2; +---- +1 33 +2 -29 +3 22 +4 -90 +5 -22 + +# Test approx_median with nullable fields and filter +query IIR +SELECT c2, + approx_median(c3) FILTER (WHERE c5 > 0), + approx_median(c11) FILTER (WHERE c5 < 0) +FROM aggregate_test_100_null GROUP BY c2 ORDER BY c2; +---- +1 -5 0.6623719 +2 12 0.52930677 +3 13 0.32792538 +4 -38 0.49774808 +5 -21 0.47652745 + +# Test approx_median with nullable fields and nullable filter +query II +SELECT c2, + approx_median(c3) FILTER (WHERE c11 > 0.5) +FROM aggregate_test_100_null GROUP BY c2 ORDER BY c2; +---- +1 35 +2 -29 +3 22 +4 -90 +5 -32 statement ok DROP TABLE aggregate_test_100_null;