Skip to content

Commit 87c792b

Browse files
compare_union
1 parent 2bc269c commit 87c792b

1 file changed

Lines changed: 180 additions & 4 deletions

File tree

arrow-ord/src/ord.rs

Lines changed: 180 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,8 @@ use arrow_array::cast::AsArray;
2121
use arrow_array::types::*;
2222
use arrow_array::*;
2323
use arrow_buffer::{ArrowNativeType, NullBuffer};
24-
use arrow_schema::{ArrowError, SortOptions};
25-
use std::cmp::Ordering;
24+
use arrow_schema::{ArrowError, DataType, SortOptions};
25+
use std::{cmp::Ordering, collections::HashMap};
2626

2727
/// Compare the values at two arbitrary indices in two arrays.
2828
pub type DynComparator = Box<dyn Fn(usize, usize) -> Ordering + Send + Sync>;
@@ -296,6 +296,70 @@ fn compare_struct(
296296
Ok(f)
297297
}
298298

299+
fn compare_union(
300+
left: &dyn Array,
301+
right: &dyn Array,
302+
opts: SortOptions,
303+
) -> Result<DynComparator, ArrowError> {
304+
let left = left.as_union();
305+
let right = right.as_union();
306+
307+
let (left_fields, left_mode) = match left.data_type() {
308+
DataType::Union(fields, mode) => (fields, mode),
309+
_ => unreachable!(),
310+
};
311+
let (right_fields, right_mode) = match right.data_type() {
312+
DataType::Union(fields, mode) => (fields, mode),
313+
_ => unreachable!(),
314+
};
315+
316+
if left_fields != right_fields || left_mode != right_mode {
317+
return Err(ArrowError::InvalidArgumentError(
318+
"Cannot compare UnionArrays with different fields or modes".to_string(),
319+
));
320+
}
321+
322+
let c_opts = child_opts(opts);
323+
324+
let mut field_comparators = HashMap::with_capacity(left_fields.len());
325+
326+
for (type_id, _field) in left_fields.iter() {
327+
let left_child = left.child(type_id);
328+
let right_child = right.child(type_id);
329+
let cmp = make_comparator(left_child.as_ref(), right_child.as_ref(), c_opts)?;
330+
331+
field_comparators.insert(type_id, cmp);
332+
}
333+
334+
let left_type_ids = left.type_ids().clone();
335+
let right_type_ids = right.type_ids().clone();
336+
337+
let left_offsets = left.offsets().cloned();
338+
let right_offsets = right.offsets().cloned();
339+
340+
let f = compare(left, right, opts, move |i, j| {
341+
let left_type_id = left_type_ids[i];
342+
let right_type_id = right_type_ids[j];
343+
344+
// first, compare by type_id
345+
match left_type_id.cmp(&right_type_id) {
346+
Ordering::Equal => {
347+
// second, compare by values
348+
let left_offset = left_offsets.as_ref().map(|o| o[i] as usize).unwrap_or(i);
349+
let right_offset = right_offsets.as_ref().map(|o| o[j] as usize).unwrap_or(j);
350+
351+
let cmp = field_comparators
352+
.get(&left_type_id)
353+
.expect("type id not found in field_comparators");
354+
355+
cmp(left_offset, right_offset)
356+
}
357+
other => other,
358+
}
359+
});
360+
Ok(f)
361+
}
362+
299363
/// Returns a comparison function that compares two values at two different positions
300364
/// between the two arrays.
301365
///
@@ -412,6 +476,7 @@ pub fn make_comparator(
412476
}
413477
},
414478
(Map(_, _), Map(_, _)) => compare_map(left, right, opts),
479+
(Union(_, _), Union(_, _)) => compare_union(left, right, opts),
415480
(lhs, rhs) => Err(ArrowError::InvalidArgumentError(match lhs == rhs {
416481
true => format!("The data type type {lhs:?} has no natural order"),
417482
false => "Can't compare arrays of different types".to_string(),
@@ -423,8 +488,8 @@ pub fn make_comparator(
423488
mod tests {
424489
use super::*;
425490
use arrow_array::builder::{Int32Builder, ListBuilder, MapBuilder, StringBuilder};
426-
use arrow_buffer::{IntervalDayTime, OffsetBuffer, i256};
427-
use arrow_schema::{DataType, Field, Fields};
491+
use arrow_buffer::{IntervalDayTime, OffsetBuffer, ScalarBuffer, i256};
492+
use arrow_schema::{DataType, Field, Fields, UnionFields};
428493
use half::f16;
429494
use std::sync::Arc;
430495

@@ -1189,4 +1254,115 @@ mod tests {
11891254
}
11901255
}
11911256
}
1257+
1258+
#[test]
1259+
fn test_dense_union() {
1260+
// create a dense union array with Int32 (type_id = 0) and Utf8 (type_id=1)
1261+
// the values are: [1, "b", 2, "a", 3]
1262+
// type_ids are: [0, 1, 0, 1, 0]
1263+
// offsets are: [0, 0, 1, 1, 2] from [1, 2, 3] and ["b", "a"]
1264+
let int_array = Int32Array::from(vec![1, 2, 3]);
1265+
let str_array = StringArray::from(vec!["b", "a"]);
1266+
1267+
let type_ids = [0, 1, 0, 1, 0].into_iter().collect::<ScalarBuffer<i8>>();
1268+
let offsets = [0, 0, 1, 1, 2].into_iter().collect::<ScalarBuffer<i32>>();
1269+
1270+
let union_fields = [
1271+
(0, Arc::new(Field::new("A", DataType::Int32, false))),
1272+
(1, Arc::new(Field::new("B", DataType::Utf8, false))),
1273+
]
1274+
.into_iter()
1275+
.collect::<UnionFields>();
1276+
1277+
let children = vec![Arc::new(int_array) as ArrayRef, Arc::new(str_array)];
1278+
1279+
let array1 =
1280+
UnionArray::try_new(union_fields.clone(), type_ids, Some(offsets), children).unwrap();
1281+
1282+
// create a second array: [2, "a", 1, "c"]
1283+
// type ids are: [0, 1, 0, 1]
1284+
// offsets are: [0, 0, 1, 1] from [2, 1] and ["a", "c"]
1285+
let int_array2 = Int32Array::from(vec![2, 1]);
1286+
let str_array2 = StringArray::from(vec!["a", "c"]);
1287+
let type_ids2 = [0, 1, 0, 1].into_iter().collect::<ScalarBuffer<i8>>();
1288+
let offsets2 = [0, 0, 1, 1].into_iter().collect::<ScalarBuffer<i32>>();
1289+
1290+
let children2 = vec![Arc::new(int_array2) as ArrayRef, Arc::new(str_array2)];
1291+
1292+
let array2 =
1293+
UnionArray::try_new(union_fields, type_ids2, Some(offsets2), children2).unwrap();
1294+
1295+
let opts = SortOptions {
1296+
descending: false,
1297+
nulls_first: true,
1298+
};
1299+
1300+
// comparing
1301+
// [1, "b", 2, "a", 3]
1302+
// [2, "a", 1, "c"]
1303+
let cmp = make_comparator(&array1, &array2, opts).unwrap();
1304+
1305+
// array1[0] = (type_id=0, value=1)
1306+
// array2[0] = (type_id=0, value=2)
1307+
assert_eq!(cmp(0, 0), Ordering::Less); // 1 < 2
1308+
1309+
// array1[0] = (type_id=0, value=1)
1310+
// array2[1] = (type_id=1, value="a")
1311+
assert_eq!(cmp(0, 1), Ordering::Less); // type_id 0 < 1
1312+
1313+
// array1[1] = (type_id=1, value="b")
1314+
// array2[1] = (type_id=1, value="a")
1315+
assert_eq!(cmp(1, 1), Ordering::Greater); // "b" > "a"
1316+
1317+
// array1[2] = (type_id=0, value=2)
1318+
// array2[0] = (type_id=0, value=2)
1319+
assert_eq!(cmp(2, 0), Ordering::Equal); // 2 == 2
1320+
1321+
// array1[3] = (type_id=1, value="a")
1322+
// array2[1] = (type_id=1, value="a")
1323+
assert_eq!(cmp(3, 1), Ordering::Equal); // "a" == "a"
1324+
1325+
// array1[1] = (type_id=1, value="b")
1326+
// array2[3] = (type_id=1, value="c")
1327+
assert_eq!(cmp(1, 3), Ordering::Less); // "b" < "c"
1328+
1329+
let opts_desc = SortOptions {
1330+
descending: true,
1331+
nulls_first: true,
1332+
};
1333+
let cmp_desc = make_comparator(&array1, &array2, opts_desc).unwrap();
1334+
1335+
assert_eq!(cmp_desc(0, 0), Ordering::Greater); // 1 > 2 (reversed)
1336+
assert_eq!(cmp_desc(0, 1), Ordering::Greater); // type_id 0 < 1, reversed to Greater
1337+
assert_eq!(cmp_desc(1, 1), Ordering::Less); // "b" < "a" (reversed)
1338+
}
1339+
1340+
#[test]
1341+
fn test_sparse_union() {
1342+
// create a sparse union array with Int32 (type_id=0) and Utf8 (type_id=1)
1343+
// values: [1, "b", 3]
1344+
// note, in sparse unions, child arrays have the same length as the union
1345+
let int_array = Int32Array::from(vec![Some(1), None, Some(3)]);
1346+
let str_array = StringArray::from(vec![None, Some("b"), None]);
1347+
let type_ids = [0, 1, 0].into_iter().collect::<ScalarBuffer<i8>>();
1348+
1349+
let union_fields = [
1350+
(0, Arc::new(Field::new("a", DataType::Int32, false))),
1351+
(1, Arc::new(Field::new("b", DataType::Utf8, false))),
1352+
]
1353+
.into_iter()
1354+
.collect::<UnionFields>();
1355+
1356+
let children = vec![Arc::new(int_array) as ArrayRef, Arc::new(str_array)];
1357+
1358+
let array = UnionArray::try_new(union_fields, type_ids, None, children).unwrap();
1359+
1360+
let opts = SortOptions::default();
1361+
let cmp = make_comparator(&array, &array, opts).unwrap();
1362+
1363+
// array[0] = (type_id=0, value=1), array[2] = (type_id=0, value=3)
1364+
assert_eq!(cmp(0, 2), Ordering::Less); // 1 < 3
1365+
// array[0] = (type_id=0, value=1), array[1] = (type_id=1, value="b")
1366+
assert_eq!(cmp(0, 1), Ordering::Less); // type_id 0 < 1
1367+
}
11921368
}

0 commit comments

Comments
 (0)