-
Notifications
You must be signed in to change notification settings - Fork 1.9k
Support compute return types from argument values (not just their DataTypes) #8985
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 13 commits
4e06013
17a2c91
02f2284
0c9acdd
56b71ae
3dbc0c7
491a4a1
468b38f
5772d9f
59b3958
f195fba
21d495f
b2e8457
4efb395
a9546ee
040d319
93b72ee
e0add48
653577f
7993af8
a3b9648
2121770
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -28,7 +28,7 @@ use crate::{utils, LogicalPlan, Projection, Subquery}; | |
| use arrow::compute::can_cast_types; | ||
| use arrow::datatypes::{DataType, Field}; | ||
| use datafusion_common::{ | ||
| internal_err, plan_datafusion_err, plan_err, Column, DFField, DFSchema, | ||
| internal_err, plan_datafusion_err, plan_err, Column, DFField, | ||
| DataFusionError, ExprSchema, Result, | ||
| }; | ||
| use std::collections::HashMap; | ||
|
|
@@ -37,19 +37,19 @@ use std::sync::Arc; | |
| /// trait to allow expr to typable with respect to a schema | ||
| pub trait ExprSchemable { | ||
| /// given a schema, return the type of the expr | ||
| fn get_type<S: ExprSchema>(&self, schema: &S) -> Result<DataType>; | ||
| fn get_type(&self, schema: &dyn ExprSchema) -> Result<DataType>; | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I had to change the traits to use I expect this to have 0 performance impact, but I will run the planning benchmarks to be sure if this acceptable
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I ran And the results looked good ( within the noise threshold / reported 1% slower which I don't attribute to this change) |
||
|
|
||
| /// given a schema, return the nullability of the expr | ||
| fn nullable<S: ExprSchema>(&self, input_schema: &S) -> Result<bool>; | ||
| fn nullable(&self, input_schema: &dyn ExprSchema) -> Result<bool>; | ||
|
|
||
| /// given a schema, return the expr's optional metadata | ||
| fn metadata<S: ExprSchema>(&self, schema: &S) -> Result<HashMap<String, String>>; | ||
| fn metadata(&self, schema: &dyn ExprSchema) -> Result<HashMap<String, String>>; | ||
|
|
||
| /// convert to a field with respect to a schema | ||
| fn to_field(&self, input_schema: &DFSchema) -> Result<DFField>; | ||
| fn to_field(&self, input_schema: &dyn ExprSchema) -> Result<DFField>; | ||
|
|
||
| /// cast to a type with respect to a schema | ||
| fn cast_to<S: ExprSchema>(self, cast_to_type: &DataType, schema: &S) -> Result<Expr>; | ||
| fn cast_to(self, cast_to_type: &DataType, schema: &dyn ExprSchema) -> Result<Expr>; | ||
| } | ||
|
|
||
| impl ExprSchemable for Expr { | ||
|
|
@@ -90,7 +90,7 @@ impl ExprSchemable for Expr { | |
| /// expression refers to a column that does not exist in the | ||
| /// schema, or when the expression is incorrectly typed | ||
| /// (e.g. `[utf8] + [bool]`). | ||
| fn get_type<S: ExprSchema>(&self, schema: &S) -> Result<DataType> { | ||
| fn get_type(&self, schema: &dyn ExprSchema) -> Result<DataType> { | ||
| match self { | ||
| Expr::Alias(Alias { expr, name, .. }) => match &**expr { | ||
| Expr::Placeholder(Placeholder { data_type, .. }) => match &data_type { | ||
|
|
@@ -136,7 +136,7 @@ impl ExprSchemable for Expr { | |
| fun.return_type(&arg_data_types) | ||
| } | ||
| ScalarFunctionDefinition::UDF(fun) => { | ||
| Ok(fun.return_type(&arg_data_types)?) | ||
| Ok(fun.return_type_from_exprs(args, schema)?) | ||
| } | ||
| ScalarFunctionDefinition::Name(_) => { | ||
| internal_err!("Function `Expr` with name should be resolved.") | ||
|
|
@@ -220,7 +220,7 @@ impl ExprSchemable for Expr { | |
| /// This function errors when it is not possible to compute its | ||
| /// nullability. This happens when the expression refers to a | ||
| /// column that does not exist in the schema. | ||
| fn nullable<S: ExprSchema>(&self, input_schema: &S) -> Result<bool> { | ||
| fn nullable(&self, input_schema: &dyn ExprSchema) -> Result<bool> { | ||
| match self { | ||
| Expr::Alias(Alias { expr, .. }) | ||
| | Expr::Not(expr) | ||
|
|
@@ -327,7 +327,7 @@ impl ExprSchemable for Expr { | |
| } | ||
| } | ||
|
|
||
| fn metadata<S: ExprSchema>(&self, schema: &S) -> Result<HashMap<String, String>> { | ||
| fn metadata(&self, schema: &dyn ExprSchema) -> Result<HashMap<String, String>> { | ||
| match self { | ||
| Expr::Column(c) => Ok(schema.metadata(c)?.clone()), | ||
| Expr::Alias(Alias { expr, .. }) => expr.metadata(schema), | ||
|
|
@@ -339,7 +339,7 @@ impl ExprSchemable for Expr { | |
| /// | ||
| /// So for example, a projected expression `col(c1) + col(c2)` is | ||
| /// placed in an output field **named** col("c1 + c2") | ||
| fn to_field(&self, input_schema: &DFSchema) -> Result<DFField> { | ||
| fn to_field(&self, input_schema: &dyn ExprSchema) -> Result<DFField> { | ||
| match self { | ||
| Expr::Column(c) => Ok(DFField::new( | ||
| c.relation.clone(), | ||
|
|
@@ -370,7 +370,7 @@ impl ExprSchemable for Expr { | |
| /// | ||
| /// This function errors when it is impossible to cast the | ||
| /// expression to the target [arrow::datatypes::DataType]. | ||
| fn cast_to<S: ExprSchema>(self, cast_to_type: &DataType, schema: &S) -> Result<Expr> { | ||
| fn cast_to(self, cast_to_type: &DataType, schema: &dyn ExprSchema) -> Result<Expr> { | ||
| let this_type = self.get_type(schema)?; | ||
| if this_type == *cast_to_type { | ||
| return Ok(self); | ||
|
|
@@ -394,10 +394,10 @@ impl ExprSchemable for Expr { | |
| } | ||
|
|
||
| /// return the schema [`Field`] for the type referenced by `get_indexed_field` | ||
| fn field_for_index<S: ExprSchema>( | ||
| fn field_for_index( | ||
| expr: &Expr, | ||
| field: &GetFieldAccess, | ||
| schema: &S, | ||
| schema: &dyn ExprSchema, | ||
| ) -> Result<Field> { | ||
| let expr_dt = expr.get_type(schema)?; | ||
| match field { | ||
|
|
@@ -457,7 +457,7 @@ mod tests { | |
| use super::*; | ||
| use crate::{col, lit}; | ||
| use arrow::datatypes::{DataType, Fields}; | ||
| use datafusion_common::{Column, ScalarValue, TableReference}; | ||
| use datafusion_common::{Column, DFSchema, ScalarValue, TableReference}; | ||
|
|
||
| macro_rules! test_is_expr_nullable { | ||
| ($EXPR_TYPE:ident) => {{ | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -17,12 +17,13 @@ | |
|
|
||
| //! [`ScalarUDF`]: Scalar User Defined Functions | ||
|
|
||
| use crate::ExprSchemable; | ||
| use crate::{ | ||
| ColumnarValue, Expr, FuncMonotonicity, ReturnTypeFunction, | ||
| ScalarFunctionImplementation, Signature, | ||
| }; | ||
| use arrow::datatypes::DataType; | ||
| use datafusion_common::Result; | ||
| use datafusion_common::{ExprSchema, Result}; | ||
| use std::any::Any; | ||
| use std::fmt; | ||
| use std::fmt::Debug; | ||
|
|
@@ -110,7 +111,7 @@ impl ScalarUDF { | |
| /// | ||
| /// If you implement [`ScalarUDFImpl`] directly you should return aliases directly. | ||
| pub fn with_aliases(self, aliases: impl IntoIterator<Item = &'static str>) -> Self { | ||
| Self::new_from_impl(AliasedScalarUDFImpl::new(self, aliases)) | ||
| Self::new_from_impl(AliasedScalarUDFImpl::new(self.inner.clone(), aliases)) | ||
| } | ||
|
|
||
| /// Returns a [`Expr`] logical expression to call this UDF with specified | ||
|
|
@@ -146,10 +147,17 @@ impl ScalarUDF { | |
| } | ||
|
|
||
| /// The datatype this function returns given the input argument input types. | ||
| /// This function is used when the input arguments are [`Expr`]s. | ||
| /// | ||
| /// See [`ScalarUDFImpl::return_type`] for more details. | ||
| pub fn return_type(&self, args: &[DataType]) -> Result<DataType> { | ||
| self.inner.return_type(args) | ||
| /// | ||
| /// See [`ScalarUDFImpl::return_type_from_exprs`] for more details. | ||
alamb marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| pub fn return_type_from_exprs( | ||
| &self, | ||
| args: &[Expr], | ||
| schema: &dyn ExprSchema, | ||
| ) -> Result<DataType> { | ||
| // If the implementation provides a return_type_from_exprs, use it | ||
| self.inner.return_type_from_exprs(args, schema) | ||
| } | ||
|
|
||
| /// Invoke the function on `args`, returning the appropriate result. | ||
|
|
@@ -249,6 +257,43 @@ pub trait ScalarUDFImpl: Debug + Send + Sync { | |
| /// the arguments | ||
| fn return_type(&self, arg_types: &[DataType]) -> Result<DataType>; | ||
|
|
||
| /// What [`DataType`] will be returned by this function, given the | ||
| /// arguments? | ||
| /// | ||
| /// Note most UDFs should implement [`Self::return_type`] and not this | ||
| /// function. The output type for most functions only depends on the types | ||
| /// of their inputs (e.g. `sqrt(f32)` is always `f32`). | ||
| /// | ||
| /// By default, this function calls [`Self::return_type`] with the | ||
| /// types of each argument. | ||
| /// | ||
| /// This method can be overridden for functions that return different | ||
| /// *types* based on the *values* of their arguments. | ||
| /// | ||
| /// For example, the following two function calls get the same argument | ||
| /// types (something and a `Utf8` string) but return different types based | ||
| /// on the value of the second argument: | ||
| /// | ||
| /// * `arrow_cast(x, 'Int16')` --> `Int16` | ||
| /// * `arrow_cast(x, 'Float32')` --> `Float32` | ||
| /// | ||
| /// # Notes: | ||
| /// | ||
| /// This function must consistently return the same type for the same | ||
| /// logical input even if the input is simplified (e.g. it must return the same | ||
| /// value for `('foo' | 'bar')` as it does for ('foobar'). | ||
|
Comment on lines
+268
to
+292
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe add some documentation about what would happen if a user tries to implement both And what the suggested implementation for
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Excellent idea - added in 653577f |
||
| fn return_type_from_exprs( | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since we need to use the we could change the trait impl to something like this pub trait ScalarUDFImpl: Debug + Send + Sync {
/// What [`DataType`] will be returned by this function, given the types of
/// the expr arguments
fn return_type_from_exprs(
&self,
arg_exprs: &[Expr],
schema: &dyn ExprSchema,
) -> Option<Result<DataType>> {
// The default implementation returns None
// so that people don't have to implement `return_type_from_exprs` if they dont want to
None
}
}then change the impl ScalarUDF
/// The datatype this function returns given the input argument input types.
/// This function is used when the input arguments are [`Expr`]s.
/// See [`ScalarUDFImpl::return_type_from_exprs`] for more details.
pub fn return_type_from_exprs<S: ExprSchema>(
&self,
args: &[Expr],
schema: &S,
) -> Result<DataType> {
// If the implementation provides a return_type_from_exprs, use it
if let Some(return_type) = self.inner.return_type_from_exprs(args, schema) {
return_type
// Otherwise, use the return_type function
} else {
let arg_types = args
.iter()
.map(|arg| arg.get_type(schema))
.collect::<Result<Vec<_>>>()?;
self.return_type(&arg_types)
}
}
}this way we don't need to constrain the ScalarFunctionDefinition::UDF(fun) => {
Ok(fun.return_type_from_exprs(&args, schema)?)
}and it still makes It does make it very slightly less ergonomic as end users now need to wrap their body in an
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This works well on my side. Thanks! Another question for me is, how can user implement
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hmm yeah @yyy1000 you still run into the same error then. I'm wondering if it'd be easiest to just change the type signature on pub trait ExprSchemable<S: ExprSchema> {
/// given a schema, return the type of the expr
fn get_type(&self, schema: &S) -> Result<DataType>;
/// given a schema, return the nullability of the expr
fn nullable(&self, input_schema: &S) -> Result<bool>;
/// given a schema, return the expr's optional metadata
fn metadata(&self, schema: &S) -> Result<HashMap<String, String>>;
/// convert to a field with respect to a schema
fn to_field(&self, input_schema: &DFSchema) -> Result<DFField>;
/// cast to a type with respect to a schema
fn cast_to(self, cast_to_type: &DataType, schema: &S) -> Result<Expr>;
}
impl ExprSchemable<DFSchema> for Expr {
//...
}then the trait can just go back to the original implementation you had using fn return_type_from_exprs(
&self,
arg_exprs: &[Expr],
schema: &DFSchema,
) -> Result<DataType> {
let arg_types = arg_exprs
.iter()
.map(|e| e.get_type(schema))
.collect::<Result<Vec<_>>>()?;
self.return_type(&arg_types)
}I tried this locally and was able to get things to compile locally, and was able to implement a udf using the trait. It does make it a little less flexible as it's expecting a I think the only other approach would be to make
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for your help! @universalmind303
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Update is I changed the signature to take |
||
| &self, | ||
| args: &[Expr], | ||
| schema: &dyn ExprSchema, | ||
| ) -> Result<DataType> { | ||
| let arg_types = args | ||
| .iter() | ||
| .map(|arg| arg.get_type(schema)) | ||
| .collect::<Result<Vec<_>>>()?; | ||
| self.return_type(&arg_types) | ||
| } | ||
|
|
||
| /// Invoke the function on `args`, returning the appropriate result | ||
| /// | ||
| /// The function will be invoked passed with the slice of [`ColumnarValue`] | ||
|
|
@@ -290,13 +335,13 @@ pub trait ScalarUDFImpl: Debug + Send + Sync { | |
| /// implement [`ScalarUDFImpl`], which supports aliases, directly if possible. | ||
| #[derive(Debug)] | ||
| struct AliasedScalarUDFImpl { | ||
| inner: ScalarUDF, | ||
| inner: Arc<dyn ScalarUDFImpl>, | ||
| aliases: Vec<String>, | ||
| } | ||
|
|
||
| impl AliasedScalarUDFImpl { | ||
| pub fn new( | ||
| inner: ScalarUDF, | ||
| inner: Arc<dyn ScalarUDFImpl>, | ||
| new_aliases: impl IntoIterator<Item = &'static str>, | ||
| ) -> Self { | ||
| let mut aliases = inner.aliases().to_vec(); | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here is an example of the feature working