diff --git a/dask_sql/_compat.py b/dask_sql/_compat.py index 20d69852c..0ed1c8c37 100644 --- a/dask_sql/_compat.py +++ b/dask_sql/_compat.py @@ -14,3 +14,7 @@ # TODO: remove when dask min version gets bumped BROADCAST_JOIN_SUPPORT_WORKING = _dask_version > parseVersion("2023.1.0") + +# Parquet predicate-support version checks +PQ_NOT_IN_SUPPORT = parseVersion(dask.__version__) > parseVersion("2023.5.1") +PQ_IS_SUPPORT = parseVersion(dask.__version__) >= parseVersion("2023.3.1") diff --git a/dask_sql/physical/utils/filter.py b/dask_sql/physical/utils/filter.py index 5309289c4..5ac82fdd0 100644 --- a/dask_sql/physical/utils/filter.py +++ b/dask_sql/physical/utils/filter.py @@ -5,10 +5,12 @@ import dask.dataframe as dd import numpy as np from dask.blockwise import Blockwise -from dask.highlevelgraph import HighLevelGraph +from dask.highlevelgraph import HighLevelGraph, MaterializedLayer from dask.layers import DataFrameIOLayer from dask.utils import M, apply, is_arraylike +from dask_sql._compat import PQ_IS_SUPPORT, PQ_NOT_IN_SUPPORT + logger = logging.getLogger(__name__) @@ -59,6 +61,13 @@ def attempt_predicate_pushdown(ddf: dd.DataFrame) -> dd.DataFrame: return ddf io_layer = io_layer.pop() + # Bail if any filters are already present in ddf + existing_filters = ( + ddf.dask.layers[io_layer].creation_info.get("kwargs", {}).get("filters") + ) + if existing_filters: + return ddf + # Start by converting the HLG to a `RegenerableGraph`. # Succeeding here means that all layers in the graph # are regenerable. @@ -79,7 +88,7 @@ def attempt_predicate_pushdown(ddf: dd.DataFrame) -> dd.DataFrame: # No filters encountered return ddf filters = filters.to_list_tuple() - except ValueError: + except (ValueError, TypeError): # DNF dispatching failed for 1+ layers logger.warning( "Predicate pushdown optimization skipped. One or more " @@ -142,13 +151,16 @@ def to_dnf(expr): # Credit: https://stackoverflow.com/a/58372345 if not isinstance(expr, (Or, And)): + if not isinstance(expr, tuple): + raise TypeError(f"expected tuple, got {expr}") result = Or((And((expr,)),)) elif isinstance(expr, Or): result = Or(se for e in expr for se in to_dnf(e)) elif isinstance(expr, And): total = [] for c in itertools.product(*[to_dnf(e) for e in expr]): - total.append(And(se for e in c for se in e)) + conjunction = [se for e in c for se in e if isinstance(se, tuple)] + total.append(And(conjunction)) result = Or(total) return result @@ -170,20 +182,65 @@ def to_dnf(expr): np.not_equal: "!=", } +# Define all regenerable "pass-through" ops +# that do not affect filters. +_pass_through_ops = {M.fillna, M.astype} + # Define set of all "regenerable" operations. # Predicate pushdown is supported for graphs # comprised of `Blockwise` layers based on these # operations -_regenerable_ops = set(_comparison_symbols.keys()) | { - operator.and_, - operator.or_, - operator.getitem, - M.fillna, -} +_regenerable_ops = ( + set(_comparison_symbols.keys()) + | { + operator.and_, + operator.or_, + operator.getitem, + operator.inv, + M.isin, + M.isna, + } + | _pass_through_ops +) # Specify functions that must be generated with # a different API at the dataframe-collection level -_special_op_mappings = {M.fillna: dd._Frame.fillna} +_special_op_mappings = { + M.fillna: dd._Frame.fillna, + M.isin: dd._Frame.isin, + M.isna: dd._Frame.isna, + M.astype: dd._Frame.astype, +} + +# Convert _pass_through_ops to respect "special" mappings +_pass_through_ops = {_special_op_mappings.get(op, op) for op in _pass_through_ops} + + +def _preprocess_layers(input_layers): + # NOTE: This is a Layer-specific work-around to deal with + # the fact that `dd._Frame.isin(values)` will add a distinct + # `MaterializedLayer` for the `values` argument. + # See: https://github.com/dask-contrib/dask-sql/issues/607 + skip = set() + layers = input_layers.copy() + for key, layer in layers.items(): + if key.startswith("isin-") and isinstance(layer, Blockwise): + indices = list(layer.indices) + for i, (k, ind) in enumerate(layer.indices): + if ( + ind is None + and isinstance(layers.get(k), MaterializedLayer) + and isinstance(layers[k].get(k), (np.ndarray, tuple)) + ): + # Replace `indices[i]` with a literal value and + # make sure we skip the `MaterializedLayer` that + # we are now fusing into the `isin` + value = layers[k][k] + value = value[0](*value[1:]) if callable(value[0]) else value + indices[i] = (value, None) + skip.add(k) + layer.indices = tuple(indices) + return {k: v for k, v in layers.items() if k not in skip} class RegenerableLayer: @@ -261,8 +318,14 @@ def _dnf_filter_expression(self, dsk): func = _blockwise_logical_dnf elif op == operator.getitem: func = _blockwise_getitem_dnf - elif op == dd._Frame.fillna: - func = _blockwise_fillna_dnf + elif op == dd._Frame.isin: + func = _blockwise_isin_dnf + elif op == dd._Frame.isna: + func = _blockwise_isna_dnf + elif op == operator.inv: + func = _blockwise_inv_dnf + elif op in _pass_through_ops: + func = _blockwise_pass_through_dnf else: raise ValueError(f"No DNF expression for {op}") @@ -288,7 +351,7 @@ def from_hlg(cls, hlg: HighLevelGraph): raise TypeError(f"Expected HighLevelGraph, got {type(hlg)}") _layers = {} - for key, layer in hlg.layers.items(): + for key, layer in _preprocess_layers(hlg.layers).items(): regenerable_layer = None if isinstance(layer, DataFrameIOLayer): regenerable_layer = RegenerableLayer(layer, layer.creation_info or {}) @@ -335,23 +398,30 @@ def _get_blockwise_input(input_index, indices: list, dsk: RegenerableGraph): return dsk.layers[key]._dnf_filter_expression(dsk) +def _inv(symbol: str): + if symbol == "in" and not PQ_NOT_IN_SUPPORT: + raise ValueError("This version of dask does not support 'not in'") + return { + ">": "<", + "<": ">", + ">=": "<=", + "<=": ">=", + "in": "not in", + "not in": "in", + "is": "is not", + "is not": "is", + }.get(symbol, symbol) + + def _blockwise_comparison_dnf(op, indices: list, dsk: RegenerableGraph): # Return DNF expression pattern for a simple comparison left = _get_blockwise_input(0, indices, dsk) right = _get_blockwise_input(1, indices, dsk) - def _inv(symbol: str): - return { - ">": "<", - "<": ">", - ">=": "<=", - "<=": ">=", - }.get(symbol, symbol) - if is_arraylike(left) and hasattr(left, "item") and left.size == 1: left = left.item() # Need inverse comparison in read_parquet - return (right, _inv(_comparison_symbols[op]), left) + return to_dnf((right, _inv(_comparison_symbols[op]), left)) if is_arraylike(right) and hasattr(right, "item") and right.size == 1: right = right.item() return to_dnf((left, _comparison_symbols[op], right)) @@ -361,10 +431,17 @@ def _blockwise_logical_dnf(op, indices: list, dsk: RegenerableGraph): # Return DNF expression pattern for logical "and" or "or" left = _get_blockwise_input(0, indices, dsk) right = _get_blockwise_input(1, indices, dsk) + + vals = [] + for val in [left, right]: + if not isinstance(val, (tuple, Or, And)): + raise TypeError(f"Invalid logical operand: {val}") + vals.append(to_dnf(val)) + if op == operator.or_: - return to_dnf(Or([left, right])) + return to_dnf(Or(vals)) elif op == operator.and_: - return to_dnf(And([left, right])) + return to_dnf(And(vals)) else: raise ValueError @@ -375,6 +452,39 @@ def _blockwise_getitem_dnf(op, indices: list, dsk: RegenerableGraph): return key -def _blockwise_fillna_dnf(op, indices: list, dsk: RegenerableGraph): +def _blockwise_pass_through_dnf(op, indices: list, dsk: RegenerableGraph): # Return dnf of input collection return _get_blockwise_input(0, indices, dsk) + + +def _blockwise_isin_dnf(op, indices: list, dsk: RegenerableGraph): + # Return DNF expression pattern for a simple "in" comparison + left = _get_blockwise_input(0, indices, dsk) + right = _get_blockwise_input(1, indices, dsk) + return to_dnf((left, "in", tuple(right))) + + +def _blockwise_isna_dnf(op, indices: list, dsk: RegenerableGraph): + # Return DNF expression pattern for `isna` + if not PQ_IS_SUPPORT: + raise ValueError("This version of dask does not support 'is' predicates.") + left = _get_blockwise_input(0, indices, dsk) + return to_dnf((left, "is", None)) + + +def _blockwise_inv_dnf(op, indices: list, dsk: RegenerableGraph): + # Return DNF expression pattern for the inverse of a comparison + expr = _get_blockwise_input(0, indices, dsk).to_list_tuple() + new_expr = [] + count = 0 + for conjunction in expr: + new_conjunction = [] + for col, op, val in conjunction: + count += 1 + new_conjunction.append((col, _inv(op), val)) + new_expr.append(And(new_conjunction)) + if count > 1: + # Havent taken the time to think through + # general inversion yet. + raise ValueError("inv(DNF) case not implemented.") + return to_dnf(Or(new_expr)) diff --git a/tests/integration/test_filter.py b/tests/integration/test_filter.py index 0388aced8..9de072d5e 100644 --- a/tests/integration/test_filter.py +++ b/tests/integration/test_filter.py @@ -5,6 +5,7 @@ from dask.utils_test import hlg_layer from packaging.version import parse as parseVersion +from dask_sql._compat import PQ_IS_SUPPORT, PQ_NOT_IN_SUPPORT from tests.utils import assert_eq DASK_GT_2022_4_2 = parseVersion(dask.__version__) >= parseVersion("2022.4.2") @@ -162,10 +163,26 @@ def test_filter_year(c): ), pytest.param( "SELECT * FROM parquet_ddf WHERE b IN (1, 3, 5, 6)", - lambda x: x[(x["b"] == 1) | (x["b"] == 3) | (x["b"] == 5) | (x["b"] == 6)], - [[("b", "==", 1)], [("b", "==", 3)], [("b", "==", 5)], [("b", "==", 6)]], - marks=pytest.mark.xfail( - reason="WIP https://github.com/dask-contrib/dask-sql/issues/607" + lambda x: x[x["b"].isin([1, 3, 5, 6])], + [[("b", "in", (1, 3, 5, 6))]], + ), + pytest.param( + "SELECT * FROM parquet_ddf WHERE c IN ('A', 'B', 'C', 'D')", + lambda x: x[x["c"].isin(["A", "B", "C", "D"])], + [[("c", "in", ("A", "B", "C", "D"))]], + ), + pytest.param( + "SELECT * FROM parquet_ddf WHERE b NOT IN (1, 6)", + lambda x: x[(x["b"] != 1) & (x["b"] != 6)], + [[("b", "!=", 1), ("b", "!=", 6)]], + ), + pytest.param( + "SELECT * FROM parquet_ddf WHERE b NOT IN (1, 3, 5, 6)", + lambda x: x[~x["b"].isin([1, 3, 5, 6])], + [[("b", "not in", (1, 3, 5, 6))]], + marks=pytest.mark.skipif( + not PQ_NOT_IN_SUPPORT, + reason="Requires https://github.com/dask/dask/pull/10320", ), ), ( @@ -296,3 +313,55 @@ def test_filter_decimal(c, gpu): assert_eq(result_df, expected_df, check_index=False) c.drop_table("df") + + +@pytest.mark.skipif( + not PQ_IS_SUPPORT, + reason="Requires https://github.com/dask/dask/pull/10320", +) +def test_predicate_pushdown_isna(tmpdir): + from dask_sql.context import Context + + c = Context() + + path = str(tmpdir) + dd.from_pandas( + pd.DataFrame( + { + "a": [1, 2, None] * 5, + "b": range(15), + "index": range(15), + } + ), + npartitions=3, + ).to_parquet(path + "/df1") + df1 = dd.read_parquet(path + "/df1", index="index") + c.create_table("df1", df1) + + dd.from_pandas( + pd.DataFrame( + { + "a": [None, 2, 3] * 5, + "b": range(15), + "index": range(15), + }, + ), + npartitions=3, + ).to_parquet(path + "/df2") + df2 = dd.read_parquet(path + "/df2", index="index") + c.create_table("df2", df2) + + return_df = c.sql("SELECT df1.a FROM df1, df2 WHERE df1.a = df2.a") + + # Check for predicate pushdown + filters = [[("a", "is not", None)]] + got_filters = hlg_layer(return_df.dask, "read-parquet").creation_info["kwargs"][ + "filters" + ] + + got_filters = frozenset(frozenset(v) for v in got_filters) + expect_filters = frozenset(frozenset(v) for v in filters) + + assert got_filters == expect_filters + assert all(return_df.compute() == 2) + assert len(return_df) == 25 diff --git a/tests/unit/test_utils.py b/tests/unit/test_utils.py index 41f918558..00eb5a885 100644 --- a/tests/unit/test_utils.py +++ b/tests/unit/test_utils.py @@ -1,7 +1,10 @@ import pandas as pd import pytest from dask import dataframe as dd +from dask.utils_test import hlg_layer +from dask_sql._compat import PQ_IS_SUPPORT, PQ_NOT_IN_SUPPORT +from dask_sql.physical.utils.filter import attempt_predicate_pushdown from dask_sql.utils import Pluggable, is_frame @@ -52,3 +55,78 @@ def test_overwrite(): assert PluginTest1.get_plugin("some_key") == "value_2" assert PluginTest1().get_plugin("some_key") == "value_2" + + +def test_predicate_pushdown_simple(parquet_ddf): + filtered_df = parquet_ddf[parquet_ddf["a"] > 1] + pushdown_df = attempt_predicate_pushdown(filtered_df) + got_filters = hlg_layer(pushdown_df.dask, "read-parquet").creation_info["kwargs"][ + "filters" + ] + got_filters = frozenset(frozenset(v) for v in got_filters) + expected_filters = [[("a", ">", 1)]] + expected_filters = frozenset(frozenset(v) for v in expected_filters) + assert got_filters == expected_filters + + +def test_predicate_pushdown_logical(parquet_ddf): + filtered_df = parquet_ddf[ + (parquet_ddf["a"] > 1) & (parquet_ddf["b"] < 2) | (parquet_ddf["a"] == -1) + ] + + pushdown_df = attempt_predicate_pushdown(filtered_df) + got_filters = hlg_layer(pushdown_df.dask, "read-parquet").creation_info["kwargs"][ + "filters" + ] + got_filters = frozenset(frozenset(v) for v in got_filters) + expected_filters = [[("a", ">", 1), ("b", "<", 2)], [("a", "==", -1)]] + expected_filters = frozenset(frozenset(v) for v in expected_filters) + assert got_filters == expected_filters + + +@pytest.mark.skipif( + not PQ_NOT_IN_SUPPORT, + reason="Requires https://github.com/dask/dask/pull/10320", +) +def test_predicate_pushdown_in(parquet_ddf): + filtered_df = parquet_ddf[ + (parquet_ddf["a"] > 1) & (parquet_ddf["b"] < 2) + | (parquet_ddf["a"] == -1) & parquet_ddf["c"].isin(("A", "B", "C")) + | ~parquet_ddf["b"].isin((5, 6, 7)) + ] + pushdown_df = attempt_predicate_pushdown(filtered_df) + got_filters = hlg_layer(pushdown_df.dask, "read-parquet").creation_info["kwargs"][ + "filters" + ] + got_filters = frozenset(frozenset(v) for v in got_filters) + expected_filters = [ + [("b", "<", 2), ("a", ">", 1)], + [("a", "==", -1), ("c", "in", ("A", "B", "C"))], + [("b", "not in", (5, 6, 7))], + ] + expected_filters = frozenset(frozenset(v) for v in expected_filters) + assert got_filters == expected_filters + + +@pytest.mark.skipif( + not PQ_IS_SUPPORT, + reason="Requires dask>=2023.3.1", +) +def test_predicate_pushdown_isna(parquet_ddf): + filtered_df = parquet_ddf[ + (parquet_ddf["a"] > 1) & (parquet_ddf["b"] < 2) + | (parquet_ddf["a"] == -1) & ~parquet_ddf["c"].isna() + | parquet_ddf["b"].isna() + ] + pushdown_df = attempt_predicate_pushdown(filtered_df) + got_filters = hlg_layer(pushdown_df.dask, "read-parquet").creation_info["kwargs"][ + "filters" + ] + got_filters = frozenset(frozenset(v) for v in got_filters) + expected_filters = [ + [("b", "<", 2), ("a", ">", 1)], + [("a", "==", -1), ("c", "is not", None)], + [("b", "is", None)], + ] + expected_filters = frozenset(frozenset(v) for v in expected_filters) + assert got_filters == expected_filters