diff --git a/dask_planner/src/sql/logical/table_scan.rs b/dask_planner/src/sql/logical/table_scan.rs index 679d24c49..3b7a89e6e 100644 --- a/dask_planner/src/sql/logical/table_scan.rs +++ b/dask_planner/src/sql/logical/table_scan.rs @@ -1,4 +1,4 @@ -use std::sync::Arc; +use std::{sync::Arc, vec}; use datafusion_python::{ datafusion_common::{DFSchema, ScalarValue}, @@ -19,6 +19,7 @@ pub struct PyTableScan { input: Arc, } +type FilterTuple = (String, String, Option>); #[pyclass(name = "FilteredResult", module = "dask_planner", subclass)] #[derive(Debug, Clone)] pub struct PyFilteredResult { @@ -31,7 +32,7 @@ pub struct PyFilteredResult { // Expr(s) that can have their filtering logic performed in the pyarrow IO logic // are stored here in a DNF format that is expected by pyarrow. #[pyo3(get)] - pub filtered_exprs: Vec<(String, String, Vec)>, + pub filtered_exprs: Vec<(PyExpr, FilterTuple)>, } impl PyTableScan { @@ -45,9 +46,10 @@ impl PyTableScan { /// it as well if needed. pub fn _expand_dnf_filter( filter: &Expr, + input: &Arc, py: Python, - ) -> Result)>, DaskPlannerError> { - let mut filter_tuple: Vec<(String, String, Vec)> = Vec::new(); + ) -> Result, DaskPlannerError> { + let mut filter_tuple: Vec<(PyExpr, FilterTuple)> = Vec::new(); match filter { Expr::InList { @@ -100,9 +102,12 @@ impl PyTableScan { .collect(); filter_tuple.push(( - ident.unwrap_or(expr.canonical_name()), - op.to_string(), - il?, + PyExpr::from(filter.clone(), Some(vec![input.clone()])), + ( + ident.unwrap_or(expr.canonical_name()), + op.to_string(), + Some(il?), + ), )); Ok(filter_tuple) } else { @@ -110,15 +115,35 @@ impl PyTableScan { "Invalid identifying column Expr instance `{}`. using in Dask instead", filter )); - Err::)>, DaskPlannerError>(er) + Err::, DaskPlannerError>(er) } } + Expr::IsNotNull(expr) => { + // Only handle simple Expr(s) for IsNotNull operations for now + let ident = match *expr.clone() { + Expr::Column(col) => Ok(col.name), + _ => Err(DaskPlannerError::InvalidIOFilter(format!( + "Invalid IsNotNull Expr type `{}`. using in Dask instead", + filter + ))), + }; + + filter_tuple.push(( + PyExpr::from(filter.clone(), Some(vec![input.clone()])), + ( + ident.unwrap_or(expr.canonical_name()), + "is not".to_string(), + None, + ), + )); + Ok(filter_tuple) + } _ => { let er = DaskPlannerError::InvalidIOFilter(format!( "Unable to apply filter: `{}` to IO reader, using in Dask instead", filter )); - Err::)>, DaskPlannerError>(er) + Err::, DaskPlannerError>(er) } } } @@ -132,12 +157,12 @@ impl PyTableScan { filters: &[Expr], py: Python, ) -> PyFilteredResult { - let mut filtered_exprs: Vec<(String, String, Vec)> = Vec::new(); + let mut filtered_exprs: Vec<(PyExpr, FilterTuple)> = Vec::new(); let mut unfiltered_exprs: Vec = Vec::new(); filters .iter() - .for_each(|f| match PyTableScan::_expand_dnf_filter(f, py) { + .for_each(|f| match PyTableScan::_expand_dnf_filter(f, input, py) { Ok(mut expanded_dnf_filter) => filtered_exprs.append(&mut expanded_dnf_filter), Err(_e) => { unfiltered_exprs.push(PyExpr::from(f.clone(), Some(vec![input.clone()]))) diff --git a/dask_sql/physical/rel/logical/filter.py b/dask_sql/physical/rel/logical/filter.py index 178121fef..d3c3f5fd3 100644 --- a/dask_sql/physical/rel/logical/filter.py +++ b/dask_sql/physical/rel/logical/filter.py @@ -1,5 +1,5 @@ import logging -from typing import TYPE_CHECKING, Union +from typing import TYPE_CHECKING, List, Union import dask.config as dask_config import dask.dataframe as dd @@ -17,7 +17,11 @@ logger = logging.getLogger(__name__) -def filter_or_scalar(df: dd.DataFrame, filter_condition: Union[np.bool_, dd.Series]): +def filter_or_scalar( + df: dd.DataFrame, + filter_condition: Union[np.bool_, dd.Series], + add_filters: List = None, +): """ Some (complex) SQL queries can lead to a strange condition which is always true or false. We do not need to filter in this case. @@ -35,7 +39,7 @@ def filter_or_scalar(df: dd.DataFrame, filter_condition: Union[np.bool_, dd.Seri filter_condition = filter_condition.fillna(False) out = df[filter_condition] if dask_config.get("sql.predicate_pushdown"): - return attempt_predicate_pushdown(out) + return attempt_predicate_pushdown(out, add_filters=add_filters) else: return out diff --git a/dask_sql/physical/rel/logical/table_scan.py b/dask_sql/physical/rel/logical/table_scan.py index 5c1718f62..b4025ec97 100644 --- a/dask_sql/physical/rel/logical/table_scan.py +++ b/dask_sql/physical/rel/logical/table_scan.py @@ -3,6 +3,8 @@ from functools import reduce from typing import TYPE_CHECKING +from dask.utils_test import hlg_layer + from dask_sql.datacontainer import DataContainer from dask_sql.physical.rel.base import BaseRelPlugin from dask_sql.physical.rel.logical.filter import filter_or_scalar @@ -77,16 +79,38 @@ def _apply_projections(self, table_scan, dask_table, dc): def _apply_filters(self, table_scan, rel, dc, context): df = dc.df cc = dc.column_container - filters = table_scan.getFilters() - # All partial filters here are applied in conjunction (&) - if filters: + all_filters = table_scan.getFilters() + conjunctive_dnf_filters = table_scan.getDNFFilters().filtered_exprs + non_dnf_filters = table_scan.getDNFFilters().io_unfilterable_exprs + + if conjunctive_dnf_filters: + # Extract the PyExprs from the conjunctive DNF filters + filter_exprs = [f[0] for f in conjunctive_dnf_filters] + if non_dnf_filters: + filter_exprs.extend(non_dnf_filters) + + df_condition = reduce( + operator.and_, + [ + RexConverter.convert(rel, rex, dc, context=context) + for rex in filter_exprs + ], + ) + df = filter_or_scalar( + df, df_condition, add_filters=[f[1] for f in conjunctive_dnf_filters] + ) + elif all_filters: df_condition = reduce( operator.and_, [ RexConverter.convert(rel, rex, dc, context=context) - for rex in filters + for rex in all_filters ], ) df = filter_or_scalar(df, df_condition) + try: + logger.debug(hlg_layer(df.dask, "read-parquet").creation_info) + except KeyError: + pass return DataContainer(df, cc) diff --git a/dask_sql/physical/utils/filter.py b/dask_sql/physical/utils/filter.py index 0b4e0f40b..f99934c07 100644 --- a/dask_sql/physical/utils/filter.py +++ b/dask_sql/physical/utils/filter.py @@ -129,10 +129,12 @@ def attempt_predicate_pushdown( # Regenerate collection with filtered IO layer try: + _regen_cache = {} return dsk.layers[name]._regenerate_collection( dsk, # TODO: shouldn't need to specify index=False after dask#9661 is merged new_kwargs={io_layer: {"filters": filters, "index": False}}, + _regen_cache=_regen_cache, ) except ValueError as err: # Most-likely failed to apply filters in read_parquet. @@ -195,15 +197,33 @@ def __bool__(self) -> bool: @classmethod def normalize(cls, filters: _And | _Or | list | tuple | None): """Convert raw filters to the `_Or(_And)` DNF representation""" + + def _valid_tuple(predicate: tuple): + col, op, val = predicate + if isinstance(col, tuple): + raise TypeError("filters must be List[Tuple] or List[List[Tuple]]") + if op in ("in", "not in"): + return (col, op, tuple(val)) + else: + return predicate + + def _valid_list(conjunction: list): + valid = [] + for predicate in conjunction: + if not isinstance(predicate, tuple): + raise TypeError(f"Predicate must be a tuple, got {predicate}") + valid.append(_valid_tuple(predicate)) + return valid + if not filters: result = None elif isinstance(filters, list): conjunctions = filters if isinstance(filters[0], list) else [filters] - result = cls._Or([cls._And(conjunction) for conjunction in conjunctions]) + result = cls._Or( + [cls._And(_valid_list(conjunction)) for conjunction in conjunctions] + ) elif isinstance(filters, tuple): - if isinstance(filters[0], tuple): - raise TypeError("filters must be List[Tuple] or List[List[Tuple]]") - result = cls._Or((cls._And((filters,)),)) + result = cls._Or((cls._And((_valid_tuple(filters),)),)) elif isinstance(filters, cls._Or): result = cls._Or(se for e in filters for se in cls.normalize(e)) elif isinstance(filters, cls._And): @@ -332,7 +352,8 @@ def _regenerate_collection( # Return regenerated layer if the work was # already done - _regen_cache = _regen_cache or {} + if _regen_cache is None: + _regen_cache = {} if self.layer.output in _regen_cache: return _regen_cache[self.layer.output]