diff --git a/datafusion-examples/examples/composed_extension_codec.rs b/datafusion-examples/examples/composed_extension_codec.rs index 57f2c370413aa..c4004063172ce 100644 --- a/datafusion-examples/examples/composed_extension_codec.rs +++ b/datafusion-examples/examples/composed_extension_codec.rs @@ -39,6 +39,7 @@ use datafusion::common::Result; use datafusion::execution::TaskContext; use datafusion::physical_plan::{DisplayAs, ExecutionPlan}; use datafusion::prelude::SessionContext; +use datafusion_proto::physical_plan::DecodeContext; use datafusion_proto::physical_plan::{ AsExecutionPlan, ComposedPhysicalExtensionCodec, PhysicalExtensionCodec, }; @@ -71,7 +72,7 @@ async fn main() { // deserialize proto back to execution plan let result_exec_plan: Arc = proto - .try_into_physical_plan(&ctx.task_ctx(), &composed_codec) + .try_into_physical_plan(&DecodeContext::new(&ctx.task_ctx()), &composed_codec) .expect("from proto"); // assert that the original and deserialized execution plans are equal @@ -137,7 +138,7 @@ impl PhysicalExtensionCodec for ParentPhysicalExtensionCodec { &self, buf: &[u8], inputs: &[Arc], - _ctx: &TaskContext, + _ctx: &DecodeContext, ) -> Result> { if buf == "ParentExec".as_bytes() { Ok(Arc::new(ParentExec { @@ -213,7 +214,7 @@ impl PhysicalExtensionCodec for ChildPhysicalExtensionCodec { &self, buf: &[u8], _inputs: &[Arc], - _ctx: &TaskContext, + _ctx: &DecodeContext, ) -> Result> { if buf == "ChildExec".as_bytes() { Ok(Arc::new(ChildExec {})) diff --git a/datafusion/ffi/src/plan_properties.rs b/datafusion/ffi/src/plan_properties.rs index 48c2698a58c75..db230ce357608 100644 --- a/datafusion/ffi/src/plan_properties.rs +++ b/datafusion/ffi/src/plan_properties.rs @@ -38,7 +38,7 @@ use datafusion_proto::{ physical_plan::{ from_proto::{parse_physical_sort_exprs, parse_protobuf_partitioning}, to_proto::{serialize_partitioning, serialize_physical_sort_exprs}, - DefaultPhysicalExtensionCodec, + DecodeContext, DefaultPhysicalExtensionCodec, }, protobuf::{Partitioning, PhysicalSortExprNodeCollection}, }; @@ -183,6 +183,7 @@ impl TryFrom for PlanProperties { let default_ctx = SessionContext::new(); let task_context = default_ctx.task_ctx(); let codex = DefaultPhysicalExtensionCodec {}; + let decode_context = DecodeContext::new(&task_context); let ffi_orderings = unsafe { (ffi_props.output_ordering)(&ffi_props) }; @@ -191,7 +192,7 @@ impl TryFrom for PlanProperties { .map_err(|e| DataFusionError::External(Box::new(e)))?; let sort_exprs = parse_physical_sort_exprs( &proto_output_ordering.physical_sort_expr_nodes, - &task_context, + &decode_context, &schema, &codex, )?; @@ -203,7 +204,7 @@ impl TryFrom for PlanProperties { .map_err(|e| DataFusionError::External(Box::new(e)))?; let partitioning = parse_protobuf_partitioning( Some(&proto_output_partitioning), - &task_context, + &decode_context, &schema, &codex, )? diff --git a/datafusion/ffi/src/udaf/accumulator_args.rs b/datafusion/ffi/src/udaf/accumulator_args.rs index 0302c26a2e6b5..00885bc68fd04 100644 --- a/datafusion/ffi/src/udaf/accumulator_args.rs +++ b/datafusion/ffi/src/udaf/accumulator_args.rs @@ -35,7 +35,7 @@ use datafusion_proto::{ physical_plan::{ from_proto::{parse_physical_exprs, parse_physical_sort_exprs}, to_proto::{serialize_physical_exprs, serialize_physical_sort_exprs}, - DefaultPhysicalExtensionCodec, + DecodeContext, DefaultPhysicalExtensionCodec, }, protobuf::PhysicalAggregateExprNode, }; @@ -121,16 +121,17 @@ impl TryFrom for ForeignAccumulatorArgs { let default_ctx = SessionContext::new(); let task_ctx = default_ctx.task_ctx(); + let decode_ctx = DecodeContext::new(&task_ctx); let codex = DefaultPhysicalExtensionCodec {}; let order_bys = parse_physical_sort_exprs( &proto_def.ordering_req, - &task_ctx, + &decode_ctx, &schema, &codex, )?; - let exprs = parse_physical_exprs(&proto_def.expr, &task_ctx, &schema, &codex)?; + let exprs = parse_physical_exprs(&proto_def.expr, &decode_ctx, &schema, &codex)?; Ok(Self { return_field, diff --git a/datafusion/ffi/src/udwf/partition_evaluator_args.rs b/datafusion/ffi/src/udwf/partition_evaluator_args.rs index cd26412564374..2b41c7e98ee8c 100644 --- a/datafusion/ffi/src/udwf/partition_evaluator_args.rs +++ b/datafusion/ffi/src/udwf/partition_evaluator_args.rs @@ -35,7 +35,7 @@ use datafusion_common::exec_datafusion_err; use datafusion_proto::{ physical_plan::{ from_proto::parse_physical_expr, to_proto::serialize_physical_exprs, - DefaultPhysicalExtensionCodec, + DecodeContext, DefaultPhysicalExtensionCodec, }, protobuf::PhysicalExprNode, }; @@ -137,6 +137,8 @@ impl TryFrom for ForeignPartitionEvaluatorArgs { fn try_from(value: FFI_PartitionEvaluatorArgs) -> Result { let default_ctx = SessionContext::new(); + let task_ctx = default_ctx.task_ctx(); + let decode_ctx = DecodeContext::new(&task_ctx); let codec = DefaultPhysicalExtensionCodec {}; let schema: SchemaRef = value.schema.into(); @@ -148,9 +150,7 @@ impl TryFrom for ForeignPartitionEvaluatorArgs { .collect::, prost::DecodeError>>() .map_err(|e| exec_datafusion_err!("Failed to decode PhysicalExprNode: {e}"))? .iter() - .map(|expr_node| { - parse_physical_expr(expr_node, &default_ctx.task_ctx(), &schema, &codec) - }) + .map(|expr_node| parse_physical_expr(expr_node, &decode_ctx, &schema, &codec)) .collect::>>()?; let input_fields = input_exprs diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index 11103472ae2ae..4a38b62cefbea 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -865,6 +865,14 @@ message PhysicalExprNode { UnknownColumn unknown_column = 20; } + + // Optional ID for caching during deserialization. This is used for deduplication, + // so PhysicalExprs with the same ID will be deserialized as Arcs pointing to the + // same address (instead of distinct addresses) on the deserializing machine. + // + // We use the Arc pointer address during serialization as the ID, as this by default + // indicates if a PhysicalExpr is identical to another on the serializing machine. + optional uint64 id = 21; } message PhysicalScalarUdfNode { diff --git a/datafusion/proto/src/bytes/mod.rs b/datafusion/proto/src/bytes/mod.rs index 12d9938373ce6..58a6aff442463 100644 --- a/datafusion/proto/src/bytes/mod.rs +++ b/datafusion/proto/src/bytes/mod.rs @@ -21,7 +21,7 @@ use crate::logical_plan::{ self, AsLogicalPlan, DefaultLogicalExtensionCodec, LogicalExtensionCodec, }; use crate::physical_plan::{ - AsExecutionPlan, DefaultPhysicalExtensionCodec, PhysicalExtensionCodec, + AsExecutionPlan, DecodeContext, DefaultPhysicalExtensionCodec, PhysicalExtensionCodec, }; use crate::protobuf; use datafusion_common::{plan_datafusion_err, Result}; @@ -313,7 +313,8 @@ pub fn physical_plan_from_json( let back: protobuf::PhysicalPlanNode = serde_json::from_str(json) .map_err(|e| plan_datafusion_err!("Error serializing plan: {e}"))?; let extension_codec = DefaultPhysicalExtensionCodec {}; - back.try_into_physical_plan(&ctx, &extension_codec) + let decode_ctx = DecodeContext::new(ctx); + back.try_into_physical_plan(&decode_ctx, &extension_codec) } /// Deserialize a PhysicalPlan from bytes @@ -333,5 +334,6 @@ pub fn physical_plan_from_bytes_with_extension_codec( ) -> Result> { let protobuf = protobuf::PhysicalPlanNode::decode(bytes) .map_err(|e| plan_datafusion_err!("Error decoding expr as protobuf: {e}"))?; - protobuf.try_into_physical_plan(ctx, extension_codec) + let decode_ctx = DecodeContext::new(ctx); + protobuf.try_into_physical_plan(&decode_ctx, extension_codec) } diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index b34da2c312de0..e1d764760ba63 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -15711,10 +15711,18 @@ impl serde::Serialize for PhysicalExprNode { { use serde::ser::SerializeStruct; let mut len = 0; + if self.id.is_some() { + len += 1; + } if self.expr_type.is_some() { len += 1; } let mut struct_ser = serializer.serialize_struct("datafusion.PhysicalExprNode", len)?; + if let Some(v) = self.id.as_ref() { + #[allow(clippy::needless_borrow)] + #[allow(clippy::needless_borrows_for_generic_args)] + struct_ser.serialize_field("id", ToString::to_string(&v).as_str())?; + } if let Some(v) = self.expr_type.as_ref() { match v { physical_expr_node::ExprType::Column(v) => { @@ -15783,6 +15791,7 @@ impl<'de> serde::Deserialize<'de> for PhysicalExprNode { D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ + "id", "column", "literal", "binary_expr", @@ -15817,6 +15826,7 @@ impl<'de> serde::Deserialize<'de> for PhysicalExprNode { #[allow(clippy::enum_variant_names)] enum GeneratedField { + Id, Column, Literal, BinaryExpr, @@ -15856,6 +15866,7 @@ impl<'de> serde::Deserialize<'de> for PhysicalExprNode { E: serde::de::Error, { match value { + "id" => Ok(GeneratedField::Id), "column" => Ok(GeneratedField::Column), "literal" => Ok(GeneratedField::Literal), "binaryExpr" | "binary_expr" => Ok(GeneratedField::BinaryExpr), @@ -15893,9 +15904,18 @@ impl<'de> serde::Deserialize<'de> for PhysicalExprNode { where V: serde::de::MapAccess<'de>, { + let mut id__ = None; let mut expr_type__ = None; while let Some(k) = map_.next_key()? { match k { + GeneratedField::Id => { + if id__.is_some() { + return Err(serde::de::Error::duplicate_field("id")); + } + id__ = + map_.next_value::<::std::option::Option<::pbjson::private::NumberDeserialize<_>>>()?.map(|x| x.0) + ; + } GeneratedField::Column => { if expr_type__.is_some() { return Err(serde::de::Error::duplicate_field("column")); @@ -16025,6 +16045,7 @@ impl<'de> serde::Deserialize<'de> for PhysicalExprNode { } } Ok(PhysicalExprNode { + id: id__, expr_type: expr_type__, }) } diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index 2e1c482db65c4..807c7ade647ff 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -1259,6 +1259,14 @@ pub struct PhysicalExtensionNode { /// physical expressions #[derive(Clone, PartialEq, ::prost::Message)] pub struct PhysicalExprNode { + /// Optional ID for caching during deserialization. This is used for deduplication, + /// so PhysicalExprs with the same ID will be deserialized as Arcs pointing to the + /// same address (instead of distinct addresses) on the deserializing machine. + /// + /// We use the Arc pointer address during serialization as the ID, as this by default + /// indicates if a PhysicalExpr is identical to another on the serializing machine. + #[prost(uint64, optional, tag = "21")] + pub id: ::core::option::Option, #[prost( oneof = "physical_expr_node::ExprType", tags = "1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 15, 16, 18, 19, 20" diff --git a/datafusion/proto/src/physical_plan/from_proto.rs b/datafusion/proto/src/physical_plan/from_proto.rs index 7c4b9e55b8137..5020c21e2a81a 100644 --- a/datafusion/proto/src/physical_plan/from_proto.rs +++ b/datafusion/proto/src/physical_plan/from_proto.rs @@ -40,7 +40,7 @@ use datafusion_datasource_json::file_format::JsonSink; #[cfg(feature = "parquet")] use datafusion_datasource_parquet::file_format::ParquetSink; use datafusion_execution::object_store::ObjectStoreUrl; -use datafusion_execution::{FunctionRegistry, TaskContext}; +use datafusion_execution::FunctionRegistry; use datafusion_expr::WindowFunctionDefinition; use datafusion_physical_expr::{LexOrdering, PhysicalSortExpr, ScalarFunctionExpr}; use datafusion_physical_plan::expressions::{ @@ -56,7 +56,7 @@ use crate::logical_plan::{self}; use crate::protobuf; use crate::protobuf::physical_expr_node::ExprType; -use super::PhysicalExtensionCodec; +use super::{DecodeContext, PhysicalExtensionCodec}; impl From<&protobuf::PhysicalColumn> for Column { fn from(c: &protobuf::PhysicalColumn) -> Column { @@ -75,12 +75,12 @@ impl From<&protobuf::PhysicalColumn> for Column { /// * `codec` - An extension codec used to decode custom UDFs. pub fn parse_physical_sort_expr( proto: &protobuf::PhysicalSortExprNode, - ctx: &TaskContext, + decode_ctx: &DecodeContext, input_schema: &Schema, codec: &dyn PhysicalExtensionCodec, ) -> Result { if let Some(expr) = &proto.expr { - let expr = parse_physical_expr(expr.as_ref(), ctx, input_schema, codec)?; + let expr = parse_physical_expr(expr.as_ref(), decode_ctx, input_schema, codec)?; let options = SortOptions { descending: !proto.asc, nulls_first: proto.nulls_first, @@ -102,13 +102,15 @@ pub fn parse_physical_sort_expr( /// * `codec` - An extension codec used to decode custom UDFs. pub fn parse_physical_sort_exprs( proto: &[protobuf::PhysicalSortExprNode], - ctx: &TaskContext, + decode_ctx: &DecodeContext, input_schema: &Schema, codec: &dyn PhysicalExtensionCodec, ) -> Result> { proto .iter() - .map(|sort_expr| parse_physical_sort_expr(sort_expr, ctx, input_schema, codec)) + .map(|sort_expr| { + parse_physical_sort_expr(sort_expr, decode_ctx, input_schema, codec) + }) .collect() } @@ -124,15 +126,18 @@ pub fn parse_physical_sort_exprs( /// * `codec` - An extension codec used to decode custom UDFs. pub fn parse_physical_window_expr( proto: &protobuf::PhysicalWindowExprNode, - ctx: &TaskContext, + decode_ctx: &DecodeContext, input_schema: &Schema, codec: &dyn PhysicalExtensionCodec, ) -> Result> { - let window_node_expr = parse_physical_exprs(&proto.args, ctx, input_schema, codec)?; + let ctx = decode_ctx.task_context(); + let window_node_expr = + parse_physical_exprs(&proto.args, decode_ctx, input_schema, codec)?; let partition_by = - parse_physical_exprs(&proto.partition_by, ctx, input_schema, codec)?; + parse_physical_exprs(&proto.partition_by, decode_ctx, input_schema, codec)?; - let order_by = parse_physical_sort_exprs(&proto.order_by, ctx, input_schema, codec)?; + let order_by = + parse_physical_sort_exprs(&proto.order_by, decode_ctx, input_schema, codec)?; let window_frame = proto .window_frame @@ -183,7 +188,7 @@ pub fn parse_physical_window_expr( pub fn parse_physical_exprs<'a, I>( protos: I, - ctx: &TaskContext, + decode_ctx: &DecodeContext, input_schema: &Schema, codec: &dyn PhysicalExtensionCodec, ) -> Result>> @@ -192,7 +197,7 @@ where { protos .into_iter() - .map(|p| parse_physical_expr(p, ctx, input_schema, codec)) + .map(|p| parse_physical_expr(p, decode_ctx, input_schema, codec)) .collect::>>() } @@ -207,10 +212,19 @@ where /// * `codec` - An extension codec used to decode custom UDFs. pub fn parse_physical_expr( proto: &protobuf::PhysicalExprNode, - ctx: &TaskContext, + decode_ctx: &DecodeContext, input_schema: &Schema, codec: &dyn PhysicalExtensionCodec, ) -> Result> { + // Check cache first if an ID is present + if let Some(id) = proto.id { + if let Some(cached) = decode_ctx.get_cached_expr(id) { + return Ok(cached); + } + } + + let ctx = decode_ctx.task_context(); + let expr_type = proto .expr_type .as_ref() @@ -226,7 +240,7 @@ pub fn parse_physical_expr( ExprType::BinaryExpr(binary_expr) => Arc::new(BinaryExpr::new( parse_required_physical_expr( binary_expr.l.as_deref(), - ctx, + decode_ctx, "left", input_schema, codec, @@ -234,7 +248,7 @@ pub fn parse_physical_expr( logical_plan::from_proto::from_proto_binary_op(&binary_expr.op)?, parse_required_physical_expr( binary_expr.r.as_deref(), - ctx, + decode_ctx, "right", input_schema, codec, @@ -256,7 +270,7 @@ pub fn parse_physical_expr( ExprType::IsNullExpr(e) => { Arc::new(IsNullExpr::new(parse_required_physical_expr( e.expr.as_deref(), - ctx, + decode_ctx, "expr", input_schema, codec, @@ -265,7 +279,7 @@ pub fn parse_physical_expr( ExprType::IsNotNullExpr(e) => { Arc::new(IsNotNullExpr::new(parse_required_physical_expr( e.expr.as_deref(), - ctx, + decode_ctx, "expr", input_schema, codec, @@ -273,7 +287,7 @@ pub fn parse_physical_expr( } ExprType::NotExpr(e) => Arc::new(NotExpr::new(parse_required_physical_expr( e.expr.as_deref(), - ctx, + decode_ctx, "expr", input_schema, codec, @@ -281,7 +295,7 @@ pub fn parse_physical_expr( ExprType::Negative(e) => { Arc::new(NegativeExpr::new(parse_required_physical_expr( e.expr.as_deref(), - ctx, + decode_ctx, "expr", input_schema, codec, @@ -290,19 +304,19 @@ pub fn parse_physical_expr( ExprType::InList(e) => in_list( parse_required_physical_expr( e.expr.as_deref(), - ctx, + decode_ctx, "expr", input_schema, codec, )?, - parse_physical_exprs(&e.list, ctx, input_schema, codec)?, + parse_physical_exprs(&e.list, decode_ctx, input_schema, codec)?, &e.negated, input_schema, )?, ExprType::Case(e) => Arc::new(CaseExpr::try_new( e.expr .as_ref() - .map(|e| parse_physical_expr(e.as_ref(), ctx, input_schema, codec)) + .map(|e| parse_physical_expr(e.as_ref(), decode_ctx, input_schema, codec)) .transpose()?, e.when_then_expr .iter() @@ -310,14 +324,14 @@ pub fn parse_physical_expr( Ok(( parse_required_physical_expr( e.when_expr.as_ref(), - ctx, + decode_ctx, "when_expr", input_schema, codec, )?, parse_required_physical_expr( e.then_expr.as_ref(), - ctx, + decode_ctx, "then_expr", input_schema, codec, @@ -327,13 +341,13 @@ pub fn parse_physical_expr( .collect::>>()?, e.else_expr .as_ref() - .map(|e| parse_physical_expr(e.as_ref(), ctx, input_schema, codec)) + .map(|e| parse_physical_expr(e.as_ref(), decode_ctx, input_schema, codec)) .transpose()?, )?), ExprType::Cast(e) => Arc::new(CastExpr::new( parse_required_physical_expr( e.expr.as_deref(), - ctx, + decode_ctx, "expr", input_schema, codec, @@ -344,7 +358,7 @@ pub fn parse_physical_expr( ExprType::TryCast(e) => Arc::new(TryCastExpr::new( parse_required_physical_expr( e.expr.as_deref(), - ctx, + decode_ctx, "expr", input_schema, codec, @@ -360,7 +374,7 @@ pub fn parse_physical_expr( }; let scalar_fun_def = Arc::clone(&udf); - let args = parse_physical_exprs(&e.args, ctx, input_schema, codec)?; + let args = parse_physical_exprs(&e.args, decode_ctx, input_schema, codec)?; let config_options = Arc::clone(ctx.session_config().options()); @@ -385,14 +399,14 @@ pub fn parse_physical_expr( like_expr.case_insensitive, parse_required_physical_expr( like_expr.expr.as_deref(), - ctx, + decode_ctx, "expr", input_schema, codec, )?, parse_required_physical_expr( like_expr.pattern.as_deref(), - ctx, + decode_ctx, "pattern", input_schema, codec, @@ -402,37 +416,46 @@ pub fn parse_physical_expr( let inputs: Vec> = extension .inputs .iter() - .map(|e| parse_physical_expr(e, ctx, input_schema, codec)) + .map(|e| parse_physical_expr(e, decode_ctx, input_schema, codec)) .collect::>()?; (codec.try_decode_expr(extension.expr.as_slice(), &inputs)?) as _ } }; - Ok(pexpr) + // Insert into cache if an ID is present + if let Some(id) = proto.id { + Ok(decode_ctx.insert_cached_expr(id, pexpr)) + } else { + Ok(pexpr) + } } fn parse_required_physical_expr( expr: Option<&protobuf::PhysicalExprNode>, - ctx: &TaskContext, + decode_ctx: &DecodeContext, field: &str, input_schema: &Schema, codec: &dyn PhysicalExtensionCodec, ) -> Result> { - expr.map(|e| parse_physical_expr(e, ctx, input_schema, codec)) + expr.map(|e| parse_physical_expr(e, decode_ctx, input_schema, codec)) .transpose()? .ok_or_else(|| internal_datafusion_err!("Missing required field {field:?}")) } pub fn parse_protobuf_hash_partitioning( partitioning: Option<&protobuf::PhysicalHashRepartition>, - ctx: &TaskContext, + decode_ctx: &DecodeContext, input_schema: &Schema, codec: &dyn PhysicalExtensionCodec, ) -> Result> { match partitioning { Some(hash_part) => { - let expr = - parse_physical_exprs(&hash_part.hash_expr, ctx, input_schema, codec)?; + let expr = parse_physical_exprs( + &hash_part.hash_expr, + decode_ctx, + input_schema, + codec, + )?; Ok(Some(Partitioning::Hash( expr, @@ -445,7 +468,7 @@ pub fn parse_protobuf_hash_partitioning( pub fn parse_protobuf_partitioning( partitioning: Option<&protobuf::Partitioning>, - ctx: &TaskContext, + decode_ctx: &DecodeContext, input_schema: &Schema, codec: &dyn PhysicalExtensionCodec, ) -> Result> { @@ -459,7 +482,7 @@ pub fn parse_protobuf_partitioning( Some(protobuf::partitioning::PartitionMethod::Hash(hash_repartition)) => { parse_protobuf_hash_partitioning( Some(hash_repartition), - ctx, + decode_ctx, input_schema, codec, ) @@ -483,7 +506,7 @@ pub fn parse_protobuf_file_scan_schema( pub fn parse_protobuf_file_scan_config( proto: &protobuf::FileScanExecConf, - ctx: &TaskContext, + decode_ctx: &DecodeContext, codec: &dyn PhysicalExtensionCodec, file_source: Arc, ) -> Result { @@ -534,7 +557,7 @@ pub fn parse_protobuf_file_scan_config( for node_collection in &proto.output_ordering { let sort_exprs = parse_physical_sort_exprs( &node_collection.physical_sort_expr_nodes, - ctx, + decode_ctx, &schema, codec, )?; diff --git a/datafusion/proto/src/physical_plan/mod.rs b/datafusion/proto/src/physical_plan/mod.rs index cd9bd672851d0..4039773351539 100644 --- a/datafusion/proto/src/physical_plan/mod.rs +++ b/datafusion/proto/src/physical_plan/mod.rs @@ -15,8 +15,9 @@ // specific language governing permissions and limitations // under the License. +use std::collections::HashMap; use std::fmt::Debug; -use std::sync::Arc; +use std::sync::{Arc, Mutex}; use self::from_proto::parse_protobuf_partitioning; use self::to_proto::{serialize_partitioning, serialize_physical_expr}; @@ -101,6 +102,49 @@ use datafusion_physical_plan::{ExecutionPlan, InputOrderMode, PhysicalExpr, Wind use prost::bytes::BufMut; use prost::Message; +/// Context for decoding physical expressions with caching support. +/// +/// This struct wraps a `TaskContext` and maintains a cache of previously deserialized +/// physical expressions. The cache is keyed by the expression's ID (derived from the +/// Arc pointer during serialization), allowing duplicate expressions in a plan to be +/// deserialized only once. +pub struct DecodeContext<'a> { + task_context: &'a TaskContext, + cache: Mutex>>, +} + +impl<'a> DecodeContext<'a> { + /// Create a new DecodeContext wrapping the given TaskContext. + pub fn new(task_context: &'a TaskContext) -> Self { + Self { + task_context, + cache: Mutex::new(HashMap::new()), + } + } + + /// Get the underlying TaskContext reference. + pub fn task_context(&self) -> &'a TaskContext { + self.task_context + } + + /// Attempt to retrieve a cached physical expression by its ID. + pub fn get_cached_expr(&self, id: u64) -> Option> { + self.cache.lock().unwrap().get(&id).cloned() + } + + /// Insert a physical expression into the cache with the given ID. + /// Returns the inserted expression or an existing one if the ID was already present. + #[must_use] + pub fn insert_cached_expr( + &self, + id: u64, + expr: Arc, + ) -> Arc { + let mut cache = self.cache.lock().unwrap(); + Arc::clone(cache.entry(id).or_insert(expr)) + } +} + pub mod from_proto; pub mod to_proto; @@ -126,8 +170,7 @@ impl AsExecutionPlan for protobuf::PhysicalPlanNode { fn try_into_physical_plan( &self, - ctx: &TaskContext, - + decode_ctx: &DecodeContext, extension_codec: &dyn PhysicalExtensionCodec, ) -> Result> { let plan = self.physical_plan_type.as_ref().ok_or_else(|| { @@ -137,117 +180,130 @@ impl AsExecutionPlan for protobuf::PhysicalPlanNode { })?; match plan { PhysicalPlanType::Explain(explain) => { - self.try_into_explain_physical_plan(explain, ctx, extension_codec) - } - PhysicalPlanType::Projection(projection) => { - self.try_into_projection_physical_plan(projection, ctx, extension_codec) + self.try_into_explain_physical_plan(explain, decode_ctx, extension_codec) } + PhysicalPlanType::Projection(projection) => self + .try_into_projection_physical_plan( + projection, + decode_ctx, + extension_codec, + ), PhysicalPlanType::Filter(filter) => { - self.try_into_filter_physical_plan(filter, ctx, extension_codec) + self.try_into_filter_physical_plan(filter, decode_ctx, extension_codec) } PhysicalPlanType::CsvScan(scan) => { - self.try_into_csv_scan_physical_plan(scan, ctx, extension_codec) + self.try_into_csv_scan_physical_plan(scan, decode_ctx, extension_codec) } PhysicalPlanType::JsonScan(scan) => { - self.try_into_json_scan_physical_plan(scan, ctx, extension_codec) + self.try_into_json_scan_physical_plan(scan, decode_ctx, extension_codec) } #[cfg_attr(not(feature = "parquet"), allow(unused_variables))] - PhysicalPlanType::ParquetScan(scan) => { - self.try_into_parquet_scan_physical_plan(scan, ctx, extension_codec) - } + PhysicalPlanType::ParquetScan(scan) => self + .try_into_parquet_scan_physical_plan(scan, decode_ctx, extension_codec), #[cfg_attr(not(feature = "avro"), allow(unused_variables))] PhysicalPlanType::AvroScan(scan) => { - self.try_into_avro_scan_physical_plan(scan, ctx, extension_codec) + self.try_into_avro_scan_physical_plan(scan, decode_ctx, extension_codec) } PhysicalPlanType::MemoryScan(scan) => { - self.try_into_memory_scan_physical_plan(scan, ctx, extension_codec) + self.try_into_memory_scan_physical_plan(scan, decode_ctx, extension_codec) } PhysicalPlanType::CoalesceBatches(coalesce_batches) => self .try_into_coalesce_batches_physical_plan( coalesce_batches, - ctx, + decode_ctx, extension_codec, ), PhysicalPlanType::Merge(merge) => { - self.try_into_merge_physical_plan(merge, ctx, extension_codec) - } - PhysicalPlanType::Repartition(repart) => { - self.try_into_repartition_physical_plan(repart, ctx, extension_codec) - } - PhysicalPlanType::GlobalLimit(limit) => { - self.try_into_global_limit_physical_plan(limit, ctx, extension_codec) - } - PhysicalPlanType::LocalLimit(limit) => { - self.try_into_local_limit_physical_plan(limit, ctx, extension_codec) - } - PhysicalPlanType::Window(window_agg) => { - self.try_into_window_physical_plan(window_agg, ctx, extension_codec) - } - PhysicalPlanType::Aggregate(hash_agg) => { - self.try_into_aggregate_physical_plan(hash_agg, ctx, extension_codec) - } - PhysicalPlanType::HashJoin(hashjoin) => { - self.try_into_hash_join_physical_plan(hashjoin, ctx, extension_codec) + self.try_into_merge_physical_plan(merge, decode_ctx, extension_codec) } + PhysicalPlanType::Repartition(repart) => self + .try_into_repartition_physical_plan(repart, decode_ctx, extension_codec), + PhysicalPlanType::GlobalLimit(limit) => self + .try_into_global_limit_physical_plan(limit, decode_ctx, extension_codec), + PhysicalPlanType::LocalLimit(limit) => self + .try_into_local_limit_physical_plan(limit, decode_ctx, extension_codec), + PhysicalPlanType::Window(window_agg) => self.try_into_window_physical_plan( + window_agg, + decode_ctx, + extension_codec, + ), + PhysicalPlanType::Aggregate(hash_agg) => self + .try_into_aggregate_physical_plan(hash_agg, decode_ctx, extension_codec), + PhysicalPlanType::HashJoin(hashjoin) => self + .try_into_hash_join_physical_plan(hashjoin, decode_ctx, extension_codec), PhysicalPlanType::SymmetricHashJoin(sym_join) => self .try_into_symmetric_hash_join_physical_plan( sym_join, - ctx, + decode_ctx, extension_codec, ), PhysicalPlanType::Union(union) => { - self.try_into_union_physical_plan(union, ctx, extension_codec) - } - PhysicalPlanType::Interleave(interleave) => { - self.try_into_interleave_physical_plan(interleave, ctx, extension_codec) - } - PhysicalPlanType::CrossJoin(crossjoin) => { - self.try_into_cross_join_physical_plan(crossjoin, ctx, extension_codec) + self.try_into_union_physical_plan(union, decode_ctx, extension_codec) } + PhysicalPlanType::Interleave(interleave) => self + .try_into_interleave_physical_plan( + interleave, + decode_ctx, + extension_codec, + ), + PhysicalPlanType::CrossJoin(crossjoin) => self + .try_into_cross_join_physical_plan( + crossjoin, + decode_ctx, + extension_codec, + ), PhysicalPlanType::Empty(empty) => { - self.try_into_empty_physical_plan(empty, ctx, extension_codec) + self.try_into_empty_physical_plan(empty, decode_ctx, extension_codec) } PhysicalPlanType::PlaceholderRow(placeholder) => self .try_into_placeholder_row_physical_plan( placeholder, - ctx, + decode_ctx, extension_codec, ), PhysicalPlanType::Sort(sort) => { - self.try_into_sort_physical_plan(sort, ctx, extension_codec) + self.try_into_sort_physical_plan(sort, decode_ctx, extension_codec) } PhysicalPlanType::SortPreservingMerge(sort) => self - .try_into_sort_preserving_merge_physical_plan(sort, ctx, extension_codec), - PhysicalPlanType::Extension(extension) => { - self.try_into_extension_physical_plan(extension, ctx, extension_codec) - } - PhysicalPlanType::NestedLoopJoin(join) => { - self.try_into_nested_loop_join_physical_plan(join, ctx, extension_codec) - } + .try_into_sort_preserving_merge_physical_plan( + sort, + decode_ctx, + extension_codec, + ), + PhysicalPlanType::Extension(extension) => self + .try_into_extension_physical_plan(extension, decode_ctx, extension_codec), + PhysicalPlanType::NestedLoopJoin(join) => self + .try_into_nested_loop_join_physical_plan( + join, + decode_ctx, + extension_codec, + ), PhysicalPlanType::Analyze(analyze) => { - self.try_into_analyze_physical_plan(analyze, ctx, extension_codec) + self.try_into_analyze_physical_plan(analyze, decode_ctx, extension_codec) } PhysicalPlanType::JsonSink(sink) => { - self.try_into_json_sink_physical_plan(sink, ctx, extension_codec) + self.try_into_json_sink_physical_plan(sink, decode_ctx, extension_codec) } PhysicalPlanType::CsvSink(sink) => { - self.try_into_csv_sink_physical_plan(sink, ctx, extension_codec) + self.try_into_csv_sink_physical_plan(sink, decode_ctx, extension_codec) } #[cfg_attr(not(feature = "parquet"), allow(unused_variables))] - PhysicalPlanType::ParquetSink(sink) => { - self.try_into_parquet_sink_physical_plan(sink, ctx, extension_codec) - } + PhysicalPlanType::ParquetSink(sink) => self + .try_into_parquet_sink_physical_plan(sink, decode_ctx, extension_codec), PhysicalPlanType::Unnest(unnest) => { - self.try_into_unnest_physical_plan(unnest, ctx, extension_codec) - } - PhysicalPlanType::Cooperative(cooperative) => { - self.try_into_cooperative_physical_plan(cooperative, ctx, extension_codec) + self.try_into_unnest_physical_plan(unnest, decode_ctx, extension_codec) } + PhysicalPlanType::Cooperative(cooperative) => self + .try_into_cooperative_physical_plan( + cooperative, + decode_ctx, + extension_codec, + ), PhysicalPlanType::GenerateSeries(generate_series) => { self.try_into_generate_series_physical_plan(generate_series) } PhysicalPlanType::SortMergeJoin(sort_join) => { - self.try_into_sort_join(sort_join, ctx, extension_codec) + self.try_into_sort_join(sort_join, decode_ctx, extension_codec) } } } @@ -492,7 +548,7 @@ impl protobuf::PhysicalPlanNode { fn try_into_explain_physical_plan( &self, explain: &protobuf::ExplainExecNode, - _ctx: &TaskContext, + _decode_ctx: &DecodeContext, _extension_codec: &dyn PhysicalExtensionCodec, ) -> Result> { @@ -510,12 +566,11 @@ impl protobuf::PhysicalPlanNode { fn try_into_projection_physical_plan( &self, projection: &protobuf::ProjectionExecNode, - ctx: &TaskContext, - + decode_ctx: &DecodeContext, extension_codec: &dyn PhysicalExtensionCodec, ) -> Result> { let input: Arc = - into_physical_plan(&projection.input, ctx, extension_codec)?; + into_physical_plan(&projection.input, decode_ctx, extension_codec)?; let exprs = projection .expr .iter() @@ -524,7 +579,7 @@ impl protobuf::PhysicalPlanNode { Ok(( parse_physical_expr( expr, - ctx, + decode_ctx, input.schema().as_ref(), extension_codec, )?, @@ -542,18 +597,23 @@ impl protobuf::PhysicalPlanNode { fn try_into_filter_physical_plan( &self, filter: &protobuf::FilterExecNode, - ctx: &TaskContext, + decode_ctx: &DecodeContext, extension_codec: &dyn PhysicalExtensionCodec, ) -> Result> { let input: Arc = - into_physical_plan(&filter.input, ctx, extension_codec)?; + into_physical_plan(&filter.input, decode_ctx, extension_codec)?; let predicate = filter .expr .as_ref() .map(|expr| { - parse_physical_expr(expr, ctx, input.schema().as_ref(), extension_codec) + parse_physical_expr( + expr, + decode_ctx, + input.schema().as_ref(), + extension_codec, + ) }) .transpose()? .ok_or_else(|| { @@ -590,7 +650,7 @@ impl protobuf::PhysicalPlanNode { fn try_into_csv_scan_physical_plan( &self, scan: &protobuf::CsvScanExecNode, - ctx: &TaskContext, + decode_ctx: &DecodeContext, extension_codec: &dyn PhysicalExtensionCodec, ) -> Result> { @@ -624,7 +684,7 @@ impl protobuf::PhysicalPlanNode { let conf = FileScanConfigBuilder::from(parse_protobuf_file_scan_config( scan.base_conf.as_ref().unwrap(), - ctx, + decode_ctx, extension_codec, source, )?) @@ -637,13 +697,13 @@ impl protobuf::PhysicalPlanNode { fn try_into_json_scan_physical_plan( &self, scan: &protobuf::JsonScanExecNode, - ctx: &TaskContext, + decode_ctx: &DecodeContext, extension_codec: &dyn PhysicalExtensionCodec, ) -> Result> { let scan_conf = parse_protobuf_file_scan_config( scan.base_conf.as_ref().unwrap(), - ctx, + decode_ctx, extension_codec, Arc::new(JsonSource::new()), )?; @@ -654,7 +714,7 @@ impl protobuf::PhysicalPlanNode { fn try_into_parquet_scan_physical_plan( &self, scan: &protobuf::ParquetScanExecNode, - ctx: &TaskContext, + decode_ctx: &DecodeContext, extension_codec: &dyn PhysicalExtensionCodec, ) -> Result> { @@ -684,7 +744,7 @@ impl protobuf::PhysicalPlanNode { .map(|expr| { parse_physical_expr( expr, - ctx, + decode_ctx, predicate_schema.as_ref(), extension_codec, ) @@ -702,7 +762,7 @@ impl protobuf::PhysicalPlanNode { } let base_config = parse_protobuf_file_scan_config( base_conf, - ctx, + decode_ctx, extension_codec, Arc::new(source), )?; @@ -716,7 +776,7 @@ impl protobuf::PhysicalPlanNode { fn try_into_avro_scan_physical_plan( &self, scan: &protobuf::AvroScanExecNode, - ctx: &TaskContext, + decode_ctx: &DecodeContext, extension_codec: &dyn PhysicalExtensionCodec, ) -> Result> { @@ -724,7 +784,7 @@ impl protobuf::PhysicalPlanNode { { let conf = parse_protobuf_file_scan_config( scan.base_conf.as_ref().unwrap(), - ctx, + decode_ctx, extension_codec, Arc::new(AvroSource::new()), )?; @@ -737,7 +797,7 @@ impl protobuf::PhysicalPlanNode { fn try_into_memory_scan_physical_plan( &self, scan: &protobuf::MemoryScanExecNode, - ctx: &TaskContext, + decode_ctx: &DecodeContext, extension_codec: &dyn PhysicalExtensionCodec, ) -> Result> { @@ -767,7 +827,7 @@ impl protobuf::PhysicalPlanNode { for ordering in &scan.sort_information { let sort_exprs = parse_physical_sort_exprs( &ordering.physical_sort_expr_nodes, - ctx, + decode_ctx, &schema, extension_codec, )?; @@ -786,12 +846,11 @@ impl protobuf::PhysicalPlanNode { fn try_into_coalesce_batches_physical_plan( &self, coalesce_batches: &protobuf::CoalesceBatchesExecNode, - ctx: &TaskContext, - + decode_ctx: &DecodeContext, extension_codec: &dyn PhysicalExtensionCodec, ) -> Result> { let input: Arc = - into_physical_plan(&coalesce_batches.input, ctx, extension_codec)?; + into_physical_plan(&coalesce_batches.input, decode_ctx, extension_codec)?; Ok(Arc::new( CoalesceBatchesExec::new(input, coalesce_batches.target_batch_size as usize) .with_fetch(coalesce_batches.fetch.map(|f| f as usize)), @@ -801,12 +860,11 @@ impl protobuf::PhysicalPlanNode { fn try_into_merge_physical_plan( &self, merge: &protobuf::CoalescePartitionsExecNode, - ctx: &TaskContext, - + decode_ctx: &DecodeContext, extension_codec: &dyn PhysicalExtensionCodec, ) -> Result> { let input: Arc = - into_physical_plan(&merge.input, ctx, extension_codec)?; + into_physical_plan(&merge.input, decode_ctx, extension_codec)?; Ok(Arc::new( CoalescePartitionsExec::new(input) .with_fetch(merge.fetch.map(|f| f as usize)), @@ -816,15 +874,14 @@ impl protobuf::PhysicalPlanNode { fn try_into_repartition_physical_plan( &self, repart: &protobuf::RepartitionExecNode, - ctx: &TaskContext, - + decode_ctx: &DecodeContext, extension_codec: &dyn PhysicalExtensionCodec, ) -> Result> { let input: Arc = - into_physical_plan(&repart.input, ctx, extension_codec)?; + into_physical_plan(&repart.input, decode_ctx, extension_codec)?; let partitioning = parse_protobuf_partitioning( repart.partitioning.as_ref(), - ctx, + decode_ctx, input.schema().as_ref(), extension_codec, )?; @@ -837,12 +894,11 @@ impl protobuf::PhysicalPlanNode { fn try_into_global_limit_physical_plan( &self, limit: &protobuf::GlobalLimitExecNode, - ctx: &TaskContext, - + decode_ctx: &DecodeContext, extension_codec: &dyn PhysicalExtensionCodec, ) -> Result> { let input: Arc = - into_physical_plan(&limit.input, ctx, extension_codec)?; + into_physical_plan(&limit.input, decode_ctx, extension_codec)?; let fetch = if limit.fetch >= 0 { Some(limit.fetch as usize) } else { @@ -858,24 +914,22 @@ impl protobuf::PhysicalPlanNode { fn try_into_local_limit_physical_plan( &self, limit: &protobuf::LocalLimitExecNode, - ctx: &TaskContext, - + decode_ctx: &DecodeContext, extension_codec: &dyn PhysicalExtensionCodec, ) -> Result> { let input: Arc = - into_physical_plan(&limit.input, ctx, extension_codec)?; + into_physical_plan(&limit.input, decode_ctx, extension_codec)?; Ok(Arc::new(LocalLimitExec::new(input, limit.fetch as usize))) } fn try_into_window_physical_plan( &self, window_agg: &protobuf::WindowAggExecNode, - ctx: &TaskContext, - + decode_ctx: &DecodeContext, extension_codec: &dyn PhysicalExtensionCodec, ) -> Result> { let input: Arc = - into_physical_plan(&window_agg.input, ctx, extension_codec)?; + into_physical_plan(&window_agg.input, decode_ctx, extension_codec)?; let input_schema = input.schema(); let physical_window_expr: Vec> = window_agg @@ -884,7 +938,7 @@ impl protobuf::PhysicalPlanNode { .map(|window_expr| { parse_physical_window_expr( window_expr, - ctx, + decode_ctx, input_schema.as_ref(), extension_codec, ) @@ -895,7 +949,12 @@ impl protobuf::PhysicalPlanNode { .partition_keys .iter() .map(|expr| { - parse_physical_expr(expr, ctx, input.schema().as_ref(), extension_codec) + parse_physical_expr( + expr, + decode_ctx, + input.schema().as_ref(), + extension_codec, + ) }) .collect::>>>()?; @@ -928,12 +987,11 @@ impl protobuf::PhysicalPlanNode { fn try_into_aggregate_physical_plan( &self, hash_agg: &protobuf::AggregateExecNode, - ctx: &TaskContext, - + decode_ctx: &DecodeContext, extension_codec: &dyn PhysicalExtensionCodec, ) -> Result> { let input: Arc = - into_physical_plan(&hash_agg.input, ctx, extension_codec)?; + into_physical_plan(&hash_agg.input, decode_ctx, extension_codec)?; let mode = protobuf::AggregateMode::try_from(hash_agg.mode).map_err(|_| { proto_error(format!( "Received a AggregateNode message with unknown AggregateMode {}", @@ -957,8 +1015,13 @@ impl protobuf::PhysicalPlanNode { .iter() .zip(hash_agg.group_expr_name.iter()) .map(|(expr, name)| { - parse_physical_expr(expr, ctx, input.schema().as_ref(), extension_codec) - .map(|expr| (expr, name.to_string())) + parse_physical_expr( + expr, + decode_ctx, + input.schema().as_ref(), + extension_codec, + ) + .map(|expr| (expr, name.to_string())) }) .collect::, _>>()?; @@ -967,8 +1030,13 @@ impl protobuf::PhysicalPlanNode { .iter() .zip(hash_agg.group_expr_name.iter()) .map(|(expr, name)| { - parse_physical_expr(expr, ctx, input.schema().as_ref(), extension_codec) - .map(|expr| (expr, name.to_string())) + parse_physical_expr( + expr, + decode_ctx, + input.schema().as_ref(), + extension_codec, + ) + .map(|expr| (expr, name.to_string())) }) .collect::, _>>()?; @@ -994,7 +1062,12 @@ impl protobuf::PhysicalPlanNode { expr.expr .as_ref() .map(|e| { - parse_physical_expr(e, ctx, &physical_schema, extension_codec) + parse_physical_expr( + e, + decode_ctx, + &physical_schema, + extension_codec, + ) }) .transpose() }) @@ -1017,7 +1090,7 @@ impl protobuf::PhysicalPlanNode { .map(|e| { parse_physical_expr( e, - ctx, + decode_ctx, &physical_schema, extension_codec, ) @@ -1029,7 +1102,7 @@ impl protobuf::PhysicalPlanNode { .map(|e| { parse_physical_sort_expr( e, - ctx, + decode_ctx, &physical_schema, extension_codec, ) @@ -1043,10 +1116,13 @@ impl protobuf::PhysicalPlanNode { let agg_udf = match &agg_node.fun_definition { Some(buf) => extension_codec .try_decode_udaf(udaf_name, buf)?, - None => ctx.udaf(udaf_name).or_else(|_| { - extension_codec - .try_decode_udaf(udaf_name, &[]) - })?, + None => decode_ctx + .task_context() + .udaf(udaf_name) + .or_else(|_| { + extension_codec + .try_decode_udaf(udaf_name, &[]) + })?, }; AggregateExprBuilder::new(agg_udf, input_phy_expr) @@ -1094,14 +1170,13 @@ impl protobuf::PhysicalPlanNode { fn try_into_hash_join_physical_plan( &self, hashjoin: &protobuf::HashJoinExecNode, - ctx: &TaskContext, - + decode_ctx: &DecodeContext, extension_codec: &dyn PhysicalExtensionCodec, ) -> Result> { let left: Arc = - into_physical_plan(&hashjoin.left, ctx, extension_codec)?; + into_physical_plan(&hashjoin.left, decode_ctx, extension_codec)?; let right: Arc = - into_physical_plan(&hashjoin.right, ctx, extension_codec)?; + into_physical_plan(&hashjoin.right, decode_ctx, extension_codec)?; let left_schema = left.schema(); let right_schema = right.schema(); let on: Vec<(PhysicalExprRef, PhysicalExprRef)> = hashjoin @@ -1110,13 +1185,13 @@ impl protobuf::PhysicalPlanNode { .map(|col| { let left = parse_physical_expr( &col.left.clone().unwrap(), - ctx, + decode_ctx, left_schema.as_ref(), extension_codec, )?; let right = parse_physical_expr( &col.right.clone().unwrap(), - ctx, + decode_ctx, right_schema.as_ref(), extension_codec, )?; @@ -1151,7 +1226,8 @@ impl protobuf::PhysicalPlanNode { f.expression.as_ref().ok_or_else(|| { proto_error("Unexpected empty filter expression") })?, - ctx, &schema, + decode_ctx, + &schema, extension_codec, )?; let column_indices = f.column_indices @@ -1212,12 +1288,11 @@ impl protobuf::PhysicalPlanNode { fn try_into_symmetric_hash_join_physical_plan( &self, sym_join: &protobuf::SymmetricHashJoinExecNode, - ctx: &TaskContext, - + decode_ctx: &DecodeContext, extension_codec: &dyn PhysicalExtensionCodec, ) -> Result> { - let left = into_physical_plan(&sym_join.left, ctx, extension_codec)?; - let right = into_physical_plan(&sym_join.right, ctx, extension_codec)?; + let left = into_physical_plan(&sym_join.left, decode_ctx, extension_codec)?; + let right = into_physical_plan(&sym_join.right, decode_ctx, extension_codec)?; let left_schema = left.schema(); let right_schema = right.schema(); let on = sym_join @@ -1226,13 +1301,13 @@ impl protobuf::PhysicalPlanNode { .map(|col| { let left = parse_physical_expr( &col.left.clone().unwrap(), - ctx, + decode_ctx, left_schema.as_ref(), extension_codec, )?; let right = parse_physical_expr( &col.right.clone().unwrap(), - ctx, + decode_ctx, right_schema.as_ref(), extension_codec, )?; @@ -1267,7 +1342,7 @@ impl protobuf::PhysicalPlanNode { f.expression.as_ref().ok_or_else(|| { proto_error("Unexpected empty filter expression") })?, - ctx, &schema, + decode_ctx, &schema, extension_codec, )?; let column_indices = f.column_indices @@ -1292,7 +1367,7 @@ impl protobuf::PhysicalPlanNode { let left_sort_exprs = parse_physical_sort_exprs( &sym_join.left_sort_exprs, - ctx, + decode_ctx, &left_schema, extension_codec, )?; @@ -1300,7 +1375,7 @@ impl protobuf::PhysicalPlanNode { let right_sort_exprs = parse_physical_sort_exprs( &sym_join.right_sort_exprs, - ctx, + decode_ctx, &right_schema, extension_codec, )?; @@ -1340,13 +1415,12 @@ impl protobuf::PhysicalPlanNode { fn try_into_union_physical_plan( &self, union: &protobuf::UnionExecNode, - ctx: &TaskContext, - + decode_ctx: &DecodeContext, extension_codec: &dyn PhysicalExtensionCodec, ) -> Result> { let mut inputs: Vec> = vec![]; for input in &union.inputs { - inputs.push(input.try_into_physical_plan(ctx, extension_codec)?); + inputs.push(input.try_into_physical_plan(decode_ctx, extension_codec)?); } UnionExec::try_new(inputs) } @@ -1354,13 +1428,12 @@ impl protobuf::PhysicalPlanNode { fn try_into_interleave_physical_plan( &self, interleave: &protobuf::InterleaveExecNode, - ctx: &TaskContext, - + decode_ctx: &DecodeContext, extension_codec: &dyn PhysicalExtensionCodec, ) -> Result> { let mut inputs: Vec> = vec![]; for input in &interleave.inputs { - inputs.push(input.try_into_physical_plan(ctx, extension_codec)?); + inputs.push(input.try_into_physical_plan(decode_ctx, extension_codec)?); } Ok(Arc::new(InterleaveExec::try_new(inputs)?)) } @@ -1368,21 +1441,20 @@ impl protobuf::PhysicalPlanNode { fn try_into_cross_join_physical_plan( &self, crossjoin: &protobuf::CrossJoinExecNode, - ctx: &TaskContext, - + decode_ctx: &DecodeContext, extension_codec: &dyn PhysicalExtensionCodec, ) -> Result> { let left: Arc = - into_physical_plan(&crossjoin.left, ctx, extension_codec)?; + into_physical_plan(&crossjoin.left, decode_ctx, extension_codec)?; let right: Arc = - into_physical_plan(&crossjoin.right, ctx, extension_codec)?; + into_physical_plan(&crossjoin.right, decode_ctx, extension_codec)?; Ok(Arc::new(CrossJoinExec::new(left, right))) } fn try_into_empty_physical_plan( &self, empty: &protobuf::EmptyExecNode, - _ctx: &TaskContext, + _decode_ctx: &DecodeContext, _extension_codec: &dyn PhysicalExtensionCodec, ) -> Result> { @@ -1393,7 +1465,7 @@ impl protobuf::PhysicalPlanNode { fn try_into_placeholder_row_physical_plan( &self, placeholder: &protobuf::PlaceholderRowExecNode, - _ctx: &TaskContext, + _decode_ctx: &DecodeContext, _extension_codec: &dyn PhysicalExtensionCodec, ) -> Result> { @@ -1404,11 +1476,10 @@ impl protobuf::PhysicalPlanNode { fn try_into_sort_physical_plan( &self, sort: &protobuf::SortExecNode, - ctx: &TaskContext, - + decode_ctx: &DecodeContext, extension_codec: &dyn PhysicalExtensionCodec, ) -> Result> { - let input = into_physical_plan(&sort.input, ctx, extension_codec)?; + let input = into_physical_plan(&sort.input, decode_ctx, extension_codec)?; let exprs = sort .expr .iter() @@ -1429,7 +1500,7 @@ impl protobuf::PhysicalPlanNode { })? .as_ref(); Ok(PhysicalSortExpr { - expr: parse_physical_expr(expr, ctx, input.schema().as_ref(), extension_codec)?, + expr: parse_physical_expr(expr, decode_ctx, input.schema().as_ref(), extension_codec)?, options: SortOptions { descending: !sort_expr.asc, nulls_first: sort_expr.nulls_first, @@ -1456,11 +1527,10 @@ impl protobuf::PhysicalPlanNode { fn try_into_sort_preserving_merge_physical_plan( &self, sort: &protobuf::SortPreservingMergeExecNode, - ctx: &TaskContext, - + decode_ctx: &DecodeContext, extension_codec: &dyn PhysicalExtensionCodec, ) -> Result> { - let input = into_physical_plan(&sort.input, ctx, extension_codec)?; + let input = into_physical_plan(&sort.input, decode_ctx, extension_codec)?; let exprs = sort .expr .iter() @@ -1483,7 +1553,7 @@ impl protobuf::PhysicalPlanNode { Ok(PhysicalSortExpr { expr: parse_physical_expr( expr, - ctx, + decode_ctx, input.schema().as_ref(), extension_codec, )?, @@ -1509,18 +1579,18 @@ impl protobuf::PhysicalPlanNode { fn try_into_extension_physical_plan( &self, extension: &protobuf::PhysicalExtensionNode, - ctx: &TaskContext, + decode_ctx: &DecodeContext, extension_codec: &dyn PhysicalExtensionCodec, ) -> Result> { let inputs: Vec> = extension .inputs .iter() - .map(|i| i.try_into_physical_plan(ctx, extension_codec)) + .map(|i| i.try_into_physical_plan(decode_ctx, extension_codec)) .collect::>()?; let extension_node = - extension_codec.try_decode(extension.node.as_slice(), &inputs, ctx)?; + extension_codec.try_decode(extension.node.as_slice(), &inputs, decode_ctx)?; Ok(extension_node) } @@ -1528,14 +1598,13 @@ impl protobuf::PhysicalPlanNode { fn try_into_nested_loop_join_physical_plan( &self, join: &protobuf::NestedLoopJoinExecNode, - ctx: &TaskContext, - + decode_ctx: &DecodeContext, extension_codec: &dyn PhysicalExtensionCodec, ) -> Result> { let left: Arc = - into_physical_plan(&join.left, ctx, extension_codec)?; + into_physical_plan(&join.left, decode_ctx, extension_codec)?; let right: Arc = - into_physical_plan(&join.right, ctx, extension_codec)?; + into_physical_plan(&join.right, decode_ctx, extension_codec)?; let join_type = protobuf::JoinType::try_from(join.join_type).map_err(|_| { proto_error(format!( "Received a NestedLoopJoinExecNode message with unknown JoinType {}", @@ -1556,7 +1625,7 @@ impl protobuf::PhysicalPlanNode { f.expression.as_ref().ok_or_else(|| { proto_error("Unexpected empty filter expression") })?, - ctx, &schema, + decode_ctx, &schema, extension_codec, )?; let column_indices = f.column_indices @@ -1602,12 +1671,11 @@ impl protobuf::PhysicalPlanNode { fn try_into_analyze_physical_plan( &self, analyze: &protobuf::AnalyzeExecNode, - ctx: &TaskContext, - + decode_ctx: &DecodeContext, extension_codec: &dyn PhysicalExtensionCodec, ) -> Result> { let input: Arc = - into_physical_plan(&analyze.input, ctx, extension_codec)?; + into_physical_plan(&analyze.input, decode_ctx, extension_codec)?; Ok(Arc::new(AnalyzeExec::new( analyze.verbose, analyze.show_statistics, @@ -1620,11 +1688,10 @@ impl protobuf::PhysicalPlanNode { fn try_into_json_sink_physical_plan( &self, sink: &protobuf::JsonSinkExecNode, - ctx: &TaskContext, - + decode_ctx: &DecodeContext, extension_codec: &dyn PhysicalExtensionCodec, ) -> Result> { - let input = into_physical_plan(&sink.input, ctx, extension_codec)?; + let input = into_physical_plan(&sink.input, decode_ctx, extension_codec)?; let data_sink: JsonSink = sink .sink @@ -1638,7 +1705,7 @@ impl protobuf::PhysicalPlanNode { .map(|collection| { parse_physical_sort_exprs( &collection.physical_sort_expr_nodes, - ctx, + decode_ctx, &sink_schema, extension_codec, ) @@ -1658,11 +1725,10 @@ impl protobuf::PhysicalPlanNode { fn try_into_csv_sink_physical_plan( &self, sink: &protobuf::CsvSinkExecNode, - ctx: &TaskContext, - + decode_ctx: &DecodeContext, extension_codec: &dyn PhysicalExtensionCodec, ) -> Result> { - let input = into_physical_plan(&sink.input, ctx, extension_codec)?; + let input = into_physical_plan(&sink.input, decode_ctx, extension_codec)?; let data_sink: CsvSink = sink .sink @@ -1676,7 +1742,7 @@ impl protobuf::PhysicalPlanNode { .map(|collection| { parse_physical_sort_exprs( &collection.physical_sort_expr_nodes, - ctx, + decode_ctx, &sink_schema, extension_codec, ) @@ -1697,13 +1763,12 @@ impl protobuf::PhysicalPlanNode { fn try_into_parquet_sink_physical_plan( &self, sink: &protobuf::ParquetSinkExecNode, - ctx: &TaskContext, - + decode_ctx: &DecodeContext, extension_codec: &dyn PhysicalExtensionCodec, ) -> Result> { #[cfg(feature = "parquet")] { - let input = into_physical_plan(&sink.input, ctx, extension_codec)?; + let input = into_physical_plan(&sink.input, decode_ctx, extension_codec)?; let data_sink: ParquetSink = sink .sink @@ -1717,7 +1782,7 @@ impl protobuf::PhysicalPlanNode { .map(|collection| { parse_physical_sort_exprs( &collection.physical_sort_expr_nodes, - ctx, + decode_ctx, &sink_schema, extension_codec, ) @@ -1740,11 +1805,10 @@ impl protobuf::PhysicalPlanNode { fn try_into_unnest_physical_plan( &self, unnest: &protobuf::UnnestExecNode, - ctx: &TaskContext, - + decode_ctx: &DecodeContext, extension_codec: &dyn PhysicalExtensionCodec, ) -> Result> { - let input = into_physical_plan(&unnest.input, ctx, extension_codec)?; + let input = into_physical_plan(&unnest.input, decode_ctx, extension_codec)?; Ok(Arc::new(UnnestExec::new( input, @@ -1771,13 +1835,13 @@ impl protobuf::PhysicalPlanNode { fn try_into_sort_join( &self, sort_join: &SortMergeJoinExecNode, - ctx: &TaskContext, + decode_ctx: &DecodeContext, extension_codec: &dyn PhysicalExtensionCodec, ) -> Result> { - let left = into_physical_plan(&sort_join.left, ctx, extension_codec)?; + let left = into_physical_plan(&sort_join.left, decode_ctx, extension_codec)?; let left_schema = left.schema(); - let right = into_physical_plan(&sort_join.right, ctx, extension_codec)?; + let right = into_physical_plan(&sort_join.right, decode_ctx, extension_codec)?; let right_schema = right.schema(); let filter = sort_join @@ -1794,7 +1858,7 @@ impl protobuf::PhysicalPlanNode { f.expression.as_ref().ok_or_else(|| { proto_error("Unexpected empty filter expression") })?, - ctx, + decode_ctx, &schema, extension_codec, )?; @@ -1855,13 +1919,13 @@ impl protobuf::PhysicalPlanNode { .map(|col| { let left = parse_physical_expr( &col.left.clone().unwrap(), - ctx, + decode_ctx, left_schema.as_ref(), extension_codec, )?; let right = parse_physical_expr( &col.right.clone().unwrap(), - ctx, + decode_ctx, right_schema.as_ref(), extension_codec, )?; @@ -1948,11 +2012,10 @@ impl protobuf::PhysicalPlanNode { fn try_into_cooperative_physical_plan( &self, field_stream: &protobuf::CooperativeExecNode, - ctx: &TaskContext, - + decode_ctx: &DecodeContext, extension_codec: &dyn PhysicalExtensionCodec, ) -> Result> { - let input = into_physical_plan(&field_stream.input, ctx, extension_codec)?; + let input = into_physical_plan(&field_stream.input, decode_ctx, extension_codec)?; Ok(Arc::new(CooperativeExec::new(input))) } @@ -2736,6 +2799,7 @@ impl protobuf::PhysicalPlanNode { nulls_first: expr.options.nulls_first, }); Ok(protobuf::PhysicalExprNode { + id: None, expr_type: Some(ExprType::Sort(sort_expr)), }) }) @@ -2812,6 +2876,7 @@ impl protobuf::PhysicalPlanNode { nulls_first: expr.options.nulls_first, }); Ok(protobuf::PhysicalExprNode { + id: None, expr_type: Some(ExprType::Sort(sort_expr)), }) }) @@ -3220,8 +3285,7 @@ pub trait AsExecutionPlan: Debug + Send + Sync + Clone { fn try_into_physical_plan( &self, - ctx: &TaskContext, - + decode_ctx: &DecodeContext, extension_codec: &dyn PhysicalExtensionCodec, ) -> Result>; @@ -3238,7 +3302,7 @@ pub trait PhysicalExtensionCodec: Debug + Send + Sync { &self, buf: &[u8], inputs: &[Arc], - ctx: &TaskContext, + decode_ctx: &DecodeContext, ) -> Result>; fn try_encode(&self, node: Arc, buf: &mut Vec) -> Result<()>; @@ -3294,7 +3358,7 @@ impl PhysicalExtensionCodec for DefaultPhysicalExtensionCodec { &self, _buf: &[u8], _inputs: &[Arc], - _ctx: &TaskContext, + _decode_ctx: &DecodeContext, ) -> Result> { not_impl_err!("PhysicalExtensionCodec is not provided") } @@ -3394,9 +3458,11 @@ impl PhysicalExtensionCodec for ComposedPhysicalExtensionCodec { &self, buf: &[u8], inputs: &[Arc], - ctx: &TaskContext, + decode_ctx: &DecodeContext, ) -> Result> { - self.decode_protobuf(buf, |codec, data| codec.try_decode(data, inputs, ctx)) + self.decode_protobuf(buf, |codec, data| { + codec.try_decode(data, inputs, decode_ctx) + }) } fn try_encode(&self, node: Arc, buf: &mut Vec) -> Result<()> { @@ -3422,12 +3488,11 @@ impl PhysicalExtensionCodec for ComposedPhysicalExtensionCodec { fn into_physical_plan( node: &Option>, - ctx: &TaskContext, - + decode_ctx: &DecodeContext, extension_codec: &dyn PhysicalExtensionCodec, ) -> Result> { if let Some(field) = node { - field.try_into_physical_plan(ctx, extension_codec) + field.try_into_physical_plan(decode_ctx, extension_codec) } else { Err(proto_error("Missing required field in protobuf")) } diff --git a/datafusion/proto/src/physical_plan/to_proto.rs b/datafusion/proto/src/physical_plan/to_proto.rs index 399c234191aa7..1a174d65f343b 100644 --- a/datafusion/proto/src/physical_plan/to_proto.rs +++ b/datafusion/proto/src/physical_plan/to_proto.rs @@ -56,6 +56,9 @@ pub fn serialize_physical_aggr_expr( aggr_expr: Arc, codec: &dyn PhysicalExtensionCodec, ) -> Result { + // Calculate the ID from the Arc pointer for caching during deserialization + let expr_id = Arc::as_ptr(&aggr_expr).addr() as u64; + let expressions = serialize_physical_exprs(&aggr_expr.expressions(), codec)?; let order_bys = serialize_physical_sort_exprs(aggr_expr.order_bys().iter().cloned(), codec)?; @@ -64,6 +67,7 @@ pub fn serialize_physical_aggr_expr( let mut buf = Vec::new(); codec.try_encode_udaf(aggr_expr.fun(), &mut buf)?; Ok(protobuf::PhysicalExprNode { + id: Some(expr_id), expr_type: Some(protobuf::physical_expr_node::ExprType::AggregateExpr( protobuf::PhysicalAggregateExprNode { aggregate_function: Some(physical_aggregate_expr_node::AggregateFunction::UserDefinedAggrFunction(name)), @@ -221,6 +225,10 @@ pub fn serialize_physical_expr( value: &Arc, codec: &dyn PhysicalExtensionCodec, ) -> Result { + // Calculate the ID from the Arc pointer for caching during deserialization + // Use the data pointer (not the vtable) for fat pointers + let expr_id = Arc::as_ptr(value) as *const () as usize as u64; + // Snapshot the expr in case it has dynamic predicate state so // it can be serialized let value = snapshot_physical_expr(Arc::clone(value))?; @@ -228,6 +236,7 @@ pub fn serialize_physical_expr( if let Some(expr) = expr.downcast_ref::() { Ok(protobuf::PhysicalExprNode { + id: Some(expr_id), expr_type: Some(protobuf::physical_expr_node::ExprType::Column( protobuf::PhysicalColumn { name: expr.name().to_string(), @@ -237,6 +246,7 @@ pub fn serialize_physical_expr( }) } else if let Some(expr) = expr.downcast_ref::() { Ok(protobuf::PhysicalExprNode { + id: Some(expr_id), expr_type: Some(protobuf::physical_expr_node::ExprType::UnknownColumn( protobuf::UnknownColumn { name: expr.name().to_string(), @@ -251,12 +261,14 @@ pub fn serialize_physical_expr( }); Ok(protobuf::PhysicalExprNode { + id: Some(expr_id), expr_type: Some(protobuf::physical_expr_node::ExprType::BinaryExpr( binary_expr, )), }) } else if let Some(expr) = expr.downcast_ref::() { Ok(protobuf::PhysicalExprNode { + id: Some(expr_id), expr_type: Some( protobuf::physical_expr_node::ExprType::Case( Box::new( @@ -288,6 +300,7 @@ pub fn serialize_physical_expr( }) } else if let Some(expr) = expr.downcast_ref::() { Ok(protobuf::PhysicalExprNode { + id: Some(expr_id), expr_type: Some(protobuf::physical_expr_node::ExprType::NotExpr(Box::new( protobuf::PhysicalNot { expr: Some(Box::new(serialize_physical_expr(expr.arg(), codec)?)), @@ -296,6 +309,7 @@ pub fn serialize_physical_expr( }) } else if let Some(expr) = expr.downcast_ref::() { Ok(protobuf::PhysicalExprNode { + id: Some(expr_id), expr_type: Some(protobuf::physical_expr_node::ExprType::IsNullExpr( Box::new(protobuf::PhysicalIsNull { expr: Some(Box::new(serialize_physical_expr(expr.arg(), codec)?)), @@ -304,6 +318,7 @@ pub fn serialize_physical_expr( }) } else if let Some(expr) = expr.downcast_ref::() { Ok(protobuf::PhysicalExprNode { + id: Some(expr_id), expr_type: Some(protobuf::physical_expr_node::ExprType::IsNotNullExpr( Box::new(protobuf::PhysicalIsNotNull { expr: Some(Box::new(serialize_physical_expr(expr.arg(), codec)?)), @@ -312,6 +327,7 @@ pub fn serialize_physical_expr( }) } else if let Some(expr) = expr.downcast_ref::() { Ok(protobuf::PhysicalExprNode { + id: Some(expr_id), expr_type: Some(protobuf::physical_expr_node::ExprType::InList(Box::new( protobuf::PhysicalInListNode { expr: Some(Box::new(serialize_physical_expr(expr.expr(), codec)?)), @@ -322,6 +338,7 @@ pub fn serialize_physical_expr( }) } else if let Some(expr) = expr.downcast_ref::() { Ok(protobuf::PhysicalExprNode { + id: Some(expr_id), expr_type: Some(protobuf::physical_expr_node::ExprType::Negative(Box::new( protobuf::PhysicalNegativeNode { expr: Some(Box::new(serialize_physical_expr(expr.arg(), codec)?)), @@ -330,12 +347,14 @@ pub fn serialize_physical_expr( }) } else if let Some(lit) = expr.downcast_ref::() { Ok(protobuf::PhysicalExprNode { + id: Some(expr_id), expr_type: Some(protobuf::physical_expr_node::ExprType::Literal( lit.value().try_into()?, )), }) } else if let Some(cast) = expr.downcast_ref::() { Ok(protobuf::PhysicalExprNode { + id: Some(expr_id), expr_type: Some(protobuf::physical_expr_node::ExprType::Cast(Box::new( protobuf::PhysicalCastNode { expr: Some(Box::new(serialize_physical_expr(cast.expr(), codec)?)), @@ -345,6 +364,7 @@ pub fn serialize_physical_expr( }) } else if let Some(cast) = expr.downcast_ref::() { Ok(protobuf::PhysicalExprNode { + id: Some(expr_id), expr_type: Some(protobuf::physical_expr_node::ExprType::TryCast(Box::new( protobuf::PhysicalTryCastNode { expr: Some(Box::new(serialize_physical_expr(cast.expr(), codec)?)), @@ -356,6 +376,7 @@ pub fn serialize_physical_expr( let mut buf = Vec::new(); codec.try_encode_udf(expr.fun(), &mut buf)?; Ok(protobuf::PhysicalExprNode { + id: Some(expr_id), expr_type: Some(protobuf::physical_expr_node::ExprType::ScalarUdf( protobuf::PhysicalScalarUdfNode { name: expr.name().to_string(), @@ -372,6 +393,7 @@ pub fn serialize_physical_expr( }) } else if let Some(expr) = expr.downcast_ref::() { Ok(protobuf::PhysicalExprNode { + id: Some(expr_id), expr_type: Some(protobuf::physical_expr_node::ExprType::LikeExpr(Box::new( protobuf::PhysicalLikeExprNode { negated: expr.negated(), @@ -394,6 +416,7 @@ pub fn serialize_physical_expr( .map(|e| serialize_physical_expr(e, codec)) .collect::>()?; Ok(protobuf::PhysicalExprNode { + id: Some(expr_id), expr_type: Some(protobuf::physical_expr_node::ExprType::Extension( protobuf::PhysicalExtensionExprNode { expr: buf, inputs }, )), diff --git a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs index e6cfcb95805a1..7b503d05bfc64 100644 --- a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs @@ -56,7 +56,6 @@ use datafusion::datasource::physical_plan::{ }; use datafusion::datasource::sink::DataSinkExec; use datafusion::datasource::source::DataSourceExec; -use datafusion::execution::TaskContext; use datafusion::functions_aggregate::count::count_udaf; use datafusion::functions_aggregate::sum::sum_udaf; use datafusion::functions_window::nth_value::nth_value_udwf; @@ -114,7 +113,7 @@ use datafusion_functions_aggregate::average::avg_udaf; use datafusion_functions_aggregate::nth_value::nth_value_udaf; use datafusion_functions_aggregate::string_agg::string_agg_udaf; use datafusion_proto::physical_plan::{ - AsExecutionPlan, DefaultPhysicalExtensionCodec, PhysicalExtensionCodec, + AsExecutionPlan, DecodeContext, DefaultPhysicalExtensionCodec, PhysicalExtensionCodec, }; use datafusion_proto::protobuf::{self, PhysicalPlanNode}; @@ -142,8 +141,10 @@ fn roundtrip_test_and_return( let proto: protobuf::PhysicalPlanNode = protobuf::PhysicalPlanNode::try_from_physical_plan(exec_plan.clone(), codec) .expect("to proto"); + let task_ctx = ctx.task_ctx(); + let decode_ctx = DecodeContext::new(&task_ctx); let result_exec_plan: Arc = proto - .try_into_physical_plan(&ctx.task_ctx(), codec) + .try_into_physical_plan(&decode_ctx, codec) .expect("from proto"); pretty_assertions::assert_eq!( @@ -774,6 +775,71 @@ fn roundtrip_filter_with_not_and_in_list() -> Result<()> { )?)) } +#[test] +fn roundtrip_expr_deduplication() -> Result<()> { + use datafusion_proto::physical_plan::DecodeContext; + use datafusion_proto::physical_plan::DefaultPhysicalExtensionCodec; + + let field_a = Field::new("a", DataType::Int64, false); + let schema = Arc::new(Schema::new(vec![field_a])); + + // Create a single column expression that will be reused + let col_expr = col("a", &schema)?; + + // Use the same expression twice in a binary expression (a = a) + let binary_expr = binary( + Arc::clone(&col_expr), + Operator::Eq, + Arc::clone(&col_expr), + &schema, + )?; + + let filter_exec = Arc::new(FilterExec::try_new( + binary_expr, + Arc::new(EmptyExec::new(schema.clone())), + )?); + + // Perform roundtrip + let ctx = SessionContext::new(); + let codec = DefaultPhysicalExtensionCodec {}; + let proto = + protobuf::PhysicalPlanNode::try_from_physical_plan(filter_exec.clone(), &codec)?; + + let task_ctx = ctx.task_ctx(); + let decode_ctx = DecodeContext::new(&task_ctx); + let result_exec_plan = proto.try_into_physical_plan(&decode_ctx, &codec)?; + + // Verify the plans are equivalent + pretty_assertions::assert_eq!( + format!("{filter_exec:?}"), + format!("{result_exec_plan:?}") + ); + + // Verify deduplication: extract the binary expression from the filter + let result_filter = result_exec_plan + .as_any() + .downcast_ref::() + .expect("should be FilterExec"); + + let result_binary = result_filter + .predicate() + .as_any() + .downcast_ref::() + .expect("should be BinaryExpr"); + + // Check that left and right expressions share the same Arc pointer + // (this proves deduplication worked) + let left_ptr = Arc::as_ptr(result_binary.left()).addr(); + let right_ptr = Arc::as_ptr(result_binary.right()).addr(); + + assert_eq!( + left_ptr, right_ptr, + "Left and right expressions should share the same Arc after deduplication" + ); + + Ok(()) +} + #[test] fn roundtrip_sort() -> Result<()> { let field_a = Field::new("a", DataType::Boolean, false); @@ -1027,7 +1093,7 @@ fn roundtrip_parquet_exec_with_custom_predicate_expr() -> Result<()> { &self, _buf: &[u8], _inputs: &[Arc], - _ctx: &TaskContext, + _ctx: &DecodeContext, ) -> Result> { unreachable!() } @@ -1135,7 +1201,7 @@ impl PhysicalExtensionCodec for UDFExtensionCodec { &self, _buf: &[u8], _inputs: &[Arc], - _ctx: &TaskContext, + _ctx: &DecodeContext, ) -> Result> { not_impl_err!("No extension codec provided") } @@ -1738,8 +1804,10 @@ async fn roundtrip_coalesce() -> Result<()> { )?; let node = PhysicalPlanNode::decode(node.encode_to_vec().as_slice()) .map_err(|e| DataFusionError::External(Box::new(e)))?; + let task_ctx = ctx.task_ctx(); + let decode_ctx = DecodeContext::new(&task_ctx); let restored = - node.try_into_physical_plan(&ctx.task_ctx(), &DefaultPhysicalExtensionCodec {})?; + node.try_into_physical_plan(&decode_ctx, &DefaultPhysicalExtensionCodec {})?; assert_eq!( plan.schema(), @@ -1774,8 +1842,10 @@ async fn roundtrip_generate_series() -> Result<()> { )?; let node = PhysicalPlanNode::decode(node.encode_to_vec().as_slice()) .map_err(|e| DataFusionError::External(Box::new(e)))?; + let task_ctx = ctx.task_ctx(); + let decode_ctx = DecodeContext::new(&task_ctx); let restored = - node.try_into_physical_plan(&ctx.task_ctx(), &DefaultPhysicalExtensionCodec {})?; + node.try_into_physical_plan(&decode_ctx, &DefaultPhysicalExtensionCodec {})?; assert_eq!( plan.schema(), @@ -1896,8 +1966,10 @@ async fn roundtrip_physical_plan_node() { PhysicalPlanNode::try_from_physical_plan(plan, &DefaultPhysicalExtensionCodec {}) .unwrap(); + let task_ctx = ctx.task_ctx(); + let decode_ctx = DecodeContext::new(&task_ctx); let plan = node - .try_into_physical_plan(&ctx.task_ctx(), &DefaultPhysicalExtensionCodec {}) + .try_into_physical_plan(&decode_ctx, &DefaultPhysicalExtensionCodec {}) .unwrap(); let _ = plan.execute(0, ctx.task_ctx()).unwrap(); @@ -1976,8 +2048,9 @@ async fn test_serialize_deserialize_tpch_queries() -> Result<()> { PhysicalPlanNode::try_from_physical_plan(physical_plan.clone(), &codec)?; // deserialize the physical plan - let _deserialized_plan = - proto.try_into_physical_plan(&ctx.task_ctx(), &codec)?; + let task_ctx = ctx.task_ctx(); + let decode_ctx = DecodeContext::new(&task_ctx); + let _deserialized_plan = proto.try_into_physical_plan(&decode_ctx, &codec)?; } } @@ -2096,7 +2169,9 @@ async fn test_tpch_part_in_list_query_with_real_parquet_data() -> Result<()> { let proto = PhysicalPlanNode::try_from_physical_plan(physical_plan.clone(), &codec)?; // This will fail with the bug, but should succeed when fixed - let _deserialized_plan = proto.try_into_physical_plan(&ctx.task_ctx(), &codec)?; + let task_ctx = ctx.task_ctx(); + let decode_ctx = DecodeContext::new(&task_ctx); + let _deserialized_plan = proto.try_into_physical_plan(&decode_ctx, &codec)?; Ok(()) } @@ -2124,8 +2199,10 @@ async fn analyze_roundtrip_unoptimized() -> Result<()> { let node = PhysicalPlanNode::decode(node.encode_to_vec().as_slice()) .map_err(|e| DataFusionError::External(Box::new(e)))?; + let task_ctx = ctx.task_ctx(); + let decode_ctx = DecodeContext::new(&task_ctx); let unoptimized = - node.try_into_physical_plan(&ctx.task_ctx(), &DefaultPhysicalExtensionCodec {})?; + node.try_into_physical_plan(&decode_ctx, &DefaultPhysicalExtensionCodec {})?; let physical_planner = datafusion::physical_planner::DefaultPhysicalPlanner::default();