diff --git a/parquet-variant/src/builder.rs b/parquet-variant/src/builder.rs index a5fb66a84ff4..1c6ebe23d24f 100644 --- a/parquet-variant/src/builder.rs +++ b/parquet-variant/src/builder.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. use crate::decoder::{VariantBasicType, VariantPrimitiveType}; -use crate::{ShortString, Variant}; +use crate::{ShortString, Variant, VariantDecimal16, VariantDecimal4, VariantDecimal8}; use std::collections::BTreeMap; const BASIC_TYPE_BITS: u8 = 2; @@ -384,9 +384,15 @@ impl VariantBuilder { Variant::Date(v) => self.append_date(v), Variant::TimestampMicros(v) => self.append_timestamp_micros(v), Variant::TimestampNtzMicros(v) => self.append_timestamp_ntz_micros(v), - Variant::Decimal4 { integer, scale } => self.append_decimal4(integer, scale), - Variant::Decimal8 { integer, scale } => self.append_decimal8(integer, scale), - Variant::Decimal16 { integer, scale } => self.append_decimal16(integer, scale), + Variant::Decimal4(VariantDecimal4 { integer, scale }) => { + self.append_decimal4(integer, scale) + } + Variant::Decimal8(VariantDecimal8 { integer, scale }) => { + self.append_decimal8(integer, scale) + } + Variant::Decimal16(VariantDecimal16 { integer, scale }) => { + self.append_decimal16(integer, scale) + } Variant::Float(v) => self.append_float(v), Variant::Double(v) => self.append_double(v), Variant::Binary(v) => self.append_binary(v), diff --git a/parquet-variant/src/variant.rs b/parquet-variant/src/variant.rs index 51327b4d2528..b343a538d54c 100644 --- a/parquet-variant/src/variant.rs +++ b/parquet-variant/src/variant.rs @@ -40,8 +40,100 @@ const MAX_SHORT_STRING_BYTES: usize = 0x3F; #[derive(Debug, Clone, Copy, PartialEq)] pub struct ShortString<'a>(pub(crate) &'a str); +/// Represents a 4-byte decimal value in the Variant format. +/// +/// This struct stores a decimal number using a 32-bit signed integer for the coefficient +/// and an 8-bit unsigned integer for the scale (number of decimal places). Its precision is limited to 9 digits. +/// +/// For valid precision and scale values, see the Variant specification: +/// +/// +#[derive(Debug, Clone, Copy, PartialEq)] +pub struct VariantDecimal4 { + pub(crate) integer: i32, + pub(crate) scale: u8, +} + +impl VariantDecimal4 { + pub fn try_new(integer: i32, scale: u8) -> Result { + const PRECISION_MAX: u32 = 9; + + // Validate that scale doesn't exceed precision + if scale as u32 > PRECISION_MAX { + return Err(ArrowError::InvalidArgumentError(format!( + "Scale {} cannot be greater than precision 9 for 4-byte decimal", + scale + ))); + } + + Ok(VariantDecimal4 { integer, scale }) + } +} + +/// Represents an 8-byte decimal value in the Variant format. +/// +/// This struct stores a decimal number using a 64-bit signed integer for the coefficient +/// and an 8-bit unsigned integer for the scale (number of decimal places). Its precision is between 10 and 18 digits. +/// +/// For valid precision and scale values, see the Variant specification: +/// +/// +/// +#[derive(Debug, Clone, Copy, PartialEq)] +pub struct VariantDecimal8 { + pub(crate) integer: i64, + pub(crate) scale: u8, +} + +impl VariantDecimal8 { + pub fn try_new(integer: i64, scale: u8) -> Result { + const PRECISION_MAX: u32 = 18; + + // Validate that scale doesn't exceed precision + if scale as u32 > PRECISION_MAX { + return Err(ArrowError::InvalidArgumentError(format!( + "Scale {} cannot be greater than precision 18 for 8-byte decimal", + scale + ))); + } + + Ok(VariantDecimal8 { integer, scale }) + } +} + +/// Represents an 16-byte decimal value in the Variant format. +/// +/// This struct stores a decimal number using a 128-bit signed integer for the coefficient +/// and an 8-bit unsigned integer for the scale (number of decimal places). Its precision is between 19 and 38 digits. +/// +/// For valid precision and scale values, see the Variant specification: +/// +/// +/// +#[derive(Debug, Clone, Copy, PartialEq)] +pub struct VariantDecimal16 { + pub(crate) integer: i128, + pub(crate) scale: u8, +} + +impl VariantDecimal16 { + pub fn try_new(integer: i128, scale: u8) -> Result { + const PRECISION_MAX: u32 = 38; + + // Validate that scale doesn't exceed precision + if scale as u32 > PRECISION_MAX { + return Err(ArrowError::InvalidArgumentError(format!( + "Scale {} cannot be greater than precision 38 for 16-byte decimal", + scale + ))); + } + + Ok(VariantDecimal16 { integer, scale }) + } +} + impl<'a> ShortString<'a> { - /// Attempts to interpret `value` as a variant short string value. + /// Attempts to interpret `value` as a variant short string value. /// /// # Validation /// @@ -194,11 +286,11 @@ pub enum Variant<'m, 'v> { /// Primitive (type_id=1): TIMESTAMP(isAdjustedToUTC=false, MICROS) TimestampNtzMicros(NaiveDateTime), /// Primitive (type_id=1): DECIMAL(precision, scale) 32-bits - Decimal4 { integer: i32, scale: u8 }, + Decimal4(VariantDecimal4), /// Primitive (type_id=1): DECIMAL(precision, scale) 64-bits - Decimal8 { integer: i64, scale: u8 }, + Decimal8(VariantDecimal8), /// Primitive (type_id=1): DECIMAL(precision, scale) 128-bits - Decimal16 { integer: i128, scale: u8 }, + Decimal16(VariantDecimal16), /// Primitive (type_id=1): FLOAT Float(f32), /// Primitive (type_id=1): DOUBLE @@ -269,15 +361,15 @@ impl<'m, 'v> Variant<'m, 'v> { VariantPrimitiveType::Int64 => Variant::Int64(decoder::decode_int64(value_data)?), VariantPrimitiveType::Decimal4 => { let (integer, scale) = decoder::decode_decimal4(value_data)?; - Variant::Decimal4 { integer, scale } + Variant::Decimal4(VariantDecimal4 { integer, scale }) } VariantPrimitiveType::Decimal8 => { let (integer, scale) = decoder::decode_decimal8(value_data)?; - Variant::Decimal8 { integer, scale } + Variant::Decimal8(VariantDecimal8 { integer, scale }) } VariantPrimitiveType::Decimal16 => { let (integer, scale) = decoder::decode_decimal16(value_data)?; - Variant::Decimal16 { integer, scale } + Variant::Decimal16(VariantDecimal16 { integer, scale }) } VariantPrimitiveType::Float => Variant::Float(decoder::decode_float(value_data)?), VariantPrimitiveType::Double => { @@ -640,18 +732,18 @@ impl<'m, 'v> Variant<'m, 'v> { /// # Examples /// /// ``` - /// use parquet_variant::Variant; + /// use parquet_variant::{Variant, VariantDecimal4, VariantDecimal8}; /// /// // you can extract decimal parts from smaller or equally-sized decimal variants - /// let v1 = Variant::from((1234_i32, 2)); + /// let v1 = Variant::from(VariantDecimal4::try_new(1234_i32, 2).unwrap()); /// assert_eq!(v1.as_decimal_int32(), Some((1234_i32, 2))); /// /// // and from larger decimal variants if they fit - /// let v2 = Variant::from((1234_i64, 2)); + /// let v2 = Variant::from(VariantDecimal8::try_new(1234_i64, 2).unwrap()); /// assert_eq!(v2.as_decimal_int32(), Some((1234_i32, 2))); /// /// // but not if the value would overflow i32 - /// let v3 = Variant::from((12345678901i64, 2)); + /// let v3 = Variant::from(VariantDecimal8::try_new(12345678901i64, 2).unwrap()); /// assert_eq!(v3.as_decimal_int32(), None); /// /// // or if the variant is not a decimal @@ -660,17 +752,17 @@ impl<'m, 'v> Variant<'m, 'v> { /// ``` pub fn as_decimal_int32(&self) -> Option<(i32, u8)> { match *self { - Variant::Decimal4 { integer, scale } => Some((integer, scale)), - Variant::Decimal8 { integer, scale } => { - if let Ok(converted_integer) = integer.try_into() { - Some((converted_integer, scale)) + Variant::Decimal4(decimal4) => Some((decimal4.integer, decimal4.scale)), + Variant::Decimal8(decimal8) => { + if let Ok(converted_integer) = decimal8.integer.try_into() { + Some((converted_integer, decimal8.scale)) } else { None } } - Variant::Decimal16 { integer, scale } => { - if let Ok(converted_integer) = integer.try_into() { - Some((converted_integer, scale)) + Variant::Decimal16(decimal16) => { + if let Ok(converted_integer) = decimal16.integer.try_into() { + Some((converted_integer, decimal16.scale)) } else { None } @@ -688,18 +780,18 @@ impl<'m, 'v> Variant<'m, 'v> { /// # Examples /// /// ``` - /// use parquet_variant::Variant; + /// use parquet_variant::{Variant, VariantDecimal8, VariantDecimal16}; /// /// // you can extract decimal parts from smaller or equally-sized decimal variants - /// let v1 = Variant::from((1234_i64, 2)); + /// let v1 = Variant::from(VariantDecimal8::try_new(1234_i64, 2).unwrap()); /// assert_eq!(v1.as_decimal_int64(), Some((1234_i64, 2))); /// /// // and from larger decimal variants if they fit - /// let v2 = Variant::from((1234_i128, 2)); + /// let v2 = Variant::from(VariantDecimal16::try_new(1234_i128, 2).unwrap()); /// assert_eq!(v2.as_decimal_int64(), Some((1234_i64, 2))); /// /// // but not if the value would overflow i64 - /// let v3 = Variant::from((2e19 as i128, 2)); + /// let v3 = Variant::from(VariantDecimal16::try_new(2e19 as i128, 2).unwrap()); /// assert_eq!(v3.as_decimal_int64(), None); /// /// // or if the variant is not a decimal @@ -708,11 +800,11 @@ impl<'m, 'v> Variant<'m, 'v> { /// ``` pub fn as_decimal_int64(&self) -> Option<(i64, u8)> { match *self { - Variant::Decimal4 { integer, scale } => Some((integer.into(), scale)), - Variant::Decimal8 { integer, scale } => Some((integer, scale)), - Variant::Decimal16 { integer, scale } => { - if let Ok(converted_integer) = integer.try_into() { - Some((converted_integer, scale)) + Variant::Decimal4(decimal) => Some((decimal.integer.into(), decimal.scale)), + Variant::Decimal8(decimal) => Some((decimal.integer, decimal.scale)), + Variant::Decimal16(decimal) => { + if let Ok(converted_integer) = decimal.integer.try_into() { + Some((converted_integer, decimal.scale)) } else { None } @@ -730,10 +822,10 @@ impl<'m, 'v> Variant<'m, 'v> { /// # Examples /// /// ``` - /// use parquet_variant::Variant; + /// use parquet_variant::{Variant, VariantDecimal16}; /// /// // you can extract decimal parts from smaller or equally-sized decimal variants - /// let v1 = Variant::from((1234_i128, 2)); + /// let v1 = Variant::from(VariantDecimal16::try_new(1234_i128, 2).unwrap()); /// assert_eq!(v1.as_decimal_int128(), Some((1234_i128, 2))); /// /// // but not if the variant is not a decimal @@ -742,9 +834,9 @@ impl<'m, 'v> Variant<'m, 'v> { /// ``` pub fn as_decimal_int128(&self) -> Option<(i128, u8)> { match *self { - Variant::Decimal4 { integer, scale } => Some((integer.into(), scale)), - Variant::Decimal8 { integer, scale } => Some((integer.into(), scale)), - Variant::Decimal16 { integer, scale } => Some((integer, scale)), + Variant::Decimal4(decimal) => Some((decimal.integer.into(), decimal.scale)), + Variant::Decimal8(decimal) => Some((decimal.integer.into(), decimal.scale)), + Variant::Decimal16(decimal) => Some((decimal.integer, decimal.scale)), _ => None, } } @@ -912,30 +1004,21 @@ impl From for Variant<'_, '_> { } } -impl From<(i32, u8)> for Variant<'_, '_> { - fn from(value: (i32, u8)) -> Self { - Variant::Decimal4 { - integer: value.0, - scale: value.1, - } +impl From for Variant<'_, '_> { + fn from(value: VariantDecimal4) -> Self { + Variant::Decimal4(value) } } -impl From<(i64, u8)> for Variant<'_, '_> { - fn from(value: (i64, u8)) -> Self { - Variant::Decimal8 { - integer: value.0, - scale: value.1, - } +impl From for Variant<'_, '_> { + fn from(value: VariantDecimal8) -> Self { + Variant::Decimal8(value) } } -impl From<(i128, u8)> for Variant<'_, '_> { - fn from(value: (i128, u8)) -> Self { - Variant::Decimal16 { - integer: value.0, - scale: value.1, - } +impl From for Variant<'_, '_> { + fn from(value: VariantDecimal16) -> Self { + Variant::Decimal16(value) } } @@ -994,6 +1077,36 @@ impl<'v> From<&'v str> for Variant<'_, 'v> { } } +impl TryFrom<(i32, u8)> for Variant<'_, '_> { + type Error = ArrowError; + + fn try_from(value: (i32, u8)) -> Result { + Ok(Variant::Decimal4(VariantDecimal4::try_new( + value.0, value.1, + )?)) + } +} + +impl TryFrom<(i64, u8)> for Variant<'_, '_> { + type Error = ArrowError; + + fn try_from(value: (i64, u8)) -> Result { + Ok(Variant::Decimal8(VariantDecimal8::try_new( + value.0, value.1, + )?)) + } +} + +impl TryFrom<(i128, u8)> for Variant<'_, '_> { + type Error = ArrowError; + + fn try_from(value: (i128, u8)) -> Result { + Ok(Variant::Decimal16(VariantDecimal16::try_new( + value.0, value.1, + )?)) + } +} + #[cfg(test)] mod tests { use super::*; @@ -1007,4 +1120,28 @@ mod tests { let res = ShortString::try_new(&long_string); assert!(res.is_err()); } + + #[test] + fn test_variant_decimal_conversion() { + let decimal4 = VariantDecimal4::try_new(1234_i32, 2).unwrap(); + let variant = Variant::from(decimal4); + assert_eq!(variant.as_decimal_int32(), Some((1234_i32, 2))); + + let decimal8 = VariantDecimal8::try_new(12345678901_i64, 2).unwrap(); + let variant = Variant::from(decimal8); + assert_eq!(variant.as_decimal_int64(), Some((12345678901_i64, 2))); + + let decimal16 = VariantDecimal16::try_new(123456789012345678901234567890_i128, 2).unwrap(); + let variant = Variant::from(decimal16); + assert_eq!( + variant.as_decimal_int128(), + Some((123456789012345678901234567890_i128, 2)) + ); + } + + #[test] + fn test_invalid_variant_decimal_conversion() { + let decimal4 = VariantDecimal4::try_new(123456789_i32, 20); + assert!(decimal4.is_err(), "i32 overflow should fail"); + } } diff --git a/parquet-variant/tests/variant_interop.rs b/parquet-variant/tests/variant_interop.rs index bfa2ab267c27..be63357422e4 100644 --- a/parquet-variant/tests/variant_interop.rs +++ b/parquet-variant/tests/variant_interop.rs @@ -24,7 +24,9 @@ use std::fs; use std::path::{Path, PathBuf}; use chrono::NaiveDate; -use parquet_variant::{ShortString, Variant, VariantBuilder}; +use parquet_variant::{ + ShortString, Variant, VariantBuilder, VariantDecimal16, VariantDecimal4, VariantDecimal8, +}; fn cases_dir() -> PathBuf { Path::new(env!("CARGO_MANIFEST_DIR")) @@ -63,9 +65,10 @@ fn get_primitive_cases() -> Vec<(&'static str, Variant<'static, 'static>)> { ("primitive_boolean_false", Variant::BooleanFalse), ("primitive_boolean_true", Variant::BooleanTrue), ("primitive_date", Variant::Date(NaiveDate::from_ymd_opt(2025, 4 , 16).unwrap())), - ("primitive_decimal4", Variant::Decimal4{integer: 1234, scale: 2}), - ("primitive_decimal8", Variant::Decimal8{integer: 1234567890, scale: 2}), - ("primitive_decimal16", Variant::Decimal16{integer: 1234567891234567890, scale: 2}), + ("primitive_decimal4", Variant::from(VariantDecimal4::try_new(1234i32, 2u8).unwrap())), + // ("primitive_decimal8", Variant::Decimal8{integer: 1234567890, scale: 2}), + ("primitive_decimal8", Variant::Decimal8(VariantDecimal8::try_new(1234567890,2).unwrap())), + ("primitive_decimal16", Variant::Decimal16(VariantDecimal16::try_new(1234567891234567890, 2).unwrap())), ("primitive_float", Variant::Float(1234567890.1234)), ("primitive_double", Variant::Double(1234567890.1234)), ("primitive_int8", Variant::Int8(42)), @@ -123,10 +126,7 @@ fn variant_object_primitive() { // spark wrote this as a decimal4 (not a double) ( "double_field", - Variant::Decimal4 { - integer: 123456789, - scale: 8, - }, + Variant::Decimal4(VariantDecimal4::try_new(123456789, 8).unwrap()), ), ("int_field", Variant::Int8(1)), ("null_field", Variant::Null), @@ -210,7 +210,10 @@ fn variant_object_builder() { // The double field is actually encoded as decimal4 with scale 8 // Value: 123456789, Scale: 8 -> 1.23456789 - obj.append_value("double_field", (123456789i32, 8u8)); + obj.append_value( + "double_field", + VariantDecimal4::try_new(123456789i32, 8u8).unwrap(), + ); obj.append_value("boolean_true_field", true); obj.append_value("boolean_false_field", false); obj.append_value("string_field", "Apache Parquet");