diff --git a/datafusion/src/physical_plan/expressions/binary.rs b/datafusion/src/physical_plan/expressions/binary.rs index d1fc3bcdc029f..16a9383f79747 100644 --- a/datafusion/src/physical_plan/expressions/binary.rs +++ b/datafusion/src/physical_plan/expressions/binary.rs @@ -25,10 +25,9 @@ use arrow::compute::kernels::arithmetic::{ multiply_scalar, subtract, subtract_scalar, }; use arrow::compute::kernels::boolean::{and_kleene, not, or_kleene}; -use arrow::compute::kernels::comparison::{eq, gt, gt_eq, lt, lt_eq, neq}; use arrow::compute::kernels::comparison::{ - eq_bool, eq_bool_scalar, gt_bool, gt_bool_scalar, gt_eq_bool, gt_eq_bool_scalar, - lt_bool, lt_bool_scalar, lt_eq_bool, lt_eq_bool_scalar, neq_bool, neq_bool_scalar, + eq_bool_scalar, gt_bool_scalar, gt_eq_bool_scalar, lt_bool_scalar, lt_eq_bool_scalar, + neq_bool_scalar, }; use arrow::compute::kernels::comparison::{ eq_dyn_bool_scalar, gt_dyn_bool_scalar, gt_eq_dyn_bool_scalar, lt_dyn_bool_scalar, @@ -45,15 +44,12 @@ use arrow::compute::kernels::comparison::{ use arrow::compute::kernels::comparison::{ eq_scalar, gt_eq_scalar, gt_scalar, lt_eq_scalar, lt_scalar, neq_scalar, }; -use arrow::compute::kernels::comparison::{ - eq_utf8, gt_eq_utf8, gt_utf8, like_utf8, lt_eq_utf8, lt_utf8, neq_utf8, nlike_utf8, - regexp_is_match_utf8, -}; use arrow::compute::kernels::comparison::{ eq_utf8_scalar, gt_eq_utf8_scalar, gt_utf8_scalar, like_utf8_scalar, lt_eq_utf8_scalar, lt_utf8_scalar, neq_utf8_scalar, nlike_utf8_scalar, regexp_is_match_utf8_scalar, }; +use arrow::compute::kernels::comparison::{like_utf8, nlike_utf8, regexp_is_match_utf8}; use arrow::datatypes::{ArrowNumericType, DataType, Schema, TimeUnit}; use arrow::error::ArrowError::DivideByZero; use arrow::record_batch::RecordBatch; @@ -65,6 +61,50 @@ use crate::physical_plan::expressions::try_cast; use crate::physical_plan::{ColumnarValue, PhysicalExpr}; use crate::scalar::ScalarValue; +// TODO move to arrow_rs +// https://github.com/apache/arrow-rs/issues/1312 +fn as_decimal_array(arr: &dyn Array) -> &DecimalArray { + arr.as_any() + .downcast_ref::() + .expect("Unable to downcast to typed array to DecimalArray") +} + +/// create a `dyn_op` wrapper function for the specified operation +/// that call the underlying dyn_op arrow kernel if the type is +/// supported, and translates ArrowError to DataFusionError +macro_rules! make_dyn_comp_op { + ($OP:tt) => { + paste::paste! { + /// wrapper over arrow compute kernel that maps Error types and + /// patches missing support in arrow + fn [<$OP _dyn>] (left: &dyn Array, right: &dyn Array) -> Result { + match (left.data_type(), right.data_type()) { + // Call `op_decimal` (e.g. `eq_decimal) until + // arrow has native support + // https://github.com/apache/arrow-rs/issues/1200 + (DataType::Decimal(_, _), DataType::Decimal(_, _)) => { + [<$OP _decimal>](as_decimal_array(left), as_decimal_array(right)) + }, + // By default call the arrow kernel + _ => { + arrow::compute::kernels::comparison::[<$OP _dyn>](left, right) + .map_err(|e| e.into()) + } + } + .map(|a| Arc::new(a) as ArrayRef) + } + } + }; +} + +// create eq_dyn, gt_dyn, wrappers etc +make_dyn_comp_op!(eq); +make_dyn_comp_op!(gt); +make_dyn_comp_op!(gt_eq); +make_dyn_comp_op!(lt); +make_dyn_comp_op!(lt_eq); +make_dyn_comp_op!(neq); + // Simple (low performance) kernels until optimized kernels are added to arrow // See https://github.com/apache/arrow-rs/issues/960 @@ -91,8 +131,10 @@ fn is_not_distinct_from_bool( .collect()) } -// TODO add iter for decimal array -// TODO move this to arrow-rs +// TODO move decimal kernels to to arrow-rs +// https://github.com/apache/arrow-rs/issues/1200 + +// TODO use iter added for for decimal array in // https://github.com/apache/arrow-rs/issues/1083 pub(super) fn eq_decimal_scalar( left: &DecimalArray, @@ -1194,12 +1236,12 @@ impl BinaryExpr { match &self.op { Operator::Like => binary_string_array_op!(left, right, like), Operator::NotLike => binary_string_array_op!(left, right, nlike), - Operator::Lt => binary_array_op!(left, right, lt), - Operator::LtEq => binary_array_op!(left, right, lt_eq), - Operator::Gt => binary_array_op!(left, right, gt), - Operator::GtEq => binary_array_op!(left, right, gt_eq), - Operator::Eq => binary_array_op!(left, right, eq), - Operator::NotEq => binary_array_op!(left, right, neq), + Operator::Lt => lt_dyn(&left, &right), + Operator::LtEq => lt_eq_dyn(&left, &right), + Operator::Gt => gt_dyn(&left, &right), + Operator::GtEq => gt_eq_dyn(&left, &right), + Operator::Eq => eq_dyn(&left, &right), + Operator::NotEq => neq_dyn(&left, &right), Operator::IsDistinctFrom => binary_array_op!(left, right, is_distinct_from), Operator::IsNotDistinctFrom => { binary_array_op!(left, right, is_not_distinct_from)