diff --git a/arrow/src/compute/kernels/comparison.rs b/arrow/src/compute/kernels/comparison.rs index 3e7a084cf334..b4de7744298e 100644 --- a/arrow/src/compute/kernels/comparison.rs +++ b/arrow/src/compute/kernels/comparison.rs @@ -1102,7 +1102,7 @@ where | DataType::UInt32 | DataType::UInt64 => {dyn_compare_scalar!(&left, right, key_type, eq_scalar)} _ => Err(ArrowError::ComputeError( - "Kernel only supports PrimitiveArray or DictionaryArray with Primitive values".to_string(), + "eq_dyn_scalar only supports PrimitiveArray or DictionaryArray with Primitive values".to_string(), )) } DataType::Int8 @@ -1116,7 +1116,43 @@ where dyn_compare_scalar!(&left, right, eq_scalar) } _ => Err(ArrowError::ComputeError( - "Kernel only supports PrimitiveArray or DictionaryArray with Primitive values".to_string(), + "eq_dyn_scalar only supports PrimitiveArray or DictionaryArray with Primitive values".to_string(), + )) + } +} + +/// Perform `left < right` operation on an array and a numeric scalar +/// value. Supports PrimitiveArrays, and DictionaryArrays that have primitive values +pub fn lt_dyn_scalar(left: Arc, right: T) -> Result +where + T: TryInto + Copy + std::fmt::Debug, +{ + match left.data_type() { + DataType::Dictionary(key_type, value_type) => match value_type.as_ref() { + DataType::Int8 + | DataType::Int16 + | DataType::Int32 + | DataType::Int64 + | DataType::UInt8 + | DataType::UInt16 + | DataType::UInt32 + | DataType::UInt64 => {dyn_compare_scalar!(&left, right, key_type, lt_scalar)} + _ => Err(ArrowError::ComputeError( + "lt_dyn_scalar only supports PrimitiveArray or DictionaryArray with Primitive values".to_string(), + )) + } + DataType::Int8 + | DataType::Int16 + | DataType::Int32 + | DataType::Int64 + | DataType::UInt8 + | DataType::UInt16 + | DataType::UInt32 + | DataType::UInt64 => { + dyn_compare_scalar!(&left, right, lt_scalar) + } + _ => Err(ArrowError::ComputeError( + "lt_dyn_scalar only supports PrimitiveArray or DictionaryArray with Primitive values".to_string(), )) } } @@ -2973,6 +3009,33 @@ mod tests { ); } #[test] + fn test_lt_dyn_scalar() { + let array = Int32Array::from(vec![6, 7, 8, 8, 10]); + let array = Arc::new(array); + let a_eq = lt_dyn_scalar(array, 8).unwrap(); + assert_eq!( + a_eq, + BooleanArray::from( + vec![Some(true), Some(true), Some(false), Some(false), Some(false)] + ) + ); + } + #[test] + fn test_lt_dyn_scalar_with_dict() { + let key_builder = PrimitiveBuilder::::new(3); + let value_builder = PrimitiveBuilder::::new(2); + let mut builder = PrimitiveDictionaryBuilder::new(key_builder, value_builder); + builder.append(123).unwrap(); + builder.append_null().unwrap(); + builder.append(23).unwrap(); + let array = Arc::new(builder.finish()); + let a_eq = lt_dyn_scalar(array, 123).unwrap(); + assert_eq!( + a_eq, + BooleanArray::from(vec![Some(false), None, Some(true)]) + ); + } + #[test] fn test_eq_dyn_utf8_scalar() { let array = StringArray::from(vec!["abc", "def", "xyz"]); let array = Arc::new(array);