Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
174 changes: 143 additions & 31 deletions arrow/src/compute/kernels/comparison.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2032,10 +2032,10 @@ macro_rules! typed_compares {

/// Applies $OP to $LEFT and $RIGHT which are two dictionaries which have (the same) key type $KT
macro_rules! typed_dict_cmp {
($LEFT: expr, $RIGHT: expr, $OP: expr, $KT: tt) => {{
($LEFT: expr, $RIGHT: expr, $OP: expr, $OP_BOOL: expr, $KT: tt) => {{
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍 nice readability improvement

match ($LEFT.value_type(), $RIGHT.value_type()) {
(DataType::Boolean, DataType::Boolean) => {
cmp_dict_bool::<$KT, _>($LEFT, $RIGHT, $OP)
cmp_dict_bool::<$KT, _>($LEFT, $RIGHT, $OP_BOOL)
}
(DataType::Int8, DataType::Int8) => {
cmp_dict::<$KT, Int8Type, _>($LEFT, $RIGHT, $OP)
Expand Down Expand Up @@ -2141,49 +2141,49 @@ macro_rules! typed_dict_cmp {

macro_rules! typed_dict_compares {
// Applies `LEFT OP RIGHT` when `LEFT` and `RIGHT` both are `DictionaryArray`
($LEFT: expr, $RIGHT: expr, $OP: expr) => {{
($LEFT: expr, $RIGHT: expr, $OP: expr, $OP_BOOL: expr) => {{
match ($LEFT.data_type(), $RIGHT.data_type()) {
(DataType::Dictionary(left_key_type, _), DataType::Dictionary(right_key_type, _))=> {
match (left_key_type.as_ref(), right_key_type.as_ref()) {
(DataType::Int8, DataType::Int8) => {
let left = as_dictionary_array::<Int8Type>($LEFT);
let right = as_dictionary_array::<Int8Type>($RIGHT);
typed_dict_cmp!(left, right, $OP, Int8Type)
typed_dict_cmp!(left, right, $OP, $OP_BOOL, Int8Type)
}
(DataType::Int16, DataType::Int16) => {
let left = as_dictionary_array::<Int16Type>($LEFT);
let right = as_dictionary_array::<Int16Type>($RIGHT);
typed_dict_cmp!(left, right, $OP, Int16Type)
typed_dict_cmp!(left, right, $OP, $OP_BOOL, Int16Type)
}
(DataType::Int32, DataType::Int32) => {
let left = as_dictionary_array::<Int32Type>($LEFT);
let right = as_dictionary_array::<Int32Type>($RIGHT);
typed_dict_cmp!(left, right, $OP, Int32Type)
typed_dict_cmp!(left, right, $OP, $OP_BOOL, Int32Type)
}
(DataType::Int64, DataType::Int64) => {
let left = as_dictionary_array::<Int64Type>($LEFT);
let right = as_dictionary_array::<Int64Type>($RIGHT);
typed_dict_cmp!(left, right, $OP, Int64Type)
typed_dict_cmp!(left, right, $OP, $OP_BOOL, Int64Type)
}
(DataType::UInt8, DataType::UInt8) => {
let left = as_dictionary_array::<UInt8Type>($LEFT);
let right = as_dictionary_array::<UInt8Type>($RIGHT);
typed_dict_cmp!(left, right, $OP, UInt8Type)
typed_dict_cmp!(left, right, $OP, $OP_BOOL, UInt8Type)
}
(DataType::UInt16, DataType::UInt16) => {
let left = as_dictionary_array::<UInt16Type>($LEFT);
let right = as_dictionary_array::<UInt16Type>($RIGHT);
typed_dict_cmp!(left, right, $OP, UInt16Type)
typed_dict_cmp!(left, right, $OP, $OP_BOOL, UInt16Type)
}
(DataType::UInt32, DataType::UInt32) => {
let left = as_dictionary_array::<UInt32Type>($LEFT);
let right = as_dictionary_array::<UInt32Type>($RIGHT);
typed_dict_cmp!(left, right, $OP, UInt32Type)
typed_dict_cmp!(left, right, $OP, $OP_BOOL, UInt32Type)
}
(DataType::UInt64, DataType::UInt64) => {
let left = as_dictionary_array::<UInt64Type>($LEFT);
let right = as_dictionary_array::<UInt64Type>($RIGHT);
typed_dict_cmp!(left, right, $OP, UInt64Type)
typed_dict_cmp!(left, right, $OP, $OP_BOOL, UInt64Type)
}
(t1, t2) if t1 == t2 => Err(ArrowError::NotYetImplemented(format!(
"Comparing dictionary arrays of type {} is not yet implemented",
Expand Down Expand Up @@ -2317,7 +2317,7 @@ where
pub fn eq_dyn(left: &dyn Array, right: &dyn Array) -> Result<BooleanArray> {
match left.data_type() {
DataType::Dictionary(_, _) => {
typed_dict_compares!(left, right, |a, b| a == b)
typed_dict_compares!(left, right, |a, b| a == b, |a, b| a == b)
}
_ => typed_compares!(left, right, eq_bool, eq, eq_utf8, eq_binary),
}
Expand All @@ -2340,7 +2340,12 @@ pub fn eq_dyn(left: &dyn Array, right: &dyn Array) -> Result<BooleanArray> {
/// assert_eq!(BooleanArray::from(vec![Some(false), None, Some(true)]), result);
/// ```
pub fn neq_dyn(left: &dyn Array, right: &dyn Array) -> Result<BooleanArray> {
typed_compares!(left, right, neq_bool, neq, neq_utf8, neq_binary)
match left.data_type() {
DataType::Dictionary(_, _) => {
typed_dict_compares!(left, right, |a, b| a != b, |a, b| a != b)
}
_ => typed_compares!(left, right, neq_bool, neq, neq_utf8, neq_binary),
}
}

/// Perform `left < right` operation on two (dynamic) [`Array`]s.
Expand All @@ -2358,8 +2363,14 @@ pub fn neq_dyn(left: &dyn Array, right: &dyn Array) -> Result<BooleanArray> {
/// let result = lt_dyn(&array1, &array2).unwrap();
/// assert_eq!(BooleanArray::from(vec![Some(true), Some(false), None]), result);
/// ```
#[allow(clippy::bool_comparison)]
pub fn lt_dyn(left: &dyn Array, right: &dyn Array) -> Result<BooleanArray> {
typed_compares!(left, right, lt_bool, lt, lt_utf8, lt_binary)
match left.data_type() {
DataType::Dictionary(_, _) => {
typed_dict_compares!(left, right, |a, b| a < b, |a, b| a < b)
}
_ => typed_compares!(left, right, lt_bool, lt, lt_utf8, lt_binary),
}
}

/// Perform `left <= right` operation on two (dynamic) [`Array`]s.
Expand All @@ -2378,7 +2389,12 @@ pub fn lt_dyn(left: &dyn Array, right: &dyn Array) -> Result<BooleanArray> {
/// assert_eq!(BooleanArray::from(vec![Some(false), Some(true), Some(true), None]), result);
/// ```
pub fn lt_eq_dyn(left: &dyn Array, right: &dyn Array) -> Result<BooleanArray> {
typed_compares!(left, right, lt_eq_bool, lt_eq, lt_eq_utf8, lt_eq_binary)
match left.data_type() {
DataType::Dictionary(_, _) => {
typed_dict_compares!(left, right, |a, b| a <= b, |a, b| a <= b)
}
_ => typed_compares!(left, right, lt_eq_bool, lt_eq, lt_eq_utf8, lt_eq_binary),
}
}

/// Perform `left > right` operation on two (dynamic) [`Array`]s.
Expand All @@ -2395,8 +2411,14 @@ pub fn lt_eq_dyn(left: &dyn Array, right: &dyn Array) -> Result<BooleanArray> {
/// let result = gt_dyn(&array1, &array2).unwrap();
/// assert_eq!(BooleanArray::from(vec![Some(true), Some(false), None]), result);
/// ```
#[allow(clippy::bool_comparison)]
pub fn gt_dyn(left: &dyn Array, right: &dyn Array) -> Result<BooleanArray> {
typed_compares!(left, right, gt_bool, gt, gt_utf8, gt_binary)
match left.data_type() {
DataType::Dictionary(_, _) => {
typed_dict_compares!(left, right, |a, b| a > b, |a, b| a > b)
}
_ => typed_compares!(left, right, gt_bool, gt, gt_utf8, gt_binary),
}
}

/// Perform `left >= right` operation on two (dynamic) [`Array`]s.
Expand All @@ -2414,7 +2436,12 @@ pub fn gt_dyn(left: &dyn Array, right: &dyn Array) -> Result<BooleanArray> {
/// assert_eq!(BooleanArray::from(vec![Some(false), Some(true), None]), result);
/// ```
pub fn gt_eq_dyn(left: &dyn Array, right: &dyn Array) -> Result<BooleanArray> {
typed_compares!(left, right, gt_eq_bool, gt_eq, gt_eq_utf8, gt_eq_binary)
match left.data_type() {
DataType::Dictionary(_, _) => {
typed_dict_compares!(left, right, |a, b| a >= b, |a, b| a >= b)
}
_ => typed_compares!(left, right, gt_eq_bool, gt_eq, gt_eq_utf8, gt_eq_binary),
}
}

/// Perform `left == right` operation on two [`PrimitiveArray`]s.
Expand Down Expand Up @@ -4663,7 +4690,7 @@ mod tests {
}

#[test]
fn test_eq_dyn_dictionary_i8_array() {
fn test_eq_dyn_neq_dyn_dictionary_i8_array() {
// Construct a value array
let values = Int8Array::from_iter_values([10_i8, 11, 12, 13, 14, 15, 16, 17]);

Expand All @@ -4673,12 +4700,17 @@ mod tests {
let dict_array2 = DictionaryArray::try_new(&keys2, &values).unwrap();

let result = eq_dyn(&dict_array1, &dict_array2);
assert!(result.is_ok());
assert_eq!(result.unwrap(), BooleanArray::from(vec![true, false, true]));

let result = neq_dyn(&dict_array1, &dict_array2);
assert_eq!(
result.unwrap(),
BooleanArray::from(vec![false, true, false])
);
}

#[test]
fn test_eq_dyn_dictionary_u64_array() {
fn test_eq_dyn_neq_dyn_dictionary_u64_array() {
let values = UInt64Array::from_iter_values([10_u64, 11, 12, 13, 14, 15, 16, 17]);

let keys1 = UInt64Array::from_iter_values([1_u64, 3, 4]);
Expand All @@ -4689,15 +4721,17 @@ mod tests {
DictionaryArray::<UInt64Type>::try_new(&keys2, &values).unwrap();

let result = eq_dyn(&dict_array1, &dict_array2);
assert!(result.is_ok());
assert_eq!(
result.unwrap(),
BooleanArray::from(vec![false, true, false])
);

let result = neq_dyn(&dict_array1, &dict_array2);
assert_eq!(result.unwrap(), BooleanArray::from(vec![true, false, true]));
}

#[test]
fn test_eq_dyn_dictionary_utf8_array() {
fn test_eq_dyn_neq_dyn_dictionary_utf8_array() {
let test1 = vec!["a", "a", "b", "c"];
let test2 = vec!["a", "b", "b", "c"];

Expand All @@ -4711,15 +4745,20 @@ mod tests {
.collect();

let result = eq_dyn(&dict_array1, &dict_array2);
assert!(result.is_ok());
assert_eq!(
result.unwrap(),
BooleanArray::from(vec![Some(true), None, None, Some(true)])
);

let result = neq_dyn(&dict_array1, &dict_array2);
assert_eq!(
result.unwrap(),
BooleanArray::from(vec![Some(false), None, None, Some(false)])
);
}

#[test]
fn test_eq_dyn_dictionary_binary_array() {
fn test_eq_dyn_neq_dyn_dictionary_binary_array() {
let values: BinaryArray = ["hello", "", "parquet"]
.into_iter()
.map(|b| Some(b.as_bytes()))
Expand All @@ -4733,15 +4772,17 @@ mod tests {
DictionaryArray::<UInt64Type>::try_new(&keys2, &values).unwrap();

let result = eq_dyn(&dict_array1, &dict_array2);
assert!(result.is_ok());
assert_eq!(
result.unwrap(),
BooleanArray::from(vec![true, false, false])
);

let result = neq_dyn(&dict_array1, &dict_array2);
assert_eq!(result.unwrap(), BooleanArray::from(vec![false, true, true]));
}

#[test]
fn test_eq_dyn_dictionary_interval_array() {
fn test_eq_dyn_neq_dyn_dictionary_interval_array() {
let values = IntervalDayTimeArray::from(vec![1, 6, 10, 2, 3, 5]);

let keys1 = UInt64Array::from_iter_values([1_u64, 0, 3]);
Expand All @@ -4752,12 +4793,17 @@ mod tests {
DictionaryArray::<UInt64Type>::try_new(&keys2, &values).unwrap();

let result = eq_dyn(&dict_array1, &dict_array2);
assert!(result.is_ok());
assert_eq!(result.unwrap(), BooleanArray::from(vec![false, true, true]));

let result = neq_dyn(&dict_array1, &dict_array2);
assert_eq!(
result.unwrap(),
BooleanArray::from(vec![true, false, false])
);
}

#[test]
fn test_eq_dyn_dictionary_date_array() {
fn test_eq_dyn_neq_dyn_dictionary_date_array() {
let values = Date32Array::from(vec![1, 6, 10, 2, 3, 5]);

let keys1 = UInt64Array::from_iter_values([1_u64, 0, 3]);
Expand All @@ -4768,12 +4814,17 @@ mod tests {
DictionaryArray::<UInt64Type>::try_new(&keys2, &values).unwrap();

let result = eq_dyn(&dict_array1, &dict_array2);
assert!(result.is_ok());
assert_eq!(result.unwrap(), BooleanArray::from(vec![false, true, true]));

let result = neq_dyn(&dict_array1, &dict_array2);
assert_eq!(
result.unwrap(),
BooleanArray::from(vec![true, false, false])
);
}

#[test]
fn test_eq_dyn_dictionary_bool_array() {
fn test_eq_dyn_neq_dyn_dictionary_bool_array() {
let values = BooleanArray::from(vec![true, false]);

let keys1 = UInt64Array::from_iter_values([1_u64, 1, 1]);
Expand All @@ -4784,10 +4835,71 @@ mod tests {
DictionaryArray::<UInt64Type>::try_new(&keys2, &values).unwrap();

let result = eq_dyn(&dict_array1, &dict_array2);
assert!(result.is_ok());
assert_eq!(
result.unwrap(),
BooleanArray::from(vec![false, true, false])
);

let result = neq_dyn(&dict_array1, &dict_array2);
assert_eq!(result.unwrap(), BooleanArray::from(vec![true, false, true]));
}

#[test]
fn test_lt_dyn_gt_dyn_dictionary_i8_array() {
// Construct a value array
let values = Int8Array::from_iter_values([10_i8, 11, 12, 13, 14, 15, 16, 17]);

let keys1 = Int8Array::from_iter_values([3_i8, 4, 4]);
let keys2 = Int8Array::from_iter_values([4_i8, 3, 4]);
let dict_array1 = DictionaryArray::try_new(&keys1, &values).unwrap();
let dict_array2 = DictionaryArray::try_new(&keys2, &values).unwrap();

let result = lt_dyn(&dict_array1, &dict_array2);
assert_eq!(
result.unwrap(),
BooleanArray::from(vec![true, false, false])
);

let result = lt_eq_dyn(&dict_array1, &dict_array2);
assert_eq!(result.unwrap(), BooleanArray::from(vec![true, false, true]));

let result = gt_dyn(&dict_array1, &dict_array2);
assert_eq!(
result.unwrap(),
BooleanArray::from(vec![false, true, false])
);

let result = gt_eq_dyn(&dict_array1, &dict_array2);
assert_eq!(result.unwrap(), BooleanArray::from(vec![false, true, true]));
}

#[test]
fn test_lt_dyn_gt_dyn_dictionary_bool_array() {
let values = BooleanArray::from(vec![true, false]);

let keys1 = UInt64Array::from_iter_values([1_u64, 1, 0]);
let keys2 = UInt64Array::from_iter_values([0_u64, 1, 1]);
let dict_array1 =
DictionaryArray::<UInt64Type>::try_new(&keys1, &values).unwrap();
let dict_array2 =
DictionaryArray::<UInt64Type>::try_new(&keys2, &values).unwrap();

let result = lt_dyn(&dict_array1, &dict_array2);
assert_eq!(
result.unwrap(),
BooleanArray::from(vec![true, false, false])
);

let result = lt_eq_dyn(&dict_array1, &dict_array2);
assert_eq!(result.unwrap(), BooleanArray::from(vec![true, true, false]));

let result = gt_dyn(&dict_array1, &dict_array2);
assert_eq!(
result.unwrap(),
BooleanArray::from(vec![false, false, true])
);

let result = gt_eq_dyn(&dict_array1, &dict_array2);
assert_eq!(result.unwrap(), BooleanArray::from(vec![false, true, true]));
}
}