Skip to content

Commit 807508b

Browse files
kczimmJeadie
authored andcommitted
Infer placeholder datatype after LIMIT clause as DataType::Int64 (#81)
UPSTREAM NOTE: Upstream PR has been created but not merged yet. Should be available in DF49 apache#15980
1 parent a9dec78 commit 807508b

File tree

5 files changed

+55
-10
lines changed

5 files changed

+55
-10
lines changed

datafusion/core/src/datasource/listing/table.rs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@ use datafusion_datasource::file_scan_config::{FileScanConfig, FileScanConfigBuil
3535
use datafusion_datasource::metadata::MetadataColumn;
3636
use datafusion_expr::dml::InsertOp;
3737
use datafusion_expr::{Expr, SortExpr, TableProviderFilterPushDown, TableType};
38-
use datafusion_physical_expr::create_physical_expr;
3938
use datafusion_physical_plan::empty::EmptyExec;
4039
use datafusion_physical_plan::{ExecutionPlan, Statistics};
4140

datafusion/core/tests/sql/path_partition.rs

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,16 +40,15 @@ use datafusion_catalog::memory::DataSourceExec;
4040
use datafusion_catalog::TableProvider;
4141
use datafusion_common::stats::Precision;
4242
use datafusion_common::test_util::batches_to_sort_string;
43-
use datafusion_common::{Column, ScalarValue};
43+
use datafusion_common::ScalarValue;
4444
use datafusion_datasource::file_scan_config::FileScanConfig;
4545
use datafusion_datasource::metadata::MetadataColumn;
46-
use datafusion_datasource_parquet::source::ParquetSource;
4746
use datafusion_execution::config::SessionConfig;
4847

4948
use async_trait::async_trait;
5049
use bytes::Bytes;
5150
use chrono::{TimeZone, Utc};
52-
use datafusion_expr::{col, lit, BinaryExpr, Expr, Operator};
51+
use datafusion_expr::{col, lit, Expr};
5352
use futures::stream::{self, BoxStream};
5453
use insta::assert_snapshot;
5554
use object_store::{

datafusion/expr/src/expr.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1985,7 +1985,7 @@ impl Expr {
19851985
&Expr::Column(Column {
19861986
relation: None,
19871987
name: first_field.name().clone(),
1988-
spans: Spans::new(),
1988+
spans: Spans::default(),
19891989
}),
19901990
schema,
19911991
)?;
@@ -3525,7 +3525,7 @@ mod test {
35253525
let subquery_filter = Expr::BinaryExpr(BinaryExpr {
35263526
left: Box::new(col("B")),
35273527
op: Operator::Gt,
3528-
right: Box::new(Expr::Literal(ScalarValue::Int32(Some(3)))),
3528+
right: Box::new(Expr::Literal(ScalarValue::Int32(Some(3)), None)),
35293529
});
35303530

35313531
let subquery_scan = LogicalPlan::TableScan(TableScan {

datafusion/expr/src/logical_plan/plan.rs

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1497,6 +1497,14 @@ impl LogicalPlan {
14971497
let mut param_types: HashMap<String, Option<DataType>> = HashMap::new();
14981498

14991499
self.apply_with_subqueries(|plan| {
1500+
if let LogicalPlan::Limit(Limit { fetch: Some(e), .. }) = plan {
1501+
if let Expr::Placeholder(Placeholder { id, data_type }) = &**e {
1502+
param_types.insert(
1503+
id.clone(),
1504+
Some(data_type.clone().unwrap_or(DataType::Int64)),
1505+
);
1506+
}
1507+
}
15001508
plan.apply_expressions(|expr| {
15011509
expr.apply(|expr| {
15021510
if let Expr::Placeholder(Placeholder { id, data_type }) = expr {
@@ -1510,6 +1518,9 @@ impl LogicalPlan {
15101518
(_, Some(dt)) => {
15111519
param_types.insert(id.clone(), Some(dt.clone()));
15121520
}
1521+
(Some(Some(_)), None) => {
1522+
// we have already inferred the datatype
1523+
}
15131524
_ => {
15141525
param_types.insert(id.clone(), None);
15151526
}
@@ -5613,6 +5624,42 @@ mod tests {
56135624
"USING join should have all fields"
56145625
);
56155626
assert_eq!(using_join.join_constraint, JoinConstraint::Using);
5627+
Ok(())
5628+
}
5629+
5630+
#[test]
5631+
fn test_resolved_placeholder_limit() -> Result<()> {
5632+
let schema = Arc::new(Schema::new(vec![Field::new("A", DataType::Int32, true)]));
5633+
let source = Arc::new(LogicalTableSource::new(schema.clone()));
5634+
5635+
let placeholder_value = "$1";
5636+
5637+
// SELECT * FROM my_table LIMIT $1
5638+
let plan = LogicalPlan::Limit(Limit {
5639+
skip: None,
5640+
fetch: Some(Box::new(Expr::Placeholder(Placeholder {
5641+
id: placeholder_value.to_string(),
5642+
data_type: None,
5643+
}))),
5644+
input: Arc::new(LogicalPlan::TableScan(TableScan {
5645+
table_name: TableReference::from("my_table"),
5646+
source,
5647+
projected_schema: Arc::new(DFSchema::try_from(schema.clone())?),
5648+
projection: None,
5649+
filters: vec![],
5650+
fetch: None,
5651+
})),
5652+
});
5653+
5654+
let params = plan.get_parameter_types().expect("to infer type");
5655+
assert_eq!(params.len(), 1);
5656+
5657+
let parameter_type = params
5658+
.clone()
5659+
.get(placeholder_value)
5660+
.expect("to get type")
5661+
.clone();
5662+
assert_eq!(parameter_type, Some(DataType::Int64));
56165663

56175664
Ok(())
56185665
}

datafusion/sql/tests/cases/plan_to_sql.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -284,7 +284,7 @@ fn roundtrip_crossjoin() -> Result<()> {
284284
plan_roundtrip,
285285
@r"
286286
Projection: j1.j1_id, j2.j2_string
287-
Cross Join:
287+
Cross Join:
288288
TableScan: j1
289289
TableScan: j2
290290
"
@@ -1977,7 +1977,7 @@ fn test_complex_order_by_with_grouping() -> Result<()> {
19771977
}, {
19781978
assert_snapshot!(
19791979
sql,
1980-
@r#"SELECT j1.j1_id, j1.j1_string, lochierarchy FROM (SELECT j1.j1_id, j1.j1_string, (grouping(j1.j1_id) + grouping(j1.j1_string)) AS lochierarchy, grouping(j1.j1_string), grouping(j1.j1_id) FROM j1 GROUP BY ROLLUP (j1.j1_id, j1.j1_string) ORDER BY (grouping(j1.j1_id) + grouping(j1.j1_string)) DESC NULLS FIRST, CASE WHEN ((grouping(j1.j1_id) + grouping(j1.j1_string)) = 0) THEN j1.j1_id END ASC NULLS LAST) LIMIT 100"#
1980+
@"SELECT j1_id, j1_string, lochierarchy FROM (SELECT j1.j1_id, j1.j1_string, (grouping(j1.j1_id) + grouping(j1.j1_string)) AS lochierarchy, grouping(j1.j1_string), grouping(j1.j1_id) FROM j1 GROUP BY ROLLUP (j1.j1_id, j1.j1_string) ORDER BY (grouping(j1.j1_id) + grouping(j1.j1_string)) DESC NULLS FIRST, CASE WHEN ((grouping(j1.j1_id) + grouping(j1.j1_string)) = 0) THEN j1.j1_id END ASC NULLS LAST) LIMIT 100"
19811981
);
19821982
});
19831983

@@ -2682,7 +2682,7 @@ fn test_struct_expr() {
26822682
);
26832683
assert_snapshot!(
26842684
statement,
2685-
@r#"SELECT test."metadata".product FROM (SELECT {product: {"name": 'Product Name'}} AS "metadata") AS test WHERE (test."metadata".product."name" = 'Product Name')"#
2685+
@r#"SELECT product FROM (SELECT {product: {"name": 'Product Name'}} AS "metadata") AS test WHERE (test."metadata".product."name" = 'Product Name')"#
26862686
);
26872687

26882688
let statement = generate_round_trip_statement(
@@ -2691,7 +2691,7 @@ fn test_struct_expr() {
26912691
);
26922692
assert_snapshot!(
26932693
statement,
2694-
@r#"SELECT test."metadata".product FROM (SELECT {product: {"name": 'Product Name'}} AS "metadata") AS test WHERE (test."metadata".product."name" = 'Product Name')"#
2694+
@r#"SELECT product FROM (SELECT {product: {"name": 'Product Name'}} AS "metadata") AS test WHERE (test."metadata".product."name" = 'Product Name')"#
26952695
);
26962696
}
26972697

0 commit comments

Comments
 (0)