diff --git a/dask_sql/physical/utils/filter.py b/dask_sql/physical/utils/filter.py index 5309289c4..7098b31b5 100644 --- a/dask_sql/physical/utils/filter.py +++ b/dask_sql/physical/utils/filter.py @@ -1,6 +1,7 @@ import itertools import logging import operator +from typing import List import dask.dataframe as dd import numpy as np @@ -12,7 +13,11 @@ logger = logging.getLogger(__name__) -def attempt_predicate_pushdown(ddf: dd.DataFrame) -> dd.DataFrame: +def attempt_predicate_pushdown( + ddf: dd.DataFrame, + conjunctive_filters: List[tuple] = None, + disjunctive_filters: List[tuple] = None, +) -> dd.DataFrame: """Use graph information to update IO-level filters The original `ddf` will be returned if/when the @@ -24,6 +29,9 @@ def attempt_predicate_pushdown(ddf: dd.DataFrame) -> dd.DataFrame: is due to the fact that `npartitions` and `divisions` may change when this optimization is applied (invalidating npartition/divisions-specific logic in following Layers). + + Additonally applies provided conjunctive and disjunctive filters + if applicable. """ # Check that we have a supported `ddf` object @@ -87,6 +95,11 @@ def attempt_predicate_pushdown(ddf: dd.DataFrame) -> dd.DataFrame: ) return ddf + # Expand the filter set with provided conjunctive and disjunctive filters + filters.extend([f] for f in disjunctive_filters or []) + # Add conjunctive filters to each disjunctive filter + for f in filters: + f.extend(conjunctive_filters or []) # Regenerate collection with filtered IO layer try: return dsk.layers[name]._regenerate_collection( diff --git a/tests/unit/test_utils.py b/tests/unit/test_utils.py index 41f918558..e705be2bd 100644 --- a/tests/unit/test_utils.py +++ b/tests/unit/test_utils.py @@ -1,7 +1,9 @@ import pandas as pd import pytest from dask import dataframe as dd +from dask.utils_test import hlg_layer +from dask_sql.physical.utils.filter import attempt_predicate_pushdown from dask_sql.utils import Pluggable, is_frame @@ -52,3 +54,96 @@ def test_overwrite(): assert PluginTest1.get_plugin("some_key") == "value_2" assert PluginTest1().get_plugin("some_key") == "value_2" + + +def test_predicate_pushdown(parquet_ddf): + filtered_df = parquet_ddf[parquet_ddf["a"] > 1] + pushdown_df = attempt_predicate_pushdown(filtered_df) + got_filters = hlg_layer(pushdown_df.dask, "read-parquet").creation_info["kwargs"][ + "filters" + ] + got_filters = frozenset(frozenset(v) for v in got_filters) + expected_filters = [[("a", ">", 1)]] + expected_filters = frozenset(frozenset(v) for v in expected_filters) + assert got_filters == expected_filters + + filtered_df = parquet_ddf[ + (parquet_ddf["a"] > 1) & (parquet_ddf["b"] < 2) | (parquet_ddf["a"] == -1) + ] + + pushdown_df = attempt_predicate_pushdown(filtered_df) + got_filters = hlg_layer(pushdown_df.dask, "read-parquet").creation_info["kwargs"][ + "filters" + ] + got_filters = frozenset(frozenset(v) for v in got_filters) + expected_filters = [[("a", ">", 1), ("b", "<", 2)], [("a", "==", -1)]] + expected_filters = frozenset(frozenset(v) for v in expected_filters) + assert got_filters == expected_filters + + disjunctive_filters = [("c", "in", ("A", "B", "C"))] + pushdown_df = attempt_predicate_pushdown( + filtered_df, disjunctive_filters=disjunctive_filters + ) + got_filters = hlg_layer(pushdown_df.dask, "read-parquet").creation_info["kwargs"][ + "filters" + ] + got_filters = frozenset(frozenset(v) for v in got_filters) + expected_filters = [ + [("b", "<", 2), ("a", ">", 1)], + [("a", "==", -1)], + [("c", "in", ("A", "B", "C"))], + ] + expected_filters = frozenset(frozenset(v) for v in expected_filters) + assert got_filters == expected_filters + + disjunctive_filters = [("c", "in", ("A", "B", "C")), ("b", "in", (5, 6, 7))] + pushdown_df = attempt_predicate_pushdown( + filtered_df, disjunctive_filters=disjunctive_filters + ) + got_filters = hlg_layer(pushdown_df.dask, "read-parquet").creation_info["kwargs"][ + "filters" + ] + got_filters = frozenset(frozenset(v) for v in got_filters) + expected_filters = [ + [("b", "<", 2), ("a", ">", 1)], + [("a", "==", -1)], + [("c", "in", ("A", "B", "C"))], + [("b", "in", (5, 6, 7))], + ] + expected_filters = frozenset(frozenset(v) for v in expected_filters) + assert got_filters == expected_filters + + conjunctive_filters = [("c", "in", ("A", "B", "C"))] + pushdown_df = attempt_predicate_pushdown( + filtered_df, conjunctive_filters=conjunctive_filters + ) + got_filters = hlg_layer(pushdown_df.dask, "read-parquet").creation_info["kwargs"][ + "filters" + ] + got_filters = frozenset(frozenset(v) for v in got_filters) + expected_filters = [ + [("b", "<", 2), ("a", ">", 1), ("c", "in", ("A", "B", "C"))], + [("a", "==", -1), ("c", "in", ("A", "B", "C"))], + ] + expected_filters = frozenset(frozenset(v) for v in expected_filters) + assert got_filters == expected_filters + + conjunctive_filters = [("c", "in", ("A", "B", "C")), ("a", "<=", 100)] + disjunctive_filters = [("b", "in", (5, 6, 7)), ("a", ">=", 100)] + pushdown_df = attempt_predicate_pushdown( + filtered_df, + conjunctive_filters=conjunctive_filters, + disjunctive_filters=disjunctive_filters, + ) + got_filters = hlg_layer(pushdown_df.dask, "read-parquet").creation_info["kwargs"][ + "filters" + ] + got_filters = frozenset(frozenset(v) for v in got_filters) + expected_filters = [ + [("b", "<", 2), ("a", ">", 1), ("c", "in", ("A", "B", "C")), ("a", "<=", 100)], + [("a", "==", -1), ("c", "in", ("A", "B", "C")), ("a", "<=", 100)], + [("b", "in", (5, 6, 7)), ("c", "in", ("A", "B", "C")), ("a", "<=", 100)], + [("a", ">=", 100), ("c", "in", ("A", "B", "C")), ("a", "<=", 100)], + ] + expected_filters = frozenset(frozenset(v) for v in expected_filters) + assert got_filters == expected_filters