Skip to content

Commit d92d584

Browse files
compare union apache#8838
1 parent 2bc269c commit d92d584

File tree

1 file changed

+316
-4
lines changed

1 file changed

+316
-4
lines changed

arrow-ord/src/ord.rs

Lines changed: 316 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,8 @@ use arrow_array::cast::AsArray;
2121
use arrow_array::types::*;
2222
use arrow_array::*;
2323
use arrow_buffer::{ArrowNativeType, NullBuffer};
24-
use arrow_schema::{ArrowError, SortOptions};
25-
use std::cmp::Ordering;
24+
use arrow_schema::{ArrowError, DataType, SortOptions};
25+
use std::{cmp::Ordering, collections::HashMap};
2626

2727
/// Compare the values at two arbitrary indices in two arrays.
2828
pub type DynComparator = Box<dyn Fn(usize, usize) -> Ordering + Send + Sync>;
@@ -296,6 +296,78 @@ fn compare_struct(
296296
Ok(f)
297297
}
298298

299+
fn compare_union(
300+
left: &dyn Array,
301+
right: &dyn Array,
302+
opts: SortOptions,
303+
) -> Result<DynComparator, ArrowError> {
304+
let left = left.as_union();
305+
let right = right.as_union();
306+
307+
let (left_fields, left_mode) = match left.data_type() {
308+
DataType::Union(fields, mode) => (fields, mode),
309+
_ => unreachable!(),
310+
};
311+
let (right_fields, right_mode) = match right.data_type() {
312+
DataType::Union(fields, mode) => (fields, mode),
313+
_ => unreachable!(),
314+
};
315+
316+
if left_fields != right_fields {
317+
return Err(ArrowError::InvalidArgumentError(format!(
318+
"Cannot compare UnionArrays with different fields: left={:?}, right={:?}",
319+
left_fields, right_fields
320+
)));
321+
}
322+
323+
if left_mode != right_mode {
324+
return Err(ArrowError::InvalidArgumentError(format!(
325+
"Cannot compare UnionArrays with different modes: left={:?}, right={:?}",
326+
left_mode, right_mode
327+
)));
328+
}
329+
330+
let c_opts = child_opts(opts);
331+
332+
let mut field_comparators = HashMap::with_capacity(left_fields.len());
333+
334+
for (type_id, _field) in left_fields.iter() {
335+
let left_child = left.child(type_id);
336+
let right_child = right.child(type_id);
337+
let cmp = make_comparator(left_child.as_ref(), right_child.as_ref(), c_opts)?;
338+
339+
field_comparators.insert(type_id, cmp);
340+
}
341+
342+
let left_type_ids = left.type_ids().clone();
343+
let right_type_ids = right.type_ids().clone();
344+
345+
let left_offsets = left.offsets().cloned();
346+
let right_offsets = right.offsets().cloned();
347+
348+
let f = compare(left, right, opts, move |i, j| {
349+
let left_type_id = left_type_ids[i];
350+
let right_type_id = right_type_ids[j];
351+
352+
// first, compare by type_id
353+
match left_type_id.cmp(&right_type_id) {
354+
Ordering::Equal => {
355+
// second, compare by values
356+
let left_offset = left_offsets.as_ref().map(|o| o[i] as usize).unwrap_or(i);
357+
let right_offset = right_offsets.as_ref().map(|o| o[j] as usize).unwrap_or(j);
358+
359+
let cmp = field_comparators
360+
.get(&left_type_id)
361+
.expect("type id not found in field_comparators");
362+
363+
cmp(left_offset, right_offset)
364+
}
365+
other => other,
366+
}
367+
});
368+
Ok(f)
369+
}
370+
299371
/// Returns a comparison function that compares two values at two different positions
300372
/// between the two arrays.
301373
///
@@ -412,6 +484,7 @@ pub fn make_comparator(
412484
}
413485
},
414486
(Map(_, _), Map(_, _)) => compare_map(left, right, opts),
487+
(Union(_, _), Union(_, _)) => compare_union(left, right, opts),
415488
(lhs, rhs) => Err(ArrowError::InvalidArgumentError(match lhs == rhs {
416489
true => format!("The data type type {lhs:?} has no natural order"),
417490
false => "Can't compare arrays of different types".to_string(),
@@ -423,8 +496,8 @@ pub fn make_comparator(
423496
mod tests {
424497
use super::*;
425498
use arrow_array::builder::{Int32Builder, ListBuilder, MapBuilder, StringBuilder};
426-
use arrow_buffer::{IntervalDayTime, OffsetBuffer, i256};
427-
use arrow_schema::{DataType, Field, Fields};
499+
use arrow_buffer::{IntervalDayTime, OffsetBuffer, ScalarBuffer, i256};
500+
use arrow_schema::{DataType, Field, Fields, UnionFields};
428501
use half::f16;
429502
use std::sync::Arc;
430503

@@ -1189,4 +1262,243 @@ mod tests {
11891262
}
11901263
}
11911264
}
1265+
1266+
#[test]
1267+
fn test_dense_union() {
1268+
// create a dense union array with Int32 (type_id = 0) and Utf8 (type_id=1)
1269+
// the values are: [1, "b", 2, "a", 3]
1270+
// type_ids are: [0, 1, 0, 1, 0]
1271+
// offsets are: [0, 0, 1, 1, 2] from [1, 2, 3] and ["b", "a"]
1272+
let int_array = Int32Array::from(vec![1, 2, 3]);
1273+
let str_array = StringArray::from(vec!["b", "a"]);
1274+
1275+
let type_ids = [0, 1, 0, 1, 0].into_iter().collect::<ScalarBuffer<i8>>();
1276+
let offsets = [0, 0, 1, 1, 2].into_iter().collect::<ScalarBuffer<i32>>();
1277+
1278+
let union_fields = [
1279+
(0, Arc::new(Field::new("A", DataType::Int32, false))),
1280+
(1, Arc::new(Field::new("B", DataType::Utf8, false))),
1281+
]
1282+
.into_iter()
1283+
.collect::<UnionFields>();
1284+
1285+
let children = vec![Arc::new(int_array) as ArrayRef, Arc::new(str_array)];
1286+
1287+
let array1 =
1288+
UnionArray::try_new(union_fields.clone(), type_ids, Some(offsets), children).unwrap();
1289+
1290+
// create a second array: [2, "a", 1, "c"]
1291+
// type ids are: [0, 1, 0, 1]
1292+
// offsets are: [0, 0, 1, 1] from [2, 1] and ["a", "c"]
1293+
let int_array2 = Int32Array::from(vec![2, 1]);
1294+
let str_array2 = StringArray::from(vec!["a", "c"]);
1295+
let type_ids2 = [0, 1, 0, 1].into_iter().collect::<ScalarBuffer<i8>>();
1296+
let offsets2 = [0, 0, 1, 1].into_iter().collect::<ScalarBuffer<i32>>();
1297+
1298+
let children2 = vec![Arc::new(int_array2) as ArrayRef, Arc::new(str_array2)];
1299+
1300+
let array2 =
1301+
UnionArray::try_new(union_fields, type_ids2, Some(offsets2), children2).unwrap();
1302+
1303+
let opts = SortOptions {
1304+
descending: false,
1305+
nulls_first: true,
1306+
};
1307+
1308+
// comparing
1309+
// [1, "b", 2, "a", 3]
1310+
// [2, "a", 1, "c"]
1311+
let cmp = make_comparator(&array1, &array2, opts).unwrap();
1312+
1313+
// array1[0] = (type_id=0, value=1)
1314+
// array2[0] = (type_id=0, value=2)
1315+
assert_eq!(cmp(0, 0), Ordering::Less); // 1 < 2
1316+
1317+
// array1[0] = (type_id=0, value=1)
1318+
// array2[1] = (type_id=1, value="a")
1319+
assert_eq!(cmp(0, 1), Ordering::Less); // type_id 0 < 1
1320+
1321+
// array1[1] = (type_id=1, value="b")
1322+
// array2[1] = (type_id=1, value="a")
1323+
assert_eq!(cmp(1, 1), Ordering::Greater); // "b" > "a"
1324+
1325+
// array1[2] = (type_id=0, value=2)
1326+
// array2[0] = (type_id=0, value=2)
1327+
assert_eq!(cmp(2, 0), Ordering::Equal); // 2 == 2
1328+
1329+
// array1[3] = (type_id=1, value="a")
1330+
// array2[1] = (type_id=1, value="a")
1331+
assert_eq!(cmp(3, 1), Ordering::Equal); // "a" == "a"
1332+
1333+
// array1[1] = (type_id=1, value="b")
1334+
// array2[3] = (type_id=1, value="c")
1335+
assert_eq!(cmp(1, 3), Ordering::Less); // "b" < "c"
1336+
1337+
let opts_desc = SortOptions {
1338+
descending: true,
1339+
nulls_first: true,
1340+
};
1341+
let cmp_desc = make_comparator(&array1, &array2, opts_desc).unwrap();
1342+
1343+
assert_eq!(cmp_desc(0, 0), Ordering::Greater); // 1 > 2 (reversed)
1344+
assert_eq!(cmp_desc(0, 1), Ordering::Greater); // type_id 0 < 1, reversed to Greater
1345+
assert_eq!(cmp_desc(1, 1), Ordering::Less); // "b" < "a" (reversed)
1346+
}
1347+
1348+
#[test]
1349+
fn test_sparse_union() {
1350+
// create a sparse union array with Int32 (type_id=0) and Utf8 (type_id=1)
1351+
// values: [1, "b", 3]
1352+
// note, in sparse unions, child arrays have the same length as the union
1353+
let int_array = Int32Array::from(vec![Some(1), None, Some(3)]);
1354+
let str_array = StringArray::from(vec![None, Some("b"), None]);
1355+
let type_ids = [0, 1, 0].into_iter().collect::<ScalarBuffer<i8>>();
1356+
1357+
let union_fields = [
1358+
(0, Arc::new(Field::new("a", DataType::Int32, false))),
1359+
(1, Arc::new(Field::new("b", DataType::Utf8, false))),
1360+
]
1361+
.into_iter()
1362+
.collect::<UnionFields>();
1363+
1364+
let children = vec![Arc::new(int_array) as ArrayRef, Arc::new(str_array)];
1365+
1366+
let array = UnionArray::try_new(union_fields, type_ids, None, children).unwrap();
1367+
1368+
let opts = SortOptions::default();
1369+
let cmp = make_comparator(&array, &array, opts).unwrap();
1370+
1371+
// array[0] = (type_id=0, value=1), array[2] = (type_id=0, value=3)
1372+
assert_eq!(cmp(0, 2), Ordering::Less); // 1 < 3
1373+
// array[0] = (type_id=0, value=1), array[1] = (type_id=1, value="b")
1374+
assert_eq!(cmp(0, 1), Ordering::Less); // type_id 0 < 1
1375+
}
1376+
1377+
#[test]
1378+
#[should_panic(expected = "index out of bounds")]
1379+
fn test_union_out_of_bounds() {
1380+
// create a dense union array with 3 elements
1381+
let int_array = Int32Array::from(vec![1, 2]);
1382+
let str_array = StringArray::from(vec!["a"]);
1383+
1384+
let type_ids = [0, 1, 0].into_iter().collect::<ScalarBuffer<i8>>();
1385+
let offsets = [0, 0, 1].into_iter().collect::<ScalarBuffer<i32>>();
1386+
1387+
let union_fields = [
1388+
(0, Arc::new(Field::new("A", DataType::Int32, false))),
1389+
(1, Arc::new(Field::new("B", DataType::Utf8, false))),
1390+
]
1391+
.into_iter()
1392+
.collect::<UnionFields>();
1393+
1394+
let children = vec![Arc::new(int_array) as ArrayRef, Arc::new(str_array)];
1395+
1396+
let array = UnionArray::try_new(union_fields, type_ids, Some(offsets), children).unwrap();
1397+
1398+
let opts = SortOptions::default();
1399+
let cmp = make_comparator(&array, &array, opts).unwrap();
1400+
1401+
// oob
1402+
cmp(0, 3);
1403+
}
1404+
1405+
#[test]
1406+
fn test_union_incompatible_fields() {
1407+
// create first union with Int32 and Utf8
1408+
let int_array1 = Int32Array::from(vec![1, 2]);
1409+
let str_array1 = StringArray::from(vec!["a", "b"]);
1410+
1411+
let type_ids1 = [0, 1].into_iter().collect::<ScalarBuffer<i8>>();
1412+
let offsets1 = [0, 0].into_iter().collect::<ScalarBuffer<i32>>();
1413+
1414+
let union_fields1 = [
1415+
(0, Arc::new(Field::new("A", DataType::Int32, false))),
1416+
(1, Arc::new(Field::new("B", DataType::Utf8, false))),
1417+
]
1418+
.into_iter()
1419+
.collect::<UnionFields>();
1420+
1421+
let children1 = vec![Arc::new(int_array1) as ArrayRef, Arc::new(str_array1)];
1422+
1423+
let array1 =
1424+
UnionArray::try_new(union_fields1, type_ids1, Some(offsets1), children1).unwrap();
1425+
1426+
// create second union with Int32 and Float64 (incompatible with first)
1427+
let int_array2 = Int32Array::from(vec![3, 4]);
1428+
let float_array2 = Float64Array::from(vec![1.0, 2.0]);
1429+
1430+
let type_ids2 = [0, 1].into_iter().collect::<ScalarBuffer<i8>>();
1431+
let offsets2 = [0, 0].into_iter().collect::<ScalarBuffer<i32>>();
1432+
1433+
let union_fields2 = [
1434+
(0, Arc::new(Field::new("A", DataType::Int32, false))),
1435+
(1, Arc::new(Field::new("C", DataType::Float64, false))),
1436+
]
1437+
.into_iter()
1438+
.collect::<UnionFields>();
1439+
1440+
let children2 = vec![Arc::new(int_array2) as ArrayRef, Arc::new(float_array2)];
1441+
1442+
let array2 =
1443+
UnionArray::try_new(union_fields2, type_ids2, Some(offsets2), children2).unwrap();
1444+
1445+
let opts = SortOptions::default();
1446+
1447+
let Result::Err(ArrowError::InvalidArgumentError(out)) =
1448+
make_comparator(&array1, &array2, opts)
1449+
else {
1450+
panic!("expected error when making comparator of incompatible union arrays");
1451+
};
1452+
1453+
assert_eq!(
1454+
&out,
1455+
"Cannot compare UnionArrays with different fields: left=[(0, Field { name: \"A\", data_type: Int32 }), (1, Field { name: \"B\", data_type: Utf8 })], right=[(0, Field { name: \"A\", data_type: Int32 }), (1, Field { name: \"C\", data_type: Float64 })]"
1456+
);
1457+
}
1458+
1459+
#[test]
1460+
fn test_union_incompatible_modes() {
1461+
// create first union as Dense with Int32 and Utf8
1462+
let int_array1 = Int32Array::from(vec![1, 2]);
1463+
let str_array1 = StringArray::from(vec!["a", "b"]);
1464+
1465+
let type_ids1 = [0, 1].into_iter().collect::<ScalarBuffer<i8>>();
1466+
let offsets1 = [0, 0].into_iter().collect::<ScalarBuffer<i32>>();
1467+
1468+
let union_fields1 = [
1469+
(0, Arc::new(Field::new("A", DataType::Int32, false))),
1470+
(1, Arc::new(Field::new("B", DataType::Utf8, false))),
1471+
]
1472+
.into_iter()
1473+
.collect::<UnionFields>();
1474+
1475+
let children1 = vec![Arc::new(int_array1) as ArrayRef, Arc::new(str_array1)];
1476+
1477+
let array1 =
1478+
UnionArray::try_new(union_fields1.clone(), type_ids1, Some(offsets1), children1)
1479+
.unwrap();
1480+
1481+
// create second union as Sparse with same fields (Int32 and Utf8)
1482+
let int_array2 = Int32Array::from(vec![Some(3), None]);
1483+
let str_array2 = StringArray::from(vec![None, Some("c")]);
1484+
1485+
let type_ids2 = [0, 1].into_iter().collect::<ScalarBuffer<i8>>();
1486+
1487+
let children2 = vec![Arc::new(int_array2) as ArrayRef, Arc::new(str_array2)];
1488+
1489+
let array2 = UnionArray::try_new(union_fields1, type_ids2, None, children2).unwrap();
1490+
1491+
let opts = SortOptions::default();
1492+
1493+
let Result::Err(ArrowError::InvalidArgumentError(out)) =
1494+
make_comparator(&array1, &array2, opts)
1495+
else {
1496+
panic!("expected error when making comparator of union arrays with different modes");
1497+
};
1498+
1499+
assert_eq!(
1500+
&out,
1501+
"Cannot compare UnionArrays with different modes: left=Dense, right=Sparse"
1502+
);
1503+
}
11921504
}

0 commit comments

Comments
 (0)