Skip to content
250 changes: 201 additions & 49 deletions parquet-variant-compute/src/cast_to_variant.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,67 +51,147 @@ fn convert_timestamp(
time_zone: &Option<Arc<str>>,
input: &dyn Array,
builder: &mut VariantArrayBuilder,
) {
strict: bool,
) -> Result<(), ArrowError> {
let native_datetimes: Vec<Option<NaiveDateTime>> = match time_unit {
arrow_schema::TimeUnit::Second => {
let ts_array = input
.as_any()
.downcast_ref::<TimestampSecondArray>()
.expect("Array is not TimestampSecondArray");

ts_array
.iter()
.map(|x| x.map(|y| timestamp_s_to_datetime(y).unwrap()))
.collect()
if strict {
let mut result = Vec::with_capacity(ts_array.len());
for x in ts_array.iter() {
match x {
Some(y) => match timestamp_s_to_datetime(y) {
Some(dt) => result.push(Some(dt)),
None => {
return Err(ArrowError::ComputeError(
"Invalid timestamp seconds value".to_string(),
))
}
},
None => result.push(None),
}
}
result
} else {
ts_array
.iter()
.map(|x| x.and_then(timestamp_s_to_datetime))
.collect()
}
}
arrow_schema::TimeUnit::Millisecond => {
let ts_array = input
.as_any()
.downcast_ref::<TimestampMillisecondArray>()
.expect("Array is not TimestampMillisecondArray");

ts_array
.iter()
.map(|x| x.map(|y| timestamp_ms_to_datetime(y).unwrap()))
.collect()
if strict {
let mut result = Vec::with_capacity(ts_array.len());
for x in ts_array.iter() {
match x {
Some(y) => match timestamp_ms_to_datetime(y) {
Some(dt) => result.push(Some(dt)),
None => {
return Err(ArrowError::ComputeError(
"Invalid timestamp milliseconds value".to_string(),
))
}
},
None => result.push(None),
}
}
result
} else {
ts_array
.iter()
.map(|x| x.and_then(timestamp_ms_to_datetime))
.collect()
}
}
arrow_schema::TimeUnit::Microsecond => {
let ts_array = input
.as_any()
.downcast_ref::<TimestampMicrosecondArray>()
.expect("Array is not TimestampMicrosecondArray");
ts_array
.iter()
.map(|x| x.map(|y| timestamp_us_to_datetime(y).unwrap()))
.collect()
if strict {
let mut result = Vec::with_capacity(ts_array.len());
for x in ts_array.iter() {
match x {
Some(y) => match timestamp_us_to_datetime(y) {
Some(dt) => result.push(Some(dt)),
None => {
return Err(ArrowError::ComputeError(
"Invalid timestamp microseconds value".to_string(),
))
}
},
None => result.push(None),
}
}
result
} else {
ts_array
.iter()
.map(|x| x.and_then(timestamp_us_to_datetime))
.collect()
}
}
arrow_schema::TimeUnit::Nanosecond => {
let ts_array = input
.as_any()
.downcast_ref::<TimestampNanosecondArray>()
.expect("Array is not TimestampNanosecondArray");
ts_array
.iter()
.map(|x| x.map(|y| timestamp_ns_to_datetime(y).unwrap()))
.collect()
if strict {
let mut result = Vec::with_capacity(ts_array.len());
for x in ts_array.iter() {
match x {
Some(y) => match timestamp_ns_to_datetime(y) {
Some(dt) => result.push(Some(dt)),
None => {
return Err(ArrowError::ComputeError(
"Invalid timestamp nanoseconds value".to_string(),
))
}
},
None => result.push(None),
}
}
result
} else {
ts_array
.iter()
.map(|x| x.and_then(timestamp_ns_to_datetime))
.collect()
}
}
};

for x in native_datetimes {
for (i, x) in native_datetimes.iter().enumerate() {
match x {
Some(ndt) => {
if time_zone.is_none() {
builder.append_variant(ndt.into());
builder.append_variant((*ndt).into());
} else {
let utc_dt: DateTime<Utc> = Utc.from_utc_datetime(&ndt);
let utc_dt: DateTime<Utc> = Utc.from_utc_datetime(ndt);
builder.append_variant(utc_dt.into());
}
}
None => {
if strict && input.is_valid(i) {
return Err(ArrowError::ComputeError(format!(
"Failed to convert timestamp at index {}: invalid timestamp value",
i
)));
}
builder.append_null();
}
}
}
Ok(())
}

/// Casts a typed arrow [`Array`] to a [`VariantArray`]. This is useful when you
Expand Down Expand Up @@ -143,7 +223,14 @@ fn convert_timestamp(
/// `1970-01-01T00:00:01.234567890Z`
/// will be truncated to
/// `1970-01-01T00:00:01.234567Z`
pub fn cast_to_variant(input: &dyn Array) -> Result<VariantArray, ArrowError> {
///
/// # Arguments
/// * `input` - The array to convert to VariantArray
/// * `strict` - If true, return error on conversion failure. If false, insert null for failed conversions.
pub fn cast_to_variant_with_options(
input: &dyn Array,
strict: bool,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there any possibility/risk that we may need additional options in the future? If so, it might be better to pass a struct? Or maybe we just deal with that if/when it comes?

Copy link
Contributor Author

@codephage2020 codephage2020 Aug 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, very good discovery. I have thought about this implementation. We can refactor this part in the next PR.

Done.

) -> Result<VariantArray, ArrowError> {
let mut builder = VariantArrayBuilder::new(input.len());

let input_type = input.data_type();
Expand Down Expand Up @@ -248,7 +335,7 @@ pub fn cast_to_variant(input: &dyn Array) -> Result<VariantArray, ArrowError> {
}
}
DataType::Timestamp(time_unit, time_zone) => {
convert_timestamp(time_unit, time_zone, input, &mut builder);
convert_timestamp(time_unit, time_zone, input, &mut builder, strict)?;
}
DataType::Time32(unit) => {
match *unit {
Expand All @@ -257,10 +344,11 @@ pub fn cast_to_variant(input: &dyn Array) -> Result<VariantArray, ArrowError> {
Time32SecondType,
as_primitive,
// nano second are always 0
|v| NaiveTime::from_num_seconds_from_midnight_opt(v as u32, 0u32).unwrap(),
|v| NaiveTime::from_num_seconds_from_midnight_opt(v as u32, 0u32),
input,
builder
);
builder,
strict
)?;
}
TimeUnit::Millisecond => {
generic_conversion_array!(
Expand All @@ -269,11 +357,11 @@ pub fn cast_to_variant(input: &dyn Array) -> Result<VariantArray, ArrowError> {
|v| NaiveTime::from_num_seconds_from_midnight_opt(
v as u32 / 1000,
(v as u32 % 1000) * 1_000_000
)
.unwrap(),
),
input,
builder
);
builder,
strict
)?;
}
_ => {
return Err(ArrowError::CastError(format!(
Expand All @@ -292,11 +380,11 @@ pub fn cast_to_variant(input: &dyn Array) -> Result<VariantArray, ArrowError> {
|v| NaiveTime::from_num_seconds_from_midnight_opt(
(v / 1_000_000) as u32,
(v % 1_000_000 * 1_000) as u32
)
.unwrap(),
),
input,
builder
);
builder,
strict
)?;
}
TimeUnit::Nanosecond => {
generic_conversion_array!(
Expand All @@ -305,11 +393,11 @@ pub fn cast_to_variant(input: &dyn Array) -> Result<VariantArray, ArrowError> {
|v| NaiveTime::from_num_seconds_from_midnight_opt(
(v / 1_000_000_000) as u32,
(v % 1_000_000_000) as u32
)
.unwrap(),
),
input,
builder
);
builder,
strict
)?;
}
_ => {
return Err(ArrowError::CastError(format!(
Expand Down Expand Up @@ -396,10 +484,11 @@ pub fn cast_to_variant(input: &dyn Array) -> Result<VariantArray, ArrowError> {
generic_conversion_array!(
Date64Type,
as_primitive,
|v: i64| { Date64Type::to_naive_date_opt(v).unwrap() },
|v: i64| Date64Type::to_naive_date_opt(v),
input,
builder
);
builder,
strict
)?;
}
DataType::RunEndEncoded(run_ends, _) => match run_ends.data_type() {
DataType::Int16 => convert_run_end_encoded::<Int16Type>(input, &mut builder)?,
Expand Down Expand Up @@ -545,6 +634,14 @@ pub fn cast_to_variant(input: &dyn Array) -> Result<VariantArray, ArrowError> {
Ok(builder.build())
}

/// Convert an array to a `VariantArray` with strict mode enabled (panics on conversion failures).
///
/// This function provides backward compatibility. For non-panicking behavior,
/// use `cast_to_variant_with_options` with `strict = false`.
pub fn cast_to_variant(input: &dyn Array) -> Result<VariantArray, ArrowError> {
cast_to_variant_with_options(input, true)
}

/// Convert union arrays
fn convert_union(
fields: &UnionFields,
Expand Down Expand Up @@ -645,10 +742,6 @@ fn convert_dictionary_encoded(
Ok(())
}

// TODO do we need a cast_with_options to allow specifying conversion behavior,
// e.g. how to handle overflows, whether to convert to Variant::Null or return
// an error, etc. ?

#[cfg(test)]
mod tests {
use super::*;
Expand All @@ -661,8 +754,8 @@ mod tests {
IntervalDayTimeArray, IntervalMonthDayNanoArray, IntervalYearMonthArray, LargeListArray,
LargeStringArray, ListArray, MapArray, NullArray, StringArray, StringRunBuilder,
StringViewArray, StructArray, Time32MillisecondArray, Time32SecondArray,
Time64MicrosecondArray, Time64NanosecondArray, UInt16Array, UInt32Array, UInt64Array,
UInt8Array, UnionArray,
Time64MicrosecondArray, Time64NanosecondArray, TimestampSecondArray, UInt16Array,
UInt32Array, UInt64Array, UInt8Array, UnionArray,
};
use arrow::buffer::{NullBuffer, OffsetBuffer, ScalarBuffer};
use arrow::datatypes::{IntervalDayTime, IntervalMonthDayNano};
Expand Down Expand Up @@ -2376,9 +2469,8 @@ mod tests {
/// Converts the given `Array` to a `VariantArray` and tests the conversion
/// against the expected values. It also tests the handling of nulls by
/// setting one element to null and verifying the output.
fn run_test(values: ArrayRef, expected: Vec<Option<Variant>>) {
// test without nulls
let variant_array = cast_to_variant(&values).unwrap();
fn run_test_with_options(values: ArrayRef, expected: Vec<Option<Variant>>, strict: bool) {
let variant_array = cast_to_variant_with_options(&values, strict).unwrap();
assert_eq!(variant_array.len(), expected.len());
for (i, expected_value) in expected.iter().enumerate() {
match expected_value {
Expand All @@ -2392,4 +2484,64 @@ mod tests {
}
}
}

fn run_test(values: ArrayRef, expected: Vec<Option<Variant>>) {
run_test_with_options(values, expected, true);
}

fn run_test_non_strict(values: ArrayRef, expected: Vec<Option<Variant>>) {
run_test_with_options(values, expected, false);
}

#[test]
fn test_cast_to_variant_non_strict_mode_date64() {
let date64_values = Date64Array::from(vec![Some(i64::MAX), Some(0), Some(i64::MIN)]);

let values = Arc::new(date64_values);
run_test_non_strict(
values,
vec![
None,
Some(Variant::Date(Date64Type::to_naive_date_opt(0).unwrap())),
None,
],
);
}

#[test]
fn test_cast_to_variant_non_strict_mode_time32() {
let time32_array = Time32SecondArray::from(vec![Some(90000), Some(3600), Some(-1)]);

let values = Arc::new(time32_array);
run_test_non_strict(
values,
vec![
None,
Some(Variant::Time(
NaiveTime::from_num_seconds_from_midnight_opt(3600, 0).unwrap(),
)),
None,
],
);
}

#[test]
fn test_cast_to_variant_non_strict_mode_timestamp() {
let ts_array = TimestampSecondArray::from(vec![Some(i64::MAX), Some(0), Some(1609459200)])
.with_timezone_opt(None::<&str>);

let values = Arc::new(ts_array);
run_test_non_strict(
values,
vec![
None, // Invalid timestamp becomes null
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just realized that for overflow value, we may need to convert it to Variant::Null instead of None like what we did for Decimal, see this comment for context. However, I don't think it should be in this PR, we could discuss and make them consistent later.

cc @alamb

Some(Variant::TimestampNtzMicros(
timestamp_s_to_datetime(0).unwrap(),
)),
Some(Variant::TimestampNtzMicros(
timestamp_s_to_datetime(1609459200).unwrap(),
)),
],
);
}
}
Loading
Loading