-
Notifications
You must be signed in to change notification settings - Fork 13
feat: Detect and fail on unrecognised envelope flags #2453
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -3,6 +3,8 @@ | |
| use std::io::{Read, Write}; | ||
| use std::num::NonZeroU8; | ||
|
|
||
| use itertools::Itertools; | ||
|
|
||
| use super::EnvelopeError; | ||
|
|
||
| /// Magic number identifying the start of an envelope. | ||
|
|
@@ -11,6 +13,12 @@ use super::EnvelopeError; | |
| /// to avoid accidental collisions with other file formats. | ||
| pub const MAGIC_NUMBERS: &[u8] = "HUGRiHJv".as_bytes(); | ||
|
|
||
| /// The all-unset header flags configuration. | ||
| /// Bit 7 is always set to ensure we have a printable ASCII character. | ||
| const DEFAULT_FLAGS: u8 = 0b0100_0000u8; | ||
| /// The ZSTD flag bit in the header's flags. | ||
| const ZSTD_FLAG: u8 = 0b0000_0001; | ||
|
|
||
| /// Header at the start of a binary envelope file. | ||
| /// | ||
| /// See the [`crate::envelope`] module documentation for the binary format. | ||
|
|
@@ -224,8 +232,10 @@ impl EnvelopeHeader { | |
| let format_bytes = [self.format as u8]; | ||
| writer.write_all(&format_bytes)?; | ||
| // Next is the flags byte. | ||
| let mut flags = 0b01000000u8; | ||
| flags |= u8::from(self.zstd); | ||
| let mut flags = DEFAULT_FLAGS; | ||
| if self.zstd { | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks that's much clearer! |
||
| flags |= ZSTD_FLAG; | ||
| } | ||
| writer.write_all(&[flags])?; | ||
|
|
||
| Ok(()) | ||
|
|
@@ -259,7 +269,16 @@ impl EnvelopeHeader { | |
| // Next is the flags byte. | ||
| let mut flags_bytes = [0; 1]; | ||
| reader.read_exact(&mut flags_bytes)?; | ||
| let zstd = flags_bytes[0] & 0x1 != 0; | ||
| let flags: u8 = flags_bytes[0]; | ||
|
|
||
| let zstd = flags & ZSTD_FLAG != 0; | ||
|
|
||
| // Check if there's any unrecognized flags. | ||
| let other_flags = (flags ^ DEFAULT_FLAGS) & !ZSTD_FLAG; | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'd find it easier to read as
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's not quite the same. DEFAULT_FLAGS has an always true value, so if there's a zero there we have to complain too. |
||
| if other_flags != 0 { | ||
| let flag_ids = (0..8).filter(|i| other_flags & (1 << i) != 0).collect_vec(); | ||
| return Err(EnvelopeError::FlagUnsupported { flag_ids }); | ||
| } | ||
|
|
||
| Ok(Self { format, zstd }) | ||
| } | ||
|
|
@@ -268,6 +287,7 @@ impl EnvelopeHeader { | |
| #[cfg(test)] | ||
| mod tests { | ||
| use super::*; | ||
| use cool_asserts::assert_matches; | ||
| use rstest::rstest; | ||
|
|
||
| #[rstest] | ||
|
|
@@ -296,4 +316,35 @@ mod tests { | |
| let read_header = EnvelopeHeader::read(&mut buffer.as_slice()).unwrap(); | ||
| assert_eq!(header, read_header); | ||
| } | ||
|
|
||
| #[rstest] | ||
| fn header_errors() { | ||
| let header = EnvelopeHeader { | ||
| format: EnvelopeFormat::Model, | ||
| zstd: false, | ||
| }; | ||
| let mut buffer = Vec::new(); | ||
| header.write(&mut buffer).unwrap(); | ||
|
|
||
| assert_eq!(buffer.len(), 10); | ||
| let flags = buffer[9]; | ||
| assert_eq!(flags, DEFAULT_FLAGS); | ||
|
|
||
| // Invalid magic | ||
| let mut invalid_magic = buffer.clone(); | ||
| invalid_magic[7] = 0xFF; | ||
| assert_matches!( | ||
| EnvelopeHeader::read(&mut invalid_magic.as_slice()), | ||
| Err(EnvelopeError::MagicNumber { .. }) | ||
| ); | ||
|
|
||
| // Unrecognised flags | ||
| let mut unrecognised_flags = buffer.clone(); | ||
| unrecognised_flags[9] |= 0b0001_0010; | ||
| assert_matches!( | ||
| EnvelopeHeader::read(&mut unrecognised_flags.as_slice()), | ||
| Err(EnvelopeError::FlagUnsupported { flag_ids }) | ||
| => assert_eq!(flag_ids, vec![1, 4]) | ||
| ); | ||
| } | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -46,6 +46,12 @@ | |
| # This is a hard-coded magic number that identifies the start of a HUGR envelope. | ||
| MAGIC_NUMBERS = b"HUGRiHJv" | ||
|
|
||
| # The all-unset header flags configuration. | ||
| # Bit 7 is always set to ensure we have a printable ASCII character. | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If you can't reuse the same consts as un Rust, can you maybe add a comment about which file their Rust counterparts are defined in? |
||
| _DEFAULT_FLAGS = 0b0100_0000 | ||
| # The ZSTD flag bit in the header's flags. | ||
| _ZSTD_FLAG = 0b0000_0001 | ||
|
|
||
|
|
||
| def make_envelope(package: Package | Hugr, config: EnvelopeConfig) -> bytes: | ||
| """Encode a HUGR or Package into an envelope, using the given configuration.""" | ||
|
|
@@ -180,9 +186,9 @@ class EnvelopeHeader: | |
| def to_bytes(self) -> bytes: | ||
| header_bytes = bytearray(MAGIC_NUMBERS) | ||
| header_bytes.append(self.format.value) | ||
| flags = 0b01000000 | ||
| flags = _DEFAULT_FLAGS | ||
| if self.zstd: | ||
| flags |= 0b00000001 | ||
| flags |= _ZSTD_FLAG | ||
| header_bytes.append(flags) | ||
| return bytes(header_bytes) | ||
|
|
||
|
|
@@ -204,7 +210,15 @@ def from_bytes(data: bytes) -> EnvelopeHeader: | |
| format: EnvelopeFormat = EnvelopeFormat(data[8]) | ||
|
|
||
| flags = data[9] | ||
| zstd = bool(flags & 0b00000001) | ||
| zstd = bool(flags & _ZSTD_FLAG) | ||
| other_flags = (flags ^ _DEFAULT_FLAGS) & ~_ZSTD_FLAG | ||
| if other_flags: | ||
| flag_ids = [i for i in range(8) if other_flags & (1 << i)] | ||
| msg = ( | ||
| f"Unrecognised Envelope flags {flag_ids}." | ||
| + " Please update your HUGR version." | ||
| ) | ||
| raise ValueError(msg) | ||
|
|
||
| return EnvelopeHeader(format=format, zstd=zstd) | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: you're using # in the singular case. You can use it too in the plural by changing the join argument to ", #" as well as adding one to the string