Skip to content
11 changes: 7 additions & 4 deletions dask_sql/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -845,15 +845,19 @@ def _get_ral(self, sql):
def _compute_table_from_rel(self, rel: "LogicalPlan", return_futures: bool = True):
dc = RelConverter.convert(rel, context=self)

# Optimization might remove some alias projects. Make sure to keep them here.
select_names = [field for field in rel.getRowType().getFieldList()]

if rel.get_current_node_type() == "Explain":
return dc
if dc is None:
return

# Optimization might remove some alias projects. Make sure to keep them here.
select_names = [field for field in rel.getRowType().getFieldList()]

if select_names:
cc = dc.column_container

select_names = select_names[: len(cc.columns)]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you know why getFieldList and getFieldNames are returning an extra field that we need to filter out here (and in fix_column_to_row_type)? I want to believe there's something we could do on the Rust end to avoid needing these workarounds in the Python code, but okay with filing that as a follow up issue to address.

cc @jdye64 if you have any idea what could be going on here

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah it seems like it's treating left semi and left anti joins like left joins where the column merged on from both left and right tables is included in the result whereas in leftsemi and leftanti joins only the left hand column should be returned

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would you mind filing an issue around this and linking it here so we could follow up on this later if possible?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


# Use FQ name if not unique and simple name if it is unique. If a join contains the same column
# names the output col is prepended with the fully qualified column name
field_counts = Counter([field.getName() for field in select_names])
Expand All @@ -864,7 +868,6 @@ def _compute_table_from_rel(self, rel: "LogicalPlan", return_futures: bool = Tru
for field in select_names
]

cc = dc.column_container
cc = cc.rename(
{
df_col: select_name
Expand Down
15 changes: 11 additions & 4 deletions dask_sql/physical/rel/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def convert(self, rel: "LogicalPlan", context: "dask_sql.Context") -> dd.DataFra

@staticmethod
def fix_column_to_row_type(
cc: ColumnContainer, row_type: "RelDataType"
cc: ColumnContainer, row_type: "RelDataType", join_type: str = None
) -> ColumnContainer:
"""
Make sure that the given column container
Expand All @@ -39,6 +39,8 @@ def fix_column_to_row_type(
and will just "blindly" rename the columns.
"""
field_names = [str(x) for x in row_type.getFieldNames()]
if join_type in ("leftsemi", "leftanti"):
field_names = field_names[: len(cc.columns)]

logger.debug(f"Renaming {cc.columns} to {field_names}")
cc = cc.rename_handle_duplicates(
Expand Down Expand Up @@ -84,7 +86,9 @@ def assert_inputs(
return [RelConverter.convert(input_rel, context) for input_rel in input_rels]

@staticmethod
def fix_dtype_to_row_type(dc: DataContainer, row_type: "RelDataType"):
def fix_dtype_to_row_type(
dc: DataContainer, row_type: "RelDataType", join_type: str = None
):
"""
Fix the dtype of the given data container (or: the df within it)
to the data type given as argument.
Expand All @@ -98,9 +102,12 @@ def fix_dtype_to_row_type(dc: DataContainer, row_type: "RelDataType"):
df = dc.df
cc = dc.column_container

field_list = row_type.getFieldList()
if join_type in ("leftsemi", "leftanti"):
field_list = field_list[: len(cc.columns)]

field_types = {
str(field.getQualifiedName()): field.getType()
for field in row_type.getFieldList()
str(field.getQualifiedName()): field.getType() for field in field_list
}

for field_name, field_type in field_types.items():
Expand Down
48 changes: 35 additions & 13 deletions dask_sql/physical/rel/logical/join.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from dask_sql.physical.rel.base import BaseRelPlugin
from dask_sql.physical.rel.logical.filter import filter_or_scalar
from dask_sql.physical.rex import RexConverter
from dask_sql.utils import is_cudf_type

if TYPE_CHECKING:
import dask_sql
Expand Down Expand Up @@ -45,7 +46,8 @@ class DaskJoinPlugin(BaseRelPlugin):
"LEFT": "left",
"RIGHT": "right",
"FULL": "outer",
"LEFTSEMI": "inner", # TODO: Need research here! This is likely not a true inner join
"LEFTSEMI": "leftsemi",
"LEFTANTI": "leftanti",
}

def convert(self, rel: "LogicalPlan", context: "dask_sql.Context") -> DataContainer:
Expand Down Expand Up @@ -74,6 +76,9 @@ def convert(self, rel: "LogicalPlan", context: "dask_sql.Context") -> DataContai

join_type = join.getJoinType()
join_type = self.JOIN_TYPE_MAPPING[str(join_type)]
# TODO: update with correct implementation of leftsemi
if join_type == "leftsemi" and not is_cudf_type(df_lhs_renamed):
join_type = "inner"

# 3. The join condition can have two forms, that we can understand
# (a) a = b
Expand Down Expand Up @@ -170,22 +175,27 @@ def merge_single_partitions(lhs_partition, rhs_partition):

# 6. So the next step is to make sure
# we have the correct column order (and to remove the temporary join columns)
correct_column_order = list(df_lhs_renamed.columns) + list(
df_rhs_renamed.columns
)
if join_type in ("leftsemi", "leftanti"):
correct_column_order = list(df_lhs_renamed.columns)
else:
correct_column_order = list(df_lhs_renamed.columns) + list(
df_rhs_renamed.columns
)
cc = ColumnContainer(df.columns).limit_to(correct_column_order)

# and to rename them like the rel specifies
row_type = rel.getRowType()
field_specifications = [str(f) for f in row_type.getFieldNames()]
if join_type in ("leftsemi", "leftanti"):
field_specifications = field_specifications[: len(cc.columns)]

cc = cc.rename(
{
from_col: to_col
for from_col, to_col in zip(cc.columns, field_specifications)
}
)
cc = self.fix_column_to_row_type(cc, row_type)
cc = self.fix_column_to_row_type(cc, row_type, join_type)
dc = DataContainer(df, cc)

# 7. Last but not least we apply any filters by and-chaining together the filters
Expand All @@ -202,7 +212,7 @@ def merge_single_partitions(lhs_partition, rhs_partition):
df = filter_or_scalar(df, filter_condition)
dc = DataContainer(df, cc)

dc = self.fix_dtype_to_row_type(dc, rel.getRowType())
dc = self.fix_dtype_to_row_type(dc, rel.getRowType(), join_type)
# # Rename underlying DataFrame column names back to their original values before returning
# df = dc.assign()
# dc = DataContainer(df, ColumnContainer(cc.columns))
Expand All @@ -227,7 +237,7 @@ def _join_on_columns(
[~df_lhs_renamed.iloc[:, index].isna() for index in lhs_on],
)
df_lhs_renamed = df_lhs_renamed[df_lhs_filter]
if join_type in ["inner", "left"]:
if join_type in ["inner", "left", "leftanti", "leftsemi"]:
df_rhs_filter = reduce(
operator.and_,
[~df_rhs_renamed.iloc[:, index].isna() for index in rhs_on],
Expand Down Expand Up @@ -256,12 +266,24 @@ def _join_on_columns(
"For more information refer to https://github.com/dask/dask/issues/9851"
" and https://github.com/dask/dask/issues/9870"
)
df = df_lhs_with_tmp.merge(
df_rhs_with_tmp,
on=added_columns,
how=join_type,
broadcast=broadcast,
).drop(columns=added_columns)
if join_type == "leftanti" and not is_cudf_type(df_lhs_with_tmp):
df = df_lhs_with_tmp.merge(
df_rhs_with_tmp,
on=added_columns,
how="left",
broadcast=broadcast,
indicator=True,
).drop(columns=added_columns)
df = df[df["_merge"] == "left_only"].drop(
columns=["_merge"] + list(df_rhs_with_tmp.columns), errors="ignore"
)
else:
df = df_lhs_with_tmp.merge(
df_rhs_with_tmp,
on=added_columns,
how=join_type,
broadcast=broadcast,
).drop(columns=added_columns)

return df

Expand Down
50 changes: 50 additions & 0 deletions tests/integration/test_join.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,56 @@ def test_join_left(c):
assert_eq(return_df, expected_df, check_index=False)


@pytest.mark.parametrize("gpu", [False, pytest.param(True, marks=pytest.mark.gpu)])
def test_join_left_anti(c, gpu):
df1 = pd.DataFrame({"id": [1, 1, 2, 4], "a": ["a", "b", "c", "d"]})
df2 = pd.DataFrame({"id": [2, 1, 2, 3], "b": ["c", "c", "a", "c"]})
c.create_table("df_1", df1, gpu=gpu)
c.create_table("df_2", df2, gpu=gpu)

return_df = c.sql(
"""
SELECT lhs.id, lhs.a
FROM df_1 AS lhs
LEFT ANTI JOIN df_2 AS rhs
ON lhs.id = rhs.id
"""
)
expected_df = pd.DataFrame(
{
"id": [4],
"a": ["d"],
}
)

assert_eq(return_df, expected_df, check_index=False)


@pytest.mark.gpu
def test_join_left_semi(c):
df1 = pd.DataFrame({"id": [1, 1, 2, 4], "a": ["a", "b", "c", "d"]})
df2 = pd.DataFrame({"id": [2, 1, 2, 3], "b": ["c", "c", "a", "c"]})
c.create_table("df_1", df1, gpu=True)
c.create_table("df_2", df2, gpu=True)

return_df = c.sql(
"""
SELECT lhs.id, lhs.a
FROM df_1 AS lhs
LEFT SEMI JOIN df_2 AS rhs
ON lhs.id = rhs.id
"""
)
expected_df = pd.DataFrame(
{
"id": [1, 1, 2],
"a": ["a", "b", "c"],
}
)

assert_eq(return_df, expected_df, check_index=False)


def test_join_right(c):
return_df = c.sql(
"""
Expand Down
1 change: 0 additions & 1 deletion tests/unit/test_queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@
77,
80,
86,
87,
88,
89,
92,
Expand Down