diff --git a/datafusion/ffi/src/table_provider.rs b/datafusion/ffi/src/table_provider.rs index 890511997a706..1215c05066217 100644 --- a/datafusion/ffi/src/table_provider.rs +++ b/datafusion/ffi/src/table_provider.rs @@ -195,7 +195,7 @@ fn supports_filters_pushdown_internal( let proto_filters = LogicalExprList::decode(filters_serialized) .map_err(|e| DataFusionError::Plan(e.to_string()))?; - parse_exprs(proto_filters.expr.iter(), &default_ctx, &codec)? + parse_exprs(proto_filters.expr.iter(), &default_ctx.task_ctx(), &codec)? } }; let filters_borrowed: Vec<&Expr> = filters.iter().collect(); @@ -252,7 +252,7 @@ unsafe extern "C" fn scan_fn_wrapper( rresult_return!(parse_exprs( proto_filters.expr.iter(), - &default_ctx, + &default_ctx.task_ctx(), &codec )) } diff --git a/datafusion/ffi/src/udtf.rs b/datafusion/ffi/src/udtf.rs index ceedec2599a29..f7b8c55789e05 100644 --- a/datafusion/ffi/src/udtf.rs +++ b/datafusion/ffi/src/udtf.rs @@ -97,8 +97,11 @@ unsafe extern "C" fn call_fn_wrapper( let proto_filters = rresult_return!(LogicalExprList::decode(args.as_ref())); - let args = - rresult_return!(parse_exprs(proto_filters.expr.iter(), &default_ctx, &codec)); + let args = rresult_return!(parse_exprs( + proto_filters.expr.iter(), + &default_ctx.task_ctx(), + &codec + )); let table_provider = rresult_return!(udtf.call(&args)); RResult::ROk(FFI_TableProvider::new(table_provider, false, runtime)) diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index f9400d14a59c9..e94670e6e7c51 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -406,9 +406,21 @@ message LogicalExprNode { Unnest unnest = 35; + ScalarSubquery scalar_subquery = 36; + OuterReferenceColumn outer_reference_column = 37; } } +message ScalarSubquery { + LogicalPlanNode subquery = 1; + repeated LogicalExprNode outer_ref_columns = 2; +} + +message OuterReferenceColumn { + datafusion_common.Field field = 1; + datafusion_common.Column column = 2; +} + message Wildcard { TableReference qualifier = 1; } diff --git a/datafusion/proto/src/bytes/mod.rs b/datafusion/proto/src/bytes/mod.rs index 12d9938373ce6..b538de259a6c9 100644 --- a/datafusion/proto/src/bytes/mod.rs +++ b/datafusion/proto/src/bytes/mod.rs @@ -24,25 +24,22 @@ use crate::physical_plan::{ AsExecutionPlan, DefaultPhysicalExtensionCodec, PhysicalExtensionCodec, }; use crate::protobuf; -use datafusion_common::{plan_datafusion_err, Result}; +use arrow::datatypes::{DataType, SchemaRef}; +use datafusion_catalog::TableProvider; +use datafusion_common::{plan_datafusion_err, Result, TableReference}; use datafusion_execution::TaskContext; use datafusion_expr::{ - create_udaf, create_udf, create_udwf, AggregateUDF, Expr, LogicalPlan, Volatility, - WindowUDF, + create_udaf, create_udf, create_udwf, AggregateUDF, Expr, Extension, LogicalPlan, + ScalarUDF, Volatility, WindowUDF, }; +use datafusion_physical_plan::ExecutionPlan; +// Reexport Bytes which appears in the API use prost::{ bytes::{Bytes, BytesMut}, Message, }; use std::sync::Arc; -// Reexport Bytes which appears in the API -use datafusion_execution::registry::FunctionRegistry; -use datafusion_expr::planner::ExprPlanner; -use datafusion_physical_plan::ExecutionPlan; - -mod registry; - /// Encodes something (such as [`Expr`]) to/from a stream of /// bytes. /// @@ -66,24 +63,19 @@ pub trait Serializeable: Sized { /// Convert `bytes` (the output of [`to_bytes`]) back into an /// object. This will error if the serialized bytes contain any - /// user defined functions, in which case use - /// [`from_bytes_with_registry`] + /// user defined functions, in which case use [`from_bytes_with_ctx`]. /// /// [`to_bytes`]: Self::to_bytes - /// [`from_bytes_with_registry`]: Self::from_bytes_with_registry + /// [`from_bytes_with_ctx`]: Self::from_bytes_with_ctx fn from_bytes(bytes: &[u8]) -> Result { - Self::from_bytes_with_registry(bytes, ®istry::NoRegistry {}) + Self::from_bytes_with_ctx(bytes, &TaskContext::default()) } - /// Convert `bytes` (the output of [`to_bytes`]) back into an - /// object resolving user defined functions with the specified - /// `registry` + /// Convert `bytes` (the output of [`to_bytes`]) back into an object + /// resolving user defined functions with the specified `ctx`. /// /// [`to_bytes`]: Self::to_bytes - fn from_bytes_with_registry( - bytes: &[u8], - registry: &dyn FunctionRegistry, - ) -> Result; + fn from_bytes_with_ctx(bytes: &[u8], ctx: &TaskContext) -> Result; } impl Serializeable for Expr { @@ -104,95 +96,106 @@ impl Serializeable for Expr { // Until the underlying prost issue ( https://github.com/tokio-rs/prost/issues/736 ) is fixed, we try to // deserialize the data here and check for errors. // - // Need to provide some placeholder registry because the stream may contain UDFs - struct PlaceHolderRegistry; - - impl FunctionRegistry for PlaceHolderRegistry { - fn udfs(&self) -> std::collections::HashSet { - std::collections::HashSet::default() - } - - fn udf(&self, name: &str) -> Result> { + // Need to provide some placeholder codec because the stream may contain UDFs + // (using codec since with TaskContext we can't pass through unknown udfs; it + // requires registering them beforehand) + #[derive(Debug)] + struct PlaceholderLogicalExtensionCodec {} + impl LogicalExtensionCodec for PlaceholderLogicalExtensionCodec { + fn try_decode_udf(&self, name: &str, _buf: &[u8]) -> Result> { Ok(Arc::new(create_udf( name, vec![], - arrow::datatypes::DataType::Null, + DataType::Null, Volatility::Immutable, Arc::new(|_| unimplemented!()), ))) } - fn udaf(&self, name: &str) -> Result> { + fn try_decode_udaf( + &self, + name: &str, + _buf: &[u8], + ) -> Result> { Ok(Arc::new(create_udaf( name, - vec![arrow::datatypes::DataType::Null], - Arc::new(arrow::datatypes::DataType::Null), + vec![DataType::Null], + Arc::new(DataType::Null), Volatility::Immutable, Arc::new(|_| unimplemented!()), Arc::new(vec![]), ))) } - fn udwf(&self, name: &str) -> Result> { + fn try_decode_udwf(&self, name: &str, _buf: &[u8]) -> Result> { Ok(Arc::new(create_udwf( name, - arrow::datatypes::DataType::Null, - Arc::new(arrow::datatypes::DataType::Null), + DataType::Null, + Arc::new(DataType::Null), Volatility::Immutable, Arc::new(|| unimplemented!()), ))) } - fn register_udaf( - &mut self, - _udaf: Arc, - ) -> Result>> { - datafusion_common::internal_err!( - "register_udaf called in Placeholder Registry!" - ) - } - fn register_udf( - &mut self, - _udf: Arc, - ) -> Result>> { - datafusion_common::internal_err!( - "register_udf called in Placeholder Registry!" - ) - } - fn register_udwf( - &mut self, - _udaf: Arc, - ) -> Result>> { - datafusion_common::internal_err!( - "register_udwf called in Placeholder Registry!" - ) + + fn try_decode( + &self, + _buf: &[u8], + _inputs: &[LogicalPlan], + _ctx: &TaskContext, + ) -> Result { + unimplemented!() } - fn expr_planners(&self) -> Vec> { - vec![] + fn try_encode(&self, _node: &Extension, _buf: &mut Vec) -> Result<()> { + unimplemented!() } - fn udafs(&self) -> std::collections::HashSet { - std::collections::HashSet::default() + fn try_decode_table_provider( + &self, + _buf: &[u8], + _table_ref: &TableReference, + _schema: SchemaRef, + _ctx: &TaskContext, + ) -> Result> { + unimplemented!() } - fn udwfs(&self) -> std::collections::HashSet { - std::collections::HashSet::default() + fn try_encode_table_provider( + &self, + _table_ref: &TableReference, + _node: Arc, + _buf: &mut Vec, + ) -> Result<()> { + unimplemented!() } } - Expr::from_bytes_with_registry(&bytes, &PlaceHolderRegistry)?; + + // Copied from from_bytes_with_ctx below but with placeholder registry instead of + // default. + { + let bytes: &[u8] = &bytes; + let protobuf = protobuf::LogicalExprNode::decode(bytes).map_err(|e| { + plan_datafusion_err!("Error decoding expr as protobuf: {e}") + })?; + + let extension_codec = PlaceholderLogicalExtensionCodec {}; + logical_plan::from_proto::parse_expr( + &protobuf, + &TaskContext::default(), + &extension_codec, + ) + .map_err(|e| plan_datafusion_err!("Error parsing protobuf into Expr: {e}"))?; + } Ok(bytes) } - fn from_bytes_with_registry( - bytes: &[u8], - registry: &dyn FunctionRegistry, - ) -> Result { + fn from_bytes_with_ctx(bytes: &[u8], ctx: &TaskContext) -> Result { let protobuf = protobuf::LogicalExprNode::decode(bytes) .map_err(|e| plan_datafusion_err!("Error decoding expr as protobuf: {e}"))?; let extension_codec = DefaultLogicalExtensionCodec {}; - logical_plan::from_proto::parse_expr(&protobuf, registry, &extension_codec) + logical_plan::from_proto::parse_expr(&protobuf, ctx, &extension_codec) .map_err(|e| plan_datafusion_err!("Error parsing protobuf into Expr: {e}")) } } diff --git a/datafusion/proto/src/bytes/registry.rs b/datafusion/proto/src/bytes/registry.rs deleted file mode 100644 index 087e073db21af..0000000000000 --- a/datafusion/proto/src/bytes/registry.rs +++ /dev/null @@ -1,70 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use std::{collections::HashSet, sync::Arc}; - -use datafusion_common::plan_err; -use datafusion_common::Result; -use datafusion_execution::registry::FunctionRegistry; -use datafusion_expr::planner::ExprPlanner; -use datafusion_expr::{AggregateUDF, ScalarUDF, WindowUDF}; - -/// A default [`FunctionRegistry`] registry that does not resolve any -/// user defined functions -pub(crate) struct NoRegistry {} - -impl FunctionRegistry for NoRegistry { - fn udfs(&self) -> HashSet { - HashSet::new() - } - - fn udf(&self, name: &str) -> Result> { - plan_err!("No function registry provided to deserialize, so can not deserialize User Defined Function '{name}'") - } - - fn udaf(&self, name: &str) -> Result> { - plan_err!("No function registry provided to deserialize, so can not deserialize User Defined Aggregate Function '{name}'") - } - - fn udwf(&self, name: &str) -> Result> { - plan_err!("No function registry provided to deserialize, so can not deserialize User Defined Window Function '{name}'") - } - fn register_udaf( - &mut self, - udaf: Arc, - ) -> Result>> { - plan_err!("No function registry provided to deserialize, so can not register User Defined Aggregate Function '{}'", udaf.inner().name()) - } - fn register_udf(&mut self, udf: Arc) -> Result>> { - plan_err!("No function registry provided to deserialize, so can not deserialize User Defined Function '{}'", udf.inner().name()) - } - fn register_udwf(&mut self, udwf: Arc) -> Result>> { - plan_err!("No function registry provided to deserialize, so can not deserialize User Defined Window Function '{}'", udwf.inner().name()) - } - - fn expr_planners(&self) -> Vec> { - vec![] - } - - fn udafs(&self) -> HashSet { - HashSet::new() - } - - fn udwfs(&self) -> HashSet { - HashSet::new() - } -} diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index 4cf834d0601e4..b014f8f0919be 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -11382,6 +11382,12 @@ impl serde::Serialize for LogicalExprNode { logical_expr_node::ExprType::Unnest(v) => { struct_ser.serialize_field("unnest", v)?; } + logical_expr_node::ExprType::ScalarSubquery(v) => { + struct_ser.serialize_field("scalarSubquery", v)?; + } + logical_expr_node::ExprType::OuterReferenceColumn(v) => { + struct_ser.serialize_field("outerReferenceColumn", v)?; + } } } struct_ser.end() @@ -11443,6 +11449,10 @@ impl<'de> serde::Deserialize<'de> for LogicalExprNode { "similarTo", "placeholder", "unnest", + "scalar_subquery", + "scalarSubquery", + "outer_reference_column", + "outerReferenceColumn", ]; #[allow(clippy::enum_variant_names)] @@ -11478,6 +11488,8 @@ impl<'de> serde::Deserialize<'de> for LogicalExprNode { SimilarTo, Placeholder, Unnest, + ScalarSubquery, + OuterReferenceColumn, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -11530,6 +11542,8 @@ impl<'de> serde::Deserialize<'de> for LogicalExprNode { "similarTo" | "similar_to" => Ok(GeneratedField::SimilarTo), "placeholder" => Ok(GeneratedField::Placeholder), "unnest" => Ok(GeneratedField::Unnest), + "scalarSubquery" | "scalar_subquery" => Ok(GeneratedField::ScalarSubquery), + "outerReferenceColumn" | "outer_reference_column" => Ok(GeneratedField::OuterReferenceColumn), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -11767,6 +11781,20 @@ impl<'de> serde::Deserialize<'de> for LogicalExprNode { return Err(serde::de::Error::duplicate_field("unnest")); } expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_expr_node::ExprType::Unnest) +; + } + GeneratedField::ScalarSubquery => { + if expr_type__.is_some() { + return Err(serde::de::Error::duplicate_field("scalarSubquery")); + } + expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_expr_node::ExprType::ScalarSubquery) +; + } + GeneratedField::OuterReferenceColumn => { + if expr_type__.is_some() { + return Err(serde::de::Error::duplicate_field("outerReferenceColumn")); + } + expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_expr_node::ExprType::OuterReferenceColumn) ; } } @@ -13559,6 +13587,114 @@ impl<'de> serde::Deserialize<'de> for OptimizedPhysicalPlanType { deserializer.deserialize_struct("datafusion.OptimizedPhysicalPlanType", FIELDS, GeneratedVisitor) } } +impl serde::Serialize for OuterReferenceColumn { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if self.field.is_some() { + len += 1; + } + if self.column.is_some() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.OuterReferenceColumn", len)?; + if let Some(v) = self.field.as_ref() { + struct_ser.serialize_field("field", v)?; + } + if let Some(v) = self.column.as_ref() { + struct_ser.serialize_field("column", v)?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for OuterReferenceColumn { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "field", + "column", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + Field, + Column, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "field" => Ok(GeneratedField::Field), + "column" => Ok(GeneratedField::Column), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = OuterReferenceColumn; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion.OuterReferenceColumn") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut field__ = None; + let mut column__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::Field => { + if field__.is_some() { + return Err(serde::de::Error::duplicate_field("field")); + } + field__ = map_.next_value()?; + } + GeneratedField::Column => { + if column__.is_some() { + return Err(serde::de::Error::duplicate_field("column")); + } + column__ = map_.next_value()?; + } + } + } + Ok(OuterReferenceColumn { + field: field__, + column: column__, + }) + } + } + deserializer.deserialize_struct("datafusion.OuterReferenceColumn", FIELDS, GeneratedVisitor) + } +} impl serde::Serialize for ParquetScanExecNode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result @@ -19998,6 +20134,115 @@ impl<'de> serde::Deserialize<'de> for RollupNode { deserializer.deserialize_struct("datafusion.RollupNode", FIELDS, GeneratedVisitor) } } +impl serde::Serialize for ScalarSubquery { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if self.subquery.is_some() { + len += 1; + } + if !self.outer_ref_columns.is_empty() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.ScalarSubquery", len)?; + if let Some(v) = self.subquery.as_ref() { + struct_ser.serialize_field("subquery", v)?; + } + if !self.outer_ref_columns.is_empty() { + struct_ser.serialize_field("outerRefColumns", &self.outer_ref_columns)?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for ScalarSubquery { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "subquery", + "outer_ref_columns", + "outerRefColumns", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + Subquery, + OuterRefColumns, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "subquery" => Ok(GeneratedField::Subquery), + "outerRefColumns" | "outer_ref_columns" => Ok(GeneratedField::OuterRefColumns), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = ScalarSubquery; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion.ScalarSubquery") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut subquery__ = None; + let mut outer_ref_columns__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::Subquery => { + if subquery__.is_some() { + return Err(serde::de::Error::duplicate_field("subquery")); + } + subquery__ = map_.next_value()?; + } + GeneratedField::OuterRefColumns => { + if outer_ref_columns__.is_some() { + return Err(serde::de::Error::duplicate_field("outerRefColumns")); + } + outer_ref_columns__ = Some(map_.next_value()?); + } + } + } + Ok(ScalarSubquery { + subquery: subquery__, + outer_ref_columns: outer_ref_columns__.unwrap_or_default(), + }) + } + } + deserializer.deserialize_struct("datafusion.ScalarSubquery", FIELDS, GeneratedVisitor) + } +} impl serde::Serialize for ScalarUdfExprNode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index 12b4176274113..5cd61e9a48ba3 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -195,8 +195,8 @@ pub mod projection_node { pub struct SelectionNode { #[prost(message, optional, boxed, tag = "1")] pub input: ::core::option::Option<::prost::alloc::boxed::Box>, - #[prost(message, optional, tag = "2")] - pub expr: ::core::option::Option, + #[prost(message, optional, boxed, tag = "2")] + pub expr: ::core::option::Option<::prost::alloc::boxed::Box>, } #[derive(Clone, PartialEq, ::prost::Message)] pub struct SortNode { @@ -382,8 +382,8 @@ pub struct JoinNode { pub right_join_key: ::prost::alloc::vec::Vec, #[prost(enumeration = "super::datafusion_common::NullEquality", tag = "7")] pub null_equality: i32, - #[prost(message, optional, tag = "8")] - pub filter: ::core::option::Option, + #[prost(message, optional, boxed, tag = "8")] + pub filter: ::core::option::Option<::prost::alloc::boxed::Box>, } #[derive(Clone, PartialEq, ::prost::Message)] pub struct DistinctNode { @@ -566,7 +566,7 @@ pub struct SubqueryAliasNode { pub struct LogicalExprNode { #[prost( oneof = "logical_expr_node::ExprType", - tags = "1, 2, 3, 4, 6, 7, 8, 9, 10, 11, 13, 14, 15, 17, 18, 19, 20, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35" + tags = "1, 2, 3, 4, 6, 7, 8, 9, 10, 11, 13, 14, 15, 17, 18, 19, 20, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37" )] pub expr_type: ::core::option::Option, } @@ -644,8 +644,26 @@ pub mod logical_expr_node { Placeholder(super::PlaceholderNode), #[prost(message, tag = "35")] Unnest(super::Unnest), + #[prost(message, tag = "36")] + ScalarSubquery(::prost::alloc::boxed::Box), + #[prost(message, tag = "37")] + OuterReferenceColumn(super::OuterReferenceColumn), } } +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct ScalarSubquery { + #[prost(message, optional, boxed, tag = "1")] + pub subquery: ::core::option::Option<::prost::alloc::boxed::Box>, + #[prost(message, repeated, tag = "2")] + pub outer_ref_columns: ::prost::alloc::vec::Vec, +} +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct OuterReferenceColumn { + #[prost(message, optional, tag = "1")] + pub field: ::core::option::Option, + #[prost(message, optional, tag = "2")] + pub column: ::core::option::Option, +} #[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)] pub struct Wildcard { #[prost(message, optional, tag = "1")] diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index 598a77f5420e2..4c71ca4c755f3 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -20,9 +20,10 @@ use std::sync::Arc; use arrow::datatypes::Field; use datafusion_common::{ exec_datafusion_err, internal_err, plan_datafusion_err, NullEquality, - RecursionUnnestOption, Result, ScalarValue, TableReference, UnnestOptions, + RecursionUnnestOption, Result, ScalarValue, Spans, TableReference, UnnestOptions, }; use datafusion_execution::registry::FunctionRegistry; +use datafusion_execution::TaskContext; use datafusion_expr::dml::InsertOp; use datafusion_expr::expr::{Alias, NullTreatment, Placeholder, Sort}; use datafusion_expr::expr::{Unnest, WildcardOptions}; @@ -34,12 +35,14 @@ use datafusion_expr::{ JoinConstraint, JoinType, Like, Operator, TryCast, WindowFrame, WindowFrameBound, WindowFrameUnits, }; -use datafusion_expr::{ExprFunctionExt, WriteOp}; +use datafusion_expr::{ExprFunctionExt, Subquery, WriteOp}; use datafusion_proto_common::{from_proto::FromOptionalField, FromProtoError as Error}; +use crate::logical_plan::AsLogicalPlan; use crate::protobuf::plan_type::PlanTypeEnum::{ FinalPhysicalPlanWithSchema, InitialPhysicalPlanWithSchema, }; +use crate::protobuf::OuterReferenceColumn; use crate::protobuf::{ self, plan_type::PlanTypeEnum::{ @@ -255,7 +258,7 @@ impl From for NullTreatment { pub fn parse_expr( proto: &protobuf::LogicalExprNode, - registry: &dyn FunctionRegistry, + ctx: &TaskContext, codec: &dyn LogicalExtensionCodec, ) -> Result { use protobuf::{logical_expr_node::ExprType, window_expr_node}; @@ -268,7 +271,7 @@ pub fn parse_expr( match expr_type { ExprType::BinaryExpr(binary_expr) => { let op = from_proto_binary_op(&binary_expr.op)?; - let operands = parse_exprs(&binary_expr.operands, registry, codec)?; + let operands = parse_exprs(&binary_expr.operands, ctx, codec)?; if operands.len() < 2 { return Err(proto_error( @@ -295,8 +298,8 @@ pub fn parse_expr( .window_function .as_ref() .ok_or_else(|| Error::required("window_function"))?; - let partition_by = parse_exprs(&expr.partition_by, registry, codec)?; - let mut order_by = parse_sorts(&expr.order_by, registry, codec)?; + let partition_by = parse_exprs(&expr.partition_by, ctx, codec)?; + let mut order_by = parse_sorts(&expr.order_by, ctx, codec)?; let window_frame = expr .window_frame .as_ref() @@ -328,7 +331,7 @@ pub fn parse_expr( window_expr_node::WindowFunction::Udaf(udaf_name) => { let udaf_function = match &expr.fun_definition { Some(buf) => codec.try_decode_udaf(udaf_name, buf)?, - None => registry + None => ctx .udaf(udaf_name) .or_else(|_| codec.try_decode_udaf(udaf_name, &[]))?, }; @@ -337,7 +340,7 @@ pub fn parse_expr( window_expr_node::WindowFunction::Udwf(udwf_name) => { let udwf_function = match &expr.fun_definition { Some(buf) => codec.try_decode_udwf(udwf_name, buf)?, - None => registry + None => ctx .udwf(udwf_name) .or_else(|_| codec.try_decode_udwf(udwf_name, &[]))?, }; @@ -345,7 +348,7 @@ pub fn parse_expr( } }; - let args = parse_exprs(&expr.exprs, registry, codec)?; + let args = parse_exprs(&expr.exprs, ctx, codec)?; let mut builder = Expr::from(WindowFunction::new(agg_fn, args)) .partition_by(partition_by) .order_by(order_by) @@ -356,8 +359,7 @@ pub fn parse_expr( builder = builder.distinct(); }; - if let Some(filter) = - parse_optional_expr(expr.filter.as_deref(), registry, codec)? + if let Some(filter) = parse_optional_expr(expr.filter.as_deref(), ctx, codec)? { builder = builder.filter(filter); } @@ -365,7 +367,7 @@ pub fn parse_expr( builder.build().map_err(Error::DataFusionError) } ExprType::Alias(alias) => Ok(Expr::Alias(Alias::new( - parse_required_expr(alias.expr.as_deref(), registry, "expr", codec)?, + parse_required_expr(alias.expr.as_deref(), ctx, "expr", codec)?, alias .relation .first() @@ -375,69 +377,69 @@ pub fn parse_expr( ))), ExprType::IsNullExpr(is_null) => Ok(Expr::IsNull(Box::new(parse_required_expr( is_null.expr.as_deref(), - registry, + ctx, "expr", codec, )?))), ExprType::IsNotNullExpr(is_not_null) => Ok(Expr::IsNotNull(Box::new( - parse_required_expr(is_not_null.expr.as_deref(), registry, "expr", codec)?, + parse_required_expr(is_not_null.expr.as_deref(), ctx, "expr", codec)?, ))), ExprType::NotExpr(not) => Ok(Expr::Not(Box::new(parse_required_expr( not.expr.as_deref(), - registry, + ctx, "expr", codec, )?))), ExprType::IsTrue(msg) => Ok(Expr::IsTrue(Box::new(parse_required_expr( msg.expr.as_deref(), - registry, + ctx, "expr", codec, )?))), ExprType::IsFalse(msg) => Ok(Expr::IsFalse(Box::new(parse_required_expr( msg.expr.as_deref(), - registry, + ctx, "expr", codec, )?))), ExprType::IsUnknown(msg) => Ok(Expr::IsUnknown(Box::new(parse_required_expr( msg.expr.as_deref(), - registry, + ctx, "expr", codec, )?))), ExprType::IsNotTrue(msg) => Ok(Expr::IsNotTrue(Box::new(parse_required_expr( msg.expr.as_deref(), - registry, + ctx, "expr", codec, )?))), ExprType::IsNotFalse(msg) => Ok(Expr::IsNotFalse(Box::new(parse_required_expr( msg.expr.as_deref(), - registry, + ctx, "expr", codec, )?))), ExprType::IsNotUnknown(msg) => Ok(Expr::IsNotUnknown(Box::new( - parse_required_expr(msg.expr.as_deref(), registry, "expr", codec)?, + parse_required_expr(msg.expr.as_deref(), ctx, "expr", codec)?, ))), ExprType::Between(between) => Ok(Expr::Between(Between::new( Box::new(parse_required_expr( between.expr.as_deref(), - registry, + ctx, "expr", codec, )?), between.negated, Box::new(parse_required_expr( between.low.as_deref(), - registry, + ctx, "expr", codec, )?), Box::new(parse_required_expr( between.high.as_deref(), - registry, + ctx, "expr", codec, )?), @@ -446,13 +448,13 @@ pub fn parse_expr( like.negated, Box::new(parse_required_expr( like.expr.as_deref(), - registry, + ctx, "expr", codec, )?), Box::new(parse_required_expr( like.pattern.as_deref(), - registry, + ctx, "pattern", codec, )?), @@ -463,13 +465,13 @@ pub fn parse_expr( like.negated, Box::new(parse_required_expr( like.expr.as_deref(), - registry, + ctx, "expr", codec, )?), Box::new(parse_required_expr( like.pattern.as_deref(), - registry, + ctx, "pattern", codec, )?), @@ -480,13 +482,13 @@ pub fn parse_expr( like.negated, Box::new(parse_required_expr( like.expr.as_deref(), - registry, + ctx, "expr", codec, )?), Box::new(parse_required_expr( like.pattern.as_deref(), - registry, + ctx, "pattern", codec, )?), @@ -500,13 +502,13 @@ pub fn parse_expr( .map(|e| { let when_expr = parse_required_expr( e.when_expr.as_ref(), - registry, + ctx, "when_expr", codec, )?; let then_expr = parse_required_expr( e.then_expr.as_ref(), - registry, + ctx, "then_expr", codec, )?; @@ -514,16 +516,15 @@ pub fn parse_expr( }) .collect::, Box)>, Error>>()?; Ok(Expr::Case(Case::new( - parse_optional_expr(case.expr.as_deref(), registry, codec)?.map(Box::new), + parse_optional_expr(case.expr.as_deref(), ctx, codec)?.map(Box::new), when_then_expr, - parse_optional_expr(case.else_expr.as_deref(), registry, codec)? - .map(Box::new), + parse_optional_expr(case.else_expr.as_deref(), ctx, codec)?.map(Box::new), ))) } ExprType::Cast(cast) => { let expr = Box::new(parse_required_expr( cast.expr.as_deref(), - registry, + ctx, "expr", codec, )?); @@ -533,7 +534,7 @@ pub fn parse_expr( ExprType::TryCast(cast) => { let expr = Box::new(parse_required_expr( cast.expr.as_deref(), - registry, + ctx, "expr", codec, )?); @@ -541,10 +542,10 @@ pub fn parse_expr( Ok(Expr::TryCast(TryCast::new(expr, data_type))) } ExprType::Negative(negative) => Ok(Expr::Negative(Box::new( - parse_required_expr(negative.expr.as_deref(), registry, "expr", codec)?, + parse_required_expr(negative.expr.as_deref(), ctx, "expr", codec)?, ))), ExprType::Unnest(unnest) => { - let mut exprs = parse_exprs(&unnest.exprs, registry, codec)?; + let mut exprs = parse_exprs(&unnest.exprs, ctx, codec)?; if exprs.len() != 1 { return Err(proto_error("Unnest must have exactly one expression")); } @@ -553,11 +554,11 @@ pub fn parse_expr( ExprType::InList(in_list) => Ok(Expr::InList(InList::new( Box::new(parse_required_expr( in_list.expr.as_deref(), - registry, + ctx, "expr", codec, )?), - parse_exprs(&in_list.list, registry, codec)?, + parse_exprs(&in_list.list, ctx, codec)?, in_list.negated, ))), ExprType::Wildcard(protobuf::Wildcard { qualifier }) => { @@ -575,19 +576,19 @@ pub fn parse_expr( }) => { let scalar_fn = match fun_definition { Some(buf) => codec.try_decode_udf(fun_name, buf)?, - None => registry + None => ctx .udf(fun_name.as_str()) .or_else(|_| codec.try_decode_udf(fun_name, &[]))?, }; Ok(Expr::ScalarFunction(expr::ScalarFunction::new_udf( scalar_fn, - parse_exprs(args, registry, codec)?, + parse_exprs(args, ctx, codec)?, ))) } ExprType::AggregateUdfExpr(pb) => { let agg_fn = match &pb.fun_definition { Some(buf) => codec.try_decode_udaf(&pb.fun_name, buf)?, - None => registry + None => ctx .udaf(&pb.fun_name) .or_else(|_| codec.try_decode_udaf(&pb.fun_name, &[]))?, }; @@ -606,26 +607,25 @@ pub fn parse_expr( Ok(Expr::AggregateFunction(expr::AggregateFunction::new_udf( agg_fn, - parse_exprs(&pb.args, registry, codec)?, + parse_exprs(&pb.args, ctx, codec)?, pb.distinct, - parse_optional_expr(pb.filter.as_deref(), registry, codec)?.map(Box::new), - parse_sorts(&pb.order_by, registry, codec)?, + parse_optional_expr(pb.filter.as_deref(), ctx, codec)?.map(Box::new), + parse_sorts(&pb.order_by, ctx, codec)?, null_treatment, ))) } - ExprType::GroupingSet(GroupingSetNode { expr }) => { Ok(Expr::GroupingSet(GroupingSets( expr.iter() - .map(|expr_list| parse_exprs(&expr_list.expr, registry, codec)) + .map(|expr_list| parse_exprs(&expr_list.expr, ctx, codec)) .collect::, Error>>()?, ))) } ExprType::Cube(CubeNode { expr }) => Ok(Expr::GroupingSet(GroupingSet::Cube( - parse_exprs(expr, registry, codec)?, + parse_exprs(expr, ctx, codec)?, ))), ExprType::Rollup(RollupNode { expr }) => Ok(Expr::GroupingSet( - GroupingSet::Rollup(parse_exprs(expr, registry, codec)?), + GroupingSet::Rollup(parse_exprs(expr, ctx, codec)?), )), ExprType::Placeholder(PlaceholderNode { id, @@ -647,13 +647,33 @@ pub fn parse_expr( ))) } }, + ExprType::ScalarSubquery(scalar_subquery) => { + let subquery = scalar_subquery + .subquery + .as_ref() + .ok_or_else(|| Error::required("subquery"))?; + Ok(Expr::ScalarSubquery(Subquery { + subquery: Arc::new(subquery.try_into_logical_plan(ctx, codec)?), + outer_ref_columns: parse_exprs( + &scalar_subquery.outer_ref_columns, + ctx, + codec, + )?, + spans: Spans::new(), + })) + } + ExprType::OuterReferenceColumn(OuterReferenceColumn { field, column }) => { + let column = column.to_owned().ok_or_else(|| Error::required("column"))?; + let field = field.as_ref().required("field")?; + Ok(Expr::OuterReferenceColumn(Arc::new(field), column.into())) + } } } /// Parse a vector of `protobuf::LogicalExprNode`s. pub fn parse_exprs<'a, I>( protos: I, - registry: &dyn FunctionRegistry, + ctx: &TaskContext, codec: &dyn LogicalExtensionCodec, ) -> Result, Error> where @@ -662,7 +682,7 @@ where let res = protos .into_iter() .map(|elem| { - parse_expr(elem, registry, codec).map_err(|e| plan_datafusion_err!("{}", e)) + parse_expr(elem, ctx, codec).map_err(|e| plan_datafusion_err!("{}", e)) }) .collect::>>()?; Ok(res) @@ -670,7 +690,7 @@ where pub fn parse_sorts<'a, I>( protos: I, - registry: &dyn FunctionRegistry, + ctx: &TaskContext, codec: &dyn LogicalExtensionCodec, ) -> Result, Error> where @@ -678,17 +698,17 @@ where { protos .into_iter() - .map(|sort| parse_sort(sort, registry, codec)) + .map(|sort| parse_sort(sort, ctx, codec)) .collect::, Error>>() } pub fn parse_sort( sort: &protobuf::SortExprNode, - registry: &dyn FunctionRegistry, + ctx: &TaskContext, codec: &dyn LogicalExtensionCodec, ) -> Result { Ok(Sort::new( - parse_required_expr(sort.expr.as_ref(), registry, "expr", codec)?, + parse_required_expr(sort.expr.as_ref(), ctx, "expr", codec)?, sort.asc, sort.nulls_first, )) @@ -740,23 +760,23 @@ pub fn from_proto_binary_op(op: &str) -> Result { fn parse_optional_expr( p: Option<&protobuf::LogicalExprNode>, - registry: &dyn FunctionRegistry, + ctx: &TaskContext, codec: &dyn LogicalExtensionCodec, ) -> Result, Error> { match p { - Some(expr) => parse_expr(expr, registry, codec).map(Some), + Some(expr) => parse_expr(expr, ctx, codec).map(Some), None => Ok(None), } } fn parse_required_expr( p: Option<&protobuf::LogicalExprNode>, - registry: &dyn FunctionRegistry, + ctx: &TaskContext, field: impl Into, codec: &dyn LogicalExtensionCodec, ) -> Result { match p { - Some(expr) => parse_expr(expr, registry, codec), + Some(expr) => parse_expr(expr, ctx, codec), None => Err(Error::required(field)), } } diff --git a/datafusion/proto/src/logical_plan/mod.rs b/datafusion/proto/src/logical_plan/mod.rs index 9644c9f69feae..2176f7919c1f5 100644 --- a/datafusion/proto/src/logical_plan/mod.rs +++ b/datafusion/proto/src/logical_plan/mod.rs @@ -1230,10 +1230,10 @@ impl AsLogicalPlan for LogicalPlanNode { logical_plan_type: Some(LogicalPlanType::Selection(Box::new( protobuf::SelectionNode { input: Some(Box::new(input)), - expr: Some(serialize_expr( + expr: Some(Box::new(serialize_expr( &filter.predicate, extension_codec, - )?), + )?)), }, ))), }) @@ -1350,7 +1350,7 @@ impl AsLogicalPlan for LogicalPlanNode { let filter = filter .as_ref() .map(|e| serialize_expr(e, extension_codec)) - .map_or(Ok(None), |v| v.map(Some))?; + .map_or(Ok(None), |v| v.map(|v| Some(Box::new(v))))?; Ok(LogicalPlanNode { logical_plan_type: Some(LogicalPlanType::Join(Box::new( protobuf::JoinNode { diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index 2774b5b6ba7c3..28453eda0dcb7 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -27,14 +27,14 @@ use datafusion_expr::expr::{ self, AggregateFunctionParams, Alias, Between, BinaryExpr, Cast, GroupingSet, InList, Like, NullTreatment, Placeholder, ScalarFunction, Unnest, }; -use datafusion_expr::WriteOp; use datafusion_expr::{ logical_plan::PlanType, logical_plan::StringifiedPlan, Expr, JoinConstraint, JoinType, SortExpr, TryCast, WindowFrame, WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition, }; +use datafusion_expr::{Subquery, WriteOp}; -use crate::protobuf::RecursionUnnestOption; +use crate::logical_plan::AsLogicalPlan; use crate::protobuf::{ self, plan_type::PlanTypeEnum::{ @@ -48,6 +48,7 @@ use crate::protobuf::{ OptimizedLogicalPlanType, OptimizedPhysicalPlanType, PlaceholderNode, RollupNode, ToProtoError as Error, }; +use crate::protobuf::{OuterReferenceColumn, RecursionUnnestOption}; use super::LogicalExtensionCodec; @@ -576,13 +577,30 @@ pub fn serialize_expr( qualifier: qualifier.to_owned().map(|x| x.into()), })), }, - Expr::ScalarSubquery(_) - | Expr::InSubquery(_) - | Expr::Exists { .. } - | Expr::OuterReferenceColumn { .. } => { - // we would need to add logical plan operators to datafusion.proto to support this - // see discussion in https://github.com/apache/datafusion/issues/2565 - return Err(Error::General("Proto serialization error: Expr::ScalarSubquery(_) | Expr::InSubquery(_) | Expr::Exists { .. } | Exp:OuterReferenceColumn not supported".to_string())); + Expr::ScalarSubquery(Subquery { + subquery, + outer_ref_columns, + spans: _, + }) => protobuf::LogicalExprNode { + expr_type: Some(ExprType::ScalarSubquery(Box::new( + protobuf::ScalarSubquery { + subquery: Some(Box::new( + protobuf::LogicalPlanNode::try_from_logical_plan(subquery, codec) + .map_err(|e| Error::General(format!("Proto serialization error: Failed to serialize Scalar Subquery: {e}")))?, + )), + outer_ref_columns: serialize_exprs(outer_ref_columns, codec)?, + }, + ))), + }, + Expr::OuterReferenceColumn(field, column) => { + protobuf::LogicalExprNode { + expr_type: Some(ExprType::OuterReferenceColumn(OuterReferenceColumn{ + field: Some(field.as_ref().try_into()?), column: Some(column.into()) + })) + } + } + Expr::InSubquery(_) | Expr::Exists { .. } => { + return Err(Error::NotImplemented("Proto serialization error: Expr::InSubquery(_) | Expr::Exists { .. } not supported".to_string())); } Expr::GroupingSet(GroupingSet::Cube(exprs)) => protobuf::LogicalExprNode { expr_type: Some(ExprType::Cube(CubeNode { diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs index bfd693e6a0f83..cc717e161b882 100644 --- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs @@ -132,7 +132,8 @@ fn roundtrip_expr_test_with_codec( ) { let proto: protobuf::LogicalExprNode = serialize_expr(&initial_struct, codec) .unwrap_or_else(|e| panic!("Error serializing expression: {e:?}")); - let round_trip: Expr = from_proto::parse_expr(&proto, &ctx, codec).unwrap(); + let round_trip: Expr = + from_proto::parse_expr(&proto, &ctx.task_ctx(), codec).unwrap(); assert_eq!(format!("{:?}", &initial_struct), format!("{round_trip:?}")); @@ -2349,8 +2350,8 @@ fn roundtrip_scalar_udf_extension_codec() { let test_expr = udf.call(vec!["foo".lit()]); let ctx = SessionContext::new(); let proto = serialize_expr(&test_expr, &UDFExtensionCodec).expect("serialize expr"); - let round_trip = - from_proto::parse_expr(&proto, &ctx, &UDFExtensionCodec).expect("parse expr"); + let round_trip = from_proto::parse_expr(&proto, &ctx.task_ctx(), &UDFExtensionCodec) + .expect("parse expr"); assert_eq!(format!("{:?}", &test_expr), format!("{round_trip:?}")); roundtrip_json_test(&proto); @@ -2362,8 +2363,8 @@ fn roundtrip_aggregate_udf_extension_codec() { let test_expr = udf.call(vec![42.lit()]); let ctx = SessionContext::new(); let proto = serialize_expr(&test_expr, &UDFExtensionCodec).expect("serialize expr"); - let round_trip = - from_proto::parse_expr(&proto, &ctx, &UDFExtensionCodec).expect("parse expr"); + let round_trip = from_proto::parse_expr(&proto, &ctx.task_ctx(), &UDFExtensionCodec) + .expect("parse expr"); assert_eq!(format!("{:?}", &test_expr), format!("{round_trip:?}")); roundtrip_json_test(&proto); @@ -2879,3 +2880,26 @@ async fn roundtrip_mixed_case_table_reference() -> Result<()> { Ok(()) } + +#[tokio::test] +async fn roundtrip_scalar_subquery_with_outer_reference_column() -> Result<()> { + let query = "SELECT t1.a, t1.b + FROM t1 + WHERE t1.a + 1 = ( + SELECT t2.a FROM t2 WHERE t1.a > 0 LIMIT 1 + )"; + + let ctx = SessionContext::new(); + ctx.register_csv("t1", "tests/testdata/test.csv", CsvReadOptions::default()) + .await?; + ctx.register_csv("t2", "tests/testdata/test.csv", CsvReadOptions::default()) + .await?; + let dataframe = ctx.sql(query).await?; + let plan = dataframe.logical_plan(); + + let bytes = logical_plan_to_bytes(plan)?; + + let logical_round_trip = logical_plan_from_bytes(&bytes, &ctx.task_ctx())?; + assert_eq!(plan, &logical_round_trip); + Ok(()) +} diff --git a/datafusion/proto/tests/cases/serialize.rs b/datafusion/proto/tests/cases/serialize.rs index f45a62e948740..5d81e11d39a5a 100644 --- a/datafusion/proto/tests/cases/serialize.rs +++ b/datafusion/proto/tests/cases/serialize.rs @@ -77,7 +77,7 @@ fn udf_roundtrip_with_registry() { .call(vec![lit("")]); let bytes = expr.to_bytes().unwrap(); - let deserialized_expr = Expr::from_bytes_with_registry(&bytes, &ctx).unwrap(); + let deserialized_expr = Expr::from_bytes_with_ctx(&bytes, &ctx.task_ctx()).unwrap(); assert_eq!(expr, deserialized_expr); } @@ -281,7 +281,7 @@ fn test_expression_serialization_roundtrip() { let extension_codec = DefaultLogicalExtensionCodec {}; let proto = serialize_expr(&expr, &extension_codec).unwrap(); - let deserialize = parse_expr(&proto, &ctx, &extension_codec).unwrap(); + let deserialize = parse_expr(&proto, &ctx.task_ctx(), &extension_codec).unwrap(); let serialize_name = extract_function_name(&expr); let deserialize_name = extract_function_name(&deserialize);