@@ -565,11 +565,12 @@ where
565565 let sums = std:: mem:: take ( & mut self . sums ) ;
566566 let nulls = self . null_state . build ( ) ;
567567
568+ assert_eq ! ( nulls. len( ) , sums. len( ) ) ;
568569 assert_eq ! ( counts. len( ) , sums. len( ) ) ;
569570
570571 // don't evaluate averages with null inputs to avoid errors on null values
571- let array : PrimitiveArray < T > = if let Some ( nulls ) = nulls . as_ref ( ) {
572- assert_eq ! ( nulls. len ( ) , sums . len ( ) ) ;
572+
573+ let array : PrimitiveArray < T > = if nulls. null_count ( ) > 0 {
573574 let mut builder = PrimitiveBuilder :: < T > :: with_capacity ( nulls. len ( ) ) ;
574575 let iter = sums. into_iter ( ) . zip ( counts. into_iter ( ) ) . zip ( nulls. iter ( ) ) ;
575576
@@ -587,7 +588,7 @@ where
587588 . zip ( counts. into_iter ( ) )
588589 . map ( |( sum, count) | ( self . avg_fn ) ( sum, count) )
589590 . collect :: < Result < Vec < _ > > > ( ) ?;
590- PrimitiveArray :: new ( averages. into ( ) , nulls) // no copy
591+ PrimitiveArray :: new ( averages. into ( ) , Some ( nulls) ) // no copy
591592 } ;
592593
593594 // fix up decimal precision and scale for decimals
@@ -598,9 +599,9 @@ where
598599
599600 // return arrays for sums and counts
600601 fn state ( & mut self ) -> Result < Vec < ArrayRef > > {
601- let nulls = self . null_state . build ( ) ;
602+ let nulls = Some ( self . null_state . build ( ) ) ;
602603 let counts = std:: mem:: take ( & mut self . counts ) ;
603- let counts = UInt64Array :: from ( counts) ; // zero copy
604+ let counts = UInt64Array :: new ( counts. into ( ) , nulls . clone ( ) ) ; // zero copy
604605
605606 let sums = std:: mem:: take ( & mut self . sums ) ;
606607 let sums = PrimitiveArray :: < T > :: new ( sums. into ( ) , nulls) ; // zero copy
0 commit comments