diff --git a/.github/workflows/ci-py.yml b/.github/workflows/ci-py.yml index 0c9cef18f2..d75ef2dcf6 100644 --- a/.github/workflows/ci-py.yml +++ b/.github/workflows/ci-py.yml @@ -13,8 +13,6 @@ on: env: SCCACHE_GHA_ENABLED: "true" - HUGR_BIN_DIR: ${{ github.workspace }}/target/debug - HUGR_BIN: ${{ github.workspace }}/target/debug/hugr # Pinned version for the uv package manager UV_VERSION: "0.9.5" UV_FROZEN: 1 @@ -68,30 +66,8 @@ jobs: - name: Lint with ruff run: uv run ruff check - build_binary: - needs: changes - if: ${{ needs.changes.outputs.python == 'true' }} - - name: Build HUGR binary - runs-on: ubuntu-latest - env: - SCCACHE_GHA_ENABLED: "true" - RUSTC_WRAPPER: "sccache" - - steps: - - uses: actions/checkout@v5 - - uses: mozilla-actions/sccache-action@v0.0.9 - - name: Install stable toolchain - uses: dtolnay/rust-toolchain@stable - - name: Build HUGR binary - run: cargo build -p hugr-cli - - name: Upload the binary to the artifacts - uses: actions/upload-artifact@v5 - with: - name: hugr_binary - path: target/debug/hugr test: - needs: [changes, build_binary] + needs: [changes] if: ${{ needs.changes.outputs.python == 'true' }} name: test python ${{ matrix.python-version.py }} runs-on: ubuntu-latest @@ -110,12 +86,6 @@ jobs: version: ${{ env.UV_VERSION }} enable-cache: true - - name: Download the hugr binary - uses: actions/download-artifact@v6 - with: - name: hugr_binary - path: ${{env.HUGR_BIN_DIR}} - - name: Setup dependencies run: uv sync --python ${{ matrix.python-version.py }} @@ -125,13 +95,11 @@ jobs: - name: Run tests if: github.event_name == 'merge_group' || !matrix.python-version.coverage run: | - chmod +x $HUGR_BIN HUGR_RENDER_DOT=1 uv run pytest - name: Run python tests with coverage instrumentation if: github.event_name != 'merge_group' && matrix.python-version.coverage run: | - chmod +x $HUGR_BIN HUGR_RENDER_DOT=1 uv run pytest --cov=./ --cov-report=xml - name: Upload python coverage to codecov.io diff --git a/Cargo.lock b/Cargo.lock index e28891b556..cd8dff63c9 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1346,7 +1346,7 @@ name = "hugr-py" version = "0.1.0" dependencies = [ "bumpalo", - "hugr-core", + "hugr-cli", "hugr-model", "pastey", "pyo3", diff --git a/Cargo.toml b/Cargo.toml index 595d677c35..9989a83981 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -129,3 +129,10 @@ jsonschema.debug = 1 [profile.dist] inherits = "release" lto = "thin" + +# The profile that 'hugr-py' will build with +[profile.release-py] +inherits = "release" +lto = "fat" +strip = "symbols" +panic = "abort" diff --git a/hugr-cli/Cargo.toml b/hugr-cli/Cargo.toml index 7cbbb8ec9e..903af035fc 100644 --- a/hugr-cli/Cargo.toml +++ b/hugr-cli/Cargo.toml @@ -15,6 +15,10 @@ categories = ["compilers"] [lib] bench = false +[features] +default = ["tracing"] +tracing = ["dep:tracing", "dep:tracing-subscriber"] + [dependencies] clap = { workspace = true, features = ["derive", "cargo"] } clap-verbosity-flag.workspace = true @@ -25,8 +29,8 @@ serde = { workspace = true, features = ["derive"] } clio = { workspace = true, features = ["clap-parse"] } anyhow.workspace = true thiserror.workspace = true -tracing = "0.1.41" -tracing-subscriber = { version = "0.3.20", features = ["fmt"] } +tracing = { version = "0.1.41", optional = true } +tracing-subscriber = { version = "0.3.20", features = ["fmt"], optional = true } tabled = "0.20.0" schemars = { workspace = true, features = ["derive"] } diff --git a/hugr-cli/src/convert.rs b/hugr-cli/src/convert.rs index 3f73bb2690..6a9aa18fd2 100644 --- a/hugr-cli/src/convert.rs +++ b/hugr-cli/src/convert.rs @@ -3,6 +3,7 @@ use anyhow::Result; use clap::Parser; use clio::Output; use hugr::envelope::{EnvelopeConfig, EnvelopeFormat, ZstdConfig}; +use std::io::{Read, Write}; use crate::CliError; use crate::hugr_io::HugrInputArgs; @@ -47,9 +48,20 @@ pub struct ConvertArgs { } impl ConvertArgs { - /// Convert a HUGR between different envelope formats - pub fn run_convert(&mut self) -> Result<()> { - let (env_config, package) = self.input_args.get_described_package()?; + /// Convert a HUGR between different envelope formats with optional input/output overrides. + /// + /// # Arguments + /// + /// * `input_override` - Optional reader to use instead of the CLI input argument. + /// * `output_override` - Optional writer to use instead of the CLI output argument. + pub fn run_convert_with_io( + &mut self, + input_override: Option, + mut output_override: Option, + ) -> Result<()> { + let (env_config, package) = self + .input_args + .get_described_package_with_reader(input_override)?; // Handle text and binary format flags, which override the format option let mut config = if self.text { @@ -78,8 +90,17 @@ impl ConvertArgs { } // Write the package with the requested format - hugr::envelope::write_envelope(&mut self.output, &package, config)?; + if let Some(ref mut writer) = output_override { + hugr::envelope::write_envelope(writer, &package, config)?; + } else { + hugr::envelope::write_envelope(&mut self.output, &package, config)?; + } Ok(()) } + + /// Convert a HUGR between different envelope formats + pub fn run_convert(&mut self) -> Result<()> { + self.run_convert_with_io(None::<&[u8]>, None::>) + } } diff --git a/hugr-cli/src/describe.rs b/hugr-cli/src/describe.rs index 017e93505c..2ccf3c3e72 100644 --- a/hugr-cli/src/describe.rs +++ b/hugr-cli/src/describe.rs @@ -8,7 +8,7 @@ use hugr::envelope::ReadError; use hugr::envelope::description::{ExtensionDesc, ModuleDesc, PackageDesc}; use hugr::extension::Version; use hugr::package::Package; -use std::io::Write; +use std::io::{Read, Write}; use tabled::Tabled; use tabled::derive::display; @@ -73,20 +73,37 @@ impl ModuleArgs { } } impl DescribeArgs { - /// Load and describe the HUGR package. - pub fn run_describe(&mut self) -> Result<()> { + /// Load and describe the HUGR package with optional input/output overrides. + /// + /// # Arguments + /// + /// * `input_override` - Optional reader to use instead of the CLI input argument. + /// * `output_override` - Optional writer to use instead of the CLI output argument. + pub fn run_describe_with_io( + &mut self, + input_override: Option, + mut output_override: Option, + ) -> Result<()> { if self.json_schema { let schema = schemars::schema_for!(PackageDescriptionJson); let schema_json = serde_json::to_string_pretty(&schema)?; - writeln!(self.output, "{schema_json}")?; + if let Some(ref mut writer) = output_override { + writeln!(writer, "{schema_json}")?; + } else { + writeln!(self.output, "{schema_json}")?; + } return Ok(()); } - let (mut desc, res) = match self.input_args.get_described_package() { + + let (mut desc, res) = match self + .input_args + .get_described_package_with_reader(input_override) + { Ok((desc, pkg)) => (desc, Ok(pkg)), Err(crate::CliError::ReadEnvelope(ReadError::Payload { source, partial_description, - })) => (partial_description, Err(source)), // keep error for later + })) => (partial_description, Err(source)), Err(e) => return Err(e.into()), }; @@ -96,91 +113,116 @@ impl DescribeArgs { } let res = res.map_err(anyhow::Error::from); + + let writer: &mut dyn Write = if let Some(ref mut w) = output_override { + w + } else { + &mut self.output + }; + if self.json { if !self.packaged_extensions { desc.packaged_extensions.clear(); } - self.output_json(desc, &res)?; + output_json(desc, &res, writer)?; } else { - self.print_description(desc)?; + print_description(desc, self.packaged_extensions, writer)?; } // bubble up any errors res.map(|_| ()) } - fn print_description(&mut self, desc: PackageDesc) -> Result<()> { - let header = desc.header(); - let n_modules = desc.n_modules(); - let n_extensions = desc.n_packaged_extensions(); - let module_str = if n_modules == 1 { "module" } else { "modules" }; - let extension_str = if n_extensions == 1 { - "extension" - } else { - "extensions" - }; - writeln!( - self.output, - "{header}\nPackage contains {n_modules} {module_str} and {n_extensions} {extension_str}", - )?; - let summaries: Vec = desc - .modules - .iter() - .map(|m| m.as_ref().map(Into::into).unwrap_or_default()) - .collect(); - let summary_table = tabled::Table::builder(summaries).index().build(); - writeln!(self.output, "{summary_table}")?; + /// Load and describe the HUGR package. + pub fn run_describe(&mut self) -> Result<()> { + self.run_describe_with_io(None::<&[u8]>, None::>) + } +} - for (i, module) in desc.modules.into_iter().enumerate() { - writeln!(self.output, "\nModule {i}:")?; - if let Some(module) = module { - self.display_module(module)?; - } - } - if self.packaged_extensions { - writeln!(self.output, "Packaged extensions:")?; - let ext_rows: Vec = desc - .packaged_extensions - .into_iter() - .flatten() - .map(Into::into) - .collect(); - let ext_table = tabled::Table::new(ext_rows); - writeln!(self.output, "{ext_table}")?; +/// Print a human-readable description of a package. +fn print_description( + desc: PackageDesc, + show_packaged_extensions: bool, + writer: &mut W, +) -> Result<()> { + let header = desc.header(); + let n_modules = desc.n_modules(); + let n_extensions = desc.n_packaged_extensions(); + let module_str = if n_modules == 1 { "module" } else { "modules" }; + let extension_str = if n_extensions == 1 { + "extension" + } else { + "extensions" + }; + + writeln!( + writer, + "{header}\nPackage contains {n_modules} {module_str} and {n_extensions} {extension_str}", + )?; + + let summaries: Vec = desc + .modules + .iter() + .map(|m| m.as_ref().map(Into::into).unwrap_or_default()) + .collect(); + let summary_table = tabled::Table::builder(summaries).index().build(); + writeln!(writer, "{summary_table}")?; + + for (i, module) in desc.modules.into_iter().enumerate() { + writeln!(writer, "\nModule {i}:")?; + if let Some(module) = module { + display_module(module, writer)?; } - Ok(()) } - - fn output_json(&mut self, package_desc: PackageDesc, res: &Result) -> Result<()> { - let err_str = res.as_ref().err().map(|e| format!("{e:?}")); - let json_desc = PackageDescriptionJson { - package_desc, - error: err_str, - }; - serde_json::to_writer_pretty(&mut self.output, &json_desc)?; - Ok(()) + if show_packaged_extensions { + writeln!(writer, "Packaged extensions:")?; + let ext_rows: Vec = desc + .packaged_extensions + .into_iter() + .flatten() + .map(Into::into) + .collect(); + let ext_table = tabled::Table::new(ext_rows); + writeln!(writer, "{ext_table}")?; } + Ok(()) +} - fn display_module(&mut self, desc: ModuleDesc) -> Result<()> { - if let Some(exts) = desc.used_extensions_resolved { - let ext_rows: Vec = exts.into_iter().map(Into::into).collect(); - let ext_table = tabled::Table::new(ext_rows); - writeln!(self.output, "Resolved extensions:\n{ext_table}")?; - } +/// Output a package description as JSON. +fn output_json( + package_desc: PackageDesc, + res: &Result, + writer: &mut W, +) -> Result<()> { + let err_str = res.as_ref().err().map(|e| format!("{e:?}")); + let json_desc = PackageDescriptionJson { + package_desc, + error: err_str, + }; + serde_json::to_writer_pretty(writer, &json_desc)?; + Ok(()) +} - if let Some(syms) = desc.public_symbols { - let sym_table = tabled::Table::new(syms.into_iter().map(|s| SymbolRow { symbol: s })); - writeln!(self.output, "Public symbols:\n{sym_table}")?; - } +/// Display information about a single module. +fn display_module(desc: ModuleDesc, writer: &mut W) -> Result<()> { + if let Some(exts) = desc.used_extensions_resolved { + let ext_rows: Vec = exts.into_iter().map(Into::into).collect(); + let ext_table = tabled::Table::new(ext_rows); + writeln!(writer, "Resolved extensions:\n{ext_table}")?; + } - if let Some(exts) = desc.used_extensions_generator { - let ext_rows: Vec = exts.into_iter().map(Into::into).collect(); - let ext_table = tabled::Table::new(ext_rows); - writeln!(self.output, "Generator claimed extensions:\n{ext_table}")?; - } + if let Some(syms) = desc.public_symbols { + let sym_table = tabled::Table::new(syms.into_iter().map(|s| SymbolRow { symbol: s })); + writeln!(writer, "Public symbols:\n{sym_table}")?; + } - Ok(()) + if let Some(exts) = desc.used_extensions_generator { + let ext_rows: Vec = exts.into_iter().map(Into::into).collect(); + let ext_table = tabled::Table::new(ext_rows); + writeln!(writer, "Generator claimed extensions:\n{ext_table}")?; } + + Ok(()) } #[derive(serde::Serialize, schemars::JsonSchema)] diff --git a/hugr-cli/src/hugr_io.rs b/hugr-cli/src/hugr_io.rs index 2691ccaf90..4578c1eaa4 100644 --- a/hugr-cli/src/hugr_io.rs +++ b/hugr-cli/src/hugr_io.rs @@ -73,10 +73,34 @@ impl HugrInputArgs { /// If [`HugrInputArgs::hugr_json`] is `true`, [`HugrInputArgs::get_hugr`] should be called instead as /// reading the input as a package will fail. pub fn get_described_package(&mut self) -> Result<(PackageDesc, Package), CliError> { + self.get_described_package_with_reader::<&[u8]>(None) + } + + /// Read a hugr envelope from an optional reader and return the envelope + /// description and the decoded package. + /// + /// If `reader` is `None`, reads from the input specified in the args. + /// + /// # Errors + /// + /// If [`HugrInputArgs::hugr_json`] is `true`, [`HugrInputArgs::get_hugr`] should be called instead as + /// reading the input as a package will fail. + pub fn get_described_package_with_reader( + &mut self, + reader: Option, + ) -> Result<(PackageDesc, Package), CliError> { let extensions = self.load_extensions()?; - let buffer = BufReader::new(&mut self.input); - Ok(read_described_envelope(buffer, &extensions)?) + match reader { + Some(r) => { + let buffer = BufReader::new(r); + Ok(read_described_envelope(buffer, &extensions)?) + } + None => { + let buffer = BufReader::new(&mut self.input); + Ok(read_described_envelope(buffer, &extensions)?) + } + } } /// Read a hugr JSON file from the input. /// @@ -86,15 +110,36 @@ impl HugrInputArgs { /// For most cases, [`HugrInputArgs::get_package`] should be called instead. #[deprecated(note = "Use `HugrInputArgs::get_package` instead.", since = "0.22.2")] pub fn get_hugr(&mut self) -> Result { + self.get_hugr_with_reader::<&[u8]>(None) + } + + /// Read a hugr JSON file from an optional reader. + /// + /// If `reader` is `None`, reads from the input specified in the args. + /// This is a legacy option for reading old HUGR JSON files. + pub(crate) fn get_hugr_with_reader( + &mut self, + reader: Option, + ) -> Result { let extensions = self.load_extensions()?; - let mut buffer = BufReader::new(&mut self.input); /// Wraps the hugr JSON so that it defines a valid envelope. const PREPEND: &str = r#"HUGRiHJv?@{"modules": ["#; const APPEND: &str = r#"],"extensions": []}"#; let mut envelope = PREPEND.to_string(); - buffer.read_to_string(&mut envelope)?; + + match reader { + Some(r) => { + let mut buffer = BufReader::new(r); + buffer.read_to_string(&mut envelope)?; + } + None => { + let mut buffer = BufReader::new(&mut self.input); + buffer.read_to_string(&mut envelope)?; + } + } + envelope.push_str(APPEND); let hugr = Hugr::load_str(envelope, Some(&extensions))?; diff --git a/hugr-cli/src/lib.rs b/hugr-cli/src/lib.rs index 1abfff2f45..d54592db5e 100644 --- a/hugr-cli/src/lib.rs +++ b/hugr-cli/src/lib.rs @@ -19,12 +19,18 @@ //! hugr validate --help //! ``` +use std::ffi::OsString; + +use anyhow::Result; use clap::{Parser, crate_version}; +#[cfg(feature = "tracing")] +use clap_verbosity_flag::VerbosityFilter; use clap_verbosity_flag::{InfoLevel, Verbosity}; use hugr::envelope::EnvelopeError; use hugr::package::PackageValidationError; -use std::ffi::OsString; use thiserror::Error; +#[cfg(feature = "tracing")] +use tracing::{error, metadata::LevelFilter}; pub mod convert; pub mod describe; @@ -32,6 +38,7 @@ pub mod extensions; pub mod hugr_io; pub mod mermaid; pub mod validate; + /// CLI arguments. #[derive(Parser, Debug)] #[clap(version = crate_version!(), long_about = None)] @@ -123,3 +130,157 @@ impl CliError { } } } + +impl CliCommand { + /// Run a CLI command with optional input/output overrides. + /// If overrides are `None`, behaves like the normal CLI. + /// If overrides are provided, stdin/stdout/files are ignored. + /// The `gen-extensions` and `external` commands don't support overrides. + /// + /// # Arguments + /// + /// * `input_override` - Optional reader to use instead of stdin/files + /// * `output_override` - Optional writer to use instead of stdout/files + /// + fn run_with_io( + self, + input_override: Option, + output_override: Option, + ) -> Result<()> { + match self { + Self::Validate(mut args) => args.run_with_input(input_override), + Self::GenExtensions(args) => { + if input_override.is_some() || output_override.is_some() { + return Err(anyhow::anyhow!( + "GenExtensions command does not support programmatic I/O overrides" + )); + } + args.run_dump(&hugr::std_extensions::STD_REG) + } + Self::Mermaid(mut args) => args.run_print_with_io(input_override, output_override), + Self::Convert(mut args) => args.run_convert_with_io(input_override, output_override), + Self::Describe(mut args) => args.run_describe_with_io(input_override, output_override), + Self::External(args) => { + if input_override.is_some() || output_override.is_some() { + return Err(anyhow::anyhow!( + "External commands do not support programmatic I/O overrides" + )); + } + run_external(args) + } + } + } +} + +impl Default for CliArgs { + fn default() -> Self { + Self::new() + } +} + +impl CliArgs { + /// Parse CLI arguments from the environment. + pub fn new() -> Self { + CliArgs::parse() + } + + /// Parse CLI arguments from an iterator. + pub fn new_from_args(args: I) -> Self + where + I: IntoIterator, + T: Into + Clone, + { + CliArgs::parse_from(args) + } + + /// Entrypoint for cli - process arguments and run commands. + /// + /// Process exits on error. + pub fn run_cli(self) { + #[cfg(feature = "tracing")] + { + let level = match self.verbose.filter() { + VerbosityFilter::Off => LevelFilter::OFF, + VerbosityFilter::Error => LevelFilter::ERROR, + VerbosityFilter::Warn => LevelFilter::WARN, + VerbosityFilter::Info => LevelFilter::INFO, + VerbosityFilter::Debug => LevelFilter::DEBUG, + VerbosityFilter::Trace => LevelFilter::TRACE, + }; + tracing_subscriber::fmt() + .with_writer(std::io::stderr) + .with_max_level(level) + .pretty() + .init(); + } + + let result = self + .command + .run_with_io(None::, None::); + + if let Err(err) = result { + #[cfg(feature = "tracing")] + error!("{:?}", err); + #[cfg(not(feature = "tracing"))] + eprintln!("{:?}", err); + std::process::exit(1); + } + } + + /// Run a CLI command with bytes input and capture bytes output. + /// + /// This provides a programmatic interface to the CLI. + /// Unlike `run_cli()`, this method: + /// - Accepts input instead of reading from stdin/files + /// - Returns output as a byte vector instead of writing to stdout/files + /// + /// # Arguments + /// + /// * `input` - The input data as bytes (e.g., a HUGR package) + /// + /// # Returns + /// + /// Returns `Ok(Vec)` with the command output, or an error on failure. + /// + /// + /// # Note + /// + /// The `gen-extensions` and `external` commands don't support byte I/O + /// and should use the normal `run_cli()` method instead. + pub fn run_with_io(self, input: impl std::io::Read) -> Result> { + let mut output = Vec::new(); + self.command.run_with_io(Some(input), Some(&mut output))?; + Ok(output) + } +} + +fn run_external(args: Vec) -> Result<()> { + // External subcommand support: invoke `hugr-` + if args.is_empty() { + eprintln!("No external subcommand specified."); + std::process::exit(1); + } + let subcmd = args[0].to_string_lossy(); + let exe = format!("hugr-{subcmd}"); + let rest: Vec<_> = args[1..] + .iter() + .map(|s| s.to_string_lossy().to_string()) + .collect(); + match std::process::Command::new(&exe).args(&rest).status() { + Ok(status) => { + if !status.success() { + std::process::exit(status.code().unwrap_or(1)); + } + } + Err(e) if e.kind() == std::io::ErrorKind::NotFound => { + eprintln!("error: no such subcommand: '{subcmd}'.\nCould not find '{exe}' in PATH."); + std::process::exit(1); + } + Err(e) => { + eprintln!("error: failed to invoke '{exe}': {e}"); + std::process::exit(1); + } + } + + Ok(()) +} diff --git a/hugr-cli/src/main.rs b/hugr-cli/src/main.rs index 48acd2a74f..c9cd47bdef 100644 --- a/hugr-cli/src/main.rs +++ b/hugr-cli/src/main.rs @@ -1,73 +1,6 @@ -//! Validate serialized HUGR on the command line - -use std::ffi::OsString; - -use anyhow::{Result, anyhow}; -use clap::Parser as _; -use clap_verbosity_flag::VerbosityFilter; -use hugr_cli::{CliArgs, CliCommand}; -use tracing::{error, metadata::LevelFilter}; +//! HUGR CLI tools. +use hugr_cli::CliArgs; fn main() { - let cli_args = CliArgs::parse(); - - let level = match cli_args.verbose.filter() { - VerbosityFilter::Off => LevelFilter::OFF, - VerbosityFilter::Error => LevelFilter::ERROR, - VerbosityFilter::Warn => LevelFilter::WARN, - VerbosityFilter::Info => LevelFilter::INFO, - VerbosityFilter::Debug => LevelFilter::DEBUG, - VerbosityFilter::Trace => LevelFilter::TRACE, - }; - tracing_subscriber::fmt() - .with_writer(std::io::stderr) - .with_max_level(level) - .pretty() - .init(); - - let result = match cli_args.command { - CliCommand::Validate(mut args) => args.run(), - CliCommand::GenExtensions(args) => args.run_dump(&hugr::std_extensions::STD_REG), - CliCommand::Mermaid(mut args) => args.run_print(), - CliCommand::Convert(mut args) => args.run_convert(), - CliCommand::Describe(mut args) => args.run_describe(), - CliCommand::External(args) => run_external(args), - _ => Err(anyhow!("Unknown command")), - }; - - if let Err(err) = result { - error!("{:?}", err); - std::process::exit(1); - } -} - -fn run_external(args: Vec) -> Result<()> { - // External subcommand support: invoke `hugr-` - if args.is_empty() { - eprintln!("No external subcommand specified."); - std::process::exit(1); - } - let subcmd = args[0].to_string_lossy(); - let exe = format!("hugr-{subcmd}"); - let rest: Vec<_> = args[1..] - .iter() - .map(|s| s.to_string_lossy().to_string()) - .collect(); - match std::process::Command::new(&exe).args(&rest).status() { - Ok(status) => { - if !status.success() { - std::process::exit(status.code().unwrap_or(1)); - } - } - Err(e) if e.kind() == std::io::ErrorKind::NotFound => { - eprintln!("error: no such subcommand: '{subcmd}'.\nCould not find '{exe}' in PATH."); - std::process::exit(1); - } - Err(e) => { - eprintln!("error: failed to invoke '{exe}': {e}"); - std::process::exit(1); - } - } - - Ok(()) + CliArgs::new().run_cli(); } diff --git a/hugr-cli/src/mermaid.rs b/hugr-cli/src/mermaid.rs index b9e2637604..50a4a79af0 100644 --- a/hugr-cli/src/mermaid.rs +++ b/hugr-cli/src/mermaid.rs @@ -1,5 +1,5 @@ //! Render mermaid diagrams. -use std::io::Write; +use std::io::{Read, Write}; use crate::CliError; use crate::hugr_io::HugrInputArgs; @@ -32,18 +32,34 @@ pub struct MermaidArgs { } impl MermaidArgs { - /// Write the mermaid diagram to the output. - pub fn run_print(&mut self) -> Result<()> { + /// Write the mermaid diagram to the output with optional input/output overrides. + /// + /// # Arguments + /// + /// * `input_override` - Optional reader to use instead of the CLI input argument. + /// * `output_override` - Optional writer to use instead of the CLI output argument. + pub fn run_print_with_io( + &mut self, + input_override: Option, + output_override: Option, + ) -> Result<()> { if self.input_args.hugr_json { - self.run_print_hugr() + self.run_print_hugr_with_io(input_override, output_override) } else { - self.run_print_envelope() + self.run_print_envelope_with_io(input_override, output_override) } } - /// Write the mermaid diagram for a HUGR envelope. - pub fn run_print_envelope(&mut self) -> Result<()> { - let (desc, package) = self.input_args.get_described_package()?; + /// Write the mermaid diagram for a HUGR envelope with optional I/O overrides. + fn run_print_envelope_with_io( + &mut self, + input_override: Option, + mut output_override: Option, + ) -> Result<()> { + let (desc, package) = self + .input_args + .get_described_package_with_reader(input_override)?; + let generator = desc.generator(); if self.validate { package @@ -52,22 +68,49 @@ impl MermaidArgs { } for hugr in package.modules { - writeln!(self.output, "{}", hugr.mermaid_string())?; + if let Some(ref mut writer) = output_override { + writeln!(writer, "{}", hugr.mermaid_string())?; + } else { + writeln!(self.output, "{}", hugr.mermaid_string())?; + } } Ok(()) } - /// Write the mermaid diagram for a legacy HUGR json. - pub fn run_print_hugr(&mut self) -> Result<()> { + /// Write the mermaid diagram for a legacy HUGR json with optional I/O overrides. + fn run_print_hugr_with_io( + &mut self, + input_override: Option, + mut output_override: Option, + ) -> Result<()> { #[allow(deprecated)] - let hugr = self.input_args.get_hugr()?; + let hugr = self.input_args.get_hugr_with_reader(input_override)?; if self.validate { hugr.validate() .map_err(PackageValidationError::Validation)?; } - writeln!(self.output, "{}", hugr.mermaid_string())?; + if let Some(ref mut writer) = output_override { + writeln!(writer, "{}", hugr.mermaid_string())?; + } else { + writeln!(self.output, "{}", hugr.mermaid_string())?; + } Ok(()) } + + /// Write the mermaid diagram to the output. + pub fn run_print(&mut self) -> Result<()> { + self.run_print_with_io(None::<&[u8]>, None::>) + } + + /// Write the mermaid diagram for a HUGR envelope. + pub fn run_print_envelope(&mut self) -> Result<()> { + self.run_print_envelope_with_io(None::<&[u8]>, None::>) + } + + /// Write the mermaid diagram for a legacy HUGR json. + pub fn run_print_hugr(&mut self) -> Result<()> { + self.run_print_hugr_with_io(None::<&[u8]>, None::>) + } } diff --git a/hugr-cli/src/validate.rs b/hugr-cli/src/validate.rs index 28d65c46b6..68c46146db 100644 --- a/hugr-cli/src/validate.rs +++ b/hugr-cli/src/validate.rs @@ -4,6 +4,8 @@ use anyhow::Result; use clap::Parser; use hugr::HugrView; use hugr::package::PackageValidationError; +use std::io::Read; +#[cfg(feature = "tracing")] use tracing::info; use crate::CliError; @@ -26,10 +28,16 @@ pub const VALID_PRINT: &str = "HUGR valid!"; impl ValArgs { /// Run the HUGR cli and validate against an extension registry. - pub fn run(&mut self) -> Result<()> { + /// + /// # Arguments + /// + /// * `input_override` - Optional reader to use instead of the CLI input argument. + /// If provided, this reader will be used for input instead of + /// `self.input_args.input`. + pub fn run_with_input(&mut self, input_override: Option) -> Result<()> { if self.input_args.hugr_json { #[allow(deprecated)] - let hugr = self.input_args.get_hugr()?; + let hugr = self.input_args.get_hugr_with_reader(input_override)?; #[allow(deprecated)] let generator = hugr::envelope::get_generator(&[&hugr]); @@ -37,15 +45,25 @@ impl ValArgs { .map_err(PackageValidationError::Validation) .map_err(|val_err| CliError::validation(generator, val_err))?; } else { - let (desc, package) = self.input_args.get_described_package()?; + let (desc, package) = self + .input_args + .get_described_package_with_reader(input_override)?; let generator = desc.generator(); package .validate() .map_err(|val_err| CliError::validation(generator, val_err))?; }; + #[cfg(feature = "tracing")] info!("{VALID_PRINT}"); + #[cfg(not(feature = "tracing"))] + eprintln!("{VALID_PRINT}"); Ok(()) } + + /// Run the HUGR cli and validate against an extension registry. + pub fn run(&mut self) -> Result<()> { + self.run_with_input(None::<&[u8]>) + } } diff --git a/hugr-cli/tests/convert.rs b/hugr-cli/tests/convert.rs index 1a1fab492a..d88055c66a 100644 --- a/hugr-cli/tests/convert.rs +++ b/hugr-cli/tests/convert.rs @@ -16,6 +16,7 @@ use hugr::{ extension::prelude::bool_t, types::Signature, }; +use hugr_cli::CliArgs; use predicates::str::contains; use rstest::{fixture, rstest}; use std::io::BufReader; @@ -269,3 +270,52 @@ fn test_format_conflicts(mut convert_cmd: Command) { .failure() .stderr(contains("cannot be used with")); } + +#[rstest] +fn test_convert_programmatic_api(test_package: Package) { + // Test the programmatic API (no CLI process spawning) + + // Serialize the test package as binary + let mut input_data = Vec::new(); + test_package + .store(&mut input_data, EnvelopeConfig::binary()) + .unwrap(); + + // Parse CLI args for conversion to JSON + let cli_args = CliArgs::new_from_args(["hugr", "convert", "--format", "json"]); + + let output = cli_args.run_with_io(input_data.as_slice()).unwrap(); + + let reader = BufReader::new(output.as_slice()); + let registry = ExtensionRegistry::default(); + let (desc, package_out) = + read_described_envelope(reader, ®istry).expect("Failed to read output envelope"); + + // Verify format is JSON + assert_eq!(desc.header.config().format, EnvelopeFormat::PackageJson); + + // Verify the package content is preserved + assert_eq!(package_out, test_package); +} + +#[rstest] +fn test_convert_programmatic_model_text(test_package: Package) { + // Test converting to model-text format programmatically + + let mut input_data = Vec::new(); + test_package + .store(&mut input_data, EnvelopeConfig::binary()) + .unwrap(); + + let cli_args = CliArgs::new_from_args(["hugr", "convert", "--format", "model-text"]); + + let output = cli_args.run_with_io(input_data.as_slice()).unwrap(); + + // Verify the output is valid model-text format + let reader = BufReader::new(output.as_slice()); + let registry = ExtensionRegistry::default(); + let (desc, _) = + read_described_envelope(reader, ®istry).expect("Failed to read output envelope"); + + assert_eq!(desc.header.config().format, EnvelopeFormat::ModelText); +} diff --git a/hugr-cli/tests/validate.rs b/hugr-cli/tests/validate.rs index 52dfb562aa..1e12b2ebd8 100644 --- a/hugr-cli/tests/validate.rs +++ b/hugr-cli/tests/validate.rs @@ -18,6 +18,7 @@ use hugr::{ std_extensions::arithmetic::float_types::float64_type, types::Signature, }; +use hugr_cli::CliArgs; use hugr_cli::validate::VALID_PRINT; use predicates::{prelude::*, str::contains}; use rstest::{fixture, rstest}; @@ -93,6 +94,7 @@ fn test_stdin(test_envelope_str: String, mut val_cmd: Command) { } #[rstest] +#[cfg(feature = "tracing")] fn test_stdin_silent(test_envelope_str: String, mut val_cmd: Command) { val_cmd.args(["-", "-q"]); val_cmd.write_stdin(test_envelope_str); @@ -240,3 +242,35 @@ fn test_validate_known_generator(invalid_hugr_with_generator: Vec, mut val_c .stderr(contains("unconnected port")) .stderr(contains("generated by test-generator-v1.0.1")); } + +#[rstest] +fn test_validate_programmatic_api(test_package: Package) { + // Serialize the test package to bytes + let mut package_bytes = Vec::new(); + test_package + .store(&mut package_bytes, EnvelopeConfig::binary()) + .unwrap(); + + // Create CLI args for validate command + let cli_args = CliArgs::new_from_args(vec!["hugr", "validate"]); + + let result = cli_args.run_with_io(package_bytes.as_slice()); + + let output = result.unwrap(); + assert!(output.is_empty()); +} + +#[rstest] +fn test_validate_programmatic_api_invalid(invalid_hugr_with_generator: Vec) { + // Create CLI args for validate command + let cli_args = CliArgs::new_from_args(vec!["hugr", "validate"]); + + // Run validation with invalid bytes input + let result = cli_args.run_with_io(invalid_hugr_with_generator.as_slice()); + + // Should fail + assert!(result.is_err()); + let err = result.unwrap_err(); + let err_string = format!("{:?}", err); + assert!(err_string.contains("unconnected port")); +} diff --git a/hugr-py/Cargo.toml b/hugr-py/Cargo.toml index 5f53884cea..4867794cb0 100644 --- a/hugr-py/Cargo.toml +++ b/hugr-py/Cargo.toml @@ -21,7 +21,7 @@ bench = false [dependencies] bumpalo = { workspace = true, features = ["collections"] } -hugr-core = { version = "0.24.3", path = "../hugr-core", features = ["zstd"] } +hugr-cli = { version = "0.24.3", path = "../hugr-cli", default-features = false } hugr-model = { version = "0.24.3", path = "../hugr-model", features = ["pyo3"] } pastey.workspace = true pyo3 = { workspace = true, features = ["extension-module", "abi3-py310"] } diff --git a/hugr-py/pyproject.toml b/hugr-py/pyproject.toml index 33742215e3..9f03767010 100644 --- a/hugr-py/pyproject.toml +++ b/hugr-py/pyproject.toml @@ -36,6 +36,10 @@ dependencies = [ "pyzstd~=0.18.0", ] + +[project.scripts] +hugr = "hugr._hugr:run_cli" + [project.optional-dependencies] docs = ["sphinx>=8.1.3,<9.0.0", "furo"] pytket = ["pytket >= 1.34.0"] @@ -55,6 +59,8 @@ python-source = "src/" manifest-path = "Cargo.toml" # "extension-module" tells pyo3 we want to build an extension module (skips linking against libpython.so) features = ["pyo3/extension-module"] +# custom profile with extra optimisations +profile = "release-py" [tool.pyright] # Rust bindings have typing stubs but no python source code. diff --git a/hugr-py/rust/lib.rs b/hugr-py/rust/lib.rs index cd4c3d3c17..ee98f57853 100644 --- a/hugr-py/rust/lib.rs +++ b/hugr-py/rust/lib.rs @@ -1,9 +1,6 @@ //! Supporting Rust library for the hugr Python bindings. -use hugr_core::{ - envelope::{EnvelopeConfig, EnvelopeFormat, read_described_envelope, write_envelope}, - std_extensions::STD_REG, -}; +use hugr_cli::CliArgs; use hugr_model::v0::ast; use pyo3::{exceptions::PyValueError, prelude::*}; @@ -54,17 +51,6 @@ fn bytes_to_package(bytes: &[u8]) -> PyResult { Ok(package) } -/// Convert an envelope to a new envelope in JSON format. -#[pyfunction] -fn to_json_envelope(bytes: &[u8]) -> PyResult> { - let (_, pkg) = read_described_envelope(bytes, &STD_REG) - .map_err(|err| PyValueError::new_err(err.to_string()))?; - let config_json = EnvelopeConfig::new(EnvelopeFormat::PackageJson); - let mut json_data: Vec = Vec::new(); - write_envelope(&mut json_data, &pkg, config_json).unwrap(); - Ok(json_data) -} - /// Returns the current version of the HUGR model format as a tuple of (major, minor, patch). #[pyfunction] fn current_model_version() -> (u64, u64, u64) { @@ -75,6 +61,43 @@ fn current_model_version() -> (u64, u64, u64) { ) } +#[pyfunction] +/// Directly invoke the HUGR CLI entrypoint. +fn run_cli() { + // python is the first arg so skip it + CliArgs::new_from_args(std::env::args().skip(1)).run_cli(); +} + +/// Run a CLI command with bytes input and return bytes output. +/// +/// This function provides a programmatic interface to the HUGR CLI, +/// allowing Python code to pass input data as bytes and receive output +/// as bytes, without needing to use stdin/stdout or temporary files. +/// +/// # Arguments +/// +/// * `args` - Command line arguments as a list of strings, not including the executable name. +/// * `input_bytes` - Optional input data as bytes (e.g., a HUGR package) +/// +/// # Returns +/// +/// Returns the command output as bytes, maybe empty. +/// Raises an exception on error. +/// +/// Errors or tracing may still be printed to stderr as normal. +/// ``` +#[pyfunction] +#[pyo3(signature = (args, input_bytes=None))] +fn cli_with_io(mut args: Vec, input_bytes: Option<&[u8]>) -> PyResult> { + // placeholder for executable + args.insert(0, String::new()); + let cli_args = CliArgs::new_from_args(args); + let input = input_bytes.unwrap_or(&[]); + cli_args + .run_with_io(input) + .map_err(|e| PyValueError::new_err(format!("{:?}", e))) +} + #[pymodule] fn _hugr(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_function(wrap_pyfunction!(term_to_string, m)?)?; @@ -94,6 +117,7 @@ fn _hugr(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_function(wrap_pyfunction!(symbol_to_string, m)?)?; m.add_function(wrap_pyfunction!(string_to_symbol, m)?)?; m.add_function(wrap_pyfunction!(current_model_version, m)?)?; - m.add_function(wrap_pyfunction!(to_json_envelope, m)?)?; + m.add_function(wrap_pyfunction!(run_cli, m)?)?; + m.add_function(wrap_pyfunction!(cli_with_io, m)?)?; Ok(()) } diff --git a/hugr-py/src/hugr/_hugr/__init__.pyi b/hugr-py/src/hugr/_hugr/__init__.pyi index 7c35e0ad8a..571392e5a6 100644 --- a/hugr-py/src/hugr/_hugr/__init__.pyi +++ b/hugr-py/src/hugr/_hugr/__init__.pyi @@ -20,3 +20,5 @@ def package_to_bytes(package: hugr.model.Package) -> bytes: ... def bytes_to_package(binary: bytes) -> hugr.model.Package: ... def current_model_version() -> tuple[int, int, int]: ... def to_json_envelope(binary: bytes) -> bytes: ... +def run_cli() -> None: ... +def cli_with_io(args: list[str], input_bytes: bytes | None = None) -> bytes: ... diff --git a/hugr-py/src/hugr/cli.py b/hugr-py/src/hugr/cli.py new file mode 100644 index 0000000000..8dc4980c5b --- /dev/null +++ b/hugr-py/src/hugr/cli.py @@ -0,0 +1,308 @@ +"""Python interface for the HUGR CLI subcommands. + +Provides programmatic access to the HUGR CLI via Rust bindings. +Exposes a generic `cli_with_io` function and helpers for the main subcommands: +validate, describe, convert, and mermaid. +""" + +from pydantic import BaseModel + +from hugr._hugr import cli_with_io + +__all__ = [ + "cli_with_io", + "validate", + "describe_str", + "describe", + "convert", + "mermaid", + "PackageDesc", + "ModuleDesc", + "ExtensionDesc", + "EntrypointDesc", +] + + +def _add_input_args( + args: list[str], no_std: bool, extensions: list[str] | None +) -> list[str]: + """Add common HugrInputArgs parameters to the argument list.""" + if no_std: + args.append("--no-std") + if extensions is not None: + for ext in extensions: + args.extend(["--extensions", ext]) + return args + + +def validate( + hugr_bytes: bytes, + *, + no_std: bool = False, + extensions: list[str] | None = None, +) -> None: + """Validate a HUGR package. + + Args: + hugr_bytes: The HUGR package as bytes. + no_std: Don't use standard extensions when validating hugrs. + Prelude is still used (default: False). + extensions: Paths to additional serialised extensions needed to load the HUGR. + + Raises: + ValueError: On validation failure. + """ + args = _add_input_args(["validate"], no_std, extensions) + cli_with_io(args, hugr_bytes) + + +class EntrypointDesc(BaseModel): + """Description of a module's entrypoint node. + + Attributes: + node: The node index of the entrypoint. + optype: String representation of the operation type. + """ + + node: int + optype: str + + +class ExtensionDesc(BaseModel): + """Description of a HUGR extension. + + Attributes: + name: The name of the extension. + version: The version string of the extension. + """ + + name: str + version: str + + +class ModuleDesc(BaseModel): + """Description of a HUGR module. + + Attributes: + entrypoint: The entrypoint node of the module, if present. + generator: Name and version of the generator that created this module. + num_nodes: Total number of nodes in the module. + public_symbols: List of public symbol names exported by the module. + used_extensions_generator: Extensions claimed by the generator in metadata. + used_extensions_resolved: Extensions actually used by the module operations. + """ + + entrypoint: EntrypointDesc | None = None + generator: str | None = None + num_nodes: int | None = None + public_symbols: list[str] | None = None + used_extensions_generator: list[ExtensionDesc] | None = None + used_extensions_resolved: list[ExtensionDesc] | None = None + + def uses_extension(self, extension_name: str) -> bool: + """Check if this module uses a specific extension. + + Args: + extension_name: The name of the extension to check. + + Returns: + True if the module uses the extension, False otherwise. + """ + return any( + ext.name == extension_name for ext in self.used_extensions_resolved or [] + ) + + +class PackageDesc(BaseModel): + """Description of a HUGR package. + + Attributes: + error: Error message if the package failed to load. + header: String representation of the envelope header. + modules: List of module descriptions in the package. + packaged_extensions: Extensions bundled with the package. + """ + + error: str | None = None + header: str + modules: list[ModuleDesc | None] + packaged_extensions: list[ExtensionDesc | None] | None = None + + def uses_extension(self, extension_name: str) -> bool: + """Check if any module in this package uses a specific extension. + + Args: + extension_name: The name of the extension to check. + + Returns: + True if any module uses the extension, False otherwise. + """ + return any( + module.uses_extension(extension_name) + for module in self.modules + if module is not None + ) + + +def describe_str( + hugr_bytes: bytes, + *, + packaged_extensions: bool = False, + no_resolved_extensions: bool = False, + public_symbols: bool = False, + generator_claimed_extensions: bool = False, + no_std: bool = False, + extensions: list[str] | None = None, + _json: bool = False, # only used by describe() +) -> str: + """Describe the contents of a HUGR package as text. + + If an error occurs during loading, partial descriptions are printed. + For example, if the first module is loaded and the second fails, + then only the first module will be described. + + Args: + hugr_bytes: The HUGR package as bytes. + packaged_extensions: Enumerate packaged extensions (default: False). + no_resolved_extensions: Don't display resolved extensions used by the module + (default: False). + public_symbols: Display public symbols in the module (default: False). + generator_claimed_extensions: Display claimed extensions set by generator + in module metadata (default: False). + no_std: Don't use standard extensions when validating hugrs. + Prelude is still used (default: False). + extensions: Paths to additional serialised extensions needed to load the HUGR. + + Returns: + Text description of the package. + """ + args = ["describe"] + if _json: + args.append("--json") + if packaged_extensions: + args.append("--packaged-extensions") + if no_resolved_extensions: + args.append("--no-resolved-extensions") + if public_symbols: + args.append("--public-symbols") + if generator_claimed_extensions: + args.append("--generator-claimed-extensions") + args = _add_input_args(args, no_std, extensions) + return cli_with_io(args, hugr_bytes).decode("utf-8") + + +def describe( + hugr_bytes: bytes, + *, + packaged_extensions: bool = False, + no_resolved_extensions: bool = False, + public_symbols: bool = False, + generator_claimed_extensions: bool = False, + no_std: bool = False, + extensions: list[str] | None = None, +) -> PackageDesc: + """Describe the contents of a HUGR package. + + If an error occurs during loading, partial descriptions are returned. + For example, if the first module is loaded and the second fails, + then only the first module will be described. + + Args: + hugr_bytes: The HUGR package as bytes. + packaged_extensions: Enumerate packaged extensions (default: False). + no_resolved_extensions: Don't display resolved extensions used by the module + (default: False). + public_symbols: Display public symbols in the module (default: False). + generator_claimed_extensions: Display claimed extensions set by generator + in module metadata (default: False). + no_std: Don't use standard extensions when validating hugrs. + Prelude is still used (default: False). + extensions: Paths to additional serialised extensions needed to load the HUGR. + + Returns: + Structured package description as a PackageDesc object. + """ + output = describe_str( + hugr_bytes, + _json=True, + packaged_extensions=packaged_extensions, + no_resolved_extensions=no_resolved_extensions, + public_symbols=public_symbols, + generator_claimed_extensions=generator_claimed_extensions, + no_std=no_std, + extensions=extensions, + ) + return PackageDesc.model_validate_json(output) + + +def convert( + hugr_bytes: bytes, + *, + format: str | None = None, + text: bool = False, + binary: bool = False, + compress: bool = False, + compression_level: int | None = None, + no_std: bool = False, + extensions: list[str] | None = None, +) -> bytes: + """Convert between different HUGR envelope formats. + + Args: + hugr_bytes: The HUGR package as bytes. + format: Output format. One of: json, model, model-exts, model-text, + model-text-exts (default: None, meaning same format as input). + text: Use default text-based envelope configuration. Cannot be combined + with format or binary (default: False). + binary: Use default binary envelope configuration. Cannot be combined + with format or text (default: False). + compress: Enable zstd compression for the output (default: False). + compression_level: Zstd compression level (1-22, where 1 is fastest and + 22 is best compression). (default None, uses the zstd default). + no_std: Don't use standard extensions when validating hugrs. + Prelude is still used (default: False). + extensions: Paths to additional serialised extensions needed to load the HUGR. + + Returns: + Converted package as bytes. + """ + args = ["convert"] + if format is not None: + args.extend(["--format", format]) + if text: + args.append("--text") + if binary: + args.append("--binary") + if compress: + args.append("--compress") + if compression_level is not None: + args.extend(["--compression-level", str(compression_level)]) + args = _add_input_args(args, no_std, extensions) + return cli_with_io(args, hugr_bytes) + + +def mermaid( + hugr_bytes: bytes, + *, + validate: bool = False, + no_std: bool = False, + extensions: list[str] | None = None, +) -> str: + """Generate mermaid diagrams from a HUGR package. + + Args: + hugr_bytes: The HUGR package as bytes. + validate: Validate before rendering, includes extension inference + (default: False). + no_std: Don't use standard extensions when validating hugrs. + Prelude is still used (default: False). + extensions: Paths to additional serialised extensions needed to load the HUGR. + + Returns: + Mermaid diagram output as a string. + """ + args = ["mermaid"] + if validate: + args.append("--validate") + args = _add_input_args(args, no_std, extensions) + return cli_with_io(args, hugr_bytes).decode("utf-8") diff --git a/hugr-py/src/hugr/envelope.py b/hugr-py/src/hugr/envelope.py index 0647359bee..088e0ae4d9 100644 --- a/hugr-py/src/hugr/envelope.py +++ b/hugr-py/src/hugr/envelope.py @@ -39,7 +39,7 @@ import pyzstd -import hugr._hugr as rust +from hugr import cli if TYPE_CHECKING: from hugr.hugr.base import Hugr @@ -117,7 +117,7 @@ def read_envelope(envelope: bytes) -> Package: # TODO Going via JSON is a temporary solution, until we get model import to # python properly implemented. # https://github.com/CQCL/hugr/issues/2287 - json_data = rust.to_json_envelope(envelope) + json_data = cli.convert(envelope, format="json") return read_envelope(json_data) diff --git a/hugr-py/tests/conftest.py b/hugr-py/tests/conftest.py index 64862e2fa5..26c24638a4 100644 --- a/hugr-py/tests/conftest.py +++ b/hugr-py/tests/conftest.py @@ -1,15 +1,13 @@ from __future__ import annotations import os -import pathlib -import subprocess from dataclasses import dataclass from enum import Enum from typing import TYPE_CHECKING, TypeVar from typing_extensions import Self -from hugr import ext, tys +from hugr import cli, ext, tys from hugr.envelope import EnvelopeConfig from hugr.hugr import Hugr from hugr.ops import AsExtOp, Command, Const, Custom, DataflowOp, ExtOp, RegisteredOp @@ -148,19 +146,6 @@ def __call__(self, q: ComWire, fl_wire: ComWire) -> Command: Rz = RzDef() -def _base_command() -> list[str]: - workspace_dir = pathlib.Path(__file__).parent.parent.parent - # use the HUGR_BIN environment variable if set, otherwise use the debug build - bin_loc = os.environ.get("HUGR_BIN", str(workspace_dir / "target/debug/hugr")) - return [bin_loc] - - -def mermaid(h: Hugr): - """Render the Hugr as a mermaid diagram for debugging.""" - cmd = [*_base_command(), "mermaid", "-"] - _run_hugr_cmd(h.to_str().encode(), cmd) - - def validate( h: Hugr | Package, *, @@ -193,12 +178,10 @@ def validate( # test hugrs. LOAD_FORMATS = ["json", "model-exts"] - cmd = [*_base_command(), "validate", "-"] - # validate text and binary formats for write_fmt in WRITE_FORMATS: serial = h.to_bytes(FORMATS[write_fmt]) - _run_hugr_cmd(serial, cmd) + cli.validate(serial) if roundtrip: # Roundtrip tests: @@ -209,9 +192,7 @@ def validate( # Run `pytest` with `-vv` to see the hash diff. for load_fmt in LOAD_FORMATS: if load_fmt != write_fmt: - cmd = [*_base_command(), "convert", "--format", load_fmt, "-"] - out = _run_hugr_cmd(serial, cmd) - converted_serial = out.stdout + converted_serial = cli.convert(serial, format=load_fmt) else: converted_serial = serial loaded = Package.from_bytes(converted_serial) @@ -226,12 +207,6 @@ def validate( h1_hash == h2_hash ), f"HUGRs are not the same for {write_fmt} -> {load_fmt}" - # Lowering functions are currently ignored in Python, - # because we don't support loading -model envelopes yet. - for ext in loaded.extensions: - for op in ext.operations.values(): - assert op.lower_funcs == [] - @dataclass(frozen=True, order=True) class _NodeHash: @@ -308,20 +283,3 @@ class _OpHash: def __lt__(self, other: _OpHash) -> bool: """Compare two op hashes by name and payload.""" return (self.name, repr(self.payload)) < (other.name, repr(other.payload)) - - -def _get_mermaid(serial: bytes) -> str: # - """Render a HUGR as a mermaid diagram using the CLI.""" - return _run_hugr_cmd(serial, [*_base_command(), "mermaid", "-"]).stdout.decode() - - -def _run_hugr_cmd(serial: bytes, cmd: list[str]) -> subprocess.CompletedProcess[bytes]: - """Run a HUGR command. - - The `serial` argument is the serialized HUGR to pass to the command via stdin. - """ - try: - return subprocess.run(cmd, check=True, input=serial, capture_output=True) # noqa: S603 - except subprocess.CalledProcessError as e: - error = e.stderr.decode() - raise RuntimeError(error) from e diff --git a/hugr-py/tests/test_cli.py b/hugr-py/tests/test_cli.py new file mode 100644 index 0000000000..158e4fd304 --- /dev/null +++ b/hugr-py/tests/test_cli.py @@ -0,0 +1,141 @@ +"""Tests for the CLI bindings.""" + +from typing import Any + +import pytest + +from hugr import cli +from hugr.build import Module +from hugr.ext import Extension +from hugr.package import Package + +from .serialization.test_extension import EXAMPLE + + +@pytest.fixture +def simple_hugr_bytes() -> bytes: + """Create a simple HUGR package as bytes for testing.""" + return Package([Module().hugr]).to_bytes() + + +@pytest.fixture +def hugr_with_extension_bytes() -> bytes: + """Create a HUGR package with an extension as bytes for testing.""" + ext = Extension.from_json(EXAMPLE) + module = Module() + return Package([module.hugr], [ext]).to_bytes() + + +def test_validate_with_bytes(simple_hugr_bytes: bytes): + """Test validating a HUGR package using the programmatic API.""" + cli.validate(simple_hugr_bytes) + + +def test_validate_with_bytes_invalid(): + """Test that invalid packages raise errors through the programmatic API.""" + # We need to pass invalid bytes through cli_with_io directly + # since Package construction would fail first + + invalid_bytes = b"not a valid hugr package" + + with pytest.raises(ValueError, match="Bad magic number"): + cli.cli_with_io(["validate"], invalid_bytes) + + +def test_validate_no_std(simple_hugr_bytes: bytes): + """Test validate with no_std flag.""" + cli.validate(simple_hugr_bytes, no_std=True) + + +@pytest.mark.parametrize( + "kwargs", + [ + {"format": "json"}, # convert to JSON format + {"text": True}, # convert to text format + {"compress": True, "compression_level": 9}, # convert with compression + ], +) +def test_convert_format(hugr_with_extension_bytes: bytes, kwargs: dict[str, Any]): + """Test converting a HUGR package between formats.""" + output_bytes = cli.convert(hugr_with_extension_bytes, **kwargs) + + output_package = Package.from_bytes(output_bytes) + input_package = Package.from_bytes(hugr_with_extension_bytes) + assert output_package == input_package + + +def test_mermaid_output(simple_hugr_bytes: bytes): + """Test generating mermaid diagrams from a HUGR package.""" + output = cli.mermaid(simple_hugr_bytes) + + assert "graph LR" in output + + +def test_mermaid_with_validation(simple_hugr_bytes: bytes): + """Test generating mermaid diagrams with validation.""" + output = cli.mermaid(simple_hugr_bytes, validate=True) + assert "graph LR" in output + + +def test_describe_output(simple_hugr_bytes: bytes): + """Test describing a HUGR package.""" + output_text = cli.describe_str(simple_hugr_bytes) + + # Should contain package information + assert "Package contains" in output_text + + +def test_describe_with_options(hugr_with_extension_bytes: bytes): + """Test describe with various options.""" + output_text = cli.describe_str(hugr_with_extension_bytes, packaged_extensions=True) + assert "Packaged extensions:" in output_text + + output_text = cli.describe_str( + hugr_with_extension_bytes, no_resolved_extensions=True + ) + assert "resolved" not in output_text + + output_text = cli.describe_str(hugr_with_extension_bytes, public_symbols=True) + assert len(output_text) > 0 + + # Test with generator_claimed_extensions flag + output_text = cli.describe_str( + hugr_with_extension_bytes, generator_claimed_extensions=True + ) + assert len(output_text) > 0 + + +def test_describe_json_basic(simple_hugr_bytes: bytes): + """Test describe_json returns structured PackageDesc.""" + desc = cli.describe(simple_hugr_bytes) + + assert isinstance(desc, cli.PackageDesc) + + # Should have expected fields + assert desc.header is not None + assert isinstance(desc.modules, list) + assert len(desc.modules) == 1 + + # Module should have properties + module = desc.modules[0] + assert module is not None + assert isinstance(module, cli.ModuleDesc) + assert module.num_nodes is not None + assert module.num_nodes > 0 + + +def test_describe_json_with_packaged_extensions(hugr_with_extension_bytes: bytes): + """Test describe_json with packaged_extensions flag.""" + desc = cli.describe(hugr_with_extension_bytes, packaged_extensions=True) + + # Should have packaged_extensions field populated + assert isinstance(desc, cli.PackageDesc) + assert desc.packaged_extensions is not None + + mod = desc.modules[0] + assert mod is not None + # mock use of extension in module + mod.used_extensions_resolved = desc.packaged_extensions # type: ignore[assignment] + + assert desc.uses_extension("ext") + assert not desc.uses_extension("nonexistent_extension") diff --git a/justfile b/justfile index 9761319f08..146da5226e 100644 --- a/justfile +++ b/justfile @@ -32,8 +32,7 @@ test-rust *TEST_ARGS: _check_nextest_installed # Run all python tests. test-python *TEST_ARGS: uv run maturin develop --uv - cargo build -p hugr-cli - HUGR_RENDER_DOT=1 uv run pytest -n auto {{TEST_ARGS}} + HUGR_RENDER_DOT=1 uv run pytest {{TEST_ARGS}} # Run all the benchmarks. bench language="[rust|python]": (_run_lang language \ diff --git a/pyproject.toml b/pyproject.toml index 9a2dada11d..ba0ad38e52 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,7 +15,6 @@ dev-dependencies = [ "syrupy >=4.7.1,<5", "types-zstd >= 1.5.6.6", "pytket >= 1.34.0", - "pytest-xdist>=3.8.0", ] [tool.pytest.ini_options] diff --git a/uv.lock b/uv.lock index 56950e47f3..51dd3a2bd5 100644 --- a/uv.lock +++ b/uv.lock @@ -18,7 +18,6 @@ dev = [ { name = "pre-commit", specifier = ">=3.6.2,<4" }, { name = "pytest", specifier = ">=8.1.1,<9" }, { name = "pytest-cov", specifier = ">=5.0.0,<6" }, - { name = "pytest-xdist", specifier = ">=3.8.0" }, { name = "pytket", specifier = ">=1.34.0" }, { name = "ruff", specifier = ">=0.6.2,<0.7" }, { name = "syrupy", specifier = ">=4.7.1,<5" }, @@ -328,15 +327,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/36/f4/c6e662dade71f56cd2f3735141b265c3c79293c109549c1e6933b0651ffc/exceptiongroup-1.3.0-py3-none-any.whl", hash = "sha256:4d111e6e0c13d0644cad6ddaa7ed0261a0b36971f6d23e7ec9b4b9097da78a10", size = 16674, upload-time = "2025-05-10T17:42:49.33Z" }, ] -[[package]] -name = "execnet" -version = "2.1.1" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/bb/ff/b4c0dc78fbe20c3e59c0c7334de0c27eb4001a2b2017999af398bf730817/execnet-2.1.1.tar.gz", hash = "sha256:5189b52c6121c24feae288166ab41b32549c7e2348652736540b9e6e7d4e72e3", size = 166524, upload-time = "2024-04-08T09:04:19.245Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/43/09/2aea36ff60d16dd8879bdb2f5b3ee0ba8d08cbbdcdfe870e695ce3784385/execnet-2.1.1-py3-none-any.whl", hash = "sha256:26dee51f1b80cebd6d0ca8e74dd8745419761d3bef34163928cbebbdc4749fdc", size = 40612, upload-time = "2024-04-08T09:04:17.414Z" }, -] - [[package]] name = "filelock" version = "3.20.0" @@ -1054,19 +1044,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/78/3a/af5b4fa5961d9a1e6237b530eb87dd04aea6eb83da09d2a4073d81b54ccf/pytest_cov-5.0.0-py3-none-any.whl", hash = "sha256:4f0764a1219df53214206bf1feea4633c3b558a2925c8b59f144f682861ce652", size = 21990, upload-time = "2024-03-24T20:16:32.444Z" }, ] -[[package]] -name = "pytest-xdist" -version = "3.8.0" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "execnet" }, - { name = "pytest" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/78/b4/439b179d1ff526791eb921115fca8e44e596a13efeda518b9d845a619450/pytest_xdist-3.8.0.tar.gz", hash = "sha256:7e578125ec9bc6050861aa93f2d59f1d8d085595d6551c2c90b6f4fad8d3a9f1", size = 88069, upload-time = "2025-07-01T13:30:59.346Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/ca/31/d4e37e9e550c2b92a9cbc2e4d0b7420a27224968580b5a447f420847c975/pytest_xdist-3.8.0-py3-none-any.whl", hash = "sha256:202ca578cfeb7370784a8c33d6d05bc6e13b4f25b5053c30a152269fd10f0b88", size = 46396, upload-time = "2025-07-01T13:30:56.632Z" }, -] - [[package]] name = "pytket" version = "2.10.1"