diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index 1ad7a2c3afaf8..528f977a94d3a 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -751,6 +751,7 @@ message AggregateUDFExprNode { message ScalarUDFExprNode { string fun_name = 1; repeated LogicalExprNode args = 2; + optional bytes fun_definition = 3; } enum BuiltInWindowFunction { diff --git a/datafusion/proto/src/bytes/mod.rs b/datafusion/proto/src/bytes/mod.rs index d4abb9ed9c6f6..610c533d574cb 100644 --- a/datafusion/proto/src/bytes/mod.rs +++ b/datafusion/proto/src/bytes/mod.rs @@ -16,6 +16,7 @@ // under the License. //! Serialization / Deserialization to Bytes +use crate::logical_plan::to_proto::serialize_expr; use crate::logical_plan::{ self, AsLogicalPlan, DefaultLogicalExtensionCodec, LogicalExtensionCodec, }; @@ -87,8 +88,8 @@ pub trait Serializeable: Sized { impl Serializeable for Expr { fn to_bytes(&self) -> Result { let mut buffer = BytesMut::new(); - let protobuf: protobuf::LogicalExprNode = self - .try_into() + let extension_codec = DefaultLogicalExtensionCodec {}; + let protobuf: protobuf::LogicalExprNode = serialize_expr(self, &extension_codec) .map_err(|e| plan_datafusion_err!("Error encoding expr as protobuf: {e}"))?; protobuf @@ -177,7 +178,8 @@ impl Serializeable for Expr { let protobuf = protobuf::LogicalExprNode::decode(bytes) .map_err(|e| plan_datafusion_err!("Error decoding expr as protobuf: {e}"))?; - logical_plan::from_proto::parse_expr(&protobuf, registry) + let extension_codec = DefaultLogicalExtensionCodec {}; + logical_plan::from_proto::parse_expr(&protobuf, registry, &extension_codec) .map_err(|e| plan_datafusion_err!("Error parsing protobuf into Expr: {e}")) } } diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index 33ebdf310ae01..d6ee204d5cf3c 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -23381,6 +23381,9 @@ impl serde::Serialize for ScalarUdfExprNode { if !self.args.is_empty() { len += 1; } + if self.fun_definition.is_some() { + len += 1; + } let mut struct_ser = serializer.serialize_struct("datafusion.ScalarUDFExprNode", len)?; if !self.fun_name.is_empty() { struct_ser.serialize_field("funName", &self.fun_name)?; @@ -23388,6 +23391,10 @@ impl serde::Serialize for ScalarUdfExprNode { if !self.args.is_empty() { struct_ser.serialize_field("args", &self.args)?; } + if let Some(v) = self.fun_definition.as_ref() { + #[allow(clippy::needless_borrow)] + struct_ser.serialize_field("funDefinition", pbjson::private::base64::encode(&v).as_str())?; + } struct_ser.end() } } @@ -23401,12 +23408,15 @@ impl<'de> serde::Deserialize<'de> for ScalarUdfExprNode { "fun_name", "funName", "args", + "fun_definition", + "funDefinition", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { FunName, Args, + FunDefinition, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -23430,6 +23440,7 @@ impl<'de> serde::Deserialize<'de> for ScalarUdfExprNode { match value { "funName" | "fun_name" => Ok(GeneratedField::FunName), "args" => Ok(GeneratedField::Args), + "funDefinition" | "fun_definition" => Ok(GeneratedField::FunDefinition), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -23451,6 +23462,7 @@ impl<'de> serde::Deserialize<'de> for ScalarUdfExprNode { { let mut fun_name__ = None; let mut args__ = None; + let mut fun_definition__ = None; while let Some(k) = map_.next_key()? { match k { GeneratedField::FunName => { @@ -23465,11 +23477,20 @@ impl<'de> serde::Deserialize<'de> for ScalarUdfExprNode { } args__ = Some(map_.next_value()?); } + GeneratedField::FunDefinition => { + if fun_definition__.is_some() { + return Err(serde::de::Error::duplicate_field("funDefinition")); + } + fun_definition__ = + map_.next_value::<::std::option::Option<::pbjson::private::BytesDeserialize<_>>>()?.map(|x| x.0) + ; + } } } Ok(ScalarUdfExprNode { fun_name: fun_name__.unwrap_or_default(), args: args__.unwrap_or_default(), + fun_definition: fun_definition__, }) } } diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index 2d21f15570ddc..432dd4a8a6a07 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -895,6 +895,8 @@ pub struct ScalarUdfExprNode { pub fun_name: ::prost::alloc::string::String, #[prost(message, repeated, tag = "2")] pub args: ::prost::alloc::vec::Vec, + #[prost(bytes = "vec", optional, tag = "3")] + pub fun_definition: ::core::option::Option<::prost::alloc::vec::Vec>, } #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index ab7065cfbd857..b1fd128d0833b 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -76,6 +76,8 @@ use datafusion_expr::{ expr::{Alias, Placeholder}, }; +use super::LogicalExtensionCodec; + #[derive(Debug)] pub enum Error { General(String), @@ -976,6 +978,7 @@ pub fn parse_i32_to_aggregate_function(value: &i32) -> Result Result { use protobuf::{logical_expr_node::ExprType, window_expr_node, ScalarFunction}; @@ -990,7 +993,7 @@ pub fn parse_expr( let operands = binary_expr .operands .iter() - .map(|expr| parse_expr(expr, registry)) + .map(|expr| parse_expr(expr, registry, codec)) .collect::, _>>()?; if operands.len() < 2 { @@ -1009,8 +1012,12 @@ pub fn parse_expr( .expect("Binary expression could not be reduced to a single expression.")) } ExprType::GetIndexedField(get_indexed_field) => { - let expr = - parse_required_expr(get_indexed_field.expr.as_deref(), registry, "expr")?; + let expr = parse_required_expr( + get_indexed_field.expr.as_deref(), + registry, + "expr", + codec, + )?; let field = match &get_indexed_field.field { Some(protobuf::get_indexed_field::Field::NamedStructField( named_struct_field, @@ -1027,6 +1034,7 @@ pub fn parse_expr( list_index.key.as_deref(), registry, "key", + codec, )?), } } @@ -1036,16 +1044,19 @@ pub fn parse_expr( list_range.start.as_deref(), registry, "start", + codec, )?), stop: Box::new(parse_required_expr( list_range.stop.as_deref(), registry, "stop", + codec, )?), stride: Box::new(parse_required_expr( list_range.stride.as_deref(), registry, "stride", + codec, )?), } } @@ -1070,12 +1081,12 @@ pub fn parse_expr( let partition_by = expr .partition_by .iter() - .map(|e| parse_expr(e, registry)) + .map(|e| parse_expr(e, registry, codec)) .collect::, _>>()?; let mut order_by = expr .order_by .iter() - .map(|e| parse_expr(e, registry)) + .map(|e| parse_expr(e, registry, codec)) .collect::, _>>()?; let window_frame = expr .window_frame @@ -1103,7 +1114,7 @@ pub fn parse_expr( datafusion_expr::expr::WindowFunctionDefinition::AggregateFunction( aggr_function, ), - vec![parse_required_expr(expr.expr.as_deref(), registry, "expr")?], + vec![parse_required_expr(expr.expr.as_deref(), registry, "expr", codec)?], partition_by, order_by, window_frame, @@ -1115,9 +1126,10 @@ pub fn parse_expr( .map_err(|_| Error::unknown("BuiltInWindowFunction", *i))? .into(); - let args = parse_optional_expr(expr.expr.as_deref(), registry)? - .map(|e| vec![e]) - .unwrap_or_else(Vec::new); + let args = + parse_optional_expr(expr.expr.as_deref(), registry, codec)? + .map(|e| vec![e]) + .unwrap_or_else(Vec::new); Ok(Expr::WindowFunction(WindowFunction::new( datafusion_expr::expr::WindowFunctionDefinition::BuiltInWindowFunction( @@ -1132,9 +1144,10 @@ pub fn parse_expr( } window_expr_node::WindowFunction::Udaf(udaf_name) => { let udaf_function = registry.udaf(udaf_name)?; - let args = parse_optional_expr(expr.expr.as_deref(), registry)? - .map(|e| vec![e]) - .unwrap_or_else(Vec::new); + let args = + parse_optional_expr(expr.expr.as_deref(), registry, codec)? + .map(|e| vec![e]) + .unwrap_or_else(Vec::new); Ok(Expr::WindowFunction(WindowFunction::new( datafusion_expr::expr::WindowFunctionDefinition::AggregateUDF( udaf_function, @@ -1148,9 +1161,10 @@ pub fn parse_expr( } window_expr_node::WindowFunction::Udwf(udwf_name) => { let udwf_function = registry.udwf(udwf_name)?; - let args = parse_optional_expr(expr.expr.as_deref(), registry)? - .map(|e| vec![e]) - .unwrap_or_else(Vec::new); + let args = + parse_optional_expr(expr.expr.as_deref(), registry, codec)? + .map(|e| vec![e]) + .unwrap_or_else(Vec::new); Ok(Expr::WindowFunction(WindowFunction::new( datafusion_expr::expr::WindowFunctionDefinition::WindowUDF( udwf_function, @@ -1171,15 +1185,16 @@ pub fn parse_expr( fun, expr.expr .iter() - .map(|e| parse_expr(e, registry)) + .map(|e| parse_expr(e, registry, codec)) .collect::, _>>()?, expr.distinct, - parse_optional_expr(expr.filter.as_deref(), registry)?.map(Box::new), - parse_vec_expr(&expr.order_by, registry)?, + parse_optional_expr(expr.filter.as_deref(), registry, codec)? + .map(Box::new), + parse_vec_expr(&expr.order_by, registry, codec)?, ))) } ExprType::Alias(alias) => Ok(Expr::Alias(Alias::new( - parse_required_expr(alias.expr.as_deref(), registry, "expr")?, + parse_required_expr(alias.expr.as_deref(), registry, "expr", codec)?, alias .relation .first() @@ -1191,90 +1206,118 @@ pub fn parse_expr( is_null.expr.as_deref(), registry, "expr", + codec, )?))), ExprType::IsNotNullExpr(is_not_null) => Ok(Expr::IsNotNull(Box::new( - parse_required_expr(is_not_null.expr.as_deref(), registry, "expr")?, + parse_required_expr(is_not_null.expr.as_deref(), registry, "expr", codec)?, ))), ExprType::NotExpr(not) => Ok(Expr::Not(Box::new(parse_required_expr( not.expr.as_deref(), registry, "expr", + codec, )?))), ExprType::IsTrue(msg) => Ok(Expr::IsTrue(Box::new(parse_required_expr( msg.expr.as_deref(), registry, "expr", + codec, )?))), ExprType::IsFalse(msg) => Ok(Expr::IsFalse(Box::new(parse_required_expr( msg.expr.as_deref(), registry, "expr", + codec, )?))), ExprType::IsUnknown(msg) => Ok(Expr::IsUnknown(Box::new(parse_required_expr( msg.expr.as_deref(), registry, "expr", + codec, )?))), ExprType::IsNotTrue(msg) => Ok(Expr::IsNotTrue(Box::new(parse_required_expr( msg.expr.as_deref(), registry, "expr", + codec, )?))), ExprType::IsNotFalse(msg) => Ok(Expr::IsNotFalse(Box::new(parse_required_expr( msg.expr.as_deref(), registry, "expr", + codec, )?))), ExprType::IsNotUnknown(msg) => Ok(Expr::IsNotUnknown(Box::new( - parse_required_expr(msg.expr.as_deref(), registry, "expr")?, + parse_required_expr(msg.expr.as_deref(), registry, "expr", codec)?, ))), ExprType::Between(between) => Ok(Expr::Between(Between::new( Box::new(parse_required_expr( between.expr.as_deref(), registry, "expr", + codec, )?), between.negated, Box::new(parse_required_expr( between.low.as_deref(), registry, "expr", + codec, )?), Box::new(parse_required_expr( between.high.as_deref(), registry, "expr", + codec, )?), ))), ExprType::Like(like) => Ok(Expr::Like(Like::new( like.negated, - Box::new(parse_required_expr(like.expr.as_deref(), registry, "expr")?), + Box::new(parse_required_expr( + like.expr.as_deref(), + registry, + "expr", + codec, + )?), Box::new(parse_required_expr( like.pattern.as_deref(), registry, "pattern", + codec, )?), parse_escape_char(&like.escape_char)?, false, ))), ExprType::Ilike(like) => Ok(Expr::Like(Like::new( like.negated, - Box::new(parse_required_expr(like.expr.as_deref(), registry, "expr")?), + Box::new(parse_required_expr( + like.expr.as_deref(), + registry, + "expr", + codec, + )?), Box::new(parse_required_expr( like.pattern.as_deref(), registry, "pattern", + codec, )?), parse_escape_char(&like.escape_char)?, true, ))), ExprType::SimilarTo(like) => Ok(Expr::SimilarTo(Like::new( like.negated, - Box::new(parse_required_expr(like.expr.as_deref(), registry, "expr")?), + Box::new(parse_required_expr( + like.expr.as_deref(), + registry, + "expr", + codec, + )?), Box::new(parse_required_expr( like.pattern.as_deref(), registry, "pattern", + codec, )?), parse_escape_char(&like.escape_char)?, false, @@ -1284,44 +1327,66 @@ pub fn parse_expr( .when_then_expr .iter() .map(|e| { - let when_expr = - parse_required_expr(e.when_expr.as_ref(), registry, "when_expr")?; - let then_expr = - parse_required_expr(e.then_expr.as_ref(), registry, "then_expr")?; + let when_expr = parse_required_expr( + e.when_expr.as_ref(), + registry, + "when_expr", + codec, + )?; + let then_expr = parse_required_expr( + e.then_expr.as_ref(), + registry, + "then_expr", + codec, + )?; Ok((Box::new(when_expr), Box::new(then_expr))) }) .collect::, Box)>, Error>>()?; Ok(Expr::Case(Case::new( - parse_optional_expr(case.expr.as_deref(), registry)?.map(Box::new), + parse_optional_expr(case.expr.as_deref(), registry, codec)?.map(Box::new), when_then_expr, - parse_optional_expr(case.else_expr.as_deref(), registry)?.map(Box::new), + parse_optional_expr(case.else_expr.as_deref(), registry, codec)? + .map(Box::new), ))) } ExprType::Cast(cast) => { - let expr = - Box::new(parse_required_expr(cast.expr.as_deref(), registry, "expr")?); + let expr = Box::new(parse_required_expr( + cast.expr.as_deref(), + registry, + "expr", + codec, + )?); let data_type = cast.arrow_type.as_ref().required("arrow_type")?; Ok(Expr::Cast(Cast::new(expr, data_type))) } ExprType::TryCast(cast) => { - let expr = - Box::new(parse_required_expr(cast.expr.as_deref(), registry, "expr")?); + let expr = Box::new(parse_required_expr( + cast.expr.as_deref(), + registry, + "expr", + codec, + )?); let data_type = cast.arrow_type.as_ref().required("arrow_type")?; Ok(Expr::TryCast(TryCast::new(expr, data_type))) } ExprType::Sort(sort) => Ok(Expr::Sort(Sort::new( - Box::new(parse_required_expr(sort.expr.as_deref(), registry, "expr")?), + Box::new(parse_required_expr( + sort.expr.as_deref(), + registry, + "expr", + codec, + )?), sort.asc, sort.nulls_first, ))), ExprType::Negative(negative) => Ok(Expr::Negative(Box::new( - parse_required_expr(negative.expr.as_deref(), registry, "expr")?, + parse_required_expr(negative.expr.as_deref(), registry, "expr", codec)?, ))), ExprType::Unnest(unnest) => { let exprs = unnest .exprs .iter() - .map(|e| parse_expr(e, registry)) + .map(|e| parse_expr(e, registry, codec)) .collect::, _>>()?; Ok(Expr::Unnest(Unnest { exprs })) } @@ -1330,11 +1395,12 @@ pub fn parse_expr( in_list.expr.as_deref(), registry, "expr", + codec, )?), in_list .list .iter() - .map(|expr| parse_expr(expr, registry)) + .map(|expr| parse_expr(expr, registry, codec)) .collect::, _>>()?, in_list.negated, ))), @@ -1352,330 +1418,370 @@ pub fn parse_expr( match scalar_function { ScalarFunction::Unknown => Err(proto_error("Unknown scalar function")), - ScalarFunction::Asinh => Ok(asinh(parse_expr(&args[0], registry)?)), - ScalarFunction::Acosh => Ok(acosh(parse_expr(&args[0], registry)?)), + ScalarFunction::Asinh => { + Ok(asinh(parse_expr(&args[0], registry, codec)?)) + } + ScalarFunction::Acosh => { + Ok(acosh(parse_expr(&args[0], registry, codec)?)) + } ScalarFunction::Array => Ok(array( args.to_owned() .iter() - .map(|expr| parse_expr(expr, registry)) + .map(|expr| parse_expr(expr, registry, codec)) .collect::, _>>()?, )), ScalarFunction::ArrayAppend => Ok(array_append( - parse_expr(&args[0], registry)?, - parse_expr(&args[1], registry)?, + parse_expr(&args[0], registry, codec)?, + parse_expr(&args[1], registry, codec)?, )), ScalarFunction::ArraySort => Ok(array_sort( - parse_expr(&args[0], registry)?, - parse_expr(&args[1], registry)?, - parse_expr(&args[2], registry)?, + parse_expr(&args[0], registry, codec)?, + parse_expr(&args[1], registry, codec)?, + parse_expr(&args[2], registry, codec)?, )), ScalarFunction::ArrayPopFront => { - Ok(array_pop_front(parse_expr(&args[0], registry)?)) + Ok(array_pop_front(parse_expr(&args[0], registry, codec)?)) } ScalarFunction::ArrayPopBack => { - Ok(array_pop_back(parse_expr(&args[0], registry)?)) + Ok(array_pop_back(parse_expr(&args[0], registry, codec)?)) } ScalarFunction::ArrayPrepend => Ok(array_prepend( - parse_expr(&args[0], registry)?, - parse_expr(&args[1], registry)?, + parse_expr(&args[0], registry, codec)?, + parse_expr(&args[1], registry, codec)?, )), ScalarFunction::ArrayConcat => Ok(array_concat( args.to_owned() .iter() - .map(|expr| parse_expr(expr, registry)) + .map(|expr| parse_expr(expr, registry, codec)) .collect::, _>>()?, )), ScalarFunction::ArrayExcept => Ok(array_except( - parse_expr(&args[0], registry)?, - parse_expr(&args[1], registry)?, + parse_expr(&args[0], registry, codec)?, + parse_expr(&args[1], registry, codec)?, )), ScalarFunction::ArrayHasAll => Ok(array_has_all( - parse_expr(&args[0], registry)?, - parse_expr(&args[1], registry)?, + parse_expr(&args[0], registry, codec)?, + parse_expr(&args[1], registry, codec)?, )), ScalarFunction::ArrayHasAny => Ok(array_has_any( - parse_expr(&args[0], registry)?, - parse_expr(&args[1], registry)?, + parse_expr(&args[0], registry, codec)?, + parse_expr(&args[1], registry, codec)?, )), ScalarFunction::ArrayHas => Ok(array_has( - parse_expr(&args[0], registry)?, - parse_expr(&args[1], registry)?, + parse_expr(&args[0], registry, codec)?, + parse_expr(&args[1], registry, codec)?, )), ScalarFunction::ArrayIntersect => Ok(array_intersect( - parse_expr(&args[0], registry)?, - parse_expr(&args[1], registry)?, + parse_expr(&args[0], registry, codec)?, + parse_expr(&args[1], registry, codec)?, )), ScalarFunction::ArrayPosition => Ok(array_position( - parse_expr(&args[0], registry)?, - parse_expr(&args[1], registry)?, - parse_expr(&args[2], registry)?, + parse_expr(&args[0], registry, codec)?, + parse_expr(&args[1], registry, codec)?, + parse_expr(&args[2], registry, codec)?, )), ScalarFunction::ArrayPositions => Ok(array_positions( - parse_expr(&args[0], registry)?, - parse_expr(&args[1], registry)?, + parse_expr(&args[0], registry, codec)?, + parse_expr(&args[1], registry, codec)?, )), ScalarFunction::ArrayRepeat => Ok(array_repeat( - parse_expr(&args[0], registry)?, - parse_expr(&args[1], registry)?, + parse_expr(&args[0], registry, codec)?, + parse_expr(&args[1], registry, codec)?, )), ScalarFunction::ArrayRemove => Ok(array_remove( - parse_expr(&args[0], registry)?, - parse_expr(&args[1], registry)?, + parse_expr(&args[0], registry, codec)?, + parse_expr(&args[1], registry, codec)?, )), ScalarFunction::ArrayRemoveN => Ok(array_remove_n( - parse_expr(&args[0], registry)?, - parse_expr(&args[1], registry)?, - parse_expr(&args[2], registry)?, + parse_expr(&args[0], registry, codec)?, + parse_expr(&args[1], registry, codec)?, + parse_expr(&args[2], registry, codec)?, )), ScalarFunction::ArrayRemoveAll => Ok(array_remove_all( - parse_expr(&args[0], registry)?, - parse_expr(&args[1], registry)?, + parse_expr(&args[0], registry, codec)?, + parse_expr(&args[1], registry, codec)?, )), ScalarFunction::ArrayReplace => Ok(array_replace( - parse_expr(&args[0], registry)?, - parse_expr(&args[1], registry)?, - parse_expr(&args[2], registry)?, + parse_expr(&args[0], registry, codec)?, + parse_expr(&args[1], registry, codec)?, + parse_expr(&args[2], registry, codec)?, )), ScalarFunction::ArrayReplaceN => Ok(array_replace_n( - parse_expr(&args[0], registry)?, - parse_expr(&args[1], registry)?, - parse_expr(&args[2], registry)?, - parse_expr(&args[3], registry)?, + parse_expr(&args[0], registry, codec)?, + parse_expr(&args[1], registry, codec)?, + parse_expr(&args[2], registry, codec)?, + parse_expr(&args[3], registry, codec)?, )), ScalarFunction::ArrayReplaceAll => Ok(array_replace_all( - parse_expr(&args[0], registry)?, - parse_expr(&args[1], registry)?, - parse_expr(&args[2], registry)?, + parse_expr(&args[0], registry, codec)?, + parse_expr(&args[1], registry, codec)?, + parse_expr(&args[2], registry, codec)?, )), ScalarFunction::ArrayReverse => { - Ok(array_reverse(parse_expr(&args[0], registry)?)) + Ok(array_reverse(parse_expr(&args[0], registry, codec)?)) } ScalarFunction::ArraySlice => Ok(array_slice( - parse_expr(&args[0], registry)?, - parse_expr(&args[1], registry)?, - parse_expr(&args[2], registry)?, - parse_expr(&args[3], registry)?, + parse_expr(&args[0], registry, codec)?, + parse_expr(&args[1], registry, codec)?, + parse_expr(&args[2], registry, codec)?, + parse_expr(&args[3], registry, codec)?, )), ScalarFunction::Cardinality => { - Ok(cardinality(parse_expr(&args[0], registry)?)) + Ok(cardinality(parse_expr(&args[0], registry, codec)?)) } ScalarFunction::ArrayLength => Ok(array_length( - parse_expr(&args[0], registry)?, - parse_expr(&args[1], registry)?, + parse_expr(&args[0], registry, codec)?, + parse_expr(&args[1], registry, codec)?, )), ScalarFunction::ArrayDims => { - Ok(array_dims(parse_expr(&args[0], registry)?)) + Ok(array_dims(parse_expr(&args[0], registry, codec)?)) } ScalarFunction::ArrayDistinct => { - Ok(array_distinct(parse_expr(&args[0], registry)?)) + Ok(array_distinct(parse_expr(&args[0], registry, codec)?)) } ScalarFunction::ArrayElement => Ok(array_element( - parse_expr(&args[0], registry)?, - parse_expr(&args[1], registry)?, + parse_expr(&args[0], registry, codec)?, + parse_expr(&args[1], registry, codec)?, )), ScalarFunction::ArrayEmpty => { - Ok(array_empty(parse_expr(&args[0], registry)?)) + Ok(array_empty(parse_expr(&args[0], registry, codec)?)) } ScalarFunction::ArrayNdims => { - Ok(array_ndims(parse_expr(&args[0], registry)?)) + Ok(array_ndims(parse_expr(&args[0], registry, codec)?)) } ScalarFunction::ArrayUnion => Ok(array_union( - parse_expr(&args[0], registry)?, - parse_expr(&args[1], registry)?, + parse_expr(&args[0], registry, codec)?, + parse_expr(&args[1], registry, codec)?, )), ScalarFunction::ArrayResize => Ok(array_resize( - parse_expr(&args[0], registry)?, - parse_expr(&args[1], registry)?, - parse_expr(&args[2], registry)?, + parse_expr(&args[0], registry, codec)?, + parse_expr(&args[1], registry, codec)?, + parse_expr(&args[2], registry, codec)?, )), - ScalarFunction::Sqrt => Ok(sqrt(parse_expr(&args[0], registry)?)), - ScalarFunction::Cbrt => Ok(cbrt(parse_expr(&args[0], registry)?)), - ScalarFunction::Sin => Ok(sin(parse_expr(&args[0], registry)?)), - ScalarFunction::Cos => Ok(cos(parse_expr(&args[0], registry)?)), - ScalarFunction::Tan => Ok(tan(parse_expr(&args[0], registry)?)), - ScalarFunction::Atan => Ok(atan(parse_expr(&args[0], registry)?)), - ScalarFunction::Sinh => Ok(sinh(parse_expr(&args[0], registry)?)), - ScalarFunction::Cosh => Ok(cosh(parse_expr(&args[0], registry)?)), - ScalarFunction::Tanh => Ok(tanh(parse_expr(&args[0], registry)?)), - ScalarFunction::Atanh => Ok(atanh(parse_expr(&args[0], registry)?)), - ScalarFunction::Exp => Ok(exp(parse_expr(&args[0], registry)?)), - ScalarFunction::Degrees => Ok(degrees(parse_expr(&args[0], registry)?)), - ScalarFunction::Radians => Ok(radians(parse_expr(&args[0], registry)?)), - ScalarFunction::Log2 => Ok(log2(parse_expr(&args[0], registry)?)), - ScalarFunction::Ln => Ok(ln(parse_expr(&args[0], registry)?)), - ScalarFunction::Log10 => Ok(log10(parse_expr(&args[0], registry)?)), - ScalarFunction::Floor => Ok(floor(parse_expr(&args[0], registry)?)), + ScalarFunction::Sqrt => Ok(sqrt(parse_expr(&args[0], registry, codec)?)), + ScalarFunction::Cbrt => Ok(cbrt(parse_expr(&args[0], registry, codec)?)), + ScalarFunction::Sin => Ok(sin(parse_expr(&args[0], registry, codec)?)), + ScalarFunction::Cos => Ok(cos(parse_expr(&args[0], registry, codec)?)), + ScalarFunction::Tan => Ok(tan(parse_expr(&args[0], registry, codec)?)), + ScalarFunction::Atan => Ok(atan(parse_expr(&args[0], registry, codec)?)), + ScalarFunction::Sinh => Ok(sinh(parse_expr(&args[0], registry, codec)?)), + ScalarFunction::Cosh => Ok(cosh(parse_expr(&args[0], registry, codec)?)), + ScalarFunction::Tanh => Ok(tanh(parse_expr(&args[0], registry, codec)?)), + ScalarFunction::Atanh => { + Ok(atanh(parse_expr(&args[0], registry, codec)?)) + } + ScalarFunction::Exp => Ok(exp(parse_expr(&args[0], registry, codec)?)), + ScalarFunction::Degrees => { + Ok(degrees(parse_expr(&args[0], registry, codec)?)) + } + ScalarFunction::Radians => { + Ok(radians(parse_expr(&args[0], registry, codec)?)) + } + ScalarFunction::Log2 => Ok(log2(parse_expr(&args[0], registry, codec)?)), + ScalarFunction::Ln => Ok(ln(parse_expr(&args[0], registry, codec)?)), + ScalarFunction::Log10 => { + Ok(log10(parse_expr(&args[0], registry, codec)?)) + } + ScalarFunction::Floor => { + Ok(floor(parse_expr(&args[0], registry, codec)?)) + } ScalarFunction::Factorial => { - Ok(factorial(parse_expr(&args[0], registry)?)) + Ok(factorial(parse_expr(&args[0], registry, codec)?)) } - ScalarFunction::Ceil => Ok(ceil(parse_expr(&args[0], registry)?)), + ScalarFunction::Ceil => Ok(ceil(parse_expr(&args[0], registry, codec)?)), ScalarFunction::Round => Ok(round( args.to_owned() .iter() - .map(|expr| parse_expr(expr, registry)) + .map(|expr| parse_expr(expr, registry, codec)) .collect::, _>>()?, )), ScalarFunction::Trunc => Ok(trunc( args.to_owned() .iter() - .map(|expr| parse_expr(expr, registry)) + .map(|expr| parse_expr(expr, registry, codec)) .collect::, _>>()?, )), - ScalarFunction::Signum => Ok(signum(parse_expr(&args[0], registry)?)), + ScalarFunction::Signum => { + Ok(signum(parse_expr(&args[0], registry, codec)?)) + } ScalarFunction::OctetLength => { - Ok(octet_length(parse_expr(&args[0], registry)?)) + Ok(octet_length(parse_expr(&args[0], registry, codec)?)) + } + ScalarFunction::Lower => { + Ok(lower(parse_expr(&args[0], registry, codec)?)) + } + ScalarFunction::Upper => { + Ok(upper(parse_expr(&args[0], registry, codec)?)) + } + ScalarFunction::Trim => Ok(trim(parse_expr(&args[0], registry, codec)?)), + ScalarFunction::Ltrim => { + Ok(ltrim(parse_expr(&args[0], registry, codec)?)) + } + ScalarFunction::Rtrim => { + Ok(rtrim(parse_expr(&args[0], registry, codec)?)) } - ScalarFunction::Lower => Ok(lower(parse_expr(&args[0], registry)?)), - ScalarFunction::Upper => Ok(upper(parse_expr(&args[0], registry)?)), - ScalarFunction::Trim => Ok(trim(parse_expr(&args[0], registry)?)), - ScalarFunction::Ltrim => Ok(ltrim(parse_expr(&args[0], registry)?)), - ScalarFunction::Rtrim => Ok(rtrim(parse_expr(&args[0], registry)?)), ScalarFunction::DatePart => Ok(date_part( - parse_expr(&args[0], registry)?, - parse_expr(&args[1], registry)?, + parse_expr(&args[0], registry, codec)?, + parse_expr(&args[1], registry, codec)?, )), ScalarFunction::DateTrunc => Ok(date_trunc( - parse_expr(&args[0], registry)?, - parse_expr(&args[1], registry)?, + parse_expr(&args[0], registry, codec)?, + parse_expr(&args[1], registry, codec)?, )), ScalarFunction::DateBin => Ok(date_bin( - parse_expr(&args[0], registry)?, - parse_expr(&args[1], registry)?, - parse_expr(&args[2], registry)?, + parse_expr(&args[0], registry, codec)?, + parse_expr(&args[1], registry, codec)?, + parse_expr(&args[2], registry, codec)?, )), - ScalarFunction::Sha224 => Ok(sha224(parse_expr(&args[0], registry)?)), - ScalarFunction::Sha256 => Ok(sha256(parse_expr(&args[0], registry)?)), - ScalarFunction::Sha384 => Ok(sha384(parse_expr(&args[0], registry)?)), - ScalarFunction::Sha512 => Ok(sha512(parse_expr(&args[0], registry)?)), - ScalarFunction::Md5 => Ok(md5(parse_expr(&args[0], registry)?)), + ScalarFunction::Sha224 => { + Ok(sha224(parse_expr(&args[0], registry, codec)?)) + } + ScalarFunction::Sha256 => { + Ok(sha256(parse_expr(&args[0], registry, codec)?)) + } + ScalarFunction::Sha384 => { + Ok(sha384(parse_expr(&args[0], registry, codec)?)) + } + ScalarFunction::Sha512 => { + Ok(sha512(parse_expr(&args[0], registry, codec)?)) + } + ScalarFunction::Md5 => Ok(md5(parse_expr(&args[0], registry, codec)?)), ScalarFunction::Digest => Ok(digest( - parse_expr(&args[0], registry)?, - parse_expr(&args[1], registry)?, + parse_expr(&args[0], registry, codec)?, + parse_expr(&args[1], registry, codec)?, )), - ScalarFunction::Ascii => Ok(ascii(parse_expr(&args[0], registry)?)), + ScalarFunction::Ascii => { + Ok(ascii(parse_expr(&args[0], registry, codec)?)) + } ScalarFunction::BitLength => { - Ok(bit_length(parse_expr(&args[0], registry)?)) + Ok(bit_length(parse_expr(&args[0], registry, codec)?)) } ScalarFunction::CharacterLength => { - Ok(character_length(parse_expr(&args[0], registry)?)) + Ok(character_length(parse_expr(&args[0], registry, codec)?)) + } + ScalarFunction::Chr => Ok(chr(parse_expr(&args[0], registry, codec)?)), + ScalarFunction::InitCap => { + Ok(initcap(parse_expr(&args[0], registry, codec)?)) } - ScalarFunction::Chr => Ok(chr(parse_expr(&args[0], registry)?)), - ScalarFunction::InitCap => Ok(initcap(parse_expr(&args[0], registry)?)), ScalarFunction::InStr => Ok(instr( - parse_expr(&args[0], registry)?, - parse_expr(&args[1], registry)?, + parse_expr(&args[0], registry, codec)?, + parse_expr(&args[1], registry, codec)?, )), ScalarFunction::Gcd => Ok(gcd( - parse_expr(&args[0], registry)?, - parse_expr(&args[1], registry)?, + parse_expr(&args[0], registry, codec)?, + parse_expr(&args[1], registry, codec)?, )), ScalarFunction::Lcm => Ok(lcm( - parse_expr(&args[0], registry)?, - parse_expr(&args[1], registry)?, + parse_expr(&args[0], registry, codec)?, + parse_expr(&args[1], registry, codec)?, )), ScalarFunction::Left => Ok(left( - parse_expr(&args[0], registry)?, - parse_expr(&args[1], registry)?, + parse_expr(&args[0], registry, codec)?, + parse_expr(&args[1], registry, codec)?, )), ScalarFunction::Random => Ok(random()), ScalarFunction::Uuid => Ok(uuid()), ScalarFunction::Repeat => Ok(repeat( - parse_expr(&args[0], registry)?, - parse_expr(&args[1], registry)?, + parse_expr(&args[0], registry, codec)?, + parse_expr(&args[1], registry, codec)?, )), ScalarFunction::Replace => Ok(replace( - parse_expr(&args[0], registry)?, - parse_expr(&args[1], registry)?, - parse_expr(&args[2], registry)?, + parse_expr(&args[0], registry, codec)?, + parse_expr(&args[1], registry, codec)?, + parse_expr(&args[2], registry, codec)?, )), - ScalarFunction::Reverse => Ok(reverse(parse_expr(&args[0], registry)?)), + ScalarFunction::Reverse => { + Ok(reverse(parse_expr(&args[0], registry, codec)?)) + } ScalarFunction::Right => Ok(right( - parse_expr(&args[0], registry)?, - parse_expr(&args[1], registry)?, + parse_expr(&args[0], registry, codec)?, + parse_expr(&args[1], registry, codec)?, )), ScalarFunction::Concat => Ok(concat_expr( args.to_owned() .iter() - .map(|expr| parse_expr(expr, registry)) + .map(|expr| parse_expr(expr, registry, codec)) .collect::, _>>()?, )), ScalarFunction::ConcatWithSeparator => Ok(concat_ws_expr( args.to_owned() .iter() - .map(|expr| parse_expr(expr, registry)) + .map(|expr| parse_expr(expr, registry, codec)) .collect::, _>>()?, )), ScalarFunction::Lpad => Ok(lpad( args.to_owned() .iter() - .map(|expr| parse_expr(expr, registry)) + .map(|expr| parse_expr(expr, registry, codec)) .collect::, _>>()?, )), ScalarFunction::Rpad => Ok(rpad( args.to_owned() .iter() - .map(|expr| parse_expr(expr, registry)) + .map(|expr| parse_expr(expr, registry, codec)) .collect::, _>>()?, )), ScalarFunction::RegexpLike => Ok(regexp_like( args.to_owned() .iter() - .map(|expr| parse_expr(expr, registry)) + .map(|expr| parse_expr(expr, registry, codec)) .collect::, _>>()?, )), ScalarFunction::RegexpReplace => Ok(regexp_replace( args.to_owned() .iter() - .map(|expr| parse_expr(expr, registry)) + .map(|expr| parse_expr(expr, registry, codec)) .collect::, _>>()?, )), ScalarFunction::Btrim => Ok(btrim( args.to_owned() .iter() - .map(|expr| parse_expr(expr, registry)) + .map(|expr| parse_expr(expr, registry, codec)) .collect::, _>>()?, )), ScalarFunction::SplitPart => Ok(split_part( - parse_expr(&args[0], registry)?, - parse_expr(&args[1], registry)?, - parse_expr(&args[2], registry)?, + parse_expr(&args[0], registry, codec)?, + parse_expr(&args[1], registry, codec)?, + parse_expr(&args[2], registry, codec)?, )), ScalarFunction::StartsWith => Ok(starts_with( - parse_expr(&args[0], registry)?, - parse_expr(&args[1], registry)?, + parse_expr(&args[0], registry, codec)?, + parse_expr(&args[1], registry, codec)?, )), ScalarFunction::EndsWith => Ok(ends_with( - parse_expr(&args[0], registry)?, - parse_expr(&args[1], registry)?, + parse_expr(&args[0], registry, codec)?, + parse_expr(&args[1], registry, codec)?, )), ScalarFunction::Strpos => Ok(strpos( - parse_expr(&args[0], registry)?, - parse_expr(&args[1], registry)?, + parse_expr(&args[0], registry, codec)?, + parse_expr(&args[1], registry, codec)?, )), ScalarFunction::Substr => { if args.len() > 2 { assert_eq!(args.len(), 3); Ok(substring( - parse_expr(&args[0], registry)?, - parse_expr(&args[1], registry)?, - parse_expr(&args[2], registry)?, + parse_expr(&args[0], registry, codec)?, + parse_expr(&args[1], registry, codec)?, + parse_expr(&args[2], registry, codec)?, )) } else { Ok(substr( - parse_expr(&args[0], registry)?, - parse_expr(&args[1], registry)?, + parse_expr(&args[0], registry, codec)?, + parse_expr(&args[1], registry, codec)?, )) } } ScalarFunction::Levenshtein => Ok(levenshtein( - parse_expr(&args[0], registry)?, - parse_expr(&args[1], registry)?, + parse_expr(&args[0], registry, codec)?, + parse_expr(&args[1], registry, codec)?, )), - ScalarFunction::ToHex => Ok(to_hex(parse_expr(&args[0], registry)?)), + ScalarFunction::ToHex => { + Ok(to_hex(parse_expr(&args[0], registry, codec)?)) + } ScalarFunction::MakeDate => { let args: Vec<_> = args .iter() - .map(|expr| parse_expr(expr, registry)) + .map(|expr| parse_expr(expr, registry, codec)) .collect::>()?; Ok(Expr::ScalarFunction(expr::ScalarFunction::new( BuiltinScalarFunction::MakeDate, @@ -1685,7 +1791,7 @@ pub fn parse_expr( ScalarFunction::ToChar => { let args: Vec<_> = args .iter() - .map(|expr| parse_expr(expr, registry)) + .map(|expr| parse_expr(expr, registry, codec)) .collect::>()?; Ok(Expr::ScalarFunction(expr::ScalarFunction::new( BuiltinScalarFunction::ToChar, @@ -1694,75 +1800,86 @@ pub fn parse_expr( } ScalarFunction::Now => Ok(now()), ScalarFunction::Translate => Ok(translate( - parse_expr(&args[0], registry)?, - parse_expr(&args[1], registry)?, - parse_expr(&args[2], registry)?, + parse_expr(&args[0], registry, codec)?, + parse_expr(&args[1], registry, codec)?, + parse_expr(&args[2], registry, codec)?, )), ScalarFunction::Coalesce => Ok(coalesce( args.to_owned() .iter() - .map(|expr| parse_expr(expr, registry)) + .map(|expr| parse_expr(expr, registry, codec)) .collect::, _>>()?, )), ScalarFunction::Pi => Ok(pi()), ScalarFunction::Power => Ok(power( - parse_expr(&args[0], registry)?, - parse_expr(&args[1], registry)?, + parse_expr(&args[0], registry, codec)?, + parse_expr(&args[1], registry, codec)?, )), ScalarFunction::Log => Ok(log( - parse_expr(&args[0], registry)?, - parse_expr(&args[1], registry)?, + parse_expr(&args[0], registry, codec)?, + parse_expr(&args[1], registry, codec)?, )), ScalarFunction::FromUnixtime => { - Ok(from_unixtime(parse_expr(&args[0], registry)?)) + Ok(from_unixtime(parse_expr(&args[0], registry, codec)?)) } ScalarFunction::Atan2 => Ok(atan2( - parse_expr(&args[0], registry)?, - parse_expr(&args[1], registry)?, + parse_expr(&args[0], registry, codec)?, + parse_expr(&args[1], registry, codec)?, )), ScalarFunction::CurrentDate => Ok(current_date()), ScalarFunction::CurrentTime => Ok(current_time()), - ScalarFunction::Cot => Ok(cot(parse_expr(&args[0], registry)?)), + ScalarFunction::Cot => Ok(cot(parse_expr(&args[0], registry, codec)?)), ScalarFunction::Nanvl => Ok(nanvl( - parse_expr(&args[0], registry)?, - parse_expr(&args[1], registry)?, + parse_expr(&args[0], registry, codec)?, + parse_expr(&args[1], registry, codec)?, )), - ScalarFunction::Iszero => Ok(iszero(parse_expr(&args[0], registry)?)), + ScalarFunction::Iszero => { + Ok(iszero(parse_expr(&args[0], registry, codec)?)) + } ScalarFunction::ArrowTypeof => { - Ok(arrow_typeof(parse_expr(&args[0], registry)?)) + Ok(arrow_typeof(parse_expr(&args[0], registry, codec)?)) + } + ScalarFunction::Flatten => { + Ok(flatten(parse_expr(&args[0], registry, codec)?)) } - ScalarFunction::Flatten => Ok(flatten(parse_expr(&args[0], registry)?)), ScalarFunction::StringToArray => Ok(string_to_array( - parse_expr(&args[0], registry)?, - parse_expr(&args[1], registry)?, - parse_expr(&args[2], registry)?, + parse_expr(&args[0], registry, codec)?, + parse_expr(&args[1], registry, codec)?, + parse_expr(&args[2], registry, codec)?, )), ScalarFunction::OverLay => Ok(overlay( args.to_owned() .iter() - .map(|expr| parse_expr(expr, registry)) + .map(|expr| parse_expr(expr, registry, codec)) .collect::, _>>()?, )), ScalarFunction::SubstrIndex => Ok(substr_index( - parse_expr(&args[0], registry)?, - parse_expr(&args[1], registry)?, - parse_expr(&args[2], registry)?, + parse_expr(&args[0], registry, codec)?, + parse_expr(&args[1], registry, codec)?, + parse_expr(&args[2], registry, codec)?, )), ScalarFunction::FindInSet => Ok(find_in_set( - parse_expr(&args[0], registry)?, - parse_expr(&args[1], registry)?, + parse_expr(&args[0], registry, codec)?, + parse_expr(&args[1], registry, codec)?, )), ScalarFunction::StructFun => { - Ok(struct_fun(parse_expr(&args[0], registry)?)) + Ok(struct_fun(parse_expr(&args[0], registry, codec)?)) } } } - ExprType::ScalarUdfExpr(protobuf::ScalarUdfExprNode { fun_name, args }) => { - let scalar_fn = registry.udf(fun_name.as_str())?; + ExprType::ScalarUdfExpr(protobuf::ScalarUdfExprNode { + fun_name, + args, + fun_definition, + }) => { + let scalar_fn = match fun_definition { + Some(buf) => codec.try_decode_udf(fun_name, buf)?, + None => registry.udf(fun_name.as_str())?, + }; Ok(Expr::ScalarFunction(expr::ScalarFunction::new_udf( scalar_fn, args.iter() - .map(|expr| parse_expr(expr, registry)) + .map(|expr| parse_expr(expr, registry, codec)) .collect::, Error>>()?, ))) } @@ -1773,11 +1890,11 @@ pub fn parse_expr( agg_fn, pb.args .iter() - .map(|expr| parse_expr(expr, registry)) + .map(|expr| parse_expr(expr, registry, codec)) .collect::, Error>>()?, false, - parse_optional_expr(pb.filter.as_deref(), registry)?.map(Box::new), - parse_vec_expr(&pb.order_by, registry)?, + parse_optional_expr(pb.filter.as_deref(), registry, codec)?.map(Box::new), + parse_vec_expr(&pb.order_by, registry, codec)?, ))) } @@ -1788,7 +1905,7 @@ pub fn parse_expr( expr_list .expr .iter() - .map(|expr| parse_expr(expr, registry)) + .map(|expr| parse_expr(expr, registry, codec)) .collect::, Error>>() }) .collect::, Error>>()?, @@ -1796,13 +1913,13 @@ pub fn parse_expr( } ExprType::Cube(CubeNode { expr }) => Ok(Expr::GroupingSet(GroupingSet::Cube( expr.iter() - .map(|expr| parse_expr(expr, registry)) + .map(|expr| parse_expr(expr, registry, codec)) .collect::, Error>>()?, ))), ExprType::Rollup(RollupNode { expr }) => { Ok(Expr::GroupingSet(GroupingSet::Rollup( expr.iter() - .map(|expr| parse_expr(expr, registry)) + .map(|expr| parse_expr(expr, registry, codec)) .collect::, Error>>()?, ))) } @@ -1870,10 +1987,13 @@ pub fn from_proto_binary_op(op: &str) -> Result { fn parse_vec_expr( p: &[protobuf::LogicalExprNode], registry: &dyn FunctionRegistry, + codec: &dyn LogicalExtensionCodec, ) -> Result>, Error> { let res = p .iter() - .map(|elem| parse_expr(elem, registry).map_err(|e| plan_datafusion_err!("{}", e))) + .map(|elem| { + parse_expr(elem, registry, codec).map_err(|e| plan_datafusion_err!("{}", e)) + }) .collect::>>()?; // Convert empty vector to None. Ok((!res.is_empty()).then_some(res)) @@ -1882,9 +2002,10 @@ fn parse_vec_expr( fn parse_optional_expr( p: Option<&protobuf::LogicalExprNode>, registry: &dyn FunctionRegistry, + codec: &dyn LogicalExtensionCodec, ) -> Result, Error> { match p { - Some(expr) => parse_expr(expr, registry).map(Some), + Some(expr) => parse_expr(expr, registry, codec).map(Some), None => Ok(None), } } @@ -1893,9 +2014,10 @@ fn parse_required_expr( p: Option<&protobuf::LogicalExprNode>, registry: &dyn FunctionRegistry, field: impl Into, + codec: &dyn LogicalExtensionCodec, ) -> Result { match p { - Some(expr) => parse_expr(expr, registry), + Some(expr) => parse_expr(expr, registry, codec), None => Err(Error::required(field)), } } diff --git a/datafusion/proto/src/logical_plan/mod.rs b/datafusion/proto/src/logical_plan/mod.rs index f107af757a711..7c9ead27e3b58 100644 --- a/datafusion/proto/src/logical_plan/mod.rs +++ b/datafusion/proto/src/logical_plan/mod.rs @@ -17,6 +17,7 @@ use arrow::csv::WriterBuilder; use datafusion_common::file_options::arrow_writer::ArrowWriterOptions; +use datafusion_expr::ScalarUDF; use std::collections::HashMap; use std::fmt::Debug; use std::str::FromStr; @@ -72,6 +73,8 @@ use datafusion_expr::dml::CopyOptions; use prost::bytes::BufMut; use prost::Message; +use self::to_proto::serialize_expr; + pub mod from_proto; pub mod to_proto; @@ -133,6 +136,14 @@ pub trait LogicalExtensionCodec: Debug + Send + Sync { node: Arc, buf: &mut Vec, ) -> Result<()>; + + fn try_decode_udf(&self, name: &str, _buf: &[u8]) -> Result> { + not_impl_err!("LogicalExtensionCodec is not provided for scalar function {name}") + } + + fn try_encode_udf(&self, _node: &ScalarUDF, _buf: &mut Vec) -> Result<()> { + Ok(()) + } } #[derive(Debug, Clone)] @@ -241,7 +252,9 @@ impl AsLogicalPlan for LogicalPlanNode { .chunks_exact(n_cols) .map(|r| { r.iter() - .map(|expr| from_proto::parse_expr(expr, ctx)) + .map(|expr| { + from_proto::parse_expr(expr, ctx, extension_codec) + }) .collect::, from_proto::Error>>() }) .collect::, _>>() @@ -255,7 +268,7 @@ impl AsLogicalPlan for LogicalPlanNode { let expr: Vec = projection .expr .iter() - .map(|expr| from_proto::parse_expr(expr, ctx)) + .map(|expr| from_proto::parse_expr(expr, ctx, extension_codec)) .collect::, _>>()?; let new_proj = project(input, expr)?; @@ -277,7 +290,7 @@ impl AsLogicalPlan for LogicalPlanNode { let expr: Expr = selection .expr .as_ref() - .map(|expr| from_proto::parse_expr(expr, ctx)) + .map(|expr| from_proto::parse_expr(expr, ctx, extension_codec)) .transpose()? .ok_or_else(|| { DataFusionError::Internal("expression required".to_string()) @@ -291,7 +304,7 @@ impl AsLogicalPlan for LogicalPlanNode { let window_expr = window .window_expr .iter() - .map(|expr| from_proto::parse_expr(expr, ctx)) + .map(|expr| from_proto::parse_expr(expr, ctx, extension_codec)) .collect::, _>>()?; LogicalPlanBuilder::from(input).window(window_expr)?.build() } @@ -301,12 +314,12 @@ impl AsLogicalPlan for LogicalPlanNode { let group_expr = aggregate .group_expr .iter() - .map(|expr| from_proto::parse_expr(expr, ctx)) + .map(|expr| from_proto::parse_expr(expr, ctx, extension_codec)) .collect::, _>>()?; let aggr_expr = aggregate .aggr_expr .iter() - .map(|expr| from_proto::parse_expr(expr, ctx)) + .map(|expr| from_proto::parse_expr(expr, ctx, extension_codec)) .collect::, _>>()?; LogicalPlanBuilder::from(input) .aggregate(group_expr, aggr_expr)? @@ -328,7 +341,7 @@ impl AsLogicalPlan for LogicalPlanNode { let filters = scan .filters .iter() - .map(|expr| from_proto::parse_expr(expr, ctx)) + .map(|expr| from_proto::parse_expr(expr, ctx, extension_codec)) .collect::, _>>()?; let mut all_sort_orders = vec![]; @@ -336,7 +349,7 @@ impl AsLogicalPlan for LogicalPlanNode { let file_sort_order = order .logical_expr_nodes .iter() - .map(|expr| from_proto::parse_expr(expr, ctx)) + .map(|expr| from_proto::parse_expr(expr, ctx, extension_codec)) .collect::, _>>()?; all_sort_orders.push(file_sort_order) } @@ -436,7 +449,7 @@ impl AsLogicalPlan for LogicalPlanNode { let filters = scan .filters .iter() - .map(|expr| from_proto::parse_expr(expr, ctx)) + .map(|expr| from_proto::parse_expr(expr, ctx, extension_codec)) .collect::, _>>()?; let provider = extension_codec.try_decode_table_provider( &scan.custom_table_data, @@ -461,7 +474,7 @@ impl AsLogicalPlan for LogicalPlanNode { let sort_expr: Vec = sort .expr .iter() - .map(|expr| from_proto::parse_expr(expr, ctx)) + .map(|expr| from_proto::parse_expr(expr, ctx, extension_codec)) .collect::, _>>()?; LogicalPlanBuilder::from(input).sort(sort_expr)?.build() } @@ -483,7 +496,9 @@ impl AsLogicalPlan for LogicalPlanNode { }) => Partitioning::Hash( pb_hash_expr .iter() - .map(|expr| from_proto::parse_expr(expr, ctx)) + .map(|expr| { + from_proto::parse_expr(expr, ctx, extension_codec) + }) .collect::, _>>()?, *partition_count as usize, ), @@ -527,7 +542,7 @@ impl AsLogicalPlan for LogicalPlanNode { let order_expr = expr .logical_expr_nodes .iter() - .map(|expr| from_proto::parse_expr(expr, ctx)) + .map(|expr| from_proto::parse_expr(expr, ctx, extension_codec)) .collect::, _>>()?; order_exprs.push(order_expr) } @@ -535,7 +550,7 @@ impl AsLogicalPlan for LogicalPlanNode { let mut column_defaults = HashMap::with_capacity(create_extern_table.column_defaults.len()); for (col_name, expr) in &create_extern_table.column_defaults { - let expr = from_proto::parse_expr(expr, ctx)?; + let expr = from_proto::parse_expr(expr, ctx, extension_codec)?; column_defaults.insert(col_name.clone(), expr); } @@ -663,12 +678,12 @@ impl AsLogicalPlan for LogicalPlanNode { let left_keys: Vec = join .left_join_key .iter() - .map(|expr| from_proto::parse_expr(expr, ctx)) + .map(|expr| from_proto::parse_expr(expr, ctx, extension_codec)) .collect::, _>>()?; let right_keys: Vec = join .right_join_key .iter() - .map(|expr| from_proto::parse_expr(expr, ctx)) + .map(|expr| from_proto::parse_expr(expr, ctx, extension_codec)) .collect::, _>>()?; let join_type = protobuf::JoinType::try_from(join.join_type).map_err(|_| { @@ -689,7 +704,7 @@ impl AsLogicalPlan for LogicalPlanNode { let filter: Option = join .filter .as_ref() - .map(|expr| from_proto::parse_expr(expr, ctx)) + .map(|expr| from_proto::parse_expr(expr, ctx, extension_codec)) .map_or(Ok(None), |v| v.map(Some))?; let builder = LogicalPlanBuilder::from(into_logical_plan!( @@ -769,12 +784,12 @@ impl AsLogicalPlan for LogicalPlanNode { let on_expr = distinct_on .on_expr .iter() - .map(|expr| from_proto::parse_expr(expr, ctx)) + .map(|expr| from_proto::parse_expr(expr, ctx, extension_codec)) .collect::, _>>()?; let select_expr = distinct_on .select_expr .iter() - .map(|expr| from_proto::parse_expr(expr, ctx)) + .map(|expr| from_proto::parse_expr(expr, ctx, extension_codec)) .collect::, _>>()?; let sort_expr = match distinct_on.sort_expr.len() { 0 => None, @@ -782,7 +797,9 @@ impl AsLogicalPlan for LogicalPlanNode { distinct_on .sort_expr .iter() - .map(|expr| from_proto::parse_expr(expr, ctx)) + .map(|expr| { + from_proto::parse_expr(expr, ctx, extension_codec) + }) .collect::, _>>()?, ), }; @@ -944,7 +961,7 @@ impl AsLogicalPlan for LogicalPlanNode { let values_list = values .iter() .flatten() - .map(|v| v.try_into()) + .map(|v| serialize_expr(v, extension_codec)) .collect::, _>>()?; Ok(protobuf::LogicalPlanNode { logical_plan_type: Some(LogicalPlanType::Values( @@ -982,7 +999,7 @@ impl AsLogicalPlan for LogicalPlanNode { let filters: Vec = filters .iter() - .map(|filter| filter.try_into()) + .map(|filter| serialize_expr(filter, extension_codec)) .collect::, _>>()?; if let Some(listing_table) = source.downcast_ref::() { @@ -1039,7 +1056,7 @@ impl AsLogicalPlan for LogicalPlanNode { let expr_vec = LogicalExprNodeCollection { logical_expr_nodes: order .iter() - .map(|expr| expr.try_into()) + .map(|expr| serialize_expr(expr, extension_codec)) .collect::, to_proto::Error>>()?, }; exprs_vec.push(expr_vec); @@ -1120,7 +1137,7 @@ impl AsLogicalPlan for LogicalPlanNode { )), expr: expr .iter() - .map(|expr| expr.try_into()) + .map(|expr| serialize_expr(expr, extension_codec)) .collect::, to_proto::Error>>()?, optional_alias: None, }, @@ -1137,7 +1154,10 @@ impl AsLogicalPlan for LogicalPlanNode { logical_plan_type: Some(LogicalPlanType::Selection(Box::new( protobuf::SelectionNode { input: Some(Box::new(input)), - expr: Some((&filter.predicate).try_into()?), + expr: Some(serialize_expr( + &filter.predicate, + extension_codec, + )?), }, ))), }) @@ -1172,7 +1192,7 @@ impl AsLogicalPlan for LogicalPlanNode { None => vec![], Some(sort_expr) => sort_expr .iter() - .map(|expr| expr.try_into()) + .map(|expr| serialize_expr(expr, extension_codec)) .collect::, _>>()?, }; Ok(protobuf::LogicalPlanNode { @@ -1180,11 +1200,11 @@ impl AsLogicalPlan for LogicalPlanNode { protobuf::DistinctOnNode { on_expr: on_expr .iter() - .map(|expr| expr.try_into()) + .map(|expr| serialize_expr(expr, extension_codec)) .collect::, _>>()?, select_expr: select_expr .iter() - .map(|expr| expr.try_into()) + .map(|expr| serialize_expr(expr, extension_codec)) .collect::, _>>()?, sort_expr, input: Some(Box::new(input)), @@ -1206,7 +1226,7 @@ impl AsLogicalPlan for LogicalPlanNode { input: Some(Box::new(input)), window_expr: window_expr .iter() - .map(|expr| expr.try_into()) + .map(|expr| serialize_expr(expr, extension_codec)) .collect::, _>>()?, }, ))), @@ -1229,11 +1249,11 @@ impl AsLogicalPlan for LogicalPlanNode { input: Some(Box::new(input)), group_expr: group_expr .iter() - .map(|expr| expr.try_into()) + .map(|expr| serialize_expr(expr, extension_codec)) .collect::, _>>()?, aggr_expr: aggr_expr .iter() - .map(|expr| expr.try_into()) + .map(|expr| serialize_expr(expr, extension_codec)) .collect::, _>>()?, }, ))), @@ -1261,7 +1281,12 @@ impl AsLogicalPlan for LogicalPlanNode { )?; let (left_join_key, right_join_key) = on .iter() - .map(|(l, r)| Ok((l.try_into()?, r.try_into()?))) + .map(|(l, r)| { + Ok(( + serialize_expr(l, extension_codec)?, + serialize_expr(r, extension_codec)?, + )) + }) .collect::, to_proto::Error>>()? .into_iter() .unzip(); @@ -1270,7 +1295,7 @@ impl AsLogicalPlan for LogicalPlanNode { join_constraint.to_owned().into(); let filter = filter .as_ref() - .map(|e| e.try_into()) + .map(|e| serialize_expr(e, extension_codec)) .map_or(Ok(None), |v| v.map(Some))?; Ok(protobuf::LogicalPlanNode { logical_plan_type: Some(LogicalPlanType::Join(Box::new( @@ -1329,7 +1354,7 @@ impl AsLogicalPlan for LogicalPlanNode { )?; let selection_expr: Vec = expr .iter() - .map(|expr| expr.try_into()) + .map(|expr| serialize_expr(expr, extension_codec)) .collect::, to_proto::Error>>()?; Ok(protobuf::LogicalPlanNode { logical_plan_type: Some(LogicalPlanType::Sort(Box::new( @@ -1361,7 +1386,7 @@ impl AsLogicalPlan for LogicalPlanNode { PartitionMethod::Hash(protobuf::HashRepartition { hash_expr: exprs .iter() - .map(|expr| expr.try_into()) + .map(|expr| serialize_expr(expr, extension_codec)) .collect::, to_proto::Error>>()?, partition_count: *partition_count as u64, }) @@ -1416,9 +1441,8 @@ impl AsLogicalPlan for LogicalPlanNode { let temp = LogicalExprNodeCollection { logical_expr_nodes: order .iter() - .map(|expr| expr.try_into()) - .collect::, to_proto::Error>>( - )?, + .map(|expr| serialize_expr(expr, extension_codec)) + .collect::, to_proto::Error>>()?, }; converted_order_exprs.push(temp); } @@ -1426,7 +1450,8 @@ impl AsLogicalPlan for LogicalPlanNode { let mut converted_column_defaults = HashMap::with_capacity(column_defaults.len()); for (col_name, expr) in column_defaults { - converted_column_defaults.insert(col_name.clone(), expr.try_into()?); + converted_column_defaults + .insert(col_name.clone(), serialize_expr(expr, extension_codec)?); } let file_compression_type = diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index c913119ff9edb..d238884374fb2 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -56,6 +56,8 @@ use datafusion_expr::{ TryCast, WindowFrame, WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition, }; +use super::LogicalExtensionCodec; + #[derive(Debug)] pub enum Error { General(String), @@ -480,615 +482,612 @@ impl TryFrom<&WindowFrame> for protobuf::WindowFrame { } } -impl TryFrom<&Expr> for protobuf::LogicalExprNode { - type Error = Error; +pub fn serialize_expr( + expr: &Expr, + codec: &dyn LogicalExtensionCodec, +) -> Result { + use protobuf::logical_expr_node::ExprType; + + let expr_node = match expr { + Expr::Column(c) => protobuf::LogicalExprNode { + expr_type: Some(ExprType::Column(c.into())), + }, + Expr::Alias(Alias { + expr, + relation, + name, + }) => { + let alias = Box::new(protobuf::AliasNode { + expr: Some(Box::new(serialize_expr(expr.as_ref(), codec)?)), + relation: relation + .to_owned() + .map(|r| vec![r.into()]) + .unwrap_or(vec![]), + alias: name.to_owned(), + }); + protobuf::LogicalExprNode { + expr_type: Some(ExprType::Alias(alias)), + } + } + Expr::Literal(value) => { + let pb_value: protobuf::ScalarValue = value.try_into()?; + protobuf::LogicalExprNode { + expr_type: Some(ExprType::Literal(pb_value)), + } + } + Expr::BinaryExpr(BinaryExpr { left, op, right }) => { + // Try to linerize a nested binary expression tree of the same operator + // into a flat vector of expressions. + let mut exprs = vec![right.as_ref()]; + let mut current_expr = left.as_ref(); + while let Expr::BinaryExpr(BinaryExpr { + left, + op: current_op, + right, + }) = current_expr + { + if current_op == op { + exprs.push(right.as_ref()); + current_expr = left.as_ref(); + } else { + break; + } + } + exprs.push(current_expr); - fn try_from(expr: &Expr) -> Result { - use protobuf::logical_expr_node::ExprType; + let binary_expr = protobuf::BinaryExprNode { + // We need to reverse exprs since operands are expected to be + // linearized from left innermost to right outermost (but while + // traversing the chain we do the exact opposite). + operands: exprs + .into_iter() + .rev() + .map(|expr| serialize_expr(expr, codec)) + .collect::, Error>>()?, + op: format!("{op:?}"), + }; + protobuf::LogicalExprNode { + expr_type: Some(ExprType::BinaryExpr(binary_expr)), + } + } + Expr::Like(Like { + negated, + expr, + pattern, + escape_char, + case_insensitive, + }) => { + if *case_insensitive { + let pb = Box::new(protobuf::ILikeNode { + negated: *negated, + expr: Some(Box::new(serialize_expr(expr.as_ref(), codec)?)), + pattern: Some(Box::new(serialize_expr(pattern.as_ref(), codec)?)), + escape_char: escape_char.map(|ch| ch.to_string()).unwrap_or_default(), + }); - let expr_node = match expr { - Expr::Column(c) => Self { - expr_type: Some(ExprType::Column(c.into())), - }, - Expr::Alias(Alias { - expr, - relation, - name, - }) => { - let alias = Box::new(protobuf::AliasNode { - expr: Some(Box::new(expr.as_ref().try_into()?)), - relation: relation - .to_owned() - .map(|r| vec![r.into()]) - .unwrap_or(vec![]), - alias: name.to_owned(), + protobuf::LogicalExprNode { + expr_type: Some(ExprType::Ilike(pb)), + } + } else { + let pb = Box::new(protobuf::LikeNode { + negated: *negated, + expr: Some(Box::new(serialize_expr(expr.as_ref(), codec)?)), + pattern: Some(Box::new(serialize_expr(pattern.as_ref(), codec)?)), + escape_char: escape_char.map(|ch| ch.to_string()).unwrap_or_default(), }); - Self { - expr_type: Some(ExprType::Alias(alias)), + + protobuf::LogicalExprNode { + expr_type: Some(ExprType::Like(pb)), } } - Expr::Literal(value) => { - let pb_value: protobuf::ScalarValue = value.try_into()?; - Self { - expr_type: Some(ExprType::Literal(pb_value)), - } + } + Expr::SimilarTo(Like { + negated, + expr, + pattern, + escape_char, + case_insensitive: _, + }) => { + let pb = Box::new(protobuf::SimilarToNode { + negated: *negated, + expr: Some(Box::new(serialize_expr(expr.as_ref(), codec)?)), + pattern: Some(Box::new(serialize_expr(pattern.as_ref(), codec)?)), + escape_char: escape_char.map(|ch| ch.to_string()).unwrap_or_default(), + }); + protobuf::LogicalExprNode { + expr_type: Some(ExprType::SimilarTo(pb)), } - Expr::BinaryExpr(BinaryExpr { left, op, right }) => { - // Try to linerize a nested binary expression tree of the same operator - // into a flat vector of expressions. - let mut exprs = vec![right.as_ref()]; - let mut current_expr = left.as_ref(); - while let Expr::BinaryExpr(BinaryExpr { - left, - op: current_op, - right, - }) = current_expr - { - if current_op == op { - exprs.push(right.as_ref()); - current_expr = left.as_ref(); - } else { - break; - } + } + Expr::WindowFunction(expr::WindowFunction { + ref fun, + ref args, + ref partition_by, + ref order_by, + ref window_frame, + // TODO: support null treatment in proto + null_treatment: _, + }) => { + let window_function = match fun { + WindowFunctionDefinition::AggregateFunction(fun) => { + protobuf::window_expr_node::WindowFunction::AggrFunction( + protobuf::AggregateFunction::from(fun).into(), + ) } - exprs.push(current_expr); - - let binary_expr = protobuf::BinaryExprNode { - // We need to reverse exprs since operands are expected to be - // linearized from left innermost to right outermost (but while - // traversing the chain we do the exact opposite). - operands: exprs - .into_iter() - .rev() - .map(|expr| expr.try_into()) - .collect::, Error>>()?, - op: format!("{op:?}"), - }; - Self { - expr_type: Some(ExprType::BinaryExpr(binary_expr)), + WindowFunctionDefinition::BuiltInWindowFunction(fun) => { + protobuf::window_expr_node::WindowFunction::BuiltInFunction( + protobuf::BuiltInWindowFunction::from(fun).into(), + ) } - } - Expr::Like(Like { - negated, - expr, - pattern, - escape_char, - case_insensitive, - }) => { - if *case_insensitive { - let pb = Box::new(protobuf::ILikeNode { - negated: *negated, - expr: Some(Box::new(expr.as_ref().try_into()?)), - pattern: Some(Box::new(pattern.as_ref().try_into()?)), - escape_char: escape_char - .map(|ch| ch.to_string()) - .unwrap_or_default(), - }); - - Self { - expr_type: Some(ExprType::Ilike(pb)), - } - } else { - let pb = Box::new(protobuf::LikeNode { - negated: *negated, - expr: Some(Box::new(expr.as_ref().try_into()?)), - pattern: Some(Box::new(pattern.as_ref().try_into()?)), - escape_char: escape_char - .map(|ch| ch.to_string()) - .unwrap_or_default(), - }); - - Self { - expr_type: Some(ExprType::Like(pb)), - } + WindowFunctionDefinition::AggregateUDF(aggr_udf) => { + protobuf::window_expr_node::WindowFunction::Udaf( + aggr_udf.name().to_string(), + ) } - } - Expr::SimilarTo(Like { - negated, - expr, - pattern, - escape_char, - case_insensitive: _, - }) => { - let pb = Box::new(protobuf::SimilarToNode { - negated: *negated, - expr: Some(Box::new(expr.as_ref().try_into()?)), - pattern: Some(Box::new(pattern.as_ref().try_into()?)), - escape_char: escape_char.map(|ch| ch.to_string()).unwrap_or_default(), - }); - Self { - expr_type: Some(ExprType::SimilarTo(pb)), + WindowFunctionDefinition::WindowUDF(window_udf) => { + protobuf::window_expr_node::WindowFunction::Udwf( + window_udf.name().to_string(), + ) } + }; + let arg_expr: Option> = if !args.is_empty() { + let arg = &args[0]; + Some(Box::new(serialize_expr(arg, codec)?)) + } else { + None + }; + let partition_by = partition_by + .iter() + .map(|e| serialize_expr(e, codec)) + .collect::, _>>()?; + let order_by = order_by + .iter() + .map(|e| serialize_expr(e, codec)) + .collect::, _>>()?; + + let window_frame: Option = + Some(window_frame.try_into()?); + let window_expr = Box::new(protobuf::WindowExprNode { + expr: arg_expr, + window_function: Some(window_function), + partition_by, + order_by, + window_frame, + }); + protobuf::LogicalExprNode { + expr_type: Some(ExprType::WindowExpr(window_expr)), } - Expr::WindowFunction(expr::WindowFunction { - ref fun, - ref args, - ref partition_by, - ref order_by, - ref window_frame, - // TODO: support null treatment in proto - null_treatment: _, - }) => { - let window_function = match fun { - WindowFunctionDefinition::AggregateFunction(fun) => { - protobuf::window_expr_node::WindowFunction::AggrFunction( - protobuf::AggregateFunction::from(fun).into(), - ) + } + Expr::AggregateFunction(expr::AggregateFunction { + ref func_def, + ref args, + ref distinct, + ref filter, + ref order_by, + }) => match func_def { + AggregateFunctionDefinition::BuiltIn(fun) => { + let aggr_function = match fun { + AggregateFunction::ApproxDistinct => { + protobuf::AggregateFunction::ApproxDistinct + } + AggregateFunction::ApproxPercentileCont => { + protobuf::AggregateFunction::ApproxPercentileCont } - WindowFunctionDefinition::BuiltInWindowFunction(fun) => { - protobuf::window_expr_node::WindowFunction::BuiltInFunction( - protobuf::BuiltInWindowFunction::from(fun).into(), - ) + AggregateFunction::ApproxPercentileContWithWeight => { + protobuf::AggregateFunction::ApproxPercentileContWithWeight } - WindowFunctionDefinition::AggregateUDF(aggr_udf) => { - protobuf::window_expr_node::WindowFunction::Udaf( - aggr_udf.name().to_string(), - ) + AggregateFunction::ArrayAgg => protobuf::AggregateFunction::ArrayAgg, + AggregateFunction::Min => protobuf::AggregateFunction::Min, + AggregateFunction::Max => protobuf::AggregateFunction::Max, + AggregateFunction::Sum => protobuf::AggregateFunction::Sum, + AggregateFunction::BitAnd => protobuf::AggregateFunction::BitAnd, + AggregateFunction::BitOr => protobuf::AggregateFunction::BitOr, + AggregateFunction::BitXor => protobuf::AggregateFunction::BitXor, + AggregateFunction::BoolAnd => protobuf::AggregateFunction::BoolAnd, + AggregateFunction::BoolOr => protobuf::AggregateFunction::BoolOr, + AggregateFunction::Avg => protobuf::AggregateFunction::Avg, + AggregateFunction::Count => protobuf::AggregateFunction::Count, + AggregateFunction::Variance => protobuf::AggregateFunction::Variance, + AggregateFunction::VariancePop => { + protobuf::AggregateFunction::VariancePop } - WindowFunctionDefinition::WindowUDF(window_udf) => { - protobuf::window_expr_node::WindowFunction::Udwf( - window_udf.name().to_string(), - ) + AggregateFunction::Covariance => { + protobuf::AggregateFunction::Covariance + } + AggregateFunction::CovariancePop => { + protobuf::AggregateFunction::CovariancePop + } + AggregateFunction::Stddev => protobuf::AggregateFunction::Stddev, + AggregateFunction::StddevPop => { + protobuf::AggregateFunction::StddevPop + } + AggregateFunction::Correlation => { + protobuf::AggregateFunction::Correlation + } + AggregateFunction::RegrSlope => { + protobuf::AggregateFunction::RegrSlope + } + AggregateFunction::RegrIntercept => { + protobuf::AggregateFunction::RegrIntercept + } + AggregateFunction::RegrR2 => protobuf::AggregateFunction::RegrR2, + AggregateFunction::RegrAvgx => protobuf::AggregateFunction::RegrAvgx, + AggregateFunction::RegrAvgy => protobuf::AggregateFunction::RegrAvgy, + AggregateFunction::RegrCount => { + protobuf::AggregateFunction::RegrCount + } + AggregateFunction::RegrSXX => protobuf::AggregateFunction::RegrSxx, + AggregateFunction::RegrSYY => protobuf::AggregateFunction::RegrSyy, + AggregateFunction::RegrSXY => protobuf::AggregateFunction::RegrSxy, + AggregateFunction::ApproxMedian => { + protobuf::AggregateFunction::ApproxMedian + } + AggregateFunction::Grouping => protobuf::AggregateFunction::Grouping, + AggregateFunction::Median => protobuf::AggregateFunction::Median, + AggregateFunction::FirstValue => { + protobuf::AggregateFunction::FirstValueAgg + } + AggregateFunction::LastValue => { + protobuf::AggregateFunction::LastValueAgg + } + AggregateFunction::NthValue => { + protobuf::AggregateFunction::NthValueAgg + } + AggregateFunction::StringAgg => { + protobuf::AggregateFunction::StringAgg } }; - let arg_expr: Option> = if !args.is_empty() { - let arg = &args[0]; - Some(Box::new(arg.try_into()?)) - } else { - None + + let aggregate_expr = protobuf::AggregateExprNode { + aggr_function: aggr_function.into(), + expr: args + .iter() + .map(|v| serialize_expr(v, codec)) + .collect::, _>>()?, + distinct: *distinct, + filter: match filter { + Some(e) => Some(Box::new(serialize_expr(e, codec)?)), + None => None, + }, + order_by: match order_by { + Some(e) => e + .iter() + .map(|expr| serialize_expr(expr, codec)) + .collect::, _>>()?, + None => vec![], + }, }; - let partition_by = partition_by - .iter() - .map(|e| e.try_into()) - .collect::, _>>()?; - let order_by = order_by - .iter() - .map(|e| e.try_into()) - .collect::, _>>()?; - - let window_frame: Option = - Some(window_frame.try_into()?); - let window_expr = Box::new(protobuf::WindowExprNode { - expr: arg_expr, - window_function: Some(window_function), - partition_by, - order_by, - window_frame, - }); - Self { - expr_type: Some(ExprType::WindowExpr(window_expr)), + protobuf::LogicalExprNode { + expr_type: Some(ExprType::AggregateExpr(Box::new(aggregate_expr))), } } - Expr::AggregateFunction(expr::AggregateFunction { - ref func_def, - ref args, - ref distinct, - ref filter, - ref order_by, - }) => { - match func_def { - AggregateFunctionDefinition::BuiltIn(fun) => { - let aggr_function = match fun { - AggregateFunction::ApproxDistinct => { - protobuf::AggregateFunction::ApproxDistinct - } - AggregateFunction::ApproxPercentileCont => { - protobuf::AggregateFunction::ApproxPercentileCont - } - AggregateFunction::ApproxPercentileContWithWeight => { - protobuf::AggregateFunction::ApproxPercentileContWithWeight - } - AggregateFunction::ArrayAgg => protobuf::AggregateFunction::ArrayAgg, - AggregateFunction::Min => protobuf::AggregateFunction::Min, - AggregateFunction::Max => protobuf::AggregateFunction::Max, - AggregateFunction::Sum => protobuf::AggregateFunction::Sum, - AggregateFunction::BitAnd => protobuf::AggregateFunction::BitAnd, - AggregateFunction::BitOr => protobuf::AggregateFunction::BitOr, - AggregateFunction::BitXor => protobuf::AggregateFunction::BitXor, - AggregateFunction::BoolAnd => protobuf::AggregateFunction::BoolAnd, - AggregateFunction::BoolOr => protobuf::AggregateFunction::BoolOr, - AggregateFunction::Avg => protobuf::AggregateFunction::Avg, - AggregateFunction::Count => protobuf::AggregateFunction::Count, - AggregateFunction::Variance => protobuf::AggregateFunction::Variance, - AggregateFunction::VariancePop => { - protobuf::AggregateFunction::VariancePop - } - AggregateFunction::Covariance => { - protobuf::AggregateFunction::Covariance - } - AggregateFunction::CovariancePop => { - protobuf::AggregateFunction::CovariancePop - } - AggregateFunction::Stddev => protobuf::AggregateFunction::Stddev, - AggregateFunction::StddevPop => { - protobuf::AggregateFunction::StddevPop - } - AggregateFunction::Correlation => { - protobuf::AggregateFunction::Correlation - } - AggregateFunction::RegrSlope => { - protobuf::AggregateFunction::RegrSlope - } - AggregateFunction::RegrIntercept => { - protobuf::AggregateFunction::RegrIntercept - } - AggregateFunction::RegrR2 => protobuf::AggregateFunction::RegrR2, - AggregateFunction::RegrAvgx => protobuf::AggregateFunction::RegrAvgx, - AggregateFunction::RegrAvgy => protobuf::AggregateFunction::RegrAvgy, - AggregateFunction::RegrCount => { - protobuf::AggregateFunction::RegrCount - } - AggregateFunction::RegrSXX => protobuf::AggregateFunction::RegrSxx, - AggregateFunction::RegrSYY => protobuf::AggregateFunction::RegrSyy, - AggregateFunction::RegrSXY => protobuf::AggregateFunction::RegrSxy, - AggregateFunction::ApproxMedian => { - protobuf::AggregateFunction::ApproxMedian - } - AggregateFunction::Grouping => protobuf::AggregateFunction::Grouping, - AggregateFunction::Median => protobuf::AggregateFunction::Median, - AggregateFunction::FirstValue => { - protobuf::AggregateFunction::FirstValueAgg - } - AggregateFunction::LastValue => { - protobuf::AggregateFunction::LastValueAgg - } - AggregateFunction::NthValue => { - protobuf::AggregateFunction::NthValueAgg - } - AggregateFunction::StringAgg => { - protobuf::AggregateFunction::StringAgg - } - }; - - let aggregate_expr = protobuf::AggregateExprNode { - aggr_function: aggr_function.into(), - expr: args + AggregateFunctionDefinition::UDF(fun) => protobuf::LogicalExprNode { + expr_type: Some(ExprType::AggregateUdfExpr(Box::new( + protobuf::AggregateUdfExprNode { + fun_name: fun.name().to_string(), + args: args + .iter() + .map(|expr| serialize_expr(expr, codec)) + .collect::, Error>>()?, + filter: match filter { + Some(e) => Some(Box::new(serialize_expr(e.as_ref(), codec)?)), + None => None, + }, + order_by: match order_by { + Some(e) => e .iter() - .map(|v| v.try_into()) + .map(|expr| serialize_expr(expr, codec)) .collect::, _>>()?, - distinct: *distinct, - filter: match filter { - Some(e) => Some(Box::new(e.as_ref().try_into()?)), - None => None, - }, - order_by: match order_by { - Some(e) => e - .iter() - .map(|expr| expr.try_into()) - .collect::, _>>()?, - None => vec![], - }, - }; - Self { - expr_type: Some(ExprType::AggregateExpr(Box::new( - aggregate_expr, - ))), - } - } - AggregateFunctionDefinition::UDF(fun) => Self { - expr_type: Some(ExprType::AggregateUdfExpr(Box::new( - protobuf::AggregateUdfExprNode { - fun_name: fun.name().to_string(), - args: args - .iter() - .map(|expr| expr.try_into()) - .collect::, Error>>()?, - filter: match filter { - Some(e) => Some(Box::new(e.as_ref().try_into()?)), - None => None, - }, - order_by: match order_by { - Some(e) => e - .iter() - .map(|expr| expr.try_into()) - .collect::, _>>()?, - None => vec![], - }, - }, - ))), + None => vec![], + }, }, - AggregateFunctionDefinition::Name(_) => { - return Err(Error::NotImplemented( + ))), + }, + AggregateFunctionDefinition::Name(_) => { + return Err(Error::NotImplemented( "Proto serialization error: Trying to serialize a unresolved function" .to_string(), )); - } - } } + }, - Expr::ScalarVariable(_, _) => { - return Err(Error::General( - "Proto serialization error: Scalar Variable not supported" - .to_string(), - )) - } - Expr::ScalarFunction(ScalarFunction { func_def, args }) => { - let args = args - .iter() - .map(|expr| expr.try_into()) - .collect::, Error>>()?; - match func_def { - ScalarFunctionDefinition::BuiltIn(fun) => { - let fun: protobuf::ScalarFunction = fun.try_into()?; - Self { - expr_type: Some(ExprType::ScalarFunction( - protobuf::ScalarFunctionNode { - fun: fun.into(), - args, - }, - )), - } + Expr::ScalarVariable(_, _) => { + return Err(Error::General( + "Proto serialization error: Scalar Variable not supported".to_string(), + )) + } + Expr::ScalarFunction(ScalarFunction { func_def, args }) => { + let args = args + .iter() + .map(|expr| serialize_expr(expr, codec)) + .collect::, Error>>()?; + match func_def { + ScalarFunctionDefinition::BuiltIn(fun) => { + let fun: protobuf::ScalarFunction = fun.try_into()?; + protobuf::LogicalExprNode { + expr_type: Some(ExprType::ScalarFunction( + protobuf::ScalarFunctionNode { + fun: fun.into(), + args, + }, + )), } - ScalarFunctionDefinition::UDF(fun) => Self { + } + ScalarFunctionDefinition::UDF(fun) => { + let mut buf = Vec::new(); + let _ = codec.try_encode_udf(fun.as_ref(), &mut buf); + + let fun_definition = if buf.is_empty() { None } else { Some(buf) }; + + protobuf::LogicalExprNode { expr_type: Some(ExprType::ScalarUdfExpr( protobuf::ScalarUdfExprNode { fun_name: fun.name().to_string(), + fun_definition, args, }, )), - }, - ScalarFunctionDefinition::Name(_) => { - return Err(Error::NotImplemented( + } + } + ScalarFunctionDefinition::Name(_) => { + return Err(Error::NotImplemented( "Proto serialization error: Trying to serialize a unresolved function" .to_string(), )); - } } } - Expr::Not(expr) => { - let expr = Box::new(protobuf::Not { - expr: Some(Box::new(expr.as_ref().try_into()?)), - }); - Self { - expr_type: Some(ExprType::NotExpr(expr)), - } + } + Expr::Not(expr) => { + let expr = Box::new(protobuf::Not { + expr: Some(Box::new(serialize_expr(expr.as_ref(), codec)?)), + }); + protobuf::LogicalExprNode { + expr_type: Some(ExprType::NotExpr(expr)), } - Expr::IsNull(expr) => { - let expr = Box::new(protobuf::IsNull { - expr: Some(Box::new(expr.as_ref().try_into()?)), - }); - Self { - expr_type: Some(ExprType::IsNullExpr(expr)), - } + } + Expr::IsNull(expr) => { + let expr = Box::new(protobuf::IsNull { + expr: Some(Box::new(serialize_expr(expr.as_ref(), codec)?)), + }); + protobuf::LogicalExprNode { + expr_type: Some(ExprType::IsNullExpr(expr)), } - Expr::IsNotNull(expr) => { - let expr = Box::new(protobuf::IsNotNull { - expr: Some(Box::new(expr.as_ref().try_into()?)), - }); - Self { - expr_type: Some(ExprType::IsNotNullExpr(expr)), - } + } + Expr::IsNotNull(expr) => { + let expr = Box::new(protobuf::IsNotNull { + expr: Some(Box::new(serialize_expr(expr.as_ref(), codec)?)), + }); + protobuf::LogicalExprNode { + expr_type: Some(ExprType::IsNotNullExpr(expr)), } - Expr::IsTrue(expr) => { - let expr = Box::new(protobuf::IsTrue { - expr: Some(Box::new(expr.as_ref().try_into()?)), - }); - Self { - expr_type: Some(ExprType::IsTrue(expr)), - } + } + Expr::IsTrue(expr) => { + let expr = Box::new(protobuf::IsTrue { + expr: Some(Box::new(serialize_expr(expr.as_ref(), codec)?)), + }); + protobuf::LogicalExprNode { + expr_type: Some(ExprType::IsTrue(expr)), } - Expr::IsFalse(expr) => { - let expr = Box::new(protobuf::IsFalse { - expr: Some(Box::new(expr.as_ref().try_into()?)), - }); - Self { - expr_type: Some(ExprType::IsFalse(expr)), - } + } + Expr::IsFalse(expr) => { + let expr = Box::new(protobuf::IsFalse { + expr: Some(Box::new(serialize_expr(expr.as_ref(), codec)?)), + }); + protobuf::LogicalExprNode { + expr_type: Some(ExprType::IsFalse(expr)), } - Expr::IsUnknown(expr) => { - let expr = Box::new(protobuf::IsUnknown { - expr: Some(Box::new(expr.as_ref().try_into()?)), - }); - Self { - expr_type: Some(ExprType::IsUnknown(expr)), - } + } + Expr::IsUnknown(expr) => { + let expr = Box::new(protobuf::IsUnknown { + expr: Some(Box::new(serialize_expr(expr.as_ref(), codec)?)), + }); + protobuf::LogicalExprNode { + expr_type: Some(ExprType::IsUnknown(expr)), } - Expr::IsNotTrue(expr) => { - let expr = Box::new(protobuf::IsNotTrue { - expr: Some(Box::new(expr.as_ref().try_into()?)), - }); - Self { - expr_type: Some(ExprType::IsNotTrue(expr)), - } + } + Expr::IsNotTrue(expr) => { + let expr = Box::new(protobuf::IsNotTrue { + expr: Some(Box::new(serialize_expr(expr.as_ref(), codec)?)), + }); + protobuf::LogicalExprNode { + expr_type: Some(ExprType::IsNotTrue(expr)), } - Expr::IsNotFalse(expr) => { - let expr = Box::new(protobuf::IsNotFalse { - expr: Some(Box::new(expr.as_ref().try_into()?)), - }); - Self { - expr_type: Some(ExprType::IsNotFalse(expr)), - } + } + Expr::IsNotFalse(expr) => { + let expr = Box::new(protobuf::IsNotFalse { + expr: Some(Box::new(serialize_expr(expr.as_ref(), codec)?)), + }); + protobuf::LogicalExprNode { + expr_type: Some(ExprType::IsNotFalse(expr)), } - Expr::IsNotUnknown(expr) => { - let expr = Box::new(protobuf::IsNotUnknown { - expr: Some(Box::new(expr.as_ref().try_into()?)), - }); - Self { - expr_type: Some(ExprType::IsNotUnknown(expr)), - } + } + Expr::IsNotUnknown(expr) => { + let expr = Box::new(protobuf::IsNotUnknown { + expr: Some(Box::new(serialize_expr(expr.as_ref(), codec)?)), + }); + protobuf::LogicalExprNode { + expr_type: Some(ExprType::IsNotUnknown(expr)), } - Expr::Between(Between { - expr, - negated, - low, - high, - }) => { - let expr = Box::new(protobuf::BetweenNode { - expr: Some(Box::new(expr.as_ref().try_into()?)), - negated: *negated, - low: Some(Box::new(low.as_ref().try_into()?)), - high: Some(Box::new(high.as_ref().try_into()?)), - }); - Self { - expr_type: Some(ExprType::Between(expr)), - } + } + Expr::Between(Between { + expr, + negated, + low, + high, + }) => { + let expr = Box::new(protobuf::BetweenNode { + expr: Some(Box::new(serialize_expr(expr.as_ref(), codec)?)), + negated: *negated, + low: Some(Box::new(serialize_expr(low.as_ref(), codec)?)), + high: Some(Box::new(serialize_expr(high.as_ref(), codec)?)), + }); + protobuf::LogicalExprNode { + expr_type: Some(ExprType::Between(expr)), } - Expr::Case(case) => { - let when_then_expr = case - .when_then_expr - .iter() - .map(|(w, t)| { - Ok(protobuf::WhenThen { - when_expr: Some(w.as_ref().try_into()?), - then_expr: Some(t.as_ref().try_into()?), - }) + } + Expr::Case(case) => { + let when_then_expr = case + .when_then_expr + .iter() + .map(|(w, t)| { + Ok(protobuf::WhenThen { + when_expr: Some(serialize_expr(w.as_ref(), codec)?), + then_expr: Some(serialize_expr(t.as_ref(), codec)?), }) - .collect::, Error>>()?; - let expr = Box::new(protobuf::CaseNode { - expr: match &case.expr { - Some(e) => Some(Box::new(e.as_ref().try_into()?)), - None => None, - }, - when_then_expr, - else_expr: match &case.else_expr { - Some(e) => Some(Box::new(e.as_ref().try_into()?)), - None => None, - }, - }); - Self { - expr_type: Some(ExprType::Case(expr)), - } + }) + .collect::, Error>>()?; + let expr = Box::new(protobuf::CaseNode { + expr: match &case.expr { + Some(e) => Some(Box::new(serialize_expr(e.as_ref(), codec)?)), + None => None, + }, + when_then_expr, + else_expr: match &case.else_expr { + Some(e) => Some(Box::new(serialize_expr(e.as_ref(), codec)?)), + None => None, + }, + }); + protobuf::LogicalExprNode { + expr_type: Some(ExprType::Case(expr)), } - Expr::Cast(Cast { expr, data_type }) => { - let expr = Box::new(protobuf::CastNode { - expr: Some(Box::new(expr.as_ref().try_into()?)), - arrow_type: Some(data_type.try_into()?), - }); - Self { - expr_type: Some(ExprType::Cast(expr)), - } + } + Expr::Cast(Cast { expr, data_type }) => { + let expr = Box::new(protobuf::CastNode { + expr: Some(Box::new(serialize_expr(expr.as_ref(), codec)?)), + arrow_type: Some(data_type.try_into()?), + }); + protobuf::LogicalExprNode { + expr_type: Some(ExprType::Cast(expr)), } - Expr::TryCast(TryCast { expr, data_type }) => { - let expr = Box::new(protobuf::TryCastNode { - expr: Some(Box::new(expr.as_ref().try_into()?)), - arrow_type: Some(data_type.try_into()?), - }); - Self { - expr_type: Some(ExprType::TryCast(expr)), - } + } + Expr::TryCast(TryCast { expr, data_type }) => { + let expr = Box::new(protobuf::TryCastNode { + expr: Some(Box::new(serialize_expr(expr.as_ref(), codec)?)), + arrow_type: Some(data_type.try_into()?), + }); + protobuf::LogicalExprNode { + expr_type: Some(ExprType::TryCast(expr)), } - Expr::Sort(Sort { - expr, - asc, - nulls_first, - }) => { - let expr = Box::new(protobuf::SortExprNode { - expr: Some(Box::new(expr.as_ref().try_into()?)), - asc: *asc, - nulls_first: *nulls_first, - }); - Self { - expr_type: Some(ExprType::Sort(expr)), - } + } + Expr::Sort(Sort { + expr, + asc, + nulls_first, + }) => { + let expr = Box::new(protobuf::SortExprNode { + expr: Some(Box::new(serialize_expr(expr.as_ref(), codec)?)), + asc: *asc, + nulls_first: *nulls_first, + }); + protobuf::LogicalExprNode { + expr_type: Some(ExprType::Sort(expr)), } - Expr::Negative(expr) => { - let expr = Box::new(protobuf::NegativeNode { - expr: Some(Box::new(expr.as_ref().try_into()?)), - }); - Self { - expr_type: Some(ExprType::Negative(expr)), - } + } + Expr::Negative(expr) => { + let expr = Box::new(protobuf::NegativeNode { + expr: Some(Box::new(serialize_expr(expr.as_ref(), codec)?)), + }); + protobuf::LogicalExprNode { + expr_type: Some(ExprType::Negative(expr)), } - Expr::Unnest(Unnest { exprs }) => { - let expr = protobuf::Unnest { - exprs: exprs.iter().map(|expr| expr.try_into()).collect::, - Error, - >>( - )?, - }; - Self { - expr_type: Some(ExprType::Unnest(expr)), - } + } + Expr::Unnest(Unnest { exprs }) => { + let expr = protobuf::Unnest { + exprs: exprs + .iter() + .map(|expr| serialize_expr(expr, codec)) + .collect::, Error>>()?, + }; + protobuf::LogicalExprNode { + expr_type: Some(ExprType::Unnest(expr)), } - Expr::InList(InList { - expr, - list, - negated, - }) => { - let expr = Box::new(protobuf::InListNode { - expr: Some(Box::new(expr.as_ref().try_into()?)), - list: list - .iter() - .map(|expr| expr.try_into()) - .collect::, Error>>()?, - negated: *negated, - }); - Self { - expr_type: Some(ExprType::InList(expr)), - } + } + Expr::InList(InList { + expr, + list, + negated, + }) => { + let expr = Box::new(protobuf::InListNode { + expr: Some(Box::new(serialize_expr(expr.as_ref(), codec)?)), + list: list + .iter() + .map(|expr| serialize_expr(expr, codec)) + .collect::, Error>>()?, + negated: *negated, + }); + protobuf::LogicalExprNode { + expr_type: Some(ExprType::InList(expr)), } - Expr::Wildcard { qualifier } => Self { - expr_type: Some(ExprType::Wildcard(protobuf::Wildcard { - qualifier: qualifier.clone().unwrap_or("".to_string()), - })), - }, - Expr::ScalarSubquery(_) - | Expr::InSubquery(_) - | Expr::Exists { .. } - | Expr::OuterReferenceColumn { .. } => { - // we would need to add logical plan operators to datafusion.proto to support this - // see discussion in https://github.com/apache/arrow-datafusion/issues/2565 - return Err(Error::General("Proto serialization error: Expr::ScalarSubquery(_) | Expr::InSubquery(_) | Expr::Exists { .. } | Exp:OuterReferenceColumn not supported".to_string())); - } - Expr::GetIndexedField(GetIndexedField { expr, field }) => { - let field = match field { - GetFieldAccess::NamedStructField { name } => { - protobuf::get_indexed_field::Field::NamedStructField( - protobuf::NamedStructField { - name: Some(name.try_into()?), - }, - ) - } - GetFieldAccess::ListIndex { key } => { - protobuf::get_indexed_field::Field::ListIndex(Box::new( - protobuf::ListIndex { - key: Some(Box::new(key.as_ref().try_into()?)), - }, - )) - } - GetFieldAccess::ListRange { - start, - stop, - stride, - } => protobuf::get_indexed_field::Field::ListRange(Box::new( - protobuf::ListRange { - start: Some(Box::new(start.as_ref().try_into()?)), - stop: Some(Box::new(stop.as_ref().try_into()?)), - stride: Some(Box::new(stride.as_ref().try_into()?)), + } + Expr::Wildcard { qualifier } => protobuf::LogicalExprNode { + expr_type: Some(ExprType::Wildcard(protobuf::Wildcard { + qualifier: qualifier.clone().unwrap_or("".to_string()), + })), + }, + Expr::ScalarSubquery(_) + | Expr::InSubquery(_) + | Expr::Exists { .. } + | Expr::OuterReferenceColumn { .. } => { + // we would need to add logical plan operators to datafusion.proto to support this + // see discussion in https://github.com/apache/arrow-datafusion/issues/2565 + return Err(Error::General("Proto serialization error: Expr::ScalarSubquery(_) | Expr::InSubquery(_) | Expr::Exists { .. } | Exp:OuterReferenceColumn not supported".to_string())); + } + Expr::GetIndexedField(GetIndexedField { expr, field }) => { + let field = match field { + GetFieldAccess::NamedStructField { name } => { + protobuf::get_indexed_field::Field::NamedStructField( + protobuf::NamedStructField { + name: Some(name.try_into()?), }, - )), - }; - - Self { - expr_type: Some(ExprType::GetIndexedField(Box::new( - protobuf::GetIndexedField { - expr: Some(Box::new(expr.as_ref().try_into()?)), - field: Some(field), + ) + } + GetFieldAccess::ListIndex { key } => { + protobuf::get_indexed_field::Field::ListIndex(Box::new( + protobuf::ListIndex { + key: Some(Box::new(serialize_expr(key.as_ref(), codec)?)), }, - ))), + )) } + GetFieldAccess::ListRange { + start, + stop, + stride, + } => protobuf::get_indexed_field::Field::ListRange(Box::new( + protobuf::ListRange { + start: Some(Box::new(serialize_expr(start.as_ref(), codec)?)), + stop: Some(Box::new(serialize_expr(stop.as_ref(), codec)?)), + stride: Some(Box::new(serialize_expr(stride.as_ref(), codec)?)), + }, + )), + }; + + protobuf::LogicalExprNode { + expr_type: Some(ExprType::GetIndexedField(Box::new( + protobuf::GetIndexedField { + expr: Some(Box::new(serialize_expr(expr.as_ref(), codec)?)), + field: Some(field), + }, + ))), } + } - Expr::GroupingSet(GroupingSet::Cube(exprs)) => Self { - expr_type: Some(ExprType::Cube(CubeNode { - expr: exprs.iter().map(|expr| expr.try_into()).collect::, - Self::Error, - >>( - )?, - })), - }, - Expr::GroupingSet(GroupingSet::Rollup(exprs)) => Self { - expr_type: Some(ExprType::Rollup(RollupNode { - expr: exprs.iter().map(|expr| expr.try_into()).collect::, - Self::Error, - >>( - )?, - })), - }, - Expr::GroupingSet(GroupingSet::GroupingSets(exprs)) => Self { + Expr::GroupingSet(GroupingSet::Cube(exprs)) => protobuf::LogicalExprNode { + expr_type: Some(ExprType::Cube(CubeNode { + expr: exprs + .iter() + .map(|expr| serialize_expr(expr, codec)) + .collect::, Error>>()?, + })), + }, + Expr::GroupingSet(GroupingSet::Rollup(exprs)) => protobuf::LogicalExprNode { + expr_type: Some(ExprType::Rollup(RollupNode { + expr: exprs + .iter() + .map(|expr| serialize_expr(expr, codec)) + .collect::, Error>>()?, + })), + }, + Expr::GroupingSet(GroupingSet::GroupingSets(exprs)) => { + protobuf::LogicalExprNode { expr_type: Some(ExprType::GroupingSet(GroupingSetNode { expr: exprs .iter() @@ -1096,29 +1095,29 @@ impl TryFrom<&Expr> for protobuf::LogicalExprNode { Ok(LogicalExprList { expr: expr_list .iter() - .map(|expr| expr.try_into()) - .collect::, Self::Error>>()?, + .map(|expr| serialize_expr(expr, codec)) + .collect::, Error>>()?, }) }) - .collect::, Self::Error>>()?, + .collect::, Error>>()?, })), - }, - Expr::Placeholder(Placeholder { id, data_type }) => { - let data_type = match data_type { - Some(data_type) => Some(data_type.try_into()?), - None => None, - }; - Self { - expr_type: Some(ExprType::Placeholder(PlaceholderNode { - id: id.clone(), - data_type, - })), - } } - }; + } + Expr::Placeholder(Placeholder { id, data_type }) => { + let data_type = match data_type { + Some(data_type) => Some(data_type.try_into()?), + None => None, + }; + protobuf::LogicalExprNode { + expr_type: Some(ExprType::Placeholder(PlaceholderNode { + id: id.clone(), + data_type, + })), + } + } + }; - Ok(expr_node) - } + Ok(expr_node) } impl TryFrom<&ScalarValue> for protobuf::ScalarValue { diff --git a/datafusion/proto/src/physical_plan/mod.rs b/datafusion/proto/src/physical_plan/mod.rs index d2961875d89ad..a20baeb4e941a 100644 --- a/datafusion/proto/src/physical_plan/mod.rs +++ b/datafusion/proto/src/physical_plan/mod.rs @@ -60,6 +60,7 @@ use datafusion::physical_plan::{ WindowExpr, }; use datafusion_common::{internal_err, not_impl_err, DataFusionError, Result}; +use datafusion_expr::ScalarUDF; use prost::bytes::BufMut; use prost::Message; @@ -1911,6 +1912,14 @@ pub trait PhysicalExtensionCodec: Debug + Send + Sync { ) -> Result>; fn try_encode(&self, node: Arc, buf: &mut Vec) -> Result<()>; + + fn try_decode_udf(&self, name: &str, _buf: &[u8]) -> Result> { + not_impl_err!("PhysicalExtensionCodec is not provided for scalar function {name}") + } + + fn try_encode_udf(&self, _node: &ScalarUDF, _buf: &mut Vec) -> Result<()> { + Ok(()) + } } #[derive(Debug)] diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs index e3bd2cb1dc47c..0ec44190ef7a5 100644 --- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs @@ -28,6 +28,8 @@ use arrow::datatypes::{ }; use datafusion_common::file_options::arrow_writer::ArrowWriterOptions; +use datafusion_expr::{ScalarUDF, ScalarUDFImpl}; +use datafusion_proto::logical_plan::to_proto::serialize_expr; use prost::Message; use datafusion::datasource::provider::TableProviderFactory; @@ -62,8 +64,8 @@ use datafusion_proto::bytes::{ logical_plan_from_bytes, logical_plan_from_bytes_with_extension_codec, logical_plan_to_bytes, logical_plan_to_bytes_with_extension_codec, }; -use datafusion_proto::logical_plan::from_proto; use datafusion_proto::logical_plan::LogicalExtensionCodec; +use datafusion_proto::logical_plan::{from_proto, DefaultLogicalExtensionCodec}; use datafusion_proto::protobuf; #[cfg(feature = "json")] @@ -78,13 +80,15 @@ fn roundtrip_json_test(_proto: &protobuf::LogicalExprNode) {} // Given a DataFusion logical Expr, convert it to protobuf and back, using debug formatting to test // equality. -fn roundtrip_expr_test(initial_struct: T, ctx: SessionContext) -where - for<'a> &'a T: TryInto + Debug, - E: Debug, -{ - let proto: protobuf::LogicalExprNode = (&initial_struct).try_into().unwrap(); - let round_trip: Expr = from_proto::parse_expr(&proto, &ctx).unwrap(); +fn roundtrip_expr_test(initial_struct: Expr, ctx: SessionContext) { + let extension_codec = DefaultLogicalExtensionCodec {}; + let proto: protobuf::LogicalExprNode = + match serialize_expr(&initial_struct, &extension_codec) { + Ok(p) => p, + Err(e) => panic!("Error serializing expression: {:?}", e), + }; + let round_trip: Expr = + from_proto::parse_expr(&proto, &ctx, &extension_codec).unwrap(); assert_eq!(format!("{:?}", &initial_struct), format!("{round_trip:?}")); @@ -631,6 +635,12 @@ pub mod proto { #[prost(uint64, tag = "1")] pub k: u64, } + + #[derive(Clone, PartialEq, ::prost::Message)] + pub struct MyRegexUdfNode { + #[prost(string, tag = "1")] + pub pattern: String, + } } #[derive(PartialEq, Eq, Hash)] @@ -707,7 +717,7 @@ impl LogicalExtensionCodec for TopKExtensionCodec { let node = TopKPlanNode::new( proto.k as usize, input.clone(), - from_proto::parse_expr(expr, ctx)?, + from_proto::parse_expr(expr, ctx, self)?, ); Ok(Extension { @@ -725,7 +735,7 @@ impl LogicalExtensionCodec for TopKExtensionCodec { if let Some(exec) = node.node.as_any().downcast_ref::() { let proto = proto::TopKPlanProto { k: exec.k as u64, - expr: Some((&exec.expr).try_into()?), + expr: Some(serialize_expr(&exec.expr, self)?), }; proto.encode(buf).map_err(|e| { @@ -756,6 +766,109 @@ impl LogicalExtensionCodec for TopKExtensionCodec { } } +#[derive(Debug)] +struct MyRegexUdf { + signature: Signature, + // regex as original string + pattern: String, +} + +impl MyRegexUdf { + fn new(pattern: String) -> Self { + Self { + signature: Signature::uniform( + 1, + vec![DataType::Int32], + Volatility::Immutable, + ), + pattern, + } + } +} + +/// Implement the ScalarUDFImpl trait for MyRegexUdf +impl ScalarUDFImpl for MyRegexUdf { + fn as_any(&self) -> &dyn Any { + self + } + fn name(&self) -> &str { + "regex_udf" + } + fn signature(&self) -> &Signature { + &self.signature + } + fn return_type(&self, args: &[DataType]) -> Result { + if !matches!(args.first(), Some(&DataType::Utf8)) { + return plan_err!("regex_udf only accepts Utf8 arguments"); + } + Ok(DataType::Int32) + } + fn invoke(&self, _args: &[ColumnarValue]) -> Result { + unimplemented!() + } +} + +#[derive(Debug)] +pub struct ScalarUDFExtensionCodec {} + +impl LogicalExtensionCodec for ScalarUDFExtensionCodec { + fn try_decode( + &self, + _buf: &[u8], + _inputs: &[LogicalPlan], + _ctx: &SessionContext, + ) -> Result { + not_impl_err!("No extension codec provided") + } + + fn try_encode(&self, _node: &Extension, _buf: &mut Vec) -> Result<()> { + not_impl_err!("No extension codec provided") + } + + fn try_decode_table_provider( + &self, + _buf: &[u8], + _schema: SchemaRef, + _ctx: &SessionContext, + ) -> Result> { + internal_err!("unsupported plan type") + } + + fn try_encode_table_provider( + &self, + _node: Arc, + _buf: &mut Vec, + ) -> Result<()> { + internal_err!("unsupported plan type") + } + + fn try_decode_udf(&self, name: &str, buf: &[u8]) -> Result> { + if name == "regex_udf" { + let proto = proto::MyRegexUdfNode::decode(buf).map_err(|err| { + DataFusionError::Internal(format!("failed to decode regex_udf: {}", err)) + })?; + + Ok(Arc::new(ScalarUDF::new_from_impl(MyRegexUdf::new( + proto.pattern, + )))) + } else { + not_impl_err!("unrecognized scalar UDF implementation, cannot decode") + } + } + + fn try_encode_udf(&self, node: &ScalarUDF, buf: &mut Vec) -> Result<()> { + let binding = node.inner(); + let udf = binding.as_any().downcast_ref::().unwrap(); + let proto = proto::MyRegexUdfNode { + pattern: udf.pattern.clone(), + }; + proto.encode(buf).map_err(|e| { + DataFusionError::Internal(format!("failed to encode udf: {e:?}")) + })?; + Ok(()) + } +} + #[test] fn round_trip_scalar_values() { let should_pass: Vec = vec![ @@ -1664,6 +1777,30 @@ fn roundtrip_scalar_udf() { roundtrip_expr_test(test_expr, ctx); } +#[test] +fn roundtrip_scalar_udf_extension_codec() { + let pattern = ".*"; + let udf = ScalarUDF::from(MyRegexUdf::new(pattern.to_string())); + let test_expr = + Expr::ScalarFunction(ScalarFunction::new_udf(Arc::new(udf.clone()), vec![])); + + let ctx = SessionContext::new(); + ctx.register_udf(udf); + + let extension_codec = ScalarUDFExtensionCodec {}; + let proto: protobuf::LogicalExprNode = + match serialize_expr(&test_expr, &extension_codec) { + Ok(p) => p, + Err(e) => panic!("Error serializing expression: {:?}", e), + }; + let round_trip: Expr = + from_proto::parse_expr(&proto, &ctx, &extension_codec).unwrap(); + + assert_eq!(format!("{:?}", &test_expr), format!("{round_trip:?}")); + + roundtrip_json_test(&proto); +} + #[test] fn roundtrip_grouping_sets() { let test_expr = Expr::GroupingSet(GroupingSet::GroupingSets(vec![ diff --git a/datafusion/proto/tests/cases/serialize.rs b/datafusion/proto/tests/cases/serialize.rs index 7dd0333909ee8..d4a1ab44a6ea8 100644 --- a/datafusion/proto/tests/cases/serialize.rs +++ b/datafusion/proto/tests/cases/serialize.rs @@ -25,6 +25,8 @@ use datafusion::prelude::SessionContext; use datafusion_expr::{col, create_udf, lit, ColumnarValue}; use datafusion_expr::{Expr, Volatility}; use datafusion_proto::bytes::Serializeable; +use datafusion_proto::logical_plan::to_proto::serialize_expr; +use datafusion_proto::logical_plan::DefaultLogicalExtensionCodec; #[test] #[should_panic( @@ -252,7 +254,6 @@ fn test_expression_serialization_roundtrip() { use datafusion_expr::expr::ScalarFunction; use datafusion_expr::BuiltinScalarFunction; use datafusion_proto::logical_plan::from_proto::parse_expr; - use datafusion_proto::protobuf::LogicalExprNode; use strum::IntoEnumIterator; let ctx = SessionContext::new(); @@ -266,8 +267,9 @@ fn test_expression_serialization_roundtrip() { let args: Vec<_> = std::iter::repeat(&lit).take(num_args).cloned().collect(); let expr = Expr::ScalarFunction(ScalarFunction::new(builtin_fun, args)); - let proto = LogicalExprNode::try_from(&expr).unwrap(); - let deserialize = parse_expr(&proto, &ctx).unwrap(); + let extension_codec = DefaultLogicalExtensionCodec {}; + let proto = serialize_expr(&expr, &extension_codec).unwrap(); + let deserialize = parse_expr(&proto, &ctx, &extension_codec).unwrap(); let serialize_name = extract_function_name(&expr); let deserialize_name = extract_function_name(&deserialize);