diff --git a/datafusion/expr/src/aggregate_function.rs b/datafusion/expr/src/aggregate_function.rs index 3debe21800232..293e679a4fd00 100644 --- a/datafusion/expr/src/aggregate_function.rs +++ b/datafusion/expr/src/aggregate_function.rs @@ -82,6 +82,10 @@ pub enum AggregateFunction { RegrSYY, /// Sum of products of pairs of numbers RegrSXY, + /// Continuous percentile + QuantileCont, + /// Discrete percentile + QuantileDisc, /// Approximate continuous percentile function ApproxPercentileCont, /// Approximate continuous percentile function with weight @@ -132,6 +136,8 @@ impl AggregateFunction { RegrSXX => "REGR_SXX", RegrSYY => "REGR_SYY", RegrSXY => "REGR_SXY", + QuantileCont => "QUANTILE_CONT", + QuantileDisc => "QUANTILE_DISC", ApproxPercentileCont => "APPROX_PERCENTILE_CONT", ApproxPercentileContWithWeight => "APPROX_PERCENTILE_CONT_WITH_WEIGHT", ApproxMedian => "APPROX_MEDIAN", @@ -191,6 +197,8 @@ impl FromStr for AggregateFunction { "regr_sxx" => AggregateFunction::RegrSXX, "regr_syy" => AggregateFunction::RegrSYY, "regr_sxy" => AggregateFunction::RegrSXY, + "quantile_cont" => AggregateFunction::QuantileCont, + "quantile_disc" => AggregateFunction::QuantileDisc, // approximate "approx_distinct" => AggregateFunction::ApproxDistinct, "approx_median" => AggregateFunction::ApproxMedian, @@ -293,9 +301,10 @@ impl AggregateFunction { AggregateFunction::ApproxPercentileContWithWeight => { Ok(coerced_data_types[0].clone()) } - AggregateFunction::ApproxMedian | AggregateFunction::Median => { - Ok(coerced_data_types[0].clone()) - } + AggregateFunction::ApproxMedian + | AggregateFunction::Median + | AggregateFunction::QuantileCont + | AggregateFunction::QuantileDisc => Ok(coerced_data_types[0].clone()), AggregateFunction::Grouping => Ok(DataType::Int32), AggregateFunction::FirstValue | AggregateFunction::LastValue => { Ok(coerced_data_types[0].clone()) @@ -380,6 +389,16 @@ impl AggregateFunction { | AggregateFunction::RegrSXY => { Signature::uniform(2, NUMERICS.to_vec(), Volatility::Immutable) } + AggregateFunction::QuantileCont | AggregateFunction::QuantileDisc => { + // signature: quantile_*(NUMERICS, float64) + Signature::one_of( + NUMERICS + .iter() + .map(|t| TypeSignature::Exact(vec![t.clone(), DataType::Float64])) + .collect(), + Volatility::Immutable, + ) + } AggregateFunction::ApproxPercentileCont => { // Accept any numeric value paired with a float64 percentile let with_tdigest_size = NUMERICS.iter().map(|t| { diff --git a/datafusion/expr/src/type_coercion/aggregates.rs b/datafusion/expr/src/type_coercion/aggregates.rs index 094b6e9da4bf4..5edee1fef6b16 100644 --- a/datafusion/expr/src/type_coercion/aggregates.rs +++ b/datafusion/expr/src/type_coercion/aggregates.rs @@ -213,6 +213,21 @@ pub fn coerce_types( } Ok(input_types.to_vec()) } + AggregateFunction::QuantileCont | AggregateFunction::QuantileDisc => { + let valid_arg0_types = [NUMERICS.to_vec(), vec![DataType::Null]].concat(); + let valid_arg1_types = NUMERICS; + let input_types_valid = // number of input already checked before + valid_arg0_types.contains(&input_types[0]) && valid_arg1_types.contains(&input_types[1]); + if !input_types_valid { + return plan_err!( + "The function {:?} does not support inputs of type {:?}, {:?}.", + agg_fun, + input_types[0], + input_types[1] + ); + } + Ok(input_types.to_vec()) + } AggregateFunction::ApproxPercentileCont => { if !is_approx_percentile_cont_supported_arg_type(&input_types[0]) { return plan_err!( diff --git a/datafusion/physical-expr/src/aggregate/approx_percentile_cont.rs b/datafusion/physical-expr/src/aggregate/approx_percentile_cont.rs index 86d173be84d9b..3d96d2d15256b 100644 --- a/datafusion/physical-expr/src/aggregate/approx_percentile_cont.rs +++ b/datafusion/physical-expr/src/aggregate/approx_percentile_cont.rs @@ -17,7 +17,7 @@ use crate::aggregate::tdigest::TryIntoF64; use crate::aggregate::tdigest::{TDigest, DEFAULT_MAX_SIZE}; -use crate::aggregate::utils::down_cast_any_ref; +use crate::aggregate::utils::{down_cast_any_ref, validate_input_percentile_expr}; use crate::expressions::{format_state_name, Literal}; use crate::{AggregateExpr, PhysicalExpr}; use arrow::{ @@ -28,7 +28,6 @@ use arrow::{ datatypes::{DataType, Field}, }; use datafusion_common::internal_err; -use datafusion_common::plan_err; use datafusion_common::DataFusionError; use datafusion_common::Result; use datafusion_common::{downcast_value, ScalarValue}; @@ -132,35 +131,6 @@ impl PartialEq for ApproxPercentileCont { } } -fn validate_input_percentile_expr(expr: &Arc) -> Result { - // Extract the desired percentile literal - let lit = expr - .as_any() - .downcast_ref::() - .ok_or_else(|| { - DataFusionError::Internal( - "desired percentile argument must be float literal".to_string(), - ) - })? - .value(); - let percentile = match lit { - ScalarValue::Float32(Some(q)) => *q as f64, - ScalarValue::Float64(Some(q)) => *q, - got => return Err(DataFusionError::NotImplemented(format!( - "Percentile value for 'APPROX_PERCENTILE_CONT' must be Float32 or Float64 literal (got data type {})", - got.get_datatype() - ))) - }; - - // Ensure the percentile is between 0 and 1. - if !(0.0..=1.0).contains(&percentile) { - return plan_err!( - "Percentile value must be between 0.0 and 1.0 inclusive, {percentile} is invalid" - ); - } - Ok(percentile) -} - fn validate_input_max_size_expr(expr: &Arc) -> Result { // Extract the desired percentile literal let lit = expr diff --git a/datafusion/physical-expr/src/aggregate/build_in.rs b/datafusion/physical-expr/src/aggregate/build_in.rs index bbccb6502665b..c615df6c03494 100644 --- a/datafusion/physical-expr/src/aggregate/build_in.rs +++ b/datafusion/physical-expr/src/aggregate/build_in.rs @@ -26,6 +26,7 @@ //! * Signature: see `Signature` //! * Return type: a function `(arg_types) -> return_type`. E.g. for min, ([f32]) -> f32, ([f64]) -> f64. +use crate::aggregate::percentile::PercentileInterpolationType; use crate::aggregate::regr::RegrType; use crate::{expressions, AggregateExpr, PhysicalExpr, PhysicalSortExpr}; use arrow::datatypes::Schema; @@ -329,6 +330,26 @@ pub fn create_aggregate_expr( fun ))); } + (AggregateFunction::QuantileCont, false) => Arc::new(expressions::Quantile::new( + name, + PercentileInterpolationType::Continuous, + input_phy_exprs[0].clone(), + input_phy_exprs[1].clone(), + rt_type, + )?), + (AggregateFunction::QuantileDisc, false) => Arc::new(expressions::Quantile::new( + name, + PercentileInterpolationType::Discrete, + input_phy_exprs[0].clone(), + input_phy_exprs[1].clone(), + rt_type, + )?), + (AggregateFunction::QuantileDisc | AggregateFunction::QuantileCont, true) => { + return Err(DataFusionError::NotImplemented(format!( + "{}(DISTINCT) aggregations are not available", + fun + ))); + } (AggregateFunction::ApproxPercentileCont, false) => { if input_phy_exprs.len() == 2 { Arc::new(expressions::ApproxPercentileCont::new( diff --git a/datafusion/physical-expr/src/aggregate/median.rs b/datafusion/physical-expr/src/aggregate/median.rs deleted file mode 100644 index 2f60966093192..0000000000000 --- a/datafusion/physical-expr/src/aggregate/median.rs +++ /dev/null @@ -1,399 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! # Median - -use crate::aggregate::utils::down_cast_any_ref; -use crate::expressions::format_state_name; -use crate::{AggregateExpr, PhysicalExpr}; -use arrow::array::{Array, ArrayRef, UInt32Array}; -use arrow::compute::sort_to_indices; -use arrow::datatypes::{DataType, Field}; -use datafusion_common::internal_err; -use datafusion_common::{DataFusionError, Result, ScalarValue}; -use datafusion_expr::Accumulator; -use std::any::Any; -use std::sync::Arc; - -/// MEDIAN aggregate expression. This uses a lot of memory because all values need to be -/// stored in memory before a result can be computed. If an approximation is sufficient -/// then APPROX_MEDIAN provides a much more efficient solution. -#[derive(Debug)] -pub struct Median { - name: String, - expr: Arc, - data_type: DataType, -} - -impl Median { - /// Create a new MEDIAN aggregate function - pub fn new( - expr: Arc, - name: impl Into, - data_type: DataType, - ) -> Self { - Self { - name: name.into(), - expr, - data_type, - } - } -} - -impl AggregateExpr for Median { - /// Return a reference to Any that can be used for downcasting - fn as_any(&self) -> &dyn Any { - self - } - - fn field(&self) -> Result { - Ok(Field::new(&self.name, self.data_type.clone(), true)) - } - - fn create_accumulator(&self) -> Result> { - Ok(Box::new(MedianAccumulator { - data_type: self.data_type.clone(), - arrays: vec![], - all_values: vec![], - })) - } - - fn state_fields(&self) -> Result> { - //Intermediate state is a list of the elements we have collected so far - let field = Field::new("item", self.data_type.clone(), true); - let data_type = DataType::List(Arc::new(field)); - - Ok(vec![Field::new( - format_state_name(&self.name, "median"), - data_type, - true, - )]) - } - - fn expressions(&self) -> Vec> { - vec![self.expr.clone()] - } - - fn name(&self) -> &str { - &self.name - } -} - -impl PartialEq for Median { - fn eq(&self, other: &dyn Any) -> bool { - down_cast_any_ref(other) - .downcast_ref::() - .map(|x| { - self.name == x.name - && self.data_type == x.data_type - && self.expr.eq(&x.expr) - }) - .unwrap_or(false) - } -} - -#[derive(Debug)] -/// The median accumulator accumulates the raw input values -/// as `ScalarValue`s -/// -/// The intermediate state is represented as a List of scalar values updated by -/// `merge_batch` and a `Vec` of `ArrayRef` that are converted to scalar values -/// in the final evaluation step so that we avoid expensive conversions and -/// allocations during `update_batch`. -struct MedianAccumulator { - data_type: DataType, - arrays: Vec, - all_values: Vec, -} - -impl Accumulator for MedianAccumulator { - fn state(&self) -> Result> { - let all_values = to_scalar_values(&self.arrays)?; - let state = ScalarValue::new_list(Some(all_values), self.data_type.clone()); - - Ok(vec![state]) - } - - fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { - assert_eq!(values.len(), 1); - let array = &values[0]; - - // Defer conversions to scalar values to final evaluation. - assert_eq!(array.data_type(), &self.data_type); - self.arrays.push(array.clone()); - - Ok(()) - } - - fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { - assert_eq!(states.len(), 1); - - let array = &states[0]; - assert!(matches!(array.data_type(), DataType::List(_))); - for index in 0..array.len() { - match ScalarValue::try_from_array(array, index)? { - ScalarValue::List(Some(mut values), _) => { - self.all_values.append(&mut values); - } - ScalarValue::List(None, _) => {} // skip empty state - v => { - return internal_err!( - "unexpected state in median. Expected DataType::List, got {v:?}" - ) - } - } - } - Ok(()) - } - - fn evaluate(&self) -> Result { - let batch_values = to_scalar_values(&self.arrays)?; - - if !self - .all_values - .iter() - .chain(batch_values.iter()) - .any(|v| !v.is_null()) - { - return ScalarValue::try_from(&self.data_type); - } - - // Create an array of all the non null values and find the - // sorted indexes - let array = ScalarValue::iter_to_array( - self.all_values - .iter() - .chain(batch_values.iter()) - // ignore null values - .filter(|v| !v.is_null()) - .cloned(), - )?; - - // find the mid point - let len = array.len(); - let mid = len / 2; - - // only sort up to the top size/2 elements - let limit = Some(mid + 1); - let options = None; - let indices = sort_to_indices(&array, options, limit)?; - - // pick the relevant indices in the original arrays - let result = if len >= 2 && len % 2 == 0 { - // even number of values, average the two mid points - let s1 = scalar_at_index(&array, &indices, mid - 1)?; - let s2 = scalar_at_index(&array, &indices, mid)?; - match s1.add(s2)? { - ScalarValue::Int8(Some(v)) => ScalarValue::Int8(Some(v / 2)), - ScalarValue::Int16(Some(v)) => ScalarValue::Int16(Some(v / 2)), - ScalarValue::Int32(Some(v)) => ScalarValue::Int32(Some(v / 2)), - ScalarValue::Int64(Some(v)) => ScalarValue::Int64(Some(v / 2)), - ScalarValue::UInt8(Some(v)) => ScalarValue::UInt8(Some(v / 2)), - ScalarValue::UInt16(Some(v)) => ScalarValue::UInt16(Some(v / 2)), - ScalarValue::UInt32(Some(v)) => ScalarValue::UInt32(Some(v / 2)), - ScalarValue::UInt64(Some(v)) => ScalarValue::UInt64(Some(v / 2)), - ScalarValue::Float32(Some(v)) => ScalarValue::Float32(Some(v / 2.0)), - ScalarValue::Float64(Some(v)) => ScalarValue::Float64(Some(v / 2.0)), - ScalarValue::Decimal128(Some(v), p, s) => { - ScalarValue::Decimal128(Some(v / 2), p, s) - } - v => { - return internal_err!("Unsupported type in MedianAccumulator: {v:?}") - } - } - } else { - // odd number of values, pick that one - scalar_at_index(&array, &indices, mid)? - }; - - Ok(result) - } - - fn size(&self) -> usize { - let arrays_size: usize = self.arrays.iter().map(|a| a.len()).sum(); - - std::mem::size_of_val(self) - + ScalarValue::size_of_vec(&self.all_values) - + arrays_size - - std::mem::size_of_val(&self.all_values) - + self.data_type.size() - - std::mem::size_of_val(&self.data_type) - } -} - -fn to_scalar_values(arrays: &[ArrayRef]) -> Result> { - let num_values: usize = arrays.iter().map(|a| a.len()).sum(); - let mut all_values = Vec::with_capacity(num_values); - - for array in arrays { - for index in 0..array.len() { - all_values.push(ScalarValue::try_from_array(&array, index)?); - } - } - - Ok(all_values) -} - -/// Given a returns `array[indicies[indicie_index]]` as a `ScalarValue` -fn scalar_at_index( - array: &dyn Array, - indices: &UInt32Array, - indicies_index: usize, -) -> Result { - let array_index = indices - .value(indicies_index) - .try_into() - .expect("Convert uint32 to usize"); - ScalarValue::try_from_array(array, array_index) -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::expressions::col; - use crate::expressions::tests::aggregate; - use crate::generic_test_op; - use arrow::record_batch::RecordBatch; - use arrow::{array::*, datatypes::*}; - use datafusion_common::Result; - - #[test] - fn median_decimal() -> Result<()> { - // test median - let array: ArrayRef = Arc::new( - (1..7) - .map(Some) - .collect::() - .with_precision_and_scale(10, 4)?, - ); - - generic_test_op!( - array, - DataType::Decimal128(10, 4), - Median, - ScalarValue::Decimal128(Some(3), 10, 4) - ) - } - - #[test] - fn median_decimal_with_nulls() -> Result<()> { - let array: ArrayRef = Arc::new( - (1..6) - .map(|i| if i == 2 { None } else { Some(i) }) - .collect::() - .with_precision_and_scale(10, 4)?, - ); - generic_test_op!( - array, - DataType::Decimal128(10, 4), - Median, - ScalarValue::Decimal128(Some(3), 10, 4) - ) - } - - #[test] - fn median_decimal_all_nulls() -> Result<()> { - // test median - let array: ArrayRef = Arc::new( - std::iter::repeat::>(None) - .take(6) - .collect::() - .with_precision_and_scale(10, 4)?, - ); - generic_test_op!( - array, - DataType::Decimal128(10, 4), - Median, - ScalarValue::Decimal128(None, 10, 4) - ) - } - - #[test] - fn median_i32_odd() -> Result<()> { - let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5])); - generic_test_op!(a, DataType::Int32, Median, ScalarValue::from(3_i32)) - } - - #[test] - fn median_i32_even() -> Result<()> { - let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5, 6])); - generic_test_op!(a, DataType::Int32, Median, ScalarValue::from(3_i32)) - } - - #[test] - fn median_i32_with_nulls() -> Result<()> { - let a: ArrayRef = Arc::new(Int32Array::from(vec![ - Some(1), - None, - Some(3), - Some(4), - Some(5), - ])); - generic_test_op!(a, DataType::Int32, Median, ScalarValue::from(3i32)) - } - - #[test] - fn median_i32_all_nulls() -> Result<()> { - let a: ArrayRef = Arc::new(Int32Array::from(vec![None, None])); - generic_test_op!(a, DataType::Int32, Median, ScalarValue::Int32(None)) - } - - #[test] - fn median_u32_odd() -> Result<()> { - let a: ArrayRef = - Arc::new(UInt32Array::from(vec![1_u32, 2_u32, 3_u32, 4_u32, 5_u32])); - generic_test_op!(a, DataType::UInt32, Median, ScalarValue::from(3u32)) - } - - #[test] - fn median_u32_even() -> Result<()> { - let a: ArrayRef = Arc::new(UInt32Array::from(vec![ - 1_u32, 2_u32, 3_u32, 4_u32, 5_u32, 6_u32, - ])); - generic_test_op!(a, DataType::UInt32, Median, ScalarValue::from(3u32)) - } - - #[test] - fn median_f32_odd() -> Result<()> { - let a: ArrayRef = - Arc::new(Float32Array::from(vec![1_f32, 2_f32, 3_f32, 4_f32, 5_f32])); - generic_test_op!(a, DataType::Float32, Median, ScalarValue::from(3_f32)) - } - - #[test] - fn median_f32_even() -> Result<()> { - let a: ArrayRef = Arc::new(Float32Array::from(vec![ - 1_f32, 2_f32, 3_f32, 4_f32, 5_f32, 6_f32, - ])); - generic_test_op!(a, DataType::Float32, Median, ScalarValue::from(3.5_f32)) - } - - #[test] - fn median_f64_odd() -> Result<()> { - let a: ArrayRef = - Arc::new(Float64Array::from(vec![1_f64, 2_f64, 3_f64, 4_f64, 5_f64])); - generic_test_op!(a, DataType::Float64, Median, ScalarValue::from(3_f64)) - } - - #[test] - fn median_f64_even() -> Result<()> { - let a: ArrayRef = Arc::new(Float64Array::from(vec![ - 1_f64, 2_f64, 3_f64, 4_f64, 5_f64, 6_f64, - ])); - generic_test_op!(a, DataType::Float64, Median, ScalarValue::from(3.5_f64)) - } -} diff --git a/datafusion/physical-expr/src/aggregate/mod.rs b/datafusion/physical-expr/src/aggregate/mod.rs index 69918cfac268c..dd97675ab9eb3 100644 --- a/datafusion/physical-expr/src/aggregate/mod.rs +++ b/datafusion/physical-expr/src/aggregate/mod.rs @@ -42,7 +42,7 @@ pub(crate) mod count_distinct; pub(crate) mod covariance; pub(crate) mod first_last; pub(crate) mod grouping; -pub(crate) mod median; +pub(crate) mod percentile; #[macro_use] pub(crate) mod min_max; pub mod build_in; diff --git a/datafusion/physical-expr/src/aggregate/percentile.rs b/datafusion/physical-expr/src/aggregate/percentile.rs new file mode 100644 index 0000000000000..29a0b02ae2b8d --- /dev/null +++ b/datafusion/physical-expr/src/aggregate/percentile.rs @@ -0,0 +1,712 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! # Percentile + +use crate::aggregate::utils::{down_cast_any_ref, validate_input_percentile_expr}; +use crate::expressions::format_state_name; +use crate::{AggregateExpr, PhysicalExpr}; +use arrow::array::{Array, ArrayRef, UInt32Array}; +use arrow::compute::sort_to_indices; +use arrow::datatypes::{DataType, Field}; +use datafusion_common::{internal_err, DataFusionError, Result, ScalarValue}; +use datafusion_expr::Accumulator; +use std::any::Any; +use std::convert::TryFrom; +use std::sync::Arc; + +#[derive(PartialEq, Debug, Clone, Copy)] +/// Enum representing if interpolation is used for the percentile aggregate expression. +pub enum PercentileInterpolationType { + /// Interpolates between adjacent values if the desired percentile lies between them. + Continuous, + /// Always returns an actual data point from the dataset. + Discrete, +} + +/// QUANTILE_CONT/QUANTILE_DISC expression +/// +/// This uses a lot of memory because all values need to be +/// stored in memory before a result can be computed. If an approximation is sufficient +/// then APPROX_PERCENTILE_CONT provides a much more efficient solution. +#[derive(Debug)] +pub struct Quantile { + name: String, + quantile_type: PercentileInterpolationType, + expr_value: Arc, + percentile_score: f64, + data_type: DataType, +} + +impl Quantile { + pub fn new( + name: impl Into, + quantile_type: PercentileInterpolationType, + expr_value: Arc, + expr_percentile_score: Arc, + data_type: DataType, + ) -> Result { + let percentile_score = validate_input_percentile_expr(&expr_percentile_score)?; + + Ok(Self { + name: name.into(), + quantile_type, + expr_value, + percentile_score, + data_type, + }) + } +} + +impl AggregateExpr for Quantile { + fn as_any(&self) -> &dyn Any { + self + } + + fn field(&self) -> Result { + Ok(Field::new(&self.name, self.data_type.clone(), true)) + } + + fn create_accumulator(&self) -> Result> { + Ok(Box::new(PercentileAccumulator { + percentile_score: self.percentile_score, + interpolation_type: self.quantile_type, + data_type: self.data_type.clone(), + arrays: vec![], + all_values: vec![], + })) + } + + fn state_fields(&self) -> Result> { + //Intermediate state is a list of the elements we have collected so far + let field = Field::new("item", self.data_type.clone(), true); + let data_type = DataType::List(Arc::new(field)); + + Ok(vec![Field::new( + format_state_name(&self.name, "median"), + data_type, + true, + )]) + } + + fn expressions(&self) -> Vec> { + vec![self.expr_value.clone()] + } + + fn name(&self) -> &str { + &self.name + } +} + +impl PartialEq for Quantile { + fn eq(&self, other: &dyn Any) -> bool { + down_cast_any_ref(other) + .downcast_ref::() + .map(|x| { + self.name == x.name + && self.data_type == x.data_type + && self.expr_value.eq(&x.expr_value) + && self.quantile_type == x.quantile_type + && self.percentile_score == x.percentile_score + }) + .unwrap_or(false) + } +} + +/// MEDIAN aggregate expression. +/// MEDIAN(x) is equivalent to QUANTILE_CONT(x, 0.5) +/// +/// This uses a lot of memory because all values need to be +/// stored in memory before a result can be computed. If an approximation is sufficient +/// then APPROX_MEDIAN provides a much more efficient solution. +#[derive(Debug)] +pub struct Median { + name: String, + expr: Arc, + data_type: DataType, +} + +impl Median { + /// Create a new MEDIAN aggregate function + pub fn new( + expr: Arc, + name: impl Into, + data_type: DataType, + ) -> Self { + Self { + name: name.into(), + expr, + data_type, + } + } +} + +impl AggregateExpr for Median { + /// Return a reference to Any that can be used for downcasting + fn as_any(&self) -> &dyn Any { + self + } + + fn field(&self) -> Result { + Ok(Field::new(&self.name, self.data_type.clone(), true)) + } + + fn create_accumulator(&self) -> Result> { + Ok(Box::new(PercentileAccumulator { + percentile_score: 0.5, + interpolation_type: PercentileInterpolationType::Continuous, + data_type: self.data_type.clone(), + arrays: vec![], + all_values: vec![], + })) + } + + fn state_fields(&self) -> Result> { + //Intermediate state is a list of the elements we have collected so far + let field = Field::new("item", self.data_type.clone(), true); + let data_type = DataType::List(Arc::new(field)); + + Ok(vec![Field::new( + format_state_name(&self.name, "median"), + data_type, + true, + )]) + } + + fn expressions(&self) -> Vec> { + vec![self.expr.clone()] + } + + fn name(&self) -> &str { + &self.name + } +} + +impl PartialEq for Median { + fn eq(&self, other: &dyn Any) -> bool { + down_cast_any_ref(other) + .downcast_ref::() + .map(|x| { + self.name == x.name + && self.data_type == x.data_type + && self.expr.eq(&x.expr) + }) + .unwrap_or(false) + } +} + +#[derive(Debug)] +/// The accumulator for median/quantile_cont/quantile/disc aggregate functions +/// It accumulates the raw input values as `ScalarValue`s +/// +/// The intermediate state is represented as a List of scalar values updated by +/// `merge_batch` and a `Vec` of `ArrayRef` that are converted to scalar values +/// in the final evaluation step so that we avoid expensive conversions and +/// allocations during `update_batch`. +struct PercentileAccumulator { + percentile_score: f64, + interpolation_type: PercentileInterpolationType, + data_type: DataType, + arrays: Vec, + all_values: Vec, +} + +macro_rules! safe_average { + (f32, $v1:expr, $v2:expr) => { + $v1 / 2.0 + $v2 / 2.0 + }; + ( f64, $v1:expr, $v2:expr) => { + $v1 / 2.0 + $v2 / 2.0 + }; + ($val_type:ty, $v1:expr, $v2:expr) => { + match $v1.checked_add($v2) { + Some(sum) => sum / (2 as $val_type), + None => $v1 / (2 as $val_type) + $v2 / (2 as $val_type), + } + }; +} + +// Example: `target_percentile` is 0.12 and it's landed between dp1 and dp2 +// dp1 has percentile 0.10 and value 0 +// dp2 has percentile 0.20 and value 100 +// `quantile_cont()` do linear interpolation: +// Then interpolation result = 0 + (0.12 - 0.10) / (0.20 - 0.10) * (100 - 0) +// = 20 +// `quantile_disc()` choose the closer dp (pick one with lower percentile if equally close) +// `target_percentile` is closer to dp1's percentile, result = 0 +macro_rules! interpolate_logic { + ($data_type:ident, $val_type:ident, $dp1_val:expr, $dp2_val:expr, $dp1_percentile:expr, $dp2_percentile:expr, $target_percentile:expr, $interpolation_type: expr) => {{ + if $dp1_percentile == $target_percentile { + ScalarValue::$data_type(Some($dp1_val)) + } else { + match $interpolation_type { + PercentileInterpolationType::Continuous => { + let (v1, v2) = ($dp1_val as $val_type, $dp2_val as $val_type); + let result = if $target_percentile == 0.5 { + // HACK: special-case median() + // float arithmetic for interpolation (in else branch) might get very lossy for + // $val_type like i8 + safe_average!($val_type, v1, v2) + } else { + v1 + (($target_percentile - $dp1_percentile) + / ($dp2_percentile - $dp1_percentile) + * ((v2 - v1) as f64)) as $val_type + }; + ScalarValue::$data_type(Some(result as $val_type)) + }, + PercentileInterpolationType::Discrete => { + let dp1_to_target = $target_percentile - $dp1_percentile; + let target_to_dp2 = $dp2_percentile - $target_percentile; + if dp1_to_target <= target_to_dp2 { + ScalarValue::$data_type(Some($dp1_val)) + } else { + ScalarValue::$data_type(Some($dp2_val)) + } + } + } + } + }}; +} + +impl Accumulator for PercentileAccumulator { + fn state(&self) -> Result> { + let all_values = to_scalar_values(&self.arrays)?; + let state = ScalarValue::new_list(Some(all_values), self.data_type.clone()); + + Ok(vec![state]) + } + + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + assert_eq!(values.len(), 1); + let array = &values[0]; + + // Defer conversions to scalar values to final evaluation. + assert_eq!(array.data_type(), &self.data_type); + self.arrays.push(array.clone()); + + Ok(()) + } + + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + assert_eq!(states.len(), 1); + + let array = &states[0]; + assert!(matches!(array.data_type(), DataType::List(_))); + for index in 0..array.len() { + match ScalarValue::try_from_array(array, index)? { + ScalarValue::List(Some(mut values), _) => { + self.all_values.append(&mut values); + } + ScalarValue::List(None, _) => {} // skip empty state + v => { + return internal_err!( + "unexpected state in median. Expected DataType::List, got {v:?}" + ) + } + } + } + Ok(()) + } + + fn evaluate(&self) -> Result { + let batch_values = to_scalar_values(&self.arrays)?; + + if !self + .all_values + .iter() + .chain(batch_values.iter()) + .any(|v| !v.is_null()) + { + return ScalarValue::try_from(&self.data_type); + } + + // Create an array of all the non null values and find the + // sorted indexes + let array = ScalarValue::iter_to_array( + self.all_values + .iter() + .chain(batch_values.iter()) + // ignore null values + .filter(|v| !v.is_null()) + .cloned(), + )?; + + let len = array.len(); + if len == 1 { + return ScalarValue::try_from_array(&array, 0); + } + + // Suppose target percentile score land between dp1 and dp2 in the sorted array + // self.percentile_score is in [dp1_percentile, dp2_percentile) + let dp1_index = (self.percentile_score * (len as f64 - 1_f64)) as usize; + let dp2_index = dp1_index + 1; + let percentile_per_gap = 1_f64 / ((len - 1) as f64); + let (dp1_percentile, dp2_percentile) = ( + dp1_index as f64 * percentile_per_gap, + dp2_index as f64 * percentile_per_gap, + ); + + // only sort up to the top len * self.percentile_score elements + let limit = Some(dp1_index + 2); + let options = None; + let indices = sort_to_indices(&array, options, limit)?; + + if self.percentile_score == 1.0 { + return scalar_at_index(&array, &indices, dp1_index); + } + + // pick the relevant indices in the original arrays + let result = { + let s1 = scalar_at_index(&array, &indices, dp1_index)?; + let s2 = scalar_at_index(&array, &indices, dp2_index)?; + match (s1, s2) { + (ScalarValue::Int8(Some(dp1_val)), ScalarValue::Int8(Some(dp2_val))) => { + interpolate_logic!( + Int8, + i8, + dp1_val, + dp2_val, + dp1_percentile, + dp2_percentile, + self.percentile_score, + self.interpolation_type + ) + } + ( + ScalarValue::Int16(Some(dp1_val)), + ScalarValue::Int16(Some(dp2_val)), + ) => { + interpolate_logic!( + Int16, + i16, + dp1_val, + dp2_val, + dp1_percentile, + dp2_percentile, + self.percentile_score, + self.interpolation_type + ) + } + (ScalarValue::Int32(Some(v1)), ScalarValue::Int32(Some(v2))) => { + interpolate_logic!( + Int32, + i32, + v1, + v2, + dp1_percentile, + dp2_percentile, + self.percentile_score, + self.interpolation_type + ) + } + (ScalarValue::Int64(Some(v1)), ScalarValue::Int64(Some(v2))) => { + interpolate_logic!( + Int64, + i64, + v1, + v2, + dp1_percentile, + dp2_percentile, + self.percentile_score, + self.interpolation_type + ) + } + + (ScalarValue::UInt8(Some(v1)), ScalarValue::UInt8(Some(v2))) => { + interpolate_logic!( + UInt8, + u8, + v1, + v2, + dp1_percentile, + dp2_percentile, + self.percentile_score, + self.interpolation_type + ) + } + + (ScalarValue::UInt16(Some(v1)), ScalarValue::UInt16(Some(v2))) => { + interpolate_logic!( + UInt16, + u16, + v1, + v2, + dp1_percentile, + dp2_percentile, + self.percentile_score, + self.interpolation_type + ) + } + + (ScalarValue::UInt32(Some(v1)), ScalarValue::UInt32(Some(v2))) => { + interpolate_logic!( + UInt32, + u32, + v1, + v2, + dp1_percentile, + dp2_percentile, + self.percentile_score, + self.interpolation_type + ) + } + + (ScalarValue::UInt64(Some(v1)), ScalarValue::UInt64(Some(v2))) => { + interpolate_logic!( + UInt64, + u64, + v1, + v2, + dp1_percentile, + dp2_percentile, + self.percentile_score, + self.interpolation_type + ) + } + + (ScalarValue::Float32(Some(v1)), ScalarValue::Float32(Some(v2))) => { + interpolate_logic!( + Float32, + f32, + v1, + v2, + dp1_percentile, + dp2_percentile, + self.percentile_score, + self.interpolation_type + ) + } + + (ScalarValue::Float64(Some(v1)), ScalarValue::Float64(Some(v2))) => { + interpolate_logic!( + Float64, + f64, + v1, + v2, + dp1_percentile, + dp2_percentile, + self.percentile_score, + self.interpolation_type + ) + } + + ( + s1 @ ScalarValue::Decimal128(_, _, _), + s2 @ ScalarValue::Decimal128(_, _, _), + ) => { + // HACK: Decimal is now only supported in median() aggregate function + let is_median = self.percentile_score == 0.5 + && self.interpolation_type + == PercentileInterpolationType::Continuous; + if is_median { + if let ScalarValue::Decimal128(Some(v), p, s) = s1.add(s2)? { + ScalarValue::Decimal128(Some(v / 2), p, s) + } else { + return internal_err!("{}", "Unreachable".to_string()); + } + } else { + return internal_err!("{}", "Decimal type not supported in quantile_cont() or quantile_disc() aggregate function".to_string()); + } + } + (scalar_value, _) => { + return internal_err!( + "{}", + format!( + "Unsupported type in PercentileAccumulator: {scalar_value:?}" + ) + ); + } + } + }; + + Ok(result) + } + + fn size(&self) -> usize { + let arrays_size: usize = self.arrays.iter().map(|a| a.len()).sum(); + + std::mem::size_of_val(self) + + ScalarValue::size_of_vec(&self.all_values) + + arrays_size + - std::mem::size_of_val(&self.all_values) + + self.data_type.size() + - std::mem::size_of_val(&self.data_type) + } +} + +fn to_scalar_values(arrays: &[ArrayRef]) -> Result> { + let num_values: usize = arrays.iter().map(|a| a.len()).sum(); + let mut all_values = Vec::with_capacity(num_values); + + for array in arrays { + for index in 0..array.len() { + all_values.push(ScalarValue::try_from_array(&array, index)?); + } + } + + Ok(all_values) +} + +/// Given a returns `array[indicies[indicie_index]]` as a `ScalarValue` +fn scalar_at_index( + array: &dyn Array, + indices: &UInt32Array, + indicies_index: usize, +) -> Result { + let array_index = indices + .value(indicies_index) + .try_into() + .expect("Convert uint32 to usize"); + ScalarValue::try_from_array(array, array_index) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::expressions::col; + use crate::expressions::tests::aggregate; + use crate::generic_test_op; + use arrow::record_batch::RecordBatch; + use arrow::{array::*, datatypes::*}; + use datafusion_common::Result; + + #[test] + fn median_decimal() -> Result<()> { + // test median + let array: ArrayRef = Arc::new( + (1..7) + .map(Some) + .collect::() + .with_precision_and_scale(10, 4)?, + ); + + generic_test_op!( + array, + DataType::Decimal128(10, 4), + Median, + ScalarValue::Decimal128(Some(3), 10, 4) + ) + } + + #[test] + fn median_decimal_with_nulls() -> Result<()> { + let array: ArrayRef = Arc::new( + (1..6) + .map(|i| if i == 2 { None } else { Some(i) }) + .collect::() + .with_precision_and_scale(10, 4)?, + ); + generic_test_op!( + array, + DataType::Decimal128(10, 4), + Median, + ScalarValue::Decimal128(Some(3), 10, 4) + ) + } + + #[test] + fn median_decimal_all_nulls() -> Result<()> { + // test median + let array: ArrayRef = Arc::new( + std::iter::repeat::>(None) + .take(6) + .collect::() + .with_precision_and_scale(10, 4)?, + ); + generic_test_op!( + array, + DataType::Decimal128(10, 4), + Median, + ScalarValue::Decimal128(None, 10, 4) + ) + } + + #[test] + fn median_i32_odd() -> Result<()> { + let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5])); + generic_test_op!(a, DataType::Int32, Median, ScalarValue::from(3_i32)) + } + + #[test] + fn median_i32_even() -> Result<()> { + let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5, 6])); + generic_test_op!(a, DataType::Int32, Median, ScalarValue::from(3_i32)) + } + + #[test] + fn median_i32_with_nulls() -> Result<()> { + let a: ArrayRef = Arc::new(Int32Array::from(vec![ + Some(1), + None, + Some(3), + Some(4), + Some(5), + ])); + generic_test_op!(a, DataType::Int32, Median, ScalarValue::from(3i32)) + } + + #[test] + fn median_i32_all_nulls() -> Result<()> { + let a: ArrayRef = Arc::new(Int32Array::from(vec![None, None])); + generic_test_op!(a, DataType::Int32, Median, ScalarValue::Int32(None)) + } + + #[test] + fn median_u32_odd() -> Result<()> { + let a: ArrayRef = + Arc::new(UInt32Array::from(vec![1_u32, 2_u32, 3_u32, 4_u32, 5_u32])); + generic_test_op!(a, DataType::UInt32, Median, ScalarValue::from(3u32)) + } + + #[test] + fn median_u32_even() -> Result<()> { + let a: ArrayRef = Arc::new(UInt32Array::from(vec![ + 1_u32, 2_u32, 3_u32, 4_u32, 5_u32, 6_u32, + ])); + generic_test_op!(a, DataType::UInt32, Median, ScalarValue::from(3u32)) + } + + #[test] + fn median_f32_odd() -> Result<()> { + let a: ArrayRef = + Arc::new(Float32Array::from(vec![1_f32, 2_f32, 3_f32, 4_f32, 5_f32])); + generic_test_op!(a, DataType::Float32, Median, ScalarValue::from(3_f32)) + } + + #[test] + fn median_f32_even() -> Result<()> { + let a: ArrayRef = Arc::new(Float32Array::from(vec![ + 1_f32, 2_f32, 3_f32, 4_f32, 5_f32, 6_f32, + ])); + generic_test_op!(a, DataType::Float32, Median, ScalarValue::from(3.5_f32)) + } + + #[test] + fn median_f64_odd() -> Result<()> { + let a: ArrayRef = + Arc::new(Float64Array::from(vec![1_f64, 2_f64, 3_f64, 4_f64, 5_f64])); + generic_test_op!(a, DataType::Float64, Median, ScalarValue::from(3_f64)) + } + + #[test] + fn median_f64_even() -> Result<()> { + let a: ArrayRef = Arc::new(Float64Array::from(vec![ + 1_f64, 2_f64, 3_f64, 4_f64, 5_f64, 6_f64, + ])); + generic_test_op!(a, DataType::Float64, Median, ScalarValue::from(3.5_f64)) + } +} diff --git a/datafusion/physical-expr/src/aggregate/utils.rs b/datafusion/physical-expr/src/aggregate/utils.rs index e86eb1dc1fc51..6f2f492dd732c 100644 --- a/datafusion/physical-expr/src/aggregate/utils.rs +++ b/datafusion/physical-expr/src/aggregate/utils.rs @@ -17,14 +17,14 @@ //! Utilities used in aggregates -use crate::{AggregateExpr, PhysicalSortExpr}; +use crate::expressions::Literal; +use crate::{AggregateExpr, PhysicalExpr, PhysicalSortExpr}; use arrow::array::ArrayRef; use arrow::datatypes::{MAX_DECIMAL_FOR_EACH_PRECISION, MIN_DECIMAL_FOR_EACH_PRECISION}; use arrow_array::cast::AsArray; use arrow_array::types::Decimal128Type; use arrow_schema::{DataType, Field}; -use datafusion_common::internal_err; -use datafusion_common::{DataFusionError, Result, ScalarValue}; +use datafusion_common::{internal_err, plan_err, DataFusionError, Result, ScalarValue}; use datafusion_expr::Accumulator; use std::any::Any; use std::sync::Arc; @@ -208,3 +208,36 @@ pub(crate) fn ordering_fields( }) .collect() } + +/// parse and validate percentile scores like 0.5 in `quantile_cont(data, 0.5);` +pub fn validate_input_percentile_expr(expr: &Arc) -> Result { + // Extract the desired percentile literal + let lit = expr + .as_any() + .downcast_ref::() + .ok_or_else(|| { + DataFusionError::Internal( + "desired percentile argument must be float literal".to_string(), + ) + })? + .value(); + let percentile = match lit { + ScalarValue::Float32(Some(q)) => *q as f64, + ScalarValue::Float64(Some(q)) => *q, + ScalarValue::Int64(Some(q)) => *q as f64, // to support 0 or 1 + got => { + return Err(DataFusionError::NotImplemented(format!( + "Percentile value must be Float32 or Float64 literal (got data type {})", + got.get_datatype() + ))) + } + }; + + // Ensure the percentile is between 0 and 1. + if !(0.0..=1.0).contains(&percentile) { + return plan_err!( + "Percentile value must be between 0.0 and 1.0 inclusive, {percentile} is invalid" + ); + } + Ok(percentile) +} diff --git a/datafusion/physical-expr/src/expressions/mod.rs b/datafusion/physical-expr/src/expressions/mod.rs index 022e0ae02ed38..03ac9db429411 100644 --- a/datafusion/physical-expr/src/expressions/mod.rs +++ b/datafusion/physical-expr/src/expressions/mod.rs @@ -56,9 +56,9 @@ pub use crate::aggregate::count_distinct::DistinctCount; pub use crate::aggregate::covariance::{Covariance, CovariancePop}; pub use crate::aggregate::first_last::{FirstValue, LastValue}; pub use crate::aggregate::grouping::Grouping; -pub use crate::aggregate::median::Median; pub use crate::aggregate::min_max::{Max, Min}; pub use crate::aggregate::min_max::{MaxAccumulator, MinAccumulator}; +pub use crate::aggregate::percentile::{Median, Quantile}; pub use crate::aggregate::regr::Regr; pub use crate::aggregate::stats::StatsType; pub use crate::aggregate::stddev::{Stddev, StddevPop}; diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index e4ef7b1bd4483..89d99d212360f 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -642,6 +642,8 @@ enum AggregateFunction { REGR_SXX = 32; REGR_SYY = 33; REGR_SXY = 34; + QUANTILE_CONT = 35; + QUANTILE_DISC = 36; } message AggregateExprNode { diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index f1a9e9c7bb74c..3f9aff9569c2d 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -474,6 +474,8 @@ impl serde::Serialize for AggregateFunction { Self::RegrSxx => "REGR_SXX", Self::RegrSyy => "REGR_SYY", Self::RegrSxy => "REGR_SXY", + Self::QuantileCont => "QUANTILE_CONT", + Self::QuantileDisc => "QUANTILE_DISC", }; serializer.serialize_str(variant) } @@ -520,6 +522,8 @@ impl<'de> serde::Deserialize<'de> for AggregateFunction { "REGR_SXX", "REGR_SYY", "REGR_SXY", + "QUANTILE_CONT", + "QUANTILE_DISC", ]; struct GeneratedVisitor; @@ -597,6 +601,8 @@ impl<'de> serde::Deserialize<'de> for AggregateFunction { "REGR_SXX" => Ok(AggregateFunction::RegrSxx), "REGR_SYY" => Ok(AggregateFunction::RegrSyy), "REGR_SXY" => Ok(AggregateFunction::RegrSxy), + "QUANTILE_CONT" => Ok(AggregateFunction::QuantileCont), + "QUANTILE_DISC" => Ok(AggregateFunction::QuantileDisc), _ => Err(serde::de::Error::unknown_variant(value, FIELDS)), } } diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index 6cf402fe66e95..6db805f37e0fa 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -2664,6 +2664,8 @@ pub enum AggregateFunction { RegrSxx = 32, RegrSyy = 33, RegrSxy = 34, + QuantileCont = 35, + QuantileDisc = 36, } impl AggregateFunction { /// String value of the enum field names used in the ProtoBuf definition. @@ -2709,6 +2711,8 @@ impl AggregateFunction { AggregateFunction::RegrSxx => "REGR_SXX", AggregateFunction::RegrSyy => "REGR_SYY", AggregateFunction::RegrSxy => "REGR_SXY", + AggregateFunction::QuantileCont => "QUANTILE_CONT", + AggregateFunction::QuantileDisc => "QUANTILE_DISC", } } /// Creates an enum from field names used in the ProtoBuf definition. @@ -2751,6 +2755,8 @@ impl AggregateFunction { "REGR_SXX" => Some(Self::RegrSxx), "REGR_SYY" => Some(Self::RegrSyy), "REGR_SXY" => Some(Self::RegrSxy), + "QUANTILE_CONT" => Some(Self::QuantileCont), + "QUANTILE_DISC" => Some(Self::QuantileDisc), _ => None, } } diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index d3329c696764a..34027872ee492 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -563,6 +563,8 @@ impl From for AggregateFunction { protobuf::AggregateFunction::RegrSxx => Self::RegrSXX, protobuf::AggregateFunction::RegrSyy => Self::RegrSYY, protobuf::AggregateFunction::RegrSxy => Self::RegrSXY, + protobuf::AggregateFunction::QuantileCont => Self::QuantileCont, + protobuf::AggregateFunction::QuantileDisc => Self::QuantileDisc, protobuf::AggregateFunction::ApproxPercentileCont => { Self::ApproxPercentileCont } diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index cb3296438165f..da5a0f9d0ae50 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -393,6 +393,8 @@ impl From<&AggregateFunction> for protobuf::AggregateFunction { AggregateFunction::RegrSXX => Self::RegrSxx, AggregateFunction::RegrSYY => Self::RegrSyy, AggregateFunction::RegrSXY => Self::RegrSxy, + AggregateFunction::QuantileCont => Self::QuantileCont, + AggregateFunction::QuantileDisc => Self::QuantileDisc, AggregateFunction::ApproxPercentileCont => Self::ApproxPercentileCont, AggregateFunction::ApproxPercentileContWithWeight => { Self::ApproxPercentileContWithWeight @@ -704,6 +706,12 @@ impl TryFrom<&Expr> for protobuf::LogicalExprNode { } AggregateFunction::Grouping => protobuf::AggregateFunction::Grouping, AggregateFunction::Median => protobuf::AggregateFunction::Median, + AggregateFunction::QuantileCont => { + protobuf::AggregateFunction::QuantileCont + } + AggregateFunction::QuantileDisc => { + protobuf::AggregateFunction::QuantileDisc + } AggregateFunction::FirstValue => { protobuf::AggregateFunction::FirstValueAgg } diff --git a/datafusion/sqllogictest/test_files/aggregate.slt b/datafusion/sqllogictest/test_files/aggregate.slt index e881acf5755b1..c6c3af1443d3c 100644 --- a/datafusion/sqllogictest/test_files/aggregate.slt +++ b/datafusion/sqllogictest/test_files/aggregate.slt @@ -2559,3 +2559,288 @@ NULL NULL 1 NULL 3 6 0 0 0 NULL NULL 1 NULL 5 15 0 0 0 3 0 2 1 5.5 16.5 0.5 4.5 1.5 3 0 3 1 6 18 2 18 6 + + + +# +# quantile_cont()/quantile_disc() tests +# + +# quantile_cont()/quantile_disc() invalid inputs +statement error +select quantile_cont(); + +statement error +select quantile_disc(1, 'foo'); + +statement error +select quantile_cont(1, 0.5, 2); + +statement error +select quantile_disc(1, 1.1); + +statement error +select quantile_cont(1, -0.1); + +statement error +select quantile_disc(1, NULL); + +statement error +select quantile_cont(NULL, NULL); + + + +# quantile_cont()/quantile_disc() scalar inputs +query IIIIII +select + quantile_cont(1, 0.0), quantile_cont(1, 0.5), quantile_cont(1, 1.0), + quantile_disc(1, 0.0), quantile_disc(1, 0.5), quantile_disc(1, 1.0) +---- +1 1 1 1 1 1 + +query RRRRRR +select + quantile_cont(9.9, 0.0), quantile_cont('NaN'::float, 0.5), quantile_cont('inf'::float, 1.0), + quantile_disc(9.9, 0.0), quantile_disc('NaN'::float, 0.5), quantile_disc('inf'::float, 1.0); +---- +9.9 NaN Infinity 9.9 NaN Infinity + +query II +select quantile_cont(1, 0), quantile_disc(1, 1); +---- +1 1 + + + +# quantile_cont()/quantile_disc() special floats + +# Values of 'NaN' are treated as greater than 'inf' +# exprs in select columns are all interpolation between 2 numbers +query RRRRR +SELECT + quantile_cont(column1, 0.1), + quantile_cont(column1, 0.3), + quantile_cont(column1, 0.5), + quantile_cont(column1, 0.7), + quantile_cont(column1, 0.9) +from (values (-0.0::float), (0.0::float), ('inf'::float), ('inf'::float), ('NaN'::float), ('NaN'::float)); +---- +0 Infinity Infinity NaN NaN + +query RRRRR +SELECT + quantile_disc(column1, 0.09), + quantile_disc(column1, 0.3), + quantile_disc(column1, 0.49), + quantile_disc(column1, 0.7), + quantile_disc(column1, 0.89) +from (values (-0.0::float), (0.0::float), ('inf'::float), ('inf'::float), ('NaN'::float), ('NaN'::float)); +---- +0 0 Infinity Infinity NaN + +query RRR +SELECT + quantile_cont(column1, 0.0), + quantile_cont(column1, 0.3), + quantile_cont(column1, 1.0) +from (values (0.0::float), ('NaN'::float)); +---- +0 NaN NaN + +query RRRR +SELECT + quantile_disc(column1, 0.0), + quantile_disc(column1, 0.3), + quantile_disc(column1, 0.6), + quantile_disc(column1, 1.0) +from (values (0.0::float), ('NaN'::float)); +---- +0 0 NaN NaN + + + +# quantile_cont()/quantile_disc() big number test (make sure no overflows) +statement ok +create table bignum ( + c1 integer unsigned not null, + c2 bigint not null, + c3 float not null, +); + +statement ok +insert into bignum values (4294967295, -9223372036854775808, 3.4e38), (4294967295, -9223372036854775808, 3.4e38); + +query IIIIIIIIIIRRRRR +select + quantile_cont(c1, 0.0), quantile_cont(c1, 0.1), median(c1), quantile_cont(c1, 0.7), quantile_cont(c1, 1.0), + quantile_cont(c2, 0.0), quantile_cont(c2, 0.1), median(c2), quantile_cont(c2, 0.7), quantile_cont(c2, 1.0), + quantile_cont(c3, 0.0), quantile_cont(c3, 0.1), median(c3), quantile_cont(c3, 0.7), quantile_cont(c3, 1.0) +from bignum +---- +4294967295 4294967295 4294967294 4294967295 4294967295 -9223372036854775808 -9223372036854775808 -9223372036854775808 -9223372036854775808 -9223372036854775808 340000000000000000000000000000000000000 340000000000000000000000000000000000000 340000000000000000000000000000000000000 340000000000000000000000000000000000000 340000000000000000000000000000000000000 + +query IIIIIIIIIIRRRRR +select + quantile_disc(c1, 0.0), quantile_disc(c1, 0.1), median(c1), quantile_disc(c1, 0.7), quantile_disc(c1, 1.0), + quantile_disc(c2, 0.0), quantile_disc(c2, 0.1), median(c2), quantile_disc(c2, 0.7), quantile_disc(c2, 1.0), + quantile_disc(c3, 0.0), quantile_disc(c3, 0.1), median(c3), quantile_disc(c3, 0.7), quantile_disc(c3, 1.0) +from bignum +---- +4294967295 4294967295 4294967294 4294967295 4294967295 -9223372036854775808 -9223372036854775808 -9223372036854775808 -9223372036854775808 -9223372036854775808 340000000000000000000000000000000000000 340000000000000000000000000000000000000 340000000000000000000000000000000000000 340000000000000000000000000000000000000 340000000000000000000000000000000000000 + +statement ok +drop table bignum; + + + +# quantile_cont()/quantile_disc() basic tests +query RRRRR +SELECT quantile_cont(column1, 0.0), quantile_cont(column1, 0.2), quantile_cont(column1, 0.5), quantile_cont(column1, 0.9), quantile_cont(column1, 1.0) +from (values (0.0::float), (100.0::float)); +---- +0 20 50 90 100 + +query RRRRR +select quantile_disc(column1, 0.0), quantile_disc(column1, 0.49), quantile_disc(column1, 0.5), quantile_disc(column1, 0.51), quantile_disc(column1, 1.0) +from (values (0.0::float), (100.0::float)); +---- +0 0 0 100 100 + +query RRRRR +SELECT quantile_cont(column1, 0.0), quantile_cont(column1, 0.2), quantile_cont(column1, 0.5), quantile_cont(column1, 0.9), quantile_cont(column1, 1.0) +from (values (0.0::float), (50.0::float), (100.0::float)); +---- +0 20 50 90 100 + +query RRRRRR +select quantile_disc(column1, 0.0), quantile_disc(column1, 0.24), quantile_disc(column1, 0.25), quantile_disc(column1, 0.5), quantile_disc(column1, 0.76), quantile_cont(column1, 1.0) +from (values (0.0::float), (50.0::float), (100.0::float)); +---- +0 0 0 50 100 100 + + + +# quantile_cont()/quantile_disc() testing merge_batch() from Accumulator's internal implementation +query IIIIIIIIIIIIIIIIIIIIIIIIRRRRRRRR +select + quantile_cont(c2, 0.3), quantile_cont(c2, 0.5), quantile_cont(c2, 0.9), quantile_cont(c2, 1.0), + quantile_cont(c3, 0.3), quantile_cont(c3, 0.5), quantile_cont(c3, 0.9), quantile_cont(c3, 1.0), + quantile_cont(c5, 0.3), quantile_cont(c5, 0.5), quantile_cont(c5, 0.9), quantile_cont(c5, 1.0), + quantile_cont(c6, 0.3), quantile_cont(c6, 0.5), quantile_cont(c6, 0.9), quantile_cont(c6, 1.0), + quantile_cont(c9, 0.3), quantile_cont(c9, 0.5), quantile_cont(c9, 0.9), quantile_cont(c9, 1.0), + quantile_cont(c10, 0.3), quantile_cont(c10, 0.5), quantile_cont(c10, 0.9), quantile_cont(c10, 1.0), + quantile_cont(c11, 0.3), quantile_cont(c11, 0.5), quantile_cont(c11, 0.9), quantile_cont(c11, 1.0), + quantile_cont(c12, 0.3), quantile_cont(c12, 0.5), quantile_cont(c12, 0.9), quantile_cont(c12, 1.0) +from aggregate_test_100; +---- +2 3 5 5 -46 15 102 125 -652929122 377164262 1991374995 2143473091 -2447222796583094622 1125553990140691277 7374910587160164343 9178511478067509438 1275641152 2365817607 3776538486 4268716378 4383947019179292752 9299860258734726870 16067113363455246373 17929716297117857676 0.34409833 0.4906719 0.8340051 0.9488028 0.295550335607 0.551390054439 0.946311358438 0.996540038759 + +query IIIIIIIIIIIIIIIIIIIIIIIIRRRRRRRR +select + quantile_disc(c2, 0.3), quantile_disc(c2, 0.5), quantile_disc(c2, 0.9), quantile_disc(c2, 1.0), + quantile_disc(c3, 0.3), quantile_disc(c3, 0.5), quantile_disc(c3, 0.9), quantile_disc(c3, 1.0), + quantile_disc(c5, 0.3), quantile_disc(c5, 0.5), quantile_disc(c5, 0.9), quantile_disc(c5, 1.0), + quantile_disc(c6, 0.3), quantile_disc(c6, 0.5), quantile_disc(c6, 0.9), quantile_disc(c6, 1.0), + quantile_disc(c9, 0.3), quantile_disc(c9, 0.5), quantile_disc(c9, 0.9), quantile_disc(c9, 1.0), + quantile_disc(c10, 0.3), quantile_disc(c10, 0.5), quantile_disc(c10, 0.9), quantile_disc(c10, 1.0), + quantile_disc(c11, 0.3), quantile_disc(c11, 0.5), quantile_disc(c11, 0.9), quantile_disc(c11, 1.0), + quantile_disc(c12, 0.3), quantile_disc(c12, 0.5), quantile_disc(c12, 0.9), quantile_disc(c12, 1.0) +from aggregate_test_100; +---- +2 3 5 5 -44 14 102 125 -644225469 370975815 1991172974 2143473091 -2390782464845307388 1080308211931669384 7373730676428214987 9178511478067509438 1289293657 2307004493 3766999078 4268716378 4403623840168496677 9135746610908713318 16060348691054629425 17929716297117857676 0.34515214 0.48515016 0.8315913 0.9488028 0.296036538665 0.543759554042 0.946309824388 0.996540038759 + +statement ok +set datafusion.execution.batch_size = 1; + +query IIIIIIIIIIIIIIIIIIIIIIIIRRRRRRRR +select + quantile_cont(c2, 0.3), quantile_cont(c2, 0.5), quantile_cont(c2, 0.9), quantile_cont(c2, 1.0), + quantile_cont(c3, 0.3), quantile_cont(c3, 0.5), quantile_cont(c3, 0.9), quantile_cont(c3, 1.0), + quantile_cont(c5, 0.3), quantile_cont(c5, 0.5), quantile_cont(c5, 0.9), quantile_cont(c5, 1.0), + quantile_cont(c6, 0.3), quantile_cont(c6, 0.5), quantile_cont(c6, 0.9), quantile_cont(c6, 1.0), + quantile_cont(c9, 0.3), quantile_cont(c9, 0.5), quantile_cont(c9, 0.9), quantile_cont(c9, 1.0), + quantile_cont(c10, 0.3), quantile_cont(c10, 0.5), quantile_cont(c10, 0.9), quantile_cont(c10, 1.0), + quantile_cont(c11, 0.3), quantile_cont(c11, 0.5), quantile_cont(c11, 0.9), quantile_cont(c11, 1.0), + quantile_cont(c12, 0.3), quantile_cont(c12, 0.5), quantile_cont(c12, 0.9), quantile_cont(c12, 1.0) +from aggregate_test_100; +---- +2 3 5 5 -46 15 102 125 -652929122 377164262 1991374995 2143473091 -2447222796583094622 1125553990140691277 7374910587160164343 9178511478067509438 1275641152 2365817607 3776538486 4268716378 4383947019179292752 9299860258734726870 16067113363455246373 17929716297117857676 0.34409833 0.4906719 0.8340051 0.9488028 0.295550335607 0.551390054439 0.946311358438 0.996540038759 + +query IIIIIIIIIIIIIIIIIIIIIIIIRRRRRRRR +select + quantile_disc(c2, 0.3), quantile_disc(c2, 0.5), quantile_disc(c2, 0.9), quantile_disc(c2, 1.0), + quantile_disc(c3, 0.3), quantile_disc(c3, 0.5), quantile_disc(c3, 0.9), quantile_disc(c3, 1.0), + quantile_disc(c5, 0.3), quantile_disc(c5, 0.5), quantile_disc(c5, 0.9), quantile_disc(c5, 1.0), + quantile_disc(c6, 0.3), quantile_disc(c6, 0.5), quantile_disc(c6, 0.9), quantile_disc(c6, 1.0), + quantile_disc(c9, 0.3), quantile_disc(c9, 0.5), quantile_disc(c9, 0.9), quantile_disc(c9, 1.0), + quantile_disc(c10, 0.3), quantile_disc(c10, 0.5), quantile_disc(c10, 0.9), quantile_disc(c10, 1.0), + quantile_disc(c11, 0.3), quantile_disc(c11, 0.5), quantile_disc(c11, 0.9), quantile_disc(c11, 1.0), + quantile_disc(c12, 0.3), quantile_disc(c12, 0.5), quantile_disc(c12, 0.9), quantile_disc(c12, 1.0) +from aggregate_test_100; +---- +2 3 5 5 -44 14 102 125 -644225469 370975815 1991172974 2143473091 -2390782464845307388 1080308211931669384 7373730676428214987 9178511478067509438 1289293657 2307004493 3766999078 4268716378 4403623840168496677 9135746610908713318 16060348691054629425 17929716297117857676 0.34515214 0.48515016 0.8315913 0.9488028 0.296036538665 0.543759554042 0.946309824388 0.996540038759 + +statement ok +set datafusion.execution.batch_size = 2; + +query IIIIIIIIIIIIIIIIIIIIIIIIRRRRRRRR +select + quantile_cont(c2, 0.3), quantile_cont(c2, 0.5), quantile_cont(c2, 0.9), quantile_cont(c2, 1.0), + quantile_cont(c3, 0.3), quantile_cont(c3, 0.5), quantile_cont(c3, 0.9), quantile_cont(c3, 1.0), + quantile_cont(c5, 0.3), quantile_cont(c5, 0.5), quantile_cont(c5, 0.9), quantile_cont(c5, 1.0), + quantile_cont(c6, 0.3), quantile_cont(c6, 0.5), quantile_cont(c6, 0.9), quantile_cont(c6, 1.0), + quantile_cont(c9, 0.3), quantile_cont(c9, 0.5), quantile_cont(c9, 0.9), quantile_cont(c9, 1.0), + quantile_cont(c10, 0.3), quantile_cont(c10, 0.5), quantile_cont(c10, 0.9), quantile_cont(c10, 1.0), + quantile_cont(c11, 0.3), quantile_cont(c11, 0.5), quantile_cont(c11, 0.9), quantile_cont(c11, 1.0), + quantile_cont(c12, 0.3), quantile_cont(c12, 0.5), quantile_cont(c12, 0.9), quantile_cont(c12, 1.0) +from aggregate_test_100; +---- +2 3 5 5 -46 15 102 125 -652929122 377164262 1991374995 2143473091 -2447222796583094622 1125553990140691277 7374910587160164343 9178511478067509438 1275641152 2365817607 3776538486 4268716378 4383947019179292752 9299860258734726870 16067113363455246373 17929716297117857676 0.34409833 0.4906719 0.8340051 0.9488028 0.295550335607 0.551390054439 0.946311358438 0.996540038759 + +query IIIIIIIIIIIIIIIIIIIIIIIIRRRRRRRR +select + quantile_disc(c2, 0.3), quantile_disc(c2, 0.5), quantile_disc(c2, 0.9), quantile_disc(c2, 1.0), + quantile_disc(c3, 0.3), quantile_disc(c3, 0.5), quantile_disc(c3, 0.9), quantile_disc(c3, 1.0), + quantile_disc(c5, 0.3), quantile_disc(c5, 0.5), quantile_disc(c5, 0.9), quantile_disc(c5, 1.0), + quantile_disc(c6, 0.3), quantile_disc(c6, 0.5), quantile_disc(c6, 0.9), quantile_disc(c6, 1.0), + quantile_disc(c9, 0.3), quantile_disc(c9, 0.5), quantile_disc(c9, 0.9), quantile_disc(c9, 1.0), + quantile_disc(c10, 0.3), quantile_disc(c10, 0.5), quantile_disc(c10, 0.9), quantile_disc(c10, 1.0), + quantile_disc(c11, 0.3), quantile_disc(c11, 0.5), quantile_disc(c11, 0.9), quantile_disc(c11, 1.0), + quantile_disc(c12, 0.3), quantile_disc(c12, 0.5), quantile_disc(c12, 0.9), quantile_disc(c12, 1.0) +from aggregate_test_100; +---- +2 3 5 5 -44 14 102 125 -644225469 370975815 1991172974 2143473091 -2390782464845307388 1080308211931669384 7373730676428214987 9178511478067509438 1289293657 2307004493 3766999078 4268716378 4403623840168496677 9135746610908713318 16060348691054629425 17929716297117857676 0.34515214 0.48515016 0.8315913 0.9488028 0.296036538665 0.543759554042 0.946309824388 0.996540038759 + +statement ok +set datafusion.execution.batch_size = 3; + +query IIIIIIIIIIIIIIIIIIIIIIIIRRRRRRRR +select + quantile_cont(c2, 0.3), quantile_cont(c2, 0.5), quantile_cont(c2, 0.9), quantile_cont(c2, 1.0), + quantile_cont(c3, 0.3), quantile_cont(c3, 0.5), quantile_cont(c3, 0.9), quantile_cont(c3, 1.0), + quantile_cont(c5, 0.3), quantile_cont(c5, 0.5), quantile_cont(c5, 0.9), quantile_cont(c5, 1.0), + quantile_cont(c6, 0.3), quantile_cont(c6, 0.5), quantile_cont(c6, 0.9), quantile_cont(c6, 1.0), + quantile_cont(c9, 0.3), quantile_cont(c9, 0.5), quantile_cont(c9, 0.9), quantile_cont(c9, 1.0), + quantile_cont(c10, 0.3), quantile_cont(c10, 0.5), quantile_cont(c10, 0.9), quantile_cont(c10, 1.0), + quantile_cont(c11, 0.3), quantile_cont(c11, 0.5), quantile_cont(c11, 0.9), quantile_cont(c11, 1.0), + quantile_cont(c12, 0.3), quantile_cont(c12, 0.5), quantile_cont(c12, 0.9), quantile_cont(c12, 1.0) +from aggregate_test_100; +---- +2 3 5 5 -46 15 102 125 -652929122 377164262 1991374995 2143473091 -2447222796583094622 1125553990140691277 7374910587160164343 9178511478067509438 1275641152 2365817607 3776538486 4268716378 4383947019179292752 9299860258734726870 16067113363455246373 17929716297117857676 0.34409833 0.4906719 0.8340051 0.9488028 0.295550335607 0.551390054439 0.946311358438 0.996540038759 + +query IIIIIIIIIIIIIIIIIIIIIIIIRRRRRRRR +select + quantile_disc(c2, 0.3), quantile_disc(c2, 0.5), quantile_disc(c2, 0.9), quantile_disc(c2, 1.0), + quantile_disc(c3, 0.3), quantile_disc(c3, 0.5), quantile_disc(c3, 0.9), quantile_disc(c3, 1.0), + quantile_disc(c5, 0.3), quantile_disc(c5, 0.5), quantile_disc(c5, 0.9), quantile_disc(c5, 1.0), + quantile_disc(c6, 0.3), quantile_disc(c6, 0.5), quantile_disc(c6, 0.9), quantile_disc(c6, 1.0), + quantile_disc(c9, 0.3), quantile_disc(c9, 0.5), quantile_disc(c9, 0.9), quantile_disc(c9, 1.0), + quantile_disc(c10, 0.3), quantile_disc(c10, 0.5), quantile_disc(c10, 0.9), quantile_disc(c10, 1.0), + quantile_disc(c11, 0.3), quantile_disc(c11, 0.5), quantile_disc(c11, 0.9), quantile_disc(c11, 1.0), + quantile_disc(c12, 0.3), quantile_disc(c12, 0.5), quantile_disc(c12, 0.9), quantile_disc(c12, 1.0) +from aggregate_test_100; +---- +2 3 5 5 -44 14 102 125 -644225469 370975815 1991172974 2143473091 -2390782464845307388 1080308211931669384 7373730676428214987 9178511478067509438 1289293657 2307004493 3766999078 4268716378 4403623840168496677 9135746610908713318 16060348691054629425 17929716297117857676 0.34515214 0.48515016 0.8315913 0.9488028 0.296036538665 0.543759554042 0.946309824388 0.996540038759 + +statement ok +set datafusion.execution.batch_size = 8192; diff --git a/docs/source/user-guide/sql/aggregate_functions.md b/docs/source/user-guide/sql/aggregate_functions.md index 427a7bf130a77..1a0ab4eb75cc1 100644 --- a/docs/source/user-guide/sql/aggregate_functions.md +++ b/docs/source/user-guide/sql/aggregate_functions.md @@ -254,6 +254,8 @@ last_value(expression [ORDER BY expression]) - [regr_sxx](#regr_sxx) - [regr_syy](#regr_syy) - [regr_sxy](#regr_sxy) +- [quantile_cont](#quantile_cont) +- [quantile_disc](#quantile_disc) ### `corr` @@ -529,6 +531,32 @@ regr_sxy(expression_y, expression_x) - **expression_x**: Independent variable. Can be a constant, column, or function, and any combination of arithmetic operators. +### `quantile_cont` + +Returns the interpolated quantile value for a dataset based on the provided percentile. Linear interpolation will be performed if the target percentile falls between two data points. + +``` +quantile_cont(expression_data, percentile) +``` + +#### Arguments + +- **expression_data**: Can be a constant, column, or function, and any combination of arithmetic operators. +- **percentile**: Percentile to compute. Must be a float value between 0 and 1 (inclusive). + +### `quantile_disc` + +Provides the exact quantile value of a dataset based on the input percentile. Instead of interpolating between data points, it directly returns the data point that is closer to the exact percentile. (If the target percentile is equally close to two data points, return the value of the data point with smaller percentile) + +``` +quantile_disc(expression_data, percentile) +``` + +#### Arguments + +- **expression_data**: Can be a constant, column, or function, and any combination of arithmetic operators. +- **percentile**: Percentile to compute. Must be a float value between 0 and 1 (inclusive). + ## Approximate - [approx_distinct](#approx_distinct)