diff --git a/datafusion/optimizer/src/unwrap_cast_in_comparison.rs b/datafusion/optimizer/src/unwrap_cast_in_comparison.rs index 293a694d68605..07a946c1add9f 100644 --- a/datafusion/optimizer/src/unwrap_cast_in_comparison.rs +++ b/datafusion/optimizer/src/unwrap_cast_in_comparison.rs @@ -623,15 +623,25 @@ mod tests { // cast(tag as Utf8) = Utf8('value') => tag = arrow_cast('value', 'Dictionary') let expr_input = cast(col("tag"), DataType::Utf8).eq(lit("value")); - let expected = col("tag").eq(lit(dict)); + let expected = col("tag").eq(lit(dict.clone())); assert_eq!(optimize_test(expr_input, &schema), expected); + // Verify reversed argument order + // arrow_cast('value', 'Dictionary') = cast(str1 as Dictionary) => Utf8('value1') = str1 + let expr_input = lit(dict.clone()).eq(cast(col("str1"), dict.data_type())); + let expected = lit("value").eq(col("str1")); + assert_eq!(optimize_test(expr_input, &schema), expected); + } + + #[test] + fn test_unwrap_cast_comparison_large_string() { + let schema = expr_test_schema(); // cast(largestr as Dictionary) = arrow_cast('value', 'Dictionary') => str1 = LargeUtf8('value1') let dict = ScalarValue::Dictionary( Box::new(DataType::Int32), Box::new(ScalarValue::LargeUtf8(Some("value".to_owned()))), ); - let expr_input = cast(col("largestr"), dict.data_type()).eq(lit(dict)); + let expr_input = cast(col("largestr"), dict.data_type()).eq(lit(dict.clone())); let expected = col("largestr").eq(lit(ScalarValue::LargeUtf8(Some("value".to_owned())))); assert_eq!(optimize_test(expr_input, &schema), expected); diff --git a/datafusion/sqllogictest/test_files/dictionary.slt b/datafusion/sqllogictest/test_files/dictionary.slt index 7e45f5e444d17..ec8a514885647 100644 --- a/datafusion/sqllogictest/test_files/dictionary.slt +++ b/datafusion/sqllogictest/test_files/dictionary.slt @@ -414,6 +414,19 @@ physical_plan 02)--FilterExec: column2@1 = 1 03)----MemoryExec: partitions=1, partition_sizes=[1] +# try literal = col to verify order doesn't matter +# filter should not cast column2 +query TT +explain SELECT * from test where '1' = column2 +---- +logical_plan +01)Filter: test.column2 = Dictionary(Int32, Utf8("1")) +02)--TableScan: test projection=[column1, column2] +physical_plan +01)CoalesceBatchesExec: target_batch_size=8192 +02)--FilterExec: column2@1 = 1 +03)----MemoryExec: partitions=1, partition_sizes=[1] + # Now query using an integer which must be coerced into a dictionary string query T?