Skip to content
4 changes: 3 additions & 1 deletion dask_sql/physical/rel/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def convert(self, rel: "LogicalPlan", context: "dask_sql.Context") -> dd.DataFra

@staticmethod
def fix_column_to_row_type(
cc: ColumnContainer, row_type: "RelDataType"
cc: ColumnContainer, row_type: "RelDataType", join_type: str = None
) -> ColumnContainer:
"""
Make sure that the given column container
Expand All @@ -39,6 +39,8 @@ def fix_column_to_row_type(
and will just "blindly" rename the columns.
"""
field_names = [str(x) for x in row_type.getFieldNames()]
if join_type in ("leftsemi", "leftanti"):
field_names = field_names[: len(cc.columns)]

logger.debug(f"Renaming {cc.columns} to {field_names}")
cc = cc.rename_handle_duplicates(
Expand Down
50 changes: 37 additions & 13 deletions dask_sql/physical/rel/logical/join.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from dask_sql.physical.rel.base import BaseRelPlugin
from dask_sql.physical.rel.logical.filter import filter_or_scalar
from dask_sql.physical.rex import RexConverter
from dask_sql.utils import is_cudf_type

if TYPE_CHECKING:
import dask_sql
Expand Down Expand Up @@ -45,7 +46,8 @@ class DaskJoinPlugin(BaseRelPlugin):
"LEFT": "left",
"RIGHT": "right",
"FULL": "outer",
"LEFTSEMI": "inner", # TODO: Need research here! This is likely not a true inner join
"LEFTSEMI": "leftsemi",
"LEFTANTI": "leftanti",
}

def convert(self, rel: "LogicalPlan", context: "dask_sql.Context") -> DataContainer:
Expand Down Expand Up @@ -74,6 +76,9 @@ def convert(self, rel: "LogicalPlan", context: "dask_sql.Context") -> DataContai

join_type = join.getJoinType()
join_type = self.JOIN_TYPE_MAPPING[str(join_type)]
# TODO: update with correct implementation of leftsemi
if join_type == "leftsemi" and not is_cudf_type(df_lhs_renamed):
join_type = "inner"

# 3. The join condition can have two forms, that we can understand
# (a) a = b
Expand Down Expand Up @@ -170,22 +175,27 @@ def merge_single_partitions(lhs_partition, rhs_partition):

# 6. So the next step is to make sure
# we have the correct column order (and to remove the temporary join columns)
correct_column_order = list(df_lhs_renamed.columns) + list(
df_rhs_renamed.columns
)
if join_type in ("leftsemi", "leftanti"):
correct_column_order = list(df_lhs_renamed.columns)
else:
correct_column_order = list(df_lhs_renamed.columns) + list(
df_rhs_renamed.columns
)
cc = ColumnContainer(df.columns).limit_to(correct_column_order)

# and to rename them like the rel specifies
row_type = rel.getRowType()
field_specifications = [str(f) for f in row_type.getFieldNames()]
if join_type in ("leftsemi", "leftanti"):
field_specifications = field_specifications[: len(cc.columns)]

cc = cc.rename(
{
from_col: to_col
for from_col, to_col in zip(cc.columns, field_specifications)
}
)
cc = self.fix_column_to_row_type(cc, row_type)
cc = self.fix_column_to_row_type(cc, row_type, join_type)
dc = DataContainer(df, cc)

# 7. Last but not least we apply any filters by and-chaining together the filters
Expand All @@ -202,7 +212,9 @@ def merge_single_partitions(lhs_partition, rhs_partition):
df = filter_or_scalar(df, filter_condition)
dc = DataContainer(df, cc)

dc = self.fix_dtype_to_row_type(dc, rel.getRowType())
# TODO: Debug this...
if join_type not in ("leftsemi", "leftanti"):
dc = self.fix_dtype_to_row_type(dc, rel.getRowType())
# # Rename underlying DataFrame column names back to their original values before returning
# df = dc.assign()
# dc = DataContainer(df, ColumnContainer(cc.columns))
Expand All @@ -227,7 +239,7 @@ def _join_on_columns(
[~df_lhs_renamed.iloc[:, index].isna() for index in lhs_on],
)
df_lhs_renamed = df_lhs_renamed[df_lhs_filter]
if join_type in ["inner", "left"]:
if join_type in ["inner", "left", "leftanti", "leftsemi"]:
df_rhs_filter = reduce(
operator.and_,
[~df_rhs_renamed.iloc[:, index].isna() for index in rhs_on],
Expand Down Expand Up @@ -256,12 +268,24 @@ def _join_on_columns(
"For more information refer to https://github.com/dask/dask/issues/9851"
" and https://github.com/dask/dask/issues/9870"
)
df = df_lhs_with_tmp.merge(
df_rhs_with_tmp,
on=added_columns,
how=join_type,
broadcast=broadcast,
).drop(columns=added_columns)
if join_type == "leftanti" and not is_cudf_type(df_lhs_with_tmp):
df = df_lhs_with_tmp.merge(
df_rhs_with_tmp,
on=added_columns,
how="left",
broadcast=broadcast,
indicator=True,
).drop(columns=added_columns)
df = df[df["_merge"] == "left_only"].drop(
columns=["_merge"] + list(df_rhs_with_tmp.columns), errors="ignore"
)
else:
df = df_lhs_with_tmp.merge(
df_rhs_with_tmp,
on=added_columns,
how=join_type,
broadcast=broadcast,
).drop(columns=added_columns)

return df

Expand Down
52 changes: 52 additions & 0 deletions tests/integration/test_join.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,58 @@ def test_join_left(c):
assert_eq(return_df, expected_df, check_index=False)


@pytest.mark.parametrize("gpu", [False, pytest.param(True, marks=pytest.mark.gpu)])
def test_join_left_anti(c, gpu):
df1 = pd.DataFrame({"id": [1, 1, 2, 4], "a": ["a", "b", "c", "d"]})
df2 = pd.DataFrame({"id": [2, 1, 2, 3], "b": ["c", "c", "a", "c"]})
c.create_table("df_1", df1, gpu=gpu)
c.create_table("df_2", df2, gpu=gpu)

return_df = c.sql(
"""
SELECT lhs.id, lhs.a
FROM df_1 AS lhs
LEFT ANTI JOIN df_2 AS rhs
ON lhs.id = rhs.id
"""
)
# TODO: Figure out why this returns lhs.id instead of id
expected_df = pd.DataFrame(
{
"lhs.id": [4],
"a": ["d"],
}
)

assert_eq(return_df, expected_df, check_index=False)


@pytest.mark.gpu
def test_join_left_semi(c):
df1 = pd.DataFrame({"id": [1, 1, 2, 4], "a": ["a", "b", "c", "d"]})
df2 = pd.DataFrame({"id": [2, 1, 2, 3], "b": ["c", "c", "a", "c"]})
c.create_table("df_1", df1, gpu=True)
c.create_table("df_2", df2, gpu=True)

return_df = c.sql(
"""
SELECT lhs.id, lhs.a
FROM df_1 AS lhs
LEFT SEMI JOIN df_2 AS rhs
ON lhs.id = rhs.id
"""
)
# TODO: Figure out why this returns lhs.id instead of id
expected_df = pd.DataFrame(
{
"lhs.id": [1, 1, 2],
"a": ["a", "b", "c"],
}
)

assert_eq(return_df, expected_df, check_index=False)


def test_join_right(c):
return_df = c.sql(
"""
Expand Down
1 change: 0 additions & 1 deletion tests/unit/test_queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@
77,
80,
86,
87,
88,
89,
92,
Expand Down