From e8698f7b8af1ca4a11f183fe234ca6800f9d076a Mon Sep 17 00:00:00 2001 From: Pepijn Van Eeckhoudt Date: Sat, 18 Oct 2025 13:36:42 +0200 Subject: [PATCH 01/44] Reduce record batch filtering in case_when_no_expr --- .../physical-expr/src/expressions/case.rs | 121 +++++++++++------- 1 file changed, 75 insertions(+), 46 deletions(-) diff --git a/datafusion/physical-expr/src/expressions/case.rs b/datafusion/physical-expr/src/expressions/case.rs index 2db599047bcd..e740d62daa78 100644 --- a/datafusion/physical-expr/src/expressions/case.rs +++ b/datafusion/physical-expr/src/expressions/case.rs @@ -17,22 +17,25 @@ use crate::expressions::try_cast; use crate::PhysicalExpr; -use std::borrow::Cow; -use std::hash::Hash; -use std::{any::Any, sync::Arc}; - use arrow::array::*; use arrow::compute::kernels::zip::zip; -use arrow::compute::{and, and_not, is_null, not, nullif, or, prep_null_mask_filter}; +use arrow::compute::{ + and, and_not, filter_record_batch, is_null, not, nullif, or, prep_null_mask_filter, +}; use arrow::datatypes::{DataType, Schema}; +use arrow::error::ArrowError; use datafusion_common::cast::as_boolean_array; use datafusion_common::{ exec_err, internal_datafusion_err, internal_err, DataFusionError, Result, ScalarValue, }; use datafusion_expr::ColumnarValue; +use std::borrow::Cow; +use std::hash::Hash; +use std::{any::Any, sync::Arc}; use super::{Column, Literal}; use datafusion_physical_expr_common::datum::compare_with_eq; +use datafusion_physical_expr_common::utils::scatter; use itertools::Itertools; type WhenThen = (Arc, Arc); @@ -122,6 +125,24 @@ fn is_cheap_and_infallible(expr: &Arc) -> bool { expr.as_any().is::() } +fn merge_result( + current_value: &dyn Array, + then_value: ColumnarValue, + then_scatter: &BooleanArray, +) -> std::result::Result { + match then_value { + ColumnarValue::Scalar(ScalarValue::Null) => nullif(current_value, then_scatter), + ColumnarValue::Scalar(then_value) => { + zip(then_scatter, &then_value.to_scalar()?, ¤t_value) + } + ColumnarValue::Array(then_value) => { + // TODO this operation should probably be feasible in one pass + let scattered_then = scatter(then_scatter, then_value.as_ref())?; + zip(then_scatter, &scattered_then, ¤t_value) + } + } +} + impl CaseExpr { /// Create a new CASE WHEN expression pub fn try_new( @@ -286,64 +307,72 @@ impl CaseExpr { // start with nulls as default output let mut current_value = new_null_array(&return_type, batch.num_rows()); - let mut remainder = BooleanArray::from(vec![true; batch.num_rows()]); - let mut remainder_count = batch.num_rows(); - for i in 0..self.when_then_expr.len() { - // If there are no rows left to process, break out of the loop early - if remainder_count == 0 { - break; - } + let mut remainder_scatter = BooleanArray::from(vec![true; batch.num_rows()]); + let mut remainder_batch = Cow::Borrowed(batch); + + for i in 0..self.when_then_expr.len() { + // Evaluate the 'when' predicate for the remainder batch + // This results in a boolean array with the same length as the remaining number of rows let when_predicate = &self.when_then_expr[i].0; - let when_value = when_predicate.evaluate_selection(batch, &remainder)?; - let when_value = when_value.into_array(batch.num_rows())?; + let when_value = when_predicate.evaluate(&remainder_batch)?; + let when_value = when_value.to_array(remainder_batch.num_rows())?; let when_value = as_boolean_array(&when_value).map_err(|_| { internal_datafusion_err!("WHEN expression did not return a BooleanArray") })?; - // Treat 'NULL' as false value - let when_value = match when_value.null_count() { - 0 => Cow::Borrowed(when_value), - _ => Cow::Owned(prep_null_mask_filter(when_value)), - }; - // Make sure we only consider rows that have not been matched yet - let when_value = and(&when_value, &remainder)?; - // If the predicate did not match any rows, continue to the next branch immediately + // If the 'when' predicate did not match any rows, continue to the next branch immediately let when_match_count = when_value.true_count(); if when_match_count == 0 { continue; } + // Make sure 'NULL' is treated as false + let when_value = match when_value.null_count() { + 0 => Cow::Borrowed(when_value), + _ => Cow::Owned(prep_null_mask_filter(when_value)), + }; + + // Filter the remainder batch based on the 'when' value + // This results in a batch containing only the rows that need to be evaluated + // for the current branch + let then_batch = filter_record_batch(&remainder_batch, &when_value)?; + + // Evaluate the then expression for the matching rows let then_expression = &self.when_then_expr[i].1; - let then_value = then_expression.evaluate_selection(batch, &when_value)?; + let then_value = then_expression.evaluate(&then_batch)?; - current_value = match then_value { - ColumnarValue::Scalar(ScalarValue::Null) => { - nullif(current_value.as_ref(), &when_value)? - } - ColumnarValue::Scalar(then_value) => { - zip(&when_value, &then_value.to_scalar()?, ¤t_value)? - } - ColumnarValue::Array(then_value) => { - zip(&when_value, &then_value, ¤t_value)? - } - }; + // Expand the 'when' match array using the 'remainder scatter' array + // This results in a truthy boolean array than we can use to merge the + // 'then' values with the `current_value` array. + let then_merge = scatter(&remainder_scatter, when_value.as_ref())?; + let then_merge = then_merge.as_boolean(); - // Succeed tuples should be filtered out for short-circuit evaluation, - // null values for the current when expr should be kept - remainder = and_not(&remainder, &when_value)?; - remainder_count -= when_match_count; + // Merge the 'then' values with the `current_value` array + current_value = merge_result(¤t_value, then_value, then_merge)?; + + // If the 'when' predicate matched all remaining row, there's nothing left to do so + // we can return early + if remainder_batch.num_rows() == when_match_count { + return Ok(ColumnarValue::Array(current_value)); + } + + // Clear the positions in 'remainder scatter' for which we just evaluated a value + remainder_scatter = and_not(&remainder_scatter, then_merge)?; + + // Finally, prepare the remainder batch for the next branch + let next_selection = not(&when_value)?; + remainder_batch = + Cow::Owned(filter_record_batch(&remainder_batch, &next_selection)?); } + // If we reached this point, some rows were left unmatched. + // Check if those need to be evaluated using the 'else' expression. if let Some(e) = self.else_expr() { - if remainder_count > 0 { - // keep `else_expr`'s data type and return type consistent - let expr = try_cast(Arc::clone(e), &batch.schema(), return_type.clone())?; - let else_ = expr - .evaluate_selection(batch, &remainder)? - .into_array(batch.num_rows())?; - current_value = zip(&remainder, &else_, ¤t_value)?; - } + // keep `else_expr`'s data type and return type consistent + let expr = try_cast(Arc::clone(e), &batch.schema(), return_type.clone())?; + let else_value = expr.evaluate(&remainder_batch)?; + current_value = merge_result(¤t_value, else_value, &remainder_scatter)?; } Ok(ColumnarValue::Array(current_value)) From 115b5542ba331d8613e08651e4a4d7517f85f36f Mon Sep 17 00:00:00 2001 From: Pepijn Van Eeckhoudt Date: Sat, 18 Oct 2025 17:58:54 +0200 Subject: [PATCH 02/44] Reduce record batch filtering in case_when_with_expr --- .../physical-expr/src/expressions/case.rs | 167 ++++++++++++------ 1 file changed, 109 insertions(+), 58 deletions(-) diff --git a/datafusion/physical-expr/src/expressions/case.rs b/datafusion/physical-expr/src/expressions/case.rs index e740d62daa78..919befd9b4dd 100644 --- a/datafusion/physical-expr/src/expressions/case.rs +++ b/datafusion/physical-expr/src/expressions/case.rs @@ -15,12 +15,14 @@ // specific language governing permissions and limitations // under the License. +use super::{Column, Literal}; use crate::expressions::try_cast; use crate::PhysicalExpr; use arrow::array::*; use arrow::compute::kernels::zip::zip; use arrow::compute::{ - and, and_not, filter_record_batch, is_null, not, nullif, or, prep_null_mask_filter, + and_not, filter, filter_record_batch, is_not_null, is_null, not, nullif, + prep_null_mask_filter, }; use arrow::datatypes::{DataType, Schema}; use arrow::error::ArrowError; @@ -29,14 +31,12 @@ use datafusion_common::{ exec_err, internal_datafusion_err, internal_err, DataFusionError, Result, ScalarValue, }; use datafusion_expr::ColumnarValue; -use std::borrow::Cow; -use std::hash::Hash; -use std::{any::Any, sync::Arc}; - -use super::{Column, Literal}; use datafusion_physical_expr_common::datum::compare_with_eq; use datafusion_physical_expr_common::utils::scatter; use itertools::Itertools; +use std::borrow::Cow; +use std::hash::Hash; +use std::{any::Any, sync::Arc}; type WhenThen = (Arc, Arc); @@ -217,79 +217,130 @@ impl CaseExpr { /// END fn case_when_with_expr(&self, batch: &RecordBatch) -> Result { let return_type = self.data_type(&batch.schema())?; - let expr = self.expr.as_ref().unwrap(); - let base_value = expr.evaluate(batch)?; + + let base_value = self.expr.as_ref().unwrap().evaluate(batch)?; let base_value = base_value.into_array(batch.num_rows())?; - let base_nulls = is_null(base_value.as_ref())?; - // start with nulls as default output - let mut current_value = new_null_array(&return_type, batch.num_rows()); - // We only consider non-null values while comparing with whens - let mut remainder = not(&base_nulls)?; - let mut non_null_remainder_count = remainder.true_count(); - for i in 0..self.when_then_expr.len() { - // If there are no rows left to process, break out of the loop early - if non_null_remainder_count == 0 { - break; + let mut remainder_scatter = is_not_null(base_value.as_ref())?; + + let not_null_count = remainder_scatter.true_count(); + let initial_values = if not_null_count == 0 { + // All null base values. No need to evaluate any when expressions since + // those can never match null. + if let Some(e) = self.else_expr() { + // There is an else expression, so all rows evaluate to that. + + // keep `else_expr`'s data type and return type consistent + let expr = try_cast(Arc::clone(e), &batch.schema(), return_type.clone())?; + let else_value = expr.evaluate(batch)?; + return Ok(ColumnarValue::Array(else_value.to_array(batch.num_rows())?)); + } else { + // No else expression, so the entire result is null. + return Ok(ColumnarValue::Array(new_null_array( + &return_type, + batch.num_rows(), + ))); } + } else if not_null_count == batch.num_rows() { + // No null base values + ( + Cow::Borrowed(batch), + base_value, + new_null_array(&return_type, batch.num_rows()), + ) + } else { + // Some null values. The initial result array is either all nulls + // or the else value for null base values. + let current_value = if let Some(e) = self.else_expr() { + // keep `else_expr`'s data type and return type consistent + let expr = try_cast(Arc::clone(e), &batch.schema(), return_type.clone())?; + let nulls = is_null(base_value.as_ref())?; + let nulls_batch = filter_record_batch(batch, &nulls)?; + let else_value = expr.evaluate(&nulls_batch)?; + scatter(&nulls, &else_value.to_array(nulls_batch.num_rows())?)? + } else { + new_null_array(&return_type, batch.num_rows()) + }; - let when_predicate = &self.when_then_expr[i].0; - let when_value = when_predicate.evaluate_selection(batch, &remainder)?; - let when_value = when_value.into_array(batch.num_rows())?; - // build boolean array representing which rows match the "when" value - let when_match = compare_with_eq( - &when_value, + ( + Cow::Owned(filter_record_batch(batch, &remainder_scatter)?), + filter(base_value.as_ref(), &remainder_scatter)?, + current_value, + ) + }; + + let (mut remainder_batch, mut base_value, mut current_value) = initial_values; + + for i in 0..self.when_then_expr.len() { + // Evaluate the 'when' predicate for the remainder batch + // This results in a boolean array with the same length as the remaining number of rows + let when_expression = &self.when_then_expr[i].0; + let when_value = when_expression.evaluate(&remainder_batch)?; + let when_value = when_value.to_array(remainder_batch.num_rows())?; + let when_value = compare_with_eq( &base_value, + &when_value, // The types of case and when expressions will be coerced to match. // We only need to check if the base_value is nested. base_value.data_type().is_nested(), )?; - // Treat nulls as false - let when_match = match when_match.null_count() { - 0 => Cow::Borrowed(&when_match), - _ => Cow::Owned(prep_null_mask_filter(&when_match)), - }; - // Make sure we only consider rows that have not been matched yet - let when_value = and(&when_match, &remainder)?; + let when_value = as_boolean_array(&when_value).map_err(|_| { + internal_datafusion_err!("WHEN expression did not return a BooleanArray") + })?; - // If the predicate did not match any rows, continue to the next branch immediately + // If the 'when' predicate did not match any rows, continue to the next branch immediately let when_match_count = when_value.true_count(); if when_match_count == 0 { continue; } - let then_expression = &self.when_then_expr[i].1; - let then_value = then_expression.evaluate_selection(batch, &when_value)?; - - current_value = match then_value { - ColumnarValue::Scalar(ScalarValue::Null) => { - nullif(current_value.as_ref(), &when_value)? - } - ColumnarValue::Scalar(then_value) => { - zip(&when_value, &then_value.to_scalar()?, ¤t_value)? - } - ColumnarValue::Array(then_value) => { - zip(&when_value, &then_value, ¤t_value)? - } + // Make sure 'NULL' is treated as false + let when_value = match when_value.null_count() { + 0 => Cow::Borrowed(when_value), + _ => Cow::Owned(prep_null_mask_filter(when_value)), }; - remainder = and_not(&remainder, &when_value)?; - non_null_remainder_count -= when_match_count; - } + // Filter the remainder batch based on the 'when' value + // This results in a batch containing only the rows that need to be evaluated + // for the current branch + let then_batch = filter_record_batch(&remainder_batch, &when_value)?; - if let Some(e) = self.else_expr() { - // null and unmatched tuples should be assigned else value - remainder = or(&base_nulls, &remainder)?; + // Evaluate the then expression for the matching rows + let then_expression = &self.when_then_expr[i].1; + let then_value = then_expression.evaluate(&then_batch)?; - if remainder.true_count() > 0 { - // keep `else_expr`'s data type and return type consistent - let expr = try_cast(Arc::clone(e), &batch.schema(), return_type.clone())?; + // Expand the 'when' match array using the 'remainder scatter' array + // This results in a truthy boolean array than we can use to merge the + // 'then' values with the `current_value` array. + let then_merge = scatter(&remainder_scatter, when_value.as_ref())?; + let then_merge = then_merge.as_boolean(); - let else_ = expr - .evaluate_selection(batch, &remainder)? - .into_array(batch.num_rows())?; - current_value = zip(&remainder, &else_, ¤t_value)?; + // Merge the 'then' values with the `current_value` array + current_value = merge_result(¤t_value, then_value, then_merge)?; + + // If the 'when' predicate matched all remaining row, there's nothing left to do so + // we can return early + if remainder_batch.num_rows() == when_match_count { + return Ok(ColumnarValue::Array(current_value)); } + + // Clear the positions in 'remainder scatter' for which we just evaluated a value + remainder_scatter = and_not(&remainder_scatter, then_merge)?; + + // Finally, prepare the remainder batch for the next branch + let next_selection = not(&when_value)?; + remainder_batch = + Cow::Owned(filter_record_batch(&remainder_batch, &next_selection)?); + base_value = filter(base_value.as_ref(), &next_selection)?; + } + + // If we reached this point, some rows were left unmatched. + // Check if those need to be evaluated using the 'else' expression. + if let Some(e) = self.else_expr() { + // keep `else_expr`'s data type and return type consistent + let expr = try_cast(Arc::clone(e), &batch.schema(), return_type.clone())?; + let else_value = expr.evaluate(&remainder_batch)?; + current_value = merge_result(¤t_value, else_value, &remainder_scatter)?; } Ok(ColumnarValue::Array(current_value)) From 9c641aa2deef2c8b843ef7b919370abb7fbe8600 Mon Sep 17 00:00:00 2001 From: Pepijn Van Eeckhoudt Date: Sun, 19 Oct 2025 21:44:01 +0200 Subject: [PATCH 03/44] Avoid unnecessary filtering in last iteration when no else expression is present --- .../physical-expr/src/expressions/case.rs | 32 ++++++++++++------- 1 file changed, 20 insertions(+), 12 deletions(-) diff --git a/datafusion/physical-expr/src/expressions/case.rs b/datafusion/physical-expr/src/expressions/case.rs index 919befd9b4dd..f73379d6dd4d 100644 --- a/datafusion/physical-expr/src/expressions/case.rs +++ b/datafusion/physical-expr/src/expressions/case.rs @@ -127,7 +127,7 @@ fn is_cheap_and_infallible(expr: &Arc) -> bool { fn merge_result( current_value: &dyn Array, - then_value: ColumnarValue, + then_value: &ColumnarValue, then_scatter: &BooleanArray, ) -> std::result::Result { match then_value { @@ -233,7 +233,9 @@ impl CaseExpr { // keep `else_expr`'s data type and return type consistent let expr = try_cast(Arc::clone(e), &batch.schema(), return_type.clone())?; let else_value = expr.evaluate(batch)?; - return Ok(ColumnarValue::Array(else_value.to_array(batch.num_rows())?)); + return Ok(ColumnarValue::Array( + else_value.into_array(batch.num_rows())?, + )); } else { // No else expression, so the entire result is null. return Ok(ColumnarValue::Array(new_null_array( @@ -257,7 +259,7 @@ impl CaseExpr { let nulls = is_null(base_value.as_ref())?; let nulls_batch = filter_record_batch(batch, &nulls)?; let else_value = expr.evaluate(&nulls_batch)?; - scatter(&nulls, &else_value.to_array(nulls_batch.num_rows())?)? + scatter(&nulls, &else_value.into_array(nulls_batch.num_rows())?)? } else { new_null_array(&return_type, batch.num_rows()) }; @@ -276,10 +278,10 @@ impl CaseExpr { // This results in a boolean array with the same length as the remaining number of rows let when_expression = &self.when_then_expr[i].0; let when_value = when_expression.evaluate(&remainder_batch)?; - let when_value = when_value.to_array(remainder_batch.num_rows())?; + let when_value = when_value.into_array(remainder_batch.num_rows())?; let when_value = compare_with_eq( - &base_value, &when_value, + &base_value, // The types of case and when expressions will be coerced to match. // We only need to check if the base_value is nested. base_value.data_type().is_nested(), @@ -316,11 +318,13 @@ impl CaseExpr { let then_merge = then_merge.as_boolean(); // Merge the 'then' values with the `current_value` array - current_value = merge_result(¤t_value, then_value, then_merge)?; + current_value = merge_result(¤t_value, &then_value, then_merge)?; // If the 'when' predicate matched all remaining row, there's nothing left to do so // we can return early - if remainder_batch.num_rows() == when_match_count { + if remainder_batch.num_rows() == when_match_count + || (self.else_expr.is_none() && i == self.when_then_expr.len() - 1) + { return Ok(ColumnarValue::Array(current_value)); } @@ -340,7 +344,8 @@ impl CaseExpr { // keep `else_expr`'s data type and return type consistent let expr = try_cast(Arc::clone(e), &batch.schema(), return_type.clone())?; let else_value = expr.evaluate(&remainder_batch)?; - current_value = merge_result(¤t_value, else_value, &remainder_scatter)?; + current_value = + merge_result(¤t_value, &else_value, &remainder_scatter)?; } Ok(ColumnarValue::Array(current_value)) @@ -367,7 +372,7 @@ impl CaseExpr { // This results in a boolean array with the same length as the remaining number of rows let when_predicate = &self.when_then_expr[i].0; let when_value = when_predicate.evaluate(&remainder_batch)?; - let when_value = when_value.to_array(remainder_batch.num_rows())?; + let when_value = when_value.into_array(remainder_batch.num_rows())?; let when_value = as_boolean_array(&when_value).map_err(|_| { internal_datafusion_err!("WHEN expression did not return a BooleanArray") })?; @@ -400,11 +405,13 @@ impl CaseExpr { let then_merge = then_merge.as_boolean(); // Merge the 'then' values with the `current_value` array - current_value = merge_result(¤t_value, then_value, then_merge)?; + current_value = merge_result(¤t_value, &then_value, then_merge)?; // If the 'when' predicate matched all remaining row, there's nothing left to do so // we can return early - if remainder_batch.num_rows() == when_match_count { + if remainder_batch.num_rows() == when_match_count + || (self.else_expr.is_none() && i == self.when_then_expr.len() - 1) + { return Ok(ColumnarValue::Array(current_value)); } @@ -423,7 +430,8 @@ impl CaseExpr { // keep `else_expr`'s data type and return type consistent let expr = try_cast(Arc::clone(e), &batch.schema(), return_type.clone())?; let else_value = expr.evaluate(&remainder_batch)?; - current_value = merge_result(¤t_value, else_value, &remainder_scatter)?; + current_value = + merge_result(¤t_value, &else_value, &remainder_scatter)?; } Ok(ColumnarValue::Array(current_value)) From 7539068a5b8aac2132992121682d23510a270ea1 Mon Sep 17 00:00:00 2001 From: Pepijn Van Eeckhoudt Date: Mon, 20 Oct 2025 22:45:43 +0200 Subject: [PATCH 04/44] Use interleave to construct case result --- .../physical-expr/src/expressions/case.rs | 252 +++++++++--------- 1 file changed, 130 insertions(+), 122 deletions(-) diff --git a/datafusion/physical-expr/src/expressions/case.rs b/datafusion/physical-expr/src/expressions/case.rs index f73379d6dd4d..74dd22b3384a 100644 --- a/datafusion/physical-expr/src/expressions/case.rs +++ b/datafusion/physical-expr/src/expressions/case.rs @@ -21,10 +21,10 @@ use crate::PhysicalExpr; use arrow::array::*; use arrow::compute::kernels::zip::zip; use arrow::compute::{ - and_not, filter, filter_record_batch, is_not_null, is_null, not, nullif, - prep_null_mask_filter, + interleave, is_null, not, nullif, prep_null_mask_filter, FilterBuilder, + FilterPredicate, }; -use arrow::datatypes::{DataType, Schema}; +use arrow::datatypes::{DataType, Schema, UInt32Type}; use arrow::error::ArrowError; use datafusion_common::cast::as_boolean_array; use datafusion_common::{ @@ -32,7 +32,6 @@ use datafusion_common::{ }; use datafusion_expr::ColumnarValue; use datafusion_physical_expr_common::datum::compare_with_eq; -use datafusion_physical_expr_common::utils::scatter; use itertools::Itertools; use std::borrow::Cow; use std::hash::Hash; @@ -125,21 +124,80 @@ fn is_cheap_and_infallible(expr: &Arc) -> bool { expr.as_any().is::() } -fn merge_result( - current_value: &dyn Array, - then_value: &ColumnarValue, - then_scatter: &BooleanArray, +fn create_filter(predicate: &BooleanArray, optimize: bool) -> FilterPredicate { + let mut filter_builder = FilterBuilder::new(predicate); + if optimize { + filter_builder = filter_builder.optimize(); + } + filter_builder.build() +} + +fn filter_record_batch( + record_batch: &RecordBatch, + filter: &FilterPredicate, +) -> std::result::Result { + let filtered_columns = record_batch + .columns() + .iter() + .map(|a| filter_array(a, filter)) + .collect::, _>>()?; + unsafe { + Ok(RecordBatch::new_unchecked( + record_batch.schema(), + filtered_columns, + filter.count(), + )) + } +} + +fn filter_array( + array: &dyn Array, + filter: &FilterPredicate, ) -> std::result::Result { - match then_value { - ColumnarValue::Scalar(ScalarValue::Null) => nullif(current_value, then_scatter), - ColumnarValue::Scalar(then_value) => { - zip(then_scatter, &then_value.to_scalar()?, ¤t_value) + filter.filter(array) +} + +struct InterleaveBuilder { + indices: Vec<(usize, usize)>, + arrays: Vec, +} + +impl InterleaveBuilder { + fn new(data_type: &DataType, capacity: usize) -> Self { + Self { + indices: vec![(0, 0); capacity], + arrays: vec![new_null_array(data_type, 1)], } - ColumnarValue::Array(then_value) => { - // TODO this operation should probably be feasible in one pass - let scattered_then = scatter(then_scatter, then_value.as_ref())?; - zip(then_scatter, &scattered_then, ¤t_value) + } + + fn add(&mut self, rows: &ArrayRef, value: ColumnarValue) -> Result<()> { + let array_index = self.arrays.len(); + match value { + ColumnarValue::Array(a) => { + self.arrays.push(a); + for (array_ix, row_ix) in rows + .as_primitive::() + .values() + .iter() + .enumerate() + { + self.indices[*row_ix as usize] = (array_index, array_ix); + } + } + ColumnarValue::Scalar(s) => { + self.arrays.push(s.to_array()?); + for row_ix in rows.as_primitive::().values().iter() { + self.indices[*row_ix as usize] = (array_index, 0); + } + } } + Ok(()) + } + + fn finish(self) -> Result { + let array_refs = self.arrays.iter().map(|a| a.as_ref()).collect::>(); + let interleaved_result = interleave(&array_refs, &self.indices)?; + Ok(ColumnarValue::Array(interleaved_result)) } } @@ -217,68 +275,46 @@ impl CaseExpr { /// END fn case_when_with_expr(&self, batch: &RecordBatch) -> Result { let return_type = self.data_type(&batch.schema())?; + let optimize_filters = batch.num_columns() > 1; + let mut interleave_builder = + InterleaveBuilder::new(&return_type, batch.num_rows()); - let base_value = self.expr.as_ref().unwrap().evaluate(batch)?; - let base_value = base_value.into_array(batch.num_rows())?; - - let mut remainder_scatter = is_not_null(base_value.as_ref())?; + let mut remainder_rows: ArrayRef = + Arc::new(UInt32Array::from_iter(0..batch.num_rows() as u32)); + let mut remainder_batch = Cow::Borrowed(batch); + let mut base_value = self + .expr + .as_ref() + .unwrap() + .evaluate(batch)? + .into_array(batch.num_rows())?; - let not_null_count = remainder_scatter.true_count(); - let initial_values = if not_null_count == 0 { - // All null base values. No need to evaluate any when expressions since - // those can never match null. + let base_nulls = is_null(base_value.as_ref())?; + if base_nulls.true_count() > 0 { if let Some(e) = self.else_expr() { - // There is an else expression, so all rows evaluate to that. - - // keep `else_expr`'s data type and return type consistent - let expr = try_cast(Arc::clone(e), &batch.schema(), return_type.clone())?; - let else_value = expr.evaluate(batch)?; - return Ok(ColumnarValue::Array( - else_value.into_array(batch.num_rows())?, - )); - } else { - // No else expression, so the entire result is null. - return Ok(ColumnarValue::Array(new_null_array( - &return_type, - batch.num_rows(), - ))); - } - } else if not_null_count == batch.num_rows() { - // No null base values - ( - Cow::Borrowed(batch), - base_value, - new_null_array(&return_type, batch.num_rows()), - ) - } else { - // Some null values. The initial result array is either all nulls - // or the else value for null base values. - let current_value = if let Some(e) = self.else_expr() { - // keep `else_expr`'s data type and return type consistent let expr = try_cast(Arc::clone(e), &batch.schema(), return_type.clone())?; - let nulls = is_null(base_value.as_ref())?; - let nulls_batch = filter_record_batch(batch, &nulls)?; - let else_value = expr.evaluate(&nulls_batch)?; - scatter(&nulls, &else_value.into_array(nulls_batch.num_rows())?)? - } else { - new_null_array(&return_type, batch.num_rows()) - }; - ( - Cow::Owned(filter_record_batch(batch, &remainder_scatter)?), - filter(base_value.as_ref(), &remainder_scatter)?, - current_value, - ) - }; + let nulls_filter = create_filter(&base_nulls, optimize_filters); + let nulls_batch = filter_record_batch(&remainder_batch, &nulls_filter)?; + let nulls_rows = filter_array(&remainder_rows, &nulls_filter)?; + let nulls_value = expr.evaluate(&nulls_batch)?; + interleave_builder.add(&nulls_rows, nulls_value)?; + } - let (mut remainder_batch, mut base_value, mut current_value) = initial_values; + let not_null_filter = create_filter(¬(&base_nulls)?, optimize_filters); + remainder_batch = + Cow::Owned(filter_record_batch(&remainder_batch, ¬_null_filter)?); + remainder_rows = filter_array(&remainder_rows, ¬_null_filter)?; + base_value = filter_array(&base_value, ¬_null_filter)?; + } for i in 0..self.when_then_expr.len() { // Evaluate the 'when' predicate for the remainder batch // This results in a boolean array with the same length as the remaining number of rows - let when_expression = &self.when_then_expr[i].0; - let when_value = when_expression.evaluate(&remainder_batch)?; - let when_value = when_value.into_array(remainder_batch.num_rows())?; + let when_expr = &self.when_then_expr[i].0; + let when_value = when_expr + .evaluate(&remainder_batch)? + .into_array(remainder_batch.num_rows())?; let when_value = compare_with_eq( &when_value, &base_value, @@ -286,9 +322,6 @@ impl CaseExpr { // We only need to check if the base_value is nested. base_value.data_type().is_nested(), )?; - let when_value = as_boolean_array(&when_value).map_err(|_| { - internal_datafusion_err!("WHEN expression did not return a BooleanArray") - })?; // If the 'when' predicate did not match any rows, continue to the next branch immediately let when_match_count = when_value.true_count(); @@ -296,46 +329,31 @@ impl CaseExpr { continue; } - // Make sure 'NULL' is treated as false - let when_value = match when_value.null_count() { - 0 => Cow::Borrowed(when_value), - _ => Cow::Owned(prep_null_mask_filter(when_value)), - }; - // Filter the remainder batch based on the 'when' value // This results in a batch containing only the rows that need to be evaluated // for the current branch - let then_batch = filter_record_batch(&remainder_batch, &when_value)?; - - // Evaluate the then expression for the matching rows + let then_filter = create_filter(&when_value, optimize_filters); + let then_batch = filter_record_batch(&remainder_batch, &then_filter)?; let then_expression = &self.when_then_expr[i].1; let then_value = then_expression.evaluate(&then_batch)?; - // Expand the 'when' match array using the 'remainder scatter' array - // This results in a truthy boolean array than we can use to merge the - // 'then' values with the `current_value` array. - let then_merge = scatter(&remainder_scatter, when_value.as_ref())?; - let then_merge = then_merge.as_boolean(); - - // Merge the 'then' values with the `current_value` array - current_value = merge_result(¤t_value, &then_value, then_merge)?; + let then_rows = filter_array(&remainder_rows, &then_filter)?; + interleave_builder.add(&then_rows, then_value)?; // If the 'when' predicate matched all remaining row, there's nothing left to do so // we can return early if remainder_batch.num_rows() == when_match_count || (self.else_expr.is_none() && i == self.when_then_expr.len() - 1) { - return Ok(ColumnarValue::Array(current_value)); + return interleave_builder.finish(); } - // Clear the positions in 'remainder scatter' for which we just evaluated a value - remainder_scatter = and_not(&remainder_scatter, then_merge)?; - - // Finally, prepare the remainder batch for the next branch let next_selection = not(&when_value)?; + let next_filter = create_filter(&next_selection, optimize_filters); remainder_batch = - Cow::Owned(filter_record_batch(&remainder_batch, &next_selection)?); - base_value = filter(base_value.as_ref(), &next_selection)?; + Cow::Owned(filter_record_batch(&remainder_batch, &next_filter)?); + remainder_rows = filter_array(&remainder_rows, &next_filter)?; + base_value = filter_array(&base_value, &next_filter)?; } // If we reached this point, some rows were left unmatched. @@ -344,11 +362,10 @@ impl CaseExpr { // keep `else_expr`'s data type and return type consistent let expr = try_cast(Arc::clone(e), &batch.schema(), return_type.clone())?; let else_value = expr.evaluate(&remainder_batch)?; - current_value = - merge_result(¤t_value, &else_value, &remainder_scatter)?; + interleave_builder.add(&remainder_rows, else_value)?; } - Ok(ColumnarValue::Array(current_value)) + interleave_builder.finish() } /// This function evaluates the form of CASE where each WHEN expression is a boolean @@ -359,12 +376,13 @@ impl CaseExpr { /// [ELSE result] /// END fn case_when_no_expr(&self, batch: &RecordBatch) -> Result { + let optimize_filters = batch.num_columns() > 1; let return_type = self.data_type(&batch.schema())?; + let mut interleave_builder = + InterleaveBuilder::new(&return_type, batch.num_rows()); - // start with nulls as default output - let mut current_value = new_null_array(&return_type, batch.num_rows()); - - let mut remainder_scatter = BooleanArray::from(vec![true; batch.num_rows()]); + let mut remainder_rows: ArrayRef = + Arc::new(UInt32Array::from_iter(0..batch.num_rows() as u32)); let mut remainder_batch = Cow::Borrowed(batch); for i in 0..self.when_then_expr.len() { @@ -392,36 +410,27 @@ impl CaseExpr { // Filter the remainder batch based on the 'when' value // This results in a batch containing only the rows that need to be evaluated // for the current branch - let then_batch = filter_record_batch(&remainder_batch, &when_value)?; - - // Evaluate the then expression for the matching rows + let then_filter = create_filter(&when_value, optimize_filters); + let then_batch = filter_record_batch(&remainder_batch, &then_filter)?; let then_expression = &self.when_then_expr[i].1; let then_value = then_expression.evaluate(&then_batch)?; - // Expand the 'when' match array using the 'remainder scatter' array - // This results in a truthy boolean array than we can use to merge the - // 'then' values with the `current_value` array. - let then_merge = scatter(&remainder_scatter, when_value.as_ref())?; - let then_merge = then_merge.as_boolean(); - - // Merge the 'then' values with the `current_value` array - current_value = merge_result(¤t_value, &then_value, then_merge)?; + let then_rows = filter_array(&remainder_rows, &then_filter)?; + interleave_builder.add(&then_rows, then_value)?; // If the 'when' predicate matched all remaining row, there's nothing left to do so // we can return early if remainder_batch.num_rows() == when_match_count || (self.else_expr.is_none() && i == self.when_then_expr.len() - 1) { - return Ok(ColumnarValue::Array(current_value)); + return interleave_builder.finish(); } - // Clear the positions in 'remainder scatter' for which we just evaluated a value - remainder_scatter = and_not(&remainder_scatter, then_merge)?; - - // Finally, prepare the remainder batch for the next branch let next_selection = not(&when_value)?; + let next_filter = create_filter(&next_selection, optimize_filters); remainder_batch = - Cow::Owned(filter_record_batch(&remainder_batch, &next_selection)?); + Cow::Owned(filter_record_batch(&remainder_batch, &next_filter)?); + remainder_rows = filter_array(&remainder_rows, &next_filter)?; } // If we reached this point, some rows were left unmatched. @@ -430,11 +439,10 @@ impl CaseExpr { // keep `else_expr`'s data type and return type consistent let expr = try_cast(Arc::clone(e), &batch.schema(), return_type.clone())?; let else_value = expr.evaluate(&remainder_batch)?; - current_value = - merge_result(¤t_value, &else_value, &remainder_scatter)?; + interleave_builder.add(&remainder_rows, else_value)?; } - Ok(ColumnarValue::Array(current_value)) + interleave_builder.finish() } /// This function evaluates the specialized case of: From b4c11173c4c48e476987d87d5a7a542d91f441e5 Mon Sep 17 00:00:00 2001 From: Pepijn Van Eeckhoudt Date: Tue, 21 Oct 2025 00:59:50 +0200 Subject: [PATCH 05/44] Add comments --- .../physical-expr/src/expressions/case.rs | 53 +++++++++++++++---- 1 file changed, 42 insertions(+), 11 deletions(-) diff --git a/datafusion/physical-expr/src/expressions/case.rs b/datafusion/physical-expr/src/expressions/case.rs index 74dd22b3384a..5de602688b4b 100644 --- a/datafusion/physical-expr/src/expressions/case.rs +++ b/datafusion/physical-expr/src/expressions/case.rs @@ -124,6 +124,7 @@ fn is_cheap_and_infallible(expr: &Arc) -> bool { expr.as_any().is::() } +/// Creates a [FilterPredicate] from a boolean array. fn create_filter(predicate: &BooleanArray, optimize: bool) -> FilterPredicate { let mut filter_builder = FilterBuilder::new(predicate); if optimize { @@ -150,6 +151,7 @@ fn filter_record_batch( } } +#[inline(always)] fn filter_array( array: &dyn Array, filter: &FilterPredicate, @@ -164,16 +166,28 @@ struct InterleaveBuilder { impl InterleaveBuilder { fn new(data_type: &DataType, capacity: usize) -> Self { + // By settings indices to (0, 0) every entry points to the single + // null value in the first array. Self { indices: vec![(0, 0); capacity], arrays: vec![new_null_array(data_type, 1)], } } - fn add(&mut self, rows: &ArrayRef, value: ColumnarValue) -> Result<()> { + /// Adds a result value. + /// + /// `rows` should be a [UInt32Array] containing [RecordBatch] relative row indices + /// for which `value` contains result values. + /// + /// If `value` is a scalar, the scalar value is used for each row in `rows`. + /// If `value` is an array, the values from the array and the indices from `rows` will be + /// processed pairwise. + fn add_result(&mut self, rows: &ArrayRef, value: ColumnarValue) -> Result<()> { let array_index = self.arrays.len(); match value { ColumnarValue::Array(a) => { + assert_eq!(a.len(), rows.len()); + self.arrays.push(a); for (array_ix, row_ix) in rows .as_primitive::() @@ -274,14 +288,19 @@ impl CaseExpr { /// [ELSE result] /// END fn case_when_with_expr(&self, batch: &RecordBatch) -> Result { - let return_type = self.data_type(&batch.schema())?; let optimize_filters = batch.num_columns() > 1; + + let return_type = self.data_type(&batch.schema())?; let mut interleave_builder = InterleaveBuilder::new(&return_type, batch.num_rows()); + // `remainder_rows` contains the indices of the rows that need to be evaluated let mut remainder_rows: ArrayRef = Arc::new(UInt32Array::from_iter(0..batch.num_rows() as u32)); + // `remainder_batch` contains the rows themselves that need to be evaluated let mut remainder_batch = Cow::Borrowed(batch); + + // evaluate the base expression let mut base_value = self .expr .as_ref() @@ -289,8 +308,14 @@ impl CaseExpr { .evaluate(batch)? .into_array(batch.num_rows())?; + // Fill in a result value already for rows where the base expression value is null + // Since each when expression is tested against the base expression using the equality + // operator, null base values can never match any when expression. `x == NULL` is false, + // for all possible values of `x`. let base_nulls = is_null(base_value.as_ref())?; if base_nulls.true_count() > 0 { + // If there is an else expression, use that as the default value for the null rows + // Otherwise the default `null` value from the eInterleaveBuilder will be used. if let Some(e) = self.else_expr() { let expr = try_cast(Arc::clone(e), &batch.schema(), return_type.clone())?; @@ -298,9 +323,10 @@ impl CaseExpr { let nulls_batch = filter_record_batch(&remainder_batch, &nulls_filter)?; let nulls_rows = filter_array(&remainder_rows, &nulls_filter)?; let nulls_value = expr.evaluate(&nulls_batch)?; - interleave_builder.add(&nulls_rows, nulls_value)?; + interleave_builder.add_result(&nulls_rows, nulls_value)?; } + // Remove the null rows from the remainder batch let not_null_filter = create_filter(¬(&base_nulls)?, optimize_filters); remainder_batch = Cow::Owned(filter_record_batch(&remainder_batch, ¬_null_filter)?); @@ -334,11 +360,11 @@ impl CaseExpr { // for the current branch let then_filter = create_filter(&when_value, optimize_filters); let then_batch = filter_record_batch(&remainder_batch, &then_filter)?; + let then_rows = filter_array(&remainder_rows, &then_filter)?; + let then_expression = &self.when_then_expr[i].1; let then_value = then_expression.evaluate(&then_batch)?; - - let then_rows = filter_array(&remainder_rows, &then_filter)?; - interleave_builder.add(&then_rows, then_value)?; + interleave_builder.add_result(&then_rows, then_value)?; // If the 'when' predicate matched all remaining row, there's nothing left to do so // we can return early @@ -348,6 +374,7 @@ impl CaseExpr { return interleave_builder.finish(); } + // Prepare the next when branch (or the else branch) let next_selection = not(&when_value)?; let next_filter = create_filter(&next_selection, optimize_filters); remainder_batch = @@ -362,7 +389,7 @@ impl CaseExpr { // keep `else_expr`'s data type and return type consistent let expr = try_cast(Arc::clone(e), &batch.schema(), return_type.clone())?; let else_value = expr.evaluate(&remainder_batch)?; - interleave_builder.add(&remainder_rows, else_value)?; + interleave_builder.add_result(&remainder_rows, else_value)?; } interleave_builder.finish() @@ -377,12 +404,15 @@ impl CaseExpr { /// END fn case_when_no_expr(&self, batch: &RecordBatch) -> Result { let optimize_filters = batch.num_columns() > 1; + let return_type = self.data_type(&batch.schema())?; let mut interleave_builder = InterleaveBuilder::new(&return_type, batch.num_rows()); + // `remainder_rows` contains the indices of the rows that need to be evaluated let mut remainder_rows: ArrayRef = Arc::new(UInt32Array::from_iter(0..batch.num_rows() as u32)); + // `remainder_batch` contains the rows themselves that need to be evaluated let mut remainder_batch = Cow::Borrowed(batch); for i in 0..self.when_then_expr.len() { @@ -412,11 +442,11 @@ impl CaseExpr { // for the current branch let then_filter = create_filter(&when_value, optimize_filters); let then_batch = filter_record_batch(&remainder_batch, &then_filter)?; + let then_rows = filter_array(&remainder_rows, &then_filter)?; + let then_expression = &self.when_then_expr[i].1; let then_value = then_expression.evaluate(&then_batch)?; - - let then_rows = filter_array(&remainder_rows, &then_filter)?; - interleave_builder.add(&then_rows, then_value)?; + interleave_builder.add_result(&then_rows, then_value)?; // If the 'when' predicate matched all remaining row, there's nothing left to do so // we can return early @@ -426,6 +456,7 @@ impl CaseExpr { return interleave_builder.finish(); } + // Prepare the next when branch (or the else branch) let next_selection = not(&when_value)?; let next_filter = create_filter(&next_selection, optimize_filters); remainder_batch = @@ -439,7 +470,7 @@ impl CaseExpr { // keep `else_expr`'s data type and return type consistent let expr = try_cast(Arc::clone(e), &batch.schema(), return_type.clone())?; let else_value = expr.evaluate(&remainder_batch)?; - interleave_builder.add(&remainder_rows, else_value)?; + interleave_builder.add_result(&remainder_rows, else_value)?; } interleave_builder.finish() From 7eae3e41b56c94901a63014da9ae3382aa4eea14 Mon Sep 17 00:00:00 2001 From: Pepijn Van Eeckhoudt Date: Tue, 21 Oct 2025 13:09:39 +0200 Subject: [PATCH 06/44] Handle null where values correctly in `case_when_with_expr` --- datafusion/physical-expr/src/expressions/case.rs | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/datafusion/physical-expr/src/expressions/case.rs b/datafusion/physical-expr/src/expressions/case.rs index 5de602688b4b..da6d6160a92a 100644 --- a/datafusion/physical-expr/src/expressions/case.rs +++ b/datafusion/physical-expr/src/expressions/case.rs @@ -310,7 +310,7 @@ impl CaseExpr { // Fill in a result value already for rows where the base expression value is null // Since each when expression is tested against the base expression using the equality - // operator, null base values can never match any when expression. `x == NULL` is false, + // operator, null base values can never match any when expression. `x = NULL` is falsy, // for all possible values of `x`. let base_nulls = is_null(base_value.as_ref())?; if base_nulls.true_count() > 0 { @@ -355,6 +355,12 @@ impl CaseExpr { continue; } + // Make sure 'NULL' is treated as false + let when_value = match when_value.null_count() { + 0 => when_value, + _ => prep_null_mask_filter(&when_value), + }; + // Filter the remainder batch based on the 'when' value // This results in a batch containing only the rows that need to be evaluated // for the current branch From f5448c993d9029c8336b099316e2a30fec8eae49 Mon Sep 17 00:00:00 2001 From: Pepijn Van Eeckhoudt Date: Tue, 21 Oct 2025 13:24:21 +0200 Subject: [PATCH 07/44] Align `case_when_with_expr` and `case_when_no_expr` --- datafusion/physical-expr/src/expressions/case.rs | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/datafusion/physical-expr/src/expressions/case.rs b/datafusion/physical-expr/src/expressions/case.rs index da6d6160a92a..8b9ec19c0952 100644 --- a/datafusion/physical-expr/src/expressions/case.rs +++ b/datafusion/physical-expr/src/expressions/case.rs @@ -425,8 +425,9 @@ impl CaseExpr { // Evaluate the 'when' predicate for the remainder batch // This results in a boolean array with the same length as the remaining number of rows let when_predicate = &self.when_then_expr[i].0; - let when_value = when_predicate.evaluate(&remainder_batch)?; - let when_value = when_value.into_array(remainder_batch.num_rows())?; + let when_value = when_predicate + .evaluate(&remainder_batch)? + .into_array(remainder_batch.num_rows())?; let when_value = as_boolean_array(&when_value).map_err(|_| { internal_datafusion_err!("WHEN expression did not return a BooleanArray") })?; From 132757e796221b5c55a5547ffc56c8a6c2631d6c Mon Sep 17 00:00:00 2001 From: Pepijn Van Eeckhoudt Date: Tue, 21 Oct 2025 15:12:52 +0200 Subject: [PATCH 08/44] Exit early when all base values are null --- datafusion/physical-expr/src/expressions/case.rs | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/datafusion/physical-expr/src/expressions/case.rs b/datafusion/physical-expr/src/expressions/case.rs index 8b9ec19c0952..368a866ce230 100644 --- a/datafusion/physical-expr/src/expressions/case.rs +++ b/datafusion/physical-expr/src/expressions/case.rs @@ -326,6 +326,11 @@ impl CaseExpr { interleave_builder.add_result(&nulls_rows, nulls_value)?; } + // All base values were null, so we can return early + if base_nulls.true_count() == remainder_batch.num_rows() { + return interleave_builder.finish(); + } + // Remove the null rows from the remainder batch let not_null_filter = create_filter(¬(&base_nulls)?, optimize_filters); remainder_batch = From 43c2fe275ac99d7b79f7e824e1787bf1d94d7909 Mon Sep 17 00:00:00 2001 From: Pepijn Van Eeckhoudt Date: Tue, 21 Oct 2025 15:13:17 +0200 Subject: [PATCH 09/44] Avoid calling `interleave` in simple cases --- .../physical-expr/src/expressions/case.rs | 24 +++++++++++++++---- 1 file changed, 20 insertions(+), 4 deletions(-) diff --git a/datafusion/physical-expr/src/expressions/case.rs b/datafusion/physical-expr/src/expressions/case.rs index 368a866ce230..ab3895d7d372 100644 --- a/datafusion/physical-expr/src/expressions/case.rs +++ b/datafusion/physical-expr/src/expressions/case.rs @@ -208,10 +208,26 @@ impl InterleaveBuilder { Ok(()) } - fn finish(self) -> Result { - let array_refs = self.arrays.iter().map(|a| a.as_ref()).collect::>(); - let interleaved_result = interleave(&array_refs, &self.indices)?; - Ok(ColumnarValue::Array(interleaved_result)) + fn finish(mut self) -> Result { + if self.arrays.len() == 1 { + // The first array is always a single null value. + if self.indices.len() == 1 { + // If there's only a single row, reuse the array + Ok(ColumnarValue::Array(self.arrays.remove(0))) + } else { + // Otherwise make a new null array with the correct type and length + Ok(ColumnarValue::Array(new_null_array(self.arrays[0].data_type(), self.indices.len()))) + } + } else if self.arrays.len() == 2 && !self.indices.iter().any(|(array_ix, _)| *array_ix == 0) && self.arrays[1].len() == self.indices.len() { + // There's only a single non-null array and no references to the null array. + // We can take a shortcut and return the non-null array directly. + Ok(ColumnarValue::Array(self.arrays.remove(1))) + } else { + // Interleave arrays + let array_refs = self.arrays.iter().map(|a| a.as_ref()).collect::>(); + let interleaved_result = interleave(&array_refs, &self.indices)?; + Ok(ColumnarValue::Array(interleaved_result)) + } } } From f49d3eaead5aa1ffe0e94bb627befa1c466d1d3f Mon Sep 17 00:00:00 2001 From: Pepijn Van Eeckhoudt Date: Tue, 21 Oct 2025 15:24:41 +0200 Subject: [PATCH 10/44] Formatting --- datafusion/physical-expr/src/expressions/case.rs | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/datafusion/physical-expr/src/expressions/case.rs b/datafusion/physical-expr/src/expressions/case.rs index ab3895d7d372..eb1e85617e17 100644 --- a/datafusion/physical-expr/src/expressions/case.rs +++ b/datafusion/physical-expr/src/expressions/case.rs @@ -216,9 +216,15 @@ impl InterleaveBuilder { Ok(ColumnarValue::Array(self.arrays.remove(0))) } else { // Otherwise make a new null array with the correct type and length - Ok(ColumnarValue::Array(new_null_array(self.arrays[0].data_type(), self.indices.len()))) + Ok(ColumnarValue::Array(new_null_array( + self.arrays[0].data_type(), + self.indices.len(), + ))) } - } else if self.arrays.len() == 2 && !self.indices.iter().any(|(array_ix, _)| *array_ix == 0) && self.arrays[1].len() == self.indices.len() { + } else if self.arrays.len() == 2 + && !self.indices.iter().any(|(array_ix, _)| *array_ix == 0) + && self.arrays[1].len() == self.indices.len() + { // There's only a single non-null array and no references to the null array. // We can take a shortcut and return the non-null array directly. Ok(ColumnarValue::Array(self.arrays.remove(1))) From 9cb6496b1ee9249a7e87e4d24120151d68a13812 Mon Sep 17 00:00:00 2001 From: Pepijn Van Eeckhoudt Date: Thu, 23 Oct 2025 14:33:38 +0200 Subject: [PATCH 11/44] Use a custom merge strategy that takes the case evaluation logic into account --- .../physical-expr/src/expressions/case.rs | 189 ++++++++++++------ 1 file changed, 124 insertions(+), 65 deletions(-) diff --git a/datafusion/physical-expr/src/expressions/case.rs b/datafusion/physical-expr/src/expressions/case.rs index eb1e85617e17..dd3185b08430 100644 --- a/datafusion/physical-expr/src/expressions/case.rs +++ b/datafusion/physical-expr/src/expressions/case.rs @@ -21,8 +21,7 @@ use crate::PhysicalExpr; use arrow::array::*; use arrow::compute::kernels::zip::zip; use arrow::compute::{ - interleave, is_null, not, nullif, prep_null_mask_filter, FilterBuilder, - FilterPredicate, + is_null, not, nullif, prep_null_mask_filter, FilterBuilder, FilterPredicate, }; use arrow::datatypes::{DataType, Schema, UInt32Type}; use arrow::error::ArrowError; @@ -40,7 +39,7 @@ use std::{any::Any, sync::Arc}; type WhenThen = (Arc, Arc); #[derive(Debug, Hash, PartialEq, Eq)] -enum EvalMethod { +pub enum EvalMethod { /// CASE WHEN condition THEN result /// [WHEN ...] /// [ELSE result] @@ -96,7 +95,7 @@ pub struct CaseExpr { /// Optional "else" expression else_expr: Option>, /// Evaluation method to use - eval_method: EvalMethod, + pub eval_method: EvalMethod, } impl std::fmt::Display for CaseExpr { @@ -159,18 +158,27 @@ fn filter_array( filter.filter(array) } -struct InterleaveBuilder { - indices: Vec<(usize, usize)>, - arrays: Vec, +struct ResultBuilder { + data_type: DataType, + // A Vec of partial results that should be merged. `partial_result_indices` contains + // indexes into this vec. + partial_results: Vec, + // Indicates per result row from which array in `partial_results` a value should be taken. + // The indexes in this array are offset by +1. The special value 0 indicates null values. + partial_result_indices: Vec, + // An optional result that is the covering result for all rows. + // This is used as an optimisation to avoid the cost of merging when all rows + // evaluate to the same case branch. + covering_result: Option, } -impl InterleaveBuilder { +impl ResultBuilder { fn new(data_type: &DataType, capacity: usize) -> Self { - // By settings indices to (0, 0) every entry points to the single - // null value in the first array. Self { - indices: vec![(0, 0); capacity], - arrays: vec![new_null_array(data_type, 1)], + data_type: data_type.clone(), + partial_result_indices: vec![0; capacity], + partial_results: vec![], + covering_result: None, } } @@ -183,56 +191,109 @@ impl InterleaveBuilder { /// If `value` is an array, the values from the array and the indices from `rows` will be /// processed pairwise. fn add_result(&mut self, rows: &ArrayRef, value: ColumnarValue) -> Result<()> { - let array_index = self.arrays.len(); match value { ColumnarValue::Array(a) => { assert_eq!(a.len(), rows.len()); - - self.arrays.push(a); - for (array_ix, row_ix) in rows - .as_primitive::() - .values() - .iter() - .enumerate() - { - self.indices[*row_ix as usize] = (array_index, array_ix); + if rows.len() == self.partial_result_indices.len() { + self.set_covering_result(ColumnarValue::Array(a)); + } else { + self.add_partial_result(rows, a.to_data()); } } ColumnarValue::Scalar(s) => { - self.arrays.push(s.to_array()?); - for row_ix in rows.as_primitive::().values().iter() { - self.indices[*row_ix as usize] = (array_index, 0); + if rows.len() == self.partial_result_indices.len() { + self.set_covering_result(ColumnarValue::Scalar(s)); + } else { + self.add_partial_result( + rows, + s.to_array_of_size(rows.len())?.to_data(), + ); } } } Ok(()) } - fn finish(mut self) -> Result { - if self.arrays.len() == 1 { - // The first array is always a single null value. - if self.indices.len() == 1 { - // If there's only a single row, reuse the array - Ok(ColumnarValue::Array(self.arrays.remove(0))) - } else { - // Otherwise make a new null array with the correct type and length - Ok(ColumnarValue::Array(new_null_array( - self.arrays[0].data_type(), - self.indices.len(), - ))) + fn add_partial_result(&mut self, rows: &ArrayRef, data: ArrayData) { + assert!(self.covering_result.is_none()); + + self.partial_results.push(data); + let array_index = self.partial_results.len(); + + for row_ix in rows.as_primitive::().values().iter() { + self.partial_result_indices[*row_ix as usize] = array_index; + } + } + + fn set_covering_result(&mut self, value: ColumnarValue) { + assert!(self.partial_results.is_empty()); + self.covering_result = Some(value); + } + + fn finish(self) -> Result { + match self.covering_result { + Some(v) => { + // If we have a covering result, we can just return it. + Ok(v) } - } else if self.arrays.len() == 2 - && !self.indices.iter().any(|(array_ix, _)| *array_ix == 0) - && self.arrays[1].len() == self.indices.len() - { - // There's only a single non-null array and no references to the null array. - // We can take a shortcut and return the non-null array directly. - Ok(ColumnarValue::Array(self.arrays.remove(1))) - } else { - // Interleave arrays - let array_refs = self.arrays.iter().map(|a| a.as_ref()).collect::>(); - let interleaved_result = interleave(&array_refs, &self.indices)?; - Ok(ColumnarValue::Array(interleaved_result)) + None => match self.partial_results.len() { + 0 => { + // No covering result and no partial results. + // This can happen for case expressions with no else branch where no rows + // matched. + Ok(ColumnarValue::Scalar(ScalarValue::try_new_null( + &self.data_type, + )?)) + } + n => { + // There are n partial results. + // Merge into a single array. + + let data_refs = self.partial_results.iter().collect(); + let mut mutable = MutableArrayData::new( + data_refs, + true, + self.partial_result_indices.len(), + ); + + // This loop extends the mutable array by taking slices from the partial results. + // + // take_offsets keeps track of how many values have been taken from each array. + let mut take_offsets = vec![0; n + 1]; + let mut start_row_ix = 0; + loop { + let array_ix = self.partial_result_indices[start_row_ix]; + + // Determine the length of the slice to take. + let mut end_row_ix = start_row_ix + 1; + while end_row_ix < self.partial_result_indices.len() + && self.partial_result_indices[end_row_ix] == array_ix + { + end_row_ix += 1; + } + + // Extend the mutable with either nulls or with values from the array. + let start_offset = take_offsets[array_ix]; + let end_offset = start_offset + (end_row_ix - start_row_ix); + if array_ix == 0 { + mutable.extend_nulls(end_offset - start_offset); + } else { + mutable.extend(array_ix - 1, start_offset, end_offset); + } + + if end_row_ix == self.partial_result_indices.len() { + break; + } else { + // Update the take_offsets array. + take_offsets[array_ix] = end_offset; + // Set the start_row_ix for the next slice. + start_row_ix = end_row_ix; + } + } + + Ok(ColumnarValue::Array(make_array(mutable.freeze()))) + } + }, } } } @@ -313,12 +374,11 @@ impl CaseExpr { let optimize_filters = batch.num_columns() > 1; let return_type = self.data_type(&batch.schema())?; - let mut interleave_builder = - InterleaveBuilder::new(&return_type, batch.num_rows()); + let mut result_builder = ResultBuilder::new(&return_type, batch.num_rows()); // `remainder_rows` contains the indices of the rows that need to be evaluated let mut remainder_rows: ArrayRef = - Arc::new(UInt32Array::from_iter(0..batch.num_rows() as u32)); + Arc::new(UInt32Array::from_iter_values(0..batch.num_rows() as u32)); // `remainder_batch` contains the rows themselves that need to be evaluated let mut remainder_batch = Cow::Borrowed(batch); @@ -337,7 +397,7 @@ impl CaseExpr { let base_nulls = is_null(base_value.as_ref())?; if base_nulls.true_count() > 0 { // If there is an else expression, use that as the default value for the null rows - // Otherwise the default `null` value from the eInterleaveBuilder will be used. + // Otherwise the default `null` value from the result builder will be used. if let Some(e) = self.else_expr() { let expr = try_cast(Arc::clone(e), &batch.schema(), return_type.clone())?; @@ -345,12 +405,12 @@ impl CaseExpr { let nulls_batch = filter_record_batch(&remainder_batch, &nulls_filter)?; let nulls_rows = filter_array(&remainder_rows, &nulls_filter)?; let nulls_value = expr.evaluate(&nulls_batch)?; - interleave_builder.add_result(&nulls_rows, nulls_value)?; + result_builder.add_result(&nulls_rows, nulls_value)?; } // All base values were null, so we can return early if base_nulls.true_count() == remainder_batch.num_rows() { - return interleave_builder.finish(); + return result_builder.finish(); } // Remove the null rows from the remainder batch @@ -397,14 +457,14 @@ impl CaseExpr { let then_expression = &self.when_then_expr[i].1; let then_value = then_expression.evaluate(&then_batch)?; - interleave_builder.add_result(&then_rows, then_value)?; + result_builder.add_result(&then_rows, then_value)?; // If the 'when' predicate matched all remaining row, there's nothing left to do so // we can return early if remainder_batch.num_rows() == when_match_count || (self.else_expr.is_none() && i == self.when_then_expr.len() - 1) { - return interleave_builder.finish(); + return result_builder.finish(); } // Prepare the next when branch (or the else branch) @@ -422,10 +482,10 @@ impl CaseExpr { // keep `else_expr`'s data type and return type consistent let expr = try_cast(Arc::clone(e), &batch.schema(), return_type.clone())?; let else_value = expr.evaluate(&remainder_batch)?; - interleave_builder.add_result(&remainder_rows, else_value)?; + result_builder.add_result(&remainder_rows, else_value)?; } - interleave_builder.finish() + result_builder.finish() } /// This function evaluates the form of CASE where each WHEN expression is a boolean @@ -439,8 +499,7 @@ impl CaseExpr { let optimize_filters = batch.num_columns() > 1; let return_type = self.data_type(&batch.schema())?; - let mut interleave_builder = - InterleaveBuilder::new(&return_type, batch.num_rows()); + let mut result_builder = ResultBuilder::new(&return_type, batch.num_rows()); // `remainder_rows` contains the indices of the rows that need to be evaluated let mut remainder_rows: ArrayRef = @@ -480,14 +539,14 @@ impl CaseExpr { let then_expression = &self.when_then_expr[i].1; let then_value = then_expression.evaluate(&then_batch)?; - interleave_builder.add_result(&then_rows, then_value)?; + result_builder.add_result(&then_rows, then_value)?; // If the 'when' predicate matched all remaining row, there's nothing left to do so // we can return early if remainder_batch.num_rows() == when_match_count || (self.else_expr.is_none() && i == self.when_then_expr.len() - 1) { - return interleave_builder.finish(); + return result_builder.finish(); } // Prepare the next when branch (or the else branch) @@ -504,10 +563,10 @@ impl CaseExpr { // keep `else_expr`'s data type and return type consistent let expr = try_cast(Arc::clone(e), &batch.schema(), return_type.clone())?; let else_value = expr.evaluate(&remainder_batch)?; - interleave_builder.add_result(&remainder_rows, else_value)?; + result_builder.add_result(&remainder_rows, else_value)?; } - interleave_builder.finish() + result_builder.finish() } /// This function evaluates the specialized case of: From 37f1334f195cae055f9887dc21d87787a71c3760 Mon Sep 17 00:00:00 2001 From: Pepijn Van Eeckhoudt Date: Thu, 23 Oct 2025 16:09:02 +0200 Subject: [PATCH 12/44] Always optimize filters --- .../physical-expr/src/expressions/case.rs | 23 ++++++++----------- 1 file changed, 9 insertions(+), 14 deletions(-) diff --git a/datafusion/physical-expr/src/expressions/case.rs b/datafusion/physical-expr/src/expressions/case.rs index dd3185b08430..0e69f1814c3d 100644 --- a/datafusion/physical-expr/src/expressions/case.rs +++ b/datafusion/physical-expr/src/expressions/case.rs @@ -124,11 +124,10 @@ fn is_cheap_and_infallible(expr: &Arc) -> bool { } /// Creates a [FilterPredicate] from a boolean array. -fn create_filter(predicate: &BooleanArray, optimize: bool) -> FilterPredicate { +fn create_filter(predicate: &BooleanArray) -> FilterPredicate { let mut filter_builder = FilterBuilder::new(predicate); - if optimize { - filter_builder = filter_builder.optimize(); - } + // Always optimize the filter since we use them multiple times. + filter_builder = filter_builder.optimize(); filter_builder.build() } @@ -371,8 +370,6 @@ impl CaseExpr { /// [ELSE result] /// END fn case_when_with_expr(&self, batch: &RecordBatch) -> Result { - let optimize_filters = batch.num_columns() > 1; - let return_type = self.data_type(&batch.schema())?; let mut result_builder = ResultBuilder::new(&return_type, batch.num_rows()); @@ -401,7 +398,7 @@ impl CaseExpr { if let Some(e) = self.else_expr() { let expr = try_cast(Arc::clone(e), &batch.schema(), return_type.clone())?; - let nulls_filter = create_filter(&base_nulls, optimize_filters); + let nulls_filter = create_filter(&base_nulls); let nulls_batch = filter_record_batch(&remainder_batch, &nulls_filter)?; let nulls_rows = filter_array(&remainder_rows, &nulls_filter)?; let nulls_value = expr.evaluate(&nulls_batch)?; @@ -414,7 +411,7 @@ impl CaseExpr { } // Remove the null rows from the remainder batch - let not_null_filter = create_filter(¬(&base_nulls)?, optimize_filters); + let not_null_filter = create_filter(¬(&base_nulls)?); remainder_batch = Cow::Owned(filter_record_batch(&remainder_batch, ¬_null_filter)?); remainder_rows = filter_array(&remainder_rows, ¬_null_filter)?; @@ -451,7 +448,7 @@ impl CaseExpr { // Filter the remainder batch based on the 'when' value // This results in a batch containing only the rows that need to be evaluated // for the current branch - let then_filter = create_filter(&when_value, optimize_filters); + let then_filter = create_filter(&when_value); let then_batch = filter_record_batch(&remainder_batch, &then_filter)?; let then_rows = filter_array(&remainder_rows, &then_filter)?; @@ -469,7 +466,7 @@ impl CaseExpr { // Prepare the next when branch (or the else branch) let next_selection = not(&when_value)?; - let next_filter = create_filter(&next_selection, optimize_filters); + let next_filter = create_filter(&next_selection); remainder_batch = Cow::Owned(filter_record_batch(&remainder_batch, &next_filter)?); remainder_rows = filter_array(&remainder_rows, &next_filter)?; @@ -496,8 +493,6 @@ impl CaseExpr { /// [ELSE result] /// END fn case_when_no_expr(&self, batch: &RecordBatch) -> Result { - let optimize_filters = batch.num_columns() > 1; - let return_type = self.data_type(&batch.schema())?; let mut result_builder = ResultBuilder::new(&return_type, batch.num_rows()); @@ -533,7 +528,7 @@ impl CaseExpr { // Filter the remainder batch based on the 'when' value // This results in a batch containing only the rows that need to be evaluated // for the current branch - let then_filter = create_filter(&when_value, optimize_filters); + let then_filter = create_filter(&when_value); let then_batch = filter_record_batch(&remainder_batch, &then_filter)?; let then_rows = filter_array(&remainder_rows, &then_filter)?; @@ -551,7 +546,7 @@ impl CaseExpr { // Prepare the next when branch (or the else branch) let next_selection = not(&when_value)?; - let next_filter = create_filter(&next_selection, optimize_filters); + let next_filter = create_filter(&next_selection); remainder_batch = Cow::Owned(filter_record_batch(&remainder_batch, &next_filter)?); remainder_rows = filter_array(&remainder_rows, &next_filter)?; From 1b2942bf69feebbfe75e9fb63142c9266f04a923 Mon Sep 17 00:00:00 2001 From: Pepijn Van Eeckhoudt Date: Thu, 23 Oct 2025 17:40:16 +0200 Subject: [PATCH 13/44] Remove accidental addition of pub --- datafusion/physical-expr/src/expressions/case.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/datafusion/physical-expr/src/expressions/case.rs b/datafusion/physical-expr/src/expressions/case.rs index 0e69f1814c3d..8cf16b0424f6 100644 --- a/datafusion/physical-expr/src/expressions/case.rs +++ b/datafusion/physical-expr/src/expressions/case.rs @@ -39,7 +39,7 @@ use std::{any::Any, sync::Arc}; type WhenThen = (Arc, Arc); #[derive(Debug, Hash, PartialEq, Eq)] -pub enum EvalMethod { +enum EvalMethod { /// CASE WHEN condition THEN result /// [WHEN ...] /// [ELSE result] @@ -95,7 +95,7 @@ pub struct CaseExpr { /// Optional "else" expression else_expr: Option>, /// Evaluation method to use - pub eval_method: EvalMethod, + eval_method: EvalMethod, } impl std::fmt::Display for CaseExpr { From f6a5734256a7a53fbd4e30f2f6f3c0197a821a88 Mon Sep 17 00:00:00 2001 From: Pepijn Van Eeckhoudt Date: Thu, 23 Oct 2025 17:45:55 +0200 Subject: [PATCH 14/44] Add comments regarding `RecordBatch::new_unchecked` usage --- datafusion/physical-expr/src/expressions/case.rs | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/datafusion/physical-expr/src/expressions/case.rs b/datafusion/physical-expr/src/expressions/case.rs index 8cf16b0424f6..72520b40fe2d 100644 --- a/datafusion/physical-expr/src/expressions/case.rs +++ b/datafusion/physical-expr/src/expressions/case.rs @@ -131,6 +131,8 @@ fn create_filter(predicate: &BooleanArray) -> FilterPredicate { filter_builder.build() } +// This should be removed when https://github.com/apache/arrow-rs/pull/8693 +// is merged and becomes available. fn filter_record_batch( record_batch: &RecordBatch, filter: &FilterPredicate, @@ -140,6 +142,11 @@ fn filter_record_batch( .iter() .map(|a| filter_array(a, filter)) .collect::, _>>()?; + // SAFETY: since we start from a valid RecordBatch, there's no need to revalidate the schema + // since the set of columns has not changed. + // The input column arrays all had the same length (since they're coming from a valid RecordBatch) + // and the filtering them with the same filter will produces a new set of arrays with identical + // lengths. unsafe { Ok(RecordBatch::new_unchecked( record_batch.schema(), From 30305579509144db6b52b9070cac992b2ae756a1 Mon Sep 17 00:00:00 2001 From: Pepijn Van Eeckhoudt Date: Thu, 23 Oct 2025 18:18:49 +0200 Subject: [PATCH 15/44] Attempt to clarify merge logic --- .../physical-expr/src/expressions/case.rs | 187 ++++++++++++------ 1 file changed, 124 insertions(+), 63 deletions(-) diff --git a/datafusion/physical-expr/src/expressions/case.rs b/datafusion/physical-expr/src/expressions/case.rs index 72520b40fe2d..676efcab3d1b 100644 --- a/datafusion/physical-expr/src/expressions/case.rs +++ b/datafusion/physical-expr/src/expressions/case.rs @@ -164,6 +164,92 @@ fn filter_array( filter.filter(array) } +/// +/// Merges elements by index from a list of [`ArrayData`], creating a new [`ColumnarValue`] from +/// those values. +/// +/// Each element in `indices` is the index of an array in `values` offset by 1. The first +/// occurrence of index value `n` will be mapped to the first value of array `n -1`. The second +/// occurrence to the second value, and so on. +/// +/// The index value `0` is used to indicate null values. +/// +/// ```text +/// ┌─────────────────┐ ┌─────────┐ ┌─────────────────┐ +/// │ A │ │ 0 │ merge( │ NULL │ +/// ├─────────────────┤ ├─────────┤ [values0, values1], ├─────────────────┤ +/// │ D │ │ 2 │ indices │ B │ +/// └─────────────────┘ ├─────────┤ ) ├─────────────────┤ +/// values array 0 │ 2 │ ─────────────────────────▶ │ C │ +/// ├─────────┤ ├─────────────────┤ +/// │ 1 │ │ A │ +/// ├─────────┤ ├─────────────────┤ +/// │ 1 │ │ D │ +/// ┌─────────────────┐ ├─────────┤ ├─────────────────┤ +/// │ B │ │ 2 │ │ E │ +/// ├─────────────────┤ └─────────┘ └─────────────────┘ +/// │ C │ +/// ├─────────────────┤ indices +/// │ E │ array result +/// └─────────────────┘ +/// values array 1 +/// values array 1 +/// ``` +fn merge(values: &[ArrayData], indices: &[usize]) -> Result { + let data_refs = values.iter().collect(); + let mut mutable = MutableArrayData::new( + data_refs, + true, + indices.len(), + ); + + // This loop extends the mutable array by taking slices from the partial results. + // + // take_offsets keeps track of how many values have been taken from each array. + let mut take_offsets = vec![0; values.len() + 1]; + let mut start_row_ix = 0; + loop { + let array_ix = indices[start_row_ix]; + + // Determine the length of the slice to take. + let mut end_row_ix = start_row_ix + 1; + while end_row_ix < indices.len() + && indices[end_row_ix] == array_ix + { + end_row_ix += 1; + } + + // Extend mutable with either nulls or with values from the array. + let start_offset = take_offsets[array_ix]; + let end_offset = start_offset + (end_row_ix - start_row_ix); + if array_ix == 0 { + mutable.extend_nulls(end_offset - start_offset); + } else { + mutable.extend(array_ix - 1, start_offset, end_offset); + } + + if end_row_ix == indices.len() { + break; + } else { + // Update the take_offsets array. + take_offsets[array_ix] = end_offset; + // Set the start_row_ix for the next slice. + start_row_ix = end_row_ix; + } + } + + Ok(make_array(mutable.freeze())) +} + +/// A builder for constructing result arrays for CASE expressions. +/// +/// Rather than building a monolithic array containing all results, it maintains a set of +/// partial result arrays and a mapping that indicates for each row which partial array +/// contains the result value for that row. +/// +/// On finish(), the builder will merge all partial results into a single array if necessary. +/// If all rows evaluated to the same array, that array can be returned directly without +/// any merging overhead. struct ResultBuilder { data_type: DataType, // A Vec of partial results that should be merged. `partial_result_indices` contains @@ -179,6 +265,9 @@ struct ResultBuilder { } impl ResultBuilder { + /// Creates a new ResultBuilder that will produce arrays of the given data type. + /// + /// The capacity parameter indicates the number of rows in the result. fn new(data_type: &DataType, capacity: usize) -> Self { Self { data_type: data_type.clone(), @@ -188,31 +277,32 @@ impl ResultBuilder { } } - /// Adds a result value. + /// Adds a result for one branch of the case expression. /// - /// `rows` should be a [UInt32Array] containing [RecordBatch] relative row indices + /// `row_indices` should be a [UInt32Array] containing [RecordBatch] relative row indices /// for which `value` contains result values. /// - /// If `value` is a scalar, the scalar value is used for each row in `rows`. - /// If `value` is an array, the values from the array and the indices from `rows` will be - /// processed pairwise. - fn add_result(&mut self, rows: &ArrayRef, value: ColumnarValue) -> Result<()> { + /// If `value` is a scalar, the scalar value will be used as the value for each row in `row_indices`. + /// + /// If `value` is an array, the values from the array and the indices from `row_indices` will be + /// processed pairwise. The lengths of `value` and `row_indices` must match. + fn add_branch_result(&mut self, row_indices: &ArrayRef, value: ColumnarValue) -> Result<()> { match value { ColumnarValue::Array(a) => { - assert_eq!(a.len(), rows.len()); - if rows.len() == self.partial_result_indices.len() { + assert_eq!(a.len(), row_indices.len()); + if row_indices.len() == self.partial_result_indices.len() { self.set_covering_result(ColumnarValue::Array(a)); } else { - self.add_partial_result(rows, a.to_data()); + self.add_partial_result(row_indices, a.to_data()); } } ColumnarValue::Scalar(s) => { - if rows.len() == self.partial_result_indices.len() { + if row_indices.len() == self.partial_result_indices.len() { self.set_covering_result(ColumnarValue::Scalar(s)); } else { self.add_partial_result( - rows, - s.to_array_of_size(rows.len())?.to_data(), + row_indices, + s.to_array_of_size(row_indices.len())?.to_data(), ); } } @@ -220,6 +310,11 @@ impl ResultBuilder { Ok(()) } + /// Adds a partial result array. + /// + /// This method adds the given array data as a partial result and updates the index mapping + /// to indicate that the specified rows should take their values from this array. + /// The partial results will be merged into a single array when finish() is called. fn add_partial_result(&mut self, rows: &ArrayRef, data: ArrayData) { assert!(self.covering_result.is_none()); @@ -231,11 +326,21 @@ impl ResultBuilder { } } + /// Sets a covering result that applies to all rows. + /// + /// This is an optimization for cases where all rows evaluate to the same result. + /// When a covering result is set, the builder will return it directly from finish() + /// without any merging overhead. fn set_covering_result(&mut self, value: ColumnarValue) { assert!(self.partial_results.is_empty()); self.covering_result = Some(value); } + /// Finishes building the result and returns the final array. + /// + /// If a covering result was set with set_covering_result(), that result will be returned directly. + /// Otherwise, all partial results will be merged into a single array. + /// If no results were added, a null array of the appropriate type will be returned. fn finish(self) -> Result { match self.covering_result { Some(v) => { @@ -251,53 +356,9 @@ impl ResultBuilder { &self.data_type, )?)) } - n => { - // There are n partial results. + _ => { // Merge into a single array. - - let data_refs = self.partial_results.iter().collect(); - let mut mutable = MutableArrayData::new( - data_refs, - true, - self.partial_result_indices.len(), - ); - - // This loop extends the mutable array by taking slices from the partial results. - // - // take_offsets keeps track of how many values have been taken from each array. - let mut take_offsets = vec![0; n + 1]; - let mut start_row_ix = 0; - loop { - let array_ix = self.partial_result_indices[start_row_ix]; - - // Determine the length of the slice to take. - let mut end_row_ix = start_row_ix + 1; - while end_row_ix < self.partial_result_indices.len() - && self.partial_result_indices[end_row_ix] == array_ix - { - end_row_ix += 1; - } - - // Extend the mutable with either nulls or with values from the array. - let start_offset = take_offsets[array_ix]; - let end_offset = start_offset + (end_row_ix - start_row_ix); - if array_ix == 0 { - mutable.extend_nulls(end_offset - start_offset); - } else { - mutable.extend(array_ix - 1, start_offset, end_offset); - } - - if end_row_ix == self.partial_result_indices.len() { - break; - } else { - // Update the take_offsets array. - take_offsets[array_ix] = end_offset; - // Set the start_row_ix for the next slice. - start_row_ix = end_row_ix; - } - } - - Ok(ColumnarValue::Array(make_array(mutable.freeze()))) + Ok(ColumnarValue::Array(merge(&self.partial_results, &self.partial_result_indices)?)) } }, } @@ -409,7 +470,7 @@ impl CaseExpr { let nulls_batch = filter_record_batch(&remainder_batch, &nulls_filter)?; let nulls_rows = filter_array(&remainder_rows, &nulls_filter)?; let nulls_value = expr.evaluate(&nulls_batch)?; - result_builder.add_result(&nulls_rows, nulls_value)?; + result_builder.add_branch_result(&nulls_rows, nulls_value)?; } // All base values were null, so we can return early @@ -461,7 +522,7 @@ impl CaseExpr { let then_expression = &self.when_then_expr[i].1; let then_value = then_expression.evaluate(&then_batch)?; - result_builder.add_result(&then_rows, then_value)?; + result_builder.add_branch_result(&then_rows, then_value)?; // If the 'when' predicate matched all remaining row, there's nothing left to do so // we can return early @@ -486,7 +547,7 @@ impl CaseExpr { // keep `else_expr`'s data type and return type consistent let expr = try_cast(Arc::clone(e), &batch.schema(), return_type.clone())?; let else_value = expr.evaluate(&remainder_batch)?; - result_builder.add_result(&remainder_rows, else_value)?; + result_builder.add_branch_result(&remainder_rows, else_value)?; } result_builder.finish() @@ -541,7 +602,7 @@ impl CaseExpr { let then_expression = &self.when_then_expr[i].1; let then_value = then_expression.evaluate(&then_batch)?; - result_builder.add_result(&then_rows, then_value)?; + result_builder.add_branch_result(&then_rows, then_value)?; // If the 'when' predicate matched all remaining row, there's nothing left to do so // we can return early @@ -565,7 +626,7 @@ impl CaseExpr { // keep `else_expr`'s data type and return type consistent let expr = try_cast(Arc::clone(e), &batch.schema(), return_type.clone())?; let else_value = expr.evaluate(&remainder_batch)?; - result_builder.add_result(&remainder_rows, else_value)?; + result_builder.add_branch_result(&remainder_rows, else_value)?; } result_builder.finish() From cad318eac92ba5545f81c789a6f247d8ad0a48bc Mon Sep 17 00:00:00 2001 From: Pepijn Van Eeckhoudt Date: Thu, 23 Oct 2025 18:21:17 +0200 Subject: [PATCH 16/44] Rename arguments --- datafusion/physical-expr/src/expressions/case.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/datafusion/physical-expr/src/expressions/case.rs b/datafusion/physical-expr/src/expressions/case.rs index 676efcab3d1b..8b227786670c 100644 --- a/datafusion/physical-expr/src/expressions/case.rs +++ b/datafusion/physical-expr/src/expressions/case.rs @@ -315,13 +315,13 @@ impl ResultBuilder { /// This method adds the given array data as a partial result and updates the index mapping /// to indicate that the specified rows should take their values from this array. /// The partial results will be merged into a single array when finish() is called. - fn add_partial_result(&mut self, rows: &ArrayRef, data: ArrayData) { + fn add_partial_result(&mut self, row_indices: &ArrayRef, row_values: ArrayData) { assert!(self.covering_result.is_none()); - self.partial_results.push(data); + self.partial_results.push(row_values); let array_index = self.partial_results.len(); - for row_ix in rows.as_primitive::().values().iter() { + for row_ix in row_indices.as_primitive::().values().iter() { self.partial_result_indices[*row_ix as usize] = array_index; } } From d57e7b68830efcc69e06b67fb505ada62112eabe Mon Sep 17 00:00:00 2001 From: Pepijn Van Eeckhoudt Date: Thu, 23 Oct 2025 18:26:51 +0200 Subject: [PATCH 17/44] Formatting --- .../physical-expr/src/expressions/case.rs | 21 ++++++++++--------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/datafusion/physical-expr/src/expressions/case.rs b/datafusion/physical-expr/src/expressions/case.rs index 8b227786670c..3ac7605cc6b7 100644 --- a/datafusion/physical-expr/src/expressions/case.rs +++ b/datafusion/physical-expr/src/expressions/case.rs @@ -197,11 +197,7 @@ fn filter_array( /// ``` fn merge(values: &[ArrayData], indices: &[usize]) -> Result { let data_refs = values.iter().collect(); - let mut mutable = MutableArrayData::new( - data_refs, - true, - indices.len(), - ); + let mut mutable = MutableArrayData::new(data_refs, true, indices.len()); // This loop extends the mutable array by taking slices from the partial results. // @@ -213,9 +209,7 @@ fn merge(values: &[ArrayData], indices: &[usize]) -> Result { // Determine the length of the slice to take. let mut end_row_ix = start_row_ix + 1; - while end_row_ix < indices.len() - && indices[end_row_ix] == array_ix - { + while end_row_ix < indices.len() && indices[end_row_ix] == array_ix { end_row_ix += 1; } @@ -286,7 +280,11 @@ impl ResultBuilder { /// /// If `value` is an array, the values from the array and the indices from `row_indices` will be /// processed pairwise. The lengths of `value` and `row_indices` must match. - fn add_branch_result(&mut self, row_indices: &ArrayRef, value: ColumnarValue) -> Result<()> { + fn add_branch_result( + &mut self, + row_indices: &ArrayRef, + value: ColumnarValue, + ) -> Result<()> { match value { ColumnarValue::Array(a) => { assert_eq!(a.len(), row_indices.len()); @@ -358,7 +356,10 @@ impl ResultBuilder { } _ => { // Merge into a single array. - Ok(ColumnarValue::Array(merge(&self.partial_results, &self.partial_result_indices)?)) + Ok(ColumnarValue::Array(merge( + &self.partial_results, + &self.partial_result_indices, + )?)) } }, } From 7f60d68f3f6c04faa19d5d899c2f8152bc210ab2 Mon Sep 17 00:00:00 2001 From: Pepijn Van Eeckhoudt Date: Thu, 23 Oct 2025 18:41:55 +0200 Subject: [PATCH 18/44] More diagrams --- .../physical-expr/src/expressions/case.rs | 31 ++++++++++++++++--- 1 file changed, 27 insertions(+), 4 deletions(-) diff --git a/datafusion/physical-expr/src/expressions/case.rs b/datafusion/physical-expr/src/expressions/case.rs index 3ac7605cc6b7..f35729af5ad1 100644 --- a/datafusion/physical-expr/src/expressions/case.rs +++ b/datafusion/physical-expr/src/expressions/case.rs @@ -168,9 +168,9 @@ fn filter_array( /// Merges elements by index from a list of [`ArrayData`], creating a new [`ColumnarValue`] from /// those values. /// -/// Each element in `indices` is the index of an array in `values` offset by 1. The first -/// occurrence of index value `n` will be mapped to the first value of array `n -1`. The second -/// occurrence to the second value, and so on. +/// Each element in `indices` is the index of an array in `values` offset by 1. `indices` is +/// processed sequentially. The first occurrence of index value `n` will be mapped to the first +/// value of array `n - 1`. The second occurrence to the second value, and so on. /// /// The index value `0` is used to indicate null values. /// @@ -193,7 +193,6 @@ fn filter_array( /// │ E │ array result /// └─────────────────┘ /// values array 1 -/// values array 1 /// ``` fn merge(values: &[ArrayData], indices: &[usize]) -> Result { let data_refs = values.iter().collect(); @@ -280,6 +279,30 @@ impl ResultBuilder { /// /// If `value` is an array, the values from the array and the indices from `row_indices` will be /// processed pairwise. The lengths of `value` and `row_indices` must match. + /// + /// The diagram below shows a situation where a when expression matched rows 1 and 4 of the + /// record batch. The then expression produced the value array `[A, D]`. + /// After adding this result, the result array will have been added to `partial_results` and + /// `partial_indices` will have been updated at indexes 1 and 4. + /// + /// ```text + /// ┌─────────┐ ┌─────────┐┌───────────┐ ┌─────────┐┌───────────┐ + /// │ A │ │ 0 ││ │ │ 0 ││┌─────────┐│ + /// ├─────────┤ ├─────────┤│ │ ├─────────┤││ A ││ + /// │ D │ │ 0 ││ │ │ 1 ││├─────────┤│ + /// └─────────┘ ├─────────┤│ │ add_branch_result( ├─────────┤││ D ││ + /// value │ 0 ││ │ row indices, │ 0 ││└─────────┘│ + /// ├─────────┤│ │ value ├─────────┤│ │ + /// │ 0 ││ │ ) │ 0 ││ │ + /// ┌─────────┐ ├─────────┤│ │ ─────────────────────────▶ ├─────────┤│ │ + /// │ 1 │ │ 0 ││ │ │ 1 ││ │ + /// ├─────────┤ ├─────────┤│ │ ├─────────┤│ │ + /// │ 4 │ │ 0 ││ │ │ 0 ││ │ + /// └─────────┘ └─────────┘└───────────┘ └─────────┘└───────────┘ + /// row indices + /// partial partial partial partial + /// indices results indices results + /// ``` fn add_branch_result( &mut self, row_indices: &ArrayRef, From 9cf46b24a02cb04435442a1d9ede8fb4d0660d29 Mon Sep 17 00:00:00 2001 From: Pepijn Van Eeckhoudt Date: Thu, 23 Oct 2025 18:44:54 +0200 Subject: [PATCH 19/44] More comments --- datafusion/physical-expr/src/expressions/case.rs | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/datafusion/physical-expr/src/expressions/case.rs b/datafusion/physical-expr/src/expressions/case.rs index f35729af5ad1..850aa84e8036 100644 --- a/datafusion/physical-expr/src/expressions/case.rs +++ b/datafusion/physical-expr/src/expressions/case.rs @@ -337,6 +337,9 @@ impl ResultBuilder { /// to indicate that the specified rows should take their values from this array. /// The partial results will be merged into a single array when finish() is called. fn add_partial_result(&mut self, row_indices: &ArrayRef, row_values: ArrayData) { + // Covering results and partial results are mutually exclusive. + // We can assert this since the case evaluation methods are written to only evaluate + // each row of the record batch once. assert!(self.covering_result.is_none()); self.partial_results.push(row_values); @@ -353,7 +356,11 @@ impl ResultBuilder { /// When a covering result is set, the builder will return it directly from finish() /// without any merging overhead. fn set_covering_result(&mut self, value: ColumnarValue) { + // Covering results and partial results are mutually exclusive. + // We can assert this since the case evaluation methods are written to only evaluate + // each row of the record batch once. assert!(self.partial_results.is_empty()); + self.covering_result = Some(value); } From fc04bd510c3037b0b6be5bca8b6f8f7beb478d98 Mon Sep 17 00:00:00 2001 From: Pepijn Van Eeckhoudt Date: Thu, 23 Oct 2025 18:45:34 +0200 Subject: [PATCH 20/44] Remove redundant comment --- datafusion/physical-expr/src/expressions/case.rs | 4 ---- 1 file changed, 4 deletions(-) diff --git a/datafusion/physical-expr/src/expressions/case.rs b/datafusion/physical-expr/src/expressions/case.rs index 850aa84e8036..15e6c9c08325 100644 --- a/datafusion/physical-expr/src/expressions/case.rs +++ b/datafusion/physical-expr/src/expressions/case.rs @@ -365,10 +365,6 @@ impl ResultBuilder { } /// Finishes building the result and returns the final array. - /// - /// If a covering result was set with set_covering_result(), that result will be returned directly. - /// Otherwise, all partial results will be merged into a single array. - /// If no results were added, a null array of the appropriate type will be returned. fn finish(self) -> Result { match self.covering_result { Some(v) => { From 7dba5549b47e704a2195f8396576d6e55e01d4e1 Mon Sep 17 00:00:00 2001 From: Pepijn Van Eeckhoudt Date: Fri, 24 Oct 2025 17:02:57 +0200 Subject: [PATCH 21/44] Calculate slice length once --- datafusion/physical-expr/src/expressions/case.rs | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/datafusion/physical-expr/src/expressions/case.rs b/datafusion/physical-expr/src/expressions/case.rs index 15e6c9c08325..6e714b95e618 100644 --- a/datafusion/physical-expr/src/expressions/case.rs +++ b/datafusion/physical-expr/src/expressions/case.rs @@ -211,12 +211,13 @@ fn merge(values: &[ArrayData], indices: &[usize]) -> Result { while end_row_ix < indices.len() && indices[end_row_ix] == array_ix { end_row_ix += 1; } + let slice_length = end_row_ix - start_row_ix; // Extend mutable with either nulls or with values from the array. let start_offset = take_offsets[array_ix]; - let end_offset = start_offset + (end_row_ix - start_row_ix); + let end_offset = start_offset + slice_length; if array_ix == 0 { - mutable.extend_nulls(end_offset - start_offset); + mutable.extend_nulls(slice_length); } else { mutable.extend(array_ix - 1, start_offset, end_offset); } From 44b763cd8a26fb24a4e6e12044dc4cc6c31b3211 Mon Sep 17 00:00:00 2001 From: Pepijn Van Eeckhoudt Date: Fri, 24 Oct 2025 23:03:54 +0200 Subject: [PATCH 22/44] Avoid filtering when all base values are null --- .../physical-expr/src/expressions/case.rs | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/datafusion/physical-expr/src/expressions/case.rs b/datafusion/physical-expr/src/expressions/case.rs index 6e714b95e618..ca1a6a7bf520 100644 --- a/datafusion/physical-expr/src/expressions/case.rs +++ b/datafusion/physical-expr/src/expressions/case.rs @@ -494,11 +494,18 @@ impl CaseExpr { if let Some(e) = self.else_expr() { let expr = try_cast(Arc::clone(e), &batch.schema(), return_type.clone())?; - let nulls_filter = create_filter(&base_nulls); - let nulls_batch = filter_record_batch(&remainder_batch, &nulls_filter)?; - let nulls_rows = filter_array(&remainder_rows, &nulls_filter)?; - let nulls_value = expr.evaluate(&nulls_batch)?; - result_builder.add_branch_result(&nulls_rows, nulls_value)?; + if base_nulls.true_count() == remainder_batch.num_rows() { + // All base values were null, so no need to filter + let nulls_value = expr.evaluate(&remainder_batch)?; + result_builder.add_branch_result(&remainder_rows, nulls_value)?; + } else { + let nulls_filter = create_filter(&base_nulls); + let nulls_batch = + filter_record_batch(&remainder_batch, &nulls_filter)?; + let nulls_rows = filter_array(&remainder_rows, &nulls_filter)?; + let nulls_value = expr.evaluate(&nulls_batch)?; + result_builder.add_branch_result(&nulls_rows, nulls_value)?; + } } // All base values were null, so we can return early From be578d43be1dfb6861df00d0b0f9369703c8d444 Mon Sep 17 00:00:00 2001 From: Pepijn Van Eeckhoudt Date: Fri, 24 Oct 2025 23:37:54 +0200 Subject: [PATCH 23/44] Avoid filtering when branch matches all remaining rows --- .../physical-expr/src/expressions/case.rs | 36 ++++++++++++++----- 1 file changed, 27 insertions(+), 9 deletions(-) diff --git a/datafusion/physical-expr/src/expressions/case.rs b/datafusion/physical-expr/src/expressions/case.rs index ca1a6a7bf520..4c44c78970a8 100644 --- a/datafusion/physical-expr/src/expressions/case.rs +++ b/datafusion/physical-expr/src/expressions/case.rs @@ -521,6 +521,10 @@ impl CaseExpr { base_value = filter_array(&base_value, ¬_null_filter)?; } + // The types of case and when expressions will be coerced to match. + // We only need to check if the base_value is nested. + let base_value_is_nested = base_value.data_type().is_nested(); + for i in 0..self.when_then_expr.len() { // Evaluate the 'when' predicate for the remainder batch // This results in a boolean array with the same length as the remaining number of rows @@ -533,15 +537,24 @@ impl CaseExpr { &base_value, // The types of case and when expressions will be coerced to match. // We only need to check if the base_value is nested. - base_value.data_type().is_nested(), + base_value_is_nested, )?; - // If the 'when' predicate did not match any rows, continue to the next branch immediately let when_match_count = when_value.true_count(); + + // If the 'when' predicate did not match any rows, continue to the next branch immediately if when_match_count == 0 { continue; } + // If the 'when' predicate matched all remaining rows, there is no need to filter + if when_match_count == remainder_batch.num_rows() { + let then_expression = &self.when_then_expr[i].1; + let then_value = then_expression.evaluate(&remainder_batch)?; + result_builder.add_branch_result(&remainder_rows, then_value)?; + return result_builder.finish(); + } + // Make sure 'NULL' is treated as false let when_value = match when_value.null_count() { 0 => when_value, @@ -561,9 +574,7 @@ impl CaseExpr { // If the 'when' predicate matched all remaining row, there's nothing left to do so // we can return early - if remainder_batch.num_rows() == when_match_count - || (self.else_expr.is_none() && i == self.when_then_expr.len() - 1) - { + if self.else_expr.is_none() && i == self.when_then_expr.len() - 1 { return result_builder.finish(); } @@ -616,12 +627,21 @@ impl CaseExpr { internal_datafusion_err!("WHEN expression did not return a BooleanArray") })?; - // If the 'when' predicate did not match any rows, continue to the next branch immediately let when_match_count = when_value.true_count(); + + // If the 'when' predicate did not match any rows, continue to the next branch immediately if when_match_count == 0 { continue; } + // If the 'when' predicate matched all remaining rows, there is no need to filter + if when_match_count == remainder_batch.num_rows() { + let then_expression = &self.when_then_expr[i].1; + let then_value = then_expression.evaluate(&remainder_batch)?; + result_builder.add_branch_result(&remainder_rows, then_value)?; + return result_builder.finish(); + } + // Make sure 'NULL' is treated as false let when_value = match when_value.null_count() { 0 => Cow::Borrowed(when_value), @@ -641,9 +661,7 @@ impl CaseExpr { // If the 'when' predicate matched all remaining row, there's nothing left to do so // we can return early - if remainder_batch.num_rows() == when_match_count - || (self.else_expr.is_none() && i == self.when_then_expr.len() - 1) - { + if self.else_expr.is_none() && i == self.when_then_expr.len() - 1 { return result_builder.finish(); } From 53cb37bd6038158c58249dfbf3b0858516c40525 Mon Sep 17 00:00:00 2001 From: Pepijn Van Eeckhoudt Date: Sat, 25 Oct 2025 00:24:40 +0200 Subject: [PATCH 24/44] Avoid expanding scalars in case_when_with_expr --- .../physical-expr/src/expressions/case.rs | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/datafusion/physical-expr/src/expressions/case.rs b/datafusion/physical-expr/src/expressions/case.rs index 4c44c78970a8..be4688db9268 100644 --- a/datafusion/physical-expr/src/expressions/case.rs +++ b/datafusion/physical-expr/src/expressions/case.rs @@ -529,16 +529,15 @@ impl CaseExpr { // Evaluate the 'when' predicate for the remainder batch // This results in a boolean array with the same length as the remaining number of rows let when_expr = &self.when_then_expr[i].0; - let when_value = when_expr - .evaluate(&remainder_batch)? - .into_array(remainder_batch.num_rows())?; - let when_value = compare_with_eq( - &when_value, - &base_value, - // The types of case and when expressions will be coerced to match. - // We only need to check if the base_value is nested. - base_value_is_nested, - )?; + let when_value = match when_expr.evaluate(&remainder_batch)? { + ColumnarValue::Array(a) => { + compare_with_eq(&a, &base_value, base_value_is_nested) + } + ColumnarValue::Scalar(s) => { + let scalar = Scalar::new(s.to_array()?); + compare_with_eq(&scalar, &base_value, base_value_is_nested) + } + }?; let when_match_count = when_value.true_count(); From 9ebe463051e36dc18dab2835d32eb6393272f656 Mon Sep 17 00:00:00 2001 From: Pepijn Van Eeckhoudt Date: Sat, 25 Oct 2025 12:07:56 +0200 Subject: [PATCH 25/44] Update code comment --- datafusion/physical-expr/src/expressions/case.rs | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/datafusion/physical-expr/src/expressions/case.rs b/datafusion/physical-expr/src/expressions/case.rs index be4688db9268..882c0adcf216 100644 --- a/datafusion/physical-expr/src/expressions/case.rs +++ b/datafusion/physical-expr/src/expressions/case.rs @@ -571,8 +571,8 @@ impl CaseExpr { let then_value = then_expression.evaluate(&then_batch)?; result_builder.add_branch_result(&then_rows, then_value)?; - // If the 'when' predicate matched all remaining row, there's nothing left to do so - // we can return early + // If this is the last 'when' branch and there is no 'else' expression, there's no + // point in calculating the remaining rows. if self.else_expr.is_none() && i == self.when_then_expr.len() - 1 { return result_builder.finish(); } @@ -658,9 +658,9 @@ impl CaseExpr { let then_value = then_expression.evaluate(&then_batch)?; result_builder.add_branch_result(&then_rows, then_value)?; - // If the 'when' predicate matched all remaining row, there's nothing left to do so - // we can return early - if self.else_expr.is_none() && i == self.when_then_expr.len() - 1 { + // If this is the last 'when' branch and there is no 'else' expression, there's no + // point in calculating the remaining rows. + if i == self.when_then_expr.len() - 1 && self.else_expr.is_none() { return result_builder.finish(); } From 2a6d1aec6b737d338ec40d8d2910b2ab572969fc Mon Sep 17 00:00:00 2001 From: Pepijn Van Eeckhoudt Date: Sat, 25 Oct 2025 15:04:47 +0200 Subject: [PATCH 26/44] Processing review comments - Use usize::MAX as null marker instead of zero - Introduce `ResultState` enum - Make diagrams more representative - Extend documentation of `merge` --- .../physical-expr/src/expressions/case.rs | 248 ++++++++++-------- 1 file changed, 145 insertions(+), 103 deletions(-) diff --git a/datafusion/physical-expr/src/expressions/case.rs b/datafusion/physical-expr/src/expressions/case.rs index 882c0adcf216..280136785a24 100644 --- a/datafusion/physical-expr/src/expressions/case.rs +++ b/datafusion/physical-expr/src/expressions/case.rs @@ -16,6 +16,7 @@ // under the License. use super::{Column, Literal}; +use crate::expressions::case::ResultState::{Complete, Partial}; use crate::expressions::try_cast; use crate::PhysicalExpr; use arrow::array::*; @@ -156,6 +157,10 @@ fn filter_record_batch( } } +// This function exists purely to be able to use the same call style +// for `filter_record_batch` and `filter_array` at the point of use. +// When https://github.com/apache/arrow-rs/pull/8693 is available, replace +// both with method calls on `FilterPredicate`. #[inline(always)] fn filter_array( array: &dyn Array, @@ -164,35 +169,60 @@ fn filter_array( filter.filter(array) } -/// +const MERGE_NULL_MARKER: usize = usize::MAX; + /// Merges elements by index from a list of [`ArrayData`], creating a new [`ColumnarValue`] from /// those values. /// -/// Each element in `indices` is the index of an array in `values` offset by 1. `indices` is -/// processed sequentially. The first occurrence of index value `n` will be mapped to the first -/// value of array `n - 1`. The second occurrence to the second value, and so on. +/// Each element in `indices` is the index of an array in `values`. The `indices` array is processed +/// sequentially. The first occurrence of index value `n` will be mapped to the first +/// value of the array at index `n`. The second occurrence to the second value, and so on. +/// An index value of `usize::MAX` is used to indicate null values. +/// +/// # Implementation notes +/// +/// This algorithm is similar in nature to both `zip` and `interleave`, but there are some important +/// differences. +/// +/// In contrast to `zip`, this function supports multiple input arrays. Instead of a boolean +/// selection vector, an index array is to take values from the input arrays, and a special marker +/// value is used to indicate null values. +/// +/// In contrast to `interleave`, this function does not use pairs of indices. The values in +/// `indices` serve the same purpose as the first value in the pairs passed to `interleave`. +/// The index in the array is implicit and is derived from the number of times a particular array +/// index occurs. +/// The more constrained indexing mechanism used by this algorithm makes it easier to copy values +/// in contiguous slices. In the example below, the two subsequent elements from array `2` can be +/// copied in a single operation from the source array instead of copying them one by one. +/// Long spans of null values are also especially cheap because they do not need to be represented +/// in an input array. +/// +/// # Safety +/// +/// This function does not check that the number of occurrences of any particular array index matches +/// the length of the corresponding input array. If an array contains more values than required, the +/// spurious values will be ignored. If an array contains fewer values than necessary, this function +/// will panic. /// -/// The index value `0` is used to indicate null values. +/// # Example /// /// ```text -/// ┌─────────────────┐ ┌─────────┐ ┌─────────────────┐ -/// │ A │ │ 0 │ merge( │ NULL │ -/// ├─────────────────┤ ├─────────┤ [values0, values1], ├─────────────────┤ -/// │ D │ │ 2 │ indices │ B │ -/// └─────────────────┘ ├─────────┤ ) ├─────────────────┤ -/// values array 0 │ 2 │ ─────────────────────────▶ │ C │ -/// ├─────────┤ ├─────────────────┤ -/// │ 1 │ │ A │ -/// ├─────────┤ ├─────────────────┤ -/// │ 1 │ │ D │ -/// ┌─────────────────┐ ├─────────┤ ├─────────────────┤ -/// │ B │ │ 2 │ │ E │ -/// ├─────────────────┤ └─────────┘ └─────────────────┘ -/// │ C │ -/// ├─────────────────┤ indices -/// │ E │ array result -/// └─────────────────┘ -/// values array 1 +/// ┌───────────┐ ┌─────────┐ ┌─────────┐ +/// │┌─────────┐│ │ MAX │ │ NULL │ +/// ││ A ││ ├─────────┤ ├─────────┤ +/// │└─────────┘│ │ 1 │ │ B │ +/// │┌─────────┐│ ├─────────┤ ├─────────┤ +/// ││ B ││ │ 0 │ merge(values, indices) │ A │ +/// │└─────────┘│ ├─────────┤ ─────────────────────────▶ ├─────────┤ +/// │┌─────────┐│ │ MAX │ │ NULL │ +/// ││ C ││ ├─────────┤ ├─────────┤ +/// │├─────────┤│ │ 2 │ │ C │ +/// ││ D ││ ├─────────┤ ├─────────┤ +/// │└─────────┘│ │ 2 │ │ D │ +/// └───────────┘ └─────────┘ └─────────┘ +/// values indices result +/// /// ``` fn merge(values: &[ArrayData], indices: &[usize]) -> Result { let data_refs = values.iter().collect(); @@ -214,19 +244,18 @@ fn merge(values: &[ArrayData], indices: &[usize]) -> Result { let slice_length = end_row_ix - start_row_ix; // Extend mutable with either nulls or with values from the array. - let start_offset = take_offsets[array_ix]; - let end_offset = start_offset + slice_length; - if array_ix == 0 { + if array_ix == MERGE_NULL_MARKER { mutable.extend_nulls(slice_length); } else { - mutable.extend(array_ix - 1, start_offset, end_offset); + let start_offset = take_offsets[array_ix]; + let end_offset = start_offset + slice_length; + mutable.extend(array_ix, start_offset, end_offset); + take_offsets[array_ix] = end_offset; } if end_row_ix == indices.len() { break; } else { - // Update the take_offsets array. - take_offsets[array_ix] = end_offset; // Set the start_row_ix for the next slice. start_row_ix = end_row_ix; } @@ -235,6 +264,21 @@ fn merge(values: &[ArrayData], indices: &[usize]) -> Result { Ok(make_array(mutable.freeze())) } +enum ResultState { + /// The final result needs to be computed by merging the the data in `arrays`. + Partial { + // A Vec of partial results that should be merged. `partial_result_indices` contains + // indexes into this vec. + arrays: Vec, + // Indicates per result row from which array in `partial_results` a value should be taken. + // The indexes in this array are offset by +1. The special value 0 indicates null values. + indices: Vec, + }, + /// A single branch matched all input rows. When creating the final result, no further merging + /// of partial results is necessary. + Complete(ColumnarValue), +} + /// A builder for constructing result arrays for CASE expressions. /// /// Rather than building a monolithic array containing all results, it maintains a set of @@ -246,28 +290,22 @@ fn merge(values: &[ArrayData], indices: &[usize]) -> Result { /// any merging overhead. struct ResultBuilder { data_type: DataType, - // A Vec of partial results that should be merged. `partial_result_indices` contains - // indexes into this vec. - partial_results: Vec, - // Indicates per result row from which array in `partial_results` a value should be taken. - // The indexes in this array are offset by +1. The special value 0 indicates null values. - partial_result_indices: Vec, - // An optional result that is the covering result for all rows. - // This is used as an optimisation to avoid the cost of merging when all rows - // evaluate to the same case branch. - covering_result: Option, + row_count: usize, + state: ResultState, } impl ResultBuilder { /// Creates a new ResultBuilder that will produce arrays of the given data type. /// - /// The capacity parameter indicates the number of rows in the result. - fn new(data_type: &DataType, capacity: usize) -> Self { + /// The `row_count` parameter indicates the number of rows in the final result. + fn new(data_type: &DataType, row_count: usize) -> Self { Self { data_type: data_type.clone(), - partial_result_indices: vec![0; capacity], - partial_results: vec![], - covering_result: None, + row_count, + state: Partial { + arrays: vec![], + indices: vec![MERGE_NULL_MARKER; row_count], + }, } } @@ -283,26 +321,26 @@ impl ResultBuilder { /// /// The diagram below shows a situation where a when expression matched rows 1 and 4 of the /// record batch. The then expression produced the value array `[A, D]`. - /// After adding this result, the result array will have been added to `partial_results` and - /// `partial_indices` will have been updated at indexes 1 and 4. + /// After adding this result, the result array will have been added to [Partial::arrays] and + /// [Partial::indices] will have been updated at indexes 1 and 4. /// /// ```text /// ┌─────────┐ ┌─────────┐┌───────────┐ ┌─────────┐┌───────────┐ - /// │ A │ │ 0 ││ │ │ 0 ││┌─────────┐│ - /// ├─────────┤ ├─────────┤│ │ ├─────────┤││ A ││ - /// │ D │ │ 0 ││ │ │ 1 ││├─────────┤│ - /// └─────────┘ ├─────────┤│ │ add_branch_result( ├─────────┤││ D ││ - /// value │ 0 ││ │ row indices, │ 0 ││└─────────┘│ - /// ├─────────┤│ │ value ├─────────┤│ │ - /// │ 0 ││ │ ) │ 0 ││ │ - /// ┌─────────┐ ├─────────┤│ │ ─────────────────────────▶ ├─────────┤│ │ - /// │ 1 │ │ 0 ││ │ │ 1 ││ │ - /// ├─────────┤ ├─────────┤│ │ ├─────────┤│ │ - /// │ 4 │ │ 0 ││ │ │ 0 ││ │ + /// │ C │ │ MAX ││┌─────────┐│ │ MAX ││┌─────────┐│ + /// ├─────────┤ ├─────────┤││ A ││ ├─────────┤││ A ││ + /// │ D │ │ MAX ││└─────────┘│ │ 2 ││└─────────┘│ + /// └─────────┘ ├─────────┤│┌─────────┐│ add_branch_result( ├─────────┤│┌─────────┐│ + /// value │ 0 │││ B ││ row indices, │ 0 │││ B ││ + /// ├─────────┤│└─────────┘│ value ├─────────┤│└─────────┘│ + /// │ MAX ││ │ ) │ MAX ││┌─────────┐│ + /// ┌─────────┐ ├─────────┤│ │ ─────────────────────────▶ ├─────────┤││ C ││ + /// │ 1 │ │ MAX ││ │ │ 2 ││├─────────┤│ + /// ├─────────┤ ├─────────┤│ │ ├─────────┤││ D ││ + /// │ 4 │ │ 1 ││ │ │ 1 ││└─────────┘│ /// └─────────┘ └─────────┘└───────────┘ └─────────┘└───────────┘ /// row indices /// partial partial partial partial - /// indices results indices results + /// indices arrays indices arrays /// ``` fn add_branch_result( &mut self, @@ -311,25 +349,25 @@ impl ResultBuilder { ) -> Result<()> { match value { ColumnarValue::Array(a) => { - assert_eq!(a.len(), row_indices.len()); - if row_indices.len() == self.partial_result_indices.len() { - self.set_covering_result(ColumnarValue::Array(a)); + if a.len() != row_indices.len() { + internal_err!("Array length must match row indices length") + } else if row_indices.len() == self.row_count { + self.set_single_result(ColumnarValue::Array(a)) } else { - self.add_partial_result(row_indices, a.to_data()); + self.add_partial_result(row_indices, a.to_data()) } } ColumnarValue::Scalar(s) => { - if row_indices.len() == self.partial_result_indices.len() { - self.set_covering_result(ColumnarValue::Scalar(s)); + if row_indices.len() == self.row_count { + self.set_single_result(ColumnarValue::Scalar(s)) } else { self.add_partial_result( row_indices, s.to_array_of_size(row_indices.len())?.to_data(), - ); + ) } } } - Ok(()) } /// Adds a partial result array. @@ -337,17 +375,22 @@ impl ResultBuilder { /// This method adds the given array data as a partial result and updates the index mapping /// to indicate that the specified rows should take their values from this array. /// The partial results will be merged into a single array when finish() is called. - fn add_partial_result(&mut self, row_indices: &ArrayRef, row_values: ArrayData) { - // Covering results and partial results are mutually exclusive. - // We can assert this since the case evaluation methods are written to only evaluate - // each row of the record batch once. - assert!(self.covering_result.is_none()); - - self.partial_results.push(row_values); - let array_index = self.partial_results.len(); + fn add_partial_result( + &mut self, + row_indices: &ArrayRef, + row_values: ArrayData, + ) -> Result<()> { + match &mut self.state { + Partial { arrays, indices } => { + let array_index = arrays.len(); + arrays.push(row_values); - for row_ix in row_indices.as_primitive::().values().iter() { - self.partial_result_indices[*row_ix as usize] = array_index; + for row_ix in row_indices.as_primitive::().values().iter() { + indices[*row_ix as usize] = array_index; + } + Ok(()) + } + Complete(_) => internal_err!("Complete result already set"), } } @@ -356,39 +399,38 @@ impl ResultBuilder { /// This is an optimization for cases where all rows evaluate to the same result. /// When a covering result is set, the builder will return it directly from finish() /// without any merging overhead. - fn set_covering_result(&mut self, value: ColumnarValue) { - // Covering results and partial results are mutually exclusive. - // We can assert this since the case evaluation methods are written to only evaluate - // each row of the record batch once. - assert!(self.partial_results.is_empty()); - - self.covering_result = Some(value); + fn set_single_result(&mut self, value: ColumnarValue) -> Result<()> { + match &self.state { + Partial { arrays, .. } if !arrays.is_empty() => { + internal_err!("Partial result already set") + } + Complete(_) => internal_err!("Complete result already set"), + _ => { + self.state = Complete(value); + Ok(()) + } + } } /// Finishes building the result and returns the final array. fn finish(self) -> Result { - match self.covering_result { - Some(v) => { - // If we have a covering result, we can just return it. + match self.state { + Complete(v) => { + // If we have a complete result, we can just return it. Ok(v) } - None => match self.partial_results.len() { - 0 => { - // No covering result and no partial results. - // This can happen for case expressions with no else branch where no rows - // matched. - Ok(ColumnarValue::Scalar(ScalarValue::try_new_null( - &self.data_type, - )?)) - } - _ => { - // Merge into a single array. - Ok(ColumnarValue::Array(merge( - &self.partial_results, - &self.partial_result_indices, - )?)) - } - }, + Partial { arrays, .. } if arrays.is_empty() => { + // No complete result and no partial results. + // This can happen for case expressions with no else branch where no rows + // matched. + Ok(ColumnarValue::Scalar(ScalarValue::try_new_null( + &self.data_type, + )?)) + } + Partial { arrays, indices } => { + // Merge partial results into a single array. + Ok(ColumnarValue::Array(merge(&arrays, &indices)?)) + } } } } From 58cdfa8a64d219aa5c35d64dc57c50e156fadfd6 Mon Sep 17 00:00:00 2001 From: Pepijn Van Eeckhoudt Date: Sat, 25 Oct 2025 15:27:14 +0200 Subject: [PATCH 27/44] Remove links to private fields --- datafusion/physical-expr/src/expressions/case.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/datafusion/physical-expr/src/expressions/case.rs b/datafusion/physical-expr/src/expressions/case.rs index 280136785a24..8a66ce9f6eb2 100644 --- a/datafusion/physical-expr/src/expressions/case.rs +++ b/datafusion/physical-expr/src/expressions/case.rs @@ -321,8 +321,8 @@ impl ResultBuilder { /// /// The diagram below shows a situation where a when expression matched rows 1 and 4 of the /// record batch. The then expression produced the value array `[A, D]`. - /// After adding this result, the result array will have been added to [Partial::arrays] and - /// [Partial::indices] will have been updated at indexes 1 and 4. + /// After adding this result, the result array will have been added to `Partial::arrays` and + /// `Partial::indices` will have been updated at indexes 1 and 4. /// /// ```text /// ┌─────────┐ ┌─────────┐┌───────────┐ ┌─────────┐┌───────────┐ From e385f080e449bf154d69540ac5725fe6553da33e Mon Sep 17 00:00:00 2001 From: Pepijn Van Eeckhoudt Date: Sat, 25 Oct 2025 16:33:34 +0200 Subject: [PATCH 28/44] Add debug assertion to detect duplicate row values. --- datafusion/physical-expr/src/expressions/case.rs | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/datafusion/physical-expr/src/expressions/case.rs b/datafusion/physical-expr/src/expressions/case.rs index 8a66ce9f6eb2..6cf86b512307 100644 --- a/datafusion/physical-expr/src/expressions/case.rs +++ b/datafusion/physical-expr/src/expressions/case.rs @@ -382,6 +382,16 @@ impl ResultBuilder { ) -> Result<()> { match &mut self.state { Partial { arrays, indices } => { + // This is check is only active for debug config because the callers of this method, + // `case_when_with_expr` and `case_when_no_expr`, already ensure that + // they only calculate a value for each row at most once. + #[cfg(debug_assertions)] + for row_ix in row_indices.as_primitive::().values().iter() { + if indices[*row_ix as usize] != MERGE_NULL_MARKER { + return internal_err!("Duplicate value for row {}", *row_ix); + } + } + let array_index = arrays.len(); arrays.push(row_values); From d512b2f9fe179f20e5d036e3414bcd38c105de64 Mon Sep 17 00:00:00 2001 From: Pepijn Van Eeckhoudt Date: Sun, 26 Oct 2025 16:28:20 +0100 Subject: [PATCH 29/44] Add SLT with unreachable 1/0 in when predicate --- datafusion/sqllogictest/test_files/case.slt | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/datafusion/sqllogictest/test_files/case.slt b/datafusion/sqllogictest/test_files/case.slt index 2f9173d2dcbd..58df98482ab8 100644 --- a/datafusion/sqllogictest/test_files/case.slt +++ b/datafusion/sqllogictest/test_files/case.slt @@ -594,4 +594,18 @@ query I SELECT CASE WHEN a = 'a' THEN 0 WHEN a = 'b' THEN 1 ELSE 2 END FROM (VALUES (NULL), ('z')) t(a) ---- 2 -2 \ No newline at end of file +2 + +# The `WHEN 1/0` is not effectively reachable in this query and should never be executed +query T +SELECT CASE a WHEN 1 THEN 'a' WHEN 2 THEN 'b' WHEN 1 / 0 THEN 'c' ELSE 'd' END FROM (VALUES (1), (2)) t(a) +---- +a +b + +# The `WHEN 1/0` is not effectively reachable in this query and should never be executed +query T +SELECT CASE WHEN a = 1 THEN 'a' WHEN a = 2 THEN 'b' WHEN a = 1 / 0 THEN 'c' ELSE 'd' END FROM (VALUES (1), (2)) t(a) +---- +a +b \ No newline at end of file From 865bace33cdce9c094cf333c7b2e376591baf1c7 Mon Sep 17 00:00:00 2001 From: Pepijn Van Eeckhoudt Date: Sun, 26 Oct 2025 16:36:48 +0100 Subject: [PATCH 30/44] Add SLT that verifies that when branches are not evaluated for already matched rows --- datafusion/sqllogictest/test_files/case.slt | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/datafusion/sqllogictest/test_files/case.slt b/datafusion/sqllogictest/test_files/case.slt index 58df98482ab8..91b03d3490f5 100644 --- a/datafusion/sqllogictest/test_files/case.slt +++ b/datafusion/sqllogictest/test_files/case.slt @@ -608,4 +608,12 @@ query T SELECT CASE WHEN a = 1 THEN 'a' WHEN a = 2 THEN 'b' WHEN a = 1 / 0 THEN 'c' ELSE 'd' END FROM (VALUES (1), (2)) t(a) ---- a -b \ No newline at end of file +b + +# The `WHEN 1/0` is not effectively reachable in this query and should never be executed +query T +SELECT CASE WHEN a = 0 THEN 'a' WHEN 1 / a = 1 THEN 'b' ELSE 'c' END FROM (VALUES (0), (1), (2)) t(a) +---- +a +b +c \ No newline at end of file From 17b6f9a8d0fec578fb7e57fcb3f747eadac35f96 Mon Sep 17 00:00:00 2001 From: Pepijn Van Eeckhoudt Date: Sun, 26 Oct 2025 16:38:30 +0100 Subject: [PATCH 31/44] Introduce PartialResultIndex to reduce size of indices array and avoid potential programming errors --- .../physical-expr/src/expressions/case.rs | 93 +++++++++++++++---- 1 file changed, 77 insertions(+), 16 deletions(-) diff --git a/datafusion/physical-expr/src/expressions/case.rs b/datafusion/physical-expr/src/expressions/case.rs index 6cf86b512307..c0ab50e89e8f 100644 --- a/datafusion/physical-expr/src/expressions/case.rs +++ b/datafusion/physical-expr/src/expressions/case.rs @@ -36,6 +36,7 @@ use itertools::Itertools; use std::borrow::Cow; use std::hash::Hash; use std::{any::Any, sync::Arc}; +use std::fmt::{Debug, Formatter}; type WhenThen = (Arc, Arc); @@ -100,7 +101,7 @@ pub struct CaseExpr { } impl std::fmt::Display for CaseExpr { - fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + fn fmt(&self, f: &mut Formatter) -> std::fmt::Result { write!(f, "CASE ")?; if let Some(e) = &self.expr { write!(f, "{e} ")?; @@ -169,8 +170,6 @@ fn filter_array( filter.filter(array) } -const MERGE_NULL_MARKER: usize = usize::MAX; - /// Merges elements by index from a list of [`ArrayData`], creating a new [`ColumnarValue`] from /// those values. /// @@ -224,7 +223,7 @@ const MERGE_NULL_MARKER: usize = usize::MAX; /// values indices result /// /// ``` -fn merge(values: &[ArrayData], indices: &[usize]) -> Result { +fn merge(values: &[ArrayData], indices: &[PartialResultIndex]) -> Result { let data_refs = values.iter().collect(); let mut mutable = MutableArrayData::new(data_refs, true, indices.len()); @@ -244,13 +243,14 @@ fn merge(values: &[ArrayData], indices: &[usize]) -> Result { let slice_length = end_row_ix - start_row_ix; // Extend mutable with either nulls or with values from the array. - if array_ix == MERGE_NULL_MARKER { - mutable.extend_nulls(slice_length); - } else { - let start_offset = take_offsets[array_ix]; - let end_offset = start_offset + slice_length; - mutable.extend(array_ix, start_offset, end_offset); - take_offsets[array_ix] = end_offset; + match array_ix.index() { + None => mutable.extend_nulls(slice_length), + Some(index) => { + let start_offset = take_offsets[index]; + let end_offset = start_offset + slice_length; + mutable.extend(index, start_offset, end_offset); + take_offsets[index] = end_offset; + } } if end_row_ix == indices.len() { @@ -264,6 +264,64 @@ fn merge(values: &[ArrayData], indices: &[usize]) -> Result { Ok(make_array(mutable.freeze())) } +/// An index into the partial results array that's more compact than `usize`. +/// +/// `u32::MAX` is reserved as a special 'none' value. This is used instead of +/// `Option` to keep the array of indices as compact as possible. +#[derive(Copy, Clone, PartialEq, Eq)] +struct PartialResultIndex { + index: u32, +} + +const NONE_VALUE: u32 = u32::MAX; + +impl PartialResultIndex { + /// Returns the 'none' placeholder value. + fn none() -> Self { + Self { index: NONE_VALUE } + } + + /// Creates a new partial result index. + /// + /// If the provide value is greater than or equal to `u32::MAX` + /// an error will be returned. + fn try_new(index: usize) -> Result { + let Ok(index) = u32::try_from(index) else { + return internal_err!("Partial result index exceeds limit"); + }; + + if index == NONE_VALUE { + return internal_err!("Partial result index exceeds limit"); + } + + Ok(Self { index }) + } + + /// Determines if this index is the 'none' placeholder value or not. + fn is_none(&self) -> bool { + self.index == NONE_VALUE + } + + /// Returns `Some(index)` if this value is not the 'none' placeholder, `None` otherwise. + fn index(&self) -> Option { + if self.is_none() { + None + } else { + Some(self.index as usize) + } + } +} + +impl Debug for PartialResultIndex { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + if self.is_none() { + write!(f, "null") + } else { + write!(f, "{}", self.index) + } + } +} + enum ResultState { /// The final result needs to be computed by merging the the data in `arrays`. Partial { @@ -271,8 +329,7 @@ enum ResultState { // indexes into this vec. arrays: Vec, // Indicates per result row from which array in `partial_results` a value should be taken. - // The indexes in this array are offset by +1. The special value 0 indicates null values. - indices: Vec, + indices: Vec, }, /// A single branch matched all input rows. When creating the final result, no further merging /// of partial results is necessary. @@ -304,7 +361,7 @@ impl ResultBuilder { row_count, state: Partial { arrays: vec![], - indices: vec![MERGE_NULL_MARKER; row_count], + indices: vec![PartialResultIndex::none(); row_count], }, } } @@ -387,12 +444,16 @@ impl ResultBuilder { // they only calculate a value for each row at most once. #[cfg(debug_assertions)] for row_ix in row_indices.as_primitive::().values().iter() { - if indices[*row_ix as usize] != MERGE_NULL_MARKER { + if !indices[*row_ix as usize].is_none() { return internal_err!("Duplicate value for row {}", *row_ix); } } let array_index = arrays.len(); + let Ok(array_index) = PartialResultIndex::try_new(array_index) else { + return internal_err!("Partial result count exceeds limit"); + }; + arrays.push(row_values); for row_ix in row_indices.as_primitive::().values().iter() { @@ -974,7 +1035,7 @@ impl PhysicalExpr for CaseExpr { } } - fn fmt_sql(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + fn fmt_sql(&self, f: &mut Formatter<'_>) -> std::fmt::Result { write!(f, "CASE ")?; if let Some(e) = &self.expr { e.fmt_sql(f)?; From ba584d387f34594ccd01b1478bcb09f918f67ce2 Mon Sep 17 00:00:00 2001 From: Pepijn Van Eeckhoudt Date: Sun, 26 Oct 2025 17:46:13 +0100 Subject: [PATCH 32/44] Formatting --- datafusion/physical-expr/src/expressions/case.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datafusion/physical-expr/src/expressions/case.rs b/datafusion/physical-expr/src/expressions/case.rs index c0ab50e89e8f..2612e64e38ee 100644 --- a/datafusion/physical-expr/src/expressions/case.rs +++ b/datafusion/physical-expr/src/expressions/case.rs @@ -34,9 +34,9 @@ use datafusion_expr::ColumnarValue; use datafusion_physical_expr_common::datum::compare_with_eq; use itertools::Itertools; use std::borrow::Cow; +use std::fmt::{Debug, Formatter}; use std::hash::Hash; use std::{any::Any, sync::Arc}; -use std::fmt::{Debug, Formatter}; type WhenThen = (Arc, Arc); From 07119f81a91427ddb33a53f247600c6386887e7c Mon Sep 17 00:00:00 2001 From: Pepijn Van Eeckhoudt Date: Sun, 26 Oct 2025 18:14:15 +0100 Subject: [PATCH 33/44] Minor code cleanup --- datafusion/physical-expr/src/expressions/case.rs | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/datafusion/physical-expr/src/expressions/case.rs b/datafusion/physical-expr/src/expressions/case.rs index 2612e64e38ee..9f133b974bce 100644 --- a/datafusion/physical-expr/src/expressions/case.rs +++ b/datafusion/physical-expr/src/expressions/case.rs @@ -449,10 +449,7 @@ impl ResultBuilder { } } - let array_index = arrays.len(); - let Ok(array_index) = PartialResultIndex::try_new(array_index) else { - return internal_err!("Partial result count exceeds limit"); - }; + let array_index = PartialResultIndex::try_new(arrays.len())?; arrays.push(row_values); From 3b8b6b82fb4efbc28a2c25033020bc26aab4c86d Mon Sep 17 00:00:00 2001 From: Pepijn Van Eeckhoudt Date: Mon, 27 Oct 2025 16:25:04 +0100 Subject: [PATCH 34/44] Reorganise base_values null handling to avoid unnecessary computations --- .../physical-expr/src/expressions/case.rs | 34 +++++++++++-------- 1 file changed, 20 insertions(+), 14 deletions(-) diff --git a/datafusion/physical-expr/src/expressions/case.rs b/datafusion/physical-expr/src/expressions/case.rs index 9f133b974bce..2ec64c696084 100644 --- a/datafusion/physical-expr/src/expressions/case.rs +++ b/datafusion/physical-expr/src/expressions/case.rs @@ -22,7 +22,7 @@ use crate::PhysicalExpr; use arrow::array::*; use arrow::compute::kernels::zip::zip; use arrow::compute::{ - is_null, not, nullif, prep_null_mask_filter, FilterBuilder, FilterPredicate, + is_not_null, not, nullif, prep_null_mask_filter, FilterBuilder, FilterPredicate, }; use arrow::datatypes::{DataType, Schema, UInt32Type}; use arrow::error::ArrowError; @@ -586,7 +586,7 @@ impl CaseExpr { let mut remainder_batch = Cow::Borrowed(batch); // evaluate the base expression - let mut base_value = self + let mut base_values = self .expr .as_ref() .unwrap() @@ -597,19 +597,25 @@ impl CaseExpr { // Since each when expression is tested against the base expression using the equality // operator, null base values can never match any when expression. `x = NULL` is falsy, // for all possible values of `x`. - let base_nulls = is_null(base_value.as_ref())?; - if base_nulls.true_count() > 0 { + if base_values.null_count() > 0 { + // Use `is_not_null` since this is a cheap clone of the null buffer from 'base_value'. + // We already checked there are nulls, so we can be sure a new buffer will not be + // created. + let base_not_nulls = is_not_null(base_values.as_ref())?; + let base_all_null = base_values.null_count() == remainder_batch.num_rows(); + // If there is an else expression, use that as the default value for the null rows // Otherwise the default `null` value from the result builder will be used. if let Some(e) = self.else_expr() { let expr = try_cast(Arc::clone(e), &batch.schema(), return_type.clone())?; - if base_nulls.true_count() == remainder_batch.num_rows() { + if base_all_null { // All base values were null, so no need to filter let nulls_value = expr.evaluate(&remainder_batch)?; result_builder.add_branch_result(&remainder_rows, nulls_value)?; } else { - let nulls_filter = create_filter(&base_nulls); + // Filter out the null rows and evaluate the else expression for those + let nulls_filter = create_filter(¬(&base_not_nulls)?); let nulls_batch = filter_record_batch(&remainder_batch, &nulls_filter)?; let nulls_rows = filter_array(&remainder_rows, &nulls_filter)?; @@ -618,22 +624,22 @@ impl CaseExpr { } } - // All base values were null, so we can return early - if base_nulls.true_count() == remainder_batch.num_rows() { + // All base values are null, so we can return early + if base_all_null { return result_builder.finish(); } // Remove the null rows from the remainder batch - let not_null_filter = create_filter(¬(&base_nulls)?); + let not_null_filter = create_filter(&base_not_nulls); remainder_batch = Cow::Owned(filter_record_batch(&remainder_batch, ¬_null_filter)?); remainder_rows = filter_array(&remainder_rows, ¬_null_filter)?; - base_value = filter_array(&base_value, ¬_null_filter)?; + base_values = filter_array(&base_values, ¬_null_filter)?; } // The types of case and when expressions will be coerced to match. // We only need to check if the base_value is nested. - let base_value_is_nested = base_value.data_type().is_nested(); + let base_value_is_nested = base_values.data_type().is_nested(); for i in 0..self.when_then_expr.len() { // Evaluate the 'when' predicate for the remainder batch @@ -641,11 +647,11 @@ impl CaseExpr { let when_expr = &self.when_then_expr[i].0; let when_value = match when_expr.evaluate(&remainder_batch)? { ColumnarValue::Array(a) => { - compare_with_eq(&a, &base_value, base_value_is_nested) + compare_with_eq(&a, &base_values, base_value_is_nested) } ColumnarValue::Scalar(s) => { let scalar = Scalar::new(s.to_array()?); - compare_with_eq(&scalar, &base_value, base_value_is_nested) + compare_with_eq(&scalar, &base_values, base_value_is_nested) } }?; @@ -693,7 +699,7 @@ impl CaseExpr { remainder_batch = Cow::Owned(filter_record_batch(&remainder_batch, &next_filter)?); remainder_rows = filter_array(&remainder_rows, &next_filter)?; - base_value = filter_array(&base_value, &next_filter)?; + base_values = filter_array(&base_values, &next_filter)?; } // If we reached this point, some rows were left unmatched. From d189c3b6cd50c9847cc41f9a3c3e58dd28fc7acd Mon Sep 17 00:00:00 2001 From: Pepijn Van Eeckhoudt Date: Mon, 27 Oct 2025 16:42:10 +0100 Subject: [PATCH 35/44] Doc revisions --- .../physical-expr/src/expressions/case.rs | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/datafusion/physical-expr/src/expressions/case.rs b/datafusion/physical-expr/src/expressions/case.rs index 2ec64c696084..12a6e12851e6 100644 --- a/datafusion/physical-expr/src/expressions/case.rs +++ b/datafusion/physical-expr/src/expressions/case.rs @@ -176,7 +176,7 @@ fn filter_array( /// Each element in `indices` is the index of an array in `values`. The `indices` array is processed /// sequentially. The first occurrence of index value `n` will be mapped to the first /// value of the array at index `n`. The second occurrence to the second value, and so on. -/// An index value of `usize::MAX` is used to indicate null values. +/// An index value where `PartialResultIndex::is_none` is `true` is used to indicate null values. /// /// # Implementation notes /// @@ -208,19 +208,19 @@ fn filter_array( /// /// ```text /// ┌───────────┐ ┌─────────┐ ┌─────────┐ -/// │┌─────────┐│ │ MAX │ │ NULL │ +/// │┌─────────┐│ │ None │ │ NULL │ /// ││ A ││ ├─────────┤ ├─────────┤ /// │└─────────┘│ │ 1 │ │ B │ /// │┌─────────┐│ ├─────────┤ ├─────────┤ /// ││ B ││ │ 0 │ merge(values, indices) │ A │ /// │└─────────┘│ ├─────────┤ ─────────────────────────▶ ├─────────┤ -/// │┌─────────┐│ │ MAX │ │ NULL │ +/// │┌─────────┐│ │ None │ │ NULL │ /// ││ C ││ ├─────────┤ ├─────────┤ /// │├─────────┤│ │ 2 │ │ C │ /// ││ D ││ ├─────────┤ ├─────────┤ /// │└─────────┘│ │ 2 │ │ D │ /// └───────────┘ └─────────┘ └─────────┘ -/// values indices result +/// values indices result /// /// ``` fn merge(values: &[ArrayData], indices: &[PartialResultIndex]) -> Result { @@ -283,7 +283,7 @@ impl PartialResultIndex { /// Creates a new partial result index. /// - /// If the provide value is greater than or equal to `u32::MAX` + /// If the provided value is greater than or equal to `u32::MAX` /// an error will be returned. fn try_new(index: usize) -> Result { let Ok(index) = u32::try_from(index) else { @@ -347,6 +347,7 @@ enum ResultState { /// any merging overhead. struct ResultBuilder { data_type: DataType, + /// The number of rows in the final result. row_count: usize, state: ResultState, } @@ -383,15 +384,15 @@ impl ResultBuilder { /// /// ```text /// ┌─────────┐ ┌─────────┐┌───────────┐ ┌─────────┐┌───────────┐ - /// │ C │ │ MAX ││┌─────────┐│ │ MAX ││┌─────────┐│ + /// │ C │ │ None ││┌─────────┐│ │ None ││┌─────────┐│ /// ├─────────┤ ├─────────┤││ A ││ ├─────────┤││ A ││ - /// │ D │ │ MAX ││└─────────┘│ │ 2 ││└─────────┘│ + /// │ D │ │ None ││└─────────┘│ │ 2 ││└─────────┘│ /// └─────────┘ ├─────────┤│┌─────────┐│ add_branch_result( ├─────────┤│┌─────────┐│ /// value │ 0 │││ B ││ row indices, │ 0 │││ B ││ /// ├─────────┤│└─────────┘│ value ├─────────┤│└─────────┘│ - /// │ MAX ││ │ ) │ MAX ││┌─────────┐│ + /// │ None ││ │ ) │ None ││┌─────────┐│ /// ┌─────────┐ ├─────────┤│ │ ─────────────────────────▶ ├─────────┤││ C ││ - /// │ 1 │ │ MAX ││ │ │ 2 ││├─────────┤│ + /// │ 1 │ │ None ││ │ │ 2 ││├─────────┤│ /// ├─────────┤ ├─────────┤│ │ ├─────────┤││ D ││ /// │ 4 │ │ 1 ││ │ │ 1 ││└─────────┘│ /// └─────────┘ └─────────┘└───────────┘ └─────────┘└───────────┘ From 596bd11ecf5be1d1b5915df80a9a353b5fe4f788 Mon Sep 17 00:00:00 2001 From: Pepijn Van Eeckhoudt Date: Mon, 27 Oct 2025 16:51:35 +0100 Subject: [PATCH 36/44] Assert row indices does not contain nulls --- datafusion/physical-expr/src/expressions/case.rs | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/datafusion/physical-expr/src/expressions/case.rs b/datafusion/physical-expr/src/expressions/case.rs index 12a6e12851e6..3d0396c0995e 100644 --- a/datafusion/physical-expr/src/expressions/case.rs +++ b/datafusion/physical-expr/src/expressions/case.rs @@ -438,6 +438,10 @@ impl ResultBuilder { row_indices: &ArrayRef, row_values: ArrayData, ) -> Result<()> { + if row_indices.null_count() != 0 { + return internal_err!("Row indices must not contain nulls"); + } + match &mut self.state { Partial { arrays, indices } => { // This is check is only active for debug config because the callers of this method, From 8b079cccf3d0dbdb8f2b697e00e22a423d10919b Mon Sep 17 00:00:00 2001 From: Pepijn Van Eeckhoudt Date: Mon, 27 Oct 2025 16:52:22 +0100 Subject: [PATCH 37/44] Complete/single/covering naming consistency --- datafusion/physical-expr/src/expressions/case.rs | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/datafusion/physical-expr/src/expressions/case.rs b/datafusion/physical-expr/src/expressions/case.rs index 3d0396c0995e..adac2fe92718 100644 --- a/datafusion/physical-expr/src/expressions/case.rs +++ b/datafusion/physical-expr/src/expressions/case.rs @@ -410,14 +410,14 @@ impl ResultBuilder { if a.len() != row_indices.len() { internal_err!("Array length must match row indices length") } else if row_indices.len() == self.row_count { - self.set_single_result(ColumnarValue::Array(a)) + self.set_complete_result(ColumnarValue::Array(a)) } else { self.add_partial_result(row_indices, a.to_data()) } } ColumnarValue::Scalar(s) => { if row_indices.len() == self.row_count { - self.set_single_result(ColumnarValue::Scalar(s)) + self.set_complete_result(ColumnarValue::Scalar(s)) } else { self.add_partial_result( row_indices, @@ -463,19 +463,19 @@ impl ResultBuilder { } Ok(()) } - Complete(_) => internal_err!("Complete result already set"), + Complete(_) => internal_err!("Cannot add a partial result when complete result is already set"), } } - /// Sets a covering result that applies to all rows. + /// Sets a result that applies to all rows. /// /// This is an optimization for cases where all rows evaluate to the same result. - /// When a covering result is set, the builder will return it directly from finish() + /// When a complete result is set, the builder will return it directly from finish() /// without any merging overhead. - fn set_single_result(&mut self, value: ColumnarValue) -> Result<()> { + fn set_complete_result(&mut self, value: ColumnarValue) -> Result<()> { match &self.state { Partial { arrays, .. } if !arrays.is_empty() => { - internal_err!("Partial result already set") + internal_err!("Cannot set a complete result when there are already partial results") } Complete(_) => internal_err!("Complete result already set"), _ => { From 1afa6f6ff0f8306b97d6d32730313b75f6a07ef0 Mon Sep 17 00:00:00 2001 From: Pepijn Van Eeckhoudt Date: Mon, 27 Oct 2025 16:52:46 +0100 Subject: [PATCH 38/44] Formatting --- datafusion/physical-expr/src/expressions/case.rs | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/datafusion/physical-expr/src/expressions/case.rs b/datafusion/physical-expr/src/expressions/case.rs index adac2fe92718..5dd741c18c9c 100644 --- a/datafusion/physical-expr/src/expressions/case.rs +++ b/datafusion/physical-expr/src/expressions/case.rs @@ -463,7 +463,9 @@ impl ResultBuilder { } Ok(()) } - Complete(_) => internal_err!("Cannot add a partial result when complete result is already set"), + Complete(_) => internal_err!( + "Cannot add a partial result when complete result is already set" + ), } } @@ -475,7 +477,9 @@ impl ResultBuilder { fn set_complete_result(&mut self, value: ColumnarValue) -> Result<()> { match &self.state { Partial { arrays, .. } if !arrays.is_empty() => { - internal_err!("Cannot set a complete result when there are already partial results") + internal_err!( + "Cannot set a complete result when there are already partial results" + ) } Complete(_) => internal_err!("Complete result already set"), _ => { From 878d86afa1746d50ef1c0c8d9a17d7242260be2a Mon Sep 17 00:00:00 2001 From: Pepijn Van Eeckhoudt Date: Mon, 27 Oct 2025 17:04:30 +0100 Subject: [PATCH 39/44] Add debug_assertions bounds check in merge --- datafusion/physical-expr/src/expressions/case.rs | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/datafusion/physical-expr/src/expressions/case.rs b/datafusion/physical-expr/src/expressions/case.rs index 5dd741c18c9c..d187875e1750 100644 --- a/datafusion/physical-expr/src/expressions/case.rs +++ b/datafusion/physical-expr/src/expressions/case.rs @@ -224,6 +224,13 @@ fn filter_array( /// /// ``` fn merge(values: &[ArrayData], indices: &[PartialResultIndex]) -> Result { + #[cfg(debug_assertions)] + for ix in indices { + if let Some(index) = ix.index() { + assert!(index < values.len(), "Index out of bounds: {} >= {}", index, values.len()); + } + } + let data_refs = values.iter().collect(); let mut mutable = MutableArrayData::new(data_refs, true, indices.len()); From 2b5a168e2a517e135202ad7f10784d3b689aeee1 Mon Sep 17 00:00:00 2001 From: Pepijn Van Eeckhoudt Date: Mon, 27 Oct 2025 17:32:53 +0100 Subject: [PATCH 40/44] Clarify usage of prep_null_mask_filter --- .../physical-expr/src/expressions/case.rs | 58 ++++++++++++------- 1 file changed, 36 insertions(+), 22 deletions(-) diff --git a/datafusion/physical-expr/src/expressions/case.rs b/datafusion/physical-expr/src/expressions/case.rs index d187875e1750..5f9a2e6635a7 100644 --- a/datafusion/physical-expr/src/expressions/case.rs +++ b/datafusion/physical-expr/src/expressions/case.rs @@ -227,7 +227,12 @@ fn merge(values: &[ArrayData], indices: &[PartialResultIndex]) -> Result= {}", index, values.len()); + assert!( + index < values.len(), + "Index out of bounds: {} >= {}", + index, + values.len() + ); } } @@ -671,30 +676,28 @@ impl CaseExpr { } }?; - let when_match_count = when_value.true_count(); + // `true_count` ignores `true` values where the validity bit is not set, so there's + // no need to call `prep_null_mask_filter`. + let when_true_count = when_value.true_count(); // If the 'when' predicate did not match any rows, continue to the next branch immediately - if when_match_count == 0 { + if when_true_count == 0 { continue; } // If the 'when' predicate matched all remaining rows, there is no need to filter - if when_match_count == remainder_batch.num_rows() { + if when_true_count == remainder_batch.num_rows() { let then_expression = &self.when_then_expr[i].1; let then_value = then_expression.evaluate(&remainder_batch)?; result_builder.add_branch_result(&remainder_rows, then_value)?; return result_builder.finish(); } - // Make sure 'NULL' is treated as false - let when_value = match when_value.null_count() { - 0 => when_value, - _ => prep_null_mask_filter(&when_value), - }; - // Filter the remainder batch based on the 'when' value // This results in a batch containing only the rows that need to be evaluated // for the current branch + // Still no need to call `prep_null_mask_filter` since `create_filter` will already do + // this unconditionally. let then_filter = create_filter(&when_value); let then_batch = filter_record_batch(&remainder_batch, &then_filter)?; let then_rows = filter_array(&remainder_rows, &then_filter)?; @@ -710,7 +713,14 @@ impl CaseExpr { } // Prepare the next when branch (or the else branch) - let next_selection = not(&when_value)?; + let next_selection = match when_value.null_count() { + 0 => not(&when_value), + _ => { + // `prep_null_mask_filter` is required to ensure the not operation treats nulls + // as false + not(&prep_null_mask_filter(&when_value)) + } + }?; let next_filter = create_filter(&next_selection); remainder_batch = Cow::Owned(filter_record_batch(&remainder_batch, &next_filter)?); @@ -758,31 +768,29 @@ impl CaseExpr { internal_datafusion_err!("WHEN expression did not return a BooleanArray") })?; - let when_match_count = when_value.true_count(); + // `true_count` ignores `true` values where the validity bit is not set, so there's + // no need to call `prep_null_mask_filter`. + let when_true_count = when_value.true_count(); // If the 'when' predicate did not match any rows, continue to the next branch immediately - if when_match_count == 0 { + if when_true_count == 0 { continue; } // If the 'when' predicate matched all remaining rows, there is no need to filter - if when_match_count == remainder_batch.num_rows() { + if when_true_count == remainder_batch.num_rows() { let then_expression = &self.when_then_expr[i].1; let then_value = then_expression.evaluate(&remainder_batch)?; result_builder.add_branch_result(&remainder_rows, then_value)?; return result_builder.finish(); } - // Make sure 'NULL' is treated as false - let when_value = match when_value.null_count() { - 0 => Cow::Borrowed(when_value), - _ => Cow::Owned(prep_null_mask_filter(when_value)), - }; - // Filter the remainder batch based on the 'when' value // This results in a batch containing only the rows that need to be evaluated // for the current branch - let then_filter = create_filter(&when_value); + // Still no need to call `prep_null_mask_filter` since `create_filter` will already do + // this unconditionally. + let then_filter = create_filter(when_value); let then_batch = filter_record_batch(&remainder_batch, &then_filter)?; let then_rows = filter_array(&remainder_rows, &then_filter)?; @@ -797,7 +805,13 @@ impl CaseExpr { } // Prepare the next when branch (or the else branch) - let next_selection = not(&when_value)?; + let next_selection = match when_value.null_count() { + 0 => not(when_value), + _ => { + // `prep_null_mask_filter` is required to ensure the not operation treats nulls as false + not(&prep_null_mask_filter(when_value)) + } + }?; let next_filter = create_filter(&next_selection); remainder_batch = Cow::Owned(filter_record_batch(&remainder_batch, &next_filter)?); From d08a9160f0880cd83bca331e0d8aab3573ba58fc Mon Sep 17 00:00:00 2001 From: Pepijn Van Eeckhoudt Date: Mon, 27 Oct 2025 17:48:09 +0100 Subject: [PATCH 41/44] Move debug assertion inside loop to remove loop duplication --- .../physical-expr/src/expressions/case.rs | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) diff --git a/datafusion/physical-expr/src/expressions/case.rs b/datafusion/physical-expr/src/expressions/case.rs index 5f9a2e6635a7..08b9404e6643 100644 --- a/datafusion/physical-expr/src/expressions/case.rs +++ b/datafusion/physical-expr/src/expressions/case.rs @@ -456,21 +456,19 @@ impl ResultBuilder { match &mut self.state { Partial { arrays, indices } => { - // This is check is only active for debug config because the callers of this method, - // `case_when_with_expr` and `case_when_no_expr`, already ensure that - // they only calculate a value for each row at most once. - #[cfg(debug_assertions)] - for row_ix in row_indices.as_primitive::().values().iter() { - if !indices[*row_ix as usize].is_none() { - return internal_err!("Duplicate value for row {}", *row_ix); - } - } - let array_index = PartialResultIndex::try_new(arrays.len())?; arrays.push(row_values); for row_ix in row_indices.as_primitive::().values().iter() { + // This is check is only active for debug config because the callers of this method, + // `case_when_with_expr` and `case_when_no_expr`, already ensure that + // they only calculate a value for each row at most once. + #[cfg(debug_assertions)] + if !indices[*row_ix as usize].is_none() { + return internal_err!("Duplicate value for row {}", *row_ix); + } + indices[*row_ix as usize] = array_index; } Ok(()) From 6fce2132c89883a2e9f9371499cc80e0559d93db Mon Sep 17 00:00:00 2001 From: Pepijn Van Eeckhoudt Date: Mon, 27 Oct 2025 18:05:39 +0100 Subject: [PATCH 42/44] Align with and without expr code --- datafusion/physical-expr/src/expressions/case.rs | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/datafusion/physical-expr/src/expressions/case.rs b/datafusion/physical-expr/src/expressions/case.rs index 08b9404e6643..ba682e1e2c57 100644 --- a/datafusion/physical-expr/src/expressions/case.rs +++ b/datafusion/physical-expr/src/expressions/case.rs @@ -798,7 +798,7 @@ impl CaseExpr { // If this is the last 'when' branch and there is no 'else' expression, there's no // point in calculating the remaining rows. - if i == self.when_then_expr.len() - 1 && self.else_expr.is_none() { + if self.else_expr.is_none() && i == self.when_then_expr.len() - 1 { return result_builder.finish(); } @@ -806,7 +806,8 @@ impl CaseExpr { let next_selection = match when_value.null_count() { 0 => not(when_value), _ => { - // `prep_null_mask_filter` is required to ensure the not operation treats nulls as false + // `prep_null_mask_filter` is required to ensure the not operation treats nulls + // as false not(&prep_null_mask_filter(when_value)) } }?; From dafea94b3b43a77d1d784c5e70754377d5592f7f Mon Sep 17 00:00:00 2001 From: Pepijn Van Eeckhoudt Date: Mon, 27 Oct 2025 19:29:51 +0100 Subject: [PATCH 43/44] Clarify add_branch_result diagram --- .../physical-expr/src/expressions/case.rs | 35 +++++++++---------- 1 file changed, 17 insertions(+), 18 deletions(-) diff --git a/datafusion/physical-expr/src/expressions/case.rs b/datafusion/physical-expr/src/expressions/case.rs index ba682e1e2c57..28da27a8b3b0 100644 --- a/datafusion/physical-expr/src/expressions/case.rs +++ b/datafusion/physical-expr/src/expressions/case.rs @@ -391,26 +391,25 @@ impl ResultBuilder { /// /// The diagram below shows a situation where a when expression matched rows 1 and 4 of the /// record batch. The then expression produced the value array `[A, D]`. - /// After adding this result, the result array will have been added to `Partial::arrays` and - /// `Partial::indices` will have been updated at indexes 1 and 4. + /// After adding this result, the result array will have been added to `partial arrays` and + /// `partial indices` will have been updated at indexes `1` and `4`. /// /// ```text - /// ┌─────────┐ ┌─────────┐┌───────────┐ ┌─────────┐┌───────────┐ - /// │ C │ │ None ││┌─────────┐│ │ None ││┌─────────┐│ - /// ├─────────┤ ├─────────┤││ A ││ ├─────────┤││ A ││ - /// │ D │ │ None ││└─────────┘│ │ 2 ││└─────────┘│ - /// └─────────┘ ├─────────┤│┌─────────┐│ add_branch_result( ├─────────┤│┌─────────┐│ - /// value │ 0 │││ B ││ row indices, │ 0 │││ B ││ - /// ├─────────┤│└─────────┘│ value ├─────────┤│└─────────┘│ - /// │ None ││ │ ) │ None ││┌─────────┐│ - /// ┌─────────┐ ├─────────┤│ │ ─────────────────────────▶ ├─────────┤││ C ││ - /// │ 1 │ │ None ││ │ │ 2 ││├─────────┤│ - /// ├─────────┤ ├─────────┤│ │ ├─────────┤││ D ││ - /// │ 4 │ │ 1 ││ │ │ 1 ││└─────────┘│ - /// └─────────┘ └─────────┘└───────────┘ └─────────┘└───────────┘ - /// row indices - /// partial partial partial partial - /// indices arrays indices arrays + /// ┌─────────┐ ┌─────────┐┌───────────┐ ┌─────────┐┌───────────┐ + /// │ C │ │ 0: None ││┌ 0 ──────┐│ │ 0: None ││┌ 0 ──────┐│ + /// ├─────────┤ ├─────────┤││ A ││ ├─────────┤││ A ││ + /// │ D │ │ 1: None ││└─────────┘│ │ 1: 2 ││└─────────┘│ + /// └─────────┘ ├─────────┤│┌ 1 ──────┐│ add_branch_result( ├─────────┤│┌ 1 ──────┐│ + /// matching │ 2: 0 │││ B ││ row indices, │ 2: 0 │││ B ││ + /// 'then' values ├─────────┤│└─────────┘│ value ├─────────┤│└─────────┘│ + /// │ 3: None ││ │ ) │ 3: None ││┌ 2 ──────┐│ + /// ┌─────────┐ ├─────────┤│ │ ─────────────────────────▶ ├─────────┤││ C ││ + /// │ 1 │ │ 4: None ││ │ │ 4: 2 ││├─────────┤│ + /// ├─────────┤ ├─────────┤│ │ ├─────────┤││ D ││ + /// │ 4 │ │ 5: 1 ││ │ │ 5: 1 ││└─────────┘│ + /// └─────────┘ └─────────┘└───────────┘ └─────────┘└───────────┘ + /// row indices partial partial partial partial + /// indices arrays indices arrays /// ``` fn add_branch_result( &mut self, From 7a4a24b513ca79c94f7d09521b9a82fd6afdd18d Mon Sep 17 00:00:00 2001 From: Pepijn Van Eeckhoudt Date: Tue, 28 Oct 2025 08:07:55 +0100 Subject: [PATCH 44/44] Add ResultState::Empty to defer creation of the indices array --- .../physical-expr/src/expressions/case.rs | 53 ++++++++++++------- 1 file changed, 35 insertions(+), 18 deletions(-) diff --git a/datafusion/physical-expr/src/expressions/case.rs b/datafusion/physical-expr/src/expressions/case.rs index 28da27a8b3b0..0b4c3af1d9c5 100644 --- a/datafusion/physical-expr/src/expressions/case.rs +++ b/datafusion/physical-expr/src/expressions/case.rs @@ -16,7 +16,7 @@ // under the License. use super::{Column, Literal}; -use crate::expressions::case::ResultState::{Complete, Partial}; +use crate::expressions::case::ResultState::{Complete, Empty, Partial}; use crate::expressions::try_cast; use crate::PhysicalExpr; use arrow::array::*; @@ -293,6 +293,10 @@ impl PartialResultIndex { Self { index: NONE_VALUE } } + fn zero() -> Self { + Self { index: 0 } + } + /// Creates a new partial result index. /// /// If the provided value is greater than or equal to `u32::MAX` @@ -335,10 +339,12 @@ impl Debug for PartialResultIndex { } enum ResultState { - /// The final result needs to be computed by merging the the data in `arrays`. + /// The final result is an array containing only null values. + Empty, + /// The final result needs to be computed by merging the data in `arrays`. Partial { - // A Vec of partial results that should be merged. `partial_result_indices` contains - // indexes into this vec. + // A `Vec` of partial results that should be merged. + // `partial_result_indices` contains indexes into this vec. arrays: Vec, // Indicates per result row from which array in `partial_results` a value should be taken. indices: Vec, @@ -372,10 +378,7 @@ impl ResultBuilder { Self { data_type: data_type.clone(), row_count, - state: Partial { - arrays: vec![], - indices: vec![PartialResultIndex::none(); row_count], - }, + state: Empty, } } @@ -454,6 +457,20 @@ impl ResultBuilder { } match &mut self.state { + Empty => { + let array_index = PartialResultIndex::zero(); + let mut indices = vec![PartialResultIndex::none(); self.row_count]; + for row_ix in row_indices.as_primitive::().values().iter() { + indices[*row_ix as usize] = array_index; + } + + self.state = Partial { + arrays: vec![row_values], + indices, + }; + + Ok(()) + } Partial { arrays, indices } => { let array_index = PartialResultIndex::try_new(arrays.len())?; @@ -485,27 +502,23 @@ impl ResultBuilder { /// without any merging overhead. fn set_complete_result(&mut self, value: ColumnarValue) -> Result<()> { match &self.state { - Partial { arrays, .. } if !arrays.is_empty() => { + Empty => { + self.state = Complete(value); + Ok(()) + } + Partial { .. } => { internal_err!( "Cannot set a complete result when there are already partial results" ) } Complete(_) => internal_err!("Complete result already set"), - _ => { - self.state = Complete(value); - Ok(()) - } } } /// Finishes building the result and returns the final array. fn finish(self) -> Result { match self.state { - Complete(v) => { - // If we have a complete result, we can just return it. - Ok(v) - } - Partial { arrays, .. } if arrays.is_empty() => { + Empty => { // No complete result and no partial results. // This can happen for case expressions with no else branch where no rows // matched. @@ -517,6 +530,10 @@ impl ResultBuilder { // Merge partial results into a single array. Ok(ColumnarValue::Array(merge(&arrays, &indices)?)) } + Complete(v) => { + // If we have a complete result, we can just return it. + Ok(v) + } } } }