Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 12 additions & 13 deletions hugr-core/src/envelope.rs
Original file line number Diff line number Diff line change
Expand Up @@ -397,13 +397,13 @@ fn decode_model(
let bump = Bump::default();
let model_package = hugr_model::v0::binary::read_from_reader(&mut stream, &bump)?;

let mut extension_registry = extension_registry.clone();
if format == EnvelopeFormat::ModelWithExtensions {
let extra_extensions = ExtensionRegistry::load_json(stream, &extension_registry)?;
extension_registry.extend(extra_extensions);
}
let packaged_extensions = if format == EnvelopeFormat::ModelWithExtensions {
ExtensionRegistry::load_json(stream, extension_registry)?
} else {
ExtensionRegistry::new([])
};

let package = import_package(&model_package, &extension_registry)?;
let package = import_package(&model_package, packaged_extensions, extension_registry)?;

Ok(package)
}
Expand Down Expand Up @@ -432,18 +432,17 @@ fn decode_model_ast(
) -> Result<Package, EnvelopeError> {
check_model_version(format)?;

let mut extension_registry = extension_registry.clone();
if format == EnvelopeFormat::ModelTextWithExtensions {
let packaged_extensions = if format == EnvelopeFormat::ModelTextWithExtensions {
let deserializer = serde_json::Deserializer::from_reader(&mut stream);
// Deserialize the first json object, leaving the rest of the reader unconsumed.
let extra_extensions = deserializer
.into_iter::<Vec<Extension>>()
.next()
.unwrap_or(Ok(vec![]))?;
for ext in extra_extensions {
extension_registry.register_updated(ext);
}
}
ExtensionRegistry::new(extra_extensions.into_iter().map(std::sync::Arc::new))
} else {
ExtensionRegistry::new([])
};

// Read the package into a string, then parse it.
//
Expand All @@ -455,7 +454,7 @@ fn decode_model_ast(
let bump = Bump::default();
let model_package = ast_package.resolve(&bump)?;

let package = import_package(&model_package, &extension_registry)?;
let package = import_package(&model_package, packaged_extensions, extension_registry)?;

Ok(package)
}
Expand Down
13 changes: 9 additions & 4 deletions hugr-core/src/import.rs
Original file line number Diff line number Diff line change
Expand Up @@ -182,19 +182,24 @@ macro_rules! error_context {
}
}

/// Import a [`Package`] from its model representation.
/// Import a [`Package`] from the model representation
/// of the modules and any included extensions.
pub fn import_package(
package: &table::Package,
extensions: &ExtensionRegistry,
packaged_extensions: ExtensionRegistry,
loaded_extensions: &ExtensionRegistry,
) -> Result<Package, ImportError> {
let mut registry = loaded_extensions.clone();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reading through the rest of this I kept thinking, why are we not building this registry earlier....now I see (below) that we need to keep them separate so we can store the correct one in the Package, so ok.

Should we return this new/merged ExtensionRegistry, though?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The extension resolution forms a merged registry of all the extensions actually used in the HUGR from this larger set of packaged + loaded, that is stored in Hugr.extensions
See the PR this is based on:
#2588

registry.extend(&packaged_extensions);
let modules = package
.modules
.iter()
.map(|module| import_hugr(module, extensions))
.map(|module| import_hugr(module, &registry))
.collect::<Result<Vec<_>, _>>()?;

// This does not panic since the import already requires a module root.
let package = Package::new(modules);
let mut package = Package::new(modules);
package.extensions = packaged_extensions;
Ok(package)
}

Expand Down
46 changes: 44 additions & 2 deletions hugr-core/tests/model.rs
Original file line number Diff line number Diff line change
@@ -1,17 +1,26 @@
#![allow(missing_docs)]

use anyhow::Result;
use rstest::{fixture, rstest};
use std::str::FromStr;

use hugr::std_extensions::std_reg;
use hugr::{
Extension, Hugr,
builder::{Dataflow as _, DataflowHugr as _},
envelope::{EnvelopeConfig, EnvelopeFormat, read_envelope, write_envelope},
extension::prelude::bool_t,
package::Package,
std_extensions::std_reg,
types::Signature,
};
use hugr_core::{export::export_package, import::import_package};
use hugr_model::v0 as model;

fn roundtrip(source: &str) -> Result<String> {
let bump = model::bumpalo::Bump::new();
let package_ast = model::ast::Package::from_str(source)?;
let package_table = package_ast.resolve(&bump)?;
let core = import_package(&package_table, &std_reg())?;
let core = import_package(&package_table, Default::default(), &std_reg())?;
let exported_table = export_package(&core.modules, &core.extensions, &bump);
let exported_ast = exported_table.as_ast().unwrap();

Expand Down Expand Up @@ -83,3 +92,36 @@ test_roundtrip!(
test_roundtrip_entrypoint,
"../../hugr-model/tests/fixtures/model-entrypoint.edn"
);

#[fixture]
fn simple_dfg_hugr() -> Hugr {
let dfg_builder =
hugr::builder::DFGBuilder::new(Signature::new(vec![bool_t()], vec![bool_t()])).unwrap();
let [i1] = dfg_builder.input_wires_arr();
dfg_builder.finish_hugr_with_outputs([i1]).unwrap()
}

#[rstest]
#[case(EnvelopeFormat::ModelTextWithExtensions)]
#[case(EnvelopeFormat::ModelWithExtensions)]
fn import_package_with_extensions(#[case] format: EnvelopeFormat, simple_dfg_hugr: Hugr) {
let ext = Extension::new_arc(
"miniquantum".try_into().unwrap(),
hugr::extension::Version::new(0, 1, 0),
|_, _| {},
);
let mut package = Package::new([simple_dfg_hugr]);
package.extensions.register_updated(ext);

let mut bytes: Vec<u8> = Vec::new();
write_envelope(&mut bytes, &package, EnvelopeConfig::new(format)).unwrap();

let buff = std::io::BufReader::new(bytes.as_slice());
let (_, loaded_pkg) = read_envelope(buff, &std_reg()).unwrap();

assert_eq!(loaded_pkg.extensions.len(), 1);
let read_ext = loaded_pkg.extensions.iter().next().unwrap();
assert_eq!(read_ext.name(), &"miniquantum".try_into().unwrap());

assert_eq!(package, loaded_pkg);
}
Loading