Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
200 changes: 198 additions & 2 deletions arrow-ord/src/ord.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,10 @@ fn child_opts(opts: SortOptions) -> SortOptions {
}
}

fn compare<A, F>(l: &A, r: &A, opts: SortOptions, cmp: F) -> DynComparator
fn compare<A, B, F>(l: &A, r: &B, opts: SortOptions, cmp: F) -> DynComparator
where
A: Array + Clone,
A: Array + ?Sized,
B: Array + ?Sized,
F: Fn(usize, usize) -> Ordering + Send + Sync + 'static,
{
let l = l.logical_nulls().filter(|x| x.null_count() > 0);
Expand Down Expand Up @@ -368,6 +369,52 @@ fn compare_union(
Ok(f)
}

fn compare_union_to_opaque(
union_array: &dyn Array,
opaque_array: &dyn Array,
opts: SortOptions,
) -> Result<DynComparator, ArrowError> {
let union_array = union_array.as_union();

let DataType::Union(union_fields, _) = union_array.data_type() else {
unreachable!()
};

let opaque_type_id = union_fields
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure about this -- it seems like it will compare the union array to the first field that has the same type.

What if the union array has multiple repeated types (I realize that is not a common use case, but I think it is possible)

.iter()
.find_map(|(i, f)| (f.data_type() == opaque_array.data_type()).then_some(i))
.ok_or_else(|| {
ArrowError::InvalidArgumentError(format!(
"cannot compare union with {} array: type not found in union fields",
opaque_array.data_type(),
))
})?;

let c_opts = child_opts(opts);

let opaque_field_comparator = {
let union_child = union_array.child(opaque_type_id);
make_comparator(union_child.as_ref(), opaque_array, c_opts)?
};

let union_type_ids = union_array.type_ids().clone();
let union_offsets = union_array.offsets().cloned();

let f = compare(union_array, opaque_array, opts, move |i, j| {
let union_type_id = union_type_ids[i];

match union_type_id.cmp(&opaque_type_id) {
Ordering::Equal => {
let union_offset = union_offsets.as_ref().map(|o| o[i] as usize).unwrap_or(i);
opaque_field_comparator(union_offset, j)
}
other => other,
}
});

Ok(f)
}

/// Returns a comparison function that compares two values at two different positions
/// between the two arrays.
///
Expand Down Expand Up @@ -485,6 +532,8 @@ pub fn make_comparator(
},
(Map(_, _), Map(_, _)) => compare_map(left, right, opts),
(Union(_, _), Union(_, _)) => compare_union(left, right, opts),
(Union(_, _), _) => compare_union_to_opaque(left, right, opts),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there any other example of comparing one array to an array of another type? I think this kernel expects the types to match exactly

I do see the argument about how UnionArray could be treated as a special case (as some sort of runtime typed thing) 🤔 but I am not sure UnionArrays are required to behave that way

I wrote up another way that your usecase might be handled

Let me know what you think

(_, Union(_, _)) => compare_union_to_opaque(right, left, opts),
(lhs, rhs) => Err(ArrowError::InvalidArgumentError(match lhs == rhs {
true => format!("The data type type {lhs:?} has no natural order"),
false => "Can't compare arrays of different types".to_string(),
Expand Down Expand Up @@ -1501,4 +1550,151 @@ mod tests {
"Cannot compare UnionArrays with different modes: left=Dense, right=Sparse"
);
}

#[test]
fn test_union_to_opaque_int32() {
let int_array = Int32Array::from(vec![1, 2, 3]);
let str_array = StringArray::from(vec!["a", "b"]);
let type_ids = [0, 1, 0, 1, 0].into_iter().collect::<ScalarBuffer<i8>>();
let offsets = [0, 0, 1, 1, 2].into_iter().collect::<ScalarBuffer<i32>>();
let union_fields = [
(0, Arc::new(Field::new("ints", DataType::Int32, false))),
(1, Arc::new(Field::new("strings", DataType::Utf8, false))),
]
.into_iter()
.collect::<UnionFields>();
let children = vec![Arc::new(int_array) as ArrayRef, Arc::new(str_array)];

// union: [1, "a", 2, "b", 3], opaque: [2, 6, 7]
let union_array =
UnionArray::try_new(union_fields, type_ids, Some(offsets), children).unwrap();
let opaque_array = Int32Array::from(vec![2, 6, 7]);
let opts = SortOptions::default();

let cmp = make_comparator(&union_array, &opaque_array, opts).unwrap();

// 1 < 2
assert_eq!(cmp(0, 0), Ordering::Less);
// type_id 1 > 0
assert_eq!(cmp(1, 0), Ordering::Greater);
// 2 == 2
assert_eq!(cmp(2, 0), Ordering::Equal);
// 3 > 6
assert_eq!(cmp(4, 1), Ordering::Less);
}

#[test]
fn test_union_to_opaque_string() {
let str_array = StringArray::from(vec![Some("apple"), None, Some("pork")]);
let int_array = Int32Array::from(vec![None, Some(67), None]);
let type_ids = [1, 0, 1].into_iter().collect::<ScalarBuffer<i8>>();
let union_fields = [
(0, Arc::new(Field::new("ints", DataType::Int32, false))),
(1, Arc::new(Field::new("strings", DataType::Utf8, false))),
]
.into_iter()
.collect::<UnionFields>();
let children = vec![Arc::new(int_array) as ArrayRef, Arc::new(str_array)];

// sparse union: ["apple", 67, "pork"], opaque: ["howdy", "john", "pork"]
let union_array = UnionArray::try_new(union_fields, type_ids, None, children).unwrap();
let opaque_array = StringArray::from(vec!["howdy", "john", "pork"]);
let opts = SortOptions::default();

let cmp = make_comparator(&union_array, &opaque_array, opts).unwrap();

// apple < howdy
assert_eq!(cmp(0, 0), Ordering::Less);
// type_id < 1
assert_eq!(cmp(1, 0), Ordering::Less);
// "pork" == "pork"
assert_eq!(cmp(2, 2), Ordering::Equal);
}

#[test]
fn test_union_to_opaque_with_nulls() {
let int_array = Int32Array::from(vec![Some(1), None, Some(3)]);
let str_array = StringArray::from(vec![Some("a"), Some("b")]);
let type_ids = [0, 1, 0, 1, 0].into_iter().collect::<ScalarBuffer<i8>>();
let offsets = [0, 0, 1, 1, 2].into_iter().collect::<ScalarBuffer<i32>>();
let union_fields = [
(0, Arc::new(Field::new("ints", DataType::Int32, false))),
(1, Arc::new(Field::new("strings", DataType::Utf8, false))),
]
.into_iter()
.collect::<UnionFields>();
let children = vec![Arc::new(int_array) as ArrayRef, Arc::new(str_array)];

// union: [1, "a", null, "b", 3], opaque: [2, null, 1]
let union_array =
UnionArray::try_new(union_fields, type_ids, Some(offsets), children).unwrap();
let opaque_array = Int32Array::from(vec![Some(2), None, Some(1)]);
let opts = SortOptions {
descending: false,
nulls_first: true,
};
let cmp = make_comparator(&union_array, &opaque_array, opts).unwrap();

// 1 > null
assert_eq!(cmp(0, 1), Ordering::Greater);
// null < 2
assert_eq!(cmp(2, 0), Ordering::Less);
// null == null
assert_eq!(cmp(2, 1), Ordering::Equal);
}

#[test]
fn test_union_to_opaque_descending() {
let int_array = Int32Array::from(vec![1, 2, 3]);
let str_array = StringArray::from(vec!["a", "b"]);
let type_ids = [0, 1, 0].into_iter().collect::<ScalarBuffer<i8>>();
let offsets = [0, 0, 1].into_iter().collect::<ScalarBuffer<i32>>();
let union_fields = [
(0, Arc::new(Field::new("ints", DataType::Int32, false))),
(1, Arc::new(Field::new("strings", DataType::Utf8, false))),
]
.into_iter()
.collect::<UnionFields>();
let children = vec![Arc::new(int_array) as ArrayRef, Arc::new(str_array)];

// union: [1, "a", 2], opaque: [2, 1]
let union_array =
UnionArray::try_new(union_fields, type_ids, Some(offsets), children).unwrap();
let opaque_array = Int32Array::from(vec![2, 1]);
let opts = SortOptions {
descending: true,
nulls_first: false,
};
let cmp = make_comparator(&union_array, &opaque_array, opts).unwrap();
// 1 > 2 (descending)
assert_eq!(cmp(0, 0), Ordering::Greater);
// 2 == 2
assert_eq!(cmp(2, 0), Ordering::Equal);
// 1 == 1
assert_eq!(cmp(0, 1), Ordering::Equal);
}

#[test]
fn test_union_to_opaque_incompatible_type() {
let int_array = Int32Array::from(vec![1, 2]);
let str_array = StringArray::from(vec!["a", "b"]);
let type_ids = [0, 1].into_iter().collect::<ScalarBuffer<i8>>();
let offsets = [0, 0].into_iter().collect::<ScalarBuffer<i32>>();
let union_fields = [
(0, Arc::new(Field::new("ints", DataType::Int32, false))),
(1, Arc::new(Field::new("strings", DataType::Utf8, false))),
]
.into_iter()
.collect::<UnionFields>();
let children = vec![Arc::new(int_array) as ArrayRef, Arc::new(str_array)];
let union_array =
UnionArray::try_new(union_fields, type_ids, Some(offsets), children).unwrap();
let opaque_array = Float64Array::from(vec![1.0, 2.0]);
let opts = SortOptions::default();
let Err(err) = make_comparator(&union_array, &opaque_array, opts) else {
panic!("expected err");
};

assert!(err.to_string().contains("cannot compare union with"));
}
}
Loading