Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions native/spark-expr/src/comet_scalar_funcs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,10 @@ use crate::scalar_funcs::hash_expressions::{
};
use crate::scalar_funcs::{
spark_ceil, spark_date_add, spark_date_sub, spark_decimal_div, spark_floor, spark_hex,
spark_isnan, spark_make_decimal, spark_murmur3_hash, spark_read_side_padding, spark_round,
spark_unhex, spark_unscaled_value, spark_xxhash64, SparkChrFunc,
spark_make_decimal, spark_murmur3_hash, spark_read_side_padding, spark_round, spark_unhex,
spark_unscaled_value, spark_xxhash64, SparkChrFunc,
};
use crate::spark_isnan;
use arrow_schema::DataType;
use datafusion_common::{DataFusionError, Result as DataFusionResult};
use datafusion_expr::registry::FunctionRegistry;
Expand Down
4 changes: 2 additions & 2 deletions native/spark-expr/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@ mod strings;
pub use strings::{Contains, EndsWith, Like, StartsWith, StringSpaceExpr, SubstringExpr};
mod kernels;
mod list;
mod regexp;
pub mod scalar_funcs;
mod schema_adapter;
pub use schema_adapter::SparkSchemaAdapterFactory;
Expand All @@ -62,6 +61,8 @@ mod unbound;
pub use unbound::UnboundColumn;
pub mod utils;
pub use normalize_nan::NormalizeNaNAndZero;
mod predicate_funcs;
pub use predicate_funcs::*;

mod variance;
pub use variance::Variance;
Expand All @@ -71,7 +72,6 @@ pub use comet_scalar_funcs::create_comet_physical_fun;
pub use error::{SparkError, SparkResult};
pub use if_expr::IfExpr;
pub use list::{ArrayInsert, GetArrayStructFields, ListExtract};
pub use regexp::RLike;
pub use structs::{CreateNamedStruct, GetStructField};
pub use temporal::{DateTruncExpr, HourExpr, MinuteExpr, SecondExpr, TimestampTruncExpr};
pub use to_json::ToJson;
Expand Down
70 changes: 70 additions & 0 deletions native/spark-expr/src/predicate_funcs/is_nan.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

use arrow::array::{Float32Array, Float64Array};
use arrow_array::{Array, BooleanArray};
use arrow_schema::DataType;
use datafusion::physical_plan::ColumnarValue;
use datafusion_common::{DataFusionError, ScalarValue};
use std::sync::Arc;

/// Spark-compatible `isnan` expression
pub fn spark_isnan(args: &[ColumnarValue]) -> Result<ColumnarValue, DataFusionError> {
fn set_nulls_to_false(is_nan: BooleanArray) -> ColumnarValue {
match is_nan.nulls() {
Some(nulls) => {
let is_not_null = nulls.inner();
ColumnarValue::Array(Arc::new(BooleanArray::new(
is_nan.values() & is_not_null,
None,
)))
}
None => ColumnarValue::Array(Arc::new(is_nan)),
}
}
let value = &args[0];
match value {
ColumnarValue::Array(array) => match array.data_type() {
DataType::Float64 => {
let array = array.as_any().downcast_ref::<Float64Array>().unwrap();
let is_nan = BooleanArray::from_unary(array, |x| x.is_nan());
Ok(set_nulls_to_false(is_nan))
}
DataType::Float32 => {
let array = array.as_any().downcast_ref::<Float32Array>().unwrap();
let is_nan = BooleanArray::from_unary(array, |x| x.is_nan());
Ok(set_nulls_to_false(is_nan))
}
other => Err(DataFusionError::Internal(format!(
"Unsupported data type {:?} for function isnan",
other,
))),
},
ColumnarValue::Scalar(a) => match a {
ScalarValue::Float64(a) => Ok(ColumnarValue::Scalar(ScalarValue::Boolean(Some(
a.map(|x| x.is_nan()).unwrap_or(false),
)))),
ScalarValue::Float32(a) => Ok(ColumnarValue::Scalar(ScalarValue::Boolean(Some(
a.map(|x| x.is_nan()).unwrap_or(false),
)))),
_ => Err(DataFusionError::Internal(format!(
"Unsupported data type {:?} for function isnan",
value.data_type(),
))),
},
}
}
22 changes: 22 additions & 0 deletions native/spark-expr/src/predicate_funcs/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

mod is_nan;
mod rlike;

pub use is_nan::spark_isnan;
pub use rlike::RLike;
49 changes: 1 addition & 48 deletions native/spark-expr/src/scalar_funcs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ use arrow::{
};
use arrow_array::builder::{GenericStringBuilder, IntervalDayTimeBuilder};
use arrow_array::types::{Int16Type, Int32Type, Int8Type};
use arrow_array::{Array, ArrowNativeTypeOp, BooleanArray, Datum, Decimal128Array};
use arrow_array::{Array, ArrowNativeTypeOp, Datum, Decimal128Array};
use arrow_schema::{ArrowError, DataType, DECIMAL128_MAX_PRECISION};
use datafusion::physical_expr_common::datum;
use datafusion::{functions::math::round::round, physical_plan::ColumnarValue};
Expand Down Expand Up @@ -505,53 +505,6 @@ pub fn spark_decimal_div(
Ok(ColumnarValue::Array(Arc::new(result)))
}

/// Spark-compatible `isnan` expression
pub fn spark_isnan(args: &[ColumnarValue]) -> Result<ColumnarValue, DataFusionError> {
fn set_nulls_to_false(is_nan: BooleanArray) -> ColumnarValue {
match is_nan.nulls() {
Some(nulls) => {
let is_not_null = nulls.inner();
ColumnarValue::Array(Arc::new(BooleanArray::new(
is_nan.values() & is_not_null,
None,
)))
}
None => ColumnarValue::Array(Arc::new(is_nan)),
}
}
let value = &args[0];
match value {
ColumnarValue::Array(array) => match array.data_type() {
DataType::Float64 => {
let array = array.as_any().downcast_ref::<Float64Array>().unwrap();
let is_nan = BooleanArray::from_unary(array, |x| x.is_nan());
Ok(set_nulls_to_false(is_nan))
}
DataType::Float32 => {
let array = array.as_any().downcast_ref::<Float32Array>().unwrap();
let is_nan = BooleanArray::from_unary(array, |x| x.is_nan());
Ok(set_nulls_to_false(is_nan))
}
other => Err(DataFusionError::Internal(format!(
"Unsupported data type {:?} for function isnan",
other,
))),
},
ColumnarValue::Scalar(a) => match a {
ScalarValue::Float64(a) => Ok(ColumnarValue::Scalar(ScalarValue::Boolean(Some(
a.map(|x| x.is_nan()).unwrap_or(false),
)))),
ScalarValue::Float32(a) => Ok(ColumnarValue::Scalar(ScalarValue::Boolean(Some(
a.map(|x| x.is_nan()).unwrap_or(false),
)))),
_ => Err(DataFusionError::Internal(format!(
"Unsupported data type {:?} for function isnan",
value.data_type(),
))),
},
}
}

macro_rules! scalar_date_arithmetic {
($start:expr, $days:expr, $op:expr) => {{
let interval = IntervalDayTime::new(*$days as i32, 0);
Expand Down