@@ -21,8 +21,8 @@ use arrow_array::cast::AsArray;
2121use arrow_array:: types:: * ;
2222use arrow_array:: * ;
2323use arrow_buffer:: { ArrowNativeType , NullBuffer } ;
24- use arrow_schema:: { ArrowError , SortOptions } ;
25- use std:: cmp:: Ordering ;
24+ use arrow_schema:: { ArrowError , DataType , SortOptions } ;
25+ use std:: { cmp:: Ordering , collections :: HashMap } ;
2626
2727/// Compare the values at two arbitrary indices in two arrays.
2828pub type DynComparator = Box < dyn Fn ( usize , usize ) -> Ordering + Send + Sync > ;
@@ -296,6 +296,78 @@ fn compare_struct(
296296 Ok ( f)
297297}
298298
299+ fn compare_union (
300+ left : & dyn Array ,
301+ right : & dyn Array ,
302+ opts : SortOptions ,
303+ ) -> Result < DynComparator , ArrowError > {
304+ let left = left. as_union ( ) ;
305+ let right = right. as_union ( ) ;
306+
307+ let ( left_fields, left_mode) = match left. data_type ( ) {
308+ DataType :: Union ( fields, mode) => ( fields, mode) ,
309+ _ => unreachable ! ( ) ,
310+ } ;
311+ let ( right_fields, right_mode) = match right. data_type ( ) {
312+ DataType :: Union ( fields, mode) => ( fields, mode) ,
313+ _ => unreachable ! ( ) ,
314+ } ;
315+
316+ if left_fields != right_fields {
317+ return Err ( ArrowError :: InvalidArgumentError ( format ! (
318+ "Cannot compare UnionArrays with different fields: left={:?}, right={:?}" ,
319+ left_fields, right_fields
320+ ) ) ) ;
321+ }
322+
323+ if left_mode != right_mode {
324+ return Err ( ArrowError :: InvalidArgumentError ( format ! (
325+ "Cannot compare UnionArrays with different modes: left={:?}, right={:?}" ,
326+ left_mode, right_mode
327+ ) ) ) ;
328+ }
329+
330+ let c_opts = child_opts ( opts) ;
331+
332+ let mut field_comparators = HashMap :: with_capacity ( left_fields. len ( ) ) ;
333+
334+ for ( type_id, _field) in left_fields. iter ( ) {
335+ let left_child = left. child ( type_id) ;
336+ let right_child = right. child ( type_id) ;
337+ let cmp = make_comparator ( left_child. as_ref ( ) , right_child. as_ref ( ) , c_opts) ?;
338+
339+ field_comparators. insert ( type_id, cmp) ;
340+ }
341+
342+ let left_type_ids = left. type_ids ( ) . clone ( ) ;
343+ let right_type_ids = right. type_ids ( ) . clone ( ) ;
344+
345+ let left_offsets = left. offsets ( ) . cloned ( ) ;
346+ let right_offsets = right. offsets ( ) . cloned ( ) ;
347+
348+ let f = compare ( left, right, opts, move |i, j| {
349+ let left_type_id = left_type_ids[ i] ;
350+ let right_type_id = right_type_ids[ j] ;
351+
352+ // first, compare by type_id
353+ match left_type_id. cmp ( & right_type_id) {
354+ Ordering :: Equal => {
355+ // second, compare by values
356+ let left_offset = left_offsets. as_ref ( ) . map ( |o| o[ i] as usize ) . unwrap_or ( i) ;
357+ let right_offset = right_offsets. as_ref ( ) . map ( |o| o[ j] as usize ) . unwrap_or ( j) ;
358+
359+ let cmp = field_comparators
360+ . get ( & left_type_id)
361+ . expect ( "type id not found in field_comparators" ) ;
362+
363+ cmp ( left_offset, right_offset)
364+ }
365+ other => other,
366+ }
367+ } ) ;
368+ Ok ( f)
369+ }
370+
299371/// Returns a comparison function that compares two values at two different positions
300372/// between the two arrays.
301373///
@@ -412,6 +484,7 @@ pub fn make_comparator(
412484 }
413485 } ,
414486 ( Map ( _, _) , Map ( _, _) ) => compare_map( left, right, opts) ,
487+ ( Union ( _, _) , Union ( _, _) ) => compare_union( left, right, opts) ,
415488 ( lhs, rhs) => Err ( ArrowError :: InvalidArgumentError ( match lhs == rhs {
416489 true => format!( "The data type type {lhs:?} has no natural order" ) ,
417490 false => "Can't compare arrays of different types" . to_string( ) ,
@@ -423,8 +496,8 @@ pub fn make_comparator(
423496mod tests {
424497 use super :: * ;
425498 use arrow_array:: builder:: { Int32Builder , ListBuilder , MapBuilder , StringBuilder } ;
426- use arrow_buffer:: { IntervalDayTime , OffsetBuffer , i256} ;
427- use arrow_schema:: { DataType , Field , Fields } ;
499+ use arrow_buffer:: { IntervalDayTime , OffsetBuffer , ScalarBuffer , i256} ;
500+ use arrow_schema:: { DataType , Field , Fields , UnionFields } ;
428501 use half:: f16;
429502 use std:: sync:: Arc ;
430503
@@ -1189,4 +1262,243 @@ mod tests {
11891262 }
11901263 }
11911264 }
1265+
1266+ #[ test]
1267+ fn test_dense_union ( ) {
1268+ // create a dense union array with Int32 (type_id = 0) and Utf8 (type_id=1)
1269+ // the values are: [1, "b", 2, "a", 3]
1270+ // type_ids are: [0, 1, 0, 1, 0]
1271+ // offsets are: [0, 0, 1, 1, 2] from [1, 2, 3] and ["b", "a"]
1272+ let int_array = Int32Array :: from ( vec ! [ 1 , 2 , 3 ] ) ;
1273+ let str_array = StringArray :: from ( vec ! [ "b" , "a" ] ) ;
1274+
1275+ let type_ids = [ 0 , 1 , 0 , 1 , 0 ] . into_iter ( ) . collect :: < ScalarBuffer < i8 > > ( ) ;
1276+ let offsets = [ 0 , 0 , 1 , 1 , 2 ] . into_iter ( ) . collect :: < ScalarBuffer < i32 > > ( ) ;
1277+
1278+ let union_fields = [
1279+ ( 0 , Arc :: new ( Field :: new ( "A" , DataType :: Int32 , false ) ) ) ,
1280+ ( 1 , Arc :: new ( Field :: new ( "B" , DataType :: Utf8 , false ) ) ) ,
1281+ ]
1282+ . into_iter ( )
1283+ . collect :: < UnionFields > ( ) ;
1284+
1285+ let children = vec ! [ Arc :: new( int_array) as ArrayRef , Arc :: new( str_array) ] ;
1286+
1287+ let array1 =
1288+ UnionArray :: try_new ( union_fields. clone ( ) , type_ids, Some ( offsets) , children) . unwrap ( ) ;
1289+
1290+ // create a second array: [2, "a", 1, "c"]
1291+ // type ids are: [0, 1, 0, 1]
1292+ // offsets are: [0, 0, 1, 1] from [2, 1] and ["a", "c"]
1293+ let int_array2 = Int32Array :: from ( vec ! [ 2 , 1 ] ) ;
1294+ let str_array2 = StringArray :: from ( vec ! [ "a" , "c" ] ) ;
1295+ let type_ids2 = [ 0 , 1 , 0 , 1 ] . into_iter ( ) . collect :: < ScalarBuffer < i8 > > ( ) ;
1296+ let offsets2 = [ 0 , 0 , 1 , 1 ] . into_iter ( ) . collect :: < ScalarBuffer < i32 > > ( ) ;
1297+
1298+ let children2 = vec ! [ Arc :: new( int_array2) as ArrayRef , Arc :: new( str_array2) ] ;
1299+
1300+ let array2 =
1301+ UnionArray :: try_new ( union_fields, type_ids2, Some ( offsets2) , children2) . unwrap ( ) ;
1302+
1303+ let opts = SortOptions {
1304+ descending : false ,
1305+ nulls_first : true ,
1306+ } ;
1307+
1308+ // comparing
1309+ // [1, "b", 2, "a", 3]
1310+ // [2, "a", 1, "c"]
1311+ let cmp = make_comparator ( & array1, & array2, opts) . unwrap ( ) ;
1312+
1313+ // array1[0] = (type_id=0, value=1)
1314+ // array2[0] = (type_id=0, value=2)
1315+ assert_eq ! ( cmp( 0 , 0 ) , Ordering :: Less ) ; // 1 < 2
1316+
1317+ // array1[0] = (type_id=0, value=1)
1318+ // array2[1] = (type_id=1, value="a")
1319+ assert_eq ! ( cmp( 0 , 1 ) , Ordering :: Less ) ; // type_id 0 < 1
1320+
1321+ // array1[1] = (type_id=1, value="b")
1322+ // array2[1] = (type_id=1, value="a")
1323+ assert_eq ! ( cmp( 1 , 1 ) , Ordering :: Greater ) ; // "b" > "a"
1324+
1325+ // array1[2] = (type_id=0, value=2)
1326+ // array2[0] = (type_id=0, value=2)
1327+ assert_eq ! ( cmp( 2 , 0 ) , Ordering :: Equal ) ; // 2 == 2
1328+
1329+ // array1[3] = (type_id=1, value="a")
1330+ // array2[1] = (type_id=1, value="a")
1331+ assert_eq ! ( cmp( 3 , 1 ) , Ordering :: Equal ) ; // "a" == "a"
1332+
1333+ // array1[1] = (type_id=1, value="b")
1334+ // array2[3] = (type_id=1, value="c")
1335+ assert_eq ! ( cmp( 1 , 3 ) , Ordering :: Less ) ; // "b" < "c"
1336+
1337+ let opts_desc = SortOptions {
1338+ descending : true ,
1339+ nulls_first : true ,
1340+ } ;
1341+ let cmp_desc = make_comparator ( & array1, & array2, opts_desc) . unwrap ( ) ;
1342+
1343+ assert_eq ! ( cmp_desc( 0 , 0 ) , Ordering :: Greater ) ; // 1 > 2 (reversed)
1344+ assert_eq ! ( cmp_desc( 0 , 1 ) , Ordering :: Greater ) ; // type_id 0 < 1, reversed to Greater
1345+ assert_eq ! ( cmp_desc( 1 , 1 ) , Ordering :: Less ) ; // "b" < "a" (reversed)
1346+ }
1347+
1348+ #[ test]
1349+ fn test_sparse_union ( ) {
1350+ // create a sparse union array with Int32 (type_id=0) and Utf8 (type_id=1)
1351+ // values: [1, "b", 3]
1352+ // note, in sparse unions, child arrays have the same length as the union
1353+ let int_array = Int32Array :: from ( vec ! [ Some ( 1 ) , None , Some ( 3 ) ] ) ;
1354+ let str_array = StringArray :: from ( vec ! [ None , Some ( "b" ) , None ] ) ;
1355+ let type_ids = [ 0 , 1 , 0 ] . into_iter ( ) . collect :: < ScalarBuffer < i8 > > ( ) ;
1356+
1357+ let union_fields = [
1358+ ( 0 , Arc :: new ( Field :: new ( "a" , DataType :: Int32 , false ) ) ) ,
1359+ ( 1 , Arc :: new ( Field :: new ( "b" , DataType :: Utf8 , false ) ) ) ,
1360+ ]
1361+ . into_iter ( )
1362+ . collect :: < UnionFields > ( ) ;
1363+
1364+ let children = vec ! [ Arc :: new( int_array) as ArrayRef , Arc :: new( str_array) ] ;
1365+
1366+ let array = UnionArray :: try_new ( union_fields, type_ids, None , children) . unwrap ( ) ;
1367+
1368+ let opts = SortOptions :: default ( ) ;
1369+ let cmp = make_comparator ( & array, & array, opts) . unwrap ( ) ;
1370+
1371+ // array[0] = (type_id=0, value=1), array[2] = (type_id=0, value=3)
1372+ assert_eq ! ( cmp( 0 , 2 ) , Ordering :: Less ) ; // 1 < 3
1373+ // array[0] = (type_id=0, value=1), array[1] = (type_id=1, value="b")
1374+ assert_eq ! ( cmp( 0 , 1 ) , Ordering :: Less ) ; // type_id 0 < 1
1375+ }
1376+
1377+ #[ test]
1378+ #[ should_panic( expected = "index out of bounds" ) ]
1379+ fn test_union_out_of_bounds ( ) {
1380+ // create a dense union array with 3 elements
1381+ let int_array = Int32Array :: from ( vec ! [ 1 , 2 ] ) ;
1382+ let str_array = StringArray :: from ( vec ! [ "a" ] ) ;
1383+
1384+ let type_ids = [ 0 , 1 , 0 ] . into_iter ( ) . collect :: < ScalarBuffer < i8 > > ( ) ;
1385+ let offsets = [ 0 , 0 , 1 ] . into_iter ( ) . collect :: < ScalarBuffer < i32 > > ( ) ;
1386+
1387+ let union_fields = [
1388+ ( 0 , Arc :: new ( Field :: new ( "A" , DataType :: Int32 , false ) ) ) ,
1389+ ( 1 , Arc :: new ( Field :: new ( "B" , DataType :: Utf8 , false ) ) ) ,
1390+ ]
1391+ . into_iter ( )
1392+ . collect :: < UnionFields > ( ) ;
1393+
1394+ let children = vec ! [ Arc :: new( int_array) as ArrayRef , Arc :: new( str_array) ] ;
1395+
1396+ let array = UnionArray :: try_new ( union_fields, type_ids, Some ( offsets) , children) . unwrap ( ) ;
1397+
1398+ let opts = SortOptions :: default ( ) ;
1399+ let cmp = make_comparator ( & array, & array, opts) . unwrap ( ) ;
1400+
1401+ // oob
1402+ cmp ( 0 , 3 ) ;
1403+ }
1404+
1405+ #[ test]
1406+ fn test_union_incompatible_fields ( ) {
1407+ // create first union with Int32 and Utf8
1408+ let int_array1 = Int32Array :: from ( vec ! [ 1 , 2 ] ) ;
1409+ let str_array1 = StringArray :: from ( vec ! [ "a" , "b" ] ) ;
1410+
1411+ let type_ids1 = [ 0 , 1 ] . into_iter ( ) . collect :: < ScalarBuffer < i8 > > ( ) ;
1412+ let offsets1 = [ 0 , 0 ] . into_iter ( ) . collect :: < ScalarBuffer < i32 > > ( ) ;
1413+
1414+ let union_fields1 = [
1415+ ( 0 , Arc :: new ( Field :: new ( "A" , DataType :: Int32 , false ) ) ) ,
1416+ ( 1 , Arc :: new ( Field :: new ( "B" , DataType :: Utf8 , false ) ) ) ,
1417+ ]
1418+ . into_iter ( )
1419+ . collect :: < UnionFields > ( ) ;
1420+
1421+ let children1 = vec ! [ Arc :: new( int_array1) as ArrayRef , Arc :: new( str_array1) ] ;
1422+
1423+ let array1 =
1424+ UnionArray :: try_new ( union_fields1, type_ids1, Some ( offsets1) , children1) . unwrap ( ) ;
1425+
1426+ // create second union with Int32 and Float64 (incompatible with first)
1427+ let int_array2 = Int32Array :: from ( vec ! [ 3 , 4 ] ) ;
1428+ let float_array2 = Float64Array :: from ( vec ! [ 1.0 , 2.0 ] ) ;
1429+
1430+ let type_ids2 = [ 0 , 1 ] . into_iter ( ) . collect :: < ScalarBuffer < i8 > > ( ) ;
1431+ let offsets2 = [ 0 , 0 ] . into_iter ( ) . collect :: < ScalarBuffer < i32 > > ( ) ;
1432+
1433+ let union_fields2 = [
1434+ ( 0 , Arc :: new ( Field :: new ( "A" , DataType :: Int32 , false ) ) ) ,
1435+ ( 1 , Arc :: new ( Field :: new ( "C" , DataType :: Float64 , false ) ) ) ,
1436+ ]
1437+ . into_iter ( )
1438+ . collect :: < UnionFields > ( ) ;
1439+
1440+ let children2 = vec ! [ Arc :: new( int_array2) as ArrayRef , Arc :: new( float_array2) ] ;
1441+
1442+ let array2 =
1443+ UnionArray :: try_new ( union_fields2, type_ids2, Some ( offsets2) , children2) . unwrap ( ) ;
1444+
1445+ let opts = SortOptions :: default ( ) ;
1446+
1447+ let Result :: Err ( ArrowError :: InvalidArgumentError ( out) ) =
1448+ make_comparator ( & array1, & array2, opts)
1449+ else {
1450+ panic ! ( "expected error when making comparator of incompatible union arrays" ) ;
1451+ } ;
1452+
1453+ assert_eq ! (
1454+ & out,
1455+ "Cannot compare UnionArrays with different fields: left=[(0, Field { name: \" A\" , data_type: Int32 }), (1, Field { name: \" B\" , data_type: Utf8 })], right=[(0, Field { name: \" A\" , data_type: Int32 }), (1, Field { name: \" C\" , data_type: Float64 })]"
1456+ ) ;
1457+ }
1458+
1459+ #[ test]
1460+ fn test_union_incompatible_modes ( ) {
1461+ // create first union as Dense with Int32 and Utf8
1462+ let int_array1 = Int32Array :: from ( vec ! [ 1 , 2 ] ) ;
1463+ let str_array1 = StringArray :: from ( vec ! [ "a" , "b" ] ) ;
1464+
1465+ let type_ids1 = [ 0 , 1 ] . into_iter ( ) . collect :: < ScalarBuffer < i8 > > ( ) ;
1466+ let offsets1 = [ 0 , 0 ] . into_iter ( ) . collect :: < ScalarBuffer < i32 > > ( ) ;
1467+
1468+ let union_fields1 = [
1469+ ( 0 , Arc :: new ( Field :: new ( "A" , DataType :: Int32 , false ) ) ) ,
1470+ ( 1 , Arc :: new ( Field :: new ( "B" , DataType :: Utf8 , false ) ) ) ,
1471+ ]
1472+ . into_iter ( )
1473+ . collect :: < UnionFields > ( ) ;
1474+
1475+ let children1 = vec ! [ Arc :: new( int_array1) as ArrayRef , Arc :: new( str_array1) ] ;
1476+
1477+ let array1 =
1478+ UnionArray :: try_new ( union_fields1. clone ( ) , type_ids1, Some ( offsets1) , children1)
1479+ . unwrap ( ) ;
1480+
1481+ // create second union as Sparse with same fields (Int32 and Utf8)
1482+ let int_array2 = Int32Array :: from ( vec ! [ Some ( 3 ) , None ] ) ;
1483+ let str_array2 = StringArray :: from ( vec ! [ None , Some ( "c" ) ] ) ;
1484+
1485+ let type_ids2 = [ 0 , 1 ] . into_iter ( ) . collect :: < ScalarBuffer < i8 > > ( ) ;
1486+
1487+ let children2 = vec ! [ Arc :: new( int_array2) as ArrayRef , Arc :: new( str_array2) ] ;
1488+
1489+ let array2 = UnionArray :: try_new ( union_fields1, type_ids2, None , children2) . unwrap ( ) ;
1490+
1491+ let opts = SortOptions :: default ( ) ;
1492+
1493+ let Result :: Err ( ArrowError :: InvalidArgumentError ( out) ) =
1494+ make_comparator ( & array1, & array2, opts)
1495+ else {
1496+ panic ! ( "expected error when making comparator of union arrays with different modes" ) ;
1497+ } ;
1498+
1499+ assert_eq ! (
1500+ & out,
1501+ "Cannot compare UnionArrays with different modes: left=Dense, right=Sparse"
1502+ ) ;
1503+ }
11921504}
0 commit comments