Skip to content
280 changes: 272 additions & 8 deletions datafusion/physical-expr/src/aggregate/count_distinct.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,21 +15,32 @@
// specific language governing permissions and limitations
// under the License.

use arrow::datatypes::{DataType, Field};
use arrow::datatypes::{DataType, Field, TimeUnit};
use arrow_array::types::{
ArrowPrimitiveType, Date32Type, Date64Type, Decimal128Type, Decimal256Type,
Float16Type, Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, Int8Type,
Time32MillisecondType, Time32SecondType, Time64MicrosecondType, Time64NanosecondType,
TimestampMicrosecondType, TimestampMillisecondType, TimestampNanosecondType,
TimestampSecondType, UInt16Type, UInt32Type, UInt64Type, UInt8Type,
};
use arrow_array::PrimitiveArray;

use std::any::Any;
use std::cmp::Eq;
use std::fmt::Debug;
use std::hash::Hash;
use std::sync::Arc;

use ahash::RandomState;
use arrow::array::{Array, ArrayRef};
use std::collections::HashSet;

use crate::aggregate::utils::down_cast_any_ref;
use crate::aggregate::utils::{down_cast_any_ref, Hashable};
use crate::expressions::format_state_name;
use crate::{AggregateExpr, PhysicalExpr};
use datafusion_common::Result;
use datafusion_common::ScalarValue;
use datafusion_common::cast::{as_list_array, as_primitive_array};
use datafusion_common::utils::array_into_list_array;
use datafusion_common::{Result, ScalarValue};
use datafusion_expr::Accumulator;

type DistinctScalarValues = ScalarValue;
Expand Down Expand Up @@ -60,6 +71,18 @@ impl DistinctCount {
}
}

macro_rules! native_distinct_count_accumulator {
($TYPE:ident) => {{
Ok(Box::new(NativeDistinctCountAccumulator::<$TYPE>::new()))
}};
}

macro_rules! float_distinct_count_accumulator {
($TYPE:ident) => {{
Ok(Box::new(FloatDistinctCountAccumulator::<$TYPE>::new()))
}};
}

impl AggregateExpr for DistinctCount {
/// Return a reference to Any that can be used for downcasting
fn as_any(&self) -> &dyn Any {
Expand All @@ -83,10 +106,57 @@ impl AggregateExpr for DistinctCount {
}

fn create_accumulator(&self) -> Result<Box<dyn Accumulator>> {
Ok(Box::new(DistinctCountAccumulator {
values: HashSet::default(),
state_data_type: self.state_data_type.clone(),
}))
use DataType::*;
use TimeUnit::*;

match &self.state_data_type {
Int8 => native_distinct_count_accumulator!(Int8Type),
Int16 => native_distinct_count_accumulator!(Int16Type),
Int32 => native_distinct_count_accumulator!(Int32Type),
Int64 => native_distinct_count_accumulator!(Int64Type),
UInt8 => native_distinct_count_accumulator!(UInt8Type),
UInt16 => native_distinct_count_accumulator!(UInt16Type),
UInt32 => native_distinct_count_accumulator!(UInt32Type),
UInt64 => native_distinct_count_accumulator!(UInt64Type),
Decimal128(_, _) => native_distinct_count_accumulator!(Decimal128Type),
Decimal256(_, _) => native_distinct_count_accumulator!(Decimal256Type),

Date32 => native_distinct_count_accumulator!(Date32Type),
Date64 => native_distinct_count_accumulator!(Date64Type),
Time32(Millisecond) => {
native_distinct_count_accumulator!(Time32MillisecondType)
}
Time32(Second) => {
native_distinct_count_accumulator!(Time32SecondType)
}
Time64(Microsecond) => {
native_distinct_count_accumulator!(Time64MicrosecondType)
}
Time64(Nanosecond) => {
native_distinct_count_accumulator!(Time64NanosecondType)
}
Timestamp(Microsecond, _) => {
native_distinct_count_accumulator!(TimestampMicrosecondType)
}
Timestamp(Millisecond, _) => {
native_distinct_count_accumulator!(TimestampMillisecondType)
}
Timestamp(Nanosecond, _) => {
native_distinct_count_accumulator!(TimestampNanosecondType)
}
Timestamp(Second, _) => {
native_distinct_count_accumulator!(TimestampSecondType)
}

Float16 => float_distinct_count_accumulator!(Float16Type),
Float32 => float_distinct_count_accumulator!(Float32Type),
Float64 => float_distinct_count_accumulator!(Float64Type),
Copy link
Contributor

@Dandandan Dandandan Jan 4, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not for now, but if we would like to do it for strings / bytes, we could do use a datastructure like this to get maximal performance:

/// Contains hashes and offsets for given hash (+ potential collisions), use `RawTable` for extra speed
uniques: HashMap<u64, SmallVec<u64; 1>>,
/// actual string/byte data, can be emitted cheaply / free
values: BufferBuilder<u8>,

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is similar to the idea in #7064

Maybe we can eventually use the same data structure (specialized for storing string values not using a String)


_ => Ok(Box::new(DistinctCountAccumulator {
values: HashSet::default(),
state_data_type: self.state_data_type.clone(),
})),
}
}

fn name(&self) -> &str {
Expand Down Expand Up @@ -192,6 +262,164 @@ impl Accumulator for DistinctCountAccumulator {
}
}

#[derive(Debug)]
struct NativeDistinctCountAccumulator<T>
where
T: ArrowPrimitiveType + Send,
T::Native: Eq + Hash,
{
values: HashSet<T::Native, RandomState>,
}

impl<T> NativeDistinctCountAccumulator<T>
where
T: ArrowPrimitiveType + Send,
T::Native: Eq + Hash,
{
fn new() -> Self {
Self {
values: HashSet::default(),
}
}
}

impl<T> Accumulator for NativeDistinctCountAccumulator<T>
where
T: ArrowPrimitiveType + Send + Debug,
T::Native: Eq + Hash,
{
fn state(&self) -> Result<Vec<ScalarValue>> {
let arr = Arc::new(PrimitiveArray::<T>::from_iter_values(
self.values.iter().cloned(),
)) as ArrayRef;
let list = Arc::new(array_into_list_array(arr)) as ArrayRef;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👏 for @jayzhan211 for switching the native implementation of ScalarValue::List to use an ArrayRef

Ok(vec![ScalarValue::List(list)])
}

fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
if values.is_empty() {
return Ok(());
}

let arr = as_primitive_array::<T>(&values[0])?;
arr.iter().for_each(|value| {
if let Some(value) = value {
self.values.insert(value);
}
});

Ok(())
}

fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
if states.is_empty() {
return Ok(());
}
assert_eq!(
states.len(),
1,
"count_distinct states must be single array"
);

let arr = as_list_array(&states[0])?;
arr.iter().try_for_each(|maybe_list| {
if let Some(list) = maybe_list {
let list = as_primitive_array::<T>(&list)?;
self.values.extend(list.values())
};
Ok(())
})
}

fn evaluate(&self) -> Result<ScalarValue> {
Ok(ScalarValue::Int64(Some(self.values.len() as i64)))
}

fn size(&self) -> usize {
std::mem::size_of_val(self)
+ std::mem::size_of_val(&self.values)
+ (std::mem::size_of::<T::Native>() * self.values.capacity())
}
}

#[derive(Debug)]
struct FloatDistinctCountAccumulator<T>
where
T: ArrowPrimitiveType + Send,
{
values: HashSet<Hashable<T::Native>, RandomState>,
}

impl<T> FloatDistinctCountAccumulator<T>
where
T: ArrowPrimitiveType + Send,
{
fn new() -> Self {
Self {
values: HashSet::default(),
}
}
}

impl<T> Accumulator for FloatDistinctCountAccumulator<T>
where
T: ArrowPrimitiveType + Send + Debug,
{
fn state(&self) -> Result<Vec<ScalarValue>> {
let arr = Arc::new(PrimitiveArray::<T>::from_iter_values(
self.values.iter().map(|v| v.0),
)) as ArrayRef;
let list = Arc::new(array_into_list_array(arr)) as ArrayRef;
Ok(vec![ScalarValue::List(list)])
}

fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
if values.is_empty() {
return Ok(());
}

let arr = as_primitive_array::<T>(&values[0])?;
arr.iter().for_each(|value| {
if let Some(value) = value {
self.values.insert(Hashable(value));
}
});

Ok(())
}

fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
if states.is_empty() {
return Ok(());
}
assert_eq!(
states.len(),
1,
"count_distinct states must be single array"
);

let arr = as_list_array(&states[0])?;
arr.iter().try_for_each(|maybe_list| {
if let Some(list) = maybe_list {
let list = as_primitive_array::<T>(&list)?;
self.values
.extend(list.values().iter().map(|v| Hashable(*v)));
};
Ok(())
})
}

fn evaluate(&self) -> Result<ScalarValue> {
Ok(ScalarValue::Int64(Some(self.values.len() as i64)))
}

fn size(&self) -> usize {
std::mem::size_of_val(self)
+ std::mem::size_of_val(&self.values)
+ (std::mem::size_of::<T::Native>() * self.values.capacity())
}
}

#[cfg(test)]
mod tests {
use crate::expressions::NoOp;
Expand All @@ -206,6 +434,8 @@ mod tests {
Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, Int8Type, UInt16Type,
UInt32Type, UInt64Type, UInt8Type,
};
use arrow_array::Decimal256Array;
use arrow_buffer::i256;
use datafusion_common::cast::{as_boolean_array, as_list_array, as_primitive_array};
use datafusion_common::internal_err;
use datafusion_common::DataFusionError;
Expand Down Expand Up @@ -367,6 +597,35 @@ mod tests {
}};
}

macro_rules! test_count_distinct_update_batch_bigint {
($ARRAY_TYPE:ident, $DATA_TYPE:ident, $PRIM_TYPE:ty) => {{
let values: Vec<Option<$PRIM_TYPE>> = vec![
Some(i256::from(1)),
Some(i256::from(1)),
None,
Some(i256::from(3)),
Some(i256::from(2)),
None,
Some(i256::from(2)),
Some(i256::from(3)),
Some(i256::from(1)),
];

let arrays = vec![Arc::new($ARRAY_TYPE::from(values)) as ArrayRef];

let (states, result) = run_update_batch(&arrays)?;

let mut state_vec = state_to_vec_primitive!(&states[0], $DATA_TYPE);
state_vec.sort();

assert_eq!(states.len(), 1);
assert_eq!(state_vec, vec![i256::from(1), i256::from(2), i256::from(3)]);
assert_eq!(result, ScalarValue::Int64(Some(3)));

Ok(())
}};
}

#[test]
fn count_distinct_update_batch_i8() -> Result<()> {
test_count_distinct_update_batch_numeric!(Int8Array, Int8Type, i8)
Expand Down Expand Up @@ -417,6 +676,11 @@ mod tests {
test_count_distinct_update_batch_floating_point!(Float64Array, Float64Type, f64)
}

#[test]
fn count_distinct_update_batch_i256() -> Result<()> {
test_count_distinct_update_batch_bigint!(Decimal256Array, Decimal256Type, i256)
}

#[test]
fn count_distinct_update_batch_boolean() -> Result<()> {
let get_count = |data: BooleanArray| -> Result<(Vec<bool>, i64)> {
Expand Down
22 changes: 2 additions & 20 deletions datafusion/physical-expr/src/aggregate/sum_distinct.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,11 @@ use arrow::array::{Array, ArrayRef};
use arrow_array::cast::AsArray;
use arrow_array::types::*;
use arrow_array::{ArrowNativeTypeOp, ArrowPrimitiveType};
use arrow_buffer::{ArrowNativeType, ToByteSlice};
use arrow_buffer::ArrowNativeType;
use std::collections::HashSet;

use crate::aggregate::sum::downcast_sum;
use crate::aggregate::utils::down_cast_any_ref;
use crate::aggregate::utils::{down_cast_any_ref, Hashable};
use crate::{AggregateExpr, PhysicalExpr};
use datafusion_common::{not_impl_err, DataFusionError, Result, ScalarValue};
use datafusion_expr::type_coercion::aggregates::sum_return_type;
Expand Down Expand Up @@ -119,24 +119,6 @@ impl PartialEq<dyn Any> for DistinctSum {
}
}

/// A wrapper around a type to provide hash for floats
#[derive(Copy, Clone)]
struct Hashable<T>(T);

impl<T: ToByteSlice> std::hash::Hash for Hashable<T> {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
self.0.to_byte_slice().hash(state)
}
}

impl<T: ArrowNativeTypeOp> PartialEq for Hashable<T> {
fn eq(&self, other: &Self) -> bool {
self.0.is_eq(other.0)
}
}

impl<T: ArrowNativeTypeOp> Eq for Hashable<T> {}

struct DistinctSumAccumulator<T: ArrowPrimitiveType> {
values: HashSet<Hashable<T::Native>, RandomState>,
data_type: DataType,
Expand Down
Loading