diff --git a/.github/docker-compose.yaml b/.github/docker-compose.yaml index cfb7eb43f..56ec50b47 100644 --- a/.github/docker-compose.yaml +++ b/.github/docker-compose.yaml @@ -11,5 +11,7 @@ services: container_name: dask-worker image: daskdev/dask:latest command: dask-worker dask-scheduler:8786 + environment: + EXTRA_CONDA_PACKAGES: "pyarrow>1.0.0" # required for parquet IO volumes: - /tmp:/tmp diff --git a/dask_sql/physical/rel/logical/filter.py b/dask_sql/physical/rel/logical/filter.py index 87c99e3e0..6e7078efd 100644 --- a/dask_sql/physical/rel/logical/filter.py +++ b/dask_sql/physical/rel/logical/filter.py @@ -1,12 +1,14 @@ import logging from typing import TYPE_CHECKING, Union +import dask.config as dask_config import dask.dataframe as dd import numpy as np from dask_sql.datacontainer import DataContainer from dask_sql.physical.rel.base import BaseRelPlugin from dask_sql.physical.rex import RexConverter +from dask_sql.physical.utils.filter import attempt_predicate_pushdown if TYPE_CHECKING: import dask_sql @@ -31,7 +33,11 @@ def filter_or_scalar(df: dd.DataFrame, filter_condition: Union[np.bool_, dd.Seri # In SQL, a NULL in a boolean is False on filtering filter_condition = filter_condition.fillna(False) - return df[filter_condition] + out = df[filter_condition] + if dask_config.get("sql.predicate_pushdown"): + return attempt_predicate_pushdown(out) + else: + return out class DaskFilterPlugin(BaseRelPlugin): diff --git a/dask_sql/physical/utils/filter.py b/dask_sql/physical/utils/filter.py new file mode 100644 index 000000000..67e4026f5 --- /dev/null +++ b/dask_sql/physical/utils/filter.py @@ -0,0 +1,368 @@ +import itertools +import logging +import operator + +import dask.dataframe as dd +import numpy as np +from dask.blockwise import Blockwise +from dask.highlevelgraph import HighLevelGraph +from dask.layers import DataFrameIOLayer +from dask.utils import M, apply, is_arraylike + +logger = logging.getLogger(__name__) + + +def attempt_predicate_pushdown(ddf: dd.DataFrame) -> dd.DataFrame: + """Use graph information to update IO-level filters + + The original `ddf` will be returned if/when the + predicate-pushdown optimization fails. + + This is a special optimization that must be called + eagerly on a DataFrame collection when filters are + applied. The "eager" requirement for this optimization + is due to the fact that `npartitions` and `divisions` + may change when this optimization is applied (invalidating + npartition/divisions-specific logic in following Layers). + """ + + # Check that we have a supported `ddf` object + if not isinstance(ddf, dd.DataFrame): + raise ValueError( + f"Predicate pushdown optimization skipped. Type {type(ddf)} " + f"does not support predicate pushdown." + ) + elif not isinstance(ddf.dask, HighLevelGraph): + logger.warning( + f"Predicate pushdown optimization skipped. Graph must be " + f"a HighLevelGraph object (got {type(ddf.dask)})." + ) + return ddf + + # We were able to extract a DNF filter expression. + # Check that we have a single IO layer with `filters` support + io_layer = [] + for k, v in ddf.dask.layers.items(): + if isinstance(v, DataFrameIOLayer): + io_layer.append(k) + creation_info = ( + (v.creation_info or {}) if hasattr(v, "creation_info") else {} + ) + if ( + "filters" not in creation_info.get("kwargs", {}) + or creation_info["kwargs"]["filters"] is not None + ): + # No filters support, or filters is already set + return ddf + if len(io_layer) != 1: + # Not a single IO layer + return ddf + io_layer = io_layer.pop() + + # Start by converting the HLG to a `RegenerableGraph`. + # Succeeding here means that all layers in the graph + # are regenerable. + try: + dsk = RegenerableGraph.from_hlg(ddf.dask) + except (ValueError, TypeError): + logger.warning( + "Predicate pushdown optimization skipped. One or more " + "layers in the HighLevelGraph was not 'regenerable'." + ) + return ddf + + # Extract a DNF-formatted filter expression + name = ddf._name + try: + filters = dsk.layers[name]._dnf_filter_expression(dsk) + if not isinstance(filters, frozenset): + # No filters encountered + return ddf + filters = filters.to_list_tuple() + except ValueError: + # DNF dispatching failed for 1+ layers + logger.warning( + "Predicate pushdown optimization skipped. One or more " + "layers has an unknown filter expression." + ) + return ddf + + # Regenerate collection with filtered IO layer + try: + return dsk.layers[name]._regenerate_collection( + dsk, new_kwargs={io_layer: {"filters": filters}}, + ) + except ValueError as err: + # Most-likely failed to apply filters in read_parquet. + # We can just bail on predicate pushdown, but we also + # raise a warning to encourage the user to file an issue. + logger.warning( + f"Predicate pushdown failed to apply filters: {filters}. " + f"Please open a bug report at " + f"https://github.com/dask-contrib/dask-sql/issues/new/choose " + f"and include the following error message: {err}" + ) + + return ddf + + +class Or(frozenset): + """Helper class for 'OR' expressions""" + + def to_list_tuple(self): + # NDF "or" is List[List[Tuple]] + def _maybe_list(val): + if isinstance(val, tuple) and val and isinstance(val[0], (tuple, list)): + return list(val) + return [val] + + return [ + _maybe_list(val.to_list_tuple()) + if hasattr(val, "to_list_tuple") + else _maybe_list(val) + for val in self + ] + + +class And(frozenset): + """Helper class for 'AND' expressions""" + + def to_list_tuple(self): + # NDF "and" is List[Tuple] + return tuple( + val.to_list_tuple() if hasattr(val, "to_list_tuple") else val + for val in self + ) + + +def to_dnf(expr): + """Normalize a boolean filter expression to disjunctive normal form (DNF)""" + + # Credit: https://stackoverflow.com/a/58372345 + if not isinstance(expr, (Or, And)): + result = Or((And((expr,)),)) + elif isinstance(expr, Or): + result = Or(se for e in expr for se in to_dnf(e)) + elif isinstance(expr, And): + total = [] + for c in itertools.product(*[to_dnf(e) for e in expr]): + total.append(And(se for e in c for se in e)) + result = Or(total) + return result + + +# Define all supported comparison functions +# (and their mapping to a string expression) +_comparison_symbols = { + operator.eq: "==", + operator.ne: "!=", + operator.lt: "<", + operator.le: "<=", + operator.gt: ">", + operator.ge: ">=", + np.greater: ">", + np.greater_equal: ">=", + np.less: "<", + np.less_equal: "<=", + np.equal: "==", + np.not_equal: "!=", +} + +# Define set of all "regenerable" operations. +# Predicate pushdown is supported for graphs +# comprised of `Blockwise` layers based on these +# operations +_regenerable_ops = set(_comparison_symbols.keys()) | { + operator.and_, + operator.or_, + operator.getitem, + M.fillna, +} + +# Specify functions that must be generated with +# a different API at the dataframe-collection level +_special_op_mappings = {M.fillna: dd._Frame.fillna} + + +class RegenerableLayer: + """Regenerable Layer + + Wraps ``dask.highlevelgraph.Blockwise`` to ensure that a + ``creation_info`` attribute is defined. This class + also defines the necessary methods for recursive + layer regeneration and filter-expression generation. + """ + + def __init__(self, layer, creation_info): + self.layer = layer # Original Blockwise layer reference + self.creation_info = creation_info + + def _regenerate_collection( + self, dsk, new_kwargs: dict = None, _regen_cache: dict = None, + ): + """Regenerate a Dask collection for this layer using the + provided inputs and key-word arguments + """ + + # Return regenerated layer if the work was + # already done + _regen_cache = _regen_cache or {} + if self.layer.output in _regen_cache: + return _regen_cache[self.layer.output] + + # Recursively generate necessary inputs to + # this layer to generate the collection + inputs = [] + for key, ind in self.layer.indices: + if ind is None: + if isinstance(key, (str, tuple)) and key in dsk.layers: + continue + inputs.append(key) + elif key in self.layer.io_deps: + continue + else: + inputs.append( + dsk.layers[key]._regenerate_collection( + dsk, new_kwargs=new_kwargs, _regen_cache=_regen_cache, + ) + ) + + # Extract the callable func and key-word args. + # Then return a regenerated collection + func = self.creation_info.get("func", None) + if func is None: + raise ValueError( + "`_regenerate_collection` failed. " + "Not all HLG layers are regenerable." + ) + regen_args = self.creation_info.get("args", []) + regen_kwargs = self.creation_info.get("kwargs", {}).copy() + regen_kwargs = {k: v for k, v in self.creation_info.get("kwargs", {}).items()} + regen_kwargs.update((new_kwargs or {}).get(self.layer.output, {})) + result = func(*inputs, *regen_args, **regen_kwargs) + _regen_cache[self.layer.output] = result + return result + + def _dnf_filter_expression(self, dsk): + """Return a DNF-formatted filter expression for the + graph terminating at this layer + """ + op = self.creation_info["func"] + if op in _comparison_symbols.keys(): + func = _blockwise_comparison_dnf + elif op in (operator.and_, operator.or_): + func = _blockwise_logical_dnf + elif op == operator.getitem: + func = _blockwise_getitem_dnf + elif op == dd._Frame.fillna: + func = _blockwise_fillna_dnf + else: + raise ValueError(f"No DNF expression for {op}") + + return func(op, self.layer.indices, dsk) + + +class RegenerableGraph: + """Regenerable Graph + + This class is similar to ``dask.highlevelgraph.HighLevelGraph``. + However, all layers in a ``RegenerableGraph`` graph must be + ``RegenerableLayer`` objects (which wrap ``Blockwise`` layers). + """ + + def __init__(self, layers: dict): + self.layers = layers + + @classmethod + def from_hlg(cls, hlg: HighLevelGraph): + """Construct a ``RegenerableGraph`` from a ``HighLevelGraph``""" + + if not isinstance(hlg, HighLevelGraph): + raise TypeError(f"Expected HighLevelGraph, got {type(hlg)}") + + _layers = {} + for key, layer in hlg.layers.items(): + regenerable_layer = None + if isinstance(layer, DataFrameIOLayer): + regenerable_layer = RegenerableLayer(layer, layer.creation_info or {}) + elif isinstance(layer, Blockwise): + tasks = list(layer.dsk.values()) + if len(tasks) == 1 and tasks[0]: + kwargs = {} + if tasks[0][0] == apply: + op = tasks[0][1] + options = tasks[0][3] + if isinstance(options, dict): + kwargs = options + elif ( + isinstance(options, tuple) + and options + and callable(options[0]) + ): + kwargs = options[0](*options[1:]) + else: + op = tasks[0][0] + if op in _regenerable_ops: + regenerable_layer = RegenerableLayer( + layer, + { + "func": _special_op_mappings.get(op, op), + "kwargs": kwargs, + }, + ) + + if regenerable_layer is None: + raise ValueError(f"Graph contains non-regenerable layer: {layer}") + + _layers[key] = regenerable_layer + + return RegenerableGraph(_layers) + + +def _get_blockwise_input(input_index, indices: list, dsk: RegenerableGraph): + # Simple utility to get the required input expressions + # for a Blockwise layer (using indices) + key = indices[input_index][0] + if indices[input_index][1] is None: + return key + return dsk.layers[key]._dnf_filter_expression(dsk) + + +def _blockwise_comparison_dnf(op, indices: list, dsk: RegenerableGraph): + # Return DNF expression pattern for a simple comparison + left = _get_blockwise_input(0, indices, dsk) + right = _get_blockwise_input(1, indices, dsk) + + def _inv(symbol: str): + return {">": "<", "<": ">", ">=": "<=", "<=": ">=",}.get(symbol, symbol) + + if is_arraylike(left) and hasattr(left, "item") and left.size == 1: + left = left.item() + # Need inverse comparison in read_parquet + return (right, _inv(_comparison_symbols[op]), left) + if is_arraylike(right) and hasattr(right, "item") and right.size == 1: + right = right.item() + return to_dnf((left, _comparison_symbols[op], right)) + + +def _blockwise_logical_dnf(op, indices: list, dsk: RegenerableGraph): + # Return DNF expression pattern for logical "and" or "or" + left = _get_blockwise_input(0, indices, dsk) + right = _get_blockwise_input(1, indices, dsk) + if op == operator.or_: + return to_dnf(Or([left, right])) + elif op == operator.and_: + return to_dnf(And([left, right])) + else: + raise ValueError + + +def _blockwise_getitem_dnf(op, indices: list, dsk: RegenerableGraph): + # Return dnf of key (selected by getitem) + key = _get_blockwise_input(1, indices, dsk) + return key + + +def _blockwise_fillna_dnf(op, indices: list, dsk: RegenerableGraph): + # Return dnf of input collection + return _get_blockwise_input(0, indices, dsk) diff --git a/dask_sql/sql-schema.yaml b/dask_sql/sql-schema.yaml index 06c766854..f65e4d344 100644 --- a/dask_sql/sql-schema.yaml +++ b/dask_sql/sql-schema.yaml @@ -26,3 +26,8 @@ properties: type: boolean description: | Whether sql identifiers are considered case sensitive while parsing. + + predicate_pushdown: + type: bool + description: | + Whether to try pushing down filter predicates into IO (when possible). diff --git a/dask_sql/sql.yaml b/dask_sql/sql.yaml index 1976e72c3..72f28c271 100644 --- a/dask_sql/sql.yaml +++ b/dask_sql/sql.yaml @@ -5,3 +5,5 @@ sql: identifier: case_sensitive: True + + predicate_pushdown: True diff --git a/tests/integration/fixtures.py b/tests/integration/fixtures.py index 056c681e0..75b98a9f7 100644 --- a/tests/integration/fixtures.py +++ b/tests/integration/fixtures.py @@ -114,6 +114,30 @@ def datetime_table(): ) +@pytest.fixture() +def parquet_ddf(tmpdir): + + # Write simple parquet dataset + df = pd.DataFrame( + { + "a": [1, 2, 3] * 5, + "b": range(15), + "c": ["A"] * 15, + "d": [ + pd.Timestamp("2013-08-01 23:00:00"), + pd.Timestamp("2014-09-01 23:00:00"), + pd.Timestamp("2015-10-01 23:00:00"), + ] + * 5, + "index": range(15), + }, + ) + dd.from_pandas(df, npartitions=3).to_parquet(os.path.join(tmpdir, "parquet")) + + # Read back with dask and apply WHERE query + return dd.read_parquet(os.path.join(tmpdir, "parquet"), index="index") + + @pytest.fixture() def gpu_user_table_1(user_table_1): return cudf.from_pandas(user_table_1) if cudf else None @@ -151,6 +175,7 @@ def c( user_table_nan, string_table, datetime_table, + parquet_ddf, gpu_user_table_1, gpu_df, gpu_long_table, @@ -168,6 +193,7 @@ def c( "user_table_nan": user_table_nan, "string_table": string_table, "datetime_table": datetime_table, + "parquet_ddf": parquet_ddf, "gpu_user_table_1": gpu_user_table_1, "gpu_df": gpu_df, "gpu_long_table": gpu_long_table, @@ -182,7 +208,11 @@ def c( for df_name, df in dfs.items(): if df is None: continue - dask_df = dd.from_pandas(df, npartitions=3) + if hasattr(df, "npartitions"): + # df is already a dask collection + dask_df = df + else: + dask_df = dd.from_pandas(df, npartitions=3) c.create_table(df_name, dask_df) yield c diff --git a/tests/integration/test_filter.py b/tests/integration/test_filter.py index ad98d4416..345b9d9e1 100644 --- a/tests/integration/test_filter.py +++ b/tests/integration/test_filter.py @@ -1,6 +1,7 @@ import dask.dataframe as dd import pandas as pd import pytest +from dask.utils_test import hlg_layer from pandas.testing import assert_frame_equal from dask_sql._compat import INT_NAN_IMPLEMENTED @@ -122,3 +123,97 @@ def test_filter_year(c): expected_df = df[df["year"] < 2016] assert_frame_equal(expected_df, actual_df) + + +@pytest.mark.parametrize( + "query,df_func,filters", + [ + ( + "SELECT * FROM parquet_ddf WHERE b < 10", + lambda x: x[x["b"] < 10], + [[("b", "<", 10)]], + ), + ( + "SELECT * FROM parquet_ddf WHERE a < 3 AND (b > 1 AND b < 5)", + lambda x: x[(x["a"] < 3) & ((x["b"] > 1) & (x["b"] < 5))], + [[("a", "<", 3), ("b", ">", 1), ("b", "<", 5)]], + ), + ( + "SELECT * FROM parquet_ddf WHERE (b > 5 AND b < 10) OR a = 1", + lambda x: x[((x["b"] > 5) & (x["b"] < 10)) | (x["a"] == 1)], + [[("a", "==", 1)], [("b", "<", 10), ("b", ">", 5)]], + ), + ( + "SELECT * FROM parquet_ddf WHERE b IN (1, 6)", + lambda x: x[(x["b"] == 1) | (x["b"] == 6)], + [[("b", "<=", 1), ("b", ">=", 1)], [("b", "<=", 6), ("b", ">=", 6)]], + ), + ( + "SELECT a FROM parquet_ddf WHERE (b > 5 AND b < 10) OR a = 1", + lambda x: x[((x["b"] > 5) & (x["b"] < 10)) | (x["a"] == 1)][["a"]], + [[("a", "==", 1)], [("b", "<", 10), ("b", ">", 5)]], + ), + ( + # Original filters NOT in disjunctive normal form + "SELECT a FROM parquet_ddf WHERE (parquet_ddf.b > 3 AND parquet_ddf.b < 10 OR parquet_ddf.a = 1) AND (parquet_ddf.c = 'A')", + lambda x: x[ + ((x["b"] > 3) & (x["b"] < 10) | (x["a"] == 1)) & (x["c"] == "A") + ][["a"]], + [ + [("c", "==", "A"), ("b", ">", 3), ("b", "<", 10)], + [("a", "==", 1), ("c", "==", "A")], + ], + ), + ( + # The predicate-pushdown optimization will be skipped here, + # because datetime accessors are not supported. However, + # the query should still succeed. + "SELECT * FROM parquet_ddf WHERE year(d) < 2015", + lambda x: x[x["d"].dt.year < 2015], + None, + ), + ], +) +def test_predicate_pushdown(c, parquet_ddf, query, df_func, filters): + + # Check for predicate pushdown. + # We can use the `hlg_layer` utility to make sure the + # `filters` field has been populated in `creation_info` + return_df = c.sql(query) + expect_filters = filters + got_filters = hlg_layer(return_df.dask, "read-parquet").creation_info["kwargs"][ + "filters" + ] + if expect_filters: + got_filters = frozenset(frozenset(v) for v in got_filters) + expect_filters = frozenset(frozenset(v) for v in filters) + assert got_filters == expect_filters + + # Check computed result is correct + df = parquet_ddf.compute() + expected_df = df_func(df) + dd.assert_eq(return_df, expected_df) + + +def test_filtered_csv(tmpdir, c): + # Predicate pushdown is NOT supported for CSV data. + # This test just checks that the "attempted" + # predicate-pushdown logic does not lead to + # any unexpected errors + + # Write simple csv dataset + df = pd.DataFrame({"a": [1, 2, 3] * 5, "b": range(15), "c": ["A"] * 15,},) + dd.from_pandas(df, npartitions=3).to_csv(tmpdir + "/*.csv", index=False) + + # Read back with dask and apply WHERE query + csv_ddf = dd.read_csv(tmpdir + "/*.csv") + try: + c.create_table("my_csv_table", csv_ddf) + return_df = c.sql("SELECT * FROM my_csv_table WHERE b < 10") + finally: + c.drop_table("my_csv_table") + + # Check computed result is correct + df = csv_ddf.compute() + expected_df = df[df["b"] < 10] + dd.assert_eq(return_df, expected_df) diff --git a/tests/integration/test_show.py b/tests/integration/test_show.py index a04129489..41e315a95 100644 --- a/tests/integration/test_show.py +++ b/tests/integration/test_show.py @@ -43,6 +43,7 @@ def test_tables(c): "user_table_nan", "string_table", "datetime_table", + "parquet_ddf", ] if cudf is None else [ @@ -56,6 +57,7 @@ def test_tables(c): "user_table_nan", "string_table", "datetime_table", + "parquet_ddf", "gpu_user_table_1", "gpu_df", "gpu_long_table",