Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions datafusion/common/src/join_type.rs
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,8 @@ impl JoinType {
| JoinType::RightSemi
| JoinType::LeftAnti
| JoinType::RightAnti
| JoinType::LeftMark
| JoinType::RightMark
)
}
}
Expand Down
8 changes: 4 additions & 4 deletions datafusion/core/tests/fuzz_cases/join_fuzz.rs
Original file line number Diff line number Diff line change
Expand Up @@ -314,7 +314,7 @@ async fn test_right_mark_join_1k() {
JoinType::RightMark,
None,
)
.run_test(&[NljHj], false)
.run_test(&[HjSmj, NljHj], false)
.await
}

Expand All @@ -326,7 +326,7 @@ async fn test_right_mark_join_1k_filtered() {
JoinType::RightMark,
Some(Box::new(col_lt_col_filter)),
)
.run_test(&[NljHj], false)
.run_test(&[HjSmj, NljHj], false)
.await
}

Expand Down Expand Up @@ -555,7 +555,7 @@ async fn test_right_mark_join_1k_binary() {
JoinType::RightMark,
None,
)
.run_test(&[NljHj], false)
.run_test(&[HjSmj, NljHj], false)
.await
}

Expand All @@ -567,7 +567,7 @@ async fn test_right_mark_join_1k_binary_filtered() {
JoinType::RightMark,
Some(Box::new(col_lt_col_filter)),
)
.run_test(&[NljHj], false)
.run_test(&[HjSmj, NljHj], false)
.await
}

Expand Down
58 changes: 57 additions & 1 deletion datafusion/core/tests/physical_optimizer/join_selection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,61 @@ async fn test_join_with_swap_semi() {
}
}

#[tokio::test]
async fn test_join_with_swap_mark() {
let join_types = [JoinType::LeftMark, JoinType::RightMark];
for join_type in join_types {
let (big, small) = create_big_and_small();

let join = HashJoinExec::try_new(
Arc::clone(&big),
Arc::clone(&small),
vec![(
Arc::new(Column::new_with_schema("big_col", &big.schema()).unwrap()),
Arc::new(Column::new_with_schema("small_col", &small.schema()).unwrap()),
)],
None,
&join_type,
None,
PartitionMode::Partitioned,
NullEquality::NullEqualsNothing,
)
.unwrap();

let original_schema = join.schema();

let optimized_join = JoinSelection::new()
.optimize(Arc::new(join), &ConfigOptions::new())
.unwrap();

let swapped_join = optimized_join
.as_any()
.downcast_ref::<HashJoinExec>()
.expect(
"A proj is not required to swap columns back to their original order",
);

assert_eq!(swapped_join.schema().fields().len(), 2);
assert_eq!(
swapped_join
.left()
.partition_statistics(None)
.unwrap()
.total_byte_size,
Precision::Inexact(8192)
);
assert_eq!(
swapped_join
.right()
.partition_statistics(None)
.unwrap()
.total_byte_size,
Precision::Inexact(2097152)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can this number be a reason of flaky test in the future?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think so, i have copied this from the other previous join selection swapping testing and thos ehave been working fine.

);
assert_eq!(original_schema, swapped_join.schema());
}
}

/// Compare the input plan with the plan after running the probe order optimizer.
macro_rules! assert_optimized {
($EXPECTED_LINES: expr, $PLAN: expr) => {
Expand Down Expand Up @@ -576,7 +631,8 @@ async fn test_nl_join_with_swap(join_type: JoinType) {
case::left_semi(JoinType::LeftSemi),
case::left_anti(JoinType::LeftAnti),
case::right_semi(JoinType::RightSemi),
case::right_anti(JoinType::RightAnti)
case::right_anti(JoinType::RightAnti),
case::right_mark(JoinType::RightMark)
)]
#[tokio::test]
async fn test_nl_join_with_swap_no_proj(join_type: JoinType) {
Expand Down
5 changes: 4 additions & 1 deletion datafusion/expr/src/logical_plan/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1696,7 +1696,10 @@ pub fn build_join_schema(
);

let (schema1, schema2) = match join_type {
JoinType::Right | JoinType::RightSemi | JoinType::RightAnti => (left, right),
JoinType::Right
| JoinType::RightSemi
| JoinType::RightAnti
| JoinType::RightMark => (left, right),
_ => (right, left),
};

Expand Down
49 changes: 47 additions & 2 deletions datafusion/optimizer/src/decorrelate_predicate_subquery.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ use datafusion_common::{internal_err, plan_err, Column, Result};
use datafusion_expr::expr::{Exists, InSubquery};
use datafusion_expr::expr_rewriter::create_col_from_scalar_expr;
use datafusion_expr::logical_plan::{JoinType, Subquery};
use datafusion_expr::utils::{conjunction, split_conjunction_owned};
use datafusion_expr::utils::{conjunction, expr_to_columns, split_conjunction_owned};
use datafusion_expr::{
exists, in_subquery, lit, not, not_exists, not_in_subquery, BinaryExpr, Expr, Filter,
LogicalPlan, LogicalPlanBuilder, Operator,
Expand Down Expand Up @@ -342,7 +342,7 @@ fn build_join(
replace_qualified_name(filter, &all_correlated_cols, &alias).map(Some)
})?;

let join_filter = match (join_filter_opt, in_predicate_opt) {
let join_filter = match (join_filter_opt, in_predicate_opt.clone()) {
(
Some(join_filter),
Some(Expr::BinaryExpr(BinaryExpr {
Expand Down Expand Up @@ -371,6 +371,51 @@ fn build_join(
(None, None) => lit(true),
_ => return Ok(None),
};

if matches!(join_type, JoinType::LeftMark | JoinType::RightMark) {
let right_schema = sub_query_alias.schema();

// Gather all columns needed for the join filter + predicates
let mut needed = std::collections::HashSet::new();
expr_to_columns(&join_filter, &mut needed)?;
if let Some(ref in_pred) = in_predicate_opt {
expr_to_columns(in_pred, &mut needed)?;
}

// Keep only columns that actually belong to the RIGHT child, and sort by their
// position in the right schema for deterministic order.
let mut right_cols_idx_and_col: Vec<(usize, Column)> = needed
.into_iter()
.filter_map(|c| right_schema.index_of_column(&c).ok().map(|idx| (idx, c)))
.collect();

right_cols_idx_and_col.sort_by_key(|(idx, _)| *idx);

let right_proj_exprs: Vec<Expr> = right_cols_idx_and_col
.into_iter()
.map(|(_, c)| Expr::Column(c))
.collect();

let right_projected = if !right_proj_exprs.is_empty() {
LogicalPlanBuilder::from(sub_query_alias.clone())
.project(right_proj_exprs)?
.build()?
} else {
// Degenerate case: no right columns referenced by the predicate(s)
sub_query_alias.clone()
};
let new_plan = LogicalPlanBuilder::from(left.clone())
.join_on(right_projected, join_type, Some(join_filter))?
.build()?;

debug!(
"predicate subquery optimized:\n{}",
new_plan.display_indent()
);

return Ok(Some(new_plan));
}

// join our sub query into the main plan
let new_plan = LogicalPlanBuilder::from(left.clone())
.join_on(sub_query_alias, join_type, Some(join_filter))?
Expand Down
6 changes: 5 additions & 1 deletion datafusion/physical-optimizer/src/join_selection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -514,7 +514,11 @@ pub(crate) fn swap_join_according_to_unboundedness(
match (*partition_mode, *join_type) {
(
_,
JoinType::Right | JoinType::RightSemi | JoinType::RightAnti | JoinType::Full,
JoinType::Right
| JoinType::RightSemi
| JoinType::RightAnti
| JoinType::RightMark
| JoinType::Full,
) => internal_err!("{join_type} join cannot be swapped for unbounded input."),
(PartitionMode::Partitioned, _) => {
hash_join.swap_inputs(PartitionMode::Partitioned)
Expand Down
2 changes: 2 additions & 0 deletions datafusion/physical-plan/src/joins/hash_join/exec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -690,6 +690,8 @@ impl HashJoinExec {
| JoinType::RightSemi
| JoinType::LeftAnti
| JoinType::RightAnti
| JoinType::LeftMark
| JoinType::RightMark
) || self.projection.is_some()
{
Ok(Arc::new(new_join))
Expand Down
2 changes: 2 additions & 0 deletions datafusion/physical-plan/src/joins/nested_loop_join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -379,6 +379,8 @@ impl NestedLoopJoinExec {
| JoinType::RightSemi
| JoinType::LeftAnti
| JoinType::RightAnti
| JoinType::LeftMark
| JoinType::RightMark
) || self.projection.is_some()
{
Arc::new(new_join)
Expand Down
73 changes: 43 additions & 30 deletions datafusion/physical-plan/src/joins/sort_merge_join/stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -430,7 +430,7 @@ pub(super) fn get_corrected_filter_mask(
corrected_mask.append_n(expected_size - corrected_mask.len(), false);
Some(corrected_mask.finish())
}
JoinType::LeftMark => {
JoinType::LeftMark | JoinType::RightMark => {
for i in 0..row_indices_length {
let last_index =
last_index_for_row(i, row_indices, batch_ids, row_indices_length);
Expand Down Expand Up @@ -582,6 +582,7 @@ impl Stream for SortMergeJoinStream {
| JoinType::LeftMark
| JoinType::Right
| JoinType::RightSemi
| JoinType::RightMark
| JoinType::LeftAnti
| JoinType::RightAnti
| JoinType::Full
Expand Down Expand Up @@ -691,6 +692,7 @@ impl Stream for SortMergeJoinStream {
| JoinType::LeftAnti
| JoinType::RightAnti
| JoinType::LeftMark
| JoinType::RightMark
| JoinType::Full
)
{
Expand Down Expand Up @@ -718,6 +720,7 @@ impl Stream for SortMergeJoinStream {
| JoinType::RightAnti
| JoinType::Full
| JoinType::LeftMark
| JoinType::RightMark
)
{
let record_batch = self.filter_joined_batch()?;
Expand Down Expand Up @@ -1042,16 +1045,23 @@ impl SortMergeJoinStream {
| JoinType::LeftAnti
| JoinType::RightAnti
| JoinType::LeftMark
| JoinType::RightMark
) {
join_streamed = !self.streamed_joined;
}
}
Ordering::Equal => {
if matches!(
self.join_type,
JoinType::LeftSemi | JoinType::LeftMark | JoinType::RightSemi
JoinType::LeftSemi
| JoinType::LeftMark
| JoinType::RightSemi
| JoinType::RightMark
) {
mark_row_as_match = matches!(self.join_type, JoinType::LeftMark);
mark_row_as_match = matches!(
self.join_type,
JoinType::LeftMark | JoinType::RightMark
);
// if the join filter is specified then its needed to output the streamed index
// only if it has not been emitted before
// the `join_filter_matched_idxs` keeps track on if streamed index has a successful
Expand Down Expand Up @@ -1266,31 +1276,32 @@ impl SortMergeJoinStream {

// The row indices of joined buffered batch
let right_indices: UInt64Array = chunk.buffered_indices.finish();
let mut right_columns = if matches!(self.join_type, JoinType::LeftMark) {
vec![Arc::new(is_not_null(&right_indices)?) as ArrayRef]
} else if matches!(
self.join_type,
JoinType::LeftSemi
| JoinType::LeftAnti
| JoinType::RightAnti
| JoinType::RightSemi
) {
vec![]
} else if let Some(buffered_idx) = chunk.buffered_batch_idx {
fetch_right_columns_by_idxs(
&self.buffered_data,
buffered_idx,
&right_indices,
)?
} else {
// If buffered batch none, meaning it is null joined batch.
// We need to create null arrays for buffered columns to join with streamed rows.
create_unmatched_columns(
let mut right_columns =
if matches!(self.join_type, JoinType::LeftMark | JoinType::RightMark) {
vec![Arc::new(is_not_null(&right_indices)?) as ArrayRef]
} else if matches!(
self.join_type,
&self.buffered_schema,
right_indices.len(),
)
};
JoinType::LeftSemi
| JoinType::LeftAnti
| JoinType::RightAnti
| JoinType::RightSemi
) {
vec![]
} else if let Some(buffered_idx) = chunk.buffered_batch_idx {
fetch_right_columns_by_idxs(
&self.buffered_data,
buffered_idx,
&right_indices,
)?
} else {
// If buffered batch none, meaning it is null joined batch.
// We need to create null arrays for buffered columns to join with streamed rows.
create_unmatched_columns(
self.join_type,
&self.buffered_schema,
right_indices.len(),
)
};

// Prepare the columns we apply join filter on later.
// Only for joined rows between streamed and buffered.
Expand All @@ -1309,7 +1320,7 @@ impl SortMergeJoinStream {
get_filter_column(&self.filter, &left_columns, &right_cols)
} else if matches!(
self.join_type,
JoinType::RightAnti | JoinType::RightSemi
JoinType::RightAnti | JoinType::RightSemi | JoinType::RightMark
) {
let right_cols = fetch_right_columns_by_idxs(
&self.buffered_data,
Expand Down Expand Up @@ -1375,6 +1386,7 @@ impl SortMergeJoinStream {
| JoinType::LeftAnti
| JoinType::RightAnti
| JoinType::LeftMark
| JoinType::RightMark
| JoinType::Full
) {
self.staging_output_record_batches
Expand Down Expand Up @@ -1475,6 +1487,7 @@ impl SortMergeJoinStream {
| JoinType::LeftAnti
| JoinType::RightAnti
| JoinType::LeftMark
| JoinType::RightMark
| JoinType::Full
))
{
Expand Down Expand Up @@ -1537,7 +1550,7 @@ impl SortMergeJoinStream {

if matches!(
self.join_type,
JoinType::Left | JoinType::LeftMark | JoinType::Right
JoinType::Left | JoinType::LeftMark | JoinType::Right | JoinType::RightMark
) {
let null_mask = compute::not(corrected_mask)?;
let null_joined_batch = filter_record_batch(&record_batch, &null_mask)?;
Expand Down Expand Up @@ -1658,7 +1671,7 @@ fn create_unmatched_columns(
schema: &SchemaRef,
size: usize,
) -> Vec<ArrayRef> {
if matches!(join_type, JoinType::LeftMark) {
if matches!(join_type, JoinType::LeftMark | JoinType::RightMark) {
vec![Arc::new(BooleanArray::from(vec![false; size])) as ArrayRef]
} else {
schema
Expand Down
Loading