Skip to content
Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 96 additions & 1 deletion datafusion/expr/src/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1623,6 +1623,24 @@ impl Expr {
| Expr::SimilarTo(Like { expr, pattern, .. }) => {
rewrite_placeholder(pattern.as_mut(), expr.as_ref(), schema)?;
}
Expr::InSubquery(InSubquery {
expr,
subquery,
negated: _,
}) => {
let subquery_schema = subquery.subquery.schema();
// Using the datatype of the first field in the subquery
if let Some(first_field) = subquery_schema.fields().first() {
rewrite_placeholder(
expr.as_mut(),
&Expr::Column(Column {
relation: None,
name: first_field.name().clone(),
}),
schema,
)?;
}
}
Expr::Placeholder(_) => {
has_placeholder = true;
}
Expand Down Expand Up @@ -2801,7 +2819,8 @@ mod test {
use crate::expr_fn::col;
use crate::{
case, lit, qualified_wildcard, wildcard, wildcard_with_options, ColumnarValue,
ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Volatility,
LogicalPlan, LogicalTableSource, Projection, ScalarFunctionArgs, ScalarUDF,
ScalarUDFImpl, TableScan, Volatility,
};
use arrow::datatypes::{Field, Schema};
use sqlparser::ast;
Expand Down Expand Up @@ -2863,6 +2882,82 @@ mod test {
}
}

#[test]
fn infer_placeholder_in_subquery() -> Result<()> {
// Schema for my_table: A (Int32), B (Int32)
let schema = Arc::new(Schema::new(vec![
Field::new("A", DataType::Int32, true),
Field::new("B", DataType::Int32, true),
]));

let source = Arc::new(LogicalTableSource::new(schema.clone()));

// Simulate: SELECT * FROM my_table WHERE $1 IN (SELECT A FROM my_table WHERE B > 3);
let placeholder = Expr::Placeholder(Placeholder {
id: "$1".to_string(),
data_type: None,
});

// Subquery: SELECT A FROM my_table WHERE B > 3
let subquery_filter = Expr::BinaryExpr(BinaryExpr {
left: Box::new(col("B")),
op: Operator::Gt,
right: Box::new(Expr::Literal(ScalarValue::Int32(Some(3)))),
});

let subquery_scan = LogicalPlan::TableScan(TableScan {
table_name: TableReference::from("my_table"),
source,
projected_schema: Arc::new(DFSchema::try_from(schema.clone())?),
projection: None,
filters: vec![subquery_filter.clone()],
fetch: None,
});

let subquery = Subquery {
subquery: Arc::new(LogicalPlan::Projection(Projection {
expr: vec![col("A")],
input: Arc::new(subquery_scan),
schema: Arc::new(DFSchema::try_from_qualified_schema(
"my_table", &schema,
)?),
})),
outer_ref_columns: vec![],
};

let in_subquery = Expr::InSubquery(InSubquery {
expr: Box::new(placeholder),
subquery,
negated: false,
});

let df_schema = DFSchema::try_from(schema)?;

let (inferred_expr, contains_placeholder) =
in_subquery.infer_placeholder_types(&df_schema)?;

assert!(
contains_placeholder,
"Expression should contain a placeholder"
);

match inferred_expr {
Expr::InSubquery(in_subquery) => match *in_subquery.expr {
Expr::Placeholder(placeholder) => {
assert_eq!(
placeholder.data_type,
Some(DataType::Int32),
"Placeholder $1 should infer Int32"
);
}
_ => panic!("Expected Placeholder expression in InSubquery"),
},
_ => panic!("Expected InSubquery expression"),
}

Ok(())
}

#[test]
fn infer_placeholder_like_and_similar_to() {
// name LIKE $1
Expand Down