diff --git a/datafusion/functions/benches/regx.rs b/datafusion/functions/benches/regx.rs index cd5d98700676..c18241f799e3 100644 --- a/datafusion/functions/benches/regx.rs +++ b/datafusion/functions/benches/regx.rs @@ -267,11 +267,11 @@ fn criterion_benchmark(c: &mut Criterion) { b.iter(|| { black_box( - regexp_replace::( + regexp_replace::( data.as_string::(), regex.as_string::(), replacement.as_string::(), - Some(&flags), + Some(flags.as_string::()), ) .expect("regexp_replace should work on valid values"), ) @@ -282,19 +282,18 @@ fn criterion_benchmark(c: &mut Criterion) { let mut rng = rand::rng(); let data = cast(&data(&mut rng), &DataType::Utf8View).unwrap(); let regex = cast(®ex(&mut rng), &DataType::Utf8View).unwrap(); - // flags are not allowed to be utf8view according to the function - let flags = Arc::new(flags(&mut rng)) as ArrayRef; + let flags = cast(&flags(&mut rng), &DataType::Utf8View).unwrap(); let replacement = Arc::new(StringViewArray::from_iter_values(iter::repeat_n( "XX", 1000, ))); b.iter(|| { black_box( - regexp_replace::( + regexp_replace::( data.as_string_view(), regex.as_string_view(), - &replacement, - Some(&flags), + &*replacement, + Some(flags.as_string_view()), ) .expect("regexp_replace should work on valid values"), ) diff --git a/datafusion/functions/src/regex/regexpreplace.rs b/datafusion/functions/src/regex/regexpreplace.rs index 39858119c89c..ca3d19822e13 100644 --- a/datafusion/functions/src/regex/regexpreplace.rs +++ b/datafusion/functions/src/regex/regexpreplace.rs @@ -24,7 +24,9 @@ use arrow::array::{new_null_array, ArrayIter, AsArray}; use arrow::array::{Array, ArrayRef, OffsetSizeTrait}; use arrow::array::{ArrayAccessor, StringViewArray}; use arrow::datatypes::DataType; -use datafusion_common::cast::as_string_view_array; +use datafusion_common::cast::{ + as_large_string_array, as_string_array, as_string_view_array, +}; use datafusion_common::exec_err; use datafusion_common::plan_err; use datafusion_common::ScalarValue; @@ -95,13 +97,12 @@ impl Default for RegexpReplaceFunc { impl RegexpReplaceFunc { pub fn new() -> Self { use DataType::*; + use TypeSignature::*; Self { signature: Signature::one_of( vec![ - TypeSignature::Exact(vec![Utf8, Utf8, Utf8]), - TypeSignature::Exact(vec![Utf8View, Utf8, Utf8]), - TypeSignature::Exact(vec![Utf8, Utf8, Utf8, Utf8]), - TypeSignature::Exact(vec![Utf8View, Utf8, Utf8, Utf8]), + Uniform(3, vec![Utf8View, LargeUtf8, Utf8]), + Uniform(4, vec![Utf8View, LargeUtf8, Utf8]), ], Volatility::Immutable, ), @@ -238,15 +239,14 @@ fn regex_replace_posix_groups(replacement: &str) -> String { /// # Ok(()) /// # } /// ``` -pub fn regexp_replace<'a, T: OffsetSizeTrait, V, B>( - string_array: V, - pattern_array: B, - replacement_array: B, - flags: Option<&ArrayRef>, +pub fn regexp_replace<'a, T: OffsetSizeTrait, U>( + string_array: U, + pattern_array: U, + replacement_array: U, + flags_array: Option, ) -> Result where - V: ArrayAccessor, - B: ArrayAccessor, + U: ArrayAccessor, { // Default implementation for regexp_replace, assumes all args are arrays // and args is a sequence of 3 or 4 elements. @@ -260,7 +260,7 @@ where let pattern_array_iter = ArrayIter::new(pattern_array); let replacement_array_iter = ArrayIter::new(replacement_array); - match flags { + match flags_array { None => { let result_iter = string_array_iter .zip(pattern_array_iter) @@ -307,13 +307,13 @@ where } } } - Some(flags) => { - let flags_array = as_generic_string_array::(flags)?; + Some(flags_array) => { + let flags_array_iter = ArrayIter::new(flags_array); let result_iter = string_array_iter .zip(pattern_array_iter) .zip(replacement_array_iter) - .zip(flags_array.iter()) + .zip(flags_array_iter) .map(|(((string, pattern), replacement), flags)| { match (string, pattern, replacement, flags) { (Some(string), Some(pattern), Some(replacement), Some(flags)) => { @@ -398,12 +398,37 @@ fn _regexp_replace_early_abort( /// Note: If the array is empty or the first argument is null, /// then calls the given early abort function. macro_rules! fetch_string_arg { - ($ARG:expr, $NAME:expr, $T:ident, $EARLY_ABORT:ident, $ARRAY_SIZE:expr) => {{ - let array = as_generic_string_array::<$T>($ARG)?; - if array.len() == 0 || array.is_null(0) { - return $EARLY_ABORT(array, $ARRAY_SIZE); - } else { - array.value(0) + ($ARG:expr, $NAME:expr, $EARLY_ABORT:ident, $ARRAY_SIZE:expr) => {{ + let string_array_type = ($ARG).data_type(); + match string_array_type { + DataType::Utf8 => { + let array = as_string_array($ARG)?; + if array.len() == 0 || array.is_null(0) { + return $EARLY_ABORT(array, $ARRAY_SIZE); + } else { + array.value(0) + } + } + DataType::LargeUtf8 => { + let array = as_large_string_array($ARG)?; + if array.len() == 0 || array.is_null(0) { + return $EARLY_ABORT(array, $ARRAY_SIZE); + } else { + array.value(0) + } + } + DataType::Utf8View => { + let array = as_string_view_array($ARG)?; + if array.len() == 0 || array.is_null(0) { + return $EARLY_ABORT(array, $ARRAY_SIZE); + } else { + array.value(0) + } + } + _ => unreachable!( + "Invalid data type for regexp_replace: {}", + string_array_type + ), } }}; } @@ -417,23 +442,17 @@ fn _regexp_replace_static_pattern_replace( args: &[ArrayRef], ) -> Result { let array_size = args[0].len(); - let pattern = fetch_string_arg!( - &args[1], - "pattern", - i32, - _regexp_replace_early_abort, - array_size - ); + let pattern = + fetch_string_arg!(&args[1], "pattern", _regexp_replace_early_abort, array_size); let replacement = fetch_string_arg!( &args[2], "replacement", - i32, _regexp_replace_early_abort, array_size ); let flags = match args.len() { 3 => None, - 4 => Some(fetch_string_arg!(&args[3], "flags", i32, _regexp_replace_early_abort, array_size)), + 4 => Some(fetch_string_arg!(&args[3], "flags", _regexp_replace_early_abort, array_size)), other => { return exec_err!( "regexp_replace was called with {other} arguments. It requires at least 3 and at most 4." @@ -590,38 +609,61 @@ pub fn specialize_regexp_replace( .map(|arg| arg.to_array(inferred_length)) .collect::>>()?; - match args[0].data_type() { - DataType::Utf8View => { - let string_array = args[0].as_string_view(); + match ( + args[0].data_type(), + args[1].data_type(), + args[2].data_type(), + args.get(3).map(|a| a.data_type()), + ) { + ( + DataType::Utf8, + DataType::Utf8, + DataType::Utf8, + Some(DataType::Utf8) | None, + ) => { + let string_array = args[0].as_string::(); let pattern_array = args[1].as_string::(); let replacement_array = args[2].as_string::(); - regexp_replace::( + let flags_array = args.get(3).map(|a| a.as_string::()); + regexp_replace::( string_array, pattern_array, replacement_array, - args.get(3), + flags_array, ) } - DataType::Utf8 => { - let string_array = args[0].as_string::(); - let pattern_array = args[1].as_string::(); - let replacement_array = args[2].as_string::(); - regexp_replace::( + ( + DataType::Utf8View, + DataType::Utf8View, + DataType::Utf8View, + Some(DataType::Utf8View) | None, + ) => { + let string_array = args[0].as_string_view(); + let pattern_array = args[1].as_string_view(); + let replacement_array = args[2].as_string_view(); + let flags_array = args.get(3).map(|a| a.as_string_view()); + regexp_replace::( string_array, pattern_array, replacement_array, - args.get(3), + flags_array, ) } - DataType::LargeUtf8 => { + ( + DataType::LargeUtf8, + DataType::LargeUtf8, + DataType::LargeUtf8, + Some(DataType::LargeUtf8) | None, + ) => { let string_array = args[0].as_string::(); let pattern_array = args[1].as_string::(); let replacement_array = args[2].as_string::(); - regexp_replace::( + let flags_array = args.get(3).map(|a| a.as_string::()); + regexp_replace::( string_array, pattern_array, replacement_array, - args.get(3), + flags_array, ) } other => { @@ -650,8 +692,8 @@ mod tests { vec!["afooc", "acd", "afoocd1234567890123", "123456789012afooc"]; let values = <$T>::from(values); - let patterns = StringArray::from(patterns); - let replacements = StringArray::from(replacement); + let patterns = <$T>::from(patterns); + let replacements = <$T>::from(replacement); let expected = <$T>::from(expected); let re = _regexp_replace_static_pattern_replace::<$O>(&[ diff --git a/datafusion/sqllogictest/test_files/string/string_literal.slt b/datafusion/sqllogictest/test_files/string/string_literal.slt index 79b783f89a61..f602dbb54b08 100644 --- a/datafusion/sqllogictest/test_files/string/string_literal.slt +++ b/datafusion/sqllogictest/test_files/string/string_literal.slt @@ -303,6 +303,26 @@ SELECT regexp_replace(arrow_cast('foobar', 'Dictionary(Int32, Utf8)'), 'bar', 'x ---- fooxx +query T +SELECT regexp_replace(arrow_cast('foobar', 'LargeUtf8'), 'bar', 'xx', 'gi') +---- +fooxx + +query T +SELECT regexp_replace(arrow_cast('foobar', 'Utf8View'), 'bar', 'xx', 'gi') +---- +fooxx + +query T +SELECT regexp_replace('foobar', arrow_cast('bar', 'LargeUtf8'), 'xx', 'gi') +---- +fooxx + +query T +SELECT regexp_replace('foobar', arrow_cast('bar', 'Utf8View'), 'xx', 'gi') +---- +fooxx + query T SELECT repeat('foo', 3) ---- diff --git a/datafusion/sqllogictest/test_files/string/string_view.slt b/datafusion/sqllogictest/test_files/string/string_view.slt index a72c8f574484..7d10a0615d45 100644 --- a/datafusion/sqllogictest/test_files/string/string_view.slt +++ b/datafusion/sqllogictest/test_files/string/string_view.slt @@ -804,7 +804,7 @@ EXPLAIN SELECT FROM test; ---- logical_plan -01)Projection: regexp_replace(test.column1_utf8view, Utf8("^https?://(?:www\.)?([^/]+)/.*$"), Utf8("\1")) AS k +01)Projection: regexp_replace(test.column1_utf8view, Utf8View("^https?://(?:www\.)?([^/]+)/.*$"), Utf8View("\1")) AS k 02)--TableScan: test projection=[column1_utf8view] ## Ensure no casts for REPEAT