1818use std:: any:: Any ;
1919use std:: sync:: Arc ;
2020
21- use crate :: strings:: make_and_append_view;
22- use crate :: utils:: make_scalar_function;
2321use arrow:: array:: {
24- Array , ArrayIter , ArrayRef , AsArray , Int64Array , NullBufferBuilder , StringArrayType ,
25- StringViewArray , StringViewBuilder ,
22+ Array , ArrayIter , ArrayRef , AsArray , Int64Array , OffsetSizeTrait ,
23+ StringArrayType , StringViewBuilder ,
2624} ;
27- use arrow:: buffer:: ScalarBuffer ;
2825use arrow:: datatypes:: DataType ;
26+
2927use datafusion_common:: cast:: as_int64_array;
3028use datafusion_common:: { exec_err, plan_err, Result } ;
3129use datafusion_expr:: {
3230 ColumnarValue , Documentation , ScalarUDFImpl , Signature , Volatility ,
3331} ;
3432use datafusion_macros:: user_doc;
3533
34+ use crate :: utils:: { make_scalar_function, utf8_to_str_type} ;
35+
3636#[ user_doc(
3737 doc_section( label = "String Functions" ) ,
3838 description = "Extracts a substring of a specified number of characters from a specific starting position in a string." ,
@@ -44,7 +44,7 @@ use datafusion_macros::user_doc;
4444| substr(Utf8("datafusion"),Int64(5),Int64(3)) |
4545+----------------------------------------------+
4646| fus |
47- +----------------------------------------------+
47+ +----------------------------------------------+
4848```"# ,
4949 standard_argument( name = "str" , prefix = "String" ) ,
5050 argument(
@@ -90,9 +90,8 @@ impl ScalarUDFImpl for SubstrFunc {
9090 & self . signature
9191 }
9292
93- // `SubstrFunc` always generates `Utf8View` output for its efficiency.
94- fn return_type ( & self , _arg_types : & [ DataType ] ) -> Result < DataType > {
95- Ok ( DataType :: Utf8View )
93+ fn return_type ( & self , arg_types : & [ DataType ] ) -> Result < DataType > {
94+ utf8_to_str_type ( & arg_types[ 0 ] , "substr" )
9695 }
9796
9897 fn invoke_with_args (
@@ -177,28 +176,21 @@ impl ScalarUDFImpl for SubstrFunc {
177176 }
178177}
179178
180- /// Extracts the substring of string starting at the start'th character, and extending for count characters if that is specified. (Same as substring(string from start for count).)
181- /// substr('alphabet', 3) = 'phabet'
182- /// substr('alphabet', 3, 2) = 'ph'
183- /// The implementation uses UTF-8 code points as characters
184179pub fn substr ( args : & [ ArrayRef ] ) -> Result < ArrayRef > {
185180 match args[ 0 ] . data_type ( ) {
186181 DataType :: Utf8 => {
187182 let string_array = args[ 0 ] . as_string :: < i32 > ( ) ;
188- string_substr :: < _ > ( string_array, & args[ 1 ..] )
183+ calculate_substr :: < _ , i32 > ( string_array, & args[ 1 ..] )
189184 }
190185 DataType :: LargeUtf8 => {
191186 let string_array = args[ 0 ] . as_string :: < i64 > ( ) ;
192- string_substr :: < _ > ( string_array, & args[ 1 ..] )
187+ calculate_substr :: < _ , i64 > ( string_array, & args[ 1 ..] )
193188 }
194189 DataType :: Utf8View => {
195190 let string_array = args[ 0 ] . as_string_view ( ) ;
196- string_view_substr ( string_array, & args[ 1 ..] )
191+ calculate_substr :: < _ , i32 > ( string_array, & args[ 1 ..] )
197192 }
198- other => exec_err ! (
199- "Unsupported data type {other:?} for function substr,\
200- expected Utf8View, Utf8 or LargeUtf8."
201- ) ,
193+ other => exec_err ! ( "Unsupported data type {other:?} for function substr" ) ,
202194 }
203195}
204196
@@ -312,120 +304,11 @@ fn enable_ascii_fast_path<'a, V: StringArrayType<'a>>(
312304 }
313305}
314306
315- // The decoding process refs the trait at: arrow/arrow-data/src/byte_view.rs:44
316- // From<u128> for ByteView
317- fn string_view_substr (
318- string_view_array : & StringViewArray ,
319- args : & [ ArrayRef ] ,
320- ) -> Result < ArrayRef > {
321- let mut views_buf = Vec :: with_capacity ( string_view_array. len ( ) ) ;
322- let mut null_builder = NullBufferBuilder :: new ( string_view_array. len ( ) ) ;
323-
324- let start_array = as_int64_array ( & args[ 0 ] ) ?;
325- let count_array_opt = if args. len ( ) == 2 {
326- Some ( as_int64_array ( & args[ 1 ] ) ?)
327- } else {
328- None
329- } ;
330-
331- let enable_ascii_fast_path =
332- enable_ascii_fast_path ( & string_view_array, start_array, count_array_opt) ;
333-
334- // In either case of `substr(s, i)` or `substr(s, i, cnt)`
335- // If any of input argument is `NULL`, the result is `NULL`
336- match args. len ( ) {
337- 1 => {
338- for ( ( str_opt, raw_view) , start_opt) in string_view_array
339- . iter ( )
340- . zip ( string_view_array. views ( ) . iter ( ) )
341- . zip ( start_array. iter ( ) )
342- {
343- if let ( Some ( str) , Some ( start) ) = ( str_opt, start_opt) {
344- let ( start, end) =
345- get_true_start_end ( str, start, None , enable_ascii_fast_path) ;
346- let substr = & str[ start..end] ;
347-
348- make_and_append_view (
349- & mut views_buf,
350- & mut null_builder,
351- raw_view,
352- substr,
353- start as u32 ,
354- ) ;
355- } else {
356- null_builder. append_null ( ) ;
357- views_buf. push ( 0 ) ;
358- }
359- }
360- }
361- 2 => {
362- let count_array = count_array_opt. unwrap ( ) ;
363- for ( ( ( str_opt, raw_view) , start_opt) , count_opt) in string_view_array
364- . iter ( )
365- . zip ( string_view_array. views ( ) . iter ( ) )
366- . zip ( start_array. iter ( ) )
367- . zip ( count_array. iter ( ) )
368- {
369- if let ( Some ( str) , Some ( start) , Some ( count) ) =
370- ( str_opt, start_opt, count_opt)
371- {
372- if count < 0 {
373- return exec_err ! (
374- "negative substring length not allowed: substr(<str>, {start}, {count})"
375- ) ;
376- } else {
377- if start == i64:: MIN {
378- return exec_err ! (
379- "negative overflow when calculating skip value"
380- ) ;
381- }
382- let ( start, end) = get_true_start_end (
383- str,
384- start,
385- Some ( count as u64 ) ,
386- enable_ascii_fast_path,
387- ) ;
388- let substr = & str[ start..end] ;
389-
390- make_and_append_view (
391- & mut views_buf,
392- & mut null_builder,
393- raw_view,
394- substr,
395- start as u32 ,
396- ) ;
397- }
398- } else {
399- null_builder. append_null ( ) ;
400- views_buf. push ( 0 ) ;
401- }
402- }
403- }
404- other => {
405- return exec_err ! (
406- "substr was called with {other} arguments. It requires 2 or 3."
407- )
408- }
409- }
410-
411- let views_buf = ScalarBuffer :: from ( views_buf) ;
412- let nulls_buf = null_builder. finish ( ) ;
413-
414- // Safety:
415- // (1) The blocks of the given views are all provided
416- // (2) Each of the range `view.offset+start..end` of view in views_buf is within
417- // the bounds of each of the blocks
418- unsafe {
419- let array = StringViewArray :: new_unchecked (
420- views_buf,
421- string_view_array. data_buffers ( ) . to_vec ( ) ,
422- nulls_buf,
423- ) ;
424- Ok ( Arc :: new ( array) as ArrayRef )
425- }
426- }
427-
428- fn string_substr < ' a , V > ( string_array : V , args : & [ ArrayRef ] ) -> Result < ArrayRef >
307+ /// Extracts the substring of string starting at the start'th character, and extending for count characters if that is specified. (Same as substring(string from start for count).)
308+ /// substr('alphabet', 3) = 'phabet'
309+ /// substr('alphabet', 3, 2) = 'ph'
310+ /// The implementation uses UTF-8 code points as characters
311+ fn calculate_substr < ' a , V , T > ( string_array : V , args : & [ ArrayRef ] ) -> Result < ArrayRef >
429312where
430313 V : StringArrayType < ' a > ,
431314{
@@ -507,8 +390,8 @@ where
507390
508391#[ cfg( test) ]
509392mod tests {
510- use arrow:: array:: { Array , StringViewArray } ;
511- use arrow:: datatypes:: DataType :: Utf8View ;
393+ use arrow:: array:: { Array , StringArray } ;
394+ use arrow:: datatypes:: DataType :: Utf8 ;
512395
513396 use datafusion_common:: { exec_err, Result , ScalarValue } ;
514397 use datafusion_expr:: { ColumnarValue , ScalarUDFImpl } ;
@@ -526,8 +409,8 @@ mod tests {
526409 ] ,
527410 Ok ( None ) ,
528411 & str ,
529- Utf8View ,
530- StringViewArray
412+ Utf8 ,
413+ StringArray
531414 ) ;
532415 test_function ! (
533416 SubstrFunc :: new( ) ,
@@ -539,35 +422,8 @@ mod tests {
539422 ] ,
540423 Ok ( Some ( "alphabet" ) ) ,
541424 & str ,
542- Utf8View ,
543- StringViewArray
544- ) ;
545- test_function ! (
546- SubstrFunc :: new( ) ,
547- vec![
548- ColumnarValue :: Scalar ( ScalarValue :: Utf8View ( Some ( String :: from(
549- "this és longer than 12B"
550- ) ) ) ) ,
551- ColumnarValue :: Scalar ( ScalarValue :: from( 5i64 ) ) ,
552- ColumnarValue :: Scalar ( ScalarValue :: from( 2i64 ) ) ,
553- ] ,
554- Ok ( Some ( " é" ) ) ,
555- & str ,
556- Utf8View ,
557- StringViewArray
558- ) ;
559- test_function ! (
560- SubstrFunc :: new( ) ,
561- vec![
562- ColumnarValue :: Scalar ( ScalarValue :: Utf8View ( Some ( String :: from(
563- "this is longer than 12B"
564- ) ) ) ) ,
565- ColumnarValue :: Scalar ( ScalarValue :: from( 5i64 ) ) ,
566- ] ,
567- Ok ( Some ( " is longer than 12B" ) ) ,
568- & str ,
569- Utf8View ,
570- StringViewArray
425+ Utf8 ,
426+ StringArray
571427 ) ;
572428 test_function ! (
573429 SubstrFunc :: new( ) ,
@@ -579,8 +435,8 @@ mod tests {
579435 ] ,
580436 Ok ( Some ( "ésoj" ) ) ,
581437 & str ,
582- Utf8View ,
583- StringViewArray
438+ Utf8 ,
439+ StringArray
584440 ) ;
585441 test_function ! (
586442 SubstrFunc :: new( ) ,
@@ -593,8 +449,8 @@ mod tests {
593449 ] ,
594450 Ok ( Some ( "ph" ) ) ,
595451 & str ,
596- Utf8View ,
597- StringViewArray
452+ Utf8 ,
453+ StringArray
598454 ) ;
599455 test_function ! (
600456 SubstrFunc :: new( ) ,
@@ -607,8 +463,8 @@ mod tests {
607463 ] ,
608464 Ok ( Some ( "phabet" ) ) ,
609465 & str ,
610- Utf8View ,
611- StringViewArray
466+ Utf8 ,
467+ StringArray
612468 ) ;
613469 test_function ! (
614470 SubstrFunc :: new( ) ,
0 commit comments