1515// specific language governing permissions and limitations
1616// under the License.
1717
18+ use std:: any:: Any ;
19+ use std:: hash:: { DefaultHasher , Hash , Hasher } ;
20+ use std:: iter;
21+ use std:: sync:: Arc ;
22+
1823use arrow:: compute:: kernels:: numeric:: add;
24+ use arrow_array:: builder:: BooleanBuilder ;
25+ use arrow_array:: cast:: AsArray ;
26+ use arrow_array:: StringArray ;
1927use arrow_array:: {
2028 Array , ArrayRef , Float32Array , Float64Array , Int32Array , RecordBatch , UInt8Array ,
2129} ;
2230use arrow_schema:: DataType :: Float64 ;
2331use arrow_schema:: { DataType , Field , Schema } ;
32+ use rand:: { thread_rng, Rng } ;
33+ use regex:: Regex ;
34+
2435use datafusion:: execution:: context:: { FunctionFactory , RegisterFunction , SessionState } ;
2536use datafusion:: prelude:: * ;
2637use datafusion:: { execution:: registry:: FunctionRegistry , test_util} ;
@@ -36,10 +47,6 @@ use datafusion_expr::{
3647 create_udaf, create_udf, Accumulator , ColumnarValue , CreateFunction , ExprSchemable ,
3748 LogicalPlanBuilder , ScalarUDF , ScalarUDFImpl , Signature , Volatility ,
3849} ;
39- use rand:: { thread_rng, Rng } ;
40- use std:: any:: Any ;
41- use std:: iter;
42- use std:: sync:: Arc ;
4350
4451/// test that casting happens on udfs.
4552/// c11 is f32, but `custom_sqrt` requires f64. Casting happens but the logical plan and
@@ -961,6 +968,121 @@ async fn create_scalar_function_from_sql_statement() -> Result<()> {
961968 Ok ( ( ) )
962969}
963970
971+ #[ derive( Debug ) ]
972+ struct MyRegexUdf {
973+ signature : Signature ,
974+ regex : Regex ,
975+ }
976+
977+ impl MyRegexUdf {
978+ fn new ( pattern : & str ) -> Self {
979+ Self {
980+ signature : Signature :: exact ( vec ! [ DataType :: Utf8 ] , Volatility :: Immutable ) ,
981+ regex : Regex :: new ( pattern) . expect ( "regex" ) ,
982+ }
983+ }
984+
985+ fn matches ( & self , value : Option < & str > ) -> Option < bool > {
986+ Some ( self . regex . is_match ( value?) )
987+ }
988+ }
989+
990+ impl ScalarUDFImpl for MyRegexUdf {
991+ fn as_any ( & self ) -> & dyn Any {
992+ self
993+ }
994+
995+ fn name ( & self ) -> & str {
996+ "regex_udf"
997+ }
998+
999+ fn signature ( & self ) -> & Signature {
1000+ & self . signature
1001+ }
1002+
1003+ fn return_type ( & self , args : & [ DataType ] ) -> Result < DataType > {
1004+ if matches ! ( args, [ DataType :: Utf8 ] ) {
1005+ Ok ( DataType :: Boolean )
1006+ } else {
1007+ plan_err ! ( "regex_udf only accepts a Utf8 argument" )
1008+ }
1009+ }
1010+
1011+ fn invoke ( & self , args : & [ ColumnarValue ] ) -> Result < ColumnarValue > {
1012+ match args {
1013+ [ ColumnarValue :: Scalar ( ScalarValue :: Utf8 ( value) ) ] => {
1014+ Ok ( ColumnarValue :: Scalar ( ScalarValue :: Boolean (
1015+ self . matches ( value. as_deref ( ) ) ,
1016+ ) ) )
1017+ }
1018+ [ ColumnarValue :: Array ( values) ] => {
1019+ let mut builder = BooleanBuilder :: with_capacity ( values. len ( ) ) ;
1020+ for value in values. as_string :: < i32 > ( ) {
1021+ builder. append_option ( self . matches ( value) )
1022+ }
1023+ Ok ( ColumnarValue :: Array ( Arc :: new ( builder. finish ( ) ) ) )
1024+ }
1025+ _ => exec_err ! ( "regex_udf only accepts a Utf8 arguments" ) ,
1026+ }
1027+ }
1028+
1029+ fn equals ( & self , other : & dyn ScalarUDFImpl ) -> bool {
1030+ if let Some ( other) = other. as_any ( ) . downcast_ref :: < MyRegexUdf > ( ) {
1031+ self . regex . as_str ( ) == other. regex . as_str ( )
1032+ } else {
1033+ false
1034+ }
1035+ }
1036+
1037+ fn hash_value ( & self ) -> u64 {
1038+ let hasher = & mut DefaultHasher :: new ( ) ;
1039+ self . regex . as_str ( ) . hash ( hasher) ;
1040+ hasher. finish ( )
1041+ }
1042+ }
1043+
1044+ #[ tokio:: test]
1045+ async fn test_parameterized_scalar_udf ( ) -> Result < ( ) > {
1046+ let batch = RecordBatch :: try_from_iter ( [ (
1047+ "text" ,
1048+ Arc :: new ( StringArray :: from ( vec ! [ "foo" , "bar" , "foobar" , "barfoo" ] ) ) as ArrayRef ,
1049+ ) ] ) ?;
1050+
1051+ let ctx = SessionContext :: new ( ) ;
1052+ ctx. register_batch ( "t" , batch) ?;
1053+ let t = ctx. table ( "t" ) . await ?;
1054+ let foo_udf = ScalarUDF :: from ( MyRegexUdf :: new ( "fo{2}" ) ) ;
1055+ let bar_udf = ScalarUDF :: from ( MyRegexUdf :: new ( "[Bb]ar" ) ) ;
1056+
1057+ let plan = LogicalPlanBuilder :: from ( t. into_optimized_plan ( ) ?)
1058+ . filter (
1059+ foo_udf
1060+ . call ( vec ! [ col( "text" ) ] )
1061+ . and ( bar_udf. call ( vec ! [ col( "text" ) ] ) ) ,
1062+ ) ?
1063+ . filter ( col ( "text" ) . is_not_null ( ) ) ?
1064+ . build ( ) ?;
1065+
1066+ assert_eq ! (
1067+ format!( "{plan:?}" ) ,
1068+ "Filter: t.text IS NOT NULL\n Filter: regex_udf(t.text) AND regex_udf(t.text)\n TableScan: t projection=[text]"
1069+ ) ;
1070+
1071+ let actual = DataFrame :: new ( ctx. state ( ) , plan) . collect ( ) . await ?;
1072+ let expected = [
1073+ "+--------+" ,
1074+ "| text |" ,
1075+ "+--------+" ,
1076+ "| foobar |" ,
1077+ "| barfoo |" ,
1078+ "+--------+" ,
1079+ ] ;
1080+ assert_batches_eq ! ( expected, & actual) ;
1081+
1082+ ctx. deregister_table ( "t" ) ?;
1083+ Ok ( ( ) )
1084+ }
1085+
9641086fn create_udf_context ( ) -> SessionContext {
9651087 let ctx = SessionContext :: new ( ) ;
9661088 // register a custom UDF
0 commit comments