diff --git a/datafusion/substrait/src/logical_plan/consumer/types.rs b/datafusion/substrait/src/logical_plan/consumer/types.rs index 7bc30e433d868..4ea479e7cccd2 100644 --- a/datafusion/substrait/src/logical_plan/consumer/types.rs +++ b/datafusion/substrait/src/logical_plan/consumer/types.rs @@ -21,7 +21,8 @@ use super::SubstraitConsumer; use crate::variation_const::{ DATE_32_TYPE_VARIATION_REF, DATE_64_TYPE_VARIATION_REF, DECIMAL_128_TYPE_VARIATION_REF, DECIMAL_256_TYPE_VARIATION_REF, - DEFAULT_CONTAINER_TYPE_VARIATION_REF, DEFAULT_TYPE_VARIATION_REF, + DEFAULT_CONTAINER_TYPE_VARIATION_REF, DEFAULT_INTERVAL_DAY_TYPE_VARIATION_REF, + DEFAULT_TYPE_VARIATION_REF, DURATION_INTERVAL_DAY_TYPE_VARIATION_REF, INTERVAL_DAY_TIME_TYPE_REF, INTERVAL_MONTH_DAY_NANO_TYPE_NAME, INTERVAL_MONTH_DAY_NANO_TYPE_REF, INTERVAL_YEAR_MONTH_TYPE_REF, LARGE_CONTAINER_TYPE_VARIATION_REF, TIMESTAMP_MICRO_TYPE_VARIATION_REF, @@ -213,7 +214,28 @@ pub fn from_substrait_type( r#type::Kind::IntervalYear(_) => { Ok(DataType::Interval(IntervalUnit::YearMonth)) } - r#type::Kind::IntervalDay(_) => Ok(DataType::Interval(IntervalUnit::DayTime)), + r#type::Kind::IntervalDay(i) => match i.type_variation_reference { + DEFAULT_INTERVAL_DAY_TYPE_VARIATION_REF => { + Ok(DataType::Interval(IntervalUnit::DayTime)) + } + DURATION_INTERVAL_DAY_TYPE_VARIATION_REF => { + let duration_unit = match i.precision { + Some(0) => Ok(TimeUnit::Second), + Some(3) => Ok(TimeUnit::Millisecond), + Some(6) => Ok(TimeUnit::Microsecond), + Some(9) => Ok(TimeUnit::Nanosecond), + p => { + not_impl_err!( + "Unsupported Substrait precision {p:?} for Duration" + ) + } + }?; + Ok(DataType::Duration(duration_unit)) + } + v => not_impl_err!( + "Unsupported Substrait type variation {v} of type {s_kind:?}" + ), + }, r#type::Kind::IntervalCompound(_) => { Ok(DataType::Interval(IntervalUnit::MonthDayNano)) } diff --git a/datafusion/substrait/src/logical_plan/producer/types.rs b/datafusion/substrait/src/logical_plan/producer/types.rs index 61b7a79095d57..6a63bbef5d7d0 100644 --- a/datafusion/substrait/src/logical_plan/producer/types.rs +++ b/datafusion/substrait/src/logical_plan/producer/types.rs @@ -19,7 +19,8 @@ use crate::logical_plan::producer::utils::flatten_names; use crate::variation_const::{ DATE_32_TYPE_VARIATION_REF, DATE_64_TYPE_VARIATION_REF, DECIMAL_128_TYPE_VARIATION_REF, DECIMAL_256_TYPE_VARIATION_REF, - DEFAULT_CONTAINER_TYPE_VARIATION_REF, DEFAULT_TYPE_VARIATION_REF, + DEFAULT_CONTAINER_TYPE_VARIATION_REF, DEFAULT_INTERVAL_DAY_TYPE_VARIATION_REF, + DEFAULT_TYPE_VARIATION_REF, DURATION_INTERVAL_DAY_TYPE_VARIATION_REF, LARGE_CONTAINER_TYPE_VARIATION_REF, UNSIGNED_INTEGER_TYPE_VARIATION_REF, VIEW_CONTAINER_TYPE_VARIATION_REF, }; @@ -153,7 +154,7 @@ pub(crate) fn to_substrait_type( }), IntervalUnit::DayTime => Ok(substrait::proto::Type { kind: Some(r#type::Kind::IntervalDay(r#type::IntervalDay { - type_variation_reference: DEFAULT_TYPE_VARIATION_REF, + type_variation_reference: DEFAULT_INTERVAL_DAY_TYPE_VARIATION_REF, nullability, precision: Some(3), // DayTime precision is always milliseconds })), @@ -171,6 +172,21 @@ pub(crate) fn to_substrait_type( } } } + DataType::Duration(duration_unit) => { + let precision = match duration_unit { + TimeUnit::Second => 0, + TimeUnit::Millisecond => 3, + TimeUnit::Microsecond => 6, + TimeUnit::Nanosecond => 9, + }; + Ok(substrait::proto::Type { + kind: Some(r#type::Kind::IntervalDay(r#type::IntervalDay { + type_variation_reference: DURATION_INTERVAL_DAY_TYPE_VARIATION_REF, + nullability, + precision: Some(precision), + })), + }) + } DataType::Binary => Ok(substrait::proto::Type { kind: Some(r#type::Kind::Binary(r#type::Binary { type_variation_reference: DEFAULT_CONTAINER_TYPE_VARIATION_REF, @@ -388,6 +404,11 @@ mod tests { round_trip_type(DataType::Interval(IntervalUnit::MonthDayNano))?; round_trip_type(DataType::Interval(IntervalUnit::DayTime))?; + round_trip_type(DataType::Duration(TimeUnit::Second))?; + round_trip_type(DataType::Duration(TimeUnit::Millisecond))?; + round_trip_type(DataType::Duration(TimeUnit::Microsecond))?; + round_trip_type(DataType::Duration(TimeUnit::Nanosecond))?; + Ok(()) } diff --git a/datafusion/substrait/src/variation_const.rs b/datafusion/substrait/src/variation_const.rs index e5bebf8e11819..efde8efe509e1 100644 --- a/datafusion/substrait/src/variation_const.rs +++ b/datafusion/substrait/src/variation_const.rs @@ -55,6 +55,15 @@ pub const LARGE_CONTAINER_TYPE_VARIATION_REF: u32 = 1; pub const VIEW_CONTAINER_TYPE_VARIATION_REF: u32 = 2; pub const DECIMAL_128_TYPE_VARIATION_REF: u32 = 0; pub const DECIMAL_256_TYPE_VARIATION_REF: u32 = 1; +/// Used for the arrow type [`DataType::Interval`] with [`IntervalUnit::DayTime`]. +/// +/// [`DataType::Interval`]: datafusion::arrow::datatypes::DataType::Interval +/// [`IntervalUnit::DayTime`]: datafusion::arrow::datatypes::IntervalUnit::DayTime +pub const DEFAULT_INTERVAL_DAY_TYPE_VARIATION_REF: u32 = 0; +/// Used for the arrow type [`DataType::Duration`]. +/// +/// [`DataType::Duration`]: datafusion::arrow::datatypes::DataType::Duration +pub const DURATION_INTERVAL_DAY_TYPE_VARIATION_REF: u32 = 1; // For [user-defined types](https://substrait.io/types/type_classes/#user-defined-types). /// For [`DataType::Interval`] with [`IntervalUnit::YearMonth`].