diff --git a/datafusion/optimizer/src/unwrap_cast_in_comparison.rs b/datafusion/optimizer/src/unwrap_cast_in_comparison.rs index 07a946c1add9f..a43c64a813b82 100644 --- a/datafusion/optimizer/src/unwrap_cast_in_comparison.rs +++ b/datafusion/optimizer/src/unwrap_cast_in_comparison.rs @@ -33,7 +33,7 @@ use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRewriter}; use datafusion_common::{internal_err, DFSchema, DFSchemaRef, Result, ScalarValue}; use datafusion_expr::expr::{BinaryExpr, Cast, InList, TryCast}; use datafusion_expr::utils::merge_schema; -use datafusion_expr::{lit, Expr, ExprSchemable, LogicalPlan, Operator}; +use datafusion_expr::{lit, Expr, ExprSchemable, LogicalPlan}; /// [`UnwrapCastInComparison`] attempts to remove casts from /// comparisons to literals ([`ScalarValue`]s) by applying the casts @@ -154,7 +154,7 @@ impl TreeNodeRewriter for UnwrapCastExprRewriter { }; is_supported_type(&left_type) && is_supported_type(&right_type) - && is_comparison_op(op) + && op.is_comparison_operator() } => { match (left.as_mut(), right.as_mut()) { @@ -270,18 +270,6 @@ impl TreeNodeRewriter for UnwrapCastExprRewriter { } } -fn is_comparison_op(op: &Operator) -> bool { - matches!( - op, - Operator::Eq - | Operator::NotEq - | Operator::Gt - | Operator::GtEq - | Operator::Lt - | Operator::LtEq - ) -} - /// Returns true if [UnwrapCastExprRewriter] supports this data type fn is_supported_type(data_type: &DataType) -> bool { is_supported_numeric_type(data_type) @@ -308,7 +296,10 @@ fn is_supported_numeric_type(data_type: &DataType) -> bool { /// Returns true if [UnwrapCastExprRewriter] supports casting this value as a string fn is_supported_string_type(data_type: &DataType) -> bool { - matches!(data_type, DataType::Utf8 | DataType::LargeUtf8) + matches!( + data_type, + DataType::Utf8 | DataType::LargeUtf8 | DataType::Utf8View + ) } /// Returns true if [UnwrapCastExprRewriter] supports casting this value as a dictionary @@ -481,12 +472,15 @@ fn try_cast_string_literal( target_type: &DataType, ) -> Option { let string_value = match lit_value { - ScalarValue::Utf8(s) | ScalarValue::LargeUtf8(s) => s.clone(), + ScalarValue::Utf8(s) | ScalarValue::LargeUtf8(s) | ScalarValue::Utf8View(s) => { + s.clone() + } _ => return None, }; let scalar_value = match target_type { DataType::Utf8 => ScalarValue::Utf8(string_value), DataType::LargeUtf8 => ScalarValue::LargeUtf8(string_value), + DataType::Utf8View => ScalarValue::Utf8View(string_value), _ => return None, }; Some(scalar_value) diff --git a/datafusion/sqllogictest/test_files/string_view.slt b/datafusion/sqllogictest/test_files/string_view.slt index f8824b23d1b9c..ba82e19e9eb47 100644 --- a/datafusion/sqllogictest/test_files/string_view.slt +++ b/datafusion/sqllogictest/test_files/string_view.slt @@ -226,12 +226,18 @@ logical_plan 02)--Filter: test.column1_utf8view = Utf8View("Andrew") 03)----TableScan: test projection=[column1_utf8, column1_utf8view] -# should not be casting the column: https://github.com/apache/datafusion/issues/10998 query TT explain SELECT column1_utf8 from test where column1_utf8 = arrow_cast('Andrew', 'Utf8View'); ---- logical_plan -01)Filter: CAST(test.column1_utf8 AS Utf8View) = Utf8View("Andrew") +01)Filter: test.column1_utf8 = Utf8("Andrew") +02)--TableScan: test projection=[column1_utf8] + +query TT +explain SELECT column1_utf8 from test where arrow_cast('Andrew', 'Utf8View') = column1_utf8; +---- +logical_plan +01)Filter: test.column1_utf8 = Utf8("Andrew") 02)--TableScan: test projection=[column1_utf8] query TT @@ -242,6 +248,14 @@ logical_plan 02)--Filter: test.column1_utf8view = Utf8View("Andrew") 03)----TableScan: test projection=[column1_utf8, column1_utf8view] +query TT +explain SELECT column1_utf8 from test where arrow_cast('Andrew', 'Dictionary(Int32, Utf8)') = column1_utf8view; +---- +logical_plan +01)Projection: test.column1_utf8 +02)--Filter: test.column1_utf8view = Utf8View("Andrew") +03)----TableScan: test projection=[column1_utf8, column1_utf8view] + # compare string / stringview # Should cast string -> stringview (which is cheap), not stringview -> string (which is not) query TT