Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions hugr-core/src/envelope.rs
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,15 @@ pub enum EnvelopeError {
#[from]
source: crate::extension::ExtensionRegistryLoadError,
},
/// The specified payload format is not supported.
#[error(
"The envelope configuration has unknown {}. Please update your HUGR version.",
if flag_ids.len() == 1 {format!("flag #{}", flag_ids[0])} else {format!("flags {}", flag_ids.iter().join(", "))}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: you're using # in the singular case. You can use it too in the plural by changing the join argument to ", #" as well as adding one to the string

)]
FlagUnsupported {
/// The unrecognized flag bits.
flag_ids: Vec<usize>,
},
}

/// Internal implementation of [`read_envelope`] to call with/without the zstd decompression wrapper.
Expand Down
57 changes: 54 additions & 3 deletions hugr-core/src/envelope/header.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
use std::io::{Read, Write};
use std::num::NonZeroU8;

use itertools::Itertools;

use super::EnvelopeError;

/// Magic number identifying the start of an envelope.
Expand All @@ -11,6 +13,12 @@ use super::EnvelopeError;
/// to avoid accidental collisions with other file formats.
pub const MAGIC_NUMBERS: &[u8] = "HUGRiHJv".as_bytes();

/// The all-unset header flags configuration.
/// Bit 7 is always set to ensure we have a printable ASCII character.
const DEFAULT_FLAGS: u8 = 0b0100_0000u8;
/// The ZSTD flag bit in the header's flags.
const ZSTD_FLAG: u8 = 0b0000_0001;

/// Header at the start of a binary envelope file.
///
/// See the [`crate::envelope`] module documentation for the binary format.
Expand Down Expand Up @@ -224,8 +232,10 @@ impl EnvelopeHeader {
let format_bytes = [self.format as u8];
writer.write_all(&format_bytes)?;
// Next is the flags byte.
let mut flags = 0b01000000u8;
flags |= u8::from(self.zstd);
let mut flags = DEFAULT_FLAGS;
if self.zstd {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks that's much clearer!

flags |= ZSTD_FLAG;
}
writer.write_all(&[flags])?;

Ok(())
Expand Down Expand Up @@ -259,7 +269,16 @@ impl EnvelopeHeader {
// Next is the flags byte.
let mut flags_bytes = [0; 1];
reader.read_exact(&mut flags_bytes)?;
let zstd = flags_bytes[0] & 0x1 != 0;
let flags: u8 = flags_bytes[0];

let zstd = flags & ZSTD_FLAG != 0;

// Check if there's any unrecognized flags.
let other_flags = (flags ^ DEFAULT_FLAGS) & !ZSTD_FLAG;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd find it easier to read as

let other_flags = flags & !(DEFAULT_FLAGS | ZSTD_FLAG);

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not quite the same. DEFAULT_FLAGS has an always true value, so if there's a zero there we have to complain too.
(I should make that more explicit)

if other_flags != 0 {
let flag_ids = (0..8).filter(|i| other_flags & (1 << i) != 0).collect_vec();
return Err(EnvelopeError::FlagUnsupported { flag_ids });
}

Ok(Self { format, zstd })
}
Expand All @@ -268,6 +287,7 @@ impl EnvelopeHeader {
#[cfg(test)]
mod tests {
use super::*;
use cool_asserts::assert_matches;
use rstest::rstest;

#[rstest]
Expand Down Expand Up @@ -296,4 +316,35 @@ mod tests {
let read_header = EnvelopeHeader::read(&mut buffer.as_slice()).unwrap();
assert_eq!(header, read_header);
}

#[rstest]
fn header_errors() {
let header = EnvelopeHeader {
format: EnvelopeFormat::Model,
zstd: false,
};
let mut buffer = Vec::new();
header.write(&mut buffer).unwrap();

assert_eq!(buffer.len(), 10);
let flags = buffer[9];
assert_eq!(flags, DEFAULT_FLAGS);

// Invalid magic
let mut invalid_magic = buffer.clone();
invalid_magic[7] = 0xFF;
assert_matches!(
EnvelopeHeader::read(&mut invalid_magic.as_slice()),
Err(EnvelopeError::MagicNumber { .. })
);

// Unrecognised flags
let mut unrecognised_flags = buffer.clone();
unrecognised_flags[9] |= 0b0001_0010;
assert_matches!(
EnvelopeHeader::read(&mut unrecognised_flags.as_slice()),
Err(EnvelopeError::FlagUnsupported { flag_ids })
=> assert_eq!(flag_ids, vec![1, 4])
);
}
}
20 changes: 17 additions & 3 deletions hugr-py/src/hugr/envelope.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,12 @@
# This is a hard-coded magic number that identifies the start of a HUGR envelope.
MAGIC_NUMBERS = b"HUGRiHJv"

# The all-unset header flags configuration.
# Bit 7 is always set to ensure we have a printable ASCII character.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you can't reuse the same consts as un Rust, can you maybe add a comment about which file their Rust counterparts are defined in?

_DEFAULT_FLAGS = 0b0100_0000
# The ZSTD flag bit in the header's flags.
_ZSTD_FLAG = 0b0000_0001


def make_envelope(package: Package | Hugr, config: EnvelopeConfig) -> bytes:
"""Encode a HUGR or Package into an envelope, using the given configuration."""
Expand Down Expand Up @@ -180,9 +186,9 @@ class EnvelopeHeader:
def to_bytes(self) -> bytes:
header_bytes = bytearray(MAGIC_NUMBERS)
header_bytes.append(self.format.value)
flags = 0b01000000
flags = _DEFAULT_FLAGS
if self.zstd:
flags |= 0b00000001
flags |= _ZSTD_FLAG
header_bytes.append(flags)
return bytes(header_bytes)

Expand All @@ -204,7 +210,15 @@ def from_bytes(data: bytes) -> EnvelopeHeader:
format: EnvelopeFormat = EnvelopeFormat(data[8])

flags = data[9]
zstd = bool(flags & 0b00000001)
zstd = bool(flags & _ZSTD_FLAG)
other_flags = (flags ^ _DEFAULT_FLAGS) & ~_ZSTD_FLAG
if other_flags:
flag_ids = [i for i in range(8) if other_flags & (1 << i)]
msg = (
f"Unrecognised Envelope flags {flag_ids}."
+ " Please update your HUGR version."
)
raise ValueError(msg)

return EnvelopeHeader(format=format, zstd=zstd)

Expand Down
Loading