-
Notifications
You must be signed in to change notification settings - Fork 1.9k
feat(spark): implement Spark try_parse_url function
#17485
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 16 commits
e995629
8869878
54e0125
944d204
bf39835
87405e2
532bd38
82f5a9b
37dc796
36ead8b
a07f7eb
80e7259
b58b4c1
3cc0cbe
47d9a21
1e40033
2e0cc6a
ae48283
6904c7d
60488a6
8626fdb
3c97788
80e2dbc
e70efa3
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -49,20 +49,7 @@ impl ParseUrl { | |
| pub fn new() -> Self { | ||
| Self { | ||
| signature: Signature::one_of( | ||
| vec![ | ||
| TypeSignature::Uniform( | ||
| 1, | ||
| vec![DataType::Utf8View, DataType::Utf8, DataType::LargeUtf8], | ||
| ), | ||
| TypeSignature::Uniform( | ||
| 2, | ||
| vec![DataType::Utf8View, DataType::Utf8, DataType::LargeUtf8], | ||
| ), | ||
| TypeSignature::Uniform( | ||
| 3, | ||
| vec![DataType::Utf8View, DataType::Utf8, DataType::LargeUtf8], | ||
| ), | ||
| ], | ||
| vec![TypeSignature::String(2), TypeSignature::String(3)], | ||
| Volatility::Immutable, | ||
| ), | ||
| } | ||
|
|
@@ -99,7 +86,11 @@ impl ParseUrl { | |
| .map_err(|e| exec_datafusion_err!("{e:?}")) | ||
| .map(|url| match part { | ||
| "HOST" => url.host_str().map(String::from), | ||
| "PATH" => Some(url.path().to_string()), | ||
| "PATH" => { | ||
| let path: String = url.path().to_string(); | ||
| let path: String = if path == "/" { "".to_string() } else { path }; | ||
| Some(path) | ||
| } | ||
| "QUERY" => match key { | ||
| None => url.query().map(String::from), | ||
| Some(key) => url | ||
|
|
@@ -116,7 +107,13 @@ impl ParseUrl { | |
| None => Some(path.to_string()), | ||
| } | ||
| } | ||
| "AUTHORITY" => Some(url.authority().to_string()), | ||
| "AUTHORITY" => { | ||
| let authority: String = url.authority().to_string(); | ||
| match (url.port(), url.port_or_known_default()) { | ||
| (None, Some(port)) => Some(format!("{authority}:{port}")), | ||
| _ => Some(authority), | ||
| } | ||
| } | ||
| "USERINFO" => { | ||
| let username = url.username(); | ||
| if username.is_empty() { | ||
|
|
@@ -154,7 +151,7 @@ impl ScalarUDFImpl for ParseUrl { | |
| ); | ||
| } | ||
| match arg_types.len() { | ||
| 2 | 3 => { | ||
| 2 | 3 if arg_types.iter().all(is_string_type) => { | ||
| if arg_types | ||
| .iter() | ||
| .any(|arg| matches!(arg, DataType::LargeUtf8)) | ||
|
|
@@ -169,6 +166,11 @@ impl ScalarUDFImpl for ParseUrl { | |
| Ok(DataType::Utf8) | ||
| } | ||
| } | ||
| 2 | 3 => plan_err!( | ||
| "`{}` expects STRING arguments, got {:?}", | ||
| &self.name(), | ||
| arg_types | ||
| ), | ||
| _ => plan_err!( | ||
| "`{}` expects 2 or 3 arguments, got {}", | ||
| &self.name(), | ||
|
|
@@ -183,6 +185,13 @@ impl ScalarUDFImpl for ParseUrl { | |
| } | ||
| } | ||
|
|
||
| fn is_string_type(dt: &DataType) -> bool { | ||
| matches!( | ||
| dt, | ||
| DataType::Utf8 | DataType::Utf8View | DataType::LargeUtf8 | ||
| ) | ||
| } | ||
|
|
||
| /// Core implementation of URL parsing function. | ||
| /// | ||
| /// # Arguments | ||
|
|
@@ -200,6 +209,13 @@ impl ScalarUDFImpl for ParseUrl { | |
| /// - The output array type (StringArray or LargeStringArray) is determined by input types | ||
| /// | ||
| fn spark_parse_url(args: &[ArrayRef]) -> Result<ArrayRef> { | ||
| spark_handled_parse_url(args, |x| x) | ||
| } | ||
|
|
||
| pub fn spark_handled_parse_url( | ||
| args: &[ArrayRef], | ||
| handler_err: impl Fn(Result<Option<String>>) -> Result<Option<String>>, | ||
| ) -> Result<ArrayRef> { | ||
| if args.len() < 2 || args.len() > 3 { | ||
| return exec_err!( | ||
| "{} expects 2 or 3 arguments, but got {}", | ||
|
|
@@ -212,6 +228,7 @@ fn spark_parse_url(args: &[ArrayRef]) -> Result<ArrayRef> { | |
| let part = &args[1]; | ||
|
|
||
| let result = if args.len() == 3 { | ||
| // In this case, the 'key' argument is passed | ||
| let key = &args[2]; | ||
|
|
||
| match (url.data_type(), part.data_type(), key.data_type()) { | ||
|
|
@@ -220,20 +237,23 @@ fn spark_parse_url(args: &[ArrayRef]) -> Result<ArrayRef> { | |
| as_string_array(url)?, | ||
| as_string_array(part)?, | ||
| as_string_array(key)?, | ||
| handler_err, | ||
| ) | ||
| } | ||
| (DataType::Utf8View, DataType::Utf8View, DataType::Utf8View) => { | ||
| process_parse_url::<_, _, _, StringViewArray>( | ||
| as_string_view_array(url)?, | ||
| as_string_view_array(part)?, | ||
| as_string_view_array(key)?, | ||
| handler_err, | ||
| ) | ||
| } | ||
| (DataType::LargeUtf8, DataType::LargeUtf8, DataType::LargeUtf8) => { | ||
| process_parse_url::<_, _, _, LargeStringArray>( | ||
| as_large_string_array(url)?, | ||
| as_large_string_array(part)?, | ||
| as_large_string_array(key)?, | ||
| handler_err, | ||
| ) | ||
| } | ||
| _ => exec_err!("{} expects STRING arguments, got {:?}", "`parse_url`", args), | ||
|
|
@@ -253,20 +273,23 @@ fn spark_parse_url(args: &[ArrayRef]) -> Result<ArrayRef> { | |
| as_string_array(url)?, | ||
| as_string_array(part)?, | ||
| &key, | ||
| handler_err, | ||
| ) | ||
| } | ||
| (DataType::Utf8View, DataType::Utf8View) => { | ||
| process_parse_url::<_, _, _, StringViewArray>( | ||
| as_string_view_array(url)?, | ||
| as_string_view_array(part)?, | ||
| &key, | ||
| handler_err, | ||
| ) | ||
| } | ||
| (DataType::LargeUtf8, DataType::LargeUtf8) => { | ||
| process_parse_url::<_, _, _, LargeStringArray>( | ||
| as_large_string_array(url)?, | ||
| as_large_string_array(part)?, | ||
| &key, | ||
| handler_err, | ||
| ) | ||
| } | ||
| _ => exec_err!("{} expects STRING arguments, got {:?}", "`parse_url`", args), | ||
|
|
@@ -279,6 +302,7 @@ fn process_parse_url<'a, A, B, C, T>( | |
| url_array: &'a A, | ||
| part_array: &'a B, | ||
| key_array: &'a C, | ||
| handle: impl Fn(Result<Option<String>>) -> Result<Option<String>>, | ||
| ) -> Result<ArrayRef> | ||
| where | ||
| &'a A: StringArrayType<'a>, | ||
|
|
@@ -292,11 +316,190 @@ where | |
| .zip(key_array.iter()) | ||
| .map(|((url, part), key)| { | ||
| if let (Some(url), Some(part), key) = (url, part, key) { | ||
| ParseUrl::parse(url, part, key) | ||
| handle(ParseUrl::parse(url, part, key)) | ||
| } else { | ||
| Ok(None) | ||
| } | ||
| }) | ||
| .collect::<Result<T>>() | ||
| .map(|array| Arc::new(array) as ArrayRef) | ||
| } | ||
|
|
||
| #[cfg(test)] | ||
| mod tests { | ||
| use super::*; | ||
| use arrow::array::{ArrayRef, Int32Array, StringArray}; | ||
| use arrow::datatypes::DataType; | ||
| use datafusion_common::Result; | ||
| use std::array::from_ref; | ||
| use std::sync::Arc; | ||
|
|
||
| fn sa(vals: &[Option<&str>]) -> ArrayRef { | ||
| Arc::new(StringArray::from(vals.to_vec())) as ArrayRef | ||
| } | ||
|
|
||
| #[test] | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I feel quite a few of these tests could be done as slt tests; @alamb do we have a preference on where tests should be done? Should we prefer slt over rust tests, and fallback only to rust if it is something that slt can't handle? Took a look at https://datafusion.apache.org/contributor-guide/testing.html but it doesn't mention if we have a specific preference, other than slt's being easier to maintain.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I personally prefer slt tests. But i agree we don't have clear guidance |
||
| fn test_parse_host() -> Result<()> { | ||
| let got = ParseUrl::parse("https://example.com/a?x=1", "HOST", None)?; | ||
| assert_eq!(got, Some("example.com".to_string())); | ||
| Ok(()) | ||
| } | ||
|
|
||
| #[test] | ||
| fn test_parse_query_no_key_vs_with_key() -> Result<()> { | ||
| let got_all = ParseUrl::parse("https://ex.com/p?a=1&b=2", "QUERY", None)?; | ||
| assert_eq!(got_all, Some("a=1&b=2".to_string())); | ||
|
|
||
| let got_a = ParseUrl::parse("https://ex.com/p?a=1&b=2", "QUERY", Some("a"))?; | ||
| assert_eq!(got_a, Some("1".to_string())); | ||
|
|
||
| let got_c = ParseUrl::parse("https://ex.com/p?a=1&b=2", "QUERY", Some("c"))?; | ||
| assert_eq!(got_c, None); | ||
| Ok(()) | ||
| } | ||
|
|
||
| #[test] | ||
| fn test_parse_ref_protocol_userinfo_file_authority() -> Result<()> { | ||
| let url = "ftp://user:[email protected]:21/files?x=1#frag"; | ||
| assert_eq!(ParseUrl::parse(url, "REF", None)?, Some("frag".to_string())); | ||
| assert_eq!( | ||
| ParseUrl::parse(url, "PROTOCOL", None)?, | ||
| Some("ftp".to_string()) | ||
| ); | ||
| assert_eq!( | ||
| ParseUrl::parse(url, "USERINFO", None)?, | ||
| Some("user:pwd".to_string()) | ||
| ); | ||
| assert_eq!( | ||
| ParseUrl::parse(url, "FILE", None)?, | ||
| Some("/files?x=1".to_string()) | ||
| ); | ||
| assert_eq!( | ||
| ParseUrl::parse(url, "AUTHORITY", None)?, | ||
| Some("user:[email protected]:21".to_string()) | ||
| ); | ||
| Ok(()) | ||
| } | ||
|
|
||
| #[test] | ||
| fn test_parse_path_root_is_empty_string() -> Result<()> { | ||
| let got = ParseUrl::parse("https://example.com/", "PATH", None)?; | ||
| assert_eq!(got, Some("".to_string())); | ||
| Ok(()) | ||
| } | ||
|
|
||
| #[test] | ||
| fn test_parse_malformed_url_returns_error() { | ||
| let err = ParseUrl::parse("notaurl", "HOST", None).unwrap_err(); | ||
| let msg = format!("{err}"); | ||
| assert!( | ||
| msg.contains("DataFusion") || msg.contains("error"), | ||
| "msg was: {msg}" | ||
| ); | ||
| } | ||
|
|
||
| #[test] | ||
| fn test_spark_utf8_two_args() -> Result<()> { | ||
| let urls = sa(&[Some("https://example.com/a?x=1"), Some("https://ex.com/")]); | ||
| let parts = sa(&[Some("HOST"), Some("PATH")]); | ||
|
|
||
| let out = spark_handled_parse_url(&[urls, parts], |x| x)?; | ||
| let out_sa = out.as_any().downcast_ref::<StringArray>().unwrap(); | ||
|
|
||
| assert_eq!(out_sa.len(), 2); | ||
| assert_eq!(out_sa.value(0), "example.com"); | ||
| assert_eq!(out_sa.value(1), ""); | ||
| Ok(()) | ||
| } | ||
|
|
||
| #[test] | ||
| fn test_spark_utf8_three_args_query_key() -> Result<()> { | ||
| let urls = sa(&[ | ||
| Some("https://example.com/a?x=1&y=2"), | ||
| Some("https://ex.com/?a=1"), | ||
| ]); | ||
| let parts = sa(&[Some("QUERY"), Some("QUERY")]); | ||
| let keys = sa(&[Some("y"), Some("b")]); | ||
|
|
||
| let out = spark_handled_parse_url(&[urls, parts, keys], |x| x)?; | ||
| let out_sa = out.as_any().downcast_ref::<StringArray>().unwrap(); | ||
|
|
||
| assert_eq!(out_sa.len(), 2); | ||
| assert_eq!(out_sa.value(0), "2"); | ||
| assert!(out_sa.is_null(1)); | ||
| Ok(()) | ||
| } | ||
|
|
||
| #[test] | ||
| fn test_spark_userinfo_and_nulls() -> Result<()> { | ||
| let urls = sa(&[ | ||
| Some("ftp://user:[email protected]:21/files"), | ||
| Some("https://example.com"), | ||
| None, | ||
| ]); | ||
| let parts = sa(&[Some("USERINFO"), Some("USERINFO"), Some("USERINFO")]); | ||
|
|
||
| let out = spark_handled_parse_url(&[urls, parts], |x| x)?; | ||
| let out_sa = out.as_any().downcast_ref::<StringArray>().unwrap(); | ||
|
|
||
| assert_eq!(out_sa.len(), 3); | ||
| assert_eq!(out_sa.value(0), "user:pwd"); | ||
| assert!(out_sa.is_null(1)); | ||
| assert!(out_sa.is_null(2)); | ||
| Ok(()) | ||
| } | ||
|
|
||
| #[test] | ||
| fn test_invalid_arg_count() { | ||
| let urls = sa(&[Some("https://example.com")]); | ||
| let err = spark_handled_parse_url(from_ref(&urls), |x| x).unwrap_err(); | ||
| assert!(format!("{err}").contains("expects 2 or 3 arguments")); | ||
|
|
||
| let parts = sa(&[Some("HOST")]); | ||
| let keys = sa(&[Some("x")]); | ||
| let err = | ||
| spark_handled_parse_url(&[urls, parts, keys, sa(&[Some("extra")])], |x| x) | ||
| .unwrap_err(); | ||
| assert!(format!("{err}").contains("expects 2 or 3 arguments")); | ||
| } | ||
|
|
||
| #[test] | ||
| fn test_non_string_types_error() { | ||
| let urls = sa(&[Some("https://example.com")]); | ||
| let bad_part = Arc::new(Int32Array::from(vec![1])) as ArrayRef; | ||
|
|
||
| let err = spark_handled_parse_url(&[urls, bad_part], |x| x).unwrap_err(); | ||
| let msg = format!("{err}"); | ||
| assert!(msg.contains("expects STRING arguments")); | ||
| } | ||
|
|
||
| #[test] | ||
| fn test_return_type_and_coercion() -> Result<()> { | ||
| let udf = ParseUrl::new(); | ||
|
|
||
| let rt = udf.return_type(&[DataType::Utf8, DataType::Utf8])?; | ||
| assert_eq!(rt, DataType::Utf8); | ||
|
|
||
| let rt = udf.return_type(&[DataType::LargeUtf8, DataType::LargeUtf8])?; | ||
| assert_eq!(rt, DataType::LargeUtf8); | ||
|
|
||
| let rt = udf.return_type(&[DataType::Utf8, DataType::Utf8, DataType::Utf8])?; | ||
| assert_eq!(rt, DataType::Utf8); | ||
|
|
||
| let rt = udf.return_type(&[DataType::LargeUtf8, DataType::Utf8])?; | ||
| assert_eq!(rt, DataType::LargeUtf8); | ||
|
|
||
| let rt = udf.return_type(&[DataType::Utf8View, DataType::Utf8])?; | ||
| assert_eq!(rt, DataType::Utf8View); | ||
|
|
||
| let err = udf | ||
| .return_type(&[DataType::Int32, DataType::Utf8]) | ||
| .unwrap_err(); | ||
| assert!(format!("{err}").contains("expects STRING arguments")); | ||
|
|
||
| let err = udf.return_type(&[DataType::Utf8]).unwrap_err(); | ||
| assert!(format!("{err}").contains("expects 2 or 3 arguments")); | ||
|
|
||
| Ok(()) | ||
| } | ||
| } | ||
Uh oh!
There was an error while loading. Please reload this page.