From acb5d08cabc4078a4cdebed5ee0802e9c86b2f55 Mon Sep 17 00:00:00 2001 From: Richard Baah Date: Tue, 15 Jul 2025 15:32:16 -0400 Subject: [PATCH 1/3] implements Sum,sum_checked,min,max,is Distict,inverse for REE. ALso includes helper functions for expanding REE into logical represention --- arrow-arith/src/aggregate.rs | 208 +++++++++++++++++++++++++++++ arrow-array/src/array/mod.rs | 3 + arrow-array/src/array/run_array.rs | 94 ++++++++++++- arrow-ord/src/cmp.rs | 135 ++++++++++++++++++- arrow-ord/src/lib.rs | 1 + 5 files changed, 432 insertions(+), 9 deletions(-) diff --git a/arrow-arith/src/aggregate.rs b/arrow-arith/src/aggregate.rs index ef0fddeb0b8e..117b45632fa0 100644 --- a/arrow-arith/src/aggregate.rs +++ b/arrow-arith/src/aggregate.rs @@ -23,6 +23,7 @@ use arrow_array::*; use arrow_buffer::{ArrowNativeType, NullBuffer}; use arrow_data::bit_iterator::try_for_each_valid_idx; use arrow_schema::*; +use num::cast; use std::borrow::BorrowMut; use std::cmp::{self, Ordering}; use std::ops::{BitAnd, BitOr, BitXor}; @@ -573,6 +574,26 @@ where Some(sum) } + DataType::RunEndEncoded(_, _) => { + let null_count = array.null_count(); + + if null_count == array.len() { + return None; + } + + // Expand REE array to its logical form and recursively call sum_array + if let Some(expanded_array) = arrow_array::unwrap_ree_array(&array) { + // Cast the expanded array to the appropriate type and call sum_array recursively + if let Some(primitive_array) = expanded_array.as_any().downcast_ref::>() { + sum::(primitive_array) + } else { + // If we can't downcast, return None + None + } + } else { + None + } + } _ => sum::(as_primitive_array(&array)), } } @@ -609,6 +630,26 @@ where Ok(Some(sum)) } + DataType::RunEndEncoded(_, _) => { + let null_count = array.null_count(); + + if null_count == array.len() { + return Ok(None); + } + + // Expand REE array to its logical form and recursively call sum_array_checked + if let Some(expanded_array) = arrow_array::unwrap_ree_array(&array) { + // Cast the expanded array to the appropriate type and call sum_checked recursively + if let Some(primitive_array) = expanded_array.as_any().downcast_ref::>() { + sum_checked::(primitive_array) + } else { + // If we can't downcast, return None + Ok(None) + } + } else { + Ok(None) + } + } _ => sum_checked::(as_primitive_array(&array)), } } @@ -645,6 +686,26 @@ where { match array.data_type() { DataType::Dictionary(_, _) => min_max_helper::(array, cmp), + DataType::RunEndEncoded(_, _) => { + let null_count = array.null_count(); + + if null_count == array.len() { + return None; + } + + // Expand REE array to its logical form and recursively call min_max_array_helper + if let Some(expanded_array) = arrow_array::unwrap_ree_array(&array) { + // Cast the expanded array to the appropriate type and call min_max_helper recursively + if let Some(primitive_array) = expanded_array.as_any().downcast_ref::>() { + min_max_helper::(primitive_array, cmp) + } else { + // If we can't downcast, return None + None + } + } else { + None + } + } _ => m(as_primitive_array(&array)), } } @@ -1701,4 +1762,151 @@ mod tests { sum_checked(&a).expect_err("overflow should be detected"); sum_array_checked::(&a).expect_err("overflow should be detected"); } + // ... existing code ... + // REE (RunEndEncodedArray) Tests + mod ree_aggregation { + use super::*; + use arrow_array::{RunArray, Int32Array, Int64Array, Float64Array}; + use arrow_array::types::{Int32Type, Int64Type, Float64Type}; + + #[test] + fn test_ree_sum_array_basic() { + // REE array: [10, 10, 20, 30, 30] (logical length 5) + let run_ends = Int32Array::from(vec![2, 3, 5]); + let values = Int32Array::from(vec![10, 20, 30]); + let run_array = RunArray::::try_new(&run_ends, &values).unwrap(); + + // Expand to logical form and test + let expanded = arrow_array::unwrap_ree_array(&run_array).unwrap(); + let primitive_array = expanded.as_any().downcast_ref::().unwrap(); + let result = sum_array::(primitive_array); + assert_eq!(result, Some(100)); // 10+10+20+30+30 = 100 + } + + #[test] + fn test_ree_sum_array_with_nulls() { + // REE array with nulls: [10, NULL, 20, NULL, 30] + let run_ends = Int32Array::from(vec![1, 2, 3, 4, 5]); + let values = Int32Array::from(vec![10, -1, 20, -1, 30]); // -1 represents null + let run_array = RunArray::::try_new(&run_ends, &values).unwrap(); + + // Expand to logical form and test + let expanded = arrow_array::unwrap_ree_array(&run_array).unwrap(); + let primitive_array = expanded.as_any().downcast_ref::().unwrap(); + let result = sum_array::(primitive_array); + assert_eq!(result, Some(60)); // 10+20+30 = 60 (nulls ignored) + } + + #[test] + fn test_ree_sum_array_checked_basic() { + // REE array: [5, 5, 10, 15, 15] (logical length 5) + let run_ends = Int32Array::from(vec![2, 3, 5]); + let values = Int32Array::from(vec![5, 10, 15]); + let run_array = RunArray::::try_new(&run_ends, &values).unwrap(); + + // Expand to logical form and test + let expanded = arrow_array::unwrap_ree_array(&run_array).unwrap(); + let primitive_array = expanded.as_any().downcast_ref::().unwrap(); + let result = sum_array_checked::(primitive_array).unwrap(); + assert_eq!(result, Some(50)); // 5+5+10+15+15 = 50 + } + + #[test] + fn test_ree_sum_array_checked_overflow() { + // REE array that will overflow: [i32::MAX, i32::MAX, 1] + let run_ends = Int32Array::from(vec![2, 3]); + let values = Int32Array::from(vec![i32::MAX, 1]); + let run_array = RunArray::::try_new(&run_ends, &values).unwrap(); + + // Expand to logical form and test + let expanded = arrow_array::unwrap_ree_array(&run_array).unwrap(); + let primitive_array = expanded.as_any().downcast_ref::().unwrap(); + let result = sum_array_checked::(primitive_array); + assert!(result.is_err()); // Should overflow + } + + #[test] + fn test_ree_min_array_basic() { + // REE array: [50, 50, 10, 30, 30] (logical length 5) + let run_ends = Int32Array::from(vec![2, 3, 5]); + let values = Int32Array::from(vec![50, 10, 30]); + let run_array = RunArray::::try_new(&run_ends, &values).unwrap(); + + // Expand to logical form and test + let expanded = arrow_array::unwrap_ree_array(&run_array).unwrap(); + let primitive_array = expanded.as_any().downcast_ref::().unwrap(); + let result = min_array::(primitive_array); + assert_eq!(result, Some(10)); // Minimum value is 10 + } + + #[test] + fn test_ree_min_array_with_nulls() { + // REE array with nulls: [100, NULL, 5, NULL, 200] + let run_ends = Int32Array::from(vec![1, 2, 3, 4, 5]); + let values = Int32Array::from(vec![100, -1, 5, -1, 200]); // -1 represents null + let run_array = RunArray::::try_new(&run_ends, &values).unwrap(); + + // Expand to logical form and test + let expanded = arrow_array::unwrap_ree_array(&run_array).unwrap(); + let primitive_array = expanded.as_any().downcast_ref::().unwrap(); + let result = min_array::(primitive_array); + assert_eq!(result, Some(5)); // Minimum non-null value is 5 + } + + #[test] + fn test_ree_max_array_basic() { + // REE array: [10, 10, 50, 20, 20] (logical length 5) + let run_ends = Int32Array::from(vec![2, 3, 5]); + let values = Int32Array::from(vec![10, 50, 20]); + let run_array = RunArray::::try_new(&run_ends, &values).unwrap(); + + // Expand to logical form and test + let expanded = arrow_array::unwrap_ree_array(&run_array).unwrap(); + let primitive_array = expanded.as_any().downcast_ref::().unwrap(); + let result = max_array::(primitive_array); + assert_eq!(result, Some(50)); // Maximum value is 50 + } + + #[test] + fn test_ree_max_array_with_nulls() { + // REE array with nulls: [5, NULL, 500, NULL, 10] + let run_ends = Int32Array::from(vec![1, 2, 3, 4, 5]); + let values = Int32Array::from(vec![5, -1, 500, -1, 10]); // -1 represents null + let run_array = RunArray::::try_new(&run_ends, &values).unwrap(); + + // Expand to logical form and test + let expanded = arrow_array::unwrap_ree_array(&run_array).unwrap(); + let primitive_array = expanded.as_any().downcast_ref::().unwrap(); + let result = max_array::(primitive_array); + assert_eq!(result, Some(500)); // Maximum non-null value is 500 + } + + #[test] + fn test_ree_sum_array_large_values() { + // REE array with large values: [1000000, 1000000, 2000000, 3000000, 3000000] + let run_ends = Int64Array::from(vec![2, 3, 5]); + let values = Int64Array::from(vec![1000000, 2000000, 3000000]); + let run_array = RunArray::::try_new(&run_ends, &values).unwrap(); + + // Expand to logical form and test + let expanded = arrow_array::unwrap_ree_array(&run_array).unwrap(); + let primitive_array = expanded.as_any().downcast_ref::().unwrap(); + let result = sum_array::(primitive_array); + assert_eq!(result, Some(10000000)); // 1M+1M+2M+3M+3M = 10M + } + + #[test] + fn test_ree_max_array_float_values() { + // REE array with float values: [1.5, 1.5, 3.7, 2.1, 2.1] + let run_ends = Int32Array::from(vec![2, 3, 5]); + let values = Float64Array::from(vec![1.5, 3.7, 2.1]); + let run_array = RunArray::::try_new(&run_ends, &values).unwrap(); + + // Expand to logical form and test + let expanded = arrow_array::unwrap_ree_array(&run_array).unwrap(); + let primitive_array = expanded.as_any().downcast_ref::().unwrap(); + let result = max_array::(primitive_array); + assert_eq!(result, Some(3.7)); // Maximum value is 3.7 + } + } } diff --git a/arrow-array/src/array/mod.rs b/arrow-array/src/array/mod.rs index e41a3a1d719a..0e0ca8a124ae 100644 --- a/arrow-array/src/array/mod.rs +++ b/arrow-array/src/array/mod.rs @@ -68,6 +68,9 @@ mod run_array; pub use run_array::*; +// Re-export the unwrap_ree_array function for public use +pub use run_array::unwrap_ree_array; + mod byte_view_array; pub use byte_view_array::*; diff --git a/arrow-array/src/array/run_array.rs b/arrow-array/src/array/run_array.rs index b305025706bc..872148329f97 100644 --- a/arrow-array/src/array/run_array.rs +++ b/arrow-array/src/array/run_array.rs @@ -23,11 +23,7 @@ use arrow_data::{ArrayData, ArrayDataBuilder}; use arrow_schema::{ArrowError, DataType, Field}; use crate::{ - builder::StringRunBuilder, - make_array, - run_iterator::RunArrayIter, - types::{Int16Type, Int32Type, Int64Type, RunEndIndexType}, - Array, ArrayAccessor, ArrayRef, PrimitiveArray, + builder::StringRunBuilder, cast::AsArray, make_array, run_iterator::RunArrayIter, types::{Int16Type, Int32Type, Int64Type, RunEndIndexType}, Array, ArrayAccessor, ArrayRef, ArrowPrimitiveType, PrimitiveArray }; /// An array of [run-end encoded values](https://arrow.apache.org/docs/format/Columnar.html#run-end-encoded-layout) @@ -251,6 +247,47 @@ impl RunArray { values: self.values.clone(), } } + /// Expands the REE array to its logical form + pub fn expand_to_logical(&self) -> Result, ArrowError> + where + T::Native: Default, + { + let typed_ree = self.downcast::>() + .ok_or_else(|| ArrowError::InvalidArgumentError("Failed to downcast to typed REE".to_string()))?; + + let mut builder = PrimitiveArray::::builder(typed_ree.len()); + for i in 0..typed_ree.len() { + if typed_ree.is_null(i) { + builder.append_null(); + } else { + builder.append_value(typed_ree.value(i)); + } + } + Ok(Box::new(builder.finish())) + } + /// Unwraps a REE array into a logical array + pub fn unwrap_ree_array(array: &dyn Array) -> Option> { + match array.data_type() { + arrow_schema::DataType::RunEndEncoded(run_ends_field, _) => { + match run_ends_field.data_type() { + arrow_schema::DataType::Int16 => { + array.as_run_opt::() + .and_then(|ree| ree.expand_to_logical::().ok()) + } + arrow_schema::DataType::Int32 => { + array.as_run_opt::() + .and_then(|ree| ree.expand_to_logical::().ok()) + } + arrow_schema::DataType::Int64 => { + array.as_run_opt::() + .and_then(|ree| ree.expand_to_logical::().ok()) + } + _ => None, + } + } + _ => None, + } +} } impl From for RunArray { @@ -528,6 +565,29 @@ pub struct TypedRunArray<'a, R: RunEndIndexType, V> { values: &'a V, } +/// Unwraps a REE array into a logical array +pub fn unwrap_ree_array(array: &dyn Array) -> Option> { + match array.data_type() { + arrow_schema::DataType::RunEndEncoded(run_ends_field, _) => { + match run_ends_field.data_type() { + arrow_schema::DataType::Int16 => { + array.as_run_opt::() + .and_then(|ree| ree.expand_to_logical::().ok()) + } + arrow_schema::DataType::Int32 => { + array.as_run_opt::() + .and_then(|ree| ree.expand_to_logical::().ok()) + } + arrow_schema::DataType::Int64 => { + array.as_run_opt::() + .and_then(|ree| ree.expand_to_logical::().ok()) + } + _ => None, + } + } + _ => None, + } +} // Manually implement `Clone` to avoid `V: Clone` type constraint impl Clone for TypedRunArray<'_, R, V> { fn clone(&self) -> Self { @@ -618,6 +678,30 @@ impl Array for TypedRunArray<'_, R, V> { } } +// ArrayAccessor implementation for RunArray itself +// This allows RunArray to be used directly with aggregation functions +impl<'a, R: RunEndIndexType> ArrayAccessor for &'a RunArray { + type Item = ArrayRef; + + fn value(&self, logical_index: usize) -> Self::Item { + assert!( + logical_index < self.len(), + "Trying to access an element at index {} from a RunArray of length {}", + logical_index, + self.len() + ); + unsafe { self.value_unchecked(logical_index) } + } + + unsafe fn value_unchecked(&self, logical_index: usize) -> Self::Item { + let physical_index = self.get_physical_index(logical_index); + // Return the value at the physical index as an ArrayRef + // This is a single-element array containing the value + let value_array = self.values().slice(physical_index, 1); + value_array + } +} + // Array accessor converts the index of logical array to the index of the physical array // using binary search. The time complexity is O(log N) where N is number of runs. impl<'a, R, V> ArrayAccessor for TypedRunArray<'a, R, V> diff --git a/arrow-ord/src/cmp.rs b/arrow-ord/src/cmp.rs index 2727ff996150..55734cedcdb7 100644 --- a/arrow-ord/src/cmp.rs +++ b/arrow-ord/src/cmp.rs @@ -24,16 +24,16 @@ //! use arrow_array::cast::AsArray; -use arrow_array::types::{ByteArrayType, ByteViewType}; +use arrow_array::types::{ByteArrayType, ByteViewType, Int16Type, Int32Type, Int64Type}; use arrow_array::{ - downcast_primitive_array, AnyDictionaryArray, Array, ArrowNativeTypeOp, BooleanArray, Datum, - FixedSizeBinaryArray, GenericByteArray, GenericByteViewArray, + array, downcast_primitive_array, AnyDictionaryArray, Array, ArrayRef, ArrowNativeTypeOp, ArrowPrimitiveType, BooleanArray, Datum, FixedSizeBinaryArray, GenericByteArray, GenericByteViewArray, PrimitiveArray }; use arrow_buffer::bit_util::ceil; use arrow_buffer::{BooleanBuffer, MutableBuffer, NullBuffer}; -use arrow_schema::ArrowError; +use arrow_schema::{ArrowError, DataType}; use arrow_select::take::take; use std::ops::Not; +use std::sync::Arc; #[derive(Debug, Copy, Clone)] enum Op { @@ -224,6 +224,14 @@ fn compare_op(op: Op, lhs: &dyn Datum, rhs: &dyn Datum) -> Result Result::try_new(&run_ends, &values).unwrap(); + + let run_ends = Int32Array::from(vec![2, 3, 5]); + let values = Int32Array::from(vec![10, 20, 30]); + let other_array = RunArray::::try_new(&run_ends, &values).unwrap(); + // Test distinct operations + let result = distinct(&run_array, &other_array).unwrap(); + assert_eq!(result, BooleanArray::from(vec![false,false,false,false,false])); + + let result = not_distinct(&run_array, &other_array).unwrap(); + // Expected: [true, false, true, false, true] (opposite of distinct) + assert_eq!(result, BooleanArray::from(vec![true,true,true,true,true])); + } + + #[test] + fn test_ree_distinct_mismatched_values() { + // [10, 10, 20] + let run_ends = Int32Array::from(vec![2, 3]); + let values = Int32Array::from(vec![10, 20]); + let lhs = RunArray::::try_new(&run_ends, &values).unwrap(); + + // [10, 10, 99] + let run_ends = Int32Array::from(vec![2, 3]); + let values = Int32Array::from(vec![10, 99]); + let rhs = RunArray::::try_new(&run_ends, &values).unwrap(); + + let result = distinct(&lhs, &rhs).unwrap(); + assert_eq!(result, BooleanArray::from(vec![false, false, true])); + + let result = not_distinct(&lhs, &rhs).unwrap(); + assert_eq!(result, BooleanArray::from(vec![true, true, false])); + } + + #[test] + fn test_ree_distinct_all_different() { + // [1, 2, 3] + let run_ends = Int32Array::from(vec![1, 2, 3]); + let values = Int32Array::from(vec![1, 2, 3]); + let lhs = RunArray::::try_new(&run_ends, &values).unwrap(); + + // [4, 5, 6] + let run_ends = Int32Array::from(vec![1, 2, 3]); + let values = Int32Array::from(vec![4, 5, 6]); + let rhs = RunArray::::try_new(&run_ends, &values).unwrap(); + + let result = distinct(&lhs, &rhs).unwrap(); + assert_eq!(result, BooleanArray::from(vec![true, true, true])); + + let result = not_distinct(&lhs, &rhs).unwrap(); + assert_eq!(result, BooleanArray::from(vec![false, false, false])); + } + + //failed + #[test] + fn test_ree_distinct_mixed_values() { + // [10, 10, 20, 30, 30] + let lhs = RunArray::::try_new( + &Int32Array::from(vec![2, 3, 5]), + &Int32Array::from(vec![10, 20, 30]), + ).unwrap(); + + // [10, 99, 20, 99, 30] + let rhs = RunArray::::try_new( + &Int32Array::from(vec![1, 2, 3, 4, 5]), + &Int32Array::from(vec![10, 99, 20, 99, 30]), + ).unwrap(); + + let result = distinct(&lhs, &rhs).unwrap(); + assert_eq!(result, BooleanArray::from(vec![false, true, false, true, false])); + + let result = not_distinct(&lhs, &rhs).unwrap(); + assert_eq!(result, BooleanArray::from(vec![true, false, true, false, true])); + } + + #[test] + fn test_distinct_with_nulls_on_expanded_ree() { + // Simulate REE-expanded arrays with nulls + let lhs = Int32Array::from(vec![Some(1), None, Some(3)]); + let rhs = Int32Array::from(vec![Some(1), Some(2), None]); + + let result = distinct(&lhs, &rhs).unwrap(); + assert_eq!(result, BooleanArray::from(vec![false, true, true])); + + let result = not_distinct(&lhs, &rhs).unwrap(); + assert_eq!(result, BooleanArray::from(vec![true, false, false])); + } + + #[test] + fn test_ree_distinct_with_nulls_and_values() { + // Logical: [10, NULL, 20, NULL, 30] + let run_ends = Int32Array::from(vec![1, 2, 3, 4, 5]); + let values = Int32Array::from(vec![10, -1, 20, -1, 30]); + let lhs = RunArray::::try_new(&run_ends, &values).unwrap(); + + // Logical: [10, NULL, 99, NULL, 30] + let run_ends = Int32Array::from(vec![1, 2, 3, 4, 5]); + let values = Int32Array::from(vec![10, -1, 99, -1, 30]); + let rhs = RunArray::::try_new(&run_ends, &values).unwrap(); + + let result = distinct(&lhs, &rhs).unwrap(); + assert_eq!(result, BooleanArray::from(vec![false, false, true, false, false])); + + let result = not_distinct(&lhs, &rhs).unwrap(); + assert_eq!(result, BooleanArray::from(vec![true, true, false, true, true])); + } + } } diff --git a/arrow-ord/src/lib.rs b/arrow-ord/src/lib.rs index 99b0451992cf..0c68846b06ee 100644 --- a/arrow-ord/src/lib.rs +++ b/arrow-ord/src/lib.rs @@ -49,6 +49,7 @@ )] #![cfg_attr(docsrs, feature(doc_auto_cfg))] #![warn(missing_docs)] +pub use arrow_array::downcast_primitive_array; pub mod cmp; #[doc(hidden)] pub mod comparison; From 72bd81a66bf236e7c3f27de0f6e62de9a88dcaa2 Mon Sep 17 00:00:00 2001 From: Richard Baah Date: Wed, 16 Jul 2025 15:47:01 -0400 Subject: [PATCH 2/3] Implementing optmizations, to remove unneeded RunEnd unpacking. Summation is the only function left --- arrow-arith/src/aggregate.rs | 227 +++++++++++++---------------- arrow-array/src/array/run_array.rs | 96 ++++++++---- arrow-ord/src/cmp.rs | 7 +- 3 files changed, 176 insertions(+), 154 deletions(-) diff --git a/arrow-arith/src/aggregate.rs b/arrow-arith/src/aggregate.rs index 117b45632fa0..d22fa4d5689d 100644 --- a/arrow-arith/src/aggregate.rs +++ b/arrow-arith/src/aggregate.rs @@ -17,13 +17,12 @@ //! Defines aggregations over Arrow arrays. -use arrow_array::cast::*; +use arrow_array::cast::{*}; use arrow_array::iterator::ArrayIter; use arrow_array::*; use arrow_buffer::{ArrowNativeType, NullBuffer}; use arrow_data::bit_iterator::try_for_each_valid_idx; use arrow_schema::*; -use num::cast; use std::borrow::BorrowMut; use std::cmp::{self, Ordering}; use std::ops::{BitAnd, BitOr, BitXor}; @@ -574,22 +573,33 @@ where Some(sum) } - DataType::RunEndEncoded(_, _) => { + DataType::RunEndEncoded(run_field, _) => { let null_count = array.null_count(); if null_count == array.len() { return None; } - - // Expand REE array to its logical form and recursively call sum_array - if let Some(expanded_array) = arrow_array::unwrap_ree_array(&array) { - // Cast the expanded array to the appropriate type and call sum_array recursively - if let Some(primitive_array) = expanded_array.as_any().downcast_ref::>() { - sum::(primitive_array) - } else { - // If we can't downcast, return None - None + let ree = match run_field.data_type() { + DataType::Int64 => AnyRunArray::new(&array, DataType::Int64), + DataType::Int32 => AnyRunArray::new(&array, DataType::Int32), + DataType::Int16 => AnyRunArray::new(&array, DataType::Int16), + _ => return None, + }; + if let Some(ree) = ree { + let mut sum = T::default_value(); + + let values = ree.values(); + let values_array = values.as_any().downcast_ref::>().unwrap(); + let values_data = values_array.values(); + let mut prev_end = 0; + for i in 0..ree.run_ends_len() { + let end = ree.run_ends_value(i); + let run_length = end - prev_end; + let run_length_native = T::Native::from_usize(run_length).unwrap(); + sum = sum.add_wrapping(values_data[i].mul_wrapping(run_length_native)); + prev_end = end; } + Some(sum) } else { None } @@ -630,25 +640,41 @@ where Ok(Some(sum)) } - DataType::RunEndEncoded(_, _) => { + DataType::RunEndEncoded(run_field, _) => { let null_count = array.null_count(); if null_count == array.len() { return Ok(None); } - // Expand REE array to its logical form and recursively call sum_array_checked - if let Some(expanded_array) = arrow_array::unwrap_ree_array(&array) { - // Cast the expanded array to the appropriate type and call sum_checked recursively - if let Some(primitive_array) = expanded_array.as_any().downcast_ref::>() { - sum_checked::(primitive_array) - } else { - // If we can't downcast, return None - Ok(None) + let ree = match run_field.data_type() { + DataType::Int64 => AnyRunArray::new(&array, DataType::Int64), + DataType::Int32 => AnyRunArray::new(&array, DataType::Int32), + DataType::Int16 => AnyRunArray::new(&array, DataType::Int16), + _ => return Ok(None), + }; + + if let Some(ree) = ree { + let mut sum = T::default_value(); + + let values = ree.values(); + let values_array = values.as_any().downcast_ref::>().unwrap(); + let values_data = values_array.values(); + + let mut prev_end = 0; + for i in 0..ree.run_ends_len() { + let end = ree.run_ends_value(i); + let run_length = end - prev_end; + let run_length_native = T::Native::from_usize(run_length).unwrap(); + sum = sum.add_checked(values_data[i].mul_checked(run_length_native)?)?; + prev_end = end; } + + Ok(Some(sum)) } else { Ok(None) } + } _ => sum_checked::(as_primitive_array(&array)), } @@ -687,24 +713,7 @@ where match array.data_type() { DataType::Dictionary(_, _) => min_max_helper::(array, cmp), DataType::RunEndEncoded(_, _) => { - let null_count = array.null_count(); - - if null_count == array.len() { - return None; - } - - // Expand REE array to its logical form and recursively call min_max_array_helper - if let Some(expanded_array) = arrow_array::unwrap_ree_array(&array) { - // Cast the expanded array to the appropriate type and call min_max_helper recursively - if let Some(primitive_array) = expanded_array.as_any().downcast_ref::>() { - min_max_helper::(primitive_array, cmp) - } else { - // If we can't downcast, return None - None - } - } else { - None - } + min_max_helper::(array, cmp) } _ => m(as_primitive_array(&array)), } @@ -1762,8 +1771,6 @@ mod tests { sum_checked(&a).expect_err("overflow should be detected"); sum_array_checked::(&a).expect_err("overflow should be detected"); } - // ... existing code ... - // REE (RunEndEncodedArray) Tests mod ree_aggregation { use super::*; use arrow_array::{RunArray, Int32Array, Int64Array, Float64Array}; @@ -1771,142 +1778,112 @@ mod tests { #[test] fn test_ree_sum_array_basic() { - // REE array: [10, 10, 20, 30, 30] (logical length 5) - let run_ends = Int32Array::from(vec![2, 3, 5]); + // REE array: [10, 10, 20, 30, 30,30] (logical length 6) + let run_ends = Int32Array::from(vec![2, 3, 6]); let values = Int32Array::from(vec![10, 20, 30]); let run_array = RunArray::::try_new(&run_ends, &values).unwrap(); - // Expand to logical form and test - let expanded = arrow_array::unwrap_ree_array(&run_array).unwrap(); - let primitive_array = expanded.as_any().downcast_ref::().unwrap(); - let result = sum_array::(primitive_array); - assert_eq!(result, Some(100)); // 10+10+20+30+30 = 100 + + let typed_array = run_array.downcast::().unwrap(); + + let result = sum_array::(typed_array); + assert_eq!(result, Some(130)); // 10+10+20+30+30+30 = 130 } #[test] fn test_ree_sum_array_with_nulls() { // REE array with nulls: [10, NULL, 20, NULL, 30] let run_ends = Int32Array::from(vec![1, 2, 3, 4, 5]); - let values = Int32Array::from(vec![10, -1, 20, -1, 30]); // -1 represents null + let values = Int32Array::from(vec![10, 0, 20, 0, 30]); // 0 represents null let run_array = RunArray::::try_new(&run_ends, &values).unwrap(); - // Expand to logical form and test - let expanded = arrow_array::unwrap_ree_array(&run_array).unwrap(); - let primitive_array = expanded.as_any().downcast_ref::().unwrap(); - let result = sum_array::(primitive_array); + let typed_array = run_array.downcast::().unwrap(); + let result = sum_array::(typed_array); assert_eq!(result, Some(60)); // 10+20+30 = 60 (nulls ignored) } #[test] - fn test_ree_sum_array_checked_basic() { + fn test_ree_sum_array_large_values() { + // REE array with large values: [1000, 1000, 2000, 3000, 3000] + let run_ends = Int64Array::from(vec![2, 3, 5]); + let values = Int64Array::from(vec![1000, 2000, 3000]); + let run_array = RunArray::::try_new(&run_ends, &values).unwrap(); + + let typed_array = run_array.downcast::().unwrap(); + let result = sum_array::(typed_array); + assert_eq!(result, Some(10000)); // 1000+1000+2000+3000+3000 = 10000 + } + + #[test] + fn test_ree_sum_checked_array_basic() { // REE array: [5, 5, 10, 15, 15] (logical length 5) let run_ends = Int32Array::from(vec![2, 3, 5]); let values = Int32Array::from(vec![5, 10, 15]); let run_array = RunArray::::try_new(&run_ends, &values).unwrap(); - // Expand to logical form and test - let expanded = arrow_array::unwrap_ree_array(&run_array).unwrap(); - let primitive_array = expanded.as_any().downcast_ref::().unwrap(); - let result = sum_array_checked::(primitive_array).unwrap(); - assert_eq!(result, Some(50)); // 5+5+10+15+15 = 50 + let typed_array = run_array.downcast::().unwrap(); + let result = sum_array_checked::(typed_array); + assert_eq!(result.unwrap(), Some(50)); // 5+5+10+15+15 = 50 } #[test] - fn test_ree_sum_array_checked_overflow() { - // REE array that will overflow: [i32::MAX, i32::MAX, 1] + fn test_ree_sum_checked_array_overflow() { + // REE array that will cause overflow: [i32::MAX, i32::MAX, 1] let run_ends = Int32Array::from(vec![2, 3]); let values = Int32Array::from(vec![i32::MAX, 1]); let run_array = RunArray::::try_new(&run_ends, &values).unwrap(); - // Expand to logical form and test - let expanded = arrow_array::unwrap_ree_array(&run_array).unwrap(); - let primitive_array = expanded.as_any().downcast_ref::().unwrap(); - let result = sum_array_checked::(primitive_array); - assert!(result.is_err()); // Should overflow + let typed_array = run_array.downcast::().unwrap(); + let result = sum_array_checked::(typed_array); + assert!(result.is_err()); // Should detect overflow } #[test] fn test_ree_min_array_basic() { - // REE array: [50, 50, 10, 30, 30] (logical length 5) + // REE array: [30, 30, 10, 20, 20] (logical length 5) let run_ends = Int32Array::from(vec![2, 3, 5]); - let values = Int32Array::from(vec![50, 10, 30]); + let values = Int32Array::from(vec![30, 10, 20]); let run_array = RunArray::::try_new(&run_ends, &values).unwrap(); - // Expand to logical form and test - let expanded = arrow_array::unwrap_ree_array(&run_array).unwrap(); - let primitive_array = expanded.as_any().downcast_ref::().unwrap(); - let result = min_array::(primitive_array); - assert_eq!(result, Some(10)); // Minimum value is 10 + let typed_array = run_array.downcast::().unwrap(); + let result = min_array::(typed_array); + assert_eq!(result, Some(10)); // min(30, 30, 10, 20, 20) = 10 } #[test] - fn test_ree_min_array_with_nulls() { - // REE array with nulls: [100, NULL, 5, NULL, 200] - let run_ends = Int32Array::from(vec![1, 2, 3, 4, 5]); - let values = Int32Array::from(vec![100, -1, 5, -1, 200]); // -1 represents null + fn test_ree_min_array_float() { + // REE array with floats: [5.5, 5.5, 2.1, 8.9, 8.9] + let run_ends = Int32Array::from(vec![2, 3, 5]); + let values = Float64Array::from(vec![5.5, 2.1, 8.9]); let run_array = RunArray::::try_new(&run_ends, &values).unwrap(); - // Expand to logical form and test - let expanded = arrow_array::unwrap_ree_array(&run_array).unwrap(); - let primitive_array = expanded.as_any().downcast_ref::().unwrap(); - let result = min_array::(primitive_array); - assert_eq!(result, Some(5)); // Minimum non-null value is 5 + let typed_array = run_array.downcast::().unwrap(); + let result = min_array::(typed_array); + assert_eq!(result, Some(2.1)); // min(5.5, 5.5, 2.1, 8.9, 8.9) = 2.1 } #[test] fn test_ree_max_array_basic() { - // REE array: [10, 10, 50, 20, 20] (logical length 5) + // REE array: [10, 10, 30, 20, 20] (logical length 5) let run_ends = Int32Array::from(vec![2, 3, 5]); - let values = Int32Array::from(vec![10, 50, 20]); + let values = Int32Array::from(vec![10, 30, 20]); let run_array = RunArray::::try_new(&run_ends, &values).unwrap(); - // Expand to logical form and test - let expanded = arrow_array::unwrap_ree_array(&run_array).unwrap(); - let primitive_array = expanded.as_any().downcast_ref::().unwrap(); - let result = max_array::(primitive_array); - assert_eq!(result, Some(50)); // Maximum value is 50 - } - - #[test] - fn test_ree_max_array_with_nulls() { - // REE array with nulls: [5, NULL, 500, NULL, 10] - let run_ends = Int32Array::from(vec![1, 2, 3, 4, 5]); - let values = Int32Array::from(vec![5, -1, 500, -1, 10]); // -1 represents null - let run_array = RunArray::::try_new(&run_ends, &values).unwrap(); - - // Expand to logical form and test - let expanded = arrow_array::unwrap_ree_array(&run_array).unwrap(); - let primitive_array = expanded.as_any().downcast_ref::().unwrap(); - let result = max_array::(primitive_array); - assert_eq!(result, Some(500)); // Maximum non-null value is 500 - } - - #[test] - fn test_ree_sum_array_large_values() { - // REE array with large values: [1000000, 1000000, 2000000, 3000000, 3000000] - let run_ends = Int64Array::from(vec![2, 3, 5]); - let values = Int64Array::from(vec![1000000, 2000000, 3000000]); - let run_array = RunArray::::try_new(&run_ends, &values).unwrap(); - - // Expand to logical form and test - let expanded = arrow_array::unwrap_ree_array(&run_array).unwrap(); - let primitive_array = expanded.as_any().downcast_ref::().unwrap(); - let result = sum_array::(primitive_array); - assert_eq!(result, Some(10000000)); // 1M+1M+2M+3M+3M = 10M + let typed_array = run_array.downcast::().unwrap(); + let result = max_array::(typed_array); + assert_eq!(result, Some(30)); // max(10, 10, 30, 20, 20) = 30 } #[test] - fn test_ree_max_array_float_values() { - // REE array with float values: [1.5, 1.5, 3.7, 2.1, 2.1] + fn test_ree_max_array_float() { + // REE array with floats: [2.1, 2.1, 8.9, 5.5, 5.5] let run_ends = Int32Array::from(vec![2, 3, 5]); - let values = Float64Array::from(vec![1.5, 3.7, 2.1]); + let values = Float64Array::from(vec![2.1, 8.9, 5.5]); let run_array = RunArray::::try_new(&run_ends, &values).unwrap(); - // Expand to logical form and test - let expanded = arrow_array::unwrap_ree_array(&run_array).unwrap(); - let primitive_array = expanded.as_any().downcast_ref::().unwrap(); - let result = max_array::(primitive_array); - assert_eq!(result, Some(3.7)); // Maximum value is 3.7 + let typed_array = run_array.downcast::().unwrap(); + let result = max_array::(typed_array); + assert_eq!(result, Some(8.9)); // max(2.1, 2.1, 8.9, 5.5, 5.5) = 8.9 } } } diff --git a/arrow-array/src/array/run_array.rs b/arrow-array/src/array/run_array.rs index 872148329f97..a254c81ba601 100644 --- a/arrow-array/src/array/run_array.rs +++ b/arrow-array/src/array/run_array.rs @@ -23,7 +23,7 @@ use arrow_data::{ArrayData, ArrayDataBuilder}; use arrow_schema::{ArrowError, DataType, Field}; use crate::{ - builder::StringRunBuilder, cast::AsArray, make_array, run_iterator::RunArrayIter, types::{Int16Type, Int32Type, Int64Type, RunEndIndexType}, Array, ArrayAccessor, ArrayRef, ArrowPrimitiveType, PrimitiveArray + builder::StringRunBuilder, cast::AsArray, make_array, run_iterator::RunArrayIter, types::{Int16Type, Int32Type, Int64Type, RunEndIndexType}, Array, ArrayAccessor, ArrayRef, ArrowPrimitiveType, PrimitiveArray, Int16Array, Int32Array, Int64Array }; /// An array of [run-end encoded values](https://arrow.apache.org/docs/format/Columnar.html#run-end-encoded-layout) @@ -678,30 +678,6 @@ impl Array for TypedRunArray<'_, R, V> { } } -// ArrayAccessor implementation for RunArray itself -// This allows RunArray to be used directly with aggregation functions -impl<'a, R: RunEndIndexType> ArrayAccessor for &'a RunArray { - type Item = ArrayRef; - - fn value(&self, logical_index: usize) -> Self::Item { - assert!( - logical_index < self.len(), - "Trying to access an element at index {} from a RunArray of length {}", - logical_index, - self.len() - ); - unsafe { self.value_unchecked(logical_index) } - } - - unsafe fn value_unchecked(&self, logical_index: usize) -> Self::Item { - let physical_index = self.get_physical_index(logical_index); - // Return the value at the physical index as an ArrayRef - // This is a single-element array containing the value - let value_array = self.values().slice(physical_index, 1); - value_array - } -} - // Array accessor converts the index of logical array to the index of the physical array // using binary search. The time complexity is O(log N) where N is number of runs. impl<'a, R, V> ArrayAccessor for TypedRunArray<'a, R, V> @@ -744,6 +720,76 @@ where } } + +/// An AnyRunArray is a wrapper around a RunArray that can be used to aggregate over a RunEndEncodedArray +/// This is used to avoid the need to downcast the RunEndEncodedArray to a specific type +pub enum AnyRunArray<'a> { + /// A RunArray with Int64 run ends + Int64(&'a RunArray), + /// A RunArray with Int32 run ends + Int32(&'a RunArray), + /// A RunArray with Int16 run ends + Int16(&'a RunArray), +} + +impl<'a> AnyRunArray<'a> { + /// Creates a new [`AnyRunArray`] from a [`dyn Array`] + pub fn new(array: &'a dyn Array, run_ends_type: DataType) -> Option { + match run_ends_type { + DataType::Int64 => Some(AnyRunArray::Int64(array.as_run_opt::().unwrap())), + DataType::Int32 => Some(AnyRunArray::Int32(array.as_run_opt::().unwrap())), + DataType::Int16 => Some(AnyRunArray::Int16(array.as_run_opt::().unwrap())), + _ => None, + } + } + + /// Returns the run ends of this [`AnyRunArray`] + pub fn run_ends(&self) -> Arc { + match self { + AnyRunArray::Int64(array) => { + let values = array.run_ends().values(); + Arc::new(Int64Array::from_iter_values(values.iter().copied())) + } + AnyRunArray::Int32(array) => { + let values = array.run_ends().values(); + Arc::new(Int32Array::from_iter_values(values.iter().copied())) + } + AnyRunArray::Int16(array) => { + let values = array.run_ends().values(); + Arc::new(Int16Array::from_iter_values(values.iter().copied())) + } + } + } + + /// Returns the values of this [`AnyRunArray`] + pub fn values(&self) -> &ArrayRef { + match self { + AnyRunArray::Int64(array) => array.values(), + AnyRunArray::Int32(array) => array.values(), + AnyRunArray::Int16(array) => array.values(), + } + } + /// Returns the run end value at the given index + pub fn run_ends_value(&self, i: usize) -> usize { + match self { + AnyRunArray::Int64(array) => array.run_ends().values()[i].as_usize(), + AnyRunArray::Int32(array) => array.run_ends().values()[i].as_usize(), + AnyRunArray::Int16(array) => array.run_ends().values()[i].as_usize(), + } + } + + /// Returns the length of run ends array + pub fn run_ends_len(&self) -> usize { + match self { + AnyRunArray::Int64(array) => array.values().len(), + AnyRunArray::Int32(array) => array.values().len(), + AnyRunArray::Int16(array) => array.values().len(), + } + } + +} + + #[cfg(test)] mod tests { use rand::rng; diff --git a/arrow-ord/src/cmp.rs b/arrow-ord/src/cmp.rs index 55734cedcdb7..897638da395c 100644 --- a/arrow-ord/src/cmp.rs +++ b/arrow-ord/src/cmp.rs @@ -24,16 +24,15 @@ //! use arrow_array::cast::AsArray; -use arrow_array::types::{ByteArrayType, ByteViewType, Int16Type, Int32Type, Int64Type}; +use arrow_array::types::{ByteArrayType, ByteViewType}; use arrow_array::{ - array, downcast_primitive_array, AnyDictionaryArray, Array, ArrayRef, ArrowNativeTypeOp, ArrowPrimitiveType, BooleanArray, Datum, FixedSizeBinaryArray, GenericByteArray, GenericByteViewArray, PrimitiveArray + downcast_primitive_array, AnyDictionaryArray, Array, ArrowNativeTypeOp, BooleanArray, Datum, FixedSizeBinaryArray, GenericByteArray, GenericByteViewArray }; use arrow_buffer::bit_util::ceil; use arrow_buffer::{BooleanBuffer, MutableBuffer, NullBuffer}; -use arrow_schema::{ArrowError, DataType}; +use arrow_schema::{ArrowError}; use arrow_select::take::take; use std::ops::Not; -use std::sync::Arc; #[derive(Debug, Copy, Clone)] enum Op { From 9d00687072f6536a42ac81c22a1279479a914ed4 Mon Sep 17 00:00:00 2001 From: Richard Baah Date: Wed, 23 Jul 2025 18:24:13 -0400 Subject: [PATCH 3/3] Add ree_distinct function that compares Run-End Encoded arrays directly without expanding to logical form. --- arrow-arith/src/aggregate.rs | 30 +-- arrow-array/src/array/mod.rs | 3 - arrow-array/src/array/run_array.rs | 203 +++++++++++++++----- arrow-ord/src/cmp.rs | 298 ++++++++++++++++++++++++++--- arrow-ord/src/lib.rs | 1 - 5 files changed, 444 insertions(+), 91 deletions(-) diff --git a/arrow-arith/src/aggregate.rs b/arrow-arith/src/aggregate.rs index d22fa4d5689d..4dcdd1dd5eae 100644 --- a/arrow-arith/src/aggregate.rs +++ b/arrow-arith/src/aggregate.rs @@ -17,7 +17,7 @@ //! Defines aggregations over Arrow arrays. -use arrow_array::cast::{*}; +use arrow_array::cast::*; use arrow_array::iterator::ArrayIter; use arrow_array::*; use arrow_buffer::{ArrowNativeType, NullBuffer}; @@ -653,7 +653,7 @@ where DataType::Int16 => AnyRunArray::new(&array, DataType::Int16), _ => return Ok(None), }; - + if let Some(ree) = ree { let mut sum = T::default_value(); @@ -666,7 +666,7 @@ where let end = ree.run_ends_value(i); let run_length = end - prev_end; let run_length_native = T::Native::from_usize(run_length).unwrap(); - sum = sum.add_checked(values_data[i].mul_checked(run_length_native)?)?; + sum = sum.add_checked(values_data[i].mul_checked(run_length_native)?)?; prev_end = end; } @@ -674,7 +674,6 @@ where } else { Ok(None) } - } _ => sum_checked::(as_primitive_array(&array)), } @@ -712,9 +711,7 @@ where { match array.data_type() { DataType::Dictionary(_, _) => min_max_helper::(array, cmp), - DataType::RunEndEncoded(_, _) => { - min_max_helper::(array, cmp) - } + DataType::RunEndEncoded(_, _) => min_max_helper::(array, cmp), _ => m(as_primitive_array(&array)), } } @@ -1773,8 +1770,8 @@ mod tests { } mod ree_aggregation { use super::*; - use arrow_array::{RunArray, Int32Array, Int64Array, Float64Array}; - use arrow_array::types::{Int32Type, Int64Type, Float64Type}; + use arrow_array::types::{Float64Type, Int32Type, Int64Type}; + use arrow_array::{Float64Array, Int32Array, Int64Array, RunArray}; #[test] fn test_ree_sum_array_basic() { @@ -1783,7 +1780,6 @@ mod tests { let values = Int32Array::from(vec![10, 20, 30]); let run_array = RunArray::::try_new(&run_ends, &values).unwrap(); - let typed_array = run_array.downcast::().unwrap(); let result = sum_array::(typed_array); @@ -1794,7 +1790,7 @@ mod tests { fn test_ree_sum_array_with_nulls() { // REE array with nulls: [10, NULL, 20, NULL, 30] let run_ends = Int32Array::from(vec![1, 2, 3, 4, 5]); - let values = Int32Array::from(vec![10, 0, 20, 0, 30]); // 0 represents null + let values = Int32Array::from(vec![Some(10), None, Some(20), None, Some(30)]); let run_array = RunArray::::try_new(&run_ends, &values).unwrap(); let typed_array = run_array.downcast::().unwrap(); @@ -1802,6 +1798,18 @@ mod tests { assert_eq!(result, Some(60)); // 10+20+30 = 60 (nulls ignored) } + #[test] + fn test_ree_sum_array_with_only_nulls() { + // REE array: [None, None, None, None, None] (logical length 5) + let run_ends = Int32Array::from(vec![1, 2, 3, 4, 5]); + let values = Int32Array::from(vec![None, None, None, None, None]); + let run_array = RunArray::::try_new(&run_ends, &values).unwrap(); + + let typed_array = run_array.downcast::().unwrap(); + let result = sum_array::(typed_array); + assert_eq!(result, Some(0)); // 0 + } + #[test] fn test_ree_sum_array_large_values() { // REE array with large values: [1000, 1000, 2000, 3000, 3000] diff --git a/arrow-array/src/array/mod.rs b/arrow-array/src/array/mod.rs index 0e0ca8a124ae..e41a3a1d719a 100644 --- a/arrow-array/src/array/mod.rs +++ b/arrow-array/src/array/mod.rs @@ -68,9 +68,6 @@ mod run_array; pub use run_array::*; -// Re-export the unwrap_ree_array function for public use -pub use run_array::unwrap_ree_array; - mod byte_view_array; pub use byte_view_array::*; diff --git a/arrow-array/src/array/run_array.rs b/arrow-array/src/array/run_array.rs index a254c81ba601..14d8d3b10dbf 100644 --- a/arrow-array/src/array/run_array.rs +++ b/arrow-array/src/array/run_array.rs @@ -18,14 +18,24 @@ use std::any::Any; use std::sync::Arc; +use crate::{ + builder::StringRunBuilder, + cast::AsArray, + make_array, + run_iterator::RunArrayIter, + types::{ + Date32Type, Date64Type, Decimal128Type, Decimal256Type, DurationNanosecondType, + Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, Int8Type, IntervalDayTimeType, + RunEndIndexType, Time32MillisecondType, Time64NanosecondType, TimestampMicrosecondType, + UInt16Type, UInt32Type, UInt64Type, UInt8Type, + }, + Array, ArrayAccessor, ArrayRef, ArrowPrimitiveType, BooleanArray, Int16Array, Int32Array, + Int64Array, PrimitiveArray, StringArray, +}; use arrow_buffer::{ArrowNativeType, BooleanBufferBuilder, NullBuffer, RunEndBuffer}; use arrow_data::{ArrayData, ArrayDataBuilder}; use arrow_schema::{ArrowError, DataType, Field}; -use crate::{ - builder::StringRunBuilder, cast::AsArray, make_array, run_iterator::RunArrayIter, types::{Int16Type, Int32Type, Int64Type, RunEndIndexType}, Array, ArrayAccessor, ArrayRef, ArrowPrimitiveType, PrimitiveArray, Int16Array, Int32Array, Int64Array -}; - /// An array of [run-end encoded values](https://arrow.apache.org/docs/format/Columnar.html#run-end-encoded-layout) /// /// This encoding is variation on [run-length encoding (RLE)](https://en.wikipedia.org/wiki/Run-length_encoding) @@ -247,14 +257,16 @@ impl RunArray { values: self.values.clone(), } } + /// Expands the REE array to its logical form - pub fn expand_to_logical(&self) -> Result, ArrowError> + pub fn expand_to_logical(&self) -> Result, ArrowError> where T::Native: Default, { - let typed_ree = self.downcast::>() - .ok_or_else(|| ArrowError::InvalidArgumentError("Failed to downcast to typed REE".to_string()))?; - + let typed_ree = self.downcast::>().ok_or_else(|| { + ArrowError::InvalidArgumentError("Failed to downcast to typed REE".to_string()) + })?; + let mut builder = PrimitiveArray::::builder(typed_ree.len()); for i in 0..typed_ree.len() { if typed_ree.is_null(i) { @@ -265,29 +277,6 @@ impl RunArray { } Ok(Box::new(builder.finish())) } - /// Unwraps a REE array into a logical array - pub fn unwrap_ree_array(array: &dyn Array) -> Option> { - match array.data_type() { - arrow_schema::DataType::RunEndEncoded(run_ends_field, _) => { - match run_ends_field.data_type() { - arrow_schema::DataType::Int16 => { - array.as_run_opt::() - .and_then(|ree| ree.expand_to_logical::().ok()) - } - arrow_schema::DataType::Int32 => { - array.as_run_opt::() - .and_then(|ree| ree.expand_to_logical::().ok()) - } - arrow_schema::DataType::Int64 => { - array.as_run_opt::() - .and_then(|ree| ree.expand_to_logical::().ok()) - } - _ => None, - } - } - _ => None, - } -} } impl From for RunArray { @@ -566,28 +555,140 @@ pub struct TypedRunArray<'a, R: RunEndIndexType, V> { } /// Unwraps a REE array into a logical array -pub fn unwrap_ree_array(array: &dyn Array) -> Option> { +pub fn ree_to_expanded_array(array: &dyn Array) -> Option> { match array.data_type() { arrow_schema::DataType::RunEndEncoded(run_ends_field, _) => { match run_ends_field.data_type() { - arrow_schema::DataType::Int16 => { - array.as_run_opt::() - .and_then(|ree| ree.expand_to_logical::().ok()) - } - arrow_schema::DataType::Int32 => { - array.as_run_opt::() - .and_then(|ree| ree.expand_to_logical::().ok()) - } - arrow_schema::DataType::Int64 => { - array.as_run_opt::() - .and_then(|ree| ree.expand_to_logical::().ok()) - } + arrow_schema::DataType::Int16 => array + .as_run_opt::() + .and_then(|ree| ree.expand_to_logical::().ok()), + arrow_schema::DataType::Int32 => array + .as_run_opt::() + .and_then(|ree| ree.expand_to_logical::().ok()), + arrow_schema::DataType::Int64 => array + .as_run_opt::() + .and_then(|ree| ree.expand_to_logical::().ok()), _ => None, } } _ => None, } } + +/// Generate a boolean array that indicates if two run arrays are equal +pub fn ree_distinct( + lhs: &AnyRunArray, + rhs: &AnyRunArray, + size: usize, + flag: bool, +) -> Option { + // Iterate through both run arrays and compare their logical indices + // we know that the run arrays of the exact same size. + let lhs_vals = lhs.values(); + let rhs_vals = rhs.values(); + if lhs_vals.data_type() != rhs_vals.data_type() { + return None; + } + match lhs_vals.data_type() { + // Integer types + DataType::Int8 => ree_distinct_primitive::(lhs, rhs, size, flag), + DataType::Int16 => ree_distinct_primitive::(lhs, rhs, size, flag), + DataType::Int32 => ree_distinct_primitive::(lhs, rhs, size, flag), + DataType::Int64 => ree_distinct_primitive::(lhs, rhs, size, flag), + DataType::UInt8 => ree_distinct_primitive::(lhs, rhs, size, flag), + DataType::UInt16 => ree_distinct_primitive::(lhs, rhs, size, flag), + DataType::UInt32 => ree_distinct_primitive::(lhs, rhs, size, flag), + DataType::UInt64 => ree_distinct_primitive::(lhs, rhs, size, flag), + + // Floating point + DataType::Float32 => ree_distinct_primitive::(lhs, rhs, size, flag), + DataType::Float64 => ree_distinct_primitive::(lhs, rhs, size, flag), + + // Temporal + DataType::Date32 => ree_distinct_primitive::(lhs, rhs, size, flag), + DataType::Date64 => ree_distinct_primitive::(lhs, rhs, size, flag), + DataType::Timestamp(_, _) => { + ree_distinct_primitive::(lhs, rhs, size, flag) + } + DataType::Time32(_) => { + ree_distinct_primitive::(lhs, rhs, size, flag) + } + DataType::Time64(_) => ree_distinct_primitive::(lhs, rhs, size, flag), + DataType::Duration(_) => { + ree_distinct_primitive::(lhs, rhs, size, flag) + } + DataType::Interval(_) => { + ree_distinct_primitive::(lhs, rhs, size, flag) + } + + // Decimals + DataType::Decimal128(_, _) => { + ree_distinct_primitive::(lhs, rhs, size, flag) + } + DataType::Decimal256(_, _) => { + ree_distinct_primitive::(lhs, rhs, size, flag) + } + // Strings arent a primitive type, so we need to handle them separately + DataType::Utf8 => ree_distinct_string(lhs, rhs, size, flag), + + // Not yet supported or complex + _ => None, + } +} + +fn ree_distinct_primitive( + lhs: &AnyRunArray, + rhs: &AnyRunArray, + size: usize, + flag: bool, +) -> Option { + let lhs_vals = lhs.values().as_any().downcast_ref::>()?; + let rhs_vals = rhs.values().as_any().downcast_ref::>()?; + let mut builder = BooleanBufferBuilder::new(size); + for i in 0..size { + let li = lhs.get_physical_index(i); + let ri = rhs.get_physical_index(i); + + let mut is_same = match (lhs_vals.is_null(li), rhs_vals.is_null(ri)) { + (true, true) => true, + (true, false) | (false, true) => false, // If one is null, result depends on flag + (false, false) => lhs_vals.value(li) == rhs_vals.value(ri), + }; + if flag { + is_same = !is_same; + } + builder.append(is_same); + } + Some(BooleanArray::from(builder.finish())) +} + +fn ree_distinct_string( + lhs: &AnyRunArray, + rhs: &AnyRunArray, + size: usize, + flag: bool, +) -> Option { + let lhs_vals = lhs.values().as_any().downcast_ref::()?; + let rhs_vals = rhs.values().as_any().downcast_ref::()?; + + let mut builder = BooleanBufferBuilder::new(size); + for i in 0..size { + let li = lhs.get_physical_index(i); + let ri = rhs.get_physical_index(i); + + let mut is_same = match (lhs_vals.is_null(li), rhs_vals.is_null(ri)) { + (true, true) => true, + (true, false) | (false, true) => false, + (false, false) => lhs_vals.value(li) == rhs_vals.value(ri), + }; + if flag { + is_same = !is_same; + } + builder.append(is_same); + } + Some(BooleanArray::from(builder.finish())) +} + // Manually implement `Clone` to avoid `V: Clone` type constraint impl Clone for TypedRunArray<'_, R, V> { fn clone(&self) -> Self { @@ -720,7 +821,6 @@ where } } - /// An AnyRunArray is a wrapper around a RunArray that can be used to aggregate over a RunEndEncodedArray /// This is used to avoid the need to downcast the RunEndEncodedArray to a specific type pub enum AnyRunArray<'a> { @@ -777,7 +877,7 @@ impl<'a> AnyRunArray<'a> { AnyRunArray::Int16(array) => array.run_ends().values()[i].as_usize(), } } - + /// Returns the length of run ends array pub fn run_ends_len(&self) -> usize { match self { @@ -786,9 +886,16 @@ impl<'a> AnyRunArray<'a> { AnyRunArray::Int16(array) => array.values().len(), } } - -} + /// Returns the physical index for the given logical index + pub fn get_physical_index(&self, logical_index: usize) -> usize { + match self { + AnyRunArray::Int64(array) => array.get_physical_index(logical_index), + AnyRunArray::Int32(array) => array.get_physical_index(logical_index), + AnyRunArray::Int16(array) => array.get_physical_index(logical_index), + } + } +} #[cfg(test)] mod tests { diff --git a/arrow-ord/src/cmp.rs b/arrow-ord/src/cmp.rs index 897638da395c..249d47c0a6aa 100644 --- a/arrow-ord/src/cmp.rs +++ b/arrow-ord/src/cmp.rs @@ -23,14 +23,16 @@ //! [here](https://doc.rust-lang.org/stable/core/arch/) for more information. //! +use arrow_array::array::*; use arrow_array::cast::AsArray; use arrow_array::types::{ByteArrayType, ByteViewType}; use arrow_array::{ - downcast_primitive_array, AnyDictionaryArray, Array, ArrowNativeTypeOp, BooleanArray, Datum, FixedSizeBinaryArray, GenericByteArray, GenericByteViewArray + downcast_primitive_array, AnyDictionaryArray, Array, ArrowNativeTypeOp, BooleanArray, Datum, + FixedSizeBinaryArray, GenericByteArray, GenericByteViewArray, }; use arrow_buffer::bit_util::ceil; use arrow_buffer::{BooleanBuffer, MutableBuffer, NullBuffer}; -use arrow_schema::{ArrowError}; +use arrow_schema::{ArrowError, DataType}; use arrow_select::take::take; use std::ops::Not; @@ -223,13 +225,25 @@ fn compare_op(op: Op, lhs: &dyn Datum, rhs: &dyn Datum) -> Result { + let l_any = AnyRunArray::new(l, run_field.data_type().clone()); + let r_any = AnyRunArray::new(r, run_field2.data_type().clone()); + let flag = match op { + Op::Distinct => true, + Op::NotDistinct => false, + _ => false, + }; + match (l_any, r_any) { + (Some(l), Some(r)) => ree_distinct(&l, &r, len, flag), + _ => None, + } + } + _ => None, + }; + if let Some(result) = _ree_distinct { + return Ok(result); + } let l_v = l.as_any_dictionary_opt(); let l = l_v.map(|x| x.values().as_ref()).unwrap_or(l); @@ -239,7 +253,6 @@ fn compare_op(op: Op, lhs: &dyn Datum, rhs: &dyn Datum) -> Result::try_new(&run_ends, &values).unwrap(); - + let run_ends = Int32Array::from(vec![2, 3, 5]); let values = Int32Array::from(vec![10, 20, 30]); let other_array = RunArray::::try_new(&run_ends, &values).unwrap(); // Test distinct operations let result = distinct(&run_array, &other_array).unwrap(); - assert_eq!(result, BooleanArray::from(vec![false,false,false,false,false])); - + assert_eq!( + result, + BooleanArray::from(vec![false, false, false, false, false]) + ); + let result = not_distinct(&run_array, &other_array).unwrap(); - // Expected: [true, false, true, false, true] (opposite of distinct) - assert_eq!(result, BooleanArray::from(vec![true,true,true,true,true])); - } + assert_eq!( + result, + BooleanArray::from(vec![true, true, true, true, true]) + ); + } #[test] fn test_ree_distinct_mismatched_values() { @@ -907,7 +925,7 @@ mod tests { let result = not_distinct(&lhs, &rhs).unwrap(); assert_eq!(result, BooleanArray::from(vec![true, true, false])); } - + #[test] fn test_ree_distinct_all_different() { // [1, 2, 3] @@ -927,28 +945,35 @@ mod tests { assert_eq!(result, BooleanArray::from(vec![false, false, false])); } - //failed #[test] fn test_ree_distinct_mixed_values() { // [10, 10, 20, 30, 30] let lhs = RunArray::::try_new( &Int32Array::from(vec![2, 3, 5]), &Int32Array::from(vec![10, 20, 30]), - ).unwrap(); + ) + .unwrap(); // [10, 99, 20, 99, 30] let rhs = RunArray::::try_new( &Int32Array::from(vec![1, 2, 3, 4, 5]), &Int32Array::from(vec![10, 99, 20, 99, 30]), - ).unwrap(); + ) + .unwrap(); let result = distinct(&lhs, &rhs).unwrap(); - assert_eq!(result, BooleanArray::from(vec![false, true, false, true, false])); + assert_eq!( + result, + BooleanArray::from(vec![false, true, false, true, false]) + ); let result = not_distinct(&lhs, &rhs).unwrap(); - assert_eq!(result, BooleanArray::from(vec![true, false, true, false, true])); + assert_eq!( + result, + BooleanArray::from(vec![true, false, true, false, true]) + ); } - + #[test] fn test_distinct_with_nulls_on_expanded_ree() { // Simulate REE-expanded arrays with nulls @@ -967,7 +992,7 @@ mod tests { // Logical: [10, NULL, 20, NULL, 30] let run_ends = Int32Array::from(vec![1, 2, 3, 4, 5]); let values = Int32Array::from(vec![10, -1, 20, -1, 30]); - let lhs = RunArray::::try_new(&run_ends, &values).unwrap(); + let lhs = RunArray::::try_new(&run_ends, &values).unwrap(); // Logical: [10, NULL, 99, NULL, 30] let run_ends = Int32Array::from(vec![1, 2, 3, 4, 5]); @@ -975,10 +1000,227 @@ mod tests { let rhs = RunArray::::try_new(&run_ends, &values).unwrap(); let result = distinct(&lhs, &rhs).unwrap(); - assert_eq!(result, BooleanArray::from(vec![false, false, true, false, false])); + assert_eq!( + result, + BooleanArray::from(vec![false, false, true, false, false]) + ); + + let result = not_distinct(&lhs, &rhs).unwrap(); + assert_eq!( + result, + BooleanArray::from(vec![true, true, false, true, true]) + ); + } + + #[test] + fn test_ree_distinct_edge_case_empty_arrays() { + // Empty REE arrays + let run_ends = Int32Array::from(vec![] as Vec); + let values = Int32Array::from(vec![] as Vec); + let lhs = RunArray::::try_new(&run_ends, &values).unwrap(); + + let run_ends = Int32Array::from(vec![] as Vec); + let values = Int32Array::from(vec![] as Vec); + let rhs = RunArray::::try_new(&run_ends, &values).unwrap(); + + let result = distinct(&lhs, &rhs).unwrap(); + assert_eq!(result, BooleanArray::from(vec![] as Vec)); + + let result = not_distinct(&lhs, &rhs).unwrap(); + assert_eq!(result, BooleanArray::from(vec![] as Vec)); + } + + #[test] + fn test_ree_distinct_edge_case_single_run() { + // Single run with same value repeated + let run_ends = Int32Array::from(vec![5]); + let values = Int32Array::from(vec![42]); + let lhs = RunArray::::try_new(&run_ends, &values).unwrap(); + + let run_ends = Int32Array::from(vec![5]); + let values = Int32Array::from(vec![42]); + let rhs = RunArray::::try_new(&run_ends, &values).unwrap(); + + let result = distinct(&lhs, &rhs).unwrap(); + assert_eq!( + result, + BooleanArray::from(vec![false, false, false, false, false]) + ); + + let result = not_distinct(&lhs, &rhs).unwrap(); + assert_eq!( + result, + BooleanArray::from(vec![true, true, true, true, true]) + ); + } + + #[test] + fn test_ree_distinct_edge_case_all_nulls() { + // All null values -> [NONE, NONE, NONE] + let run_ends = Int32Array::from(vec![3]); + let values = Int32Array::from(vec![None]); + let lhs = RunArray::::try_new(&run_ends, &values).unwrap(); + // All null values -> [NONE, NONE, NONE] + let run_ends = Int32Array::from(vec![3]); + let values = Int32Array::from(vec![None]); + let rhs = RunArray::::try_new(&run_ends, &values).unwrap(); + + let result = distinct(&lhs, &rhs).unwrap(); + assert_eq!(result, BooleanArray::from(vec![false, false, false])); + + let result = not_distinct(&lhs, &rhs).unwrap(); + assert_eq!(result, BooleanArray::from(vec![true, true, true])); + } + + #[test] + fn test_ree_distinct_edge_case_mixed_nulls_and_values() { + use arrow_array::BooleanArray; + use arrow_array::{types::Int32Type, Int32Array, RunArray}; + + // Logical LHS: [NULL, NULL, 10, 10, NULL] + let run_ends = Int32Array::from(vec![2, 4, 5]); + let values = Int32Array::from(vec![None, Some(10), None]); + let lhs = RunArray::::try_new(&run_ends, &values).unwrap(); + + // Logical RHS: [NULL, Some(99), 10, 10, NULL] + let run_ends = Int32Array::from(vec![1, 2, 4, 5]); + let values = Int32Array::from(vec![None, Some(99), Some(10), None]); + let rhs = RunArray::::try_new(&run_ends, &values).unwrap(); + + let result = distinct(&lhs, &rhs).unwrap(); + assert_eq!( + result, + BooleanArray::from(vec![ + false, // NULL vs NULL + true, // NULL vs 99 → distinct + false, // 10 vs 10 + false, // 10 vs 10 + false // NULL vs NULL + ]) + ); + + let result = not_distinct(&lhs, &rhs).unwrap(); + assert_eq!( + result, + BooleanArray::from(vec![true, false, true, true, true]) + ); + } + + #[test] + fn test_ree_distinct_float64_type() { + // Test with Float64 type + let run_ends = Int32Array::from(vec![2, 4]); + let values = arrow_array::Float64Array::from(vec![1.5, 2.5]); + let lhs = RunArray::::try_new(&run_ends, &values).unwrap(); + + let run_ends = Int32Array::from(vec![2, 4]); + let values = arrow_array::Float64Array::from(vec![1.5, 3.5]); + let rhs = RunArray::::try_new(&run_ends, &values).unwrap(); + + let result = distinct(&lhs, &rhs).unwrap(); + assert_eq!(result, BooleanArray::from(vec![false, false, true, true])); + + let result = not_distinct(&lhs, &rhs).unwrap(); + assert_eq!(result, BooleanArray::from(vec![true, true, false, false])); + } + + #[test] + fn test_ree_distinct_timestamp_type() { + // Test with Timestamp type + use arrow_array::TimestampMicrosecondArray; + + let run_ends = Int32Array::from(vec![2, 3]); + let values = TimestampMicrosecondArray::from(vec![1000000, 2000000]); // 1s, 2s + let lhs = RunArray::::try_new(&run_ends, &values).unwrap(); + + let run_ends = Int32Array::from(vec![2, 3]); + let values = TimestampMicrosecondArray::from(vec![1000000, 3000000]); // 1s, 3s + let rhs = RunArray::::try_new(&run_ends, &values).unwrap(); + + let result = distinct(&lhs, &rhs).unwrap(); + assert_eq!(result, BooleanArray::from(vec![false, false, true])); + + let result = not_distinct(&lhs, &rhs).unwrap(); + assert_eq!(result, BooleanArray::from(vec![true, true, false])); + } + + #[test] + fn test_ree_is_distinct_from_int_mixed() { + // LHS: [10, 10, NULL, 20, 30] + let run_ends = Int32Array::from(vec![2, 3, 4, 5]); + let values = Int32Array::from(vec![Some(10), None, Some(20), Some(30)]); + let lhs = RunArray::::try_new(&run_ends, &values).unwrap(); + + // RHS: [10, 99, NULL, 25, 30] + let run_ends = Int32Array::from(vec![1, 2, 3, 4, 5]); + let values = Int32Array::from(vec![Some(10), Some(99), None, Some(25), Some(30)]); + let rhs = RunArray::::try_new(&run_ends, &values).unwrap(); + + let result = distinct(&lhs, &rhs).unwrap(); + assert_eq!( + result, + BooleanArray::from(vec![ + false, // 10 == 10 + true, // 10 != 99 + false, // NULL == NULL + true, // 20 != 25 + false // 30 == 30 + ]) + ); + } + + #[test] + fn test_ree_is_not_distinct_from_int_mixed() { + // LHS: [NULL, 50, 50, NULL, 100] + let run_ends = Int32Array::from(vec![1, 3, 4, 5]); + let values = Int32Array::from(vec![None, Some(50), None, Some(100)]); + let lhs = RunArray::::try_new(&run_ends, &values).unwrap(); + + // RHS: [NULL, 50, 51, NULL, 100] + let run_ends = Int32Array::from(vec![1, 2, 3, 4, 5]); + let values = Int32Array::from(vec![None, Some(50), Some(51), None, Some(100)]); + let rhs = RunArray::::try_new(&run_ends, &values).unwrap(); let result = not_distinct(&lhs, &rhs).unwrap(); - assert_eq!(result, BooleanArray::from(vec![true, true, false, true, true])); + assert_eq!( + result, + BooleanArray::from(vec![ + true, // NULL == NULL + true, // 50 == 50 + false, // 50 != 51 + true, // NULL == NULL + true // 100 == 100 + ]) + ); + } + + #[test] + fn test_ree_is_distinct_from_utf8() { + use arrow_array::{Int32Array, RunArray, StringArray}; + + // LHS: ["foo", NULL, "bar", "baz", "baz"] + let run_ends = Int32Array::from(vec![1, 2, 3, 5]); + let values = StringArray::from(vec![Some("foo"), None, Some("bar"), Some("baz")]); + let lhs = + RunArray::::try_new(&run_ends, &values).unwrap(); + + // RHS: ["foo", "missing", NULL, "baz", "baz"] + let run_ends = Int32Array::from(vec![1, 2, 3, 5]); + let values = StringArray::from(vec![Some("foo"), Some("missing"), None, Some("baz")]); + let rhs = + RunArray::::try_new(&run_ends, &values).unwrap(); + + let result = distinct(&lhs, &rhs).unwrap(); + assert_eq!( + result, + BooleanArray::from(vec![ + false, // "foo" == "foo" + true, // NULL vs "missing" + true, // "bar" vs NULL + false, // "baz" == "baz" + false, // "baz" == "baz" + ]) + ); } } } diff --git a/arrow-ord/src/lib.rs b/arrow-ord/src/lib.rs index 0c68846b06ee..99b0451992cf 100644 --- a/arrow-ord/src/lib.rs +++ b/arrow-ord/src/lib.rs @@ -49,7 +49,6 @@ )] #![cfg_attr(docsrs, feature(doc_auto_cfg))] #![warn(missing_docs)] -pub use arrow_array::downcast_primitive_array; pub mod cmp; #[doc(hidden)] pub mod comparison;