diff --git a/arrow/src/compute/kernels/comparison.rs b/arrow/src/compute/kernels/comparison.rs index f5bfbd147ca6..d1c33e33dee6 100644 --- a/arrow/src/compute/kernels/comparison.rs +++ b/arrow/src/compute/kernels/comparison.rs @@ -2032,10 +2032,10 @@ macro_rules! typed_compares { /// Applies $OP to $LEFT and $RIGHT which are two dictionaries which have (the same) key type $KT macro_rules! typed_dict_cmp { - ($LEFT: expr, $RIGHT: expr, $OP: expr, $KT: tt) => {{ + ($LEFT: expr, $RIGHT: expr, $OP: expr, $OP_BOOL: expr, $KT: tt) => {{ match ($LEFT.value_type(), $RIGHT.value_type()) { (DataType::Boolean, DataType::Boolean) => { - cmp_dict_bool::<$KT, _>($LEFT, $RIGHT, $OP) + cmp_dict_bool::<$KT, _>($LEFT, $RIGHT, $OP_BOOL) } (DataType::Int8, DataType::Int8) => { cmp_dict::<$KT, Int8Type, _>($LEFT, $RIGHT, $OP) @@ -2141,49 +2141,49 @@ macro_rules! typed_dict_cmp { macro_rules! typed_dict_compares { // Applies `LEFT OP RIGHT` when `LEFT` and `RIGHT` both are `DictionaryArray` - ($LEFT: expr, $RIGHT: expr, $OP: expr) => {{ + ($LEFT: expr, $RIGHT: expr, $OP: expr, $OP_BOOL: expr) => {{ match ($LEFT.data_type(), $RIGHT.data_type()) { (DataType::Dictionary(left_key_type, _), DataType::Dictionary(right_key_type, _))=> { match (left_key_type.as_ref(), right_key_type.as_ref()) { (DataType::Int8, DataType::Int8) => { let left = as_dictionary_array::($LEFT); let right = as_dictionary_array::($RIGHT); - typed_dict_cmp!(left, right, $OP, Int8Type) + typed_dict_cmp!(left, right, $OP, $OP_BOOL, Int8Type) } (DataType::Int16, DataType::Int16) => { let left = as_dictionary_array::($LEFT); let right = as_dictionary_array::($RIGHT); - typed_dict_cmp!(left, right, $OP, Int16Type) + typed_dict_cmp!(left, right, $OP, $OP_BOOL, Int16Type) } (DataType::Int32, DataType::Int32) => { let left = as_dictionary_array::($LEFT); let right = as_dictionary_array::($RIGHT); - typed_dict_cmp!(left, right, $OP, Int32Type) + typed_dict_cmp!(left, right, $OP, $OP_BOOL, Int32Type) } (DataType::Int64, DataType::Int64) => { let left = as_dictionary_array::($LEFT); let right = as_dictionary_array::($RIGHT); - typed_dict_cmp!(left, right, $OP, Int64Type) + typed_dict_cmp!(left, right, $OP, $OP_BOOL, Int64Type) } (DataType::UInt8, DataType::UInt8) => { let left = as_dictionary_array::($LEFT); let right = as_dictionary_array::($RIGHT); - typed_dict_cmp!(left, right, $OP, UInt8Type) + typed_dict_cmp!(left, right, $OP, $OP_BOOL, UInt8Type) } (DataType::UInt16, DataType::UInt16) => { let left = as_dictionary_array::($LEFT); let right = as_dictionary_array::($RIGHT); - typed_dict_cmp!(left, right, $OP, UInt16Type) + typed_dict_cmp!(left, right, $OP, $OP_BOOL, UInt16Type) } (DataType::UInt32, DataType::UInt32) => { let left = as_dictionary_array::($LEFT); let right = as_dictionary_array::($RIGHT); - typed_dict_cmp!(left, right, $OP, UInt32Type) + typed_dict_cmp!(left, right, $OP, $OP_BOOL, UInt32Type) } (DataType::UInt64, DataType::UInt64) => { let left = as_dictionary_array::($LEFT); let right = as_dictionary_array::($RIGHT); - typed_dict_cmp!(left, right, $OP, UInt64Type) + typed_dict_cmp!(left, right, $OP, $OP_BOOL, UInt64Type) } (t1, t2) if t1 == t2 => Err(ArrowError::NotYetImplemented(format!( "Comparing dictionary arrays of type {} is not yet implemented", @@ -2317,7 +2317,7 @@ where pub fn eq_dyn(left: &dyn Array, right: &dyn Array) -> Result { match left.data_type() { DataType::Dictionary(_, _) => { - typed_dict_compares!(left, right, |a, b| a == b) + typed_dict_compares!(left, right, |a, b| a == b, |a, b| a == b) } _ => typed_compares!(left, right, eq_bool, eq, eq_utf8, eq_binary), } @@ -2340,7 +2340,12 @@ pub fn eq_dyn(left: &dyn Array, right: &dyn Array) -> Result { /// assert_eq!(BooleanArray::from(vec![Some(false), None, Some(true)]), result); /// ``` pub fn neq_dyn(left: &dyn Array, right: &dyn Array) -> Result { - typed_compares!(left, right, neq_bool, neq, neq_utf8, neq_binary) + match left.data_type() { + DataType::Dictionary(_, _) => { + typed_dict_compares!(left, right, |a, b| a != b, |a, b| a != b) + } + _ => typed_compares!(left, right, neq_bool, neq, neq_utf8, neq_binary), + } } /// Perform `left < right` operation on two (dynamic) [`Array`]s. @@ -2358,8 +2363,14 @@ pub fn neq_dyn(left: &dyn Array, right: &dyn Array) -> Result { /// let result = lt_dyn(&array1, &array2).unwrap(); /// assert_eq!(BooleanArray::from(vec![Some(true), Some(false), None]), result); /// ``` +#[allow(clippy::bool_comparison)] pub fn lt_dyn(left: &dyn Array, right: &dyn Array) -> Result { - typed_compares!(left, right, lt_bool, lt, lt_utf8, lt_binary) + match left.data_type() { + DataType::Dictionary(_, _) => { + typed_dict_compares!(left, right, |a, b| a < b, |a, b| a < b) + } + _ => typed_compares!(left, right, lt_bool, lt, lt_utf8, lt_binary), + } } /// Perform `left <= right` operation on two (dynamic) [`Array`]s. @@ -2378,7 +2389,12 @@ pub fn lt_dyn(left: &dyn Array, right: &dyn Array) -> Result { /// assert_eq!(BooleanArray::from(vec![Some(false), Some(true), Some(true), None]), result); /// ``` pub fn lt_eq_dyn(left: &dyn Array, right: &dyn Array) -> Result { - typed_compares!(left, right, lt_eq_bool, lt_eq, lt_eq_utf8, lt_eq_binary) + match left.data_type() { + DataType::Dictionary(_, _) => { + typed_dict_compares!(left, right, |a, b| a <= b, |a, b| a <= b) + } + _ => typed_compares!(left, right, lt_eq_bool, lt_eq, lt_eq_utf8, lt_eq_binary), + } } /// Perform `left > right` operation on two (dynamic) [`Array`]s. @@ -2395,8 +2411,14 @@ pub fn lt_eq_dyn(left: &dyn Array, right: &dyn Array) -> Result { /// let result = gt_dyn(&array1, &array2).unwrap(); /// assert_eq!(BooleanArray::from(vec![Some(true), Some(false), None]), result); /// ``` +#[allow(clippy::bool_comparison)] pub fn gt_dyn(left: &dyn Array, right: &dyn Array) -> Result { - typed_compares!(left, right, gt_bool, gt, gt_utf8, gt_binary) + match left.data_type() { + DataType::Dictionary(_, _) => { + typed_dict_compares!(left, right, |a, b| a > b, |a, b| a > b) + } + _ => typed_compares!(left, right, gt_bool, gt, gt_utf8, gt_binary), + } } /// Perform `left >= right` operation on two (dynamic) [`Array`]s. @@ -2414,7 +2436,12 @@ pub fn gt_dyn(left: &dyn Array, right: &dyn Array) -> Result { /// assert_eq!(BooleanArray::from(vec![Some(false), Some(true), None]), result); /// ``` pub fn gt_eq_dyn(left: &dyn Array, right: &dyn Array) -> Result { - typed_compares!(left, right, gt_eq_bool, gt_eq, gt_eq_utf8, gt_eq_binary) + match left.data_type() { + DataType::Dictionary(_, _) => { + typed_dict_compares!(left, right, |a, b| a >= b, |a, b| a >= b) + } + _ => typed_compares!(left, right, gt_eq_bool, gt_eq, gt_eq_utf8, gt_eq_binary), + } } /// Perform `left == right` operation on two [`PrimitiveArray`]s. @@ -4663,7 +4690,7 @@ mod tests { } #[test] - fn test_eq_dyn_dictionary_i8_array() { + fn test_eq_dyn_neq_dyn_dictionary_i8_array() { // Construct a value array let values = Int8Array::from_iter_values([10_i8, 11, 12, 13, 14, 15, 16, 17]); @@ -4673,12 +4700,17 @@ mod tests { let dict_array2 = DictionaryArray::try_new(&keys2, &values).unwrap(); let result = eq_dyn(&dict_array1, &dict_array2); - assert!(result.is_ok()); assert_eq!(result.unwrap(), BooleanArray::from(vec![true, false, true])); + + let result = neq_dyn(&dict_array1, &dict_array2); + assert_eq!( + result.unwrap(), + BooleanArray::from(vec![false, true, false]) + ); } #[test] - fn test_eq_dyn_dictionary_u64_array() { + fn test_eq_dyn_neq_dyn_dictionary_u64_array() { let values = UInt64Array::from_iter_values([10_u64, 11, 12, 13, 14, 15, 16, 17]); let keys1 = UInt64Array::from_iter_values([1_u64, 3, 4]); @@ -4689,15 +4721,17 @@ mod tests { DictionaryArray::::try_new(&keys2, &values).unwrap(); let result = eq_dyn(&dict_array1, &dict_array2); - assert!(result.is_ok()); assert_eq!( result.unwrap(), BooleanArray::from(vec![false, true, false]) ); + + let result = neq_dyn(&dict_array1, &dict_array2); + assert_eq!(result.unwrap(), BooleanArray::from(vec![true, false, true])); } #[test] - fn test_eq_dyn_dictionary_utf8_array() { + fn test_eq_dyn_neq_dyn_dictionary_utf8_array() { let test1 = vec!["a", "a", "b", "c"]; let test2 = vec!["a", "b", "b", "c"]; @@ -4711,15 +4745,20 @@ mod tests { .collect(); let result = eq_dyn(&dict_array1, &dict_array2); - assert!(result.is_ok()); assert_eq!( result.unwrap(), BooleanArray::from(vec![Some(true), None, None, Some(true)]) ); + + let result = neq_dyn(&dict_array1, &dict_array2); + assert_eq!( + result.unwrap(), + BooleanArray::from(vec![Some(false), None, None, Some(false)]) + ); } #[test] - fn test_eq_dyn_dictionary_binary_array() { + fn test_eq_dyn_neq_dyn_dictionary_binary_array() { let values: BinaryArray = ["hello", "", "parquet"] .into_iter() .map(|b| Some(b.as_bytes())) @@ -4733,15 +4772,17 @@ mod tests { DictionaryArray::::try_new(&keys2, &values).unwrap(); let result = eq_dyn(&dict_array1, &dict_array2); - assert!(result.is_ok()); assert_eq!( result.unwrap(), BooleanArray::from(vec![true, false, false]) ); + + let result = neq_dyn(&dict_array1, &dict_array2); + assert_eq!(result.unwrap(), BooleanArray::from(vec![false, true, true])); } #[test] - fn test_eq_dyn_dictionary_interval_array() { + fn test_eq_dyn_neq_dyn_dictionary_interval_array() { let values = IntervalDayTimeArray::from(vec![1, 6, 10, 2, 3, 5]); let keys1 = UInt64Array::from_iter_values([1_u64, 0, 3]); @@ -4752,12 +4793,17 @@ mod tests { DictionaryArray::::try_new(&keys2, &values).unwrap(); let result = eq_dyn(&dict_array1, &dict_array2); - assert!(result.is_ok()); assert_eq!(result.unwrap(), BooleanArray::from(vec![false, true, true])); + + let result = neq_dyn(&dict_array1, &dict_array2); + assert_eq!( + result.unwrap(), + BooleanArray::from(vec![true, false, false]) + ); } #[test] - fn test_eq_dyn_dictionary_date_array() { + fn test_eq_dyn_neq_dyn_dictionary_date_array() { let values = Date32Array::from(vec![1, 6, 10, 2, 3, 5]); let keys1 = UInt64Array::from_iter_values([1_u64, 0, 3]); @@ -4768,12 +4814,17 @@ mod tests { DictionaryArray::::try_new(&keys2, &values).unwrap(); let result = eq_dyn(&dict_array1, &dict_array2); - assert!(result.is_ok()); assert_eq!(result.unwrap(), BooleanArray::from(vec![false, true, true])); + + let result = neq_dyn(&dict_array1, &dict_array2); + assert_eq!( + result.unwrap(), + BooleanArray::from(vec![true, false, false]) + ); } #[test] - fn test_eq_dyn_dictionary_bool_array() { + fn test_eq_dyn_neq_dyn_dictionary_bool_array() { let values = BooleanArray::from(vec![true, false]); let keys1 = UInt64Array::from_iter_values([1_u64, 1, 1]); @@ -4784,10 +4835,71 @@ mod tests { DictionaryArray::::try_new(&keys2, &values).unwrap(); let result = eq_dyn(&dict_array1, &dict_array2); - assert!(result.is_ok()); assert_eq!( result.unwrap(), BooleanArray::from(vec![false, true, false]) ); + + let result = neq_dyn(&dict_array1, &dict_array2); + assert_eq!(result.unwrap(), BooleanArray::from(vec![true, false, true])); + } + + #[test] + fn test_lt_dyn_gt_dyn_dictionary_i8_array() { + // Construct a value array + let values = Int8Array::from_iter_values([10_i8, 11, 12, 13, 14, 15, 16, 17]); + + let keys1 = Int8Array::from_iter_values([3_i8, 4, 4]); + let keys2 = Int8Array::from_iter_values([4_i8, 3, 4]); + let dict_array1 = DictionaryArray::try_new(&keys1, &values).unwrap(); + let dict_array2 = DictionaryArray::try_new(&keys2, &values).unwrap(); + + let result = lt_dyn(&dict_array1, &dict_array2); + assert_eq!( + result.unwrap(), + BooleanArray::from(vec![true, false, false]) + ); + + let result = lt_eq_dyn(&dict_array1, &dict_array2); + assert_eq!(result.unwrap(), BooleanArray::from(vec![true, false, true])); + + let result = gt_dyn(&dict_array1, &dict_array2); + assert_eq!( + result.unwrap(), + BooleanArray::from(vec![false, true, false]) + ); + + let result = gt_eq_dyn(&dict_array1, &dict_array2); + assert_eq!(result.unwrap(), BooleanArray::from(vec![false, true, true])); + } + + #[test] + fn test_lt_dyn_gt_dyn_dictionary_bool_array() { + let values = BooleanArray::from(vec![true, false]); + + let keys1 = UInt64Array::from_iter_values([1_u64, 1, 0]); + let keys2 = UInt64Array::from_iter_values([0_u64, 1, 1]); + let dict_array1 = + DictionaryArray::::try_new(&keys1, &values).unwrap(); + let dict_array2 = + DictionaryArray::::try_new(&keys2, &values).unwrap(); + + let result = lt_dyn(&dict_array1, &dict_array2); + assert_eq!( + result.unwrap(), + BooleanArray::from(vec![true, false, false]) + ); + + let result = lt_eq_dyn(&dict_array1, &dict_array2); + assert_eq!(result.unwrap(), BooleanArray::from(vec![true, true, false])); + + let result = gt_dyn(&dict_array1, &dict_array2); + assert_eq!( + result.unwrap(), + BooleanArray::from(vec![false, false, true]) + ); + + let result = gt_eq_dyn(&dict_array1, &dict_array2); + assert_eq!(result.unwrap(), BooleanArray::from(vec![false, true, true])); } }