Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions datafusion/common/src/join_type.rs
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,8 @@ impl JoinType {
| JoinType::RightSemi
| JoinType::LeftAnti
| JoinType::RightAnti
| JoinType::LeftMark
| JoinType::RightMark
)
}
}
Expand Down
8 changes: 4 additions & 4 deletions datafusion/core/tests/fuzz_cases/join_fuzz.rs
Original file line number Diff line number Diff line change
Expand Up @@ -314,7 +314,7 @@ async fn test_right_mark_join_1k() {
JoinType::RightMark,
None,
)
.run_test(&[NljHj], false)
.run_test(&[HjSmj, NljHj], false)
.await
}

Expand All @@ -326,7 +326,7 @@ async fn test_right_mark_join_1k_filtered() {
JoinType::RightMark,
Some(Box::new(col_lt_col_filter)),
)
.run_test(&[NljHj], false)
.run_test(&[HjSmj, NljHj], false)
.await
}

Expand Down Expand Up @@ -555,7 +555,7 @@ async fn test_right_mark_join_1k_binary() {
JoinType::RightMark,
None,
)
.run_test(&[NljHj], false)
.run_test(&[HjSmj, NljHj], false)
.await
}

Expand All @@ -567,7 +567,7 @@ async fn test_right_mark_join_1k_binary_filtered() {
JoinType::RightMark,
Some(Box::new(col_lt_col_filter)),
)
.run_test(&[NljHj], false)
.run_test(&[HjSmj, NljHj], false)
.await
}

Expand Down
58 changes: 57 additions & 1 deletion datafusion/core/tests/physical_optimizer/join_selection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,61 @@ async fn test_join_with_swap_semi() {
}
}

#[tokio::test]
async fn test_join_with_swap_mark() {
let join_types = [JoinType::LeftMark, JoinType::RightMark];
for join_type in join_types {
let (big, small) = create_big_and_small();

let join = HashJoinExec::try_new(
Arc::clone(&big),
Arc::clone(&small),
vec![(
Arc::new(Column::new_with_schema("big_col", &big.schema()).unwrap()),
Arc::new(Column::new_with_schema("small_col", &small.schema()).unwrap()),
)],
None,
&join_type,
None,
PartitionMode::Partitioned,
NullEquality::NullEqualsNothing,
)
.unwrap();

let original_schema = join.schema();

let optimized_join = JoinSelection::new()
.optimize(Arc::new(join), &ConfigOptions::new())
.unwrap();

let swapped_join = optimized_join
.as_any()
.downcast_ref::<HashJoinExec>()
.expect(
"A proj is not required to swap columns back to their original order",
);

assert_eq!(swapped_join.schema().fields().len(), 2);
assert_eq!(
swapped_join
.left()
.partition_statistics(None)
.unwrap()
.total_byte_size,
Precision::Inexact(8192)
);
assert_eq!(
swapped_join
.right()
.partition_statistics(None)
.unwrap()
.total_byte_size,
Precision::Inexact(2097152)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can this number be a reason of flaky test in the future?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think so, i have copied this from the other previous join selection swapping testing and thos ehave been working fine.

);
assert_eq!(original_schema, swapped_join.schema());
}
}

/// Compare the input plan with the plan after running the probe order optimizer.
macro_rules! assert_optimized {
($EXPECTED_LINES: expr, $PLAN: expr) => {
Expand Down Expand Up @@ -576,7 +631,8 @@ async fn test_nl_join_with_swap(join_type: JoinType) {
case::left_semi(JoinType::LeftSemi),
case::left_anti(JoinType::LeftAnti),
case::right_semi(JoinType::RightSemi),
case::right_anti(JoinType::RightAnti)
case::right_anti(JoinType::RightAnti),
case::right_mark(JoinType::RightMark)
)]
#[tokio::test]
async fn test_nl_join_with_swap_no_proj(join_type: JoinType) {
Expand Down
5 changes: 4 additions & 1 deletion datafusion/expr/src/logical_plan/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1696,7 +1696,10 @@ pub fn build_join_schema(
);

let (schema1, schema2) = match join_type {
JoinType::Right | JoinType::RightSemi | JoinType::RightAnti => (left, right),
JoinType::Right
| JoinType::RightSemi
| JoinType::RightAnti
| JoinType::RightMark => (left, right),
_ => (right, left),
};

Expand Down
49 changes: 47 additions & 2 deletions datafusion/optimizer/src/decorrelate_predicate_subquery.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ use datafusion_common::{internal_err, plan_err, Column, Result};
use datafusion_expr::expr::{Exists, InSubquery};
use datafusion_expr::expr_rewriter::create_col_from_scalar_expr;
use datafusion_expr::logical_plan::{JoinType, Subquery};
use datafusion_expr::utils::{conjunction, split_conjunction_owned};
use datafusion_expr::utils::{conjunction, expr_to_columns, split_conjunction_owned};
use datafusion_expr::{
exists, in_subquery, lit, not, not_exists, not_in_subquery, BinaryExpr, Expr, Filter,
LogicalPlan, LogicalPlanBuilder, Operator,
Expand Down Expand Up @@ -342,7 +342,7 @@ fn build_join(
replace_qualified_name(filter, &all_correlated_cols, &alias).map(Some)
})?;

let join_filter = match (join_filter_opt, in_predicate_opt) {
let join_filter = match (join_filter_opt, in_predicate_opt.clone()) {
(
Some(join_filter),
Some(Expr::BinaryExpr(BinaryExpr {
Expand Down Expand Up @@ -371,6 +371,51 @@ fn build_join(
(None, None) => lit(true),
_ => return Ok(None),
};

if matches!(join_type, JoinType::LeftMark | JoinType::RightMark) {
let right_schema = sub_query_alias.schema();

// Gather all columns needed for the join filter + predicates
let mut needed = std::collections::HashSet::new();
expr_to_columns(&join_filter, &mut needed)?;
if let Some(ref in_pred) = in_predicate_opt {
expr_to_columns(in_pred, &mut needed)?;
}

// Keep only columns that actually belong to the RIGHT child, and sort by their
// position in the right schema for deterministic order.
let mut right_cols_idx_and_col: Vec<(usize, Column)> = needed
.into_iter()
.filter_map(|c| right_schema.index_of_column(&c).ok().map(|idx| (idx, c)))
.collect();

right_cols_idx_and_col.sort_by_key(|(idx, _)| *idx);

let right_proj_exprs: Vec<Expr> = right_cols_idx_and_col
.into_iter()
.map(|(_, c)| Expr::Column(c))
.collect();

let right_projected = if !right_proj_exprs.is_empty() {
LogicalPlanBuilder::from(sub_query_alias.clone())
.project(right_proj_exprs)?
.build()?
} else {
// Degenerate case: no right columns referenced by the predicate(s)
sub_query_alias.clone()
};
let new_plan = LogicalPlanBuilder::from(left.clone())
.join_on(right_projected, join_type, Some(join_filter))?
.build()?;

debug!(
"predicate subquery optimized:\n{}",
new_plan.display_indent()
);

return Ok(Some(new_plan));
}

// join our sub query into the main plan
let new_plan = LogicalPlanBuilder::from(left.clone())
.join_on(sub_query_alias, join_type, Some(join_filter))?
Expand Down
6 changes: 5 additions & 1 deletion datafusion/physical-optimizer/src/join_selection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -514,7 +514,11 @@ pub(crate) fn swap_join_according_to_unboundedness(
match (*partition_mode, *join_type) {
(
_,
JoinType::Right | JoinType::RightSemi | JoinType::RightAnti | JoinType::Full,
JoinType::Right
| JoinType::RightSemi
| JoinType::RightAnti
| JoinType::RightMark
| JoinType::Full,
) => internal_err!("{join_type} join cannot be swapped for unbounded input."),
(PartitionMode::Partitioned, _) => {
hash_join.swap_inputs(PartitionMode::Partitioned)
Expand Down
2 changes: 2 additions & 0 deletions datafusion/physical-plan/src/joins/hash_join/exec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -690,6 +690,8 @@ impl HashJoinExec {
| JoinType::RightSemi
| JoinType::LeftAnti
| JoinType::RightAnti
| JoinType::LeftMark
| JoinType::RightMark
) || self.projection.is_some()
{
Ok(Arc::new(new_join))
Expand Down
2 changes: 2 additions & 0 deletions datafusion/physical-plan/src/joins/nested_loop_join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -379,6 +379,8 @@ impl NestedLoopJoinExec {
| JoinType::RightSemi
| JoinType::LeftAnti
| JoinType::RightAnti
| JoinType::LeftMark
| JoinType::RightMark
) || self.projection.is_some()
{
Arc::new(new_join)
Expand Down
73 changes: 43 additions & 30 deletions datafusion/physical-plan/src/joins/sort_merge_join/stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -430,7 +430,7 @@ pub(super) fn get_corrected_filter_mask(
corrected_mask.append_n(expected_size - corrected_mask.len(), false);
Some(corrected_mask.finish())
}
JoinType::LeftMark => {
JoinType::LeftMark | JoinType::RightMark => {
for i in 0..row_indices_length {
let last_index =
last_index_for_row(i, row_indices, batch_ids, row_indices_length);
Expand Down Expand Up @@ -582,6 +582,7 @@ impl Stream for SortMergeJoinStream {
| JoinType::LeftMark
| JoinType::Right
| JoinType::RightSemi
| JoinType::RightMark
| JoinType::LeftAnti
| JoinType::RightAnti
| JoinType::Full
Expand Down Expand Up @@ -691,6 +692,7 @@ impl Stream for SortMergeJoinStream {
| JoinType::LeftAnti
| JoinType::RightAnti
| JoinType::LeftMark
| JoinType::RightMark
| JoinType::Full
)
{
Expand Down Expand Up @@ -718,6 +720,7 @@ impl Stream for SortMergeJoinStream {
| JoinType::RightAnti
| JoinType::Full
| JoinType::LeftMark
| JoinType::RightMark
)
{
let record_batch = self.filter_joined_batch()?;
Expand Down Expand Up @@ -1042,16 +1045,23 @@ impl SortMergeJoinStream {
| JoinType::LeftAnti
| JoinType::RightAnti
| JoinType::LeftMark
| JoinType::RightMark
) {
join_streamed = !self.streamed_joined;
}
}
Ordering::Equal => {
if matches!(
self.join_type,
JoinType::LeftSemi | JoinType::LeftMark | JoinType::RightSemi
JoinType::LeftSemi
| JoinType::LeftMark
| JoinType::RightSemi
| JoinType::RightMark
) {
mark_row_as_match = matches!(self.join_type, JoinType::LeftMark);
mark_row_as_match = matches!(
self.join_type,
JoinType::LeftMark | JoinType::RightMark
);
// if the join filter is specified then its needed to output the streamed index
// only if it has not been emitted before
// the `join_filter_matched_idxs` keeps track on if streamed index has a successful
Expand Down Expand Up @@ -1266,31 +1276,32 @@ impl SortMergeJoinStream {

// The row indices of joined buffered batch
let right_indices: UInt64Array = chunk.buffered_indices.finish();
let mut right_columns = if matches!(self.join_type, JoinType::LeftMark) {
vec![Arc::new(is_not_null(&right_indices)?) as ArrayRef]
} else if matches!(
self.join_type,
JoinType::LeftSemi
| JoinType::LeftAnti
| JoinType::RightAnti
| JoinType::RightSemi
) {
vec![]
} else if let Some(buffered_idx) = chunk.buffered_batch_idx {
fetch_right_columns_by_idxs(
&self.buffered_data,
buffered_idx,
&right_indices,
)?
} else {
// If buffered batch none, meaning it is null joined batch.
// We need to create null arrays for buffered columns to join with streamed rows.
create_unmatched_columns(
let mut right_columns =
if matches!(self.join_type, JoinType::LeftMark | JoinType::RightMark) {
vec![Arc::new(is_not_null(&right_indices)?) as ArrayRef]
} else if matches!(
self.join_type,
&self.buffered_schema,
right_indices.len(),
)
};
JoinType::LeftSemi
| JoinType::LeftAnti
| JoinType::RightAnti
| JoinType::RightSemi
) {
vec![]
} else if let Some(buffered_idx) = chunk.buffered_batch_idx {
fetch_right_columns_by_idxs(
&self.buffered_data,
buffered_idx,
&right_indices,
)?
} else {
// If buffered batch none, meaning it is null joined batch.
// We need to create null arrays for buffered columns to join with streamed rows.
create_unmatched_columns(
self.join_type,
&self.buffered_schema,
right_indices.len(),
)
};

// Prepare the columns we apply join filter on later.
// Only for joined rows between streamed and buffered.
Expand All @@ -1309,7 +1320,7 @@ impl SortMergeJoinStream {
get_filter_column(&self.filter, &left_columns, &right_cols)
} else if matches!(
self.join_type,
JoinType::RightAnti | JoinType::RightSemi
JoinType::RightAnti | JoinType::RightSemi | JoinType::RightMark
) {
let right_cols = fetch_right_columns_by_idxs(
&self.buffered_data,
Expand Down Expand Up @@ -1375,6 +1386,7 @@ impl SortMergeJoinStream {
| JoinType::LeftAnti
| JoinType::RightAnti
| JoinType::LeftMark
| JoinType::RightMark
| JoinType::Full
) {
self.staging_output_record_batches
Expand Down Expand Up @@ -1475,6 +1487,7 @@ impl SortMergeJoinStream {
| JoinType::LeftAnti
| JoinType::RightAnti
| JoinType::LeftMark
| JoinType::RightMark
| JoinType::Full
))
{
Expand Down Expand Up @@ -1537,7 +1550,7 @@ impl SortMergeJoinStream {

if matches!(
self.join_type,
JoinType::Left | JoinType::LeftMark | JoinType::Right
JoinType::Left | JoinType::LeftMark | JoinType::Right | JoinType::RightMark
) {
let null_mask = compute::not(corrected_mask)?;
let null_joined_batch = filter_record_batch(&record_batch, &null_mask)?;
Expand Down Expand Up @@ -1658,7 +1671,7 @@ fn create_unmatched_columns(
schema: &SchemaRef,
size: usize,
) -> Vec<ArrayRef> {
if matches!(join_type, JoinType::LeftMark) {
if matches!(join_type, JoinType::LeftMark | JoinType::RightMark) {
vec![Arc::new(BooleanArray::from(vec![false; size])) as ArrayRef]
} else {
schema
Expand Down
Loading