diff --git a/hugr-cli/src/lib.rs b/hugr-cli/src/lib.rs index d54592db5..810ad4167 100644 --- a/hugr-cli/src/lib.rs +++ b/hugr-cli/src/lib.rs @@ -247,13 +247,35 @@ impl CliArgs { /// /// The `gen-extensions` and `external` commands don't support byte I/O /// and should use the normal `run_cli()` method instead. - pub fn run_with_io(self, input: impl std::io::Read) -> Result> { + pub fn run_with_io(self, input: impl std::io::Read) -> Result, RunWithIoError> { let mut output = Vec::new(); - self.command.run_with_io(Some(input), Some(&mut output))?; - Ok(output) + let is_describe = matches!(self.command, CliCommand::Describe(_)); + let res = self.command.run_with_io(Some(input), Some(&mut output)); + match (res, is_describe) { + (Ok(()), _) => Ok(output), + (Err(e), true) => Err(RunWithIoError::Describe { source: e, output }), + (Err(e), false) => Err(RunWithIoError::Other(e)), + } } } +#[derive(Debug, Error)] +#[non_exhaustive] +#[error("Error running CLI command with IO.")] +/// Error type for `run_with_io` method. +pub enum RunWithIoError { + /// Error describing HUGR package. + Describe { + #[source] + /// Error returned from describe command. + source: anyhow::Error, + /// Describe command output. + output: Vec, + }, + /// Non-describe command error. + Other(anyhow::Error), +} + fn run_external(args: Vec) -> Result<()> { // External subcommand support: invoke `hugr-` if args.is_empty() { diff --git a/hugr-core/src/envelope/reader.rs b/hugr-core/src/envelope/reader.rs index 0d3f7f6d8..0045d0959 100644 --- a/hugr-core/src/envelope/reader.rs +++ b/hugr-core/src/envelope/reader.rs @@ -5,7 +5,7 @@ use hugr_model::v0::table; use itertools::{Either, Itertools as _}; use crate::HugrView as _; -use crate::envelope::description::PackageDesc; +use crate::envelope::description::{ExtensionDesc, ModuleDesc, PackageDesc}; use crate::envelope::header::{EnvelopeFormat, HeaderError}; use crate::envelope::{ EnvelopeError, EnvelopeHeader, ExtensionBreakingError, FormatUnsupportedError, @@ -103,6 +103,32 @@ impl EnvelopeReader { self.registry.extend(extensions); } + /// Handle extension resolution errors by recording missing extensions in the description. + /// + /// This function inspects the error and adds any missing extensions to the module description + /// with a default version of 0.0.0. + fn handle_resolution_error(desc: &mut ModuleDesc, err: &ExtensionResolutionError) { + match err { + ExtensionResolutionError::MissingOpExtension { + missing_extension, .. + } + | ExtensionResolutionError::MissingTypeExtension { + missing_extension, .. + } => desc.extend_used_extensions_resolved([ExtensionDesc::new( + missing_extension, + crate::extension::Version::new(0, 0, 0), + )]), + ExtensionResolutionError::InvalidConstTypes { + missing_extensions, .. + } => desc.extend_used_extensions_resolved( + missing_extensions + .iter() + .map(|ext| ExtensionDesc::new(ext, crate::extension::Version::new(0, 0, 0))), + ), + _ => {} + } + } + fn read_impl(&mut self) -> Result { let mut package = match self.header().format { EnvelopeFormat::PackageJson => self.decode_json()?, @@ -121,7 +147,10 @@ impl EnvelopeReader { check_breaking_extensions(module.extensions(), used_exts.drain(..))?; } - module.resolve_extension_defs(&self.registry)?; + module + .resolve_extension_defs(&self.registry) + .inspect_err(|err| Self::handle_resolution_error(desc, err))?; + // overwrite the description with the actual module read, // cheap so ok to repeat. desc.load_from_hugr(&module); @@ -415,4 +444,77 @@ mod test { assert_eq!(description.header, header); assert_eq!(description.n_modules(), 0); // No valid modules should be set } + + #[test] + fn test_handle_resolution_error() { + use crate::extension::ExtensionId; + use crate::ops::{OpName, constant::ValueName}; + use crate::types::TypeName; + + let mut desc = ModuleDesc::default(); + let handle_error = |d: &mut ModuleDesc, err: &ExtensionResolutionError| { + EnvelopeReader::>>::handle_resolution_error(d, err) + }; + let assert_extensions = |d: &ModuleDesc, expected_ids: &[&ExtensionId]| { + let resolved = d.used_extensions_resolved.as_ref().unwrap(); + assert_eq!(resolved.len(), expected_ids.len()); + let names: Vec<_> = resolved.iter().map(|e| &e.name).collect(); + for ext_id in expected_ids { + assert!(names.contains(&&ext_id.to_string())); + } + assert!( + resolved + .iter() + .all(|e| e.version == crate::extension::Version::new(0, 0, 0)) + ); + }; + + // Test MissingOpExtension + let ext_id = ExtensionId::new("test.extension").unwrap(); + let error = ExtensionResolutionError::MissingOpExtension { + node: None, + op: OpName::new("test.op"), + missing_extension: ext_id.clone(), + available_extensions: vec![], + }; + handle_error(&mut desc, &error); + assert_extensions(&desc, &[&ext_id]); + + // Test MissingTypeExtension + desc.used_extensions_resolved = None; + let ext_id2 = ExtensionId::new("test.extension2").unwrap(); + let error = ExtensionResolutionError::MissingTypeExtension { + node: None, + ty: TypeName::new("test.type"), + missing_extension: ext_id2.clone(), + available_extensions: vec![], + }; + handle_error(&mut desc, &error); + assert_extensions(&desc, &[&ext_id2]); + + // Test InvalidConstTypes with multiple extensions + desc.used_extensions_resolved = None; + let ext_id3 = ExtensionId::new("test.extension3").unwrap(); + let ext_id4 = ExtensionId::new("test.extension4").unwrap(); + let mut missing_exts = crate::extension::ExtensionSet::new(); + missing_exts.insert(ext_id3.clone()); + missing_exts.insert(ext_id4.clone()); + + let error = ExtensionResolutionError::InvalidConstTypes { + value: ValueName::new("test.value"), + missing_extensions: missing_exts, + }; + handle_error(&mut desc, &error); + assert_extensions(&desc, &[&ext_id3, &ext_id4]); + + // Test other error variant (should not add anything) + desc.used_extensions_resolved = None; + let error = ExtensionResolutionError::WrongTypeDefExtension { + extension: ExtensionId::new("ext1").unwrap(), + def: TypeName::new("def"), + wrong_extension: ExtensionId::new("ext2").unwrap(), + }; + handle_error(&mut desc, &error); + assert!(desc.used_extensions_resolved.is_none()); + } } diff --git a/hugr-py/rust/lib.rs b/hugr-py/rust/lib.rs index f988498d8..7bb8dacf8 100644 --- a/hugr-py/rust/lib.rs +++ b/hugr-py/rust/lib.rs @@ -1,8 +1,41 @@ //! Supporting Rust library for the hugr Python bindings. -use hugr_cli::CliArgs; +use hugr_cli::{CliArgs, RunWithIoError}; use hugr_model::v0::ast; -use pyo3::{exceptions::PyValueError, prelude::*}; +use pyo3::{create_exception, exceptions::PyException, exceptions::PyValueError, prelude::*}; + +// Define custom exceptions +create_exception!( + _hugr, + HugrCliError, + PyException, + "Base exception for HUGR CLI errors." +); +create_exception!( + _hugr, + HugrCliDescribeError, + HugrCliError, + "Exception for HUGR CLI describe command errors with partial output." +); + +/// Helper to convert RunWithIoError to Python exception +fn cli_error_to_py(err: RunWithIoError) -> PyErr { + match err { + RunWithIoError::Describe { source, output } => { + // Convert output bytes to string, falling back to empty string if invalid UTF-8 + let output_str = String::from_utf8(output).unwrap_or_else(|e| { + format!("", e.as_bytes().len()) + }); + + HugrCliDescribeError::new_err((format!("{:?}", source), output_str)) + } + RunWithIoError::Other(e) => HugrCliError::new_err(format!("{:?}", e)), + _ => { + // Catch-all for any future error variants (non_exhaustive enum) + HugrCliError::new_err(format!("{:?}", err)) + } + } +} macro_rules! syntax_to_and_from_string { ($name:ident, $ty:ty) => { @@ -93,13 +126,18 @@ fn cli_with_io(mut args: Vec, input_bytes: Option<&[u8]>) -> PyResult) -> PyResult<()> { + // Register custom exceptions + m.add("HugrCliError", m.py().get_type::())?; + m.add( + "HugrCliDescribeError", + m.py().get_type::(), + )?; + m.add_function(wrap_pyfunction!(term_to_string, m)?)?; m.add_function(wrap_pyfunction!(string_to_term, m)?)?; m.add_function(wrap_pyfunction!(node_to_string, m)?)?; diff --git a/hugr-py/src/hugr/_hugr/__init__.pyi b/hugr-py/src/hugr/_hugr/__init__.pyi index 571392e5a..4287361c5 100644 --- a/hugr-py/src/hugr/_hugr/__init__.pyi +++ b/hugr-py/src/hugr/_hugr/__init__.pyi @@ -1,5 +1,8 @@ import hugr.model +class HugrCliError(Exception): ... +class HugrCliDescribeError(HugrCliError): ... + def term_to_string(term: hugr.model.Term) -> str: ... def string_to_term(string: str) -> hugr.model.Term: ... def node_to_string(node: hugr.model.Node) -> str: ... diff --git a/hugr-py/src/hugr/cli.py b/hugr-py/src/hugr/cli.py index 8dc4980c5..2c999d1cf 100644 --- a/hugr-py/src/hugr/cli.py +++ b/hugr-py/src/hugr/cli.py @@ -7,7 +7,7 @@ from pydantic import BaseModel -from hugr._hugr import cli_with_io +from hugr._hugr import HugrCliDescribeError, HugrCliError, cli_with_io __all__ = [ "cli_with_io", @@ -20,6 +20,8 @@ "ModuleDesc", "ExtensionDesc", "EntrypointDesc", + "HugrCliError", + "HugrCliDescribeError", ] @@ -50,7 +52,7 @@ def validate( extensions: Paths to additional serialised extensions needed to load the HUGR. Raises: - ValueError: On validation failure. + HugrCliError: On validation failure or other CLI errors. """ args = _add_input_args(["validate"], no_std, extensions) cli_with_io(args, hugr_bytes) @@ -175,6 +177,10 @@ def describe_str( Returns: Text description of the package. + + Raises: + HugrCliDescribeError: On error during package description. The exception + contains partial output if available. """ args = ["describe"] if _json: @@ -222,16 +228,19 @@ def describe( Returns: Structured package description as a PackageDesc object. """ - output = describe_str( - hugr_bytes, - _json=True, - packaged_extensions=packaged_extensions, - no_resolved_extensions=no_resolved_extensions, - public_symbols=public_symbols, - generator_claimed_extensions=generator_claimed_extensions, - no_std=no_std, - extensions=extensions, - ) + try: + output = describe_str( + hugr_bytes, + _json=True, + packaged_extensions=packaged_extensions, + no_resolved_extensions=no_resolved_extensions, + public_symbols=public_symbols, + generator_claimed_extensions=generator_claimed_extensions, + no_std=no_std, + extensions=extensions, + ) + except HugrCliDescribeError as e: + output = e.args[1] return PackageDesc.model_validate_json(output) @@ -265,6 +274,9 @@ def convert( Returns: Converted package as bytes. + + Raises: + HugrCliError: On conversion failure or other CLI errors. """ args = ["convert"] if format is not None: @@ -300,6 +312,9 @@ def mermaid( Returns: Mermaid diagram output as a string. + + Raises: + HugrCliError: On mermaid generation failure or other CLI errors. """ args = ["mermaid"] if validate: diff --git a/hugr-py/tests/test_cli.py b/hugr-py/tests/test_cli.py index 158e4fd30..25f393e0f 100644 --- a/hugr-py/tests/test_cli.py +++ b/hugr-py/tests/test_cli.py @@ -4,8 +4,8 @@ import pytest -from hugr import cli -from hugr.build import Module +from hugr import cli, tys +from hugr.build import Dfg, Module from hugr.ext import Extension from hugr.package import Package @@ -38,7 +38,7 @@ def test_validate_with_bytes_invalid(): invalid_bytes = b"not a valid hugr package" - with pytest.raises(ValueError, match="Bad magic number"): + with pytest.raises(cli.HugrCliError, match="Bad magic number"): cli.cli_with_io(["validate"], invalid_bytes) @@ -139,3 +139,33 @@ def test_describe_json_with_packaged_extensions(hugr_with_extension_bytes: bytes assert desc.uses_extension("ext") assert not desc.uses_extension("nonexistent_extension") + + +@pytest.fixture +def hugr_using_ext() -> bytes: + """A simple HUGR package that uses an extension, but doesn't package it.""" + ext = Extension.from_json(EXAMPLE) + u_t = tys.USize() + op = ext.get_op("New").instantiate( + [u_t.type_arg()], concrete_signature=tys.FunctionType([u_t], []) + ) + h = Dfg(u_t) + a = h.inputs()[0] + h.add_op(op, a) + h.set_outputs() + + package = Package([h.hugr], []) + + return package.to_bytes() + + +def test_failed_describe(hugr_using_ext): + """Json description still succeeds, with error field populated""" + desc = cli.describe(hugr_using_ext) + mod = desc.modules[0] + assert mod is not None + assert mod.num_nodes == 8 # computed before error + assert isinstance(desc.error, str) + assert "requires extension ext" in desc.error + + assert desc.uses_extension("ext")