Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 40 additions & 23 deletions hugr-core/src/envelope.rs
Original file line number Diff line number Diff line change
Expand Up @@ -361,7 +361,7 @@ fn read_impl(
header: EnvelopeHeader,
registry: &ExtensionRegistry,
) -> Result<Package, EnvelopeError> {
let (package, combined_registry) = match header.format {
let package = match header.format {
#[allow(deprecated)]
EnvelopeFormat::PackageJson => Ok(package_json::from_json_reader(payload, registry)?),
EnvelopeFormat::Model | EnvelopeFormat::ModelWithExtensions => {
Expand All @@ -373,8 +373,7 @@ fn read_impl(
}?;

package.modules.iter().try_for_each(|module| {
check_breaking_extensions(module, &combined_registry)
.map_err(|err| WithGenerator::new(err, &package.modules))
check_breaking_extensions(module).map_err(|err| WithGenerator::new(err, &package.modules))
})?;
Ok(package)
}
Expand All @@ -393,7 +392,7 @@ fn decode_model(
mut stream: impl BufRead,
extension_registry: &ExtensionRegistry,
format: EnvelopeFormat,
) -> Result<(Package, ExtensionRegistry), EnvelopeError> {
) -> Result<Package, EnvelopeError> {
check_model_version(format)?;
let bump = Bump::default();
let model_package = hugr_model::v0::binary::read_from_reader(&mut stream, &bump)?;
Expand All @@ -406,7 +405,7 @@ fn decode_model(

let package = import_package(&model_package, &extension_registry)?;

Ok((package, extension_registry))
Ok(package)
}

fn check_model_version(format: EnvelopeFormat) -> Result<(), EnvelopeError> {
Expand All @@ -430,7 +429,7 @@ fn decode_model_ast(
mut stream: impl BufRead,
extension_registry: &ExtensionRegistry,
format: EnvelopeFormat,
) -> Result<(Package, ExtensionRegistry), EnvelopeError> {
) -> Result<Package, EnvelopeError> {
check_model_version(format)?;

let mut extension_registry = extension_registry.clone();
Expand Down Expand Up @@ -458,7 +457,7 @@ fn decode_model_ast(

let package = import_package(&model_package, &extension_registry)?;

Ok((package, extension_registry))
Ok(package)
}

/// Internal implementation of [`write_envelope`] to call with/without the zstd compression wrapper.
Expand Down Expand Up @@ -558,12 +557,20 @@ pub enum ExtensionBreakingError {
#[error("Failed to deserialize used extensions metadata")]
Deserialization(#[from] serde_json::Error),
}
/// If HUGR metadata contains a list of used extensions, under the key [`USED_EXTENSIONS_KEY`],
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this comment is true, from tests it seems that if MAJOR is 0 then minor must match, otherwise a later minor version is OK.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment is wrong, fixed

/// and extension is resolved in the HUGR, check that the
/// version of the extension in the metadata matches the resolved version (up to
/// MAJOR.MINOR).
fn check_breaking_extensions(hugr: impl crate::HugrView) -> Result<(), ExtensionBreakingError> {
check_breaking_extensions_against_registry(&hugr, hugr.extensions())
}

/// If HUGR metadata contains a list of used extensions, under the key [`USED_EXTENSIONS_KEY`],
/// and extension is registered in the given registry, check that the
/// version of the extension in the metadata matches the registered version (up to
/// MAJOR.MINOR).
fn check_breaking_extensions(
hugr: impl crate::HugrView,
fn check_breaking_extensions_against_registry(
hugr: &impl crate::HugrView,
registry: &ExtensionRegistry,
) -> Result<(), ExtensionBreakingError> {
let Some(exts) = hugr.get_metadata(hugr.module_root(), USED_EXTENSIONS_KEY) else {
Expand Down Expand Up @@ -740,6 +747,11 @@ pub(crate) mod test {
assert_eq!(package, new_package);
}

/// Test helper to call `check_breaking_extensions_against_registry`
fn check(hugr: &Hugr, registry: &ExtensionRegistry) -> Result<(), ExtensionBreakingError> {
check_breaking_extensions_against_registry(hugr, registry)
}

#[rstest]
#[case::simple(simple_package())]
fn test_check_breaking_extensions(#[case] mut package: Package) {
Expand All @@ -756,35 +768,40 @@ pub(crate) mod test {
let mut hugr = package.modules.remove(0);

// No metadata - should pass
assert_matches!(check_breaking_extensions(&hugr, &registry), Ok(()));
assert_matches!(check(&hugr, &registry), Ok(()));

// Matching version for v0 - should pass
let used_exts = json!([{ "name": "test-v0", "version": "0.2.3" }]);
hugr.set_metadata(hugr.module_root(), USED_EXTENSIONS_KEY, used_exts);
assert_matches!(check_breaking_extensions(&hugr, &registry), Ok(()));
assert_matches!(check(&hugr, &registry), Ok(()));

// Matching major/minor but different patch for v0 - should pass
let used_exts = json!([{ "name": "test-v0", "version": "0.2.4" }]);
hugr.set_metadata(hugr.module_root(), USED_EXTENSIONS_KEY, used_exts);
assert_matches!(check_breaking_extensions(&hugr, &registry), Ok(()));
assert_matches!(check(&hugr, &registry), Ok(()));

//Different minor version for v0 - should fail
let used_exts = json!([{ "name": "test-v0", "version": "0.3.3" }]);
hugr.set_metadata(hugr.module_root(), USED_EXTENSIONS_KEY, used_exts);
assert_matches!(
check_breaking_extensions(&hugr, &registry),
check(&hugr, &registry),
Err(ExtensionBreakingError::ExtensionVersionMismatch(ExtensionVersionMismatch {
name,
registered,
used
})) if name == "test-v0" && registered == Version::new(0, 2, 3) && used == Version::new(0, 3, 3)
);

assert!(
check_breaking_extensions(&hugr).is_ok(),
"Extension is not actually used in the HUGR, should be ignored by full check"
);

// Different major version for v0 - should fail
let used_exts = json!([{ "name": "test-v0", "version": "1.2.3" }]);
hugr.set_metadata(hugr.module_root(), USED_EXTENSIONS_KEY, used_exts);
assert_matches!(
check_breaking_extensions(&hugr, &registry),
check(&hugr, &registry),
Err(ExtensionBreakingError::ExtensionVersionMismatch(ExtensionVersionMismatch {
name,
registered,
Expand All @@ -795,23 +812,23 @@ pub(crate) mod test {
// Matching version for v1 - should pass
let used_exts = json!([{ "name": "test-v1", "version": "1.2.3" }]);
hugr.set_metadata(hugr.module_root(), USED_EXTENSIONS_KEY, used_exts);
assert_matches!(check_breaking_extensions(&hugr, &registry), Ok(()));
assert_matches!(check(&hugr, &registry), Ok(()));

// Different minor version for v1 - should pass
let used_exts = json!([{ "name": "test-v1", "version": "1.3.0" }]);
hugr.set_metadata(hugr.module_root(), USED_EXTENSIONS_KEY, used_exts);
assert_matches!(check_breaking_extensions(&hugr, &registry), Ok(()));
assert_matches!(check(&hugr, &registry), Ok(()));

// Different patch for v1 - should pass
let used_exts = json!([{ "name": "test-v1", "version": "1.2.4" }]);
hugr.set_metadata(hugr.module_root(), USED_EXTENSIONS_KEY, used_exts);
assert_matches!(check_breaking_extensions(&hugr, &registry), Ok(()));
assert_matches!(check(&hugr, &registry), Ok(()));

// Different major version for v1 - should fail
let used_exts = json!([{ "name": "test-v1", "version": "2.2.3" }]);
hugr.set_metadata(hugr.module_root(), USED_EXTENSIONS_KEY, used_exts);
assert_matches!(
check_breaking_extensions(&hugr, &registry),
check(&hugr, &registry),
Err(ExtensionBreakingError::ExtensionVersionMismatch(ExtensionVersionMismatch {
name,
registered,
Expand All @@ -822,7 +839,7 @@ pub(crate) mod test {
// Non-registered extension - should pass
let used_exts = json!([{ "name": "unknown", "version": "1.0.0" }]);
hugr.set_metadata(hugr.module_root(), USED_EXTENSIONS_KEY, used_exts);
assert_matches!(check_breaking_extensions(&hugr, &registry), Ok(()));
assert_matches!(check(&hugr, &registry), Ok(()));

// Multiple extensions - one mismatch should fail
let used_exts = json!([
Expand All @@ -831,7 +848,7 @@ pub(crate) mod test {
]);
hugr.set_metadata(hugr.module_root(), USED_EXTENSIONS_KEY, used_exts);
assert_matches!(
check_breaking_extensions(&hugr, &registry),
check(&hugr, &registry),
Err(ExtensionBreakingError::ExtensionVersionMismatch(ExtensionVersionMismatch {
name,
registered,
Expand All @@ -846,7 +863,7 @@ pub(crate) mod test {
json!("not an array"),
);
assert_matches!(
check_breaking_extensions(&hugr, &registry),
check(&hugr, &registry),
Err(ExtensionBreakingError::Deserialization(_))
);

Expand All @@ -856,7 +873,7 @@ pub(crate) mod test {
{ "name": "test-v1", "version": "1.9.9" }
]);
hugr.set_metadata(hugr.module_root(), USED_EXTENSIONS_KEY, used_exts);
assert_matches!(check_breaking_extensions(&hugr, &registry), Ok(()));
assert_matches!(check(&hugr, &registry), Ok(()));
}

#[test]
Expand All @@ -875,7 +892,7 @@ pub(crate) mod test {
hugr.set_metadata(hugr.module_root(), USED_EXTENSIONS_KEY, used_exts);

// Create the error and wrap it with WithGenerator
let err = check_breaking_extensions(&hugr, &registry).unwrap_err();
let err = check_breaking_extensions_against_registry(&hugr, &registry).unwrap_err();
let with_gen = WithGenerator::new(err, &[&hugr]);

let err_msg = with_gen.to_string();
Expand Down
13 changes: 5 additions & 8 deletions hugr-core/src/envelope/package_json.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ use crate::{Extension, Hugr};
pub(super) fn from_json_reader(
reader: impl io::Read,
extension_registry: &ExtensionRegistry,
) -> Result<(Package, ExtensionRegistry), PackageEncodingError> {
) -> Result<Package, PackageEncodingError> {
let val: serde_json::Value = serde_json::from_reader(reader)?;

let PackageDeser {
Expand All @@ -38,13 +38,10 @@ pub(super) fn from_json_reader(
.try_for_each(|module| module.resolve_extension_defs(&combined_registry))
.map_err(|err| WithGenerator::new(err, &modules))?;

Ok((
Package {
modules,
extensions: pkg_extensions,
},
combined_registry,
))
Ok(Package {
modules,
extensions: pkg_extensions,
})
}

/// Write the Package in json format into an io writer.
Expand Down
Loading