Skip to content

Commit 5e41c7f

Browse files
committed
Use logical null count in case_when_with_expr
1 parent 0bd127f commit 5e41c7f

File tree

1 file changed

+81
-2
lines changed
  • datafusion/physical-expr/src/expressions

1 file changed

+81
-2
lines changed

datafusion/physical-expr/src/expressions/case.rs

Lines changed: 81 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -866,12 +866,13 @@ impl CaseBody {
866866
// Since each when expression is tested against the base expression using the equality
867867
// operator, null base values can never match any when expression. `x = NULL` is falsy,
868868
// for all possible values of `x`.
869-
if base_values.null_count() > 0 {
869+
let base_null_count = base_values.logical_null_count();
870+
if base_null_count > 0 {
870871
// Use `is_not_null` since this is a cheap clone of the null buffer from 'base_value'.
871872
// We already checked there are nulls, so we can be sure a new buffer will not be
872873
// created.
873874
let base_not_nulls = is_not_null(base_values.as_ref())?;
874-
let base_all_null = base_values.null_count() == remainder_batch.num_rows();
875+
let base_all_null = base_null_count == remainder_batch.num_rows();
875876

876877
// If there is an else expression, use that as the default value for the null rows
877878
// Otherwise the default `null` value from the result builder will be used.
@@ -1545,6 +1546,84 @@ mod tests {
15451546
Ok(())
15461547
}
15471548

1549+
#[test]
1550+
fn case_with_expr_dictionary() -> Result<()> {
1551+
let schema = Schema::new(vec![Field::new(
1552+
"a",
1553+
DataType::Dictionary(Box::new(DataType::UInt8), Box::new(DataType::Utf8)),
1554+
true,
1555+
)]);
1556+
let keys = UInt8Array::from(vec![0u8, 1u8, 2u8, 3u8]);
1557+
let values = StringArray::from(vec![Some("foo"), Some("baz"), None, Some("bar")]);
1558+
let dictionary = DictionaryArray::new(keys, Arc::new(values));
1559+
let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(dictionary)])?;
1560+
1561+
let schema = batch.schema();
1562+
1563+
// CASE a WHEN 'foo' THEN 123 WHEN 'bar' THEN 456 END
1564+
let when1 = lit("foo");
1565+
let then1 = lit(123i32);
1566+
let when2 = lit("bar");
1567+
let then2 = lit(456i32);
1568+
1569+
let expr = generate_case_when_with_type_coercion(
1570+
Some(col("a", &schema)?),
1571+
vec![(when1, then1), (when2, then2)],
1572+
None,
1573+
schema.as_ref(),
1574+
)?;
1575+
let result = expr
1576+
.evaluate(&batch)?
1577+
.into_array(batch.num_rows())
1578+
.expect("Failed to convert to array");
1579+
let result = as_int32_array(&result)?;
1580+
1581+
let expected = &Int32Array::from(vec![Some(123), None, None, Some(456)]);
1582+
1583+
assert_eq!(expected, result);
1584+
1585+
Ok(())
1586+
}
1587+
1588+
#[test]
1589+
fn case_with_expr_all_null_dictionary() -> Result<()> {
1590+
let schema = Schema::new(vec![Field::new(
1591+
"a",
1592+
DataType::Dictionary(Box::new(DataType::UInt8), Box::new(DataType::Utf8)),
1593+
true,
1594+
)]);
1595+
let keys = UInt8Array::from(vec![2u8, 2u8, 2u8, 2u8]);
1596+
let values = StringArray::from(vec![Some("foo"), Some("baz"), None, Some("bar")]);
1597+
let dictionary = DictionaryArray::new(keys, Arc::new(values));
1598+
let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(dictionary)])?;
1599+
1600+
let schema = batch.schema();
1601+
1602+
// CASE a WHEN 'foo' THEN 123 WHEN 'bar' THEN 456 END
1603+
let when1 = lit("foo");
1604+
let then1 = lit(123i32);
1605+
let when2 = lit("bar");
1606+
let then2 = lit(456i32);
1607+
1608+
let expr = generate_case_when_with_type_coercion(
1609+
Some(col("a", &schema)?),
1610+
vec![(when1, then1), (when2, then2)],
1611+
None,
1612+
schema.as_ref(),
1613+
)?;
1614+
let result = expr
1615+
.evaluate(&batch)?
1616+
.into_array(batch.num_rows())
1617+
.expect("Failed to convert to array");
1618+
let result = as_int32_array(&result)?;
1619+
1620+
let expected = &Int32Array::from(vec![None, None, None, None]);
1621+
1622+
assert_eq!(expected, result);
1623+
1624+
Ok(())
1625+
}
1626+
15481627
#[test]
15491628
fn case_with_expr_else() -> Result<()> {
15501629
let batch = case_test_batch()?;

0 commit comments

Comments
 (0)