Skip to content

Commit 230d726

Browse files
jdye64charlesblucaandygrove
authored
Datafusion invalid projection (#571)
* Condition for BinaryExpr, filter, input_ref, rexcall, and rexliteral * Updates for test_filter * more of test_filter.py working with the exception of some date pytests * Add workflow to keep datafusion dev branch up to date (#440) * Include setuptools-rust in conda build recipie, in host and run * Remove PyArrow dependency * rebase with datafusion-sql-planner * refactor changes that were inadvertent during rebase * timestamp with loglca time zone * Bump DataFusion version (#494) * bump DataFusion version * remove unnecessary downcasts and use separate structs for TableSource and TableProvider * Include RelDataType work * Include RelDataType work * Introduced SqlTypeName Enum in Rust and mappings for Python * impl PyExpr.getIndex() * add getRowType() for logical.rs * Introduce DaskTypeMap for storing correlating SqlTypeName and DataTypes * use str values instead of Rust Enums, Python is unable to Hash the Rust Enums if used in a dict * linter changes, why did that work on my local pre-commit?? * linter changes, why did that work on my local pre-commit?? * Convert final strs to SqlTypeName Enum * removed a few print statements * commit to share with colleague * updates * checkpoint * Temporarily disable conda run_test.py script since it uses features not yet implemented * formatting after upstream merge * expose fromString method for SqlTypeName to use Enums instead of strings for type checking * expanded SqlTypeName from_string() support * accept INT as INTEGER * tests update * checkpoint * checkpoint * Refactor PyExpr by removing From trait, and using recursion to expand expression list for rex calls * skip test that uses create statement for gpuci * Basic DataFusion Select Functionality (#489) * Condition for BinaryExpr, filter, input_ref, rexcall, and rexliteral * Updates for test_filter * more of test_filter.py working with the exception of some date pytests * Add workflow to keep datafusion dev branch up to date (#440) * Include setuptools-rust in conda build recipie, in host and run * Remove PyArrow dependency * rebase with datafusion-sql-planner * refactor changes that were inadvertent during rebase * timestamp with loglca time zone * Include RelDataType work * Include RelDataType work * Introduced SqlTypeName Enum in Rust and mappings for Python * impl PyExpr.getIndex() * add getRowType() for logical.rs * Introduce DaskTypeMap for storing correlating SqlTypeName and DataTypes * use str values instead of Rust Enums, Python is unable to Hash the Rust Enums if used in a dict * linter changes, why did that work on my local pre-commit?? * linter changes, why did that work on my local pre-commit?? * Convert final strs to SqlTypeName Enum * removed a few print statements * Temporarily disable conda run_test.py script since it uses features not yet implemented * expose fromString method for SqlTypeName to use Enums instead of strings for type checking * expanded SqlTypeName from_string() support * accept INT as INTEGER * Remove print statements * Default to UTC if tz is None * Delegate timezone handling to the arrow library * Updates from review Co-authored-by: Charles Blackmon-Luca <[email protected]> * updates for expression * uncommented pytests * uncommented pytests * code cleanup for review * code cleanup for review * Enabled more pytest that work now * Enabled more pytest that work now * Output Expression as String when BinaryExpr does not contain a named alias * Output Expression as String when BinaryExpr does not contain a named alias * Disable 2 pytest that are causing gpuCI issues. They will be address in a follow up PR * Handle Between operation for case-when * adjust timestamp casting * Refactor projection _column_name() logic to the _column_name logic in expression.rs * removed println! statements * introduce join getCondition() logic for retrieving the combining Rex logic for joining * Updates from review * Add Offset and point to repo with offset in datafusion * Introduce offset * limit updates * commit before upstream merge * Code formatting * update Cargo.toml to use Arrow-DataFusion version with LIMIT logic * Bump DataFusion version to get changes around variant_name() * Use map partitions for determining the offset * Merge with upstream * Rename underlying DataContainer's DataFrame instance to match the column container names * Adjust ColumnContainer mapping after join.py logic to entire the bakend mapping is reset * Add enumerate to column_{i} generation string to ensure columns exist in both dataframes * Adjust join schema logic to perform merge instead of join on rust side to avoid name collisions * Handle DataFusion COUNT(UInt8(1)) as COUNT(*) * commit before merge * Update function for gathering index of a expression * Update for review check * Adjust RelDataType to retrieve fully qualified column names * Adjust base.py to get fully qualified column name * Enable passing pytests in test_join.py * Adjust keys provided by getting backend column mapping name * Adjust output_col to not use the backend_column name for special reserved exprs * uncomment cross join pytest which works now * Uncomment passing pytests in test_select.py * Review updates * Add back complex join case condition, not just cross join but 'complex' joins * Enable DataFusion CBO logic * Disable EliminateFilter optimization rule * updates * Disable tests that hit CBO generated plan edge cases of yet to be implemented logic * [REVIEW] - Modifiy sql.skip_optimize to use dask_config.get and remove used method parameter * [REVIEW] - change name of configuration from skip_optimize to optimize * [REVIEW] - Add OptimizeException catch and raise statements back * Found issue where backend column names which are results of a single aggregate resulting column, COUNT(*) for example, need to get the first agg df column since names are not valid * Remove SQL from OptimizationException * skip tests that CBO plan reorganization causes missing features to be present * If TableScan contains projections use those instead of all of the TableColums for limiting columns read during table_scan * [REVIEW] remove compute(), remove temp row_type variable * [REVIEW] - Add test for projection pushdown * [REVIEW] - Add some more parametrized test combinations * [REVIEW] - Use iterator instead of for loop and simplify contains_projections * [REVIEW] - merge upstream and adjust imports * [REVIEW] - Rename pytest function and remove duplicate table creation Co-authored-by: Charles Blackmon-Luca <[email protected]> Co-authored-by: Andy Grove <[email protected]>
1 parent a52dd7b commit 230d726

File tree

5 files changed

+93
-7
lines changed

5 files changed

+93
-7
lines changed

dask_planner/src/sql/logical.rs

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,9 @@ mod filter;
99
mod join;
1010
mod limit;
1111
mod offset;
12-
pub mod projection;
12+
mod projection;
1313
mod sort;
14+
mod table_scan;
1415
mod union;
1516

1617
use datafusion_common::{Column, DFSchemaRef, DataFusionError, Result};
@@ -111,6 +112,11 @@ impl PyLogicalPlan {
111112
to_py_plan(self.current_node.as_ref())
112113
}
113114

115+
/// LogicalPlan::TableScan as PyTableScan
116+
pub fn table_scan(&self) -> PyResult<table_scan::PyTableScan> {
117+
to_py_plan(self.current_node.as_ref())
118+
}
119+
114120
/// Gets the "input" for the current LogicalPlan
115121
pub fn get_inputs(&mut self) -> PyResult<Vec<PyLogicalPlan>> {
116122
let mut py_inputs: Vec<PyLogicalPlan> = Vec::new();
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
use crate::sql::exceptions::py_type_err;
2+
use crate::sql::logical;
3+
use datafusion_expr::logical_plan::TableScan;
4+
use pyo3::prelude::*;
5+
6+
#[pyclass(name = "TableScan", module = "dask_planner", subclass)]
7+
#[derive(Clone)]
8+
pub struct PyTableScan {
9+
pub(crate) table_scan: TableScan,
10+
}
11+
12+
#[pymethods]
13+
impl PyTableScan {
14+
#[pyo3(name = "getTableScanProjects")]
15+
fn scan_projects(&mut self) -> PyResult<Vec<String>> {
16+
match &self.table_scan.projection {
17+
Some(indices) => {
18+
let schema = self.table_scan.source.schema();
19+
Ok(indices
20+
.iter()
21+
.map(|i| schema.field(*i).name().to_string())
22+
.collect())
23+
}
24+
None => Ok(vec![]),
25+
}
26+
}
27+
28+
/// If the 'TableScan' contains columns that should be projected during the
29+
/// read return True, otherwise return False
30+
#[pyo3(name = "containsProjections")]
31+
fn contains_projections(&self) -> bool {
32+
self.table_scan.projection.is_some()
33+
}
34+
}
35+
36+
impl TryFrom<logical::LogicalPlan> for PyTableScan {
37+
type Error = PyErr;
38+
39+
fn try_from(logical_plan: logical::LogicalPlan) -> Result<Self, Self::Error> {
40+
match logical_plan {
41+
logical::LogicalPlan::TableScan(table_scan) => Ok(PyTableScan { table_scan }),
42+
_ => Err(py_type_err("unexpected plan")),
43+
}
44+
}
45+
}

dask_planner/src/sql/types/rel_data_type.rs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -76,8 +76,7 @@ impl RelDataType {
7676
self.field_list.clone()
7777
}
7878

79-
/// Returns the names of the fields in a struct type. The field count
80-
/// is equal to the size of the returned list.
79+
/// Returns the names of all of the columns in a given DaskTable
8180
#[pyo3(name = "getFieldNames")]
8281
pub fn field_names(&self) -> Vec<String> {
8382
assert!(!self.field_list.is_empty());

dask_sql/physical/rel/logical/table_scan.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,9 @@ def convert(
3333
# There should not be any input. This is the first step.
3434
self.assert_inputs(rel, 0)
3535

36+
# Rust table_scan instance handle
37+
table_scan = rel.table_scan()
38+
3639
# The table(s) we need to return
3740
table = rel.getTable()
3841

@@ -48,11 +51,15 @@ def convert(
4851
df = dc.df
4952
cc = dc.column_container
5053

51-
# Make sure we only return the requested columns
52-
row_type = table.getRowType()
53-
field_specifications = [str(f) for f in row_type.getFieldNames()]
54-
cc = cc.limit_to(field_specifications)
54+
# If the 'TableScan' instance contains projected columns only retrieve those columns
55+
# otherwise get all projected columns from the 'Projection' instance, which is contained
56+
# in the 'RelDataType' instance, aka 'row_type'
57+
if table_scan.containsProjections():
58+
field_specifications = table_scan.getTableScanProjects()
59+
else:
60+
field_specifications = [str(f) for f in table.getRowType().getFieldNames()]
5561

62+
cc = cc.limit_to(field_specifications)
5663
cc = self.fix_column_to_row_type(cc, rel.getRowType())
5764
dc = DataContainer(df, cc)
5865
dc = self.fix_dtype_to_row_type(dc, rel.getRowType())

tests/integration/test_select.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -222,3 +222,32 @@ def test_case_when_no_else(c):
222222
expected_df = pd.DataFrame({"C": [None, 1, 1, 1, None]})
223223

224224
assert_eq(actual_df, expected_df)
225+
226+
227+
def test_singular_column_projection_simple(c):
228+
df = pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]})
229+
c.create_table("df", df)
230+
231+
wildcard_result = c.sql("SELECT * from df")
232+
single_col_result = c.sql("SELECT b from df")
233+
234+
assert_eq(wildcard_result["b"], single_col_result["b"])
235+
236+
237+
@pytest.mark.parametrize(
238+
"input_cols",
239+
[
240+
["a"],
241+
["a", "b"],
242+
["a", "d"],
243+
["d", "a"],
244+
["a", "b", "d"],
245+
],
246+
)
247+
def test_multiple_column_projection(c, input_cols):
248+
projection_list = ", ".join(input_cols)
249+
result = c.sql(f"SELECT {projection_list} from parquet_ddf")
250+
251+
# There are 5 columns in the table, ensure only specified ones are read
252+
assert_eq(len(result.columns), len(input_cols))
253+
assert all(x in input_cols for x in result.columns)

0 commit comments

Comments
 (0)