Skip to content

Commit 1f1130e

Browse files
authored
Use arrow eq kernels in CaseWhen (#52)
1 parent 589f355 commit 1f1130e

File tree

1 file changed

+38
-24
lines changed
  • datafusion/src/physical_plan/expressions

1 file changed

+38
-24
lines changed

datafusion/src/physical_plan/expressions/case.rs

Lines changed: 38 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,13 @@
1717

1818
use std::{any::Any, sync::Arc};
1919

20+
use crate::error::{DataFusionError, Result};
21+
use crate::physical_plan::{ColumnarValue, PhysicalExpr};
2022
use arrow::array::{self, *};
23+
use arrow::compute::{eq, eq_utf8};
2124
use arrow::datatypes::{DataType, Schema};
2225
use arrow::record_batch::RecordBatch;
2326

24-
use crate::error::{DataFusionError, Result};
25-
use crate::physical_plan::{ColumnarValue, PhysicalExpr};
26-
2727
/// The CASE expression is similar to a series of nested if/else and there are two forms that
2828
/// can be used. The first form consists of a series of boolean "when" expressions with
2929
/// corresponding "then" expressions, and an optional "else" expression.
@@ -265,7 +265,7 @@ fn build_null_array(data_type: &DataType, num_rows: usize) -> Result<ArrayRef> {
265265
}
266266

267267
macro_rules! array_equals {
268-
($TY:ty, $L:expr, $R:expr) => {{
268+
($TY:ty, $L:expr, $R:expr, $eq_fn:expr) => {{
269269
let when_value = $L
270270
.as_ref()
271271
.as_any()
@@ -278,15 +278,7 @@ macro_rules! array_equals {
278278
.downcast_ref::<$TY>()
279279
.expect("array_equals downcast failed");
280280

281-
let mut builder = BooleanBuilder::new(when_value.len());
282-
for row in 0..when_value.len() {
283-
if when_value.is_valid(row) && base_value.is_valid(row) {
284-
builder.append_value(when_value.value(row) == base_value.value(row))?;
285-
} else {
286-
builder.append_null()?;
287-
}
288-
}
289-
Ok(builder.finish())
281+
$eq_fn(when_value, base_value).map_err(DataFusionError::from)
290282
}};
291283
}
292284

@@ -296,17 +288,39 @@ fn array_equals(
296288
base_value: ArrayRef,
297289
) -> Result<BooleanArray> {
298290
match data_type {
299-
DataType::UInt8 => array_equals!(array::UInt8Array, when_value, base_value),
300-
DataType::UInt16 => array_equals!(array::UInt16Array, when_value, base_value),
301-
DataType::UInt32 => array_equals!(array::UInt32Array, when_value, base_value),
302-
DataType::UInt64 => array_equals!(array::UInt64Array, when_value, base_value),
303-
DataType::Int8 => array_equals!(array::Int8Array, when_value, base_value),
304-
DataType::Int16 => array_equals!(array::Int16Array, when_value, base_value),
305-
DataType::Int32 => array_equals!(array::Int32Array, when_value, base_value),
306-
DataType::Int64 => array_equals!(array::Int64Array, when_value, base_value),
307-
DataType::Float32 => array_equals!(array::Float32Array, when_value, base_value),
308-
DataType::Float64 => array_equals!(array::Float64Array, when_value, base_value),
309-
DataType::Utf8 => array_equals!(array::StringArray, when_value, base_value),
291+
DataType::UInt8 => {
292+
array_equals!(array::UInt8Array, when_value, base_value, eq)
293+
}
294+
DataType::UInt16 => {
295+
array_equals!(array::UInt16Array, when_value, base_value, eq)
296+
}
297+
DataType::UInt32 => {
298+
array_equals!(array::UInt32Array, when_value, base_value, eq)
299+
}
300+
DataType::UInt64 => {
301+
array_equals!(array::UInt64Array, when_value, base_value, eq)
302+
}
303+
DataType::Int8 => {
304+
array_equals!(array::Int8Array, when_value, base_value, eq)
305+
}
306+
DataType::Int16 => {
307+
array_equals!(array::Int16Array, when_value, base_value, eq)
308+
}
309+
DataType::Int32 => {
310+
array_equals!(array::Int32Array, when_value, base_value, eq)
311+
}
312+
DataType::Int64 => {
313+
array_equals!(array::Int64Array, when_value, base_value, eq)
314+
}
315+
DataType::Float32 => {
316+
array_equals!(array::Float32Array, when_value, base_value, eq)
317+
}
318+
DataType::Float64 => {
319+
array_equals!(array::Float64Array, when_value, base_value, eq)
320+
}
321+
DataType::Utf8 => {
322+
array_equals!(array::StringArray, when_value, base_value, eq_utf8)
323+
}
310324
other => Err(DataFusionError::Execution(format!(
311325
"CASE does not support '{:?}'",
312326
other

0 commit comments

Comments
 (0)