diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index b8e4204a9c9e..165d275c3012 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -1775,6 +1775,28 @@ impl Expr { | Expr::SimilarTo(Like { expr, pattern, .. }) => { rewrite_placeholder(pattern.as_mut(), expr.as_ref(), schema)?; } + Expr::InSubquery(InSubquery { + expr, + subquery, + negated: _, + }) => { + let subquery_schema = subquery.subquery.schema(); + let fields = subquery_schema.fields(); + + // only supports subquery with exactly 1 field + // https://github.com/apache/datafusion/blob/main/datafusion/sql/src/expr/subquery.rs#L120 + if let [first_field] = &fields[..] { + rewrite_placeholder( + expr.as_mut(), + &Expr::Column(Column { + relation: None, + name: first_field.name().clone(), + spans: Spans::new(), + }), + schema, + )?; + } + } Expr::Placeholder(_) => { has_placeholder = true; } @@ -3198,7 +3220,8 @@ mod test { use crate::expr_fn::col; use crate::{ case, lit, qualified_wildcard, wildcard, wildcard_with_options, ColumnarValue, - ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Volatility, + LogicalPlan, LogicalTableSource, Projection, ScalarFunctionArgs, ScalarUDF, + ScalarUDFImpl, TableScan, Volatility, }; use arrow::datatypes::{Field, Schema}; use sqlparser::ast; @@ -3260,6 +3283,87 @@ mod test { } } + #[test] + fn infer_placeholder_in_subquery() -> Result<()> { + // Schema for my_table: A (Int32), B (Int32) + let schema = Arc::new(Schema::new(vec![ + Field::new("A", DataType::Int32, true), + Field::new("B", DataType::Int32, true), + ])); + + let source = Arc::new(LogicalTableSource::new(Arc::clone(&schema))); + + // Simulate: SELECT * FROM my_table WHERE $1 IN (SELECT A FROM my_table WHERE B > 3); + let placeholder = Expr::Placeholder(Placeholder { + id: "$1".to_string(), + data_type: None, + }); + + // Subquery: SELECT A FROM my_table WHERE B > 3 + let subquery_filter = Expr::BinaryExpr(BinaryExpr { + left: Box::new(col("B")), + op: Operator::Gt, + right: Box::new(Expr::Literal(ScalarValue::Int32(Some(3)))), + }); + + let subquery_scan = LogicalPlan::TableScan(TableScan { + table_name: TableReference::from("my_table"), + source, + projected_schema: Arc::new(DFSchema::try_from(Arc::clone(&schema))?), + projection: None, + filters: vec![subquery_filter.clone()], + fetch: None, + }); + + let projected_fields = vec![Field::new("A", DataType::Int32, true)]; + let projected_schema = Arc::new(DFSchema::from_unqualified_fields( + projected_fields.into(), + Default::default(), + )?); + + let subquery = Subquery { + subquery: Arc::new(LogicalPlan::Projection(Projection { + expr: vec![col("A")], + input: Arc::new(subquery_scan), + schema: projected_schema, + })), + outer_ref_columns: vec![], + spans: Spans::new(), + }; + + let in_subquery = Expr::InSubquery(InSubquery { + expr: Box::new(placeholder), + subquery, + negated: false, + }); + + let df_schema = DFSchema::try_from(schema)?; + + let (inferred_expr, contains_placeholder) = + in_subquery.infer_placeholder_types(&df_schema)?; + + assert!( + contains_placeholder, + "Expression should contain a placeholder" + ); + + match inferred_expr { + Expr::InSubquery(in_subquery) => match *in_subquery.expr { + Expr::Placeholder(placeholder) => { + assert_eq!( + placeholder.data_type, + Some(DataType::Int32), + "Placeholder $1 should infer Int32" + ); + } + _ => panic!("Expected Placeholder expression in InSubquery"), + }, + _ => panic!("Expected InSubquery expression"), + } + + Ok(()) + } + #[test] fn infer_placeholder_like_and_similar_to() { // name LIKE $1 diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index edf5f1126be9..52ec7065d17a 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -1494,6 +1494,30 @@ impl LogicalPlan { let mut param_types: HashMap> = HashMap::new(); self.apply_with_subqueries(|plan| { + if let LogicalPlan::Limit(Limit { + fetch: Some(f), + skip, + .. + }) = plan + { + if let Expr::Placeholder(Placeholder { id, data_type }) = &**f { + // Valid assumption, https://github.com/apache/datafusion/blob/41e7aed3a943134c40d1b18cb9d424b358b5e5b1/datafusion/optimizer/src/analyzer/type_coercion.rs#L242 + param_types.insert( + id.clone(), + Some(data_type.as_ref().cloned().unwrap_or(DataType::Int64)), + ); + } + + if let Some(s) = skip { + if let Expr::Placeholder(Placeholder { id, data_type }) = &**s { + // Valid assumption, https://github.com/apache/datafusion/blob/41e7aed3a943134c40d1b18cb9d424b358b5e5b1/datafusion/optimizer/src/analyzer/type_coercion.rs#L242 + param_types.insert( + id.clone(), + Some(data_type.as_ref().cloned().unwrap_or(DataType::Int64)), + ); + } + } + } plan.apply_expressions(|expr| { expr.apply(|expr| { if let Expr::Placeholder(Placeholder { id, data_type }) = expr { @@ -1507,6 +1531,10 @@ impl LogicalPlan { (_, Some(dt)) => { param_types.insert(id.clone(), Some(dt.clone())); } + (Some(Some(_)), None) => { + // we have already inferred the datatype like + // the LIMIT case handled specially above. + } _ => { param_types.insert(id.clone(), None); } @@ -4029,6 +4057,89 @@ mod tests { .build() } + #[test] + fn test_resolved_placeholder_limit() -> Result<()> { + let schema = Arc::new(Schema::new(vec![Field::new("A", DataType::Int32, true)])); + let source = Arc::new(LogicalTableSource::new(Arc::clone(&schema))); + + let placeholders = ["$1", "$2"]; + + // SELECT * FROM my_table LIMIT $1 OFFSET $2 + let plan = LogicalPlan::Limit(Limit { + skip: Some(Box::new(Expr::Placeholder(Placeholder { + id: placeholders[1].to_string(), + data_type: None, + }))), + fetch: Some(Box::new(Expr::Placeholder(Placeholder { + id: placeholders[0].to_string(), + data_type: None, + }))), + input: Arc::new(LogicalPlan::TableScan(TableScan { + table_name: TableReference::from("my_table"), + source, + projected_schema: Arc::new(DFSchema::try_from(Arc::clone(&schema))?), + projection: None, + filters: vec![], + fetch: None, + })), + }); + + // try to infer the placeholder datatypes for the plan + let schema = DFSchema::try_from(Arc::clone(&schema))?; + let plan = plan + .map_expressions(|e| { + let (e, has_placeholder) = e.infer_placeholder_types(&schema)?; + Ok(if !has_placeholder { + Transformed::no(e) + } else { + Transformed::yes(e) + }) + }) + .expect("map expressions") + .data; + + let LogicalPlan::Limit(Limit { + fetch: Some(f), + skip: Some(s), + .. + }) = &plan + else { + panic!("plan is not Limit with fetch and skip"); + }; + + if !matches!( + (&**f, &**s), + ( + Expr::Placeholder(Placeholder { + data_type: None, + .. + }), + Expr::Placeholder(Placeholder { + data_type: None, + .. + }) + ) + ) { + panic!( + "expected fetch and skip to be placeholders with datatypes uninferred" + ); + } + + let params = plan.get_parameter_types().expect("to infer type"); + assert_eq!(params.len(), 2); + + for placeholder in placeholders { + let parameter_type = params + .clone() + .get(placeholder) + .expect("to get fetch type") + .clone(); + assert_eq!(parameter_type, Some(DataType::Int64)); + } + + Ok(()) + } + #[test] fn test_display_indent() -> Result<()> { let plan = display_plan()?;