diff --git a/datafusion/substrait/src/extensions.rs b/datafusion/substrait/src/extensions.rs new file mode 100644 index 0000000000000..459d0e0c5ae58 --- /dev/null +++ b/datafusion/substrait/src/extensions.rs @@ -0,0 +1,157 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use datafusion::common::{plan_err, DataFusionError}; +use std::collections::HashMap; +use substrait::proto::extensions::simple_extension_declaration::{ + ExtensionFunction, ExtensionType, ExtensionTypeVariation, MappingType, +}; +use substrait::proto::extensions::SimpleExtensionDeclaration; + +/// Substrait uses [SimpleExtensions](https://substrait.io/extensions/#simple-extensions) to define +/// behavior of plans in addition to what's supported directly by the protobuf definitions. +/// That includes functions, but also provides support for custom types and variations for existing +/// types. This structs facilitates the use of these extensions in DataFusion. +/// TODO: DF doesn't yet use extensions for type variations +/// TODO: DF doesn't yet provide valid extensionUris +#[derive(Default, Debug, PartialEq)] +pub struct Extensions { + pub functions: HashMap, // anchor -> function name + pub types: HashMap, // anchor -> type name + pub type_variations: HashMap, // anchor -> type variation name +} + +impl Extensions { + /// Registers a function and returns the anchor (reference) to it. If the function has already + /// been registered, it returns the existing anchor. + /// Function names are case-insensitive (converted to lowercase). + pub fn register_function(&mut self, function_name: String) -> u32 { + let function_name = function_name.to_lowercase(); + + // Some functions are named differently in Substrait default extensions than in DF + // Rename those to match the Substrait extensions for interoperability + let function_name = match function_name.as_str() { + "substr" => "substring".to_string(), + _ => function_name, + }; + + match self.functions.iter().find(|(_, f)| *f == &function_name) { + Some((function_anchor, _)) => *function_anchor, // Function has been registered + None => { + // Function has NOT been registered + let function_anchor = self.functions.len() as u32; + self.functions + .insert(function_anchor, function_name.clone()); + function_anchor + } + } + } + + /// Registers a type and returns the anchor (reference) to it. If the type has already + /// been registered, it returns the existing anchor. + pub fn register_type(&mut self, type_name: String) -> u32 { + let type_name = type_name.to_lowercase(); + match self.types.iter().find(|(_, t)| *t == &type_name) { + Some((type_anchor, _)) => *type_anchor, // Type has been registered + None => { + // Type has NOT been registered + let type_anchor = self.types.len() as u32; + self.types.insert(type_anchor, type_name.clone()); + type_anchor + } + } + } +} + +impl TryFrom<&Vec> for Extensions { + type Error = DataFusionError; + + fn try_from( + value: &Vec, + ) -> datafusion::common::Result { + let mut functions = HashMap::new(); + let mut types = HashMap::new(); + let mut type_variations = HashMap::new(); + + for ext in value { + match &ext.mapping_type { + Some(MappingType::ExtensionFunction(ext_f)) => { + functions.insert(ext_f.function_anchor, ext_f.name.to_owned()); + } + Some(MappingType::ExtensionType(ext_t)) => { + types.insert(ext_t.type_anchor, ext_t.name.to_owned()); + } + Some(MappingType::ExtensionTypeVariation(ext_v)) => { + type_variations + .insert(ext_v.type_variation_anchor, ext_v.name.to_owned()); + } + None => return plan_err!("Cannot parse empty extension"), + } + } + + Ok(Extensions { + functions, + types, + type_variations, + }) + } +} + +impl From for Vec { + fn from(val: Extensions) -> Vec { + let mut extensions = vec![]; + for (f_anchor, f_name) in val.functions { + let function_extension = ExtensionFunction { + extension_uri_reference: u32::MAX, + function_anchor: f_anchor, + name: f_name, + }; + let simple_extension = SimpleExtensionDeclaration { + mapping_type: Some(MappingType::ExtensionFunction(function_extension)), + }; + extensions.push(simple_extension); + } + + for (t_anchor, t_name) in val.types { + let type_extension = ExtensionType { + extension_uri_reference: u32::MAX, // https://github.com/apache/datafusion/issues/11545 + type_anchor: t_anchor, + name: t_name, + }; + let simple_extension = SimpleExtensionDeclaration { + mapping_type: Some(MappingType::ExtensionType(type_extension)), + }; + extensions.push(simple_extension); + } + + for (tv_anchor, tv_name) in val.type_variations { + let type_variation_extension = ExtensionTypeVariation { + extension_uri_reference: u32::MAX, // We don't register proper extension URIs yet + type_variation_anchor: tv_anchor, + name: tv_name, + }; + let simple_extension = SimpleExtensionDeclaration { + mapping_type: Some(MappingType::ExtensionTypeVariation( + type_variation_extension, + )), + }; + extensions.push(simple_extension); + } + + extensions + } +} diff --git a/datafusion/substrait/src/lib.rs b/datafusion/substrait/src/lib.rs index 454f0e7b7cb99..0b1c796553c0a 100644 --- a/datafusion/substrait/src/lib.rs +++ b/datafusion/substrait/src/lib.rs @@ -72,6 +72,7 @@ //! # Ok(()) //! # } //! ``` +pub mod extensions; pub mod logical_plan; pub mod physical_plan; pub mod serializer; diff --git a/datafusion/substrait/src/logical_plan/consumer.rs b/datafusion/substrait/src/logical_plan/consumer.rs index 1365630d5079a..5768c44bbf6c8 100644 --- a/datafusion/substrait/src/logical_plan/consumer.rs +++ b/datafusion/substrait/src/logical_plan/consumer.rs @@ -36,16 +36,21 @@ use datafusion::logical_expr::{ use substrait::proto::expression::subquery::set_predicate::PredicateOp; use url::Url; +use crate::extensions::Extensions; use crate::variation_const::{ DATE_32_TYPE_VARIATION_REF, DATE_64_TYPE_VARIATION_REF, DECIMAL_128_TYPE_VARIATION_REF, DECIMAL_256_TYPE_VARIATION_REF, DEFAULT_CONTAINER_TYPE_VARIATION_REF, DEFAULT_TYPE_VARIATION_REF, - INTERVAL_DAY_TIME_TYPE_REF, INTERVAL_MONTH_DAY_NANO_TYPE_REF, - INTERVAL_YEAR_MONTH_TYPE_REF, LARGE_CONTAINER_TYPE_VARIATION_REF, + INTERVAL_MONTH_DAY_NANO_TYPE_NAME, LARGE_CONTAINER_TYPE_VARIATION_REF, TIMESTAMP_MICRO_TYPE_VARIATION_REF, TIMESTAMP_MILLI_TYPE_VARIATION_REF, TIMESTAMP_NANO_TYPE_VARIATION_REF, TIMESTAMP_SECOND_TYPE_VARIATION_REF, UNSIGNED_INTEGER_TYPE_VARIATION_REF, }; +#[allow(deprecated)] +use crate::variation_const::{ + INTERVAL_DAY_TIME_TYPE_REF, INTERVAL_MONTH_DAY_NANO_TYPE_REF, + INTERVAL_YEAR_MONTH_TYPE_REF, +}; use datafusion::common::scalar::ScalarStructBuilder; use datafusion::logical_expr::expr::InList; use datafusion::logical_expr::{ @@ -65,7 +70,9 @@ use std::str::FromStr; use std::sync::Arc; use substrait::proto::exchange_rel::ExchangeKind; use substrait::proto::expression::literal::user_defined::Val; -use substrait::proto::expression::literal::{IntervalDayToSecond, IntervalYearToMonth}; +use substrait::proto::expression::literal::{ + IntervalDayToSecond, IntervalYearToMonth, UserDefined, +}; use substrait::proto::expression::subquery::SubqueryType; use substrait::proto::expression::{self, FieldReference, Literal, ScalarFunction}; use substrait::proto::read_rel::local_files::file_or_files::PathType::UriFile; @@ -78,7 +85,6 @@ use substrait::proto::{ window_function::bound::Kind as BoundKind, window_function::Bound, window_function::BoundsType, MaskExpression, RexType, }, - extensions::simple_extension_declaration::MappingType, function_argument::ArgType, join_rel, plan_rel, r#type, read_rel::ReadType, @@ -185,19 +191,10 @@ pub async fn from_substrait_plan( plan: &Plan, ) -> Result { // Register function extension - let function_extension = plan - .extensions - .iter() - .map(|e| match &e.mapping_type { - Some(ext) => match ext { - MappingType::ExtensionFunction(ext_f) => { - Ok((ext_f.function_anchor, &ext_f.name)) - } - _ => not_impl_err!("Extension type not supported: {ext:?}"), - }, - None => not_impl_err!("Cannot parse empty extension"), - }) - .collect::>>()?; + let extensions = Extensions::try_from(&plan.extensions)?; + if !extensions.type_variations.is_empty() { + return not_impl_err!("Type variation extensions are not supported"); + } // Parse relations match plan.relations.len() { @@ -205,10 +202,10 @@ pub async fn from_substrait_plan( match plan.relations[0].rel_type.as_ref() { Some(rt) => match rt { plan_rel::RelType::Rel(rel) => { - Ok(from_substrait_rel(ctx, rel, &function_extension).await?) + Ok(from_substrait_rel(ctx, rel, &extensions).await?) }, plan_rel::RelType::Root(root) => { - let plan = from_substrait_rel(ctx, root.input.as_ref().unwrap(), &function_extension).await?; + let plan = from_substrait_rel(ctx, root.input.as_ref().unwrap(), &extensions).await?; if root.names.is_empty() { // Backwards compatibility for plans missing names return Ok(plan); @@ -396,7 +393,7 @@ fn make_renamed_schema( pub async fn from_substrait_rel( ctx: &SessionContext, rel: &Rel, - extensions: &HashMap, + extensions: &Extensions, ) -> Result { match &rel.rel_type { Some(RelType::Project(p)) => { @@ -660,7 +657,7 @@ pub async fn from_substrait_rel( substrait_datafusion_err!("No base schema provided for Virtual Table") })?; - let schema = from_substrait_named_struct(base_schema)?; + let schema = from_substrait_named_struct(base_schema, extensions)?; if vt.values.is_empty() { return Ok(LogicalPlan::EmptyRelation(EmptyRelation { @@ -681,6 +678,7 @@ pub async fn from_substrait_rel( name_idx += 1; // top-level names are provided through schema Ok(Expr::Literal(from_substrait_literal( lit, + extensions, &base_schema.names, &mut name_idx, )?)) @@ -892,7 +890,7 @@ pub async fn from_substrait_sorts( ctx: &SessionContext, substrait_sorts: &Vec, input_schema: &DFSchema, - extensions: &HashMap, + extensions: &Extensions, ) -> Result> { let mut sorts: Vec = vec![]; for s in substrait_sorts { @@ -942,7 +940,7 @@ pub async fn from_substrait_rex_vec( ctx: &SessionContext, exprs: &Vec, input_schema: &DFSchema, - extensions: &HashMap, + extensions: &Extensions, ) -> Result> { let mut expressions: Vec = vec![]; for expr in exprs { @@ -957,7 +955,7 @@ pub async fn from_substrait_func_args( ctx: &SessionContext, arguments: &Vec, input_schema: &DFSchema, - extensions: &HashMap, + extensions: &Extensions, ) -> Result> { let mut args: Vec = vec![]; for arg in arguments { @@ -977,7 +975,7 @@ pub async fn from_substrait_agg_func( ctx: &SessionContext, f: &AggregateFunction, input_schema: &DFSchema, - extensions: &HashMap, + extensions: &Extensions, filter: Option>, order_by: Option>, distinct: bool, @@ -985,14 +983,14 @@ pub async fn from_substrait_agg_func( let args = from_substrait_func_args(ctx, &f.arguments, input_schema, extensions).await?; - let Some(function_name) = extensions.get(&f.function_reference) else { + let Some(function_name) = extensions.functions.get(&f.function_reference) else { return plan_err!( "Aggregate function not registered: function anchor = {:?}", f.function_reference ); }; - let function_name = substrait_fun_name((**function_name).as_str()); + let function_name = substrait_fun_name(function_name); // try udaf first, then built-in aggr fn. if let Ok(fun) = ctx.udaf(function_name) { // deal with situation that count(*) got no arguments @@ -1025,7 +1023,7 @@ pub async fn from_substrait_rex( ctx: &SessionContext, e: &Expression, input_schema: &DFSchema, - extensions: &HashMap, + extensions: &Extensions, ) -> Result { match &e.rex_type { Some(RexType::SingularOrList(s)) => { @@ -1105,7 +1103,7 @@ pub async fn from_substrait_rex( })) } Some(RexType::ScalarFunction(f)) => { - let Some(fn_name) = extensions.get(&f.function_reference) else { + let Some(fn_name) = extensions.functions.get(&f.function_reference) else { return plan_err!( "Scalar function not found: function reference = {:?}", f.function_reference @@ -1155,7 +1153,7 @@ pub async fn from_substrait_rex( } } Some(RexType::Literal(lit)) => { - let scalar_value = from_substrait_literal_without_names(lit)?; + let scalar_value = from_substrait_literal_without_names(lit, extensions)?; Ok(Expr::Literal(scalar_value)) } Some(RexType::Cast(cast)) => match cast.as_ref().r#type.as_ref() { @@ -1169,12 +1167,13 @@ pub async fn from_substrait_rex( ) .await?, ), - from_substrait_type_without_names(output_type)?, + from_substrait_type_without_names(output_type, extensions)?, ))), None => substrait_err!("Cast expression without output type is not allowed"), }, Some(RexType::WindowFunction(window)) => { - let Some(fn_name) = extensions.get(&window.function_reference) else { + let Some(fn_name) = extensions.functions.get(&window.function_reference) + else { return plan_err!( "Window function not found: function reference = {:?}", window.function_reference @@ -1328,12 +1327,16 @@ pub async fn from_substrait_rex( } } -pub(crate) fn from_substrait_type_without_names(dt: &Type) -> Result { - from_substrait_type(dt, &[], &mut 0) +pub(crate) fn from_substrait_type_without_names( + dt: &Type, + extensions: &Extensions, +) -> Result { + from_substrait_type(dt, extensions, &[], &mut 0) } fn from_substrait_type( dt: &Type, + extensions: &Extensions, dfs_names: &[String], name_idx: &mut usize, ) -> Result { @@ -1416,7 +1419,7 @@ fn from_substrait_type( substrait_datafusion_err!("List type must have inner type") })?; let field = Arc::new(Field::new_list_field( - from_substrait_type(inner_type, dfs_names, name_idx)?, + from_substrait_type(inner_type, extensions, dfs_names, name_idx)?, // We ignore Substrait's nullability here to match to_substrait_literal // which always creates nullable lists true, @@ -1438,12 +1441,12 @@ fn from_substrait_type( })?; let key_field = Arc::new(Field::new( "key", - from_substrait_type(key_type, dfs_names, name_idx)?, + from_substrait_type(key_type, extensions, dfs_names, name_idx)?, false, )); let value_field = Arc::new(Field::new( "value", - from_substrait_type(value_type, dfs_names, name_idx)?, + from_substrait_type(value_type, extensions, dfs_names, name_idx)?, true, )); match map.type_variation_reference { @@ -1490,28 +1493,41 @@ fn from_substrait_type( ), }, r#type::Kind::UserDefined(u) => { - match u.type_reference { - // Kept for backwards compatibility, use IntervalYear instead - INTERVAL_YEAR_MONTH_TYPE_REF => { - Ok(DataType::Interval(IntervalUnit::YearMonth)) - } - // Kept for backwards compatibility, use IntervalDay instead - INTERVAL_DAY_TIME_TYPE_REF => { - Ok(DataType::Interval(IntervalUnit::DayTime)) + if let Some(name) = extensions.types.get(&u.type_reference) { + match name.as_ref() { + INTERVAL_MONTH_DAY_NANO_TYPE_NAME => Ok(DataType::Interval(IntervalUnit::MonthDayNano)), + _ => not_impl_err!( + "Unsupported Substrait user defined type with ref {} and variation {}", + u.type_reference, + u.type_variation_reference + ), } - // Not supported yet by Substrait - INTERVAL_MONTH_DAY_NANO_TYPE_REF => { - Ok(DataType::Interval(IntervalUnit::MonthDayNano)) - } - _ => not_impl_err!( + } else { + // Kept for backwards compatibility, new plans should include the extension instead + #[allow(deprecated)] + match u.type_reference { + // Kept for backwards compatibility, use IntervalYear instead + INTERVAL_YEAR_MONTH_TYPE_REF => { + Ok(DataType::Interval(IntervalUnit::YearMonth)) + } + // Kept for backwards compatibility, use IntervalDay instead + INTERVAL_DAY_TIME_TYPE_REF => { + Ok(DataType::Interval(IntervalUnit::DayTime)) + } + // Not supported yet by Substrait + INTERVAL_MONTH_DAY_NANO_TYPE_REF => { + Ok(DataType::Interval(IntervalUnit::MonthDayNano)) + } + _ => not_impl_err!( "Unsupported Substrait user defined type with ref {} and variation {}", u.type_reference, u.type_variation_reference ), + } } } r#type::Kind::Struct(s) => Ok(DataType::Struct(from_substrait_struct_type( - s, dfs_names, name_idx, + s, extensions, dfs_names, name_idx, )?)), r#type::Kind::Varchar(_) => Ok(DataType::Utf8), r#type::Kind::FixedChar(_) => Ok(DataType::Utf8), @@ -1523,6 +1539,7 @@ fn from_substrait_type( fn from_substrait_struct_type( s: &r#type::Struct, + extensions: &Extensions, dfs_names: &[String], name_idx: &mut usize, ) -> Result { @@ -1530,7 +1547,7 @@ fn from_substrait_struct_type( for (i, f) in s.types.iter().enumerate() { let field = Field::new( next_struct_field_name(i, dfs_names, name_idx)?, - from_substrait_type(f, dfs_names, name_idx)?, + from_substrait_type(f, extensions, dfs_names, name_idx)?, true, // We assume everything to be nullable since that's easier than ensuring it matches ); fields.push(field); @@ -1556,12 +1573,16 @@ fn next_struct_field_name( } } -fn from_substrait_named_struct(base_schema: &NamedStruct) -> Result { +fn from_substrait_named_struct( + base_schema: &NamedStruct, + extensions: &Extensions, +) -> Result { let mut name_idx = 0; let fields = from_substrait_struct_type( base_schema.r#struct.as_ref().ok_or_else(|| { substrait_datafusion_err!("Named struct must contain a struct") })?, + extensions, &base_schema.names, &mut name_idx, ); @@ -1621,12 +1642,16 @@ fn from_substrait_bound( } } -pub(crate) fn from_substrait_literal_without_names(lit: &Literal) -> Result { - from_substrait_literal(lit, &vec![], &mut 0) +pub(crate) fn from_substrait_literal_without_names( + lit: &Literal, + extensions: &Extensions, +) -> Result { + from_substrait_literal(lit, extensions, &vec![], &mut 0) } fn from_substrait_literal( lit: &Literal, + extensions: &Extensions, dfs_names: &Vec, name_idx: &mut usize, ) -> Result { @@ -1721,7 +1746,7 @@ fn from_substrait_literal( let elements = l .values .iter() - .map(|el| from_substrait_literal(el, dfs_names, name_idx)) + .map(|el| from_substrait_literal(el, extensions, dfs_names, name_idx)) .collect::>>()?; if elements.is_empty() { return substrait_err!( @@ -1744,6 +1769,7 @@ fn from_substrait_literal( Some(LiteralType::EmptyList(l)) => { let element_type = from_substrait_type( l.r#type.clone().unwrap().as_ref(), + extensions, dfs_names, name_idx, )?; @@ -1763,7 +1789,7 @@ fn from_substrait_literal( let mut builder = ScalarStructBuilder::new(); for (i, field) in s.fields.iter().enumerate() { let name = next_struct_field_name(i, dfs_names, name_idx)?; - let sv = from_substrait_literal(field, dfs_names, name_idx)?; + let sv = from_substrait_literal(field, extensions, dfs_names, name_idx)?; // We assume everything to be nullable, since Arrow's strict about things matching // and it's hard to match otherwise. builder = builder.with_scalar(Field::new(name, sv.data_type(), true), sv); @@ -1771,7 +1797,7 @@ fn from_substrait_literal( builder.build()? } Some(LiteralType::Null(ntype)) => { - from_substrait_null(ntype, dfs_names, name_idx)? + from_substrait_null(ntype, extensions, dfs_names, name_idx)? } Some(LiteralType::IntervalDayToSecond(IntervalDayToSecond { days, @@ -1786,40 +1812,9 @@ fn from_substrait_literal( } Some(LiteralType::FixedChar(c)) => ScalarValue::Utf8(Some(c.clone())), Some(LiteralType::UserDefined(user_defined)) => { - match user_defined.type_reference { - // Kept for backwards compatibility, use IntervalYearToMonth instead - INTERVAL_YEAR_MONTH_TYPE_REF => { - let Some(Val::Value(raw_val)) = user_defined.val.as_ref() else { - return substrait_err!("Interval year month value is empty"); - }; - let value_slice: [u8; 4] = - (*raw_val.value).try_into().map_err(|_| { - substrait_datafusion_err!( - "Failed to parse interval year month value" - ) - })?; - ScalarValue::IntervalYearMonth(Some(i32::from_le_bytes(value_slice))) - } - // Kept for backwards compatibility, use IntervalDayToSecond instead - INTERVAL_DAY_TIME_TYPE_REF => { - let Some(Val::Value(raw_val)) = user_defined.val.as_ref() else { - return substrait_err!("Interval day time value is empty"); - }; - let value_slice: [u8; 8] = - (*raw_val.value).try_into().map_err(|_| { - substrait_datafusion_err!( - "Failed to parse interval day time value" - ) - })?; - let days = i32::from_le_bytes(value_slice[0..4].try_into().unwrap()); - let milliseconds = - i32::from_le_bytes(value_slice[4..8].try_into().unwrap()); - ScalarValue::IntervalDayTime(Some(IntervalDayTime { - days, - milliseconds, - })) - } - INTERVAL_MONTH_DAY_NANO_TYPE_REF => { + // Helper function to prevent duplicating this code - can be inlined once the non-extension path is removed + let interval_month_day_nano = + |user_defined: &UserDefined| -> Result { let Some(Val::Value(raw_val)) = user_defined.val.as_ref() else { return substrait_err!("Interval month day nano value is empty"); }; @@ -1834,17 +1829,76 @@ fn from_substrait_literal( let days = i32::from_le_bytes(value_slice[4..8].try_into().unwrap()); let nanoseconds = i64::from_le_bytes(value_slice[8..16].try_into().unwrap()); - ScalarValue::IntervalMonthDayNano(Some(IntervalMonthDayNano { - months, - days, - nanoseconds, - })) - } - _ => { - return not_impl_err!( - "Unsupported Substrait user defined type with ref {}", - user_defined.type_reference + Ok(ScalarValue::IntervalMonthDayNano(Some( + IntervalMonthDayNano { + months, + days, + nanoseconds, + }, + ))) + }; + + if let Some(name) = extensions.types.get(&user_defined.type_reference) { + match name.as_ref() { + INTERVAL_MONTH_DAY_NANO_TYPE_NAME => { + interval_month_day_nano(user_defined)? + } + _ => { + return not_impl_err!( + "Unsupported Substrait user defined type with ref {} and name {}", + user_defined.type_reference, + name ) + } + } + } else { + // Kept for backwards compatibility - new plans should include extension instead + #[allow(deprecated)] + match user_defined.type_reference { + // Kept for backwards compatibility, use IntervalYearToMonth instead + INTERVAL_YEAR_MONTH_TYPE_REF => { + let Some(Val::Value(raw_val)) = user_defined.val.as_ref() else { + return substrait_err!("Interval year month value is empty"); + }; + let value_slice: [u8; 4] = + (*raw_val.value).try_into().map_err(|_| { + substrait_datafusion_err!( + "Failed to parse interval year month value" + ) + })?; + ScalarValue::IntervalYearMonth(Some(i32::from_le_bytes( + value_slice, + ))) + } + // Kept for backwards compatibility, use IntervalDayToSecond instead + INTERVAL_DAY_TIME_TYPE_REF => { + let Some(Val::Value(raw_val)) = user_defined.val.as_ref() else { + return substrait_err!("Interval day time value is empty"); + }; + let value_slice: [u8; 8] = + (*raw_val.value).try_into().map_err(|_| { + substrait_datafusion_err!( + "Failed to parse interval day time value" + ) + })?; + let days = + i32::from_le_bytes(value_slice[0..4].try_into().unwrap()); + let milliseconds = + i32::from_le_bytes(value_slice[4..8].try_into().unwrap()); + ScalarValue::IntervalDayTime(Some(IntervalDayTime { + days, + milliseconds, + })) + } + INTERVAL_MONTH_DAY_NANO_TYPE_REF => { + interval_month_day_nano(user_defined)? + } + _ => { + return not_impl_err!( + "Unsupported Substrait user defined type literal with ref {}", + user_defined.type_reference + ) + } } } } @@ -1856,6 +1910,7 @@ fn from_substrait_literal( fn from_substrait_null( null_type: &Type, + extensions: &Extensions, dfs_names: &[String], name_idx: &mut usize, ) -> Result { @@ -1940,6 +1995,7 @@ fn from_substrait_null( let field = Field::new_list_field( from_substrait_type( l.r#type.clone().unwrap().as_ref(), + extensions, dfs_names, name_idx, )?, @@ -1958,7 +2014,8 @@ fn from_substrait_null( } } r#type::Kind::Struct(s) => { - let fields = from_substrait_struct_type(s, dfs_names, name_idx)?; + let fields = + from_substrait_struct_type(s, extensions, dfs_names, name_idx)?; Ok(ScalarStructBuilder::new_null(fields)) } _ => not_impl_err!("Unsupported Substrait type for null: {kind:?}"), @@ -2012,7 +2069,7 @@ impl BuiltinExprBuilder { ctx: &SessionContext, f: &ScalarFunction, input_schema: &DFSchema, - extensions: &HashMap, + extensions: &Extensions, ) -> Result { match self.expr_name.as_str() { "like" => { @@ -2037,7 +2094,7 @@ impl BuiltinExprBuilder { fn_name: &str, f: &ScalarFunction, input_schema: &DFSchema, - extensions: &HashMap, + extensions: &Extensions, ) -> Result { if f.arguments.len() != 1 { return substrait_err!("Expect one argument for {fn_name} expr"); @@ -2071,7 +2128,7 @@ impl BuiltinExprBuilder { case_insensitive: bool, f: &ScalarFunction, input_schema: &DFSchema, - extensions: &HashMap, + extensions: &Extensions, ) -> Result { let fn_name = if case_insensitive { "ILIKE" } else { "LIKE" }; if f.arguments.len() != 2 && f.arguments.len() != 3 { diff --git a/datafusion/substrait/src/logical_plan/producer.rs b/datafusion/substrait/src/logical_plan/producer.rs index 7849d0bd431e6..d569aaabea2d9 100644 --- a/datafusion/substrait/src/logical_plan/producer.rs +++ b/datafusion/substrait/src/logical_plan/producer.rs @@ -16,7 +16,6 @@ // under the License. use itertools::Itertools; -use std::collections::HashMap; use std::ops::Deref; use std::sync::Arc; @@ -33,6 +32,16 @@ use datafusion::{ scalar::ScalarValue, }; +use crate::extensions::Extensions; +use crate::variation_const::{ + DATE_32_TYPE_VARIATION_REF, DATE_64_TYPE_VARIATION_REF, + DECIMAL_128_TYPE_VARIATION_REF, DECIMAL_256_TYPE_VARIATION_REF, + DEFAULT_CONTAINER_TYPE_VARIATION_REF, DEFAULT_TYPE_VARIATION_REF, + INTERVAL_MONTH_DAY_NANO_TYPE_NAME, LARGE_CONTAINER_TYPE_VARIATION_REF, + TIMESTAMP_MICRO_TYPE_VARIATION_REF, TIMESTAMP_MILLI_TYPE_VARIATION_REF, + TIMESTAMP_NANO_TYPE_VARIATION_REF, TIMESTAMP_SECOND_TYPE_VARIATION_REF, + UNSIGNED_INTEGER_TYPE_VARIATION_REF, +}; use datafusion::arrow::array::{Array, GenericListArray, OffsetSizeTrait}; use datafusion::common::{ exec_err, internal_err, not_impl_err, plan_err, substrait_datafusion_err, @@ -72,10 +81,6 @@ use substrait::{ ScalarFunction, SingularOrList, Subquery, WindowFunction as SubstraitWindowFunction, }, - extensions::{ - self, - simple_extension_declaration::{ExtensionFunction, MappingType}, - }, function_argument::ArgType, join_rel, plan_rel, r#type, read_rel::{NamedTable, ReadType}, @@ -90,39 +95,24 @@ use substrait::{ version, }; -use crate::variation_const::{ - DATE_32_TYPE_VARIATION_REF, DATE_64_TYPE_VARIATION_REF, - DECIMAL_128_TYPE_VARIATION_REF, DECIMAL_256_TYPE_VARIATION_REF, - DEFAULT_CONTAINER_TYPE_VARIATION_REF, DEFAULT_TYPE_VARIATION_REF, - INTERVAL_MONTH_DAY_NANO_TYPE_REF, INTERVAL_MONTH_DAY_NANO_TYPE_URL, - LARGE_CONTAINER_TYPE_VARIATION_REF, TIMESTAMP_MICRO_TYPE_VARIATION_REF, - TIMESTAMP_MILLI_TYPE_VARIATION_REF, TIMESTAMP_NANO_TYPE_VARIATION_REF, - TIMESTAMP_SECOND_TYPE_VARIATION_REF, UNSIGNED_INTEGER_TYPE_VARIATION_REF, -}; - /// Convert DataFusion LogicalPlan to Substrait Plan pub fn to_substrait_plan(plan: &LogicalPlan, ctx: &SessionContext) -> Result> { + let mut extensions = Extensions::default(); // Parse relation nodes - let mut extension_info: ( - Vec, - HashMap, - ) = (vec![], HashMap::new()); // Generate PlanRel(s) // Note: Only 1 relation tree is currently supported let plan_rels = vec![PlanRel { rel_type: Some(plan_rel::RelType::Root(RelRoot { - input: Some(*to_substrait_rel(plan, ctx, &mut extension_info)?), - names: to_substrait_named_struct(plan.schema())?.names, + input: Some(*to_substrait_rel(plan, ctx, &mut extensions)?), + names: to_substrait_named_struct(plan.schema(), &mut extensions)?.names, })), }]; - let (function_extensions, _) = extension_info; - // Return parsed plan Ok(Box::new(Plan { version: Some(version::version_with_producer("datafusion")), extension_uris: vec![], - extensions: function_extensions, + extensions: extensions.into(), relations: plan_rels, advanced_extensions: None, expected_type_urls: vec![], @@ -133,10 +123,7 @@ pub fn to_substrait_plan(plan: &LogicalPlan, ctx: &SessionContext) -> Result, - HashMap, - ), + extensions: &mut Extensions, ) -> Result> { match plan { LogicalPlan::TableScan(scan) => { @@ -187,7 +174,7 @@ pub fn to_substrait_rel( Ok(Box::new(Rel { rel_type: Some(RelType::Read(Box::new(ReadRel { common: None, - base_schema: Some(to_substrait_named_struct(&e.schema)?), + base_schema: Some(to_substrait_named_struct(&e.schema, extensions)?), filter: None, best_effort_filter: None, projection: None, @@ -206,10 +193,10 @@ pub fn to_substrait_rel( let fields = row .iter() .map(|v| match v { - Expr::Literal(sv) => to_substrait_literal(sv), + Expr::Literal(sv) => to_substrait_literal(sv, extensions), Expr::Alias(alias) => match alias.expr.as_ref() { // The schema gives us the names, so we can skip aliases - Expr::Literal(sv) => to_substrait_literal(sv), + Expr::Literal(sv) => to_substrait_literal(sv, extensions), _ => Err(substrait_datafusion_err!( "Only literal types can be aliased in Virtual Tables, got: {}", alias.expr.variant_name() )), @@ -225,7 +212,7 @@ pub fn to_substrait_rel( Ok(Box::new(Rel { rel_type: Some(RelType::Read(Box::new(ReadRel { common: None, - base_schema: Some(to_substrait_named_struct(&v.schema)?), + base_schema: Some(to_substrait_named_struct(&v.schema, extensions)?), filter: None, best_effort_filter: None, projection: None, @@ -238,25 +225,25 @@ pub fn to_substrait_rel( let expressions = p .expr .iter() - .map(|e| to_substrait_rex(ctx, e, p.input.schema(), 0, extension_info)) + .map(|e| to_substrait_rex(ctx, e, p.input.schema(), 0, extensions)) .collect::>>()?; Ok(Box::new(Rel { rel_type: Some(RelType::Project(Box::new(ProjectRel { common: None, - input: Some(to_substrait_rel(p.input.as_ref(), ctx, extension_info)?), + input: Some(to_substrait_rel(p.input.as_ref(), ctx, extensions)?), expressions, advanced_extension: None, }))), })) } LogicalPlan::Filter(filter) => { - let input = to_substrait_rel(filter.input.as_ref(), ctx, extension_info)?; + let input = to_substrait_rel(filter.input.as_ref(), ctx, extensions)?; let filter_expr = to_substrait_rex( ctx, &filter.predicate, filter.input.schema(), 0, - extension_info, + extensions, )?; Ok(Box::new(Rel { rel_type: Some(RelType::Filter(Box::new(FilterRel { @@ -268,7 +255,7 @@ pub fn to_substrait_rel( })) } LogicalPlan::Limit(limit) => { - let input = to_substrait_rel(limit.input.as_ref(), ctx, extension_info)?; + let input = to_substrait_rel(limit.input.as_ref(), ctx, extensions)?; // Since protobuf can't directly distinguish `None` vs `0` encode `None` as `MAX` let limit_fetch = limit.fetch.unwrap_or(usize::MAX); Ok(Box::new(Rel { @@ -282,13 +269,11 @@ pub fn to_substrait_rel( })) } LogicalPlan::Sort(sort) => { - let input = to_substrait_rel(sort.input.as_ref(), ctx, extension_info)?; + let input = to_substrait_rel(sort.input.as_ref(), ctx, extensions)?; let sort_fields = sort .expr .iter() - .map(|e| { - substrait_sort_field(ctx, e, sort.input.schema(), extension_info) - }) + .map(|e| substrait_sort_field(ctx, e, sort.input.schema(), extensions)) .collect::>>()?; Ok(Box::new(Rel { rel_type: Some(RelType::Sort(Box::new(SortRel { @@ -300,19 +285,17 @@ pub fn to_substrait_rel( })) } LogicalPlan::Aggregate(agg) => { - let input = to_substrait_rel(agg.input.as_ref(), ctx, extension_info)?; + let input = to_substrait_rel(agg.input.as_ref(), ctx, extensions)?; let groupings = to_substrait_groupings( ctx, &agg.group_expr, agg.input.schema(), - extension_info, + extensions, )?; let measures = agg .aggr_expr .iter() - .map(|e| { - to_substrait_agg_measure(ctx, e, agg.input.schema(), extension_info) - }) + .map(|e| to_substrait_agg_measure(ctx, e, agg.input.schema(), extensions)) .collect::>>()?; Ok(Box::new(Rel { @@ -327,7 +310,7 @@ pub fn to_substrait_rel( } LogicalPlan::Distinct(Distinct::All(plan)) => { // Use Substrait's AggregateRel with empty measures to represent `select distinct` - let input = to_substrait_rel(plan.as_ref(), ctx, extension_info)?; + let input = to_substrait_rel(plan.as_ref(), ctx, extensions)?; // Get grouping keys from the input relation's number of output fields let grouping = (0..plan.schema().fields().len()) .map(substrait_field_ref) @@ -346,8 +329,8 @@ pub fn to_substrait_rel( })) } LogicalPlan::Join(join) => { - let left = to_substrait_rel(join.left.as_ref(), ctx, extension_info)?; - let right = to_substrait_rel(join.right.as_ref(), ctx, extension_info)?; + let left = to_substrait_rel(join.left.as_ref(), ctx, extensions)?; + let right = to_substrait_rel(join.right.as_ref(), ctx, extensions)?; let join_type = to_substrait_jointype(join.join_type); // we only support basic joins so return an error for anything not yet supported match join.join_constraint { @@ -364,7 +347,7 @@ pub fn to_substrait_rel( filter, &Arc::new(in_join_schema), 0, - extension_info, + extensions, )?), None => None, }; @@ -382,7 +365,7 @@ pub fn to_substrait_rel( eq_op, join.left.schema(), join.right.schema(), - extension_info, + extensions, )?; // create conjunction between `join_on` and `join_filter` to embed all join conditions, @@ -393,7 +376,7 @@ pub fn to_substrait_rel( on_expr, filter, Operator::And, - extension_info, + extensions, ))), None => join_on.map(Box::new), // the join expression will only contain `join_on` if filter doesn't exist }, @@ -421,8 +404,8 @@ pub fn to_substrait_rel( right, schema: _, } = cross_join; - let left = to_substrait_rel(left.as_ref(), ctx, extension_info)?; - let right = to_substrait_rel(right.as_ref(), ctx, extension_info)?; + let left = to_substrait_rel(left.as_ref(), ctx, extensions)?; + let right = to_substrait_rel(right.as_ref(), ctx, extensions)?; Ok(Box::new(Rel { rel_type: Some(RelType::Cross(Box::new(CrossRel { common: None, @@ -435,13 +418,13 @@ pub fn to_substrait_rel( LogicalPlan::SubqueryAlias(alias) => { // Do nothing if encounters SubqueryAlias // since there is no corresponding relation type in Substrait - to_substrait_rel(alias.input.as_ref(), ctx, extension_info) + to_substrait_rel(alias.input.as_ref(), ctx, extensions) } LogicalPlan::Union(union) => { let input_rels = union .inputs .iter() - .map(|input| to_substrait_rel(input.as_ref(), ctx, extension_info)) + .map(|input| to_substrait_rel(input.as_ref(), ctx, extensions)) .collect::>>()? .into_iter() .map(|ptr| *ptr) @@ -456,7 +439,7 @@ pub fn to_substrait_rel( })) } LogicalPlan::Window(window) => { - let input = to_substrait_rel(window.input.as_ref(), ctx, extension_info)?; + let input = to_substrait_rel(window.input.as_ref(), ctx, extensions)?; // If the input is a Project relation, we can just append the WindowFunction expressions // before returning // Otherwise, wrap the input in a Project relation before appending the WindowFunction @@ -484,7 +467,7 @@ pub fn to_substrait_rel( expr, window.input.schema(), 0, - extension_info, + extensions, )?); } // Append parsed WindowFunction expressions @@ -494,8 +477,7 @@ pub fn to_substrait_rel( })) } LogicalPlan::Repartition(repartition) => { - let input = - to_substrait_rel(repartition.input.as_ref(), ctx, extension_info)?; + let input = to_substrait_rel(repartition.input.as_ref(), ctx, extensions)?; let partition_count = match repartition.partitioning_scheme { Partitioning::RoundRobinBatch(num) => num, Partitioning::Hash(_, num) => num, @@ -553,7 +535,7 @@ pub fn to_substrait_rel( .node .inputs() .into_iter() - .map(|plan| to_substrait_rel(plan, ctx, extension_info)) + .map(|plan| to_substrait_rel(plan, ctx, extensions)) .collect::>>()?; let rel_type = match inputs_rel.len() { 0 => RelType::ExtensionLeaf(ExtensionLeafRel { @@ -579,7 +561,10 @@ pub fn to_substrait_rel( } } -fn to_substrait_named_struct(schema: &DFSchemaRef) -> Result { +fn to_substrait_named_struct( + schema: &DFSchemaRef, + extensions: &mut Extensions, +) -> Result { // Substrait wants a list of all field names, including nested fields from structs, // also from within e.g. lists and maps. However, it does not want the list and map field names // themselves - only proper structs fields are considered to have useful names. @@ -624,7 +609,7 @@ fn to_substrait_named_struct(schema: &DFSchemaRef) -> Result { types: schema .fields() .iter() - .map(|f| to_substrait_type(f.data_type(), f.is_nullable())) + .map(|f| to_substrait_type(f.data_type(), f.is_nullable(), extensions)) .collect::>()?, type_variation_reference: DEFAULT_TYPE_VARIATION_REF, nullability: r#type::Nullability::Unspecified as i32, @@ -642,30 +627,27 @@ fn to_substrait_join_expr( eq_op: Operator, left_schema: &DFSchemaRef, right_schema: &DFSchemaRef, - extension_info: &mut ( - Vec, - HashMap, - ), + extensions: &mut Extensions, ) -> Result> { // Only support AND conjunction for each binary expression in join conditions let mut exprs: Vec = vec![]; for (left, right) in join_conditions { // Parse left - let l = to_substrait_rex(ctx, left, left_schema, 0, extension_info)?; + let l = to_substrait_rex(ctx, left, left_schema, 0, extensions)?; // Parse right let r = to_substrait_rex( ctx, right, right_schema, left_schema.fields().len(), // offset to return the correct index - extension_info, + extensions, )?; // AND with existing expression - exprs.push(make_binary_op_scalar_func(&l, &r, eq_op, extension_info)); + exprs.push(make_binary_op_scalar_func(&l, &r, eq_op, extensions)); } let join_expr: Option = exprs.into_iter().reduce(|acc: Expression, e: Expression| { - make_binary_op_scalar_func(&acc, &e, Operator::And, extension_info) + make_binary_op_scalar_func(&acc, &e, Operator::And, extensions) }); Ok(join_expr) } @@ -722,14 +704,11 @@ pub fn parse_flat_grouping_exprs( ctx: &SessionContext, exprs: &[Expr], schema: &DFSchemaRef, - extension_info: &mut ( - Vec, - HashMap, - ), + extensions: &mut Extensions, ) -> Result { let grouping_expressions = exprs .iter() - .map(|e| to_substrait_rex(ctx, e, schema, 0, extension_info)) + .map(|e| to_substrait_rex(ctx, e, schema, 0, extensions)) .collect::>>()?; Ok(Grouping { grouping_expressions, @@ -740,10 +719,7 @@ pub fn to_substrait_groupings( ctx: &SessionContext, exprs: &[Expr], schema: &DFSchemaRef, - extension_info: &mut ( - Vec, - HashMap, - ), + extensions: &mut Extensions, ) -> Result> { match exprs.len() { 1 => match &exprs[0] { @@ -753,9 +729,7 @@ pub fn to_substrait_groupings( )), GroupingSet::GroupingSets(sets) => Ok(sets .iter() - .map(|set| { - parse_flat_grouping_exprs(ctx, set, schema, extension_info) - }) + .map(|set| parse_flat_grouping_exprs(ctx, set, schema, extensions)) .collect::>>()?), GroupingSet::Rollup(set) => { let mut sets: Vec> = vec![vec![]]; @@ -766,23 +740,17 @@ pub fn to_substrait_groupings( .iter() .rev() .map(|set| { - parse_flat_grouping_exprs(ctx, set, schema, extension_info) + parse_flat_grouping_exprs(ctx, set, schema, extensions) }) .collect::>>()?) } }, _ => Ok(vec![parse_flat_grouping_exprs( - ctx, - exprs, - schema, - extension_info, + ctx, exprs, schema, extensions, )?]), }, _ => Ok(vec![parse_flat_grouping_exprs( - ctx, - exprs, - schema, - extension_info, + ctx, exprs, schema, extensions, )?]), } } @@ -792,25 +760,22 @@ pub fn to_substrait_agg_measure( ctx: &SessionContext, expr: &Expr, schema: &DFSchemaRef, - extension_info: &mut ( - Vec, - HashMap, - ), + extensions: &mut Extensions, ) -> Result { match expr { Expr::AggregateFunction(expr::AggregateFunction { func_def, args, distinct, filter, order_by, null_treatment: _, }) => { match func_def { AggregateFunctionDefinition::BuiltIn (fun) => { let sorts = if let Some(order_by) = order_by { - order_by.iter().map(|expr| to_substrait_sort_field(ctx, expr, schema, extension_info)).collect::>>()? + order_by.iter().map(|expr| to_substrait_sort_field(ctx, expr, schema, extensions)).collect::>>()? } else { vec![] }; let mut arguments: Vec = vec![]; for arg in args { - arguments.push(FunctionArgument { arg_type: Some(ArgType::Value(to_substrait_rex(ctx, arg, schema, 0, extension_info)?)) }); + arguments.push(FunctionArgument { arg_type: Some(ArgType::Value(to_substrait_rex(ctx, arg, schema, 0, extensions)?)) }); } - let function_anchor = register_function(fun.to_string(), extension_info); + let function_anchor = extensions.register_function(fun.to_string()); Ok(Measure { measure: Some(AggregateFunction { function_reference: function_anchor, @@ -826,22 +791,22 @@ pub fn to_substrait_agg_measure( options: vec![], }), filter: match filter { - Some(f) => Some(to_substrait_rex(ctx, f, schema, 0, extension_info)?), + Some(f) => Some(to_substrait_rex(ctx, f, schema, 0, extensions)?), None => None } }) } AggregateFunctionDefinition::UDF(fun) => { let sorts = if let Some(order_by) = order_by { - order_by.iter().map(|expr| to_substrait_sort_field(ctx, expr, schema, extension_info)).collect::>>()? + order_by.iter().map(|expr| to_substrait_sort_field(ctx, expr, schema, extensions)).collect::>>()? } else { vec![] }; let mut arguments: Vec = vec![]; for arg in args { - arguments.push(FunctionArgument { arg_type: Some(ArgType::Value(to_substrait_rex(ctx, arg, schema, 0, extension_info)?)) }); + arguments.push(FunctionArgument { arg_type: Some(ArgType::Value(to_substrait_rex(ctx, arg, schema, 0, extensions)?)) }); } - let function_anchor = register_function(fun.name().to_string(), extension_info); + let function_anchor = extensions.register_function(fun.name().to_string()); Ok(Measure { measure: Some(AggregateFunction { function_reference: function_anchor, @@ -857,7 +822,7 @@ pub fn to_substrait_agg_measure( options: vec![], }), filter: match filter { - Some(f) => Some(to_substrait_rex(ctx, f, schema, 0, extension_info)?), + Some(f) => Some(to_substrait_rex(ctx, f, schema, 0, extensions)?), None => None } }) @@ -866,7 +831,7 @@ pub fn to_substrait_agg_measure( } Expr::Alias(Alias{expr,..})=> { - to_substrait_agg_measure(ctx, expr, schema, extension_info) + to_substrait_agg_measure(ctx, expr, schema, extensions) } _ => internal_err!( "Expression must be compatible with aggregation. Unsupported expression: {:?}. ExpressionType: {:?}", @@ -881,10 +846,7 @@ fn to_substrait_sort_field( ctx: &SessionContext, expr: &Expr, schema: &DFSchemaRef, - extension_info: &mut ( - Vec, - HashMap, - ), + extensions: &mut Extensions, ) -> Result { match expr { Expr::Sort(sort) => { @@ -900,7 +862,7 @@ fn to_substrait_sort_field( sort.expr.deref(), schema, 0, - extension_info, + extensions, )?), sort_kind: Some(SortKind::Direction(sort_kind.into())), }) @@ -909,67 +871,15 @@ fn to_substrait_sort_field( } } -fn register_function( - function_name: String, - extension_info: &mut ( - Vec, - HashMap, - ), -) -> u32 { - let (function_extensions, function_set) = extension_info; - let function_name = function_name.to_lowercase(); - - // Some functions are named differently in Substrait default extensions than in DF - // Rename those to match the Substrait extensions for interoperability - let function_name = match function_name.as_str() { - "substr" => "substring".to_string(), - _ => function_name, - }; - - // To prevent ambiguous references between ScalarFunctions and AggregateFunctions, - // a plan-relative identifier starting from 0 is used as the function_anchor. - // The consumer is responsible for correctly registering - // mapping info stored in the extensions by the producer. - let function_anchor = match function_set.get(&function_name) { - Some(function_anchor) => { - // Function has been registered - *function_anchor - } - None => { - // Function has NOT been registered - let function_anchor = function_set.len() as u32; - function_set.insert(function_name.clone(), function_anchor); - - let function_extension = ExtensionFunction { - extension_uri_reference: u32::MAX, - function_anchor, - name: function_name, - }; - let simple_extension = extensions::SimpleExtensionDeclaration { - mapping_type: Some(MappingType::ExtensionFunction(function_extension)), - }; - function_extensions.push(simple_extension); - function_anchor - } - }; - - // Return function anchor - function_anchor -} - /// Return Substrait scalar function with two arguments #[allow(deprecated)] pub fn make_binary_op_scalar_func( lhs: &Expression, rhs: &Expression, op: Operator, - extension_info: &mut ( - Vec, - HashMap, - ), + extensions: &mut Extensions, ) -> Expression { - let function_anchor = - register_function(operator_to_name(op).to_string(), extension_info); + let function_anchor = extensions.register_function(operator_to_name(op).to_string()); Expression { rex_type: Some(RexType::ScalarFunction(ScalarFunction { function_reference: function_anchor, @@ -1010,17 +920,14 @@ pub fn make_binary_op_scalar_func( /// `col_ref(1) = col_ref(3 + 0)` /// , where `3` is the number of `left` columns (`col_ref_offset`) and `0` is the index /// of the join key column from `right` -/// * `extension_info` - Substrait extension info. Contains registered function information +/// * `extensions` - Substrait extension info. Contains registered function information #[allow(deprecated)] pub fn to_substrait_rex( ctx: &SessionContext, expr: &Expr, schema: &DFSchemaRef, col_ref_offset: usize, - extension_info: &mut ( - Vec, - HashMap, - ), + extensions: &mut Extensions, ) -> Result { match expr { Expr::InList(InList { @@ -1030,10 +937,10 @@ pub fn to_substrait_rex( }) => { let substrait_list = list .iter() - .map(|x| to_substrait_rex(ctx, x, schema, col_ref_offset, extension_info)) + .map(|x| to_substrait_rex(ctx, x, schema, col_ref_offset, extensions)) .collect::>>()?; let substrait_expr = - to_substrait_rex(ctx, expr, schema, col_ref_offset, extension_info)?; + to_substrait_rex(ctx, expr, schema, col_ref_offset, extensions)?; let substrait_or_list = Expression { rex_type: Some(RexType::SingularOrList(Box::new(SingularOrList { @@ -1043,8 +950,7 @@ pub fn to_substrait_rex( }; if *negated { - let function_anchor = - register_function("not".to_string(), extension_info); + let function_anchor = extensions.register_function("not".to_string()); Ok(Expression { rex_type: Some(RexType::ScalarFunction(ScalarFunction { @@ -1070,13 +976,12 @@ pub fn to_substrait_rex( arg, schema, col_ref_offset, - extension_info, + extensions, )?)), }); } - let function_anchor = - register_function(fun.name().to_string(), extension_info); + let function_anchor = extensions.register_function(fun.name().to_string()); Ok(Expression { rex_type: Some(RexType::ScalarFunction(ScalarFunction { function_reference: function_anchor, @@ -1096,58 +1001,58 @@ pub fn to_substrait_rex( if *negated { // `expr NOT BETWEEN low AND high` can be translated into (expr < low OR high < expr) let substrait_expr = - to_substrait_rex(ctx, expr, schema, col_ref_offset, extension_info)?; + to_substrait_rex(ctx, expr, schema, col_ref_offset, extensions)?; let substrait_low = - to_substrait_rex(ctx, low, schema, col_ref_offset, extension_info)?; + to_substrait_rex(ctx, low, schema, col_ref_offset, extensions)?; let substrait_high = - to_substrait_rex(ctx, high, schema, col_ref_offset, extension_info)?; + to_substrait_rex(ctx, high, schema, col_ref_offset, extensions)?; let l_expr = make_binary_op_scalar_func( &substrait_expr, &substrait_low, Operator::Lt, - extension_info, + extensions, ); let r_expr = make_binary_op_scalar_func( &substrait_high, &substrait_expr, Operator::Lt, - extension_info, + extensions, ); Ok(make_binary_op_scalar_func( &l_expr, &r_expr, Operator::Or, - extension_info, + extensions, )) } else { // `expr BETWEEN low AND high` can be translated into (low <= expr AND expr <= high) let substrait_expr = - to_substrait_rex(ctx, expr, schema, col_ref_offset, extension_info)?; + to_substrait_rex(ctx, expr, schema, col_ref_offset, extensions)?; let substrait_low = - to_substrait_rex(ctx, low, schema, col_ref_offset, extension_info)?; + to_substrait_rex(ctx, low, schema, col_ref_offset, extensions)?; let substrait_high = - to_substrait_rex(ctx, high, schema, col_ref_offset, extension_info)?; + to_substrait_rex(ctx, high, schema, col_ref_offset, extensions)?; let l_expr = make_binary_op_scalar_func( &substrait_low, &substrait_expr, Operator::LtEq, - extension_info, + extensions, ); let r_expr = make_binary_op_scalar_func( &substrait_expr, &substrait_high, Operator::LtEq, - extension_info, + extensions, ); Ok(make_binary_op_scalar_func( &l_expr, &r_expr, Operator::And, - extension_info, + extensions, )) } } @@ -1156,10 +1061,10 @@ pub fn to_substrait_rex( substrait_field_ref(index + col_ref_offset) } Expr::BinaryExpr(BinaryExpr { left, op, right }) => { - let l = to_substrait_rex(ctx, left, schema, col_ref_offset, extension_info)?; - let r = to_substrait_rex(ctx, right, schema, col_ref_offset, extension_info)?; + let l = to_substrait_rex(ctx, left, schema, col_ref_offset, extensions)?; + let r = to_substrait_rex(ctx, right, schema, col_ref_offset, extensions)?; - Ok(make_binary_op_scalar_func(&l, &r, *op, extension_info)) + Ok(make_binary_op_scalar_func(&l, &r, *op, extensions)) } Expr::Case(Case { expr, @@ -1176,7 +1081,7 @@ pub fn to_substrait_rex( e, schema, col_ref_offset, - extension_info, + extensions, )?), then: None, }); @@ -1189,14 +1094,14 @@ pub fn to_substrait_rex( r#if, schema, col_ref_offset, - extension_info, + extensions, )?), then: Some(to_substrait_rex( ctx, then, schema, col_ref_offset, - extension_info, + extensions, )?), }); } @@ -1208,7 +1113,7 @@ pub fn to_substrait_rex( e, schema, col_ref_offset, - extension_info, + extensions, )?)), None => None, }; @@ -1221,22 +1126,22 @@ pub fn to_substrait_rex( Ok(Expression { rex_type: Some(RexType::Cast(Box::new( substrait::proto::expression::Cast { - r#type: Some(to_substrait_type(data_type, true)?), + r#type: Some(to_substrait_type(data_type, true, extensions)?), input: Some(Box::new(to_substrait_rex( ctx, expr, schema, col_ref_offset, - extension_info, + extensions, )?)), failure_behavior: 0, // FAILURE_BEHAVIOR_UNSPECIFIED }, ))), }) } - Expr::Literal(value) => to_substrait_literal_expr(value), + Expr::Literal(value) => to_substrait_literal_expr(value, extensions), Expr::Alias(Alias { expr, .. }) => { - to_substrait_rex(ctx, expr, schema, col_ref_offset, extension_info) + to_substrait_rex(ctx, expr, schema, col_ref_offset, extensions) } Expr::WindowFunction(WindowFunction { fun, @@ -1247,7 +1152,7 @@ pub fn to_substrait_rex( null_treatment: _, }) => { // function reference - let function_anchor = register_function(fun.to_string(), extension_info); + let function_anchor = extensions.register_function(fun.to_string()); // arguments let mut arguments: Vec = vec![]; for arg in args { @@ -1257,19 +1162,19 @@ pub fn to_substrait_rex( arg, schema, col_ref_offset, - extension_info, + extensions, )?)), }); } // partition by expressions let partition_by = partition_by .iter() - .map(|e| to_substrait_rex(ctx, e, schema, col_ref_offset, extension_info)) + .map(|e| to_substrait_rex(ctx, e, schema, col_ref_offset, extensions)) .collect::>>()?; // order by expressions let order_by = order_by .iter() - .map(|e| substrait_sort_field(ctx, e, schema, extension_info)) + .map(|e| substrait_sort_field(ctx, e, schema, extensions)) .collect::>>()?; // window frame let bounds = to_substrait_bounds(window_frame)?; @@ -1298,7 +1203,7 @@ pub fn to_substrait_rex( *escape_char, schema, col_ref_offset, - extension_info, + extensions, ), Expr::InSubquery(InSubquery { expr, @@ -1306,10 +1211,10 @@ pub fn to_substrait_rex( negated, }) => { let substrait_expr = - to_substrait_rex(ctx, expr, schema, col_ref_offset, extension_info)?; + to_substrait_rex(ctx, expr, schema, col_ref_offset, extensions)?; let subquery_plan = - to_substrait_rel(subquery.subquery.as_ref(), ctx, extension_info)?; + to_substrait_rel(subquery.subquery.as_ref(), ctx, extensions)?; let substrait_subquery = Expression { rex_type: Some(RexType::Subquery(Box::new(Subquery { @@ -1324,8 +1229,7 @@ pub fn to_substrait_rex( }))), }; if *negated { - let function_anchor = - register_function("not".to_string(), extension_info); + let function_anchor = extensions.register_function("not".to_string()); Ok(Expression { rex_type: Some(RexType::ScalarFunction(ScalarFunction { @@ -1348,7 +1252,7 @@ pub fn to_substrait_rex( arg, schema, col_ref_offset, - extension_info, + extensions, ), Expr::IsNull(arg) => to_substrait_unary_scalar_fn( ctx, @@ -1356,7 +1260,7 @@ pub fn to_substrait_rex( arg, schema, col_ref_offset, - extension_info, + extensions, ), Expr::IsNotNull(arg) => to_substrait_unary_scalar_fn( ctx, @@ -1364,7 +1268,7 @@ pub fn to_substrait_rex( arg, schema, col_ref_offset, - extension_info, + extensions, ), Expr::IsTrue(arg) => to_substrait_unary_scalar_fn( ctx, @@ -1372,7 +1276,7 @@ pub fn to_substrait_rex( arg, schema, col_ref_offset, - extension_info, + extensions, ), Expr::IsFalse(arg) => to_substrait_unary_scalar_fn( ctx, @@ -1380,7 +1284,7 @@ pub fn to_substrait_rex( arg, schema, col_ref_offset, - extension_info, + extensions, ), Expr::IsUnknown(arg) => to_substrait_unary_scalar_fn( ctx, @@ -1388,7 +1292,7 @@ pub fn to_substrait_rex( arg, schema, col_ref_offset, - extension_info, + extensions, ), Expr::IsNotTrue(arg) => to_substrait_unary_scalar_fn( ctx, @@ -1396,7 +1300,7 @@ pub fn to_substrait_rex( arg, schema, col_ref_offset, - extension_info, + extensions, ), Expr::IsNotFalse(arg) => to_substrait_unary_scalar_fn( ctx, @@ -1404,7 +1308,7 @@ pub fn to_substrait_rex( arg, schema, col_ref_offset, - extension_info, + extensions, ), Expr::IsNotUnknown(arg) => to_substrait_unary_scalar_fn( ctx, @@ -1412,7 +1316,7 @@ pub fn to_substrait_rex( arg, schema, col_ref_offset, - extension_info, + extensions, ), Expr::Negative(arg) => to_substrait_unary_scalar_fn( ctx, @@ -1420,7 +1324,7 @@ pub fn to_substrait_rex( arg, schema, col_ref_offset, - extension_info, + extensions, ), _ => { not_impl_err!("Unsupported expression: {expr:?}") @@ -1428,7 +1332,11 @@ pub fn to_substrait_rex( } } -fn to_substrait_type(dt: &DataType, nullable: bool) -> Result { +fn to_substrait_type( + dt: &DataType, + nullable: bool, + extensions: &mut Extensions, +) -> Result { let nullability = if nullable { r#type::Nullability::Nullable as i32 } else { @@ -1548,7 +1456,9 @@ fn to_substrait_type(dt: &DataType, nullable: bool) -> Result Result { - let inner_type = to_substrait_type(inner.data_type(), inner.is_nullable())?; + let inner_type = + to_substrait_type(inner.data_type(), inner.is_nullable(), extensions)?; Ok(substrait::proto::Type { kind: Some(r#type::Kind::List(Box::new(r#type::List { r#type: Some(Box::new(inner_type)), @@ -1599,7 +1510,8 @@ fn to_substrait_type(dt: &DataType, nullable: bool) -> Result { - let inner_type = to_substrait_type(inner.data_type(), inner.is_nullable())?; + let inner_type = + to_substrait_type(inner.data_type(), inner.is_nullable(), extensions)?; Ok(substrait::proto::Type { kind: Some(r#type::Kind::List(Box::new(r#type::List { r#type: Some(Box::new(inner_type)), @@ -1613,10 +1525,12 @@ fn to_substrait_type(dt: &DataType, nullable: bool) -> Result Result { let field_types = fields .iter() - .map(|field| to_substrait_type(field.data_type(), field.is_nullable())) + .map(|field| { + to_substrait_type(field.data_type(), field.is_nullable(), extensions) + }) .collect::>>()?; Ok(substrait::proto::Type { kind: Some(r#type::Kind::Struct(r#type::Struct { @@ -1700,21 +1616,19 @@ fn make_substrait_like_expr( escape_char: Option, schema: &DFSchemaRef, col_ref_offset: usize, - extension_info: &mut ( - Vec, - HashMap, - ), + extensions: &mut Extensions, ) -> Result { let function_anchor = if ignore_case { - register_function("ilike".to_string(), extension_info) + extensions.register_function("ilike".to_string()) } else { - register_function("like".to_string(), extension_info) + extensions.register_function("like".to_string()) }; - let expr = to_substrait_rex(ctx, expr, schema, col_ref_offset, extension_info)?; - let pattern = to_substrait_rex(ctx, pattern, schema, col_ref_offset, extension_info)?; - let escape_char = to_substrait_literal_expr(&ScalarValue::Utf8( - escape_char.map(|c| c.to_string()), - ))?; + let expr = to_substrait_rex(ctx, expr, schema, col_ref_offset, extensions)?; + let pattern = to_substrait_rex(ctx, pattern, schema, col_ref_offset, extensions)?; + let escape_char = to_substrait_literal_expr( + &ScalarValue::Utf8(escape_char.map(|c| c.to_string())), + extensions, + )?; let arguments = vec![ FunctionArgument { arg_type: Some(ArgType::Value(expr)), @@ -1738,7 +1652,7 @@ fn make_substrait_like_expr( }; if negated { - let function_anchor = register_function("not".to_string(), extension_info); + let function_anchor = extensions.register_function("not".to_string()); Ok(Expression { rex_type: Some(RexType::ScalarFunction(ScalarFunction { @@ -1870,7 +1784,10 @@ fn to_substrait_bounds(window_frame: &WindowFrame) -> Result<(Bound, Bound)> { )) } -fn to_substrait_literal(value: &ScalarValue) -> Result { +fn to_substrait_literal( + value: &ScalarValue, + extensions: &mut Extensions, +) -> Result { if value.is_null() { return Ok(Literal { nullable: true, @@ -1878,6 +1795,7 @@ fn to_substrait_literal(value: &ScalarValue) -> Result { literal_type: Some(LiteralType::Null(to_substrait_type( &value.data_type(), true, + extensions, )?)), }); } @@ -1949,14 +1867,15 @@ fn to_substrait_literal(value: &ScalarValue) -> Result { let bytes = i.to_byte_slice(); ( LiteralType::UserDefined(UserDefined { - type_reference: INTERVAL_MONTH_DAY_NANO_TYPE_REF, + type_reference: extensions + .register_type(INTERVAL_MONTH_DAY_NANO_TYPE_NAME.to_string()), type_parameters: vec![], val: Some(user_defined::Val::Value(ProtoAny { - type_url: INTERVAL_MONTH_DAY_NANO_TYPE_URL.to_string(), + type_url: INTERVAL_MONTH_DAY_NANO_TYPE_NAME.to_string(), value: bytes.to_vec().into(), })), }), - INTERVAL_MONTH_DAY_NANO_TYPE_REF, + DEFAULT_TYPE_VARIATION_REF, ) } ScalarValue::IntervalDayTime(Some(i)) => ( @@ -1996,11 +1915,11 @@ fn to_substrait_literal(value: &ScalarValue) -> Result { DECIMAL_128_TYPE_VARIATION_REF, ), ScalarValue::List(l) => ( - convert_array_to_literal_list(l)?, + convert_array_to_literal_list(l, extensions)?, DEFAULT_CONTAINER_TYPE_VARIATION_REF, ), ScalarValue::LargeList(l) => ( - convert_array_to_literal_list(l)?, + convert_array_to_literal_list(l, extensions)?, LARGE_CONTAINER_TYPE_VARIATION_REF, ), ScalarValue::Struct(s) => ( @@ -2009,7 +1928,10 @@ fn to_substrait_literal(value: &ScalarValue) -> Result { .columns() .iter() .map(|col| { - to_substrait_literal(&ScalarValue::try_from_array(col, 0)?) + to_substrait_literal( + &ScalarValue::try_from_array(col, 0)?, + extensions, + ) }) .collect::>>()?, }), @@ -2030,16 +1952,26 @@ fn to_substrait_literal(value: &ScalarValue) -> Result { fn convert_array_to_literal_list( array: &GenericListArray, + extensions: &mut Extensions, ) -> Result { assert_eq!(array.len(), 1); let nested_array = array.value(0); let values = (0..nested_array.len()) - .map(|i| to_substrait_literal(&ScalarValue::try_from_array(&nested_array, i)?)) + .map(|i| { + to_substrait_literal( + &ScalarValue::try_from_array(&nested_array, i)?, + extensions, + ) + }) .collect::>>()?; if values.is_empty() { - let et = match to_substrait_type(array.data_type(), array.is_nullable())? { + let et = match to_substrait_type( + array.data_type(), + array.is_nullable(), + extensions, + )? { substrait::proto::Type { kind: Some(r#type::Kind::List(lt)), } => lt.as_ref().to_owned(), @@ -2051,8 +1983,11 @@ fn convert_array_to_literal_list( } } -fn to_substrait_literal_expr(value: &ScalarValue) -> Result { - let literal = to_substrait_literal(value)?; +fn to_substrait_literal_expr( + value: &ScalarValue, + extensions: &mut Extensions, +) -> Result { + let literal = to_substrait_literal(value, extensions)?; Ok(Expression { rex_type: Some(RexType::Literal(literal)), }) @@ -2065,14 +2000,10 @@ fn to_substrait_unary_scalar_fn( arg: &Expr, schema: &DFSchemaRef, col_ref_offset: usize, - extension_info: &mut ( - Vec, - HashMap, - ), + extensions: &mut Extensions, ) -> Result { - let function_anchor = register_function(fn_name.to_string(), extension_info); - let substrait_expr = - to_substrait_rex(ctx, arg, schema, col_ref_offset, extension_info)?; + let function_anchor = extensions.register_function(fn_name.to_string()); + let substrait_expr = to_substrait_rex(ctx, arg, schema, col_ref_offset, extensions)?; Ok(Expression { rex_type: Some(RexType::ScalarFunction(ScalarFunction { @@ -2116,10 +2047,7 @@ fn substrait_sort_field( ctx: &SessionContext, expr: &Expr, schema: &DFSchemaRef, - extension_info: &mut ( - Vec, - HashMap, - ), + extensions: &mut Extensions, ) -> Result { match expr { Expr::Sort(Sort { @@ -2127,7 +2055,7 @@ fn substrait_sort_field( asc, nulls_first, }) => { - let e = to_substrait_rex(ctx, expr, schema, 0, extension_info)?; + let e = to_substrait_rex(ctx, expr, schema, 0, extensions)?; let d = match (asc, nulls_first) { (true, true) => SortDirection::AscNullsFirst, (true, false) => SortDirection::AscNullsLast, @@ -2161,6 +2089,7 @@ fn substrait_field_ref(index: usize) -> Result { #[cfg(test)] mod test { + use super::*; use crate::logical_plan::consumer::{ from_substrait_literal_without_names, from_substrait_type_without_names, }; @@ -2168,8 +2097,7 @@ mod test { use datafusion::arrow::array::GenericListArray; use datafusion::arrow::datatypes::Field; use datafusion::common::scalar::ScalarStructBuilder; - - use super::*; + use std::collections::HashMap; #[test] fn round_trip_literals() -> Result<()> { @@ -2258,12 +2186,47 @@ mod test { fn round_trip_literal(scalar: ScalarValue) -> Result<()> { println!("Checking round trip of {scalar:?}"); - let substrait_literal = to_substrait_literal(&scalar)?; - let roundtrip_scalar = from_substrait_literal_without_names(&substrait_literal)?; + let mut extensions = Extensions::default(); + let substrait_literal = to_substrait_literal(&scalar, &mut extensions)?; + let roundtrip_scalar = + from_substrait_literal_without_names(&substrait_literal, &extensions)?; assert_eq!(scalar, roundtrip_scalar); Ok(()) } + #[test] + fn custom_type_literal_extensions() -> Result<()> { + let mut extensions = Extensions::default(); + // IntervalMonthDayNano is represented as a custom type in Substrait + let scalar = ScalarValue::IntervalMonthDayNano(Some(IntervalMonthDayNano::new( + 17, 25, 1234567890, + ))); + let substrait_literal = to_substrait_literal(&scalar, &mut extensions)?; + let roundtrip_scalar = + from_substrait_literal_without_names(&substrait_literal, &extensions)?; + assert_eq!(scalar, roundtrip_scalar); + + assert_eq!( + extensions, + Extensions { + functions: HashMap::new(), + types: HashMap::from([( + 0, + INTERVAL_MONTH_DAY_NANO_TYPE_NAME.to_string() + )]), + type_variations: HashMap::new(), + } + ); + + // Check we fail if we don't propagate extensions + assert!(from_substrait_literal_without_names( + &substrait_literal, + &Extensions::default() + ) + .is_err()); + Ok(()) + } + #[test] fn round_trip_types() -> Result<()> { round_trip_type(DataType::Boolean)?; @@ -2329,11 +2292,44 @@ mod test { fn round_trip_type(dt: DataType) -> Result<()> { println!("Checking round trip of {dt:?}"); + let mut extensions = Extensions::default(); + // As DataFusion doesn't consider nullability as a property of the type, but field, // it doesn't matter if we set nullability to true or false here. - let substrait = to_substrait_type(&dt, true)?; - let roundtrip_dt = from_substrait_type_without_names(&substrait)?; + let substrait = to_substrait_type(&dt, true, &mut extensions)?; + let roundtrip_dt = from_substrait_type_without_names(&substrait, &extensions)?; + assert_eq!(dt, roundtrip_dt); + Ok(()) + } + + #[test] + fn custom_type_extensions() -> Result<()> { + let mut extensions = Extensions::default(); + // IntervalMonthDayNano is represented as a custom type in Substrait + let dt = DataType::Interval(IntervalUnit::MonthDayNano); + + let substrait = to_substrait_type(&dt, true, &mut extensions)?; + let roundtrip_dt = from_substrait_type_without_names(&substrait, &extensions)?; assert_eq!(dt, roundtrip_dt); + + assert_eq!( + extensions, + Extensions { + functions: HashMap::new(), + types: HashMap::from([( + 0, + INTERVAL_MONTH_DAY_NANO_TYPE_NAME.to_string() + )]), + type_variations: HashMap::new(), + } + ); + + // Check we fail if we don't propagate extensions + assert!( + from_substrait_type_without_names(&substrait, &Extensions::default()) + .is_err() + ); + Ok(()) } } diff --git a/datafusion/substrait/src/variation_const.rs b/datafusion/substrait/src/variation_const.rs index 27f4b3ea228a6..c94ad2d669fde 100644 --- a/datafusion/substrait/src/variation_const.rs +++ b/datafusion/substrait/src/variation_const.rs @@ -25,13 +25,16 @@ //! - Default type reference is 0. It is used when the actual type is the same with the original type. //! - Extended variant type references start from 1, and ususlly increase by 1. //! -//! Definitions here are not the final form. All the non-system-preferred variations will be defined +//! TODO: Definitions here are not the final form. All the non-system-preferred variations will be defined //! using [simple extensions] as per the [spec of type_variations](https://substrait.io/types/type_variations/) +//! //! //! [simple extensions]: (https://substrait.io/extensions/#simple-extensions) // For [type variations](https://substrait.io/types/type_variations/#type-variations) in substrait. // Type variations are used to represent different types based on one type class. +// TODO: Define as extensions: + /// The "system-preferred" variation (i.e., no variation). pub const DEFAULT_TYPE_VARIATION_REF: u32 = 0; pub const UNSIGNED_INTEGER_TYPE_VARIATION_REF: u32 = 1; @@ -55,6 +58,7 @@ pub const DECIMAL_256_TYPE_VARIATION_REF: u32 = 1; /// [`DataType::Interval`]: datafusion::arrow::datatypes::DataType::Interval /// [`IntervalUnit::YearMonth`]: datafusion::arrow::datatypes::IntervalUnit::YearMonth /// [`ScalarValue::IntervalYearMonth`]: datafusion::common::ScalarValue::IntervalYearMonth +#[deprecated(since = "41.0.0", note = "Use Substrait `IntervalYear` type instead")] pub const INTERVAL_YEAR_MONTH_TYPE_REF: u32 = 1; /// For [`DataType::Interval`] with [`IntervalUnit::DayTime`]. @@ -68,6 +72,7 @@ pub const INTERVAL_YEAR_MONTH_TYPE_REF: u32 = 1; /// [`DataType::Interval`]: datafusion::arrow::datatypes::DataType::Interval /// [`IntervalUnit::DayTime`]: datafusion::arrow::datatypes::IntervalUnit::DayTime /// [`ScalarValue::IntervalDayTime`]: datafusion::common::ScalarValue::IntervalDayTime +#[deprecated(since = "41.0.0", note = "Use Substrait `IntervalDay` type instead")] pub const INTERVAL_DAY_TIME_TYPE_REF: u32 = 2; /// For [`DataType::Interval`] with [`IntervalUnit::MonthDayNano`]. @@ -82,21 +87,14 @@ pub const INTERVAL_DAY_TIME_TYPE_REF: u32 = 2; /// [`DataType::Interval`]: datafusion::arrow::datatypes::DataType::Interval /// [`IntervalUnit::MonthDayNano`]: datafusion::arrow::datatypes::IntervalUnit::MonthDayNano /// [`ScalarValue::IntervalMonthDayNano`]: datafusion::common::ScalarValue::IntervalMonthDayNano +#[deprecated( + since = "41.0.0", + note = "Use Substrait `UserDefinedType` with name `INTERVAL_MONTH_DAY_NANO_TYPE_NAME` instead" +)] pub const INTERVAL_MONTH_DAY_NANO_TYPE_REF: u32 = 3; -// For User Defined URLs -/// For [`DataType::Interval`] with [`IntervalUnit::YearMonth`]. -/// -/// [`DataType::Interval`]: datafusion::arrow::datatypes::DataType::Interval -/// [`IntervalUnit::YearMonth`]: datafusion::arrow::datatypes::IntervalUnit::YearMonth -pub const INTERVAL_YEAR_MONTH_TYPE_URL: &str = "interval-year-month"; -/// For [`DataType::Interval`] with [`IntervalUnit::DayTime`]. -/// -/// [`DataType::Interval`]: datafusion::arrow::datatypes::DataType::Interval -/// [`IntervalUnit::DayTime`]: datafusion::arrow::datatypes::IntervalUnit::DayTime -pub const INTERVAL_DAY_TIME_TYPE_URL: &str = "interval-day-time"; /// For [`DataType::Interval`] with [`IntervalUnit::MonthDayNano`]. /// /// [`DataType::Interval`]: datafusion::arrow::datatypes::DataType::Interval /// [`IntervalUnit::MonthDayNano`]: datafusion::arrow::datatypes::IntervalUnit::MonthDayNano -pub const INTERVAL_MONTH_DAY_NANO_TYPE_URL: &str = "interval-month-day-nano"; +pub const INTERVAL_MONTH_DAY_NANO_TYPE_NAME: &str = "interval-month-day-nano"; diff --git a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs index a7653e11d598f..5b4389c832c7c 100644 --- a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs @@ -38,7 +38,10 @@ use datafusion::optimizer::simplify_expressions::expr_simplifier::THRESHOLD_INLI use datafusion::prelude::*; use datafusion::execution::session_state::SessionStateBuilder; -use substrait::proto::extensions::simple_extension_declaration::MappingType; +use substrait::proto::extensions::simple_extension_declaration::{ + ExtensionType, MappingType, +}; +use substrait::proto::extensions::SimpleExtensionDeclaration; use substrait::proto::rel::RelType; use substrait::proto::{plan_rel, Plan, Rel}; @@ -175,15 +178,46 @@ async fn select_with_filter() -> Result<()> { #[tokio::test] async fn select_with_reused_functions() -> Result<()> { + let ctx = create_context().await?; let sql = "SELECT * FROM data WHERE a > 1 AND a < 10 AND b > 0"; - roundtrip(sql).await?; - let (mut function_names, mut function_anchors) = function_extension_info(sql).await?; - function_names.sort(); - function_anchors.sort(); + let proto = roundtrip_with_ctx(sql, ctx).await?; + let mut functions = proto + .extensions + .iter() + .map(|e| match e.mapping_type.as_ref().unwrap() { + MappingType::ExtensionFunction(ext_f) => { + (ext_f.function_anchor, ext_f.name.to_owned()) + } + _ => unreachable!("Non-function extensions not expected"), + }) + .collect::>(); + functions.sort_by_key(|(anchor, _)| *anchor); + + // Functions are encountered (and thus registered) depth-first + let expected = vec![ + (0, "gt".to_string()), + (1, "lt".to_string()), + (2, "and".to_string()), + ]; + assert_eq!(functions, expected); - assert_eq!(function_names, ["and", "gt", "lt"]); - assert_eq!(function_anchors, [0, 1, 2]); + Ok(()) +} +#[tokio::test] +async fn roundtrip_udt_extensions() -> Result<()> { + let ctx = create_context().await?; + let proto = + roundtrip_with_ctx("SELECT INTERVAL '1 YEAR 1 DAY 1 SECOND' FROM data", ctx) + .await?; + let expected_type = SimpleExtensionDeclaration { + mapping_type: Some(MappingType::ExtensionType(ExtensionType { + extension_uri_reference: u32::MAX, + type_anchor: 0, + name: "interval-month-day-nano".to_string(), + })), + }; + assert_eq!(proto.extensions, vec![expected_type]); Ok(()) } @@ -858,7 +892,8 @@ async fn roundtrip_aggregate_udf() -> Result<()> { let ctx = create_context().await?; ctx.register_udaf(dummy_agg); - roundtrip_with_ctx("select dummy_agg(a) from data", ctx).await + roundtrip_with_ctx("select dummy_agg(a) from data", ctx).await?; + Ok(()) } #[tokio::test] @@ -891,7 +926,8 @@ async fn roundtrip_window_udf() -> Result<()> { let ctx = create_context().await?; ctx.register_udwf(dummy_agg); - roundtrip_with_ctx("select dummy_window(a) OVER () from data", ctx).await + roundtrip_with_ctx("select dummy_window(a) OVER () from data", ctx).await?; + Ok(()) } #[tokio::test] @@ -1083,7 +1119,7 @@ async fn test_alias(sql_with_alias: &str, sql_no_alias: &str) -> Result<()> { Ok(()) } -async fn roundtrip_with_ctx(sql: &str, ctx: SessionContext) -> Result<()> { +async fn roundtrip_with_ctx(sql: &str, ctx: SessionContext) -> Result> { let df = ctx.sql(sql).await?; let plan = df.into_optimized_plan()?; let proto = to_substrait_plan(&plan, &ctx)?; @@ -1102,56 +1138,25 @@ async fn roundtrip_with_ctx(sql: &str, ctx: SessionContext) -> Result<()> { assert_eq!(plan.schema(), plan2.schema()); DataFrame::new(ctx.state(), plan2).show().await?; - Ok(()) + Ok(proto) } async fn roundtrip(sql: &str) -> Result<()> { - roundtrip_with_ctx(sql, create_context().await?).await + roundtrip_with_ctx(sql, create_context().await?).await?; + Ok(()) } async fn roundtrip_verify_post_join_filter(sql: &str) -> Result<()> { let ctx = create_context().await?; - let df = ctx.sql(sql).await?; - let plan = df.into_optimized_plan()?; - let proto = to_substrait_plan(&plan, &ctx)?; - let plan2 = from_substrait_plan(&ctx, &proto).await?; - let plan2 = ctx.state().optimize(&plan2)?; - - println!("{plan:#?}"); - println!("{plan2:#?}"); - - let plan1str = format!("{plan:?}"); - let plan2str = format!("{plan2:?}"); - assert_eq!(plan1str, plan2str); - - assert_eq!(plan.schema(), plan2.schema()); + let proto = roundtrip_with_ctx(sql, ctx).await?; // verify that the join filters are None verify_post_join_filter_value(proto).await } async fn roundtrip_all_types(sql: &str) -> Result<()> { - roundtrip_with_ctx(sql, create_all_type_context().await?).await -} - -async fn function_extension_info(sql: &str) -> Result<(Vec, Vec)> { - let ctx = create_context().await?; - let df = ctx.sql(sql).await?; - let plan = df.into_optimized_plan()?; - let proto = to_substrait_plan(&plan, &ctx)?; - - let mut function_names: Vec = vec![]; - let mut function_anchors: Vec = vec![]; - for e in &proto.extensions { - let (function_anchor, function_name) = match e.mapping_type.as_ref().unwrap() { - MappingType::ExtensionFunction(ext_f) => (ext_f.function_anchor, &ext_f.name), - _ => unreachable!("Producer does not generate a non-function extension"), - }; - function_names.push(function_name.to_string()); - function_anchors.push(function_anchor); - } - - Ok((function_names, function_anchors)) + roundtrip_with_ctx(sql, create_all_type_context().await?).await?; + Ok(()) } async fn create_context() -> Result {