Skip to content

Commit 7eb2f99

Browse files
committed
Add tests for parameterized UDFs
1 parent e36f1c3 commit 7eb2f99

2 files changed

Lines changed: 201 additions & 8 deletions

File tree

datafusion/core/tests/user_defined/user_defined_aggregates.rs

Lines changed: 75 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,19 @@
1818
//! This module contains end to end demonstrations of creating
1919
//! user defined aggregate functions
2020
21-
use arrow::{array::AsArray, datatypes::Fields};
22-
use arrow_array::{types::UInt64Type, Int32Array, PrimitiveArray, StructArray};
23-
use arrow_schema::Schema;
21+
use std::hash::{DefaultHasher, Hash, Hasher};
2422
use std::sync::{
2523
atomic::{AtomicBool, Ordering},
2624
Arc,
2725
};
2826

27+
use arrow::{array::AsArray, datatypes::Fields};
28+
use arrow_array::{
29+
types::UInt64Type, Int32Array, PrimitiveArray, StringArray, StructArray,
30+
};
31+
use arrow_schema::Schema;
32+
33+
use datafusion::dataframe::DataFrame;
2934
use datafusion::datasource::MemTable;
3035
use datafusion::test_util::plan_and_collect;
3136
use datafusion::{
@@ -45,7 +50,8 @@ use datafusion::{
4550
};
4651
use datafusion_common::{assert_contains, cast::as_primitive_array, exec_err};
4752
use datafusion_expr::{
48-
create_udaf, AggregateUDFImpl, GroupsAccumulator, SimpleAggregateUDF,
53+
col, create_udaf, AggregateUDFImpl, GroupsAccumulator, LogicalPlanBuilder,
54+
SimpleAggregateUDF,
4955
};
5056
use datafusion_physical_expr::expressions::AvgAccumulator;
5157

@@ -376,6 +382,56 @@ async fn test_groups_accumulator() -> Result<()> {
376382
Ok(())
377383
}
378384

385+
#[ignore]
386+
#[tokio::test]
387+
async fn test_parameterized_aggregate_udf() -> Result<()> {
388+
let batch = RecordBatch::try_from_iter([(
389+
"text",
390+
Arc::new(StringArray::from(vec!["foo"])) as ArrayRef,
391+
)])?;
392+
393+
let ctx = SessionContext::new();
394+
ctx.register_batch("t", batch)?;
395+
let t = ctx.table("t").await?;
396+
let signature = Signature::exact(vec![DataType::Utf8], Volatility::Immutable);
397+
let udf1 = AggregateUDF::from(TestGroupsAccumulator {
398+
signature: signature.clone(),
399+
result: 1,
400+
});
401+
let udf2 = AggregateUDF::from(TestGroupsAccumulator {
402+
signature: signature.clone(),
403+
result: 2,
404+
});
405+
406+
let plan = LogicalPlanBuilder::from(t.into_optimized_plan()?)
407+
.aggregate(
408+
[col("text")],
409+
[
410+
udf1.call(vec![col("text")]).alias("a"),
411+
udf2.call(vec![col("text")]).alias("b"),
412+
],
413+
)?
414+
.build()?;
415+
416+
assert_eq!(
417+
format!("{plan:?}"),
418+
"Aggregate: groupBy=[[t.text]], aggr=[[geo_mean(t.text) AS a, geo_mean(t.text) AS b]]\n TableScan: t projection=[text]"
419+
);
420+
421+
let actual = DataFrame::new(ctx.state(), plan).collect().await?;
422+
let expected = [
423+
"+------+---+---+",
424+
"| text | a | b |",
425+
"+------+---+---+",
426+
"| foo | 1 | 2 |",
427+
"+------+---+---+",
428+
];
429+
assert_batches_eq!(expected, &actual);
430+
431+
ctx.deregister_table("t")?;
432+
Ok(())
433+
}
434+
379435
/// Returns an context with a table "t" and the "first" and "time_sum"
380436
/// aggregate functions registered.
381437
///
@@ -733,6 +789,21 @@ impl AggregateUDFImpl for TestGroupsAccumulator {
733789
fn create_groups_accumulator(&self) -> Result<Box<dyn GroupsAccumulator>> {
734790
Ok(Box::new(self.clone()))
735791
}
792+
793+
fn equals(&self, other: &dyn AggregateUDFImpl) -> bool {
794+
if let Some(other) = other.as_any().downcast_ref::<TestGroupsAccumulator>() {
795+
self.result == other.result && self.signature == other.signature
796+
} else {
797+
false
798+
}
799+
}
800+
801+
fn hash_value(&self) -> u64 {
802+
let hasher = &mut DefaultHasher::new();
803+
self.signature.hash(hasher);
804+
self.result.hash(hasher);
805+
hasher.finish()
806+
}
736807
}
737808

738809
impl Accumulator for TestGroupsAccumulator {

datafusion/core/tests/user_defined/user_defined_scalar_functions.rs

Lines changed: 126 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,23 @@
1515
// specific language governing permissions and limitations
1616
// under the License.
1717

18+
use std::any::Any;
19+
use std::hash::{DefaultHasher, Hash, Hasher};
20+
use std::iter;
21+
use std::sync::Arc;
22+
1823
use arrow::compute::kernels::numeric::add;
24+
use arrow_array::builder::BooleanBuilder;
25+
use arrow_array::cast::AsArray;
26+
use arrow_array::StringArray;
1927
use arrow_array::{
2028
Array, ArrayRef, Float32Array, Float64Array, Int32Array, RecordBatch, UInt8Array,
2129
};
2230
use arrow_schema::DataType::Float64;
2331
use arrow_schema::{DataType, Field, Schema};
32+
use rand::{thread_rng, Rng};
33+
use regex::Regex;
34+
2435
use datafusion::execution::context::{FunctionFactory, RegisterFunction, SessionState};
2536
use datafusion::prelude::*;
2637
use datafusion::{execution::registry::FunctionRegistry, test_util};
@@ -36,10 +47,6 @@ use datafusion_expr::{
3647
create_udaf, create_udf, Accumulator, ColumnarValue, CreateFunction, ExprSchemable,
3748
LogicalPlanBuilder, ScalarUDF, ScalarUDFImpl, Signature, Volatility,
3849
};
39-
use rand::{thread_rng, Rng};
40-
use std::any::Any;
41-
use std::iter;
42-
use std::sync::Arc;
4350

4451
/// test that casting happens on udfs.
4552
/// c11 is f32, but `custom_sqrt` requires f64. Casting happens but the logical plan and
@@ -961,6 +968,121 @@ async fn create_scalar_function_from_sql_statement() -> Result<()> {
961968
Ok(())
962969
}
963970

971+
#[derive(Debug)]
972+
struct MyRegexUdf {
973+
signature: Signature,
974+
regex: Regex,
975+
}
976+
977+
impl MyRegexUdf {
978+
fn new(pattern: &str) -> Self {
979+
Self {
980+
signature: Signature::exact(vec![DataType::Utf8], Volatility::Immutable),
981+
regex: Regex::new(pattern).expect("regex"),
982+
}
983+
}
984+
985+
fn matches(&self, value: Option<&str>) -> Option<bool> {
986+
Some(self.regex.is_match(value?))
987+
}
988+
}
989+
990+
impl ScalarUDFImpl for MyRegexUdf {
991+
fn as_any(&self) -> &dyn Any {
992+
self
993+
}
994+
995+
fn name(&self) -> &str {
996+
"regex_udf"
997+
}
998+
999+
fn signature(&self) -> &Signature {
1000+
&self.signature
1001+
}
1002+
1003+
fn return_type(&self, args: &[DataType]) -> Result<DataType> {
1004+
if matches!(args, [DataType::Utf8]) {
1005+
Ok(DataType::Boolean)
1006+
} else {
1007+
plan_err!("regex_udf only accepts a Utf8 argument")
1008+
}
1009+
}
1010+
1011+
fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> {
1012+
match args {
1013+
[ColumnarValue::Scalar(ScalarValue::Utf8(value))] => {
1014+
Ok(ColumnarValue::Scalar(ScalarValue::Boolean(
1015+
self.matches(value.as_deref()),
1016+
)))
1017+
}
1018+
[ColumnarValue::Array(values)] => {
1019+
let mut builder = BooleanBuilder::with_capacity(values.len());
1020+
for value in values.as_string::<i32>() {
1021+
builder.append_option(self.matches(value))
1022+
}
1023+
Ok(ColumnarValue::Array(Arc::new(builder.finish())))
1024+
}
1025+
_ => exec_err!("regex_udf only accepts a Utf8 arguments"),
1026+
}
1027+
}
1028+
1029+
fn equals(&self, other: &dyn ScalarUDFImpl) -> bool {
1030+
if let Some(other) = other.as_any().downcast_ref::<MyRegexUdf>() {
1031+
self.regex.as_str() == other.regex.as_str()
1032+
} else {
1033+
false
1034+
}
1035+
}
1036+
1037+
fn hash_value(&self) -> u64 {
1038+
let hasher = &mut DefaultHasher::new();
1039+
self.regex.as_str().hash(hasher);
1040+
hasher.finish()
1041+
}
1042+
}
1043+
1044+
#[tokio::test]
1045+
async fn test_parameterized_scalar_udf() -> Result<()> {
1046+
let batch = RecordBatch::try_from_iter([(
1047+
"text",
1048+
Arc::new(StringArray::from(vec!["foo", "bar", "foobar", "barfoo"])) as ArrayRef,
1049+
)])?;
1050+
1051+
let ctx = SessionContext::new();
1052+
ctx.register_batch("t", batch)?;
1053+
let t = ctx.table("t").await?;
1054+
let foo_udf = ScalarUDF::from(MyRegexUdf::new("fo{2}"));
1055+
let bar_udf = ScalarUDF::from(MyRegexUdf::new("[Bb]ar"));
1056+
1057+
let plan = LogicalPlanBuilder::from(t.into_optimized_plan()?)
1058+
.filter(
1059+
foo_udf
1060+
.call(vec![col("text")])
1061+
.and(bar_udf.call(vec![col("text")])),
1062+
)?
1063+
.filter(col("text").is_not_null())?
1064+
.build()?;
1065+
1066+
assert_eq!(
1067+
format!("{plan:?}"),
1068+
"Filter: t.text IS NOT NULL\n Filter: regex_udf(t.text) AND regex_udf(t.text)\n TableScan: t projection=[text]"
1069+
);
1070+
1071+
let actual = DataFrame::new(ctx.state(), plan).collect().await?;
1072+
let expected = [
1073+
"+--------+",
1074+
"| text |",
1075+
"+--------+",
1076+
"| foobar |",
1077+
"| barfoo |",
1078+
"+--------+",
1079+
];
1080+
assert_batches_eq!(expected, &actual);
1081+
1082+
ctx.deregister_table("t")?;
1083+
Ok(())
1084+
}
1085+
9641086
fn create_udf_context() -> SessionContext {
9651087
let ctx = SessionContext::new();
9661088
// register a custom UDF

0 commit comments

Comments
 (0)