diff --git a/datafusion/physical-plan/src/sorts/cursor.rs b/datafusion/physical-plan/src/sorts/cursor.rs index 17033e6a3142..b85cefb96652 100644 --- a/datafusion/physical-plan/src/sorts/cursor.rs +++ b/datafusion/physical-plan/src/sorts/cursor.rs @@ -288,6 +288,64 @@ impl CursorArray for StringViewArray { } } +/// Todo use arrow-rs side api after: released +/// Builds a 128-bit composite key for an inline value: +/// +/// - High 96 bits: the inline data in big-endian byte order (for correct lexicographical sorting). +/// - Low 32 bits: the length in big-endian byte order, acting as a tiebreaker so shorter strings +/// (or those with fewer meaningful bytes) always numerically sort before longer ones. +/// +/// This function extracts the length and the 12-byte inline string data from the raw +/// little-endian `u128` representation, converts them to big-endian ordering, and packs them +/// into a single `u128` value suitable for fast, branchless comparisons. +/// +/// ### Why include length? +/// +/// A pure 96-bit content comparison can’t distinguish between two values whose inline bytes +/// compare equal—either because one is a true prefix of the other or because zero-padding +/// hides extra bytes. By tucking the 32-bit length into the lower bits, a single `u128` compare +/// handles both content and length in one go. +/// +/// Example: comparing "bar" (3 bytes) vs "bar\0" (4 bytes) +/// +/// | String | Bytes 0–4 (length LE) | Bytes 4–16 (data + padding) | +/// |------------|-----------------------|---------------------------------| +/// | `"bar"` | `03 00 00 00` | `62 61 72` + 9 × `00` | +/// | `"bar\0"`| `04 00 00 00` | `62 61 72 00` + 8 × `00` | +/// +/// Both inline parts become `62 61 72 00…00`, so they tie on content. The length field +/// then differentiates: +/// +/// ```text +/// key("bar") = 0x0000000000000000000062617200000003 +/// key("bar\0") = 0x0000000000000000000062617200000004 +/// ⇒ key("bar") < key("bar\0") +/// ``` +#[inline(always)] +pub fn inline_key_fast(raw: u128) -> u128 { + // Convert the raw u128 (little-endian) into bytes for manipulation + let raw_bytes = raw.to_le_bytes(); + + // Extract the length (first 4 bytes), convert to big-endian u32, and promote to u128 + let len_le = &raw_bytes[0..4]; + let len_be = u32::from_le_bytes(len_le.try_into().unwrap()).to_be() as u128; + + // Extract the inline string bytes (next 12 bytes), place them into the lower 12 bytes of a 16-byte array, + // padding the upper 4 bytes with zero to form a little-endian u128 value + let mut inline_bytes = [0u8; 16]; + inline_bytes[4..16].copy_from_slice(&raw_bytes[4..16]); + + // Convert to big-endian to ensure correct lexical ordering + let inline_u128 = u128::from_le_bytes(inline_bytes).to_be(); + + // Shift right by 32 bits to discard the zero padding (upper 4 bytes), + // so that the inline string occupies the high 96 bits + let inline_part = inline_u128 >> 32; + + // Combine the inline string part (high 96 bits) and length (low 32 bits) into the final key + (inline_part << 32) | len_be +} + impl CursorValues for StringViewArray { fn len(&self) -> usize { self.views().len() @@ -302,7 +360,7 @@ impl CursorValues for StringViewArray { let r_view = unsafe { r.views().get_unchecked(r_idx) }; if l.data_buffers().is_empty() && r.data_buffers().is_empty() { - return l_view.eq(r_view); + return l_view == r_view; } let l_len = *l_view as u32; @@ -322,12 +380,12 @@ impl CursorValues for StringViewArray { let l_view = unsafe { cursor.views().get_unchecked(idx) }; let r_view = unsafe { cursor.views().get_unchecked(idx - 1) }; if cursor.data_buffers().is_empty() { - return l_view.eq(r_view); + return l_view == r_view; } let l_len = *l_view as u32; - let r_len = *r_view as u32; + if l_len != r_len { return false; } @@ -345,11 +403,7 @@ impl CursorValues for StringViewArray { if l.data_buffers().is_empty() && r.data_buffers().is_empty() { let l_view = unsafe { l.views().get_unchecked(l_idx) }; let r_view = unsafe { r.views().get_unchecked(r_idx) }; - let l_len = *l_view as u32; - let r_len = *r_view as u32; - let l_data = unsafe { StringViewArray::inline_value(l_view, l_len as usize) }; - let r_data = unsafe { StringViewArray::inline_value(r_view, r_len as usize) }; - return l_data.cmp(r_data); + return inline_key_fast(*l_view).cmp(&inline_key_fast(*r_view)); } unsafe { GenericByteViewArray::compare_unchecked(l, l_idx, r, r_idx) } @@ -444,11 +498,11 @@ impl CursorValues for ArrayValues { #[cfg(test)] mod tests { - use std::sync::Arc; - + use arrow::array::GenericBinaryArray; use datafusion_execution::memory_pool::{ GreedyMemoryPool, MemoryConsumer, MemoryPool, }; + use std::sync::Arc; use super::*; @@ -609,4 +663,67 @@ mod tests { b.advance(); assert_eq!(a.cmp(&b), Ordering::Less); } + + /// Integration tests for `inline_key_fast` covering: + /// + /// 1. Monotonic ordering across increasing lengths and lexical variations. + /// 2. Cross-check against `GenericBinaryArray` comparison to ensure semantic equivalence. + /// + /// This also includes a specific test for the “bar” vs. “bar\0” case, demonstrating why + /// the length field is required even when all inline bytes fit in 12 bytes. + #[test] + fn test_inline_key_fast_various_lengths_and_lexical() { + /// Helper to create a raw u128 value representing an inline ByteView + /// - `length`: number of meaningful bytes (≤ 12) + /// - `data`: the actual inline data + fn make_raw_inline(length: u32, data: &[u8]) -> u128 { + assert!(length as usize <= 12, "Inline length must be ≤ 12"); + assert!(data.len() == length as usize, "Data must match length"); + + let mut raw_bytes = [0u8; 16]; + raw_bytes[0..4].copy_from_slice(&length.to_le_bytes()); // little-endian length + raw_bytes[4..(4 + data.len())].copy_from_slice(data); // inline data + u128::from_le_bytes(raw_bytes) + } + + // Test inputs: include the specific "bar" vs "bar\0" case, plus length and lexical variations + let test_inputs: Vec<&[u8]> = vec![ + b"a", + b"aa", + b"aaa", + b"aab", + b"abcd", + b"abcde", + b"abcdef", + b"abcdefg", + b"abcdefgh", + b"abcdefghi", + b"abcdefghij", + b"abcdefghijk", + b"abcdefghijkl", // 12 bytes, max inline + b"bar", + b"bar\0", // special case to test length tiebreaker + b"xyy", + b"xyz", + ]; + + // Monotonic key order: content then length,and cross-check against GenericBinaryArray comparison + let array: GenericBinaryArray = GenericBinaryArray::from( + test_inputs.iter().map(|s| Some(*s)).collect::>(), + ); + + for i in 0..array.len() - 1 { + let v1 = array.value(i); + let v2 = array.value(i + 1); + // Ensure lexical ordering matches + assert!(v1 < v2, "Array compare failed: {v1:?} !< {v2:?}"); + // Ensure fast key compare matches + let key1 = inline_key_fast(make_raw_inline(v1.len() as u32, v1)); + let key2 = inline_key_fast(make_raw_inline(v2.len() as u32, v2)); + assert!( + key1 < key2, + "Key compare failed: key({v1:?})=0x{key1:032x} !< key({v2:?})=0x{key2:032x}", + ); + } + } } diff --git a/datafusion/physical-plan/src/sorts/merge.rs b/datafusion/physical-plan/src/sorts/merge.rs index 0c18a3b6c703..ca2d5f2105f2 100644 --- a/datafusion/physical-plan/src/sorts/merge.rs +++ b/datafusion/physical-plan/src/sorts/merge.rs @@ -493,13 +493,12 @@ impl SortPreservingMergeStream { if self.enable_round_robin_tie_breaker && cmp_node == 1 { match (&self.cursors[winner], &self.cursors[challenger]) { (Some(ac), Some(bc)) => { - let ord = ac.cmp(bc); - if ord.is_eq() { + if ac == bc { self.handle_tie(cmp_node, &mut winner, challenger); } else { // Ends of tie breaker self.round_robin_tie_breaker_mode = false; - if ord.is_gt() { + if ac > bc { self.update_winner(cmp_node, &mut winner, challenger); } }