diff --git a/arrow-select/src/dictionary.rs b/arrow-select/src/dictionary.rs index c5773b16a486..ff1198cf7098 100644 --- a/arrow-select/src/dictionary.rs +++ b/arrow-select/src/dictionary.rs @@ -15,6 +15,11 @@ // specific language governing permissions and limitations // under the License. +//! Dictionary utilities for Arrow arrays + +use std::sync::Arc; + +use crate::filter::filter; use crate::interleave::interleave; use ahash::RandomState; use arrow_array::builder::BooleanBufferBuilder; @@ -23,10 +28,69 @@ use arrow_array::types::{ LargeUtf8Type, Utf8Type, }; use arrow_array::{cast::AsArray, downcast_primitive}; -use arrow_array::{Array, ArrayRef, DictionaryArray, GenericByteArray, PrimitiveArray}; +use arrow_array::{ + downcast_dictionary_array, AnyDictionaryArray, Array, ArrayRef, ArrowNativeTypeOp, + BooleanArray, DictionaryArray, GenericByteArray, PrimitiveArray, +}; use arrow_buffer::{ArrowNativeType, BooleanBuffer, ScalarBuffer, ToByteSlice}; use arrow_schema::{ArrowError, DataType}; +/// Garbage collects a [DictionaryArray] by removing unreferenced values. +/// +/// Returns a new [DictionaryArray] such that there are no values +/// that are not referenced by at least one key. There may still be duplicate +/// values. +/// +/// See also [`garbage_collect_any_dictionary`] if you need to handle multiple dictionary types +pub fn garbage_collect_dictionary( + dictionary: &DictionaryArray, +) -> Result, ArrowError> { + let keys = dictionary.keys(); + let values = dictionary.values(); + + let mask = dictionary.occupancy(); + + // If no work to do, return the original dictionary + if mask.count_set_bits() == values.len() { + return Ok(dictionary.clone()); + } + + // Create a mapping from the old keys to the new keys, use a Vec for easy indexing + let mut key_remap = vec![K::Native::ZERO; values.len()]; + for (new_idx, old_idx) in mask.set_indices().enumerate() { + key_remap[old_idx] = K::Native::from_usize(new_idx) + .expect("new index should fit in K::Native, as old index was in range"); + } + + // ... and then build the new keys array + let new_keys = keys.unary(|key| { + key_remap + .get(key.as_usize()) + .copied() + // nulls may be present in the keys, and they will have arbitrary value; we don't care + // and can safely return zero + .unwrap_or(K::Native::ZERO) + }); + + // Create a new values array by filtering using the mask + let values = filter(dictionary.values(), &BooleanArray::new(mask, None))?; + + Ok(DictionaryArray::new(new_keys, values)) +} + +/// Equivalent to [`garbage_collect_dictionary`] but without requiring casting to a specific key type. +pub fn garbage_collect_any_dictionary( + dictionary: &dyn AnyDictionaryArray, +) -> Result { + // FIXME: this is a workaround for MSRV Rust versions below 1.86 where trait upcasting is not stable. + // From 1.86 onward, `&dyn AnyDictionaryArray` can be directly passed to `downcast_dictionary_array!`. + let dictionary = &*dictionary.slice(0, dictionary.len()); + downcast_dictionary_array!( + dictionary => garbage_collect_dictionary(dictionary).map(|dict| Arc::new(dict) as ArrayRef), + _ => unreachable!("have a dictionary array") + ) +} + /// A best effort interner that maintains a fixed number of buckets /// and interns keys based on their hash value /// @@ -78,7 +142,7 @@ impl<'a, V> Interner<'a, V> { } } -pub struct MergedDictionaries { +pub(crate) struct MergedDictionaries { /// Provides `key_mappings[`array_idx`][`old_key`] -> new_key` pub key_mappings: Vec>, /// The new values @@ -110,7 +174,7 @@ type PtrEq = fn(&dyn Array, &dyn Array) -> bool; /// some return over the naive approach used by MutableArrayData /// /// `len` is the total length of the merged output -pub fn should_merge_dictionary_values( +pub(crate) fn should_merge_dictionary_values( dictionaries: &[&DictionaryArray], len: usize, ) -> bool { @@ -153,7 +217,7 @@ pub fn should_merge_dictionary_values( /// This method is meant to be very fast and the output dictionary values /// may not be unique, unlike `GenericByteDictionaryBuilder` which is slower /// but produces unique values -pub fn merge_dictionary_values( +pub(crate) fn merge_dictionary_values( dictionaries: &[&DictionaryArray], masks: Option<&[BooleanBuffer]>, ) -> Result, ArrowError> { @@ -298,13 +362,88 @@ fn masked_bytes<'a, T: ByteArrayType>( #[cfg(test)] mod tests { - use crate::dictionary::merge_dictionary_values; + use super::*; + use arrow_array::cast::as_string_array; use arrow_array::types::Int32Type; - use arrow_array::{DictionaryArray, Int32Array, StringArray}; + use arrow_array::types::Int8Type; + use arrow_array::{DictionaryArray, Int32Array, Int8Array, StringArray}; use arrow_buffer::{BooleanBuffer, Buffer, NullBuffer, OffsetBuffer}; use std::sync::Arc; + #[test] + fn test_garbage_collect_i32_dictionary() { + let values = StringArray::from_iter_values(["a", "b", "c", "d"]); + let keys = Int32Array::from_iter_values([0, 1, 1, 3, 0, 0, 1]); + let dict = DictionaryArray::::new(keys, Arc::new(values)); + + // Only "a", "b", "d" are referenced, "c" is not + let gc = garbage_collect_dictionary(&dict).unwrap(); + + let expected_values = StringArray::from_iter_values(["a", "b", "d"]); + let expected_keys = Int32Array::from_iter_values([0, 1, 1, 2, 0, 0, 1]); + let expected = DictionaryArray::::new(expected_keys, Arc::new(expected_values)); + + assert_eq!(gc, expected); + } + + #[test] + fn test_garbage_collect_any_dictionary() { + let values = StringArray::from_iter_values(["a", "b", "c", "d"]); + let keys = Int32Array::from_iter_values([0, 1, 1, 3, 0, 0, 1]); + let dict = DictionaryArray::::new(keys, Arc::new(values)); + + let gc = garbage_collect_any_dictionary(&dict).unwrap(); + + let expected_values = StringArray::from_iter_values(["a", "b", "d"]); + let expected_keys = Int32Array::from_iter_values([0, 1, 1, 2, 0, 0, 1]); + let expected = DictionaryArray::::new(expected_keys, Arc::new(expected_values)); + + assert_eq!(gc.as_ref(), &expected); + } + + #[test] + fn test_garbage_collect_with_nulls() { + let values = StringArray::from_iter_values(["a", "b", "c"]); + let keys = Int8Array::from(vec![Some(2), None, Some(0)]); + let dict = DictionaryArray::::new(keys, Arc::new(values)); + + let gc = garbage_collect_dictionary(&dict).unwrap(); + + let expected_values = StringArray::from_iter_values(["a", "c"]); + let expected_keys = Int8Array::from(vec![Some(1), None, Some(0)]); + let expected = DictionaryArray::::new(expected_keys, Arc::new(expected_values)); + + assert_eq!(gc, expected); + } + + #[test] + fn test_garbage_collect_empty_dictionary() { + let values = StringArray::from_iter_values::<&str, _>([]); + let keys = Int32Array::from_iter_values([]); + let dict = DictionaryArray::::new(keys, Arc::new(values)); + + let gc = garbage_collect_dictionary(&dict).unwrap(); + + assert_eq!(gc, dict); + } + + #[test] + fn test_garbage_collect_dictionary_all_unreferenced() { + let values = StringArray::from_iter_values(["a", "b", "c"]); + let keys = Int32Array::from(vec![None, None, None]); + let dict = DictionaryArray::::new(keys, Arc::new(values)); + + let gc = garbage_collect_dictionary(&dict).unwrap(); + + // All keys are null, so dictionary values can be empty + let expected_values = StringArray::from_iter_values::<&str, _>([]); + let expected_keys = Int32Array::from(vec![None, None, None]); + let expected = DictionaryArray::::new(expected_keys, Arc::new(expected_values)); + + assert_eq!(gc, expected); + } + #[test] fn test_merge_strings() { let a = DictionaryArray::::from_iter(["a", "b", "a", "b", "d", "c", "e"]); diff --git a/arrow-select/src/lib.rs b/arrow-select/src/lib.rs index a2ddff351c9a..f755a05e3da1 100644 --- a/arrow-select/src/lib.rs +++ b/arrow-select/src/lib.rs @@ -26,7 +26,7 @@ pub mod coalesce; pub mod concat; -mod dictionary; +pub mod dictionary; pub mod filter; pub mod interleave; pub mod nullif;