Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 22 additions & 1 deletion datafusion/physical-plan/src/joins/sort_merge_join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1209,7 +1209,14 @@ impl SMJStream {
) {
// The reverse of the selection mask. For the rows not pass join filter above,
// we need to join them (left or right) with null rows for outer joins.
let not_mask = compute::not(mask)?;
let not_mask = if mask.null_count() > 0 {
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I ran the added test by #9080 again in a new laptop and found this bug. I'm not sure why previously the test passed locally and in CI in #9080. 🤔

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I re-checked the results in sort_merge_join.slt and it should be correct (as it is same as join.slt which is produced by hash join operator).

// If the mask contains nulls, we need to use `prep_null_mask_filter` to
// handle the nulls in the mask as false.
compute::not(&compute::prep_null_mask_filter(mask))?
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using the test in sort_merge_join.slt as example:

CREATE TABLE IF NOT EXISTS t1(t1_id INT, t1_name TEXT, t1_int INT) AS VALUES
(11, 'a', 1),
(22, 'b', 2),
(33, 'c', 3),
(44, 'd', 4);

CREATE TABLE IF NOT EXISTS t2(t2_id INT, t2_name TEXT, t2_int INT) AS VALUES
(11, 'z', 3),
(22, 'y', 1),
(44, 'x', 3),
(55, 'w', 3);

For the query SELECT t1_id, t1_int, t2_int FROM t1 LEFT JOIN t2 ON t1_id = t2_id AND t1_int >= t2_int, 33 3 NULL is one row to output without join filter. Join filter returns null on this row, so its value in mask selection array is null and the row is not picked.

Here we take reverse mask. As it is null, its reverse mask is still false, but this row should be selected here actually. So we need to call prep_null_mask_filter to process the mask before taking reverse array.

} else {
compute::not(mask)?
};

let null_joined_batch =
compute::filter_record_batch(&output_batch, &not_mask)?;

Expand Down Expand Up @@ -1254,6 +1261,20 @@ impl SMJStream {

// For full join, we also need to output the null joined rows from the buffered side
if matches!(self.join_type, JoinType::Full) {
// Handle not mask for buffered side further.
// For buffered side, we want to output the rows that are not null joined with
// the streamed side. i.e. the rows that are not null in the `buffered_indices`.
let not_mask = if buffered_indices.null_count() > 0 {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You might be able to avoid the unwrap by using

                            let not_mask = if let Some(nulls) = if buffered_indices.nulls() { 
                                let mask = not_mask.values() & nulls.inner();
                                BooleanArray::new(mask, None)
                            } else {
                                not_mask
                            };

let nulls = buffered_indices.nulls().unwrap();
let mask = not_mask.values() & nulls.inner();
BooleanArray::new(mask, None)
Copy link
Member Author

@viirya viirya Feb 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For full outer join, we need to output buffered rows that fail join filter. But in the output_batch batch, we only care about the rows with buffered_indices not null. Other rows with null indices are rows failed with equijoin predicates.

} else {
not_mask
};

let null_joined_batch =
compute::filter_record_batch(&output_batch, &not_mask)?;

let mut streamed_columns = self
.streamed_schema
.fields()
Expand Down