@@ -1436,33 +1436,49 @@ impl<S: SimplifyInfo> TreeNodeRewriter for Simplifier<'_, S> {
14361436
14371437 // CASE WHEN true THEN A ... END --> A
14381438 // CASE WHEN X THEN A WHEN TRUE THEN B ... END --> CASE WHEN X THEN A ELSE B END
1439+ // CASE WHEN false THEN A END --> NULL
1440+ // CASE WHEN false THEN A ELSE B END --> B
1441+ // CASE WHEN X THEN A WHEN false THEN B END --> CASE WHEN X THEN A ELSE B END
14391442 Expr :: Case ( Case {
14401443 expr : None ,
1441- mut when_then_expr,
1442- else_expr : _,
1443- // if let guard is not stabilized so we can't use it yet: https://github.com/rust-lang/rust/issues/51114
1444- // Once it's supported we can avoid searching through when_then_expr twice in the below .any() and .position() calls
1445- // }) if let Some(i) = when_then_expr.iter().position(|(when, _)| is_true(when.as_ref())) => {
1444+ when_then_expr,
1445+ mut else_expr,
14461446 } ) if when_then_expr
14471447 . iter ( )
1448- . any ( |( when, _) | is_true ( when. as_ref ( ) ) ) =>
1448+ . any ( |( when, _) | is_true ( when. as_ref ( ) ) || is_false ( when . as_ref ( ) ) ) =>
14491449 {
1450- let i = when_then_expr
1451- . iter ( )
1452- . position ( |( when, _) | is_true ( when. as_ref ( ) ) )
1453- . unwrap ( ) ;
1454- let ( _, then_) = when_then_expr. swap_remove ( i) ;
1455- // CASE WHEN true THEN A ... END --> A
1456- if i == 0 {
1457- return Ok ( Transformed :: yes ( * then_) ) ;
1450+ let out_type = info. get_data_type ( & when_then_expr[ 0 ] . 1 ) ?;
1451+ let mut new_when_then_expr = Vec :: with_capacity ( when_then_expr. len ( ) ) ;
1452+
1453+ for ( when, then) in when_then_expr. into_iter ( ) {
1454+ if is_true ( when. as_ref ( ) ) {
1455+ // Skip adding the rest of the when-then expressions after WHEN true
1456+ // CASE WHEN X THEN A WHEN TRUE THEN B ... END --> CASE WHEN X THEN A ELSE B END
1457+ else_expr = Some ( then) ;
1458+ break ;
1459+ } else if !is_false ( when. as_ref ( ) ) {
1460+ new_when_then_expr. push ( ( when, then) ) ;
1461+ }
1462+ // else: skip WHEN false cases
1463+ }
1464+
1465+ // Exclude CASE statement altogether if there are no when-then expressions left
1466+ if new_when_then_expr. is_empty ( ) {
1467+ // CASE WHEN false THEN A ELSE B END --> B
1468+ if let Some ( else_expr) = else_expr {
1469+ return Ok ( Transformed :: yes ( * else_expr) ) ;
1470+ // CASE WHEN false THEN A END --> NULL
1471+ } else {
1472+ let null =
1473+ Expr :: Literal ( ScalarValue :: try_new_null ( & out_type) ?, None ) ;
1474+ return Ok ( Transformed :: yes ( null) ) ;
1475+ }
14581476 }
14591477
1460- // CASE WHEN X THEN A WHEN TRUE THEN B ... END --> CASE WHEN X THEN A ELSE B END
1461- when_then_expr. truncate ( i) ;
14621478 Transformed :: yes ( Expr :: Case ( Case {
14631479 expr : None ,
1464- when_then_expr,
1465- else_expr : Some ( then_ ) ,
1480+ when_then_expr : new_when_then_expr ,
1481+ else_expr,
14661482 } ) )
14671483 }
14681484
@@ -3810,53 +3826,53 @@ mod tests {
38103826
38113827 #[ test]
38123828 fn simplify_expr_case_when_first_true ( ) {
3813- // CASE WHEN true THEN 1 ELSE x END --> 1
3829+ // CASE WHEN true THEN 1 ELSE c1 END --> 1
38143830 assert_eq ! (
38153831 simplify( Expr :: Case ( Case :: new(
38163832 None ,
38173833 vec![ ( Box :: new( lit( true ) ) , Box :: new( lit( 1 ) ) , ) ] ,
3818- Some ( Box :: new( col( "x " ) ) ) ,
3834+ Some ( Box :: new( col( "c1 " ) ) ) ,
38193835 ) ) ) ,
38203836 lit( 1 )
38213837 ) ;
38223838
3823- // CASE WHEN true THEN col("a" ) ELSE col("b" ) END --> col("a" )
3839+ // CASE WHEN true THEN col('a' ) ELSE col('b' ) END --> col('a' )
38243840 assert_eq ! (
38253841 simplify( Expr :: Case ( Case :: new(
38263842 None ,
3827- vec![ ( Box :: new( lit( true ) ) , Box :: new( col ( "a" ) ) , ) ] ,
3828- Some ( Box :: new( col ( "b" ) ) ) ,
3843+ vec![ ( Box :: new( lit( true ) ) , Box :: new( lit ( "a" ) ) , ) ] ,
3844+ Some ( Box :: new( lit ( "b" ) ) ) ,
38293845 ) ) ) ,
3830- col ( "a" )
3846+ lit ( "a" )
38313847 ) ;
38323848
3833- // CASE WHEN true THEN col("a" ) WHEN col("x" ) > 5 THEN col("b" ) ELSE col("c" ) END --> col("a" )
3849+ // CASE WHEN true THEN col('a' ) WHEN col('x' ) > 5 THEN col('b' ) ELSE col('c' ) END --> col('a' )
38343850 assert_eq ! (
38353851 simplify( Expr :: Case ( Case :: new(
38363852 None ,
38373853 vec![
3838- ( Box :: new( lit( true ) ) , Box :: new( col ( "a" ) ) ) ,
3839- ( Box :: new( col ( "x" ) . gt( lit( 5 ) ) ) , Box :: new( col ( "b" ) ) ) ,
3854+ ( Box :: new( lit( true ) ) , Box :: new( lit ( "a" ) ) ) ,
3855+ ( Box :: new( lit ( "x" ) . gt( lit( 5 ) ) ) , Box :: new( lit ( "b" ) ) ) ,
38403856 ] ,
3841- Some ( Box :: new( col ( "c" ) ) ) ,
3857+ Some ( Box :: new( lit ( "c" ) ) ) ,
38423858 ) ) ) ,
3843- col ( "a" )
3859+ lit ( "a" )
38443860 ) ;
38453861
3846- // CASE WHEN true THEN col("a" ) END --> col("a" ) (no else clause)
3862+ // CASE WHEN true THEN col('a' ) END --> col('a' ) (no else clause)
38473863 assert_eq ! (
38483864 simplify( Expr :: Case ( Case :: new(
38493865 None ,
3850- vec![ ( Box :: new( lit( true ) ) , Box :: new( col ( "a" ) ) , ) ] ,
3866+ vec![ ( Box :: new( lit( true ) ) , Box :: new( lit ( "a" ) ) , ) ] ,
38513867 None ,
38523868 ) ) ) ,
3853- col ( "a" )
3869+ lit ( "a" )
38543870 ) ;
38553871
3856- // Negative test: CASE WHEN a THEN 1 ELSE 2 END should not be simplified
3872+ // Negative test: CASE WHEN c2 THEN 1 ELSE 2 END should not be simplified
38573873 let expr = Expr :: Case ( Case :: new (
38583874 None ,
3859- vec ! [ ( Box :: new( col( "a " ) ) , Box :: new( lit( 1 ) ) ) ] ,
3875+ vec ! [ ( Box :: new( col( "c2 " ) ) , Box :: new( lit( 1 ) ) ) ] ,
38603876 Some ( Box :: new ( lit ( 2 ) ) ) ,
38613877 ) ) ;
38623878 assert_eq ! ( simplify( expr. clone( ) ) , expr) ;
@@ -3869,87 +3885,135 @@ mod tests {
38693885 ) ) ;
38703886 assert_ne ! ( simplify( expr) , lit( 1 ) ) ;
38713887
3872- // Negative test: CASE WHEN col("x" ) > 5 THEN 1 ELSE 2 END should not be simplified
3888+ // Negative test: CASE WHEN col('c1' ) > 5 THEN 1 ELSE 2 END should not be simplified
38733889 let expr = Expr :: Case ( Case :: new (
38743890 None ,
3875- vec ! [ ( Box :: new( col( "x " ) . gt( lit( 5 ) ) ) , Box :: new( lit( 1 ) ) ) ] ,
3891+ vec ! [ ( Box :: new( col( "c1 " ) . gt( lit( 5 ) ) ) , Box :: new( lit( 1 ) ) ) ] ,
38763892 Some ( Box :: new ( lit ( 2 ) ) ) ,
38773893 ) ) ;
38783894 assert_eq ! ( simplify( expr. clone( ) ) , expr) ;
38793895 }
38803896
38813897 #[ test]
38823898 fn simplify_expr_case_when_any_true ( ) {
3883- // CASE WHEN x > 0 THEN a WHEN true THEN b ELSE c END --> CASE WHEN x > 0 THEN a ELSE b END
3899+ // CASE WHEN c3 > 0 THEN 'a' WHEN true THEN 'b' ELSE 'c' END --> CASE WHEN c3 > 0 THEN 'a' ELSE 'b' END
38843900 assert_eq ! (
38853901 simplify( Expr :: Case ( Case :: new(
38863902 None ,
38873903 vec![
3888- ( Box :: new( col( "x " ) . gt( lit( 0 ) ) ) , Box :: new( col ( "a" ) ) ) ,
3889- ( Box :: new( lit( true ) ) , Box :: new( col ( "b" ) ) ) ,
3904+ ( Box :: new( col( "c3 " ) . gt( lit( 0 ) ) ) , Box :: new( lit ( "a" ) ) ) ,
3905+ ( Box :: new( lit( true ) ) , Box :: new( lit ( "b" ) ) ) ,
38903906 ] ,
3891- Some ( Box :: new( col ( "c" ) ) ) ,
3907+ Some ( Box :: new( lit ( "c" ) ) ) ,
38923908 ) ) ) ,
38933909 Expr :: Case ( Case :: new(
38943910 None ,
3895- vec![ ( Box :: new( col( "x " ) . gt( lit( 0 ) ) ) , Box :: new( col ( "a" ) ) ) ] ,
3896- Some ( Box :: new( col ( "b" ) ) ) ,
3911+ vec![ ( Box :: new( col( "c3 " ) . gt( lit( 0 ) ) ) , Box :: new( lit ( "a" ) ) ) ] ,
3912+ Some ( Box :: new( lit ( "b" ) ) ) ,
38973913 ) )
38983914 ) ;
38993915
3900- // CASE WHEN x > 0 THEN a WHEN y < 0 THEN b WHEN true THEN c WHEN z = 0 THEN d ELSE e END
3901- // --> CASE WHEN x > 0 THEN a WHEN y < 0 THEN b ELSE c END
3916+ // CASE WHEN c3 > 0 THEN 'a' WHEN c4 < 0 THEN 'b' WHEN true THEN 'c' WHEN c3 = 0 THEN 'd' ELSE 'e' END
3917+ // --> CASE WHEN c3 > 0 THEN 'a' WHEN c4 < 0 THEN 'b' ELSE 'c' END
39023918 assert_eq ! (
39033919 simplify( Expr :: Case ( Case :: new(
39043920 None ,
39053921 vec![
3906- ( Box :: new( col( "x " ) . gt( lit( 0 ) ) ) , Box :: new( col ( "a" ) ) ) ,
3907- ( Box :: new( col( "y " ) . lt( lit( 0 ) ) ) , Box :: new( col ( "b" ) ) ) ,
3908- ( Box :: new( lit( true ) ) , Box :: new( col ( "c" ) ) ) ,
3909- ( Box :: new( col( "z " ) . eq( lit( 0 ) ) ) , Box :: new( col ( "d" ) ) ) ,
3922+ ( Box :: new( col( "c3 " ) . gt( lit( 0 ) ) ) , Box :: new( lit ( "a" ) ) ) ,
3923+ ( Box :: new( col( "c4 " ) . lt( lit( 0 ) ) ) , Box :: new( lit ( "b" ) ) ) ,
3924+ ( Box :: new( lit( true ) ) , Box :: new( lit ( "c" ) ) ) ,
3925+ ( Box :: new( col( "c3 " ) . eq( lit( 0 ) ) ) , Box :: new( lit ( "d" ) ) ) ,
39103926 ] ,
3911- Some ( Box :: new( col ( "e" ) ) ) ,
3927+ Some ( Box :: new( lit ( "e" ) ) ) ,
39123928 ) ) ) ,
39133929 Expr :: Case ( Case :: new(
39143930 None ,
39153931 vec![
3916- ( Box :: new( col( "x " ) . gt( lit( 0 ) ) ) , Box :: new( col ( "a" ) ) ) ,
3917- ( Box :: new( col( "y " ) . lt( lit( 0 ) ) ) , Box :: new( col ( "b" ) ) ) ,
3932+ ( Box :: new( col( "c3 " ) . gt( lit( 0 ) ) ) , Box :: new( lit ( "a" ) ) ) ,
3933+ ( Box :: new( col( "c4 " ) . lt( lit( 0 ) ) ) , Box :: new( lit ( "b" ) ) ) ,
39183934 ] ,
3919- Some ( Box :: new( col ( "c" ) ) ) ,
3935+ Some ( Box :: new( lit ( "c" ) ) ) ,
39203936 ) )
39213937 ) ;
39223938
3923- // CASE WHEN x > 0 THEN a WHEN y < 0 THEN b WHEN true THEN c END (no else)
3924- // --> CASE WHEN x > 0 THEN a WHEN y < 0 THEN b ELSE c END
3939+ // CASE WHEN c3 > 0 THEN 1 WHEN c4 < 0 THEN 2 WHEN true THEN 3 END (no else)
3940+ // --> CASE WHEN c3 > 0 THEN 1 WHEN c4 < 0 THEN 2 ELSE 3 END
39253941 assert_eq ! (
39263942 simplify( Expr :: Case ( Case :: new(
39273943 None ,
39283944 vec![
3929- ( Box :: new( col( "x " ) . gt( lit( 0 ) ) ) , Box :: new( col ( "a" ) ) ) ,
3930- ( Box :: new( col( "y " ) . lt( lit( 0 ) ) ) , Box :: new( col ( "b" ) ) ) ,
3931- ( Box :: new( lit( true ) ) , Box :: new( col ( "c" ) ) ) ,
3945+ ( Box :: new( col( "c3 " ) . gt( lit( 0 ) ) ) , Box :: new( lit ( 1 ) ) ) ,
3946+ ( Box :: new( col( "c4 " ) . lt( lit( 0 ) ) ) , Box :: new( lit ( 2 ) ) ) ,
3947+ ( Box :: new( lit( true ) ) , Box :: new( lit ( 3 ) ) ) ,
39323948 ] ,
39333949 None ,
39343950 ) ) ) ,
39353951 Expr :: Case ( Case :: new(
39363952 None ,
39373953 vec![
3938- ( Box :: new( col( "x " ) . gt( lit( 0 ) ) ) , Box :: new( col ( "a" ) ) ) ,
3939- ( Box :: new( col( "y " ) . lt( lit( 0 ) ) ) , Box :: new( col ( "b" ) ) ) ,
3954+ ( Box :: new( col( "c3 " ) . gt( lit( 0 ) ) ) , Box :: new( lit ( 1 ) ) ) ,
3955+ ( Box :: new( col( "c4 " ) . lt( lit( 0 ) ) ) , Box :: new( lit ( 2 ) ) ) ,
39403956 ] ,
3941- Some ( Box :: new( col ( "c" ) ) ) ,
3957+ Some ( Box :: new( lit ( 3 ) ) ) ,
39423958 ) )
39433959 ) ;
39443960
3945- // Negative test: CASE WHEN x > 0 THEN a WHEN y < 0 THEN b ELSE c END should not be simplified
3961+ // Negative test: CASE WHEN c3 > 0 THEN c3 WHEN c4 < 0 THEN 2 ELSE 3 END should not be simplified
39463962 let expr = Expr :: Case ( Case :: new (
39473963 None ,
39483964 vec ! [
3949- ( Box :: new( col( "x " ) . gt( lit( 0 ) ) ) , Box :: new( col( "a " ) ) ) ,
3950- ( Box :: new( col( "y " ) . lt( lit( 0 ) ) ) , Box :: new( col ( "b" ) ) ) ,
3965+ ( Box :: new( col( "c3 " ) . gt( lit( 0 ) ) ) , Box :: new( col( "c3 " ) ) ) ,
3966+ ( Box :: new( col( "c4 " ) . lt( lit( 0 ) ) ) , Box :: new( lit ( 2 ) ) ) ,
39513967 ] ,
3952- Some ( Box :: new ( col ( "c" ) ) ) ,
3968+ Some ( Box :: new ( lit ( 3 ) ) ) ,
3969+ ) ) ;
3970+ assert_eq ! ( simplify( expr. clone( ) ) , expr) ;
3971+ }
3972+
3973+ #[ test]
3974+ fn simplify_expr_case_when_any_false ( ) {
3975+ // CASE WHEN false THEN 'a' END --> NULL
3976+ assert_eq ! (
3977+ simplify( Expr :: Case ( Case :: new(
3978+ None ,
3979+ vec![ ( Box :: new( lit( false ) ) , Box :: new( lit( "a" ) ) ) ] ,
3980+ None ,
3981+ ) ) ) ,
3982+ Expr :: Literal ( ScalarValue :: Utf8 ( None ) , None )
3983+ ) ;
3984+
3985+ // CASE WHEN false THEN 2 ELSE 1 END --> 1
3986+ assert_eq ! (
3987+ simplify( Expr :: Case ( Case :: new(
3988+ None ,
3989+ vec![ ( Box :: new( lit( false ) ) , Box :: new( lit( 2 ) ) ) ] ,
3990+ Some ( Box :: new( lit( 1 ) ) ) ,
3991+ ) ) ) ,
3992+ lit( 1 ) ,
3993+ ) ;
3994+
3995+ // CASE WHEN c3 < 10 THEN 'b' WHEN false then c3 ELSE c4 END --> CASE WHEN c3 < 10 THEN b ELSE c4 END
3996+ assert_eq ! (
3997+ simplify( Expr :: Case ( Case :: new(
3998+ None ,
3999+ vec![
4000+ ( Box :: new( col( "c3" ) . lt( lit( 10 ) ) ) , Box :: new( lit( "b" ) ) ) ,
4001+ ( Box :: new( lit( false ) ) , Box :: new( col( "c3" ) ) ) ,
4002+ ] ,
4003+ Some ( Box :: new( col( "c4" ) ) ) ,
4004+ ) ) ) ,
4005+ Expr :: Case ( Case :: new(
4006+ None ,
4007+ vec![ ( Box :: new( col( "c3" ) . lt( lit( 10 ) ) ) , Box :: new( lit( "b" ) ) ) ] ,
4008+ Some ( Box :: new( col( "c4" ) ) ) ,
4009+ ) )
4010+ ) ;
4011+
4012+ // Negative test: CASE WHEN c3 = 4 THEN 1 ELSE 2 END should not be simplified
4013+ let expr = Expr :: Case ( Case :: new (
4014+ None ,
4015+ vec ! [ ( Box :: new( col( "c3" ) . eq( lit( 4 ) ) ) , Box :: new( lit( 1 ) ) ) ] ,
4016+ Some ( Box :: new ( lit ( 2 ) ) ) ,
39534017 ) ) ;
39544018 assert_eq ! ( simplify( expr. clone( ) ) , expr) ;
39554019 }
0 commit comments