-
Notifications
You must be signed in to change notification settings - Fork 1.1k
Add scalar comparison kernels for DictionaryArray #984
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 6 commits
c5c7f35
155d60f
4896762
0d5d1b1
c147869
b5f04c5
ee7997c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -27,8 +27,9 @@ use crate::buffer::{bitwise_bin_op_helper, buffer_unary_not, Buffer, MutableBuff | |
| use crate::compute::binary_boolean_kernel; | ||
| use crate::compute::util::combine_option_bitmap; | ||
| use crate::datatypes::{ | ||
| ArrowNumericType, DataType, Float32Type, Float64Type, Int16Type, Int32Type, | ||
| Int64Type, Int8Type, UInt16Type, UInt32Type, UInt64Type, UInt8Type, | ||
| ArrowNativeType, ArrowNumericType, ArrowPrimitiveType, DataType, Float32Type, | ||
| Float64Type, Int16Type, Int32Type, Int64Type, Int8Type, UInt16Type, UInt32Type, | ||
| UInt64Type, UInt8Type, | ||
| }; | ||
| use crate::error::{ArrowError, Result}; | ||
| use crate::util::bit_util; | ||
|
|
@@ -200,6 +201,54 @@ macro_rules! compare_op_scalar_primitive { | |
| }}; | ||
| } | ||
|
|
||
| macro_rules! compare_dict_op_scalar { | ||
| ($left:expr, $T:ident, $right:expr, $op:expr) => {{ | ||
| let null_bit_buffer = $left | ||
| .data() | ||
| .null_buffer() | ||
| .map(|b| b.bit_slice($left.offset(), $left.len())); | ||
|
|
||
| let values = $left.values(); | ||
|
|
||
| let array = values | ||
| .as_any() | ||
| .downcast_ref::<PrimitiveArray<$T>>() | ||
| .unwrap(); | ||
|
|
||
| // Safety: | ||
| // `i < $left.len()` | ||
| let comparison: Vec<bool> = (0..array.len()) | ||
| .map(|i| unsafe { $op(array.value_unchecked(i), $right) }) | ||
| .collect(); | ||
|
|
||
| let result: Vec<bool> = (0..$left.keys().len()) | ||
| .map(|key| { | ||
| let index = $left.keys().value(key); | ||
| comparison[index | ||
| .to_usize() | ||
| .expect(format!("Failed at idx {:?}", index).as_str())] | ||
| }) | ||
| .collect(); | ||
|
|
||
| // same as $left.len() | ||
| let buffer = | ||
| unsafe { MutableBuffer::from_trusted_len_iter_bool(result.into_iter()) }; | ||
|
|
||
| let data = unsafe { | ||
| ArrayData::new_unchecked( | ||
| DataType::Boolean, | ||
| $left.len(), | ||
| None, | ||
| null_bit_buffer, | ||
| 0, | ||
| vec![Buffer::from(buffer)], | ||
| vec![], | ||
| ) | ||
| }; | ||
| Ok(BooleanArray::from(data)) | ||
| }}; | ||
| } | ||
|
|
||
| /// Evaluate `op(left, right)` for [`PrimitiveArray`]s using a specified | ||
| /// comparison function. | ||
| pub fn no_simd_compare_op<T, F>( | ||
|
|
@@ -693,6 +742,14 @@ pub fn eq_utf8_scalar<OffsetSize: StringOffsetSizeTrait>( | |
| compare_op_scalar!(left, right, |a, b| a == b) | ||
| } | ||
|
|
||
| /// Perform `left == right` operation on [`DictionaryArray`] and a scalar. | ||
| // pub fn eq_dict<OffsetSize: ArrowPrimitiveType>( | ||
| // left: &DictionaryArray<OffsetSize>, | ||
| // right: &str, | ||
| // ) -> Result<BooleanArray> { | ||
| // compare_dict_op!(left, right, |a, b| a == b) | ||
| // } | ||
|
|
||
| #[inline] | ||
| fn binary_boolean_op<F>( | ||
| left: &BooleanArray, | ||
|
|
@@ -1200,6 +1257,29 @@ where | |
| return compare_op_scalar!(left, right, |a, b| a == b); | ||
| } | ||
|
|
||
| /// Perform `left == right` operation on a [`PrimitiveArray`] and a scalar value. | ||
| pub fn eq_dict_scalar<T>( | ||
| left: &DictionaryArray<T>, | ||
| right: T::Native, | ||
| ) -> Result<BooleanArray> | ||
| where | ||
| T: ArrowNumericType, | ||
| { | ||
| #[cfg(not(feature = "simd"))] | ||
| println!("{}", std::any::type_name::<T>()); | ||
| return compare_dict_op_scalar!(left, T, right, |a, b| a == b); | ||
| } | ||
|
||
|
|
||
| // pub fn eq_dict_utf8_scalar<OffsetSize>( | ||
| // left: &DictionaryArray<OffsetSize>, | ||
| // right: &str, | ||
| // ) -> Result<BooleanArray> | ||
| // where | ||
| // OffsetSize: StringOffsetSizeTrait + ArrowPrimitiveType, | ||
| // { | ||
| // #[cfg(not(feature = "simd"))] | ||
| // return compare_dict_op_scalar!(left, OffsetSize, right, |a, b| a == b); | ||
| // } | ||
| /// Perform `left != right` operation on two [`PrimitiveArray`]s. | ||
| pub fn neq<T>(left: &PrimitiveArray<T>, right: &PrimitiveArray<T>) -> Result<BooleanArray> | ||
| where | ||
|
|
@@ -2032,6 +2112,44 @@ mod tests { | |
| ); | ||
| } | ||
|
|
||
| #[test] | ||
| fn test_dict_eq_scalar() { | ||
| let key_builder = PrimitiveBuilder::<UInt8Type>::new(3); | ||
| let value_builder = PrimitiveBuilder::<UInt8Type>::new(2); | ||
| let mut builder = PrimitiveDictionaryBuilder::new(key_builder, value_builder); | ||
| builder.append(123).unwrap(); | ||
| builder.append_null().unwrap(); | ||
| builder.append(223).unwrap(); | ||
| let array = builder.finish(); | ||
| let a_eq = eq_dict_scalar(&array, 123).unwrap(); | ||
| assert_eq!( | ||
| a_eq, | ||
| BooleanArray::from(vec![Some(true), None, Some(false)]) | ||
| ); | ||
| } | ||
|
|
||
| // #[test] | ||
| // fn test_dict_eq_utf8_scalar() { | ||
| // let a: DictionaryArray<Int8Type> = vec!["a", "b", "c"].into_iter().collect(); | ||
| // let a_eq = eq_dict_utf8_scalar(&a, "b").unwrap(); | ||
| // assert_eq!(a_eq, BooleanArray::from(vec![false, true, false])); | ||
| // } | ||
| // #[test] | ||
| // fn test_dict_neq_scalar() { | ||
| // let a: DictionaryArray<Int8Type> = | ||
| // vec!["hi","hello", "world"].into_iter().collect(); | ||
| // let a_eq = neq_dict_scalar(&a, "hello").unwrap(); | ||
| // assert_eq!(a_eq, BooleanArray::from(vec![true, false, true])); | ||
| // } | ||
|
|
||
| // #[test] | ||
| // fn test_dict_lt_scalar() { | ||
| // let a: DictionaryArray<Int8Type> = | ||
| // vec!["hi","hello", "world"].into_iter().collect(); | ||
| // let a_eq = lt_dict_scalar(&a, "hi").unwrap(); | ||
| // assert_eq!(a_eq, BooleanArray::from(vec![false, true, false])); | ||
| // } | ||
|
|
||
| macro_rules! test_utf8_scalar { | ||
| ($test_name:ident, $left:expr, $right:expr, $op:expr, $expected:expr) => { | ||
| #[test] | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I didn't think the
values()(the dictionary size) has to the same as the size of the overall array 🤔There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if youre referring to the safety comment i just hadnt removed that yet