diff --git a/datafusion/physical-expr/src/equivalence/properties.rs b/datafusion/physical-expr/src/equivalence/properties.rs index 06f1e24ed202..fe866450b2b2 100644 --- a/datafusion/physical-expr/src/equivalence/properties.rs +++ b/datafusion/physical-expr/src/equivalence/properties.rs @@ -883,9 +883,11 @@ impl EquivalenceProperties { if self.is_expr_constant(source) && !const_exprs_contains(&projected_constants, target) { + let across_partitions = self.is_expr_constant_accross_partitions(source); // Expression evaluates to single value - projected_constants - .push(ConstExpr::from(target).with_across_partitions(true)); + projected_constants.push( + ConstExpr::from(target).with_across_partitions(across_partitions), + ); } } projected_constants @@ -1014,6 +1016,37 @@ impl EquivalenceProperties { is_constant_recurse(&normalized_constants, &normalized_expr) } + /// This function determines whether the provided expression is constant + /// across partitions based on the known constants. + /// + /// # Arguments + /// + /// - `expr`: A reference to a `Arc` representing the + /// expression to be checked. + /// + /// # Returns + /// + /// Returns `true` if the expression is constant across all partitions according + /// to equivalence group, `false` otherwise. + pub fn is_expr_constant_accross_partitions( + &self, + expr: &Arc, + ) -> bool { + // As an example, assume that we know columns `a` and `b` are constant. + // Then, `a`, `b` and `a + b` will all return `true` whereas `c` will + // return `false`. + let const_exprs = self.constants.iter().flat_map(|const_expr| { + if const_expr.across_partitions() { + Some(Arc::clone(const_expr.expr())) + } else { + None + } + }); + let normalized_constants = self.eq_group.normalize_exprs(const_exprs); + let normalized_expr = self.eq_group.normalize_expr(Arc::clone(expr)); + is_constant_recurse(&normalized_constants, &normalized_expr) + } + /// Retrieves the properties for a given physical expression. /// /// This function constructs an [`ExprProperties`] object for the given diff --git a/datafusion/sqllogictest/test_files/order.slt b/datafusion/sqllogictest/test_files/order.slt index d5f0521407c5..a46040aa532e 100644 --- a/datafusion/sqllogictest/test_files/order.slt +++ b/datafusion/sqllogictest/test_files/order.slt @@ -1260,3 +1260,49 @@ limit 2; statement ok drop table ordered_table; + +query TT +EXPLAIN SELECT + CASE + WHEN name = 'name1' THEN 0.0 + WHEN name = 'name2' THEN 0.5 + END AS a +FROM ( + SELECT 'name1' AS name + UNION ALL + SELECT 'name2' +) +ORDER BY a DESC; +---- +logical_plan +01)Sort: a DESC NULLS FIRST +02)--Projection: CASE WHEN name = Utf8("name1") THEN Float64(0) WHEN name = Utf8("name2") THEN Float64(0.5) END AS a +03)----Union +04)------Projection: Utf8("name1") AS name +05)--------EmptyRelation +06)------Projection: Utf8("name2") AS name +07)--------EmptyRelation +physical_plan +01)SortPreservingMergeExec: [a@0 DESC] +02)--ProjectionExec: expr=[CASE WHEN name@0 = name1 THEN 0 WHEN name@0 = name2 THEN 0.5 END as a] +03)----UnionExec +04)------ProjectionExec: expr=[name1 as name] +05)--------PlaceholderRowExec +06)------ProjectionExec: expr=[name2 as name] +07)--------PlaceholderRowExec + +query R +SELECT + CASE + WHEN name = 'name1' THEN 0.0 + WHEN name = 'name2' THEN 0.5 + END AS a +FROM ( + SELECT 'name1' AS name + UNION ALL + SELECT 'name2' +) +ORDER BY a DESC; +---- +0.5 +0