Skip to content
Merged
19 changes: 10 additions & 9 deletions dask_sql/physical/rel/logical/join.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,15 +259,16 @@ def _split_join_condition(
filter_condition = []

for operand in operands:
try:
lhs_on_part, rhs_on_part = self._extract_lhs_rhs(operand)
lhs_on.append(lhs_on_part)
rhs_on.append(rhs_on_part)
continue
except AssertionError:
pass

filter_condition.append(operand)
if isinstance(operand, org.apache.calcite.rex.RexCall):
try:
lhs_on_part, rhs_on_part = self._extract_lhs_rhs(operand)
lhs_on.append(lhs_on_part)
rhs_on.append(rhs_on_part)
continue
except AssertionError:
pass

filter_condition.append(operand)

if lhs_on and rhs_on:
return lhs_on, rhs_on, filter_condition
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,8 @@ private static class FixedRulesProgram implements Program {
CoreRules.AGGREGATE_REDUCE_FUNCTIONS, CoreRules.AGGREGATE_MERGE,
CoreRules.AGGREGATE_EXPAND_DISTINCT_AGGREGATES_TO_JOIN, CoreRules.AGGREGATE_JOIN_REMOVE,
CoreRules.PROJECT_MERGE, CoreRules.FILTER_MERGE, CoreRules.PROJECT_REMOVE,
CoreRules.PROJECT_REDUCE_EXPRESSIONS, CoreRules.FILTER_REDUCE_EXPRESSIONS,
CoreRules.FILTER_EXPAND_IS_NOT_DISTINCT_FROM, CoreRules.PROJECT_TO_LOGICAL_PROJECT_AND_WINDOW);
CoreRules.FILTER_REDUCE_EXPRESSIONS, CoreRules.FILTER_EXPAND_IS_NOT_DISTINCT_FROM,
CoreRules.PROJECT_TO_LOGICAL_PROJECT_AND_WINDOW);

@Override
public RelNode run(RelOptPlanner planner, RelNode rel, RelTraitSet requiredOutputTraits,
Expand Down
26 changes: 26 additions & 0 deletions tests/integration/test_join.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,3 +183,29 @@ def test_join_literal(c):
df_expected = pd.DataFrame({"user_id": [], "b": [], "user_id0": [], "c": []})

assert_frame_equal(df.reset_index(), df_expected.reset_index(), check_dtype=False)


def test_conditional_join(c):
df1 = pd.DataFrame({"a": [1, 2, 2, 5, 6], "b": ["w", "x", "y", "z", None]})
df2 = pd.DataFrame({"c": [None, 3, 2, 5], "d": ["h", "i", "j", "k"]})

expected_df = pd.merge(df1, df2, how="inner", left_on=["a"], right_on=["c"])
expected_df = expected_df[expected_df["b"] != None]["a"] # noqa: E711

c.create_table("df1", df1)
c.create_table("df2", df2)

actual_df = c.sql(
"""
SELECT a FROM df1
INNER JOIN df2 ON
(
a = c
AND b IS NOT NULL
)
"""
).compute()

assert_frame_equal(
actual_df.reset_index(), expected_df.reset_index(), check_dtype=False
)
8 changes: 8 additions & 0 deletions tests/integration/test_select.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,3 +116,11 @@ def test_timezones(c, datetime_table):
result_df = result_df.compute()

assert_frame_equal(result_df, datetime_table)


def test_multi_case_when(c):
df = pd.DataFrame({"a": [1, 6, 7, 8, 9]})
c.create_table("df", df)

query = "SELECT CASE WHEN a BETWEEN 6 AND 8 THEN 1 ELSE 0 END FROM df"
c.sql(query).compute()