-
Notifications
You must be signed in to change notification settings - Fork 1.1k
Add union to opaque comparisons #8896
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -36,9 +36,10 @@ fn child_opts(opts: SortOptions) -> SortOptions { | |
| } | ||
| } | ||
|
|
||
| fn compare<A, F>(l: &A, r: &A, opts: SortOptions, cmp: F) -> DynComparator | ||
| fn compare<A, B, F>(l: &A, r: &B, opts: SortOptions, cmp: F) -> DynComparator | ||
| where | ||
| A: Array + Clone, | ||
| A: Array + ?Sized, | ||
| B: Array + ?Sized, | ||
| F: Fn(usize, usize) -> Ordering + Send + Sync + 'static, | ||
| { | ||
| let l = l.logical_nulls().filter(|x| x.null_count() > 0); | ||
|
|
@@ -368,6 +369,52 @@ fn compare_union( | |
| Ok(f) | ||
| } | ||
|
|
||
| fn compare_union_to_opaque( | ||
| union_array: &dyn Array, | ||
| opaque_array: &dyn Array, | ||
| opts: SortOptions, | ||
| ) -> Result<DynComparator, ArrowError> { | ||
| let union_array = union_array.as_union(); | ||
|
|
||
| let DataType::Union(union_fields, _) = union_array.data_type() else { | ||
| unreachable!() | ||
| }; | ||
|
|
||
| let opaque_type_id = union_fields | ||
| .iter() | ||
| .find_map(|(i, f)| (f.data_type() == opaque_array.data_type()).then_some(i)) | ||
| .ok_or_else(|| { | ||
| ArrowError::InvalidArgumentError(format!( | ||
| "cannot compare union with {} array: type not found in union fields", | ||
| opaque_array.data_type(), | ||
| )) | ||
| })?; | ||
|
|
||
| let c_opts = child_opts(opts); | ||
|
|
||
| let opaque_field_comparator = { | ||
| let union_child = union_array.child(opaque_type_id); | ||
| make_comparator(union_child.as_ref(), opaque_array, c_opts)? | ||
| }; | ||
|
|
||
| let union_type_ids = union_array.type_ids().clone(); | ||
| let union_offsets = union_array.offsets().cloned(); | ||
|
|
||
| let f = compare(union_array, opaque_array, opts, move |i, j| { | ||
| let union_type_id = union_type_ids[i]; | ||
|
|
||
| match union_type_id.cmp(&opaque_type_id) { | ||
| Ordering::Equal => { | ||
| let union_offset = union_offsets.as_ref().map(|o| o[i] as usize).unwrap_or(i); | ||
| opaque_field_comparator(union_offset, j) | ||
| } | ||
| other => other, | ||
| } | ||
| }); | ||
|
|
||
| Ok(f) | ||
| } | ||
|
|
||
| /// Returns a comparison function that compares two values at two different positions | ||
| /// between the two arrays. | ||
| /// | ||
|
|
@@ -485,6 +532,8 @@ pub fn make_comparator( | |
| }, | ||
| (Map(_, _), Map(_, _)) => compare_map(left, right, opts), | ||
| (Union(_, _), Union(_, _)) => compare_union(left, right, opts), | ||
| (Union(_, _), _) => compare_union_to_opaque(left, right, opts), | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is there any other example of comparing one array to an array of another type? I think this kernel expects the types to match exactly I do see the argument about how UnionArray could be treated as a special case (as some sort of runtime typed thing) 🤔 but I am not sure UnionArrays are required to behave that way I wrote up another way that your usecase might be handled Let me know what you think |
||
| (_, Union(_, _)) => compare_union_to_opaque(right, left, opts), | ||
| (lhs, rhs) => Err(ArrowError::InvalidArgumentError(match lhs == rhs { | ||
| true => format!("The data type type {lhs:?} has no natural order"), | ||
| false => "Can't compare arrays of different types".to_string(), | ||
|
|
@@ -1501,4 +1550,151 @@ mod tests { | |
| "Cannot compare UnionArrays with different modes: left=Dense, right=Sparse" | ||
| ); | ||
| } | ||
|
|
||
| #[test] | ||
| fn test_union_to_opaque_int32() { | ||
| let int_array = Int32Array::from(vec![1, 2, 3]); | ||
| let str_array = StringArray::from(vec!["a", "b"]); | ||
| let type_ids = [0, 1, 0, 1, 0].into_iter().collect::<ScalarBuffer<i8>>(); | ||
| let offsets = [0, 0, 1, 1, 2].into_iter().collect::<ScalarBuffer<i32>>(); | ||
| let union_fields = [ | ||
| (0, Arc::new(Field::new("ints", DataType::Int32, false))), | ||
| (1, Arc::new(Field::new("strings", DataType::Utf8, false))), | ||
| ] | ||
| .into_iter() | ||
| .collect::<UnionFields>(); | ||
| let children = vec![Arc::new(int_array) as ArrayRef, Arc::new(str_array)]; | ||
|
|
||
| // union: [1, "a", 2, "b", 3], opaque: [2, 6, 7] | ||
| let union_array = | ||
| UnionArray::try_new(union_fields, type_ids, Some(offsets), children).unwrap(); | ||
| let opaque_array = Int32Array::from(vec![2, 6, 7]); | ||
| let opts = SortOptions::default(); | ||
|
|
||
| let cmp = make_comparator(&union_array, &opaque_array, opts).unwrap(); | ||
|
|
||
| // 1 < 2 | ||
| assert_eq!(cmp(0, 0), Ordering::Less); | ||
| // type_id 1 > 0 | ||
| assert_eq!(cmp(1, 0), Ordering::Greater); | ||
| // 2 == 2 | ||
| assert_eq!(cmp(2, 0), Ordering::Equal); | ||
| // 3 > 6 | ||
| assert_eq!(cmp(4, 1), Ordering::Less); | ||
| } | ||
|
|
||
| #[test] | ||
| fn test_union_to_opaque_string() { | ||
| let str_array = StringArray::from(vec![Some("apple"), None, Some("pork")]); | ||
| let int_array = Int32Array::from(vec![None, Some(67), None]); | ||
| let type_ids = [1, 0, 1].into_iter().collect::<ScalarBuffer<i8>>(); | ||
| let union_fields = [ | ||
| (0, Arc::new(Field::new("ints", DataType::Int32, false))), | ||
| (1, Arc::new(Field::new("strings", DataType::Utf8, false))), | ||
| ] | ||
| .into_iter() | ||
| .collect::<UnionFields>(); | ||
| let children = vec![Arc::new(int_array) as ArrayRef, Arc::new(str_array)]; | ||
|
|
||
| // sparse union: ["apple", 67, "pork"], opaque: ["howdy", "john", "pork"] | ||
| let union_array = UnionArray::try_new(union_fields, type_ids, None, children).unwrap(); | ||
| let opaque_array = StringArray::from(vec!["howdy", "john", "pork"]); | ||
| let opts = SortOptions::default(); | ||
|
|
||
| let cmp = make_comparator(&union_array, &opaque_array, opts).unwrap(); | ||
|
|
||
| // apple < howdy | ||
| assert_eq!(cmp(0, 0), Ordering::Less); | ||
| // type_id < 1 | ||
| assert_eq!(cmp(1, 0), Ordering::Less); | ||
| // "pork" == "pork" | ||
| assert_eq!(cmp(2, 2), Ordering::Equal); | ||
| } | ||
|
|
||
| #[test] | ||
| fn test_union_to_opaque_with_nulls() { | ||
| let int_array = Int32Array::from(vec![Some(1), None, Some(3)]); | ||
| let str_array = StringArray::from(vec![Some("a"), Some("b")]); | ||
| let type_ids = [0, 1, 0, 1, 0].into_iter().collect::<ScalarBuffer<i8>>(); | ||
| let offsets = [0, 0, 1, 1, 2].into_iter().collect::<ScalarBuffer<i32>>(); | ||
| let union_fields = [ | ||
| (0, Arc::new(Field::new("ints", DataType::Int32, false))), | ||
| (1, Arc::new(Field::new("strings", DataType::Utf8, false))), | ||
| ] | ||
| .into_iter() | ||
| .collect::<UnionFields>(); | ||
| let children = vec![Arc::new(int_array) as ArrayRef, Arc::new(str_array)]; | ||
|
|
||
| // union: [1, "a", null, "b", 3], opaque: [2, null, 1] | ||
| let union_array = | ||
| UnionArray::try_new(union_fields, type_ids, Some(offsets), children).unwrap(); | ||
| let opaque_array = Int32Array::from(vec![Some(2), None, Some(1)]); | ||
| let opts = SortOptions { | ||
| descending: false, | ||
| nulls_first: true, | ||
| }; | ||
| let cmp = make_comparator(&union_array, &opaque_array, opts).unwrap(); | ||
|
|
||
| // 1 > null | ||
| assert_eq!(cmp(0, 1), Ordering::Greater); | ||
| // null < 2 | ||
| assert_eq!(cmp(2, 0), Ordering::Less); | ||
| // null == null | ||
| assert_eq!(cmp(2, 1), Ordering::Equal); | ||
| } | ||
|
|
||
| #[test] | ||
| fn test_union_to_opaque_descending() { | ||
| let int_array = Int32Array::from(vec![1, 2, 3]); | ||
| let str_array = StringArray::from(vec!["a", "b"]); | ||
| let type_ids = [0, 1, 0].into_iter().collect::<ScalarBuffer<i8>>(); | ||
| let offsets = [0, 0, 1].into_iter().collect::<ScalarBuffer<i32>>(); | ||
| let union_fields = [ | ||
| (0, Arc::new(Field::new("ints", DataType::Int32, false))), | ||
| (1, Arc::new(Field::new("strings", DataType::Utf8, false))), | ||
| ] | ||
| .into_iter() | ||
| .collect::<UnionFields>(); | ||
| let children = vec![Arc::new(int_array) as ArrayRef, Arc::new(str_array)]; | ||
|
|
||
| // union: [1, "a", 2], opaque: [2, 1] | ||
| let union_array = | ||
| UnionArray::try_new(union_fields, type_ids, Some(offsets), children).unwrap(); | ||
| let opaque_array = Int32Array::from(vec![2, 1]); | ||
| let opts = SortOptions { | ||
| descending: true, | ||
| nulls_first: false, | ||
| }; | ||
| let cmp = make_comparator(&union_array, &opaque_array, opts).unwrap(); | ||
| // 1 > 2 (descending) | ||
| assert_eq!(cmp(0, 0), Ordering::Greater); | ||
| // 2 == 2 | ||
| assert_eq!(cmp(2, 0), Ordering::Equal); | ||
| // 1 == 1 | ||
| assert_eq!(cmp(0, 1), Ordering::Equal); | ||
| } | ||
|
|
||
| #[test] | ||
| fn test_union_to_opaque_incompatible_type() { | ||
| let int_array = Int32Array::from(vec![1, 2]); | ||
| let str_array = StringArray::from(vec!["a", "b"]); | ||
| let type_ids = [0, 1].into_iter().collect::<ScalarBuffer<i8>>(); | ||
| let offsets = [0, 0].into_iter().collect::<ScalarBuffer<i32>>(); | ||
| let union_fields = [ | ||
| (0, Arc::new(Field::new("ints", DataType::Int32, false))), | ||
| (1, Arc::new(Field::new("strings", DataType::Utf8, false))), | ||
| ] | ||
| .into_iter() | ||
| .collect::<UnionFields>(); | ||
| let children = vec![Arc::new(int_array) as ArrayRef, Arc::new(str_array)]; | ||
| let union_array = | ||
| UnionArray::try_new(union_fields, type_ids, Some(offsets), children).unwrap(); | ||
| let opaque_array = Float64Array::from(vec![1.0, 2.0]); | ||
| let opts = SortOptions::default(); | ||
| let Err(err) = make_comparator(&union_array, &opaque_array, opts) else { | ||
| panic!("expected err"); | ||
| }; | ||
|
|
||
| assert!(err.to_string().contains("cannot compare union with")); | ||
| } | ||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am not sure about this -- it seems like it will compare the union array to the first field that has the same type.
What if the union array has multiple repeated types (I realize that is not a common use case, but I think it is possible)