From 2f4bda0cbb56b4695f8972ee1d25138f7ccf61e3 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Wed, 7 Dec 2022 11:52:28 -0700 Subject: [PATCH 1/2] Change API to return LogicalPlan instead of DataFrame --- src/consumer.rs | 292 +++++++++++++++++++++++++++------------------ src/producer.rs | 215 ++++++++++++++++++++++++--------- src/serializer.rs | 13 +- tests/roundtrip.rs | 37 +++--- tests/serialize.rs | 3 +- 5 files changed, 359 insertions(+), 201 deletions(-) diff --git a/src/consumer.rs b/src/consumer.rs index 1677cdc..1bb826c 100644 --- a/src/consumer.rs +++ b/src/consumer.rs @@ -1,13 +1,13 @@ use async_recursion::async_recursion; use datafusion::common::{DFField, DFSchema, DFSchemaRef}; -use datafusion::logical_expr::{LogicalPlan, aggregate_function}; +use datafusion::logical_expr::{aggregate_function, LogicalPlan, LogicalPlanBuilder}; use datafusion::logical_plan::build_join_schema; use datafusion::prelude::JoinType; use datafusion::{ error::{DataFusionError, Result}, logical_plan::{Expr, Operator}, optimizer::utils::split_conjunction, - prelude::{Column, DataFrame, SessionContext}, + prelude::{Column, SessionContext}, scalar::ScalarValue, }; @@ -15,17 +15,14 @@ use datafusion::sql::TableReference; use substrait::protobuf::{ aggregate_function::AggregationInvocation, expression::{ - field_reference::ReferenceType::DirectReference, - literal::LiteralType, - MaskExpression, - reference_segment::ReferenceType::StructField, - RexType, + field_reference::ReferenceType::DirectReference, literal::LiteralType, + reference_segment::ReferenceType::StructField, MaskExpression, RexType, }, extensions::simple_extension_declaration::MappingType, function_argument::ArgType, read_rel::ReadType, rel::RelType, - sort_field::{SortKind::*, SortDirection}, + sort_field::{SortDirection, SortKind::*}, AggregateFunction, Expression, Plan, Rel, }; @@ -70,18 +67,22 @@ pub fn name_to_op(name: &str) -> Result { } /// Convert Substrait Plan to DataFusion DataFrame -pub async fn from_substrait_plan(ctx: &mut SessionContext, plan: &Plan) -> Result> { +pub async fn from_substrait_plan(ctx: &mut SessionContext, plan: &Plan) -> Result { // Register function extension - let function_extension = plan.extensions + let function_extension = plan + .extensions .iter() .map(|e| match &e.mapping_type { - Some(ext) => { - match ext { - MappingType::ExtensionFunction(ext_f) => Ok((ext_f.function_anchor, &ext_f.name)), - _ => Err(DataFusionError::NotImplemented(format!("Extension type not supported: {:?}", ext))) - } - } - None => Err(DataFusionError::NotImplemented("Cannot parse empty extension".to_string())) + Some(ext) => match ext { + MappingType::ExtensionFunction(ext_f) => Ok((ext_f.function_anchor, &ext_f.name)), + _ => Err(DataFusionError::NotImplemented(format!( + "Extension type not supported: {:?}", + ext + ))), + }, + None => Err(DataFusionError::NotImplemented( + "Cannot parse empty extension".to_string(), + )), }) .collect::>>()?; // Parse relations @@ -98,7 +99,6 @@ pub async fn from_substrait_plan(ctx: &mut SessionContext, plan: &Plan) -> Resul }, None => Err(DataFusionError::Internal("Cannot parse plan relation: None".to_string())) } - }, _ => Err(DataFusionError::NotImplemented(format!( "Substrait plan with more than 1 relation trees not supported. Number of relation trees: {:?}", @@ -109,17 +109,22 @@ pub async fn from_substrait_plan(ctx: &mut SessionContext, plan: &Plan) -> Resul /// Convert Substrait Rel to DataFusion DataFrame #[async_recursion] -pub async fn from_substrait_rel(ctx: &mut SessionContext, rel: &Rel, extensions: &HashMap) -> Result> { +pub async fn from_substrait_rel( + ctx: &mut SessionContext, + rel: &Rel, + extensions: &HashMap, +) -> Result { match &rel.rel_type { Some(RelType::Project(p)) => { if let Some(input) = p.input.as_ref() { - let input = from_substrait_rel(ctx, input, extensions).await?; + let input = + LogicalPlanBuilder::from(from_substrait_rel(ctx, input, extensions).await?); let mut exprs: Vec = vec![]; for e in &p.expressions { let x = from_substrait_rex(e, &input.schema(), extensions).await?; exprs.push(x.as_ref().clone()); } - input.select(exprs) + input.project(exprs)?.build() } else { Err(DataFusionError::NotImplemented( "Projection without an input is not supported".to_string(), @@ -128,10 +133,11 @@ pub async fn from_substrait_rel(ctx: &mut SessionContext, rel: &Rel, extensions: } Some(RelType::Filter(filter)) => { if let Some(input) = filter.input.as_ref() { - let input = from_substrait_rel(ctx, input, extensions).await?; + let input = + LogicalPlanBuilder::from(from_substrait_rel(ctx, input, extensions).await?); if let Some(condition) = filter.condition.as_ref() { let expr = from_substrait_rex(condition, &input.schema(), extensions).await?; - input.filter(expr.as_ref().clone()) + input.filter(expr.as_ref().clone())?.build() } else { Err(DataFusionError::NotImplemented( "Filter without an condition is not valid".to_string(), @@ -145,10 +151,11 @@ pub async fn from_substrait_rel(ctx: &mut SessionContext, rel: &Rel, extensions: } Some(RelType::Fetch(fetch)) => { if let Some(input) = fetch.input.as_ref() { - let input = from_substrait_rel(ctx, input, extensions).await?; + let input = + LogicalPlanBuilder::from(from_substrait_rel(ctx, input, extensions).await?); let offset = fetch.offset as usize; let count = fetch.count as usize; - input.limit(offset, Some(count)) + input.limit(offset, Some(count))?.build() } else { Err(DataFusionError::NotImplemented( "Fetch without an input is not valid".to_string(), @@ -157,16 +164,17 @@ pub async fn from_substrait_rel(ctx: &mut SessionContext, rel: &Rel, extensions: } Some(RelType::Sort(sort)) => { if let Some(input) = sort.input.as_ref() { - let input = from_substrait_rel(ctx, input, extensions).await?; + let input = + LogicalPlanBuilder::from(from_substrait_rel(ctx, input, extensions).await?); let mut sorts: Vec = vec![]; for s in &sort.sorts { - let expr = from_substrait_rex(&s.expr.as_ref().unwrap(), &input.schema(), extensions).await?; + let expr = + from_substrait_rex(&s.expr.as_ref().unwrap(), &input.schema(), extensions) + .await?; let asc_nullfirst = match &s.sort_kind { Some(k) => match k { Direction(d) => { - let direction : SortDirection = unsafe { - ::std::mem::transmute(*d) - }; + let direction: SortDirection = unsafe { ::std::mem::transmute(*d) }; match direction { SortDirection::AscNullsFirst => Ok((true, true)), SortDirection::AscNullsLast => Ok((true, false)), @@ -174,32 +182,34 @@ pub async fn from_substrait_rel(ctx: &mut SessionContext, rel: &Rel, extensions: SortDirection::DescNullsLast => Ok((false, false)), SortDirection::Clustered => { Err(DataFusionError::NotImplemented( - "Sort with direction clustered is not yet supported".to_string(), - )) - }, + "Sort with direction clustered is not yet supported" + .to_string(), + )) + } SortDirection::Unspecified => { Err(DataFusionError::NotImplemented( "Unspecified sort direction is invalid".to_string(), - )) + )) } } } - ComparisonFunctionReference(_) => { - Err(DataFusionError::NotImplemented( - "Sort using comparison function reference is not supported".to_string(), - )) - }, - }, - None => { - Err(DataFusionError::NotImplemented( - "Sort without sort kind is invalid".to_string(), - )) + ComparisonFunctionReference(_) => Err(DataFusionError::NotImplemented( + "Sort using comparison function reference is not supported" + .to_string(), + )), }, + None => Err(DataFusionError::NotImplemented( + "Sort without sort kind is invalid".to_string(), + )), }; let (asc, nulls_first) = asc_nullfirst.unwrap(); - sorts.push(Expr::Sort { expr: Box::new(expr.as_ref().clone()), asc: asc, nulls_first: nulls_first }); + sorts.push(Expr::Sort { + expr: Box::new(expr.as_ref().clone()), + asc: asc, + nulls_first: nulls_first, + }); } - input.sort(sorts) + input.sort(sorts)?.build() } else { Err(DataFusionError::NotImplemented( "Sort without an input is not valid".to_string(), @@ -208,17 +218,16 @@ pub async fn from_substrait_rel(ctx: &mut SessionContext, rel: &Rel, extensions: } Some(RelType::Aggregate(agg)) => { if let Some(input) = agg.input.as_ref() { - let input = from_substrait_rel(ctx, input, extensions).await?; + let input = + LogicalPlanBuilder::from(from_substrait_rel(ctx, input, extensions).await?); let mut group_expr = vec![]; let mut aggr_expr = vec![]; let groupings = match agg.groupings.len() { - 1 => { Ok(&agg.groupings[0]) }, - _ => { - Err(DataFusionError::NotImplemented( - "Aggregate with multiple grouping sets is not supported".to_string(), - )) - } + 1 => Ok(&agg.groupings[0]), + _ => Err(DataFusionError::NotImplemented( + "Aggregate with multiple grouping sets is not supported".to_string(), + )), }; for e in &groupings?.grouping_expressions { @@ -228,18 +237,30 @@ pub async fn from_substrait_rel(ctx: &mut SessionContext, rel: &Rel, extensions: for m in &agg.measures { let filter = match &m.filter { - Some(fil) => Some(Box::new(from_substrait_rex(fil, &input.schema(), extensions).await?.as_ref().clone())), - None => None + Some(fil) => Some(Box::new( + from_substrait_rex(fil, &input.schema(), extensions) + .await? + .as_ref() + .clone(), + )), + None => None, }; let agg_func = match &m.measure { Some(f) => { - let distinct = match f.invocation { + let distinct = match f.invocation { _ if f.invocation == AggregationInvocation::Distinct as i32 => true, _ if f.invocation == AggregationInvocation::All as i32 => false, - _ => false + _ => false, }; - from_substrait_agg_func(&f, &input.schema(), extensions, filter, distinct).await - }, + from_substrait_agg_func( + &f, + &input.schema(), + extensions, + filter, + distinct, + ) + .await + } None => Err(DataFusionError::NotImplemented( "Aggregate without aggregate function is not supported".to_string(), )), @@ -247,7 +268,7 @@ pub async fn from_substrait_rel(ctx: &mut SessionContext, rel: &Rel, extensions: aggr_expr.push(agg_func?.as_ref().clone()); } - input.aggregate(group_expr, aggr_expr) + input.aggregate(group_expr, aggr_expr)?.build() } else { Err(DataFusionError::NotImplemented( "Aggregate without an input is not valid".to_string(), @@ -255,7 +276,9 @@ pub async fn from_substrait_rel(ctx: &mut SessionContext, rel: &Rel, extensions: } } Some(RelType::Join(join)) => { - let left = from_substrait_rel(ctx, &join.left.as_ref().unwrap(), extensions).await?; + let left = LogicalPlanBuilder::from( + from_substrait_rel(ctx, &join.left.as_ref().unwrap(), extensions).await?, + ); let right = from_substrait_rel(ctx, &join.right.as_ref().unwrap(), extensions).await?; let join_type = match join.r#type { 1 => JoinType::Inner, @@ -268,9 +291,10 @@ pub async fn from_substrait_rel(ctx: &mut SessionContext, rel: &Rel, extensions: }; let mut predicates = vec![]; let schema = build_join_schema(&left.schema(), &right.schema(), &JoinType::Inner)?; - let on = from_substrait_rex(&join.expression.as_ref().unwrap(), &schema, extensions).await?; + let on = + from_substrait_rex(&join.expression.as_ref().unwrap(), &schema, extensions).await?; split_conjunction(&on, &mut predicates); - let pairs = predicates + let pairs: Vec<(Column, Column)> = predicates .iter() .map(|p| match p { Expr::BinaryExpr { @@ -278,7 +302,7 @@ pub async fn from_substrait_rel(ctx: &mut SessionContext, rel: &Rel, extensions: op: Operator::Eq, right, } => match (left.as_ref(), right.as_ref()) { - (Expr::Column(l), Expr::Column(r)) => Ok((l.flat_name(), r.flat_name())), + (Expr::Column(l), Expr::Column(r)) => Ok((l.clone(), r.clone())), _ => { return Err(DataFusionError::Internal( "invalid join condition".to_string(), @@ -292,9 +316,10 @@ pub async fn from_substrait_rel(ctx: &mut SessionContext, rel: &Rel, extensions: } }) .collect::>>()?; - let left_cols: Vec<&str> = pairs.iter().map(|(l, _)| l.as_str()).collect(); - let right_cols: Vec<&str> = pairs.iter().map(|(_, r)| r.as_str()).collect(); - left.join(right, join_type, &left_cols, &right_cols, None) + let left_cols: Vec = pairs.iter().map(|(l, _)| l.clone()).collect(); + let right_cols: Vec = pairs.iter().map(|(_, r)| r.clone()).collect(); + left.join(&right, join_type, (left_cols, right_cols), None)? + .build() } Some(RelType::Read(read)) => match &read.as_ref().read_type { Some(ReadType::NamedTable(nt)) => { @@ -317,7 +342,7 @@ pub async fn from_substrait_rel(ctx: &mut SessionContext, rel: &Rel, extensions: table: &nt.names[2], }, }; - let t = ctx.table(table_reference)?; + let t = ctx.table(table_reference)?.to_logical_plan()?; match &read.projection { Some(MaskExpression { select, .. }) => match &select.as_ref() { Some(projection) => { @@ -326,7 +351,7 @@ pub async fn from_substrait_rel(ctx: &mut SessionContext, rel: &Rel, extensions: .iter() .map(|item| item.field as usize) .collect(); - match t.to_logical_plan()? { + match &t { LogicalPlan::TableScan(scan) => { let mut scan = scan.clone(); let fields: Vec = column_indices @@ -337,8 +362,7 @@ pub async fn from_substrait_rel(ctx: &mut SessionContext, rel: &Rel, extensions: scan.projected_schema = DFSchemaRef::new( DFSchema::new_with_metadata(fields, HashMap::new())?, ); - let plan = LogicalPlan::TableScan(scan); - Ok(Arc::new(DataFrame::new(ctx.state.clone(), &plan))) + Ok(LogicalPlan::TableScan(scan)) } _ => Err(DataFusionError::Internal( "unexpected plan for table".to_string(), @@ -367,15 +391,15 @@ pub async fn from_substrait_agg_func( input_schema: &DFSchema, extensions: &HashMap, filter: Option>, - distinct: bool + distinct: bool, ) -> Result> { let mut args: Vec = vec![]; for arg in &f.arguments { let arg_expr = match &arg.arg_type { Some(ArgType::Value(e)) => from_substrait_rex(e, input_schema, extensions).await, _ => Err(DataFusionError::NotImplemented( - "Aggregated function argument non-Value type not supported".to_string(), - )) + "Aggregated function argument non-Value type not supported".to_string(), + )), }; args.push(arg_expr?.as_ref().clone()); } @@ -383,27 +407,26 @@ pub async fn from_substrait_agg_func( let fun = match extensions.get(&f.function_reference) { Some(function_name) => aggregate_function::AggregateFunction::from_str(function_name), None => Err(DataFusionError::NotImplemented(format!( - "Aggregated function not found: function anchor = {:?}", - f.function_reference - ) - )) + "Aggregated function not found: function anchor = {:?}", + f.function_reference + ))), }; - Ok( - Arc::new( - Expr::AggregateFunction { - fun: fun.unwrap(), - args: args, - distinct: distinct, - filter: filter - } - ) - ) + Ok(Arc::new(Expr::AggregateFunction { + fun: fun.unwrap(), + args: args, + distinct: distinct, + filter: filter, + })) } /// Convert Substrait Rex to DataFusion Expr #[async_recursion] -pub async fn from_substrait_rex(e: &Expression, input_schema: &DFSchema, extensions: &HashMap) -> Result> { +pub async fn from_substrait_rex( + e: &Expression, + input_schema: &DFSchema, + extensions: &HashMap, +) -> Result> { match &e.rex_type { Some(RexType::Selection(field_ref)) => match &field_ref.reference_type { Some(DirectReference(direct)) => match &direct.reference_type.as_ref() { @@ -413,14 +436,12 @@ pub async fn from_substrait_rex(e: &Expression, input_schema: &DFSchema, extensi )), None => Ok(Arc::new(Expr::Column(Column { relation: None, - name: input_schema - .field(x.field as usize) - .name() - .to_string(), + name: input_schema.field(x.field as usize).name().to_string(), }))), }, _ => Err(DataFusionError::NotImplemented( - "Direct reference with types other than StructField is not supported".to_string(), + "Direct reference with types other than StructField is not supported" + .to_string(), )), }, _ => Err(DataFusionError::NotImplemented( @@ -436,43 +457,82 @@ pub async fn from_substrait_rex(e: &Expression, input_schema: &DFSchema, extensi if i == 0 { // Check if the first element is type base expression if if_expr.then.is_none() { - expr = Some(Box::new(from_substrait_rex(&if_expr.r#if.as_ref().unwrap(), input_schema, extensions).await?.as_ref().clone())); + expr = Some(Box::new( + from_substrait_rex( + &if_expr.r#if.as_ref().unwrap(), + input_schema, + extensions, + ) + .await? + .as_ref() + .clone(), + )); continue; } } - when_then_expr.push( - ( - Box::new(from_substrait_rex(&if_expr.r#if.as_ref().unwrap(), input_schema, extensions).await?.as_ref().clone()), - Box::new(from_substrait_rex(&if_expr.then.as_ref().unwrap(), input_schema, extensions).await?.as_ref().clone()) + when_then_expr.push(( + Box::new( + from_substrait_rex( + &if_expr.r#if.as_ref().unwrap(), + input_schema, + extensions, + ) + .await? + .as_ref() + .clone(), ), - ); + Box::new( + from_substrait_rex( + &if_expr.then.as_ref().unwrap(), + input_schema, + extensions, + ) + .await? + .as_ref() + .clone(), + ), + )); } // Parse `else` let else_expr = match &if_then.r#else { Some(e) => Some(Box::new( - from_substrait_rex(&e, input_schema, extensions).await?.as_ref().clone(), - )), - None => None + from_substrait_rex(&e, input_schema, extensions) + .await? + .as_ref() + .clone(), + )), + None => None, }; - Ok(Arc::new(Expr::Case { expr: expr, when_then_expr: when_then_expr, else_expr: else_expr })) - }, + Ok(Arc::new(Expr::Case { + expr: expr, + when_then_expr: when_then_expr, + else_expr: else_expr, + })) + } Some(RexType::ScalarFunction(f)) => { assert!(f.arguments.len() == 2); let op = match extensions.get(&f.function_reference) { - Some(fname) => name_to_op(fname), - None => Err(DataFusionError::NotImplemented(format!( - "Aggregated function not found: function reference = {:?}", - f.function_reference - ) - )) + Some(fname) => name_to_op(fname), + None => Err(DataFusionError::NotImplemented(format!( + "Aggregated function not found: function reference = {:?}", + f.function_reference + ))), }; match (&f.arguments[0].arg_type, &f.arguments[1].arg_type) { (Some(ArgType::Value(l)), Some(ArgType::Value(r))) => { Ok(Arc::new(Expr::BinaryExpr { - left: Box::new(from_substrait_rex(l, input_schema, extensions).await?.as_ref().clone()), + left: Box::new( + from_substrait_rex(l, input_schema, extensions) + .await? + .as_ref() + .clone(), + ), op: op?, right: Box::new( - from_substrait_rex(r, input_schema, extensions).await?.as_ref().clone(), + from_substrait_rex(r, input_schema, extensions) + .await? + .as_ref() + .clone(), ), })) } @@ -507,9 +567,9 @@ pub async fn from_substrait_rex(e: &Expression, input_schema: &DFSchema, extensi Some(LiteralType::Fp64(f)) => { Ok(Arc::new(Expr::Literal(ScalarValue::Float64(Some(*f))))) } - Some(LiteralType::String(s)) => Ok(Arc::new(Expr::Literal(ScalarValue::Utf8( - Some(s.clone()), - )))), + Some(LiteralType::String(s)) => { + Ok(Arc::new(Expr::Literal(ScalarValue::Utf8(Some(s.clone()))))) + } Some(LiteralType::Binary(b)) => Ok(Arc::new(Expr::Literal(ScalarValue::Binary(Some( b.clone(), ))))), diff --git a/src/producer.rs b/src/producer.rs index d809ea7..cfce50a 100644 --- a/src/producer.rs +++ b/src/producer.rs @@ -15,36 +15,36 @@ use substrait::protobuf::{ if_then::IfClause, literal::LiteralType, mask_expression::{StructItem, StructSelect}, - reference_segment, - FieldReference, IfThen, Literal, MaskExpression, ReferenceSegment, RexType, ScalarFunction, + reference_segment, FieldReference, IfThen, Literal, MaskExpression, ReferenceSegment, + RexType, ScalarFunction, + }, + extensions::{ + self, + simple_extension_declaration::{ExtensionFunction, MappingType}, }, - extensions::{self, simple_extension_declaration::{MappingType, ExtensionFunction}}, function_argument::ArgType, plan_rel, read_rel::{NamedTable, ReadType}, rel::RelType, - sort_field::{ - SortDirection, - SortKind, - }, - AggregateRel, Expression, FetchRel, FilterRel, FunctionArgument, JoinRel, NamedStruct, ProjectRel, ReadRel, SortField, SortRel, - PlanRel, - Plan, Rel, RelRoot, AggregateFunction, + sort_field::{SortDirection, SortKind}, + AggregateFunction, AggregateRel, Expression, FetchRel, FilterRel, FunctionArgument, JoinRel, + NamedStruct, Plan, PlanRel, ProjectRel, ReadRel, Rel, RelRoot, SortField, SortRel, }; /// Convert DataFusion LogicalPlan to Substrait Plan pub fn to_substrait_plan(plan: &LogicalPlan) -> Result> { // Parse relation nodes - let mut extension_info: (Vec, HashMap) = (vec![], HashMap::new()); + let mut extension_info: ( + Vec, + HashMap, + ) = (vec![], HashMap::new()); // Generate PlanRel(s) // Note: Only 1 relation tree is currently supported let plan_rels = vec![PlanRel { - rel_type: Some(plan_rel::RelType::Root( - RelRoot { - input: Some(*to_substrait_rel(plan, &mut extension_info)?), - names: plan.schema().field_names(), - } - )) + rel_type: Some(plan_rel::RelType::Root(RelRoot { + input: Some(*to_substrait_rel(plan, &mut extension_info)?), + names: plan.schema().field_names(), + })), }]; let (function_extensions, _) = extension_info; @@ -57,11 +57,16 @@ pub fn to_substrait_plan(plan: &LogicalPlan) -> Result> { advanced_extensions: None, expected_type_urls: vec![], })) - } /// Convert DataFusion LogicalPlan to Substrait Rel -pub fn to_substrait_rel(plan: &LogicalPlan, extension_info: &mut (Vec, HashMap)) -> Result> { +pub fn to_substrait_rel( + plan: &LogicalPlan, + extension_info: &mut ( + Vec, + HashMap, + ), +) -> Result> { match plan { LogicalPlan::TableScan(scan) => { let projection = scan.projection.as_ref().map(|p| { @@ -121,7 +126,8 @@ pub fn to_substrait_rel(plan: &LogicalPlan, extension_info: &mut (Vec { let input = to_substrait_rel(filter.input.as_ref(), extension_info)?; - let filter_expr = to_substrait_rex(&filter.predicate, filter.input.schema(), extension_info)?; + let filter_expr = + to_substrait_rex(&filter.predicate, filter.input.schema(), extension_info)?; Ok(Box::new(Rel { rel_type: Some(RelType::Filter(Box::new(FilterRel { common: None, @@ -176,12 +182,14 @@ pub fn to_substrait_rel(plan: &LogicalPlan, extension_info: &mut (Vec>>()?; - + Ok(Box::new(Rel { rel_type: Some(RelType::Aggregate(Box::new(AggregateRel { common: None, input: Some(input), - groupings: vec![Grouping { grouping_expressions: grouping }], //groupings, + groupings: vec![Grouping { + grouping_expressions: grouping, + }], //groupings, measures: measures, advanced_extension: None, }))), @@ -199,7 +207,9 @@ pub fn to_substrait_rel(plan: &LogicalPlan, extension_info: &mut (Vec &'static str { } } -pub fn to_substrait_agg_measure(expr: &Expr, schema: &DFSchemaRef, extension_info: &mut (Vec, HashMap)) -> Result { +pub fn to_substrait_agg_measure( + expr: &Expr, + schema: &DFSchemaRef, + extension_info: &mut ( + Vec, + HashMap, + ), +) -> Result { match expr { - Expr::AggregateFunction { fun, args, distinct, filter } => { + Expr::AggregateFunction { + fun, + args, + distinct, + filter, + } => { let mut arguments: Vec = vec![]; for arg in args { - arguments.push(FunctionArgument { arg_type: Some(ArgType::Value(to_substrait_rex(arg, schema, extension_info)?)) }); + arguments.push(FunctionArgument { + arg_type: Some(ArgType::Value(to_substrait_rex( + arg, + schema, + extension_info, + )?)), + }); } let function_name = fun.to_string().to_lowercase(); let function_anchor = _register_function(function_name, extension_info); @@ -329,10 +361,10 @@ pub fn to_substrait_agg_measure(expr: &Expr, schema: &DFSchemaRef, extension_inf }), filter: match filter { Some(f) => Some(to_substrait_rex(f, schema, extension_info)?), - None => None - } + None => None, + }, }) - }, + } _ => Err(DataFusionError::Internal(format!( "Expression must be compatible with aggregation. Unsupported expression: {:?}", expr @@ -340,7 +372,13 @@ pub fn to_substrait_agg_measure(expr: &Expr, schema: &DFSchemaRef, extension_inf } } -fn _register_function(function_name: String, extension_info: &mut (Vec, HashMap)) -> u32 { +fn _register_function( + function_name: String, + extension_info: &mut ( + Vec, + HashMap, + ), +) -> u32 { let (function_extensions, function_set) = extension_info; let function_name = function_name.to_lowercase(); // To prevent ambiguous references between ScalarFunctions and AggregateFunctions, @@ -351,7 +389,7 @@ fn _register_function(function_name: String, extension_info: &mut (Vec { // Function has been registered *function_anchor - }, + } None => { // Function has NOT been registered let function_anchor = function_set.len() as u32; @@ -369,14 +407,21 @@ fn _register_function(function_name: String, extension_info: &mut (Vec, HashMap)) -> Expression { +pub fn make_binary_op_scalar_func( + lhs: &Expression, + rhs: &Expression, + op: Operator, + extension_info: &mut ( + Vec, + HashMap, + ), +) -> Expression { let function_name = operator_to_name(op).to_string().to_lowercase(); let function_anchor = _register_function(function_name, extension_info); Expression { @@ -397,29 +442,71 @@ pub fn make_binary_op_scalar_func(lhs: &Expression, rhs: &Expression, op: Operat } /// Convert DataFusion Expr to Substrait Rex -pub fn to_substrait_rex(expr: &Expr, schema: &DFSchemaRef, extension_info: &mut (Vec, HashMap)) -> Result { +pub fn to_substrait_rex( + expr: &Expr, + schema: &DFSchemaRef, + extension_info: &mut ( + Vec, + HashMap, + ), +) -> Result { match expr { - Expr::Between { expr, negated, low, high } => { + Expr::Between { + expr, + negated, + low, + high, + } => { if *negated { // `expr NOT BETWEEN low AND high` can be translated into (expr < low OR high < expr) let substrait_expr = to_substrait_rex(expr, schema, extension_info)?; let substrait_low = to_substrait_rex(low, schema, extension_info)?; let substrait_high = to_substrait_rex(high, schema, extension_info)?; - let l_expr = make_binary_op_scalar_func(&substrait_expr, &substrait_low, Operator::Lt, extension_info); - let r_expr = make_binary_op_scalar_func(&substrait_high, &substrait_expr, Operator::Lt, extension_info); + let l_expr = make_binary_op_scalar_func( + &substrait_expr, + &substrait_low, + Operator::Lt, + extension_info, + ); + let r_expr = make_binary_op_scalar_func( + &substrait_high, + &substrait_expr, + Operator::Lt, + extension_info, + ); - Ok(make_binary_op_scalar_func(&l_expr, &r_expr, Operator::Or, extension_info)) + Ok(make_binary_op_scalar_func( + &l_expr, + &r_expr, + Operator::Or, + extension_info, + )) } else { // `expr BETWEEN low AND high` can be translated into (low <= expr AND expr <= high) let substrait_expr = to_substrait_rex(expr, schema, extension_info)?; let substrait_low = to_substrait_rex(low, schema, extension_info)?; let substrait_high = to_substrait_rex(high, schema, extension_info)?; - let l_expr = make_binary_op_scalar_func(&substrait_low, &substrait_expr, Operator::LtEq, extension_info); - let r_expr = make_binary_op_scalar_func(&substrait_expr, &substrait_high, Operator::LtEq, extension_info); + let l_expr = make_binary_op_scalar_func( + &substrait_low, + &substrait_expr, + Operator::LtEq, + extension_info, + ); + let r_expr = make_binary_op_scalar_func( + &substrait_expr, + &substrait_high, + Operator::LtEq, + extension_info, + ); - Ok(make_binary_op_scalar_func(&l_expr, &r_expr, Operator::And, extension_info)) + Ok(make_binary_op_scalar_func( + &l_expr, + &r_expr, + Operator::And, + extension_info, + )) } } Expr::Column(col) => { @@ -432,10 +519,15 @@ pub fn to_substrait_rex(expr: &Expr, schema: &DFSchemaRef, extension_info: &mut Ok(make_binary_op_scalar_func(&l, &r, *op, extension_info)) } - Expr::Case { expr, when_then_expr, else_expr } => { + Expr::Case { + expr, + when_then_expr, + else_expr, + } => { let mut ifs: Vec = vec![]; // Parse base - if let Some(e) = expr { // Base expression exists + if let Some(e) = expr { + // Base expression exists ifs.push(IfClause { r#if: Some(to_substrait_rex(e, schema, extension_info)?), then: None, @@ -454,11 +546,11 @@ pub fn to_substrait_rex(expr: &Expr, schema: &DFSchemaRef, extension_info: &mut Some(e) => Some(Box::new(to_substrait_rex(e, schema, extension_info)?)), None => None, }; - + Ok(Expression { rex_type: Some(RexType::IfThen(Box::new(IfThen { ifs: ifs, - r#else: r#else + r#else: r#else, }))), }) } @@ -491,9 +583,7 @@ pub fn to_substrait_rex(expr: &Expr, schema: &DFSchemaRef, extension_info: &mut })), }) } - Expr::Alias(expr, _alias) => { - to_substrait_rex(expr, schema, extension_info) - } + Expr::Alias(expr, _alias) => to_substrait_rex(expr, schema, extension_info), _ => Err(DataFusionError::NotImplemented(format!( "Unsupported expression: {:?}", expr @@ -501,9 +591,20 @@ pub fn to_substrait_rex(expr: &Expr, schema: &DFSchemaRef, extension_info: &mut } } -fn substrait_sort_field(expr: &Expr, schema: &DFSchemaRef, extension_info: &mut (Vec, HashMap)) -> Result { +fn substrait_sort_field( + expr: &Expr, + schema: &DFSchemaRef, + extension_info: &mut ( + Vec, + HashMap, + ), +) -> Result { match expr { - Expr::Sort { expr, asc, nulls_first } => { + Expr::Sort { + expr, + asc, + nulls_first, + } => { let e = to_substrait_rex(expr, schema, extension_info)?; let d = match (asc, nulls_first) { (true, true) => SortDirection::AscNullsFirst, @@ -515,7 +616,7 @@ fn substrait_sort_field(expr: &Expr, schema: &DFSchemaRef, extension_info: &mut expr: Some(e), sort_kind: Some(SortKind::Direction(d as i32)), }) - }, + } _ => Err(DataFusionError::NotImplemented(format!( "Expecting sort expression but got {:?}", expr @@ -527,12 +628,12 @@ fn substrait_field_ref(index: usize) -> Result { Ok(Expression { rex_type: Some(RexType::Selection(Box::new(FieldReference { reference_type: Some(ReferenceType::DirectReference(ReferenceSegment { - reference_type: Some(reference_segment::ReferenceType::StructField( - Box::new(reference_segment::StructField { + reference_type: Some(reference_segment::ReferenceType::StructField(Box::new( + reference_segment::StructField { field: index as i32, child: None, - }), - )), + }, + ))), })), root_type: None, }))), diff --git a/src/serializer.rs b/src/serializer.rs index 8662ad5..a3d6e4f 100644 --- a/src/serializer.rs +++ b/src/serializer.rs @@ -7,7 +7,7 @@ use prost::Message; use substrait::protobuf::Plan; use std::fs::OpenOptions; -use std::io::{Write, Read}; +use std::io::{Read, Write}; pub async fn serialize(sql: &str, ctx: &SessionContext, path: &str) -> Result<()> { let df = ctx.sql(sql).await?; @@ -16,10 +16,7 @@ pub async fn serialize(sql: &str, ctx: &SessionContext, path: &str) -> Result<() let mut protobuf_out = Vec::::new(); proto.encode(&mut protobuf_out).unwrap(); - let mut file = OpenOptions::new() - .create(true) - .write(true) - .open(path)?; + let mut file = OpenOptions::new().create(true).write(true).open(path)?; file.write_all(&protobuf_out)?; Ok(()) } @@ -27,14 +24,10 @@ pub async fn serialize(sql: &str, ctx: &SessionContext, path: &str) -> Result<() pub async fn deserialize(path: &str) -> Result> { let mut protobuf_in = Vec::::new(); - let mut file = OpenOptions::new() - .read(true) - .open(path)?; + let mut file = OpenOptions::new().read(true).open(path)?; file.read_to_end(&mut protobuf_in)?; let proto = Message::decode(&*protobuf_in).unwrap(); Ok(Box::new(proto)) } - - diff --git a/tests/roundtrip.rs b/tests/roundtrip.rs index 35f314f..eb6d967 100644 --- a/tests/roundtrip.rs +++ b/tests/roundtrip.rs @@ -1,7 +1,6 @@ use datafusion_substrait::consumer; use datafusion_substrait::producer; - #[cfg(test)] mod tests { @@ -79,7 +78,8 @@ mod tests { test_alias( "SELECT * FROM (SELECT distinct a FROM data)", // `SELECT *` is used to add `projection` at the root "SELECT a FROM data GROUP BY a", - ).await + ) + .await } #[tokio::test] @@ -87,15 +87,13 @@ mod tests { test_alias( "SELECT * FROM (SELECT distinct a, b FROM data)", // `SELECT *` is used to add `projection` at the root "SELECT a, b FROM data GROUP BY a, b", - ).await + ) + .await } #[tokio::test] async fn simple_alias() -> Result<()> { - test_alias( - "SELECT d1.a, d1.b FROM data d1", - "SELECT a, b FROM data", - ).await + test_alias("SELECT d1.a, d1.b FROM data d1", "SELECT a, b FROM data").await } #[tokio::test] @@ -111,7 +109,7 @@ mod tests { async fn between_integers() -> Result<()> { test_alias( "SELECT * FROM data WHERE a BETWEEN 2 AND 6", - "SELECT * FROM data WHERE a >= 2 AND a <= 6" + "SELECT * FROM data WHERE a >= 2 AND a <= 6", ) .await } @@ -120,7 +118,7 @@ mod tests { async fn not_between_integers() -> Result<()> { test_alias( "SELECT * FROM data WHERE a NOT BETWEEN 2 AND 6", - "SELECT * FROM data WHERE a < 2 OR a > 6" + "SELECT * FROM data WHERE a < 2 OR a > 6", ) .await } @@ -132,11 +130,14 @@ mod tests { #[tokio::test] async fn case_with_base_expression() -> Result<()> { - roundtrip("SELECT (CASE a + roundtrip( + "SELECT (CASE a WHEN 0 THEN 'zero' WHEN 1 THEN 'one' ELSE 'other' - END) FROM data").await + END) FROM data", + ) + .await } #[tokio::test] @@ -193,11 +194,15 @@ mod tests { let df_a = ctx.sql(sql_with_alias).await?; let proto_a = to_substrait_plan(&df_a.to_logical_plan()?)?; - let plan_with_alias = from_substrait_plan(&mut ctx, &proto_a).await?.to_logical_plan()?; + let plan_with_alias = from_substrait_plan(&mut ctx, &proto_a) + .await? + .to_logical_plan()?; let df = ctx.sql(sql_no_alias).await?; let proto = to_substrait_plan(&df.to_logical_plan()?)?; - let plan = from_substrait_plan(&mut ctx, &proto).await?.to_logical_plan()?; + let plan = from_substrait_plan(&mut ctx, &proto) + .await? + .to_logical_plan()?; println!("{:#?}", plan_with_alias); println!("{:#?}", plan); @@ -226,7 +231,7 @@ mod tests { Ok(()) } - async fn function_extension_info(sql: &str) -> Result<(Vec, Vec)> { + async fn function_extension_info(sql: &str) -> Result<(Vec, Vec)> { let ctx = create_context().await?; let df = ctx.sql(sql).await?; let plan = df.to_logical_plan()?; @@ -237,12 +242,12 @@ mod tests { for e in &proto.extensions { let (function_anchor, function_name) = match e.mapping_type.as_ref().unwrap() { MappingType::ExtensionFunction(ext_f) => (ext_f.function_anchor, &ext_f.name), - _ => unreachable!("Producer does not generate a non-function extension") + _ => unreachable!("Producer does not generate a non-function extension"), }; function_names.push(function_name.to_string()); function_anchors.push(function_anchor); } - + Ok((function_names, function_anchors)) } diff --git a/tests/serialize.rs b/tests/serialize.rs index 3b5b8b0..618bb9c 100644 --- a/tests/serialize.rs +++ b/tests/serialize.rs @@ -1,4 +1,3 @@ - #[cfg(test)] mod tests { @@ -43,4 +42,4 @@ mod tests { .await?; Ok(ctx) } -} \ No newline at end of file +} From e79468a1bb4d735175329ea10f9201757cad3ef2 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Wed, 7 Dec 2022 13:27:24 -0700 Subject: [PATCH 2/2] apply suggestion from review --- src/consumer.rs | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/consumer.rs b/src/consumer.rs index 1bb826c..79c8ede 100644 --- a/src/consumer.rs +++ b/src/consumer.rs @@ -316,8 +316,7 @@ pub async fn from_substrait_rel( } }) .collect::>>()?; - let left_cols: Vec = pairs.iter().map(|(l, _)| l.clone()).collect(); - let right_cols: Vec = pairs.iter().map(|(_, r)| r.clone()).collect(); + let (left_cols, right_cols): (Vec<_>, Vec<_>) = pairs.iter().cloned().unzip(); left.join(&right, join_type, (left_cols, right_cols), None)? .build() }