Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 25 additions & 3 deletions hugr-cli/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -247,13 +247,35 @@ impl CliArgs {
///
/// The `gen-extensions` and `external` commands don't support byte I/O
/// and should use the normal `run_cli()` method instead.
pub fn run_with_io(self, input: impl std::io::Read) -> Result<Vec<u8>> {
pub fn run_with_io(self, input: impl std::io::Read) -> Result<Vec<u8>, RunWithIoError> {
let mut output = Vec::new();
self.command.run_with_io(Some(input), Some(&mut output))?;
Ok(output)
let is_describe = matches!(self.command, CliCommand::Describe(_));
let res = self.command.run_with_io(Some(input), Some(&mut output));
match (res, is_describe) {
(Ok(()), _) => Ok(output),
(Err(e), true) => Err(RunWithIoError::Describe { source: e, output }),
(Err(e), false) => Err(RunWithIoError::Other(e)),
}
}
}

#[derive(Debug, Error)]
#[non_exhaustive]
#[error("Error running CLI command with IO.")]
/// Error type for `run_with_io` method.
pub enum RunWithIoError {
/// Error describing HUGR package.
Describe {
#[source]
/// Error returned from describe command.
source: anyhow::Error,
/// Describe command output.
output: Vec<u8>,
},
/// Non-describe command error.
Other(anyhow::Error),
}

fn run_external(args: Vec<OsString>) -> Result<()> {
// External subcommand support: invoke `hugr-<subcommand>`
if args.is_empty() {
Expand Down
106 changes: 104 additions & 2 deletions hugr-core/src/envelope/reader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use hugr_model::v0::table;
use itertools::{Either, Itertools as _};

use crate::HugrView as _;
use crate::envelope::description::PackageDesc;
use crate::envelope::description::{ExtensionDesc, ModuleDesc, PackageDesc};
use crate::envelope::header::{EnvelopeFormat, HeaderError};
use crate::envelope::{
EnvelopeError, EnvelopeHeader, ExtensionBreakingError, FormatUnsupportedError,
Expand Down Expand Up @@ -103,6 +103,32 @@ impl<R: BufRead> EnvelopeReader<R> {
self.registry.extend(extensions);
}

/// Handle extension resolution errors by recording missing extensions in the description.
///
/// This function inspects the error and adds any missing extensions to the module description
/// with a default version of 0.0.0.
fn handle_resolution_error(desc: &mut ModuleDesc, err: &ExtensionResolutionError) {
match err {
ExtensionResolutionError::MissingOpExtension {
missing_extension, ..
}
| ExtensionResolutionError::MissingTypeExtension {
missing_extension, ..
} => desc.extend_used_extensions_resolved([ExtensionDesc::new(
missing_extension,
crate::extension::Version::new(0, 0, 0),
)]),
ExtensionResolutionError::InvalidConstTypes {
missing_extensions, ..
} => desc.extend_used_extensions_resolved(
missing_extensions
.iter()
.map(|ext| ExtensionDesc::new(ext, crate::extension::Version::new(0, 0, 0))),
),
_ => {}
}
}

fn read_impl(&mut self) -> Result<Package, PayloadError> {
let mut package = match self.header().format {
EnvelopeFormat::PackageJson => self.decode_json()?,
Expand All @@ -121,7 +147,10 @@ impl<R: BufRead> EnvelopeReader<R> {
check_breaking_extensions(module.extensions(), used_exts.drain(..))?;
}

module.resolve_extension_defs(&self.registry)?;
module
.resolve_extension_defs(&self.registry)
.inspect_err(|err| Self::handle_resolution_error(desc, err))?;

// overwrite the description with the actual module read,
// cheap so ok to repeat.
desc.load_from_hugr(&module);
Expand Down Expand Up @@ -415,4 +444,77 @@ mod test {
assert_eq!(description.header, header);
assert_eq!(description.n_modules(), 0); // No valid modules should be set
}

#[test]
fn test_handle_resolution_error() {
use crate::extension::ExtensionId;
use crate::ops::{OpName, constant::ValueName};
use crate::types::TypeName;

let mut desc = ModuleDesc::default();
let handle_error = |d: &mut ModuleDesc, err: &ExtensionResolutionError| {
EnvelopeReader::<Cursor<Vec<u8>>>::handle_resolution_error(d, err)
};
let assert_extensions = |d: &ModuleDesc, expected_ids: &[&ExtensionId]| {
let resolved = d.used_extensions_resolved.as_ref().unwrap();
assert_eq!(resolved.len(), expected_ids.len());
let names: Vec<_> = resolved.iter().map(|e| &e.name).collect();
for ext_id in expected_ids {
assert!(names.contains(&&ext_id.to_string()));
}
assert!(
resolved
.iter()
.all(|e| e.version == crate::extension::Version::new(0, 0, 0))
);
};

// Test MissingOpExtension
let ext_id = ExtensionId::new("test.extension").unwrap();
let error = ExtensionResolutionError::MissingOpExtension {
node: None,
op: OpName::new("test.op"),
missing_extension: ext_id.clone(),
available_extensions: vec![],
};
handle_error(&mut desc, &error);
assert_extensions(&desc, &[&ext_id]);

// Test MissingTypeExtension
desc.used_extensions_resolved = None;
let ext_id2 = ExtensionId::new("test.extension2").unwrap();
let error = ExtensionResolutionError::MissingTypeExtension {
node: None,
ty: TypeName::new("test.type"),
missing_extension: ext_id2.clone(),
available_extensions: vec![],
};
handle_error(&mut desc, &error);
assert_extensions(&desc, &[&ext_id2]);

// Test InvalidConstTypes with multiple extensions
desc.used_extensions_resolved = None;
let ext_id3 = ExtensionId::new("test.extension3").unwrap();
let ext_id4 = ExtensionId::new("test.extension4").unwrap();
let mut missing_exts = crate::extension::ExtensionSet::new();
missing_exts.insert(ext_id3.clone());
missing_exts.insert(ext_id4.clone());

let error = ExtensionResolutionError::InvalidConstTypes {
value: ValueName::new("test.value"),
missing_extensions: missing_exts,
};
handle_error(&mut desc, &error);
assert_extensions(&desc, &[&ext_id3, &ext_id4]);

// Test other error variant (should not add anything)
desc.used_extensions_resolved = None;
let error = ExtensionResolutionError::WrongTypeDefExtension {
extension: ExtensionId::new("ext1").unwrap(),
def: TypeName::new("def"),
wrong_extension: ExtensionId::new("ext2").unwrap(),
};
handle_error(&mut desc, &error);
assert!(desc.used_extensions_resolved.is_none());
}
}
48 changes: 43 additions & 5 deletions hugr-py/rust/lib.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,41 @@
//! Supporting Rust library for the hugr Python bindings.

use hugr_cli::CliArgs;
use hugr_cli::{CliArgs, RunWithIoError};
use hugr_model::v0::ast;
use pyo3::{exceptions::PyValueError, prelude::*};
use pyo3::{create_exception, exceptions::PyException, exceptions::PyValueError, prelude::*};

// Define custom exceptions
create_exception!(
_hugr,
HugrCliError,
PyException,
"Base exception for HUGR CLI errors."
);
create_exception!(
_hugr,
HugrCliDescribeError,
HugrCliError,
"Exception for HUGR CLI describe command errors with partial output."
);

/// Helper to convert RunWithIoError to Python exception
fn cli_error_to_py(err: RunWithIoError) -> PyErr {
match err {
RunWithIoError::Describe { source, output } => {
// Convert output bytes to string, falling back to empty string if invalid UTF-8
let output_str = String::from_utf8(output).unwrap_or_else(|e| {
format!("<Invalid UTF-8 output: {} bytes>", e.as_bytes().len())
});

HugrCliDescribeError::new_err((format!("{:?}", source), output_str))
}
RunWithIoError::Other(e) => HugrCliError::new_err(format!("{:?}", e)),
_ => {
// Catch-all for any future error variants (non_exhaustive enum)
HugrCliError::new_err(format!("{:?}", err))
}
}
}

macro_rules! syntax_to_and_from_string {
($name:ident, $ty:ty) => {
Expand Down Expand Up @@ -93,13 +126,18 @@ fn cli_with_io(mut args: Vec<String>, input_bytes: Option<&[u8]>) -> PyResult<Ve
args.insert(0, String::new());
let cli_args = CliArgs::new_from_args(args);
let input = input_bytes.unwrap_or(&[]);
cli_args
.run_with_io(input)
.map_err(|e| PyValueError::new_err(format!("{:?}", e)))
cli_args.run_with_io(input).map_err(cli_error_to_py)
}

#[pymodule]
fn _hugr(m: &Bound<'_, PyModule>) -> PyResult<()> {
// Register custom exceptions
m.add("HugrCliError", m.py().get_type::<HugrCliError>())?;
m.add(
"HugrCliDescribeError",
m.py().get_type::<HugrCliDescribeError>(),
)?;

m.add_function(wrap_pyfunction!(term_to_string, m)?)?;
m.add_function(wrap_pyfunction!(string_to_term, m)?)?;
m.add_function(wrap_pyfunction!(node_to_string, m)?)?;
Expand Down
3 changes: 3 additions & 0 deletions hugr-py/src/hugr/_hugr/__init__.pyi
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import hugr.model

class HugrCliError(Exception): ...
class HugrCliDescribeError(HugrCliError): ...

def term_to_string(term: hugr.model.Term) -> str: ...
def string_to_term(string: str) -> hugr.model.Term: ...
def node_to_string(node: hugr.model.Node) -> str: ...
Expand Down
39 changes: 27 additions & 12 deletions hugr-py/src/hugr/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from pydantic import BaseModel

from hugr._hugr import cli_with_io
from hugr._hugr import HugrCliDescribeError, HugrCliError, cli_with_io

__all__ = [
"cli_with_io",
Expand All @@ -20,6 +20,8 @@
"ModuleDesc",
"ExtensionDesc",
"EntrypointDesc",
"HugrCliError",
"HugrCliDescribeError",
]


Expand Down Expand Up @@ -50,7 +52,7 @@ def validate(
extensions: Paths to additional serialised extensions needed to load the HUGR.

Raises:
ValueError: On validation failure.
HugrCliError: On validation failure or other CLI errors.
"""
args = _add_input_args(["validate"], no_std, extensions)
cli_with_io(args, hugr_bytes)
Expand Down Expand Up @@ -175,6 +177,10 @@ def describe_str(

Returns:
Text description of the package.

Raises:
HugrCliDescribeError: On error during package description. The exception
contains partial output if available.
"""
args = ["describe"]
if _json:
Expand Down Expand Up @@ -222,16 +228,19 @@ def describe(
Returns:
Structured package description as a PackageDesc object.
"""
output = describe_str(
hugr_bytes,
_json=True,
packaged_extensions=packaged_extensions,
no_resolved_extensions=no_resolved_extensions,
public_symbols=public_symbols,
generator_claimed_extensions=generator_claimed_extensions,
no_std=no_std,
extensions=extensions,
)
try:
output = describe_str(
hugr_bytes,
_json=True,
packaged_extensions=packaged_extensions,
no_resolved_extensions=no_resolved_extensions,
public_symbols=public_symbols,
generator_claimed_extensions=generator_claimed_extensions,
no_std=no_std,
extensions=extensions,
)
except HugrCliDescribeError as e:
output = e.args[1]
return PackageDesc.model_validate_json(output)


Expand Down Expand Up @@ -265,6 +274,9 @@ def convert(

Returns:
Converted package as bytes.

Raises:
HugrCliError: On conversion failure or other CLI errors.
"""
args = ["convert"]
if format is not None:
Expand Down Expand Up @@ -300,6 +312,9 @@ def mermaid(

Returns:
Mermaid diagram output as a string.

Raises:
HugrCliError: On mermaid generation failure or other CLI errors.
"""
args = ["mermaid"]
if validate:
Expand Down
36 changes: 33 additions & 3 deletions hugr-py/tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@

import pytest

from hugr import cli
from hugr.build import Module
from hugr import cli, tys
from hugr.build import Dfg, Module
from hugr.ext import Extension
from hugr.package import Package

Expand Down Expand Up @@ -38,7 +38,7 @@ def test_validate_with_bytes_invalid():

invalid_bytes = b"not a valid hugr package"

with pytest.raises(ValueError, match="Bad magic number"):
with pytest.raises(cli.HugrCliError, match="Bad magic number"):
cli.cli_with_io(["validate"], invalid_bytes)


Expand Down Expand Up @@ -139,3 +139,33 @@ def test_describe_json_with_packaged_extensions(hugr_with_extension_bytes: bytes

assert desc.uses_extension("ext")
assert not desc.uses_extension("nonexistent_extension")


@pytest.fixture
def hugr_using_ext() -> bytes:
"""A simple HUGR package that uses an extension, but doesn't package it."""
ext = Extension.from_json(EXAMPLE)
u_t = tys.USize()
op = ext.get_op("New").instantiate(
[u_t.type_arg()], concrete_signature=tys.FunctionType([u_t], [])
)
h = Dfg(u_t)
a = h.inputs()[0]
h.add_op(op, a)
h.set_outputs()

package = Package([h.hugr], [])

return package.to_bytes()


def test_failed_describe(hugr_using_ext):
"""Json description still succeeds, with error field populated"""
desc = cli.describe(hugr_using_ext)
mod = desc.modules[0]
assert mod is not None
assert mod.num_nodes == 8 # computed before error
assert isinstance(desc.error, str)
assert "requires extension ext" in desc.error

assert desc.uses_extension("ext")
Loading