diff --git a/datafusion/optimizer/src/common_subexpr_eliminate.rs b/datafusion/optimizer/src/common_subexpr_eliminate.rs index 6820ba04f0e9..3ed1309f1544 100644 --- a/datafusion/optimizer/src/common_subexpr_eliminate.rs +++ b/datafusion/optimizer/src/common_subexpr_eliminate.rs @@ -128,7 +128,7 @@ impl CommonSubexprEliminate { fn rewrite_exprs_list( &self, exprs_list: &[&[Expr]], - arrays_list: &[&[Vec<(usize, String)>]], + arrays_list: &[&[IdArray]], expr_stats: &ExprStats, common_exprs: &mut CommonExprs, ) -> Result>> { @@ -159,7 +159,7 @@ impl CommonSubexprEliminate { fn rewrite_expr( &self, exprs_list: &[&[Expr]], - arrays_list: &[&[Vec<(usize, String)>]], + arrays_list: &[&[IdArray]], input: &LogicalPlan, expr_stats: &ExprStats, config: &dyn OptimizerConfig, @@ -480,7 +480,7 @@ fn to_arrays( input_schema: DFSchemaRef, expr_stats: &mut ExprStats, expr_mask: ExprMask, -) -> Result>> { +) -> Result> { expr.iter() .map(|e| { let mut id_array = vec![]; @@ -739,7 +739,7 @@ fn expr_identifier(expr: &Expr, sub_expr_identifier: Identifier) -> Identifier { fn expr_to_identifier( expr: &Expr, expr_stats: &mut ExprStats, - id_array: &mut Vec<(usize, Identifier)>, + id_array: &mut IdArray, input_schema: DFSchemaRef, expr_mask: ExprMask, ) -> Result<()> { @@ -769,15 +769,28 @@ struct CommonSubexprRewriter<'a> { common_exprs: &'a mut CommonExprs, // preorder index, starts from 0. down_index: usize, + // how many aliases have we seen so far + alias_counter: usize, } impl TreeNodeRewriter for CommonSubexprRewriter<'_> { type Node = Expr; + fn f_up(&mut self, expr: Expr) -> Result> { + if matches!(expr, Expr::Alias(_)) { + self.alias_counter -= 1 + } + Ok(Transformed::no(expr)) + } + fn f_down(&mut self, expr: Expr) -> Result> { // The `CommonSubexprRewriter` relies on `ExprIdentifierVisitor` to generate // the `id_array`, which records the expr's identifier used to rewrite expr. So if we // skip an expr in `ExprIdentifierVisitor`, we should skip it here, too. + if matches!(expr, Expr::Alias(_)) { + self.alias_counter += 1; + } + if expr.short_circuits() || expr.is_volatile()? { return Ok(Transformed::new(expr, false, TreeNodeRecursion::Jump)); } @@ -801,15 +814,16 @@ impl TreeNodeRewriter for CommonSubexprRewriter<'_> { let expr_name = expr.display_name()?; self.common_exprs.insert(expr_id.clone(), expr); - // Alias this `Column` expr to it original "expr name", - // `projection_push_down` optimizer use "expr name" to eliminate useless - // projections. - // TODO: do we really need to alias here? - Ok(Transformed::new( - col(expr_id).alias(expr_name), - true, - TreeNodeRecursion::Jump, - )) + + // alias the expressions without an `Alias` ancestor node + let rewritten = if self.alias_counter > 0 { + col(expr_id) + } else { + self.alias_counter += 1; + col(expr_id).alias(expr_name) + }; + + Ok(Transformed::new(rewritten, true, TreeNodeRecursion::Jump)) } else { Ok(Transformed::no(expr)) } @@ -829,6 +843,7 @@ fn replace_common_expr( id_array, common_exprs, down_index: 0, + alias_counter: 0, }) .data() } @@ -962,6 +977,26 @@ mod test { Ok(()) } + #[test] + fn nested_aliases() -> Result<()> { + let table_scan = test_table_scan()?; + + let plan = LogicalPlanBuilder::from(table_scan.clone()) + .project(vec![ + (col("a") + col("b") - col("c")).alias("alias1") * (col("a") + col("b")), + col("a") + col("b"), + ])? + .build()?; + + let expected = "Projection: {test.a + test.b|{test.b}|{test.a}} - test.c AS alias1 * {test.a + test.b|{test.b}|{test.a}} AS test.a + test.b, {test.a + test.b|{test.b}|{test.a}} AS test.a + test.b\ + \n Projection: test.a + test.b AS {test.a + test.b|{test.b}|{test.a}}, test.a, test.b, test.c\ + \n TableScan: test"; + + assert_optimized_plan_eq(expected, &plan); + + Ok(()) + } + #[test] fn aggregate() -> Result<()> { let table_scan = test_table_scan()?; @@ -1006,7 +1041,7 @@ mod test { )? .build()?; - let expected = "Projection: {AVG(test.a)|{test.a}} AS AVG(test.a) AS col1, {AVG(test.a)|{test.a}} AS AVG(test.a) AS col2, col3, {AVG(test.c)} AS AVG(test.c), {my_agg(test.a)|{test.a}} AS my_agg(test.a) AS col4, {my_agg(test.a)|{test.a}} AS my_agg(test.a) AS col5, col6, {my_agg(test.c)} AS my_agg(test.c)\ + let expected = "Projection: {AVG(test.a)|{test.a}} AS col1, {AVG(test.a)|{test.a}} AS col2, col3, {AVG(test.c)} AS AVG(test.c), {my_agg(test.a)|{test.a}} AS col4, {my_agg(test.a)|{test.a}} AS col5, col6, {my_agg(test.c)} AS my_agg(test.c)\ \n Aggregate: groupBy=[[]], aggr=[[AVG(test.a) AS {AVG(test.a)|{test.a}}, my_agg(test.a) AS {my_agg(test.a)|{test.a}}, AVG(test.b) AS col3, AVG(test.c) AS {AVG(test.c)}, my_agg(test.b) AS col6, my_agg(test.c) AS {my_agg(test.c)}]]\ \n TableScan: test"; @@ -1042,7 +1077,7 @@ mod test { )? .build()?; - let expected = "Aggregate: groupBy=[[]], aggr=[[AVG({UInt32(1) + test.a|{test.a}|{UInt32(1)}} AS UInt32(1) + test.a) AS col1, my_agg({UInt32(1) + test.a|{test.a}|{UInt32(1)}} AS UInt32(1) + test.a) AS col2]]\n Projection: UInt32(1) + test.a AS {UInt32(1) + test.a|{test.a}|{UInt32(1)}}, test.a, test.b, test.c\n TableScan: test"; + let expected = "Aggregate: groupBy=[[]], aggr=[[AVG({UInt32(1) + test.a|{test.a}|{UInt32(1)}}) AS col1, my_agg({UInt32(1) + test.a|{test.a}|{UInt32(1)}}) AS col2]]\n Projection: UInt32(1) + test.a AS {UInt32(1) + test.a|{test.a}|{UInt32(1)}}, test.a, test.b, test.c\n TableScan: test"; assert_optimized_plan_eq(expected, &plan); @@ -1057,7 +1092,7 @@ mod test { )? .build()?; - let expected = "Aggregate: groupBy=[[{UInt32(1) + test.a|{test.a}|{UInt32(1)}} AS UInt32(1) + test.a]], aggr=[[AVG({UInt32(1) + test.a|{test.a}|{UInt32(1)}} AS UInt32(1) + test.a) AS col1, my_agg({UInt32(1) + test.a|{test.a}|{UInt32(1)}} AS UInt32(1) + test.a) AS col2]]\ + let expected = "Aggregate: groupBy=[[{UInt32(1) + test.a|{test.a}|{UInt32(1)}} AS UInt32(1) + test.a]], aggr=[[AVG({UInt32(1) + test.a|{test.a}|{UInt32(1)}}) AS col1, my_agg({UInt32(1) + test.a|{test.a}|{UInt32(1)}}) AS col2]]\ \n Projection: UInt32(1) + test.a AS {UInt32(1) + test.a|{test.a}|{UInt32(1)}}, test.a, test.b, test.c\ \n TableScan: test"; @@ -1078,8 +1113,8 @@ mod test { )? .build()?; - let expected = "Projection: UInt32(1) + test.a, UInt32(1) + {AVG({UInt32(1) + test.a|{test.a}|{UInt32(1)}} AS UInt32(1) + test.a)|{{UInt32(1) + test.a|{test.a}|{UInt32(1)}} AS UInt32(1) + test.a|{{UInt32(1) + test.a|{test.a}|{UInt32(1)}}}}} AS AVG(UInt32(1) + test.a) AS col1, UInt32(1) - {AVG({UInt32(1) + test.a|{test.a}|{UInt32(1)}} AS UInt32(1) + test.a)|{{UInt32(1) + test.a|{test.a}|{UInt32(1)}} AS UInt32(1) + test.a|{{UInt32(1) + test.a|{test.a}|{UInt32(1)}}}}} AS AVG(UInt32(1) + test.a) AS col2, {AVG({UInt32(1) + test.a|{test.a}|{UInt32(1)}} AS UInt32(1) + test.a)|{{UInt32(1) + test.a|{test.a}|{UInt32(1)}} AS UInt32(1) + test.a|{{UInt32(1) + test.a|{test.a}|{UInt32(1)}}}}} AS AVG(UInt32(1) + test.a), UInt32(1) + {my_agg({UInt32(1) + test.a|{test.a}|{UInt32(1)}} AS UInt32(1) + test.a)|{{UInt32(1) + test.a|{test.a}|{UInt32(1)}} AS UInt32(1) + test.a|{{UInt32(1) + test.a|{test.a}|{UInt32(1)}}}}} AS my_agg(UInt32(1) + test.a) AS col3, UInt32(1) - {my_agg({UInt32(1) + test.a|{test.a}|{UInt32(1)}} AS UInt32(1) + test.a)|{{UInt32(1) + test.a|{test.a}|{UInt32(1)}} AS UInt32(1) + test.a|{{UInt32(1) + test.a|{test.a}|{UInt32(1)}}}}} AS my_agg(UInt32(1) + test.a) AS col4, {my_agg({UInt32(1) + test.a|{test.a}|{UInt32(1)}} AS UInt32(1) + test.a)|{{UInt32(1) + test.a|{test.a}|{UInt32(1)}} AS UInt32(1) + test.a|{{UInt32(1) + test.a|{test.a}|{UInt32(1)}}}}} AS my_agg(UInt32(1) + test.a)\ - \n Aggregate: groupBy=[[{UInt32(1) + test.a|{test.a}|{UInt32(1)}} AS UInt32(1) + test.a]], aggr=[[AVG({UInt32(1) + test.a|{test.a}|{UInt32(1)}} AS UInt32(1) + test.a) AS {AVG({UInt32(1) + test.a|{test.a}|{UInt32(1)}} AS UInt32(1) + test.a)|{{UInt32(1) + test.a|{test.a}|{UInt32(1)}} AS UInt32(1) + test.a|{{UInt32(1) + test.a|{test.a}|{UInt32(1)}}}}}, my_agg({UInt32(1) + test.a|{test.a}|{UInt32(1)}} AS UInt32(1) + test.a) AS {my_agg({UInt32(1) + test.a|{test.a}|{UInt32(1)}} AS UInt32(1) + test.a)|{{UInt32(1) + test.a|{test.a}|{UInt32(1)}} AS UInt32(1) + test.a|{{UInt32(1) + test.a|{test.a}|{UInt32(1)}}}}}]]\ + let expected = "Projection: UInt32(1) + test.a, UInt32(1) + {AVG({UInt32(1) + test.a|{test.a}|{UInt32(1)}})|{{UInt32(1) + test.a|{test.a}|{UInt32(1)}}}} AS col1, UInt32(1) - {AVG({UInt32(1) + test.a|{test.a}|{UInt32(1)}})|{{UInt32(1) + test.a|{test.a}|{UInt32(1)}}}} AS col2, {AVG({UInt32(1) + test.a|{test.a}|{UInt32(1)}} AS UInt32(1) + test.a)} AS AVG(UInt32(1) + test.a), UInt32(1) + {my_agg({UInt32(1) + test.a|{test.a}|{UInt32(1)}})|{{UInt32(1) + test.a|{test.a}|{UInt32(1)}}}} AS col3, UInt32(1) - {my_agg({UInt32(1) + test.a|{test.a}|{UInt32(1)}})|{{UInt32(1) + test.a|{test.a}|{UInt32(1)}}}} AS col4, {my_agg({UInt32(1) + test.a|{test.a}|{UInt32(1)}} AS UInt32(1) + test.a)} AS my_agg(UInt32(1) + test.a)\ + \n Aggregate: groupBy=[[{UInt32(1) + test.a|{test.a}|{UInt32(1)}} AS UInt32(1) + test.a]], aggr=[[AVG({UInt32(1) + test.a|{test.a}|{UInt32(1)}}) AS {AVG({UInt32(1) + test.a|{test.a}|{UInt32(1)}})|{{UInt32(1) + test.a|{test.a}|{UInt32(1)}}}}, my_agg({UInt32(1) + test.a|{test.a}|{UInt32(1)}}) AS {my_agg({UInt32(1) + test.a|{test.a}|{UInt32(1)}})|{{UInt32(1) + test.a|{test.a}|{UInt32(1)}}}}, AVG({UInt32(1) + test.a|{test.a}|{UInt32(1)}} AS UInt32(1) + test.a) AS {AVG({UInt32(1) + test.a|{test.a}|{UInt32(1)}} AS UInt32(1) + test.a)}, my_agg({UInt32(1) + test.a|{test.a}|{UInt32(1)}} AS UInt32(1) + test.a) AS {my_agg({UInt32(1) + test.a|{test.a}|{UInt32(1)}} AS UInt32(1) + test.a)}]]\ \n Projection: UInt32(1) + test.a AS {UInt32(1) + test.a|{test.a}|{UInt32(1)}}, test.a, test.b, test.c\ \n TableScan: test"; @@ -1126,7 +1161,7 @@ mod test { ])? .build()?; - let expected = "Projection: {Int32(1) + test.a|{test.a}|{Int32(1)}} AS Int32(1) + test.a AS first, {Int32(1) + test.a|{test.a}|{Int32(1)}} AS Int32(1) + test.a AS second\ + let expected = "Projection: {Int32(1) + test.a|{test.a}|{Int32(1)}} AS first, {Int32(1) + test.a|{test.a}|{Int32(1)}} AS second\ \n Projection: Int32(1) + test.a AS {Int32(1) + test.a|{test.a}|{Int32(1)}}, test.a, test.b, test.c\ \n TableScan: test"; diff --git a/datafusion/sqllogictest/test_files/group_by.slt b/datafusion/sqllogictest/test_files/group_by.slt index 24a301d4a700..9e8a2450e0a5 100644 --- a/datafusion/sqllogictest/test_files/group_by.slt +++ b/datafusion/sqllogictest/test_files/group_by.slt @@ -4538,7 +4538,7 @@ CREATE EXTERNAL TABLE timestamp_table ( c2 INT, ) STORED AS CSV -LOCATION 'test_files/scratch/group_by/timestamp_table' +LOCATION 'test_files/scratch/group_by/timestamp_table' OPTIONS ('format.has_header' 'true'); # Group By using date_trunc @@ -4611,7 +4611,7 @@ DROP TABLE timestamp_table; # Table with an int column and Dict column: statement ok -CREATE TABLE int8_dict AS VALUES +CREATE TABLE int8_dict AS VALUES (1, arrow_cast('A', 'Dictionary(Int8, Utf8)')), (2, arrow_cast('B', 'Dictionary(Int8, Utf8)')), (2, arrow_cast('A', 'Dictionary(Int8, Utf8)')), @@ -4649,7 +4649,7 @@ DROP TABLE int8_dict; # Table with an int column and Dict column: statement ok -CREATE TABLE int16_dict AS VALUES +CREATE TABLE int16_dict AS VALUES (1, arrow_cast('A', 'Dictionary(Int16, Utf8)')), (2, arrow_cast('B', 'Dictionary(Int16, Utf8)')), (2, arrow_cast('A', 'Dictionary(Int16, Utf8)')), @@ -4687,7 +4687,7 @@ DROP TABLE int16_dict; # Table with an int column and Dict column: statement ok -CREATE TABLE int32_dict AS VALUES +CREATE TABLE int32_dict AS VALUES (1, arrow_cast('A', 'Dictionary(Int32, Utf8)')), (2, arrow_cast('B', 'Dictionary(Int32, Utf8)')), (2, arrow_cast('A', 'Dictionary(Int32, Utf8)')), @@ -4725,7 +4725,7 @@ DROP TABLE int32_dict; # Table with an int column and Dict column: statement ok -CREATE TABLE int64_dict AS VALUES +CREATE TABLE int64_dict AS VALUES (1, arrow_cast('A', 'Dictionary(Int64, Utf8)')), (2, arrow_cast('B', 'Dictionary(Int64, Utf8)')), (2, arrow_cast('A', 'Dictionary(Int64, Utf8)')), @@ -4763,7 +4763,7 @@ DROP TABLE int64_dict; # Table with an int column and Dict column: statement ok -CREATE TABLE uint8_dict AS VALUES +CREATE TABLE uint8_dict AS VALUES (1, arrow_cast('A', 'Dictionary(UInt8, Utf8)')), (2, arrow_cast('B', 'Dictionary(UInt8, Utf8)')), (2, arrow_cast('A', 'Dictionary(UInt8, Utf8)')), @@ -4801,7 +4801,7 @@ DROP TABLE uint8_dict; # Table with an int column and Dict column: statement ok -CREATE TABLE uint16_dict AS VALUES +CREATE TABLE uint16_dict AS VALUES (1, arrow_cast('A', 'Dictionary(UInt16, Utf8)')), (2, arrow_cast('B', 'Dictionary(UInt16, Utf8)')), (2, arrow_cast('A', 'Dictionary(UInt16, Utf8)')), @@ -4839,7 +4839,7 @@ DROP TABLE uint16_dict; # Table with an int column and Dict column: statement ok -CREATE TABLE uint32_dict AS VALUES +CREATE TABLE uint32_dict AS VALUES (1, arrow_cast('A', 'Dictionary(UInt32, Utf8)')), (2, arrow_cast('B', 'Dictionary(UInt32, Utf8)')), (2, arrow_cast('A', 'Dictionary(UInt32, Utf8)')), @@ -4877,7 +4877,7 @@ DROP TABLE uint32_dict; # Table with an int column and Dict column: statement ok -CREATE TABLE uint64_dict AS VALUES +CREATE TABLE uint64_dict AS VALUES (1, arrow_cast('A', 'Dictionary(UInt64, Utf8)')), (2, arrow_cast('B', 'Dictionary(UInt64, Utf8)')), (2, arrow_cast('A', 'Dictionary(UInt64, Utf8)')), diff --git a/datafusion/sqllogictest/test_files/tpch/q1.slt.part b/datafusion/sqllogictest/test_files/tpch/q1.slt.part index 0583c6ef07a7..5e0930b99228 100644 --- a/datafusion/sqllogictest/test_files/tpch/q1.slt.part +++ b/datafusion/sqllogictest/test_files/tpch/q1.slt.part @@ -42,7 +42,7 @@ explain select logical_plan 01)Sort: lineitem.l_returnflag ASC NULLS LAST, lineitem.l_linestatus ASC NULLS LAST 02)--Projection: lineitem.l_returnflag, lineitem.l_linestatus, sum(lineitem.l_quantity) AS sum_qty, sum(lineitem.l_extendedprice) AS sum_base_price, sum(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount) AS sum_disc_price, sum(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount * Int64(1) + lineitem.l_tax) AS sum_charge, AVG(lineitem.l_quantity) AS avg_qty, AVG(lineitem.l_extendedprice) AS avg_price, AVG(lineitem.l_discount) AS avg_disc, COUNT(*) AS count_order -03)----Aggregate: groupBy=[[lineitem.l_returnflag, lineitem.l_linestatus]], aggr=[[sum(lineitem.l_quantity), sum(lineitem.l_extendedprice), sum({lineitem.l_extendedprice * (Decimal128(Some(1),20,0) - lineitem.l_discount)|{Decimal128(Some(1),20,0) - lineitem.l_discount|{lineitem.l_discount}|{Decimal128(Some(1),20,0)}}|{lineitem.l_extendedprice}} AS lineitem.l_extendedprice * Decimal128(Some(1),20,0) - lineitem.l_discount) AS sum(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount), sum({lineitem.l_extendedprice * (Decimal128(Some(1),20,0) - lineitem.l_discount)|{Decimal128(Some(1),20,0) - lineitem.l_discount|{lineitem.l_discount}|{Decimal128(Some(1),20,0)}}|{lineitem.l_extendedprice}} AS lineitem.l_extendedprice * Decimal128(Some(1),20,0) - lineitem.l_discount * (Decimal128(Some(1),20,0) + lineitem.l_tax)) AS sum(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount * Int64(1) + lineitem.l_tax), AVG(lineitem.l_quantity), AVG(lineitem.l_extendedprice), AVG(lineitem.l_discount), COUNT(Int64(1)) AS COUNT(*)]] +03)----Aggregate: groupBy=[[lineitem.l_returnflag, lineitem.l_linestatus]], aggr=[[sum(lineitem.l_quantity), sum(lineitem.l_extendedprice), sum({lineitem.l_extendedprice * (Decimal128(Some(1),20,0) - lineitem.l_discount)|{Decimal128(Some(1),20,0) - lineitem.l_discount|{lineitem.l_discount}|{Decimal128(Some(1),20,0)}}|{lineitem.l_extendedprice}}) AS sum(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount), sum({lineitem.l_extendedprice * (Decimal128(Some(1),20,0) - lineitem.l_discount)|{Decimal128(Some(1),20,0) - lineitem.l_discount|{lineitem.l_discount}|{Decimal128(Some(1),20,0)}}|{lineitem.l_extendedprice}} * (Decimal128(Some(1),20,0) + lineitem.l_tax)) AS sum(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount * Int64(1) + lineitem.l_tax), AVG(lineitem.l_quantity), AVG(lineitem.l_extendedprice), AVG(lineitem.l_discount), COUNT(Int64(1)) AS COUNT(*)]] 04)------Projection: lineitem.l_extendedprice * (Decimal128(Some(1),20,0) - lineitem.l_discount) AS {lineitem.l_extendedprice * (Decimal128(Some(1),20,0) - lineitem.l_discount)|{Decimal128(Some(1),20,0) - lineitem.l_discount|{lineitem.l_discount}|{Decimal128(Some(1),20,0)}}|{lineitem.l_extendedprice}}, lineitem.l_quantity, lineitem.l_extendedprice, lineitem.l_discount, lineitem.l_tax, lineitem.l_returnflag, lineitem.l_linestatus 05)--------Filter: lineitem.l_shipdate <= Date32("1998-09-02") 06)----------TableScan: lineitem projection=[l_quantity, l_extendedprice, l_discount, l_tax, l_returnflag, l_linestatus, l_shipdate], partial_filters=[lineitem.l_shipdate <= Date32("1998-09-02")] @@ -80,7 +80,7 @@ group by l_linestatus order by l_returnflag, - l_linestatus; + l_linestatus; ---- A F 3774200 5320753880.69 5054096266.6828 5256751331.449234 25.537587 36002.123829 0.050144 147790 N F 95257 133737795.84 127132372.6512 132286291.229445 25.300664 35521.326916 0.049394 3765