diff --git a/crates/polars-config/src/lib.rs b/crates/polars-config/src/lib.rs index 8db90af13ac7..a8e1d83dc150 100644 --- a/crates/polars-config/src/lib.rs +++ b/crates/polars-config/src/lib.rs @@ -62,6 +62,12 @@ const DEFAULT_OOC_SPILL_MIN_BYTES: u64 = 100 * 1024; // 100 KB const JOIN_SAMPLE_LIMIT: &str = "POLARS_JOIN_SAMPLE_LIMIT"; const DEFAULT_JOIN_SAMPLE_LIMIT: u64 = 10_000_000; +/// Allows pruning of strict hconcat inputs in projection pushdown. This can reduce data loading +/// but may discard shape errors. +const PROJECTION_PUSHDOWN_PRUNE_STRICT_HCONCAT_INPUTS: &str = + "POLARS_PROJECTION_PUSHDOWN_PRUNE_STRICT_HCONCAT_INPUTS"; +const DEFAULT_PROJECTION_PUSHDOWN_PRUNE_STRICT_HCONCAT_INPUTS: bool = false; + static KNOWN_OPTIONS: &[&str] = &[ // Public. VERBOSE, @@ -103,6 +109,7 @@ static KNOWN_OPTIONS: &[&str] = &[ OOC_MEMORY_BUDGET_FRACTION, OOC_SPILL_MIN_BYTES, JOIN_SAMPLE_LIMIT, + PROJECTION_PUSHDOWN_PRUNE_STRICT_HCONCAT_INPUTS, ]; pub struct Config { @@ -123,6 +130,7 @@ pub struct Config { ooc_memory_budget_fraction: AtomicU64, ooc_spill_min_bytes: AtomicU64, join_sample_limit: AtomicU64, + projection_pushdown_prune_strict_hconcat_inputs: AtomicBool, } impl Config { @@ -149,6 +157,9 @@ impl Config { ), ooc_spill_min_bytes: AtomicU64::new(DEFAULT_OOC_SPILL_MIN_BYTES), join_sample_limit: AtomicU64::new(DEFAULT_JOIN_SAMPLE_LIMIT), + projection_pushdown_prune_strict_hconcat_inputs: AtomicBool::new( + DEFAULT_PROJECTION_PUSHDOWN_PRUNE_STRICT_HCONCAT_INPUTS, + ), }; cfg.reload_env_vars(); cfg @@ -252,6 +263,13 @@ impl Config { .unwrap_or(DEFAULT_JOIN_SAMPLE_LIMIT), Ordering::Relaxed, ), + PROJECTION_PUSHDOWN_PRUNE_STRICT_HCONCAT_INPUTS => { + self.projection_pushdown_prune_strict_hconcat_inputs.store( + val.and_then(|x| parse::parse_bool(var, x)) + .unwrap_or(DEFAULT_PROJECTION_PUSHDOWN_PRUNE_STRICT_HCONCAT_INPUTS), + Ordering::Relaxed, + ) + }, _ => { if var.starts_with("POLARS_") { if self.warn_unknown_config.load(Ordering::Relaxed) { @@ -334,6 +352,11 @@ impl Config { pub fn join_sample_limit(&self) -> u64 { self.join_sample_limit.load(Ordering::Relaxed) } + + pub fn projection_pushdown_prune_strict_hconcat_inputs(&self) -> bool { + self.projection_pushdown_prune_strict_hconcat_inputs + .load(Ordering::Relaxed) + } } pub fn config() -> &'static Config { diff --git a/crates/polars-lazy/src/tests/projection_queries.rs b/crates/polars-lazy/src/tests/projection_queries.rs index 7154ac11a9f3..43f7872d6ead 100644 --- a/crates/polars-lazy/src/tests/projection_queries.rs +++ b/crates/polars-lazy/src/tests/projection_queries.rs @@ -220,52 +220,3 @@ fn test_select_hconcat_pushdown_non_strict_25263() -> PolarsResult<()> { Ok(()) } - -#[test] -fn test_select_hconcat_pushdown_strict_25263() -> PolarsResult<()> { - let df_a = df![ - "a" => [1, 2], - "b" => [4, 5], - ]? - .lazy(); - - let df_b = df![ - "d" => [1, 2], - ]? - .lazy(); - - // strict: we don't read any columns from `df_a` - let lf = concat_lf_horizontal( - [df_a, df_b], - HConcatOptions { - strict: true, - ..Default::default() - }, - )? - .select([col("d")]); - let plan = lf.clone().to_alp_optimized()?; - - let node = plan.lp_top; - let lp_arena = plan.lp_arena; - - assert!(lp_arena.iter(node).all(|(_, plan)| match plan { - IR::DataFrameScan { schema, .. } => { - // make sure that we don't read any columns from `df_a` - if schema.contains("a") { - panic!("should not have read any columns from `df_a`"); - } - true - }, - _ => true, - })); - - let out = lf.collect()?; - assert_eq!( - out, - df![ - "d" => [Some(1), Some(2)] - ]? - ); - - Ok(()) -} diff --git a/crates/polars-plan/src/plans/optimizer/projection_pushdown/mod.rs b/crates/polars-plan/src/plans/optimizer/projection_pushdown/mod.rs index 827a55d03ed6..bcb739a2e10b 100644 --- a/crates/polars-plan/src/plans/optimizer/projection_pushdown/mod.rs +++ b/crates/polars-plan/src/plans/optimizer/projection_pushdown/mod.rs @@ -136,55 +136,60 @@ impl<'a, 'arena> NodeVisitor for ProjectionPushdownVisitor<'a, 'arena> { edges: &mut dyn crate::traversal::edge_provider::NodeEdgesProvider, ) -> std::ops::ControlFlow { let out_edge = &mut edges.outputs()[0]; + let parent_key_and_port = out_edge.parent_key_and_port(); // This node was unlinked. We skip post-visit but remove the deletion mark, - // as otherwise the parent node will not be visited. - if out_edge.parent_key_and_port().is_deleted() { + // as otherwise the parent node will not be called for post_visit. + if parent_key_and_port.is_deleted() { out_edge.parent_key_and_port_mut().set_deleted(false); return ControlFlow::Continue(()); } - let parent_key_and_port = out_edge.parent_key_and_port(); + match storage.get(key) { + IR::HConcat { inputs, .. } => { + debug_assert_eq!(inputs.len(), edges.inputs().len()); + }, - 'patch_ext_context: { - let IR::ExtContext { schema, .. } = storage.get(key) else { - break 'patch_ext_context; - }; + IR::ExtContext { schema, .. } => { + let schema = match storage.get(parent_key_and_port.node) { + // Replace simple-projection added from pre-visit + IR::SimpleProjection { columns, .. } => columns.clone(), + // Wrap in `Select {}` if it is the root node, otherwise it only returns cols from first input. + _ if parent_key_and_port.node + == self.default_edge.parent_key_and_port().node => + { + schema.clone() + }, + _ => return ControlFlow::Continue(()), + }; - let schema = match storage.get(parent_key_and_port.node) { - // Replace simple-projection added from pre-visit - IR::SimpleProjection { columns, .. } => columns.clone(), - // Wrap in `Select {}` if it is the root node, otherwise it only returns cols from first input. - _ if parent_key_and_port.node == self.default_edge.parent_key_and_port().node => { - schema.clone() - }, - _ => break 'patch_ext_context, - }; + let mut exprs = Vec::with_capacity(schema.len()); + let schema = schema.clone(); + exprs.extend( + schema + .iter_names_cloned() + .map(|name| ExprIR::from_column_name(name, self.expr_arena)), + ); - let mut exprs = Vec::with_capacity(schema.len()); - let schema = schema.clone(); - exprs.extend( - schema - .iter_names_cloned() - .map(|name| ExprIR::from_column_name(name, self.expr_arena)), - ); - - let ext_ctx_ir = storage.take(key); - let new_key = storage.add(ext_ctx_ir); - - storage.replace( - key, - IR::Select { - input: new_key, - expr: exprs, - schema, - options: ProjectionOptions { - run_parallel: false, - duplicate_check: false, - should_broadcast: true, + let ext_ctx_ir = storage.take(key); + let new_key = storage.add(ext_ctx_ir); + + storage.replace( + key, + IR::Select { + input: new_key, + expr: exprs, + schema, + options: ProjectionOptions { + run_parallel: false, + duplicate_check: false, + should_broadcast: true, + }, }, - }, - ); + ); + }, + + _ => {}, } ControlFlow::Continue(()) @@ -677,9 +682,9 @@ impl ProjectionPushdownVisitor<'_, '_> { let output_schema_arc = output_schema; - if exprs.len() != orig_exprs_len - || input_names_projection.len() != input_schema.len() - { + let has_dropped_input_column = input_names_projection.len() != input_schema.len(); + + if exprs.len() != orig_exprs_len || has_dropped_input_column { let output_schema = Arc::make_mut(output_schema_arc); let mut orig_schema = mem::take(output_schema); @@ -706,18 +711,17 @@ impl ProjectionPushdownVisitor<'_, '_> { ) .map(Arc::new) }) - .or_else(|| { - (out_edge.projection() == Projection::All - && input_names_projection.len() != input_schema.len()) - .then_some(original_schema) - }) + .or( + (has_dropped_input_column && out_edge.projection() == Projection::All) + .then_some(original_schema), + ) { out_edge .parent_key_and_port_mut() .attach_simple_projection(schema, storage); } - if input_names_projection.len() != input_schema.len() { + if has_dropped_input_column { mem::swap(out_edge.names_mut(), input_names_projection); let names = out_edge.take_names(); *edges.inputs()[0].projection_state_mut() = ProjectionState { @@ -1217,6 +1221,7 @@ impl ProjectionPushdownVisitor<'_, '_> { let mut idx: usize = 0; let mut deleted: usize = 0; + let mut last_kept_input: usize = usize::MAX; inputs.retain(|input_node| { idx += 1; @@ -1243,7 +1248,11 @@ impl ProjectionPushdownVisitor<'_, '_> { ); if hconcat_projected_names.len() == base_new_names_len { - if strict && !self.maintain_errors { + if strict + && polars_config::config() + .projection_pushdown_prune_strict_hconcat_inputs() + && !self.maintain_errors + { break 'set_keep; } @@ -1274,14 +1283,42 @@ impl ProjectionPushdownVisitor<'_, '_> { if !keep { in_port.set_deleted(true); deleted += 1; - } else if deleted != 0 { - in_port.idx = idx - deleted; + } else { + last_kept_input = idx; + + if deleted != 0 { + in_port.idx = idx - deleted; + } } keep }); let new_inputs = inputs; + + if new_inputs.len() == 1 { + let input_node = new_inputs.into_iter().next().unwrap(); + let [in_edge, out_edge] = edges.get_input_output_mut(last_kept_input, 0); + let parent_key_and_port = out_edge.parent_key_and_port(); + + *storage + .get_mut(parent_key_and_port.node) + .inputs_mut() + .nth(parent_key_and_port.idx) + .unwrap() = input_node; + + // Only update parent node info; projection is already set from above. + mem::swap( + in_edge.parent_key_and_port_mut(), + out_edge.parent_key_and_port_mut(), + ); + + edges.outputs()[0] + .parent_key_and_port_mut() + .set_deleted(true); + return; + } + let IR::HConcat { inputs, schema, .. } = storage.get_mut(key) else { unreachable!() }; @@ -1289,20 +1326,12 @@ impl ProjectionPushdownVisitor<'_, '_> { Arc::make_mut(schema).retain(|name, _| hconcat_projected_names.contains(name)); - if hconcat_projected_names.len() != projected_names.len() { + if let Some(projected_schema) = + compute_simple_projection_schema(projected_names.as_slice(), schema, false) + { edges.outputs()[0] .parent_key_and_port_mut() - .attach_simple_projection( - Arc::new( - compute_simple_projection_schema( - projected_names.as_slice(), - schema, - false, - ) - .unwrap(), - ), - storage, - ); + .attach_simple_projection(Arc::new(projected_schema), storage); } }, diff --git a/py-polars/tests/unit/functions/test_concat.py b/py-polars/tests/unit/functions/test_concat.py index e1dfb8fc4eb4..4e699ba37bca 100644 --- a/py-polars/tests/unit/functions/test_concat.py +++ b/py-polars/tests/unit/functions/test_concat.py @@ -7,6 +7,7 @@ from polars._typing import ConcatMethod from polars.exceptions import ShapeError from polars.testing import assert_frame_equal +from tests.conftest import PlMonkeyPatch @pytest.mark.may_fail_cloud # reason: @serialize-stack-overflow @@ -472,3 +473,41 @@ def test_concat_horizontal_zero_width_height_mismatch_26876() -> None: with pytest.raises(ShapeError): q.collect() + + +def test_concat_horizontal_lazy_strict_raises_shape_error_27415( + plmonkeypatch: PlMonkeyPatch, +) -> None: + hconcat = pl.concat( + [ + pl.LazyFrame({"x": [0, 1]}), + pl.LazyFrame({"y": [0, 1, 2]}), + pl.LazyFrame({"z": [0, -1, -2]}), + ], + how="horizontal", + strict=True, + ) + + q = hconcat.select("y") + + with pytest.raises(ShapeError): + q.collect() + + plmonkeypatch.setenv("POLARS_PROJECTION_PUSHDOWN_PRUNE_STRICT_HCONCAT_INPUTS", "1") + q = hconcat.select("y") + plan = q.explain() + + assert "HCONCAT" not in plan + assert_frame_equal(q.collect(), pl.DataFrame({"y": [0, 1, 2]})) + + q = hconcat.select("z", "y") + + assert_frame_equal( + q.collect(), + pl.DataFrame( + { + "z": [0, -1, -2], + "y": [0, 1, 2], + } + ), + )