-
Notifications
You must be signed in to change notification settings - Fork 1.1k
[Variant] add strict mode to cast_to_variant #8233
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 2 commits
0663077
7ff0b9d
e2c455f
3ef1e17
097d4dc
c527b86
3c52e6d
a136b09
f42aed0
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -51,67 +51,147 @@ fn convert_timestamp( | |
| time_zone: &Option<Arc<str>>, | ||
| input: &dyn Array, | ||
| builder: &mut VariantArrayBuilder, | ||
| ) { | ||
| strict: bool, | ||
| ) -> Result<(), ArrowError> { | ||
| let native_datetimes: Vec<Option<NaiveDateTime>> = match time_unit { | ||
| arrow_schema::TimeUnit::Second => { | ||
| let ts_array = input | ||
| .as_any() | ||
| .downcast_ref::<TimestampSecondArray>() | ||
| .expect("Array is not TimestampSecondArray"); | ||
|
|
||
| ts_array | ||
| .iter() | ||
| .map(|x| x.map(|y| timestamp_s_to_datetime(y).unwrap())) | ||
| .collect() | ||
| if strict { | ||
| let mut result = Vec::with_capacity(ts_array.len()); | ||
| for x in ts_array.iter() { | ||
| match x { | ||
| Some(y) => match timestamp_s_to_datetime(y) { | ||
| Some(dt) => result.push(Some(dt)), | ||
| None => { | ||
| return Err(ArrowError::ComputeError( | ||
| "Invalid timestamp seconds value".to_string(), | ||
| )) | ||
| } | ||
| }, | ||
| None => result.push(None), | ||
| } | ||
| } | ||
| result | ||
| } else { | ||
| ts_array | ||
| .iter() | ||
| .map(|x| x.and_then(timestamp_s_to_datetime)) | ||
| .collect() | ||
| } | ||
| } | ||
| arrow_schema::TimeUnit::Millisecond => { | ||
| let ts_array = input | ||
| .as_any() | ||
| .downcast_ref::<TimestampMillisecondArray>() | ||
| .expect("Array is not TimestampMillisecondArray"); | ||
|
|
||
| ts_array | ||
| .iter() | ||
| .map(|x| x.map(|y| timestamp_ms_to_datetime(y).unwrap())) | ||
| .collect() | ||
| if strict { | ||
| let mut result = Vec::with_capacity(ts_array.len()); | ||
| for x in ts_array.iter() { | ||
| match x { | ||
| Some(y) => match timestamp_ms_to_datetime(y) { | ||
| Some(dt) => result.push(Some(dt)), | ||
| None => { | ||
| return Err(ArrowError::ComputeError( | ||
| "Invalid timestamp milliseconds value".to_string(), | ||
| )) | ||
| } | ||
| }, | ||
| None => result.push(None), | ||
| } | ||
| } | ||
| result | ||
| } else { | ||
| ts_array | ||
| .iter() | ||
| .map(|x| x.and_then(timestamp_ms_to_datetime)) | ||
| .collect() | ||
| } | ||
| } | ||
| arrow_schema::TimeUnit::Microsecond => { | ||
| let ts_array = input | ||
| .as_any() | ||
| .downcast_ref::<TimestampMicrosecondArray>() | ||
| .expect("Array is not TimestampMicrosecondArray"); | ||
| ts_array | ||
| .iter() | ||
| .map(|x| x.map(|y| timestamp_us_to_datetime(y).unwrap())) | ||
| .collect() | ||
| if strict { | ||
| let mut result = Vec::with_capacity(ts_array.len()); | ||
| for x in ts_array.iter() { | ||
| match x { | ||
| Some(y) => match timestamp_us_to_datetime(y) { | ||
| Some(dt) => result.push(Some(dt)), | ||
| None => { | ||
| return Err(ArrowError::ComputeError( | ||
| "Invalid timestamp microseconds value".to_string(), | ||
| )) | ||
| } | ||
| }, | ||
| None => result.push(None), | ||
| } | ||
| } | ||
| result | ||
| } else { | ||
| ts_array | ||
| .iter() | ||
| .map(|x| x.and_then(timestamp_us_to_datetime)) | ||
| .collect() | ||
| } | ||
| } | ||
| arrow_schema::TimeUnit::Nanosecond => { | ||
| let ts_array = input | ||
| .as_any() | ||
| .downcast_ref::<TimestampNanosecondArray>() | ||
| .expect("Array is not TimestampNanosecondArray"); | ||
| ts_array | ||
| .iter() | ||
| .map(|x| x.map(|y| timestamp_ns_to_datetime(y).unwrap())) | ||
| .collect() | ||
| if strict { | ||
| let mut result = Vec::with_capacity(ts_array.len()); | ||
| for x in ts_array.iter() { | ||
| match x { | ||
| Some(y) => match timestamp_ns_to_datetime(y) { | ||
| Some(dt) => result.push(Some(dt)), | ||
| None => { | ||
| return Err(ArrowError::ComputeError( | ||
| "Invalid timestamp nanoseconds value".to_string(), | ||
| )) | ||
| } | ||
| }, | ||
| None => result.push(None), | ||
| } | ||
| } | ||
| result | ||
| } else { | ||
| ts_array | ||
| .iter() | ||
| .map(|x| x.and_then(timestamp_ns_to_datetime)) | ||
| .collect() | ||
| } | ||
| } | ||
| }; | ||
|
|
||
| for x in native_datetimes { | ||
| for (i, x) in native_datetimes.iter().enumerate() { | ||
| match x { | ||
| Some(ndt) => { | ||
| if time_zone.is_none() { | ||
| builder.append_variant(ndt.into()); | ||
| builder.append_variant((*ndt).into()); | ||
| } else { | ||
| let utc_dt: DateTime<Utc> = Utc.from_utc_datetime(&ndt); | ||
| let utc_dt: DateTime<Utc> = Utc.from_utc_datetime(ndt); | ||
| builder.append_variant(utc_dt.into()); | ||
| } | ||
| } | ||
| None => { | ||
| if strict && input.is_valid(i) { | ||
| return Err(ArrowError::ComputeError(format!( | ||
| "Failed to convert timestamp at index {}: invalid timestamp value", | ||
| i | ||
| ))); | ||
| } | ||
| builder.append_null(); | ||
codephage2020 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| } | ||
| } | ||
| } | ||
| Ok(()) | ||
| } | ||
|
|
||
| /// Casts a typed arrow [`Array`] to a [`VariantArray`]. This is useful when you | ||
|
|
@@ -143,7 +223,14 @@ fn convert_timestamp( | |
| /// `1970-01-01T00:00:01.234567890Z` | ||
| /// will be truncated to | ||
| /// `1970-01-01T00:00:01.234567Z` | ||
| pub fn cast_to_variant(input: &dyn Array) -> Result<VariantArray, ArrowError> { | ||
| /// | ||
| /// # Arguments | ||
| /// * `input` - The array to convert to VariantArray | ||
| /// * `strict` - If true, return error on conversion failure. If false, insert null for failed conversions. | ||
| pub fn cast_to_variant_with_options( | ||
| input: &dyn Array, | ||
| strict: bool, | ||
|
||
| ) -> Result<VariantArray, ArrowError> { | ||
| let mut builder = VariantArrayBuilder::new(input.len()); | ||
|
|
||
| let input_type = input.data_type(); | ||
|
|
@@ -248,7 +335,7 @@ pub fn cast_to_variant(input: &dyn Array) -> Result<VariantArray, ArrowError> { | |
| } | ||
| } | ||
| DataType::Timestamp(time_unit, time_zone) => { | ||
| convert_timestamp(time_unit, time_zone, input, &mut builder); | ||
| convert_timestamp(time_unit, time_zone, input, &mut builder, strict)?; | ||
| } | ||
| DataType::Time32(unit) => { | ||
| match *unit { | ||
|
|
@@ -257,10 +344,11 @@ pub fn cast_to_variant(input: &dyn Array) -> Result<VariantArray, ArrowError> { | |
| Time32SecondType, | ||
| as_primitive, | ||
| // nano second are always 0 | ||
| |v| NaiveTime::from_num_seconds_from_midnight_opt(v as u32, 0u32).unwrap(), | ||
| |v| NaiveTime::from_num_seconds_from_midnight_opt(v as u32, 0u32), | ||
| input, | ||
| builder | ||
| ); | ||
| builder, | ||
| strict | ||
| )?; | ||
| } | ||
| TimeUnit::Millisecond => { | ||
| generic_conversion_array!( | ||
|
|
@@ -269,11 +357,11 @@ pub fn cast_to_variant(input: &dyn Array) -> Result<VariantArray, ArrowError> { | |
| |v| NaiveTime::from_num_seconds_from_midnight_opt( | ||
| v as u32 / 1000, | ||
| (v as u32 % 1000) * 1_000_000 | ||
| ) | ||
| .unwrap(), | ||
| ), | ||
| input, | ||
| builder | ||
| ); | ||
| builder, | ||
| strict | ||
| )?; | ||
| } | ||
| _ => { | ||
| return Err(ArrowError::CastError(format!( | ||
|
|
@@ -292,11 +380,11 @@ pub fn cast_to_variant(input: &dyn Array) -> Result<VariantArray, ArrowError> { | |
| |v| NaiveTime::from_num_seconds_from_midnight_opt( | ||
| (v / 1_000_000) as u32, | ||
| (v % 1_000_000 * 1_000) as u32 | ||
| ) | ||
| .unwrap(), | ||
| ), | ||
| input, | ||
| builder | ||
| ); | ||
| builder, | ||
| strict | ||
| )?; | ||
| } | ||
| TimeUnit::Nanosecond => { | ||
| generic_conversion_array!( | ||
|
|
@@ -305,11 +393,11 @@ pub fn cast_to_variant(input: &dyn Array) -> Result<VariantArray, ArrowError> { | |
| |v| NaiveTime::from_num_seconds_from_midnight_opt( | ||
| (v / 1_000_000_000) as u32, | ||
| (v % 1_000_000_000) as u32 | ||
| ) | ||
| .unwrap(), | ||
| ), | ||
| input, | ||
| builder | ||
| ); | ||
| builder, | ||
| strict | ||
| )?; | ||
| } | ||
| _ => { | ||
| return Err(ArrowError::CastError(format!( | ||
|
|
@@ -396,10 +484,11 @@ pub fn cast_to_variant(input: &dyn Array) -> Result<VariantArray, ArrowError> { | |
| generic_conversion_array!( | ||
| Date64Type, | ||
| as_primitive, | ||
| |v: i64| { Date64Type::to_naive_date_opt(v).unwrap() }, | ||
| |v: i64| Date64Type::to_naive_date_opt(v), | ||
| input, | ||
| builder | ||
| ); | ||
| builder, | ||
| strict | ||
| )?; | ||
| } | ||
| DataType::RunEndEncoded(run_ends, _) => match run_ends.data_type() { | ||
| DataType::Int16 => convert_run_end_encoded::<Int16Type>(input, &mut builder)?, | ||
|
|
@@ -545,6 +634,14 @@ pub fn cast_to_variant(input: &dyn Array) -> Result<VariantArray, ArrowError> { | |
| Ok(builder.build()) | ||
| } | ||
|
|
||
| /// Convert an array to a `VariantArray` with strict mode enabled (panics on conversion failures). | ||
| /// | ||
| /// This function provides backward compatibility. For non-panicking behavior, | ||
| /// use `cast_to_variant_with_options` with `strict = false`. | ||
codephage2020 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| pub fn cast_to_variant(input: &dyn Array) -> Result<VariantArray, ArrowError> { | ||
| cast_to_variant_with_options(input, true) | ||
| } | ||
|
|
||
| /// Convert union arrays | ||
| fn convert_union( | ||
| fields: &UnionFields, | ||
|
|
@@ -645,10 +742,6 @@ fn convert_dictionary_encoded( | |
| Ok(()) | ||
| } | ||
|
|
||
| // TODO do we need a cast_with_options to allow specifying conversion behavior, | ||
| // e.g. how to handle overflows, whether to convert to Variant::Null or return | ||
| // an error, etc. ? | ||
|
|
||
| #[cfg(test)] | ||
| mod tests { | ||
| use super::*; | ||
|
|
@@ -661,8 +754,8 @@ mod tests { | |
| IntervalDayTimeArray, IntervalMonthDayNanoArray, IntervalYearMonthArray, LargeListArray, | ||
| LargeStringArray, ListArray, MapArray, NullArray, StringArray, StringRunBuilder, | ||
| StringViewArray, StructArray, Time32MillisecondArray, Time32SecondArray, | ||
| Time64MicrosecondArray, Time64NanosecondArray, UInt16Array, UInt32Array, UInt64Array, | ||
| UInt8Array, UnionArray, | ||
| Time64MicrosecondArray, Time64NanosecondArray, TimestampSecondArray, UInt16Array, | ||
| UInt32Array, UInt64Array, UInt8Array, UnionArray, | ||
| }; | ||
| use arrow::buffer::{NullBuffer, OffsetBuffer, ScalarBuffer}; | ||
| use arrow::datatypes::{IntervalDayTime, IntervalMonthDayNano}; | ||
|
|
@@ -2376,9 +2469,8 @@ mod tests { | |
| /// Converts the given `Array` to a `VariantArray` and tests the conversion | ||
| /// against the expected values. It also tests the handling of nulls by | ||
| /// setting one element to null and verifying the output. | ||
| fn run_test(values: ArrayRef, expected: Vec<Option<Variant>>) { | ||
| // test without nulls | ||
| let variant_array = cast_to_variant(&values).unwrap(); | ||
| fn run_test_with_options(values: ArrayRef, expected: Vec<Option<Variant>>, strict: bool) { | ||
| let variant_array = cast_to_variant_with_options(&values, strict).unwrap(); | ||
| assert_eq!(variant_array.len(), expected.len()); | ||
| for (i, expected_value) in expected.iter().enumerate() { | ||
| match expected_value { | ||
|
|
@@ -2392,4 +2484,64 @@ mod tests { | |
| } | ||
| } | ||
| } | ||
|
|
||
| fn run_test(values: ArrayRef, expected: Vec<Option<Variant>>) { | ||
| run_test_with_options(values, expected, true); | ||
| } | ||
|
|
||
| fn run_test_non_strict(values: ArrayRef, expected: Vec<Option<Variant>>) { | ||
| run_test_with_options(values, expected, false); | ||
| } | ||
|
|
||
| #[test] | ||
| fn test_cast_to_variant_non_strict_mode_date64() { | ||
| let date64_values = Date64Array::from(vec![Some(i64::MAX), Some(0), Some(i64::MIN)]); | ||
|
|
||
| let values = Arc::new(date64_values); | ||
| run_test_non_strict( | ||
| values, | ||
| vec![ | ||
| None, | ||
| Some(Variant::Date(Date64Type::to_naive_date_opt(0).unwrap())), | ||
| None, | ||
| ], | ||
| ); | ||
| } | ||
|
|
||
| #[test] | ||
| fn test_cast_to_variant_non_strict_mode_time32() { | ||
| let time32_array = Time32SecondArray::from(vec![Some(90000), Some(3600), Some(-1)]); | ||
|
|
||
| let values = Arc::new(time32_array); | ||
| run_test_non_strict( | ||
| values, | ||
| vec![ | ||
| None, | ||
| Some(Variant::Time( | ||
| NaiveTime::from_num_seconds_from_midnight_opt(3600, 0).unwrap(), | ||
| )), | ||
| None, | ||
| ], | ||
| ); | ||
| } | ||
|
|
||
| #[test] | ||
| fn test_cast_to_variant_non_strict_mode_timestamp() { | ||
| let ts_array = TimestampSecondArray::from(vec![Some(i64::MAX), Some(0), Some(1609459200)]) | ||
| .with_timezone_opt(None::<&str>); | ||
|
|
||
| let values = Arc::new(ts_array); | ||
| run_test_non_strict( | ||
| values, | ||
| vec![ | ||
| None, // Invalid timestamp becomes null | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just realized that for overflow value, we may need to convert it to cc @alamb |
||
| Some(Variant::TimestampNtzMicros( | ||
| timestamp_s_to_datetime(0).unwrap(), | ||
| )), | ||
| Some(Variant::TimestampNtzMicros( | ||
| timestamp_s_to_datetime(1609459200).unwrap(), | ||
| )), | ||
| ], | ||
| ); | ||
| } | ||
| } | ||
Uh oh!
There was an error while loading. Please reload this page.