diff --git a/datafusion/sql/src/planner.rs b/datafusion/sql/src/planner.rs index 9ad515087a364..5cbe1d7c014ad 100644 --- a/datafusion/sql/src/planner.rs +++ b/datafusion/sql/src/planner.rs @@ -135,6 +135,9 @@ pub struct PlannerContext { ctes: HashMap>, /// The query schema of the outer query plan, used to resolve the columns in subquery outer_query_schema: Option, + /// The joined schemas of all FROM clauses planned so far. When planning LATERAL + /// FROM clauses, this should become a suffix of the `outer_query_schema`. + outer_from_schema: Option, } impl Default for PlannerContext { @@ -150,6 +153,7 @@ impl PlannerContext { prepare_param_data_types: Arc::new(vec![]), ctes: HashMap::new(), outer_query_schema: None, + outer_from_schema: None, } } @@ -177,6 +181,29 @@ impl PlannerContext { schema } + // return a clone of the outer FROM schema + pub fn outer_from_schema(&self) -> Option> { + self.outer_from_schema.clone() + } + + /// sets the outer FROM schema, returning the existing one, if any + pub fn set_outer_from_schema( + &mut self, + mut schema: Option, + ) -> Option { + std::mem::swap(&mut self.outer_from_schema, &mut schema); + schema + } + + /// extends the FROM schema, returning the existing one, if any + pub fn extend_outer_from_schema(&mut self, schema: &DFSchemaRef) -> Result<()> { + self.outer_from_schema = match self.outer_from_schema.as_ref() { + Some(from_schema) => Some(Arc::new(from_schema.join(schema)?)), + None => Some(Arc::clone(schema)), + }; + Ok(()) + } + /// Return the types of parameters (`$1`, `$2`, etc) if known pub fn prepare_param_data_types(&self) -> &[DataType] { &self.prepare_param_data_types diff --git a/datafusion/sql/src/relation/join.rs b/datafusion/sql/src/relation/join.rs index fb1d00b7e48a5..409533a3eaa58 100644 --- a/datafusion/sql/src/relation/join.rs +++ b/datafusion/sql/src/relation/join.rs @@ -18,7 +18,7 @@ use crate::planner::{ContextProvider, PlannerContext, SqlToRel}; use datafusion_common::{not_impl_err, Column, Result}; use datafusion_expr::{JoinType, LogicalPlan, LogicalPlanBuilder}; -use sqlparser::ast::{Join, JoinConstraint, JoinOperator, TableWithJoins}; +use sqlparser::ast::{Join, JoinConstraint, JoinOperator, TableFactor, TableWithJoins}; use std::collections::HashSet; impl<'a, S: ContextProvider> SqlToRel<'a, S> { @@ -27,10 +27,17 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { t: TableWithJoins, planner_context: &mut PlannerContext, ) -> Result { - let mut left = self.create_relation(t.relation, planner_context)?; - for join in t.joins.into_iter() { + let mut left = if is_lateral(&t.relation) { + self.create_relation_subquery(t.relation, planner_context)? + } else { + self.create_relation(t.relation, planner_context)? + }; + let old_outer_from_schema = planner_context.outer_from_schema(); + for join in t.joins { + planner_context.extend_outer_from_schema(left.schema())?; left = self.parse_relation_join(left, join, planner_context)?; } + planner_context.set_outer_from_schema(old_outer_from_schema); Ok(left) } @@ -40,7 +47,11 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { join: Join, planner_context: &mut PlannerContext, ) -> Result { - let right = self.create_relation(join.relation, planner_context)?; + let right = if is_lateral_join(&join)? { + self.create_relation_subquery(join.relation, planner_context)? + } else { + self.create_relation(join.relation, planner_context)? + }; match join.join_operator { JoinOperator::LeftOuter(constraint) => { self.parse_join(left, right, constraint, JoinType::Left, planner_context) @@ -144,3 +155,33 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } } } + +/// Return `true` iff the given [`TableFactor`] is lateral. +pub(crate) fn is_lateral(factor: &TableFactor) -> bool { + match factor { + TableFactor::Derived { lateral, .. } => *lateral, + TableFactor::Function { lateral, .. } => *lateral, + _ => false, + } +} + +/// Return `true` iff the given [`Join`] is lateral. +pub(crate) fn is_lateral_join(join: &Join) -> Result { + let is_lateral_syntax = is_lateral(&join.relation); + let is_apply_syntax = match join.join_operator { + JoinOperator::FullOuter(..) + | JoinOperator::RightOuter(..) + | JoinOperator::RightAnti(..) + | JoinOperator::RightSemi(..) + if is_lateral_syntax => + { + return not_impl_err!( + "LATERAL syntax is not supported for \ + FULL OUTER and RIGHT [OUTER | ANTI | SEMI] joins" + ); + } + JoinOperator::CrossApply | JoinOperator::OuterApply => true, + _ => false, + }; + Ok(is_lateral_syntax || is_apply_syntax) +} diff --git a/datafusion/sql/src/relation/mod.rs b/datafusion/sql/src/relation/mod.rs index c5fe180c23025..86e49780724b2 100644 --- a/datafusion/sql/src/relation/mod.rs +++ b/datafusion/sql/src/relation/mod.rs @@ -15,11 +15,15 @@ // specific language governing permissions and limitations // under the License. +use std::sync::Arc; + use crate::planner::{ContextProvider, PlannerContext, SqlToRel}; use datafusion_common::tree_node::{Transformed, TreeNode}; use datafusion_common::{not_impl_err, plan_err, DFSchema, Result, TableReference}; +use datafusion_expr::builder::subquery_alias; use datafusion_expr::{expr::Unnest, Expr, LogicalPlan, LogicalPlanBuilder}; +use datafusion_expr::{Subquery, SubqueryAlias}; use sqlparser::ast::{FunctionArg, FunctionArgExpr, TableFactor}; mod join; @@ -153,6 +157,53 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { Ok(optimized_plan) } } + + pub(crate) fn create_relation_subquery( + &self, + subquery: TableFactor, + planner_context: &mut PlannerContext, + ) -> Result { + // At this point for a syntacitally valid query the outer_from_schema is + // guaranteed to be set, so the `.unwrap()` call will never panic. This + // is the case because we only call this method for lateral table + // factors, and those can never be the first factor in a FROM list. This + // means we arrived here through the `for` loop in `plan_from_tables` or + // the `for` loop in `plan_table_with_joins`. + let old_from_schema = planner_context + .set_outer_from_schema(None) + .unwrap_or_else(|| Arc::new(DFSchema::empty())); + let new_query_schema = match planner_context.outer_query_schema() { + Some(old_query_schema) => { + let mut new_query_schema = old_from_schema.as_ref().clone(); + new_query_schema.merge(old_query_schema); + Some(Arc::new(new_query_schema)) + } + None => Some(Arc::clone(&old_from_schema)), + }; + let old_query_schema = planner_context.set_outer_query_schema(new_query_schema); + + let plan = self.create_relation(subquery, planner_context)?; + let outer_ref_columns = plan.all_out_ref_exprs(); + + planner_context.set_outer_query_schema(old_query_schema); + planner_context.set_outer_from_schema(Some(old_from_schema)); + + match plan { + LogicalPlan::SubqueryAlias(SubqueryAlias { input, alias, .. }) => { + subquery_alias( + LogicalPlan::Subquery(Subquery { + subquery: input, + outer_ref_columns, + }), + alias, + ) + } + plan => Ok(LogicalPlan::Subquery(Subquery { + subquery: Arc::new(plan), + outer_ref_columns, + })), + } + } } fn optimize_subquery_sort(plan: LogicalPlan) -> Result> { diff --git a/datafusion/sql/src/select.rs b/datafusion/sql/src/select.rs index 339234d9965ca..f42dec40149ff 100644 --- a/datafusion/sql/src/select.rs +++ b/datafusion/sql/src/select.rs @@ -496,19 +496,30 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { match from.len() { 0 => Ok(LogicalPlanBuilder::empty(true).build()?), 1 => { - let from = from.remove(0); - self.plan_table_with_joins(from, planner_context) + let input = from.remove(0); + self.plan_table_with_joins(input, planner_context) } _ => { - let mut plans = from - .into_iter() - .map(|t| self.plan_table_with_joins(t, planner_context)); - - let mut left = LogicalPlanBuilder::from(plans.next().unwrap()?); - - for right in plans { - left = left.cross_join(right?)?; + let mut from = from.into_iter(); + + let mut left = LogicalPlanBuilder::from({ + let input = from.next().unwrap(); + self.plan_table_with_joins(input, planner_context)? + }); + let old_outer_from_schema = { + let left_schema = Some(Arc::clone(left.schema())); + planner_context.set_outer_from_schema(left_schema) + }; + for input in from { + // Join `input` with the current result (`left`). + let right = self.plan_table_with_joins(input, planner_context)?; + left = left.cross_join(right)?; + // Update the outer FROM schema. + let left_schema = Some(Arc::clone(left.schema())); + planner_context.set_outer_from_schema(left_schema); } + planner_context.set_outer_from_schema(old_outer_from_schema); + Ok(left.build()?) } } diff --git a/datafusion/sql/tests/sql_integration.rs b/datafusion/sql/tests/sql_integration.rs index 7ce3565fa29f6..5685e09c9c9fb 100644 --- a/datafusion/sql/tests/sql_integration.rs +++ b/datafusion/sql/tests/sql_integration.rs @@ -3103,6 +3103,114 @@ fn join_on_complex_condition() { quick_test(sql, expected); } +#[test] +fn lateral_constant() { + let sql = "SELECT * FROM j1, LATERAL (SELECT 1) AS j2"; + let expected = "Projection: *\ + \n CrossJoin:\ + \n TableScan: j1\ + \n SubqueryAlias: j2\ + \n Subquery:\ + \n Projection: Int64(1)\ + \n EmptyRelation"; + quick_test(sql, expected); +} + +#[test] +fn lateral_comma_join() { + let sql = "SELECT j1_string, j2_string FROM + j1, \ + LATERAL (SELECT * FROM j2 WHERE j1_id < j2_id) AS j2"; + let expected = "Projection: j1.j1_string, j2.j2_string\ + \n CrossJoin:\ + \n TableScan: j1\ + \n SubqueryAlias: j2\ + \n Subquery:\ + \n Projection: *\ + \n Filter: outer_ref(j1.j1_id) < j2.j2_id\ + \n TableScan: j2"; + quick_test(sql, expected); +} + +#[test] +fn lateral_comma_join_referencing_join_rhs() { + let sql = "SELECT * FROM\ + \n j1 JOIN (j2 JOIN j3 ON(j2_id = j3_id - 2)) ON(j1_id = j2_id),\ + \n LATERAL (SELECT * FROM j3 WHERE j3_string = j2_string) as j4;"; + let expected = "Projection: *\ + \n CrossJoin:\ + \n Inner Join: Filter: j1.j1_id = j2.j2_id\ + \n TableScan: j1\ + \n Inner Join: Filter: j2.j2_id = j3.j3_id - Int64(2)\ + \n TableScan: j2\ + \n TableScan: j3\ + \n SubqueryAlias: j4\ + \n Subquery:\ + \n Projection: *\ + \n Filter: j3.j3_string = outer_ref(j2.j2_string)\ + \n TableScan: j3"; + quick_test(sql, expected); +} + +#[test] +fn lateral_comma_join_with_shadowing() { + // The j1_id on line 3 references the (closest) j1 definition from line 2. + let sql = "\ + SELECT * FROM j1, LATERAL (\ + SELECT * FROM j1, LATERAL (\ + SELECT * FROM j2 WHERE j1_id = j2_id\ + ) as j2\ + ) as j2;"; + let expected = "Projection: *\ + \n CrossJoin:\ + \n TableScan: j1\ + \n SubqueryAlias: j2\ + \n Subquery:\ + \n Projection: *\ + \n CrossJoin:\ + \n TableScan: j1\ + \n SubqueryAlias: j2\ + \n Subquery:\ + \n Projection: *\ + \n Filter: outer_ref(j1.j1_id) = j2.j2_id\ + \n TableScan: j2"; + quick_test(sql, expected); +} + +#[test] +fn lateral_left_join() { + let sql = "SELECT j1_string, j2_string FROM \ + j1 \ + LEFT JOIN LATERAL (SELECT * FROM j2 WHERE j1_id < j2_id) AS j2 ON(true);"; + let expected = "Projection: j1.j1_string, j2.j2_string\ + \n Left Join: Filter: Boolean(true)\ + \n TableScan: j1\ + \n SubqueryAlias: j2\ + \n Subquery:\ + \n Projection: *\ + \n Filter: outer_ref(j1.j1_id) < j2.j2_id\ + \n TableScan: j2"; + quick_test(sql, expected); +} + +#[test] +fn lateral_nested_left_join() { + let sql = "SELECT * FROM + j1, \ + (j2 LEFT JOIN LATERAL (SELECT * FROM j3 WHERE j1_id + j2_id = j3_id) AS j3 ON(true))"; + let expected = "Projection: *\ + \n CrossJoin:\ + \n TableScan: j1\ + \n Left Join: Filter: Boolean(true)\ + \n TableScan: j2\ + \n SubqueryAlias: j3\ + \n Subquery:\ + \n Projection: *\ + \n Filter: outer_ref(j1.j1_id) + outer_ref(j2.j2_id) = j3.j3_id\ + \n TableScan: j3"; + quick_test(sql, expected); +} + #[test] fn hive_aggregate_with_filter() -> Result<()> { let dialect = &HiveDialect {}; diff --git a/datafusion/sqllogictest/test_files/joins.slt b/datafusion/sqllogictest/test_files/joins.slt index 7af145fe3e818..0ef745a6b8e65 100644 --- a/datafusion/sqllogictest/test_files/joins.slt +++ b/datafusion/sqllogictest/test_files/joins.slt @@ -4046,6 +4046,54 @@ physical_plan 05)------MemoryExec: partitions=1, partition_sizes=[1] +# Test CROSS JOIN LATERAL syntax (planning) +query TT +explain select t1_id, t1_name, i from join_t1 t1 cross join lateral (select * from unnest(generate_series(1, t1_int))) as series(i); +---- +logical_plan +01)CrossJoin: +02)--SubqueryAlias: t1 +03)----TableScan: join_t1 projection=[t1_id, t1_name] +04)--SubqueryAlias: series +05)----Subquery: +06)------Projection: UNNEST(generate_series(Int64(1),outer_ref(t1.t1_int))) AS i +07)--------Unnest: lists[UNNEST(generate_series(Int64(1),outer_ref(t1.t1_int)))] structs[] +08)----------Projection: generate_series(Int64(1), CAST(outer_ref(t1.t1_int) AS Int64)) AS UNNEST(generate_series(Int64(1),outer_ref(t1.t1_int))) +09)------------EmptyRelation + + +# Test CROSS JOIN LATERAL syntax (execution) +# TODO: https://github.com/apache/datafusion/issues/10048 +query error DataFusion error: This feature is not implemented: Physical plan does not support logical expression OuterReferenceColumn\(UInt32, Column \{ relation: Some\(Bare \{ table: "t1" \}\), name: "t1_int" \}\) +select t1_id, t1_name, i from join_t1 t1 cross join lateral (select * from unnest(generate_series(1, t1_int))) as series(i); + + +# Test INNER JOIN LATERAL syntax (planning) +query TT +explain select t1_id, t1_name, i from join_t1 t2 inner join lateral (select * from unnest(generate_series(1, t1_int))) as series(i) on(t1_id > i); +---- +logical_plan +01)Inner Join: Filter: CAST(t2.t1_id AS Int64) > series.i +02)--SubqueryAlias: t2 +03)----TableScan: join_t1 projection=[t1_id, t1_name] +04)--SubqueryAlias: series +05)----Subquery: +06)------Projection: UNNEST(generate_series(Int64(1),outer_ref(t2.t1_int))) AS i +07)--------Unnest: lists[UNNEST(generate_series(Int64(1),outer_ref(t2.t1_int)))] structs[] +08)----------Projection: generate_series(Int64(1), CAST(outer_ref(t2.t1_int) AS Int64)) AS UNNEST(generate_series(Int64(1),outer_ref(t2.t1_int))) +09)------------EmptyRelation + + +# Test INNER JOIN LATERAL syntax (execution) +# TODO: https://github.com/apache/datafusion/issues/10048 +query error DataFusion error: This feature is not implemented: Physical plan does not support logical expression OuterReferenceColumn\(UInt32, Column \{ relation: Some\(Bare \{ table: "t2" \}\), name: "t1_int" \}\) +select t1_id, t1_name, i from join_t1 t2 inner join lateral (select * from unnest(generate_series(1, t1_int))) as series(i) on(t1_id > i); + +# Test RIGHT JOIN LATERAL syntax (unsupported) +query error DataFusion error: This feature is not implemented: LATERAL syntax is not supported for FULL OUTER and RIGHT \[OUTER \| ANTI \| SEMI\] joins +select t1_id, t1_name, i from join_t1 t1 right join lateral (select * from unnest(generate_series(1, t1_int))) as series(i); + + # Functional dependencies across a join statement ok CREATE TABLE sales_global (