@@ -866,12 +866,13 @@ impl CaseBody {
866866 // Since each when expression is tested against the base expression using the equality
867867 // operator, null base values can never match any when expression. `x = NULL` is falsy,
868868 // for all possible values of `x`.
869- if base_values. null_count ( ) > 0 {
869+ let base_null_count = base_values. logical_null_count ( ) ;
870+ if base_null_count > 0 {
870871 // Use `is_not_null` since this is a cheap clone of the null buffer from 'base_value'.
871872 // We already checked there are nulls, so we can be sure a new buffer will not be
872873 // created.
873874 let base_not_nulls = is_not_null ( base_values. as_ref ( ) ) ?;
874- let base_all_null = base_values . null_count ( ) == remainder_batch. num_rows ( ) ;
875+ let base_all_null = base_null_count == remainder_batch. num_rows ( ) ;
875876
876877 // If there is an else expression, use that as the default value for the null rows
877878 // Otherwise the default `null` value from the result builder will be used.
@@ -1545,6 +1546,84 @@ mod tests {
15451546 Ok ( ( ) )
15461547 }
15471548
1549+ #[ test]
1550+ fn case_with_expr_dictionary ( ) -> Result < ( ) > {
1551+ let schema = Schema :: new ( vec ! [ Field :: new(
1552+ "a" ,
1553+ DataType :: Dictionary ( Box :: new( DataType :: UInt8 ) , Box :: new( DataType :: Utf8 ) ) ,
1554+ true ,
1555+ ) ] ) ;
1556+ let keys = UInt8Array :: from ( vec ! [ 0u8 , 1u8 , 2u8 , 3u8 ] ) ;
1557+ let values = StringArray :: from ( vec ! [ Some ( "foo" ) , Some ( "baz" ) , None , Some ( "bar" ) ] ) ;
1558+ let dictionary = DictionaryArray :: new ( keys, Arc :: new ( values) ) ;
1559+ let batch = RecordBatch :: try_new ( Arc :: new ( schema) , vec ! [ Arc :: new( dictionary) ] ) ?;
1560+
1561+ let schema = batch. schema ( ) ;
1562+
1563+ // CASE a WHEN 'foo' THEN 123 WHEN 'bar' THEN 456 END
1564+ let when1 = lit ( "foo" ) ;
1565+ let then1 = lit ( 123i32 ) ;
1566+ let when2 = lit ( "bar" ) ;
1567+ let then2 = lit ( 456i32 ) ;
1568+
1569+ let expr = generate_case_when_with_type_coercion (
1570+ Some ( col ( "a" , & schema) ?) ,
1571+ vec ! [ ( when1, then1) , ( when2, then2) ] ,
1572+ None ,
1573+ schema. as_ref ( ) ,
1574+ ) ?;
1575+ let result = expr
1576+ . evaluate ( & batch) ?
1577+ . into_array ( batch. num_rows ( ) )
1578+ . expect ( "Failed to convert to array" ) ;
1579+ let result = as_int32_array ( & result) ?;
1580+
1581+ let expected = & Int32Array :: from ( vec ! [ Some ( 123 ) , None , None , Some ( 456 ) ] ) ;
1582+
1583+ assert_eq ! ( expected, result) ;
1584+
1585+ Ok ( ( ) )
1586+ }
1587+
1588+ #[ test]
1589+ fn case_with_expr_all_null_dictionary ( ) -> Result < ( ) > {
1590+ let schema = Schema :: new ( vec ! [ Field :: new(
1591+ "a" ,
1592+ DataType :: Dictionary ( Box :: new( DataType :: UInt8 ) , Box :: new( DataType :: Utf8 ) ) ,
1593+ true ,
1594+ ) ] ) ;
1595+ let keys = UInt8Array :: from ( vec ! [ 2u8 , 2u8 , 2u8 , 2u8 ] ) ;
1596+ let values = StringArray :: from ( vec ! [ Some ( "foo" ) , Some ( "baz" ) , None , Some ( "bar" ) ] ) ;
1597+ let dictionary = DictionaryArray :: new ( keys, Arc :: new ( values) ) ;
1598+ let batch = RecordBatch :: try_new ( Arc :: new ( schema) , vec ! [ Arc :: new( dictionary) ] ) ?;
1599+
1600+ let schema = batch. schema ( ) ;
1601+
1602+ // CASE a WHEN 'foo' THEN 123 WHEN 'bar' THEN 456 END
1603+ let when1 = lit ( "foo" ) ;
1604+ let then1 = lit ( 123i32 ) ;
1605+ let when2 = lit ( "bar" ) ;
1606+ let then2 = lit ( 456i32 ) ;
1607+
1608+ let expr = generate_case_when_with_type_coercion (
1609+ Some ( col ( "a" , & schema) ?) ,
1610+ vec ! [ ( when1, then1) , ( when2, then2) ] ,
1611+ None ,
1612+ schema. as_ref ( ) ,
1613+ ) ?;
1614+ let result = expr
1615+ . evaluate ( & batch) ?
1616+ . into_array ( batch. num_rows ( ) )
1617+ . expect ( "Failed to convert to array" ) ;
1618+ let result = as_int32_array ( & result) ?;
1619+
1620+ let expected = & Int32Array :: from ( vec ! [ None , None , None , None ] ) ;
1621+
1622+ assert_eq ! ( expected, result) ;
1623+
1624+ Ok ( ( ) )
1625+ }
1626+
15481627 #[ test]
15491628 fn case_with_expr_else ( ) -> Result < ( ) > {
15501629 let batch = case_test_batch ( ) ?;
0 commit comments