Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
170 changes: 170 additions & 0 deletions datafusion-examples/examples/return_types_udf.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

use std::any::Any;

use arrow_schema::{Field, Schema};
use datafusion::{arrow::datatypes::DataType, logical_expr::Volatility};

use datafusion::error::Result;
use datafusion::prelude::*;
use datafusion_common::{
internal_err, DataFusionError, ExprSchema, ScalarValue, ToDFSchema
};
use datafusion_expr::{
expr::ScalarFunction, ColumnarValue, ExprSchemable, ScalarUDF, ScalarUDFImpl,
Signature,
};

#[derive(Debug)]
struct UDFWithExprReturn {
signature: Signature,
}

impl UDFWithExprReturn {
fn new() -> Self {
Self {
signature: Signature::any(3, Volatility::Immutable),
}
}
}

//Implement the ScalarUDFImpl trait for UDFWithExprReturn
impl ScalarUDFImpl for UDFWithExprReturn {
fn as_any(&self) -> &dyn Any {
self
}
fn name(&self) -> &str {
"udf_with_expr_return"
}
fn signature(&self) -> &Signature {
&self.signature
}
fn return_type(&self, _: &[DataType]) -> Result<DataType> {
Ok(DataType::Int32)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is pretty confusing I think -- as it seems inconsistent with the return_type_from_exprs

}
// An example of how to use the exprs to determine the return type
// If the third argument is '0', return the type of the first argument
// If the third argument is '1', return the type of the second argument
fn return_type_from_exprs(
&self,
arg_exprs: &[Expr],
schema: &dyn ExprSchema,
) -> Result<DataType> {
if arg_exprs.len() != 3 {
return internal_err!("The size of the args must be 3.");
}
let take_idx = match arg_exprs.get(2).unwrap() {
Expr::Literal(ScalarValue::Int64(Some(idx))) if (idx == &0 || idx == &1) => {
*idx as usize
}
_ => unreachable!(),
};
arg_exprs.get(take_idx).unwrap().get_type(schema)
}
// The actual implementation would add one to the argument
fn invoke(&self, _args: &[ColumnarValue]) -> Result<ColumnarValue> {
unimplemented!()
}
}

#[derive(Debug)]
struct UDFDefault {
signature: Signature,
}

impl UDFDefault {
fn new() -> Self {
Self {
signature: Signature::any(3, Volatility::Immutable),
}
}
}

// Implement the ScalarUDFImpl trait for UDFDefault
// This is the same as UDFWithExprReturn, except without return_type_from_exprs
impl ScalarUDFImpl for UDFDefault {
fn as_any(&self) -> &dyn Any {
self
}
fn name(&self) -> &str {
"udf_default"
}
fn signature(&self) -> &Signature {
&self.signature
}
fn return_type(&self, _: &[DataType]) -> Result<DataType> {
Ok(DataType::Boolean)
}
// The actual implementation would add one to the argument
fn invoke(&self, _args: &[ColumnarValue]) -> Result<ColumnarValue> {
unimplemented!()
}
}

#[tokio::main]
async fn main() -> Result<()> {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this example is missing actually using the function in a query / dataframe. As @Weijun-H pointed out the logic added to ScalarUDFImpl doesn't seem to be connected anywhere else 🤔

What I think the example needs to do is someething like

  1. Create a ScalarUDF
  2. register the function with a SessionContext
  3. Run a query that uses that function (ideally both with SQL and dataframe APIs)

So for example, a good example function might be a function that takes a string argument select my_cast(<arg>, 'string') that converts the argument based on the value of the string

Then for example run queries like

select my_cast(c1, 'i32'), arrow_typeof(my_cast(c1, 'i32')); -- returns value and DataType::Int32
select my_cast(c1, 'i64'), arrow_typeof(my_cast(c1, 'i64')); -- returns value and DataType::Int64

Does that make sense?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I realized that it missing using the function.
I tried to make changes in https://github.com/apache/arrow-datafusion/blob/4d02cc0114908d4f805b2323f20751b1f6d9c2f4/datafusion/expr/src/expr_schema.rs#L106-L108
to replace return_type with return_type_from_exprs. But passing schema as a param is a problem cause the type of schema is a generic type S. 🤔

// Create a new ScalarUDF from the implementation
let udf_with_expr_return = ScalarUDF::from(UDFWithExprReturn::new());

// Call 'return_type' to get the return type of the function
let ret = udf_with_expr_return.return_type(&[DataType::Int32])?;
assert_eq!(ret, DataType::Int32);

let schema = Schema::new(vec![
Field::new("a", DataType::Float32, false),
Field::new("b", DataType::Float64, false),
])
.to_dfschema()?;

// Set the third argument to 0 to return the type of the first argument
let expr0 = udf_with_expr_return.call(vec![col("a"), col("b"), lit(0_i64)]);
let args = match expr0 {
Expr::ScalarFunction(ScalarFunction { func_def: _, args }) => args,
_ => panic!("Expected ScalarFunction"),
};
let ret = udf_with_expr_return.return_type_from_exprs(&args, &schema)?;
// The return type should be the same as the first argument
assert_eq!(ret, DataType::Float32);

// Set the third argument to 1 to return the type of the second argument
let expr1 = udf_with_expr_return.call(vec![col("a"), col("b"), lit(1_i64)]);
let args1 = match expr1 {
Expr::ScalarFunction(ScalarFunction { func_def: _, args }) => args,
_ => panic!("Expected ScalarFunction"),
};
let ret = udf_with_expr_return.return_type_from_exprs(&args1, &schema)?;
// The return type should be the same as the second argument
assert_eq!(ret, DataType::Float64);

// Create a new ScalarUDF from the implementation
let udf_default = ScalarUDF::from(UDFDefault::new());
// Call 'return_type' to get the return type of the function
let ret = udf_default.return_type(&[DataType::Int32])?;
assert_eq!(ret, DataType::Boolean);

// Set the third argument to 0 to return the type of the first argument
let expr2 = udf_default.call(vec![col("a"), col("b"), lit(0_i64)]);
let args = match expr2 {
Expr::ScalarFunction(ScalarFunction { func_def: _, args }) => args,
_ => panic!("Expected ScalarFunction"),
};
let ret = udf_default.return_type_from_exprs(&args, &schema)?;
assert_eq!(ret, DataType::Boolean);

Ok(())
}
8 changes: 4 additions & 4 deletions datafusion/expr/src/expr_schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ use std::sync::Arc;
/// trait to allow expr to typable with respect to a schema
pub trait ExprSchemable {
/// given a schema, return the type of the expr
fn get_type<S: ExprSchema>(&self, schema: &S) -> Result<DataType>;
fn get_type<S: ExprSchema + ?Sized>(&self, schema: &S) -> Result<DataType>;

/// given a schema, return the nullability of the expr
fn nullable<S: ExprSchema>(&self, input_schema: &S) -> Result<bool>;
Expand Down Expand Up @@ -90,7 +90,7 @@ impl ExprSchemable for Expr {
/// expression refers to a column that does not exist in the
/// schema, or when the expression is incorrectly typed
/// (e.g. `[utf8] + [bool]`).
fn get_type<S: ExprSchema>(&self, schema: &S) -> Result<DataType> {
fn get_type<S: ExprSchema + ?Sized>(&self, schema: &S) -> Result<DataType> {
match self {
Expr::Alias(Alias { expr, name, .. }) => match &**expr {
Expr::Placeholder(Placeholder { data_type, .. }) => match &data_type {
Expand Down Expand Up @@ -136,7 +136,7 @@ impl ExprSchemable for Expr {
fun.return_type(&arg_data_types)
}
ScalarFunctionDefinition::UDF(fun) => {
Ok(fun.return_type(&arg_data_types)?)
fun.return_type_from_exprs(args, schema)
Copy link
Contributor Author

@yyy1000 yyy1000 Feb 9, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@alamb @brayanjuls . It's here that it gets the return type of ScalarUDF, and for UDF, it should be changed to this.
However, this can't work because schema needs to be Sized. Here is the error from my complier.
Screenshot 2024-02-08 at 9 32 38 PM

}
ScalarFunctionDefinition::Name(_) => {
internal_err!("Function `Expr` with name should be resolved.")
Expand Down Expand Up @@ -394,7 +394,7 @@ impl ExprSchemable for Expr {
}

/// return the schema [`Field`] for the type referenced by `get_indexed_field`
fn field_for_index<S: ExprSchema>(
fn field_for_index<S: ExprSchema + ?Sized>(
expr: &Expr,
field: &GetFieldAccess,
schema: &S,
Expand Down
30 changes: 29 additions & 1 deletion datafusion/expr/src/udf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,13 @@

//! [`ScalarUDF`]: Scalar User Defined Functions

use crate::ExprSchemable;
use crate::{
ColumnarValue, Expr, FuncMonotonicity, ReturnTypeFunction,
ScalarFunctionImplementation, Signature,
};
use arrow::datatypes::DataType;
use datafusion_common::Result;
use datafusion_common::{DFSchema, ExprSchema, Result};
use std::any::Any;
use std::fmt;
use std::fmt::Debug;
Expand Down Expand Up @@ -152,6 +153,17 @@ impl ScalarUDF {
self.inner.return_type(args)
}

/// The datatype this function returns given the input argument input types.
/// This function is used when the input arguments are [`Expr`]s.
/// See [`ScalarUDFImpl::return_type_from_exprs`] for more details.
pub fn return_type_from_exprs(
&self,
args: &[Expr],
schema: &dyn ExprSchema,
) -> Result<DataType> {
self.inner.return_type_from_exprs(args, schema)
}

/// Invoke the function on `args`, returning the appropriate result.
///
/// See [`ScalarUDFImpl::invoke`] for more details.
Expand Down Expand Up @@ -249,6 +261,22 @@ pub trait ScalarUDFImpl: Debug + Send + Sync {
/// the arguments
fn return_type(&self, arg_types: &[DataType]) -> Result<DataType>;

/// What [`DataType`] will be returned by this function, given the types of
/// the expr arguments
fn return_type_from_exprs(
Copy link
Contributor

@universalmind303 universalmind303 Feb 9, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we need to use the ExprSchema here, but there's issues going from <S: ExprSchema> to &dyn ExprSchema inside the trait, what if we moved up the default implementation out of the trait and into ScalarUDF?

we could change the trait impl to something like this

pub trait ScalarUDFImpl: Debug + Send + Sync {
    /// What [`DataType`] will be returned by this function, given the types of
    /// the expr arguments
    fn return_type_from_exprs(
        &self,
        arg_exprs: &[Expr],
        schema: &dyn ExprSchema,
    ) -> Option<Result<DataType>> {
        // The default implementation returns None
        // so that people don't have to implement `return_type_from_exprs` if they dont want to
        None
    }
}

then change the ScalarUDF impl to this

impl ScalarUDF
    /// The datatype this function returns given the input argument input types.
    /// This function is used when the input arguments are [`Expr`]s.
    /// See [`ScalarUDFImpl::return_type_from_exprs`] for more details.
    pub fn return_type_from_exprs<S: ExprSchema>(
        &self,
        args: &[Expr],
        schema: &S,
    ) -> Result<DataType> {
        // If the implementation provides a return_type_from_exprs, use it
        if let Some(return_type) = self.inner.return_type_from_exprs(args, schema) {
            return_type
        // Otherwise, use the return_type function
        } else {
            let arg_types = args
                .iter()
                .map(|arg| arg.get_type(schema))
                .collect::<Result<Vec<_>>>()?;
            self.return_type(&arg_types)
        }
    }
}

this way we don't need to constrain the ExprSchemable functions to ?Sized, and we can update the get_type function to use the return_type_from_exprs without any compile time errors.

ScalarFunctionDefinition::UDF(fun) => {
    Ok(fun.return_type_from_exprs(&args, schema)?)
}

and it still makes return_type_from_exprs an opt-in method.


It does make it very slightly less ergonomic as end users now need to wrap their body in an Option, but overall i think it's a decent compromise.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This works well on my side. Thanks!

Another question for me is, how can user implement return_type_from_exprs when using schema. For example, arg_exprs.get(take_idx).unwrap().get_type(schema) will lead an error

the size for values of type dyn ExprSchema cannot be known at compilation time
the trait Sized is not implemented for dyn ExprSchema

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm yeah @yyy1000 you still run into the same error then.

I'm wondering if it'd be easiest to just change the type signature on ExprSchemable to be generic over the trait instead of the functions.

pub trait ExprSchemable<S: ExprSchema> {
    /// given a schema, return the type of the expr
    fn get_type(&self, schema: &S) -> Result<DataType>;

    /// given a schema, return the nullability of the expr
    fn nullable(&self, input_schema: &S) -> Result<bool>;

    /// given a schema, return the expr's optional metadata
    fn metadata(&self, schema: &S) -> Result<HashMap<String, String>>;

    /// convert to a field with respect to a schema
    fn to_field(&self, input_schema: &DFSchema) -> Result<DFField>;

    /// cast to a type with respect to a schema
    fn cast_to(self, cast_to_type: &DataType, schema: &S) -> Result<Expr>;
}

impl ExprSchemable<DFSchema> for Expr {
//... 
}

then the trait can just go back to the original implementation you had using &DFSchema

    fn return_type_from_exprs(
        &self,
        arg_exprs: &[Expr],
        schema: &DFSchema,
    ) -> Result<DataType> {
        let arg_types = arg_exprs
            .iter()
            .map(|e| e.get_type(schema))
            .collect::<Result<Vec<_>>>()?;
        self.return_type(&arg_types)
    }

I tried this locally and was able to get things to compile locally, and was able to implement a udf using the trait.

It does make it a little less flexible as it's expecting a DFSchema, but i think thats ok?

I think the only other approach would be to make ScalarUDFImpl dynamic over <S: ExprSchema>, but I feel like that's much less ideal than just using a concrete type.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your help! @universalmind303
It looks good and I think I can try it to see. :)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Update is I changed the signature to take &dyn ExprSchema which seems to have worked just fine

&self,
arg_exprs: &[Expr],
schema: &dyn ExprSchema,
) -> Result<DataType> {
// provide default implementation that calls `self.return_type()`
// so that people don't have to implement `return_type_from_exprs` if they dont want to
let arg_types = arg_exprs
.iter()
.map(|e| e.get_type(schema))
.collect::<Result<Vec<_>>>()?;
self.return_type(&arg_types)
}

/// Invoke the function on `args`, returning the appropriate result
///
/// The function will be invoked passed with the slice of [`ColumnarValue`]
Expand Down