diff --git a/native/Cargo.lock b/native/Cargo.lock index 7966bb80bb..538c40ee23 100644 --- a/native/Cargo.lock +++ b/native/Cargo.lock @@ -436,7 +436,18 @@ checksum = "d640d25bc63c50fb1f0b545ffd80207d2e10a4c965530809b40ba3386825c391" dependencies = [ "alloc-no-stdlib", "alloc-stdlib", - "brotli-decompressor", + "brotli-decompressor 2.5.1", +] + +[[package]] +name = "brotli" +version = "7.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cc97b8f16f944bba54f0433f07e30be199b6dc2bd25937444bbad560bcea29bd" +dependencies = [ + "alloc-no-stdlib", + "alloc-stdlib", + "brotli-decompressor 4.0.1", ] [[package]] @@ -449,6 +460,16 @@ dependencies = [ "alloc-stdlib", ] +[[package]] +name = "brotli-decompressor" +version = "4.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9a45bd2e4095a8b518033b128020dd4a55aab1c0a381ba4404a472630f4bc362" +dependencies = [ + "alloc-no-stdlib", + "alloc-stdlib", +] + [[package]] name = "bumpalo" version = "3.16.0" @@ -842,6 +863,7 @@ dependencies = [ "num_cpus", "object_store", "parking_lot", + "parquet", "paste", "pin-project-lite", "rand", @@ -878,7 +900,7 @@ dependencies = [ "arrow-schema", "assertables", "async-trait", - "brotli", + "brotli 3.5.0", "bytes", "crc32fast", "criterion", @@ -914,7 +936,7 @@ dependencies = [ "tempfile", "thiserror", "tokio", - "zstd", + "zstd 0.11.2+zstd.1.5.2", ] [[package]] @@ -943,6 +965,7 @@ dependencies = [ "datafusion-physical-expr", "futures", "num", + "parquet", "rand", "regex", "thiserror", @@ -969,6 +992,7 @@ dependencies = [ "libc", "num_cpus", "object_store", + "parquet", "paste", "sqlparser", "tokio", @@ -2350,16 +2374,33 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "dea02606ba6f5e856561d8d507dba8bac060aefca2a6c0f1aa1d361fed91ff3e" dependencies = [ "ahash", + "arrow-array", + "arrow-buffer", + "arrow-cast", + "arrow-data", + "arrow-ipc", + "arrow-schema", + "arrow-select", + "base64", + "brotli 7.0.0", "bytes", "chrono", + "flate2", + "futures", "half", "hashbrown 0.14.5", + "lz4_flex", "num", "num-bigint", + "object_store", "paste", "seq-macro", + "snap", "thrift", + "tokio", "twox-hash 1.6.3", + "zstd 0.13.2", + "zstd-sys", ] [[package]] @@ -3652,7 +3693,16 @@ version = "0.11.2+zstd.1.5.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "20cc960326ece64f010d2d2107537f26dc589a6573a316bd5b1dba685fa5fde4" dependencies = [ - "zstd-safe", + "zstd-safe 5.0.2+zstd.1.5.2", +] + +[[package]] +name = "zstd" +version = "0.13.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fcf2b778a664581e31e389454a7072dab1647606d44f7feea22cd5abb9c9f3f9" +dependencies = [ + "zstd-safe 7.2.1", ] [[package]] @@ -3665,6 +3715,15 @@ dependencies = [ "zstd-sys", ] +[[package]] +name = "zstd-safe" +version = "7.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "54a3ab4db68cea366acc5c897c7b4d4d1b8994a9cd6e6f841f8964566a419059" +dependencies = [ + "zstd-sys", +] + [[package]] name = "zstd-sys" version = "2.0.13+zstd.1.5.6" diff --git a/native/Cargo.toml b/native/Cargo.toml index 4ac85479f2..bd46cf0c9f 100644 --- a/native/Cargo.toml +++ b/native/Cargo.toml @@ -39,8 +39,8 @@ arrow-buffer = { version = "53.2.0" } arrow-data = { version = "53.2.0" } arrow-schema = { version = "53.2.0" } parquet = { version = "53.2.0", default-features = false, features = ["experimental"] } -datafusion-common = { version = "43.0.0" } datafusion = { version = "43.0.0", default-features = false, features = ["unicode_expressions", "crypto_expressions"] } +datafusion-common = { version = "43.0.0" } datafusion-functions = { version = "43.0.0", features = ["crypto_expressions"] } datafusion-functions-nested = { version = "43.0.0", default-features = false } datafusion-expr = { version = "43.0.0", default-features = false } diff --git a/native/core/src/parquet/util/test_common/mod.rs b/native/core/src/parquet/util/test_common/mod.rs index e46d732239..d92544608e 100644 --- a/native/core/src/parquet/util/test_common/mod.rs +++ b/native/core/src/parquet/util/test_common/mod.rs @@ -15,10 +15,9 @@ // specific language governing permissions and limitations // under the License. -pub mod file_util; pub mod page_util; pub mod rand_gen; pub use self::rand_gen::{random_bools, random_bytes, random_numbers, random_numbers_range}; -pub use self::file_util::{get_temp_file, get_temp_filename}; +pub use datafusion_comet_spark_expr::test_common::file_util::{get_temp_file, get_temp_filename}; diff --git a/native/spark-expr/Cargo.toml b/native/spark-expr/Cargo.toml index d0bc2fd9dd..27367d83e1 100644 --- a/native/spark-expr/Cargo.toml +++ b/native/spark-expr/Cargo.toml @@ -33,7 +33,7 @@ arrow-buffer = { workspace = true } arrow-data = { workspace = true } arrow-schema = { workspace = true } chrono = { workspace = true } -datafusion = { workspace = true } +datafusion = { workspace = true, features = ["parquet"] } datafusion-common = { workspace = true } datafusion-expr = { workspace = true } datafusion-physical-expr = { workspace = true } @@ -43,9 +43,11 @@ regex = { workspace = true } thiserror = { workspace = true } futures = { workspace = true } twox-hash = "2.0.0" +rand = { workspace = true } [dev-dependencies] arrow-data = {workspace = true} +parquet = { workspace = true, features = ["arrow"] } criterion = "0.5.1" rand = { workspace = true} tokio = { version = "1", features = ["rt-multi-thread"] } diff --git a/native/spark-expr/src/cast.rs b/native/spark-expr/src/cast.rs index f62d0220c9..d96bcbbdb6 100644 --- a/native/spark-expr/src/cast.rs +++ b/native/spark-expr/src/cast.rs @@ -15,6 +15,9 @@ // specific language governing permissions and limitations // under the License. +use crate::timezone; +use crate::utils::array_with_timezone; +use crate::{EvalMode, SparkError, SparkResult}; use arrow::{ array::{ cast::AsArray, @@ -35,11 +38,18 @@ use arrow::{ use arrow_array::builder::StringBuilder; use arrow_array::{DictionaryArray, StringArray, StructArray}; use arrow_schema::{DataType, Field, Schema}; +use chrono::{NaiveDate, NaiveDateTime, TimeZone, Timelike}; +use datafusion::physical_expr_common::physical_expr::down_cast_any_ref; use datafusion_common::{ cast::as_generic_string_array, internal_err, Result as DataFusionResult, ScalarValue, }; use datafusion_expr::ColumnarValue; use datafusion_physical_expr::PhysicalExpr; +use num::{ + cast::AsPrimitive, integer::div_floor, traits::CheckedNeg, CheckedSub, Integer, Num, + ToPrimitive, +}; +use regex::Regex; use std::str::FromStr; use std::{ any::Any, @@ -49,19 +59,6 @@ use std::{ sync::Arc, }; -use chrono::{NaiveDate, NaiveDateTime, TimeZone, Timelike}; -use datafusion::physical_expr_common::physical_expr::down_cast_any_ref; -use num::{ - cast::AsPrimitive, integer::div_floor, traits::CheckedNeg, CheckedSub, Integer, Num, - ToPrimitive, -}; -use regex::Regex; - -use crate::timezone; -use crate::utils::array_with_timezone; - -use crate::{EvalMode, SparkError, SparkResult}; - static TIMESTAMP_FORMAT: Option<&str> = Some("%Y-%m-%d %H:%M:%S%.f"); const MICROS_PER_SECOND: i64 = 1000000; @@ -141,6 +138,240 @@ pub struct Cast { pub cast_options: SparkCastOptions, } +/// Determine if Comet supports a cast, taking options such as EvalMode and Timezone into account. +pub fn cast_supported( + from_type: &DataType, + to_type: &DataType, + options: &SparkCastOptions, +) -> bool { + use DataType::*; + + let from_type = if let Dictionary(_, dt) = from_type { + dt + } else { + from_type + }; + + let to_type = if let Dictionary(_, dt) = to_type { + dt + } else { + to_type + }; + + if from_type == to_type { + return true; + } + + match (from_type, to_type) { + (Boolean, _) => can_cast_from_boolean(to_type, options), + (UInt8 | UInt16 | UInt32 | UInt64, Int8 | Int16 | Int32 | Int64) + if options.allow_cast_unsigned_ints => + { + true + } + (Int8, _) => can_cast_from_byte(to_type, options), + (Int16, _) => can_cast_from_short(to_type, options), + (Int32, _) => can_cast_from_int(to_type, options), + (Int64, _) => can_cast_from_long(to_type, options), + (Float32, _) => can_cast_from_float(to_type, options), + (Float64, _) => can_cast_from_double(to_type, options), + (Decimal128(p, s), _) => can_cast_from_decimal(p, s, to_type, options), + (Timestamp(_, None), _) => can_cast_from_timestamp_ntz(to_type, options), + (Timestamp(_, Some(_)), _) => can_cast_from_timestamp(to_type, options), + (Utf8 | LargeUtf8, _) => can_cast_from_string(to_type, options), + (_, Utf8 | LargeUtf8) => can_cast_to_string(from_type, options), + (Struct(from_fields), Struct(to_fields)) => from_fields + .iter() + .zip(to_fields.iter()) + .all(|(a, b)| cast_supported(a.data_type(), b.data_type(), options)), + _ => false, + } +} + +fn can_cast_from_string(to_type: &DataType, options: &SparkCastOptions) -> bool { + use DataType::*; + match to_type { + Boolean | Int8 | Int16 | Int32 | Int64 | Binary => true, + Float32 | Float64 => { + // https://github.com/apache/datafusion-comet/issues/326 + // Does not support inputs ending with 'd' or 'f'. Does not support 'inf'. + // Does not support ANSI mode. + options.allow_incompat + } + Decimal128(_, _) => { + // https://github.com/apache/datafusion-comet/issues/325 + // Does not support inputs ending with 'd' or 'f'. Does not support 'inf'. + // Does not support ANSI mode. Returns 0.0 instead of null if input contains no digits + + options.allow_incompat + } + Date32 | Date64 => { + // https://github.com/apache/datafusion-comet/issues/327 + // Only supports years between 262143 BC and 262142 AD + options.allow_incompat + } + Timestamp(_, _) if options.eval_mode == EvalMode::Ansi => { + // ANSI mode not supported + false + } + Timestamp(_, Some(tz)) if tz.as_ref() != "UTC" => { + // Cast will use UTC instead of $timeZoneId + options.allow_incompat + } + Timestamp(_, _) => { + // https://github.com/apache/datafusion-comet/issues/328 + // Not all valid formats are supported + options.allow_incompat + } + _ => false, + } +} + +fn can_cast_to_string(from_type: &DataType, options: &SparkCastOptions) -> bool { + use DataType::*; + match from_type { + Boolean | Int8 | Int16 | Int32 | Int64 | Date32 | Date64 | Timestamp(_, _) => true, + Float32 | Float64 => { + // There can be differences in precision. + // For example, the input \"1.4E-45\" will produce 1.0E-45 " + + // instead of 1.4E-45")) + true + } + Decimal128(_, _) => { + // https://github.com/apache/datafusion-comet/issues/1068 + // There can be formatting differences in some case due to Spark using + // scientific notation where Comet does not + true + } + Binary => { + // https://github.com/apache/datafusion-comet/issues/377 + // Only works for binary data representing valid UTF-8 strings + options.allow_incompat + } + Struct(fields) => fields + .iter() + .all(|f| can_cast_to_string(f.data_type(), options)), + _ => false, + } +} + +fn can_cast_from_timestamp_ntz(to_type: &DataType, options: &SparkCastOptions) -> bool { + use DataType::*; + match to_type { + Timestamp(_, _) | Date32 | Date64 | Utf8 => { + // incompatible + options.allow_incompat + } + _ => { + // unsupported + false + } + } +} + +fn can_cast_from_timestamp(to_type: &DataType, _options: &SparkCastOptions) -> bool { + use DataType::*; + match to_type { + Boolean | Int8 | Int16 => { + // https://github.com/apache/datafusion-comet/issues/352 + // this seems like an edge case that isn't important for us to support + false + } + Int64 => { + // https://github.com/apache/datafusion-comet/issues/352 + true + } + Date32 | Date64 | Utf8 | Decimal128(_, _) => true, + _ => { + // unsupported + false + } + } +} + +fn can_cast_from_boolean(to_type: &DataType, _: &SparkCastOptions) -> bool { + use DataType::*; + matches!(to_type, Int8 | Int16 | Int32 | Int64 | Float32 | Float64) +} + +fn can_cast_from_byte(to_type: &DataType, _: &SparkCastOptions) -> bool { + use DataType::*; + matches!( + to_type, + Boolean | Int8 | Int16 | Int32 | Int64 | Float32 | Float64 | Decimal128(_, _) + ) +} + +fn can_cast_from_short(to_type: &DataType, _: &SparkCastOptions) -> bool { + use DataType::*; + matches!( + to_type, + Boolean | Int8 | Int16 | Int32 | Int64 | Float32 | Float64 | Decimal128(_, _) + ) +} + +fn can_cast_from_int(to_type: &DataType, options: &SparkCastOptions) -> bool { + use DataType::*; + match to_type { + Boolean | Int8 | Int16 | Int32 | Int64 | Float32 | Float64 | Utf8 => true, + Decimal128(_, _) => { + // incompatible: no overflow check + options.allow_incompat + } + _ => false, + } +} + +fn can_cast_from_long(to_type: &DataType, options: &SparkCastOptions) -> bool { + use DataType::*; + match to_type { + Boolean | Int8 | Int16 | Int32 | Int64 | Float32 | Float64 => true, + Decimal128(_, _) => { + // incompatible: no overflow check + options.allow_incompat + } + _ => false, + } +} + +fn can_cast_from_float(to_type: &DataType, _: &SparkCastOptions) -> bool { + use DataType::*; + matches!( + to_type, + Boolean | Int8 | Int16 | Int32 | Int64 | Float64 | Decimal128(_, _) + ) +} + +fn can_cast_from_double(to_type: &DataType, _: &SparkCastOptions) -> bool { + use DataType::*; + matches!( + to_type, + Boolean | Int8 | Int16 | Int32 | Int64 | Float32 | Decimal128(_, _) + ) +} + +fn can_cast_from_decimal( + p1: &u8, + _s1: &i8, + to_type: &DataType, + options: &SparkCastOptions, +) -> bool { + use DataType::*; + match to_type { + Int8 | Int16 | Int32 | Int64 | Float32 | Float64 => true, + Decimal128(p2, _) => { + if p2 < p1 { + // https://github.com/apache/datafusion/issues/13492 + // Incompatible(Some("Casting to smaller precision is not supported")) + options.allow_incompat + } else { + true + } + } + _ => false, + } +} + macro_rules! cast_utf8_to_int { ($array:expr, $eval_mode:expr, $array_type:ty, $cast_method:ident) => {{ let len = $array.len(); @@ -560,6 +791,8 @@ pub struct SparkCastOptions { pub timezone: String, /// Allow casts that are supported but not guaranteed to be 100% compatible pub allow_incompat: bool, + /// Support casting unsigned ints to signed ints (used by Parquet SchemaAdapter) + pub allow_cast_unsigned_ints: bool, } impl SparkCastOptions { @@ -568,6 +801,7 @@ impl SparkCastOptions { eval_mode, timezone: timezone.to_string(), allow_incompat, + allow_cast_unsigned_ints: false, } } @@ -576,6 +810,7 @@ impl SparkCastOptions { eval_mode, timezone: "".to_string(), allow_incompat, + allow_cast_unsigned_ints: false, } } } @@ -611,14 +846,14 @@ fn cast_array( to_type: &DataType, cast_options: &SparkCastOptions, ) -> DataFusionResult { + use DataType::*; let array = array_with_timezone(array, cast_options.timezone.clone(), Some(to_type))?; let from_type = array.data_type().clone(); let array = match &from_type { - DataType::Dictionary(key_type, value_type) - if key_type.as_ref() == &DataType::Int32 - && (value_type.as_ref() == &DataType::Utf8 - || value_type.as_ref() == &DataType::LargeUtf8) => + Dictionary(key_type, value_type) + if key_type.as_ref() == &Int32 + && (value_type.as_ref() == &Utf8 || value_type.as_ref() == &LargeUtf8) => { let dict_array = array .as_any() @@ -631,7 +866,7 @@ fn cast_array( ); let casted_result = match to_type { - DataType::Dictionary(_, _) => Arc::new(casted_dictionary.clone()), + Dictionary(_, _) => Arc::new(casted_dictionary.clone()), _ => take(casted_dictionary.values().as_ref(), dict_array.keys(), None)?, }; return Ok(spark_cast_postprocess(casted_result, &from_type, to_type)); @@ -642,70 +877,66 @@ fn cast_array( let eval_mode = cast_options.eval_mode; let cast_result = match (from_type, to_type) { - (DataType::Utf8, DataType::Boolean) => spark_cast_utf8_to_boolean::(&array, eval_mode), - (DataType::LargeUtf8, DataType::Boolean) => { - spark_cast_utf8_to_boolean::(&array, eval_mode) - } - (DataType::Utf8, DataType::Timestamp(_, _)) => { + (Utf8, Boolean) => spark_cast_utf8_to_boolean::(&array, eval_mode), + (LargeUtf8, Boolean) => spark_cast_utf8_to_boolean::(&array, eval_mode), + (Utf8, Timestamp(_, _)) => { cast_string_to_timestamp(&array, to_type, eval_mode, &cast_options.timezone) } - (DataType::Utf8, DataType::Date32) => cast_string_to_date(&array, to_type, eval_mode), - (DataType::Int64, DataType::Int32) - | (DataType::Int64, DataType::Int16) - | (DataType::Int64, DataType::Int8) - | (DataType::Int32, DataType::Int16) - | (DataType::Int32, DataType::Int8) - | (DataType::Int16, DataType::Int8) + (Utf8, Date32) => cast_string_to_date(&array, to_type, eval_mode), + (Int64, Int32) + | (Int64, Int16) + | (Int64, Int8) + | (Int32, Int16) + | (Int32, Int8) + | (Int16, Int8) if eval_mode != EvalMode::Try => { spark_cast_int_to_int(&array, eval_mode, from_type, to_type) } - (DataType::Utf8, DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64) => { + (Utf8, Int8 | Int16 | Int32 | Int64) => { cast_string_to_int::(to_type, &array, eval_mode) } - ( - DataType::LargeUtf8, - DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64, - ) => cast_string_to_int::(to_type, &array, eval_mode), - (DataType::Float64, DataType::Utf8) => spark_cast_float64_to_utf8::(&array, eval_mode), - (DataType::Float64, DataType::LargeUtf8) => { - spark_cast_float64_to_utf8::(&array, eval_mode) - } - (DataType::Float32, DataType::Utf8) => spark_cast_float32_to_utf8::(&array, eval_mode), - (DataType::Float32, DataType::LargeUtf8) => { - spark_cast_float32_to_utf8::(&array, eval_mode) - } - (DataType::Float32, DataType::Decimal128(precision, scale)) => { + (LargeUtf8, Int8 | Int16 | Int32 | Int64) => { + cast_string_to_int::(to_type, &array, eval_mode) + } + (Float64, Utf8) => spark_cast_float64_to_utf8::(&array, eval_mode), + (Float64, LargeUtf8) => spark_cast_float64_to_utf8::(&array, eval_mode), + (Float32, Utf8) => spark_cast_float32_to_utf8::(&array, eval_mode), + (Float32, LargeUtf8) => spark_cast_float32_to_utf8::(&array, eval_mode), + (Float32, Decimal128(precision, scale)) => { cast_float32_to_decimal128(&array, *precision, *scale, eval_mode) } - (DataType::Float64, DataType::Decimal128(precision, scale)) => { + (Float64, Decimal128(precision, scale)) => { cast_float64_to_decimal128(&array, *precision, *scale, eval_mode) } - (DataType::Float32, DataType::Int8) - | (DataType::Float32, DataType::Int16) - | (DataType::Float32, DataType::Int32) - | (DataType::Float32, DataType::Int64) - | (DataType::Float64, DataType::Int8) - | (DataType::Float64, DataType::Int16) - | (DataType::Float64, DataType::Int32) - | (DataType::Float64, DataType::Int64) - | (DataType::Decimal128(_, _), DataType::Int8) - | (DataType::Decimal128(_, _), DataType::Int16) - | (DataType::Decimal128(_, _), DataType::Int32) - | (DataType::Decimal128(_, _), DataType::Int64) + (Float32, Int8) + | (Float32, Int16) + | (Float32, Int32) + | (Float32, Int64) + | (Float64, Int8) + | (Float64, Int16) + | (Float64, Int32) + | (Float64, Int64) + | (Decimal128(_, _), Int8) + | (Decimal128(_, _), Int16) + | (Decimal128(_, _), Int32) + | (Decimal128(_, _), Int64) if eval_mode != EvalMode::Try => { spark_cast_nonintegral_numeric_to_integral(&array, eval_mode, from_type, to_type) } - (DataType::Struct(_), DataType::Utf8) => { - Ok(casts_struct_to_string(array.as_struct(), cast_options)?) - } - (DataType::Struct(_), DataType::Struct(_)) => Ok(cast_struct_to_struct( + (Struct(_), Utf8) => Ok(casts_struct_to_string(array.as_struct(), cast_options)?), + (Struct(_), Struct(_)) => Ok(cast_struct_to_struct( array.as_struct(), from_type, to_type, cast_options, )?), + (UInt8 | UInt16 | UInt32 | UInt64, Int8 | Int16 | Int32 | Int64) + if cast_options.allow_cast_unsigned_ints => + { + Ok(cast_with_options(&array, to_type, &CAST_OPTIONS)?) + } _ if is_datafusion_spark_compatible(from_type, to_type, cast_options.allow_incompat) => { // use DataFusion cast only when we know that it is compatible with Spark Ok(cast_with_options(&array, to_type, &CAST_OPTIONS)?) diff --git a/native/spark-expr/src/lib.rs b/native/spark-expr/src/lib.rs index 8a57480587..f358731004 100644 --- a/native/spark-expr/src/lib.rs +++ b/native/spark-expr/src/lib.rs @@ -41,6 +41,9 @@ mod kernels; mod list; mod regexp; pub mod scalar_funcs; +mod schema_adapter; +pub use schema_adapter::SparkSchemaAdapterFactory; + pub mod spark_hash; mod stddev; pub use stddev::Stddev; @@ -51,6 +54,8 @@ mod negative; pub use negative::{create_negate_expr, NegativeExpr}; mod normalize_nan; mod temporal; + +pub mod test_common; pub mod timezone; mod to_json; mod unbound; diff --git a/native/spark-expr/src/schema_adapter.rs b/native/spark-expr/src/schema_adapter.rs new file mode 100644 index 0000000000..161ad6f164 --- /dev/null +++ b/native/spark-expr/src/schema_adapter.rs @@ -0,0 +1,376 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Custom schema adapter that uses Spark-compatible casts + +use crate::cast::cast_supported; +use crate::{spark_cast, SparkCastOptions}; +use arrow_array::{new_null_array, Array, RecordBatch, RecordBatchOptions}; +use arrow_schema::{Schema, SchemaRef}; +use datafusion::datasource::schema_adapter::{SchemaAdapter, SchemaAdapterFactory, SchemaMapper}; +use datafusion_common::plan_err; +use datafusion_expr::ColumnarValue; +use std::sync::Arc; + +/// An implementation of DataFusion's `SchemaAdapterFactory` that uses a Spark-compatible +/// `cast` implementation. +#[derive(Clone, Debug)] +pub struct SparkSchemaAdapterFactory { + /// Spark cast options + cast_options: SparkCastOptions, +} + +impl SparkSchemaAdapterFactory { + pub fn new(options: SparkCastOptions) -> Self { + Self { + cast_options: options, + } + } +} + +impl SchemaAdapterFactory for SparkSchemaAdapterFactory { + /// Create a new factory for mapping batches from a file schema to a table + /// schema. + /// + /// This is a convenience for [`DefaultSchemaAdapterFactory::create`] with + /// the same schema for both the projected table schema and the table + /// schema. + fn create( + &self, + required_schema: SchemaRef, + table_schema: SchemaRef, + ) -> Box { + Box::new(SparkSchemaAdapter { + required_schema, + table_schema, + cast_options: self.cast_options.clone(), + }) + } +} + +/// This SchemaAdapter requires both the table schema and the projected table +/// schema. See [`SchemaMapping`] for more details +#[derive(Clone, Debug)] +pub struct SparkSchemaAdapter { + /// The schema for the table, projected to include only the fields being output (projected) by the + /// associated ParquetExec + required_schema: SchemaRef, + /// The entire table schema for the table we're using this to adapt. + /// + /// This is used to evaluate any filters pushed down into the scan + /// which may refer to columns that are not referred to anywhere + /// else in the plan. + table_schema: SchemaRef, + /// Spark cast options + cast_options: SparkCastOptions, +} + +impl SchemaAdapter for SparkSchemaAdapter { + /// Map a column index in the table schema to a column index in a particular + /// file schema + /// + /// Panics if index is not in range for the table schema + fn map_column_index(&self, index: usize, file_schema: &Schema) -> Option { + let field = self.required_schema.field(index); + Some(file_schema.fields.find(field.name())?.0) + } + + /// Creates a `SchemaMapping` for casting or mapping the columns from the + /// file schema to the table schema. + /// + /// If the provided `file_schema` contains columns of a different type to + /// the expected `table_schema`, the method will attempt to cast the array + /// data from the file schema to the table schema where possible. + /// + /// Returns a [`SchemaMapping`] that can be applied to the output batch + /// along with an ordered list of columns to project from the file + fn map_schema( + &self, + file_schema: &Schema, + ) -> datafusion_common::Result<(Arc, Vec)> { + let mut projection = Vec::with_capacity(file_schema.fields().len()); + let mut field_mappings = vec![None; self.required_schema.fields().len()]; + + for (file_idx, file_field) in file_schema.fields.iter().enumerate() { + if let Some((table_idx, table_field)) = + self.required_schema.fields().find(file_field.name()) + { + if cast_supported( + file_field.data_type(), + table_field.data_type(), + &self.cast_options, + ) { + field_mappings[table_idx] = Some(projection.len()); + projection.push(file_idx); + } else { + return plan_err!( + "Cannot cast file schema field {} of type {:?} to required schema field of type {:?}", + file_field.name(), + file_field.data_type(), + table_field.data_type() + ); + } + } + } + + Ok(( + Arc::new(SchemaMapping { + required_schema: Arc::::clone(&self.required_schema), + field_mappings, + table_schema: Arc::::clone(&self.table_schema), + cast_options: self.cast_options.clone(), + }), + projection, + )) + } +} + +// TODO SchemaMapping is mostly copied from DataFusion but calls spark_cast +// instead of arrow cast - can we reduce the amount of code copied here and make +// the DataFusion version more extensible? + +/// The SchemaMapping struct holds a mapping from the file schema to the table +/// schema and any necessary type conversions. +/// +/// Note, because `map_batch` and `map_partial_batch` functions have different +/// needs, this struct holds two schemas: +/// +/// 1. The projected **table** schema +/// 2. The full table schema +/// +/// [`map_batch`] is used by the ParquetOpener to produce a RecordBatch which +/// has the projected schema, since that's the schema which is supposed to come +/// out of the execution of this query. Thus `map_batch` uses +/// `projected_table_schema` as it can only operate on the projected fields. +/// +/// [`map_partial_batch`] is used to create a RecordBatch with a schema that +/// can be used for Parquet predicate pushdown, meaning that it may contain +/// fields which are not in the projected schema (as the fields that parquet +/// pushdown filters operate can be completely distinct from the fields that are +/// projected (output) out of the ParquetExec). `map_partial_batch` thus uses +/// `table_schema` to create the resulting RecordBatch (as it could be operating +/// on any fields in the schema). +/// +/// [`map_batch`]: Self::map_batch +/// [`map_partial_batch`]: Self::map_partial_batch +#[derive(Debug)] +pub struct SchemaMapping { + /// The schema of the table. This is the expected schema after conversion + /// and it should match the schema of the query result. + required_schema: SchemaRef, + /// Mapping from field index in `projected_table_schema` to index in + /// projected file_schema. + /// + /// They are Options instead of just plain `usize`s because the table could + /// have fields that don't exist in the file. + field_mappings: Vec>, + /// The entire table schema, as opposed to the projected_table_schema (which + /// only contains the columns that we are projecting out of this query). + /// This contains all fields in the table, regardless of if they will be + /// projected out or not. + table_schema: SchemaRef, + /// Spark cast options + cast_options: SparkCastOptions, +} + +impl SchemaMapper for SchemaMapping { + /// Adapts a `RecordBatch` to match the `projected_table_schema` using the stored mapping and + /// conversions. The produced RecordBatch has a schema that contains only the projected + /// columns, so if one needs a RecordBatch with a schema that references columns which are not + /// in the projected, it would be better to use `map_partial_batch` + fn map_batch(&self, batch: RecordBatch) -> datafusion_common::Result { + let batch_rows = batch.num_rows(); + let batch_cols = batch.columns().to_vec(); + + let cols = self + .required_schema + // go through each field in the projected schema + .fields() + .iter() + // and zip it with the index that maps fields from the projected table schema to the + // projected file schema in `batch` + .zip(&self.field_mappings) + // and for each one... + .map(|(field, file_idx)| { + file_idx.map_or_else( + // If this field only exists in the table, and not in the file, then we know + // that it's null, so just return that. + || Ok(new_null_array(field.data_type(), batch_rows)), + // However, if it does exist in both, then try to cast it to the correct output + // type + |batch_idx| { + spark_cast( + ColumnarValue::Array(Arc::clone(&batch_cols[batch_idx])), + field.data_type(), + &self.cast_options, + )? + .into_array(batch_rows) + }, + ) + }) + .collect::, _>>()?; + + // Necessary to handle empty batches + let options = RecordBatchOptions::new().with_row_count(Some(batch.num_rows())); + + let schema = Arc::::clone(&self.required_schema); + let record_batch = RecordBatch::try_new_with_options(schema, cols, &options)?; + Ok(record_batch) + } + + /// Adapts a [`RecordBatch`]'s schema into one that has all the correct output types and only + /// contains the fields that exist in both the file schema and table schema. + /// + /// Unlike `map_batch` this method also preserves the columns that + /// may not appear in the final output (`projected_table_schema`) but may + /// appear in push down predicates + fn map_partial_batch(&self, batch: RecordBatch) -> datafusion_common::Result { + let batch_cols = batch.columns().to_vec(); + let schema = batch.schema(); + + // for each field in the batch's schema (which is based on a file, not a table)... + let (cols, fields) = schema + .fields() + .iter() + .zip(batch_cols.iter()) + .flat_map(|(field, batch_col)| { + self.table_schema + // try to get the same field from the table schema that we have stored in self + .field_with_name(field.name()) + // and if we don't have it, that's fine, ignore it. This may occur when we've + // created an external table whose fields are a subset of the fields in this + // file, then tried to read data from the file into this table. If that is the + // case here, it's fine to ignore because we don't care about this field + // anyways + .ok() + // but if we do have it, + .map(|table_field| { + // try to cast it into the correct output type. we don't want to ignore this + // error, though, so it's propagated. + spark_cast( + ColumnarValue::Array(Arc::clone(batch_col)), + table_field.data_type(), + &self.cast_options, + )? + .into_array(batch_col.len()) + // and if that works, return the field and column. + .map(|new_col| (new_col, table_field.clone())) + }) + }) + .collect::, _>>()? + .into_iter() + .unzip::<_, _, Vec<_>, Vec<_>>(); + + // Necessary to handle empty batches + let options = RecordBatchOptions::new().with_row_count(Some(batch.num_rows())); + + let schema = Arc::new(Schema::new_with_metadata(fields, schema.metadata().clone())); + let record_batch = RecordBatch::try_new_with_options(schema, cols, &options)?; + Ok(record_batch) + } +} + +#[cfg(test)] +mod test { + use crate::test_common::file_util::get_temp_filename; + use crate::{EvalMode, SparkCastOptions, SparkSchemaAdapterFactory}; + use arrow::array::{Int32Array, StringArray}; + use arrow::datatypes::{DataType, Field, Schema}; + use arrow::record_batch::RecordBatch; + use arrow_array::UInt32Array; + use arrow_schema::SchemaRef; + use datafusion::datasource::listing::PartitionedFile; + use datafusion::datasource::physical_plan::{FileScanConfig, ParquetExec}; + use datafusion::execution::object_store::ObjectStoreUrl; + use datafusion::execution::TaskContext; + use datafusion::physical_plan::ExecutionPlan; + use datafusion_common::DataFusionError; + use futures::StreamExt; + use parquet::arrow::ArrowWriter; + use std::fs::File; + use std::sync::Arc; + + #[tokio::test] + async fn parquet_roundtrip_int_as_string() -> Result<(), DataFusionError> { + let file_schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new("name", DataType::Utf8, false), + ])); + + let ids = Arc::new(Int32Array::from(vec![1, 2, 3])) as Arc; + let names = Arc::new(StringArray::from(vec!["Alice", "Bob", "Charlie"])) + as Arc; + let batch = RecordBatch::try_new(Arc::clone(&file_schema), vec![ids, names])?; + + let required_schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Utf8, false), + Field::new("name", DataType::Utf8, false), + ])); + + let _ = roundtrip(&batch, required_schema).await?; + + Ok(()) + } + + #[tokio::test] + async fn parquet_roundtrip_unsigned_int() -> Result<(), DataFusionError> { + let file_schema = Arc::new(Schema::new(vec![Field::new("id", DataType::UInt32, false)])); + + let ids = Arc::new(UInt32Array::from(vec![1, 2, 3])) as Arc; + let batch = RecordBatch::try_new(Arc::clone(&file_schema), vec![ids])?; + + let required_schema = Arc::new(Schema::new(vec![Field::new("id", DataType::Int32, false)])); + + let _ = roundtrip(&batch, required_schema).await?; + + Ok(()) + } + + /// Create a Parquet file containing a single batch and then read the batch back using + /// the specified required_schema. This will cause the SchemaAdapter code to be used. + async fn roundtrip( + batch: &RecordBatch, + required_schema: SchemaRef, + ) -> Result { + let filename = get_temp_filename(); + let filename = filename.as_path().as_os_str().to_str().unwrap().to_string(); + let file = File::create(&filename)?; + let mut writer = ArrowWriter::try_new(file, Arc::clone(&batch.schema()), None)?; + writer.write(batch)?; + writer.close()?; + + let object_store_url = ObjectStoreUrl::local_filesystem(); + let file_scan_config = FileScanConfig::new(object_store_url, required_schema) + .with_file_groups(vec![vec![PartitionedFile::from_path( + filename.to_string(), + )?]]); + + let mut spark_cast_options = SparkCastOptions::new(EvalMode::Legacy, "UTC", false); + spark_cast_options.allow_cast_unsigned_ints = true; + + let parquet_exec = ParquetExec::builder(file_scan_config) + .with_schema_adapter_factory(Arc::new(SparkSchemaAdapterFactory::new( + spark_cast_options, + ))) + .build(); + + let mut stream = parquet_exec + .execute(0, Arc::new(TaskContext::default())) + .unwrap(); + stream.next().await.unwrap() + } +} diff --git a/native/core/src/parquet/util/test_common/file_util.rs b/native/spark-expr/src/test_common/file_util.rs similarity index 100% rename from native/core/src/parquet/util/test_common/file_util.rs rename to native/spark-expr/src/test_common/file_util.rs diff --git a/native/spark-expr/src/test_common/mod.rs b/native/spark-expr/src/test_common/mod.rs new file mode 100644 index 0000000000..efd25a4a2a --- /dev/null +++ b/native/spark-expr/src/test_common/mod.rs @@ -0,0 +1,17 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. +pub mod file_util;