Skip to content
19 changes: 10 additions & 9 deletions dask_sql/physical/rel/logical/join.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,15 +259,16 @@ def _split_join_condition(
filter_condition = []

for operand in operands:
try:
lhs_on_part, rhs_on_part = self._extract_lhs_rhs(operand)
lhs_on.append(lhs_on_part)
rhs_on.append(rhs_on_part)
continue
except AssertionError:
pass

filter_condition.append(operand)
if isinstance(operand, org.apache.calcite.rex.RexCall):
try:
lhs_on_part, rhs_on_part = self._extract_lhs_rhs(operand)
lhs_on.append(lhs_on_part)
rhs_on.append(rhs_on_part)
continue
except AssertionError:
pass

filter_condition.append(operand)

if lhs_on and rhs_on:
return lhs_on, rhs_on, filter_condition
Expand Down
20 changes: 20 additions & 0 deletions tests/integration/test_join.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,3 +183,23 @@ def test_join_literal(c):
df_expected = pd.DataFrame({"user_id": [], "b": [], "user_id0": [], "c": []})

assert_frame_equal(df.reset_index(), df_expected.reset_index(), check_dtype=False)


def test_conditional_join(c):
df1 = pd.DataFrame({"a": [1, 2, 2], "b": ["x", "y", "z"]})
df2 = pd.DataFrame({"c": [2, 3, 5], "d": ["i", "j", "k"]})

c.create_table("df1", df1)
c.create_table("df2", df2)

query = """
SELECT
a
FROM df1
INNER JOIN df2 ON
(
a = c
AND b IS NOT NULL
)
"""
ddf = c.sql(query)