Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 10 additions & 9 deletions dask_sql/physical/rel/logical/join.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,15 +259,16 @@ def _split_join_condition(
filter_condition = []

for operand in operands:
try:
lhs_on_part, rhs_on_part = self._extract_lhs_rhs(operand)
lhs_on.append(lhs_on_part)
rhs_on.append(rhs_on_part)
continue
except AssertionError:
pass

filter_condition.append(operand)
if isinstance(operand, org.apache.calcite.rex.RexCall):
try:
lhs_on_part, rhs_on_part = self._extract_lhs_rhs(operand)
lhs_on.append(lhs_on_part)
rhs_on.append(rhs_on_part)
continue
except AssertionError:
pass

filter_condition.append(operand)

if lhs_on and rhs_on:
return lhs_on, rhs_on, filter_condition
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,6 @@ public DaskPlanner() {
addRule(CoreRules.JOIN_COMMUTE);
addRule(CoreRules.FILTER_INTO_JOIN);
addRule(CoreRules.PROJECT_JOIN_TRANSPOSE);
addRule(JoinPushThroughJoinRule.RIGHT);
addRule(JoinPushThroughJoinRule.LEFT);
addRule(CoreRules.SORT_PROJECT_TRANSPOSE);
addRule(CoreRules.SORT_JOIN_TRANSPOSE);
addRule(CoreRules.SORT_UNION_TRANSPOSE);
Expand Down
22 changes: 15 additions & 7 deletions planner/src/main/java/com/dask/sql/application/DaskProgram.java
Original file line number Diff line number Diff line change
Expand Up @@ -87,13 +87,21 @@ public RelNode run(RelNode rel, RelTraitSet relTraitSet) {
* FixedRulesProgram applies a fixed set of conversion rules, which we always
*/
private static class FixedRulesProgram implements Program {
static private final List<RelOptRule> RULES = Arrays.asList(CoreRules.AGGREGATE_PROJECT_PULL_UP_CONSTANTS,
CoreRules.AGGREGATE_ANY_PULL_UP_CONSTANTS, CoreRules.AGGREGATE_PROJECT_MERGE,
CoreRules.AGGREGATE_REDUCE_FUNCTIONS, CoreRules.AGGREGATE_MERGE,
CoreRules.AGGREGATE_EXPAND_DISTINCT_AGGREGATES_TO_JOIN, CoreRules.AGGREGATE_JOIN_REMOVE,
CoreRules.PROJECT_MERGE, CoreRules.FILTER_MERGE, CoreRules.PROJECT_REMOVE,
CoreRules.PROJECT_REDUCE_EXPRESSIONS, CoreRules.FILTER_REDUCE_EXPRESSIONS,
CoreRules.FILTER_EXPAND_IS_NOT_DISTINCT_FROM, CoreRules.PROJECT_TO_LOGICAL_PROJECT_AND_WINDOW);
static private final List<RelOptRule> RULES = Arrays.asList(
CoreRules.AGGREGATE_PROJECT_PULL_UP_CONSTANTS,
CoreRules.AGGREGATE_ANY_PULL_UP_CONSTANTS,
CoreRules.AGGREGATE_PROJECT_MERGE,
CoreRules.AGGREGATE_REDUCE_FUNCTIONS,
CoreRules.AGGREGATE_MERGE,
CoreRules.AGGREGATE_EXPAND_DISTINCT_AGGREGATES_TO_JOIN,
CoreRules.AGGREGATE_JOIN_REMOVE,
CoreRules.PROJECT_MERGE,
CoreRules.FILTER_MERGE,
CoreRules.PROJECT_REMOVE,
CoreRules.FILTER_REDUCE_EXPRESSIONS,
CoreRules.FILTER_EXPAND_IS_NOT_DISTINCT_FROM,
CoreRules.PROJECT_TO_LOGICAL_PROJECT_AND_WINDOW
);

@Override
public RelNode run(RelOptPlanner planner, RelNode rel, RelTraitSet requiredOutputTraits,
Expand Down
66 changes: 66 additions & 0 deletions tests/integration/test_join.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
import pandas as pd
from pandas.testing import assert_frame_equal

from dask_sql import Context


def test_join(c):
df = c.sql(
Expand Down Expand Up @@ -183,3 +185,67 @@ def test_join_literal(c):
df_expected = pd.DataFrame({"user_id": [], "b": [], "user_id0": [], "c": []})

assert_frame_equal(df.reset_index(), df_expected.reset_index(), check_dtype=False)


def test_conditional_join(c):
df1 = pd.DataFrame({"a": [1, 2, 2, 5, 6], "b": ["w", "x", "y", "z", None]})
df2 = pd.DataFrame({"c": [None, 3, 2, 5], "d": ["h", "i", "j", "k"]})

expected_df = pd.merge(df1, df2, how="inner", left_on=["a"], right_on=["c"])
expected_df = expected_df[expected_df["b"] != None]["a"] # noqa: E711

c.create_table("df1", df1)
c.create_table("df2", df2)

actual_df = c.sql(
"""
SELECT a FROM df1
INNER JOIN df2 ON
(
a = c
AND b IS NOT NULL
)
"""
).compute()

assert_frame_equal(
actual_df.reset_index(), expected_df.reset_index(), check_dtype=False
)


def test_join_case_projection_subquery():
c = Context()

# Tables for query
demo = pd.DataFrame({"demo_sku": [], "hd_dep_count": []})
site_page = pd.DataFrame({"site_page_sk": [], "site_char_count": []})
sales = pd.DataFrame(
{"sales_hdemo_sk": [], "sales_page_sk": [], "sold_time_sk": []}
)
t_dim = pd.DataFrame({"t_time_sk": [], "t_hour": []})

c.create_table("demos", demo, persist=False)
c.create_table("site_page", site_page, persist=False)
c.create_table("sales", sales, persist=False)
c.create_table("t_dim", t_dim, persist=False)

actual_df = c.sql(
"""
SELECT CASE WHEN pmc > 0.0 THEN CAST (amc AS DOUBLE) / CAST (pmc AS DOUBLE) ELSE -1.0 END AS am_pm_ratio
FROM
(
SELECT SUM(amc1) AS amc, SUM(pmc1) AS pmc
FROM
(
SELECT
CASE WHEN t_hour BETWEEN 7 AND 8 THEN COUNT(1) ELSE 0 END AS amc1,
CASE WHEN t_hour BETWEEN 19 AND 20 THEN COUNT(1) ELSE 0 END AS pmc1
FROM sales ws
JOIN demos hd ON (hd.demo_sku = ws.sales_hdemo_sk and hd.hd_dep_count = 5)
JOIN site_page sp ON (sp.site_page_sk = ws.sales_page_sk and sp.site_char_count BETWEEN 5000 AND 6000)
JOIN t_dim td ON (td.t_time_sk = ws.sold_time_sk and td.t_hour IN (7,8,19,20))
GROUP BY t_hour
) cnt_am_pm
) sum_am_pm
"""
).compute()
16 changes: 16 additions & 0 deletions tests/integration/test_select.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,3 +116,19 @@ def test_timezones(c, datetime_table):
result_df = result_df.compute()

assert_frame_equal(result_df, datetime_table)


def test_multi_case_when(c):
df = pd.DataFrame({"a": [1, 6, 7, 8, 9]})
c.create_table("df", df)

actual_df = c.sql(
"""
SELECT
CASE WHEN a BETWEEN 6 AND 8 THEN 1 ELSE 0 END AS C
FROM df
"""
).compute()

expected_df = pd.DataFrame({"C": [0, 1, 1, 1, 0]}, dtype=np.int32)
assert_frame_equal(actual_df, expected_df)