@@ -20,13 +20,14 @@ use ahash::RandomState;
2020use arrow:: compute:: cast;
2121use arrow:: record_batch:: RecordBatch ;
2222use arrow:: row:: { RowConverter , Rows , SortField } ;
23- use arrow_array:: { Array , ArrayRef } ;
23+ use arrow_array:: { Array , ArrayRef , ListArray , StructArray } ;
2424use arrow_schema:: { DataType , SchemaRef } ;
2525use datafusion_common:: hash_utils:: create_hashes;
2626use datafusion_common:: { DataFusionError , Result } ;
2727use datafusion_execution:: memory_pool:: proxy:: { RawTableAllocExt , VecAllocExt } ;
2828use datafusion_expr:: EmitTo ;
2929use hashbrown:: raw:: RawTable ;
30+ use std:: sync:: Arc ;
3031
3132/// A [`GroupValues`] making use of [`Rows`]
3233pub struct GroupValuesRows {
@@ -230,6 +231,11 @@ impl GroupValues for GroupValuesRows {
230231 }
231232 * array = cast ( array. as_ref ( ) , expected) ?;
232233 }
234+
235+ if expected. is_nested ( ) && needs_nested_dictionary_encoding ( expected, array) ?
236+ {
237+ * array = dictionary_encode_nested ( array. clone ( ) , expected) ?;
238+ }
233239 }
234240
235241 self . group_values = Some ( group_values) ;
@@ -249,3 +255,93 @@ impl GroupValues for GroupValuesRows {
249255 self . hashes_buffer . shrink_to ( count) ;
250256 }
251257}
258+
259+ fn needs_nested_dictionary_encoding (
260+ expected : & DataType ,
261+ actual : & ArrayRef ,
262+ ) -> Result < bool > {
263+ match ( expected, actual. data_type ( ) ) {
264+ (
265+ & DataType :: Struct ( ref expected_fields) ,
266+ & DataType :: Struct ( ref actual_fields) ,
267+ ) => {
268+ if expected_fields. len ( ) != actual_fields. len ( ) {
269+ return Err ( DataFusionError :: Internal ( format ! (
270+ "Converted group rows expected struct of {} fields got {}" ,
271+ expected_fields. len( ) ,
272+ actual_fields. len( ) ,
273+ ) ) ) ;
274+ }
275+
276+ let actual_struct = actual. as_any ( ) . downcast_ref :: < StructArray > ( ) . unwrap ( ) ;
277+ Ok ( expected_fields
278+ . iter ( )
279+ . zip ( actual_struct. columns ( ) . iter ( ) )
280+ . map ( |( expected_field, actual_column) | {
281+ // Propagate the result of needs_nested_dictionary_encoding
282+ needs_nested_dictionary_encoding (
283+ expected_field. data_type ( ) ,
284+ actual_column,
285+ )
286+ } )
287+ . try_fold ( false , |acc, needs_nested| {
288+ Ok :: < bool , DataFusionError > ( acc || needs_nested?)
289+ } ) ?)
290+ }
291+ ( & DataType :: List ( ref expected_field) , & DataType :: List ( _) ) => {
292+ let actual_list = actual. as_any ( ) . downcast_ref :: < ListArray > ( ) . unwrap ( ) ;
293+ needs_nested_dictionary_encoding (
294+ expected_field. data_type ( ) ,
295+ actual_list. values ( ) ,
296+ )
297+ }
298+ ( & DataType :: Dictionary ( _, ref value) , _) => {
299+ let actual_data_type = actual. data_type ( ) ;
300+ let value_data_type = value. as_ref ( ) ;
301+ if value_data_type != actual_data_type {
302+ return Err ( DataFusionError :: Internal ( format ! (
303+ "Converted group rows expected dictionary of {value_data_type} got {actual_data_type}"
304+ ) ) ) ;
305+ }
306+
307+ Ok ( true )
308+ }
309+ ( _, _) => Ok ( false ) ,
310+ }
311+ }
312+
313+ fn dictionary_encode_nested ( array : ArrayRef , expected : & DataType ) -> Result < ArrayRef > {
314+ match ( expected, array. data_type ( ) ) {
315+ ( & DataType :: Struct ( ref expected_fields) , _) => {
316+ let struct_array = array. as_any ( ) . downcast_ref :: < StructArray > ( ) . unwrap ( ) ;
317+ let arrays = expected_fields
318+ . iter ( )
319+ . zip ( struct_array. columns ( ) )
320+ . map ( |( expected_field, column) | {
321+ dictionary_encode_nested ( column. clone ( ) , expected_field. data_type ( ) )
322+ } )
323+ . collect :: < Result < Vec < _ > > > ( ) ?;
324+
325+ Ok ( Arc :: new ( StructArray :: try_new (
326+ expected_fields. clone ( ) ,
327+ arrays,
328+ struct_array. nulls ( ) . cloned ( ) ,
329+ ) ?) )
330+ }
331+ ( & DataType :: List ( ref expected_field) , & DataType :: List ( _) ) => {
332+ let list = array. as_any ( ) . downcast_ref :: < ListArray > ( ) . unwrap ( ) ;
333+
334+ Ok ( Arc :: new ( ListArray :: try_new (
335+ expected_field. clone ( ) ,
336+ list. offsets ( ) . clone ( ) ,
337+ dictionary_encode_nested (
338+ list. values ( ) . clone ( ) ,
339+ expected_field. data_type ( ) ,
340+ ) ?,
341+ list. nulls ( ) . cloned ( ) ,
342+ ) ?) )
343+ }
344+ ( & DataType :: Dictionary ( _, _) , _) => Ok ( cast ( array. as_ref ( ) , expected) ?) ,
345+ ( _, _) => Ok ( array. clone ( ) ) ,
346+ }
347+ }
0 commit comments