diff --git a/dask_sql/physical/rel/logical/join.py b/dask_sql/physical/rel/logical/join.py index b84373d1f..4a2efb558 100644 --- a/dask_sql/physical/rel/logical/join.py +++ b/dask_sql/physical/rel/logical/join.py @@ -259,15 +259,16 @@ def _split_join_condition( filter_condition = [] for operand in operands: - try: - lhs_on_part, rhs_on_part = self._extract_lhs_rhs(operand) - lhs_on.append(lhs_on_part) - rhs_on.append(rhs_on_part) - continue - except AssertionError: - pass - - filter_condition.append(operand) + if isinstance(operand, org.apache.calcite.rex.RexCall): + try: + lhs_on_part, rhs_on_part = self._extract_lhs_rhs(operand) + lhs_on.append(lhs_on_part) + rhs_on.append(rhs_on_part) + continue + except AssertionError: + pass + + filter_condition.append(operand) if lhs_on and rhs_on: return lhs_on, rhs_on, filter_condition diff --git a/planner/src/main/java/com/dask/sql/application/DaskPlanner.java b/planner/src/main/java/com/dask/sql/application/DaskPlanner.java index 4de993283..878da8440 100644 --- a/planner/src/main/java/com/dask/sql/application/DaskPlanner.java +++ b/planner/src/main/java/com/dask/sql/application/DaskPlanner.java @@ -68,8 +68,6 @@ public DaskPlanner() { addRule(CoreRules.JOIN_COMMUTE); addRule(CoreRules.FILTER_INTO_JOIN); addRule(CoreRules.PROJECT_JOIN_TRANSPOSE); - addRule(JoinPushThroughJoinRule.RIGHT); - addRule(JoinPushThroughJoinRule.LEFT); addRule(CoreRules.SORT_PROJECT_TRANSPOSE); addRule(CoreRules.SORT_JOIN_TRANSPOSE); addRule(CoreRules.SORT_UNION_TRANSPOSE); diff --git a/planner/src/main/java/com/dask/sql/application/DaskProgram.java b/planner/src/main/java/com/dask/sql/application/DaskProgram.java index 054a6e7c4..b9d2f4387 100644 --- a/planner/src/main/java/com/dask/sql/application/DaskProgram.java +++ b/planner/src/main/java/com/dask/sql/application/DaskProgram.java @@ -87,13 +87,21 @@ public RelNode run(RelNode rel, RelTraitSet relTraitSet) { * FixedRulesProgram applies a fixed set of conversion rules, which we always */ private static class FixedRulesProgram implements Program { - static private final List RULES = Arrays.asList(CoreRules.AGGREGATE_PROJECT_PULL_UP_CONSTANTS, - CoreRules.AGGREGATE_ANY_PULL_UP_CONSTANTS, CoreRules.AGGREGATE_PROJECT_MERGE, - CoreRules.AGGREGATE_REDUCE_FUNCTIONS, CoreRules.AGGREGATE_MERGE, - CoreRules.AGGREGATE_EXPAND_DISTINCT_AGGREGATES_TO_JOIN, CoreRules.AGGREGATE_JOIN_REMOVE, - CoreRules.PROJECT_MERGE, CoreRules.FILTER_MERGE, CoreRules.PROJECT_REMOVE, - CoreRules.PROJECT_REDUCE_EXPRESSIONS, CoreRules.FILTER_REDUCE_EXPRESSIONS, - CoreRules.FILTER_EXPAND_IS_NOT_DISTINCT_FROM, CoreRules.PROJECT_TO_LOGICAL_PROJECT_AND_WINDOW); + static private final List RULES = Arrays.asList( + CoreRules.AGGREGATE_PROJECT_PULL_UP_CONSTANTS, + CoreRules.AGGREGATE_ANY_PULL_UP_CONSTANTS, + CoreRules.AGGREGATE_PROJECT_MERGE, + CoreRules.AGGREGATE_REDUCE_FUNCTIONS, + CoreRules.AGGREGATE_MERGE, + CoreRules.AGGREGATE_EXPAND_DISTINCT_AGGREGATES_TO_JOIN, + CoreRules.AGGREGATE_JOIN_REMOVE, + CoreRules.PROJECT_MERGE, + CoreRules.FILTER_MERGE, + CoreRules.PROJECT_REMOVE, + CoreRules.FILTER_REDUCE_EXPRESSIONS, + CoreRules.FILTER_EXPAND_IS_NOT_DISTINCT_FROM, + CoreRules.PROJECT_TO_LOGICAL_PROJECT_AND_WINDOW + ); @Override public RelNode run(RelOptPlanner planner, RelNode rel, RelTraitSet requiredOutputTraits, diff --git a/tests/integration/test_join.py b/tests/integration/test_join.py index e8984f355..2e1bb8fe7 100644 --- a/tests/integration/test_join.py +++ b/tests/integration/test_join.py @@ -2,6 +2,8 @@ import pandas as pd from pandas.testing import assert_frame_equal +from dask_sql import Context + def test_join(c): df = c.sql( @@ -183,3 +185,67 @@ def test_join_literal(c): df_expected = pd.DataFrame({"user_id": [], "b": [], "user_id0": [], "c": []}) assert_frame_equal(df.reset_index(), df_expected.reset_index(), check_dtype=False) + + +def test_conditional_join(c): + df1 = pd.DataFrame({"a": [1, 2, 2, 5, 6], "b": ["w", "x", "y", "z", None]}) + df2 = pd.DataFrame({"c": [None, 3, 2, 5], "d": ["h", "i", "j", "k"]}) + + expected_df = pd.merge(df1, df2, how="inner", left_on=["a"], right_on=["c"]) + expected_df = expected_df[expected_df["b"] != None]["a"] # noqa: E711 + + c.create_table("df1", df1) + c.create_table("df2", df2) + + actual_df = c.sql( + """ + SELECT a FROM df1 + INNER JOIN df2 ON + ( + a = c + AND b IS NOT NULL + ) + """ + ).compute() + + assert_frame_equal( + actual_df.reset_index(), expected_df.reset_index(), check_dtype=False + ) + + +def test_join_case_projection_subquery(): + c = Context() + + # Tables for query + demo = pd.DataFrame({"demo_sku": [], "hd_dep_count": []}) + site_page = pd.DataFrame({"site_page_sk": [], "site_char_count": []}) + sales = pd.DataFrame( + {"sales_hdemo_sk": [], "sales_page_sk": [], "sold_time_sk": []} + ) + t_dim = pd.DataFrame({"t_time_sk": [], "t_hour": []}) + + c.create_table("demos", demo, persist=False) + c.create_table("site_page", site_page, persist=False) + c.create_table("sales", sales, persist=False) + c.create_table("t_dim", t_dim, persist=False) + + actual_df = c.sql( + """ + SELECT CASE WHEN pmc > 0.0 THEN CAST (amc AS DOUBLE) / CAST (pmc AS DOUBLE) ELSE -1.0 END AS am_pm_ratio + FROM + ( + SELECT SUM(amc1) AS amc, SUM(pmc1) AS pmc + FROM + ( + SELECT + CASE WHEN t_hour BETWEEN 7 AND 8 THEN COUNT(1) ELSE 0 END AS amc1, + CASE WHEN t_hour BETWEEN 19 AND 20 THEN COUNT(1) ELSE 0 END AS pmc1 + FROM sales ws + JOIN demos hd ON (hd.demo_sku = ws.sales_hdemo_sk and hd.hd_dep_count = 5) + JOIN site_page sp ON (sp.site_page_sk = ws.sales_page_sk and sp.site_char_count BETWEEN 5000 AND 6000) + JOIN t_dim td ON (td.t_time_sk = ws.sold_time_sk and td.t_hour IN (7,8,19,20)) + GROUP BY t_hour + ) cnt_am_pm + ) sum_am_pm + """ + ).compute() diff --git a/tests/integration/test_select.py b/tests/integration/test_select.py index 2897ed15a..437631cef 100644 --- a/tests/integration/test_select.py +++ b/tests/integration/test_select.py @@ -116,3 +116,19 @@ def test_timezones(c, datetime_table): result_df = result_df.compute() assert_frame_equal(result_df, datetime_table) + + +def test_multi_case_when(c): + df = pd.DataFrame({"a": [1, 6, 7, 8, 9]}) + c.create_table("df", df) + + actual_df = c.sql( + """ + SELECT + CASE WHEN a BETWEEN 6 AND 8 THEN 1 ELSE 0 END AS C + FROM df + """ + ).compute() + + expected_df = pd.DataFrame({"C": [0, 1, 1, 1, 0]}, dtype=np.int32) + assert_frame_equal(actual_df, expected_df)