Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
1667eaf
add basic isin workaround
rjzamora May 26, 2023
46e9b97
add case - but blocked by upstream-dask bug (I think)
rjzamora May 26, 2023
c45da57
link dask PR
rjzamora May 26, 2023
120994b
Merge branch 'main' into isin-filtering
rjzamora May 30, 2023
16db287
Merge branch 'main' into isin-filtering
rjzamora Jun 1, 2023
ba0189f
update test to skip for released dask versions
rjzamora Jun 1, 2023
9f80b62
add isna support
rjzamora Jun 1, 2023
4253c9a
use from_pandas instead of from_dict
rjzamora Jun 2, 2023
b8da938
use from_pandas instead of from_dict (second attempt)
rjzamora Jun 2, 2023
a270b17
add more dask-version checks
rjzamora Jun 2, 2023
a1afe52
use None instead of np.nan
rjzamora Jun 2, 2023
ac40f24
bail if existing filters, and move version check
rjzamora Jun 5, 2023
ef28899
remove check for pre-existing filters (CI experiment)
rjzamora Jun 5, 2023
79162c4
revert last commit
rjzamora Jun 5, 2023
df52f6d
minor _blockwise_comparison_dnf fix
rjzamora Jun 5, 2023
822a33f
Merge branch 'main' into isin-filtering
rjzamora Jun 6, 2023
2099961
drop non-tuples
rjzamora Jun 6, 2023
dc77c4c
Merge remote-tracking branch 'upstream/main' into isin-filtering
rjzamora Jun 6, 2023
42bcb77
Merge branch 'isin-filtering' of https://github.com/rjzamora/dask-sql…
rjzamora Jun 6, 2023
5b71359
bail for logical operation containing literal boolean operand
rjzamora Jun 6, 2023
58f9a69
bail on TypeError in addition to ValueError
rjzamora Jun 6, 2023
aca5d48
Merge branch 'main' into isin-filtering
rjzamora Jun 6, 2023
b7cb68d
add tests to test_utils
rjzamora Jun 7, 2023
7b223df
Merge remote-tracking branch 'upstream/main' into isin-filtering
rjzamora Jun 7, 2023
5193718
Merge branch 'isin-filtering' of https://github.com/rjzamora/dask-sql…
rjzamora Jun 7, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions dask_sql/_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,7 @@

# TODO: remove when dask min version gets bumped
BROADCAST_JOIN_SUPPORT_WORKING = _dask_version > parseVersion("2023.1.0")

# Parquet predicate-support version checks
PQ_NOT_IN_SUPPORT = parseVersion(dask.__version__) > parseVersion("2023.5.1")
PQ_IS_SUPPORT = parseVersion(dask.__version__) >= parseVersion("2023.3.1")
160 changes: 135 additions & 25 deletions dask_sql/physical/utils/filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,12 @@
import dask.dataframe as dd
import numpy as np
from dask.blockwise import Blockwise
from dask.highlevelgraph import HighLevelGraph
from dask.highlevelgraph import HighLevelGraph, MaterializedLayer
from dask.layers import DataFrameIOLayer
from dask.utils import M, apply, is_arraylike

from dask_sql._compat import PQ_IS_SUPPORT, PQ_NOT_IN_SUPPORT

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -59,6 +61,13 @@ def attempt_predicate_pushdown(ddf: dd.DataFrame) -> dd.DataFrame:
return ddf
io_layer = io_layer.pop()

# Bail if any filters are already present in ddf
existing_filters = (
ddf.dask.layers[io_layer].creation_info.get("kwargs", {}).get("filters")
)
if existing_filters:
return ddf

# Start by converting the HLG to a `RegenerableGraph`.
# Succeeding here means that all layers in the graph
# are regenerable.
Expand All @@ -79,7 +88,7 @@ def attempt_predicate_pushdown(ddf: dd.DataFrame) -> dd.DataFrame:
# No filters encountered
return ddf
filters = filters.to_list_tuple()
except ValueError:
except (ValueError, TypeError):
# DNF dispatching failed for 1+ layers
logger.warning(
"Predicate pushdown optimization skipped. One or more "
Expand Down Expand Up @@ -142,13 +151,16 @@ def to_dnf(expr):

# Credit: https://stackoverflow.com/a/58372345
if not isinstance(expr, (Or, And)):
if not isinstance(expr, tuple):
raise TypeError(f"expected tuple, got {expr}")
result = Or((And((expr,)),))
elif isinstance(expr, Or):
result = Or(se for e in expr for se in to_dnf(e))
elif isinstance(expr, And):
total = []
for c in itertools.product(*[to_dnf(e) for e in expr]):
total.append(And(se for e in c for se in e))
conjunction = [se for e in c for se in e if isinstance(se, tuple)]
total.append(And(conjunction))
result = Or(total)
return result

Expand All @@ -170,20 +182,65 @@ def to_dnf(expr):
np.not_equal: "!=",
}

# Define all regenerable "pass-through" ops
# that do not affect filters.
_pass_through_ops = {M.fillna, M.astype}

# Define set of all "regenerable" operations.
# Predicate pushdown is supported for graphs
# comprised of `Blockwise` layers based on these
# operations
_regenerable_ops = set(_comparison_symbols.keys()) | {
operator.and_,
operator.or_,
operator.getitem,
M.fillna,
}
_regenerable_ops = (
set(_comparison_symbols.keys())
| {
operator.and_,
operator.or_,
operator.getitem,
operator.inv,
M.isin,
M.isna,
}
| _pass_through_ops
)

# Specify functions that must be generated with
# a different API at the dataframe-collection level
_special_op_mappings = {M.fillna: dd._Frame.fillna}
_special_op_mappings = {
M.fillna: dd._Frame.fillna,
M.isin: dd._Frame.isin,
M.isna: dd._Frame.isna,
M.astype: dd._Frame.astype,
}

# Convert _pass_through_ops to respect "special" mappings
_pass_through_ops = {_special_op_mappings.get(op, op) for op in _pass_through_ops}


def _preprocess_layers(input_layers):
# NOTE: This is a Layer-specific work-around to deal with
# the fact that `dd._Frame.isin(values)` will add a distinct
# `MaterializedLayer` for the `values` argument.
# See: https://github.com/dask-contrib/dask-sql/issues/607
skip = set()
layers = input_layers.copy()
for key, layer in layers.items():
if key.startswith("isin-") and isinstance(layer, Blockwise):
indices = list(layer.indices)
for i, (k, ind) in enumerate(layer.indices):
if (
ind is None
and isinstance(layers.get(k), MaterializedLayer)
and isinstance(layers[k].get(k), (np.ndarray, tuple))
):
# Replace `indices[i]` with a literal value and
# make sure we skip the `MaterializedLayer` that
# we are now fusing into the `isin`
value = layers[k][k]
value = value[0](*value[1:]) if callable(value[0]) else value
indices[i] = (value, None)
skip.add(k)
layer.indices = tuple(indices)
return {k: v for k, v in layers.items() if k not in skip}


class RegenerableLayer:
Expand Down Expand Up @@ -261,8 +318,14 @@ def _dnf_filter_expression(self, dsk):
func = _blockwise_logical_dnf
elif op == operator.getitem:
func = _blockwise_getitem_dnf
elif op == dd._Frame.fillna:
func = _blockwise_fillna_dnf
elif op == dd._Frame.isin:
func = _blockwise_isin_dnf
elif op == dd._Frame.isna:
func = _blockwise_isna_dnf
elif op == operator.inv:
func = _blockwise_inv_dnf
elif op in _pass_through_ops:
func = _blockwise_pass_through_dnf
else:
raise ValueError(f"No DNF expression for {op}")

Expand All @@ -288,7 +351,7 @@ def from_hlg(cls, hlg: HighLevelGraph):
raise TypeError(f"Expected HighLevelGraph, got {type(hlg)}")

_layers = {}
for key, layer in hlg.layers.items():
for key, layer in _preprocess_layers(hlg.layers).items():
regenerable_layer = None
if isinstance(layer, DataFrameIOLayer):
regenerable_layer = RegenerableLayer(layer, layer.creation_info or {})
Expand Down Expand Up @@ -335,23 +398,30 @@ def _get_blockwise_input(input_index, indices: list, dsk: RegenerableGraph):
return dsk.layers[key]._dnf_filter_expression(dsk)


def _inv(symbol: str):
if symbol == "in" and not PQ_NOT_IN_SUPPORT:
raise ValueError("This version of dask does not support 'not in'")
return {
">": "<",
"<": ">",
">=": "<=",
"<=": ">=",
"in": "not in",
"not in": "in",
"is": "is not",
"is not": "is",
}.get(symbol, symbol)


def _blockwise_comparison_dnf(op, indices: list, dsk: RegenerableGraph):
# Return DNF expression pattern for a simple comparison
left = _get_blockwise_input(0, indices, dsk)
right = _get_blockwise_input(1, indices, dsk)

def _inv(symbol: str):
return {
">": "<",
"<": ">",
">=": "<=",
"<=": ">=",
}.get(symbol, symbol)

if is_arraylike(left) and hasattr(left, "item") and left.size == 1:
left = left.item()
# Need inverse comparison in read_parquet
return (right, _inv(_comparison_symbols[op]), left)
return to_dnf((right, _inv(_comparison_symbols[op]), left))
if is_arraylike(right) and hasattr(right, "item") and right.size == 1:
right = right.item()
return to_dnf((left, _comparison_symbols[op], right))
Expand All @@ -361,10 +431,17 @@ def _blockwise_logical_dnf(op, indices: list, dsk: RegenerableGraph):
# Return DNF expression pattern for logical "and" or "or"
left = _get_blockwise_input(0, indices, dsk)
right = _get_blockwise_input(1, indices, dsk)

vals = []
for val in [left, right]:
if not isinstance(val, (tuple, Or, And)):
raise TypeError(f"Invalid logical operand: {val}")
vals.append(to_dnf(val))

if op == operator.or_:
return to_dnf(Or([left, right]))
return to_dnf(Or(vals))
elif op == operator.and_:
return to_dnf(And([left, right]))
return to_dnf(And(vals))
else:
raise ValueError

Expand All @@ -375,6 +452,39 @@ def _blockwise_getitem_dnf(op, indices: list, dsk: RegenerableGraph):
return key


def _blockwise_fillna_dnf(op, indices: list, dsk: RegenerableGraph):
def _blockwise_pass_through_dnf(op, indices: list, dsk: RegenerableGraph):
# Return dnf of input collection
return _get_blockwise_input(0, indices, dsk)


def _blockwise_isin_dnf(op, indices: list, dsk: RegenerableGraph):
# Return DNF expression pattern for a simple "in" comparison
left = _get_blockwise_input(0, indices, dsk)
right = _get_blockwise_input(1, indices, dsk)
return to_dnf((left, "in", tuple(right)))


def _blockwise_isna_dnf(op, indices: list, dsk: RegenerableGraph):
# Return DNF expression pattern for `isna`
if not PQ_IS_SUPPORT:
raise ValueError("This version of dask does not support 'is' predicates.")
left = _get_blockwise_input(0, indices, dsk)
return to_dnf((left, "is", None))


def _blockwise_inv_dnf(op, indices: list, dsk: RegenerableGraph):
# Return DNF expression pattern for the inverse of a comparison
expr = _get_blockwise_input(0, indices, dsk).to_list_tuple()
new_expr = []
count = 0
for conjunction in expr:
new_conjunction = []
for col, op, val in conjunction:
count += 1
new_conjunction.append((col, _inv(op), val))
new_expr.append(And(new_conjunction))
if count > 1:
# Havent taken the time to think through
# general inversion yet.
raise ValueError("inv(DNF) case not implemented.")
return to_dnf(Or(new_expr))
77 changes: 73 additions & 4 deletions tests/integration/test_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from dask.utils_test import hlg_layer
from packaging.version import parse as parseVersion

from dask_sql._compat import PQ_IS_SUPPORT, PQ_NOT_IN_SUPPORT
from tests.utils import assert_eq

DASK_GT_2022_4_2 = parseVersion(dask.__version__) >= parseVersion("2022.4.2")
Expand Down Expand Up @@ -162,10 +163,26 @@ def test_filter_year(c):
),
pytest.param(
"SELECT * FROM parquet_ddf WHERE b IN (1, 3, 5, 6)",
lambda x: x[(x["b"] == 1) | (x["b"] == 3) | (x["b"] == 5) | (x["b"] == 6)],
[[("b", "==", 1)], [("b", "==", 3)], [("b", "==", 5)], [("b", "==", 6)]],
marks=pytest.mark.xfail(
reason="WIP https://github.com/dask-contrib/dask-sql/issues/607"
lambda x: x[x["b"].isin([1, 3, 5, 6])],
[[("b", "in", (1, 3, 5, 6))]],
),
pytest.param(
"SELECT * FROM parquet_ddf WHERE c IN ('A', 'B', 'C', 'D')",
lambda x: x[x["c"].isin(["A", "B", "C", "D"])],
[[("c", "in", ("A", "B", "C", "D"))]],
),
pytest.param(
"SELECT * FROM parquet_ddf WHERE b NOT IN (1, 6)",
lambda x: x[(x["b"] != 1) & (x["b"] != 6)],
[[("b", "!=", 1), ("b", "!=", 6)]],
),
pytest.param(
"SELECT * FROM parquet_ddf WHERE b NOT IN (1, 3, 5, 6)",
lambda x: x[~x["b"].isin([1, 3, 5, 6])],
[[("b", "not in", (1, 3, 5, 6))]],
marks=pytest.mark.skipif(
not PQ_NOT_IN_SUPPORT,
reason="Requires https://github.com/dask/dask/pull/10320",
),
),
(
Expand Down Expand Up @@ -296,3 +313,55 @@ def test_filter_decimal(c, gpu):

assert_eq(result_df, expected_df, check_index=False)
c.drop_table("df")


@pytest.mark.skipif(
not PQ_IS_SUPPORT,
reason="Requires https://github.com/dask/dask/pull/10320",
)
def test_predicate_pushdown_isna(tmpdir):
from dask_sql.context import Context

c = Context()

path = str(tmpdir)
dd.from_pandas(
pd.DataFrame(
{
"a": [1, 2, None] * 5,
"b": range(15),
"index": range(15),
}
),
npartitions=3,
).to_parquet(path + "/df1")
df1 = dd.read_parquet(path + "/df1", index="index")
c.create_table("df1", df1)

dd.from_pandas(
pd.DataFrame(
{
"a": [None, 2, 3] * 5,
"b": range(15),
"index": range(15),
},
),
npartitions=3,
).to_parquet(path + "/df2")
df2 = dd.read_parquet(path + "/df2", index="index")
c.create_table("df2", df2)

return_df = c.sql("SELECT df1.a FROM df1, df2 WHERE df1.a = df2.a")

# Check for predicate pushdown
filters = [[("a", "is not", None)]]
got_filters = hlg_layer(return_df.dask, "read-parquet").creation_info["kwargs"][
"filters"
]

got_filters = frozenset(frozenset(v) for v in got_filters)
expect_filters = frozenset(frozenset(v) for v in filters)

assert got_filters == expect_filters
assert all(return_df.compute() == 2)
assert len(return_df) == 25
Loading