diff --git a/parquet-variant/Cargo.toml b/parquet-variant/Cargo.toml index 838ca7de8885..6bec373d0204 100644 --- a/parquet-variant/Cargo.toml +++ b/parquet-variant/Cargo.toml @@ -38,4 +38,8 @@ chrono = { workspace = true } serde_json = "1.0" base64 = "0.22" +[dev-dependencies] +paste = { version = "1.0" } + + [lib] diff --git a/parquet-variant/src/decoder.rs b/parquet-variant/src/decoder.rs index e73911aa2953..6b5c1310787c 100644 --- a/parquet-variant/src/decoder.rs +++ b/parquet-variant/src/decoder.rs @@ -283,157 +283,182 @@ pub(crate) fn decode_short_string(metadata: u8, data: &[u8]) -> Result Result<(), ArrowError> { - let data = [0x2a]; - let result = decode_int8(&data)?; - assert_eq!(result, 42); - Ok(()) + use paste::paste; + + macro_rules! test_decoder_bounds { + ($test_name:ident, $data:expr, $decode_fn:ident, $expected:expr) => { + paste! { + #[test] + fn [<$test_name _exact_length>]() { + let result = $decode_fn(&$data).unwrap(); + assert_eq!(result, $expected); + } + + #[test] + fn [<$test_name _truncated_length>]() { + // Remove the last byte of data so that there is not enough to decode + let truncated_data = &$data[.. $data.len() - 1]; + let result = $decode_fn(truncated_data); + assert!(matches!(result, Err(ArrowError::InvalidArgumentError(_)))); + } + } + }; } - #[test] - fn test_i16() -> Result<(), ArrowError> { - let data = [0xd2, 0x04]; - let result = decode_int16(&data)?; - assert_eq!(result, 1234); - Ok(()) + mod integer { + use super::*; + + test_decoder_bounds!(test_i8, [0x2a], decode_int8, 42); + test_decoder_bounds!(test_i16, [0xd2, 0x04], decode_int16, 1234); + test_decoder_bounds!(test_i32, [0x40, 0xe2, 0x01, 0x00], decode_int32, 123456); + test_decoder_bounds!( + test_i64, + [0x15, 0x81, 0xe9, 0x7d, 0xf4, 0x10, 0x22, 0x11], + decode_int64, + 1234567890123456789 + ); } - #[test] - fn test_i32() -> Result<(), ArrowError> { - let data = [0x40, 0xe2, 0x01, 0x00]; - let result = decode_int32(&data)?; - assert_eq!(result, 123456); - Ok(()) - } + mod decimal { + use super::*; + + test_decoder_bounds!( + test_decimal4, + [ + 0x02, // Scale + 0xd2, 0x04, 0x00, 0x00, // Unscaled Value + ], + decode_decimal4, + (1234, 2) + ); - #[test] - fn test_i64() -> Result<(), ArrowError> { - let data = [0x15, 0x81, 0xe9, 0x7d, 0xf4, 0x10, 0x22, 0x11]; - let result = decode_int64(&data)?; - assert_eq!(result, 1234567890123456789); - Ok(()) - } + test_decoder_bounds!( + test_decimal8, + [ + 0x02, // Scale + 0xd2, 0x02, 0x96, 0x49, 0x00, 0x00, 0x00, 0x00, // Unscaled Value + ], + decode_decimal8, + (1234567890, 2) + ); - #[test] - fn test_decimal4() -> Result<(), ArrowError> { - let data = [ - 0x02, // Scale - 0xd2, 0x04, 0x00, 0x00, // Integer - ]; - let result = decode_decimal4(&data)?; - assert_eq!(result, (1234, 2)); - Ok(()) + test_decoder_bounds!( + test_decimal16, + [ + 0x02, // Scale + 0xd2, 0xb6, 0x23, 0xc0, 0xf4, 0x10, 0x22, 0x11, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, // Unscaled Value + ], + decode_decimal16, + (1234567891234567890, 2) + ); } - #[test] - fn test_decimal8() -> Result<(), ArrowError> { - let data = [ - 0x02, // Scale - 0xd2, 0x02, 0x96, 0x49, 0x00, 0x00, 0x00, 0x00, // Integer - ]; - let result = decode_decimal8(&data)?; - assert_eq!(result, (1234567890, 2)); - Ok(()) - } + mod float { + use super::*; - #[test] - fn test_decimal16() -> Result<(), ArrowError> { - let data = [ - 0x02, // Scale - 0xd2, 0xb6, 0x23, 0xc0, 0xf4, 0x10, 0x22, 0x11, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, // Integer - ]; - let result = decode_decimal16(&data)?; - assert_eq!(result, (1234567891234567890, 2)); - Ok(()) - } + test_decoder_bounds!( + test_float, + [0x06, 0x2c, 0x93, 0x4e], + decode_float, + 1234567890.1234 + ); - #[test] - fn test_float() -> Result<(), ArrowError> { - let data = [0x06, 0x2c, 0x93, 0x4e]; - let result = decode_float(&data)?; - assert_eq!(result, 1234567890.1234); - Ok(()) + test_decoder_bounds!( + test_double, + [0xc9, 0xe5, 0x87, 0xb4, 0x80, 0x65, 0xd2, 0x41], + decode_double, + 1234567890.1234 + ); } - #[test] - fn test_double() -> Result<(), ArrowError> { - let data = [0xc9, 0xe5, 0x87, 0xb4, 0x80, 0x65, 0xd2, 0x41]; - let result = decode_double(&data)?; - assert_eq!(result, 1234567890.1234); - Ok(()) - } + mod datetime { + use super::*; - #[test] - fn test_date() -> Result<(), ArrowError> { - let data = [0xe2, 0x4e, 0x0, 0x0]; - let result = decode_date(&data)?; - assert_eq!(result, NaiveDate::from_ymd_opt(2025, 4, 16).unwrap()); - Ok(()) - } + test_decoder_bounds!( + test_date, + [0xe2, 0x4e, 0x0, 0x0], + decode_date, + NaiveDate::from_ymd_opt(2025, 4, 16).unwrap() + ); - #[test] - fn test_timestamp_micros() -> Result<(), ArrowError> { - let data = [0xe0, 0x52, 0x97, 0xdd, 0xe7, 0x32, 0x06, 0x00]; - let result = decode_timestamp_micros(&data)?; - assert_eq!( - result, + test_decoder_bounds!( + test_timestamp_micros, + [0xe0, 0x52, 0x97, 0xdd, 0xe7, 0x32, 0x06, 0x00], + decode_timestamp_micros, NaiveDate::from_ymd_opt(2025, 4, 16) .unwrap() .and_hms_milli_opt(16, 34, 56, 780) .unwrap() .and_utc() ); - Ok(()) - } - #[test] - fn test_timestampntz_micros() -> Result<(), ArrowError> { - let data = [0xe0, 0x52, 0x97, 0xdd, 0xe7, 0x32, 0x06, 0x00]; - let result = decode_timestampntz_micros(&data)?; - assert_eq!( - result, + test_decoder_bounds!( + test_timestampntz_micros, + [0xe0, 0x52, 0x97, 0xdd, 0xe7, 0x32, 0x06, 0x00], + decode_timestampntz_micros, NaiveDate::from_ymd_opt(2025, 4, 16) .unwrap() .and_hms_milli_opt(16, 34, 56, 780) .unwrap() ); - Ok(()) } #[test] - fn test_binary() -> Result<(), ArrowError> { + fn test_binary_exact_length() { let data = [ 0x09, 0, 0, 0, // Length of binary data, 4-byte little-endian 0x03, 0x13, 0x37, 0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, ]; - let result = decode_binary(&data)?; + let result = decode_binary(&data).unwrap(); assert_eq!( result, [0x03, 0x13, 0x37, 0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe] ); - Ok(()) } #[test] - fn test_short_string() -> Result<(), ArrowError> { + fn test_binary_truncated_length() { + let data = [ + 0x09, 0, 0, 0, // Length of binary data, 4-byte little-endian + 0x03, 0x13, 0x37, 0xde, 0xad, 0xbe, 0xef, 0xca, + ]; + let result = decode_binary(&data); + assert!(matches!(result, Err(ArrowError::InvalidArgumentError(_)))); + } + + #[test] + fn test_short_string_exact_length() { let data = [b'H', b'e', b'l', b'l', b'o', b'o']; - let result = decode_short_string(1 | 5 << 2, &data)?; + let result = decode_short_string(1 | 5 << 2, &data).unwrap(); assert_eq!(result.0, "Hello"); - Ok(()) } #[test] - fn test_string() -> Result<(), ArrowError> { + fn test_short_string_truncated_length() { + let data = [b'H', b'e', b'l']; + let result = decode_short_string(1 | 5 << 2, &data); + assert!(matches!(result, Err(ArrowError::InvalidArgumentError(_)))); + } + + #[test] + fn test_string_exact_length() { let data = [ 0x05, 0, 0, 0, // Length of string, 4-byte little-endian b'H', b'e', b'l', b'l', b'o', b'o', ]; - let result = decode_long_string(&data)?; + let result = decode_long_string(&data).unwrap(); assert_eq!(result, "Hello"); - Ok(()) + } + + #[test] + fn test_string_truncated_length() { + let data = [ + 0x05, 0, 0, 0, // Length of string, 4-byte little-endian + b'H', b'e', b'l', + ]; + let result = decode_long_string(&data); + assert!(matches!(result, Err(ArrowError::InvalidArgumentError(_)))); } #[test]