diff --git a/Cargo.lock b/Cargo.lock index 6a3bad8472862..881ce2d78fd0b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2408,6 +2408,7 @@ dependencies = [ "ctor", "datafusion-common", "datafusion-expr", + "datafusion-expr-common", "datafusion-functions-aggregate", "datafusion-functions-window", "datafusion-functions-window-common", diff --git a/datafusion/datasource-parquet/src/opener.rs b/datafusion/datasource-parquet/src/opener.rs index e7e8d52070392..6fe9752fd9154 100644 --- a/datafusion/datasource-parquet/src/opener.rs +++ b/datafusion/datasource-parquet/src/opener.rs @@ -33,6 +33,7 @@ use arrow::datatypes::{FieldRef, SchemaRef, TimeUnit}; use arrow::error::ArrowError; use datafusion_common::{exec_err, DataFusionError, Result}; use datafusion_datasource::PartitionedFile; +use datafusion_physical_expr::simplifier::PhysicalExprSimplifier; use datafusion_physical_expr::PhysicalExprSchemaRewriter; use datafusion_physical_expr_common::physical_expr::{ is_dynamic_physical_expr, PhysicalExpr, @@ -233,7 +234,16 @@ impl FileOpener for ParquetOpener { ) .rewrite(p) .map_err(ArrowError::from) + .map(|p| { + // After rewriting to the file schema, further simplifications may be possible. + // For example, if `'a' = col_that_is_missing` becomes `'a' = NULL` that can then be simplified to `FALSE` + // and we can avoid doing any more work on the file (bloom filters, loading the page index, etc.). + PhysicalExprSimplifier::new(&physical_file_schema) + .simplify(p) + .map_err(ArrowError::from) + }) }) + .transpose()? .transpose()?; // Build predicates for this specific file diff --git a/datafusion/expr-common/src/casts.rs b/datafusion/expr-common/src/casts.rs new file mode 100644 index 0000000000000..c31d4f77c6a7f --- /dev/null +++ b/datafusion/expr-common/src/casts.rs @@ -0,0 +1,1227 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Utilities for casting scalar literals to different data types +//! +//! This module contains functions for casting ScalarValue literals +//! to different data types, originally extracted from the optimizer's +//! unwrap_cast module to be shared between logical and physical layers. + +use std::cmp::Ordering; + +use arrow::datatypes::{ + DataType, TimeUnit, MAX_DECIMAL128_FOR_EACH_PRECISION, + MIN_DECIMAL128_FOR_EACH_PRECISION, +}; +use arrow::temporal_conversions::{MICROSECONDS, MILLISECONDS, NANOSECONDS}; +use datafusion_common::ScalarValue; + +/// Convert a literal value from one data type to another +pub fn try_cast_literal_to_type( + lit_value: &ScalarValue, + target_type: &DataType, +) -> Option { + let lit_data_type = lit_value.data_type(); + if !is_supported_type(&lit_data_type) || !is_supported_type(target_type) { + return None; + } + if lit_value.is_null() { + // null value can be cast to any type of null value + return ScalarValue::try_from(target_type).ok(); + } + try_cast_numeric_literal(lit_value, target_type) + .or_else(|| try_cast_string_literal(lit_value, target_type)) + .or_else(|| try_cast_dictionary(lit_value, target_type)) + .or_else(|| try_cast_binary(lit_value, target_type)) +} + +/// Returns true if unwrap_cast_in_comparison supports this data type +pub fn is_supported_type(data_type: &DataType) -> bool { + is_supported_numeric_type(data_type) + || is_supported_string_type(data_type) + || is_supported_dictionary_type(data_type) + || is_supported_binary_type(data_type) +} + +/// Returns true if unwrap_cast_in_comparison support this numeric type +fn is_supported_numeric_type(data_type: &DataType) -> bool { + matches!( + data_type, + DataType::UInt8 + | DataType::UInt16 + | DataType::UInt32 + | DataType::UInt64 + | DataType::Int8 + | DataType::Int16 + | DataType::Int32 + | DataType::Int64 + | DataType::Decimal128(_, _) + | DataType::Timestamp(_, _) + ) +} + +/// Returns true if unwrap_cast_in_comparison supports casting this value as a string +fn is_supported_string_type(data_type: &DataType) -> bool { + matches!( + data_type, + DataType::Utf8 | DataType::LargeUtf8 | DataType::Utf8View + ) +} + +/// Returns true if unwrap_cast_in_comparison supports casting this value as a dictionary +fn is_supported_dictionary_type(data_type: &DataType) -> bool { + matches!(data_type, + DataType::Dictionary(_, inner) if is_supported_type(inner)) +} + +fn is_supported_binary_type(data_type: &DataType) -> bool { + matches!(data_type, DataType::Binary | DataType::FixedSizeBinary(_)) +} + +/// Convert a numeric value from one numeric data type to another +fn try_cast_numeric_literal( + lit_value: &ScalarValue, + target_type: &DataType, +) -> Option { + let lit_data_type = lit_value.data_type(); + if !is_supported_numeric_type(&lit_data_type) + || !is_supported_numeric_type(target_type) + { + return None; + } + + let mul = match target_type { + DataType::UInt8 + | DataType::UInt16 + | DataType::UInt32 + | DataType::UInt64 + | DataType::Int8 + | DataType::Int16 + | DataType::Int32 + | DataType::Int64 => 1_i128, + DataType::Timestamp(_, _) => 1_i128, + DataType::Decimal128(_, scale) => 10_i128.pow(*scale as u32), + _ => return None, + }; + let (target_min, target_max) = match target_type { + DataType::UInt8 => (u8::MIN as i128, u8::MAX as i128), + DataType::UInt16 => (u16::MIN as i128, u16::MAX as i128), + DataType::UInt32 => (u32::MIN as i128, u32::MAX as i128), + DataType::UInt64 => (u64::MIN as i128, u64::MAX as i128), + DataType::Int8 => (i8::MIN as i128, i8::MAX as i128), + DataType::Int16 => (i16::MIN as i128, i16::MAX as i128), + DataType::Int32 => (i32::MIN as i128, i32::MAX as i128), + DataType::Int64 => (i64::MIN as i128, i64::MAX as i128), + DataType::Timestamp(_, _) => (i64::MIN as i128, i64::MAX as i128), + DataType::Decimal128(precision, _) => ( + // Different precision for decimal128 can store different range of value. + // For example, the precision is 3, the max of value is `999` and the min + // value is `-999` + MIN_DECIMAL128_FOR_EACH_PRECISION[*precision as usize], + MAX_DECIMAL128_FOR_EACH_PRECISION[*precision as usize], + ), + _ => return None, + }; + let lit_value_target_type = match lit_value { + ScalarValue::Int8(Some(v)) => (*v as i128).checked_mul(mul), + ScalarValue::Int16(Some(v)) => (*v as i128).checked_mul(mul), + ScalarValue::Int32(Some(v)) => (*v as i128).checked_mul(mul), + ScalarValue::Int64(Some(v)) => (*v as i128).checked_mul(mul), + ScalarValue::UInt8(Some(v)) => (*v as i128).checked_mul(mul), + ScalarValue::UInt16(Some(v)) => (*v as i128).checked_mul(mul), + ScalarValue::UInt32(Some(v)) => (*v as i128).checked_mul(mul), + ScalarValue::UInt64(Some(v)) => (*v as i128).checked_mul(mul), + ScalarValue::TimestampSecond(Some(v), _) => (*v as i128).checked_mul(mul), + ScalarValue::TimestampMillisecond(Some(v), _) => (*v as i128).checked_mul(mul), + ScalarValue::TimestampMicrosecond(Some(v), _) => (*v as i128).checked_mul(mul), + ScalarValue::TimestampNanosecond(Some(v), _) => (*v as i128).checked_mul(mul), + ScalarValue::Decimal128(Some(v), _, scale) => { + let lit_scale_mul = 10_i128.pow(*scale as u32); + if mul >= lit_scale_mul { + // Example: + // lit is decimal(123,3,2) + // target type is decimal(5,3) + // the lit can be converted to the decimal(1230,5,3) + (*v).checked_mul(mul / lit_scale_mul) + } else if (*v) % (lit_scale_mul / mul) == 0 { + // Example: + // lit is decimal(123000,10,3) + // target type is int32: the lit can be converted to INT32(123) + // target type is decimal(10,2): the lit can be converted to decimal(12300,10,2) + Some(*v / (lit_scale_mul / mul)) + } else { + // can't convert the lit decimal to the target data type + None + } + } + _ => None, + }; + + match lit_value_target_type { + None => None, + Some(value) => { + if value >= target_min && value <= target_max { + // the value casted from lit to the target type is in the range of target type. + // return the target type of scalar value + let result_scalar = match target_type { + DataType::Int8 => ScalarValue::Int8(Some(value as i8)), + DataType::Int16 => ScalarValue::Int16(Some(value as i16)), + DataType::Int32 => ScalarValue::Int32(Some(value as i32)), + DataType::Int64 => ScalarValue::Int64(Some(value as i64)), + DataType::UInt8 => ScalarValue::UInt8(Some(value as u8)), + DataType::UInt16 => ScalarValue::UInt16(Some(value as u16)), + DataType::UInt32 => ScalarValue::UInt32(Some(value as u32)), + DataType::UInt64 => ScalarValue::UInt64(Some(value as u64)), + DataType::Timestamp(TimeUnit::Second, tz) => { + let value = cast_between_timestamp( + &lit_data_type, + &DataType::Timestamp(TimeUnit::Second, tz.clone()), + value, + ); + ScalarValue::TimestampSecond(value, tz.clone()) + } + DataType::Timestamp(TimeUnit::Millisecond, tz) => { + let value = cast_between_timestamp( + &lit_data_type, + &DataType::Timestamp(TimeUnit::Millisecond, tz.clone()), + value, + ); + ScalarValue::TimestampMillisecond(value, tz.clone()) + } + DataType::Timestamp(TimeUnit::Microsecond, tz) => { + let value = cast_between_timestamp( + &lit_data_type, + &DataType::Timestamp(TimeUnit::Microsecond, tz.clone()), + value, + ); + ScalarValue::TimestampMicrosecond(value, tz.clone()) + } + DataType::Timestamp(TimeUnit::Nanosecond, tz) => { + let value = cast_between_timestamp( + &lit_data_type, + &DataType::Timestamp(TimeUnit::Nanosecond, tz.clone()), + value, + ); + ScalarValue::TimestampNanosecond(value, tz.clone()) + } + DataType::Decimal128(p, s) => { + ScalarValue::Decimal128(Some(value), *p, *s) + } + _ => { + return None; + } + }; + Some(result_scalar) + } else { + None + } + } + } +} + +fn try_cast_string_literal( + lit_value: &ScalarValue, + target_type: &DataType, +) -> Option { + let string_value = lit_value.try_as_str()?.map(|s| s.to_string()); + let scalar_value = match target_type { + DataType::Utf8 => ScalarValue::Utf8(string_value), + DataType::LargeUtf8 => ScalarValue::LargeUtf8(string_value), + DataType::Utf8View => ScalarValue::Utf8View(string_value), + _ => return None, + }; + Some(scalar_value) +} + +/// Attempt to cast to/from a dictionary type by wrapping/unwrapping the dictionary +fn try_cast_dictionary( + lit_value: &ScalarValue, + target_type: &DataType, +) -> Option { + let lit_value_type = lit_value.data_type(); + let result_scalar = match (lit_value, target_type) { + // Unwrap dictionary when inner type matches target type + (ScalarValue::Dictionary(_, inner_value), _) + if inner_value.data_type() == *target_type => + { + (**inner_value).clone() + } + // Wrap type when target type is dictionary + (_, DataType::Dictionary(index_type, inner_type)) + if **inner_type == lit_value_type => + { + ScalarValue::Dictionary(index_type.clone(), Box::new(lit_value.clone())) + } + _ => { + return None; + } + }; + Some(result_scalar) +} + +/// Cast a timestamp value from one unit to another +fn cast_between_timestamp(from: &DataType, to: &DataType, value: i128) -> Option { + let value = value as i64; + let from_scale = match from { + DataType::Timestamp(TimeUnit::Second, _) => 1, + DataType::Timestamp(TimeUnit::Millisecond, _) => MILLISECONDS, + DataType::Timestamp(TimeUnit::Microsecond, _) => MICROSECONDS, + DataType::Timestamp(TimeUnit::Nanosecond, _) => NANOSECONDS, + _ => return Some(value), + }; + + let to_scale = match to { + DataType::Timestamp(TimeUnit::Second, _) => 1, + DataType::Timestamp(TimeUnit::Millisecond, _) => MILLISECONDS, + DataType::Timestamp(TimeUnit::Microsecond, _) => MICROSECONDS, + DataType::Timestamp(TimeUnit::Nanosecond, _) => NANOSECONDS, + _ => return Some(value), + }; + + match from_scale.cmp(&to_scale) { + Ordering::Less => value.checked_mul(to_scale / from_scale), + Ordering::Greater => Some(value / (from_scale / to_scale)), + Ordering::Equal => Some(value), + } +} + +fn try_cast_binary( + lit_value: &ScalarValue, + target_type: &DataType, +) -> Option { + match (lit_value, target_type) { + (ScalarValue::Binary(Some(v)), DataType::FixedSizeBinary(n)) + if v.len() == *n as usize => + { + Some(ScalarValue::FixedSizeBinary(*n, Some(v.clone()))) + } + _ => None, + } +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow::compute::{cast_with_options, CastOptions}; + use arrow::datatypes::{Field, Fields, TimeUnit}; + use std::sync::Arc; + + #[derive(Debug, Clone)] + enum ExpectedCast { + /// test successfully cast value and it is as specified + Value(ScalarValue), + /// test returned OK, but could not cast the value + NoValue, + } + + /// Runs try_cast_literal_to_type with the specified inputs and + /// ensure it computes the expected output, and ensures the + /// casting is consistent with the Arrow kernels + fn expect_cast( + literal: ScalarValue, + target_type: DataType, + expected_result: ExpectedCast, + ) { + let actual_value = try_cast_literal_to_type(&literal, &target_type); + + println!("expect_cast: "); + println!(" {literal:?} --> {target_type:?}"); + println!(" expected_result: {expected_result:?}"); + println!(" actual_result: {actual_value:?}"); + + match expected_result { + ExpectedCast::Value(expected_value) => { + let actual_value = + actual_value.expect("Expected cast value but got None"); + + assert_eq!(actual_value, expected_value); + + // Verify that calling the arrow + // cast kernel yields the same results + // input array + let literal_array = literal + .to_array_of_size(1) + .expect("Failed to convert to array of size"); + let expected_array = expected_value + .to_array_of_size(1) + .expect("Failed to convert to array of size"); + let cast_array = cast_with_options( + &literal_array, + &target_type, + &CastOptions::default(), + ) + .expect("Expected to be cast array with arrow cast kernel"); + + assert_eq!( + &expected_array, &cast_array, + "Result of casting {literal:?} with arrow was\n {cast_array:#?}\nbut expected\n{expected_array:#?}" + ); + + // Verify that for timestamp types the timezones are the same + // (ScalarValue::cmp doesn't account for timezones); + if let ( + DataType::Timestamp(left_unit, left_tz), + DataType::Timestamp(right_unit, right_tz), + ) = (actual_value.data_type(), expected_value.data_type()) + { + assert_eq!(left_unit, right_unit); + assert_eq!(left_tz, right_tz); + } + } + ExpectedCast::NoValue => { + assert!( + actual_value.is_none(), + "Expected no cast value, but got {actual_value:?}" + ); + } + } + } + + #[test] + fn test_try_cast_to_type_nulls() { + // test that nulls can be cast to/from all integer types + let scalars = vec![ + ScalarValue::Int8(None), + ScalarValue::Int16(None), + ScalarValue::Int32(None), + ScalarValue::Int64(None), + ScalarValue::UInt8(None), + ScalarValue::UInt16(None), + ScalarValue::UInt32(None), + ScalarValue::UInt64(None), + ScalarValue::Decimal128(None, 3, 0), + ScalarValue::Decimal128(None, 8, 2), + ScalarValue::Utf8(None), + ScalarValue::LargeUtf8(None), + ]; + + for s1 in &scalars { + for s2 in &scalars { + let expected_value = ExpectedCast::Value(s2.clone()); + + expect_cast(s1.clone(), s2.data_type(), expected_value); + } + } + } + + #[test] + fn test_try_cast_to_type_int_in_range() { + // test values that can be cast to/from all integer types + let scalars = vec![ + ScalarValue::Int8(Some(123)), + ScalarValue::Int16(Some(123)), + ScalarValue::Int32(Some(123)), + ScalarValue::Int64(Some(123)), + ScalarValue::UInt8(Some(123)), + ScalarValue::UInt16(Some(123)), + ScalarValue::UInt32(Some(123)), + ScalarValue::UInt64(Some(123)), + ScalarValue::Decimal128(Some(123), 3, 0), + ScalarValue::Decimal128(Some(12300), 8, 2), + ]; + + for s1 in &scalars { + for s2 in &scalars { + let expected_value = ExpectedCast::Value(s2.clone()); + + expect_cast(s1.clone(), s2.data_type(), expected_value); + } + } + + let max_i32 = ScalarValue::Int32(Some(i32::MAX)); + expect_cast( + max_i32, + DataType::UInt64, + ExpectedCast::Value(ScalarValue::UInt64(Some(i32::MAX as u64))), + ); + + let min_i32 = ScalarValue::Int32(Some(i32::MIN)); + expect_cast( + min_i32, + DataType::Int64, + ExpectedCast::Value(ScalarValue::Int64(Some(i32::MIN as i64))), + ); + + let max_i64 = ScalarValue::Int64(Some(i64::MAX)); + expect_cast( + max_i64, + DataType::UInt64, + ExpectedCast::Value(ScalarValue::UInt64(Some(i64::MAX as u64))), + ); + } + + #[test] + fn test_try_cast_to_type_int_out_of_range() { + let min_i32 = ScalarValue::Int32(Some(i32::MIN)); + let min_i64 = ScalarValue::Int64(Some(i64::MIN)); + let max_i64 = ScalarValue::Int64(Some(i64::MAX)); + let max_u64 = ScalarValue::UInt64(Some(u64::MAX)); + + expect_cast(max_i64.clone(), DataType::Int8, ExpectedCast::NoValue); + + expect_cast(max_i64.clone(), DataType::Int16, ExpectedCast::NoValue); + + expect_cast(max_i64, DataType::Int32, ExpectedCast::NoValue); + + expect_cast(max_u64, DataType::Int64, ExpectedCast::NoValue); + + expect_cast(min_i64, DataType::UInt64, ExpectedCast::NoValue); + + expect_cast(min_i32, DataType::UInt64, ExpectedCast::NoValue); + + // decimal out of range + expect_cast( + ScalarValue::Decimal128(Some(99999999999999999999999999999999999900), 38, 0), + DataType::Int64, + ExpectedCast::NoValue, + ); + + expect_cast( + ScalarValue::Decimal128(Some(-9999999999999999999999999999999999), 37, 1), + DataType::Int64, + ExpectedCast::NoValue, + ); + } + + #[test] + fn test_try_decimal_cast_in_range() { + expect_cast( + ScalarValue::Decimal128(Some(12300), 5, 2), + DataType::Decimal128(3, 0), + ExpectedCast::Value(ScalarValue::Decimal128(Some(123), 3, 0)), + ); + + expect_cast( + ScalarValue::Decimal128(Some(12300), 5, 2), + DataType::Decimal128(8, 0), + ExpectedCast::Value(ScalarValue::Decimal128(Some(123), 8, 0)), + ); + + expect_cast( + ScalarValue::Decimal128(Some(12300), 5, 2), + DataType::Decimal128(8, 5), + ExpectedCast::Value(ScalarValue::Decimal128(Some(12300000), 8, 5)), + ); + } + + #[test] + fn test_try_decimal_cast_out_of_range() { + // decimal would lose precision + expect_cast( + ScalarValue::Decimal128(Some(12345), 5, 2), + DataType::Decimal128(3, 0), + ExpectedCast::NoValue, + ); + + // decimal would lose precision + expect_cast( + ScalarValue::Decimal128(Some(12300), 5, 2), + DataType::Decimal128(2, 0), + ExpectedCast::NoValue, + ); + } + + #[test] + fn test_try_cast_to_type_timestamps() { + for time_unit in [ + TimeUnit::Second, + TimeUnit::Millisecond, + TimeUnit::Microsecond, + TimeUnit::Nanosecond, + ] { + let utc = Some("+00:00".into()); + // No timezone, utc timezone + let (lit_tz_none, lit_tz_utc) = match time_unit { + TimeUnit::Second => ( + ScalarValue::TimestampSecond(Some(12345), None), + ScalarValue::TimestampSecond(Some(12345), utc), + ), + + TimeUnit::Millisecond => ( + ScalarValue::TimestampMillisecond(Some(12345), None), + ScalarValue::TimestampMillisecond(Some(12345), utc), + ), + + TimeUnit::Microsecond => ( + ScalarValue::TimestampMicrosecond(Some(12345), None), + ScalarValue::TimestampMicrosecond(Some(12345), utc), + ), + + TimeUnit::Nanosecond => ( + ScalarValue::TimestampNanosecond(Some(12345), None), + ScalarValue::TimestampNanosecond(Some(12345), utc), + ), + }; + + // DataFusion ignores timezones for comparisons of ScalarValue + // so double check it here + assert_eq!(lit_tz_none, lit_tz_utc); + + // e.g. DataType::Timestamp(_, None) + let dt_tz_none = lit_tz_none.data_type(); + + // e.g. DataType::Timestamp(_, Some(utc)) + let dt_tz_utc = lit_tz_utc.data_type(); + + // None <--> None + expect_cast( + lit_tz_none.clone(), + dt_tz_none.clone(), + ExpectedCast::Value(lit_tz_none.clone()), + ); + + // None <--> Utc + expect_cast( + lit_tz_none.clone(), + dt_tz_utc.clone(), + ExpectedCast::Value(lit_tz_utc.clone()), + ); + + // Utc <--> None + expect_cast( + lit_tz_utc.clone(), + dt_tz_none.clone(), + ExpectedCast::Value(lit_tz_none.clone()), + ); + + // Utc <--> Utc + expect_cast( + lit_tz_utc.clone(), + dt_tz_utc.clone(), + ExpectedCast::Value(lit_tz_utc.clone()), + ); + + // timestamp to int64 + expect_cast( + lit_tz_utc.clone(), + DataType::Int64, + ExpectedCast::Value(ScalarValue::Int64(Some(12345))), + ); + + // int64 to timestamp + expect_cast( + ScalarValue::Int64(Some(12345)), + dt_tz_none.clone(), + ExpectedCast::Value(lit_tz_none.clone()), + ); + + // int64 to timestamp + expect_cast( + ScalarValue::Int64(Some(12345)), + dt_tz_utc.clone(), + ExpectedCast::Value(lit_tz_utc.clone()), + ); + + // timestamp to string (not supported yet) + expect_cast( + lit_tz_utc.clone(), + DataType::LargeUtf8, + ExpectedCast::NoValue, + ); + } + } + + #[test] + fn test_try_cast_to_type_unsupported() { + // int64 to list + expect_cast( + ScalarValue::Int64(Some(12345)), + DataType::List(Arc::new(Field::new("f", DataType::Int32, true))), + ExpectedCast::NoValue, + ); + } + + #[test] + fn test_try_cast_literal_to_timestamp() { + // same timestamp + let new_scalar = try_cast_literal_to_type( + &ScalarValue::TimestampNanosecond(Some(123456), None), + &DataType::Timestamp(TimeUnit::Nanosecond, None), + ) + .unwrap(); + + assert_eq!( + new_scalar, + ScalarValue::TimestampNanosecond(Some(123456), None) + ); + + // TimestampNanosecond to TimestampMicrosecond + let new_scalar = try_cast_literal_to_type( + &ScalarValue::TimestampNanosecond(Some(123456), None), + &DataType::Timestamp(TimeUnit::Microsecond, None), + ) + .unwrap(); + + assert_eq!( + new_scalar, + ScalarValue::TimestampMicrosecond(Some(123), None) + ); + + // TimestampNanosecond to TimestampMillisecond + let new_scalar = try_cast_literal_to_type( + &ScalarValue::TimestampNanosecond(Some(123456), None), + &DataType::Timestamp(TimeUnit::Millisecond, None), + ) + .unwrap(); + + assert_eq!(new_scalar, ScalarValue::TimestampMillisecond(Some(0), None)); + + // TimestampNanosecond to TimestampSecond + let new_scalar = try_cast_literal_to_type( + &ScalarValue::TimestampNanosecond(Some(123456), None), + &DataType::Timestamp(TimeUnit::Second, None), + ) + .unwrap(); + + assert_eq!(new_scalar, ScalarValue::TimestampSecond(Some(0), None)); + + // TimestampMicrosecond to TimestampNanosecond + let new_scalar = try_cast_literal_to_type( + &ScalarValue::TimestampMicrosecond(Some(123), None), + &DataType::Timestamp(TimeUnit::Nanosecond, None), + ) + .unwrap(); + + assert_eq!( + new_scalar, + ScalarValue::TimestampNanosecond(Some(123000), None) + ); + + // TimestampMicrosecond to TimestampMillisecond + let new_scalar = try_cast_literal_to_type( + &ScalarValue::TimestampMicrosecond(Some(123), None), + &DataType::Timestamp(TimeUnit::Millisecond, None), + ) + .unwrap(); + + assert_eq!(new_scalar, ScalarValue::TimestampMillisecond(Some(0), None)); + + // TimestampMicrosecond to TimestampSecond + let new_scalar = try_cast_literal_to_type( + &ScalarValue::TimestampMicrosecond(Some(123456789), None), + &DataType::Timestamp(TimeUnit::Second, None), + ) + .unwrap(); + assert_eq!(new_scalar, ScalarValue::TimestampSecond(Some(123), None)); + + // TimestampMillisecond to TimestampNanosecond + let new_scalar = try_cast_literal_to_type( + &ScalarValue::TimestampMillisecond(Some(123), None), + &DataType::Timestamp(TimeUnit::Nanosecond, None), + ) + .unwrap(); + assert_eq!( + new_scalar, + ScalarValue::TimestampNanosecond(Some(123000000), None) + ); + + // TimestampMillisecond to TimestampMicrosecond + let new_scalar = try_cast_literal_to_type( + &ScalarValue::TimestampMillisecond(Some(123), None), + &DataType::Timestamp(TimeUnit::Microsecond, None), + ) + .unwrap(); + assert_eq!( + new_scalar, + ScalarValue::TimestampMicrosecond(Some(123000), None) + ); + // TimestampMillisecond to TimestampSecond + let new_scalar = try_cast_literal_to_type( + &ScalarValue::TimestampMillisecond(Some(123456789), None), + &DataType::Timestamp(TimeUnit::Second, None), + ) + .unwrap(); + assert_eq!(new_scalar, ScalarValue::TimestampSecond(Some(123456), None)); + + // TimestampSecond to TimestampNanosecond + let new_scalar = try_cast_literal_to_type( + &ScalarValue::TimestampSecond(Some(123), None), + &DataType::Timestamp(TimeUnit::Nanosecond, None), + ) + .unwrap(); + assert_eq!( + new_scalar, + ScalarValue::TimestampNanosecond(Some(123000000000), None) + ); + + // TimestampSecond to TimestampMicrosecond + let new_scalar = try_cast_literal_to_type( + &ScalarValue::TimestampSecond(Some(123), None), + &DataType::Timestamp(TimeUnit::Microsecond, None), + ) + .unwrap(); + assert_eq!( + new_scalar, + ScalarValue::TimestampMicrosecond(Some(123000000), None) + ); + + // TimestampSecond to TimestampMillisecond + let new_scalar = try_cast_literal_to_type( + &ScalarValue::TimestampSecond(Some(123), None), + &DataType::Timestamp(TimeUnit::Millisecond, None), + ) + .unwrap(); + assert_eq!( + new_scalar, + ScalarValue::TimestampMillisecond(Some(123000), None) + ); + + // overflow + let new_scalar = try_cast_literal_to_type( + &ScalarValue::TimestampSecond(Some(i64::MAX), None), + &DataType::Timestamp(TimeUnit::Millisecond, None), + ) + .unwrap(); + assert_eq!(new_scalar, ScalarValue::TimestampMillisecond(None, None)); + } + + #[test] + fn test_try_cast_to_string_type() { + let scalars = vec![ + ScalarValue::from("string"), + ScalarValue::LargeUtf8(Some("string".to_owned())), + ]; + + for s1 in &scalars { + for s2 in &scalars { + let expected_value = ExpectedCast::Value(s2.clone()); + + expect_cast(s1.clone(), s2.data_type(), expected_value); + } + } + } + + #[test] + fn test_try_cast_to_dictionary_type() { + fn dictionary_type(t: DataType) -> DataType { + DataType::Dictionary(Box::new(DataType::Int32), Box::new(t)) + } + fn dictionary_value(value: ScalarValue) -> ScalarValue { + ScalarValue::Dictionary(Box::new(DataType::Int32), Box::new(value)) + } + let scalars = vec![ + ScalarValue::from("string"), + ScalarValue::LargeUtf8(Some("string".to_owned())), + ]; + for s in &scalars { + expect_cast( + s.clone(), + dictionary_type(s.data_type()), + ExpectedCast::Value(dictionary_value(s.clone())), + ); + expect_cast( + dictionary_value(s.clone()), + s.data_type(), + ExpectedCast::Value(s.clone()), + ) + } + } + + #[test] + fn test_try_cast_to_fixed_size_binary() { + expect_cast( + ScalarValue::Binary(Some(vec![1, 2, 3])), + DataType::FixedSizeBinary(3), + ExpectedCast::Value(ScalarValue::FixedSizeBinary(3, Some(vec![1, 2, 3]))), + ) + } + + #[test] + fn test_numeric_boundary_values() { + // Test exact boundary values for signed integers + expect_cast( + ScalarValue::Int8(Some(i8::MAX)), + DataType::UInt8, + ExpectedCast::Value(ScalarValue::UInt8(Some(i8::MAX as u8))), + ); + + expect_cast( + ScalarValue::Int8(Some(i8::MIN)), + DataType::UInt8, + ExpectedCast::NoValue, + ); + + expect_cast( + ScalarValue::UInt8(Some(u8::MAX)), + DataType::Int8, + ExpectedCast::NoValue, + ); + + // Test cross-type boundary scenarios + expect_cast( + ScalarValue::Int32(Some(i32::MAX)), + DataType::Int64, + ExpectedCast::Value(ScalarValue::Int64(Some(i32::MAX as i64))), + ); + + expect_cast( + ScalarValue::Int64(Some(i64::MIN)), + DataType::UInt64, + ExpectedCast::NoValue, + ); + + // Test unsigned to signed edge cases + expect_cast( + ScalarValue::UInt32(Some(u32::MAX)), + DataType::Int32, + ExpectedCast::NoValue, + ); + + expect_cast( + ScalarValue::UInt64(Some(u64::MAX)), + DataType::Int64, + ExpectedCast::NoValue, + ); + } + + #[test] + fn test_decimal_precision_limits() { + use arrow::datatypes::{ + MAX_DECIMAL128_FOR_EACH_PRECISION, MIN_DECIMAL128_FOR_EACH_PRECISION, + }; + + // Test maximum precision values + expect_cast( + ScalarValue::Decimal128(Some(MAX_DECIMAL128_FOR_EACH_PRECISION[3]), 3, 0), + DataType::Decimal128(5, 0), + ExpectedCast::Value(ScalarValue::Decimal128( + Some(MAX_DECIMAL128_FOR_EACH_PRECISION[3]), + 5, + 0, + )), + ); + + // Test minimum precision values + expect_cast( + ScalarValue::Decimal128(Some(MIN_DECIMAL128_FOR_EACH_PRECISION[3]), 3, 0), + DataType::Decimal128(5, 0), + ExpectedCast::Value(ScalarValue::Decimal128( + Some(MIN_DECIMAL128_FOR_EACH_PRECISION[3]), + 5, + 0, + )), + ); + + // Test scale increase + expect_cast( + ScalarValue::Decimal128(Some(123), 3, 0), + DataType::Decimal128(5, 2), + ExpectedCast::Value(ScalarValue::Decimal128(Some(12300), 5, 2)), + ); + + // Test precision overflow (value too large for target precision) + expect_cast( + ScalarValue::Decimal128(Some(MAX_DECIMAL128_FOR_EACH_PRECISION[10]), 10, 0), + DataType::Decimal128(3, 0), + ExpectedCast::NoValue, + ); + + // Test non-divisible decimal conversion (should fail) + expect_cast( + ScalarValue::Decimal128(Some(12345), 5, 3), // 12.345 + DataType::Int32, + ExpectedCast::NoValue, // Can't convert 12.345 to integer without loss + ); + + // Test edge case: scale reduction with precision loss + expect_cast( + ScalarValue::Decimal128(Some(12345), 5, 2), // 123.45 + DataType::Decimal128(3, 0), // Can only hold up to 999 + ExpectedCast::NoValue, + ); + } + + #[test] + fn test_timestamp_overflow_scenarios() { + // Test overflow in timestamp conversions + let max_seconds = i64::MAX / 1_000_000_000; // Avoid overflow when converting to nanos + + // This should work - within safe range + expect_cast( + ScalarValue::TimestampSecond(Some(max_seconds), None), + DataType::Timestamp(TimeUnit::Nanosecond, None), + ExpectedCast::Value(ScalarValue::TimestampNanosecond( + Some(max_seconds * 1_000_000_000), + None, + )), + ); + + // Test very large nanosecond value conversion to smaller units + expect_cast( + ScalarValue::TimestampNanosecond(Some(i64::MAX), None), + DataType::Timestamp(TimeUnit::Second, None), + ExpectedCast::Value(ScalarValue::TimestampSecond( + Some(i64::MAX / 1_000_000_000), + None, + )), + ); + + // Test precision loss in downscaling + expect_cast( + ScalarValue::TimestampNanosecond(Some(1), None), + DataType::Timestamp(TimeUnit::Second, None), + ExpectedCast::Value(ScalarValue::TimestampSecond(Some(0), None)), + ); + + expect_cast( + ScalarValue::TimestampMicrosecond(Some(999), None), + DataType::Timestamp(TimeUnit::Millisecond, None), + ExpectedCast::Value(ScalarValue::TimestampMillisecond(Some(0), None)), + ); + } + + #[test] + fn test_string_view() { + // Test Utf8View to other string types + expect_cast( + ScalarValue::Utf8View(Some("test".to_string())), + DataType::Utf8, + ExpectedCast::Value(ScalarValue::Utf8(Some("test".to_string()))), + ); + + expect_cast( + ScalarValue::Utf8View(Some("test".to_string())), + DataType::LargeUtf8, + ExpectedCast::Value(ScalarValue::LargeUtf8(Some("test".to_string()))), + ); + + // Test other string types to Utf8View + expect_cast( + ScalarValue::Utf8(Some("hello".to_string())), + DataType::Utf8View, + ExpectedCast::Value(ScalarValue::Utf8View(Some("hello".to_string()))), + ); + + expect_cast( + ScalarValue::LargeUtf8(Some("world".to_string())), + DataType::Utf8View, + ExpectedCast::Value(ScalarValue::Utf8View(Some("world".to_string()))), + ); + + // Test empty string + expect_cast( + ScalarValue::Utf8(Some("".to_string())), + DataType::Utf8View, + ExpectedCast::Value(ScalarValue::Utf8View(Some("".to_string()))), + ); + + // Test large string + let large_string = "x".repeat(1000); + expect_cast( + ScalarValue::LargeUtf8(Some(large_string.clone())), + DataType::Utf8View, + ExpectedCast::Value(ScalarValue::Utf8View(Some(large_string))), + ); + } + + #[test] + fn test_binary_size_edge_cases() { + // Test size mismatch - too small + expect_cast( + ScalarValue::Binary(Some(vec![1, 2])), + DataType::FixedSizeBinary(3), + ExpectedCast::NoValue, + ); + + // Test size mismatch - too large + expect_cast( + ScalarValue::Binary(Some(vec![1, 2, 3, 4])), + DataType::FixedSizeBinary(3), + ExpectedCast::NoValue, + ); + + // Test empty binary + expect_cast( + ScalarValue::Binary(Some(vec![])), + DataType::FixedSizeBinary(0), + ExpectedCast::Value(ScalarValue::FixedSizeBinary(0, Some(vec![]))), + ); + + // Test exact size match + expect_cast( + ScalarValue::Binary(Some(vec![1, 2, 3])), + DataType::FixedSizeBinary(3), + ExpectedCast::Value(ScalarValue::FixedSizeBinary(3, Some(vec![1, 2, 3]))), + ); + + // Test single byte + expect_cast( + ScalarValue::Binary(Some(vec![42])), + DataType::FixedSizeBinary(1), + ExpectedCast::Value(ScalarValue::FixedSizeBinary(1, Some(vec![42]))), + ); + } + + #[test] + fn test_dictionary_index_types() { + // Test different dictionary index types + let string_value = ScalarValue::Utf8(Some("test".to_string())); + + // Int8 index dictionary + let dict_int8 = + DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::Utf8)); + expect_cast( + string_value.clone(), + dict_int8, + ExpectedCast::Value(ScalarValue::Dictionary( + Box::new(DataType::Int8), + Box::new(string_value.clone()), + )), + ); + + // Int16 index dictionary + let dict_int16 = + DataType::Dictionary(Box::new(DataType::Int16), Box::new(DataType::Utf8)); + expect_cast( + string_value.clone(), + dict_int16, + ExpectedCast::Value(ScalarValue::Dictionary( + Box::new(DataType::Int16), + Box::new(string_value.clone()), + )), + ); + + // Int64 index dictionary + let dict_int64 = + DataType::Dictionary(Box::new(DataType::Int64), Box::new(DataType::Utf8)); + expect_cast( + string_value.clone(), + dict_int64, + ExpectedCast::Value(ScalarValue::Dictionary( + Box::new(DataType::Int64), + Box::new(string_value.clone()), + )), + ); + + // Test dictionary unwrapping + let dict_value = ScalarValue::Dictionary( + Box::new(DataType::Int32), + Box::new(ScalarValue::LargeUtf8(Some("unwrap_test".to_string()))), + ); + expect_cast( + dict_value, + DataType::LargeUtf8, + ExpectedCast::Value(ScalarValue::LargeUtf8(Some("unwrap_test".to_string()))), + ); + } + + #[test] + fn test_type_support_functions() { + // Test numeric type support + assert!(is_supported_numeric_type(&DataType::Int8)); + assert!(is_supported_numeric_type(&DataType::UInt64)); + assert!(is_supported_numeric_type(&DataType::Decimal128(10, 2))); + assert!(is_supported_numeric_type(&DataType::Timestamp( + TimeUnit::Nanosecond, + None + ))); + assert!(!is_supported_numeric_type(&DataType::Float32)); + assert!(!is_supported_numeric_type(&DataType::Float64)); + + // Test string type support + assert!(is_supported_string_type(&DataType::Utf8)); + assert!(is_supported_string_type(&DataType::LargeUtf8)); + assert!(is_supported_string_type(&DataType::Utf8View)); + assert!(!is_supported_string_type(&DataType::Binary)); + + // Test binary type support + assert!(is_supported_binary_type(&DataType::Binary)); + assert!(is_supported_binary_type(&DataType::FixedSizeBinary(10))); + assert!(!is_supported_binary_type(&DataType::Utf8)); + + // Test dictionary type support with nested types + assert!(is_supported_dictionary_type(&DataType::Dictionary( + Box::new(DataType::Int32), + Box::new(DataType::Utf8) + ))); + assert!(is_supported_dictionary_type(&DataType::Dictionary( + Box::new(DataType::Int32), + Box::new(DataType::Int64) + ))); + assert!(!is_supported_dictionary_type(&DataType::Dictionary( + Box::new(DataType::Int32), + Box::new(DataType::List(Arc::new(Field::new( + "item", + DataType::Int32, + true + )))) + ))); + + // Test overall type support + assert!(is_supported_type(&DataType::Int32)); + assert!(is_supported_type(&DataType::Utf8)); + assert!(is_supported_type(&DataType::Binary)); + assert!(is_supported_type(&DataType::Dictionary( + Box::new(DataType::Int32), + Box::new(DataType::Utf8) + ))); + assert!(!is_supported_type(&DataType::List(Arc::new(Field::new( + "item", + DataType::Int32, + true + ))))); + assert!(!is_supported_type(&DataType::Struct(Fields::empty()))); + } + + #[test] + fn test_error_conditions() { + // Test unsupported source type + expect_cast( + ScalarValue::Float32(Some(1.5)), + DataType::Int32, + ExpectedCast::NoValue, + ); + + // Test unsupported target type + expect_cast( + ScalarValue::Int32(Some(123)), + DataType::Float64, + ExpectedCast::NoValue, + ); + + // Test both types unsupported + expect_cast( + ScalarValue::Float64(Some(1.5)), + DataType::Float32, + ExpectedCast::NoValue, + ); + + // Test complex unsupported types + let list_type = + DataType::List(Arc::new(Field::new("item", DataType::Int32, true))); + expect_cast( + ScalarValue::Int32(Some(123)), + list_type, + ExpectedCast::NoValue, + ); + + // Test dictionary with unsupported inner type + let bad_dict = DataType::Dictionary( + Box::new(DataType::Int32), + Box::new(DataType::List(Arc::new(Field::new( + "item", + DataType::Int32, + true, + )))), + ); + expect_cast( + ScalarValue::Int32(Some(123)), + bad_dict, + ExpectedCast::NoValue, + ); + } +} diff --git a/datafusion/expr-common/src/lib.rs b/datafusion/expr-common/src/lib.rs index 961670a3b7f45..597bf7713be2a 100644 --- a/datafusion/expr-common/src/lib.rs +++ b/datafusion/expr-common/src/lib.rs @@ -33,6 +33,7 @@ #![deny(clippy::clone_on_ref_ptr)] pub mod accumulator; +pub mod casts; pub mod columnar_value; pub mod groups_accumulator; pub mod interval_arithmetic; diff --git a/datafusion/optimizer/Cargo.toml b/datafusion/optimizer/Cargo.toml index 60358d20e2a1a..6d43ab7e9d7b2 100644 --- a/datafusion/optimizer/Cargo.toml +++ b/datafusion/optimizer/Cargo.toml @@ -45,6 +45,7 @@ arrow = { workspace = true } chrono = { workspace = true } datafusion-common = { workspace = true, default-features = true } datafusion-expr = { workspace = true } +datafusion-expr-common = { workspace = true } datafusion-physical-expr = { workspace = true } indexmap = { workspace = true } itertools = { workspace = true } diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs index 2be7a2b0bd6ea..26ac4a30b7047 100644 --- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs @@ -46,6 +46,7 @@ use datafusion_physical_expr::{create_physical_expr, execution_props::ExecutionP use super::inlist_simplifier::ShortenInListSimplifier; use super::utils::*; +use crate::analyzer::type_coercion::TypeCoercionRewriter; use crate::simplify_expressions::guarantees::GuaranteeRewriter; use crate::simplify_expressions::regex::simplify_regex_expr; use crate::simplify_expressions::unwrap_cast::{ @@ -54,11 +55,8 @@ use crate::simplify_expressions::unwrap_cast::{ unwrap_cast_in_comparison_for_binary, }; use crate::simplify_expressions::SimplifyInfo; -use crate::{ - analyzer::type_coercion::TypeCoercionRewriter, - simplify_expressions::unwrap_cast::try_cast_literal_to_type, -}; use datafusion_expr::expr::FieldMetadata; +use datafusion_expr_common::casts::try_cast_literal_to_type; use indexmap::IndexSet; use regex::Regex; diff --git a/datafusion/optimizer/src/simplify_expressions/unwrap_cast.rs b/datafusion/optimizer/src/simplify_expressions/unwrap_cast.rs index 7c8ff8305e843..6e66e467a89de 100644 --- a/datafusion/optimizer/src/simplify_expressions/unwrap_cast.rs +++ b/datafusion/optimizer/src/simplify_expressions/unwrap_cast.rs @@ -55,17 +55,12 @@ //! ``` //! -use std::cmp::Ordering; - -use arrow::datatypes::{ - DataType, TimeUnit, MAX_DECIMAL128_FOR_EACH_PRECISION, - MIN_DECIMAL128_FOR_EACH_PRECISION, -}; -use arrow::temporal_conversions::{MICROSECONDS, MILLISECONDS, NANOSECONDS}; +use arrow::datatypes::DataType; use datafusion_common::{internal_err, tree_node::Transformed}; use datafusion_common::{Result, ScalarValue}; use datafusion_expr::{lit, BinaryExpr}; use datafusion_expr::{simplify::SimplifyInfo, Cast, Expr, Operator, TryCast}; +use datafusion_expr_common::casts::{is_supported_type, try_cast_literal_to_type}; pub(super) fn unwrap_cast_in_comparison_for_binary( info: &S, @@ -192,49 +187,6 @@ pub(super) fn is_cast_expr_and_support_unwrap_cast_in_comparison_for_inlist< true } -/// Returns true if unwrap_cast_in_comparison supports this data type -fn is_supported_type(data_type: &DataType) -> bool { - is_supported_numeric_type(data_type) - || is_supported_string_type(data_type) - || is_supported_dictionary_type(data_type) - || is_supported_binary_type(data_type) -} - -/// Returns true if unwrap_cast_in_comparison support this numeric type -fn is_supported_numeric_type(data_type: &DataType) -> bool { - matches!( - data_type, - DataType::UInt8 - | DataType::UInt16 - | DataType::UInt32 - | DataType::UInt64 - | DataType::Int8 - | DataType::Int16 - | DataType::Int32 - | DataType::Int64 - | DataType::Decimal128(_, _) - | DataType::Timestamp(_, _) - ) -} - -/// Returns true if unwrap_cast_in_comparison supports casting this value as a string -fn is_supported_string_type(data_type: &DataType) -> bool { - matches!( - data_type, - DataType::Utf8 | DataType::LargeUtf8 | DataType::Utf8View - ) -} - -/// Returns true if unwrap_cast_in_comparison supports casting this value as a dictionary -fn is_supported_dictionary_type(data_type: &DataType) -> bool { - matches!(data_type, - DataType::Dictionary(_, inner) if is_supported_type(inner)) -} - -fn is_supported_binary_type(data_type: &DataType) -> bool { - matches!(data_type, DataType::Binary | DataType::FixedSizeBinary(_)) -} - ///// Tries to move a cast from an expression (such as column) to the literal other side of a comparison operator./ /// /// Specifically, rewrites @@ -281,246 +233,6 @@ fn cast_literal_to_type_with_op( } } -/// Convert a literal value from one data type to another -pub(super) fn try_cast_literal_to_type( - lit_value: &ScalarValue, - target_type: &DataType, -) -> Option { - let lit_data_type = lit_value.data_type(); - if !is_supported_type(&lit_data_type) || !is_supported_type(target_type) { - return None; - } - if lit_value.is_null() { - // null value can be cast to any type of null value - return ScalarValue::try_from(target_type).ok(); - } - try_cast_numeric_literal(lit_value, target_type) - .or_else(|| try_cast_string_literal(lit_value, target_type)) - .or_else(|| try_cast_dictionary(lit_value, target_type)) - .or_else(|| try_cast_binary(lit_value, target_type)) -} - -/// Convert a numeric value from one numeric data type to another -fn try_cast_numeric_literal( - lit_value: &ScalarValue, - target_type: &DataType, -) -> Option { - let lit_data_type = lit_value.data_type(); - if !is_supported_numeric_type(&lit_data_type) - || !is_supported_numeric_type(target_type) - { - return None; - } - - let mul = match target_type { - DataType::UInt8 - | DataType::UInt16 - | DataType::UInt32 - | DataType::UInt64 - | DataType::Int8 - | DataType::Int16 - | DataType::Int32 - | DataType::Int64 => 1_i128, - DataType::Timestamp(_, _) => 1_i128, - DataType::Decimal128(_, scale) => 10_i128.pow(*scale as u32), - _ => return None, - }; - let (target_min, target_max) = match target_type { - DataType::UInt8 => (u8::MIN as i128, u8::MAX as i128), - DataType::UInt16 => (u16::MIN as i128, u16::MAX as i128), - DataType::UInt32 => (u32::MIN as i128, u32::MAX as i128), - DataType::UInt64 => (u64::MIN as i128, u64::MAX as i128), - DataType::Int8 => (i8::MIN as i128, i8::MAX as i128), - DataType::Int16 => (i16::MIN as i128, i16::MAX as i128), - DataType::Int32 => (i32::MIN as i128, i32::MAX as i128), - DataType::Int64 => (i64::MIN as i128, i64::MAX as i128), - DataType::Timestamp(_, _) => (i64::MIN as i128, i64::MAX as i128), - DataType::Decimal128(precision, _) => ( - // Different precision for decimal128 can store different range of value. - // For example, the precision is 3, the max of value is `999` and the min - // value is `-999` - MIN_DECIMAL128_FOR_EACH_PRECISION[*precision as usize], - MAX_DECIMAL128_FOR_EACH_PRECISION[*precision as usize], - ), - _ => return None, - }; - let lit_value_target_type = match lit_value { - ScalarValue::Int8(Some(v)) => (*v as i128).checked_mul(mul), - ScalarValue::Int16(Some(v)) => (*v as i128).checked_mul(mul), - ScalarValue::Int32(Some(v)) => (*v as i128).checked_mul(mul), - ScalarValue::Int64(Some(v)) => (*v as i128).checked_mul(mul), - ScalarValue::UInt8(Some(v)) => (*v as i128).checked_mul(mul), - ScalarValue::UInt16(Some(v)) => (*v as i128).checked_mul(mul), - ScalarValue::UInt32(Some(v)) => (*v as i128).checked_mul(mul), - ScalarValue::UInt64(Some(v)) => (*v as i128).checked_mul(mul), - ScalarValue::TimestampSecond(Some(v), _) => (*v as i128).checked_mul(mul), - ScalarValue::TimestampMillisecond(Some(v), _) => (*v as i128).checked_mul(mul), - ScalarValue::TimestampMicrosecond(Some(v), _) => (*v as i128).checked_mul(mul), - ScalarValue::TimestampNanosecond(Some(v), _) => (*v as i128).checked_mul(mul), - ScalarValue::Decimal128(Some(v), _, scale) => { - let lit_scale_mul = 10_i128.pow(*scale as u32); - if mul >= lit_scale_mul { - // Example: - // lit is decimal(123,3,2) - // target type is decimal(5,3) - // the lit can be converted to the decimal(1230,5,3) - (*v).checked_mul(mul / lit_scale_mul) - } else if (*v) % (lit_scale_mul / mul) == 0 { - // Example: - // lit is decimal(123000,10,3) - // target type is int32: the lit can be converted to INT32(123) - // target type is decimal(10,2): the lit can be converted to decimal(12300,10,2) - Some(*v / (lit_scale_mul / mul)) - } else { - // can't convert the lit decimal to the target data type - None - } - } - _ => None, - }; - - match lit_value_target_type { - None => None, - Some(value) => { - if value >= target_min && value <= target_max { - // the value casted from lit to the target type is in the range of target type. - // return the target type of scalar value - let result_scalar = match target_type { - DataType::Int8 => ScalarValue::Int8(Some(value as i8)), - DataType::Int16 => ScalarValue::Int16(Some(value as i16)), - DataType::Int32 => ScalarValue::Int32(Some(value as i32)), - DataType::Int64 => ScalarValue::Int64(Some(value as i64)), - DataType::UInt8 => ScalarValue::UInt8(Some(value as u8)), - DataType::UInt16 => ScalarValue::UInt16(Some(value as u16)), - DataType::UInt32 => ScalarValue::UInt32(Some(value as u32)), - DataType::UInt64 => ScalarValue::UInt64(Some(value as u64)), - DataType::Timestamp(TimeUnit::Second, tz) => { - let value = cast_between_timestamp( - &lit_data_type, - &DataType::Timestamp(TimeUnit::Second, tz.clone()), - value, - ); - ScalarValue::TimestampSecond(value, tz.clone()) - } - DataType::Timestamp(TimeUnit::Millisecond, tz) => { - let value = cast_between_timestamp( - &lit_data_type, - &DataType::Timestamp(TimeUnit::Millisecond, tz.clone()), - value, - ); - ScalarValue::TimestampMillisecond(value, tz.clone()) - } - DataType::Timestamp(TimeUnit::Microsecond, tz) => { - let value = cast_between_timestamp( - &lit_data_type, - &DataType::Timestamp(TimeUnit::Microsecond, tz.clone()), - value, - ); - ScalarValue::TimestampMicrosecond(value, tz.clone()) - } - DataType::Timestamp(TimeUnit::Nanosecond, tz) => { - let value = cast_between_timestamp( - &lit_data_type, - &DataType::Timestamp(TimeUnit::Nanosecond, tz.clone()), - value, - ); - ScalarValue::TimestampNanosecond(value, tz.clone()) - } - DataType::Decimal128(p, s) => { - ScalarValue::Decimal128(Some(value), *p, *s) - } - _ => { - return None; - } - }; - Some(result_scalar) - } else { - None - } - } - } -} - -fn try_cast_string_literal( - lit_value: &ScalarValue, - target_type: &DataType, -) -> Option { - let string_value = lit_value.try_as_str()?.map(|s| s.to_string()); - let scalar_value = match target_type { - DataType::Utf8 => ScalarValue::Utf8(string_value), - DataType::LargeUtf8 => ScalarValue::LargeUtf8(string_value), - DataType::Utf8View => ScalarValue::Utf8View(string_value), - _ => return None, - }; - Some(scalar_value) -} - -/// Attempt to cast to/from a dictionary type by wrapping/unwrapping the dictionary -fn try_cast_dictionary( - lit_value: &ScalarValue, - target_type: &DataType, -) -> Option { - let lit_value_type = lit_value.data_type(); - let result_scalar = match (lit_value, target_type) { - // Unwrap dictionary when inner type matches target type - (ScalarValue::Dictionary(_, inner_value), _) - if inner_value.data_type() == *target_type => - { - (**inner_value).clone() - } - // Wrap type when target type is dictionary - (_, DataType::Dictionary(index_type, inner_type)) - if **inner_type == lit_value_type => - { - ScalarValue::Dictionary(index_type.clone(), Box::new(lit_value.clone())) - } - _ => { - return None; - } - }; - Some(result_scalar) -} - -/// Cast a timestamp value from one unit to another -fn cast_between_timestamp(from: &DataType, to: &DataType, value: i128) -> Option { - let value = value as i64; - let from_scale = match from { - DataType::Timestamp(TimeUnit::Second, _) => 1, - DataType::Timestamp(TimeUnit::Millisecond, _) => MILLISECONDS, - DataType::Timestamp(TimeUnit::Microsecond, _) => MICROSECONDS, - DataType::Timestamp(TimeUnit::Nanosecond, _) => NANOSECONDS, - _ => return Some(value), - }; - - let to_scale = match to { - DataType::Timestamp(TimeUnit::Second, _) => 1, - DataType::Timestamp(TimeUnit::Millisecond, _) => MILLISECONDS, - DataType::Timestamp(TimeUnit::Microsecond, _) => MICROSECONDS, - DataType::Timestamp(TimeUnit::Nanosecond, _) => NANOSECONDS, - _ => return Some(value), - }; - - match from_scale.cmp(&to_scale) { - Ordering::Less => value.checked_mul(to_scale / from_scale), - Ordering::Greater => Some(value / (from_scale / to_scale)), - Ordering::Equal => Some(value), - } -} - -fn try_cast_binary( - lit_value: &ScalarValue, - target_type: &DataType, -) -> Option { - match (lit_value, target_type) { - (ScalarValue::Binary(Some(v)), DataType::FixedSizeBinary(n)) - if v.len() == *n as usize => - { - Some(ScalarValue::FixedSizeBinary(*n, Some(v.clone()))) - } - _ => None, - } -} - #[cfg(test)] mod tests { use super::*; @@ -528,8 +240,7 @@ mod tests { use std::sync::Arc; use crate::simplify_expressions::ExprSimplifier; - use arrow::compute::{cast_with_options, CastOptions}; - use arrow::datatypes::Field; + use arrow::datatypes::{Field, TimeUnit}; use datafusion_common::{DFSchema, DFSchemaRef}; use datafusion_expr::execution_props::ExecutionProps; use datafusion_expr::simplify::SimplifyContext; @@ -960,523 +671,4 @@ mod tests { fn dictionary_tag_type() -> DataType { DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)) } - - #[test] - fn test_try_cast_to_type_nulls() { - // test that nulls can be cast to/from all integer types - let scalars = vec![ - ScalarValue::Int8(None), - ScalarValue::Int16(None), - ScalarValue::Int32(None), - ScalarValue::Int64(None), - ScalarValue::UInt8(None), - ScalarValue::UInt16(None), - ScalarValue::UInt32(None), - ScalarValue::UInt64(None), - ScalarValue::Decimal128(None, 3, 0), - ScalarValue::Decimal128(None, 8, 2), - ScalarValue::Utf8(None), - ScalarValue::LargeUtf8(None), - ]; - - for s1 in &scalars { - for s2 in &scalars { - let expected_value = ExpectedCast::Value(s2.clone()); - - expect_cast(s1.clone(), s2.data_type(), expected_value); - } - } - } - - #[test] - fn test_try_cast_to_type_int_in_range() { - // test values that can be cast to/from all integer types - let scalars = vec![ - ScalarValue::Int8(Some(123)), - ScalarValue::Int16(Some(123)), - ScalarValue::Int32(Some(123)), - ScalarValue::Int64(Some(123)), - ScalarValue::UInt8(Some(123)), - ScalarValue::UInt16(Some(123)), - ScalarValue::UInt32(Some(123)), - ScalarValue::UInt64(Some(123)), - ScalarValue::Decimal128(Some(123), 3, 0), - ScalarValue::Decimal128(Some(12300), 8, 2), - ]; - - for s1 in &scalars { - for s2 in &scalars { - let expected_value = ExpectedCast::Value(s2.clone()); - - expect_cast(s1.clone(), s2.data_type(), expected_value); - } - } - - let max_i32 = ScalarValue::Int32(Some(i32::MAX)); - expect_cast( - max_i32, - DataType::UInt64, - ExpectedCast::Value(ScalarValue::UInt64(Some(i32::MAX as u64))), - ); - - let min_i32 = ScalarValue::Int32(Some(i32::MIN)); - expect_cast( - min_i32, - DataType::Int64, - ExpectedCast::Value(ScalarValue::Int64(Some(i32::MIN as i64))), - ); - - let max_i64 = ScalarValue::Int64(Some(i64::MAX)); - expect_cast( - max_i64, - DataType::UInt64, - ExpectedCast::Value(ScalarValue::UInt64(Some(i64::MAX as u64))), - ); - } - - #[test] - fn test_try_cast_to_type_int_out_of_range() { - let min_i32 = ScalarValue::Int32(Some(i32::MIN)); - let min_i64 = ScalarValue::Int64(Some(i64::MIN)); - let max_i64 = ScalarValue::Int64(Some(i64::MAX)); - let max_u64 = ScalarValue::UInt64(Some(u64::MAX)); - - expect_cast(max_i64.clone(), DataType::Int8, ExpectedCast::NoValue); - - expect_cast(max_i64.clone(), DataType::Int16, ExpectedCast::NoValue); - - expect_cast(max_i64, DataType::Int32, ExpectedCast::NoValue); - - expect_cast(max_u64, DataType::Int64, ExpectedCast::NoValue); - - expect_cast(min_i64, DataType::UInt64, ExpectedCast::NoValue); - - expect_cast(min_i32, DataType::UInt64, ExpectedCast::NoValue); - - // decimal out of range - expect_cast( - ScalarValue::Decimal128(Some(99999999999999999999999999999999999900), 38, 0), - DataType::Int64, - ExpectedCast::NoValue, - ); - - expect_cast( - ScalarValue::Decimal128(Some(-9999999999999999999999999999999999), 37, 1), - DataType::Int64, - ExpectedCast::NoValue, - ); - } - - #[test] - fn test_try_decimal_cast_in_range() { - expect_cast( - ScalarValue::Decimal128(Some(12300), 5, 2), - DataType::Decimal128(3, 0), - ExpectedCast::Value(ScalarValue::Decimal128(Some(123), 3, 0)), - ); - - expect_cast( - ScalarValue::Decimal128(Some(12300), 5, 2), - DataType::Decimal128(8, 0), - ExpectedCast::Value(ScalarValue::Decimal128(Some(123), 8, 0)), - ); - - expect_cast( - ScalarValue::Decimal128(Some(12300), 5, 2), - DataType::Decimal128(8, 5), - ExpectedCast::Value(ScalarValue::Decimal128(Some(12300000), 8, 5)), - ); - } - - #[test] - fn test_try_decimal_cast_out_of_range() { - // decimal would lose precision - expect_cast( - ScalarValue::Decimal128(Some(12345), 5, 2), - DataType::Decimal128(3, 0), - ExpectedCast::NoValue, - ); - - // decimal would lose precision - expect_cast( - ScalarValue::Decimal128(Some(12300), 5, 2), - DataType::Decimal128(2, 0), - ExpectedCast::NoValue, - ); - } - - #[test] - fn test_try_cast_to_type_timestamps() { - for time_unit in [ - TimeUnit::Second, - TimeUnit::Millisecond, - TimeUnit::Microsecond, - TimeUnit::Nanosecond, - ] { - let utc = Some("+00:00".into()); - // No timezone, utc timezone - let (lit_tz_none, lit_tz_utc) = match time_unit { - TimeUnit::Second => ( - ScalarValue::TimestampSecond(Some(12345), None), - ScalarValue::TimestampSecond(Some(12345), utc), - ), - - TimeUnit::Millisecond => ( - ScalarValue::TimestampMillisecond(Some(12345), None), - ScalarValue::TimestampMillisecond(Some(12345), utc), - ), - - TimeUnit::Microsecond => ( - ScalarValue::TimestampMicrosecond(Some(12345), None), - ScalarValue::TimestampMicrosecond(Some(12345), utc), - ), - - TimeUnit::Nanosecond => ( - ScalarValue::TimestampNanosecond(Some(12345), None), - ScalarValue::TimestampNanosecond(Some(12345), utc), - ), - }; - - // DataFusion ignores timezones for comparisons of ScalarValue - // so double check it here - assert_eq!(lit_tz_none, lit_tz_utc); - - // e.g. DataType::Timestamp(_, None) - let dt_tz_none = lit_tz_none.data_type(); - - // e.g. DataType::Timestamp(_, Some(utc)) - let dt_tz_utc = lit_tz_utc.data_type(); - - // None <--> None - expect_cast( - lit_tz_none.clone(), - dt_tz_none.clone(), - ExpectedCast::Value(lit_tz_none.clone()), - ); - - // None <--> Utc - expect_cast( - lit_tz_none.clone(), - dt_tz_utc.clone(), - ExpectedCast::Value(lit_tz_utc.clone()), - ); - - // Utc <--> None - expect_cast( - lit_tz_utc.clone(), - dt_tz_none.clone(), - ExpectedCast::Value(lit_tz_none.clone()), - ); - - // Utc <--> Utc - expect_cast( - lit_tz_utc.clone(), - dt_tz_utc.clone(), - ExpectedCast::Value(lit_tz_utc.clone()), - ); - - // timestamp to int64 - expect_cast( - lit_tz_utc.clone(), - DataType::Int64, - ExpectedCast::Value(ScalarValue::Int64(Some(12345))), - ); - - // int64 to timestamp - expect_cast( - ScalarValue::Int64(Some(12345)), - dt_tz_none.clone(), - ExpectedCast::Value(lit_tz_none.clone()), - ); - - // int64 to timestamp - expect_cast( - ScalarValue::Int64(Some(12345)), - dt_tz_utc.clone(), - ExpectedCast::Value(lit_tz_utc.clone()), - ); - - // timestamp to string (not supported yet) - expect_cast( - lit_tz_utc.clone(), - DataType::LargeUtf8, - ExpectedCast::NoValue, - ); - } - } - - #[test] - fn test_try_cast_to_type_unsupported() { - // int64 to list - expect_cast( - ScalarValue::Int64(Some(12345)), - DataType::List(Arc::new(Field::new("f", DataType::Int32, true))), - ExpectedCast::NoValue, - ); - } - - #[derive(Debug, Clone)] - enum ExpectedCast { - /// test successfully cast value and it is as specified - Value(ScalarValue), - /// test returned OK, but could not cast the value - NoValue, - } - - /// Runs try_cast_literal_to_type with the specified inputs and - /// ensure it computes the expected output, and ensures the - /// casting is consistent with the Arrow kernels - fn expect_cast( - literal: ScalarValue, - target_type: DataType, - expected_result: ExpectedCast, - ) { - let actual_value = try_cast_literal_to_type(&literal, &target_type); - - println!("expect_cast: "); - println!(" {literal:?} --> {target_type:?}"); - println!(" expected_result: {expected_result:?}"); - println!(" actual_result: {actual_value:?}"); - - match expected_result { - ExpectedCast::Value(expected_value) => { - let actual_value = - actual_value.expect("Expected cast value but got None"); - - assert_eq!(actual_value, expected_value); - - // Verify that calling the arrow - // cast kernel yields the same results - // input array - let literal_array = literal - .to_array_of_size(1) - .expect("Failed to convert to array of size"); - let expected_array = expected_value - .to_array_of_size(1) - .expect("Failed to convert to array of size"); - let cast_array = cast_with_options( - &literal_array, - &target_type, - &CastOptions::default(), - ) - .expect("Expected to be cast array with arrow cast kernel"); - - assert_eq!( - &expected_array, &cast_array, - "Result of casting {literal:?} with arrow was\n {cast_array:#?}\nbut expected\n{expected_array:#?}" - ); - - // Verify that for timestamp types the timezones are the same - // (ScalarValue::cmp doesn't account for timezones); - if let ( - DataType::Timestamp(left_unit, left_tz), - DataType::Timestamp(right_unit, right_tz), - ) = (actual_value.data_type(), expected_value.data_type()) - { - assert_eq!(left_unit, right_unit); - assert_eq!(left_tz, right_tz); - } - } - ExpectedCast::NoValue => { - assert!( - actual_value.is_none(), - "Expected no cast value, but got {actual_value:?}" - ); - } - } - } - - #[test] - fn test_try_cast_literal_to_timestamp() { - // same timestamp - let new_scalar = try_cast_literal_to_type( - &ScalarValue::TimestampNanosecond(Some(123456), None), - &DataType::Timestamp(TimeUnit::Nanosecond, None), - ) - .unwrap(); - - assert_eq!( - new_scalar, - ScalarValue::TimestampNanosecond(Some(123456), None) - ); - - // TimestampNanosecond to TimestampMicrosecond - let new_scalar = try_cast_literal_to_type( - &ScalarValue::TimestampNanosecond(Some(123456), None), - &DataType::Timestamp(TimeUnit::Microsecond, None), - ) - .unwrap(); - - assert_eq!( - new_scalar, - ScalarValue::TimestampMicrosecond(Some(123), None) - ); - - // TimestampNanosecond to TimestampMillisecond - let new_scalar = try_cast_literal_to_type( - &ScalarValue::TimestampNanosecond(Some(123456), None), - &DataType::Timestamp(TimeUnit::Millisecond, None), - ) - .unwrap(); - - assert_eq!(new_scalar, ScalarValue::TimestampMillisecond(Some(0), None)); - - // TimestampNanosecond to TimestampSecond - let new_scalar = try_cast_literal_to_type( - &ScalarValue::TimestampNanosecond(Some(123456), None), - &DataType::Timestamp(TimeUnit::Second, None), - ) - .unwrap(); - - assert_eq!(new_scalar, ScalarValue::TimestampSecond(Some(0), None)); - - // TimestampMicrosecond to TimestampNanosecond - let new_scalar = try_cast_literal_to_type( - &ScalarValue::TimestampMicrosecond(Some(123), None), - &DataType::Timestamp(TimeUnit::Nanosecond, None), - ) - .unwrap(); - - assert_eq!( - new_scalar, - ScalarValue::TimestampNanosecond(Some(123000), None) - ); - - // TimestampMicrosecond to TimestampMillisecond - let new_scalar = try_cast_literal_to_type( - &ScalarValue::TimestampMicrosecond(Some(123), None), - &DataType::Timestamp(TimeUnit::Millisecond, None), - ) - .unwrap(); - - assert_eq!(new_scalar, ScalarValue::TimestampMillisecond(Some(0), None)); - - // TimestampMicrosecond to TimestampSecond - let new_scalar = try_cast_literal_to_type( - &ScalarValue::TimestampMicrosecond(Some(123456789), None), - &DataType::Timestamp(TimeUnit::Second, None), - ) - .unwrap(); - assert_eq!(new_scalar, ScalarValue::TimestampSecond(Some(123), None)); - - // TimestampMillisecond to TimestampNanosecond - let new_scalar = try_cast_literal_to_type( - &ScalarValue::TimestampMillisecond(Some(123), None), - &DataType::Timestamp(TimeUnit::Nanosecond, None), - ) - .unwrap(); - assert_eq!( - new_scalar, - ScalarValue::TimestampNanosecond(Some(123000000), None) - ); - - // TimestampMillisecond to TimestampMicrosecond - let new_scalar = try_cast_literal_to_type( - &ScalarValue::TimestampMillisecond(Some(123), None), - &DataType::Timestamp(TimeUnit::Microsecond, None), - ) - .unwrap(); - assert_eq!( - new_scalar, - ScalarValue::TimestampMicrosecond(Some(123000), None) - ); - // TimestampMillisecond to TimestampSecond - let new_scalar = try_cast_literal_to_type( - &ScalarValue::TimestampMillisecond(Some(123456789), None), - &DataType::Timestamp(TimeUnit::Second, None), - ) - .unwrap(); - assert_eq!(new_scalar, ScalarValue::TimestampSecond(Some(123456), None)); - - // TimestampSecond to TimestampNanosecond - let new_scalar = try_cast_literal_to_type( - &ScalarValue::TimestampSecond(Some(123), None), - &DataType::Timestamp(TimeUnit::Nanosecond, None), - ) - .unwrap(); - assert_eq!( - new_scalar, - ScalarValue::TimestampNanosecond(Some(123000000000), None) - ); - - // TimestampSecond to TimestampMicrosecond - let new_scalar = try_cast_literal_to_type( - &ScalarValue::TimestampSecond(Some(123), None), - &DataType::Timestamp(TimeUnit::Microsecond, None), - ) - .unwrap(); - assert_eq!( - new_scalar, - ScalarValue::TimestampMicrosecond(Some(123000000), None) - ); - - // TimestampSecond to TimestampMillisecond - let new_scalar = try_cast_literal_to_type( - &ScalarValue::TimestampSecond(Some(123), None), - &DataType::Timestamp(TimeUnit::Millisecond, None), - ) - .unwrap(); - assert_eq!( - new_scalar, - ScalarValue::TimestampMillisecond(Some(123000), None) - ); - - // overflow - let new_scalar = try_cast_literal_to_type( - &ScalarValue::TimestampSecond(Some(i64::MAX), None), - &DataType::Timestamp(TimeUnit::Millisecond, None), - ) - .unwrap(); - assert_eq!(new_scalar, ScalarValue::TimestampMillisecond(None, None)); - } - - #[test] - fn test_try_cast_to_string_type() { - let scalars = vec![ - ScalarValue::from("string"), - ScalarValue::LargeUtf8(Some("string".to_owned())), - ]; - - for s1 in &scalars { - for s2 in &scalars { - let expected_value = ExpectedCast::Value(s2.clone()); - - expect_cast(s1.clone(), s2.data_type(), expected_value); - } - } - } - #[test] - fn test_try_cast_to_dictionary_type() { - fn dictionary_type(t: DataType) -> DataType { - DataType::Dictionary(Box::new(DataType::Int32), Box::new(t)) - } - fn dictionary_value(value: ScalarValue) -> ScalarValue { - ScalarValue::Dictionary(Box::new(DataType::Int32), Box::new(value)) - } - let scalars = vec![ - ScalarValue::from("string"), - ScalarValue::LargeUtf8(Some("string".to_owned())), - ]; - for s in &scalars { - expect_cast( - s.clone(), - dictionary_type(s.data_type()), - ExpectedCast::Value(dictionary_value(s.clone())), - ); - expect_cast( - dictionary_value(s.clone()), - s.data_type(), - ExpectedCast::Value(s.clone()), - ) - } - } - - #[test] - fn try_cast_to_fixed_size_binary() { - expect_cast( - ScalarValue::Binary(Some(vec![1, 2, 3])), - DataType::FixedSizeBinary(3), - ExpectedCast::Value(ScalarValue::FixedSizeBinary(3, Some(vec![1, 2, 3]))), - ) - } } diff --git a/datafusion/physical-expr/src/lib.rs b/datafusion/physical-expr/src/lib.rs index 3bdb9d84d8278..03fc77f156d95 100644 --- a/datafusion/physical-expr/src/lib.rs +++ b/datafusion/physical-expr/src/lib.rs @@ -39,6 +39,7 @@ mod physical_expr; pub mod planner; mod scalar_function; pub mod schema_rewriter; +pub mod simplifier; pub mod statistics; pub mod utils; pub mod window; diff --git a/datafusion/physical-expr/src/simplifier/mod.rs b/datafusion/physical-expr/src/simplifier/mod.rs new file mode 100644 index 0000000000000..80d6ee0a7b914 --- /dev/null +++ b/datafusion/physical-expr/src/simplifier/mod.rs @@ -0,0 +1,188 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Simplifier for Physical Expressions + +use arrow::datatypes::Schema; +use datafusion_common::{ + tree_node::{Transformed, TreeNode, TreeNodeRewriter}, + Result, +}; +use std::sync::Arc; + +use crate::PhysicalExpr; + +pub mod unwrap_cast; + +/// Simplifies physical expressions by applying various optimizations +/// +/// This can be useful after adapting expressions from a table schema +/// to a file schema. For example, casts added to match the types may +/// potentially be unwrapped. +pub struct PhysicalExprSimplifier<'a> { + schema: &'a Schema, +} + +impl<'a> PhysicalExprSimplifier<'a> { + /// Create a new physical expression simplifier + pub fn new(schema: &'a Schema) -> Self { + Self { schema } + } + + /// Simplify a physical expression + pub fn simplify( + &mut self, + expr: Arc, + ) -> Result> { + Ok(expr.rewrite(self)?.data) + } +} + +impl<'a> TreeNodeRewriter for PhysicalExprSimplifier<'a> { + type Node = Arc; + + fn f_up(&mut self, node: Self::Node) -> Result> { + // Apply unwrap cast optimization + #[cfg(test)] + let original_type = node.data_type(self.schema).unwrap(); + let unwrapped = unwrap_cast::unwrap_cast_in_comparison(node, self.schema)?; + #[cfg(test)] + assert_eq!( + unwrapped.data.data_type(self.schema).unwrap(), + original_type, + "Simplified expression should have the same data type as the original" + ); + Ok(unwrapped) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::expressions::{col, lit, BinaryExpr, CastExpr, Literal, TryCastExpr}; + use arrow::datatypes::{DataType, Field, Schema}; + use datafusion_common::ScalarValue; + use datafusion_expr::Operator; + + fn test_schema() -> Schema { + Schema::new(vec![ + Field::new("c1", DataType::Int32, false), + Field::new("c2", DataType::Int64, false), + Field::new("c3", DataType::Utf8, false), + ]) + } + + #[test] + fn test_simplify() { + let schema = test_schema(); + let mut simplifier = PhysicalExprSimplifier::new(&schema); + + // Create: cast(c2 as INT32) != INT32(99) + let column_expr = col("c2", &schema).unwrap(); + let cast_expr = Arc::new(CastExpr::new(column_expr, DataType::Int32, None)); + let literal_expr = lit(ScalarValue::Int32(Some(99))); + let binary_expr = + Arc::new(BinaryExpr::new(cast_expr, Operator::NotEq, literal_expr)); + + // Apply full simplification (uses TreeNodeRewriter) + let optimized = simplifier.simplify(binary_expr).unwrap(); + + let optimized_binary = optimized.as_any().downcast_ref::().unwrap(); + + // Should be optimized to: c2 != INT64(99) (c2 is INT64, literal cast to match) + let left_expr = optimized_binary.left(); + assert!( + left_expr.as_any().downcast_ref::().is_none() + && left_expr.as_any().downcast_ref::().is_none() + ); + let right_literal = optimized_binary + .right() + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(right_literal.value(), &ScalarValue::Int64(Some(99))); + } + + #[test] + fn test_nested_expression_simplification() { + let schema = test_schema(); + let mut simplifier = PhysicalExprSimplifier::new(&schema); + + // Create nested expression: (cast(c1 as INT64) > INT64(5)) OR (cast(c2 as INT32) <= INT32(10)) + let c1_expr = col("c1", &schema).unwrap(); + let c1_cast = Arc::new(CastExpr::new(c1_expr, DataType::Int64, None)); + let c1_literal = lit(ScalarValue::Int64(Some(5))); + let c1_binary = Arc::new(BinaryExpr::new(c1_cast, Operator::Gt, c1_literal)); + + let c2_expr = col("c2", &schema).unwrap(); + let c2_cast = Arc::new(CastExpr::new(c2_expr, DataType::Int32, None)); + let c2_literal = lit(ScalarValue::Int32(Some(10))); + let c2_binary = Arc::new(BinaryExpr::new(c2_cast, Operator::LtEq, c2_literal)); + + let or_expr = Arc::new(BinaryExpr::new(c1_binary, Operator::Or, c2_binary)); + + // Apply simplification + let optimized = simplifier.simplify(or_expr).unwrap(); + + let or_binary = optimized.as_any().downcast_ref::().unwrap(); + + // Verify left side: c1 > INT32(5) + let left_binary = or_binary + .left() + .as_any() + .downcast_ref::() + .unwrap(); + let left_left_expr = left_binary.left(); + assert!( + left_left_expr.as_any().downcast_ref::().is_none() + && left_left_expr + .as_any() + .downcast_ref::() + .is_none() + ); + let left_literal = left_binary + .right() + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(left_literal.value(), &ScalarValue::Int32(Some(5))); + + // Verify right side: c2 <= INT64(10) + let right_binary = or_binary + .right() + .as_any() + .downcast_ref::() + .unwrap(); + let right_left_expr = right_binary.left(); + assert!( + right_left_expr + .as_any() + .downcast_ref::() + .is_none() + && right_left_expr + .as_any() + .downcast_ref::() + .is_none() + ); + let right_literal = right_binary + .right() + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(right_literal.value(), &ScalarValue::Int64(Some(10))); + } +} diff --git a/datafusion/physical-expr/src/simplifier/unwrap_cast.rs b/datafusion/physical-expr/src/simplifier/unwrap_cast.rs new file mode 100644 index 0000000000000..d409ce9cb5bf2 --- /dev/null +++ b/datafusion/physical-expr/src/simplifier/unwrap_cast.rs @@ -0,0 +1,646 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Unwrap casts in binary comparisons for physical expressions +//! +//! This module provides optimization for physical expressions similar to the logical +//! optimizer's unwrap_cast module. It attempts to remove casts from comparisons to +//! literals by applying the casts to the literals if possible. +//! +//! The optimization improves performance by: +//! 1. Reducing runtime cast operations on column data +//! 2. Enabling better predicate pushdown opportunities +//! 3. Optimizing filter expressions in physical plans +//! +//! # Example +//! +//! Physical expression: `cast(column as INT64) > INT64(10)` +//! Optimized to: `column > INT32(10)` (assuming column is INT32) + +use std::sync::Arc; + +use arrow::datatypes::{DataType, Schema}; +use datafusion_common::{ + tree_node::{Transformed, TreeNode}, + Result, ScalarValue, +}; +use datafusion_expr::Operator; +use datafusion_expr_common::casts::try_cast_literal_to_type; + +use crate::expressions::{lit, BinaryExpr, CastExpr, Literal, TryCastExpr}; +use crate::PhysicalExpr; + +/// Attempts to unwrap casts in comparison expressions. +pub(crate) fn unwrap_cast_in_comparison( + expr: Arc, + schema: &Schema, +) -> Result>> { + expr.transform_down(|e| { + if let Some(binary) = e.as_any().downcast_ref::() { + if let Some(unwrapped) = try_unwrap_cast_binary(binary, schema)? { + return Ok(Transformed::yes(unwrapped)); + } + } + Ok(Transformed::no(e)) + }) +} + +/// Try to unwrap casts in binary expressions +fn try_unwrap_cast_binary( + binary: &BinaryExpr, + schema: &Schema, +) -> Result>> { + // Case 1: cast(left_expr) op literal + if let (Some((inner_expr, _cast_type)), Some(literal)) = ( + extract_cast_info(binary.left()), + binary.right().as_any().downcast_ref::(), + ) { + if binary.op().supports_propagation() { + if let Some(unwrapped) = try_unwrap_cast_comparison( + Arc::clone(inner_expr), + literal.value(), + *binary.op(), + schema, + )? { + return Ok(Some(unwrapped)); + } + } + } + + // Case 2: literal op cast(right_expr) + if let (Some(literal), Some((inner_expr, _cast_type))) = ( + binary.left().as_any().downcast_ref::(), + extract_cast_info(binary.right()), + ) { + // For literal op cast(expr), we need to swap the operator + if let Some(swapped_op) = binary.op().swap() { + if binary.op().supports_propagation() { + if let Some(unwrapped) = try_unwrap_cast_comparison( + Arc::clone(inner_expr), + literal.value(), + swapped_op, + schema, + )? { + return Ok(Some(unwrapped)); + } + } + } + // If the operator cannot be swapped, we skip this optimization case + // but don't prevent other optimizations + } + + Ok(None) +} + +/// Extract cast information from a physical expression +/// +/// If the expression is a CAST(expr, datatype) or TRY_CAST(expr, datatype), +/// returns Some((inner_expr, target_datatype)). Otherwise returns None. +fn extract_cast_info( + expr: &Arc, +) -> Option<(&Arc, &DataType)> { + if let Some(cast) = expr.as_any().downcast_ref::() { + Some((cast.expr(), cast.cast_type())) + } else if let Some(try_cast) = expr.as_any().downcast_ref::() { + Some((try_cast.expr(), try_cast.cast_type())) + } else { + None + } +} + +/// Try to unwrap a cast in comparison by moving the cast to the literal +fn try_unwrap_cast_comparison( + inner_expr: Arc, + literal_value: &ScalarValue, + op: Operator, + schema: &Schema, +) -> Result>> { + // Get the data type of the inner expression + let inner_type = inner_expr.data_type(schema)?; + + // Try to cast the literal to the inner expression's type + if let Some(casted_literal) = try_cast_literal_to_type(literal_value, &inner_type) { + let literal_expr = lit(casted_literal); + let binary_expr = BinaryExpr::new(inner_expr, op, literal_expr); + return Ok(Some(Arc::new(binary_expr))); + } + + Ok(None) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::expressions::{col, lit}; + use arrow::datatypes::{DataType, Field, Schema}; + use datafusion_common::ScalarValue; + use datafusion_expr::Operator; + + /// Check if an expression is a cast expression + fn is_cast_expr(expr: &Arc) -> bool { + expr.as_any().downcast_ref::().is_some() + || expr.as_any().downcast_ref::().is_some() + } + + /// Check if a binary expression is suitable for cast unwrapping + fn is_binary_expr_with_cast_and_literal(binary: &BinaryExpr) -> bool { + // Check if left is cast and right is literal + let left_cast_right_literal = is_cast_expr(binary.left()) + && binary.right().as_any().downcast_ref::().is_some(); + + // Check if left is literal and right is cast + let left_literal_right_cast = + binary.left().as_any().downcast_ref::().is_some() + && is_cast_expr(binary.right()); + + left_cast_right_literal || left_literal_right_cast + } + + fn test_schema() -> Schema { + Schema::new(vec![ + Field::new("c1", DataType::Int32, false), + Field::new("c2", DataType::Int64, false), + Field::new("c3", DataType::Utf8, false), + ]) + } + + #[test] + fn test_unwrap_cast_in_binary_comparison() { + let schema = test_schema(); + + // Create: cast(c1 as INT64) > INT64(10) + let column_expr = col("c1", &schema).unwrap(); + let cast_expr = Arc::new(CastExpr::new(column_expr, DataType::Int64, None)); + let literal_expr = lit(10i64); + let binary_expr = + Arc::new(BinaryExpr::new(cast_expr, Operator::Gt, literal_expr)); + + // Apply unwrap cast optimization + let result = unwrap_cast_in_comparison(binary_expr, &schema).unwrap(); + + // Should be transformed + assert!(result.transformed); + + // The result should be: c1 > INT32(10) + let optimized = result.data; + let optimized_binary = optimized.as_any().downcast_ref::().unwrap(); + + // Check that left side is no longer a cast + assert!(!is_cast_expr(optimized_binary.left())); + + // Check that right side is a literal with the correct type and value + let right_literal = optimized_binary + .right() + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(right_literal.value(), &ScalarValue::Int32(Some(10))); + } + + #[test] + fn test_unwrap_cast_with_literal_on_left() { + let schema = test_schema(); + + // Create: INT64(10) < cast(c1 as INT64) + let column_expr = col("c1", &schema).unwrap(); + let cast_expr = Arc::new(CastExpr::new(column_expr, DataType::Int64, None)); + let literal_expr = lit(10i64); + let binary_expr = + Arc::new(BinaryExpr::new(literal_expr, Operator::Lt, cast_expr)); + + // Apply unwrap cast optimization + let result = unwrap_cast_in_comparison(binary_expr, &schema).unwrap(); + + // Should be transformed + assert!(result.transformed); + + // The result should be equivalent to: c1 > INT32(10) + let optimized = result.data; + let optimized_binary = optimized.as_any().downcast_ref::().unwrap(); + + // Check the operator was swapped + assert_eq!(*optimized_binary.op(), Operator::Gt); + } + + #[test] + fn test_no_unwrap_when_types_unsupported() { + let schema = Schema::new(vec![Field::new("f1", DataType::Float32, false)]); + + // Create: cast(f1 as FLOAT64) > FLOAT64(10.5) + let column_expr = col("f1", &schema).unwrap(); + let cast_expr = Arc::new(CastExpr::new(column_expr, DataType::Float64, None)); + let literal_expr = lit(10.5f64); + let binary_expr = + Arc::new(BinaryExpr::new(cast_expr, Operator::Gt, literal_expr)); + + // Apply unwrap cast optimization + let result = unwrap_cast_in_comparison(binary_expr, &schema).unwrap(); + + // Should NOT be transformed (floating point types not supported) + assert!(!result.transformed); + } + + #[test] + fn test_is_binary_expr_with_cast_and_literal() { + let schema = test_schema(); + + let column_expr = col("c1", &schema).unwrap(); + let cast_expr = Arc::new(CastExpr::new(column_expr, DataType::Int64, None)); + let literal_expr = lit(10i64); + let binary_expr = + Arc::new(BinaryExpr::new(cast_expr, Operator::Gt, literal_expr)); + let binary_ref = binary_expr.as_any().downcast_ref::().unwrap(); + + assert!(is_binary_expr_with_cast_and_literal(binary_ref)); + } + + #[test] + fn test_unwrap_cast_literal_on_left_side() { + // Test case for: literal <= cast(column) + // This was the specific case that caused the bug + let schema = Schema::new(vec![Field::new( + "decimal_col", + DataType::Decimal128(9, 2), + true, + )]); + + // Create: Decimal128(400) <= cast(decimal_col as Decimal128(22, 2)) + let column_expr = col("decimal_col", &schema).unwrap(); + let cast_expr = Arc::new(CastExpr::new( + column_expr, + DataType::Decimal128(22, 2), + None, + )); + let literal_expr = lit(ScalarValue::Decimal128(Some(400), 22, 2)); + let binary_expr = + Arc::new(BinaryExpr::new(literal_expr, Operator::LtEq, cast_expr)); + + // Apply unwrap cast optimization + let result = unwrap_cast_in_comparison(binary_expr, &schema).unwrap(); + + // Should be transformed + assert!(result.transformed); + + // The result should be: decimal_col >= Decimal128(400, 9, 2) + let optimized = result.data; + let optimized_binary = optimized.as_any().downcast_ref::().unwrap(); + + // Check operator was swapped correctly + assert_eq!(*optimized_binary.op(), Operator::GtEq); + + // Check that left side is the column without cast + assert!(!is_cast_expr(optimized_binary.left())); + + // Check that right side is a literal with the correct type + let right_literal = optimized_binary + .right() + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!( + right_literal.value().data_type(), + DataType::Decimal128(9, 2) + ); + } + + #[test] + fn test_unwrap_cast_with_different_comparison_operators() { + let schema = Schema::new(vec![Field::new("int_col", DataType::Int32, false)]); + + // Test all comparison operators with literal on the left + let operators = vec![ + (Operator::Lt, Operator::Gt), + (Operator::LtEq, Operator::GtEq), + (Operator::Gt, Operator::Lt), + (Operator::GtEq, Operator::LtEq), + (Operator::Eq, Operator::Eq), + (Operator::NotEq, Operator::NotEq), + ]; + + for (original_op, expected_op) in operators { + // Create: INT64(100) op cast(int_col as INT64) + let column_expr = col("int_col", &schema).unwrap(); + let cast_expr = Arc::new(CastExpr::new(column_expr, DataType::Int64, None)); + let literal_expr = lit(100i64); + let binary_expr = + Arc::new(BinaryExpr::new(literal_expr, original_op, cast_expr)); + + // Apply unwrap cast optimization + let result = unwrap_cast_in_comparison(binary_expr, &schema).unwrap(); + + // Should be transformed + assert!(result.transformed); + + let optimized = result.data; + let optimized_binary = + optimized.as_any().downcast_ref::().unwrap(); + + // Check the operator was swapped correctly + assert_eq!( + *optimized_binary.op(), + expected_op, + "Failed for operator {original_op:?} -> {expected_op:?}" + ); + + // Check that left side has no cast + assert!(!is_cast_expr(optimized_binary.left())); + + // Check that the literal was cast to the column type + let right_literal = optimized_binary + .right() + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(right_literal.value(), &ScalarValue::Int32(Some(100))); + } + } + + #[test] + fn test_unwrap_cast_with_decimal_types() { + // Test various decimal precision/scale combinations + let test_cases = vec![ + // (column_precision, column_scale, cast_precision, cast_scale, value) + (9, 2, 22, 2, 400), + (10, 3, 20, 3, 1000), + (5, 1, 10, 1, 99), + ]; + + for (col_p, col_s, cast_p, cast_s, value) in test_cases { + let schema = Schema::new(vec![Field::new( + "decimal_col", + DataType::Decimal128(col_p, col_s), + true, + )]); + + // Test both: cast(column) op literal AND literal op cast(column) + + // Case 1: cast(column) > literal + let column_expr = col("decimal_col", &schema).unwrap(); + let cast_expr = Arc::new(CastExpr::new( + Arc::clone(&column_expr), + DataType::Decimal128(cast_p, cast_s), + None, + )); + let literal_expr = lit(ScalarValue::Decimal128(Some(value), cast_p, cast_s)); + let binary_expr = + Arc::new(BinaryExpr::new(cast_expr, Operator::Gt, literal_expr)); + + let result = unwrap_cast_in_comparison(binary_expr, &schema).unwrap(); + assert!(result.transformed); + + // Case 2: literal < cast(column) + let cast_expr = Arc::new(CastExpr::new( + column_expr, + DataType::Decimal128(cast_p, cast_s), + None, + )); + let literal_expr = lit(ScalarValue::Decimal128(Some(value), cast_p, cast_s)); + let binary_expr = + Arc::new(BinaryExpr::new(literal_expr, Operator::Lt, cast_expr)); + + let result = unwrap_cast_in_comparison(binary_expr, &schema).unwrap(); + assert!(result.transformed); + } + } + + #[test] + fn test_unwrap_cast_with_null_literals() { + // Test with NULL literals to ensure they're handled correctly + let schema = Schema::new(vec![Field::new("int_col", DataType::Int32, true)]); + + // Create: cast(int_col as INT64) = NULL + let column_expr = col("int_col", &schema).unwrap(); + let cast_expr = Arc::new(CastExpr::new(column_expr, DataType::Int64, None)); + let null_literal = lit(ScalarValue::Int64(None)); + let binary_expr = + Arc::new(BinaryExpr::new(cast_expr, Operator::Eq, null_literal)); + + // Apply unwrap cast optimization + let result = unwrap_cast_in_comparison(binary_expr, &schema).unwrap(); + + // Should be transformed + assert!(result.transformed); + + // Verify the NULL was cast to the column type + let optimized = result.data; + let optimized_binary = optimized.as_any().downcast_ref::().unwrap(); + let right_literal = optimized_binary + .right() + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(right_literal.value(), &ScalarValue::Int32(None)); + } + + #[test] + fn test_unwrap_cast_with_try_cast() { + // Test that TryCast expressions are also unwrapped correctly + let schema = Schema::new(vec![Field::new("str_col", DataType::Utf8, true)]); + + // Create: try_cast(str_col as INT64) > INT64(100) + let column_expr = col("str_col", &schema).unwrap(); + let try_cast_expr = Arc::new(TryCastExpr::new(column_expr, DataType::Int64)); + let literal_expr = lit(100i64); + let binary_expr = + Arc::new(BinaryExpr::new(try_cast_expr, Operator::Gt, literal_expr)); + + // Apply unwrap cast optimization + let result = unwrap_cast_in_comparison(binary_expr, &schema).unwrap(); + + // Should NOT be transformed (string to int cast not supported) + assert!(!result.transformed); + } + + #[test] + fn test_unwrap_cast_preserves_non_comparison_operators() { + // Test that non-comparison operators in AND/OR expressions are preserved + let schema = Schema::new(vec![Field::new("int_col", DataType::Int32, false)]); + + // Create: cast(int_col as INT64) > INT64(10) AND cast(int_col as INT64) < INT64(20) + let column_expr = col("int_col", &schema).unwrap(); + + let cast1 = Arc::new(CastExpr::new( + Arc::clone(&column_expr), + DataType::Int64, + None, + )); + let lit1 = lit(10i64); + let compare1 = Arc::new(BinaryExpr::new(cast1, Operator::Gt, lit1)); + + let cast2 = Arc::new(CastExpr::new(column_expr, DataType::Int64, None)); + let lit2 = lit(20i64); + let compare2 = Arc::new(BinaryExpr::new(cast2, Operator::Lt, lit2)); + + let and_expr = Arc::new(BinaryExpr::new(compare1, Operator::And, compare2)); + + // Apply unwrap cast optimization + let result = unwrap_cast_in_comparison(and_expr, &schema).unwrap(); + + // Should be transformed + assert!(result.transformed); + + // Verify the AND operator is preserved + let optimized = result.data; + let and_binary = optimized.as_any().downcast_ref::().unwrap(); + assert_eq!(*and_binary.op(), Operator::And); + + // Both sides should have their casts unwrapped + let left_binary = and_binary + .left() + .as_any() + .downcast_ref::() + .unwrap(); + let right_binary = and_binary + .right() + .as_any() + .downcast_ref::() + .unwrap(); + + assert!(!is_cast_expr(left_binary.left())); + assert!(!is_cast_expr(right_binary.left())); + } + + #[test] + fn test_try_cast_unwrapping() { + let schema = test_schema(); + + // Create: try_cast(c1 as INT64) <= INT64(100) + let column_expr = col("c1", &schema).unwrap(); + let try_cast_expr = Arc::new(TryCastExpr::new(column_expr, DataType::Int64)); + let literal_expr = lit(100i64); + let binary_expr = + Arc::new(BinaryExpr::new(try_cast_expr, Operator::LtEq, literal_expr)); + + // Apply unwrap cast optimization + let result = unwrap_cast_in_comparison(binary_expr, &schema).unwrap(); + + // Should be transformed to: c1 <= INT32(100) + assert!(result.transformed); + + let optimized = result.data; + let optimized_binary = optimized.as_any().downcast_ref::().unwrap(); + + // Verify the try_cast was removed + assert!(!is_cast_expr(optimized_binary.left())); + + // Verify the literal was converted + let right_literal = optimized_binary + .right() + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(right_literal.value(), &ScalarValue::Int32(Some(100))); + } + + #[test] + fn test_non_swappable_operator() { + // Test case with an operator that cannot be swapped + let schema = Schema::new(vec![Field::new("int_col", DataType::Int32, false)]); + + // Create: INT64(10) + cast(int_col as INT64) + // The Plus operator cannot be swapped, so this should not be transformed + let column_expr = col("int_col", &schema).unwrap(); + let cast_expr = Arc::new(CastExpr::new(column_expr, DataType::Int64, None)); + let literal_expr = lit(10i64); + let binary_expr = + Arc::new(BinaryExpr::new(literal_expr, Operator::Plus, cast_expr)); + + // Apply unwrap cast optimization + let result = unwrap_cast_in_comparison(binary_expr, &schema).unwrap(); + + // Should NOT be transformed because Plus cannot be swapped + assert!(!result.transformed); + } + + #[test] + fn test_cast_that_cannot_be_unwrapped_overflow() { + // Test case where the literal value would overflow the target type + let schema = Schema::new(vec![Field::new("small_int", DataType::Int8, false)]); + + // Create: cast(small_int as INT64) > INT64(1000) + // This should NOT be unwrapped because 1000 cannot fit in Int8 (max value is 127) + let column_expr = col("small_int", &schema).unwrap(); + let cast_expr = Arc::new(CastExpr::new(column_expr, DataType::Int64, None)); + let literal_expr = lit(1000i64); // Value too large for Int8 + let binary_expr = + Arc::new(BinaryExpr::new(cast_expr, Operator::Gt, literal_expr)); + + // Apply unwrap cast optimization + let result = unwrap_cast_in_comparison(binary_expr, &schema).unwrap(); + + // Should NOT be transformed due to overflow + assert!(!result.transformed); + } + + #[test] + fn test_complex_nested_expression() { + let schema = test_schema(); + + // Create a more complex expression with nested casts + // (cast(c1 as INT64) > INT64(10)) AND (cast(c2 as INT32) = INT32(20)) + let c1_expr = col("c1", &schema).unwrap(); + let c1_cast = Arc::new(CastExpr::new(c1_expr, DataType::Int64, None)); + let c1_literal = lit(10i64); + let c1_binary = Arc::new(BinaryExpr::new(c1_cast, Operator::Gt, c1_literal)); + + let c2_expr = col("c2", &schema).unwrap(); + let c2_cast = Arc::new(CastExpr::new(c2_expr, DataType::Int32, None)); + let c2_literal = lit(20i32); + let c2_binary = Arc::new(BinaryExpr::new(c2_cast, Operator::Eq, c2_literal)); + + // Create AND expression + let and_expr = Arc::new(BinaryExpr::new(c1_binary, Operator::And, c2_binary)); + + // Apply unwrap cast optimization + let result = unwrap_cast_in_comparison(and_expr, &schema).unwrap(); + + // Should be transformed + assert!(result.transformed); + + // Verify both sides of the AND were optimized + let optimized = result.data; + let and_binary = optimized.as_any().downcast_ref::().unwrap(); + + // Left side should be: c1 > INT32(10) + let left_binary = and_binary + .left() + .as_any() + .downcast_ref::() + .unwrap(); + assert!(!is_cast_expr(left_binary.left())); + let left_literal = left_binary + .right() + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(left_literal.value(), &ScalarValue::Int32(Some(10))); + + // Right side should be: c2 = INT64(20) (c2 is already INT64, literal cast to match) + let right_binary = and_binary + .right() + .as_any() + .downcast_ref::() + .unwrap(); + assert!(!is_cast_expr(right_binary.left())); + let right_literal = right_binary + .right() + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(right_literal.value(), &ScalarValue::Int64(Some(20))); + } +} diff --git a/datafusion/pruning/src/pruning_predicate.rs b/datafusion/pruning/src/pruning_predicate.rs index 1551a8f79a7a8..e863f57f8d165 100644 --- a/datafusion/pruning/src/pruning_predicate.rs +++ b/datafusion/pruning/src/pruning_predicate.rs @@ -30,6 +30,7 @@ use arrow::{ }; // pub use for backwards compatibility pub use datafusion_common::pruning::PruningStatistics; +use datafusion_physical_expr::simplifier::PhysicalExprSimplifier; use datafusion_physical_plan::metrics::Count; use log::{debug, trace}; @@ -468,6 +469,10 @@ impl PruningPredicate { &mut required_columns, &unhandled_hook, ); + let predicate_schema = required_columns.schema(); + // Simplify the newly created predicate to get rid of redundant casts, comparisons, etc. + let predicate_expr = + PhysicalExprSimplifier::new(&predicate_schema).simplify(predicate_expr)?; let literal_guarantees = LiteralGuarantee::analyze(&expr); @@ -735,6 +740,21 @@ impl RequiredColumns { } } + /// Returns a schema that describes the columns required to evaluate this + /// pruning predicate. + /// The schema contains the fields for each column in `self.columns` with + /// the appropriate data type for the statistics. + /// Order matters, this same order is used to evaluate the + /// pruning predicate. + fn schema(&self) -> Schema { + let fields = self + .columns + .iter() + .map(|(_c, _t, f)| f.clone()) + .collect::>(); + Schema::new(fields) + } + /// Returns an iterator over items in columns (see doc on /// `self.columns` for details) pub(crate) fn iter( @@ -883,7 +903,6 @@ fn build_statistics_record_batch( statistics: &S, required_columns: &RequiredColumns, ) -> Result { - let mut fields = Vec::::new(); let mut arrays = Vec::::new(); // For each needed statistics column: for (column, statistics_type, stat_field) in required_columns.iter() { @@ -912,11 +931,10 @@ fn build_statistics_record_batch( // provides timestamp statistics as "Int64") let array = arrow::compute::cast(&array, data_type)?; - fields.push(stat_field.clone()); arrays.push(array); } - let schema = Arc::new(Schema::new(fields)); + let schema = Arc::new(required_columns.schema()); // provide the count in case there were no needed statistics let mut options = RecordBatchOptions::default(); options.row_count = Some(statistics.num_containers());