diff --git a/tensorboard/data/server/BUILD b/tensorboard/data/server/BUILD index 5c2387adde..32694283ef 100644 --- a/tensorboard/data/server/BUILD +++ b/tensorboard/data/server/BUILD @@ -32,6 +32,7 @@ rust_library( "reservoir.rs", "scripted_reader.rs", "tf_record.rs", + "types.rs", ] + _checked_in_proto_files, edition = "2018", deps = [ diff --git a/tensorboard/data/server/lib.rs b/tensorboard/data/server/lib.rs index 1230ddfe09..fdfd4c2d65 100644 --- a/tensorboard/data/server/lib.rs +++ b/tensorboard/data/server/lib.rs @@ -19,6 +19,7 @@ pub mod event_file; pub mod masked_crc; pub mod reservoir; pub mod tf_record; +pub mod types; #[cfg(test)] mod scripted_reader; diff --git a/tensorboard/data/server/reservoir.rs b/tensorboard/data/server/reservoir.rs index 0d4dd3f62b..ba8131ad24 100644 --- a/tensorboard/data/server/reservoir.rs +++ b/tensorboard/data/server/reservoir.rs @@ -21,6 +21,8 @@ use rand::{ }; use rand_chacha::ChaCha20Rng; +use crate::types::Step; + /// A [reservoir sampling] data structure, with support for preemption and deferred "commits" of /// records to a separate destination for better concurrency. /// @@ -98,10 +100,6 @@ pub struct StageReservoir { seen: usize, } -/// A step associated with a record, strictly increasing over time within a record stream. -#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Copy, Clone)] -pub struct Step(pub i64); - /// A buffer of records that have been committed and not yet evicted from the reservoir. /// /// This is a snapshot of the reservoir contents at some point in time that is periodically updated diff --git a/tensorboard/data/server/types.rs b/tensorboard/data/server/types.rs new file mode 100644 index 0000000000..59a764504a --- /dev/null +++ b/tensorboard/data/server/types.rs @@ -0,0 +1,110 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +//! Core simple types. + +use std::borrow::Borrow; + +/// A step associated with a record, strictly increasing over time within a record stream. +#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Copy, Clone)] +pub struct Step(pub i64); + +/// The wall time of a TensorBoard event. +/// +/// Wall times represent floating-point seconds since Unix epoch. They must be finite and non-NaN. +#[derive(Debug, PartialEq, PartialOrd, Copy, Clone)] +pub struct WallTime(f64); + +impl WallTime { + /// Parses a wall time from a time stamp representing seconds since Unix epoch. + /// + /// Returns `None` if the given time is infinite or NaN. + pub fn new(time: f64) -> Option { + if time.is_finite() { + Some(WallTime(time)) + } else { + None + } + } +} + +// Wall times are totally ordered and have a total equivalence relation, since we guarantee that +// they are not NaN. +#[allow(clippy::derive_ord_xor_partial_ord)] // okay because it agrees with `PartialOrd` impl +impl Ord for WallTime { + fn cmp(&self, other: &Self) -> std::cmp::Ordering { + self.partial_cmp(&other) + .unwrap_or_else(|| unreachable!("{:?} <> {:?}", &self, &other)) + } +} +impl Eq for WallTime {} + +impl From for f64 { + fn from(wt: WallTime) -> f64 { + wt.0 + } +} + +/// The name of a time series within the context of a run. +/// +/// Tag names are valid Unicode text strings. They should be non-empty, though this type does not +/// enforce that. +#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone)] +pub struct Tag(pub String); + +impl Borrow for Tag { + fn borrow(&self) -> &str { + &self.0 + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_tag_hash_map_str_access() { + use std::collections::HashMap; + let mut m: HashMap = HashMap::new(); + m.insert(Tag("accuracy".to_string()), 1); + m.insert(Tag("loss".to_string()), 2); + // We can call `get` given only a `&str`, not an owned `Tag`. + assert_eq!(m.get("accuracy"), Some(&1)); + assert_eq!(m.get("xent"), None); + } + + #[test] + fn test_wall_time() { + assert_eq!(WallTime::new(f64::INFINITY), None); + assert_eq!(WallTime::new(-f64::INFINITY), None); + assert_eq!(WallTime::new(f64::NAN), None); + + assert_eq!(f64::from(WallTime::new(1234.5).unwrap()), 1234.5); + assert!(WallTime::new(1234.5) < WallTime::new(2345.625)); + + let mut actual = vec![ + WallTime::new(123.0).unwrap(), + WallTime::new(-456.0).unwrap(), + WallTime::new(789.0).unwrap(), + ]; + actual.sort(); + let expected = vec![ + WallTime::new(-456.0).unwrap(), + WallTime::new(123.0).unwrap(), + WallTime::new(789.0).unwrap(), + ]; + assert_eq!(actual, expected); + } +}