diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index c8a9e1f349d1..63631e07e657 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -1623,6 +1623,26 @@ impl Expr { | Expr::SimilarTo(Like { expr, pattern, .. }) => { rewrite_placeholder(pattern.as_mut(), expr.as_ref(), schema)?; } + Expr::InSubquery(InSubquery { + expr, + subquery, + negated: _, + }) => { + let subquery_schema = subquery.subquery.schema(); + let fields = subquery_schema.fields(); + + // only supports subquery with exactly 1 field + if let [first_field] = &fields[..] { + rewrite_placeholder( + expr.as_mut(), + &Expr::Column(Column { + relation: None, + name: first_field.name().clone(), + }), + schema, + )?; + } + } Expr::Placeholder(_) => { has_placeholder = true; } @@ -2801,7 +2821,8 @@ mod test { use crate::expr_fn::col; use crate::{ case, lit, qualified_wildcard, wildcard, wildcard_with_options, ColumnarValue, - ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Volatility, + LogicalPlan, LogicalTableSource, Projection, ScalarFunctionArgs, ScalarUDF, + ScalarUDFImpl, TableScan, Volatility, }; use arrow::datatypes::{Field, Schema}; use sqlparser::ast; @@ -2863,6 +2884,86 @@ mod test { } } + #[test] + fn infer_placeholder_in_subquery() -> Result<()> { + // Schema for my_table: A (Int32), B (Int32) + let schema = Arc::new(Schema::new(vec![ + Field::new("A", DataType::Int32, true), + Field::new("B", DataType::Int32, true), + ])); + + let source = Arc::new(LogicalTableSource::new(schema.clone())); + + // Simulate: SELECT * FROM my_table WHERE $1 IN (SELECT A FROM my_table WHERE B > 3); + let placeholder = Expr::Placeholder(Placeholder { + id: "$1".to_string(), + data_type: None, + }); + + // Subquery: SELECT A FROM my_table WHERE B > 3 + let subquery_filter = Expr::BinaryExpr(BinaryExpr { + left: Box::new(col("B")), + op: Operator::Gt, + right: Box::new(Expr::Literal(ScalarValue::Int32(Some(3)))), + }); + + let subquery_scan = LogicalPlan::TableScan(TableScan { + table_name: TableReference::from("my_table"), + source, + projected_schema: Arc::new(DFSchema::try_from(schema.clone())?), + projection: None, + filters: vec![subquery_filter.clone()], + fetch: None, + }); + + let projected_fields = vec![Field::new("A", DataType::Int32, true)]; + let projected_schema = Arc::new(DFSchema::from_unqualified_fields( + projected_fields.into(), + Default::default(), + )?); + + let subquery = Subquery { + subquery: Arc::new(LogicalPlan::Projection(Projection { + expr: vec![col("A")], + input: Arc::new(subquery_scan), + schema: projected_schema, + })), + outer_ref_columns: vec![], + }; + + let in_subquery = Expr::InSubquery(InSubquery { + expr: Box::new(placeholder), + subquery, + negated: false, + }); + + let df_schema = DFSchema::try_from(schema)?; + + let (inferred_expr, contains_placeholder) = + in_subquery.infer_placeholder_types(&df_schema)?; + + assert!( + contains_placeholder, + "Expression should contain a placeholder" + ); + + match inferred_expr { + Expr::InSubquery(in_subquery) => match *in_subquery.expr { + Expr::Placeholder(placeholder) => { + assert_eq!( + placeholder.data_type, + Some(DataType::Int32), + "Placeholder $1 should infer Int32" + ); + } + _ => panic!("Expected Placeholder expression in InSubquery"), + }, + _ => panic!("Expected InSubquery expression"), + } + + Ok(()) + } + #[test] fn infer_placeholder_like_and_similar_to() { // name LIKE $1