@@ -21,8 +21,8 @@ use arrow_array::cast::AsArray;
2121use arrow_array:: types:: * ;
2222use arrow_array:: * ;
2323use arrow_buffer:: { ArrowNativeType , NullBuffer } ;
24- use arrow_schema:: { ArrowError , SortOptions } ;
25- use std:: cmp:: Ordering ;
24+ use arrow_schema:: { ArrowError , DataType , SortOptions } ;
25+ use std:: { cmp:: Ordering , collections :: HashMap } ;
2626
2727/// Compare the values at two arbitrary indices in two arrays.
2828pub type DynComparator = Box < dyn Fn ( usize , usize ) -> Ordering + Send + Sync > ;
@@ -296,6 +296,70 @@ fn compare_struct(
296296 Ok ( f)
297297}
298298
299+ fn compare_union (
300+ left : & dyn Array ,
301+ right : & dyn Array ,
302+ opts : SortOptions ,
303+ ) -> Result < DynComparator , ArrowError > {
304+ let left = left. as_union ( ) ;
305+ let right = right. as_union ( ) ;
306+
307+ let ( left_fields, left_mode) = match left. data_type ( ) {
308+ DataType :: Union ( fields, mode) => ( fields, mode) ,
309+ _ => unreachable ! ( ) ,
310+ } ;
311+ let ( right_fields, right_mode) = match right. data_type ( ) {
312+ DataType :: Union ( fields, mode) => ( fields, mode) ,
313+ _ => unreachable ! ( ) ,
314+ } ;
315+
316+ if left_fields != right_fields || left_mode != right_mode {
317+ return Err ( ArrowError :: InvalidArgumentError (
318+ "Cannot compare UnionArrays with different fields or modes" . to_string ( ) ,
319+ ) ) ;
320+ }
321+
322+ let c_opts = child_opts ( opts) ;
323+
324+ let mut field_comparators = HashMap :: with_capacity ( left_fields. len ( ) ) ;
325+
326+ for ( type_id, _field) in left_fields. iter ( ) {
327+ let left_child = left. child ( type_id) ;
328+ let right_child = right. child ( type_id) ;
329+ let cmp = make_comparator ( left_child. as_ref ( ) , right_child. as_ref ( ) , c_opts) ?;
330+
331+ field_comparators. insert ( type_id, cmp) ;
332+ }
333+
334+ let left_type_ids = left. type_ids ( ) . clone ( ) ;
335+ let right_type_ids = right. type_ids ( ) . clone ( ) ;
336+
337+ let left_offsets = left. offsets ( ) . cloned ( ) ;
338+ let right_offsets = right. offsets ( ) . cloned ( ) ;
339+
340+ let f = compare ( left, right, opts, move |i, j| {
341+ let left_type_id = left_type_ids[ i] ;
342+ let right_type_id = right_type_ids[ j] ;
343+
344+ // first, compare by type_id
345+ match left_type_id. cmp ( & right_type_id) {
346+ Ordering :: Equal => {
347+ // second, compare by values
348+ let left_offset = left_offsets. as_ref ( ) . map ( |o| o[ i] as usize ) . unwrap_or ( i) ;
349+ let right_offset = right_offsets. as_ref ( ) . map ( |o| o[ j] as usize ) . unwrap_or ( j) ;
350+
351+ let cmp = field_comparators
352+ . get ( & left_type_id)
353+ . expect ( "type id not found in field_comparators" ) ;
354+
355+ cmp ( left_offset, right_offset)
356+ }
357+ other => other,
358+ }
359+ } ) ;
360+ Ok ( f)
361+ }
362+
299363/// Returns a comparison function that compares two values at two different positions
300364/// between the two arrays.
301365///
@@ -412,6 +476,7 @@ pub fn make_comparator(
412476 }
413477 } ,
414478 ( Map ( _, _) , Map ( _, _) ) => compare_map( left, right, opts) ,
479+ ( Union ( _, _) , Union ( _, _) ) => compare_union( left, right, opts) ,
415480 ( lhs, rhs) => Err ( ArrowError :: InvalidArgumentError ( match lhs == rhs {
416481 true => format!( "The data type type {lhs:?} has no natural order" ) ,
417482 false => "Can't compare arrays of different types" . to_string( ) ,
@@ -423,8 +488,8 @@ pub fn make_comparator(
423488mod tests {
424489 use super :: * ;
425490 use arrow_array:: builder:: { Int32Builder , ListBuilder , MapBuilder , StringBuilder } ;
426- use arrow_buffer:: { IntervalDayTime , OffsetBuffer , i256} ;
427- use arrow_schema:: { DataType , Field , Fields } ;
491+ use arrow_buffer:: { IntervalDayTime , OffsetBuffer , ScalarBuffer , i256} ;
492+ use arrow_schema:: { DataType , Field , Fields , UnionFields } ;
428493 use half:: f16;
429494 use std:: sync:: Arc ;
430495
@@ -1189,4 +1254,115 @@ mod tests {
11891254 }
11901255 }
11911256 }
1257+
1258+ #[ test]
1259+ fn test_dense_union ( ) {
1260+ // create a dense union array with Int32 (type_id = 0) and Utf8 (type_id=1)
1261+ // the values are: [1, "b", 2, "a", 3]
1262+ // type_ids are: [0, 1, 0, 1, 0]
1263+ // offsets are: [0, 0, 1, 1, 2] from [1, 2, 3] and ["b", "a"]
1264+ let int_array = Int32Array :: from ( vec ! [ 1 , 2 , 3 ] ) ;
1265+ let str_array = StringArray :: from ( vec ! [ "b" , "a" ] ) ;
1266+
1267+ let type_ids = [ 0 , 1 , 0 , 1 , 0 ] . into_iter ( ) . collect :: < ScalarBuffer < i8 > > ( ) ;
1268+ let offsets = [ 0 , 0 , 1 , 1 , 2 ] . into_iter ( ) . collect :: < ScalarBuffer < i32 > > ( ) ;
1269+
1270+ let union_fields = [
1271+ ( 0 , Arc :: new ( Field :: new ( "A" , DataType :: Int32 , false ) ) ) ,
1272+ ( 1 , Arc :: new ( Field :: new ( "B" , DataType :: Utf8 , false ) ) ) ,
1273+ ]
1274+ . into_iter ( )
1275+ . collect :: < UnionFields > ( ) ;
1276+
1277+ let children = vec ! [ Arc :: new( int_array) as ArrayRef , Arc :: new( str_array) ] ;
1278+
1279+ let array1 =
1280+ UnionArray :: try_new ( union_fields. clone ( ) , type_ids, Some ( offsets) , children) . unwrap ( ) ;
1281+
1282+ // create a second array: [2, "a", 1, "c"]
1283+ // type ids are: [0, 1, 0, 1]
1284+ // offsets are: [0, 0, 1, 1] from [2, 1] and ["a", "c"]
1285+ let int_array2 = Int32Array :: from ( vec ! [ 2 , 1 ] ) ;
1286+ let str_array2 = StringArray :: from ( vec ! [ "a" , "c" ] ) ;
1287+ let type_ids2 = [ 0 , 1 , 0 , 1 ] . into_iter ( ) . collect :: < ScalarBuffer < i8 > > ( ) ;
1288+ let offsets2 = [ 0 , 0 , 1 , 1 ] . into_iter ( ) . collect :: < ScalarBuffer < i32 > > ( ) ;
1289+
1290+ let children2 = vec ! [ Arc :: new( int_array2) as ArrayRef , Arc :: new( str_array2) ] ;
1291+
1292+ let array2 =
1293+ UnionArray :: try_new ( union_fields, type_ids2, Some ( offsets2) , children2) . unwrap ( ) ;
1294+
1295+ let opts = SortOptions {
1296+ descending : false ,
1297+ nulls_first : true ,
1298+ } ;
1299+
1300+ // comparing
1301+ // [1, "b", 2, "a", 3]
1302+ // [2, "a", 1, "c"]
1303+ let cmp = make_comparator ( & array1, & array2, opts) . unwrap ( ) ;
1304+
1305+ // array1[0] = (type_id=0, value=1)
1306+ // array2[0] = (type_id=0, value=2)
1307+ assert_eq ! ( cmp( 0 , 0 ) , Ordering :: Less ) ; // 1 < 2
1308+
1309+ // array1[0] = (type_id=0, value=1)
1310+ // array2[1] = (type_id=1, value="a")
1311+ assert_eq ! ( cmp( 0 , 1 ) , Ordering :: Less ) ; // type_id 0 < 1
1312+
1313+ // array1[1] = (type_id=1, value="b")
1314+ // array2[1] = (type_id=1, value="a")
1315+ assert_eq ! ( cmp( 1 , 1 ) , Ordering :: Greater ) ; // "b" > "a"
1316+
1317+ // array1[2] = (type_id=0, value=2)
1318+ // array2[0] = (type_id=0, value=2)
1319+ assert_eq ! ( cmp( 2 , 0 ) , Ordering :: Equal ) ; // 2 == 2
1320+
1321+ // array1[3] = (type_id=1, value="a")
1322+ // array2[1] = (type_id=1, value="a")
1323+ assert_eq ! ( cmp( 3 , 1 ) , Ordering :: Equal ) ; // "a" == "a"
1324+
1325+ // array1[1] = (type_id=1, value="b")
1326+ // array2[3] = (type_id=1, value="c")
1327+ assert_eq ! ( cmp( 1 , 3 ) , Ordering :: Less ) ; // "b" < "c"
1328+
1329+ let opts_desc = SortOptions {
1330+ descending : true ,
1331+ nulls_first : true ,
1332+ } ;
1333+ let cmp_desc = make_comparator ( & array1, & array2, opts_desc) . unwrap ( ) ;
1334+
1335+ assert_eq ! ( cmp_desc( 0 , 0 ) , Ordering :: Greater ) ; // 1 > 2 (reversed)
1336+ assert_eq ! ( cmp_desc( 0 , 1 ) , Ordering :: Greater ) ; // type_id 0 < 1, reversed to Greater
1337+ assert_eq ! ( cmp_desc( 1 , 1 ) , Ordering :: Less ) ; // "b" < "a" (reversed)
1338+ }
1339+
1340+ #[ test]
1341+ fn test_sparse_union ( ) {
1342+ // create a sparse union array with Int32 (type_id=0) and Utf8 (type_id=1)
1343+ // values: [1, "b", 3]
1344+ // note, in sparse unions, child arrays have the same length as the union
1345+ let int_array = Int32Array :: from ( vec ! [ Some ( 1 ) , None , Some ( 3 ) ] ) ;
1346+ let str_array = StringArray :: from ( vec ! [ None , Some ( "b" ) , None ] ) ;
1347+ let type_ids = [ 0 , 1 , 0 ] . into_iter ( ) . collect :: < ScalarBuffer < i8 > > ( ) ;
1348+
1349+ let union_fields = [
1350+ ( 0 , Arc :: new ( Field :: new ( "a" , DataType :: Int32 , false ) ) ) ,
1351+ ( 1 , Arc :: new ( Field :: new ( "b" , DataType :: Utf8 , false ) ) ) ,
1352+ ]
1353+ . into_iter ( )
1354+ . collect :: < UnionFields > ( ) ;
1355+
1356+ let children = vec ! [ Arc :: new( int_array) as ArrayRef , Arc :: new( str_array) ] ;
1357+
1358+ let array = UnionArray :: try_new ( union_fields, type_ids, None , children) . unwrap ( ) ;
1359+
1360+ let opts = SortOptions :: default ( ) ;
1361+ let cmp = make_comparator ( & array, & array, opts) . unwrap ( ) ;
1362+
1363+ // array[0] = (type_id=0, value=1), array[2] = (type_id=0, value=3)
1364+ assert_eq ! ( cmp( 0 , 2 ) , Ordering :: Less ) ; // 1 < 3
1365+ // array[0] = (type_id=0, value=1), array[1] = (type_id=1, value="b")
1366+ assert_eq ! ( cmp( 0 , 1 ) , Ordering :: Less ) ; // type_id 0 < 1
1367+ }
11921368}
0 commit comments