@@ -32,10 +32,33 @@ use datafusion_common::cast::as_boolean_array;
3232use datafusion_common:: { exec_err, internal_err, DataFusionError , Result , ScalarValue } ;
3333use datafusion_expr:: ColumnarValue ;
3434
35+ use datafusion_physical_expr_common:: expressions:: column:: Column ;
36+ use datafusion_physical_expr_common:: expressions:: Literal ;
3537use itertools:: Itertools ;
3638
3739type WhenThen = ( Arc < dyn PhysicalExpr > , Arc < dyn PhysicalExpr > ) ;
3840
41+ #[ derive( Debug , Hash ) ]
42+ enum EvalMethod {
43+ /// CASE WHEN condition THEN result
44+ /// [WHEN ...]
45+ /// [ELSE result]
46+ /// END
47+ NoExpression ,
48+ /// CASE expression
49+ /// WHEN value THEN result
50+ /// [WHEN ...]
51+ /// [ELSE result]
52+ /// END
53+ WithExpression ,
54+ /// This is a specialization for a specific use case where we can take a fast path
55+ /// for expressions that are infallible and can be cheaply computed for the entire
56+ /// record batch rather than just for the rows where the predicate is true.
57+ ///
58+ /// CASE WHEN condition THEN column [ELSE NULL] END
59+ InfallibleExprOrNull ,
60+ }
61+
3962/// The CASE expression is similar to a series of nested if/else and there are two forms that
4063/// can be used. The first form consists of a series of boolean "when" expressions with
4164/// corresponding "then" expressions, and an optional "else" expression.
@@ -61,6 +84,8 @@ pub struct CaseExpr {
6184 when_then_expr : Vec < WhenThen > ,
6285 /// Optional "else" expression
6386 else_expr : Option < Arc < dyn PhysicalExpr > > ,
87+ /// Evaluation method to use
88+ eval_method : EvalMethod ,
6489}
6590
6691impl std:: fmt:: Display for CaseExpr {
@@ -79,20 +104,51 @@ impl std::fmt::Display for CaseExpr {
79104 }
80105}
81106
107+ /// This is a specialization for a specific use case where we can take a fast path
108+ /// for expressions that are infallible and can be cheaply computed for the entire
109+ /// record batch rather than just for the rows where the predicate is true. For now,
110+ /// this is limited to use with Column expressions but could potentially be used for other
111+ /// expressions in the future
112+ fn is_cheap_and_infallible ( expr : & Arc < dyn PhysicalExpr > ) -> bool {
113+ expr. as_any ( ) . is :: < Column > ( )
114+ }
115+
82116impl CaseExpr {
83117 /// Create a new CASE WHEN expression
84118 pub fn try_new (
85119 expr : Option < Arc < dyn PhysicalExpr > > ,
86120 when_then_expr : Vec < WhenThen > ,
87121 else_expr : Option < Arc < dyn PhysicalExpr > > ,
88122 ) -> Result < Self > {
123+ // normalize null literals to None in the else_expr (this already happens
124+ // during SQL planning, but not necessarily for other use cases)
125+ let else_expr = match & else_expr {
126+ Some ( e) => match e. as_any ( ) . downcast_ref :: < Literal > ( ) {
127+ Some ( lit) if lit. value ( ) . is_null ( ) => None ,
128+ _ => else_expr,
129+ } ,
130+ _ => else_expr,
131+ } ;
132+
89133 if when_then_expr. is_empty ( ) {
90134 exec_err ! ( "There must be at least one WHEN clause" )
91135 } else {
136+ let eval_method = if expr. is_some ( ) {
137+ EvalMethod :: WithExpression
138+ } else if when_then_expr. len ( ) == 1
139+ && is_cheap_and_infallible ( & ( when_then_expr[ 0 ] . 1 ) )
140+ && else_expr. is_none ( )
141+ {
142+ EvalMethod :: InfallibleExprOrNull
143+ } else {
144+ EvalMethod :: NoExpression
145+ } ;
146+
92147 Ok ( Self {
93148 expr,
94149 when_then_expr,
95150 else_expr,
151+ eval_method,
96152 } )
97153 }
98154 }
@@ -256,6 +312,38 @@ impl CaseExpr {
256312
257313 Ok ( ColumnarValue :: Array ( current_value) )
258314 }
315+
316+ /// This function evaluates the specialized case of:
317+ ///
318+ /// CASE WHEN condition THEN column
319+ /// [ELSE NULL]
320+ /// END
321+ ///
322+ /// Note that this function is only safe to use for "then" expressions
323+ /// that are infallible because the expression will be evaluated for all
324+ /// rows in the input batch.
325+ fn case_column_or_null ( & self , batch : & RecordBatch ) -> Result < ColumnarValue > {
326+ let when_expr = & self . when_then_expr [ 0 ] . 0 ;
327+ let then_expr = & self . when_then_expr [ 0 ] . 1 ;
328+ if let ColumnarValue :: Array ( bit_mask) = when_expr. evaluate ( batch) ? {
329+ let bit_mask = bit_mask
330+ . as_any ( )
331+ . downcast_ref :: < BooleanArray > ( )
332+ . expect ( "predicate should evaluate to a boolean array" ) ;
333+ // invert the bitmask
334+ let bit_mask = not ( bit_mask) ?;
335+ match then_expr. evaluate ( batch) ? {
336+ ColumnarValue :: Array ( array) => {
337+ Ok ( ColumnarValue :: Array ( nullif ( & array, & bit_mask) ?) )
338+ }
339+ ColumnarValue :: Scalar ( _) => {
340+ internal_err ! ( "expression did not evaluate to an array" )
341+ }
342+ }
343+ } else {
344+ internal_err ! ( "predicate did not evaluate to an array" )
345+ }
346+ }
259347}
260348
261349impl PhysicalExpr for CaseExpr {
@@ -303,14 +391,21 @@ impl PhysicalExpr for CaseExpr {
303391 }
304392
305393 fn evaluate ( & self , batch : & RecordBatch ) -> Result < ColumnarValue > {
306- if self . expr . is_some ( ) {
307- // this use case evaluates "expr" and then compares the values with the "when"
308- // values
309- self . case_when_with_expr ( batch)
310- } else {
311- // The "when" conditions all evaluate to boolean in this use case and can be
312- // arbitrary expressions
313- self . case_when_no_expr ( batch)
394+ match self . eval_method {
395+ EvalMethod :: WithExpression => {
396+ // this use case evaluates "expr" and then compares the values with the "when"
397+ // values
398+ self . case_when_with_expr ( batch)
399+ }
400+ EvalMethod :: NoExpression => {
401+ // The "when" conditions all evaluate to boolean in this use case and can be
402+ // arbitrary expressions
403+ self . case_when_no_expr ( batch)
404+ }
405+ EvalMethod :: InfallibleExprOrNull => {
406+ // Specialization for CASE WHEN expr THEN column [ELSE NULL] END
407+ self . case_column_or_null ( batch)
408+ }
314409 }
315410 }
316411
@@ -409,7 +504,7 @@ pub fn case(
409504#[ cfg( test) ]
410505mod tests {
411506 use super :: * ;
412- use crate :: expressions:: { binary, cast, col, lit} ;
507+ use crate :: expressions:: { binary, cast, col, lit, BinaryExpr } ;
413508
414509 use arrow:: buffer:: Buffer ;
415510 use arrow:: datatypes:: DataType :: Float64 ;
@@ -419,6 +514,7 @@ mod tests {
419514 use datafusion_common:: tree_node:: { Transformed , TransformedResult , TreeNode } ;
420515 use datafusion_expr:: type_coercion:: binary:: comparison_coercion;
421516 use datafusion_expr:: Operator ;
517+ use datafusion_physical_expr_common:: expressions:: Literal ;
422518
423519 #[ test]
424520 fn case_with_expr ( ) -> Result < ( ) > {
@@ -998,6 +1094,53 @@ mod tests {
9981094 Ok ( ( ) )
9991095 }
10001096
1097+ #[ test]
1098+ fn test_column_or_null_specialization ( ) -> Result < ( ) > {
1099+ // create input data
1100+ let mut c1 = Int32Builder :: new ( ) ;
1101+ let mut c2 = StringBuilder :: new ( ) ;
1102+ for i in 0 ..1000 {
1103+ c1. append_value ( i) ;
1104+ if i % 7 == 0 {
1105+ c2. append_null ( ) ;
1106+ } else {
1107+ c2. append_value ( & format ! ( "string {i}" ) ) ;
1108+ }
1109+ }
1110+ let c1 = Arc :: new ( c1. finish ( ) ) ;
1111+ let c2 = Arc :: new ( c2. finish ( ) ) ;
1112+ let schema = Schema :: new ( vec ! [
1113+ Field :: new( "c1" , DataType :: Int32 , true ) ,
1114+ Field :: new( "c2" , DataType :: Utf8 , true ) ,
1115+ ] ) ;
1116+ let batch = RecordBatch :: try_new ( Arc :: new ( schema) , vec ! [ c1, c2] ) . unwrap ( ) ;
1117+
1118+ // CaseWhenExprOrNull should produce same results as CaseExpr
1119+ let predicate = Arc :: new ( BinaryExpr :: new (
1120+ make_col ( "c1" , 0 ) ,
1121+ Operator :: LtEq ,
1122+ make_lit_i32 ( 250 ) ,
1123+ ) ) ;
1124+ let expr = CaseExpr :: try_new ( None , vec ! [ ( predicate, make_col( "c2" , 1 ) ) ] , None ) ?;
1125+ assert ! ( matches!( expr. eval_method, EvalMethod :: InfallibleExprOrNull ) ) ;
1126+ match expr. evaluate ( & batch) ? {
1127+ ColumnarValue :: Array ( array) => {
1128+ assert_eq ! ( 1000 , array. len( ) ) ;
1129+ assert_eq ! ( 785 , array. null_count( ) ) ;
1130+ }
1131+ _ => unreachable ! ( ) ,
1132+ }
1133+ Ok ( ( ) )
1134+ }
1135+
1136+ fn make_col ( name : & str , index : usize ) -> Arc < dyn PhysicalExpr > {
1137+ Arc :: new ( Column :: new ( name, index) )
1138+ }
1139+
1140+ fn make_lit_i32 ( n : i32 ) -> Arc < dyn PhysicalExpr > {
1141+ Arc :: new ( Literal :: new ( ScalarValue :: Int32 ( Some ( n) ) ) )
1142+ }
1143+
10011144 fn generate_case_when_with_type_coercion (
10021145 expr : Option < Arc < dyn PhysicalExpr > > ,
10031146 when_thens : Vec < WhenThen > ,
0 commit comments