1515// specific language governing permissions and limitations
1616// under the License.
1717
18+ //! Dictionary utilities for Arrow arrays
19+
20+ use std:: sync:: Arc ;
21+
22+ use crate :: filter:: filter;
1823use crate :: interleave:: interleave;
1924use ahash:: RandomState ;
2025use arrow_array:: builder:: BooleanBufferBuilder ;
@@ -23,10 +28,69 @@ use arrow_array::types::{
2328 LargeUtf8Type , Utf8Type ,
2429} ;
2530use arrow_array:: { cast:: AsArray , downcast_primitive} ;
26- use arrow_array:: { Array , ArrayRef , DictionaryArray , GenericByteArray , PrimitiveArray } ;
31+ use arrow_array:: {
32+ downcast_dictionary_array, AnyDictionaryArray , Array , ArrayRef , ArrowNativeTypeOp ,
33+ BooleanArray , DictionaryArray , GenericByteArray , PrimitiveArray ,
34+ } ;
2735use arrow_buffer:: { ArrowNativeType , BooleanBuffer , ScalarBuffer , ToByteSlice } ;
2836use arrow_schema:: { ArrowError , DataType } ;
2937
38+ /// Garbage collects a [DictionaryArray] by removing unreferenced values.
39+ ///
40+ /// Returns a new [DictionaryArray] such that there are no values
41+ /// that are not referenced by at least one key. There may still be duplicate
42+ /// values.
43+ ///
44+ /// See also [`garbage_collect_any_dictionary`] if you need to handle multiple dictionary types
45+ pub fn garbage_collect_dictionary < K : ArrowDictionaryKeyType > (
46+ dictionary : & DictionaryArray < K > ,
47+ ) -> Result < DictionaryArray < K > , ArrowError > {
48+ let keys = dictionary. keys ( ) ;
49+ let values = dictionary. values ( ) ;
50+
51+ let mask = dictionary. occupancy ( ) ;
52+
53+ // If no work to do, return the original dictionary
54+ if mask. count_set_bits ( ) == values. len ( ) {
55+ return Ok ( dictionary. clone ( ) ) ;
56+ }
57+
58+ // Create a mapping from the old keys to the new keys, use a Vec for easy indexing
59+ let mut key_remap = vec ! [ K :: Native :: ZERO ; values. len( ) ] ;
60+ for ( new_idx, old_idx) in mask. set_indices ( ) . enumerate ( ) {
61+ key_remap[ old_idx] = K :: Native :: from_usize ( new_idx)
62+ . expect ( "new index should fit in K::Native, as old index was in range" ) ;
63+ }
64+
65+ // ... and then build the new keys array
66+ let new_keys = keys. unary ( |key| {
67+ key_remap
68+ . get ( key. as_usize ( ) )
69+ . copied ( )
70+ // nulls may be present in the keys, and they will have arbitrary value; we don't care
71+ // and can safely return zero
72+ . unwrap_or ( K :: Native :: ZERO )
73+ } ) ;
74+
75+ // Create a new values array by filtering using the mask
76+ let values = filter ( dictionary. values ( ) , & BooleanArray :: new ( mask, None ) ) ?;
77+
78+ Ok ( DictionaryArray :: new ( new_keys, values) )
79+ }
80+
81+ /// Equivalent to [`garbage_collect_dictionary`] but without requiring casting to a specific key type.
82+ pub fn garbage_collect_any_dictionary (
83+ dictionary : & dyn AnyDictionaryArray ,
84+ ) -> Result < ArrayRef , ArrowError > {
85+ // FIXME: this is a workaround for MSRV Rust versions below 1.86 where trait upcasting is not stable.
86+ // From 1.86 onward, `&dyn AnyDictionaryArray` can be directly passed to `downcast_dictionary_array!`.
87+ let dictionary = & * dictionary. slice ( 0 , dictionary. len ( ) ) ;
88+ downcast_dictionary_array ! (
89+ dictionary => garbage_collect_dictionary( dictionary) . map( |dict| Arc :: new( dict) as ArrayRef ) ,
90+ _ => unreachable!( "have a dictionary array" )
91+ )
92+ }
93+
3094/// A best effort interner that maintains a fixed number of buckets
3195/// and interns keys based on their hash value
3296///
@@ -78,7 +142,7 @@ impl<'a, V> Interner<'a, V> {
78142 }
79143}
80144
81- pub struct MergedDictionaries < K : ArrowDictionaryKeyType > {
145+ pub ( crate ) struct MergedDictionaries < K : ArrowDictionaryKeyType > {
82146 /// Provides `key_mappings[`array_idx`][`old_key`] -> new_key`
83147 pub key_mappings : Vec < Vec < K :: Native > > ,
84148 /// The new values
@@ -110,7 +174,7 @@ type PtrEq = fn(&dyn Array, &dyn Array) -> bool;
110174/// some return over the naive approach used by MutableArrayData
111175///
112176/// `len` is the total length of the merged output
113- pub fn should_merge_dictionary_values < K : ArrowDictionaryKeyType > (
177+ pub ( crate ) fn should_merge_dictionary_values < K : ArrowDictionaryKeyType > (
114178 dictionaries : & [ & DictionaryArray < K > ] ,
115179 len : usize ,
116180) -> bool {
@@ -153,7 +217,7 @@ pub fn should_merge_dictionary_values<K: ArrowDictionaryKeyType>(
153217/// This method is meant to be very fast and the output dictionary values
154218/// may not be unique, unlike `GenericByteDictionaryBuilder` which is slower
155219/// but produces unique values
156- pub fn merge_dictionary_values < K : ArrowDictionaryKeyType > (
220+ pub ( crate ) fn merge_dictionary_values < K : ArrowDictionaryKeyType > (
157221 dictionaries : & [ & DictionaryArray < K > ] ,
158222 masks : Option < & [ BooleanBuffer ] > ,
159223) -> Result < MergedDictionaries < K > , ArrowError > {
@@ -298,13 +362,88 @@ fn masked_bytes<'a, T: ByteArrayType>(
298362
299363#[ cfg( test) ]
300364mod tests {
301- use crate :: dictionary:: merge_dictionary_values;
365+ use super :: * ;
366+
302367 use arrow_array:: cast:: as_string_array;
303368 use arrow_array:: types:: Int32Type ;
304- use arrow_array:: { DictionaryArray , Int32Array , StringArray } ;
369+ use arrow_array:: types:: Int8Type ;
370+ use arrow_array:: { DictionaryArray , Int32Array , Int8Array , StringArray } ;
305371 use arrow_buffer:: { BooleanBuffer , Buffer , NullBuffer , OffsetBuffer } ;
306372 use std:: sync:: Arc ;
307373
374+ #[ test]
375+ fn test_garbage_collect_i32_dictionary ( ) {
376+ let values = StringArray :: from_iter_values ( [ "a" , "b" , "c" , "d" ] ) ;
377+ let keys = Int32Array :: from_iter_values ( [ 0 , 1 , 1 , 3 , 0 , 0 , 1 ] ) ;
378+ let dict = DictionaryArray :: < Int32Type > :: new ( keys, Arc :: new ( values) ) ;
379+
380+ // Only "a", "b", "d" are referenced, "c" is not
381+ let gc = garbage_collect_dictionary ( & dict) . unwrap ( ) ;
382+
383+ let expected_values = StringArray :: from_iter_values ( [ "a" , "b" , "d" ] ) ;
384+ let expected_keys = Int32Array :: from_iter_values ( [ 0 , 1 , 1 , 2 , 0 , 0 , 1 ] ) ;
385+ let expected = DictionaryArray :: < Int32Type > :: new ( expected_keys, Arc :: new ( expected_values) ) ;
386+
387+ assert_eq ! ( gc, expected) ;
388+ }
389+
390+ #[ test]
391+ fn test_garbage_collect_any_dictionary ( ) {
392+ let values = StringArray :: from_iter_values ( [ "a" , "b" , "c" , "d" ] ) ;
393+ let keys = Int32Array :: from_iter_values ( [ 0 , 1 , 1 , 3 , 0 , 0 , 1 ] ) ;
394+ let dict = DictionaryArray :: < Int32Type > :: new ( keys, Arc :: new ( values) ) ;
395+
396+ let gc = garbage_collect_any_dictionary ( & dict) . unwrap ( ) ;
397+
398+ let expected_values = StringArray :: from_iter_values ( [ "a" , "b" , "d" ] ) ;
399+ let expected_keys = Int32Array :: from_iter_values ( [ 0 , 1 , 1 , 2 , 0 , 0 , 1 ] ) ;
400+ let expected = DictionaryArray :: < Int32Type > :: new ( expected_keys, Arc :: new ( expected_values) ) ;
401+
402+ assert_eq ! ( gc. as_ref( ) , & expected) ;
403+ }
404+
405+ #[ test]
406+ fn test_garbage_collect_with_nulls ( ) {
407+ let values = StringArray :: from_iter_values ( [ "a" , "b" , "c" ] ) ;
408+ let keys = Int8Array :: from ( vec ! [ Some ( 2 ) , None , Some ( 0 ) ] ) ;
409+ let dict = DictionaryArray :: < Int8Type > :: new ( keys, Arc :: new ( values) ) ;
410+
411+ let gc = garbage_collect_dictionary ( & dict) . unwrap ( ) ;
412+
413+ let expected_values = StringArray :: from_iter_values ( [ "a" , "c" ] ) ;
414+ let expected_keys = Int8Array :: from ( vec ! [ Some ( 1 ) , None , Some ( 0 ) ] ) ;
415+ let expected = DictionaryArray :: < Int8Type > :: new ( expected_keys, Arc :: new ( expected_values) ) ;
416+
417+ assert_eq ! ( gc, expected) ;
418+ }
419+
420+ #[ test]
421+ fn test_garbage_collect_empty_dictionary ( ) {
422+ let values = StringArray :: from_iter_values :: < & str , _ > ( [ ] ) ;
423+ let keys = Int32Array :: from_iter_values ( [ ] ) ;
424+ let dict = DictionaryArray :: < Int32Type > :: new ( keys, Arc :: new ( values) ) ;
425+
426+ let gc = garbage_collect_dictionary ( & dict) . unwrap ( ) ;
427+
428+ assert_eq ! ( gc, dict) ;
429+ }
430+
431+ #[ test]
432+ fn test_garbage_collect_dictionary_all_unreferenced ( ) {
433+ let values = StringArray :: from_iter_values ( [ "a" , "b" , "c" ] ) ;
434+ let keys = Int32Array :: from ( vec ! [ None , None , None ] ) ;
435+ let dict = DictionaryArray :: < Int32Type > :: new ( keys, Arc :: new ( values) ) ;
436+
437+ let gc = garbage_collect_dictionary ( & dict) . unwrap ( ) ;
438+
439+ // All keys are null, so dictionary values can be empty
440+ let expected_values = StringArray :: from_iter_values :: < & str , _ > ( [ ] ) ;
441+ let expected_keys = Int32Array :: from ( vec ! [ None , None , None ] ) ;
442+ let expected = DictionaryArray :: < Int32Type > :: new ( expected_keys, Arc :: new ( expected_values) ) ;
443+
444+ assert_eq ! ( gc, expected) ;
445+ }
446+
308447 #[ test]
309448 fn test_merge_strings ( ) {
310449 let a = DictionaryArray :: < Int32Type > :: from_iter ( [ "a" , "b" , "a" , "b" , "d" , "c" , "e" ] ) ;
0 commit comments