-
Notifications
You must be signed in to change notification settings - Fork 1.9k
Migrate arrow_cast to a UDF
#9610
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 8 commits
b9aa40b
0c63b47
040f5c2
33cc854
01d1f6b
769ff55
56c337b
6025b0a
9cb7ff1
182e1da
5b8fc25
1af869f
0c7b7be
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -15,63 +15,123 @@ | |
| // specific language governing permissions and limitations | ||
| // under the License. | ||
|
|
||
| //! Implementation of the `arrow_cast` function that allows | ||
| //! casting to arbitrary arrow types (rather than SQL types) | ||
| //! [`ArrowCastFunc`]: Implementation of the `arrow_cast` | ||
|
|
||
| use std::any::Any; | ||
| use std::{fmt::Display, iter::Peekable, str::Chars, sync::Arc}; | ||
|
|
||
| use arrow_schema::{DataType, Field, IntervalUnit, TimeUnit}; | ||
| use datafusion_common::{ | ||
| plan_datafusion_err, DFSchema, DataFusionError, Result, ScalarValue, | ||
| internal_err, plan_datafusion_err, plan_err, DataFusionError, ExprSchema, Result, | ||
| ScalarValue, | ||
| }; | ||
|
|
||
| use datafusion_common::plan_err; | ||
| use datafusion_expr::{Expr, ExprSchemable}; | ||
| use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo}; | ||
| use datafusion_expr::{ColumnarValue, Expr, ScalarUDFImpl, Signature, Volatility}; | ||
|
|
||
| pub const ARROW_CAST_NAME: &str = "arrow_cast"; | ||
|
|
||
| /// Create an [`Expr`] that evaluates the `arrow_cast` function | ||
| /// Implements casting to arbitrary arrow types (rather than SQL types) | ||
| /// | ||
| /// Note that the `arrow_cast` function is somewhat special in that its | ||
| /// return depends only on the *value* of its second argument (not its type) | ||
| /// | ||
| /// This function is not a [`BuiltinScalarFunction`] because the | ||
| /// return type of [`BuiltinScalarFunction`] depends only on the | ||
| /// *types* of the arguments. However, the type of `arrow_type` depends on | ||
| /// the *value* of its second argument. | ||
| /// It is implemented by calling the same underlying arrow `cast` kernel as | ||
| /// normal SQL casts. | ||
| /// | ||
| /// Use the `cast` function to cast to SQL type (which is then mapped | ||
| /// to the corresponding arrow type). For example to cast to `int` | ||
| /// (which is then mapped to the arrow type `Int32`) | ||
| /// For example to cast to `int` using SQL (which is then mapped to the arrow | ||
| /// type `Int32`) | ||
| /// | ||
| /// ```sql | ||
| /// select cast(column_x as int) ... | ||
| /// ``` | ||
| /// | ||
| /// Use the `arrow_cast` functiont to cast to a specfic arrow type | ||
| /// You can use the `arrow_cast` functiont to cast to a specific arrow type | ||
| /// | ||
| /// For example | ||
| /// ```sql | ||
| /// select arrow_cast(column_x, 'Float64') | ||
| /// ``` | ||
| /// [`BuiltinScalarFunction`]: datafusion_expr::BuiltinScalarFunction | ||
| pub fn create_arrow_cast(mut args: Vec<Expr>, schema: &DFSchema) -> Result<Expr> { | ||
| #[derive(Debug)] | ||
| pub(super) struct ArrowCastFunc { | ||
| signature: Signature, | ||
| } | ||
|
|
||
| impl ArrowCastFunc { | ||
| pub fn new() -> Self { | ||
| Self { | ||
| signature: Signature::any(2, Volatility::Immutable), | ||
| } | ||
| } | ||
| } | ||
|
|
||
| impl ScalarUDFImpl for ArrowCastFunc { | ||
| fn as_any(&self) -> &dyn Any { | ||
| self | ||
| } | ||
|
|
||
| fn name(&self) -> &str { | ||
| "arrow_cast" | ||
| } | ||
|
|
||
| fn signature(&self) -> &Signature { | ||
| &self.signature | ||
| } | ||
|
|
||
| fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> { | ||
| parse_data_type(&arg_types[1].to_string()) | ||
|
||
| } | ||
|
|
||
| fn return_type_from_exprs( | ||
| &self, | ||
| args: &[Expr], | ||
| _schema: &dyn ExprSchema, | ||
| _arg_types: &[DataType], | ||
| ) -> Result<DataType> { | ||
| data_type_from_args(args) | ||
| } | ||
|
|
||
| fn invoke(&self, _args: &[ColumnarValue]) -> Result<ColumnarValue> { | ||
| internal_err!("arrow_cast should have been simplified to cast") | ||
| } | ||
|
|
||
| fn simplify( | ||
| &self, | ||
| mut args: Vec<Expr>, | ||
| info: &dyn SimplifyInfo, | ||
| ) -> Result<ExprSimplifyResult> { | ||
| // convert this into a real cast | ||
| let target_type = data_type_from_args(&args)?; | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This simplify logic mirrors the previous behavior in that arrow_cast is replaced with a normal cast |
||
| // remove second (type) argument | ||
| args.pop().unwrap(); | ||
| let arg = args.pop().unwrap(); | ||
|
|
||
| let source_type = info.get_data_type(&arg)?; | ||
| let new_expr = if source_type == target_type { | ||
| // the argument's data type is already the correct type | ||
| arg | ||
| } else { | ||
| // Use an actual cast to get the correct type | ||
| Expr::Cast(datafusion_expr::Cast { | ||
| expr: Box::new(arg), | ||
| data_type: target_type, | ||
| }) | ||
| }; | ||
| // return the newly written argument to DataFusion | ||
| Ok(ExprSimplifyResult::Simplified(new_expr)) | ||
| } | ||
| } | ||
|
|
||
| /// Returns the requested type from the arguments | ||
| fn data_type_from_args(args: &[Expr]) -> Result<DataType> { | ||
| if args.len() != 2 { | ||
| return plan_err!("arrow_cast needs 2 arguments, {} provided", args.len()); | ||
| } | ||
| let arg1 = args.pop().unwrap(); | ||
| let arg0 = args.pop().unwrap(); | ||
|
|
||
| // arg1 must be a string | ||
| let data_type_string = if let Expr::Literal(ScalarValue::Utf8(Some(v))) = arg1 { | ||
| v | ||
| } else { | ||
| let Expr::Literal(ScalarValue::Utf8(Some(val))) = &args[1] else { | ||
| return plan_err!( | ||
| "arrow_cast requires its second argument to be a constant string, got {arg1}" | ||
| "arrow_cast requires its second argument to be a constant string, got {:?}", | ||
| &args[1] | ||
| ); | ||
| }; | ||
|
|
||
| // do the actual lookup to the appropriate data type | ||
| let data_type = parse_data_type(&data_type_string)?; | ||
|
|
||
| arg0.cast_to(&data_type, schema) | ||
| parse_data_type(val) | ||
| } | ||
|
|
||
| /// Parses `str` into a `DataType`. | ||
|
|
@@ -80,22 +140,8 @@ pub fn create_arrow_cast(mut args: Vec<Expr>, schema: &DFSchema) -> Result<Expr> | |
| /// impl, and maintains the invariant that | ||
| /// `parse_data_type(data_type.to_string()) == data_type` | ||
| /// | ||
| /// Example: | ||
| /// ``` | ||
| /// # use datafusion_sql::parse_data_type; | ||
| /// # use arrow_schema::DataType; | ||
| /// let display_value = "Int32"; | ||
| /// | ||
| /// // "Int32" is the Display value of `DataType` | ||
| /// assert_eq!(display_value, &format!("{}", DataType::Int32)); | ||
| /// | ||
| /// // parse_data_type coverts "Int32" back to `DataType`: | ||
| /// let data_type = parse_data_type(display_value).unwrap(); | ||
| /// assert_eq!(data_type, DataType::Int32); | ||
| /// ``` | ||
| /// | ||
| /// Remove if added to arrow: <https://github.com/apache/arrow-rs/issues/3821> | ||
| pub fn parse_data_type(val: &str) -> Result<DataType> { | ||
| fn parse_data_type(val: &str) -> Result<DataType> { | ||
| Parser::new(val).parse() | ||
| } | ||
|
|
||
|
|
@@ -844,7 +890,6 @@ mod test { | |
| assert!(message.contains("Must be a supported arrow type name such as 'Int32' or 'Timestamp(Nanosecond, None)'")); | ||
| } | ||
| } | ||
| println!(" Ok"); | ||
| } | ||
| } | ||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These differences are due to the fact that
arrow_castis just a normal function now rather than a special case in the parser. Thus the naming reflects normal function namingThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Interesting I got stuck implementing the simpliy function because I thought it should convert
arrow_cast(t.values,Utf8(\"Utf8\"))tot.valuesand other similar cases as well.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah -- this is pretty tricky.
arrow_castwas quite special in the parser, so now that it is handled like a normal function it has the same (somewhat strange) function effect of column naming