diff --git a/datafusion/common/src/join_type.rs b/datafusion/common/src/join_type.rs index d9a1478f02387..e59a165df398d 100644 --- a/datafusion/common/src/join_type.rs +++ b/datafusion/common/src/join_type.rs @@ -74,6 +74,15 @@ pub enum JoinType { RightMark, } +// Semi/anti joins intentionally omitted: they only emit subsets of the left input and thus do +// not "preserve" every row for dynamic filter purposes. +const LEFT_PRESERVING: &[JoinType] = + &[JoinType::Left, JoinType::Full, JoinType::LeftMark]; + +// Symmetric rationale applies on the right: semi/anti joins do not preserve all right rows. +const RIGHT_PRESERVING: &[JoinType] = + &[JoinType::Right, JoinType::Full, JoinType::RightMark]; + impl JoinType { pub fn is_outer(self) -> bool { self == JoinType::Left || self == JoinType::Right || self == JoinType::Full @@ -111,6 +120,31 @@ impl JoinType { | JoinType::RightAnti ) } + + /// Returns true if this join type preserves all rows from the specified `side`. + pub fn preserves(self, side: JoinSide) -> bool { + match side { + JoinSide::Left => LEFT_PRESERVING.contains(&self), + JoinSide::Right => RIGHT_PRESERVING.contains(&self), + JoinSide::None => false, + } + } + + /// Returns true if this join type preserves all rows from its left input. + /// + /// For [`JoinType::Left`], [`JoinType::Full`], and [`JoinType::LeftMark`] joins + /// every row from the left side will appear in the output at least once. + pub fn preserves_left(self) -> bool { + self.preserves(JoinSide::Left) + } + + /// Returns true if this join type preserves all rows from its right input. + /// + /// For [`JoinType::Right`], [`JoinType::Full`], and [`JoinType::RightMark`] joins + /// every row from the right side will appear in the output at least once. + pub fn preserves_right(self) -> bool { + self.preserves(JoinSide::Right) + } } impl Display for JoinType { @@ -194,3 +228,57 @@ impl JoinSide { } } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_join_type_swap() { + assert_eq!(JoinType::Inner.swap(), JoinType::Inner); + assert_eq!(JoinType::Left.swap(), JoinType::Right); + assert_eq!(JoinType::Right.swap(), JoinType::Left); + assert_eq!(JoinType::Full.swap(), JoinType::Full); + assert_eq!(JoinType::LeftSemi.swap(), JoinType::RightSemi); + assert_eq!(JoinType::RightSemi.swap(), JoinType::LeftSemi); + assert_eq!(JoinType::LeftAnti.swap(), JoinType::RightAnti); + assert_eq!(JoinType::RightAnti.swap(), JoinType::LeftAnti); + assert_eq!(JoinType::LeftMark.swap(), JoinType::RightMark); + assert_eq!(JoinType::RightMark.swap(), JoinType::LeftMark); + } + + #[test] + fn test_join_type_supports_swap() { + use JoinType::*; + let supported = [ + Inner, Left, Right, Full, LeftSemi, RightSemi, LeftAnti, RightAnti, + ]; + for jt in supported { + assert!(jt.supports_swap(), "{jt:?} should support swap"); + } + let not_supported = [LeftMark, RightMark]; + for jt in not_supported { + assert!(!jt.supports_swap(), "{jt:?} should not support swap"); + } + } + + #[test] + fn test_preserves_sides() { + use JoinSide::*; + + assert!(JoinType::Left.preserves(Left)); + assert!(JoinType::Full.preserves(Left)); + assert!(JoinType::LeftMark.preserves(Left)); + assert!(!JoinType::LeftSemi.preserves(Left)); + + assert!(JoinType::Right.preserves(Right)); + assert!(JoinType::Full.preserves(Right)); + assert!(JoinType::RightMark.preserves(Right)); + assert!(!JoinType::RightSemi.preserves(Right)); + + assert!(!JoinType::LeftAnti.preserves(Left)); + assert!(!JoinType::LeftAnti.preserves(Right)); + assert!(!JoinType::RightAnti.preserves(Left)); + assert!(!JoinType::RightAnti.preserves(Right)); + } +} diff --git a/datafusion/core/tests/physical_optimizer/filter_pushdown/mod.rs b/datafusion/core/tests/physical_optimizer/filter_pushdown/mod.rs index 9f588519ecac3..0f478c926cc9d 100644 --- a/datafusion/core/tests/physical_optimizer/filter_pushdown/mod.rs +++ b/datafusion/core/tests/physical_optimizer/filter_pushdown/mod.rs @@ -18,7 +18,7 @@ use std::sync::{Arc, LazyLock}; use arrow::{ - array::record_batch, + array::{record_batch, Int32Array, RecordBatch}, datatypes::{DataType, Field, Schema, SchemaRef}, util::pretty::pretty_format_batches, }; @@ -33,7 +33,7 @@ use datafusion::{ prelude::{ParquetReadOptions, SessionConfig, SessionContext}, scalar::ScalarValue, }; -use datafusion_common::config::ConfigOptions; +use datafusion_common::{config::ConfigOptions, JoinType}; use datafusion_execution::object_store::ObjectStoreUrl; use datafusion_expr::ScalarUDF; use datafusion_functions::math::random::RandomFunc; @@ -41,7 +41,10 @@ use datafusion_functions_aggregate::count::count_udaf; use datafusion_physical_expr::{ aggregate::AggregateExprBuilder, Partitioning, ScalarFunctionExpr, }; -use datafusion_physical_expr::{expressions::col, LexOrdering, PhysicalSortExpr}; +use datafusion_physical_expr::{ + expressions::{col, lit, DynamicFilterPhysicalExpr}, + LexOrdering, PhysicalSortExpr, +}; use datafusion_physical_optimizer::{ filter_pushdown::FilterPushdown, PhysicalOptimizerRule, }; @@ -58,7 +61,10 @@ use datafusion_physical_plan::{ use futures::StreamExt; use object_store::{memory::InMemory, ObjectStore}; -use util::{format_plan_for_test, OptimizationTest, TestNode, TestScanBuilder}; +use util::{ + build_hash_join, build_topk, format_plan_for_test, sort_expr, OptimizationTest, + TestNode, TestScanBuilder, +}; mod util; @@ -166,8 +172,7 @@ fn test_pushdown_into_scan_with_config_options() { #[tokio::test] async fn test_dynamic_filter_pushdown_through_hash_join_with_topk() { - use datafusion_common::JoinType; - use datafusion_physical_plan::joins::{HashJoinExec, PartitionMode}; + use datafusion_physical_plan::joins::PartitionMode; // Create build side with limited values let build_batches = vec![record_batch!( @@ -176,15 +181,6 @@ async fn test_dynamic_filter_pushdown_through_hash_join_with_topk() { ("c", Float64, [1.0, 2.0]) ) .unwrap()]; - let build_side_schema = Arc::new(Schema::new(vec![ - Field::new("a", DataType::Utf8, false), - Field::new("b", DataType::Utf8View, false), - Field::new("c", DataType::Float64, false), - ])); - let build_scan = TestScanBuilder::new(Arc::clone(&build_side_schema)) - .with_support(true) - .with_batches(build_batches) - .build(); // Create probe side with more values let probe_batches = vec![record_batch!( @@ -193,44 +189,19 @@ async fn test_dynamic_filter_pushdown_through_hash_join_with_topk() { ("f", Float64, [1.0, 2.0, 3.0, 4.0]) ) .unwrap()]; - let probe_side_schema = Arc::new(Schema::new(vec![ - Field::new("d", DataType::Utf8, false), - Field::new("e", DataType::Utf8View, false), - Field::new("f", DataType::Float64, false), - ])); - let probe_scan = TestScanBuilder::new(Arc::clone(&probe_side_schema)) - .with_support(true) - .with_batches(probe_batches) - .build(); // Create HashJoinExec - let on = vec![( - col("a", &build_side_schema).unwrap(), - col("d", &probe_side_schema).unwrap(), - )]; - let join = Arc::new( - HashJoinExec::try_new( - build_scan, - probe_scan, - on, - None, - &JoinType::Inner, - None, - PartitionMode::Partitioned, - datafusion_common::NullEquality::NullEqualsNothing, - ) - .unwrap(), + let join = build_hash_join( + build_batches, + probe_batches, + vec![("a", "d")], + JoinType::Inner, + PartitionMode::Partitioned, ); - let join_schema = join.schema(); - // Finally let's add a SortExec on the outside to test pushdown of dynamic filters - let sort_expr = - PhysicalSortExpr::new(col("e", &join_schema).unwrap(), SortOptions::default()); - let plan = Arc::new( - SortExec::new(LexOrdering::new(vec![sort_expr]).unwrap(), join) - .with_fetch(Some(2)), - ) as Arc; + let sort_expr = sort_expr("e", &join.schema(), SortOptions::default()); + let plan = build_topk(join, vec![sort_expr], 2); let mut config = ConfigOptions::default(); config.optimizer.enable_dynamic_filter_pushdown = true; @@ -282,9 +253,19 @@ async fn test_dynamic_filter_pushdown_through_hash_join_with_topk() { // Dynamic filters arise in cases such as nested inner joins or TopK -> HashJoinExec -> Scan setups. #[tokio::test] async fn test_static_filter_pushdown_through_hash_join() { - use datafusion_common::JoinType; use datafusion_physical_plan::joins::{HashJoinExec, PartitionMode}; + let build_side_schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Utf8, false), + Field::new("b", DataType::Utf8View, false), + Field::new("c", DataType::Float64, false), + ])); + let probe_side_schema = Arc::new(Schema::new(vec![ + Field::new("d", DataType::Utf8, false), + Field::new("e", DataType::Utf8View, false), + Field::new("f", DataType::Float64, false), + ])); + // Create build side with limited values let build_batches = vec![record_batch!( ("a", Utf8, ["aa", "ab"]), @@ -292,15 +273,6 @@ async fn test_static_filter_pushdown_through_hash_join() { ("c", Float64, [1.0, 2.0]) ) .unwrap()]; - let build_side_schema = Arc::new(Schema::new(vec![ - Field::new("a", DataType::Utf8, false), - Field::new("b", DataType::Utf8View, false), - Field::new("c", DataType::Float64, false), - ])); - let build_scan = TestScanBuilder::new(Arc::clone(&build_side_schema)) - .with_support(true) - .with_batches(build_batches) - .build(); // Create probe side with more values let probe_batches = vec![record_batch!( @@ -309,33 +281,14 @@ async fn test_static_filter_pushdown_through_hash_join() { ("f", Float64, [1.0, 2.0, 3.0, 4.0]) ) .unwrap()]; - let probe_side_schema = Arc::new(Schema::new(vec![ - Field::new("d", DataType::Utf8, false), - Field::new("e", DataType::Utf8View, false), - Field::new("f", DataType::Float64, false), - ])); - let probe_scan = TestScanBuilder::new(Arc::clone(&probe_side_schema)) - .with_support(true) - .with_batches(probe_batches) - .build(); // Create HashJoinExec - let on = vec![( - col("a", &build_side_schema).unwrap(), - col("d", &probe_side_schema).unwrap(), - )]; - let join = Arc::new( - HashJoinExec::try_new( - build_scan, - probe_scan, - on, - None, - &JoinType::Inner, - None, - PartitionMode::Partitioned, - datafusion_common::NullEquality::NullEqualsNothing, - ) - .unwrap(), + let join = build_hash_join( + build_batches, + probe_batches, + vec![("a", "d")], + JoinType::Inner, + PartitionMode::Partitioned, ); // Create filters that can be pushed down to different sides @@ -353,8 +306,7 @@ async fn test_static_filter_pushdown_through_hash_join() { col("d", &join_schema).unwrap(), )) as Arc; - let filter = - Arc::new(FilterExec::try_new(left_filter, Arc::clone(&join) as _).unwrap()); + let filter = Arc::new(FilterExec::try_new(left_filter, Arc::clone(&join)).unwrap()); let filter = Arc::new(FilterExec::try_new(right_filter, filter).unwrap()); let plan = Arc::new(FilterExec::try_new(cross_filter, filter).unwrap()) as Arc; @@ -836,8 +788,7 @@ async fn test_topk_dynamic_filter_pushdown_multi_column_sort() { #[tokio::test] async fn test_hashjoin_dynamic_filter_pushdown() { - use datafusion_common::JoinType; - use datafusion_physical_plan::joins::{HashJoinExec, PartitionMode}; + use datafusion_physical_plan::joins::PartitionMode; // Create build side with limited values let build_batches = vec![record_batch!( @@ -846,15 +797,6 @@ async fn test_hashjoin_dynamic_filter_pushdown() { ("c", Float64, [1.0, 2.0]) // Extra column not used in join ) .unwrap()]; - let build_side_schema = Arc::new(Schema::new(vec![ - Field::new("a", DataType::Utf8, false), - Field::new("b", DataType::Utf8, false), - Field::new("c", DataType::Float64, false), - ])); - let build_scan = TestScanBuilder::new(Arc::clone(&build_side_schema)) - .with_support(true) - .with_batches(build_batches) - .build(); // Create probe side with more values let probe_batches = vec![record_batch!( @@ -863,40 +805,15 @@ async fn test_hashjoin_dynamic_filter_pushdown() { ("e", Float64, [1.0, 2.0, 3.0, 4.0]) // Extra column not used in join ) .unwrap()]; - let probe_side_schema = Arc::new(Schema::new(vec![ - Field::new("a", DataType::Utf8, false), - Field::new("b", DataType::Utf8, false), - Field::new("e", DataType::Float64, false), - ])); - let probe_scan = TestScanBuilder::new(Arc::clone(&probe_side_schema)) - .with_support(true) - .with_batches(probe_batches) - .build(); // Create HashJoinExec with dynamic filter - let on = vec![ - ( - col("a", &build_side_schema).unwrap(), - col("a", &probe_side_schema).unwrap(), - ), - ( - col("b", &build_side_schema).unwrap(), - col("b", &probe_side_schema).unwrap(), - ), - ]; - let plan = Arc::new( - HashJoinExec::try_new( - build_scan, - probe_scan, - on, - None, - &JoinType::Inner, - None, - PartitionMode::CollectLeft, - datafusion_common::NullEquality::NullEqualsNothing, - ) - .unwrap(), - ) as Arc; + let plan = build_hash_join( + build_batches, + probe_batches, + vec![("a", "a"), ("b", "b")], + JoinType::Inner, + PartitionMode::CollectLeft, + ); // expect the predicate to be pushed down into the probe side DataSource insta::assert_snapshot!( @@ -1219,6 +1136,20 @@ async fn test_hashjoin_dynamic_filter_pushdown_partitioned() { ); } +fn build_int32_scan(values: &[i32]) -> Arc { + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])); + let batch = RecordBatch::try_new( + schema.clone(), + vec![Arc::new(Int32Array::from(values.to_vec()))], + ) + .unwrap(); + let batches = vec![batch]; + TestScanBuilder::new(schema) + .with_support(true) + .with_batches(batches) + .build() +} + #[tokio::test] async fn test_hashjoin_dynamic_filter_pushdown_collect_left() { use datafusion_common::JoinType; @@ -1544,9 +1475,265 @@ async fn test_nested_hashjoin_dynamic_filter_pushdown() { ); } +// Verify dynamic filter pushdown for an INNER hash join. +// +// This test constructs a HashJoinExec where the right side is used to build a +// dynamic filter that should prune rows on the left side. It executes the +// plan with small test scans, enables dynamic filter pushdown, and asserts +// that the resulting plan contains a `DynamicFilterPhysicalExpr` on the probe +// (left) side and that the left scan's metrics reflect the pruning. +#[tokio::test] +async fn test_hashjoin_dynamic_filter_pushdown_inner_join() { + use datafusion_physical_plan::joins::{HashJoinExec, PartitionMode}; + + // Left side with extra values that should be pruned by the dynamic filter + let left_scan = build_int32_scan(&[1, 2, 3, 4]); + let left_schema = left_scan.schema(); + + // Right side with limited values used to build the dynamic filter + let right_scan = build_int32_scan(&[1, 2]); + let right_schema = right_scan.schema(); + + let on = vec![( + col("a", &left_schema).unwrap(), + col("a", &right_schema).unwrap(), + )]; + + let plan = Arc::new( + HashJoinExec::try_new( + right_scan, + Arc::clone(&left_scan), + on, + None, + &JoinType::Inner, + None, + PartitionMode::CollectLeft, + datafusion_common::NullEquality::NullEqualsNothing, + ) + .unwrap(), + ) as Arc; + + let mut config = ConfigOptions::default(); + config.execution.parquet.pushdown_filters = true; + config.optimizer.enable_dynamic_filter_pushdown = true; + let plan = FilterPushdown::new_post_optimization() + .optimize(plan, &config) + .unwrap(); + + let config = SessionConfig::new().with_batch_size(10); + let session_ctx = SessionContext::new_with_config(config); + session_ctx.register_object_store( + ObjectStoreUrl::parse("test://").unwrap().as_ref(), + Arc::new(InMemory::new()), + ); + let state = session_ctx.state(); + let task_ctx = state.task_ctx(); + let batches = collect(Arc::clone(&plan), Arc::clone(&task_ctx)) + .await + .unwrap(); + + let plan_str = format_plan_for_test(&plan); + assert!(plan_str.contains("DynamicFilterPhysicalExpr [ a@0 >= 1 AND a@0 <= 2 ]")); + insta::assert_snapshot!( + format!("{}", plan_str), + @r" + - HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(a@0, a@0)] + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a], file_type=test, pushdown_supported=true + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a], file_type=test, pushdown_supported=true, predicate=DynamicFilterPhysicalExpr [ a@0 >= 1 AND a@0 <= 2 ] + " + ); + assert_eq!(left_scan.metrics().unwrap().output_rows().unwrap(), 2); + let result = format!("{}", pretty_format_batches(&batches).unwrap()); + insta::assert_snapshot!( + result, + @r" + +---+---+ + | a | a | + +---+---+ + | 1 | 1 | + | 2 | 2 | + +---+---+ + " + ); +} + +// Verify handling of a child pushdown result that carries a dynamic filter +// from the left child when the join expects dynamic filters from the left side. +// +// The test constructs a HashJoinExec with TestScanBuilders for left/right, builds a +// DynamicFilterPhysicalExpr over the left-side join keys and wraps it as an +// unsupported pushed predicate inside a `ChildPushdownResult` for the left child. +// It then calls `handle_child_pushdown_result(...)` and asserts that the returned +// propagation indicates the join node was updated. This ensures dynamic filter +// results reported by children are correctly processed and linked into the join +// operator when applicable. +#[test] +fn test_hashjoin_handle_child_pushdown_result_dynamic_filter_left() { + use datafusion_common::{JoinSide, NullEquality}; + use datafusion_physical_plan::filter_pushdown::{ + ChildPushdownResult, FilterPushdownPhase, PushedDownPredicate, + }; + use datafusion_physical_plan::joins::{HashJoinExec, PartitionMode}; + + // Schemas for left and right inputs + let left_schema = + Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])); + let right_schema = + Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])); + + // Create dummy scans + let left_scan = TestScanBuilder::new(Arc::clone(&left_schema)) + .with_support(true) + .build(); + let right_scan = TestScanBuilder::new(Arc::clone(&right_schema)) + .with_support(true) + .build(); + + let on = vec![( + col("a", &left_schema).unwrap(), + col("a", &right_schema).unwrap(), + )]; + + let join = HashJoinExec::try_new( + Arc::clone(&left_scan), + Arc::clone(&right_scan), + on.clone(), + None, + &JoinType::Right, + None, + PartitionMode::CollectLeft, + NullEquality::NullEqualsNothing, + ) + .unwrap(); + + assert_eq!(join.dynamic_filter_side(), JoinSide::Left); + + let keys: Vec<_> = on.iter().map(|(l, _)| l.clone()).collect(); + let df_expr = Arc::new(DynamicFilterPhysicalExpr::new(keys, lit(true))); + + let child_pushdown_result = ChildPushdownResult { + parent_filters: vec![], + self_filters: vec![ + vec![PushedDownPredicate::unsupported(df_expr.clone())], + vec![], + ], + }; + + let propagation = join + .handle_child_pushdown_result( + FilterPushdownPhase::Post, + child_pushdown_result, + &ConfigOptions::default(), + ) + .unwrap(); + + assert!(propagation.updated_node.is_some()); +} + +// Helper that builds and executes a FULL join used to validate dynamic filter +// behavior for `JoinType::Full`. +// +// The function constructs small left/right scans where the right side contains +// a subset of values present on the left. It enables dynamic filter pushdown +// and executes the plan to collect batches. Tests that call this helper assert +// that dynamic filters for FULL joins are applied appropriately and that +// scan metrics reflect the expected number of output rows (i.e., no incorrect +// pruning that would violate FULL join semantics). +async fn full_join_dynamic_filter_test() -> ( + Arc, + Arc, + Arc, + Vec, +) { + use datafusion_physical_plan::joins::{HashJoinExec, PartitionMode}; + + // Left side with extra values that would be pruned if dynamic filters applied + let left_scan = build_int32_scan(&[1, 2, 3, 4]); + let left_schema = left_scan.schema(); + + // Right side with limited values used for filter construction + let right_scan = build_int32_scan(&[1, 2]); + let right_schema = right_scan.schema(); + + let on = vec![( + col("a", &left_schema).unwrap(), + col("a", &right_schema).unwrap(), + )]; + + let plan = Arc::new( + HashJoinExec::try_new( + Arc::clone(&left_scan), + Arc::clone(&right_scan), + on, + None, + &JoinType::Full, + None, + PartitionMode::CollectLeft, + datafusion_common::NullEquality::NullEqualsNothing, + ) + .unwrap(), + ) as Arc; + + let mut config = ConfigOptions::default(); + config.execution.parquet.pushdown_filters = true; + config.optimizer.enable_dynamic_filter_pushdown = true; + let plan = FilterPushdown::new_post_optimization() + .optimize(plan, &config) + .unwrap(); + + let config = SessionConfig::new().with_batch_size(10); + let session_ctx = SessionContext::new_with_config(config); + session_ctx.register_object_store( + ObjectStoreUrl::parse("test://").unwrap().as_ref(), + Arc::new(InMemory::new()), + ); + let state = session_ctx.state(); + let task_ctx = state.task_ctx(); + let batches = collect(Arc::clone(&plan), Arc::clone(&task_ctx)) + .await + .unwrap(); + + (plan, left_scan, right_scan, batches) +} + +// Verify dynamic filter behavior for a FULL hash join. +// +// This test uses the `full_join_dynamic_filter_test` helper to build and run +// a FULL join where the right side contains a subset of the left values. It +// asserts that dynamic filter pushdown does not incorrectly prune rows that +// must be preserved by FULL join semantics and that scan metrics/reporting +// reflect the expected output row counts. +#[tokio::test] +async fn test_hashjoin_dynamic_filter_pushdown_full_join() { + let (plan, left_scan, right_scan, batches) = full_join_dynamic_filter_test().await; + insta::assert_snapshot!( + format!("{}", format_plan_for_test(&plan)), + @r" + - HashJoinExec: mode=CollectLeft, join_type=Full, on=[(a@0, a@0)] + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a], file_type=test, pushdown_supported=true + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a], file_type=test, pushdown_supported=true + " + ); + assert_eq!(left_scan.metrics().unwrap().output_rows().unwrap(), 4); + assert_eq!(right_scan.metrics().unwrap().output_rows().unwrap(), 2); + let result = format!("{}", pretty_format_batches(&batches).unwrap()); + insta::assert_snapshot!( + result, + @r" + +---+---+ + | a | a | + +---+---+ + | 1 | 1 | + | 2 | 2 | + | 3 | | + | 4 | | + +---+---+ + " + ); +} + #[tokio::test] async fn test_hashjoin_parent_filter_pushdown() { - use datafusion_common::JoinType; use datafusion_physical_plan::joins::{HashJoinExec, PartitionMode}; // Create build side with limited values diff --git a/datafusion/core/tests/physical_optimizer/filter_pushdown/util.rs b/datafusion/core/tests/physical_optimizer/filter_pushdown/util.rs index 2fe705b14921a..ecf0ae2db51fa 100644 --- a/datafusion/core/tests/physical_optimizer/filter_pushdown/util.rs +++ b/datafusion/core/tests/physical_optimizer/filter_pushdown/util.rs @@ -17,26 +17,30 @@ use arrow::datatypes::SchemaRef; use arrow::{array::RecordBatch, compute::concat_batches}; +use arrow_schema::SortOptions; use datafusion::{datasource::object_store::ObjectStoreUrl, physical_plan::PhysicalExpr}; -use datafusion_common::{config::ConfigOptions, internal_err, Result, Statistics}; +use datafusion_common::{ + config::ConfigOptions, internal_err, JoinType, Result, Statistics, +}; use datafusion_datasource::{ file::FileSource, file_meta::FileMeta, file_scan_config::FileScanConfig, file_scan_config::FileScanConfigBuilder, file_stream::FileOpenFuture, file_stream::FileOpener, schema_adapter::DefaultSchemaAdapterFactory, schema_adapter::SchemaAdapterFactory, source::DataSourceExec, PartitionedFile, }; +use datafusion_physical_expr::{expressions::col, LexOrdering, PhysicalSortExpr}; use datafusion_physical_expr_common::physical_expr::fmt_sql; use datafusion_physical_optimizer::PhysicalOptimizerRule; -use datafusion_physical_plan::filter::batch_filter; -use datafusion_physical_plan::filter_pushdown::{FilterPushdownPhase, PushedDown}; use datafusion_physical_plan::{ displayable, - filter::FilterExec, + filter::{batch_filter, FilterExec}, filter_pushdown::{ ChildFilterDescription, ChildPushdownResult, FilterDescription, - FilterPushdownPropagation, + FilterPushdownPhase, FilterPushdownPropagation, PushedDown, }, + joins::{HashJoinExec, PartitionMode}, metrics::ExecutionPlanMetricsSet, + sorts::sort::SortExec, DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties, }; use futures::StreamExt; @@ -300,6 +304,79 @@ impl TestScanBuilder { } } +/// Create a [`DataSourceExec`] from the provided [`RecordBatch`]es that +/// supports filter pushdown. All batches must have the same schema. +pub fn build_scan(batches: Vec) -> Arc { + assert!(!batches.is_empty(), "batches must not be empty"); + let schema = Arc::clone(&batches[0].schema()); + TestScanBuilder::new(schema) + .with_support(true) + .with_batches(batches) + .build() +} + +/// Create a [`HashJoinExec`] joining the provided batches on the named columns. +/// +/// The `on` parameter specifies the join keys as pairs of column names +/// `(left, right)`. +pub fn build_hash_join( + left_batches: Vec, + right_batches: Vec, + on: Vec<(&str, &str)>, + join_type: JoinType, + partition_mode: PartitionMode, +) -> Arc { + let left_schema = Arc::clone(&left_batches[0].schema()); + let right_schema = Arc::clone(&right_batches[0].schema()); + let left = build_scan(left_batches); + let right = build_scan(right_batches); + let on = on + .into_iter() + .map(|(l, r)| { + ( + col(l, &left_schema).unwrap(), + col(r, &right_schema).unwrap(), + ) + }) + .collect(); + Arc::new( + HashJoinExec::try_new( + left, + right, + on, + None, + &join_type, + None, + partition_mode, + datafusion_common::NullEquality::NullEqualsNothing, + ) + .unwrap(), + ) +} + +/// Create a [`SortExec`] configured for TopK behaviour with the given +/// sort expressions and fetch value. +pub fn build_topk( + input: Arc, + sort_exprs: Vec, + fetch: usize, +) -> Arc { + Arc::new( + SortExec::new(LexOrdering::new(sort_exprs).unwrap(), input) + .with_fetch(Some(fetch)), + ) +} + +/// Convenience function to create a [`PhysicalSortExpr`] given a column name +/// and [`SortOptions`]. +pub fn sort_expr( + column: &str, + schema: &SchemaRef, + options: SortOptions, +) -> PhysicalSortExpr { + PhysicalSortExpr::new(col(column, schema).unwrap(), options) +} + /// Index into the data that has been returned so far #[derive(Debug, Default, Clone)] pub struct BatchIndex { diff --git a/datafusion/physical-plan/src/joins/hash_join/exec.rs b/datafusion/physical-plan/src/joins/hash_join/exec.rs index c8ed1960393ca..79fe93529dd1e 100644 --- a/datafusion/physical-plan/src/joins/hash_join/exec.rs +++ b/datafusion/physical-plan/src/joins/hash_join/exec.rs @@ -29,6 +29,7 @@ use crate::filter_pushdown::{ use crate::joins::hash_join::shared_bounds::{ColumnBounds, SharedBoundsAccumulator}; use crate::joins::hash_join::stream::{ BuildSide, BuildSideInitialState, HashJoinStream, HashJoinStreamState, + ProbeSideBoundsAccumulator, }; use crate::joins::join_hash_map::{JoinHashMapU32, JoinHashMapU64}; use crate::joins::utils::{ @@ -63,7 +64,8 @@ use arrow_schema::DataType; use datafusion_common::config::ConfigOptions; use datafusion_common::utils::memory::estimate_memory_size; use datafusion_common::{ - internal_err, plan_err, project_schema, JoinSide, JoinType, NullEquality, Result, + internal_datafusion_err, internal_err, plan_err, project_schema, JoinSide, JoinType, + NullEquality, Result, }; use datafusion_execution::memory_pool::{MemoryConsumer, MemoryReservation}; use datafusion_execution::TaskContext; @@ -84,6 +86,29 @@ use parking_lot::Mutex; const HASH_JOIN_SEED: RandomState = RandomState::with_seeds('J' as u64, 'O' as u64, 'I' as u64, 'N' as u64); +fn dynamic_filter_pushdown_side(join_type: JoinType) -> JoinSide { + use JoinSide::*; + + let preserves_left = join_type.preserves_left(); + let preserves_right = join_type.preserves_right(); + + match (preserves_left, preserves_right) { + (true, true) => None, + (true, false) => Right, + (false, true) => Left, + (false, false) => match join_type { + // Semi/anti joins do not preserve the build-side rows, but we still + // want dynamic filters that reference those rows to run there. By + // keeping them on the non-preserving side we avoid pushing the + // filter to the opposite input and incorrectly filtering the + // non-preserved (outer) rows instead. + JoinType::LeftSemi | JoinType::LeftAnti => Left, + JoinType::RightSemi | JoinType::RightAnti => Right, + _ => Right, + }, + } +} + /// HashTable and input data for the left (build side) of a join pub(super) struct JoinLeftData { /// The hash table with indices into `batch` @@ -463,12 +488,25 @@ impl HashJoinExec { }) } - fn create_dynamic_filter(on: &JoinOn) -> Arc { - // Extract the right-side keys (probe side keys) from the `on` clauses - // Dynamic filter will be created from build side values (left side) and applied to probe side (right side) - let right_keys: Vec<_> = on.iter().map(|(_, r)| Arc::clone(r)).collect(); + fn join_exprs_for_side(on: &JoinOn, side: JoinSide) -> Vec { + match side { + JoinSide::Left => on.iter().map(|(l, _)| Arc::clone(l)).collect(), + JoinSide::Right => on.iter().map(|(_, r)| Arc::clone(r)).collect(), + JoinSide::None => Vec::new(), + } + } + + fn create_dynamic_filter( + on: &JoinOn, + side: JoinSide, + ) -> Result> { + if side == JoinSide::None { + return internal_err!("dynamic filter side must be specified"); + } + // Extract the join key expressions from the side that will receive the dynamic filter + let keys = Self::join_exprs_for_side(on, side); // Initialize with a placeholder expression (true) that will be updated when the hash table is built - Arc::new(DynamicFilterPhysicalExpr::new(right_keys, lit(true))) + Ok(Arc::new(DynamicFilterPhysicalExpr::new(keys, lit(true)))) } /// left (build) side which gets hashed @@ -481,6 +519,11 @@ impl HashJoinExec { &self.right } + /// Preferred input side for installing a dynamic filter. + pub fn dynamic_filter_side(&self) -> JoinSide { + dynamic_filter_pushdown_side(self.join_type) + } + /// Set of common columns used to join on pub fn on(&self) -> &[(PhysicalExprRef, PhysicalExprRef)] { &self.on @@ -912,6 +955,7 @@ impl ExecutionPlan for HashJoinExec { } let enable_dynamic_filter_pushdown = self.dynamic_filter.is_some(); + let df_side = self.dynamic_filter_side(); let join_metrics = BuildProbeJoinMetrics::new(partition, &self.metrics); let left_fut = match self.mode { @@ -929,7 +973,7 @@ impl ExecutionPlan for HashJoinExec { reservation, need_produce_result_in_final(self.join_type), self.right().output_partitioning().partition_count(), - enable_dynamic_filter_pushdown, + enable_dynamic_filter_pushdown && df_side == JoinSide::Right, )) })?, PartitionMode::Partitioned => { @@ -947,7 +991,7 @@ impl ExecutionPlan for HashJoinExec { reservation, need_produce_result_in_final(self.join_type), 1, - enable_dynamic_filter_pushdown, + enable_dynamic_filter_pushdown && df_side == JoinSide::Right, )) } PartitionMode::Auto => { @@ -960,23 +1004,22 @@ impl ExecutionPlan for HashJoinExec { let batch_size = context.session_config().batch_size(); + // Select join expressions for dynamic filter side + let dynamic_filter_on = Self::join_exprs_for_side(&self.on, df_side); + // Initialize bounds_accumulator lazily with runtime partition counts (only if enabled) let bounds_accumulator = enable_dynamic_filter_pushdown .then(|| { self.dynamic_filter.as_ref().map(|df| { let filter = Arc::clone(&df.filter); - let on_right = self - .on - .iter() - .map(|(_, right_expr)| Arc::clone(right_expr)) - .collect::>(); + let on = dynamic_filter_on.clone(); Some(Arc::clone(df.bounds_accumulator.get_or_init(|| { Arc::new(SharedBoundsAccumulator::new_from_partition_mode( self.mode, self.left.as_ref(), self.right.as_ref(), filter, - on_right, + on, )) }))) }) @@ -984,9 +1027,10 @@ impl ExecutionPlan for HashJoinExec { .flatten() .flatten(); - // we have the batches and the hash map with their keys. We can how create a stream + // we have the batches and the hash map with their keys. We can now create a stream // over the right that uses this information to issue new batches. let right_stream = self.right.execute(partition, context)?; + let right_schema = right_stream.schema(); // update column indices to reflect the projection let column_indices_after_projection = match &self.projection { @@ -1003,12 +1047,30 @@ impl ExecutionPlan for HashJoinExec { .map(|(_, right_expr)| Arc::clone(right_expr)) .collect::>(); + let probe_bounds_accumulators = + if enable_dynamic_filter_pushdown && df_side == JoinSide::Left { + Some( + on_right + .iter() + .map(|expr| { + ProbeSideBoundsAccumulator::try_new( + Arc::clone(expr), + &right_schema, + ) + }) + .collect::>>()?, + ) + } else { + None + }; + Ok(Box::pin(HashJoinStream::new( partition, self.schema(), on_right, self.filter.clone(), self.join_type, + df_side, right_stream, self.random_state.clone(), join_metrics, @@ -1020,6 +1082,7 @@ impl ExecutionPlan for HashJoinExec { vec![], self.right.output_ordering().is_some(), bounds_accumulator, + probe_bounds_accumulators, self.mode, ))) } @@ -1096,6 +1159,34 @@ impl ExecutionPlan for HashJoinExec { phase: FilterPushdownPhase, parent_filters: Vec>, config: &ConfigOptions, + ) -> Result { + let df_side = self.dynamic_filter_side(); + self.gather_filters_for_pushdown_with_side(phase, parent_filters, config, df_side) + } + + fn handle_child_pushdown_result( + &self, + phase: FilterPushdownPhase, + child_pushdown_result: ChildPushdownResult, + config: &ConfigOptions, + ) -> Result>> { + let df_side = self.dynamic_filter_side(); + self.handle_child_pushdown_result_with_side( + phase, + child_pushdown_result, + config, + df_side, + ) + } +} + +impl HashJoinExec { + fn gather_filters_for_pushdown_with_side( + &self, + phase: FilterPushdownPhase, + parent_filters: Vec>, + config: &ConfigOptions, + df_side: JoinSide, ) -> Result { // Other types of joins can support *some* filters, but restrictions are complex and error prone. // For now we don't support them. @@ -1109,7 +1200,7 @@ impl ExecutionPlan for HashJoinExec { } // Get basic filter descriptions for both children - let left_child = crate::filter_pushdown::ChildFilterDescription::from_child( + let mut left_child = crate::filter_pushdown::ChildFilterDescription::from_child( &parent_filters, self.left(), )?; @@ -1122,9 +1213,25 @@ impl ExecutionPlan for HashJoinExec { if matches!(phase, FilterPushdownPhase::Post) && config.optimizer.enable_dynamic_filter_pushdown { - // Add actual dynamic filter to right side (probe side) - let dynamic_filter = Self::create_dynamic_filter(&self.on); - right_child = right_child.with_self_filter(dynamic_filter); + if df_side == JoinSide::None { + // A join type that preserves both sides (e.g. FULL) cannot + // leverage dynamic filters. Return early before attempting to + // create one. + return Ok(FilterDescription::new() + .with_child(left_child) + .with_child(right_child)); + } + + let dynamic_filter = Self::create_dynamic_filter(&self.on, df_side)?; + match df_side { + JoinSide::Left => { + left_child = left_child.with_self_filter(dynamic_filter); + } + JoinSide::Right => { + right_child = right_child.with_self_filter(dynamic_filter); + } + JoinSide::None => unreachable!(), + } } Ok(FilterDescription::new() @@ -1132,20 +1239,19 @@ impl ExecutionPlan for HashJoinExec { .with_child(right_child)) } - fn handle_child_pushdown_result( + fn handle_child_pushdown_result_with_side( &self, _phase: FilterPushdownPhase, child_pushdown_result: ChildPushdownResult, _config: &ConfigOptions, + df_side: JoinSide, ) -> Result>> { // Note: this check shouldn't be necessary because we already marked all parent filters as unsupported for // non-inner joins in `gather_filters_for_pushdown`. // However it's a cheap check and serves to inform future devs touching this function that they need to be really // careful pushing down filters through non-inner joins. - if self.join_type != JoinType::Inner { - // Other types of joins can support *some* filters, but restrictions are complex and error prone. - // For now we don't support them. - // See the logical optimizer rules for more details: datafusion/optimizer/src/push_down_filter.rs + if df_side == JoinSide::None { + // Joins that preserve both sides (e.g. FULL) cannot leverage dynamic filters. return Ok(FilterPushdownPropagation::all_unsupported( child_pushdown_result, )); @@ -1153,38 +1259,43 @@ impl ExecutionPlan for HashJoinExec { let mut result = FilterPushdownPropagation::if_any(child_pushdown_result.clone()); assert_eq!(child_pushdown_result.self_filters.len(), 2); // Should always be 2, we have 2 children - let right_child_self_filters = &child_pushdown_result.self_filters[1]; // We only push down filters to the right child - // We expect 0 or 1 self filters - if let Some(filter) = right_child_self_filters.first() { + let filter_child_idx = match df_side { + JoinSide::Left => 0, + JoinSide::Right => 1, + JoinSide::None => return Ok(result), + }; + // We expect 0 or 1 self filters + let child_self_filters = &child_pushdown_result.self_filters[filter_child_idx]; + if let Some(filter) = child_self_filters.first() { // Note that we don't check PushdDownPredicate::discrimnant because even if nothing said // "yes, I can fully evaluate this filter" things might still use it for statistics -> it's worth updating let predicate = Arc::clone(&filter.predicate); - if let Ok(dynamic_filter) = - Arc::downcast::(predicate) - { - // We successfully pushed down our self filter - we need to make a new node with the dynamic filter - let new_node = Arc::new(HashJoinExec { - left: Arc::clone(&self.left), - right: Arc::clone(&self.right), - on: self.on.clone(), - filter: self.filter.clone(), - join_type: self.join_type, - join_schema: Arc::clone(&self.join_schema), - left_fut: Arc::clone(&self.left_fut), - random_state: self.random_state.clone(), - mode: self.mode, - metrics: ExecutionPlanMetricsSet::new(), - projection: self.projection.clone(), - column_indices: self.column_indices.clone(), - null_equality: self.null_equality, - cache: self.cache.clone(), - dynamic_filter: Some(HashJoinExecDynamicFilter { - filter: dynamic_filter, - bounds_accumulator: OnceLock::new(), - }), - }); - result = result.with_updated_node(new_node as Arc); - } + let dynamic_filter = Arc::downcast::(predicate) + .map_err(|_| { + internal_datafusion_err!("expected DynamicFilterPhysicalExpr") + })?; + // We successfully pushed down our self filter - we need to make a new node with the dynamic filter + let new_node = Arc::new(HashJoinExec { + left: Arc::clone(&self.left), + right: Arc::clone(&self.right), + on: self.on.clone(), + filter: self.filter.clone(), + join_type: self.join_type, + join_schema: Arc::clone(&self.join_schema), + left_fut: Arc::clone(&self.left_fut), + random_state: self.random_state.clone(), + mode: self.mode, + metrics: ExecutionPlanMetricsSet::new(), + projection: self.projection.clone(), + column_indices: self.column_indices.clone(), + null_equality: self.null_equality, + cache: self.cache.clone(), + dynamic_filter: Some(HashJoinExecDynamicFilter { + filter: dynamic_filter, + bounds_accumulator: OnceLock::new(), + }), + }); + result = result.with_updated_node(new_node as Arc); } Ok(result) } @@ -1576,6 +1687,67 @@ mod tests { ) } + #[test] + fn create_dynamic_filter_none_side_returns_error() { + let on: JoinOn = vec![]; + let err = HashJoinExec::create_dynamic_filter(&on, JoinSide::None).unwrap_err(); + assert_contains!(err.to_string(), "dynamic filter side must be specified"); + } + + #[rstest] + #[case(JoinType::Inner, JoinSide::Right)] + #[case(JoinType::Left, JoinSide::Right)] + #[case(JoinType::Right, JoinSide::Left)] + #[case(JoinType::Full, JoinSide::None)] + #[case(JoinType::LeftMark, JoinSide::Right)] + #[case(JoinType::RightMark, JoinSide::Left)] + #[case(JoinType::LeftSemi, JoinSide::Left)] + #[case(JoinType::RightSemi, JoinSide::Right)] + #[case(JoinType::LeftAnti, JoinSide::Left)] + #[case(JoinType::RightAnti, JoinSide::Right)] + fn dynamic_filter_side_prefers_non_preserved_input( + #[case] join_type: JoinType, + #[case] expected_side: JoinSide, + ) { + assert_eq!(dynamic_filter_pushdown_side(join_type), expected_side); + } + + #[test] + fn full_join_skips_dynamic_filter_creation() -> Result<()> { + use arrow::array::Int32Array; + use arrow::datatypes::{DataType, Field, Schema}; + use datafusion_physical_expr::expressions::col; + + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])); + let batch = RecordBatch::try_new( + Arc::clone(&schema), + vec![Arc::new(Int32Array::from(vec![1]))], + )?; + let left = + TestMemoryExec::try_new(&[vec![batch.clone()]], Arc::clone(&schema), None)?; + let right = TestMemoryExec::try_new(&[vec![batch]], Arc::clone(&schema), None)?; + + let on = vec![(col("a", &left.schema())?, col("a", &right.schema())?)]; + let join = HashJoinExec::try_new( + Arc::new(left), + Arc::new(right), + on, + None, + &JoinType::Full, + None, + PartitionMode::CollectLeft, + NullEquality::NullEqualsNull, + )?; + + let mut config = ConfigOptions::default(); + config.optimizer.enable_dynamic_filter_pushdown = true; + + let desc = + join.gather_filters_for_pushdown(FilterPushdownPhase::Post, vec![], &config)?; + assert!(desc.self_filters().iter().all(|f| f.is_empty())); + Ok(()) + } + async fn join_collect( left: Arc, right: Arc, @@ -4377,6 +4549,71 @@ mod tests { Ok(()) } + // This test verifies that when a HashJoinExec is created with a dynamic filter + // targeting the left side, the join build phase collects min/max bounds from + // the build-side input and reports them back into the dynamic filter for the + // other side. Concretely: + // - Left input has values [1, 3, 5] + // - Right (build) input has values [2, 4, 6] + // - JoinType::Right is used so that the right side acts as the build side + // and the dynamic filter is attached to the left side expression. + // - After fully executing the join, the dynamic filter should be updated + // with the observed bounds `a@0 >= 2 AND a@0 <= 6` (min=2, max=6). + // The test asserts that HashJoinExec correctly accumulates and reports these + // bounds so downstream consumers can use the dynamic predicate for pruning. + #[tokio::test] + async fn reports_bounds_when_dynamic_filter_side_left() -> Result<()> { + use datafusion_physical_expr::expressions::col; + + let task_ctx = Arc::new(TaskContext::default()); + + let left_schema = + Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])); + let left_batch = RecordBatch::try_new( + Arc::clone(&left_schema), + vec![Arc::new(Int32Array::from(vec![1, 3, 5]))], + )?; + let left = TestMemoryExec::try_new(&[vec![left_batch]], left_schema, None)?; + + let right_schema = + Arc::new(Schema::new(vec![Field::new("b", DataType::Int32, false)])); + let right_batch = RecordBatch::try_new( + Arc::clone(&right_schema), + vec![Arc::new(Int32Array::from(vec![2, 4, 6]))], + )?; + let right = TestMemoryExec::try_new(&[vec![right_batch]], right_schema, None)?; + + let on = vec![(col("a", &left.schema())?, col("b", &right.schema())?)]; + + let mut join = HashJoinExec::try_new( + Arc::new(left), + Arc::new(right), + on, + None, + &JoinType::Right, + None, + PartitionMode::CollectLeft, + NullEquality::NullEqualsNull, + )?; + + let dynamic_filter = + HashJoinExec::create_dynamic_filter(&join.on, JoinSide::Left)?; + join.dynamic_filter = Some(HashJoinExecDynamicFilter { + filter: Arc::clone(&dynamic_filter), + bounds_accumulator: OnceLock::new(), + }); + + let stream = join.execute(0, task_ctx)?; + let _batches: Vec = stream.try_collect().await?; + + assert_eq!( + format!("{}", dynamic_filter.current().unwrap()), + "a@0 >= 2 AND a@0 <= 6" + ); + + Ok(()) + } + fn build_table_struct( struct_name: &str, field_name_and_values: (&str, &Vec>), diff --git a/datafusion/physical-plan/src/joins/hash_join/shared_bounds.rs b/datafusion/physical-plan/src/joins/hash_join/shared_bounds.rs index 25f7a0de31acd..e6c2a43ad9182 100644 --- a/datafusion/physical-plan/src/joins/hash_join/shared_bounds.rs +++ b/datafusion/physical-plan/src/joins/hash_join/shared_bounds.rs @@ -105,10 +105,10 @@ pub(crate) struct SharedBoundsAccumulator { /// Shared state protected by a single mutex to avoid ordering concerns inner: Mutex, barrier: Barrier, - /// Dynamic filter for pushdown to probe side + /// Dynamic filter for pushdown dynamic_filter: Arc, - /// Right side join expressions needed for creating filter bounds - on_right: Vec, + /// Join expressions on the side receiving the dynamic filter + join_exprs: Vec, } /// State protected by SharedBoundsAccumulator's mutex @@ -149,7 +149,7 @@ impl SharedBoundsAccumulator { left_child: &dyn ExecutionPlan, right_child: &dyn ExecutionPlan, dynamic_filter: Arc, - on_right: Vec, + join_exprs: Vec, ) -> Self { // Troubleshooting: If partition counts are incorrect, verify this logic matches // the actual execution pattern in collect_build_side() @@ -171,7 +171,7 @@ impl SharedBoundsAccumulator { }), barrier: Barrier::new(expected_calls), dynamic_filter, - on_right, + join_exprs, } } @@ -199,16 +199,16 @@ impl SharedBoundsAccumulator { // Create range predicates for each join key in this partition let mut column_predicates = Vec::with_capacity(partition_bounds.len()); - for (col_idx, right_expr) in self.on_right.iter().enumerate() { + for (col_idx, expr) in self.join_exprs.iter().enumerate() { if let Some(column_bounds) = partition_bounds.get_column_bounds(col_idx) { // Create predicate: col >= min AND col <= max let min_expr = Arc::new(BinaryExpr::new( - Arc::clone(right_expr), + Arc::clone(expr), Operator::GtEq, lit(column_bounds.min.clone()), )) as Arc; let max_expr = Arc::new(BinaryExpr::new( - Arc::clone(right_expr), + Arc::clone(expr), Operator::LtEq, lit(column_bounds.max.clone()), )) as Arc; @@ -311,3 +311,67 @@ impl fmt::Debug for SharedBoundsAccumulator { write!(f, "SharedBoundsAccumulator") } } +#[cfg(test)] +mod tests { + use super::*; + use crate::empty::EmptyExec; + use arrow::datatypes::{DataType, Field, Schema}; + use datafusion_common_runtime::SpawnedTask; + use datafusion_physical_expr::expressions::{col, lit, DynamicFilterPhysicalExpr}; + use tokio::task; + + // This test verifies the synchronization behavior of `SharedBoundsAccumulator`. + // It ensures that the dynamic filter is not updated until all expected + // partitions have reported their build-side bounds. One partition reports + // in a spawned task while the test reports another; the dynamic filter + // should remain the default until the final partition arrives, at which + // point the accumulated bounds are combined and the dynamic filter is + // updated exactly once with range predicates (>= and <=) for the join key. + #[tokio::test] + async fn waits_for_all_partitions_before_updating() { + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, true)])); + let left = EmptyExec::new(Arc::clone(&schema)).with_partitions(2); + let right = EmptyExec::new(Arc::clone(&schema)).with_partitions(2); + let col_expr = col("a", &schema).unwrap(); + let dynamic = Arc::new(DynamicFilterPhysicalExpr::new( + vec![Arc::clone(&col_expr)], + lit(true), + )); + let acc = Arc::new(SharedBoundsAccumulator::new_from_partition_mode( + PartitionMode::Partitioned, + &left, + &right, + Arc::clone(&dynamic), + vec![Arc::clone(&col_expr)], + )); + + assert_eq!(format!("{}", dynamic.current().unwrap()), "true"); + + let acc0 = Arc::clone(&acc); + let handle = SpawnedTask::spawn(async move { + acc0.report_partition_bounds( + 0, + Some(vec![ColumnBounds::new( + ScalarValue::from(1i32), + ScalarValue::from(2i32), + )]), + ) + .await + .unwrap(); + }); + task::yield_now().await; + assert_eq!(format!("{}", dynamic.current().unwrap()), "true"); + acc.report_partition_bounds( + 1, + Some(vec![ColumnBounds::new( + ScalarValue::from(3i32), + ScalarValue::from(4i32), + )]), + ) + .await + .unwrap(); + handle.await.unwrap(); + let updated = format!("{}", dynamic.current().unwrap()); + assert!(updated.contains(">=") && updated.contains("<=")); + } +} diff --git a/datafusion/physical-plan/src/joins/hash_join/stream.rs b/datafusion/physical-plan/src/joins/hash_join/stream.rs index adc00d9fe75ec..c744ba9c97625 100644 --- a/datafusion/physical-plan/src/joins/hash_join/stream.rs +++ b/datafusion/physical-plan/src/joins/hash_join/stream.rs @@ -24,7 +24,7 @@ use std::sync::Arc; use std::task::Poll; use crate::joins::hash_join::exec::JoinLeftData; -use crate::joins::hash_join::shared_bounds::SharedBoundsAccumulator; +use crate::joins::hash_join::shared_bounds::{ColumnBounds, SharedBoundsAccumulator}; use crate::joins::utils::{ equal_rows_arr, get_final_indices_from_shared_bitmap, OnceFut, }; @@ -43,11 +43,13 @@ use crate::{ }; use arrow::array::{ArrayRef, UInt32Array, UInt64Array}; -use arrow::datatypes::{Schema, SchemaRef}; +use arrow::datatypes::{DataType, Schema, SchemaRef}; use arrow::record_batch::RecordBatch; use datafusion_common::{ internal_datafusion_err, internal_err, JoinSide, JoinType, NullEquality, Result, }; +use datafusion_expr::Accumulator; +use datafusion_functions_aggregate_common::min_max::{MaxAccumulator, MinAccumulator}; use datafusion_physical_expr::PhysicalExprRef; use ahash::RandomState; @@ -102,6 +104,56 @@ impl BuildSide { } } +/// Accumulates probe-side column bounds for dynamic filter pushdown. +/// +/// This mirrors the build-side accumulator used when collecting bounds from +/// the left (build) side. Each accumulator tracks the minimum and maximum +/// values for a single join key expression. +pub(super) struct ProbeSideBoundsAccumulator { + expr: PhysicalExprRef, + min: MinAccumulator, + max: MaxAccumulator, +} + +impl ProbeSideBoundsAccumulator { + /// Creates a new accumulator for the given join key expression. + pub(super) fn try_new(expr: PhysicalExprRef, schema: &SchemaRef) -> Result { + fn dictionary_value_type(data_type: &DataType) -> DataType { + match data_type { + DataType::Dictionary(_, value_type) => { + dictionary_value_type(value_type.as_ref()) + } + _ => data_type.clone(), + } + } + + let data_type = expr + .data_type(schema) + .map(|dt| dictionary_value_type(&dt))?; + Ok(Self { + expr, + min: MinAccumulator::try_new(&data_type)?, + max: MaxAccumulator::try_new(&data_type)?, + }) + } + + /// Updates bounds using values from the provided batch. + fn update_batch(&mut self, batch: &RecordBatch) -> Result<()> { + let array = self.expr.evaluate(batch)?.into_array(batch.num_rows())?; + self.min.update_batch(std::slice::from_ref(&array))?; + self.max.update_batch(std::slice::from_ref(&array))?; + Ok(()) + } + + /// Returns the final column bounds. + fn evaluate(mut self) -> Result { + Ok(ColumnBounds::new( + self.min.evaluate()?, + self.max.evaluate()?, + )) + } +} + /// Represents state of HashJoinStream /// /// Expected state transitions performed by HashJoinStream are: @@ -186,6 +238,8 @@ pub(super) struct HashJoinStream { filter: Option, /// type of the join (left, right, semi, etc) join_type: JoinType, + /// Preferred input side for dynamic filter installation + dynamic_filter_side: JoinSide, /// right (probe) input right: SendableRecordBatchStream, /// Random state used for hashing initialization @@ -211,7 +265,10 @@ pub(super) struct HashJoinStream { /// Optional future to signal when bounds have been reported by all partitions /// and the dynamic filter has been updated bounds_waiter: Option>, - + /// Accumulators for probe-side bounds when filtering the left side + probe_bounds_accumulators: Option>, + /// Total number of probe-side rows processed (for bounds reporting) + probe_side_row_count: usize, /// Partitioning mode to use mode: PartitionMode, } @@ -305,6 +362,7 @@ impl HashJoinStream { on_right: Vec, filter: Option, join_type: JoinType, + dynamic_filter_side: JoinSide, right: SendableRecordBatchStream, random_state: RandomState, join_metrics: BuildProbeJoinMetrics, @@ -316,6 +374,7 @@ impl HashJoinStream { hashes_buffer: Vec, right_side_ordered: bool, bounds_accumulator: Option>, + probe_bounds_accumulators: Option>, mode: PartitionMode, ) -> Self { Self { @@ -324,6 +383,7 @@ impl HashJoinStream { on_right, filter, join_type, + dynamic_filter_side, right, random_state, join_metrics, @@ -336,6 +396,8 @@ impl HashJoinStream { right_side_ordered, bounds_accumulator, bounds_waiter: None, + probe_bounds_accumulators, + probe_side_row_count: 0, mode, } } @@ -411,21 +473,26 @@ impl HashJoinStream { // Dynamic filter coordination between partitions: // Report bounds to the accumulator which will handle synchronization and filter updates if let Some(ref bounds_accumulator) = self.bounds_accumulator { - let bounds_accumulator = Arc::clone(bounds_accumulator); - - let left_side_partition_id = match self.mode { - PartitionMode::Partitioned => self.partition, - PartitionMode::CollectLeft => 0, - PartitionMode::Auto => unreachable!("PartitionMode::Auto should not be present at execution time. This is a bug in DataFusion, please report it!"), - }; - - let left_data_bounds = left_data.bounds.clone(); - self.bounds_waiter = Some(OnceFut::new(async move { - bounds_accumulator - .report_partition_bounds(left_side_partition_id, left_data_bounds) - .await - })); - self.state = HashJoinStreamState::WaitPartitionBoundsReport; + if self.dynamic_filter_side == JoinSide::Right { + let bounds_accumulator = Arc::clone(bounds_accumulator); + + let left_side_partition_id = match self.mode { + PartitionMode::Partitioned => self.partition, + PartitionMode::CollectLeft => 0, + PartitionMode::Auto => unreachable!("PartitionMode::Auto should not be present at execution time. This is a bug in DataFusion, please report it!"), + }; + + let left_data_bounds = left_data.bounds.clone(); + self.bounds_waiter = Some(OnceFut::new(async move { + bounds_accumulator + .report_partition_bounds(left_side_partition_id, left_data_bounds) + .await + })); + self.state = HashJoinStreamState::WaitPartitionBoundsReport; + } else { + // Bounds for other sides are not collected at this stage + self.state = HashJoinStreamState::FetchProbeBatch; + } } else { self.state = HashJoinStreamState::FetchProbeBatch; } @@ -444,7 +511,35 @@ impl HashJoinStream { ) -> Poll>>> { match ready!(self.right.poll_next_unpin(cx)) { None => { - self.state = HashJoinStreamState::ExhaustedProbeSide; + if let Some(ref bounds_accumulator) = self.bounds_accumulator { + if self.dynamic_filter_side == JoinSide::Left { + if let Some(accs) = self.probe_bounds_accumulators.take() { + let right_bounds = if self.probe_side_row_count > 0 { + Some( + accs.into_iter() + .map(|acc| acc.evaluate()) + .collect::>>()?, + ) + } else { + None + }; + let bounds_accumulator = Arc::clone(bounds_accumulator); + let partition = self.partition; + self.bounds_waiter = Some(OnceFut::new(async move { + bounds_accumulator + .report_partition_bounds(partition, right_bounds) + .await + })); + self.state = HashJoinStreamState::WaitPartitionBoundsReport; + } else { + self.state = HashJoinStreamState::ExhaustedProbeSide; + } + } else { + self.state = HashJoinStreamState::ExhaustedProbeSide; + } + } else { + self.state = HashJoinStreamState::ExhaustedProbeSide; + } } Some(Ok(batch)) => { // Precalculate hash values for fetched batch @@ -454,6 +549,13 @@ impl HashJoinStream { .map(|c| c.evaluate(&batch)?.into_array(batch.num_rows())) .collect::>>()?; + if let Some(accumulators) = self.probe_bounds_accumulators.as_mut() { + for acc in accumulators.iter_mut() { + acc.update_batch(&batch)?; + } + self.probe_side_row_count += batch.num_rows(); + } + self.hashes_buffer.clear(); self.hashes_buffer.resize(batch.num_rows(), 0); create_hashes(&keys_values, &self.random_state, &mut self.hashes_buffer)?;