diff --git a/datafusion-examples/README.md b/datafusion-examples/README.md index 5cb7cec503a67..e1fb401e7b73f 100644 --- a/datafusion-examples/README.md +++ b/datafusion-examples/README.md @@ -54,6 +54,7 @@ cargo run --example csv_sql - [`deserialize_to_struct.rs`](examples/deserialize_to_struct.rs): Convert query results into rust structs using serde - [`expr_api.rs`](examples/expr_api.rs): Create, execute, simplify and analyze `Expr`s - [`flight_sql_server.rs`](examples/flight/flight_sql_server.rs): Run DataFusion as a standalone process and execute SQL queries from JDBC clients +- [`function_factory.rs`](examples/function_factory.rs): Register `CREATE FUNCTION` handler to implement SQL macros - [`make_date.rs`](examples/make_date.rs): Examples of using the make_date function - [`memtable.rs`](examples/memtable.rs): Create an query data in memory using SQL and `RecordBatch`es - [`pruning.rs`](examples/parquet_sql.rs): Use pruning to rule out files based on statistics diff --git a/datafusion-examples/examples/function_factory.rs b/datafusion-examples/examples/function_factory.rs new file mode 100644 index 0000000000000..6c033e6c8eef6 --- /dev/null +++ b/datafusion-examples/examples/function_factory.rs @@ -0,0 +1,232 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use datafusion::error::Result; +use datafusion::execution::config::SessionConfig; +use datafusion::execution::context::{FunctionFactory, RegisterFunction, SessionContext}; +use datafusion_common::tree_node::{Transformed, TreeNode}; +use datafusion_common::{exec_err, internal_err, DataFusionError}; +use datafusion_expr::simplify::ExprSimplifyResult; +use datafusion_expr::simplify::SimplifyInfo; +use datafusion_expr::{CreateFunction, Expr, ScalarUDF, ScalarUDFImpl, Signature}; +use std::result::Result as RResult; +use std::sync::Arc; + +/// This example shows how to utilize [FunctionFactory] to implement simple +/// SQL-macro like functions using a `CREATE FUNCTION` statement. The same +/// functionality can support functions defined in any language or library. +/// +/// Apart from [FunctionFactory], this example covers +/// [ScalarUDFImpl::simplify()] which is often used at the same time, to replace +/// a function call with another expression at rutime. +/// +/// This example is rather simple and does not cover all cases required for a +/// real implementation. +#[tokio::main] +async fn main() -> Result<()> { + // First we must configure the SessionContext with our function factory + let ctx = SessionContext::new() + // register custom function factory + .with_function_factory(Arc::new(CustomFunctionFactory::default())); + + // With the function factory, we can now call `CREATE FUNCTION` SQL functions + + // Let us register a function called f which takes a single argument and + // returns that value plus one + let sql = r#" + CREATE FUNCTION f1(BIGINT) + RETURNS BIGINT + RETURN $1 + 1 + "#; + + ctx.sql(sql).await?.show().await?; + + // Now, let us register a function called f2 which takes two arguments and + // returns the first argument added to the result of calling f1 on that + // argument + let sql = r#" + CREATE FUNCTION f2(BIGINT, BIGINT) + RETURNS BIGINT + RETURN $1 + f1($2) + "#; + + ctx.sql(sql).await?.show().await?; + + // Invoke f2, and we expect to see 1 + (1 + 2) = 4 + // Note this function works on columns as well as constants. + let sql = r#" + SELECT f2(1, 2) + "#; + ctx.sql(sql).await?.show().await?; + + // Now we clean up the session by dropping the functions + ctx.sql("DROP FUNCTION f1").await?.show().await?; + ctx.sql("DROP FUNCTION f2").await?.show().await?; + + Ok(()) +} + +/// This is our FunctionFactory that is responsible for converting `CREATE +/// FUNCTION` statements into function instances +#[derive(Debug, Default)] +struct CustomFunctionFactory {} + +#[async_trait::async_trait] +impl FunctionFactory for CustomFunctionFactory { + /// This function takes the parsed `CREATE FUNCTION` statement and returns + /// the function instance. + async fn create( + &self, + _state: &SessionConfig, + statement: CreateFunction, + ) -> Result { + let f: ScalarFunctionWrapper = statement.try_into()?; + + Ok(RegisterFunction::Scalar(Arc::new(ScalarUDF::from(f)))) + } +} + +/// this function represents the newly created execution engine. +#[derive(Debug)] +struct ScalarFunctionWrapper { + /// The text of the function body, `$1 + f1($2)` in our example + name: String, + expr: Expr, + signature: Signature, + return_type: arrow_schema::DataType, +} + +impl ScalarUDFImpl for ScalarFunctionWrapper { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn name(&self) -> &str { + &self.name + } + + fn signature(&self) -> &datafusion_expr::Signature { + &self.signature + } + + fn return_type( + &self, + _arg_types: &[arrow_schema::DataType], + ) -> Result { + Ok(self.return_type.clone()) + } + + fn invoke( + &self, + _args: &[datafusion_expr::ColumnarValue], + ) -> Result { + // Since this function is always simplified to another expression, it + // should never actually be invoked + internal_err!("This function should not get invoked!") + } + + /// The simplify function is called to simply a call such as `f2(2)`. This + /// function parses the string and returns the resulting expression + fn simplify( + &self, + args: Vec, + _info: &dyn SimplifyInfo, + ) -> Result { + let replacement = Self::replacement(&self.expr, &args)?; + + Ok(ExprSimplifyResult::Simplified(replacement)) + } + + fn aliases(&self) -> &[String] { + &[] + } + + fn monotonicity(&self) -> Result> { + Ok(None) + } +} + +impl ScalarFunctionWrapper { + // replaces placeholders such as $1 with actual arguments (args[0] + fn replacement(expr: &Expr, args: &[Expr]) -> Result { + let result = expr.clone().transform(&|e| { + let r = match e { + Expr::Placeholder(placeholder) => { + let placeholder_position = + Self::parse_placeholder_identifier(&placeholder.id)?; + if placeholder_position < args.len() { + Transformed::yes(args[placeholder_position].clone()) + } else { + exec_err!( + "Function argument {} not provided, argument missing!", + placeholder.id + )? + } + } + _ => Transformed::no(e), + }; + + Ok(r) + })?; + + Ok(result.data) + } + // Finds placeholder identifier such as `$X` format where X >= 1 + fn parse_placeholder_identifier(placeholder: &str) -> Result { + if let Some(value) = placeholder.strip_prefix('$') { + Ok(value.parse().map(|v: usize| v - 1).map_err(|e| { + DataFusionError::Execution(format!( + "Placeholder `{}` parsing error: {}!", + placeholder, e + )) + })?) + } else { + exec_err!("Placeholder should start with `$`!") + } + } +} + +/// This impl block creates a scalar function from +/// a parsed `CREATE FUNCTION` statement (`CreateFunction`) +impl TryFrom for ScalarFunctionWrapper { + type Error = DataFusionError; + + fn try_from(definition: CreateFunction) -> RResult { + Ok(Self { + name: definition.name, + expr: definition + .params + .return_ + .expect("Expression has to be defined!"), + return_type: definition + .return_type + .expect("Return type has to be defined!"), + signature: Signature::exact( + definition + .args + .unwrap_or_default() + .into_iter() + .map(|a| a.data_type) + .collect(), + definition + .params + .behavior + .unwrap_or(datafusion_expr::Volatility::Volatile), + ), + }) + } +} diff --git a/datafusion/core/src/execution/context/mod.rs b/datafusion/core/src/execution/context/mod.rs index e071c5c80e118..7b37e4914cf9d 100644 --- a/datafusion/core/src/execution/context/mod.rs +++ b/datafusion/core/src/execution/context/mod.rs @@ -349,6 +349,15 @@ impl SessionContext { self.session_start_time } + /// Registers a [`FunctionFactory`] to handle `CREATE FUNCTION` statements + pub fn with_function_factory( + self, + function_factory: Arc, + ) -> Self { + self.state.write().set_function_factory(function_factory); + self + } + /// Registers the [`RecordBatch`] as the specified table name pub fn register_batch( &self, @@ -1659,6 +1668,11 @@ impl SessionState { self } + /// Registers a [`FunctionFactory`] to handle `CREATE FUNCTION` statements + pub fn set_function_factory(&mut self, function_factory: Arc) { + self.function_factory = Some(function_factory); + } + /// Replace the extension [`SerializerRegistry`] pub fn with_serializer_registry( mut self, diff --git a/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs index d9b60134b3d9a..ca61c61db16af 100644 --- a/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs +++ b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs @@ -17,28 +17,25 @@ use arrow::compute::kernels::numeric::add; use arrow_array::{ - Array, ArrayRef, ArrowNativeTypeOp, Float32Array, Float64Array, Int32Array, - RecordBatch, UInt8Array, + Array, ArrayRef, Float32Array, Float64Array, Int32Array, RecordBatch, UInt8Array, }; use arrow_schema::DataType::Float64; use arrow_schema::{DataType, Field, Schema}; use datafusion::execution::context::{FunctionFactory, RegisterFunction, SessionState}; use datafusion::prelude::*; use datafusion::{execution::registry::FunctionRegistry, test_util}; -use datafusion_common::cast::as_float64_array; +use datafusion_common::tree_node::{Transformed, TreeNode}; use datafusion_common::{ - assert_batches_eq, assert_batches_sorted_eq, cast::as_int32_array, not_impl_err, - plan_err, ExprSchema, Result, ScalarValue, + assert_batches_eq, assert_batches_sorted_eq, cast::as_float64_array, + cast::as_int32_array, not_impl_err, plan_err, ExprSchema, Result, ScalarValue, }; -use datafusion_expr::simplify::ExprSimplifyResult; -use datafusion_expr::simplify::SimplifyInfo; +use datafusion_common::{exec_err, internal_err, DataFusionError}; +use datafusion_execution::runtime_env::{RuntimeConfig, RuntimeEnv}; +use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo}; use datafusion_expr::{ create_udaf, create_udf, Accumulator, ColumnarValue, CreateFunction, ExprSchemable, LogicalPlanBuilder, ScalarUDF, ScalarUDFImpl, Signature, Volatility, }; -use parking_lot::Mutex; - -use datafusion_execution::runtime_env::{RuntimeConfig, RuntimeEnv}; use rand::{thread_rng, Rng}; use std::any::Any; use std::iter; @@ -739,65 +736,156 @@ async fn verify_udf_return_type() -> Result<()> { Ok(()) } +// create_scalar_function_from_sql_statement helper +// structures and methods. + #[derive(Debug, Default)] -struct MockFunctionFactory { - pub captured_expr: Mutex>, -} +struct CustomFunctionFactory {} #[async_trait::async_trait] -impl FunctionFactory for MockFunctionFactory { - #[doc = r" Crates and registers a function from [CreateFunction] statement"] - #[must_use] - #[allow(clippy::type_complexity, clippy::type_repetition_in_bounds)] +impl FunctionFactory for CustomFunctionFactory { async fn create( &self, - _config: &SessionConfig, + _state: &SessionConfig, statement: CreateFunction, - ) -> datafusion::error::Result { - // In this example, we always create a function that adds its arguments - // with the name specified in `CREATE FUNCTION`. In a real implementation - // the body of the created UDF would also likely be a function of the contents - // of the `CreateFunction` - let mock_add = Arc::new(|args: &[datafusion_expr::ColumnarValue]| { - let args = datafusion_expr::ColumnarValue::values_to_arrays(args)?; - let base = - datafusion_common::cast::as_float64_array(&args[0]).expect("cast failed"); - let exponent = - datafusion_common::cast::as_float64_array(&args[1]).expect("cast failed"); - - let array = base - .iter() - .zip(exponent.iter()) - .map(|(base, exponent)| match (base, exponent) { - (Some(base), Some(exponent)) => Some(base.add_wrapping(exponent)), - _ => None, - }) - .collect::(); - Ok(datafusion_expr::ColumnarValue::from( - Arc::new(array) as arrow_array::ArrayRef - )) - }); - - let args = statement.args.unwrap(); - let mock_udf = create_udf( - &statement.name, - vec![args[0].data_type.clone(), args[1].data_type.clone()], - Arc::new(statement.return_type.unwrap()), - datafusion_expr::Volatility::Immutable, - mock_add, - ); - - // capture expression so we can verify - // it has been parsed - *self.captured_expr.lock() = statement.params.return_; - - Ok(RegisterFunction::Scalar(Arc::new(mock_udf))) + ) -> Result { + let f: ScalarFunctionWrapper = statement.try_into()?; + + Ok(RegisterFunction::Scalar(Arc::new(ScalarUDF::from(f)))) + } +} +// a wrapper type to be used to register +// custom function to datafusion context +// +// it also defines custom [ScalarUDFImpl::simplify()] +// to replace ScalarUDF expression with one instance contains. +#[derive(Debug)] +struct ScalarFunctionWrapper { + name: String, + expr: Expr, + signature: Signature, + return_type: arrow_schema::DataType, +} + +impl ScalarUDFImpl for ScalarFunctionWrapper { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn name(&self) -> &str { + &self.name + } + + fn signature(&self) -> &datafusion_expr::Signature { + &self.signature + } + + fn return_type( + &self, + _arg_types: &[arrow_schema::DataType], + ) -> Result { + Ok(self.return_type.clone()) + } + + fn invoke( + &self, + _args: &[datafusion_expr::ColumnarValue], + ) -> Result { + internal_err!("This function should not get invoked!") + } + + fn simplify( + &self, + args: Vec, + _info: &dyn SimplifyInfo, + ) -> Result { + let replacement = Self::replacement(&self.expr, &args)?; + + Ok(ExprSimplifyResult::Simplified(replacement)) + } + + fn aliases(&self) -> &[String] { + &[] + } + + fn monotonicity(&self) -> Result> { + Ok(None) + } +} + +impl ScalarFunctionWrapper { + // replaces placeholders with actual arguments + fn replacement(expr: &Expr, args: &[Expr]) -> Result { + let result = expr.clone().transform(&|e| { + let r = match e { + Expr::Placeholder(placeholder) => { + let placeholder_position = + Self::parse_placeholder_identifier(&placeholder.id)?; + if placeholder_position < args.len() { + Transformed::yes(args[placeholder_position].clone()) + } else { + exec_err!( + "Function argument {} not provided, argument missing!", + placeholder.id + )? + } + } + _ => Transformed::no(e), + }; + + Ok(r) + })?; + + Ok(result.data) + } + // Finds placeholder identifier. + // placeholders are in `$X` format where X >= 1 + fn parse_placeholder_identifier(placeholder: &str) -> Result { + if let Some(value) = placeholder.strip_prefix('$') { + Ok(value.parse().map(|v: usize| v - 1).map_err(|e| { + DataFusionError::Execution(format!( + "Placeholder `{}` parsing error: {}!", + placeholder, e + )) + })?) + } else { + exec_err!("Placeholder should start with `$`!") + } + } +} + +impl TryFrom for ScalarFunctionWrapper { + type Error = DataFusionError; + + fn try_from(definition: CreateFunction) -> std::result::Result { + Ok(Self { + name: definition.name, + expr: definition + .params + .return_ + .expect("Expression has to be defined!"), + return_type: definition + .return_type + .expect("Return type has to be defined!"), + signature: Signature::exact( + definition + .args + .unwrap_or_default() + .into_iter() + .map(|a| a.data_type) + .collect(), + definition + .params + .behavior + .unwrap_or(datafusion_expr::Volatility::Volatile), + ), + }) } } #[tokio::test] async fn create_scalar_function_from_sql_statement() -> Result<()> { - let function_factory = Arc::new(MockFunctionFactory::default()); + let function_factory = Arc::new(CustomFunctionFactory::default()); let runtime_config = RuntimeConfig::new(); let runtime_environment = RuntimeEnv::new(runtime_config)?; @@ -826,11 +914,22 @@ async fn create_scalar_function_from_sql_statement() -> Result<()> { .await .is_err()); - ctx.sql("select better_add(2.0, 2.0)").await?.show().await?; + let result = ctx + .sql("select better_add(2.0, 2.0)") + .await? + .collect() + .await?; - // check if we sql expr has been converted to datafusion expr - let captured_expression = function_factory.captured_expr.lock().clone().unwrap(); - assert_eq!("$1 + $2", captured_expression.to_string()); + assert_batches_eq!( + &[ + "+-----------------------------------+", + "| better_add(Float64(2),Float64(2)) |", + "+-----------------------------------+", + "| 4.0 |", + "+-----------------------------------+", + ], + &result + ); // statement drops function assert!(ctx.sql("drop function better_add").await.is_ok());