Skip to content
206 changes: 153 additions & 53 deletions parquet-variant-compute/src/cast_to_variant.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ use std::sync::Arc;

use crate::type_conversion::{
decimal_to_variant_decimal, generic_conversion_array, non_generic_conversion_array,
primitive_conversion_array,
primitive_conversion_array, timestamp_to_variant_timestamp,
};
use crate::{VariantArray, VariantArrayBuilder};
use arrow::array::{
Expand All @@ -46,72 +46,99 @@ use parquet_variant::{
Variant, VariantBuilder, VariantDecimal16, VariantDecimal4, VariantDecimal8,
};

fn convert_timestamp(
/// Options for controlling the behavior of `cast_to_variant_with_options`.
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct CastOptions {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks quite similar to https://docs.rs/arrow/latest/arrow/compute/struct.CastOptions.html

However that seems to be defined in arrow-compute

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought more about this and I think a dedicated CastOptions struct makes sense for variant

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Specifically they are different kernels and so some options like FormatOptions for the normal cast kernel may not be relevant for this one

/// If true, return error on conversion failure. If false, insert null for failed conversions.
pub strict: bool,
}

impl Default for CastOptions {
fn default() -> Self {
Self { strict: true }
}
}

fn convert_timestamp_with_options(
time_unit: &TimeUnit,
time_zone: &Option<Arc<str>>,
input: &dyn Array,
builder: &mut VariantArrayBuilder,
) {
options: &CastOptions,
) -> Result<(), ArrowError> {
let native_datetimes: Vec<Option<NaiveDateTime>> = match time_unit {
arrow_schema::TimeUnit::Second => {
let ts_array = input
.as_any()
.downcast_ref::<TimestampSecondArray>()
.expect("Array is not TimestampSecondArray");

ts_array
.iter()
.map(|x| x.map(|y| timestamp_s_to_datetime(y).unwrap()))
.collect()
timestamp_to_variant_timestamp!(
ts_array,
timestamp_s_to_datetime,
"seconds",
options.strict
)
}
arrow_schema::TimeUnit::Millisecond => {
let ts_array = input
.as_any()
.downcast_ref::<TimestampMillisecondArray>()
.expect("Array is not TimestampMillisecondArray");

ts_array
.iter()
.map(|x| x.map(|y| timestamp_ms_to_datetime(y).unwrap()))
.collect()
timestamp_to_variant_timestamp!(
ts_array,
timestamp_ms_to_datetime,
"milliseconds",
options.strict
)
}
arrow_schema::TimeUnit::Microsecond => {
let ts_array = input
.as_any()
.downcast_ref::<TimestampMicrosecondArray>()
.expect("Array is not TimestampMicrosecondArray");
ts_array
.iter()
.map(|x| x.map(|y| timestamp_us_to_datetime(y).unwrap()))
.collect()
timestamp_to_variant_timestamp!(
ts_array,
timestamp_us_to_datetime,
"microseconds",
options.strict
)
}
arrow_schema::TimeUnit::Nanosecond => {
let ts_array = input
.as_any()
.downcast_ref::<TimestampNanosecondArray>()
.expect("Array is not TimestampNanosecondArray");
ts_array
.iter()
.map(|x| x.map(|y| timestamp_ns_to_datetime(y).unwrap()))
.collect()
timestamp_to_variant_timestamp!(
ts_array,
timestamp_ns_to_datetime,
"nanoseconds",
options.strict
)
}
};

for x in native_datetimes {
for (i, x) in native_datetimes.iter().enumerate() {
match x {
Some(ndt) => {
if time_zone.is_none() {
builder.append_variant(ndt.into());
builder.append_variant((*ndt).into());
} else {
let utc_dt: DateTime<Utc> = Utc.from_utc_datetime(&ndt);
let utc_dt: DateTime<Utc> = Utc.from_utc_datetime(ndt);
builder.append_variant(utc_dt.into());
}
}
None if options.strict && input.is_valid(i) => {
return Err(ArrowError::ComputeError(format!(
"Failed to convert timestamp at index {}: invalid timestamp value",
i
)));
}
None => {
builder.append_null();
}
}
}
Ok(())
}

/// Casts a typed arrow [`Array`] to a [`VariantArray`]. This is useful when you
Expand Down Expand Up @@ -143,7 +170,14 @@ fn convert_timestamp(
/// `1970-01-01T00:00:01.234567890Z`
/// will be truncated to
/// `1970-01-01T00:00:01.234567Z`
pub fn cast_to_variant(input: &dyn Array) -> Result<VariantArray, ArrowError> {
///
/// # Arguments
/// * `input` - The array to convert to VariantArray
/// * `options` - Options controlling conversion behavior
pub fn cast_to_variant_with_options(
input: &dyn Array,
options: &CastOptions,
) -> Result<VariantArray, ArrowError> {
let mut builder = VariantArrayBuilder::new(input.len());

let input_type = input.data_type();
Expand Down Expand Up @@ -248,7 +282,7 @@ pub fn cast_to_variant(input: &dyn Array) -> Result<VariantArray, ArrowError> {
}
}
DataType::Timestamp(time_unit, time_zone) => {
convert_timestamp(time_unit, time_zone, input, &mut builder);
convert_timestamp_with_options(time_unit, time_zone, input, &mut builder, options)?;
}
DataType::Time32(unit) => {
match *unit {
Expand All @@ -257,10 +291,11 @@ pub fn cast_to_variant(input: &dyn Array) -> Result<VariantArray, ArrowError> {
Time32SecondType,
as_primitive,
// nano second are always 0
|v| NaiveTime::from_num_seconds_from_midnight_opt(v as u32, 0u32).unwrap(),
|v| NaiveTime::from_num_seconds_from_midnight_opt(v as u32, 0u32),
input,
builder
);
builder,
options.strict
)?;
}
TimeUnit::Millisecond => {
generic_conversion_array!(
Expand All @@ -269,11 +304,11 @@ pub fn cast_to_variant(input: &dyn Array) -> Result<VariantArray, ArrowError> {
|v| NaiveTime::from_num_seconds_from_midnight_opt(
v as u32 / 1000,
(v as u32 % 1000) * 1_000_000
)
.unwrap(),
),
input,
builder
);
builder,
options.strict
)?;
}
_ => {
return Err(ArrowError::CastError(format!(
Expand All @@ -292,11 +327,11 @@ pub fn cast_to_variant(input: &dyn Array) -> Result<VariantArray, ArrowError> {
|v| NaiveTime::from_num_seconds_from_midnight_opt(
(v / 1_000_000) as u32,
(v % 1_000_000 * 1_000) as u32
)
.unwrap(),
),
input,
builder
);
builder,
options.strict
)?;
}
TimeUnit::Nanosecond => {
generic_conversion_array!(
Expand All @@ -305,11 +340,11 @@ pub fn cast_to_variant(input: &dyn Array) -> Result<VariantArray, ArrowError> {
|v| NaiveTime::from_num_seconds_from_midnight_opt(
(v / 1_000_000_000) as u32,
(v % 1_000_000_000) as u32
)
.unwrap(),
),
input,
builder
);
builder,
options.strict
)?;
}
_ => {
return Err(ArrowError::CastError(format!(
Expand Down Expand Up @@ -396,10 +431,11 @@ pub fn cast_to_variant(input: &dyn Array) -> Result<VariantArray, ArrowError> {
generic_conversion_array!(
Date64Type,
as_primitive,
|v: i64| { Date64Type::to_naive_date_opt(v).unwrap() },
|v: i64| Date64Type::to_naive_date_opt(v),
input,
builder
);
builder,
options.strict
)?;
}
DataType::RunEndEncoded(run_ends, _) => match run_ends.data_type() {
DataType::Int16 => convert_run_end_encoded::<Int16Type>(input, &mut builder)?,
Expand Down Expand Up @@ -545,6 +581,14 @@ pub fn cast_to_variant(input: &dyn Array) -> Result<VariantArray, ArrowError> {
Ok(builder.build())
}

/// Convert an array to a `VariantArray` with strict mode enabled (returns errors on conversion failures).
///
/// This function provides backward compatibility. For non-strict behavior,
/// use `cast_to_variant_with_options` with `CastOptions { strict: false }`.
pub fn cast_to_variant(input: &dyn Array) -> Result<VariantArray, ArrowError> {
cast_to_variant_with_options(input, &CastOptions::default())
}

/// Convert union arrays
fn convert_union(
fields: &UnionFields,
Expand Down Expand Up @@ -645,10 +689,6 @@ fn convert_dictionary_encoded(
Ok(())
}

// TODO do we need a cast_with_options to allow specifying conversion behavior,
// e.g. how to handle overflows, whether to convert to Variant::Null or return
// an error, etc. ?

#[cfg(test)]
mod tests {
use super::*;
Expand All @@ -661,8 +701,8 @@ mod tests {
IntervalDayTimeArray, IntervalMonthDayNanoArray, IntervalYearMonthArray, LargeListArray,
LargeStringArray, ListArray, MapArray, NullArray, StringArray, StringRunBuilder,
StringViewArray, StructArray, Time32MillisecondArray, Time32SecondArray,
Time64MicrosecondArray, Time64NanosecondArray, UInt16Array, UInt32Array, UInt64Array,
UInt8Array, UnionArray,
Time64MicrosecondArray, Time64NanosecondArray, TimestampSecondArray, UInt16Array,
UInt32Array, UInt64Array, UInt8Array, UnionArray,
};
use arrow::buffer::{NullBuffer, OffsetBuffer, ScalarBuffer};
use arrow::datatypes::{IntervalDayTime, IntervalMonthDayNano};
Expand Down Expand Up @@ -2376,9 +2416,9 @@ mod tests {
/// Converts the given `Array` to a `VariantArray` and tests the conversion
/// against the expected values. It also tests the handling of nulls by
/// setting one element to null and verifying the output.
fn run_test(values: ArrayRef, expected: Vec<Option<Variant>>) {
// test without nulls
let variant_array = cast_to_variant(&values).unwrap();
fn run_test_with_options(values: ArrayRef, expected: Vec<Option<Variant>>, strict: bool) {
let options = CastOptions { strict };
let variant_array = cast_to_variant_with_options(&values, &options).unwrap();
assert_eq!(variant_array.len(), expected.len());
for (i, expected_value) in expected.iter().enumerate() {
match expected_value {
Expand All @@ -2392,4 +2432,64 @@ mod tests {
}
}
}

fn run_test(values: ArrayRef, expected: Vec<Option<Variant>>) {
run_test_with_options(values, expected, true);
}

fn run_test_non_strict(values: ArrayRef, expected: Vec<Option<Variant>>) {
run_test_with_options(values, expected, false);
}

#[test]
fn test_cast_to_variant_non_strict_mode_date64() {
let date64_values = Date64Array::from(vec![Some(i64::MAX), Some(0), Some(i64::MIN)]);

let values = Arc::new(date64_values);
run_test_non_strict(
values,
vec![
None,
Some(Variant::Date(Date64Type::to_naive_date_opt(0).unwrap())),
None,
],
);
}

#[test]
fn test_cast_to_variant_non_strict_mode_time32() {
let time32_array = Time32SecondArray::from(vec![Some(90000), Some(3600), Some(-1)]);

let values = Arc::new(time32_array);
run_test_non_strict(
values,
vec![
None,
Some(Variant::Time(
NaiveTime::from_num_seconds_from_midnight_opt(3600, 0).unwrap(),
)),
None,
],
);
}

#[test]
fn test_cast_to_variant_non_strict_mode_timestamp() {
let ts_array = TimestampSecondArray::from(vec![Some(i64::MAX), Some(0), Some(1609459200)])
.with_timezone_opt(None::<&str>);

let values = Arc::new(ts_array);
run_test_non_strict(
values,
vec![
None, // Invalid timestamp becomes null
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just realized that for overflow value, we may need to convert it to Variant::Null instead of None like what we did for Decimal, see this comment for context. However, I don't think it should be in this PR, we could discuss and make them consistent later.

cc @alamb

Some(Variant::TimestampNtzMicros(
timestamp_s_to_datetime(0).unwrap(),
)),
Some(Variant::TimestampNtzMicros(
timestamp_s_to_datetime(1609459200).unwrap(),
)),
],
);
}
}
3 changes: 2 additions & 1 deletion parquet-variant-compute/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
//! - [`VariantArrayBuilder`]: For building [`VariantArray`]
//! - [`json_to_variant`]: Function to convert a batch of JSON strings to a `VariantArray`.
//! - [`variant_to_json`]: Function to convert a `VariantArray` to a batch of JSON strings.
//! - [`cast_to_variant`]: Module to cast other Arrow arrays to `VariantArray`.
//! - [`mod@cast_to_variant`]: Module to cast other Arrow arrays to `VariantArray`.
//! - [`variant_get`]: Module to get values from a `VariantArray` using a specified [`VariantPath`]
//!
//! ## 🚧 Work In Progress
Expand All @@ -46,5 +46,6 @@ pub mod variant_get;
pub use variant_array::{ShreddingState, VariantArray};
pub use variant_array_builder::{VariantArrayBuilder, VariantArrayVariantBuilder};

pub use cast_to_variant::{cast_to_variant, cast_to_variant_with_options, CastOptions};
pub use from_json::json_to_variant;
pub use to_json::variant_to_json;
Loading
Loading