Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 86 additions & 5 deletions parquet/src/column/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -921,12 +921,16 @@ impl<T: DataType> ColumnWriterImpl<T> {
}
}

#[allow(clippy::eq_op)]
fn update_page_min_max(&mut self, val: &T::T) {
if self.min_page_value.as_ref().map_or(true, |min| min > val) {
self.min_page_value = Some(val.clone());
}
if self.max_page_value.as_ref().map_or(true, |max| max < val) {
self.max_page_value = Some(val.clone());
// simple "isNaN" check that works for all types
if val == val {
Comment on lines +926 to +927
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we introduce some generic T::is_nan function into the data type abstraction layer (that basically returns False for every non-float/double type and calls .is_nan() for float/double) or is this local workaround good enough?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think that's necessary. It'll just be adding overhead for us.

The compiler can figure out what we're trying to do, and will optimise out the branch for anything that's not a floating point number.

Look at this example: https://godbolt.org/z/z7M3Mfqzz

The is_nan::<i32>() always returns false, and the ucomisd does the floating point comparison

if self.min_page_value.as_ref().map_or(true, |min| min > val) {
self.min_page_value = Some(val.clone());
}
if self.max_page_value.as_ref().map_or(true, |max| max < val) {
self.max_page_value = Some(val.clone());
}
}
}

Expand Down Expand Up @@ -1652,6 +1656,68 @@ mod tests {
);
}

#[test]
fn test_float_statistics_nan_middle() {
let stats = statistics_roundtrip::<FloatType>(&[1.0, f32::NAN, 2.0]);
assert!(stats.has_min_max_set());
if let Statistics::Float(stats) = stats {
assert_eq!(stats.min(), &1.0);
assert_eq!(stats.max(), &2.0);
} else {
panic!("expecting Statistics::Float");
}
}

#[test]
fn test_float_statistics_nan_start() {
let stats = statistics_roundtrip::<FloatType>(&[f32::NAN, 1.0, 2.0]);
assert!(stats.has_min_max_set());
if let Statistics::Float(stats) = stats {
assert_eq!(stats.min(), &1.0);
assert_eq!(stats.max(), &2.0);
} else {
panic!("expecting Statistics::Float");
}
}

#[test]
fn test_float_statistics_nan_only() {
let stats = statistics_roundtrip::<FloatType>(&[f32::NAN, f32::NAN]);
assert!(!stats.has_min_max_set());
assert!(matches!(stats, Statistics::Float(_)));
}

#[test]
fn test_double_statistics_nan_middle() {
let stats = statistics_roundtrip::<DoubleType>(&[1.0, f64::NAN, 2.0]);
assert!(stats.has_min_max_set());
if let Statistics::Double(stats) = stats {
assert_eq!(stats.min(), &1.0);
assert_eq!(stats.max(), &2.0);
} else {
panic!("expecting Statistics::Float");
}
}

#[test]
fn test_double_statistics_nan_start() {
let stats = statistics_roundtrip::<DoubleType>(&[f64::NAN, 1.0, 2.0]);
assert!(stats.has_min_max_set());
if let Statistics::Double(stats) = stats {
assert_eq!(stats.min(), &1.0);
assert_eq!(stats.max(), &2.0);
} else {
panic!("expecting Statistics::Float");
}
}

#[test]
fn test_double_statistics_nan_only() {
let stats = statistics_roundtrip::<DoubleType>(&[f64::NAN, f64::NAN]);
assert!(!stats.has_min_max_set());
assert!(matches!(stats, Statistics::Double(_)));
}

/// Performs write-read roundtrip with randomly generated values and levels.
/// `max_size` is maximum number of values or levels (if `max_def_level` > 0) to write
/// for a column.
Expand Down Expand Up @@ -1905,4 +1971,19 @@ mod tests {
Ok(())
}
}

/// Write data into parquet using [`get_test_page_writer`] and [`get_test_column_writer`] and returns generated statistics.
fn statistics_roundtrip<T: DataType>(values: &[<T as DataType>::T]) -> Statistics {
let page_writer = get_test_page_writer();
let props = Arc::new(WriterProperties::builder().build());
let mut writer = get_test_column_writer::<T>(page_writer, 0, 0, props);
writer.write_batch(values, None, None).unwrap();

let (_bytes_written, _rows_written, metadata) = writer.close().unwrap();
if let Some(stats) = metadata.statistics() {
stats.clone()
} else {
panic!("metadata missing statistics");
}
}
}