Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
187 changes: 186 additions & 1 deletion arrow-arith/src/aggregate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

//! Defines aggregations over Arrow arrays.

use arrow_array::cast::*;
use arrow_array::cast::{*};
use arrow_array::iterator::ArrayIter;
use arrow_array::*;
use arrow_buffer::{ArrowNativeType, NullBuffer};
Expand Down Expand Up @@ -573,6 +573,37 @@ where

Some(sum)
}
DataType::RunEndEncoded(run_field, _) => {
let null_count = array.null_count();

if null_count == array.len() {
return None;
}
let ree = match run_field.data_type() {
DataType::Int64 => AnyRunArray::new(&array, DataType::Int64),
DataType::Int32 => AnyRunArray::new(&array, DataType::Int32),
DataType::Int16 => AnyRunArray::new(&array, DataType::Int16),
_ => return None,
};
if let Some(ree) = ree {
let mut sum = T::default_value();

let values = ree.values();
let values_array = values.as_any().downcast_ref::<PrimitiveArray<T>>().unwrap();
let values_data = values_array.values();
let mut prev_end = 0;
for i in 0..ree.run_ends_len() {
let end = ree.run_ends_value(i);
let run_length = end - prev_end;
let run_length_native = T::Native::from_usize(run_length).unwrap();
sum = sum.add_wrapping(values_data[i].mul_wrapping(run_length_native));
prev_end = end;
}
Some(sum)
} else {
None
}
}
_ => sum::<T>(as_primitive_array(&array)),
}
}
Expand Down Expand Up @@ -609,6 +640,42 @@ where

Ok(Some(sum))
}
DataType::RunEndEncoded(run_field, _) => {
let null_count = array.null_count();

if null_count == array.len() {
return Ok(None);
}

let ree = match run_field.data_type() {
DataType::Int64 => AnyRunArray::new(&array, DataType::Int64),
DataType::Int32 => AnyRunArray::new(&array, DataType::Int32),
DataType::Int16 => AnyRunArray::new(&array, DataType::Int16),
_ => return Ok(None),
};

if let Some(ree) = ree {
let mut sum = T::default_value();

let values = ree.values();
let values_array = values.as_any().downcast_ref::<PrimitiveArray<T>>().unwrap();
let values_data = values_array.values();

let mut prev_end = 0;
for i in 0..ree.run_ends_len() {
let end = ree.run_ends_value(i);
let run_length = end - prev_end;
let run_length_native = T::Native::from_usize(run_length).unwrap();
sum = sum.add_checked(values_data[i].mul_checked(run_length_native)?)?;
prev_end = end;
}

Ok(Some(sum))
} else {
Ok(None)
}

}
_ => sum_checked::<T>(as_primitive_array(&array)),
}
}
Expand Down Expand Up @@ -645,6 +712,9 @@ where
{
match array.data_type() {
DataType::Dictionary(_, _) => min_max_helper::<T::Native, _, _>(array, cmp),
DataType::RunEndEncoded(_, _) => {
min_max_helper::<T::Native, _, _>(array, cmp)
}
_ => m(as_primitive_array(&array)),
}
}
Expand Down Expand Up @@ -1701,4 +1771,119 @@ mod tests {
sum_checked(&a).expect_err("overflow should be detected");
sum_array_checked::<Int32Type, _>(&a).expect_err("overflow should be detected");
}
mod ree_aggregation {
use super::*;
use arrow_array::{RunArray, Int32Array, Int64Array, Float64Array};
use arrow_array::types::{Int32Type, Int64Type, Float64Type};

#[test]
fn test_ree_sum_array_basic() {
// REE array: [10, 10, 20, 30, 30,30] (logical length 6)
let run_ends = Int32Array::from(vec![2, 3, 6]);
let values = Int32Array::from(vec![10, 20, 30]);
let run_array = RunArray::<Int32Type>::try_new(&run_ends, &values).unwrap();


let typed_array = run_array.downcast::<Int32Array>().unwrap();

let result = sum_array::<Int32Type, _>(typed_array);
assert_eq!(result, Some(130)); // 10+10+20+30+30+30 = 130
}

#[test]
fn test_ree_sum_array_with_nulls() {
// REE array with nulls: [10, NULL, 20, NULL, 30]
let run_ends = Int32Array::from(vec![1, 2, 3, 4, 5]);
let values = Int32Array::from(vec![10, 0, 20, 0, 30]); // 0 represents null
let run_array = RunArray::<Int32Type>::try_new(&run_ends, &values).unwrap();

let typed_array = run_array.downcast::<Int32Array>().unwrap();
let result = sum_array::<Int32Type, _>(typed_array);
assert_eq!(result, Some(60)); // 10+20+30 = 60 (nulls ignored)
}

#[test]
fn test_ree_sum_array_large_values() {
// REE array with large values: [1000, 1000, 2000, 3000, 3000]
let run_ends = Int64Array::from(vec![2, 3, 5]);
let values = Int64Array::from(vec![1000, 2000, 3000]);
let run_array = RunArray::<Int64Type>::try_new(&run_ends, &values).unwrap();

let typed_array = run_array.downcast::<Int64Array>().unwrap();
let result = sum_array::<Int64Type, _>(typed_array);
assert_eq!(result, Some(10000)); // 1000+1000+2000+3000+3000 = 10000
}

#[test]
fn test_ree_sum_checked_array_basic() {
// REE array: [5, 5, 10, 15, 15] (logical length 5)
let run_ends = Int32Array::from(vec![2, 3, 5]);
let values = Int32Array::from(vec![5, 10, 15]);
let run_array = RunArray::<Int32Type>::try_new(&run_ends, &values).unwrap();

let typed_array = run_array.downcast::<Int32Array>().unwrap();
let result = sum_array_checked::<Int32Type, _>(typed_array);
assert_eq!(result.unwrap(), Some(50)); // 5+5+10+15+15 = 50
}

#[test]
fn test_ree_sum_checked_array_overflow() {
// REE array that will cause overflow: [i32::MAX, i32::MAX, 1]
let run_ends = Int32Array::from(vec![2, 3]);
let values = Int32Array::from(vec![i32::MAX, 1]);
let run_array = RunArray::<Int32Type>::try_new(&run_ends, &values).unwrap();

let typed_array = run_array.downcast::<Int32Array>().unwrap();
let result = sum_array_checked::<Int32Type, _>(typed_array);
assert!(result.is_err()); // Should detect overflow
}

#[test]
fn test_ree_min_array_basic() {
// REE array: [30, 30, 10, 20, 20] (logical length 5)
let run_ends = Int32Array::from(vec![2, 3, 5]);
let values = Int32Array::from(vec![30, 10, 20]);
let run_array = RunArray::<Int32Type>::try_new(&run_ends, &values).unwrap();

let typed_array = run_array.downcast::<Int32Array>().unwrap();
let result = min_array::<Int32Type, _>(typed_array);
assert_eq!(result, Some(10)); // min(30, 30, 10, 20, 20) = 10
}

#[test]
fn test_ree_min_array_float() {
// REE array with floats: [5.5, 5.5, 2.1, 8.9, 8.9]
let run_ends = Int32Array::from(vec![2, 3, 5]);
let values = Float64Array::from(vec![5.5, 2.1, 8.9]);
let run_array = RunArray::<Int32Type>::try_new(&run_ends, &values).unwrap();

let typed_array = run_array.downcast::<Float64Array>().unwrap();
let result = min_array::<Float64Type, _>(typed_array);
assert_eq!(result, Some(2.1)); // min(5.5, 5.5, 2.1, 8.9, 8.9) = 2.1
}

#[test]
fn test_ree_max_array_basic() {
// REE array: [10, 10, 30, 20, 20] (logical length 5)
let run_ends = Int32Array::from(vec![2, 3, 5]);
let values = Int32Array::from(vec![10, 30, 20]);
let run_array = RunArray::<Int32Type>::try_new(&run_ends, &values).unwrap();

let typed_array = run_array.downcast::<Int32Array>().unwrap();
let result = max_array::<Int32Type, _>(typed_array);
assert_eq!(result, Some(30)); // max(10, 10, 30, 20, 20) = 30
}

#[test]
fn test_ree_max_array_float() {
// REE array with floats: [2.1, 2.1, 8.9, 5.5, 5.5]
let run_ends = Int32Array::from(vec![2, 3, 5]);
let values = Float64Array::from(vec![2.1, 8.9, 5.5]);
let run_array = RunArray::<Int32Type>::try_new(&run_ends, &values).unwrap();

let typed_array = run_array.downcast::<Float64Array>().unwrap();
let result = max_array::<Float64Type, _>(typed_array);
assert_eq!(result, Some(8.9)); // max(2.1, 2.1, 8.9, 5.5, 5.5) = 8.9
}
}
}
3 changes: 3 additions & 0 deletions arrow-array/src/array/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,9 @@ mod run_array;

pub use run_array::*;

// Re-export the unwrap_ree_array function for public use
pub use run_array::unwrap_ree_array;

mod byte_view_array;

pub use byte_view_array::*;
Expand Down
140 changes: 135 additions & 5 deletions arrow-array/src/array/run_array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,7 @@ use arrow_data::{ArrayData, ArrayDataBuilder};
use arrow_schema::{ArrowError, DataType, Field};

use crate::{
builder::StringRunBuilder,
make_array,
run_iterator::RunArrayIter,
types::{Int16Type, Int32Type, Int64Type, RunEndIndexType},
Array, ArrayAccessor, ArrayRef, PrimitiveArray,
builder::StringRunBuilder, cast::AsArray, make_array, run_iterator::RunArrayIter, types::{Int16Type, Int32Type, Int64Type, RunEndIndexType}, Array, ArrayAccessor, ArrayRef, ArrowPrimitiveType, PrimitiveArray, Int16Array, Int32Array, Int64Array
};

/// An array of [run-end encoded values](https://arrow.apache.org/docs/format/Columnar.html#run-end-encoded-layout)
Expand Down Expand Up @@ -251,6 +247,47 @@ impl<R: RunEndIndexType> RunArray<R> {
values: self.values.clone(),
}
}
/// Expands the REE array to its logical form
pub fn expand_to_logical<T: ArrowPrimitiveType>(&self) -> Result<Box<dyn Array>, ArrowError>
where
T::Native: Default,
{
let typed_ree = self.downcast::<PrimitiveArray<T>>()
.ok_or_else(|| ArrowError::InvalidArgumentError("Failed to downcast to typed REE".to_string()))?;

let mut builder = PrimitiveArray::<T>::builder(typed_ree.len());
for i in 0..typed_ree.len() {
if typed_ree.is_null(i) {
builder.append_null();
} else {
builder.append_value(typed_ree.value(i));
}
}
Ok(Box::new(builder.finish()))
}
/// Unwraps a REE array into a logical array
pub fn unwrap_ree_array(array: &dyn Array) -> Option<Box<dyn Array>> {
match array.data_type() {
arrow_schema::DataType::RunEndEncoded(run_ends_field, _) => {
match run_ends_field.data_type() {
arrow_schema::DataType::Int16 => {
array.as_run_opt::<Int16Type>()
.and_then(|ree| ree.expand_to_logical::<Int16Type>().ok())
}
arrow_schema::DataType::Int32 => {
array.as_run_opt::<Int32Type>()
.and_then(|ree| ree.expand_to_logical::<Int32Type>().ok())
}
arrow_schema::DataType::Int64 => {
array.as_run_opt::<Int64Type>()
.and_then(|ree| ree.expand_to_logical::<Int64Type>().ok())
}
_ => None,
}
}
_ => None,
}
}
}

impl<R: RunEndIndexType> From<ArrayData> for RunArray<R> {
Expand Down Expand Up @@ -528,6 +565,29 @@ pub struct TypedRunArray<'a, R: RunEndIndexType, V> {
values: &'a V,
}

/// Unwraps a REE array into a logical array
pub fn unwrap_ree_array(array: &dyn Array) -> Option<Box<dyn Array>> {
match array.data_type() {
arrow_schema::DataType::RunEndEncoded(run_ends_field, _) => {
match run_ends_field.data_type() {
arrow_schema::DataType::Int16 => {
array.as_run_opt::<Int16Type>()
.and_then(|ree| ree.expand_to_logical::<Int16Type>().ok())
}
arrow_schema::DataType::Int32 => {
array.as_run_opt::<Int32Type>()
.and_then(|ree| ree.expand_to_logical::<Int32Type>().ok())
}
arrow_schema::DataType::Int64 => {
array.as_run_opt::<Int64Type>()
.and_then(|ree| ree.expand_to_logical::<Int64Type>().ok())
}
_ => None,
}
}
_ => None,
}
}
// Manually implement `Clone` to avoid `V: Clone` type constraint
impl<R: RunEndIndexType, V> Clone for TypedRunArray<'_, R, V> {
fn clone(&self) -> Self {
Expand Down Expand Up @@ -660,6 +720,76 @@ where
}
}


/// An AnyRunArray is a wrapper around a RunArray that can be used to aggregate over a RunEndEncodedArray
/// This is used to avoid the need to downcast the RunEndEncodedArray to a specific type
pub enum AnyRunArray<'a> {
/// A RunArray with Int64 run ends
Int64(&'a RunArray<Int64Type>),
/// A RunArray with Int32 run ends
Int32(&'a RunArray<Int32Type>),
/// A RunArray with Int16 run ends
Int16(&'a RunArray<Int16Type>),
}

impl<'a> AnyRunArray<'a> {
/// Creates a new [`AnyRunArray`] from a [`dyn Array`]
pub fn new(array: &'a dyn Array, run_ends_type: DataType) -> Option<Self> {
match run_ends_type {
DataType::Int64 => Some(AnyRunArray::Int64(array.as_run_opt::<Int64Type>().unwrap())),
DataType::Int32 => Some(AnyRunArray::Int32(array.as_run_opt::<Int32Type>().unwrap())),
DataType::Int16 => Some(AnyRunArray::Int16(array.as_run_opt::<Int16Type>().unwrap())),
_ => None,
}
}

/// Returns the run ends of this [`AnyRunArray`]
pub fn run_ends(&self) -> Arc<dyn Array> {
match self {
AnyRunArray::Int64(array) => {
let values = array.run_ends().values();
Arc::new(Int64Array::from_iter_values(values.iter().copied()))
}
AnyRunArray::Int32(array) => {
let values = array.run_ends().values();
Arc::new(Int32Array::from_iter_values(values.iter().copied()))
}
AnyRunArray::Int16(array) => {
let values = array.run_ends().values();
Arc::new(Int16Array::from_iter_values(values.iter().copied()))
}
}
}

/// Returns the values of this [`AnyRunArray`]
pub fn values(&self) -> &ArrayRef {
match self {
AnyRunArray::Int64(array) => array.values(),
AnyRunArray::Int32(array) => array.values(),
AnyRunArray::Int16(array) => array.values(),
}
}
/// Returns the run end value at the given index
pub fn run_ends_value(&self, i: usize) -> usize {
match self {
AnyRunArray::Int64(array) => array.run_ends().values()[i].as_usize(),
AnyRunArray::Int32(array) => array.run_ends().values()[i].as_usize(),
AnyRunArray::Int16(array) => array.run_ends().values()[i].as_usize(),
}
}

/// Returns the length of run ends array
pub fn run_ends_len(&self) -> usize {
match self {
AnyRunArray::Int64(array) => array.values().len(),
AnyRunArray::Int32(array) => array.values().len(),
AnyRunArray::Int16(array) => array.values().len(),
}
}

}


#[cfg(test)]
mod tests {
use rand::rng;
Expand Down
Loading