Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 102 additions & 1 deletion datafusion/expr/src/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1623,6 +1623,26 @@ impl Expr {
| Expr::SimilarTo(Like { expr, pattern, .. }) => {
rewrite_placeholder(pattern.as_mut(), expr.as_ref(), schema)?;
}
Expr::InSubquery(InSubquery {
expr,
subquery,
negated: _,
}) => {
let subquery_schema = subquery.subquery.schema();
let fields = subquery_schema.fields();

// only supports subquery with exactly 1 field
if let [first_field] = &fields[..] {
rewrite_placeholder(
expr.as_mut(),
&Expr::Column(Column {
relation: None,
name: first_field.name().clone(),
}),
schema,
)?;
}
}
Expr::Placeholder(_) => {
has_placeholder = true;
}
Expand Down Expand Up @@ -2801,7 +2821,8 @@ mod test {
use crate::expr_fn::col;
use crate::{
case, lit, qualified_wildcard, wildcard, wildcard_with_options, ColumnarValue,
ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Volatility,
LogicalPlan, LogicalTableSource, Projection, ScalarFunctionArgs, ScalarUDF,
ScalarUDFImpl, TableScan, Volatility,
};
use arrow::datatypes::{Field, Schema};
use sqlparser::ast;
Expand Down Expand Up @@ -2863,6 +2884,86 @@ mod test {
}
}

#[test]
fn infer_placeholder_in_subquery() -> Result<()> {
// Schema for my_table: A (Int32), B (Int32)
let schema = Arc::new(Schema::new(vec![
Field::new("A", DataType::Int32, true),
Field::new("B", DataType::Int32, true),
]));

let source = Arc::new(LogicalTableSource::new(schema.clone()));

// Simulate: SELECT * FROM my_table WHERE $1 IN (SELECT A FROM my_table WHERE B > 3);
let placeholder = Expr::Placeholder(Placeholder {
id: "$1".to_string(),
data_type: None,
});

// Subquery: SELECT A FROM my_table WHERE B > 3
let subquery_filter = Expr::BinaryExpr(BinaryExpr {
left: Box::new(col("B")),
op: Operator::Gt,
right: Box::new(Expr::Literal(ScalarValue::Int32(Some(3)))),
});

let subquery_scan = LogicalPlan::TableScan(TableScan {
table_name: TableReference::from("my_table"),
source,
projected_schema: Arc::new(DFSchema::try_from(schema.clone())?),
projection: None,
filters: vec![subquery_filter.clone()],
fetch: None,
});

let projected_fields = vec![Field::new("A", DataType::Int32, true)];
let projected_schema = Arc::new(DFSchema::from_unqualified_fields(
projected_fields.into(),
Default::default(),
)?);

let subquery = Subquery {
subquery: Arc::new(LogicalPlan::Projection(Projection {
expr: vec![col("A")],
input: Arc::new(subquery_scan),
schema: projected_schema,
})),
outer_ref_columns: vec![],
};

let in_subquery = Expr::InSubquery(InSubquery {
expr: Box::new(placeholder),
subquery,
negated: false,
});

let df_schema = DFSchema::try_from(schema)?;

let (inferred_expr, contains_placeholder) =
in_subquery.infer_placeholder_types(&df_schema)?;

assert!(
contains_placeholder,
"Expression should contain a placeholder"
);

match inferred_expr {
Expr::InSubquery(in_subquery) => match *in_subquery.expr {
Expr::Placeholder(placeholder) => {
assert_eq!(
placeholder.data_type,
Some(DataType::Int32),
"Placeholder $1 should infer Int32"
);
}
_ => panic!("Expected Placeholder expression in InSubquery"),
},
_ => panic!("Expected InSubquery expression"),
}

Ok(())
}

#[test]
fn infer_placeholder_like_and_similar_to() {
// name LIKE $1
Expand Down