Skip to content

Commit ce8e67c

Browse files
authored
Fix subtraction underflow when sorting string arrays with many nulls (#285)
1 parent 8226219 commit ce8e67c

File tree

1 file changed

+274
-11
lines changed

1 file changed

+274
-11
lines changed

arrow/src/compute/kernels/sort.rs

Lines changed: 274 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -410,24 +410,27 @@ fn sort_boolean(
410410
len = limit.min(len);
411411
}
412412
if !descending {
413-
sort_by(&mut valids, len - nulls_len, |a, b| cmp(a.1, b.1));
413+
sort_by(&mut valids, len.saturating_sub(nulls_len), |a, b| {
414+
cmp(a.1, b.1)
415+
});
414416
} else {
415-
sort_by(&mut valids, len - nulls_len, |a, b| cmp(a.1, b.1).reverse());
417+
sort_by(&mut valids, len.saturating_sub(nulls_len), |a, b| {
418+
cmp(a.1, b.1).reverse()
419+
});
416420
// reverse to keep a stable ordering
417421
nulls.reverse();
418422
}
419423

420424
// collect results directly into a buffer instead of a vec to avoid another aligned allocation
421-
let mut result = MutableBuffer::new(values.len() * std::mem::size_of::<u32>());
425+
let result_capacity = len * std::mem::size_of::<u32>();
426+
let mut result = MutableBuffer::new(result_capacity);
422427
// sets len to capacity so we can access the whole buffer as a typed slice
423-
result.resize(values.len() * std::mem::size_of::<u32>(), 0);
428+
result.resize(result_capacity, 0);
424429
let result_slice: &mut [u32] = result.typed_data_mut();
425430

426-
debug_assert_eq!(result_slice.len(), nulls_len + valids_len);
427-
428431
if options.nulls_first {
429432
let size = nulls_len.min(len);
430-
result_slice[0..nulls_len.min(len)].copy_from_slice(&nulls);
433+
result_slice[0..size].copy_from_slice(&nulls[0..size]);
431434
if nulls_len < len {
432435
insert_valid_values(result_slice, nulls_len, &valids[0..len - size]);
433436
}
@@ -626,9 +629,13 @@ where
626629
len = limit.min(len);
627630
}
628631
if !descending {
629-
sort_by(&mut valids, len - nulls_len, |a, b| cmp(a.1, b.1));
632+
sort_by(&mut valids, len.saturating_sub(nulls_len), |a, b| {
633+
cmp(a.1, b.1)
634+
});
630635
} else {
631-
sort_by(&mut valids, len - nulls_len, |a, b| cmp(a.1, b.1).reverse());
636+
sort_by(&mut valids, len.saturating_sub(nulls_len), |a, b| {
637+
cmp(a.1, b.1).reverse()
638+
});
632639
// reverse to keep a stable ordering
633640
nulls.reverse();
634641
}
@@ -689,11 +696,11 @@ where
689696
len = limit.min(len);
690697
}
691698
if !descending {
692-
sort_by(&mut valids, len - nulls_len, |a, b| {
699+
sort_by(&mut valids, len.saturating_sub(nulls_len), |a, b| {
693700
cmp_array(a.1.as_ref(), b.1.as_ref())
694701
});
695702
} else {
696-
sort_by(&mut valids, len - nulls_len, |a, b| {
703+
sort_by(&mut valids, len.saturating_sub(nulls_len), |a, b| {
697704
cmp_array(a.1.as_ref(), b.1.as_ref()).reverse()
698705
});
699706
// reverse to keep a stable ordering
@@ -1285,6 +1292,48 @@ mod tests {
12851292
None,
12861293
vec![5, 0, 2, 1, 4, 3],
12871294
);
1295+
1296+
// valid values less than limit with extra nulls
1297+
test_sort_to_indices_primitive_arrays::<Float64Type>(
1298+
vec![Some(2.0), None, None, Some(1.0)],
1299+
Some(SortOptions {
1300+
descending: false,
1301+
nulls_first: false,
1302+
}),
1303+
Some(3),
1304+
vec![3, 0, 1],
1305+
);
1306+
1307+
test_sort_to_indices_primitive_arrays::<Float64Type>(
1308+
vec![Some(2.0), None, None, Some(1.0)],
1309+
Some(SortOptions {
1310+
descending: false,
1311+
nulls_first: true,
1312+
}),
1313+
Some(3),
1314+
vec![1, 2, 3],
1315+
);
1316+
1317+
// more nulls than limit
1318+
test_sort_to_indices_primitive_arrays::<Float64Type>(
1319+
vec![Some(1.0), None, None, None],
1320+
Some(SortOptions {
1321+
descending: false,
1322+
nulls_first: true,
1323+
}),
1324+
Some(2),
1325+
vec![1, 2],
1326+
);
1327+
1328+
test_sort_to_indices_primitive_arrays::<Float64Type>(
1329+
vec![Some(1.0), None, None, None],
1330+
Some(SortOptions {
1331+
descending: false,
1332+
nulls_first: false,
1333+
}),
1334+
Some(2),
1335+
vec![0, 1],
1336+
);
12881337
}
12891338

12901339
#[test]
@@ -1329,6 +1378,48 @@ mod tests {
13291378
Some(3),
13301379
vec![5, 0, 2],
13311380
);
1381+
1382+
// valid values less than limit with extra nulls
1383+
test_sort_to_indices_boolean_arrays(
1384+
vec![Some(true), None, None, Some(false)],
1385+
Some(SortOptions {
1386+
descending: false,
1387+
nulls_first: false,
1388+
}),
1389+
Some(3),
1390+
vec![3, 0, 1],
1391+
);
1392+
1393+
test_sort_to_indices_boolean_arrays(
1394+
vec![Some(true), None, None, Some(false)],
1395+
Some(SortOptions {
1396+
descending: false,
1397+
nulls_first: true,
1398+
}),
1399+
Some(3),
1400+
vec![1, 2, 3],
1401+
);
1402+
1403+
// more nulls than limit
1404+
test_sort_to_indices_boolean_arrays(
1405+
vec![Some(true), None, None, None],
1406+
Some(SortOptions {
1407+
descending: false,
1408+
nulls_first: true,
1409+
}),
1410+
Some(2),
1411+
vec![1, 2],
1412+
);
1413+
1414+
test_sort_to_indices_boolean_arrays(
1415+
vec![Some(true), None, None, None],
1416+
Some(SortOptions {
1417+
descending: false,
1418+
nulls_first: false,
1419+
}),
1420+
Some(2),
1421+
vec![0, 1],
1422+
);
13321423
}
13331424

13341425
#[test]
@@ -1686,6 +1777,48 @@ mod tests {
16861777
Some(3),
16871778
vec![3, 0, 2],
16881779
);
1780+
1781+
// valid values less than limit with extra nulls
1782+
test_sort_to_indices_string_arrays(
1783+
vec![Some("def"), None, None, Some("abc")],
1784+
Some(SortOptions {
1785+
descending: false,
1786+
nulls_first: false,
1787+
}),
1788+
Some(3),
1789+
vec![3, 0, 1],
1790+
);
1791+
1792+
test_sort_to_indices_string_arrays(
1793+
vec![Some("def"), None, None, Some("abc")],
1794+
Some(SortOptions {
1795+
descending: false,
1796+
nulls_first: true,
1797+
}),
1798+
Some(3),
1799+
vec![1, 2, 3],
1800+
);
1801+
1802+
// more nulls than limit
1803+
test_sort_to_indices_string_arrays(
1804+
vec![Some("def"), None, None, None],
1805+
Some(SortOptions {
1806+
descending: false,
1807+
nulls_first: true,
1808+
}),
1809+
Some(2),
1810+
vec![1, 2],
1811+
);
1812+
1813+
test_sort_to_indices_string_arrays(
1814+
vec![Some("def"), None, None, None],
1815+
Some(SortOptions {
1816+
descending: false,
1817+
nulls_first: false,
1818+
}),
1819+
Some(2),
1820+
vec![0, 1],
1821+
);
16891822
}
16901823

16911824
#[test]
@@ -1799,6 +1932,48 @@ mod tests {
17991932
Some(3),
18001933
vec![None, None, Some("sad")],
18011934
);
1935+
1936+
// valid values less than limit with extra nulls
1937+
test_sort_string_arrays(
1938+
vec![Some("def"), None, None, Some("abc")],
1939+
Some(SortOptions {
1940+
descending: false,
1941+
nulls_first: false,
1942+
}),
1943+
Some(3),
1944+
vec![Some("abc"), Some("def"), None],
1945+
);
1946+
1947+
test_sort_string_arrays(
1948+
vec![Some("def"), None, None, Some("abc")],
1949+
Some(SortOptions {
1950+
descending: false,
1951+
nulls_first: true,
1952+
}),
1953+
Some(3),
1954+
vec![None, None, Some("abc")],
1955+
);
1956+
1957+
// more nulls than limit
1958+
test_sort_string_arrays(
1959+
vec![Some("def"), None, None, None],
1960+
Some(SortOptions {
1961+
descending: false,
1962+
nulls_first: true,
1963+
}),
1964+
Some(2),
1965+
vec![None, None],
1966+
);
1967+
1968+
test_sort_string_arrays(
1969+
vec![Some("def"), None, None, None],
1970+
Some(SortOptions {
1971+
descending: false,
1972+
nulls_first: false,
1973+
}),
1974+
Some(2),
1975+
vec![Some("def"), None],
1976+
);
18021977
}
18031978

18041979
#[test]
@@ -1912,6 +2087,48 @@ mod tests {
19122087
Some(3),
19132088
vec![None, None, Some("sad")],
19142089
);
2090+
2091+
// valid values less than limit with extra nulls
2092+
test_sort_string_dict_arrays::<Int16Type>(
2093+
vec![Some("def"), None, None, Some("abc")],
2094+
Some(SortOptions {
2095+
descending: false,
2096+
nulls_first: false,
2097+
}),
2098+
Some(3),
2099+
vec![Some("abc"), Some("def"), None],
2100+
);
2101+
2102+
test_sort_string_dict_arrays::<Int16Type>(
2103+
vec![Some("def"), None, None, Some("abc")],
2104+
Some(SortOptions {
2105+
descending: false,
2106+
nulls_first: true,
2107+
}),
2108+
Some(3),
2109+
vec![None, None, Some("abc")],
2110+
);
2111+
2112+
// more nulls than limit
2113+
test_sort_string_dict_arrays::<Int16Type>(
2114+
vec![Some("def"), None, None, None],
2115+
Some(SortOptions {
2116+
descending: false,
2117+
nulls_first: true,
2118+
}),
2119+
Some(2),
2120+
vec![None, None],
2121+
);
2122+
2123+
test_sort_string_dict_arrays::<Int16Type>(
2124+
vec![Some("def"), None, None, None],
2125+
Some(SortOptions {
2126+
descending: false,
2127+
nulls_first: false,
2128+
}),
2129+
Some(2),
2130+
vec![Some("def"), None],
2131+
);
19152132
}
19162133

19172134
#[test]
@@ -1999,6 +2216,52 @@ mod tests {
19992216
vec![Some(vec![Some(1), Some(0)]), Some(vec![Some(1), Some(1)])],
20002217
None,
20012218
);
2219+
2220+
// valid values less than limit with extra nulls
2221+
test_sort_list_arrays::<Int32Type>(
2222+
vec![Some(vec![Some(1)]), None, None, Some(vec![Some(2)])],
2223+
Some(SortOptions {
2224+
descending: false,
2225+
nulls_first: false,
2226+
}),
2227+
Some(3),
2228+
vec![Some(vec![Some(1)]), Some(vec![Some(2)]), None],
2229+
None,
2230+
);
2231+
2232+
test_sort_list_arrays::<Int32Type>(
2233+
vec![Some(vec![Some(1)]), None, None, Some(vec![Some(2)])],
2234+
Some(SortOptions {
2235+
descending: false,
2236+
nulls_first: true,
2237+
}),
2238+
Some(3),
2239+
vec![None, None, Some(vec![Some(2)])],
2240+
None,
2241+
);
2242+
2243+
// more nulls than limit
2244+
test_sort_list_arrays::<Int32Type>(
2245+
vec![Some(vec![Some(1)]), None, None, None],
2246+
Some(SortOptions {
2247+
descending: false,
2248+
nulls_first: true,
2249+
}),
2250+
Some(2),
2251+
vec![None, None],
2252+
None,
2253+
);
2254+
2255+
test_sort_list_arrays::<Int32Type>(
2256+
vec![Some(vec![Some(1)]), None, None, None],
2257+
Some(SortOptions {
2258+
descending: false,
2259+
nulls_first: false,
2260+
}),
2261+
Some(2),
2262+
vec![Some(vec![Some(1)]), None],
2263+
None,
2264+
);
20022265
}
20032266

20042267
#[test]

0 commit comments

Comments
 (0)