diff --git a/datafusion/functions/src/unicode/strpos.rs b/datafusion/functions/src/unicode/strpos.rs index 3879f779eb713..6da67c8a2798a 100644 --- a/datafusion/functions/src/unicode/strpos.rs +++ b/datafusion/functions/src/unicode/strpos.rs @@ -23,7 +23,8 @@ use arrow::datatypes::{ArrowNativeType, DataType, Int32Type, Int64Type}; use crate::string::common::StringArrayType; use crate::utils::{make_scalar_function, utf8_to_int_type}; -use datafusion_common::{exec_err, plan_err, Result}; +use datafusion_common::{exec_err, Result}; +use datafusion_expr::TypeSignature::Exact; use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; #[derive(Debug)] @@ -40,8 +41,20 @@ impl Default for StrposFunc { impl StrposFunc { pub fn new() -> Self { + use DataType::*; Self { - signature: Signature::user_defined(Volatility::Immutable), + signature: Signature::one_of( + vec![ + Exact(vec![Utf8, Utf8]), + Exact(vec![Utf8, LargeUtf8]), + Exact(vec![LargeUtf8, Utf8]), + Exact(vec![LargeUtf8, LargeUtf8]), + Exact(vec![Utf8View, Utf8View]), + Exact(vec![Utf8View, Utf8]), + Exact(vec![Utf8View, LargeUtf8]), + ], + Volatility::Immutable, + ), aliases: vec![String::from("instr"), String::from("position")], } } @@ -71,25 +84,6 @@ impl ScalarUDFImpl for StrposFunc { fn aliases(&self) -> &[String] { &self.aliases } - - fn coerce_types(&self, arg_types: &[DataType]) -> Result> { - match arg_types { - [first, second ] => { - match (first, second) { - (DataType::LargeUtf8 | DataType::Utf8View | DataType::Utf8, DataType::LargeUtf8 | DataType::Utf8View | DataType::Utf8) => Ok(arg_types.to_vec()), - (DataType::Null, DataType::Null) => Ok(vec![DataType::Utf8, DataType::Utf8]), - (DataType::Null, _) => Ok(vec![DataType::Utf8, second.to_owned()]), - (_, DataType::Null) => Ok(vec![first.to_owned(), DataType::Utf8]), - (DataType::Dictionary(_, value_type), DataType::LargeUtf8 | DataType::Utf8View | DataType::Utf8) => match **value_type { - DataType::LargeUtf8 | DataType::Utf8View | DataType::Utf8 | DataType::Null | DataType::Binary => Ok(vec![*value_type.clone(), second.to_owned()]), - _ => plan_err!("The STRPOS/INSTR/POSITION function can only accept strings, but got {:?}.", **value_type), - }, - _ => plan_err!("The STRPOS/INSTR/POSITION function can only accept strings, but got {:?}.", arg_types) - } - }, - _ => plan_err!("The STRPOS/INSTR/POSITION function can only accept strings, but got {:?}", arg_types) - } - } } fn strpos(args: &[ArrayRef]) -> Result { diff --git a/datafusion/sqllogictest/test_files/functions.slt b/datafusion/sqllogictest/test_files/functions.slt index 7d41c26ba012d..5b6017b08a00a 100644 --- a/datafusion/sqllogictest/test_files/functions.slt +++ b/datafusion/sqllogictest/test_files/functions.slt @@ -553,6 +553,16 @@ SELECT strpos(arrow_cast('helloworld', 'Dictionary(Int32, Utf8)'), 'world') ---- 6 +query I +SELECT strpos('helloworld', NULL) +---- +NULL + +query I +SELECT strpos(arrow_cast('helloworld', 'Dictionary(Int32, Utf8)'), NULL) +---- +NULL + statement ok CREATE TABLE products ( product_id INT PRIMARY KEY, diff --git a/datafusion/sqllogictest/test_files/scalar.slt b/datafusion/sqllogictest/test_files/scalar.slt index 3b9c9a16042ca..8820fffaeb473 100644 --- a/datafusion/sqllogictest/test_files/scalar.slt +++ b/datafusion/sqllogictest/test_files/scalar.slt @@ -1907,8 +1907,10 @@ select position('' in '') 1 -query error POSITION function can only accept strings +query I select position(1 in 1) +---- +1 query I