Skip to content
Merged
Show file tree
Hide file tree
Changes from 29 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
157 changes: 120 additions & 37 deletions datafusion/physical-plan/src/joins/hash_join/exec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
// under the License.

use std::fmt;
use std::marker::PhantomData;
use std::mem::size_of;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::{Arc, OnceLock};
Expand All @@ -26,7 +27,9 @@ use crate::filter_pushdown::{
ChildPushdownResult, FilterDescription, FilterPushdownPhase,
FilterPushdownPropagation,
};
use crate::joins::hash_join::shared_bounds::{ColumnBounds, SharedBoundsAccumulator};
use crate::joins::hash_join::shared_bounds::{
ColumnBounds, MinMaxColumnBounds, SharedBoundsAccumulator,
};
use crate::joins::hash_join::stream::{
BuildSide, BuildSideInitialState, HashJoinStream, HashJoinStreamState,
};
Expand Down Expand Up @@ -103,7 +106,7 @@ pub(super) struct JoinLeftData {
/// The MemoryReservation ensures proper tracking of memory resources throughout the join operation's lifecycle.
_reservation: MemoryReservation,
/// Bounds computed from the build side for dynamic filter pushdown
pub(super) bounds: Option<Vec<ColumnBounds>>,
pub(super) bounds: Option<Vec<Arc<dyn ColumnBounds>>>,
}

impl JoinLeftData {
Expand All @@ -115,7 +118,7 @@ impl JoinLeftData {
visited_indices_bitmap: SharedBitmapBuilder,
probe_threads_counter: AtomicUsize,
reservation: MemoryReservation,
bounds: Option<Vec<ColumnBounds>>,
bounds: Option<Vec<Arc<dyn ColumnBounds>>>,
) -> Self {
Self {
hash_map,
Expand Down Expand Up @@ -319,7 +322,7 @@ impl JoinLeftData {
/// Note this structure includes a [`OnceAsync`] that is used to coordinate the
/// loading of the left side with the processing in each output stream.
/// Therefore it can not be [`Clone`]
pub struct HashJoinExec {
pub struct HashJoinExec<A: CollectLeftAccumulator + 'static = MinMaxLeftAccumulator> {
/// left (build) side which gets hashed
pub left: Arc<dyn ExecutionPlan>,
/// right (probe) side which are filtered by the hash table
Expand Down Expand Up @@ -358,6 +361,8 @@ pub struct HashJoinExec {
/// Set when dynamic filter pushdown is detected in handle_child_pushdown_result.
/// HashJoinExec also needs to keep a shared bounds accumulator for coordinating updates.
dynamic_filter: Option<HashJoinExecDynamicFilter>,
/// Phantom data for the bounds accumulator type
_phantom_accumulator: PhantomData<A>,
}

#[derive(Clone)]
Expand All @@ -369,7 +374,7 @@ struct HashJoinExecDynamicFilter {
bounds_accumulator: OnceLock<Arc<SharedBoundsAccumulator>>,
}

impl fmt::Debug for HashJoinExec {
impl<A: CollectLeftAccumulator> fmt::Debug for HashJoinExec<A> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("HashJoinExec")
.field("left", &self.left)
Expand All @@ -391,14 +396,14 @@ impl fmt::Debug for HashJoinExec {
}
}

impl EmbeddedProjection for HashJoinExec {
impl<A: CollectLeftAccumulator + 'static> EmbeddedProjection for HashJoinExec<A> {
fn with_projection(&self, projection: Option<Vec<usize>>) -> Result<Self> {
self.with_projection(projection)
}
}

impl HashJoinExec {
/// Tries to create a new [HashJoinExec].
impl HashJoinExec<MinMaxLeftAccumulator> {
/// Tries to create a new [HashJoinExec] with a default `MinMaxLeftAccumulator` bounds accumulator.
///
/// # Error
/// This function errors when it is not possible to join the left and right sides on keys `on`.
Expand Down Expand Up @@ -460,6 +465,75 @@ impl HashJoinExec {
null_equality,
cache,
dynamic_filter: None,
_phantom_accumulator: PhantomData,
})
}
}

impl<A: CollectLeftAccumulator + 'static> HashJoinExec<A> {
/// Tries to create a new [HashJoinExec] with a custom bounds accumulator.
///
/// # Error
/// This function errors when it is not possible to join the left and right sides on keys `on`.
#[allow(clippy::too_many_arguments)]
pub fn try_new_with_accumulator(
left: Arc<dyn ExecutionPlan>,
right: Arc<dyn ExecutionPlan>,
on: JoinOn,
filter: Option<JoinFilter>,
join_type: &JoinType,
projection: Option<Vec<usize>>,
partition_mode: PartitionMode,
null_equality: NullEquality,
) -> Result<Self> {
let left_schema = left.schema();
let right_schema = right.schema();
if on.is_empty() {
return plan_err!("On constraints in HashJoinExec should be non-empty");
}

check_join_is_valid(&left_schema, &right_schema, &on)?;

let (join_schema, column_indices) =
build_join_schema(&left_schema, &right_schema, join_type);

let random_state = HASH_JOIN_SEED;

let join_schema = Arc::new(join_schema);

// check if the projection is valid
can_project(&join_schema, projection.as_ref())?;

let cache = Self::compute_properties(
&left,
&right,
Arc::clone(&join_schema),
*join_type,
&on,
partition_mode,
projection.as_ref(),
)?;

// Initialize both dynamic filter and bounds accumulator to None
// They will be set later if dynamic filtering is enabled

Ok(HashJoinExec {
left,
right,
on,
filter,
join_type: *join_type,
join_schema,
left_fut: Default::default(),
random_state,
mode: partition_mode,
metrics: ExecutionPlanMetricsSet::new(),
projection,
column_indices,
null_equality,
cache,
dynamic_filter: None,
_phantom_accumulator: PhantomData,
})
}

Expand Down Expand Up @@ -549,7 +623,7 @@ impl HashJoinExec {
},
None => None,
};
Self::try_new(
Self::try_new_with_accumulator(
Arc::clone(&self.left),
Arc::clone(&self.right),
self.on.clone(),
Expand Down Expand Up @@ -665,7 +739,7 @@ impl HashJoinExec {
) -> Result<Arc<dyn ExecutionPlan>> {
let left = self.left();
let right = self.right();
let new_join = HashJoinExec::try_new(
let new_join = HashJoinExec::<A>::try_new_with_accumulator(
Arc::clone(right),
Arc::clone(left),
self.on()
Expand Down Expand Up @@ -699,7 +773,7 @@ impl HashJoinExec {
}
}

impl DisplayAs for HashJoinExec {
impl<A: CollectLeftAccumulator + 'static> DisplayAs for HashJoinExec<A> {
fn fmt_as(&self, t: DisplayFormatType, f: &mut fmt::Formatter) -> fmt::Result {
match t {
DisplayFormatType::Default | DisplayFormatType::Verbose => {
Expand Down Expand Up @@ -763,7 +837,7 @@ impl DisplayAs for HashJoinExec {
}
}

impl ExecutionPlan for HashJoinExec {
impl<A: CollectLeftAccumulator + 'static> ExecutionPlan for HashJoinExec<A> {
fn name(&self) -> &'static str {
"HashJoinExec"
}
Expand Down Expand Up @@ -833,7 +907,7 @@ impl ExecutionPlan for HashJoinExec {
self: Arc<Self>,
children: Vec<Arc<dyn ExecutionPlan>>,
) -> Result<Arc<dyn ExecutionPlan>> {
Ok(Arc::new(HashJoinExec {
Ok(Arc::new(HashJoinExec::<A> {
left: Arc::clone(&children[0]),
right: Arc::clone(&children[1]),
on: self.on.clone(),
Expand All @@ -858,11 +932,12 @@ impl ExecutionPlan for HashJoinExec {
)?,
// Keep the dynamic filter, bounds accumulator will be reset
dynamic_filter: self.dynamic_filter.clone(),
_phantom_accumulator: PhantomData,
}))
}

fn reset_state(self: Arc<Self>) -> Result<Arc<dyn ExecutionPlan>> {
Ok(Arc::new(HashJoinExec {
Ok(Arc::new(HashJoinExec::<A> {
left: Arc::clone(&self.left),
right: Arc::clone(&self.right),
on: self.on.clone(),
Expand All @@ -880,6 +955,7 @@ impl ExecutionPlan for HashJoinExec {
cache: self.cache.clone(),
// Reset dynamic filter and bounds accumulator to initial state
dynamic_filter: None,
_phantom_accumulator: PhantomData,
}))
}

Expand Down Expand Up @@ -921,7 +997,7 @@ impl ExecutionPlan for HashJoinExec {
let reservation =
MemoryConsumer::new("HashJoinInput").register(context.memory_pool());

Ok(collect_left_input(
Ok(collect_left_input::<MinMaxLeftAccumulator>(
self.random_state.clone(),
left_stream,
on_left.clone(),
Expand All @@ -939,7 +1015,7 @@ impl ExecutionPlan for HashJoinExec {
MemoryConsumer::new(format!("HashJoinInput[{partition}]"))
.register(context.memory_pool());

OnceFut::new(collect_left_input(
OnceFut::new(collect_left_input::<MinMaxLeftAccumulator>(
self.random_state.clone(),
left_stream,
on_left.clone(),
Expand Down Expand Up @@ -1162,7 +1238,7 @@ impl ExecutionPlan for HashJoinExec {
Arc::downcast::<DynamicFilterPhysicalExpr>(predicate)
{
// We successfully pushed down our self filter - we need to make a new node with the dynamic filter
let new_node = Arc::new(HashJoinExec {
let new_node = Arc::new(HashJoinExec::<A> {
left: Arc::clone(&self.left),
right: Arc::clone(&self.right),
on: self.on.clone(),
Expand All @@ -1181,6 +1257,7 @@ impl ExecutionPlan for HashJoinExec {
filter: dynamic_filter,
bounds_accumulator: OnceLock::new(),
}),
_phantom_accumulator: PhantomData,
});
result = result.with_updated_node(new_node as Arc<dyn ExecutionPlan>);
}
Expand All @@ -1189,6 +1266,15 @@ impl ExecutionPlan for HashJoinExec {
}
}

pub trait CollectLeftAccumulator: Send + Sync {
fn try_new(expr: Arc<dyn PhysicalExpr>, schema: &SchemaRef) -> Result<Self>
where
Self: Sized;

fn update_batch(&mut self, batch: &RecordBatch) -> Result<()>;
fn evaluate(self) -> Result<Arc<dyn ColumnBounds>>;
}

/// Accumulator for collecting min/max bounds from build-side data during hash join.
///
/// This struct encapsulates the logic for progressively computing column bounds
Expand All @@ -1198,7 +1284,7 @@ impl ExecutionPlan for HashJoinExec {
/// The bounds are used for dynamic filter pushdown optimization, where filters
/// based on the actual data ranges can be pushed down to the probe side to
/// eliminate unnecessary data early.
struct CollectLeftAccumulator {
pub struct MinMaxLeftAccumulator {
/// The physical expression to evaluate for each batch
expr: Arc<dyn PhysicalExpr>,
/// Accumulator for tracking the minimum value across all batches
Expand All @@ -1207,7 +1293,7 @@ struct CollectLeftAccumulator {
max: MaxAccumulator,
}

impl CollectLeftAccumulator {
impl CollectLeftAccumulator for MinMaxLeftAccumulator {
/// Creates a new accumulator for tracking bounds of a join key expression.
///
/// # Arguments
Expand Down Expand Up @@ -1261,24 +1347,24 @@ impl CollectLeftAccumulator {
///
/// # Returns
/// The `ColumnBounds` containing the minimum and maximum values observed
fn evaluate(mut self) -> Result<ColumnBounds> {
Ok(ColumnBounds::new(
fn evaluate(mut self) -> Result<Arc<dyn ColumnBounds>> {
Ok(Arc::new(MinMaxColumnBounds::new(
self.min.evaluate()?,
self.max.evaluate()?,
))
)))
}
}

/// State for collecting the build-side data during hash join
struct BuildSideState {
struct BuildSideState<A: CollectLeftAccumulator> {
batches: Vec<RecordBatch>,
num_rows: usize,
metrics: BuildProbeJoinMetrics,
reservation: MemoryReservation,
bounds_accumulators: Option<Vec<CollectLeftAccumulator>>,
bounds_accumulators: Option<Vec<A>>,
}

impl BuildSideState {
impl<A: CollectLeftAccumulator> BuildSideState<A> {
/// Create a new BuildSideState with optional accumulators for bounds computation
fn try_new(
metrics: BuildProbeJoinMetrics,
Expand All @@ -1296,9 +1382,7 @@ impl BuildSideState {
.then(|| {
on_left
.iter()
.map(|expr| {
CollectLeftAccumulator::try_new(Arc::clone(expr), schema)
})
.map(|expr| A::try_new(Arc::clone(expr), schema))
.collect::<Result<Vec<_>>>()
})
.transpose()?,
Expand Down Expand Up @@ -1335,7 +1419,7 @@ impl BuildSideState {
/// `JoinLeftData` containing the hash map, consolidated batch, join key values,
/// visited indices bitmap, and computed bounds (if requested).
#[allow(clippy::too_many_arguments)]
async fn collect_left_input(
async fn collect_left_input<A: CollectLeftAccumulator>(
random_state: RandomState,
left_stream: SendableRecordBatchStream,
on_left: Vec<PhysicalExprRef>,
Expand All @@ -1350,7 +1434,7 @@ async fn collect_left_input(
// This operation performs 2 steps at once:
// 1. creates a [JoinHashMap] of all batches from the stream
// 2. stores the batches in a vector.
let initial = BuildSideState::try_new(
let initial = BuildSideState::<A>::try_new(
metrics,
reservation,
on_left.clone(),
Expand Down Expand Up @@ -1384,7 +1468,7 @@ async fn collect_left_input(
.await?;

// Extract fields from state
let BuildSideState {
let BuildSideState::<A> {
batches,
num_rows,
metrics,
Expand Down Expand Up @@ -1459,13 +1543,12 @@ async fn collect_left_input(

// Compute bounds for dynamic filter if enabled
let bounds = match bounds_accumulators {
Some(accumulators) if num_rows > 0 => {
let bounds = accumulators
Some(accumulators) if num_rows > 0 => Some(
accumulators
.into_iter()
.map(CollectLeftAccumulator::evaluate)
.collect::<Result<Vec<_>>>()?;
Some(bounds)
}
.map(|a| a.evaluate())
.collect::<Result<Vec<_>>>()?,
),
_ => None,
};

Expand Down
2 changes: 2 additions & 0 deletions datafusion/physical-plan/src/joins/hash_join/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
//! [`HashJoinExec`] Partitioned Hash Join Operator

pub use exec::HashJoinExec;
pub use exec::{CollectLeftAccumulator, MinMaxLeftAccumulator};
pub use shared_bounds::ColumnBounds;

mod exec;
mod shared_bounds;
Expand Down
Loading