Skip to content

Commit 28fa74b

Browse files
authored
feat: Optimize CASE expression for "column or null" use case (#11534)
1 parent 5f0dfbb commit 28fa74b

File tree

5 files changed

+242
-14
lines changed

5 files changed

+242
-14
lines changed

datafusion/core/example.parquet

976 Bytes
Binary file not shown.

datafusion/physical-expr/benches/case_when.rs

Lines changed: 37 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -40,21 +40,29 @@ fn criterion_benchmark(c: &mut Criterion) {
4040
// create input data
4141
let mut c1 = Int32Builder::new();
4242
let mut c2 = StringBuilder::new();
43+
let mut c3 = StringBuilder::new();
4344
for i in 0..1000 {
4445
c1.append_value(i);
4546
if i % 7 == 0 {
4647
c2.append_null();
4748
} else {
4849
c2.append_value(&format!("string {i}"));
4950
}
51+
if i % 9 == 0 {
52+
c3.append_null();
53+
} else {
54+
c3.append_value(&format!("other string {i}"));
55+
}
5056
}
5157
let c1 = Arc::new(c1.finish());
5258
let c2 = Arc::new(c2.finish());
59+
let c3 = Arc::new(c3.finish());
5360
let schema = Schema::new(vec![
5461
Field::new("c1", DataType::Int32, true),
5562
Field::new("c2", DataType::Utf8, true),
63+
Field::new("c3", DataType::Utf8, true),
5664
]);
57-
let batch = RecordBatch::try_new(Arc::new(schema), vec![c1, c2]).unwrap();
65+
let batch = RecordBatch::try_new(Arc::new(schema), vec![c1, c2, c3]).unwrap();
5866

5967
// use same predicate for all benchmarks
6068
let predicate = Arc::new(BinaryExpr::new(
@@ -63,7 +71,7 @@ fn criterion_benchmark(c: &mut Criterion) {
6371
make_lit_i32(500),
6472
));
6573

66-
// CASE WHEN expr THEN 1 ELSE 0 END
74+
// CASE WHEN c1 <= 500 THEN 1 ELSE 0 END
6775
c.bench_function("case_when: scalar or scalar", |b| {
6876
let expr = Arc::new(
6977
CaseExpr::try_new(
@@ -76,13 +84,38 @@ fn criterion_benchmark(c: &mut Criterion) {
7684
b.iter(|| black_box(expr.evaluate(black_box(&batch)).unwrap()))
7785
});
7886

79-
// CASE WHEN expr THEN col ELSE null END
87+
// CASE WHEN c1 <= 500 THEN c2 [ELSE NULL] END
8088
c.bench_function("case_when: column or null", |b| {
89+
let expr = Arc::new(
90+
CaseExpr::try_new(None, vec![(predicate.clone(), make_col("c2", 1))], None)
91+
.unwrap(),
92+
);
93+
b.iter(|| black_box(expr.evaluate(black_box(&batch)).unwrap()))
94+
});
95+
96+
// CASE WHEN c1 <= 500 THEN c2 ELSE c3 END
97+
c.bench_function("case_when: expr or expr", |b| {
8198
let expr = Arc::new(
8299
CaseExpr::try_new(
83100
None,
84101
vec![(predicate.clone(), make_col("c2", 1))],
85-
Some(Arc::new(Literal::new(ScalarValue::Utf8(None)))),
102+
Some(make_col("c3", 2)),
103+
)
104+
.unwrap(),
105+
);
106+
b.iter(|| black_box(expr.evaluate(black_box(&batch)).unwrap()))
107+
});
108+
109+
// CASE c1 WHEN 1 THEN c2 WHEN 2 THEN c3 END
110+
c.bench_function("case_when: CASE expr", |b| {
111+
let expr = Arc::new(
112+
CaseExpr::try_new(
113+
Some(make_col("c1", 0)),
114+
vec![
115+
(make_lit_i32(1), make_col("c2", 1)),
116+
(make_lit_i32(2), make_col("c3", 2)),
117+
],
118+
None,
86119
)
87120
.unwrap(),
88121
);

datafusion/physical-expr/src/expressions/case.rs

Lines changed: 152 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -32,10 +32,33 @@ use datafusion_common::cast::as_boolean_array;
3232
use datafusion_common::{exec_err, internal_err, DataFusionError, Result, ScalarValue};
3333
use datafusion_expr::ColumnarValue;
3434

35+
use datafusion_physical_expr_common::expressions::column::Column;
36+
use datafusion_physical_expr_common::expressions::Literal;
3537
use itertools::Itertools;
3638

3739
type WhenThen = (Arc<dyn PhysicalExpr>, Arc<dyn PhysicalExpr>);
3840

41+
#[derive(Debug, Hash)]
42+
enum EvalMethod {
43+
/// CASE WHEN condition THEN result
44+
/// [WHEN ...]
45+
/// [ELSE result]
46+
/// END
47+
NoExpression,
48+
/// CASE expression
49+
/// WHEN value THEN result
50+
/// [WHEN ...]
51+
/// [ELSE result]
52+
/// END
53+
WithExpression,
54+
/// This is a specialization for a specific use case where we can take a fast path
55+
/// for expressions that are infallible and can be cheaply computed for the entire
56+
/// record batch rather than just for the rows where the predicate is true.
57+
///
58+
/// CASE WHEN condition THEN column [ELSE NULL] END
59+
InfallibleExprOrNull,
60+
}
61+
3962
/// The CASE expression is similar to a series of nested if/else and there are two forms that
4063
/// can be used. The first form consists of a series of boolean "when" expressions with
4164
/// corresponding "then" expressions, and an optional "else" expression.
@@ -61,6 +84,8 @@ pub struct CaseExpr {
6184
when_then_expr: Vec<WhenThen>,
6285
/// Optional "else" expression
6386
else_expr: Option<Arc<dyn PhysicalExpr>>,
87+
/// Evaluation method to use
88+
eval_method: EvalMethod,
6489
}
6590

6691
impl std::fmt::Display for CaseExpr {
@@ -79,20 +104,51 @@ impl std::fmt::Display for CaseExpr {
79104
}
80105
}
81106

107+
/// This is a specialization for a specific use case where we can take a fast path
108+
/// for expressions that are infallible and can be cheaply computed for the entire
109+
/// record batch rather than just for the rows where the predicate is true. For now,
110+
/// this is limited to use with Column expressions but could potentially be used for other
111+
/// expressions in the future
112+
fn is_cheap_and_infallible(expr: &Arc<dyn PhysicalExpr>) -> bool {
113+
expr.as_any().is::<Column>()
114+
}
115+
82116
impl CaseExpr {
83117
/// Create a new CASE WHEN expression
84118
pub fn try_new(
85119
expr: Option<Arc<dyn PhysicalExpr>>,
86120
when_then_expr: Vec<WhenThen>,
87121
else_expr: Option<Arc<dyn PhysicalExpr>>,
88122
) -> Result<Self> {
123+
// normalize null literals to None in the else_expr (this already happens
124+
// during SQL planning, but not necessarily for other use cases)
125+
let else_expr = match &else_expr {
126+
Some(e) => match e.as_any().downcast_ref::<Literal>() {
127+
Some(lit) if lit.value().is_null() => None,
128+
_ => else_expr,
129+
},
130+
_ => else_expr,
131+
};
132+
89133
if when_then_expr.is_empty() {
90134
exec_err!("There must be at least one WHEN clause")
91135
} else {
136+
let eval_method = if expr.is_some() {
137+
EvalMethod::WithExpression
138+
} else if when_then_expr.len() == 1
139+
&& is_cheap_and_infallible(&(when_then_expr[0].1))
140+
&& else_expr.is_none()
141+
{
142+
EvalMethod::InfallibleExprOrNull
143+
} else {
144+
EvalMethod::NoExpression
145+
};
146+
92147
Ok(Self {
93148
expr,
94149
when_then_expr,
95150
else_expr,
151+
eval_method,
96152
})
97153
}
98154
}
@@ -256,6 +312,38 @@ impl CaseExpr {
256312

257313
Ok(ColumnarValue::Array(current_value))
258314
}
315+
316+
/// This function evaluates the specialized case of:
317+
///
318+
/// CASE WHEN condition THEN column
319+
/// [ELSE NULL]
320+
/// END
321+
///
322+
/// Note that this function is only safe to use for "then" expressions
323+
/// that are infallible because the expression will be evaluated for all
324+
/// rows in the input batch.
325+
fn case_column_or_null(&self, batch: &RecordBatch) -> Result<ColumnarValue> {
326+
let when_expr = &self.when_then_expr[0].0;
327+
let then_expr = &self.when_then_expr[0].1;
328+
if let ColumnarValue::Array(bit_mask) = when_expr.evaluate(batch)? {
329+
let bit_mask = bit_mask
330+
.as_any()
331+
.downcast_ref::<BooleanArray>()
332+
.expect("predicate should evaluate to a boolean array");
333+
// invert the bitmask
334+
let bit_mask = not(bit_mask)?;
335+
match then_expr.evaluate(batch)? {
336+
ColumnarValue::Array(array) => {
337+
Ok(ColumnarValue::Array(nullif(&array, &bit_mask)?))
338+
}
339+
ColumnarValue::Scalar(_) => {
340+
internal_err!("expression did not evaluate to an array")
341+
}
342+
}
343+
} else {
344+
internal_err!("predicate did not evaluate to an array")
345+
}
346+
}
259347
}
260348

261349
impl PhysicalExpr for CaseExpr {
@@ -303,14 +391,21 @@ impl PhysicalExpr for CaseExpr {
303391
}
304392

305393
fn evaluate(&self, batch: &RecordBatch) -> Result<ColumnarValue> {
306-
if self.expr.is_some() {
307-
// this use case evaluates "expr" and then compares the values with the "when"
308-
// values
309-
self.case_when_with_expr(batch)
310-
} else {
311-
// The "when" conditions all evaluate to boolean in this use case and can be
312-
// arbitrary expressions
313-
self.case_when_no_expr(batch)
394+
match self.eval_method {
395+
EvalMethod::WithExpression => {
396+
// this use case evaluates "expr" and then compares the values with the "when"
397+
// values
398+
self.case_when_with_expr(batch)
399+
}
400+
EvalMethod::NoExpression => {
401+
// The "when" conditions all evaluate to boolean in this use case and can be
402+
// arbitrary expressions
403+
self.case_when_no_expr(batch)
404+
}
405+
EvalMethod::InfallibleExprOrNull => {
406+
// Specialization for CASE WHEN expr THEN column [ELSE NULL] END
407+
self.case_column_or_null(batch)
408+
}
314409
}
315410
}
316411

@@ -409,7 +504,7 @@ pub fn case(
409504
#[cfg(test)]
410505
mod tests {
411506
use super::*;
412-
use crate::expressions::{binary, cast, col, lit};
507+
use crate::expressions::{binary, cast, col, lit, BinaryExpr};
413508

414509
use arrow::buffer::Buffer;
415510
use arrow::datatypes::DataType::Float64;
@@ -419,6 +514,7 @@ mod tests {
419514
use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode};
420515
use datafusion_expr::type_coercion::binary::comparison_coercion;
421516
use datafusion_expr::Operator;
517+
use datafusion_physical_expr_common::expressions::Literal;
422518

423519
#[test]
424520
fn case_with_expr() -> Result<()> {
@@ -998,6 +1094,53 @@ mod tests {
9981094
Ok(())
9991095
}
10001096

1097+
#[test]
1098+
fn test_column_or_null_specialization() -> Result<()> {
1099+
// create input data
1100+
let mut c1 = Int32Builder::new();
1101+
let mut c2 = StringBuilder::new();
1102+
for i in 0..1000 {
1103+
c1.append_value(i);
1104+
if i % 7 == 0 {
1105+
c2.append_null();
1106+
} else {
1107+
c2.append_value(&format!("string {i}"));
1108+
}
1109+
}
1110+
let c1 = Arc::new(c1.finish());
1111+
let c2 = Arc::new(c2.finish());
1112+
let schema = Schema::new(vec![
1113+
Field::new("c1", DataType::Int32, true),
1114+
Field::new("c2", DataType::Utf8, true),
1115+
]);
1116+
let batch = RecordBatch::try_new(Arc::new(schema), vec![c1, c2]).unwrap();
1117+
1118+
// CaseWhenExprOrNull should produce same results as CaseExpr
1119+
let predicate = Arc::new(BinaryExpr::new(
1120+
make_col("c1", 0),
1121+
Operator::LtEq,
1122+
make_lit_i32(250),
1123+
));
1124+
let expr = CaseExpr::try_new(None, vec![(predicate, make_col("c2", 1))], None)?;
1125+
assert!(matches!(expr.eval_method, EvalMethod::InfallibleExprOrNull));
1126+
match expr.evaluate(&batch)? {
1127+
ColumnarValue::Array(array) => {
1128+
assert_eq!(1000, array.len());
1129+
assert_eq!(785, array.null_count());
1130+
}
1131+
_ => unreachable!(),
1132+
}
1133+
Ok(())
1134+
}
1135+
1136+
fn make_col(name: &str, index: usize) -> Arc<dyn PhysicalExpr> {
1137+
Arc::new(Column::new(name, index))
1138+
}
1139+
1140+
fn make_lit_i32(n: i32) -> Arc<dyn PhysicalExpr> {
1141+
Arc::new(Literal::new(ScalarValue::Int32(Some(n))))
1142+
}
1143+
10011144
fn generate_case_when_with_type_coercion(
10021145
expr: Option<Arc<dyn PhysicalExpr>>,
10031146
when_thens: Vec<WhenThen>,

datafusion/sqllogictest/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ In order to run the sqllogictests running against a previously running Postgres
133133
PG_COMPAT=true PG_URI="postgresql://[email protected]/postgres" cargo test --features=postgres --test sqllogictests
134134
```
135135

136-
The environemnt variables:
136+
The environment variables:
137137

138138
1. `PG_COMPAT` instructs sqllogictest to run against Postgres (not DataFusion)
139139
2. `PG_URI` contains a `libpq` style connection string, whose format is described in
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
18+
# create test data
19+
statement ok
20+
create table foo (a int, b int) as values (1, 2), (3, 4), (5, 6);
21+
22+
# CASE WHEN with condition
23+
query T
24+
SELECT CASE a WHEN 1 THEN 'one' WHEN 3 THEN 'three' ELSE '?' END FROM foo
25+
----
26+
one
27+
three
28+
?
29+
30+
# CASE WHEN with no condition
31+
query I
32+
SELECT CASE WHEN a > 2 THEN a ELSE b END FROM foo
33+
----
34+
2
35+
3
36+
5
37+
38+
# column or explicit null
39+
query I
40+
SELECT CASE WHEN a > 2 THEN b ELSE null END FROM foo
41+
----
42+
NULL
43+
4
44+
6
45+
46+
# column or implicit null
47+
query I
48+
SELECT CASE WHEN a > 2 THEN b END FROM foo
49+
----
50+
NULL
51+
4
52+
6

0 commit comments

Comments
 (0)