From cea8ab9f53f2e531330dfb74f23e7ce049ef60f4 Mon Sep 17 00:00:00 2001 From: Jenny Chen Date: Mon, 28 Jul 2025 14:02:05 -0600 Subject: [PATCH 1/6] temporarily revert back to hugr-v0.20.2 --- .config/nextest.toml | 7 - .github/change-filters.yml | 7 - .github/workflows/ci-py.yml | 45 +- .github/workflows/ci-rs.yml | 36 +- .github/workflows/notify-coverage.yml | 2 +- .github/workflows/release-please.yml | 2 - .github/workflows/release-plz.yml | 8 +- .github/workflows/semver-checks.yml | 31 +- .github/workflows/unsoundness.yml | 49 +- .release-please-manifest.json | 4 +- Cargo.lock | 239 +-- Cargo.toml | 20 +- hugr-cli/CHANGELOG.md | 12 - hugr-cli/Cargo.toml | 6 +- hugr-cli/src/extensions.rs | 13 +- hugr-cli/src/lib.rs | 62 +- hugr-cli/src/main.rs | 126 +- hugr-cli/src/mermaid.rs | 25 +- hugr-cli/src/validate.rs | 46 +- hugr-cli/tests/validate.rs | 43 +- hugr-core/CHANGELOG.md | 60 - hugr-core/Cargo.toml | 13 +- hugr-core/src/builder.rs | 12 +- hugr-core/src/builder/build_traits.rs | 106 +- hugr-core/src/builder/circuit.rs | 16 +- hugr-core/src/builder/dataflow.rs | 89 +- hugr-core/src/builder/module.rs | 143 +- hugr-core/src/core.rs | 88 +- hugr-core/src/envelope.rs | 94 +- hugr-core/src/envelope/header.rs | 57 +- hugr-core/src/envelope/package_json.rs | 26 +- hugr-core/src/envelope/serde_with.rs | 383 +--- hugr-core/src/export.rs | 306 ++-- hugr-core/src/extension.rs | 37 +- hugr-core/src/extension/declarative/types.rs | 4 +- hugr-core/src/extension/op_def.rs | 101 +- hugr-core/src/extension/prelude.rs | 128 +- hugr-core/src/extension/prelude/generic.rs | 8 +- hugr-core/src/extension/resolution.rs | 4 +- hugr-core/src/extension/resolution/ops.rs | 4 +- hugr-core/src/extension/resolution/test.rs | 8 +- hugr-core/src/extension/resolution/types.rs | 65 +- .../src/extension/resolution/types_mut.rs | 49 +- hugr-core/src/extension/simple_op.rs | 7 +- hugr-core/src/extension/type_def.rs | 45 +- hugr-core/src/hugr.rs | 25 +- hugr-core/src/hugr/hugrmut.rs | 45 +- hugr-core/src/hugr/internal.rs | 20 +- hugr-core/src/hugr/patch/inline_call.rs | 29 +- hugr-core/src/hugr/patch/outline_cfg.rs | 4 +- hugr-core/src/hugr/patch/peel_loop.rs | 20 +- hugr-core/src/hugr/patch/simple_replace.rs | 125 +- .../src/hugr/persistent.rs | 421 +++-- .../src/hugr/persistent}/parents_view.rs | 14 +- hugr-core/src/hugr/persistent/resolver.rs | 43 + .../src/hugr/persistent}/state_space.rs | 264 +-- .../hugr/persistent}/state_space/serial.rs | 64 +- .../src/hugr/persistent}/tests.rs | 137 +- .../src/hugr/persistent}/trait_impls.rs | 45 +- .../src/hugr/persistent}/walker.rs | 526 ++---- .../src/hugr/persistent/walker/pinned.rs | 164 ++ hugr-core/src/hugr/serialize.rs | 9 +- hugr-core/src/hugr/serialize/test.rs | 165 +- .../upgrade/testcases/hugr_with_named_op.json | 114 +- hugr-core/src/hugr/validate.rs | 160 +- hugr-core/src/hugr/validate/test.rs | 346 ++-- hugr-core/src/hugr/views.rs | 2 +- hugr-core/src/hugr/views/impls.rs | 1 - hugr-core/src/hugr/views/render.rs | 14 +- hugr-core/src/hugr/views/rerooted.rs | 1 - hugr-core/src/hugr/views/root_checked/dfg.rs | 316 ++-- hugr-core/src/hugr/views/sibling_subgraph.rs | 10 +- hugr-core/src/hugr/views/tests.rs | 8 +- hugr-core/src/import.rs | 1597 +++++++---------- hugr-core/src/lib.rs | 3 +- hugr-core/src/ops/constant.rs | 63 +- hugr-core/src/ops/constant/serialize.rs | 59 - hugr-core/src/ops/controlflow.rs | 14 +- hugr-core/src/ops/custom.rs | 13 +- hugr-core/src/ops/dataflow.rs | 12 +- hugr-core/src/ops/module.rs | 66 +- hugr-core/src/ops/tag.rs | 6 +- hugr-core/src/ops/validate.rs | 22 +- hugr-core/src/package.rs | 7 +- hugr-core/src/std_extensions.rs | 1 - .../std_extensions/arithmetic/int_types.rs | 28 +- .../src/std_extensions/arithmetic/mod.rs | 2 +- hugr-core/src/std_extensions/collections.rs | 1 - .../src/std_extensions/collections/array.rs | 4 +- .../collections/array/array_clone.rs | 14 +- .../collections/array/array_conversion.rs | 18 +- .../collections/array/array_discard.rs | 14 +- .../collections/array/array_op.rs | 46 +- .../collections/array/array_repeat.rs | 16 +- .../collections/array/array_scan.rs | 36 +- .../collections/array/array_value.rs | 6 +- .../collections/array/op_builder.rs | 13 - .../collections/borrow_array.rs | 797 -------- .../src/std_extensions/collections/list.rs | 30 +- .../collections/static_array.rs | 12 +- .../std_extensions/collections/value_array.rs | 2 +- hugr-core/src/std_extensions/ptr.rs | 6 +- hugr-core/src/types.rs | 165 +- hugr-core/src/types/check.rs | 8 +- hugr-core/src/types/custom.rs | 4 +- hugr-core/src/types/poly_func.rs | 141 +- hugr-core/src/types/row_var.rs | 4 +- hugr-core/src/types/serialize.rs | 172 +- hugr-core/src/types/type_param.rs | 1090 ++++------- hugr-core/src/types/type_row.rs | 162 +- hugr-core/tests/model.rs | 142 +- .../tests/persistent_walker_example.rs | 227 ++- .../tests/snapshots/model__roundtrip_add.snap | 33 +- .../snapshots/model__roundtrip_alias.snap | 2 +- .../snapshots/model__roundtrip_call.snap | 5 +- .../tests/snapshots/model__roundtrip_cfg.snap | 17 +- .../snapshots/model__roundtrip_cond.snap | 41 +- .../snapshots/model__roundtrip_const.snap | 13 +- .../model__roundtrip_constraints.snap | 16 +- .../model__roundtrip_entrypoint.snap | 19 +- .../snapshots/model__roundtrip_loop.snap | 7 +- .../snapshots/model__roundtrip_order.snap | 76 +- .../snapshots/model__roundtrip_params.snap | 40 +- hugr-llvm/CHANGELOG.md | 12 - hugr-llvm/Cargo.toml | 6 +- hugr-llvm/src/emit/ops/cfg.rs | 6 +- ...test_fns__diverse_cfg_children@llvm14.snap | 17 +- ...verse_cfg_children@pre-mem2reg@llvm14.snap | 45 +- ...test_fns__diverse_dfg_children@llvm14.snap | 23 + ...verse_dfg_children@pre-mem2reg@llvm14.snap | 38 + hugr-llvm/src/emit/test.rs | 98 +- hugr-llvm/src/extension/collections/array.rs | 39 +- hugr-llvm/src/extension/collections/list.rs | 8 +- ...t_static_array_of_static_array@llvm14.snap | 4 +- ...ay_of_static_array@pre-mem2reg@llvm14.snap | 4 +- .../src/extension/collections/stack_array.rs | 40 +- .../src/extension/collections/static_array.rs | 11 +- hugr-llvm/src/extension/conversions.rs | 20 +- hugr-llvm/src/extension/float.rs | 7 +- hugr-llvm/src/extension/int.rs | 17 +- hugr-llvm/src/extension/logic.rs | 4 +- hugr-llvm/src/extension/prelude.rs | 143 +- ...lude__test__prelude_make_error@llvm14.snap | 19 - ...prelude_make_error@pre-mem2reg@llvm14.snap | 31 - ...__prelude_make_error_and_panic@llvm14.snap | 28 - ...ke_error_and_panic@pre-mem2reg@llvm14.snap | 37 - hugr-llvm/src/test.rs | 2 +- hugr-llvm/src/utils/fat.rs | 9 +- hugr-model/CHANGELOG.md | 24 - hugr-model/Cargo.toml | 3 +- hugr-model/FORMAT_VERSION | 1 - hugr-model/capnp/hugr-v0.capnp | 13 - hugr-model/src/capnp/hugr_v0_capnp.rs | 550 +----- hugr-model/src/lib.rs | 19 - hugr-model/src/v0/ast/hugr.pest | 6 +- hugr-model/src/v0/ast/mod.rs | 4 +- hugr-model/src/v0/ast/parse.rs | 12 +- hugr-model/src/v0/ast/print.rs | 8 +- hugr-model/src/v0/ast/python.rs | 32 - hugr-model/src/v0/ast/resolve.rs | 18 - hugr-model/src/v0/ast/view.rs | 2 - hugr-model/src/v0/binary/read.rs | 140 +- hugr-model/src/v0/binary/write.rs | 18 +- hugr-model/src/v0/mod.rs | 53 +- hugr-model/src/v0/scope/vars.rs | 27 +- hugr-model/src/v0/table/mod.rs | 4 +- hugr-model/tests/fixtures/model-add.edn | 25 +- hugr-model/tests/fixtures/model-call.edn | 6 +- hugr-model/tests/fixtures/model-cfg.edn | 32 +- hugr-model/tests/fixtures/model-cond.edn | 37 +- hugr-model/tests/fixtures/model-const.edn | 6 +- .../tests/fixtures/model-constraints.edn | 5 +- .../tests/fixtures/model-entrypoint.edn | 10 +- hugr-model/tests/fixtures/model-loop.edn | 4 +- hugr-model/tests/fixtures/model-order.edn | 55 +- hugr-model/tests/fixtures/model-params.edn | 17 +- hugr-passes/CHANGELOG.md | 31 - hugr-passes/Cargo.toml | 4 +- hugr-passes/src/call_graph.rs | 11 +- hugr-passes/src/composable.rs | 40 +- hugr-passes/src/const_fold.rs | 1 + hugr-passes/src/const_fold/test.rs | 2 +- hugr-passes/src/dataflow.rs | 1 + hugr-passes/src/dataflow/datalog.rs | 5 - hugr-passes/src/dataflow/test.rs | 18 +- hugr-passes/src/dead_funcs.rs | 14 +- hugr-passes/src/inline_dfgs.rs | 99 - hugr-passes/src/inline_funcs.rs | 229 --- hugr-passes/src/lib.rs | 4 +- hugr-passes/src/linearize_array.rs | 10 +- hugr-passes/src/lower.rs | 2 - hugr-passes/src/monomorphize.rs | 171 +- hugr-passes/src/non_local.rs | 1 + hugr-passes/src/replace_types.rs | 172 +- hugr-passes/src/replace_types/handlers.rs | 10 +- hugr-passes/src/replace_types/linearize.rs | 156 +- hugr-persistent/CHANGELOG.md | 11 - hugr-persistent/Cargo.toml | 43 - hugr-persistent/README.md | 59 - hugr-persistent/src/lib.rs | 98 - hugr-persistent/src/persistent_hugr/serial.rs | 75 - ..._serial__tests__serde_persistent_hugr.snap | 184 -- hugr-persistent/src/resolver.rs | 147 -- ..._serial__tests__serialize_state_space.snap | 244 --- hugr-persistent/src/subgraph.rs | 215 --- hugr-persistent/src/wire.rs | 303 ---- hugr-py/CHANGELOG.md | 103 -- hugr-py/Cargo.toml | 2 +- hugr-py/pyproject.toml | 2 +- hugr-py/rust/lib.rs | 11 - hugr-py/src/hugr/__init__.py | 2 +- hugr-py/src/hugr/_hugr/__init__.pyi | 1 - hugr-py/src/hugr/_serialization/extension.py | 28 +- hugr-py/src/hugr/_serialization/ops.py | 50 +- hugr-py/src/hugr/_serialization/tys.py | 94 +- hugr-py/src/hugr/build/dfg.py | 62 +- hugr-py/src/hugr/build/function.py | 42 +- hugr-py/src/hugr/envelope.py | 36 +- hugr-py/src/hugr/ext.py | 4 +- hugr-py/src/hugr/hugr/base.py | 229 +-- hugr-py/src/hugr/hugr/render.py | 11 +- hugr-py/src/hugr/model/__init__.py | 19 - hugr-py/src/hugr/model/export.py | 201 +-- hugr-py/src/hugr/ops.py | 83 +- .../_json_defs/collections/borrow_arr.json | 1139 ------------ hugr-py/src/hugr/std/_json_defs/prelude.json | 34 +- hugr-py/src/hugr/std/collections/array.py | 2 +- .../src/hugr/std/collections/borrow_array.py | 94 - .../src/hugr/std/collections/static_array.py | 3 - hugr-py/src/hugr/std/int.py | 30 +- hugr-py/src/hugr/tys.py | 229 +-- hugr-py/src/hugr/utils.py | 24 - hugr-py/src/hugr/val.py | 114 +- .../tests/__snapshots__/test_hugr_build.ambr | 374 +--- .../tests/__snapshots__/test_order_edges.ambr | 258 --- hugr-py/tests/conftest.py | 203 +-- hugr-py/tests/test_cfg.py | 2 +- hugr-py/tests/test_custom.py | 2 +- hugr-py/tests/test_envelope.py | 43 +- hugr-py/tests/test_hugr_build.py | 150 +- hugr-py/tests/test_ops.py | 6 +- hugr-py/tests/test_order_edges.py | 49 - hugr-py/tests/test_prelude.py | 38 - hugr-py/tests/test_tys.py | 58 +- hugr-py/tests/test_val.py | 14 +- hugr/CHANGELOG.md | 89 - hugr/Cargo.toml | 12 +- hugr/benches/benchmarks/hugr/examples.rs | 4 +- hugr/benches/benchmarks/types.rs | 2 +- hugr/src/lib.rs | 4 - justfile | 4 +- release-plz.toml | 9 - resources/test/hugr-no-visibility.hugr | 52 - scripts/check_extension_versions.py | 90 - scripts/generate_schema.py | 5 +- specification/hugr.md | 47 +- specification/schema/hugr_schema_live.json | 232 +-- .../schema/hugr_schema_strict_live.json | 232 +-- .../schema/testing_hugr_schema_live.json | 232 +-- .../testing_hugr_schema_strict_live.json | 232 +-- .../collections/borrow_arr.json | 1139 ------------ specification/std_extensions/prelude.json | 34 +- uv.lock | 2 +- 263 files changed, 5225 insertions(+), 16221 deletions(-) delete mode 100644 .config/nextest.toml rename hugr-persistent/src/persistent_hugr.rs => hugr-core/src/hugr/persistent.rs (59%) rename {hugr-persistent/src => hugr-core/src/hugr/persistent}/parents_view.rs (95%) create mode 100644 hugr-core/src/hugr/persistent/resolver.rs rename {hugr-persistent/src => hugr-core/src/hugr/persistent}/state_space.rs (70%) rename {hugr-persistent/src => hugr-core/src/hugr/persistent}/state_space/serial.rs (66%) rename {hugr-persistent/src => hugr-core/src/hugr/persistent}/tests.rs (81%) rename {hugr-persistent/src => hugr-core/src/hugr/persistent}/trait_impls.rs (92%) rename {hugr-persistent/src => hugr-core/src/hugr/persistent}/walker.rs (52%) create mode 100644 hugr-core/src/hugr/persistent/walker/pinned.rs delete mode 100644 hugr-core/src/ops/constant/serialize.rs delete mode 100644 hugr-core/src/std_extensions/collections/borrow_array.rs rename {hugr-persistent => hugr-core}/tests/persistent_walker_example.rs (62%) create mode 100644 hugr-llvm/src/emit/snapshots/hugr_llvm__emit__test__test_fns__diverse_dfg_children@llvm14.snap create mode 100644 hugr-llvm/src/emit/snapshots/hugr_llvm__emit__test__test_fns__diverse_dfg_children@pre-mem2reg@llvm14.snap delete mode 100644 hugr-llvm/src/extension/snapshots/hugr_llvm__extension__prelude__test__prelude_make_error@llvm14.snap delete mode 100644 hugr-llvm/src/extension/snapshots/hugr_llvm__extension__prelude__test__prelude_make_error@pre-mem2reg@llvm14.snap delete mode 100644 hugr-llvm/src/extension/snapshots/hugr_llvm__extension__prelude__test__prelude_make_error_and_panic@llvm14.snap delete mode 100644 hugr-llvm/src/extension/snapshots/hugr_llvm__extension__prelude__test__prelude_make_error_and_panic@pre-mem2reg@llvm14.snap delete mode 100644 hugr-model/FORMAT_VERSION delete mode 100644 hugr-passes/src/inline_dfgs.rs delete mode 100644 hugr-passes/src/inline_funcs.rs delete mode 100644 hugr-persistent/CHANGELOG.md delete mode 100644 hugr-persistent/Cargo.toml delete mode 100644 hugr-persistent/README.md delete mode 100644 hugr-persistent/src/lib.rs delete mode 100644 hugr-persistent/src/persistent_hugr/serial.rs delete mode 100644 hugr-persistent/src/persistent_hugr/snapshots/hugr_persistent__persistent_hugr__serial__tests__serde_persistent_hugr.snap delete mode 100644 hugr-persistent/src/resolver.rs delete mode 100644 hugr-persistent/src/state_space/snapshots/hugr_persistent__state_space__serial__tests__serialize_state_space.snap delete mode 100644 hugr-persistent/src/subgraph.rs delete mode 100644 hugr-persistent/src/wire.rs delete mode 100644 hugr-py/src/hugr/std/_json_defs/collections/borrow_arr.json delete mode 100644 hugr-py/src/hugr/std/collections/borrow_array.py delete mode 100644 hugr-py/tests/__snapshots__/test_order_edges.ambr delete mode 100644 hugr-py/tests/test_order_edges.py delete mode 100644 resources/test/hugr-no-visibility.hugr delete mode 100644 scripts/check_extension_versions.py delete mode 100644 specification/std_extensions/collections/borrow_arr.json diff --git a/.config/nextest.toml b/.config/nextest.toml deleted file mode 100644 index e267109367..0000000000 --- a/.config/nextest.toml +++ /dev/null @@ -1,7 +0,0 @@ - - -[profile.default-miri] -# Fail if tests take more than 5 mins. -# Those tests should be skipped in `.github/workflows/unsondness.yml`. -slow-timeout = { period = "60s", terminate-after = 5 } -fail-fast = false diff --git a/.github/change-filters.yml b/.github/change-filters.yml index 4fe700efc2..f7f7864564 100644 --- a/.github/change-filters.yml +++ b/.github/change-filters.yml @@ -1,8 +1,5 @@ # Filters used by [dorny/path-filters](https://github.com/dorny/paths-filter) # to detect changes in each subproject, and only run the corresponding jobs. -# -# We use a composable action to add some additional checks. -# When adding a new category here, make sure to also update `.github/actions/check-changes/action.yml` # Dependencies and common workspace configuration. rust-config: &rust-config @@ -27,12 +24,8 @@ rust: &rust - "hugr-cli/**" - "hugr-core/**" - "hugr-passes/**" - - "hugr-persistent/**" - "specification/schema/**" -std-extensions: - - "specification/std_extensions/**" - python: - *rust - "hugr-py/**" diff --git a/.github/workflows/ci-py.yml b/.github/workflows/ci-py.yml index c4f88b3f2e..b71cbdca04 100644 --- a/.github/workflows/ci-py.yml +++ b/.github/workflows/ci-py.yml @@ -21,19 +21,22 @@ env: jobs: # Check if changes were made to the relevant files. - # Always returns true if running on the default branch, to ensure all changes are thoroughly checked. + # Always returns true if running on the default branch, to ensure all changes are throughly checked. changes: - name: Check for changes + name: Check for changes in Python files runs-on: ubuntu-latest + # Required permissions permissions: pull-requests: read + # Set job outputs to values from filter step outputs: - python: ${{ steps.filter.outputs.python }} - extensions: ${{ steps.filter.outputs.llvm }} + python: ${{ github.ref_name == github.event.repository.default_branch || steps.filter.outputs.python }} steps: - uses: actions/checkout@v4 - - uses: ./.github/actions/check-changes + - uses: dorny/paths-filter@v3 id: filter + with: + filters: .github/change-filters.yml check: needs: changes @@ -176,41 +179,11 @@ jobs: exit 1 fi - extension-versions: - runs-on: ubuntu-latest - needs: [changes] - if: ${{ needs.changes.outputs.extensions == 'true' }} - name: Check std extensions versions - steps: - - uses: actions/checkout@v4 - with: - fetch-depth: 0 # Need full history to compare with main - - - name: Set up Python - uses: actions/setup-python@v5 - with: - python-version: '3.10' - - - name: Check if extension versions are updated - run: | - # Check against latest tag on the target branch - # When not on a pull request, base_ref should be empty so we default to HEAD - if [ -z "$TARGET_REF" ]; then - BASE_SHA="HEAD~1" - else - BASE_SHA=$(git rev-parse origin/$TARGET_REF) - fi - echo "Comparing to ref: $BASE_SHA" - - python ./scripts/check_extension_versions.py $BASE_SHA - env: - TARGET_REF: ${{ github.base_ref }} - # This is a meta job to mark successful completion of the required checks, # even if they are skipped due to no changes in the relevant files. required-checks: name: Required checks 🐍 - needs: [changes, check, test, serialization-schema, extension-versions] + needs: [changes, check, test, serialization-schema] if: ${{ !cancelled() }} runs-on: ubuntu-latest steps: diff --git a/.github/workflows/ci-rs.yml b/.github/workflows/ci-rs.yml index cc7a19635c..40e0046e06 100644 --- a/.github/workflows/ci-rs.yml +++ b/.github/workflows/ci-rs.yml @@ -19,7 +19,6 @@ env: CI: true # insta snapshots behave differently on ci SCCACHE_GHA_ENABLED: "true" RUSTC_WRAPPER: "sccache" - HUGR_TEST_SCHEMA: "1" # different strings for install action and feature name # adapted from https://github.com/TheDan64/inkwell/blob/master/.github/workflows/test.yml LLVM_VERSION: "14.0" @@ -31,16 +30,36 @@ jobs: changes: name: Check for changes runs-on: ubuntu-latest + # Required permissions permissions: pull-requests: read + # Set job outputs to values from filter step + # These outputs are always true when running after a merge to main, or if the PR has a `run-ci-checks` label. outputs: - rust: ${{ steps.filter.outputs.rust }} - llvm: ${{ steps.filter.outputs.llvm }} - model: ${{ steps.filter.outputs.model }} + rust: ${{ steps.filter.outputs.rust == 'true' || steps.override.outputs.out == 'true' }} + python: ${{ steps.filter.outputs.python == 'true' || steps.override.outputs.out == 'true' }} + model: ${{ steps.filter.outputs.model == 'true' || steps.override.outputs.out == 'true' }} + llvm: ${{ steps.filter.outputs.llvm == 'true' || steps.override.outputs.out == 'true' }} steps: - uses: actions/checkout@v4 - - uses: ./.github/actions/check-changes + - name: Override label + id: override + run: | + echo "Label contains run-ci-checks: $OVERRIDE_LABEL" + if [ "$OVERRIDE_LABEL" == "true" ]; then + echo "Overriding due to label 'run-ci-checks'" + echo "out=true" >> $GITHUB_OUTPUT + elif [ "$DEFAULT_BRANCH" == "true" ]; then + echo "Overriding due to running on the default branch" + echo "out=true" >> $GITHUB_OUTPUT + fi + env: + OVERRIDE_LABEL: ${{ github.event_name == 'pull_request' && contains( github.event.pull_request.labels.*.name, 'run-ci-checks') }} + DEFAULT_BRANCH: ${{ github.ref_name == github.event.repository.default_branch }} + - uses: dorny/paths-filter@v3 id: filter + with: + filters: .github/change-filters.yml check: needs: changes @@ -151,12 +170,6 @@ jobs: - name: Tests hugr-llvm if: ${{ needs.changes.outputs.llvm == 'true'}} run: cargo test -p hugr-llvm --verbose --features llvm${{ env.LLVM_FEATURE_NAME }} - - name: Build hugr-persistent - if: ${{ needs.changes.outputs.rust == 'true'}} - run: cargo test -p hugr-persistent --verbose --no-run - - name: Tests hugr-persistent - if: ${{ needs.changes.outputs.rust == 'true'}} - run: cargo test -p hugr-persistent --verbose - name: Build HUGR binary run: cargo build -p hugr-cli - name: Upload the binary to the artifacts @@ -342,7 +355,6 @@ jobs: cargo llvm-cov --no-report --no-default-features --doctests cargo llvm-cov --no-report --all-features --doctests cargo llvm-cov --no-report -p hugr-llvm --features llvm14-0 --doctests - cargo llvm-cov --no-report -p hugr-persistent --doctests - name: Generate coverage report run: cargo llvm-cov --all-features report --codecov --output-path coverage.json - name: Upload coverage to codecov.io diff --git a/.github/workflows/notify-coverage.yml b/.github/workflows/notify-coverage.yml index 7eae317ab0..25c8f21702 100644 --- a/.github/workflows/notify-coverage.yml +++ b/.github/workflows/notify-coverage.yml @@ -22,7 +22,7 @@ jobs: if: needs.coverage-trend.outputs.should_notify == 'true' steps: - name: Send notification - uses: slackapi/slack-github-action@v2.1.1 + uses: slackapi/slack-github-action@v2.1.0 with: method: chat.postMessage token: ${{ secrets.SLACK_BOT_TOKEN }} diff --git a/.github/workflows/release-please.yml b/.github/workflows/release-please.yml index 35f2c6711b..16534f141c 100644 --- a/.github/workflows/release-please.yml +++ b/.github/workflows/release-please.yml @@ -6,7 +6,6 @@ on: push: branches: - main - - release/* permissions: contents: write @@ -22,4 +21,3 @@ jobs: # Using a personal access token so releases created by this workflow can trigger the deployment workflow token: ${{ secrets.HUGRBOT_PAT }} config-file: release-please-config.json - target-branch: ${{ github.ref_name }} diff --git a/.github/workflows/release-plz.yml b/.github/workflows/release-plz.yml index 9139b9eb32..5eacec276b 100644 --- a/.github/workflows/release-plz.yml +++ b/.github/workflows/release-plz.yml @@ -13,9 +13,6 @@ jobs: release-plz: name: Release-plz runs-on: ubuntu-latest - environment: crate-release - permissions: - id-token: write # Required for OIDC token exchange steps: - name: Checkout repository uses: actions/checkout@v4 @@ -35,11 +32,8 @@ jobs: # otherwise release-plz fails due to uncommitted changes. directory: ${{ runner.temp }}/llvm - - uses: rust-lang/crates-io-auth-action@v1 - id: auth - - name: Run release-plz uses: MarcoIeni/release-plz-action@v0.5 env: GITHUB_TOKEN: ${{ secrets.HUGRBOT_PAT }} - CARGO_REGISTRY_TOKEN: ${{ steps.auth.outputs.token }} + CARGO_REGISTRY_TOKEN: ${{ secrets.CARGO_REGISTRY_TOKEN }} diff --git a/.github/workflows/semver-checks.yml b/.github/workflows/semver-checks.yml index 881ec2227d..2c410aa85d 100644 --- a/.github/workflows/semver-checks.yml +++ b/.github/workflows/semver-checks.yml @@ -6,19 +6,38 @@ on: jobs: # Check if changes were made to the relevant files. - # Always returns true if running on the default branch, to ensure all changes are thoroughly checked. + # Always returns true if running on the default branch, to ensure all changes are throughly checked. changes: name: Check for changes runs-on: ubuntu-latest + # Required permissions permissions: pull-requests: read + # Set job outputs to values from filter step + # These outputs are always true when running after a merge to main, or if the PR has a `run-ci-checks` label. outputs: - rust: ${{ steps.filter.outputs.rust }} - python: ${{ steps.filter.outputs.python }} + rust: ${{ steps.filter.outputs.rust == 'true' || steps.override.outputs.out == 'true' }} + python: ${{ steps.filter.outputs.python == 'true' || steps.override.outputs.out == 'true' }} steps: - - uses: actions/checkout@v4 - - uses: ./.github/actions/check-changes - id: filter + - uses: actions/checkout@v4 + - name: Override label + id: override + run: | + echo "Label contains run-ci-checks: $OVERRIDE_LABEL" + if [ "$OVERRIDE_LABEL" == "true" ]; then + echo "Overriding due to label 'run-ci-checks'" + echo "out=true" >> $GITHUB_OUTPUT + elif [ "$DEFAULT_BRANCH" == "true" ]; then + echo "Overriding due to running on the default branch" + echo "out=true" >> $GITHUB_OUTPUT + fi + env: + OVERRIDE_LABEL: ${{ github.event_name == 'pull_request' && contains( github.event.pull_request.labels.*.name, 'run-ci-checks') }} + DEFAULT_BRANCH: ${{ github.ref_name == github.event.repository.default_branch }} + - uses: dorny/paths-filter@v3 + id: filter + with: + filters: .github/change-filters.yml rs-semver-checks: needs: [changes] diff --git a/.github/workflows/unsoundness.yml b/.github/workflows/unsoundness.yml index f110fe168b..3ba0b4cf0d 100644 --- a/.github/workflows/unsoundness.yml +++ b/.github/workflows/unsoundness.yml @@ -31,53 +31,12 @@ jobs: rustup toolchain install nightly --component miri rustup override set nightly cargo miri setup - - - uses: taiki-e/install-action@v2 + - uses: Swatinem/rust-cache@v2 with: - tool: nextest - - # Run miri unsoundness checks. - # - # The default "zstd" feature requires FFI to the zstd library encode/decode envelopes. - # As this is not supported in miri, we must disable it here. - # - # We also skip tests that take over 5mins in CI. + prefix-key: v0-miri - name: Test with Miri - run: | - cargo miri nextest run --no-default-features -- \ - --skip "builder::circuit::test::with_nonlinear_and_outputs" \ - --skip "extension::op_def::test::check_ext_id_wellformed" \ - --skip "extension::resolution::test::register_new_cyclic" \ - --skip "extension::simple_op::test::check_ext_id_wellformed" \ - --skip "extension::test::test_register_update" \ - --skip "extension::type_def::test::test_instantiate_typedef" \ - --skip "hugr::ident::test::proptest::arbitrary_identlist_valid" \ - --skip "hugr::ident::test::test_idents" \ - --skip "hugr::patch::replace::test::test_invalid" \ - --skip "hugr::validate::test::check_ext_id_wellformed" \ - --skip "ops::constant::test::test_json_const" \ - --skip "ops::custom::test::resolve_missing" \ - --skip "ops::custom::test::new_opaque_op" \ - --skip "std_extensions::arithmetic::int_types::test::proptest::valid_signed_int" \ - --skip "types::test::construct" \ - --skip "types::test::transform" \ - --skip "types::test::transform_copyable_to_linear" \ - --skip "types::type_param::test::proptest::term_contains_itself" \ - `# -------- hugr-model` \ - --skip "v0::test::test_literal_text" \ - `# -------- hugr-passes` \ - --skip "dataflow::partial_value::test::bounded_lattice" \ - --skip "dataflow::partial_value::test::lattice" \ - --skip "dataflow::partial_value::test::lattice_associative" \ - --skip "dataflow::partial_value::test::meet_join_self_noop" \ - --skip "dataflow::partial_value::test::partial_value_type" \ - --skip "dataflow::partial_value::test::partial_value_valid" \ - --skip "merge_bbs::test::check_ext_id_wellformed" \ - --skip "monomorphize::test::test_recursion_module" \ - --skip "replace_types::test::dfg_conditional_case" \ - --skip "replace_types::test::module_func_cfg_call" \ - --skip "replace_types::test::op_to_call" \ - # + run: cargo miri test + create-issue: uses: CQCL/hugrverse-actions/.github/workflows/create-issue.yml@main diff --git a/.release-please-manifest.json b/.release-please-manifest.json index 8e5bbf9e12..7f59cf2351 100644 --- a/.release-please-manifest.json +++ b/.release-please-manifest.json @@ -1,3 +1,3 @@ { - "hugr-py": "0.13.0rc1" -} \ No newline at end of file + "hugr-py": "0.12.1" +} diff --git a/Cargo.lock b/Cargo.lock index a9d98367a2..4981ad7340 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -327,9 +327,9 @@ dependencies = [ [[package]] name = "bumpalo" -version = "3.19.0" +version = "3.18.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "46c5e41b57b8bba42a04676d81cb89e9ee8e859a1a66f80a5a72e1cb76b34d43" +checksum = "793db76d6187cd04dff33004d8e6c9cc4e05cd330500379d2394209271b4aeee" [[package]] name = "bytecount" @@ -351,9 +351,9 @@ checksum = "d71b6127be86fdcfddb610f7182ac57211d4b18a3e9c82eb2d17662f2227ad6a" [[package]] name = "capnp" -version = "0.21.3" +version = "0.20.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d55799fdec2a55eee8c267430d7464eb9c27ad2e5c8a49b433ff213b56852c7f" +checksum = "053b81915c2ce1629b8fb964f578b18cb39b23ef9d5b24120d0dfc959569a1d9" dependencies = [ "embedded-io", ] @@ -440,9 +440,9 @@ dependencies = [ [[package]] name = "clap" -version = "4.5.41" +version = "4.5.40" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "be92d32e80243a54711e5d7ce823c35c41c9d929dc4ab58e1276f625841aadf9" +checksum = "40b6887a1d8685cebccf115538db5c0efe625ccac9696ad45c409d96566e910f" dependencies = [ "clap_builder", "clap_derive", @@ -460,9 +460,9 @@ dependencies = [ [[package]] name = "clap_builder" -version = "4.5.41" +version = "4.5.40" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "707eab41e9622f9139419d573eca0900137718000c517d47da73045f54331c3d" +checksum = "e0c66c08ce9f0c698cbce5c0279d0bb6ac936d8674174fe48f736533b964f59e" dependencies = [ "anstream", "anstyle", @@ -472,9 +472,9 @@ dependencies = [ [[package]] name = "clap_derive" -version = "4.5.41" +version = "4.5.40" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ef4f52386a59ca4c860f7393bcf8abd8dfd91ecccc0f774635ff68e92eeef491" +checksum = "d2c7947ae4cc3d851207c1adb5b5e260ff0cca11446b1d6d1423788e442257ce" dependencies = [ "heck", "proc-macro2", @@ -676,9 +676,9 @@ dependencies = [ [[package]] name = "delegate" -version = "0.13.4" +version = "0.13.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6178a82cf56c836a3ba61a7935cdb1c49bfaa6fa4327cd5bf554a503087de26b" +checksum = "b9b6483c2bbed26f97861cf57651d4f2b731964a28cd2257f934a4b452480d21" dependencies = [ "proc-macro2", "quote", @@ -1201,7 +1201,7 @@ checksum = "6dbf3de79e51f3d586ab4cb9d5c3e2c14aa28ed23d180cf89b4df0454a69cc87" [[package]] name = "hugr" -version = "0.22.1" +version = "0.20.2" dependencies = [ "bumpalo", "criterion", @@ -1209,14 +1209,13 @@ dependencies = [ "hugr-llvm", "hugr-model", "hugr-passes", - "hugr-persistent", "lazy_static", "serde_json", ] [[package]] name = "hugr-cli" -version = "0.22.1" +version = "0.20.2" dependencies = [ "anyhow", "assert_cmd", @@ -1231,16 +1230,12 @@ dependencies = [ "serde_json", "tempfile", "thiserror 2.0.12", - "tracing", - "tracing-subscriber", ] [[package]] name = "hugr-core" -version = "0.22.1" +version = "0.20.2" dependencies = [ - "anyhow", - "base64", "cgmath", "cool_asserts", "delegate", @@ -1251,12 +1246,11 @@ dependencies = [ "html-escape", "hugr", "hugr-model", - "indexmap 2.10.0", + "indexmap 2.9.0", "insta", "itertools 0.14.0", "jsonschema", "lazy_static", - "ordered-float", "paste", "petgraph 0.8.2", "portgraph", @@ -1270,19 +1264,17 @@ dependencies = [ "serde_json", "serde_with", "serde_yaml", - "smallvec", "smol_str", "static_assertions", "strum", "thiserror 2.0.12", - "tracing", "typetag", "zstd", ] [[package]] name = "hugr-llvm" -version = "0.22.1" +version = "0.20.2" dependencies = [ "anyhow", "delegate", @@ -1301,14 +1293,14 @@ dependencies = [ [[package]] name = "hugr-model" -version = "0.22.1" +version = "0.20.2" dependencies = [ "base64", "bumpalo", "capnp", "derive_more 1.0.0", "fxhash", - "indexmap 2.10.0", + "indexmap 2.9.0", "insta", "itertools 0.14.0", "ordered-float", @@ -1319,14 +1311,13 @@ dependencies = [ "proptest", "proptest-derive", "pyo3", - "semver", "smol_str", "thiserror 2.0.12", ] [[package]] name = "hugr-passes" -version = "0.22.1" +version = "0.20.2" dependencies = [ "ascent", "derive_more 1.0.0", @@ -1343,28 +1334,6 @@ dependencies = [ "thiserror 2.0.12", ] -[[package]] -name = "hugr-persistent" -version = "0.2.1" -dependencies = [ - "delegate", - "derive_more 1.0.0", - "hugr-core", - "insta", - "itertools 0.14.0", - "lazy_static", - "petgraph 0.8.2", - "portgraph", - "relrc", - "rstest", - "semver", - "serde", - "serde_json", - "serde_with", - "thiserror 2.0.12", - "wyhash", -] - [[package]] name = "hugr-py" version = "0.1.0" @@ -1604,9 +1573,9 @@ dependencies = [ [[package]] name = "indexmap" -version = "2.10.0" +version = "2.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fe4cd85333e22411419a0bcae1297d25e58c9443848b11dc6a86fefe8c78a661" +checksum = "cea70ddb795996207ad57735b50c5982d8844f38ba9ee5f1aedcfb708a2aa11e" dependencies = [ "equivalent", "hashbrown 0.15.4", @@ -1879,16 +1848,6 @@ version = "0.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "61807f77802ff30975e01f4f071c8ba10c022052f98b3294119f3e615d13e5be" -[[package]] -name = "nu-ansi-term" -version = "0.46.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "77a8165726e8236064dbb45459242600304b42a5ea24ee2948e18e023bf7ba84" -dependencies = [ - "overload", - "winapi", -] - [[package]] name = "num" version = "0.4.3" @@ -2008,8 +1967,6 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e2c1f9f56e534ac6a9b8a4600bdf0f530fb393b5f393e7b4d03489c3cf0c3f01" dependencies = [ "num-traits", - "rand 0.8.5", - "serde", ] [[package]] @@ -2018,12 +1975,6 @@ version = "0.5.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1a80800c0488c3a21695ea981a54918fbb37abf04f4d0720c453632255e2ff0e" -[[package]] -name = "overload" -version = "0.1.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b15813163c1d831bf4a13c3610c05c0d03b39feb07f7e09fa234dac9b15aaf39" - [[package]] name = "parking_lot" version = "0.12.4" @@ -2110,7 +2061,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b4c5cc86750666a3ed20bdaf5ca2a0344f9c67674cae0515bec2da16fbaa47db" dependencies = [ "fixedbitset 0.4.2", - "indexmap 2.10.0", + "indexmap 2.9.0", ] [[package]] @@ -2121,7 +2072,7 @@ checksum = "54acf3a685220b533e437e264e4d932cfbdc4cc7ec0cd232ed73c08d03b8a7ca" dependencies = [ "fixedbitset 0.5.7", "hashbrown 0.15.4", - "indexmap 2.10.0", + "indexmap 2.9.0", "serde", ] @@ -2179,14 +2130,13 @@ checksum = "f84267b20a16ea918e43c6a88433c2d54fa145c92a811b5b047ccbe153674483" [[package]] name = "portgraph" -version = "0.15.1" +version = "0.14.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "61fb905fbfbc9abf3bd37853bbd4b25d31dffd5631994f8df528f85455085657" +checksum = "5fdce52d51ec359351ff3c209fafb6f133562abf52d951ce5821c0184798d979" dependencies = [ "bitvec", "delegate", "itertools 0.14.0", - "num-traits", "petgraph 0.8.2", "serde", "thiserror 2.0.12", @@ -2296,7 +2246,7 @@ dependencies = [ "bitflags", "lazy_static", "num-traits", - "rand 0.9.1", + "rand", "rand_chacha", "rand_xorshift", "regex-syntax", @@ -2416,16 +2366,6 @@ version = "0.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "dc33ff2d4973d518d823d61aa239014831e521c75da58e3df4840d3f47749d09" -[[package]] -name = "rand" -version = "0.8.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404" -dependencies = [ - "rand_core 0.6.4", - "serde", -] - [[package]] name = "rand" version = "0.9.1" @@ -2451,9 +2391,6 @@ name = "rand_core" version = "0.6.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c" -dependencies = [ - "serde", -] [[package]] name = "rand_core" @@ -2738,18 +2675,6 @@ dependencies = [ "serde_json", ] -[[package]] -name = "schemars" -version = "1.0.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1375ba8ef45a6f15d83fa8748f1079428295d403d6ea991d09ab100155fbc06d" -dependencies = [ - "dyn-clone", - "ref-cast", - "serde", - "serde_json", -] - [[package]] name = "scopeguard" version = "1.2.0" @@ -2811,17 +2736,16 @@ dependencies = [ [[package]] name = "serde_with" -version = "3.14.0" +version = "3.13.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f2c45cd61fefa9db6f254525d46e392b852e0e61d9a1fd36e5bd183450a556d5" +checksum = "bf65a400f8f66fb7b0552869ad70157166676db75ed8181f8104ea91cf9d0b42" dependencies = [ "base64", "chrono", "hex", "indexmap 1.9.3", - "indexmap 2.10.0", - "schemars 0.9.0", - "schemars 1.0.3", + "indexmap 2.9.0", + "schemars", "serde", "serde_derive", "serde_json", @@ -2831,9 +2755,9 @@ dependencies = [ [[package]] name = "serde_with_macros" -version = "3.14.0" +version = "3.13.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "de90945e6565ce0d9a25098082ed4ee4002e047cb59892c318d66821e14bb30f" +checksum = "81679d9ed988d5e9a5e6531dc3f2c28efbd639cbd1dfb628df08edea6004da77" dependencies = [ "darling", "proc-macro2", @@ -2847,7 +2771,7 @@ version = "0.9.34+deprecated" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6a8b1a1a2ebf674015cc02edccce75287f1a0130d394307b36743c2f5d504b47" dependencies = [ - "indexmap 2.10.0", + "indexmap 2.9.0", "itoa", "ryu", "serde", @@ -2865,15 +2789,6 @@ dependencies = [ "digest", ] -[[package]] -name = "sharded-slab" -version = "0.1.7" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f40ca3c46823713e0d4209592e8d6e826aa57e928f09752619fc696c499637f6" -dependencies = [ - "lazy_static", -] - [[package]] name = "shlex" version = "1.3.0" @@ -3070,15 +2985,6 @@ dependencies = [ "syn", ] -[[package]] -name = "thread_local" -version = "1.1.9" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f60246a4944f24f6e018aa17cdeffb7818b76356965d03b07d6a9886e8962185" -dependencies = [ - "cfg-if", -] - [[package]] name = "time" version = "0.3.41" @@ -3156,7 +3062,7 @@ version = "0.22.27" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "41fe8c660ae4257887cf66394862d21dbca4a6ddd26f04a3560410406a2f819a" dependencies = [ - "indexmap 2.10.0", + "indexmap 2.9.0", "toml_datetime", "winnow", ] @@ -3213,21 +3119,9 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "784e0ac535deb450455cbfa28a6f0df145ea1bb7ae51b821cf5e7927fdcfbdd0" dependencies = [ "pin-project-lite", - "tracing-attributes", "tracing-core", ] -[[package]] -name = "tracing-attributes" -version = "0.1.29" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1b1ffbcf9c6f6b99d386e7444eb608ba646ae452a36b39737deb9663b610f662" -dependencies = [ - "proc-macro2", - "quote", - "syn", -] - [[package]] name = "tracing-core" version = "0.1.34" @@ -3235,32 +3129,6 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b9d12581f227e93f094d3af2ae690a574abb8a2b9b7a96e7cfe9647b2b617678" dependencies = [ "once_cell", - "valuable", -] - -[[package]] -name = "tracing-log" -version = "0.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ee855f1f400bd0e5c02d150ae5de3840039a3f54b025156404e34c23c03f47c3" -dependencies = [ - "log", - "once_cell", - "tracing-core", -] - -[[package]] -name = "tracing-subscriber" -version = "0.3.19" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e8189decb5ac0fa7bc8b96b7cb9b2701d60d48805aca84a238004d665fcc4008" -dependencies = [ - "nu-ansi-term", - "sharded-slab", - "smallvec", - "thread_local", - "tracing-core", - "tracing-log", ] [[package]] @@ -3403,12 +3271,6 @@ dependencies = [ "vsimd", ] -[[package]] -name = "valuable" -version = "0.1.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ba73ea9cf16a25df0c8caa16c51acb937d5712a8429db78a3ee29d5dcacd3a65" - [[package]] name = "version_check" version = "0.9.5" @@ -3545,22 +3407,6 @@ dependencies = [ "wasm-bindgen", ] -[[package]] -name = "winapi" -version = "0.3.9" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5c839a674fcd7a98952e593242ea400abe93992746761e38641405d28b00f419" -dependencies = [ - "winapi-i686-pc-windows-gnu", - "winapi-x86_64-pc-windows-gnu", -] - -[[package]] -name = "winapi-i686-pc-windows-gnu" -version = "0.4.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6" - [[package]] name = "winapi-util" version = "0.1.9" @@ -3570,12 +3416,6 @@ dependencies = [ "windows-sys 0.59.0", ] -[[package]] -name = "winapi-x86_64-pc-windows-gnu" -version = "0.4.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f" - [[package]] name = "windows-core" version = "0.61.2" @@ -3798,15 +3638,6 @@ version = "0.6.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ea2f10b9bb0928dfb1b42b65e1f9e36f7f54dbdf08457afefb38afcdec4fa2bb" -[[package]] -name = "wyhash" -version = "0.6.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ca4d373340c479fd1e779f7a763acee85da3e423b1a9a9acccf97babcc92edbb" -dependencies = [ - "rand_core 0.9.3", -] - [[package]] name = "wyz" version = "0.5.1" diff --git a/Cargo.toml b/Cargo.toml index 56d2033786..b123c9897d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -11,7 +11,6 @@ members = [ "hugr-model", "hugr-llvm", "hugr-py", - "hugr-persistent", ] default-members = ["hugr", "hugr-core", "hugr-passes", "hugr-cli", "hugr-model"] @@ -38,14 +37,18 @@ missing_docs = "warn" # https://github.com/rust-lang/rust-clippy/issues/5112 debug_assert_with_mut_call = "warn" +# TODO: Reduce the size of error types. +result_large_err = "allow" +large_enum_variant = "allow" + [workspace.dependencies] anyhow = "1.0.98" insta = { version = "1.43.1" } bitvec = "1.0.1" -capnp = "0.21.3" +capnp = "0.20.6" cgmath = "0.18.0" cool_asserts = "2.0.3" -delegate = "0.13.4" +delegate = "0.13.3" derive_more = "1.0.0" downcast-rs = "2.0.1" enum_dispatch = "0.3.11" @@ -63,7 +66,7 @@ rstest = "0.24.0" semver = "1.0.26" serde = "1.0.219" serde_json = "1.0.140" -serde_with = "3.14.0" +serde_with = "3.13.0" serde_yaml = "0.9.34" smol_str = "0.3.1" static_assertions = "1.1.0" @@ -71,15 +74,15 @@ strum = "0.27.0" tempfile = "3.20" thiserror = "2.0.12" typetag = "0.2.20" -clap = { version = "4.5.41" } +clap = { version = "4.5.40" } clio = "0.3.5" clap-verbosity-flag = "3.0.3" assert_cmd = "2.0.17" assert_fs = "1.1.3" predicates = "3.1.0" -indexmap = "2.10.0" +indexmap = "2.9.0" fxhash = "0.2.1" -bumpalo = "3.19.0" +bumpalo = "3.18.1" pathsearch = "0.2.0" base64 = "0.22.1" ordered-float = "5.0.0" @@ -89,12 +92,11 @@ pretty = "0.12.4" pretty_assertions = "1.4.1" zstd = "0.13.2" relrc = "0.4.6" -wyhash = "0.6.0" # These public dependencies usually require breaking changes downstream, so we # try to be as permissive as possible. pyo3 = ">= 0.23.4, < 0.25" -portgraph = { version = "0.15.1" } +portgraph = { version = "0.14.1" } petgraph = { version = ">= 0.8.1, < 0.9", default-features = false } [profile.dev.package] diff --git a/hugr-cli/CHANGELOG.md b/hugr-cli/CHANGELOG.md index 090b49ea50..5bbc26a608 100644 --- a/hugr-cli/CHANGELOG.md +++ b/hugr-cli/CHANGELOG.md @@ -1,18 +1,6 @@ # Changelog -## [0.22.0](https://github.com/CQCL/hugr/compare/hugr-cli-v0.21.0...hugr-cli-v0.22.0) - 2025-07-24 - -### New Features - -- include generator metatada in model import and cli validate errors ([#2452](https://github.com/CQCL/hugr/pull/2452)) - -## [0.21.0](https://github.com/CQCL/hugr/compare/hugr-cli-v0.20.2...hugr-cli-v0.21.0) - 2025-07-09 - -### New Features - -- [**breaking**] Better error reporting in `hugr-cli`. ([#2318](https://github.com/CQCL/hugr/pull/2318)) - ## [0.20.2](https://github.com/CQCL/hugr/compare/hugr-cli-v0.20.1...hugr-cli-v0.20.2) - 2025-06-25 ### New Features diff --git a/hugr-cli/Cargo.toml b/hugr-cli/Cargo.toml index aba4984530..a05666cc0a 100644 --- a/hugr-cli/Cargo.toml +++ b/hugr-cli/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "hugr-cli" -version = "0.22.1" +version = "0.20.2" edition = { workspace = true } rust-version = { workspace = true } license = { workspace = true } @@ -19,13 +19,11 @@ bench = false clap = { workspace = true, features = ["derive", "cargo"] } clap-verbosity-flag.workspace = true derive_more = { workspace = true, features = ["display", "error", "from"] } -hugr = { path = "../hugr", version = "0.22.1" } +hugr = { path = "../hugr", version = "0.20.2" } serde_json.workspace = true clio = { workspace = true, features = ["clap-parse"] } anyhow.workspace = true thiserror.workspace = true -tracing = "0.1.41" -tracing-subscriber = { version = "0.3.19", features = ["fmt"] } [lints] workspace = true diff --git a/hugr-cli/src/extensions.rs b/hugr-cli/src/extensions.rs index 1fc31e8571..ff4862634b 100644 --- a/hugr-cli/src/extensions.rs +++ b/hugr-cli/src/extensions.rs @@ -1,5 +1,4 @@ //! Dump standard extensions in serialized form. -use anyhow::Result; use clap::Parser; use hugr::extension::ExtensionRegistry; use std::{io::Write, path::PathBuf}; @@ -26,7 +25,7 @@ impl ExtArgs { /// Write out the standard extensions in serialized form. /// Qualified names of extensions used to generate directories under the specified output directory. /// E.g. extension "foo.bar.baz" will be written to "OUTPUT/foo/bar/baz.json". - pub fn run_dump(&self, registry: &ExtensionRegistry) -> Result<()> { + pub fn run_dump(&self, registry: &ExtensionRegistry) { let base_dir = &self.outdir; for ext in registry { @@ -36,17 +35,15 @@ impl ExtArgs { } path.set_extension("json"); - std::fs::create_dir_all(path.clone().parent().unwrap())?; + std::fs::create_dir_all(path.clone().parent().unwrap()).unwrap(); // file buffer - let mut file = std::fs::File::create(&path)?; + let mut file = std::fs::File::create(&path).unwrap(); - serde_json::to_writer_pretty(&mut file, &ext)?; + serde_json::to_writer_pretty(&mut file, &ext).unwrap(); // write newline, for pre-commit end of file check that edits the file to // add newlines if missing. - file.write_all(b"\n")?; + file.write_all(b"\n").unwrap(); } - - Ok(()) } } diff --git a/hugr-cli/src/lib.rs b/hugr-cli/src/lib.rs index d4c269d811..0b91ed547b 100644 --- a/hugr-cli/src/lib.rs +++ b/hugr-cli/src/lib.rs @@ -57,11 +57,11 @@ //! ``` use clap::{Parser, crate_version}; +use clap_verbosity_flag::log::Level; use clap_verbosity_flag::{InfoLevel, Verbosity}; use hugr::envelope::EnvelopeError; use hugr::package::PackageValidationError; use std::ffi::OsString; -use thiserror::Error; pub mod convert; pub mod extensions; @@ -74,19 +74,8 @@ pub mod validate; #[clap(version = crate_version!(), long_about = None)] #[clap(about = "HUGR CLI tools.")] #[group(id = "hugr")] -pub struct CliArgs { - /// The command to be run. - #[command(subcommand)] - pub command: CliCommand, - /// Verbosity. - #[command(flatten)] - pub verbose: Verbosity, -} - -/// The CLI subcommands. -#[derive(Debug, clap::Subcommand)] #[non_exhaustive] -pub enum CliCommand { +pub enum CliArgs { /// Validate and visualize a HUGR file. Validate(validate::ValArgs), /// Write standard extensions out in serialized form. @@ -101,38 +90,45 @@ pub enum CliCommand { } /// Error type for the CLI. -#[derive(Debug, Error)] +#[derive(Debug, derive_more::Display, thiserror::Error, derive_more::From)] #[non_exhaustive] pub enum CliError { /// Error reading input. - #[error("Error reading from path.")] - InputFile(#[from] std::io::Error), + #[display("Error reading from path: {_0}")] + InputFile(std::io::Error), /// Error parsing input. - #[error("Error parsing package.")] - Parse(#[from] serde_json::Error), - #[error("Error validating HUGR.")] + #[display("Error parsing package: {_0}")] + Parse(serde_json::Error), + #[display("Error validating HUGR: {_0}")] /// Errors produced by the `validate` subcommand. - Validate(#[from] PackageValidationError), - #[error("Error decoding HUGR envelope.")] + Validate(PackageValidationError), + #[display("Error decoding HUGR envelope: {_0}")] /// Errors produced by the `validate` subcommand. - Envelope(#[from] EnvelopeError), + Envelope(EnvelopeError), /// Pretty error when the user passes a non-envelope file. - #[error( + #[display( "Input file is not a HUGR envelope. Invalid magic number.\n\nUse `--hugr-json` to read a raw HUGR JSON file instead." )] NotAnEnvelope, /// Invalid format string for conversion. - #[error( + #[display( "Invalid format: '{_0}'. Valid formats are: json, model, model-exts, model-text, model-text-exts" )] InvalidFormat(String), - #[error("Error validating HUGR generated by {generator}")] - /// Errors produced by the `validate` subcommand, with a known generator of the HUGR. - ValidateKnownGenerator { - #[source] - /// The inner validation error. - inner: PackageValidationError, - /// The generator of the HUGR. - generator: Box, - }, +} + +/// Other arguments affecting the HUGR CLI runtime. +#[derive(Parser, Debug)] +pub struct OtherArgs { + /// Verbosity. + #[command(flatten)] + pub verbose: Verbosity, +} + +impl OtherArgs { + /// Test whether a `level` message should be output. + #[must_use] + pub fn verbosity(&self, level: Level) -> bool { + self.verbose.log_level_filter() >= level + } } diff --git a/hugr-cli/src/main.rs b/hugr-cli/src/main.rs index 28e64020ff..8063f25916 100644 --- a/hugr-cli/src/main.rs +++ b/hugr-cli/src/main.rs @@ -1,72 +1,84 @@ //! Validate serialized HUGR on the command line -use std::ffi::OsString; - -use anyhow::{Result, anyhow}; use clap::Parser as _; -use clap_verbosity_flag::VerbosityFilter; -use hugr_cli::{CliArgs, CliCommand}; -use tracing::{error, metadata::LevelFilter}; - -fn main() { - let cli_args = CliArgs::parse(); - let level = match cli_args.verbose.filter() { - VerbosityFilter::Off => LevelFilter::OFF, - VerbosityFilter::Error => LevelFilter::ERROR, - VerbosityFilter::Warn => LevelFilter::WARN, - VerbosityFilter::Info => LevelFilter::INFO, - VerbosityFilter::Debug => LevelFilter::DEBUG, - VerbosityFilter::Trace => LevelFilter::TRACE, - }; - tracing_subscriber::fmt() - .with_writer(std::io::stderr) - .with_max_level(level) - .pretty() - .init(); +use hugr_cli::{CliArgs, convert, mermaid, validate}; - let result = match cli_args.command { - CliCommand::Validate(mut args) => args.run(), - CliCommand::GenExtensions(args) => args.run_dump(&hugr::std_extensions::STD_REG), - CliCommand::Mermaid(mut args) => args.run_print(), - CliCommand::Convert(mut args) => args.run_convert(), - CliCommand::External(args) => run_external(args), - _ => Err(anyhow!("Unknown command")), - }; +use clap_verbosity_flag::log::Level; - if let Err(err) = result { - error!("{:?}", err); - std::process::exit(1); +fn main() { + match CliArgs::parse() { + CliArgs::Validate(args) => run_validate(args), + CliArgs::GenExtensions(args) => args.run_dump(&hugr::std_extensions::STD_REG), + CliArgs::Mermaid(args) => run_mermaid(args), + CliArgs::Convert(args) => run_convert(args), + CliArgs::External(args) => { + // External subcommand support: invoke `hugr-` + if args.is_empty() { + eprintln!("No external subcommand specified."); + std::process::exit(1); + } + let subcmd = args[0].to_string_lossy(); + let exe = format!("hugr-{}", subcmd); + let rest: Vec<_> = args[1..] + .iter() + .map(|s| s.to_string_lossy().to_string()) + .collect(); + match std::process::Command::new(&exe).args(&rest).status() { + Ok(status) => { + if !status.success() { + std::process::exit(status.code().unwrap_or(1)); + } + } + Err(e) if e.kind() == std::io::ErrorKind::NotFound => { + eprintln!( + "error: no such subcommand: '{subcmd}'.\nCould not find '{exe}' in PATH." + ); + std::process::exit(1); + } + Err(e) => { + eprintln!("error: failed to invoke '{exe}': {e}"); + std::process::exit(1); + } + } + } + _ => { + eprintln!("Unknown command"); + std::process::exit(1); + } } } -fn run_external(args: Vec) -> Result<()> { - // External subcommand support: invoke `hugr-` - if args.is_empty() { - eprintln!("No external subcommand specified."); +/// Run the `validate` subcommand. +fn run_validate(mut args: validate::ValArgs) { + let result = args.run(); + + if let Err(e) = result { + if args.verbosity(Level::Error) { + eprintln!("{e}"); + } std::process::exit(1); } - let subcmd = args[0].to_string_lossy(); - let exe = format!("hugr-{subcmd}"); - let rest: Vec<_> = args[1..] - .iter() - .map(|s| s.to_string_lossy().to_string()) - .collect(); - match std::process::Command::new(&exe).args(&rest).status() { - Ok(status) => { - if !status.success() { - std::process::exit(status.code().unwrap_or(1)); - } - } - Err(e) if e.kind() == std::io::ErrorKind::NotFound => { - eprintln!("error: no such subcommand: '{subcmd}'.\nCould not find '{exe}' in PATH."); - std::process::exit(1); - } - Err(e) => { - eprintln!("error: failed to invoke '{exe}': {e}"); - std::process::exit(1); +} + +/// Run the `mermaid` subcommand. +fn run_mermaid(mut args: mermaid::MermaidArgs) { + let result = args.run_print(); + + if let Err(e) = result { + if args.other_args.verbosity(Level::Error) { + eprintln!("{e}"); } + std::process::exit(1); } +} - Ok(()) +/// Run the `convert` subcommand. +fn run_convert(mut args: convert::ConvertArgs) { + let result = args.run_convert(); + + if let Err(e) = result { + eprintln!("{e}"); + std::process::exit(1); + } } diff --git a/hugr-cli/src/mermaid.rs b/hugr-cli/src/mermaid.rs index fbe4a09a4e..0ce5c9b8a4 100644 --- a/hugr-cli/src/mermaid.rs +++ b/hugr-cli/src/mermaid.rs @@ -1,14 +1,15 @@ //! Render mermaid diagrams. use std::io::Write; -use crate::CliError; -use crate::hugr_io::HugrInputArgs; -use anyhow::Result; use clap::Parser; +use clap_verbosity_flag::log::Level; use clio::Output; use hugr::HugrView; use hugr::package::PackageValidationError; +use crate::OtherArgs; +use crate::hugr_io::HugrInputArgs; + /// Dump the standard extensions. #[derive(Parser, Debug)] #[clap(version = "1.0", long_about = None)] @@ -29,11 +30,15 @@ pub struct MermaidArgs { /// Output file '-' for stdout #[clap(long, short, value_parser, default_value = "-")] output: Output, + + /// Additional arguments + #[command(flatten)] + pub other_args: OtherArgs, } impl MermaidArgs { /// Write the mermaid diagram to the output. - pub fn run_print(&mut self) -> Result<()> { + pub fn run_print(&mut self) -> Result<(), crate::CliError> { if self.input_args.hugr_json { self.run_print_hugr() } else { @@ -42,11 +47,11 @@ impl MermaidArgs { } /// Write the mermaid diagram for a HUGR envelope. - pub fn run_print_envelope(&mut self) -> Result<()> { + pub fn run_print_envelope(&mut self) -> Result<(), crate::CliError> { let package = self.input_args.get_package()?; if self.validate { - package.validate().map_err(CliError::Validate)?; + package.validate()?; } for hugr in package.modules { @@ -56,7 +61,7 @@ impl MermaidArgs { } /// Write the mermaid diagram for a legacy HUGR json. - pub fn run_print_hugr(&mut self) -> Result<()> { + pub fn run_print_hugr(&mut self) -> Result<(), crate::CliError> { let hugr = self.input_args.get_hugr()?; if self.validate { @@ -67,4 +72,10 @@ impl MermaidArgs { writeln!(self.output, "{}", hugr.mermaid_string())?; Ok(()) } + + /// Test whether a `level` message should be output. + #[must_use] + pub fn verbosity(&self, level: Level) -> bool { + self.other_args.verbosity(level) + } } diff --git a/hugr-cli/src/validate.rs b/hugr-cli/src/validate.rs index 2ec19e8b38..ddf51d135c 100644 --- a/hugr-cli/src/validate.rs +++ b/hugr-cli/src/validate.rs @@ -1,13 +1,12 @@ //! The `validate` subcommand. -use anyhow::Result; use clap::Parser; -use hugr::HugrView; +use clap_verbosity_flag::log::Level; use hugr::package::PackageValidationError; -use tracing::info; +use hugr::{Hugr, HugrView}; -use crate::CliError; use crate::hugr_io::HugrInputArgs; +use crate::{CliError, OtherArgs}; /// Validate and visualise a HUGR file. #[derive(Parser, Debug)] @@ -19,6 +18,10 @@ pub struct ValArgs { /// Hugr input. #[command(flatten)] pub input_args: HugrInputArgs, + + /// Additional arguments + #[command(flatten)] + pub other_args: OtherArgs, } /// String to print when validation is successful. @@ -26,35 +29,28 @@ pub const VALID_PRINT: &str = "HUGR valid!"; impl ValArgs { /// Run the HUGR cli and validate against an extension registry. - pub fn run(&mut self) -> Result<()> { - if self.input_args.hugr_json { + pub fn run(&mut self) -> Result, CliError> { + let result = if self.input_args.hugr_json { let hugr = self.input_args.get_hugr()?; - let generator = hugr::envelope::get_generator(&[&hugr]); - hugr.validate() - .map_err(PackageValidationError::Validation) - .map_err(|val_err| wrap_generator(generator, val_err))?; + .map_err(PackageValidationError::Validation)?; + vec![hugr] } else { let package = self.input_args.get_package()?; - let generator = hugr::envelope::get_generator(&package.modules); - package - .validate() - .map_err(|val_err| wrap_generator(generator, val_err))?; + package.validate()?; + package.modules }; - info!("{VALID_PRINT}"); + if self.verbosity(Level::Info) { + eprintln!("{VALID_PRINT}"); + } - Ok(()) + Ok(result) } -} -fn wrap_generator(generator: Option, val_err: PackageValidationError) -> CliError { - if let Some(g) = generator { - CliError::ValidateKnownGenerator { - inner: val_err, - generator: Box::new(g.to_string()), - } - } else { - CliError::Validate(val_err) + /// Test whether a `level` message should be output. + #[must_use] + pub fn verbosity(&self, level: Level) -> bool { + self.other_args.verbosity(level) } } diff --git a/hugr-cli/tests/validate.rs b/hugr-cli/tests/validate.rs index d7d552decc..5fb09ca091 100644 --- a/hugr-cli/tests/validate.rs +++ b/hugr-cli/tests/validate.rs @@ -13,15 +13,12 @@ use hugr::types::Type; use hugr::{ builder::{Container, Dataflow}, extension::prelude::{bool_t, qb_t}, - hugr::HugrView, - hugr::hugrmut::HugrMut, std_extensions::arithmetic::float_types::float64_type, types::Signature, }; use hugr_cli::validate::VALID_PRINT; use predicates::{prelude::*, str::contains}; use rstest::{fixture, rstest}; -use serde_json::json; #[fixture] fn cmd() -> Command { @@ -126,7 +123,6 @@ fn test_mermaid_invalid(bad_hugr_string: String, mut cmd: Command) { cmd.write_stdin(bad_hugr_string); cmd.assert() .failure() - .stderr(contains("unconnected port")) .stderr(contains("Error validating HUGR")); } @@ -138,7 +134,6 @@ fn test_bad_hugr(bad_hugr_string: String, mut val_cmd: Command) { val_cmd .assert() .failure() - .stderr(contains("unconnected port")) .stderr(contains("Error validating HUGR")); } @@ -151,8 +146,7 @@ fn test_bad_json(mut val_cmd: Command) { val_cmd .assert() .failure() - .stderr(contains("Error decoding HUGR envelope")) - .stderr(contains("missing field")); + .stderr(contains("Error decoding HUGR envelope")); } #[rstest] @@ -205,38 +199,3 @@ fn test_package_validation(package_string: String, mut val_cmd: Command) { val_cmd.assert().success().stderr(contains(VALID_PRINT)); } - -/// Create a deliberately invalid HUGR with a known generator -#[fixture] -fn invalid_hugr_with_generator() -> Vec { - // Create an invalid HUGR (missing outputs in a dataflow) - let df = DFGBuilder::new(Signature::new_endo(vec![qb_t()])).unwrap(); - let mut bad_hugr = df.hugr().clone(); // Missing outputs makes this invalid - bad_hugr.set_metadata( - bad_hugr.module_root(), - hugr::envelope::GENERATOR_KEY, - json!({"name": "test-generator", "version": "1.0.1"}), - ); - // Create envelope with a specific generator - let envelope_config = EnvelopeConfig::binary(); - - let mut buff = Vec::new(); - // Serialize to string - bad_hugr.store(&mut buff, envelope_config).unwrap(); - buff -} - -#[rstest] -fn test_validate_known_generator(invalid_hugr_with_generator: Vec, mut val_cmd: Command) { - // Write the invalid HUGR to stdin - val_cmd.write_stdin(invalid_hugr_with_generator); - val_cmd.arg("-"); - - // Expect a failure with the generator name in the error message - val_cmd - .assert() - .failure() - .stderr(contains("Error validating HUGR")) - .stderr(contains("unconnected port")) - .stderr(contains("generated by test-generator-v1.0.1")); -} diff --git a/hugr-core/CHANGELOG.md b/hugr-core/CHANGELOG.md index 2202aac02a..9026431a7a 100644 --- a/hugr-core/CHANGELOG.md +++ b/hugr-core/CHANGELOG.md @@ -1,65 +1,5 @@ # Changelog - -## [0.22.0](https://github.com/CQCL/hugr/compare/hugr-core-v0.21.0...hugr-core-v0.22.0) - 2025-07-24 - -### Bug Fixes - -- Ensure SumTypes have the same json encoding in -rs and -py ([#2465](https://github.com/CQCL/hugr/pull/2465)) - -### New Features - -- Export entrypoint metadata in Python and fix bug in import ([#2434](https://github.com/CQCL/hugr/pull/2434)) -- Names of private functions become `core.title` metadata. ([#2448](https://github.com/CQCL/hugr/pull/2448)) -- [**breaking**] Use binary envelopes for operation lower_func encoding ([#2447](https://github.com/CQCL/hugr/pull/2447)) -- [**breaking**] Update portgraph dependency to 0.15 ([#2455](https://github.com/CQCL/hugr/pull/2455)) -- Detect and fail on unrecognised envelope flags ([#2453](https://github.com/CQCL/hugr/pull/2453)) -- include generator metatada in model import and cli validate errors ([#2452](https://github.com/CQCL/hugr/pull/2452)) -- [**breaking**] Add `insert_region` to HugrMut ([#2463](https://github.com/CQCL/hugr/pull/2463)) -- Non-region entrypoints in `hugr-model`. ([#2467](https://github.com/CQCL/hugr/pull/2467)) -## [0.21.0](https://github.com/CQCL/hugr/compare/hugr-core-v0.20.2...hugr-core-v0.21.0) - 2025-07-09 - -### Bug Fixes - -- Fixed two bugs in import/export of function operations ([#2324](https://github.com/CQCL/hugr/pull/2324)) -- Model import should perform extension resolution ([#2326](https://github.com/CQCL/hugr/pull/2326)) -- [**breaking**] Fixed bugs in model CFG handling and improved CFG signatures ([#2334](https://github.com/CQCL/hugr/pull/2334)) -- Use List instead of Tuple in conversions for TypeArg/TypeRow ([#2378](https://github.com/CQCL/hugr/pull/2378)) -- Do extension resolution on loaded extensions from the model format ([#2389](https://github.com/CQCL/hugr/pull/2389)) -- Make JSON Schema checks actually work again ([#2412](https://github.com/CQCL/hugr/pull/2412)) -- Order hints on input and output nodes. ([#2422](https://github.com/CQCL/hugr/pull/2422)) - -### New Features - -- [**breaking**] No nested FuncDefns (or AliasDefns) ([#2256](https://github.com/CQCL/hugr/pull/2256)) -- [**breaking**] Split `TypeArg::Sequence` into tuples and lists. ([#2140](https://github.com/CQCL/hugr/pull/2140)) -- [**breaking**] Added float and bytes literal to core and python bindings. ([#2289](https://github.com/CQCL/hugr/pull/2289)) -- [**breaking**] More helpful error messages in model import ([#2272](https://github.com/CQCL/hugr/pull/2272)) -- [**breaking**] Better error reporting in `hugr-cli`. ([#2318](https://github.com/CQCL/hugr/pull/2318)) -- [**breaking**] Merge `TypeParam` and `TypeArg` into one `Term` type in Rust ([#2309](https://github.com/CQCL/hugr/pull/2309)) -- *(persistent)* Add serialisation for CommitStateSpace ([#2344](https://github.com/CQCL/hugr/pull/2344)) -- add TryFrom impls for TypeArg/TypeRow ([#2366](https://github.com/CQCL/hugr/pull/2366)) -- Add `MakeError` op ([#2377](https://github.com/CQCL/hugr/pull/2377)) -- Open lists and tuples in `Term` ([#2360](https://github.com/CQCL/hugr/pull/2360)) -- Call `FunctionBuilder::add_{in,out}put` for any AsMut ([#2376](https://github.com/CQCL/hugr/pull/2376)) -- Add Root checked methods to DataflowParentID ([#2382](https://github.com/CQCL/hugr/pull/2382)) -- Add PersistentWire type ([#2361](https://github.com/CQCL/hugr/pull/2361)) -- Add `BorrowArray` extension ([#2395](https://github.com/CQCL/hugr/pull/2395)) -- [**breaking**] Rename 'Any' type bound to 'Linear' ([#2421](https://github.com/CQCL/hugr/pull/2421)) -- [**breaking**] Add Visibility to FuncDefn/FuncDecl. ([#2143](https://github.com/CQCL/hugr/pull/2143)) -- *(per)* [**breaking**] Support empty wires in commits ([#2349](https://github.com/CQCL/hugr/pull/2349)) -- [**breaking**] hugr-model use explicit Option, with ::Unspecified in capnp ([#2424](https://github.com/CQCL/hugr/pull/2424)) - -### Refactor - -- [**breaking**] move PersistentHugr into separate crate ([#2277](https://github.com/CQCL/hugr/pull/2277)) -- [**breaking**] remove deprecated runtime extension errors ([#2369](https://github.com/CQCL/hugr/pull/2369)) -- [**breaking**] Reduce error type sizes ([#2420](https://github.com/CQCL/hugr/pull/2420)) - -### Testing - -- Check hugr json serializations against the schema (again) ([#2216](https://github.com/CQCL/hugr/pull/2216)) - ## [0.20.2](https://github.com/CQCL/hugr/compare/hugr-core-v0.20.1...hugr-core-v0.20.2) - 2025-06-25 ### Documentation diff --git a/hugr-core/Cargo.toml b/hugr-core/Cargo.toml index 02ea45885f..a365c73e19 100644 --- a/hugr-core/Cargo.toml +++ b/hugr-core/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "hugr-core" -version = "0.22.1" +version = "0.20.2" edition = { workspace = true } rust-version = { workspace = true } @@ -19,7 +19,6 @@ workspace = true [features] declarative = ["serde_yaml"] zstd = ["dep:zstd"] -default = [] [lib] bench = false @@ -27,8 +26,11 @@ bench = false [[test]] name = "model" +[[test]] +name = "persistent_walker_example" + [dependencies] -hugr-model = { version = "0.22.1", path = "../hugr-model" } +hugr-model = { version = "0.20.2", path = "../hugr-model" } cgmath = { workspace = true, features = ["serde"] } delegate = { workspace = true } @@ -61,11 +63,7 @@ thiserror = { workspace = true } typetag = { workspace = true } semver = { workspace = true, features = ["serde"] } zstd = { workspace = true, optional = true } -ordered-float = { workspace = true, features = ["serde"] } -base64.workspace = true relrc = { workspace = true, features = ["petgraph", "serde"] } -smallvec = "1.15.0" -tracing = "0.1.41" [dev-dependencies] rstest = { workspace = true } @@ -79,4 +77,3 @@ proptest-derive = { workspace = true } # Required for documentation examples hugr = { path = "../hugr" } serde_yaml = "0.9.34" -anyhow = { workspace = true } diff --git a/hugr-core/src/builder.rs b/hugr-core/src/builder.rs index ee5046dd5c..aa2d949056 100644 --- a/hugr-core/src/builder.rs +++ b/hugr-core/src/builder.rs @@ -189,7 +189,7 @@ pub enum BuildError { #[error("Found an error while setting the outputs of a {} container, {container_node}. {error}", .container_op.name())] #[allow(missing_docs)] OutputWiring { - container_op: Box, + container_op: OpType, container_node: Node, #[source] error: BuilderWiringError, @@ -201,7 +201,7 @@ pub enum BuildError { #[error("Got an input wire while adding a {} to the circuit. {error}", .op.name())] #[allow(missing_docs)] OperationWiring { - op: Box, + op: OpType, #[source] error: BuilderWiringError, }, @@ -219,7 +219,7 @@ pub enum BuilderWiringError { #[error("Cannot copy linear type {typ} from output {src_offset} of node {src}")] #[allow(missing_docs)] NoCopyLinear { - typ: Box, + typ: Type, src: Node, src_offset: Port, }, @@ -244,7 +244,7 @@ pub enum BuilderWiringError { src_offset: Port, dst: Node, dst_offset: Port, - typ: Box, + typ: Type, }, } @@ -261,8 +261,8 @@ pub(crate) mod test { use super::handle::BuildHandle; use super::{ - BuildError, CFGBuilder, DFGBuilder, Dataflow, DataflowHugr, FuncID, FunctionBuilder, - ModuleBuilder, + BuildError, CFGBuilder, Container, DFGBuilder, Dataflow, DataflowHugr, FuncID, + FunctionBuilder, ModuleBuilder, }; use super::{DataflowSubContainer, HugrBuilder}; diff --git a/hugr-core/src/builder/build_traits.rs b/hugr-core/src/builder/build_traits.rs index ac4d46645d..03731bb7da 100644 --- a/hugr-core/src/builder/build_traits.rs +++ b/hugr-core/src/builder/build_traits.rs @@ -9,7 +9,7 @@ use crate::{Extension, IncomingPort, Node, OutgoingPort}; use std::iter; use std::sync::Arc; -use super::{BuilderWiringError, ModuleBuilder}; +use super::{BuilderWiringError, FunctionBuilder}; use super::{ CircuitBuilder, handle::{BuildHandle, Outputs}, @@ -21,7 +21,7 @@ use crate::{ }; use crate::extension::ExtensionRegistry; -use crate::types::{Signature, Type, TypeArg, TypeRow}; +use crate::types::{PolyFuncType, Signature, Type, TypeArg, TypeRow}; use itertools::Itertools; @@ -82,20 +82,37 @@ pub trait Container { self.add_child_node(constant.into()).into() } - /// Insert a HUGR's entrypoint region as a child of the container. + /// Add a [`ops::FuncDefn`] node and returns a builder to define the function + /// body graph. /// - /// To insert an arbitrary region of a HUGR, use [`Container::add_hugr_region`]. - fn add_hugr(&mut self, child: Hugr) -> InsertionResult { - let region = child.entrypoint(); - self.add_hugr_region(child, region) + /// # Errors + /// + /// This function will return an error if there is an error in adding the + /// [`ops::FuncDefn`] node. + fn define_function( + &mut self, + name: impl Into, + signature: impl Into, + ) -> Result, BuildError> { + let signature: PolyFuncType = signature.into(); + let body = signature.body().clone(); + let f_node = self.add_child_node(ops::FuncDefn::new(name, signature)); + + // Add the extensions used by the function types. + self.use_extensions( + body.used_extensions().unwrap_or_else(|e| { + panic!("Build-time signatures should have valid extensions. {e}") + }), + ); + + let db = DFGBuilder::create_with_io(self.hugr_mut(), f_node, body)?; + Ok(FunctionBuilder::from_dfg_builder(db)) } - /// Insert a HUGR region as a child of the container. - /// - /// To insert the entrypoint region of a HUGR, use [`Container::add_hugr`]. - fn add_hugr_region(&mut self, child: Hugr, region: Node) -> InsertionResult { + /// Insert a HUGR as a child of the container. + fn add_hugr(&mut self, child: Hugr) -> InsertionResult { let parent = self.container_node(); - self.hugr_mut().insert_region(parent, child, region) + self.hugr_mut().insert_hugr(parent, child) } /// Insert a copy of a HUGR as a child of the container. @@ -138,19 +155,8 @@ pub trait Container { } /// Types implementing this trait can be used to build complete HUGRs -/// (with varying entrypoint node types) +/// (with varying root node types) pub trait HugrBuilder: Container { - /// Allows adding definitions to the module root of which - /// this builder is building a part - fn module_root_builder(&mut self) -> ModuleBuilder<&mut Hugr> { - debug_assert!( - self.hugr() - .get_optype(self.hugr().module_root()) - .is_module() - ); - ModuleBuilder(self.hugr_mut()) - } - /// Finish building the HUGR, perform any validation checks and return it. fn finish_hugr(self) -> Result>; } @@ -210,10 +216,6 @@ pub trait Dataflow: Container { /// Insert a hugr-defined op to the sibling graph, wiring up the /// `input_wires` to the incoming ports of the resulting root node. /// - /// Inserts everything from the entrypoint region of the HUGR. - /// See [`Dataflow::add_hugr_region_with_wires`] for a generic version that allows - /// inserting a region other than the entrypoint. - /// /// # Errors /// /// This function will return an error if there is an error when adding the @@ -223,34 +225,12 @@ pub trait Dataflow: Container { hugr: Hugr, input_wires: impl IntoIterator, ) -> Result, BuildError> { - let region = hugr.entrypoint(); - self.add_hugr_region_with_wires(hugr, region, input_wires) - } - - /// Insert a hugr-defined op to the sibling graph, wiring up the - /// `input_wires` to the incoming ports of the resulting root node. - /// - /// `region` must be a node in the `hugr`. See [`Dataflow::add_hugr_with_wires`] - /// for a helper that inserts the entrypoint region to the HUGR. - /// - /// # Errors - /// - /// This function will return an error if there is an error when adding the - /// node. - fn add_hugr_region_with_wires( - &mut self, - hugr: Hugr, - region: Node, - input_wires: impl IntoIterator, - ) -> Result, BuildError> { - let optype = hugr.get_optype(region).clone(); + let optype = hugr.get_optype(hugr.entrypoint()).clone(); let num_outputs = optype.value_output_count(); - let node = self.add_hugr_region(hugr, region).inserted_entrypoint; + let node = self.add_hugr(hugr).inserted_entrypoint; - wire_up_inputs(input_wires, node, self).map_err(|error| BuildError::OperationWiring { - op: Box::new(optype), - error, - })?; + wire_up_inputs(input_wires, node, self) + .map_err(|error| BuildError::OperationWiring { op: optype, error })?; Ok((node, num_outputs).into()) } @@ -271,10 +251,8 @@ pub trait Dataflow: Container { let optype = hugr.get_optype(hugr.entrypoint()).clone(); let num_outputs = optype.value_output_count(); - wire_up_inputs(input_wires, node, self).map_err(|error| BuildError::OperationWiring { - op: Box::new(optype), - error, - })?; + wire_up_inputs(input_wires, node, self) + .map_err(|error| BuildError::OperationWiring { op: optype, error })?; Ok((node, num_outputs).into()) } @@ -291,7 +269,7 @@ pub trait Dataflow: Container { let [_, out] = self.io(); wire_up_inputs(output_wires.into_iter().collect_vec(), out, self).map_err(|error| { BuildError::OutputWiring { - container_op: Box::new(self.hugr().get_optype(self.container_node()).clone()), + container_op: self.hugr().get_optype(self.container_node()).clone(), container_node: self.container_node(), error, } @@ -700,10 +678,8 @@ fn add_node_with_wires( let num_outputs = op.value_output_count(); let op_node = data_builder.add_child_node(op.clone()); - wire_up_inputs(inputs, op_node, data_builder).map_err(|error| BuildError::OperationWiring { - op: Box::new(op), - error, - })?; + wire_up_inputs(inputs, op_node, data_builder) + .map_err(|error| BuildError::OperationWiring { op, error })?; Ok((op_node, num_outputs)) } @@ -755,7 +731,7 @@ fn wire_up( src_offset: src_port.into(), dst, dst_offset: dst_port.into(), - typ: Box::new(typ), + typ, }); } @@ -786,7 +762,7 @@ fn wire_up( } else if !typ.copyable() & base.linked_ports(src, src_port).next().is_some() { // Don't copy linear edges. return Err(BuilderWiringError::NoCopyLinear { - typ: Box::new(typ), + typ, src, src_offset: src_port.into(), }); diff --git a/hugr-core/src/builder/circuit.rs b/hugr-core/src/builder/circuit.rs index 193bfe0675..58f388f439 100644 --- a/hugr-core/src/builder/circuit.rs +++ b/hugr-core/src/builder/circuit.rs @@ -27,10 +27,10 @@ pub struct CircuitBuilder<'a, T: ?Sized> { #[non_exhaustive] pub enum CircuitBuildError { /// Invalid index for stored wires. - #[error("Invalid wire index {invalid_index} while attempting to add operation {}.", .op.as_ref().map(|op| op.name()).unwrap_or_default())] + #[error("Invalid wire index {invalid_index} while attempting to add operation {}.", .op.as_ref().map(NamedOp::name).unwrap_or_default())] InvalidWireIndex { /// The operation. - op: Option>, + op: Option, /// The invalid indices. invalid_index: usize, }, @@ -38,7 +38,7 @@ pub enum CircuitBuildError { #[error("The linear inputs {:?} had no corresponding output wire in operation {}.", .index.as_slice(), .op.name())] MismatchedLinearInputs { /// The operation. - op: Box, + op: OpType, /// The index of the input that had no corresponding output wire. index: Vec, }, @@ -143,7 +143,7 @@ impl<'a, T: Dataflow + ?Sized> CircuitBuilder<'a, T> { let input_wires = input_wires.map_err(|invalid_index| CircuitBuildError::InvalidWireIndex { - op: Some(Box::new(op.clone())), + op: Some(op.clone()), invalid_index, })?; @@ -169,7 +169,7 @@ impl<'a, T: Dataflow + ?Sized> CircuitBuilder<'a, T> { if !linear_inputs.is_empty() { return Err(CircuitBuildError::MismatchedLinearInputs { - op: Box::new(op), + op, index: linear_inputs.values().copied().collect(), } .into()); @@ -245,7 +245,7 @@ mod test { use cool_asserts::assert_matches; use crate::Extension; - use crate::builder::{HugrBuilder, ModuleBuilder}; + use crate::builder::{Container, HugrBuilder, ModuleBuilder}; use crate::extension::ExtensionId; use crate::extension::prelude::{qb_t, usize_t}; use crate::std_extensions::arithmetic::float_types::ConstF64; @@ -389,7 +389,7 @@ mod test { assert_matches!( circ.append(cx_gate(), [q0, invalid_index]), Err(BuildError::CircuitError(CircuitBuildError::InvalidWireIndex { op, invalid_index: idx })) - if op == Some(Box::new(cx_gate().into())) && idx == invalid_index, + if op == Some(cx_gate().into()) && idx == invalid_index, ); // Untracking an invalid index returns an error @@ -403,7 +403,7 @@ mod test { assert_matches!( circ.append(q_discard(), [q1]), Err(BuildError::CircuitError(CircuitBuildError::MismatchedLinearInputs { op, index })) - if *op == q_discard().into() && index == [q1], + if op == q_discard().into() && index == [q1], ); let outs = circ.finish(); diff --git a/hugr-core/src/builder/dataflow.rs b/hugr-core/src/builder/dataflow.rs index d1131116f3..2a0fdf9315 100644 --- a/hugr-core/src/builder/dataflow.rs +++ b/hugr-core/src/builder/dataflow.rs @@ -8,9 +8,14 @@ use std::marker::PhantomData; use crate::hugr::internal::HugrMutInternals; use crate::hugr::{HugrView, ValidationError}; -use crate::ops::{self, DataflowParent, FuncDefn, Input, OpParent, Output}; +use crate::ops::{self, OpParent}; +use crate::ops::{DataflowParent, Input, Output}; +use crate::{Direction, IncomingPort, OutgoingPort, Wire}; + use crate::types::{PolyFuncType, Signature, Type}; -use crate::{Direction, Hugr, IncomingPort, Node, OutgoingPort, Visibility, Wire, hugr::HugrMut}; + +use crate::Node; +use crate::{Hugr, hugr::HugrMut}; /// Builder for a [`ops::DFG`] node. #[derive(Debug, Clone, PartialEq)] @@ -147,65 +152,22 @@ impl DFGWrapper { pub type FunctionBuilder = DFGWrapper>>; impl FunctionBuilder { - /// Initialize a builder for a [`FuncDefn`](ops::FuncDefn)-rooted HUGR; - /// the function will be private. (See also [Self::new_vis].) - /// + /// Initialize a builder for a `FuncDefn` rooted HUGR /// # Errors /// /// Error in adding DFG child nodes. pub fn new( name: impl Into, signature: impl Into, - ) -> Result { - Self::new_with_op(FuncDefn::new(name, signature)) - } - - /// Initialize a builder for a FuncDefn-rooted HUGR, with the specified - /// [Visibility]. - /// - /// # Errors - /// - /// Error in adding DFG child nodes. - pub fn new_vis( - name: impl Into, - signature: impl Into, - visibility: Visibility, - ) -> Result { - Self::new_with_op(FuncDefn::new_vis(name, signature, visibility)) - } - - fn new_with_op(op: FuncDefn) -> Result { - let body = op.signature().body().clone(); - - let base = Hugr::new_with_entrypoint(op).expect("FuncDefn entrypoint should be valid"); - let root = base.entrypoint(); - - let db = DFGBuilder::create_with_io(base, root, body)?; - Ok(Self::from_dfg_builder(db)) - } -} - -impl + AsRef> FunctionBuilder { - /// Initialize a new function definition on the root module of an existing HUGR. - /// - /// The HUGR's entrypoint will **not** be modified. - /// - /// # Errors - /// - /// Error in adding DFG child nodes. - pub fn with_hugr( - mut hugr: B, - name: impl Into, - signature: impl Into, ) -> Result { let signature: PolyFuncType = signature.into(); let body = signature.body().clone(); let op = ops::FuncDefn::new(name, signature); - let module = hugr.as_ref().module_root(); - let func = hugr.as_mut().add_node_with_parent(module, op); + let base = Hugr::new_with_entrypoint(op).expect("FuncDefn entrypoint should be valid"); + let root = base.entrypoint(); - let db = DFGBuilder::create_with_io(hugr, func, body)?; + let db = DFGBuilder::create_with_io(base, root, body)?; Ok(Self::from_dfg_builder(db)) } @@ -297,6 +259,31 @@ impl + AsRef> FunctionBuilder { } } +impl + AsRef> FunctionBuilder { + /// Initialize a new function definition on the root module of an existing HUGR. + /// + /// The HUGR's entrypoint will **not** be modified. + /// + /// # Errors + /// + /// Error in adding DFG child nodes. + pub fn with_hugr( + mut hugr: B, + name: impl Into, + signature: impl Into, + ) -> Result { + let signature: PolyFuncType = signature.into(); + let body = signature.body().clone(); + let op = ops::FuncDefn::new(name, signature); + + let module = hugr.as_ref().module_root(); + let func = hugr.as_mut().add_node_with_parent(module, op); + + let db = DFGBuilder::create_with_io(hugr, func, body)?; + Ok(Self::from_dfg_builder(db)) + } +} + impl + AsRef, T> Container for DFGWrapper { #[inline] fn container_node(&self) -> Node { @@ -450,7 +437,7 @@ pub(crate) mod test { error: BuilderWiringError::NoCopyLinear { typ, .. }, .. }) - if *typ == qb_t() + if typ == qb_t() ); } @@ -665,7 +652,7 @@ pub(crate) mod test { FunctionBuilder::new( "bad_eval", PolyFuncType::new( - [TypeParam::new_list_type(TypeBound::Copyable)], + [TypeParam::new_list(TypeBound::Copyable)], Signature::new( Type::new_function(FuncValueType::new(usize_t(), tv.clone())), vec![], diff --git a/hugr-core/src/builder/module.rs b/hugr-core/src/builder/module.rs index 5499abae92..543b9f2c1e 100644 --- a/hugr-core/src/builder/module.rs +++ b/hugr-core/src/builder/module.rs @@ -4,24 +4,25 @@ use super::{ dataflow::{DFGBuilder, FunctionBuilder}, }; +use crate::hugr::ValidationError; use crate::hugr::internal::HugrMutInternals; use crate::hugr::views::HugrView; use crate::ops; -use crate::ops::handle::{AliasID, FuncID, NodeHandle}; use crate::types::{PolyFuncType, Type, TypeBound}; -use crate::{Hugr, Node, Visibility}; -use crate::{hugr::ValidationError, ops::FuncDefn}; +use crate::ops::handle::{AliasID, FuncID, NodeHandle}; + +use crate::{Hugr, Node}; use smol_str::SmolStr; /// Builder for a HUGR module. -#[derive(Debug, Default, Clone, PartialEq)] +#[derive(Debug, Clone, PartialEq)] pub struct ModuleBuilder(pub(super) T); impl + AsRef> Container for ModuleBuilder { #[inline] fn container_node(&self) -> Node { - self.0.as_ref().module_root() + self.0.as_ref().entrypoint() } #[inline] @@ -38,7 +39,13 @@ impl ModuleBuilder { /// Begin building a new module. #[must_use] pub fn new() -> Self { - Self::default() + Self(Default::default()) + } +} + +impl Default for ModuleBuilder { + fn default() -> Self { + Self::new() } } @@ -68,61 +75,25 @@ impl + AsRef> ModuleBuilder { f_id: &FuncID, ) -> Result, BuildError> { let f_node = f_id.node(); - let opty = self.hugr_mut().optype_mut(f_node); - let ops::OpType::FuncDecl(decl) = opty else { - return Err(BuildError::UnexpectedType { - node: f_node, - op_desc: "crate::ops::OpType::FuncDecl", - }); - }; - - let body = decl.signature().body().clone(); - *opty = ops::FuncDefn::new_vis( - decl.func_name(), - decl.signature().clone(), - decl.visibility().clone(), - ) - .into(); - - let db = DFGBuilder::create_with_io(self.hugr_mut(), f_node, body)?; - Ok(FunctionBuilder::from_dfg_builder(db)) - } - - /// Add a [`ops::FuncDefn`] node of the specified visibility. - /// Returns a builder to define the function body graph. - /// - /// # Errors - /// - /// This function will return an error if there is an error in adding the - /// [`ops::FuncDefn`] node. - pub fn define_function_vis( - &mut self, - name: impl Into, - signature: impl Into, - visibility: Visibility, - ) -> Result, BuildError> { - self.define_function_op(FuncDefn::new_vis(name, signature, visibility)) - } - - fn define_function_op( - &mut self, - op: FuncDefn, - ) -> Result, BuildError> { - let body = op.signature().body().clone(); - let f_node = self.add_child_node(op); - - // Add the extensions used by the function types. - self.use_extensions( - body.used_extensions().unwrap_or_else(|e| { - panic!("Build-time signatures should have valid extensions. {e}") - }), - ); + let decl = + self.hugr() + .get_optype(f_node) + .as_func_decl() + .ok_or(BuildError::UnexpectedType { + node: f_node, + op_desc: "crate::ops::OpType::FuncDecl", + })?; + let name = decl.func_name().clone(); + let sig = decl.signature().clone(); + let body = sig.body().clone(); + self.hugr_mut() + .replace_op(f_node, ops::FuncDefn::new(name, sig)); let db = DFGBuilder::create_with_io(self.hugr_mut(), f_node, body)?; Ok(FunctionBuilder::from_dfg_builder(db)) } - /// Declare a [Visibility::Public] function with `signature` and return a handle to the declaration. + /// Declare a function with `signature` and return a handle to the declaration. /// /// # Errors /// @@ -132,26 +103,10 @@ impl + AsRef> ModuleBuilder { &mut self, name: impl Into, signature: PolyFuncType, - ) -> Result, BuildError> { - self.declare_vis(name, signature, Visibility::Public) - } - - /// Declare a function with the specified `signature` and [Visibility], - /// and return a handle to the declaration. - /// - /// # Errors - /// - /// This function will return an error if there is an error in adding the - /// [`crate::ops::OpType::FuncDecl`] node. - pub fn declare_vis( - &mut self, - name: impl Into, - signature: PolyFuncType, - visibility: Visibility, ) -> Result, BuildError> { let body = signature.body().clone(); // TODO add param names to metadata - let declare_n = self.add_child_node(ops::FuncDecl::new_vis(name, signature, visibility)); + let declare_n = self.add_child_node(ops::FuncDecl::new(name, signature)); // Add the extensions used by the function types. self.use_extensions( @@ -163,21 +118,6 @@ impl + AsRef> ModuleBuilder { Ok(declare_n.into()) } - /// Adds a [`ops::FuncDefn`] node and returns a builder to define the function - /// body graph. The function will be private. (See [Self::define_function_vis].) - /// - /// # Errors - /// - /// This function will return an error if there is an error in adding the - /// [`ops::FuncDefn`] node. - pub fn define_function( - &mut self, - name: impl Into, - signature: impl Into, - ) -> Result, BuildError> { - self.define_function_op(FuncDefn::new(name, signature)) - } - /// Add a [`crate::ops::OpType::AliasDefn`] node and return a handle to the Alias. /// /// # Errors @@ -259,7 +199,7 @@ mod test { let mut module_builder = ModuleBuilder::new(); let qubit_state_type = - module_builder.add_alias_declare("qubit_state", TypeBound::Linear)?; + module_builder.add_alias_declare("qubit_state", TypeBound::Any)?; let f_build = module_builder.define_function( "main", @@ -275,6 +215,31 @@ mod test { Ok(()) } + #[test] + fn local_def() -> Result<(), BuildError> { + let build_result = { + let mut module_builder = ModuleBuilder::new(); + + let mut f_build = module_builder.define_function( + "main", + Signature::new(vec![usize_t()], vec![usize_t(), usize_t()]), + )?; + let local_build = f_build.define_function( + "local", + Signature::new(vec![usize_t()], vec![usize_t(), usize_t()]), + )?; + let [wire] = local_build.input_wires_arr(); + let f_id = local_build.finish_with_outputs([wire, wire])?; + + let call = f_build.call(f_id.handle(), &[], f_build.input_wires())?; + + f_build.finish_with_outputs(call.outputs())?; + module_builder.finish_hugr() + }; + assert_matches!(build_result, Ok(_)); + Ok(()) + } + #[test] fn builder_from_existing() -> Result<(), BuildError> { let hugr = Hugr::new(); diff --git a/hugr-core/src/core.rs b/hugr-core/src/core.rs index 4578aa5357..06366822ae 100644 --- a/hugr-core/src/core.rs +++ b/hugr-core/src/core.rs @@ -7,7 +7,7 @@ pub use itertools::Either; use derive_more::From; use itertools::Either::{Left, Right}; -use crate::{HugrView, hugr::HugrError}; +use crate::hugr::HugrError; /// A handle to a node in the HUGR. #[derive( @@ -34,7 +34,7 @@ pub struct Node { )] #[serde(transparent)] pub struct Port { - offset: portgraph::PortOffset, + offset: portgraph::PortOffset, } /// A trait for getting the undirected index of a port. @@ -139,7 +139,7 @@ impl Port { /// Returns the port as a portgraph `PortOffset`. #[inline] - pub(crate) fn pg_offset(self) -> portgraph::PortOffset { + pub(crate) fn pg_offset(self) -> portgraph::PortOffset { self.offset } } @@ -219,55 +219,17 @@ impl Wire { Self(node, port.into()) } - /// Create a new wire from a node and a port that is connected to the wire. - /// - /// If `port` is an incoming port, the wire is traversed to find the unique - /// outgoing port that is connected to the wire. Otherwise, this is - /// equivalent to constructing a wire using [`Wire::new`]. - /// - /// ## Panics - /// - /// This will panic if the wire is not connected to a unique outgoing port. - #[inline] - pub fn from_connected_port( - node: N, - port: impl Into, - hugr: &impl HugrView, - ) -> Self { - let (node, outgoing) = match port.into().as_directed() { - Either::Left(incoming) => hugr - .single_linked_output(node, incoming) - .expect("invalid dfg port"), - Either::Right(outgoing) => (node, outgoing), - }; - Self::new(node, outgoing) - } - - /// The node of the unique outgoing port that the wire is connected to. + /// The node that this wire is connected to. #[inline] pub fn node(&self) -> N { self.0 } - /// The unique outgoing port that the wire is connected to. + /// The output port that this wire is connected to. #[inline] pub fn source(&self) -> OutgoingPort { self.1 } - - /// Get all ports connected to the wire. - /// - /// Return a chained iterator of the unique outgoing port, followed by all - /// incoming ports connected to the wire. - pub fn all_connected_ports<'h, H: HugrView>( - &self, - hugr: &'h H, - ) -> impl Iterator + use<'h, N, H> { - let node = self.node(); - let out_port = self.source(); - - std::iter::once((node, out_port.into())).chain(hugr.linked_ports(node, out_port)) - } } impl std::fmt::Display for Wire { @@ -276,46 +238,6 @@ impl std::fmt::Display for Wire { } } -/// Marks [FuncDefn](crate::ops::FuncDefn)s and [FuncDecl](crate::ops::FuncDecl)s as -/// to whether they should be considered for linking. -#[derive( - Clone, - Debug, - derive_more::Display, - PartialEq, - Eq, - PartialOrd, - Ord, - serde::Serialize, - serde::Deserialize, -)] -#[cfg_attr(test, derive(proptest_derive::Arbitrary))] -#[non_exhaustive] -pub enum Visibility { - /// Function is visible or exported - Public, - /// Function is hidden, for use within the hugr only - Private, -} - -impl From for Visibility { - fn from(value: hugr_model::v0::Visibility) -> Self { - match value { - hugr_model::v0::Visibility::Private => Self::Private, - hugr_model::v0::Visibility::Public => Self::Public, - } - } -} - -impl From for hugr_model::v0::Visibility { - fn from(value: Visibility) -> Self { - match value { - Visibility::Public => hugr_model::v0::Visibility::Public, - Visibility::Private => hugr_model::v0::Visibility::Private, - } - } -} - /// Enum for uniquely identifying the origin of linear wires in a circuit-like /// dataflow region. /// diff --git a/hugr-core/src/envelope.rs b/hugr-core/src/envelope.rs index d07f4b3e41..0223267b85 100644 --- a/hugr-core/src/envelope.rs +++ b/hugr-core/src/envelope.rs @@ -73,11 +73,11 @@ pub const USED_EXTENSIONS_KEY: &str = "core.used_extensions"; /// If multiple modules have different generators, a comma-separated list is returned in /// module order. /// If no generator is found, `None` is returned. -pub fn get_generator(modules: &[H]) -> Option { +fn get_generator(modules: &[H]) -> Option { let generators: Vec = modules .iter() .filter_map(|hugr| hugr.get_metadata(hugr.module_root(), GENERATOR_KEY)) - .map(format_generator) + .map(|v| v.to_string()) .collect(); if generators.is_empty() { return None; @@ -86,31 +86,6 @@ pub fn get_generator(modules: &[H]) -> Option { Some(generators.join(", ")) } -/// Format a generator value from the metadata. -pub fn format_generator(json_val: &serde_json::Value) -> String { - match json_val { - serde_json::Value::String(s) => s.clone(), - serde_json::Value::Object(obj) => { - if let (Some(name), version) = ( - obj.get("name").and_then(|v| v.as_str()), - obj.get("version").and_then(|v| v.as_str()), - ) { - if let Some(version) = version { - // Expected format: {"name": "generator", "version": "1.0.0"} - format!("{name}-v{version}") - } else { - name.to_string() - } - } else { - // just print the whole object as a string - json_val.to_string() - } - } - // Raw JSON string fallback - _ => json_val.to_string(), - } -} - fn gen_str(generator: &Option) -> String { match generator { Some(g) => format!("\ngenerated by {g}"), @@ -122,7 +97,7 @@ fn gen_str(generator: &Option) -> String { #[derive(Error, Debug)] #[error("{inner}{}", gen_str(&self.generator))] pub struct WithGenerator { - inner: Box, + inner: E, /// The name of the generator that produced the envelope, if any. generator: Option, } @@ -130,7 +105,7 @@ pub struct WithGenerator { impl WithGenerator { fn new(err: E, modules: &[impl HugrView]) -> Self { Self { - inner: Box::new(err), + inner: err, generator: get_generator(modules), } } @@ -204,15 +179,16 @@ pub(crate) fn write_envelope_impl<'h>( } /// Error type for envelope operations. -#[derive(Debug, Error)] +#[derive(derive_more::Display, derive_more::Error, Debug, derive_more::From)] #[non_exhaustive] pub enum EnvelopeError { /// Bad magic number. - #[error( + #[display( "Bad magic number. expected 0x{:X} found 0x{:X}", u64::from_be_bytes(*expected), u64::from_be_bytes(*found) )] + #[from(ignore)] MagicNumber { /// The expected magic number. /// @@ -222,18 +198,20 @@ pub enum EnvelopeError { found: [u8; 8], }, /// The specified payload format is invalid. - #[error("Format descriptor {descriptor} is invalid.")] + #[display("Format descriptor {descriptor} is invalid.")] + #[from(ignore)] InvalidFormatDescriptor { /// The unsupported format. descriptor: usize, }, /// The specified payload format is not supported. - #[error("Payload format {format} is not supported.{}", + #[display("Payload format {format} is not supported.{}", match feature { Some(f) => format!(" This requires the '{f}' feature for `hugr`."), None => String::new() }, )] + #[from(ignore)] FormatUnsupported { /// The unsupported format. format: EnvelopeFormat, @@ -243,97 +221,68 @@ pub enum EnvelopeError { /// Not all envelope formats can be represented as ASCII. /// /// This error is used when trying to store the envelope into a string. - #[error("Envelope format {format} cannot be represented as ASCII.")] + #[display("Envelope format {format} cannot be represented as ASCII.")] + #[from(ignore)] NonASCIIFormat { /// The unsupported format. format: EnvelopeFormat, }, /// Envelope encoding required zstd compression, but the feature is not enabled. - #[error("Zstd compression is not supported. This requires the 'zstd' feature for `hugr`.")] + #[display("Zstd compression is not supported. This requires the 'zstd' feature for `hugr`.")] + #[from(ignore)] ZstdUnsupported, /// Expected the envelope to contain a single HUGR. - #[error("Expected an envelope containing a single hugr, but it contained {}.", if *count == 0 { + #[display("Expected an envelope containing a single hugr, but it contained {}.", if *count == 0 { "none".to_string() } else { count.to_string() })] + #[from(ignore)] ExpectedSingleHugr { /// The number of HUGRs in the package. count: usize, }, /// JSON serialization error. - #[error(transparent)] SerdeError { /// The source error. - #[from] source: serde_json::Error, }, /// IO read/write error. - #[error(transparent)] IO { /// The source error. - #[from] source: std::io::Error, }, /// Error writing a json package to the payload. - #[error(transparent)] PackageEncoding { /// The source error. - #[from] source: PackageEncodingError, }, /// Error importing a HUGR from a hugr-model payload. - #[error(transparent)] ModelImport { /// The source error. - #[from] source: ImportError, // TODO add generator to model import errors }, /// Error reading a HUGR model payload. - #[error(transparent)] ModelRead { /// The source error. - #[from] source: hugr_model::v0::binary::ReadError, }, /// Error writing a HUGR model payload. - #[error(transparent)] ModelWrite { /// The source error. - #[from] source: hugr_model::v0::binary::WriteError, }, /// Error reading a HUGR model payload. - #[error("Model text parsing error")] ModelTextRead { /// The source error. - #[from] source: hugr_model::v0::ast::ParseError, }, /// Error reading a HUGR model payload. - #[error(transparent)] ModelTextResolve { /// The source error. - #[from] source: hugr_model::v0::ast::ResolveError, }, - /// Error reading a list of extensions from the envelope. - #[error(transparent)] - ExtensionLoad { - /// The source error. - #[from] - source: crate::extension::ExtensionRegistryLoadError, - }, - /// The specified payload format is not supported. - #[error( - "The envelope configuration has unknown {}. Please update your HUGR version.", - if flag_ids.len() == 1 {format!("flag #{}", flag_ids[0])} else {format!("flags {}", flag_ids.iter().join(", "))} - )] - FlagUnsupported { - /// The unrecognized flag bits. - flag_ids: Vec, - }, } /// Internal implementation of [`read_envelope`] to call with/without the zstd decompression wrapper. @@ -380,8 +329,11 @@ fn decode_model( let mut extension_registry = extension_registry.clone(); if format == EnvelopeFormat::ModelWithExtensions { - let extra_extensions = ExtensionRegistry::load_json(stream, &extension_registry)?; - extension_registry.extend(extra_extensions); + let extra_extensions: Vec = + serde_json::from_reader::<_, Vec>(stream)?; + for ext in extra_extensions { + extension_registry.register_updated(ext); + } } Ok(import_package(&model_package, &extension_registry)?) @@ -851,6 +803,6 @@ pub(crate) mod test { let err_msg = with_gen.to_string(); assert!(err_msg.contains("Extension 'test' version mismatch")); - assert!(err_msg.contains("TestGenerator-v1.2.3")); + assert!(err_msg.contains(generator_name.to_string().as_str())); } } diff --git a/hugr-core/src/envelope/header.rs b/hugr-core/src/envelope/header.rs index 54353e2f18..66af887454 100644 --- a/hugr-core/src/envelope/header.rs +++ b/hugr-core/src/envelope/header.rs @@ -3,8 +3,6 @@ use std::io::{Read, Write}; use std::num::NonZeroU8; -use itertools::Itertools; - use super::EnvelopeError; /// Magic number identifying the start of an envelope. @@ -13,12 +11,6 @@ use super::EnvelopeError; /// to avoid accidental collisions with other file formats. pub const MAGIC_NUMBERS: &[u8] = "HUGRiHJv".as_bytes(); -/// The all-unset header flags configuration. -/// Bit 7 is always set to ensure we have a printable ASCII character. -const DEFAULT_FLAGS: u8 = 0b0100_0000u8; -/// The ZSTD flag bit in the header's flags. -const ZSTD_FLAG: u8 = 0b0000_0001; - /// Header at the start of a binary envelope file. /// /// See the [`crate::envelope`] module documentation for the binary format. @@ -232,10 +224,8 @@ impl EnvelopeHeader { let format_bytes = [self.format as u8]; writer.write_all(&format_bytes)?; // Next is the flags byte. - let mut flags = DEFAULT_FLAGS; - if self.zstd { - flags |= ZSTD_FLAG; - } + let mut flags = 0b01000000u8; + flags |= u8::from(self.zstd); writer.write_all(&[flags])?; Ok(()) @@ -269,16 +259,7 @@ impl EnvelopeHeader { // Next is the flags byte. let mut flags_bytes = [0; 1]; reader.read_exact(&mut flags_bytes)?; - let flags: u8 = flags_bytes[0]; - - let zstd = flags & ZSTD_FLAG != 0; - - // Check if there's any unrecognized flags. - let other_flags = (flags ^ DEFAULT_FLAGS) & !ZSTD_FLAG; - if other_flags != 0 { - let flag_ids = (0..8).filter(|i| other_flags & (1 << i) != 0).collect_vec(); - return Err(EnvelopeError::FlagUnsupported { flag_ids }); - } + let zstd = flags_bytes[0] & 0x1 != 0; Ok(Self { format, zstd }) } @@ -287,7 +268,6 @@ impl EnvelopeHeader { #[cfg(test)] mod tests { use super::*; - use cool_asserts::assert_matches; use rstest::rstest; #[rstest] @@ -316,35 +296,4 @@ mod tests { let read_header = EnvelopeHeader::read(&mut buffer.as_slice()).unwrap(); assert_eq!(header, read_header); } - - #[rstest] - fn header_errors() { - let header = EnvelopeHeader { - format: EnvelopeFormat::Model, - zstd: false, - }; - let mut buffer = Vec::new(); - header.write(&mut buffer).unwrap(); - - assert_eq!(buffer.len(), 10); - let flags = buffer[9]; - assert_eq!(flags, DEFAULT_FLAGS); - - // Invalid magic - let mut invalid_magic = buffer.clone(); - invalid_magic[7] = 0xFF; - assert_matches!( - EnvelopeHeader::read(&mut invalid_magic.as_slice()), - Err(EnvelopeError::MagicNumber { .. }) - ); - - // Unrecognised flags - let mut unrecognised_flags = buffer.clone(); - unrecognised_flags[9] |= 0b0001_0010; - assert_matches!( - EnvelopeHeader::read(&mut unrecognised_flags.as_slice()), - Err(EnvelopeError::FlagUnsupported { flag_ids }) - => assert_eq!(flag_ids, vec![1, 4]) - ); - } } diff --git a/hugr-core/src/envelope/package_json.rs b/hugr-core/src/envelope/package_json.rs index 2aa6c982f7..bbdf19d26e 100644 --- a/hugr-core/src/envelope/package_json.rs +++ b/hugr-core/src/envelope/package_json.rs @@ -6,6 +6,7 @@ use std::io; use super::{ExtensionBreakingError, WithGenerator, check_breaking_extensions}; use crate::extension::ExtensionRegistry; use crate::extension::resolution::ExtensionResolutionError; +use crate::hugr::ExtensionError; use crate::package::Package; use crate::{Extension, Hugr}; @@ -56,20 +57,6 @@ pub(super) fn to_json_writer<'h>( modules: hugrs.into_iter().map(HugrSer).collect(), extensions: extensions.iter().map(std::convert::AsRef::as_ref).collect(), }; - - // Validate the hugr serializations against the schema. - // - // NOTE: The schema definition is currently broken, so this check always succeeds. - // See - #[cfg(all(test, not(miri)))] - if std::env::var("HUGR_TEST_SCHEMA").is_ok_and(|x| !x.is_empty()) { - use crate::hugr::serialize::test::check_hugr_serialization_schema; - - for hugr in &pkg_ser.modules { - check_hugr_serialization_schema(hugr.0); - } - } - serde_json::to_writer(writer, &pkg_ser)?; Ok(()) } @@ -77,16 +64,17 @@ pub(super) fn to_json_writer<'h>( /// Error raised while loading a package. #[derive(Debug, Display, Error, From)] #[non_exhaustive] -#[display("Error reading or writing a package in JSON format.")] pub enum PackageEncodingError { /// Error raised while parsing the package json. - JsonEncoding(#[from] serde_json::Error), + JsonEncoding(serde_json::Error), /// Error raised while reading from a file. - IOError(#[from] io::Error), + IOError(io::Error), /// Could not resolve the extension needed to encode the hugr. - ExtensionResolution(#[from] WithGenerator), + ExtensionResolution(WithGenerator), /// Error raised while checking for breaking extension version mismatch. - ExtensionVersion(#[from] WithGenerator), + ExtensionVersion(WithGenerator), + /// Could not resolve the runtime extensions for the hugr. + RuntimeExtensionResolution(ExtensionError), } /// A private package structure implementing the serde traits. diff --git a/hugr-core/src/envelope/serde_with.rs b/hugr-core/src/envelope/serde_with.rs index 28d3cd3189..7b9517d3e0 100644 --- a/hugr-core/src/envelope/serde_with.rs +++ b/hugr-core/src/envelope/serde_with.rs @@ -15,9 +15,6 @@ use crate::std_extensions::STD_REG; /// De/Serialize a package or hugr by encoding it into a textual Envelope and /// storing it as a string. /// -/// This is similar to [`AsBinaryEnvelope`], but uses a textual envelope instead -/// of a binary one. -/// /// Note that only PRELUDE extensions are used to decode the package's content. /// When serializing a HUGR, any additional extensions required to load it are /// embedded in the envelope. Packages should manually add any required @@ -48,53 +45,9 @@ use crate::std_extensions::STD_REG; /// When reading an encoded HUGR, the `AsStringEnvelope` deserializer will first /// try to decode the value as an string-encoded envelope. If that fails, it /// will fallback to decoding the legacy HUGR serde definition. This temporary -/// compatibility is required to support `hugr <= 0.19` and will be removed in -/// a future version. +/// compatibility layer is meant to be removed in 0.21.0. pub struct AsStringEnvelope; -/// De/Serialize a package or hugr by encoding it into a binary envelope and -/// storing it as a base64-encoded string. -/// -/// This is similar to [`AsStringEnvelope`], but uses a binary envelope instead -/// of a string. -/// When deserializing, if the string starts with the envelope magic 'HUGRiHJv' -/// it will be loaded as a string envelope without base64 decoding. -/// -/// Note that only PRELUDE extensions are used to decode the package's content. -/// When serializing a HUGR, any additional extensions required to load it are -/// embedded in the envelope. Packages should manually add any required -/// extensions before serializing. -/// -/// # Examples -/// -/// ```rust -/// # use serde::{Deserialize, Serialize}; -/// # use serde_json::json; -/// # use serde_with::{serde_as}; -/// # use hugr_core::Hugr; -/// # use hugr_core::package::Package; -/// # use hugr_core::envelope::serde_with::AsBinaryEnvelope; -/// # -/// #[serde_as] -/// #[derive(Deserialize, Serialize)] -/// struct A { -/// #[serde_as(as = "AsBinaryEnvelope")] -/// package: Package, -/// #[serde_as(as = "Vec")] -/// hugrs: Vec, -/// } -/// ``` -/// -/// # Backwards compatibility -/// -/// When reading an encoded HUGR, the `AsBinaryEnvelope` deserializer will first -/// try to decode the value as an binary-encoded envelope. If that fails, it -/// will fallback to decoding a string envelope instead, and then finally to -/// decoding the legacy HUGR serde definition. This temporary compatibility -/// layer is required to support `hugr <= 0.19` and will be removed in a future -/// version. -pub struct AsBinaryEnvelope; - /// Implements [`serde_with::DeserializeAs`] and [`serde_with::SerializeAs`] for /// the helper to deserialize `Hugr` and `Package` types, using the given /// extension registry. @@ -258,337 +211,3 @@ macro_rules! impl_serde_as_string_envelope { pub use impl_serde_as_string_envelope; impl_serde_as_string_envelope!(AsStringEnvelope, &STD_REG); - -/// Implements [`serde_with::DeserializeAs`] and [`serde_with::SerializeAs`] for -/// the helper to deserialize `Hugr` and `Package` types, using the given -/// extension registry. -/// -/// This macro is used to implement the default [`AsBinaryEnvelope`] wrapper. -/// -/// # Parameters -/// -/// - `$adaptor`: The name of the adaptor type to implement. -/// - `$extension_reg`: A reference to the extension registry to use for deserialization. -/// -/// # Examples -/// -/// ```rust -/// # use serde::{Deserialize, Serialize}; -/// # use serde_json::json; -/// # use serde_with::{serde_as}; -/// # use hugr_core::Hugr; -/// # use hugr_core::package::Package; -/// # use hugr_core::envelope::serde_with::AsBinaryEnvelope; -/// # use hugr_core::envelope::serde_with::impl_serde_as_binary_envelope; -/// # use hugr_core::extension::ExtensionRegistry; -/// # -/// struct CustomAsEnvelope; -/// -/// impl_serde_as_binary_envelope!(CustomAsEnvelope, &hugr_core::extension::EMPTY_REG); -/// -/// #[serde_as] -/// #[derive(Deserialize, Serialize)] -/// struct A { -/// #[serde_as(as = "CustomAsEnvelope")] -/// package: Package, -/// } -/// ``` -/// -#[macro_export] -macro_rules! impl_serde_as_binary_envelope { - ($adaptor:ident, $extension_reg:expr) => { - impl<'de> serde_with::DeserializeAs<'de, $crate::package::Package> for $adaptor { - fn deserialize_as(deserializer: D) -> Result<$crate::package::Package, D::Error> - where - D: serde::Deserializer<'de>, - { - struct Helper; - impl serde::de::Visitor<'_> for Helper { - type Value = $crate::package::Package; - - fn expecting( - &self, - formatter: &mut std::fmt::Formatter<'_>, - ) -> std::fmt::Result { - formatter.write_str("a base64-encoded envelope") - } - - fn visit_str(self, value: &str) -> Result - where - E: serde::de::Error, - { - use $crate::envelope::serde_with::base64::{DecoderReader, STANDARD}; - - let extensions: &$crate::extension::ExtensionRegistry = $extension_reg; - - if value - .as_bytes() - .starts_with($crate::envelope::MAGIC_NUMBERS) - { - // If the string starts with the envelope magic 'HUGRiHJv', - // skip the base64 decoding. - let reader = std::io::Cursor::new(value.as_bytes()); - $crate::package::Package::load(reader, Some(extensions)) - .map_err(|e| serde::de::Error::custom(format!("{e:?}"))) - } else { - let reader = DecoderReader::new(value.as_bytes(), &STANDARD); - let buf_reader = std::io::BufReader::new(reader); - $crate::package::Package::load(buf_reader, Some(extensions)) - .map_err(|e| serde::de::Error::custom(format!("{e:?}"))) - } - } - } - - deserializer.deserialize_str(Helper) - } - } - - impl<'de> serde_with::DeserializeAs<'de, $crate::Hugr> for $adaptor { - fn deserialize_as(deserializer: D) -> Result<$crate::Hugr, D::Error> - where - D: serde::Deserializer<'de>, - { - struct Helper; - impl<'vis> serde::de::Visitor<'vis> for Helper { - type Value = $crate::Hugr; - - fn expecting( - &self, - formatter: &mut std::fmt::Formatter<'_>, - ) -> std::fmt::Result { - formatter.write_str("a base64-encoded envelope") - } - - fn visit_str(self, value: &str) -> Result - where - E: serde::de::Error, - { - use $crate::envelope::serde_with::base64::{DecoderReader, STANDARD}; - - let extensions: &$crate::extension::ExtensionRegistry = $extension_reg; - - if value - .as_bytes() - .starts_with($crate::envelope::MAGIC_NUMBERS) - { - // If the string starts with the envelope magic 'HUGRiHJv', - // skip the base64 decoding. - let reader = std::io::Cursor::new(value.as_bytes()); - $crate::Hugr::load(reader, Some(extensions)) - .map_err(|e| serde::de::Error::custom(format!("{e:?}"))) - } else { - let reader = DecoderReader::new(value.as_bytes(), &STANDARD); - let buf_reader = std::io::BufReader::new(reader); - $crate::Hugr::load(buf_reader, Some(extensions)) - .map_err(|e| serde::de::Error::custom(format!("{e:?}"))) - } - } - - fn visit_map(self, map: A) -> Result - where - A: serde::de::MapAccess<'vis>, - { - // Backwards compatibility: If the encoded value is not a - // string, we may have a legacy HUGR serde structure instead. In that - // case, we can add an envelope header and try again. - // - // TODO: Remove this fallback in a breaking change - let deserializer = serde::de::value::MapAccessDeserializer::new(map); - #[allow(deprecated)] - let mut hugr = - $crate::hugr::serialize::serde_deserialize_hugr(deserializer) - .map_err(serde::de::Error::custom)?; - - let extensions: &$crate::extension::ExtensionRegistry = $extension_reg; - hugr.resolve_extension_defs(extensions) - .map_err(serde::de::Error::custom)?; - Ok(hugr) - } - } - - // TODO: Go back to `deserialize_str` once the fallback is removed. - deserializer.deserialize_any(Helper) - } - } - - impl serde_with::SerializeAs<$crate::package::Package> for $adaptor { - fn serialize_as( - source: &$crate::package::Package, - serializer: S, - ) -> Result - where - S: serde::Serializer, - { - use $crate::envelope::serde_with::base64::{EncoderStringWriter, STANDARD}; - - let mut writer = EncoderStringWriter::new(&STANDARD); - source - .store(&mut writer, $crate::envelope::EnvelopeConfig::binary()) - .map_err(serde::ser::Error::custom)?; - let str = writer.into_inner(); - serializer.collect_str(&str) - } - } - - impl serde_with::SerializeAs<$crate::Hugr> for $adaptor { - fn serialize_as(source: &$crate::Hugr, serializer: S) -> Result - where - S: serde::Serializer, - { - // Include any additional extension required to load the HUGR in the envelope. - let extensions: &$crate::extension::ExtensionRegistry = $extension_reg; - let mut extra_extensions = $crate::extension::ExtensionRegistry::default(); - for ext in $crate::hugr::views::HugrView::extensions(source).iter() { - if !extensions.contains(ext.name()) { - extra_extensions.register_updated(ext.clone()); - } - } - use $crate::envelope::serde_with::base64::{EncoderStringWriter, STANDARD}; - - let mut writer = EncoderStringWriter::new(&STANDARD); - source - .store_with_exts( - &mut writer, - $crate::envelope::EnvelopeConfig::binary(), - &extra_extensions, - ) - .map_err(serde::ser::Error::custom)?; - let str = writer.into_inner(); - serializer.collect_str(&str) - } - } - }; -} -pub use impl_serde_as_binary_envelope; - -impl_serde_as_binary_envelope!(AsBinaryEnvelope, &STD_REG); - -// Hidden re-export required to expand the binary envelope macros on external -// crates. -#[doc(hidden)] -pub mod base64 { - pub use base64::Engine; - pub use base64::engine::general_purpose::STANDARD; - pub use base64::read::DecoderReader; - pub use base64::write::EncoderStringWriter; -} - -#[cfg(test)] -mod test { - use rstest::rstest; - use serde::{Deserialize, Serialize}; - use serde_with::serde_as; - - use crate::Hugr; - use crate::package::Package; - - use super::*; - - #[serde_as] - #[derive(Deserialize, Serialize)] - struct TextPkg { - #[serde_as(as = "AsStringEnvelope")] - data: Package, - } - - #[serde_as] - #[derive(Default, Deserialize, Serialize)] - struct TextHugr { - #[serde_as(as = "AsStringEnvelope")] - data: Hugr, - } - - #[serde_as] - #[derive(Deserialize, Serialize)] - struct BinaryPkg { - #[serde_as(as = "AsBinaryEnvelope")] - data: Package, - } - - #[serde_as] - #[derive(Default, Deserialize, Serialize)] - struct BinaryHugr { - #[serde_as(as = "AsBinaryEnvelope")] - data: Hugr, - } - - #[derive(Default, Deserialize, Serialize)] - struct LegacyHugr { - #[serde(deserialize_with = "Hugr::serde_deserialize")] - #[serde(serialize_with = "Hugr::serde_serialize")] - data: Hugr, - } - - impl Default for TextPkg { - fn default() -> Self { - // Default package with a single hugr (so it can be decoded as a hugr too). - Self { - data: Package::from_hugr(Hugr::default()), - } - } - } - - impl Default for BinaryPkg { - fn default() -> Self { - // Default package with a single hugr (so it can be decoded as a hugr too). - Self { - data: Package::from_hugr(Hugr::default()), - } - } - } - - fn decode serde::Deserialize<'a>>(encoded: String) -> Result<(), serde_json::Error> { - let _: T = serde_json::de::from_str(&encoded)?; - Ok(()) - } - - #[rstest] - // Text formats are swappable - #[case::text_pkg_text_pkg(TextPkg::default(), decode::, false)] - #[case::text_pkg_text_hugr(TextPkg::default(), decode::, false)] - #[case::text_hugr_text_pkg(TextHugr::default(), decode::, false)] - #[case::text_hugr_text_hugr(TextHugr::default(), decode::, false)] - // Binary formats can read each other - #[case::bin_pkg_bin_pkg(BinaryPkg::default(), decode::, false)] - #[case::bin_pkg_bin_hugr(BinaryPkg::default(), decode::, false)] - #[case::bin_hugr_bin_pkg(BinaryHugr::default(), decode::, false)] - #[case::bin_hugr_bin_hugr(BinaryHugr::default(), decode::, false)] - // Binary formats can read text ones - #[case::text_pkg_bin_pkg(TextPkg::default(), decode::, false)] - #[case::text_pkg_bin_hugr(TextPkg::default(), decode::, false)] - #[case::text_hugr_bin_pkg(TextHugr::default(), decode::, false)] - #[case::text_hugr_bin_hugr(TextHugr::default(), decode::, false)] - // But text formats can't read binary - #[case::bin_pkg_text_pkg(BinaryPkg::default(), decode::, true)] - #[case::bin_pkg_text_hugr(BinaryPkg::default(), decode::, true)] - #[case::bin_hugr_text_pkg(BinaryHugr::default(), decode::, true)] - #[case::bin_hugr_text_hugr(BinaryHugr::default(), decode::, true)] - // We can read old hugrs into hugrs, but not packages - #[case::legacy_hugr_text_pkg(LegacyHugr::default(), decode::, true)] - #[case::legacy_hugr_text_hugr(LegacyHugr::default(), decode::, false)] - #[case::legacy_hugr_bin_pkg(LegacyHugr::default(), decode::, true)] - #[case::legacy_hugr_bin_hugr(LegacyHugr::default(), decode::, false)] - // Decoding any new format as legacy hugr always fails - #[case::text_pkg_legacy_hugr(TextPkg::default(), decode::, true)] - #[case::text_hugr_legacy_hugr(TextHugr::default(), decode::, true)] - #[case::bin_pkg_legacy_hugr(BinaryPkg::default(), decode::, true)] - #[case::bin_hugr_legacy_hugr(BinaryHugr::default(), decode::, true)] - #[cfg_attr(all(miri, feature = "zstd"), ignore)] // FFI calls (required to compress with zstd) are not supported in miri - fn check_format_compatibility( - #[case] encoder: impl serde::Serialize, - #[case] decoder: fn(String) -> Result<(), serde_json::Error>, - #[case] errors: bool, - ) { - let encoded = serde_json::to_string(&encoder).unwrap(); - let decoded = decoder(encoded); - match (errors, decoded) { - (false, Err(e)) => { - panic!("Decoding error: {e}"); - } - (true, Ok(_)) => { - panic!("Roundtrip should have failed"); - } - _ => {} - } - } -} diff --git a/hugr-core/src/export.rs b/hugr-core/src/export.rs index 5eaead792a..dff471cc59 100644 --- a/hugr-core/src/export.rs +++ b/hugr-core/src/export.rs @@ -1,8 +1,6 @@ //! Exporting HUGR graphs to their `hugr-model` representation. -use crate::Visibility; use crate::extension::ExtensionRegistry; use crate::hugr::internal::HugrInternals; -use crate::types::type_param::Term; use crate::{ Direction, Hugr, HugrView, IncomingPort, Node, NodeIndex as _, Port, extension::{ExtensionId, OpDef, SignatureFunc}, @@ -16,19 +14,19 @@ use crate::{ }, types::{ CustomType, EdgeKind, FuncTypeBase, MaybeRV, PolyFuncTypeBase, RowVariable, SumType, - TypeBase, TypeBound, TypeEnum, type_param::TermVar, type_row::TypeRowBase, + TypeArg, TypeBase, TypeBound, TypeEnum, TypeRow, + type_param::{TypeArgVariable, TypeParam}, + type_row::TypeRowBase, }, }; use fxhash::{FxBuildHasher, FxHashMap}; -use hugr_model::v0::bumpalo; use hugr_model::v0::{ self as model, bumpalo::{Bump, collections::String as BumpString, collections::Vec as BumpVec}, table, }; use petgraph::unionfind::UnionFind; -use smol_str::ToSmolStr; use std::fmt::Write; /// Exports a deconstructed `Package` to its representation in the model. @@ -97,8 +95,6 @@ struct Context<'a> { // that ensures that the `node_to_id` and `id_to_node` maps stay in sync. } -const NO_VIS: Option = None; - impl<'a> Context<'a> { pub fn new(hugr: &'a Hugr, bump: &'a Bump) -> Self { let mut module = table::Module::default(); @@ -235,6 +231,16 @@ impl<'a> Context<'a> { } } + /// Get the name of a function definition or declaration node. Returns `None` if not + /// one of those operations. + fn get_func_name(&self, func_node: Node) -> Option<&'a str> { + match self.hugr.get_optype(func_node) { + OpType::FuncDecl(func_decl) => Some(func_decl.func_name()), + OpType::FuncDefn(func_defn) => Some(func_defn.func_name()), + _ => None, + } + } + fn with_local_scope(&mut self, node: table::NodeId, f: impl FnOnce(&mut Self) -> T) -> T { let prev_local_scope = self.local_scope.replace(node); let prev_local_constraints = std::mem::take(&mut self.local_constraints); @@ -263,12 +269,8 @@ impl<'a> Context<'a> { // We record the name of the symbol defined by the node, if any. let symbol = match optype { - OpType::FuncDefn(_) | OpType::FuncDecl(_) => { - // Functions aren't exported using their core name but with a mangled - // name derived from their id. The function's core name will be recorded - // using `core.title` metadata. - Some(self.mangled_name(node)) - } + OpType::FuncDefn(func_defn) => Some(func_defn.func_name().as_str()), + OpType::FuncDecl(func_decl) => Some(func_decl.func_name().as_str()), OpType::AliasDecl(alias_decl) => Some(alias_decl.name.as_str()), OpType::AliasDefn(alias_defn) => Some(alias_defn.name.as_str()), _ => None, @@ -288,7 +290,6 @@ impl<'a> Context<'a> { // the node id. This is necessary to establish the correct node id for the // local scope introduced by some operations. We will overwrite this node later. let mut regions: &[_] = &[]; - let mut meta = Vec::new(); let node = self.id_to_node[&node_id]; let optype = self.hugr.get_optype(node); @@ -309,7 +310,6 @@ impl<'a> Context<'a> { node, model::ScopeClosure::Open, false, - false, )]); table::Operation::Dfg } @@ -334,36 +334,24 @@ impl<'a> Context<'a> { node, model::ScopeClosure::Open, false, - false, )]); table::Operation::Block } OpType::FuncDefn(func) => self.with_local_scope(node_id, |this| { - let symbol_name = this.export_func_name(node, &mut meta); - - let symbol = this.export_poly_func_type( - symbol_name, - Some(func.visibility().clone().into()), - func.signature(), - ); + let name = this.get_func_name(node).unwrap(); + let symbol = this.export_poly_func_type(name, func.signature()); regions = this.bump.alloc_slice_copy(&[this.export_dfg( node, model::ScopeClosure::Closed, false, - false, )]); table::Operation::DefineFunc(symbol) }), OpType::FuncDecl(func) => self.with_local_scope(node_id, |this| { - let symbol_name = this.export_func_name(node, &mut meta); - - let symbol = this.export_poly_func_type( - symbol_name, - Some(func.visibility().clone().into()), - func.signature(), - ); + let name = this.get_func_name(node).unwrap(); + let symbol = this.export_poly_func_type(name, func.signature()); table::Operation::DeclareFunc(symbol) }), @@ -371,7 +359,6 @@ impl<'a> Context<'a> { // TODO: We should support aliases with different types and with parameters let signature = this.make_term_apply(model::CORE_TYPE, &[]); let symbol = this.bump.alloc(table::Symbol { - visibility: &NO_VIS, // not spec'd in hugr-core name: &alias.name, params: &[], constraints: &[], @@ -385,7 +372,6 @@ impl<'a> Context<'a> { // TODO: We should support aliases with different types and with parameters let signature = this.make_term_apply(model::CORE_TYPE, &[]); let symbol = this.bump.alloc(table::Symbol { - visibility: &NO_VIS, // not spec'd in hugr-core name: &alias.name, params: &[], constraints: &[], @@ -399,7 +385,7 @@ impl<'a> Context<'a> { let node = self.connected_function(node).unwrap(); let symbol = self.node_to_id[&node]; let mut args = BumpVec::new_in(self.bump); - args.extend(call.type_args.iter().map(|arg| self.export_term(arg, None))); + args.extend(call.type_args.iter().map(|arg| self.export_type_arg(arg))); let args = args.into_bump_slice(); let func = self.make_term(table::Term::Apply(symbol, args)); @@ -415,7 +401,7 @@ impl<'a> Context<'a> { let node = self.connected_function(node).unwrap(); let symbol = self.node_to_id[&node]; let mut args = BumpVec::new_in(self.bump); - args.extend(load.type_args.iter().map(|arg| self.export_term(arg, None))); + args.extend(load.type_args.iter().map(|arg| self.export_type_arg(arg))); let args = args.into_bump_slice(); let func = self.make_term(table::Term::Apply(symbol, args)); let runtime_type = self.make_term(table::Term::Wildcard); @@ -465,7 +451,6 @@ impl<'a> Context<'a> { node, model::ScopeClosure::Open, false, - false, )]); table::Operation::TailLoop } @@ -479,7 +464,7 @@ impl<'a> Context<'a> { let node = self.export_opdef(op.def()); let params = self .bump - .alloc_slice_fill_iter(op.args().iter().map(|arg| self.export_term(arg, None))); + .alloc_slice_fill_iter(op.args().iter().map(|arg| self.export_type_arg(arg))); let operation = self.make_term(table::Term::Apply(node, params)); table::Operation::Custom(operation) } @@ -488,7 +473,7 @@ impl<'a> Context<'a> { let node = self.make_named_global_ref(op.extension(), op.unqualified_id()); let params = self .bump - .alloc_slice_fill_iter(op.args().iter().map(|arg| self.export_term(arg, None))); + .alloc_slice_fill_iter(op.args().iter().map(|arg| self.export_type_arg(arg))); let operation = self.make_term(table::Term::Apply(node, params)); table::Operation::Custom(operation) } @@ -517,10 +502,12 @@ impl<'a> Context<'a> { let inputs = self.make_ports(node, Direction::Incoming, num_inputs); let outputs = self.make_ports(node, Direction::Outgoing, num_outputs); - self.export_node_json_metadata(node, &mut meta); - self.export_node_order_metadata(node, &mut meta); - self.export_node_entrypoint_metadata(node, &mut meta); - let meta = self.bump.alloc_slice_copy(&meta); + let meta = { + let mut meta = Vec::new(); + self.export_node_json_metadata(node, &mut meta); + self.export_node_order_metadata(node, &mut meta); + self.bump.alloc_slice_copy(&meta) + }; self.module.nodes[node_id.index()] = table::Node { operation, @@ -559,7 +546,7 @@ impl<'a> Context<'a> { let symbol = self.with_local_scope(node, |this| { let name = this.make_qualified_name(opdef.extension_id(), opdef.name()); - this.export_poly_func_type(name, None, poly_func_type) + this.export_poly_func_type(name, poly_func_type) }); let meta = { @@ -591,6 +578,7 @@ impl<'a> Context<'a> { pub fn export_block_signature(&mut self, block: &DataflowBlock) -> table::TermId { let inputs = { let inputs = self.export_type_row(&block.inputs); + let inputs = self.make_term_apply(model::CORE_CTRL, &[inputs]); self.make_term(table::Term::List( self.bump.alloc_slice_copy(&[table::SeqPart::Item(inputs)]), )) @@ -602,12 +590,13 @@ impl<'a> Context<'a> { let mut outputs = BumpVec::with_capacity_in(block.sum_rows.len(), self.bump); for sum_row in &block.sum_rows { let variant = self.export_type_row_with_tail(sum_row, Some(tail)); - outputs.push(table::SeqPart::Item(variant)); + let control = self.make_term_apply(model::CORE_CTRL, &[variant]); + outputs.push(table::SeqPart::Item(control)); } self.make_term(table::Term::List(outputs.into_bump_slice())) }; - self.make_term_apply(model::CORE_CTRL, &[inputs, outputs]) + self.make_term_apply(model::CORE_FN, &[inputs, outputs]) } /// Creates a data flow region from the given node's children. @@ -618,7 +607,6 @@ impl<'a> Context<'a> { node: Node, closure: model::ScopeClosure, export_json_meta: bool, - export_entrypoint_meta: bool, ) -> table::RegionId { let region = self.module.insert_region(table::Region::default()); @@ -637,54 +625,46 @@ impl<'a> Context<'a> { if export_json_meta { self.export_node_json_metadata(node, &mut meta); } - if export_entrypoint_meta { - self.export_node_entrypoint_metadata(node, &mut meta); - } + self.export_node_entrypoint_metadata(node, &mut meta); let children = self.hugr.children(node); let mut region_children = BumpVec::with_capacity_in(children.size_hint().0 - 2, self.bump); + let mut output_node = None; + for child in children { match self.hugr.get_optype(child) { OpType::Input(input) => { sources = self.make_ports(child, Direction::Outgoing, input.types.len()); input_types = Some(&input.types); - - if has_order_edges(self.hugr, child) { - let key = self.make_term(model::Literal::Nat(child.index() as u64).into()); - meta.push(self.make_term_apply(model::ORDER_HINT_INPUT_KEY, &[key])); - } } OpType::Output(output) => { targets = self.make_ports(child, Direction::Incoming, output.types.len()); output_types = Some(&output.types); - - if has_order_edges(self.hugr, child) { - let key = self.make_term(model::Literal::Nat(child.index() as u64).into()); - meta.push(self.make_term_apply(model::ORDER_HINT_OUTPUT_KEY, &[key])); - } + output_node = Some(child); } - _ => { + child_optype => { if let Some(child_id) = self.export_node_shallow(child) { region_children.push(child_id); + + // Record all order edges that originate from this node in metadata. + let successors = child_optype + .other_output_port() + .into_iter() + .flat_map(|port| self.hugr.linked_inputs(child, port)) + .map(|(successor, _)| successor) + .filter(|successor| Some(*successor) != output_node); + + for successor in successors { + let a = + self.make_term(model::Literal::Nat(child.index() as u64).into()); + let b = self + .make_term(model::Literal::Nat(successor.index() as u64).into()); + meta.push(self.make_term_apply(model::ORDER_HINT_ORDER, &[a, b])); + } } } } - - // Record all order edges that originate from this node in metadata. - let successors = self - .hugr - .get_optype(child) - .other_output_port() - .into_iter() - .flat_map(|port| self.hugr.linked_inputs(child, port)) - .map(|(successor, _)| successor); - - for successor in successors { - let a = self.make_term(model::Literal::Nat(child.index() as u64).into()); - let b = self.make_term(model::Literal::Nat(successor.index() as u64).into()); - meta.push(self.make_term_apply(model::ORDER_HINT_ORDER, &[a, b])); - } } for child_id in ®ion_children { @@ -760,21 +740,18 @@ impl<'a> Context<'a> { let signature = { let node_signature = self.hugr.signature(node).unwrap(); - let inputs = { - let types = self.export_type_row(node_signature.input()); - self.make_term(table::Term::List( - self.bump.alloc_slice_copy(&[table::SeqPart::Item(types)]), - )) - }; - - let outputs = { - let types = self.export_type_row(node_signature.output()); + let mut wrap_ctrl = |types: &TypeRow| { + let types = self.export_type_row(types); + let types_ctrl = self.make_term_apply(model::CORE_CTRL, &[types]); self.make_term(table::Term::List( - self.bump.alloc_slice_copy(&[table::SeqPart::Item(types)]), + self.bump + .alloc_slice_copy(&[table::SeqPart::Item(types_ctrl)]), )) }; - Some(self.make_term_apply(model::CORE_CTRL, &[inputs, outputs])) + let inputs = wrap_ctrl(node_signature.input()); + let outputs = wrap_ctrl(node_signature.output()); + Some(self.make_term_apply(model::CORE_FN, &[inputs, outputs])) }; let scope = match closure { @@ -809,7 +786,7 @@ impl<'a> Context<'a> { panic!("expected a `Case` node as a child of a `Conditional` node"); }; - regions.push(self.export_dfg(child, model::ScopeClosure::Open, true, true)); + regions.push(self.export_dfg(child, model::ScopeClosure::Open, true)); } regions.into_bump_slice() @@ -819,17 +796,16 @@ impl<'a> Context<'a> { pub fn export_poly_func_type( &mut self, name: &'a str, - visibility: Option, t: &PolyFuncTypeBase, ) -> &'a table::Symbol<'a> { let mut params = BumpVec::with_capacity_in(t.params().len(), self.bump); let scope = self .local_scope .expect("exporting poly func type outside of local scope"); - let visibility = self.bump.alloc(visibility); + for (i, param) in t.params().iter().enumerate() { let name = self.bump.alloc_str(&i.to_string()); - let r#type = self.export_term(param, Some((scope, i as _))); + let r#type = self.export_type_param(param, Some((scope, i as _))); let param = table::Param { name, r#type }; params.push(param); } @@ -838,7 +814,6 @@ impl<'a> Context<'a> { let body = self.export_func_type(t.body()); self.bump.alloc(table::Symbol { - visibility, name, params: params.into_bump_slice(), constraints, @@ -878,12 +853,30 @@ impl<'a> Context<'a> { let args = self .bump - .alloc_slice_fill_iter(t.args().iter().map(|p| self.export_term(p, None))); + .alloc_slice_fill_iter(t.args().iter().map(|p| self.export_type_arg(p))); let term = table::Term::Apply(symbol, args); self.make_term(term) } - pub fn export_type_arg_var(&mut self, var: &TermVar) -> table::TermId { + pub fn export_type_arg(&mut self, t: &TypeArg) -> table::TermId { + match t { + TypeArg::Type { ty } => self.export_type(ty), + TypeArg::BoundedNat { n } => self.make_term(model::Literal::Nat(*n).into()), + TypeArg::String { arg } => self.make_term(model::Literal::Str(arg.into()).into()), + TypeArg::Sequence { elems } => { + // For now we assume that the sequence is meant to be a list. + let parts = self.bump.alloc_slice_fill_iter( + elems + .iter() + .map(|elem| table::SeqPart::Item(self.export_type_arg(elem))), + ); + self.make_term(table::Term::List(parts)) + } + TypeArg::Variable { v } => self.export_type_arg_var(v), + } + } + + pub fn export_type_arg_var(&mut self, var: &TypeArgVariable) -> table::TermId { let node = self.local_scope.expect("local variable out of scope"); self.make_term(table::Term::Var(table::VarId(node, var.index() as _))) } @@ -949,19 +942,19 @@ impl<'a> Context<'a> { self.make_term(table::Term::List(parts)) } - /// Exports a term. + /// Exports a `TypeParam` to a term. /// - /// The `var` argument is set when the term being exported is the + /// The `var` argument is set when the type parameter being exported is the /// type of a parameter to a polymorphic definition. In that case we can /// generate a `nonlinear` constraint for the type of runtime types marked as /// `TypeBound::Copyable`. - pub fn export_term( + pub fn export_type_param( &mut self, - t: &Term, + t: &TypeParam, var: Option<(table::NodeId, table::VarIndex)>, ) -> table::TermId { match t { - Term::RuntimeType(b) => { + TypeParam::Type { b } => { if let (Some((node, index)), TypeBound::Copyable) = (var, b) { let term = self.make_term(table::Term::Var(table::VarId(node, index))); let non_linear = self.make_term_apply(model::CORE_NON_LINEAR, &[term]); @@ -970,57 +963,22 @@ impl<'a> Context<'a> { self.make_term_apply(model::CORE_TYPE, &[]) } - Term::BoundedNatType(_) => self.make_term_apply(model::CORE_NAT_TYPE, &[]), - Term::StringType => self.make_term_apply(model::CORE_STR_TYPE, &[]), - Term::BytesType => self.make_term_apply(model::CORE_BYTES_TYPE, &[]), - Term::FloatType => self.make_term_apply(model::CORE_FLOAT_TYPE, &[]), - Term::ListType(item_type) => { - let item_type = self.export_term(item_type, None); + // This ignores the bound on the natural for now. + TypeParam::BoundedNat { .. } => self.make_term_apply(model::CORE_NAT_TYPE, &[]), + TypeParam::String => self.make_term_apply(model::CORE_STR_TYPE, &[]), + TypeParam::List { param } => { + let item_type = self.export_type_param(param, None); self.make_term_apply(model::CORE_LIST_TYPE, &[item_type]) } - Term::TupleType(item_types) => { - let item_types = self.export_term(item_types, None); - self.make_term_apply(model::CORE_TUPLE_TYPE, &[item_types]) - } - Term::Runtime(ty) => self.export_type(ty), - Term::BoundedNat(value) => self.make_term(model::Literal::Nat(*value).into()), - Term::String(value) => self.make_term(model::Literal::Str(value.into()).into()), - Term::Float(value) => self.make_term(model::Literal::Float(*value).into()), - Term::Bytes(value) => self.make_term(model::Literal::Bytes(value.clone()).into()), - Term::List(elems) => { + TypeParam::Tuple { params } => { let parts = self.bump.alloc_slice_fill_iter( - elems - .iter() - .map(|elem| table::SeqPart::Item(self.export_term(elem, None))), - ); - self.make_term(table::Term::List(parts)) - } - Term::ListConcat(lists) => { - let parts = self.bump.alloc_slice_fill_iter( - lists - .iter() - .map(|elem| table::SeqPart::Splice(self.export_term(elem, None))), - ); - self.make_term(table::Term::List(parts)) - } - Term::Tuple(elems) => { - let parts = self.bump.alloc_slice_fill_iter( - elems + params .iter() - .map(|elem| table::SeqPart::Item(self.export_term(elem, None))), + .map(|param| table::SeqPart::Item(self.export_type_param(param, None))), ); - self.make_term(table::Term::Tuple(parts)) + let types = self.make_term(table::Term::List(parts)); + self.make_term_apply(model::CORE_TUPLE_TYPE, &[types]) } - Term::TupleConcat(tuples) => { - let parts = self.bump.alloc_slice_fill_iter( - tuples - .iter() - .map(|elem| table::SeqPart::Splice(self.export_term(elem, None))), - ); - self.make_term(table::Term::Tuple(parts)) - } - Term::Variable(v) => self.export_type_arg_var(v), - Term::StaticType => self.make_term_apply(model::CORE_STATIC, &[]), } } @@ -1084,7 +1042,7 @@ impl<'a> Context<'a> { let region = match hugr.entrypoint_optype() { OpType::DFG(_) => { - self.export_dfg(hugr.entrypoint(), model::ScopeClosure::Closed, true, true) + self.export_dfg(hugr.entrypoint(), model::ScopeClosure::Closed, true) } _ => panic!("Value::Function root must be a DFG"), }; @@ -1125,7 +1083,21 @@ impl<'a> Context<'a> { } fn export_node_order_metadata(&mut self, node: Node, meta: &mut Vec) { - if has_order_edges(self.hugr, node) { + fn is_relevant_node(hugr: &Hugr, node: Node) -> bool { + let optype = hugr.get_optype(node); + !optype.is_input() && !optype.is_output() + } + + let optype = self.hugr.get_optype(node); + + let has_order_edges = Direction::BOTH + .iter() + .filter(|dir| optype.other_port_kind(**dir) == Some(EdgeKind::StateOrder)) + .filter_map(|dir| optype.other_port(*dir)) + .flat_map(|port| self.hugr.linked_ports(node, port)) + .any(|(other, _)| is_relevant_node(self.hugr, other)); + + if has_order_edges { let key = self.make_term(model::Literal::Nat(node.index() as u64).into()); meta.push(self.make_term_apply(model::ORDER_HINT_KEY, &[key])); } @@ -1137,33 +1109,6 @@ impl<'a> Context<'a> { } } - /// Used when exporting function definitions or declarations. When the - /// function is public, its symbol name will be the core name. For private - /// functions, the symbol name is derived from the node id and the core name - /// is exported as `core.title` metadata. - /// - /// This is a hack, necessary due to core names for functions being - /// non-functional. Once functions have a "link name", that should be used as the symbol name here. - fn export_func_name(&mut self, node: Node, meta: &mut Vec) -> &'a str { - let (name, vis) = match self.hugr.get_optype(node) { - OpType::FuncDefn(func_defn) => (func_defn.func_name(), func_defn.visibility()), - OpType::FuncDecl(func_decl) => (func_decl.func_name(), func_decl.visibility()), - _ => panic!( - "`export_func_name` is only supposed to be used on function declarations and definitions" - ), - }; - - match vis { - Visibility::Public => name, - Visibility::Private => { - let literal = - self.make_term(table::Term::Literal(model::Literal::Str(name.to_smolstr()))); - meta.push(self.make_term_apply(model::CORE_TITLE, &[literal])); - self.mangled_name(node) - } - } - } - pub fn make_json_meta(&mut self, name: &str, value: &serde_json::Value) -> table::TermId { let value = serde_json::to_string(value).expect("json values are always serializable"); let value = self.make_term(model::Literal::Str(value.into()).into()); @@ -1190,11 +1135,6 @@ impl<'a> Context<'a> { let args = self.bump.alloc_slice_copy(args); self.make_term(table::Term::Apply(symbol, args)) } - - /// Creates a mangled name for a particular node. - fn mangled_name(&self, node: Node) -> &'a str { - bumpalo::format!(in &self.bump, "_{}", node.index()).into_bump_str() - } } type FxIndexSet = indexmap::IndexSet; @@ -1272,18 +1212,6 @@ impl Links { } } -/// Returns `true` if a node has any incident order edges. -fn has_order_edges(hugr: &Hugr, node: Node) -> bool { - let optype = hugr.get_optype(node); - Direction::BOTH - .iter() - .filter(|dir| optype.other_port_kind(**dir) == Some(EdgeKind::StateOrder)) - .filter_map(|dir| optype.other_port(*dir)) - .flat_map(|port| hugr.linked_ports(node, port)) - .next() - .is_some() -} - #[cfg(test)] mod test { use rstest::{fixture, rstest}; diff --git a/hugr-core/src/extension.rs b/hugr-core/src/extension.rs index c6dc2be25a..bb5034e1b1 100644 --- a/hugr-core/src/extension.rs +++ b/hugr-core/src/extension.rs @@ -22,7 +22,7 @@ use crate::hugr::IdentList; use crate::ops::custom::{ExtensionOp, OpaqueOp}; use crate::ops::{OpName, OpNameRef}; use crate::types::RowVariable; -use crate::types::type_param::{TermTypeError, TypeArg, TypeParam}; +use crate::types::type_param::{TypeArg, TypeArgError, TypeParam}; use crate::types::{CustomType, TypeBound, TypeName}; use crate::types::{Signature, TypeNameRef}; @@ -36,7 +36,7 @@ mod type_def; pub use const_fold::{ConstFold, ConstFoldResult, Folder, fold_out_row}; pub use op_def::{ CustomSignatureFunc, CustomValidator, LowerFunc, OpDef, SignatureFromArgs, SignatureFunc, - ValidateJustArgs, ValidateTypeArgs, deserialize_lower_funcs, + ValidateJustArgs, ValidateTypeArgs, }; pub use prelude::{PRELUDE, PRELUDE_REGISTRY}; pub use type_def::{TypeDef, TypeDefBound}; @@ -136,8 +136,8 @@ impl ExtensionRegistry { match self.exts.entry(extension.name().clone()) { btree_map::Entry::Occupied(prev) => Err(ExtensionRegistryError::AlreadyRegistered( extension.name().clone(), - Box::new(prev.get().version().clone()), - Box::new(extension.version().clone()), + prev.get().version().clone(), + extension.version().clone(), )), btree_map::Entry::Vacant(ve) => { ve.insert(extension); @@ -387,7 +387,7 @@ pub enum SignatureError { ExtensionMismatch(ExtensionId, ExtensionId), /// When the type arguments of the node did not match the params declared by the `OpDef` #[error("Type arguments of node did not match params declared by definition: {0}")] - TypeArgMismatch(#[from] TermTypeError), + TypeArgMismatch(#[from] TypeArgError), /// Invalid type arguments #[error("Invalid type arguments for operation")] InvalidTypeArgs, @@ -408,8 +408,8 @@ pub enum SignatureError { /// A Type Variable's cache of its declared kind is incorrect #[error("Type Variable claims to be {cached} but actual declaration {actual}")] TypeVarDoesNotMatchDeclaration { - actual: Box, - cached: Box, + actual: TypeParam, + cached: TypeParam, }, /// A type variable that was used has not been declared #[error("Type variable {idx} was not declared ({num_decls} in scope)")] @@ -425,8 +425,8 @@ pub enum SignatureError { "Incorrect result of type application in Call - cached {cached} but expected {expected}" )] CallIncorrectlyAppliesType { - cached: Box, - expected: Box, + cached: Signature, + expected: Signature, }, /// The result of the type application stored in a [`LoadFunction`] /// is not what we get by applying the type-args to the polymorphic function @@ -436,8 +436,8 @@ pub enum SignatureError { "Incorrect result of type application in LoadFunction - cached {cached} but expected {expected}" )] LoadFunctionIncorrectlyAppliesType { - cached: Box, - expected: Box, + cached: Signature, + expected: Signature, }, /// Extension declaration specifies a binary compute signature function, but none @@ -697,7 +697,7 @@ pub enum ExtensionRegistryError { #[error( "The registry already contains an extension with id {0} and version {1}. New extension has version {2}." )] - AlreadyRegistered(ExtensionId, Box, Box), + AlreadyRegistered(ExtensionId, Version, Version), /// A registered extension has invalid signatures. #[error("The extension {0} contains an invalid signature, {1}.")] InvalidSignature(ExtensionId, #[source] SignatureError), @@ -706,20 +706,13 @@ pub enum ExtensionRegistryError { /// An error that can occur while loading an extension registry. #[derive(Debug, Error)] #[non_exhaustive] -#[error("Extension registry load error")] pub enum ExtensionRegistryLoadError { /// Deserialization error. #[error(transparent)] SerdeError(#[from] serde_json::Error), /// Error when resolving internal extension references. #[error(transparent)] - ExtensionResolutionError(Box), -} - -impl From for ExtensionRegistryLoadError { - fn from(error: ExtensionResolutionError) -> Self { - Self::ExtensionResolutionError(Box::new(error)) - } + ExtensionResolutionError(#[from] ExtensionResolutionError), } /// An error that can occur in building a new extension. @@ -896,8 +889,8 @@ pub mod test { reg.register(ext1_1.clone()), Err(ExtensionRegistryError::AlreadyRegistered( ext_1_id.clone(), - Box::new(Version::new(1, 0, 0)), - Box::new(Version::new(1, 1, 0)) + Version::new(1, 0, 0), + Version::new(1, 1, 0) )) ); diff --git a/hugr-core/src/extension/declarative/types.rs b/hugr-core/src/extension/declarative/types.rs index 46224d48e8..ebbf628d68 100644 --- a/hugr-core/src/extension/declarative/types.rs +++ b/hugr-core/src/extension/declarative/types.rs @@ -100,7 +100,7 @@ impl From for TypeDefBound { bound: TypeBound::Copyable, }, TypeDefBoundDeclaration::Any => Self::Explicit { - bound: TypeBound::Linear, + bound: TypeBound::Any, }, } } @@ -129,6 +129,6 @@ impl TypeParamDeclaration { _extension: &Extension, _ctx: DeclarationContext<'_>, ) -> Result { - Ok(TypeParam::StringType) + Ok(TypeParam::String) } } diff --git a/hugr-core/src/extension/op_def.rs b/hugr-core/src/extension/op_def.rs index 6a2b5ab69f..9c30cbdd47 100644 --- a/hugr-core/src/extension/op_def.rs +++ b/hugr-core/src/extension/op_def.rs @@ -12,9 +12,9 @@ use super::{ }; use crate::Hugr; -use crate::envelope::serde_with::AsBinaryEnvelope; +use crate::envelope::serde_with::AsStringEnvelope; use crate::ops::{OpName, OpNameRef}; -use crate::types::type_param::{TypeArg, TypeParam, check_term_types}; +use crate::types::type_param::{TypeArg, TypeParam, check_type_args}; use crate::types::{FuncValueType, PolyFuncType, PolyFuncTypeRV, Signature}; mod serialize_signature_func; @@ -239,7 +239,7 @@ impl SignatureFunc { let static_params = func.static_params(); let (static_args, other_args) = args.split_at(min(static_params.len(), args.len())); - check_term_types(static_args, static_params)?; + check_type_args(static_args, static_params)?; temp = func.compute_signature(static_args, def)?; (&temp, other_args) } @@ -268,12 +268,8 @@ impl Debug for SignatureFunc { /// Different ways that an [OpDef] can lower operation nodes i.e. provide a Hugr /// that implements the operation using a set of other extensions. -/// -/// Does not implement [`serde::Deserialize`] directly since the serde error for -/// untagged enums is unhelpful. Use [`deserialize_lower_funcs`] with -/// [`serde(deserialize_with = "deserialize_lower_funcs")] instead. #[serde_as] -#[derive(serde::Serialize)] +#[derive(serde::Deserialize, serde::Serialize)] #[serde(untagged)] pub enum LowerFunc { /// Lowering to a fixed Hugr. Since this cannot depend upon the [TypeArg]s, @@ -285,8 +281,8 @@ pub enum LowerFunc { /// [OpDef] /// /// [ExtensionOp]: crate::ops::ExtensionOp - #[serde_as(as = "Box")] - hugr: Box, + #[serde_as(as = "AsStringEnvelope")] + hugr: Hugr, }, /// Custom binary function that can (fallibly) compute a Hugr /// for the particular instance and set of available extensions. @@ -294,34 +290,6 @@ pub enum LowerFunc { CustomFunc(Box), } -/// A function for deserializing sequences of [`LowerFunc::FixedHugr`]. -/// -/// We could let serde deserialize [`LowerFunc`] as-is, but if the LowerFunc -/// deserialization fails it just returns an opaque "data did not match any -/// variant of untagged enum LowerFunc" error. This function will return the -/// internal errors instead. -pub fn deserialize_lower_funcs<'de, D>(deserializer: D) -> Result, D::Error> -where - D: serde::Deserializer<'de>, -{ - #[serde_as] - #[derive(serde::Deserialize)] - struct FixedHugrDeserializer { - pub extensions: ExtensionSet, - #[serde_as(as = "Box")] - pub hugr: Box, - } - - let funcs: Vec = serde::Deserialize::deserialize(deserializer)?; - Ok(funcs - .into_iter() - .map(|f| LowerFunc::FixedHugr { - extensions: f.extensions, - hugr: f.hugr, - }) - .collect()) -} - impl Debug for LowerFunc { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { @@ -354,11 +322,7 @@ pub struct OpDef { signature_func: SignatureFunc, // Some operations cannot lower themselves and tools that do not understand them // can only treat them as opaque/black-box ops. - #[serde( - default, - skip_serializing_if = "Vec::is_empty", - deserialize_with = "deserialize_lower_funcs" - )] + #[serde(default, skip_serializing_if = "Vec::is_empty")] pub(crate) lower_funcs: Vec, /// Operations can optionally implement [`ConstFold`] to implement constant folding. @@ -383,7 +347,7 @@ impl OpDef { let (static_args, other_args) = args.split_at(min(custom.static_params().len(), args.len())); static_args.iter().try_for_each(|ta| ta.validate(&[]))?; - check_term_types(static_args, custom.static_params())?; + check_type_args(static_args, custom.static_params())?; temp = custom.compute_signature(static_args, self)?; (&temp, other_args) } @@ -393,7 +357,7 @@ impl OpDef { } }; args.iter().try_for_each(|ta| ta.validate(var_decls))?; - check_term_types(args, pf.params())?; + check_type_args(args, pf.params())?; Ok(()) } @@ -413,7 +377,7 @@ impl OpDef { .filter_map(|f| match f { LowerFunc::FixedHugr { extensions, hugr } => { if available_extensions.is_superset(extensions) { - Some(hugr.as_ref().clone()) + Some(hugr.clone()) } else { None } @@ -589,7 +553,7 @@ pub(super) mod test { use crate::extension::{ExtensionRegistry, ExtensionSet, PRELUDE}; use crate::ops::OpName; use crate::std_extensions::collections::list; - use crate::types::type_param::{TermTypeError, TypeParam}; + use crate::types::type_param::{TypeArgError, TypeParam}; use crate::types::{PolyFuncTypeRV, Signature, Type, TypeArg, TypeBound, TypeRV}; use crate::{Extension, const_extension_ids}; @@ -692,7 +656,7 @@ pub(super) mod test { const OP_NAME: OpName = OpName::new_inline("Reverse"); let ext = Extension::try_new_test_arc(EXT_ID, |ext, extension_ref| { - const TP: TypeParam = TypeParam::RuntimeType(TypeBound::Linear); + const TP: TypeParam = TypeParam::Type { b: TypeBound::Any }; let list_of_var = Type::new_extension(list_def.instantiate(vec![TypeArg::new_var_use(0, TP)])?); let type_scheme = PolyFuncTypeRV::new(vec![TP], Signature::new_endo(vec![list_of_var])); @@ -700,7 +664,7 @@ pub(super) mod test { let def = ext.add_op(OP_NAME, "desc".into(), type_scheme, extension_ref)?; def.add_lower_func(LowerFunc::FixedHugr { extensions: ExtensionSet::new(), - hugr: Box::new(crate::builder::test::simple_dfg_hugr()), // this is nonsense, but we are not testing the actual lowering here + hugr: crate::builder::test::simple_dfg_hugr(), // this is nonsense, but we are not testing the actual lowering here }); def.add_misc("key", Default::default()); assert_eq!(def.description(), "desc"); @@ -714,10 +678,11 @@ pub(super) mod test { reg.validate()?; let e = reg.get(&EXT_ID).unwrap(); - let list_usize = Type::new_extension(list_def.instantiate(vec![usize_t().into()])?); + let list_usize = + Type::new_extension(list_def.instantiate(vec![TypeArg::Type { ty: usize_t() }])?); let mut dfg = DFGBuilder::new(endo_sig(vec![list_usize]))?; let rev = dfg.add_dataflow_op( - e.instantiate_extension_op(&OP_NAME, vec![usize_t().into()]) + e.instantiate_extension_op(&OP_NAME, vec![TypeArg::Type { ty: usize_t() }]) .unwrap(), dfg.input_wires(), )?; @@ -738,13 +703,13 @@ pub(super) mod test { &self, arg_values: &[TypeArg], ) -> Result { - const TP: TypeParam = TypeParam::RuntimeType(TypeBound::Linear); - let [TypeArg::BoundedNat(n)] = arg_values else { + const TP: TypeParam = TypeParam::Type { b: TypeBound::Any }; + let [TypeArg::BoundedNat { n }] = arg_values else { return Err(SignatureError::InvalidTypeArgs); }; let n = *n as usize; let tvs: Vec = (0..n) - .map(|_| Type::new_var_use(0, TypeBound::Linear)) + .map(|_| Type::new_var_use(0, TypeBound::Any)) .collect(); Ok(PolyFuncTypeRV::new( vec![TP.clone()], @@ -753,7 +718,7 @@ pub(super) mod test { } fn static_params(&self) -> &[TypeParam] { - const MAX_NAT: &[TypeParam] = &[TypeParam::max_nat_type()]; + const MAX_NAT: &[TypeParam] = &[TypeParam::max_nat()]; MAX_NAT } } @@ -762,7 +727,7 @@ pub(super) mod test { ext.add_op("MyOp".into(), String::new(), SigFun(), extension_ref)?; // Base case, no type variables: - let args = [TypeArg::BoundedNat(3), usize_t().into()]; + let args = [TypeArg::BoundedNat { n: 3 }, usize_t().into()]; assert_eq!( def.compute_signature(&args), Ok(Signature::new( @@ -775,7 +740,7 @@ pub(super) mod test { // Second arg may be a variable (substitutable) let tyvar = Type::new_var_use(0, TypeBound::Copyable); let tyvars: Vec = vec![tyvar.clone(); 3]; - let args = [TypeArg::BoundedNat(3), tyvar.clone().into()]; + let args = [TypeArg::BoundedNat { n: 3 }, tyvar.clone().into()]; assert_eq!( def.compute_signature(&args), Ok(Signature::new( @@ -788,15 +753,15 @@ pub(super) mod test { // quick sanity check that we are validating the args - note changed bound: assert_eq!( - def.validate_args(&args, &[TypeBound::Linear.into()]), + def.validate_args(&args, &[TypeBound::Any.into()]), Err(SignatureError::TypeVarDoesNotMatchDeclaration { - actual: Box::new(TypeBound::Linear.into()), - cached: Box::new(TypeBound::Copyable.into()) + actual: TypeBound::Any.into(), + cached: TypeBound::Copyable.into() }) ); // First arg must be concrete, not a variable - let kind = TypeParam::bounded_nat_type(NonZeroU64::new(5).unwrap()); + let kind = TypeParam::bounded_nat(NonZeroU64::new(5).unwrap()); let args = [TypeArg::new_var_use(0, kind.clone()), usize_t().into()]; // We can't prevent this from getting into our compute_signature implementation: assert_eq!( @@ -827,13 +792,13 @@ pub(super) mod test { "SimpleOp".into(), String::new(), PolyFuncTypeRV::new( - vec![TypeBound::Linear.into()], - Signature::new_endo(vec![Type::new_var_use(0, TypeBound::Linear)]), + vec![TypeBound::Any.into()], + Signature::new_endo(vec![Type::new_var_use(0, TypeBound::Any)]), ), extension_ref, )?; let tv = Type::new_var_use(0, TypeBound::Copyable); - let args = [tv.clone().into()]; + let args = [TypeArg::Type { ty: tv.clone() }]; let decls = [TypeBound::Copyable.into()]; def.validate_args(&args, &decls).unwrap(); assert_eq!(def.compute_signature(&args), Ok(Signature::new_endo(tv))); @@ -842,9 +807,9 @@ pub(super) mod test { assert_eq!( def.compute_signature(&[arg.clone()]), Err(SignatureError::TypeArgMismatch( - TermTypeError::TypeMismatch { - type_: Box::new(TypeBound::Linear.into()), - term: Box::new(arg), + TypeArgError::TypeMismatch { + param: TypeBound::Any.into(), + arg } )) ); @@ -887,7 +852,7 @@ pub(super) mod test { any::() .prop_map(|extensions| LowerFunc::FixedHugr { extensions, - hugr: Box::new(simple_dfg_hugr()), + hugr: simple_dfg_hugr(), }) .boxed() } diff --git a/hugr-core/src/extension/prelude.rs b/hugr-core/src/extension/prelude.rs index 1b59d50ea7..3af70b75b4 100644 --- a/hugr-core/src/extension/prelude.rs +++ b/hugr-core/src/extension/prelude.rs @@ -18,8 +18,8 @@ use crate::ops::constant::{CustomCheckFailure, CustomConst, ValueName}; use crate::ops::{NamedOp, Value}; use crate::types::type_param::{TypeArg, TypeParam}; use crate::types::{ - CustomType, FuncValueType, PolyFuncType, PolyFuncTypeRV, Signature, SumType, Term, Type, - TypeBound, TypeName, TypeRV, TypeRow, TypeRowRV, + CustomType, FuncValueType, PolyFuncType, PolyFuncTypeRV, Signature, SumType, Type, TypeBound, + TypeName, TypeRV, TypeRow, TypeRowRV, }; use crate::utils::sorted_consts; use crate::{Extension, type_row}; @@ -39,7 +39,7 @@ pub mod generic; /// Name of prelude extension. pub const PRELUDE_ID: ExtensionId = ExtensionId::new_unchecked("prelude"); /// Extension version. -pub const VERSION: semver::Version = semver::Version::new(0, 2, 1); +pub const VERSION: semver::Version = semver::Version::new(0, 2, 0); lazy_static! { /// Prelude extension, containing common types and operations. pub static ref PRELUDE: Arc = { @@ -52,7 +52,6 @@ lazy_static! { // would try to access the `PRELUDE` lazy static recursively, // causing a deadlock. let string_type: Type = string_custom_type(extension_ref).into(); - let usize_type: Type = usize_custom_t(extension_ref).into(); let error_type: CustomType = error_custom_type(extension_ref); prelude @@ -75,7 +74,7 @@ lazy_static! { prelude.add_op( PRINT_OP_ID, "Print the string to standard output".to_string(), - Signature::new(vec![string_type.clone()], type_row![]), + Signature::new(vec![string_type], type_row![]), extension_ref, ) .unwrap(); @@ -97,23 +96,15 @@ lazy_static! { extension_ref, ) .unwrap(); - prelude - .add_op( - MAKE_ERROR_OP_ID, - "Create an error value".to_string(), - Signature::new(vec![usize_type, string_type], vec![error_type.clone().into()]), - extension_ref, - ) - .unwrap(); prelude .add_op( PANIC_OP_ID, "Panic with input error".to_string(), PolyFuncTypeRV::new( - [TypeParam::new_list_type(TypeBound::Linear), TypeParam::new_list_type(TypeBound::Linear)], + [TypeParam::new_list(TypeBound::Any), TypeParam::new_list(TypeBound::Any)], FuncValueType::new( - vec![TypeRV::new_extension(error_type.clone()), TypeRV::new_row_var_use(0, TypeBound::Linear)], - vec![TypeRV::new_row_var_use(1, TypeBound::Linear)], + vec![TypeRV::new_extension(error_type.clone()), TypeRV::new_row_var_use(0, TypeBound::Any)], + vec![TypeRV::new_row_var_use(1, TypeBound::Any)], ), ), extension_ref, @@ -124,10 +115,10 @@ lazy_static! { EXIT_OP_ID, "Exit with input error".to_string(), PolyFuncTypeRV::new( - [TypeParam::new_list_type(TypeBound::Linear), TypeParam::new_list_type(TypeBound::Linear)], + [TypeParam::new_list(TypeBound::Any), TypeParam::new_list(TypeBound::Any)], FuncValueType::new( - vec![TypeRV::new_extension(error_type), TypeRV::new_row_var_use(0, TypeBound::Linear)], - vec![TypeRV::new_row_var_use(1, TypeBound::Linear)], + vec![TypeRV::new_extension(error_type), TypeRV::new_row_var_use(0, TypeBound::Any)], + vec![TypeRV::new_row_var_use(1, TypeBound::Any)], ), ), extension_ref, @@ -160,7 +151,7 @@ pub(crate) fn qb_custom_t(extension_ref: &Weak) -> CustomType { TypeName::new_inline("qubit"), vec![], PRELUDE_ID, - TypeBound::Linear, + TypeBound::Any, extension_ref, ) } @@ -181,15 +172,10 @@ pub fn bool_t() -> Type { Type::new_unit_sum(2) } -/// Name of the prelude `MakeError` operation. -/// -/// This operation can be used to dynamically create error values. -pub const MAKE_ERROR_OP_ID: OpName = OpName::new_inline("MakeError"); - /// Name of the prelude panic operation. /// /// This operation can have any input and any output wires; it is instantiated -/// with two [`TypeArg::List`]s representing these. The first input to the +/// with two [`TypeArg::Sequence`]s representing these. The first input to the /// operation is always an error type; the remaining inputs correspond to the /// first sequence of types in its instantiation; the outputs correspond to the /// second sequence of types in its instantiation. Note that the inputs and @@ -203,7 +189,7 @@ pub const PANIC_OP_ID: OpName = OpName::new_inline("panic"); /// Name of the prelude exit operation. /// /// This operation can have any input and any output wires; it is instantiated -/// with two [`TypeArg::List`]s representing these. The first input to the +/// with two [`TypeArg::Sequence`]s representing these. The first input to the /// operation is always an error type; the remaining inputs correspond to the /// first sequence of types in its instantiation; the outputs correspond to the /// second sequence of types in its instantiation. Note that the inputs and @@ -626,10 +612,10 @@ impl MakeOpDef for TupleOpDef { } fn init_signature(&self, _extension_ref: &Weak) -> SignatureFunc { - let rv = TypeRV::new_row_var_use(0, TypeBound::Linear); + let rv = TypeRV::new_row_var_use(0, TypeBound::Any); let tuple_type = TypeRV::new_tuple(vec![rv.clone()]); - let param = TypeParam::new_list_type(TypeBound::Linear); + let param = TypeParam::new_list(TypeBound::Any); match self { TupleOpDef::MakeTuple => { PolyFuncTypeRV::new([param], FuncValueType::new(rv, tuple_type)) @@ -692,13 +678,13 @@ impl MakeExtensionOp for MakeTuple { if def != TupleOpDef::MakeTuple { return Err(OpLoadError::NotMember(ext_op.unqualified_id().to_string()))?; } - let [TypeArg::List(elems)] = ext_op.args() else { + let [TypeArg::Sequence { elems }] = ext_op.args() else { return Err(SignatureError::InvalidTypeArgs)?; }; let tys: Result, _> = elems .iter() .map(|a| match a { - TypeArg::Runtime(ty) => Ok(ty.clone()), + TypeArg::Type { ty } => Ok(ty.clone()), _ => Err(SignatureError::InvalidTypeArgs), }) .collect(); @@ -706,7 +692,13 @@ impl MakeExtensionOp for MakeTuple { } fn type_args(&self) -> Vec { - vec![Term::new_list(self.0.iter().map(|t| t.clone().into()))] + vec![TypeArg::Sequence { + elems: self + .0 + .iter() + .map(|t| TypeArg::Type { ty: t.clone() }) + .collect(), + }] } } @@ -747,21 +739,27 @@ impl MakeExtensionOp for UnpackTuple { if def != TupleOpDef::UnpackTuple { return Err(OpLoadError::NotMember(ext_op.unqualified_id().to_string()))?; } - let [Term::List(elems)] = ext_op.args() else { + let [TypeArg::Sequence { elems }] = ext_op.args() else { return Err(SignatureError::InvalidTypeArgs)?; }; let tys: Result, _> = elems .iter() .map(|a| match a { - Term::Runtime(ty) => Ok(ty.clone()), + TypeArg::Type { ty } => Ok(ty.clone()), _ => Err(SignatureError::InvalidTypeArgs), }) .collect(); Ok(Self(tys?.into())) } - fn type_args(&self) -> Vec { - vec![Term::new_list(self.0.iter().map(|t| t.clone().into()))] + fn type_args(&self) -> Vec { + vec![TypeArg::Sequence { + elems: self + .0 + .iter() + .map(|t| TypeArg::Type { ty: t.clone() }) + .collect(), + }] } } @@ -800,8 +798,8 @@ impl MakeOpDef for NoopDef { } fn init_signature(&self, _extension_ref: &Weak) -> SignatureFunc { - let tv = Type::new_var_use(0, TypeBound::Linear); - PolyFuncType::new([TypeBound::Linear.into()], Signature::new_endo(tv)).into() + let tv = Type::new_var_use(0, TypeBound::Any); + PolyFuncType::new([TypeBound::Any.into()], Signature::new_endo(tv)).into() } fn description(&self) -> String { @@ -865,14 +863,14 @@ impl MakeExtensionOp for Noop { Self: Sized, { let _def = NoopDef::from_def(ext_op.def())?; - let [TypeArg::Runtime(ty)] = ext_op.args() else { + let [TypeArg::Type { ty }] = ext_op.args() else { return Err(SignatureError::InvalidTypeArgs)?; }; Ok(Self(ty.clone())) } fn type_args(&self) -> Vec { - vec![self.0.clone().into()] + vec![TypeArg::Type { ty: self.0.clone() }] } } @@ -912,8 +910,8 @@ impl MakeOpDef for BarrierDef { fn init_signature(&self, _extension_ref: &Weak) -> SignatureFunc { PolyFuncTypeRV::new( - vec![TypeParam::new_list_type(TypeBound::Linear)], - FuncValueType::new_endo(TypeRV::new_row_var_use(0, TypeBound::Linear)), + vec![TypeParam::new_list(TypeBound::Any)], + FuncValueType::new_endo(TypeRV::new_row_var_use(0, TypeBound::Any)), ) .into() } @@ -971,13 +969,13 @@ impl MakeExtensionOp for Barrier { { let _def = BarrierDef::from_def(ext_op.def())?; - let [TypeArg::List(elems)] = ext_op.args() else { + let [TypeArg::Sequence { elems }] = ext_op.args() else { return Err(SignatureError::InvalidTypeArgs)?; }; let tys: Result, _> = elems .iter() .map(|a| match a { - TypeArg::Runtime(ty) => Ok(ty.clone()), + TypeArg::Type { ty } => Ok(ty.clone()), _ => Err(SignatureError::InvalidTypeArgs), }) .collect(); @@ -987,9 +985,13 @@ impl MakeExtensionOp for Barrier { } fn type_args(&self) -> Vec { - vec![TypeArg::new_list( - self.type_row.iter().map(|t| t.clone().into()), - )] + vec![TypeArg::Sequence { + elems: self + .type_row + .iter() + .map(|t| TypeArg::Type { ty: t.clone() }) + .collect(), + }] } } @@ -1007,7 +1009,6 @@ impl MakeRegisteredOp for Barrier { mod test { use crate::builder::inout_sig; use crate::std_extensions::arithmetic::float_types::{ConstF64, float64_type}; - use crate::types::Term; use crate::{ Hugr, Wire, builder::{DFGBuilder, Dataflow, DataflowHugr, endo_sig}, @@ -1020,8 +1021,6 @@ mod test { type_row, }; - use crate::hugr::views::HugrView; - #[test] fn test_make_tuple() { let op = MakeTuple::new(type_row![Type::UNIT]); @@ -1133,8 +1132,9 @@ mod test { let err = b.add_load_value(error_val); + const TYPE_ARG_NONE: TypeArg = TypeArg::Sequence { elems: vec![] }; let op = PRELUDE - .instantiate_extension_op(&EXIT_OP_ID, [Term::new_list([]), Term::new_list([])]) + .instantiate_extension_op(&EXIT_OP_ID, [TYPE_ARG_NONE, TYPE_ARG_NONE]) .unwrap(); b.add_dataflow_op(op, [err]).unwrap(); @@ -1142,32 +1142,14 @@ mod test { b.finish_hugr_with_outputs([]).unwrap(); } - #[test] - /// test the prelude make error op with the panic op. - fn test_make_error() { - let err_op = PRELUDE - .instantiate_extension_op(&MAKE_ERROR_OP_ID, []) - .unwrap(); - let panic_op = PRELUDE - .instantiate_extension_op(&EXIT_OP_ID, [Term::new_list([]), Term::new_list([])]) - .unwrap(); - - let mut b = - DFGBuilder::new(Signature::new(vec![usize_t(), string_type()], type_row![])).unwrap(); - let [signal, message] = b.input_wires_arr(); - let err_value = b.add_dataflow_op(err_op, [signal, message]).unwrap(); - b.add_dataflow_op(panic_op, err_value.outputs()).unwrap(); - - let h = b.finish_hugr_with_outputs([]).unwrap(); - h.validate().unwrap(); - } - #[test] /// test the panic operation with input and output wires fn test_panic_with_io() { let error_val = ConstError::new(42, "PANIC"); - let type_arg_q: Term = qb_t().into(); - let type_arg_2q: Term = Term::new_list([type_arg_q.clone(), type_arg_q]); + let type_arg_q: TypeArg = TypeArg::Type { ty: qb_t() }; + let type_arg_2q: TypeArg = TypeArg::Sequence { + elems: vec![type_arg_q.clone(), type_arg_q], + }; let panic_op = PRELUDE .instantiate_extension_op(&PANIC_OP_ID, [type_arg_2q.clone(), type_arg_2q.clone()]) .unwrap(); diff --git a/hugr-core/src/extension/prelude/generic.rs b/hugr-core/src/extension/prelude/generic.rs index ca00c713fd..9ea231e1bb 100644 --- a/hugr-core/src/extension/prelude/generic.rs +++ b/hugr-core/src/extension/prelude/generic.rs @@ -74,7 +74,7 @@ impl MakeOpDef for LoadNatDef { fn init_signature(&self, _extension_ref: &Weak) -> SignatureFunc { let usize_t: Type = usize_custom_t(_extension_ref).into(); - let params = vec![TypeParam::max_nat_type()]; + let params = vec![TypeParam::max_nat()]; PolyFuncTypeRV::new(params, FuncValueType::new(type_row![], vec![usize_t])).into() } @@ -166,7 +166,7 @@ mod tests { extension::prelude::{ConstUsize, usize_t}, ops::{OpType, constant}, type_row, - types::Term, + types::TypeArg, }; use super::LoadNat; @@ -175,7 +175,7 @@ mod tests { fn test_load_nat() { let mut b = DFGBuilder::new(inout_sig(type_row![], vec![usize_t()])).unwrap(); - let arg = Term::from(4u64); + let arg = TypeArg::BoundedNat { n: 4 }; let op = LoadNat::new(arg); let out = b.add_dataflow_op(op.clone(), []).unwrap(); @@ -195,7 +195,7 @@ mod tests { #[test] fn test_load_nat_fold() { - let arg = Term::from(5u64); + let arg = TypeArg::BoundedNat { n: 5 }; let op = LoadNat::new(arg); let optype: OpType = op.into(); diff --git a/hugr-core/src/extension/resolution.rs b/hugr-core/src/extension/resolution.rs index 52f2c5dbf5..0e7bfbbab8 100644 --- a/hugr-core/src/extension/resolution.rs +++ b/hugr-core/src/extension/resolution.rs @@ -26,7 +26,7 @@ pub(crate) use ops::{collect_op_extension, resolve_op_extensions}; pub(crate) use types::{collect_op_types_extensions, collect_signature_exts, collect_type_exts}; pub(crate) use types_mut::resolve_op_types_extensions; use types_mut::{ - resolve_custom_type_exts, resolve_term_exts, resolve_type_exts, resolve_value_exts, + resolve_custom_type_exts, resolve_type_exts, resolve_typearg_exts, resolve_value_exts, }; use derive_more::{Display, Error, From}; @@ -63,7 +63,7 @@ pub fn resolve_typearg_extensions( extensions: &WeakExtensionRegistry, ) -> Result<(), ExtensionResolutionError> { let mut used_extensions = WeakExtensionRegistry::default(); - resolve_term_exts(None, arg, extensions, &mut used_extensions) + resolve_typearg_exts(None, arg, extensions, &mut used_extensions) } /// Update all weak Extension pointers inside a constant value. diff --git a/hugr-core/src/extension/resolution/ops.rs b/hugr-core/src/extension/resolution/ops.rs index e6727fd834..a76ba47d8c 100644 --- a/hugr-core/src/extension/resolution/ops.rs +++ b/hugr-core/src/extension/resolution/ops.rs @@ -98,8 +98,8 @@ pub(crate) fn resolve_op_extensions<'e>( node, extension: opaque.extension().clone(), op: def.name().clone(), - computed: Box::new(ext_op.signature().into_owned()), - stored: Box::new(opaque.signature().into_owned()), + computed: ext_op.signature().into_owned(), + stored: opaque.signature().into_owned(), } .into()); } diff --git a/hugr-core/src/extension/resolution/test.rs b/hugr-core/src/extension/resolution/test.rs index 43c64b561d..e73dd54fbd 100644 --- a/hugr-core/src/extension/resolution/test.rs +++ b/hugr-core/src/extension/resolution/test.rs @@ -25,7 +25,7 @@ use crate::std_extensions::arithmetic::int_types::{self, int_type}; use crate::std_extensions::collections::list::ListValue; use crate::std_extensions::std_reg; use crate::types::type_param::TypeParam; -use crate::types::{PolyFuncType, Signature, Type, TypeBound}; +use crate::types::{PolyFuncType, Signature, Type, TypeArg, TypeBound}; use crate::{Extension, Hugr, HugrView, type_row}; #[rstest] @@ -333,12 +333,12 @@ fn resolve_custom_const(#[case] custom_const: impl CustomConst) { #[rstest] fn resolve_call() { let dummy_fn_sig = PolyFuncType::new( - vec![TypeParam::RuntimeType(TypeBound::Linear)], + vec![TypeParam::Type { b: TypeBound::Any }], Signature::new(vec![], vec![bool_t()]), ); - let generic_type_1 = float64_type().into(); - let generic_type_2 = int_type(6).into(); + let generic_type_1 = TypeArg::Type { ty: float64_type() }; + let generic_type_2 = TypeArg::Type { ty: int_type(6) }; let expected_exts = [ float_types::EXTENSION_ID.clone(), int_types::EXTENSION_ID.clone(), diff --git a/hugr-core/src/extension/resolution/types.rs b/hugr-core/src/extension/resolution/types.rs index 0ea6bd7007..531509d6ee 100644 --- a/hugr-core/src/extension/resolution/types.rs +++ b/hugr-core/src/extension/resolution/types.rs @@ -11,7 +11,7 @@ use crate::Node; use crate::extension::{ExtensionRegistry, ExtensionSet}; use crate::ops::{DataflowOpTrait, OpType, Value}; use crate::types::type_row::TypeRowBase; -use crate::types::{FuncTypeBase, MaybeRV, SumType, Term, TypeBase, TypeEnum}; +use crate::types::{FuncTypeBase, MaybeRV, SumType, TypeArg, TypeBase, TypeEnum}; /// Collects every extension used to define the types in an operation. /// @@ -38,7 +38,7 @@ pub(crate) fn collect_op_types_extensions( match op { OpType::ExtensionOp(ext) => { for arg in ext.args() { - collect_term_exts(arg, &mut used, &mut missing); + collect_typearg_exts(arg, &mut used, &mut missing); } collect_signature_exts(&ext.signature(), &mut used, &mut missing); } @@ -55,7 +55,7 @@ pub(crate) fn collect_op_types_extensions( collect_signature_exts(c.func_sig.body(), &mut used, &mut missing); collect_signature_exts(&c.instantiation, &mut used, &mut missing); for arg in &c.type_args { - collect_term_exts(arg, &mut used, &mut missing); + collect_typearg_exts(arg, &mut used, &mut missing); } } OpType::CallIndirect(c) => collect_signature_exts(&c.signature, &mut used, &mut missing), @@ -64,13 +64,13 @@ pub(crate) fn collect_op_types_extensions( collect_signature_exts(lf.func_sig.body(), &mut used, &mut missing); collect_signature_exts(&lf.instantiation, &mut used, &mut missing); for arg in &lf.type_args { - collect_term_exts(arg, &mut used, &mut missing); + collect_typearg_exts(arg, &mut used, &mut missing); } } OpType::DFG(dfg) => collect_signature_exts(&dfg.signature, &mut used, &mut missing), OpType::OpaqueOp(op) => { for arg in op.args() { - collect_term_exts(arg, &mut used, &mut missing); + collect_typearg_exts(arg, &mut used, &mut missing); } collect_signature_exts(&op.signature(), &mut used, &mut missing); } @@ -172,7 +172,7 @@ pub(crate) fn collect_type_exts( match typ.as_type_enum() { TypeEnum::Extension(custom) => { for arg in custom.args() { - collect_term_exts(arg, used_extensions, missing_extensions); + collect_typearg_exts(arg, used_extensions, missing_extensions); } let ext_ref = custom.extension_ref(); // Check if the extension reference is still valid. @@ -202,58 +202,29 @@ pub(crate) fn collect_type_exts( } } -/// Collect the Extension pointers in the [`CustomType`]s inside a [`Term`]. +/// Collect the Extension pointers in the [`CustomType`]s inside a type argument. /// /// # Attributes /// -/// - `term`: The term argument to collect the extensions from. +/// - `arg`: The type argument to collect the extensions from. /// - `used_extensions`: A The registry where to store the used extensions. /// - `missing_extensions`: A set of `ExtensionId`s of which the /// `Weak` pointer has been invalidated. -pub(super) fn collect_term_exts( - term: &Term, +pub(super) fn collect_typearg_exts( + arg: &TypeArg, used_extensions: &mut WeakExtensionRegistry, missing_extensions: &mut ExtensionSet, ) { - match term { - Term::Runtime(ty) => collect_type_exts(ty, used_extensions, missing_extensions), - Term::List(elems) => { - for elem in elems.iter() { - collect_term_exts(elem, used_extensions, missing_extensions); + match arg { + TypeArg::Type { ty } => collect_type_exts(ty, used_extensions, missing_extensions), + TypeArg::Sequence { elems } => { + for elem in elems { + collect_typearg_exts(elem, used_extensions, missing_extensions); } } - Term::Tuple(elems) => { - for elem in elems.iter() { - collect_term_exts(elem, used_extensions, missing_extensions); - } - } - Term::ListType(item_type) => { - collect_term_exts(item_type, used_extensions, missing_extensions) - } - Term::TupleType(item_types) => { - collect_term_exts(item_types, used_extensions, missing_extensions) - } - Term::ListConcat(lists) => { - for list in lists { - collect_term_exts(list, used_extensions, missing_extensions); - } - } - Term::TupleConcat(tuples) => { - for tuple in tuples { - collect_term_exts(tuple, used_extensions, missing_extensions); - } - } - Term::Variable(_) - | Term::RuntimeType(_) - | Term::StaticType - | Term::BoundedNatType(_) - | Term::StringType - | Term::BytesType - | Term::FloatType - | Term::BoundedNat(_) - | Term::String(_) - | Term::Bytes(_) - | Term::Float(_) => {} + // We ignore the `TypeArg::Extension` case, as it is not required to + // **define** the hugr. + _ => {} } } diff --git a/hugr-core/src/extension/resolution/types_mut.rs b/hugr-core/src/extension/resolution/types_mut.rs index 8135ca0b1b..c4093a18c2 100644 --- a/hugr-core/src/extension/resolution/types_mut.rs +++ b/hugr-core/src/extension/resolution/types_mut.rs @@ -10,7 +10,7 @@ use super::{ExtensionResolutionError, WeakExtensionRegistry}; use crate::extension::ExtensionSet; use crate::ops::{OpType, Value}; use crate::types::type_row::TypeRowBase; -use crate::types::{CustomType, FuncTypeBase, MaybeRV, SumType, Term, TypeBase, TypeEnum}; +use crate::types::{CustomType, FuncTypeBase, MaybeRV, SumType, TypeArg, TypeBase, TypeEnum}; use crate::{Extension, Node}; /// Replace the dangling extension pointer in the [`CustomType`]s inside an @@ -30,7 +30,7 @@ pub fn resolve_op_types_extensions( match op { OpType::ExtensionOp(ext) => { for arg in ext.args_mut() { - resolve_term_exts(node, arg, extensions, used_extensions)?; + resolve_typearg_exts(node, arg, extensions, used_extensions)?; } resolve_signature_exts(node, ext.signature_mut(), extensions, used_extensions)?; } @@ -61,7 +61,7 @@ pub fn resolve_op_types_extensions( resolve_signature_exts(node, c.func_sig.body_mut(), extensions, used_extensions)?; resolve_signature_exts(node, &mut c.instantiation, extensions, used_extensions)?; for arg in &mut c.type_args { - resolve_term_exts(node, arg, extensions, used_extensions)?; + resolve_typearg_exts(node, arg, extensions, used_extensions)?; } } OpType::CallIndirect(c) => { @@ -74,7 +74,7 @@ pub fn resolve_op_types_extensions( resolve_signature_exts(node, lf.func_sig.body_mut(), extensions, used_extensions)?; resolve_signature_exts(node, &mut lf.instantiation, extensions, used_extensions)?; for arg in &mut lf.type_args { - resolve_term_exts(node, arg, extensions, used_extensions)?; + resolve_typearg_exts(node, arg, extensions, used_extensions)?; } } OpType::DFG(dfg) => { @@ -82,7 +82,7 @@ pub fn resolve_op_types_extensions( } OpType::OpaqueOp(op) => { for arg in op.args_mut() { - resolve_term_exts(node, arg, extensions, used_extensions)?; + resolve_typearg_exts(node, arg, extensions, used_extensions)?; } resolve_signature_exts(node, op.signature_mut(), extensions, used_extensions)?; } @@ -195,7 +195,7 @@ pub(super) fn resolve_custom_type_exts( used_extensions: &mut WeakExtensionRegistry, ) -> Result<(), ExtensionResolutionError> { for arg in custom.args_mut() { - resolve_term_exts(node, arg, extensions, used_extensions)?; + resolve_typearg_exts(node, arg, extensions, used_extensions)?; } let ext_id = custom.extension(); @@ -211,42 +211,23 @@ pub(super) fn resolve_custom_type_exts( Ok(()) } -/// Update all weak Extension pointers in the [`CustomType`]s inside a [`Term`]. +/// Update all weak Extension pointers in the [`CustomType`]s inside a type arg. /// /// Adds the extensions used in the type to the `used_extensions` registry. -pub(super) fn resolve_term_exts( +pub(super) fn resolve_typearg_exts( node: Option, - term: &mut Term, + arg: &mut TypeArg, extensions: &WeakExtensionRegistry, used_extensions: &mut WeakExtensionRegistry, ) -> Result<(), ExtensionResolutionError> { - match term { - Term::Runtime(ty) => resolve_type_exts(node, ty, extensions, used_extensions)?, - Term::List(children) - | Term::ListConcat(children) - | Term::Tuple(children) - | Term::TupleConcat(children) => { - for child in children.iter_mut() { - resolve_term_exts(node, child, extensions, used_extensions)?; + match arg { + TypeArg::Type { ty } => resolve_type_exts(node, ty, extensions, used_extensions)?, + TypeArg::Sequence { elems } => { + for elem in elems.iter_mut() { + resolve_typearg_exts(node, elem, extensions, used_extensions)?; } } - Term::ListType(item_type) => { - resolve_term_exts(node, item_type.as_mut(), extensions, used_extensions)?; - } - Term::TupleType(item_types) => { - resolve_term_exts(node, item_types.as_mut(), extensions, used_extensions)?; - } - Term::Variable(_) - | Term::RuntimeType(_) - | Term::StaticType - | Term::BoundedNatType(_) - | Term::StringType - | Term::BytesType - | Term::FloatType - | Term::BoundedNat(_) - | Term::String(_) - | Term::Bytes(_) - | Term::Float(_) => {} + _ => {} } Ok(()) } diff --git a/hugr-core/src/extension/simple_op.rs b/hugr-core/src/extension/simple_op.rs index 8685b63325..bf013ba5dc 100644 --- a/hugr-core/src/extension/simple_op.rs +++ b/hugr-core/src/extension/simple_op.rs @@ -308,10 +308,7 @@ impl From for OpType { mod test { use std::sync::Arc; - use crate::{ - const_extension_ids, type_row, - types::{Signature, Term}, - }; + use crate::{const_extension_ids, type_row, types::Signature}; use super::*; use lazy_static::lazy_static; @@ -396,7 +393,7 @@ mod test { assert_eq!(o.instantiate(&[]), Ok(o.clone())); assert_eq!( - o.instantiate(&[Term::from(1u64)]), + o.instantiate(&[TypeArg::BoundedNat { n: 1 }]), Err(OpLoadError::InvalidArgs(SignatureError::InvalidTypeArgs)) ); } diff --git a/hugr-core/src/extension/type_def.rs b/hugr-core/src/extension/type_def.rs index b848c7528f..fceb336b2f 100644 --- a/hugr-core/src/extension/type_def.rs +++ b/hugr-core/src/extension/type_def.rs @@ -6,7 +6,7 @@ use super::{Extension, ExtensionId, SignatureError}; use crate::types::{CustomType, TypeName, least_upper_bound}; -use crate::types::type_param::{TypeArg, check_term_types}; +use crate::types::type_param::{TypeArg, check_type_args}; use crate::types::type_param::TypeParam; @@ -34,7 +34,7 @@ impl TypeDefBound { #[must_use] pub fn any() -> Self { TypeDefBound::Explicit { - bound: TypeBound::Linear, + bound: TypeBound::Any, } } @@ -79,7 +79,7 @@ pub struct TypeDef { impl TypeDef { /// Check provided type arguments are valid against parameters. pub fn check_args(&self, args: &[TypeArg]) -> Result<(), SignatureError> { - check_term_types(args, &self.params).map_err(SignatureError::TypeArgMismatch) + check_type_args(args, &self.params).map_err(SignatureError::TypeArgMismatch) } /// Check [`CustomType`] is a valid instantiation of this definition. @@ -102,7 +102,7 @@ impl TypeDef { )); } - check_term_types(custom.type_args(), &self.params)?; + check_type_args(custom.type_args(), &self.params)?; let calc_bound = self.bound(custom.args()); if calc_bound == custom.bound() { @@ -123,7 +123,7 @@ impl TypeDef { /// valid instances of the type parameters. pub fn instantiate(&self, args: impl Into>) -> Result { let args = args.into(); - check_term_types(&args, &self.params)?; + check_type_args(&args, &self.params)?; let bound = self.bound(&args); Ok(CustomType::new( self.name().clone(), @@ -142,12 +142,12 @@ impl TypeDef { let args: Vec<_> = args.iter().collect(); if indices.is_empty() { // Assume most general case - return TypeBound::Linear; + return TypeBound::Any; } least_upper_bound(indices.iter().map(|i| { let ta = args.get(*i); match ta { - Some(TypeArg::Runtime(s)) => s.least_upper_bound(), + Some(TypeArg::Type { ty: s }) => s.least_upper_bound(), _ => panic!("TypeArg index does not refer to a type."), } })) @@ -241,7 +241,7 @@ mod test { use crate::extension::SignatureError; use crate::extension::prelude::{qb_t, usize_t}; use crate::std_extensions::arithmetic::float_types::float64_type; - use crate::types::type_param::{TermTypeError, TypeParam}; + use crate::types::type_param::{TypeArg, TypeArgError, TypeParam}; use crate::types::{Signature, Type, TypeBound}; use super::{TypeDef, TypeDefBound}; @@ -250,7 +250,9 @@ mod test { fn test_instantiate_typedef() { let def = TypeDef { name: "MyType".into(), - params: vec![TypeParam::RuntimeType(TypeBound::Copyable)], + params: vec![TypeParam::Type { + b: TypeBound::Copyable, + }], extension: "MyRsrc".try_into().unwrap(), // Dummy extension. Will return `None` when trying to upgrade it into an `Arc`. extension_ref: Default::default(), @@ -258,9 +260,9 @@ mod test { bound: TypeDefBound::FromParams { indices: vec![0] }, }; let typ = Type::new_extension( - def.instantiate(vec![ - Type::new_function(Signature::new(vec![], vec![])).into(), - ]) + def.instantiate(vec![TypeArg::Type { + ty: Type::new_function(Signature::new(vec![], vec![])), + }]) .unwrap(), ); assert_eq!(typ.least_upper_bound(), TypeBound::Copyable); @@ -269,24 +271,27 @@ mod test { // And some bad arguments...firstly, wrong kind of TypeArg: assert_eq!( - def.instantiate([qb_t().into()]), + def.instantiate([TypeArg::Type { ty: qb_t() }]), Err(SignatureError::TypeArgMismatch( - TermTypeError::TypeMismatch { - term: Box::new(qb_t().into()), - type_: Box::new(TypeBound::Copyable.into()) + TypeArgError::TypeMismatch { + arg: TypeArg::Type { ty: qb_t() }, + param: TypeBound::Copyable.into() } )) ); // Too few arguments: assert_eq!( def.instantiate([]).unwrap_err(), - SignatureError::TypeArgMismatch(TermTypeError::WrongNumberArgs(0, 1)) + SignatureError::TypeArgMismatch(TypeArgError::WrongNumberArgs(0, 1)) ); // Too many arguments: assert_eq!( - def.instantiate([float64_type().into(), float64_type().into(),]) - .unwrap_err(), - SignatureError::TypeArgMismatch(TermTypeError::WrongNumberArgs(2, 1)) + def.instantiate([ + TypeArg::Type { ty: float64_type() }, + TypeArg::Type { ty: float64_type() }, + ]) + .unwrap_err(), + SignatureError::TypeArgMismatch(TypeArgError::WrongNumberArgs(2, 1)) ); } } diff --git a/hugr-core/src/hugr.rs b/hugr-core/src/hugr.rs index 78bdd88390..f9398b0cf8 100644 --- a/hugr-core/src/hugr.rs +++ b/hugr-core/src/hugr.rs @@ -5,6 +5,7 @@ pub mod hugrmut; pub(crate) mod ident; pub mod internal; pub mod patch; +pub mod persistent; pub mod serialize; pub mod validate; pub mod views; @@ -41,7 +42,7 @@ use crate::{Direction, Node}; #[derive(Clone, Debug, PartialEq)] pub struct Hugr { /// The graph encoding the adjacency structure of the HUGR. - graph: MultiPortGraph, + graph: MultiPortGraph, /// The node hierarchy. hierarchy: Hierarchy, @@ -553,8 +554,8 @@ pub(crate) mod test { use crate::extension::prelude::bool_t; use crate::ops::OpaqueOp; use crate::ops::handle::NodeHandle; + use crate::test_file; use crate::types::Signature; - use crate::{Visibility, test_file}; use cool_asserts::assert_matches; use itertools::Either; use portgraph::LinkView; @@ -674,26 +675,6 @@ pub(crate) mod test { assert_matches!(&hugr, Ok(_)); } - #[test] - #[cfg_attr(miri, ignore)] // Opening files is not supported in (isolated) miri - fn load_funcs_no_visibility() { - let hugr = Hugr::load( - BufReader::new(File::open(test_file!("hugr-no-visibility.hugr")).unwrap()), - None, - ) - .unwrap(); - - let [_mod, decl, defn] = hugr.nodes().take(3).collect_array().unwrap(); - assert_eq!( - hugr.get_optype(decl).as_func_decl().unwrap().visibility(), - &Visibility::Public - ); - assert_eq!( - hugr.get_optype(defn).as_func_defn().unwrap().visibility(), - &Visibility::Private - ); - } - fn hugr_failing_2262() -> Hugr { let sig = Signature::new(vec![bool_t(); 2], bool_t()); let mut mb = ModuleBuilder::new(); diff --git a/hugr-core/src/hugr/hugrmut.rs b/hugr-core/src/hugr/hugrmut.rs index 74a6d1461d..0265acd59f 100644 --- a/hugr-core/src/hugr/hugrmut.rs +++ b/hugr-core/src/hugr/hugrmut.rs @@ -183,23 +183,7 @@ pub trait HugrMut: HugrMutInternals { /// # Panics /// /// If the root node is not in the graph. - fn insert_hugr(&mut self, root: Self::Node, other: Hugr) -> InsertionResult { - let region = other.entrypoint(); - Self::insert_region(self, root, other, region) - } - - /// Insert a sub-region of another hugr into this one, under a given parent node. - /// - /// # Panics - /// - /// - If the root node is not in the graph. - /// - If the `region` node is not in `other`. - fn insert_region( - &mut self, - root: Self::Node, - other: Hugr, - region: Node, - ) -> InsertionResult; + fn insert_hugr(&mut self, root: Self::Node, other: Hugr) -> InsertionResult; /// Copy another hugr into this one, under a given parent node. /// @@ -263,17 +247,15 @@ pub trait HugrMut: HugrMutInternals { ExtensionRegistry: Extend; } -/// Records the result of inserting a Hugr or view via [`HugrMut::insert_hugr`], -/// [`HugrMut::insert_from_view`], or [`HugrMut::insert_region`]. +/// Records the result of inserting a Hugr or view +/// via [`HugrMut::insert_hugr`] or [`HugrMut::insert_from_view`]. /// -/// Contains a map from the nodes in the source HUGR to the nodes in the target -/// HUGR, using their respective `Node` types. +/// Contains a map from the nodes in the source HUGR to the nodes in the +/// target HUGR, using their respective `Node` types. pub struct InsertionResult { - /// The node, after insertion, that was the root of the inserted Hugr. + /// The node, after insertion, that was the entrypoint of the inserted Hugr. /// - /// That is, the value in [`InsertionResult::node_map`] under the key that - /// was the the `region` passed to [`HugrMut::insert_region`] or the - /// [`HugrView::entrypoint`] in the other cases. + /// That is, the value in [`InsertionResult::node_map`] under the key that was the [`HugrView::entrypoint`]. pub inserted_entrypoint: TargetN, /// Map from nodes in the Hugr/view that was inserted, to their new /// positions in the Hugr into which said was inserted. @@ -412,14 +394,17 @@ impl HugrMut for Hugr { (src_port, dst_port) } - fn insert_region( + fn insert_hugr( &mut self, root: Self::Node, mut other: Hugr, - region: Node, ) -> InsertionResult { - let node_map = insert_hugr_internal(self, &other, other.descendants(region), |&n| { - if n == region { Some(root) } else { None } + let node_map = insert_hugr_internal(self, &other, other.entry_descendants(), |&n| { + if n == other.entrypoint() { + Some(root) + } else { + None + } }); // Merge the extension sets. self.extensions.extend(other.extensions()); @@ -435,7 +420,7 @@ impl HugrMut for Hugr { self.metadata.set(new_node_pg, meta); } InsertionResult { - inserted_entrypoint: node_map[®ion], + inserted_entrypoint: node_map[&other.entrypoint()], node_map, } } diff --git a/hugr-core/src/hugr/internal.rs b/hugr-core/src/hugr/internal.rs index bb1a77f423..523ddcb1b1 100644 --- a/hugr-core/src/hugr/internal.rs +++ b/hugr-core/src/hugr/internal.rs @@ -21,7 +21,7 @@ use crate::ops::handle::NodeHandle; /// view. pub trait HugrInternals { /// The portgraph graph structure returned by [`HugrInternals::region_portgraph`]. - type RegionPortgraph<'p>: LinkView + Clone + 'p + type RegionPortgraph<'p>: LinkView + Clone + 'p where Self: 'p; @@ -109,7 +109,7 @@ impl PortgraphNodeMap for std::collections::HashMap { impl HugrInternals for Hugr { type RegionPortgraph<'p> - = &'p MultiPortGraph + = &'p MultiPortGraph where Self: 'p; @@ -390,22 +390,6 @@ impl HugrMutInternals for Hugr { } } -impl Hugr { - /// Consumes the HUGR and return a flat portgraph view of the region rooted - /// at `parent`. - #[inline] - pub fn into_region_portgraph( - self, - parent: Node, - ) -> portgraph::view::FlatRegion<'static, MultiPortGraph> { - let root = parent.into_portgraph(); - let Self { - graph, hierarchy, .. - } = self; - portgraph::view::FlatRegion::new_without_root(graph, hierarchy, root) - } -} - #[cfg(test)] mod test { use crate::{ diff --git a/hugr-core/src/hugr/patch/inline_call.rs b/hugr-core/src/hugr/patch/inline_call.rs index a4d8383847..40eac06e0e 100644 --- a/hugr-core/src/hugr/patch/inline_call.rs +++ b/hugr-core/src/hugr/patch/inline_call.rs @@ -291,32 +291,29 @@ mod test { fn test_polymorphic() -> Result<(), Box> { let tuple_ty = Type::new_tuple(vec![usize_t(); 2]); let mut fb = FunctionBuilder::new("mkpair", Signature::new(usize_t(), tuple_ty.clone()))?; - let helper = { - let mut mb = fb.module_root_builder(); - let fb2 = mb.define_function( - "id", - PolyFuncType::new( - [TypeBound::Copyable.into()], - Signature::new_endo(Type::new_var_use(0, TypeBound::Copyable)), - ), - )?; - let inps = fb2.input_wires(); - fb2.finish_with_outputs(inps)? - }; - let call1 = fb.call(helper.handle(), &[usize_t().into()], fb.input_wires())?; + let inner = fb.define_function( + "id", + PolyFuncType::new( + [TypeBound::Copyable.into()], + Signature::new_endo(Type::new_var_use(0, TypeBound::Copyable)), + ), + )?; + let inps = inner.input_wires(); + let inner = inner.finish_with_outputs(inps)?; + let call1 = fb.call(inner.handle(), &[usize_t().into()], fb.input_wires())?; let [call1_out] = call1.outputs_arr(); let tup = fb.make_tuple([call1_out, call1_out])?; - let call2 = fb.call(helper.handle(), &[tuple_ty.into()], [tup])?; + let call2 = fb.call(inner.handle(), &[tuple_ty.into()], [tup])?; let mut hugr = fb.finish_hugr_with_outputs(call2.outputs()).unwrap(); assert_eq!( - hugr.output_neighbours(helper.node()).collect::>(), + hugr.output_neighbours(inner.node()).collect::>(), [call1.node(), call2.node()] ); hugr.apply_patch(InlineCall::new(call1.node()))?; assert_eq!( - hugr.output_neighbours(helper.node()).collect::>(), + hugr.output_neighbours(inner.node()).collect::>(), [call2.node()] ); assert!(hugr.get_optype(call1.node()).is_dfg()); diff --git a/hugr-core/src/hugr/patch/outline_cfg.rs b/hugr-core/src/hugr/patch/outline_cfg.rs index 081ba24ea1..e0d5a27850 100644 --- a/hugr-core/src/hugr/patch/outline_cfg.rs +++ b/hugr-core/src/hugr/patch/outline_cfg.rs @@ -48,7 +48,7 @@ impl OutlineCfg { }; let o = h.get_optype(cfg_n); let OpType::CFG(_) = o else { - return Err(OutlineCfgError::ParentNotCfg(cfg_n, Box::new(o.clone()))); + return Err(OutlineCfgError::ParentNotCfg(cfg_n, o.clone())); }; let cfg_entry = h.children(cfg_n).next().unwrap(); let mut entry = None; @@ -215,7 +215,7 @@ pub enum OutlineCfgError { NotSiblings, /// The parent node was not a CFG node #[error("The parent node {0} was not a CFG but a {1}")] - ParentNotCfg(Node, Box), + ParentNotCfg(Node, OpType), /// Multiple blocks had incoming edges #[error("Multiple blocks had predecessors outside the set - at least {0} and {1}")] MultipleEntryNodes(Node, Node), diff --git a/hugr-core/src/hugr/patch/peel_loop.rs b/hugr-core/src/hugr/patch/peel_loop.rs index ccb9218283..9cf61290b6 100644 --- a/hugr-core/src/hugr/patch/peel_loop.rs +++ b/hugr-core/src/hugr/patch/peel_loop.rs @@ -135,7 +135,7 @@ mod test { use crate::builder::test::simple_dfg_hugr; use crate::builder::{ - Dataflow, DataflowHugr, DataflowSubContainer, FunctionBuilder, HugrBuilder, + Container, Dataflow, DataflowHugr, DataflowSubContainer, FunctionBuilder, HugrBuilder, }; use crate::extension::prelude::{bool_t, usize_t}; use crate::ops::{OpTag, OpTrait, Tag, TailLoop, handle::NodeHandle}; @@ -165,13 +165,8 @@ mod test { #[test] fn peel_loop_incoming_edges() { let i32_t = || INT_TYPES[5].clone(); - let mut fb = FunctionBuilder::new( - "main", - Signature::new(vec![bool_t(), usize_t(), i32_t()], usize_t()), - ) - .unwrap(); - let helper = fb - .module_root_builder() + let mut mb = crate::builder::ModuleBuilder::new(); + let helper = mb .declare( "helper", Signature::new( @@ -181,6 +176,12 @@ mod test { .into(), ) .unwrap(); + let mut fb = mb + .define_function( + "main", + Signature::new(vec![bool_t(), usize_t(), i32_t()], usize_t()), + ) + .unwrap(); let [b, u, i] = fb.input_wires_arr(); let (tl, call) = { let mut tlb = fb @@ -196,7 +197,8 @@ mod test { let [pred, other] = c.outputs_arr(); (tlb.finish_with_outputs(pred, [other]).unwrap(), c.node()) }; - let mut h = fb.finish_hugr_with_outputs(tl.outputs()).unwrap(); + let _ = fb.finish_with_outputs(tl.outputs()).unwrap(); + let mut h = mb.finish_hugr().unwrap(); h.apply_patch(PeelTailLoop::new(tl.node())).unwrap(); h.validate().unwrap(); diff --git a/hugr-core/src/hugr/patch/simple_replace.rs b/hugr-core/src/hugr/patch/simple_replace.rs index 46e5bde205..8fb22febda 100644 --- a/hugr-core/src/hugr/patch/simple_replace.rs +++ b/hugr-core/src/hugr/patch/simple_replace.rs @@ -58,12 +58,12 @@ impl SimpleReplacement { .inner_function_type() .ok_or(InvalidReplacement::InvalidDataflowGraph { node: replacement.entrypoint(), - op: Box::new(replacement.get_optype(replacement.entrypoint()).to_owned()), + op: replacement.get_optype(replacement.entrypoint()).to_owned(), })?; if subgraph_sig != repl_sig { return Err(InvalidReplacement::InvalidSignature { - expected: Box::new(subgraph_sig), - actual: Some(Box::new(repl_sig.into_owned())), + expected: subgraph_sig, + actual: Some(repl_sig.into_owned()), }); } Ok(Self { @@ -126,16 +126,11 @@ impl SimpleReplacement { /// of `self`. /// /// The returned port will be in `replacement`, unless the wire in the - /// replacement is empty and `boundary` is [`BoundaryMode::SnapToHost`] (the - /// default), in which case it will be another `host` port. If - /// [`BoundaryMode::IncludeIO`] is passed, the returned port will always - /// be in `replacement` even if it is invalid (i.e. it is an IO node in - /// the replacement). + /// replacement is empty, in which case it will another `host` port. pub fn linked_replacement_output( &self, port: impl Into>, host: &impl HugrView, - boundary: BoundaryMode, ) -> Option> { let HostPort(node, port) = port.into(); let pos = self @@ -144,7 +139,7 @@ impl SimpleReplacement { .iter() .position(move |&(n, p)| host.linked_inputs(n, p).contains(&(node, port)))?; - Some(self.linked_replacement_output_by_position(pos, host, boundary)) + Some(self.linked_replacement_output_by_position(pos, host)) } /// The outgoing port linked to the i-th output boundary edge of `subgraph`. @@ -155,7 +150,6 @@ impl SimpleReplacement { &self, pos: usize, host: &impl HugrView, - boundary: BoundaryMode, ) -> BoundaryPort { debug_assert!(pos < self.subgraph().signature(host).output_count()); @@ -166,7 +160,7 @@ impl SimpleReplacement { .single_linked_output(repl_out, pos) .expect("valid dfg wire"); - if out_node != repl_inp || boundary == BoundaryMode::IncludeIO { + if out_node != repl_inp { BoundaryPort::Replacement(out_node, out_port) } else { let (in_node, in_port) = *self.subgraph.incoming_ports()[out_port.index()] @@ -213,16 +207,11 @@ impl SimpleReplacement { /// of `self`. /// /// The returned ports will be in `replacement`, unless the wires in the - /// replacement are empty and `boundary` is [`BoundaryMode::SnapToHost`] - /// (the default), in which case they will be other `host` ports. If - /// [`BoundaryMode::IncludeIO`] is passed, the returned ports will - /// always be in `replacement` even if they are invalid (i.e. they are - /// an IO node in the replacement). + /// replacement are empty, in which case they are other `host` ports. pub fn linked_replacement_inputs<'a>( &'a self, port: impl Into>, host: &'a impl HugrView, - boundary: BoundaryMode, ) -> impl Iterator> + 'a { let HostPort(node, port) = port.into(); let positions = self @@ -234,16 +223,18 @@ impl SimpleReplacement { host.single_linked_output(n, p).expect("valid dfg wire") == (node, port) }); - positions - .flat_map(move |pos| self.linked_replacement_inputs_by_position(pos, host, boundary)) + positions.flat_map(|pos| self.linked_replacement_inputs_by_position(pos, host)) } /// The incoming ports linked to the i-th input boundary edge of `subgraph`. + /// + /// The ports will be in `replacement` for all endpoints of the i-th input + /// wire that are not the output node of `replacement` and be in `host` + /// otherwise. fn linked_replacement_inputs_by_position( &self, pos: usize, host: &impl HugrView, - boundary: BoundaryMode, ) -> impl Iterator> { debug_assert!(pos < self.subgraph().signature(host).input_count()); @@ -251,7 +242,7 @@ impl SimpleReplacement { self.replacement .linked_inputs(repl_inp, pos) .flat_map(move |(in_node, in_port)| { - if in_node != repl_out || boundary == BoundaryMode::IncludeIO { + if in_node != repl_out { Either::Left(std::iter::once(BoundaryPort::Replacement(in_node, in_port))) } else { let (out_node, out_port) = self.subgraph.outgoing_ports()[in_port.index()]; @@ -325,7 +316,7 @@ impl SimpleReplacement { subgraph_outgoing_ports .enumerate() .flat_map(|(pos, subg_np)| { - self.linked_replacement_inputs_by_position(pos, host, BoundaryMode::SnapToHost) + self.linked_replacement_inputs_by_position(pos, host) .filter_map(move |np| Some((np.as_replacement()?, subg_np))) }) .map(|((repl_node, repl_port), (subgraph_node, subgraph_port))| { @@ -368,7 +359,7 @@ impl SimpleReplacement { .enumerate() .filter_map(|(pos, subg_all)| { let np = self - .linked_replacement_output_by_position(pos, host, BoundaryMode::SnapToHost) + .linked_replacement_output_by_position(pos, host) .as_replacement()?; Some((np, subg_all)) }) @@ -415,7 +406,7 @@ impl SimpleReplacement { .enumerate() .filter_map(|(pos, subg_all)| { Some(( - self.linked_replacement_output_by_position(pos, host, BoundaryMode::SnapToHost) + self.linked_replacement_output_by_position(pos, host) .as_host()?, subg_all, )) @@ -509,25 +500,27 @@ impl SimpleReplacement { /// Map the host nodes in `self` according to `node_map`. /// /// `node_map` must map nodes in the current HUGR of the subgraph to - /// its equivalent nodes in some `new_host`. + /// its equivalent nodes in some `new_hugr`. /// /// This converts a replacement that acts on nodes of type `HostNode` to - /// a replacement that acts on `new_host`, with nodes of type `N`. - pub fn map_host_nodes( + /// a replacement that acts on `new_hugr`, with nodes of type `N`. + /// + /// This does not check convexity. It is up to the caller to ensure that + /// the mapped replacement obtained from this applies on a convex subgraph + /// of the new HUGR. + pub(crate) fn map_host_nodes( &self, node_map: impl Fn(HostNode) -> N, - new_host: &impl HugrView, - ) -> Result, InvalidReplacement> { + ) -> SimpleReplacement { let Self { subgraph, replacement, } = self; let subgraph = subgraph.map_nodes(node_map); - SimpleReplacement::try_new(subgraph, new_host, replacement.clone()) + SimpleReplacement::new_unchecked(subgraph, replacement.clone()) } - /// Allows to get the [Self::invalidated_nodes] without requiring a - /// [HugrView]. + /// Allows to get the [Self::invalidated_nodes] without requiring a [HugrView]. pub fn invalidation_set(&self) -> impl Iterator { self.subgraph.nodes().iter().copied() } @@ -550,24 +543,6 @@ impl PatchVerification for SimpleReplacement { } } -/// In [`SimpleReplacement::replacement`], IO nodes marking the boundary will -/// not be valid nodes in the host after the replacement is applied. -/// -/// This enum allows specifying whether these invalid nodes on the boundary -/// should be returned or should be resolved to valid nodes in the host. -#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)] -pub enum BoundaryMode { - /// Only consider nodes that are valid after the replacement is applied. - /// - /// This means that nodes in hosts may be returned in places where nodes in - /// the replacement would be typically expected. - #[default] - SnapToHost, - /// Include all nodes, including potentially invalid ones (inputs and - /// outputs of replacements). - IncludeIO, -} - /// Result of applying a [`SimpleReplacement`]. pub struct Outcome { /// Map from Node in replacement to corresponding Node in the result Hugr @@ -676,11 +651,11 @@ pub(in crate::hugr::patch) mod test { use crate::builder::test::n_identity; use crate::builder::{ - BuildError, DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer, HugrBuilder, - ModuleBuilder, endo_sig, inout_sig, + BuildError, Container, DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer, + HugrBuilder, ModuleBuilder, endo_sig, inout_sig, }; use crate::extension::prelude::{bool_t, qb_t}; - use crate::hugr::patch::simple_replace::{BoundaryMode, Outcome}; + use crate::hugr::patch::simple_replace::Outcome; use crate::hugr::patch::{BoundaryPort, HostPort, PatchVerification, ReplacementPort}; use crate::hugr::views::{HugrView, SiblingSubgraph}; use crate::hugr::{Hugr, HugrMut, Patch}; @@ -1173,11 +1148,7 @@ pub(in crate::hugr::patch) mod test { // Test linked_replacement_inputs with empty replacement let replacement_inputs: Vec<_> = repl - .linked_replacement_inputs( - (inp, OutgoingPort::from(0)), - &hugr, - BoundaryMode::SnapToHost, - ) + .linked_replacement_inputs((inp, OutgoingPort::from(0)), &hugr) .collect(); assert_eq!( @@ -1190,12 +1161,8 @@ pub(in crate::hugr::patch) mod test { // Test linked_replacement_output with empty replacement let replacement_output = (0..4) .map(|i| { - repl.linked_replacement_output( - (out, IncomingPort::from(i)), - &hugr, - BoundaryMode::SnapToHost, - ) - .unwrap() + repl.linked_replacement_output((out, IncomingPort::from(i)), &hugr) + .unwrap() }) .collect_vec(); @@ -1227,11 +1194,7 @@ pub(in crate::hugr::patch) mod test { }; let replacement_inputs: Vec<_> = repl - .linked_replacement_inputs( - (inp, OutgoingPort::from(0)), - &hugr, - BoundaryMode::SnapToHost, - ) + .linked_replacement_inputs((inp, OutgoingPort::from(0)), &hugr) .collect(); assert_eq!( @@ -1243,12 +1206,8 @@ pub(in crate::hugr::patch) mod test { let replacement_output = (0..4) .map(|i| { - repl.linked_replacement_output( - (out, IncomingPort::from(i)), - &hugr, - BoundaryMode::SnapToHost, - ) - .unwrap() + repl.linked_replacement_output((out, IncomingPort::from(i)), &hugr) + .unwrap() }) .collect_vec(); @@ -1285,11 +1244,7 @@ pub(in crate::hugr::patch) mod test { }; let replacement_inputs: Vec<_> = repl - .linked_replacement_inputs( - (inp, OutgoingPort::from(0)), - &hugr, - BoundaryMode::SnapToHost, - ) + .linked_replacement_inputs((inp, OutgoingPort::from(0)), &hugr) .collect(); assert_eq!( @@ -1305,12 +1260,8 @@ pub(in crate::hugr::patch) mod test { let replacement_output = (0..4) .map(|i| { - repl.linked_replacement_output( - (out, IncomingPort::from(i)), - &hugr, - BoundaryMode::SnapToHost, - ) - .unwrap() + repl.linked_replacement_output((out, IncomingPort::from(i)), &hugr) + .unwrap() }) .collect_vec(); diff --git a/hugr-persistent/src/persistent_hugr.rs b/hugr-core/src/hugr/persistent.rs similarity index 59% rename from hugr-persistent/src/persistent_hugr.rs rename to hugr-core/src/hugr/persistent.rs index bdca32ec1f..d2813e11b6 100644 --- a/hugr-persistent/src/persistent_hugr.rs +++ b/hugr-core/src/hugr/persistent.rs @@ -1,23 +1,98 @@ +//! Persistent data structure for HUGR mutations. +//! +//! This module provides a persistent data structure [`PersistentHugr`] that +//! implements [`crate::HugrView`]; mutations to the data are stored +//! persistently as a set of [`Commit`]s along with the dependencies between the +//! commits. +//! +//! As a result of persistency, the entire mutation history of a HUGR can be +//! traversed and references to previous versions of the data remain valid even +//! as the HUGR graph is "mutated" by applying patches: the patches are in +//! effect added to the history as new commits. +//! +//! The data structure underlying [`PersistentHugr`], which stores the history +//! of all commits, is [`CommitStateSpace`]. Multiple [`PersistentHugr`] can be +//! stored within a single [`CommitStateSpace`], which allows for the efficient +//! exploration of the space of all possible graph rewrites. +//! +//! ## Overlapping commits +//! +//! In general, [`CommitStateSpace`] may contain overlapping commits. Such +//! mutations are mutually exclusive as they modify the same nodes. It is +//! therefore not possible to apply all commits in a [`CommitStateSpace`] +//! simultaneously. A [`PersistentHugr`] on the other hand always corresponds to +//! a subgraph of a [`CommitStateSpace`] that is guaranteed to contain only +//! non-overlapping, compatible commits. By applying all commits in a +//! [`PersistentHugr`], we can materialize a [`Hugr`]. Traversing the +//! materialized HUGR is equivalent to using the [`crate::HugrView`] +//! implementation of the corresponding [`PersistentHugr`]. +//! +//! ## Summary of data types +//! +//! - [`Commit`] A modification to a [`Hugr`] (currently a +//! [`SimpleReplacement`]) that forms the atomic unit of change for a +//! [`PersistentHugr`] (like a commit in git). This is a reference-counted +//! value that is cheap to clone and will be freed when the last reference is +//! dropped. +//! - [`PersistentHugr`] A data structure that implements [`crate::HugrView`] +//! and can be used as a drop-in replacement for a [`crate::Hugr`] for +//! read-only access and mutations through the [`PatchVerification`] and +//! [`Patch`] traits. Mutations are stored as a history of commits. Unlike +//! [`CommitStateSpace`], it maintains the invariant that all contained +//! commits are compatible with eachother. +//! - [`CommitStateSpace`] Stores commits, recording the dependencies between +//! them. Includes the base HUGR and any number of possibly incompatible +//! (overlapping) commits. Unlike a [`PersistentHugr`], a state space can +//! contain mutually exclusive commits. +//! +//! ## Usage +//! +//! A [`PersistentHugr`] can be created from a base HUGR using +//! [`PersistentHugr::with_base`]. Replacements can then be applied to it +//! using [`PersistentHugr::add_replacement`]. Alternatively, if you already +//! have a populated state space, use [`PersistentHugr::try_new`] to create a +//! new HUGR with those commits. +//! +//! Add a sequence of commits to a state space by merging a [`PersistentHugr`] +//! into it using [`CommitStateSpace::extend`] or directly using +//! [`CommitStateSpace::try_add_commit`]. +//! +//! To obtain a [`PersistentHugr`] from your state space, use +//! [`CommitStateSpace::try_extract_hugr`]. A [`PersistentHugr`] can always be +//! materialized into a [`Hugr`] type using [`PersistentHugr::to_hugr`]. +//! +//! +//! [`PatchVerification`]: crate::hugr::patch::PatchVerification + +mod parents_view; +mod resolver; +mod state_space; +mod trait_impls; +pub mod walker; + +pub use walker::{PinnedWire, Walker}; + use std::{ - collections::{BTreeSet, HashMap}, + collections::{BTreeSet, HashMap, VecDeque}, mem, vec, }; use delegate::delegate; use derive_more::derive::From; -use hugr_core::{ - Direction, Hugr, HugrView, IncomingPort, Node, OutgoingPort, Port, SimpleReplacement, - hugr::patch::{Patch, simple_replace}, -}; use itertools::{Either, Itertools}; use relrc::RelRc; +use state_space::CommitData; +pub use state_space::{CommitId, CommitStateSpace, InvalidCommit, PatchNode}; + +pub use resolver::PointerEqResolver; use crate::{ - CommitData, CommitId, CommitStateSpace, InvalidCommit, PatchNode, PersistentReplacement, - Resolver, + Hugr, HugrView, IncomingPort, Node, OutgoingPort, Port, SimpleReplacement, + hugr::patch::{Patch, simple_replace}, }; -pub mod serial; +/// A replacement operation that can be applied to a [`PersistentHugr`]. +pub type PersistentReplacement = SimpleReplacement; /// A patch that can be applied to a [`PersistentHugr`] or a /// [`CommitStateSpace`] as an atomic commit. @@ -38,49 +113,29 @@ impl Commit { /// Requires a reference to the commit state space that the nodes in /// `replacement` refer to. /// - /// Use [`Self::try_new`] instead if the parents of the commit cannot be - /// inferred from the invalidation set of `replacement` alone. - /// /// The replacement must act on a non-empty subgraph, otherwise this /// function will return an [`InvalidCommit::EmptyReplacement`] error. /// /// If any of the parents of the replacement are not in the commit state /// space, this function will return an [`InvalidCommit::UnknownParent`] /// error. - pub fn try_from_replacement( - replacement: PersistentReplacement, - graph: &CommitStateSpace, - ) -> Result { - Self::try_new(replacement, [], graph) - } - - /// Create a new commit - /// - /// Requires a reference to the commit state space that the nodes in - /// `replacement` refer to. - /// - /// The returned commit will correspond to the application of `replacement` - /// and will be the child of the commits in `parents` as well as of all - /// the commits in the invalidation set of `replacement`. - /// - /// The replacement must act on a non-empty subgraph, otherwise this - /// function will return an [`InvalidCommit::EmptyReplacement`] error. - /// If any of the parents of the replacement are not in the commit state - /// space, this function will return an [`InvalidCommit::UnknownParent`] - /// error. - pub fn try_new( + pub fn try_from_replacement( replacement: PersistentReplacement, - parents: impl IntoIterator, - graph: &CommitStateSpace, + graph: &CommitStateSpace, ) -> Result { if replacement.subgraph().nodes().is_empty() { return Err(InvalidCommit::EmptyReplacement); } - let repl_parents = get_parent_commits(&replacement, graph)?; - let parents = parents - .into_iter() - .chain(repl_parents) - .unique_by(|p| p.as_ptr()); + let parent_ids = replacement.invalidation_set().map(|n| n.0).unique(); + let parents = parent_ids + .map(|id| { + if graph.contains_id(id) { + Ok(graph.get_commit(id).clone()) + } else { + Err(InvalidCommit::UnknownParent(id)) + } + }) + .collect::, _>>()?; let rc = RelRc::with_parents( replacement.into(), parents.into_iter().map(|p| (p.into(), ())), @@ -88,7 +143,7 @@ impl Commit { Ok(Self(rc)) } - pub(crate) fn as_relrc(&self) -> &RelRc { + fn as_relrc(&self) -> &RelRc { &self.0 } @@ -132,8 +187,8 @@ impl Commit { delegate! { to self.0 { - pub(crate) fn value(&self) -> &CommitData; - pub(crate) fn as_ptr(&self) -> *const relrc::node::InnerData; + fn value(&self) -> &CommitData; + fn as_ptr(&self) -> *const relrc::node::InnerData; } } @@ -187,13 +242,12 @@ impl<'a> From<&'a RelRc> for &'a Commit { /// /// ## Supported access and mutation /// -/// [`PersistentHugr`] implements [`HugrView`], so that it can used as +/// [`PersistentHugr`] implements [`crate::HugrView`], so that it can used as /// a drop-in substitute for a Hugr wherever read-only access is required. It -/// does not implement [`HugrMut`](hugr_core::hugr::hugrmut::HugrMut), however. -/// Mutations must be performed by applying patches (see -/// [`PatchVerification`](hugr_core::hugr::patch::PatchVerification) -/// and [`Patch`]). Currently, only [`SimpleReplacement`] patches are supported. -/// You can use [`Self::add_replacement`] to add a patch to `self`, or use the +/// does not implement [`HugrMut`](crate::hugr::HugrMut), however. Mutations +/// must be performed by applying patches (see [`PatchVerification`] and +/// [`Patch`]). Currently, only [`SimpleReplacement`] patches are supported. You +/// can use [`Self::add_replacement`] to add a patch to `self`, or use the /// aforementioned patch traits. /// /// ## Patches, commits and history @@ -213,16 +267,19 @@ impl<'a> From<&'a RelRc> for &'a Commit { /// /// Currently, only patches that apply to subgraphs within dataflow regions /// are supported. +/// +/// [`PatchVerification`]: crate::hugr::patch::PatchVerification + #[derive(Clone, Debug)] -pub struct PersistentHugr { +pub struct PersistentHugr { /// The state space of all commits. /// /// Invariant: all commits are "compatible", meaning that no two patches /// invalidate the same node. - state_space: CommitStateSpace, + state_space: CommitStateSpace, } -impl PersistentHugr { +impl PersistentHugr { /// Create a [`PersistentHugr`] with `hugr` as its base HUGR. /// /// All replacements added in the future will apply on top of `hugr`. @@ -252,6 +309,13 @@ impl PersistentHugr { graph.try_extract_hugr(graph.all_commit_ids()) } + /// Construct a [`PersistentHugr`] from a [`CommitStateSpace`]. + /// + /// Does not check that the commits are compatible. + fn from_state_space_unsafe(state_space: CommitStateSpace) -> Self { + Self { state_space } + } + /// Add a replacement to `self`. /// /// The effect of this is equivalent to applying `replacement` to the @@ -331,22 +395,13 @@ impl PersistentHugr { } Ok(commit_id.expect("new_commits cannot be empty")) } -} - -impl PersistentHugr { - /// Construct a [`PersistentHugr`] from a [`CommitStateSpace`]. - /// - /// Does not check that the commits are compatible. - pub(crate) fn from_state_space_unsafe(state_space: CommitStateSpace) -> Self { - Self { state_space } - } /// Convert this `PersistentHugr` to a materialized Hugr by applying all /// commits in `self`. /// /// This operation may be expensive and should be avoided in /// performance-critical paths. For read-only views into the data, rely - /// instead on the [`HugrView`] implementation when possible. + /// instead on the [`crate::HugrView`] implementation when possible. pub fn to_hugr(&self) -> Hugr { self.apply_all().0 } @@ -366,9 +421,7 @@ impl PersistentHugr { continue; }; - let repl = repl - .map_host_nodes(|n| node_map[&n], &hugr) - .expect("invalid replacement"); + let repl = repl.map_host_nodes(|n| node_map[&n]); let simple_replace::Outcome { node_map: new_node_map, @@ -399,12 +452,12 @@ impl PersistentHugr { } /// Get a reference to the underlying state space of `self`. - pub fn as_state_space(&self) -> &CommitStateSpace { + pub fn as_state_space(&self) -> &CommitStateSpace { &self.state_space } /// Convert `self` into its underlying [`CommitStateSpace`]. - pub fn into_state_space(self) -> CommitStateSpace { + pub fn into_state_space(self) -> CommitStateSpace { self.state_space } @@ -414,14 +467,68 @@ impl PersistentHugr { /// /// Panics if `node` is not in `self` (in particular if it is deleted) or if /// `port` is not a value port in `node`. - pub(crate) fn single_outgoing_port( + fn get_single_outgoing_port( &self, node: PatchNode, port: impl Into, ) -> (PatchNode, OutgoingPort) { - let w = self.get_wire(node, port.into()); - w.single_outgoing_port(self) - .expect("found invalid dfg wire") + let mut in_port = port.into(); + let PatchNode(commit_id, mut in_node) = node; + + assert!(self.is_value_port(node, in_port), "not a dataflow wire"); + assert!(self.contains_node(node), "node not in self"); + + let hugr = self.commit_hugr(commit_id); + let (mut out_node, mut out_port) = hugr + .single_linked_output(in_node, in_port) + .map(|(n, p)| (PatchNode(commit_id, n), p)) + .expect("invalid HUGR"); + + // invariant: (out_node, out_port) -> (in_node, in_port) is a boundary + // edge, i.e. it never is the case that both are deleted by the same + // child commit + loop { + let commit_id = out_node.0; + + if let Some(deleted_by) = self.find_deleting_commit(out_node) { + (out_node, out_port) = self + .state_space + .linked_child_output(PatchNode(commit_id, in_node), in_port, deleted_by) + .expect("valid boundary edge"); + // update (in_node, in_port) + (in_node, in_port) = { + let new_commit_id = out_node.0; + let hugr = self.commit_hugr(new_commit_id); + hugr.linked_inputs(out_node.1, out_port) + .find(|&(n, _)| { + self.find_deleting_commit(PatchNode(commit_id, n)).is_none() + }) + .expect("out_node is connected to output node (which is never deleted)") + }; + } else if self + .replacement(commit_id) + .is_some_and(|repl| repl.get_replacement_io()[0] == out_node.1) + { + // out_node is an input node + (out_node, out_port) = self + .as_state_space() + .linked_parent_input(PatchNode(commit_id, in_node), in_port); + // update (in_node, in_port) + (in_node, in_port) = { + let new_commit_id = out_node.0; + let hugr = self.commit_hugr(new_commit_id); + hugr.linked_inputs(out_node.1, out_port) + .find(|&(n, _)| { + self.find_deleting_commit(PatchNode(new_commit_id, n)) + == Some(commit_id) + }) + .expect("boundary edge must connect out_node to deleted node") + }; + } else { + // valid outgoing node! + return (out_node, out_port); + } + } } /// All incoming ports that the given outgoing port is attached to. @@ -430,14 +537,99 @@ impl PersistentHugr { /// /// Panics if `out_node` is not in `self` (in particular if it is deleted) /// or if `out_port` is not a value port in `out_node`. - pub(crate) fn all_incoming_ports( + fn get_all_incoming_ports( &self, out_node: PatchNode, out_port: OutgoingPort, ) -> impl Iterator { - let w = self.get_wire(out_node, out_port); - w.into_all_ports(self, Direction::Incoming) - .map(|(node, port)| (node, port.as_incoming().unwrap())) + assert!( + self.is_value_port(out_node, out_port), + "not a dataflow wire" + ); + assert!(self.contains_node(out_node), "node not in self"); + + let mut visited = BTreeSet::new(); + // enqueue the outport and initialise the set of valid incoming ports + // to the valid incoming ports in this commit + let mut queue = VecDeque::from([(out_node, out_port)]); + let mut valid_incoming_ports = BTreeSet::from_iter( + self.commit_hugr(out_node.0) + .linked_inputs(out_node.1, out_port) + .map(|(in_node, in_port)| (PatchNode(out_node.0, in_node), in_port)) + .filter(|(in_node, _)| self.contains_node(*in_node)), + ); + + // A simple BFS across the commit history to find all equivalent incoming ports. + while let Some((out_node, out_port)) = queue.pop_front() { + if !visited.insert((out_node, out_port)) { + continue; + } + let commit_id = out_node.0; + let hugr = self.commit_hugr(commit_id); + let out_deleted_by = self.find_deleting_commit(out_node); + let curr_repl_out = { + let repl = self.replacement(commit_id); + repl.map(|r| r.get_replacement_io()[1]) + }; + // incoming ports are of interest to us if + // (i) they are connected to the output of a replacement (then there will be a + // linked port in a parent commit), or + // (ii) they are deleted by a child commit and are not equal to the out_node + // (then there will be a linked port in a child commit) + let is_linked_to_output = curr_repl_out.is_some_and(|curr_repl_out| { + hugr.linked_inputs(out_node.1, out_port) + .any(|(in_node, _)| in_node == curr_repl_out) + }); + + let deleted_by_child: BTreeSet<_> = hugr + .linked_inputs(out_node.1, out_port) + .filter(|(in_node, _)| Some(in_node) != curr_repl_out.as_ref()) + .filter_map(|(in_node, _)| { + self.find_deleting_commit(PatchNode(commit_id, in_node)) + .filter(|other_deleted_by| + // (out_node, out_port) -> (in_node, in_port) is a boundary edge + // into the child commit `other_deleted_by` + (Some(other_deleted_by) != out_deleted_by.as_ref())) + }) + .collect(); + + // Convert an incoming port to the unique outgoing port that it is linked to + let to_outgoing_port = |(PatchNode(commit_id, in_node), in_port)| { + let hugr = self.commit_hugr(commit_id); + let (out_node, out_port) = hugr + .single_linked_output(in_node, in_port) + .expect("valid dfg wire"); + (PatchNode(commit_id, out_node), out_port) + }; + + if is_linked_to_output { + // Traverse boundary to parent(s) + let new_ins = self + .as_state_space() + .linked_parent_outputs(out_node, out_port); + for (in_node, in_port) in new_ins { + if self.contains_node(in_node) { + valid_incoming_ports.insert((in_node, in_port)); + } + queue.push_back(to_outgoing_port((in_node, in_port))); + } + } + + for child in deleted_by_child { + // Traverse boundary to `child` + let new_ins = self + .as_state_space() + .linked_child_inputs(out_node, out_port, child); + for (in_node, in_port) in new_ins { + if self.contains_node(in_node) { + valid_incoming_ports.insert((in_node, in_port)); + } + queue.push_back(to_outgoing_port((in_node, in_port))); + } + } + } + + valid_incoming_ports.into_iter() } delegate! { @@ -454,19 +646,17 @@ impl PersistentHugr { pub fn base_commit(&self) -> &Commit; /// Get the commit with ID `commit_id`. pub fn get_commit(&self, commit_id: CommitId) -> &Commit; - /// Check whether `commit_id` exists and return it. - pub fn try_get_commit(&self, commit_id: CommitId) -> Option<&Commit>; /// Get an iterator over all nodes inserted by `commit_id`. /// /// All nodes will be PatchNodes with commit ID `commit_id`. pub fn inserted_nodes(&self, commit_id: CommitId) -> impl Iterator + '_; /// Get the replacement for `commit_id`. - pub(crate) fn replacement(&self, commit_id: CommitId) -> Option<&SimpleReplacement>; + fn replacement(&self, commit_id: CommitId) -> Option<&SimpleReplacement>; /// Get the Hugr inserted by `commit_id`. /// /// This is either the replacement Hugr of a [`CommitData::Replacement`] or /// the base Hugr of a [`CommitData::Base`]. - pub(crate) fn commit_hugr(&self, commit_id: CommitId) -> &Hugr; + pub(super) fn commit_hugr(&self, commit_id: CommitId) -> &Hugr; /// Get an iterator over all commit IDs in the persistent HUGR. pub fn all_commit_ids(&self) -> impl Iterator + Clone + '_; } @@ -511,11 +701,7 @@ impl PersistentHugr { .unique() } - /// Get the child commit that deletes `node`. - pub(crate) fn find_deleting_commit( - &self, - node @ PatchNode(commit_id, _): PatchNode, - ) -> Option { + fn find_deleting_commit(&self, node @ PatchNode(commit_id, _): PatchNode) -> Option { let mut children = self.state_space.children(commit_id); children.find(move |&child_id| { let child = self.get_commit(child_id); @@ -523,12 +709,6 @@ impl PersistentHugr { }) } - /// Convert a node ID specific to a commit HUGR into a patch node in the - /// [`PersistentHugr`]. - pub(crate) fn to_persistent_node(&self, node: Node, commit_id: CommitId) -> PatchNode { - PatchNode(commit_id, node) - } - /// Check if a patch node is in the PersistentHugr, that is, it belongs to /// a commit in the state space and is not deleted by any child commit. pub fn contains_node(&self, PatchNode(commit_id, node): PatchNode) -> bool { @@ -540,46 +720,16 @@ impl PersistentHugr { self.contains_id(commit_id) && !is_replacement_io() && !is_deleted() } - pub(crate) fn is_value_port( - &self, - PatchNode(commit_id, node): PatchNode, - port: impl Into, - ) -> bool { + fn is_value_port(&self, PatchNode(commit_id, node): PatchNode, port: impl Into) -> bool { self.commit_hugr(commit_id) .get_optype(node) .port_kind(port) .expect("invalid port") .is_value() } - - pub(super) fn value_ports( - &self, - patch_node @ PatchNode(commit_id, node): PatchNode, - dir: Direction, - ) -> impl Iterator + '_ { - let hugr = self.commit_hugr(commit_id); - let ports = hugr.node_ports(node, dir); - ports.filter_map(move |p| self.is_value_port(patch_node, p).then_some((patch_node, p))) - } - - pub(super) fn output_value_ports( - &self, - patch_node: PatchNode, - ) -> impl Iterator + '_ { - self.value_ports(patch_node, Direction::Outgoing) - .map(|(n, p)| (n, p.as_outgoing().expect("unexpected port direction"))) - } - - pub(super) fn input_value_ports( - &self, - patch_node: PatchNode, - ) -> impl Iterator + '_ { - self.value_ports(patch_node, Direction::Incoming) - .map(|(n, p)| (n, p.as_incoming().expect("unexpected port direction"))) - } } -impl IntoIterator for PersistentHugr { +impl IntoIterator for PersistentHugr { type Item = Commit; type IntoIter = vec::IntoIter; @@ -595,13 +745,13 @@ impl IntoIterator for PersistentHugr { /// Find a node in `commit` that is invalidated by more than one child commit /// among `children`. -pub(crate) fn find_conflicting_node<'a>( +fn find_conflicting_node<'a>( commit_id: CommitId, - children: impl IntoIterator, + mut children: impl Iterator, ) -> Option { let mut all_invalidated = BTreeSet::new(); - children.into_iter().find_map(|child| { + children.find_map(|child| { let mut new_invalidated = child .invalidation_set() @@ -616,17 +766,12 @@ pub(crate) fn find_conflicting_node<'a>( }) } -fn get_parent_commits( - replacement: &PersistentReplacement, - graph: &CommitStateSpace, -) -> Result, InvalidCommit> { - let parent_ids = replacement.invalidation_set().map(|n| n.owner()).unique(); - parent_ids - .map(|id| { - graph - .try_get_commit(id) - .cloned() - .ok_or(InvalidCommit::UnknownParent(id)) - }) - .collect() +pub mod serial { + //! Serialization formats of [`CommitStateSpace`](super::CommitStateSpace) + //! and related types + #[doc(inline)] + pub use super::state_space::serial::*; } + +#[cfg(test)] +mod tests; diff --git a/hugr-persistent/src/parents_view.rs b/hugr-core/src/hugr/persistent/parents_view.rs similarity index 95% rename from hugr-persistent/src/parents_view.rs rename to hugr-core/src/hugr/persistent/parents_view.rs index 6f1f3c86de..b4aa076060 100644 --- a/hugr-persistent/src/parents_view.rs +++ b/hugr-core/src/hugr/persistent/parents_view.rs @@ -1,10 +1,9 @@ use std::collections::{BTreeMap, HashMap}; -use hugr_core::{ +use crate::{ Direction, Hugr, HugrView, Node, Port, extension::ExtensionRegistry, hugr::{ - self, internal::HugrInternals, views::{ExtractionResult, render}, }, @@ -18,15 +17,12 @@ use super::{CommitStateSpace, PatchNode, state_space::CommitId}; /// Note that this is not a valid HUGR: not a single entrypoint, root etc. As /// a consequence, not all HugrView methods are implemented. #[derive(Debug, Clone)] -pub(crate) struct ParentsView<'a> { +pub(super) struct ParentsView<'a> { hugrs: BTreeMap, } impl<'a> ParentsView<'a> { - pub(crate) fn from_commit( - commit_id: CommitId, - state_space: &'a CommitStateSpace, - ) -> Self { + pub(super) fn from_commit(commit_id: CommitId, state_space: &'a CommitStateSpace) -> Self { let mut hugrs = BTreeMap::new(); for parent in state_space.parents(commit_id) { hugrs.insert(parent, state_space.commit_hugr(parent)); @@ -37,7 +33,7 @@ impl<'a> ParentsView<'a> { impl HugrInternals for ParentsView<'_> { type RegionPortgraph<'p> - = portgraph::MultiPortGraph + = portgraph::MultiPortGraph where Self: 'p; @@ -55,7 +51,7 @@ impl HugrInternals for ParentsView<'_> { unimplemented!() } - fn node_metadata_map(&self, node: Self::Node) -> &hugr::NodeMetadataMap { + fn node_metadata_map(&self, node: Self::Node) -> &crate::hugr::NodeMetadataMap { let PatchNode(commit_id, node) = node; self.hugrs .get(&commit_id) diff --git a/hugr-core/src/hugr/persistent/resolver.rs b/hugr-core/src/hugr/persistent/resolver.rs new file mode 100644 index 0000000000..0a0d140ee5 --- /dev/null +++ b/hugr-core/src/hugr/persistent/resolver.rs @@ -0,0 +1,43 @@ +use relrc::EquivalenceResolver; + +/// A resolver that considers two nodes equivalent if they are the same pointer. +/// +/// Resolvers determine when two patches are equivalent and should be merged +/// in the patch history. +/// +/// This is a trivial resolver (to be expanded on later), that considers two +/// patches equivalent if they point to the same data in memory. +#[derive(Clone, Debug, Default, PartialEq, Eq)] +pub struct PointerEqResolver; + +impl EquivalenceResolver for PointerEqResolver { + type MergeMapping = (); + + type DedupKey = *const N; + + fn id(&self) -> String { + "PointerEqResolver".to_string() + } + + fn dedup_key(&self, value: &N, _incoming_edges: &[&E]) -> Self::DedupKey { + value as *const N + } + + fn try_merge_mapping( + &self, + a_value: &N, + _a_incoming_edges: &[&E], + b_value: &N, + _b_incoming_edges: &[&E], + ) -> Result { + if std::ptr::eq(a_value, b_value) { + Ok(()) + } else { + Err(relrc::resolver::NotEquivalent) + } + } + + fn move_edge_source(&self, _mapping: &Self::MergeMapping, edge: &E) -> E { + edge.clone() + } +} diff --git a/hugr-persistent/src/state_space.rs b/hugr-core/src/hugr/persistent/state_space.rs similarity index 70% rename from hugr-persistent/src/state_space.rs rename to hugr-core/src/hugr/persistent/state_space.rs index 30f704347d..d710c1aaec 100644 --- a/hugr-persistent/src/state_space.rs +++ b/hugr-core/src/hugr/persistent/state_space.rs @@ -1,29 +1,19 @@ -//! Store of commit histories for a [`PersistentHugr`]. - use std::collections::{BTreeSet, VecDeque}; use delegate::delegate; use derive_more::From; -use hugr_core::{ - Direction, Hugr, HugrView, IncomingPort, Node, OutgoingPort, Port, SimpleReplacement, - hugr::{ - self, - internal::HugrInternals, - patch::{ - BoundaryPort, - simple_replace::{BoundaryMode, InvalidReplacement}, - }, - views::InvalidSignature, - }, - ops::OpType, -}; -use itertools::{Either, Itertools}; +use itertools::Itertools; use relrc::{HistoryGraph, RelRc}; use thiserror::Error; +use super::{ + Commit, PersistentHugr, PersistentReplacement, PointerEqResolver, find_conflicting_node, + parents_view::ParentsView, +}; use crate::{ - Commit, PersistentHugr, PersistentReplacement, PointerEqResolver, Resolver, - find_conflicting_node, parents_view::ParentsView, subgraph::InvalidPinnedSubgraph, + Direction, Hugr, HugrView, IncomingPort, Node, OutgoingPort, Port, SimpleReplacement, + hugr::{internal::HugrInternals, patch::BoundaryPort}, + ops::OpType, }; pub mod serial; @@ -33,46 +23,23 @@ pub type CommitId = relrc::NodeId; /// A HUGR node within a commit of the commit state space #[derive( - Copy, Clone, Ord, PartialOrd, Eq, PartialEq, Hash, serde::Serialize, serde::Deserialize, + Copy, Clone, Ord, PartialOrd, Eq, PartialEq, Debug, Hash, serde::Serialize, serde::Deserialize, )] pub struct PatchNode(pub CommitId, pub Node); -impl PatchNode { - /// Get the commit ID of the commit that owns this node. - pub fn owner(&self) -> CommitId { - self.0 - } -} - -// Print out PatchNodes as `Node(x)@commit_hex` -impl std::fmt::Debug for PatchNode { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "{:?}@{}", self.1, self.0) - } -} - impl std::fmt::Display for PatchNode { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "{self:?}") + write!(f, "{:?}", self) } } -mod hidden { - use super::*; - - /// The data stored in a [`Commit`], either the base [`Hugr`] (on which all - /// other commits apply), or a [`PersistentReplacement`] - /// - /// This is a "unnamable" type: we do not expose this struct publicly in our - /// API, but we can still use it in public trait bounds (see - /// [`Resolver`](crate::resolver::Resolver)). - #[derive(Debug, Clone, From)] - pub enum CommitData { - Base(Hugr), - Replacement(PersistentReplacement), - } +/// The data stored in a [`Commit`], either the base [`Hugr`] (on which all +/// other commits apply), or a [`PersistentReplacement`] +#[derive(Debug, Clone, From)] +pub(super) enum CommitData { + Base(Hugr), + Replacement(PersistentReplacement), } -pub(crate) use hidden::CommitData; /// A set of commits with directed (acyclic) dependencies between them. /// @@ -94,24 +61,24 @@ pub(crate) use hidden::CommitData; /// same subgraph. Use [`Self::try_extract_hugr`] to get a [`PersistentHugr`] /// with a set of compatible commits. #[derive(Clone, Debug)] -pub struct CommitStateSpace { +pub struct CommitStateSpace { /// A set of commits with directed (acyclic) dependencies between them. /// /// Each commit is stored as a [`RelRc`]. - pub(super) graph: HistoryGraph, + graph: HistoryGraph, /// The unique root of the commit graph. /// /// The only commit in the graph with variant [`CommitData::Base`]. All /// other commits are [`CommitData::Replacement`]s, and are descendants /// of this. - pub(super) base_commit: CommitId, + base_commit: CommitId, } -impl CommitStateSpace { +impl CommitStateSpace { /// Create a new commit state space with a single base commit. pub fn with_base(hugr: Hugr) -> Self { let commit = RelRc::new(CommitData::Base(hugr)); - let graph = HistoryGraph::new([commit.clone()], R::default()); + let graph = HistoryGraph::new([commit.clone()], PointerEqResolver); let base_commit = graph .all_node_ids() .exactly_one() @@ -127,7 +94,7 @@ impl CommitStateSpace { pub fn try_from_commits( commits: impl IntoIterator, ) -> Result { - let graph = HistoryGraph::new(commits.into_iter().map_into(), R::default()); + let graph = HistoryGraph::new(commits.into_iter().map_into(), PointerEqResolver); let base_commits = graph .all_node_ids() .filter(|&id| matches!(graph.get_node(id).value(), CommitData::Base(_))) @@ -151,29 +118,39 @@ impl CommitStateSpace { self.try_add_commit(commit) } + /// Add a set of commits to the state space. + /// + /// Commits must be valid replacement commits or coincide with the existing + /// base commit. + pub fn extend(&mut self, commits: impl IntoIterator) { + // TODO: make this more efficient + for commit in commits { + self.try_add_commit(commit) + .expect("invalid commit in extend"); + } + } + /// Add a commit (and all its ancestors) to the state space. /// /// Returns an [`InvalidCommit::NonUniqueBase`] error if the commit is a /// base commit and does not coincide with the existing base commit. pub fn try_add_commit(&mut self, commit: Commit) -> Result { - let is_base = commit.as_relrc().ptr_eq(self.base_commit().as_relrc()); - if !is_base && matches!(commit.value(), CommitData::Base(_)) { + if matches!(commit.value(), CommitData::Base(_) if !commit.0.ptr_eq(&self.base_commit().0)) + { return Err(InvalidCommit::NonUniqueBase(2)); } let commit = commit.into(); Ok(self.graph.insert_node(commit)) } - /// Add a set of commits to the state space. - /// - /// Commits must be valid replacement commits or coincide with the existing - /// base commit. - pub fn extend(&mut self, commits: impl IntoIterator) { - // TODO: make this more efficient - for commit in commits { - self.try_add_commit(commit) - .expect("invalid commit in extend"); - } + /// Check if `commit` is in the commit state space. + pub fn contains(&self, commit: &Commit) -> bool { + self.graph.contains(commit.as_relrc()) + } + + /// Check if `commit_id` is in the commit state space. + pub fn contains_id(&self, commit_id: CommitId) -> bool { + self.graph.contains_id(commit_id) } /// Extract a `PersistentHugr` from this state space, consisting of @@ -187,7 +164,7 @@ impl CommitStateSpace { pub fn try_extract_hugr( &self, commits: impl IntoIterator, - ) -> Result, InvalidCommit> { + ) -> Result { // Define commits as the set of all ancestors of the given commits let all_commit_ids = get_all_ancestors(&self.graph, commits); @@ -210,25 +187,13 @@ impl CommitStateSpace { let commits = all_commit_ids .into_iter() .map(|id| self.get_commit(id).as_relrc().clone()); - let subgraph = HistoryGraph::new(commits, R::default()); + let subgraph = HistoryGraph::new(commits, PointerEqResolver); Ok(PersistentHugr::from_state_space_unsafe(Self { graph: subgraph, base_commit: self.base_commit, })) } -} - -impl CommitStateSpace { - /// Check if `commit` is in the commit state space. - pub fn contains(&self, commit: &Commit) -> bool { - self.graph.contains(commit.as_relrc()) - } - - /// Check if `commit_id` is in the commit state space. - pub fn contains_id(&self, commit_id: CommitId) -> bool { - self.graph.contains_id(commit_id) - } /// Get the base commit ID. pub fn base(&self) -> CommitId { @@ -253,12 +218,6 @@ impl CommitStateSpace { self.graph.get_node(commit_id).into() } - /// Check whether `commit_id` exists and return it. - pub fn try_get_commit(&self, commit_id: CommitId) -> Option<&Commit> { - self.contains_id(commit_id) - .then(|| self.get_commit(commit_id)) - } - /// Get an iterator over all commit IDs in the state space. pub fn all_commit_ids(&self) -> impl Iterator + Clone + '_ { let vec = self.graph.all_node_ids().collect_vec(); @@ -297,7 +256,7 @@ impl CommitStateSpace { } } - pub(crate) fn as_history_graph(&self) -> &HistoryGraph { + pub(super) fn as_history_graph(&self) -> &HistoryGraph { &self.graph } @@ -305,7 +264,7 @@ impl CommitStateSpace { /// /// This is either the replacement Hugr of a [`CommitData::Replacement`] or /// the base Hugr of a [`CommitData::Base`]. - pub(crate) fn commit_hugr(&self, commit_id: CommitId) -> &Hugr { + pub(super) fn commit_hugr(&self, commit_id: CommitId) -> &Hugr { let commit = self.get_commit(commit_id); match commit.value() { CommitData::Base(base) => base, @@ -336,23 +295,15 @@ impl CommitStateSpace { /// Get the boundary inputs linked to `(node, port)` in `child`. /// - /// The returned ports will be ports on successors of the input node in the - /// `child` commit, unless (node, port) is connected to a passthrough wire - /// in `child` (i.e. a wire from input node to output node), in which - /// case they will be in one of the parents of `child`. - /// - /// `child` should be a child commit of the owner of `node`. - /// /// ## Panics /// - /// Panics if `(node, port)` is not a boundary edge, if `child` is not - /// a valid commit ID or if it is the base commit. - pub(crate) fn linked_child_inputs( + /// Panics if `(node, port)` is not a boundary edge, or if `child` is not + /// a valid commit ID. + pub(super) fn linked_child_inputs( &self, node: PatchNode, port: OutgoingPort, child: CommitId, - return_invalid: BoundaryMode, ) -> impl Iterator + '_ { assert!( self.is_boundary_edge(node, port, child), @@ -361,7 +312,7 @@ impl CommitStateSpace { let parent_hugrs = ParentsView::from_commit(child, self); let repl = self.replacement(child).expect("valid child commit"); - repl.linked_replacement_inputs((node, port), &parent_hugrs, return_invalid) + repl.linked_replacement_inputs((node, port), &parent_hugrs) .collect_vec() .into_iter() .map(move |np| match np { @@ -372,70 +323,32 @@ impl CommitStateSpace { /// Get the single boundary output linked to `(node, port)` in `child`. /// - /// The returned port will be a port on a predecessor of the output node in - /// the `child` commit, unless (node, port) is connected to a passthrough - /// wire in `child` (i.e. a wire from input node to output node), in - /// which case it will be in one of the parents of `child`. - /// - /// `child` should be a child commit of the owner of `node` (or `None` will - /// be returned). - /// /// ## Panics /// /// Panics if `child` is not a valid commit ID. - pub(crate) fn linked_child_output( + pub(super) fn linked_child_output( &self, node: PatchNode, port: IncomingPort, child: CommitId, - return_invalid: BoundaryMode, ) -> Option<(PatchNode, OutgoingPort)> { let parent_hugrs = ParentsView::from_commit(child, self); - let repl = self.replacement(child)?; - match repl.linked_replacement_output((node, port), &parent_hugrs, return_invalid)? { + let repl = self.replacement(child).expect("valid child commit"); + match repl.linked_replacement_output((node, port), &parent_hugrs)? { BoundaryPort::Host(patch_node, port) => (patch_node, port), BoundaryPort::Replacement(node, port) => (PatchNode(child, node), port), } .into() } - /// Get the boundary ports linked to `(node, port)` in `child`. - /// - /// `child` should be a child commit of the owner of `node`. - /// - /// See [`Self::linked_child_inputs`] and [`Self::linked_child_output`] for - /// more details. - pub(crate) fn linked_child_ports( - &self, - node: PatchNode, - port: impl Into, - child: CommitId, - return_invalid: BoundaryMode, - ) -> impl Iterator + '_ { - match port.into().as_directed() { - Either::Left(incoming) => Either::Left( - self.linked_child_output(node, incoming, child, return_invalid) - .into_iter() - .map(|(node, port)| (node, port.into())), - ), - Either::Right(outgoing) => Either::Right( - self.linked_child_inputs(node, outgoing, child, return_invalid) - .map(|(node, port)| (node, port.into())), - ), - } - } - - /// Get the single output port linked to `(node, port)` in a parent of the - /// commit of `node`. - /// - /// The returned port belongs to the input boundary of the subgraph in - /// parent. + /// Get the single output boundary port linked to `(node, port)` in a + /// parent of the commit of `node`. /// /// ## Panics /// /// Panics if `(node, port)` is not connected to the input node in the /// commit of `node`, or if the node is not valid. - pub fn linked_parent_input( + pub(super) fn linked_parent_input( &self, PatchNode(commit_id, node): PatchNode, port: IncomingPort, @@ -453,17 +366,7 @@ impl CommitStateSpace { repl.linked_host_input((node, port), &parent_hugrs).into() } - /// Get the input ports linked to `(node, port)` in a parent of the commit - /// of `node`. - /// - /// The returned ports belong to the output boundary of the subgraph in - /// parent. - /// - /// ## Panics - /// - /// Panics if `(node, port)` is not connected to the output node in the - /// commit of `node`, or if the node is not valid. - pub fn linked_parent_outputs( + pub(super) fn linked_parent_outputs( &self, PatchNode(commit_id, node): PatchNode, port: OutgoingPort, @@ -484,35 +387,8 @@ impl CommitStateSpace { .into_iter() } - /// Get the ports linked to `(node, port)` in a parent of the commit of - /// `node`. - /// - /// See [`Self::linked_parent_input`] and [`Self::linked_parent_outputs`] - /// for more details. - /// - /// ## Panics - /// - /// Panics if `(node, port)` is not connected to an IO node in the commit - /// of `node`, or if the node is not valid. - pub fn linked_parent_ports( - &self, - node: PatchNode, - port: impl Into, - ) -> impl Iterator + '_ { - match port.into().as_directed() { - Either::Left(incoming) => { - let (node, port) = self.linked_parent_input(node, incoming); - Either::Left(std::iter::once((node, port.into()))) - } - Either::Right(outgoing) => Either::Right( - self.linked_parent_outputs(node, outgoing) - .map(|(node, port)| (node, port.into())), - ), - } - } - /// Get the replacement for `commit_id`. - pub(crate) fn replacement(&self, commit_id: CommitId) -> Option<&SimpleReplacement> { + pub(super) fn replacement(&self, commit_id: CommitId) -> Option<&SimpleReplacement> { let commit = self.get_commit(commit_id); commit.replacement() } @@ -520,7 +396,7 @@ impl CommitStateSpace { // The subset of HugrView methods that can be implemented on CommitStateSpace // by simplify delegating to the patches' respective HUGRs -impl CommitStateSpace { +impl CommitStateSpace { /// Get the type of the operation at `node`. pub fn get_optype(&self, PatchNode(commit_id, node): PatchNode) -> &OpType { let hugr = self.commit_hugr(commit_id); @@ -571,7 +447,7 @@ impl CommitStateSpace { pub fn node_metadata_map( &self, PatchNode(commit_id, node): PatchNode, - ) -> &hugr::NodeMetadataMap { + ) -> &crate::hugr::NodeMetadataMap { self.commit_hugr(commit_id).node_metadata_map(node) } } @@ -615,20 +491,4 @@ pub enum InvalidCommit { /// The commit is an empty replacement. #[error("Not allowed: empty replacement")] EmptyReplacement, - - #[error("Invalid subgraph: {0}")] - /// The subgraph of the replacement is not convex. - InvalidSubgraph(#[from] InvalidPinnedSubgraph), - - /// The replacement of the commit is invalid. - #[error("Invalid replacement: {0}")] - InvalidReplacement(#[from] InvalidReplacement), - - /// The signature of the replacement is invalid. - #[error("Invalid signature: {0}")] - InvalidSignature(#[from] InvalidSignature), - - /// A wire has an unpinned port. - #[error("Incomplete wire: {0} is unpinned")] - IncompleteWire(PatchNode, Port), } diff --git a/hugr-persistent/src/state_space/serial.rs b/hugr-core/src/hugr/persistent/state_space/serial.rs similarity index 66% rename from hugr-persistent/src/state_space/serial.rs rename to hugr-core/src/hugr/persistent/state_space/serial.rs index b20c585eb7..c345308b8b 100644 --- a/hugr-persistent/src/state_space/serial.rs +++ b/hugr-core/src/hugr/persistent/state_space/serial.rs @@ -1,9 +1,7 @@ -//! Serialized format for [`CommitStateSpace`] - use relrc::serialization::SerializedHistoryGraph; use super::*; -use hugr_core::hugr::patch::simple_replace::serial::SerialSimpleReplacement; +use crate::hugr::patch::simple_replace::serial::SerialSimpleReplacement; /// Serialized format for [`PersistentReplacement`] pub type SerialPersistentReplacement = SerialSimpleReplacement; @@ -53,28 +51,28 @@ impl> From> for CommitData { /// Serialized format for commit state space #[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)] -pub struct SerialCommitStateSpace { +pub struct SerialCommitStateSpace { /// The serialized history graph containing commit data - pub graph: SerializedHistoryGraph, (), R>, + pub graph: SerializedHistoryGraph, (), PointerEqResolver>, /// The base commit ID pub base_commit: CommitId, } -impl CommitStateSpace { +impl CommitStateSpace { /// Create a new [`CommitStateSpace`] from its serialized format - pub fn from_serial>(value: SerialCommitStateSpace) -> Self { + pub fn from_serial + Clone>(value: SerialCommitStateSpace) -> Self { let SerialCommitStateSpace { graph, base_commit } = value; // Deserialize the SerializedHistoryGraph into a HistoryGraph with CommitData let graph = graph.map_nodes(|n| CommitData::from_serial(n)); - let graph = HistoryGraph::try_from_serialized(graph, R::default()) + let graph = HistoryGraph::try_from_serialized(graph, PointerEqResolver) .expect("failed to deserialize history graph"); Self { graph, base_commit } } /// Convert a [`CommitStateSpace`] into its serialized format - pub fn into_serial>(self) -> SerialCommitStateSpace { + pub fn into_serial>(self) -> SerialCommitStateSpace { let Self { graph, base_commit } = self; let graph = graph.to_serialized(); let graph = graph.map_nodes(|n| n.into_serial()); @@ -82,7 +80,10 @@ impl CommitStateSpace { } /// Create a serialized format from a reference to [`CommitStateSpace`] - pub fn to_serial>(&self) -> SerialCommitStateSpace { + pub fn to_serial(&self) -> SerialCommitStateSpace + where + H: From, + { let Self { graph, base_commit } = self; let graph = graph.to_serialized(); let graph = graph.map_nodes(|n| n.into_serial()); @@ -93,46 +94,51 @@ impl CommitStateSpace { } } -impl, R: Resolver> From> for SerialCommitStateSpace { - fn from(value: CommitStateSpace) -> Self { +impl> From for SerialCommitStateSpace { + fn from(value: CommitStateSpace) -> Self { value.into_serial() } } -impl, R: Resolver> From> for CommitStateSpace { - fn from(value: SerialCommitStateSpace) -> Self { +impl> From> for CommitStateSpace { + fn from(value: SerialCommitStateSpace) -> Self { CommitStateSpace::from_serial(value) } } #[cfg(test)] mod tests { + use derive_more::derive::Into; use rstest::rstest; + use serde_with::serde_as; use super::*; use crate::{ - SerdeHashResolver, - tests::{WrappedHugr, test_state_space}, + envelope::serde_with::AsStringEnvelope, hugr::persistent::tests::test_state_space, }; + #[serde_as] + #[derive(Debug, Clone, serde::Serialize, serde::Deserialize, From, Into)] + struct WrappedHugr { + #[serde_as(as = "AsStringEnvelope")] + pub hugr: Hugr, + } + #[cfg_attr(miri, ignore)] // Opening files is not supported in (isolated) miri #[rstest] - fn test_serialize_state_space( - test_state_space: ( - CommitStateSpace>, - [CommitId; 4], - ), - ) { + fn test_serialize_state_space(test_state_space: (CommitStateSpace, [CommitId; 4])) { let (state_space, _) = test_state_space; let serialized = state_space.to_serial::(); - let deser = CommitStateSpace::from_serial(serialized.clone()); - let serialized_2 = deser.to_serial::(); + let deser = CommitStateSpace::from_serial(serialized); + let _serialized_2 = deser.to_serial::(); - insta::assert_snapshot!(serde_json::to_string_pretty(&serialized).unwrap()); - assert_eq!( - serde_json::to_string(&serialized).unwrap(), - serde_json::to_string(&serialized_2).unwrap() - ); + // TODO: add this once PointerEqResolver is replaced by a deterministic resolver + // insta::assert_snapshot!(serde_json::to_string_pretty(&serialized).unwrap()); + // see https://github.com/CQCL/hugr/issues/2299 + // assert_eq!( + // serde_json::to_string(&serialized), + // serde_json::to_string(&serialized_2) + // ); } } diff --git a/hugr-persistent/src/tests.rs b/hugr-core/src/hugr/persistent/tests.rs similarity index 81% rename from hugr-persistent/src/tests.rs rename to hugr-core/src/hugr/persistent/tests.rs index 77b26be8ac..ae0876ef5b 100644 --- a/hugr-persistent/src/tests.rs +++ b/hugr-core/src/hugr/persistent/tests.rs @@ -1,20 +1,22 @@ use std::collections::{BTreeMap, HashMap}; -use derive_more::derive::{From, Into}; -use hugr_core::{ +use rstest::*; + +use crate::{ IncomingPort, Node, OutgoingPort, SimpleReplacement, - builder::{DFGBuilder, Dataflow, DataflowHugr, endo_sig, inout_sig}, + builder::{DFGBuilder, Dataflow, DataflowHugr, inout_sig}, extension::prelude::bool_t, - hugr::{Hugr, HugrView, patch::Patch, views::SiblingSubgraph}, + hugr::{ + Hugr, HugrView, + patch::Patch, + persistent::{Commit, PatchNode}, + views::SiblingSubgraph, + }, ops::handle::NodeHandle, std_extensions::logic::LogicOp, }; -use rstest::*; -use crate::{ - Commit, CommitStateSpace, PatchNode, PersistentHugr, PersistentReplacement, Resolver, - state_space::CommitId, -}; +use super::{CommitStateSpace, state_space::CommitId}; /// Creates a simple test Hugr with a DFG that contains a small boolean circuit /// @@ -205,10 +207,10 @@ fn create_not_and_to_xor_replacement(hugr: &Hugr) -> SimpleReplacement { /// - `commit1` and `commit2` are disjoint with `commit4` (i.e. compatible), /// - `commit2` depends on `commit1` #[fixture] -pub(crate) fn test_state_space() -> (CommitStateSpace, [CommitId; 4]) { +pub(super) fn test_state_space() -> (CommitStateSpace, [CommitId; 4]) { let (base_hugr, [not0_node, not1_node, _and_node]) = simple_hugr(); - let mut state_space = CommitStateSpace::::with_base(base_hugr); + let mut state_space = CommitStateSpace::with_base(base_hugr); // Create first replacement (replace NOT0 with two NOT gates) let replacement1 = create_double_not_replacement(state_space.base_hugr(), not0_node); @@ -216,11 +218,8 @@ pub(crate) fn test_state_space() -> (CommitStateSpace, [CommitId // Add first commit to state space, replacing NOT0 with two NOT gates let commit1 = { let to_patch_node = |n: Node| PatchNode(state_space.base(), n); - let new_host = state_space.try_extract_hugr([state_space.base()]).unwrap(); // translate replacement1 to patch nodes in the base commit of the state space - let replacement1 = replacement1 - .map_host_nodes(to_patch_node, &new_host) - .unwrap(); + let replacement1 = replacement1.map_host_nodes(to_patch_node); state_space.try_add_replacement(replacement1).unwrap() }; @@ -260,10 +259,7 @@ pub(crate) fn test_state_space() -> (CommitStateSpace, [CommitId }; // translate replacement2 to patch nodes - let new_host = state_space.try_extract_hugr([commit1]).unwrap(); - let replacement2 = replacement2 - .map_host_nodes(to_patch_node, &new_host) - .unwrap(); + let replacement2 = replacement2.map_host_nodes(to_patch_node); state_space.try_add_replacement(replacement2).unwrap() }; @@ -272,11 +268,9 @@ pub(crate) fn test_state_space() -> (CommitStateSpace, [CommitId let commit3 = { let replacement3 = create_not_and_to_xor_replacement(state_space.base_hugr()); let to_patch_node = |n: Node| PatchNode(state_space.base(), n); - let new_host = state_space.try_extract_hugr([state_space.base()]).unwrap(); - let replacement3 = replacement3 - .map_host_nodes(to_patch_node, &new_host) - .unwrap(); - state_space.try_add_replacement(replacement3).unwrap() + state_space + .try_add_replacement(replacement3.map_host_nodes(to_patch_node)) + .unwrap() }; // Create a fourth commit that is disjoint from `commit1`, replacing NOT1 @@ -284,54 +278,13 @@ pub(crate) fn test_state_space() -> (CommitStateSpace, [CommitId let commit4 = { let replacement4 = create_double_not_replacement(state_space.base_hugr(), not1_node); let to_patch_node = |n: Node| PatchNode(state_space.base(), n); - let new_host = state_space.try_extract_hugr([state_space.base()]).unwrap(); - let replacement4 = replacement4 - .map_host_nodes(to_patch_node, &new_host) - .unwrap(); + let replacement4 = replacement4.map_host_nodes(to_patch_node); state_space.try_add_replacement(replacement4).unwrap() }; (state_space, [commit1, commit2, commit3, commit4]) } -#[fixture] -pub(super) fn persistent_hugr_empty_child() -> (PersistentHugr, [CommitId; 2], [PatchNode; 3]) { - let (triple_not_hugr, not_nodes) = { - let mut dfg_builder = DFGBuilder::new(endo_sig(bool_t())).unwrap(); - let [mut w] = dfg_builder.input_wires_arr(); - let not_nodes = [(); 3].map(|()| { - let handle = dfg_builder.add_dataflow_op(LogicOp::Not, vec![w]).unwrap(); - [w] = handle.outputs_arr(); - handle.node() - }); - ( - dfg_builder.finish_hugr_with_outputs([w]).unwrap(), - not_nodes, - ) - }; - let mut hugr = PersistentHugr::with_base(triple_not_hugr); - let empty_hugr = { - let dfg_builder = DFGBuilder::new(endo_sig(bool_t())).unwrap(); - let inputs = dfg_builder.input_wires(); - dfg_builder.finish_hugr_with_outputs(inputs).unwrap() - }; - let subg_nodes = [PatchNode(hugr.base(), not_nodes[1])]; - let repl = PersistentReplacement::try_new( - SiblingSubgraph::try_from_nodes(subg_nodes, &hugr).unwrap(), - &hugr, - empty_hugr, - ) - .unwrap(); - - let empty_commit = hugr.try_add_replacement(repl).unwrap(); - let base_commit = hugr.base(); - ( - hugr, - [base_commit, empty_commit], - not_nodes.map(|n| PatchNode(base_commit, n)), - ) -} - #[rstest] fn test_successive_replacements(test_state_space: (CommitStateSpace, [CommitId; 4])) { let (state_space, [commit1, commit2, _commit3, _commit4]) = test_state_space; @@ -466,7 +419,8 @@ fn test_try_add_replacement(test_state_space: (CommitStateSpace, [CommitId; 4])) let result = persistent_hugr.try_add_replacement(repl4.clone()); assert!( result.is_ok(), - "[commit1, commit2] + [commit4] are compatible. Got {result:?}" + "[commit1, commit2] + [commit4] are compatible. Got {:?}", + result ); let hugr = persistent_hugr.to_hugr(); let exp_hugr = state_space @@ -482,7 +436,8 @@ fn test_try_add_replacement(test_state_space: (CommitStateSpace, [CommitId; 4])) let result = persistent_hugr.try_add_replacement(repl3.clone()); assert!( result.is_err(), - "[commit1, commit2] + [commit3] are incompatible. Got {result:?}" + "[commit1, commit2] + [commit3] are incompatible. Got {:?}", + result ); } } @@ -522,49 +477,3 @@ fn test_try_add_commit(test_state_space: (CommitStateSpace, [CommitId; 4])) { .expect_err("commit3 is incompatible with [commit1, commit2]"); } } - -/// A Hugr that serialises with no extensions -#[derive(Debug, Clone, serde::Serialize, serde::Deserialize, From, Into)] -pub(crate) struct WrappedHugr { - #[serde(with = "serial")] - pub hugr: Hugr, -} - -mod serial { - use hugr_core::envelope::EnvelopeConfig; - use hugr_core::std_extensions::STD_REG; - use serde::Deserialize; - - use super::*; - - pub(crate) fn serialize(hugr: &Hugr, serializer: S) -> Result - where - S: serde::Serializer, - { - let mut str = hugr - .store_str_with_exts(EnvelopeConfig::text(), &STD_REG) - .map_err(serde::ser::Error::custom)?; - // TODO: replace this with a proper hugr hash (see https://github.com/CQCL/hugr/issues/2091) - remove_encoder_version(&mut str); - serializer.serialize_str(&str) - } - - fn remove_encoder_version(str: &mut String) { - // Remove encoder version information for consistent test output - let encoder_pattern = r#""encoder":"hugr-rs v"#; - if let Some(start) = str.find(encoder_pattern) { - if let Some(end) = str[start..].find(r#"","#) { - let end = start + end + 2; // +2 for the `",` part - str.replace_range(start..end, ""); - } - } - } - - pub(crate) fn deserialize<'de, D>(deserializer: D) -> Result - where - D: serde::Deserializer<'de>, - { - let str = String::deserialize(deserializer)?; - Hugr::load_str(str, Some(&STD_REG)).map_err(serde::de::Error::custom) - } -} diff --git a/hugr-persistent/src/trait_impls.rs b/hugr-core/src/hugr/persistent/trait_impls.rs similarity index 92% rename from hugr-persistent/src/trait_impls.rs rename to hugr-core/src/hugr/persistent/trait_impls.rs index 17fadca6c6..6c68762029 100644 --- a/hugr-persistent/src/trait_impls.rs +++ b/hugr-core/src/hugr/persistent/trait_impls.rs @@ -1,19 +1,18 @@ use std::collections::HashMap; use itertools::{Either, Itertools}; +use portgraph::render::MermaidFormat; -use hugr_core::{ +use crate::{ Direction, Hugr, HugrView, Node, Port, - extension::ExtensionRegistry, hugr::{ - self, Patch, SimpleReplacementError, + Patch, SimpleReplacementError, internal::HugrInternals, views::{ ExtractionResult, render::{self, MermaidFormatter, NodeLabel}, }, }, - ops::OpType, }; use super::{ @@ -37,9 +36,9 @@ impl Patch for PersistentReplacement { } } -impl HugrInternals for PersistentHugr { +impl HugrInternals for PersistentHugr { type RegionPortgraph<'p> - = portgraph::MultiPortGraph + = portgraph::MultiPortGraph where Self: 'p; @@ -58,10 +57,15 @@ impl HugrInternals for PersistentHugr { let (hugr, node_map) = self.apply_all(); let parent = node_map[&parent]; - (hugr.into_region_portgraph(parent), node_map) + let region = portgraph::view::FlatRegion::new_without_root( + hugr.graph, + hugr.hierarchy, + parent.into_portgraph(), + ); + (region, node_map) } - fn node_metadata_map(&self, node: Self::Node) -> &hugr::NodeMetadataMap { + fn node_metadata_map(&self, node: Self::Node) -> &crate::hugr::NodeMetadataMap { self.as_state_space().node_metadata_map(node) } } @@ -71,7 +75,7 @@ impl HugrInternals for PersistentHugr { // the whole extracted HUGR in memory. We are currently prioritizing correctness // and clarity over performance and will optimise some of these operations in // the future as bottlenecks are encountered. (see #2248) -impl HugrView for PersistentHugr { +impl HugrView for PersistentHugr { fn entrypoint(&self) -> Self::Node { // The entrypoint remains unchanged throughout the patch history, and is // found in the base hugr. @@ -107,7 +111,7 @@ impl HugrView for PersistentHugr { Some(parent_inv) } - fn get_optype(&self, node: Self::Node) -> &OpType { + fn get_optype(&self, node: Self::Node) -> &crate::ops::OpType { self.as_state_space().get_optype(node) } @@ -175,11 +179,11 @@ impl HugrView for PersistentHugr { } else { match port.as_directed() { Either::Left(incoming) => { - let (out_node, out_port) = self.single_outgoing_port(node, incoming); + let (out_node, out_port) = self.get_single_outgoing_port(node, incoming); ret_ports.push((out_node, out_port.into())) } Either::Right(outgoing) => ret_ports.extend( - self.all_incoming_ports(node, outgoing) + self.get_all_incoming_ports(node, outgoing) .map(|(node, port)| (node, port.into())), ), } @@ -256,7 +260,7 @@ impl HugrView for PersistentHugr { // replace node labels with patch node IDs let node_labels_map: HashMap<_, _> = node_map .into_iter() - .map(|(k, v)| (v, format!("{k:?}"))) + .map(|(k, v)| (v, format!("{:?}", k))) .collect(); NodeLabel::Custom(node_labels_map) } @@ -277,7 +281,12 @@ impl HugrView for PersistentHugr { .with_port_offsets(formatter.port_offsets()) .with_type_labels(formatter.type_labels()); - config.finish() + hugr.graph + .mermaid_format() + .with_hierarchy(&hugr.hierarchy) + .with_node_style(render::node_style(&hugr, config.clone())) + .with_edge_style(render::edge_style(&hugr, config)) + .finish() } fn dot_string(&self) -> String @@ -287,8 +296,8 @@ impl HugrView for PersistentHugr { unimplemented!("use mermaid_string instead") } - fn extensions(&self) -> &ExtensionRegistry { - self.base_hugr().extensions() + fn extensions(&self) -> &crate::extension::ExtensionRegistry { + &self.base_hugr().extensions } fn extract_hugr( @@ -296,7 +305,7 @@ impl HugrView for PersistentHugr { parent: Self::Node, ) -> ( Hugr, - impl hugr::views::ExtractionResult + 'static, + impl crate::hugr::views::ExtractionResult + 'static, ) { let (hugr, apply_node_map) = self.apply_all(); let (extracted_hugr, extracted_node_map) = hugr.extract_hugr(apply_node_map[&parent]); @@ -321,7 +330,7 @@ impl HugrView for PersistentHugr { mod tests { use std::collections::HashSet; - use crate::{CommitStateSpace, state_space::CommitId}; + use crate::hugr::persistent::{CommitStateSpace, state_space::CommitId}; use super::super::tests::test_state_space; use super::*; diff --git a/hugr-persistent/src/walker.rs b/hugr-core/src/hugr/persistent/walker.rs similarity index 52% rename from hugr-persistent/src/walker.rs rename to hugr-core/src/hugr/persistent/walker.rs index bf579398c5..47123f393e 100644 --- a/hugr-persistent/src/walker.rs +++ b/hugr-core/src/hugr/persistent/walker.rs @@ -44,29 +44,25 @@ //! 5. Once exploration is complete (e.g. a pattern was fully matched), the //! walker can be converted into a [`PersistentHugr`] instance using //! [`Walker::into_persistent_hugr`]. The matched nodes and ports can then be -//! used to create a -//! [`SiblingSubgraph`](hugr_core::hugr::views::SiblingSubgraph) object, -//! which can then be used to create new -//! [`SimpleReplacement`](hugr_core::SimpleReplacement) instances---and -//! possibly in turn be added to the commit state space and the exploration -//! of the state space restarted! +//! used to create a [`SiblingSubgraph`](crate::hugr::views::SiblingSubgraph) +//! object, which can then be used to create new +//! [`SimpleReplacement`](crate::SimpleReplacement) instances---and possibly +//! in turn be added to the commit state space and the exploration of the +//! state space restarted! //! //! This approach allows efficiently finding patterns across many potential //! versions of the graph simultaneously, without having to materialize //! each version separately. +mod pinned; +pub use pinned::PinnedWire; + use std::{borrow::Cow, collections::BTreeSet}; -use hugr_core::hugr::patch::simple_replace::BoundaryMode; -use hugr_core::ops::handle::DataflowParentID; use itertools::{Either, Itertools}; use thiserror::Error; -use hugr_core::{Direction, Hugr, HugrView, Port, PortIndex, hugr::views::RootCheckable}; - -use crate::{Commit, PersistentReplacement, PinnedSubgraph}; - -use crate::{PersistentWire, PointerEqResolver, resolver::Resolver}; +use crate::{Direction, HugrView, Port}; use super::{CommitStateSpace, InvalidCommit, PatchNode, PersistentHugr, state_space::CommitId}; @@ -88,30 +84,30 @@ use super::{CommitStateSpace, InvalidCommit, PatchNode, PersistentHugr, state_sp /// expansions of the current walker. /// current walker. #[derive(Debug, Clone)] -pub struct Walker<'a, R: Clone = PointerEqResolver> { +pub struct Walker<'a> { /// The state space being traversed. - state_space: Cow<'a, CommitStateSpace>, + state_space: Cow<'a, CommitStateSpace>, /// The subset of compatible commits in `state_space` that are currently /// selected. // Note that we could store this as a set of `CommitId`s, but it is very // convenient to have access to all the methods of PersistentHugr (on top // of guaranteeing the compatibility invariant). The tradeoff is more // memory consumption. - selected_commits: PersistentHugr, + selected_commits: PersistentHugr, /// The set of nodes that have been traversed by the walker and can no /// longer be rewritten. pinned_nodes: BTreeSet, } -impl<'a, R: Resolver> Walker<'a, R> { +impl<'a> Walker<'a> { /// Create a new [`Walker`] over the given state space. /// /// No nodes are pinned initially. The [`Walker`] starts with only the base /// Hugr `state_space.base_hugr()` selected. - pub fn new(state_space: impl Into>>) -> Self { + pub fn new(state_space: impl Into>) -> Self { let state_space = state_space.into(); let base = state_space.base_commit().clone(); - let selected_commits: PersistentHugr = PersistentHugr::from_commit(base); + let selected_commits = PersistentHugr::from_commit(base); Self { state_space, selected_commits, @@ -122,7 +118,7 @@ impl<'a, R: Resolver> Walker<'a, R> { /// Create a new [`Walker`] with a single pinned node. pub fn from_pinned_node( node: PatchNode, - state_space: impl Into>>, + state_space: impl Into>, ) -> Self { let mut walker = Self::new(state_space); walker @@ -150,59 +146,68 @@ impl<'a, R: Resolver> Walker<'a, R> { } } else { let commit = self.state_space.get_commit(commit_id).clone(); - self.try_select_commit(commit)?; + // TODO/Optimize: we should be able to check for an AlreadyPinned error at + // the same time that we check the ancestors are compatible in + // `PersistentHugr`, with e.g. a callback, instead of storing a backup + let backup = self.selected_commits.clone(); + self.selected_commits.try_add_commit(commit)?; + if let Some(&pinned_node) = self + .pinned_nodes + .iter() + .find(|&&n| !self.selected_commits.contains_node(n)) + { + self.selected_commits = backup; + return Err(PinNodeError::AlreadyPinned(pinned_node)); + } } Ok(self.pinned_nodes.insert(node)) } - /// Add a commit to the selected commits of the Walker. + /// Get the wire connected to a specified port of a pinned node. /// - /// Return the ID of the added commit if it was added successfully, or the - /// existing ID of the commit if it was already selected. + /// # Panics + /// Panics if `node` is not already pinned in this Walker. + pub fn get_wire(&self, node: PatchNode, port: impl Into) -> PinnedWire { + PinnedWire::from_pinned_port(node, port, self) + } + + /// Materialise the [`PersistentHugr`] containing all the compatible commits + /// that have been selected during exploration. + pub fn into_persistent_hugr(self) -> PersistentHugr { + self.selected_commits + } + + /// View the [`PersistentHugr`] containing all the compatible commits that + /// have been selected so far during exploration. /// - /// Return an error if the commit is not compatible with the current set of - /// selected commits, or if the commit deletes an already pinned node. - pub fn try_select_commit(&mut self, commit: Commit) -> Result { - // TODO: we should be able to check for an AlreadyPinned error at - // the same time that we check the ancestors are compatible in - // `PersistentHugr`, with e.g. a callback, instead of storing a backup - let backup = self.selected_commits.clone(); - let commit_id = self.selected_commits.try_add_commit(commit)?; - if let Some(&pinned_node) = self - .pinned_nodes - .iter() - .find(|&&n| !self.selected_commits.contains_node(n)) - { - self.selected_commits = backup; - return Err(PinNodeError::AlreadyPinned(pinned_node)); - } - Ok(commit_id) + /// Of the space of all possible HUGRs that can be obtained from future + /// expansions of the walker, this is the HUGR corresponding to selecting + /// as few commits as possible (i.e. all the commits that have been selected + /// so far and no more). + pub fn as_hugr_view(&self) -> &PersistentHugr { + &self.selected_commits } /// Expand the Walker by pinning a node connected to the given wire. /// /// To understand how Walkers are expanded, it is useful to understand how /// in a walker, the HUGR graph is partitioned into two parts: - /// - a subgraph made of pinned nodes: this part of the HUGR is frozen: it - /// cannot be modified by further expansions the Walker. + /// - a subgraph made of pinned nodes: this part of the HUGR is frozen: it cannot be + /// modified by further expansions the Walker. /// - the complement subgraph: the unpinned part of the HUGR has not been - /// explored yet. Multiple alternative HUGRs can be obtained depending on - /// which commits are selected. + /// explored yet. Multiple alternative HUGRs can be obtained depending + /// on which commits are selected. /// /// To every walker thus corresponds a space of possible HUGRs that can be - /// obtained, depending on which commits are selected and which further - /// nodes are pinned. The expansion of a walker returns a set of - /// walkers, which together cover the same space of possible HUGRs, each - /// having a different additional node pinned. + /// obtained, depending on which commits are selected and which further nodes + /// are pinned. The expansion of a walker returns a set of walkers, which + /// together cover the same space of possible HUGRs, each having a different + /// additional node pinned. /// - /// If the wire is not complete yet, return an iterator over all possible - /// [`Walker`]s that can be created by pinning exactly one additional - /// node (or one additonal commit with an empty wire) connected to - /// `wire`. Each returned [`Walker`] represents a different alternative - /// Hugr in the exploration space. - /// - /// If the wire is already complete, return an iterator containing one - /// walker: the current walker unchanged. + /// Return an iterator over all possible [`Walker`]s that can be created by + /// pinning exactly one additional node connected to `wire`. Each returned + /// [`Walker`] represents a different alternative Hugr in the exploration + /// space. /// /// Optionally, the expansion can be restricted to only ports with the given /// direction (incoming or outgoing). @@ -214,221 +219,78 @@ impl<'a, R: Resolver> Walker<'a, R> { /// true, then an empty iterator is returned. pub fn expand<'b>( &'b self, - wire: &'b PersistentWire, + wire: &'b PinnedWire, dir: impl Into>, - ) -> impl Iterator> + 'b { + ) -> impl Iterator> + 'b { let dir = dir.into(); - if self.is_complete(wire, dir) { - return Either::Left(std::iter::once(self.clone())); - } - // Find unpinned ports on the wire (satisfying the direction constraint) - let unpinned_ports = self.wire_unpinned_ports(wire, dir); + let unpinned_ports = wire.unpinned_ports(dir); // Obtain set of pinnable nodes by considering all ports (in descendant // commits) equivalent to currently unpinned ports. let pinnable_nodes = unpinned_ports .flat_map(|(node, port)| self.equivalent_descendant_ports(node, port)) - .map(|(n, _, commits)| (n, commits)) + .map(|(n, _)| n) .unique(); - let new_walkers = pinnable_nodes.filter_map(|(pinnable_node, new_commits)| { - let contains_new_commit = || { - new_commits - .iter() - .any(|&cm| !self.selected_commits.contains_id(cm)) - }; + pinnable_nodes.filter_map(|pinnable_node| { debug_assert!( - !self.is_pinned(pinnable_node) || contains_new_commit(), - "trying to pin already pinned node and no new commit is selected" + !self.is_pinned(pinnable_node), + "trying to pin already pinned node" ); - // Update the selected commits to include the new commits. - let new_selected_commits = self - .state_space - .try_extract_hugr(self.selected_commits.all_commit_ids().chain(new_commits)) - .ok()?; - - // Make sure that the pinned nodes are still valid after including the new - // selected commits. - if self - .pinned_nodes - .iter() - .any(|&pnode| !new_selected_commits.contains_node(pnode)) - { - return None; - } - - // Construct a new walker and pin `pinnable_node`. - let mut new_walker = Walker { - state_space: self.state_space.clone(), - selected_commits: new_selected_commits, - pinned_nodes: self.pinned_nodes.clone(), - }; + // Construct a new walker by pinning `pinnable_node` (if possible). + let mut new_walker = self.clone(); new_walker.try_pin_node(pinnable_node).ok()?; Some(new_walker) - }); - - Either::Right(new_walkers) - } - - /// Create a new commit from a set of complete pinned wires and a - /// replacement. - /// - /// The subgraph of the commit is the subgraph given by the set of edges - /// in `wires`. `map_boundary` must provide a map from the boundary ports - /// of the subgraph to the inputs/output ports in `repl`. The returned port - /// must be of the opposite direction as the port passed as argument: - /// - an incoming subgraph port must be mapped to an outgoing port of the - /// input node of `repl` - /// - an outgoing subgraph port must be mapped to an incoming port of the - /// output node of `repl` - /// - /// ## Panics - /// - /// This will panic if repl is not a DFG graph. - pub fn try_create_commit( - &self, - subgraph: impl Into, - repl: impl RootCheckable, - map_boundary: impl Fn(PatchNode, Port) -> Port, - ) -> Result { - let pinned_subgraph = subgraph.into(); - let subgraph = pinned_subgraph.to_sibling_subgraph(self.as_hugr_view())?; - let selected_commits = pinned_subgraph - .selected_commits() - .map(|id| self.state_space.get_commit(id).clone()); - - let repl = { - let mut repl = repl.try_into_checked().expect("replacement is not DFG"); - let new_inputs = subgraph - .incoming_ports() - .iter() - .flatten() // because of singleton-vec wrapping above - .map(|&(n, p)| { - map_boundary(n, p.into()) - .as_outgoing() - .expect("unexpected port direction returned by map_boundary") - .index() - }) - .collect_vec(); - let new_outputs = subgraph - .outgoing_ports() - .iter() - .map(|&(n, p)| { - map_boundary(n, p.into()) - .as_incoming() - .expect("unexpected port direction returned by map_boundary") - .index() - }) - .collect_vec(); - repl.map_function_type(&new_inputs, &new_outputs)?; - PersistentReplacement::try_new(subgraph, self.as_hugr_view(), repl.into_hugr())? - }; - - Commit::try_new(repl, selected_commits, &self.state_space) - } -} - -impl Walker<'_, R> { - /// Get the wire connected to a specified port of a pinned node. - /// - /// # Panics - /// Panics if `node` is not already pinned in this Walker. - pub fn get_wire(&self, node: PatchNode, port: impl Into) -> PersistentWire { - assert!(self.is_pinned(node), "node must be pinned"); - self.selected_commits.get_wire(node, port) - } - - /// Materialise the [`PersistentHugr`] containing all the compatible commits - /// that have been selected during exploration. - pub fn into_persistent_hugr(self) -> PersistentHugr { - self.selected_commits - } - - /// View the [`PersistentHugr`] containing all the compatible commits that - /// have been selected so far during exploration. - /// - /// Of the space of all possible HUGRs that can be obtained from future - /// expansions of the walker, this is the HUGR corresponding to selecting - /// as few commits as possible (i.e. all the commits that have been selected - /// so far and no more). - pub fn as_hugr_view(&self) -> &PersistentHugr { - &self.selected_commits - } - - /// Check if a node is pinned in the [`Walker`]. - pub fn is_pinned(&self, node: PatchNode) -> bool { - self.pinned_nodes.contains(&node) - } - - /// Iterate over all pinned nodes in the [`Walker`]. - pub fn pinned_nodes(&self) -> impl Iterator + '_ { - self.pinned_nodes.iter().copied() + }) } /// Get all equivalent ports among the commits that are descendants of the /// current commit. /// /// The ports in the returned iterator will be in the same direction as - /// `port`. For each equivalent port, also return the set of empty commits - /// that were visited to find it. - fn equivalent_descendant_ports( - &self, - node: PatchNode, - port: Port, - ) -> Vec<(PatchNode, Port, BTreeSet)> { + /// `port`. + fn equivalent_descendant_ports(&self, node: PatchNode, port: Port) -> Vec<(PatchNode, Port)> { // Now, perform a BFS to find all equivalent ports - let mut all_ports = vec![(node, port, BTreeSet::new())]; + let mut all_ports = vec![(node, port)]; let mut index = 0; while index < all_ports.len() { - let (node, port, empty_commits) = all_ports[index].clone(); + let (node, port) = all_ports[index]; index += 1; for (child_id, (opp_node, opp_port)) in self.state_space.children_at_boundary_port(node, port) { - for (node, port) in self.state_space.linked_child_ports( - opp_node, - opp_port, - child_id, - BoundaryMode::SnapToHost, - ) { - let mut empty_commits = empty_commits.clone(); - if node.0 != child_id { - empty_commits.insert(child_id); + match opp_port.as_directed() { + Either::Left(in_port) => { + if let Some((n, p)) = self + .state_space + .linked_child_output(opp_node, in_port, child_id) + { + all_ports.push((n, p.into())); + } + } + Either::Right(out_port) => { + all_ports.extend( + self.state_space + .linked_child_inputs(opp_node, out_port, child_id) + .map(|(n, p)| (n, p.into())), + ); } - all_ports.push((node, port, empty_commits)); } } } all_ports } -} - -#[cfg(test)] -impl Walker<'_, R> { - // Check walker equality by comparing pointers to the state space and - // other fields. Only for testing purposes. - fn component_wise_ptr_eq(&self, other: &Self) -> bool { - std::ptr::eq(self.state_space.as_ref(), other.state_space.as_ref()) - && self.pinned_nodes == other.pinned_nodes - && BTreeSet::from_iter(self.selected_commits.all_commit_ids()) - == BTreeSet::from_iter(other.selected_commits.all_commit_ids()) - } - /// Check if the Walker cannot be expanded further, i.e. expanding it - /// returns the same Walker. - fn no_more_expansion(&self, wire: &PersistentWire, dir: impl Into>) -> bool { - let Some([new_walker]) = self.expand(wire, dir).collect_array() else { - return false; - }; - new_walker.component_wise_ptr_eq(self) + fn is_pinned(&self, node: PatchNode) -> bool { + self.pinned_nodes.contains(&node) } } -impl CommitStateSpace { +impl CommitStateSpace { /// Given a node and port, return all child commits of the current `node` /// that delete `node` but keep at least one port linked to `(node, port)`. /// In other words, (node, port) is a boundary port of the subgraph of @@ -487,37 +349,27 @@ impl From for PinNodeError { } } -impl<'a, R: Clone> From<&'a CommitStateSpace> for Cow<'a, CommitStateSpace> { - fn from(value: &'a CommitStateSpace) -> Self { +impl<'a> From<&'a CommitStateSpace> for Cow<'a, CommitStateSpace> { + fn from(value: &'a CommitStateSpace) -> Self { Cow::Borrowed(value) } } -impl From> for Cow<'_, CommitStateSpace> { - fn from(value: CommitStateSpace) -> Self { +impl From for Cow<'_, CommitStateSpace> { + fn from(value: CommitStateSpace) -> Self { Cow::Owned(value) } } #[cfg(test)] mod tests { - use std::collections::BTreeSet; - - use hugr_core::{ - Direction, HugrView, IncomingPort, OutgoingPort, - builder::{DFGBuilder, Dataflow, DataflowHugr, endo_sig}, - extension::prelude::bool_t, - std_extensions::logic::LogicOp, - }; - use itertools::Itertools; use rstest::rstest; + use crate::hugr::persistent::{state_space::CommitId, tests::test_state_space}; + use crate::std_extensions::logic::LogicOp; + use crate::{IncomingPort, OutgoingPort}; + use super::*; - use crate::{ - PersistentHugr, Walker, - state_space::CommitId, - tests::{persistent_hugr_empty_child, test_state_space}, - }; #[rstest] fn test_walker_base_or_child_expansion(test_state_space: (CommitStateSpace, [CommitId; 4])) { @@ -538,8 +390,7 @@ mod tests { let in0 = walker.get_wire(base_and_node, IncomingPort::from(0)); // a single incoming port (already pinned) => no more expansion - assert!(walker.no_more_expansion(&in0, Direction::Incoming)); - + assert!(walker.expand(&in0, Direction::Incoming).next().is_none()); // commit 2 cannot be applied, because AND is pinned // => only base commit, or commit1 let out_walkers = walker.expand(&in0, Direction::Outgoing).collect_vec(); @@ -547,11 +398,11 @@ mod tests { for new_walker in out_walkers { // new wire is complete (and thus cannot be expanded) let in0 = new_walker.get_wire(base_and_node, IncomingPort::from(0)); - assert!(new_walker.is_complete(&in0, None)); - assert!(new_walker.no_more_expansion(&in0, None)); + assert!(in0.is_complete(None)); + assert!(new_walker.expand(&in0, None).next().is_none()); // all nodes on wire are pinned - let (not_node, _) = in0.single_outgoing_port(new_walker.as_hugr_view()).unwrap(); + let (not_node, _) = in0.pinned_outport().unwrap(); assert!(new_walker.is_pinned(base_and_node)); assert!(new_walker.is_pinned(not_node)); @@ -605,8 +456,9 @@ mod tests { assert!(walker.is_pinned(not4_node)); let not4_out = walker.get_wire(not4_node, OutgoingPort::from(0)); + let expanded_out = walker.expand(¬4_out, Direction::Outgoing).collect_vec(); // a single outgoing port (already pinned) => no more expansion - assert!(walker.no_more_expansion(¬4_out, Direction::Outgoing)); + assert!(expanded_out.is_empty()); // Three options: // - AND gate from base @@ -625,20 +477,17 @@ mod tests { .collect::>(); assert!( exp_options.remove(&commit_ids), - "{commit_ids:?} not an expected set of commit IDs (or duplicate)" + "{:?} not an expected set of commit IDs (or duplicate)", + commit_ids ); // new wire is complete (and thus cannot be expanded) let not4_out = new_walker.get_wire(not4_node, OutgoingPort::from(0)); - assert!(new_walker.is_complete(¬4_out, None)); - assert!(new_walker.no_more_expansion(¬4_out, None)); + assert!(not4_out.is_complete(None)); + assert!(new_walker.expand(¬4_out, None).next().is_none()); // all nodes on wire are pinned - let (next_node, _) = not4_out - .all_incoming_ports(new_walker.as_hugr_view()) - .exactly_one() - .ok() - .unwrap(); + let (next_node, _) = not4_out.pinned_inports().exactly_one().ok().unwrap(); assert!(new_walker.is_pinned(not4_node)); assert!(new_walker.is_pinned(next_node)); @@ -659,7 +508,8 @@ mod tests { assert!( exp_options.is_empty(), - "missing expected options: {exp_options:?}" + "missing expected options: {:?}", + exp_options ); } @@ -677,7 +527,7 @@ mod tests { let hugr = state_space.try_extract_hugr([commit4]).unwrap(); let (second_not_node, out_port) = - hugr.single_outgoing_port(base_and_node, IncomingPort::from(1)); + hugr.get_single_outgoing_port(base_and_node, IncomingPort::from(1)); assert_eq!(second_not_node.0, commit4); assert_eq!(out_port, OutgoingPort::from(0)); @@ -685,153 +535,11 @@ mod tests { .try_extract_hugr([commit1, commit2, commit4]) .unwrap(); let (new_and_node, in_port) = hugr - .all_incoming_ports(second_not_node, out_port) + .get_all_incoming_ports(second_not_node, out_port) .exactly_one() .ok() .unwrap(); assert_eq!(new_and_node.0, commit2); assert_eq!(in_port, 1.into()); } - - /// Test that the walker handles empty replacements correctly. - /// - /// The base hugr is a sequence of 3 NOT gates, with a single input/output - /// boolean. A single replacement exists in the state space, which replaces - /// the middle NOT gate with nothing. - #[rstest] - fn test_walk_over_empty_repls( - persistent_hugr_empty_child: (PersistentHugr, [CommitId; 2], [PatchNode; 3]), - ) { - let (hugr, [base_commit, empty_commit], [not0, not1, not2]) = persistent_hugr_empty_child; - let walker = Walker::from_pinned_node(not0, hugr.as_state_space()); - - let not0_outwire = walker.get_wire(not0, OutgoingPort::from(0)); - let expanded_wires = walker - .expand(¬0_outwire, Direction::Incoming) - .collect_vec(); - - assert_eq!(expanded_wires.len(), 2); - - let connected_inports: BTreeSet<_> = expanded_wires - .iter() - .map(|new_walker| { - let wire = new_walker.get_wire(not0, OutgoingPort::from(0)); - wire.all_incoming_ports(new_walker.as_hugr_view()) - .exactly_one() - .ok() - .unwrap() - }) - .collect(); - - assert_eq!( - connected_inports, - BTreeSet::from_iter([(not1, IncomingPort::from(0)), (not2, IncomingPort::from(0))]) - ); - - let traversed_commits: BTreeSet> = expanded_wires - .iter() - .map(|new_walker| { - let wire = new_walker.get_wire(not0, OutgoingPort::from(0)); - wire.owners().collect() - }) - .collect(); - - assert_eq!( - traversed_commits, - BTreeSet::from_iter([ - BTreeSet::from_iter([base_commit]), - BTreeSet::from_iter([base_commit, empty_commit]) - ]) - ); - } - - #[rstest] - fn test_create_commit_over_empty( - persistent_hugr_empty_child: (PersistentHugr, [CommitId; 2], [PatchNode; 3]), - ) { - let (hugr, [base_commit, empty_commit], [not0, _not1, not2]) = persistent_hugr_empty_child; - let mut walker = Walker { - state_space: hugr.as_state_space().into(), - selected_commits: hugr.clone(), - pinned_nodes: BTreeSet::from_iter([not0]), - }; - - // wire: Not0 -> Not2 (bridging over Not1) - let wire = walker.get_wire(not0, OutgoingPort::from(0)); - walker = walker.expand(&wire, None).exactly_one().ok().unwrap(); - let wire = walker.get_wire(not0, OutgoingPort::from(0)); - assert!(walker.is_complete(&wire, None)); - - let empty_hugr = { - let dfg_builder = DFGBuilder::new(endo_sig(bool_t())).unwrap(); - let inputs = dfg_builder.input_wires(); - dfg_builder.finish_hugr_with_outputs(inputs).unwrap() - }; - let commit = walker - .try_create_commit( - PinnedSubgraph::try_from_pinned(std::iter::empty(), [wire], &walker).unwrap(), - empty_hugr, - |node, port| { - assert_eq!(port.index(), 0); - assert!([not0, not2].contains(&node)); - match port.direction() { - Direction::Incoming => OutgoingPort::from(0).into(), - Direction::Outgoing => IncomingPort::from(0).into(), - } - }, - ) - .unwrap(); - - let mut new_state_space = hugr.as_state_space().to_owned(); - let commit_id = new_state_space.try_add_commit(commit.clone()).unwrap(); - assert_eq!( - new_state_space.parents(commit_id).collect::>(), - BTreeSet::from_iter([base_commit, empty_commit]) - ); - - let res_hugr: PersistentHugr = PersistentHugr::from_commit(commit); - assert!(res_hugr.validate().is_ok()); - - // should be an empty DFG hugr - // module root + function def + func I/O nodes + DFG entrypoint + I/O nodes - assert_eq!(res_hugr.num_nodes(), 1 + 1 + 2 + 1 + 2); - } - - /// Test that the walker handles empty replacements correctly. - /// - /// The base hugr is a sequence of 3 NOT gates, with a single input/output - /// boolean. A single replacement exists in the state space, which replaces - /// the middle NOT gate with nothing. - /// - /// In this test, we pin both the first and third NOT and see if the walker - /// suggests to possible wires as outgoing from the first NOT. This tests - /// the edge case in which a new wire already has all its ports pinned. - #[rstest] - fn test_walk_over_two_pinned_nodes( - persistent_hugr_empty_child: (PersistentHugr, [CommitId; 2], [PatchNode; 3]), - ) { - let (hugr, [base_commit, empty_commit], [not0, _not1, not2]) = persistent_hugr_empty_child; - let mut walker = Walker::from_pinned_node(not0, hugr.as_state_space()); - assert!(walker.try_pin_node(not2).unwrap()); - - let not0_outwire = walker.get_wire(not0, OutgoingPort::from(0)); - let expanded_walkers = walker.expand(¬0_outwire, Direction::Incoming); - - let expanded_wires: BTreeSet> = expanded_walkers - .map(|new_walker| { - new_walker - .get_wire(not0, OutgoingPort::from(0)) - .owners() - .collect() - }) - .collect(); - - assert_eq!( - expanded_wires, - BTreeSet::from_iter([ - BTreeSet::from_iter([base_commit]), - BTreeSet::from_iter([base_commit, empty_commit]) - ]) - ); - } } diff --git a/hugr-core/src/hugr/persistent/walker/pinned.rs b/hugr-core/src/hugr/persistent/walker/pinned.rs new file mode 100644 index 0000000000..02c4d4dcad --- /dev/null +++ b/hugr-core/src/hugr/persistent/walker/pinned.rs @@ -0,0 +1,164 @@ +//! Utilities for pinned ports and pinned wires. +//! +//! Encapsulation: we only ever expose pinned values publicly. + +use itertools::Either; + +use crate::{Direction, IncomingPort, OutgoingPort, Port, hugr::persistent::PatchNode}; + +use super::Walker; + +/// A wire in the current HUGR of a [`Walker`] with some of its endpoints +/// pinned. +/// +/// Just like a normal HUGR [`Wire`](crate::Wire), a [`PinnedWire`] has +/// endpoints: the ports that are linked together by the wire. A [`PinnedWire`] +/// however distinguishes itself in that each of its ports is specified either +/// as "pinned" or "unpinned". A port is pinned if and only if the node it is +/// attached to is pinned in the walker. +/// +/// A [`PinnedWire`] always has at least one pinned port. +/// +/// All pinned ports of a [`PinnedWire`] can be retrieved using +/// [`PinnedWire::pinned_inports`] and [`PinnedWire::pinned_outport`]. Unpinned +/// ports, on the other hand, represent undetermined connections, which may +/// still change as the walker is expanded (see [`Walker::expand`]). +/// +/// Whether all incoming or outgoing ports are pinned can be checked using +/// [`PinnedWire::is_complete`]. +#[derive(Debug, Clone)] +pub struct PinnedWire { + outgoing: MaybePinned, + incoming: Vec>, +} + +/// A private enum to track whether a port is pinned. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +enum MaybePinned

{ + Pinned(PatchNode, P), + Unpinned(PatchNode, P), +} + +impl

MaybePinned

{ - Pinned(PatchNode, P), - Unpinned(PatchNode, P), -} - -impl

MaybePinned

{ - fn new(node: PatchNode, port: P, walker: &Walker) -> Self { - debug_assert!( - walker.selected_commits.contains_node(node), - "pinned node not in walker" - ); - if walker.is_pinned(node) { - MaybePinned::Pinned(node, port) - } else { - MaybePinned::Unpinned(node, port) - } - } - - fn is_pinned(&self) -> bool { - matches!(self, MaybePinned::Pinned(_, _)) - } - - fn into_unpinned>(self) -> Option<(PatchNode, PP)> { - match self { - MaybePinned::Pinned(_, _) => None, - MaybePinned::Unpinned(node, port) => Some((node, port.into())), - } - } - - fn into_pinned>(self) -> Option<(PatchNode, PP)> { - match self { - MaybePinned::Pinned(node, port) => Some((node, port.into())), - MaybePinned::Unpinned(_, _) => None, - } - } -} - -impl PinnedWire { - /// Create a new pinned wire in `walker` from a pinned node and a port. - /// - /// # Panics - /// Panics if `node` is not pinned in `walker`. - pub fn from_pinned_port(node: PatchNode, port: impl Into, walker: &Walker) -> Self { - assert!(walker.is_pinned(node), "node must be pinned"); - - let (outgoing_node, outgoing_port) = match port.into().as_directed() { - Either::Left(incoming) => walker - .selected_commits - .get_single_outgoing_port(node, incoming), - Either::Right(outgoing) => (node, outgoing), - }; - - let outgoing = MaybePinned::new(outgoing_node, outgoing_port, walker); - - let incoming = walker - .selected_commits - .get_all_incoming_ports(outgoing_node, outgoing_port) - .map(|(n, p)| MaybePinned::new(n, p, walker)) - .collect(); - - Self { outgoing, incoming } - } - - /// Check if all ports on the wire in the given direction are pinned. - /// - /// A wire is complete in a direction if and only if expanding the wire - /// in that direction would yield no new walkers. If no direction is - /// specified, checks if the wire is complete in both directions. - pub fn is_complete(&self, dir: impl Into>) -> bool { - match dir.into() { - Some(Direction::Outgoing) => self.outgoing.is_pinned(), - Some(Direction::Incoming) => self.incoming.iter().all(|p| p.is_pinned()), - None => self.outgoing.is_pinned() && self.incoming.iter().all(|p| p.is_pinned()), - } - } - - /// Get the outgoing port of the wire, if it is pinned. - /// - /// Returns `None` if the outgoing port is not pinned. - pub fn pinned_outport(&self) -> Option<(PatchNode, OutgoingPort)> { - self.outgoing.into_pinned() - } - - /// Get all pinned incoming ports of the wire. - /// - /// Returns an iterator over all pinned incoming ports. - pub fn pinned_inports(&self) -> impl Iterator + '_ { - self.incoming.iter().filter_map(|&p| p.into_pinned()) - } - - /// Get all pinned ports of the wire. - pub fn all_pinned_ports(&self) -> impl Iterator + '_ { - fn to_port((node, port): (PatchNode, impl Into)) -> (PatchNode, Port) { - (node, port.into()) - } - self.pinned_outport() - .into_iter() - .map(to_port) - .chain(self.pinned_inports().map(to_port)) - } - - /// Get all unpinned ports of the wire, optionally filtering to only those - /// in the given direction. - pub(super) fn unpinned_ports( - &self, - dir: impl Into>, - ) -> impl Iterator + '_ { - let incoming = self - .incoming - .iter() - .filter_map(|p| p.into_unpinned::()); - let outgoing = self.outgoing.into_unpinned::(); - let dir = dir.into(); - mask_iter(incoming, dir != Some(Direction::Outgoing)) - .chain(mask_iter(outgoing, dir != Some(Direction::Incoming))) - } -} - -/// Return an iterator over the items in `iter` if `mask` is true, otherwise -/// return an empty iterator. -#[inline] -fn mask_iter(iter: impl IntoIterator, mask: bool) -> impl Iterator { - match mask { - true => Either::Left(iter.into_iter()), - false => Either::Right(std::iter::empty()), - } - .into_iter() -} diff --git a/hugr-core/src/hugr/serialize.rs b/hugr-core/src/hugr/serialize.rs index 7dcb14c1c1..8df7cf8357 100644 --- a/hugr-core/src/hugr/serialize.rs +++ b/hugr-core/src/hugr/serialize.rs @@ -78,6 +78,7 @@ impl Versioned { #[derive(Clone, Serialize, Deserialize, PartialEq, Debug)] struct NodeSer { + /// Node index of the parent. parent: Node, #[serde(flatten)] op: OpType, @@ -89,7 +90,7 @@ struct SerHugrLatest { /// For each node: (parent, `node_operation`) nodes: Vec, /// for each edge: (src, `src_offset`, tgt, `tgt_offset`) - edges: Vec<[(Node, Option); 2]>, + edges: Vec<[(Node, Option); 2]>, /// for each node: (metadata) #[serde(default)] metadata: Option>>, @@ -113,7 +114,7 @@ pub enum HUGRSerializationError { AttachError(#[from] AttachError), /// Failed to add edge. #[error("Failed to build edge when deserializing: {0}.")] - LinkError(#[from] LinkError), + LinkError(#[from] LinkError), /// Edges without port offsets cannot be present in operations without non-dataflow ports. #[error( "Cannot connect an {dir:?} edge without port offset to node {node} with operation type {op_type}." @@ -214,7 +215,7 @@ impl TryFrom<&Hugr> for SerHugrLatest { let op = hugr.get_optype(node); let is_value_port = offset < op.value_port_count(dir); let is_static_input = op.static_port(dir).is_some_and(|p| p.index() == offset); - let offset = (is_value_port || is_static_input).then_some(offset as u16); + let offset = (is_value_port || is_static_input).then_some(offset as u32); (node_rekey[&node], offset) }; @@ -282,7 +283,7 @@ impl TryFrom for Hugr { } if let Some(entrypoint) = entrypoint { - hugr.set_entrypoint(entrypoint); + hugr.set_entrypoint(hugr_node(entrypoint)); } if let Some(metadata) = metadata { diff --git a/hugr-core/src/hugr/serialize/test.rs b/hugr-core/src/hugr/serialize/test.rs index 2b500ed038..60249489df 100644 --- a/hugr-core/src/hugr/serialize/test.rs +++ b/hugr-core/src/hugr/serialize/test.rs @@ -24,7 +24,7 @@ use crate::types::{ FuncValueType, PolyFuncType, PolyFuncTypeRV, Signature, SumType, Type, TypeArg, TypeBound, TypeRV, }; -use crate::{OutgoingPort, type_row}; +use crate::{OutgoingPort, Visibility, type_row}; use itertools::Itertools; use jsonschema::{Draft, Validator}; @@ -62,26 +62,29 @@ impl NamedSchema { Self { name, schema } } - pub fn check(&self, val: &serde_json::Value) { + pub fn check(&self, val: &serde_json::Value) -> Result<(), String> { let mut errors = self.schema.iter_errors(val).peekable(); - if errors.peek().is_some() { - // errors don't necessarily implement Debug - eprintln!("Schema failed to validate: {}", self.name); - for error in errors { - eprintln!("Validation error: {error}"); - eprintln!("Instance path: {}", error.instance_path); - } - panic!("Serialization test failed."); + if errors.peek().is_none() { + return Ok(()); } + + // errors don't necessarily implement Debug + let mut strs = vec![format!("Schema failed to validate: {}", self.name)]; + strs.extend(errors.flat_map(|error| { + [ + format!("Validation error: {error}"), + format!("Instance path: {}", error.instance_path), + ] + })); + strs.push("Serialization test failed.".to_string()); + Err(strs.join("\n")) } pub fn check_schemas( val: &serde_json::Value, schemas: impl IntoIterator, - ) { - for schema in schemas { - schema.check(val); - } + ) -> Result<(), String> { + schemas.into_iter().try_for_each(|schema| schema.check(val)) } } @@ -89,7 +92,7 @@ macro_rules! include_schema { ($name:ident, $path:literal) => { lazy_static! { static ref $name: NamedSchema = - NamedSchema::new("$name", { + NamedSchema::new(stringify!($name), { let schema_val: serde_json::Value = serde_json::from_str(include_str!( concat!("../../../../specification/schema/", $path, "_live.json") )) @@ -161,7 +164,7 @@ fn ser_deserialize_check_schema( val: serde_json::Value, schemas: impl IntoIterator, ) -> T { - NamedSchema::check_schemas(&val, schemas); + NamedSchema::check_schemas(&val, schemas).unwrap(); serde_json::from_value(val).unwrap() } @@ -171,8 +174,22 @@ fn ser_roundtrip_check_schema, ) -> TDeser { let val = serde_json::to_value(g).unwrap(); - NamedSchema::check_schemas(&val, schemas); - serde_json::from_value(val).unwrap() + match NamedSchema::check_schemas(&val, schemas) { + Ok(()) => serde_json::from_value(val).unwrap(), + Err(msg) => panic!("ser_roundtrip_check_schema failed with {msg}, input was {val}"), + } +} + +/// Serialize a Hugr and check that it is valid against the schema. +/// +/// # Panics +/// +/// Panics if the serialization fails or if the schema validation fails. +pub(crate) fn check_hugr_serialization_schema(hugr: &Hugr) { + let schemas = get_schemas(true); + let hugr_ser = HugrSer(hugr); + let val = serde_json::to_value(hugr_ser).unwrap(); + NamedSchema::check_schemas(&val, schemas).unwrap(); } /// Serialize and deserialize a HUGR, and check that the result is the same as the original. @@ -210,8 +227,80 @@ fn check_testing_roundtrip(t: impl Into) { assert_eq!(before, after); } +fn test_schema_val() -> serde_json::Value { + serde_json::json!({ + "op_def":null, + "optype":{ + "name":"polyfunc1", + "op":"FuncDefn", + "visibility": "Public", + "parent":0, + "signature":{ + "body":{ + "input":[], + "output":[] + }, + "params":[ + {"bound":null,"tp":"BoundedNat"} + ] + } + }, + "poly_func_type":null, + "sum_type":null, + "typ":null, + "value":null, + "version":"live" + }) +} + +fn schema_val() -> serde_json::Value { + serde_json::json!({"nodes": [], "edges": [], "version": "live"}) +} + +#[rstest] +#[case(&TESTING_SCHEMA, &TESTING_SCHEMA_STRICT, test_schema_val(), Some("optype"))] +#[case(&SCHEMA, &SCHEMA_STRICT, schema_val(), None)] +fn wrong_fields( + #[case] lax_schema: &'static NamedSchema, + #[case] strict_schema: &'static NamedSchema, + #[case] mut val: serde_json::Value, + #[case] target_loc: impl IntoIterator + Clone, +) { + use serde_json::Value; + fn get_fields( + val: &mut Value, + mut path: impl Iterator, + ) -> &mut serde_json::Map { + let Value::Object(fields) = val else { panic!() }; + match path.next() { + Some(n) => get_fields(fields.get_mut(n).unwrap(), path), + None => fields, + } + } + // First, some "known good" JSON + NamedSchema::check_schemas(&val, [lax_schema, strict_schema]).unwrap(); + + // Now try adding an extra field + let fields = get_fields(&mut val, target_loc.clone().into_iter()); + fields.insert( + "extra_field".to_string(), + Value::String("not in schema".to_string()), + ); + strict_schema.check(&val).unwrap_err(); + lax_schema.check(&val).unwrap(); + + // And removing one + let fields = get_fields(&mut val, target_loc.into_iter()); + fields.remove("extra_field").unwrap(); + let key = fields.keys().next().unwrap().clone(); + fields.remove(&key).unwrap(); + + lax_schema.check(&val).unwrap_err(); + strict_schema.check(&val).unwrap_err(); +} + /// Generate an optype for a node with a matching amount of inputs and outputs. -fn gen_optype(g: &MultiPortGraph, node: portgraph::NodeIndex) -> OpType { +fn gen_optype(g: &MultiPortGraph, node: portgraph::NodeIndex) -> OpType { let inputs = g.num_inputs(node); let outputs = g.num_outputs(node); match (inputs == 0, outputs == 0) { @@ -428,7 +517,7 @@ fn serialize_types_roundtrip() { #[case(bool_t())] #[case(usize_t())] #[case(INT_TYPES[2].clone())] -#[case(Type::new_alias(crate::ops::AliasDecl::new("t", TypeBound::Any)))] +#[case(Type::new_alias(crate::ops::AliasDecl::new("t", TypeBound::Linear)))] #[case(Type::new_var_use(2, TypeBound::Copyable))] #[case(Type::new_tuple(vec![bool_t(),qb_t()]))] #[case(Type::new_sum([vec![bool_t(),qb_t()], vec![Type::new_unit_sum(4)]]))] @@ -458,13 +547,13 @@ fn roundtrip_value(#[case] value: Value) { fn polyfunctype1() -> PolyFuncType { let function_type = Signature::new_endo(type_row![]); - PolyFuncType::new([TypeParam::max_nat()], function_type) + PolyFuncType::new([TypeParam::max_nat_type()], function_type) } fn polyfunctype2() -> PolyFuncTypeRV { - let tv0 = TypeRV::new_row_var_use(0, TypeBound::Any); + let tv0 = TypeRV::new_row_var_use(0, TypeBound::Linear); let tv1 = TypeRV::new_row_var_use(1, TypeBound::Copyable); - let params = [TypeBound::Any, TypeBound::Copyable].map(TypeParam::new_list); + let params = [TypeBound::Linear, TypeBound::Copyable].map(TypeParam::new_list_type); let inputs = vec![ TypeRV::new_function(FuncValueType::new(tv0.clone(), tv1.clone())), tv0, @@ -479,26 +568,26 @@ fn polyfunctype2() -> PolyFuncTypeRV { #[rstest] #[case(Signature::new_endo(type_row![]).into())] #[case(polyfunctype1())] -#[case(PolyFuncType::new([TypeParam::String], Signature::new_endo(vec![Type::new_var_use(0, TypeBound::Copyable)])))] +#[case(PolyFuncType::new([TypeParam::StringType], Signature::new_endo(vec![Type::new_var_use(0, TypeBound::Copyable)])))] #[case(PolyFuncType::new([TypeBound::Copyable.into()], Signature::new_endo(vec![Type::new_var_use(0, TypeBound::Copyable)])))] -#[case(PolyFuncType::new([TypeParam::new_list(TypeBound::Any)], Signature::new_endo(type_row![])))] -#[case(PolyFuncType::new([TypeParam::Tuple { params: [TypeBound::Any.into(), TypeParam::bounded_nat(2.try_into().unwrap())].into() }], Signature::new_endo(type_row![])))] +#[case(PolyFuncType::new([TypeParam::new_list_type(TypeBound::Linear)], Signature::new_endo(type_row![])))] +#[case(PolyFuncType::new([TypeParam::new_tuple_type([TypeBound::Linear.into(), TypeParam::bounded_nat_type(2.try_into().unwrap())])], Signature::new_endo(type_row![])))] #[case(PolyFuncType::new( - [TypeParam::new_list(TypeBound::Any)], - Signature::new_endo(Type::new_tuple(TypeRV::new_row_var_use(0, TypeBound::Any)))))] + [TypeParam::new_list_type(TypeBound::Linear)], + Signature::new_endo(Type::new_tuple(TypeRV::new_row_var_use(0, TypeBound::Linear)))))] fn roundtrip_polyfunctype_fixedlen(#[case] poly_func_type: PolyFuncType) { check_testing_roundtrip(poly_func_type); } #[rstest] #[case(FuncValueType::new_endo(type_row![]).into())] -#[case(PolyFuncTypeRV::new([TypeParam::String], FuncValueType::new_endo(vec![Type::new_var_use(0, TypeBound::Copyable)])))] +#[case(PolyFuncTypeRV::new([TypeParam::StringType], FuncValueType::new_endo(vec![Type::new_var_use(0, TypeBound::Copyable)])))] #[case(PolyFuncTypeRV::new([TypeBound::Copyable.into()], FuncValueType::new_endo(vec![Type::new_var_use(0, TypeBound::Copyable)])))] -#[case(PolyFuncTypeRV::new([TypeParam::new_list(TypeBound::Any)], FuncValueType::new_endo(type_row![])))] -#[case(PolyFuncTypeRV::new([TypeParam::Tuple { params: [TypeBound::Any.into(), TypeParam::bounded_nat(2.try_into().unwrap())].into() }], FuncValueType::new_endo(type_row![])))] +#[case(PolyFuncTypeRV::new([TypeParam::new_list_type(TypeBound::Linear)], FuncValueType::new_endo(type_row![])))] +#[case(PolyFuncTypeRV::new([TypeParam::new_tuple_type([TypeBound::Linear.into(), TypeParam::bounded_nat_type(2.try_into().unwrap())])], FuncValueType::new_endo(type_row![])))] #[case(PolyFuncTypeRV::new( - [TypeParam::new_list(TypeBound::Any)], - FuncValueType::new_endo(TypeRV::new_row_var_use(0, TypeBound::Any))))] + [TypeParam::new_list_type(TypeBound::Linear)], + FuncValueType::new_endo(TypeRV::new_row_var_use(0, TypeBound::Linear))))] #[case(polyfunctype2())] fn roundtrip_polyfunctype_varlen(#[case] poly_func_type: PolyFuncTypeRV) { check_testing_roundtrip(poly_func_type); @@ -506,15 +595,15 @@ fn roundtrip_polyfunctype_varlen(#[case] poly_func_type: PolyFuncTypeRV) { #[rstest] #[case(ops::Module::new())] -#[case(ops::FuncDefn::new("polyfunc1", polyfunctype1()))] -#[case(ops::FuncDecl::new("polyfunc2", polyfunctype1()))] +#[case(ops::FuncDefn::new_vis("polyfunc1", polyfunctype1(), Visibility::Private))] +#[case(ops::FuncDefn::new_vis("pubfunc1", polyfunctype1(), Visibility::Public))] #[case(ops::AliasDefn { name: "aliasdefn".into(), definition: Type::new_unit_sum(4)})] -#[case(ops::AliasDecl { name: "aliasdecl".into(), bound: TypeBound::Any})] +#[case(ops::AliasDecl { name: "aliasdecl".into(), bound: TypeBound::Linear})] #[case(ops::Const::new(Value::false_val()))] #[case(ops::Const::new(Value::function(crate::builder::test::simple_dfg_hugr()).unwrap()))] #[case(ops::Input::new(vec![Type::new_var_use(3,TypeBound::Copyable)]))] #[case(ops::Output::new(vec![Type::new_function(FuncValueType::new_endo(type_row![]))]))] -#[case(ops::Call::try_new(polyfunctype1(), [TypeArg::BoundedNat{n: 1}]).unwrap())] +#[case(ops::Call::try_new(polyfunctype1(), [TypeArg::BoundedNat(1)]).unwrap())] #[case(ops::CallIndirect { signature : Signature::new_endo(vec![bool_t()]) })] fn roundtrip_optype(#[case] optype: impl Into + std::fmt::Debug) { check_testing_roundtrip(NodeSer { @@ -529,7 +618,7 @@ fn std_extensions_valid() { let std_reg = crate::std_extensions::std_reg(); for ext in std_reg { let val = serde_json::to_value(ext).unwrap(); - NamedSchema::check_schemas(&val, get_schemas(true)); + NamedSchema::check_schemas(&val, get_schemas(true)).unwrap(); // check deserialises correctly, can't check equality because of custom binaries. let deser: crate::extension::Extension = serde_json::from_value(val.clone()).unwrap(); assert_eq!(serde_json::to_value(deser).unwrap(), val); diff --git a/hugr-core/src/hugr/serialize/upgrade/testcases/hugr_with_named_op.json b/hugr-core/src/hugr/serialize/upgrade/testcases/hugr_with_named_op.json index ca3965d874..112581e94f 100644 --- a/hugr-core/src/hugr/serialize/upgrade/testcases/hugr_with_named_op.json +++ b/hugr-core/src/hugr/serialize/upgrade/testcases/hugr_with_named_op.json @@ -3,67 +3,8 @@ "nodes": [ { "parent": 0, - "op": "Module" - }, - { - "parent": 0, - "op": "FuncDefn", - "name": "main", - "signature": { - "params": [], - "body": { - "input": [ - { - "t": "Sum", - "s": "Unit", - "size": 2 - }, - { - "t": "Sum", - "s": "Unit", - "size": 2 - } - ], - "output": [ - { - "t": "Sum", - "s": "Unit", - "size": 2 - } - ] - } - } - }, - { - "parent": 1, - "op": "Input", - "types": [ - { - "t": "Sum", - "s": "Unit", - "size": 2 - }, - { - "t": "Sum", - "s": "Unit", - "size": 2 - } - ] - }, - { - "parent": 1, - "op": "Output", - "types": [ - { - "t": "Sum", - "s": "Unit", - "size": 2 - } - ] - }, - { - "parent": 1, "op": "DFG", + "name": "main", "signature": { "input": [ { @@ -87,7 +28,7 @@ } }, { - "parent": 4, + "parent": 0, "op": "Input", "types": [ { @@ -103,7 +44,7 @@ ] }, { - "parent": 4, + "parent": 0, "op": "Output", "types": [ { @@ -114,7 +55,7 @@ ] }, { - "parent": 4, + "parent": 0, "op": "Extension", "extension": "logic", "name": "And", @@ -146,27 +87,7 @@ "edges": [ [ [ - 2, - 0 - ], - [ - 4, - 0 - ] - ], - [ - [ - 2, - 1 - ], - [ - 4, - 1 - ] - ], - [ - [ - 4, + 1, 0 ], [ @@ -176,45 +97,30 @@ ], [ [ - 5, - 0 - ], - [ - 7, - 0 - ] - ], - [ - [ - 5, + 1, 1 ], [ - 7, + 3, 1 ] ], [ [ - 7, + 3, 0 ], [ - 6, + 2, 0 ] ] ], "metadata": [ - null, - null, - null, - null, null, null, null, null ], - "encoder": "hugr-rs v0.15.4", - "entrypoint": 4 + "encoder": "hugr-rs v0.15.4" } diff --git a/hugr-core/src/hugr/validate.rs b/hugr-core/src/hugr/validate.rs index 8291cdcde9..41fe7ba45b 100644 --- a/hugr-core/src/hugr/validate.rs +++ b/hugr-core/src/hugr/validate.rs @@ -1,6 +1,7 @@ //! HUGR invariant checks. use std::collections::HashMap; +use std::collections::hash_map::Entry; use std::iter; use itertools::Itertools; @@ -19,9 +20,8 @@ use crate::ops::validate::{ use crate::ops::{NamedOp, OpName, OpTag, OpTrait, OpType, ValidateOp}; use crate::types::EdgeKind; use crate::types::type_param::TypeParam; -use crate::{Direction, Port}; +use crate::{Direction, Port, Visibility}; -use super::ExtensionError; use super::internal::PortgraphNodeMap; use super::views::HugrView; @@ -60,6 +60,7 @@ impl<'a, H: HugrView> ValidationContext<'a, H> { // Hierarchy and children. No type variables declared outside the root. self.validate_subtree(self.hugr.entrypoint(), &[])?; + self.validate_linkage()?; // In tests we take the opportunity to verify that the hugr // serialization round-trips. We verify the schema of the serialization // format only when an environment variable is set. This allows @@ -81,6 +82,44 @@ impl<'a, H: HugrView> ValidationContext<'a, H> { Ok(()) } + fn validate_linkage(&self) -> Result<(), ValidationError> { + // Map from func_name, for visible funcs only, to *tuple of* + // Node with that func_name, + // Signature, + // bool - true for FuncDefn + let mut node_sig_defn = HashMap::new(); + + for c in self.hugr.children(self.hugr.module_root()) { + let (func_name, sig, is_defn) = match self.hugr.get_optype(c) { + OpType::FuncDecl(fd) if fd.visibility() == &Visibility::Public => { + (fd.func_name(), fd.signature(), false) + } + OpType::FuncDefn(fd) if fd.visibility() == &Visibility::Public => { + (fd.func_name(), fd.signature(), true) + } + _ => continue, + }; + match node_sig_defn.entry(func_name) { + Entry::Vacant(ve) => { + ve.insert((c, sig, is_defn)); + } + Entry::Occupied(oe) => { + // Allow two decls of the same sig (aliasing - we are allowing some laziness here). + // Reject if at least one Defn - either two conflicting impls, + // or Decl+Defn which should have been linked + let (prev_c, prev_sig, prev_defn) = oe.get(); + if prev_sig != &sig || is_defn || *prev_defn { + return Err(ValidationError::DuplicateExport { + link_name: func_name.clone(), + children: [*prev_c, c], + }); + }; + } + } + } + Ok(()) + } + /// Compute the dominator tree for a CFG region, identified by its container /// node. /// @@ -119,7 +158,7 @@ impl<'a, H: HugrView> ValidationContext<'a, H> { if num_ports != op_type.port_count(dir) { return Err(ValidationError::WrongNumberOfPorts { node, - optype: op_type.clone(), + optype: Box::new(op_type.clone()), actual: num_ports, expected: op_type.port_count(dir), dir, @@ -137,9 +176,9 @@ impl<'a, H: HugrView> ValidationContext<'a, H> { if !allowed_children.is_superset(op_type.tag()) { return Err(ValidationError::InvalidParentOp { child: node, - child_optype: op_type.clone(), + child_optype: Box::new(op_type.clone()), parent, - parent_optype: parent_optype.clone(), + parent_optype: Box::new(parent_optype.clone()), allowed_children, }); } @@ -151,7 +190,7 @@ impl<'a, H: HugrView> ValidationContext<'a, H> { if validity_flags.allowed_children == OpTag::None { return Err(ValidationError::EntrypointNotContainer { node, - optype: op_type.clone(), + optype: Box::new(op_type.clone()), }); } } @@ -200,7 +239,7 @@ impl<'a, H: HugrView> ValidationContext<'a, H> { return Err(ValidationError::UnconnectedPort { node, port, - port_kind, + port_kind: Box::new(port_kind), }); } @@ -210,7 +249,7 @@ impl<'a, H: HugrView> ValidationContext<'a, H> { return Err(ValidationError::TooManyConnections { node, port, - port_kind, + port_kind: Box::new(port_kind), }); } return Ok(()); @@ -230,7 +269,7 @@ impl<'a, H: HugrView> ValidationContext<'a, H> { return Err(ValidationError::TooManyConnections { node, port, - port_kind, + port_kind: Box::new(port_kind), }); } @@ -244,10 +283,10 @@ impl<'a, H: HugrView> ValidationContext<'a, H> { return Err(ValidationError::IncompatiblePorts { from: node, from_port: port, - from_kind: port_kind, + from_kind: Box::new(port_kind), to: other_node, to_port: other_offset, - to_kind: other_kind, + to_kind: Box::new(other_kind), }); } @@ -286,7 +325,7 @@ impl<'a, H: HugrView> ValidationContext<'a, H> { if flags.allowed_children.is_empty() { return Err(ValidationError::NonContainerWithChildren { node, - optype: op_type.clone(), + optype: Box::new(op_type.clone()), }); } @@ -296,8 +335,8 @@ impl<'a, H: HugrView> ValidationContext<'a, H> { if !flags.allowed_first_child.is_superset(first_child.tag()) { return Err(ValidationError::InvalidInitialChild { parent: node, - parent_optype: op_type.clone(), - optype: first_child.clone(), + parent_optype: Box::new(op_type.clone()), + optype: Box::new(first_child.clone()), expected: flags.allowed_first_child, position: "first", }); @@ -310,8 +349,8 @@ impl<'a, H: HugrView> ValidationContext<'a, H> { if !flags.allowed_second_child.is_superset(second_child.tag()) { return Err(ValidationError::InvalidInitialChild { parent: node, - parent_optype: op_type.clone(), - optype: second_child.clone(), + parent_optype: Box::new(op_type.clone()), + optype: Box::new(second_child.clone()), expected: flags.allowed_second_child, position: "second", }); @@ -322,7 +361,7 @@ impl<'a, H: HugrView> ValidationContext<'a, H> { if let Err(source) = op_type.validate_op_children(children_optypes) { return Err(ValidationError::InvalidChildren { parent: node, - parent_optype: op_type.clone(), + parent_optype: Box::new(op_type.clone()), source, }); } @@ -349,7 +388,7 @@ impl<'a, H: HugrView> ValidationContext<'a, H> { if let Err(source) = edge_check(edge_data) { return Err(ValidationError::InvalidEdges { parent: node, - parent_optype: op_type.clone(), + parent_optype: Box::new(op_type.clone()), source, }); } @@ -364,7 +403,7 @@ impl<'a, H: HugrView> ValidationContext<'a, H> { } else if flags.requires_children { return Err(ValidationError::ContainerWithoutChildren { node, - optype: op_type.clone(), + optype: Box::new(op_type.clone()), }); } @@ -395,7 +434,7 @@ impl<'a, H: HugrView> ValidationContext<'a, H> { if nodes_visited != node_count { return Err(ValidationError::NotADag { node: parent, - optype: op_type.clone(), + optype: Box::new(op_type.clone()), }); } @@ -433,7 +472,7 @@ impl<'a, H: HugrView> ValidationContext<'a, H> { from_offset, to, to_offset, - ty: edge_kind, + ty: Box::new(edge_kind), }); } @@ -443,28 +482,12 @@ impl<'a, H: HugrView> ValidationContext<'a, H> { // // This search could be sped-up with a pre-computed LCA structure, but // for valid Hugrs this search should be very short. - // - // For Value edges only, we record any FuncDefn we went through; if there is - // any such, then that is an error, but we report that only if the dom/ext - // relation was otherwise ok (an error about an edge "entering" some ancestor - // node could be misleading if the source isn't where it's expected) - let mut err_entered_func = None; let from_parent_parent = self.hugr.get_parent(from_parent); for (ancestor, ancestor_parent) in iter::successors(to_parent, |&p| self.hugr.get_parent(p)).tuple_windows() { - if !is_static && self.hugr.get_optype(ancestor).is_func_defn() { - err_entered_func.get_or_insert(InterGraphEdgeError::ValueEdgeIntoFunc { - to, - to_offset, - from, - from_offset, - func: ancestor, - }); - } if ancestor_parent == from_parent { // External edge. - err_entered_func.map_or(Ok(()), Err)?; if !is_static { // Must have an order edge. self.hugr @@ -488,10 +511,10 @@ impl<'a, H: HugrView> ValidationContext<'a, H> { from_offset, to, to_offset, - ancestor_parent_op: ancestor_parent_op.clone(), + ancestor_parent_op: Box::new(ancestor_parent_op.clone()), }); } - err_entered_func.map_or(Ok(()), Err)?; + // Check domination let (dominator_tree, node_map) = if let Some(tree) = self.dominators.get(&ancestor_parent) { @@ -617,7 +640,7 @@ pub enum ValidationError { )] WrongNumberOfPorts { node: N, - optype: OpType, + optype: Box, actual: usize, expected: usize, dir: Direction, @@ -627,14 +650,14 @@ pub enum ValidationError { UnconnectedPort { node: N, port: Port, - port_kind: EdgeKind, + port_kind: Box, }, /// A linear port is connected to more than one thing. #[error("{node} has a port {port} of type {port_kind} with more than one connection.")] TooManyConnections { node: N, port: Port, - port_kind: EdgeKind, + port_kind: Box, }, /// Connected ports have different types, or non-unifiable types. #[error( @@ -643,10 +666,10 @@ pub enum ValidationError { IncompatiblePorts { from: N, from_port: Port, - from_kind: EdgeKind, + from_kind: Box, to: N, to_port: Port, - to_kind: EdgeKind, + to_kind: Box, }, /// The non-root node has no parent. #[error("{node} has no parent.")] @@ -655,9 +678,9 @@ pub enum ValidationError { #[error("The operation {parent_optype} cannot contain a {child_optype} as a child. Allowed children: {}. In {child} with parent {parent}.", allowed_children.description())] InvalidParentOp { child: N, - child_optype: OpType, + child_optype: Box, parent: N, - parent_optype: OpType, + parent_optype: Box, allowed_children: OpTag, }, /// Invalid first/second child. @@ -666,8 +689,8 @@ pub enum ValidationError { )] InvalidInitialChild { parent: N, - parent_optype: OpType, - optype: OpType, + parent_optype: Box, + optype: Box, expected: OpTag, position: &'static str, }, @@ -678,9 +701,19 @@ pub enum ValidationError { )] InvalidChildren { parent: N, - parent_optype: OpType, + parent_optype: Box, source: ChildrenValidationError, }, + /// Multiple, incompatible, nodes with [Visibility::Public] use the same `func_name` + /// in a [Module](super::Module). (Multiple [`FuncDecl`](crate::ops::FuncDecl)s with + /// the same signature are allowed) + #[error("FuncDefn/Decl {} is exported under same name {link_name} as earlier node {}", children[0], children[1])] + DuplicateExport { + /// The `func_name` of a public `FuncDecl` or `FuncDefn` + link_name: String, + /// Two nodes using that name + children: [N; 2], + }, /// The children graph has invalid edges. #[error( "An operation {parent_optype} contains invalid edges between its children: {source}. In parent {parent}, edge from {from:?} port {from_port:?} to {to:?} port {to_port:?}", @@ -691,27 +724,23 @@ pub enum ValidationError { )] InvalidEdges { parent: N, - parent_optype: OpType, + parent_optype: Box, source: EdgeValidationError, }, /// The node operation is not a container, but has children. #[error("{node} with optype {optype} is not a container, but has children.")] - NonContainerWithChildren { node: N, optype: OpType }, + NonContainerWithChildren { node: N, optype: Box }, /// The node must have children, but has none. #[error("{node} with optype {optype} must have children, but has none.")] - ContainerWithoutChildren { node: N, optype: OpType }, + ContainerWithoutChildren { node: N, optype: Box }, /// The children of a node do not form a DAG. #[error( "The children of an operation {optype} must form a DAG. Loops are not allowed. In {node}." )] - NotADag { node: N, optype: OpType }, + NotADag { node: N, optype: Box }, /// There are invalid inter-graph edges. #[error(transparent)] InterGraphEdgeError(#[from] InterGraphEdgeError), - /// There are errors in the extension deltas. - #[deprecated(note = "Never returned since hugr-core-v0.20.0")] - #[error(transparent)] - ExtensionError(#[from] ExtensionError), /// A node claims to still be awaiting extension inference. Perhaps it is not acted upon by inference. #[error( "{node} needs a concrete ExtensionSet - inference will provide this for Case/CFG/Conditional/DataflowBlock/DFG/TailLoop only" @@ -740,7 +769,7 @@ pub enum ValidationError { ConstTypeError(#[from] ConstTypeError), /// The HUGR entrypoint must be a region container. #[error("The HUGR entrypoint ({node}) must be a region container, but '{}' does not accept children.", optype.name())] - EntrypointNotContainer { node: N, optype: OpType }, + EntrypointNotContainer { node: N, optype: Box }, } /// Errors related to the inter-graph edge validations. @@ -757,18 +786,7 @@ pub enum InterGraphEdgeError { from_offset: Port, to: N, to_offset: Port, - ty: EdgeKind, - }, - /// Inter-Graph edges may not enter into `FuncDefns` unless they are static - #[error( - "Inter-graph Value edges cannot enter into FuncDefns. Inter-graph edge from {from} ({from_offset}) to {to} ({to_offset} enters FuncDefn {func}" - )] - ValueEdgeIntoFunc { - from: N, - from_offset: Port, - to: N, - to_offset: Port, - func: N, + ty: Box, }, /// The grandparent of a dominator inter-graph edge must be a CFG container. #[error( @@ -779,7 +797,7 @@ pub enum InterGraphEdgeError { from_offset: Port, to: N, to_offset: Port, - ancestor_parent_op: OpType, + ancestor_parent_op: Box, }, /// The sibling ancestors of the external inter-graph edge endpoints must be have an order edge between them. #[error( diff --git a/hugr-core/src/hugr/validate/test.rs b/hugr-core/src/hugr/validate/test.rs index 8ee95cde61..ec086243ee 100644 --- a/hugr-core/src/hugr/validate/test.rs +++ b/hugr-core/src/hugr/validate/test.rs @@ -1,38 +1,44 @@ +use std::borrow::Cow; use std::fs::File; use std::io::BufReader; use std::sync::Arc; use cool_asserts::assert_matches; +use rstest::rstest; use super::*; use crate::builder::test::closed_dfg_root_hugr; use crate::builder::{ BuildError, Container, DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer, - FunctionBuilder, HugrBuilder, ModuleBuilder, SubContainer, inout_sig, + FunctionBuilder, HugrBuilder, ModuleBuilder, inout_sig, }; use crate::extension::prelude::Noop; use crate::extension::prelude::{bool_t, qb_t, usize_t}; use crate::extension::{Extension, ExtensionRegistry, PRELUDE, TypeDefBound}; use crate::hugr::HugrMut; use crate::hugr::internal::HugrMutInternals; -use crate::ops::dataflow::IOTrait; +use crate::ops::dataflow::{DataflowParent, IOTrait}; use crate::ops::handle::NodeHandle; -use crate::ops::{self, OpType, Value}; +use crate::ops::{self, FuncDecl, FuncDefn, OpType, Value}; use crate::std_extensions::logic::LogicOp; use crate::std_extensions::logic::test::{and_op, or_op}; -use crate::types::type_param::{TypeArg, TypeArgError}; +use crate::types::type_param::{TermTypeError, TypeArg}; use crate::types::{ - CustomType, FuncValueType, PolyFuncType, PolyFuncTypeRV, Signature, Type, TypeBound, TypeRV, - TypeRow, + CustomType, FuncValueType, PolyFuncType, PolyFuncTypeRV, Signature, Term, Type, TypeBound, + TypeRV, TypeRow, }; use crate::{Direction, Hugr, IncomingPort, Node, const_extension_ids, test_file, type_row}; -/// Creates a hugr with a single function definition that copies a bit `copies` times. +/// Creates a hugr with a single, public, function definition that copies a bit `copies` times. /// /// Returns the hugr and the node index of the definition. fn make_simple_hugr(copies: usize) -> (Hugr, Node) { - let def_op: OpType = - ops::FuncDefn::new("main", Signature::new(bool_t(), vec![bool_t(); copies])).into(); + let def_op: OpType = FuncDefn::new_vis( + "main", + Signature::new(bool_t(), vec![bool_t(); copies]), + Visibility::Public, + ) + .into(); let mut b = Hugr::default(); let root = b.entrypoint(); @@ -126,7 +132,7 @@ fn children_restrictions() { // Add a definition without children let def_sig = Signature::new(vec![bool_t()], vec![bool_t(), bool_t()]); - let new_def = b.add_node_with_parent(root, ops::FuncDefn::new("main", def_sig)); + let new_def = b.add_node_with_parent(root, FuncDefn::new("main", def_sig)); assert_matches!( b.validate(), Err(ValidationError::ContainerWithoutChildren { node, .. }) => assert_eq!(node, new_def) @@ -225,35 +231,6 @@ fn test_ext_edge() { h.validate().unwrap(); } -#[test] -fn no_ext_edge_into_func() -> Result<(), Box> { - let b2b = Signature::new_endo(bool_t()); - let mut h = DFGBuilder::new(Signature::new(bool_t(), Type::new_function(b2b.clone())))?; - let [input] = h.input_wires_arr(); - - let mut dfg = h.dfg_builder(Signature::new(vec![], Type::new_function(b2b.clone())), [])?; - let mut func = dfg.define_function("AndWithOuter", b2b.clone())?; - let [fn_input] = func.input_wires_arr(); - let and_op = func.add_dataflow_op(and_op(), [fn_input, input])?; // 'ext' edge - let func = func.finish_with_outputs(and_op.outputs())?; - let loadfn = dfg.load_func(func.handle(), &[])?; - let dfg = dfg.finish_with_outputs([loadfn])?; - let res = h.finish_hugr_with_outputs(dfg.outputs()); - assert_eq!( - res, - Err(BuildError::InvalidHUGR( - ValidationError::InterGraphEdgeError(InterGraphEdgeError::ValueEdgeIntoFunc { - from: input.node(), - from_offset: input.source().into(), - to: and_op.node(), - to_offset: IncomingPort::from(1).into(), - func: func.node() - }) - )) - ); - Ok(()) -} - #[test] fn test_local_const() { let mut h = closed_dfg_root_hugr(Signature::new_endo(bool_t())); @@ -266,7 +243,7 @@ fn test_local_const() { Err(ValidationError::UnconnectedPort { node: and, port: IncomingPort::from(1).into(), - port_kind: EdgeKind::Value(bool_t()) + port_kind: Box::new(EdgeKind::Value(bool_t())) }) ); let const_op: ops::Const = ops::Value::from_bool(true).into(); @@ -305,7 +282,7 @@ fn identity_hugr_with_type(t: Type) -> (Hugr, Node) { let def = b.add_node_with_parent( b.entrypoint(), - ops::FuncDefn::new("main", Signature::new_endo(row.clone())), + FuncDefn::new("main", Signature::new_endo(row.clone())), ); let input = b.add_node_with_parent(def, ops::Input::new(row.clone())); @@ -347,9 +324,9 @@ fn invalid_types() { let valid = Type::new_extension(CustomType::new( "MyContainer", - vec![TypeArg::Type { ty: usize_t() }], + vec![usize_t().into()], EXT_ID, - TypeBound::Any, + TypeBound::Linear, &Arc::downgrade(&ext), )); let mut hugr = identity_hugr_with_type(valid.clone()).0; @@ -359,22 +336,22 @@ fn invalid_types() { // valid is Any, so is not allowed as an element of an outer MyContainer. let element_outside_bound = CustomType::new( "MyContainer", - vec![TypeArg::Type { ty: valid.clone() }], + vec![valid.clone().into()], EXT_ID, - TypeBound::Any, + TypeBound::Linear, &Arc::downgrade(&ext), ); assert_eq!( validate_to_sig_error(element_outside_bound), - SignatureError::TypeArgMismatch(TypeArgError::TypeMismatch { - param: TypeBound::Copyable.into(), - arg: TypeArg::Type { ty: valid } + SignatureError::TypeArgMismatch(TermTypeError::TypeMismatch { + type_: Box::new(TypeBound::Copyable.into()), + term: Box::new(valid.into()) }) ); let bad_bound = CustomType::new( "MyContainer", - vec![TypeArg::Type { ty: usize_t() }], + vec![usize_t().into()], EXT_ID, TypeBound::Copyable, &Arc::downgrade(&ext), @@ -383,41 +360,36 @@ fn invalid_types() { validate_to_sig_error(bad_bound.clone()), SignatureError::WrongBound { actual: TypeBound::Copyable, - expected: TypeBound::Any + expected: TypeBound::Linear } ); // bad_bound claims to be Copyable, which is valid as an element for the outer MyContainer. let nested = CustomType::new( "MyContainer", - vec![TypeArg::Type { - ty: Type::new_extension(bad_bound), - }], + vec![Type::new_extension(bad_bound).into()], EXT_ID, - TypeBound::Any, + TypeBound::Linear, &Arc::downgrade(&ext), ); assert_eq!( validate_to_sig_error(nested), SignatureError::WrongBound { actual: TypeBound::Copyable, - expected: TypeBound::Any + expected: TypeBound::Linear } ); let too_many_type_args = CustomType::new( "MyContainer", - vec![ - TypeArg::Type { ty: usize_t() }, - TypeArg::BoundedNat { n: 3 }, - ], + vec![usize_t().into(), 3u64.into()], EXT_ID, - TypeBound::Any, + TypeBound::Linear, &Arc::downgrade(&ext), ); assert_eq!( validate_to_sig_error(too_many_type_args), - SignatureError::TypeArgMismatch(TypeArgError::WrongNumberArgs(2, 1)) + SignatureError::TypeArgMismatch(TermTypeError::WrongNumberArgs(2, 1)) ); } @@ -427,8 +399,8 @@ fn typevars_declared() -> Result<(), Box> { let f = FunctionBuilder::new( "myfunc", PolyFuncType::new( - [TypeBound::Any.into()], - Signature::new_endo(vec![Type::new_var_use(0, TypeBound::Any)]), + [TypeBound::Linear.into()], + Signature::new_endo(vec![Type::new_var_use(0, TypeBound::Linear)]), ), )?; let [w] = f.input_wires_arr(); @@ -437,8 +409,8 @@ fn typevars_declared() -> Result<(), Box> { let f = FunctionBuilder::new( "myfunc", PolyFuncType::new( - [TypeBound::Any.into()], - Signature::new_endo(vec![Type::new_var_use(1, TypeBound::Any)]), + [TypeBound::Linear.into()], + Signature::new_endo(vec![Type::new_var_use(1, TypeBound::Linear)]), ), )?; let [w] = f.input_wires_arr(); @@ -447,7 +419,7 @@ fn typevars_declared() -> Result<(), Box> { let f = FunctionBuilder::new( "myfunc", PolyFuncType::new( - [TypeBound::Any.into()], + [TypeBound::Linear.into()], Signature::new_endo(vec![Type::new_var_use(1, TypeBound::Copyable)]), ), )?; @@ -456,51 +428,39 @@ fn typevars_declared() -> Result<(), Box> { Ok(()) } -/// Test that nested `FuncDefns` cannot use Type Variables declared by enclosing `FuncDefns` +/// Test that `FuncDefns` cannot be nested. #[test] -fn nested_typevars() -> Result<(), Box> { - const OUTER_BOUND: TypeBound = TypeBound::Any; - const INNER_BOUND: TypeBound = TypeBound::Copyable; - fn build(t: Type) -> Result { - let mut outer = FunctionBuilder::new( - "outer", - PolyFuncType::new( - [OUTER_BOUND.into()], - Signature::new_endo(vec![Type::new_var_use(0, TypeBound::Any)]), - ), - )?; - let inner = outer.define_function( - "inner", - PolyFuncType::new([INNER_BOUND.into()], Signature::new_endo(vec![t])), - )?; - let [w] = inner.input_wires_arr(); - inner.finish_with_outputs([w])?; - let [w] = outer.input_wires_arr(); - outer.finish_hugr_with_outputs([w]) - } - assert!(build(Type::new_var_use(0, INNER_BOUND)).is_ok()); - assert_matches!( - build(Type::new_var_use(1, OUTER_BOUND)).unwrap_err(), - BuildError::InvalidHUGR(ValidationError::SignatureError { - cause: SignatureError::FreeTypeVar { - idx: 1, - num_decls: 1 - }, - .. +fn no_nested_funcdefns() -> Result<(), Box> { + let mut outer = FunctionBuilder::new("outer", Signature::new_endo(usize_t()))?; + let inner = outer + .add_hugr({ + let inner = FunctionBuilder::new("inner", Signature::new_endo(bool_t()))?; + let [w] = inner.input_wires_arr(); + inner.finish_hugr_with_outputs([w])? }) + .inserted_entrypoint; + let [w] = outer.input_wires_arr(); + let outer_node = outer.container_node(); + let hugr = outer.finish_hugr_with_outputs([w]); + assert_matches!( + hugr.unwrap_err(), + BuildError::InvalidHUGR(ValidationError::InvalidParentOp { + child_optype, + allowed_children: OpTag::DataflowChild, + parent_optype, + child, parent + }) if matches!(*child_optype, OpType::FuncDefn(_)) && matches!(*parent_optype, OpType::FuncDefn(_)) => { + assert_eq!(child, inner); + assert_eq!(parent, outer_node); + } ); - assert_matches!(build(Type::new_var_use(0, OUTER_BOUND)).unwrap_err(), - BuildError::InvalidHUGR(ValidationError::SignatureError { cause: SignatureError::TypeVarDoesNotMatchDeclaration { actual, cached }, .. }) => - {assert_eq!(actual, INNER_BOUND.into()); assert_eq!(cached, OUTER_BOUND.into())}); Ok(()) } #[test] fn no_polymorphic_consts() -> Result<(), Box> { use crate::std_extensions::collections::list; - const BOUND: TypeParam = TypeParam::Type { - b: TypeBound::Copyable, - }; + const BOUND: TypeParam = TypeParam::RuntimeType(TypeBound::Copyable); let list_of_var = Type::new_extension( list::EXTENSION .get_type(&list::LIST_TYPENAME) @@ -533,10 +493,10 @@ fn no_polymorphic_consts() -> Result<(), Box> { } pub(crate) fn extension_with_eval_parallel() -> Arc { - let rowp = TypeParam::new_list(TypeBound::Any); + let rowp = TypeParam::new_list_type(TypeBound::Linear); Extension::new_test_arc(EXT_ID, |ext, extension_ref| { - let inputs = TypeRV::new_row_var_use(0, TypeBound::Any); - let outputs = TypeRV::new_row_var_use(1, TypeBound::Any); + let inputs = TypeRV::new_row_var_use(0, TypeBound::Linear); + let outputs = TypeRV::new_row_var_use(1, TypeBound::Linear); let evaled_fn = TypeRV::new_function(FuncValueType::new(inputs.clone(), outputs.clone())); let pf = PolyFuncTypeRV::new( [rowp.clone(), rowp.clone()], @@ -545,7 +505,7 @@ pub(crate) fn extension_with_eval_parallel() -> Arc { ext.add_op("eval".into(), String::new(), pf, extension_ref) .unwrap(); - let rv = |idx| TypeRV::new_row_var_use(idx, TypeBound::Any); + let rv = |idx| TypeRV::new_row_var_use(idx, TypeBound::Linear); let pf = PolyFuncTypeRV::new( [rowp.clone(), rowp.clone(), rowp.clone(), rowp.clone()], Signature::new( @@ -563,8 +523,8 @@ pub(crate) fn extension_with_eval_parallel() -> Arc { #[test] fn instantiate_row_variables() -> Result<(), Box> { - fn uint_seq(i: usize) -> TypeArg { - vec![TypeArg::Type { ty: usize_t() }; i].into() + fn uint_seq(i: usize) -> Term { + vec![usize_t().into(); i].into() } let e = extension_with_eval_parallel(); let mut dfb = DFGBuilder::new(inout_sig( @@ -588,124 +548,49 @@ fn instantiate_row_variables() -> Result<(), Box> { Ok(()) } -fn seq1ty(t: TypeRV) -> TypeArg { - TypeArg::Sequence { - elems: vec![t.into()], - } +fn list1ty(t: TypeRV) -> Term { + Term::new_list([t.into()]) } #[test] fn row_variables() -> Result<(), Box> { let e = extension_with_eval_parallel(); - let tv = TypeRV::new_row_var_use(0, TypeBound::Any); + let tv = TypeRV::new_row_var_use(0, TypeBound::Linear); let inner_ft = Type::new_function(FuncValueType::new_endo(tv.clone())); let ft_usz = Type::new_function(FuncValueType::new_endo(vec![tv.clone(), usize_t().into()])); let mut fb = FunctionBuilder::new( "id", PolyFuncType::new( - [TypeParam::new_list(TypeBound::Any)], + [TypeParam::new_list_type(TypeBound::Linear)], Signature::new(inner_ft.clone(), ft_usz), ), )?; // All the wires here are carrying higher-order Function values let [func_arg] = fb.input_wires_arr(); let id_usz = { - let bldr = fb.define_function("id_usz", Signature::new_endo(usize_t()))?; + let mut mb = fb.module_root_builder(); + let bldr = mb.define_function("id_usz", Signature::new_endo(usize_t()))?; let vals = bldr.input_wires(); - let inner_def = bldr.finish_with_outputs(vals)?; - fb.load_func(inner_def.handle(), &[])? + let helper_def = bldr.finish_with_outputs(vals)?; + fb.load_func(helper_def.handle(), &[])? }; let par = e.instantiate_extension_op( "parallel", - [tv.clone(), usize_t().into(), tv.clone(), usize_t().into()].map(seq1ty), + [tv.clone(), usize_t().into(), tv.clone(), usize_t().into()].map(list1ty), )?; let par_func = fb.add_dataflow_op(par, [func_arg, id_usz])?; fb.finish_hugr_with_outputs(par_func.outputs())?; Ok(()) } -#[test] -fn test_polymorphic_call() -> Result<(), Box> { - // TODO: This tests a function call that is polymorphic in an extension set. - // Should this be rewritten to be polymorphic in something else or removed? - - let e = Extension::try_new_test_arc(EXT_ID, |ext, extension_ref| { - let params: Vec = vec![TypeBound::Any.into(), TypeBound::Any.into()]; - let evaled_fn = Type::new_function(Signature::new( - Type::new_var_use(0, TypeBound::Any), - Type::new_var_use(1, TypeBound::Any), - )); - // Single-input/output version of the higher-order "eval" operation, with extension param. - // Note the extension-delta of the eval node includes that of the input function. - ext.add_op( - "eval".into(), - String::new(), - PolyFuncTypeRV::new( - params.clone(), - Signature::new( - vec![evaled_fn, Type::new_var_use(0, TypeBound::Any)], - Type::new_var_use(1, TypeBound::Any), - ), - ), - extension_ref, - )?; - - Ok(()) - })?; - - fn utou() -> Type { - Type::new_function(Signature::new_endo(usize_t())) - } - - let int_pair = Type::new_tuple(vec![usize_t(); 2]); - // Root DFG: applies a function int-->int to each element of a pair of two ints - let mut d = DFGBuilder::new(inout_sig( - vec![utou(), int_pair.clone()], - vec![int_pair.clone()], - ))?; - // ....by calling a function (int-->int, int_pair) -> int_pair - let f = { - let mut f = d.define_function( - "two_ints", - PolyFuncType::new( - vec![], - Signature::new(vec![utou(), int_pair.clone()], int_pair.clone()), - ), - )?; - let [func, tup] = f.input_wires_arr(); - let mut c = f.conditional_builder( - (vec![vec![usize_t(); 2].into()], tup), - vec![], - vec![usize_t(); 2].into(), - )?; - let mut cc = c.case_builder(0)?; - let [i1, i2] = cc.input_wires_arr(); - let op = e.instantiate_extension_op("eval", vec![usize_t().into(), usize_t().into()])?; - let [f1] = cc.add_dataflow_op(op.clone(), [func, i1])?.outputs_arr(); - let [f2] = cc.add_dataflow_op(op, [func, i2])?.outputs_arr(); - cc.finish_with_outputs([f1, f2])?; - let res = c.finish_sub_container()?.outputs(); - let tup = f.make_tuple(res)?; - f.finish_with_outputs([tup])? - }; - - let [func, tup] = d.input_wires_arr(); - let call = d.call(f.handle(), &[], [func, tup])?; - let h = d.finish_hugr_with_outputs(call.outputs())?; - let call_ty = h.get_optype(call.node()).dataflow_signature().unwrap(); - let exp_fun_ty = Signature::new(vec![utou(), int_pair.clone()], int_pair); - assert_eq!(call_ty.as_ref(), &exp_fun_ty); - Ok(()) -} - #[test] fn test_polymorphic_load() -> Result<(), Box> { let mut m = ModuleBuilder::new(); let id = m.declare( "id", PolyFuncType::new( - vec![TypeBound::Any.into()], - Signature::new_endo(vec![Type::new_var_use(0, TypeBound::Any)]), + vec![TypeBound::Linear.into()], + Signature::new_endo(vec![Type::new_var_use(0, TypeBound::Linear)]), ), )?; let sig = Signature::new( @@ -852,7 +737,7 @@ fn cfg_connections() -> Result<(), Box> { Err(ValidationError::TooManyConnections { node: middle.node(), port: Port::new(Direction::Outgoing, 0), - port_kind: EdgeKind::ControlFlow + port_kind: Box::new(EdgeKind::ControlFlow) }) ); Ok(()) @@ -876,3 +761,74 @@ fn cfg_entry_io_bug() -> Result<(), Box> { Ok(()) } + +fn sig1() -> Signature { + Signature::new_endo(bool_t()) +} + +fn sig2() -> Signature { + Signature::new_endo(usize_t()) +} + +#[rstest] +// Private FuncDefns never conflict even if different sig +#[case( + FuncDefn::new_vis("foo", sig1(), Visibility::Public), + FuncDefn::new("foo", sig2()), + None +)] +#[case(FuncDefn::new("foo", sig1()), FuncDecl::new("foo", sig2()), None)] +// Public FuncDefn conflicts with anything Public even if same sig +#[case( + FuncDefn::new_vis("foo", sig1(), Visibility::Public), + FuncDefn::new_vis("foo", sig1(), Visibility::Public), + Some("foo") +)] +#[case( + FuncDefn::new_vis("foo", sig1(), Visibility::Public), + FuncDecl::new("foo", sig1()), + Some("foo") +)] +// Two public FuncDecls are ok with same sig +#[case(FuncDecl::new("foo", sig1()), FuncDecl::new("foo", sig1()), None)] +// But two public FuncDecls not ok if different sigs +#[case( + FuncDecl::new("foo", sig1()), + FuncDecl::new("foo", sig2()), + Some("foo") +)] +fn validate_linkage( + #[case] f1: impl Into, + #[case] f2: impl Into, + #[case] err: Option<&str>, +) { + let mut h = Hugr::new(); + let [n1, n2] = [f1.into(), f2.into()].map(|f| { + let def_sig = f + .as_func_defn() + .map(FuncDefn::inner_signature) + .map(Cow::into_owned); + let n = h.add_node_with_parent(h.module_root(), f); + if let Some(Signature { input, output }) = def_sig { + let i = h.add_node_with_parent(n, ops::Input::new(input)); + let o = h.add_node_with_parent(n, ops::Output::new(output)); + h.connect(i, 0, o, 0); // Assume all sig's used in test are 1-ary endomorphic + } + n + }); + let r = h.validate(); + match err { + None => r.unwrap(), + Some(name) => { + let Err(ValidationError::DuplicateExport { + link_name, + children, + }) = r + else { + panic!("validate() should have produced DuplicateExport error not {r:?}") + }; + assert_eq!(link_name, name); + assert!(children == [n1, n2] || children == [n2, n1]); + } + } +} diff --git a/hugr-core/src/hugr/views.rs b/hugr-core/src/hugr/views.rs index 1704d79f65..3a7b435cf7 100644 --- a/hugr-core/src/hugr/views.rs +++ b/hugr-core/src/hugr/views.rs @@ -421,7 +421,7 @@ pub trait HugrView: HugrInternals { let config = match RenderConfig::try_from(formatter) { Ok(config) => config, Err(e) => { - panic!("Unsupported format option: {}", e); + panic!("Unsupported format option: {e}"); } }; #[allow(deprecated)] diff --git a/hugr-core/src/hugr/views/impls.rs b/hugr-core/src/hugr/views/impls.rs index 75311f2f73..dbc9ad7b8f 100644 --- a/hugr-core/src/hugr/views/impls.rs +++ b/hugr-core/src/hugr/views/impls.rs @@ -115,6 +115,7 @@ macro_rules! hugr_mut_methods { fn disconnect(&mut self, node: Self::Node, port: impl Into); fn add_other_edge(&mut self, src: Self::Node, dst: Self::Node) -> (crate::OutgoingPort, crate::IncomingPort); fn insert_hugr(&mut self, root: Self::Node, other: crate::Hugr) -> crate::hugr::hugrmut::InsertionResult; + fn insert_region(&mut self, root: Self::Node, other: crate::Hugr, region: crate::Node) -> crate::hugr::hugrmut::InsertionResult; fn insert_from_view(&mut self, root: Self::Node, other: &Other) -> crate::hugr::hugrmut::InsertionResult; fn insert_subgraph(&mut self, root: Self::Node, other: &Other, subgraph: &crate::hugr::views::SiblingSubgraph) -> std::collections::HashMap; fn use_extension(&mut self, extension: impl Into>); diff --git a/hugr-core/src/hugr/views/render.rs b/hugr-core/src/hugr/views/render.rs index b787a9e383..3f8e48c963 100644 --- a/hugr-core/src/hugr/views/render.rs +++ b/hugr-core/src/hugr/views/render.rs @@ -340,8 +340,8 @@ pub(in crate::hugr) fn edge_style<'a>( config: MermaidFormatter<'_>, ) -> Box< dyn FnMut( - ::LinkEndpoint, - ::LinkEndpoint, + as LinkView>::LinkEndpoint, + as LinkView>::LinkEndpoint, ) -> EdgeStyle + 'a, > { @@ -417,15 +417,5 @@ mod tests { { assert!(RenderConfig::try_from(config).is_err()); } - - #[allow(deprecated)] - let config = RenderConfig { - entrypoint: Some(h.entrypoint()), - ..Default::default() - }; - assert_eq!( - MermaidFormatter::from_render_config(config, &h), - h.mermaid_format() - ) } } diff --git a/hugr-core/src/hugr/views/rerooted.rs b/hugr-core/src/hugr/views/rerooted.rs index 8c84abdc71..18821cfe29 100644 --- a/hugr-core/src/hugr/views/rerooted.rs +++ b/hugr-core/src/hugr/views/rerooted.rs @@ -138,6 +138,7 @@ impl HugrMut for Rerooted { fn disconnect(&mut self, node: Self::Node, port: impl Into); fn add_other_edge(&mut self, src: Self::Node, dst: Self::Node) -> (crate::OutgoingPort, crate::IncomingPort); fn insert_hugr(&mut self, root: Self::Node, other: crate::Hugr) -> crate::hugr::hugrmut::InsertionResult; + fn insert_region(&mut self, root: Self::Node, other: crate::Hugr, region: crate::Node) -> crate::hugr::hugrmut::InsertionResult; fn insert_from_view(&mut self, root: Self::Node, other: &Other) -> crate::hugr::hugrmut::InsertionResult; fn insert_subgraph(&mut self, root: Self::Node, other: &Other, subgraph: &crate::hugr::views::SiblingSubgraph) -> std::collections::HashMap; fn use_extension(&mut self, extension: impl Into>); diff --git a/hugr-core/src/hugr/views/root_checked/dfg.rs b/hugr-core/src/hugr/views/root_checked/dfg.rs index f9681c9ca7..160b4a3116 100644 --- a/hugr-core/src/hugr/views/root_checked/dfg.rs +++ b/hugr-core/src/hugr/views/root_checked/dfg.rs @@ -8,138 +8,148 @@ use thiserror::Error; use crate::{ IncomingPort, OutgoingPort, PortIndex, hugr::HugrMut, - ops::{DFG, FuncDefn, Input, OpTrait, OpType, Output, dataflow::IOTrait, handle::DfgID}, + ops::{ + OpTrait, OpType, + handle::{DataflowParentID, DfgID}, + }, types::{NoRV, Signature, TypeBase}, }; use super::RootChecked; -impl RootChecked> { - /// Get the input and output nodes of the DFG at the entrypoint node. - pub fn get_io(&self) -> [H::Node; 2] { - self.hugr() - .get_io(self.hugr().entrypoint()) - .expect("valid DFG graph") - } - - /// Rewire the inputs and outputs of the DFG to modify its signature. - /// - /// Reorder the outgoing resp. incoming wires at the input resp. output - /// node of the DFG to modify the signature of the DFG HUGR. This will - /// recursively update the signatures of all ancestors of the entrypoint. - /// - /// ### Arguments - /// - /// * `new_inputs`: The new input signature. After the map, the i-th input - /// wire will be connected to the ports connected to the - /// `new_inputs[i]`-th input of the old DFG. - /// * `new_outputs`: The new output signature. After the map, the i-th - /// output wire will be connected to the ports connected to the - /// `new_outputs[i]`-th output of the old DFG. - /// - /// Returns an `InvalidSignature` error if the new_inputs and new_outputs - /// map are not valid signatures. - /// - /// ### Panics - /// - /// Panics if the DFG is not trivially nested, i.e. if there is an ancestor - /// DFG of the entrypoint that has more than one inner DFG. - pub fn map_function_type( - &mut self, - new_inputs: &[usize], - new_outputs: &[usize], - ) -> Result<(), InvalidSignature> { - let [inp, out] = self.get_io(); - let Self(hugr, _) = self; - - // Record the old connections from and to the input and output nodes - let old_inputs_incoming = hugr - .node_outputs(inp) - .map(|p| hugr.linked_inputs(inp, p).collect_vec()) - .collect_vec(); - let old_outputs_outgoing = hugr - .node_inputs(out) - .map(|p| hugr.linked_outputs(out, p).collect_vec()) - .collect_vec(); - - // The old signature types - let old_inp_sig = hugr - .get_optype(inp) - .dataflow_signature() - .expect("input has signature"); - let old_inp_sig = old_inp_sig.output_types(); - let old_out_sig = hugr - .get_optype(out) - .dataflow_signature() - .expect("output has signature"); - let old_out_sig = old_out_sig.input_types(); - - // Check if the signature map is valid - check_valid_inputs(&old_inputs_incoming, old_inp_sig, new_inputs)?; - check_valid_outputs(old_out_sig, new_outputs)?; - - // The new signature types - let new_inp_sig = new_inputs - .iter() - .map(|&i| old_inp_sig[i].clone()) - .collect_vec(); - let new_out_sig = new_outputs - .iter() - .map(|&i| old_out_sig[i].clone()) - .collect_vec(); - let new_sig = Signature::new(new_inp_sig, new_out_sig); - - // Remove all edges of the input and output nodes - disconnect_all(hugr, inp); - disconnect_all(hugr, out); - - // Update the signatures of the IO and their ancestors - let mut is_ancestor = false; - let mut node = hugr.entrypoint(); - while matches!(hugr.get_optype(node), OpType::FuncDefn(_) | OpType::DFG(_)) { - let [inner_inp, inner_out] = hugr.get_io(node).expect("valid DFG graph"); - for node in [node, inner_inp, inner_out] { - update_signature(hugr, node, &new_sig); +macro_rules! impl_dataflow_parent_methods { + ($handle_type:ident) => { + impl RootChecked> { + /// Get the input and output nodes of the DFG at the entrypoint node. + pub fn get_io(&self) -> [H::Node; 2] { + self.hugr() + .get_io(self.hugr().entrypoint()) + .expect("valid DFG graph") } - if is_ancestor { - update_inner_dfg_links(hugr, node); - } - if let Some(parent) = hugr.get_parent(node) { - node = parent; - is_ancestor = true; - } else { - break; - } - } - // Insert the new edges at the input - let mut old_output_to_new_input = BTreeMap::::new(); - for (inp_pos, &old_pos) in new_inputs.iter().enumerate() { - for &(node, port) in &old_inputs_incoming[old_pos] { - if node != out { - hugr.connect(inp, inp_pos, node, port); - } else { - old_output_to_new_input.insert(port, inp_pos.into()); + /// Rewire the inputs and outputs of the nested DFG to modify its signature. + /// + /// Reorder the outgoing resp. incoming wires at the input resp. output + /// node of the DFG to modify the signature of the DFG HUGR. This will + /// recursively update the signatures of all ancestors of the entrypoint. + /// + /// ### Arguments + /// + /// * `new_inputs`: The new input signature. After the map, the i-th input + /// wire will be connected to the ports connected to the + /// `new_inputs[i]`-th input of the old DFG. + /// * `new_outputs`: The new output signature. After the map, the i-th + /// output wire will be connected to the ports connected to the + /// `new_outputs[i]`-th output of the old DFG. + /// + /// Returns an `InvalidSignature` error if the new_inputs and new_outputs + /// map are not valid signatures. + /// + /// ### Panics + /// + /// Panics if the DFG is not trivially nested, i.e. if there is an ancestor + /// DFG of the entrypoint that has more than one inner DFG. + pub fn map_function_type( + &mut self, + new_inputs: &[usize], + new_outputs: &[usize], + ) -> Result<(), InvalidSignature> { + let [inp, out] = self.get_io(); + let Self(hugr, _) = self; + + // Record the old connections from and to the input and output nodes + let old_inputs_incoming = hugr + .node_outputs(inp) + .map(|p| hugr.linked_inputs(inp, p).collect_vec()) + .collect_vec(); + let old_outputs_outgoing = hugr + .node_inputs(out) + .map(|p| hugr.linked_outputs(out, p).collect_vec()) + .collect_vec(); + + // The old signature types + let old_inp_sig = hugr + .get_optype(inp) + .dataflow_signature() + .expect("input has signature"); + let old_inp_sig = old_inp_sig.output_types(); + let old_out_sig = hugr + .get_optype(out) + .dataflow_signature() + .expect("output has signature"); + let old_out_sig = old_out_sig.input_types(); + + // Check if the signature map is valid + check_valid_inputs(&old_inputs_incoming, old_inp_sig, new_inputs)?; + check_valid_outputs(old_out_sig, new_outputs)?; + + // The new signature types + let new_inp_sig = new_inputs + .iter() + .map(|&i| old_inp_sig[i].clone()) + .collect_vec(); + let new_out_sig = new_outputs + .iter() + .map(|&i| old_out_sig[i].clone()) + .collect_vec(); + let new_sig = Signature::new(new_inp_sig, new_out_sig); + + // Remove all edges of the input and output nodes + disconnect_all(hugr, inp); + disconnect_all(hugr, out); + + // Update the signatures of the IO and their ancestors + let mut is_ancestor = false; + let mut node = hugr.entrypoint(); + while matches!(hugr.get_optype(node), OpType::FuncDefn(_) | OpType::DFG(_)) { + let [inner_inp, inner_out] = hugr.get_io(node).expect("valid DFG graph"); + for node in [node, inner_inp, inner_out] { + update_signature(hugr, node, &new_sig); + } + if is_ancestor { + update_inner_dfg_links(hugr, node); + } + if let Some(parent) = hugr.get_parent(node) { + node = parent; + is_ancestor = true; + } else { + break; + } + } + + // Insert the new edges at the input + let mut old_output_to_new_input = BTreeMap::::new(); + for (inp_pos, &old_pos) in new_inputs.iter().enumerate() { + for &(node, port) in &old_inputs_incoming[old_pos] { + if node != out { + hugr.connect(inp, inp_pos, node, port); + } else { + old_output_to_new_input.insert(port, inp_pos.into()); + } + } } - } - } - // Insert the new edges at the output - for (out_pos, &old_pos) in new_outputs.iter().enumerate() { - for &(node, port) in &old_outputs_outgoing[old_pos] { - if node != inp { - hugr.connect(node, port, out, out_pos); - } else { - let &inp_pos = old_output_to_new_input.get(&old_pos.into()).unwrap(); - hugr.connect(inp, inp_pos, out, out_pos); + // Insert the new edges at the output + for (out_pos, &old_pos) in new_outputs.iter().enumerate() { + for &(node, port) in &old_outputs_outgoing[old_pos] { + if node != inp { + hugr.connect(node, port, out, out_pos); + } else { + let &inp_pos = old_output_to_new_input.get(&old_pos.into()).unwrap(); + hugr.connect(inp, inp_pos, out, out_pos); + } + } } + + Ok(()) } } - - Ok(()) - } + }; } +impl_dataflow_parent_methods!(DataflowParentID); +impl_dataflow_parent_methods!(DfgID); + /// Panics if the DFG within `node` is not a single inner DFG. fn update_inner_dfg_links(hugr: &mut H, node: H::Node) { // connect all edges of the inner DFG to the input and output nodes @@ -168,20 +178,19 @@ fn disconnect_all(hugr: &mut H, node: H::Node) { } fn update_signature(hugr: &mut H, node: H::Node, new_sig: &Signature) { - let new_op: OpType = match hugr.get_optype(node) { - OpType::DFG(_) => DFG { - signature: new_sig.clone(), + match hugr.optype_mut(node) { + OpType::DFG(dfg) => { + dfg.signature = new_sig.clone(); } - .into(), - OpType::FuncDefn(fn_def_op) => { - FuncDefn::new(fn_def_op.func_name().clone(), new_sig.clone()).into() + OpType::FuncDefn(fn_def_op) => *fn_def_op.signature_mut() = new_sig.clone().into(), + OpType::Input(inp) => { + inp.types = new_sig.input().clone(); } - OpType::Input(_) => Input::new(new_sig.input().clone()).into(), - OpType::Output(_) => Output::new(new_sig.output().clone()).into(), + OpType::Output(out) => out.types = new_sig.output().clone(), _ => panic!("only update signature of DFG, FuncDefn, Input, or Output"), }; + let new_op = hugr.get_optype(node); hugr.set_num_ports(node, new_op.input_count(), new_op.output_count()); - hugr.replace_op(node, new_op); } fn check_valid_inputs( @@ -268,11 +277,11 @@ mod test { use super::*; use crate::builder::{ - Container, DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer, HugrBuilder, endo_sig, + DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer, HugrBuilder, endo_sig, }; use crate::extension::prelude::{bool_t, qb_t}; use crate::hugr::views::root_checked::RootChecked; - use crate::ops::handle::{DfgID, NodeHandle}; + use crate::ops::handle::NodeHandle; use crate::ops::{NamedOp, OpParent}; use crate::types::Signature; use crate::utils::test_quantum_extension::cx_gate; @@ -290,6 +299,51 @@ mod test { let sig = Signature::new_endo(vec![qb_t(), qb_t()]); let mut hugr = new_empty_dfg(sig); + // Wrap in RootChecked + let mut dfg_view = RootChecked::<&mut Hugr, DataflowParentID>::try_new(&mut hugr).unwrap(); + + // Test mapping inputs: [0,1] -> [1,0] + let input_map = vec![1, 0]; + let output_map = vec![0, 1]; + + // Map the I/O + dfg_view.map_function_type(&input_map, &output_map).unwrap(); + + // Verify the new signature + let dfg_hugr = dfg_view.hugr(); + let new_sig = dfg_hugr + .get_optype(dfg_hugr.entrypoint()) + .dataflow_signature() + .unwrap(); + assert_eq!(new_sig.input_count(), 2); + assert_eq!(new_sig.output_count(), 2); + + // Test invalid mapping - missing input + let invalid_input_map = vec![0, 0]; + let err = dfg_view.map_function_type(&invalid_input_map, &output_map); + assert!(matches!(err, Err(InvalidSignature::MissingIO(1, "input")))); + + // Test invalid mapping - duplicate input + let invalid_input_map = vec![0, 0, 1]; + assert!(matches!( + dfg_view.map_function_type(&invalid_input_map, &output_map), + Err(InvalidSignature::DuplicateInput(0)) + )); + + // Test invalid mapping - unknown output + let invalid_output_map = vec![0, 2]; + assert!(matches!( + dfg_view.map_function_type(&input_map, &invalid_output_map), + Err(InvalidSignature::UnknownIO(2, "output")) + )); + } + + #[test] + fn test_map_io_dfg_id() { + // Create a DFG with 2 inputs and 2 outputs + let sig = Signature::new_endo(vec![qb_t(), qb_t()]); + let mut hugr = new_empty_dfg(sig); + // Wrap in RootChecked let mut dfg_view = RootChecked::<&mut Hugr, DfgID>::try_new(&mut hugr).unwrap(); @@ -337,7 +391,7 @@ mod test { let mut hugr = new_empty_dfg(sig); // Wrap in RootChecked - let mut dfg_view = RootChecked::<&mut Hugr, DfgID>::try_new(&mut hugr).unwrap(); + let mut dfg_view = RootChecked::<&mut Hugr, DataflowParentID>::try_new(&mut hugr).unwrap(); // Test mapping outputs: [0] -> [0,0] (duplicating the output) let input_map = vec![0]; @@ -377,7 +431,7 @@ mod test { .unwrap(); // Wrap in RootChecked - let mut dfg_view = RootChecked::<&mut Hugr, DfgID>::try_new(&mut hugr).unwrap(); + let mut dfg_view = RootChecked::<&mut Hugr, DataflowParentID>::try_new(&mut hugr).unwrap(); // Test mapping inputs: [0,1] -> [1,0] (swapping inputs) let input_map = vec![1, 0]; diff --git a/hugr-core/src/hugr/views/sibling_subgraph.rs b/hugr-core/src/hugr/views/sibling_subgraph.rs index 03dff67f51..bb7a53ad15 100644 --- a/hugr-core/src/hugr/views/sibling_subgraph.rs +++ b/hugr-core/src/hugr/views/sibling_subgraph.rs @@ -428,7 +428,7 @@ impl SiblingSubgraph { if !OpTag::DataflowParent.is_superset(dfg_optype.tag()) { return Err(InvalidReplacement::InvalidDataflowGraph { node: rep_root, - op: dfg_optype.clone(), + op: Box::new(dfg_optype.clone()), }); } let [rep_input, rep_output] = replacement @@ -575,7 +575,7 @@ fn pick_parent<'a, N: HugrNode>( } fn make_boundary<'a, H: HugrView>( - region: &impl LinkView, + region: &impl LinkView, node_map: &H::RegionPortgraphNodes, inputs: &'a IncomingPorts, outputs: &'a OutgoingPorts, @@ -881,7 +881,7 @@ pub enum InvalidReplacement { /// The node ID of the root node. node: Node, /// The op type of the root node. - op: OpType, + op: Box, }, /// Replacement graph type mismatch. #[error( @@ -890,9 +890,9 @@ pub enum InvalidReplacement { ] InvalidSignature { /// The expected signature. - expected: Signature, + expected: Box, /// The actual signature. - actual: Option, + actual: Option>, }, /// `SiblingSubgraph` is not convex. #[error("SiblingSubgraph is not convex.")] diff --git a/hugr-core/src/hugr/views/tests.rs b/hugr-core/src/hugr/views/tests.rs index 3a83b5b9d8..28d304d2ee 100644 --- a/hugr-core/src/hugr/views/tests.rs +++ b/hugr-core/src/hugr/views/tests.rs @@ -4,7 +4,8 @@ use rstest::{fixture, rstest}; use crate::{ Hugr, HugrView, builder::{ - BuildError, BuildHandle, Container, DFGBuilder, Dataflow, DataflowHugr, endo_sig, inout_sig, + BuildError, BuildHandle, Container, DFGBuilder, Dataflow, DataflowHugr, HugrBuilder, + endo_sig, inout_sig, }, extension::prelude::qb_t, ops::{ @@ -183,8 +184,9 @@ fn test_dataflow_ports_only() { let mut dfg = DFGBuilder::new(endo_sig(bool_t())).unwrap(); let local_and = { - let local_and = dfg - .define_function("and", Signature::new(vec![bool_t(); 2], vec![bool_t()])) + let mut mb = dfg.module_root_builder(); + let local_and = mb + .define_function("and", Signature::new(vec![bool_t(); 2], bool_t())) .unwrap(); let first_input = local_and.input().out_wire(0); local_and.finish_with_outputs([first_input]).unwrap() diff --git a/hugr-core/src/import.rs b/hugr-core/src/import.rs index 664f4d3ea1..b1a606da96 100644 --- a/hugr-core/src/import.rs +++ b/hugr-core/src/import.rs @@ -7,7 +7,9 @@ use std::sync::Arc; use crate::{ Direction, Hugr, HugrView, Node, Port, - extension::{ExtensionId, ExtensionRegistry, SignatureError}, + extension::{ + ExtensionId, ExtensionRegistry, SignatureError, resolution::ExtensionResolutionError, + }, hugr::{HugrMut, NodeMetadata}, ops::{ AliasDecl, AliasDefn, CFG, Call, CallIndirect, Case, Conditional, Const, DFG, @@ -22,67 +24,102 @@ use crate::{ }, types::{ CustomType, FuncTypeBase, MaybeRV, PolyFuncType, PolyFuncTypeBase, RowVariable, Signature, - Type, TypeArg, TypeBase, TypeBound, TypeEnum, TypeName, TypeRow, type_param::TypeParam, + Term, Type, TypeArg, TypeBase, TypeBound, TypeEnum, TypeName, TypeRow, + type_param::{SeqPart, TypeParam}, type_row::TypeRowBase, }, }; use fxhash::FxHashMap; -use hugr_model::v0 as model; use hugr_model::v0::table; -use itertools::Either; +use hugr_model::v0::{self as model}; +use itertools::{Either, Itertools}; use smol_str::{SmolStr, ToSmolStr}; use thiserror::Error; -/// Error during import. +fn gen_str(generator: &Option) -> String { + match generator { + Some(g) => format!(" generated by {g}"), + None => String::new(), + } +} + +/// An error that can occur during import. +#[derive(Debug, Clone, Error)] +#[error("failed to import hugr{}", gen_str(&self.generator))] +pub struct ImportError { + #[source] + inner: ImportErrorInner, + generator: Option, +} + #[derive(Debug, Clone, Error)] -#[non_exhaustive] -pub enum ImportError { +enum ImportErrorInner { /// The model contains a feature that is not supported by the importer yet. /// Errors of this kind are expected to be removed as the model format and /// the core HUGR representation converge. #[error("currently unsupported: {0}")] Unsupported(String), + /// The model contains implicit information that has not yet been inferred. /// This includes wildcards and application of functions with implicit parameters. #[error("uninferred implicit: {0}")] Uninferred(String), + + /// The model is not well-formed. + #[error("{0}")] + Invalid(String), + + /// An error with additional context. + #[error("import failed in context: {1}")] + Context(#[source] Box, String), + /// A signature mismatch was detected during import. - #[error("signature error: {0}")] + #[error("signature error")] Signature(#[from] SignatureError), - /// A required extension is missing. + + /// An error relating to the loaded extension registry. + #[error("extension error")] + Extension(#[from] ExtensionError), + + /// Incorrect order hints. + #[error("incorrect order hint")] + OrderHint(#[from] OrderHintError), + + /// Extension resolution. + #[error("extension resolution error")] + ExtensionResolution(#[from] ExtensionResolutionError), +} + +#[derive(Debug, Clone, Error)] +enum ExtensionError { + /// An extension is missing. #[error("Importing the hugr requires extension {missing_ext}, which was not found in the registry. The available extensions are: [{}]", available.iter().map(std::string::ToString::to_string).collect::>().join(", "))] - Extension { + Missing { /// The missing extension. missing_ext: ExtensionId, /// The available extensions in the registry. available: Vec, }, + /// An extension type is missing. #[error( "Importing the hugr requires extension {ext} to have a type named {name}, but it was not found." )] - ExtensionType { + MissingType { /// The extension that is missing the type. ext: ExtensionId, /// The name of the missing type. name: TypeName, }, - /// The model is not well-formed. - #[error("validate error: {0}")] - Model(#[from] table::ModelError), - /// Incorrect order hints. - #[error("incorrect order hint: {0}")] - OrderHint(#[from] OrderHintError), } /// Import error caused by incorrect order hints. #[derive(Debug, Clone, Error)] -#[non_exhaustive] -pub enum OrderHintError { +enum OrderHintError { /// Duplicate order hint key in the same region. #[error("duplicate order hint key {0}")] - DuplicateKey(table::NodeId, u64), + DuplicateKey(table::RegionId, u64), /// Order hint including a key not defined in the region. #[error("order hint with unknown key {0}")] UnknownKey(u64), @@ -91,14 +128,28 @@ pub enum OrderHintError { NoOrderPort(table::NodeId), } -/// Helper macro to create an `ImportError::Unsupported` error with a formatted message. +/// Helper macro to create an `ImportErrorInner::Unsupported` error with a formatted message. macro_rules! error_unsupported { - ($($e:expr),*) => { ImportError::Unsupported(format!($($e),*)) } + ($($e:expr),*) => { ImportErrorInner::Unsupported(format!($($e),*)) } } -/// Helper macro to create an `ImportError::Uninferred` error with a formatted message. +/// Helper macro to create an `ImportErrorInner::Uninferred` error with a formatted message. macro_rules! error_uninferred { - ($($e:expr),*) => { ImportError::Uninferred(format!($($e),*)) } + ($($e:expr),*) => { ImportErrorInner::Uninferred(format!($($e),*)) } +} + +/// Helper macro to create an `ImportErrorInner::Invalid` error with a formatted message. +macro_rules! error_invalid { + ($($e:expr),*) => { ImportErrorInner::Invalid(format!($($e),*)) } +} + +/// Helper macro to create an `ImportErrorInner::Context` error with a formatted message. +macro_rules! error_context { + ($err:expr, $($e:expr),*) => { + { + ImportErrorInner::Context(Box::new($err), format!($($e),*)) + } + } } /// Import a [`Package`] from its model representation. @@ -117,6 +168,22 @@ pub fn import_package( Ok(package) } +/// Get the name of the generator from the metadata of the module. +/// If no generator is found, `None` is returned. +fn get_generator(ctx: &Context<'_>) -> Option { + ctx.module + .get_region(ctx.module.root) + .map(|r| r.meta.iter()) + .into_iter() + .flatten() + .find_map(|meta| { + let (name, json_val) = ctx.decode_json_meta(*meta).ok()??; + + (name == crate::envelope::GENERATOR_KEY) + .then_some(crate::envelope::format_generator(&json_val)) + }) +} + /// Import a [`Hugr`] module from its model representation. pub fn import_hugr( module: &table::Module, @@ -136,10 +203,26 @@ pub fn import_hugr( region_scope: table::RegionId::default(), }; - ctx.import_root()?; - ctx.link_ports()?; - ctx.link_static_ports()?; - + let import_steps: [fn(&mut Context) -> _; 3] = [ + |ctx| ctx.import_root(), + |ctx| ctx.link_ports(), + |ctx| ctx.link_static_ports(), + ]; + + for step in import_steps { + if let Err(e) = step(&mut ctx) { + return Err(ImportError { + inner: e, + generator: get_generator(&ctx), + }); + } + } + ctx.hugr + .resolve_extension_defs(extensions) + .map_err(|e| ImportError { + inner: ImportErrorInner::ExtensionResolution(e), + generator: get_generator(&ctx), + })?; Ok(ctx.hugr) } @@ -173,7 +256,7 @@ struct Context<'a> { impl<'a> Context<'a> { /// Get the signature of the node with the given `NodeId`. - fn get_node_signature(&mut self, node: table::NodeId) -> Result { + fn get_node_signature(&mut self, node: table::NodeId) -> Result { let node_data = self.get_node(node)?; let signature = node_data .signature @@ -183,26 +266,29 @@ impl<'a> Context<'a> { /// Get the node with the given `NodeId`, or return an error if it does not exist. #[inline] - fn get_node(&self, node_id: table::NodeId) -> Result<&'a table::Node<'a>, ImportError> { + fn get_node(&self, node_id: table::NodeId) -> Result<&'a table::Node<'a>, ImportErrorInner> { self.module .get_node(node_id) - .ok_or_else(|| table::ModelError::NodeNotFound(node_id).into()) + .ok_or_else(|| error_invalid!("unknown node {}", node_id)) } /// Get the term with the given `TermId`, or return an error if it does not exist. #[inline] - fn get_term(&self, term_id: table::TermId) -> Result<&'a table::Term<'a>, ImportError> { + fn get_term(&self, term_id: table::TermId) -> Result<&'a table::Term<'a>, ImportErrorInner> { self.module .get_term(term_id) - .ok_or_else(|| table::ModelError::TermNotFound(term_id).into()) + .ok_or_else(|| error_invalid!("unknown term {}", term_id)) } /// Get the region with the given `RegionId`, or return an error if it does not exist. #[inline] - fn get_region(&self, region_id: table::RegionId) -> Result<&'a table::Region<'a>, ImportError> { + fn get_region( + &self, + region_id: table::RegionId, + ) -> Result<&'a table::Region<'a>, ImportErrorInner> { self.module .get_region(region_id) - .ok_or_else(|| table::ModelError::RegionNotFound(region_id).into()) + .ok_or_else(|| error_invalid!("unknown region {}", region_id)) } fn make_node( @@ -210,7 +296,7 @@ impl<'a> Context<'a> { node_id: table::NodeId, op: OpType, parent: Node, - ) -> Result { + ) -> Result { let node = self.hugr.add_node_with_parent(parent, op); self.nodes.insert(node_id, node); @@ -219,7 +305,8 @@ impl<'a> Context<'a> { self.record_links(node, Direction::Outgoing, node_data.outputs); for meta_item in node_data.meta { - self.import_node_metadata(node, *meta_item)?; + self.import_node_metadata(node, *meta_item) + .map_err(|err| error_context!(err, "node metadata"))?; } Ok(node) @@ -229,21 +316,9 @@ impl<'a> Context<'a> { &mut self, node: Node, meta_item: table::TermId, - ) -> Result<(), ImportError> { + ) -> Result<(), ImportErrorInner> { // Import the JSON metadata - if let Some([name_arg, json_arg]) = self.match_symbol(meta_item, model::COMPAT_META_JSON)? { - let table::Term::Literal(model::Literal::Str(name)) = self.get_term(name_arg)? else { - return Err(table::ModelError::TypeError(meta_item).into()); - }; - - let table::Term::Literal(model::Literal::Str(json_str)) = self.get_term(json_arg)? - else { - return Err(table::ModelError::TypeError(meta_item).into()); - }; - - let json_value: NodeMetadata = serde_json::from_str(json_str) - .map_err(|_| table::ModelError::TypeError(meta_item))?; - + if let Some((name, json_value)) = self.decode_json_meta(meta_item)? { self.hugr.set_metadata(node, name, json_value); } @@ -255,6 +330,44 @@ impl<'a> Context<'a> { Ok(()) } + fn decode_json_meta( + &self, + meta_item: table::TermId, + ) -> Result, ImportErrorInner> { + Ok( + if let Some([name_arg, json_arg]) = + self.match_symbol(meta_item, model::COMPAT_META_JSON)? + { + let table::Term::Literal(model::Literal::Str(name)) = self.get_term(name_arg)? + else { + return Err(error_invalid!( + "`{}` expects a string literal as its first argument", + model::COMPAT_META_JSON + )); + }; + + let table::Term::Literal(model::Literal::Str(json_str)) = + self.get_term(json_arg)? + else { + return Err(error_invalid!( + "`{}` expects a string literal as its second argument", + model::COMPAT_CONST_JSON + )); + }; + + let json_value: NodeMetadata = serde_json::from_str(json_str).map_err(|_| { + error_invalid!( + "failed to parse JSON string for `{}` metadata", + model::COMPAT_CONST_JSON + ) + })?; + Some((name.to_owned(), json_value)) + } else { + None + }, + ) + } + /// Associate links with the ports of the given node in the given direction. fn record_links(&mut self, node: Node, direction: Direction, links: &'a [table::LinkIndex]) { let optype = self.hugr.get_optype(node); @@ -271,7 +384,7 @@ impl<'a> Context<'a> { /// Link up the ports in the hugr graph, according to the connectivity information that /// has been gathered in the `link_ports` map. - fn link_ports(&mut self) -> Result<(), ImportError> { + fn link_ports(&mut self) -> Result<(), ImportErrorInner> { // For each edge, we group the ports by their direction. We reuse the `inputs` and // `outputs` vectors to avoid unnecessary allocations. let mut inputs = Vec::new(); @@ -319,7 +432,7 @@ impl<'a> Context<'a> { Ok(()) } - fn link_static_ports(&mut self) -> Result<(), ImportError> { + fn link_static_ports(&mut self) -> Result<(), ImportErrorInner> { for (src_id, dst_id) in std::mem::take(&mut self.static_edges) { // None of these lookups should fail given how we constructed `static_edges`. let src = self.nodes[&src_id]; @@ -332,35 +445,40 @@ impl<'a> Context<'a> { Ok(()) } - fn get_symbol_name(&self, node_id: table::NodeId) -> Result<&'a str, ImportError> { + fn get_symbol_name(&self, node_id: table::NodeId) -> Result<&'a str, ImportErrorInner> { let node_data = self.get_node(node_id)?; let name = node_data .operation .symbol() - .ok_or(table::ModelError::InvalidSymbol(node_id))?; + .ok_or_else(|| error_invalid!("node {} is expected to be a symbol", node_id))?; Ok(name) } fn get_func_signature( &mut self, func_node: table::NodeId, - ) -> Result { + ) -> Result { let symbol = match self.get_node(func_node)?.operation { table::Operation::DefineFunc(symbol) => symbol, table::Operation::DeclareFunc(symbol) => symbol, - _ => return Err(table::ModelError::UnexpectedOperation(func_node).into()), + _ => { + return Err(error_invalid!( + "node {} is expected to be a function declaration or definition", + func_node + )); + } }; self.import_poly_func_type(func_node, *symbol, |_, signature| Ok(signature)) } /// Import the root region of the module. - fn import_root(&mut self) -> Result<(), ImportError> { + fn import_root(&mut self) -> Result<(), ImportErrorInner> { self.region_scope = self.module.root; let region_data = self.get_region(self.module.root)?; for node in region_data.children { - self.import_node(*node, self.hugr.entrypoint())?; + self.import_node(*node, self.hugr.module_root())?; } for meta_item in region_data.meta { @@ -374,250 +492,126 @@ impl<'a> Context<'a> { &mut self, node_id: table::NodeId, parent: Node, - ) -> Result, ImportError> { + ) -> Result, ImportErrorInner> { let node_data = self.get_node(node_id)?; - match node_data.operation { - table::Operation::Invalid => Err(table::ModelError::InvalidOperation(node_id).into()), - table::Operation::Dfg => { - let signature = self.get_node_signature(node_id)?; - let optype = OpType::DFG(DFG { signature }); - let node = self.make_node(node_id, optype, parent)?; - - let [region] = node_data.regions else { - return Err(table::ModelError::InvalidRegions(node_id).into()); - }; - - self.import_dfg_region(node_id, *region, node)?; - Ok(Some(node)) - } - - table::Operation::Cfg => { - let signature = self.get_node_signature(node_id)?; - let optype = OpType::CFG(CFG { signature }); - let node = self.make_node(node_id, optype, parent)?; - - let [region] = node_data.regions else { - return Err(table::ModelError::InvalidRegions(node_id).into()); - }; - - self.import_cfg_region(node_id, *region, node)?; - Ok(Some(node)) - } - - table::Operation::Block => { - let node = self.import_cfg_block(node_id, parent)?; - Ok(Some(node)) - } - - table::Operation::DefineFunc(symbol) => { - self.import_poly_func_type(node_id, *symbol, |ctx, signature| { - let optype = OpType::FuncDefn(FuncDefn::new(symbol.name, signature)); - - let node = ctx.make_node(node_id, optype, parent)?; - - let [region] = node_data.regions else { - return Err(table::ModelError::InvalidRegions(node_id).into()); - }; - - ctx.import_dfg_region(node_id, *region, node)?; - - Ok(Some(node)) - }) - } - - table::Operation::DeclareFunc(symbol) => { - self.import_poly_func_type(node_id, *symbol, |ctx, signature| { - let optype = OpType::FuncDecl(FuncDecl::new(symbol.name, signature)); - - let node = ctx.make_node(node_id, optype, parent)?; - - Ok(Some(node)) - }) - } - - table::Operation::TailLoop => { - let node = self.import_tail_loop(node_id, parent)?; - Ok(Some(node)) + let result = match node_data.operation { + table::Operation::Invalid => { + return Err(error_invalid!("tried to import an `invalid` operation")); } - table::Operation::Conditional => { - let node = self.import_conditional(node_id, parent)?; - Ok(Some(node)) - } - - table::Operation::Custom(operation) => { - if let Some([_, _]) = self.match_symbol(operation, model::CORE_CALL_INDIRECT)? { - let signature = self.get_node_signature(node_id)?; - let optype = OpType::CallIndirect(CallIndirect { signature }); - let node = self.make_node(node_id, optype, parent)?; - return Ok(Some(node)); - } - - if let Some([_, _, func]) = self.match_symbol(operation, model::CORE_CALL)? { - let table::Term::Apply(symbol, args) = self.get_term(func)? else { - return Err(table::ModelError::TypeError(func).into()); - }; - - let func_sig = self.get_func_signature(*symbol)?; - - let type_args = args - .iter() - .map(|term| self.import_type_arg(*term)) - .collect::, _>>()?; - - self.static_edges.push((*symbol, node_id)); - let optype = OpType::Call(Call::try_new(func_sig, type_args)?); - - let node = self.make_node(node_id, optype, parent)?; - return Ok(Some(node)); - } - - if let Some([_, value]) = self.match_symbol(operation, model::CORE_LOAD_CONST)? { - // If the constant refers directly to a function, import this as the `LoadFunc` operation. - if let table::Term::Apply(symbol, args) = self.get_term(value)? { - let func_node_data = self - .module - .get_node(*symbol) - .ok_or(table::ModelError::NodeNotFound(*symbol))?; - - if let table::Operation::DefineFunc(_) | table::Operation::DeclareFunc(_) = - func_node_data.operation - { - let func_sig = self.get_func_signature(*symbol)?; - let type_args = args - .iter() - .map(|term| self.import_type_arg(*term)) - .collect::, _>>()?; - - self.static_edges.push((*symbol, node_id)); - - let optype = - OpType::LoadFunction(LoadFunction::try_new(func_sig, type_args)?); - - let node = self.make_node(node_id, optype, parent)?; - return Ok(Some(node)); - } - } - - // Otherwise use const nodes - let signature = node_data - .signature - .ok_or_else(|| error_uninferred!("node signature"))?; - let [_, outputs] = self.get_func_type(signature)?; - let outputs = self.import_closed_list(outputs)?; - let output = outputs - .first() - .ok_or(table::ModelError::TypeError(signature))?; - let datatype = self.import_type(*output)?; - - let imported_value = self.import_value(value, *output)?; - - let load_const_node = self.make_node( - node_id, - OpType::LoadConstant(LoadConstant { - datatype: datatype.clone(), - }), - parent, - )?; - - let const_node = self - .hugr - .add_node_with_parent(parent, OpType::Const(Const::new(imported_value))); - - self.hugr.connect(const_node, 0, load_const_node, 0); - - return Ok(Some(load_const_node)); - } - - if let Some([_, _, tag]) = self.match_symbol(operation, model::CORE_MAKE_ADT)? { - let table::Term::Literal(model::Literal::Nat(tag)) = self.get_term(tag)? else { - return Err(table::ModelError::TypeError(tag).into()); - }; - - let signature = node_data - .signature - .ok_or_else(|| error_uninferred!("node signature"))?; - let [_, outputs] = self.get_func_type(signature)?; - let (variants, _) = self.import_adt_and_rest(node_id, outputs)?; - let node = self.make_node( - node_id, - OpType::Tag(Tag { - variants, - tag: *tag as usize, - }), - parent, - )?; - return Ok(Some(node)); - } - - let table::Term::Apply(node, params) = self.get_term(operation)? else { - return Err(table::ModelError::TypeError(operation).into()); - }; - let name = self.get_symbol_name(*node)?; - let args = params - .iter() - .map(|param| self.import_type_arg(*param)) - .collect::, _>>()?; - let (extension, name) = self.import_custom_name(name)?; - let signature = self.get_node_signature(node_id)?; - // TODO: Currently we do not have the description or any other metadata for - // the custom op. This will improve with declarative extensions being able - // to declare operations as a node, in which case the description will be attached - // to that node as metadata. - - let optype = OpType::OpaqueOp(OpaqueOp::new(extension, name, args, signature)); - - let node = self.make_node(node_id, optype, parent)?; + table::Operation::Dfg => Some( + self.import_node_dfg(node_id, parent, node_data) + .map_err(|err| error_context!(err, "`dfg` node with id {}", node_id))?, + ), + + table::Operation::Cfg => Some( + self.import_node_cfg(node_id, parent, node_data) + .map_err(|err| error_context!(err, "`cfg` node with id {}", node_id))?, + ), + + table::Operation::Block => Some( + self.import_node_block(node_id, parent) + .map_err(|err| error_context!(err, "`block` node with id {}", node_id))?, + ), + + table::Operation::DefineFunc(symbol) => Some( + self.import_node_define_func(node_id, symbol, node_data, parent) + .map_err(|err| error_context!(err, "`define-func` node with id {}", node_id))?, + ), + + table::Operation::DeclareFunc(symbol) => Some( + self.import_node_declare_func(node_id, symbol, parent) + .map_err(|err| { + error_context!(err, "`declare-func` node with id {}", node_id) + })?, + ), + + table::Operation::TailLoop => Some( + self.import_tail_loop(node_id, parent) + .map_err(|err| error_context!(err, "`tail-loop` node with id {}", node_id))?, + ), + + table::Operation::Conditional => Some( + self.import_conditional(node_id, parent) + .map_err(|err| error_context!(err, "`cond` node with id {}", node_id))?, + ), + + table::Operation::Custom(operation) => Some( + self.import_node_custom(node_id, operation, node_data, parent) + .map_err(|err| error_context!(err, "custom node with id {}", node_id))?, + ), + + table::Operation::DefineAlias(symbol, value) => Some( + self.import_node_define_alias(node_id, symbol, value, parent) + .map_err(|err| { + error_context!(err, "`define-alias` node with id {}", node_id) + })?, + ), + + table::Operation::DeclareAlias(symbol) => Some( + self.import_node_declare_alias(node_id, symbol, parent) + .map_err(|err| { + error_context!(err, "`declare-alias` node with id {}", node_id) + })?, + ), + + table::Operation::Import { .. } => None, + + table::Operation::DeclareConstructor { .. } => None, + table::Operation::DeclareOperation { .. } => None, + }; - Ok(Some(node)) - } + Ok(result) + } - table::Operation::DefineAlias(symbol, value) => { - if !symbol.params.is_empty() { - return Err(error_unsupported!( - "parameters or constraints in alias definition" - )); - } + fn import_node_dfg( + &mut self, + node_id: table::NodeId, + parent: Node, + node_data: &'a table::Node<'a>, + ) -> Result { + let signature = self + .get_node_signature(node_id) + .map_err(|err| error_context!(err, "node signature"))?; - let optype = OpType::AliasDefn(AliasDefn { - name: symbol.name.to_smolstr(), - definition: self.import_type(value)?, - }); + let optype = OpType::DFG(DFG { signature }); + let node = self.make_node(node_id, optype, parent)?; - let node = self.make_node(node_id, optype, parent)?; - Ok(Some(node)) - } + let [region] = node_data.regions else { + return Err(error_invalid!("dfg region expects a single region")); + }; - table::Operation::DeclareAlias(symbol) => { - if !symbol.params.is_empty() { - return Err(error_unsupported!( - "parameters or constraints in alias declaration" - )); - } + self.import_dfg_region(*region, node)?; + Ok(node) + } - let optype = OpType::AliasDecl(AliasDecl { - name: symbol.name.to_smolstr(), - bound: TypeBound::Copyable, - }); + fn import_node_cfg( + &mut self, + node_id: table::NodeId, + parent: Node, + node_data: &'a table::Node<'a>, + ) -> Result { + let signature = self + .get_node_signature(node_id) + .map_err(|err| error_context!(err, "node signature"))?; - let node = self.make_node(node_id, optype, parent)?; - Ok(Some(node)) - } + let optype = OpType::CFG(CFG { signature }); + let node = self.make_node(node_id, optype, parent)?; - table::Operation::Import { .. } => Ok(None), + let [region] = node_data.regions else { + return Err(error_invalid!("cfg nodes expect a single region")); + }; - table::Operation::DeclareConstructor { .. } => Ok(None), - table::Operation::DeclareOperation { .. } => Ok(None), - } + self.import_cfg_region(*region, node)?; + Ok(node) } fn import_dfg_region( &mut self, - node_id: table::NodeId, region: table::RegionId, node: Node, - ) -> Result<(), ImportError> { + ) -> Result<(), ImportErrorInner> { let region_data = self.get_region(region)?; let prev_region = self.region_scope; @@ -626,14 +620,16 @@ impl<'a> Context<'a> { } if region_data.kind != model::RegionKind::DataFlow { - return Err(table::ModelError::InvalidRegions(node_id).into()); + return Err(error_invalid!("expected dfg region")); } - let signature = self.import_func_type( - region_data - .signature - .ok_or_else(|| error_uninferred!("region signature"))?, - )?; + let signature = self + .import_func_type( + region_data + .signature + .ok_or_else(|| error_uninferred!("region signature"))?, + ) + .map_err(|err| error_context!(err, "signature of dfg region with id {}", region))?; // Create the input and output nodes let input = self.hugr.add_node_with_parent( @@ -657,7 +653,7 @@ impl<'a> Context<'a> { self.import_node(*child, node)?; } - self.create_order_edges(region)?; + self.create_order_edges(region, input, output)?; for meta_item in region_data.meta { self.import_node_metadata(node, *meta_item)?; @@ -671,13 +667,18 @@ impl<'a> Context<'a> { /// Create order edges between nodes of a dataflow region based on order hint metadata. /// /// This method assumes that the nodes for the children of the region have already been imported. - fn create_order_edges(&mut self, region_id: table::RegionId) -> Result<(), ImportError> { + fn create_order_edges( + &mut self, + region_id: table::RegionId, + input: Node, + output: Node, + ) -> Result<(), ImportErrorInner> { let region_data = self.get_region(region_id)?; debug_assert_eq!(region_data.kind, model::RegionKind::DataFlow); // Collect order hint keys // PERFORMANCE: It might be worthwhile to reuse the map to avoid allocations. - let mut order_keys = FxHashMap::::default(); + let mut order_keys = FxHashMap::::default(); for child_id in region_data.children { let child_data = self.get_node(*child_id)?; @@ -691,8 +692,42 @@ impl<'a> Context<'a> { continue; }; - if order_keys.insert(*key, *child_id).is_some() { - return Err(OrderHintError::DuplicateKey(*child_id, *key).into()); + // NOTE: The lookups here are expected to succeed since we only + // process the order metadata after we have imported the nodes. + let child_node = self.nodes[child_id]; + let child_optype = self.hugr.get_optype(child_node); + + // Check that the node has order ports. + // NOTE: This assumes that a node has an input order port iff it has an output one. + if child_optype.other_output_port().is_none() { + return Err(OrderHintError::NoOrderPort(*child_id).into()); + } + + if order_keys.insert(*key, child_node).is_some() { + return Err(OrderHintError::DuplicateKey(region_id, *key).into()); + } + } + } + + // Collect the order hint keys for the input and output nodes + for meta_id in region_data.meta { + if let Some([key]) = self.match_symbol(*meta_id, model::ORDER_HINT_INPUT_KEY)? { + let table::Term::Literal(model::Literal::Nat(key)) = self.get_term(key)? else { + continue; + }; + + if order_keys.insert(*key, input).is_some() { + return Err(OrderHintError::DuplicateKey(region_id, *key).into()); + } + } + + if let Some([key]) = self.match_symbol(*meta_id, model::ORDER_HINT_OUTPUT_KEY)? { + let table::Term::Literal(model::Literal::Nat(key)) = self.get_term(key)? else { + continue; + }; + + if order_keys.insert(*key, output).is_some() { + return Err(OrderHintError::DuplicateKey(region_id, *key).into()); } } } @@ -714,24 +749,13 @@ impl<'a> Context<'a> { let a = order_keys.get(a).ok_or(OrderHintError::UnknownKey(*a))?; let b = order_keys.get(b).ok_or(OrderHintError::UnknownKey(*b))?; - // NOTE: The lookups here are expected to succeed since we only - // process the order metadata after we have imported the nodes. - let a_node = self.nodes[a]; - let b_node = self.nodes[b]; + // NOTE: The unwrap here must succeed: + // - For all ordinary nodes we checked that they have an order port. + // - Input and output nodes always have an order port. + let a_port = self.hugr.get_optype(*a).other_output_port().unwrap(); + let b_port = self.hugr.get_optype(*b).other_input_port().unwrap(); - let a_port = self - .hugr - .get_optype(a_node) - .other_output_port() - .ok_or(OrderHintError::NoOrderPort(*a))?; - - let b_port = self - .hugr - .get_optype(b_node) - .other_input_port() - .ok_or(OrderHintError::NoOrderPort(*b))?; - - self.hugr.connect(a_node, a_port, b_node, b_port); + self.hugr.connect(*a, a_port, *b, b_port); } Ok(()) @@ -739,13 +763,12 @@ impl<'a> Context<'a> { fn import_adt_and_rest( &mut self, - node_id: table::NodeId, list: table::TermId, - ) -> Result<(Vec, TypeRow), ImportError> { + ) -> Result<(Vec, TypeRow), ImportErrorInner> { let items = self.import_closed_list(list)?; let Some((first, rest)) = items.split_first() else { - return Err(table::ModelError::InvalidRegions(node_id).into()); + return Err(error_invalid!("expected list to have at least one element")); }; let sum_rows: Vec<_> = { @@ -766,35 +789,40 @@ impl<'a> Context<'a> { &mut self, node_id: table::NodeId, parent: Node, - ) -> Result { + ) -> Result { let node_data = self.get_node(node_id)?; debug_assert_eq!(node_data.operation, table::Operation::TailLoop); let [region] = node_data.regions else { - return Err(table::ModelError::InvalidRegions(node_id).into()); + return Err(error_invalid!( + "loop node {} expects a single region", + node_id + )); }; - let region_data = self.get_region(*region)?; - let [_, region_outputs] = self.get_func_type( - region_data - .signature - .ok_or_else(|| error_uninferred!("region signature"))?, - )?; - let (sum_rows, rest) = self.import_adt_and_rest(node_id, region_outputs)?; + let region_data = self.get_region(*region)?; - let (just_inputs, just_outputs) = { - let mut sum_rows = sum_rows.into_iter(); + let (just_inputs, just_outputs, rest) = (|| { + let [_, region_outputs] = self.get_func_type( + region_data + .signature + .ok_or_else(|| error_uninferred!("region signature"))?, + )?; + let (sum_rows, rest) = self.import_adt_and_rest(region_outputs)?; - let Some(just_inputs) = sum_rows.next() else { - return Err(table::ModelError::TypeError(region_outputs).into()); - }; + if sum_rows.len() != 2 { + return Err(error_invalid!( + "loop nodes expect their first target to be an ADT with two variants" + )); + } - let Some(just_outputs) = sum_rows.next() else { - return Err(table::ModelError::TypeError(region_outputs).into()); - }; + let mut sum_rows = sum_rows.into_iter(); + let just_inputs = sum_rows.next().unwrap(); + let just_outputs = sum_rows.next().unwrap(); - (just_inputs, just_outputs) - }; + Ok((just_inputs, just_outputs, rest)) + })() + .map_err(|err| error_context!(err, "region signature"))?; let optype = OpType::TailLoop(TailLoop { just_inputs, @@ -804,7 +832,7 @@ impl<'a> Context<'a> { let node = self.make_node(node_id, optype, parent)?; - self.import_dfg_region(node_id, *region, node)?; + self.import_dfg_region(*region, node)?; Ok(node) } @@ -812,16 +840,22 @@ impl<'a> Context<'a> { &mut self, node_id: table::NodeId, parent: Node, - ) -> Result { + ) -> Result { let node_data = self.get_node(node_id)?; debug_assert_eq!(node_data.operation, table::Operation::Conditional); - let [inputs, outputs] = self.get_func_type( - node_data - .signature - .ok_or_else(|| error_uninferred!("node signature"))?, - )?; - let (sum_rows, other_inputs) = self.import_adt_and_rest(node_id, inputs)?; - let outputs = self.import_type_row(outputs)?; + + let (sum_rows, other_inputs, outputs) = (|| { + let [inputs, outputs] = self.get_func_type( + node_data + .signature + .ok_or_else(|| error_uninferred!("node signature"))?, + )?; + let (sum_rows, other_inputs) = self.import_adt_and_rest(inputs)?; + let outputs = self.import_type_row(outputs)?; + + Ok((sum_rows, other_inputs, outputs)) + })() + .map_err(|err| error_context!(err, "node signature"))?; let optype = OpType::Conditional(Conditional { sum_rows, @@ -843,7 +877,7 @@ impl<'a> Context<'a> { .hugr .add_node_with_parent(node, OpType::Case(Case { signature })); - self.import_dfg_region(node_id, *region, case_node)?; + self.import_dfg_region(*region, case_node)?; } Ok(node) @@ -851,14 +885,13 @@ impl<'a> Context<'a> { fn import_cfg_region( &mut self, - node_id: table::NodeId, region: table::RegionId, node: Node, - ) -> Result<(), ImportError> { + ) -> Result<(), ImportErrorInner> { let region_data = self.get_region(region)?; if region_data.kind != model::RegionKind::ControlFlow { - return Err(table::ModelError::InvalidRegions(node_id).into()); + return Err(error_invalid!("expected cfg region")); } let prev_region = self.region_scope; @@ -866,19 +899,22 @@ impl<'a> Context<'a> { self.region_scope = region; } - let [_, region_targets] = self.get_func_type( - region_data - .signature - .ok_or_else(|| error_uninferred!("region signature"))?, - )?; + let region_target_types = (|| { + let [_, region_targets] = self.get_ctrl_type( + region_data + .signature + .ok_or_else(|| error_uninferred!("region signature"))?, + )?; - let region_target_types = self.import_closed_list(region_targets)?; + self.import_closed_list(region_targets) + })() + .map_err(|err| error_context!(err, "signature of cfg region with id {}", region))?; // Identify the entry node of the control flow region by looking for // a block whose input is linked to the sole source port of the CFG region. let entry_node = 'find_entry: { let [entry_link] = region_data.sources else { - return Err(table::ModelError::InvalidRegions(node_id).into()); + return Err(error_invalid!("cfg region expects a single source")); }; for child in region_data.children { @@ -894,29 +930,22 @@ impl<'a> Context<'a> { // directly from the source to the target of the region. This is // currently not allowed in hugr core directly, but may be simulated // by constructing an empty entry block. - return Err(table::ModelError::InvalidRegions(node_id).into()); + return Err(error_invalid!("cfg region without entry node")); }; // The entry node in core control flow regions is identified by being - // the first child node of the CFG node. We therefore import the entry - // node first and follow it up by every other node. + // the first child node of the CFG node. We therefore import the entry node first. self.import_node(entry_node, node)?; - for child in region_data.children { - if *child != entry_node { - self.import_node(*child, node)?; - } - } - - // Create the exit node for the control flow region. + // Create the exit node for the control flow region. This always needs + // to be second in the node list. { let cfg_outputs = { - let [ctrl_type] = region_target_types.as_slice() else { - return Err(table::ModelError::TypeError(region_targets).into()); + let [target_types] = region_target_types.as_slice() else { + return Err(error_invalid!("cfg region expects a single target")); }; - let [types] = self.expect_symbol(*ctrl_type, model::CORE_CTRL)?; - self.import_type_row(types)? + self.import_type_row(*target_types)? }; let exit = self @@ -925,8 +954,16 @@ impl<'a> Context<'a> { self.record_links(exit, Direction::Incoming, region_data.targets); } + // Finally we import all other nodes. + for child in region_data.children { + if *child != entry_node { + self.import_node(*child, node)?; + } + } + for meta_item in region_data.meta { - self.import_node_metadata(node, *meta_item)?; + self.import_node_metadata(node, *meta_item) + .map_err(|err| error_context!(err, "node metadata"))?; } self.region_scope = prev_region; @@ -934,16 +971,16 @@ impl<'a> Context<'a> { Ok(()) } - fn import_cfg_block( + fn import_node_block( &mut self, node_id: table::NodeId, parent: Node, - ) -> Result { + ) -> Result { let node_data = self.get_node(node_id)?; debug_assert_eq!(node_data.operation, table::Operation::Block); let [region] = node_data.regions else { - return Err(table::ModelError::InvalidRegions(node_id).into()); + return Err(error_invalid!("basic block expects a single region")); }; let region_data = self.get_region(*region)?; let [inputs, outputs] = self.get_func_type( @@ -952,7 +989,7 @@ impl<'a> Context<'a> { .ok_or_else(|| error_uninferred!("region signature"))?, )?; let inputs = self.import_type_row(inputs)?; - let (sum_rows, other_outputs) = self.import_adt_and_rest(node_id, outputs)?; + let (sum_rows, other_outputs) = self.import_adt_and_rest(outputs)?; let optype = OpType::DataflowBlock(DataflowBlock { inputs, @@ -961,350 +998,545 @@ impl<'a> Context<'a> { }); let node = self.make_node(node_id, optype, parent)?; - self.import_dfg_region(node_id, *region, node)?; + self.import_dfg_region(*region, node).map_err(|err| { + error_context!(err, "block body defined by region with id {}", *region) + })?; Ok(node) } - fn import_poly_func_type( + fn import_node_define_func( &mut self, - node: table::NodeId, - symbol: table::Symbol<'a>, - in_scope: impl FnOnce(&mut Self, PolyFuncTypeBase) -> Result, - ) -> Result { - let mut imported_params = Vec::with_capacity(symbol.params.len()); - - for (index, param) in symbol.params.iter().enumerate() { - self.local_vars - .insert(table::VarId(node, index as _), LocalVar::new(param.r#type)); - } + node_id: table::NodeId, + symbol: &'a table::Symbol<'a>, + node_data: &'a table::Node<'a>, + parent: Node, + ) -> Result { + let visibility = symbol.visibility.clone().ok_or(ImportErrorInner::Invalid( + "No visibility for FuncDefn".to_string(), + ))?; + self.import_poly_func_type(node_id, *symbol, |ctx, signature| { + let func_name = ctx.import_title_metadata(node_id)?.unwrap_or(symbol.name); + + let optype = + OpType::FuncDefn(FuncDefn::new_vis(func_name, signature, visibility.into())); + + let node = ctx.make_node(node_id, optype, parent)?; + + let [region] = node_data.regions else { + return Err(error_invalid!( + "function definition nodes expect a single region" + )); + }; - for constraint in symbol.constraints { - if let Some([term]) = self.match_symbol(*constraint, model::CORE_NON_LINEAR)? { - let table::Term::Var(var) = self.get_term(term)? else { - return Err(error_unsupported!( - "constraint on term that is not a variable" - )); - }; + ctx.import_dfg_region(*region, node).map_err(|err| { + error_context!(err, "function body defined by region with id {}", *region) + })?; - self.local_vars - .get_mut(var) - .ok_or(table::ModelError::InvalidVar(*var))? - .bound = TypeBound::Copyable; - } else { - return Err(error_unsupported!("constraint other than copy or discard")); - } - } - - for (index, param) in symbol.params.iter().enumerate() { - // NOTE: `PolyFuncType` only has explicit type parameters at present. - let bound = self.local_vars[&table::VarId(node, index as _)].bound; - imported_params.push(self.import_type_param(param.r#type, bound)?); - } + Ok(node) + }) + } - let body = self.import_func_type::(symbol.signature)?; - in_scope(self, PolyFuncTypeBase::new(imported_params, body)) + fn import_node_declare_func( + &mut self, + node_id: table::NodeId, + symbol: &'a table::Symbol<'a>, + parent: Node, + ) -> Result { + let visibility = symbol.visibility.clone().ok_or(ImportErrorInner::Invalid( + "No visibility for FuncDecl".to_string(), + ))?; + self.import_poly_func_type(node_id, *symbol, |ctx, signature| { + let func_name = ctx.import_title_metadata(node_id)?.unwrap_or(symbol.name); + + let optype = + OpType::FuncDecl(FuncDecl::new_vis(func_name, signature, visibility.into())); + let node = ctx.make_node(node_id, optype, parent)?; + Ok(node) + }) } - /// Import a [`TypeParam`] from a term that represents a static type. - fn import_type_param( + fn import_node_custom( &mut self, - term_id: table::TermId, - bound: TypeBound, - ) -> Result { - if let Some([]) = self.match_symbol(term_id, model::CORE_STR_TYPE)? { - return Ok(TypeParam::String); + node_id: table::NodeId, + operation: table::TermId, + node_data: &'a table::Node<'a>, + parent: Node, + ) -> Result { + if let Some([inputs, outputs]) = self.match_symbol(operation, model::CORE_CALL_INDIRECT)? { + let inputs = self.import_type_row(inputs)?; + let outputs = self.import_type_row(outputs)?; + let signature = Signature::new(inputs, outputs); + let optype = OpType::CallIndirect(CallIndirect { signature }); + let node = self.make_node(node_id, optype, parent)?; + return Ok(node); } - if let Some([]) = self.match_symbol(term_id, model::CORE_NAT_TYPE)? { - return Ok(TypeParam::max_nat()); - } + if let Some([_, _, func]) = self.match_symbol(operation, model::CORE_CALL)? { + let table::Term::Apply(symbol, args) = self.get_term(func)? else { + return Err(error_invalid!( + "expected a symbol application to be passed to `{}`", + model::CORE_CALL + )); + }; - if let Some([]) = self.match_symbol(term_id, model::CORE_BYTES_TYPE)? { - return Err(error_unsupported!( - "`{}` as `TypeParam`", - model::CORE_BYTES_TYPE - )); - } + let func_sig = self.get_func_signature(*symbol)?; - if let Some([]) = self.match_symbol(term_id, model::CORE_FLOAT_TYPE)? { - return Err(error_unsupported!( - "`{}` as `TypeParam`", - model::CORE_FLOAT_TYPE - )); - } + let type_args = args + .iter() + .map(|term| self.import_term(*term)) + .collect::, _>>()?; - if let Some([]) = self.match_symbol(term_id, model::CORE_TYPE)? { - return Ok(TypeParam::Type { b: bound }); - } + self.static_edges.push((*symbol, node_id)); + let optype = OpType::Call( + Call::try_new(func_sig, type_args).map_err(ImportErrorInner::Signature)?, + ); - if let Some([]) = self.match_symbol(term_id, model::CORE_STATIC)? { - return Err(error_unsupported!( - "`{}` as `TypeParam`", - model::CORE_STATIC - )); + let node = self.make_node(node_id, optype, parent)?; + return Ok(node); } - if let Some([]) = self.match_symbol(term_id, model::CORE_CONSTRAINT)? { - return Err(error_unsupported!( - "`{}` as `TypeParam`", - model::CORE_CONSTRAINT - )); - } + if let Some([_, value]) = self.match_symbol(operation, model::CORE_LOAD_CONST)? { + // If the constant refers directly to a function, import this as the `LoadFunc` operation. + if let table::Term::Apply(symbol, args) = self.get_term(value)? { + let func_node_data = self.get_node(*symbol)?; - if let Some([]) = self.match_symbol(term_id, model::CORE_CONST)? { - return Err(error_unsupported!("`{}` as `TypeParam`", model::CORE_CONST)); - } + if let table::Operation::DefineFunc(_) | table::Operation::DeclareFunc(_) = + func_node_data.operation + { + let func_sig = self.get_func_signature(*symbol)?; + let type_args = args + .iter() + .map(|term| self.import_term(*term)) + .collect::, _>>()?; - if let Some([]) = self.match_symbol(term_id, model::CORE_CTRL_TYPE)? { - return Err(error_unsupported!( - "`{}` as `TypeParam`", - model::CORE_CTRL_TYPE - )); - } + self.static_edges.push((*symbol, node_id)); - if let Some([item_type]) = self.match_symbol(term_id, model::CORE_LIST_TYPE)? { - // At present `hugr-model` has no way to express that the item - // type of a list must be copyable. Therefore we import it as `Any`. - let param = Box::new(self.import_type_param(item_type, TypeBound::Any)?); - return Ok(TypeParam::List { param }); - } + let optype = OpType::LoadFunction( + LoadFunction::try_new(func_sig, type_args) + .map_err(ImportErrorInner::Signature)?, + ); - if let Some([_]) = self.match_symbol(term_id, model::CORE_TUPLE_TYPE)? { - // At present `hugr-model` has no way to express that the item - // types of a tuple must be copyable. Therefore we import it as `Any`. - todo!("import tuple type"); - } + let node = self.make_node(node_id, optype, parent)?; + return Ok(node); + } + } - match self.get_term(term_id)? { - table::Term::Wildcard => Err(error_uninferred!("wildcard")), + // Otherwise use const nodes + let signature = node_data + .signature + .ok_or_else(|| error_uninferred!("node signature"))?; + let [_, outputs] = self.get_func_type(signature)?; + let outputs = self.import_closed_list(outputs)?; + let output = outputs.first().ok_or_else(|| { + error_invalid!("`{}` expects a single output", model::CORE_LOAD_CONST) + })?; + let datatype = self.import_type(*output)?; + + let imported_value = self.import_value(value, *output)?; + + let load_const_node = self.make_node( + node_id, + OpType::LoadConstant(LoadConstant { + datatype: datatype.clone(), + }), + parent, + )?; - table::Term::Var { .. } => Err(error_unsupported!("type variable as `TypeParam`")), - table::Term::Apply(symbol, _) => { - let name = self.get_symbol_name(*symbol)?; - Err(error_unsupported!("custom type `{}` as `TypeParam`", name)) - } + let const_node = self + .hugr + .add_node_with_parent(parent, OpType::Const(Const::new(imported_value))); - table::Term::Tuple(_) - | table::Term::List { .. } - | table::Term::Func { .. } - | table::Term::Literal(_) => Err(table::ModelError::TypeError(term_id).into()), - } - } + self.hugr.connect(const_node, 0, load_const_node, 0); - /// Import a `TypeArg` from a term that represents a static type or value. - fn import_type_arg(&mut self, term_id: table::TermId) -> Result { - if let Some([]) = self.match_symbol(term_id, model::CORE_STR_TYPE)? { - return Err(error_unsupported!( - "`{}` as `TypeArg`", - model::CORE_STR_TYPE - )); + return Ok(load_const_node); } - if let Some([]) = self.match_symbol(term_id, model::CORE_NAT_TYPE)? { - return Err(error_unsupported!( - "`{}` as `TypeArg`", - model::CORE_NAT_TYPE - )); + if let Some([_, _, tag]) = self.match_symbol(operation, model::CORE_MAKE_ADT)? { + let table::Term::Literal(model::Literal::Nat(tag)) = self.get_term(tag)? else { + return Err(error_invalid!( + "`{}` expects a nat literal tag", + model::CORE_MAKE_ADT + )); + }; + + let signature = node_data + .signature + .ok_or_else(|| error_uninferred!("node signature"))?; + let [_, outputs] = self.get_func_type(signature)?; + let (variants, _) = self.import_adt_and_rest(outputs)?; + let node = self.make_node( + node_id, + OpType::Tag(Tag { + variants, + tag: *tag as usize, + }), + parent, + )?; + return Ok(node); } - if let Some([]) = self.match_symbol(term_id, model::CORE_BYTES_TYPE)? { - return Err(error_unsupported!( - "`{}` as `TypeArg`", - model::CORE_BYTES_TYPE + let table::Term::Apply(node, params) = self.get_term(operation)? else { + return Err(error_invalid!( + "custom operations expect a symbol application referencing an operation" )); - } + }; + let name = self.get_symbol_name(*node)?; + let args = params + .iter() + .map(|param| self.import_term(*param)) + .collect::, _>>()?; + let (extension, name) = self.import_custom_name(name)?; + let signature = self.get_node_signature(node_id)?; + + // TODO: Currently we do not have the description or any other metadata for + // the custom op. This will improve with declarative extensions being able + // to declare operations as a node, in which case the description will be attached + // to that node as metadata. + + let optype = OpType::OpaqueOp(OpaqueOp::new(extension, name, args, signature)); + self.make_node(node_id, optype, parent) + } - if let Some([]) = self.match_symbol(term_id, model::CORE_FLOAT_TYPE)? { + fn import_node_define_alias( + &mut self, + node_id: table::NodeId, + symbol: &'a table::Symbol<'a>, + value: table::TermId, + parent: Node, + ) -> Result { + if !symbol.params.is_empty() { return Err(error_unsupported!( - "`{}` as `TypeArg`", - model::CORE_FLOAT_TYPE + "parameters or constraints in alias definition" )); } - if let Some([]) = self.match_symbol(term_id, model::CORE_TYPE)? { - return Err(error_unsupported!("`{}` as `TypeArg`", model::CORE_TYPE)); - } + let optype = OpType::AliasDefn(AliasDefn { + name: symbol.name.to_smolstr(), + definition: self.import_type(value)?, + }); - if let Some([]) = self.match_symbol(term_id, model::CORE_CONSTRAINT)? { + let node = self.make_node(node_id, optype, parent)?; + Ok(node) + } + + fn import_node_declare_alias( + &mut self, + node_id: table::NodeId, + symbol: &'a table::Symbol<'a>, + parent: Node, + ) -> Result { + if !symbol.params.is_empty() { return Err(error_unsupported!( - "`{}` as `TypeArg`", - model::CORE_CONSTRAINT + "parameters or constraints in alias declaration" )); } - if let Some([]) = self.match_symbol(term_id, model::CORE_STATIC)? { - return Err(error_unsupported!("`{}` as `TypeArg`", model::CORE_STATIC)); - } + let optype = OpType::AliasDecl(AliasDecl { + name: symbol.name.to_smolstr(), + bound: TypeBound::Copyable, + }); - if let Some([]) = self.match_symbol(term_id, model::CORE_CTRL_TYPE)? { - return Err(error_unsupported!( - "`{}` as `TypeArg`", - model::CORE_CTRL_TYPE - )); - } + let node = self.make_node(node_id, optype, parent)?; + Ok(node) + } - if let Some([]) = self.match_symbol(term_id, model::CORE_CONST)? { - return Err(error_unsupported!("`{}` as `TypeArg`", model::CORE_CONST)); - } + fn import_poly_func_type( + &mut self, + node: table::NodeId, + symbol: table::Symbol<'a>, + in_scope: impl FnOnce(&mut Self, PolyFuncTypeBase) -> Result, + ) -> Result { + (|| { + let mut imported_params = Vec::with_capacity(symbol.params.len()); - if let Some([]) = self.match_symbol(term_id, model::CORE_LIST_TYPE)? { - return Err(error_unsupported!( - "`{}` as `TypeArg`", - model::CORE_LIST_TYPE - )); - } + for (index, param) in symbol.params.iter().enumerate() { + self.local_vars + .insert(table::VarId(node, index as _), LocalVar::new(param.r#type)); + } - match self.get_term(term_id)? { - table::Term::Wildcard => Err(error_uninferred!("wildcard")), + for constraint in symbol.constraints { + if let Some([term]) = self.match_symbol(*constraint, model::CORE_NON_LINEAR)? { + let table::Term::Var(var) = self.get_term(term)? else { + return Err(error_unsupported!( + "constraint on term that is not a variable" + )); + }; + + self.local_vars + .get_mut(var) + .ok_or_else(|| error_invalid!("unknown variable {}", var))? + .bound = TypeBound::Copyable; + } else { + return Err(error_unsupported!("constraint other than copy or discard")); + } + } - table::Term::Var(var) => { - let var_info = self - .local_vars - .get(var) - .ok_or(table::ModelError::InvalidVar(*var))?; - let decl = self.import_type_param(var_info.r#type, var_info.bound)?; - Ok(TypeArg::new_var_use(var.1 as _, decl)) + for (index, param) in symbol.params.iter().enumerate() { + // NOTE: `PolyFuncType` only has explicit type parameters at present. + let bound = self.local_vars[&table::VarId(node, index as _)].bound; + imported_params.push( + self.import_term_with_bound(param.r#type, bound) + .map_err(|err| error_context!(err, "type of parameter `{}`", param.name))?, + ); } - table::Term::List { .. } => { - let elems = self - .import_closed_list(term_id)? - .iter() - .map(|item| self.import_type_arg(*item)) - .collect::>()?; + let body = self.import_func_type::(symbol.signature)?; + in_scope(self, PolyFuncTypeBase::new(imported_params, body)) + })() + .map_err(|err| error_context!(err, "symbol `{}` defined by node {}", symbol.name, node)) + } - Ok(TypeArg::Sequence { elems }) + /// Import a [`Term`] from a term that represents a static type or value. + fn import_term(&mut self, term_id: table::TermId) -> Result { + self.import_term_with_bound(term_id, TypeBound::Linear) + } + + fn import_term_with_bound( + &mut self, + term_id: table::TermId, + bound: TypeBound, + ) -> Result { + (|| { + if let Some([]) = self.match_symbol(term_id, model::CORE_STR_TYPE)? { + return Ok(Term::StringType); } - table::Term::Tuple { .. } => { - // NOTE: While `TypeArg`s can represent tuples as - // `TypeArg::Sequence`s, this conflates lists and tuples. To - // avoid ambiguity we therefore report an error here for now. - Err(error_unsupported!("tuples as `TypeArg`")) + if let Some([]) = self.match_symbol(term_id, model::CORE_NAT_TYPE)? { + return Ok(Term::max_nat_type()); } - table::Term::Literal(model::Literal::Str(value)) => Ok(TypeArg::String { - arg: value.to_string(), - }), + if let Some([]) = self.match_symbol(term_id, model::CORE_BYTES_TYPE)? { + return Ok(Term::BytesType); + } - table::Term::Literal(model::Literal::Nat(value)) => { - Ok(TypeArg::BoundedNat { n: *value }) + if let Some([]) = self.match_symbol(term_id, model::CORE_FLOAT_TYPE)? { + return Ok(Term::FloatType); } - table::Term::Literal(model::Literal::Bytes(_)) => { - Err(error_unsupported!("`(bytes ..)` as `TypeArg`")) + if let Some([]) = self.match_symbol(term_id, model::CORE_TYPE)? { + return Ok(TypeParam::RuntimeType(bound)); } - table::Term::Literal(model::Literal::Float(_)) => { - Err(error_unsupported!("float literal as `TypeArg`")) + + if let Some([]) = self.match_symbol(term_id, model::CORE_CONSTRAINT)? { + return Err(error_unsupported!("`{}`", model::CORE_CONSTRAINT)); } - table::Term::Func { .. } => Err(error_unsupported!("function constant as `TypeArg`")), - table::Term::Apply { .. } => { - let ty = self.import_type(term_id)?; - Ok(TypeArg::Type { ty }) + if let Some([]) = self.match_symbol(term_id, model::CORE_STATIC)? { + return Ok(Term::StaticType); } - } + + if let Some([]) = self.match_symbol(term_id, model::CORE_CONST)? { + return Err(error_unsupported!("`{}`", model::CORE_CONST)); + } + + if let Some([item_type]) = self.match_symbol(term_id, model::CORE_LIST_TYPE)? { + // At present `hugr-model` has no way to express that the item + // type of a list must be copyable. Therefore we import it as `Any`. + let item_type = self + .import_term(item_type) + .map_err(|err| error_context!(err, "item type of list type"))?; + return Ok(TypeParam::new_list_type(item_type)); + } + + if let Some([item_types]) = self.match_symbol(term_id, model::CORE_TUPLE_TYPE)? { + // At present `hugr-model` has no way to express that the item + // types of a tuple must be copyable. Therefore we import it as `Any`. + let item_types = self + .import_term(item_types) + .map_err(|err| error_context!(err, "item types of tuple type"))?; + return Ok(TypeParam::new_tuple_type(item_types)); + } + + match self.get_term(term_id)? { + table::Term::Wildcard => Err(error_uninferred!("wildcard")), + + table::Term::Var(var) => { + let var_info = self + .local_vars + .get(var) + .ok_or_else(|| error_invalid!("unknown variable {}", var))?; + let decl = self.import_term_with_bound(var_info.r#type, var_info.bound)?; + Ok(Term::new_var_use(var.1 as _, decl)) + } + + table::Term::List(parts) => { + // PERFORMANCE: Can we do this without the additional allocation? + let parts: Vec<_> = parts + .iter() + .map(|part| self.import_seq_part(part)) + .collect::>() + .map_err(|err| error_context!(err, "list parts"))?; + Ok(TypeArg::new_list_from_parts(parts)) + } + + table::Term::Tuple(parts) => { + // PERFORMANCE: Can we do this without the additional allocation? + let parts: Vec<_> = parts + .iter() + .map(|part| self.import_seq_part(part)) + .try_collect() + .map_err(|err| error_context!(err, "tuple parts"))?; + Ok(TypeArg::new_tuple_from_parts(parts)) + } + + table::Term::Literal(model::Literal::Str(value)) => { + Ok(Term::String(value.to_string())) + } + + table::Term::Literal(model::Literal::Nat(value)) => Ok(Term::BoundedNat(*value)), + + table::Term::Literal(model::Literal::Bytes(value)) => { + Ok(Term::Bytes(value.clone())) + } + table::Term::Literal(model::Literal::Float(value)) => Ok(Term::Float(*value)), + table::Term::Func { .. } => Err(error_unsupported!("function constant")), + + table::Term::Apply { .. } => { + let ty: Type = self.import_type(term_id)?; + Ok(ty.into()) + } + } + })() + .map_err(|err| error_context!(err, "term {}", term_id)) + } + + fn import_seq_part( + &mut self, + seq_part: &'a table::SeqPart, + ) -> Result, ImportErrorInner> { + Ok(match seq_part { + table::SeqPart::Item(term_id) => SeqPart::Item(self.import_term(*term_id)?), + table::SeqPart::Splice(term_id) => SeqPart::Splice(self.import_term(*term_id)?), + }) } /// Import a `Type` from a term that represents a runtime type. fn import_type( &mut self, term_id: table::TermId, - ) -> Result, ImportError> { - if let Some([_, _]) = self.match_symbol(term_id, model::CORE_FN)? { - let func_type = self.import_func_type::(term_id)?; - return Ok(TypeBase::new_function(func_type)); - } + ) -> Result, ImportErrorInner> { + (|| { + if let Some([_, _]) = self.match_symbol(term_id, model::CORE_FN)? { + let func_type = self.import_func_type::(term_id)?; + return Ok(TypeBase::new_function(func_type)); + } - if let Some([variants]) = self.match_symbol(term_id, model::CORE_ADT)? { - let variants = self.import_closed_list(variants)?; - let variants = variants - .iter() - .map(|variant| self.import_type_row::(*variant)) - .collect::, _>>()?; - return Ok(TypeBase::new_sum(variants)); - } + if let Some([variants]) = self.match_symbol(term_id, model::CORE_ADT)? { + let variants = (|| { + self.import_closed_list(variants)? + .iter() + .map(|variant| self.import_type_row::(*variant)) + .collect::, _>>() + })() + .map_err(|err| error_context!(err, "adt variants"))?; - match self.get_term(term_id)? { - table::Term::Wildcard => Err(error_uninferred!("wildcard")), + return Ok(TypeBase::new_sum(variants)); + } - table::Term::Apply(symbol, args) => { - let args = args - .iter() - .map(|arg| self.import_type_arg(*arg)) - .collect::, _>>()?; - - let name = self.get_symbol_name(*symbol)?; - let (extension, id) = self.import_custom_name(name)?; - - let extension_ref = - self.extensions - .get(&extension) - .ok_or_else(|| ImportError::Extension { - missing_ext: extension.clone(), - available: self.extensions.ids().cloned().collect(), - })?; + match self.get_term(term_id)? { + table::Term::Wildcard => Err(error_uninferred!("wildcard")), + + table::Term::Apply(symbol, args) => { + let name = self.get_symbol_name(*symbol)?; - let ext_type = - extension_ref - .get_type(&id) - .ok_or_else(|| ImportError::ExtensionType { - ext: extension.clone(), - name: id.clone(), + let args = args + .iter() + .map(|arg| self.import_term(*arg)) + .collect::, _>>() + .map_err(|err| { + error_context!(err, "type argument of custom type `{}`", name) })?; - let bound = ext_type.bound(&args); + let (extension, id) = self.import_custom_name(name)?; + + let extension_ref = + self.extensions + .get(&extension) + .ok_or_else(|| ExtensionError::Missing { + missing_ext: extension.clone(), + available: self.extensions.ids().cloned().collect(), + })?; + + let ext_type = + extension_ref + .get_type(&id) + .ok_or_else(|| ExtensionError::MissingType { + ext: extension.clone(), + name: id.clone(), + })?; + + let bound = ext_type.bound(&args); + + Ok(TypeBase::new_extension(CustomType::new( + id, + args, + extension, + bound, + &Arc::downgrade(extension_ref), + ))) + } - Ok(TypeBase::new_extension(CustomType::new( - id, - args, - extension, - bound, - &Arc::downgrade(extension_ref), - ))) - } + table::Term::Var(var @ table::VarId(_, index)) => { + let local_var = self + .local_vars + .get(var) + .ok_or(error_invalid!("unknown var {}", var))?; + Ok(TypeBase::new_var_use(*index as _, local_var.bound)) + } - table::Term::Var(var @ table::VarId(_, index)) => { - let local_var = self - .local_vars - .get(var) - .ok_or(table::ModelError::InvalidVar(*var))?; - Ok(TypeBase::new_var_use(*index as _, local_var.bound)) + // The following terms are not runtime types, but the core `Type` only contains runtime types. + // We therefore report a type error here. + table::Term::List { .. } + | table::Term::Tuple { .. } + | table::Term::Literal(_) + | table::Term::Func { .. } => Err(error_invalid!("expected a runtime type")), } - - // The following terms are not runtime types, but the core `Type` only contains runtime types. - // We therefore report a type error here. - table::Term::List { .. } - | table::Term::Tuple { .. } - | table::Term::Literal(_) - | table::Term::Func { .. } => Err(table::ModelError::TypeError(term_id).into()), - } + })() + .map_err(|err| error_context!(err, "term {} as `Type`", term_id)) } - fn get_func_type(&mut self, term_id: table::TermId) -> Result<[table::TermId; 2], ImportError> { + fn get_func_type( + &mut self, + term_id: table::TermId, + ) -> Result<[table::TermId; 2], ImportErrorInner> { self.match_symbol(term_id, model::CORE_FN)? - .ok_or(table::ModelError::TypeError(term_id).into()) + .ok_or(error_invalid!("expected a function type")) + } + + fn get_ctrl_type( + &mut self, + term_id: table::TermId, + ) -> Result<[table::TermId; 2], ImportErrorInner> { + self.match_symbol(term_id, model::CORE_CTRL)? + .ok_or(error_invalid!("expected a control type")) } fn import_func_type( &mut self, term_id: table::TermId, - ) -> Result, ImportError> { - let [inputs, outputs] = self.get_func_type(term_id)?; - let inputs = self.import_type_row(inputs)?; - let outputs = self.import_type_row(outputs)?; - Ok(FuncTypeBase::new(inputs, outputs)) + ) -> Result, ImportErrorInner> { + (|| { + let [inputs, outputs] = self.get_func_type(term_id)?; + let inputs = self + .import_type_row(inputs) + .map_err(|err| error_context!(err, "function inputs"))?; + let outputs = self + .import_type_row(outputs) + .map_err(|err| error_context!(err, "function outputs"))?; + Ok(FuncTypeBase::new(inputs, outputs)) + })() + .map_err(|err| error_context!(err, "function type")) } fn import_closed_list( &mut self, term_id: table::TermId, - ) -> Result, ImportError> { + ) -> Result, ImportErrorInner> { fn import_into( ctx: &mut Context, term_id: table::TermId, types: &mut Vec, - ) -> Result<(), ImportError> { + ) -> Result<(), ImportErrorInner> { match ctx.get_term(term_id)? { table::Term::List(parts) => { types.reserve(parts.len()); @@ -1320,7 +1552,7 @@ impl<'a> Context<'a> { } } } - _ => return Err(table::ModelError::TypeError(term_id).into()), + _ => return Err(error_invalid!("expected a closed list")), } Ok(()) @@ -1334,12 +1566,12 @@ impl<'a> Context<'a> { fn import_closed_tuple( &mut self, term_id: table::TermId, - ) -> Result, ImportError> { + ) -> Result, ImportErrorInner> { fn import_into( ctx: &mut Context, term_id: table::TermId, types: &mut Vec, - ) -> Result<(), ImportError> { + ) -> Result<(), ImportErrorInner> { match ctx.get_term(term_id)? { table::Term::Tuple(parts) => { types.reserve(parts.len()); @@ -1355,7 +1587,7 @@ impl<'a> Context<'a> { } } } - _ => return Err(table::ModelError::TypeError(term_id).into()), + _ => return Err(error_invalid!("expected a closed tuple")), } Ok(()) @@ -1369,7 +1601,7 @@ impl<'a> Context<'a> { fn import_type_rows( &mut self, term_id: table::TermId, - ) -> Result>, ImportError> { + ) -> Result>, ImportErrorInner> { self.import_closed_list(term_id)? .into_iter() .map(|term_id| self.import_type_row::(term_id)) @@ -1379,12 +1611,12 @@ impl<'a> Context<'a> { fn import_type_row( &mut self, term_id: table::TermId, - ) -> Result, ImportError> { + ) -> Result, ImportErrorInner> { fn import_into( ctx: &mut Context, term_id: table::TermId, types: &mut Vec>, - ) -> Result<(), ImportError> { + ) -> Result<(), ImportErrorInner> { match ctx.get_term(term_id)? { table::Term::List(parts) => { types.reserve(parts.len()); @@ -1401,11 +1633,11 @@ impl<'a> Context<'a> { } } table::Term::Var(table::VarId(_, index)) => { - let var = RV::try_from_rv(RowVariable(*index as _, TypeBound::Any)) - .map_err(|_| table::ModelError::TypeError(term_id))?; + let var = RV::try_from_rv(RowVariable(*index as _, TypeBound::Linear)) + .map_err(|_| error_invalid!("expected a closed list"))?; types.push(TypeBase::new(TypeEnum::RowVar(var))); } - _ => return Err(table::ModelError::TypeError(term_id).into()), + _ => return Err(error_invalid!("expected a list")), } Ok(()) @@ -1419,17 +1651,17 @@ impl<'a> Context<'a> { fn import_custom_name( &mut self, symbol: &'a str, - ) -> Result<(ExtensionId, SmolStr), ImportError> { + ) -> Result<(ExtensionId, SmolStr), ImportErrorInner> { use std::collections::hash_map::Entry; match self.custom_name_cache.entry(symbol) { Entry::Occupied(occupied_entry) => Ok(occupied_entry.get().clone()), Entry::Vacant(vacant_entry) => { let qualified_name = ExtensionId::new(symbol) - .map_err(|_| table::ModelError::MalformedName(symbol.to_smolstr()))?; + .map_err(|_| error_invalid!("`{}` is not a valid symbol name", symbol))?; let (extension, id) = qualified_name .split_last() - .ok_or_else(|| table::ModelError::MalformedName(symbol.to_smolstr()))?; + .ok_or_else(|| error_invalid!("`{}` is not a valid symbol name", symbol))?; vacant_entry.insert((extension.clone(), id.clone())); Ok((extension, id)) @@ -1441,7 +1673,7 @@ impl<'a> Context<'a> { &mut self, term_id: table::TermId, type_id: table::TermId, - ) -> Result { + ) -> Result { let term_data = self.get_term(term_id)?; // NOTE: We have special cased arrays, integers, and floats for now. @@ -1449,7 +1681,10 @@ impl<'a> Context<'a> { if let Some([runtime_type, json]) = self.match_symbol(term_id, model::COMPAT_CONST_JSON)? { let table::Term::Literal(model::Literal::Str(json)) = self.get_term(json)? else { - return Err(table::ModelError::TypeError(term_id).into()); + return Err(error_invalid!( + "`{}` expects a string literal", + model::COMPAT_CONST_JSON + )); }; // We attempt to deserialize as the custom const directly. @@ -1462,8 +1697,12 @@ impl<'a> Context<'a> { return Ok(Value::Extension { e: opaque_value }); } else { let runtime_type = self.import_type(runtime_type)?; - let value: serde_json::Value = serde_json::from_str(json) - .map_err(|_| table::ModelError::TypeError(term_id))?; + let value: serde_json::Value = serde_json::from_str(json).map_err(|_| { + error_invalid!( + "unable to parse JSON string for `{}`", + model::COMPAT_CONST_JSON + ) + })?; let custom_const = CustomSerialized::new(runtime_type, value); let opaque_value = OpaqueValue::new(custom_const); return Ok(Value::Extension { e: opaque_value }); @@ -1487,29 +1726,42 @@ impl<'a> Context<'a> { let table::Term::Literal(model::Literal::Nat(bitwidth)) = self.get_term(bitwidth)? else { - return Err(table::ModelError::TypeError(term_id).into()); + return Err(error_invalid!( + "`{}` expects a nat literal in its `bitwidth` argument", + ConstInt::CTR_NAME + )); }; if *bitwidth > 6 { - return Err(table::ModelError::TypeError(term_id).into()); + return Err(error_invalid!( + "`{}` expects the bitwidth to be at most 6, got {}", + ConstInt::CTR_NAME, + bitwidth + )); } *bitwidth as u8 }; let value = { let table::Term::Literal(model::Literal::Nat(value)) = self.get_term(value)? else { - return Err(table::ModelError::TypeError(term_id).into()); + return Err(error_invalid!( + "`{}` expects a nat literal value", + ConstInt::CTR_NAME + )); }; *value }; return Ok(ConstInt::new_u(bitwidth, value) - .map_err(|_| table::ModelError::TypeError(term_id))? + .map_err(|_| error_invalid!("failed to create int constant"))? .into()); } if let Some([value]) = self.match_symbol(term_id, ConstF64::CTR_NAME)? { let table::Term::Literal(model::Literal::Float(value)) = self.get_term(value)? else { - return Err(table::ModelError::TypeError(term_id).into()); + return Err(error_invalid!( + "`{}` expects a float literal value", + ConstF64::CTR_NAME + )); }; return Ok(ConstF64::new(value.into_inner()).into()); @@ -1521,12 +1773,16 @@ impl<'a> Context<'a> { let variants = self.import_closed_list(variants)?; let table::Term::Literal(model::Literal::Nat(tag)) = self.get_term(tag)? else { - return Err(table::ModelError::TypeError(term_id).into()); + return Err(error_invalid!( + "`{}` expects a nat literal tag", + model::CORE_ADT + )); }; - let variant = variants - .get(*tag as usize) - .ok_or(table::ModelError::TypeError(term_id))?; + let variant = variants.get(*tag as usize).ok_or(error_invalid!( + "the tag of a `{}` must be a valid index into the list of variants", + model::CORE_CONST_ADT + ))?; let variant = self.import_closed_list(*variant)?; @@ -1564,7 +1820,7 @@ impl<'a> Context<'a> { } table::Term::List { .. } | table::Term::Tuple(_) | table::Term::Literal(_) => { - Err(table::ModelError::TypeError(term_id).into()) + Err(error_invalid!("expected constant")) } table::Term::Func { .. } => Err(error_unsupported!("constant function value")), @@ -1575,7 +1831,7 @@ impl<'a> Context<'a> { &self, term_id: table::TermId, name: &str, - ) -> Result, ImportError> { + ) -> Result, ImportErrorInner> { let term = self.get_term(term_id)?; // TODO: Follow alias chains? @@ -1609,9 +1865,36 @@ impl<'a> Context<'a> { &self, term_id: table::TermId, name: &str, - ) -> Result<[table::TermId; N], ImportError> { - self.match_symbol(term_id, name)? - .ok_or(table::ModelError::TypeError(term_id).into()) + ) -> Result<[table::TermId; N], ImportErrorInner> { + self.match_symbol(term_id, name)?.ok_or(error_invalid!( + "expected symbol `{}` with arity {}", + name, + N + )) + } + + /// Searches for `core.title` metadata on the given node. + fn import_title_metadata( + &self, + node_id: table::NodeId, + ) -> Result, ImportErrorInner> { + let node_data = self.get_node(node_id)?; + for meta in node_data.meta { + let Some([name]) = self.match_symbol(*meta, model::CORE_TITLE)? else { + continue; + }; + + let table::Term::Literal(model::Literal::Str(name)) = self.get_term(name)? else { + return Err(error_invalid!( + "`{}` metadata expected a string literal as argument", + model::CORE_TITLE + )); + }; + + return Ok(Some(name.as_str())); + } + + Ok(None) } } @@ -1628,7 +1911,7 @@ impl LocalVar { pub fn new(r#type: table::TermId) -> Self { Self { r#type, - bound: TypeBound::Any, + bound: TypeBound::Linear, } } } diff --git a/hugr-core/src/lib.rs b/hugr-core/src/lib.rs index e5f57d2a8f..862b8dee8a 100644 --- a/hugr-core/src/lib.rs +++ b/hugr-core/src/lib.rs @@ -24,7 +24,8 @@ pub mod types; pub mod utils; pub use crate::core::{ - CircuitUnit, Direction, IncomingPort, Node, NodeIndex, OutgoingPort, Port, PortIndex, Wire, + CircuitUnit, Direction, IncomingPort, Node, NodeIndex, OutgoingPort, Port, PortIndex, + Visibility, Wire, }; pub use crate::extension::Extension; pub use crate::hugr::{Hugr, HugrView, SimpleReplacement}; diff --git a/hugr-core/src/ops/constant.rs b/hugr-core/src/ops/constant.rs index d27a4a0ad8..deebe5434f 100644 --- a/hugr-core/src/ops/constant.rs +++ b/hugr-core/src/ops/constant.rs @@ -1,6 +1,7 @@ //! Constant value definitions. mod custom; +mod serialize; use std::borrow::Cow; use std::collections::hash_map::DefaultHasher; // Moves into std::hash in Rust 1.76. @@ -11,6 +12,7 @@ use super::{OpTag, OpType}; use crate::envelope::serde_with::AsStringEnvelope; use crate::types::{CustomType, EdgeKind, Signature, SumType, SumTypeError, Type, TypeRow}; use crate::{Hugr, HugrView}; +use serialize::SerialSum; use delegate::delegate; use itertools::Itertools; @@ -107,16 +109,6 @@ impl AsRef for Const { } } -#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)] -struct SerialSum { - #[serde(default)] - tag: usize, - #[serde(rename = "vs")] - values: Vec, - #[serde(default, rename = "typ")] - sum_type: Option, -} - #[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)] #[serde(try_from = "SerialSum")] #[serde(into = "SerialSum")] @@ -160,43 +152,6 @@ pub(crate) fn maybe_hash_values(vals: &[Value], st: &mut H) -> bool { } } -impl TryFrom for Sum { - type Error = &'static str; - - fn try_from(value: SerialSum) -> Result { - let SerialSum { - tag, - values, - sum_type, - } = value; - - let sum_type = if let Some(sum_type) = sum_type { - sum_type - } else { - if tag != 0 { - return Err("Sum type must be provided if tag is not 0"); - } - SumType::new_tuple(values.iter().map(Value::get_type).collect_vec()) - }; - - Ok(Self { - tag, - values, - sum_type, - }) - } -} - -impl From for SerialSum { - fn from(value: Sum) -> Self { - Self { - tag: value.tag, - values: value.values, - sum_type: Some(value.sum_type), - } - } -} - #[serde_as] #[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)] #[serde(tag = "v")] @@ -327,9 +282,9 @@ pub enum CustomCheckFailure { #[error("Expected type: {expected} but value was of type: {found}")] TypeMismatch { /// The expected custom type. - expected: CustomType, + expected: Box, /// The custom type found when checking. - found: Type, + found: Box, }, /// Any other message #[error("{0}")] @@ -349,11 +304,11 @@ pub enum ConstTypeError { )] NotMonomorphicFunction { /// The root node type of the Hugr that (claims to) define the function constant. - hugr_root_type: OpType, + hugr_root_type: Box, }, /// A mismatch between the type expected and the value. #[error("Value {1:?} does not match expected type {0}")] - ConstCheckFail(Type, Value), + ConstCheckFail(Box, Value), /// Error when checking a custom value. #[error("Error when checking custom type: {0}")] CustomCheckFail(#[from] CustomCheckFailure), @@ -362,7 +317,7 @@ pub enum ConstTypeError { /// Hugrs (even functions) inside Consts must be monomorphic fn mono_fn_type(h: &Hugr) -> Result, ConstTypeError> { let err = || ConstTypeError::NotMonomorphicFunction { - hugr_root_type: h.entrypoint_optype().clone(), + hugr_root_type: Box::new(h.entrypoint_optype().clone()), }; if let Some(pf) = h.poly_func_type() { match pf.try_into() { @@ -728,7 +683,7 @@ pub(crate) mod test { index: 1, expected, found, - })) if expected == float64_type() && found == const_usize() + })) if *expected == float64_type() && *found == const_usize() ); } @@ -860,7 +815,7 @@ pub(crate) mod test { let ex_id: ExtensionId = "my_extension".try_into().unwrap(); let typ_int = CustomType::new( "my_type", - vec![TypeArg::BoundedNat { n: 8 }], + vec![TypeArg::BoundedNat(8)], ex_id.clone(), TypeBound::Copyable, // Dummy extension reference. diff --git a/hugr-core/src/ops/constant/serialize.rs b/hugr-core/src/ops/constant/serialize.rs new file mode 100644 index 0000000000..1ccfe523b3 --- /dev/null +++ b/hugr-core/src/ops/constant/serialize.rs @@ -0,0 +1,59 @@ +//! Helper definitions used to serialize constant values and ops. + +use itertools::Itertools; + +use crate::ops::Value; +use crate::types::SumType; +use crate::types::serialize::SerSimpleType; + +use super::Sum; + +/// Helper struct to serialize constant [`Sum`] values with a custom layout. +#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] +pub(super) struct SerialSum { + #[serde(default)] + tag: usize, + #[serde(rename = "vs")] + values: Vec, + /// Uses the `SerSimpleType` wrapper here instead of a direct `SumType`, + /// to ensure it gets correctly tagged with the `t` discriminant field. + #[serde(default, rename = "typ")] + sum_type: Option, +} + +impl From for SerialSum { + fn from(value: Sum) -> Self { + Self { + tag: value.tag, + values: value.values, + sum_type: Some(SerSimpleType::Sum(value.sum_type)), + } + } +} + +impl TryFrom for Sum { + type Error = &'static str; + + fn try_from(value: SerialSum) -> Result { + let SerialSum { + tag, + values, + sum_type, + } = value; + + let sum_type = if let Some(SerSimpleType::Sum(sum_type)) = sum_type { + sum_type + } else { + if tag != 0 { + return Err("Sum type must be provided if tag is not 0"); + } + SumType::new_tuple(values.iter().map(Value::get_type).collect_vec()) + }; + + Ok(Self { + tag, + values, + sum_type, + }) + } +} diff --git a/hugr-core/src/ops/controlflow.rs b/hugr-core/src/ops/controlflow.rs index ca358c624b..874a2b6ce9 100644 --- a/hugr-core/src/ops/controlflow.rs +++ b/hugr-core/src/ops/controlflow.rs @@ -359,7 +359,7 @@ mod test { #[test] fn test_subst_dataflow_block() { use crate::ops::OpTrait; - let tv0 = Type::new_var_use(0, TypeBound::Any); + let tv0 = Type::new_var_use(0, TypeBound::Linear); let dfb = DataflowBlock { inputs: vec![usize_t(), tv0.clone()].into(), other_outputs: vec![tv0.clone()].into(), @@ -375,16 +375,18 @@ mod test { #[test] fn test_subst_conditional() { - let tv1 = Type::new_var_use(1, TypeBound::Any); + let tv1 = Type::new_var_use(1, TypeBound::Linear); let cond = Conditional { sum_rows: vec![usize_t().into(), tv1.clone().into()], - other_inputs: vec![Type::new_tuple(TypeRV::new_row_var_use(0, TypeBound::Any))].into(), + other_inputs: vec![Type::new_tuple(TypeRV::new_row_var_use( + 0, + TypeBound::Linear, + ))] + .into(), outputs: vec![usize_t(), tv1].into(), }; let cond2 = cond.substitute(&Substitution::new(&[ - TypeArg::Sequence { - elems: vec![usize_t().into(); 3], - }, + TypeArg::new_list([usize_t().into(), usize_t().into(), usize_t().into()]), qb_t().into(), ])); let st = Type::new_sum(vec![usize_t(), qb_t()]); //both single-element variants diff --git a/hugr-core/src/ops/custom.rs b/hugr-core/src/ops/custom.rs index f639584789..139c87505b 100644 --- a/hugr-core/src/ops/custom.rs +++ b/hugr-core/src/ops/custom.rs @@ -7,7 +7,8 @@ use thiserror::Error; #[cfg(test)] use { crate::extension::test::SimpleOpDef, crate::proptest::any_nonempty_smolstr, - ::proptest::prelude::*, ::proptest_derive::Arbitrary, + crate::types::proptest_utils::any_serde_type_arg_vec, ::proptest::prelude::*, + ::proptest_derive::Arbitrary, }; use crate::core::HugrNode; @@ -35,6 +36,7 @@ pub struct ExtensionOp { proptest(strategy = "any::().prop_map(|x| Arc::new(x.into()))") )] def: Arc, + #[cfg_attr(test, proptest(strategy = "any_serde_type_arg_vec()"))] args: Vec, signature: Signature, // Cache } @@ -235,6 +237,7 @@ pub struct OpaqueOp { extension: ExtensionId, #[cfg_attr(test, proptest(strategy = "any_nonempty_smolstr()"))] name: OpName, + #[cfg_attr(test, proptest(strategy = "any_serde_type_arg_vec()"))] args: Vec, // note that the `signature` field might not include `extension`. Thus this must // remain private, and should be accessed through @@ -353,8 +356,8 @@ pub enum OpaqueOpError { node: N, extension: ExtensionId, op: OpName, - stored: Signature, - computed: Signature, + stored: Box, + computed: Box, }, /// An error in computing the signature of the `ExtensionOp` #[error("Error in signature of operation '{name}' in {node}: {cause}")] @@ -406,11 +409,11 @@ mod test { let op = OpaqueOp::new( "res".try_into().unwrap(), "op", - vec![TypeArg::Type { ty: usize_t() }], + vec![usize_t().into()], sig.clone(), ); assert_eq!(op.name(), "OpaqueOp:res.op"); - assert_eq!(op.args(), &[TypeArg::Type { ty: usize_t() }]); + assert_eq!(op.args(), &[usize_t().into()]); assert_eq!(op.signature().as_ref(), &sig); } diff --git a/hugr-core/src/ops/dataflow.rs b/hugr-core/src/ops/dataflow.rs index 66aa4144b6..2a09fef5c8 100644 --- a/hugr-core/src/ops/dataflow.rs +++ b/hugr-core/src/ops/dataflow.rs @@ -10,7 +10,7 @@ use crate::types::{EdgeKind, PolyFuncType, Signature, Substitution, Type, TypeAr use crate::{IncomingPort, type_row}; #[cfg(test)] -use proptest_derive::Arbitrary; +use {crate::types::proptest_utils::any_serde_type_arg_vec, proptest_derive::Arbitrary}; /// Trait implemented by all dataflow operations. pub trait DataflowOpTrait: Sized { @@ -191,6 +191,7 @@ pub struct Call { /// Signature of function being called. pub func_sig: PolyFuncType, /// The type arguments that instantiate `func_sig`. + #[cfg_attr(test, proptest(strategy = "any_serde_type_arg_vec()"))] pub type_args: Vec, /// The instantiation of `func_sig`. pub instantiation: Signature, // Cache, so we can fail in try_new() not in signature() @@ -284,8 +285,8 @@ impl Call { Ok(()) } else { Err(SignatureError::CallIncorrectlyAppliesType { - cached: self.instantiation.clone(), - expected: other.instantiation.clone(), + cached: Box::new(self.instantiation.clone()), + expected: Box::new(other.instantiation.clone()), }) } } @@ -391,6 +392,7 @@ pub struct LoadFunction { /// Signature of the function pub func_sig: PolyFuncType, /// The type arguments that instantiate `func_sig`. + #[cfg_attr(test, proptest(strategy = "any_serde_type_arg_vec()"))] pub type_args: Vec, /// The instantiation of `func_sig`. pub instantiation: Signature, // Cache, so we can fail in try_new() not in signature() @@ -474,8 +476,8 @@ impl LoadFunction { Ok(()) } else { Err(SignatureError::LoadFunctionIncorrectlyAppliesType { - cached: self.instantiation.clone(), - expected: other.instantiation.clone(), + cached: Box::new(self.instantiation.clone()), + expected: Box::new(other.instantiation.clone()), }) } } diff --git a/hugr-core/src/ops/module.rs b/hugr-core/src/ops/module.rs index db2b81f9f3..eda121f235 100644 --- a/hugr-core/src/ops/module.rs +++ b/hugr-core/src/ops/module.rs @@ -9,12 +9,11 @@ use { ::proptest_derive::Arbitrary, }; -use crate::types::{EdgeKind, PolyFuncType, Signature}; -use crate::types::{Type, TypeBound}; +use crate::Visibility; +use crate::types::{EdgeKind, PolyFuncType, Signature, Type, TypeBound}; -use super::StaticTag; use super::dataflow::DataflowParent; -use super::{OpTag, OpTrait, impl_op_name}; +use super::{OpTag, OpTrait, StaticTag, impl_op_name}; /// The root of a module, parent of all other `OpType`s. #[derive(Debug, Clone, PartialEq, Eq, Default, serde::Serialize, serde::Deserialize)] @@ -57,14 +56,31 @@ pub struct FuncDefn { #[cfg_attr(test, proptest(strategy = "any_nonempty_string()"))] name: String, signature: PolyFuncType, + #[serde(default = "priv_vis")] // sadly serde does not pick this up from the schema + visibility: Visibility, +} + +fn priv_vis() -> Visibility { + Visibility::Private } impl FuncDefn { - /// Create a new instance with the given name and signature + /// Create a new, [Visibility::Private], instance with the given name and signature. + /// See also [Self::new_vis]. pub fn new(name: impl Into, signature: impl Into) -> Self { + Self::new_vis(name, signature, Visibility::Private) + } + + /// Create a new instance with the specified name and visibility + pub fn new_vis( + name: impl Into, + signature: impl Into, + visibility: Visibility, + ) -> Self { Self { name: name.into(), signature: signature.into(), + visibility, } } @@ -87,6 +103,16 @@ impl FuncDefn { pub fn signature_mut(&mut self) -> &mut PolyFuncType { &mut self.signature } + + /// The visibility of the function, e.g. for linking + pub fn visibility(&self) -> &Visibility { + &self.visibility + } + + /// Allows changing [Self::visibility] + pub fn visibility_mut(&mut self) -> &mut Visibility { + &mut self.visibility + } } impl_op_name!(FuncDefn); @@ -123,14 +149,32 @@ pub struct FuncDecl { #[cfg_attr(test, proptest(strategy = "any_nonempty_string()"))] name: String, signature: PolyFuncType, + // (again) sadly serde does not pick this up from the schema + #[serde(default = "pub_vis")] // Note opposite of FuncDefn + visibility: Visibility, +} + +fn pub_vis() -> Visibility { + Visibility::Public } impl FuncDecl { - /// Create a new instance with the given name and signature + /// Create a new [Visibility::Public] instance with the given name and signature. + /// See also [Self::new_vis] pub fn new(name: impl Into, signature: impl Into) -> Self { + Self::new_vis(name, signature, Visibility::Public) + } + + /// Create a new instance with the given name, signature and visibility + pub fn new_vis( + name: impl Into, + signature: impl Into, + visibility: Visibility, + ) -> Self { Self { name: name.into(), signature: signature.into(), + visibility, } } @@ -139,11 +183,21 @@ impl FuncDecl { &self.name } + /// The visibility of the function, e.g. for linking + pub fn visibility(&self) -> &Visibility { + &self.visibility + } + /// Allows mutating the name of the function (as per [Self::func_name]) pub fn func_name_mut(&mut self) -> &mut String { &mut self.name } + /// Allows mutating the [Self::visibility] of the function + pub fn visibility_mut(&mut self) -> &mut Visibility { + &mut self.visibility + } + /// Gets the signature of the function pub fn signature(&self) -> &PolyFuncType { &self.signature diff --git a/hugr-core/src/ops/tag.rs b/hugr-core/src/ops/tag.rs index bed7e47370..2834cd94eb 100644 --- a/hugr-core/src/ops/tag.rs +++ b/hugr-core/src/ops/tag.rs @@ -57,6 +57,8 @@ pub enum OpTag { /// A function load operation. LoadFunc, /// A definition that could be at module level or inside a DSG. + /// Note that this means only Constants, as all other defn/decls + /// must be at Module level. ScopedDefn, /// A tail-recursive loop. TailLoop, @@ -112,8 +114,8 @@ impl OpTag { OpTag::Input => &[OpTag::DataflowChild], OpTag::Output => &[OpTag::DataflowChild], OpTag::Function => &[OpTag::ModuleOp, OpTag::StaticOutput], - OpTag::Alias => &[OpTag::ScopedDefn], - OpTag::FuncDefn => &[OpTag::Function, OpTag::ScopedDefn, OpTag::DataflowParent], + OpTag::Alias => &[OpTag::ModuleOp], + OpTag::FuncDefn => &[OpTag::Function, OpTag::DataflowParent], OpTag::DataflowBlock => &[OpTag::ControlFlowChild, OpTag::DataflowParent], OpTag::BasicBlockExit => &[OpTag::ControlFlowChild], OpTag::Case => &[OpTag::Any, OpTag::DataflowParent], diff --git a/hugr-core/src/ops/validate.rs b/hugr-core/src/ops/validate.rs index f19ae5e079..9bb4ebbe89 100644 --- a/hugr-core/src/ops/validate.rs +++ b/hugr-core/src/ops/validate.rs @@ -103,7 +103,7 @@ impl ValidateOp for super::Conditional { if sig.input != self.case_input_row(i).unwrap() || sig.output != self.outputs { return Err(ChildrenValidationError::ConditionalCaseSignature { child, - optype: optype.clone(), + optype: Box::new(optype.clone()), }); } } @@ -177,7 +177,7 @@ pub enum ChildrenValidationError { #[error("A {optype} operation is only allowed as a {expected_position} child")] InternalIOChildren { child: N, - optype: OpType, + optype: Box, expected_position: &'static str, }, /// The signature of the contained dataflow graph does not match the one of the container. @@ -193,7 +193,7 @@ pub enum ChildrenValidationError { }, /// The signature of a child case in a conditional operation does not match the container's signature. #[error("A conditional case has optype {sig}, which differs from the signature of Conditional container", sig=optype.dataflow_signature().unwrap_or_default())] - ConditionalCaseSignature { child: N, optype: OpType }, + ConditionalCaseSignature { child: N, optype: Box }, /// The conditional container's branching value does not match the number of children. #[error("The conditional container's branch Sum input should be a sum with {expected_count} elements, but it had {} elements. Sum rows: {actual_sum_rows:?}", actual_sum_rows.len())] @@ -227,9 +227,9 @@ pub enum EdgeValidationError { source_ty = source_types.clone().unwrap_or_default(), )] CFGEdgeSignatureMismatch { - edge: ChildrenEdgeData, - source_types: Option, - target_types: TypeRow, + edge: Box>, + source_types: Option>, + target_types: Box, }, } @@ -323,14 +323,14 @@ fn validate_io_nodes<'a, N: HugrNode>( OpTag::Input => { return Err(ChildrenValidationError::InternalIOChildren { child, - optype: optype.clone(), + optype: Box::new(optype.clone()), expected_position: "first", }); } OpTag::Output => { return Err(ChildrenValidationError::InternalIOChildren { child, - optype: optype.clone(), + optype: Box::new(optype.clone()), expected_position: "second", }); } @@ -357,9 +357,9 @@ fn validate_cfg_edge(edge: ChildrenEdgeData) -> Result<(), EdgeV if source_types.as_ref() != Some(target_input) { let target_types = target_input.clone(); return Err(EdgeValidationError::CFGEdgeSignatureMismatch { - edge, - source_types, - target_types, + edge: Box::new(edge), + source_types: source_types.map(Box::new), + target_types: Box::new(target_types), }); } diff --git a/hugr-core/src/package.rs b/hugr-core/src/package.rs index 7be90da270..e50639ec05 100644 --- a/hugr-core/src/package.rs +++ b/hugr-core/src/package.rs @@ -1,6 +1,5 @@ //! Bundles of hugr modules along with the extension required to load them. -use derive_more::{Display, Error, From}; use std::io; use crate::envelope::{EnvelopeConfig, EnvelopeError, read_envelope, write_envelope}; @@ -8,6 +7,7 @@ use crate::extension::ExtensionRegistry; use crate::hugr::{HugrView, ValidationError}; use crate::std_extensions::STD_REG; use crate::{Hugr, Node}; +use thiserror::Error; #[derive(Debug, Default, Clone, PartialEq)] /// Package of module HUGRs. @@ -131,11 +131,12 @@ impl AsRef<[Hugr]> for Package { } /// Error raised while validating a package. -#[derive(Debug, Display, From, Error)] +#[derive(Debug, Error)] #[non_exhaustive] +#[error("Package validation error.")] pub enum PackageValidationError { /// Error raised while validating the package hugrs. - Validation(ValidationError), + Validation(#[from] ValidationError), } #[cfg(test)] diff --git a/hugr-core/src/std_extensions.rs b/hugr-core/src/std_extensions.rs index 1d49ea4e1e..d57663391d 100644 --- a/hugr-core/src/std_extensions.rs +++ b/hugr-core/src/std_extensions.rs @@ -21,6 +21,7 @@ pub fn std_reg() -> ExtensionRegistry { arithmetic::float_types::EXTENSION.to_owned(), collections::array::EXTENSION.to_owned(), collections::list::EXTENSION.to_owned(), + collections::borrow_array::EXTENSION.to_owned(), collections::static_array::EXTENSION.to_owned(), collections::value_array::EXTENSION.to_owned(), logic::EXTENSION.to_owned(), diff --git a/hugr-core/src/std_extensions/arithmetic/int_types.rs b/hugr-core/src/std_extensions/arithmetic/int_types.rs index 5db32d55ed..71eb8fa91e 100644 --- a/hugr-core/src/std_extensions/arithmetic/int_types.rs +++ b/hugr-core/src/std_extensions/arithmetic/int_types.rs @@ -4,14 +4,14 @@ use std::num::NonZeroU64; use std::sync::{Arc, Weak}; use crate::ops::constant::ValueName; -use crate::types::TypeName; +use crate::types::{Term, TypeName}; use crate::{ Extension, extension::ExtensionId, ops::constant::CustomConst, types::{ ConstTypeError, CustomType, Type, TypeBound, - type_param::{TypeArg, TypeArgError, TypeParam}, + type_param::{TermTypeError, TypeArg, TypeParam}, }, }; use lazy_static::lazy_static; @@ -49,7 +49,7 @@ pub fn int_type(width_arg: impl Into) -> Type { lazy_static! { /// Array of valid integer types, indexed by log width of the integer. pub static ref INT_TYPES: [Type; LOG_WIDTH_BOUND as usize] = (0..LOG_WIDTH_BOUND) - .map(|i| int_type(TypeArg::BoundedNat { n: u64::from(i) })) + .map(|i| int_type(Term::from(u64::from(i)))) .collect::>() .try_into() .unwrap(); @@ -69,27 +69,25 @@ pub const LOG_WIDTH_BOUND: u8 = LOG_WIDTH_MAX + 1; /// Type parameter for the log width of the integer. #[allow(clippy::assertions_on_constants)] -pub const LOG_WIDTH_TYPE_PARAM: TypeParam = TypeParam::bounded_nat({ +pub const LOG_WIDTH_TYPE_PARAM: TypeParam = TypeParam::bounded_nat_type({ assert!(LOG_WIDTH_BOUND > 0); NonZeroU64::MIN.saturating_add(LOG_WIDTH_BOUND as u64 - 1) }); /// Get the log width of the specified type argument or error if the argument /// is invalid. -pub(super) fn get_log_width(arg: &TypeArg) -> Result { +pub(super) fn get_log_width(arg: &TypeArg) -> Result { match arg { - TypeArg::BoundedNat { n } if is_valid_log_width(*n as u8) => Ok(*n as u8), - _ => Err(TypeArgError::TypeMismatch { - arg: arg.clone(), - param: LOG_WIDTH_TYPE_PARAM, + TypeArg::BoundedNat(n) if is_valid_log_width(*n as u8) => Ok(*n as u8), + _ => Err(TermTypeError::TypeMismatch { + term: Box::new(arg.clone()), + type_: Box::new(LOG_WIDTH_TYPE_PARAM), }), } } const fn type_arg(log_width: u8) -> TypeArg { - TypeArg::BoundedNat { - n: log_width as u64, - } + TypeArg::BoundedNat(log_width as u64) } /// An integer (either signed or unsigned) @@ -239,13 +237,13 @@ mod test { #[test] fn test_int_widths() { - let type_arg_32 = TypeArg::BoundedNat { n: 5 }; + let type_arg_32 = TypeArg::BoundedNat(5); assert_matches!(get_log_width(&type_arg_32), Ok(5)); - let type_arg_128 = TypeArg::BoundedNat { n: 7 }; + let type_arg_128 = TypeArg::BoundedNat(7); assert_matches!( get_log_width(&type_arg_128), - Err(TypeArgError::TypeMismatch { .. }) + Err(TermTypeError::TypeMismatch { .. }) ); } diff --git a/hugr-core/src/std_extensions/arithmetic/mod.rs b/hugr-core/src/std_extensions/arithmetic/mod.rs index dc26ac4b0b..fbf3531ee7 100644 --- a/hugr-core/src/std_extensions/arithmetic/mod.rs +++ b/hugr-core/src/std_extensions/arithmetic/mod.rs @@ -20,7 +20,7 @@ mod test { for i in 0..LOG_WIDTH_BOUND { assert_eq!( INT_TYPES[i as usize], - int_type(TypeArg::BoundedNat { n: u64::from(i) }) + int_type(TypeArg::BoundedNat(u64::from(i))) ); } } diff --git a/hugr-core/src/std_extensions/collections.rs b/hugr-core/src/std_extensions/collections.rs index efd53c805e..0c52ad94d6 100644 --- a/hugr-core/src/std_extensions/collections.rs +++ b/hugr-core/src/std_extensions/collections.rs @@ -1,6 +1,7 @@ //! List type and operations. pub mod array; +pub mod borrow_array; pub mod list; pub mod static_array; pub mod value_array; diff --git a/hugr-core/src/std_extensions/collections/array.rs b/hugr-core/src/std_extensions/collections/array.rs index eb31441453..de55a41947 100644 --- a/hugr-core/src/std_extensions/collections/array.rs +++ b/hugr-core/src/std_extensions/collections/array.rs @@ -96,7 +96,7 @@ lazy_static! { Extension::new_arc(EXTENSION_ID, VERSION, |extension, extension_ref| { extension.add_type( ARRAY_TYPENAME, - vec![ TypeParam::max_nat(), TypeBound::Any.into()], + vec![ TypeParam::max_nat_type(), TypeBound::Linear.into()], "Fixed-length array".into(), // Default array is linear, even if the elements are copyable TypeDefBound::any(), @@ -223,7 +223,7 @@ pub trait ArrayOpBuilder: GenericArrayOpBuilder { self.add_generic_array_unpack::(elem_ty, size, input) } /// Adds an array clone operation to the dataflow graph and return the wires - /// representing the originala and cloned array. + /// representing the original and cloned array. /// /// # Arguments /// diff --git a/hugr-core/src/std_extensions/collections/array/array_clone.rs b/hugr-core/src/std_extensions/collections/array/array_clone.rs index 2a3de6d6d9..2575a32c26 100644 --- a/hugr-core/src/std_extensions/collections/array/array_clone.rs +++ b/hugr-core/src/std_extensions/collections/array/array_clone.rs @@ -51,8 +51,8 @@ impl FromStr for GenericArrayCloneDef { impl GenericArrayCloneDef { /// To avoid recursion when defining the extension, take the type definition as an argument. fn signature_from_def(&self, array_def: &TypeDef) -> SignatureFunc { - let params = vec![TypeParam::max_nat(), TypeBound::Copyable.into()]; - let size = TypeArg::new_var_use(0, TypeParam::max_nat()); + let params = vec![TypeParam::max_nat_type(), TypeBound::Copyable.into()]; + let size = TypeArg::new_var_use(0, TypeParam::max_nat_type()); let element_ty = Type::new_var_use(1, TypeBound::Copyable); let array_ty = AK::instantiate_ty(array_def, size, element_ty) .expect("Array type instantiation failed"); @@ -157,10 +157,7 @@ impl MakeExtensionOp for GenericArrayClone { } fn type_args(&self) -> Vec { - vec![ - TypeArg::BoundedNat { n: self.size }, - self.elem_ty.clone().into(), - ] + vec![self.size.into(), self.elem_ty.clone().into()] } } @@ -183,7 +180,7 @@ impl HasConcrete for GenericArrayCloneDef { fn instantiate(&self, type_args: &[TypeArg]) -> Result { match type_args { - [TypeArg::BoundedNat { n }, TypeArg::Type { ty }] if ty.copyable() => { + [TypeArg::BoundedNat(n), TypeArg::Runtime(ty)] if ty.copyable() => { Ok(GenericArrayClone::new(ty.clone(), *n).unwrap()) } _ => Err(SignatureError::InvalidTypeArgs.into()), @@ -197,6 +194,7 @@ mod tests { use crate::extension::prelude::bool_t; use crate::std_extensions::collections::array::Array; + use crate::std_extensions::collections::borrow_array::BorrowArray; use crate::{ extension::prelude::qb_t, ops::{OpTrait, OpType}, @@ -206,6 +204,7 @@ mod tests { #[rstest] #[case(Array)] + #[case(BorrowArray)] fn test_clone_def(#[case] _kind: AK) { let op = GenericArrayClone::::new(bool_t(), 2).unwrap(); let optype: OpType = op.clone().into(); @@ -220,6 +219,7 @@ mod tests { #[rstest] #[case(Array)] + #[case(BorrowArray)] fn test_clone(#[case] _kind: AK) { let size = 2; let element_ty = bool_t(); diff --git a/hugr-core/src/std_extensions/collections/array/array_conversion.rs b/hugr-core/src/std_extensions/collections/array/array_conversion.rs index 21544dfd15..015b968002 100644 --- a/hugr-core/src/std_extensions/collections/array/array_conversion.rs +++ b/hugr-core/src/std_extensions/collections/array/array_conversion.rs @@ -76,9 +76,9 @@ impl { /// To avoid recursion when defining the extension, take the type definition as an argument. fn signature_from_def(&self, array_def: &TypeDef) -> SignatureFunc { - let params = vec![TypeParam::max_nat(), TypeBound::Any.into()]; - let size = TypeArg::new_var_use(0, TypeParam::max_nat()); - let element_ty = Type::new_var_use(1, TypeBound::Any); + let params = vec![TypeParam::max_nat_type(), TypeBound::Linear.into()]; + let size = TypeArg::new_var_use(0, TypeParam::max_nat_type()); + let element_ty = Type::new_var_use(1, TypeBound::Linear); let this_ty = AK::instantiate_ty(array_def, size.clone(), element_ty.clone()) .expect("Array type instantiation failed"); @@ -202,10 +202,7 @@ impl MakeExtensionOp } fn type_args(&self) -> Vec { - vec![ - TypeArg::BoundedNat { n: self.size }, - self.elem_ty.clone().into(), - ] + vec![TypeArg::BoundedNat(self.size), self.elem_ty.clone().into()] } } @@ -234,7 +231,7 @@ impl HasConcrete fn instantiate(&self, type_args: &[TypeArg]) -> Result { match type_args { - [TypeArg::BoundedNat { n }, TypeArg::Type { ty }] => { + [TypeArg::BoundedNat(n), TypeArg::Runtime(ty)] => { Ok(GenericArrayConvert::new(ty.clone(), *n)) } _ => Err(SignatureError::InvalidTypeArgs.into()), @@ -249,12 +246,14 @@ mod tests { use crate::extension::prelude::bool_t; use crate::ops::{OpTrait, OpType}; use crate::std_extensions::collections::array::Array; + use crate::std_extensions::collections::borrow_array::BorrowArray; use crate::std_extensions::collections::value_array::ValueArray; use super::*; #[rstest] #[case(ValueArray, Array)] + #[case(BorrowArray, Array)] fn test_convert_from_def( #[case] _kind: AK, #[case] _other_kind: OtherAK, @@ -267,6 +266,7 @@ mod tests { #[rstest] #[case(ValueArray, Array)] + #[case(BorrowArray, Array)] fn test_convert_into_def( #[case] _kind: AK, #[case] _other_kind: OtherAK, @@ -279,6 +279,7 @@ mod tests { #[rstest] #[case(ValueArray, Array)] + #[case(BorrowArray, Array)] fn test_convert_from( #[case] _kind: AK, #[case] _other_kind: OtherAK, @@ -299,6 +300,7 @@ mod tests { #[rstest] #[case(ValueArray, Array)] + #[case(BorrowArray, Array)] fn test_convert_into( #[case] _kind: AK, #[case] _other_kind: OtherAK, diff --git a/hugr-core/src/std_extensions/collections/array/array_discard.rs b/hugr-core/src/std_extensions/collections/array/array_discard.rs index 67e2281f72..7e7a6599e0 100644 --- a/hugr-core/src/std_extensions/collections/array/array_discard.rs +++ b/hugr-core/src/std_extensions/collections/array/array_discard.rs @@ -51,8 +51,8 @@ impl FromStr for GenericArrayDiscardDef { impl GenericArrayDiscardDef { /// To avoid recursion when defining the extension, take the type definition as an argument. fn signature_from_def(&self, array_def: &TypeDef) -> SignatureFunc { - let params = vec![TypeParam::max_nat(), TypeBound::Copyable.into()]; - let size = TypeArg::new_var_use(0, TypeParam::max_nat()); + let params = vec![TypeParam::max_nat_type(), TypeBound::Copyable.into()]; + let size = TypeArg::new_var_use(0, TypeParam::max_nat_type()); let element_ty = Type::new_var_use(1, TypeBound::Copyable); let array_ty = AK::instantiate_ty(array_def, size, element_ty) .expect("Array type instantiation failed"); @@ -141,10 +141,7 @@ impl MakeExtensionOp for GenericArrayDiscard { } fn type_args(&self) -> Vec { - vec![ - TypeArg::BoundedNat { n: self.size }, - self.elem_ty.clone().into(), - ] + vec![self.size.into(), self.elem_ty.clone().into()] } } @@ -167,7 +164,7 @@ impl HasConcrete for GenericArrayDiscardDef { fn instantiate(&self, type_args: &[TypeArg]) -> Result { match type_args { - [TypeArg::BoundedNat { n }, TypeArg::Type { ty }] if ty.copyable() => { + [TypeArg::BoundedNat(n), TypeArg::Runtime(ty)] if ty.copyable() => { Ok(GenericArrayDiscard::new(ty.clone(), *n).unwrap()) } _ => Err(SignatureError::InvalidTypeArgs.into()), @@ -181,6 +178,7 @@ mod tests { use crate::extension::prelude::bool_t; use crate::std_extensions::collections::array::Array; + use crate::std_extensions::collections::borrow_array::BorrowArray; use crate::{ extension::prelude::qb_t, ops::{OpTrait, OpType}, @@ -190,6 +188,7 @@ mod tests { #[rstest] #[case(Array)] + #[case(BorrowArray)] fn test_discard_def(#[case] _kind: AK) { let op = GenericArrayDiscard::::new(bool_t(), 2).unwrap(); let optype: OpType = op.clone().into(); @@ -201,6 +200,7 @@ mod tests { #[rstest] #[case(Array)] + #[case(BorrowArray)] fn test_discard(#[case] _kind: AK) { let size = 2; let element_ty = bool_t(); diff --git a/hugr-core/src/std_extensions/collections/array/array_op.rs b/hugr-core/src/std_extensions/collections/array/array_op.rs index 915603c1da..dc7cf3d940 100644 --- a/hugr-core/src/std_extensions/collections/array/array_op.rs +++ b/hugr-core/src/std_extensions/collections/array/array_op.rs @@ -16,7 +16,7 @@ use crate::extension::{ use crate::ops::{ExtensionOp, OpName}; use crate::type_row; use crate::types::type_param::{TypeArg, TypeParam}; -use crate::types::{FuncValueType, PolyFuncTypeRV, Type, TypeBound}; +use crate::types::{FuncValueType, PolyFuncTypeRV, Term, Type, TypeBound}; use crate::utils::Never; use super::array_kind::ArrayKind; @@ -65,16 +65,16 @@ pub enum GenericArrayOpDef { } /// Static parameters for array operations. Includes array size. Type is part of the type scheme. -const STATIC_SIZE_PARAM: &[TypeParam; 1] = &[TypeParam::max_nat()]; +const STATIC_SIZE_PARAM: &[TypeParam; 1] = &[TypeParam::max_nat_type()]; impl SignatureFromArgs for GenericArrayOpDef { fn compute_signature(&self, arg_values: &[TypeArg]) -> Result { - let [TypeArg::BoundedNat { n }] = *arg_values else { + let [TypeArg::BoundedNat(n)] = *arg_values else { return Err(SignatureError::InvalidTypeArgs); }; - let elem_ty_var = Type::new_var_use(0, TypeBound::Any); + let elem_ty_var = Type::new_var_use(0, TypeBound::Linear); let array_ty = AK::ty(n, elem_ty_var.clone()); - let params = vec![TypeBound::Any.into()]; + let params = vec![TypeBound::Linear.into()]; let poly_func_ty = match self { GenericArrayOpDef::new_array => PolyFuncTypeRV::new( params, @@ -139,11 +139,11 @@ impl GenericArrayOpDef { // signature computed dynamically, so can rely on type definition in extension. (*self).into() } else { - let size_var = TypeArg::new_var_use(0, TypeParam::max_nat()); - let elem_ty_var = Type::new_var_use(1, TypeBound::Any); + let size_var = TypeArg::new_var_use(0, TypeParam::max_nat_type()); + let elem_ty_var = Type::new_var_use(1, TypeBound::Linear); let array_ty = AK::instantiate_ty(array_def, size_var.clone(), elem_ty_var.clone()) .expect("Array type instantiation failed"); - let standard_params = vec![TypeParam::max_nat(), TypeBound::Any.into()]; + let standard_params = vec![TypeParam::max_nat_type(), TypeBound::Linear.into()]; // We can assume that the prelude has ben loaded at this point, // since it doesn't depend on the array extension. @@ -151,7 +151,7 @@ impl GenericArrayOpDef { match self { get => { - let params = vec![TypeParam::max_nat(), TypeBound::Copyable.into()]; + let params = vec![TypeParam::max_nat_type(), TypeBound::Copyable.into()]; let copy_elem_ty = Type::new_var_use(1, TypeBound::Copyable); let copy_array_ty = AK::instantiate_ty(array_def, size_var, copy_elem_ty.clone()) @@ -184,9 +184,9 @@ impl GenericArrayOpDef { ) } discard_empty => PolyFuncTypeRV::new( - vec![TypeBound::Any.into()], + vec![TypeBound::Linear.into()], FuncValueType::new( - AK::instantiate_ty(array_def, 0, Type::new_var_use(0, TypeBound::Any)) + AK::instantiate_ty(array_def, 0, Type::new_var_use(0, TypeBound::Linear)) .expect("Array type instantiation failed"), type_row![], ), @@ -282,13 +282,11 @@ impl MakeExtensionOp for GenericArrayOp { def.instantiate(ext_op.args()) } - fn type_args(&self) -> Vec { + fn type_args(&self) -> Vec { use GenericArrayOpDef::{ _phantom, discard_empty, get, new_array, pop_left, pop_right, set, swap, unpack, }; - let ty_arg = TypeArg::Type { - ty: self.elem_ty.clone(), - }; + let ty_arg = self.elem_ty.clone().into(); match self.def { discard_empty => { debug_assert_eq!( @@ -298,7 +296,7 @@ impl MakeExtensionOp for GenericArrayOp { vec![ty_arg] } new_array | unpack | pop_left | pop_right | get | set | swap => { - vec![TypeArg::BoundedNat { n: self.size }, ty_arg] + vec![self.size.into(), ty_arg] } _phantom(_, never) => match never {}, } @@ -322,10 +320,10 @@ impl HasDef for GenericArrayOp { impl HasConcrete for GenericArrayOpDef { type Concrete = GenericArrayOp; - fn instantiate(&self, type_args: &[TypeArg]) -> Result { + fn instantiate(&self, type_args: &[Term]) -> Result { let (ty, size) = match (self, type_args) { - (GenericArrayOpDef::discard_empty, [TypeArg::Type { ty }]) => (ty.clone(), 0), - (_, [TypeArg::BoundedNat { n }, TypeArg::Type { ty }]) => (ty.clone(), *n), + (GenericArrayOpDef::discard_empty, [Term::Runtime(ty)]) => (ty.clone(), 0), + (_, [Term::BoundedNat(n), Term::Runtime(ty)]) => (ty.clone(), *n), _ => return Err(SignatureError::InvalidTypeArgs.into()), }; @@ -341,6 +339,7 @@ mod tests { use crate::extension::prelude::usize_t; use crate::std_extensions::arithmetic::float_types::float64_type; use crate::std_extensions::collections::array::Array; + use crate::std_extensions::collections::borrow_array::BorrowArray; use crate::std_extensions::collections::value_array::ValueArray; use crate::{ builder::{DFGBuilder, Dataflow, DataflowHugr, inout_sig}, @@ -353,6 +352,7 @@ mod tests { #[rstest] #[case(Array)] #[case(ValueArray)] + #[case(BorrowArray)] fn test_array_ops(#[case] _kind: AK) { for def in GenericArrayOpDef::::iter() { let ty = if def == GenericArrayOpDef::get { @@ -375,6 +375,7 @@ mod tests { #[rstest] #[case(Array)] #[case(ValueArray)] + #[case(BorrowArray)] /// Test building a HUGR involving a new_array operation. fn test_new_array(#[case] _kind: AK) { let mut b = DFGBuilder::new(inout_sig(vec![qb_t(), qb_t()], AK::ty(2, qb_t()))).unwrap(); @@ -391,6 +392,7 @@ mod tests { #[rstest] #[case(Array)] #[case(ValueArray)] + #[case(BorrowArray)] /// Test building a HUGR involving an unpack operation. fn test_unpack(#[case] _kind: AK) { let mut b = DFGBuilder::new(inout_sig(AK::ty(2, qb_t()), vec![qb_t(), qb_t()])).unwrap(); @@ -407,6 +409,7 @@ mod tests { #[rstest] #[case(Array)] #[case(ValueArray)] + #[case(BorrowArray)] fn test_get(#[case] _kind: AK) { let size = 2; let element_ty = bool_t(); @@ -432,6 +435,7 @@ mod tests { #[rstest] #[case(Array)] #[case(ValueArray)] + #[case(BorrowArray)] fn test_set(#[case] _kind: AK) { let size = 2; let element_ty = bool_t(); @@ -454,6 +458,7 @@ mod tests { #[rstest] #[case(Array)] #[case(ValueArray)] + #[case(BorrowArray)] fn test_swap(#[case] _kind: AK) { let size = 2; let element_ty = bool_t(); @@ -475,6 +480,7 @@ mod tests { #[rstest] #[case(Array)] #[case(ValueArray)] + #[case(BorrowArray)] fn test_pops(#[case] _kind: AK) { let size = 2; let element_ty = bool_t(); @@ -507,6 +513,7 @@ mod tests { #[rstest] #[case(Array)] #[case(ValueArray)] + #[case(BorrowArray)] fn test_discard_empty(#[case] _kind: AK) { let size = 0; let element_ty = bool_t(); @@ -525,6 +532,7 @@ mod tests { #[rstest] #[case(Array)] #[case(ValueArray)] + #[case(BorrowArray)] /// Initialize an array operation where the element type is not from the prelude. fn test_non_prelude_op(#[case] _kind: AK) { let size = 2; diff --git a/hugr-core/src/std_extensions/collections/array/array_repeat.rs b/hugr-core/src/std_extensions/collections/array/array_repeat.rs index d3302d253a..e2c77ef21c 100644 --- a/hugr-core/src/std_extensions/collections/array/array_repeat.rs +++ b/hugr-core/src/std_extensions/collections/array/array_repeat.rs @@ -52,9 +52,9 @@ impl FromStr for GenericArrayRepeatDef { impl GenericArrayRepeatDef { /// To avoid recursion when defining the extension, take the type definition as an argument. fn signature_from_def(&self, array_def: &TypeDef) -> SignatureFunc { - let params = vec![TypeParam::max_nat(), TypeBound::Any.into()]; - let n = TypeArg::new_var_use(0, TypeParam::max_nat()); - let t = Type::new_var_use(1, TypeBound::Any); + let params = vec![TypeParam::max_nat_type(), TypeBound::Linear.into()]; + let n = TypeArg::new_var_use(0, TypeParam::max_nat_type()); + let t = Type::new_var_use(1, TypeBound::Linear); let func = Type::new_function(Signature::new(vec![], vec![t.clone()])); let array_ty = AK::instantiate_ty(array_def, n, t).expect("Array type instantiation failed"); @@ -147,10 +147,7 @@ impl MakeExtensionOp for GenericArrayRepeat { } fn type_args(&self) -> Vec { - vec![ - TypeArg::BoundedNat { n: self.size }, - self.elem_ty.clone().into(), - ] + vec![self.size.into(), self.elem_ty.clone().into()] } } @@ -173,7 +170,7 @@ impl HasConcrete for GenericArrayRepeatDef { fn instantiate(&self, type_args: &[TypeArg]) -> Result { match type_args { - [TypeArg::BoundedNat { n }, TypeArg::Type { ty }] => { + [TypeArg::BoundedNat(n), TypeArg::Runtime(ty)] => { Ok(GenericArrayRepeat::new(ty.clone(), *n)) } _ => Err(SignatureError::InvalidTypeArgs.into()), @@ -186,6 +183,7 @@ mod tests { use rstest::rstest; use crate::std_extensions::collections::array::Array; + use crate::std_extensions::collections::borrow_array::BorrowArray; use crate::std_extensions::collections::value_array::ValueArray; use crate::{ extension::prelude::qb_t, @@ -198,6 +196,7 @@ mod tests { #[rstest] #[case(Array)] #[case(ValueArray)] + #[case(BorrowArray)] fn test_repeat_def(#[case] _kind: AK) { let op = GenericArrayRepeat::::new(qb_t(), 2); let optype: OpType = op.clone().into(); @@ -208,6 +207,7 @@ mod tests { #[rstest] #[case(Array)] #[case(ValueArray)] + #[case(BorrowArray)] fn test_repeat(#[case] _kind: AK) { let size = 2; let element_ty = qb_t(); diff --git a/hugr-core/src/std_extensions/collections/array/array_scan.rs b/hugr-core/src/std_extensions/collections/array/array_scan.rs index 2dc5d2f734..416777f436 100644 --- a/hugr-core/src/std_extensions/collections/array/array_scan.rs +++ b/hugr-core/src/std_extensions/collections/array/array_scan.rs @@ -56,15 +56,15 @@ impl GenericArrayScanDef { fn signature_from_def(&self, array_def: &TypeDef) -> SignatureFunc { // array, (T1, *A -> T2, *A), *A, -> array, *A let params = vec![ - TypeParam::max_nat(), - TypeBound::Any.into(), - TypeBound::Any.into(), - TypeParam::new_list(TypeBound::Any), + TypeParam::max_nat_type(), + TypeBound::Linear.into(), + TypeBound::Linear.into(), + TypeParam::new_list_type(TypeBound::Linear), ]; - let n = TypeArg::new_var_use(0, TypeParam::max_nat()); - let t1 = Type::new_var_use(1, TypeBound::Any); - let t2 = Type::new_var_use(2, TypeBound::Any); - let s = TypeRV::new_row_var_use(3, TypeBound::Any); + let n = TypeArg::new_var_use(0, TypeParam::max_nat_type()); + let t1 = Type::new_var_use(1, TypeBound::Linear); + let t2 = Type::new_var_use(2, TypeBound::Linear); + let s = TypeRV::new_row_var_use(3, TypeBound::Linear); PolyFuncTypeRV::new( params, FuncTypeBase::::new( @@ -185,12 +185,10 @@ impl MakeExtensionOp for GenericArrayScan { fn type_args(&self) -> Vec { vec![ - TypeArg::BoundedNat { n: self.size }, + self.size.into(), self.src_ty.clone().into(), self.tgt_ty.clone().into(), - TypeArg::Sequence { - elems: self.acc_tys.clone().into_iter().map_into().collect(), - }, + TypeArg::new_list(self.acc_tys.clone().into_iter().map_into()), ] } } @@ -215,15 +213,15 @@ impl HasConcrete for GenericArrayScanDef { fn instantiate(&self, type_args: &[TypeArg]) -> Result { match type_args { [ - TypeArg::BoundedNat { n }, - TypeArg::Type { ty: src_ty }, - TypeArg::Type { ty: tgt_ty }, - TypeArg::Sequence { elems: acc_tys }, + TypeArg::BoundedNat(n), + TypeArg::Runtime(src_ty), + TypeArg::Runtime(tgt_ty), + TypeArg::List(acc_tys), ] => { let acc_tys: Result<_, OpLoadError> = acc_tys .iter() .map(|acc_ty| match acc_ty { - TypeArg::Type { ty } => Ok(ty.clone()), + TypeArg::Runtime(ty) => Ok(ty.clone()), _ => Err(SignatureError::InvalidTypeArgs.into()), }) .collect(); @@ -245,6 +243,7 @@ mod tests { use crate::extension::prelude::usize_t; use crate::std_extensions::collections::array::Array; + use crate::std_extensions::collections::borrow_array::BorrowArray; use crate::std_extensions::collections::value_array::ValueArray; use crate::{ extension::prelude::{bool_t, qb_t}, @@ -257,6 +256,7 @@ mod tests { #[rstest] #[case(Array)] #[case(ValueArray)] + #[case(BorrowArray)] fn test_scan_def(#[case] _kind: AK) { let op = GenericArrayScan::::new(bool_t(), qb_t(), vec![usize_t()], 2); let optype: OpType = op.clone().into(); @@ -267,6 +267,7 @@ mod tests { #[rstest] #[case(Array)] #[case(ValueArray)] + #[case(BorrowArray)] fn test_scan_map(#[case] _kind: AK) { let size = 2; let src_ty = qb_t(); @@ -292,6 +293,7 @@ mod tests { #[rstest] #[case(Array)] #[case(ValueArray)] + #[case(BorrowArray)] fn test_scan_accs(#[case] _kind: AK) { let size = 2; let src_ty = qb_t(); diff --git a/hugr-core/src/std_extensions/collections/array/array_value.rs b/hugr-core/src/std_extensions/collections/array/array_value.rs index 8828acd982..916f218739 100644 --- a/hugr-core/src/std_extensions/collections/array/array_value.rs +++ b/hugr-core/src/std_extensions/collections/array/array_value.rs @@ -94,9 +94,7 @@ impl GenericArrayValue { // constant can only hold classic type. let ty = match typ.args() { - [TypeArg::BoundedNat { n }, TypeArg::Type { ty }] - if *n as usize == self.values.len() => - { + [TypeArg::BoundedNat(n), TypeArg::Runtime(ty)] if *n as usize == self.values.len() => { ty } _ => { @@ -148,6 +146,7 @@ mod test { use crate::std_extensions::arithmetic::float_types::ConstF64; use crate::std_extensions::collections::array::Array; + use crate::std_extensions::collections::borrow_array::BorrowArray; use crate::std_extensions::collections::value_array::ValueArray; use super::*; @@ -155,6 +154,7 @@ mod test { #[rstest] #[case(Array)] #[case(ValueArray)] + #[case(BorrowArray)] fn test_array_value(#[case] _kind: AK) { let array_value = GenericArrayValue::::new(usize_t(), vec![ConstUsize::new(3).into()]); array_value.validate().unwrap(); diff --git a/hugr-core/src/std_extensions/collections/array/op_builder.rs b/hugr-core/src/std_extensions/collections/array/op_builder.rs index 2740673f80..b408e1a3de 100644 --- a/hugr-core/src/std_extensions/collections/array/op_builder.rs +++ b/hugr-core/src/std_extensions/collections/array/op_builder.rs @@ -1,6 +1,7 @@ //! Builder trait for array operations in the dataflow graph. use crate::std_extensions::collections::array::GenericArrayOpDef; +use crate::std_extensions::collections::borrow_array::BorrowArray; use crate::std_extensions::collections::value_array::ValueArray; use crate::{ Wire, @@ -390,6 +391,11 @@ pub fn build_all_value_array_ops(builder: B) -> B { build_all_array_ops_generic::(builder) } +/// Helper function to build a Hugr that contains all basic array operations. +pub fn build_all_borrow_array_ops(builder: B) -> B { + build_all_array_ops_generic::(builder) +} + /// Testing utilities to generate Hugrs that contain array operations. #[cfg(test)] mod test { @@ -411,4 +417,11 @@ mod test { let builder = DFGBuilder::new(sig).unwrap(); build_all_value_array_ops(builder).finish_hugr().unwrap(); } + + #[test] + fn all_borrow_array_ops() { + let sig = Signature::new_endo(Type::EMPTY_TYPEROW); + let builder = DFGBuilder::new(sig).unwrap(); + build_all_borrow_array_ops(builder).finish_hugr().unwrap(); + } } diff --git a/hugr-core/src/std_extensions/collections/borrow_array.rs b/hugr-core/src/std_extensions/collections/borrow_array.rs new file mode 100644 index 0000000000..52982d0833 --- /dev/null +++ b/hugr-core/src/std_extensions/collections/borrow_array.rs @@ -0,0 +1,797 @@ +//! A version of the standard fixed-length array extension that includes unsafe +//! operations for borrowing and returning that may panic. + +use std::sync::{self, Arc}; + +use delegate::delegate; +use lazy_static::lazy_static; + +use crate::extension::{ExtensionId, SignatureError, TypeDef, TypeDefBound}; +use crate::ops::constant::{CustomConst, ValueName}; +use crate::type_row; +use crate::types::type_param::{TypeArg, TypeParam}; +use crate::types::{CustomCheckFailure, Term, Type, TypeBound, TypeName}; +use crate::{Extension, Wire}; +use crate::{ + builder::{BuildError, Dataflow}, + extension::SignatureFunc, +}; +use crate::{ + extension::simple_op::{HasConcrete, HasDef, MakeExtensionOp, MakeOpDef, MakeRegisteredOp}, + ops::ExtensionOp, +}; +use crate::{ + extension::{ + OpDef, + prelude::usize_t, + resolution::{ExtensionResolutionError, WeakExtensionRegistry}, + simple_op::{OpLoadError, try_from_name}, + }, + ops::OpName, + types::{FuncValueType, PolyFuncTypeRV}, +}; + +use super::array::op_builder::GenericArrayOpBuilder; +use super::array::{ + Array, ArrayKind, FROM, GenericArrayClone, GenericArrayCloneDef, GenericArrayConvert, + GenericArrayConvertDef, GenericArrayDiscard, GenericArrayDiscardDef, GenericArrayOp, + GenericArrayOpDef, GenericArrayRepeat, GenericArrayRepeatDef, GenericArrayScan, + GenericArrayScanDef, GenericArrayValue, INTO, +}; + +/// Reported unique name of the borrow array type. +pub const BORROW_ARRAY_TYPENAME: TypeName = TypeName::new_inline("borrow_array"); +/// Reported unique name of the borrow array value. +pub const BORROW_ARRAY_VALUENAME: TypeName = TypeName::new_inline("borrow_array"); +/// Reported unique name of the extension +pub const EXTENSION_ID: ExtensionId = ExtensionId::new_unchecked("collections.borrow_arr"); +/// Extension version. +pub const VERSION: semver::Version = semver::Version::new(0, 1, 1); + +/// A linear, unsafe, fixed-length collection of values. +/// +/// Borrow arrays are linear, even if their elements are copyable. +#[derive(Clone, Copy, Debug, derive_more::Display, Eq, PartialEq, Default)] +pub struct BorrowArray; + +impl ArrayKind for BorrowArray { + const EXTENSION_ID: ExtensionId = EXTENSION_ID; + const TYPE_NAME: TypeName = BORROW_ARRAY_TYPENAME; + const VALUE_NAME: ValueName = BORROW_ARRAY_VALUENAME; + + fn extension() -> &'static Arc { + &EXTENSION + } + + fn type_def() -> &'static TypeDef { + EXTENSION.get_type(&BORROW_ARRAY_TYPENAME).unwrap() + } +} + +/// Borrow array operation definitions. +pub type BArrayOpDef = GenericArrayOpDef; +/// Borrow array clone operation definition. +pub type BArrayCloneDef = GenericArrayCloneDef; +/// Borrow array discard operation definition. +pub type BArrayDiscardDef = GenericArrayDiscardDef; +/// Borrow array repeat operation definition. +pub type BArrayRepeatDef = GenericArrayRepeatDef; +/// Borrow array scan operation definition. +pub type BArrayScanDef = GenericArrayScanDef; +/// Borrow array to default array conversion operation definition. +pub type BArrayToArrayDef = GenericArrayConvertDef; +/// Borrow array from default array conversion operation definition. +pub type BArrayFromArrayDef = GenericArrayConvertDef; + +/// Borrow array operations. +pub type BArrayOp = GenericArrayOp; +/// The borrow array clone operation. +pub type BArrayClone = GenericArrayClone; +/// The borrow array discard operation. +pub type BArrayDiscard = GenericArrayDiscard; +/// The borrow array repeat operation. +pub type BArrayRepeat = GenericArrayRepeat; +/// The borrow array scan operation. +pub type BArrayScan = GenericArrayScan; +/// The borrow array to default array conversion operation. +pub type BArrayToArray = GenericArrayConvert; +/// The borrow array from default array conversion operation. +pub type BArrayFromArray = GenericArrayConvert; + +/// A borrow array extension value. +pub type BArrayValue = GenericArrayValue; + +#[derive( + Clone, + Copy, + Debug, + Hash, + PartialEq, + Eq, + strum::EnumIter, + strum::IntoStaticStr, + strum::EnumString, +)] +#[allow(non_camel_case_types, missing_docs)] +#[non_exhaustive] +pub enum BArrayUnsafeOpDef { + /// `borrow: borrow_array, index -> elem_ty, borrow_array` + borrow, + /// `return: borrow_array, index, elem_ty -> borrow_array` + #[strum(serialize = "return")] + r#return, + /// `discard_all_borrowed: borrow_array -> ()` + discard_all_borrowed, + /// `new_all_borrowed: () -> borrow_array` + new_all_borrowed, +} + +impl BArrayUnsafeOpDef { + /// Instantiate a new unsafe borrow array operation with the given element type and array size. + #[must_use] + pub fn to_concrete(self, elem_ty: Type, size: u64) -> BArrayUnsafeOp { + BArrayUnsafeOp { + def: self, + elem_ty, + size, + } + } + + fn signature_from_def(&self, def: &TypeDef, _: &sync::Weak) -> SignatureFunc { + let size_var = TypeArg::new_var_use(0, TypeParam::max_nat_type()); + let elem_ty_var = Type::new_var_use(1, TypeBound::Linear); + let array_ty: Type = def + .instantiate(vec![size_var, elem_ty_var.clone().into()]) + .unwrap() + .into(); + + let params = vec![TypeParam::max_nat_type(), TypeBound::Linear.into()]; + + let usize_t: Type = usize_t(); + + match self { + Self::borrow => PolyFuncTypeRV::new( + params, + FuncValueType::new(vec![array_ty.clone(), usize_t], vec![elem_ty_var, array_ty]), + ), + Self::r#return => PolyFuncTypeRV::new( + params, + FuncValueType::new( + vec![array_ty.clone(), usize_t, elem_ty_var.clone()], + vec![array_ty], + ), + ), + Self::discard_all_borrowed => { + PolyFuncTypeRV::new(params, FuncValueType::new(vec![array_ty], type_row![])) + } + Self::new_all_borrowed => { + PolyFuncTypeRV::new(params, FuncValueType::new(type_row![], vec![array_ty])) + } + } + .into() + } +} + +impl MakeOpDef for BArrayUnsafeOpDef { + fn opdef_id(&self) -> OpName { + <&'static str>::from(self).into() + } + + fn from_def(op_def: &OpDef) -> Result + where + Self: Sized, + { + try_from_name(op_def.name(), op_def.extension_id()) + } + + fn init_signature(&self, extension_ref: &sync::Weak) -> SignatureFunc { + self.signature_from_def( + EXTENSION.get_type(&BORROW_ARRAY_TYPENAME).unwrap(), + extension_ref, + ) + } + + fn extension_ref(&self) -> sync::Weak { + Arc::downgrade(&EXTENSION) + } + + fn extension(&self) -> ExtensionId { + EXTENSION_ID.clone() + } + + fn description(&self) -> String { + match self { + Self::borrow => { + "Take an element from a borrow array (panicking if it was already taken before)" + } + Self::r#return => { + "Put an element into a borrow array (panicking if there is an element already)" + } + Self::discard_all_borrowed => { + "Discard a borrow array where all elements have been borrowed" + } + Self::new_all_borrowed => "Create a new borrow array that contains no elements", + } + .into() + } + + // This method is re-defined here to avoid recursive loops initializing the extension. + fn add_to_extension( + &self, + extension: &mut Extension, + extension_ref: &sync::Weak, + ) -> Result<(), crate::extension::ExtensionBuildError> { + let sig = self.signature_from_def( + extension.get_type(&BORROW_ARRAY_TYPENAME).unwrap(), + extension_ref, + ); + let def = extension.add_op(self.opdef_id(), self.description(), sig, extension_ref)?; + + self.post_opdef(def); + + Ok(()) + } +} + +#[derive(Clone, Debug, PartialEq)] +/// Concrete array operation. +pub struct BArrayUnsafeOp { + /// The operation definition. + pub def: BArrayUnsafeOpDef, + /// The element type of the array. + pub elem_ty: Type, + /// The size of the array. + pub size: u64, +} + +impl MakeExtensionOp for BArrayUnsafeOp { + fn op_id(&self) -> OpName { + self.def.opdef_id() + } + + fn from_extension_op(ext_op: &ExtensionOp) -> Result + where + Self: Sized, + { + let def = BArrayUnsafeOpDef::from_def(ext_op.def())?; + def.instantiate(ext_op.args()) + } + + fn type_args(&self) -> Vec { + vec![self.size.into(), self.elem_ty.clone().into()] + } +} + +impl HasDef for BArrayUnsafeOp { + type Def = BArrayUnsafeOpDef; +} + +impl HasConcrete for BArrayUnsafeOpDef { + type Concrete = BArrayUnsafeOp; + + fn instantiate(&self, type_args: &[TypeArg]) -> Result { + match type_args { + [Term::BoundedNat(n), Term::Runtime(ty)] => Ok(self.to_concrete(ty.clone(), *n)), + _ => Err(SignatureError::InvalidTypeArgs.into()), + } + } +} + +impl MakeRegisteredOp for BArrayUnsafeOp { + fn extension_id(&self) -> ExtensionId { + EXTENSION_ID.clone() + } + + fn extension_ref(&self) -> sync::Weak { + Arc::downgrade(&EXTENSION) + } +} + +lazy_static! { + /// Extension for borrow array operations. + pub static ref EXTENSION: Arc = { + Extension::new_arc(EXTENSION_ID, VERSION, |extension, extension_ref| { + extension.add_type( + BORROW_ARRAY_TYPENAME, + vec![ TypeParam::max_nat_type(), TypeBound::Linear.into()], + "Fixed-length borrow array".into(), + // Borrow array is linear, even if the elements are copyable. + TypeDefBound::any(), + extension_ref, + ) + .unwrap(); + + BArrayOpDef::load_all_ops(extension, extension_ref).unwrap(); + BArrayCloneDef::new().add_to_extension(extension, extension_ref).unwrap(); + BArrayDiscardDef::new().add_to_extension(extension, extension_ref).unwrap(); + BArrayRepeatDef::new().add_to_extension(extension, extension_ref).unwrap(); + BArrayScanDef::new().add_to_extension(extension, extension_ref).unwrap(); + BArrayToArrayDef::new().add_to_extension(extension, extension_ref).unwrap(); + BArrayFromArrayDef::new().add_to_extension(extension, extension_ref).unwrap(); + + BArrayUnsafeOpDef::load_all_ops(extension, extension_ref).unwrap(); + }) + }; +} + +#[typetag::serde(name = "BArrayValue")] +impl CustomConst for BArrayValue { + delegate! { + to self { + fn name(&self) -> ValueName; + fn validate(&self) -> Result<(), CustomCheckFailure>; + fn update_extensions( + &mut self, + extensions: &WeakExtensionRegistry, + ) -> Result<(), ExtensionResolutionError>; + fn get_type(&self) -> Type; + } + } + + fn equal_consts(&self, other: &dyn CustomConst) -> bool { + crate::ops::constant::downcast_equal_consts(self, other) + } +} + +/// Gets the [`TypeDef`] for borrow arrays. Note that instantiations are more easily +/// created via [`borrow_array_type`] and [`borrow_array_type_parametric`] +#[must_use] +pub fn borrow_array_type_def() -> &'static TypeDef { + BorrowArray::type_def() +} + +/// Instantiate a new borrow array type given a size argument and element type. +/// +/// This method is equivalent to [`borrow_array_type_parametric`], but uses concrete +/// arguments types to ensure no errors are possible. +#[must_use] +pub fn borrow_array_type(size: u64, element_ty: Type) -> Type { + BorrowArray::ty(size, element_ty) +} + +/// Instantiate a new borrow array type given the size and element type parameters. +/// +/// This is a generic version of [`borrow_array_type`]. +pub fn borrow_array_type_parametric( + size: impl Into, + element_ty: impl Into, +) -> Result { + BorrowArray::ty_parametric(size, element_ty) +} + +/// Trait for building borrow array operations in a dataflow graph. +pub trait BArrayOpBuilder: GenericArrayOpBuilder { + /// Adds a new array operation to the dataflow graph and return the wire + /// representing the new array. + /// + /// # Arguments + /// + /// * `elem_ty` - The type of the elements in the array. + /// * `values` - An iterator over the values to initialize the array with. + /// + /// # Errors + /// + /// If building the operation fails. + /// + /// # Returns + /// + /// The wire representing the new array. + fn add_new_borrow_array( + &mut self, + elem_ty: Type, + values: impl IntoIterator, + ) -> Result { + self.add_new_generic_array::(elem_ty, values) + } + /// Adds an array unpack operation to the dataflow graph. + /// + /// This operation unpacks an array into individual elements. + /// + /// # Arguments + /// + /// * `elem_ty` - The type of the elements in the array. + /// * `size` - The size of the array. + /// * `input` - The wire representing the array to unpack. + /// + /// # Errors + /// + /// If building the operation fails. + /// + /// # Returns + /// + /// A vector of wires representing the individual elements from the array. + fn add_borrow_array_unpack( + &mut self, + elem_ty: Type, + size: u64, + input: Wire, + ) -> Result, BuildError> { + self.add_generic_array_unpack::(elem_ty, size, input) + } + /// Adds an array clone operation to the dataflow graph and return the wires + /// representing the original and cloned array. + /// + /// # Arguments + /// + /// * `elem_ty` - The type of the elements in the array. + /// * `size` - The size of the array. + /// * `input` - The wire representing the array. + /// + /// # Errors + /// + /// If building the operation fails. + /// + /// # Returns + /// + /// The wires representing the original and cloned array. + fn add_borrow_array_clone( + &mut self, + elem_ty: Type, + size: u64, + input: Wire, + ) -> Result<(Wire, Wire), BuildError> { + self.add_generic_array_clone::(elem_ty, size, input) + } + + /// Adds an array discard operation to the dataflow graph. + /// + /// # Arguments + /// + /// * `elem_ty` - The type of the elements in the array. + /// * `size` - The size of the array. + /// * `input` - The wire representing the array. + /// + /// # Errors + /// + /// If building the operation fails. + fn add_borrow_array_discard( + &mut self, + elem_ty: Type, + size: u64, + input: Wire, + ) -> Result<(), BuildError> { + self.add_generic_array_discard::(elem_ty, size, input) + } + + /// Adds an array get operation to the dataflow graph. + /// + /// # Arguments + /// + /// * `elem_ty` - The type of the elements in the array. + /// * `size` - The size of the array. + /// * `input` - The wire representing the array. + /// * `index` - The wire representing the index to get. + /// + /// # Errors + /// + /// If building the operation fails. + /// + /// # Returns + /// + /// * The wire representing the value at the specified index in the array + /// * The wire representing the array + fn add_borrow_array_get( + &mut self, + elem_ty: Type, + size: u64, + input: Wire, + index: Wire, + ) -> Result<(Wire, Wire), BuildError> { + self.add_generic_array_get::(elem_ty, size, input, index) + } + + /// Adds an array set operation to the dataflow graph. + /// + /// This operation sets the value at a specified index in the array. + /// + /// # Arguments + /// + /// * `elem_ty` - The type of the elements in the array. + /// * `size` - The size of the array. + /// * `input` - The wire representing the array. + /// * `index` - The wire representing the index to set. + /// * `value` - The wire representing the value to set at the specified index. + /// + /// # Errors + /// + /// Returns an error if building the operation fails. + /// + /// # Returns + /// + /// The wire representing the updated array after the set operation. + fn add_borrow_array_set( + &mut self, + elem_ty: Type, + size: u64, + input: Wire, + index: Wire, + value: Wire, + ) -> Result { + self.add_generic_array_set::(elem_ty, size, input, index, value) + } + + /// Adds an array swap operation to the dataflow graph. + /// + /// This operation swaps the values at two specified indices in the array. + /// + /// # Arguments + /// + /// * `elem_ty` - The type of the elements in the array. + /// * `size` - The size of the array. + /// * `input` - The wire representing the array. + /// * `index1` - The wire representing the first index to swap. + /// * `index2` - The wire representing the second index to swap. + /// + /// # Errors + /// + /// Returns an error if building the operation fails. + /// + /// # Returns + /// + /// The wire representing the updated array after the swap operation. + fn add_borrow_array_swap( + &mut self, + elem_ty: Type, + size: u64, + input: Wire, + index1: Wire, + index2: Wire, + ) -> Result { + let op = + GenericArrayOpDef::::swap.instantiate(&[size.into(), elem_ty.into()])?; + let [out] = self + .add_dataflow_op(op, vec![input, index1, index2])? + .outputs_arr(); + Ok(out) + } + + /// Adds an array pop-left operation to the dataflow graph. + /// + /// This operation removes the leftmost element from the array. + /// + /// # Arguments + /// + /// * `elem_ty` - The type of the elements in the array. + /// * `size` - The size of the array. + /// * `input` - The wire representing the array. + /// + /// # Errors + /// + /// Returns an error if building the operation fails. + /// + /// # Returns + /// + /// The wire representing the Option> + fn add_borrow_array_pop_left( + &mut self, + elem_ty: Type, + size: u64, + input: Wire, + ) -> Result { + self.add_generic_array_pop_left::(elem_ty, size, input) + } + + /// Adds an array pop-right operation to the dataflow graph. + /// + /// This operation removes the rightmost element from the array. + /// + /// # Arguments + /// + /// * `elem_ty` - The type of the elements in the array. + /// * `size` - The size of the array. + /// * `input` - The wire representing the array. + /// + /// # Errors + /// + /// Returns an error if building the operation fails. + /// + /// # Returns + /// + /// The wire representing the Option> + fn add_borrow_array_pop_right( + &mut self, + elem_ty: Type, + size: u64, + input: Wire, + ) -> Result { + self.add_generic_array_pop_right::(elem_ty, size, input) + } + + /// Adds an operation to discard an empty array from the dataflow graph. + /// + /// # Arguments + /// + /// * `elem_ty` - The type of the elements in the array. + /// * `input` - The wire representing the array. + /// + /// # Errors + /// + /// Returns an error if building the operation fails. + fn add_borrow_array_discard_empty( + &mut self, + elem_ty: Type, + input: Wire, + ) -> Result<(), BuildError> { + self.add_generic_array_discard_empty::(elem_ty, input) + } + + /// Adds a borrow array borrow operation to the dataflow graph. + /// + /// # Arguments + /// + /// * `elem_ty` - The type of the elements in the array. + /// * `size` - The size of the array. + /// * `input` - The wire representing the array. + /// * `index` - The wire representing the index to get. + /// + /// # Errors + /// + /// Returns an error if building the operation fails. + fn add_borrow_array_borrow( + &mut self, + elem_ty: Type, + size: u64, + input: Wire, + index: Wire, + ) -> Result<(Wire, Wire), BuildError> { + let op = BArrayUnsafeOpDef::borrow.instantiate(&[size.into(), elem_ty.into()])?; + let [out, arr] = self + .add_dataflow_op(op.to_extension_op().unwrap(), vec![input, index])? + .outputs_arr(); + Ok((out, arr)) + } + + /// Adds a borrow array put operation to the dataflow graph. + /// + /// # Arguments + /// + /// * `elem_ty` - The type of the elements in the array. + /// * `size` - The size of the array. + /// * `input` - The wire representing the array. + /// * `index` - The wire representing the index to set. + /// * `value` - The wire representing the value to set at the specified index. + /// + /// # Errors + /// + /// Returns an error if building the operation fails. + fn add_borrow_array_return( + &mut self, + elem_ty: Type, + size: u64, + input: Wire, + index: Wire, + value: Wire, + ) -> Result { + let op = BArrayUnsafeOpDef::r#return.instantiate(&[size.into(), elem_ty.into()])?; + let [arr] = self + .add_dataflow_op(op.to_extension_op().unwrap(), vec![input, index, value])? + .outputs_arr(); + Ok(arr) + } + + /// Adds an operation to discard a borrow array where all elements have been borrowed. + /// + /// # Arguments + /// + /// * `elem_ty` - The type of the elements in the array. + /// * `size` - The size of the array. + /// * `input` - The wire representing the array. + /// + /// # Errors + /// + /// Returns an error if building the operation fails. + fn add_discard_all_borrowed( + &mut self, + elem_ty: Type, + size: u64, + input: Wire, + ) -> Result<(), BuildError> { + let op = + BArrayUnsafeOpDef::discard_all_borrowed.instantiate(&[size.into(), elem_ty.into()])?; + self.add_dataflow_op(op.to_extension_op().unwrap(), vec![input])?; + Ok(()) + } + + /// Adds an operation to create a new empty borrowed array in the dataflow graph. + /// + /// # Arguments + /// + /// * `elem_ty` - The type of the elements in the array. + /// * `size` - The size of the array. + /// + /// # Errors + /// + /// Returns an error if building the operation fails. + fn add_new_all_borrowed(&mut self, elem_ty: Type, size: u64) -> Result { + let op = BArrayUnsafeOpDef::new_all_borrowed.instantiate(&[size.into(), elem_ty.into()])?; + let [arr] = self + .add_dataflow_op(op.to_extension_op().unwrap(), vec![])? + .outputs_arr(); + Ok(arr) + } +} + +impl BArrayOpBuilder for D {} + +#[cfg(test)] +mod test { + use strum::IntoEnumIterator; + + use crate::{ + builder::{DFGBuilder, Dataflow, DataflowHugr as _}, + extension::prelude::{ConstUsize, qb_t, usize_t}, + ops::OpType, + std_extensions::collections::borrow_array::{ + BArrayOpBuilder, BArrayUnsafeOp, BArrayUnsafeOpDef, borrow_array_type, + }, + types::Signature, + }; + + #[test] + fn test_borrow_array_unsafe_ops() { + for def in BArrayUnsafeOpDef::iter() { + let op = def.to_concrete(qb_t(), 2); + let optype: OpType = op.clone().into(); + let new_op: BArrayUnsafeOp = optype.cast().unwrap(); + assert_eq!(new_op, op); + } + } + + #[test] + fn test_borrow_and_return() { + let size = 22; + let elem_ty = qb_t(); + let arr_ty = borrow_array_type(size, elem_ty.clone()); + let _ = { + let mut builder = DFGBuilder::new(Signature::new_endo(vec![arr_ty.clone()])).unwrap(); + let idx1 = builder.add_load_value(ConstUsize::new(11)); + let idx2 = builder.add_load_value(ConstUsize::new(11)); + let [arr] = builder.input_wires_arr(); + let (el, arr_with_take) = builder + .add_borrow_array_borrow(elem_ty.clone(), size, arr, idx1) + .unwrap(); + let arr_with_put = builder + .add_borrow_array_return(elem_ty, size, arr_with_take, idx2, el) + .unwrap(); + builder.finish_hugr_with_outputs([arr_with_put]).unwrap() + }; + } + + #[test] + fn test_discard_all_borrowed() { + let size = 1; + let elem_ty = qb_t(); + let arr_ty = borrow_array_type(size, elem_ty.clone()); + let _ = { + let mut builder = + DFGBuilder::new(Signature::new(vec![arr_ty.clone()], vec![qb_t()])).unwrap(); + let idx = builder.add_load_value(ConstUsize::new(0)); + let [arr] = builder.input_wires_arr(); + let (el, arr_with_borrowed) = builder + .add_borrow_array_borrow(elem_ty.clone(), size, arr, idx) + .unwrap(); + builder + .add_discard_all_borrowed(elem_ty, size, arr_with_borrowed) + .unwrap(); + builder.finish_hugr_with_outputs([el]).unwrap() + }; + } + + #[test] + fn test_new_all_borrowed() { + let size = 5; + let elem_ty = usize_t(); + let arr_ty = borrow_array_type(size, elem_ty.clone()); + let _ = { + let mut builder = + DFGBuilder::new(Signature::new(vec![], vec![arr_ty.clone()])).unwrap(); + let arr = builder.add_new_all_borrowed(elem_ty.clone(), size).unwrap(); + let idx = builder.add_load_value(ConstUsize::new(3)); + let val = builder.add_load_value(ConstUsize::new(202)); + let arr_with_put = builder + .add_borrow_array_return(elem_ty, size, arr, idx, val) + .unwrap(); + builder.finish_hugr_with_outputs([arr_with_put]).unwrap() + }; + } +} diff --git a/hugr-core/src/std_extensions/collections/list.rs b/hugr-core/src/std_extensions/collections/list.rs index 05d05048a6..817f90dba4 100644 --- a/hugr-core/src/std_extensions/collections/list.rs +++ b/hugr-core/src/std_extensions/collections/list.rs @@ -21,7 +21,7 @@ use crate::extension::simple_op::{MakeOpDef, MakeRegisteredOp}; use crate::extension::{ExtensionBuildError, OpDef, SignatureFunc}; use crate::ops::constant::{TryHash, ValueName, maybe_hash_values}; use crate::ops::{OpName, Value}; -use crate::types::{TypeName, TypeRowRV}; +use crate::types::{Term, TypeName, TypeRowRV}; use crate::{ Extension, extension::{ @@ -112,7 +112,7 @@ impl CustomConst for ListValue { .map_err(|_| error())?; // constant can only hold classic type. - let [TypeArg::Type { ty }] = typ.args() else { + let [TypeArg::Runtime(ty)] = typ.args() else { return Err(error()); }; @@ -167,7 +167,7 @@ pub enum ListOp { impl ListOp { /// Type parameter used in the list types. - const TP: TypeParam = TypeParam::Type { b: TypeBound::Any }; + const TP: TypeParam = TypeParam::RuntimeType(TypeBound::Linear); /// Instantiate a list operation with an `element_type`. #[must_use] @@ -181,7 +181,7 @@ impl ListOp { /// Compute the signature of the operation, given the list type definition. fn compute_signature(self, list_type_def: &TypeDef) -> SignatureFunc { use ListOp::{get, insert, length, pop, push, set}; - let e = Type::new_var_use(0, TypeBound::Any); + let e = Type::new_var_use(0, TypeBound::Linear); let l = self.list_type(list_type_def, 0); match self { pop => self @@ -325,9 +325,7 @@ pub fn list_type_def() -> &'static TypeDef { /// Get the type of a list of `elem_type` as a `CustomType`. #[must_use] pub fn list_custom_type(elem_type: Type) -> CustomType { - list_type_def() - .instantiate(vec![TypeArg::Type { ty: elem_type }]) - .unwrap() + list_type_def().instantiate(vec![elem_type.into()]).unwrap() } /// Get the `Type` of a list of `elem_type`. @@ -353,7 +351,7 @@ impl MakeExtensionOp for ListOpInst { fn from_extension_op( ext_op: &ExtensionOp, ) -> Result { - let [TypeArg::Type { ty }] = ext_op.args() else { + let [Term::Runtime(ty)] = ext_op.args() else { return Err(SignatureError::InvalidTypeArgs.into()); }; let name = ext_op.unqualified_id(); @@ -367,10 +365,8 @@ impl MakeExtensionOp for ListOpInst { }) } - fn type_args(&self) -> Vec { - vec![TypeArg::Type { - ty: self.elem_type.clone(), - }] + fn type_args(&self) -> Vec { + vec![self.elem_type.clone().into()] } } @@ -413,15 +409,9 @@ mod test { fn test_list() { let list_def = list_type_def(); - let list_type = list_def - .instantiate([TypeArg::Type { ty: usize_t() }]) - .unwrap(); + let list_type = list_def.instantiate([usize_t().into()]).unwrap(); - assert!( - list_def - .instantiate([TypeArg::BoundedNat { n: 3 }]) - .is_err() - ); + assert!(list_def.instantiate([3u64.into()]).is_err()); list_def.check_custom(&list_type).unwrap(); let list_value = ListValue(vec![ConstUsize::new(3).into()], usize_t()); diff --git a/hugr-core/src/std_extensions/collections/static_array.rs b/hugr-core/src/std_extensions/collections/static_array.rs index 6f3e889e68..c99e7617b2 100644 --- a/hugr-core/src/std_extensions/collections/static_array.rs +++ b/hugr-core/src/std_extensions/collections/static_array.rs @@ -38,7 +38,7 @@ use crate::{ types::{ ConstTypeError, CustomCheckFailure, CustomType, PolyFuncType, Signature, Type, TypeArg, TypeBound, TypeName, - type_param::{TypeArgError, TypeParam}, + type_param::{TermTypeError, TypeParam}, }, }; @@ -309,12 +309,12 @@ impl HasConcrete for StaticArrayOpDef { match type_args { [arg] => { let elem_ty = arg - .as_type() + .as_runtime() .filter(|t| Copyable.contains(t.least_upper_bound())) .ok_or(SignatureError::TypeArgMismatch( - TypeArgError::TypeMismatch { - param: Copyable.into(), - arg: arg.clone(), + TermTypeError::TypeMismatch { + type_: Box::new(Copyable.into()), + term: Box::new(arg.clone()), }, ))?; @@ -324,7 +324,7 @@ impl HasConcrete for StaticArrayOpDef { }) } _ => Err( - SignatureError::TypeArgMismatch(TypeArgError::WrongNumberArgs(type_args.len(), 1)) + SignatureError::TypeArgMismatch(TermTypeError::WrongNumberArgs(type_args.len(), 1)) .into(), ), } diff --git a/hugr-core/src/std_extensions/collections/value_array.rs b/hugr-core/src/std_extensions/collections/value_array.rs index fe89824d77..947fef9188 100644 --- a/hugr-core/src/std_extensions/collections/value_array.rs +++ b/hugr-core/src/std_extensions/collections/value_array.rs @@ -102,7 +102,7 @@ lazy_static! { Extension::new_arc(EXTENSION_ID, VERSION, |extension, extension_ref| { extension.add_type( VALUE_ARRAY_TYPENAME, - vec![ TypeParam::max_nat(), TypeBound::Any.into()], + vec![ TypeParam::max_nat_type(), TypeBound::Linear.into()], "Fixed-length value array".into(), // Value arrays are copyable iff their elements are TypeDefBound::from_params(vec![1]), diff --git a/hugr-core/src/std_extensions/ptr.rs b/hugr-core/src/std_extensions/ptr.rs index 3955c3a972..74ecf63fc1 100644 --- a/hugr-core/src/std_extensions/ptr.rs +++ b/hugr-core/src/std_extensions/ptr.rs @@ -89,9 +89,7 @@ impl MakeOpDef for PtrOpDef { pub const EXTENSION_ID: ExtensionId = ExtensionId::new_unchecked("ptr"); /// Name of pointer type. pub const PTR_TYPE_ID: TypeName = TypeName::new_inline("ptr"); -const TYPE_PARAMS: [TypeParam; 1] = [TypeParam::Type { - b: TypeBound::Copyable, -}]; +const TYPE_PARAMS: [TypeParam; 1] = [TypeParam::RuntimeType(TypeBound::Copyable)]; /// Extension version. pub const VERSION: semver::Version = semver::Version::new(0, 1, 0); @@ -209,7 +207,7 @@ impl HasConcrete for PtrOpDef { fn instantiate(&self, type_args: &[TypeArg]) -> Result { let ty = match type_args { - [TypeArg::Type { ty }] => ty.clone(), + [TypeArg::Runtime(ty)] => ty.clone(), _ => return Err(SignatureError::InvalidTypeArgs.into()), }; diff --git a/hugr-core/src/types.rs b/hugr-core/src/types.rs index 2b36233133..cdfd1012a8 100644 --- a/hugr-core/src/types.rs +++ b/hugr-core/src/types.rs @@ -4,7 +4,7 @@ mod check; pub mod custom; mod poly_func; mod row_var; -mod serialize; +pub(crate) mod serialize; mod signature; pub mod type_param; pub mod type_row; @@ -15,14 +15,14 @@ use crate::extension::resolution::{ ExtensionCollectionError, WeakExtensionRegistry, collect_type_exts, }; pub use crate::ops::constant::{ConstTypeError, CustomCheckFailure}; -use crate::types::type_param::check_type_arg; +use crate::types::type_param::check_term_type; use crate::utils::display_list_with_separator; pub use check::SumTypeError; pub use custom::CustomType; pub use poly_func::{PolyFuncType, PolyFuncTypeRV}; pub use signature::{FuncTypeBase, FuncValueType, Signature}; use smol_str::SmolStr; -pub use type_param::TypeArg; +pub use type_param::{Term, TypeArg}; pub use type_row::{TypeRow, TypeRowRV}; pub(crate) use poly_func::PolyFuncTypeBase; @@ -131,9 +131,11 @@ pub enum TypeBound { #[serde(rename = "C", alias = "E")] // alias to read in legacy Eq variants Copyable, /// No bound on the type. + /// + /// It cannot be copied nor discarded. #[serde(rename = "A")] #[default] - Any, + Linear, } impl TypeBound { @@ -152,16 +154,16 @@ impl TypeBound { /// Report if this bound contains another. #[must_use] pub const fn contains(&self, other: TypeBound) -> bool { - use TypeBound::{Any, Copyable}; - matches!((self, other), (Any, _) | (_, Copyable)) + use TypeBound::{Copyable, Linear}; + matches!((self, other), (Linear, _) | (_, Copyable)) } } /// Calculate the least upper bound for an iterator of bounds pub(crate) fn least_upper_bound(mut tags: impl Iterator) -> TypeBound { tags.fold_while(TypeBound::Copyable, |acc, new| { - if acc == TypeBound::Any || new == TypeBound::Any { - Done(TypeBound::Any) + if acc == TypeBound::Linear || new == TypeBound::Linear { + Done(TypeBound::Linear) } else { Continue(acc.union(new)) } @@ -490,7 +492,7 @@ impl TypeBase { /// New use (occurrence) of the type variable with specified index. /// `bound` must be exactly that with which the variable was declared - /// (i.e. as a [`TypeParam::Type`]`(bound)`), which may be narrower + /// (i.e. as a [`Term::RuntimeType`]`(bound)`), which may be narrower /// than required for the use. #[must_use] pub const fn new_var_use(idx: usize, bound: TypeBound) -> Self { @@ -575,7 +577,7 @@ impl TypeBase { TypeEnum::RowVar(rv) => rv.substitute(t), TypeEnum::Alias(_) | TypeEnum::Sum(SumType::Unit { .. }) => vec![self.clone()], TypeEnum::Variable(idx, bound) => { - let TypeArg::Type { ty } = t.apply_var(*idx, &((*bound).into())) else { + let TypeArg::Runtime(ty) = t.apply_var(*idx, &((*bound).into())) else { panic!("Variable was not a type - try validate() first") }; vec![ty.into_()] @@ -653,7 +655,7 @@ impl TypeRV { /// New use (occurrence) of the row variable with specified index. /// `bound` must match that with which the variable was declared - /// (i.e. as a [TypeParam::List]` of a `[TypeParam::Type]` of that bound). + /// (i.e. as a list of runtime types of that bound). /// For use in [OpDef], not [FuncDefn], type schemes only. /// /// [OpDef]: crate::extension::OpDef @@ -740,7 +742,7 @@ impl<'a> Substitution<'a> { .0 .get(idx) .expect("Undeclared type variable - call validate() ?"); - debug_assert_eq!(check_type_arg(arg, decl), Ok(())); + debug_assert_eq!(check_term_type(arg, decl), Ok(())); arg.clone() } @@ -749,14 +751,14 @@ impl<'a> Substitution<'a> { .0 .get(idx) .expect("Undeclared type variable - call validate() ?"); - debug_assert!(check_type_arg(arg, &TypeParam::new_list(bound)).is_ok()); + debug_assert!(check_term_type(arg, &TypeParam::new_list_type(bound)).is_ok()); match arg { - TypeArg::Sequence { elems } => elems + TypeArg::List(elems) => elems .iter() .map(|ta| { match ta { - TypeArg::Type { ty } => return ty.clone().into(), - TypeArg::Variable { v } => { + Term::Runtime(ty) => return ty.clone().into(), + Term::Variable(v) => { if let Some(b) = v.bound_if_row_var() { return TypeRV::new_row_var_use(v.index(), b); } @@ -766,7 +768,7 @@ impl<'a> Substitution<'a> { panic!("Not a list of types - call validate() ?") }) .collect(), - TypeArg::Type { ty } if matches!(ty.0, TypeEnum::RowVar(_)) => { + Term::Runtime(ty) if matches!(ty.0, TypeEnum::RowVar(_)) => { // Standalone "Type" can be used iff its actually a Row Variable not an actual (single) Type vec![ty.clone().into()] } @@ -777,11 +779,11 @@ impl<'a> Substitution<'a> { /// A transformation that can be applied to a [Type] or [`TypeArg`]. /// More general in some ways than a Substitution: can fail with a -/// [`Self::Err`], may change [`TypeBound::Copyable`] to [`TypeBound::Any`], +/// [`Self::Err`], may change [`TypeBound::Copyable`] to [`TypeBound::Linear`], /// and applies to arbitrary extension types rather than type variables. pub trait TypeTransformer { /// Error returned when a [`CustomType`] cannot be transformed, or a type - /// containing it (e.g. if changing a [`TypeArg::Type`] from copyable to + /// containing it (e.g. if changing a runtime type from copyable to /// linear invalidates a parameterized type). type Err: std::error::Error + From; @@ -839,8 +841,8 @@ pub(crate) fn check_typevar_decl( Ok(()) } else { Err(SignatureError::TypeVarDoesNotMatchDeclaration { - cached: cached_decl.clone(), - actual: actual.clone(), + cached: Box::new(cached_decl.clone()), + actual: Box::new(actual.clone()), }) } } @@ -857,7 +859,7 @@ pub(crate) mod test { use crate::extension::prelude::{option_type, qb_t, usize_t}; use crate::std_extensions::collections::array::{array_type, array_type_parametric}; use crate::std_extensions::collections::list::list_type; - use crate::types::type_param::TypeArgError; + use crate::types::type_param::TermTypeError; use crate::{Extension, hugr::IdentList, type_row}; #[test] @@ -930,7 +932,7 @@ pub(crate) mod test { fn sum_variants() { let variants: Vec = vec![ TypeRV::UNIT.into(), - vec![TypeRV::new_row_var_use(0, TypeBound::Any)].into(), + vec![TypeRV::new_row_var_use(0, TypeBound::Linear)].into(), ]; let t = SumType::new(variants.clone()); assert_eq!(variants, t.variants().cloned().collect_vec()); @@ -977,7 +979,7 @@ pub(crate) mod test { |t| array_type(10, t), |t| { array_type_parametric( - TypeArg::new_var_use(0, TypeParam::bounded_nat(3.try_into().unwrap())), + TypeArg::new_var_use(0, TypeParam::bounded_nat_type(3.try_into().unwrap())), t, ) .unwrap() @@ -1001,7 +1003,7 @@ pub(crate) mod test { .unwrap(); e.add_type( COLN, - vec![TypeParam::new_list(TypeBound::Copyable)], + vec![TypeParam::new_list_type(TypeBound::Copyable)], String::new(), TypeDefBound::copyable(), w, @@ -1020,31 +1022,27 @@ pub(crate) mod test { let coln = e.get_type(&COLN).unwrap(); let c_of_cpy = coln - .instantiate([TypeArg::Sequence { - elems: vec![Type::from(cpy.clone()).into()], - }]) + .instantiate([Term::new_list([Type::from(cpy.clone()).into()])]) .unwrap(); let mut t = Type::new_extension(c_of_cpy.clone()); assert_eq!( t.transform(&cpy_to_qb), - Err(SignatureError::from(TypeArgError::TypeMismatch { - param: TypeBound::Copyable.into(), - arg: qb_t().into() + Err(SignatureError::from(TermTypeError::TypeMismatch { + type_: Box::new(TypeBound::Copyable.into()), + term: Box::new(qb_t().into()) })) ); let mut t = Type::new_extension( - coln.instantiate([TypeArg::Sequence { - elems: vec![mk_opt(Type::from(cpy.clone())).into()], - }]) - .unwrap(), + coln.instantiate([Term::new_list([mk_opt(Type::from(cpy.clone())).into()])]) + .unwrap(), ); assert_eq!( t.transform(&cpy_to_qb), - Err(SignatureError::from(TypeArgError::TypeMismatch { - param: TypeBound::Copyable.into(), - arg: mk_opt(qb_t()).into() + Err(SignatureError::from(TermTypeError::TypeMismatch { + type_: Box::new(TypeBound::Copyable.into()), + term: Box::new(mk_opt(qb_t()).into()) })) ); @@ -1054,19 +1052,15 @@ pub(crate) mod test { (ct == &c_of_cpy).then_some(usize_t()) }); let mut t = Type::new_extension( - coln.instantiate([TypeArg::Sequence { - elems: vec![Type::from(c_of_cpy.clone()).into(); 2], - }]) - .unwrap(), + coln.instantiate([Term::new_list(vec![Type::from(c_of_cpy.clone()).into(); 2])]) + .unwrap(), ); assert_eq!(t.transform(&cpy_to_qb2), Ok(true)); assert_eq!( t, Type::new_extension( - coln.instantiate([TypeArg::Sequence { - elems: vec![usize_t().into(); 2] - }]) - .unwrap() + coln.instantiate([Term::new_list([usize_t().into(), usize_t().into()])]) + .unwrap() ) ); } @@ -1115,3 +1109,82 @@ pub(crate) mod test { } } } + +#[cfg(test)] +pub(super) mod proptest_utils { + use proptest::collection::vec; + use proptest::prelude::{Strategy, any_with}; + + use super::serialize::{TermSer, TypeArgSer, TypeParamSer}; + use super::type_param::Term; + + use crate::proptest::RecursionDepth; + use crate::types::serialize::ArrayOrTermSer; + + fn term_is_serde_type_arg(t: &Term) -> bool { + let TermSer::TypeArg(arg) = TermSer::from(t.clone()) else { + return false; + }; + match arg { + TypeArgSer::List { elems: terms } + | TypeArgSer::ListConcat { lists: terms } + | TypeArgSer::Tuple { elems: terms } + | TypeArgSer::TupleConcat { tuples: terms } => terms.iter().all(term_is_serde_type_arg), + TypeArgSer::Variable { v } => term_is_serde_type_param(&v.cached_decl), + TypeArgSer::Type { ty } => { + if let Some(cty) = ty.as_extension() { + cty.args().iter().all(term_is_serde_type_arg) + } else { + true + } + } // Do we need to inspect inside function types? sum types? + TypeArgSer::BoundedNat { .. } + | TypeArgSer::String { .. } + | TypeArgSer::Bytes { .. } + | TypeArgSer::Float { .. } => true, + } + } + + fn term_is_serde_type_param(t: &Term) -> bool { + let TermSer::TypeParam(parm) = TermSer::from(t.clone()) else { + return false; + }; + match parm { + TypeParamSer::Type { .. } + | TypeParamSer::BoundedNat { .. } + | TypeParamSer::String + | TypeParamSer::Bytes + | TypeParamSer::Float + | TypeParamSer::StaticType => true, + TypeParamSer::List { param } => term_is_serde_type_param(¶m), + TypeParamSer::Tuple { params } => { + match ¶ms { + ArrayOrTermSer::Array(terms) => terms.iter().all(term_is_serde_type_param), + ArrayOrTermSer::Term(b) => match &**b { + Term::List(_) => panic!("Should be represented as ArrayOrTermSer::Array"), + // This might be well-typed, but does not fit the (TODO: update) JSON schema + Term::Variable(_) => false, + // Similarly, but not produced by our `impl Arbitrary`: + Term::ListConcat(_) => todo!("Update schema"), + + // The others do not fit the JSON schema, and are not well-typed, + // but can be produced by our impl of Arbitrary, so we must filter out: + _ => false, + }, + } + } + } + } + + pub fn any_serde_type_arg(depth: RecursionDepth) -> impl Strategy { + any_with::(depth).prop_filter("Term was not a TypeArg", term_is_serde_type_arg) + } + + pub fn any_serde_type_arg_vec() -> impl Strategy> { + vec(any_serde_type_arg(RecursionDepth::default()), 1..3) + } + + pub fn any_serde_type_param(depth: RecursionDepth) -> impl Strategy { + any_with::(depth).prop_filter("Term was not a TypeParam", term_is_serde_type_param) + } +} diff --git a/hugr-core/src/types/check.rs b/hugr-core/src/types/check.rs index 2146ee41ba..072da5884e 100644 --- a/hugr-core/src/types/check.rs +++ b/hugr-core/src/types/check.rs @@ -17,9 +17,9 @@ pub enum SumTypeError { /// The element in the tuple that was wrong. index: usize, /// The expected type. - expected: Type, + expected: Box, /// The value that was found. - found: Value, + found: Box, }, /// The type of the variant we were trying to convert into contained type variables #[error("Sum variant #{tag} contained a variable #{varidx}")] @@ -88,8 +88,8 @@ impl super::SumType { Err(SumTypeError::InvalidValueType { tag, index, - expected: t.clone(), - found: v.clone(), + expected: Box::new(t.clone()), + found: Box::new(v.clone()), })?; } } diff --git a/hugr-core/src/types/custom.rs b/hugr-core/src/types/custom.rs index 02ab188338..248e0f6253 100644 --- a/hugr-core/src/types/custom.rs +++ b/hugr-core/src/types/custom.rs @@ -188,7 +188,7 @@ mod test { use crate::extension::ExtensionId; use crate::proptest::RecursionDepth; use crate::proptest::any_nonempty_string; - use crate::types::type_param::TypeArg; + use crate::types::proptest_utils::any_serde_type_arg; use crate::types::{CustomType, TypeBound}; use ::proptest::collection::vec; use ::proptest::prelude::*; @@ -224,7 +224,7 @@ mod test { Just(vec![]).boxed() } else { // a TypeArg may contain a CustomType, so we descend here - vec(any_with::(depth.descend()), 0..3).boxed() + vec(any_serde_type_arg(depth.descend()), 0..3).boxed() }; (any_nonempty_string(), args, any::(), bound) .prop_map(|(id, args, extension, bound)| { diff --git a/hugr-core/src/types/poly_func.rs b/hugr-core/src/types/poly_func.rs index 8121741bf8..0de6b1b029 100644 --- a/hugr-core/src/types/poly_func.rs +++ b/hugr-core/src/types/poly_func.rs @@ -7,13 +7,14 @@ use itertools::Itertools; use crate::extension::SignatureError; #[cfg(test)] use { + super::proptest_utils::any_serde_type_param, crate::proptest::RecursionDepth, ::proptest::{collection::vec, prelude::*}, proptest_derive::Arbitrary, }; use super::Substitution; -use super::type_param::{TypeArg, TypeParam, check_type_args}; +use super::type_param::{TypeArg, TypeParam, check_term_types}; use super::{MaybeRV, NoRV, RowVariable, signature::FuncTypeBase}; /// A polymorphic type scheme, i.e. of a [`FuncDecl`], [`FuncDefn`] or [`OpDef`]. @@ -31,7 +32,7 @@ pub struct PolyFuncTypeBase { /// The declared type parameters, i.e., these must be instantiated with /// the same number of [`TypeArg`]s before the function can be called. This /// defines the indices used by variables inside the body. - #[cfg_attr(test, proptest(strategy = "vec(any_with::(params), 0..3)"))] + #[cfg_attr(test, proptest(strategy = "vec(any_serde_type_param(params), 0..3)"))] params: Vec, /// Template for the function. May contain variables up to length of [`Self::params`] #[cfg_attr(test, proptest(strategy = "any_with::>(params)"))] @@ -122,7 +123,7 @@ impl PolyFuncTypeBase { pub fn instantiate(&self, args: &[TypeArg]) -> Result, SignatureError> { // Check that args are applicable, and that we have a value for each binder, // i.e. each possible free variable within the body. - check_type_args(args, &self.params)?; + check_term_types(args, &self.params)?; Ok(self.body.substitute(&Substitution(args))) } @@ -166,9 +167,9 @@ pub(crate) mod test { use crate::std_extensions::collections::array::{self, array_type_parametric}; use crate::std_extensions::collections::list; use crate::types::signature::FuncTypeBase; - use crate::types::type_param::{TypeArg, TypeArgError, TypeParam}; + use crate::types::type_param::{TermTypeError, TypeArg, TypeParam}; use crate::types::{ - CustomType, FuncValueType, MaybeRV, Signature, Type, TypeBound, TypeName, TypeRV, + CustomType, FuncValueType, MaybeRV, Signature, Term, Type, TypeBound, TypeName, TypeRV, }; use super::PolyFuncTypeBase; @@ -192,21 +193,19 @@ pub(crate) mod test { #[test] fn test_opaque() -> Result<(), SignatureError> { let list_def = list::EXTENSION.get_type(&list::LIST_TYPENAME).unwrap(); - let tyvar = TypeArg::new_var_use(0, TypeBound::Any.into()); + let tyvar = TypeArg::new_var_use(0, TypeBound::Linear.into()); let list_of_var = Type::new_extension(list_def.instantiate([tyvar.clone()])?); let list_len = PolyFuncTypeBase::new_validated( - [TypeBound::Any.into()], + [TypeBound::Linear.into()], Signature::new(vec![list_of_var], vec![usize_t()]), )?; - let t = list_len.instantiate(&[TypeArg::Type { ty: usize_t() }])?; + let t = list_len.instantiate(&[usize_t().into()])?; assert_eq!( t, Signature::new( vec![Type::new_extension( - list_def - .instantiate([TypeArg::Type { ty: usize_t() }]) - .unwrap() + list_def.instantiate([usize_t().into()]).unwrap() )], vec![usize_t()] ) @@ -217,9 +216,9 @@ pub(crate) mod test { #[test] fn test_mismatched_args() -> Result<(), SignatureError> { - let size_var = TypeArg::new_var_use(0, TypeParam::max_nat()); - let ty_var = TypeArg::new_var_use(1, TypeBound::Any.into()); - let type_params = [TypeParam::max_nat(), TypeBound::Any.into()]; + let size_var = TypeArg::new_var_use(0, TypeParam::max_nat_type()); + let ty_var = TypeArg::new_var_use(1, TypeBound::Linear.into()); + let type_params = [TypeParam::max_nat_type(), TypeBound::Linear.into()]; // Valid schema... let good_array = array_type_parametric(size_var.clone(), ty_var.clone())?; @@ -227,29 +226,23 @@ pub(crate) mod test { PolyFuncTypeBase::new_validated(type_params.clone(), Signature::new_endo(good_array))?; // Sanity check (good args) - good_ts.instantiate(&[ - TypeArg::BoundedNat { n: 5 }, - TypeArg::Type { ty: usize_t() }, - ])?; - - let wrong_args = good_ts.instantiate(&[ - TypeArg::Type { ty: usize_t() }, - TypeArg::BoundedNat { n: 5 }, - ]); + good_ts.instantiate(&[5u64.into(), usize_t().into()])?; + + let wrong_args = good_ts.instantiate(&[usize_t().into(), 5u64.into()]); assert_eq!( wrong_args, Err(SignatureError::TypeArgMismatch( - TypeArgError::TypeMismatch { - param: type_params[0].clone(), - arg: TypeArg::Type { ty: usize_t() } + TermTypeError::TypeMismatch { + type_: Box::new(type_params[0].clone()), + term: Box::new(usize_t().into()), } )) ); // (Try to) make a schema with the args in the wrong order - let arg_err = SignatureError::TypeArgMismatch(TypeArgError::TypeMismatch { - param: type_params[0].clone(), - arg: ty_var.clone(), + let arg_err = SignatureError::TypeArgMismatch(TermTypeError::TypeMismatch { + type_: Box::new(type_params[0].clone()), + term: Box::new(ty_var.clone()), }); assert_eq!( array_type_parametric(ty_var.clone(), size_var.clone()), @@ -260,7 +253,7 @@ pub(crate) mod test { "array", [ty_var, size_var], array::EXTENSION_ID, - TypeBound::Any, + TypeBound::Linear, &Arc::downgrade(&array::EXTENSION), )); let bad_ts = @@ -277,20 +270,16 @@ pub(crate) mod test { let list_def = list::EXTENSION.get_type(&list::LIST_TYPENAME).unwrap(); let body_type = Signature::new_endo(Type::new_extension(list_def.instantiate([tv])?)); for decl in [ - TypeParam::List { - param: Box::new(TypeParam::max_nat()), - }, - TypeParam::String, - TypeParam::Tuple { - params: vec![TypeBound::Any.into(), TypeParam::max_nat()], - }, + Term::new_list_type(Term::max_nat_type()), + Term::StringType, + Term::new_tuple_type([TypeBound::Linear.into(), Term::max_nat_type()]), ] { let invalid_ts = PolyFuncTypeBase::new_validated([decl.clone()], body_type.clone()); assert_eq!( invalid_ts.err(), Some(SignatureError::TypeVarDoesNotMatchDeclaration { - cached: TypeBound::Copyable.into(), - actual: decl + cached: Box::new(TypeBound::Copyable.into()), + actual: Box::new(decl) }) ); } @@ -336,7 +325,7 @@ pub(crate) mod test { TYPE_NAME, [TypeArg::new_var_use(0, tp)], EXT_ID, - TypeBound::Any, + TypeBound::Linear, &Arc::downgrade(&ext), ))), ) @@ -348,9 +337,9 @@ pub(crate) mod test { assert_eq!( make_scheme(decl.clone()).err(), Some(SignatureError::TypeArgMismatch( - TypeArgError::TypeMismatch { - param: bound.clone(), - arg: TypeArg::new_var_use(0, decl.clone()) + TermTypeError::TypeMismatch { + type_: Box::new(bound.clone()), + term: Box::new(TypeArg::new_var_use(0, decl.clone())) } )) ); @@ -363,38 +352,33 @@ pub(crate) mod test { decl_accepts_rejects_var( TypeBound::Copyable.into(), &[TypeBound::Copyable.into()], - &[TypeBound::Any.into()], + &[TypeBound::Linear.into()], )?; - let list_of_tys = |b: TypeBound| TypeParam::List { - param: Box::new(b.into()), - }; decl_accepts_rejects_var( - list_of_tys(TypeBound::Copyable), - &[list_of_tys(TypeBound::Copyable)], - &[list_of_tys(TypeBound::Any)], + Term::new_list_type(TypeBound::Copyable), + &[Term::new_list_type(TypeBound::Copyable)], + &[Term::new_list_type(TypeBound::Linear)], )?; decl_accepts_rejects_var( - TypeParam::max_nat(), - &[TypeParam::bounded_nat(NonZeroU64::new(5).unwrap())], + TypeParam::max_nat_type(), + &[TypeParam::bounded_nat_type(NonZeroU64::new(5).unwrap())], &[], )?; decl_accepts_rejects_var( - TypeParam::bounded_nat(NonZeroU64::new(10).unwrap()), - &[TypeParam::bounded_nat(NonZeroU64::new(5).unwrap())], - &[TypeParam::max_nat()], + TypeParam::bounded_nat_type(NonZeroU64::new(10).unwrap()), + &[TypeParam::bounded_nat_type(NonZeroU64::new(5).unwrap())], + &[TypeParam::max_nat_type()], )?; Ok(()) } - const TP_ANY: TypeParam = TypeParam::Type { b: TypeBound::Any }; + const TP_ANY: TypeParam = TypeParam::RuntimeType(TypeBound::Linear); #[test] fn row_variables_bad_schema() { // Mismatched TypeBound (Copyable vs Any) - let decl = TypeParam::List { - param: Box::new(TP_ANY), - }; + let decl = Term::new_list_type(TP_ANY); let e = PolyFuncTypeBase::new_validated( [decl.clone()], FuncValueType::new( @@ -404,26 +388,26 @@ pub(crate) mod test { ) .unwrap_err(); assert_matches!(e, SignatureError::TypeVarDoesNotMatchDeclaration { actual, cached } => { - assert_eq!(actual, decl); - assert_eq!(cached, TypeParam::List {param: Box::new(TypeParam::Type {b: TypeBound::Copyable})}); + assert_eq!(*actual, decl); + assert_eq!(*cached, TypeParam::new_list_type(TypeBound::Copyable)); }); // Declared as row variable, used as type variable let e = PolyFuncTypeBase::new_validated( [decl.clone()], - Signature::new_endo(vec![Type::new_var_use(0, TypeBound::Any)]), + Signature::new_endo(vec![Type::new_var_use(0, TypeBound::Linear)]), ) .unwrap_err(); assert_matches!(e, SignatureError::TypeVarDoesNotMatchDeclaration { actual, cached } => { - assert_eq!(actual, decl); - assert_eq!(cached, TP_ANY); + assert_eq!(*actual, decl); + assert_eq!(*cached, TP_ANY); }); } #[test] fn row_variables() { - let rty = TypeRV::new_row_var_use(0, TypeBound::Any); + let rty = TypeRV::new_row_var_use(0, TypeBound::Linear); let pf = PolyFuncTypeBase::new_validated( - [TypeParam::new_list(TP_ANY)], + [TypeParam::new_list_type(TP_ANY)], FuncValueType::new( vec![usize_t().into(), rty.clone()], vec![TypeRV::new_tuple(rty)], @@ -434,16 +418,11 @@ pub(crate) mod test { fn seq2() -> Vec { vec![usize_t().into(), bool_t().into()] } - pf.instantiate(&[TypeArg::Type { ty: usize_t() }]) + pf.instantiate(&[usize_t().into()]).unwrap_err(); + pf.instantiate(&[Term::new_list([usize_t().into(), Term::new_list(seq2())])]) .unwrap_err(); - pf.instantiate(&[TypeArg::Sequence { - elems: vec![usize_t().into(), TypeArg::Sequence { elems: seq2() }], - }]) - .unwrap_err(); - let t2 = pf - .instantiate(&[TypeArg::Sequence { elems: seq2() }]) - .unwrap(); + let t2 = pf.instantiate(&[Term::new_list(seq2())]).unwrap(); assert_eq!( t2, Signature::new( @@ -460,20 +439,18 @@ pub(crate) mod test { TypeBound::Copyable, ))); let pf = PolyFuncTypeBase::new_validated( - [TypeParam::List { - param: Box::new(TypeParam::Type { - b: TypeBound::Copyable, - }), - }], + [Term::new_list_type(TypeBound::Copyable)], Signature::new(vec![usize_t(), inner_fty.clone()], vec![inner_fty]), ) .unwrap(); let inner3 = Type::new_function(Signature::new_endo(vec![usize_t(), bool_t(), usize_t()])); let t3 = pf - .instantiate(&[TypeArg::Sequence { - elems: vec![usize_t().into(), bool_t().into(), usize_t().into()], - }]) + .instantiate(&[Term::new_list([ + usize_t().into(), + bool_t().into(), + usize_t().into(), + ])]) .unwrap(); assert_eq!( t3, diff --git a/hugr-core/src/types/row_var.rs b/hugr-core/src/types/row_var.rs index 106870003b..086ab7b076 100644 --- a/hugr-core/src/types/row_var.rs +++ b/hugr-core/src/types/row_var.rs @@ -6,7 +6,7 @@ use crate::extension::SignatureError; #[cfg(test)] use proptest::prelude::{BoxedStrategy, Strategy, any}; -/// Describes a row variable - a type variable bound with a [`TypeParam::List`] of [`TypeParam::Type`] +/// Describes a row variable - a type variable bound with a list of runtime types /// of the specified bound (checked in validation) // The serde derives here are not used except as markers // so that other types containing this can also #derive-serde the same way. @@ -70,7 +70,7 @@ impl MaybeRV for RowVariable { } fn validate(&self, var_decls: &[TypeParam]) -> Result<(), SignatureError> { - check_typevar_decl(var_decls, self.0, &TypeParam::new_list(self.1)) + check_typevar_decl(var_decls, self.0, &TypeParam::new_list_type(self.1)) } #[allow(private_interfaces)] diff --git a/hugr-core/src/types/serialize.rs b/hugr-core/src/types/serialize.rs index 198c0c1eda..c0a35dfd5e 100644 --- a/hugr-core/src/types/serialize.rs +++ b/hugr-core/src/types/serialize.rs @@ -1,3 +1,7 @@ +use std::sync::Arc; + +use ordered_float::OrderedFloat; + use super::{FuncValueType, MaybeRV, RowVariable, SumType, TypeBase, TypeBound, TypeEnum}; use super::custom::CustomType; @@ -5,10 +9,12 @@ use super::custom::CustomType; use crate::extension::SignatureError; use crate::extension::prelude::{qb_t, usize_t}; use crate::ops::AliasDecl; +use crate::types::type_param::{TermVar, UpperBound}; +use crate::types::{Term, Type}; #[derive(serde::Serialize, serde::Deserialize, Clone, Debug)] #[serde(tag = "t")] -pub(super) enum SerSimpleType { +pub(crate) enum SerSimpleType { Q, I, G(Box), @@ -60,3 +66,167 @@ impl TryFrom for TypeBase { }) } } + +#[derive(Clone, Debug, serde::Deserialize, serde::Serialize)] +#[non_exhaustive] +#[serde(tag = "tp")] +pub(super) enum TypeParamSer { + Type { b: TypeBound }, + BoundedNat { bound: UpperBound }, + String, + Bytes, + Float, + StaticType, + List { param: Box }, + Tuple { params: ArrayOrTermSer }, +} + +#[derive(Clone, Debug, serde::Deserialize, serde::Serialize)] +#[non_exhaustive] +#[serde(tag = "tya")] +pub(super) enum TypeArgSer { + Type { + ty: Type, + }, + BoundedNat { + n: u64, + }, + String { + arg: String, + }, + Bytes { + #[serde(with = "base64")] + value: Arc<[u8]>, + }, + Float { + value: OrderedFloat, + }, + List { + elems: Vec, + }, + ListConcat { + lists: Vec, + }, + Tuple { + elems: Vec, + }, + TupleConcat { + tuples: Vec, + }, + Variable { + #[serde(flatten)] + v: TermVar, + }, +} + +#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] +#[serde(untagged)] +pub(super) enum TermSer { + TypeArg(TypeArgSer), + TypeParam(TypeParamSer), +} + +impl From for TermSer { + fn from(value: Term) -> Self { + match value { + Term::RuntimeType(b) => TermSer::TypeParam(TypeParamSer::Type { b }), + Term::StaticType => TermSer::TypeParam(TypeParamSer::StaticType), + Term::BoundedNatType(bound) => TermSer::TypeParam(TypeParamSer::BoundedNat { bound }), + Term::StringType => TermSer::TypeParam(TypeParamSer::String), + Term::BytesType => TermSer::TypeParam(TypeParamSer::Bytes), + Term::FloatType => TermSer::TypeParam(TypeParamSer::Float), + Term::ListType(param) => TermSer::TypeParam(TypeParamSer::List { param }), + Term::Runtime(ty) => TermSer::TypeArg(TypeArgSer::Type { ty }), + Term::TupleType(params) => TermSer::TypeParam(TypeParamSer::Tuple { + params: (*params).into(), + }), + Term::BoundedNat(n) => TermSer::TypeArg(TypeArgSer::BoundedNat { n }), + Term::String(arg) => TermSer::TypeArg(TypeArgSer::String { arg }), + Term::Bytes(value) => TermSer::TypeArg(TypeArgSer::Bytes { value }), + Term::Float(value) => TermSer::TypeArg(TypeArgSer::Float { value }), + Term::List(elems) => TermSer::TypeArg(TypeArgSer::List { elems }), + Term::Tuple(elems) => TermSer::TypeArg(TypeArgSer::Tuple { elems }), + Term::Variable(v) => TermSer::TypeArg(TypeArgSer::Variable { v }), + Term::ListConcat(lists) => TermSer::TypeArg(TypeArgSer::ListConcat { lists }), + Term::TupleConcat(tuples) => TermSer::TypeArg(TypeArgSer::TupleConcat { tuples }), + } + } +} + +impl From for Term { + fn from(value: TermSer) -> Self { + match value { + TermSer::TypeParam(param) => match param { + TypeParamSer::Type { b } => Term::RuntimeType(b), + TypeParamSer::StaticType => Term::StaticType, + TypeParamSer::BoundedNat { bound } => Term::BoundedNatType(bound), + TypeParamSer::String => Term::StringType, + TypeParamSer::Bytes => Term::BytesType, + TypeParamSer::Float => Term::FloatType, + TypeParamSer::List { param } => Term::ListType(param), + TypeParamSer::Tuple { params } => Term::TupleType(Box::new(params.into())), + }, + TermSer::TypeArg(arg) => match arg { + TypeArgSer::Type { ty } => Term::Runtime(ty), + TypeArgSer::BoundedNat { n } => Term::BoundedNat(n), + TypeArgSer::String { arg } => Term::String(arg), + TypeArgSer::Bytes { value } => Term::Bytes(value), + TypeArgSer::Float { value } => Term::Float(value), + TypeArgSer::List { elems } => Term::List(elems), + TypeArgSer::Tuple { elems } => Term::Tuple(elems), + TypeArgSer::Variable { v } => Term::Variable(v), + TypeArgSer::ListConcat { lists } => Term::ListConcat(lists), + TypeArgSer::TupleConcat { tuples } => Term::TupleConcat(tuples), + }, + } + } +} + +/// Helper type that serialises lists as JSON arrays for compatibility. +#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] +#[serde(untagged)] +pub(super) enum ArrayOrTermSer { + Array(Vec), + Term(Box), // TODO JSON Schema does not really support this yet +} + +impl From for Term { + fn from(value: ArrayOrTermSer) -> Self { + match value { + ArrayOrTermSer::Array(terms) => Term::new_list(terms), + ArrayOrTermSer::Term(term) => *term, + } + } +} + +impl From for ArrayOrTermSer { + fn from(term: Term) -> Self { + match term { + Term::List(terms) => ArrayOrTermSer::Array(terms), + term => ArrayOrTermSer::Term(Box::new(term)), + } + } +} + +/// Helper for to serialize and deserialize the byte string in [`TypeArg::Bytes`] via base64. +mod base64 { + use std::sync::Arc; + + use base64::Engine as _; + use base64::prelude::BASE64_STANDARD; + use serde::{Deserialize, Serialize}; + use serde::{Deserializer, Serializer}; + + pub fn serialize(v: &Arc<[u8]>, s: S) -> Result { + let base64 = BASE64_STANDARD.encode(v); + base64.serialize(s) + } + + pub fn deserialize<'de, D: Deserializer<'de>>(d: D) -> Result, D::Error> { + let base64 = String::deserialize(d)?; + BASE64_STANDARD + .decode(base64.as_bytes()) + .map(|v| v.into()) + .map_err(serde::de::Error::custom) + } +} diff --git a/hugr-core/src/types/type_param.rs b/hugr-core/src/types/type_param.rs index 0b6d19fa7d..a1cbe4bcea 100644 --- a/hugr-core/src/types/type_param.rs +++ b/hugr-core/src/types/type_param.rs @@ -4,11 +4,15 @@ //! //! [`TypeDef`]: crate::extension::TypeDef -use itertools::Itertools; +use ordered_float::OrderedFloat; #[cfg(test)] use proptest_derive::Arbitrary; +use smallvec::{SmallVec, smallvec}; +use std::iter::FusedIterator; use std::num::NonZeroU64; +use std::sync::Arc; use thiserror::Error; +use tracing::warn; use super::row_var::MaybeRV; use super::{ @@ -48,242 +52,286 @@ impl UpperBound { } } -/// A *kind* of [`TypeArg`]. Thus, a parameter declared by a [`PolyFuncType`] or [`PolyFuncTypeRV`], -/// specifying a value that must be provided statically in order to instantiate it. -/// -/// [`PolyFuncType`]: super::PolyFuncType -/// [`PolyFuncTypeRV`]: super::PolyFuncTypeRV +/// A [`Term`] that is a static argument to an operation or constructor. +pub type TypeArg = Term; + +/// A [`Term`] that is the static type of an operation or constructor parameter. +pub type TypeParam = Term; + +/// A term in the language of static parameters in HUGR. #[derive( Clone, Debug, PartialEq, Eq, Hash, derive_more::Display, serde::Deserialize, serde::Serialize, )] #[non_exhaustive] -#[serde(tag = "tp")] -pub enum TypeParam { - /// Argument is a [`TypeArg::Type`]. - #[display("Type{}", match b { - TypeBound::Any => String::new(), - _ => format!("[{b}]") +#[serde( + from = "crate::types::serialize::TermSer", + into = "crate::types::serialize::TermSer" +)] +pub enum Term { + /// The type of runtime types. + #[display("Type{}", match _0 { + TypeBound::Linear => String::new(), + _ => format!("[{_0}]") })] - Type { - /// Bound for the type parameter. - b: TypeBound, - }, - /// Argument is a [`TypeArg::BoundedNat`] that is less than the upper bound. - #[display("{}", match bound.value() { + RuntimeType(TypeBound), + /// The type of static data. + StaticType, + /// The type of static natural numbers up to a given bound. + #[display("{}", match _0.value() { Some(v) => format!("BoundedNat[{v}]"), None => "Nat".to_string() })] - BoundedNat { - /// Upper bound for the Nat parameter. - bound: UpperBound, - }, - /// Argument is a [`TypeArg::String`]. - String, - /// Argument is a [`TypeArg::Sequence`]. A list of indeterminate size containing - /// parameters all of the (same) specified element type. - #[display("List[{param}]")] - List { - /// The [`TypeParam`] describing each element of the list. - param: Box, - }, - /// Argument is a [`TypeArg::Sequence`]. A tuple of parameters. - #[display("Tuple[{}]", params.iter().map(std::string::ToString::to_string).join(", "))] - Tuple { - /// The [`TypeParam`]s contained in the tuple. - params: Vec, - }, + BoundedNatType(UpperBound), + /// The type of static strings. See [`Term::String`]. + StringType, + /// The type of static byte strings. See [`Term::Bytes`]. + BytesType, + /// The type of static floating point numbers. See [`Term::Float`]. + FloatType, + /// The type of static lists of indeterminate size containing terms of the + /// specified static type. + #[display("ListType[{_0}]")] + ListType(Box), + /// The type of static tuples. + #[display("TupleType[{_0}]")] + TupleType(Box), + /// A runtime type as a term. Instance of [`Term::RuntimeType`]. + #[display("{_0}")] + Runtime(Type), + /// A 64bit unsigned integer literal. Instance of [`Term::BoundedNatType`]. + #[display("{_0}")] + BoundedNat(u64), + /// UTF-8 encoded string literal. Instance of [`Term::StringType`]. + #[display("\"{_0}\"")] + String(String), + /// Byte string literal. Instance of [`Term::BytesType`]. + #[display("bytes")] + Bytes(Arc<[u8]>), + /// A 64-bit floating point number. Instance of [`Term::FloatType`]. + #[display("{}", _0.into_inner())] + Float(OrderedFloat), + /// A list of static terms. Instance of [`Term::ListType`]. + #[display("[{}]", { + use itertools::Itertools as _; + _0.iter().map(|t|t.to_string()).join(",") + })] + List(Vec), + /// Instance of [`TypeParam::List`] defined by a sequence of concatenated lists of the same type. + #[display("[{}]", { + use itertools::Itertools as _; + _0.iter().map(|t| format!("... {t}")).join(",") + })] + ListConcat(Vec), + /// Instance of [`TypeParam::Tuple`] defined by a sequence of elements of varying type. + #[display("({})", { + use itertools::Itertools as _; + _0.iter().map(std::string::ToString::to_string).join(",") + })] + Tuple(Vec), + /// Instance of [`TypeParam::Tuple`] defined by a sequence of concatenated tuples. + #[display("({})", { + use itertools::Itertools as _; + _0.iter().map(|tuple| format!("... {tuple}")).join(",") + })] + TupleConcat(Vec), + /// Variable (used in type schemes or inside polymorphic functions), + /// but not a runtime type (not even a row variable i.e. list of runtime types) + /// - see [`Term::new_var_use`] + #[display("{_0}")] + Variable(TermVar), } -impl TypeParam { - /// [`TypeParam::BoundedNat`] with the maximum bound (`u64::MAX` + 1) +impl Term { + /// Creates a [`Term::BoundedNatType`] with the maximum bound (`u64::MAX` + 1). #[must_use] - pub const fn max_nat() -> Self { - Self::BoundedNat { - bound: UpperBound(None), - } + pub const fn max_nat_type() -> Self { + Self::BoundedNatType(UpperBound(None)) } - /// [`TypeParam::BoundedNat`] with the stated upper bound (non-exclusive) + /// Creates a [`Term::BoundedNatType`] with the stated upper bound (non-exclusive). #[must_use] - pub const fn bounded_nat(upper_bound: NonZeroU64) -> Self { - Self::BoundedNat { - bound: UpperBound(Some(upper_bound)), - } + pub const fn bounded_nat_type(upper_bound: NonZeroU64) -> Self { + Self::BoundedNatType(UpperBound(Some(upper_bound))) } - /// Make a new `TypeParam::List` (an arbitrary-length homogeneous list) - pub fn new_list(elem: impl Into) -> Self { - Self::List { - param: Box::new(elem.into()), - } + /// Creates a new [`Term::List`] given a sequence of its items. + pub fn new_list(items: impl IntoIterator) -> Self { + Self::List(items.into_iter().collect()) + } + + /// Creates a new [`Term::ListType`] given the type of its elements. + pub fn new_list_type(elem: impl Into) -> Self { + Self::ListType(Box::new(elem.into())) + } + + /// Creates a new [`Term::TupleType`] given the type of its elements. + pub fn new_tuple_type(item_types: impl Into) -> Self { + Self::TupleType(Box::new(item_types.into())) } - fn contains(&self, other: &TypeParam) -> bool { + /// Checks if this term is a supertype of another. + /// + /// The subtyping relation applies primarily to terms that represent static + /// types. For consistency the relation is extended to a partial order on + /// all terms; in particular it is reflexive so that every term (even if it + /// is not a static type) is considered a subtype of itself. + fn is_supertype(&self, other: &Term) -> bool { match (self, other) { - (TypeParam::Type { b: b1 }, TypeParam::Type { b: b2 }) => b1.contains(*b2), - (TypeParam::BoundedNat { bound: b1 }, TypeParam::BoundedNat { bound: b2 }) => { - b1.contains(b2) + (Term::RuntimeType(b1), Term::RuntimeType(b2)) => b1.contains(*b2), + (Term::BoundedNatType(b1), Term::BoundedNatType(b2)) => b1.contains(b2), + (Term::StringType, Term::StringType) => true, + (Term::StaticType, Term::StaticType) => true, + (Term::ListType(e1), Term::ListType(e2)) => e1.is_supertype(e2), + (Term::TupleType(es1), Term::TupleType(es2)) => es1.is_supertype(es2), + (Term::BytesType, Term::BytesType) => true, + (Term::FloatType, Term::FloatType) => true, + (Term::Runtime(t1), Term::Runtime(t2)) => t1 == t2, + (Term::BoundedNat(n1), Term::BoundedNat(n2)) => n1 == n2, + (Term::String(s1), Term::String(s2)) => s1 == s2, + (Term::Bytes(v1), Term::Bytes(v2)) => v1 == v2, + (Term::Float(f1), Term::Float(f2)) => f1 == f2, + (Term::Variable(v1), Term::Variable(v2)) => v1 == v2, + (Term::List(es1), Term::List(es2)) => { + es1.len() == es2.len() && es1.iter().zip(es2).all(|(e1, e2)| e1.is_supertype(e2)) } - (TypeParam::String, TypeParam::String) => true, - (TypeParam::List { param: e1 }, TypeParam::List { param: e2 }) => e1.contains(e2), - (TypeParam::Tuple { params: es1 }, TypeParam::Tuple { params: es2 }) => { - es1.len() == es2.len() && es1.iter().zip(es2).all(|(e1, e2)| e1.contains(e2)) + (Term::Tuple(es1), Term::Tuple(es2)) => { + es1.len() == es2.len() && es1.iter().zip(es2).all(|(e1, e2)| e1.is_supertype(e2)) } _ => false, } } } -impl From for TypeParam { +impl From for Term { fn from(bound: TypeBound) -> Self { - Self::Type { b: bound } + Self::RuntimeType(bound) } } -impl From for TypeParam { +impl From for Term { fn from(bound: UpperBound) -> Self { - Self::BoundedNat { bound } + Self::BoundedNatType(bound) } } -/// A statically-known argument value to an operation. -#[derive( - Clone, Debug, PartialEq, Eq, Hash, serde::Deserialize, serde::Serialize, derive_more::Display, -)] -#[non_exhaustive] -#[serde(tag = "tya")] -pub enum TypeArg { - /// Where the (Type/Op)Def declares that an argument is a [`TypeParam::Type`] - #[display("{ty}")] - Type { - /// The concrete type for the parameter. - ty: Type, - }, - /// Instance of [`TypeParam::BoundedNat`]. 64-bit unsigned integer. - #[display("{n}")] - BoundedNat { - /// The integer value for the parameter. - n: u64, - }, - ///Instance of [`TypeParam::String`]. UTF-8 encoded string argument. - #[display("\"{arg}\"")] - String { - /// The string value for the parameter. - arg: String, - }, - /// Instance of [`TypeParam::List`] or [`TypeParam::Tuple`], defined by a - /// sequence of elements. - #[display("({})", { - use itertools::Itertools as _; - elems.iter().map(std::string::ToString::to_string).join(",") - })] - Sequence { - /// List of element types - elems: Vec, - }, - /// Variable (used in type schemes or inside polymorphic functions), - /// but not a [`TypeArg::Type`] (not even a row variable i.e. [`TypeParam::List`] of type) - /// - see [`TypeArg::new_var_use`] - #[display("{v}")] - Variable { - #[allow(missing_docs)] - #[serde(flatten)] - v: TypeArgVariable, - }, -} - -impl From> for TypeArg { +impl From> for Term { fn from(value: TypeBase) -> Self { match value.try_into_type() { - Ok(ty) => TypeArg::Type { ty }, - Err(RowVariable(idx, bound)) => TypeArg::new_var_use(idx, TypeParam::new_list(bound)), + Ok(ty) => Term::Runtime(ty), + Err(RowVariable(idx, bound)) => Term::new_var_use(idx, TypeParam::new_list_type(bound)), } } } -impl From for TypeArg { +impl From for Term { fn from(n: u64) -> Self { - Self::BoundedNat { n } + Self::BoundedNat(n) } } -impl From for TypeArg { +impl From for Term { fn from(arg: String) -> Self { - TypeArg::String { arg } + Term::String(arg) } } -impl From<&str> for TypeArg { +impl From<&str> for Term { fn from(arg: &str) -> Self { - TypeArg::String { - arg: arg.to_string(), - } + Term::String(arg.to_string()) } } -impl From> for TypeArg { - fn from(elems: Vec) -> Self { - Self::Sequence { elems } +impl From> for Term { + fn from(elems: Vec) -> Self { + Self::new_list(elems) } } -/// Variable in a `TypeArg`, that is not a single [`TypeArg::Type`] (i.e. not a [`Type::new_var_use`] +impl From<[Term; N]> for Term { + fn from(value: [Term; N]) -> Self { + Self::new_list(value) + } +} + +/// Variable in a [`Term`], that is not a single runtime type (i.e. not a [`Type::new_var_use`] /// - it might be a [`Type::new_row_var_use`]). #[derive( Clone, Debug, PartialEq, Eq, Hash, serde::Deserialize, serde::Serialize, derive_more::Display, )] #[display("#{idx}")] -pub struct TypeArgVariable { +pub struct TermVar { idx: usize, - cached_decl: TypeParam, + pub(in crate::types) cached_decl: Box, } -impl TypeArg { - /// [`Type::UNIT`] as a [`TypeArg::Type`] - pub const UNIT: Self = Self::Type { ty: Type::UNIT }; +impl Term { + /// [`Type::UNIT`] as a [`Term::Runtime`] + pub const UNIT: Self = Self::Runtime(Type::UNIT); /// Makes a `TypeArg` representing a use (occurrence) of the type variable /// with the specified index. /// `decl` must be exactly that with which the variable was declared. #[must_use] - pub fn new_var_use(idx: usize, decl: TypeParam) -> Self { + pub fn new_var_use(idx: usize, decl: Term) -> Self { match decl { // Note a TypeParam::List of TypeParam::Type *cannot* be represented // as a TypeArg::Type because the latter stores a Type i.e. only a single type, // not a RowVariable. - TypeParam::Type { b } => Type::new_var_use(idx, b).into(), - _ => TypeArg::Variable { - v: TypeArgVariable { - idx, - cached_decl: decl, - }, - }, + Term::RuntimeType(b) => Type::new_var_use(idx, b).into(), + _ => Term::Variable(TermVar { + idx, + cached_decl: Box::new(decl), + }), } } - /// Returns an integer if the `TypeArg` is an instance of `BoundedNat`. + /// Creates a new string literal. + #[inline] + pub fn new_string(str: impl ToString) -> Self { + Self::String(str.to_string()) + } + + /// Creates a new concatenated list. + #[inline] + pub fn new_list_concat(lists: impl IntoIterator) -> Self { + Self::ListConcat(lists.into_iter().collect()) + } + + /// Creates a new tuple from its items. + #[inline] + pub fn new_tuple(items: impl IntoIterator) -> Self { + Self::Tuple(items.into_iter().collect()) + } + + /// Creates a new concatenated tuple. + #[inline] + pub fn new_tuple_concat(tuples: impl IntoIterator) -> Self { + Self::TupleConcat(tuples.into_iter().collect()) + } + + /// Returns an integer if the [`Term`] is a natural number literal. #[must_use] pub fn as_nat(&self) -> Option { match self { - TypeArg::BoundedNat { n } => Some(*n), + TypeArg::BoundedNat(n) => Some(*n), _ => None, } } - /// Returns a type if the `TypeArg` is an instance of Type. + /// Returns a [`Type`] if the [`Term`] is a runtime type. #[must_use] - pub fn as_type(&self) -> Option> { + pub fn as_runtime(&self) -> Option> { match self { - TypeArg::Type { ty } => Some(ty.clone()), + TypeArg::Runtime(ty) => Some(ty.clone()), _ => None, } } - /// Returns a string if the `TypeArg` is an instance of String. + /// Returns a string if the [`Term`] is a string literal. #[must_use] pub fn as_string(&self) -> Option { match self { - TypeArg::String { arg } => Some(arg.clone()), + TypeArg::String(arg) => Some(arg.clone()), _ => None, } } @@ -292,75 +340,264 @@ impl TypeArg { /// is valid and closed. pub(crate) fn validate(&self, var_decls: &[TypeParam]) -> Result<(), SignatureError> { match self { - TypeArg::Type { ty } => ty.validate(var_decls), - TypeArg::BoundedNat { .. } | TypeArg::String { .. } => Ok(()), - TypeArg::Sequence { elems } => elems.iter().try_for_each(|a| a.validate(var_decls)), - TypeArg::Variable { - v: TypeArgVariable { idx, cached_decl }, - } => { + Term::Runtime(ty) => ty.validate(var_decls), + Term::List(elems) => { + // TODO: Full validation would check that the type of the elements agrees + elems.iter().try_for_each(|a| a.validate(var_decls)) + } + Term::Tuple(elems) => elems.iter().try_for_each(|a| a.validate(var_decls)), + Term::BoundedNat(_) | Term::String { .. } | Term::Float(_) | Term::Bytes(_) => Ok(()), + TypeArg::ListConcat(lists) => { + // TODO: Full validation would check that each of the lists is indeed a + // list or list variable of the correct types. + lists.iter().try_for_each(|a| a.validate(var_decls)) + } + TypeArg::TupleConcat(tuples) => tuples.iter().try_for_each(|a| a.validate(var_decls)), + Term::Variable(TermVar { idx, cached_decl }) => { assert!( - !matches!(cached_decl, TypeParam::Type { .. }), + !matches!(&**cached_decl, TypeParam::RuntimeType { .. }), "Malformed TypeArg::Variable {cached_decl} - should be inconstructible" ); check_typevar_decl(var_decls, *idx, cached_decl) } + Term::RuntimeType { .. } => Ok(()), + Term::BoundedNatType { .. } => Ok(()), + Term::StringType => Ok(()), + Term::BytesType => Ok(()), + Term::FloatType => Ok(()), + Term::ListType(item_type) => item_type.validate(var_decls), + Term::TupleType(item_types) => item_types.validate(var_decls), + Term::StaticType => Ok(()), } } pub(crate) fn substitute(&self, t: &Substitution) -> Self { match self { - TypeArg::Type { ty } => { - // RowVariables are represented as TypeArg::Variable + Term::Runtime(ty) => { + // RowVariables are represented as Term::Variable ty.substitute1(t).into() } - TypeArg::BoundedNat { .. } | TypeArg::String { .. } => self.clone(), // We do not allow variables as bounds on BoundedNat's - TypeArg::Sequence { elems } => { - let mut are_types = elems.iter().map(|ta| match ta { - TypeArg::Type { .. } => true, - TypeArg::Variable { v } => v.bound_if_row_var().is_some(), - _ => false, - }); - let elems = match are_types.next() { - Some(true) => { - assert!(are_types.all(|b| b)); // If one is a Type, so must the rest be - // So, anything that doesn't produce a Type, was a row variable => multiple Types - elems - .iter() - .flat_map(|ta| match ta.substitute(t) { - ty @ TypeArg::Type { .. } => vec![ty], - TypeArg::Sequence { elems } => elems, - _ => panic!("Expected Type or row of Types"), - }) - .collect() + TypeArg::BoundedNat(_) | TypeArg::String(_) | TypeArg::Bytes(_) | TypeArg::Float(_) => { + self.clone() + } // We do not allow variables as bounds on BoundedNat's + TypeArg::List(elems) => { + // NOTE: This implements a hack allowing substitutions to + // replace `TypeArg::Variable`s representing "row variables" + // with a list that is to be spliced into the containing list. + // We won't need this code anymore once we stop conflating types + // with lists of types. + + fn is_type(type_arg: &TypeArg) -> bool { + match type_arg { + TypeArg::Runtime(_) => true, + TypeArg::Variable(v) => v.bound_if_row_var().is_some(), + _ => false, } - _ => { - // not types, no need to flatten (and mustn't, in case of nested Sequences) - elems.iter().map(|ta| ta.substitute(t)).collect() + } + + let are_types = elems.first().map(is_type).unwrap_or(false); + + Self::new_list_from_parts(elems.iter().map(|elem| match elem.substitute(t) { + list @ TypeArg::List { .. } if are_types => SeqPart::Splice(list), + list @ TypeArg::ListConcat { .. } if are_types => SeqPart::Splice(list), + elem => SeqPart::Item(elem), + })) + } + TypeArg::ListConcat(lists) => { + // When a substitution instantiates spliced list variables, we + // may be able to merge the concatenated lists. + Self::new_list_from_parts( + lists.iter().map(|list| SeqPart::Splice(list.substitute(t))), + ) + } + Term::Tuple(elems) => { + Term::Tuple(elems.iter().map(|elem| elem.substitute(t)).collect()) + } + TypeArg::TupleConcat(tuples) => { + // When a substitution instantiates spliced tuple variables, + // we may be able to merge the concatenated tuples. + Self::new_tuple_from_parts( + tuples + .iter() + .map(|tuple| SeqPart::Splice(tuple.substitute(t))), + ) + } + TypeArg::Variable(TermVar { idx, cached_decl }) => t.apply_var(*idx, cached_decl), + Term::RuntimeType(_) => self.clone(), + Term::BoundedNatType(_) => self.clone(), + Term::StringType => self.clone(), + Term::BytesType => self.clone(), + Term::FloatType => self.clone(), + Term::ListType(item_type) => Term::new_list_type(item_type.substitute(t)), + Term::TupleType(item_types) => Term::new_list_type(item_types.substitute(t)), + Term::StaticType => self.clone(), + } + } + + /// Helper method for [`TypeArg::new_list_from_parts`] and [`TypeArg::new_tuple_from_parts`]. + fn new_seq_from_parts( + parts: impl IntoIterator>, + make_items: impl Fn(Vec) -> Self, + make_concat: impl Fn(Vec) -> Self, + ) -> Self { + let mut items = Vec::new(); + let mut seqs = Vec::new(); + + for part in parts { + match part { + SeqPart::Item(item) => items.push(item), + SeqPart::Splice(seq) => { + if !items.is_empty() { + seqs.push(make_items(std::mem::take(&mut items))); } - }; - TypeArg::Sequence { elems } + seqs.push(seq); + } } - TypeArg::Variable { - v: TypeArgVariable { idx, cached_decl }, - } => t.apply_var(*idx, cached_decl), } + + if seqs.is_empty() { + make_items(items) + } else if items.is_empty() { + make_concat(seqs) + } else { + seqs.push(make_items(items)); + make_concat(seqs) + } + } + + /// Creates a new list from a sequence of [`SeqPart`]s. + pub fn new_list_from_parts(parts: impl IntoIterator>) -> Self { + Self::new_seq_from_parts( + parts.into_iter().flat_map(ListPartIter::new), + TypeArg::List, + TypeArg::ListConcat, + ) + } + + /// Iterates over the [`SeqPart`]s of a list. + /// + /// # Examples + /// + /// The parts of a closed list are the items of that list wrapped in [`SeqPart::Item`]: + /// + /// ``` + /// # use hugr_core::types::type_param::{Term, SeqPart}; + /// # let a = Term::new_string("a"); + /// # let b = Term::new_string("b"); + /// let term = Term::new_list([a.clone(), b.clone()]); + /// + /// assert_eq!( + /// term.into_list_parts().collect::>(), + /// vec![SeqPart::Item(a), SeqPart::Item(b)] + /// ); + /// ``` + /// + /// Parts of a concatenated list that are not closed lists are wrapped in [`SeqPart::Splice`]: + /// + /// ``` + /// # use hugr_core::types::type_param::{Term, SeqPart}; + /// # let a = Term::new_string("a"); + /// # let b = Term::new_string("b"); + /// # let c = Term::new_string("c"); + /// let var = Term::new_var_use(0, Term::new_list_type(Term::StringType)); + /// let term = Term::new_list_concat([ + /// Term::new_list([a.clone(), b.clone()]), + /// var.clone(), + /// Term::new_list([c.clone()]) + /// ]); + /// + /// assert_eq!( + /// term.into_list_parts().collect::>(), + /// vec![SeqPart::Item(a), SeqPart::Item(b), SeqPart::Splice(var), SeqPart::Item(c)] + /// ); + /// ``` + /// + /// Nested concatenations are traversed recursively: + /// + /// ``` + /// # use hugr_core::types::type_param::{Term, SeqPart}; + /// # let a = Term::new_string("a"); + /// # let b = Term::new_string("b"); + /// # let c = Term::new_string("c"); + /// let term = Term::new_list_concat([ + /// Term::new_list_concat([ + /// Term::new_list([a.clone()]), + /// Term::new_list([b.clone()]) + /// ]), + /// Term::new_list([]), + /// Term::new_list([c.clone()]) + /// ]); + /// + /// assert_eq!( + /// term.into_list_parts().collect::>(), + /// vec![SeqPart::Item(a), SeqPart::Item(b), SeqPart::Item(c)] + /// ); + /// ``` + /// + /// When invoked on a type argument that is not a list, a single + /// [`SeqPart::Splice`] is returned that wraps the type argument. + /// This is the expected behaviour for type variables that stand for lists. + /// This behaviour also allows this method not to fail on ill-typed type arguments. + /// ``` + /// # use hugr_core::types::type_param::{Term, SeqPart}; + /// let term = Term::new_string("not a list"); + /// assert_eq!( + /// term.clone().into_list_parts().collect::>(), + /// vec![SeqPart::Splice(term)] + /// ); + /// ``` + #[inline] + pub fn into_list_parts(self) -> ListPartIter { + ListPartIter::new(SeqPart::Splice(self)) + } + + /// Creates a new tuple from a sequence of [`SeqPart`]s. + /// + /// Analogous to [`TypeArg::new_list_from_parts`]. + pub fn new_tuple_from_parts(parts: impl IntoIterator>) -> Self { + Self::new_seq_from_parts( + parts.into_iter().flat_map(TuplePartIter::new), + TypeArg::Tuple, + TypeArg::TupleConcat, + ) + } + + /// Iterates over the [`SeqPart`]s of a tuple. + /// + /// Analogous to [`TypeArg::into_list_parts`]. + #[inline] + pub fn into_tuple_parts(self) -> TuplePartIter { + TuplePartIter::new(SeqPart::Splice(self)) } } -impl Transformable for TypeArg { +impl Transformable for Term { fn transform(&mut self, tr: &T) -> Result { match self { - TypeArg::Type { ty } => ty.transform(tr), - TypeArg::Sequence { elems } => elems.transform(tr), - TypeArg::BoundedNat { .. } | TypeArg::String { .. } | TypeArg::Variable { .. } => { - Ok(false) - } + Term::Runtime(ty) => ty.transform(tr), + Term::List(elems) => elems.transform(tr), + Term::Tuple(elems) => elems.transform(tr), + Term::BoundedNat(_) + | Term::String(_) + | Term::Variable(_) + | Term::Float(_) + | Term::Bytes(_) => Ok(false), + Term::RuntimeType { .. } => Ok(false), + Term::BoundedNatType { .. } => Ok(false), + Term::StringType => Ok(false), + Term::BytesType => Ok(false), + Term::FloatType => Ok(false), + Term::ListType(item_type) => item_type.transform(tr), + Term::TupleType(item_types) => item_types.transform(tr), + Term::StaticType => Ok(false), + TypeArg::ListConcat(lists) => lists.transform(tr), + TypeArg::TupleConcat(tuples) => tuples.transform(tr), } } } -impl TypeArgVariable { +impl TermVar { /// Return the index. #[must_use] pub fn index(&self) -> usize { @@ -371,8 +608,8 @@ impl TypeArgVariable { /// the [`TypeBound`] of the individual types it might stand for. #[must_use] pub fn bound_if_row_var(&self) -> Option { - if let TypeParam::List { param } = &self.cached_decl { - if let TypeParam::Type { b } = **param { + if let Term::ListType(item_type) = &*self.cached_decl { + if let Term::RuntimeType(b) = **item_type { return Some(b); } } @@ -380,80 +617,103 @@ impl TypeArgVariable { } } -/// Checks a [`TypeArg`] is as expected for a [`TypeParam`] -pub fn check_type_arg(arg: &TypeArg, param: &TypeParam) -> Result<(), TypeArgError> { - match (arg, param) { - ( - TypeArg::Variable { - v: TypeArgVariable { cached_decl, .. }, - }, - _, - ) if param.contains(cached_decl) => Ok(()), - (TypeArg::Type { ty }, TypeParam::Type { b: bound }) - if bound.contains(ty.least_upper_bound()) => - { +/// Checks that a [`Term`] is valid for a given type. +pub fn check_term_type(term: &Term, type_: &Term) -> Result<(), TermTypeError> { + match (term, type_) { + (Term::Variable(TermVar { cached_decl, .. }), _) if type_.is_supertype(cached_decl) => { Ok(()) } - (TypeArg::Sequence { elems }, TypeParam::List { param }) => { - elems.iter().try_for_each(|arg| { + (Term::Runtime(ty), Term::RuntimeType(bound)) if bound.contains(ty.least_upper_bound()) => { + Ok(()) + } + (Term::List(elems), Term::ListType(item_type)) => { + elems.iter().try_for_each(|term| { // Also allow elements that are RowVars if fitting into a List of Types - if let (TypeArg::Variable { v }, TypeParam::Type { b: param_bound }) = - (arg, &**param) - { + if let (Term::Variable(v), Term::RuntimeType(param_bound)) = (term, &**item_type) { if v.bound_if_row_var() .is_some_and(|arg_bound| param_bound.contains(arg_bound)) { return Ok(()); } } - check_type_arg(arg, param) + check_term_type(term, item_type) }) } - (TypeArg::Sequence { elems: items }, TypeParam::Tuple { params: types }) => { - if items.len() == types.len() { - items - .iter() - .zip(types.iter()) - .try_for_each(|(arg, param)| check_type_arg(arg, param)) - } else { - Err(TypeArgError::WrongNumberTuple(items.len(), types.len())) + (Term::ListConcat(lists), Term::ListType(item_type)) => lists + .iter() + .try_for_each(|list| check_term_type(list, item_type)), + (TypeArg::Tuple(_) | TypeArg::TupleConcat(_), TypeParam::TupleType(item_types)) => { + let term_parts: Vec<_> = term.clone().into_tuple_parts().collect(); + let type_parts: Vec<_> = item_types.clone().into_list_parts().collect(); + + for (term, type_) in term_parts.iter().zip(&type_parts) { + match (term, type_) { + (SeqPart::Item(term), SeqPart::Item(type_)) => { + check_term_type(term, type_)?; + } + (_, SeqPart::Splice(_)) | (SeqPart::Splice(_), _) => { + // TODO: Checking tuples with splicing requires more + // sophisticated validation infrastructure to do well. + warn!( + "Validation for open tuples is not implemented yet, succeeding regardless..." + ); + return Ok(()); + } + } } - } - (TypeArg::BoundedNat { n: val }, TypeParam::BoundedNat { bound }) - if bound.valid_value(*val) => - { + + if term_parts.len() != type_parts.len() { + return Err(TermTypeError::WrongNumberTuple( + term_parts.len(), + type_parts.len(), + )); + } + Ok(()) } + (Term::BoundedNat(val), Term::BoundedNatType(bound)) if bound.valid_value(*val) => Ok(()), + (Term::String { .. }, Term::StringType) => Ok(()), + (Term::Bytes(_), Term::BytesType) => Ok(()), + (Term::Float(_), Term::FloatType) => Ok(()), + + // Static types + (Term::StaticType, Term::StaticType) => Ok(()), + (Term::StringType, Term::StaticType) => Ok(()), + (Term::BytesType, Term::StaticType) => Ok(()), + (Term::BoundedNatType { .. }, Term::StaticType) => Ok(()), + (Term::FloatType, Term::StaticType) => Ok(()), + (Term::ListType { .. }, Term::StaticType) => Ok(()), + (Term::TupleType(_), Term::StaticType) => Ok(()), + (Term::RuntimeType(_), Term::StaticType) => Ok(()), - (TypeArg::String { .. }, TypeParam::String) => Ok(()), - _ => Err(TypeArgError::TypeMismatch { - arg: arg.clone(), - param: param.clone(), + _ => Err(TermTypeError::TypeMismatch { + term: Box::new(term.clone()), + type_: Box::new(type_.clone()), }), } } -/// Check a list of type arguments match a list of required type parameters -pub fn check_type_args(args: &[TypeArg], params: &[TypeParam]) -> Result<(), TypeArgError> { - if args.len() != params.len() { - return Err(TypeArgError::WrongNumberArgs(args.len(), params.len())); +/// Check a list of [`Term`]s is valid for a list of types. +pub fn check_term_types(terms: &[Term], types: &[Term]) -> Result<(), TermTypeError> { + if terms.len() != types.len() { + return Err(TermTypeError::WrongNumberArgs(terms.len(), types.len())); } - for (a, p) in args.iter().zip(params.iter()) { - check_type_arg(a, p)?; + for (term, type_) in terms.iter().zip(types.iter()) { + check_term_type(term, type_)?; } Ok(()) } -/// Errors that can occur fitting a [`TypeArg`] into a [`TypeParam`] +/// Errors that can occur when checking that a [`Term`] has an expected type. #[derive(Clone, Debug, PartialEq, Eq, Error)] #[non_exhaustive] -pub enum TypeArgError { +pub enum TermTypeError { #[allow(missing_docs)] - /// For now, general case of a type arg not fitting a param. + /// For now, general case of a term not fitting a type. /// We'll have more cases when we allow general Containers. // TODO It may become possible to combine this with ConstTypeError. - #[error("Type argument {arg} does not fit declared parameter {param}")] - TypeMismatch { param: TypeParam, arg: TypeArg }, + #[error("Term {term} does not fit declared type {type_}")] + TypeMismatch { term: Box, type_: Box }, /// Wrong number of type arguments (actual vs expected). // For now this only happens at the top level (TypeArgs of op/type vs TypeParams of Op/TypeDef). // However in the future it may be applicable to e.g. contents of Tuples too. @@ -470,35 +730,173 @@ pub enum TypeArgError { OpaqueTypeMismatch(#[from] crate::types::CustomCheckFailure), /// Invalid value #[error("Invalid value of type argument")] - InvalidValue(TypeArg), + InvalidValue(Box), +} + +/// Part of a sequence. +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub enum SeqPart { + /// An individual item in the sequence. + Item(T), + /// A subsequence that is spliced into the parent sequence. + Splice(T), +} + +/// Iterator created by [`TypeArg::into_list_parts`]. +#[derive(Debug, Clone)] +pub struct ListPartIter { + parts: SmallVec<[SeqPart; 1]>, } +impl ListPartIter { + #[inline] + fn new(part: SeqPart) -> Self { + Self { + parts: smallvec![part], + } + } +} + +impl Iterator for ListPartIter { + type Item = SeqPart; + + fn next(&mut self) -> Option { + loop { + match self.parts.pop()? { + SeqPart::Splice(TypeArg::List(elems)) => self + .parts + .extend(elems.into_iter().rev().map(SeqPart::Item)), + SeqPart::Splice(TypeArg::ListConcat(lists)) => self + .parts + .extend(lists.into_iter().rev().map(SeqPart::Splice)), + part => return Some(part), + } + } + } +} + +impl FusedIterator for ListPartIter {} + +/// Iterator created by [`TypeArg::into_tuple_parts`]. +#[derive(Debug, Clone)] +pub struct TuplePartIter { + parts: SmallVec<[SeqPart; 1]>, +} + +impl TuplePartIter { + #[inline] + fn new(part: SeqPart) -> Self { + Self { + parts: smallvec![part], + } + } +} + +impl Iterator for TuplePartIter { + type Item = SeqPart; + + fn next(&mut self) -> Option { + loop { + match self.parts.pop()? { + SeqPart::Splice(TypeArg::Tuple(elems)) => self + .parts + .extend(elems.into_iter().rev().map(SeqPart::Item)), + SeqPart::Splice(TypeArg::TupleConcat(tuples)) => self + .parts + .extend(tuples.into_iter().rev().map(SeqPart::Splice)), + part => return Some(part), + } + } + } +} + +impl FusedIterator for TuplePartIter {} + #[cfg(test)] mod test { use itertools::Itertools; - use super::{Substitution, TypeArg, TypeParam, check_type_arg}; + use super::{Substitution, TypeArg, TypeParam, check_term_type}; use crate::extension::prelude::{bool_t, usize_t}; - use crate::types::{TypeBound, TypeRV, type_param::TypeArgError}; + use crate::types::Term; + use crate::types::type_param::SeqPart; + use crate::types::{TypeBound, TypeRV, type_param::TermTypeError}; + + #[test] + fn new_list_from_parts_items() { + let a = TypeArg::new_string("a"); + let b = TypeArg::new_string("b"); + + let parts = [SeqPart::Item(a.clone()), SeqPart::Item(b.clone())]; + let items = [a, b]; + + assert_eq!( + TypeArg::new_list_from_parts(parts.clone()), + TypeArg::new_list(items.clone()) + ); + + assert_eq!( + TypeArg::new_tuple_from_parts(parts), + TypeArg::new_tuple(items) + ); + } + + #[test] + fn new_list_from_parts_flatten() { + let a = Term::new_string("a"); + let b = Term::new_string("b"); + let c = Term::new_string("c"); + let d = Term::new_string("d"); + let var = Term::new_var_use(0, Term::new_list_type(Term::StringType)); + let parts = [ + SeqPart::Splice(Term::new_list([a.clone(), b.clone()])), + SeqPart::Splice(Term::new_list_concat([Term::new_list([c.clone()])])), + SeqPart::Item(d.clone()), + SeqPart::Splice(var.clone()), + ]; + assert_eq!( + Term::new_list_from_parts(parts), + Term::new_list_concat([Term::new_list([a, b, c, d]), var]) + ); + } + + #[test] + fn new_tuple_from_parts_flatten() { + let a = Term::new_string("a"); + let b = Term::new_string("b"); + let c = Term::new_string("c"); + let d = Term::new_string("d"); + let var = Term::new_var_use(0, Term::new_tuple([Term::StringType])); + let parts = [ + SeqPart::Splice(Term::new_tuple([a.clone(), b.clone()])), + SeqPart::Splice(Term::new_tuple_concat([Term::new_tuple([c.clone()])])), + SeqPart::Item(d.clone()), + SeqPart::Splice(var.clone()), + ]; + assert_eq!( + Term::new_tuple_from_parts(parts), + Term::new_tuple_concat([Term::new_tuple([a, b, c, d]), var]) + ); + } #[test] fn type_arg_fits_param() { let rowvar = TypeRV::new_row_var_use; - fn check(arg: impl Into, param: &TypeParam) -> Result<(), TypeArgError> { - check_type_arg(&arg.into(), param) + fn check(arg: impl Into, param: &TypeParam) -> Result<(), TermTypeError> { + check_term_type(&arg.into(), param) } fn check_seq>( args: &[T], param: &TypeParam, - ) -> Result<(), TypeArgError> { + ) -> Result<(), TermTypeError> { let arg = args.iter().cloned().map_into().collect_vec().into(); - check_type_arg(&arg, param) + check_term_type(&arg, param) } - // Simple cases: a TypeArg::Type is a TypeParam::Type but singleton sequences are lists + // Simple cases: a Term::Type is a Term::RuntimeType but singleton sequences are lists check(usize_t(), &TypeBound::Copyable.into()).unwrap(); - let seq_param = TypeParam::new_list(TypeBound::Copyable); + let seq_param = TypeParam::new_list_type(TypeBound::Copyable); check(usize_t(), &seq_param).unwrap_err(); - check_seq(&[usize_t()], &TypeBound::Any.into()).unwrap_err(); + check_seq(&[usize_t()], &TypeBound::Linear.into()).unwrap_err(); // Into a list of type, we can fit a single row var check(rowvar(0, TypeBound::Copyable), &seq_param).unwrap(); @@ -507,17 +905,17 @@ mod test { check_seq(&[rowvar(0, TypeBound::Copyable)], &seq_param).unwrap(); check_seq( &[ - rowvar(1, TypeBound::Any), + rowvar(1, TypeBound::Linear), usize_t().into(), rowvar(0, TypeBound::Copyable), ], - &TypeParam::new_list(TypeBound::Any), + &TypeParam::new_list_type(TypeBound::Linear), ) .unwrap(); // Next one fails because a list of Eq is required check_seq( &[ - rowvar(1, TypeBound::Any), + rowvar(1, TypeBound::Linear), usize_t().into(), rowvar(0, TypeBound::Copyable), ], @@ -532,9 +930,9 @@ mod test { .unwrap_err(); // Similar for nats (but no equivalent of fancy row vars) - check(5, &TypeParam::max_nat()).unwrap(); - check_seq(&[5], &TypeParam::max_nat()).unwrap_err(); - let list_of_nat = TypeParam::new_list(TypeParam::max_nat()); + check(5, &TypeParam::max_nat_type()).unwrap(); + check_seq(&[5], &TypeParam::max_nat_type()).unwrap_err(); + let list_of_nat = TypeParam::new_list_type(TypeParam::max_nat_type()); check(5, &list_of_nat).unwrap_err(); check_seq(&[5], &list_of_nat).unwrap(); check(TypeArg::new_var_use(0, list_of_nat.clone()), &list_of_nat).unwrap(); @@ -545,15 +943,23 @@ mod test { ) .unwrap_err(); - // TypeParam::Tuples require a TypeArg::Seq of the same number of elems - let usize_and_ty = TypeParam::Tuple { - params: vec![TypeParam::max_nat(), TypeBound::Copyable.into()], - }; - check(vec![5.into(), usize_t().into()], &usize_and_ty).unwrap(); - check(vec![usize_t().into(), 5.into()], &usize_and_ty).unwrap_err(); // Wrong way around - let two_types = TypeParam::Tuple { - params: vec![TypeBound::Any.into(), TypeBound::Any.into()], - }; + // `Term::TupleType` requires a `Term::Tuple` of the same number of elems + let usize_and_ty = + TypeParam::new_tuple_type([TypeParam::max_nat_type(), TypeBound::Copyable.into()]); + check( + TypeArg::Tuple(vec![5.into(), usize_t().into()]), + &usize_and_ty, + ) + .unwrap(); + check( + TypeArg::Tuple(vec![usize_t().into(), 5.into()]), + &usize_and_ty, + ) + .unwrap_err(); // Wrong way around + let two_types = TypeParam::new_tuple_type(Term::new_list([ + TypeBound::Linear.into(), + TypeBound::Linear.into(), + ])); check(TypeArg::new_var_use(0, two_types.clone()), &two_types).unwrap(); // not a Row Var which could have any number of elems check(TypeArg::new_var_use(0, seq_param), &two_types).unwrap_err(); @@ -561,116 +967,143 @@ mod test { #[test] fn type_arg_subst_row() { - let row_param = TypeParam::new_list(TypeBound::Copyable); - let row_arg: TypeArg = vec![bool_t().into(), TypeArg::UNIT].into(); - check_type_arg(&row_arg, &row_param).unwrap(); + let row_param = Term::new_list_type(TypeBound::Copyable); + let row_arg: Term = vec![bool_t().into(), Term::UNIT].into(); + check_term_type(&row_arg, &row_param).unwrap(); // Now say a row variable referring to *that* row was used // to instantiate an outer "row parameter" (list of type). - let outer_param = TypeParam::new_list(TypeBound::Any); - let outer_arg = TypeArg::Sequence { - elems: vec![ - TypeRV::new_row_var_use(0, TypeBound::Copyable).into(), - usize_t().into(), - ], - }; - check_type_arg(&outer_arg, &outer_param).unwrap(); + let outer_param = Term::new_list_type(TypeBound::Linear); + let outer_arg = Term::new_list([ + TypeRV::new_row_var_use(0, TypeBound::Copyable).into(), + usize_t().into(), + ]); + check_term_type(&outer_arg, &outer_param).unwrap(); let outer_arg2 = outer_arg.substitute(&Substitution(&[row_arg])); assert_eq!( outer_arg2, - vec![bool_t().into(), TypeArg::UNIT, usize_t().into()].into() + vec![bool_t().into(), Term::UNIT, usize_t().into()].into() ); // Of course this is still valid (as substitution is guaranteed to preserve validity) - check_type_arg(&outer_arg2, &outer_param).unwrap(); + check_term_type(&outer_arg2, &outer_param).unwrap(); } #[test] fn subst_list_list() { - let outer_param = TypeParam::new_list(TypeParam::new_list(TypeBound::Any)); - let row_var_decl = TypeParam::new_list(TypeBound::Copyable); - let row_var_use = TypeArg::new_var_use(0, row_var_decl.clone()); - let good_arg = TypeArg::Sequence { - elems: vec![ - // The row variables here refer to `row_var_decl` above - vec![usize_t().into()].into(), - row_var_use.clone(), - vec![row_var_use, usize_t().into()].into(), - ], - }; - check_type_arg(&good_arg, &outer_param).unwrap(); + let outer_param = Term::new_list_type(Term::new_list_type(TypeBound::Linear)); + let row_var_decl = Term::new_list_type(TypeBound::Copyable); + let row_var_use = Term::new_var_use(0, row_var_decl.clone()); + let good_arg = Term::new_list([ + // The row variables here refer to `row_var_decl` above + vec![usize_t().into()].into(), + row_var_use.clone(), + vec![row_var_use, usize_t().into()].into(), + ]); + check_term_type(&good_arg, &outer_param).unwrap(); // Outer list cannot include single types: - let TypeArg::Sequence { mut elems } = good_arg.clone() else { + let Term::List(mut elems) = good_arg.clone() else { panic!() }; elems.push(usize_t().into()); assert_eq!( - check_type_arg(&TypeArg::Sequence { elems }, &outer_param), - Err(TypeArgError::TypeMismatch { - arg: usize_t().into(), + check_term_type(&Term::new_list(elems), &outer_param), + Err(TermTypeError::TypeMismatch { + term: Box::new(usize_t().into()), // The error reports the type expected for each element of the list: - param: TypeParam::new_list(TypeBound::Any) + type_: Box::new(TypeParam::new_list_type(TypeBound::Linear)) }) ); // Now substitute a list of two types for that row-variable let row_var_arg = vec![usize_t().into(), bool_t().into()].into(); - check_type_arg(&row_var_arg, &row_var_decl).unwrap(); + check_term_type(&row_var_arg, &row_var_decl).unwrap(); let subst_arg = good_arg.substitute(&Substitution(&[row_var_arg.clone()])); - check_type_arg(&subst_arg, &outer_param).unwrap(); // invariance of substitution + check_term_type(&subst_arg, &outer_param).unwrap(); // invariance of substitution assert_eq!( subst_arg, - TypeArg::Sequence { - elems: vec![ - vec![usize_t().into()].into(), - row_var_arg, - vec![usize_t().into(), bool_t().into(), usize_t().into()].into() - ] - } + Term::new_list([ + Term::new_list([usize_t().into()]), + row_var_arg, + Term::new_list([usize_t().into(), bool_t().into(), usize_t().into()]) + ]) ); } + #[test] + fn bytes_json_roundtrip() { + let bytes_arg = Term::Bytes(vec![0, 1, 2, 3, 255, 254, 253, 252].into()); + let serialized = serde_json::to_string(&bytes_arg).unwrap(); + let deserialized: Term = serde_json::from_str(&serialized).unwrap(); + assert_eq!(deserialized, bytes_arg); + } + mod proptest { use proptest::prelude::*; - use super::super::{TypeArg, TypeArgVariable, TypeParam, UpperBound}; + use super::super::{TermVar, UpperBound}; use crate::proptest::RecursionDepth; - use crate::types::{Type, TypeBound}; + use crate::types::{Term, Type, TypeBound, proptest_utils::any_serde_type_param}; - impl Arbitrary for TypeArgVariable { + impl Arbitrary for TermVar { type Parameters = RecursionDepth; type Strategy = BoxedStrategy; fn arbitrary_with(depth: Self::Parameters) -> Self::Strategy { - (any::(), any_with::(depth)) - .prop_map(|(idx, cached_decl)| Self { idx, cached_decl }) + (any::(), any_serde_type_param(depth)) + .prop_map(|(idx, cached_decl)| Self { + idx, + cached_decl: Box::new(cached_decl), + }) .boxed() } } - impl Arbitrary for TypeParam { + impl Arbitrary for Term { type Parameters = RecursionDepth; type Strategy = BoxedStrategy; fn arbitrary_with(depth: Self::Parameters) -> Self::Strategy { use prop::collection::vec; use prop::strategy::Union; let mut strat = Union::new([ - Just(Self::String).boxed(), - any::().prop_map(|b| Self::Type { b }).boxed(), - any::() - .prop_map(|bound| Self::BoundedNat { bound }) + Just(Self::StringType).boxed(), + Just(Self::BytesType).boxed(), + Just(Self::FloatType).boxed(), + Just(Self::StringType).boxed(), + any::().prop_map(Self::from).boxed(), + any::().prop_map(Self::from).boxed(), + any::().prop_map(Self::from).boxed(), + any::().prop_map(Self::from).boxed(), + any::>() + .prop_map(|bytes| Self::Bytes(bytes.into())) .boxed(), + any::() + .prop_map(|value| Self::Float(value.into())) + .boxed(), + any_with::(depth).prop_map(Self::from).boxed(), ]); if !depth.leaf() { - // we descend here because we these constructors contain TypeParams + // we descend here because we these constructors contain Terms strat = strat + .or( + // TODO this is a bit dodgy, TypeArgVariables are supposed + // to be constructed from TypeArg::new_var_use. We are only + // using this instance for serialization now, but if we want + // to generate valid TypeArgs this will need to change. + any_with::(depth.descend()) + .prop_map(Self::Variable) + .boxed(), + ) + .or(any_with::(depth.descend()) + .prop_map(Self::new_list_type) + .boxed()) .or(any_with::(depth.descend()) - .prop_map(|x| Self::List { param: Box::new(x) }) + .prop_map(Self::new_tuple_type) .boxed()) .or(vec(any_with::(depth.descend()), 0..3) - .prop_map(|params| Self::Tuple { params }) + .prop_map(Self::new_list) .boxed()); } @@ -678,33 +1111,10 @@ mod test { } } - impl Arbitrary for TypeArg { - type Parameters = RecursionDepth; - type Strategy = BoxedStrategy; - fn arbitrary_with(depth: Self::Parameters) -> Self::Strategy { - use prop::collection::vec; - use prop::strategy::Union; - let mut strat = Union::new([ - any::().prop_map(|n| Self::BoundedNat { n }).boxed(), - any::().prop_map(|arg| Self::String { arg }).boxed(), - any_with::(depth) - .prop_map(|ty| Self::Type { ty }) - .boxed(), - // TODO this is a bit dodgy, TypeArgVariables are supposed - // to be constructed from TypeArg::new_var_use. We are only - // using this instance for serialization now, but if we want - // to generate valid TypeArgs this will need to change. - any_with::(depth) - .prop_map(|v| Self::Variable { v }) - .boxed(), - ]); - if !depth.leaf() { - // We descend here because this constructor contains TypeArg> - strat = strat.or(vec(any_with::(depth.descend()), 0..3) - .prop_map(|elems| Self::Sequence { elems }) - .boxed()); - } - strat.boxed() + proptest! { + #[test] + fn term_contains_itself(term: Term) { + assert!(term.is_supertype(&term)); } } } diff --git a/hugr-core/src/types/type_row.rs b/hugr-core/src/types/type_row.rs index c458b8c181..7b9e24d282 100644 --- a/hugr-core/src/types/type_row.rs +++ b/hugr-core/src/types/type_row.rs @@ -8,8 +8,8 @@ use std::{ }; use super::{ - MaybeRV, NoRV, RowVariable, Substitution, Transformable, Type, TypeBase, TypeTransformer, - type_param::TypeParam, + MaybeRV, NoRV, RowVariable, Substitution, Term, Transformable, Type, TypeArg, TypeBase, TypeRV, + TypeTransformer, type_param::TypeParam, }; use crate::{extension::SignatureError, utils::display_list}; use delegate::delegate; @@ -28,7 +28,8 @@ pub struct TypeRowBase { /// Row of single types i.e. of known length, for node inputs/outputs pub type TypeRow = TypeRowBase; -/// Row of types and/or row variables, the number of actual types is thus unknown +/// Row of types and/or row variables, the number of actual types is thus +/// unknown pub type TypeRowRV = TypeRowBase; impl PartialEq> for TypeRowBase { @@ -195,6 +196,81 @@ impl From for TypeRow { } } +// Fallibly convert a [Term] to a [TypeRV]. +// +// This will fail if `arg` is of non-type kind (e.g. String). +impl TryFrom for TypeRV { + type Error = SignatureError; + + fn try_from(value: Term) -> Result { + match value { + TypeArg::Runtime(ty) => Ok(ty.into()), + TypeArg::Variable(v) => Ok(TypeRV::new_row_var_use( + v.index(), + v.bound_if_row_var() + .ok_or(SignatureError::InvalidTypeArgs)?, + )), + _ => Err(SignatureError::InvalidTypeArgs), + } + } +} + +// Fallibly convert a [Term] to a [TypeRow]. +// +// This will fail if `arg` is of non-sequence kind (e.g. Type) +// or if the sequence contains row variables. +impl TryFrom for TypeRow { + type Error = SignatureError; + + fn try_from(value: TypeArg) -> Result { + match value { + TypeArg::List(elems) => elems + .into_iter() + .map(|ta| ta.as_runtime()) + .collect::>>() + .map(|x| x.into()) + .ok_or(SignatureError::InvalidTypeArgs), + _ => Err(SignatureError::InvalidTypeArgs), + } + } +} + +// Fallibly convert a [TypeArg] to a [TypeRowRV]. +// +// This will fail if `arg` is of non-sequence kind (e.g. Type). +impl TryFrom for TypeRowRV { + type Error = SignatureError; + + fn try_from(value: Term) -> Result { + match value { + TypeArg::List(elems) => elems + .into_iter() + .map(TypeRV::try_from) + .collect::, _>>() + .map(|vec| vec.into()), + TypeArg::Variable(v) => Ok(vec![TypeRV::new_row_var_use( + v.index(), + v.bound_if_row_var() + .ok_or(SignatureError::InvalidTypeArgs)?, + )] + .into()), + _ => Err(SignatureError::InvalidTypeArgs), + } + } +} + +impl From for Term { + fn from(value: TypeRow) -> Self { + Term::List(value.into_owned().into_iter().map_into().collect()) + } +} + +impl From for Term { + fn from(value: TypeRowRV) -> Self { + Term::List(value.into_owned().into_iter().map_into().collect()) + } +} + impl Deref for TypeRowBase { type Target = [TypeBase]; @@ -211,6 +287,12 @@ impl DerefMut for TypeRowBase { #[cfg(test)] mod test { + use super::*; + use crate::{ + extension::prelude::bool_t, + types::{Type, TypeArg, TypeRV}, + }; + mod proptest { use crate::proptest::RecursionDepth; use crate::types::{MaybeRV, TypeBase, TypeRowBase}; @@ -231,4 +313,78 @@ mod test { } } } + + #[test] + fn test_try_from_term_to_typerv() { + // Test successful conversion with Runtime type + let runtime_type = Type::UNIT; + let term = TypeArg::Runtime(runtime_type.clone()); + let result = TypeRV::try_from(term); + assert!(result.is_ok()); + assert_eq!(result.unwrap(), TypeRV::from(runtime_type)); + + // Test failure with non-type kind + let term = Term::String("test".to_string()); + let result = TypeRV::try_from(term); + assert!(result.is_err()); + } + + #[test] + fn test_try_from_term_to_typerow() { + // Test successful conversion with List + let types = vec![Type::new_unit_sum(1), bool_t()]; + let type_args = types.iter().map(|t| TypeArg::Runtime(t.clone())).collect(); + let term = TypeArg::List(type_args); + let result = TypeRow::try_from(term); + assert!(result.is_ok()); + assert_eq!(result.unwrap(), TypeRow::from(types)); + + // Test failure with non-list + let term = TypeArg::Runtime(Type::UNIT); + let result = TypeRow::try_from(term); + assert!(result.is_err()); + } + + #[test] + fn test_try_from_term_to_typerowrv() { + // Test successful conversion with List + let types = [TypeRV::from(Type::UNIT), TypeRV::from(bool_t())]; + let type_args = types.iter().map(|t| t.clone().into()).collect(); + let term = TypeArg::List(type_args); + let result = TypeRowRV::try_from(term); + assert!(result.is_ok()); + + // Test failure with non-sequence kind + let term = Term::String("test".to_string()); + let result = TypeRowRV::try_from(term); + assert!(result.is_err()); + } + + #[test] + fn test_from_typerow_to_term() { + let types = vec![Type::UNIT, bool_t()]; + let type_row = TypeRow::from(types); + let term = Term::from(type_row); + + match term { + Term::List(elems) => { + assert_eq!(elems.len(), 2); + } + _ => panic!("Expected Term::List"), + } + } + + #[test] + fn test_from_typerowrv_to_term() { + let types = vec![TypeRV::from(Type::UNIT), TypeRV::from(bool_t())]; + let type_row_rv = TypeRowRV::from(types); + let term = Term::from(type_row_rv); + + match term { + TypeArg::List(elems) => { + assert_eq!(elems.len(), 2); + } + _ => panic!("Expected Term::List"), + } + } } diff --git a/hugr-core/tests/model.rs b/hugr-core/tests/model.rs index a312b16441..3e451e41df 100644 --- a/hugr-core/tests/model.rs +++ b/hugr-core/tests/model.rs @@ -1,105 +1,85 @@ #![allow(missing_docs)] +use anyhow::Result; use std::str::FromStr; use hugr::std_extensions::std_reg; use hugr_core::{export::export_package, import::import_package}; use hugr_model::v0 as model; -fn roundtrip(source: &str) -> String { +fn roundtrip(source: &str) -> Result { let bump = model::bumpalo::Bump::new(); - let package_ast = model::ast::Package::from_str(source).unwrap(); - let package_table = package_ast.resolve(&bump).unwrap(); - let core = import_package(&package_table, &std_reg()).unwrap(); + let package_ast = model::ast::Package::from_str(source)?; + let package_table = package_ast.resolve(&bump)?; + let core = import_package(&package_table, &std_reg())?; let exported_table = export_package(&core.modules, &core.extensions, &bump); let exported_ast = exported_table.as_ast().unwrap(); - exported_ast.to_string() -} -#[test] -#[cfg_attr(miri, ignore)] // Opening files is not supported in (isolated) miri -pub fn test_roundtrip_add() { - insta::assert_snapshot!(roundtrip(include_str!( - "../../hugr-model/tests/fixtures/model-add.edn" - ))); + Ok(exported_ast.to_string()) } -#[test] -#[cfg_attr(miri, ignore)] // Opening files is not supported in (isolated) miri -pub fn test_roundtrip_call() { - insta::assert_snapshot!(roundtrip(include_str!( - "../../hugr-model/tests/fixtures/model-call.edn" - ))); +macro_rules! test_roundtrip { + ($name: ident, $file: expr) => { + #[test] + #[cfg_attr(miri, ignore)] // Opening files is not supported in (isolated) miri + pub fn $name() { + let ast = roundtrip(include_str!($file)).unwrap_or_else(|err| panic!("{:?}", err)); + insta::assert_snapshot!(ast) + } + }; } -#[test] -#[cfg_attr(miri, ignore)] // Opening files is not supported in (isolated) miri -pub fn test_roundtrip_alias() { - insta::assert_snapshot!(roundtrip(include_str!( - "../../hugr-model/tests/fixtures/model-alias.edn" - ))); -} +test_roundtrip!( + test_roundtrip_add, + "../../hugr-model/tests/fixtures/model-add.edn" +); -#[test] -#[cfg_attr(miri, ignore)] // Opening files is not supported in (isolated) miri -pub fn test_roundtrip_cfg() { - insta::assert_snapshot!(roundtrip(include_str!( - "../../hugr-model/tests/fixtures/model-cfg.edn" - ))); -} +test_roundtrip!( + test_roundtrip_call, + "../../hugr-model/tests/fixtures/model-call.edn" +); -#[test] -#[cfg_attr(miri, ignore)] // Opening files is not supported in (isolated) miri -pub fn test_roundtrip_cond() { - insta::assert_snapshot!(roundtrip(include_str!( - "../../hugr-model/tests/fixtures/model-cond.edn" - ))); -} +test_roundtrip!( + test_roundtrip_alias, + "../../hugr-model/tests/fixtures/model-alias.edn" +); -#[test] -#[cfg_attr(miri, ignore)] // Opening files is not supported in (isolated) miri -pub fn test_roundtrip_loop() { - insta::assert_snapshot!(roundtrip(include_str!( - "../../hugr-model/tests/fixtures/model-loop.edn" - ))); -} +test_roundtrip!( + test_roundtrip_cfg, + "../../hugr-model/tests/fixtures/model-cfg.edn" +); -#[test] -#[cfg_attr(miri, ignore)] // Opening files is not supported in (isolated) miri -pub fn test_roundtrip_params() { - insta::assert_snapshot!(roundtrip(include_str!( - "../../hugr-model/tests/fixtures/model-params.edn" - ))); -} +test_roundtrip!( + test_roundtrip_cond, + "../../hugr-model/tests/fixtures/model-cond.edn" +); -#[test] -#[cfg_attr(miri, ignore)] // Opening files is not supported in (isolated) miri -pub fn test_roundtrip_constraints() { - insta::assert_snapshot!(roundtrip(include_str!( - "../../hugr-model/tests/fixtures/model-constraints.edn" - ))); -} +test_roundtrip!( + test_roundtrip_loop, + "../../hugr-model/tests/fixtures/model-loop.edn" +); -#[test] -#[cfg_attr(miri, ignore)] // Opening files is not supported in (isolated) miri -pub fn test_roundtrip_const() { - insta::assert_snapshot!(roundtrip(include_str!( - "../../hugr-model/tests/fixtures/model-const.edn" - ))); -} +test_roundtrip!( + test_roundtrip_params, + "../../hugr-model/tests/fixtures/model-params.edn" +); -#[test] -#[cfg_attr(miri, ignore)] // Opening files is not supported in (isolated) miri -pub fn test_roundtrip_order() { - insta::assert_snapshot!(roundtrip(include_str!( - "../../hugr-model/tests/fixtures/model-order.edn" - ))); -} +test_roundtrip!( + test_roundtrip_constraints, + "../../hugr-model/tests/fixtures/model-constraints.edn" +); -#[test] -#[cfg_attr(miri, ignore)] // Opening files is not supported in (isolated) miri -pub fn test_roundtrip_entrypoint() { - insta::assert_snapshot!(roundtrip(include_str!( - "../../hugr-model/tests/fixtures/model-entrypoint.edn" - ))); -} +test_roundtrip!( + test_roundtrip_const, + "../../hugr-model/tests/fixtures/model-const.edn" +); + +test_roundtrip!( + test_roundtrip_order, + "../../hugr-model/tests/fixtures/model-order.edn" +); + +test_roundtrip!( + test_roundtrip_entrypoint, + "../../hugr-model/tests/fixtures/model-entrypoint.edn" +); diff --git a/hugr-core/tests/snapshots/model__roundtrip_add.snap b/hugr-core/tests/snapshots/model__roundtrip_add.snap index 4547fb3ebd..625b621f09 100644 --- a/hugr-core/tests/snapshots/model__roundtrip_add.snap +++ b/hugr-core/tests/snapshots/model__roundtrip_add.snap @@ -1,29 +1,42 @@ --- source: hugr-core/tests/model.rs -expression: "roundtrip(include_str!(\"../../hugr-model/tests/fixtures/model-add.edn\"))" +expression: ast --- (hugr 0) (mod) -(import core.fn) +(import core.meta.description) + +(import core.nat) -(import arithmetic.int.iadd) +(import core.fn) (import arithmetic.int.types.int) +(declare-operation + arithmetic.int.iadd + (param ?0 core.nat) + (core.fn + [(arithmetic.int.types.int ?0) (arithmetic.int.types.int ?0)] + [(arithmetic.int.types.int ?0)]) + (meta + (core.meta.description + "addition modulo 2^N (signed and unsigned versions are the same op)"))) + (define-func + public example.add (core.fn - [arithmetic.int.types.int arithmetic.int.types.int] - [arithmetic.int.types.int]) + [(arithmetic.int.types.int 6) (arithmetic.int.types.int 6)] + [(arithmetic.int.types.int 6)]) (dfg [%0 %1] [%2] (signature (core.fn - [arithmetic.int.types.int arithmetic.int.types.int] - [arithmetic.int.types.int])) - (arithmetic.int.iadd [%0 %1] [%2] + [(arithmetic.int.types.int 6) (arithmetic.int.types.int 6)] + [(arithmetic.int.types.int 6)])) + ((arithmetic.int.iadd 6) [%0 %1] [%2] (signature (core.fn - [arithmetic.int.types.int arithmetic.int.types.int] - [arithmetic.int.types.int]))))) + [(arithmetic.int.types.int 6) (arithmetic.int.types.int 6)] + [(arithmetic.int.types.int 6)]))))) diff --git a/hugr-core/tests/snapshots/model__roundtrip_alias.snap b/hugr-core/tests/snapshots/model__roundtrip_alias.snap index 5f0b44daf4..6174d8c744 100644 --- a/hugr-core/tests/snapshots/model__roundtrip_alias.snap +++ b/hugr-core/tests/snapshots/model__roundtrip_alias.snap @@ -1,6 +1,6 @@ --- source: hugr-core/tests/model.rs -expression: "roundtrip(include_str!(\"../../hugr-model/tests/fixtures/model-alias.edn\"))" +expression: ast --- (hugr 0) diff --git a/hugr-core/tests/snapshots/model__roundtrip_call.snap b/hugr-core/tests/snapshots/model__roundtrip_call.snap index 50d9c55c33..8681cf372c 100644 --- a/hugr-core/tests/snapshots/model__roundtrip_call.snap +++ b/hugr-core/tests/snapshots/model__roundtrip_call.snap @@ -1,6 +1,6 @@ --- source: hugr-core/tests/model.rs -expression: "roundtrip(include_str!(\"../../hugr-model/tests/fixtures/model-call.edn\"))" +expression: ast --- (hugr 0) @@ -17,12 +17,14 @@ expression: "roundtrip(include_str!(\"../../hugr-model/tests/fixtures/model-call (import arithmetic.int.types.int) (declare-func + public example.callee (core.fn [arithmetic.int.types.int] [arithmetic.int.types.int]) (meta (compat.meta_json "description" "\"This is a function declaration.\"")) (meta (compat.meta_json "title" "\"Callee\""))) (define-func + public example.caller (core.fn [arithmetic.int.types.int] [arithmetic.int.types.int]) (meta @@ -41,6 +43,7 @@ expression: "roundtrip(include_str!(\"../../hugr-model/tests/fixtures/model-call (core.fn [arithmetic.int.types.int] [arithmetic.int.types.int]))))) (define-func + public example.load (core.fn [] [(core.fn [arithmetic.int.types.int] [arithmetic.int.types.int])]) (dfg [] [%0] diff --git a/hugr-core/tests/snapshots/model__roundtrip_cfg.snap b/hugr-core/tests/snapshots/model__roundtrip_cfg.snap index 7a6136bdb2..da2fe4851f 100644 --- a/hugr-core/tests/snapshots/model__roundtrip_cfg.snap +++ b/hugr-core/tests/snapshots/model__roundtrip_cfg.snap @@ -1,6 +1,6 @@ --- source: hugr-core/tests/model.rs -expression: "roundtrip(include_str!(\"../../hugr-model/tests/fixtures/model-cfg.edn\"))" +expression: ast --- (hugr 0) @@ -16,36 +16,35 @@ expression: "roundtrip(include_str!(\"../../hugr-model/tests/fixtures/model-cfg. (import core.adt) -(define-func example.cfg_loop (param ?0 core.type) (core.fn [?0] [?0]) +(define-func public example.cfg_loop (param ?0 core.type) (core.fn [?0] [?0]) (dfg [%0] [%1] (signature (core.fn [?0] [?0])) (cfg [%0] [%1] (signature (core.fn [?0] [?0])) (cfg [%2] [%3] - (signature (core.fn [(core.ctrl [?0])] [(core.ctrl [?0])])) + (signature (core.ctrl [[?0]] [[?0]])) (block [%2] [%3 %2] - (signature - (core.fn [(core.ctrl [?0])] [(core.ctrl [?0]) (core.ctrl [?0])])) + (signature (core.ctrl [[?0]] [[?0] [?0]])) (dfg [%4] [%5] (signature (core.fn [?0] [(core.adt [[?0] [?0]])])) ((core.make_adt 0) [%4] [%5] (signature (core.fn [?0] [(core.adt [[?0] [?0]])]))))))))) -(define-func example.cfg_order (param ?0 core.type) (core.fn [?0] [?0]) +(define-func public example.cfg_order (param ?0 core.type) (core.fn [?0] [?0]) (dfg [%0] [%1] (signature (core.fn [?0] [?0])) (cfg [%0] [%1] (signature (core.fn [?0] [?0])) (cfg [%2] [%3] - (signature (core.fn [(core.ctrl [?0])] [(core.ctrl [?0])])) + (signature (core.ctrl [[?0]] [[?0]])) (block [%2] [%6] - (signature (core.fn [(core.ctrl [?0])] [(core.ctrl [?0])])) + (signature (core.ctrl [[?0]] [[?0]])) (dfg [%4] [%5] (signature (core.fn [?0] [(core.adt [[?0]])])) ((core.make_adt 0) [%4] [%5] (signature (core.fn [?0] [(core.adt [[?0]])]))))) (block [%6] [%3] - (signature (core.fn [(core.ctrl [?0])] [(core.ctrl [?0])])) + (signature (core.ctrl [[?0]] [[?0]])) (dfg [%7] [%8] (signature (core.fn [?0] [(core.adt [[?0]])])) ((core.make_adt 0) [%7] [%8] diff --git a/hugr-core/tests/snapshots/model__roundtrip_cond.snap b/hugr-core/tests/snapshots/model__roundtrip_cond.snap index e4c49f1193..c51323db5c 100644 --- a/hugr-core/tests/snapshots/model__roundtrip_cond.snap +++ b/hugr-core/tests/snapshots/model__roundtrip_cond.snap @@ -1,42 +1,57 @@ --- source: hugr-core/tests/model.rs -expression: "roundtrip(include_str!(\"../../hugr-model/tests/fixtures/model-cond.edn\"))" +expression: ast --- (hugr 0) (mod) +(import core.meta.description) + +(import core.nat) + (import core.fn) (import core.adt) (import arithmetic.int.types.int) -(import arithmetic.int.ineg) +(declare-operation + arithmetic.int.ineg + (param ?0 core.nat) + (core.fn [(arithmetic.int.types.int ?0)] [(arithmetic.int.types.int ?0)]) + (meta + (core.meta.description + "negation modulo 2^N (signed and unsigned versions are the same op)"))) (define-func + public example.cond (core.fn - [(core.adt [[] []]) arithmetic.int.types.int] - [arithmetic.int.types.int]) + [(core.adt [[] []]) (arithmetic.int.types.int 6)] + [(arithmetic.int.types.int 6)]) (dfg [%0 %1] [%2] (signature (core.fn - [(core.adt [[] []]) arithmetic.int.types.int] - [arithmetic.int.types.int])) + [(core.adt [[] []]) (arithmetic.int.types.int 6)] + [(arithmetic.int.types.int 6)])) (cond [%0 %1] [%2] (signature (core.fn - [(core.adt [[] []]) arithmetic.int.types.int] - [arithmetic.int.types.int])) + [(core.adt [[] []]) (arithmetic.int.types.int 6)] + [(arithmetic.int.types.int 6)])) (dfg [%3] [%3] (signature - (core.fn [arithmetic.int.types.int] [arithmetic.int.types.int]))) + (core.fn + [(arithmetic.int.types.int 6)] + [(arithmetic.int.types.int 6)]))) (dfg [%4] [%5] (signature - (core.fn [arithmetic.int.types.int] [arithmetic.int.types.int])) - (arithmetic.int.ineg [%4] [%5] + (core.fn + [(arithmetic.int.types.int 6)] + [(arithmetic.int.types.int 6)])) + ((arithmetic.int.ineg 6) [%4] [%5] (signature (core.fn - [arithmetic.int.types.int] - [arithmetic.int.types.int]))))))) + [(arithmetic.int.types.int 6)] + [(arithmetic.int.types.int 6)]))))))) diff --git a/hugr-core/tests/snapshots/model__roundtrip_const.snap b/hugr-core/tests/snapshots/model__roundtrip_const.snap index 99cfdb55e9..3b386275ba 100644 --- a/hugr-core/tests/snapshots/model__roundtrip_const.snap +++ b/hugr-core/tests/snapshots/model__roundtrip_const.snap @@ -1,6 +1,6 @@ --- source: hugr-core/tests/model.rs -expression: "roundtrip(include_str!(\"../../hugr-model/tests/fixtures/model-const.edn\"))" +expression: ast --- (hugr 0) @@ -28,7 +28,10 @@ expression: "roundtrip(include_str!(\"../../hugr-model/tests/fixtures/model-cons (import core.adt) -(define-func example.bools (core.fn [] [(core.adt [[] []]) (core.adt [[] []])]) +(define-func + public + example.bools + (core.fn [] [(core.adt [[] []]) (core.adt [[] []])]) (dfg [] [%0 %1] (signature (core.fn [] [(core.adt [[] []]) (core.adt [[] []])])) ((core.load_const (core.const.adt [[] []] _ 0 [])) [] [%0] @@ -37,6 +40,7 @@ expression: "roundtrip(include_str!(\"../../hugr-model/tests/fixtures/model-cons (signature (core.fn [] [(core.adt [[] []])]))))) (define-func + public example.make-pair (core.fn [] @@ -73,7 +77,10 @@ expression: "roundtrip(include_str!(\"../../hugr-model/tests/fixtures/model-cons [[(collections.array.array 5 (arithmetic.int.types.int 6)) arithmetic.float.types.float64]])]))))) -(define-func example.f64-json (core.fn [] [arithmetic.float.types.float64]) +(define-func + public + example.f64-json + (core.fn [] [arithmetic.float.types.float64]) (dfg [] [%0 %1] (signature (core.fn diff --git a/hugr-core/tests/snapshots/model__roundtrip_constraints.snap b/hugr-core/tests/snapshots/model__roundtrip_constraints.snap index b9b406f3c5..2c50e73489 100644 --- a/hugr-core/tests/snapshots/model__roundtrip_constraints.snap +++ b/hugr-core/tests/snapshots/model__roundtrip_constraints.snap @@ -1,6 +1,6 @@ --- source: hugr-core/tests/model.rs -expression: "roundtrip(include_str!(\"../../hugr-model/tests/fixtures/model-constraints.edn\"))" +expression: ast --- (hugr 0) @@ -10,20 +10,25 @@ expression: "roundtrip(include_str!(\"../../hugr-model/tests/fixtures/model-cons (import core.nat) -(import core.type) - (import core.nonlinear) +(import core.type) + (import core.fn) +(import core.title) + (declare-func - array.replicate + private + _1 (param ?0 core.nat) (param ?1 core.type) (where (core.nonlinear ?1)) - (core.fn [?1] [(collections.array.array ?0 ?1)])) + (core.fn [?1] [(collections.array.array ?0 ?1)]) + (meta (core.title "array.replicate"))) (declare-func + public array.copy (param ?0 core.nat) (param ?1 core.type) @@ -33,6 +38,7 @@ expression: "roundtrip(include_str!(\"../../hugr-model/tests/fixtures/model-cons [(collections.array.array ?0 ?1) (collections.array.array ?0 ?1)])) (define-func + public util.copy (param ?0 core.type) (where (core.nonlinear ?0)) diff --git a/hugr-core/tests/snapshots/model__roundtrip_entrypoint.snap b/hugr-core/tests/snapshots/model__roundtrip_entrypoint.snap index 1db0b9d1cd..1340bb6b02 100644 --- a/hugr-core/tests/snapshots/model__roundtrip_entrypoint.snap +++ b/hugr-core/tests/snapshots/model__roundtrip_entrypoint.snap @@ -1,6 +1,6 @@ --- source: hugr-core/tests/model.rs -expression: "roundtrip(include_str!(\"../../hugr-model/tests/fixtures/model-entrypoint.edn\"))" +expression: ast --- (hugr 0) @@ -10,8 +10,9 @@ expression: "roundtrip(include_str!(\"../../hugr-model/tests/fixtures/model-entr (import core.entrypoint) -(define-func main (core.fn [] []) - (dfg (signature (core.fn [] [])) (meta core.entrypoint))) +(define-func public main (core.fn [] []) + (meta core.entrypoint) + (dfg (signature (core.fn [] [])))) (mod) @@ -19,8 +20,9 @@ expression: "roundtrip(include_str!(\"../../hugr-model/tests/fixtures/model-entr (import core.entrypoint) -(define-func wrapper_dfg (core.fn [] []) - (dfg (signature (core.fn [] [])) (meta core.entrypoint))) +(define-func public wrapper_dfg (core.fn [] []) + (meta core.entrypoint) + (dfg (signature (core.fn [] [])))) (mod) @@ -34,16 +36,17 @@ expression: "roundtrip(include_str!(\"../../hugr-model/tests/fixtures/model-entr (import core.adt) -(define-func wrapper_cfg (core.fn [] []) +(define-func public wrapper_cfg (core.fn [] []) (dfg (signature (core.fn [] [])) (cfg (signature (core.fn [] [])) + (meta core.entrypoint) (cfg [%0] [%1] - (signature (core.fn [(core.ctrl [])] [(core.ctrl [])])) + (signature (core.ctrl [[]] [[]])) (meta core.entrypoint) (block [%0] [%1] - (signature (core.fn [(core.ctrl [])] [(core.ctrl [])])) + (signature (core.ctrl [[]] [[]])) (dfg [] [%2] (signature (core.fn [] [(core.adt [[]])])) ((core.make_adt 0) [] [%2] diff --git a/hugr-core/tests/snapshots/model__roundtrip_loop.snap b/hugr-core/tests/snapshots/model__roundtrip_loop.snap index 50035a637c..e6991516f3 100644 --- a/hugr-core/tests/snapshots/model__roundtrip_loop.snap +++ b/hugr-core/tests/snapshots/model__roundtrip_loop.snap @@ -1,6 +1,6 @@ --- source: hugr-core/tests/model.rs -expression: "roundtrip(include_str!(\"../../hugr-model/tests/fixtures/model-loop.edn\"))" +expression: ast --- (hugr 0) @@ -14,7 +14,10 @@ expression: "roundtrip(include_str!(\"../../hugr-model/tests/fixtures/model-loop (import core.adt) -(define-func example.loop (param ?0 core.type) (core.fn [?0] [?0]) +(import core.title) + +(define-func private _1 (param ?0 core.type) (core.fn [?0] [?0]) + (meta (core.title "example.loop")) (dfg [%0] [%1] (signature (core.fn [?0] [?0])) (tail-loop [%0] [%1] diff --git a/hugr-core/tests/snapshots/model__roundtrip_order.snap b/hugr-core/tests/snapshots/model__roundtrip_order.snap index e358670e1a..dda51b71cf 100644 --- a/hugr-core/tests/snapshots/model__roundtrip_order.snap +++ b/hugr-core/tests/snapshots/model__roundtrip_order.snap @@ -1,60 +1,80 @@ --- source: hugr-core/tests/model.rs -expression: "roundtrip(include_str!(\"../../hugr-model/tests/fixtures/model-order.edn\"))" +expression: ast --- (hugr 0) (mod) -(import core.order_hint.key) +(import core.meta.description) -(import core.fn) +(import core.order_hint.input_key) (import core.order_hint.order) (import arithmetic.int.types.int) -(import arithmetic.int.ineg) +(import core.nat) + +(import core.order_hint.key) + +(import core.order_hint.output_key) + +(import core.fn) + +(declare-operation + arithmetic.int.ineg + (param ?0 core.nat) + (core.fn [(arithmetic.int.types.int ?0)] [(arithmetic.int.types.int ?0)]) + (meta + (core.meta.description + "negation modulo 2^N (signed and unsigned versions are the same op)"))) (define-func + public main (core.fn - [arithmetic.int.types.int - arithmetic.int.types.int - arithmetic.int.types.int - arithmetic.int.types.int] - [arithmetic.int.types.int - arithmetic.int.types.int - arithmetic.int.types.int - arithmetic.int.types.int]) + [(arithmetic.int.types.int 6) + (arithmetic.int.types.int 6) + (arithmetic.int.types.int 6) + (arithmetic.int.types.int 6)] + [(arithmetic.int.types.int 6) + (arithmetic.int.types.int 6) + (arithmetic.int.types.int 6) + (arithmetic.int.types.int 6)]) (dfg [%0 %1 %2 %3] [%4 %5 %6 %7] (signature (core.fn - [arithmetic.int.types.int - arithmetic.int.types.int - arithmetic.int.types.int - arithmetic.int.types.int] - [arithmetic.int.types.int - arithmetic.int.types.int - arithmetic.int.types.int - arithmetic.int.types.int])) + [(arithmetic.int.types.int 6) + (arithmetic.int.types.int 6) + (arithmetic.int.types.int 6) + (arithmetic.int.types.int 6)] + [(arithmetic.int.types.int 6) + (arithmetic.int.types.int 6) + (arithmetic.int.types.int 6) + (arithmetic.int.types.int 6)])) + (meta (core.order_hint.input_key 2)) + (meta (core.order_hint.order 2 4)) + (meta (core.order_hint.order 2 3)) + (meta (core.order_hint.output_key 3)) (meta (core.order_hint.order 4 7)) (meta (core.order_hint.order 5 6)) (meta (core.order_hint.order 5 4)) + (meta (core.order_hint.order 5 3)) (meta (core.order_hint.order 6 7)) - (arithmetic.int.ineg [%0] [%4] + ((arithmetic.int.ineg 6) [%0] [%4] (signature - (core.fn [arithmetic.int.types.int] [arithmetic.int.types.int])) + (core.fn [(arithmetic.int.types.int 6)] [(arithmetic.int.types.int 6)])) (meta (core.order_hint.key 4))) - (arithmetic.int.ineg [%1] [%5] + ((arithmetic.int.ineg 6) [%1] [%5] (signature - (core.fn [arithmetic.int.types.int] [arithmetic.int.types.int])) + (core.fn [(arithmetic.int.types.int 6)] [(arithmetic.int.types.int 6)])) (meta (core.order_hint.key 5))) - (arithmetic.int.ineg [%2] [%6] + ((arithmetic.int.ineg 6) [%2] [%6] (signature - (core.fn [arithmetic.int.types.int] [arithmetic.int.types.int])) + (core.fn [(arithmetic.int.types.int 6)] [(arithmetic.int.types.int 6)])) (meta (core.order_hint.key 6))) - (arithmetic.int.ineg [%3] [%7] + ((arithmetic.int.ineg 6) [%3] [%7] (signature - (core.fn [arithmetic.int.types.int] [arithmetic.int.types.int])) + (core.fn [(arithmetic.int.types.int 6)] [(arithmetic.int.types.int 6)])) (meta (core.order_hint.key 7))))) diff --git a/hugr-core/tests/snapshots/model__roundtrip_params.snap b/hugr-core/tests/snapshots/model__roundtrip_params.snap index 77d6d9cc77..c5c5eac95b 100644 --- a/hugr-core/tests/snapshots/model__roundtrip_params.snap +++ b/hugr-core/tests/snapshots/model__roundtrip_params.snap @@ -1,18 +1,54 @@ --- source: hugr-core/tests/model.rs -expression: "roundtrip(include_str!(\"../../hugr-model/tests/fixtures/model-params.edn\"))" +expression: ast --- (hugr 0) (mod) -(import core.fn) +(import core.call) (import core.type) +(import core.bytes) + +(import core.nat) + +(import core.fn) + +(import core.str) + +(import core.float) + +(import core.title) + (define-func + public example.swap (param ?0 core.type) (param ?1 core.type) (core.fn [?0 ?1] [?1 ?0]) (dfg [%0 %1] [%1 %0] (signature (core.fn [?0 ?1] [?1 ?0])))) + +(declare-func + public + example.literals + (param ?0 core.str) + (param ?1 core.nat) + (param ?2 core.bytes) + (param ?3 core.float) + (core.fn [] [])) + +(define-func private _5 (core.fn [] []) + (meta (core.title "example.call_literals")) + (dfg + (signature (core.fn [] [])) + ((core.call + [] + [] + (example.literals + "string" + 42 + (bytes "SGVsbG8gd29ybGQg8J+Yig==") + 6.023e23)) + (signature (core.fn [] []))))) diff --git a/hugr-llvm/CHANGELOG.md b/hugr-llvm/CHANGELOG.md index 0ed47be85a..7ab27c08a4 100644 --- a/hugr-llvm/CHANGELOG.md +++ b/hugr-llvm/CHANGELOG.md @@ -1,3 +1,5 @@ +# Changelog + # Changelog All notable changes to this project will be documented in this file. @@ -5,6 +7,16 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## [0.21.0](https://github.com/CQCL/hugr/compare/hugr-llvm-v0.20.2...hugr-llvm-v0.21.0) - 2025-07-09 + +### New Features + +- [**breaking**] No nested FuncDefns (or AliasDefns) ([#2256](https://github.com/CQCL/hugr/pull/2256)) +- [**breaking**] Split `TypeArg::Sequence` into tuples and lists. ([#2140](https://github.com/CQCL/hugr/pull/2140)) +- [**breaking**] More helpful error messages in model import ([#2272](https://github.com/CQCL/hugr/pull/2272)) +- [**breaking**] Merge `TypeParam` and `TypeArg` into one `Term` type in Rust ([#2309](https://github.com/CQCL/hugr/pull/2309)) +- Add `MakeError` op ([#2377](https://github.com/CQCL/hugr/pull/2377)) + ## [0.20.2](https://github.com/CQCL/hugr/compare/hugr-llvm-v0.20.1...hugr-llvm-v0.20.2) - 2025-06-25 ### New Features diff --git a/hugr-llvm/Cargo.toml b/hugr-llvm/Cargo.toml index 31bd533960..0a906686db 100644 --- a/hugr-llvm/Cargo.toml +++ b/hugr-llvm/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "hugr-llvm" -version = "0.20.2" +version = "0.22.1" description = "A general and extensible crate for lowering HUGRs into LLVM IR" edition.workspace = true @@ -26,8 +26,8 @@ workspace = true [dependencies] inkwell = { version = "0.6.0", default-features = false } -hugr-core = { path = "../hugr-core", version = "0.20.2" } -anyhow = "1.0.98" +hugr-core = { path = "../hugr-core", version = "0.22.1" } +anyhow.workspace = true itertools.workspace = true delegate.workspace = true petgraph.workspace = true diff --git a/hugr-llvm/src/emit/func.rs b/hugr-llvm/src/emit/func.rs index 3c5eed2ef1..77d865540f 100644 --- a/hugr-llvm/src/emit/func.rs +++ b/hugr-llvm/src/emit/func.rs @@ -109,11 +109,6 @@ impl<'c, 'a, H: HugrView> EmitFuncContext<'c, 'a, H> { self.todo.insert(node.node()); } - /// Returns the current [FunctionValue] being emitted. - pub fn func(&self) -> FunctionValue<'c> { - self.func - } - /// Returns the internal [Builder]. Callers must ensure that it is /// positioned at the end of a basic block. This invariant is not checked(it /// doesn't seem possible to check it). diff --git a/hugr-llvm/src/emit/ops/cfg.rs b/hugr-llvm/src/emit/ops/cfg.rs index bdf27389a2..790110fe2a 100644 --- a/hugr-llvm/src/emit/ops/cfg.rs +++ b/hugr-llvm/src/emit/ops/cfg.rs @@ -217,7 +217,7 @@ impl<'c, 'hugr, H: HugrView> CfgEmitter<'c, 'hugr, H> { #[cfg(test)] mod test { - use hugr_core::builder::{Dataflow, DataflowSubContainer, SubContainer}; + use hugr_core::builder::{Dataflow, DataflowHugr, SubContainer}; use hugr_core::extension::ExtensionRegistry; use hugr_core::extension::prelude::{self, bool_t}; use hugr_core::ops::Value; @@ -279,7 +279,7 @@ mod test { cfg_builder.branch(&b1, 1, &exit_block).unwrap(); let cfg = cfg_builder.finish_sub_container().unwrap(); let [cfg_out] = cfg.outputs_arr(); - builder.finish_with_outputs([cfg_out]).unwrap() + builder.finish_hugr_with_outputs([cfg_out]).unwrap() }); llvm_ctx.add_extensions(CodegenExtsBuilder::add_default_prelude_extensions); check_emission!(hugr, llvm_ctx); @@ -395,7 +395,7 @@ mod test { .unwrap() .outputs_arr() }; - builder.finish_with_outputs([outer_cfg_out]).unwrap() + builder.finish_hugr_with_outputs([outer_cfg_out]).unwrap() }); check_emission!(hugr, llvm_ctx); } diff --git a/hugr-llvm/src/emit/snapshots/hugr_llvm__emit__test__test_fns__diverse_cfg_children@llvm14.snap b/hugr-llvm/src/emit/snapshots/hugr_llvm__emit__test__test_fns__diverse_cfg_children@llvm14.snap index 9ea0d09e8d..d673a4b73e 100644 --- a/hugr-llvm/src/emit/snapshots/hugr_llvm__emit__test__test_fns__diverse_cfg_children@llvm14.snap +++ b/hugr-llvm/src/emit/snapshots/hugr_llvm__emit__test__test_fns__diverse_cfg_children@llvm14.snap @@ -13,21 +13,12 @@ entry_block: ; preds = %alloca_block br label %0 0: ; preds = %entry_block - %1 = call i1 @_hl.scoped_func.7() - switch i1 false, label %2 [ + switch i1 false, label %1 [ ] -2: ; preds = %0 - br label %3 +1: ; preds = %0 + br label %2 -3: ; preds = %2 - ret i1 %1 -} - -define i1 @_hl.scoped_func.7() { -alloca_block: - br label %entry_block - -entry_block: ; preds = %alloca_block +2: ; preds = %1 ret i1 false } diff --git a/hugr-llvm/src/emit/snapshots/hugr_llvm__emit__test__test_fns__diverse_cfg_children@pre-mem2reg@llvm14.snap b/hugr-llvm/src/emit/snapshots/hugr_llvm__emit__test__test_fns__diverse_cfg_children@pre-mem2reg@llvm14.snap index c38ac33f4d..025b85a9ac 100644 --- a/hugr-llvm/src/emit/snapshots/hugr_llvm__emit__test__test_fns__diverse_cfg_children@pre-mem2reg@llvm14.snap +++ b/hugr-llvm/src/emit/snapshots/hugr_llvm__emit__test__test_fns__diverse_cfg_children@pre-mem2reg@llvm14.snap @@ -10,31 +10,30 @@ alloca_block: %"0" = alloca i1, align 1 %"4_0" = alloca i1, align 1 %"01" = alloca i1, align 1 - %"15_0" = alloca {}, align 8 - %"16_0" = alloca i1, align 1 + %"11_0" = alloca {}, align 8 + %"12_0" = alloca i1, align 1 br label %entry_block entry_block: ; preds = %alloca_block br label %0 0: ; preds = %entry_block - %1 = call i1 @_hl.scoped_func.7() - store i1 %1, i1* %"16_0", align 1 - store {} undef, {}* %"15_0", align 1 - %"15_02" = load {}, {}* %"15_0", align 1 - %"16_03" = load i1, i1* %"16_0", align 1 - store {} %"15_02", {}* %"15_0", align 1 - store i1 %"16_03", i1* %"16_0", align 1 - %"15_04" = load {}, {}* %"15_0", align 1 - %"16_05" = load i1, i1* %"16_0", align 1 - switch i1 false, label %2 [ + store i1 false, i1* %"12_0", align 1 + store {} undef, {}* %"11_0", align 1 + %"11_02" = load {}, {}* %"11_0", align 1 + %"12_03" = load i1, i1* %"12_0", align 1 + store {} %"11_02", {}* %"11_0", align 1 + store i1 %"12_03", i1* %"12_0", align 1 + %"11_04" = load {}, {}* %"11_0", align 1 + %"12_05" = load i1, i1* %"12_0", align 1 + switch i1 false, label %1 [ ] -2: ; preds = %0 - store i1 %"16_05", i1* %"01", align 1 - br label %3 +1: ; preds = %0 + store i1 %"12_05", i1* %"01", align 1 + br label %2 -3: ; preds = %2 +2: ; preds = %1 %"06" = load i1, i1* %"01", align 1 store i1 %"06", i1* %"4_0", align 1 %"4_07" = load i1, i1* %"4_0", align 1 @@ -42,17 +41,3 @@ entry_block: ; preds = %alloca_block %"08" = load i1, i1* %"0", align 1 ret i1 %"08" } - -define i1 @_hl.scoped_func.7() { -alloca_block: - %"0" = alloca i1, align 1 - %"10_0" = alloca i1, align 1 - br label %entry_block - -entry_block: ; preds = %alloca_block - store i1 false, i1* %"10_0", align 1 - %"10_01" = load i1, i1* %"10_0", align 1 - store i1 %"10_01", i1* %"0", align 1 - %"02" = load i1, i1* %"0", align 1 - ret i1 %"02" -} diff --git a/hugr-llvm/src/emit/snapshots/hugr_llvm__emit__test__test_fns__diverse_dfg_children@llvm14.snap b/hugr-llvm/src/emit/snapshots/hugr_llvm__emit__test__test_fns__diverse_dfg_children@llvm14.snap deleted file mode 100644 index ea9074b87b..0000000000 --- a/hugr-llvm/src/emit/snapshots/hugr_llvm__emit__test__test_fns__diverse_dfg_children@llvm14.snap +++ /dev/null @@ -1,23 +0,0 @@ ---- -source: hugr-llvm/src/emit/test.rs -expression: mod_str ---- -; ModuleID = 'test_context' -source_filename = "test_context" - -define i1 @_hl.main.1() { -alloca_block: - br label %entry_block - -entry_block: ; preds = %alloca_block - %0 = call i1 @_hl.scoped_func.8() - ret i1 %0 -} - -define i1 @_hl.scoped_func.8() { -alloca_block: - br label %entry_block - -entry_block: ; preds = %alloca_block - ret i1 false -} diff --git a/hugr-llvm/src/emit/snapshots/hugr_llvm__emit__test__test_fns__diverse_dfg_children@pre-mem2reg@llvm14.snap b/hugr-llvm/src/emit/snapshots/hugr_llvm__emit__test__test_fns__diverse_dfg_children@pre-mem2reg@llvm14.snap deleted file mode 100644 index f990db641b..0000000000 --- a/hugr-llvm/src/emit/snapshots/hugr_llvm__emit__test__test_fns__diverse_dfg_children@pre-mem2reg@llvm14.snap +++ /dev/null @@ -1,38 +0,0 @@ ---- -source: hugr-llvm/src/emit/test.rs -expression: mod_str ---- -; ModuleID = 'test_context' -source_filename = "test_context" - -define i1 @_hl.main.1() { -alloca_block: - %"0" = alloca i1, align 1 - %"4_0" = alloca i1, align 1 - %"12_0" = alloca i1, align 1 - br label %entry_block - -entry_block: ; preds = %alloca_block - %0 = call i1 @_hl.scoped_func.8() - store i1 %0, i1* %"12_0", align 1 - %"12_01" = load i1, i1* %"12_0", align 1 - store i1 %"12_01", i1* %"4_0", align 1 - %"4_02" = load i1, i1* %"4_0", align 1 - store i1 %"4_02", i1* %"0", align 1 - %"03" = load i1, i1* %"0", align 1 - ret i1 %"03" -} - -define i1 @_hl.scoped_func.8() { -alloca_block: - %"0" = alloca i1, align 1 - %"11_0" = alloca i1, align 1 - br label %entry_block - -entry_block: ; preds = %alloca_block - store i1 false, i1* %"11_0", align 1 - %"11_01" = load i1, i1* %"11_0", align 1 - store i1 %"11_01", i1* %"0", align 1 - %"02" = load i1, i1* %"0", align 1 - ret i1 %"02" -} diff --git a/hugr-llvm/src/emit/test.rs b/hugr-llvm/src/emit/test.rs index d5194bf47b..d79ac361cd 100644 --- a/hugr-llvm/src/emit/test.rs +++ b/hugr-llvm/src/emit/test.rs @@ -1,9 +1,7 @@ use crate::types::HugrFuncType; use crate::utils::fat::FatNode; use anyhow::{Result, anyhow}; -use hugr_core::builder::{ - BuildHandle, Container, DFGWrapper, HugrBuilder, ModuleBuilder, SubContainer, -}; +use hugr_core::builder::{BuildHandle, DFGWrapper, FunctionBuilder}; use hugr_core::extension::ExtensionRegistry; use hugr_core::ops::handle::FuncID; use hugr_core::types::TypeRow; @@ -15,7 +13,7 @@ use inkwell::values::GenericValue; use super::EmitHugr; #[allow(clippy::upper_case_acronyms)] -pub type DFGW<'a> = DFGWrapper<&'a mut Hugr, BuildHandle>>; +pub type DFGW = DFGWrapper>>; pub struct SimpleHugrConfig { ins: TypeRow, @@ -131,31 +129,13 @@ impl SimpleHugrConfig { self } - pub fn finish( - self, - make: impl for<'a> FnOnce(DFGW<'a>) -> as SubContainer>::ContainerHandle, - ) -> Hugr { + pub fn finish(self, make: impl FnOnce(DFGW) -> Hugr) -> Hugr { self.finish_with_exts(|builder, _| make(builder)) } - pub fn finish_with_exts( - self, - make: impl for<'a> FnOnce( - DFGW<'a>, - &ExtensionRegistry, - ) -> as SubContainer>::ContainerHandle, - ) -> Hugr { - let mut mod_b = ModuleBuilder::new(); - let func_b = mod_b - .define_function("main", HugrFuncType::new(self.ins, self.outs)) - .unwrap(); - make(func_b, &self.extensions); - - // Intentionally left as a debugging aid. If the HUGR you construct - // fails validation, uncomment the following line to print it out - // unvalidated. - // println!("{}", mod_b.hugr().mermaid_string()); - mod_b.finish_hugr().unwrap_or_else(|e| panic!("{e}")) + pub fn finish_with_exts(self, make: impl FnOnce(DFGW, &ExtensionRegistry) -> Hugr) -> Hugr { + let func_b = FunctionBuilder::new("main", HugrFuncType::new(self.ins, self.outs)).unwrap(); + make(func_b, &self.extensions) } } @@ -187,11 +167,7 @@ pub use insta; macro_rules! check_emission { // Call the macro with a snapshot name. ($snapshot_name:expr, $hugr: ident, $test_ctx:ident) => {{ - let root = - $crate::utils::fat::FatExt::fat_root::<$crate::emit::test::hugr_core::ops::Module>( - &$hugr, - ) - .unwrap(); + let root = $crate::utils::fat::FatExt::fat_root(&$hugr).unwrap(); let emission = $crate::emit::test::Emission::emit_hugr(root, $test_ctx.get_emit_hugr()).unwrap(); @@ -237,8 +213,8 @@ mod test_fns { use crate::custom::CodegenExtsBuilder; use crate::types::{HugrFuncType, HugrSumType}; - use hugr_core::builder::DataflowSubContainer; use hugr_core::builder::{Container, Dataflow, HugrBuilder, ModuleBuilder, SubContainer}; + use hugr_core::builder::{DataflowHugr, DataflowSubContainer}; use hugr_core::extension::PRELUDE_REGISTRY; use hugr_core::extension::prelude::{ConstUsize, bool_t, usize_t}; use hugr_core::ops::constant::CustomConst; @@ -266,7 +242,7 @@ mod test_fns { builder.input_wires(), ) .unwrap(); - builder.finish_with_outputs(tag.outputs()).unwrap() + builder.finish_hugr_with_outputs(tag.outputs()).unwrap() }); let _ = check_emission!(hugr, llvm_ctx); } @@ -284,7 +260,7 @@ mod test_fns { let w = b.input_wires(); b.finish_with_outputs(w).unwrap() }; - builder.finish_with_outputs(dfg.outputs()).unwrap() + builder.finish_hugr_with_outputs(dfg.outputs()).unwrap() }); check_emission!(hugr, llvm_ctx); } @@ -329,7 +305,7 @@ mod test_fns { cond_b.finish_sub_container().unwrap() }; let [o1, o2] = cond.outputs_arr(); - builder.finish_with_outputs([o1, o2]).unwrap() + builder.finish_hugr_with_outputs([o1, o2]).unwrap() }) }; check_emission!(hugr, llvm_ctx); @@ -349,7 +325,7 @@ mod test_fns { .with_extensions(STD_REG.to_owned()) .finish(|mut builder: DFGW| { let konst = builder.add_load_value(v); - builder.finish_with_outputs([konst]).unwrap() + builder.finish_hugr_with_outputs([konst]).unwrap() }); check_emission!(hugr, llvm_ctx); } @@ -411,7 +387,7 @@ mod test_fns { .instantiate_extension_op("iadd", [4.into()]) .unwrap(); let add = builder.add_dataflow_op(ext_op, [k1, k2]).unwrap(); - builder.finish_with_outputs(add.outputs()).unwrap() + builder.finish_hugr_with_outputs(add.outputs()).unwrap() }); check_emission!(hugr, llvm_ctx); } @@ -453,34 +429,6 @@ mod test_fns { check_emission!(hugr, llvm_ctx); } - #[rstest] - fn diverse_dfg_children(llvm_ctx: TestContext) { - let hugr = SimpleHugrConfig::new() - .with_outs(bool_t()) - .finish(|mut builder: DFGW| { - let [r] = { - let mut builder = builder - .dfg_builder(HugrFuncType::new(type_row![], bool_t()), []) - .unwrap(); - let konst = builder.add_constant(Value::false_val()); - let func = { - let mut builder = builder - .define_function( - "scoped_func", - HugrFuncType::new(type_row![], bool_t()), - ) - .unwrap(); - let w = builder.load_const(&konst); - builder.finish_with_outputs([w]).unwrap() - }; - let [r] = builder.call(func.handle(), &[], []).unwrap().outputs_arr(); - builder.finish_with_outputs([r]).unwrap().outputs_arr() - }; - builder.finish_with_outputs([r]).unwrap() - }); - check_emission!(hugr, llvm_ctx); - } - #[rstest] fn diverse_cfg_children(llvm_ctx: TestContext) { let hugr = SimpleHugrConfig::new() @@ -489,29 +437,19 @@ mod test_fns { let [r] = { let mut builder = builder.cfg_builder([], vec![bool_t()].into()).unwrap(); let konst = builder.add_constant(Value::false_val()); - let func = { - let mut builder = builder - .define_function( - "scoped_func", - HugrFuncType::new(type_row![], bool_t()), - ) - .unwrap(); - let w = builder.load_const(&konst); - builder.finish_with_outputs([w]).unwrap() - }; let entry = { let mut builder = builder .entry_builder([type_row![]], vec![bool_t()].into()) .unwrap(); let control = builder.add_load_value(Value::unary_unit_sum()); - let [r] = builder.call(func.handle(), &[], []).unwrap().outputs_arr(); + let r = builder.load_const(&konst); builder.finish_with_outputs(control, [r]).unwrap() }; let exit = builder.exit_block(); builder.branch(&entry, 0, &exit).unwrap(); builder.finish_sub_container().unwrap().outputs_arr() }; - builder.finish_with_outputs([r]).unwrap() + builder.finish_hugr_with_outputs([r]).unwrap() }); check_emission!(hugr, llvm_ctx); } @@ -575,7 +513,7 @@ mod test_fns { .finish_with_outputs(sum_inp_w, []) .unwrap() .outputs_arr(); - builder.finish_with_outputs(outs).unwrap() + builder.finish_hugr_with_outputs(outs).unwrap() }) }; llvm_ctx.add_extensions(CodegenExtsBuilder::add_default_prelude_extensions); @@ -696,7 +634,7 @@ mod test_fns { }; let [out_int] = tail_l.outputs_arr(); builder - .finish_with_outputs([out_int]) + .finish_hugr_with_outputs([out_int]) .unwrap_or_else(|e| panic!("{e}")) }) } @@ -731,7 +669,7 @@ mod test_fns { .with_extensions(PRELUDE_REGISTRY.to_owned()) .finish(|mut builder: DFGW| { let konst = builder.add_load_value(ConstUsize::new(42)); - builder.finish_with_outputs([konst]).unwrap() + builder.finish_hugr_with_outputs([konst]).unwrap() }); exec_ctx.add_extensions(CodegenExtsBuilder::add_default_prelude_extensions); assert_eq!(42, exec_ctx.exec_hugr_u64(hugr, "main")); diff --git a/hugr-llvm/src/extension/collections/array.rs b/hugr-llvm/src/extension/collections/array.rs index 725fbe2724..da5141d72f 100644 --- a/hugr-llvm/src/extension/collections/array.rs +++ b/hugr-llvm/src/extension/collections/array.rs @@ -214,7 +214,7 @@ impl CodegenExtension for ArrayCodegenExtension { .custom_type((array::EXTENSION_ID, array::ARRAY_TYPENAME), { let ccg = self.0.clone(); move |ts, hugr_type| { - let [TypeArg::BoundedNat { n }, TypeArg::Type { ty }] = hugr_type.args() else { + let [TypeArg::BoundedNat(n), TypeArg::Runtime(ty)] = hugr_type.args() else { return Err(anyhow!("Invalid type args for array type")); }; let elem_ty = ts.llvm_type(ty)?; @@ -908,7 +908,7 @@ pub fn emit_scan_op<'c, H: HugrView>( #[cfg(test)] mod test { - use hugr_core::builder::Container as _; + use hugr_core::builder::{DataflowHugr, HugrBuilder}; use hugr_core::extension::prelude::either_type; use hugr_core::ops::Tag; use hugr_core::std_extensions::STD_REG; @@ -952,7 +952,7 @@ mod test { build_all_array_ops(builder.dfg_builder_endo([]).unwrap()) .finish_sub_container() .unwrap(); - builder.finish_sub_container().unwrap() + builder.finish_hugr().unwrap() }); llvm_ctx.add_extensions(|cge| { cge.add_default_prelude_extensions() @@ -971,7 +971,7 @@ mod test { let arr = builder.add_new_array(usize_t(), [us1, us2]).unwrap(); let (_, arr) = builder.add_array_get(usize_t(), 2, arr, us1).unwrap(); builder.add_array_discard(usize_t(), 2, arr).unwrap(); - builder.finish_with_outputs([]).unwrap() + builder.finish_hugr_with_outputs([]).unwrap() }); llvm_ctx.add_extensions(|cge| { cge.add_default_prelude_extensions() @@ -991,7 +991,7 @@ mod test { let (arr1, arr2) = builder.add_array_clone(usize_t(), 2, arr).unwrap(); builder.add_array_discard(usize_t(), 2, arr1).unwrap(); builder.add_array_discard(usize_t(), 2, arr2).unwrap(); - builder.finish_with_outputs([]).unwrap() + builder.finish_hugr_with_outputs([]).unwrap() }); llvm_ctx.add_extensions(|cge| { cge.add_default_prelude_extensions() @@ -1008,7 +1008,7 @@ mod test { .finish(|mut builder| { let vs = vec![ConstUsize::new(1).into(), ConstUsize::new(2).into()]; let arr = builder.add_load_value(array::ArrayValue::new(usize_t(), vs)); - builder.finish_with_outputs([arr]).unwrap() + builder.finish_hugr_with_outputs([arr]).unwrap() }); llvm_ctx.add_extensions(|cge| { cge.add_default_prelude_extensions() @@ -1102,7 +1102,7 @@ mod test { } builder.finish_sub_container().unwrap().out_wire(0) }; - builder.finish_with_outputs([r]).unwrap() + builder.finish_hugr_with_outputs([r]).unwrap() }); exec_ctx.add_extensions(|cge| { cge.add_default_prelude_extensions() @@ -1207,7 +1207,7 @@ mod test { .unwrap(); conditional.finish_sub_container().unwrap().out_wire(0) }; - builder.finish_with_outputs([r]).unwrap() + builder.finish_hugr_with_outputs([r]).unwrap() }); exec_ctx.add_extensions(|cge| { cge.add_default_prelude_extensions() @@ -1314,7 +1314,7 @@ mod test { conditional.finish_sub_container().unwrap().out_wire(0) }; builder.add_array_discard(int_ty.clone(), 2, arr).unwrap(); - builder.finish_with_outputs([r]).unwrap() + builder.finish_hugr_with_outputs([r]).unwrap() }); exec_ctx.add_extensions(|cge| { cge.add_default_prelude_extensions() @@ -1371,7 +1371,7 @@ mod test { builder .add_array_discard(int_ty.clone(), 2, arr_clone) .unwrap(); - builder.finish_with_outputs([elem]).unwrap() + builder.finish_hugr_with_outputs([elem]).unwrap() }); exec_ctx.add_extensions(|cge| { cge.add_default_prelude_extensions() @@ -1441,7 +1441,7 @@ mod test { arr, ) .unwrap(); - builder.finish_with_outputs([r]).unwrap() + builder.finish_hugr_with_outputs([r]).unwrap() }); exec_ctx.add_extensions(|cge| { cge.add_default_prelude_extensions() @@ -1486,7 +1486,7 @@ mod test { r = builder.add_iadd(6, r, elem).unwrap(); } - builder.finish_with_outputs([r]).unwrap() + builder.finish_hugr_with_outputs([r]).unwrap() }); exec_ctx.add_extensions(|cge| { cge.add_default_prelude_extensions() @@ -1518,7 +1518,8 @@ mod test { .with_outs(int_ty.clone()) .with_extensions(exec_registry()) .finish(|mut builder| { - let mut func = builder + let mut mb = builder.module_root_builder(); + let mut func = mb .define_function("foo", Signature::new(vec![], vec![int_ty.clone()])) .unwrap(); let v = func.add_load_value(ConstInt::new_u(6, value).unwrap()); @@ -1539,7 +1540,7 @@ mod test { builder .add_array_discard(int_ty.clone(), size, arr) .unwrap(); - builder.finish_with_outputs([elem]).unwrap() + builder.finish_hugr_with_outputs([elem]).unwrap() }); exec_ctx.add_extensions(|cge| { cge.add_default_prelude_extensions() @@ -1572,7 +1573,8 @@ mod test { .add_new_array(int_ty.clone(), new_array_args) .unwrap(); - let mut func = builder + let mut mb = builder.module_root_builder(); + let mut func = mb .define_function( "foo", Signature::new(vec![int_ty.clone()], vec![int_ty.clone()]), @@ -1610,7 +1612,7 @@ mod test { builder .add_array_discard_empty(int_ty.clone(), arr) .unwrap(); - builder.finish_with_outputs([r]).unwrap() + builder.finish_hugr_with_outputs([r]).unwrap() }); exec_ctx.add_extensions(|cge| { cge.add_default_prelude_extensions() @@ -1642,7 +1644,8 @@ mod test { .add_new_array(int_ty.clone(), new_array_args) .unwrap(); - let mut func = builder + let mut mb = builder.module_root_builder(); + let mut func = mb .define_function( "foo", Signature::new( @@ -1666,7 +1669,7 @@ mod test { .unwrap() .outputs_arr(); builder.add_array_discard(Type::UNIT, size, arr).unwrap(); - builder.finish_with_outputs([sum]).unwrap() + builder.finish_hugr_with_outputs([sum]).unwrap() }); exec_ctx.add_extensions(|cge| { cge.add_default_prelude_extensions() diff --git a/hugr-llvm/src/extension/collections/list.rs b/hugr-llvm/src/extension/collections/list.rs index e1ff76e2a8..746ce11bc0 100644 --- a/hugr-llvm/src/extension/collections/list.rs +++ b/hugr-llvm/src/extension/collections/list.rs @@ -203,7 +203,7 @@ fn emit_list_op<'c, H: HugrView>( op: ListOp, ) -> Result<()> { let hugr_elem_ty = match args.node().args() { - [TypeArg::Type { ty }] => ty.clone(), + [TypeArg::Runtime(ty)] => ty.clone(), _ => { bail!("Collections: invalid type args for list op"); } @@ -366,7 +366,7 @@ fn build_load_i8_ptr<'c, H: HugrView>( #[cfg(test)] mod test { use hugr_core::{ - builder::{Dataflow, DataflowSubContainer}, + builder::{Dataflow, DataflowHugr}, extension::{ ExtensionRegistry, prelude::{self, ConstUsize, qb_t, usize_t}, @@ -407,7 +407,7 @@ mod test { .add_dataflow_op(ext_op, hugr_builder.input_wires()) .unwrap() .outputs(); - hugr_builder.finish_with_outputs(outputs).unwrap() + hugr_builder.finish_hugr_with_outputs(outputs).unwrap() }); llvm_ctx.add_extensions(CodegenExtsBuilder::add_default_prelude_extensions); llvm_ctx.add_extensions(CodegenExtsBuilder::add_default_list_extensions); @@ -427,7 +427,7 @@ mod test { .with_extensions(es) .finish(|mut hugr_builder| { let list = hugr_builder.add_load_value(ListValue::new(elem_ty, contents)); - hugr_builder.finish_with_outputs(vec![list]).unwrap() + hugr_builder.finish_hugr_with_outputs(vec![list]).unwrap() }); llvm_ctx.add_extensions(CodegenExtsBuilder::add_default_prelude_extensions); diff --git a/hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__static_array__test__emit_static_array_of_static_array@llvm14.snap b/hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__static_array__test__emit_static_array_of_static_array@llvm14.snap index 1af774422e..ad17a2c59f 100644 --- a/hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__static_array__test__emit_static_array_of_static_array@llvm14.snap +++ b/hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__static_array__test__emit_static_array_of_static_array@llvm14.snap @@ -15,14 +15,14 @@ source_filename = "test_context" @sa.inner.7f5d5e16.0 = constant { i64, [7 x i64] } { i64 7, [7 x i64] [i64 7, i64 7, i64 7, i64 7, i64 7, i64 7, i64 7] } @sa.inner.a0bc9c53.0 = constant { i64, [8 x i64] } { i64 8, [8 x i64] [i64 8, i64 8, i64 8, i64 8, i64 8, i64 8, i64 8, i64 8] } @sa.inner.1e8aada3.0 = constant { i64, [9 x i64] } { i64 9, [9 x i64] [i64 9, i64 9, i64 9, i64 9, i64 9, i64 9, i64 9, i64 9, i64 9] } -@sa.outer.c4a5911a.0 = constant { i64, [10 x { i64, [0 x i64] }*] } { i64 10, [10 x { i64, [0 x i64] }*] [{ i64, [0 x i64] }* @sa.inner.6acc1b76.0, { i64, [0 x i64] }* bitcast ({ i64, [1 x i64] }* @sa.inner.e637bb5.0 to { i64, [0 x i64] }*), { i64, [0 x i64] }* bitcast ({ i64, [2 x i64] }* @sa.inner.2b6593f.0 to { i64, [0 x i64] }*), { i64, [0 x i64] }* bitcast ({ i64, [3 x i64] }* @sa.inner.1b9ad7c.0 to { i64, [0 x i64] }*), { i64, [0 x i64] }* bitcast ({ i64, [4 x i64] }* @sa.inner.e67fbfa4.0 to { i64, [0 x i64] }*), { i64, [0 x i64] }* bitcast ({ i64, [5 x i64] }* @sa.inner.15dc27f6.0 to { i64, [0 x i64] }*), { i64, [0 x i64] }* bitcast ({ i64, [6 x i64] }* @sa.inner.c43a2bb2.0 to { i64, [0 x i64] }*), { i64, [0 x i64] }* bitcast ({ i64, [7 x i64] }* @sa.inner.7f5d5e16.0 to { i64, [0 x i64] }*), { i64, [0 x i64] }* bitcast ({ i64, [8 x i64] }* @sa.inner.a0bc9c53.0 to { i64, [0 x i64] }*), { i64, [0 x i64] }* bitcast ({ i64, [9 x i64] }* @sa.inner.1e8aada3.0 to { i64, [0 x i64] }*)] } +@sa.outer.e55b610a.0 = constant { i64, [10 x { i64, [0 x i64] }*] } { i64 10, [10 x { i64, [0 x i64] }*] [{ i64, [0 x i64] }* @sa.inner.6acc1b76.0, { i64, [0 x i64] }* bitcast ({ i64, [1 x i64] }* @sa.inner.e637bb5.0 to { i64, [0 x i64] }*), { i64, [0 x i64] }* bitcast ({ i64, [2 x i64] }* @sa.inner.2b6593f.0 to { i64, [0 x i64] }*), { i64, [0 x i64] }* bitcast ({ i64, [3 x i64] }* @sa.inner.1b9ad7c.0 to { i64, [0 x i64] }*), { i64, [0 x i64] }* bitcast ({ i64, [4 x i64] }* @sa.inner.e67fbfa4.0 to { i64, [0 x i64] }*), { i64, [0 x i64] }* bitcast ({ i64, [5 x i64] }* @sa.inner.15dc27f6.0 to { i64, [0 x i64] }*), { i64, [0 x i64] }* bitcast ({ i64, [6 x i64] }* @sa.inner.c43a2bb2.0 to { i64, [0 x i64] }*), { i64, [0 x i64] }* bitcast ({ i64, [7 x i64] }* @sa.inner.7f5d5e16.0 to { i64, [0 x i64] }*), { i64, [0 x i64] }* bitcast ({ i64, [8 x i64] }* @sa.inner.a0bc9c53.0 to { i64, [0 x i64] }*), { i64, [0 x i64] }* bitcast ({ i64, [9 x i64] }* @sa.inner.1e8aada3.0 to { i64, [0 x i64] }*)] } define i64 @_hl.main.1() { alloca_block: br label %entry_block entry_block: ; preds = %alloca_block - %0 = getelementptr inbounds { i64, [0 x { i64, [0 x i64] }*] }, { i64, [0 x { i64, [0 x i64] }*] }* bitcast ({ i64, [10 x { i64, [0 x i64] }*] }* @sa.outer.c4a5911a.0 to { i64, [0 x { i64, [0 x i64] }*] }*), i32 0, i32 0 + %0 = getelementptr inbounds { i64, [0 x { i64, [0 x i64] }*] }, { i64, [0 x { i64, [0 x i64] }*] }* bitcast ({ i64, [10 x { i64, [0 x i64] }*] }* @sa.outer.e55b610a.0 to { i64, [0 x { i64, [0 x i64] }*] }*), i32 0, i32 0 %1 = load i64, i64* %0, align 4 ret i64 %1 } diff --git a/hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__static_array__test__emit_static_array_of_static_array@pre-mem2reg@llvm14.snap b/hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__static_array__test__emit_static_array_of_static_array@pre-mem2reg@llvm14.snap index be8b63018c..b0f0741226 100644 --- a/hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__static_array__test__emit_static_array_of_static_array@pre-mem2reg@llvm14.snap +++ b/hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__static_array__test__emit_static_array_of_static_array@pre-mem2reg@llvm14.snap @@ -15,7 +15,7 @@ source_filename = "test_context" @sa.inner.7f5d5e16.0 = constant { i64, [7 x i64] } { i64 7, [7 x i64] [i64 7, i64 7, i64 7, i64 7, i64 7, i64 7, i64 7] } @sa.inner.a0bc9c53.0 = constant { i64, [8 x i64] } { i64 8, [8 x i64] [i64 8, i64 8, i64 8, i64 8, i64 8, i64 8, i64 8, i64 8] } @sa.inner.1e8aada3.0 = constant { i64, [9 x i64] } { i64 9, [9 x i64] [i64 9, i64 9, i64 9, i64 9, i64 9, i64 9, i64 9, i64 9, i64 9] } -@sa.outer.c4a5911a.0 = constant { i64, [10 x { i64, [0 x i64] }*] } { i64 10, [10 x { i64, [0 x i64] }*] [{ i64, [0 x i64] }* @sa.inner.6acc1b76.0, { i64, [0 x i64] }* bitcast ({ i64, [1 x i64] }* @sa.inner.e637bb5.0 to { i64, [0 x i64] }*), { i64, [0 x i64] }* bitcast ({ i64, [2 x i64] }* @sa.inner.2b6593f.0 to { i64, [0 x i64] }*), { i64, [0 x i64] }* bitcast ({ i64, [3 x i64] }* @sa.inner.1b9ad7c.0 to { i64, [0 x i64] }*), { i64, [0 x i64] }* bitcast ({ i64, [4 x i64] }* @sa.inner.e67fbfa4.0 to { i64, [0 x i64] }*), { i64, [0 x i64] }* bitcast ({ i64, [5 x i64] }* @sa.inner.15dc27f6.0 to { i64, [0 x i64] }*), { i64, [0 x i64] }* bitcast ({ i64, [6 x i64] }* @sa.inner.c43a2bb2.0 to { i64, [0 x i64] }*), { i64, [0 x i64] }* bitcast ({ i64, [7 x i64] }* @sa.inner.7f5d5e16.0 to { i64, [0 x i64] }*), { i64, [0 x i64] }* bitcast ({ i64, [8 x i64] }* @sa.inner.a0bc9c53.0 to { i64, [0 x i64] }*), { i64, [0 x i64] }* bitcast ({ i64, [9 x i64] }* @sa.inner.1e8aada3.0 to { i64, [0 x i64] }*)] } +@sa.outer.e55b610a.0 = constant { i64, [10 x { i64, [0 x i64] }*] } { i64 10, [10 x { i64, [0 x i64] }*] [{ i64, [0 x i64] }* @sa.inner.6acc1b76.0, { i64, [0 x i64] }* bitcast ({ i64, [1 x i64] }* @sa.inner.e637bb5.0 to { i64, [0 x i64] }*), { i64, [0 x i64] }* bitcast ({ i64, [2 x i64] }* @sa.inner.2b6593f.0 to { i64, [0 x i64] }*), { i64, [0 x i64] }* bitcast ({ i64, [3 x i64] }* @sa.inner.1b9ad7c.0 to { i64, [0 x i64] }*), { i64, [0 x i64] }* bitcast ({ i64, [4 x i64] }* @sa.inner.e67fbfa4.0 to { i64, [0 x i64] }*), { i64, [0 x i64] }* bitcast ({ i64, [5 x i64] }* @sa.inner.15dc27f6.0 to { i64, [0 x i64] }*), { i64, [0 x i64] }* bitcast ({ i64, [6 x i64] }* @sa.inner.c43a2bb2.0 to { i64, [0 x i64] }*), { i64, [0 x i64] }* bitcast ({ i64, [7 x i64] }* @sa.inner.7f5d5e16.0 to { i64, [0 x i64] }*), { i64, [0 x i64] }* bitcast ({ i64, [8 x i64] }* @sa.inner.a0bc9c53.0 to { i64, [0 x i64] }*), { i64, [0 x i64] }* bitcast ({ i64, [9 x i64] }* @sa.inner.1e8aada3.0 to { i64, [0 x i64] }*)] } define i64 @_hl.main.1() { alloca_block: @@ -25,7 +25,7 @@ alloca_block: br label %entry_block entry_block: ; preds = %alloca_block - store { i64, [0 x { i64, [0 x i64] }*] }* bitcast ({ i64, [10 x { i64, [0 x i64] }*] }* @sa.outer.c4a5911a.0 to { i64, [0 x { i64, [0 x i64] }*] }*), { i64, [0 x { i64, [0 x i64] }*] }** %"5_0", align 8 + store { i64, [0 x { i64, [0 x i64] }*] }* bitcast ({ i64, [10 x { i64, [0 x i64] }*] }* @sa.outer.e55b610a.0 to { i64, [0 x { i64, [0 x i64] }*] }*), { i64, [0 x { i64, [0 x i64] }*] }** %"5_0", align 8 %"5_01" = load { i64, [0 x { i64, [0 x i64] }*] }*, { i64, [0 x { i64, [0 x i64] }*] }** %"5_0", align 8 %0 = getelementptr inbounds { i64, [0 x { i64, [0 x i64] }*] }, { i64, [0 x { i64, [0 x i64] }*] }* %"5_01", i32 0, i32 0 %1 = load i64, i64* %0, align 4 diff --git a/hugr-llvm/src/extension/collections/stack_array.rs b/hugr-llvm/src/extension/collections/stack_array.rs index 9361a298cd..297f539511 100644 --- a/hugr-llvm/src/extension/collections/stack_array.rs +++ b/hugr-llvm/src/extension/collections/stack_array.rs @@ -126,7 +126,7 @@ impl CodegenExtension for ArrayCodegenExtension { .custom_type((array::EXTENSION_ID, array::ARRAY_TYPENAME), { let ccg = self.0.clone(); move |ts, hugr_type| { - let [TypeArg::BoundedNat { n }, TypeArg::Type { ty }] = hugr_type.args() else { + let [TypeArg::BoundedNat(n), TypeArg::Runtime(ty)] = hugr_type.args() else { return Err(anyhow!("Invalid type args for array type")); }; let elem_ty = ts.llvm_type(ty)?; @@ -726,7 +726,7 @@ fn emit_scan_op<'c, H: HugrView>( #[cfg(test)] mod test { - use hugr_core::builder::Container as _; + use hugr_core::builder::{DataflowHugr as _, HugrBuilder}; use hugr_core::extension::prelude::either_type; use hugr_core::ops::Tag; use hugr_core::std_extensions::STD_REG; @@ -770,7 +770,7 @@ mod test { build_all_array_ops(builder.dfg_builder_endo([]).unwrap()) .finish_sub_container() .unwrap(); - builder.finish_sub_container().unwrap() + builder.finish_hugr().unwrap() }); llvm_ctx.add_extensions(|cge| { cge.add_default_prelude_extensions() @@ -789,7 +789,7 @@ mod test { let arr = builder.add_new_array(usize_t(), [us1, us2]).unwrap(); let (_, arr) = builder.add_array_get(usize_t(), 2, arr, us1).unwrap(); builder.add_array_discard(usize_t(), 2, arr).unwrap(); - builder.finish_with_outputs([]).unwrap() + builder.finish_hugr_with_outputs([]).unwrap() }); llvm_ctx.add_extensions(|cge| { cge.add_default_prelude_extensions() @@ -809,7 +809,7 @@ mod test { let (arr1, arr2) = builder.add_array_clone(usize_t(), 2, arr).unwrap(); builder.add_array_discard(usize_t(), 2, arr1).unwrap(); builder.add_array_discard(usize_t(), 2, arr2).unwrap(); - builder.finish_with_outputs([]).unwrap() + builder.finish_hugr_with_outputs([]).unwrap() }); llvm_ctx.add_extensions(|cge| { cge.add_default_prelude_extensions() @@ -826,7 +826,7 @@ mod test { .finish(|mut builder| { let vs = vec![ConstUsize::new(1).into(), ConstUsize::new(2).into()]; let arr = builder.add_load_value(array::ArrayValue::new(usize_t(), vs)); - builder.finish_with_outputs([arr]).unwrap() + builder.finish_hugr_with_outputs([arr]).unwrap() }); llvm_ctx.add_extensions(|cge| { cge.add_default_prelude_extensions() @@ -885,7 +885,7 @@ mod test { } builder.finish_sub_container().unwrap().out_wire(0) }; - builder.finish_with_outputs([r]).unwrap() + builder.finish_hugr_with_outputs([r]).unwrap() }); exec_ctx.add_extensions(|cge| { cge.add_default_prelude_extensions() @@ -990,7 +990,7 @@ mod test { .unwrap(); conditional.finish_sub_container().unwrap().out_wire(0) }; - builder.finish_with_outputs([r]).unwrap() + builder.finish_hugr_with_outputs([r]).unwrap() }); exec_ctx.add_extensions(|cge| { cge.add_default_prelude_extensions() @@ -1097,7 +1097,7 @@ mod test { conditional.finish_sub_container().unwrap().out_wire(0) }; builder.add_array_discard(int_ty.clone(), 2, arr).unwrap(); - builder.finish_with_outputs([r]).unwrap() + builder.finish_hugr_with_outputs([r]).unwrap() }); exec_ctx.add_extensions(|cge| { cge.add_default_prelude_extensions() @@ -1154,7 +1154,7 @@ mod test { builder .add_array_discard(int_ty.clone(), 2, arr_clone) .unwrap(); - builder.finish_with_outputs([elem]).unwrap() + builder.finish_hugr_with_outputs([elem]).unwrap() }); exec_ctx.add_extensions(|cge| { cge.add_default_prelude_extensions() @@ -1224,7 +1224,7 @@ mod test { arr, ) .unwrap(); - builder.finish_with_outputs([r]).unwrap() + builder.finish_hugr_with_outputs([r]).unwrap() }); exec_ctx.add_extensions(|cge| { cge.add_default_prelude_extensions() @@ -1269,7 +1269,7 @@ mod test { r = builder.add_iadd(6, r, elem).unwrap(); } - builder.finish_with_outputs([r]).unwrap() + builder.finish_hugr_with_outputs([r]).unwrap() }); exec_ctx.add_extensions(|cge| { cge.add_default_prelude_extensions() @@ -1301,7 +1301,8 @@ mod test { .with_outs(int_ty.clone()) .with_extensions(exec_registry()) .finish(|mut builder| { - let mut func = builder + let mut mb = builder.module_root_builder(); + let mut func = mb .define_function("foo", Signature::new(vec![], vec![int_ty.clone()])) .unwrap(); let v = func.add_load_value(ConstInt::new_u(6, value).unwrap()); @@ -1322,7 +1323,7 @@ mod test { builder .add_array_discard(int_ty.clone(), size, arr) .unwrap(); - builder.finish_with_outputs([elem]).unwrap() + builder.finish_hugr_with_outputs([elem]).unwrap() }); exec_ctx.add_extensions(|cge| { cge.add_default_prelude_extensions() @@ -1355,7 +1356,8 @@ mod test { .add_new_array(int_ty.clone(), new_array_args) .unwrap(); - let mut func = builder + let mut mb = builder.module_root_builder(); + let mut func = mb .define_function( "foo", Signature::new(vec![int_ty.clone()], vec![int_ty.clone()]), @@ -1393,7 +1395,7 @@ mod test { builder .add_array_discard_empty(int_ty.clone(), arr) .unwrap(); - builder.finish_with_outputs([r]).unwrap() + builder.finish_hugr_with_outputs([r]).unwrap() }); exec_ctx.add_extensions(|cge| { cge.add_default_prelude_extensions() @@ -1412,7 +1414,6 @@ mod test { // We build a HUGR that: // - Creates an array [1, 2, 3, ..., size] // - Sums up the elements of the array using a scan and returns that sum - let int_ty = int_type(6); let hugr = SimpleHugrConfig::new() .with_outs(int_ty.clone()) @@ -1425,7 +1426,8 @@ mod test { .add_new_array(int_ty.clone(), new_array_args) .unwrap(); - let mut func = builder + let mut mb = builder.module_root_builder(); + let mut func = mb .define_function( "foo", Signature::new( @@ -1449,7 +1451,7 @@ mod test { .unwrap() .outputs_arr(); builder.add_array_discard(Type::UNIT, size, arr).unwrap(); - builder.finish_with_outputs([sum]).unwrap() + builder.finish_hugr_with_outputs([sum]).unwrap() }); exec_ctx.add_extensions(|cge| { cge.add_default_prelude_extensions() diff --git a/hugr-llvm/src/extension/collections/static_array.rs b/hugr-llvm/src/extension/collections/static_array.rs index e9df520ee3..50ac99b723 100644 --- a/hugr-llvm/src/extension/collections/static_array.rs +++ b/hugr-llvm/src/extension/collections/static_array.rs @@ -370,7 +370,7 @@ impl CodegenExtension for StaticArrayCodegenE let sac = self.0.clone(); move |ts, custom_type| { let element_type = custom_type.args()[0] - .as_type() + .as_runtime() .expect("Type argument for static array must be a type"); sac.static_array_type(ts, &element_type) } @@ -394,6 +394,7 @@ impl CodegenExtension for StaticArrayCodegenE mod test { use super::*; use float_types::float64_type; + use hugr_core::builder::DataflowHugr; use hugr_core::extension::prelude::ConstUsize; use hugr_core::ops::OpType; use hugr_core::ops::Value; @@ -459,7 +460,7 @@ mod test { ])) .finish(|mut builder| { let a = builder.add_load_value(value); - builder.finish_with_outputs([a]).unwrap() + builder.finish_hugr_with_outputs([a]).unwrap() }); check_emission!(hugr, llvm_ctx); } @@ -512,7 +513,7 @@ mod test { } cond.finish_sub_container().unwrap().outputs_arr() }; - builder.finish_with_outputs([out]).unwrap() + builder.finish_hugr_with_outputs([out]).unwrap() }); exec_ctx.add_extensions(|ceb| { @@ -534,7 +535,7 @@ mod test { let arr = builder .add_load_value(StaticArrayValue::try_new("empty", usize_t(), vec![]).unwrap()); let len = builder.add_static_array_len(usize_t(), arr).unwrap(); - builder.finish_with_outputs([len]).unwrap() + builder.finish_hugr_with_outputs([len]).unwrap() }); exec_ctx.add_extensions(|ceb| { @@ -574,7 +575,7 @@ mod test { let len = builder .add_static_array_len(inner_arr_ty, outer_arr) .unwrap(); - builder.finish_with_outputs([len]).unwrap() + builder.finish_hugr_with_outputs([len]).unwrap() }); check_emission!(hugr, llvm_ctx); } diff --git a/hugr-llvm/src/extension/conversions.rs b/hugr-llvm/src/extension/conversions.rs index cbc036719b..0ed8ec88c2 100644 --- a/hugr-llvm/src/extension/conversions.rs +++ b/hugr-llvm/src/extension/conversions.rs @@ -275,7 +275,7 @@ mod test { use crate::check_emission; use crate::emit::test::{DFGW, SimpleHugrConfig}; use crate::test::{TestContext, exec_ctx, llvm_ctx}; - use hugr_core::builder::SubContainer; + use hugr_core::builder::{DataflowHugr, SubContainer}; use hugr_core::std_extensions::STD_REG; use hugr_core::std_extensions::arithmetic::float_types::ConstF64; use hugr_core::std_extensions::arithmetic::int_types::ConstInt; @@ -311,7 +311,7 @@ mod test { .add_dataflow_op(ext_op, [in1]) .unwrap() .outputs(); - hugr_builder.finish_with_outputs(outputs).unwrap() + hugr_builder.finish_hugr_with_outputs(outputs).unwrap() }) } @@ -381,7 +381,7 @@ mod test { .add_dataflow_op(ext_op, [in1]) .unwrap() .outputs_arr(); - hugr_builder.finish_with_outputs([out1]).unwrap() + hugr_builder.finish_hugr_with_outputs([out1]).unwrap() }); check_emission!(op_name, hugr, llvm_ctx); } @@ -393,7 +393,7 @@ mod test { .with_extensions(PRELUDE_REGISTRY.to_owned()) .finish(|mut builder: DFGW| { let konst = builder.add_load_value(ConstUsize::new(42)); - builder.finish_with_outputs([konst]).unwrap() + builder.finish_hugr_with_outputs([konst]).unwrap() }); exec_ctx.add_extensions(CodegenExtsBuilder::add_default_prelude_extensions); assert_eq!(42, exec_ctx.exec_hugr_u64(hugr, "main")); @@ -417,7 +417,7 @@ mod test { .add_dataflow_op(ConvertOpDef::itousize.without_log_width(), [int]) .unwrap() .outputs_arr(); - builder.finish_with_outputs([usize_]).unwrap() + builder.finish_hugr_with_outputs([usize_]).unwrap() }); exec_ctx.add_extensions(|builder| { builder @@ -481,7 +481,7 @@ mod test { .add_dataflow_op(ConvertOpDef::itousize.without_log_width(), [cond_result]) .unwrap() .outputs_arr(); - builder.finish_with_outputs([usize_]).unwrap() + builder.finish_hugr_with_outputs([usize_]).unwrap() }) } @@ -613,7 +613,7 @@ mod test { let true_result = case_true.add_load_value(ConstUsize::new(6)); case_true.finish_with_outputs([true_result]).unwrap(); let res = cond.finish_sub_container().unwrap(); - builder.finish_with_outputs(res.outputs()).unwrap() + builder.finish_hugr_with_outputs(res.outputs()).unwrap() }); exec_ctx.add_extensions(|builder| { builder @@ -635,7 +635,7 @@ mod test { let [b] = builder.add_dataflow_op(i2b, [i]).unwrap().outputs_arr(); let b2i = EXTENSION.instantiate_extension_op("ifrombool", []).unwrap(); let [i] = builder.add_dataflow_op(b2i, [b]).unwrap().outputs_arr(); - builder.finish_with_outputs([i]).unwrap() + builder.finish_hugr_with_outputs([i]).unwrap() }); exec_ctx.add_extensions(|builder| { builder @@ -663,7 +663,7 @@ mod test { .instantiate_extension_op("bytecast_int64_to_float64", []) .unwrap(); let [f] = builder.add_dataflow_op(i2f, [i]).unwrap().outputs_arr(); - builder.finish_with_outputs([f]).unwrap() + builder.finish_hugr_with_outputs([f]).unwrap() }); exec_ctx.add_extensions(|builder| { builder @@ -690,7 +690,7 @@ mod test { .instantiate_extension_op("bytecast_float64_to_int64", []) .unwrap(); let [i] = builder.add_dataflow_op(f2i, [f]).unwrap().outputs_arr(); - builder.finish_with_outputs([i]).unwrap() + builder.finish_hugr_with_outputs([i]).unwrap() }); exec_ctx.add_extensions(|builder| { builder diff --git a/hugr-llvm/src/extension/float.rs b/hugr-llvm/src/extension/float.rs index b95a698b18..968ae3f585 100644 --- a/hugr-llvm/src/extension/float.rs +++ b/hugr-llvm/src/extension/float.rs @@ -149,13 +149,14 @@ impl<'a, H: HugrView + 'a> CodegenExtsBuilder<'a, H> { #[cfg(test)] mod test { use hugr_core::Hugr; + use hugr_core::builder::DataflowHugr; use hugr_core::extension::SignatureFunc; use hugr_core::extension::simple_op::MakeOpDef; use hugr_core::std_extensions::STD_REG; use hugr_core::std_extensions::arithmetic::float_ops::FloatOps; use hugr_core::types::TypeRow; use hugr_core::{ - builder::{Dataflow, DataflowSubContainer}, + builder::Dataflow, std_extensions::arithmetic::float_types::{ConstF64, float64_type}, }; use rstest::rstest; @@ -184,7 +185,7 @@ mod test { .add_dataflow_op(op, builder.input_wires()) .unwrap() .outputs(); - builder.finish_with_outputs(outputs).unwrap() + builder.finish_hugr_with_outputs(outputs).unwrap() }) } @@ -196,7 +197,7 @@ mod test { .with_extensions(STD_REG.to_owned()) .finish(|mut builder| { let c = builder.add_load_value(ConstF64::new(3.12)); - builder.finish_with_outputs([c]).unwrap() + builder.finish_hugr_with_outputs([c]).unwrap() }); check_emission!(hugr, llvm_ctx); } diff --git a/hugr-llvm/src/extension/int.rs b/hugr-llvm/src/extension/int.rs index 315c7c7296..bea508d774 100644 --- a/hugr-llvm/src/extension/int.rs +++ b/hugr-llvm/src/extension/int.rs @@ -668,7 +668,7 @@ fn emit_int_op<'c, H: HugrView>( ]) }), IntOpDef::inarrow_s => { - let Some(TypeArg::BoundedNat { n: out_log_width }) = args.node().args().last().cloned() + let Some(TypeArg::BoundedNat(out_log_width)) = args.node().args().last().cloned() else { bail!("Type arg to inarrow_s wasn't a Nat"); }; @@ -686,7 +686,7 @@ fn emit_int_op<'c, H: HugrView>( }) } IntOpDef::inarrow_u => { - let Some(TypeArg::BoundedNat { n: out_log_width }) = args.node().args().last().cloned() + let Some(TypeArg::BoundedNat(out_log_width)) = args.node().args().last().cloned() else { bail!("Type arg to inarrow_u wasn't a Nat"); }; @@ -756,7 +756,7 @@ pub(crate) fn get_width_arg>( args: &EmitOpArgs<'_, '_, ExtensionOp, H>, op: &impl MakeExtensionOp, ) -> Result { - let [TypeArg::BoundedNat { n: log_width }] = args.node.args() else { + let [TypeArg::BoundedNat(log_width)] = args.node.args() else { bail!( "Expected exactly one BoundedNat parameter to {}", op.op_id() @@ -1094,7 +1094,7 @@ fn llvm_type<'c>( context: TypingSession<'c, '_>, hugr_type: &CustomType, ) -> Result> { - if let [TypeArg::BoundedNat { n }] = hugr_type.args() { + if let [TypeArg::BoundedNat(n)] = hugr_type.args() { let m = *n as usize; if m < int_types::INT_TYPES.len() && int_types::INT_TYPES[m] == hugr_type.clone().into() { return Ok(match m { @@ -1141,6 +1141,7 @@ impl<'a, H: HugrView + 'a> CodegenExtsBuilder<'a, H> { #[cfg(test)] mod test { use anyhow::Result; + use hugr_core::builder::DataflowHugr; use hugr_core::extension::prelude::{ConstError, UnwrapBuilder, error_type}; use hugr_core::std_extensions::STD_REG; use hugr_core::{ @@ -1242,7 +1243,9 @@ mod test { .unwrap() .outputs(); let processed_outputs = process(&mut hugr_builder, outputs).unwrap(); - hugr_builder.finish_with_outputs(processed_outputs).unwrap() + hugr_builder + .finish_hugr_with_outputs(processed_outputs) + .unwrap() }) } @@ -1578,7 +1581,7 @@ mod test { .add_dataflow_op(iu_to_s, [unsigned]) .unwrap() .outputs_arr(); - hugr_builder.finish_with_outputs([signed]).unwrap() + hugr_builder.finish_hugr_with_outputs([signed]).unwrap() }); let act = int_exec_ctx.exec_hugr_i64(hugr, "main"); assert_eq!(act, val as i64); @@ -1605,7 +1608,7 @@ mod test { .add_dataflow_op(make_int_op("iadd", log_width), [unsigned, num]) .unwrap() .outputs_arr(); - hugr_builder.finish_with_outputs([res]).unwrap() + hugr_builder.finish_hugr_with_outputs([res]).unwrap() }); let act = int_exec_ctx.exec_hugr_u64(hugr, "main"); assert_eq!(act, (val as u64) + 42); diff --git a/hugr-llvm/src/extension/logic.rs b/hugr-llvm/src/extension/logic.rs index 50dd2bd17c..b382a21408 100644 --- a/hugr-llvm/src/extension/logic.rs +++ b/hugr-llvm/src/extension/logic.rs @@ -76,7 +76,7 @@ impl<'a, H: HugrView + 'a> CodegenExtsBuilder<'a, H> { mod test { use hugr_core::{ Hugr, - builder::{Dataflow, DataflowSubContainer}, + builder::{Dataflow, DataflowHugr}, extension::{ExtensionRegistry, prelude::bool_t}, std_extensions::logic::{self, LogicOp}, }; @@ -99,7 +99,7 @@ mod test { .add_dataflow_op(op, builder.input_wires()) .unwrap() .outputs(); - builder.finish_with_outputs(outputs).unwrap() + builder.finish_hugr_with_outputs(outputs).unwrap() }) } diff --git a/hugr-llvm/src/extension/prelude.rs b/hugr-llvm/src/extension/prelude.rs index 62a00527c8..d4b918b559 100644 --- a/hugr-llvm/src/extension/prelude.rs +++ b/hugr-llvm/src/extension/prelude.rs @@ -117,6 +117,37 @@ pub trait PreludeCodegen: Clone { Ok(err.into()) } + /// Emit instructions to construct an error value from a signal and message. + /// + /// The type of the returned value must match [`Self::error_type`]. + /// + /// The default implementation constructs a struct with the given signal and message. + fn emit_make_error<'c, H: HugrView>( + &self, + ctx: &mut EmitFuncContext<'c, '_, H>, + signal: BasicValueEnum<'c>, + message: BasicValueEnum<'c>, + ) -> Result> { + let builder = ctx.builder(); + + // The usize signal is an i64 but error struct stores an i32. + let i32_type = ctx.typing_session().iw_context().i32_type(); + let signal_int = signal.into_int_value(); + let signal_truncated = builder.build_int_truncate(signal_int, i32_type, "")?; + + // Construct the error struct as runtime value. + let err_ty = ctx.llvm_type(&error_type())?.into_struct_type(); + let undef = err_ty.get_undef(); + let err_with_sig = builder + .build_insert_value(undef, signal_truncated, 0, "")? + .into_struct_value(); + let err_complete = builder + .build_insert_value(err_with_sig, message, 1, "")? + .into_struct_value(); + + Ok(err_complete.into()) + } + /// Emit instructions to halt execution with the error `err`. /// /// The type of `err` must match that returned from [`Self::error_type`]. @@ -345,6 +376,22 @@ pub fn add_prelude_extensions<'a, H: HugrView + 'a>( args.outputs.finish(context.builder(), []) } }) + .extension_op(prelude::PRELUDE_ID, prelude::MAKE_ERROR_OP_ID, { + let pcg = pcg.clone(); + move |context, args| { + let signal = args.inputs[0]; + let message = args.inputs[1]; + ensure!( + message.get_type() + == pcg + .string_type(&context.typing_session())? + .as_basic_type_enum(), + signal.get_type() == pcg.usize_type(&context.typing_session()).into() + ); + let err = pcg.emit_make_error(context, signal, message)?; + args.outputs.finish(context.builder(), [err]) + } + }) .extension_op(prelude::PRELUDE_ID, prelude::PANIC_OP_ID, { let pcg = pcg.clone(); move |context, args| { @@ -389,7 +436,7 @@ pub fn add_prelude_extensions<'a, H: HugrView + 'a>( move |context, args| { let load_nat = LoadNat::from_extension_op(args.node().as_ref())?; let v = match load_nat.get_nat() { - TypeArg::BoundedNat { n } => pcg + TypeArg::BoundedNat(n) => pcg .usize_type(&context.typing_session()) .const_int(n, false), arg => bail!("Unexpected type arg for LoadNat: {}", arg), @@ -405,10 +452,10 @@ pub fn add_prelude_extensions<'a, H: HugrView + 'a>( #[cfg(test)] mod test { - use hugr_core::builder::{Dataflow, DataflowSubContainer}; + use hugr_core::builder::{Dataflow, DataflowHugr}; use hugr_core::extension::PRELUDE; - use hugr_core::extension::prelude::{EXIT_OP_ID, Noop}; - use hugr_core::types::{Type, TypeArg}; + use hugr_core::extension::prelude::{EXIT_OP_ID, MAKE_ERROR_OP_ID, Noop}; + use hugr_core::types::{Term, Type}; use hugr_core::{Hugr, type_row}; use prelude::{PANIC_OP_ID, PRINT_OP_ID, bool_t, qb_t, usize_t}; use rstest::{fixture, rstest}; @@ -479,7 +526,7 @@ mod test { .with_extensions(prelude::PRELUDE_REGISTRY.to_owned()) .finish(|mut builder| { let k = builder.add_load_value(ConstUsize::new(17)); - builder.finish_with_outputs([k]).unwrap() + builder.finish_hugr_with_outputs([k]).unwrap() }); check_emission!(hugr, prelude_llvm_ctx); } @@ -502,7 +549,7 @@ mod test { .finish(|mut builder| { let k1 = builder.add_load_value(konst1); let k2 = builder.add_load_value(konst2); - builder.finish_with_outputs([k1, k2]).unwrap() + builder.finish_hugr_with_outputs([k1, k2]).unwrap() }); check_emission!(hugr, prelude_llvm_ctx); } @@ -519,7 +566,7 @@ mod test { .add_dataflow_op(Noop::new(usize_t()), in_wires) .unwrap() .outputs(); - builder.finish_with_outputs(r).unwrap() + builder.finish_hugr_with_outputs(r).unwrap() }); check_emission!(hugr, prelude_llvm_ctx); } @@ -533,7 +580,7 @@ mod test { .finish(|mut builder| { let in_wires = builder.input_wires(); let r = builder.make_tuple(in_wires).unwrap(); - builder.finish_with_outputs([r]).unwrap() + builder.finish_hugr_with_outputs([r]).unwrap() }); check_emission!(hugr, prelude_llvm_ctx); } @@ -551,7 +598,7 @@ mod test { builder.input_wires(), ) .unwrap(); - builder.finish_with_outputs(unpack.outputs()).unwrap() + builder.finish_hugr_with_outputs(unpack.outputs()).unwrap() }); check_emission!(hugr, prelude_llvm_ctx); } @@ -559,10 +606,8 @@ mod test { #[rstest] fn prelude_panic(prelude_llvm_ctx: TestContext) { let error_val = ConstError::new(42, "PANIC"); - let type_arg_q: TypeArg = TypeArg::Type { ty: qb_t() }; - let type_arg_2q: TypeArg = TypeArg::Sequence { - elems: vec![type_arg_q.clone(), type_arg_q], - }; + let type_arg_q: Term = qb_t().into(); + let type_arg_2q = Term::new_list([type_arg_q.clone(), type_arg_q]); let panic_op = PRELUDE .instantiate_extension_op(&PANIC_OP_ID, [type_arg_2q.clone(), type_arg_2q.clone()]) .unwrap(); @@ -578,7 +623,7 @@ mod test { .add_dataflow_op(panic_op, [err, q0, q1]) .unwrap() .outputs_arr(); - builder.finish_with_outputs([q0, q1]).unwrap() + builder.finish_hugr_with_outputs([q0, q1]).unwrap() }); check_emission!(hugr, prelude_llvm_ctx); @@ -587,10 +632,8 @@ mod test { #[rstest] fn prelude_exit(prelude_llvm_ctx: TestContext) { let error_val = ConstError::new(42, "EXIT"); - let type_arg_q: TypeArg = TypeArg::Type { ty: qb_t() }; - let type_arg_2q: TypeArg = TypeArg::Sequence { - elems: vec![type_arg_q.clone(), type_arg_q], - }; + let type_arg_q: Term = qb_t().into(); + let type_arg_2q = Term::new_list([type_arg_q.clone(), type_arg_q]); let exit_op = PRELUDE .instantiate_extension_op(&EXIT_OP_ID, [type_arg_2q.clone(), type_arg_2q.clone()]) .unwrap(); @@ -606,7 +649,7 @@ mod test { .add_dataflow_op(exit_op, [err, q0, q1]) .unwrap() .outputs_arr(); - builder.finish_with_outputs([q0, q1]).unwrap() + builder.finish_hugr_with_outputs([q0, q1]).unwrap() }); check_emission!(hugr, prelude_llvm_ctx); @@ -622,7 +665,61 @@ mod test { .finish(|mut builder| { let greeting_out = builder.add_load_value(greeting); builder.add_dataflow_op(print_op, [greeting_out]).unwrap(); - builder.finish_with_outputs([]).unwrap() + builder.finish_hugr_with_outputs([]).unwrap() + }); + + check_emission!(hugr, prelude_llvm_ctx); + } + + #[rstest] + fn prelude_make_error(prelude_llvm_ctx: TestContext) { + let sig: ConstUsize = ConstUsize::new(100); + let msg: ConstString = ConstString::new("Error!".into()); + + let make_error_op = PRELUDE + .instantiate_extension_op(&MAKE_ERROR_OP_ID, []) + .unwrap(); + + let hugr = SimpleHugrConfig::new() + .with_extensions(prelude::PRELUDE_REGISTRY.to_owned()) + .with_outs(error_type()) + .finish(|mut builder| { + let sig_out = builder.add_load_value(sig); + let msg_out = builder.add_load_value(msg); + let [err] = builder + .add_dataflow_op(make_error_op, [sig_out, msg_out]) + .unwrap() + .outputs_arr(); + builder.finish_hugr_with_outputs([err]).unwrap() + }); + + check_emission!(hugr, prelude_llvm_ctx); + } + + #[rstest] + fn prelude_make_error_and_panic(prelude_llvm_ctx: TestContext) { + let sig: ConstUsize = ConstUsize::new(100); + let msg: ConstString = ConstString::new("Error!".into()); + + let make_error_op = PRELUDE + .instantiate_extension_op(&MAKE_ERROR_OP_ID, []) + .unwrap(); + + let panic_op = PRELUDE + .instantiate_extension_op(&PANIC_OP_ID, [Term::new_list([]), Term::new_list([])]) + .unwrap(); + + let hugr = SimpleHugrConfig::new() + .with_extensions(prelude::PRELUDE_REGISTRY.to_owned()) + .finish(|mut builder| { + let sig_out = builder.add_load_value(sig); + let msg_out = builder.add_load_value(msg); + let [err] = builder + .add_dataflow_op(make_error_op, [sig_out, msg_out]) + .unwrap() + .outputs_arr(); + builder.add_dataflow_op(panic_op, [err]).unwrap(); + builder.finish_hugr_with_outputs([]).unwrap() }); check_emission!(hugr, prelude_llvm_ctx); @@ -635,10 +732,10 @@ mod test { .with_extensions(prelude::PRELUDE_REGISTRY.to_owned()) .finish(|mut builder| { let v = builder - .add_dataflow_op(LoadNat::new(TypeArg::BoundedNat { n: 42 }), vec![]) + .add_dataflow_op(LoadNat::new(42u64.into()), vec![]) .unwrap() .out_wire(0); - builder.finish_with_outputs([v]).unwrap() + builder.finish_hugr_with_outputs([v]).unwrap() }); check_emission!(hugr, prelude_llvm_ctx); } @@ -651,7 +748,7 @@ mod test { .finish(|mut builder| { let i = builder.add_load_value(ConstUsize::new(42)); let [w1, _w2] = builder.add_barrier([i, i]).unwrap().outputs_arr(); - builder.finish_with_outputs([w1]).unwrap() + builder.finish_hugr_with_outputs([w1]).unwrap() }) } diff --git a/hugr-llvm/src/extension/snapshots/hugr_llvm__extension__prelude__test__prelude_make_error@llvm14.snap b/hugr-llvm/src/extension/snapshots/hugr_llvm__extension__prelude__test__prelude_make_error@llvm14.snap new file mode 100644 index 0000000000..2a543d0e11 --- /dev/null +++ b/hugr-llvm/src/extension/snapshots/hugr_llvm__extension__prelude__test__prelude_make_error@llvm14.snap @@ -0,0 +1,19 @@ +--- +source: hugr-llvm/src/extension/prelude.rs +expression: mod_str +--- +; ModuleID = 'test_context' +source_filename = "test_context" + +@0 = private unnamed_addr constant [7 x i8] c"Error!\00", align 1 + +define { i32, i8* } @_hl.main.1() { +alloca_block: + br label %entry_block + +entry_block: ; preds = %alloca_block + %0 = trunc i64 100 to i32 + %1 = insertvalue { i32, i8* } undef, i32 %0, 0 + %2 = insertvalue { i32, i8* } %1, i8* getelementptr inbounds ([7 x i8], [7 x i8]* @0, i32 0, i32 0), 1 + ret { i32, i8* } %2 +} diff --git a/hugr-llvm/src/extension/snapshots/hugr_llvm__extension__prelude__test__prelude_make_error@pre-mem2reg@llvm14.snap b/hugr-llvm/src/extension/snapshots/hugr_llvm__extension__prelude__test__prelude_make_error@pre-mem2reg@llvm14.snap new file mode 100644 index 0000000000..d061dc36a3 --- /dev/null +++ b/hugr-llvm/src/extension/snapshots/hugr_llvm__extension__prelude__test__prelude_make_error@pre-mem2reg@llvm14.snap @@ -0,0 +1,31 @@ +--- +source: hugr-llvm/src/extension/prelude.rs +expression: mod_str +--- +; ModuleID = 'test_context' +source_filename = "test_context" + +@0 = private unnamed_addr constant [7 x i8] c"Error!\00", align 1 + +define { i32, i8* } @_hl.main.1() { +alloca_block: + %"0" = alloca { i32, i8* }, align 8 + %"7_0" = alloca i8*, align 8 + %"5_0" = alloca i64, align 8 + %"8_0" = alloca { i32, i8* }, align 8 + br label %entry_block + +entry_block: ; preds = %alloca_block + store i8* getelementptr inbounds ([7 x i8], [7 x i8]* @0, i32 0, i32 0), i8** %"7_0", align 8 + store i64 100, i64* %"5_0", align 4 + %"5_01" = load i64, i64* %"5_0", align 4 + %"7_02" = load i8*, i8** %"7_0", align 8 + %0 = trunc i64 %"5_01" to i32 + %1 = insertvalue { i32, i8* } undef, i32 %0, 0 + %2 = insertvalue { i32, i8* } %1, i8* %"7_02", 1 + store { i32, i8* } %2, { i32, i8* }* %"8_0", align 8 + %"8_03" = load { i32, i8* }, { i32, i8* }* %"8_0", align 8 + store { i32, i8* } %"8_03", { i32, i8* }* %"0", align 8 + %"04" = load { i32, i8* }, { i32, i8* }* %"0", align 8 + ret { i32, i8* } %"04" +} diff --git a/hugr-llvm/src/extension/snapshots/hugr_llvm__extension__prelude__test__prelude_make_error_and_panic@llvm14.snap b/hugr-llvm/src/extension/snapshots/hugr_llvm__extension__prelude__test__prelude_make_error_and_panic@llvm14.snap new file mode 100644 index 0000000000..fdaae15e98 --- /dev/null +++ b/hugr-llvm/src/extension/snapshots/hugr_llvm__extension__prelude__test__prelude_make_error_and_panic@llvm14.snap @@ -0,0 +1,28 @@ +--- +source: hugr-llvm/src/extension/prelude.rs +expression: mod_str +--- +; ModuleID = 'test_context' +source_filename = "test_context" + +@0 = private unnamed_addr constant [7 x i8] c"Error!\00", align 1 +@prelude.panic_template = private unnamed_addr constant [34 x i8] c"Program panicked (signal %i): %s\0A\00", align 1 + +define void @_hl.main.1() { +alloca_block: + br label %entry_block + +entry_block: ; preds = %alloca_block + %0 = trunc i64 100 to i32 + %1 = insertvalue { i32, i8* } undef, i32 %0, 0 + %2 = insertvalue { i32, i8* } %1, i8* getelementptr inbounds ([7 x i8], [7 x i8]* @0, i32 0, i32 0), 1 + %3 = extractvalue { i32, i8* } %2, 0 + %4 = extractvalue { i32, i8* } %2, 1 + %5 = call i32 (i8*, ...) @printf(i8* getelementptr inbounds ([34 x i8], [34 x i8]* @prelude.panic_template, i32 0, i32 0), i32 %3, i8* %4) + call void @abort() + ret void +} + +declare i32 @printf(i8*, ...) + +declare void @abort() diff --git a/hugr-llvm/src/extension/snapshots/hugr_llvm__extension__prelude__test__prelude_make_error_and_panic@pre-mem2reg@llvm14.snap b/hugr-llvm/src/extension/snapshots/hugr_llvm__extension__prelude__test__prelude_make_error_and_panic@pre-mem2reg@llvm14.snap new file mode 100644 index 0000000000..8ff4526e04 --- /dev/null +++ b/hugr-llvm/src/extension/snapshots/hugr_llvm__extension__prelude__test__prelude_make_error_and_panic@pre-mem2reg@llvm14.snap @@ -0,0 +1,37 @@ +--- +source: hugr-llvm/src/extension/prelude.rs +expression: mod_str +--- +; ModuleID = 'test_context' +source_filename = "test_context" + +@0 = private unnamed_addr constant [7 x i8] c"Error!\00", align 1 +@prelude.panic_template = private unnamed_addr constant [34 x i8] c"Program panicked (signal %i): %s\0A\00", align 1 + +define void @_hl.main.1() { +alloca_block: + %"7_0" = alloca i8*, align 8 + %"5_0" = alloca i64, align 8 + %"8_0" = alloca { i32, i8* }, align 8 + br label %entry_block + +entry_block: ; preds = %alloca_block + store i8* getelementptr inbounds ([7 x i8], [7 x i8]* @0, i32 0, i32 0), i8** %"7_0", align 8 + store i64 100, i64* %"5_0", align 4 + %"5_01" = load i64, i64* %"5_0", align 4 + %"7_02" = load i8*, i8** %"7_0", align 8 + %0 = trunc i64 %"5_01" to i32 + %1 = insertvalue { i32, i8* } undef, i32 %0, 0 + %2 = insertvalue { i32, i8* } %1, i8* %"7_02", 1 + store { i32, i8* } %2, { i32, i8* }* %"8_0", align 8 + %"8_03" = load { i32, i8* }, { i32, i8* }* %"8_0", align 8 + %3 = extractvalue { i32, i8* } %"8_03", 0 + %4 = extractvalue { i32, i8* } %"8_03", 1 + %5 = call i32 (i8*, ...) @printf(i8* getelementptr inbounds ([34 x i8], [34 x i8]* @prelude.panic_template, i32 0, i32 0), i32 %3, i8* %4) + call void @abort() + ret void +} + +declare i32 @printf(i8*, ...) + +declare void @abort() diff --git a/hugr-llvm/src/test.rs b/hugr-llvm/src/test.rs index 9864ae12e1..59919baad4 100644 --- a/hugr-llvm/src/test.rs +++ b/hugr-llvm/src/test.rs @@ -2,7 +2,7 @@ use std::rc::Rc; use hugr_core::{ Hugr, - builder::{Container, Dataflow, DataflowSubContainer, HugrBuilder, ModuleBuilder}, + builder::{Dataflow, DataflowSubContainer, HugrBuilder, ModuleBuilder}, ops::{OpTrait, OpType}, types::PolyFuncType, }; diff --git a/hugr-llvm/src/utils/fat.rs b/hugr-llvm/src/utils/fat.rs index 1b046ddf02..1476bcb484 100644 --- a/hugr-llvm/src/utils/fat.rs +++ b/hugr-llvm/src/utils/fat.rs @@ -8,7 +8,7 @@ use hugr_core::hugr::views::Rerooted; use hugr_core::{ Hugr, HugrView, IncomingPort, Node, NodeIndex, OutgoingPort, core::HugrNode, - ops::{CFG, DataflowBlock, ExitBlock, Input, OpType, Output}, + ops::{CFG, DataflowBlock, ExitBlock, Input, Module, OpType, Output}, types::Type, }; use itertools::Itertools as _; @@ -373,7 +373,12 @@ pub trait FatExt: HugrView { } /// Try to create a specific [`FatNode`] for the root of a [`HugrView`]. - fn fat_root(&self) -> Option> + fn fat_root(&self) -> Option> { + self.try_fat(self.module_root()) + } + + /// Try to create a specific [`FatNode`] for the entrypoint of a [`HugrView`]. + fn fat_entrypoint(&self) -> Option> where for<'a> &'a OpType: TryInto<&'a OT>, { diff --git a/hugr-model/CHANGELOG.md b/hugr-model/CHANGELOG.md index 901f077d14..4a2e4ccb35 100644 --- a/hugr-model/CHANGELOG.md +++ b/hugr-model/CHANGELOG.md @@ -1,5 +1,29 @@ # Changelog + +## [0.22.0](https://github.com/CQCL/hugr/compare/hugr-model-v0.21.0...hugr-model-v0.22.0) - 2025-07-24 + +### New Features + +- Names of private functions become `core.title` metadata. ([#2448](https://github.com/CQCL/hugr/pull/2448)) +- include generator metatada in model import and cli validate errors ([#2452](https://github.com/CQCL/hugr/pull/2452)) +- Version number in hugr binary format. ([#2468](https://github.com/CQCL/hugr/pull/2468)) +- Use semver crate for -model version, and include in docs ([#2471](https://github.com/CQCL/hugr/pull/2471)) +## [0.21.0](https://github.com/CQCL/hugr/compare/hugr-model-v0.20.2...hugr-model-v0.21.0) - 2025-07-09 + +### Bug Fixes + +- Model import should perform extension resolution ([#2326](https://github.com/CQCL/hugr/pull/2326)) +- [**breaking**] Fixed bugs in model CFG handling and improved CFG signatures ([#2334](https://github.com/CQCL/hugr/pull/2334)) +- [**breaking**] Fix panic in model resolver when variable is used outside of symbol. ([#2362](https://github.com/CQCL/hugr/pull/2362)) +- Order hints on input and output nodes. ([#2422](https://github.com/CQCL/hugr/pull/2422)) + +### New Features + +- [**breaking**] Added float and bytes literal to core and python bindings. ([#2289](https://github.com/CQCL/hugr/pull/2289)) +- [**breaking**] Add Visibility to FuncDefn/FuncDecl. ([#2143](https://github.com/CQCL/hugr/pull/2143)) +- [**breaking**] hugr-model use explicit Option, with ::Unspecified in capnp ([#2424](https://github.com/CQCL/hugr/pull/2424)) + ## [0.20.2](https://github.com/CQCL/hugr/compare/hugr-model-v0.20.1...hugr-model-v0.20.2) - 2025-06-25 ### New Features diff --git a/hugr-model/Cargo.toml b/hugr-model/Cargo.toml index 24c1f9e42e..fb67977250 100644 --- a/hugr-model/Cargo.toml +++ b/hugr-model/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "hugr-model" -version = "0.20.2" +version = "0.22.1" readme = "README.md" documentation = "https://docs.rs/hugr-model/" description = "Data model for Quantinuum's HUGR intermediate representation" @@ -27,6 +27,7 @@ ordered-float = { workspace = true } pest = { workspace = true } pest_derive = { workspace = true } pretty = { workspace = true } +semver = { workspace = true } smol_str = { workspace = true, features = ["serde"] } thiserror.workspace = true pyo3 = { workspace = true, optional = true, features = ["extension-module"] } diff --git a/hugr-model/FORMAT_VERSION b/hugr-model/FORMAT_VERSION new file mode 100644 index 0000000000..3eefcb9dd5 --- /dev/null +++ b/hugr-model/FORMAT_VERSION @@ -0,0 +1 @@ +1.0.0 diff --git a/hugr-model/capnp/hugr-v0.capnp b/hugr-model/capnp/hugr-v0.capnp index f69beb18f7..7891b7f245 100644 --- a/hugr-model/capnp/hugr-v0.capnp +++ b/hugr-model/capnp/hugr-v0.capnp @@ -20,6 +20,12 @@ using LinkIndex = UInt32; struct Package { modules @0 :List(Module); + version @1 :Version; +} + +struct Version { + major @0 :UInt32; + minor @1 :UInt32; } struct Module { @@ -61,6 +67,7 @@ struct Operation { } struct Symbol { + visibility @4 :Visibility; name @0 :Text; params @1 :List(Param); constraints @2 :List(TermId); @@ -120,3 +127,9 @@ struct Param { name @0 :Text; type @1 :TermId; } + +enum Visibility { + unspecified @0; + private @1; + public @2; +} diff --git a/hugr-model/src/capnp/hugr_v0_capnp.rs b/hugr-model/src/capnp/hugr_v0_capnp.rs index aea608cfde..9e0a6ed7b3 100644 --- a/hugr-model/src/capnp/hugr_v0_capnp.rs +++ b/hugr-model/src/capnp/hugr_v0_capnp.rs @@ -72,11 +72,19 @@ pub mod package { pub fn has_modules(&self) -> bool { !self.reader.get_pointer_field(0).is_null() } + #[inline] + pub fn get_version(self) -> ::capnp::Result> { + ::capnp::traits::FromPointerReader::get_from_pointer(&self.reader.get_pointer_field(1), ::core::option::Option::None) + } + #[inline] + pub fn has_version(&self) -> bool { + !self.reader.get_pointer_field(1).is_null() + } } pub struct Builder<'a> { builder: ::capnp::private::layout::StructBuilder<'a> } impl <> ::capnp::traits::HasStructSize for Builder<'_,> { - const STRUCT_SIZE: ::capnp::private::layout::StructSize = ::capnp::private::layout::StructSize { data: 0, pointers: 1 }; + const STRUCT_SIZE: ::capnp::private::layout::StructSize = ::capnp::private::layout::StructSize { data: 0, pointers: 2 }; } impl <> ::capnp::traits::HasTypeId for Builder<'_,> { const TYPE_ID: u64 = _private::TYPE_ID; @@ -142,6 +150,22 @@ pub mod package { pub fn has_modules(&self) -> bool { !self.builder.is_pointer_field_null(0) } + #[inline] + pub fn get_version(self) -> ::capnp::Result> { + ::capnp::traits::FromPointerBuilder::get_from_pointer(self.builder.get_pointer_field(1), ::core::option::Option::None) + } + #[inline] + pub fn set_version(&mut self, value: crate::hugr_v0_capnp::version::Reader<'_>) -> ::capnp::Result<()> { + ::capnp::traits::SetterInput::set_pointer_builder(self.builder.reborrow().get_pointer_field(1), value, false) + } + #[inline] + pub fn init_version(self, ) -> crate::hugr_v0_capnp::version::Builder<'a> { + ::capnp::traits::FromPointerBuilder::init_pointer(self.builder.get_pointer_field(1), 0) + } + #[inline] + pub fn has_version(&self) -> bool { + !self.builder.is_pointer_field_null(1) + } } pub struct Pipeline { _typeless: ::capnp::any_pointer::Pipeline } @@ -151,19 +175,23 @@ pub mod package { } } impl Pipeline { + pub fn get_version(&self) -> crate::hugr_v0_capnp::version::Pipeline { + ::capnp::capability::FromTypelessPipeline::new(self._typeless.get_pointer_field(1)) + } } mod _private { - pub static ENCODED_NODE: [::capnp::Word; 37] = [ - ::capnp::word(0, 0, 0, 0, 5, 0, 6, 0), + pub static ENCODED_NODE: [::capnp::Word; 53] = [ + ::capnp::word(0, 0, 0, 0, 6, 0, 6, 0), ::capnp::word(56, 36, 26, 168, 243, 12, 207, 208), ::capnp::word(20, 0, 0, 0, 1, 0, 0, 0), ::capnp::word(1, 150, 80, 40, 197, 50, 43, 224), - ::capnp::word(1, 0, 7, 0, 0, 0, 0, 0), + ::capnp::word(2, 0, 7, 0, 0, 0, 0, 0), ::capnp::word(0, 0, 0, 0, 0, 0, 0, 0), + ::capnp::word(93, 1, 0, 0, 166, 1, 0, 0), ::capnp::word(21, 0, 0, 0, 226, 0, 0, 0), ::capnp::word(33, 0, 0, 0, 7, 0, 0, 0), ::capnp::word(0, 0, 0, 0, 0, 0, 0, 0), - ::capnp::word(29, 0, 0, 0, 63, 0, 0, 0), + ::capnp::word(29, 0, 0, 0, 119, 0, 0, 0), ::capnp::word(0, 0, 0, 0, 0, 0, 0, 0), ::capnp::word(0, 0, 0, 0, 0, 0, 0, 0), ::capnp::word(99, 97, 112, 110, 112, 47, 104, 117), @@ -171,14 +199,21 @@ pub mod package { ::capnp::word(112, 110, 112, 58, 80, 97, 99, 107), ::capnp::word(97, 103, 101, 0, 0, 0, 0, 0), ::capnp::word(0, 0, 0, 0, 1, 0, 1, 0), - ::capnp::word(4, 0, 0, 0, 3, 0, 4, 0), + ::capnp::word(8, 0, 0, 0, 3, 0, 4, 0), ::capnp::word(0, 0, 0, 0, 0, 0, 0, 0), ::capnp::word(0, 0, 1, 0, 0, 0, 0, 0), ::capnp::word(0, 0, 0, 0, 0, 0, 0, 0), - ::capnp::word(13, 0, 0, 0, 66, 0, 0, 0), + ::capnp::word(41, 0, 0, 0, 66, 0, 0, 0), ::capnp::word(0, 0, 0, 0, 0, 0, 0, 0), - ::capnp::word(8, 0, 0, 0, 3, 0, 1, 0), - ::capnp::word(36, 0, 0, 0, 2, 0, 1, 0), + ::capnp::word(36, 0, 0, 0, 3, 0, 1, 0), + ::capnp::word(64, 0, 0, 0, 2, 0, 1, 0), + ::capnp::word(1, 0, 0, 0, 1, 0, 0, 0), + ::capnp::word(0, 0, 1, 0, 1, 0, 0, 0), + ::capnp::word(0, 0, 0, 0, 0, 0, 0, 0), + ::capnp::word(61, 0, 0, 0, 66, 0, 0, 0), + ::capnp::word(0, 0, 0, 0, 0, 0, 0, 0), + ::capnp::word(56, 0, 0, 0, 3, 0, 1, 0), + ::capnp::word(68, 0, 0, 0, 2, 0, 1, 0), ::capnp::word(109, 111, 100, 117, 108, 101, 115, 0), ::capnp::word(14, 0, 0, 0, 0, 0, 0, 0), ::capnp::word(0, 0, 0, 0, 0, 0, 0, 0), @@ -191,15 +226,24 @@ pub mod package { ::capnp::word(14, 0, 0, 0, 0, 0, 0, 0), ::capnp::word(0, 0, 0, 0, 0, 0, 0, 0), ::capnp::word(0, 0, 0, 0, 0, 0, 0, 0), + ::capnp::word(118, 101, 114, 115, 105, 111, 110, 0), + ::capnp::word(16, 0, 0, 0, 0, 0, 0, 0), + ::capnp::word(167, 171, 245, 145, 177, 155, 108, 182), + ::capnp::word(0, 0, 0, 0, 0, 0, 0, 0), + ::capnp::word(0, 0, 0, 0, 0, 0, 0, 0), + ::capnp::word(16, 0, 0, 0, 0, 0, 0, 0), + ::capnp::word(0, 0, 0, 0, 0, 0, 0, 0), + ::capnp::word(0, 0, 0, 0, 0, 0, 0, 0), ]; pub fn get_field_types(index: u16) -> ::capnp::introspect::Type { match index { 0 => <::capnp::struct_list::Owned as ::capnp::introspect::Introspect>::introspect(), - _ => panic!("invalid field index {}", index), + 1 => ::introspect(), + _ => ::capnp::introspect::panic_invalid_field_index(index), } } pub fn get_annotation_types(child_index: Option, index: u32) -> ::capnp::introspect::Type { - panic!("invalid annotation indices ({:?}, {}) ", child_index, index) + ::capnp::introspect::panic_invalid_annotation_indices(child_index, index) } pub static RAW_SCHEMA: ::capnp::introspect::RawStructSchema = ::capnp::introspect::RawStructSchema { encoded_node: &ENCODED_NODE, @@ -207,13 +251,237 @@ pub mod package { members_by_discriminant: MEMBERS_BY_DISCRIMINANT, members_by_name: MEMBERS_BY_NAME, }; - pub static NONUNION_MEMBERS : &[u16] = &[0]; + pub static NONUNION_MEMBERS : &[u16] = &[0,1]; pub static MEMBERS_BY_DISCRIMINANT : &[u16] = &[]; - pub static MEMBERS_BY_NAME : &[u16] = &[0]; + pub static MEMBERS_BY_NAME : &[u16] = &[0,1]; pub const TYPE_ID: u64 = 0xd0cf_0cf3_a81a_2438; } } +pub mod version { + #[derive(Copy, Clone)] + pub struct Owned(()); + impl ::capnp::introspect::Introspect for Owned { fn introspect() -> ::capnp::introspect::Type { ::capnp::introspect::TypeVariant::Struct(::capnp::introspect::RawBrandedStructSchema { generic: &_private::RAW_SCHEMA, field_types: _private::get_field_types, annotation_types: _private::get_annotation_types }).into() } } + impl ::capnp::traits::Owned for Owned { type Reader<'a> = Reader<'a>; type Builder<'a> = Builder<'a>; } + impl ::capnp::traits::OwnedStruct for Owned { type Reader<'a> = Reader<'a>; type Builder<'a> = Builder<'a>; } + impl ::capnp::traits::Pipelined for Owned { type Pipeline = Pipeline; } + + pub struct Reader<'a> { reader: ::capnp::private::layout::StructReader<'a> } + impl <> ::core::marker::Copy for Reader<'_,> {} + impl <> ::core::clone::Clone for Reader<'_,> { + fn clone(&self) -> Self { *self } + } + + impl <> ::capnp::traits::HasTypeId for Reader<'_,> { + const TYPE_ID: u64 = _private::TYPE_ID; + } + impl <'a,> ::core::convert::From<::capnp::private::layout::StructReader<'a>> for Reader<'a,> { + fn from(reader: ::capnp::private::layout::StructReader<'a>) -> Self { + Self { reader, } + } + } + + impl <'a,> ::core::convert::From> for ::capnp::dynamic_value::Reader<'a> { + fn from(reader: Reader<'a,>) -> Self { + Self::Struct(::capnp::dynamic_struct::Reader::new(reader.reader, ::capnp::schema::StructSchema::new(::capnp::introspect::RawBrandedStructSchema { generic: &_private::RAW_SCHEMA, field_types: _private::get_field_types::<>, annotation_types: _private::get_annotation_types::<>}))) + } + } + + impl <> ::core::fmt::Debug for Reader<'_,> { + fn fmt(&self, f: &mut ::core::fmt::Formatter<'_>) -> ::core::result::Result<(), ::core::fmt::Error> { + core::fmt::Debug::fmt(&::core::convert::Into::<::capnp::dynamic_value::Reader<'_>>::into(*self), f) + } + } + + impl <'a,> ::capnp::traits::FromPointerReader<'a> for Reader<'a,> { + fn get_from_pointer(reader: &::capnp::private::layout::PointerReader<'a>, default: ::core::option::Option<&'a [::capnp::Word]>) -> ::capnp::Result { + ::core::result::Result::Ok(reader.get_struct(default)?.into()) + } + } + + impl <'a,> ::capnp::traits::IntoInternalStructReader<'a> for Reader<'a,> { + fn into_internal_struct_reader(self) -> ::capnp::private::layout::StructReader<'a> { + self.reader + } + } + + impl <'a,> ::capnp::traits::Imbue<'a> for Reader<'a,> { + fn imbue(&mut self, cap_table: &'a ::capnp::private::layout::CapTable) { + self.reader.imbue(::capnp::private::layout::CapTableReader::Plain(cap_table)) + } + } + + impl <> Reader<'_,> { + pub fn reborrow(&self) -> Reader<'_,> { + Self { .. *self } + } + + pub fn total_size(&self) -> ::capnp::Result<::capnp::MessageSize> { + self.reader.total_size() + } + #[inline] + pub fn get_major(self) -> u32 { + self.reader.get_data_field::(0) + } + #[inline] + pub fn get_minor(self) -> u32 { + self.reader.get_data_field::(1) + } + } + + pub struct Builder<'a> { builder: ::capnp::private::layout::StructBuilder<'a> } + impl <> ::capnp::traits::HasStructSize for Builder<'_,> { + const STRUCT_SIZE: ::capnp::private::layout::StructSize = ::capnp::private::layout::StructSize { data: 1, pointers: 0 }; + } + impl <> ::capnp::traits::HasTypeId for Builder<'_,> { + const TYPE_ID: u64 = _private::TYPE_ID; + } + impl <'a,> ::core::convert::From<::capnp::private::layout::StructBuilder<'a>> for Builder<'a,> { + fn from(builder: ::capnp::private::layout::StructBuilder<'a>) -> Self { + Self { builder, } + } + } + + impl <'a,> ::core::convert::From> for ::capnp::dynamic_value::Builder<'a> { + fn from(builder: Builder<'a,>) -> Self { + Self::Struct(::capnp::dynamic_struct::Builder::new(builder.builder, ::capnp::schema::StructSchema::new(::capnp::introspect::RawBrandedStructSchema { generic: &_private::RAW_SCHEMA, field_types: _private::get_field_types::<>, annotation_types: _private::get_annotation_types::<>}))) + } + } + + impl <'a,> ::capnp::traits::ImbueMut<'a> for Builder<'a,> { + fn imbue_mut(&mut self, cap_table: &'a mut ::capnp::private::layout::CapTable) { + self.builder.imbue(::capnp::private::layout::CapTableBuilder::Plain(cap_table)) + } + } + + impl <'a,> ::capnp::traits::FromPointerBuilder<'a> for Builder<'a,> { + fn init_pointer(builder: ::capnp::private::layout::PointerBuilder<'a>, _size: u32) -> Self { + builder.init_struct(::STRUCT_SIZE).into() + } + fn get_from_pointer(builder: ::capnp::private::layout::PointerBuilder<'a>, default: ::core::option::Option<&'a [::capnp::Word]>) -> ::capnp::Result { + ::core::result::Result::Ok(builder.get_struct(::STRUCT_SIZE, default)?.into()) + } + } + + impl <> ::capnp::traits::SetterInput> for Reader<'_,> { + fn set_pointer_builder(mut pointer: ::capnp::private::layout::PointerBuilder<'_>, value: Self, canonicalize: bool) -> ::capnp::Result<()> { pointer.set_struct(&value.reader, canonicalize) } + } + + impl <'a,> Builder<'a,> { + pub fn into_reader(self) -> Reader<'a,> { + self.builder.into_reader().into() + } + pub fn reborrow(&mut self) -> Builder<'_,> { + Builder { builder: self.builder.reborrow() } + } + pub fn reborrow_as_reader(&self) -> Reader<'_,> { + self.builder.as_reader().into() + } + + pub fn total_size(&self) -> ::capnp::Result<::capnp::MessageSize> { + self.builder.as_reader().total_size() + } + #[inline] + pub fn get_major(self) -> u32 { + self.builder.get_data_field::(0) + } + #[inline] + pub fn set_major(&mut self, value: u32) { + self.builder.set_data_field::(0, value); + } + #[inline] + pub fn get_minor(self) -> u32 { + self.builder.get_data_field::(1) + } + #[inline] + pub fn set_minor(&mut self, value: u32) { + self.builder.set_data_field::(1, value); + } + } + + pub struct Pipeline { _typeless: ::capnp::any_pointer::Pipeline } + impl ::capnp::capability::FromTypelessPipeline for Pipeline { + fn new(typeless: ::capnp::any_pointer::Pipeline) -> Self { + Self { _typeless: typeless, } + } + } + impl Pipeline { + } + mod _private { + pub static ENCODED_NODE: [::capnp::Word; 49] = [ + ::capnp::word(0, 0, 0, 0, 6, 0, 6, 0), + ::capnp::word(167, 171, 245, 145, 177, 155, 108, 182), + ::capnp::word(20, 0, 0, 0, 1, 0, 1, 0), + ::capnp::word(1, 150, 80, 40, 197, 50, 43, 224), + ::capnp::word(0, 0, 7, 0, 0, 0, 0, 0), + ::capnp::word(0, 0, 0, 0, 0, 0, 0, 0), + ::capnp::word(168, 1, 0, 0, 230, 1, 0, 0), + ::capnp::word(21, 0, 0, 0, 226, 0, 0, 0), + ::capnp::word(33, 0, 0, 0, 7, 0, 0, 0), + ::capnp::word(0, 0, 0, 0, 0, 0, 0, 0), + ::capnp::word(29, 0, 0, 0, 119, 0, 0, 0), + ::capnp::word(0, 0, 0, 0, 0, 0, 0, 0), + ::capnp::word(0, 0, 0, 0, 0, 0, 0, 0), + ::capnp::word(99, 97, 112, 110, 112, 47, 104, 117), + ::capnp::word(103, 114, 45, 118, 48, 46, 99, 97), + ::capnp::word(112, 110, 112, 58, 86, 101, 114, 115), + ::capnp::word(105, 111, 110, 0, 0, 0, 0, 0), + ::capnp::word(0, 0, 0, 0, 1, 0, 1, 0), + ::capnp::word(8, 0, 0, 0, 3, 0, 4, 0), + ::capnp::word(0, 0, 0, 0, 0, 0, 0, 0), + ::capnp::word(0, 0, 1, 0, 0, 0, 0, 0), + ::capnp::word(0, 0, 0, 0, 0, 0, 0, 0), + ::capnp::word(41, 0, 0, 0, 50, 0, 0, 0), + ::capnp::word(0, 0, 0, 0, 0, 0, 0, 0), + ::capnp::word(36, 0, 0, 0, 3, 0, 1, 0), + ::capnp::word(48, 0, 0, 0, 2, 0, 1, 0), + ::capnp::word(1, 0, 0, 0, 1, 0, 0, 0), + ::capnp::word(0, 0, 1, 0, 1, 0, 0, 0), + ::capnp::word(0, 0, 0, 0, 0, 0, 0, 0), + ::capnp::word(45, 0, 0, 0, 50, 0, 0, 0), + ::capnp::word(0, 0, 0, 0, 0, 0, 0, 0), + ::capnp::word(40, 0, 0, 0, 3, 0, 1, 0), + ::capnp::word(52, 0, 0, 0, 2, 0, 1, 0), + ::capnp::word(109, 97, 106, 111, 114, 0, 0, 0), + ::capnp::word(8, 0, 0, 0, 0, 0, 0, 0), + ::capnp::word(0, 0, 0, 0, 0, 0, 0, 0), + ::capnp::word(0, 0, 0, 0, 0, 0, 0, 0), + ::capnp::word(0, 0, 0, 0, 0, 0, 0, 0), + ::capnp::word(8, 0, 0, 0, 0, 0, 0, 0), + ::capnp::word(0, 0, 0, 0, 0, 0, 0, 0), + ::capnp::word(0, 0, 0, 0, 0, 0, 0, 0), + ::capnp::word(109, 105, 110, 111, 114, 0, 0, 0), + ::capnp::word(8, 0, 0, 0, 0, 0, 0, 0), + ::capnp::word(0, 0, 0, 0, 0, 0, 0, 0), + ::capnp::word(0, 0, 0, 0, 0, 0, 0, 0), + ::capnp::word(0, 0, 0, 0, 0, 0, 0, 0), + ::capnp::word(8, 0, 0, 0, 0, 0, 0, 0), + ::capnp::word(0, 0, 0, 0, 0, 0, 0, 0), + ::capnp::word(0, 0, 0, 0, 0, 0, 0, 0), + ]; + pub fn get_field_types(index: u16) -> ::capnp::introspect::Type { + match index { + 0 => ::introspect(), + 1 => ::introspect(), + _ => ::capnp::introspect::panic_invalid_field_index(index), + } + } + pub fn get_annotation_types(child_index: Option, index: u32) -> ::capnp::introspect::Type { + ::capnp::introspect::panic_invalid_annotation_indices(child_index, index) + } + pub static RAW_SCHEMA: ::capnp::introspect::RawStructSchema = ::capnp::introspect::RawStructSchema { + encoded_node: &ENCODED_NODE, + nonunion_members: NONUNION_MEMBERS, + members_by_discriminant: MEMBERS_BY_DISCRIMINANT, + members_by_name: MEMBERS_BY_NAME, + }; + pub static NONUNION_MEMBERS : &[u16] = &[0,1]; + pub static MEMBERS_BY_DISCRIMINANT : &[u16] = &[]; + pub static MEMBERS_BY_NAME : &[u16] = &[0,1]; + pub const TYPE_ID: u64 = 0xb66c_9bb1_91f5_aba7; + } +} + pub mod module { #[derive(Copy, Clone)] pub struct Owned(()); @@ -424,13 +692,14 @@ pub mod module { impl Pipeline { } mod _private { - pub static ENCODED_NODE: [::capnp::Word; 90] = [ - ::capnp::word(0, 0, 0, 0, 5, 0, 6, 0), + pub static ENCODED_NODE: [::capnp::Word; 91] = [ + ::capnp::word(0, 0, 0, 0, 6, 0, 6, 0), ::capnp::word(167, 107, 35, 13, 152, 216, 48, 189), ::capnp::word(20, 0, 0, 0, 1, 0, 1, 0), ::capnp::word(1, 150, 80, 40, 197, 50, 43, 224), ::capnp::word(3, 0, 7, 0, 0, 0, 0, 0), ::capnp::word(0, 0, 0, 0, 0, 0, 0, 0), + ::capnp::word(232, 1, 0, 0, 98, 2, 0, 0), ::capnp::word(21, 0, 0, 0, 218, 0, 0, 0), ::capnp::word(33, 0, 0, 0, 7, 0, 0, 0), ::capnp::word(0, 0, 0, 0, 0, 0, 0, 0), @@ -522,11 +791,11 @@ pub mod module { 1 => <::capnp::struct_list::Owned as ::capnp::introspect::Introspect>::introspect(), 2 => <::capnp::struct_list::Owned as ::capnp::introspect::Introspect>::introspect(), 3 => <::capnp::struct_list::Owned as ::capnp::introspect::Introspect>::introspect(), - _ => panic!("invalid field index {}", index), + _ => ::capnp::introspect::panic_invalid_field_index(index), } } pub fn get_annotation_types(child_index: Option, index: u32) -> ::capnp::introspect::Type { - panic!("invalid annotation indices ({:?}, {}) ", child_index, index) + ::capnp::introspect::panic_invalid_annotation_indices(child_index, index) } pub static RAW_SCHEMA: ::capnp::introspect::RawStructSchema = ::capnp::introspect::RawStructSchema { encoded_node: &ENCODED_NODE, @@ -802,13 +1071,14 @@ pub mod node { } } mod _private { - pub static ENCODED_NODE: [::capnp::Word; 126] = [ - ::capnp::word(0, 0, 0, 0, 5, 0, 6, 0), + pub static ENCODED_NODE: [::capnp::Word; 127] = [ + ::capnp::word(0, 0, 0, 0, 6, 0, 6, 0), ::capnp::word(108, 130, 159, 249, 96, 124, 57, 228), ::capnp::word(20, 0, 0, 0, 1, 0, 1, 0), ::capnp::word(1, 150, 80, 40, 197, 50, 43, 224), ::capnp::word(5, 0, 7, 0, 0, 0, 0, 0), ::capnp::word(0, 0, 0, 0, 0, 0, 0, 0), + ::capnp::word(100, 2, 0, 0, 46, 3, 0, 0), ::capnp::word(21, 0, 0, 0, 202, 0, 0, 0), ::capnp::word(33, 0, 0, 0, 7, 0, 0, 0), ::capnp::word(0, 0, 0, 0, 0, 0, 0, 0), @@ -938,11 +1208,11 @@ pub mod node { 3 => <::capnp::primitive_list::Owned as ::capnp::introspect::Introspect>::introspect(), 4 => <::capnp::primitive_list::Owned as ::capnp::introspect::Introspect>::introspect(), 5 => ::introspect(), - _ => panic!("invalid field index {}", index), + _ => ::capnp::introspect::panic_invalid_field_index(index), } } pub fn get_annotation_types(child_index: Option, index: u32) -> ::capnp::introspect::Type { - panic!("invalid annotation indices ({:?}, {}) ", child_index, index) + ::capnp::introspect::panic_invalid_annotation_indices(child_index, index) } pub static RAW_SCHEMA: ::capnp::introspect::RawStructSchema = ::capnp::introspect::RawStructSchema { encoded_node: &ENCODED_NODE, @@ -1393,13 +1663,14 @@ pub mod operation { impl Pipeline { } mod _private { - pub static ENCODED_NODE: [::capnp::Word; 229] = [ - ::capnp::word(0, 0, 0, 0, 5, 0, 6, 0), + pub static ENCODED_NODE: [::capnp::Word; 230] = [ + ::capnp::word(0, 0, 0, 0, 6, 0, 6, 0), ::capnp::word(216, 191, 119, 93, 53, 241, 240, 155), ::capnp::word(20, 0, 0, 0, 1, 0, 1, 0), ::capnp::word(1, 150, 80, 40, 197, 50, 43, 224), ::capnp::word(1, 0, 7, 0, 0, 0, 14, 0), ::capnp::word(2, 0, 0, 0, 0, 0, 0, 0), + ::capnp::word(48, 3, 0, 0, 38, 5, 0, 0), ::capnp::word(21, 0, 0, 0, 242, 0, 0, 0), ::capnp::word(33, 0, 0, 0, 7, 0, 0, 0), ::capnp::word(0, 0, 0, 0, 0, 0, 0, 0), @@ -1640,11 +1911,11 @@ pub mod operation { 11 => <::capnp::text::Owned as ::capnp::introspect::Introspect>::introspect(), 12 => ::introspect(), 13 => ::introspect(), - _ => panic!("invalid field index {}", index), + _ => ::capnp::introspect::panic_invalid_field_index(index), } } pub fn get_annotation_types(child_index: Option, index: u32) -> ::capnp::introspect::Type { - panic!("invalid annotation indices ({:?}, {}) ", child_index, index) + ::capnp::introspect::panic_invalid_annotation_indices(child_index, index) } pub static RAW_SCHEMA: ::capnp::introspect::RawStructSchema = ::capnp::introspect::RawStructSchema { encoded_node: &ENCODED_NODE, @@ -1841,13 +2112,14 @@ pub mod operation { } } mod _private { - pub static ENCODED_NODE: [::capnp::Word; 48] = [ - ::capnp::word(0, 0, 0, 0, 5, 0, 6, 0), + pub static ENCODED_NODE: [::capnp::Word; 49] = [ + ::capnp::word(0, 0, 0, 0, 6, 0, 6, 0), ::capnp::word(156, 202, 42, 93, 60, 14, 161, 193), ::capnp::word(30, 0, 0, 0, 1, 0, 1, 0), ::capnp::word(216, 191, 119, 93, 53, 241, 240, 155), ::capnp::word(1, 0, 7, 0, 1, 0, 0, 0), ::capnp::word(0, 0, 0, 0, 0, 0, 0, 0), + ::capnp::word(0, 0, 0, 0, 0, 0, 0, 0), ::capnp::word(21, 0, 0, 0, 66, 1, 0, 0), ::capnp::word(0, 0, 0, 0, 0, 0, 0, 0), ::capnp::word(0, 0, 0, 0, 0, 0, 0, 0), @@ -1895,11 +2167,11 @@ pub mod operation { match index { 0 => ::introspect(), 1 => ::introspect(), - _ => panic!("invalid field index {}", index), + _ => ::capnp::introspect::panic_invalid_field_index(index), } } pub fn get_annotation_types(child_index: Option, index: u32) -> ::capnp::introspect::Type { - panic!("invalid annotation indices ({:?}, {}) ", child_index, index) + ::capnp::introspect::panic_invalid_annotation_indices(child_index, index) } pub static RAW_SCHEMA: ::capnp::introspect::RawStructSchema = ::capnp::introspect::RawStructSchema { encoded_node: &ENCODED_NODE, @@ -2004,6 +2276,10 @@ pub mod symbol { pub fn get_signature(self) -> u32 { self.reader.get_data_field::(0) } + #[inline] + pub fn get_visibility(self) -> ::core::result::Result { + ::core::convert::TryInto::try_into(self.reader.get_data_field::(2)) + } } pub struct Builder<'a> { builder: ::capnp::private::layout::StructBuilder<'a> } @@ -2114,6 +2390,14 @@ pub mod symbol { pub fn set_signature(&mut self, value: u32) { self.builder.set_data_field::(0, value); } + #[inline] + pub fn get_visibility(self) -> ::core::result::Result { + ::core::convert::TryInto::try_into(self.builder.get_data_field::(2)) + } + #[inline] + pub fn set_visibility(&mut self, value: crate::hugr_v0_capnp::Visibility) { + self.builder.set_data_field::(2, value as u16); + } } pub struct Pipeline { _typeless: ::capnp::any_pointer::Pipeline } @@ -2125,17 +2409,18 @@ pub mod symbol { impl Pipeline { } mod _private { - pub static ENCODED_NODE: [::capnp::Word; 88] = [ - ::capnp::word(0, 0, 0, 0, 5, 0, 6, 0), + pub static ENCODED_NODE: [::capnp::Word; 105] = [ + ::capnp::word(0, 0, 0, 0, 6, 0, 6, 0), ::capnp::word(63, 209, 84, 70, 225, 154, 206, 223), ::capnp::word(20, 0, 0, 0, 1, 0, 1, 0), ::capnp::word(1, 150, 80, 40, 197, 50, 43, 224), ::capnp::word(3, 0, 7, 0, 0, 0, 0, 0), ::capnp::word(0, 0, 0, 0, 0, 0, 0, 0), + ::capnp::word(40, 5, 0, 0, 195, 5, 0, 0), ::capnp::word(21, 0, 0, 0, 218, 0, 0, 0), ::capnp::word(33, 0, 0, 0, 7, 0, 0, 0), ::capnp::word(0, 0, 0, 0, 0, 0, 0, 0), - ::capnp::word(29, 0, 0, 0, 231, 0, 0, 0), + ::capnp::word(29, 0, 0, 0, 31, 1, 0, 0), ::capnp::word(0, 0, 0, 0, 0, 0, 0, 0), ::capnp::word(0, 0, 0, 0, 0, 0, 0, 0), ::capnp::word(99, 97, 112, 110, 112, 47, 104, 117), @@ -2143,35 +2428,42 @@ pub mod symbol { ::capnp::word(112, 110, 112, 58, 83, 121, 109, 98), ::capnp::word(111, 108, 0, 0, 0, 0, 0, 0), ::capnp::word(0, 0, 0, 0, 1, 0, 1, 0), - ::capnp::word(16, 0, 0, 0, 3, 0, 4, 0), - ::capnp::word(0, 0, 0, 0, 0, 0, 0, 0), + ::capnp::word(20, 0, 0, 0, 3, 0, 4, 0), + ::capnp::word(1, 0, 0, 0, 0, 0, 0, 0), ::capnp::word(0, 0, 1, 0, 0, 0, 0, 0), ::capnp::word(0, 0, 0, 0, 0, 0, 0, 0), - ::capnp::word(97, 0, 0, 0, 42, 0, 0, 0), + ::capnp::word(125, 0, 0, 0, 42, 0, 0, 0), ::capnp::word(0, 0, 0, 0, 0, 0, 0, 0), - ::capnp::word(92, 0, 0, 0, 3, 0, 1, 0), - ::capnp::word(104, 0, 0, 0, 2, 0, 1, 0), - ::capnp::word(1, 0, 0, 0, 1, 0, 0, 0), + ::capnp::word(120, 0, 0, 0, 3, 0, 1, 0), + ::capnp::word(132, 0, 0, 0, 2, 0, 1, 0), + ::capnp::word(2, 0, 0, 0, 1, 0, 0, 0), ::capnp::word(0, 0, 1, 0, 1, 0, 0, 0), ::capnp::word(0, 0, 0, 0, 0, 0, 0, 0), - ::capnp::word(101, 0, 0, 0, 58, 0, 0, 0), + ::capnp::word(129, 0, 0, 0, 58, 0, 0, 0), ::capnp::word(0, 0, 0, 0, 0, 0, 0, 0), - ::capnp::word(96, 0, 0, 0, 3, 0, 1, 0), - ::capnp::word(124, 0, 0, 0, 2, 0, 1, 0), - ::capnp::word(2, 0, 0, 0, 2, 0, 0, 0), + ::capnp::word(124, 0, 0, 0, 3, 0, 1, 0), + ::capnp::word(152, 0, 0, 0, 2, 0, 1, 0), + ::capnp::word(3, 0, 0, 0, 2, 0, 0, 0), ::capnp::word(0, 0, 1, 0, 2, 0, 0, 0), ::capnp::word(0, 0, 0, 0, 0, 0, 0, 0), - ::capnp::word(121, 0, 0, 0, 98, 0, 0, 0), + ::capnp::word(149, 0, 0, 0, 98, 0, 0, 0), ::capnp::word(0, 0, 0, 0, 0, 0, 0, 0), - ::capnp::word(120, 0, 0, 0, 3, 0, 1, 0), - ::capnp::word(148, 0, 0, 0, 2, 0, 1, 0), - ::capnp::word(3, 0, 0, 0, 0, 0, 0, 0), + ::capnp::word(148, 0, 0, 0, 3, 0, 1, 0), + ::capnp::word(176, 0, 0, 0, 2, 0, 1, 0), + ::capnp::word(4, 0, 0, 0, 0, 0, 0, 0), ::capnp::word(0, 0, 1, 0, 3, 0, 0, 0), ::capnp::word(0, 0, 0, 0, 0, 0, 0, 0), - ::capnp::word(145, 0, 0, 0, 82, 0, 0, 0), + ::capnp::word(173, 0, 0, 0, 82, 0, 0, 0), + ::capnp::word(0, 0, 0, 0, 0, 0, 0, 0), + ::capnp::word(172, 0, 0, 0, 3, 0, 1, 0), + ::capnp::word(184, 0, 0, 0, 2, 0, 1, 0), + ::capnp::word(0, 0, 0, 0, 2, 0, 0, 0), + ::capnp::word(0, 0, 1, 0, 4, 0, 0, 0), + ::capnp::word(0, 0, 0, 0, 0, 0, 0, 0), + ::capnp::word(181, 0, 0, 0, 90, 0, 0, 0), ::capnp::word(0, 0, 0, 0, 0, 0, 0, 0), - ::capnp::word(144, 0, 0, 0, 3, 0, 1, 0), - ::capnp::word(156, 0, 0, 0, 2, 0, 1, 0), + ::capnp::word(180, 0, 0, 0, 3, 0, 1, 0), + ::capnp::word(192, 0, 0, 0, 2, 0, 1, 0), ::capnp::word(110, 97, 109, 101, 0, 0, 0, 0), ::capnp::word(12, 0, 0, 0, 0, 0, 0, 0), ::capnp::word(0, 0, 0, 0, 0, 0, 0, 0), @@ -2214,6 +2506,15 @@ pub mod symbol { ::capnp::word(8, 0, 0, 0, 0, 0, 0, 0), ::capnp::word(0, 0, 0, 0, 0, 0, 0, 0), ::capnp::word(0, 0, 0, 0, 0, 0, 0, 0), + ::capnp::word(118, 105, 115, 105, 98, 105, 108, 105), + ::capnp::word(116, 121, 0, 0, 0, 0, 0, 0), + ::capnp::word(15, 0, 0, 0, 0, 0, 0, 0), + ::capnp::word(1, 131, 104, 122, 242, 21, 131, 141), + ::capnp::word(0, 0, 0, 0, 0, 0, 0, 0), + ::capnp::word(0, 0, 0, 0, 0, 0, 0, 0), + ::capnp::word(15, 0, 0, 0, 0, 0, 0, 0), + ::capnp::word(0, 0, 0, 0, 0, 0, 0, 0), + ::capnp::word(0, 0, 0, 0, 0, 0, 0, 0), ]; pub fn get_field_types(index: u16) -> ::capnp::introspect::Type { match index { @@ -2221,11 +2522,12 @@ pub mod symbol { 1 => <::capnp::struct_list::Owned as ::capnp::introspect::Introspect>::introspect(), 2 => <::capnp::primitive_list::Owned as ::capnp::introspect::Introspect>::introspect(), 3 => ::introspect(), - _ => panic!("invalid field index {}", index), + 4 => ::introspect(), + _ => ::capnp::introspect::panic_invalid_field_index(index), } } pub fn get_annotation_types(child_index: Option, index: u32) -> ::capnp::introspect::Type { - panic!("invalid annotation indices ({:?}, {}) ", child_index, index) + ::capnp::introspect::panic_invalid_annotation_indices(child_index, index) } pub static RAW_SCHEMA: ::capnp::introspect::RawStructSchema = ::capnp::introspect::RawStructSchema { encoded_node: &ENCODED_NODE, @@ -2233,9 +2535,9 @@ pub mod symbol { members_by_discriminant: MEMBERS_BY_DISCRIMINANT, members_by_name: MEMBERS_BY_NAME, }; - pub static NONUNION_MEMBERS : &[u16] = &[0,1,2,3]; + pub static NONUNION_MEMBERS : &[u16] = &[0,1,2,3,4]; pub static MEMBERS_BY_DISCRIMINANT : &[u16] = &[]; - pub static MEMBERS_BY_NAME : &[u16] = &[2,0,1,3]; + pub static MEMBERS_BY_NAME : &[u16] = &[2,0,1,3,4]; pub const TYPE_ID: u64 = 0xdfce_9ae1_4654_d13f; } } @@ -2513,13 +2815,14 @@ pub mod region { } } mod _private { - pub static ENCODED_NODE: [::capnp::Word; 141] = [ - ::capnp::word(0, 0, 0, 0, 5, 0, 6, 0), + pub static ENCODED_NODE: [::capnp::Word; 142] = [ + ::capnp::word(0, 0, 0, 0, 6, 0, 6, 0), ::capnp::word(225, 113, 253, 231, 231, 39, 130, 153), ::capnp::word(20, 0, 0, 0, 1, 0, 1, 0), ::capnp::word(1, 150, 80, 40, 197, 50, 43, 224), ::capnp::word(5, 0, 7, 0, 0, 0, 0, 0), ::capnp::word(0, 0, 0, 0, 0, 0, 0, 0), + ::capnp::word(197, 5, 0, 0, 168, 6, 0, 0), ::capnp::word(21, 0, 0, 0, 218, 0, 0, 0), ::capnp::word(33, 0, 0, 0, 7, 0, 0, 0), ::capnp::word(0, 0, 0, 0, 0, 0, 0, 0), @@ -2665,11 +2968,11 @@ pub mod region { 4 => <::capnp::primitive_list::Owned as ::capnp::introspect::Introspect>::introspect(), 5 => ::introspect(), 6 => ::introspect(), - _ => panic!("invalid field index {}", index), + _ => ::capnp::introspect::panic_invalid_field_index(index), } } pub fn get_annotation_types(child_index: Option, index: u32) -> ::capnp::introspect::Type { - panic!("invalid annotation indices ({:?}, {}) ", child_index, index) + ::capnp::introspect::panic_invalid_annotation_indices(child_index, index) } pub static RAW_SCHEMA: ::capnp::introspect::RawStructSchema = ::capnp::introspect::RawStructSchema { encoded_node: &ENCODED_NODE, @@ -2834,13 +3137,14 @@ pub mod region_scope { impl Pipeline { } mod _private { - pub static ENCODED_NODE: [::capnp::Word; 48] = [ - ::capnp::word(0, 0, 0, 0, 5, 0, 6, 0), + pub static ENCODED_NODE: [::capnp::Word; 49] = [ + ::capnp::word(0, 0, 0, 0, 6, 0, 6, 0), ::capnp::word(163, 135, 81, 30, 243, 205, 148, 170), ::capnp::word(20, 0, 0, 0, 1, 0, 1, 0), ::capnp::word(1, 150, 80, 40, 197, 50, 43, 224), ::capnp::word(0, 0, 7, 0, 0, 0, 0, 0), ::capnp::word(0, 0, 0, 0, 0, 0, 0, 0), + ::capnp::word(170, 6, 0, 0, 236, 6, 0, 0), ::capnp::word(21, 0, 0, 0, 2, 1, 0, 0), ::capnp::word(33, 0, 0, 0, 7, 0, 0, 0), ::capnp::word(0, 0, 0, 0, 0, 0, 0, 0), @@ -2888,11 +3192,11 @@ pub mod region_scope { match index { 0 => ::introspect(), 1 => ::introspect(), - _ => panic!("invalid field index {}", index), + _ => ::capnp::introspect::panic_invalid_field_index(index), } } pub fn get_annotation_types(child_index: Option, index: u32) -> ::capnp::introspect::Type { - panic!("invalid annotation indices ({:?}, {}) ", child_index, index) + ::capnp::introspect::panic_invalid_annotation_indices(child_index, index) } pub static RAW_SCHEMA: ::capnp::introspect::RawStructSchema = ::capnp::introspect::RawStructSchema { encoded_node: &ENCODED_NODE, @@ -2940,13 +3244,14 @@ impl ::capnp::traits::HasTypeId for RegionKind { const TYPE_ID: u64 = 0xe457_1af6_23a3_76b4u64; } mod region_kind { -pub static ENCODED_NODE: [::capnp::Word; 32] = [ - ::capnp::word(0, 0, 0, 0, 5, 0, 6, 0), +pub static ENCODED_NODE: [::capnp::Word; 33] = [ + ::capnp::word(0, 0, 0, 0, 6, 0, 6, 0), ::capnp::word(180, 118, 163, 35, 246, 26, 87, 228), ::capnp::word(20, 0, 0, 0, 2, 0, 0, 0), ::capnp::word(1, 150, 80, 40, 197, 50, 43, 224), ::capnp::word(0, 0, 0, 0, 0, 0, 0, 0), ::capnp::word(0, 0, 0, 0, 0, 0, 0, 0), + ::capnp::word(238, 6, 0, 0, 53, 7, 0, 0), ::capnp::word(21, 0, 0, 0, 250, 0, 0, 0), ::capnp::word(33, 0, 0, 0, 7, 0, 0, 0), ::capnp::word(0, 0, 0, 0, 0, 0, 0, 0), @@ -2975,7 +3280,7 @@ pub static ENCODED_NODE: [::capnp::Word; 32] = [ ::capnp::word(109, 111, 100, 117, 108, 101, 0, 0), ]; pub fn get_annotation_types(child_index: Option, index: u32) -> ::capnp::introspect::Type { - panic!("invalid annotation indices ({:?}, {}) ", child_index, index) + ::capnp::introspect::panic_invalid_annotation_indices(child_index, index) } } @@ -3332,13 +3637,14 @@ pub mod term { impl Pipeline { } mod _private { - pub static ENCODED_NODE: [::capnp::Word; 167] = [ - ::capnp::word(0, 0, 0, 0, 5, 0, 6, 0), + pub static ENCODED_NODE: [::capnp::Word; 168] = [ + ::capnp::word(0, 0, 0, 0, 6, 0, 6, 0), ::capnp::word(178, 107, 91, 137, 60, 121, 191, 207), ::capnp::word(20, 0, 0, 0, 1, 0, 2, 0), ::capnp::word(1, 150, 80, 40, 197, 50, 43, 224), ::capnp::word(1, 0, 7, 0, 0, 0, 10, 0), ::capnp::word(2, 0, 0, 0, 0, 0, 0, 0), + ::capnp::word(55, 7, 0, 0, 105, 9, 0, 0), ::capnp::word(21, 0, 0, 0, 202, 0, 0, 0), ::capnp::word(33, 0, 0, 0, 23, 0, 0, 0), ::capnp::word(0, 0, 0, 0, 0, 0, 0, 0), @@ -3513,11 +3819,11 @@ pub mod term { 7 => ::introspect(), 8 => <() as ::capnp::introspect::Introspect>::introspect(), 9 => <::capnp::struct_list::Owned as ::capnp::introspect::Introspect>::introspect(), - _ => panic!("invalid field index {}", index), + _ => ::capnp::introspect::panic_invalid_field_index(index), } } pub fn get_annotation_types(child_index: Option, index: u32) -> ::capnp::introspect::Type { - panic!("invalid annotation indices ({:?}, {}) ", child_index, index) + ::capnp::introspect::panic_invalid_annotation_indices(child_index, index) } pub static RAW_SCHEMA: ::capnp::introspect::RawStructSchema = ::capnp::introspect::RawStructSchema { encoded_node: &ENCODED_NODE, @@ -3715,13 +4021,14 @@ pub mod term { impl Pipeline { } mod _private { - pub static ENCODED_NODE: [::capnp::Word; 49] = [ - ::capnp::word(0, 0, 0, 0, 5, 0, 6, 0), + pub static ENCODED_NODE: [::capnp::Word; 50] = [ + ::capnp::word(0, 0, 0, 0, 6, 0, 6, 0), ::capnp::word(136, 151, 188, 135, 237, 57, 73, 141), ::capnp::word(25, 0, 0, 0, 1, 0, 1, 0), ::capnp::word(178, 107, 91, 137, 60, 121, 191, 207), ::capnp::word(0, 0, 7, 0, 0, 0, 2, 0), ::capnp::word(2, 0, 0, 0, 0, 0, 0, 0), + ::capnp::word(251, 8, 0, 0, 103, 9, 0, 0), ::capnp::word(21, 0, 0, 0, 10, 1, 0, 0), ::capnp::word(37, 0, 0, 0, 7, 0, 0, 0), ::capnp::word(0, 0, 0, 0, 0, 0, 0, 0), @@ -3770,11 +4077,11 @@ pub mod term { match index { 0 => ::introspect(), 1 => ::introspect(), - _ => panic!("invalid field index {}", index), + _ => ::capnp::introspect::panic_invalid_field_index(index), } } pub fn get_annotation_types(child_index: Option, index: u32) -> ::capnp::introspect::Type { - panic!("invalid annotation indices ({:?}, {}) ", child_index, index) + ::capnp::introspect::panic_invalid_annotation_indices(child_index, index) } pub static RAW_SCHEMA: ::capnp::introspect::RawStructSchema = ::capnp::introspect::RawStructSchema { encoded_node: &ENCODED_NODE, @@ -3957,13 +4264,14 @@ pub mod term { impl Pipeline { } mod _private { - pub static ENCODED_NODE: [::capnp::Word; 51] = [ - ::capnp::word(0, 0, 0, 0, 5, 0, 6, 0), + pub static ENCODED_NODE: [::capnp::Word; 52] = [ + ::capnp::word(0, 0, 0, 0, 6, 0, 6, 0), ::capnp::word(150, 98, 109, 181, 159, 123, 122, 222), ::capnp::word(25, 0, 0, 0, 1, 0, 2, 0), ::capnp::word(178, 107, 91, 137, 60, 121, 191, 207), ::capnp::word(1, 0, 7, 0, 1, 0, 0, 0), ::capnp::word(0, 0, 0, 0, 0, 0, 0, 0), + ::capnp::word(0, 0, 0, 0, 0, 0, 0, 0), ::capnp::word(21, 0, 0, 0, 250, 0, 0, 0), ::capnp::word(0, 0, 0, 0, 0, 0, 0, 0), ::capnp::word(0, 0, 0, 0, 0, 0, 0, 0), @@ -4014,11 +4322,11 @@ pub mod term { match index { 0 => ::introspect(), 1 => <::capnp::primitive_list::Owned as ::capnp::introspect::Introspect>::introspect(), - _ => panic!("invalid field index {}", index), + _ => ::capnp::introspect::panic_invalid_field_index(index), } } pub fn get_annotation_types(child_index: Option, index: u32) -> ::capnp::introspect::Type { - panic!("invalid annotation indices ({:?}, {}) ", child_index, index) + ::capnp::introspect::panic_invalid_annotation_indices(child_index, index) } pub static RAW_SCHEMA: ::capnp::introspect::RawStructSchema = ::capnp::introspect::RawStructSchema { encoded_node: &ENCODED_NODE, @@ -4183,13 +4491,14 @@ pub mod term { impl Pipeline { } mod _private { - pub static ENCODED_NODE: [::capnp::Word; 48] = [ - ::capnp::word(0, 0, 0, 0, 5, 0, 6, 0), + pub static ENCODED_NODE: [::capnp::Word; 49] = [ + ::capnp::word(0, 0, 0, 0, 6, 0, 6, 0), ::capnp::word(55, 205, 218, 56, 109, 17, 119, 134), ::capnp::word(25, 0, 0, 0, 1, 0, 2, 0), ::capnp::word(178, 107, 91, 137, 60, 121, 191, 207), ::capnp::word(1, 0, 7, 0, 1, 0, 0, 0), ::capnp::word(0, 0, 0, 0, 0, 0, 0, 0), + ::capnp::word(0, 0, 0, 0, 0, 0, 0, 0), ::capnp::word(21, 0, 0, 0, 18, 1, 0, 0), ::capnp::word(0, 0, 0, 0, 0, 0, 0, 0), ::capnp::word(0, 0, 0, 0, 0, 0, 0, 0), @@ -4237,11 +4546,11 @@ pub mod term { match index { 0 => ::introspect(), 1 => ::introspect(), - _ => panic!("invalid field index {}", index), + _ => ::capnp::introspect::panic_invalid_field_index(index), } } pub fn get_annotation_types(child_index: Option, index: u32) -> ::capnp::introspect::Type { - panic!("invalid annotation indices ({:?}, {}) ", child_index, index) + ::capnp::introspect::panic_invalid_annotation_indices(child_index, index) } pub static RAW_SCHEMA: ::capnp::introspect::RawStructSchema = ::capnp::introspect::RawStructSchema { encoded_node: &ENCODED_NODE, @@ -4419,13 +4728,14 @@ pub mod param { impl Pipeline { } mod _private { - pub static ENCODED_NODE: [::capnp::Word; 48] = [ - ::capnp::word(0, 0, 0, 0, 5, 0, 6, 0), + pub static ENCODED_NODE: [::capnp::Word; 49] = [ + ::capnp::word(0, 0, 0, 0, 6, 0, 6, 0), ::capnp::word(232, 73, 199, 85, 129, 167, 53, 211), ::capnp::word(20, 0, 0, 0, 1, 0, 1, 0), ::capnp::word(1, 150, 80, 40, 197, 50, 43, 224), ::capnp::word(1, 0, 7, 0, 0, 0, 0, 0), ::capnp::word(0, 0, 0, 0, 0, 0, 0, 0), + ::capnp::word(107, 9, 0, 0, 163, 9, 0, 0), ::capnp::word(21, 0, 0, 0, 210, 0, 0, 0), ::capnp::word(33, 0, 0, 0, 7, 0, 0, 0), ::capnp::word(0, 0, 0, 0, 0, 0, 0, 0), @@ -4473,11 +4783,11 @@ pub mod param { match index { 0 => <::capnp::text::Owned as ::capnp::introspect::Introspect>::introspect(), 1 => ::introspect(), - _ => panic!("invalid field index {}", index), + _ => ::capnp::introspect::panic_invalid_field_index(index), } } pub fn get_annotation_types(child_index: Option, index: u32) -> ::capnp::introspect::Type { - panic!("invalid annotation indices ({:?}, {}) ", child_index, index) + ::capnp::introspect::panic_invalid_annotation_indices(child_index, index) } pub static RAW_SCHEMA: ::capnp::introspect::RawStructSchema = ::capnp::introspect::RawStructSchema { encoded_node: &ENCODED_NODE, @@ -4491,3 +4801,75 @@ pub mod param { pub const TYPE_ID: u64 = 0xd335_a781_55c7_49e8; } } + +#[repr(u16)] +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub enum Visibility { + Unspecified = 0, + Private = 1, + Public = 2, +} + +impl ::capnp::introspect::Introspect for Visibility { + fn introspect() -> ::capnp::introspect::Type { ::capnp::introspect::TypeVariant::Enum(::capnp::introspect::RawEnumSchema { encoded_node: &visibility::ENCODED_NODE, annotation_types: visibility::get_annotation_types }).into() } +} +impl ::core::convert::From for ::capnp::dynamic_value::Reader<'_> { + fn from(e: Visibility) -> Self { ::capnp::dynamic_value::Enum::new(e.into(), ::capnp::introspect::RawEnumSchema { encoded_node: &visibility::ENCODED_NODE, annotation_types: visibility::get_annotation_types }.into()).into() } +} +impl ::core::convert::TryFrom for Visibility { + type Error = ::capnp::NotInSchema; + fn try_from(value: u16) -> ::core::result::Result>::Error> { + match value { + 0 => ::core::result::Result::Ok(Self::Unspecified), + 1 => ::core::result::Result::Ok(Self::Private), + 2 => ::core::result::Result::Ok(Self::Public), + n => ::core::result::Result::Err(::capnp::NotInSchema(n)), + } + } +} +impl From for u16 { + #[inline] + fn from(x: Visibility) -> u16 { x as u16 } +} +impl ::capnp::traits::HasTypeId for Visibility { + const TYPE_ID: u64 = 0x8d83_15f2_7a68_8301u64; +} +mod visibility { +pub static ENCODED_NODE: [::capnp::Word; 32] = [ + ::capnp::word(0, 0, 0, 0, 6, 0, 6, 0), + ::capnp::word(1, 131, 104, 122, 242, 21, 131, 141), + ::capnp::word(20, 0, 0, 0, 2, 0, 0, 0), + ::capnp::word(1, 150, 80, 40, 197, 50, 43, 224), + ::capnp::word(0, 0, 0, 0, 0, 0, 0, 0), + ::capnp::word(0, 0, 0, 0, 0, 0, 0, 0), + ::capnp::word(165, 9, 0, 0, 235, 9, 0, 0), + ::capnp::word(21, 0, 0, 0, 250, 0, 0, 0), + ::capnp::word(33, 0, 0, 0, 7, 0, 0, 0), + ::capnp::word(0, 0, 0, 0, 0, 0, 0, 0), + ::capnp::word(29, 0, 0, 0, 79, 0, 0, 0), + ::capnp::word(0, 0, 0, 0, 0, 0, 0, 0), + ::capnp::word(0, 0, 0, 0, 0, 0, 0, 0), + ::capnp::word(99, 97, 112, 110, 112, 47, 104, 117), + ::capnp::word(103, 114, 45, 118, 48, 46, 99, 97), + ::capnp::word(112, 110, 112, 58, 86, 105, 115, 105), + ::capnp::word(98, 105, 108, 105, 116, 121, 0, 0), + ::capnp::word(0, 0, 0, 0, 1, 0, 1, 0), + ::capnp::word(12, 0, 0, 0, 1, 0, 2, 0), + ::capnp::word(0, 0, 0, 0, 0, 0, 0, 0), + ::capnp::word(29, 0, 0, 0, 98, 0, 0, 0), + ::capnp::word(0, 0, 0, 0, 0, 0, 0, 0), + ::capnp::word(1, 0, 0, 0, 0, 0, 0, 0), + ::capnp::word(25, 0, 0, 0, 66, 0, 0, 0), + ::capnp::word(0, 0, 0, 0, 0, 0, 0, 0), + ::capnp::word(2, 0, 0, 0, 0, 0, 0, 0), + ::capnp::word(17, 0, 0, 0, 58, 0, 0, 0), + ::capnp::word(0, 0, 0, 0, 0, 0, 0, 0), + ::capnp::word(117, 110, 115, 112, 101, 99, 105, 102), + ::capnp::word(105, 101, 100, 0, 0, 0, 0, 0), + ::capnp::word(112, 114, 105, 118, 97, 116, 101, 0), + ::capnp::word(112, 117, 98, 108, 105, 99, 0, 0), +]; +pub fn get_annotation_types(child_index: Option, index: u32) -> ::capnp::introspect::Type { + ::capnp::introspect::panic_invalid_annotation_indices(child_index, index) +} +} diff --git a/hugr-model/src/lib.rs b/hugr-model/src/lib.rs index c139cefc42..b4b0b8fce2 100644 --- a/hugr-model/src/lib.rs +++ b/hugr-model/src/lib.rs @@ -1,10 +1,29 @@ //! The data model of the HUGR intermediate representation. +//! //! This crate defines data structures that capture the structure of a HUGR graph and //! all its associated information in a form that can be stored on disk. The data structures //! are not designed for efficient traversal or modification, but for simplicity and serialization. +//! +//! This crate supports version ` +#![doc = include_str!("../FORMAT_VERSION")] +//! ` of the HUGR model format. mod capnp; pub mod v0; +use std::sync::LazyLock; + // This is required here since the generated code assumes it's in the package root. use capnp::hugr_v0_capnp; + +/// The current version of the HUGR model format. +pub static CURRENT_VERSION: LazyLock = LazyLock::new(|| { + // We allow non-zero patch versions, but ignore them for compatibility checks. + let v = semver::Version::parse(include_str!("../FORMAT_VERSION").trim()) + .expect("`FORMAT_VERSION` in `hugr-model` contains version that fails to parse"); + assert!( + v.pre.is_empty(), + "`FORMAT_VERSION` in `hugr-model` should not have a pre-release version" + ); + v +}); diff --git a/hugr-model/src/v0/ast/hugr.pest b/hugr-model/src/v0/ast/hugr.pest index 698f32b056..d960cb3a40 100644 --- a/hugr-model/src/v0/ast/hugr.pest +++ b/hugr-model/src/v0/ast/hugr.pest @@ -16,6 +16,8 @@ reserved = @{ | "list" | "meta" | "signature" + | "public" + | "private" | "dfg" | "cfg" | "block" @@ -79,7 +81,9 @@ node_cond = { "(" ~ "cond" ~ port_lists? ~ signature? ~ meta* ~ reg node_import = { "(" ~ "import" ~ symbol_name ~ meta* ~ ")" } node_custom = { "(" ~ term ~ port_lists? ~ signature? ~ meta* ~ region* ~ ")" } -symbol = { symbol_name ~ param* ~ where_clause* ~ term } +visibility = { "public" | "private" } + +symbol = { visibility? ~ symbol_name ~ param* ~ where_clause* ~ term } signature = { "(" ~ "signature" ~ term ~ ")" } param = { "(" ~ "param" ~ term_var ~ term ~ ")" } diff --git a/hugr-model/src/v0/ast/mod.rs b/hugr-model/src/v0/ast/mod.rs index faee6f8276..b6e817b990 100644 --- a/hugr-model/src/v0/ast/mod.rs +++ b/hugr-model/src/v0/ast/mod.rs @@ -25,7 +25,7 @@ use std::sync::Arc; use bumpalo::Bump; use super::table::{self}; -use super::{LinkName, Literal, RegionKind, SymbolName, VarName}; +use super::{LinkName, Literal, RegionKind, SymbolName, VarName, Visibility}; mod parse; mod print; @@ -194,6 +194,8 @@ impl Operation { /// [`table::Symbol`]: crate::v0::table::Symbol #[derive(Debug, Clone, PartialEq, Eq)] pub struct Symbol { + /// The visibility of the symbol. + pub visibility: Option, /// The name of the symbol. pub name: SymbolName, /// The parameters of the symbol. diff --git a/hugr-model/src/v0/ast/parse.rs b/hugr-model/src/v0/ast/parse.rs index d1d5dff741..a2c9a5cd9b 100644 --- a/hugr-model/src/v0/ast/parse.rs +++ b/hugr-model/src/v0/ast/parse.rs @@ -28,7 +28,7 @@ use thiserror::Error; use crate::v0::ast::{LinkName, Module, Operation, SeqPart}; use crate::v0::{Literal, RegionKind}; -use super::{Node, Package, Param, Region, Symbol, VarName}; +use super::{Node, Package, Param, Region, Symbol, VarName, Visibility}; use super::{SymbolName, Term}; mod pest_parser { @@ -292,13 +292,23 @@ fn parse_param(pair: Pair) -> ParseResult { fn parse_symbol(pair: Pair) -> ParseResult { debug_assert_eq!(Rule::symbol, pair.as_rule()); + let mut pairs = pair.into_inner(); + let visibility = take_rule(&mut pairs, Rule::visibility) + .next() + .map(|pair| match pair.as_str() { + "public" => Ok(Visibility::Public), + "private" => Ok(Visibility::Private), + _ => unreachable!("Expected 'public' or 'private', got {}", pair.as_str()), + }) + .transpose()?; let name = parse_symbol_name(pairs.next().unwrap())?; let params = parse_params(&mut pairs)?; let constraints = parse_constraints(&mut pairs)?; let signature = parse_term(pairs.next().unwrap())?; Ok(Symbol { + visibility, name, params, constraints, diff --git a/hugr-model/src/v0/ast/print.rs b/hugr-model/src/v0/ast/print.rs index dd47602a4b..071146dedd 100644 --- a/hugr-model/src/v0/ast/print.rs +++ b/hugr-model/src/v0/ast/print.rs @@ -7,7 +7,7 @@ use crate::v0::{Literal, RegionKind}; use super::{ LinkName, Module, Node, Operation, Package, Param, Region, SeqPart, Symbol, SymbolName, Term, - VarName, + VarName, Visibility, }; struct Printer<'a> { @@ -369,6 +369,12 @@ fn print_region<'a>(printer: &mut Printer<'a>, region: &'a Region) { } fn print_symbol<'a>(printer: &mut Printer<'a>, symbol: &'a Symbol) { + match symbol.visibility { + None => (), + Some(Visibility::Private) => printer.text("private"), + Some(Visibility::Public) => printer.text("public"), + } + print_symbol_name(printer, &symbol.name); for param in &symbol.params { diff --git a/hugr-model/src/v0/ast/python.rs b/hugr-model/src/v0/ast/python.rs index 90ef22a814..b70d9447c9 100644 --- a/hugr-model/src/v0/ast/python.rs +++ b/hugr-model/src/v0/ast/python.rs @@ -1,5 +1,7 @@ use std::sync::Arc; +use crate::v0::Visibility; + use super::{Module, Node, Operation, Package, Param, Region, SeqPart, Symbol, Term}; use pyo3::{ Bound, PyAny, PyResult, @@ -139,13 +141,41 @@ impl<'py> pyo3::IntoPyObject<'py> for &Param { } } +impl<'py> pyo3::FromPyObject<'py> for Visibility { + fn extract_bound(ob: &Bound<'py, PyAny>) -> PyResult { + match ob.str()?.to_str()? { + "Public" => Ok(Visibility::Public), + "Private" => Ok(Visibility::Private), + s => Err(PyTypeError::new_err(format!( + "Expected \"Public\" or \"Private\", got {s}", + ))), + } + } +} + +impl<'py> pyo3::IntoPyObject<'py> for &Visibility { + type Target = pyo3::types::PyAny; + type Output = pyo3::Bound<'py, Self::Target>; + type Error = pyo3::PyErr; + + fn into_pyobject(self, py: pyo3::Python<'py>) -> Result { + let s = match self { + Visibility::Private => "Private", + Visibility::Public => "Public", + }; + Ok(pyo3::types::PyString::new(py, s).into_any()) + } +} + impl<'py> pyo3::FromPyObject<'py> for Symbol { fn extract_bound(symbol: &Bound<'py, PyAny>) -> PyResult { let name = symbol.getattr("name")?.extract()?; let params: Vec<_> = symbol.getattr("params")?.extract()?; + let visibility = symbol.getattr("visibility")?.extract()?; let constraints: Vec<_> = symbol.getattr("constraints")?.extract()?; let signature = symbol.getattr("signature")?.extract()?; Ok(Self { + visibility, name, signature, params: params.into(), @@ -164,6 +194,7 @@ impl<'py> pyo3::IntoPyObject<'py> for &Symbol { let py_class = py_module.getattr("Symbol")?; py_class.call1(( self.name.as_ref(), + &self.visibility, self.params.as_ref(), self.constraints.as_ref(), &self.signature, @@ -425,5 +456,6 @@ impl_into_pyobject_owned!(Symbol); impl_into_pyobject_owned!(Module); impl_into_pyobject_owned!(Package); impl_into_pyobject_owned!(Node); +impl_into_pyobject_owned!(Visibility); impl_into_pyobject_owned!(Region); impl_into_pyobject_owned!(Operation); diff --git a/hugr-model/src/v0/ast/resolve.rs b/hugr-model/src/v0/ast/resolve.rs index d691de0f01..f126d3e69e 100644 --- a/hugr-model/src/v0/ast/resolve.rs +++ b/hugr-model/src/v0/ast/resolve.rs @@ -289,11 +289,13 @@ impl<'a> Context<'a> { fn resolve_symbol(&mut self, symbol: &'a Symbol) -> BuildResult<&'a table::Symbol<'a>> { let name = symbol.name.as_ref(); + let visibility = &symbol.visibility; let params = self.resolve_params(&symbol.params)?; let constraints = self.resolve_terms(&symbol.constraints)?; let signature = self.resolve_term(&symbol.signature)?; Ok(self.bump.alloc(table::Symbol { + visibility, name, params, constraints, @@ -363,6 +365,7 @@ impl<'a> Context<'a> { /// Error that may occur in [`Module::resolve`]. #[derive(Debug, Clone, Error)] #[non_exhaustive] +#[error("Error resolving model module")] pub enum ResolveError { /// Unknown variable. #[error("unknown var: {0}")] @@ -388,3 +391,18 @@ fn try_alloc_slice( } Ok(vec.into_bump_slice()) } + +#[cfg(test)] +mod test { + use crate::v0::ast; + use bumpalo::Bump; + use std::str::FromStr as _; + + #[test] + fn vars_in_root_scope() { + let text = "(hugr 0) (mod) (meta ?x)"; + let ast = ast::Package::from_str(text).unwrap(); + let bump = Bump::new(); + assert!(ast.resolve(&bump).is_err()); + } +} diff --git a/hugr-model/src/v0/ast/view.rs b/hugr-model/src/v0/ast/view.rs index 8feb158539..8c38038100 100644 --- a/hugr-model/src/v0/ast/view.rs +++ b/hugr-model/src/v0/ast/view.rs @@ -91,11 +91,13 @@ impl<'a> View<'a, table::SeqPart> for SeqPart { impl<'a> View<'a, table::Symbol<'a>> for Symbol { fn view(module: &'a table::Module<'a>, id: table::Symbol<'a>) -> Option { + let visibility = id.visibility.clone(); let name = SymbolName::new(id.name); let params = module.view(id.params)?; let constraints = module.view(id.constraints)?; let signature = module.view(id.signature)?; Some(Symbol { + visibility, name, params, constraints, diff --git a/hugr-model/src/v0/binary/read.rs b/hugr-model/src/v0/binary/read.rs index 001a805cc9..c5f65463ba 100644 --- a/hugr-model/src/v0/binary/read.rs +++ b/hugr-model/src/v0/binary/read.rs @@ -1,6 +1,6 @@ use crate::capnp::hugr_v0_capnp as hugr_capnp; -use crate::v0 as model; use crate::v0::table; +use crate::{CURRENT_VERSION, v0 as model}; use bumpalo::Bump; use bumpalo::collections::Vec as BumpVec; use std::io::BufRead; @@ -8,10 +8,20 @@ use std::io::BufRead; /// An error encountered while deserialising a model. #[derive(Debug, derive_more::From, derive_more::Display, derive_more::Error)] #[non_exhaustive] +#[display("Error reading a HUGR model payload.")] pub enum ReadError { #[from(forward)] /// An error encountered while decoding a model from a `capnproto` buffer. DecodingError(capnp::Error), + + /// The file could not be read due to a version mismatch. + #[display("Can not read file with version {actual} (tooling version {current}).")] + VersionError { + /// The current version of the hugr-model format. + current: semver::Version, + /// The version of the hugr-model format in the file. + actual: semver::Version, + }, } type ReadResult = Result; @@ -57,6 +67,15 @@ fn read_package<'a>( bump: &'a Bump, reader: hugr_capnp::package::Reader, ) -> ReadResult> { + let version = read_version(reader.get_version()?)?; + + if version.major != CURRENT_VERSION.major || version.minor > CURRENT_VERSION.minor { + return Err(ReadError::VersionError { + current: CURRENT_VERSION.clone(), + actual: version, + }); + } + let modules = reader .get_modules()? .iter() @@ -66,6 +85,12 @@ fn read_package<'a>( Ok(table::Package { modules }) } +fn read_version(reader: hugr_capnp::version::Reader) -> ReadResult { + let major = reader.get_major(); + let minor = reader.get_minor(); + Ok(semver::Version::new(major as u64, minor as u64, 0)) +} + fn read_module<'a>( bump: &'a Bump, reader: hugr_capnp::module::Reader, @@ -126,88 +151,21 @@ fn read_operation<'a>( Which::Dfg(()) => table::Operation::Dfg, Which::Cfg(()) => table::Operation::Cfg, Which::Block(()) => table::Operation::Block, - Which::FuncDefn(reader) => { - let reader = reader?; - let name = bump.alloc_str(reader.get_name()?.to_str()?); - let params = read_list!(bump, reader.get_params()?, read_param); - let constraints = read_scalar_list!(bump, reader, get_constraints, table::TermId); - let signature = table::TermId(reader.get_signature()); - let symbol = bump.alloc(table::Symbol { - name, - params, - constraints, - signature, - }); - table::Operation::DefineFunc(symbol) - } - Which::FuncDecl(reader) => { - let reader = reader?; - let name = bump.alloc_str(reader.get_name()?.to_str()?); - let params = read_list!(bump, reader.get_params()?, read_param); - let constraints = read_scalar_list!(bump, reader, get_constraints, table::TermId); - let signature = table::TermId(reader.get_signature()); - let symbol = bump.alloc(table::Symbol { - name, - params, - constraints, - signature, - }); - table::Operation::DeclareFunc(symbol) - } + Which::FuncDefn(reader) => table::Operation::DefineFunc(read_symbol(bump, reader?, None)?), + Which::FuncDecl(reader) => table::Operation::DeclareFunc(read_symbol(bump, reader?, None)?), Which::AliasDefn(reader) => { let symbol = reader.get_symbol()?; let value = table::TermId(reader.get_value()); - let name = bump.alloc_str(symbol.get_name()?.to_str()?); - let params = read_list!(bump, symbol.get_params()?, read_param); - let signature = table::TermId(symbol.get_signature()); - let symbol = bump.alloc(table::Symbol { - name, - params, - constraints: &[], - signature, - }); - table::Operation::DefineAlias(symbol, value) + table::Operation::DefineAlias(read_symbol(bump, symbol, Some(&[]))?, value) } Which::AliasDecl(reader) => { - let reader = reader?; - let name = bump.alloc_str(reader.get_name()?.to_str()?); - let params = read_list!(bump, reader.get_params()?, read_param); - let signature = table::TermId(reader.get_signature()); - let symbol = bump.alloc(table::Symbol { - name, - params, - constraints: &[], - signature, - }); - table::Operation::DeclareAlias(symbol) + table::Operation::DeclareAlias(read_symbol(bump, reader?, Some(&[]))?) } Which::ConstructorDecl(reader) => { - let reader = reader?; - let name = bump.alloc_str(reader.get_name()?.to_str()?); - let params = read_list!(bump, reader.get_params()?, read_param); - let constraints = read_scalar_list!(bump, reader, get_constraints, table::TermId); - let signature = table::TermId(reader.get_signature()); - let symbol = bump.alloc(table::Symbol { - name, - params, - constraints, - signature, - }); - table::Operation::DeclareConstructor(symbol) + table::Operation::DeclareConstructor(read_symbol(bump, reader?, None)?) } Which::OperationDecl(reader) => { - let reader = reader?; - let name = bump.alloc_str(reader.get_name()?.to_str()?); - let params = read_list!(bump, reader.get_params()?, read_param); - let constraints = read_scalar_list!(bump, reader, get_constraints, table::TermId); - let signature = table::TermId(reader.get_signature()); - let symbol = bump.alloc(table::Symbol { - name, - params, - constraints, - signature, - }); - table::Operation::DeclareOperation(symbol) + table::Operation::DeclareOperation(read_symbol(bump, reader?, None)?) } Which::Custom(operation) => table::Operation::Custom(table::TermId(operation)), Which::TailLoop(()) => table::Operation::TailLoop, @@ -257,6 +215,40 @@ fn read_region_scope(reader: hugr_capnp::region_scope::Reader) -> ReadResult for Option { + fn from(value: hugr_capnp::Visibility) -> Self { + match value { + hugr_capnp::Visibility::Unspecified => None, + hugr_capnp::Visibility::Private => Some(model::Visibility::Private), + hugr_capnp::Visibility::Public => Some(model::Visibility::Public), + } + } +} + +/// (Only) if `constraints` are None, then they are read from the `reader` +fn read_symbol<'a>( + bump: &'a Bump, + reader: hugr_capnp::symbol::Reader, + constraints: Option<&'a [table::TermId]>, +) -> ReadResult<&'a mut table::Symbol<'a>> { + let name = bump.alloc_str(reader.get_name()?.to_str()?); + let visibility = reader.get_visibility()?.into(); + let visibility = bump.alloc(visibility); + let params = read_list!(bump, reader.get_params()?, read_param); + let constraints = match constraints { + Some(cs) => cs, + None => read_scalar_list!(bump, reader, get_constraints, table::TermId), + }; + let signature = table::TermId(reader.get_signature()); + Ok(bump.alloc(table::Symbol { + visibility, + name, + params, + constraints, + signature, + })) +} + fn read_term<'a>(bump: &'a Bump, reader: hugr_capnp::term::Reader) -> ReadResult> { use hugr_capnp::term::Which; Ok(match reader.which()? { diff --git a/hugr-model/src/v0/binary/write.rs b/hugr-model/src/v0/binary/write.rs index e9b76eca1a..49919dc481 100644 --- a/hugr-model/src/v0/binary/write.rs +++ b/hugr-model/src/v0/binary/write.rs @@ -1,13 +1,15 @@ use std::io::Write; +use crate::CURRENT_VERSION; use crate::capnp::hugr_v0_capnp as hugr_capnp; -use crate::v0 as model; -use crate::v0::table; +use crate::v0::{self as model, table}; /// An error encounter while serializing a model. #[derive(Debug, derive_more::From, derive_more::Display, derive_more::Error)] #[non_exhaustive] +#[display("Error encoding a package in HUGR model format.")] pub enum WriteError { + #[from(forward)] /// An error encountered while encoding a `capnproto` buffer. EncodingError(capnp::Error), } @@ -45,6 +47,12 @@ pub fn write_to_vec(package: &table::Package) -> Vec { fn write_package(mut builder: hugr_capnp::package::Builder, package: &table::Package) { write_list!(builder, init_modules, write_module, package.modules); + write_version(builder.init_version(), &CURRENT_VERSION); +} + +fn write_version(mut builder: hugr_capnp::version::Builder, version: &semver::Version) { + builder.set_major(version.major as u32); + builder.set_minor(version.minor as u32); } fn write_module(mut builder: hugr_capnp::module::Builder, module: &table::Module) { @@ -110,6 +118,12 @@ fn write_operation(mut builder: hugr_capnp::operation::Builder, operation: &tabl fn write_symbol(mut builder: hugr_capnp::symbol::Builder, symbol: &table::Symbol) { builder.set_name(symbol.name); + if let Some(vis) = symbol.visibility { + builder.set_visibility(match vis { + model::Visibility::Private => hugr_capnp::Visibility::Private, + model::Visibility::Public => hugr_capnp::Visibility::Public, + }) + } // else, None -> use capnp default == Unspecified write_list!(builder, init_params, write_param, symbol.params); let _ = builder.set_constraints(table::TermId::unwrap_slice(symbol.constraints)); builder.set_signature(symbol.signature.0); diff --git a/hugr-model/src/v0/mod.rs b/hugr-model/src/v0/mod.rs index 15e29f3bde..c74201a910 100644 --- a/hugr-model/src/v0/mod.rs +++ b/hugr-model/src/v0/mod.rs @@ -91,6 +91,15 @@ use smol_str::SmolStr; use std::sync::Arc; use table::LinkIndex; +/// Describes how a function or symbol should be acted upon by a linker +#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)] +pub enum Visibility { + /// The linker should ignore this function or symbol + Private, + /// The linker should act upon this function or symbol + Public, +} + /// Core function types. /// /// - **Parameter:** `?inputs : (core.list core.type)` @@ -163,17 +172,13 @@ pub const CORE_BYTES_TYPE: &str = "core.bytes"; /// - **Result:** `core.static` pub const CORE_FLOAT_TYPE: &str = "core.float"; -/// Type of a control flow edge. +/// Type of control flow regions. /// -/// - **Parameter:** `?types : (core.list core.type)` -/// - **Result:** `core.ctrl_type` +/// - **Parameter:** `?inputs : (core.list (core.list core.type))` +/// - **Parameter:** `?outputs : (core.list (core.list core.type))` +/// - **Result:** `core.type` pub const CORE_CTRL: &str = "core.ctrl"; -/// The type of the types for control flow edges. -/// -/// - **Result:** `?type : core.static` -pub const CORE_CTRL_TYPE: &str = "core.ctrl_type"; - /// The type for runtime constants. /// /// - **Parameter:** `?type : core.type` @@ -282,6 +287,26 @@ pub const COMPAT_CONST_JSON: &str = "compat.const_json"; /// - **Result:** `core.meta` pub const ORDER_HINT_KEY: &str = "core.order_hint.key"; +/// Metadata constructor for order hint keys on input nodes. +/// +/// When the sources of a dataflow region are represented by an input operation +/// within the region, this metadata can be attached the region to give the +/// input node an order hint key. +/// +/// - **Parameter:** `?key : core.nat` +/// - **Result:** `core.meta` +pub const ORDER_HINT_INPUT_KEY: &str = "core.order_hint.input_key"; + +/// Metadata constructor for order hint keys on output nodes. +/// +/// When the targets of a dataflow region are represented by an output operation +/// within the region, this metadata can be attached the region to give the +/// output node an order hint key. +/// +/// - **Parameter:** `?key : core.nat` +/// - **Result:** `core.meta` +pub const ORDER_HINT_OUTPUT_KEY: &str = "core.order_hint.output_key"; + /// Metadata constructor for order hints. /// /// When this metadata is attached to a dataflow region, it can indicate a @@ -297,6 +322,18 @@ pub const ORDER_HINT_KEY: &str = "core.order_hint.key"; /// - **Result:** `core.meta` pub const ORDER_HINT_ORDER: &str = "core.order_hint.order"; +/// Metadata constructor for symbol titles. +/// +/// The names of functions in `hugr-core` are currently not used for symbol +/// resolution, but rather serve as a short description of the function. +/// As such, there is no requirement for uniqueness or formatting. +/// This metadata can be used to preserve that name when serializing through +/// `hugr-model`. +/// +/// - **Parameter:** `?title: core.str` +/// - **Result:** `core.meta` +pub const CORE_TITLE: &str = "core.title"; + pub mod ast; pub mod binary; pub mod scope; diff --git a/hugr-model/src/v0/scope/vars.rs b/hugr-model/src/v0/scope/vars.rs index e35d8812c3..b7085e2ee8 100644 --- a/hugr-model/src/v0/scope/vars.rs +++ b/hugr-model/src/v0/scope/vars.rs @@ -78,28 +78,23 @@ impl<'a> VarTable<'a> { /// # Errors /// /// Returns an error if the variable is not defined in the current scope. - /// - /// # Panics - /// - /// Panics if there are no open scopes. pub fn resolve(&self, name: &'a str) -> Result> { - let scope = self.scopes.last().unwrap(); + let scope = self.scopes.last().ok_or(UnknownVarError::Root(name))?; let set_index = self .vars .get_index_of(&(scope.node, name)) - .ok_or(UnknownVarError(scope.node, name))?; + .ok_or(UnknownVarError::WithinNode(scope.node, name))?; let var_index = (set_index - scope.var_stack) as u16; Ok(VarId(scope.node, var_index)) } /// Check if a variable is visible in the current scope. - /// - /// # Panics - /// - /// Panics if there are no open scopes. #[must_use] pub fn is_visible(&self, var: VarId) -> bool { - let scope = self.scopes.last().unwrap(); + let Some(scope) = self.scopes.last() else { + return false; + }; + scope.node == var.0 && var.1 < scope.var_count } @@ -149,5 +144,11 @@ pub struct DuplicateVarError<'a>(NodeId, &'a str); /// Error that occurs when a variable is not defined in the current scope. #[derive(Debug, Clone, Error)] -#[error("can not resolve variable `{1}` in node {0}")] -pub struct UnknownVarError<'a>(NodeId, &'a str); +pub enum UnknownVarError<'a> { + /// Failed to resolve a variable when in scope of a node. + #[error("can not resolve variable `{1}` in node {0}")] + WithinNode(NodeId, &'a str), + /// Failed to resolve a variable when in the root scope. + #[error("can not resolve variable `{0}` in the root scope")] + Root(&'a str), +} diff --git a/hugr-model/src/v0/table/mod.rs b/hugr-model/src/v0/table/mod.rs index 501305510b..6ca6370f8f 100644 --- a/hugr-model/src/v0/table/mod.rs +++ b/hugr-model/src/v0/table/mod.rs @@ -29,7 +29,7 @@ use smol_str::SmolStr; use thiserror::Error; mod view; -use super::{Literal, RegionKind, ast}; +use super::{Literal, RegionKind, Visibility, ast}; pub use view::View; /// A package consisting of a sequence of [`Module`]s. @@ -303,6 +303,8 @@ pub struct RegionScope { /// [`ast::Symbol`]: crate::v0::ast::Symbol #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] pub struct Symbol<'a> { + /// The visibility of the symbol. + pub visibility: &'a Option, /// The name of the symbol. pub name: &'a str, /// The static parameters. diff --git a/hugr-model/tests/fixtures/model-add.edn b/hugr-model/tests/fixtures/model-add.edn index 93b3a1a5b3..5b02678744 100644 --- a/hugr-model/tests/fixtures/model-add.edn +++ b/hugr-model/tests/fixtures/model-add.edn @@ -2,14 +2,19 @@ (mod) -(define-func example.add +(define-func + public + example.add (core.fn - [arithmetic.int.types.int arithmetic.int.types.int] - [arithmetic.int.types.int]) - (dfg - [%0 %1] - [%2] - (signature (core.fn [arithmetic.int.types.int arithmetic.int.types.int] [arithmetic.int.types.int])) - (arithmetic.int.iadd - [%0 %1] [%2] - (signature (core.fn [arithmetic.int.types.int arithmetic.int.types.int] [arithmetic.int.types.int]))))) + [(arithmetic.int.types.int 6) (arithmetic.int.types.int 6)] + [(arithmetic.int.types.int 6)]) + (dfg [%0 %1] [%2] + (signature + (core.fn + [(arithmetic.int.types.int 6) (arithmetic.int.types.int 6)] + [(arithmetic.int.types.int 6)])) + ((arithmetic.int.iadd 6) [%0 %1] [%2] + (signature + (core.fn + [(arithmetic.int.types.int 6) (arithmetic.int.types.int 6)] + [(arithmetic.int.types.int 6)]))))) diff --git a/hugr-model/tests/fixtures/model-call.edn b/hugr-model/tests/fixtures/model-call.edn index 4bf2eaaac4..8fa25feaf6 100644 --- a/hugr-model/tests/fixtures/model-call.edn +++ b/hugr-model/tests/fixtures/model-call.edn @@ -2,13 +2,13 @@ (mod) -(declare-func +(declare-func public example.callee (core.fn [arithmetic.int.types.int] [arithmetic.int.types.int]) (meta (compat.meta_json "title" "\"Callee\"")) (meta (compat.meta_json "description" "\"This is a function declaration.\""))) -(define-func example.caller +(define-func public example.caller (core.fn [arithmetic.int.types.int] [arithmetic.int.types.int]) (meta (compat.meta_json "title" "\"Caller\"")) (meta (compat.meta_json "description" "\"This defines a function that calls the function which we declared earlier.\"")) @@ -17,7 +17,7 @@ ((core.call _ _ example.callee) [%3] [%4] (signature (core.fn [arithmetic.int.types.int] [arithmetic.int.types.int]))))) -(define-func +(define-func public example.load (core.fn [] [(core.fn [arithmetic.int.types.int] [arithmetic.int.types.int])]) (dfg diff --git a/hugr-model/tests/fixtures/model-cfg.edn b/hugr-model/tests/fixtures/model-cfg.edn index e2a760f5c0..bfce67854b 100644 --- a/hugr-model/tests/fixtures/model-cfg.edn +++ b/hugr-model/tests/fixtures/model-cfg.edn @@ -2,23 +2,23 @@ (mod) -(define-func example.cfg_loop +(define-func public example.cfg_loop (param ?a core.type) (core.fn [?a] [?a]) (dfg [%0] [%1] - (signature (core.fn [?a] [?a])) - (cfg [%0] [%1] - (signature (core.fn [?a] [?a])) - (cfg [%2] [%4] - (signature (core.fn [(core.ctrl [?a])] [(core.ctrl [?a])])) - (block [%2] [%4 %2] - (signature (core.fn [(core.ctrl [?a])] [(core.ctrl [?a]) (core.ctrl [?a])])) - (dfg [%5] [%6] - (signature (core.fn [?a] [(core.adt [[?a] [?a]])])) - ((core.make_adt 0) [%5] [%6] - (signature (core.fn [?a] [(core.adt [[?a] [?a]])]))))))))) + (signature (core.fn [?a] [?a])) + (cfg [%0] [%1] + (signature (core.fn [?a] [?a])) + (cfg [%2] [%3] + (signature (core.ctrl [[?a]] [[?a]])) + (block [%2] [%3 %2] + (signature (core.ctrl [[?a]] [[?a] [?a]])) + (dfg [%4] [%5] + (signature (core.fn [?a] [(core.adt [[?a] [?a]])])) + ((core.make_adt 0) [%4] [%5] + (signature (core.fn [?a] [(core.adt [[?a] [?a]])]))))))))) -(define-func example.cfg_order +(define-func public example.cfg_order (param ?a core.type) (core.fn [?a] [?a]) (dfg [%0] [%1] @@ -26,15 +26,15 @@ (cfg [%0] [%1] (signature (core.fn [?a] [?a])) (cfg [%2] [%4] - (signature (core.fn [(core.ctrl [?a])] [(core.ctrl [?a])])) + (signature (core.ctrl [[?a]] [[?a]])) (block [%3] [%4] - (signature (core.fn [(core.ctrl [?a])] [(core.ctrl [?a])])) + (signature (core.ctrl [[?a]] [[?a]])) (dfg [%5] [%6] (signature (core.fn [?a] [(core.adt [[?a]])])) ((core.make_adt _ _ 0) [%5] [%6] (signature (core.fn [?a] [(core.adt [[?a]])]))))) (block [%2] [%3] - (signature (core.fn [(core.ctrl [?a])] [(core.ctrl [?a])])) + (signature (core.ctrl [[?a]] [[?a]])) (dfg [%7] [%8] (signature (core.fn [?a] [(core.adt [[?a]])])) ((core.make_adt _ _ 0) [%7] [%8] diff --git a/hugr-model/tests/fixtures/model-cond.edn b/hugr-model/tests/fixtures/model-cond.edn index fd6fadcc86..9f49446d6d 100644 --- a/hugr-model/tests/fixtures/model-cond.edn +++ b/hugr-model/tests/fixtures/model-cond.edn @@ -2,16 +2,29 @@ (mod) -(define-func example.cond - (core.fn [(core.adt [[] []]) arithmetic.int.types.int] - [arithmetic.int.types.int]) +(define-func public + example.cond + (core.fn + [(core.adt [[] []]) (arithmetic.int.types.int 6)] + [(arithmetic.int.types.int 6)]) (dfg [%0 %1] [%2] - (signature (core.fn [(core.adt [[] []]) arithmetic.int.types.int] [arithmetic.int.types.int])) - (cond [%0 %1] [%2] - (signature (core.fn [(core.adt [[] []]) arithmetic.int.types.int] [arithmetic.int.types.int])) - (dfg [%3] [%3] - (signature (core.fn [arithmetic.int.types.int] [arithmetic.int.types.int]))) - (dfg [%4] [%5] - (signature (core.fn [arithmetic.int.types.int] [arithmetic.int.types.int])) - (arithmetic.int.ineg [%4] [%5] - (signature (core.fn [arithmetic.int.types.int] [arithmetic.int.types.int]))))))) + (signature + (core.fn + [(core.adt [[] []]) (arithmetic.int.types.int 6)] + [(arithmetic.int.types.int 6)])) + (cond [%0 %1] [%2] + (signature + (core.fn + [(core.adt [[] []]) (arithmetic.int.types.int 6)] + [(arithmetic.int.types.int 6)])) + (dfg [%3] [%3] + (signature + (core.fn [(arithmetic.int.types.int 6)] [(arithmetic.int.types.int 6)]))) + (dfg [%4] [%5] + (signature + (core.fn [(arithmetic.int.types.int 6)] [(arithmetic.int.types.int 6)])) + ((arithmetic.int.ineg 6) [%4] [%5] + (signature + (core.fn + [(arithmetic.int.types.int 6)] + [(arithmetic.int.types.int 6)]))))))) diff --git a/hugr-model/tests/fixtures/model-const.edn b/hugr-model/tests/fixtures/model-const.edn index 5d9bcb49b4..959afe2ef4 100644 --- a/hugr-model/tests/fixtures/model-const.edn +++ b/hugr-model/tests/fixtures/model-const.edn @@ -2,7 +2,7 @@ (mod) -(define-func example.bools +(define-func public example.bools (core.fn [] [(core.adt [[] []]) (core.adt [[] []])]) (dfg [] [%false %true] @@ -12,7 +12,7 @@ ((core.load_const (core.const.adt 1 (tuple))) [] [%true] (signature (core.fn [] [(core.adt [[] []])]))))) -(define-func example.make-pair +(define-func public example.make-pair (core.fn [] [(core.adt [[(collections.array.array 5 (arithmetic.int.types.int 6)) @@ -45,7 +45,7 @@ [[(collections.array.array 5 (arithmetic.int.types.int 6)) arithmetic.float.types.float64]])]))))) -(define-func example.f64-json +(define-func public example.f64-json (core.fn [] [arithmetic.float.types.float64]) (dfg [] [%0 %1] diff --git a/hugr-model/tests/fixtures/model-constraints.edn b/hugr-model/tests/fixtures/model-constraints.edn index 6884e55936..761a33c058 100644 --- a/hugr-model/tests/fixtures/model-constraints.edn +++ b/hugr-model/tests/fixtures/model-constraints.edn @@ -2,13 +2,14 @@ (mod) -(declare-func array.replicate +(declare-func private array.replicate (param ?n core.nat) (param ?t core.type) (where (core.nonlinear ?t)) (core.fn [?t] [(collections.array.array ?n ?t)])) (declare-func + public array.copy (param ?n core.nat) (param ?t core.type) @@ -18,7 +19,7 @@ [(collections.array.array ?n ?t) (collections.array.array ?n ?t)])) -(define-func util.copy +(define-func public util.copy (param ?t core.type) (where (core.nonlinear ?t)) (core.fn [?t] [?t ?t]) diff --git a/hugr-model/tests/fixtures/model-entrypoint.edn b/hugr-model/tests/fixtures/model-entrypoint.edn index 10cab9173b..fb70d10309 100644 --- a/hugr-model/tests/fixtures/model-entrypoint.edn +++ b/hugr-model/tests/fixtures/model-entrypoint.edn @@ -2,7 +2,7 @@ (mod) -(define-func main +(define-func public main (core.fn [] []) (meta core.entrypoint) (dfg [] [] @@ -10,7 +10,7 @@ (mod) -(define-func wrapper_dfg +(define-func public wrapper_dfg (core.fn [] []) (dfg [] [] (signature (core.fn [] [])) @@ -18,17 +18,17 @@ (mod) -(define-func wrapper_cfg +(define-func public wrapper_cfg (core.fn [] []) (dfg [] [] (signature (core.fn [] [])) (cfg [] [] (signature (core.fn [] [])) (cfg [%entry] [%exit] - (signature (core.fn [(core.ctrl [])] [(core.ctrl [])])) + (signature (core.ctrl [[]] [[]])) (meta core.entrypoint) (block [%entry] [%exit] - (signature (core.fn [(core.ctrl [])] [(core.ctrl [])])) + (signature (core.ctrl [[]] [[]])) (dfg [] [%value] (signature (core.fn [] [(core.adt [[]])])) ((core.make_adt _ _ 0) [] [%value] diff --git a/hugr-model/tests/fixtures/model-loop.edn b/hugr-model/tests/fixtures/model-loop.edn index 5c4a6779e3..8276ed74ba 100644 --- a/hugr-model/tests/fixtures/model-loop.edn +++ b/hugr-model/tests/fixtures/model-loop.edn @@ -2,7 +2,9 @@ (mod) -(define-func example.loop +(define-func + private + example.loop (param ?a core.type) (core.fn [?a] [?a]) (dfg [%0] [%1] diff --git a/hugr-model/tests/fixtures/model-order.edn b/hugr-model/tests/fixtures/model-order.edn index 76bf7b0ba6..ed5c1e69e9 100644 --- a/hugr-model/tests/fixtures/model-order.edn +++ b/hugr-model/tests/fixtures/model-order.edn @@ -2,49 +2,54 @@ (mod) -(define-func main +(define-func public main (core.fn - [arithmetic.int.types.int - arithmetic.int.types.int - arithmetic.int.types.int - arithmetic.int.types.int] - [arithmetic.int.types.int - arithmetic.int.types.int - arithmetic.int.types.int - arithmetic.int.types.int]) + [(arithmetic.int.types.int 6) + (arithmetic.int.types.int 6) + (arithmetic.int.types.int 6) + (arithmetic.int.types.int 6)] + [(arithmetic.int.types.int 6) + (arithmetic.int.types.int 6) + (arithmetic.int.types.int 6) + (arithmetic.int.types.int 6)]) (dfg [%0 %1 %2 %3] [%4 %5 %6 %7] (signature (core.fn - [arithmetic.int.types.int - arithmetic.int.types.int - arithmetic.int.types.int - arithmetic.int.types.int] - [arithmetic.int.types.int - arithmetic.int.types.int - arithmetic.int.types.int - arithmetic.int.types.int])) + [(arithmetic.int.types.int 6) + (arithmetic.int.types.int 6) + (arithmetic.int.types.int 6) + (arithmetic.int.types.int 6)] + [(arithmetic.int.types.int 6) + (arithmetic.int.types.int 6) + (arithmetic.int.types.int 6) + (arithmetic.int.types.int 6)])) (meta (core.order_hint.order 1 2)) (meta (core.order_hint.order 1 0)) (meta (core.order_hint.order 2 3)) (meta (core.order_hint.order 0 3)) + (meta (core.order_hint.input_key 4)) + (meta (core.order_hint.order 4 0)) + (meta (core.order_hint.order 4 5)) + (meta (core.order_hint.order 1 5)) + (meta (core.order_hint.output_key 5)) - (arithmetic.int.ineg + ((arithmetic.int.ineg 6) [%0] [%4] - (signature (core.fn [arithmetic.int.types.int] [arithmetic.int.types.int])) + (signature (core.fn [(arithmetic.int.types.int 6)] [(arithmetic.int.types.int 6)])) (meta (core.order_hint.key 0))) - (arithmetic.int.ineg + ((arithmetic.int.ineg 6) [%1] [%5] - (signature (core.fn [arithmetic.int.types.int] [arithmetic.int.types.int])) + (signature (core.fn [(arithmetic.int.types.int 6)] [(arithmetic.int.types.int 6)])) (meta (core.order_hint.key 1))) - (arithmetic.int.ineg + ((arithmetic.int.ineg 6) [%2] [%6] - (signature (core.fn [arithmetic.int.types.int] [arithmetic.int.types.int])) + (signature (core.fn [(arithmetic.int.types.int 6)] [(arithmetic.int.types.int 6)])) (meta (core.order_hint.key 2))) - (arithmetic.int.ineg + ((arithmetic.int.ineg 6) [%3] [%7] - (signature (core.fn [arithmetic.int.types.int] [arithmetic.int.types.int])) + (signature (core.fn [(arithmetic.int.types.int 6)] [(arithmetic.int.types.int 6)])) (meta (core.order_hint.key 3))))) diff --git a/hugr-model/tests/fixtures/model-params.edn b/hugr-model/tests/fixtures/model-params.edn index ba81fa6007..c4e29f933c 100644 --- a/hugr-model/tests/fixtures/model-params.edn +++ b/hugr-model/tests/fixtures/model-params.edn @@ -2,10 +2,25 @@ (mod) -(define-func example.swap +(define-func public example.swap ; The types of the values to be swapped are passed as implicit parameters. (param ?a core.type) (param ?b core.type) (core.fn [?a ?b] [?b ?a]) (dfg [%a %b] [%b %a] (signature (core.fn [?a ?b] [?b ?a])))) + +(declare-func public example.literals + (param ?a core.str) + (param ?b core.nat) + (param ?c core.bytes) + (param ?d core.float) + (core.fn [] [])) + +(define-func private example.call_literals + (core.fn [] []) + (dfg [] [] + (signature (core.fn [] [])) + ((core.call + (example.literals "string" 42 (bytes "SGVsbG8gd29ybGQg8J+Yig==") 6.023e23)) + (signature (core.fn [] []))))) diff --git a/hugr-passes/CHANGELOG.md b/hugr-passes/CHANGELOG.md index d5f2921845..7ae8e089b1 100644 --- a/hugr-passes/CHANGELOG.md +++ b/hugr-passes/CHANGELOG.md @@ -1,6 +1,37 @@ # Changelog +## [0.22.1](https://github.com/CQCL/hugr/compare/hugr-passes-v0.22.0...hugr-passes-v0.22.1) - 2025-07-28 + +### New Features + +- Include copy_discard_array in DelegatingLinearizer::default ([#2479](https://github.com/CQCL/hugr/pull/2479)) +- Inline calls to functions not on cycles in the call graph ([#2450](https://github.com/CQCL/hugr/pull/2450)) + +## [0.22.0](https://github.com/CQCL/hugr/compare/hugr-passes-v0.21.0...hugr-passes-v0.22.0) - 2025-07-24 + +### New Features + +- ReplaceTypes allows linearizing inside Op replacements ([#2435](https://github.com/CQCL/hugr/pull/2435)) +- Add pass for DFG inlining ([#2460](https://github.com/CQCL/hugr/pull/2460)) + +## [0.21.0](https://github.com/CQCL/hugr/compare/hugr-passes-v0.20.2...hugr-passes-v0.21.0) - 2025-07-09 + +### Bug Fixes + +- DeadFuncElimPass+CallGraph w/ non-module-child entrypoint ([#2390](https://github.com/CQCL/hugr/pull/2390)) + +### New Features + +- [**breaking**] No nested FuncDefns (or AliasDefns) ([#2256](https://github.com/CQCL/hugr/pull/2256)) +- [**breaking**] Split `TypeArg::Sequence` into tuples and lists. ([#2140](https://github.com/CQCL/hugr/pull/2140)) +- [**breaking**] Merge `TypeParam` and `TypeArg` into one `Term` type in Rust ([#2309](https://github.com/CQCL/hugr/pull/2309)) +- [**breaking**] Rename 'Any' type bound to 'Linear' ([#2421](https://github.com/CQCL/hugr/pull/2421)) + +### Refactor + +- [**breaking**] Reduce error type sizes ([#2420](https://github.com/CQCL/hugr/pull/2420)) + ## [0.20.2](https://github.com/CQCL/hugr/compare/hugr-passes-v0.20.1...hugr-passes-v0.20.2) - 2025-06-25 ### Bug Fixes diff --git a/hugr-passes/Cargo.toml b/hugr-passes/Cargo.toml index 8c1daafbcd..adf019fcc5 100644 --- a/hugr-passes/Cargo.toml +++ b/hugr-passes/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "hugr-passes" -version = "0.20.2" +version = "0.22.1" edition = { workspace = true } rust-version = { workspace = true } license = { workspace = true } @@ -19,7 +19,7 @@ workspace = true bench = false [dependencies] -hugr-core = { path = "../hugr-core", version = "0.20.2" } +hugr-core = { path = "../hugr-core", version = "0.22.1" } portgraph = { workspace = true } ascent = { version = "0.8.0" } derive_more = { workspace = true, features = ["display", "error", "from"] } diff --git a/hugr-passes/src/call_graph.rs b/hugr-passes/src/call_graph.rs index e33881b1f7..7baf8530dd 100644 --- a/hugr-passes/src/call_graph.rs +++ b/hugr-passes/src/call_graph.rs @@ -1,4 +1,3 @@ -#![warn(missing_docs)] //! Data structure for call graphs of a Hugr use std::collections::HashMap; @@ -6,6 +5,7 @@ use hugr_core::{HugrView, Node, core::HugrNode, ops::OpType}; use petgraph::Graph; /// Weight for an edge in a [`CallGraph`] +#[derive(Clone, Debug, PartialEq, Eq)] pub enum CallGraphEdge { /// Edge corresponds to a [Call](OpType::Call) node (specified) in the Hugr Call(N), @@ -48,19 +48,20 @@ impl CallGraph { /// Makes a new `CallGraph` for a Hugr. pub fn new(hugr: &impl HugrView) -> Self { let mut g = Graph::default(); - let non_func_root = - (!hugr.get_optype(hugr.entrypoint()).is_module()).then_some(hugr.entrypoint()); - let node_to_g = hugr + let mut node_to_g = hugr .children(hugr.module_root()) .filter_map(|n| { let weight = match hugr.get_optype(n) { OpType::FuncDecl(_) => CallGraphNode::FuncDecl(n), OpType::FuncDefn(_) => CallGraphNode::FuncDefn(n), - _ => (Some(n) == non_func_root).then_some(CallGraphNode::NonFuncRoot)?, + _ => return None, }; Some((n, g.add_node(weight))) }) .collect::>(); + if !hugr.entrypoint_optype().is_module() && !node_to_g.contains_key(&hugr.entrypoint()) { + node_to_g.insert(hugr.entrypoint(), g.add_node(CallGraphNode::NonFuncRoot)); + } for (func, cg_node) in &node_to_g { traverse(hugr, *cg_node, *func, &mut g, &node_to_g); } diff --git a/hugr-passes/src/composable.rs b/hugr-passes/src/composable.rs index d7f44fcebb..bda5e66cf7 100644 --- a/hugr-passes/src/composable.rs +++ b/hugr-passes/src/composable.rs @@ -9,11 +9,16 @@ use itertools::Either; /// An optimization pass that can be sequenced with another and/or wrapped /// e.g. by [`ValidatingPass`] pub trait ComposablePass: Sized { + /// Error thrown by this pass. type Error: Error; + /// Result returned by this pass. type Result; // Would like to default to () but currently unstable + /// Run the pass on the given HUGR. fn run(&self, hugr: &mut H) -> Result; + /// Apply a function to the error type of this pass, returning a new + /// [`ComposablePass`] that has the same result type. fn map_err( self, f: impl Fn(Self::Error) -> E2, @@ -52,7 +57,9 @@ pub trait ComposablePass: Sized { /// Trait for combining the error types from two different passes /// into a single error. pub trait ErrorCombiner: Error { + /// Create a combined error from the first pass's error. fn from_first(a: A) -> Self; + /// Create a combined error from the second pass's error. fn from_second(b: B) -> Self; } @@ -113,20 +120,33 @@ pub enum ValidatePassError where N: HugrNode + 'static, { + /// Validation failed on the initial HUGR. #[error("Failed to validate input HUGR: {err}\n{pretty_hugr}")] Input { + /// The validation error that occurred. #[source] - err: ValidationError, + err: Box>, + /// A pretty-printed representation of the HUGR that failed validation. pretty_hugr: String, }, + /// Validation failed on the final HUGR. #[error("Failed to validate output HUGR: {err}\n{pretty_hugr}")] Output { + /// The validation error that occurred. #[source] - err: ValidationError, + err: Box>, + /// A pretty-printed representation of the HUGR that failed validation. pretty_hugr: String, }, + /// An error from the underlying pass. #[error(transparent)] - Underlying(#[from] E), + Underlying(Box), +} + +impl From for ValidatePassError { + fn from(err: E) -> Self { + Self::Underlying(Box::new(err)) + } } /// Runs an underlying pass, but with validation of the Hugr @@ -134,6 +154,7 @@ where pub struct ValidatingPass(P, PhantomData); impl, H: HugrMut> ValidatingPass { + /// Return a new [`ValidatingPass`] that wraps the given underlying pass. pub fn new(underlying: P) -> Self { Self(underlying, PhantomData) } @@ -157,12 +178,12 @@ where fn run(&self, hugr: &mut H) -> Result { self.validation_impl(hugr, |err, pretty_hugr| ValidatePassError::Input { - err, + err: Box::new(err), pretty_hugr, })?; - let res = self.0.run(hugr).map_err(ValidatePassError::Underlying)?; + let res = self.0.run(hugr)?; self.validation_impl(hugr, |err, pretty_hugr| ValidatePassError::Output { - err, + err: Box::new(err), pretty_hugr, })?; Ok(res) @@ -213,7 +234,7 @@ pub(crate) fn validate_if_test, H: HugrMut>( if cfg!(test) { ValidatingPass::new(pass).run(hugr) } else { - pass.run(hugr).map_err(ValidatePassError::Underlying) + Ok(pass.run(hugr)?) } } @@ -223,8 +244,7 @@ mod test { use std::convert::Infallible; use hugr_core::builder::{ - Container, Dataflow, DataflowHugr, DataflowSubContainer, FunctionBuilder, HugrBuilder, - ModuleBuilder, + Dataflow, DataflowHugr, DataflowSubContainer, FunctionBuilder, HugrBuilder, ModuleBuilder, }; use hugr_core::extension::prelude::{ConstUsize, MakeTuple, UnpackTuple, bool_t, usize_t}; use hugr_core::hugr::hugrmut::HugrMut; @@ -304,7 +324,7 @@ mod test { assert_eq!(h, backup); // Did nothing let r = ValidatingPass::new(cfold).run(&mut h); - assert!(matches!(r, Err(ValidatePassError::Input { err: e, .. }) if e == err)); + assert!(matches!(r, Err(ValidatePassError::Input { err: e, .. }) if *e == err)); } #[test] diff --git a/hugr-passes/src/const_fold.rs b/hugr-passes/src/const_fold.rs index 11a92faa48..9c450b0aca 100644 --- a/hugr-passes/src/const_fold.rs +++ b/hugr-passes/src/const_fold.rs @@ -1,4 +1,3 @@ -#![warn(missing_docs)] //! Constant-folding pass. //! An (example) use of the [dataflow analysis framework](super::dataflow). diff --git a/hugr-passes/src/const_fold/test.rs b/hugr-passes/src/const_fold/test.rs index f4165676b2..a60684ec07 100644 --- a/hugr-passes/src/const_fold/test.rs +++ b/hugr-passes/src/const_fold/test.rs @@ -1591,7 +1591,7 @@ fn test_module() -> Result<(), Box> { // Define a top-level constant, (only) the second of which can be removed let c7 = mb.add_constant(Value::from(ConstInt::new_u(5, 7)?)); let c17 = mb.add_constant(Value::from(ConstInt::new_u(5, 17)?)); - let ad1 = mb.add_alias_declare("unused", TypeBound::Any)?; + let ad1 = mb.add_alias_declare("unused", TypeBound::Linear)?; let ad2 = mb.add_alias_def("unused2", INT_TYPES[3].clone())?; let mut main = mb.define_function( "main", diff --git a/hugr-passes/src/dataflow.rs b/hugr-passes/src/dataflow.rs index a97901c61b..3311409655 100644 --- a/hugr-passes/src/dataflow.rs +++ b/hugr-passes/src/dataflow.rs @@ -1,4 +1,3 @@ -#![warn(missing_docs)] //! Dataflow analysis of Hugrs. mod datalog; diff --git a/hugr-passes/src/dataflow/datalog.rs b/hugr-passes/src/dataflow/datalog.rs index 0eadbcdc10..0368f931bc 100644 --- a/hugr-passes/src/dataflow/datalog.rs +++ b/hugr-passes/src/dataflow/datalog.rs @@ -51,6 +51,11 @@ impl Machine { /// or [Conditional](hugr_core::ops::Conditional)). /// Any inputs not given values by `in_values`, are set to [`PartialValue::Top`]. /// Multiple calls for the same `parent` will `join` values for corresponding ports. + #[expect( + clippy::result_large_err, + reason = "Not called recursively and not a performance bottleneck" + )] + #[inline] pub fn prepopulate_inputs( &mut self, parent: H::Node, diff --git a/hugr-passes/src/dataflow/test.rs b/hugr-passes/src/dataflow/test.rs index 205d9ba4fa..915f6f3425 100644 --- a/hugr-passes/src/dataflow/test.rs +++ b/hugr-passes/src/dataflow/test.rs @@ -2,7 +2,7 @@ use std::convert::Infallible; use ascent::{Lattice, lattice::BoundedLattice}; -use hugr_core::builder::{CFGBuilder, Container, DataflowHugr, ModuleBuilder, inout_sig}; +use hugr_core::builder::{CFGBuilder, DataflowHugr, ModuleBuilder, inout_sig}; use hugr_core::ops::{CallIndirect, TailLoop}; use hugr_core::types::{ConstTypeError, TypeRow}; use hugr_core::{Hugr, Node, Wire}; @@ -409,11 +409,14 @@ fn test_call( #[case] out: PartialValue, ) { let mut builder = DFGBuilder::new(Signature::new_endo(vec![bool_t(); 2])).unwrap(); - let func_bldr = builder - .define_function("id", Signature::new_endo(bool_t())) - .unwrap(); - let [v] = func_bldr.input_wires_arr(); - let func_defn = func_bldr.finish_with_outputs([v]).unwrap(); + let func_defn = { + let mut mb = builder.module_root_builder(); + let func_bldr = mb + .define_function("id", Signature::new_endo(bool_t())) + .unwrap(); + let [v] = func_bldr.input_wires_arr(); + func_bldr.finish_with_outputs([v]).unwrap() + }; let [a, b] = builder.input_wires_arr(); let [a2] = builder .call(func_defn.handle(), &[], [a]) @@ -554,7 +557,8 @@ fn call_indirect(#[case] inp1: PartialValue, #[case] inp2: PartialValue> ComposablePass for InlineDFGsPass { + type Error = Infallible; + type Result = (); + + fn run(&self, h: &mut H) -> Result<(), Self::Error> { + let dfgs = h + .entry_descendants() + .skip(1) // Skip the entrypoint itself + .filter(|&n| h.get_optype(n).is_dfg()) + .collect_vec(); + for dfg in dfgs { + h.apply_patch(InlineDFG(dfg.into())) + .map_err(|err| -> Infallible { + match err { + InlineDFGError::CantInlineEntrypoint { .. } => { + unreachable!("We skipped the entrypoint") + } + InlineDFGError::NotDFG { .. } => unreachable!("Should be a DFG"), + _ => unreachable!("No other error cases"), + } + }) + .unwrap(); + } + Ok(()) + } +} + +#[cfg(test)] +mod test { + use hugr_core::{ + HugrView, + builder::{DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer}, + extension::prelude::qb_t, + types::Signature, + }; + + use crate::ComposablePass; + + use super::InlineDFGsPass; + + #[test] + fn inline_dfgs() -> Result<(), Box> { + let mut outer = DFGBuilder::new(Signature::new_endo(vec![qb_t(), qb_t()]))?; + let [a, b] = outer.input_wires_arr(); + + let inner1 = outer.dfg_builder_endo([(qb_t(), a)])?; + let [inner1_a] = inner1.input_wires_arr(); + let [a] = inner1.finish_with_outputs([inner1_a])?.outputs_arr(); + + let mut inner2 = outer.dfg_builder_endo([(qb_t(), b)])?; + let [inner2_b] = inner2.input_wires_arr(); + let inner2_inner = inner2.dfg_builder_endo([(qb_t(), inner2_b)])?; + let [inner2_inner_b] = inner2_inner.input_wires_arr(); + let [inner2_b] = inner2_inner + .finish_with_outputs([inner2_inner_b])? + .outputs_arr(); + let [b] = inner2.finish_with_outputs([inner2_b])?.outputs_arr(); + + let inner3 = outer.dfg_builder_endo([(qb_t(), a), (qb_t(), b)])?; + let [inner3_a, inner3_b] = inner3.input_wires_arr(); + let [a, b] = inner3 + .finish_with_outputs([inner3_a, inner3_b])? + .outputs_arr(); + + let mut h = outer.finish_hugr_with_outputs([a, b])?; + assert_eq!(h.num_nodes(), 5 * 3 + 4); // 5 DFGs with I/O + 4 nodes for module/func roots + InlineDFGsPass.run(&mut h).unwrap(); + + // Root should be the only remaining DFG + assert!(h.get_optype(h.entrypoint()).is_dfg()); + assert!( + h.entry_descendants() + .skip(1) + .all(|n| !h.get_optype(n).is_dfg()) + ); + assert_eq!(h.num_nodes(), 3 + 4); // 1 DFG with I/O + 4 nodes for module/func roots + Ok(()) + } +} diff --git a/hugr-passes/src/inline_funcs.rs b/hugr-passes/src/inline_funcs.rs new file mode 100644 index 0000000000..b999560f45 --- /dev/null +++ b/hugr-passes/src/inline_funcs.rs @@ -0,0 +1,229 @@ +//! Contains a pass to inline calls to selected functions in a Hugr. +use std::collections::{HashSet, VecDeque}; + +use hugr_core::hugr::hugrmut::HugrMut; +use hugr_core::hugr::patch::inline_call::InlineCall; +use itertools::Itertools; +use petgraph::algo::tarjan_scc; + +use crate::call_graph::{CallGraph, CallGraphNode}; + +/// Error raised by [inline_acyclic] +#[derive(Clone, Debug, thiserror::Error, PartialEq)] +#[non_exhaustive] +pub enum InlineFuncsError {} + +/// Inline (a subset of) [Call]s whose target [FuncDefn]s are not in cycles of the call +/// graph. +/// +/// The function `call_predicate` is passed each such [Call] node and can return +/// `false` to prevent that Call from being inlined. (Note the [Call] may be created as +/// a result of previous inlinings so may not have existed in the original Hugr). +/// +/// [Call]: hugr_core::ops::Call +/// [FuncDefn]: hugr_core::ops::FuncDefn +pub fn inline_acyclic( + h: &mut H, + call_predicate: impl Fn(&H, H::Node) -> bool, +) -> Result<(), InlineFuncsError> { + let cg = CallGraph::new(&*h); + let g = cg.graph(); + let all_funcs_in_cycles = tarjan_scc(g) + .into_iter() + .flat_map(|mut ns| { + if let Ok(n) = ns.iter().exactly_one() { + if g.edges_connecting(*n, *n).next().is_none() { + ns.clear(); // Single-node SCC has no self edge, so discard + } + } + ns.into_iter().map(|n| { + let CallGraphNode::FuncDefn(fd) = g.node_weight(n).unwrap() else { + panic!("Expected only FuncDefns in sccs") + }; + *fd + }) + }) + .collect::>(); + let target_funcs: HashSet = h + .children(h.module_root()) + .filter(|n| h.get_optype(*n).is_func_defn() && !all_funcs_in_cycles.contains(n)) + .collect(); + let mut q = VecDeque::from([h.entrypoint()]); + while let Some(n) = q.pop_front() { + if h.get_optype(n).is_call() { + if let Some(t) = h.static_source(n) { + if target_funcs.contains(&t) && call_predicate(h, n) { + // We've already checked all error conditions + h.apply_patch(InlineCall::new(n)).unwrap(); + } + } + } + // Traverse children - including any resulting from turning Call into DFG + q.extend(h.children(n)); + } + Ok(()) +} + +#[cfg(test)] +mod test { + use std::collections::HashSet; + + use hugr_core::core::HugrNode; + use hugr_core::ops::OpType; + use itertools::Itertools; + use petgraph::visit::EdgeRef; + + use hugr_core::HugrView; + use hugr_core::builder::{Dataflow, DataflowSubContainer, HugrBuilder, ModuleBuilder}; + use hugr_core::{Hugr, extension::prelude::qb_t, types::Signature}; + use rstest::rstest; + + use crate::call_graph::{CallGraph, CallGraphNode}; + use crate::inline_funcs::inline_acyclic; + + /// /->-\ + /// main -> f g -> b -> c + /// / \-<-/ + /// / + /// \-> a -> x + fn make_test_hugr() -> Hugr { + let sig = || Signature::new_endo(qb_t()); + let mut mb = ModuleBuilder::new(); + let x = mb.declare("x", sig().into()).unwrap(); + let a = { + let mut fb = mb.define_function("a", sig()).unwrap(); + let ins = fb.input_wires(); + let res = fb.call(&x, &[], ins).unwrap(); + fb.finish_with_outputs(res.outputs()).unwrap() + }; + let c = { + let fb = mb.define_function("c", sig()).unwrap(); + let ins = fb.input_wires(); + fb.finish_with_outputs(ins).unwrap() + }; + let b = { + let mut fb = mb.define_function("b", sig()).unwrap(); + let ins = fb.input_wires(); + let res = fb.call(c.handle(), &[], ins).unwrap().outputs(); + fb.finish_with_outputs(res).unwrap() + }; + let f = mb.declare("f", sig().into()).unwrap(); + let g = { + let mut fb = mb.define_function("g", sig()).unwrap(); + let ins = fb.input_wires(); + let c1 = fb.call(&f, &[], ins).unwrap(); + let c2 = fb.call(b.handle(), &[], c1.outputs()).unwrap(); + fb.finish_with_outputs(c2.outputs()).unwrap() + }; + let _f = { + let mut fb = mb.define_declaration(&f).unwrap(); + let ins = fb.input_wires(); + let c1 = fb.call(g.handle(), &[], ins).unwrap(); + let c2 = fb.call(a.handle(), &[], c1.outputs()).unwrap(); + fb.finish_with_outputs(c2.outputs()).unwrap() + }; + mb.finish_hugr().unwrap() + } + + fn find_func(h: &H, name: &str) -> H::Node { + h.children(h.module_root()) + .find(|n| { + h.get_optype(*n) + .as_func_defn() + .is_some_and(|fd| fd.func_name() == name) + }) + .unwrap() + } + + #[rstest] + #[case(["a", "b", "c"], ["a", "b", "c"], [vec!["g", "x"], vec!["f"], vec!["x"], vec![], vec![]])] + #[case(["a", "b"], ["a", "b"], [vec!["g", "x"], vec!["f", "c"], vec!["x"], vec!["c"], vec![]])] + #[case(["c"], ["c"], [vec!["g", "a"], vec!("f", "b"), vec!["x"], vec![], vec![]])] + fn test_inline( + #[case] req: impl IntoIterator, + #[case] check_not_called: impl IntoIterator, + #[case] calls_fgabc: [Vec<&'static str>; 5], + ) { + let mut h = make_test_hugr(); + let target_funcs = req + .into_iter() + .map(|name| find_func(&h, name)) + .collect::>(); + inline_acyclic(&mut h, |h, call| { + let tgt = h.static_source(call).unwrap(); + // Check the callback is never asked about an impossible inlining + assert!(["a", "b", "c"].contains(&func_name(h, tgt).as_str())); + target_funcs.contains(&tgt) + }) + .unwrap(); + let cg = CallGraph::new(&h); + for fname in check_not_called { + let fnode = find_func(&h, fname); + let fnode = cg.node_index(fnode).unwrap(); + assert_eq!( + None, + cg.graph() + .edges_directed(fnode, petgraph::Direction::Incoming) + .next() + ); + } + for (fname, tgts) in ["f", "g", "a", "b", "c"].into_iter().zip_eq(calls_fgabc) { + let fnode = find_func(&h, fname); + assert_eq!( + outgoing_calls(&cg, fnode) + .into_iter() + .map(|n| func_name(&h, n).as_str()) + .collect::>(), + HashSet::from_iter(tgts), + "Calls from {fname}" + ); + } + } + + fn outgoing_calls(cg: &CallGraph, src: N) -> Vec { + let src = cg.node_index(src).unwrap(); + cg.graph() + .edges_directed(src, petgraph::Direction::Outgoing) + .map(|e| func_node(cg.graph().node_weight(e.target()).unwrap())) + .collect() + } + + #[test] + fn test_filter_caller() { + let mut h = make_test_hugr(); + let [g, b, c] = ["g", "b", "c"].map(|n| find_func(&h, n)); + // Inline calls contained within `g` + inline_acyclic(&mut h, |h, mut call| { + loop { + if call == g { + return true; + }; + let Some(parent) = h.get_parent(call) else { + return false; + }; + call = parent; + } + }) + .unwrap(); + let cg = CallGraph::new(&h); + // b and then c should have been inlined into g, leaving only cyclic call to f + assert_eq!(outgoing_calls(&cg, g), [find_func(&h, "f")]); + // But c should not have been inlined into b: + assert_eq!(outgoing_calls(&cg, b), [c]); + } + + fn func_node(cgn: &CallGraphNode) -> N { + match cgn { + CallGraphNode::FuncDecl(n) | CallGraphNode::FuncDefn(n) => *n, + CallGraphNode::NonFuncRoot => panic!(), + } + } + + fn func_name(h: &H, n: H::Node) -> &String { + match h.get_optype(n) { + OpType::FuncDecl(fd) => fd.func_name(), + OpType::FuncDefn(fd) => fd.func_name(), + _ => panic!(), + } + } +} diff --git a/hugr-passes/src/lib.rs b/hugr-passes/src/lib.rs index c82fc5abe6..6e97f3422e 100644 --- a/hugr-passes/src/lib.rs +++ b/hugr-passes/src/lib.rs @@ -1,5 +1,4 @@ //! Compilation passes acting on the HUGR program representation. -#![expect(missing_docs)] // TODO: Fix... pub mod call_graph; pub mod composable; @@ -12,6 +11,9 @@ mod dead_funcs; pub use dead_funcs::{RemoveDeadFuncsError, RemoveDeadFuncsPass, remove_dead_funcs}; pub mod force_order; mod half_node; +pub mod inline_dfgs; +pub mod inline_funcs; +pub use inline_funcs::inline_acyclic; pub mod linearize_array; pub use linearize_array::LinearizeArrayPass; pub mod lower; diff --git a/hugr-passes/src/linearize_array.rs b/hugr-passes/src/linearize_array.rs index 4f8da110c9..07fbc6e958 100644 --- a/hugr-passes/src/linearize_array.rs +++ b/hugr-passes/src/linearize_array.rs @@ -10,7 +10,7 @@ use hugr_core::{ std_extensions::collections::{ array::{ ARRAY_REPEAT_OP_ID, ARRAY_SCAN_OP_ID, Array, ArrayKind, ArrayOpDef, ArrayRepeatDef, - ArrayScanDef, ArrayValue, array_type_def, array_type_parametric, + ArrayScanDef, ArrayValue, array_type_parametric, }, value_array::{self, VArrayFromArrayDef, VArrayToArrayDef, VArrayValue, ValueArray}, }, @@ -21,9 +21,7 @@ use strum::IntoEnumIterator; use crate::{ ComposablePass, ReplaceTypes, - replace_types::{ - DelegatingLinearizer, NodeTemplate, ReplaceTypesError, handlers::copy_discard_array, - }, + replace_types::{DelegatingLinearizer, NodeTemplate, ReplaceTypesError}, }; /// A HUGR -> HUGR pass that turns 'value_array`s into regular linear `array`s. @@ -66,7 +64,7 @@ impl Default for LinearizeArrayPass { // error out and make sure we're not emitting `get`s for nested value // arrays. assert!( - op_def != ArrayOpDef::get || args[1].as_type().unwrap().copyable(), + op_def != ArrayOpDef::get || args[1].as_runtime().unwrap().copyable(), "Cannot linearise arrays in this Hugr: \ Contains a `get` operation on nested value arrays" ); @@ -114,8 +112,6 @@ impl Default for LinearizeArrayPass { )) }, ); - pass.linearizer() - .register_callback(array_type_def(), copy_discard_array); Self(pass) } } diff --git a/hugr-passes/src/lower.rs b/hugr-passes/src/lower.rs index 45b5ce9080..db9e60e135 100644 --- a/hugr-passes/src/lower.rs +++ b/hugr-passes/src/lower.rs @@ -1,3 +1,5 @@ +//! Passes to lower operations in a HUGR. + use hugr_core::{ Hugr, Node, hugr::{hugrmut::HugrMut, views::SiblingSubgraph}, diff --git a/hugr-passes/src/monomorphize.rs b/hugr-passes/src/monomorphize.rs index 2a97f75240..2d5abd5eb1 100644 --- a/hugr-passes/src/monomorphize.rs +++ b/hugr-passes/src/monomorphize.rs @@ -133,25 +133,7 @@ fn instantiate( mono_sig: Signature, cache: &mut Instantiations, ) -> Node { - let for_func = cache.entry(poly_func).or_insert_with(|| { - // First time we've instantiated poly_func. Lift any nested FuncDefn's out to the same level. - let outer_name = h - .get_optype(poly_func) - .as_func_defn() - .unwrap() - .func_name() - .clone(); - let mut to_scan = Vec::from_iter(h.children(poly_func)); - while let Some(n) = to_scan.pop() { - if let OpType::FuncDefn(fd) = h.optype_mut(n) { - *fd.func_name_mut() = mangle_inner_func(&outer_name, fd.func_name()); - h.move_after_sibling(n, poly_func); - } else { - to_scan.extend(h.children(n)); - } - } - HashMap::new() - }); + let for_func = cache.entry(poly_func).or_default(); let ve = match for_func.entry(type_args.clone()) { Entry::Occupied(n) => return *n.get(), @@ -231,9 +213,10 @@ impl> ComposablePass for MonomorphizePass { } } -struct TypeArgsList<'a>(&'a [TypeArg]); +/// Helper to create mangled representations of lists of [TypeArg]s. +struct TypeArgsSeq<'a>(&'a [TypeArg]); -impl std::fmt::Display for TypeArgsList<'_> { +impl std::fmt::Display for TypeArgsSeq<'_> { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { for arg in self.0 { f.write_char('$')?; @@ -249,13 +232,14 @@ fn escape_dollar(str: impl AsRef) -> String { fn write_type_arg_str(arg: &TypeArg, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match arg { - TypeArg::Type { ty } => f.write_fmt(format_args!("t({})", escape_dollar(ty.to_string()))), - TypeArg::BoundedNat { n } => f.write_fmt(format_args!("n({n})")), - TypeArg::String { arg } => f.write_fmt(format_args!("s({})", escape_dollar(arg))), - TypeArg::Sequence { elems } => f.write_fmt(format_args!("seq({})", TypeArgsList(elems))), + TypeArg::Runtime(ty) => f.write_fmt(format_args!("t({})", escape_dollar(ty.to_string()))), + TypeArg::BoundedNat(n) => f.write_fmt(format_args!("n({n})")), + TypeArg::String(arg) => f.write_fmt(format_args!("s({})", escape_dollar(arg))), + TypeArg::List(elems) => f.write_fmt(format_args!("list({})", TypeArgsSeq(elems))), + TypeArg::Tuple(elems) => f.write_fmt(format_args!("tuple({})", TypeArgsSeq(elems))), // We are monomorphizing. We will never monomorphize to a signature // containing a variable. - TypeArg::Variable { .. } => panic!("type_arg_str variable: {arg}"), + TypeArg::Variable(_) => panic!("type_arg_str variable: {arg}"), _ => panic!("unknown type arg: {arg}"), } } @@ -275,11 +259,7 @@ fn write_type_arg_str(arg: &TypeArg, f: &mut std::fmt::Formatter<'_>) -> std::fm /// is used as "t({arg})" for the string representation of that arg. pub fn mangle_name(name: &str, type_args: impl AsRef<[TypeArg]>) -> String { let name = escape_dollar(name); - format!("${name}${}", TypeArgsList(type_args.as_ref())) -} - -fn mangle_inner_func(outer_name: &str, inner_name: &str) -> String { - format!("${outer_name}${inner_name}") + format!("${name}${}", TypeArgsSeq(type_args.as_ref())) } #[cfg(test)] @@ -288,6 +268,7 @@ mod test { use std::iter; use hugr_core::extension::simple_op::MakeRegisteredOp as _; + use hugr_core::hugr::hugrmut::HugrMut; use hugr_core::std_extensions::arithmetic::int_types::INT_TYPES; use hugr_core::std_extensions::collections; use hugr_core::std_extensions::collections::array::ArrayKind; @@ -308,7 +289,7 @@ mod test { use crate::{monomorphize, remove_dead_funcs}; - use super::{is_polymorphic, mangle_inner_func, mangle_name}; + use super::{is_polymorphic, mangle_name}; fn pair_type(ty: Type) -> Type { Type::new_tuple(vec![ty.clone(), ty]) @@ -425,13 +406,12 @@ mod test { } #[test] - fn test_flattening_multiargs_nats() { + fn test_multiargs_nats() { //pf1 contains pf2 contains mono_func -> pf1 and pf1 share pf2's and they share mono_func let tv = |i| Type::new_var_use(i, TypeBound::Copyable); - let sv = |i| TypeArg::new_var_use(i, TypeParam::max_nat()); - let sa = |n| TypeArg::BoundedNat { n }; - + let sv = |i| TypeArg::new_var_use(i, TypeParam::max_nat_type()); + let sa = |n| TypeArg::BoundedNat(n); let n: u64 = 5; let mut outer = FunctionBuilder::new( "mainish", @@ -447,32 +427,23 @@ mod test { .unwrap(); let arr2u = || ValueArray::ty_parametric(sa(2), usize_t()).unwrap(); - let pf1t = PolyFuncType::new( - [TypeParam::max_nat()], - Signature::new( - ValueArray::ty_parametric(sv(0), arr2u()).unwrap(), - usize_t(), - ), - ); - let mut pf1 = outer.define_function("pf1", pf1t).unwrap(); - let pf2t = PolyFuncType::new( - [TypeParam::max_nat(), TypeBound::Copyable.into()], - Signature::new( - vec![ValueArray::ty_parametric(sv(0), tv(1)).unwrap()], - tv(1), - ), - ); - let mut pf2 = pf1.define_function("pf2", pf2t).unwrap(); + let mut mb = outer.module_root_builder(); let mono_func = { - let mut fb = pf2 + let mut fb = mb .define_function("get_usz", Signature::new(vec![], usize_t())) .unwrap(); let cst0 = fb.add_load_value(ConstUsize::new(1)); fb.finish_with_outputs([cst0]).unwrap() }; + let pf2 = { + let pf2t = PolyFuncType::new( + [TypeParam::max_nat_type(), TypeBound::Copyable.into()], + Signature::new(ValueArray::ty_parametric(sv(0), tv(1)).unwrap(), tv(1)), + ); + let mut pf2 = mb.define_function("pf2", pf2t).unwrap(); let [inw] = pf2.input_wires_arr(); let [idx] = pf2.call(mono_func.handle(), &[], []).unwrap().outputs_arr(); let op_def = collections::value_array::EXTENSION.get_op("get").unwrap(); @@ -484,6 +455,16 @@ mod test { .unwrap(); pf2.finish_with_outputs([got]).unwrap() }; + + let pf1t = PolyFuncType::new( + [TypeParam::max_nat_type()], + Signature::new( + ValueArray::ty_parametric(sv(0), arr2u()).unwrap(), + usize_t(), + ), + ); + let mut pf1 = mb.define_function("pf1", pf1t).unwrap(); + // pf1: Two calls to pf2, one depending on pf1's TypeArg, the other not let inner = pf1 .call(pf2.handle(), &[sv(0), arr2u().into()], pf1.input_wires()) @@ -491,11 +472,12 @@ mod test { let elem = pf1 .call( pf2.handle(), - &[TypeArg::BoundedNat { n: 2 }, usize_t().into()], + &[TypeArg::BoundedNat(2), usize_t().into()], inner.outputs(), ) .unwrap(); let pf1 = pf1.finish_with_outputs(elem.outputs()).unwrap(); + // Outer: two calls to pf1 with different TypeArgs let [e1] = outer .call(pf1.handle(), &[sa(n)], outer.input_wires()) @@ -516,23 +498,24 @@ mod test { .call(pf1.handle(), &[sa(n - 1)], [ar2_unwrapped]) .unwrap() .outputs_arr(); + let outer_func = outer.container_node(); let mut hugr = outer.finish_hugr_with_outputs([e1, e2]).unwrap(); + hugr.set_entrypoint(hugr.module_root()); // We want to act on everything, not just `main` monomorphize(&mut hugr).unwrap(); let mono_hugr = hugr; mono_hugr.validate().unwrap(); let funcs = list_funcs(&mono_hugr); - let pf2_name = mangle_inner_func("pf1", "pf2"); assert_eq!( funcs.keys().copied().sorted().collect_vec(), vec![ - &mangle_name("pf1", &[TypeArg::BoundedNat { n: 5 }]), - &mangle_name("pf1", &[TypeArg::BoundedNat { n: 4 }]), - &mangle_name(&pf2_name, &[TypeArg::BoundedNat { n: 5 }, arr2u().into()]), // from pf1<5> - &mangle_name(&pf2_name, &[TypeArg::BoundedNat { n: 4 }, arr2u().into()]), // from pf1<4> - &mangle_name(&pf2_name, &[TypeArg::BoundedNat { n: 2 }, usize_t().into()]), // from both pf1<4> and <5> - &mangle_inner_func(&pf2_name, "get_usz"), - &pf2_name, + &mangle_name("pf1", &[TypeArg::BoundedNat(5)]), + &mangle_name("pf1", &[TypeArg::BoundedNat(4)]), + &mangle_name("pf2", &[TypeArg::BoundedNat(5), arr2u().into()]), // from pf1<5> + &mangle_name("pf2", &[TypeArg::BoundedNat(4), arr2u().into()]), // from pf1<4> + &mangle_name("pf2", &[TypeArg::BoundedNat(2), usize_t().into()]), // from both pf1<4> and <5> + "get_usz", + "pf2", "mainish", "pf1" ] @@ -540,13 +523,10 @@ mod test { .sorted() .collect_vec() ); - for (n, fd) in funcs.into_values() { - if n == mono_hugr.entrypoint() { - assert_eq!(fd.func_name(), "mainish"); - } else { - assert_ne!(fd.func_name(), "mainish"); - } - } + #[allow(clippy::unnecessary_to_owned)] // it is necessary + let (n, fd) = *funcs.get(&"mainish".to_string()).unwrap(); + assert_eq!(n, outer_func); + assert_eq!(fd.func_name(), "mainish"); // just a sanity check on list_funcs } fn list_funcs(h: &Hugr) -> HashMap<&String, (Node, &FuncDefn)> { @@ -559,50 +539,6 @@ mod test { .collect::>() } - #[test] - fn test_no_flatten_out_of_mono_func() -> Result<(), Box> { - let ity = || INT_TYPES[4].clone(); - let sig = Signature::new_endo(vec![usize_t(), ity()]); - let mut dfg = DFGBuilder::new(sig.clone()).unwrap(); - let mut mono = dfg.define_function("id2", sig).unwrap(); - let pf = mono - .define_function( - "id", - PolyFuncType::new( - [TypeBound::Any.into()], - Signature::new_endo(Type::new_var_use(0, TypeBound::Any)), - ), - ) - .unwrap(); - let outs = pf.input_wires(); - let pf = pf.finish_with_outputs(outs).unwrap(); - let [a, b] = mono.input_wires_arr(); - let [a] = mono - .call(pf.handle(), &[usize_t().into()], [a]) - .unwrap() - .outputs_arr(); - let [b] = mono - .call(pf.handle(), &[ity().into()], [b]) - .unwrap() - .outputs_arr(); - let mono = mono.finish_with_outputs([a, b]).unwrap(); - let c = dfg.call(mono.handle(), &[], dfg.input_wires()).unwrap(); - let mut hugr = dfg.finish_hugr_with_outputs(c.outputs()).unwrap(); - monomorphize(&mut hugr)?; - let mono_hugr = hugr; - - let mut funcs = list_funcs(&mono_hugr); - #[allow(clippy::unnecessary_to_owned)] // It is necessary - let (m, _) = funcs.remove(&"id2".to_string()).unwrap(); - assert_eq!(m, mono.handle().node()); - assert_eq!(mono_hugr.get_parent(m), Some(mono_hugr.entrypoint())); - for t in [usize_t(), ity()] { - let (n, _) = funcs.remove(&mangle_name("id", &[t.into()])).unwrap(); - assert_eq!(mono_hugr.get_parent(n), Some(m)); // Not lifted to top - } - Ok(()) - } - #[test] fn load_function() { let mut hugr = { @@ -612,8 +548,8 @@ mod test { .define_function( "foo", PolyFuncType::new( - [TypeBound::Any.into()], - Signature::new_endo(Type::new_var_use(0, TypeBound::Any)), + [TypeBound::Linear.into()], + Signature::new_endo(Type::new_var_use(0, TypeBound::Linear)), ), ) .unwrap(); @@ -657,9 +593,10 @@ mod test { #[case::type_int(vec![INT_TYPES[2].clone().into()], "$foo$$t(int(2))")] #[case::string(vec!["arg".into()], "$foo$$s(arg)")] #[case::dollar_string(vec!["$arg".into()], "$foo$$s(\\$arg)")] - #[case::sequence(vec![vec![0.into(), Type::UNIT.into()].into()], "$foo$$seq($n(0)$t(Unit))")] + #[case::sequence(vec![vec![0.into(), Type::UNIT.into()].into()], "$foo$$list($n(0)$t(Unit))")] + #[case::sequence(vec![TypeArg::Tuple(vec![0.into(),Type::UNIT.into()])], "$foo$$tuple($n(0)$t(Unit))")] #[should_panic] - #[case::typeargvariable(vec![TypeArg::new_var_use(1, TypeParam::String)], + #[case::typeargvariable(vec![TypeArg::new_var_use(1, TypeParam::StringType)], "$foo$$v(1)")] #[case::multiple(vec![0.into(), "arg".into()], "$foo$$n(0)$s(arg)")] fn test_mangle_name(#[case] args: Vec, #[case] expected: String) { diff --git a/hugr-passes/src/non_local.rs b/hugr-passes/src/non_local.rs index 75bbea399e..df276a1ff9 100644 --- a/hugr-passes/src/non_local.rs +++ b/hugr-passes/src/non_local.rs @@ -1,6 +1,5 @@ //! This module provides functions for finding non-local edges //! in a Hugr and converting them to local edges. -#![warn(missing_docs)] use itertools::Itertools as _; use hugr_core::{ diff --git a/hugr-passes/src/replace_types.rs b/hugr-passes/src/replace_types.rs index ac19094c19..0b5cca8f6a 100644 --- a/hugr-passes/src/replace_types.rs +++ b/hugr-passes/src/replace_types.rs @@ -1,5 +1,4 @@ #![allow(clippy::type_complexity)] -#![warn(missing_docs)] //! Replace types with other types across the Hugr. See [`ReplaceTypes`] and [Linearizer]. //! use std::borrow::Cow; @@ -108,9 +107,9 @@ impl NodeTemplate { } } - fn replace(&self, hugr: &mut impl HugrMut, n: Node) -> Result<(), BuildError> { + fn replace(self, hugr: &mut impl HugrMut, n: Node) -> Result<(), BuildError> { assert_eq!(hugr.children(n).count(), 0); - let new_optype = match self.clone() { + let new_optype = match self { NodeTemplate::SingleOp(op_type) => op_type, NodeTemplate::CompoundOp(new_h) => { let new_entrypoint = hugr.insert_hugr(n, *new_h).inserted_entrypoint; @@ -172,6 +171,23 @@ fn call>( Ok(Call::try_new(func_sig, type_args)?) } +/// Options for how the replacement for an op is processed. May be specified by +/// [ReplaceTypes::replace_op_with] and [ReplaceTypes::replace_parametrized_op_with]. +/// Otherwise (the default), replacements are inserted as is (without further processing). +#[derive(Clone, Default, PartialEq, Eq)] // More derives might inhibit future extension +pub struct ReplacementOptions { + linearize: bool, +} + +impl ReplacementOptions { + /// Specifies that all operations within the replacement should have their + /// output ports linearized. + pub fn with_linearization(mut self, lin: bool) -> Self { + self.linearize = lin; + self + } +} + /// A configuration of what types, ops, and constants should be replaced with what. /// May be applied to a Hugr via [`Self::run`]. /// @@ -204,8 +220,14 @@ pub struct ReplaceTypes { type_map: HashMap, param_types: HashMap Option>>, linearize: DelegatingLinearizer, - op_map: HashMap, - param_ops: HashMap Option>>, + op_map: HashMap, + param_ops: HashMap< + ParametricOp, + ( + Arc Option>, + ReplacementOptions, + ), + >, consts: HashMap< CustomType, Arc Result>, @@ -259,7 +281,7 @@ pub enum ReplaceTypesError { #[error(transparent)] LinearizeError(#[from] LinearizeError), #[error("Replacement op for {0} could not be added because {1}")] - AddTemplateError(Node, BuildError), + AddTemplateError(Node, Box), } impl ReplaceTypes { @@ -338,13 +360,36 @@ impl ReplaceTypes { } /// Configures this instance to change occurrences of `src` to `dest`. + /// Equivalent to [Self::replace_op_with] with default [ReplacementOptions]. + pub fn replace_op(&mut self, src: &ExtensionOp, dest: NodeTemplate) { + self.replace_op_with(src, dest, ReplacementOptions::default()) + } + + /// Configures this instance to change occurrences of `src` to `dest`. + /// /// Note that if `src` is an instance of a *parametrized* [`OpDef`], this takes /// precedence over [`Self::replace_parametrized_op`] where the `src`s overlap. Thus, /// this should only be used on already-*[monomorphize](super::monomorphize())d* /// Hugrs, as substitution (parametric polymorphism) happening later will not respect /// this replacement. - pub fn replace_op(&mut self, src: &ExtensionOp, dest: NodeTemplate) { - self.op_map.insert(OpHashWrapper::from(src), dest); + pub fn replace_op_with( + &mut self, + src: &ExtensionOp, + dest: NodeTemplate, + opts: ReplacementOptions, + ) { + self.op_map.insert(OpHashWrapper::from(src), (dest, opts)); + } + + /// Configures this instance to change occurrences of a parametrized op `src` + /// via a callback that builds the replacement type given the [`TypeArg`]s. + /// Equivalent to [Self::replace_parametrized_op_with] with default [ReplacementOptions]. + pub fn replace_parametrized_op( + &mut self, + src: &OpDef, + dest_fn: impl Fn(&[TypeArg]) -> Option + 'static, + ) { + self.replace_parametrized_op_with(src, dest_fn, ReplacementOptions::default()) } /// Configures this instance to change occurrences of a parametrized op `src` @@ -353,12 +398,13 @@ impl ReplaceTypes { /// fit the bounds of the original op). /// /// If the Callback returns None, the new typeargs will be applied to the original op. - pub fn replace_parametrized_op( + pub fn replace_parametrized_op_with( &mut self, src: &OpDef, dest_fn: impl Fn(&[TypeArg]) -> Option + 'static, + opts: ReplacementOptions, ) { - self.param_ops.insert(src.into(), Arc::new(dest_fn)); + self.param_ops.insert(src.into(), (Arc::new(dest_fn), opts)); } /// Configures this instance to change [Const]s of type `src_ty`, using @@ -448,34 +494,40 @@ impl ReplaceTypes { | rest.transform(self)?), OpType::Const(Const { value, .. }) => self.change_value(value), - OpType::ExtensionOp(ext_op) => Ok( - // Copy/discard insertion done by caller - if let Some(replacement) = self.op_map.get(&OpHashWrapper::from(&*ext_op)) { + OpType::ExtensionOp(ext_op) => Ok({ + let def = ext_op.def_arc(); + let mut changed = false; + let replacement = match self.op_map.get(&OpHashWrapper::from(&*ext_op)) { + r @ Some(_) => r.cloned(), + None => { + let mut args = ext_op.args().to_vec(); + changed = args.transform(self)?; + let r2 = self + .param_ops + .get(&def.as_ref().into()) + .and_then(|(rep_fn, opts)| rep_fn(&args).map(|nt| (nt, opts.clone()))); + if r2.is_none() && changed { + *ext_op = ExtensionOp::new(def.clone(), args)?; + } + r2 + } + }; + if let Some((replacement, opts)) = replacement { replacement .replace(hugr, n) - .map_err(|e| ReplaceTypesError::AddTemplateError(n, e))?; - true - } else { - let def = ext_op.def_arc(); - let mut args = ext_op.args().to_vec(); - let ch = args.transform(self)?; - if let Some(replacement) = self - .param_ops - .get(&def.as_ref().into()) - .and_then(|rep_fn| rep_fn(&args)) - { - replacement - .replace(hugr, n) - .map_err(|e| ReplaceTypesError::AddTemplateError(n, e))?; - true - } else { - if ch { - *ext_op = ExtensionOp::new(def.clone(), args)?; + .map_err(|e| ReplaceTypesError::AddTemplateError(n, Box::new(e)))?; + if opts.linearize { + for d in hugr.descendants(n).collect::>() { + if d != n { + self.linearize_outputs(hugr, d)?; + } } - ch } - }, - ), + true + } else { + changed + } + }), OpType::OpaqueOp(_) => panic!("OpaqueOp should not be in a Hugr"), @@ -519,6 +571,27 @@ impl ReplaceTypes { Value::Function { hugr } => self.run(&mut **hugr), } } + + fn linearize_outputs>( + &self, + hugr: &mut H, + n: H::Node, + ) -> Result<(), LinearizeError> { + if let Some(new_sig) = hugr.get_optype(n).dataflow_signature() { + let new_sig = new_sig.into_owned(); + for outp in new_sig.output_ports() { + if !new_sig.out_port_type(outp).unwrap().copyable() { + let targets = hugr.linked_inputs(n, outp).collect::>(); + if targets.len() != 1 { + hugr.disconnect(n, outp); + let src = Wire::new(n, outp); + self.linearize.insert_copy_discard(hugr, src, &targets)?; + } + } + } + } + Ok(()) + } } impl> ComposablePass for ReplaceTypes { @@ -529,21 +602,8 @@ impl> ComposablePass for ReplaceTypes { let mut changed = false; for n in hugr.entry_descendants().collect::>() { changed |= self.change_node(hugr, n)?; - let new_dfsig = hugr.get_optype(n).dataflow_signature(); - if let Some(new_sig) = new_dfsig - .filter(|_| changed && n != hugr.entrypoint()) - .map(Cow::into_owned) - { - for outp in new_sig.output_ports() { - if !new_sig.out_port_type(outp).unwrap().copyable() { - let targets = hugr.linked_inputs(n, outp).collect::>(); - if targets.len() != 1 { - hugr.disconnect(n, outp); - let src = Wire::new(n, outp); - self.linearize.insert_copy_discard(hugr, src, &targets)?; - } - } - } + if n != hugr.entrypoint() && changed { + self.linearize_outputs(hugr, n)?; } } Ok(changed) @@ -641,7 +701,7 @@ mod test { } fn just_elem_type(args: &[TypeArg]) -> &Type { - let [TypeArg::Type { ty }] = args else { + let [TypeArg::Runtime(ty)] = args else { panic!("Expected just elem type") }; ty @@ -655,7 +715,7 @@ mod test { let pv_of_var = ext .add_type( PACKED_VEC.into(), - vec![TypeBound::Any.into()], + vec![TypeBound::Linear.into()], String::new(), TypeDefBound::from_params(vec![0]), w, @@ -670,7 +730,7 @@ mod test { vec![TypeBound::Copyable.into()], Signature::new( vec![pv_of_var.into(), i64_t()], - Type::new_var_use(0, TypeBound::Any), + Type::new_var_use(0, TypeBound::Linear), ), ), w, @@ -748,9 +808,9 @@ mod test { let c_int = Type::from(coln.instantiate([i64_t().into()]).unwrap()); let c_bool = Type::from(coln.instantiate([bool_t().into()]).unwrap()); let mut mb = ModuleBuilder::new(); - let sig = Signature::new_endo(Type::new_var_use(0, TypeBound::Any)); + let sig = Signature::new_endo(Type::new_var_use(0, TypeBound::Linear)); let fb = mb - .define_function("id", PolyFuncType::new([TypeBound::Any.into()], sig)) + .define_function("id", PolyFuncType::new([TypeBound::Linear.into()], sig)) .unwrap(); let inps = fb.input_wires(); let id = fb.finish_with_outputs(inps).unwrap(); @@ -967,8 +1027,8 @@ mod test { IdentList::new_unchecked("NoBoundsCheck"), Version::new(0, 0, 0), |e, w| { - let params = vec![TypeBound::Any.into()]; - let tv = Type::new_var_use(0, TypeBound::Any); + let params = vec![TypeBound::Linear.into()]; + let tv = Type::new_var_use(0, TypeBound::Linear); let list_of_var = list_type(tv.clone()); e.add_op( READ.into(), diff --git a/hugr-passes/src/replace_types/handlers.rs b/hugr-passes/src/replace_types/handlers.rs index 25abb846bc..7c0fe5f550 100644 --- a/hugr-passes/src/replace_types/handlers.rs +++ b/hugr-passes/src/replace_types/handlers.rs @@ -106,7 +106,7 @@ pub fn linearize_generic_array( ) -> Result { // Require known length i.e. usable only after monomorphization, due to no-variables limitation // restriction on NodeTemplate::CompoundOp - let [TypeArg::BoundedNat { n }, TypeArg::Type { ty }] = args else { + let [TypeArg::BoundedNat(n), TypeArg::Runtime(ty)] = args else { panic!("Illegal TypeArgs to array: {args:?}") }; if num_outports == 0 { @@ -116,7 +116,9 @@ pub fn linearize_generic_array( let [to_discard] = dfb.input_wires_arr(); lin.copy_discard_op(ty, 0)? .add(&mut dfb, [to_discard]) - .map_err(|e| LinearizeError::NestedTemplateError(ty.clone(), e))?; + .map_err(|e| { + LinearizeError::NestedTemplateError(Box::new(ty.clone()), Box::new(e)) + })?; let ret = dfb.add_load_value(Value::unary_unit_sum()); dfb.finish_hugr_with_outputs([ret]).unwrap() }; @@ -189,7 +191,7 @@ pub fn linearize_generic_array( let mut copies = lin .copy_discard_op(ty, num_outports)? .add(&mut dfb, [elem]) - .map_err(|e| LinearizeError::NestedTemplateError(ty.clone(), e))? + .map_err(|e| LinearizeError::NestedTemplateError(Box::new(ty.clone()), Box::new(e)))? .outputs(); let copy0 = copies.next().unwrap(); // We'll return this directly @@ -307,7 +309,7 @@ pub fn copy_discard_array( ) -> Result { // Require known length i.e. usable only after monomorphization, due to no-variables limitation // restriction on NodeTemplate::CompoundOp - let [TypeArg::BoundedNat { n }, TypeArg::Type { ty }] = args else { + let [TypeArg::BoundedNat(n), TypeArg::Runtime(ty)] = args else { panic!("Illegal TypeArgs to array: {args:?}") }; if ty.copyable() { diff --git a/hugr-passes/src/replace_types/linearize.rs b/hugr-passes/src/replace_types/linearize.rs index bc12e730bd..4227c5d817 100644 --- a/hugr-passes/src/replace_types/linearize.rs +++ b/hugr-passes/src/replace_types/linearize.rs @@ -5,17 +5,19 @@ use hugr_core::builder::{ HugrBuilder, inout_sig, }; use hugr_core::extension::{SignatureError, TypeDef}; +use hugr_core::std_extensions::collections::array::array_type_def; use hugr_core::std_extensions::collections::value_array::value_array_type_def; use hugr_core::types::{CustomType, Signature, Type, TypeArg, TypeEnum, TypeRow}; use hugr_core::{HugrView, IncomingPort, Node, Wire, hugr::hugrmut::HugrMut, ops::Tag}; use itertools::Itertools; -use super::{NodeTemplate, ParametricType, handlers::linearize_value_array}; +use super::handlers::{copy_discard_array, linearize_value_array}; +use super::{NodeTemplate, ParametricType}; /// Trait for things that know how to wire up linear outports to other than one /// target. Used to restore Hugr validity when a [`ReplaceTypes`](super::ReplaceTypes) /// results in types of such outports changing from [Copyable] to linear (i.e. -/// [`hugr_core::types::TypeBound::Any`]). +/// [`hugr_core::types::TypeBound::Linear`]). /// /// Note that this is not really effective before [monomorphization]: if a /// function polymorphic over a [Copyable] becomes called with a @@ -52,12 +54,10 @@ pub trait Linearizer { src: Wire, targets: &[(Node, IncomingPort)], ) -> Result<(), LinearizeError> { - let sig = hugr.signature(src.node()).unwrap(); - let typ = sig.port_type(src.source()).unwrap(); let (tgt_node, tgt_inport) = if targets.len() == 1 { *targets.first().unwrap() } else { - // Fail fast if the edges are nonlocal. (TODO transform to local edges!) + // Fail fast if the edges are nonlocal. let src_parent = hugr .get_parent(src.node()) .expect("Root node cannot have out edges"); @@ -74,11 +74,12 @@ pub trait Linearizer { tgt_parent, }); } - let typ = typ.clone(); // Stop borrowing hugr in order to add_hugr to it + let sig = hugr.signature(src.node()).unwrap(); + let typ = sig.port_type(src.source()).unwrap().clone(); let copy_discard_op = self .copy_discard_op(&typ, targets.len())? .add_hugr(hugr, src_parent) - .map_err(|e| LinearizeError::NestedTemplateError(typ, e))?; + .map_err(|e| LinearizeError::NestedTemplateError(Box::new(typ), Box::new(e)))?; for (n, (tgt_node, tgt_port)) in targets.iter().enumerate() { hugr.connect(copy_discard_op, n, *tgt_node, *tgt_port); } @@ -125,6 +126,7 @@ impl Default for DelegatingLinearizer { fn default() -> Self { let mut res = Self::new_empty(); res.register_callback(value_array_type_def(), linearize_value_array); + res.register_callback(array_type_def(), copy_discard_array); res } } @@ -140,15 +142,16 @@ pub struct CallbackHandler<'a>(#[allow(dead_code)] &'a DelegatingLinearizer); #[non_exhaustive] pub enum LinearizeError { #[error("Need copy/discard op for {_0}")] - NeedCopyDiscard(Type), + NeedCopyDiscard(Box), #[error("Copy/discard op for {typ} with {num_outports} outputs had wrong signature {sig:?}")] WrongSignature { - typ: Type, + typ: Box, num_outports: usize, - sig: Option, + sig: Option>, }, #[error( - "Cannot add nonlocal edge for linear type from {src} (with parent {src_parent}) to {tgt} (with parent {tgt_parent})" + "Cannot add nonlocal edge for linear type from {src} (with parent {src_parent}) to {tgt} (with parent {tgt_parent}). + Try using LocalizeEdges pass first." )] NoLinearNonLocalEdges { src: Node, @@ -163,14 +166,14 @@ pub enum LinearizeError { /// [Variable](TypeEnum::Variable)s, [Row variables](TypeEnum::RowVar), /// or [Alias](TypeEnum::Alias)es. #[error("Cannot linearize type {_0}")] - UnsupportedType(Type), + UnsupportedType(Box), /// Neither does linearization make sense for copyable types #[error("Type {_0} is copyable")] - CopyableType(Type), + CopyableType(Box), /// Error may be returned by a callback for e.g. a container because it could /// not generate a [`NodeTemplate`] because of a problem with an element #[error("Could not generate NodeTemplate for contained type {0} because {1}")] - NestedTemplateError(Type, BuildError), + NestedTemplateError(Box, Box), } impl DelegatingLinearizer { @@ -206,7 +209,7 @@ impl DelegatingLinearizer { ) -> Result<(), LinearizeError> { let typ = Type::new_extension(cty.clone()); if typ.copyable() { - return Err(LinearizeError::CopyableType(typ)); + return Err(LinearizeError::CopyableType(Box::new(typ))); } check_sig(©, &typ, 2)?; check_sig(&discard, &typ, 0)?; @@ -247,9 +250,9 @@ impl DelegatingLinearizer { fn check_sig(tmpl: &NodeTemplate, typ: &Type, num_outports: usize) -> Result<(), LinearizeError> { tmpl.check_signature(&typ.clone().into(), &vec![typ.clone(); num_outports].into()) .map_err(|sig| LinearizeError::WrongSignature { - typ: typ.clone(), + typ: Box::new(typ.clone()), num_outports, - sig, + sig: sig.map(Box::new), }) } @@ -260,7 +263,7 @@ impl Linearizer for DelegatingLinearizer { num_outports: usize, ) -> Result { if typ.copyable() { - return Err(LinearizeError::CopyableType(typ.clone())); + return Err(LinearizeError::CopyableType(Box::new(typ.clone()))); } assert!(num_outports != 1); @@ -338,14 +341,14 @@ impl Linearizer for DelegatingLinearizer { let copy_discard_fn = self .copy_discard_parametric .get(&cty.into()) - .ok_or_else(|| LinearizeError::NeedCopyDiscard(typ.clone()))?; + .ok_or_else(|| LinearizeError::NeedCopyDiscard(Box::new(typ.clone())))?; let tmpl = copy_discard_fn(cty.args(), num_outports, &CallbackHandler(self))?; check_sig(&tmpl, typ, num_outports)?; Ok(tmpl) } } TypeEnum::Function(_) => panic!("Ruled out above as copyable"), - _ => Err(LinearizeError::UnsupportedType(typ.clone())), + _ => Err(LinearizeError::UnsupportedType(Box::new(typ.clone()))), } } } @@ -371,7 +374,7 @@ mod test { HugrBuilder, inout_sig, }; - use hugr_core::extension::prelude::{option_type, usize_t}; + use hugr_core::extension::prelude::{option_type, qb_t, usize_t}; use hugr_core::extension::simple_op::MakeExtensionOp; use hugr_core::extension::{ CustomSignatureFunc, OpDef, SignatureError, SignatureFunc, TypeDefBound, Version, @@ -385,14 +388,16 @@ mod test { }; use hugr_core::types::type_param::TypeParam; use hugr_core::types::{ - FuncValueType, PolyFuncTypeRV, Signature, Type, TypeArg, TypeEnum, TypeRow, + FuncValueType, PolyFuncTypeRV, Signature, Type, TypeArg, TypeBound, TypeEnum, TypeRow, }; use hugr_core::{Extension, Hugr, HugrView, Node, hugr::IdentList, type_row}; use itertools::Itertools; use rstest::rstest; use crate::replace_types::handlers::linearize_value_array; - use crate::replace_types::{LinearizeError, NodeTemplate, ReplaceTypesError}; + use crate::replace_types::{ + LinearizeError, NodeTemplate, ReplaceTypesError, ReplacementOptions, + }; use crate::{ComposablePass, ReplaceTypes}; const LIN_T: &str = "Lin"; @@ -404,7 +409,7 @@ mod test { arg_values: &[TypeArg], _def: &'o OpDef, ) -> Result { - let [TypeArg::BoundedNat { n }] = arg_values else { + let [TypeArg::BoundedNat(n)] = arg_values else { panic!() }; let outs = vec![self.0.clone(); *n as usize]; @@ -412,7 +417,7 @@ mod test { } fn static_params(&self) -> &[TypeParam] { - const JUST_NAT: &[TypeParam] = &[TypeParam::max_nat()]; + const JUST_NAT: &[TypeParam] = &[TypeParam::max_nat_type()]; JUST_NAT } } @@ -648,9 +653,9 @@ mod test { assert_eq!( bad_copy, Err(LinearizeError::WrongSignature { - typ: lin_t.clone(), + typ: Box::new(lin_t.clone()), num_outports: 2, - sig: sig3.clone() + sig: sig3.clone().map(Box::new) }) ); @@ -663,9 +668,9 @@ mod test { assert_eq!( bad_discard, Err(LinearizeError::WrongSignature { - typ: lin_t.clone(), + typ: Box::new(lin_t.clone()), num_outports: 0, - sig: sig3.clone() + sig: sig3.clone().map(Box::new) }) ); @@ -685,9 +690,9 @@ mod test { replacer.run(&mut h), Err(ReplaceTypesError::LinearizeError( LinearizeError::WrongSignature { - typ: lin_t.clone(), + typ: Box::new(lin_t.clone()), num_outports: 2, - sig: sig3.clone() + sig: sig3.clone().map(Box::new) } )) ); @@ -800,7 +805,8 @@ mod test { // A simple Hugr that discards a usize_t, with a "drop" function let mut dfb = DFGBuilder::new(inout_sig(usize_t(), type_row![])).unwrap(); let discard_fn = { - let mut fb = dfb + let mut mb = dfb.module_root_builder(); + let mut fb = mb .define_function("drop", Signature::new(lin_t.clone(), type_row![])) .unwrap(); let ins = fb.input_wires(); @@ -815,12 +821,11 @@ mod test { let backup = dfb.finish_hugr().unwrap(); let mut lower_discard_to_call = ReplaceTypes::default(); - // The `copy_fn` here will break completely, but we don't use it lower_discard_to_call .linearizer() .register_simple( lin_ct.clone(), - NodeTemplate::Call(backup.entrypoint(), vec![]), + NodeTemplate::Call(backup.entrypoint(), vec![]), // Arbitrary, unused NodeTemplate::Call(discard_fn, vec![]), ) .unwrap(); @@ -834,20 +839,85 @@ mod test { assert_eq!(h.output_neighbours(discard_fn).count(), 1); } - // But if we lower usize_t to array, the call will fail + // But if we lower usize_t to array, the call will fail. lower_discard_to_call.replace_type( usize_t().as_extension().unwrap().clone(), value_array_type(4, lin_ct.into()), ); let r = lower_discard_to_call.run(&mut backup.clone()); - assert!(matches!( - r, - Err(ReplaceTypesError::LinearizeError( - LinearizeError::NestedTemplateError( - nested_t, - BuildError::NodeNotFound { node } + // Note the error (or success) can be quite fragile, according to what the `discard_fn` + // Node points at in the (hidden here) inner Hugr built by the array linearization helper. + if let Err(ReplaceTypesError::LinearizeError(LinearizeError::NestedTemplateError( + nested_t, + build_err, + ))) = r + { + assert_eq!(*nested_t, lin_t); + assert!(matches!( + *build_err, BuildError::NodeNotFound { node } if node == discard_fn + )); + } else { + panic!("Expected error"); + } + } + + #[test] + fn use_in_op_callback() { + let (e, mut lowerer) = ext_lowerer(); + let drop_ext = Extension::new_arc( + IdentList::new_unchecked("DropExt"), + Version::new(0, 0, 0), + |e, w| { + e.add_op( + "drop".into(), + String::new(), + PolyFuncTypeRV::new( + [TypeBound::Linear.into()], // It won't *lower* for any type tho! + Signature::new(Type::new_var_use(0, TypeBound::Linear), vec![]), + ), + w, ) - )) if nested_t == lin_t && node == discard_fn - )); + .unwrap(); + }, + ); + let drop_op = drop_ext.get_op("drop").unwrap(); + lowerer.replace_parametrized_op_with( + drop_op, + |args| { + let [TypeArg::Runtime(ty)] = args else { + panic!("Expected just one type") + }; + // The Hugr here is invalid, so we have to pull it out manually + let mut dfb = DFGBuilder::new(Signature::new(ty.clone(), vec![])).unwrap(); + let h = std::mem::take(dfb.hugr_mut()); + Some(NodeTemplate::CompoundOp(Box::new(h))) + }, + ReplacementOptions::default().with_linearization(true), + ); + + let build_hugr = |ty: Type| { + let mut dfb = DFGBuilder::new(Signature::new(ty.clone(), vec![])).unwrap(); + let [inp] = dfb.input_wires_arr(); + let drop_op = drop_ext + .instantiate_extension_op("drop", [ty.into()]) + .unwrap(); + dfb.add_dataflow_op(drop_op, [inp]).unwrap(); + dfb.finish_hugr().unwrap() + }; + // We can drop a tuple of 2* lin_t + let lin_t = Type::from(e.get_type(LIN_T).unwrap().instantiate([]).unwrap()); + let mut h = build_hugr(Type::new_tuple(vec![lin_t; 2])); + lowerer.run(&mut h).unwrap(); + h.validate().unwrap(); + let mut exts = h.nodes().filter_map(|n| h.get_optype(n).as_extension_op()); + assert_eq!(exts.clone().count(), 2); + assert!(exts.all(|eo| eo.qualified_id() == "TestExt.discard")); + + // We cannot drop a qubit + let mut h = build_hugr(qb_t()); + assert_eq!( + lowerer.run(&mut h).unwrap_err(), + ReplaceTypesError::LinearizeError(LinearizeError::NeedCopyDiscard(Box::new(qb_t()))) + ); } } diff --git a/hugr-persistent/CHANGELOG.md b/hugr-persistent/CHANGELOG.md new file mode 100644 index 0000000000..a69263ca63 --- /dev/null +++ b/hugr-persistent/CHANGELOG.md @@ -0,0 +1,11 @@ +# Changelog + + +## [0.2.0](https://github.com/CQCL/hugr/compare/hugr-persistent-v0.1.0...hugr-persistent-v0.2.0) - 2025-07-24 + +### New Features + +- [**breaking**] Update portgraph dependency to 0.15 ([#2455](https://github.com/CQCL/hugr/pull/2455)) +## 0.1.0 (2025-07-10) + +Initial release. diff --git a/hugr-persistent/Cargo.toml b/hugr-persistent/Cargo.toml new file mode 100644 index 0000000000..75aa61f240 --- /dev/null +++ b/hugr-persistent/Cargo.toml @@ -0,0 +1,43 @@ +[package] +name = "hugr-persistent" +version = "0.2.1" +edition = { workspace = true } +rust-version = { workspace = true } +license = { workspace = true } +readme = "README.md" +documentation = "https://docs.rs/hugr-persistent/" +homepage = { workspace = true } +repository = { workspace = true } +description = "Persistent IR structure for Quantinuum's HUGR" +keywords = ["Quantum", "Quantinuum"] +categories = ["compilers"] + +[[test]] +name = "persistent_walker_example" + +[dependencies] +hugr-core = { path = "../hugr-core", version = "0.22.1" } + +derive_more = { workspace = true, features = ["display", "error", "from"] } +delegate.workspace = true +itertools.workspace = true +petgraph.workspace = true +portgraph.workspace = true +relrc = { workspace = true, features = ["petgraph", "serde"] } +serde.workspace = true +serde_json.workspace = true +thiserror.workspace = true +wyhash.workspace = true + +[lints] +workspace = true + +[lib] +bench = false + +[dev-dependencies] +rstest.workspace = true +lazy_static.workspace = true +semver.workspace = true +serde_with.workspace = true +insta.workspace = true diff --git a/hugr-persistent/README.md b/hugr-persistent/README.md new file mode 100644 index 0000000000..95386fa664 --- /dev/null +++ b/hugr-persistent/README.md @@ -0,0 +1,59 @@ +![](/hugr/assets/hugr_logo.svg) + +# hugr-persistent + +[![build_status][]](https://github.com/CQCL/hugr/actions) +[![crates][]](https://crates.io/crates/hugr-persistent) +[![msrv][]](https://github.com/CQCL/hugr) +[![codecov][]](https://codecov.io/gh/CQCL/hugr) + +The Hierarchical Unified Graph Representation (HUGR, pronounced _hugger_) is the +common representation of quantum circuits and operations in the Quantinuum +ecosystem. + +It provides a high-fidelity representation of operations, that facilitates +compilation and encodes runnable programs. + +The HUGR specification is [here](https://github.com/CQCL/hugr/blob/main/specification/hugr.md). + +## Overview + +This crate provides a persistent data structure for HUGR mutations; mutations to +the data are stored persistently as a set of `Commit`s along with the +dependencies between them. + +As a result of persistency, the entire mutation history of a HUGR can be +traversed and references to previous versions of the data remain valid even +as the HUGR graph is "mutated" by applying patches: the patches are in +effect added to the history as new commits. + +## Usage + +Add the dependency to your project: + +```bash +cargo add hugr-persistent +``` + +Please read the [API documentation here][]. + +## Recent Changes + +See [CHANGELOG][] for a list of changes. The minimum supported rust +version will only change on major releases. + +## Development + +See [DEVELOPMENT.md](https://github.com/CQCL/hugr/blob/main/DEVELOPMENT.md) for instructions on setting up the development environment. + +## License + +This project is licensed under Apache License, Version 2.0 ([LICENSE][] or http://www.apache.org/licenses/LICENSE-2.0). + + [API documentation here]: https://docs.rs/hugr-persistent/ + [build_status]: https://github.com/CQCL/hugr/actions/workflows/ci-rs.yml/badge.svg?branch=main + [msrv]: https://img.shields.io/crates/msrv/hugr-persistent + [crates]: https://img.shields.io/crates/v/hugr-persistent + [codecov]: https://img.shields.io/codecov/c/gh/CQCL/hugr?logo=codecov + [LICENSE]: https://github.com/CQCL/hugr/blob/main/LICENCE + [CHANGELOG]: https://github.com/CQCL/hugr/blob/main/hugr-persistent/CHANGELOG.md diff --git a/hugr-persistent/src/lib.rs b/hugr-persistent/src/lib.rs new file mode 100644 index 0000000000..694e83e23e --- /dev/null +++ b/hugr-persistent/src/lib.rs @@ -0,0 +1,98 @@ +#![doc(hidden)] // TODO: remove when stable + +//! Persistent data structure for HUGR mutations. +//! +//! This crate provides a persistent data structure [`PersistentHugr`] that +//! implements [`HugrView`](hugr_core::HugrView); mutations to the data are +//! stored persistently as a set of [`Commit`]s along with the dependencies +//! between the commits. +//! +//! As a result of persistency, the entire mutation history of a HUGR can be +//! traversed and references to previous versions of the data remain valid even +//! as the HUGR graph is "mutated" by applying patches: the patches are in +//! effect added to the history as new commits. +//! +//! The data structure underlying [`PersistentHugr`], which stores the history +//! of all commits, is [`CommitStateSpace`]. Multiple [`PersistentHugr`] can be +//! stored within a single [`CommitStateSpace`], which allows for the efficient +//! exploration of the space of all possible graph rewrites. +//! +//! ## Overlapping commits +//! +//! In general, [`CommitStateSpace`] may contain overlapping commits. Such +//! mutations are mutually exclusive as they modify the same nodes. It is +//! therefore not possible to apply all commits in a [`CommitStateSpace`] +//! simultaneously. A [`PersistentHugr`] on the other hand always corresponds to +//! a subgraph of a [`CommitStateSpace`] that is guaranteed to contain only +//! non-overlapping, compatible commits. By applying all commits in a +//! [`PersistentHugr`], we can materialize a [`Hugr`](hugr_core::Hugr). +//! Traversing the materialized HUGR is equivalent to using the +//! [`HugrView`](hugr_core::HugrView) implementation of the corresponding +//! [`PersistentHugr`]. +//! +//! ## Summary of data types +//! +//! - [`Commit`] A modification to a [`Hugr`](hugr_core::Hugr) (currently a +//! [`SimpleReplacement`](hugr_core::SimpleReplacement)) that forms the atomic +//! unit of change for a [`PersistentHugr`] (like a commit in git). This is a +//! reference-counted value that is cheap to clone and will be freed when the +//! last reference is dropped. +//! - [`PersistentHugr`] A data structure that implements +//! [`HugrView`][hugr_core::HugrView] and can be used as a drop-in replacement +//! for a [`Hugr`][hugr_core::Hugr] for read-only access and mutations through +//! the [`PatchVerification`](hugr_core::hugr::patch::PatchVerification) and +//! [`Patch`](hugr_core::hugr::Patch) traits. Mutations are stored as a +//! history of commits. Unlike [`CommitStateSpace`], it maintains the +//! invariant that all contained commits are compatible with eachother. +//! - [`CommitStateSpace`] Stores commits, recording the dependencies between +//! them. Includes the base HUGR and any number of possibly incompatible +//! (overlapping) commits. Unlike a [`PersistentHugr`], a state space can +//! contain mutually exclusive commits. +//! +//! ## Usage +//! +//! A [`PersistentHugr`] can be created from a base HUGR using +//! [`PersistentHugr::with_base`]. Replacements can then be applied to it +//! using [`PersistentHugr::add_replacement`]. Alternatively, if you already +//! have a populated state space, use [`PersistentHugr::try_new`] to create a +//! new HUGR with those commits. +//! +//! Add a sequence of commits to a state space by merging a [`PersistentHugr`] +//! into it using [`CommitStateSpace::extend`] or directly using +//! [`CommitStateSpace::try_add_commit`]. +//! +//! To obtain a [`PersistentHugr`] from your state space, use +//! [`CommitStateSpace::try_extract_hugr`]. A [`PersistentHugr`] can always be +//! materialized into a [`Hugr`][hugr_core::Hugr] type using +//! [`PersistentHugr::to_hugr`]. + +mod parents_view; +mod persistent_hugr; +mod resolver; +pub mod state_space; +pub mod subgraph; +mod trait_impls; +pub mod walker; +mod wire; + +pub use persistent_hugr::{Commit, PersistentHugr}; +pub use resolver::{PointerEqResolver, Resolver, SerdeHashResolver}; +pub use state_space::{CommitId, CommitStateSpace, InvalidCommit, PatchNode}; +pub use subgraph::PinnedSubgraph; +pub use walker::Walker; +pub use wire::PersistentWire; + +/// A replacement operation that can be applied to a [`PersistentHugr`]. +pub type PersistentReplacement = hugr_core::SimpleReplacement; + +use persistent_hugr::find_conflicting_node; +use state_space::CommitData; + +pub mod serial { + //! Serialized formats for commits, state spaces and persistent HUGRs. + pub use super::persistent_hugr::serial::*; + pub use super::state_space::serial::*; +} + +#[cfg(test)] +mod tests; diff --git a/hugr-core/src/hugr/persistent/parents_view.rs b/hugr-persistent/src/parents_view.rs similarity index 95% rename from hugr-core/src/hugr/persistent/parents_view.rs rename to hugr-persistent/src/parents_view.rs index b4aa076060..6f1f3c86de 100644 --- a/hugr-core/src/hugr/persistent/parents_view.rs +++ b/hugr-persistent/src/parents_view.rs @@ -1,9 +1,10 @@ use std::collections::{BTreeMap, HashMap}; -use crate::{ +use hugr_core::{ Direction, Hugr, HugrView, Node, Port, extension::ExtensionRegistry, hugr::{ + self, internal::HugrInternals, views::{ExtractionResult, render}, }, @@ -17,12 +18,15 @@ use super::{CommitStateSpace, PatchNode, state_space::CommitId}; /// Note that this is not a valid HUGR: not a single entrypoint, root etc. As /// a consequence, not all HugrView methods are implemented. #[derive(Debug, Clone)] -pub(super) struct ParentsView<'a> { +pub(crate) struct ParentsView<'a> { hugrs: BTreeMap, } impl<'a> ParentsView<'a> { - pub(super) fn from_commit(commit_id: CommitId, state_space: &'a CommitStateSpace) -> Self { + pub(crate) fn from_commit( + commit_id: CommitId, + state_space: &'a CommitStateSpace, + ) -> Self { let mut hugrs = BTreeMap::new(); for parent in state_space.parents(commit_id) { hugrs.insert(parent, state_space.commit_hugr(parent)); @@ -33,7 +37,7 @@ impl<'a> ParentsView<'a> { impl HugrInternals for ParentsView<'_> { type RegionPortgraph<'p> - = portgraph::MultiPortGraph + = portgraph::MultiPortGraph where Self: 'p; @@ -51,7 +55,7 @@ impl HugrInternals for ParentsView<'_> { unimplemented!() } - fn node_metadata_map(&self, node: Self::Node) -> &crate::hugr::NodeMetadataMap { + fn node_metadata_map(&self, node: Self::Node) -> &hugr::NodeMetadataMap { let PatchNode(commit_id, node) = node; self.hugrs .get(&commit_id) diff --git a/hugr-core/src/hugr/persistent.rs b/hugr-persistent/src/persistent_hugr.rs similarity index 59% rename from hugr-core/src/hugr/persistent.rs rename to hugr-persistent/src/persistent_hugr.rs index d2813e11b6..bdca32ec1f 100644 --- a/hugr-core/src/hugr/persistent.rs +++ b/hugr-persistent/src/persistent_hugr.rs @@ -1,98 +1,23 @@ -//! Persistent data structure for HUGR mutations. -//! -//! This module provides a persistent data structure [`PersistentHugr`] that -//! implements [`crate::HugrView`]; mutations to the data are stored -//! persistently as a set of [`Commit`]s along with the dependencies between the -//! commits. -//! -//! As a result of persistency, the entire mutation history of a HUGR can be -//! traversed and references to previous versions of the data remain valid even -//! as the HUGR graph is "mutated" by applying patches: the patches are in -//! effect added to the history as new commits. -//! -//! The data structure underlying [`PersistentHugr`], which stores the history -//! of all commits, is [`CommitStateSpace`]. Multiple [`PersistentHugr`] can be -//! stored within a single [`CommitStateSpace`], which allows for the efficient -//! exploration of the space of all possible graph rewrites. -//! -//! ## Overlapping commits -//! -//! In general, [`CommitStateSpace`] may contain overlapping commits. Such -//! mutations are mutually exclusive as they modify the same nodes. It is -//! therefore not possible to apply all commits in a [`CommitStateSpace`] -//! simultaneously. A [`PersistentHugr`] on the other hand always corresponds to -//! a subgraph of a [`CommitStateSpace`] that is guaranteed to contain only -//! non-overlapping, compatible commits. By applying all commits in a -//! [`PersistentHugr`], we can materialize a [`Hugr`]. Traversing the -//! materialized HUGR is equivalent to using the [`crate::HugrView`] -//! implementation of the corresponding [`PersistentHugr`]. -//! -//! ## Summary of data types -//! -//! - [`Commit`] A modification to a [`Hugr`] (currently a -//! [`SimpleReplacement`]) that forms the atomic unit of change for a -//! [`PersistentHugr`] (like a commit in git). This is a reference-counted -//! value that is cheap to clone and will be freed when the last reference is -//! dropped. -//! - [`PersistentHugr`] A data structure that implements [`crate::HugrView`] -//! and can be used as a drop-in replacement for a [`crate::Hugr`] for -//! read-only access and mutations through the [`PatchVerification`] and -//! [`Patch`] traits. Mutations are stored as a history of commits. Unlike -//! [`CommitStateSpace`], it maintains the invariant that all contained -//! commits are compatible with eachother. -//! - [`CommitStateSpace`] Stores commits, recording the dependencies between -//! them. Includes the base HUGR and any number of possibly incompatible -//! (overlapping) commits. Unlike a [`PersistentHugr`], a state space can -//! contain mutually exclusive commits. -//! -//! ## Usage -//! -//! A [`PersistentHugr`] can be created from a base HUGR using -//! [`PersistentHugr::with_base`]. Replacements can then be applied to it -//! using [`PersistentHugr::add_replacement`]. Alternatively, if you already -//! have a populated state space, use [`PersistentHugr::try_new`] to create a -//! new HUGR with those commits. -//! -//! Add a sequence of commits to a state space by merging a [`PersistentHugr`] -//! into it using [`CommitStateSpace::extend`] or directly using -//! [`CommitStateSpace::try_add_commit`]. -//! -//! To obtain a [`PersistentHugr`] from your state space, use -//! [`CommitStateSpace::try_extract_hugr`]. A [`PersistentHugr`] can always be -//! materialized into a [`Hugr`] type using [`PersistentHugr::to_hugr`]. -//! -//! -//! [`PatchVerification`]: crate::hugr::patch::PatchVerification - -mod parents_view; -mod resolver; -mod state_space; -mod trait_impls; -pub mod walker; - -pub use walker::{PinnedWire, Walker}; - use std::{ - collections::{BTreeSet, HashMap, VecDeque}, + collections::{BTreeSet, HashMap}, mem, vec, }; use delegate::delegate; use derive_more::derive::From; +use hugr_core::{ + Direction, Hugr, HugrView, IncomingPort, Node, OutgoingPort, Port, SimpleReplacement, + hugr::patch::{Patch, simple_replace}, +}; use itertools::{Either, Itertools}; use relrc::RelRc; -use state_space::CommitData; -pub use state_space::{CommitId, CommitStateSpace, InvalidCommit, PatchNode}; - -pub use resolver::PointerEqResolver; use crate::{ - Hugr, HugrView, IncomingPort, Node, OutgoingPort, Port, SimpleReplacement, - hugr::patch::{Patch, simple_replace}, + CommitData, CommitId, CommitStateSpace, InvalidCommit, PatchNode, PersistentReplacement, + Resolver, }; -/// A replacement operation that can be applied to a [`PersistentHugr`]. -pub type PersistentReplacement = SimpleReplacement; +pub mod serial; /// A patch that can be applied to a [`PersistentHugr`] or a /// [`CommitStateSpace`] as an atomic commit. @@ -113,29 +38,49 @@ impl Commit { /// Requires a reference to the commit state space that the nodes in /// `replacement` refer to. /// + /// Use [`Self::try_new`] instead if the parents of the commit cannot be + /// inferred from the invalidation set of `replacement` alone. + /// /// The replacement must act on a non-empty subgraph, otherwise this /// function will return an [`InvalidCommit::EmptyReplacement`] error. /// /// If any of the parents of the replacement are not in the commit state /// space, this function will return an [`InvalidCommit::UnknownParent`] /// error. - pub fn try_from_replacement( + pub fn try_from_replacement( + replacement: PersistentReplacement, + graph: &CommitStateSpace, + ) -> Result { + Self::try_new(replacement, [], graph) + } + + /// Create a new commit + /// + /// Requires a reference to the commit state space that the nodes in + /// `replacement` refer to. + /// + /// The returned commit will correspond to the application of `replacement` + /// and will be the child of the commits in `parents` as well as of all + /// the commits in the invalidation set of `replacement`. + /// + /// The replacement must act on a non-empty subgraph, otherwise this + /// function will return an [`InvalidCommit::EmptyReplacement`] error. + /// If any of the parents of the replacement are not in the commit state + /// space, this function will return an [`InvalidCommit::UnknownParent`] + /// error. + pub fn try_new( replacement: PersistentReplacement, - graph: &CommitStateSpace, + parents: impl IntoIterator, + graph: &CommitStateSpace, ) -> Result { if replacement.subgraph().nodes().is_empty() { return Err(InvalidCommit::EmptyReplacement); } - let parent_ids = replacement.invalidation_set().map(|n| n.0).unique(); - let parents = parent_ids - .map(|id| { - if graph.contains_id(id) { - Ok(graph.get_commit(id).clone()) - } else { - Err(InvalidCommit::UnknownParent(id)) - } - }) - .collect::, _>>()?; + let repl_parents = get_parent_commits(&replacement, graph)?; + let parents = parents + .into_iter() + .chain(repl_parents) + .unique_by(|p| p.as_ptr()); let rc = RelRc::with_parents( replacement.into(), parents.into_iter().map(|p| (p.into(), ())), @@ -143,7 +88,7 @@ impl Commit { Ok(Self(rc)) } - fn as_relrc(&self) -> &RelRc { + pub(crate) fn as_relrc(&self) -> &RelRc { &self.0 } @@ -187,8 +132,8 @@ impl Commit { delegate! { to self.0 { - fn value(&self) -> &CommitData; - fn as_ptr(&self) -> *const relrc::node::InnerData; + pub(crate) fn value(&self) -> &CommitData; + pub(crate) fn as_ptr(&self) -> *const relrc::node::InnerData; } } @@ -242,12 +187,13 @@ impl<'a> From<&'a RelRc> for &'a Commit { /// /// ## Supported access and mutation /// -/// [`PersistentHugr`] implements [`crate::HugrView`], so that it can used as +/// [`PersistentHugr`] implements [`HugrView`], so that it can used as /// a drop-in substitute for a Hugr wherever read-only access is required. It -/// does not implement [`HugrMut`](crate::hugr::HugrMut), however. Mutations -/// must be performed by applying patches (see [`PatchVerification`] and -/// [`Patch`]). Currently, only [`SimpleReplacement`] patches are supported. You -/// can use [`Self::add_replacement`] to add a patch to `self`, or use the +/// does not implement [`HugrMut`](hugr_core::hugr::hugrmut::HugrMut), however. +/// Mutations must be performed by applying patches (see +/// [`PatchVerification`](hugr_core::hugr::patch::PatchVerification) +/// and [`Patch`]). Currently, only [`SimpleReplacement`] patches are supported. +/// You can use [`Self::add_replacement`] to add a patch to `self`, or use the /// aforementioned patch traits. /// /// ## Patches, commits and history @@ -267,19 +213,16 @@ impl<'a> From<&'a RelRc> for &'a Commit { /// /// Currently, only patches that apply to subgraphs within dataflow regions /// are supported. -/// -/// [`PatchVerification`]: crate::hugr::patch::PatchVerification - #[derive(Clone, Debug)] -pub struct PersistentHugr { +pub struct PersistentHugr { /// The state space of all commits. /// /// Invariant: all commits are "compatible", meaning that no two patches /// invalidate the same node. - state_space: CommitStateSpace, + state_space: CommitStateSpace, } -impl PersistentHugr { +impl PersistentHugr { /// Create a [`PersistentHugr`] with `hugr` as its base HUGR. /// /// All replacements added in the future will apply on top of `hugr`. @@ -309,13 +252,6 @@ impl PersistentHugr { graph.try_extract_hugr(graph.all_commit_ids()) } - /// Construct a [`PersistentHugr`] from a [`CommitStateSpace`]. - /// - /// Does not check that the commits are compatible. - fn from_state_space_unsafe(state_space: CommitStateSpace) -> Self { - Self { state_space } - } - /// Add a replacement to `self`. /// /// The effect of this is equivalent to applying `replacement` to the @@ -395,13 +331,22 @@ impl PersistentHugr { } Ok(commit_id.expect("new_commits cannot be empty")) } +} + +impl PersistentHugr { + /// Construct a [`PersistentHugr`] from a [`CommitStateSpace`]. + /// + /// Does not check that the commits are compatible. + pub(crate) fn from_state_space_unsafe(state_space: CommitStateSpace) -> Self { + Self { state_space } + } /// Convert this `PersistentHugr` to a materialized Hugr by applying all /// commits in `self`. /// /// This operation may be expensive and should be avoided in /// performance-critical paths. For read-only views into the data, rely - /// instead on the [`crate::HugrView`] implementation when possible. + /// instead on the [`HugrView`] implementation when possible. pub fn to_hugr(&self) -> Hugr { self.apply_all().0 } @@ -421,7 +366,9 @@ impl PersistentHugr { continue; }; - let repl = repl.map_host_nodes(|n| node_map[&n]); + let repl = repl + .map_host_nodes(|n| node_map[&n], &hugr) + .expect("invalid replacement"); let simple_replace::Outcome { node_map: new_node_map, @@ -452,12 +399,12 @@ impl PersistentHugr { } /// Get a reference to the underlying state space of `self`. - pub fn as_state_space(&self) -> &CommitStateSpace { + pub fn as_state_space(&self) -> &CommitStateSpace { &self.state_space } /// Convert `self` into its underlying [`CommitStateSpace`]. - pub fn into_state_space(self) -> CommitStateSpace { + pub fn into_state_space(self) -> CommitStateSpace { self.state_space } @@ -467,68 +414,14 @@ impl PersistentHugr { /// /// Panics if `node` is not in `self` (in particular if it is deleted) or if /// `port` is not a value port in `node`. - fn get_single_outgoing_port( + pub(crate) fn single_outgoing_port( &self, node: PatchNode, port: impl Into, ) -> (PatchNode, OutgoingPort) { - let mut in_port = port.into(); - let PatchNode(commit_id, mut in_node) = node; - - assert!(self.is_value_port(node, in_port), "not a dataflow wire"); - assert!(self.contains_node(node), "node not in self"); - - let hugr = self.commit_hugr(commit_id); - let (mut out_node, mut out_port) = hugr - .single_linked_output(in_node, in_port) - .map(|(n, p)| (PatchNode(commit_id, n), p)) - .expect("invalid HUGR"); - - // invariant: (out_node, out_port) -> (in_node, in_port) is a boundary - // edge, i.e. it never is the case that both are deleted by the same - // child commit - loop { - let commit_id = out_node.0; - - if let Some(deleted_by) = self.find_deleting_commit(out_node) { - (out_node, out_port) = self - .state_space - .linked_child_output(PatchNode(commit_id, in_node), in_port, deleted_by) - .expect("valid boundary edge"); - // update (in_node, in_port) - (in_node, in_port) = { - let new_commit_id = out_node.0; - let hugr = self.commit_hugr(new_commit_id); - hugr.linked_inputs(out_node.1, out_port) - .find(|&(n, _)| { - self.find_deleting_commit(PatchNode(commit_id, n)).is_none() - }) - .expect("out_node is connected to output node (which is never deleted)") - }; - } else if self - .replacement(commit_id) - .is_some_and(|repl| repl.get_replacement_io()[0] == out_node.1) - { - // out_node is an input node - (out_node, out_port) = self - .as_state_space() - .linked_parent_input(PatchNode(commit_id, in_node), in_port); - // update (in_node, in_port) - (in_node, in_port) = { - let new_commit_id = out_node.0; - let hugr = self.commit_hugr(new_commit_id); - hugr.linked_inputs(out_node.1, out_port) - .find(|&(n, _)| { - self.find_deleting_commit(PatchNode(new_commit_id, n)) - == Some(commit_id) - }) - .expect("boundary edge must connect out_node to deleted node") - }; - } else { - // valid outgoing node! - return (out_node, out_port); - } - } + let w = self.get_wire(node, port.into()); + w.single_outgoing_port(self) + .expect("found invalid dfg wire") } /// All incoming ports that the given outgoing port is attached to. @@ -537,99 +430,14 @@ impl PersistentHugr { /// /// Panics if `out_node` is not in `self` (in particular if it is deleted) /// or if `out_port` is not a value port in `out_node`. - fn get_all_incoming_ports( + pub(crate) fn all_incoming_ports( &self, out_node: PatchNode, out_port: OutgoingPort, ) -> impl Iterator { - assert!( - self.is_value_port(out_node, out_port), - "not a dataflow wire" - ); - assert!(self.contains_node(out_node), "node not in self"); - - let mut visited = BTreeSet::new(); - // enqueue the outport and initialise the set of valid incoming ports - // to the valid incoming ports in this commit - let mut queue = VecDeque::from([(out_node, out_port)]); - let mut valid_incoming_ports = BTreeSet::from_iter( - self.commit_hugr(out_node.0) - .linked_inputs(out_node.1, out_port) - .map(|(in_node, in_port)| (PatchNode(out_node.0, in_node), in_port)) - .filter(|(in_node, _)| self.contains_node(*in_node)), - ); - - // A simple BFS across the commit history to find all equivalent incoming ports. - while let Some((out_node, out_port)) = queue.pop_front() { - if !visited.insert((out_node, out_port)) { - continue; - } - let commit_id = out_node.0; - let hugr = self.commit_hugr(commit_id); - let out_deleted_by = self.find_deleting_commit(out_node); - let curr_repl_out = { - let repl = self.replacement(commit_id); - repl.map(|r| r.get_replacement_io()[1]) - }; - // incoming ports are of interest to us if - // (i) they are connected to the output of a replacement (then there will be a - // linked port in a parent commit), or - // (ii) they are deleted by a child commit and are not equal to the out_node - // (then there will be a linked port in a child commit) - let is_linked_to_output = curr_repl_out.is_some_and(|curr_repl_out| { - hugr.linked_inputs(out_node.1, out_port) - .any(|(in_node, _)| in_node == curr_repl_out) - }); - - let deleted_by_child: BTreeSet<_> = hugr - .linked_inputs(out_node.1, out_port) - .filter(|(in_node, _)| Some(in_node) != curr_repl_out.as_ref()) - .filter_map(|(in_node, _)| { - self.find_deleting_commit(PatchNode(commit_id, in_node)) - .filter(|other_deleted_by| - // (out_node, out_port) -> (in_node, in_port) is a boundary edge - // into the child commit `other_deleted_by` - (Some(other_deleted_by) != out_deleted_by.as_ref())) - }) - .collect(); - - // Convert an incoming port to the unique outgoing port that it is linked to - let to_outgoing_port = |(PatchNode(commit_id, in_node), in_port)| { - let hugr = self.commit_hugr(commit_id); - let (out_node, out_port) = hugr - .single_linked_output(in_node, in_port) - .expect("valid dfg wire"); - (PatchNode(commit_id, out_node), out_port) - }; - - if is_linked_to_output { - // Traverse boundary to parent(s) - let new_ins = self - .as_state_space() - .linked_parent_outputs(out_node, out_port); - for (in_node, in_port) in new_ins { - if self.contains_node(in_node) { - valid_incoming_ports.insert((in_node, in_port)); - } - queue.push_back(to_outgoing_port((in_node, in_port))); - } - } - - for child in deleted_by_child { - // Traverse boundary to `child` - let new_ins = self - .as_state_space() - .linked_child_inputs(out_node, out_port, child); - for (in_node, in_port) in new_ins { - if self.contains_node(in_node) { - valid_incoming_ports.insert((in_node, in_port)); - } - queue.push_back(to_outgoing_port((in_node, in_port))); - } - } - } - - valid_incoming_ports.into_iter() + let w = self.get_wire(out_node, out_port); + w.into_all_ports(self, Direction::Incoming) + .map(|(node, port)| (node, port.as_incoming().unwrap())) } delegate! { @@ -646,17 +454,19 @@ impl PersistentHugr { pub fn base_commit(&self) -> &Commit; /// Get the commit with ID `commit_id`. pub fn get_commit(&self, commit_id: CommitId) -> &Commit; + /// Check whether `commit_id` exists and return it. + pub fn try_get_commit(&self, commit_id: CommitId) -> Option<&Commit>; /// Get an iterator over all nodes inserted by `commit_id`. /// /// All nodes will be PatchNodes with commit ID `commit_id`. pub fn inserted_nodes(&self, commit_id: CommitId) -> impl Iterator + '_; /// Get the replacement for `commit_id`. - fn replacement(&self, commit_id: CommitId) -> Option<&SimpleReplacement>; + pub(crate) fn replacement(&self, commit_id: CommitId) -> Option<&SimpleReplacement>; /// Get the Hugr inserted by `commit_id`. /// /// This is either the replacement Hugr of a [`CommitData::Replacement`] or /// the base Hugr of a [`CommitData::Base`]. - pub(super) fn commit_hugr(&self, commit_id: CommitId) -> &Hugr; + pub(crate) fn commit_hugr(&self, commit_id: CommitId) -> &Hugr; /// Get an iterator over all commit IDs in the persistent HUGR. pub fn all_commit_ids(&self) -> impl Iterator + Clone + '_; } @@ -701,7 +511,11 @@ impl PersistentHugr { .unique() } - fn find_deleting_commit(&self, node @ PatchNode(commit_id, _): PatchNode) -> Option { + /// Get the child commit that deletes `node`. + pub(crate) fn find_deleting_commit( + &self, + node @ PatchNode(commit_id, _): PatchNode, + ) -> Option { let mut children = self.state_space.children(commit_id); children.find(move |&child_id| { let child = self.get_commit(child_id); @@ -709,6 +523,12 @@ impl PersistentHugr { }) } + /// Convert a node ID specific to a commit HUGR into a patch node in the + /// [`PersistentHugr`]. + pub(crate) fn to_persistent_node(&self, node: Node, commit_id: CommitId) -> PatchNode { + PatchNode(commit_id, node) + } + /// Check if a patch node is in the PersistentHugr, that is, it belongs to /// a commit in the state space and is not deleted by any child commit. pub fn contains_node(&self, PatchNode(commit_id, node): PatchNode) -> bool { @@ -720,16 +540,46 @@ impl PersistentHugr { self.contains_id(commit_id) && !is_replacement_io() && !is_deleted() } - fn is_value_port(&self, PatchNode(commit_id, node): PatchNode, port: impl Into) -> bool { + pub(crate) fn is_value_port( + &self, + PatchNode(commit_id, node): PatchNode, + port: impl Into, + ) -> bool { self.commit_hugr(commit_id) .get_optype(node) .port_kind(port) .expect("invalid port") .is_value() } + + pub(super) fn value_ports( + &self, + patch_node @ PatchNode(commit_id, node): PatchNode, + dir: Direction, + ) -> impl Iterator + '_ { + let hugr = self.commit_hugr(commit_id); + let ports = hugr.node_ports(node, dir); + ports.filter_map(move |p| self.is_value_port(patch_node, p).then_some((patch_node, p))) + } + + pub(super) fn output_value_ports( + &self, + patch_node: PatchNode, + ) -> impl Iterator + '_ { + self.value_ports(patch_node, Direction::Outgoing) + .map(|(n, p)| (n, p.as_outgoing().expect("unexpected port direction"))) + } + + pub(super) fn input_value_ports( + &self, + patch_node: PatchNode, + ) -> impl Iterator + '_ { + self.value_ports(patch_node, Direction::Incoming) + .map(|(n, p)| (n, p.as_incoming().expect("unexpected port direction"))) + } } -impl IntoIterator for PersistentHugr { +impl IntoIterator for PersistentHugr { type Item = Commit; type IntoIter = vec::IntoIter; @@ -745,13 +595,13 @@ impl IntoIterator for PersistentHugr { /// Find a node in `commit` that is invalidated by more than one child commit /// among `children`. -fn find_conflicting_node<'a>( +pub(crate) fn find_conflicting_node<'a>( commit_id: CommitId, - mut children: impl Iterator, + children: impl IntoIterator, ) -> Option { let mut all_invalidated = BTreeSet::new(); - children.find_map(|child| { + children.into_iter().find_map(|child| { let mut new_invalidated = child .invalidation_set() @@ -766,12 +616,17 @@ fn find_conflicting_node<'a>( }) } -pub mod serial { - //! Serialization formats of [`CommitStateSpace`](super::CommitStateSpace) - //! and related types - #[doc(inline)] - pub use super::state_space::serial::*; +fn get_parent_commits( + replacement: &PersistentReplacement, + graph: &CommitStateSpace, +) -> Result, InvalidCommit> { + let parent_ids = replacement.invalidation_set().map(|n| n.owner()).unique(); + parent_ids + .map(|id| { + graph + .try_get_commit(id) + .cloned() + .ok_or(InvalidCommit::UnknownParent(id)) + }) + .collect() } - -#[cfg(test)] -mod tests; diff --git a/hugr-persistent/src/persistent_hugr/serial.rs b/hugr-persistent/src/persistent_hugr/serial.rs new file mode 100644 index 0000000000..9a41e4acef --- /dev/null +++ b/hugr-persistent/src/persistent_hugr/serial.rs @@ -0,0 +1,75 @@ +//! Serialized format for [`PersistentHugr`] + +use hugr_core::Hugr; + +use crate::{CommitStateSpace, Resolver, state_space::serial::SerialCommitStateSpace}; + +use super::PersistentHugr; + +/// Serialized format for [`PersistentHugr`] +#[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)] +pub struct SerialPersistentHugr { + /// The state space of all commits. + state_space: SerialCommitStateSpace, +} + +impl PersistentHugr { + /// Create a new [`CommitStateSpace`] from its serialized format + pub fn from_serial>(value: SerialPersistentHugr) -> Self { + let SerialPersistentHugr { state_space } = value; + let state_space = CommitStateSpace::from_serial(state_space); + Self { state_space } + } + + /// Convert a [`CommitStateSpace`] into its serialized format + pub fn into_serial>(self) -> SerialPersistentHugr { + let Self { state_space } = self; + let state_space = state_space.into_serial(); + SerialPersistentHugr { state_space } + } + + /// Create a serialized format from a reference to [`CommitStateSpace`] + pub fn to_serial>(&self) -> SerialPersistentHugr { + let Self { state_space } = self; + let state_space = state_space.to_serial(); + SerialPersistentHugr { state_space } + } +} + +impl, R: Resolver> From> for SerialPersistentHugr { + fn from(value: PersistentHugr) -> Self { + value.into_serial() + } +} + +impl, R: Resolver> From> for PersistentHugr { + fn from(value: SerialPersistentHugr) -> Self { + PersistentHugr::from_serial(value) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{ + CommitId, SerdeHashResolver, + tests::{WrappedHugr, test_state_space}, + }; + + use rstest::rstest; + + #[rstest] + fn test_serde_persistent_hugr( + test_state_space: ( + CommitStateSpace>, + [CommitId; 4], + ), + ) { + let (state_space, [cm1, cm2, _, cm4]) = test_state_space; + + let per_hugr = state_space.try_extract_hugr([cm1, cm2, cm4]).unwrap(); + let ser_per_hugr = per_hugr.to_serial::(); + + insta::assert_snapshot!(serde_json::to_string_pretty(&ser_per_hugr).unwrap()); + } +} diff --git a/hugr-persistent/src/persistent_hugr/snapshots/hugr_persistent__persistent_hugr__serial__tests__serde_persistent_hugr.snap b/hugr-persistent/src/persistent_hugr/snapshots/hugr_persistent__persistent_hugr__serial__tests__serde_persistent_hugr.snap new file mode 100644 index 0000000000..e7f544586a --- /dev/null +++ b/hugr-persistent/src/persistent_hugr/snapshots/hugr_persistent__persistent_hugr__serial__tests__serde_persistent_hugr.snap @@ -0,0 +1,184 @@ +--- +source: hugr-persistent/src/persistent_hugr/serial.rs +expression: "serde_json::to_string_pretty(&ser_per_hugr).unwrap()" +--- +{ + "state_space": { + "graph": { + "nodes": { + "3fd58bd8c5f2494a": { + "value": { + "Base": { + "hugr": "HUGRiHJv?@{\"modules\":[{\"version\":\"live\",\"nodes\":[{\"parent\":0,\"op\":\"Module\"},{\"parent\":0,\"op\":\"FuncDefn\",\"name\":\"main\",\"signature\":{\"params\":[],\"body\":{\"input\":[{\"t\":\"Sum\",\"s\":\"Unit\",\"size\":2},{\"t\":\"Sum\",\"s\":\"Unit\",\"size\":2}],\"output\":[{\"t\":\"Sum\",\"s\":\"Unit\",\"size\":2}]}},\"visibility\":\"Private\"},{\"parent\":1,\"op\":\"Input\",\"types\":[{\"t\":\"Sum\",\"s\":\"Unit\",\"size\":2},{\"t\":\"Sum\",\"s\":\"Unit\",\"size\":2}]},{\"parent\":1,\"op\":\"Output\",\"types\":[{\"t\":\"Sum\",\"s\":\"Unit\",\"size\":2}]},{\"parent\":1,\"op\":\"DFG\",\"signature\":{\"input\":[{\"t\":\"Sum\",\"s\":\"Unit\",\"size\":2},{\"t\":\"Sum\",\"s\":\"Unit\",\"size\":2}],\"output\":[{\"t\":\"Sum\",\"s\":\"Unit\",\"size\":2}]}},{\"parent\":4,\"op\":\"Input\",\"types\":[{\"t\":\"Sum\",\"s\":\"Unit\",\"size\":2},{\"t\":\"Sum\",\"s\":\"Unit\",\"size\":2}]},{\"parent\":4,\"op\":\"Output\",\"types\":[{\"t\":\"Sum\",\"s\":\"Unit\",\"size\":2}]},{\"parent\":4,\"op\":\"Extension\",\"extension\":\"logic\",\"name\":\"Not\",\"args\":[],\"signature\":{\"input\":[{\"t\":\"Sum\",\"s\":\"Unit\",\"size\":2}],\"output\":[{\"t\":\"Sum\",\"s\":\"Unit\",\"size\":2}]}},{\"parent\":4,\"op\":\"Extension\",\"extension\":\"logic\",\"name\":\"Not\",\"args\":[],\"signature\":{\"input\":[{\"t\":\"Sum\",\"s\":\"Unit\",\"size\":2}],\"output\":[{\"t\":\"Sum\",\"s\":\"Unit\",\"size\":2}]}},{\"parent\":4,\"op\":\"Extension\",\"extension\":\"logic\",\"name\":\"And\",\"args\":[],\"signature\":{\"input\":[{\"t\":\"Sum\",\"s\":\"Unit\",\"size\":2},{\"t\":\"Sum\",\"s\":\"Unit\",\"size\":2}],\"output\":[{\"t\":\"Sum\",\"s\":\"Unit\",\"size\":2}]}}],\"edges\":[[[2,0],[4,0]],[[2,1],[4,1]],[[4,0],[3,0]],[[5,0],[7,0]],[[5,1],[8,0]],[[7,0],[9,0]],[[8,0],[9,1]],[[9,0],[6,0]]],\"metadata\":[null,null,null,null,null,null,null,null,null,null],\"entrypoint\":4}],\"extensions\":[{\"version\":\"0.1.0\",\"name\":\"arithmetic.conversions\",\"types\":{},\"operations\":{\"bytecast_float64_to_int64\":{\"extension\":\"arithmetic.conversions\",\"name\":\"bytecast_float64_to_int64\",\"description\":\"reinterpret an float64 as an int based on its bytes, with the same endianness\",\"signature\":{\"params\":[],\"body\":{\"input\":[{\"t\":\"Opaque\",\"extension\":\"arithmetic.float.types\",\"id\":\"float64\",\"args\":[],\"bound\":\"C\"}],\"output\":[{\"t\":\"Opaque\",\"extension\":\"arithmetic.int.types\",\"id\":\"int\",\"args\":[{\"tya\":\"BoundedNat\",\"n\":6}],\"bound\":\"C\"}]}},\"binary\":false},\"bytecast_int64_to_float64\":{\"extension\":\"arithmetic.conversions\",\"name\":\"bytecast_int64_to_float64\",\"description\":\"reinterpret an int64 as a float64 based on its bytes, with the same endianness\",\"signature\":{\"params\":[],\"body\":{\"input\":[{\"t\":\"Opaque\",\"extension\":\"arithmetic.int.types\",\"id\":\"int\",\"args\":[{\"tya\":\"BoundedNat\",\"n\":6}],\"bound\":\"C\"}],\"output\":[{\"t\":\"Opaque\",\"extension\":\"arithmetic.float.types\",\"id\":\"float64\",\"args\":[],\"bound\":\"C\"}]}},\"binary\":false},\"convert_s\":{\"extension\":\"arithmetic.conversions\",\"name\":\"convert_s\",\"description\":\"signed int to float\",\"signature\":{\"params\":[{\"tp\":\"BoundedNat\",\"bound\":7}],\"body\":{\"input\":[{\"t\":\"Opaque\",\"extension\":\"arithmetic.int.types\",\"id\":\"int\",\"args\":[{\"tya\":\"Variable\",\"idx\":0,\"cached_decl\":{\"tp\":\"BoundedNat\",\"bound\":7}}],\"bound\":\"C\"}],\"output\":[{\"t\":\"Opaque\",\"extension\":\"arithmetic.float.types\",\"id\":\"float64\",\"args\":[],\"bound\":\"C\"}]}},\"binary\":false},\"convert_u\":{\"extension\":\"arithmetic.conversions\",\"name\":\"convert_u\",\"description\":\"unsigned int to float\",\"signature\":{\"params\":[{\"tp\":\"BoundedNat\",\"bound\":7}],\"body\":{\"input\":[{\"t\":\"Opaque\",\"extension\":\"arithmetic.int.types\",\"id\":\"int\",\"args\":[{\"tya\":\"Variable\",\"idx\":0,\"cached_decl\":{\"tp\":\"BoundedNat\",\"bound\":7}}],\"bound\":\"C\"}],\"output\":[{\"t\":\"Opaque\",\"extension\":\"arithmetic.float.types\",\"id\":\"float64\",\"args\":[],\"bound\":\"C\"}]}},\"binary\":false},\"ifrombool\":{\"extension\":\"arithmetic.conversions\",\"name\":\"ifrombool\",\"description\":\"convert from bool into a 1-bit integer (1 is true, 0 is false)\",\"signature\":{\"params\":[],\"body\":{\"input\":[{\"t\":\"Sum\",\"s\":\"Unit\",\"size\":2}],\"output\":[{\"t\":\"Opaque\",\"extension\":\"arithmetic.int.types\",\"id\":\"int\",\"args\":[{\"tya\":\"BoundedNat\",\"n\":0}],\"bound\":\"C\"}]}},\"binary\":false},\"ifromusize\":{\"extension\":\"arithmetic.conversions\",\"name\":\"ifromusize\",\"description\":\"convert a usize to a 64b unsigned integer\",\"signature\":{\"params\":[],\"body\":{\"input\":[{\"t\":\"I\"}],\"output\":[{\"t\":\"Opaque\",\"extension\":\"arithmetic.int.types\",\"id\":\"int\",\"args\":[{\"tya\":\"BoundedNat\",\"n\":6}],\"bound\":\"C\"}]}},\"binary\":false},\"itobool\":{\"extension\":\"arithmetic.conversions\",\"name\":\"itobool\",\"description\":\"convert a 1-bit integer to bool (1 is true, 0 is false)\",\"signature\":{\"params\":[],\"body\":{\"input\":[{\"t\":\"Opaque\",\"extension\":\"arithmetic.int.types\",\"id\":\"int\",\"args\":[{\"tya\":\"BoundedNat\",\"n\":0}],\"bound\":\"C\"}],\"output\":[{\"t\":\"Sum\",\"s\":\"Unit\",\"size\":2}]}},\"binary\":false},\"itostring_s\":{\"extension\":\"arithmetic.conversions\",\"name\":\"itostring_s\",\"description\":\"convert a signed integer to its string representation\",\"signature\":{\"params\":[{\"tp\":\"BoundedNat\",\"bound\":7}],\"body\":{\"input\":[{\"t\":\"Opaque\",\"extension\":\"arithmetic.int.types\",\"id\":\"int\",\"args\":[{\"tya\":\"Variable\",\"idx\":0,\"cached_decl\":{\"tp\":\"BoundedNat\",\"bound\":7}}],\"bound\":\"C\"}],\"output\":[{\"t\":\"Opaque\",\"extension\":\"prelude\",\"id\":\"string\",\"args\":[],\"bound\":\"C\"}]}},\"binary\":false},\"itostring_u\":{\"extension\":\"arithmetic.conversions\",\"name\":\"itostring_u\",\"description\":\"convert an unsigned integer to its string representation\",\"signature\":{\"params\":[{\"tp\":\"BoundedNat\",\"bound\":7}],\"body\":{\"input\":[{\"t\":\"Opaque\",\"extension\":\"arithmetic.int.types\",\"id\":\"int\",\"args\":[{\"tya\":\"Variable\",\"idx\":0,\"cached_decl\":{\"tp\":\"BoundedNat\",\"bound\":7}}],\"bound\":\"C\"}],\"output\":[{\"t\":\"Opaque\",\"extension\":\"prelude\",\"id\":\"string\",\"args\":[],\"bound\":\"C\"}]}},\"binary\":false},\"itousize\":{\"extension\":\"arithmetic.conversions\",\"name\":\"itousize\",\"description\":\"convert a 64b unsigned integer to its usize representation\",\"signature\":{\"params\":[],\"body\":{\"input\":[{\"t\":\"Opaque\",\"extension\":\"arithmetic.int.types\",\"id\":\"int\",\"args\":[{\"tya\":\"BoundedNat\",\"n\":6}],\"bound\":\"C\"}],\"output\":[{\"t\":\"I\"}]}},\"binary\":false},\"trunc_s\":{\"extension\":\"arithmetic.conversions\",\"name\":\"trunc_s\",\"description\":\"float to signed int\",\"signature\":{\"params\":[{\"tp\":\"BoundedNat\",\"bound\":7}],\"body\":{\"input\":[{\"t\":\"Opaque\",\"extension\":\"arithmetic.float.types\",\"id\":\"float64\",\"args\":[],\"bound\":\"C\"}],\"output\":[{\"t\":\"Sum\",\"s\":\"General\",\"rows\":[[{\"t\":\"Opaque\",\"extension\":\"prelude\",\"id\":\"error\",\"args\":[],\"bound\":\"C\"}],[{\"t\":\"Opaque\",\"extension\":\"arithmetic.int.types\",\"id\":\"int\",\"args\":[{\"tya\":\"Variable\",\"idx\":0,\"cached_decl\":{\"tp\":\"BoundedNat\",\"bound\":7}}],\"bound\":\"C\"}]]}]}},\"binary\":false},\"trunc_u\":{\"extension\":\"arithmetic.conversions\",\"name\":\"trunc_u\",\"description\":\"float to unsigned int\",\"signature\":{\"params\":[{\"tp\":\"BoundedNat\",\"bound\":7}],\"body\":{\"input\":[{\"t\":\"Opaque\",\"extension\":\"arithmetic.float.types\",\"id\":\"float64\",\"args\":[],\"bound\":\"C\"}],\"output\":[{\"t\":\"Sum\",\"s\":\"General\",\"rows\":[[{\"t\":\"Opaque\",\"extension\":\"prelude\",\"id\":\"error\",\"args\":[],\"bound\":\"C\"}],[{\"t\":\"Opaque\",\"extension\":\"arithmetic.int.types\",\"id\":\"int\",\"args\":[{\"tya\":\"Variable\",\"idx\":0,\"cached_decl\":{\"tp\":\"BoundedNat\",\"bound\":7}}],\"bound\":\"C\"}]]}]}},\"binary\":false}}},{\"version\":\"0.1.0\",\"name\":\"arithmetic.float\",\"types\":{},\"operations\":{\"fabs\":{\"extension\":\"arithmetic.float\",\"name\":\"fabs\",\"description\":\"absolute value\",\"signature\":{\"params\":[],\"body\":{\"input\":[{\"t\":\"Opaque\",\"extension\":\"arithmetic.float.types\",\"id\":\"float64\",\"args\":[],\"bound\":\"C\"}],\"output\":[{\"t\":\"Opaque\",\"extension\":\"arithmetic.float.types\",\"id\":\"float64\",\"args\":[],\"bound\":\"C\"}]}},\"binary\":false},\"fadd\":{\"extension\":\"arithmetic.float\",\"name\":\"fadd\",\"description\":\"addition\",\"signature\":{\"params\":[],\"body\":{\"input\":[{\"t\":\"Opaque\",\"extension\":\"arithmetic.float.types\",\"id\":\"float64\",\"args\":[],\"bound\":\"C\"},{\"t\":\"Opaque\",\"extension\":\"arithmetic.float.types\",\"id\":\"float64\",\"args\":[],\"bound\":\"C\"}],\"output\":[{\"t\":\"Opaque\",\"extension\":\"arithmetic.float.types\",\"id\":\"float64\",\"args\":[],\"bound\":\"C\"}]}},\"binary\":false},\"fceil\":{\"extension\":\"arithmetic.float\",\"name\":\"fceil\",\"description\":\"ceiling\",\"signature\":{\"params\":[],\"body\":{\"input\":[{\"t\":\"Opaque\",\"extension\":\"arithmetic.float.types\",\"id\":\"float64\",\"args\":[],\"bound\":\"C\"}],\"output\":[{\"t\":\"Opaque\",\"extension\":\"arithmetic.float.types\",\"id\":\"float64\",\"args\":[],\"bound\":\"C\"}]}},\"binary\":false},\"fdiv\":{\"extension\":\"arithmetic.float\",\"name\":\"fdiv\",\"description\":\"division\",\"signature\":{\"params\":[],\"body\":{\"input\":[{\"t\":\"Opaque\",\"extension\":\"arithmetic.float.types\",\"id\":\"float64\",\"args\":[],\"bound\":\"C\"},{\"t\":\"Opaque\",\"extension\":\"arithmetic.float.types\",\"id\":\"float64\",\"args\":[],\"bound\":\"C\"}],\"output\":[{\"t\":\"Opaque\",\"extension\":\"arithmetic.float.types\",\"id\":\"float64\",\"args\":[],\"bound\":\"C\"}]}},\"binary\":false},\"feq\":{\"extension\":\"arithmetic.float\",\"name\":\"feq\",\"description\":\"equality test\",\"signature\":{\"params\":[],\"body\":{\"input\":[{\"t\":\"Opaque\",\"extension\":\"arithmetic.float.types\",\"id\":\"float64\",\"args\":[],\"bound\":\"C\"},{\"t\":\"Opaque\",\"extension\":\"arithmetic.float.types\",\"id\":\"float64\",\"args\":[],\"bound\":\"C\"}],\"output\":[{\"t\":\"Sum\",\"s\":\"Unit\",\"size\":2}]}},\"binary\":false},\"ffloor\":{\"extension\":\"arithmetic.float\",\"name\":\"ffloor\",\"description\":\"floor\",\"signature\":{\"params\":[],\"body\":{\"input\":[{\"t\":\"Opaque\",\"extension\":\"arithmetic.float.types\",\"id\":\"float64\",\"args\":[],\"bound\":\"C\"}],\"output\":[{\"t\":\"Opaque\",\"extension\":\"arithmetic.float.types\",\"id\":\"float64\",\"args\":[],\"bound\":\"C\"}]}},\"binary\":false},\"fge\":{\"extension\":\"arithmetic.float\",\"name\":\"fge\",\"description\":\"\\\"greater than or equal\\\"\",\"signature\":{\"params\":[],\"body\":{\"input\":[{\"t\":\"Opaque\",\"extension\":\"arithmetic.float.types\",\"id\":\"float64\",\"args\":[],\"bound\":\"C\"},{\"t\":\"Opaque\",\"extension\":\"arithmetic.float.types\",\"id\":\"float64\",\"args\":[],\"bound\":\"C\"}],\"output\":[{\"t\":\"Sum\",\"s\":\"Unit\",\"size\":2}]}},\"binary\":false},\"fgt\":{\"extension\":\"arithmetic.float\",\"name\":\"fgt\",\"description\":\"\\\"greater than\\\"\",\"signature\":{\"params\":[],\"body\":{\"input\":[{\"t\":\"Opaque\",\"extension\":\"arithmetic.float.types\",\"id\":\"float64\",\"args\":[],\"bound\":\"C\"},{\"t\":\"Opaque\",\"extension\":\"arithmetic.float.types\",\"id\":\"float64\",\"args\":[],\"bound\":\"C\"}],\"output\":[{\"t\":\"Sum\",\"s\":\"Unit\",\"size\":2}]}},\"binary\":false},\"fle\":{\"extension\":\"arithmetic.float\",\"name\":\"fle\",\"description\":\"\\\"less than or equal\\\"\",\"signature\":{\"params\":[],\"body\":{\"input\":[{\"t\":\"Opaque\",\"extension\":\"arithmetic.float.types\",\"id\":\"float64\",\"args\":[],\"bound\":\"C\"},{\"t\":\"Opaque\",\"extension\":\"arithmetic.float.types\",\"id\":\"float64\",\"args\":[],\"bound\":\"C\"}],\"output\":[{\"t\":\"Sum\",\"s\":\"Unit\",\"size\":2}]}},\"binary\":false},\"flt\":{\"extension\":\"arithmetic.float\",\"name\":\"flt\",\"description\":\"\\\"less than\\\"\",\"signature\":{\"params\":[],\"body\":{\"input\":[{\"t\":\"Opaque\",\"extension\":\"arithmetic.float.types\",\"id\":\"float64\",\"args\":[],\"bound\":\"C\"},{\"t\":\"Opaque\",\"extension\":\"arithmetic.float.types\",\"id\":\"float64\",\"args\":[],\"bound\":\"C\"}],\"output\":[{\"t\":\"Sum\",\"s\":\"Unit\",\"size\":2}]}},\"binary\":false},\"fmax\":{\"extension\":\"arithmetic.float\",\"name\":\"fmax\",\"description\":\"maximum\",\"signature\":{\"params\":[],\"body\":{\"input\":[{\"t\":\"Opaque\",\"extension\":\"arithmetic.float.types\",\"id\":\"float64\",\"args\":[],\"bound\":\"C\"},{\"t\":\"Opaque\",\"extension\":\"arithmetic.float.types\",\"id\":\"float64\",\"args\":[],\"bound\":\"C\"}],\"output\":[{\"t\":\"Opaque\",\"extension\":\"arithmetic.float.types\",\"id\":\"float64\",\"args\":[],\"bound\":\"C\"}]}},\"binary\":false},\"fmin\":{\"extension\":\"arithmetic.float\",\"name\":\"fmin\",\"description\":\"minimum\",\"signature\":{\"params\":[],\"body\":{\"input\":[{\"t\":\"Opaque\",\"extension\":\"arithmetic.float.types\",\"id\":\"float64\",\"args\":[],\"bound\":\"C\"},{\"t\":\"Opaque\",\"extension\":\"arithmetic.float.types\",\"id\":\"float64\",\"args\":[],\"bound\":\"C\"}],\"output\":[{\"t\":\"Opaque\",\"extension\":\"arithmetic.float.types\",\"id\":\"float64\",\"args\":[],\"bound\":\"C\"}]}},\"binary\":false},\"fmul\":{\"extension\":\"arithmetic.float\",\"name\":\"fmul\",\"description\":\"multiplication\",\"signature\":{\"params\":[],\"body\":{\"input\":[{\"t\":\"Opaque\",\"extension\":\"arithmetic.float.types\",\"id\":\"float64\",\"args\":[],\"bound\":\"C\"},{\"t\":\"Opaque\",\"extension\":\"arithmetic.float.types\",\"id\":\"float64\",\"args\":[],\"bound\":\"C\"}],\"output\":[{\"t\":\"Opaque\",\"extension\":\"arithmetic.float.types\",\"id\":\"float64\",\"args\":[],\"bound\":\"C\"}]}},\"binary\":false},\"fne\":{\"extension\":\"arithmetic.float\",\"name\":\"fne\",\"description\":\"inequality test\",\"signature\":{\"params\":[],\"body\":{\"input\":[{\"t\":\"Opaque\",\"extension\":\"arithmetic.float.types\",\"id\":\"float64\",\"args\":[],\"bound\":\"C\"},{\"t\":\"Opaque\",\"extension\":\"arithmetic.float.types\",\"id\":\"float64\",\"args\":[],\"bound\":\"C\"}],\"output\":[{\"t\":\"Sum\",\"s\":\"Unit\",\"size\":2}]}},\"binary\":false},\"fneg\":{\"extension\":\"arithmetic.float\",\"name\":\"fneg\",\"description\":\"negation\",\"signature\":{\"params\":[],\"body\":{\"input\":[{\"t\":\"Opaque\",\"extension\":\"arithmetic.float.types\",\"id\":\"float64\",\"args\":[],\"bound\":\"C\"}],\"output\":[{\"t\":\"Opaque\",\"extension\":\"arithmetic.float.types\",\"id\":\"float64\",\"args\":[],\"bound\":\"C\"}]}},\"binary\":false},\"fpow\":{\"extension\":\"arithmetic.float\",\"name\":\"fpow\",\"description\":\"exponentiation\",\"signature\":{\"params\":[],\"body\":{\"input\":[{\"t\":\"Opaque\",\"extension\":\"arithmetic.float.types\",\"id\":\"float64\",\"args\":[],\"bound\":\"C\"},{\"t\":\"Opaque\",\"extension\":\"arithmetic.float.types\",\"id\":\"float64\",\"args\":[],\"bound\":\"C\"}],\"output\":[{\"t\":\"Opaque\",\"extension\":\"arithmetic.float.types\",\"id\":\"float64\",\"args\":[],\"bound\":\"C\"}]}},\"binary\":false},\"fround\":{\"extension\":\"arithmetic.float\",\"name\":\"fround\",\"description\":\"round\",\"signature\":{\"params\":[],\"body\":{\"input\":[{\"t\":\"Opaque\",\"extension\":\"arithmetic.float.types\",\"id\":\"float64\",\"args\":[],\"bound\":\"C\"}],\"output\":[{\"t\":\"Opaque\",\"extension\":\"arithmetic.float.types\",\"id\":\"float64\",\"args\":[],\"bound\":\"C\"}]}},\"binary\":false},\"fsub\":{\"extension\":\"arithmetic.float\",\"name\":\"fsub\",\"description\":\"subtraction\",\"signature\":{\"params\":[],\"body\":{\"input\":[{\"t\":\"Opaque\",\"extension\":\"arithmetic.float.types\",\"id\":\"float64\",\"args\":[],\"bound\":\"C\"},{\"t\":\"Opaque\",\"extension\":\"arithmetic.float.types\",\"id\":\"float64\",\"args\":[],\"bound\":\"C\"}],\"output\":[{\"t\":\"Opaque\",\"extension\":\"arithmetic.float.types\",\"id\":\"float64\",\"args\":[],\"bound\":\"C\"}]}},\"binary\":false},\"ftostring\":{\"extension\":\"arithmetic.float\",\"name\":\"ftostring\",\"description\":\"string representation\",\"signature\":{\"params\":[],\"body\":{\"input\":[{\"t\":\"Opaque\",\"extension\":\"arithmetic.float.types\",\"id\":\"float64\",\"args\":[],\"bound\":\"C\"}],\"output\":[{\"t\":\"Opaque\",\"extension\":\"prelude\",\"id\":\"string\",\"args\":[],\"bound\":\"C\"}]}},\"binary\":false}}},{\"version\":\"0.1.0\",\"name\":\"arithmetic.float.types\",\"types\":{\"float64\":{\"extension\":\"arithmetic.float.types\",\"name\":\"float64\",\"params\":[],\"description\":\"64-bit IEEE 754-2019 floating-point value\",\"bound\":{\"b\":\"Explicit\",\"bound\":\"C\"}}},\"operations\":{}},{\"version\":\"0.1.0\",\"name\":\"arithmetic.int\",\"types\":{},\"operations\":{\"iabs\":{\"extension\":\"arithmetic.int\",\"name\":\"iabs\",\"description\":\"convert signed to unsigned by taking absolute value\",\"signature\":{\"params\":[{\"tp\":\"BoundedNat\",\"bound\":7}],\"body\":{\"input\":[{\"t\":\"Opaque\",\"extension\":\"arithmetic.int.types\",\"id\":\"int\",\"args\":[{\"tya\":\"Variable\",\"idx\":0,\"cached_decl\":{\"tp\":\"BoundedNat\",\"bound\":7}}],\"bound\":\"C\"}],\"output\":[{\"t\":\"Opaque\",\"extension\":\"arithmetic.int.types\",\"id\":\"int\",\"args\":[{\"tya\":\"Variable\",\"idx\":0,\"cached_decl\":{\"tp\":\"BoundedNat\",\"bound\":7}}],\"bound\":\"C\"}]}},\"binary\":false},\"iadd\":{\"extension\":\"arithmetic.int\",\"name\":\"iadd\",\"description\":\"addition modulo 2^N (signed and unsigned versions are the same op)\",\"signature\":{\"params\":[{\"tp\":\"BoundedNat\",\"bound\":7}],\"body\":{\"input\":[{\"t\":\"Opaque\",\"extension\":\"arithmetic.int.types\",\"id\":\"int\",\"args\":[{\"tya\":\"Variable\",\"idx\":0,\"cached_decl\":{\"tp\":\"BoundedNat\",\"bound\":7}}],\"bound\":\"C\"},{\"t\":\"Opaque\",\"extension\":\"arithmetic.int.types\",\"id\":\"int\",\"args\":[{\"tya\":\"Variable\",\"idx\":0,\"cached_decl\":{\"tp\":\"BoundedNat\",\"bound\":7}}],\"bound\":\"C\"}],\"output\":[{\"t\":\"Opaque\",\"extension\":\"arithmetic.int.types\",\"id\":\"int\",\"args\":[{\"tya\":\"Variable\",\"idx\":0,\"cached_decl\":{\"tp\":\"BoundedNat\",\"bound\":7}}],\"bound\":\"C\"}]}},\"binary\":false},\"iand\":{\"extension\":\"arithmetic.int\",\"name\":\"iand\",\"description\":\"bitwise AND\",\"signature\":{\"params\":[{\"tp\":\"BoundedNat\",\"bound\":7}],\"body\":{\"input\":[{\"t\":\"Opaque\",\"extension\":\"arithmetic.int.types\",\"id\":\"int\",\"args\":[{\"tya\":\"Variable\",\"idx\":0,\"cached_decl\":{\"tp\":\"BoundedNat\",\"bound\":7}}],\"bound\":\"C\"},{\"t\":\"Opaque\",\"extension\":\"arithmetic.int.types\",\"id\":\"int\",\"args\":[{\"tya\":\"Variable\",\"idx\":0,\"cached_decl\":{\"tp\":\"BoundedNat\",\"bound\":7}}],\"bound\":\"C\"}],\"output\":[{\"t\":\"Opaque\",\"extension\":\"arithmetic.int.types\",\"id\":\"int\",\"args\":[{\"tya\":\"Variable\",\"idx\":0,\"cached_decl\":{\"tp\":\"BoundedNat\",\"bound\":7}}],\"bound\":\"C\"}]}},\"binary\":false},\"idiv_checked_s\":{\"extension\":\"arithmetic.int\",\"name\":\"idiv_checked_s\",\"description\":\"as idivmod_checked_s but discarding the second output\",\"signature\":{\"params\":[{\"tp\":\"BoundedNat\",\"bound\":7}],\"body\":{\"input\":[{\"t\":\"Opaque\",\"extension\":\"arithmetic.int.types\",\"id\":\"int\",\"args\":[{\"tya\":\"Variable\",\"idx\":0,\"cached_decl\":{\"tp\":\"BoundedNat\",\"bound\":7}}],\"bound\":\"C\"},{\"t\":\"Opaque\",\"extension\":\"arithmetic.int.types\",\"id\":\"int\",\"args\":[{\"tya\":\"Variable\",\"idx\":0,\"cached_decl\":{\"tp\":\"BoundedNat\",\"bound\":7}}],\"bound\":\"C\"}],\"output\":[{\"t\":\"Sum\",\"s\":\"General\",\"rows\":[[{\"t\":\"Opaque\",\"extension\":\"prelude\",\"id\":\"error\",\"args\":[],\"bound\":\"C\"}],[{\"t\":\"Opaque\",\"extension\":\"arithmetic.int.types\",\"id\":\"int\",\"args\":[{\"tya\":\"Variable\",\"idx\":0,\"cached_decl\":{\"tp\":\"BoundedNat\",\"bound\":7}}],\"bound\":\"C\"}]]}]}},\"binary\":false},\"idiv_checked_u\":{\"extension\":\"arithmetic.int\",\"name\":\"idiv_checked_u\",\"description\":\"as idivmod_checked_u but discarding the second output\",\"signature\":{\"params\":[{\"tp\":\"BoundedNat\",\"bound\":7}],\"body\":{\"input\":[{\"t\":\"Opaque\",\"extension\":\"arithmetic.int.types\",\"id\":\"int\",\"args\":[{\"tya\":\"Variable\",\"idx\":0,\"cached_decl\":{\"tp\":\"BoundedNat\",\"bound\":7}}],\"bound\":\"C\"},{\"t\":\"Opaque\",\"extension\":\"arithmetic.int.types\",\"id\":\"int\",\"args\":[{\"tya\":\"Variable\",\"idx\":0,\"cached_decl\":{\"tp\":\"BoundedNat\",\"bound\":7}}],\"bound\":\"C\"}],\"output\":[{\"t\":\"Sum\",\"s\":\"General\",\"rows\":[[{\"t\":\"Opaque\",\"extension\":\"prelude\",\"id\":\"error\",\"args\":[],\"bound\":\"C\"}],[{\"t\":\"Opaque\",\"extension\":\"arithmetic.int.types\",\"id\":\"int\",\"args\":[{\"tya\":\"Variable\",\"idx\":0,\"cached_decl\":{\"tp\":\"BoundedNat\",\"bound\":7}}],\"bound\":\"C\"}]]}]}},\"binary\":false},\"idiv_s\":{\"extension\":\"arithmetic.int\",\"name\":\"idiv_s\",\"description\":\"as idivmod_s but discarding the second output\",\"signature\":{\"params\":[{\"tp\":\"BoundedNat\",\"bound\":7}],\"body\":{\"input\":[{\"t\":\"Opaque\",\"extension\":\"arithmetic.int.types\",\"id\":\"int\",\"args\":[{\"tya\":\"Variable\",\"idx\":0,\"cached_decl\":{\"tp\":\"BoundedNat\",\"bound\":7}}],\"bound\":\"C\"},{\"t\":\"Opaque\",\"extension\":\"arithmetic.int.types\",\"id\":\"int\",\"args\":[{\"tya\":\"Variable\",\"idx\":0,\"cached_decl\":{\"tp\":\"BoundedNat\",\"bound\":7}}],\"bound\":\"C\"}],\"output\":[{\"t\":\"Opaque\",\"extension\":\"arithmetic.int.types\",\"id\":\"int\",\"args\":[{\"tya\":\"Variable\",\"idx\":0,\"cached_decl\":{\"tp\":\"BoundedNat\",\"bound\":7}}],\"bound\":\"C\"}]}},\"binary\":false},\"idiv_u\":{\"extension\":\"arithmetic.int\",\"name\":\"idiv_u\",\"description\":\"as idivmod_u but discarding the second output\",\"signature\":{\"params\":[{\"tp\":\"BoundedNat\",\"bound\":7}],\"body\":{\"input\":[{\"t\":\"Opaque\",\"extension\":\"arithmetic.int.types\",\"id\":\"int\",\"args\":[{\"tya\":\"Variable\",\"idx\":0,\"cached_decl\":{\"tp\":\"BoundedNat\",\"bound\":7}}],\"bound\":\"C\"},{\"t\":\"Opaque\",\"extension\":\"arithmetic.int.types\",\"id\":\"int\",\"args\":[{\"tya\":\"Variable\",\"idx\":0,\"cached_decl\":{\"tp\":\"BoundedNat\",\"bound\":7}}],\"bound\":\"C\"}],\"output\":[{\"t\":\"Opaque\",\"extension\":\"arithmetic.int.types\",\"id\":\"int\",\"args\":[{\"tya\":\"Variable\",\"idx\":0,\"cached_decl\":{\"tp\":\"BoundedNat\",\"bound\":7}}],\"bound\":\"C\"}]}},\"binary\":false},\"idivmod_checked_s\":{\"extension\":\"arithmetic.int\",\"name\":\"idivmod_checked_s\",\"description\":\"given signed integer -2^{N-1} <= n < 2^{N-1} and unsigned 0 <= m < 2^N, generates signed q and unsigned r where q*m+r=n, 0<=r {} +impl> Resolver for T {} + +/// A resolver that considers two nodes equivalent if they are the same pointer. +/// +/// Resolvers determine when two patches are equivalent and should be merged +/// in the patch history. +/// +/// This is a trivial resolver (to be expanded on later), that considers two +/// patches equivalent if they point to the same data in memory. +#[derive(Clone, Debug, Default, PartialEq, Eq, serde::Serialize, serde::Deserialize)] +pub struct PointerEqResolver; + +impl EquivalenceResolver for PointerEqResolver { + type MergeMapping = (); + + type DedupKey = *const N; + + fn id(&self) -> String { + "PointerEqResolver".to_string() + } + + fn dedup_key(&self, value: &N, _incoming_edges: &[&E]) -> Self::DedupKey { + value as *const N + } + + fn try_merge_mapping( + &self, + a_value: &N, + _a_incoming_edges: &[&E], + b_value: &N, + _b_incoming_edges: &[&E], + ) -> Result { + if std::ptr::eq(a_value, b_value) { + Ok(()) + } else { + Err(relrc::resolver::NotEquivalent) + } + } + + fn move_edge_source(&self, _mapping: &Self::MergeMapping, edge: &E) -> E { + edge.clone() + } +} + +/// A resolver that considers two nodes equivalent if the hashes of their +/// serialisation is the same. +/// +/// ### Generic type parameter +/// +/// This is parametrised over a serializable type `H`, which must implement +/// [`From`]. This type is used to serialise the commit data before +/// hashing it. +/// +/// Resolvers determine when two patches are equivalent and should be merged +/// in the patch history. +#[derive(Clone, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)] +pub struct SerdeHashResolver(#[serde(skip)] PhantomData); + +impl Default for SerdeHashResolver { + fn default() -> Self { + Self(PhantomData) + } +} + +impl SerdeHashResolver { + fn hash(value: &impl serde::Serialize) -> u64 { + let bytes = serde_json::to_vec(value).unwrap(); + const SEED: u64 = 0; + wyhash(&bytes, SEED) + } +} + +impl> EquivalenceResolver + for SerdeHashResolver +{ + type MergeMapping = (); + + type DedupKey = u64; + + fn id(&self) -> String { + "SerdeHashResolver".to_string() + } + + fn dedup_key(&self, value: &CommitData, _incoming_edges: &[&()]) -> Self::DedupKey { + let ser_value = value.clone().into_serial::(); + Self::hash(&ser_value) + } + + fn try_merge_mapping( + &self, + a_value: &CommitData, + _a_incoming_edges: &[&()], + b_value: &CommitData, + _b_incoming_edges: &[&()], + ) -> Result { + let a_ser_value = a_value.clone().into_serial::(); + let b_ser_value = b_value.clone().into_serial::(); + if Self::hash(&a_ser_value) == Self::hash(&b_ser_value) { + Ok(()) + } else { + Err(relrc::resolver::NotEquivalent) + } + } + + fn move_edge_source(&self, _mapping: &Self::MergeMapping, _edge: &()) {} +} + +#[cfg(test)] +mod tests { + use hugr_core::{builder::endo_sig, ops::FuncDefn}; + + use super::*; + use crate::{CommitData, tests::WrappedHugr}; + + #[test] + fn test_serde_hash_resolver_equality() { + let resolver = SerdeHashResolver::::default(); + + // Create a base CommitData + let base_data = CommitData::Base(Hugr::new()); + + // Clone the data to create an equivalent copy + let cloned_data = base_data.clone(); + + // Check that original and cloned data are considered equivalent + let result = resolver.try_merge_mapping(&base_data, &[], &cloned_data, &[]); + // Verify that the merge succeeds since the data is equivalent + assert!(result.is_ok()); + + // Check that the original and replacement data are considered different + let repl_data = CommitData::Base( + Hugr::new_with_entrypoint(FuncDefn::new("dummy", endo_sig(vec![]))).unwrap(), + ); + let result = resolver.try_merge_mapping(&base_data, &[], &repl_data, &[]); + assert!(result.is_err()); + } +} diff --git a/hugr-core/src/hugr/persistent/state_space.rs b/hugr-persistent/src/state_space.rs similarity index 70% rename from hugr-core/src/hugr/persistent/state_space.rs rename to hugr-persistent/src/state_space.rs index d710c1aaec..30f704347d 100644 --- a/hugr-core/src/hugr/persistent/state_space.rs +++ b/hugr-persistent/src/state_space.rs @@ -1,19 +1,29 @@ +//! Store of commit histories for a [`PersistentHugr`]. + use std::collections::{BTreeSet, VecDeque}; use delegate::delegate; use derive_more::From; -use itertools::Itertools; +use hugr_core::{ + Direction, Hugr, HugrView, IncomingPort, Node, OutgoingPort, Port, SimpleReplacement, + hugr::{ + self, + internal::HugrInternals, + patch::{ + BoundaryPort, + simple_replace::{BoundaryMode, InvalidReplacement}, + }, + views::InvalidSignature, + }, + ops::OpType, +}; +use itertools::{Either, Itertools}; use relrc::{HistoryGraph, RelRc}; use thiserror::Error; -use super::{ - Commit, PersistentHugr, PersistentReplacement, PointerEqResolver, find_conflicting_node, - parents_view::ParentsView, -}; use crate::{ - Direction, Hugr, HugrView, IncomingPort, Node, OutgoingPort, Port, SimpleReplacement, - hugr::{internal::HugrInternals, patch::BoundaryPort}, - ops::OpType, + Commit, PersistentHugr, PersistentReplacement, PointerEqResolver, Resolver, + find_conflicting_node, parents_view::ParentsView, subgraph::InvalidPinnedSubgraph, }; pub mod serial; @@ -23,23 +33,46 @@ pub type CommitId = relrc::NodeId; /// A HUGR node within a commit of the commit state space #[derive( - Copy, Clone, Ord, PartialOrd, Eq, PartialEq, Debug, Hash, serde::Serialize, serde::Deserialize, + Copy, Clone, Ord, PartialOrd, Eq, PartialEq, Hash, serde::Serialize, serde::Deserialize, )] pub struct PatchNode(pub CommitId, pub Node); +impl PatchNode { + /// Get the commit ID of the commit that owns this node. + pub fn owner(&self) -> CommitId { + self.0 + } +} + +// Print out PatchNodes as `Node(x)@commit_hex` +impl std::fmt::Debug for PatchNode { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{:?}@{}", self.1, self.0) + } +} + impl std::fmt::Display for PatchNode { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "{:?}", self) + write!(f, "{self:?}") } } -/// The data stored in a [`Commit`], either the base [`Hugr`] (on which all -/// other commits apply), or a [`PersistentReplacement`] -#[derive(Debug, Clone, From)] -pub(super) enum CommitData { - Base(Hugr), - Replacement(PersistentReplacement), +mod hidden { + use super::*; + + /// The data stored in a [`Commit`], either the base [`Hugr`] (on which all + /// other commits apply), or a [`PersistentReplacement`] + /// + /// This is a "unnamable" type: we do not expose this struct publicly in our + /// API, but we can still use it in public trait bounds (see + /// [`Resolver`](crate::resolver::Resolver)). + #[derive(Debug, Clone, From)] + pub enum CommitData { + Base(Hugr), + Replacement(PersistentReplacement), + } } +pub(crate) use hidden::CommitData; /// A set of commits with directed (acyclic) dependencies between them. /// @@ -61,24 +94,24 @@ pub(super) enum CommitData { /// same subgraph. Use [`Self::try_extract_hugr`] to get a [`PersistentHugr`] /// with a set of compatible commits. #[derive(Clone, Debug)] -pub struct CommitStateSpace { +pub struct CommitStateSpace { /// A set of commits with directed (acyclic) dependencies between them. /// /// Each commit is stored as a [`RelRc`]. - graph: HistoryGraph, + pub(super) graph: HistoryGraph, /// The unique root of the commit graph. /// /// The only commit in the graph with variant [`CommitData::Base`]. All /// other commits are [`CommitData::Replacement`]s, and are descendants /// of this. - base_commit: CommitId, + pub(super) base_commit: CommitId, } -impl CommitStateSpace { +impl CommitStateSpace { /// Create a new commit state space with a single base commit. pub fn with_base(hugr: Hugr) -> Self { let commit = RelRc::new(CommitData::Base(hugr)); - let graph = HistoryGraph::new([commit.clone()], PointerEqResolver); + let graph = HistoryGraph::new([commit.clone()], R::default()); let base_commit = graph .all_node_ids() .exactly_one() @@ -94,7 +127,7 @@ impl CommitStateSpace { pub fn try_from_commits( commits: impl IntoIterator, ) -> Result { - let graph = HistoryGraph::new(commits.into_iter().map_into(), PointerEqResolver); + let graph = HistoryGraph::new(commits.into_iter().map_into(), R::default()); let base_commits = graph .all_node_ids() .filter(|&id| matches!(graph.get_node(id).value(), CommitData::Base(_))) @@ -118,39 +151,29 @@ impl CommitStateSpace { self.try_add_commit(commit) } - /// Add a set of commits to the state space. - /// - /// Commits must be valid replacement commits or coincide with the existing - /// base commit. - pub fn extend(&mut self, commits: impl IntoIterator) { - // TODO: make this more efficient - for commit in commits { - self.try_add_commit(commit) - .expect("invalid commit in extend"); - } - } - /// Add a commit (and all its ancestors) to the state space. /// /// Returns an [`InvalidCommit::NonUniqueBase`] error if the commit is a /// base commit and does not coincide with the existing base commit. pub fn try_add_commit(&mut self, commit: Commit) -> Result { - if matches!(commit.value(), CommitData::Base(_) if !commit.0.ptr_eq(&self.base_commit().0)) - { + let is_base = commit.as_relrc().ptr_eq(self.base_commit().as_relrc()); + if !is_base && matches!(commit.value(), CommitData::Base(_)) { return Err(InvalidCommit::NonUniqueBase(2)); } let commit = commit.into(); Ok(self.graph.insert_node(commit)) } - /// Check if `commit` is in the commit state space. - pub fn contains(&self, commit: &Commit) -> bool { - self.graph.contains(commit.as_relrc()) - } - - /// Check if `commit_id` is in the commit state space. - pub fn contains_id(&self, commit_id: CommitId) -> bool { - self.graph.contains_id(commit_id) + /// Add a set of commits to the state space. + /// + /// Commits must be valid replacement commits or coincide with the existing + /// base commit. + pub fn extend(&mut self, commits: impl IntoIterator) { + // TODO: make this more efficient + for commit in commits { + self.try_add_commit(commit) + .expect("invalid commit in extend"); + } } /// Extract a `PersistentHugr` from this state space, consisting of @@ -164,7 +187,7 @@ impl CommitStateSpace { pub fn try_extract_hugr( &self, commits: impl IntoIterator, - ) -> Result { + ) -> Result, InvalidCommit> { // Define commits as the set of all ancestors of the given commits let all_commit_ids = get_all_ancestors(&self.graph, commits); @@ -187,13 +210,25 @@ impl CommitStateSpace { let commits = all_commit_ids .into_iter() .map(|id| self.get_commit(id).as_relrc().clone()); - let subgraph = HistoryGraph::new(commits, PointerEqResolver); + let subgraph = HistoryGraph::new(commits, R::default()); Ok(PersistentHugr::from_state_space_unsafe(Self { graph: subgraph, base_commit: self.base_commit, })) } +} + +impl CommitStateSpace { + /// Check if `commit` is in the commit state space. + pub fn contains(&self, commit: &Commit) -> bool { + self.graph.contains(commit.as_relrc()) + } + + /// Check if `commit_id` is in the commit state space. + pub fn contains_id(&self, commit_id: CommitId) -> bool { + self.graph.contains_id(commit_id) + } /// Get the base commit ID. pub fn base(&self) -> CommitId { @@ -218,6 +253,12 @@ impl CommitStateSpace { self.graph.get_node(commit_id).into() } + /// Check whether `commit_id` exists and return it. + pub fn try_get_commit(&self, commit_id: CommitId) -> Option<&Commit> { + self.contains_id(commit_id) + .then(|| self.get_commit(commit_id)) + } + /// Get an iterator over all commit IDs in the state space. pub fn all_commit_ids(&self) -> impl Iterator + Clone + '_ { let vec = self.graph.all_node_ids().collect_vec(); @@ -256,7 +297,7 @@ impl CommitStateSpace { } } - pub(super) fn as_history_graph(&self) -> &HistoryGraph { + pub(crate) fn as_history_graph(&self) -> &HistoryGraph { &self.graph } @@ -264,7 +305,7 @@ impl CommitStateSpace { /// /// This is either the replacement Hugr of a [`CommitData::Replacement`] or /// the base Hugr of a [`CommitData::Base`]. - pub(super) fn commit_hugr(&self, commit_id: CommitId) -> &Hugr { + pub(crate) fn commit_hugr(&self, commit_id: CommitId) -> &Hugr { let commit = self.get_commit(commit_id); match commit.value() { CommitData::Base(base) => base, @@ -295,15 +336,23 @@ impl CommitStateSpace { /// Get the boundary inputs linked to `(node, port)` in `child`. /// + /// The returned ports will be ports on successors of the input node in the + /// `child` commit, unless (node, port) is connected to a passthrough wire + /// in `child` (i.e. a wire from input node to output node), in which + /// case they will be in one of the parents of `child`. + /// + /// `child` should be a child commit of the owner of `node`. + /// /// ## Panics /// - /// Panics if `(node, port)` is not a boundary edge, or if `child` is not - /// a valid commit ID. - pub(super) fn linked_child_inputs( + /// Panics if `(node, port)` is not a boundary edge, if `child` is not + /// a valid commit ID or if it is the base commit. + pub(crate) fn linked_child_inputs( &self, node: PatchNode, port: OutgoingPort, child: CommitId, + return_invalid: BoundaryMode, ) -> impl Iterator + '_ { assert!( self.is_boundary_edge(node, port, child), @@ -312,7 +361,7 @@ impl CommitStateSpace { let parent_hugrs = ParentsView::from_commit(child, self); let repl = self.replacement(child).expect("valid child commit"); - repl.linked_replacement_inputs((node, port), &parent_hugrs) + repl.linked_replacement_inputs((node, port), &parent_hugrs, return_invalid) .collect_vec() .into_iter() .map(move |np| match np { @@ -323,32 +372,70 @@ impl CommitStateSpace { /// Get the single boundary output linked to `(node, port)` in `child`. /// + /// The returned port will be a port on a predecessor of the output node in + /// the `child` commit, unless (node, port) is connected to a passthrough + /// wire in `child` (i.e. a wire from input node to output node), in + /// which case it will be in one of the parents of `child`. + /// + /// `child` should be a child commit of the owner of `node` (or `None` will + /// be returned). + /// /// ## Panics /// /// Panics if `child` is not a valid commit ID. - pub(super) fn linked_child_output( + pub(crate) fn linked_child_output( &self, node: PatchNode, port: IncomingPort, child: CommitId, + return_invalid: BoundaryMode, ) -> Option<(PatchNode, OutgoingPort)> { let parent_hugrs = ParentsView::from_commit(child, self); - let repl = self.replacement(child).expect("valid child commit"); - match repl.linked_replacement_output((node, port), &parent_hugrs)? { + let repl = self.replacement(child)?; + match repl.linked_replacement_output((node, port), &parent_hugrs, return_invalid)? { BoundaryPort::Host(patch_node, port) => (patch_node, port), BoundaryPort::Replacement(node, port) => (PatchNode(child, node), port), } .into() } - /// Get the single output boundary port linked to `(node, port)` in a - /// parent of the commit of `node`. + /// Get the boundary ports linked to `(node, port)` in `child`. + /// + /// `child` should be a child commit of the owner of `node`. + /// + /// See [`Self::linked_child_inputs`] and [`Self::linked_child_output`] for + /// more details. + pub(crate) fn linked_child_ports( + &self, + node: PatchNode, + port: impl Into, + child: CommitId, + return_invalid: BoundaryMode, + ) -> impl Iterator + '_ { + match port.into().as_directed() { + Either::Left(incoming) => Either::Left( + self.linked_child_output(node, incoming, child, return_invalid) + .into_iter() + .map(|(node, port)| (node, port.into())), + ), + Either::Right(outgoing) => Either::Right( + self.linked_child_inputs(node, outgoing, child, return_invalid) + .map(|(node, port)| (node, port.into())), + ), + } + } + + /// Get the single output port linked to `(node, port)` in a parent of the + /// commit of `node`. + /// + /// The returned port belongs to the input boundary of the subgraph in + /// parent. /// /// ## Panics /// /// Panics if `(node, port)` is not connected to the input node in the /// commit of `node`, or if the node is not valid. - pub(super) fn linked_parent_input( + pub fn linked_parent_input( &self, PatchNode(commit_id, node): PatchNode, port: IncomingPort, @@ -366,7 +453,17 @@ impl CommitStateSpace { repl.linked_host_input((node, port), &parent_hugrs).into() } - pub(super) fn linked_parent_outputs( + /// Get the input ports linked to `(node, port)` in a parent of the commit + /// of `node`. + /// + /// The returned ports belong to the output boundary of the subgraph in + /// parent. + /// + /// ## Panics + /// + /// Panics if `(node, port)` is not connected to the output node in the + /// commit of `node`, or if the node is not valid. + pub fn linked_parent_outputs( &self, PatchNode(commit_id, node): PatchNode, port: OutgoingPort, @@ -387,8 +484,35 @@ impl CommitStateSpace { .into_iter() } + /// Get the ports linked to `(node, port)` in a parent of the commit of + /// `node`. + /// + /// See [`Self::linked_parent_input`] and [`Self::linked_parent_outputs`] + /// for more details. + /// + /// ## Panics + /// + /// Panics if `(node, port)` is not connected to an IO node in the commit + /// of `node`, or if the node is not valid. + pub fn linked_parent_ports( + &self, + node: PatchNode, + port: impl Into, + ) -> impl Iterator + '_ { + match port.into().as_directed() { + Either::Left(incoming) => { + let (node, port) = self.linked_parent_input(node, incoming); + Either::Left(std::iter::once((node, port.into()))) + } + Either::Right(outgoing) => Either::Right( + self.linked_parent_outputs(node, outgoing) + .map(|(node, port)| (node, port.into())), + ), + } + } + /// Get the replacement for `commit_id`. - pub(super) fn replacement(&self, commit_id: CommitId) -> Option<&SimpleReplacement> { + pub(crate) fn replacement(&self, commit_id: CommitId) -> Option<&SimpleReplacement> { let commit = self.get_commit(commit_id); commit.replacement() } @@ -396,7 +520,7 @@ impl CommitStateSpace { // The subset of HugrView methods that can be implemented on CommitStateSpace // by simplify delegating to the patches' respective HUGRs -impl CommitStateSpace { +impl CommitStateSpace { /// Get the type of the operation at `node`. pub fn get_optype(&self, PatchNode(commit_id, node): PatchNode) -> &OpType { let hugr = self.commit_hugr(commit_id); @@ -447,7 +571,7 @@ impl CommitStateSpace { pub fn node_metadata_map( &self, PatchNode(commit_id, node): PatchNode, - ) -> &crate::hugr::NodeMetadataMap { + ) -> &hugr::NodeMetadataMap { self.commit_hugr(commit_id).node_metadata_map(node) } } @@ -491,4 +615,20 @@ pub enum InvalidCommit { /// The commit is an empty replacement. #[error("Not allowed: empty replacement")] EmptyReplacement, + + #[error("Invalid subgraph: {0}")] + /// The subgraph of the replacement is not convex. + InvalidSubgraph(#[from] InvalidPinnedSubgraph), + + /// The replacement of the commit is invalid. + #[error("Invalid replacement: {0}")] + InvalidReplacement(#[from] InvalidReplacement), + + /// The signature of the replacement is invalid. + #[error("Invalid signature: {0}")] + InvalidSignature(#[from] InvalidSignature), + + /// A wire has an unpinned port. + #[error("Incomplete wire: {0} is unpinned")] + IncompleteWire(PatchNode, Port), } diff --git a/hugr-core/src/hugr/persistent/state_space/serial.rs b/hugr-persistent/src/state_space/serial.rs similarity index 66% rename from hugr-core/src/hugr/persistent/state_space/serial.rs rename to hugr-persistent/src/state_space/serial.rs index c345308b8b..b20c585eb7 100644 --- a/hugr-core/src/hugr/persistent/state_space/serial.rs +++ b/hugr-persistent/src/state_space/serial.rs @@ -1,7 +1,9 @@ +//! Serialized format for [`CommitStateSpace`] + use relrc::serialization::SerializedHistoryGraph; use super::*; -use crate::hugr::patch::simple_replace::serial::SerialSimpleReplacement; +use hugr_core::hugr::patch::simple_replace::serial::SerialSimpleReplacement; /// Serialized format for [`PersistentReplacement`] pub type SerialPersistentReplacement = SerialSimpleReplacement; @@ -51,28 +53,28 @@ impl> From> for CommitData { /// Serialized format for commit state space #[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)] -pub struct SerialCommitStateSpace { +pub struct SerialCommitStateSpace { /// The serialized history graph containing commit data - pub graph: SerializedHistoryGraph, (), PointerEqResolver>, + pub graph: SerializedHistoryGraph, (), R>, /// The base commit ID pub base_commit: CommitId, } -impl CommitStateSpace { +impl CommitStateSpace { /// Create a new [`CommitStateSpace`] from its serialized format - pub fn from_serial + Clone>(value: SerialCommitStateSpace) -> Self { + pub fn from_serial>(value: SerialCommitStateSpace) -> Self { let SerialCommitStateSpace { graph, base_commit } = value; // Deserialize the SerializedHistoryGraph into a HistoryGraph with CommitData let graph = graph.map_nodes(|n| CommitData::from_serial(n)); - let graph = HistoryGraph::try_from_serialized(graph, PointerEqResolver) + let graph = HistoryGraph::try_from_serialized(graph, R::default()) .expect("failed to deserialize history graph"); Self { graph, base_commit } } /// Convert a [`CommitStateSpace`] into its serialized format - pub fn into_serial>(self) -> SerialCommitStateSpace { + pub fn into_serial>(self) -> SerialCommitStateSpace { let Self { graph, base_commit } = self; let graph = graph.to_serialized(); let graph = graph.map_nodes(|n| n.into_serial()); @@ -80,10 +82,7 @@ impl CommitStateSpace { } /// Create a serialized format from a reference to [`CommitStateSpace`] - pub fn to_serial(&self) -> SerialCommitStateSpace - where - H: From, - { + pub fn to_serial>(&self) -> SerialCommitStateSpace { let Self { graph, base_commit } = self; let graph = graph.to_serialized(); let graph = graph.map_nodes(|n| n.into_serial()); @@ -94,51 +93,46 @@ impl CommitStateSpace { } } -impl> From for SerialCommitStateSpace { - fn from(value: CommitStateSpace) -> Self { +impl, R: Resolver> From> for SerialCommitStateSpace { + fn from(value: CommitStateSpace) -> Self { value.into_serial() } } -impl> From> for CommitStateSpace { - fn from(value: SerialCommitStateSpace) -> Self { +impl, R: Resolver> From> for CommitStateSpace { + fn from(value: SerialCommitStateSpace) -> Self { CommitStateSpace::from_serial(value) } } #[cfg(test)] mod tests { - use derive_more::derive::Into; use rstest::rstest; - use serde_with::serde_as; use super::*; use crate::{ - envelope::serde_with::AsStringEnvelope, hugr::persistent::tests::test_state_space, + SerdeHashResolver, + tests::{WrappedHugr, test_state_space}, }; - #[serde_as] - #[derive(Debug, Clone, serde::Serialize, serde::Deserialize, From, Into)] - struct WrappedHugr { - #[serde_as(as = "AsStringEnvelope")] - pub hugr: Hugr, - } - #[cfg_attr(miri, ignore)] // Opening files is not supported in (isolated) miri #[rstest] - fn test_serialize_state_space(test_state_space: (CommitStateSpace, [CommitId; 4])) { + fn test_serialize_state_space( + test_state_space: ( + CommitStateSpace>, + [CommitId; 4], + ), + ) { let (state_space, _) = test_state_space; let serialized = state_space.to_serial::(); - let deser = CommitStateSpace::from_serial(serialized); - let _serialized_2 = deser.to_serial::(); + let deser = CommitStateSpace::from_serial(serialized.clone()); + let serialized_2 = deser.to_serial::(); - // TODO: add this once PointerEqResolver is replaced by a deterministic resolver - // insta::assert_snapshot!(serde_json::to_string_pretty(&serialized).unwrap()); - // see https://github.com/CQCL/hugr/issues/2299 - // assert_eq!( - // serde_json::to_string(&serialized), - // serde_json::to_string(&serialized_2) - // ); + insta::assert_snapshot!(serde_json::to_string_pretty(&serialized).unwrap()); + assert_eq!( + serde_json::to_string(&serialized).unwrap(), + serde_json::to_string(&serialized_2).unwrap() + ); } } diff --git a/hugr-persistent/src/state_space/snapshots/hugr_persistent__state_space__serial__tests__serialize_state_space.snap b/hugr-persistent/src/state_space/snapshots/hugr_persistent__state_space__serial__tests__serialize_state_space.snap new file mode 100644 index 0000000000..b415f9d784 --- /dev/null +++ b/hugr-persistent/src/state_space/snapshots/hugr_persistent__state_space__serial__tests__serialize_state_space.snap @@ -0,0 +1,244 @@ +--- +source: hugr-persistent/src/state_space/serial.rs +expression: "serde_json::to_string_pretty(&serialized).unwrap()" +--- +{ + "graph": { + "nodes": { + "3fd58bd8c5f2494a": { + "value": { + "Base": { + "hugr": "HUGRiHJv?@{\"modules\":[{\"version\":\"live\",\"nodes\":[{\"parent\":0,\"op\":\"Module\"},{\"parent\":0,\"op\":\"FuncDefn\",\"name\":\"main\",\"signature\":{\"params\":[],\"body\":{\"input\":[{\"t\":\"Sum\",\"s\":\"Unit\",\"size\":2},{\"t\":\"Sum\",\"s\":\"Unit\",\"size\":2}],\"output\":[{\"t\":\"Sum\",\"s\":\"Unit\",\"size\":2}]}},\"visibility\":\"Private\"},{\"parent\":1,\"op\":\"Input\",\"types\":[{\"t\":\"Sum\",\"s\":\"Unit\",\"size\":2},{\"t\":\"Sum\",\"s\":\"Unit\",\"size\":2}]},{\"parent\":1,\"op\":\"Output\",\"types\":[{\"t\":\"Sum\",\"s\":\"Unit\",\"size\":2}]},{\"parent\":1,\"op\":\"DFG\",\"signature\":{\"input\":[{\"t\":\"Sum\",\"s\":\"Unit\",\"size\":2},{\"t\":\"Sum\",\"s\":\"Unit\",\"size\":2}],\"output\":[{\"t\":\"Sum\",\"s\":\"Unit\",\"size\":2}]}},{\"parent\":4,\"op\":\"Input\",\"types\":[{\"t\":\"Sum\",\"s\":\"Unit\",\"size\":2},{\"t\":\"Sum\",\"s\":\"Unit\",\"size\":2}]},{\"parent\":4,\"op\":\"Output\",\"types\":[{\"t\":\"Sum\",\"s\":\"Unit\",\"size\":2}]},{\"parent\":4,\"op\":\"Extension\",\"extension\":\"logic\",\"name\":\"Not\",\"args\":[],\"signature\":{\"input\":[{\"t\":\"Sum\",\"s\":\"Unit\",\"size\":2}],\"output\":[{\"t\":\"Sum\",\"s\":\"Unit\",\"size\":2}]}},{\"parent\":4,\"op\":\"Extension\",\"extension\":\"logic\",\"name\":\"Not\",\"args\":[],\"signature\":{\"input\":[{\"t\":\"Sum\",\"s\":\"Unit\",\"size\":2}],\"output\":[{\"t\":\"Sum\",\"s\":\"Unit\",\"size\":2}]}},{\"parent\":4,\"op\":\"Extension\",\"extension\":\"logic\",\"name\":\"And\",\"args\":[],\"signature\":{\"input\":[{\"t\":\"Sum\",\"s\":\"Unit\",\"size\":2},{\"t\":\"Sum\",\"s\":\"Unit\",\"size\":2}],\"output\":[{\"t\":\"Sum\",\"s\":\"Unit\",\"size\":2}]}}],\"edges\":[[[2,0],[4,0]],[[2,1],[4,1]],[[4,0],[3,0]],[[5,0],[7,0]],[[5,1],[8,0]],[[7,0],[9,0]],[[8,0],[9,1]],[[9,0],[6,0]]],\"metadata\":[null,null,null,null,null,null,null,null,null,null],\"entrypoint\":4}],\"extensions\":[{\"version\":\"0.1.0\",\"name\":\"arithmetic.conversions\",\"types\":{},\"operations\":{\"bytecast_float64_to_int64\":{\"extension\":\"arithmetic.conversions\",\"name\":\"bytecast_float64_to_int64\",\"description\":\"reinterpret an float64 as an int based on its bytes, with the same endianness\",\"signature\":{\"params\":[],\"body\":{\"input\":[{\"t\":\"Opaque\",\"extension\":\"arithmetic.float.types\",\"id\":\"float64\",\"args\":[],\"bound\":\"C\"}],\"output\":[{\"t\":\"Opaque\",\"extension\":\"arithmetic.int.types\",\"id\":\"int\",\"args\":[{\"tya\":\"BoundedNat\",\"n\":6}],\"bound\":\"C\"}]}},\"binary\":false},\"bytecast_int64_to_float64\":{\"extension\":\"arithmetic.conversions\",\"name\":\"bytecast_int64_to_float64\",\"description\":\"reinterpret an int64 as a float64 based on its bytes, with the same endianness\",\"signature\":{\"params\":[],\"body\":{\"input\":[{\"t\":\"Opaque\",\"extension\":\"arithmetic.int.types\",\"id\":\"int\",\"args\":[{\"tya\":\"BoundedNat\",\"n\":6}],\"bound\":\"C\"}],\"output\":[{\"t\":\"Opaque\",\"extension\":\"arithmetic.float.types\",\"id\":\"float64\",\"args\":[],\"bound\":\"C\"}]}},\"binary\":false},\"convert_s\":{\"extension\":\"arithmetic.conversions\",\"name\":\"convert_s\",\"description\":\"signed int to float\",\"signature\":{\"params\":[{\"tp\":\"BoundedNat\",\"bound\":7}],\"body\":{\"input\":[{\"t\":\"Opaque\",\"extension\":\"arithmetic.int.types\",\"id\":\"int\",\"args\":[{\"tya\":\"Variable\",\"idx\":0,\"cached_decl\":{\"tp\":\"BoundedNat\",\"bound\":7}}],\"bound\":\"C\"}],\"output\":[{\"t\":\"Opaque\",\"extension\":\"arithmetic.float.types\",\"id\":\"float64\",\"args\":[],\"bound\":\"C\"}]}},\"binary\":false},\"convert_u\":{\"extension\":\"arithmetic.conversions\",\"name\":\"convert_u\",\"description\":\"unsigned int to float\",\"signature\":{\"params\":[{\"tp\":\"BoundedNat\",\"bound\":7}],\"body\":{\"input\":[{\"t\":\"Opaque\",\"extension\":\"arithmetic.int.types\",\"id\":\"int\",\"args\":[{\"tya\":\"Variable\",\"idx\":0,\"cached_decl\":{\"tp\":\"BoundedNat\",\"bound\":7}}],\"bound\":\"C\"}],\"output\":[{\"t\":\"Opaque\",\"extension\":\"arithmetic.float.types\",\"id\":\"float64\",\"args\":[],\"bound\":\"C\"}]}},\"binary\":false},\"ifrombool\":{\"extension\":\"arithmetic.conversions\",\"name\":\"ifrombool\",\"description\":\"convert from bool into a 1-bit integer (1 is true, 0 is false)\",\"signature\":{\"params\":[],\"body\":{\"input\":[{\"t\":\"Sum\",\"s\":\"Unit\",\"size\":2}],\"output\":[{\"t\":\"Opaque\",\"extension\":\"arithmetic.int.types\",\"id\":\"int\",\"args\":[{\"tya\":\"BoundedNat\",\"n\":0}],\"bound\":\"C\"}]}},\"binary\":false},\"ifromusize\":{\"extension\":\"arithmetic.conversions\",\"name\":\"ifromusize\",\"description\":\"convert a usize to a 64b unsigned integer\",\"signature\":{\"params\":[],\"body\":{\"input\":[{\"t\":\"I\"}],\"output\":[{\"t\":\"Opaque\",\"extension\":\"arithmetic.int.types\",\"id\":\"int\",\"args\":[{\"tya\":\"BoundedNat\",\"n\":6}],\"bound\":\"C\"}]}},\"binary\":false},\"itobool\":{\"extension\":\"arithmetic.conversions\",\"name\":\"itobool\",\"description\":\"convert a 1-bit integer to bool (1 is true, 0 is false)\",\"signature\":{\"params\":[],\"body\":{\"input\":[{\"t\":\"Opaque\",\"extension\":\"arithmetic.int.types\",\"id\":\"int\",\"args\":[{\"tya\":\"BoundedNat\",\"n\":0}],\"bound\":\"C\"}],\"output\":[{\"t\":\"Sum\",\"s\":\"Unit\",\"size\":2}]}},\"binary\":false},\"itostring_s\":{\"extension\":\"arithmetic.conversions\",\"name\":\"itostring_s\",\"description\":\"convert a signed integer to its string representation\",\"signature\":{\"params\":[{\"tp\":\"BoundedNat\",\"bound\":7}],\"body\":{\"input\":[{\"t\":\"Opaque\",\"extension\":\"arithmetic.int.types\",\"id\":\"int\",\"args\":[{\"tya\":\"Variable\",\"idx\":0,\"cached_decl\":{\"tp\":\"BoundedNat\",\"bound\":7}}],\"bound\":\"C\"}],\"output\":[{\"t\":\"Opaque\",\"extension\":\"prelude\",\"id\":\"string\",\"args\":[],\"bound\":\"C\"}]}},\"binary\":false},\"itostring_u\":{\"extension\":\"arithmetic.conversions\",\"name\":\"itostring_u\",\"description\":\"convert an unsigned integer to its string representation\",\"signature\":{\"params\":[{\"tp\":\"BoundedNat\",\"bound\":7}],\"body\":{\"input\":[{\"t\":\"Opaque\",\"extension\":\"arithmetic.int.types\",\"id\":\"int\",\"args\":[{\"tya\":\"Variable\",\"idx\":0,\"cached_decl\":{\"tp\":\"BoundedNat\",\"bound\":7}}],\"bound\":\"C\"}],\"output\":[{\"t\":\"Opaque\",\"extension\":\"prelude\",\"id\":\"string\",\"args\":[],\"bound\":\"C\"}]}},\"binary\":false},\"itousize\":{\"extension\":\"arithmetic.conversions\",\"name\":\"itousize\",\"description\":\"convert a 64b unsigned integer to its usize representation\",\"signature\":{\"params\":[],\"body\":{\"input\":[{\"t\":\"Opaque\",\"extension\":\"arithmetic.int.types\",\"id\":\"int\",\"args\":[{\"tya\":\"BoundedNat\",\"n\":6}],\"bound\":\"C\"}],\"output\":[{\"t\":\"I\"}]}},\"binary\":false},\"trunc_s\":{\"extension\":\"arithmetic.conversions\",\"name\":\"trunc_s\",\"description\":\"float to signed int\",\"signature\":{\"params\":[{\"tp\":\"BoundedNat\",\"bound\":7}],\"body\":{\"input\":[{\"t\":\"Opaque\",\"extension\":\"arithmetic.float.types\",\"id\":\"float64\",\"args\":[],\"bound\":\"C\"}],\"output\":[{\"t\":\"Sum\",\"s\":\"General\",\"rows\":[[{\"t\":\"Opaque\",\"extension\":\"prelude\",\"id\":\"error\",\"args\":[],\"bound\":\"C\"}],[{\"t\":\"Opaque\",\"extension\":\"arithmetic.int.types\",\"id\":\"int\",\"args\":[{\"tya\":\"Variable\",\"idx\":0,\"cached_decl\":{\"tp\":\"BoundedNat\",\"bound\":7}}],\"bound\":\"C\"}]]}]}},\"binary\":false},\"trunc_u\":{\"extension\":\"arithmetic.conversions\",\"name\":\"trunc_u\",\"description\":\"float to unsigned int\",\"signature\":{\"params\":[{\"tp\":\"BoundedNat\",\"bound\":7}],\"body\":{\"input\":[{\"t\":\"Opaque\",\"extension\":\"arithmetic.float.types\",\"id\":\"float64\",\"args\":[],\"bound\":\"C\"}],\"output\":[{\"t\":\"Sum\",\"s\":\"General\",\"rows\":[[{\"t\":\"Opaque\",\"extension\":\"prelude\",\"id\":\"error\",\"args\":[],\"bound\":\"C\"}],[{\"t\":\"Opaque\",\"extension\":\"arithmetic.int.types\",\"id\":\"int\",\"args\":[{\"tya\":\"Variable\",\"idx\":0,\"cached_decl\":{\"tp\":\"BoundedNat\",\"bound\":7}}],\"bound\":\"C\"}]]}]}},\"binary\":false}}},{\"version\":\"0.1.0\",\"name\":\"arithmetic.float\",\"types\":{},\"operations\":{\"fabs\":{\"extension\":\"arithmetic.float\",\"name\":\"fabs\",\"description\":\"absolute value\",\"signature\":{\"params\":[],\"body\":{\"input\":[{\"t\":\"Opaque\",\"extension\":\"arithmetic.float.types\",\"id\":\"float64\",\"args\":[],\"bound\":\"C\"}],\"output\":[{\"t\":\"Opaque\",\"extension\":\"arithmetic.float.types\",\"id\":\"float64\",\"args\":[],\"bound\":\"C\"}]}},\"binary\":false},\"fadd\":{\"extension\":\"arithmetic.float\",\"name\":\"fadd\",\"description\":\"addition\",\"signature\":{\"params\":[],\"body\":{\"input\":[{\"t\":\"Opaque\",\"extension\":\"arithmetic.float.types\",\"id\":\"float64\",\"args\":[],\"bound\":\"C\"},{\"t\":\"Opaque\",\"extension\":\"arithmetic.float.types\",\"id\":\"float64\",\"args\":[],\"bound\":\"C\"}],\"output\":[{\"t\":\"Opaque\",\"extension\":\"arithmetic.float.types\",\"id\":\"float64\",\"args\":[],\"bound\":\"C\"}]}},\"binary\":false},\"fceil\":{\"extension\":\"arithmetic.float\",\"name\":\"fceil\",\"description\":\"ceiling\",\"signature\":{\"params\":[],\"body\":{\"input\":[{\"t\":\"Opaque\",\"extension\":\"arithmetic.float.types\",\"id\":\"float64\",\"args\":[],\"bound\":\"C\"}],\"output\":[{\"t\":\"Opaque\",\"extension\":\"arithmetic.float.types\",\"id\":\"float64\",\"args\":[],\"bound\":\"C\"}]}},\"binary\":false},\"fdiv\":{\"extension\":\"arithmetic.float\",\"name\":\"fdiv\",\"description\":\"division\",\"signature\":{\"params\":[],\"body\":{\"input\":[{\"t\":\"Opaque\",\"extension\":\"arithmetic.float.types\",\"id\":\"float64\",\"args\":[],\"bound\":\"C\"},{\"t\":\"Opaque\",\"extension\":\"arithmetic.float.types\",\"id\":\"float64\",\"args\":[],\"bound\":\"C\"}],\"output\":[{\"t\":\"Opaque\",\"extension\":\"arithmetic.float.types\",\"id\":\"float64\",\"args\":[],\"bound\":\"C\"}]}},\"binary\":false},\"feq\":{\"extension\":\"arithmetic.float\",\"name\":\"feq\",\"description\":\"equality test\",\"signature\":{\"params\":[],\"body\":{\"input\":[{\"t\":\"Opaque\",\"extension\":\"arithmetic.float.types\",\"id\":\"float64\",\"args\":[],\"bound\":\"C\"},{\"t\":\"Opaque\",\"extension\":\"arithmetic.float.types\",\"id\":\"float64\",\"args\":[],\"bound\":\"C\"}],\"output\":[{\"t\":\"Sum\",\"s\":\"Unit\",\"size\":2}]}},\"binary\":false},\"ffloor\":{\"extension\":\"arithmetic.float\",\"name\":\"ffloor\",\"description\":\"floor\",\"signature\":{\"params\":[],\"body\":{\"input\":[{\"t\":\"Opaque\",\"extension\":\"arithmetic.float.types\",\"id\":\"float64\",\"args\":[],\"bound\":\"C\"}],\"output\":[{\"t\":\"Opaque\",\"extension\":\"arithmetic.float.types\",\"id\":\"float64\",\"args\":[],\"bound\":\"C\"}]}},\"binary\":false},\"fge\":{\"extension\":\"arithmetic.float\",\"name\":\"fge\",\"description\":\"\\\"greater than or equal\\\"\",\"signature\":{\"params\":[],\"body\":{\"input\":[{\"t\":\"Opaque\",\"extension\":\"arithmetic.float.types\",\"id\":\"float64\",\"args\":[],\"bound\":\"C\"},{\"t\":\"Opaque\",\"extension\":\"arithmetic.float.types\",\"id\":\"float64\",\"args\":[],\"bound\":\"C\"}],\"output\":[{\"t\":\"Sum\",\"s\":\"Unit\",\"size\":2}]}},\"binary\":false},\"fgt\":{\"extension\":\"arithmetic.float\",\"name\":\"fgt\",\"description\":\"\\\"greater than\\\"\",\"signature\":{\"params\":[],\"body\":{\"input\":[{\"t\":\"Opaque\",\"extension\":\"arithmetic.float.types\",\"id\":\"float64\",\"args\":[],\"bound\":\"C\"},{\"t\":\"Opaque\",\"extension\":\"arithmetic.float.types\",\"id\":\"float64\",\"args\":[],\"bound\":\"C\"}],\"output\":[{\"t\":\"Sum\",\"s\":\"Unit\",\"size\":2}]}},\"binary\":false},\"fle\":{\"extension\":\"arithmetic.float\",\"name\":\"fle\",\"description\":\"\\\"less than or equal\\\"\",\"signature\":{\"params\":[],\"body\":{\"input\":[{\"t\":\"Opaque\",\"extension\":\"arithmetic.float.types\",\"id\":\"float64\",\"args\":[],\"bound\":\"C\"},{\"t\":\"Opaque\",\"extension\":\"arithmetic.float.types\",\"id\":\"float64\",\"args\":[],\"bound\":\"C\"}],\"output\":[{\"t\":\"Sum\",\"s\":\"Unit\",\"size\":2}]}},\"binary\":false},\"flt\":{\"extension\":\"arithmetic.float\",\"name\":\"flt\",\"description\":\"\\\"less than\\\"\",\"signature\":{\"params\":[],\"body\":{\"input\":[{\"t\":\"Opaque\",\"extension\":\"arithmetic.float.types\",\"id\":\"float64\",\"args\":[],\"bound\":\"C\"},{\"t\":\"Opaque\",\"extension\":\"arithmetic.float.types\",\"id\":\"float64\",\"args\":[],\"bound\":\"C\"}],\"output\":[{\"t\":\"Sum\",\"s\":\"Unit\",\"size\":2}]}},\"binary\":false},\"fmax\":{\"extension\":\"arithmetic.float\",\"name\":\"fmax\",\"description\":\"maximum\",\"signature\":{\"params\":[],\"body\":{\"input\":[{\"t\":\"Opaque\",\"extension\":\"arithmetic.float.types\",\"id\":\"float64\",\"args\":[],\"bound\":\"C\"},{\"t\":\"Opaque\",\"extension\":\"arithmetic.float.types\",\"id\":\"float64\",\"args\":[],\"bound\":\"C\"}],\"output\":[{\"t\":\"Opaque\",\"extension\":\"arithmetic.float.types\",\"id\":\"float64\",\"args\":[],\"bound\":\"C\"}]}},\"binary\":false},\"fmin\":{\"extension\":\"arithmetic.float\",\"name\":\"fmin\",\"description\":\"minimum\",\"signature\":{\"params\":[],\"body\":{\"input\":[{\"t\":\"Opaque\",\"extension\":\"arithmetic.float.types\",\"id\":\"float64\",\"args\":[],\"bound\":\"C\"},{\"t\":\"Opaque\",\"extension\":\"arithmetic.float.types\",\"id\":\"float64\",\"args\":[],\"bound\":\"C\"}],\"output\":[{\"t\":\"Opaque\",\"extension\":\"arithmetic.float.types\",\"id\":\"float64\",\"args\":[],\"bound\":\"C\"}]}},\"binary\":false},\"fmul\":{\"extension\":\"arithmetic.float\",\"name\":\"fmul\",\"description\":\"multiplication\",\"signature\":{\"params\":[],\"body\":{\"input\":[{\"t\":\"Opaque\",\"extension\":\"arithmetic.float.types\",\"id\":\"float64\",\"args\":[],\"bound\":\"C\"},{\"t\":\"Opaque\",\"extension\":\"arithmetic.float.types\",\"id\":\"float64\",\"args\":[],\"bound\":\"C\"}],\"output\":[{\"t\":\"Opaque\",\"extension\":\"arithmetic.float.types\",\"id\":\"float64\",\"args\":[],\"bound\":\"C\"}]}},\"binary\":false},\"fne\":{\"extension\":\"arithmetic.float\",\"name\":\"fne\",\"description\":\"inequality test\",\"signature\":{\"params\":[],\"body\":{\"input\":[{\"t\":\"Opaque\",\"extension\":\"arithmetic.float.types\",\"id\":\"float64\",\"args\":[],\"bound\":\"C\"},{\"t\":\"Opaque\",\"extension\":\"arithmetic.float.types\",\"id\":\"float64\",\"args\":[],\"bound\":\"C\"}],\"output\":[{\"t\":\"Sum\",\"s\":\"Unit\",\"size\":2}]}},\"binary\":false},\"fneg\":{\"extension\":\"arithmetic.float\",\"name\":\"fneg\",\"description\":\"negation\",\"signature\":{\"params\":[],\"body\":{\"input\":[{\"t\":\"Opaque\",\"extension\":\"arithmetic.float.types\",\"id\":\"float64\",\"args\":[],\"bound\":\"C\"}],\"output\":[{\"t\":\"Opaque\",\"extension\":\"arithmetic.float.types\",\"id\":\"float64\",\"args\":[],\"bound\":\"C\"}]}},\"binary\":false},\"fpow\":{\"extension\":\"arithmetic.float\",\"name\":\"fpow\",\"description\":\"exponentiation\",\"signature\":{\"params\":[],\"body\":{\"input\":[{\"t\":\"Opaque\",\"extension\":\"arithmetic.float.types\",\"id\":\"float64\",\"args\":[],\"bound\":\"C\"},{\"t\":\"Opaque\",\"extension\":\"arithmetic.float.types\",\"id\":\"float64\",\"args\":[],\"bound\":\"C\"}],\"output\":[{\"t\":\"Opaque\",\"extension\":\"arithmetic.float.types\",\"id\":\"float64\",\"args\":[],\"bound\":\"C\"}]}},\"binary\":false},\"fround\":{\"extension\":\"arithmetic.float\",\"name\":\"fround\",\"description\":\"round\",\"signature\":{\"params\":[],\"body\":{\"input\":[{\"t\":\"Opaque\",\"extension\":\"arithmetic.float.types\",\"id\":\"float64\",\"args\":[],\"bound\":\"C\"}],\"output\":[{\"t\":\"Opaque\",\"extension\":\"arithmetic.float.types\",\"id\":\"float64\",\"args\":[],\"bound\":\"C\"}]}},\"binary\":false},\"fsub\":{\"extension\":\"arithmetic.float\",\"name\":\"fsub\",\"description\":\"subtraction\",\"signature\":{\"params\":[],\"body\":{\"input\":[{\"t\":\"Opaque\",\"extension\":\"arithmetic.float.types\",\"id\":\"float64\",\"args\":[],\"bound\":\"C\"},{\"t\":\"Opaque\",\"extension\":\"arithmetic.float.types\",\"id\":\"float64\",\"args\":[],\"bound\":\"C\"}],\"output\":[{\"t\":\"Opaque\",\"extension\":\"arithmetic.float.types\",\"id\":\"float64\",\"args\":[],\"bound\":\"C\"}]}},\"binary\":false},\"ftostring\":{\"extension\":\"arithmetic.float\",\"name\":\"ftostring\",\"description\":\"string representation\",\"signature\":{\"params\":[],\"body\":{\"input\":[{\"t\":\"Opaque\",\"extension\":\"arithmetic.float.types\",\"id\":\"float64\",\"args\":[],\"bound\":\"C\"}],\"output\":[{\"t\":\"Opaque\",\"extension\":\"prelude\",\"id\":\"string\",\"args\":[],\"bound\":\"C\"}]}},\"binary\":false}}},{\"version\":\"0.1.0\",\"name\":\"arithmetic.float.types\",\"types\":{\"float64\":{\"extension\":\"arithmetic.float.types\",\"name\":\"float64\",\"params\":[],\"description\":\"64-bit IEEE 754-2019 floating-point value\",\"bound\":{\"b\":\"Explicit\",\"bound\":\"C\"}}},\"operations\":{}},{\"version\":\"0.1.0\",\"name\":\"arithmetic.int\",\"types\":{},\"operations\":{\"iabs\":{\"extension\":\"arithmetic.int\",\"name\":\"iabs\",\"description\":\"convert signed to unsigned by taking absolute value\",\"signature\":{\"params\":[{\"tp\":\"BoundedNat\",\"bound\":7}],\"body\":{\"input\":[{\"t\":\"Opaque\",\"extension\":\"arithmetic.int.types\",\"id\":\"int\",\"args\":[{\"tya\":\"Variable\",\"idx\":0,\"cached_decl\":{\"tp\":\"BoundedNat\",\"bound\":7}}],\"bound\":\"C\"}],\"output\":[{\"t\":\"Opaque\",\"extension\":\"arithmetic.int.types\",\"id\":\"int\",\"args\":[{\"tya\":\"Variable\",\"idx\":0,\"cached_decl\":{\"tp\":\"BoundedNat\",\"bound\":7}}],\"bound\":\"C\"}]}},\"binary\":false},\"iadd\":{\"extension\":\"arithmetic.int\",\"name\":\"iadd\",\"description\":\"addition modulo 2^N (signed and unsigned versions are the same op)\",\"signature\":{\"params\":[{\"tp\":\"BoundedNat\",\"bound\":7}],\"body\":{\"input\":[{\"t\":\"Opaque\",\"extension\":\"arithmetic.int.types\",\"id\":\"int\",\"args\":[{\"tya\":\"Variable\",\"idx\":0,\"cached_decl\":{\"tp\":\"BoundedNat\",\"bound\":7}}],\"bound\":\"C\"},{\"t\":\"Opaque\",\"extension\":\"arithmetic.int.types\",\"id\":\"int\",\"args\":[{\"tya\":\"Variable\",\"idx\":0,\"cached_decl\":{\"tp\":\"BoundedNat\",\"bound\":7}}],\"bound\":\"C\"}],\"output\":[{\"t\":\"Opaque\",\"extension\":\"arithmetic.int.types\",\"id\":\"int\",\"args\":[{\"tya\":\"Variable\",\"idx\":0,\"cached_decl\":{\"tp\":\"BoundedNat\",\"bound\":7}}],\"bound\":\"C\"}]}},\"binary\":false},\"iand\":{\"extension\":\"arithmetic.int\",\"name\":\"iand\",\"description\":\"bitwise AND\",\"signature\":{\"params\":[{\"tp\":\"BoundedNat\",\"bound\":7}],\"body\":{\"input\":[{\"t\":\"Opaque\",\"extension\":\"arithmetic.int.types\",\"id\":\"int\",\"args\":[{\"tya\":\"Variable\",\"idx\":0,\"cached_decl\":{\"tp\":\"BoundedNat\",\"bound\":7}}],\"bound\":\"C\"},{\"t\":\"Opaque\",\"extension\":\"arithmetic.int.types\",\"id\":\"int\",\"args\":[{\"tya\":\"Variable\",\"idx\":0,\"cached_decl\":{\"tp\":\"BoundedNat\",\"bound\":7}}],\"bound\":\"C\"}],\"output\":[{\"t\":\"Opaque\",\"extension\":\"arithmetic.int.types\",\"id\":\"int\",\"args\":[{\"tya\":\"Variable\",\"idx\":0,\"cached_decl\":{\"tp\":\"BoundedNat\",\"bound\":7}}],\"bound\":\"C\"}]}},\"binary\":false},\"idiv_checked_s\":{\"extension\":\"arithmetic.int\",\"name\":\"idiv_checked_s\",\"description\":\"as idivmod_checked_s but discarding the second output\",\"signature\":{\"params\":[{\"tp\":\"BoundedNat\",\"bound\":7}],\"body\":{\"input\":[{\"t\":\"Opaque\",\"extension\":\"arithmetic.int.types\",\"id\":\"int\",\"args\":[{\"tya\":\"Variable\",\"idx\":0,\"cached_decl\":{\"tp\":\"BoundedNat\",\"bound\":7}}],\"bound\":\"C\"},{\"t\":\"Opaque\",\"extension\":\"arithmetic.int.types\",\"id\":\"int\",\"args\":[{\"tya\":\"Variable\",\"idx\":0,\"cached_decl\":{\"tp\":\"BoundedNat\",\"bound\":7}}],\"bound\":\"C\"}],\"output\":[{\"t\":\"Sum\",\"s\":\"General\",\"rows\":[[{\"t\":\"Opaque\",\"extension\":\"prelude\",\"id\":\"error\",\"args\":[],\"bound\":\"C\"}],[{\"t\":\"Opaque\",\"extension\":\"arithmetic.int.types\",\"id\":\"int\",\"args\":[{\"tya\":\"Variable\",\"idx\":0,\"cached_decl\":{\"tp\":\"BoundedNat\",\"bound\":7}}],\"bound\":\"C\"}]]}]}},\"binary\":false},\"idiv_checked_u\":{\"extension\":\"arithmetic.int\",\"name\":\"idiv_checked_u\",\"description\":\"as idivmod_checked_u but discarding the second output\",\"signature\":{\"params\":[{\"tp\":\"BoundedNat\",\"bound\":7}],\"body\":{\"input\":[{\"t\":\"Opaque\",\"extension\":\"arithmetic.int.types\",\"id\":\"int\",\"args\":[{\"tya\":\"Variable\",\"idx\":0,\"cached_decl\":{\"tp\":\"BoundedNat\",\"bound\":7}}],\"bound\":\"C\"},{\"t\":\"Opaque\",\"extension\":\"arithmetic.int.types\",\"id\":\"int\",\"args\":[{\"tya\":\"Variable\",\"idx\":0,\"cached_decl\":{\"tp\":\"BoundedNat\",\"bound\":7}}],\"bound\":\"C\"}],\"output\":[{\"t\":\"Sum\",\"s\":\"General\",\"rows\":[[{\"t\":\"Opaque\",\"extension\":\"prelude\",\"id\":\"error\",\"args\":[],\"bound\":\"C\"}],[{\"t\":\"Opaque\",\"extension\":\"arithmetic.int.types\",\"id\":\"int\",\"args\":[{\"tya\":\"Variable\",\"idx\":0,\"cached_decl\":{\"tp\":\"BoundedNat\",\"bound\":7}}],\"bound\":\"C\"}]]}]}},\"binary\":false},\"idiv_s\":{\"extension\":\"arithmetic.int\",\"name\":\"idiv_s\",\"description\":\"as idivmod_s but discarding the second output\",\"signature\":{\"params\":[{\"tp\":\"BoundedNat\",\"bound\":7}],\"body\":{\"input\":[{\"t\":\"Opaque\",\"extension\":\"arithmetic.int.types\",\"id\":\"int\",\"args\":[{\"tya\":\"Variable\",\"idx\":0,\"cached_decl\":{\"tp\":\"BoundedNat\",\"bound\":7}}],\"bound\":\"C\"},{\"t\":\"Opaque\",\"extension\":\"arithmetic.int.types\",\"id\":\"int\",\"args\":[{\"tya\":\"Variable\",\"idx\":0,\"cached_decl\":{\"tp\":\"BoundedNat\",\"bound\":7}}],\"bound\":\"C\"}],\"output\":[{\"t\":\"Opaque\",\"extension\":\"arithmetic.int.types\",\"id\":\"int\",\"args\":[{\"tya\":\"Variable\",\"idx\":0,\"cached_decl\":{\"tp\":\"BoundedNat\",\"bound\":7}}],\"bound\":\"C\"}]}},\"binary\":false},\"idiv_u\":{\"extension\":\"arithmetic.int\",\"name\":\"idiv_u\",\"description\":\"as idivmod_u but discarding the second output\",\"signature\":{\"params\":[{\"tp\":\"BoundedNat\",\"bound\":7}],\"body\":{\"input\":[{\"t\":\"Opaque\",\"extension\":\"arithmetic.int.types\",\"id\":\"int\",\"args\":[{\"tya\":\"Variable\",\"idx\":0,\"cached_decl\":{\"tp\":\"BoundedNat\",\"bound\":7}}],\"bound\":\"C\"},{\"t\":\"Opaque\",\"extension\":\"arithmetic.int.types\",\"id\":\"int\",\"args\":[{\"tya\":\"Variable\",\"idx\":0,\"cached_decl\":{\"tp\":\"BoundedNat\",\"bound\":7}}],\"bound\":\"C\"}],\"output\":[{\"t\":\"Opaque\",\"extension\":\"arithmetic.int.types\",\"id\":\"int\",\"args\":[{\"tya\":\"Variable\",\"idx\":0,\"cached_decl\":{\"tp\":\"BoundedNat\",\"bound\":7}}],\"bound\":\"C\"}]}},\"binary\":false},\"idivmod_checked_s\":{\"extension\":\"arithmetic.int\",\"name\":\"idivmod_checked_s\",\"description\":\"given signed integer -2^{N-1} <= n < 2^{N-1} and unsigned 0 <= m < 2^N, generates signed q and unsigned r where q*m+r=n, 0<=r, + /// The input ports of the subgraph. + /// + /// Grouped by input parameter. Each port must be unique and belong to a + /// node in `nodes`. + inputs: Vec>, + /// The output ports of the subgraph. + /// + /// Repeated ports are allowed and correspond to copying the output. Every + /// port must belong to a node in `nodes`. + outputs: Vec<(PatchNode, OutgoingPort)>, + /// The commits that must be selected in the host for the subgraph to be + /// valid. + selected_commits: BTreeSet, +} + +impl From> for PinnedSubgraph { + fn from(subgraph: SiblingSubgraph) -> Self { + Self { + inputs: subgraph.incoming_ports().clone(), + outputs: subgraph.outgoing_ports().clone(), + nodes: BTreeSet::from_iter(subgraph.nodes().iter().copied()), + selected_commits: BTreeSet::new(), + } + } +} + +impl PinnedSubgraph { + /// Create a new subgraph from a set of pinned nodes and wires. + /// + /// All nodes must be pinned and all wires must be complete in the given + /// `walker`. + /// + /// Nodes that are not isolated, i.e. are attached to at least one wire in + /// `wires` will be added implicitly to the graph and do not need to be + /// explicitly listed in `nodes`. + pub fn try_from_pinned( + nodes: impl IntoIterator, + wires: impl IntoIterator, + walker: &Walker, + ) -> Result { + let mut selected_commits = BTreeSet::new(); + let host = walker.as_hugr_view(); + let wires = wires.into_iter().collect_vec(); + let nodes = nodes.into_iter().collect_vec(); + + for w in wires.iter() { + if !walker.is_complete(w, None) { + return Err(InvalidPinnedSubgraph::IncompleteWire(w.clone())); + } + for id in w.owners() { + if host.contains_id(id) { + selected_commits.insert(id); + } else { + return Err(InvalidPinnedSubgraph::InvalidCommit(id)); + } + } + } + + if let Some(&unpinned) = nodes.iter().find(|&&n| !walker.is_pinned(n)) { + return Err(InvalidPinnedSubgraph::UnpinnedNode(unpinned)); + } + + let (inputs, outputs, all_nodes) = Self::compute_io_ports(nodes, wires, host); + + Ok(Self { + selected_commits, + nodes: all_nodes, + inputs, + outputs, + }) + } + + /// Create a new subgraph from a set of complete wires in `walker`. + pub fn try_from_wires( + wires: impl IntoIterator, + walker: &Walker, + ) -> Result { + Self::try_from_pinned(std::iter::empty(), wires, walker) + } + + /// Compute the input and output ports for the given pinned nodes and wires. + /// + /// Return the input boundary ports, output boundary ports as well as the + /// set of all nodes in the subgraph. + pub fn compute_io_ports( + nodes: impl IntoIterator, + wires: impl IntoIterator, + host: &PersistentHugr, + ) -> ( + IncomingPorts, + OutgoingPorts, + BTreeSet, + ) { + let mut wire_ports_incoming = BTreeSet::new(); + let mut wire_ports_outgoing = BTreeSet::new(); + + for w in wires { + wire_ports_incoming.extend(w.all_incoming_ports(host)); + wire_ports_outgoing.extend(w.single_outgoing_port(host)); + } + + let mut all_nodes = BTreeSet::from_iter(nodes); + all_nodes.extend(wire_ports_incoming.iter().map(|&(n, _)| n)); + all_nodes.extend(wire_ports_outgoing.iter().map(|&(n, _)| n)); + + // (in/out) boundary: all in/out ports on the nodes of the wire, minus ports + // that are part of the wires + let inputs = all_nodes + .iter() + .flat_map(|&n| host.input_value_ports(n)) + .filter(|node_port| !wire_ports_incoming.contains(node_port)) + .map(|np| vec![np]) + .collect_vec(); + let outputs = all_nodes + .iter() + .flat_map(|&n| host.output_value_ports(n)) + .filter(|node_port| !wire_ports_outgoing.contains(node_port)) + .collect_vec(); + + (inputs, outputs, all_nodes) + } + + /// Convert the pinned subgraph to a [`SiblingSubgraph`] for the given + /// `host`. + /// + /// This will fail if any of the required selected commits are not in the + /// host, if any of the nodes are invalid in the host (e.g. deleted by + /// another commit in host), or if the subgraph is not convex. + pub fn to_sibling_subgraph( + &self, + host: &PersistentHugr, + ) -> Result, InvalidPinnedSubgraph> { + if let Some(&unselected) = self + .selected_commits + .iter() + .find(|&&id| !host.contains_id(id)) + { + return Err(InvalidPinnedSubgraph::InvalidCommit(unselected)); + } + + if let Some(invalid) = self.nodes.iter().find(|&&n| !host.contains_node(n)) { + return Err(InvalidPinnedSubgraph::InvalidNode(*invalid)); + } + + Ok(SiblingSubgraph::try_new( + self.inputs.clone(), + self.outputs.clone(), + host, + )?) + } + + /// Iterate over all the commits required by this pinned subgraph. + pub fn selected_commits(&self) -> impl Iterator + '_ { + self.selected_commits.iter().copied() + } + + /// Iterate over all the nodes in this pinned subgraph. + pub fn nodes(&self) -> impl Iterator + '_ { + self.nodes.iter().copied() + } + + /// Returns the computed [`IncomingPorts`] of the subgraph. + #[must_use] + pub fn incoming_ports(&self) -> &IncomingPorts { + &self.inputs + } + + /// Returns the computed [`OutgoingPorts`] of the subgraph. + #[must_use] + pub fn outgoing_ports(&self) -> &OutgoingPorts { + &self.outputs + } +} + +#[derive(Debug, Clone, Error)] +#[non_exhaustive] +pub enum InvalidPinnedSubgraph { + #[error("Invalid subgraph: {0}")] + InvalidSubgraph(#[from] InvalidSubgraph), + #[error("Invalid commit in host: {0}")] + InvalidCommit(CommitId), + #[error("Wire is not complete: {0:?}")] + IncompleteWire(PersistentWire), + #[error("Node is not pinned: {0}")] + UnpinnedNode(PatchNode), + #[error("Invalid node in host: {0}")] + InvalidNode(PatchNode), +} diff --git a/hugr-core/src/hugr/persistent/tests.rs b/hugr-persistent/src/tests.rs similarity index 81% rename from hugr-core/src/hugr/persistent/tests.rs rename to hugr-persistent/src/tests.rs index ae0876ef5b..77b26be8ac 100644 --- a/hugr-core/src/hugr/persistent/tests.rs +++ b/hugr-persistent/src/tests.rs @@ -1,22 +1,20 @@ use std::collections::{BTreeMap, HashMap}; -use rstest::*; - -use crate::{ +use derive_more::derive::{From, Into}; +use hugr_core::{ IncomingPort, Node, OutgoingPort, SimpleReplacement, - builder::{DFGBuilder, Dataflow, DataflowHugr, inout_sig}, + builder::{DFGBuilder, Dataflow, DataflowHugr, endo_sig, inout_sig}, extension::prelude::bool_t, - hugr::{ - Hugr, HugrView, - patch::Patch, - persistent::{Commit, PatchNode}, - views::SiblingSubgraph, - }, + hugr::{Hugr, HugrView, patch::Patch, views::SiblingSubgraph}, ops::handle::NodeHandle, std_extensions::logic::LogicOp, }; +use rstest::*; -use super::{CommitStateSpace, state_space::CommitId}; +use crate::{ + Commit, CommitStateSpace, PatchNode, PersistentHugr, PersistentReplacement, Resolver, + state_space::CommitId, +}; /// Creates a simple test Hugr with a DFG that contains a small boolean circuit /// @@ -207,10 +205,10 @@ fn create_not_and_to_xor_replacement(hugr: &Hugr) -> SimpleReplacement { /// - `commit1` and `commit2` are disjoint with `commit4` (i.e. compatible), /// - `commit2` depends on `commit1` #[fixture] -pub(super) fn test_state_space() -> (CommitStateSpace, [CommitId; 4]) { +pub(crate) fn test_state_space() -> (CommitStateSpace, [CommitId; 4]) { let (base_hugr, [not0_node, not1_node, _and_node]) = simple_hugr(); - let mut state_space = CommitStateSpace::with_base(base_hugr); + let mut state_space = CommitStateSpace::::with_base(base_hugr); // Create first replacement (replace NOT0 with two NOT gates) let replacement1 = create_double_not_replacement(state_space.base_hugr(), not0_node); @@ -218,8 +216,11 @@ pub(super) fn test_state_space() -> (CommitStateSpace, [CommitId; 4]) { // Add first commit to state space, replacing NOT0 with two NOT gates let commit1 = { let to_patch_node = |n: Node| PatchNode(state_space.base(), n); + let new_host = state_space.try_extract_hugr([state_space.base()]).unwrap(); // translate replacement1 to patch nodes in the base commit of the state space - let replacement1 = replacement1.map_host_nodes(to_patch_node); + let replacement1 = replacement1 + .map_host_nodes(to_patch_node, &new_host) + .unwrap(); state_space.try_add_replacement(replacement1).unwrap() }; @@ -259,7 +260,10 @@ pub(super) fn test_state_space() -> (CommitStateSpace, [CommitId; 4]) { }; // translate replacement2 to patch nodes - let replacement2 = replacement2.map_host_nodes(to_patch_node); + let new_host = state_space.try_extract_hugr([commit1]).unwrap(); + let replacement2 = replacement2 + .map_host_nodes(to_patch_node, &new_host) + .unwrap(); state_space.try_add_replacement(replacement2).unwrap() }; @@ -268,9 +272,11 @@ pub(super) fn test_state_space() -> (CommitStateSpace, [CommitId; 4]) { let commit3 = { let replacement3 = create_not_and_to_xor_replacement(state_space.base_hugr()); let to_patch_node = |n: Node| PatchNode(state_space.base(), n); - state_space - .try_add_replacement(replacement3.map_host_nodes(to_patch_node)) - .unwrap() + let new_host = state_space.try_extract_hugr([state_space.base()]).unwrap(); + let replacement3 = replacement3 + .map_host_nodes(to_patch_node, &new_host) + .unwrap(); + state_space.try_add_replacement(replacement3).unwrap() }; // Create a fourth commit that is disjoint from `commit1`, replacing NOT1 @@ -278,13 +284,54 @@ pub(super) fn test_state_space() -> (CommitStateSpace, [CommitId; 4]) { let commit4 = { let replacement4 = create_double_not_replacement(state_space.base_hugr(), not1_node); let to_patch_node = |n: Node| PatchNode(state_space.base(), n); - let replacement4 = replacement4.map_host_nodes(to_patch_node); + let new_host = state_space.try_extract_hugr([state_space.base()]).unwrap(); + let replacement4 = replacement4 + .map_host_nodes(to_patch_node, &new_host) + .unwrap(); state_space.try_add_replacement(replacement4).unwrap() }; (state_space, [commit1, commit2, commit3, commit4]) } +#[fixture] +pub(super) fn persistent_hugr_empty_child() -> (PersistentHugr, [CommitId; 2], [PatchNode; 3]) { + let (triple_not_hugr, not_nodes) = { + let mut dfg_builder = DFGBuilder::new(endo_sig(bool_t())).unwrap(); + let [mut w] = dfg_builder.input_wires_arr(); + let not_nodes = [(); 3].map(|()| { + let handle = dfg_builder.add_dataflow_op(LogicOp::Not, vec![w]).unwrap(); + [w] = handle.outputs_arr(); + handle.node() + }); + ( + dfg_builder.finish_hugr_with_outputs([w]).unwrap(), + not_nodes, + ) + }; + let mut hugr = PersistentHugr::with_base(triple_not_hugr); + let empty_hugr = { + let dfg_builder = DFGBuilder::new(endo_sig(bool_t())).unwrap(); + let inputs = dfg_builder.input_wires(); + dfg_builder.finish_hugr_with_outputs(inputs).unwrap() + }; + let subg_nodes = [PatchNode(hugr.base(), not_nodes[1])]; + let repl = PersistentReplacement::try_new( + SiblingSubgraph::try_from_nodes(subg_nodes, &hugr).unwrap(), + &hugr, + empty_hugr, + ) + .unwrap(); + + let empty_commit = hugr.try_add_replacement(repl).unwrap(); + let base_commit = hugr.base(); + ( + hugr, + [base_commit, empty_commit], + not_nodes.map(|n| PatchNode(base_commit, n)), + ) +} + #[rstest] fn test_successive_replacements(test_state_space: (CommitStateSpace, [CommitId; 4])) { let (state_space, [commit1, commit2, _commit3, _commit4]) = test_state_space; @@ -419,8 +466,7 @@ fn test_try_add_replacement(test_state_space: (CommitStateSpace, [CommitId; 4])) let result = persistent_hugr.try_add_replacement(repl4.clone()); assert!( result.is_ok(), - "[commit1, commit2] + [commit4] are compatible. Got {:?}", - result + "[commit1, commit2] + [commit4] are compatible. Got {result:?}" ); let hugr = persistent_hugr.to_hugr(); let exp_hugr = state_space @@ -436,8 +482,7 @@ fn test_try_add_replacement(test_state_space: (CommitStateSpace, [CommitId; 4])) let result = persistent_hugr.try_add_replacement(repl3.clone()); assert!( result.is_err(), - "[commit1, commit2] + [commit3] are incompatible. Got {:?}", - result + "[commit1, commit2] + [commit3] are incompatible. Got {result:?}" ); } } @@ -477,3 +522,49 @@ fn test_try_add_commit(test_state_space: (CommitStateSpace, [CommitId; 4])) { .expect_err("commit3 is incompatible with [commit1, commit2]"); } } + +/// A Hugr that serialises with no extensions +#[derive(Debug, Clone, serde::Serialize, serde::Deserialize, From, Into)] +pub(crate) struct WrappedHugr { + #[serde(with = "serial")] + pub hugr: Hugr, +} + +mod serial { + use hugr_core::envelope::EnvelopeConfig; + use hugr_core::std_extensions::STD_REG; + use serde::Deserialize; + + use super::*; + + pub(crate) fn serialize(hugr: &Hugr, serializer: S) -> Result + where + S: serde::Serializer, + { + let mut str = hugr + .store_str_with_exts(EnvelopeConfig::text(), &STD_REG) + .map_err(serde::ser::Error::custom)?; + // TODO: replace this with a proper hugr hash (see https://github.com/CQCL/hugr/issues/2091) + remove_encoder_version(&mut str); + serializer.serialize_str(&str) + } + + fn remove_encoder_version(str: &mut String) { + // Remove encoder version information for consistent test output + let encoder_pattern = r#""encoder":"hugr-rs v"#; + if let Some(start) = str.find(encoder_pattern) { + if let Some(end) = str[start..].find(r#"","#) { + let end = start + end + 2; // +2 for the `",` part + str.replace_range(start..end, ""); + } + } + } + + pub(crate) fn deserialize<'de, D>(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + let str = String::deserialize(deserializer)?; + Hugr::load_str(str, Some(&STD_REG)).map_err(serde::de::Error::custom) + } +} diff --git a/hugr-core/src/hugr/persistent/trait_impls.rs b/hugr-persistent/src/trait_impls.rs similarity index 92% rename from hugr-core/src/hugr/persistent/trait_impls.rs rename to hugr-persistent/src/trait_impls.rs index 6c68762029..17fadca6c6 100644 --- a/hugr-core/src/hugr/persistent/trait_impls.rs +++ b/hugr-persistent/src/trait_impls.rs @@ -1,18 +1,19 @@ use std::collections::HashMap; use itertools::{Either, Itertools}; -use portgraph::render::MermaidFormat; -use crate::{ +use hugr_core::{ Direction, Hugr, HugrView, Node, Port, + extension::ExtensionRegistry, hugr::{ - Patch, SimpleReplacementError, + self, Patch, SimpleReplacementError, internal::HugrInternals, views::{ ExtractionResult, render::{self, MermaidFormatter, NodeLabel}, }, }, + ops::OpType, }; use super::{ @@ -36,9 +37,9 @@ impl Patch for PersistentReplacement { } } -impl HugrInternals for PersistentHugr { +impl HugrInternals for PersistentHugr { type RegionPortgraph<'p> - = portgraph::MultiPortGraph + = portgraph::MultiPortGraph where Self: 'p; @@ -57,15 +58,10 @@ impl HugrInternals for PersistentHugr { let (hugr, node_map) = self.apply_all(); let parent = node_map[&parent]; - let region = portgraph::view::FlatRegion::new_without_root( - hugr.graph, - hugr.hierarchy, - parent.into_portgraph(), - ); - (region, node_map) + (hugr.into_region_portgraph(parent), node_map) } - fn node_metadata_map(&self, node: Self::Node) -> &crate::hugr::NodeMetadataMap { + fn node_metadata_map(&self, node: Self::Node) -> &hugr::NodeMetadataMap { self.as_state_space().node_metadata_map(node) } } @@ -75,7 +71,7 @@ impl HugrInternals for PersistentHugr { // the whole extracted HUGR in memory. We are currently prioritizing correctness // and clarity over performance and will optimise some of these operations in // the future as bottlenecks are encountered. (see #2248) -impl HugrView for PersistentHugr { +impl HugrView for PersistentHugr { fn entrypoint(&self) -> Self::Node { // The entrypoint remains unchanged throughout the patch history, and is // found in the base hugr. @@ -111,7 +107,7 @@ impl HugrView for PersistentHugr { Some(parent_inv) } - fn get_optype(&self, node: Self::Node) -> &crate::ops::OpType { + fn get_optype(&self, node: Self::Node) -> &OpType { self.as_state_space().get_optype(node) } @@ -179,11 +175,11 @@ impl HugrView for PersistentHugr { } else { match port.as_directed() { Either::Left(incoming) => { - let (out_node, out_port) = self.get_single_outgoing_port(node, incoming); + let (out_node, out_port) = self.single_outgoing_port(node, incoming); ret_ports.push((out_node, out_port.into())) } Either::Right(outgoing) => ret_ports.extend( - self.get_all_incoming_ports(node, outgoing) + self.all_incoming_ports(node, outgoing) .map(|(node, port)| (node, port.into())), ), } @@ -260,7 +256,7 @@ impl HugrView for PersistentHugr { // replace node labels with patch node IDs let node_labels_map: HashMap<_, _> = node_map .into_iter() - .map(|(k, v)| (v, format!("{:?}", k))) + .map(|(k, v)| (v, format!("{k:?}"))) .collect(); NodeLabel::Custom(node_labels_map) } @@ -281,12 +277,7 @@ impl HugrView for PersistentHugr { .with_port_offsets(formatter.port_offsets()) .with_type_labels(formatter.type_labels()); - hugr.graph - .mermaid_format() - .with_hierarchy(&hugr.hierarchy) - .with_node_style(render::node_style(&hugr, config.clone())) - .with_edge_style(render::edge_style(&hugr, config)) - .finish() + config.finish() } fn dot_string(&self) -> String @@ -296,8 +287,8 @@ impl HugrView for PersistentHugr { unimplemented!("use mermaid_string instead") } - fn extensions(&self) -> &crate::extension::ExtensionRegistry { - &self.base_hugr().extensions + fn extensions(&self) -> &ExtensionRegistry { + self.base_hugr().extensions() } fn extract_hugr( @@ -305,7 +296,7 @@ impl HugrView for PersistentHugr { parent: Self::Node, ) -> ( Hugr, - impl crate::hugr::views::ExtractionResult + 'static, + impl hugr::views::ExtractionResult + 'static, ) { let (hugr, apply_node_map) = self.apply_all(); let (extracted_hugr, extracted_node_map) = hugr.extract_hugr(apply_node_map[&parent]); @@ -330,7 +321,7 @@ impl HugrView for PersistentHugr { mod tests { use std::collections::HashSet; - use crate::hugr::persistent::{CommitStateSpace, state_space::CommitId}; + use crate::{CommitStateSpace, state_space::CommitId}; use super::super::tests::test_state_space; use super::*; diff --git a/hugr-core/src/hugr/persistent/walker.rs b/hugr-persistent/src/walker.rs similarity index 52% rename from hugr-core/src/hugr/persistent/walker.rs rename to hugr-persistent/src/walker.rs index 47123f393e..bf579398c5 100644 --- a/hugr-core/src/hugr/persistent/walker.rs +++ b/hugr-persistent/src/walker.rs @@ -44,25 +44,29 @@ //! 5. Once exploration is complete (e.g. a pattern was fully matched), the //! walker can be converted into a [`PersistentHugr`] instance using //! [`Walker::into_persistent_hugr`]. The matched nodes and ports can then be -//! used to create a [`SiblingSubgraph`](crate::hugr::views::SiblingSubgraph) -//! object, which can then be used to create new -//! [`SimpleReplacement`](crate::SimpleReplacement) instances---and possibly -//! in turn be added to the commit state space and the exploration of the -//! state space restarted! +//! used to create a +//! [`SiblingSubgraph`](hugr_core::hugr::views::SiblingSubgraph) object, +//! which can then be used to create new +//! [`SimpleReplacement`](hugr_core::SimpleReplacement) instances---and +//! possibly in turn be added to the commit state space and the exploration +//! of the state space restarted! //! //! This approach allows efficiently finding patterns across many potential //! versions of the graph simultaneously, without having to materialize //! each version separately. -mod pinned; -pub use pinned::PinnedWire; - use std::{borrow::Cow, collections::BTreeSet}; +use hugr_core::hugr::patch::simple_replace::BoundaryMode; +use hugr_core::ops::handle::DataflowParentID; use itertools::{Either, Itertools}; use thiserror::Error; -use crate::{Direction, HugrView, Port}; +use hugr_core::{Direction, Hugr, HugrView, Port, PortIndex, hugr::views::RootCheckable}; + +use crate::{Commit, PersistentReplacement, PinnedSubgraph}; + +use crate::{PersistentWire, PointerEqResolver, resolver::Resolver}; use super::{CommitStateSpace, InvalidCommit, PatchNode, PersistentHugr, state_space::CommitId}; @@ -84,30 +88,30 @@ use super::{CommitStateSpace, InvalidCommit, PatchNode, PersistentHugr, state_sp /// expansions of the current walker. /// current walker. #[derive(Debug, Clone)] -pub struct Walker<'a> { +pub struct Walker<'a, R: Clone = PointerEqResolver> { /// The state space being traversed. - state_space: Cow<'a, CommitStateSpace>, + state_space: Cow<'a, CommitStateSpace>, /// The subset of compatible commits in `state_space` that are currently /// selected. // Note that we could store this as a set of `CommitId`s, but it is very // convenient to have access to all the methods of PersistentHugr (on top // of guaranteeing the compatibility invariant). The tradeoff is more // memory consumption. - selected_commits: PersistentHugr, + selected_commits: PersistentHugr, /// The set of nodes that have been traversed by the walker and can no /// longer be rewritten. pinned_nodes: BTreeSet, } -impl<'a> Walker<'a> { +impl<'a, R: Resolver> Walker<'a, R> { /// Create a new [`Walker`] over the given state space. /// /// No nodes are pinned initially. The [`Walker`] starts with only the base /// Hugr `state_space.base_hugr()` selected. - pub fn new(state_space: impl Into>) -> Self { + pub fn new(state_space: impl Into>>) -> Self { let state_space = state_space.into(); let base = state_space.base_commit().clone(); - let selected_commits = PersistentHugr::from_commit(base); + let selected_commits: PersistentHugr = PersistentHugr::from_commit(base); Self { state_space, selected_commits, @@ -118,7 +122,7 @@ impl<'a> Walker<'a> { /// Create a new [`Walker`] with a single pinned node. pub fn from_pinned_node( node: PatchNode, - state_space: impl Into>, + state_space: impl Into>>, ) -> Self { let mut walker = Self::new(state_space); walker @@ -146,68 +150,59 @@ impl<'a> Walker<'a> { } } else { let commit = self.state_space.get_commit(commit_id).clone(); - // TODO/Optimize: we should be able to check for an AlreadyPinned error at - // the same time that we check the ancestors are compatible in - // `PersistentHugr`, with e.g. a callback, instead of storing a backup - let backup = self.selected_commits.clone(); - self.selected_commits.try_add_commit(commit)?; - if let Some(&pinned_node) = self - .pinned_nodes - .iter() - .find(|&&n| !self.selected_commits.contains_node(n)) - { - self.selected_commits = backup; - return Err(PinNodeError::AlreadyPinned(pinned_node)); - } + self.try_select_commit(commit)?; } Ok(self.pinned_nodes.insert(node)) } - /// Get the wire connected to a specified port of a pinned node. + /// Add a commit to the selected commits of the Walker. /// - /// # Panics - /// Panics if `node` is not already pinned in this Walker. - pub fn get_wire(&self, node: PatchNode, port: impl Into) -> PinnedWire { - PinnedWire::from_pinned_port(node, port, self) - } - - /// Materialise the [`PersistentHugr`] containing all the compatible commits - /// that have been selected during exploration. - pub fn into_persistent_hugr(self) -> PersistentHugr { - self.selected_commits - } - - /// View the [`PersistentHugr`] containing all the compatible commits that - /// have been selected so far during exploration. + /// Return the ID of the added commit if it was added successfully, or the + /// existing ID of the commit if it was already selected. /// - /// Of the space of all possible HUGRs that can be obtained from future - /// expansions of the walker, this is the HUGR corresponding to selecting - /// as few commits as possible (i.e. all the commits that have been selected - /// so far and no more). - pub fn as_hugr_view(&self) -> &PersistentHugr { - &self.selected_commits + /// Return an error if the commit is not compatible with the current set of + /// selected commits, or if the commit deletes an already pinned node. + pub fn try_select_commit(&mut self, commit: Commit) -> Result { + // TODO: we should be able to check for an AlreadyPinned error at + // the same time that we check the ancestors are compatible in + // `PersistentHugr`, with e.g. a callback, instead of storing a backup + let backup = self.selected_commits.clone(); + let commit_id = self.selected_commits.try_add_commit(commit)?; + if let Some(&pinned_node) = self + .pinned_nodes + .iter() + .find(|&&n| !self.selected_commits.contains_node(n)) + { + self.selected_commits = backup; + return Err(PinNodeError::AlreadyPinned(pinned_node)); + } + Ok(commit_id) } /// Expand the Walker by pinning a node connected to the given wire. /// /// To understand how Walkers are expanded, it is useful to understand how /// in a walker, the HUGR graph is partitioned into two parts: - /// - a subgraph made of pinned nodes: this part of the HUGR is frozen: it cannot be - /// modified by further expansions the Walker. + /// - a subgraph made of pinned nodes: this part of the HUGR is frozen: it + /// cannot be modified by further expansions the Walker. /// - the complement subgraph: the unpinned part of the HUGR has not been - /// explored yet. Multiple alternative HUGRs can be obtained depending - /// on which commits are selected. + /// explored yet. Multiple alternative HUGRs can be obtained depending on + /// which commits are selected. /// /// To every walker thus corresponds a space of possible HUGRs that can be - /// obtained, depending on which commits are selected and which further nodes - /// are pinned. The expansion of a walker returns a set of walkers, which - /// together cover the same space of possible HUGRs, each having a different - /// additional node pinned. + /// obtained, depending on which commits are selected and which further + /// nodes are pinned. The expansion of a walker returns a set of + /// walkers, which together cover the same space of possible HUGRs, each + /// having a different additional node pinned. /// - /// Return an iterator over all possible [`Walker`]s that can be created by - /// pinning exactly one additional node connected to `wire`. Each returned - /// [`Walker`] represents a different alternative Hugr in the exploration - /// space. + /// If the wire is not complete yet, return an iterator over all possible + /// [`Walker`]s that can be created by pinning exactly one additional + /// node (or one additonal commit with an empty wire) connected to + /// `wire`. Each returned [`Walker`] represents a different alternative + /// Hugr in the exploration space. + /// + /// If the wire is already complete, return an iterator containing one + /// walker: the current walker unchanged. /// /// Optionally, the expansion can be restricted to only ports with the given /// direction (incoming or outgoing). @@ -219,78 +214,221 @@ impl<'a> Walker<'a> { /// true, then an empty iterator is returned. pub fn expand<'b>( &'b self, - wire: &'b PinnedWire, + wire: &'b PersistentWire, dir: impl Into>, - ) -> impl Iterator> + 'b { + ) -> impl Iterator> + 'b { let dir = dir.into(); + if self.is_complete(wire, dir) { + return Either::Left(std::iter::once(self.clone())); + } + // Find unpinned ports on the wire (satisfying the direction constraint) - let unpinned_ports = wire.unpinned_ports(dir); + let unpinned_ports = self.wire_unpinned_ports(wire, dir); // Obtain set of pinnable nodes by considering all ports (in descendant // commits) equivalent to currently unpinned ports. let pinnable_nodes = unpinned_ports .flat_map(|(node, port)| self.equivalent_descendant_ports(node, port)) - .map(|(n, _)| n) + .map(|(n, _, commits)| (n, commits)) .unique(); - pinnable_nodes.filter_map(|pinnable_node| { + let new_walkers = pinnable_nodes.filter_map(|(pinnable_node, new_commits)| { + let contains_new_commit = || { + new_commits + .iter() + .any(|&cm| !self.selected_commits.contains_id(cm)) + }; debug_assert!( - !self.is_pinned(pinnable_node), - "trying to pin already pinned node" + !self.is_pinned(pinnable_node) || contains_new_commit(), + "trying to pin already pinned node and no new commit is selected" ); - // Construct a new walker by pinning `pinnable_node` (if possible). - let mut new_walker = self.clone(); + // Update the selected commits to include the new commits. + let new_selected_commits = self + .state_space + .try_extract_hugr(self.selected_commits.all_commit_ids().chain(new_commits)) + .ok()?; + + // Make sure that the pinned nodes are still valid after including the new + // selected commits. + if self + .pinned_nodes + .iter() + .any(|&pnode| !new_selected_commits.contains_node(pnode)) + { + return None; + } + + // Construct a new walker and pin `pinnable_node`. + let mut new_walker = Walker { + state_space: self.state_space.clone(), + selected_commits: new_selected_commits, + pinned_nodes: self.pinned_nodes.clone(), + }; new_walker.try_pin_node(pinnable_node).ok()?; Some(new_walker) - }) + }); + + Either::Right(new_walkers) + } + + /// Create a new commit from a set of complete pinned wires and a + /// replacement. + /// + /// The subgraph of the commit is the subgraph given by the set of edges + /// in `wires`. `map_boundary` must provide a map from the boundary ports + /// of the subgraph to the inputs/output ports in `repl`. The returned port + /// must be of the opposite direction as the port passed as argument: + /// - an incoming subgraph port must be mapped to an outgoing port of the + /// input node of `repl` + /// - an outgoing subgraph port must be mapped to an incoming port of the + /// output node of `repl` + /// + /// ## Panics + /// + /// This will panic if repl is not a DFG graph. + pub fn try_create_commit( + &self, + subgraph: impl Into, + repl: impl RootCheckable, + map_boundary: impl Fn(PatchNode, Port) -> Port, + ) -> Result { + let pinned_subgraph = subgraph.into(); + let subgraph = pinned_subgraph.to_sibling_subgraph(self.as_hugr_view())?; + let selected_commits = pinned_subgraph + .selected_commits() + .map(|id| self.state_space.get_commit(id).clone()); + + let repl = { + let mut repl = repl.try_into_checked().expect("replacement is not DFG"); + let new_inputs = subgraph + .incoming_ports() + .iter() + .flatten() // because of singleton-vec wrapping above + .map(|&(n, p)| { + map_boundary(n, p.into()) + .as_outgoing() + .expect("unexpected port direction returned by map_boundary") + .index() + }) + .collect_vec(); + let new_outputs = subgraph + .outgoing_ports() + .iter() + .map(|&(n, p)| { + map_boundary(n, p.into()) + .as_incoming() + .expect("unexpected port direction returned by map_boundary") + .index() + }) + .collect_vec(); + repl.map_function_type(&new_inputs, &new_outputs)?; + PersistentReplacement::try_new(subgraph, self.as_hugr_view(), repl.into_hugr())? + }; + + Commit::try_new(repl, selected_commits, &self.state_space) + } +} + +impl Walker<'_, R> { + /// Get the wire connected to a specified port of a pinned node. + /// + /// # Panics + /// Panics if `node` is not already pinned in this Walker. + pub fn get_wire(&self, node: PatchNode, port: impl Into) -> PersistentWire { + assert!(self.is_pinned(node), "node must be pinned"); + self.selected_commits.get_wire(node, port) + } + + /// Materialise the [`PersistentHugr`] containing all the compatible commits + /// that have been selected during exploration. + pub fn into_persistent_hugr(self) -> PersistentHugr { + self.selected_commits + } + + /// View the [`PersistentHugr`] containing all the compatible commits that + /// have been selected so far during exploration. + /// + /// Of the space of all possible HUGRs that can be obtained from future + /// expansions of the walker, this is the HUGR corresponding to selecting + /// as few commits as possible (i.e. all the commits that have been selected + /// so far and no more). + pub fn as_hugr_view(&self) -> &PersistentHugr { + &self.selected_commits + } + + /// Check if a node is pinned in the [`Walker`]. + pub fn is_pinned(&self, node: PatchNode) -> bool { + self.pinned_nodes.contains(&node) + } + + /// Iterate over all pinned nodes in the [`Walker`]. + pub fn pinned_nodes(&self) -> impl Iterator + '_ { + self.pinned_nodes.iter().copied() } /// Get all equivalent ports among the commits that are descendants of the /// current commit. /// /// The ports in the returned iterator will be in the same direction as - /// `port`. - fn equivalent_descendant_ports(&self, node: PatchNode, port: Port) -> Vec<(PatchNode, Port)> { + /// `port`. For each equivalent port, also return the set of empty commits + /// that were visited to find it. + fn equivalent_descendant_ports( + &self, + node: PatchNode, + port: Port, + ) -> Vec<(PatchNode, Port, BTreeSet)> { // Now, perform a BFS to find all equivalent ports - let mut all_ports = vec![(node, port)]; + let mut all_ports = vec![(node, port, BTreeSet::new())]; let mut index = 0; while index < all_ports.len() { - let (node, port) = all_ports[index]; + let (node, port, empty_commits) = all_ports[index].clone(); index += 1; for (child_id, (opp_node, opp_port)) in self.state_space.children_at_boundary_port(node, port) { - match opp_port.as_directed() { - Either::Left(in_port) => { - if let Some((n, p)) = self - .state_space - .linked_child_output(opp_node, in_port, child_id) - { - all_ports.push((n, p.into())); - } - } - Either::Right(out_port) => { - all_ports.extend( - self.state_space - .linked_child_inputs(opp_node, out_port, child_id) - .map(|(n, p)| (n, p.into())), - ); + for (node, port) in self.state_space.linked_child_ports( + opp_node, + opp_port, + child_id, + BoundaryMode::SnapToHost, + ) { + let mut empty_commits = empty_commits.clone(); + if node.0 != child_id { + empty_commits.insert(child_id); } + all_ports.push((node, port, empty_commits)); } } } all_ports } +} + +#[cfg(test)] +impl Walker<'_, R> { + // Check walker equality by comparing pointers to the state space and + // other fields. Only for testing purposes. + fn component_wise_ptr_eq(&self, other: &Self) -> bool { + std::ptr::eq(self.state_space.as_ref(), other.state_space.as_ref()) + && self.pinned_nodes == other.pinned_nodes + && BTreeSet::from_iter(self.selected_commits.all_commit_ids()) + == BTreeSet::from_iter(other.selected_commits.all_commit_ids()) + } - fn is_pinned(&self, node: PatchNode) -> bool { - self.pinned_nodes.contains(&node) + /// Check if the Walker cannot be expanded further, i.e. expanding it + /// returns the same Walker. + fn no_more_expansion(&self, wire: &PersistentWire, dir: impl Into>) -> bool { + let Some([new_walker]) = self.expand(wire, dir).collect_array() else { + return false; + }; + new_walker.component_wise_ptr_eq(self) } } -impl CommitStateSpace { +impl CommitStateSpace { /// Given a node and port, return all child commits of the current `node` /// that delete `node` but keep at least one port linked to `(node, port)`. /// In other words, (node, port) is a boundary port of the subgraph of @@ -349,27 +487,37 @@ impl From for PinNodeError { } } -impl<'a> From<&'a CommitStateSpace> for Cow<'a, CommitStateSpace> { - fn from(value: &'a CommitStateSpace) -> Self { +impl<'a, R: Clone> From<&'a CommitStateSpace> for Cow<'a, CommitStateSpace> { + fn from(value: &'a CommitStateSpace) -> Self { Cow::Borrowed(value) } } -impl From for Cow<'_, CommitStateSpace> { - fn from(value: CommitStateSpace) -> Self { +impl From> for Cow<'_, CommitStateSpace> { + fn from(value: CommitStateSpace) -> Self { Cow::Owned(value) } } #[cfg(test)] mod tests { + use std::collections::BTreeSet; + + use hugr_core::{ + Direction, HugrView, IncomingPort, OutgoingPort, + builder::{DFGBuilder, Dataflow, DataflowHugr, endo_sig}, + extension::prelude::bool_t, + std_extensions::logic::LogicOp, + }; + use itertools::Itertools; use rstest::rstest; - use crate::hugr::persistent::{state_space::CommitId, tests::test_state_space}; - use crate::std_extensions::logic::LogicOp; - use crate::{IncomingPort, OutgoingPort}; - use super::*; + use crate::{ + PersistentHugr, Walker, + state_space::CommitId, + tests::{persistent_hugr_empty_child, test_state_space}, + }; #[rstest] fn test_walker_base_or_child_expansion(test_state_space: (CommitStateSpace, [CommitId; 4])) { @@ -390,7 +538,8 @@ mod tests { let in0 = walker.get_wire(base_and_node, IncomingPort::from(0)); // a single incoming port (already pinned) => no more expansion - assert!(walker.expand(&in0, Direction::Incoming).next().is_none()); + assert!(walker.no_more_expansion(&in0, Direction::Incoming)); + // commit 2 cannot be applied, because AND is pinned // => only base commit, or commit1 let out_walkers = walker.expand(&in0, Direction::Outgoing).collect_vec(); @@ -398,11 +547,11 @@ mod tests { for new_walker in out_walkers { // new wire is complete (and thus cannot be expanded) let in0 = new_walker.get_wire(base_and_node, IncomingPort::from(0)); - assert!(in0.is_complete(None)); - assert!(new_walker.expand(&in0, None).next().is_none()); + assert!(new_walker.is_complete(&in0, None)); + assert!(new_walker.no_more_expansion(&in0, None)); // all nodes on wire are pinned - let (not_node, _) = in0.pinned_outport().unwrap(); + let (not_node, _) = in0.single_outgoing_port(new_walker.as_hugr_view()).unwrap(); assert!(new_walker.is_pinned(base_and_node)); assert!(new_walker.is_pinned(not_node)); @@ -456,9 +605,8 @@ mod tests { assert!(walker.is_pinned(not4_node)); let not4_out = walker.get_wire(not4_node, OutgoingPort::from(0)); - let expanded_out = walker.expand(¬4_out, Direction::Outgoing).collect_vec(); // a single outgoing port (already pinned) => no more expansion - assert!(expanded_out.is_empty()); + assert!(walker.no_more_expansion(¬4_out, Direction::Outgoing)); // Three options: // - AND gate from base @@ -477,17 +625,20 @@ mod tests { .collect::>(); assert!( exp_options.remove(&commit_ids), - "{:?} not an expected set of commit IDs (or duplicate)", - commit_ids + "{commit_ids:?} not an expected set of commit IDs (or duplicate)" ); // new wire is complete (and thus cannot be expanded) let not4_out = new_walker.get_wire(not4_node, OutgoingPort::from(0)); - assert!(not4_out.is_complete(None)); - assert!(new_walker.expand(¬4_out, None).next().is_none()); + assert!(new_walker.is_complete(¬4_out, None)); + assert!(new_walker.no_more_expansion(¬4_out, None)); // all nodes on wire are pinned - let (next_node, _) = not4_out.pinned_inports().exactly_one().ok().unwrap(); + let (next_node, _) = not4_out + .all_incoming_ports(new_walker.as_hugr_view()) + .exactly_one() + .ok() + .unwrap(); assert!(new_walker.is_pinned(not4_node)); assert!(new_walker.is_pinned(next_node)); @@ -508,8 +659,7 @@ mod tests { assert!( exp_options.is_empty(), - "missing expected options: {:?}", - exp_options + "missing expected options: {exp_options:?}" ); } @@ -527,7 +677,7 @@ mod tests { let hugr = state_space.try_extract_hugr([commit4]).unwrap(); let (second_not_node, out_port) = - hugr.get_single_outgoing_port(base_and_node, IncomingPort::from(1)); + hugr.single_outgoing_port(base_and_node, IncomingPort::from(1)); assert_eq!(second_not_node.0, commit4); assert_eq!(out_port, OutgoingPort::from(0)); @@ -535,11 +685,153 @@ mod tests { .try_extract_hugr([commit1, commit2, commit4]) .unwrap(); let (new_and_node, in_port) = hugr - .get_all_incoming_ports(second_not_node, out_port) + .all_incoming_ports(second_not_node, out_port) .exactly_one() .ok() .unwrap(); assert_eq!(new_and_node.0, commit2); assert_eq!(in_port, 1.into()); } + + /// Test that the walker handles empty replacements correctly. + /// + /// The base hugr is a sequence of 3 NOT gates, with a single input/output + /// boolean. A single replacement exists in the state space, which replaces + /// the middle NOT gate with nothing. + #[rstest] + fn test_walk_over_empty_repls( + persistent_hugr_empty_child: (PersistentHugr, [CommitId; 2], [PatchNode; 3]), + ) { + let (hugr, [base_commit, empty_commit], [not0, not1, not2]) = persistent_hugr_empty_child; + let walker = Walker::from_pinned_node(not0, hugr.as_state_space()); + + let not0_outwire = walker.get_wire(not0, OutgoingPort::from(0)); + let expanded_wires = walker + .expand(¬0_outwire, Direction::Incoming) + .collect_vec(); + + assert_eq!(expanded_wires.len(), 2); + + let connected_inports: BTreeSet<_> = expanded_wires + .iter() + .map(|new_walker| { + let wire = new_walker.get_wire(not0, OutgoingPort::from(0)); + wire.all_incoming_ports(new_walker.as_hugr_view()) + .exactly_one() + .ok() + .unwrap() + }) + .collect(); + + assert_eq!( + connected_inports, + BTreeSet::from_iter([(not1, IncomingPort::from(0)), (not2, IncomingPort::from(0))]) + ); + + let traversed_commits: BTreeSet> = expanded_wires + .iter() + .map(|new_walker| { + let wire = new_walker.get_wire(not0, OutgoingPort::from(0)); + wire.owners().collect() + }) + .collect(); + + assert_eq!( + traversed_commits, + BTreeSet::from_iter([ + BTreeSet::from_iter([base_commit]), + BTreeSet::from_iter([base_commit, empty_commit]) + ]) + ); + } + + #[rstest] + fn test_create_commit_over_empty( + persistent_hugr_empty_child: (PersistentHugr, [CommitId; 2], [PatchNode; 3]), + ) { + let (hugr, [base_commit, empty_commit], [not0, _not1, not2]) = persistent_hugr_empty_child; + let mut walker = Walker { + state_space: hugr.as_state_space().into(), + selected_commits: hugr.clone(), + pinned_nodes: BTreeSet::from_iter([not0]), + }; + + // wire: Not0 -> Not2 (bridging over Not1) + let wire = walker.get_wire(not0, OutgoingPort::from(0)); + walker = walker.expand(&wire, None).exactly_one().ok().unwrap(); + let wire = walker.get_wire(not0, OutgoingPort::from(0)); + assert!(walker.is_complete(&wire, None)); + + let empty_hugr = { + let dfg_builder = DFGBuilder::new(endo_sig(bool_t())).unwrap(); + let inputs = dfg_builder.input_wires(); + dfg_builder.finish_hugr_with_outputs(inputs).unwrap() + }; + let commit = walker + .try_create_commit( + PinnedSubgraph::try_from_pinned(std::iter::empty(), [wire], &walker).unwrap(), + empty_hugr, + |node, port| { + assert_eq!(port.index(), 0); + assert!([not0, not2].contains(&node)); + match port.direction() { + Direction::Incoming => OutgoingPort::from(0).into(), + Direction::Outgoing => IncomingPort::from(0).into(), + } + }, + ) + .unwrap(); + + let mut new_state_space = hugr.as_state_space().to_owned(); + let commit_id = new_state_space.try_add_commit(commit.clone()).unwrap(); + assert_eq!( + new_state_space.parents(commit_id).collect::>(), + BTreeSet::from_iter([base_commit, empty_commit]) + ); + + let res_hugr: PersistentHugr = PersistentHugr::from_commit(commit); + assert!(res_hugr.validate().is_ok()); + + // should be an empty DFG hugr + // module root + function def + func I/O nodes + DFG entrypoint + I/O nodes + assert_eq!(res_hugr.num_nodes(), 1 + 1 + 2 + 1 + 2); + } + + /// Test that the walker handles empty replacements correctly. + /// + /// The base hugr is a sequence of 3 NOT gates, with a single input/output + /// boolean. A single replacement exists in the state space, which replaces + /// the middle NOT gate with nothing. + /// + /// In this test, we pin both the first and third NOT and see if the walker + /// suggests to possible wires as outgoing from the first NOT. This tests + /// the edge case in which a new wire already has all its ports pinned. + #[rstest] + fn test_walk_over_two_pinned_nodes( + persistent_hugr_empty_child: (PersistentHugr, [CommitId; 2], [PatchNode; 3]), + ) { + let (hugr, [base_commit, empty_commit], [not0, _not1, not2]) = persistent_hugr_empty_child; + let mut walker = Walker::from_pinned_node(not0, hugr.as_state_space()); + assert!(walker.try_pin_node(not2).unwrap()); + + let not0_outwire = walker.get_wire(not0, OutgoingPort::from(0)); + let expanded_walkers = walker.expand(¬0_outwire, Direction::Incoming); + + let expanded_wires: BTreeSet> = expanded_walkers + .map(|new_walker| { + new_walker + .get_wire(not0, OutgoingPort::from(0)) + .owners() + .collect() + }) + .collect(); + + assert_eq!( + expanded_wires, + BTreeSet::from_iter([ + BTreeSet::from_iter([base_commit]), + BTreeSet::from_iter([base_commit, empty_commit]) + ]) + ); + } } diff --git a/hugr-persistent/src/wire.rs b/hugr-persistent/src/wire.rs new file mode 100644 index 0000000000..a84d4e6923 --- /dev/null +++ b/hugr-persistent/src/wire.rs @@ -0,0 +1,303 @@ +use std::collections::{BTreeSet, VecDeque}; + +use hugr_core::{ + Direction, HugrView, IncomingPort, OutgoingPort, Port, Wire, + hugr::patch::simple_replace::BoundaryMode, +}; +use itertools::Itertools; + +use crate::{CommitId, PatchNode, PersistentHugr, Resolver, Walker}; + +/// A wire in a [`PersistentHugr`]. +/// +/// A wire may be composed of multiple wires in the underlying commits +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub struct PersistentWire { + wires: BTreeSet, +} + +/// A wire within a commit HUGR of a [`PersistentHugr`]. +/// +/// Also stores the ID of the commit that contains the wire; +/// equivalent to (indeed contains) a `Wire`. +/// +/// Note that it does not correspond to a valid wire in a [`PersistentHugr`] +/// (see [`PersistentWire`]): some of its connected ports may be on deleted or +/// IO nodes that are not valid in the [`PersistentHugr`]. +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] +struct CommitWire(Wire); + +impl CommitWire { + fn from_connected_port( + PatchNode(commit_id, node): PatchNode, + port: impl Into, + hugr: &PersistentHugr, + ) -> Self { + let commit_hugr = hugr.commit_hugr(commit_id); + let wire = Wire::from_connected_port(node, port, commit_hugr); + Self(Wire::new(PatchNode(commit_id, wire.node()), wire.source())) + } + + fn all_connected_ports<'h, R>( + &self, + hugr: &'h PersistentHugr, + ) -> impl Iterator + use<'h, R> { + let wire = Wire::new(self.0.node().1, self.0.source()); + let commit_id = self.commit_id(); + wire.all_connected_ports(hugr.commit_hugr(commit_id)) + .map(move |(node, port)| (hugr.to_persistent_node(node, commit_id), port)) + } + + fn commit_id(&self) -> CommitId { + self.0.node().0 + } + + delegate::delegate! { + to self.0 { + fn node(&self) -> PatchNode; + } + } +} + +/// A node in a commit of a [`PersistentHugr`] is either a valid node of the +/// HUGR, a node deleted by a child commit in that [`PersistentHugr`], or an +/// input or output node in a replacement graph. +#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] +enum NodeStatus { + /// A node deleted by a child commit in that [`PersistentHugr`]. + /// + /// The ID of the child commit is stored in the variant. + Deleted(CommitId), + /// An input or output node in the replacement graph of a Commit + ReplacementIO, + /// A valid node in the [`PersistentHugr`] + Valid, +} + +impl PersistentHugr { + pub fn get_wire(&self, node: PatchNode, port: impl Into) -> PersistentWire { + PersistentWire::from_port(node, port, self) + } + + /// Whether a node is valid in `self`, is deleted or is an IO node in a + /// replacement graph. + fn node_status(&self, per_node @ PatchNode(commit_id, node): PatchNode) -> NodeStatus { + debug_assert!(self.contains_id(commit_id), "unknown commit"); + if self + .replacement(commit_id) + .is_some_and(|repl| repl.get_replacement_io().contains(&node)) + { + NodeStatus::ReplacementIO + } else if let Some(commit_id) = self.find_deleting_commit(per_node) { + NodeStatus::Deleted(commit_id) + } else { + NodeStatus::Valid + } + } +} + +impl PersistentWire { + /// Get the wire connected to a specified port of a pinned node in `hugr`. + fn from_port(node: PatchNode, port: impl Into, per_hugr: &PersistentHugr) -> Self { + assert!(per_hugr.contains_node(node), "node not in hugr"); + + // Queue of wires within each commit HUGR, that combined will form the + // persistent wire. + let mut commit_wires = + BTreeSet::from_iter([CommitWire::from_connected_port(node, port, per_hugr)]); + let mut queue = VecDeque::from_iter(commit_wires.iter().copied()); + + while let Some(wire) = queue.pop_front() { + let commit_id = wire.commit_id(); + let commit_hugr = per_hugr.commit_hugr(commit_id); + let all_ports = wire.all_connected_ports(per_hugr); + + for (per_node @ PatchNode(_, node), port) in all_ports { + match per_hugr.node_status(per_node) { + NodeStatus::Deleted(deleted_by) => { + // If node is deleted, check if there are wires between + // ports on the opposite end of the wire and boundary + // ports in the child commit that deleted the node. + for (opp_node, opp_port) in commit_hugr.linked_ports(node, port) { + let opp_node = per_hugr.to_persistent_node(opp_node, commit_id); + for (child_node, child_port) in + per_hugr.as_state_space().linked_child_ports( + opp_node, + opp_port, + deleted_by, + BoundaryMode::IncludeIO, + ) + { + debug_assert_eq!(child_node.owner(), deleted_by); + let w = CommitWire::from_connected_port( + child_node, child_port, per_hugr, + ); + if commit_wires.insert(w) { + queue.push_back(w); + } + } + } + } + NodeStatus::ReplacementIO => { + // If node is an input (resp. output) node in a replacement graph, there + // must be (at least) one wire between the incoming (resp. outgoing) + // boundary ports of the commit (i.e. the ports connected to + // the input resp. output) and ports in a parent commit. + for (opp_node, opp_port) in commit_hugr.linked_ports(node, port) { + let opp_node = per_hugr.to_persistent_node(opp_node, commit_id); + for (parent_node, parent_port) in per_hugr + .as_state_space() + .linked_parent_ports(opp_node, opp_port) + { + let w = CommitWire::from_connected_port( + parent_node, + parent_port, + per_hugr, + ); + if commit_wires.insert(w) { + queue.push_back(w); + } + } + } + } + NodeStatus::Valid => {} + } + } + } + + Self { + wires: commit_wires, + } + } + + /// Get all ports attached to a wire in `hugr`. + /// + /// All ports returned are on nodes that are contained in `hugr`. + pub fn all_ports( + &self, + hugr: &PersistentHugr, + dir: impl Into>, + ) -> impl Iterator { + all_ports_impl(self.wires.iter().copied(), dir.into(), hugr) + } + + /// All commit IDs that the wire traverses. + pub fn owners(&self) -> impl Iterator { + self.wires.iter().map(|w| w.node().owner()).unique() + } + + /// Consume the wire and return all ports attached to a wire in `hugr`. + /// + /// All ports returned are on nodes that are contained in `hugr`. + pub fn into_all_ports( + self, + hugr: &PersistentHugr, + dir: impl Into>, + ) -> impl Iterator { + all_ports_impl(self.wires.into_iter(), dir.into(), hugr) + } + + pub fn single_outgoing_port( + &self, + hugr: &PersistentHugr, + ) -> Option<(PatchNode, OutgoingPort)> { + single_outgoing(self.all_ports(hugr, Direction::Outgoing)) + } + + pub fn all_incoming_ports( + &self, + hugr: &PersistentHugr, + ) -> impl Iterator { + self.all_ports(hugr, Direction::Incoming) + .map(|(node, port)| (node, port.as_incoming().unwrap())) + } +} + +impl Walker<'_, R> { + /// Get all ports on a wire that are not pinned in `self`. + pub(crate) fn wire_unpinned_ports( + &self, + wire: &PersistentWire, + dir: impl Into>, + ) -> impl Iterator { + let ports = wire.all_ports(self.as_hugr_view(), dir); + ports.filter(|(node, _)| !self.is_pinned(*node)) + } + + /// Get the ports of the wire that are on pinned nodes of `self`. + pub fn wire_pinned_ports( + &self, + wire: &PersistentWire, + dir: impl Into>, + ) -> impl Iterator { + let ports = wire.all_ports(self.as_hugr_view(), dir); + ports.filter(|(node, _)| self.is_pinned(*node)) + } + + /// Get the outgoing port of a wire if it is pinned in `walker`. + pub fn wire_pinned_outport(&self, wire: &PersistentWire) -> Option<(PatchNode, OutgoingPort)> { + single_outgoing(self.wire_pinned_ports(wire, Direction::Outgoing)) + } + + /// Get all pinned incoming ports of a wire. + pub fn wire_pinned_inports( + &self, + wire: &PersistentWire, + ) -> impl Iterator { + self.wire_pinned_ports(wire, Direction::Incoming) + .map(|(node, port)| (node, port.as_incoming().expect("incoming port"))) + } + + /// Whether a wire is complete in the specified direction, i.e. there are no + /// unpinned ports left. + pub fn is_complete(&self, wire: &PersistentWire, dir: impl Into>) -> bool { + self.wire_unpinned_ports(wire, dir).next().is_none() + } +} + +/// Implementation of the (shared) body of [`PersistentWire::all_ports`] and +/// [`PersistentWire::into_all_ports`]. +fn all_ports_impl( + wires: impl Iterator, + dir: Option, + per_hugr: &PersistentHugr, +) -> impl Iterator { + let all_ports = wires.flat_map(move |w| w.all_connected_ports(per_hugr)); + + // Filter out invalid and wrong direction ports + all_ports + .filter(move |(_, port)| dir.is_none_or(|dir| port.direction() == dir)) + .filter(|&(node, _)| per_hugr.node_status(node) == NodeStatus::Valid) +} + +fn single_outgoing(iter: impl Iterator) -> Option<(N, OutgoingPort)> { + let (node, port) = iter.exactly_one().ok()?; + Some((node, port.as_outgoing().ok()?)) +} + +#[cfg(test)] +mod tests { + use std::collections::BTreeSet; + + use crate::{CommitId, CommitStateSpace, PatchNode, tests::test_state_space}; + use hugr_core::{HugrView, OutgoingPort}; + use itertools::Itertools; + use rstest::rstest; + + #[rstest] + fn test_all_ports(test_state_space: (CommitStateSpace, [CommitId; 4])) { + let (state_space, [_, _, cm3, cm4]) = test_state_space; + let hugr = state_space.try_extract_hugr([cm3, cm4]).unwrap(); + let cm4_not = { + let hugr4 = state_space.commit_hugr(cm4); + let out = state_space.replacement(cm4).unwrap().get_replacement_io()[1]; + let node = hugr4.input_neighbours(out).exactly_one().ok().unwrap(); + PatchNode(cm4, node) + }; + let w = hugr.get_wire(cm4_not, OutgoingPort::from(0)); + assert_eq!( + BTreeSet::from_iter(w.wires.iter().map(|w| w.0.node().0)), + BTreeSet::from_iter([cm3, cm4, state_space.base(),]) + ); + } +} diff --git a/hugr-core/tests/persistent_walker_example.rs b/hugr-persistent/tests/persistent_walker_example.rs similarity index 62% rename from hugr-core/tests/persistent_walker_example.rs rename to hugr-persistent/tests/persistent_walker_example.rs index 8da20df657..19a02bac6a 100644 --- a/hugr-core/tests/persistent_walker_example.rs +++ b/hugr-persistent/tests/persistent_walker_example.rs @@ -2,42 +2,37 @@ use std::collections::{BTreeSet, VecDeque}; -use hugr::types::EdgeKind; -use itertools::Itertools; +use itertools::{Either, Itertools}; use hugr_core::{ - Hugr, HugrView, PortIndex, SimpleReplacement, + Hugr, HugrView, IncomingPort, OutgoingPort, Port, PortIndex, builder::{DFGBuilder, Dataflow, DataflowHugr, endo_sig}, extension::prelude::qb_t, - hugr::{ - persistent::{CommitStateSpace, PersistentReplacement, PinnedWire, Walker}, - views::SiblingSubgraph, - }, + ops::OpType, + types::EdgeKind, }; +use hugr_persistent::{Commit, CommitStateSpace, PersistentWire, PinnedSubgraph, Walker}; + /// The maximum commit depth that we will consider in this example -const MAX_COMMITS: usize = 2; +const MAX_COMMITS: usize = 4; // We define a HUGR extension within this file, with CZ and H gates. Normally, // you would use an existing extension (e.g. as provided by tket2). -use walker_example_extension::{cz_gate, h_gate}; +use walker_example_extension::cz_gate; mod walker_example_extension { use std::sync::Arc; - use hugr::Extension; - use hugr::extension::ExtensionId; - use hugr::ops::{ExtensionOp, OpName}; - use hugr::types::{FuncValueType, PolyFuncTypeRV}; + use hugr_core::Extension; + use hugr_core::extension::ExtensionId; + use hugr_core::ops::{ExtensionOp, OpName}; + use hugr_core::types::{FuncValueType, PolyFuncTypeRV}; use lazy_static::lazy_static; use semver::Version; use super::*; - fn one_qb_func() -> PolyFuncTypeRV { - FuncValueType::new_endo(qb_t()).into() - } - fn two_qb_func() -> PolyFuncTypeRV { FuncValueType::new_endo(vec![qb_t(), qb_t()]).into() } @@ -49,15 +44,6 @@ mod walker_example_extension { EXTENSION_ID, Version::new(0, 0, 0), |extension, extension_ref| { - extension - .add_op( - OpName::new_inline("H"), - "Hadamard".into(), - one_qb_func(), - extension_ref, - ) - .unwrap(); - extension .add_op( OpName::new_inline("CZ"), @@ -75,10 +61,6 @@ mod walker_example_extension { static ref EXTENSION: Arc = extension(); } - pub fn h_gate() -> ExtensionOp { - EXTENSION.instantiate_extension_op("H", []).unwrap() - } - pub fn cz_gate() -> ExtensionOp { EXTENSION.instantiate_extension_op("CZ", []).unwrap() } @@ -109,15 +91,12 @@ fn dfg_hugr() -> Hugr { builder.finish_hugr_with_outputs(vec![q0, q1, q2]).unwrap() } -// TODO: currently empty replacements are buggy, so we have temporarily added -// a single Hadamard gate on each qubit. -fn empty_2qb_hugr() -> Hugr { - let mut builder = DFGBuilder::new(endo_sig(vec![qb_t(), qb_t()])).unwrap(); - let [q0, q1] = builder.input_wires_arr(); - let h0 = builder.add_dataflow_op(h_gate(), vec![q0]).unwrap(); - let [q0] = h0.outputs_arr(); - let h1 = builder.add_dataflow_op(h_gate(), vec![q1]).unwrap(); - let [q1] = h1.outputs_arr(); +fn empty_2qb_hugr(flip_args: bool) -> Hugr { + let builder = DFGBuilder::new(endo_sig(vec![qb_t(), qb_t()])).unwrap(); + let [mut q0, mut q1] = builder.input_wires_arr(); + if flip_args { + (q0, q1) = (q1, q0); + } builder.finish_hugr_with_outputs(vec![q0, q1]).unwrap() } @@ -134,7 +113,7 @@ fn two_cz_3qb_hugr() -> Hugr { /// Traverse all commits in state space, enqueueing all outgoing wires of /// CZ nodes fn enqueue_all( - queue: &mut VecDeque<(PinnedWire, Walker<'static>)>, + queue: &mut VecDeque<(PersistentWire, Walker<'static>)>, state_space: &CommitStateSpace, ) { for id in state_space.all_commit_ids() { @@ -170,10 +149,10 @@ fn build_state_space() -> CommitStateSpace { enqueue_all(&mut wire_queue, &state_space); while let Some((wire, walker)) = wire_queue.pop_front() { - if !wire.is_complete(None) { + if !walker.is_complete(&wire, None) { // expand the wire in all possible ways - let (pinned_node, pinned_port) = wire - .all_pinned_ports() + let (pinned_node, pinned_port) = walker + .wire_pinned_ports(&wire, None) .next() .expect("at least one port was already pinned"); assert!( @@ -183,7 +162,7 @@ fn build_state_space() -> CommitStateSpace { for subwalker in walker.expand(&wire, None) { assert!( subwalker.as_hugr_view().contains_node(pinned_node), - "pinned node is deleted" + "pinned node {pinned_node:?} is deleted", ); wire_queue.push_back((subwalker.get_wire(pinned_node, pinned_port), subwalker)); } @@ -191,7 +170,10 @@ fn build_state_space() -> CommitStateSpace { // we have a complete wire, so we can commute the CZ gates (or // cancel them out) - let patch_nodes: BTreeSet<_> = wire.all_pinned_ports().map(|(n, _)| n).collect(); + let patch_nodes: BTreeSet<_> = walker + .wire_pinned_ports(&wire, None) + .map(|(n, _)| n) + .collect(); // check that the patch applies to more than one commit (or the base), // otherwise we have infinite commutations back and forth let patch_owners: BTreeSet<_> = patch_nodes.iter().map(|n| n.0).collect(); @@ -204,22 +186,16 @@ fn build_state_space() -> CommitStateSpace { continue; } - let Some(repl) = create_replacement(wire, &walker) else { + let Some(new_commit) = create_commit(wire, &walker) else { continue; }; assert_eq!( - repl.subgraph() - .nodes() - .iter() - .copied() - .collect::>(), + new_commit.deleted_nodes().collect::>(), patch_nodes ); - state_space - .try_add_replacement(repl) - .expect("repl acts on non-empty subgraph"); + state_space.try_add_commit(new_commit).unwrap(); // enqueue new wires added by the replacement // (this will also add a lot of already visited wires, but they will @@ -231,14 +207,14 @@ fn build_state_space() -> CommitStateSpace { state_space } -fn create_replacement(wire: PinnedWire, walker: &Walker) -> Option { +fn create_commit(wire: PersistentWire, walker: &Walker) -> Option { let hugr = walker.clone().into_persistent_hugr(); let (out_node, _) = wire - .pinned_outport() + .single_outgoing_port(&hugr) .expect("outgoing port was already pinned (and is unique)"); let (in_node, _) = wire - .pinned_inports() + .all_incoming_ports(&hugr) .exactly_one() .ok() .expect("all our wires have exactly one incoming port"); @@ -256,13 +232,30 @@ fn create_replacement(wire: PinnedWire, walker: &Walker) -> Option { // out_node and in_node act on the same qubits - // => cancel out the two CZ gates - ( - empty_2qb_hugr(), - SiblingSubgraph::try_from_nodes([out_node, in_node], &hugr).ok()?, + // => replace the two CZ gates with the empty 2qb HUGR + + // If the two CZ gates have flipped port ordering, we need to insert + // a swap gate + let add_swap = all_edges[0][0].index() != all_edges[0][1].index(); + + // Get the wires between the two CZ gates + let wires = all_edges + .into_iter() + .map(|[out_port, _]| walker.get_wire(out_node, out_port)); + + // Create the commit + walker.try_create_commit( + PinnedSubgraph::try_from_wires(wires, walker).unwrap(), + empty_2qb_hugr(add_swap), + |_, port| { + // the incoming/outgoing ports of the subgraph map trivially to the empty 2qb + // HUGR + let dir = port.direction(); + Port::new(dir.reverse(), port.index()) + }, ) } 1 => { @@ -273,32 +266,53 @@ fn create_replacement(wire: PinnedWire, walker: &Walker) -> Option establish which qubit is shared between the two CZ gates let [out_port, in_port] = all_edges.into_iter().exactly_one().unwrap(); - let shared_qb_on_out_node = out_port.index(); - let shared_qb_on_in_node = in_port.index(); - - let subgraph = SiblingSubgraph::try_new( - vec![ - vec![(out_node, shared_qb_on_out_node.into())], - vec![(out_node, (1 - shared_qb_on_out_node).into())], - vec![(in_node, (1 - shared_qb_on_in_node).into())], - ], - vec![ - (in_node, shared_qb_on_in_node.into()), - (out_node, (1 - shared_qb_on_out_node).into()), - (in_node, (1 - shared_qb_on_in_node).into()), - ], - &hugr, + let shared_qb_out = out_port.index(); + let shared_qb_in = in_port.index(); + + walker.try_create_commit( + PinnedSubgraph::try_from_wires([wire], walker).unwrap(), + repl_hugr, + |node, port| { + // map the incoming/outgoing ports of the subgraph to the replacement as + // follows: + // - the first qubit is the one that is shared between the two CZ gates + // - the second qubit only touches the first CZ (out_node) + // - the third qubit only touches the second CZ (in_node) + match port.as_directed() { + Either::Left(incoming) => { + let in_boundary: [(_, IncomingPort); 3] = [ + (out_node, shared_qb_out.into()), + (out_node, (1 - shared_qb_out).into()), + (in_node, (1 - shared_qb_in).into()), + ]; + let out_index = in_boundary + .iter() + .position(|&(n, p)| n == node && p == incoming) + .expect("invalid input port"); + OutgoingPort::from(out_index).into() + } + Either::Right(outgoing) => { + let out_boundary: [(_, OutgoingPort); 3] = [ + (in_node, shared_qb_in.into()), + (out_node, (1 - shared_qb_out).into()), + (in_node, (1 - shared_qb_in).into()), + ]; + let in_index = out_boundary + .iter() + .position(|&(n, p)| n == node && p == outgoing) + .expect("invalid output port"); + IncomingPort::from(in_index).into() + } + } + }, ) - .ok()?; - - (repl_hugr, subgraph) } _ => unreachable!(), - }; - - SimpleReplacement::try_new(subgraph, &hugr, repl_hugr).ok() + } + .ok() } +#[ignore = "takes 10s (todo: optimise)"] #[test] fn walker_example() { let state_space = build_state_space(); @@ -324,26 +338,16 @@ fn walker_example() { ); } - // assert_eq!(state_space.all_commit_ids().count(), 13); - let empty_commits = state_space .all_commit_ids() - // .filter(|&id| state_space.commit_hugr(id).num_nodes() == 3) - .filter(|&id| { - state_space - .inserted_nodes(id) - .filter(|&n| state_space.get_optype(n) == &h_gate().into()) - .count() - == 2 - }) + .filter(|&id| state_space.inserted_nodes(id).count() == 0) .collect_vec(); // there should be a combination of three empty commits that are compatible // and such that the resulting HUGR is empty let mut empty_hugr = None; - // for cs in empty_commits.iter().combinations(3) { - for cs in empty_commits.iter().combinations(2) { - let cs = cs.into_iter().copied().collect_vec(); + for cs in empty_commits.iter().combinations(3) { + let cs = cs.into_iter().copied(); if let Ok(hugr) = state_space.try_extract_hugr(cs) { empty_hugr = Some(hugr); } @@ -351,16 +355,23 @@ fn walker_example() { let empty_hugr = empty_hugr.unwrap().to_hugr(); - // assert_eq!(empty_hugr.num_nodes(), 3); - - let n_cz = empty_hugr - .nodes() - .filter(|&n| empty_hugr.get_optype(n) == &cz_gate().into()) - .count(); - let n_h = empty_hugr - .nodes() - .filter(|&n| empty_hugr.get_optype(n) == &h_gate().into()) - .count(); - assert_eq!(n_cz, 2); - assert_eq!(n_h, 4); + // The empty hugr should have 7 nodes: + // module root, funcdef, 2 func IO, DFG root, 2 DFG IO + assert_eq!(empty_hugr.num_nodes(), 7); + assert_eq!( + empty_hugr + .nodes() + .filter(|&n| { + !matches!( + empty_hugr.get_optype(n), + OpType::Input(_) + | OpType::Output(_) + | OpType::FuncDefn(_) + | OpType::Module(_) + | OpType::DFG(_) + ) + }) + .count(), + 0 + ); } diff --git a/hugr-py/CHANGELOG.md b/hugr-py/CHANGELOG.md index 7cce076974..9400909bb1 100644 --- a/hugr-py/CHANGELOG.md +++ b/hugr-py/CHANGELOG.md @@ -1,5 +1,108 @@ # Changelog +## [0.13.0rc1](https://github.com/CQCL/hugr/compare/hugr-py-v0.12.5...hugr-py-v0.13.0rc1) (2025-07-24) + + +### ⚠ BREAKING CHANGES + +* Lowering functions in extension operations are now encoded as binary envelopes. Older hugr versions will error out when trying to load them. +* **py:** `EnvelopeConfig::BINARY` now uses the model binary encoding. `EnvelopeFormat.MODULE` is now `EnvelopeFormat.MODEL`. `EnvelopeFormat.MODULE_WITH_EXTS` is now `EnvelopeFormat.MODEL_WITH_EXTS` +* hugr-model: Symbol has an extra field +* Renamed the `Any` type bound to `Linear` +* The model CFG signature types were changed. +* Added `TypeParam`s and `TypeArg`s corresponding to floats and bytes. +* `TypeArg::Sequence` needs to be replaced with +* FuncDefns must be moved to beneath Module. `Container::define_function` is gone, use `HugrBuilder::module_root_builder`; similarly in hugr-py `DefinitionBuilder` (`define_function` -> `module_root_builder().define_function`). In hugr-llvm, some uses of + +### Features + +* Add `BorrowArray` extension ([#2395](https://github.com/CQCL/hugr/issues/2395)) ([782687e](https://github.com/CQCL/hugr/commit/782687ed917c3e4295c2c3c59a17d784fc6f932d)) +* Add `MakeError` op ([#2377](https://github.com/CQCL/hugr/issues/2377)) ([909a794](https://github.com/CQCL/hugr/commit/909a7948c1465aab5528895bdee0e49958a416b6)), closes [#1863](https://github.com/CQCL/hugr/issues/1863) +* add toposort to HUGR-py ([#2367](https://github.com/CQCL/hugr/issues/2367)) ([34eed34](https://github.com/CQCL/hugr/commit/34eed3422c9aa34bd6b8ad868dcbab733eb5d14c)) +* Add Visibility to FuncDefn/FuncDecl. ([#2143](https://github.com/CQCL/hugr/issues/2143)) ([5bbe0cd](https://github.com/CQCL/hugr/commit/5bbe0cdc60625b4047f0cddc9598d6652ed6f736)) +* Added float and bytes literal to core and python bindings. ([#2289](https://github.com/CQCL/hugr/issues/2289)) ([e9c5e91](https://github.com/CQCL/hugr/commit/e9c5e914d4fd9ee270dee8e43875d8a413b02926)) +* **core, llvm:** add array unpack operations ([#2339](https://github.com/CQCL/hugr/issues/2339)) ([a1a70f1](https://github.com/CQCL/hugr/commit/a1a70f1afb5d8d57082269d167816c7a90497dcf)), closes [#1947](https://github.com/CQCL/hugr/issues/1947) +* Detect and fail on unrecognised envelope flags ([#2453](https://github.com/CQCL/hugr/issues/2453)) ([5e36770](https://github.com/CQCL/hugr/commit/5e36770895b79e878c1bbdf22e67e8cbff6513b6)) +* Export entrypoint metadata in Python and fix bug in import ([#2434](https://github.com/CQCL/hugr/issues/2434)) ([d17b245](https://github.com/CQCL/hugr/commit/d17b245c41d943da1c338094c31a75b55efe4061)) +* Expose `BorrowArray` in `hugr-py` ([#2425](https://github.com/CQCL/hugr/issues/2425)) ([fdb675f](https://github.com/CQCL/hugr/commit/fdb675f1473a9bf349fce0824c56539e239c11f3)), closes [#2406](https://github.com/CQCL/hugr/issues/2406) +* include generator metatada in model import and cli validate errors ([#2452](https://github.com/CQCL/hugr/issues/2452)) ([f7cedb4](https://github.com/CQCL/hugr/commit/f7cedb4f39b67a77b4c6a55ec00b624b54668eaa)) +* Names of private functions become `core.title` metadata. ([#2448](https://github.com/CQCL/hugr/issues/2448)) ([4bc7f65](https://github.com/CQCL/hugr/commit/4bc7f65338d9a8b37d3a5625aeaf093970d97926)) +* No nested FuncDefns (or AliasDefns) ([#2256](https://github.com/CQCL/hugr/issues/2256)) ([214b8df](https://github.com/CQCL/hugr/commit/214b8df837537b8ac15c3b60845350c3818a6ac7)) +* Non-region entrypoints in `hugr-model`. ([#2467](https://github.com/CQCL/hugr/issues/2467)) ([7b42da6](https://github.com/CQCL/hugr/commit/7b42da6f62de9fe36187512dba428fe3db8d6120)) +* Open lists and tuples in `Term` ([#2360](https://github.com/CQCL/hugr/issues/2360)) ([292af80](https://github.com/CQCL/hugr/commit/292af8010dba6b4c2ea5bb69edae31cbf1e0cb6a)) +* **py:** enable Model as default BINARY envelope format ([#2317](https://github.com/CQCL/hugr/issues/2317)) ([f089931](https://github.com/CQCL/hugr/commit/f08993124e48093c2328096a93cec8a9ad67a41c)) +* **py:** Helper methods to get the neighbours of a node ([#2370](https://github.com/CQCL/hugr/issues/2370)) ([bb6fa50](https://github.com/CQCL/hugr/commit/bb6fa50957ac5121bebc78a06335262a6559e695)) +* **py:** Use SumValue serialization for tuples ([#2466](https://github.com/CQCL/hugr/issues/2466)) ([f615037](https://github.com/CQCL/hugr/commit/f615037621aa0eeb37de8f1126fa9020705cb565)) +* Rename 'Any' type bound to 'Linear' ([#2421](https://github.com/CQCL/hugr/issues/2421)) ([c2f8b30](https://github.com/CQCL/hugr/commit/c2f8b30afd3a1b75f6babe77a90b13211e45e3a7)) +* Split `TypeArg::Sequence` into tuples and lists. ([#2140](https://github.com/CQCL/hugr/issues/2140)) ([cc4997f](https://github.com/CQCL/hugr/commit/cc4997f12dad4dfecc37be564712cae18dfce159)) +* Standarize the string formating of sum types and values ([#2432](https://github.com/CQCL/hugr/issues/2432)) ([ec207e7](https://github.com/CQCL/hugr/commit/ec207e7dbe6dbaa9f40421eb0836c9de7e3ea240)) +* Use binary envelopes for operation lower_func encoding ([#2447](https://github.com/CQCL/hugr/issues/2447)) ([2c16a77](https://github.com/CQCL/hugr/commit/2c16a7797a3b5800c5540d1e6a767dd38ad8ca6b)) + + +### Bug Fixes + +* Ensure SumTypes have the same json encoding in -rs and -py ([#2465](https://github.com/CQCL/hugr/issues/2465)) ([7f97e6f](https://github.com/CQCL/hugr/commit/7f97e6f84f0bb2b441fe3e2589e91f19de50198e)) +* Escape html-like labels in DotRenderer ([#2383](https://github.com/CQCL/hugr/issues/2383)) ([eaa7dfe](https://github.com/CQCL/hugr/commit/eaa7dfe35eb08dbd20d5f5353e92b58850e0f31f)) +* Export metadata in Python ([#2342](https://github.com/CQCL/hugr/issues/2342)) ([7be52db](https://github.com/CQCL/hugr/commit/7be52db4f63d7ce8556a5ba0d8d245ebb567e7ed)) +* Fix model export of `Opaque` types. ([#2446](https://github.com/CQCL/hugr/issues/2446)) ([3943499](https://github.com/CQCL/hugr/commit/39434996ba18db83a50455fda90c60aea11a8387)) +* Fixed bug in python model export name mangling. ([#2323](https://github.com/CQCL/hugr/issues/2323)) ([041342f](https://github.com/CQCL/hugr/commit/041342f58a3dcd9f73dbbaab102221c5d9ff5f61)) +* Fixed bugs in model CFG handling and improved CFG signatures ([#2334](https://github.com/CQCL/hugr/issues/2334)) ([ccd2eb2](https://github.com/CQCL/hugr/commit/ccd2eb226358b44aede7dd9e9217448c7e6c0f3a)) +* Fixed export of `Call` and `LoadConst` nodes in `hugr-py`. ([#2429](https://github.com/CQCL/hugr/issues/2429)) ([6a0e270](https://github.com/CQCL/hugr/commit/6a0e270e7edbea4cc08e2948d3f8a16b9e763af7)) +* Fixed invalid extension name in test. ([#2319](https://github.com/CQCL/hugr/issues/2319)) ([c58ddbf](https://github.com/CQCL/hugr/commit/c58ddbfcc0a557a1644fc8094370e6c62a7ce129)) +* Fixed two bugs in import/export of function operations ([#2324](https://github.com/CQCL/hugr/issues/2324)) ([1ad450f](https://github.com/CQCL/hugr/commit/1ad450f807485f7ef6083270aaa4523cb95b2490)) +* map IntValue to unsigned repr when serializing ([#2413](https://github.com/CQCL/hugr/issues/2413)) ([26d426e](https://github.com/CQCL/hugr/commit/26d426ee7ffdc38063a337e66458b8d797131bca)), closes [#2409](https://github.com/CQCL/hugr/issues/2409) +* Order hints on input and output nodes. ([#2422](https://github.com/CQCL/hugr/issues/2422)) ([a31ccbc](https://github.com/CQCL/hugr/commit/a31ccbcaaa7561f8d221269262cd9ca9e89ad67b)) +* **py:** correct ConstString JSON encoding ([#2325](https://github.com/CQCL/hugr/issues/2325)) ([9649a48](https://github.com/CQCL/hugr/commit/9649a48d376aff27e475c70072aecd55ae7a4ccb)) +* StaticArrayVal payload encoding, improve roundtrip checker ([#2444](https://github.com/CQCL/hugr/issues/2444)) ([1a301eb](https://github.com/CQCL/hugr/commit/1a301eb818401c314d4d7bac40698ec2e73babe7)) +* stringify metadata before escaping in renderer ([#2405](https://github.com/CQCL/hugr/issues/2405)) ([8d67420](https://github.com/CQCL/hugr/commit/8d67420e8fd2e979256ff64bcf0b2813ed19ac00)) + +## [0.12.5](https://github.com/CQCL/hugr/compare/hugr-py-v0.12.4...hugr-py-v0.12.5) (2025-07-08) + + +### Bug Fixes + +* map IntValue to unsigned repr when serializing ([#2413](https://github.com/CQCL/hugr/issues/2413)) ([4ad1d4e](https://github.com/CQCL/hugr/commit/4ad1d4e010eca07207306320b3cf74396f1f8181)), closes [#2409](https://github.com/CQCL/hugr/issues/2409) + +## [0.12.4](https://github.com/CQCL/hugr/compare/hugr-py-v0.12.3...hugr-py-v0.12.4) (2025-07-03) + + +### Bug Fixes + +* stringify metadata before escaping in renderer ([#2405](https://github.com/CQCL/hugr/issues/2405)) ([1f01e97](https://github.com/CQCL/hugr/commit/1f01e97696afe02b46eedb2c6e3e2f2369a4ac7b)) + +## [0.12.3](https://github.com/CQCL/hugr/compare/hugr-py-v0.12.2...hugr-py-v0.12.3) (2025-07-03) + + +### Features + +* add toposort to HUGR-py ([#2367](https://github.com/CQCL/hugr/issues/2367)) ([ba8988e](https://github.com/CQCL/hugr/commit/ba8988e87c2a3d64953838e9a1cff4989740cf05)) +* **core, llvm:** add array unpack operations ([#2339](https://github.com/CQCL/hugr/issues/2339)) ([74b25aa](https://github.com/CQCL/hugr/commit/74b25aa3a704c082f84a0c34fad2654e3392ff50)), closes [#1947](https://github.com/CQCL/hugr/issues/1947) +* **py:** Helper methods to get the neighbours of a node ([#2370](https://github.com/CQCL/hugr/issues/2370)) ([1ed6440](https://github.com/CQCL/hugr/commit/1ed64409aaf7e8f26fb5928051245e560881a621)) + + +### Bug Fixes + +* Escape html-like labels in DotRenderer ([#2383](https://github.com/CQCL/hugr/issues/2383)) ([c7a43a6](https://github.com/CQCL/hugr/commit/c7a43a69878e1271251b570070f192ebf57aaadd)) +* Fixed invalid extension name in test. ([#2319](https://github.com/CQCL/hugr/issues/2319)) ([fbe1d9c](https://github.com/CQCL/hugr/commit/fbe1d9c061768360144f5463dcf357fb59ac736f)) +* **py:** correct ConstString JSON encoding ([#2325](https://github.com/CQCL/hugr/issues/2325)) ([325168b](https://github.com/CQCL/hugr/commit/325168b50b5e40e884127ad89d7acb5ab3a412f8)) + +## [0.12.2](https://github.com/CQCL/hugr/compare/hugr-py-v0.12.1...hugr-py-v0.12.2) (2025-06-03) + + +### Bug Fixes + +* use envelopes for `FixedHugr` encoding ([#2283](https://github.com/CQCL/hugr/issues/2283)) ([2c8cbb9](https://github.com/CQCL/hugr/commit/2c8cbb99bc74d5d43956b5f75c89f17748b5ee39)), closes [#2282](https://github.com/CQCL/hugr/issues/2282) + + +### Performance Improvements + +* **py:** mutable `Node` to avoid linear update cost ([#2288](https://github.com/CQCL/hugr/issues/2288)) ([84fb200](https://github.com/CQCL/hugr/commit/84fb2002dc835f6b98ceb95bd80a7bcff9eecdd8)) + + +### Documentation + +* **py:** fix `TypeDef` example ([#2268](https://github.com/CQCL/hugr/issues/2268)) ([ede8e7b](https://github.com/CQCL/hugr/commit/ede8e7b087591303038ecc5b449bb85bf39c948b)) + ## [0.12.1](https://github.com/CQCL/hugr/compare/hugr-py-v0.12.0...hugr-py-v0.12.1) (2025-05-20) diff --git a/hugr-py/Cargo.toml b/hugr-py/Cargo.toml index 27020c8122..c864e3bf5b 100644 --- a/hugr-py/Cargo.toml +++ b/hugr-py/Cargo.toml @@ -21,6 +21,6 @@ bench = false [dependencies] bumpalo = { workspace = true, features = ["collections"] } -hugr-model = { version = "0.20.2", path = "../hugr-model", features = ["pyo3"] } +hugr-model = { version = "0.22.1", path = "../hugr-model", features = ["pyo3"] } paste.workspace = true pyo3 = { workspace = true, features = ["extension-module", "abi3-py310"] } diff --git a/hugr-py/pyproject.toml b/hugr-py/pyproject.toml index 8ecd43912e..93739961da 100644 --- a/hugr-py/pyproject.toml +++ b/hugr-py/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "hugr" -version = "0.12.1" +version = "0.13.0rc1" requires-python = ">=3.10" description = "Quantinuum's common representation for quantum programs" license = { file = "LICENCE" } diff --git a/hugr-py/rust/lib.rs b/hugr-py/rust/lib.rs index 5a6c705d23..bf7f0f1cbb 100644 --- a/hugr-py/rust/lib.rs +++ b/hugr-py/rust/lib.rs @@ -50,6 +50,16 @@ fn bytes_to_package(bytes: &[u8]) -> PyResult { Ok(package) } +/// Returns the current version of the HUGR model format as a tuple of (major, minor, patch). +#[pyfunction] +fn current_model_version() -> (u64, u64, u64) { + ( + hugr_model::CURRENT_VERSION.major, + hugr_model::CURRENT_VERSION.minor, + hugr_model::CURRENT_VERSION.patch, + ) +} + #[pymodule] fn _hugr(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_function(wrap_pyfunction!(term_to_string, m)?)?; @@ -68,5 +78,6 @@ fn _hugr(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_function(wrap_pyfunction!(string_to_param, m)?)?; m.add_function(wrap_pyfunction!(symbol_to_string, m)?)?; m.add_function(wrap_pyfunction!(string_to_symbol, m)?)?; + m.add_function(wrap_pyfunction!(current_model_version, m)?)?; Ok(()) } diff --git a/hugr-py/src/hugr/__init__.py b/hugr-py/src/hugr/__init__.py index 267d6f9e6a..679d47d112 100644 --- a/hugr-py/src/hugr/__init__.py +++ b/hugr-py/src/hugr/__init__.py @@ -18,4 +18,4 @@ # This is updated by our release-please workflow, triggered by this # annotation: x-release-please-version -__version__ = "0.12.1" +__version__ = "0.13.0rc1" diff --git a/hugr-py/src/hugr/_hugr/__init__.pyi b/hugr-py/src/hugr/_hugr/__init__.pyi index 68605037f3..efcc99f910 100644 --- a/hugr-py/src/hugr/_hugr/__init__.pyi +++ b/hugr-py/src/hugr/_hugr/__init__.pyi @@ -18,3 +18,4 @@ def package_to_string(package: hugr.model.Package) -> str: ... def string_to_package(string: str) -> hugr.model.Package: ... def package_to_bytes(package: hugr.model.Package) -> bytes: ... def bytes_to_package(binary: bytes) -> hugr.model.Package: ... +def current_model_version() -> tuple[int, int, int]: ... diff --git a/hugr-py/src/hugr/_serialization/extension.py b/hugr-py/src/hugr/_serialization/extension.py index 5ffdae2ff9..fed975fa61 100644 --- a/hugr-py/src/hugr/_serialization/extension.py +++ b/hugr-py/src/hugr/_serialization/extension.py @@ -66,12 +66,26 @@ def deserialize(self, extension: ext.Extension) -> ext.TypeDef: class FixedHugr(ConfiguredBaseModel): + """Fixed HUGR used to define the lowering of an operation. + + Args: + extensions: Extensions used in the HUGR. + hugr: Base64-encoded HUGR envelope. + """ + extensions: ExtensionSet hugr: str def deserialize(self) -> ext.FixedHugr: - hugr = Hugr.from_str(self.hugr) - return ext.FixedHugr(extensions=self.extensions, hugr=hugr) + # Loading fixed HUGRs requires reading hugr-model envelopes, + # which is not currently supported in Python. + # TODO: Add support for loading fixed HUGRs in Python. + # https://github.com/CQCL/hugr/issues/2287 + msg = ( + "Loading extensions with operation lowering functions is not " + + "supported in Python" + ) + raise NotImplementedError(msg) class OpDef(ConfiguredBaseModel, populate_by_name=True): @@ -91,13 +105,21 @@ def deserialize(self, extension: ext.Extension) -> ext.OpDef: self.binary, ) + # Loading fixed HUGRs requires reading hugr-model envelopes, + # which is not currently supported in Python. + # We currently ignore any lower functions instead of raising an error. + # + # TODO: Add support for loading fixed HUGRs in Python. + # https://github.com/CQCL/hugr/issues/2287 + lower_funcs: list[ext.FixedHugr] = [] + return extension.add_op_def( ext.OpDef( name=self.name, description=self.description, misc=self.misc or {}, signature=signature, - lower_funcs=[f.deserialize() for f in self.lower_funcs], + lower_funcs=lower_funcs, ) ) diff --git a/hugr-py/src/hugr/_serialization/ops.py b/hugr-py/src/hugr/_serialization/ops.py index cde3bd6160..8f602ac94c 100644 --- a/hugr-py/src/hugr/_serialization/ops.py +++ b/hugr-py/src/hugr/_serialization/ops.py @@ -7,6 +7,7 @@ from pydantic import ConfigDict, Field, RootModel +from hugr import tys from hugr.hugr.node_port import ( NodeIdx, # noqa: TCH001 # pydantic needs this alias in scope ) @@ -75,11 +76,16 @@ class FuncDefn(BaseOp): name: str signature: PolyFuncType + visibility: tys.Visibility = Field(default="Private") def deserialize(self) -> ops.FuncDefn: poly_func = self.signature.deserialize() return ops.FuncDefn( - self.name, inputs=poly_func.body.input, _outputs=poly_func.body.output + self.name, + params=poly_func.params, + inputs=poly_func.body.input, + _outputs=poly_func.body.output, + visibility=self.visibility, ) @@ -89,9 +95,12 @@ class FuncDecl(BaseOp): op: Literal["FuncDecl"] = "FuncDecl" name: str signature: PolyFuncType + visibility: tys.Visibility = Field(default="Public") def deserialize(self) -> ops.FuncDecl: - return ops.FuncDecl(self.name, self.signature.deserialize()) + return ops.FuncDecl( + self.name, self.signature.deserialize(), visibility=self.visibility + ) class CustomConst(ConfiguredBaseModel): @@ -123,24 +132,13 @@ class FunctionValue(BaseValue): """A higher-order function value.""" v: Literal["Function"] = Field(default="Function", title="ValueTag") - hugr: Any + hugr: str def deserialize(self) -> val.Value: - from hugr._serialization.serial_hugr import SerialHugr from hugr.hugr import Hugr # pydantic stores the serialized dictionary because of the "Any" annotation - return val.Function(Hugr._from_serial(SerialHugr(**self.hugr))) - - -class TupleValue(BaseValue): - """A constant tuple value.""" - - v: Literal["Tuple"] = Field(default="Tuple", title="ValueTag") - vs: list[Value] - - def deserialize(self) -> val.Value: - return val.Tuple(*deser_it(v.root for v in self.vs)) + return val.Function(Hugr.from_str(self.hugr)) class SumValue(BaseValue): @@ -149,9 +147,9 @@ class SumValue(BaseValue): For any Sum type where this value meets the type of the variant indicated by the tag """ - v: Literal["Sum"] = Field(default="Sum", title="ValueTag") - tag: int - typ: SumType + v: Literal["Sum", "Tuple"] = Field(default="Sum", title="ValueTag") + tag: int = Field(default=0, title="VariantTag") + typ: SumType | None = Field(default=None, title="SumType") vs: list[Value] model_config = ConfigDict( json_schema_extra={ @@ -163,15 +161,22 @@ class SumValue(BaseValue): ) def deserialize(self) -> val.Value: - return val.Sum( - self.tag, self.typ.deserialize(), deser_it(v.root for v in self.vs) - ) + if self.typ is None: + # Backwards compatibility of "Tuple" values + assert self.tag == 0, "Sum type must be provided if tag is not 0" + vs = deser_it(v.root for v in self.vs) + typ = tys.Sum(variant_rows=[[v.type_() for v in vs]]) + return val.Sum(0, typ, vs) + else: + return val.Sum( + self.tag, self.typ.deserialize(), deser_it(v.root for v in self.vs) + ) class Value(RootModel): """A constant Value.""" - root: CustomValue | FunctionValue | TupleValue | SumValue = Field(discriminator="v") + root: CustomValue | FunctionValue | SumValue = Field(discriminator="v") model_config = ConfigDict(json_schema_extra={"required": ["v"]}) @@ -598,6 +603,5 @@ class OpType(RootModel): from hugr import ( # noqa: E402 # needed to avoid circular imports ops, - tys, val, ) diff --git a/hugr-py/src/hugr/_serialization/tys.py b/hugr-py/src/hugr/_serialization/tys.py index c00a733751..1ed869c56c 100644 --- a/hugr-py/src/hugr/_serialization/tys.py +++ b/hugr-py/src/hugr/_serialization/tys.py @@ -1,5 +1,6 @@ from __future__ import annotations +import base64 import inspect import sys from abc import ABC, abstractmethod @@ -94,6 +95,20 @@ def deserialize(self) -> tys.StringParam: return tys.StringParam() +class BytesParam(BaseTypeParam): + tp: Literal["Bytes"] = "Bytes" + + def deserialize(self) -> tys.BytesParam: + return tys.BytesParam() + + +class FloatParam(BaseTypeParam): + tp: Literal["Float"] = "Float" + + def deserialize(self) -> tys.FloatParam: + return tys.FloatParam() + + class ListParam(BaseTypeParam): tp: Literal["List"] = "List" param: TypeParam @@ -114,7 +129,13 @@ class TypeParam(RootModel): """A type parameter.""" root: Annotated[ - TypeTypeParam | BoundedNatParam | StringParam | ListParam | TupleParam, + TypeTypeParam + | BoundedNatParam + | StringParam + | FloatParam + | BytesParam + | ListParam + | TupleParam, WrapValidator(_json_custom_error_validator), ] = Field(discriminator="tp") @@ -158,12 +179,56 @@ def deserialize(self) -> tys.StringArg: return tys.StringArg(value=self.arg) -class SequenceArg(BaseTypeArg): - tya: Literal["Sequence"] = "Sequence" +class FloatArg(BaseTypeArg): + tya: Literal["Float"] = "Float" + value: float + + def deserialize(self) -> tys.FloatArg: + return tys.FloatArg(value=self.value) + + +class BytesArg(BaseTypeArg): + tya: Literal["Bytes"] = "Bytes" + value: str = Field( + description="Base64-encoded byte string", + json_schema_extra={"contentEncoding": "base64"}, + ) + + def deserialize(self) -> tys.BytesArg: + value = base64.b64decode(self.value) + return tys.BytesArg(value=value) + + +class ListArg(BaseTypeArg): + tya: Literal["List"] = "List" elems: list[TypeArg] - def deserialize(self) -> tys.SequenceArg: - return tys.SequenceArg(elems=deser_it(self.elems)) + def deserialize(self) -> tys.ListArg: + return tys.ListArg(elems=deser_it(self.elems)) + + +class ListConcatArg(BaseTypeArg): + tya: Literal["ListConcat"] = "ListConcat" + lists: list[TypeArg] + + def deserialize(self) -> tys.ListConcatArg: + return tys.ListConcatArg(lists=deser_it(self.lists)) + + +class TupleArg(BaseTypeArg): + tya: Literal["Tuple"] = "Tuple" + elems: list[TypeArg] + + def deserialize(self) -> tys.TupleArg: + return tys.TupleArg(elems=deser_it(self.elems)) + + +class TupleConcatArg(BaseTypeArg): + tya: Literal["TupleConcat"] = "TupleConcat" + tuples: list[TypeArg] + + def deserialize(self) -> tys.TupleConcatArg: + return tys.TupleConcatArg(tuples=deser_it(self.tuples)) class VariableArg(BaseTypeArg): @@ -179,7 +244,14 @@ class TypeArg(RootModel): """A type argument.""" root: Annotated[ - TypeTypeArg | BoundedNatArg | StringArg | SequenceArg | VariableArg, + TypeTypeArg + | BoundedNatArg + | StringArg + | BytesArg + | FloatArg + | ListArg + | TupleArg + | VariableArg, WrapValidator(_json_custom_error_validator), ] = Field(discriminator="tya") @@ -340,15 +412,15 @@ def deserialize(self) -> tys.PolyFuncType: class TypeBound(Enum): Copyable = "C" - Any = "A" + Linear = "A" @staticmethod def join(*bs: TypeBound) -> TypeBound: """Computes the least upper bound for a sequence of bounds.""" res = TypeBound.Copyable for b in bs: - if b == TypeBound.Any: - return TypeBound.Any + if b == TypeBound.Linear: + return TypeBound.Linear if res == TypeBound.Copyable: res = b return res @@ -357,8 +429,8 @@ def __str__(self) -> str: match self: case TypeBound.Copyable: return "Copyable" - case TypeBound.Any: - return "Any" + case TypeBound.Linear: + return "Linear" class Opaque(BaseType): diff --git a/hugr-py/src/hugr/build/dfg.py b/hugr-py/src/hugr/build/dfg.py index 786723a606..ee4b917353 100644 --- a/hugr-py/src/hugr/build/dfg.py +++ b/hugr-py/src/hugr/build/dfg.py @@ -21,8 +21,9 @@ if TYPE_CHECKING: from collections.abc import Iterable, Sequence + from hugr.build.function import Module from hugr.hugr.node_port import Node, OutPort, PortOffset, ToNode, Wire - from hugr.tys import Type, TypeParam, TypeRow + from hugr.tys import TypeParam, TypeRow from .cfg import Cfg from .cond_loop import Conditional, If, TailLoop @@ -36,40 +37,21 @@ class DataflowError(Exception): @dataclass() class DefinitionBuilder(Generic[OpVar]): - """Base class for builders that can define functions, constants, and aliases. + """Base class for builders that can define constants, and allow access + to the `Module` for declaring/defining functions and aliases. As this class may be a root node, it does not extend `ParentBuilder`. """ hugr: Hugr[OpVar] - def define_function( - self, - name: str, - input_types: TypeRow, - output_types: TypeRow | None = None, - type_params: list[TypeParam] | None = None, - parent: ToNode | None = None, - ) -> Function: - """Start building a function definition in the graph. - - Args: - name: The name of the function. - input_types: The input types for the function. - output_types: The output types for the function. - If not provided, it will be inferred after the function is built. - type_params: The type parameters for the function, if polymorphic. - parent: The parent node of the constant. Defaults to the entrypoint node. - - Returns: - The new function builder. + def module_root_builder(self) -> Module: + """Allows access to the `Module` at the root of the Hugr + (outside the scope of this builder, perhaps outside the entrypoint). """ - parent_node = parent or self.hugr.entrypoint - parent_op = ops.FuncDefn(name, input_types, type_params or []) - func = Function.new_nested(parent_op, self.hugr, parent_node) - if output_types is not None: - func.declare_outputs(output_types) - return func + from hugr.build.function import Module # Avoid circular import + + return Module(self.hugr) def add_const(self, value: val.Value, parent: ToNode | None = None) -> Node: """Add a static constant to the graph. @@ -90,11 +72,6 @@ def add_const(self, value: val.Value, parent: ToNode | None = None) -> Node: parent_node = parent or self.hugr.entrypoint return self.hugr.add_node(ops.Const(value), parent_node) - def add_alias_defn(self, name: str, ty: Type, parent: ToNode | None = None) -> Node: - """Add a type alias definition.""" - parent_node = parent or self.hugr.entrypoint - return self.hugr.add_node(ops.AliasDefn(name, ty), parent_node) - DP = TypeVar("DP", bound=ops.DfParentOp) @@ -155,8 +132,15 @@ def new_nested( """ new = cls.__new__(cls) + try: + num_outs = parent_op.num_out + except ops.IncompleteOp: + num_outs = None + new.hugr = hugr - new.parent_node = hugr.add_node(parent_op, parent or hugr.entrypoint) + new.parent_node = hugr.add_node( + parent_op, parent or hugr.entrypoint, num_outs=num_outs + ) new._init_io_nodes(parent_op) return new @@ -228,7 +212,14 @@ def add_op( >>> dfg.add_op(ops.Noop(), dfg.inputs()[0]) Node(3) """ - new_n = self.hugr.add_node(op, self.parent_node, metadata=metadata) + try: + num_outs = op.num_out + except ops.IncompleteOp: + num_outs = None + + new_n = self.hugr.add_node( + op, self.parent_node, metadata=metadata, num_outs=num_outs + ) self._wire_up(new_n, args) new_n._num_out_ports = op.num_out return new_n @@ -755,7 +746,6 @@ def declare_outputs(self, output_types: TypeRow) -> None: defined yet. The wires passed to :meth:`set_outputs` must match the declared output types. """ - self._set_parent_output_count(len(output_types)) self.parent_op._set_out_types(output_types) def set_outputs(self, *args: Wire) -> None: diff --git a/hugr-py/src/hugr/build/function.py b/hugr-py/src/hugr/build/function.py index b5d8b8c1ff..6c7fa29249 100644 --- a/hugr-py/src/hugr/build/function.py +++ b/hugr-py/src/hugr/build/function.py @@ -11,7 +11,7 @@ if TYPE_CHECKING: from hugr.hugr.node_port import Node - from hugr.tys import PolyFuncType, TypeBound, TypeRow + from hugr.tys import PolyFuncType, Type, TypeBound, TypeParam, TypeRow __all__ = ["Function", "Module"] @@ -28,13 +28,39 @@ class Module(DefinitionBuilder[ops.Module]): hugr: Hugr[ops.Module] - def __init__(self) -> None: - self.hugr = Hugr(ops.Module()) + def __init__(self, hugr: Hugr | None = None) -> None: + self.hugr = Hugr(ops.Module()) if hugr is None else hugr def define_main(self, input_types: TypeRow) -> Function: """Define the 'main' function in the module. See :meth:`define_function`.""" return self.define_function("main", input_types) + def define_function( + self, + name: str, + input_types: TypeRow, + output_types: TypeRow | None = None, + type_params: list[TypeParam] | None = None, + ) -> Function: + """Start building a function definition in the graph. + + Args: + name: The name of the function. + input_types: The input types for the function. + output_types: The output types for the function. + If not provided, it will be inferred after the function is built. + type_params: The type parameters for the function, if polymorphic. + parent: The parent node of the constant. Defaults to the entrypoint node. + + Returns: + The new function builder. + """ + parent_op = ops.FuncDefn(name, input_types, type_params or []) + func = Function.new_nested(parent_op, self.hugr, self.hugr.module_root) + if output_types is not None: + func.declare_outputs(output_types) + return func + def declare_function(self, name: str, signature: PolyFuncType) -> Node: """Add a function declaration to the module. @@ -52,11 +78,17 @@ def declare_function(self, name: str, signature: PolyFuncType) -> Node: >>> m.declare_function("f", sig) Node(1) """ - return self.hugr.add_node(ops.FuncDecl(name, signature), self.hugr.entrypoint) + return self.hugr.add_node( + ops.FuncDecl(name, signature), self.hugr.entrypoint, num_outs=1 + ) + + def add_alias_defn(self, name: str, ty: Type) -> Node: + """Add a type alias definition.""" + return self.hugr.add_node(ops.AliasDefn(name, ty), self.hugr.module_root) def add_alias_decl(self, name: str, bound: TypeBound) -> Node: """Add a type alias declaration.""" - return self.hugr.add_node(ops.AliasDecl(name, bound), self.hugr.entrypoint) + return self.hugr.add_node(ops.AliasDecl(name, bound), self.hugr.module_root) @property def metadata(self) -> dict[str, object]: diff --git a/hugr-py/src/hugr/envelope.py b/hugr-py/src/hugr/envelope.py index 643c6b3ab4..07d199ad4f 100644 --- a/hugr-py/src/hugr/envelope.py +++ b/hugr-py/src/hugr/envelope.py @@ -46,6 +46,12 @@ # This is a hard-coded magic number that identifies the start of a HUGR envelope. MAGIC_NUMBERS = b"HUGRiHJv" +# The all-unset header flags configuration. +# Bit 7 is always set to ensure we have a printable ASCII character. +_DEFAULT_FLAGS = 0b0100_0000 +# The ZSTD flag bit in the header's flags. +_ZSTD_FLAG = 0b0000_0001 + def make_envelope(package: Package | Hugr, config: EnvelopeConfig) -> bytes: """Encode a HUGR or Package into an envelope, using the given configuration.""" @@ -65,10 +71,10 @@ def make_envelope(package: Package | Hugr, config: EnvelopeConfig) -> bytes: # `make_envelope_str`, but we prioritize speed for binary formats. payload = json_str.encode("utf-8") - case EnvelopeFormat.MODULE: + case EnvelopeFormat.MODEL: payload = bytes(package.to_model()) - case EnvelopeFormat.MODULE_WITH_EXTS: + case EnvelopeFormat.MODEL_WITH_EXTS: package_bytes = bytes(package.to_model()) extension_str = json.dumps( [ext._to_serial().model_dump(mode="json") for ext in package.extensions] @@ -105,7 +111,7 @@ def read_envelope(envelope: bytes) -> Package: match header.format: case EnvelopeFormat.JSON: return ext_s.Package.model_validate_json(payload).deserialize() - case EnvelopeFormat.MODULE | EnvelopeFormat.MODULE_WITH_EXTS: + case EnvelopeFormat.MODEL | EnvelopeFormat.MODEL_WITH_EXTS: msg = "Decoding HUGR envelopes in MODULE format is not supported yet." raise ValueError(msg) @@ -150,10 +156,10 @@ def read_envelope_hugr_str(envelope: str) -> Hugr: class EnvelopeFormat(Enum): """Format used to encode a HUGR envelope.""" - MODULE = 1 - """A capnp-encoded hugr-module.""" - MODULE_WITH_EXTS = 2 - """A capnp-encoded hugr-module, immediately followed by a json-encoded + MODEL = 1 + """A capnp-encoded hugr-model.""" + MODEL_WITH_EXTS = 2 + """A capnp-encoded hugr-model, immediately followed by a json-encoded extension registry.""" JSON = 63 # '?' in ASCII """A json-encoded hugr-package. This format is ASCII-printable.""" @@ -180,9 +186,9 @@ class EnvelopeHeader: def to_bytes(self) -> bytes: header_bytes = bytearray(MAGIC_NUMBERS) header_bytes.append(self.format.value) - flags = 0b01000000 + flags = _DEFAULT_FLAGS if self.zstd: - flags |= 0b00000001 + flags |= _ZSTD_FLAG header_bytes.append(flags) return bytes(header_bytes) @@ -204,7 +210,15 @@ def from_bytes(data: bytes) -> EnvelopeHeader: format: EnvelopeFormat = EnvelopeFormat(data[8]) flags = data[9] - zstd = bool(flags & 0b00000001) + zstd = bool(flags & _ZSTD_FLAG) + other_flags = (flags ^ _DEFAULT_FLAGS) & ~_ZSTD_FLAG + if other_flags: + flag_ids = [i for i in range(8) if other_flags & (1 << i)] + msg = ( + f"Unrecognised Envelope flags {flag_ids}." + + " Please update your HUGR version." + ) + raise ValueError(msg) return EnvelopeHeader(format=format, zstd=zstd) @@ -232,4 +246,4 @@ def _make_header(self) -> EnvelopeHeader: # Set EnvelopeConfig's class variables. # These can only be initialized _after_ the class is defined. EnvelopeConfig.TEXT = EnvelopeConfig(format=EnvelopeFormat.JSON, zstd=None) -EnvelopeConfig.BINARY = EnvelopeConfig(format=EnvelopeFormat.JSON, zstd=None) +EnvelopeConfig.BINARY = EnvelopeConfig(format=EnvelopeFormat.MODEL_WITH_EXTS, zstd=0) diff --git a/hugr-py/src/hugr/ext.py b/hugr-py/src/hugr/ext.py index 8123fa556b..53def975d3 100644 --- a/hugr-py/src/hugr/ext.py +++ b/hugr-py/src/hugr/ext.py @@ -2,6 +2,7 @@ from __future__ import annotations +import base64 from dataclasses import dataclass, field from typing import TYPE_CHECKING, Any, TypeVar @@ -154,7 +155,8 @@ class FixedHugr: hugr: Hugr def _to_serial(self) -> ext_s.FixedHugr: - return ext_s.FixedHugr(extensions=self.extensions, hugr=self.hugr.to_str()) + hugr_64: str = base64.b64encode(self.hugr.to_bytes()).decode() + return ext_s.FixedHugr(extensions=self.extensions, hugr=hugr_64) @dataclass diff --git a/hugr-py/src/hugr/hugr/base.py b/hugr-py/src/hugr/hugr/base.py index f5ca7ff8e2..d0fe111ee6 100644 --- a/hugr-py/src/hugr/hugr/base.py +++ b/hugr-py/src/hugr/hugr/base.py @@ -35,13 +35,16 @@ Conditional, Const, Custom, + DataflowBlock, DataflowOp, + ExitBlock, FuncDefn, IncompleteOp, Module, Op, + is_dataflow_op, ) -from hugr.tys import Kind, Type, ValueKind +from hugr.tys import Kind, OrderKind, Type, ValueKind from hugr.utils import BiMap from hugr.val import Value @@ -98,7 +101,8 @@ class Hugr(Mapping[Node, NodeData], Generic[OpVarCov]): """The core HUGR datastructure. Args: - root_op: The operation for the root node. Defaults to a Module. + entrypoint_op: The operation for the entrypoint node. Defaults to a Module + (which will then also be the root). Examples: >>> h = Hugr() @@ -148,7 +152,9 @@ def __init__(self, entrypoint_op: OpVarCov | None = None) -> None: case None | Module(): pass case ops.FuncDefn(): - self.entrypoint = self.add_node(entrypoint_op, self.module_root) + self.entrypoint = self.add_node( + entrypoint_op, self.module_root, num_outs=1 + ) case _: from hugr.build import Function @@ -226,6 +232,65 @@ def nodes(self) -> Iterable[tuple[Node, NodeData]]: """ return self.items() + def sorted_region_nodes(self, parent: Node) -> Iterator[Node]: + """Iterator over a topological ordering of all the hugr nodes. + + Note that the sort is performed within a hugr region and non-local + edges are ignored. + + Args: + parent: The parent node of the region to sort. + + Raises: + ValueError: If the region contains a cycle. + + Examples: + >>> from hugr.build.tracked_dfg import TrackedDfg + >>> from hugr.std.logic import Not + >>> dfg = TrackedDfg(tys.Bool) + >>> [b] = dfg.track_inputs() + >>> for _ in range(6): + ... _= dfg.add(Not(b)); + >>> dfg.set_tracked_outputs() + >>> nodes = list(dfg.hugr) + >>> list(dfg.hugr.sorted_region_nodes(nodes[4])) + [Node(5), Node(7), Node(8), Node(9), Node(10), Node(11), Node(12), Node(6)] + """ + # A dict to keep track of how many times we see a node. + # Store the Nodes with the input degrees as values. + # Implementation uses Kahn's algorithm + # https://en.wikipedia.org/wiki/Topological_sorting#Kahn's_algorithm + visit_dict: dict[Node, int] = {} + queue: Queue[Node] = Queue() + for node in self.children(parent): + incoming = 0 + for n in self.input_neighbours(node): + same_region = self[n].parent == parent + # Only update the degree of the node if edge is within the same region. + # We do not count non-local edges. + if same_region: + incoming += 1 + if incoming: + visit_dict[node] = incoming + # If a Node has no dependencies, add it to the queue. + else: + queue.put(node) + + while not queue.empty(): + new_node = queue.get() + yield new_node + + for neigh in self.output_neighbours(new_node): + visit_dict[neigh] -= 1 + if visit_dict[neigh] == 0: + del visit_dict[neigh] + queue.put(neigh) + + # If our dict is non-empty here then our graph contains a cycle + if visit_dict: + err = "Graph contains a cycle. No topological ordering exists." + raise ValueError(err) + def links(self) -> Iterator[tuple[OutPort, InPort]]: """Iterator over all the links in the HUGR. @@ -482,6 +547,12 @@ def add_order_link(self, src: ToNode, dst: ToNode) -> None: """ source = src.out(-1) target = dst.inp(-1) + assert ( + self.port_kind(source) == OrderKind() + ), f"Operation {self[src].op.name()} does not support order edges" + assert ( + self.port_kind(target) == OrderKind() + ), f"Operation {self[dst].op.name()} does not support order edges" if not self.has_link(source, target): self.add_link(source, target) @@ -527,15 +598,20 @@ def num_ports(self, node: ToNode, direction: Direction) -> int: Not necessarily the number of connected ports - if port `i` is connected, then all ports `0..i` are assumed to exist. + This value includes order ports. + Args: node: Node to query. direction: Direction of ports to count. Examples: + >>> from hugr.std.logic import Not >>> h = Hugr() - >>> n1 = h.add_const(val.TRUE) - >>> n2 = h.add_const(val.FALSE) - >>> h.add_link(n1.out(0), n2.inp(2)) # not a valid link! + >>> n1 = h.add_node(Not) + >>> n2 = h.add_node(Not) + >>> # Passing offset `2` here allocates new ports automatically + >>> h.add_link(n1.out(0), n2.inp(2)) + >>> h.add_order_link(n1, n2) >>> h.num_ports(n1, Direction.OUTGOING) 1 >>> h.num_ports(n2, Direction.INCOMING) @@ -548,11 +624,17 @@ def num_ports(self, node: ToNode, direction: Direction) -> int: ) def num_in_ports(self, node: ToNode) -> int: - """The number of incoming ports of a node. See :meth:`num_ports`.""" + """The number of incoming ports of a node. See :meth:`num_ports`. + + This value does not include order ports. + """ return self[node]._num_inps def num_out_ports(self, node: ToNode) -> int: - """The number of outgoing ports of a node. See :meth:`num_ports`.""" + """The number of outgoing ports of a node. See :meth:`num_ports`. + + This value cound does not include order ports. + """ return self[node]._num_outs def _linked_ports( @@ -634,9 +716,16 @@ def _node_links( port = cast("P", node.port(offset, direction)) yield port, list(self._linked_ports(port, links)) + order_port = cast("P", node.port(-1, direction)) + linked_order = list(self._linked_ports(order_port, links)) + if linked_order: + yield order_port, linked_order + def outgoing_links(self, node: ToNode) -> Iterable[tuple[OutPort, list[InPort]]]: """Iterator over outgoing links from a given node. + This number includes order ports. + Args: node: Node to query. @@ -648,14 +737,17 @@ def outgoing_links(self, node: ToNode) -> Iterable[tuple[OutPort, list[InPort]]] >>> df = dfg.Dfg() >>> df.hugr.add_link(df.input_node.out(0), df.output_node.inp(0)) >>> df.hugr.add_link(df.input_node.out(0), df.output_node.inp(1)) + >>> df.hugr.add_order_link(df.input_node, df.output_node) >>> list(df.hugr.outgoing_links(df.input_node)) - [(OutPort(Node(5), 0), [InPort(Node(6), 0), InPort(Node(6), 1)])] - """ + [(OutPort(Node(5), 0), [InPort(Node(6), 0), InPort(Node(6), 1)]), (OutPort(Node(5), -1), [InPort(Node(6), -1)])] + """ # noqa: E501 return self._node_links(node, self._links.fwd) def incoming_links(self, node: ToNode) -> Iterable[tuple[InPort, list[OutPort]]]: """Iterator over incoming links to a given node. + This number includes order ports. + Args: node: Node to query. @@ -667,11 +759,81 @@ def incoming_links(self, node: ToNode) -> Iterable[tuple[InPort, list[OutPort]]] >>> df = dfg.Dfg() >>> df.hugr.add_link(df.input_node.out(0), df.output_node.inp(0)) >>> df.hugr.add_link(df.input_node.out(0), df.output_node.inp(1)) + >>> df.hugr.add_order_link(df.input_node, df.output_node) >>> list(df.hugr.incoming_links(df.output_node)) - [(InPort(Node(6), 0), [OutPort(Node(5), 0)]), (InPort(Node(6), 1), [OutPort(Node(5), 0)])] + [(InPort(Node(6), 0), [OutPort(Node(5), 0)]), (InPort(Node(6), 1), [OutPort(Node(5), 0)]), (InPort(Node(6), -1), [OutPort(Node(5), -1)])] """ # noqa: E501 return self._node_links(node, self._links.bck) + def neighbours( + self, node: ToNode, direction: Direction | None = None + ) -> Iterable[Node]: + """Iterator over the neighbours of a node. + + Args: + node: Node to query. + direction: If given, only return neighbours in that direction. + + Returns: + Iterator of nodes connected to `node`, ordered by direction and port + offset. Nodes connected via multiple links will be returned multiple times. + + Examples: + >>> df = dfg.Dfg() + >>> df.hugr.add_link(df.input_node.out(0), df.output_node.inp(0)) + >>> df.hugr.add_link(df.input_node.out(0), df.output_node.inp(1)) + >>> list(df.hugr.neighbours(df.input_node)) + [Node(6), Node(6)] + >>> list(df.hugr.neighbours(df.output_node, Direction.OUTGOING)) + [] + """ + if direction is None or direction == Direction.INCOMING: + for _, linked_outputs in self.incoming_links(node): + for out_port in linked_outputs: + yield out_port.node + if direction is None or direction == Direction.OUTGOING: + for _, linked_inputs in self.outgoing_links(node): + for in_port in linked_inputs: + yield in_port.node + + def input_neighbours(self, node: ToNode) -> Iterable[Node]: + """Iterator over the input neighbours of a node. + + Args: + node: Node to query. + + Returns: + Iterator of nodes connected to `node` via incoming links. + Nodes connected via multiple links will be returned multiple times. + + Examples: + >>> df = dfg.Dfg() + >>> df.hugr.add_link(df.input_node.out(0), df.output_node.inp(0)) + >>> df.hugr.add_link(df.input_node.out(0), df.output_node.inp(1)) + >>> list(df.hugr.input_neighbours(df.output_node)) + [Node(5), Node(5)] + """ + return self.neighbours(node, Direction.INCOMING) + + def output_neighbours(self, node: ToNode) -> Iterable[Node]: + """Iterator over the output neighbours of a node. + + Args: + node: Node to query. + + Returns: + Iterator of nodes connected to `node` via outgoing links. + Nodes connected via multiple links will be returned multiple times. + + Examples: + >>> df = dfg.Dfg() + >>> df.hugr.add_link(df.input_node.out(0), df.output_node.inp(0)) + >>> df.hugr.add_link(df.input_node.out(0), df.output_node.inp(1)) + >>> list(df.hugr.output_neighbours(df.input_node)) + [Node(6), Node(6)] + """ + return self.neighbours(node, Direction.OUTGOING) + def num_incoming(self, node: Node) -> int: """The number of incoming links to a `node`. @@ -681,7 +843,7 @@ def num_incoming(self, node: Node) -> int: >>> df.hugr.num_incoming(df.output_node) 1 """ - return sum(1 for _ in self.incoming_links(node)) + return sum(len(links) for (_, links) in self.incoming_links(node)) def num_outgoing(self, node: ToNode) -> int: """The number of outgoing links from a `node`. @@ -692,7 +854,7 @@ def num_outgoing(self, node: ToNode) -> int: >>> df.hugr.num_outgoing(df.input_node) 1 """ - return sum(1 for _ in self.outgoing_links(node)) + return sum(len(links) for (_, links) in self.outgoing_links(node)) # TODO: num_links and _linked_ports @@ -777,7 +939,9 @@ def _to_serial(self) -> SerialHugr: def _serialize_link( link: tuple[_SO, _SI], - ) -> tuple[tuple[NodeIdx, PortOffset], tuple[NodeIdx, PortOffset]]: + ) -> tuple[ + tuple[NodeIdx, PortOffset | None], tuple[NodeIdx, PortOffset | None] + ]: src, dst = link s, d = self._constrain_offset(src.port), self._constrain_offset(dst.port) return (src.port.node.idx, s), (dst.port.node.idx, d) @@ -804,16 +968,16 @@ def _serialize_link( entrypoint=entrypoint, ) - def _constrain_offset(self, p: P) -> PortOffset: - # An offset of -1 is a special case, indicating an order edge, - # not counted in the number of ports. + def _constrain_offset(self, p: P) -> PortOffset | None: + """Constrain an offset to be a valid encoded port offset. + + Order edges and control flow edges should be encoded without an offset. + """ if p.offset < 0: assert p.offset == -1, "Only order edges are allowed with offset < 0" - offset = self.num_ports(p.node, p.direction) + return None else: - offset = p.offset - - return offset + return p.offset def resolve_extensions(self, registry: ext.ExtensionRegistry) -> Hugr: """Resolve extension types and operations in the HUGR by matching them to @@ -833,7 +997,7 @@ def _connect_df_entrypoint_outputs(self) -> None: """ from hugr.build import Function - if not isinstance(self.entrypoint_op(), DataflowOp): + if not is_dataflow_op(self.entrypoint_op()): return func_node = self[self.entrypoint].parent @@ -878,9 +1042,8 @@ def get_meta(idx: int) -> dict[str, Any]: parent: Node | None = Node(serial_node.root.parent) serial_node.root.parent = -1 - n = hugr._add_node( - serial_node.root.deserialize(), parent, metadata=node_meta - ) + op = serial_node.root.deserialize() + n = hugr._add_node(op, parent, metadata=node_meta, num_outs=op.num_out) assert ( n.idx == idx + boilerplate_nodes ), "Nodes should be added contiguously" @@ -889,11 +1052,21 @@ def get_meta(idx: int) -> dict[str, Any]: hugr.entrypoint = n for (src_node, src_offset), (dst_node, dst_offset) in serial.edges: + src = Node(src_node, _metadata=get_meta(src_node)) + dst = Node(dst_node, _metadata=get_meta(dst_node)) if src_offset is None or dst_offset is None: - continue + src_op = hugr[src].op + if isinstance(src_op, DataflowBlock | ExitBlock): + # Control flow edge + src_offset = 0 + dst_offset = 0 + else: + # Order edge + hugr.add_order_link(src, dst) + continue hugr.add_link( - Node(src_node, _metadata=get_meta(src_node)).out(src_offset), - Node(dst_node, _metadata=get_meta(dst_node)).inp(dst_offset), + src.out(src_offset), + dst.inp(dst_offset), ) return hugr diff --git a/hugr-py/src/hugr/hugr/render.py b/hugr-py/src/hugr/hugr/render.py index 059ed2c03f..ce72b7cb09 100644 --- a/hugr-py/src/hugr/hugr/render.py +++ b/hugr-py/src/hugr/hugr/render.py @@ -1,5 +1,6 @@ """Visualise HUGR using graphviz.""" +import html from collections.abc import Iterable from dataclasses import dataclass, field @@ -101,7 +102,9 @@ def render(self, hugr: Hugr) -> Digraph: "margin": "0", "bgcolor": self.config.palette.background, } - if not (name := hugr[hugr.module_root].metadata.get("name", None)): + if name := hugr[hugr.module_root].metadata.get("name", None): + name = html.escape(str(name)) + else: name = "" graph = gv.Digraph(name, strict=False) @@ -215,7 +218,8 @@ def _viz_node(self, node: Node, hugr: Hugr, graph: Digraph) -> None: meta = hugr[node].metadata if len(meta) > 0: data = "

" + "
".join( - f"{key}: {value}" for key, value in meta.items() + html.escape(key) + ": " + html.escape(str(value)) + for key, value in meta.items() ) else: data = "" @@ -236,6 +240,7 @@ def _viz_node(self, node: Node, hugr: Hugr, graph: Digraph) -> None: op_name = op.op_def().name else: op_name = op.name() + op_name = html.escape(op_name) label_config = { "node_back_color": self.config.palette.node, @@ -286,7 +291,7 @@ def _viz_link( label = "" match kind: case ValueKind(ty): - label = str(ty) + label = html.escape(str(ty)) color = self.config.palette.edge case OrderKind(): color = self.config.palette.dark diff --git a/hugr-py/src/hugr/model/__init__.py b/hugr-py/src/hugr/model/__init__.py index d7688e0112..ac97177cd2 100644 --- a/hugr-py/src/hugr/model/__init__.py +++ b/hugr-py/src/hugr/model/__init__.py @@ -5,7 +5,20 @@ from enum import Enum from typing import Protocol +from semver import Version + import hugr._hugr as rust +from hugr.tys import Visibility + + +def _current_version() -> Version: + """Get the current version of the HUGR model.""" + (major, minor, patch) = rust.current_model_version() + return Version(major=major, minor=minor, patch=patch) + + +# The current version of the HUGR model. +CURRENT_VERSION: Version = _current_version() class Term(Protocol): @@ -101,6 +114,7 @@ class Symbol: """A named symbol.""" name: str + visibility: Visibility params: Sequence[Param] = field(default_factory=list) constraints: Sequence[Term] = field(default_factory=list) signature: Term = field(default_factory=Wildcard) @@ -289,3 +303,8 @@ def from_str(s: str) -> "Package": def from_bytes(b: bytes) -> "Package": """Read a package from its binary representation.""" return rust.bytes_to_package(b) + + @property + def version(self) -> Version: + """Returns the model version used to encode this package.""" + return CURRENT_VERSION diff --git a/hugr-py/src/hugr/model/export.py b/hugr-py/src/hugr/model/export.py index d93713d2a9..8bef32e452 100644 --- a/hugr-py/src/hugr/model/export.py +++ b/hugr-py/src/hugr/model/export.py @@ -29,7 +29,15 @@ Tag, TailLoop, ) -from hugr.tys import ConstKind, FunctionKind, Type, TypeBound, TypeParam, TypeTypeParam +from hugr.tys import ( + ConstKind, + FunctionKind, + Type, + TypeBound, + TypeParam, + TypeTypeParam, + Visibility, +) class ModelExport: @@ -39,8 +47,7 @@ def __init__(self, hugr: Hugr): self.hugr = hugr self.link_ports: _UnionFind[InPort | OutPort] = _UnionFind() self.link_names: dict[InPort | OutPort, str] = {} - - # TODO: Store the hugr entrypoint + self.link_next = 0 for a, b in self.hugr.links(): self.link_ports.union(a, b) @@ -52,20 +59,26 @@ def link_name(self, port: InPort | OutPort) -> str: if root in self.link_names: return self.link_names[root] else: - index = str(len(self.link_names)) + index = str(self.link_next) + self.link_next += 1 self.link_names[root] = index return index - def export_node(self, node: Node) -> model.Node | None: + def export_node( + self, node: Node, virtual_input_links: Sequence[str] = [] + ) -> model.Node | None: """Export the node with the given node id.""" node_data = self.hugr[node] inputs = [self.link_name(InPort(node, i)) for i in range(node_data._num_inps)] + inputs = [*inputs, *virtual_input_links] + outputs = [self.link_name(OutPort(node, i)) for i in range(node_data._num_outs)] meta = self.export_json_meta(node) + meta += self.export_entrypoint_meta(node) # Add an order hint key to the node if necessary - if _needs_order_key(self.hugr, node): + if _has_order_links(self.hugr, node): meta.append(model.Apply("core.order_hint.key", [model.Literal(node.idx)])) match node_data.op: @@ -111,7 +124,8 @@ def export_node(self, node: Node) -> model.Node | None: case Conditional() as op: regions = [ - self.export_region_dfg(child) for child in node_data.children + self.export_region_dfg(child, entrypoint_meta=True) + for child in node_data.children ] signature = op.outer_signature().to_model() @@ -138,30 +152,45 @@ def export_node(self, node: Node) -> model.Node | None: ) case FuncDefn() as op: - name = _mangle_name(node, op.f_name) + name = _mangle_name(node, op.f_name, op.visibility) symbol = self.export_symbol( - name, op.signature.params, op.signature.body + name, op.visibility, op.signature.params, op.signature.body ) region = self.export_region_dfg(node) + if op.visibility == "Private": + meta.append(model.Apply("core.title", [model.Literal(op.f_name)])) + return model.Node( operation=model.DefineFunc(symbol), regions=[region], meta=meta ) case FuncDecl() as op: - name = _mangle_name(node, op.f_name) + name = _mangle_name(node, op.f_name, op.visibility) symbol = self.export_symbol( - name, op.signature.params, op.signature.body + name, op.visibility, op.signature.params, op.signature.body ) + + if op.visibility == "Private": + meta.append(model.Apply("core.title", [model.Literal(op.f_name)])) + return model.Node(operation=model.DeclareFunc(symbol), meta=meta) case AliasDecl() as op: - symbol = model.Symbol(name=op.alias, signature=model.Apply("core.type")) + symbol = model.Symbol( + name=op.alias, + visibility="Public", + signature=model.Apply("core.type"), + ) return model.Node(operation=model.DeclareAlias(symbol), meta=meta) case AliasDefn() as op: - symbol = model.Symbol(name=op.alias, signature=model.Apply("core.type")) + symbol = model.Symbol( + name=op.alias, + visibility="Public", + signature=model.Apply("core.type"), + ) alias_value = cast(model.Term, op.definition.to_model()) @@ -182,6 +211,11 @@ def export_node(self, node: Node) -> model.Node | None: error = f"Call node {node} is not connected to a function." raise ValueError(error) + # We ignore the static input edge since the function is passed + # as an argument instead. + assert len(inputs) == len(input_types) + 1 + inputs = inputs[0 : len(inputs) - 1] + func = model.Apply(func_name, func_args) return model.Node( @@ -202,7 +236,8 @@ def export_node(self, node: Node) -> model.Node | None: ) case LoadFunc() as op: - signature = op.instantiation.to_model() + signature = op.outer_signature().to_model() + instantiation = op.instantiation.to_model() func_args = cast( list[model.Term], [type.to_model() for type in op.type_args] ) @@ -216,10 +251,10 @@ def export_node(self, node: Node) -> model.Node | None: return model.Node( operation=model.CustomOp( - model.Apply("core.load_const", [signature, func]) + model.Apply("core.load_const", [instantiation, func]) ), signature=signature, - inputs=inputs, + inputs=[], outputs=outputs, meta=meta, ) @@ -272,7 +307,7 @@ def export_node(self, node: Node) -> model.Node | None: model.Apply("core.load_const", [type, value]) ), signature=signature, - inputs=inputs, + inputs=[], outputs=outputs, meta=meta, ) @@ -296,31 +331,21 @@ def export_node(self, node: Node) -> model.Node | None: case DataflowBlock() as op: region = self.export_region_dfg(node) - input_types = [ - model.Apply( - "core.ctrl", - [model.List([type.to_model() for type in op.inputs])], - ) - ] + input_types = [model.List([type.to_model() for type in op.inputs])] other_output_types = [type.to_model() for type in op.other_outputs] output_types = [ - model.Apply( - "core.ctrl", + model.List( [ - model.List( - [ - *[type.to_model() for type in row], - *other_output_types, - ] - ) - ], + *[type.to_model() for type in row], + *other_output_types, + ] ) for row in op.sum_ty.variant_rows ] signature = model.Apply( - "core.fn", + "core.ctrl", [model.List(input_types), model.List(output_types)], ) @@ -379,10 +404,18 @@ def export_json_meta(self, node: Node) -> list[model.Term]: return meta + def export_entrypoint_meta(self, node: Node) -> list[model.Term]: + """Export entrypoint metadata if the node is the module entrypoint.""" + if self.hugr.entrypoint == node: + return [model.Apply("core.entrypoint")] + else: + return [] + def export_region_module(self, node: Node) -> model.Region: """Export a module node as a module region.""" node_data = self.hugr[node] meta = self.export_json_meta(node) + meta += self.export_entrypoint_meta(node) children = [] for child in node_data.children: @@ -393,7 +426,7 @@ def export_region_module(self, node: Node) -> model.Region: return model.Region(kind=model.RegionKind.MODULE, children=children, meta=meta) - def export_region_dfg(self, node: Node) -> model.Region: + def export_region_dfg(self, node: Node, entrypoint_meta=False) -> model.Region: """Export the children of a node as a dataflow region.""" node_data = self.hugr[node] children: list[model.Node] = [] @@ -403,6 +436,9 @@ def export_region_dfg(self, node: Node) -> model.Region: targets = [] meta = [] + if entrypoint_meta: + meta += self.export_entrypoint_meta(node) + for child in node_data.children: child_data = self.hugr[child] @@ -414,6 +450,13 @@ def export_region_dfg(self, node: Node) -> model.Region: for i in range(child_data._num_outs) ] + if _has_order_links(self.hugr, child): + meta.append( + model.Apply( + "core.order_hint.input_key", [model.Literal(child.idx)] + ) + ) + case Output() as op: target_types = model.List([type.to_model() for type in op.types]) targets = [ @@ -421,6 +464,13 @@ def export_region_dfg(self, node: Node) -> model.Region: for i in range(child_data._num_inps) ] + if _has_order_links(self.hugr, child): + meta.append( + model.Apply( + "core.order_hint.output_key", [model.Literal(child.idx)] + ) + ) + case _: child_node = self.export_node(child) @@ -429,14 +479,13 @@ def export_region_dfg(self, node: Node) -> model.Region: children.append(child_node) - meta += [ - model.Apply( - "core.order_hint.order", - [model.Literal(child.idx), model.Literal(successor.idx)], - ) - for successor in self.hugr.outgoing_order_links(child) - if not isinstance(self.hugr[successor].op, Output) - ] + meta += [ + model.Apply( + "core.order_hint.order", + [model.Literal(child.idx), model.Literal(successor.idx)], + ) + for successor in self.hugr.outgoing_order_links(child) + ] signature = model.Apply("core.fn", [source_types, target_types]) @@ -458,6 +507,7 @@ def export_region_cfg(self, node: Node) -> model.Region: source_types: model.Term = model.Wildcard() target_types: model.Term = model.Wildcard() children = [] + meta = self.export_entrypoint_meta(node) for child in node_data.children: child_data = self.hugr[child] @@ -476,9 +526,14 @@ def export_region_cfg(self, node: Node) -> model.Region: source_types = model.List( [type.to_model() for type in op.inputs] ) - source = self.link_name(OutPort(child, 0)) + source = str(self.link_next) + self.link_next += 1 - child_node = self.export_node(child) + child_node = self.export_node( + child, virtual_input_links=[source] + ) + else: + child_node = self.export_node(child) if child_node is not None: children.append(child_node) @@ -490,7 +545,13 @@ def export_region_cfg(self, node: Node) -> model.Region: error = f"CFG {node} has no entry block." raise ValueError(error) - signature = model.Apply("core.fn", [source_types, target_types]) + signature = model.Apply( + "core.ctrl", + [ + model.List([source_types]), + model.List([target_types]), + ], + ) return model.Region( kind=model.RegionKind.CONTROL_FLOW, @@ -498,10 +559,15 @@ def export_region_cfg(self, node: Node) -> model.Region: sources=[source], signature=signature, children=children, + meta=meta, ) def export_symbol( - self, name: str, param_types: Sequence[TypeParam], body: Type + self, + name: str, + visibility: Visibility, + param_types: Sequence[TypeParam], + body: Type, ) -> model.Symbol: """Export a symbol.""" constraints = [] @@ -522,13 +588,14 @@ def export_symbol( return model.Symbol( name=name, + visibility=visibility, params=params, constraints=constraints, signature=cast(model.Term, body.to_model()), ) def find_func_input(self, node: Node) -> str | None: - """Find the name of the function that a node is connected to, if any.""" + """Find the symbol name of the function that a node is connected to, if any.""" try: func_node = next( out_port.node @@ -542,12 +609,14 @@ def find_func_input(self, node: Node) -> str | None: match self.hugr[func_node].op: case FuncDecl() as func_op: name = func_op.f_name + visibility = func_op.visibility case FuncDefn() as func_op: name = func_op.f_name + visibility = func_op.visibility case _: return None - return _mangle_name(func_node, name) + return _mangle_name(func_node, name, visibility) def find_const_input(self, node: Node) -> model.Term | None: """Find and export the constant that a node is connected to, if any.""" @@ -568,10 +637,17 @@ def find_const_input(self, node: Node) -> model.Term | None: return None -def _mangle_name(node: Node, name: str) -> str: - # Until we come to an agreement on the uniqueness of names, we mangle the names - # by adding the node id. - return f"_{name}_{node.idx}" +def _mangle_name(node: Node, name: str, visibility: Visibility) -> str: + match visibility: + case "Private": + # Until we come to an agreement on the uniqueness of names, + # we mangle the names by replacing id with the node id. + return f"_{node.idx}" + case "Public": + return name + case _: + error = f"Unexpected visibility {visibility}" + raise ValueError(error) T = TypeVar("T") @@ -610,19 +686,12 @@ def union(self, a: T, b: T): self.sizes[a] += self.sizes[b] -def _needs_order_key(hugr: Hugr, node: Node) -> bool: - """Checks whether the node has any order links for the purposes of - exporting order hint metadata. Order links to `Input` or `Output` - operations are ignored, since they are not present in the model format. - """ - for succ in hugr.outgoing_order_links(node): - succ_op = hugr[succ].op - if not isinstance(succ_op, Output): - return True - - for pred in hugr.incoming_order_links(node): - pred_op = hugr[pred].op - if not isinstance(pred_op, Input): - return True +def _has_order_links(hugr: Hugr, node: Node) -> bool: + """Checks whether the node has any order links.""" + for _succ in hugr.outgoing_order_links(node): + return True + + for _pred in hugr.incoming_order_links(node): + return True return False diff --git a/hugr-py/src/hugr/ops.py b/hugr-py/src/hugr/ops.py index fdcdd89082..23ef810d84 100644 --- a/hugr-py/src/hugr/ops.py +++ b/hugr-py/src/hugr/ops.py @@ -19,13 +19,14 @@ import hugr._serialization.ops as sops from hugr import tys, val from hugr.hugr.node_port import Direction, InPort, Node, OutPort, PortOffset, Wire -from hugr.utils import comma_sep_str, ser_it +from hugr.utils import comma_sep_repr, comma_sep_str, ser_it if TYPE_CHECKING: from collections.abc import Sequence from hugr import ext from hugr._serialization.ops import BaseOp + from hugr.tys import Visibility @dataclass @@ -131,11 +132,12 @@ def port_type(self, port: InPort | OutPort) -> tys.Type: Bool """ - sig = self.outer_signature() if port.offset == -1: # Order port msg = "Order port has no type." raise ValueError(msg) + + sig = self.outer_signature() try: if port.direction == Direction.INCOMING: return sig.input[port.offset] @@ -240,6 +242,12 @@ def _to_serial(self, parent: Node) -> sops.Input: def _inputs(self) -> tys.TypeRow: return [] + def port_kind(self, port: InPort | OutPort) -> tys.Kind: + # Input only allows order edges on outgoing ports + if port.offset == -1 and port.direction == Direction.OUTGOING: + return tys.OrderKind() + return tys.ValueKind(self.port_type(port)) + def outer_signature(self) -> tys.FunctionType: return tys.FunctionType(input=[], output=self.types) @@ -257,7 +265,10 @@ class Output(DataflowOp, _PartialOp): """ _types: tys.TypeRow | None = field(default=None, repr=False) - num_out: int = field(default=0, repr=False) + + @property + def num_out(self) -> int: + return 0 @property def types(self) -> tys.TypeRow: @@ -269,6 +280,12 @@ def _to_serial(self, parent: Node) -> sops.Output: def _inputs(self) -> tys.TypeRow: return self.types + def port_kind(self, port: InPort | OutPort) -> tys.Kind: + # Output only allows order edges on incoming ports + if port.offset == -1 and port.direction == Direction.INCOMING: + return tys.OrderKind() + return tys.ValueKind(self.port_type(port)) + def outer_signature(self) -> tys.FunctionType: return tys.FunctionType(input=self.types, output=[]) @@ -534,7 +551,7 @@ def cached_signature(self) -> tys.FunctionType | None: ) def type_args(self) -> list[tys.TypeArg]: - return [tys.SequenceArg([t.type_arg() for t in self.types])] + return [tys.ListArg([t.type_arg() for t in self.types])] def __call__(self, *elements: ComWire) -> Command: return super().__call__(*elements) @@ -576,7 +593,7 @@ def cached_signature(self) -> tys.FunctionType | None: ) def type_args(self) -> list[tys.TypeArg]: - return [tys.SequenceArg([t.type_arg() for t in self.types])] + return [tys.ListArg([t.type_arg() for t in self.types])] @property def num_out(self) -> int: @@ -604,7 +621,7 @@ def name(self) -> str: return "UnpackTuple" -@dataclass() +@dataclass(frozen=True) class Tag(DataflowOp): """Tag a row of incoming values to make them a variant of a sum type. @@ -632,63 +649,67 @@ def outer_signature(self) -> tys.FunctionType: ) def __repr__(self) -> str: + if len(self.sum_ty.variant_rows) == 2: + left, right = self.sum_ty.variant_rows + if len(left) == 0 and self.tag == 1: + return f"Some({comma_sep_repr(right)})" + elif self.tag == 0: + return f"Left({left!r}, {right!r})" + else: + return f"Right({left!r}, {right!r})" + return f"Tag(tag={self.tag}, sum_ty={self.sum_ty!r})" + + def __str__(self) -> str: + if len(self.sum_ty.variant_rows) == 2: + left, right = self.sum_ty.variant_rows + if len(left) == 0 and self.tag == 1: + return "Some" + elif self.tag == 0: + return "Left" + else: + return "Right" return f"Tag({self.tag})" -@dataclass +@dataclass(frozen=True, eq=False, repr=False) class Some(Tag): """Tag operation for the `Some` variant of an Option type. Example: # construct a Some variant holding a row of Bool and Unit types >>> Some(tys.Bool, tys.Unit) - Some + Some(Bool, Unit) """ def __init__(self, *some_tys: tys.Type) -> None: super().__init__(1, tys.Option(*some_tys)) - def __repr__(self) -> str: - return "Some" - -@dataclass +@dataclass(frozen=True, eq=False, repr=False) class Right(Tag): """Tag operation for the `Right` variant of an type.""" def __init__(self, either_type: tys.Either) -> None: super().__init__(1, either_type) - def __repr__(self) -> str: - return "Right" - -@dataclass +@dataclass(frozen=True, eq=False, repr=False) class Left(Tag): """Tag operation for the `Left` variant of an type.""" def __init__(self, either_type: tys.Either) -> None: super().__init__(0, either_type) - def __repr__(self) -> str: - return "Left" - class Continue(Left): """Tag operation for the `Continue` variant of a TailLoop controlling Either type. """ - def __repr__(self) -> str: - return "Continue" - class Break(Right): """Tag operation for the `Break` variant of a TailLoop controlling Either type.""" - def __repr__(self) -> str: - return "Break" - class DfParentOp(Op, Protocol): """Abstract parent of dataflow graph operations. Can be queried for the @@ -819,7 +840,7 @@ def _inputs(self) -> tys.TypeRow: @dataclass class DataflowBlock(DfParentOp): - """Parent of non-entry basic block in a control flow graph.""" + """Parent of non-exit basic block in a control flow graph.""" #: Inputs types of the inner dataflow graph. inputs: tys.TypeRow @@ -1150,6 +1171,8 @@ class FuncDefn(DfParentOp): params: list[tys.TypeParam] = field(default_factory=list) _outputs: tys.TypeRow | None = field(default=None, repr=False) num_out: int = field(default=1, repr=False) + #: Visibility (for linking). + visibility: Visibility = "Private" @property def outputs(self) -> tys.TypeRow: @@ -1176,6 +1199,7 @@ def _to_serial(self, parent: Node) -> sops.FuncDefn: parent=parent.idx, name=self.f_name, signature=self.signature._to_serial(), + visibility=self.visibility, ) def inner_signature(self) -> tys.FunctionType: @@ -1207,12 +1231,15 @@ class FuncDecl(Op): #: polymorphic function signature signature: tys.PolyFuncType num_out: int = field(default=1, repr=False) + #: Visibility (for linking). + visibility: Visibility = "Public" def _to_serial(self, parent: Node) -> sops.FuncDecl: return sops.FuncDecl( parent=parent.idx, name=self.f_name, signature=self.signature._to_serial(), + visibility=self.visibility, ) def port_kind(self, port: InPort | OutPort) -> tys.Kind: @@ -1391,7 +1418,9 @@ class LoadFunc(_CallOrLoad, DataflowOp): is provided. """ - num_out: int = field(default=1, repr=False) + @property + def num_out(self) -> int: + return 1 def _to_serial(self, parent: Node) -> sops.LoadFunction: return sops.LoadFunction( diff --git a/hugr-py/src/hugr/std/_json_defs/collections/borrow_arr.json b/hugr-py/src/hugr/std/_json_defs/collections/borrow_arr.json new file mode 100644 index 0000000000..1774b4aea6 --- /dev/null +++ b/hugr-py/src/hugr/std/_json_defs/collections/borrow_arr.json @@ -0,0 +1,1139 @@ +{ + "version": "0.1.1", + "name": "collections.borrow_arr", + "types": { + "borrow_array": { + "extension": "collections.borrow_arr", + "name": "borrow_array", + "params": [ + { + "tp": "BoundedNat", + "bound": null + }, + { + "tp": "Type", + "b": "A" + } + ], + "description": "Fixed-length borrow array", + "bound": { + "b": "Explicit", + "bound": "A" + } + } + }, + "operations": { + "borrow": { + "extension": "collections.borrow_arr", + "name": "borrow", + "description": "Take an element from a borrow array (panicking if it was already taken before)", + "signature": { + "params": [ + { + "tp": "BoundedNat", + "bound": null + }, + { + "tp": "Type", + "b": "A" + } + ], + "body": { + "input": [ + { + "t": "Opaque", + "extension": "collections.borrow_arr", + "id": "borrow_array", + "args": [ + { + "tya": "Variable", + "idx": 0, + "cached_decl": { + "tp": "BoundedNat", + "bound": null + } + }, + { + "tya": "Type", + "ty": { + "t": "V", + "i": 1, + "b": "A" + } + } + ], + "bound": "A" + }, + { + "t": "I" + } + ], + "output": [ + { + "t": "V", + "i": 1, + "b": "A" + }, + { + "t": "Opaque", + "extension": "collections.borrow_arr", + "id": "borrow_array", + "args": [ + { + "tya": "Variable", + "idx": 0, + "cached_decl": { + "tp": "BoundedNat", + "bound": null + } + }, + { + "tya": "Type", + "ty": { + "t": "V", + "i": 1, + "b": "A" + } + } + ], + "bound": "A" + } + ] + } + }, + "binary": false + }, + "clone": { + "extension": "collections.borrow_arr", + "name": "clone", + "description": "Clones an array with copyable elements", + "signature": { + "params": [ + { + "tp": "BoundedNat", + "bound": null + }, + { + "tp": "Type", + "b": "C" + } + ], + "body": { + "input": [ + { + "t": "Opaque", + "extension": "collections.borrow_arr", + "id": "borrow_array", + "args": [ + { + "tya": "Variable", + "idx": 0, + "cached_decl": { + "tp": "BoundedNat", + "bound": null + } + }, + { + "tya": "Type", + "ty": { + "t": "V", + "i": 1, + "b": "C" + } + } + ], + "bound": "A" + } + ], + "output": [ + { + "t": "Opaque", + "extension": "collections.borrow_arr", + "id": "borrow_array", + "args": [ + { + "tya": "Variable", + "idx": 0, + "cached_decl": { + "tp": "BoundedNat", + "bound": null + } + }, + { + "tya": "Type", + "ty": { + "t": "V", + "i": 1, + "b": "C" + } + } + ], + "bound": "A" + }, + { + "t": "Opaque", + "extension": "collections.borrow_arr", + "id": "borrow_array", + "args": [ + { + "tya": "Variable", + "idx": 0, + "cached_decl": { + "tp": "BoundedNat", + "bound": null + } + }, + { + "tya": "Type", + "ty": { + "t": "V", + "i": 1, + "b": "C" + } + } + ], + "bound": "A" + } + ] + } + }, + "binary": false + }, + "discard": { + "extension": "collections.borrow_arr", + "name": "discard", + "description": "Discards an array with copyable elements", + "signature": { + "params": [ + { + "tp": "BoundedNat", + "bound": null + }, + { + "tp": "Type", + "b": "C" + } + ], + "body": { + "input": [ + { + "t": "Opaque", + "extension": "collections.borrow_arr", + "id": "borrow_array", + "args": [ + { + "tya": "Variable", + "idx": 0, + "cached_decl": { + "tp": "BoundedNat", + "bound": null + } + }, + { + "tya": "Type", + "ty": { + "t": "V", + "i": 1, + "b": "C" + } + } + ], + "bound": "A" + } + ], + "output": [] + } + }, + "binary": false + }, + "discard_all_borrowed": { + "extension": "collections.borrow_arr", + "name": "discard_all_borrowed", + "description": "Discard a borrow array where all elements have been borrowed", + "signature": { + "params": [ + { + "tp": "BoundedNat", + "bound": null + }, + { + "tp": "Type", + "b": "A" + } + ], + "body": { + "input": [ + { + "t": "Opaque", + "extension": "collections.borrow_arr", + "id": "borrow_array", + "args": [ + { + "tya": "Variable", + "idx": 0, + "cached_decl": { + "tp": "BoundedNat", + "bound": null + } + }, + { + "tya": "Type", + "ty": { + "t": "V", + "i": 1, + "b": "A" + } + } + ], + "bound": "A" + } + ], + "output": [] + } + }, + "binary": false + }, + "discard_empty": { + "extension": "collections.borrow_arr", + "name": "discard_empty", + "description": "Discard an empty array", + "signature": { + "params": [ + { + "tp": "Type", + "b": "A" + } + ], + "body": { + "input": [ + { + "t": "Opaque", + "extension": "collections.borrow_arr", + "id": "borrow_array", + "args": [ + { + "tya": "BoundedNat", + "n": 0 + }, + { + "tya": "Type", + "ty": { + "t": "V", + "i": 0, + "b": "A" + } + } + ], + "bound": "A" + } + ], + "output": [] + } + }, + "binary": false + }, + "from_array": { + "extension": "collections.borrow_arr", + "name": "from_array", + "description": "Turns `array` into `borrow_array`", + "signature": { + "params": [ + { + "tp": "BoundedNat", + "bound": null + }, + { + "tp": "Type", + "b": "A" + } + ], + "body": { + "input": [ + { + "t": "Opaque", + "extension": "collections.array", + "id": "array", + "args": [ + { + "tya": "Variable", + "idx": 0, + "cached_decl": { + "tp": "BoundedNat", + "bound": null + } + }, + { + "tya": "Type", + "ty": { + "t": "V", + "i": 1, + "b": "A" + } + } + ], + "bound": "A" + } + ], + "output": [ + { + "t": "Opaque", + "extension": "collections.borrow_arr", + "id": "borrow_array", + "args": [ + { + "tya": "Variable", + "idx": 0, + "cached_decl": { + "tp": "BoundedNat", + "bound": null + } + }, + { + "tya": "Type", + "ty": { + "t": "V", + "i": 1, + "b": "A" + } + } + ], + "bound": "A" + } + ] + } + }, + "binary": false + }, + "get": { + "extension": "collections.borrow_arr", + "name": "get", + "description": "Get an element from an array", + "signature": { + "params": [ + { + "tp": "BoundedNat", + "bound": null + }, + { + "tp": "Type", + "b": "C" + } + ], + "body": { + "input": [ + { + "t": "Opaque", + "extension": "collections.borrow_arr", + "id": "borrow_array", + "args": [ + { + "tya": "Variable", + "idx": 0, + "cached_decl": { + "tp": "BoundedNat", + "bound": null + } + }, + { + "tya": "Type", + "ty": { + "t": "V", + "i": 1, + "b": "C" + } + } + ], + "bound": "A" + }, + { + "t": "I" + } + ], + "output": [ + { + "t": "Sum", + "s": "General", + "rows": [ + [], + [ + { + "t": "V", + "i": 1, + "b": "C" + } + ] + ] + }, + { + "t": "Opaque", + "extension": "collections.borrow_arr", + "id": "borrow_array", + "args": [ + { + "tya": "Variable", + "idx": 0, + "cached_decl": { + "tp": "BoundedNat", + "bound": null + } + }, + { + "tya": "Type", + "ty": { + "t": "V", + "i": 1, + "b": "C" + } + } + ], + "bound": "A" + } + ] + } + }, + "binary": false + }, + "new_all_borrowed": { + "extension": "collections.borrow_arr", + "name": "new_all_borrowed", + "description": "Create a new borrow array that contains no elements", + "signature": { + "params": [ + { + "tp": "BoundedNat", + "bound": null + }, + { + "tp": "Type", + "b": "A" + } + ], + "body": { + "input": [], + "output": [ + { + "t": "Opaque", + "extension": "collections.borrow_arr", + "id": "borrow_array", + "args": [ + { + "tya": "Variable", + "idx": 0, + "cached_decl": { + "tp": "BoundedNat", + "bound": null + } + }, + { + "tya": "Type", + "ty": { + "t": "V", + "i": 1, + "b": "A" + } + } + ], + "bound": "A" + } + ] + } + }, + "binary": false + }, + "new_array": { + "extension": "collections.borrow_arr", + "name": "new_array", + "description": "Create a new array from elements", + "signature": null, + "binary": true + }, + "pop_left": { + "extension": "collections.borrow_arr", + "name": "pop_left", + "description": "Pop an element from the left of an array", + "signature": null, + "binary": true + }, + "pop_right": { + "extension": "collections.borrow_arr", + "name": "pop_right", + "description": "Pop an element from the right of an array", + "signature": null, + "binary": true + }, + "repeat": { + "extension": "collections.borrow_arr", + "name": "repeat", + "description": "Creates a new array whose elements are initialised by calling the given function n times", + "signature": { + "params": [ + { + "tp": "BoundedNat", + "bound": null + }, + { + "tp": "Type", + "b": "A" + } + ], + "body": { + "input": [ + { + "t": "G", + "input": [], + "output": [ + { + "t": "V", + "i": 1, + "b": "A" + } + ] + } + ], + "output": [ + { + "t": "Opaque", + "extension": "collections.borrow_arr", + "id": "borrow_array", + "args": [ + { + "tya": "Variable", + "idx": 0, + "cached_decl": { + "tp": "BoundedNat", + "bound": null + } + }, + { + "tya": "Type", + "ty": { + "t": "V", + "i": 1, + "b": "A" + } + } + ], + "bound": "A" + } + ] + } + }, + "binary": false + }, + "return": { + "extension": "collections.borrow_arr", + "name": "return", + "description": "Put an element into a borrow array (panicking if there is an element already)", + "signature": { + "params": [ + { + "tp": "BoundedNat", + "bound": null + }, + { + "tp": "Type", + "b": "A" + } + ], + "body": { + "input": [ + { + "t": "Opaque", + "extension": "collections.borrow_arr", + "id": "borrow_array", + "args": [ + { + "tya": "Variable", + "idx": 0, + "cached_decl": { + "tp": "BoundedNat", + "bound": null + } + }, + { + "tya": "Type", + "ty": { + "t": "V", + "i": 1, + "b": "A" + } + } + ], + "bound": "A" + }, + { + "t": "I" + }, + { + "t": "V", + "i": 1, + "b": "A" + } + ], + "output": [ + { + "t": "Opaque", + "extension": "collections.borrow_arr", + "id": "borrow_array", + "args": [ + { + "tya": "Variable", + "idx": 0, + "cached_decl": { + "tp": "BoundedNat", + "bound": null + } + }, + { + "tya": "Type", + "ty": { + "t": "V", + "i": 1, + "b": "A" + } + } + ], + "bound": "A" + } + ] + } + }, + "binary": false + }, + "scan": { + "extension": "collections.borrow_arr", + "name": "scan", + "description": "A combination of map and foldl. Applies a function to each element of the array with an accumulator that is passed through from start to finish. Returns the resulting array and the final state of the accumulator.", + "signature": { + "params": [ + { + "tp": "BoundedNat", + "bound": null + }, + { + "tp": "Type", + "b": "A" + }, + { + "tp": "Type", + "b": "A" + }, + { + "tp": "List", + "param": { + "tp": "Type", + "b": "A" + } + } + ], + "body": { + "input": [ + { + "t": "Opaque", + "extension": "collections.borrow_arr", + "id": "borrow_array", + "args": [ + { + "tya": "Variable", + "idx": 0, + "cached_decl": { + "tp": "BoundedNat", + "bound": null + } + }, + { + "tya": "Type", + "ty": { + "t": "V", + "i": 1, + "b": "A" + } + } + ], + "bound": "A" + }, + { + "t": "G", + "input": [ + { + "t": "V", + "i": 1, + "b": "A" + }, + { + "t": "R", + "i": 3, + "b": "A" + } + ], + "output": [ + { + "t": "V", + "i": 2, + "b": "A" + }, + { + "t": "R", + "i": 3, + "b": "A" + } + ] + }, + { + "t": "R", + "i": 3, + "b": "A" + } + ], + "output": [ + { + "t": "Opaque", + "extension": "collections.borrow_arr", + "id": "borrow_array", + "args": [ + { + "tya": "Variable", + "idx": 0, + "cached_decl": { + "tp": "BoundedNat", + "bound": null + } + }, + { + "tya": "Type", + "ty": { + "t": "V", + "i": 2, + "b": "A" + } + } + ], + "bound": "A" + }, + { + "t": "R", + "i": 3, + "b": "A" + } + ] + } + }, + "binary": false + }, + "set": { + "extension": "collections.borrow_arr", + "name": "set", + "description": "Set an element in an array", + "signature": { + "params": [ + { + "tp": "BoundedNat", + "bound": null + }, + { + "tp": "Type", + "b": "A" + } + ], + "body": { + "input": [ + { + "t": "Opaque", + "extension": "collections.borrow_arr", + "id": "borrow_array", + "args": [ + { + "tya": "Variable", + "idx": 0, + "cached_decl": { + "tp": "BoundedNat", + "bound": null + } + }, + { + "tya": "Type", + "ty": { + "t": "V", + "i": 1, + "b": "A" + } + } + ], + "bound": "A" + }, + { + "t": "I" + }, + { + "t": "V", + "i": 1, + "b": "A" + } + ], + "output": [ + { + "t": "Sum", + "s": "General", + "rows": [ + [ + { + "t": "V", + "i": 1, + "b": "A" + }, + { + "t": "Opaque", + "extension": "collections.borrow_arr", + "id": "borrow_array", + "args": [ + { + "tya": "Variable", + "idx": 0, + "cached_decl": { + "tp": "BoundedNat", + "bound": null + } + }, + { + "tya": "Type", + "ty": { + "t": "V", + "i": 1, + "b": "A" + } + } + ], + "bound": "A" + } + ], + [ + { + "t": "V", + "i": 1, + "b": "A" + }, + { + "t": "Opaque", + "extension": "collections.borrow_arr", + "id": "borrow_array", + "args": [ + { + "tya": "Variable", + "idx": 0, + "cached_decl": { + "tp": "BoundedNat", + "bound": null + } + }, + { + "tya": "Type", + "ty": { + "t": "V", + "i": 1, + "b": "A" + } + } + ], + "bound": "A" + } + ] + ] + } + ] + } + }, + "binary": false + }, + "swap": { + "extension": "collections.borrow_arr", + "name": "swap", + "description": "Swap two elements in an array", + "signature": { + "params": [ + { + "tp": "BoundedNat", + "bound": null + }, + { + "tp": "Type", + "b": "A" + } + ], + "body": { + "input": [ + { + "t": "Opaque", + "extension": "collections.borrow_arr", + "id": "borrow_array", + "args": [ + { + "tya": "Variable", + "idx": 0, + "cached_decl": { + "tp": "BoundedNat", + "bound": null + } + }, + { + "tya": "Type", + "ty": { + "t": "V", + "i": 1, + "b": "A" + } + } + ], + "bound": "A" + }, + { + "t": "I" + }, + { + "t": "I" + } + ], + "output": [ + { + "t": "Sum", + "s": "General", + "rows": [ + [ + { + "t": "Opaque", + "extension": "collections.borrow_arr", + "id": "borrow_array", + "args": [ + { + "tya": "Variable", + "idx": 0, + "cached_decl": { + "tp": "BoundedNat", + "bound": null + } + }, + { + "tya": "Type", + "ty": { + "t": "V", + "i": 1, + "b": "A" + } + } + ], + "bound": "A" + } + ], + [ + { + "t": "Opaque", + "extension": "collections.borrow_arr", + "id": "borrow_array", + "args": [ + { + "tya": "Variable", + "idx": 0, + "cached_decl": { + "tp": "BoundedNat", + "bound": null + } + }, + { + "tya": "Type", + "ty": { + "t": "V", + "i": 1, + "b": "A" + } + } + ], + "bound": "A" + } + ] + ] + } + ] + } + }, + "binary": false + }, + "to_array": { + "extension": "collections.borrow_arr", + "name": "to_array", + "description": "Turns `borrow_array` into `array`", + "signature": { + "params": [ + { + "tp": "BoundedNat", + "bound": null + }, + { + "tp": "Type", + "b": "A" + } + ], + "body": { + "input": [ + { + "t": "Opaque", + "extension": "collections.borrow_arr", + "id": "borrow_array", + "args": [ + { + "tya": "Variable", + "idx": 0, + "cached_decl": { + "tp": "BoundedNat", + "bound": null + } + }, + { + "tya": "Type", + "ty": { + "t": "V", + "i": 1, + "b": "A" + } + } + ], + "bound": "A" + } + ], + "output": [ + { + "t": "Opaque", + "extension": "collections.array", + "id": "array", + "args": [ + { + "tya": "Variable", + "idx": 0, + "cached_decl": { + "tp": "BoundedNat", + "bound": null + } + }, + { + "tya": "Type", + "ty": { + "t": "V", + "i": 1, + "b": "A" + } + } + ], + "bound": "A" + } + ] + } + }, + "binary": false + }, + "unpack": { + "extension": "collections.borrow_arr", + "name": "unpack", + "description": "Unpack an array into its elements", + "signature": null, + "binary": true + } + } +} diff --git a/hugr-py/src/hugr/std/_json_defs/prelude.json b/hugr-py/src/hugr/std/_json_defs/prelude.json index 7cf1d02c70..81c2f948a0 100644 --- a/hugr-py/src/hugr/std/_json_defs/prelude.json +++ b/hugr-py/src/hugr/std/_json_defs/prelude.json @@ -1,5 +1,5 @@ { - "version": "0.2.0", + "version": "0.2.1", "name": "prelude", "types": { "error": { @@ -77,6 +77,38 @@ }, "binary": false }, + "MakeError": { + "extension": "prelude", + "name": "MakeError", + "description": "Create an error value", + "signature": { + "params": [], + "body": { + "input": [ + { + "t": "I" + }, + { + "t": "Opaque", + "extension": "prelude", + "id": "string", + "args": [], + "bound": "C" + } + ], + "output": [ + { + "t": "Opaque", + "extension": "prelude", + "id": "error", + "args": [], + "bound": "C" + } + ] + } + }, + "binary": false + }, "MakeTuple": { "extension": "prelude", "name": "MakeTuple", diff --git a/hugr-py/src/hugr/std/collections/array.py b/hugr-py/src/hugr/std/collections/array.py index 958b826502..d7f70a2318 100644 --- a/hugr-py/src/hugr/std/collections/array.py +++ b/hugr-py/src/hugr/std/collections/array.py @@ -54,7 +54,7 @@ def size(self) -> int | None: return None def type_bound(self) -> tys.TypeBound: - return tys.TypeBound.Any + return tys.TypeBound.Linear @dataclass diff --git a/hugr-py/src/hugr/std/collections/borrow_array.py b/hugr-py/src/hugr/std/collections/borrow_array.py new file mode 100644 index 0000000000..9d01f86e6a --- /dev/null +++ b/hugr-py/src/hugr/std/collections/borrow_array.py @@ -0,0 +1,94 @@ +"""Borrow array types and operations.""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import cast + +import hugr.model as model +from hugr import tys, val +from hugr.std import _load_extension +from hugr.utils import comma_sep_str + +EXTENSION = _load_extension("collections.borrow_arr") + + +@dataclass(eq=False) +class BorrowArray(tys.ExtType): + """Fixed `size` borrow array of `ty` elements.""" + + def __init__(self, ty: tys.Type, size: int | tys.TypeArg) -> None: + if isinstance(size, int): + size = tys.BoundedNatArg(size) + + err_msg = ( + f"Borrow array size must be a bounded natural or a nat variable, not {size}" + ) + match size: + case tys.BoundedNatArg(_n): + pass + case tys.VariableArg(_idx, param): + if not isinstance(param, tys.BoundedNatParam): + raise ValueError(err_msg) # noqa: TRY004 + case _: + raise ValueError(err_msg) + + ty_arg = tys.TypeTypeArg(ty) + + self.type_def = EXTENSION.types["borrow_array"] + self.args = [size, ty_arg] + + @property + def ty(self) -> tys.Type: + assert isinstance( + self.args[1], tys.TypeTypeArg + ), "Borrow array elements must have a valid type" + return self.args[1].ty + + @property + def size(self) -> int | None: + """If the borrow array has a concrete size, return it. + + Otherwise, return None. + """ + if isinstance(self.args[0], tys.BoundedNatArg): + return self.args[0].n + return None + + def type_bound(self) -> tys.TypeBound: + return tys.TypeBound.Linear + + +# Note that only borrow array values with no elements borrowed should be emitted. +@dataclass +class BorrowArrayVal(val.ExtensionValue): + """Constant value for a statically sized borrow array of elements.""" + + v: list[val.Value] + ty: BorrowArray + + def __init__(self, v: list[val.Value], elem_ty: tys.Type) -> None: + self.v = v + self.ty = BorrowArray(elem_ty, len(v)) + + def to_value(self) -> val.Extension: + name = "BorrowArrayValue" + # The value list must be serialized at this point, otherwise the + # `Extension` value would not be serializable. + vs = [v._to_serial_root() for v in self.v] + element_ty = self.ty.ty._to_serial_root() + serial_val = {"values": vs, "typ": element_ty} + return val.Extension(name, typ=self.ty, val=serial_val) + + def __str__(self) -> str: + return f"borrow_array({comma_sep_str(self.v)})" + + def to_model(self) -> model.Term: + return model.Apply( + "collections.borrow_array.const", + [ + model.Literal(len(self.v)), + cast(model.Term, self.ty.ty.to_model()), + model.List([value.to_model() for value in self.v]), + ], + ) diff --git a/hugr-py/src/hugr/std/collections/static_array.py b/hugr-py/src/hugr/std/collections/static_array.py index 60975b336d..84731ee1bf 100644 --- a/hugr-py/src/hugr/std/collections/static_array.py +++ b/hugr-py/src/hugr/std/collections/static_array.py @@ -50,6 +50,9 @@ def __init__(self, v: list[val.Value], elem_ty: tys.Type, name: str) -> None: self.name = name def to_value(self) -> val.Extension: + # Encode the nested values as JSON strings directly, to mirror what + # happens when loading (where we can't decode the constant payload back + # into specialized `Value`s). serial_val = { "value": { "values": [v._to_serial_root() for v in self.v], diff --git a/hugr-py/src/hugr/std/int.py b/hugr-py/src/hugr/std/int.py index f58bc9e3eb..4107fcd284 100644 --- a/hugr-py/src/hugr/std/int.py +++ b/hugr-py/src/hugr/std/int.py @@ -52,16 +52,39 @@ def _int_tv(index: int) -> tys.ExtType: INT_T = int_t(5) +def _to_unsigned(val: int, bits: int) -> int: + """Convert a signed integer to its unsigned representation + in twos-complement form. + + Positive integers are unchanged, while negative integers + are converted by adding 2^bits to the value. + + Raises ValueError if the value is out of range for the given bit width + (valid range is [-2^(bits-1), 2^(bits-1)-1]). + """ + half_max = 1 << (bits - 1) + min_val = -half_max + max_val = half_max - 1 + if val < min_val or val > max_val: + msg = f"Value {val} out of range for {bits}-bit signed integer." + raise ValueError(msg) # + + if val < 0: + return (1 << bits) + val + return val + + @dataclass class IntVal(val.ExtensionValue): - """Custom value for an integer.""" + """Custom value for a signed integer.""" v: int width: int = field(default=5) def to_value(self) -> val.Extension: name = "ConstInt" - payload = {"log_width": self.width, "value": self.v} + unsigned = _to_unsigned(self.v, 1 << self.width) + payload = {"log_width": self.width, "value": unsigned} return val.Extension( name, typ=int_t(self.width), @@ -72,8 +95,9 @@ def __str__(self) -> str: return f"{self.v}" def to_model(self) -> model.Term: + unsigned = _to_unsigned(self.v, 1 << self.width) return model.Apply( - "arithmetic.int.const", [model.Literal(self.width), model.Literal(self.v)] + "arithmetic.int.const", [model.Literal(self.width), model.Literal(unsigned)] ) diff --git a/hugr-py/src/hugr/tys.py b/hugr-py/src/hugr/tys.py index 8411f19bfa..a59a9a90bf 100644 --- a/hugr-py/src/hugr/tys.py +++ b/hugr-py/src/hugr/tys.py @@ -2,12 +2,13 @@ from __future__ import annotations +import base64 from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Protocol, cast, runtime_checkable +from typing import TYPE_CHECKING, Literal, Protocol, cast, runtime_checkable import hugr._serialization.tys as stys import hugr.model as model -from hugr.utils import comma_sep_repr, comma_sep_str, ser_it +from hugr.utils import comma_sep_repr, comma_sep_str, comma_sep_str_paren, ser_it if TYPE_CHECKING: from collections.abc import Iterable, Sequence @@ -18,6 +19,7 @@ ExtensionId = stys.ExtensionId ExtensionSet = stys.ExtensionSet TypeBound = stys.TypeBound +Visibility = Literal["Public", "Private"] @runtime_checkable @@ -67,7 +69,7 @@ def type_bound(self) -> stys.TypeBound: >>> Tuple(Bool, Bool).type_bound() >>> Tuple(Qubit, Bool).type_bound() - + """ ... # pragma: no cover @@ -154,6 +156,34 @@ def to_model(self) -> model.Term: return model.Apply("core.str") +@dataclass(frozen=True) +class FloatParam(TypeParam): + """Float type parameter.""" + + def _to_serial(self) -> stys.FloatParam: + return stys.FloatParam() + + def __str__(self) -> str: + return "Float" + + def to_model(self) -> model.Term: + return model.Apply("core.float") + + +@dataclass(frozen=True) +class BytesParam(TypeParam): + """Bytes type parameter.""" + + def _to_serial(self) -> stys.BytesParam: + return stys.BytesParam() + + def __str__(self) -> str: + return "Bytes" + + def to_model(self) -> model.Term: + return model.Apply("core.bytes") + + @dataclass(frozen=True) class ListParam(TypeParam): """Type parameter which requires a list of type arguments.""" @@ -245,26 +275,120 @@ def to_model(self) -> model.Term: @dataclass(frozen=True) -class SequenceArg(TypeArg): - """Sequence of type arguments, for a :class:`ListParam` or :class:`TupleParam`.""" +class FloatArg(TypeArg): + """A floating point type argument.""" + + value: float + + def _to_serial(self) -> stys.FloatArg: + return stys.FloatArg(value=self.value) + + def __str__(self) -> str: + return f"{self.value}" + + def to_model(self) -> model.Term: + return model.Literal(self.value) + + +@dataclass(frozen=True) +class BytesArg(TypeArg): + """A bytes type argument.""" + + value: bytes + + def _to_serial(self) -> stys.BytesArg: + value = base64.b64encode(self.value).decode() + return stys.BytesArg(value=value) + + def __str__(self) -> str: + return "bytes" + + def to_model(self) -> model.Term: + return model.Literal(self.value) + + +@dataclass(frozen=True) +class ListArg(TypeArg): + """Sequence of type arguments for a :class:`ListParam`.""" elems: list[TypeArg] - def _to_serial(self) -> stys.SequenceArg: - return stys.SequenceArg(elems=ser_it(self.elems)) + def _to_serial(self) -> stys.ListArg: + return stys.ListArg(elems=ser_it(self.elems)) def resolve(self, registry: ext.ExtensionRegistry) -> TypeArg: - return SequenceArg([arg.resolve(registry) for arg in self.elems]) + return ListArg([arg.resolve(registry) for arg in self.elems]) def __str__(self) -> str: - return f"({comma_sep_str(self.elems)})" + return f"[{comma_sep_str(self.elems)}]" def to_model(self) -> model.Term: - # TODO: We should separate lists and tuples. - # For now we assume that this is a list. return model.List([elem.to_model() for elem in self.elems]) +@dataclass(frozen=True) +class ListConcatArg(TypeArg): + """Sequence of lists to concatenate for a :class:`ListParam`.""" + + lists: list[TypeArg] + + def _to_serial(self) -> stys.ListConcatArg: + return stys.ListConcatArg(lists=ser_it(self.lists)) + + def resolve(self, registry: ext.ExtensionRegistry) -> TypeArg: + return ListConcatArg([arg.resolve(registry) for arg in self.lists]) + + def __str__(self) -> str: + lists = comma_sep_str(f"... {list}" for list in self.lists) + return f"[{lists}]" + + def to_model(self) -> model.Term: + return model.List( + [model.Splice(cast(model.Term, elem.to_model())) for elem in self.lists] + ) + + +@dataclass(frozen=True) +class TupleArg(TypeArg): + """Sequence of type arguments for a :class:`TupleParam`.""" + + elems: list[TypeArg] + + def _to_serial(self) -> stys.TupleArg: + return stys.TupleArg(elems=ser_it(self.elems)) + + def resolve(self, registry: ext.ExtensionRegistry) -> TypeArg: + return TupleArg([arg.resolve(registry) for arg in self.elems]) + + def __str__(self) -> str: + return f"({comma_sep_str(self.elems)})" + + def to_model(self) -> model.Term: + return model.Tuple([elem.to_model() for elem in self.elems]) + + +@dataclass(frozen=True) +class TupleConcatArg(TypeArg): + """Sequence of tuples to concatenate for a :class:`TupleParam`.""" + + tuples: list[TypeArg] + + def _to_serial(self) -> stys.TupleConcatArg: + return stys.TupleConcatArg(tuples=ser_it(self.tuples)) + + def resolve(self, registry: ext.ExtensionRegistry) -> TypeArg: + return TupleConcatArg([arg.resolve(registry) for arg in self.tuples]) + + def __str__(self) -> str: + tuples = comma_sep_str(f"... {tuple}" for tuple in self.tuples) + return f"({tuples})" + + def to_model(self) -> model.Term: + return model.Tuple( + [model.Splice(cast(model.Term, elem.to_model())) for elem in self.tuples] + ) + + @dataclass(frozen=True) class VariableArg(TypeArg): """A type argument variable.""" @@ -306,7 +430,38 @@ def as_tuple(self) -> Tuple: return Tuple(*self.variant_rows[0]) def __repr__(self) -> str: - return f"Sum({self.variant_rows})" + if self == Bool: + return "Bool" + elif self == Unit: + return "Unit" + elif all(len(row) == 0 for row in self.variant_rows): + return f"UnitSum({len(self.variant_rows)})" + elif len(self.variant_rows) == 1: + return f"Tuple{tuple(self.variant_rows[0])}" + elif len(self.variant_rows) == 2 and len(self.variant_rows[0]) == 0: + return f"Option({comma_sep_repr(self.variant_rows[1])})" + elif len(self.variant_rows) == 2: + left, right = self.variant_rows + return f"Either(left={left}, right={right})" + else: + return f"Sum({self.variant_rows})" + + def __str__(self) -> str: + if self == Bool: + return "Bool" + elif self == Unit: + return "Unit" + elif all(len(row) == 0 for row in self.variant_rows): + return f"UnitSum({len(self.variant_rows)})" + elif len(self.variant_rows) == 1: + return f"Tuple{tuple(self.variant_rows[0])}" + elif len(self.variant_rows) == 2 and len(self.variant_rows[0]) == 0: + return f"Option({comma_sep_str(self.variant_rows[1])})" + elif len(self.variant_rows) == 2: + left, right = self.variant_rows + return f"Either({comma_sep_str_paren(left)}, {comma_sep_str_paren(right)})" + else: + return f"Sum({self.variant_rows})" def __eq__(self, other: object) -> bool: return isinstance(other, Sum) and self.variant_rows == other.variant_rows @@ -325,7 +480,7 @@ def to_model(self) -> model.Term: return model.Apply("core.adt", [variants]) -@dataclass(eq=False) +@dataclass(eq=False, repr=False) class UnitSum(Sum): """Simple :class:`Sum` type with `size` variants of empty rows.""" @@ -338,18 +493,14 @@ def __init__(self, size: int): def _to_serial(self) -> stys.UnitSum: # type: ignore[override] return stys.UnitSum(size=self.size) - def __repr__(self) -> str: - if self == Bool: - return "Bool" - elif self == Unit: - return "Unit" - return f"UnitSum({self.size})" - def resolve(self, registry: ext.ExtensionRegistry) -> UnitSum: return self + def __str__(self) -> str: + return self.__repr__() -@dataclass(eq=False) + +@dataclass(eq=False, repr=False) class Tuple(Sum): """Product type with `tys` elements. Instances of this type correspond to :class:`Sum` with a single variant. @@ -358,11 +509,8 @@ class Tuple(Sum): def __init__(self, *tys: Type): self.variant_rows = [list(tys)] - def __repr__(self) -> str: - return f"Tuple{tuple(self.variant_rows[0])}" - -@dataclass(eq=False) +@dataclass(eq=False, repr=False) class Option(Sum): """Optional tuple of elements. @@ -373,11 +521,8 @@ class Option(Sum): def __init__(self, *tys: Type): self.variant_rows = [[], list(tys)] - def __repr__(self) -> str: - return f"Option({comma_sep_repr(self.variant_rows[1])})" - -@dataclass(eq=False) +@dataclass(eq=False, repr=False) class Either(Sum): """Two-variant tuple of elements. @@ -390,16 +535,6 @@ class Either(Sum): def __init__(self, left: Iterable[Type], right: Iterable[Type]): self.variant_rows = [list(left), list(right)] - def __repr__(self) -> str: # pragma: no cover - left, right = self.variant_rows - return f"Either(left={left}, right={right})" - - def __str__(self) -> str: - left, right = self.variant_rows - left_str = left[0] if len(left) == 1 else tuple(left) - right_str = right[0] if len(right) == 1 else tuple(right) - return f"Either({left_str}, {right_str})" - @dataclass(frozen=True) class Variable(Type): @@ -632,15 +767,7 @@ def __eq__(self, value): return super().__eq__(value) def to_model(self) -> model.Term: - # This cast is only neccessary because `Type` can both be an - # actual type or a row variable. - args = [cast(model.Term, arg.to_model()) for arg in self.args] - - extension_name = self.type_def.get_extension().name - type_name = self.type_def.name - name = f"{extension_name}.{type_name}" - - return model.Apply(name, args) + return self._to_opaque().to_model() def _type_str(name: str, args: Sequence[TypeArg]) -> str: @@ -687,17 +814,17 @@ def __str__(self) -> str: return _type_str(self.id, self.args) def to_model(self) -> model.Term: - # This cast is only neccessary because `Type` can both be an + # This cast is only necessary because `Type` can both be an # actual type or a row variable. args = [cast(model.Term, arg.to_model()) for arg in self.args] - return model.Apply(self.id, args) + return model.Apply(f"{self.extension}.{self.id}", args) @dataclass class _QubitDef(Type): def type_bound(self) -> TypeBound: - return TypeBound.Any + return TypeBound.Linear def _to_serial(self) -> stys.Qubit: return stys.Qubit() diff --git a/hugr-py/src/hugr/utils.py b/hugr-py/src/hugr/utils.py index 480f3337b9..0c6048ec32 100644 --- a/hugr-py/src/hugr/utils.py +++ b/hugr-py/src/hugr/utils.py @@ -215,3 +215,27 @@ def comma_sep_str(items: Iterable[T]) -> str: def comma_sep_repr(items: Iterable[T]) -> str: """Join items with commas and repr.""" return ", ".join(map(repr, items)) + + +def comma_sep_str_paren(items: Iterable[T]) -> str: + """Join items with commas and str, wrapping them in parentheses if more than one.""" + items = list(items) + if len(items) == 0: + return "()" + elif len(items) == 1: + return f"{items[0]}" + else: + return f"({comma_sep_str(items)})" + + +def comma_sep_repr_paren(items: Iterable[T]) -> str: + """Join items with commas and repr, wrapping them in parentheses if more + than one. + """ + items = list(items) + if len(items) == 0: + return "()" + elif len(items) == 1: + return f"{items[0]}" + else: + return f"({comma_sep_repr(items)})" diff --git a/hugr-py/src/hugr/val.py b/hugr-py/src/hugr/val.py index 925c91f989..a929969edd 100644 --- a/hugr-py/src/hugr/val.py +++ b/hugr-py/src/hugr/val.py @@ -45,8 +45,8 @@ class Sum(Value): """Sum-of-product value. Example: - >>> Sum(0, tys.Sum([[tys.Bool], [tys.Unit]]), [TRUE]) - Sum(tag=0, typ=Sum([[Bool], [Unit]]), vals=[TRUE]) + >>> Sum(0, tys.Sum([[tys.Bool], [tys.Unit], [tys.Bool]]), [TRUE]) + Sum(tag=0, typ=Sum([[Bool], [Unit], [Bool]]), vals=[TRUE]) """ #: Tag identifying the variant. @@ -70,6 +70,59 @@ def _to_serial(self) -> sops.SumValue: vs=ser_it(self.vals), ) + def __repr__(self) -> str: + if self == TRUE: + return "TRUE" + elif self == FALSE: + return "FALSE" + elif self == Unit: + return "Unit" + elif all(len(row) == 0 for row in self.typ.variant_rows): + return f"UnitSum({self.tag}, {self.n_variants})" + elif len(self.typ.variant_rows) == 1: + return f"Tuple({comma_sep_repr(self.vals)})" + elif len(self.typ.variant_rows) == 2 and len(self.typ.variant_rows[0]) == 0: + # Option + if self.tag == 0: + return f"None({comma_sep_str(self.typ.variant_rows[1])})" + else: + return f"Some({comma_sep_repr(self.vals)})" + elif len(self.typ.variant_rows) == 2: + # Either + left_typ, right_typ = self.typ.variant_rows + if self.tag == 0: + return f"Left(vals={self.vals}, right_typ={list(right_typ)})" + else: + return f"Right(left_typ={list(left_typ)}, vals={self.vals})" + else: + return f"Sum(tag={self.tag}, typ={self.typ}, vals={self.vals})" + + def __str__(self) -> str: + if self == TRUE: + return "TRUE" + elif self == FALSE: + return "FALSE" + elif self == Unit: + return "Unit" + elif all(len(row) == 0 for row in self.typ.variant_rows): + return f"UnitSum({self.tag}, {self.n_variants})" + elif len(self.typ.variant_rows) == 1: + return f"Tuple({comma_sep_str(self.vals)})" + elif len(self.typ.variant_rows) == 2 and len(self.typ.variant_rows[0]) == 0: + # Option + if self.tag == 0: + return "None" + else: + return f"Some({comma_sep_str(self.vals)})" + elif len(self.typ.variant_rows) == 2: + # Either + if self.tag == 0: + return f"Left({comma_sep_str(self.vals)})" + else: + return f"Right({comma_sep_str(self.vals)})" + else: + return f"Sum({self.tag}, {self.typ}, {self.vals})" + def __eq__(self, other: object) -> bool: return ( isinstance(other, Sum) @@ -100,6 +153,7 @@ def to_model(self) -> model.Term: ) +@dataclass(eq=False, repr=False) class UnitSum(Sum): """Simple :class:`Sum` with each variant being an empty row. @@ -119,15 +173,6 @@ def __init__(self, tag: int, size: int): vals=[], ) - def __repr__(self) -> str: - if self == TRUE: - return "TRUE" - if self == FALSE: - return "FALSE" - if self == Unit: - return "Unit" - return f"UnitSum({self.tag}, {self.n_variants})" - def bool_value(b: bool) -> UnitSum: """Convert a python bool to a HUGR boolean value. @@ -149,7 +194,7 @@ def bool_value(b: bool) -> UnitSum: FALSE = bool_value(False) -@dataclass(eq=False) +@dataclass(eq=False, repr=False) class Tuple(Sum): """Tuple or product value, defined by a list of values. Internally a :class:`Sum` with a single variant row. @@ -169,18 +214,18 @@ def __init__(self, *vals: Value): tag=0, typ=tys.Tuple(*(v.type_() for v in val_list)), vals=val_list ) - # sops.TupleValue isn't an instance of sops.SumValue - # so mypy doesn't like the override of Sum._to_serial - def _to_serial(self) -> sops.TupleValue: # type: ignore[override] - return sops.TupleValue( + def _to_serial(self) -> sops.SumValue: + return sops.SumValue( + tag=0, + typ=stys.SumType(root=self.type_()._to_serial()), vs=ser_it(self.vals), ) def __repr__(self) -> str: - return f"Tuple({comma_sep_repr(self.vals)})" + return super().__repr__() -@dataclass(eq=False) +@dataclass(eq=False, repr=False) class Some(Sum): """Optional tuple of value, containing a list of values. @@ -199,11 +244,8 @@ def __init__(self, *vals: Value): tag=1, typ=tys.Option(*(v.type_() for v in val_list)), vals=val_list ) - def __repr__(self) -> str: - return f"Some({comma_sep_repr(self.vals)})" - -@dataclass(eq=False) +@dataclass(eq=False, repr=False) class None_(Sum): """Optional tuple of value, containing no values. @@ -219,14 +261,8 @@ class None_(Sum): def __init__(self, *types: tys.Type): super().__init__(tag=0, typ=tys.Option(*types), vals=[]) - def __repr__(self) -> str: - return f"None({comma_sep_str(self.typ.variant_rows[1])})" - - def __str__(self) -> str: - return "None" - -@dataclass(eq=False) +@dataclass(eq=False, repr=False) class Left(Sum): """Left variant of a :class:`tys.Either` type, containing a list of values. @@ -248,15 +284,8 @@ def __init__(self, vals: Iterable[Value], right_typ: Iterable[tys.Type]): vals=val_list, ) - def __repr__(self) -> str: - _, right_typ = self.typ.variant_rows - return f"Left(vals={self.vals}, right_typ={list(right_typ)})" - - def __str__(self) -> str: - return f"Left({comma_sep_str(self.vals)})" - -@dataclass(eq=False) +@dataclass(eq=False, repr=False) class Right(Sum): """Right variant of a :class:`tys.Either` type, containing a list of values. @@ -280,13 +309,6 @@ def __init__(self, left_typ: Iterable[tys.Type], vals: Iterable[Value]): vals=val_list, ) - def __repr__(self) -> str: - left_typ, _ = self.typ.variant_rows - return f"Right(left_typ={list(left_typ)}, vals={self.vals})" - - def __str__(self) -> str: - return f"Right({comma_sep_str(self.vals)})" - @dataclass class Function(Value): @@ -298,9 +320,7 @@ def type_(self) -> tys.FunctionType: return self.body.entrypoint_op().inner_signature() def _to_serial(self) -> sops.FunctionValue: - return sops.FunctionValue( - hugr=self.body._to_serial(), - ) + return sops.FunctionValue(hugr=self.body.to_str()) def to_model(self) -> model.Term: module = self.body.to_model() diff --git a/hugr-py/tests/__snapshots__/test_hugr_build.ambr b/hugr-py/tests/__snapshots__/test_hugr_build.ambr index 733ba214c5..240953c02b 100644 --- a/hugr-py/tests/__snapshots__/test_hugr_build.ambr +++ b/hugr-py/tests/__snapshots__/test_hugr_build.ambr @@ -191,6 +191,16 @@ + + + + + + +
0
+ + + > shape=plain] color="#1CADE4" label="" margin=10 penwidth=1 @@ -480,6 +490,16 @@ + + + + + + +
0
+ + + > shape=plain] color="#1CADE4" label="" margin=10 penwidth=1 @@ -794,6 +814,16 @@ + + + + + + +
0
+ + + > shape=plain] color="#1CADE4" label="" margin=10 penwidth=1 @@ -825,6 +855,129 @@ ''' # --- +# name: test_fndef_output_ports + ''' + digraph { + bgcolor=white margin=0 nodesep=0.15 rankdir="" ranksep=0.1 + subgraph cluster0 { + subgraph cluster1 { + 2 [label=< + + + + + + +
+ + +
Input
+
+ > shape=plain] + 3 [label=< + + + + + + + + + + +
+ + + + +
0123
+
+ + +
Output
+
+ > shape=plain] + 4 [label=< + + + + + + + + + + +
+ + +
MakeTuple
+
+ + + + +
0
+
+ > shape=plain] + 1 [label=< + + + + + + + + + + +
+ + +
FuncDefn(main)
+
+ + + + +
0
+
+ > shape=plain] + color="#1CADE4" label="" margin=10 penwidth=1 + } + 0 [label=< + + + + + + +
+ + +
[Module]
+
+ > shape=plain] + color="#F4A261" label="" margin=10 penwidth=2 + } + 4:"out.0" -> 3:"in.0" [label=Unit arrowhead=none arrowsize=1.0 color="#1CADE4" fontcolor=black fontname=monospace fontsize=9 penwidth=1.5] + 4:"out.0" -> 3:"in.1" [label=Unit arrowhead=none arrowsize=1.0 color="#1CADE4" fontcolor=black fontname=monospace fontsize=9 penwidth=1.5] + 4:"out.0" -> 3:"in.2" [label=Unit arrowhead=none arrowsize=1.0 color="#1CADE4" fontcolor=black fontname=monospace fontsize=9 penwidth=1.5] + 4:"out.0" -> 3:"in.3" [label=Unit arrowhead=none arrowsize=1.0 color="#1CADE4" fontcolor=black fontname=monospace fontsize=9 penwidth=1.5] + } + + ''' +# --- # name: test_higher_order ''' digraph { @@ -940,7 +1093,7 @@ + COLOR="black">Const(Function(body=Hugr(module_root=Node(0), entrypoint=Node(4), _nodes=[NodeData(op=Module(), parent=None, metadata={}), NodeData(op=FuncDefn(f_name='main', inputs=[Qubit], params=[], visibility='Private'), parent=Node(0), metadata={}), NodeData(op=Input(types=[Qubit]), parent=Node(1), metadata={}), NodeData(op=Output(), parent=Node(1), metadata={}), NodeData(op=DFG(inputs=[Qubit]), parent=Node(1), metadata={}), NodeData(op=Input(types=[Qubit]), parent=Node(4), metadata={}), NodeData(op=Output(), parent=Node(4), metadata={}), NodeData(op=Noop(Qubit), parent=Node(4), metadata={})], _links=BiMap({_SubPort(port=OutPort(Node(2), 0), sub_offset=0): _SubPort(port=InPort(Node(4), 0), sub_offset=0), _SubPort(port=OutPort(Node(5), 0), sub_offset=0): _SubPort(port=InPort(Node(7), 0), sub_offset=0), _SubPort(port=OutPort(Node(7), 0), sub_offset=0): _SubPort(port=InPort(Node(6), 0), sub_offset=0), _SubPort(port=OutPort(Node(4), 0), sub_offset=0): _SubPort(port=InPort(Node(3), 0), sub_offset=0)}), _free_nodes=[])))
Const(Function(body=Hugr(module_root=Node(0), entrypoint=Node(4), _nodes=[NodeData(op=Module(), parent=None, metadata={}), NodeData(op=FuncDefn(f_name='main', inputs=[Qubit], params=[]), parent=Node(0), metadata={}), NodeData(op=Input(types=[Qubit]), parent=Node(1), metadata={}), NodeData(op=Output(), parent=Node(1), metadata={}), NodeData(op=DFG(inputs=[Qubit]), parent=Node(1), metadata={}), NodeData(op=Input(types=[Qubit]), parent=Node(4), metadata={}), NodeData(op=Output(), parent=Node(4), metadata={}), NodeData(op=Noop(Qubit), parent=Node(4), metadata={})], _links=BiMap({_SubPort(port=OutPort(Node(2), 0), sub_offset=0): _SubPort(port=InPort(Node(4), 0), sub_offset=0), _SubPort(port=OutPort(Node(5), 0), sub_offset=0): _SubPort(port=InPort(Node(7), 0), sub_offset=0), _SubPort(port=OutPort(Node(7), 0), sub_offset=0): _SubPort(port=InPort(Node(6), 0), sub_offset=0), _SubPort(port=OutPort(Node(4), 0), sub_offset=0): _SubPort(port=InPort(Node(3), 0), sub_offset=0)}), _free_nodes=[])))
@@ -1100,7 +1253,7 @@ } 2:"out.0" -> 4:"in.0" [label=Qubit arrowhead=none arrowsize=1.0 color="#1CADE4" fontcolor=black fontname=monospace fontsize=9 penwidth=1.5] 7:"out.0" -> 8:"in.0" [label="" arrowhead=none arrowsize=1.0 color="#77CEEF" fontcolor=black fontname=monospace fontsize=9 penwidth=1.5] - 8:"out.0" -> 9:"in.0" [label="Qubit -> Qubit" arrowhead=none arrowsize=1.0 color="#1CADE4" fontcolor=black fontname=monospace fontsize=9 penwidth=1.5] + 8:"out.0" -> 9:"in.0" [label="Qubit -> Qubit" arrowhead=none arrowsize=1.0 color="#1CADE4" fontcolor=black fontname=monospace fontsize=9 penwidth=1.5] 5:"out.0" -> 9:"in.1" [label=Qubit arrowhead=none arrowsize=1.0 color="#1CADE4" fontcolor=black fontname=monospace fontsize=9 penwidth=1.5] 5:"out.-1" -> 8:"in.-1" [label="" arrowhead=none arrowsize=1.0 color=black fontcolor=black fontname=monospace fontsize=9 penwidth=1.5] 9:"out.0" -> 6:"in.0" [label=Qubit arrowhead=none arrowsize=1.0 color="#1CADE4" fontcolor=black fontname=monospace fontsize=9 penwidth=1.5] @@ -1109,6 +1262,147 @@ ''' # --- +# name: test_html_labels + ''' + digraph "<i>Module Root</i>" { + bgcolor=white margin=0 nodesep=0.15 rankdir="" ranksep=0.1 + subgraph cluster0 { + subgraph cluster1 { + 2 [label=< + + + + + + + + + + +
+ + +
Input
+
+ + + + +
0
+
+ > shape=plain] + 3 [label=< + + + + + + + + + + +
+ + + + +
0
+
+ + +
Output
+
+ > shape=plain] + 4 [label=< + + + + + + + + + + + + + + +
+ + + + +
0
+
+ + +
Some
+
+ + + + +
0
+
+ > shape=plain] + 1 [label=< + + + + + + + + + + +
+ + +
[FuncDefn(<jupyter-notebook>)]

label: <b>Bold Label</b>
<other-label>: <i>Italic Label</i>
meta_can_be_anything: [42, 'string', 3.14, True]
+
+ + + + +
0
+
+ > shape=plain] + color="#F4A261" label="" margin=10 penwidth=2 + } + 0 [label=< + + + + + + +
+ + +
Module

name: <i>Module Root</i>
+
+ > shape=plain] + color="#1CADE4" label="" margin=10 penwidth=1 + } + 2:"out.0" -> 4:"in.0" [label=Bool arrowhead=none arrowsize=1.0 color="#1CADE4" fontcolor=black fontname=monospace fontsize=9 penwidth=1.5] + 2:"out.0" -> 3:"in.0" [label=Bool arrowhead=none arrowsize=1.0 color="#1CADE4" fontcolor=black fontname=monospace fontsize=9 penwidth=1.5] + } + + ''' +# --- # name: test_insert_nested ''' digraph { @@ -1389,6 +1683,16 @@ + + + + + + +
0
+ + + > shape=plain] color="#1CADE4" label="" margin=10 penwidth=1 @@ -1612,6 +1916,16 @@ + + + + + + +
0
+ + + > shape=plain] color="#1CADE4" label="" margin=10 penwidth=1 @@ -1833,6 +2147,16 @@ + + + + + + +
0
+ + + > shape=plain] color="#1CADE4" label="" margin=10 penwidth=1 @@ -1854,14 +2178,14 @@ > shape=plain] color="#1CADE4" label="" margin=10 penwidth=1 } - 2:"out.0" -> 4:"in.0" [label="int<5>" arrowhead=none arrowsize=1.0 color="#1CADE4" fontcolor=black fontname=monospace fontsize=9 penwidth=1.5] - 2:"out.1" -> 4:"in.1" [label="int<5>" arrowhead=none arrowsize=1.0 color="#1CADE4" fontcolor=black fontname=monospace fontsize=9 penwidth=1.5] - 5:"out.0" -> 7:"in.0" [label="int<5>" arrowhead=none arrowsize=1.0 color="#1CADE4" fontcolor=black fontname=monospace fontsize=9 penwidth=1.5] - 5:"out.1" -> 7:"in.1" [label="int<5>" arrowhead=none arrowsize=1.0 color="#1CADE4" fontcolor=black fontname=monospace fontsize=9 penwidth=1.5] - 7:"out.0" -> 6:"in.0" [label="int<5>" arrowhead=none arrowsize=1.0 color="#1CADE4" fontcolor=black fontname=monospace fontsize=9 penwidth=1.5] - 7:"out.1" -> 6:"in.1" [label="int<5>" arrowhead=none arrowsize=1.0 color="#1CADE4" fontcolor=black fontname=monospace fontsize=9 penwidth=1.5] - 4:"out.0" -> 3:"in.0" [label="int<5>" arrowhead=none arrowsize=1.0 color="#1CADE4" fontcolor=black fontname=monospace fontsize=9 penwidth=1.5] - 4:"out.1" -> 3:"in.1" [label="int<5>" arrowhead=none arrowsize=1.0 color="#1CADE4" fontcolor=black fontname=monospace fontsize=9 penwidth=1.5] + 2:"out.0" -> 4:"in.0" [label="int<5>" arrowhead=none arrowsize=1.0 color="#1CADE4" fontcolor=black fontname=monospace fontsize=9 penwidth=1.5] + 2:"out.1" -> 4:"in.1" [label="int<5>" arrowhead=none arrowsize=1.0 color="#1CADE4" fontcolor=black fontname=monospace fontsize=9 penwidth=1.5] + 5:"out.0" -> 7:"in.0" [label="int<5>" arrowhead=none arrowsize=1.0 color="#1CADE4" fontcolor=black fontname=monospace fontsize=9 penwidth=1.5] + 5:"out.1" -> 7:"in.1" [label="int<5>" arrowhead=none arrowsize=1.0 color="#1CADE4" fontcolor=black fontname=monospace fontsize=9 penwidth=1.5] + 7:"out.0" -> 6:"in.0" [label="int<5>" arrowhead=none arrowsize=1.0 color="#1CADE4" fontcolor=black fontname=monospace fontsize=9 penwidth=1.5] + 7:"out.1" -> 6:"in.1" [label="int<5>" arrowhead=none arrowsize=1.0 color="#1CADE4" fontcolor=black fontname=monospace fontsize=9 penwidth=1.5] + 4:"out.0" -> 3:"in.0" [label="int<5>" arrowhead=none arrowsize=1.0 color="#1CADE4" fontcolor=black fontname=monospace fontsize=9 penwidth=1.5] + 4:"out.1" -> 3:"in.1" [label="int<5>" arrowhead=none arrowsize=1.0 color="#1CADE4" fontcolor=black fontname=monospace fontsize=9 penwidth=1.5] } ''' @@ -2023,6 +2347,16 @@ + + + + + + +
0
+ + + > shape=plain] color="#1CADE4" label="" margin=10 penwidth=1 @@ -2352,6 +2686,16 @@ + + + + + + +
0
+ + + > shape=plain] color="#1CADE4" label="" margin=10 penwidth=1 @@ -2610,6 +2954,16 @@ + + + + + + +
0
+ + + > shape=plain] color="#1CADE4" label="" margin=10 penwidth=1 diff --git a/hugr-py/tests/__snapshots__/test_order_edges.ambr b/hugr-py/tests/__snapshots__/test_order_edges.ambr new file mode 100644 index 0000000000..3dae0b5530 --- /dev/null +++ b/hugr-py/tests/__snapshots__/test_order_edges.ambr @@ -0,0 +1,258 @@ +# serializer version: 1 +# name: test_order_unconnected + ''' + digraph { + bgcolor=white margin=0 nodesep=0.15 rankdir="" ranksep=0.1 + subgraph cluster0 { + subgraph cluster1 { + 2 [label=< + + + + + + + + + + +
+ + +
Input
+
+ + + + +
0
+
+ > shape=plain] + 3 [label=< + + + + + + + + + + +
+ + + + +
0
+
+ + +
Output
+
+ > shape=plain] + subgraph cluster4 { + 5 [label=< + + + + + + + + + + +
+ + +
Input
+
+ + + + +
0
+
+ > shape=plain] + 6 [label=< + + + + + + + + + + +
+ + + + +
0
+
+ + +
Output
+
+ > shape=plain] + 7 [label=< + + + + + + + + + + + + + + +
+ + + + +
0
+
+ + +
MeasureFree
+
+ + + + +
0
+
+ > shape=plain] + 8 [label=< + + + + + + + + + + +
+ + +
QAlloc
+
+ + + + +
0
+
+ > shape=plain] + 4 [label=< + + + + + + + + + + + + + + +
+ + + + +
0
+
+ + +
[DFG]
+
+ + + + +
0
+
+ > shape=plain] + color="#F4A261" label="" margin=10 penwidth=2 + } + 1 [label=< + + + + + + + + + + +
+ + +
FuncDefn(main)
+
+ + + + +
0
+
+ > shape=plain] + color="#1CADE4" label="" margin=10 penwidth=1 + } + 0 [label=< + + + + + + +
+ + +
Module
+
+ > shape=plain] + color="#1CADE4" label="" margin=10 penwidth=1 + } + 2:"out.0" -> 4:"in.0" [label=Qubit arrowhead=none arrowsize=1.0 color="#1CADE4" fontcolor=black fontname=monospace fontsize=9 penwidth=1.5] + 5:"out.0" -> 7:"in.0" [label=Qubit arrowhead=none arrowsize=1.0 color="#1CADE4" fontcolor=black fontname=monospace fontsize=9 penwidth=1.5] + 7:"out.-1" -> 8:"in.-1" [label="" arrowhead=none arrowsize=1.0 color=black fontcolor=black fontname=monospace fontsize=9 penwidth=1.5] + 8:"out.0" -> 6:"in.0" [label=Qubit arrowhead=none arrowsize=1.0 color="#1CADE4" fontcolor=black fontname=monospace fontsize=9 penwidth=1.5] + 4:"out.0" -> 3:"in.0" [label=Qubit arrowhead=none arrowsize=1.0 color="#1CADE4" fontcolor=black fontname=monospace fontsize=9 penwidth=1.5] + } + + ''' +# --- diff --git a/hugr-py/tests/conftest.py b/hugr-py/tests/conftest.py index f4b2617a01..909d0a8bfd 100644 --- a/hugr-py/tests/conftest.py +++ b/hugr-py/tests/conftest.py @@ -12,13 +12,16 @@ from hugr import ext, tys from hugr.envelope import EnvelopeConfig from hugr.hugr import Hugr -from hugr.ops import AsExtOp, Command, DataflowOp, ExtOp, RegisteredOp +from hugr.ops import AsExtOp, Command, Const, Custom, DataflowOp, ExtOp, RegisteredOp from hugr.package import Package from hugr.std.float import FLOAT_T if TYPE_CHECKING: + import typing + from syrupy.assertion import SnapshotAssertion + from hugr.hugr.node_port import Node from hugr.ops import ComWire QUANTUM_EXT = ext.Extension("pytest.quantum", ext.Version(0, 1, 0)) @@ -106,6 +109,32 @@ def __call__(self, q: ComWire) -> Command: Measure = MeasureDef() +@QUANTUM_EXT.register_op( + "MeasureFree", + signature=tys.FunctionType([tys.Qubit], [tys.Bool]), +) +@dataclass(frozen=True) +class MeasureFreeDef(RegisteredOp): + def __call__(self, q: ComWire) -> Command: + return super().__call__(q) + + +MeasureFree = MeasureFreeDef() + + +@QUANTUM_EXT.register_op( + "QAlloc", + signature=tys.FunctionType([], [tys.Qubit]), +) +@dataclass(frozen=True) +class QAllocDef(RegisteredOp): + def __call__(self) -> Command: + return super().__call__() + + +QAlloc = QAllocDef() + + @QUANTUM_EXT.register_op( "Rz", signature=tys.FunctionType([tys.Qubit, FLOAT_T], [tys.Qubit]), @@ -146,42 +175,156 @@ def validate( snapshot: A hugr render snapshot. If not None, it will be compared against the rendered HUGR. Pass `--snapshot-update` to pytest to update the snapshot file. """ - # TODO: Use envelopes instead of legacy hugr-json + if snap is not None: + dot = h.render_dot() if isinstance(h, Hugr) else h.modules[0].render_dot() + assert snap == dot.source + if os.environ.get("HUGR_RENDER_DOT"): + dot.pipe("svg") + + # Encoding formats to test, indexed by the format name as used by + # `hugr convert --format`. + FORMATS = { + "json": EnvelopeConfig.TEXT, + "model-exts": EnvelopeConfig.BINARY, + } + # Envelope formats used when exporting test hugrs. + WRITE_FORMATS = ["json", "model-exts"] + # Envelope formats used as target for `hugr convert` before loading back the + # test hugrs. + # + # Model envelopes cannot currently be loaded from python. + # TODO: Add model envelope loading to python, and add it to the list. + LOAD_FORMATS = ["json"] + cmd = [*_base_command(), "validate", "-"] - serial = h.to_bytes(EnvelopeConfig.BINARY) - _run_hugr_cmd(serial, cmd) - - if not roundtrip: - return - - # Roundtrip checks - if isinstance(h, Hugr): - starting_json = h.to_str() - h2 = Hugr.from_str(starting_json) - roundtrip_json = h2.to_str() - assert roundtrip_json == starting_json - - if snap is not None: - dot = h.render_dot() - assert snap == dot.source - if os.environ.get("HUGR_RENDER_DOT"): - dot.pipe("svg") - else: - # Package - encoded = h.to_str(EnvelopeConfig.TEXT) - loaded = Package.from_str(encoded) - roundtrip_encoded = loaded.to_str(EnvelopeConfig.TEXT) - assert encoded == roundtrip_encoded - - -def _run_hugr_cmd(serial: bytes, cmd: list[str]): + # validate text and binary formats + for write_fmt in WRITE_FORMATS: + serial = h.to_bytes(FORMATS[write_fmt]) + _run_hugr_cmd(serial, cmd) + + if roundtrip: + # Roundtrip tests: + # Try converting to all possible LOAD_FORMATS, load them back in, + # and check that the loaded HUGR corresponds to the original using + # a node hash comparison. + # + # Run `pytest` with `-vv` to see the hash diff. + for load_fmt in LOAD_FORMATS: + if load_fmt != write_fmt: + cmd = [*_base_command(), "convert", "--format", load_fmt, "-"] + out = _run_hugr_cmd(serial, cmd) + converted_serial = out.stdout + else: + converted_serial = serial + loaded = Package.from_bytes(converted_serial) + + modules = [h] if isinstance(h, Hugr) else h.modules + + assert len(loaded.modules) == len(modules) + for m1, m2 in zip(loaded.modules, modules, strict=True): + h1_hash = _NodeHash.hash_hugr(m1, "original") + h2_hash = _NodeHash.hash_hugr(m2, "loaded") + assert ( + h1_hash == h2_hash + ), f"HUGRs are not the same for {write_fmt} -> {load_fmt}" + + # Lowering functions are currently ignored in Python, + # because we don't support loading -model envelopes yet. + for ext in loaded.extensions: + for op in ext.operations.values(): + assert op.lower_funcs == [] + + +@dataclass(frozen=True, order=True) +class _NodeHash: + op: _OpHash + entrypoint: bool + input_neighbours: int + output_neighbours: int + input_ports: int + output_ports: int + input_order_edges: int + output_order_edges: int + is_region: bool + node_depth: int + children_hashes: list[_NodeHash] + metadata: dict[str, str] + + @classmethod + def hash_hugr(cls, h: Hugr, name: str) -> _NodeHash: + """Returns an order-independent hash of a HUGR.""" + return cls._hash_node(h, h.module_root, 0, name) + + @classmethod + def _hash_node(cls, h: Hugr, n: Node, depth: int, name: str) -> _NodeHash: + children = h.children(n) + child_hashes = sorted(cls._hash_node(h, c, depth + 1, name) for c in children) + metadata = {k: str(v) for k, v in h[n].metadata.items()} + + # Pick a normalized representation of the op name. + op_type = h[n].op + if isinstance(op_type, AsExtOp): + op_type = op_type.ext_op.to_custom_op() + op = _OpHash(f"{op_type.extension}.{op_type.op_name}") + elif isinstance(op_type, Custom): + op = _OpHash(f"{op_type.extension}.{op_type.op_name}") + elif isinstance(op_type, Const): + # We need every custom value to have the same repr if they compare + # equal. For example, an `IntVal(42)` should be the same as the + # equivalent `Extension` value. This needs a lot of extra + # unwrapping, since each class implements different `__repr__` + # methods. + # + # Our solution here is to encode the value into JSON and compare those. + # This may miss some errors, but it's the best we can do for now. Note that + # roundtripping via `sops.Value` is not enough, since nested + # specialized values don't get serialized straight away. (e.g. + # StaticArrayVal's dictionary payload containing a SumValue + # internally, see `test_val_static_array`). + value_dict = op_type.val._to_serial_root().model_dump(mode="json") + op = _OpHash("Const", value_dict) + else: + op = _OpHash(op_type.name()) + + return _NodeHash( + entrypoint=n == h.entrypoint, + op=op, + input_neighbours=h.num_incoming(n), + output_neighbours=h.num_outgoing(n), + input_ports=h.num_in_ports(n), + output_ports=h.num_out_ports(n), + input_order_edges=len(list(h.incoming_order_links(n))), + output_order_edges=len(list(h.outgoing_order_links(n))), + is_region=len(children) > 0, + node_depth=depth, + children_hashes=child_hashes, + metadata=metadata, + ) + + +@dataclass(frozen=True) +class _OpHash: + name: str + payload: None | typing.Any = None + + def __lt__(self, other: _OpHash) -> bool: + """Compare two op hashes by name and payload.""" + return (self.name, repr(self.payload)) < (other.name, repr(other.payload)) + + +def _get_mermaid(serial: bytes) -> str: # + """Render a HUGR as a mermaid diagram using the CLI.""" + return _run_hugr_cmd(serial, [*_base_command(), "mermaid", "-"]).stdout.decode() + + +def _run_hugr_cmd(serial: bytes, cmd: list[str]) -> subprocess.CompletedProcess[bytes]: """Run a HUGR command. The `serial` argument is the serialized HUGR to pass to the command via stdin. """ try: - subprocess.run(cmd, check=True, input=serial, capture_output=True) # noqa: S603 + return subprocess.run(cmd, check=True, input=serial, capture_output=True) # noqa: S603 except subprocess.CalledProcessError as e: error = e.stderr.decode() raise RuntimeError(error) from e diff --git a/hugr-py/tests/test_cfg.py b/hugr-py/tests/test_cfg.py index 14753e2f44..3eed38e106 100644 --- a/hugr-py/tests/test_cfg.py +++ b/hugr-py/tests/test_cfg.py @@ -8,7 +8,7 @@ def build_basic_cfg(cfg: Cfg) -> None: with cfg.add_entry() as entry: entry.set_single_succ_outputs(*entry.inputs()) - cfg.branch(entry[0], cfg.exit) + cfg.branch_exit(entry[0]) def test_basic_cfg() -> None: diff --git a/hugr-py/tests/test_custom.py b/hugr-py/tests/test_custom.py index 48f57de7a7..1bb980915f 100644 --- a/hugr-py/tests/test_custom.py +++ b/hugr-py/tests/test_custom.py @@ -139,7 +139,7 @@ def test_custom_bad_eq(): ext.TypeDef( "List", description="A list of elements.", - params=[tys.TypeTypeParam(tys.TypeBound.Any)], + params=[tys.TypeTypeParam(tys.TypeBound.Linear)], bound=ext.FromParamsBound([0]), ) ) diff --git a/hugr-py/tests/test_envelope.py b/hugr-py/tests/test_envelope.py index 1e667e8eef..c97cd074bc 100644 --- a/hugr-py/tests/test_envelope.py +++ b/hugr-py/tests/test_envelope.py @@ -1,10 +1,17 @@ -from hugr import tys +from pathlib import Path + +import pytest +import semver + +from hugr import ops, tys from hugr.build.function import Module from hugr.envelope import EnvelopeConfig, EnvelopeFormat +from hugr.hugr.node_port import Node from hugr.package import Package -def test_envelope(): +@pytest.fixture +def package() -> Package: mod = Module() f_id = mod.define_function("id", [tys.Qubit]) f_id.set_outputs(f_id.input_node[0]) @@ -17,8 +24,10 @@ def test_envelope(): q = f_main.input_node[0] call = f_main.call(f_id_decl, q) f_main.set_outputs(call) - package = Package([mod.hugr, mod2.hugr]) + return Package([mod.hugr, mod2.hugr]) + +def test_envelope(package: Package): # Binary compression roundtrip for format in [EnvelopeFormat.JSON]: for compression in [None, 0]: @@ -27,6 +36,30 @@ def test_envelope(): assert decoded == package # String roundtrip - encoded = package.to_str(EnvelopeConfig.TEXT) - decoded = Package.from_str(encoded) + encoded_str = package.to_str(EnvelopeConfig.TEXT) + decoded = Package.from_str(encoded_str) assert decoded == package + + +def test_model(package: Package): + model_pkg = package.to_model() + + # This value is statically defined in the rust bindings. + assert model_pkg.version >= semver.Version(major=1, minor=0, patch=0) + + +def test_legacy_funcdefn(): + p = Path(__file__).parents[2] / "resources" / "test" / "hugr-no-visibility.hugr" + try: + with p.open("rb") as f: + pkg_bytes = f.read() + except FileNotFoundError: + pytest.skip("Missing test file") + decoded = Package.from_bytes(pkg_bytes) + h = decoded.modules[0] + op1 = h[Node(1)].op + assert isinstance(op1, ops.FuncDecl) + assert op1.visibility == "Public" + op2 = h[Node(2)].op + assert isinstance(op2, ops.FuncDefn) + assert op2.visibility == "Private" diff --git a/hugr-py/tests/test_hugr_build.py b/hugr-py/tests/test_hugr_build.py index 74c8018f9e..0a6677145b 100644 --- a/hugr-py/tests/test_hugr_build.py +++ b/hugr-py/tests/test_hugr_build.py @@ -5,15 +5,16 @@ import hugr.ops as ops import hugr.tys as tys import hugr.val as val -from hugr.build.dfg import Dfg, _ancestral_sibling +from hugr.build.dfg import Dfg, Function, _ancestral_sibling from hugr.build.function import Module from hugr.hugr import Hugr from hugr.hugr.node_port import Node, _SubPort from hugr.ops import NoConcreteFunc +from hugr.package import Package from hugr.std.int import INT_T, DivMod, IntVal from hugr.std.logic import Not -from .conftest import validate +from .conftest import QUANTUM_EXT, H, validate def test_stable_indices(): @@ -196,7 +197,7 @@ def test_build_inter_graph(snapshot): validate(h.hugr, snap=snapshot) assert _SubPort(h.input_node.out(-1)) in h.hugr._links - assert h.hugr.num_outgoing(h.input_node) == 2 # doesn't count state order + assert h.hugr.num_outgoing(h.input_node) == 3 assert len(list(h.hugr.outgoing_order_links(h.input_node))) == 1 assert len(list(h.hugr.incoming_order_links(nested))) == 1 assert len(list(h.hugr.incoming_order_links(h.output_node))) == 0 @@ -214,7 +215,6 @@ def test_ancestral_sibling(): @pytest.mark.parametrize( "val", [ - val.Function(simple_id().hugr), val.Sum(1, tys.Sum([[INT_T], [tys.Bool, INT_T]]), [val.TRUE, IntVal(34)]), val.Tuple(val.TRUE, IntVal(23)), ], @@ -232,8 +232,8 @@ def test_poly_function(direct_call: bool) -> None: f_id = mod.declare_function( "id", tys.PolyFuncType( - [tys.TypeTypeParam(tys.TypeBound.Any)], - tys.FunctionType.endo([tys.Variable(0, tys.TypeBound.Any)]), + [tys.TypeTypeParam(tys.TypeBound.Linear)], + tys.FunctionType.endo([tys.Variable(0, tys.TypeBound.Linear)]), ), ) @@ -259,6 +259,39 @@ def test_poly_function(direct_call: bool) -> None: validate(mod.hugr) +def test_literals() -> None: + mod = Module() + + func = mod.declare_function( + "literals", + tys.PolyFuncType( + [ + tys.StringParam(), + tys.BoundedNatParam(), + tys.BytesParam(), + tys.FloatParam(), + ], + tys.FunctionType.endo([tys.Qubit]), + ), + ) + + caller = mod.define_function("caller", [tys.Qubit], [tys.Qubit]) + call = caller.call( + func, + caller.inputs()[0], + instantiation=tys.FunctionType.endo([tys.Qubit]), + type_args=[ + tys.StringArg("string"), + tys.BoundedNatArg(42), + tys.BytesArg(b"HUGR"), + tys.FloatArg(0.9), + ], + ) + caller.set_outputs(call) + + validate(mod.hugr) + + @pytest.mark.parametrize("direct_call", [True, False]) def test_mono_function(direct_call: bool) -> None: mod = Module() @@ -278,6 +311,37 @@ def test_mono_function(direct_call: bool) -> None: validate(mod.hugr) +def test_static_output() -> None: + mod = Module() + + mod.declare_function( + "declared", + tys.PolyFuncType( + [], + tys.FunctionType.endo([]), + ), + ) + + func = mod.define_function("defined", [], []) + func.declare_outputs([]) + func.set_outputs() + + validate(mod.hugr) + + +def test_function_dfg() -> None: + d = Dfg(tys.Qubit) + + f_id = d.module_root_builder().define_function("id", [tys.Qubit]) + f_id.set_outputs(f_id.input_node[0]) + + (q,) = d.inputs() + call = d.call(f_id, q) + d.set_outputs(call) + + validate(d.hugr) + + def test_recursive_function(snapshot) -> None: mod = Module() @@ -299,6 +363,7 @@ def test_invalid_recursive_function() -> None: f_recursive.set_outputs(f_recursive.input_node[0]) +@pytest.mark.skip("Value::Function is deprecated and not supported by model encoding.") def test_higher_order(snapshot) -> None: noop_fn = Dfg(tys.Qubit) noop_fn.set_outputs(noop_fn.add(ops.Noop()(noop_fn.input_node[0]))) @@ -358,3 +423,76 @@ def test_option() -> None: dfg.set_outputs(b) validate(dfg.hugr) + + +# a helper for the toposort tests +@pytest.fixture +def simple_fn() -> Function: + f = Function("prepare_qubit", [tys.Bool, tys.Qubit]) + [b, q] = f.inputs() + + h = f.add_op(H, q) + q = h.out(0) + + nnot = f.add_op(Not, b) + + f.set_outputs(q, nnot, b) + validate(Package([f.hugr], [QUANTUM_EXT])) + return f + + +# https://github.com/CQCL/hugr/issues/2350 +def test_toposort(simple_fn: Function) -> None: + nodes = list(simple_fn.hugr) + func_node = nodes[1] + + sorted_nodes = list(simple_fn.hugr.sorted_region_nodes(func_node)) + assert set(sorted_nodes) == set(simple_fn.hugr.children(simple_fn)) + assert sorted_nodes[0] == simple_fn.input_node + assert sorted_nodes[-1] == simple_fn.output_node + + +def test_toposort_error(simple_fn: Function) -> None: + # Test that we get an error if we toposort an invalid hugr containing a cycle + nodes = list(simple_fn.hugr) + func_node = nodes[1] + + # Add a loop, invalidating the HUGR + simple_fn.hugr.add_link(nodes[4].out_port(), nodes[4].inp(0)) + with pytest.raises( + ValueError, match="Graph contains a cycle. No topological ordering exists." + ): + list(simple_fn.hugr.sorted_region_nodes(func_node)) + + +def test_html_labels(snapshot) -> None: + """Ensures that HTML-like labels can be processed correctly by both the builder and + the renderer. + """ + f = Function( + "", + [tys.Bool], + ) + f.metadata["label"] = "Bold Label" + f.metadata[""] = "Italic Label" + f.metadata["meta_can_be_anything"] = [42, "string", 3.14, True] + + f.hugr[f.hugr.module_root].metadata["name"] = "Module Root" + + b = f.inputs()[0] + f.add_op(ops.Some(tys.Bool), b) + f.set_outputs(b) + + validate(f.hugr, snap=snapshot) + + +# https://github.com/CQCL/hugr/issues/2438 +def test_fndef_output_ports(snapshot): + mod = Module() + main = mod.define_function("main", [], [tys.Unit, tys.Unit, tys.Unit, tys.Unit]) + unit = main.add_op(ops.MakeTuple()) + main.set_outputs(*4 * [unit]) + + assert mod.hugr.num_out_ports(main) == 1 + + validate(mod.hugr, snap=snapshot) diff --git a/hugr-py/tests/test_ops.py b/hugr-py/tests/test_ops.py index 3d5d444772..fd6b8870db 100644 --- a/hugr-py/tests/test_ops.py +++ b/hugr-py/tests/test_ops.py @@ -1,5 +1,6 @@ import pytest +from hugr import tys from hugr.hugr.node_port import InPort, Node, OutPort from hugr.ops import ( CFG, @@ -42,7 +43,8 @@ (DivMod, "arithmetic.int.idivmod_u<5>"), (MakeTuple(), "MakeTuple"), (UnpackTuple(), "UnpackTuple"), - (Tag(0, Bool), "Tag(0)"), + (Tag(0, Bool), "Left"), + (Tag(0, tys.Sum([[Bool, Bool, Bool]])), "Tag(0)"), (CFG([]), "CFG"), (DFG([]), "DFG"), (DataflowBlock([]), "DataflowBlock"), @@ -59,7 +61,7 @@ (FuncDecl("bar", PolyFuncType.empty()), "FuncDecl(bar)"), (Const(TRUE), "Const(TRUE)"), (Noop(), "Noop"), - (AliasDecl("baz", TypeBound.Any), "AliasDecl(baz)"), + (AliasDecl("baz", TypeBound.Linear), "AliasDecl(baz)"), (AliasDefn("baz", Bool), "AliasDefn(baz)"), ], ) diff --git a/hugr-py/tests/test_order_edges.py b/hugr-py/tests/test_order_edges.py new file mode 100644 index 0000000000..78e585a2b6 --- /dev/null +++ b/hugr-py/tests/test_order_edges.py @@ -0,0 +1,49 @@ +from hugr import tys +from hugr.build.dfg import Dfg +from hugr.package import Package + +from .conftest import QUANTUM_EXT, MeasureFree, QAlloc, validate + + +def test_order_links(): + dfg = Dfg(tys.Bool) + inp_0 = dfg.input_node.out(0) + inp_order = dfg.input_node.out(-1) + out_0 = dfg.output_node.inp(0) + out_1 = dfg.output_node.inp(1) + out_order = dfg.output_node.inp(-1) + + dfg.hugr.add_link(inp_0, out_0) + dfg.hugr.add_link(inp_0, out_1) + assert list(dfg.hugr.outgoing_links(dfg.input_node)) == [ + (inp_0, [out_0, out_1]), + ] + assert list(dfg.hugr.incoming_links(dfg.output_node)) == [ + (out_0, [inp_0]), + (out_1, [inp_0]), + ] + + # Now add an order link + dfg.hugr.add_order_link(dfg.input_node, dfg.output_node) + assert list(dfg.hugr.incoming_order_links(dfg.output_node)) == [dfg.input_node] + assert list(dfg.hugr.outgoing_order_links(dfg.input_node)) == [dfg.output_node] + assert list(dfg.hugr.outgoing_links(dfg.input_node)) == [ + (inp_0, [out_0, out_1]), + (inp_order, [out_order]), + ] + assert list(dfg.hugr.incoming_links(dfg.output_node)) == [ + (out_0, [inp_0]), + (out_1, [inp_0]), + (out_order, [inp_order]), + ] + + +# https://github.com/CQCL/hugr/issues/2439 +def test_order_unconnected(snapshot): + dfg = Dfg(tys.Qubit) + meas = dfg.add(MeasureFree(*dfg.inputs())) + alloc = dfg.add_op(QAlloc) + dfg.hugr.add_order_link(meas, alloc) + dfg.set_outputs(alloc) + + validate(Package([dfg.hugr], [QUANTUM_EXT]), snap=snapshot) diff --git a/hugr-py/tests/test_prelude.py b/hugr-py/tests/test_prelude.py index c2ef2cbeec..71af3d4e06 100644 --- a/hugr-py/tests/test_prelude.py +++ b/hugr-py/tests/test_prelude.py @@ -1,4 +1,7 @@ +import pytest + from hugr.build.dfg import Dfg +from hugr.std.int import IntVal, int_t from hugr.std.prelude import STRING_T, StringVal from .conftest import validate @@ -16,3 +19,38 @@ def test_string_val(): dfg.set_outputs(v) validate(dfg.hugr) + + +@pytest.mark.parametrize( + ("log_width", "v", "unsigned"), + [ + (5, 1, 1), + (4, 0, 0), + (6, 42, 42), + (2, -1, 15), + (1, -2, 2), + (3, -23, 233), + (3, -256, None), + (2, 16, None), + ], +) +def test_int_val(log_width: int, v: int, unsigned: int | None): + val = IntVal(v, log_width) + if unsigned is None: + with pytest.raises( + ValueError, + match=f"Value {v} out of range for {1<"), (StaticArray(Bool), "static_array"), (ValueArray(Bool, 3), "value_array<3, Type(Bool)>"), - (Variable(2, TypeBound.Any), "$2"), + (BorrowArray(Bool, 3), "borrow_array<3, Type(Bool)>"), + (Variable(2, TypeBound.Linear), "$2"), (RowVariable(4, TypeBound.Copyable), "$4"), (USize(), "USize"), (INT_T, "int<5>"), @@ -132,10 +147,10 @@ def test_args_str(arg: TypeArg, string: str): (FunctionType([Bool, Qubit], [Qubit, Bool]), "Bool, Qubit -> Qubit, Bool"), ( PolyFuncType( - [TypeTypeParam(TypeBound.Any), BoundedNatParam(7)], + [TypeTypeParam(TypeBound.Linear), BoundedNatParam(7)], FunctionType([_int_tv(1)], [Variable(0, TypeBound.Copyable)]), ), - "∀ Any, Nat(7). int<$1> -> $0", + "∀ Linear, Nat(7). int<$1> -> $0", ), ], ) @@ -166,12 +181,12 @@ def test_array(): ls = Array(Bool, 3) assert ls.ty == Bool assert ls.size == 3 - assert ls.type_bound() == TypeBound.Any + assert ls.type_bound() == TypeBound.Linear ls = Array(ty_var, len_var) assert ls.ty == ty_var assert ls.size is None - assert ls.type_bound() == TypeBound.Any + assert ls.type_bound() == TypeBound.Linear ar_val = ArrayVal([val.TRUE, val.FALSE], Bool) assert ar_val.v == [val.TRUE, val.FALSE] @@ -179,7 +194,7 @@ def test_array(): def test_value_array(): - ty_var = Variable(0, TypeBound.Any) + ty_var = Variable(0, TypeBound.Linear) len_var = VariableArg(1, BoundedNatParam()) ls = ValueArray(Bool, 3) @@ -190,13 +205,32 @@ def test_value_array(): ls = ValueArray(ty_var, len_var) assert ls.ty == ty_var assert ls.size is None - assert ls.type_bound() == TypeBound.Any + assert ls.type_bound() == TypeBound.Linear ar_val = ValueArrayVal([val.TRUE, val.FALSE], Bool) assert ar_val.v == [val.TRUE, val.FALSE] assert ar_val.ty == ValueArray(Bool, 2) +def test_borrow_array(): + ty_var = Variable(0, TypeBound.Copyable) + len_var = VariableArg(1, BoundedNatParam()) + + ls = BorrowArray(Bool, 3) + assert ls.ty == Bool + assert ls.size == 3 + assert ls.type_bound() == TypeBound.Linear + + ls = BorrowArray(ty_var, len_var) + assert ls.ty == ty_var + assert ls.size is None + assert ls.type_bound() == TypeBound.Linear + + ar_val = BorrowArrayVal([val.TRUE, val.FALSE], Bool) + assert ar_val.v == [val.TRUE, val.FALSE] + assert ar_val.ty == BorrowArray(Bool, 2) + + def test_static_array(): ty_var = Variable(0, TypeBound.Copyable) diff --git a/hugr-py/tests/test_val.py b/hugr-py/tests/test_val.py index 11fd1a1194..5afab88400 100644 --- a/hugr-py/tests/test_val.py +++ b/hugr-py/tests/test_val.py @@ -14,7 +14,6 @@ Sum, Tuple, UnitSum, - Value, bool_value, ) @@ -44,9 +43,9 @@ def test_sums(): ("value", "string", "repr_str"), [ ( - Sum(0, tys.Sum([[tys.Bool], [tys.Qubit]]), [TRUE, FALSE]), - "Sum(tag=0, typ=Sum([[Bool], [Qubit]]), vals=[TRUE, FALSE])", - "Sum(tag=0, typ=Sum([[Bool], [Qubit]]), vals=[TRUE, FALSE])", + Sum(0, tys.Sum([[tys.Bool], [tys.Qubit], [tys.Bool]]), [TRUE]), + "Sum(0, Sum([[Bool], [Qubit], [Bool]]), [TRUE])", + "Sum(tag=0, typ=Sum([[Bool], [Qubit], [Bool]]), vals=[TRUE])", ), (UnitSum(0, size=1), "Unit", "Unit"), (UnitSum(0, size=2), "FALSE", "FALSE"), @@ -67,10 +66,15 @@ def test_sums(): ), ], ) -def test_val_sum_str(value: Value, string: str, repr_str: str): +def test_val_sum_str(value: Sum, string: str, repr_str: str): assert str(value) == string assert repr(value) == repr_str + # Make sure the corresponding `Sum` also renders the same + sum_val = Sum(value.tag, value.typ, value.vals) + assert str(sum_val) == string + assert repr(sum_val) == repr_str + def test_val_static_array(): from hugr.std.collections.static_array import StaticArrayVal diff --git a/hugr/CHANGELOG.md b/hugr/CHANGELOG.md index 46439ca695..a29df3b196 100644 --- a/hugr/CHANGELOG.md +++ b/hugr/CHANGELOG.md @@ -1,5 +1,94 @@ # Changelog + +## [0.22.1](https://github.com/CQCL/hugr/compare/hugr-v0.22.0...hugr-v0.22.1) - 2025-07-28 + +### New Features + +- Include copy_discard_array in DelegatingLinearizer::default ([#2479](https://github.com/CQCL/hugr/pull/2479)) +- Inline calls to functions not on cycles in the call graph ([#2450](https://github.com/CQCL/hugr/pull/2450)) + +## [0.22.0](https://github.com/CQCL/hugr/compare/hugr-v0.21.0...hugr-v0.22.0) - 2025-07-24 + +This release fixes multiple inconsistencies between the serialization formats +and improves the error messages when loading unsupported envelopes. + +We now also support nodes with up to `2^32` connections to the same port (up from `2^16`). + +### Bug Fixes + +- Ensure SumTypes have the same json encoding in -rs and -py ([#2465](https://github.com/CQCL/hugr/pull/2465)) + +### New Features + +- ReplaceTypes allows linearizing inside Op replacements ([#2435](https://github.com/CQCL/hugr/pull/2435)) +- Add pass for DFG inlining ([#2460](https://github.com/CQCL/hugr/pull/2460)) +- Export entrypoint metadata in Python and fix bug in import ([#2434](https://github.com/CQCL/hugr/pull/2434)) +- Names of private functions become `core.title` metadata. ([#2448](https://github.com/CQCL/hugr/pull/2448)) +- [**breaking**] Use binary envelopes for operation lower_func encoding ([#2447](https://github.com/CQCL/hugr/pull/2447)) +- [**breaking**] Update portgraph dependency to 0.15 ([#2455](https://github.com/CQCL/hugr/pull/2455)) +- Detect and fail on unrecognised envelope flags ([#2453](https://github.com/CQCL/hugr/pull/2453)) +- include generator metatada in model import and cli validate errors ([#2452](https://github.com/CQCL/hugr/pull/2452)) +- [**breaking**] Add `insert_region` to HugrMut ([#2463](https://github.com/CQCL/hugr/pull/2463)) +- Non-region entrypoints in `hugr-model`. ([#2467](https://github.com/CQCL/hugr/pull/2467)) + +## [0.21.0](https://github.com/CQCL/hugr/compare/hugr-v0.20.2...hugr-v0.21.0) - 2025-07-09 + + +This release includes a long list of changes: + +- The HUGR model serialization format is now stable, and should be preferred over the old JSON format. +- Type parameters and type arguments are now unified into a single `Term` type. +- Function definitions can no longer be nested inside dataflow regions. Now they must be defined at the top level module. +- Function definitions and declarations now have a `Visibility` field, which define whether they are visible in the public API of the module. +- And many more fixes and improvements. + +### Bug Fixes + +- DeadFuncElimPass+CallGraph w/ non-module-child entrypoint ([#2390](https://github.com/CQCL/hugr/pull/2390)) +- Fixed two bugs in import/export of function operations ([#2324](https://github.com/CQCL/hugr/pull/2324)) +- Model import should perform extension resolution ([#2326](https://github.com/CQCL/hugr/pull/2326)) +- [**breaking**] Fixed bugs in model CFG handling and improved CFG signatures ([#2334](https://github.com/CQCL/hugr/pull/2334)) +- Use List instead of Tuple in conversions for TypeArg/TypeRow ([#2378](https://github.com/CQCL/hugr/pull/2378)) +- Do extension resolution on loaded extensions from the model format ([#2389](https://github.com/CQCL/hugr/pull/2389)) +- Make JSON Schema checks actually work again ([#2412](https://github.com/CQCL/hugr/pull/2412)) +- Order hints on input and output nodes. ([#2422](https://github.com/CQCL/hugr/pull/2422)) + +### Documentation + +- Hide hugr-persistent docs ([#2357](https://github.com/CQCL/hugr/pull/2357)) + +### New Features + +- [**breaking**] Split `TypeArg::Sequence` into tuples and lists. ([#2140](https://github.com/CQCL/hugr/pull/2140)) +- [**breaking**] Added float and bytes literal to core and python bindings. ([#2289](https://github.com/CQCL/hugr/pull/2289)) +- [**breaking**] More helpful error messages in model import ([#2272](https://github.com/CQCL/hugr/pull/2272)) +- [**breaking**] Better error reporting in `hugr-cli`. ([#2318](https://github.com/CQCL/hugr/pull/2318)) +- [**breaking**] Merge `TypeParam` and `TypeArg` into one `Term` type in Rust ([#2309](https://github.com/CQCL/hugr/pull/2309)) +- *(persistent)* Add serialisation for CommitStateSpace ([#2344](https://github.com/CQCL/hugr/pull/2344)) +- add TryFrom impls for TypeArg/TypeRow ([#2366](https://github.com/CQCL/hugr/pull/2366)) +- Add `MakeError` op ([#2377](https://github.com/CQCL/hugr/pull/2377)) +- Open lists and tuples in `Term` ([#2360](https://github.com/CQCL/hugr/pull/2360)) +- Call `FunctionBuilder::add_{in,out}put` for any AsMut ([#2376](https://github.com/CQCL/hugr/pull/2376)) +- Add Root checked methods to DataflowParentID ([#2382](https://github.com/CQCL/hugr/pull/2382)) +- Add PersistentWire type ([#2361](https://github.com/CQCL/hugr/pull/2361)) +- Add `BorrowArray` extension ([#2395](https://github.com/CQCL/hugr/pull/2395)) +- [**breaking**] Add Visibility to FuncDefn/FuncDecl. ([#2143](https://github.com/CQCL/hugr/pull/2143)) +- *(per)* [**breaking**] Support empty wires in commits ([#2349](https://github.com/CQCL/hugr/pull/2349)) +- [**breaking**] hugr-model use explicit Option, with ::Unspecified in capnp ([#2424](https://github.com/CQCL/hugr/pull/2424)) +- [**breaking**] No nested FuncDefns (or AliasDefns) ([#2256](https://github.com/CQCL/hugr/pull/2256)) +- [**breaking**] Rename 'Any' type bound to 'Linear' ([#2421](https://github.com/CQCL/hugr/pull/2421)) + +### Refactor + +- [**breaking**] remove deprecated runtime extension errors ([#2369](https://github.com/CQCL/hugr/pull/2369)) +- [**breaking**] Reduce error type sizes ([#2420](https://github.com/CQCL/hugr/pull/2420)) +- [**breaking**] move PersistentHugr into separate crate ([#2277](https://github.com/CQCL/hugr/pull/2277)) + +### Testing + +- Check hugr json serializations against the schema (again) ([#2216](https://github.com/CQCL/hugr/pull/2216)) + ## [0.20.2](https://github.com/CQCL/hugr/compare/hugr-v0.20.1...hugr-v0.20.2) - 2025-06-25 ### Bug Fixes diff --git a/hugr/Cargo.toml b/hugr/Cargo.toml index d43ee4ae90..895c147a60 100644 --- a/hugr/Cargo.toml +++ b/hugr/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "hugr" -version = "0.20.2" +version = "0.22.1" edition = { workspace = true } rust-version = { workspace = true } @@ -28,12 +28,14 @@ declarative = ["hugr-core/declarative"] llvm = ["hugr-llvm/llvm14-0"] llvm-test = ["hugr-llvm/llvm14-0", "hugr-llvm/test-utils"] zstd = ["hugr-core/zstd"] +persistent_unstable = ["hugr-persistent"] [dependencies] -hugr-model = { path = "../hugr-model", version = "0.20.2" } -hugr-core = { path = "../hugr-core", version = "0.20.2" } -hugr-passes = { path = "../hugr-passes", version = "0.20.2" } -hugr-llvm = { path = "../hugr-llvm", version = "0.20.2", optional = true } +hugr-model = { path = "../hugr-model", version = "0.22.1" } +hugr-core = { path = "../hugr-core", version = "0.22.1" } +hugr-passes = { path = "../hugr-passes", version = "0.22.1" } +hugr-llvm = { path = "../hugr-llvm", version = "0.22.1", optional = true } +hugr-persistent = { path = "../hugr-persistent", version = "0.2.1", optional = true } [dev-dependencies] lazy_static = { workspace = true } diff --git a/hugr/benches/benchmarks/hugr/examples.rs b/hugr/benches/benchmarks/hugr/examples.rs index 2fdd1a762e..8e274ac8ff 100644 --- a/hugr/benches/benchmarks/hugr/examples.rs +++ b/hugr/benches/benchmarks/hugr/examples.rs @@ -3,8 +3,8 @@ use std::sync::{Arc, LazyLock}; use hugr::builder::{ - BuildError, CFGBuilder, Container, DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer, - HugrBuilder, ModuleBuilder, + BuildError, CFGBuilder, DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer, HugrBuilder, + ModuleBuilder, }; use hugr::extension::ExtensionRegistry; use hugr::extension::prelude::{bool_t, qb_t, usize_t}; diff --git a/hugr/benches/benchmarks/types.rs b/hugr/benches/benchmarks/types.rs index 0ed0a12a05..5b564bddde 100644 --- a/hugr/benches/benchmarks/types.rs +++ b/hugr/benches/benchmarks/types.rs @@ -13,7 +13,7 @@ fn make_complex_type() -> Type { let int = usize_t(); let q_register = Type::new_tuple(vec![qb; 8]); let b_register = Type::new_tuple(vec![int; 8]); - let q_alias = Type::new_alias(AliasDecl::new("QReg", TypeBound::Any)); + let q_alias = Type::new_alias(AliasDecl::new("QReg", TypeBound::Linear)); let sum = Type::new_sum([q_register, q_alias]); Type::new_function(Signature::new(vec![sum], vec![b_register])) } diff --git a/hugr/src/lib.rs b/hugr/src/lib.rs index 341d74f12b..ce460ce404 100644 --- a/hugr/src/lib.rs +++ b/hugr/src/lib.rs @@ -140,6 +140,10 @@ pub use hugr_passes as algorithms; #[doc(inline)] pub use hugr_llvm as llvm; +#[cfg(feature = "persistent_unstable")] +#[doc(hidden)] // TODO: remove when stable +pub use hugr_persistent as persistent; + // Modules with hand-picked re-exports. pub mod hugr; diff --git a/justfile b/justfile index 87771f9803..a092d45bf9 100644 --- a/justfile +++ b/justfile @@ -10,7 +10,7 @@ setup: # Run the pre-commit checks. check: - uv run pre-commit run --all-files + HUGR_TEST_SCHEMA=1 uv run pre-commit run --all-files # Run all the tests. test: test-rust test-python @@ -20,7 +20,7 @@ test-rust *TEST_ARGS: @# built into a binary build (without using `maturin`) @# @# This feature list should be kept in sync with the `hugr-py/pyproject.toml` - cargo test \ + HUGR_TEST_SCHEMA=1 cargo test \ --workspace \ --exclude 'hugr-py' \ --features 'hugr/declarative hugr/llvm hugr/llvm-test hugr/zstd' {{TEST_ARGS}} diff --git a/release-plz.toml b/release-plz.toml index 4bc9f71047..ebedf1be07 100644 --- a/release-plz.toml +++ b/release-plz.toml @@ -22,6 +22,11 @@ pr_labels = ["release"] release_always = false [changelog] + +header = """# Changelog + +""" + sort_commits = "oldest" # Allowed conventional commit types @@ -69,3 +74,7 @@ version_group = "hugr" name = "hugr-llvm" release = true version_group = "hugr" + +[[package]] +name = "hugr-persistent" +release = true diff --git a/resources/test/hugr-no-visibility.hugr b/resources/test/hugr-no-visibility.hugr new file mode 100644 index 0000000000..e61f933966 --- /dev/null +++ b/resources/test/hugr-no-visibility.hugr @@ -0,0 +1,52 @@ +HUGRiHJv?@{ + "modules": [ + { + "version": "live", + "nodes": [ + { + "parent": 0, + "op": "Module" + }, + { + "name":"polyfunc1", + "op":"FuncDecl", + "parent":0, + "signature":{ + "body":{ + "input":[], + "output":[] + }, + "params":[ + ] + } + }, + { + "name":"polyfunc2", + "op":"FuncDefn", + "parent":0, + "signature":{ + "body":{ + "input":[], + "output":[] + }, + "params":[ + ] + } + }, + { + "op": "Input", + "parent": 2, + "types": [] + }, + { + "op": "Output", + "parent": 2, + "types": [] + } + ], + "edges": [], + "encoder": null + } + ], + "extensions": [] +} diff --git a/scripts/check_extension_versions.py b/scripts/check_extension_versions.py new file mode 100644 index 0000000000..f871e23fea --- /dev/null +++ b/scripts/check_extension_versions.py @@ -0,0 +1,90 @@ +#!/usr/bin/env python + +import json +import subprocess +import sys +from pathlib import Path + + +def get_changed_files(target: str) -> list[Path]: + """Get list of changed extension files in the PR""" + # Use git to get the list of files changed compared to target + cmd = [ + "git", + "diff", + "--name-only", + target, + "--", + "specification/std_extensions/", + ] + result = subprocess.run(cmd, capture_output=True, text=True, check=True) # noqa: S603 + changed_files = [Path(f) for f in result.stdout.splitlines() if f.endswith(".json")] + return changed_files + + +def check_version_changes(changed_files: list[Path], target: str) -> list[str]: + """Check if versions have been updated in changed files""" + errors = [] + + for file_path in changed_files: + # Skip files that don't exist anymore (deleted files) + if not file_path.exists(): + continue + + # Get the version in the current branch + with file_path.open("r") as f: + current = json.load(f) + current_version = current.get("version") + + # Get the version in the target branch + try: + cmd = ["git", "show", f"{target}:{file_path}"] + result = subprocess.run(cmd, capture_output=True, text=True) # noqa: S603 + + if result.returncode == 0: + # File exists in target + target_content = json.loads(result.stdout) + target_version = target_content.get("version") + + if current_version == target_version: + errors.append( + f"Error: {file_path} was modified but version {current_version}" + " was not updated." + ) + else: + print( + f"Version updated in {file_path}: {target_version}" + f" -> {current_version}" + ) + + else: + # New file - no version check needed + pass + + except json.JSONDecodeError: + # File is new or not valid JSON in target + pass + return errors + + +def main() -> int: + target = sys.argv[1] if len(sys.argv) > 1 else "origin/main" + changed_files = get_changed_files(target) + if not changed_files: + print("No extension files changed.") + return 0 + + print(f"Changed extension files: {', '.join(map(str, changed_files))}") + + errors = check_version_changes(changed_files, target) + if errors: + for error in errors: + sys.stderr.write(error) + return 1 + + print("All changed extension files have updated versions.") + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/scripts/generate_schema.py b/scripts/generate_schema.py index 98661914ce..eea90351ad 100644 --- a/scripts/generate_schema.py +++ b/scripts/generate_schema.py @@ -14,7 +14,7 @@ from pathlib import Path from pydantic import ConfigDict -from pydantic.json_schema import models_json_schema +from pydantic.json_schema import DEFAULT_REF_TEMPLATE, models_json_schema from hugr._serialization.extension import Extension, Package from hugr._serialization.serial_hugr import SerialHugr @@ -38,6 +38,9 @@ def write_schema( _, top_level_schema = models_json_schema( [(s, "validation") for s in schemas], title="HUGR schema" ) + top_level_schema["oneOf"] = [ + {"$ref": DEFAULT_REF_TEMPLATE.format(model=s.__name__)} for s in schemas + ] with path.open("w") as f: json.dump(top_level_schema, f, indent=4) diff --git a/specification/hugr.md b/specification/hugr.md index dc8251f30a..2cf934fd97 100644 --- a/specification/hugr.md +++ b/specification/hugr.md @@ -248,6 +248,11 @@ edges. The following operations are *only* valid as immediate children of a - `AliasDecl`: an external type alias declaration. At link time this can be replaced with the definition. An alias declared with `AliasDecl` is equivalent to a named opaque type. +- `FuncDefn` : a function definition. Like `FuncDecl` but with a function body. + The function body is defined by the sibling graph formed by its children. + At link time `FuncDecl` nodes are replaced by `FuncDefn`. +- `AliasDefn`: type alias definition. At link time `AliasDecl` can be replaced with + `AliasDefn`. There may also be other [scoped definitions](#scoped-definitions). @@ -258,11 +263,6 @@ regions and control-flow regions: - `Const` : a static constant value of type T stored in the node weight. Like `FuncDecl` and `FuncDefn` this has one `Const` out-edge per use. -- `FuncDefn` : a function definition. Like `FuncDecl` but with a function body. - The function body is defined by the sibling graph formed by its children. - At link time `FuncDecl` nodes are replaced by `FuncDefn`. -- `AliasDefn`: type alias definition. At link time `AliasDecl` can be replaced with - `AliasDefn`. A **loadable HUGR** is a module HUGR where all input ports are connected and there are no `FuncDecl/AliasDecl` nodes. @@ -552,11 +552,8 @@ parent(n2) when the edge's locality is: Each of these localities have additional constraints as follows: 1. For Ext edges, we require parent(n1) == - parenti(n2) for some i\>1, *and* for Value edges only: - * there must be a order edge from n1 to - parenti-1(n2). - * None of the parentj(n2), for i\>j\>=1, - may be a FuncDefn node + parenti(n2) for some i\>1, *and* for Value edges only there must be a order edge from n1 to + parenti-1(n2). The order edge records the ordering requirement that results, i.e. it must be possible to @@ -569,9 +566,6 @@ Each of these localities have additional constraints as follows: For Static edges this order edge is not required since the source is guaranteed to causally precede the target. - The FuncDefn restriction means that FuncDefn really are static, - and do not capture runtime values from their environment. - 2. For Dom edges, we must have that parent2(n1) == parenti(n2) is a CFG-node, for some i\>1, **and** parent(n1) strictly dominates @@ -580,8 +574,6 @@ Each of these localities have additional constraints as follows: i\>1 allows the node to target an arbitrarily-deep descendant of the dominated block, similar to an Ext edge.) - The same FuncDefn restriction also applies here, on the parent(j)(n2) for i\>j\>=1 (of course j=i is the CFG and j=i-1 is the basic block). - Specifically, these rules allow for edges where in a given execution of the HUGR the source of the edge executes once, but the target may execute \>=0 times. @@ -779,7 +771,7 @@ existing metadata, given the node ID. engine)? Reserved metadata keys used by the HUGR tooling are prefixed with `core.`. -Use of this prefix by external tooling may cause issues. +Use of this prefix by external tooling may cause issues. #### Generator Metadata Tooling generating HUGR can specify some reserved metadata keys to be used for debugging @@ -832,7 +824,7 @@ copied or discarded (multiple or 0 links from on output port respectively): allows multiple (or 0) outgoing edges from an outport; also these types can be sent down `Const` edges. -Note that all dataflow inputs (`Value`, `Const` and `Function`) always require a single connection, regardless of whether the type is `AnyType` or `Copyable`. +Note that all dataflow inputs (`Value`, `Const` and `Function`) always require a single connection, regardless of whether the type is `Linear` or `Copyable`. **Rows** The `#` is a *row* which is a sequence of zero or more types. Types in the row can optionally be given names in metadata i.e. this does not affect behaviour of the HUGR. When writing literal types, we use `#` to distinguish between tuples and rows, e.g. `(int<1>,int<2>)` is a tuple while `Sum(#(int<1>),#(int<2>))` contains two rows. @@ -866,6 +858,9 @@ such declarations may include (bind) any number of type parameters, of kinds as TypeParam ::= Type(Any|Copyable) | BoundedUSize(u64|) -- note optional bound | Extensions + | String + | Bytes + | Float | List(TypeParam) -- homogeneous, any sized | Tuple([TypeParam]) -- heterogenous, fixed size | Opaque(Name, [TypeArg]) -- e.g. Opaque("Array", [5, Opaque("usize", [])]) @@ -883,22 +878,26 @@ TypeArgs appropriate for the function's TypeParams: ```haskell TypeArg ::= Type(Type) -- could be a variable of kind Type, or contain variable(s) | BoundedUSize(u64) + | String(String) + | Bytes([u8]) + | Float(f64) | Extensions(Extensions) -- may contain TypeArg's of kind Extensions - | Sequence([TypeArg]) -- fits either a List or Tuple TypeParam + | List([TypeArg]) + | Tuple([TypeArg]) | Opaque(Value) | Variable -- refers to an enclosing TypeParam (binder) of any kind above ``` For example, a Function node declaring a `TypeParam::Opaque("Array", [5, TypeArg::Type(Type::Opaque("usize"))])` means that any `Call` to it must statically provide a *value* that is an array of 5 `usize`s; -or a Function node declaring a `TypeParam::BoundedUSize(5)` and a `TypeParam::Type(Any)` requires two TypeArgs, +or a Function node declaring a `TypeParam::BoundedUSize(5)` and a `TypeParam::Type(Linear)` requires two TypeArgs, firstly a non-negative integer less than 5, secondly a type (which might be from an extension, e.g. `usize`). Given TypeArgs, the body of the Function node's type can be converted to a monomorphic signature by substitution, i.e. replacing each type variable in the body with the corresponding TypeArg. This is guaranteed to produce a valid type as long as the TypeArgs match the declared TypeParams, which can be checked in advance. -(Note that within a polymorphic type scheme, type variables of kind `Sequence` or `Opaque` will only be usable +(Note that within a polymorphic type scheme, type variables of kind `List`, `Tuple` or `Opaque` will only be usable as arguments to Opaque types---see [Extension System](#extension-system).) #### Row Variables @@ -910,16 +909,16 @@ treatment, as follows: but also a single `TypeArg::Type`. (This is purely a notational convenience.) For example, `Type::Function(usize, unit, )` is equivalent shorthand for `Type::Function(#(usize), #(unit), )`. -* When a `TypeArg::Sequence` is provided as argument for such a TypeParam, we allow +* When a `TypeArg::List` is provided as argument for such a TypeParam, we allow elements to be a mixture of both types (including variables of kind `TypeParam::Type(_)`) and also row variables. When such variables are instantiated - (with other Sequences) the elements of the inner Sequence are spliced directly into - the outer (concatenating their elements), eliding the inner (Sequence) wrapper. + (with other `List`s) the elements of the inner `List` are spliced directly into + the outer (concatenating their elements), eliding the inner (`List`) wrapper. For example, a polymorphic FuncDefn might declare a row variable X of kind `TypeParam::List(TypeParam::Type(Copyable))` and have as output a (tuple) type `Sum([#(X, usize)])`. A call that instantiates said type-parameter with -`TypeArg::Sequence([usize, unit])` would then have output `Sum([#(usize, unit, usize)])`. +`TypeArg::List([usize, unit])` would then have output `Sum([#(usize, unit, usize)])`. See [Declarative Format](#declarative-format) for more examples. diff --git a/specification/schema/hugr_schema_live.json b/specification/schema/hugr_schema_live.json index 4a5c38b0e6..a0fd06f72a 100644 --- a/specification/schema/hugr_schema_live.json +++ b/specification/schema/hugr_schema_live.json @@ -130,6 +130,41 @@ "title": "BoundedNatParam", "type": "object" }, + "BytesArg": { + "additionalProperties": true, + "properties": { + "tya": { + "const": "Bytes", + "default": "Bytes", + "title": "Tya", + "type": "string" + }, + "value": { + "contentEncoding": "base64", + "description": "Base64-encoded byte string", + "title": "Value", + "type": "string" + } + }, + "required": [ + "value" + ], + "title": "BytesArg", + "type": "object" + }, + "BytesParam": { + "additionalProperties": true, + "properties": { + "tp": { + "const": "Bytes", + "default": "Bytes", + "title": "Tp", + "type": "string" + } + }, + "title": "BytesParam", + "type": "object" + }, "CFG": { "additionalProperties": true, "description": "A dataflow node which is defined by a child CFG.", @@ -546,6 +581,7 @@ "type": "object" }, "FixedHugr": { + "description": "Fixed HUGR used to define the lowering of an operation.\n\nArgs:\n extensions: Extensions used in the HUGR.\n hugr: Base64-encoded HUGR envelope.", "properties": { "extensions": { "items": { @@ -566,6 +602,39 @@ "title": "FixedHugr", "type": "object" }, + "FloatArg": { + "additionalProperties": true, + "properties": { + "tya": { + "const": "Float", + "default": "Float", + "title": "Tya", + "type": "string" + }, + "value": { + "title": "Value", + "type": "number" + } + }, + "required": [ + "value" + ], + "title": "FloatArg", + "type": "object" + }, + "FloatParam": { + "additionalProperties": true, + "properties": { + "tp": { + "const": "Float", + "default": "Float", + "title": "Tp", + "type": "string" + } + }, + "title": "FloatParam", + "type": "object" + }, "FromParamsBound": { "properties": { "b": { @@ -608,6 +677,15 @@ }, "signature": { "$ref": "#/$defs/PolyFuncType" + }, + "visibility": { + "default": "Public", + "enum": [ + "Public", + "Private" + ], + "title": "Visibility", + "type": "string" } }, "required": [ @@ -638,6 +716,15 @@ }, "signature": { "$ref": "#/$defs/PolyFuncType" + }, + "visibility": { + "default": "Private", + "enum": [ + "Public", + "Private" + ], + "title": "Visibility", + "type": "string" } }, "required": [ @@ -691,7 +778,8 @@ "type": "string" }, "hugr": { - "title": "Hugr" + "title": "Hugr", + "type": "string" } }, "required": [ @@ -761,6 +849,29 @@ "title": "Input", "type": "object" }, + "ListArg": { + "additionalProperties": true, + "properties": { + "tya": { + "const": "List", + "default": "List", + "title": "Tya", + "type": "string" + }, + "elems": { + "items": { + "$ref": "#/$defs/TypeArg" + }, + "title": "Elems", + "type": "array" + } + }, + "required": [ + "elems" + ], + "title": "ListArg", + "type": "object" + }, "ListParam": { "additionalProperties": true, "properties": { @@ -1171,29 +1282,6 @@ "title": "RowVar", "type": "object" }, - "SequenceArg": { - "additionalProperties": true, - "properties": { - "tya": { - "const": "Sequence", - "default": "Sequence", - "title": "Tya", - "type": "string" - }, - "elems": { - "items": { - "$ref": "#/$defs/TypeArg" - }, - "title": "Elems", - "type": "array" - } - }, - "required": [ - "elems" - ], - "title": "SequenceArg", - "type": "object" - }, "SerialHugr": { "additionalProperties": true, "description": "A serializable representation of a Hugr.", @@ -1376,17 +1464,30 @@ "description": "A Sum variant For any Sum type where this value meets the type of the variant indicated by the tag.", "properties": { "v": { - "const": "Sum", "default": "Sum", + "enum": [ + "Sum", + "Tuple" + ], "title": "ValueTag", "type": "string" }, "tag": { - "title": "Tag", + "default": 0, + "title": "VariantTag", "type": "integer" }, "typ": { - "$ref": "#/$defs/SumType" + "anyOf": [ + { + "$ref": "#/$defs/SumType" + }, + { + "type": "null" + } + ], + "default": null, + "title": "SumType" }, "vs": { "items": { @@ -1397,8 +1498,6 @@ } }, "required": [ - "tag", - "typ", "vs" ], "title": "SumValue", @@ -1483,51 +1582,50 @@ "title": "TailLoop", "type": "object" }, - "TupleParam": { + "TupleArg": { "additionalProperties": true, "properties": { - "tp": { + "tya": { "const": "Tuple", "default": "Tuple", - "title": "Tp", + "title": "Tya", "type": "string" }, - "params": { + "elems": { "items": { - "$ref": "#/$defs/TypeParam" + "$ref": "#/$defs/TypeArg" }, - "title": "Params", + "title": "Elems", "type": "array" } }, "required": [ - "params" + "elems" ], - "title": "TupleParam", + "title": "TupleArg", "type": "object" }, - "TupleValue": { + "TupleParam": { "additionalProperties": true, - "description": "A constant tuple value.", "properties": { - "v": { + "tp": { "const": "Tuple", "default": "Tuple", - "title": "ValueTag", + "title": "Tp", "type": "string" }, - "vs": { + "params": { "items": { - "$ref": "#/$defs/Value" + "$ref": "#/$defs/TypeParam" }, - "title": "Vs", + "title": "Params", "type": "array" } }, "required": [ - "vs" + "params" ], - "title": "TupleValue", + "title": "TupleParam", "type": "object" }, "Type": { @@ -1581,8 +1679,11 @@ "discriminator": { "mapping": { "BoundedNat": "#/$defs/BoundedNatArg", - "Sequence": "#/$defs/SequenceArg", + "Bytes": "#/$defs/BytesArg", + "Float": "#/$defs/FloatArg", + "List": "#/$defs/ListArg", "String": "#/$defs/StringArg", + "Tuple": "#/$defs/TupleArg", "Type": "#/$defs/TypeTypeArg", "Variable": "#/$defs/VariableArg" }, @@ -1599,7 +1700,16 @@ "$ref": "#/$defs/StringArg" }, { - "$ref": "#/$defs/SequenceArg" + "$ref": "#/$defs/BytesArg" + }, + { + "$ref": "#/$defs/FloatArg" + }, + { + "$ref": "#/$defs/ListArg" + }, + { + "$ref": "#/$defs/TupleArg" }, { "$ref": "#/$defs/VariableArg" @@ -1676,6 +1786,8 @@ "discriminator": { "mapping": { "BoundedNat": "#/$defs/BoundedNatParam", + "Bytes": "#/$defs/BytesParam", + "Float": "#/$defs/FloatParam", "List": "#/$defs/ListParam", "String": "#/$defs/StringParam", "Tuple": "#/$defs/TupleParam", @@ -1693,6 +1805,12 @@ { "$ref": "#/$defs/StringParam" }, + { + "$ref": "#/$defs/FloatParam" + }, + { + "$ref": "#/$defs/BytesParam" + }, { "$ref": "#/$defs/ListParam" }, @@ -1791,7 +1909,7 @@ "Extension": "#/$defs/CustomValue", "Function": "#/$defs/FunctionValue", "Sum": "#/$defs/SumValue", - "Tuple": "#/$defs/TupleValue" + "Tuple": "#/$defs/SumValue" }, "propertyName": "v" }, @@ -1802,9 +1920,6 @@ { "$ref": "#/$defs/FunctionValue" }, - { - "$ref": "#/$defs/TupleValue" - }, { "$ref": "#/$defs/SumValue" } @@ -1864,5 +1979,16 @@ "type": "object" } }, - "title": "HUGR schema" + "title": "HUGR schema", + "oneOf": [ + { + "$ref": "#/$defs/SerialHugr" + }, + { + "$ref": "#/$defs/Extension" + }, + { + "$ref": "#/$defs/Package" + } + ] } \ No newline at end of file diff --git a/specification/schema/hugr_schema_strict_live.json b/specification/schema/hugr_schema_strict_live.json index 419eb86d43..cd37e262cd 100644 --- a/specification/schema/hugr_schema_strict_live.json +++ b/specification/schema/hugr_schema_strict_live.json @@ -130,6 +130,41 @@ "title": "BoundedNatParam", "type": "object" }, + "BytesArg": { + "additionalProperties": false, + "properties": { + "tya": { + "const": "Bytes", + "default": "Bytes", + "title": "Tya", + "type": "string" + }, + "value": { + "contentEncoding": "base64", + "description": "Base64-encoded byte string", + "title": "Value", + "type": "string" + } + }, + "required": [ + "value" + ], + "title": "BytesArg", + "type": "object" + }, + "BytesParam": { + "additionalProperties": false, + "properties": { + "tp": { + "const": "Bytes", + "default": "Bytes", + "title": "Tp", + "type": "string" + } + }, + "title": "BytesParam", + "type": "object" + }, "CFG": { "additionalProperties": false, "description": "A dataflow node which is defined by a child CFG.", @@ -546,6 +581,7 @@ "type": "object" }, "FixedHugr": { + "description": "Fixed HUGR used to define the lowering of an operation.\n\nArgs:\n extensions: Extensions used in the HUGR.\n hugr: Base64-encoded HUGR envelope.", "properties": { "extensions": { "items": { @@ -566,6 +602,39 @@ "title": "FixedHugr", "type": "object" }, + "FloatArg": { + "additionalProperties": false, + "properties": { + "tya": { + "const": "Float", + "default": "Float", + "title": "Tya", + "type": "string" + }, + "value": { + "title": "Value", + "type": "number" + } + }, + "required": [ + "value" + ], + "title": "FloatArg", + "type": "object" + }, + "FloatParam": { + "additionalProperties": false, + "properties": { + "tp": { + "const": "Float", + "default": "Float", + "title": "Tp", + "type": "string" + } + }, + "title": "FloatParam", + "type": "object" + }, "FromParamsBound": { "properties": { "b": { @@ -608,6 +677,15 @@ }, "signature": { "$ref": "#/$defs/PolyFuncType" + }, + "visibility": { + "default": "Public", + "enum": [ + "Public", + "Private" + ], + "title": "Visibility", + "type": "string" } }, "required": [ @@ -638,6 +716,15 @@ }, "signature": { "$ref": "#/$defs/PolyFuncType" + }, + "visibility": { + "default": "Private", + "enum": [ + "Public", + "Private" + ], + "title": "Visibility", + "type": "string" } }, "required": [ @@ -691,7 +778,8 @@ "type": "string" }, "hugr": { - "title": "Hugr" + "title": "Hugr", + "type": "string" } }, "required": [ @@ -761,6 +849,29 @@ "title": "Input", "type": "object" }, + "ListArg": { + "additionalProperties": false, + "properties": { + "tya": { + "const": "List", + "default": "List", + "title": "Tya", + "type": "string" + }, + "elems": { + "items": { + "$ref": "#/$defs/TypeArg" + }, + "title": "Elems", + "type": "array" + } + }, + "required": [ + "elems" + ], + "title": "ListArg", + "type": "object" + }, "ListParam": { "additionalProperties": false, "properties": { @@ -1171,29 +1282,6 @@ "title": "RowVar", "type": "object" }, - "SequenceArg": { - "additionalProperties": false, - "properties": { - "tya": { - "const": "Sequence", - "default": "Sequence", - "title": "Tya", - "type": "string" - }, - "elems": { - "items": { - "$ref": "#/$defs/TypeArg" - }, - "title": "Elems", - "type": "array" - } - }, - "required": [ - "elems" - ], - "title": "SequenceArg", - "type": "object" - }, "SerialHugr": { "additionalProperties": false, "description": "A serializable representation of a Hugr.", @@ -1376,17 +1464,30 @@ "description": "A Sum variant For any Sum type where this value meets the type of the variant indicated by the tag.", "properties": { "v": { - "const": "Sum", "default": "Sum", + "enum": [ + "Sum", + "Tuple" + ], "title": "ValueTag", "type": "string" }, "tag": { - "title": "Tag", + "default": 0, + "title": "VariantTag", "type": "integer" }, "typ": { - "$ref": "#/$defs/SumType" + "anyOf": [ + { + "$ref": "#/$defs/SumType" + }, + { + "type": "null" + } + ], + "default": null, + "title": "SumType" }, "vs": { "items": { @@ -1397,8 +1498,6 @@ } }, "required": [ - "tag", - "typ", "vs" ], "title": "SumValue", @@ -1483,51 +1582,50 @@ "title": "TailLoop", "type": "object" }, - "TupleParam": { + "TupleArg": { "additionalProperties": false, "properties": { - "tp": { + "tya": { "const": "Tuple", "default": "Tuple", - "title": "Tp", + "title": "Tya", "type": "string" }, - "params": { + "elems": { "items": { - "$ref": "#/$defs/TypeParam" + "$ref": "#/$defs/TypeArg" }, - "title": "Params", + "title": "Elems", "type": "array" } }, "required": [ - "params" + "elems" ], - "title": "TupleParam", + "title": "TupleArg", "type": "object" }, - "TupleValue": { + "TupleParam": { "additionalProperties": false, - "description": "A constant tuple value.", "properties": { - "v": { + "tp": { "const": "Tuple", "default": "Tuple", - "title": "ValueTag", + "title": "Tp", "type": "string" }, - "vs": { + "params": { "items": { - "$ref": "#/$defs/Value" + "$ref": "#/$defs/TypeParam" }, - "title": "Vs", + "title": "Params", "type": "array" } }, "required": [ - "vs" + "params" ], - "title": "TupleValue", + "title": "TupleParam", "type": "object" }, "Type": { @@ -1581,8 +1679,11 @@ "discriminator": { "mapping": { "BoundedNat": "#/$defs/BoundedNatArg", - "Sequence": "#/$defs/SequenceArg", + "Bytes": "#/$defs/BytesArg", + "Float": "#/$defs/FloatArg", + "List": "#/$defs/ListArg", "String": "#/$defs/StringArg", + "Tuple": "#/$defs/TupleArg", "Type": "#/$defs/TypeTypeArg", "Variable": "#/$defs/VariableArg" }, @@ -1599,7 +1700,16 @@ "$ref": "#/$defs/StringArg" }, { - "$ref": "#/$defs/SequenceArg" + "$ref": "#/$defs/BytesArg" + }, + { + "$ref": "#/$defs/FloatArg" + }, + { + "$ref": "#/$defs/ListArg" + }, + { + "$ref": "#/$defs/TupleArg" }, { "$ref": "#/$defs/VariableArg" @@ -1676,6 +1786,8 @@ "discriminator": { "mapping": { "BoundedNat": "#/$defs/BoundedNatParam", + "Bytes": "#/$defs/BytesParam", + "Float": "#/$defs/FloatParam", "List": "#/$defs/ListParam", "String": "#/$defs/StringParam", "Tuple": "#/$defs/TupleParam", @@ -1693,6 +1805,12 @@ { "$ref": "#/$defs/StringParam" }, + { + "$ref": "#/$defs/FloatParam" + }, + { + "$ref": "#/$defs/BytesParam" + }, { "$ref": "#/$defs/ListParam" }, @@ -1791,7 +1909,7 @@ "Extension": "#/$defs/CustomValue", "Function": "#/$defs/FunctionValue", "Sum": "#/$defs/SumValue", - "Tuple": "#/$defs/TupleValue" + "Tuple": "#/$defs/SumValue" }, "propertyName": "v" }, @@ -1802,9 +1920,6 @@ { "$ref": "#/$defs/FunctionValue" }, - { - "$ref": "#/$defs/TupleValue" - }, { "$ref": "#/$defs/SumValue" } @@ -1864,5 +1979,16 @@ "type": "object" } }, - "title": "HUGR schema" + "title": "HUGR schema", + "oneOf": [ + { + "$ref": "#/$defs/SerialHugr" + }, + { + "$ref": "#/$defs/Extension" + }, + { + "$ref": "#/$defs/Package" + } + ] } \ No newline at end of file diff --git a/specification/schema/testing_hugr_schema_live.json b/specification/schema/testing_hugr_schema_live.json index a9f483d3c4..157facb661 100644 --- a/specification/schema/testing_hugr_schema_live.json +++ b/specification/schema/testing_hugr_schema_live.json @@ -130,6 +130,41 @@ "title": "BoundedNatParam", "type": "object" }, + "BytesArg": { + "additionalProperties": true, + "properties": { + "tya": { + "const": "Bytes", + "default": "Bytes", + "title": "Tya", + "type": "string" + }, + "value": { + "contentEncoding": "base64", + "description": "Base64-encoded byte string", + "title": "Value", + "type": "string" + } + }, + "required": [ + "value" + ], + "title": "BytesArg", + "type": "object" + }, + "BytesParam": { + "additionalProperties": true, + "properties": { + "tp": { + "const": "Bytes", + "default": "Bytes", + "title": "Tp", + "type": "string" + } + }, + "title": "BytesParam", + "type": "object" + }, "CFG": { "additionalProperties": true, "description": "A dataflow node which is defined by a child CFG.", @@ -546,6 +581,7 @@ "type": "object" }, "FixedHugr": { + "description": "Fixed HUGR used to define the lowering of an operation.\n\nArgs:\n extensions: Extensions used in the HUGR.\n hugr: Base64-encoded HUGR envelope.", "properties": { "extensions": { "items": { @@ -566,6 +602,39 @@ "title": "FixedHugr", "type": "object" }, + "FloatArg": { + "additionalProperties": true, + "properties": { + "tya": { + "const": "Float", + "default": "Float", + "title": "Tya", + "type": "string" + }, + "value": { + "title": "Value", + "type": "number" + } + }, + "required": [ + "value" + ], + "title": "FloatArg", + "type": "object" + }, + "FloatParam": { + "additionalProperties": true, + "properties": { + "tp": { + "const": "Float", + "default": "Float", + "title": "Tp", + "type": "string" + } + }, + "title": "FloatParam", + "type": "object" + }, "FromParamsBound": { "properties": { "b": { @@ -608,6 +677,15 @@ }, "signature": { "$ref": "#/$defs/PolyFuncType" + }, + "visibility": { + "default": "Public", + "enum": [ + "Public", + "Private" + ], + "title": "Visibility", + "type": "string" } }, "required": [ @@ -638,6 +716,15 @@ }, "signature": { "$ref": "#/$defs/PolyFuncType" + }, + "visibility": { + "default": "Private", + "enum": [ + "Public", + "Private" + ], + "title": "Visibility", + "type": "string" } }, "required": [ @@ -691,7 +778,8 @@ "type": "string" }, "hugr": { - "title": "Hugr" + "title": "Hugr", + "type": "string" } }, "required": [ @@ -761,6 +849,29 @@ "title": "Input", "type": "object" }, + "ListArg": { + "additionalProperties": true, + "properties": { + "tya": { + "const": "List", + "default": "List", + "title": "Tya", + "type": "string" + }, + "elems": { + "items": { + "$ref": "#/$defs/TypeArg" + }, + "title": "Elems", + "type": "array" + } + }, + "required": [ + "elems" + ], + "title": "ListArg", + "type": "object" + }, "ListParam": { "additionalProperties": true, "properties": { @@ -1171,29 +1282,6 @@ "title": "RowVar", "type": "object" }, - "SequenceArg": { - "additionalProperties": true, - "properties": { - "tya": { - "const": "Sequence", - "default": "Sequence", - "title": "Tya", - "type": "string" - }, - "elems": { - "items": { - "$ref": "#/$defs/TypeArg" - }, - "title": "Elems", - "type": "array" - } - }, - "required": [ - "elems" - ], - "title": "SequenceArg", - "type": "object" - }, "SerialHugr": { "description": "A serializable representation of a Hugr.", "properties": { @@ -1375,17 +1463,30 @@ "description": "A Sum variant For any Sum type where this value meets the type of the variant indicated by the tag.", "properties": { "v": { - "const": "Sum", "default": "Sum", + "enum": [ + "Sum", + "Tuple" + ], "title": "ValueTag", "type": "string" }, "tag": { - "title": "Tag", + "default": 0, + "title": "VariantTag", "type": "integer" }, "typ": { - "$ref": "#/$defs/SumType" + "anyOf": [ + { + "$ref": "#/$defs/SumType" + }, + { + "type": "null" + } + ], + "default": null, + "title": "SumType" }, "vs": { "items": { @@ -1396,8 +1497,6 @@ } }, "required": [ - "tag", - "typ", "vs" ], "title": "SumValue", @@ -1561,51 +1660,50 @@ "title": "TestingHugr", "type": "object" }, - "TupleParam": { + "TupleArg": { "additionalProperties": true, "properties": { - "tp": { + "tya": { "const": "Tuple", "default": "Tuple", - "title": "Tp", + "title": "Tya", "type": "string" }, - "params": { + "elems": { "items": { - "$ref": "#/$defs/TypeParam" + "$ref": "#/$defs/TypeArg" }, - "title": "Params", + "title": "Elems", "type": "array" } }, "required": [ - "params" + "elems" ], - "title": "TupleParam", + "title": "TupleArg", "type": "object" }, - "TupleValue": { + "TupleParam": { "additionalProperties": true, - "description": "A constant tuple value.", "properties": { - "v": { + "tp": { "const": "Tuple", "default": "Tuple", - "title": "ValueTag", + "title": "Tp", "type": "string" }, - "vs": { + "params": { "items": { - "$ref": "#/$defs/Value" + "$ref": "#/$defs/TypeParam" }, - "title": "Vs", + "title": "Params", "type": "array" } }, "required": [ - "vs" + "params" ], - "title": "TupleValue", + "title": "TupleParam", "type": "object" }, "Type": { @@ -1659,8 +1757,11 @@ "discriminator": { "mapping": { "BoundedNat": "#/$defs/BoundedNatArg", - "Sequence": "#/$defs/SequenceArg", + "Bytes": "#/$defs/BytesArg", + "Float": "#/$defs/FloatArg", + "List": "#/$defs/ListArg", "String": "#/$defs/StringArg", + "Tuple": "#/$defs/TupleArg", "Type": "#/$defs/TypeTypeArg", "Variable": "#/$defs/VariableArg" }, @@ -1677,7 +1778,16 @@ "$ref": "#/$defs/StringArg" }, { - "$ref": "#/$defs/SequenceArg" + "$ref": "#/$defs/BytesArg" + }, + { + "$ref": "#/$defs/FloatArg" + }, + { + "$ref": "#/$defs/ListArg" + }, + { + "$ref": "#/$defs/TupleArg" }, { "$ref": "#/$defs/VariableArg" @@ -1754,6 +1864,8 @@ "discriminator": { "mapping": { "BoundedNat": "#/$defs/BoundedNatParam", + "Bytes": "#/$defs/BytesParam", + "Float": "#/$defs/FloatParam", "List": "#/$defs/ListParam", "String": "#/$defs/StringParam", "Tuple": "#/$defs/TupleParam", @@ -1771,6 +1883,12 @@ { "$ref": "#/$defs/StringParam" }, + { + "$ref": "#/$defs/FloatParam" + }, + { + "$ref": "#/$defs/BytesParam" + }, { "$ref": "#/$defs/ListParam" }, @@ -1869,7 +1987,7 @@ "Extension": "#/$defs/CustomValue", "Function": "#/$defs/FunctionValue", "Sum": "#/$defs/SumValue", - "Tuple": "#/$defs/TupleValue" + "Tuple": "#/$defs/SumValue" }, "propertyName": "v" }, @@ -1880,9 +1998,6 @@ { "$ref": "#/$defs/FunctionValue" }, - { - "$ref": "#/$defs/TupleValue" - }, { "$ref": "#/$defs/SumValue" } @@ -1942,5 +2057,16 @@ "type": "object" } }, - "title": "HUGR schema" + "title": "HUGR schema", + "oneOf": [ + { + "$ref": "#/$defs/TestingHugr" + }, + { + "$ref": "#/$defs/Extension" + }, + { + "$ref": "#/$defs/Package" + } + ] } \ No newline at end of file diff --git a/specification/schema/testing_hugr_schema_strict_live.json b/specification/schema/testing_hugr_schema_strict_live.json index 108f69f2f4..33244f3ed7 100644 --- a/specification/schema/testing_hugr_schema_strict_live.json +++ b/specification/schema/testing_hugr_schema_strict_live.json @@ -130,6 +130,41 @@ "title": "BoundedNatParam", "type": "object" }, + "BytesArg": { + "additionalProperties": false, + "properties": { + "tya": { + "const": "Bytes", + "default": "Bytes", + "title": "Tya", + "type": "string" + }, + "value": { + "contentEncoding": "base64", + "description": "Base64-encoded byte string", + "title": "Value", + "type": "string" + } + }, + "required": [ + "value" + ], + "title": "BytesArg", + "type": "object" + }, + "BytesParam": { + "additionalProperties": false, + "properties": { + "tp": { + "const": "Bytes", + "default": "Bytes", + "title": "Tp", + "type": "string" + } + }, + "title": "BytesParam", + "type": "object" + }, "CFG": { "additionalProperties": false, "description": "A dataflow node which is defined by a child CFG.", @@ -546,6 +581,7 @@ "type": "object" }, "FixedHugr": { + "description": "Fixed HUGR used to define the lowering of an operation.\n\nArgs:\n extensions: Extensions used in the HUGR.\n hugr: Base64-encoded HUGR envelope.", "properties": { "extensions": { "items": { @@ -566,6 +602,39 @@ "title": "FixedHugr", "type": "object" }, + "FloatArg": { + "additionalProperties": false, + "properties": { + "tya": { + "const": "Float", + "default": "Float", + "title": "Tya", + "type": "string" + }, + "value": { + "title": "Value", + "type": "number" + } + }, + "required": [ + "value" + ], + "title": "FloatArg", + "type": "object" + }, + "FloatParam": { + "additionalProperties": false, + "properties": { + "tp": { + "const": "Float", + "default": "Float", + "title": "Tp", + "type": "string" + } + }, + "title": "FloatParam", + "type": "object" + }, "FromParamsBound": { "properties": { "b": { @@ -608,6 +677,15 @@ }, "signature": { "$ref": "#/$defs/PolyFuncType" + }, + "visibility": { + "default": "Public", + "enum": [ + "Public", + "Private" + ], + "title": "Visibility", + "type": "string" } }, "required": [ @@ -638,6 +716,15 @@ }, "signature": { "$ref": "#/$defs/PolyFuncType" + }, + "visibility": { + "default": "Private", + "enum": [ + "Public", + "Private" + ], + "title": "Visibility", + "type": "string" } }, "required": [ @@ -691,7 +778,8 @@ "type": "string" }, "hugr": { - "title": "Hugr" + "title": "Hugr", + "type": "string" } }, "required": [ @@ -761,6 +849,29 @@ "title": "Input", "type": "object" }, + "ListArg": { + "additionalProperties": false, + "properties": { + "tya": { + "const": "List", + "default": "List", + "title": "Tya", + "type": "string" + }, + "elems": { + "items": { + "$ref": "#/$defs/TypeArg" + }, + "title": "Elems", + "type": "array" + } + }, + "required": [ + "elems" + ], + "title": "ListArg", + "type": "object" + }, "ListParam": { "additionalProperties": false, "properties": { @@ -1171,29 +1282,6 @@ "title": "RowVar", "type": "object" }, - "SequenceArg": { - "additionalProperties": false, - "properties": { - "tya": { - "const": "Sequence", - "default": "Sequence", - "title": "Tya", - "type": "string" - }, - "elems": { - "items": { - "$ref": "#/$defs/TypeArg" - }, - "title": "Elems", - "type": "array" - } - }, - "required": [ - "elems" - ], - "title": "SequenceArg", - "type": "object" - }, "SerialHugr": { "description": "A serializable representation of a Hugr.", "properties": { @@ -1375,17 +1463,30 @@ "description": "A Sum variant For any Sum type where this value meets the type of the variant indicated by the tag.", "properties": { "v": { - "const": "Sum", "default": "Sum", + "enum": [ + "Sum", + "Tuple" + ], "title": "ValueTag", "type": "string" }, "tag": { - "title": "Tag", + "default": 0, + "title": "VariantTag", "type": "integer" }, "typ": { - "$ref": "#/$defs/SumType" + "anyOf": [ + { + "$ref": "#/$defs/SumType" + }, + { + "type": "null" + } + ], + "default": null, + "title": "SumType" }, "vs": { "items": { @@ -1396,8 +1497,6 @@ } }, "required": [ - "tag", - "typ", "vs" ], "title": "SumValue", @@ -1561,51 +1660,50 @@ "title": "TestingHugr", "type": "object" }, - "TupleParam": { + "TupleArg": { "additionalProperties": false, "properties": { - "tp": { + "tya": { "const": "Tuple", "default": "Tuple", - "title": "Tp", + "title": "Tya", "type": "string" }, - "params": { + "elems": { "items": { - "$ref": "#/$defs/TypeParam" + "$ref": "#/$defs/TypeArg" }, - "title": "Params", + "title": "Elems", "type": "array" } }, "required": [ - "params" + "elems" ], - "title": "TupleParam", + "title": "TupleArg", "type": "object" }, - "TupleValue": { + "TupleParam": { "additionalProperties": false, - "description": "A constant tuple value.", "properties": { - "v": { + "tp": { "const": "Tuple", "default": "Tuple", - "title": "ValueTag", + "title": "Tp", "type": "string" }, - "vs": { + "params": { "items": { - "$ref": "#/$defs/Value" + "$ref": "#/$defs/TypeParam" }, - "title": "Vs", + "title": "Params", "type": "array" } }, "required": [ - "vs" + "params" ], - "title": "TupleValue", + "title": "TupleParam", "type": "object" }, "Type": { @@ -1659,8 +1757,11 @@ "discriminator": { "mapping": { "BoundedNat": "#/$defs/BoundedNatArg", - "Sequence": "#/$defs/SequenceArg", + "Bytes": "#/$defs/BytesArg", + "Float": "#/$defs/FloatArg", + "List": "#/$defs/ListArg", "String": "#/$defs/StringArg", + "Tuple": "#/$defs/TupleArg", "Type": "#/$defs/TypeTypeArg", "Variable": "#/$defs/VariableArg" }, @@ -1677,7 +1778,16 @@ "$ref": "#/$defs/StringArg" }, { - "$ref": "#/$defs/SequenceArg" + "$ref": "#/$defs/BytesArg" + }, + { + "$ref": "#/$defs/FloatArg" + }, + { + "$ref": "#/$defs/ListArg" + }, + { + "$ref": "#/$defs/TupleArg" }, { "$ref": "#/$defs/VariableArg" @@ -1754,6 +1864,8 @@ "discriminator": { "mapping": { "BoundedNat": "#/$defs/BoundedNatParam", + "Bytes": "#/$defs/BytesParam", + "Float": "#/$defs/FloatParam", "List": "#/$defs/ListParam", "String": "#/$defs/StringParam", "Tuple": "#/$defs/TupleParam", @@ -1771,6 +1883,12 @@ { "$ref": "#/$defs/StringParam" }, + { + "$ref": "#/$defs/FloatParam" + }, + { + "$ref": "#/$defs/BytesParam" + }, { "$ref": "#/$defs/ListParam" }, @@ -1869,7 +1987,7 @@ "Extension": "#/$defs/CustomValue", "Function": "#/$defs/FunctionValue", "Sum": "#/$defs/SumValue", - "Tuple": "#/$defs/TupleValue" + "Tuple": "#/$defs/SumValue" }, "propertyName": "v" }, @@ -1880,9 +1998,6 @@ { "$ref": "#/$defs/FunctionValue" }, - { - "$ref": "#/$defs/TupleValue" - }, { "$ref": "#/$defs/SumValue" } @@ -1942,5 +2057,16 @@ "type": "object" } }, - "title": "HUGR schema" + "title": "HUGR schema", + "oneOf": [ + { + "$ref": "#/$defs/TestingHugr" + }, + { + "$ref": "#/$defs/Extension" + }, + { + "$ref": "#/$defs/Package" + } + ] } \ No newline at end of file diff --git a/specification/std_extensions/collections/borrow_arr.json b/specification/std_extensions/collections/borrow_arr.json new file mode 100644 index 0000000000..1774b4aea6 --- /dev/null +++ b/specification/std_extensions/collections/borrow_arr.json @@ -0,0 +1,1139 @@ +{ + "version": "0.1.1", + "name": "collections.borrow_arr", + "types": { + "borrow_array": { + "extension": "collections.borrow_arr", + "name": "borrow_array", + "params": [ + { + "tp": "BoundedNat", + "bound": null + }, + { + "tp": "Type", + "b": "A" + } + ], + "description": "Fixed-length borrow array", + "bound": { + "b": "Explicit", + "bound": "A" + } + } + }, + "operations": { + "borrow": { + "extension": "collections.borrow_arr", + "name": "borrow", + "description": "Take an element from a borrow array (panicking if it was already taken before)", + "signature": { + "params": [ + { + "tp": "BoundedNat", + "bound": null + }, + { + "tp": "Type", + "b": "A" + } + ], + "body": { + "input": [ + { + "t": "Opaque", + "extension": "collections.borrow_arr", + "id": "borrow_array", + "args": [ + { + "tya": "Variable", + "idx": 0, + "cached_decl": { + "tp": "BoundedNat", + "bound": null + } + }, + { + "tya": "Type", + "ty": { + "t": "V", + "i": 1, + "b": "A" + } + } + ], + "bound": "A" + }, + { + "t": "I" + } + ], + "output": [ + { + "t": "V", + "i": 1, + "b": "A" + }, + { + "t": "Opaque", + "extension": "collections.borrow_arr", + "id": "borrow_array", + "args": [ + { + "tya": "Variable", + "idx": 0, + "cached_decl": { + "tp": "BoundedNat", + "bound": null + } + }, + { + "tya": "Type", + "ty": { + "t": "V", + "i": 1, + "b": "A" + } + } + ], + "bound": "A" + } + ] + } + }, + "binary": false + }, + "clone": { + "extension": "collections.borrow_arr", + "name": "clone", + "description": "Clones an array with copyable elements", + "signature": { + "params": [ + { + "tp": "BoundedNat", + "bound": null + }, + { + "tp": "Type", + "b": "C" + } + ], + "body": { + "input": [ + { + "t": "Opaque", + "extension": "collections.borrow_arr", + "id": "borrow_array", + "args": [ + { + "tya": "Variable", + "idx": 0, + "cached_decl": { + "tp": "BoundedNat", + "bound": null + } + }, + { + "tya": "Type", + "ty": { + "t": "V", + "i": 1, + "b": "C" + } + } + ], + "bound": "A" + } + ], + "output": [ + { + "t": "Opaque", + "extension": "collections.borrow_arr", + "id": "borrow_array", + "args": [ + { + "tya": "Variable", + "idx": 0, + "cached_decl": { + "tp": "BoundedNat", + "bound": null + } + }, + { + "tya": "Type", + "ty": { + "t": "V", + "i": 1, + "b": "C" + } + } + ], + "bound": "A" + }, + { + "t": "Opaque", + "extension": "collections.borrow_arr", + "id": "borrow_array", + "args": [ + { + "tya": "Variable", + "idx": 0, + "cached_decl": { + "tp": "BoundedNat", + "bound": null + } + }, + { + "tya": "Type", + "ty": { + "t": "V", + "i": 1, + "b": "C" + } + } + ], + "bound": "A" + } + ] + } + }, + "binary": false + }, + "discard": { + "extension": "collections.borrow_arr", + "name": "discard", + "description": "Discards an array with copyable elements", + "signature": { + "params": [ + { + "tp": "BoundedNat", + "bound": null + }, + { + "tp": "Type", + "b": "C" + } + ], + "body": { + "input": [ + { + "t": "Opaque", + "extension": "collections.borrow_arr", + "id": "borrow_array", + "args": [ + { + "tya": "Variable", + "idx": 0, + "cached_decl": { + "tp": "BoundedNat", + "bound": null + } + }, + { + "tya": "Type", + "ty": { + "t": "V", + "i": 1, + "b": "C" + } + } + ], + "bound": "A" + } + ], + "output": [] + } + }, + "binary": false + }, + "discard_all_borrowed": { + "extension": "collections.borrow_arr", + "name": "discard_all_borrowed", + "description": "Discard a borrow array where all elements have been borrowed", + "signature": { + "params": [ + { + "tp": "BoundedNat", + "bound": null + }, + { + "tp": "Type", + "b": "A" + } + ], + "body": { + "input": [ + { + "t": "Opaque", + "extension": "collections.borrow_arr", + "id": "borrow_array", + "args": [ + { + "tya": "Variable", + "idx": 0, + "cached_decl": { + "tp": "BoundedNat", + "bound": null + } + }, + { + "tya": "Type", + "ty": { + "t": "V", + "i": 1, + "b": "A" + } + } + ], + "bound": "A" + } + ], + "output": [] + } + }, + "binary": false + }, + "discard_empty": { + "extension": "collections.borrow_arr", + "name": "discard_empty", + "description": "Discard an empty array", + "signature": { + "params": [ + { + "tp": "Type", + "b": "A" + } + ], + "body": { + "input": [ + { + "t": "Opaque", + "extension": "collections.borrow_arr", + "id": "borrow_array", + "args": [ + { + "tya": "BoundedNat", + "n": 0 + }, + { + "tya": "Type", + "ty": { + "t": "V", + "i": 0, + "b": "A" + } + } + ], + "bound": "A" + } + ], + "output": [] + } + }, + "binary": false + }, + "from_array": { + "extension": "collections.borrow_arr", + "name": "from_array", + "description": "Turns `array` into `borrow_array`", + "signature": { + "params": [ + { + "tp": "BoundedNat", + "bound": null + }, + { + "tp": "Type", + "b": "A" + } + ], + "body": { + "input": [ + { + "t": "Opaque", + "extension": "collections.array", + "id": "array", + "args": [ + { + "tya": "Variable", + "idx": 0, + "cached_decl": { + "tp": "BoundedNat", + "bound": null + } + }, + { + "tya": "Type", + "ty": { + "t": "V", + "i": 1, + "b": "A" + } + } + ], + "bound": "A" + } + ], + "output": [ + { + "t": "Opaque", + "extension": "collections.borrow_arr", + "id": "borrow_array", + "args": [ + { + "tya": "Variable", + "idx": 0, + "cached_decl": { + "tp": "BoundedNat", + "bound": null + } + }, + { + "tya": "Type", + "ty": { + "t": "V", + "i": 1, + "b": "A" + } + } + ], + "bound": "A" + } + ] + } + }, + "binary": false + }, + "get": { + "extension": "collections.borrow_arr", + "name": "get", + "description": "Get an element from an array", + "signature": { + "params": [ + { + "tp": "BoundedNat", + "bound": null + }, + { + "tp": "Type", + "b": "C" + } + ], + "body": { + "input": [ + { + "t": "Opaque", + "extension": "collections.borrow_arr", + "id": "borrow_array", + "args": [ + { + "tya": "Variable", + "idx": 0, + "cached_decl": { + "tp": "BoundedNat", + "bound": null + } + }, + { + "tya": "Type", + "ty": { + "t": "V", + "i": 1, + "b": "C" + } + } + ], + "bound": "A" + }, + { + "t": "I" + } + ], + "output": [ + { + "t": "Sum", + "s": "General", + "rows": [ + [], + [ + { + "t": "V", + "i": 1, + "b": "C" + } + ] + ] + }, + { + "t": "Opaque", + "extension": "collections.borrow_arr", + "id": "borrow_array", + "args": [ + { + "tya": "Variable", + "idx": 0, + "cached_decl": { + "tp": "BoundedNat", + "bound": null + } + }, + { + "tya": "Type", + "ty": { + "t": "V", + "i": 1, + "b": "C" + } + } + ], + "bound": "A" + } + ] + } + }, + "binary": false + }, + "new_all_borrowed": { + "extension": "collections.borrow_arr", + "name": "new_all_borrowed", + "description": "Create a new borrow array that contains no elements", + "signature": { + "params": [ + { + "tp": "BoundedNat", + "bound": null + }, + { + "tp": "Type", + "b": "A" + } + ], + "body": { + "input": [], + "output": [ + { + "t": "Opaque", + "extension": "collections.borrow_arr", + "id": "borrow_array", + "args": [ + { + "tya": "Variable", + "idx": 0, + "cached_decl": { + "tp": "BoundedNat", + "bound": null + } + }, + { + "tya": "Type", + "ty": { + "t": "V", + "i": 1, + "b": "A" + } + } + ], + "bound": "A" + } + ] + } + }, + "binary": false + }, + "new_array": { + "extension": "collections.borrow_arr", + "name": "new_array", + "description": "Create a new array from elements", + "signature": null, + "binary": true + }, + "pop_left": { + "extension": "collections.borrow_arr", + "name": "pop_left", + "description": "Pop an element from the left of an array", + "signature": null, + "binary": true + }, + "pop_right": { + "extension": "collections.borrow_arr", + "name": "pop_right", + "description": "Pop an element from the right of an array", + "signature": null, + "binary": true + }, + "repeat": { + "extension": "collections.borrow_arr", + "name": "repeat", + "description": "Creates a new array whose elements are initialised by calling the given function n times", + "signature": { + "params": [ + { + "tp": "BoundedNat", + "bound": null + }, + { + "tp": "Type", + "b": "A" + } + ], + "body": { + "input": [ + { + "t": "G", + "input": [], + "output": [ + { + "t": "V", + "i": 1, + "b": "A" + } + ] + } + ], + "output": [ + { + "t": "Opaque", + "extension": "collections.borrow_arr", + "id": "borrow_array", + "args": [ + { + "tya": "Variable", + "idx": 0, + "cached_decl": { + "tp": "BoundedNat", + "bound": null + } + }, + { + "tya": "Type", + "ty": { + "t": "V", + "i": 1, + "b": "A" + } + } + ], + "bound": "A" + } + ] + } + }, + "binary": false + }, + "return": { + "extension": "collections.borrow_arr", + "name": "return", + "description": "Put an element into a borrow array (panicking if there is an element already)", + "signature": { + "params": [ + { + "tp": "BoundedNat", + "bound": null + }, + { + "tp": "Type", + "b": "A" + } + ], + "body": { + "input": [ + { + "t": "Opaque", + "extension": "collections.borrow_arr", + "id": "borrow_array", + "args": [ + { + "tya": "Variable", + "idx": 0, + "cached_decl": { + "tp": "BoundedNat", + "bound": null + } + }, + { + "tya": "Type", + "ty": { + "t": "V", + "i": 1, + "b": "A" + } + } + ], + "bound": "A" + }, + { + "t": "I" + }, + { + "t": "V", + "i": 1, + "b": "A" + } + ], + "output": [ + { + "t": "Opaque", + "extension": "collections.borrow_arr", + "id": "borrow_array", + "args": [ + { + "tya": "Variable", + "idx": 0, + "cached_decl": { + "tp": "BoundedNat", + "bound": null + } + }, + { + "tya": "Type", + "ty": { + "t": "V", + "i": 1, + "b": "A" + } + } + ], + "bound": "A" + } + ] + } + }, + "binary": false + }, + "scan": { + "extension": "collections.borrow_arr", + "name": "scan", + "description": "A combination of map and foldl. Applies a function to each element of the array with an accumulator that is passed through from start to finish. Returns the resulting array and the final state of the accumulator.", + "signature": { + "params": [ + { + "tp": "BoundedNat", + "bound": null + }, + { + "tp": "Type", + "b": "A" + }, + { + "tp": "Type", + "b": "A" + }, + { + "tp": "List", + "param": { + "tp": "Type", + "b": "A" + } + } + ], + "body": { + "input": [ + { + "t": "Opaque", + "extension": "collections.borrow_arr", + "id": "borrow_array", + "args": [ + { + "tya": "Variable", + "idx": 0, + "cached_decl": { + "tp": "BoundedNat", + "bound": null + } + }, + { + "tya": "Type", + "ty": { + "t": "V", + "i": 1, + "b": "A" + } + } + ], + "bound": "A" + }, + { + "t": "G", + "input": [ + { + "t": "V", + "i": 1, + "b": "A" + }, + { + "t": "R", + "i": 3, + "b": "A" + } + ], + "output": [ + { + "t": "V", + "i": 2, + "b": "A" + }, + { + "t": "R", + "i": 3, + "b": "A" + } + ] + }, + { + "t": "R", + "i": 3, + "b": "A" + } + ], + "output": [ + { + "t": "Opaque", + "extension": "collections.borrow_arr", + "id": "borrow_array", + "args": [ + { + "tya": "Variable", + "idx": 0, + "cached_decl": { + "tp": "BoundedNat", + "bound": null + } + }, + { + "tya": "Type", + "ty": { + "t": "V", + "i": 2, + "b": "A" + } + } + ], + "bound": "A" + }, + { + "t": "R", + "i": 3, + "b": "A" + } + ] + } + }, + "binary": false + }, + "set": { + "extension": "collections.borrow_arr", + "name": "set", + "description": "Set an element in an array", + "signature": { + "params": [ + { + "tp": "BoundedNat", + "bound": null + }, + { + "tp": "Type", + "b": "A" + } + ], + "body": { + "input": [ + { + "t": "Opaque", + "extension": "collections.borrow_arr", + "id": "borrow_array", + "args": [ + { + "tya": "Variable", + "idx": 0, + "cached_decl": { + "tp": "BoundedNat", + "bound": null + } + }, + { + "tya": "Type", + "ty": { + "t": "V", + "i": 1, + "b": "A" + } + } + ], + "bound": "A" + }, + { + "t": "I" + }, + { + "t": "V", + "i": 1, + "b": "A" + } + ], + "output": [ + { + "t": "Sum", + "s": "General", + "rows": [ + [ + { + "t": "V", + "i": 1, + "b": "A" + }, + { + "t": "Opaque", + "extension": "collections.borrow_arr", + "id": "borrow_array", + "args": [ + { + "tya": "Variable", + "idx": 0, + "cached_decl": { + "tp": "BoundedNat", + "bound": null + } + }, + { + "tya": "Type", + "ty": { + "t": "V", + "i": 1, + "b": "A" + } + } + ], + "bound": "A" + } + ], + [ + { + "t": "V", + "i": 1, + "b": "A" + }, + { + "t": "Opaque", + "extension": "collections.borrow_arr", + "id": "borrow_array", + "args": [ + { + "tya": "Variable", + "idx": 0, + "cached_decl": { + "tp": "BoundedNat", + "bound": null + } + }, + { + "tya": "Type", + "ty": { + "t": "V", + "i": 1, + "b": "A" + } + } + ], + "bound": "A" + } + ] + ] + } + ] + } + }, + "binary": false + }, + "swap": { + "extension": "collections.borrow_arr", + "name": "swap", + "description": "Swap two elements in an array", + "signature": { + "params": [ + { + "tp": "BoundedNat", + "bound": null + }, + { + "tp": "Type", + "b": "A" + } + ], + "body": { + "input": [ + { + "t": "Opaque", + "extension": "collections.borrow_arr", + "id": "borrow_array", + "args": [ + { + "tya": "Variable", + "idx": 0, + "cached_decl": { + "tp": "BoundedNat", + "bound": null + } + }, + { + "tya": "Type", + "ty": { + "t": "V", + "i": 1, + "b": "A" + } + } + ], + "bound": "A" + }, + { + "t": "I" + }, + { + "t": "I" + } + ], + "output": [ + { + "t": "Sum", + "s": "General", + "rows": [ + [ + { + "t": "Opaque", + "extension": "collections.borrow_arr", + "id": "borrow_array", + "args": [ + { + "tya": "Variable", + "idx": 0, + "cached_decl": { + "tp": "BoundedNat", + "bound": null + } + }, + { + "tya": "Type", + "ty": { + "t": "V", + "i": 1, + "b": "A" + } + } + ], + "bound": "A" + } + ], + [ + { + "t": "Opaque", + "extension": "collections.borrow_arr", + "id": "borrow_array", + "args": [ + { + "tya": "Variable", + "idx": 0, + "cached_decl": { + "tp": "BoundedNat", + "bound": null + } + }, + { + "tya": "Type", + "ty": { + "t": "V", + "i": 1, + "b": "A" + } + } + ], + "bound": "A" + } + ] + ] + } + ] + } + }, + "binary": false + }, + "to_array": { + "extension": "collections.borrow_arr", + "name": "to_array", + "description": "Turns `borrow_array` into `array`", + "signature": { + "params": [ + { + "tp": "BoundedNat", + "bound": null + }, + { + "tp": "Type", + "b": "A" + } + ], + "body": { + "input": [ + { + "t": "Opaque", + "extension": "collections.borrow_arr", + "id": "borrow_array", + "args": [ + { + "tya": "Variable", + "idx": 0, + "cached_decl": { + "tp": "BoundedNat", + "bound": null + } + }, + { + "tya": "Type", + "ty": { + "t": "V", + "i": 1, + "b": "A" + } + } + ], + "bound": "A" + } + ], + "output": [ + { + "t": "Opaque", + "extension": "collections.array", + "id": "array", + "args": [ + { + "tya": "Variable", + "idx": 0, + "cached_decl": { + "tp": "BoundedNat", + "bound": null + } + }, + { + "tya": "Type", + "ty": { + "t": "V", + "i": 1, + "b": "A" + } + } + ], + "bound": "A" + } + ] + } + }, + "binary": false + }, + "unpack": { + "extension": "collections.borrow_arr", + "name": "unpack", + "description": "Unpack an array into its elements", + "signature": null, + "binary": true + } + } +} diff --git a/specification/std_extensions/prelude.json b/specification/std_extensions/prelude.json index 7cf1d02c70..81c2f948a0 100644 --- a/specification/std_extensions/prelude.json +++ b/specification/std_extensions/prelude.json @@ -1,5 +1,5 @@ { - "version": "0.2.0", + "version": "0.2.1", "name": "prelude", "types": { "error": { @@ -77,6 +77,38 @@ }, "binary": false }, + "MakeError": { + "extension": "prelude", + "name": "MakeError", + "description": "Create an error value", + "signature": { + "params": [], + "body": { + "input": [ + { + "t": "I" + }, + { + "t": "Opaque", + "extension": "prelude", + "id": "string", + "args": [], + "bound": "C" + } + ], + "output": [ + { + "t": "Opaque", + "extension": "prelude", + "id": "error", + "args": [], + "bound": "C" + } + ] + } + }, + "binary": false + }, "MakeTuple": { "extension": "prelude", "name": "MakeTuple", diff --git a/uv.lock b/uv.lock index 35a6661182..e2a327aaae 100644 --- a/uv.lock +++ b/uv.lock @@ -281,7 +281,7 @@ wheels = [ [[package]] name = "hugr" -version = "0.12.1" +version = "0.13.0rc1" source = { editable = "hugr-py" } dependencies = [ { name = "graphviz" }, From 1986d0170503ffdf198586b6f2fca1f4c6a58c7f Mon Sep 17 00:00:00 2001 From: Jenny Chen Date: Mon, 28 Jul 2025 14:04:55 -0600 Subject: [PATCH 4/6] added func getter for EmitFuncContext; based on latest hugr 0.22.1 --- hugr-llvm/src/emit/func.rs | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/hugr-llvm/src/emit/func.rs b/hugr-llvm/src/emit/func.rs index 77d865540f..3c5eed2ef1 100644 --- a/hugr-llvm/src/emit/func.rs +++ b/hugr-llvm/src/emit/func.rs @@ -109,6 +109,11 @@ impl<'c, 'a, H: HugrView> EmitFuncContext<'c, 'a, H> { self.todo.insert(node.node()); } + /// Returns the current [FunctionValue] being emitted. + pub fn func(&self) -> FunctionValue<'c> { + self.func + } + /// Returns the internal [Builder]. Callers must ensure that it is /// positioned at the end of a basic block. This invariant is not checked(it /// doesn't seem possible to check it). From 6ac12c6d764b836e235959c39e81a8be54016eff Mon Sep 17 00:00:00 2001 From: Jenny Chen Date: Mon, 28 Jul 2025 19:52:23 -0600 Subject: [PATCH 5/6] added unit test for EmitFuncContext::func getter --- hugr-llvm/src/emit/func.rs | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/hugr-llvm/src/emit/func.rs b/hugr-llvm/src/emit/func.rs index 3c5eed2ef1..feeab3735d 100644 --- a/hugr-llvm/src/emit/func.rs +++ b/hugr-llvm/src/emit/func.rs @@ -357,3 +357,19 @@ pub fn build_ok_or_else<'c, H: HugrView>( let either = builder.build_select(is_ok, right, left, "")?; Ok(either) } + +#[cfg(test)] +mod tests { + #[test] + fn test_func_getter() { + // Use TestContext for consistent test setup + let test_ctx = crate::test::test_ctx(-1); + let emit_context = test_ctx.get_emit_module_context(); + let func_type = emit_context.iw_context().void_type().fn_type(&[], false); + let function = emit_context.module().add_function("test_func", func_type, None); + let func_context = super::EmitFuncContext::new(emit_context, function).unwrap(); + + // Assert the getter returns the correct function + assert_eq!(func_context.func(), function); + } +} From 45dd0ac81058d07a7a66c952a864e83dd1da03fc Mon Sep 17 00:00:00 2001 From: Jenny Chen Date: Mon, 28 Jul 2025 20:02:03 -0600 Subject: [PATCH 6/6] fix format --- hugr-llvm/src/emit/func.rs | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/hugr-llvm/src/emit/func.rs b/hugr-llvm/src/emit/func.rs index feeab3735d..33afc5f26f 100644 --- a/hugr-llvm/src/emit/func.rs +++ b/hugr-llvm/src/emit/func.rs @@ -366,7 +366,9 @@ mod tests { let test_ctx = crate::test::test_ctx(-1); let emit_context = test_ctx.get_emit_module_context(); let func_type = emit_context.iw_context().void_type().fn_type(&[], false); - let function = emit_context.module().add_function("test_func", func_type, None); + let function = emit_context + .module() + .add_function("test_func", func_type, None); let func_context = super::EmitFuncContext::new(emit_context, function).unwrap(); // Assert the getter returns the correct function

{ + fn new(node: PatchNode, port: P, walker: &Walker) -> Self { + debug_assert!( + walker.selected_commits.contains_node(node), + "pinned node not in walker" + ); + if walker.is_pinned(node) { + MaybePinned::Pinned(node, port) + } else { + MaybePinned::Unpinned(node, port) + } + } + + fn is_pinned(&self) -> bool { + matches!(self, MaybePinned::Pinned(_, _)) + } + + fn into_unpinned>(self) -> Option<(PatchNode, PP)> { + match self { + MaybePinned::Pinned(_, _) => None, + MaybePinned::Unpinned(node, port) => Some((node, port.into())), + } + } + + fn into_pinned>(self) -> Option<(PatchNode, PP)> { + match self { + MaybePinned::Pinned(node, port) => Some((node, port.into())), + MaybePinned::Unpinned(_, _) => None, + } + } +} + +impl PinnedWire { + /// Create a new pinned wire in `walker` from a pinned node and a port. + /// + /// # Panics + /// Panics if `node` is not pinned in `walker`. + pub fn from_pinned_port(node: PatchNode, port: impl Into, walker: &Walker) -> Self { + assert!(walker.is_pinned(node), "node must be pinned"); + + let (outgoing_node, outgoing_port) = match port.into().as_directed() { + Either::Left(incoming) => walker + .selected_commits + .get_single_outgoing_port(node, incoming), + Either::Right(outgoing) => (node, outgoing), + }; + + let outgoing = MaybePinned::new(outgoing_node, outgoing_port, walker); + + let incoming = walker + .selected_commits + .get_all_incoming_ports(outgoing_node, outgoing_port) + .map(|(n, p)| MaybePinned::new(n, p, walker)) + .collect(); + + Self { outgoing, incoming } + } + + /// Check if all ports on the wire in the given direction are pinned. + /// + /// A wire is complete in a direction if and only if expanding the wire + /// in that direction would yield no new walkers. If no direction is + /// specified, checks if the wire is complete in both directions. + pub fn is_complete(&self, dir: impl Into>) -> bool { + match dir.into() { + Some(Direction::Outgoing) => self.outgoing.is_pinned(), + Some(Direction::Incoming) => self.incoming.iter().all(|p| p.is_pinned()), + None => self.outgoing.is_pinned() && self.incoming.iter().all(|p| p.is_pinned()), + } + } + + /// Get the outgoing port of the wire, if it is pinned. + /// + /// Returns `None` if the outgoing port is not pinned. + pub fn pinned_outport(&self) -> Option<(PatchNode, OutgoingPort)> { + self.outgoing.into_pinned() + } + + /// Get all pinned incoming ports of the wire. + /// + /// Returns an iterator over all pinned incoming ports. + pub fn pinned_inports(&self) -> impl Iterator + '_ { + self.incoming.iter().filter_map(|&p| p.into_pinned()) + } + + /// Get all pinned ports of the wire. + pub fn all_pinned_ports(&self) -> impl Iterator + '_ { + fn to_port((node, port): (PatchNode, impl Into)) -> (PatchNode, Port) { + (node, port.into()) + } + self.pinned_outport() + .into_iter() + .map(to_port) + .chain(self.pinned_inports().map(to_port)) + } + + /// Get all unpinned ports of the wire, optionally filtering to only those + /// in the given direction. + pub(super) fn unpinned_ports( + &self, + dir: impl Into>, + ) -> impl Iterator + '_ { + let incoming = self + .incoming + .iter() + .filter_map(|p| p.into_unpinned::()); + let outgoing = self.outgoing.into_unpinned::(); + let dir = dir.into(); + mask_iter(incoming, dir != Some(Direction::Outgoing)) + .chain(mask_iter(outgoing, dir != Some(Direction::Incoming))) + } +} + +/// Return an iterator over the items in `iter` if `mask` is true, otherwise +/// return an empty iterator. +#[inline] +fn mask_iter(iter: impl IntoIterator, mask: bool) -> impl Iterator { + match mask { + true => Either::Left(iter.into_iter()), + false => Either::Right(std::iter::empty()), + } + .into_iter() +} diff --git a/hugr-core/src/hugr/serialize.rs b/hugr-core/src/hugr/serialize.rs index 8df7cf8357..7dcb14c1c1 100644 --- a/hugr-core/src/hugr/serialize.rs +++ b/hugr-core/src/hugr/serialize.rs @@ -78,7 +78,6 @@ impl Versioned { #[derive(Clone, Serialize, Deserialize, PartialEq, Debug)] struct NodeSer { - /// Node index of the parent. parent: Node, #[serde(flatten)] op: OpType, @@ -90,7 +89,7 @@ struct SerHugrLatest { /// For each node: (parent, `node_operation`) nodes: Vec, /// for each edge: (src, `src_offset`, tgt, `tgt_offset`) - edges: Vec<[(Node, Option); 2]>, + edges: Vec<[(Node, Option); 2]>, /// for each node: (metadata) #[serde(default)] metadata: Option>>, @@ -114,7 +113,7 @@ pub enum HUGRSerializationError { AttachError(#[from] AttachError), /// Failed to add edge. #[error("Failed to build edge when deserializing: {0}.")] - LinkError(#[from] LinkError), + LinkError(#[from] LinkError), /// Edges without port offsets cannot be present in operations without non-dataflow ports. #[error( "Cannot connect an {dir:?} edge without port offset to node {node} with operation type {op_type}." @@ -215,7 +214,7 @@ impl TryFrom<&Hugr> for SerHugrLatest { let op = hugr.get_optype(node); let is_value_port = offset < op.value_port_count(dir); let is_static_input = op.static_port(dir).is_some_and(|p| p.index() == offset); - let offset = (is_value_port || is_static_input).then_some(offset as u32); + let offset = (is_value_port || is_static_input).then_some(offset as u16); (node_rekey[&node], offset) }; @@ -283,7 +282,7 @@ impl TryFrom for Hugr { } if let Some(entrypoint) = entrypoint { - hugr.set_entrypoint(hugr_node(entrypoint)); + hugr.set_entrypoint(entrypoint); } if let Some(metadata) = metadata { diff --git a/hugr-core/src/hugr/serialize/test.rs b/hugr-core/src/hugr/serialize/test.rs index 60249489df..2b500ed038 100644 --- a/hugr-core/src/hugr/serialize/test.rs +++ b/hugr-core/src/hugr/serialize/test.rs @@ -24,7 +24,7 @@ use crate::types::{ FuncValueType, PolyFuncType, PolyFuncTypeRV, Signature, SumType, Type, TypeArg, TypeBound, TypeRV, }; -use crate::{OutgoingPort, Visibility, type_row}; +use crate::{OutgoingPort, type_row}; use itertools::Itertools; use jsonschema::{Draft, Validator}; @@ -62,29 +62,26 @@ impl NamedSchema { Self { name, schema } } - pub fn check(&self, val: &serde_json::Value) -> Result<(), String> { + pub fn check(&self, val: &serde_json::Value) { let mut errors = self.schema.iter_errors(val).peekable(); - if errors.peek().is_none() { - return Ok(()); + if errors.peek().is_some() { + // errors don't necessarily implement Debug + eprintln!("Schema failed to validate: {}", self.name); + for error in errors { + eprintln!("Validation error: {error}"); + eprintln!("Instance path: {}", error.instance_path); + } + panic!("Serialization test failed."); } - - // errors don't necessarily implement Debug - let mut strs = vec![format!("Schema failed to validate: {}", self.name)]; - strs.extend(errors.flat_map(|error| { - [ - format!("Validation error: {error}"), - format!("Instance path: {}", error.instance_path), - ] - })); - strs.push("Serialization test failed.".to_string()); - Err(strs.join("\n")) } pub fn check_schemas( val: &serde_json::Value, schemas: impl IntoIterator, - ) -> Result<(), String> { - schemas.into_iter().try_for_each(|schema| schema.check(val)) + ) { + for schema in schemas { + schema.check(val); + } } } @@ -92,7 +89,7 @@ macro_rules! include_schema { ($name:ident, $path:literal) => { lazy_static! { static ref $name: NamedSchema = - NamedSchema::new(stringify!($name), { + NamedSchema::new("$name", { let schema_val: serde_json::Value = serde_json::from_str(include_str!( concat!("../../../../specification/schema/", $path, "_live.json") )) @@ -164,7 +161,7 @@ fn ser_deserialize_check_schema( val: serde_json::Value, schemas: impl IntoIterator, ) -> T { - NamedSchema::check_schemas(&val, schemas).unwrap(); + NamedSchema::check_schemas(&val, schemas); serde_json::from_value(val).unwrap() } @@ -174,22 +171,8 @@ fn ser_roundtrip_check_schema, ) -> TDeser { let val = serde_json::to_value(g).unwrap(); - match NamedSchema::check_schemas(&val, schemas) { - Ok(()) => serde_json::from_value(val).unwrap(), - Err(msg) => panic!("ser_roundtrip_check_schema failed with {msg}, input was {val}"), - } -} - -/// Serialize a Hugr and check that it is valid against the schema. -/// -/// # Panics -/// -/// Panics if the serialization fails or if the schema validation fails. -pub(crate) fn check_hugr_serialization_schema(hugr: &Hugr) { - let schemas = get_schemas(true); - let hugr_ser = HugrSer(hugr); - let val = serde_json::to_value(hugr_ser).unwrap(); - NamedSchema::check_schemas(&val, schemas).unwrap(); + NamedSchema::check_schemas(&val, schemas); + serde_json::from_value(val).unwrap() } /// Serialize and deserialize a HUGR, and check that the result is the same as the original. @@ -227,80 +210,8 @@ fn check_testing_roundtrip(t: impl Into) { assert_eq!(before, after); } -fn test_schema_val() -> serde_json::Value { - serde_json::json!({ - "op_def":null, - "optype":{ - "name":"polyfunc1", - "op":"FuncDefn", - "visibility": "Public", - "parent":0, - "signature":{ - "body":{ - "input":[], - "output":[] - }, - "params":[ - {"bound":null,"tp":"BoundedNat"} - ] - } - }, - "poly_func_type":null, - "sum_type":null, - "typ":null, - "value":null, - "version":"live" - }) -} - -fn schema_val() -> serde_json::Value { - serde_json::json!({"nodes": [], "edges": [], "version": "live"}) -} - -#[rstest] -#[case(&TESTING_SCHEMA, &TESTING_SCHEMA_STRICT, test_schema_val(), Some("optype"))] -#[case(&SCHEMA, &SCHEMA_STRICT, schema_val(), None)] -fn wrong_fields( - #[case] lax_schema: &'static NamedSchema, - #[case] strict_schema: &'static NamedSchema, - #[case] mut val: serde_json::Value, - #[case] target_loc: impl IntoIterator + Clone, -) { - use serde_json::Value; - fn get_fields( - val: &mut Value, - mut path: impl Iterator, - ) -> &mut serde_json::Map { - let Value::Object(fields) = val else { panic!() }; - match path.next() { - Some(n) => get_fields(fields.get_mut(n).unwrap(), path), - None => fields, - } - } - // First, some "known good" JSON - NamedSchema::check_schemas(&val, [lax_schema, strict_schema]).unwrap(); - - // Now try adding an extra field - let fields = get_fields(&mut val, target_loc.clone().into_iter()); - fields.insert( - "extra_field".to_string(), - Value::String("not in schema".to_string()), - ); - strict_schema.check(&val).unwrap_err(); - lax_schema.check(&val).unwrap(); - - // And removing one - let fields = get_fields(&mut val, target_loc.into_iter()); - fields.remove("extra_field").unwrap(); - let key = fields.keys().next().unwrap().clone(); - fields.remove(&key).unwrap(); - - lax_schema.check(&val).unwrap_err(); - strict_schema.check(&val).unwrap_err(); -} - /// Generate an optype for a node with a matching amount of inputs and outputs. -fn gen_optype(g: &MultiPortGraph, node: portgraph::NodeIndex) -> OpType { +fn gen_optype(g: &MultiPortGraph, node: portgraph::NodeIndex) -> OpType { let inputs = g.num_inputs(node); let outputs = g.num_outputs(node); match (inputs == 0, outputs == 0) { @@ -517,7 +428,7 @@ fn serialize_types_roundtrip() { #[case(bool_t())] #[case(usize_t())] #[case(INT_TYPES[2].clone())] -#[case(Type::new_alias(crate::ops::AliasDecl::new("t", TypeBound::Linear)))] +#[case(Type::new_alias(crate::ops::AliasDecl::new("t", TypeBound::Any)))] #[case(Type::new_var_use(2, TypeBound::Copyable))] #[case(Type::new_tuple(vec![bool_t(),qb_t()]))] #[case(Type::new_sum([vec![bool_t(),qb_t()], vec![Type::new_unit_sum(4)]]))] @@ -547,13 +458,13 @@ fn roundtrip_value(#[case] value: Value) { fn polyfunctype1() -> PolyFuncType { let function_type = Signature::new_endo(type_row![]); - PolyFuncType::new([TypeParam::max_nat_type()], function_type) + PolyFuncType::new([TypeParam::max_nat()], function_type) } fn polyfunctype2() -> PolyFuncTypeRV { - let tv0 = TypeRV::new_row_var_use(0, TypeBound::Linear); + let tv0 = TypeRV::new_row_var_use(0, TypeBound::Any); let tv1 = TypeRV::new_row_var_use(1, TypeBound::Copyable); - let params = [TypeBound::Linear, TypeBound::Copyable].map(TypeParam::new_list_type); + let params = [TypeBound::Any, TypeBound::Copyable].map(TypeParam::new_list); let inputs = vec![ TypeRV::new_function(FuncValueType::new(tv0.clone(), tv1.clone())), tv0, @@ -568,26 +479,26 @@ fn polyfunctype2() -> PolyFuncTypeRV { #[rstest] #[case(Signature::new_endo(type_row![]).into())] #[case(polyfunctype1())] -#[case(PolyFuncType::new([TypeParam::StringType], Signature::new_endo(vec![Type::new_var_use(0, TypeBound::Copyable)])))] +#[case(PolyFuncType::new([TypeParam::String], Signature::new_endo(vec![Type::new_var_use(0, TypeBound::Copyable)])))] #[case(PolyFuncType::new([TypeBound::Copyable.into()], Signature::new_endo(vec![Type::new_var_use(0, TypeBound::Copyable)])))] -#[case(PolyFuncType::new([TypeParam::new_list_type(TypeBound::Linear)], Signature::new_endo(type_row![])))] -#[case(PolyFuncType::new([TypeParam::new_tuple_type([TypeBound::Linear.into(), TypeParam::bounded_nat_type(2.try_into().unwrap())])], Signature::new_endo(type_row![])))] +#[case(PolyFuncType::new([TypeParam::new_list(TypeBound::Any)], Signature::new_endo(type_row![])))] +#[case(PolyFuncType::new([TypeParam::Tuple { params: [TypeBound::Any.into(), TypeParam::bounded_nat(2.try_into().unwrap())].into() }], Signature::new_endo(type_row![])))] #[case(PolyFuncType::new( - [TypeParam::new_list_type(TypeBound::Linear)], - Signature::new_endo(Type::new_tuple(TypeRV::new_row_var_use(0, TypeBound::Linear)))))] + [TypeParam::new_list(TypeBound::Any)], + Signature::new_endo(Type::new_tuple(TypeRV::new_row_var_use(0, TypeBound::Any)))))] fn roundtrip_polyfunctype_fixedlen(#[case] poly_func_type: PolyFuncType) { check_testing_roundtrip(poly_func_type); } #[rstest] #[case(FuncValueType::new_endo(type_row![]).into())] -#[case(PolyFuncTypeRV::new([TypeParam::StringType], FuncValueType::new_endo(vec![Type::new_var_use(0, TypeBound::Copyable)])))] +#[case(PolyFuncTypeRV::new([TypeParam::String], FuncValueType::new_endo(vec![Type::new_var_use(0, TypeBound::Copyable)])))] #[case(PolyFuncTypeRV::new([TypeBound::Copyable.into()], FuncValueType::new_endo(vec![Type::new_var_use(0, TypeBound::Copyable)])))] -#[case(PolyFuncTypeRV::new([TypeParam::new_list_type(TypeBound::Linear)], FuncValueType::new_endo(type_row![])))] -#[case(PolyFuncTypeRV::new([TypeParam::new_tuple_type([TypeBound::Linear.into(), TypeParam::bounded_nat_type(2.try_into().unwrap())])], FuncValueType::new_endo(type_row![])))] +#[case(PolyFuncTypeRV::new([TypeParam::new_list(TypeBound::Any)], FuncValueType::new_endo(type_row![])))] +#[case(PolyFuncTypeRV::new([TypeParam::Tuple { params: [TypeBound::Any.into(), TypeParam::bounded_nat(2.try_into().unwrap())].into() }], FuncValueType::new_endo(type_row![])))] #[case(PolyFuncTypeRV::new( - [TypeParam::new_list_type(TypeBound::Linear)], - FuncValueType::new_endo(TypeRV::new_row_var_use(0, TypeBound::Linear))))] + [TypeParam::new_list(TypeBound::Any)], + FuncValueType::new_endo(TypeRV::new_row_var_use(0, TypeBound::Any))))] #[case(polyfunctype2())] fn roundtrip_polyfunctype_varlen(#[case] poly_func_type: PolyFuncTypeRV) { check_testing_roundtrip(poly_func_type); @@ -595,15 +506,15 @@ fn roundtrip_polyfunctype_varlen(#[case] poly_func_type: PolyFuncTypeRV) { #[rstest] #[case(ops::Module::new())] -#[case(ops::FuncDefn::new_vis("polyfunc1", polyfunctype1(), Visibility::Private))] -#[case(ops::FuncDefn::new_vis("pubfunc1", polyfunctype1(), Visibility::Public))] +#[case(ops::FuncDefn::new("polyfunc1", polyfunctype1()))] +#[case(ops::FuncDecl::new("polyfunc2", polyfunctype1()))] #[case(ops::AliasDefn { name: "aliasdefn".into(), definition: Type::new_unit_sum(4)})] -#[case(ops::AliasDecl { name: "aliasdecl".into(), bound: TypeBound::Linear})] +#[case(ops::AliasDecl { name: "aliasdecl".into(), bound: TypeBound::Any})] #[case(ops::Const::new(Value::false_val()))] #[case(ops::Const::new(Value::function(crate::builder::test::simple_dfg_hugr()).unwrap()))] #[case(ops::Input::new(vec![Type::new_var_use(3,TypeBound::Copyable)]))] #[case(ops::Output::new(vec![Type::new_function(FuncValueType::new_endo(type_row![]))]))] -#[case(ops::Call::try_new(polyfunctype1(), [TypeArg::BoundedNat(1)]).unwrap())] +#[case(ops::Call::try_new(polyfunctype1(), [TypeArg::BoundedNat{n: 1}]).unwrap())] #[case(ops::CallIndirect { signature : Signature::new_endo(vec![bool_t()]) })] fn roundtrip_optype(#[case] optype: impl Into + std::fmt::Debug) { check_testing_roundtrip(NodeSer { @@ -618,7 +529,7 @@ fn std_extensions_valid() { let std_reg = crate::std_extensions::std_reg(); for ext in std_reg { let val = serde_json::to_value(ext).unwrap(); - NamedSchema::check_schemas(&val, get_schemas(true)).unwrap(); + NamedSchema::check_schemas(&val, get_schemas(true)); // check deserialises correctly, can't check equality because of custom binaries. let deser: crate::extension::Extension = serde_json::from_value(val.clone()).unwrap(); assert_eq!(serde_json::to_value(deser).unwrap(), val); diff --git a/hugr-core/src/hugr/serialize/upgrade/testcases/hugr_with_named_op.json b/hugr-core/src/hugr/serialize/upgrade/testcases/hugr_with_named_op.json index 112581e94f..ca3965d874 100644 --- a/hugr-core/src/hugr/serialize/upgrade/testcases/hugr_with_named_op.json +++ b/hugr-core/src/hugr/serialize/upgrade/testcases/hugr_with_named_op.json @@ -3,8 +3,67 @@ "nodes": [ { "parent": 0, - "op": "DFG", + "op": "Module" + }, + { + "parent": 0, + "op": "FuncDefn", "name": "main", + "signature": { + "params": [], + "body": { + "input": [ + { + "t": "Sum", + "s": "Unit", + "size": 2 + }, + { + "t": "Sum", + "s": "Unit", + "size": 2 + } + ], + "output": [ + { + "t": "Sum", + "s": "Unit", + "size": 2 + } + ] + } + } + }, + { + "parent": 1, + "op": "Input", + "types": [ + { + "t": "Sum", + "s": "Unit", + "size": 2 + }, + { + "t": "Sum", + "s": "Unit", + "size": 2 + } + ] + }, + { + "parent": 1, + "op": "Output", + "types": [ + { + "t": "Sum", + "s": "Unit", + "size": 2 + } + ] + }, + { + "parent": 1, + "op": "DFG", "signature": { "input": [ { @@ -28,7 +87,7 @@ } }, { - "parent": 0, + "parent": 4, "op": "Input", "types": [ { @@ -44,7 +103,7 @@ ] }, { - "parent": 0, + "parent": 4, "op": "Output", "types": [ { @@ -55,7 +114,7 @@ ] }, { - "parent": 0, + "parent": 4, "op": "Extension", "extension": "logic", "name": "And", @@ -87,40 +146,75 @@ "edges": [ [ [ - 1, + 2, 0 ], [ - 3, + 4, 0 ] ], [ [ - 1, + 2, 1 ], [ - 3, + 4, 1 ] ], [ + [ + 4, + 0 + ], [ 3, 0 + ] + ], + [ + [ + 5, + 0 ], [ - 2, + 7, + 0 + ] + ], + [ + [ + 5, + 1 + ], + [ + 7, + 1 + ] + ], + [ + [ + 7, + 0 + ], + [ + 6, 0 ] ] ], "metadata": [ + null, + null, + null, + null, null, null, null, null ], - "encoder": "hugr-rs v0.15.4" + "encoder": "hugr-rs v0.15.4", + "entrypoint": 4 } diff --git a/hugr-core/src/hugr/validate.rs b/hugr-core/src/hugr/validate.rs index 41fe7ba45b..8291cdcde9 100644 --- a/hugr-core/src/hugr/validate.rs +++ b/hugr-core/src/hugr/validate.rs @@ -1,7 +1,6 @@ //! HUGR invariant checks. use std::collections::HashMap; -use std::collections::hash_map::Entry; use std::iter; use itertools::Itertools; @@ -20,8 +19,9 @@ use crate::ops::validate::{ use crate::ops::{NamedOp, OpName, OpTag, OpTrait, OpType, ValidateOp}; use crate::types::EdgeKind; use crate::types::type_param::TypeParam; -use crate::{Direction, Port, Visibility}; +use crate::{Direction, Port}; +use super::ExtensionError; use super::internal::PortgraphNodeMap; use super::views::HugrView; @@ -60,7 +60,6 @@ impl<'a, H: HugrView> ValidationContext<'a, H> { // Hierarchy and children. No type variables declared outside the root. self.validate_subtree(self.hugr.entrypoint(), &[])?; - self.validate_linkage()?; // In tests we take the opportunity to verify that the hugr // serialization round-trips. We verify the schema of the serialization // format only when an environment variable is set. This allows @@ -82,44 +81,6 @@ impl<'a, H: HugrView> ValidationContext<'a, H> { Ok(()) } - fn validate_linkage(&self) -> Result<(), ValidationError> { - // Map from func_name, for visible funcs only, to *tuple of* - // Node with that func_name, - // Signature, - // bool - true for FuncDefn - let mut node_sig_defn = HashMap::new(); - - for c in self.hugr.children(self.hugr.module_root()) { - let (func_name, sig, is_defn) = match self.hugr.get_optype(c) { - OpType::FuncDecl(fd) if fd.visibility() == &Visibility::Public => { - (fd.func_name(), fd.signature(), false) - } - OpType::FuncDefn(fd) if fd.visibility() == &Visibility::Public => { - (fd.func_name(), fd.signature(), true) - } - _ => continue, - }; - match node_sig_defn.entry(func_name) { - Entry::Vacant(ve) => { - ve.insert((c, sig, is_defn)); - } - Entry::Occupied(oe) => { - // Allow two decls of the same sig (aliasing - we are allowing some laziness here). - // Reject if at least one Defn - either two conflicting impls, - // or Decl+Defn which should have been linked - let (prev_c, prev_sig, prev_defn) = oe.get(); - if prev_sig != &sig || is_defn || *prev_defn { - return Err(ValidationError::DuplicateExport { - link_name: func_name.clone(), - children: [*prev_c, c], - }); - }; - } - } - } - Ok(()) - } - /// Compute the dominator tree for a CFG region, identified by its container /// node. /// @@ -158,7 +119,7 @@ impl<'a, H: HugrView> ValidationContext<'a, H> { if num_ports != op_type.port_count(dir) { return Err(ValidationError::WrongNumberOfPorts { node, - optype: Box::new(op_type.clone()), + optype: op_type.clone(), actual: num_ports, expected: op_type.port_count(dir), dir, @@ -176,9 +137,9 @@ impl<'a, H: HugrView> ValidationContext<'a, H> { if !allowed_children.is_superset(op_type.tag()) { return Err(ValidationError::InvalidParentOp { child: node, - child_optype: Box::new(op_type.clone()), + child_optype: op_type.clone(), parent, - parent_optype: Box::new(parent_optype.clone()), + parent_optype: parent_optype.clone(), allowed_children, }); } @@ -190,7 +151,7 @@ impl<'a, H: HugrView> ValidationContext<'a, H> { if validity_flags.allowed_children == OpTag::None { return Err(ValidationError::EntrypointNotContainer { node, - optype: Box::new(op_type.clone()), + optype: op_type.clone(), }); } } @@ -239,7 +200,7 @@ impl<'a, H: HugrView> ValidationContext<'a, H> { return Err(ValidationError::UnconnectedPort { node, port, - port_kind: Box::new(port_kind), + port_kind, }); } @@ -249,7 +210,7 @@ impl<'a, H: HugrView> ValidationContext<'a, H> { return Err(ValidationError::TooManyConnections { node, port, - port_kind: Box::new(port_kind), + port_kind, }); } return Ok(()); @@ -269,7 +230,7 @@ impl<'a, H: HugrView> ValidationContext<'a, H> { return Err(ValidationError::TooManyConnections { node, port, - port_kind: Box::new(port_kind), + port_kind, }); } @@ -283,10 +244,10 @@ impl<'a, H: HugrView> ValidationContext<'a, H> { return Err(ValidationError::IncompatiblePorts { from: node, from_port: port, - from_kind: Box::new(port_kind), + from_kind: port_kind, to: other_node, to_port: other_offset, - to_kind: Box::new(other_kind), + to_kind: other_kind, }); } @@ -325,7 +286,7 @@ impl<'a, H: HugrView> ValidationContext<'a, H> { if flags.allowed_children.is_empty() { return Err(ValidationError::NonContainerWithChildren { node, - optype: Box::new(op_type.clone()), + optype: op_type.clone(), }); } @@ -335,8 +296,8 @@ impl<'a, H: HugrView> ValidationContext<'a, H> { if !flags.allowed_first_child.is_superset(first_child.tag()) { return Err(ValidationError::InvalidInitialChild { parent: node, - parent_optype: Box::new(op_type.clone()), - optype: Box::new(first_child.clone()), + parent_optype: op_type.clone(), + optype: first_child.clone(), expected: flags.allowed_first_child, position: "first", }); @@ -349,8 +310,8 @@ impl<'a, H: HugrView> ValidationContext<'a, H> { if !flags.allowed_second_child.is_superset(second_child.tag()) { return Err(ValidationError::InvalidInitialChild { parent: node, - parent_optype: Box::new(op_type.clone()), - optype: Box::new(second_child.clone()), + parent_optype: op_type.clone(), + optype: second_child.clone(), expected: flags.allowed_second_child, position: "second", }); @@ -361,7 +322,7 @@ impl<'a, H: HugrView> ValidationContext<'a, H> { if let Err(source) = op_type.validate_op_children(children_optypes) { return Err(ValidationError::InvalidChildren { parent: node, - parent_optype: Box::new(op_type.clone()), + parent_optype: op_type.clone(), source, }); } @@ -388,7 +349,7 @@ impl<'a, H: HugrView> ValidationContext<'a, H> { if let Err(source) = edge_check(edge_data) { return Err(ValidationError::InvalidEdges { parent: node, - parent_optype: Box::new(op_type.clone()), + parent_optype: op_type.clone(), source, }); } @@ -403,7 +364,7 @@ impl<'a, H: HugrView> ValidationContext<'a, H> { } else if flags.requires_children { return Err(ValidationError::ContainerWithoutChildren { node, - optype: Box::new(op_type.clone()), + optype: op_type.clone(), }); } @@ -434,7 +395,7 @@ impl<'a, H: HugrView> ValidationContext<'a, H> { if nodes_visited != node_count { return Err(ValidationError::NotADag { node: parent, - optype: Box::new(op_type.clone()), + optype: op_type.clone(), }); } @@ -472,7 +433,7 @@ impl<'a, H: HugrView> ValidationContext<'a, H> { from_offset, to, to_offset, - ty: Box::new(edge_kind), + ty: edge_kind, }); } @@ -482,12 +443,28 @@ impl<'a, H: HugrView> ValidationContext<'a, H> { // // This search could be sped-up with a pre-computed LCA structure, but // for valid Hugrs this search should be very short. + // + // For Value edges only, we record any FuncDefn we went through; if there is + // any such, then that is an error, but we report that only if the dom/ext + // relation was otherwise ok (an error about an edge "entering" some ancestor + // node could be misleading if the source isn't where it's expected) + let mut err_entered_func = None; let from_parent_parent = self.hugr.get_parent(from_parent); for (ancestor, ancestor_parent) in iter::successors(to_parent, |&p| self.hugr.get_parent(p)).tuple_windows() { + if !is_static && self.hugr.get_optype(ancestor).is_func_defn() { + err_entered_func.get_or_insert(InterGraphEdgeError::ValueEdgeIntoFunc { + to, + to_offset, + from, + from_offset, + func: ancestor, + }); + } if ancestor_parent == from_parent { // External edge. + err_entered_func.map_or(Ok(()), Err)?; if !is_static { // Must have an order edge. self.hugr @@ -511,10 +488,10 @@ impl<'a, H: HugrView> ValidationContext<'a, H> { from_offset, to, to_offset, - ancestor_parent_op: Box::new(ancestor_parent_op.clone()), + ancestor_parent_op: ancestor_parent_op.clone(), }); } - + err_entered_func.map_or(Ok(()), Err)?; // Check domination let (dominator_tree, node_map) = if let Some(tree) = self.dominators.get(&ancestor_parent) { @@ -640,7 +617,7 @@ pub enum ValidationError { )] WrongNumberOfPorts { node: N, - optype: Box, + optype: OpType, actual: usize, expected: usize, dir: Direction, @@ -650,14 +627,14 @@ pub enum ValidationError { UnconnectedPort { node: N, port: Port, - port_kind: Box, + port_kind: EdgeKind, }, /// A linear port is connected to more than one thing. #[error("{node} has a port {port} of type {port_kind} with more than one connection.")] TooManyConnections { node: N, port: Port, - port_kind: Box, + port_kind: EdgeKind, }, /// Connected ports have different types, or non-unifiable types. #[error( @@ -666,10 +643,10 @@ pub enum ValidationError { IncompatiblePorts { from: N, from_port: Port, - from_kind: Box, + from_kind: EdgeKind, to: N, to_port: Port, - to_kind: Box, + to_kind: EdgeKind, }, /// The non-root node has no parent. #[error("{node} has no parent.")] @@ -678,9 +655,9 @@ pub enum ValidationError { #[error("The operation {parent_optype} cannot contain a {child_optype} as a child. Allowed children: {}. In {child} with parent {parent}.", allowed_children.description())] InvalidParentOp { child: N, - child_optype: Box, + child_optype: OpType, parent: N, - parent_optype: Box, + parent_optype: OpType, allowed_children: OpTag, }, /// Invalid first/second child. @@ -689,8 +666,8 @@ pub enum ValidationError { )] InvalidInitialChild { parent: N, - parent_optype: Box, - optype: Box, + parent_optype: OpType, + optype: OpType, expected: OpTag, position: &'static str, }, @@ -701,19 +678,9 @@ pub enum ValidationError { )] InvalidChildren { parent: N, - parent_optype: Box, + parent_optype: OpType, source: ChildrenValidationError, }, - /// Multiple, incompatible, nodes with [Visibility::Public] use the same `func_name` - /// in a [Module](super::Module). (Multiple [`FuncDecl`](crate::ops::FuncDecl)s with - /// the same signature are allowed) - #[error("FuncDefn/Decl {} is exported under same name {link_name} as earlier node {}", children[0], children[1])] - DuplicateExport { - /// The `func_name` of a public `FuncDecl` or `FuncDefn` - link_name: String, - /// Two nodes using that name - children: [N; 2], - }, /// The children graph has invalid edges. #[error( "An operation {parent_optype} contains invalid edges between its children: {source}. In parent {parent}, edge from {from:?} port {from_port:?} to {to:?} port {to_port:?}", @@ -724,23 +691,27 @@ pub enum ValidationError { )] InvalidEdges { parent: N, - parent_optype: Box, + parent_optype: OpType, source: EdgeValidationError, }, /// The node operation is not a container, but has children. #[error("{node} with optype {optype} is not a container, but has children.")] - NonContainerWithChildren { node: N, optype: Box }, + NonContainerWithChildren { node: N, optype: OpType }, /// The node must have children, but has none. #[error("{node} with optype {optype} must have children, but has none.")] - ContainerWithoutChildren { node: N, optype: Box }, + ContainerWithoutChildren { node: N, optype: OpType }, /// The children of a node do not form a DAG. #[error( "The children of an operation {optype} must form a DAG. Loops are not allowed. In {node}." )] - NotADag { node: N, optype: Box }, + NotADag { node: N, optype: OpType }, /// There are invalid inter-graph edges. #[error(transparent)] InterGraphEdgeError(#[from] InterGraphEdgeError), + /// There are errors in the extension deltas. + #[deprecated(note = "Never returned since hugr-core-v0.20.0")] + #[error(transparent)] + ExtensionError(#[from] ExtensionError), /// A node claims to still be awaiting extension inference. Perhaps it is not acted upon by inference. #[error( "{node} needs a concrete ExtensionSet - inference will provide this for Case/CFG/Conditional/DataflowBlock/DFG/TailLoop only" @@ -769,7 +740,7 @@ pub enum ValidationError { ConstTypeError(#[from] ConstTypeError), /// The HUGR entrypoint must be a region container. #[error("The HUGR entrypoint ({node}) must be a region container, but '{}' does not accept children.", optype.name())] - EntrypointNotContainer { node: N, optype: Box }, + EntrypointNotContainer { node: N, optype: OpType }, } /// Errors related to the inter-graph edge validations. @@ -786,7 +757,18 @@ pub enum InterGraphEdgeError { from_offset: Port, to: N, to_offset: Port, - ty: Box, + ty: EdgeKind, + }, + /// Inter-Graph edges may not enter into `FuncDefns` unless they are static + #[error( + "Inter-graph Value edges cannot enter into FuncDefns. Inter-graph edge from {from} ({from_offset}) to {to} ({to_offset} enters FuncDefn {func}" + )] + ValueEdgeIntoFunc { + from: N, + from_offset: Port, + to: N, + to_offset: Port, + func: N, }, /// The grandparent of a dominator inter-graph edge must be a CFG container. #[error( @@ -797,7 +779,7 @@ pub enum InterGraphEdgeError { from_offset: Port, to: N, to_offset: Port, - ancestor_parent_op: Box, + ancestor_parent_op: OpType, }, /// The sibling ancestors of the external inter-graph edge endpoints must be have an order edge between them. #[error( diff --git a/hugr-core/src/hugr/validate/test.rs b/hugr-core/src/hugr/validate/test.rs index ec086243ee..8ee95cde61 100644 --- a/hugr-core/src/hugr/validate/test.rs +++ b/hugr-core/src/hugr/validate/test.rs @@ -1,44 +1,38 @@ -use std::borrow::Cow; use std::fs::File; use std::io::BufReader; use std::sync::Arc; use cool_asserts::assert_matches; -use rstest::rstest; use super::*; use crate::builder::test::closed_dfg_root_hugr; use crate::builder::{ BuildError, Container, DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer, - FunctionBuilder, HugrBuilder, ModuleBuilder, inout_sig, + FunctionBuilder, HugrBuilder, ModuleBuilder, SubContainer, inout_sig, }; use crate::extension::prelude::Noop; use crate::extension::prelude::{bool_t, qb_t, usize_t}; use crate::extension::{Extension, ExtensionRegistry, PRELUDE, TypeDefBound}; use crate::hugr::HugrMut; use crate::hugr::internal::HugrMutInternals; -use crate::ops::dataflow::{DataflowParent, IOTrait}; +use crate::ops::dataflow::IOTrait; use crate::ops::handle::NodeHandle; -use crate::ops::{self, FuncDecl, FuncDefn, OpType, Value}; +use crate::ops::{self, OpType, Value}; use crate::std_extensions::logic::LogicOp; use crate::std_extensions::logic::test::{and_op, or_op}; -use crate::types::type_param::{TermTypeError, TypeArg}; +use crate::types::type_param::{TypeArg, TypeArgError}; use crate::types::{ - CustomType, FuncValueType, PolyFuncType, PolyFuncTypeRV, Signature, Term, Type, TypeBound, - TypeRV, TypeRow, + CustomType, FuncValueType, PolyFuncType, PolyFuncTypeRV, Signature, Type, TypeBound, TypeRV, + TypeRow, }; use crate::{Direction, Hugr, IncomingPort, Node, const_extension_ids, test_file, type_row}; -/// Creates a hugr with a single, public, function definition that copies a bit `copies` times. +/// Creates a hugr with a single function definition that copies a bit `copies` times. /// /// Returns the hugr and the node index of the definition. fn make_simple_hugr(copies: usize) -> (Hugr, Node) { - let def_op: OpType = FuncDefn::new_vis( - "main", - Signature::new(bool_t(), vec![bool_t(); copies]), - Visibility::Public, - ) - .into(); + let def_op: OpType = + ops::FuncDefn::new("main", Signature::new(bool_t(), vec![bool_t(); copies])).into(); let mut b = Hugr::default(); let root = b.entrypoint(); @@ -132,7 +126,7 @@ fn children_restrictions() { // Add a definition without children let def_sig = Signature::new(vec![bool_t()], vec![bool_t(), bool_t()]); - let new_def = b.add_node_with_parent(root, FuncDefn::new("main", def_sig)); + let new_def = b.add_node_with_parent(root, ops::FuncDefn::new("main", def_sig)); assert_matches!( b.validate(), Err(ValidationError::ContainerWithoutChildren { node, .. }) => assert_eq!(node, new_def) @@ -231,6 +225,35 @@ fn test_ext_edge() { h.validate().unwrap(); } +#[test] +fn no_ext_edge_into_func() -> Result<(), Box> { + let b2b = Signature::new_endo(bool_t()); + let mut h = DFGBuilder::new(Signature::new(bool_t(), Type::new_function(b2b.clone())))?; + let [input] = h.input_wires_arr(); + + let mut dfg = h.dfg_builder(Signature::new(vec![], Type::new_function(b2b.clone())), [])?; + let mut func = dfg.define_function("AndWithOuter", b2b.clone())?; + let [fn_input] = func.input_wires_arr(); + let and_op = func.add_dataflow_op(and_op(), [fn_input, input])?; // 'ext' edge + let func = func.finish_with_outputs(and_op.outputs())?; + let loadfn = dfg.load_func(func.handle(), &[])?; + let dfg = dfg.finish_with_outputs([loadfn])?; + let res = h.finish_hugr_with_outputs(dfg.outputs()); + assert_eq!( + res, + Err(BuildError::InvalidHUGR( + ValidationError::InterGraphEdgeError(InterGraphEdgeError::ValueEdgeIntoFunc { + from: input.node(), + from_offset: input.source().into(), + to: and_op.node(), + to_offset: IncomingPort::from(1).into(), + func: func.node() + }) + )) + ); + Ok(()) +} + #[test] fn test_local_const() { let mut h = closed_dfg_root_hugr(Signature::new_endo(bool_t())); @@ -243,7 +266,7 @@ fn test_local_const() { Err(ValidationError::UnconnectedPort { node: and, port: IncomingPort::from(1).into(), - port_kind: Box::new(EdgeKind::Value(bool_t())) + port_kind: EdgeKind::Value(bool_t()) }) ); let const_op: ops::Const = ops::Value::from_bool(true).into(); @@ -282,7 +305,7 @@ fn identity_hugr_with_type(t: Type) -> (Hugr, Node) { let def = b.add_node_with_parent( b.entrypoint(), - FuncDefn::new("main", Signature::new_endo(row.clone())), + ops::FuncDefn::new("main", Signature::new_endo(row.clone())), ); let input = b.add_node_with_parent(def, ops::Input::new(row.clone())); @@ -324,9 +347,9 @@ fn invalid_types() { let valid = Type::new_extension(CustomType::new( "MyContainer", - vec![usize_t().into()], + vec![TypeArg::Type { ty: usize_t() }], EXT_ID, - TypeBound::Linear, + TypeBound::Any, &Arc::downgrade(&ext), )); let mut hugr = identity_hugr_with_type(valid.clone()).0; @@ -336,22 +359,22 @@ fn invalid_types() { // valid is Any, so is not allowed as an element of an outer MyContainer. let element_outside_bound = CustomType::new( "MyContainer", - vec![valid.clone().into()], + vec![TypeArg::Type { ty: valid.clone() }], EXT_ID, - TypeBound::Linear, + TypeBound::Any, &Arc::downgrade(&ext), ); assert_eq!( validate_to_sig_error(element_outside_bound), - SignatureError::TypeArgMismatch(TermTypeError::TypeMismatch { - type_: Box::new(TypeBound::Copyable.into()), - term: Box::new(valid.into()) + SignatureError::TypeArgMismatch(TypeArgError::TypeMismatch { + param: TypeBound::Copyable.into(), + arg: TypeArg::Type { ty: valid } }) ); let bad_bound = CustomType::new( "MyContainer", - vec![usize_t().into()], + vec![TypeArg::Type { ty: usize_t() }], EXT_ID, TypeBound::Copyable, &Arc::downgrade(&ext), @@ -360,36 +383,41 @@ fn invalid_types() { validate_to_sig_error(bad_bound.clone()), SignatureError::WrongBound { actual: TypeBound::Copyable, - expected: TypeBound::Linear + expected: TypeBound::Any } ); // bad_bound claims to be Copyable, which is valid as an element for the outer MyContainer. let nested = CustomType::new( "MyContainer", - vec![Type::new_extension(bad_bound).into()], + vec![TypeArg::Type { + ty: Type::new_extension(bad_bound), + }], EXT_ID, - TypeBound::Linear, + TypeBound::Any, &Arc::downgrade(&ext), ); assert_eq!( validate_to_sig_error(nested), SignatureError::WrongBound { actual: TypeBound::Copyable, - expected: TypeBound::Linear + expected: TypeBound::Any } ); let too_many_type_args = CustomType::new( "MyContainer", - vec![usize_t().into(), 3u64.into()], + vec![ + TypeArg::Type { ty: usize_t() }, + TypeArg::BoundedNat { n: 3 }, + ], EXT_ID, - TypeBound::Linear, + TypeBound::Any, &Arc::downgrade(&ext), ); assert_eq!( validate_to_sig_error(too_many_type_args), - SignatureError::TypeArgMismatch(TermTypeError::WrongNumberArgs(2, 1)) + SignatureError::TypeArgMismatch(TypeArgError::WrongNumberArgs(2, 1)) ); } @@ -399,8 +427,8 @@ fn typevars_declared() -> Result<(), Box> { let f = FunctionBuilder::new( "myfunc", PolyFuncType::new( - [TypeBound::Linear.into()], - Signature::new_endo(vec![Type::new_var_use(0, TypeBound::Linear)]), + [TypeBound::Any.into()], + Signature::new_endo(vec![Type::new_var_use(0, TypeBound::Any)]), ), )?; let [w] = f.input_wires_arr(); @@ -409,8 +437,8 @@ fn typevars_declared() -> Result<(), Box> { let f = FunctionBuilder::new( "myfunc", PolyFuncType::new( - [TypeBound::Linear.into()], - Signature::new_endo(vec![Type::new_var_use(1, TypeBound::Linear)]), + [TypeBound::Any.into()], + Signature::new_endo(vec![Type::new_var_use(1, TypeBound::Any)]), ), )?; let [w] = f.input_wires_arr(); @@ -419,7 +447,7 @@ fn typevars_declared() -> Result<(), Box> { let f = FunctionBuilder::new( "myfunc", PolyFuncType::new( - [TypeBound::Linear.into()], + [TypeBound::Any.into()], Signature::new_endo(vec![Type::new_var_use(1, TypeBound::Copyable)]), ), )?; @@ -428,39 +456,51 @@ fn typevars_declared() -> Result<(), Box> { Ok(()) } -/// Test that `FuncDefns` cannot be nested. +/// Test that nested `FuncDefns` cannot use Type Variables declared by enclosing `FuncDefns` #[test] -fn no_nested_funcdefns() -> Result<(), Box> { - let mut outer = FunctionBuilder::new("outer", Signature::new_endo(usize_t()))?; - let inner = outer - .add_hugr({ - let inner = FunctionBuilder::new("inner", Signature::new_endo(bool_t()))?; - let [w] = inner.input_wires_arr(); - inner.finish_hugr_with_outputs([w])? - }) - .inserted_entrypoint; - let [w] = outer.input_wires_arr(); - let outer_node = outer.container_node(); - let hugr = outer.finish_hugr_with_outputs([w]); +fn nested_typevars() -> Result<(), Box> { + const OUTER_BOUND: TypeBound = TypeBound::Any; + const INNER_BOUND: TypeBound = TypeBound::Copyable; + fn build(t: Type) -> Result { + let mut outer = FunctionBuilder::new( + "outer", + PolyFuncType::new( + [OUTER_BOUND.into()], + Signature::new_endo(vec![Type::new_var_use(0, TypeBound::Any)]), + ), + )?; + let inner = outer.define_function( + "inner", + PolyFuncType::new([INNER_BOUND.into()], Signature::new_endo(vec![t])), + )?; + let [w] = inner.input_wires_arr(); + inner.finish_with_outputs([w])?; + let [w] = outer.input_wires_arr(); + outer.finish_hugr_with_outputs([w]) + } + assert!(build(Type::new_var_use(0, INNER_BOUND)).is_ok()); assert_matches!( - hugr.unwrap_err(), - BuildError::InvalidHUGR(ValidationError::InvalidParentOp { - child_optype, - allowed_children: OpTag::DataflowChild, - parent_optype, - child, parent - }) if matches!(*child_optype, OpType::FuncDefn(_)) && matches!(*parent_optype, OpType::FuncDefn(_)) => { - assert_eq!(child, inner); - assert_eq!(parent, outer_node); - } + build(Type::new_var_use(1, OUTER_BOUND)).unwrap_err(), + BuildError::InvalidHUGR(ValidationError::SignatureError { + cause: SignatureError::FreeTypeVar { + idx: 1, + num_decls: 1 + }, + .. + }) ); + assert_matches!(build(Type::new_var_use(0, OUTER_BOUND)).unwrap_err(), + BuildError::InvalidHUGR(ValidationError::SignatureError { cause: SignatureError::TypeVarDoesNotMatchDeclaration { actual, cached }, .. }) => + {assert_eq!(actual, INNER_BOUND.into()); assert_eq!(cached, OUTER_BOUND.into())}); Ok(()) } #[test] fn no_polymorphic_consts() -> Result<(), Box> { use crate::std_extensions::collections::list; - const BOUND: TypeParam = TypeParam::RuntimeType(TypeBound::Copyable); + const BOUND: TypeParam = TypeParam::Type { + b: TypeBound::Copyable, + }; let list_of_var = Type::new_extension( list::EXTENSION .get_type(&list::LIST_TYPENAME) @@ -493,10 +533,10 @@ fn no_polymorphic_consts() -> Result<(), Box> { } pub(crate) fn extension_with_eval_parallel() -> Arc { - let rowp = TypeParam::new_list_type(TypeBound::Linear); + let rowp = TypeParam::new_list(TypeBound::Any); Extension::new_test_arc(EXT_ID, |ext, extension_ref| { - let inputs = TypeRV::new_row_var_use(0, TypeBound::Linear); - let outputs = TypeRV::new_row_var_use(1, TypeBound::Linear); + let inputs = TypeRV::new_row_var_use(0, TypeBound::Any); + let outputs = TypeRV::new_row_var_use(1, TypeBound::Any); let evaled_fn = TypeRV::new_function(FuncValueType::new(inputs.clone(), outputs.clone())); let pf = PolyFuncTypeRV::new( [rowp.clone(), rowp.clone()], @@ -505,7 +545,7 @@ pub(crate) fn extension_with_eval_parallel() -> Arc { ext.add_op("eval".into(), String::new(), pf, extension_ref) .unwrap(); - let rv = |idx| TypeRV::new_row_var_use(idx, TypeBound::Linear); + let rv = |idx| TypeRV::new_row_var_use(idx, TypeBound::Any); let pf = PolyFuncTypeRV::new( [rowp.clone(), rowp.clone(), rowp.clone(), rowp.clone()], Signature::new( @@ -523,8 +563,8 @@ pub(crate) fn extension_with_eval_parallel() -> Arc { #[test] fn instantiate_row_variables() -> Result<(), Box> { - fn uint_seq(i: usize) -> Term { - vec![usize_t().into(); i].into() + fn uint_seq(i: usize) -> TypeArg { + vec![TypeArg::Type { ty: usize_t() }; i].into() } let e = extension_with_eval_parallel(); let mut dfb = DFGBuilder::new(inout_sig( @@ -548,49 +588,124 @@ fn instantiate_row_variables() -> Result<(), Box> { Ok(()) } -fn list1ty(t: TypeRV) -> Term { - Term::new_list([t.into()]) +fn seq1ty(t: TypeRV) -> TypeArg { + TypeArg::Sequence { + elems: vec![t.into()], + } } #[test] fn row_variables() -> Result<(), Box> { let e = extension_with_eval_parallel(); - let tv = TypeRV::new_row_var_use(0, TypeBound::Linear); + let tv = TypeRV::new_row_var_use(0, TypeBound::Any); let inner_ft = Type::new_function(FuncValueType::new_endo(tv.clone())); let ft_usz = Type::new_function(FuncValueType::new_endo(vec![tv.clone(), usize_t().into()])); let mut fb = FunctionBuilder::new( "id", PolyFuncType::new( - [TypeParam::new_list_type(TypeBound::Linear)], + [TypeParam::new_list(TypeBound::Any)], Signature::new(inner_ft.clone(), ft_usz), ), )?; // All the wires here are carrying higher-order Function values let [func_arg] = fb.input_wires_arr(); let id_usz = { - let mut mb = fb.module_root_builder(); - let bldr = mb.define_function("id_usz", Signature::new_endo(usize_t()))?; + let bldr = fb.define_function("id_usz", Signature::new_endo(usize_t()))?; let vals = bldr.input_wires(); - let helper_def = bldr.finish_with_outputs(vals)?; - fb.load_func(helper_def.handle(), &[])? + let inner_def = bldr.finish_with_outputs(vals)?; + fb.load_func(inner_def.handle(), &[])? }; let par = e.instantiate_extension_op( "parallel", - [tv.clone(), usize_t().into(), tv.clone(), usize_t().into()].map(list1ty), + [tv.clone(), usize_t().into(), tv.clone(), usize_t().into()].map(seq1ty), )?; let par_func = fb.add_dataflow_op(par, [func_arg, id_usz])?; fb.finish_hugr_with_outputs(par_func.outputs())?; Ok(()) } +#[test] +fn test_polymorphic_call() -> Result<(), Box> { + // TODO: This tests a function call that is polymorphic in an extension set. + // Should this be rewritten to be polymorphic in something else or removed? + + let e = Extension::try_new_test_arc(EXT_ID, |ext, extension_ref| { + let params: Vec = vec![TypeBound::Any.into(), TypeBound::Any.into()]; + let evaled_fn = Type::new_function(Signature::new( + Type::new_var_use(0, TypeBound::Any), + Type::new_var_use(1, TypeBound::Any), + )); + // Single-input/output version of the higher-order "eval" operation, with extension param. + // Note the extension-delta of the eval node includes that of the input function. + ext.add_op( + "eval".into(), + String::new(), + PolyFuncTypeRV::new( + params.clone(), + Signature::new( + vec![evaled_fn, Type::new_var_use(0, TypeBound::Any)], + Type::new_var_use(1, TypeBound::Any), + ), + ), + extension_ref, + )?; + + Ok(()) + })?; + + fn utou() -> Type { + Type::new_function(Signature::new_endo(usize_t())) + } + + let int_pair = Type::new_tuple(vec![usize_t(); 2]); + // Root DFG: applies a function int-->int to each element of a pair of two ints + let mut d = DFGBuilder::new(inout_sig( + vec![utou(), int_pair.clone()], + vec![int_pair.clone()], + ))?; + // ....by calling a function (int-->int, int_pair) -> int_pair + let f = { + let mut f = d.define_function( + "two_ints", + PolyFuncType::new( + vec![], + Signature::new(vec![utou(), int_pair.clone()], int_pair.clone()), + ), + )?; + let [func, tup] = f.input_wires_arr(); + let mut c = f.conditional_builder( + (vec![vec![usize_t(); 2].into()], tup), + vec![], + vec![usize_t(); 2].into(), + )?; + let mut cc = c.case_builder(0)?; + let [i1, i2] = cc.input_wires_arr(); + let op = e.instantiate_extension_op("eval", vec![usize_t().into(), usize_t().into()])?; + let [f1] = cc.add_dataflow_op(op.clone(), [func, i1])?.outputs_arr(); + let [f2] = cc.add_dataflow_op(op, [func, i2])?.outputs_arr(); + cc.finish_with_outputs([f1, f2])?; + let res = c.finish_sub_container()?.outputs(); + let tup = f.make_tuple(res)?; + f.finish_with_outputs([tup])? + }; + + let [func, tup] = d.input_wires_arr(); + let call = d.call(f.handle(), &[], [func, tup])?; + let h = d.finish_hugr_with_outputs(call.outputs())?; + let call_ty = h.get_optype(call.node()).dataflow_signature().unwrap(); + let exp_fun_ty = Signature::new(vec![utou(), int_pair.clone()], int_pair); + assert_eq!(call_ty.as_ref(), &exp_fun_ty); + Ok(()) +} + #[test] fn test_polymorphic_load() -> Result<(), Box> { let mut m = ModuleBuilder::new(); let id = m.declare( "id", PolyFuncType::new( - vec![TypeBound::Linear.into()], - Signature::new_endo(vec![Type::new_var_use(0, TypeBound::Linear)]), + vec![TypeBound::Any.into()], + Signature::new_endo(vec![Type::new_var_use(0, TypeBound::Any)]), ), )?; let sig = Signature::new( @@ -737,7 +852,7 @@ fn cfg_connections() -> Result<(), Box> { Err(ValidationError::TooManyConnections { node: middle.node(), port: Port::new(Direction::Outgoing, 0), - port_kind: Box::new(EdgeKind::ControlFlow) + port_kind: EdgeKind::ControlFlow }) ); Ok(()) @@ -761,74 +876,3 @@ fn cfg_entry_io_bug() -> Result<(), Box> { Ok(()) } - -fn sig1() -> Signature { - Signature::new_endo(bool_t()) -} - -fn sig2() -> Signature { - Signature::new_endo(usize_t()) -} - -#[rstest] -// Private FuncDefns never conflict even if different sig -#[case( - FuncDefn::new_vis("foo", sig1(), Visibility::Public), - FuncDefn::new("foo", sig2()), - None -)] -#[case(FuncDefn::new("foo", sig1()), FuncDecl::new("foo", sig2()), None)] -// Public FuncDefn conflicts with anything Public even if same sig -#[case( - FuncDefn::new_vis("foo", sig1(), Visibility::Public), - FuncDefn::new_vis("foo", sig1(), Visibility::Public), - Some("foo") -)] -#[case( - FuncDefn::new_vis("foo", sig1(), Visibility::Public), - FuncDecl::new("foo", sig1()), - Some("foo") -)] -// Two public FuncDecls are ok with same sig -#[case(FuncDecl::new("foo", sig1()), FuncDecl::new("foo", sig1()), None)] -// But two public FuncDecls not ok if different sigs -#[case( - FuncDecl::new("foo", sig1()), - FuncDecl::new("foo", sig2()), - Some("foo") -)] -fn validate_linkage( - #[case] f1: impl Into, - #[case] f2: impl Into, - #[case] err: Option<&str>, -) { - let mut h = Hugr::new(); - let [n1, n2] = [f1.into(), f2.into()].map(|f| { - let def_sig = f - .as_func_defn() - .map(FuncDefn::inner_signature) - .map(Cow::into_owned); - let n = h.add_node_with_parent(h.module_root(), f); - if let Some(Signature { input, output }) = def_sig { - let i = h.add_node_with_parent(n, ops::Input::new(input)); - let o = h.add_node_with_parent(n, ops::Output::new(output)); - h.connect(i, 0, o, 0); // Assume all sig's used in test are 1-ary endomorphic - } - n - }); - let r = h.validate(); - match err { - None => r.unwrap(), - Some(name) => { - let Err(ValidationError::DuplicateExport { - link_name, - children, - }) = r - else { - panic!("validate() should have produced DuplicateExport error not {r:?}") - }; - assert_eq!(link_name, name); - assert!(children == [n1, n2] || children == [n2, n1]); - } - } -} diff --git a/hugr-core/src/hugr/views.rs b/hugr-core/src/hugr/views.rs index 3a7b435cf7..1704d79f65 100644 --- a/hugr-core/src/hugr/views.rs +++ b/hugr-core/src/hugr/views.rs @@ -421,7 +421,7 @@ pub trait HugrView: HugrInternals { let config = match RenderConfig::try_from(formatter) { Ok(config) => config, Err(e) => { - panic!("Unsupported format option: {e}"); + panic!("Unsupported format option: {}", e); } }; #[allow(deprecated)] diff --git a/hugr-core/src/hugr/views/impls.rs b/hugr-core/src/hugr/views/impls.rs index dbc9ad7b8f..75311f2f73 100644 --- a/hugr-core/src/hugr/views/impls.rs +++ b/hugr-core/src/hugr/views/impls.rs @@ -115,7 +115,6 @@ macro_rules! hugr_mut_methods { fn disconnect(&mut self, node: Self::Node, port: impl Into); fn add_other_edge(&mut self, src: Self::Node, dst: Self::Node) -> (crate::OutgoingPort, crate::IncomingPort); fn insert_hugr(&mut self, root: Self::Node, other: crate::Hugr) -> crate::hugr::hugrmut::InsertionResult; - fn insert_region(&mut self, root: Self::Node, other: crate::Hugr, region: crate::Node) -> crate::hugr::hugrmut::InsertionResult; fn insert_from_view(&mut self, root: Self::Node, other: &Other) -> crate::hugr::hugrmut::InsertionResult; fn insert_subgraph(&mut self, root: Self::Node, other: &Other, subgraph: &crate::hugr::views::SiblingSubgraph) -> std::collections::HashMap; fn use_extension(&mut self, extension: impl Into>); diff --git a/hugr-core/src/hugr/views/render.rs b/hugr-core/src/hugr/views/render.rs index 3f8e48c963..b787a9e383 100644 --- a/hugr-core/src/hugr/views/render.rs +++ b/hugr-core/src/hugr/views/render.rs @@ -340,8 +340,8 @@ pub(in crate::hugr) fn edge_style<'a>( config: MermaidFormatter<'_>, ) -> Box< dyn FnMut( - as LinkView>::LinkEndpoint, - as LinkView>::LinkEndpoint, + ::LinkEndpoint, + ::LinkEndpoint, ) -> EdgeStyle + 'a, > { @@ -417,5 +417,15 @@ mod tests { { assert!(RenderConfig::try_from(config).is_err()); } + + #[allow(deprecated)] + let config = RenderConfig { + entrypoint: Some(h.entrypoint()), + ..Default::default() + }; + assert_eq!( + MermaidFormatter::from_render_config(config, &h), + h.mermaid_format() + ) } } diff --git a/hugr-core/src/hugr/views/rerooted.rs b/hugr-core/src/hugr/views/rerooted.rs index 18821cfe29..8c84abdc71 100644 --- a/hugr-core/src/hugr/views/rerooted.rs +++ b/hugr-core/src/hugr/views/rerooted.rs @@ -138,7 +138,6 @@ impl HugrMut for Rerooted { fn disconnect(&mut self, node: Self::Node, port: impl Into); fn add_other_edge(&mut self, src: Self::Node, dst: Self::Node) -> (crate::OutgoingPort, crate::IncomingPort); fn insert_hugr(&mut self, root: Self::Node, other: crate::Hugr) -> crate::hugr::hugrmut::InsertionResult; - fn insert_region(&mut self, root: Self::Node, other: crate::Hugr, region: crate::Node) -> crate::hugr::hugrmut::InsertionResult; fn insert_from_view(&mut self, root: Self::Node, other: &Other) -> crate::hugr::hugrmut::InsertionResult; fn insert_subgraph(&mut self, root: Self::Node, other: &Other, subgraph: &crate::hugr::views::SiblingSubgraph) -> std::collections::HashMap; fn use_extension(&mut self, extension: impl Into>); diff --git a/hugr-core/src/hugr/views/root_checked/dfg.rs b/hugr-core/src/hugr/views/root_checked/dfg.rs index 160b4a3116..f9681c9ca7 100644 --- a/hugr-core/src/hugr/views/root_checked/dfg.rs +++ b/hugr-core/src/hugr/views/root_checked/dfg.rs @@ -8,147 +8,137 @@ use thiserror::Error; use crate::{ IncomingPort, OutgoingPort, PortIndex, hugr::HugrMut, - ops::{ - OpTrait, OpType, - handle::{DataflowParentID, DfgID}, - }, + ops::{DFG, FuncDefn, Input, OpTrait, OpType, Output, dataflow::IOTrait, handle::DfgID}, types::{NoRV, Signature, TypeBase}, }; use super::RootChecked; -macro_rules! impl_dataflow_parent_methods { - ($handle_type:ident) => { - impl RootChecked> { - /// Get the input and output nodes of the DFG at the entrypoint node. - pub fn get_io(&self) -> [H::Node; 2] { - self.hugr() - .get_io(self.hugr().entrypoint()) - .expect("valid DFG graph") - } +impl RootChecked> { + /// Get the input and output nodes of the DFG at the entrypoint node. + pub fn get_io(&self) -> [H::Node; 2] { + self.hugr() + .get_io(self.hugr().entrypoint()) + .expect("valid DFG graph") + } - /// Rewire the inputs and outputs of the nested DFG to modify its signature. - /// - /// Reorder the outgoing resp. incoming wires at the input resp. output - /// node of the DFG to modify the signature of the DFG HUGR. This will - /// recursively update the signatures of all ancestors of the entrypoint. - /// - /// ### Arguments - /// - /// * `new_inputs`: The new input signature. After the map, the i-th input - /// wire will be connected to the ports connected to the - /// `new_inputs[i]`-th input of the old DFG. - /// * `new_outputs`: The new output signature. After the map, the i-th - /// output wire will be connected to the ports connected to the - /// `new_outputs[i]`-th output of the old DFG. - /// - /// Returns an `InvalidSignature` error if the new_inputs and new_outputs - /// map are not valid signatures. - /// - /// ### Panics - /// - /// Panics if the DFG is not trivially nested, i.e. if there is an ancestor - /// DFG of the entrypoint that has more than one inner DFG. - pub fn map_function_type( - &mut self, - new_inputs: &[usize], - new_outputs: &[usize], - ) -> Result<(), InvalidSignature> { - let [inp, out] = self.get_io(); - let Self(hugr, _) = self; - - // Record the old connections from and to the input and output nodes - let old_inputs_incoming = hugr - .node_outputs(inp) - .map(|p| hugr.linked_inputs(inp, p).collect_vec()) - .collect_vec(); - let old_outputs_outgoing = hugr - .node_inputs(out) - .map(|p| hugr.linked_outputs(out, p).collect_vec()) - .collect_vec(); - - // The old signature types - let old_inp_sig = hugr - .get_optype(inp) - .dataflow_signature() - .expect("input has signature"); - let old_inp_sig = old_inp_sig.output_types(); - let old_out_sig = hugr - .get_optype(out) - .dataflow_signature() - .expect("output has signature"); - let old_out_sig = old_out_sig.input_types(); - - // Check if the signature map is valid - check_valid_inputs(&old_inputs_incoming, old_inp_sig, new_inputs)?; - check_valid_outputs(old_out_sig, new_outputs)?; - - // The new signature types - let new_inp_sig = new_inputs - .iter() - .map(|&i| old_inp_sig[i].clone()) - .collect_vec(); - let new_out_sig = new_outputs - .iter() - .map(|&i| old_out_sig[i].clone()) - .collect_vec(); - let new_sig = Signature::new(new_inp_sig, new_out_sig); - - // Remove all edges of the input and output nodes - disconnect_all(hugr, inp); - disconnect_all(hugr, out); - - // Update the signatures of the IO and their ancestors - let mut is_ancestor = false; - let mut node = hugr.entrypoint(); - while matches!(hugr.get_optype(node), OpType::FuncDefn(_) | OpType::DFG(_)) { - let [inner_inp, inner_out] = hugr.get_io(node).expect("valid DFG graph"); - for node in [node, inner_inp, inner_out] { - update_signature(hugr, node, &new_sig); - } - if is_ancestor { - update_inner_dfg_links(hugr, node); - } - if let Some(parent) = hugr.get_parent(node) { - node = parent; - is_ancestor = true; - } else { - break; - } - } + /// Rewire the inputs and outputs of the DFG to modify its signature. + /// + /// Reorder the outgoing resp. incoming wires at the input resp. output + /// node of the DFG to modify the signature of the DFG HUGR. This will + /// recursively update the signatures of all ancestors of the entrypoint. + /// + /// ### Arguments + /// + /// * `new_inputs`: The new input signature. After the map, the i-th input + /// wire will be connected to the ports connected to the + /// `new_inputs[i]`-th input of the old DFG. + /// * `new_outputs`: The new output signature. After the map, the i-th + /// output wire will be connected to the ports connected to the + /// `new_outputs[i]`-th output of the old DFG. + /// + /// Returns an `InvalidSignature` error if the new_inputs and new_outputs + /// map are not valid signatures. + /// + /// ### Panics + /// + /// Panics if the DFG is not trivially nested, i.e. if there is an ancestor + /// DFG of the entrypoint that has more than one inner DFG. + pub fn map_function_type( + &mut self, + new_inputs: &[usize], + new_outputs: &[usize], + ) -> Result<(), InvalidSignature> { + let [inp, out] = self.get_io(); + let Self(hugr, _) = self; + + // Record the old connections from and to the input and output nodes + let old_inputs_incoming = hugr + .node_outputs(inp) + .map(|p| hugr.linked_inputs(inp, p).collect_vec()) + .collect_vec(); + let old_outputs_outgoing = hugr + .node_inputs(out) + .map(|p| hugr.linked_outputs(out, p).collect_vec()) + .collect_vec(); + + // The old signature types + let old_inp_sig = hugr + .get_optype(inp) + .dataflow_signature() + .expect("input has signature"); + let old_inp_sig = old_inp_sig.output_types(); + let old_out_sig = hugr + .get_optype(out) + .dataflow_signature() + .expect("output has signature"); + let old_out_sig = old_out_sig.input_types(); + + // Check if the signature map is valid + check_valid_inputs(&old_inputs_incoming, old_inp_sig, new_inputs)?; + check_valid_outputs(old_out_sig, new_outputs)?; + + // The new signature types + let new_inp_sig = new_inputs + .iter() + .map(|&i| old_inp_sig[i].clone()) + .collect_vec(); + let new_out_sig = new_outputs + .iter() + .map(|&i| old_out_sig[i].clone()) + .collect_vec(); + let new_sig = Signature::new(new_inp_sig, new_out_sig); + + // Remove all edges of the input and output nodes + disconnect_all(hugr, inp); + disconnect_all(hugr, out); + + // Update the signatures of the IO and their ancestors + let mut is_ancestor = false; + let mut node = hugr.entrypoint(); + while matches!(hugr.get_optype(node), OpType::FuncDefn(_) | OpType::DFG(_)) { + let [inner_inp, inner_out] = hugr.get_io(node).expect("valid DFG graph"); + for node in [node, inner_inp, inner_out] { + update_signature(hugr, node, &new_sig); + } + if is_ancestor { + update_inner_dfg_links(hugr, node); + } + if let Some(parent) = hugr.get_parent(node) { + node = parent; + is_ancestor = true; + } else { + break; + } + } - // Insert the new edges at the input - let mut old_output_to_new_input = BTreeMap::::new(); - for (inp_pos, &old_pos) in new_inputs.iter().enumerate() { - for &(node, port) in &old_inputs_incoming[old_pos] { - if node != out { - hugr.connect(inp, inp_pos, node, port); - } else { - old_output_to_new_input.insert(port, inp_pos.into()); - } - } + // Insert the new edges at the input + let mut old_output_to_new_input = BTreeMap::::new(); + for (inp_pos, &old_pos) in new_inputs.iter().enumerate() { + for &(node, port) in &old_inputs_incoming[old_pos] { + if node != out { + hugr.connect(inp, inp_pos, node, port); + } else { + old_output_to_new_input.insert(port, inp_pos.into()); } + } + } - // Insert the new edges at the output - for (out_pos, &old_pos) in new_outputs.iter().enumerate() { - for &(node, port) in &old_outputs_outgoing[old_pos] { - if node != inp { - hugr.connect(node, port, out, out_pos); - } else { - let &inp_pos = old_output_to_new_input.get(&old_pos.into()).unwrap(); - hugr.connect(inp, inp_pos, out, out_pos); - } - } + // Insert the new edges at the output + for (out_pos, &old_pos) in new_outputs.iter().enumerate() { + for &(node, port) in &old_outputs_outgoing[old_pos] { + if node != inp { + hugr.connect(node, port, out, out_pos); + } else { + let &inp_pos = old_output_to_new_input.get(&old_pos.into()).unwrap(); + hugr.connect(inp, inp_pos, out, out_pos); } - - Ok(()) } } - }; -} -impl_dataflow_parent_methods!(DataflowParentID); -impl_dataflow_parent_methods!(DfgID); + Ok(()) + } +} /// Panics if the DFG within `node` is not a single inner DFG. fn update_inner_dfg_links(hugr: &mut H, node: H::Node) { @@ -178,19 +168,20 @@ fn disconnect_all(hugr: &mut H, node: H::Node) { } fn update_signature(hugr: &mut H, node: H::Node, new_sig: &Signature) { - match hugr.optype_mut(node) { - OpType::DFG(dfg) => { - dfg.signature = new_sig.clone(); + let new_op: OpType = match hugr.get_optype(node) { + OpType::DFG(_) => DFG { + signature: new_sig.clone(), } - OpType::FuncDefn(fn_def_op) => *fn_def_op.signature_mut() = new_sig.clone().into(), - OpType::Input(inp) => { - inp.types = new_sig.input().clone(); + .into(), + OpType::FuncDefn(fn_def_op) => { + FuncDefn::new(fn_def_op.func_name().clone(), new_sig.clone()).into() } - OpType::Output(out) => out.types = new_sig.output().clone(), + OpType::Input(_) => Input::new(new_sig.input().clone()).into(), + OpType::Output(_) => Output::new(new_sig.output().clone()).into(), _ => panic!("only update signature of DFG, FuncDefn, Input, or Output"), }; - let new_op = hugr.get_optype(node); hugr.set_num_ports(node, new_op.input_count(), new_op.output_count()); + hugr.replace_op(node, new_op); } fn check_valid_inputs( @@ -277,11 +268,11 @@ mod test { use super::*; use crate::builder::{ - DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer, HugrBuilder, endo_sig, + Container, DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer, HugrBuilder, endo_sig, }; use crate::extension::prelude::{bool_t, qb_t}; use crate::hugr::views::root_checked::RootChecked; - use crate::ops::handle::NodeHandle; + use crate::ops::handle::{DfgID, NodeHandle}; use crate::ops::{NamedOp, OpParent}; use crate::types::Signature; use crate::utils::test_quantum_extension::cx_gate; @@ -299,51 +290,6 @@ mod test { let sig = Signature::new_endo(vec![qb_t(), qb_t()]); let mut hugr = new_empty_dfg(sig); - // Wrap in RootChecked - let mut dfg_view = RootChecked::<&mut Hugr, DataflowParentID>::try_new(&mut hugr).unwrap(); - - // Test mapping inputs: [0,1] -> [1,0] - let input_map = vec![1, 0]; - let output_map = vec![0, 1]; - - // Map the I/O - dfg_view.map_function_type(&input_map, &output_map).unwrap(); - - // Verify the new signature - let dfg_hugr = dfg_view.hugr(); - let new_sig = dfg_hugr - .get_optype(dfg_hugr.entrypoint()) - .dataflow_signature() - .unwrap(); - assert_eq!(new_sig.input_count(), 2); - assert_eq!(new_sig.output_count(), 2); - - // Test invalid mapping - missing input - let invalid_input_map = vec![0, 0]; - let err = dfg_view.map_function_type(&invalid_input_map, &output_map); - assert!(matches!(err, Err(InvalidSignature::MissingIO(1, "input")))); - - // Test invalid mapping - duplicate input - let invalid_input_map = vec![0, 0, 1]; - assert!(matches!( - dfg_view.map_function_type(&invalid_input_map, &output_map), - Err(InvalidSignature::DuplicateInput(0)) - )); - - // Test invalid mapping - unknown output - let invalid_output_map = vec![0, 2]; - assert!(matches!( - dfg_view.map_function_type(&input_map, &invalid_output_map), - Err(InvalidSignature::UnknownIO(2, "output")) - )); - } - - #[test] - fn test_map_io_dfg_id() { - // Create a DFG with 2 inputs and 2 outputs - let sig = Signature::new_endo(vec![qb_t(), qb_t()]); - let mut hugr = new_empty_dfg(sig); - // Wrap in RootChecked let mut dfg_view = RootChecked::<&mut Hugr, DfgID>::try_new(&mut hugr).unwrap(); @@ -391,7 +337,7 @@ mod test { let mut hugr = new_empty_dfg(sig); // Wrap in RootChecked - let mut dfg_view = RootChecked::<&mut Hugr, DataflowParentID>::try_new(&mut hugr).unwrap(); + let mut dfg_view = RootChecked::<&mut Hugr, DfgID>::try_new(&mut hugr).unwrap(); // Test mapping outputs: [0] -> [0,0] (duplicating the output) let input_map = vec![0]; @@ -431,7 +377,7 @@ mod test { .unwrap(); // Wrap in RootChecked - let mut dfg_view = RootChecked::<&mut Hugr, DataflowParentID>::try_new(&mut hugr).unwrap(); + let mut dfg_view = RootChecked::<&mut Hugr, DfgID>::try_new(&mut hugr).unwrap(); // Test mapping inputs: [0,1] -> [1,0] (swapping inputs) let input_map = vec![1, 0]; diff --git a/hugr-core/src/hugr/views/sibling_subgraph.rs b/hugr-core/src/hugr/views/sibling_subgraph.rs index bb7a53ad15..03dff67f51 100644 --- a/hugr-core/src/hugr/views/sibling_subgraph.rs +++ b/hugr-core/src/hugr/views/sibling_subgraph.rs @@ -428,7 +428,7 @@ impl SiblingSubgraph { if !OpTag::DataflowParent.is_superset(dfg_optype.tag()) { return Err(InvalidReplacement::InvalidDataflowGraph { node: rep_root, - op: Box::new(dfg_optype.clone()), + op: dfg_optype.clone(), }); } let [rep_input, rep_output] = replacement @@ -575,7 +575,7 @@ fn pick_parent<'a, N: HugrNode>( } fn make_boundary<'a, H: HugrView>( - region: &impl LinkView, + region: &impl LinkView, node_map: &H::RegionPortgraphNodes, inputs: &'a IncomingPorts, outputs: &'a OutgoingPorts, @@ -881,7 +881,7 @@ pub enum InvalidReplacement { /// The node ID of the root node. node: Node, /// The op type of the root node. - op: Box, + op: OpType, }, /// Replacement graph type mismatch. #[error( @@ -890,9 +890,9 @@ pub enum InvalidReplacement { ] InvalidSignature { /// The expected signature. - expected: Box, + expected: Signature, /// The actual signature. - actual: Option>, + actual: Option, }, /// `SiblingSubgraph` is not convex. #[error("SiblingSubgraph is not convex.")] diff --git a/hugr-core/src/hugr/views/tests.rs b/hugr-core/src/hugr/views/tests.rs index 28d304d2ee..3a83b5b9d8 100644 --- a/hugr-core/src/hugr/views/tests.rs +++ b/hugr-core/src/hugr/views/tests.rs @@ -4,8 +4,7 @@ use rstest::{fixture, rstest}; use crate::{ Hugr, HugrView, builder::{ - BuildError, BuildHandle, Container, DFGBuilder, Dataflow, DataflowHugr, HugrBuilder, - endo_sig, inout_sig, + BuildError, BuildHandle, Container, DFGBuilder, Dataflow, DataflowHugr, endo_sig, inout_sig, }, extension::prelude::qb_t, ops::{ @@ -184,9 +183,8 @@ fn test_dataflow_ports_only() { let mut dfg = DFGBuilder::new(endo_sig(bool_t())).unwrap(); let local_and = { - let mut mb = dfg.module_root_builder(); - let local_and = mb - .define_function("and", Signature::new(vec![bool_t(); 2], bool_t())) + let local_and = dfg + .define_function("and", Signature::new(vec![bool_t(); 2], vec![bool_t()])) .unwrap(); let first_input = local_and.input().out_wire(0); local_and.finish_with_outputs([first_input]).unwrap() diff --git a/hugr-core/src/import.rs b/hugr-core/src/import.rs index b1a606da96..664f4d3ea1 100644 --- a/hugr-core/src/import.rs +++ b/hugr-core/src/import.rs @@ -7,9 +7,7 @@ use std::sync::Arc; use crate::{ Direction, Hugr, HugrView, Node, Port, - extension::{ - ExtensionId, ExtensionRegistry, SignatureError, resolution::ExtensionResolutionError, - }, + extension::{ExtensionId, ExtensionRegistry, SignatureError}, hugr::{HugrMut, NodeMetadata}, ops::{ AliasDecl, AliasDefn, CFG, Call, CallIndirect, Case, Conditional, Const, DFG, @@ -24,102 +22,67 @@ use crate::{ }, types::{ CustomType, FuncTypeBase, MaybeRV, PolyFuncType, PolyFuncTypeBase, RowVariable, Signature, - Term, Type, TypeArg, TypeBase, TypeBound, TypeEnum, TypeName, TypeRow, - type_param::{SeqPart, TypeParam}, + Type, TypeArg, TypeBase, TypeBound, TypeEnum, TypeName, TypeRow, type_param::TypeParam, type_row::TypeRowBase, }, }; use fxhash::FxHashMap; +use hugr_model::v0 as model; use hugr_model::v0::table; -use hugr_model::v0::{self as model}; -use itertools::{Either, Itertools}; +use itertools::Either; use smol_str::{SmolStr, ToSmolStr}; use thiserror::Error; -fn gen_str(generator: &Option) -> String { - match generator { - Some(g) => format!(" generated by {g}"), - None => String::new(), - } -} - -/// An error that can occur during import. -#[derive(Debug, Clone, Error)] -#[error("failed to import hugr{}", gen_str(&self.generator))] -pub struct ImportError { - #[source] - inner: ImportErrorInner, - generator: Option, -} - +/// Error during import. #[derive(Debug, Clone, Error)] -enum ImportErrorInner { +#[non_exhaustive] +pub enum ImportError { /// The model contains a feature that is not supported by the importer yet. /// Errors of this kind are expected to be removed as the model format and /// the core HUGR representation converge. #[error("currently unsupported: {0}")] Unsupported(String), - /// The model contains implicit information that has not yet been inferred. /// This includes wildcards and application of functions with implicit parameters. #[error("uninferred implicit: {0}")] Uninferred(String), - - /// The model is not well-formed. - #[error("{0}")] - Invalid(String), - - /// An error with additional context. - #[error("import failed in context: {1}")] - Context(#[source] Box, String), - /// A signature mismatch was detected during import. - #[error("signature error")] + #[error("signature error: {0}")] Signature(#[from] SignatureError), - - /// An error relating to the loaded extension registry. - #[error("extension error")] - Extension(#[from] ExtensionError), - - /// Incorrect order hints. - #[error("incorrect order hint")] - OrderHint(#[from] OrderHintError), - - /// Extension resolution. - #[error("extension resolution error")] - ExtensionResolution(#[from] ExtensionResolutionError), -} - -#[derive(Debug, Clone, Error)] -enum ExtensionError { - /// An extension is missing. + /// A required extension is missing. #[error("Importing the hugr requires extension {missing_ext}, which was not found in the registry. The available extensions are: [{}]", available.iter().map(std::string::ToString::to_string).collect::>().join(", "))] - Missing { + Extension { /// The missing extension. missing_ext: ExtensionId, /// The available extensions in the registry. available: Vec, }, - /// An extension type is missing. #[error( "Importing the hugr requires extension {ext} to have a type named {name}, but it was not found." )] - MissingType { + ExtensionType { /// The extension that is missing the type. ext: ExtensionId, /// The name of the missing type. name: TypeName, }, + /// The model is not well-formed. + #[error("validate error: {0}")] + Model(#[from] table::ModelError), + /// Incorrect order hints. + #[error("incorrect order hint: {0}")] + OrderHint(#[from] OrderHintError), } /// Import error caused by incorrect order hints. #[derive(Debug, Clone, Error)] -enum OrderHintError { +#[non_exhaustive] +pub enum OrderHintError { /// Duplicate order hint key in the same region. #[error("duplicate order hint key {0}")] - DuplicateKey(table::RegionId, u64), + DuplicateKey(table::NodeId, u64), /// Order hint including a key not defined in the region. #[error("order hint with unknown key {0}")] UnknownKey(u64), @@ -128,28 +91,14 @@ enum OrderHintError { NoOrderPort(table::NodeId), } -/// Helper macro to create an `ImportErrorInner::Unsupported` error with a formatted message. +/// Helper macro to create an `ImportError::Unsupported` error with a formatted message. macro_rules! error_unsupported { - ($($e:expr),*) => { ImportErrorInner::Unsupported(format!($($e),*)) } + ($($e:expr),*) => { ImportError::Unsupported(format!($($e),*)) } } -/// Helper macro to create an `ImportErrorInner::Uninferred` error with a formatted message. +/// Helper macro to create an `ImportError::Uninferred` error with a formatted message. macro_rules! error_uninferred { - ($($e:expr),*) => { ImportErrorInner::Uninferred(format!($($e),*)) } -} - -/// Helper macro to create an `ImportErrorInner::Invalid` error with a formatted message. -macro_rules! error_invalid { - ($($e:expr),*) => { ImportErrorInner::Invalid(format!($($e),*)) } -} - -/// Helper macro to create an `ImportErrorInner::Context` error with a formatted message. -macro_rules! error_context { - ($err:expr, $($e:expr),*) => { - { - ImportErrorInner::Context(Box::new($err), format!($($e),*)) - } - } + ($($e:expr),*) => { ImportError::Uninferred(format!($($e),*)) } } /// Import a [`Package`] from its model representation. @@ -168,22 +117,6 @@ pub fn import_package( Ok(package) } -/// Get the name of the generator from the metadata of the module. -/// If no generator is found, `None` is returned. -fn get_generator(ctx: &Context<'_>) -> Option { - ctx.module - .get_region(ctx.module.root) - .map(|r| r.meta.iter()) - .into_iter() - .flatten() - .find_map(|meta| { - let (name, json_val) = ctx.decode_json_meta(*meta).ok()??; - - (name == crate::envelope::GENERATOR_KEY) - .then_some(crate::envelope::format_generator(&json_val)) - }) -} - /// Import a [`Hugr`] module from its model representation. pub fn import_hugr( module: &table::Module, @@ -203,26 +136,10 @@ pub fn import_hugr( region_scope: table::RegionId::default(), }; - let import_steps: [fn(&mut Context) -> _; 3] = [ - |ctx| ctx.import_root(), - |ctx| ctx.link_ports(), - |ctx| ctx.link_static_ports(), - ]; - - for step in import_steps { - if let Err(e) = step(&mut ctx) { - return Err(ImportError { - inner: e, - generator: get_generator(&ctx), - }); - } - } - ctx.hugr - .resolve_extension_defs(extensions) - .map_err(|e| ImportError { - inner: ImportErrorInner::ExtensionResolution(e), - generator: get_generator(&ctx), - })?; + ctx.import_root()?; + ctx.link_ports()?; + ctx.link_static_ports()?; + Ok(ctx.hugr) } @@ -256,7 +173,7 @@ struct Context<'a> { impl<'a> Context<'a> { /// Get the signature of the node with the given `NodeId`. - fn get_node_signature(&mut self, node: table::NodeId) -> Result { + fn get_node_signature(&mut self, node: table::NodeId) -> Result { let node_data = self.get_node(node)?; let signature = node_data .signature @@ -266,29 +183,26 @@ impl<'a> Context<'a> { /// Get the node with the given `NodeId`, or return an error if it does not exist. #[inline] - fn get_node(&self, node_id: table::NodeId) -> Result<&'a table::Node<'a>, ImportErrorInner> { + fn get_node(&self, node_id: table::NodeId) -> Result<&'a table::Node<'a>, ImportError> { self.module .get_node(node_id) - .ok_or_else(|| error_invalid!("unknown node {}", node_id)) + .ok_or_else(|| table::ModelError::NodeNotFound(node_id).into()) } /// Get the term with the given `TermId`, or return an error if it does not exist. #[inline] - fn get_term(&self, term_id: table::TermId) -> Result<&'a table::Term<'a>, ImportErrorInner> { + fn get_term(&self, term_id: table::TermId) -> Result<&'a table::Term<'a>, ImportError> { self.module .get_term(term_id) - .ok_or_else(|| error_invalid!("unknown term {}", term_id)) + .ok_or_else(|| table::ModelError::TermNotFound(term_id).into()) } /// Get the region with the given `RegionId`, or return an error if it does not exist. #[inline] - fn get_region( - &self, - region_id: table::RegionId, - ) -> Result<&'a table::Region<'a>, ImportErrorInner> { + fn get_region(&self, region_id: table::RegionId) -> Result<&'a table::Region<'a>, ImportError> { self.module .get_region(region_id) - .ok_or_else(|| error_invalid!("unknown region {}", region_id)) + .ok_or_else(|| table::ModelError::RegionNotFound(region_id).into()) } fn make_node( @@ -296,7 +210,7 @@ impl<'a> Context<'a> { node_id: table::NodeId, op: OpType, parent: Node, - ) -> Result { + ) -> Result { let node = self.hugr.add_node_with_parent(parent, op); self.nodes.insert(node_id, node); @@ -305,8 +219,7 @@ impl<'a> Context<'a> { self.record_links(node, Direction::Outgoing, node_data.outputs); for meta_item in node_data.meta { - self.import_node_metadata(node, *meta_item) - .map_err(|err| error_context!(err, "node metadata"))?; + self.import_node_metadata(node, *meta_item)?; } Ok(node) @@ -316,9 +229,21 @@ impl<'a> Context<'a> { &mut self, node: Node, meta_item: table::TermId, - ) -> Result<(), ImportErrorInner> { + ) -> Result<(), ImportError> { // Import the JSON metadata - if let Some((name, json_value)) = self.decode_json_meta(meta_item)? { + if let Some([name_arg, json_arg]) = self.match_symbol(meta_item, model::COMPAT_META_JSON)? { + let table::Term::Literal(model::Literal::Str(name)) = self.get_term(name_arg)? else { + return Err(table::ModelError::TypeError(meta_item).into()); + }; + + let table::Term::Literal(model::Literal::Str(json_str)) = self.get_term(json_arg)? + else { + return Err(table::ModelError::TypeError(meta_item).into()); + }; + + let json_value: NodeMetadata = serde_json::from_str(json_str) + .map_err(|_| table::ModelError::TypeError(meta_item))?; + self.hugr.set_metadata(node, name, json_value); } @@ -330,44 +255,6 @@ impl<'a> Context<'a> { Ok(()) } - fn decode_json_meta( - &self, - meta_item: table::TermId, - ) -> Result, ImportErrorInner> { - Ok( - if let Some([name_arg, json_arg]) = - self.match_symbol(meta_item, model::COMPAT_META_JSON)? - { - let table::Term::Literal(model::Literal::Str(name)) = self.get_term(name_arg)? - else { - return Err(error_invalid!( - "`{}` expects a string literal as its first argument", - model::COMPAT_META_JSON - )); - }; - - let table::Term::Literal(model::Literal::Str(json_str)) = - self.get_term(json_arg)? - else { - return Err(error_invalid!( - "`{}` expects a string literal as its second argument", - model::COMPAT_CONST_JSON - )); - }; - - let json_value: NodeMetadata = serde_json::from_str(json_str).map_err(|_| { - error_invalid!( - "failed to parse JSON string for `{}` metadata", - model::COMPAT_CONST_JSON - ) - })?; - Some((name.to_owned(), json_value)) - } else { - None - }, - ) - } - /// Associate links with the ports of the given node in the given direction. fn record_links(&mut self, node: Node, direction: Direction, links: &'a [table::LinkIndex]) { let optype = self.hugr.get_optype(node); @@ -384,7 +271,7 @@ impl<'a> Context<'a> { /// Link up the ports in the hugr graph, according to the connectivity information that /// has been gathered in the `link_ports` map. - fn link_ports(&mut self) -> Result<(), ImportErrorInner> { + fn link_ports(&mut self) -> Result<(), ImportError> { // For each edge, we group the ports by their direction. We reuse the `inputs` and // `outputs` vectors to avoid unnecessary allocations. let mut inputs = Vec::new(); @@ -432,7 +319,7 @@ impl<'a> Context<'a> { Ok(()) } - fn link_static_ports(&mut self) -> Result<(), ImportErrorInner> { + fn link_static_ports(&mut self) -> Result<(), ImportError> { for (src_id, dst_id) in std::mem::take(&mut self.static_edges) { // None of these lookups should fail given how we constructed `static_edges`. let src = self.nodes[&src_id]; @@ -445,40 +332,35 @@ impl<'a> Context<'a> { Ok(()) } - fn get_symbol_name(&self, node_id: table::NodeId) -> Result<&'a str, ImportErrorInner> { + fn get_symbol_name(&self, node_id: table::NodeId) -> Result<&'a str, ImportError> { let node_data = self.get_node(node_id)?; let name = node_data .operation .symbol() - .ok_or_else(|| error_invalid!("node {} is expected to be a symbol", node_id))?; + .ok_or(table::ModelError::InvalidSymbol(node_id))?; Ok(name) } fn get_func_signature( &mut self, func_node: table::NodeId, - ) -> Result { + ) -> Result { let symbol = match self.get_node(func_node)?.operation { table::Operation::DefineFunc(symbol) => symbol, table::Operation::DeclareFunc(symbol) => symbol, - _ => { - return Err(error_invalid!( - "node {} is expected to be a function declaration or definition", - func_node - )); - } + _ => return Err(table::ModelError::UnexpectedOperation(func_node).into()), }; self.import_poly_func_type(func_node, *symbol, |_, signature| Ok(signature)) } /// Import the root region of the module. - fn import_root(&mut self) -> Result<(), ImportErrorInner> { + fn import_root(&mut self) -> Result<(), ImportError> { self.region_scope = self.module.root; let region_data = self.get_region(self.module.root)?; for node in region_data.children { - self.import_node(*node, self.hugr.module_root())?; + self.import_node(*node, self.hugr.entrypoint())?; } for meta_item in region_data.meta { @@ -492,126 +374,250 @@ impl<'a> Context<'a> { &mut self, node_id: table::NodeId, parent: Node, - ) -> Result, ImportErrorInner> { + ) -> Result, ImportError> { let node_data = self.get_node(node_id)?; - let result = match node_data.operation { - table::Operation::Invalid => { - return Err(error_invalid!("tried to import an `invalid` operation")); + match node_data.operation { + table::Operation::Invalid => Err(table::ModelError::InvalidOperation(node_id).into()), + table::Operation::Dfg => { + let signature = self.get_node_signature(node_id)?; + let optype = OpType::DFG(DFG { signature }); + let node = self.make_node(node_id, optype, parent)?; + + let [region] = node_data.regions else { + return Err(table::ModelError::InvalidRegions(node_id).into()); + }; + + self.import_dfg_region(node_id, *region, node)?; + Ok(Some(node)) } - table::Operation::Dfg => Some( - self.import_node_dfg(node_id, parent, node_data) - .map_err(|err| error_context!(err, "`dfg` node with id {}", node_id))?, - ), - - table::Operation::Cfg => Some( - self.import_node_cfg(node_id, parent, node_data) - .map_err(|err| error_context!(err, "`cfg` node with id {}", node_id))?, - ), - - table::Operation::Block => Some( - self.import_node_block(node_id, parent) - .map_err(|err| error_context!(err, "`block` node with id {}", node_id))?, - ), - - table::Operation::DefineFunc(symbol) => Some( - self.import_node_define_func(node_id, symbol, node_data, parent) - .map_err(|err| error_context!(err, "`define-func` node with id {}", node_id))?, - ), - - table::Operation::DeclareFunc(symbol) => Some( - self.import_node_declare_func(node_id, symbol, parent) - .map_err(|err| { - error_context!(err, "`declare-func` node with id {}", node_id) - })?, - ), - - table::Operation::TailLoop => Some( - self.import_tail_loop(node_id, parent) - .map_err(|err| error_context!(err, "`tail-loop` node with id {}", node_id))?, - ), - - table::Operation::Conditional => Some( - self.import_conditional(node_id, parent) - .map_err(|err| error_context!(err, "`cond` node with id {}", node_id))?, - ), - - table::Operation::Custom(operation) => Some( - self.import_node_custom(node_id, operation, node_data, parent) - .map_err(|err| error_context!(err, "custom node with id {}", node_id))?, - ), - - table::Operation::DefineAlias(symbol, value) => Some( - self.import_node_define_alias(node_id, symbol, value, parent) - .map_err(|err| { - error_context!(err, "`define-alias` node with id {}", node_id) - })?, - ), - - table::Operation::DeclareAlias(symbol) => Some( - self.import_node_declare_alias(node_id, symbol, parent) - .map_err(|err| { - error_context!(err, "`declare-alias` node with id {}", node_id) - })?, - ), - - table::Operation::Import { .. } => None, - - table::Operation::DeclareConstructor { .. } => None, - table::Operation::DeclareOperation { .. } => None, - }; + table::Operation::Cfg => { + let signature = self.get_node_signature(node_id)?; + let optype = OpType::CFG(CFG { signature }); + let node = self.make_node(node_id, optype, parent)?; - Ok(result) - } + let [region] = node_data.regions else { + return Err(table::ModelError::InvalidRegions(node_id).into()); + }; - fn import_node_dfg( - &mut self, - node_id: table::NodeId, - parent: Node, - node_data: &'a table::Node<'a>, - ) -> Result { - let signature = self - .get_node_signature(node_id) - .map_err(|err| error_context!(err, "node signature"))?; + self.import_cfg_region(node_id, *region, node)?; + Ok(Some(node)) + } - let optype = OpType::DFG(DFG { signature }); - let node = self.make_node(node_id, optype, parent)?; + table::Operation::Block => { + let node = self.import_cfg_block(node_id, parent)?; + Ok(Some(node)) + } - let [region] = node_data.regions else { - return Err(error_invalid!("dfg region expects a single region")); - }; + table::Operation::DefineFunc(symbol) => { + self.import_poly_func_type(node_id, *symbol, |ctx, signature| { + let optype = OpType::FuncDefn(FuncDefn::new(symbol.name, signature)); - self.import_dfg_region(*region, node)?; - Ok(node) - } + let node = ctx.make_node(node_id, optype, parent)?; - fn import_node_cfg( - &mut self, - node_id: table::NodeId, - parent: Node, - node_data: &'a table::Node<'a>, - ) -> Result { - let signature = self - .get_node_signature(node_id) - .map_err(|err| error_context!(err, "node signature"))?; + let [region] = node_data.regions else { + return Err(table::ModelError::InvalidRegions(node_id).into()); + }; - let optype = OpType::CFG(CFG { signature }); - let node = self.make_node(node_id, optype, parent)?; + ctx.import_dfg_region(node_id, *region, node)?; - let [region] = node_data.regions else { - return Err(error_invalid!("cfg nodes expect a single region")); - }; + Ok(Some(node)) + }) + } - self.import_cfg_region(*region, node)?; - Ok(node) + table::Operation::DeclareFunc(symbol) => { + self.import_poly_func_type(node_id, *symbol, |ctx, signature| { + let optype = OpType::FuncDecl(FuncDecl::new(symbol.name, signature)); + + let node = ctx.make_node(node_id, optype, parent)?; + + Ok(Some(node)) + }) + } + + table::Operation::TailLoop => { + let node = self.import_tail_loop(node_id, parent)?; + Ok(Some(node)) + } + table::Operation::Conditional => { + let node = self.import_conditional(node_id, parent)?; + Ok(Some(node)) + } + + table::Operation::Custom(operation) => { + if let Some([_, _]) = self.match_symbol(operation, model::CORE_CALL_INDIRECT)? { + let signature = self.get_node_signature(node_id)?; + let optype = OpType::CallIndirect(CallIndirect { signature }); + let node = self.make_node(node_id, optype, parent)?; + return Ok(Some(node)); + } + + if let Some([_, _, func]) = self.match_symbol(operation, model::CORE_CALL)? { + let table::Term::Apply(symbol, args) = self.get_term(func)? else { + return Err(table::ModelError::TypeError(func).into()); + }; + + let func_sig = self.get_func_signature(*symbol)?; + + let type_args = args + .iter() + .map(|term| self.import_type_arg(*term)) + .collect::, _>>()?; + + self.static_edges.push((*symbol, node_id)); + let optype = OpType::Call(Call::try_new(func_sig, type_args)?); + + let node = self.make_node(node_id, optype, parent)?; + return Ok(Some(node)); + } + + if let Some([_, value]) = self.match_symbol(operation, model::CORE_LOAD_CONST)? { + // If the constant refers directly to a function, import this as the `LoadFunc` operation. + if let table::Term::Apply(symbol, args) = self.get_term(value)? { + let func_node_data = self + .module + .get_node(*symbol) + .ok_or(table::ModelError::NodeNotFound(*symbol))?; + + if let table::Operation::DefineFunc(_) | table::Operation::DeclareFunc(_) = + func_node_data.operation + { + let func_sig = self.get_func_signature(*symbol)?; + let type_args = args + .iter() + .map(|term| self.import_type_arg(*term)) + .collect::, _>>()?; + + self.static_edges.push((*symbol, node_id)); + + let optype = + OpType::LoadFunction(LoadFunction::try_new(func_sig, type_args)?); + + let node = self.make_node(node_id, optype, parent)?; + return Ok(Some(node)); + } + } + + // Otherwise use const nodes + let signature = node_data + .signature + .ok_or_else(|| error_uninferred!("node signature"))?; + let [_, outputs] = self.get_func_type(signature)?; + let outputs = self.import_closed_list(outputs)?; + let output = outputs + .first() + .ok_or(table::ModelError::TypeError(signature))?; + let datatype = self.import_type(*output)?; + + let imported_value = self.import_value(value, *output)?; + + let load_const_node = self.make_node( + node_id, + OpType::LoadConstant(LoadConstant { + datatype: datatype.clone(), + }), + parent, + )?; + + let const_node = self + .hugr + .add_node_with_parent(parent, OpType::Const(Const::new(imported_value))); + + self.hugr.connect(const_node, 0, load_const_node, 0); + + return Ok(Some(load_const_node)); + } + + if let Some([_, _, tag]) = self.match_symbol(operation, model::CORE_MAKE_ADT)? { + let table::Term::Literal(model::Literal::Nat(tag)) = self.get_term(tag)? else { + return Err(table::ModelError::TypeError(tag).into()); + }; + + let signature = node_data + .signature + .ok_or_else(|| error_uninferred!("node signature"))?; + let [_, outputs] = self.get_func_type(signature)?; + let (variants, _) = self.import_adt_and_rest(node_id, outputs)?; + let node = self.make_node( + node_id, + OpType::Tag(Tag { + variants, + tag: *tag as usize, + }), + parent, + )?; + return Ok(Some(node)); + } + + let table::Term::Apply(node, params) = self.get_term(operation)? else { + return Err(table::ModelError::TypeError(operation).into()); + }; + let name = self.get_symbol_name(*node)?; + let args = params + .iter() + .map(|param| self.import_type_arg(*param)) + .collect::, _>>()?; + let (extension, name) = self.import_custom_name(name)?; + let signature = self.get_node_signature(node_id)?; + + // TODO: Currently we do not have the description or any other metadata for + // the custom op. This will improve with declarative extensions being able + // to declare operations as a node, in which case the description will be attached + // to that node as metadata. + + let optype = OpType::OpaqueOp(OpaqueOp::new(extension, name, args, signature)); + + let node = self.make_node(node_id, optype, parent)?; + + Ok(Some(node)) + } + + table::Operation::DefineAlias(symbol, value) => { + if !symbol.params.is_empty() { + return Err(error_unsupported!( + "parameters or constraints in alias definition" + )); + } + + let optype = OpType::AliasDefn(AliasDefn { + name: symbol.name.to_smolstr(), + definition: self.import_type(value)?, + }); + + let node = self.make_node(node_id, optype, parent)?; + Ok(Some(node)) + } + + table::Operation::DeclareAlias(symbol) => { + if !symbol.params.is_empty() { + return Err(error_unsupported!( + "parameters or constraints in alias declaration" + )); + } + + let optype = OpType::AliasDecl(AliasDecl { + name: symbol.name.to_smolstr(), + bound: TypeBound::Copyable, + }); + + let node = self.make_node(node_id, optype, parent)?; + Ok(Some(node)) + } + + table::Operation::Import { .. } => Ok(None), + + table::Operation::DeclareConstructor { .. } => Ok(None), + table::Operation::DeclareOperation { .. } => Ok(None), + } } fn import_dfg_region( &mut self, + node_id: table::NodeId, region: table::RegionId, node: Node, - ) -> Result<(), ImportErrorInner> { + ) -> Result<(), ImportError> { let region_data = self.get_region(region)?; let prev_region = self.region_scope; @@ -620,16 +626,14 @@ impl<'a> Context<'a> { } if region_data.kind != model::RegionKind::DataFlow { - return Err(error_invalid!("expected dfg region")); + return Err(table::ModelError::InvalidRegions(node_id).into()); } - let signature = self - .import_func_type( - region_data - .signature - .ok_or_else(|| error_uninferred!("region signature"))?, - ) - .map_err(|err| error_context!(err, "signature of dfg region with id {}", region))?; + let signature = self.import_func_type( + region_data + .signature + .ok_or_else(|| error_uninferred!("region signature"))?, + )?; // Create the input and output nodes let input = self.hugr.add_node_with_parent( @@ -653,7 +657,7 @@ impl<'a> Context<'a> { self.import_node(*child, node)?; } - self.create_order_edges(region, input, output)?; + self.create_order_edges(region)?; for meta_item in region_data.meta { self.import_node_metadata(node, *meta_item)?; @@ -667,18 +671,13 @@ impl<'a> Context<'a> { /// Create order edges between nodes of a dataflow region based on order hint metadata. /// /// This method assumes that the nodes for the children of the region have already been imported. - fn create_order_edges( - &mut self, - region_id: table::RegionId, - input: Node, - output: Node, - ) -> Result<(), ImportErrorInner> { + fn create_order_edges(&mut self, region_id: table::RegionId) -> Result<(), ImportError> { let region_data = self.get_region(region_id)?; debug_assert_eq!(region_data.kind, model::RegionKind::DataFlow); // Collect order hint keys // PERFORMANCE: It might be worthwhile to reuse the map to avoid allocations. - let mut order_keys = FxHashMap::::default(); + let mut order_keys = FxHashMap::::default(); for child_id in region_data.children { let child_data = self.get_node(*child_id)?; @@ -692,42 +691,8 @@ impl<'a> Context<'a> { continue; }; - // NOTE: The lookups here are expected to succeed since we only - // process the order metadata after we have imported the nodes. - let child_node = self.nodes[child_id]; - let child_optype = self.hugr.get_optype(child_node); - - // Check that the node has order ports. - // NOTE: This assumes that a node has an input order port iff it has an output one. - if child_optype.other_output_port().is_none() { - return Err(OrderHintError::NoOrderPort(*child_id).into()); - } - - if order_keys.insert(*key, child_node).is_some() { - return Err(OrderHintError::DuplicateKey(region_id, *key).into()); - } - } - } - - // Collect the order hint keys for the input and output nodes - for meta_id in region_data.meta { - if let Some([key]) = self.match_symbol(*meta_id, model::ORDER_HINT_INPUT_KEY)? { - let table::Term::Literal(model::Literal::Nat(key)) = self.get_term(key)? else { - continue; - }; - - if order_keys.insert(*key, input).is_some() { - return Err(OrderHintError::DuplicateKey(region_id, *key).into()); - } - } - - if let Some([key]) = self.match_symbol(*meta_id, model::ORDER_HINT_OUTPUT_KEY)? { - let table::Term::Literal(model::Literal::Nat(key)) = self.get_term(key)? else { - continue; - }; - - if order_keys.insert(*key, output).is_some() { - return Err(OrderHintError::DuplicateKey(region_id, *key).into()); + if order_keys.insert(*key, *child_id).is_some() { + return Err(OrderHintError::DuplicateKey(*child_id, *key).into()); } } } @@ -749,13 +714,24 @@ impl<'a> Context<'a> { let a = order_keys.get(a).ok_or(OrderHintError::UnknownKey(*a))?; let b = order_keys.get(b).ok_or(OrderHintError::UnknownKey(*b))?; - // NOTE: The unwrap here must succeed: - // - For all ordinary nodes we checked that they have an order port. - // - Input and output nodes always have an order port. - let a_port = self.hugr.get_optype(*a).other_output_port().unwrap(); - let b_port = self.hugr.get_optype(*b).other_input_port().unwrap(); + // NOTE: The lookups here are expected to succeed since we only + // process the order metadata after we have imported the nodes. + let a_node = self.nodes[a]; + let b_node = self.nodes[b]; - self.hugr.connect(*a, a_port, *b, b_port); + let a_port = self + .hugr + .get_optype(a_node) + .other_output_port() + .ok_or(OrderHintError::NoOrderPort(*a))?; + + let b_port = self + .hugr + .get_optype(b_node) + .other_input_port() + .ok_or(OrderHintError::NoOrderPort(*b))?; + + self.hugr.connect(a_node, a_port, b_node, b_port); } Ok(()) @@ -763,12 +739,13 @@ impl<'a> Context<'a> { fn import_adt_and_rest( &mut self, + node_id: table::NodeId, list: table::TermId, - ) -> Result<(Vec, TypeRow), ImportErrorInner> { + ) -> Result<(Vec, TypeRow), ImportError> { let items = self.import_closed_list(list)?; let Some((first, rest)) = items.split_first() else { - return Err(error_invalid!("expected list to have at least one element")); + return Err(table::ModelError::InvalidRegions(node_id).into()); }; let sum_rows: Vec<_> = { @@ -789,40 +766,35 @@ impl<'a> Context<'a> { &mut self, node_id: table::NodeId, parent: Node, - ) -> Result { + ) -> Result { let node_data = self.get_node(node_id)?; debug_assert_eq!(node_data.operation, table::Operation::TailLoop); let [region] = node_data.regions else { - return Err(error_invalid!( - "loop node {} expects a single region", - node_id - )); + return Err(table::ModelError::InvalidRegions(node_id).into()); }; - let region_data = self.get_region(*region)?; - let (just_inputs, just_outputs, rest) = (|| { - let [_, region_outputs] = self.get_func_type( - region_data - .signature - .ok_or_else(|| error_uninferred!("region signature"))?, - )?; - let (sum_rows, rest) = self.import_adt_and_rest(region_outputs)?; - - if sum_rows.len() != 2 { - return Err(error_invalid!( - "loop nodes expect their first target to be an ADT with two variants" - )); - } + let [_, region_outputs] = self.get_func_type( + region_data + .signature + .ok_or_else(|| error_uninferred!("region signature"))?, + )?; + let (sum_rows, rest) = self.import_adt_and_rest(node_id, region_outputs)?; + let (just_inputs, just_outputs) = { let mut sum_rows = sum_rows.into_iter(); - let just_inputs = sum_rows.next().unwrap(); - let just_outputs = sum_rows.next().unwrap(); - Ok((just_inputs, just_outputs, rest)) - })() - .map_err(|err| error_context!(err, "region signature"))?; + let Some(just_inputs) = sum_rows.next() else { + return Err(table::ModelError::TypeError(region_outputs).into()); + }; + + let Some(just_outputs) = sum_rows.next() else { + return Err(table::ModelError::TypeError(region_outputs).into()); + }; + + (just_inputs, just_outputs) + }; let optype = OpType::TailLoop(TailLoop { just_inputs, @@ -832,7 +804,7 @@ impl<'a> Context<'a> { let node = self.make_node(node_id, optype, parent)?; - self.import_dfg_region(*region, node)?; + self.import_dfg_region(node_id, *region, node)?; Ok(node) } @@ -840,22 +812,16 @@ impl<'a> Context<'a> { &mut self, node_id: table::NodeId, parent: Node, - ) -> Result { + ) -> Result { let node_data = self.get_node(node_id)?; debug_assert_eq!(node_data.operation, table::Operation::Conditional); - - let (sum_rows, other_inputs, outputs) = (|| { - let [inputs, outputs] = self.get_func_type( - node_data - .signature - .ok_or_else(|| error_uninferred!("node signature"))?, - )?; - let (sum_rows, other_inputs) = self.import_adt_and_rest(inputs)?; - let outputs = self.import_type_row(outputs)?; - - Ok((sum_rows, other_inputs, outputs)) - })() - .map_err(|err| error_context!(err, "node signature"))?; + let [inputs, outputs] = self.get_func_type( + node_data + .signature + .ok_or_else(|| error_uninferred!("node signature"))?, + )?; + let (sum_rows, other_inputs) = self.import_adt_and_rest(node_id, inputs)?; + let outputs = self.import_type_row(outputs)?; let optype = OpType::Conditional(Conditional { sum_rows, @@ -877,7 +843,7 @@ impl<'a> Context<'a> { .hugr .add_node_with_parent(node, OpType::Case(Case { signature })); - self.import_dfg_region(*region, case_node)?; + self.import_dfg_region(node_id, *region, case_node)?; } Ok(node) @@ -885,13 +851,14 @@ impl<'a> Context<'a> { fn import_cfg_region( &mut self, + node_id: table::NodeId, region: table::RegionId, node: Node, - ) -> Result<(), ImportErrorInner> { + ) -> Result<(), ImportError> { let region_data = self.get_region(region)?; if region_data.kind != model::RegionKind::ControlFlow { - return Err(error_invalid!("expected cfg region")); + return Err(table::ModelError::InvalidRegions(node_id).into()); } let prev_region = self.region_scope; @@ -899,22 +866,19 @@ impl<'a> Context<'a> { self.region_scope = region; } - let region_target_types = (|| { - let [_, region_targets] = self.get_ctrl_type( - region_data - .signature - .ok_or_else(|| error_uninferred!("region signature"))?, - )?; + let [_, region_targets] = self.get_func_type( + region_data + .signature + .ok_or_else(|| error_uninferred!("region signature"))?, + )?; - self.import_closed_list(region_targets) - })() - .map_err(|err| error_context!(err, "signature of cfg region with id {}", region))?; + let region_target_types = self.import_closed_list(region_targets)?; // Identify the entry node of the control flow region by looking for // a block whose input is linked to the sole source port of the CFG region. let entry_node = 'find_entry: { let [entry_link] = region_data.sources else { - return Err(error_invalid!("cfg region expects a single source")); + return Err(table::ModelError::InvalidRegions(node_id).into()); }; for child in region_data.children { @@ -930,22 +894,29 @@ impl<'a> Context<'a> { // directly from the source to the target of the region. This is // currently not allowed in hugr core directly, but may be simulated // by constructing an empty entry block. - return Err(error_invalid!("cfg region without entry node")); + return Err(table::ModelError::InvalidRegions(node_id).into()); }; // The entry node in core control flow regions is identified by being - // the first child node of the CFG node. We therefore import the entry node first. + // the first child node of the CFG node. We therefore import the entry + // node first and follow it up by every other node. self.import_node(entry_node, node)?; - // Create the exit node for the control flow region. This always needs - // to be second in the node list. + for child in region_data.children { + if *child != entry_node { + self.import_node(*child, node)?; + } + } + + // Create the exit node for the control flow region. { let cfg_outputs = { - let [target_types] = region_target_types.as_slice() else { - return Err(error_invalid!("cfg region expects a single target")); + let [ctrl_type] = region_target_types.as_slice() else { + return Err(table::ModelError::TypeError(region_targets).into()); }; - self.import_type_row(*target_types)? + let [types] = self.expect_symbol(*ctrl_type, model::CORE_CTRL)?; + self.import_type_row(types)? }; let exit = self @@ -954,16 +925,8 @@ impl<'a> Context<'a> { self.record_links(exit, Direction::Incoming, region_data.targets); } - // Finally we import all other nodes. - for child in region_data.children { - if *child != entry_node { - self.import_node(*child, node)?; - } - } - for meta_item in region_data.meta { - self.import_node_metadata(node, *meta_item) - .map_err(|err| error_context!(err, "node metadata"))?; + self.import_node_metadata(node, *meta_item)?; } self.region_scope = prev_region; @@ -971,16 +934,16 @@ impl<'a> Context<'a> { Ok(()) } - fn import_node_block( + fn import_cfg_block( &mut self, node_id: table::NodeId, parent: Node, - ) -> Result { + ) -> Result { let node_data = self.get_node(node_id)?; debug_assert_eq!(node_data.operation, table::Operation::Block); let [region] = node_data.regions else { - return Err(error_invalid!("basic block expects a single region")); + return Err(table::ModelError::InvalidRegions(node_id).into()); }; let region_data = self.get_region(*region)?; let [inputs, outputs] = self.get_func_type( @@ -989,7 +952,7 @@ impl<'a> Context<'a> { .ok_or_else(|| error_uninferred!("region signature"))?, )?; let inputs = self.import_type_row(inputs)?; - let (sum_rows, other_outputs) = self.import_adt_and_rest(outputs)?; + let (sum_rows, other_outputs) = self.import_adt_and_rest(node_id, outputs)?; let optype = OpType::DataflowBlock(DataflowBlock { inputs, @@ -998,545 +961,350 @@ impl<'a> Context<'a> { }); let node = self.make_node(node_id, optype, parent)?; - self.import_dfg_region(*region, node).map_err(|err| { - error_context!(err, "block body defined by region with id {}", *region) - })?; + self.import_dfg_region(node_id, *region, node)?; Ok(node) } - fn import_node_define_func( + fn import_poly_func_type( &mut self, - node_id: table::NodeId, - symbol: &'a table::Symbol<'a>, - node_data: &'a table::Node<'a>, - parent: Node, - ) -> Result { - let visibility = symbol.visibility.clone().ok_or(ImportErrorInner::Invalid( - "No visibility for FuncDefn".to_string(), - ))?; - self.import_poly_func_type(node_id, *symbol, |ctx, signature| { - let func_name = ctx.import_title_metadata(node_id)?.unwrap_or(symbol.name); - - let optype = - OpType::FuncDefn(FuncDefn::new_vis(func_name, signature, visibility.into())); - - let node = ctx.make_node(node_id, optype, parent)?; - - let [region] = node_data.regions else { - return Err(error_invalid!( - "function definition nodes expect a single region" - )); - }; + node: table::NodeId, + symbol: table::Symbol<'a>, + in_scope: impl FnOnce(&mut Self, PolyFuncTypeBase) -> Result, + ) -> Result { + let mut imported_params = Vec::with_capacity(symbol.params.len()); - ctx.import_dfg_region(*region, node).map_err(|err| { - error_context!(err, "function body defined by region with id {}", *region) - })?; + for (index, param) in symbol.params.iter().enumerate() { + self.local_vars + .insert(table::VarId(node, index as _), LocalVar::new(param.r#type)); + } - Ok(node) - }) - } + for constraint in symbol.constraints { + if let Some([term]) = self.match_symbol(*constraint, model::CORE_NON_LINEAR)? { + let table::Term::Var(var) = self.get_term(term)? else { + return Err(error_unsupported!( + "constraint on term that is not a variable" + )); + }; - fn import_node_declare_func( - &mut self, - node_id: table::NodeId, - symbol: &'a table::Symbol<'a>, - parent: Node, - ) -> Result { - let visibility = symbol.visibility.clone().ok_or(ImportErrorInner::Invalid( - "No visibility for FuncDecl".to_string(), - ))?; - self.import_poly_func_type(node_id, *symbol, |ctx, signature| { - let func_name = ctx.import_title_metadata(node_id)?.unwrap_or(symbol.name); - - let optype = - OpType::FuncDecl(FuncDecl::new_vis(func_name, signature, visibility.into())); - let node = ctx.make_node(node_id, optype, parent)?; - Ok(node) - }) + self.local_vars + .get_mut(var) + .ok_or(table::ModelError::InvalidVar(*var))? + .bound = TypeBound::Copyable; + } else { + return Err(error_unsupported!("constraint other than copy or discard")); + } + } + + for (index, param) in symbol.params.iter().enumerate() { + // NOTE: `PolyFuncType` only has explicit type parameters at present. + let bound = self.local_vars[&table::VarId(node, index as _)].bound; + imported_params.push(self.import_type_param(param.r#type, bound)?); + } + + let body = self.import_func_type::(symbol.signature)?; + in_scope(self, PolyFuncTypeBase::new(imported_params, body)) } - fn import_node_custom( + /// Import a [`TypeParam`] from a term that represents a static type. + fn import_type_param( &mut self, - node_id: table::NodeId, - operation: table::TermId, - node_data: &'a table::Node<'a>, - parent: Node, - ) -> Result { - if let Some([inputs, outputs]) = self.match_symbol(operation, model::CORE_CALL_INDIRECT)? { - let inputs = self.import_type_row(inputs)?; - let outputs = self.import_type_row(outputs)?; - let signature = Signature::new(inputs, outputs); - let optype = OpType::CallIndirect(CallIndirect { signature }); - let node = self.make_node(node_id, optype, parent)?; - return Ok(node); + term_id: table::TermId, + bound: TypeBound, + ) -> Result { + if let Some([]) = self.match_symbol(term_id, model::CORE_STR_TYPE)? { + return Ok(TypeParam::String); } - if let Some([_, _, func]) = self.match_symbol(operation, model::CORE_CALL)? { - let table::Term::Apply(symbol, args) = self.get_term(func)? else { - return Err(error_invalid!( - "expected a symbol application to be passed to `{}`", - model::CORE_CALL - )); - }; - - let func_sig = self.get_func_signature(*symbol)?; + if let Some([]) = self.match_symbol(term_id, model::CORE_NAT_TYPE)? { + return Ok(TypeParam::max_nat()); + } - let type_args = args - .iter() - .map(|term| self.import_term(*term)) - .collect::, _>>()?; + if let Some([]) = self.match_symbol(term_id, model::CORE_BYTES_TYPE)? { + return Err(error_unsupported!( + "`{}` as `TypeParam`", + model::CORE_BYTES_TYPE + )); + } - self.static_edges.push((*symbol, node_id)); - let optype = OpType::Call( - Call::try_new(func_sig, type_args).map_err(ImportErrorInner::Signature)?, - ); + if let Some([]) = self.match_symbol(term_id, model::CORE_FLOAT_TYPE)? { + return Err(error_unsupported!( + "`{}` as `TypeParam`", + model::CORE_FLOAT_TYPE + )); + } - let node = self.make_node(node_id, optype, parent)?; - return Ok(node); + if let Some([]) = self.match_symbol(term_id, model::CORE_TYPE)? { + return Ok(TypeParam::Type { b: bound }); } - if let Some([_, value]) = self.match_symbol(operation, model::CORE_LOAD_CONST)? { - // If the constant refers directly to a function, import this as the `LoadFunc` operation. - if let table::Term::Apply(symbol, args) = self.get_term(value)? { - let func_node_data = self.get_node(*symbol)?; + if let Some([]) = self.match_symbol(term_id, model::CORE_STATIC)? { + return Err(error_unsupported!( + "`{}` as `TypeParam`", + model::CORE_STATIC + )); + } - if let table::Operation::DefineFunc(_) | table::Operation::DeclareFunc(_) = - func_node_data.operation - { - let func_sig = self.get_func_signature(*symbol)?; - let type_args = args - .iter() - .map(|term| self.import_term(*term)) - .collect::, _>>()?; + if let Some([]) = self.match_symbol(term_id, model::CORE_CONSTRAINT)? { + return Err(error_unsupported!( + "`{}` as `TypeParam`", + model::CORE_CONSTRAINT + )); + } - self.static_edges.push((*symbol, node_id)); + if let Some([]) = self.match_symbol(term_id, model::CORE_CONST)? { + return Err(error_unsupported!("`{}` as `TypeParam`", model::CORE_CONST)); + } - let optype = OpType::LoadFunction( - LoadFunction::try_new(func_sig, type_args) - .map_err(ImportErrorInner::Signature)?, - ); + if let Some([]) = self.match_symbol(term_id, model::CORE_CTRL_TYPE)? { + return Err(error_unsupported!( + "`{}` as `TypeParam`", + model::CORE_CTRL_TYPE + )); + } - let node = self.make_node(node_id, optype, parent)?; - return Ok(node); - } - } + if let Some([item_type]) = self.match_symbol(term_id, model::CORE_LIST_TYPE)? { + // At present `hugr-model` has no way to express that the item + // type of a list must be copyable. Therefore we import it as `Any`. + let param = Box::new(self.import_type_param(item_type, TypeBound::Any)?); + return Ok(TypeParam::List { param }); + } - // Otherwise use const nodes - let signature = node_data - .signature - .ok_or_else(|| error_uninferred!("node signature"))?; - let [_, outputs] = self.get_func_type(signature)?; - let outputs = self.import_closed_list(outputs)?; - let output = outputs.first().ok_or_else(|| { - error_invalid!("`{}` expects a single output", model::CORE_LOAD_CONST) - })?; - let datatype = self.import_type(*output)?; - - let imported_value = self.import_value(value, *output)?; - - let load_const_node = self.make_node( - node_id, - OpType::LoadConstant(LoadConstant { - datatype: datatype.clone(), - }), - parent, - )?; + if let Some([_]) = self.match_symbol(term_id, model::CORE_TUPLE_TYPE)? { + // At present `hugr-model` has no way to express that the item + // types of a tuple must be copyable. Therefore we import it as `Any`. + todo!("import tuple type"); + } - let const_node = self - .hugr - .add_node_with_parent(parent, OpType::Const(Const::new(imported_value))); + match self.get_term(term_id)? { + table::Term::Wildcard => Err(error_uninferred!("wildcard")), - self.hugr.connect(const_node, 0, load_const_node, 0); + table::Term::Var { .. } => Err(error_unsupported!("type variable as `TypeParam`")), + table::Term::Apply(symbol, _) => { + let name = self.get_symbol_name(*symbol)?; + Err(error_unsupported!("custom type `{}` as `TypeParam`", name)) + } - return Ok(load_const_node); + table::Term::Tuple(_) + | table::Term::List { .. } + | table::Term::Func { .. } + | table::Term::Literal(_) => Err(table::ModelError::TypeError(term_id).into()), } + } - if let Some([_, _, tag]) = self.match_symbol(operation, model::CORE_MAKE_ADT)? { - let table::Term::Literal(model::Literal::Nat(tag)) = self.get_term(tag)? else { - return Err(error_invalid!( - "`{}` expects a nat literal tag", - model::CORE_MAKE_ADT - )); - }; - - let signature = node_data - .signature - .ok_or_else(|| error_uninferred!("node signature"))?; - let [_, outputs] = self.get_func_type(signature)?; - let (variants, _) = self.import_adt_and_rest(outputs)?; - let node = self.make_node( - node_id, - OpType::Tag(Tag { - variants, - tag: *tag as usize, - }), - parent, - )?; - return Ok(node); + /// Import a `TypeArg` from a term that represents a static type or value. + fn import_type_arg(&mut self, term_id: table::TermId) -> Result { + if let Some([]) = self.match_symbol(term_id, model::CORE_STR_TYPE)? { + return Err(error_unsupported!( + "`{}` as `TypeArg`", + model::CORE_STR_TYPE + )); } - let table::Term::Apply(node, params) = self.get_term(operation)? else { - return Err(error_invalid!( - "custom operations expect a symbol application referencing an operation" + if let Some([]) = self.match_symbol(term_id, model::CORE_NAT_TYPE)? { + return Err(error_unsupported!( + "`{}` as `TypeArg`", + model::CORE_NAT_TYPE )); - }; - let name = self.get_symbol_name(*node)?; - let args = params - .iter() - .map(|param| self.import_term(*param)) - .collect::, _>>()?; - let (extension, name) = self.import_custom_name(name)?; - let signature = self.get_node_signature(node_id)?; - - // TODO: Currently we do not have the description or any other metadata for - // the custom op. This will improve with declarative extensions being able - // to declare operations as a node, in which case the description will be attached - // to that node as metadata. - - let optype = OpType::OpaqueOp(OpaqueOp::new(extension, name, args, signature)); - self.make_node(node_id, optype, parent) - } + } - fn import_node_define_alias( - &mut self, - node_id: table::NodeId, - symbol: &'a table::Symbol<'a>, - value: table::TermId, - parent: Node, - ) -> Result { - if !symbol.params.is_empty() { + if let Some([]) = self.match_symbol(term_id, model::CORE_BYTES_TYPE)? { return Err(error_unsupported!( - "parameters or constraints in alias definition" + "`{}` as `TypeArg`", + model::CORE_BYTES_TYPE )); } - let optype = OpType::AliasDefn(AliasDefn { - name: symbol.name.to_smolstr(), - definition: self.import_type(value)?, - }); - - let node = self.make_node(node_id, optype, parent)?; - Ok(node) - } - - fn import_node_declare_alias( - &mut self, - node_id: table::NodeId, - symbol: &'a table::Symbol<'a>, - parent: Node, - ) -> Result { - if !symbol.params.is_empty() { + if let Some([]) = self.match_symbol(term_id, model::CORE_FLOAT_TYPE)? { return Err(error_unsupported!( - "parameters or constraints in alias declaration" + "`{}` as `TypeArg`", + model::CORE_FLOAT_TYPE )); } - let optype = OpType::AliasDecl(AliasDecl { - name: symbol.name.to_smolstr(), - bound: TypeBound::Copyable, - }); - - let node = self.make_node(node_id, optype, parent)?; - Ok(node) - } - - fn import_poly_func_type( - &mut self, - node: table::NodeId, - symbol: table::Symbol<'a>, - in_scope: impl FnOnce(&mut Self, PolyFuncTypeBase) -> Result, - ) -> Result { - (|| { - let mut imported_params = Vec::with_capacity(symbol.params.len()); - - for (index, param) in symbol.params.iter().enumerate() { - self.local_vars - .insert(table::VarId(node, index as _), LocalVar::new(param.r#type)); - } + if let Some([]) = self.match_symbol(term_id, model::CORE_TYPE)? { + return Err(error_unsupported!("`{}` as `TypeArg`", model::CORE_TYPE)); + } - for constraint in symbol.constraints { - if let Some([term]) = self.match_symbol(*constraint, model::CORE_NON_LINEAR)? { - let table::Term::Var(var) = self.get_term(term)? else { - return Err(error_unsupported!( - "constraint on term that is not a variable" - )); - }; + if let Some([]) = self.match_symbol(term_id, model::CORE_CONSTRAINT)? { + return Err(error_unsupported!( + "`{}` as `TypeArg`", + model::CORE_CONSTRAINT + )); + } - self.local_vars - .get_mut(var) - .ok_or_else(|| error_invalid!("unknown variable {}", var))? - .bound = TypeBound::Copyable; - } else { - return Err(error_unsupported!("constraint other than copy or discard")); - } - } + if let Some([]) = self.match_symbol(term_id, model::CORE_STATIC)? { + return Err(error_unsupported!("`{}` as `TypeArg`", model::CORE_STATIC)); + } - for (index, param) in symbol.params.iter().enumerate() { - // NOTE: `PolyFuncType` only has explicit type parameters at present. - let bound = self.local_vars[&table::VarId(node, index as _)].bound; - imported_params.push( - self.import_term_with_bound(param.r#type, bound) - .map_err(|err| error_context!(err, "type of parameter `{}`", param.name))?, - ); - } + if let Some([]) = self.match_symbol(term_id, model::CORE_CTRL_TYPE)? { + return Err(error_unsupported!( + "`{}` as `TypeArg`", + model::CORE_CTRL_TYPE + )); + } - let body = self.import_func_type::(symbol.signature)?; - in_scope(self, PolyFuncTypeBase::new(imported_params, body)) - })() - .map_err(|err| error_context!(err, "symbol `{}` defined by node {}", symbol.name, node)) - } + if let Some([]) = self.match_symbol(term_id, model::CORE_CONST)? { + return Err(error_unsupported!("`{}` as `TypeArg`", model::CORE_CONST)); + } - /// Import a [`Term`] from a term that represents a static type or value. - fn import_term(&mut self, term_id: table::TermId) -> Result { - self.import_term_with_bound(term_id, TypeBound::Linear) - } + if let Some([]) = self.match_symbol(term_id, model::CORE_LIST_TYPE)? { + return Err(error_unsupported!( + "`{}` as `TypeArg`", + model::CORE_LIST_TYPE + )); + } - fn import_term_with_bound( - &mut self, - term_id: table::TermId, - bound: TypeBound, - ) -> Result { - (|| { - if let Some([]) = self.match_symbol(term_id, model::CORE_STR_TYPE)? { - return Ok(Term::StringType); - } + match self.get_term(term_id)? { + table::Term::Wildcard => Err(error_uninferred!("wildcard")), - if let Some([]) = self.match_symbol(term_id, model::CORE_NAT_TYPE)? { - return Ok(Term::max_nat_type()); + table::Term::Var(var) => { + let var_info = self + .local_vars + .get(var) + .ok_or(table::ModelError::InvalidVar(*var))?; + let decl = self.import_type_param(var_info.r#type, var_info.bound)?; + Ok(TypeArg::new_var_use(var.1 as _, decl)) } - if let Some([]) = self.match_symbol(term_id, model::CORE_BYTES_TYPE)? { - return Ok(Term::BytesType); - } + table::Term::List { .. } => { + let elems = self + .import_closed_list(term_id)? + .iter() + .map(|item| self.import_type_arg(*item)) + .collect::>()?; - if let Some([]) = self.match_symbol(term_id, model::CORE_FLOAT_TYPE)? { - return Ok(Term::FloatType); + Ok(TypeArg::Sequence { elems }) } - if let Some([]) = self.match_symbol(term_id, model::CORE_TYPE)? { - return Ok(TypeParam::RuntimeType(bound)); + table::Term::Tuple { .. } => { + // NOTE: While `TypeArg`s can represent tuples as + // `TypeArg::Sequence`s, this conflates lists and tuples. To + // avoid ambiguity we therefore report an error here for now. + Err(error_unsupported!("tuples as `TypeArg`")) } - if let Some([]) = self.match_symbol(term_id, model::CORE_CONSTRAINT)? { - return Err(error_unsupported!("`{}`", model::CORE_CONSTRAINT)); - } + table::Term::Literal(model::Literal::Str(value)) => Ok(TypeArg::String { + arg: value.to_string(), + }), - if let Some([]) = self.match_symbol(term_id, model::CORE_STATIC)? { - return Ok(Term::StaticType); + table::Term::Literal(model::Literal::Nat(value)) => { + Ok(TypeArg::BoundedNat { n: *value }) } - if let Some([]) = self.match_symbol(term_id, model::CORE_CONST)? { - return Err(error_unsupported!("`{}`", model::CORE_CONST)); + table::Term::Literal(model::Literal::Bytes(_)) => { + Err(error_unsupported!("`(bytes ..)` as `TypeArg`")) } - - if let Some([item_type]) = self.match_symbol(term_id, model::CORE_LIST_TYPE)? { - // At present `hugr-model` has no way to express that the item - // type of a list must be copyable. Therefore we import it as `Any`. - let item_type = self - .import_term(item_type) - .map_err(|err| error_context!(err, "item type of list type"))?; - return Ok(TypeParam::new_list_type(item_type)); + table::Term::Literal(model::Literal::Float(_)) => { + Err(error_unsupported!("float literal as `TypeArg`")) } + table::Term::Func { .. } => Err(error_unsupported!("function constant as `TypeArg`")), - if let Some([item_types]) = self.match_symbol(term_id, model::CORE_TUPLE_TYPE)? { - // At present `hugr-model` has no way to express that the item - // types of a tuple must be copyable. Therefore we import it as `Any`. - let item_types = self - .import_term(item_types) - .map_err(|err| error_context!(err, "item types of tuple type"))?; - return Ok(TypeParam::new_tuple_type(item_types)); + table::Term::Apply { .. } => { + let ty = self.import_type(term_id)?; + Ok(TypeArg::Type { ty }) } - - match self.get_term(term_id)? { - table::Term::Wildcard => Err(error_uninferred!("wildcard")), - - table::Term::Var(var) => { - let var_info = self - .local_vars - .get(var) - .ok_or_else(|| error_invalid!("unknown variable {}", var))?; - let decl = self.import_term_with_bound(var_info.r#type, var_info.bound)?; - Ok(Term::new_var_use(var.1 as _, decl)) - } - - table::Term::List(parts) => { - // PERFORMANCE: Can we do this without the additional allocation? - let parts: Vec<_> = parts - .iter() - .map(|part| self.import_seq_part(part)) - .collect::>() - .map_err(|err| error_context!(err, "list parts"))?; - Ok(TypeArg::new_list_from_parts(parts)) - } - - table::Term::Tuple(parts) => { - // PERFORMANCE: Can we do this without the additional allocation? - let parts: Vec<_> = parts - .iter() - .map(|part| self.import_seq_part(part)) - .try_collect() - .map_err(|err| error_context!(err, "tuple parts"))?; - Ok(TypeArg::new_tuple_from_parts(parts)) - } - - table::Term::Literal(model::Literal::Str(value)) => { - Ok(Term::String(value.to_string())) - } - - table::Term::Literal(model::Literal::Nat(value)) => Ok(Term::BoundedNat(*value)), - - table::Term::Literal(model::Literal::Bytes(value)) => { - Ok(Term::Bytes(value.clone())) - } - table::Term::Literal(model::Literal::Float(value)) => Ok(Term::Float(*value)), - table::Term::Func { .. } => Err(error_unsupported!("function constant")), - - table::Term::Apply { .. } => { - let ty: Type = self.import_type(term_id)?; - Ok(ty.into()) - } - } - })() - .map_err(|err| error_context!(err, "term {}", term_id)) - } - - fn import_seq_part( - &mut self, - seq_part: &'a table::SeqPart, - ) -> Result, ImportErrorInner> { - Ok(match seq_part { - table::SeqPart::Item(term_id) => SeqPart::Item(self.import_term(*term_id)?), - table::SeqPart::Splice(term_id) => SeqPart::Splice(self.import_term(*term_id)?), - }) + } } /// Import a `Type` from a term that represents a runtime type. fn import_type( &mut self, term_id: table::TermId, - ) -> Result, ImportErrorInner> { - (|| { - if let Some([_, _]) = self.match_symbol(term_id, model::CORE_FN)? { - let func_type = self.import_func_type::(term_id)?; - return Ok(TypeBase::new_function(func_type)); - } - - if let Some([variants]) = self.match_symbol(term_id, model::CORE_ADT)? { - let variants = (|| { - self.import_closed_list(variants)? - .iter() - .map(|variant| self.import_type_row::(*variant)) - .collect::, _>>() - })() - .map_err(|err| error_context!(err, "adt variants"))?; + ) -> Result, ImportError> { + if let Some([_, _]) = self.match_symbol(term_id, model::CORE_FN)? { + let func_type = self.import_func_type::(term_id)?; + return Ok(TypeBase::new_function(func_type)); + } - return Ok(TypeBase::new_sum(variants)); - } + if let Some([variants]) = self.match_symbol(term_id, model::CORE_ADT)? { + let variants = self.import_closed_list(variants)?; + let variants = variants + .iter() + .map(|variant| self.import_type_row::(*variant)) + .collect::, _>>()?; + return Ok(TypeBase::new_sum(variants)); + } - match self.get_term(term_id)? { - table::Term::Wildcard => Err(error_uninferred!("wildcard")), + match self.get_term(term_id)? { + table::Term::Wildcard => Err(error_uninferred!("wildcard")), - table::Term::Apply(symbol, args) => { - let name = self.get_symbol_name(*symbol)?; + table::Term::Apply(symbol, args) => { + let args = args + .iter() + .map(|arg| self.import_type_arg(*arg)) + .collect::, _>>()?; + + let name = self.get_symbol_name(*symbol)?; + let (extension, id) = self.import_custom_name(name)?; + + let extension_ref = + self.extensions + .get(&extension) + .ok_or_else(|| ImportError::Extension { + missing_ext: extension.clone(), + available: self.extensions.ids().cloned().collect(), + })?; - let args = args - .iter() - .map(|arg| self.import_term(*arg)) - .collect::, _>>() - .map_err(|err| { - error_context!(err, "type argument of custom type `{}`", name) + let ext_type = + extension_ref + .get_type(&id) + .ok_or_else(|| ImportError::ExtensionType { + ext: extension.clone(), + name: id.clone(), })?; - let (extension, id) = self.import_custom_name(name)?; - - let extension_ref = - self.extensions - .get(&extension) - .ok_or_else(|| ExtensionError::Missing { - missing_ext: extension.clone(), - available: self.extensions.ids().cloned().collect(), - })?; - - let ext_type = - extension_ref - .get_type(&id) - .ok_or_else(|| ExtensionError::MissingType { - ext: extension.clone(), - name: id.clone(), - })?; - - let bound = ext_type.bound(&args); - - Ok(TypeBase::new_extension(CustomType::new( - id, - args, - extension, - bound, - &Arc::downgrade(extension_ref), - ))) - } + let bound = ext_type.bound(&args); - table::Term::Var(var @ table::VarId(_, index)) => { - let local_var = self - .local_vars - .get(var) - .ok_or(error_invalid!("unknown var {}", var))?; - Ok(TypeBase::new_var_use(*index as _, local_var.bound)) - } + Ok(TypeBase::new_extension(CustomType::new( + id, + args, + extension, + bound, + &Arc::downgrade(extension_ref), + ))) + } - // The following terms are not runtime types, but the core `Type` only contains runtime types. - // We therefore report a type error here. - table::Term::List { .. } - | table::Term::Tuple { .. } - | table::Term::Literal(_) - | table::Term::Func { .. } => Err(error_invalid!("expected a runtime type")), + table::Term::Var(var @ table::VarId(_, index)) => { + let local_var = self + .local_vars + .get(var) + .ok_or(table::ModelError::InvalidVar(*var))?; + Ok(TypeBase::new_var_use(*index as _, local_var.bound)) } - })() - .map_err(|err| error_context!(err, "term {} as `Type`", term_id)) - } - fn get_func_type( - &mut self, - term_id: table::TermId, - ) -> Result<[table::TermId; 2], ImportErrorInner> { - self.match_symbol(term_id, model::CORE_FN)? - .ok_or(error_invalid!("expected a function type")) + // The following terms are not runtime types, but the core `Type` only contains runtime types. + // We therefore report a type error here. + table::Term::List { .. } + | table::Term::Tuple { .. } + | table::Term::Literal(_) + | table::Term::Func { .. } => Err(table::ModelError::TypeError(term_id).into()), + } } - fn get_ctrl_type( - &mut self, - term_id: table::TermId, - ) -> Result<[table::TermId; 2], ImportErrorInner> { - self.match_symbol(term_id, model::CORE_CTRL)? - .ok_or(error_invalid!("expected a control type")) + fn get_func_type(&mut self, term_id: table::TermId) -> Result<[table::TermId; 2], ImportError> { + self.match_symbol(term_id, model::CORE_FN)? + .ok_or(table::ModelError::TypeError(term_id).into()) } fn import_func_type( &mut self, term_id: table::TermId, - ) -> Result, ImportErrorInner> { - (|| { - let [inputs, outputs] = self.get_func_type(term_id)?; - let inputs = self - .import_type_row(inputs) - .map_err(|err| error_context!(err, "function inputs"))?; - let outputs = self - .import_type_row(outputs) - .map_err(|err| error_context!(err, "function outputs"))?; - Ok(FuncTypeBase::new(inputs, outputs)) - })() - .map_err(|err| error_context!(err, "function type")) + ) -> Result, ImportError> { + let [inputs, outputs] = self.get_func_type(term_id)?; + let inputs = self.import_type_row(inputs)?; + let outputs = self.import_type_row(outputs)?; + Ok(FuncTypeBase::new(inputs, outputs)) } fn import_closed_list( &mut self, term_id: table::TermId, - ) -> Result, ImportErrorInner> { + ) -> Result, ImportError> { fn import_into( ctx: &mut Context, term_id: table::TermId, types: &mut Vec, - ) -> Result<(), ImportErrorInner> { + ) -> Result<(), ImportError> { match ctx.get_term(term_id)? { table::Term::List(parts) => { types.reserve(parts.len()); @@ -1552,7 +1320,7 @@ impl<'a> Context<'a> { } } } - _ => return Err(error_invalid!("expected a closed list")), + _ => return Err(table::ModelError::TypeError(term_id).into()), } Ok(()) @@ -1566,12 +1334,12 @@ impl<'a> Context<'a> { fn import_closed_tuple( &mut self, term_id: table::TermId, - ) -> Result, ImportErrorInner> { + ) -> Result, ImportError> { fn import_into( ctx: &mut Context, term_id: table::TermId, types: &mut Vec, - ) -> Result<(), ImportErrorInner> { + ) -> Result<(), ImportError> { match ctx.get_term(term_id)? { table::Term::Tuple(parts) => { types.reserve(parts.len()); @@ -1587,7 +1355,7 @@ impl<'a> Context<'a> { } } } - _ => return Err(error_invalid!("expected a closed tuple")), + _ => return Err(table::ModelError::TypeError(term_id).into()), } Ok(()) @@ -1601,7 +1369,7 @@ impl<'a> Context<'a> { fn import_type_rows( &mut self, term_id: table::TermId, - ) -> Result>, ImportErrorInner> { + ) -> Result>, ImportError> { self.import_closed_list(term_id)? .into_iter() .map(|term_id| self.import_type_row::(term_id)) @@ -1611,12 +1379,12 @@ impl<'a> Context<'a> { fn import_type_row( &mut self, term_id: table::TermId, - ) -> Result, ImportErrorInner> { + ) -> Result, ImportError> { fn import_into( ctx: &mut Context, term_id: table::TermId, types: &mut Vec>, - ) -> Result<(), ImportErrorInner> { + ) -> Result<(), ImportError> { match ctx.get_term(term_id)? { table::Term::List(parts) => { types.reserve(parts.len()); @@ -1633,11 +1401,11 @@ impl<'a> Context<'a> { } } table::Term::Var(table::VarId(_, index)) => { - let var = RV::try_from_rv(RowVariable(*index as _, TypeBound::Linear)) - .map_err(|_| error_invalid!("expected a closed list"))?; + let var = RV::try_from_rv(RowVariable(*index as _, TypeBound::Any)) + .map_err(|_| table::ModelError::TypeError(term_id))?; types.push(TypeBase::new(TypeEnum::RowVar(var))); } - _ => return Err(error_invalid!("expected a list")), + _ => return Err(table::ModelError::TypeError(term_id).into()), } Ok(()) @@ -1651,17 +1419,17 @@ impl<'a> Context<'a> { fn import_custom_name( &mut self, symbol: &'a str, - ) -> Result<(ExtensionId, SmolStr), ImportErrorInner> { + ) -> Result<(ExtensionId, SmolStr), ImportError> { use std::collections::hash_map::Entry; match self.custom_name_cache.entry(symbol) { Entry::Occupied(occupied_entry) => Ok(occupied_entry.get().clone()), Entry::Vacant(vacant_entry) => { let qualified_name = ExtensionId::new(symbol) - .map_err(|_| error_invalid!("`{}` is not a valid symbol name", symbol))?; + .map_err(|_| table::ModelError::MalformedName(symbol.to_smolstr()))?; let (extension, id) = qualified_name .split_last() - .ok_or_else(|| error_invalid!("`{}` is not a valid symbol name", symbol))?; + .ok_or_else(|| table::ModelError::MalformedName(symbol.to_smolstr()))?; vacant_entry.insert((extension.clone(), id.clone())); Ok((extension, id)) @@ -1673,7 +1441,7 @@ impl<'a> Context<'a> { &mut self, term_id: table::TermId, type_id: table::TermId, - ) -> Result { + ) -> Result { let term_data = self.get_term(term_id)?; // NOTE: We have special cased arrays, integers, and floats for now. @@ -1681,10 +1449,7 @@ impl<'a> Context<'a> { if let Some([runtime_type, json]) = self.match_symbol(term_id, model::COMPAT_CONST_JSON)? { let table::Term::Literal(model::Literal::Str(json)) = self.get_term(json)? else { - return Err(error_invalid!( - "`{}` expects a string literal", - model::COMPAT_CONST_JSON - )); + return Err(table::ModelError::TypeError(term_id).into()); }; // We attempt to deserialize as the custom const directly. @@ -1697,12 +1462,8 @@ impl<'a> Context<'a> { return Ok(Value::Extension { e: opaque_value }); } else { let runtime_type = self.import_type(runtime_type)?; - let value: serde_json::Value = serde_json::from_str(json).map_err(|_| { - error_invalid!( - "unable to parse JSON string for `{}`", - model::COMPAT_CONST_JSON - ) - })?; + let value: serde_json::Value = serde_json::from_str(json) + .map_err(|_| table::ModelError::TypeError(term_id))?; let custom_const = CustomSerialized::new(runtime_type, value); let opaque_value = OpaqueValue::new(custom_const); return Ok(Value::Extension { e: opaque_value }); @@ -1726,42 +1487,29 @@ impl<'a> Context<'a> { let table::Term::Literal(model::Literal::Nat(bitwidth)) = self.get_term(bitwidth)? else { - return Err(error_invalid!( - "`{}` expects a nat literal in its `bitwidth` argument", - ConstInt::CTR_NAME - )); + return Err(table::ModelError::TypeError(term_id).into()); }; if *bitwidth > 6 { - return Err(error_invalid!( - "`{}` expects the bitwidth to be at most 6, got {}", - ConstInt::CTR_NAME, - bitwidth - )); + return Err(table::ModelError::TypeError(term_id).into()); } *bitwidth as u8 }; let value = { let table::Term::Literal(model::Literal::Nat(value)) = self.get_term(value)? else { - return Err(error_invalid!( - "`{}` expects a nat literal value", - ConstInt::CTR_NAME - )); + return Err(table::ModelError::TypeError(term_id).into()); }; *value }; return Ok(ConstInt::new_u(bitwidth, value) - .map_err(|_| error_invalid!("failed to create int constant"))? + .map_err(|_| table::ModelError::TypeError(term_id))? .into()); } if let Some([value]) = self.match_symbol(term_id, ConstF64::CTR_NAME)? { let table::Term::Literal(model::Literal::Float(value)) = self.get_term(value)? else { - return Err(error_invalid!( - "`{}` expects a float literal value", - ConstF64::CTR_NAME - )); + return Err(table::ModelError::TypeError(term_id).into()); }; return Ok(ConstF64::new(value.into_inner()).into()); @@ -1773,16 +1521,12 @@ impl<'a> Context<'a> { let variants = self.import_closed_list(variants)?; let table::Term::Literal(model::Literal::Nat(tag)) = self.get_term(tag)? else { - return Err(error_invalid!( - "`{}` expects a nat literal tag", - model::CORE_ADT - )); + return Err(table::ModelError::TypeError(term_id).into()); }; - let variant = variants.get(*tag as usize).ok_or(error_invalid!( - "the tag of a `{}` must be a valid index into the list of variants", - model::CORE_CONST_ADT - ))?; + let variant = variants + .get(*tag as usize) + .ok_or(table::ModelError::TypeError(term_id))?; let variant = self.import_closed_list(*variant)?; @@ -1820,7 +1564,7 @@ impl<'a> Context<'a> { } table::Term::List { .. } | table::Term::Tuple(_) | table::Term::Literal(_) => { - Err(error_invalid!("expected constant")) + Err(table::ModelError::TypeError(term_id).into()) } table::Term::Func { .. } => Err(error_unsupported!("constant function value")), @@ -1831,7 +1575,7 @@ impl<'a> Context<'a> { &self, term_id: table::TermId, name: &str, - ) -> Result, ImportErrorInner> { + ) -> Result, ImportError> { let term = self.get_term(term_id)?; // TODO: Follow alias chains? @@ -1865,36 +1609,9 @@ impl<'a> Context<'a> { &self, term_id: table::TermId, name: &str, - ) -> Result<[table::TermId; N], ImportErrorInner> { - self.match_symbol(term_id, name)?.ok_or(error_invalid!( - "expected symbol `{}` with arity {}", - name, - N - )) - } - - /// Searches for `core.title` metadata on the given node. - fn import_title_metadata( - &self, - node_id: table::NodeId, - ) -> Result, ImportErrorInner> { - let node_data = self.get_node(node_id)?; - for meta in node_data.meta { - let Some([name]) = self.match_symbol(*meta, model::CORE_TITLE)? else { - continue; - }; - - let table::Term::Literal(model::Literal::Str(name)) = self.get_term(name)? else { - return Err(error_invalid!( - "`{}` metadata expected a string literal as argument", - model::CORE_TITLE - )); - }; - - return Ok(Some(name.as_str())); - } - - Ok(None) + ) -> Result<[table::TermId; N], ImportError> { + self.match_symbol(term_id, name)? + .ok_or(table::ModelError::TypeError(term_id).into()) } } @@ -1911,7 +1628,7 @@ impl LocalVar { pub fn new(r#type: table::TermId) -> Self { Self { r#type, - bound: TypeBound::Linear, + bound: TypeBound::Any, } } } diff --git a/hugr-core/src/lib.rs b/hugr-core/src/lib.rs index 862b8dee8a..e5f57d2a8f 100644 --- a/hugr-core/src/lib.rs +++ b/hugr-core/src/lib.rs @@ -24,8 +24,7 @@ pub mod types; pub mod utils; pub use crate::core::{ - CircuitUnit, Direction, IncomingPort, Node, NodeIndex, OutgoingPort, Port, PortIndex, - Visibility, Wire, + CircuitUnit, Direction, IncomingPort, Node, NodeIndex, OutgoingPort, Port, PortIndex, Wire, }; pub use crate::extension::Extension; pub use crate::hugr::{Hugr, HugrView, SimpleReplacement}; diff --git a/hugr-core/src/ops/constant.rs b/hugr-core/src/ops/constant.rs index deebe5434f..d27a4a0ad8 100644 --- a/hugr-core/src/ops/constant.rs +++ b/hugr-core/src/ops/constant.rs @@ -1,7 +1,6 @@ //! Constant value definitions. mod custom; -mod serialize; use std::borrow::Cow; use std::collections::hash_map::DefaultHasher; // Moves into std::hash in Rust 1.76. @@ -12,7 +11,6 @@ use super::{OpTag, OpType}; use crate::envelope::serde_with::AsStringEnvelope; use crate::types::{CustomType, EdgeKind, Signature, SumType, SumTypeError, Type, TypeRow}; use crate::{Hugr, HugrView}; -use serialize::SerialSum; use delegate::delegate; use itertools::Itertools; @@ -109,6 +107,16 @@ impl AsRef for Const { } } +#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)] +struct SerialSum { + #[serde(default)] + tag: usize, + #[serde(rename = "vs")] + values: Vec, + #[serde(default, rename = "typ")] + sum_type: Option, +} + #[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)] #[serde(try_from = "SerialSum")] #[serde(into = "SerialSum")] @@ -152,6 +160,43 @@ pub(crate) fn maybe_hash_values(vals: &[Value], st: &mut H) -> bool { } } +impl TryFrom for Sum { + type Error = &'static str; + + fn try_from(value: SerialSum) -> Result { + let SerialSum { + tag, + values, + sum_type, + } = value; + + let sum_type = if let Some(sum_type) = sum_type { + sum_type + } else { + if tag != 0 { + return Err("Sum type must be provided if tag is not 0"); + } + SumType::new_tuple(values.iter().map(Value::get_type).collect_vec()) + }; + + Ok(Self { + tag, + values, + sum_type, + }) + } +} + +impl From for SerialSum { + fn from(value: Sum) -> Self { + Self { + tag: value.tag, + values: value.values, + sum_type: Some(value.sum_type), + } + } +} + #[serde_as] #[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)] #[serde(tag = "v")] @@ -282,9 +327,9 @@ pub enum CustomCheckFailure { #[error("Expected type: {expected} but value was of type: {found}")] TypeMismatch { /// The expected custom type. - expected: Box, + expected: CustomType, /// The custom type found when checking. - found: Box, + found: Type, }, /// Any other message #[error("{0}")] @@ -304,11 +349,11 @@ pub enum ConstTypeError { )] NotMonomorphicFunction { /// The root node type of the Hugr that (claims to) define the function constant. - hugr_root_type: Box, + hugr_root_type: OpType, }, /// A mismatch between the type expected and the value. #[error("Value {1:?} does not match expected type {0}")] - ConstCheckFail(Box, Value), + ConstCheckFail(Type, Value), /// Error when checking a custom value. #[error("Error when checking custom type: {0}")] CustomCheckFail(#[from] CustomCheckFailure), @@ -317,7 +362,7 @@ pub enum ConstTypeError { /// Hugrs (even functions) inside Consts must be monomorphic fn mono_fn_type(h: &Hugr) -> Result, ConstTypeError> { let err = || ConstTypeError::NotMonomorphicFunction { - hugr_root_type: Box::new(h.entrypoint_optype().clone()), + hugr_root_type: h.entrypoint_optype().clone(), }; if let Some(pf) = h.poly_func_type() { match pf.try_into() { @@ -683,7 +728,7 @@ pub(crate) mod test { index: 1, expected, found, - })) if *expected == float64_type() && *found == const_usize() + })) if expected == float64_type() && found == const_usize() ); } @@ -815,7 +860,7 @@ pub(crate) mod test { let ex_id: ExtensionId = "my_extension".try_into().unwrap(); let typ_int = CustomType::new( "my_type", - vec![TypeArg::BoundedNat(8)], + vec![TypeArg::BoundedNat { n: 8 }], ex_id.clone(), TypeBound::Copyable, // Dummy extension reference. diff --git a/hugr-core/src/ops/constant/serialize.rs b/hugr-core/src/ops/constant/serialize.rs deleted file mode 100644 index 1ccfe523b3..0000000000 --- a/hugr-core/src/ops/constant/serialize.rs +++ /dev/null @@ -1,59 +0,0 @@ -//! Helper definitions used to serialize constant values and ops. - -use itertools::Itertools; - -use crate::ops::Value; -use crate::types::SumType; -use crate::types::serialize::SerSimpleType; - -use super::Sum; - -/// Helper struct to serialize constant [`Sum`] values with a custom layout. -#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] -pub(super) struct SerialSum { - #[serde(default)] - tag: usize, - #[serde(rename = "vs")] - values: Vec, - /// Uses the `SerSimpleType` wrapper here instead of a direct `SumType`, - /// to ensure it gets correctly tagged with the `t` discriminant field. - #[serde(default, rename = "typ")] - sum_type: Option, -} - -impl From for SerialSum { - fn from(value: Sum) -> Self { - Self { - tag: value.tag, - values: value.values, - sum_type: Some(SerSimpleType::Sum(value.sum_type)), - } - } -} - -impl TryFrom for Sum { - type Error = &'static str; - - fn try_from(value: SerialSum) -> Result { - let SerialSum { - tag, - values, - sum_type, - } = value; - - let sum_type = if let Some(SerSimpleType::Sum(sum_type)) = sum_type { - sum_type - } else { - if tag != 0 { - return Err("Sum type must be provided if tag is not 0"); - } - SumType::new_tuple(values.iter().map(Value::get_type).collect_vec()) - }; - - Ok(Self { - tag, - values, - sum_type, - }) - } -} diff --git a/hugr-core/src/ops/controlflow.rs b/hugr-core/src/ops/controlflow.rs index 874a2b6ce9..ca358c624b 100644 --- a/hugr-core/src/ops/controlflow.rs +++ b/hugr-core/src/ops/controlflow.rs @@ -359,7 +359,7 @@ mod test { #[test] fn test_subst_dataflow_block() { use crate::ops::OpTrait; - let tv0 = Type::new_var_use(0, TypeBound::Linear); + let tv0 = Type::new_var_use(0, TypeBound::Any); let dfb = DataflowBlock { inputs: vec![usize_t(), tv0.clone()].into(), other_outputs: vec![tv0.clone()].into(), @@ -375,18 +375,16 @@ mod test { #[test] fn test_subst_conditional() { - let tv1 = Type::new_var_use(1, TypeBound::Linear); + let tv1 = Type::new_var_use(1, TypeBound::Any); let cond = Conditional { sum_rows: vec![usize_t().into(), tv1.clone().into()], - other_inputs: vec![Type::new_tuple(TypeRV::new_row_var_use( - 0, - TypeBound::Linear, - ))] - .into(), + other_inputs: vec![Type::new_tuple(TypeRV::new_row_var_use(0, TypeBound::Any))].into(), outputs: vec![usize_t(), tv1].into(), }; let cond2 = cond.substitute(&Substitution::new(&[ - TypeArg::new_list([usize_t().into(), usize_t().into(), usize_t().into()]), + TypeArg::Sequence { + elems: vec![usize_t().into(); 3], + }, qb_t().into(), ])); let st = Type::new_sum(vec![usize_t(), qb_t()]); //both single-element variants diff --git a/hugr-core/src/ops/custom.rs b/hugr-core/src/ops/custom.rs index 139c87505b..f639584789 100644 --- a/hugr-core/src/ops/custom.rs +++ b/hugr-core/src/ops/custom.rs @@ -7,8 +7,7 @@ use thiserror::Error; #[cfg(test)] use { crate::extension::test::SimpleOpDef, crate::proptest::any_nonempty_smolstr, - crate::types::proptest_utils::any_serde_type_arg_vec, ::proptest::prelude::*, - ::proptest_derive::Arbitrary, + ::proptest::prelude::*, ::proptest_derive::Arbitrary, }; use crate::core::HugrNode; @@ -36,7 +35,6 @@ pub struct ExtensionOp { proptest(strategy = "any::().prop_map(|x| Arc::new(x.into()))") )] def: Arc, - #[cfg_attr(test, proptest(strategy = "any_serde_type_arg_vec()"))] args: Vec, signature: Signature, // Cache } @@ -237,7 +235,6 @@ pub struct OpaqueOp { extension: ExtensionId, #[cfg_attr(test, proptest(strategy = "any_nonempty_smolstr()"))] name: OpName, - #[cfg_attr(test, proptest(strategy = "any_serde_type_arg_vec()"))] args: Vec, // note that the `signature` field might not include `extension`. Thus this must // remain private, and should be accessed through @@ -356,8 +353,8 @@ pub enum OpaqueOpError { node: N, extension: ExtensionId, op: OpName, - stored: Box, - computed: Box, + stored: Signature, + computed: Signature, }, /// An error in computing the signature of the `ExtensionOp` #[error("Error in signature of operation '{name}' in {node}: {cause}")] @@ -409,11 +406,11 @@ mod test { let op = OpaqueOp::new( "res".try_into().unwrap(), "op", - vec![usize_t().into()], + vec![TypeArg::Type { ty: usize_t() }], sig.clone(), ); assert_eq!(op.name(), "OpaqueOp:res.op"); - assert_eq!(op.args(), &[usize_t().into()]); + assert_eq!(op.args(), &[TypeArg::Type { ty: usize_t() }]); assert_eq!(op.signature().as_ref(), &sig); } diff --git a/hugr-core/src/ops/dataflow.rs b/hugr-core/src/ops/dataflow.rs index 2a09fef5c8..66aa4144b6 100644 --- a/hugr-core/src/ops/dataflow.rs +++ b/hugr-core/src/ops/dataflow.rs @@ -10,7 +10,7 @@ use crate::types::{EdgeKind, PolyFuncType, Signature, Substitution, Type, TypeAr use crate::{IncomingPort, type_row}; #[cfg(test)] -use {crate::types::proptest_utils::any_serde_type_arg_vec, proptest_derive::Arbitrary}; +use proptest_derive::Arbitrary; /// Trait implemented by all dataflow operations. pub trait DataflowOpTrait: Sized { @@ -191,7 +191,6 @@ pub struct Call { /// Signature of function being called. pub func_sig: PolyFuncType, /// The type arguments that instantiate `func_sig`. - #[cfg_attr(test, proptest(strategy = "any_serde_type_arg_vec()"))] pub type_args: Vec, /// The instantiation of `func_sig`. pub instantiation: Signature, // Cache, so we can fail in try_new() not in signature() @@ -285,8 +284,8 @@ impl Call { Ok(()) } else { Err(SignatureError::CallIncorrectlyAppliesType { - cached: Box::new(self.instantiation.clone()), - expected: Box::new(other.instantiation.clone()), + cached: self.instantiation.clone(), + expected: other.instantiation.clone(), }) } } @@ -392,7 +391,6 @@ pub struct LoadFunction { /// Signature of the function pub func_sig: PolyFuncType, /// The type arguments that instantiate `func_sig`. - #[cfg_attr(test, proptest(strategy = "any_serde_type_arg_vec()"))] pub type_args: Vec, /// The instantiation of `func_sig`. pub instantiation: Signature, // Cache, so we can fail in try_new() not in signature() @@ -476,8 +474,8 @@ impl LoadFunction { Ok(()) } else { Err(SignatureError::LoadFunctionIncorrectlyAppliesType { - cached: Box::new(self.instantiation.clone()), - expected: Box::new(other.instantiation.clone()), + cached: self.instantiation.clone(), + expected: other.instantiation.clone(), }) } } diff --git a/hugr-core/src/ops/module.rs b/hugr-core/src/ops/module.rs index eda121f235..db2b81f9f3 100644 --- a/hugr-core/src/ops/module.rs +++ b/hugr-core/src/ops/module.rs @@ -9,11 +9,12 @@ use { ::proptest_derive::Arbitrary, }; -use crate::Visibility; -use crate::types::{EdgeKind, PolyFuncType, Signature, Type, TypeBound}; +use crate::types::{EdgeKind, PolyFuncType, Signature}; +use crate::types::{Type, TypeBound}; +use super::StaticTag; use super::dataflow::DataflowParent; -use super::{OpTag, OpTrait, StaticTag, impl_op_name}; +use super::{OpTag, OpTrait, impl_op_name}; /// The root of a module, parent of all other `OpType`s. #[derive(Debug, Clone, PartialEq, Eq, Default, serde::Serialize, serde::Deserialize)] @@ -56,31 +57,14 @@ pub struct FuncDefn { #[cfg_attr(test, proptest(strategy = "any_nonempty_string()"))] name: String, signature: PolyFuncType, - #[serde(default = "priv_vis")] // sadly serde does not pick this up from the schema - visibility: Visibility, -} - -fn priv_vis() -> Visibility { - Visibility::Private } impl FuncDefn { - /// Create a new, [Visibility::Private], instance with the given name and signature. - /// See also [Self::new_vis]. + /// Create a new instance with the given name and signature pub fn new(name: impl Into, signature: impl Into) -> Self { - Self::new_vis(name, signature, Visibility::Private) - } - - /// Create a new instance with the specified name and visibility - pub fn new_vis( - name: impl Into, - signature: impl Into, - visibility: Visibility, - ) -> Self { Self { name: name.into(), signature: signature.into(), - visibility, } } @@ -103,16 +87,6 @@ impl FuncDefn { pub fn signature_mut(&mut self) -> &mut PolyFuncType { &mut self.signature } - - /// The visibility of the function, e.g. for linking - pub fn visibility(&self) -> &Visibility { - &self.visibility - } - - /// Allows changing [Self::visibility] - pub fn visibility_mut(&mut self) -> &mut Visibility { - &mut self.visibility - } } impl_op_name!(FuncDefn); @@ -149,32 +123,14 @@ pub struct FuncDecl { #[cfg_attr(test, proptest(strategy = "any_nonempty_string()"))] name: String, signature: PolyFuncType, - // (again) sadly serde does not pick this up from the schema - #[serde(default = "pub_vis")] // Note opposite of FuncDefn - visibility: Visibility, -} - -fn pub_vis() -> Visibility { - Visibility::Public } impl FuncDecl { - /// Create a new [Visibility::Public] instance with the given name and signature. - /// See also [Self::new_vis] + /// Create a new instance with the given name and signature pub fn new(name: impl Into, signature: impl Into) -> Self { - Self::new_vis(name, signature, Visibility::Public) - } - - /// Create a new instance with the given name, signature and visibility - pub fn new_vis( - name: impl Into, - signature: impl Into, - visibility: Visibility, - ) -> Self { Self { name: name.into(), signature: signature.into(), - visibility, } } @@ -183,21 +139,11 @@ impl FuncDecl { &self.name } - /// The visibility of the function, e.g. for linking - pub fn visibility(&self) -> &Visibility { - &self.visibility - } - /// Allows mutating the name of the function (as per [Self::func_name]) pub fn func_name_mut(&mut self) -> &mut String { &mut self.name } - /// Allows mutating the [Self::visibility] of the function - pub fn visibility_mut(&mut self) -> &mut Visibility { - &mut self.visibility - } - /// Gets the signature of the function pub fn signature(&self) -> &PolyFuncType { &self.signature diff --git a/hugr-core/src/ops/tag.rs b/hugr-core/src/ops/tag.rs index 2834cd94eb..bed7e47370 100644 --- a/hugr-core/src/ops/tag.rs +++ b/hugr-core/src/ops/tag.rs @@ -57,8 +57,6 @@ pub enum OpTag { /// A function load operation. LoadFunc, /// A definition that could be at module level or inside a DSG. - /// Note that this means only Constants, as all other defn/decls - /// must be at Module level. ScopedDefn, /// A tail-recursive loop. TailLoop, @@ -114,8 +112,8 @@ impl OpTag { OpTag::Input => &[OpTag::DataflowChild], OpTag::Output => &[OpTag::DataflowChild], OpTag::Function => &[OpTag::ModuleOp, OpTag::StaticOutput], - OpTag::Alias => &[OpTag::ModuleOp], - OpTag::FuncDefn => &[OpTag::Function, OpTag::DataflowParent], + OpTag::Alias => &[OpTag::ScopedDefn], + OpTag::FuncDefn => &[OpTag::Function, OpTag::ScopedDefn, OpTag::DataflowParent], OpTag::DataflowBlock => &[OpTag::ControlFlowChild, OpTag::DataflowParent], OpTag::BasicBlockExit => &[OpTag::ControlFlowChild], OpTag::Case => &[OpTag::Any, OpTag::DataflowParent], diff --git a/hugr-core/src/ops/validate.rs b/hugr-core/src/ops/validate.rs index 9bb4ebbe89..f19ae5e079 100644 --- a/hugr-core/src/ops/validate.rs +++ b/hugr-core/src/ops/validate.rs @@ -103,7 +103,7 @@ impl ValidateOp for super::Conditional { if sig.input != self.case_input_row(i).unwrap() || sig.output != self.outputs { return Err(ChildrenValidationError::ConditionalCaseSignature { child, - optype: Box::new(optype.clone()), + optype: optype.clone(), }); } } @@ -177,7 +177,7 @@ pub enum ChildrenValidationError { #[error("A {optype} operation is only allowed as a {expected_position} child")] InternalIOChildren { child: N, - optype: Box, + optype: OpType, expected_position: &'static str, }, /// The signature of the contained dataflow graph does not match the one of the container. @@ -193,7 +193,7 @@ pub enum ChildrenValidationError { }, /// The signature of a child case in a conditional operation does not match the container's signature. #[error("A conditional case has optype {sig}, which differs from the signature of Conditional container", sig=optype.dataflow_signature().unwrap_or_default())] - ConditionalCaseSignature { child: N, optype: Box }, + ConditionalCaseSignature { child: N, optype: OpType }, /// The conditional container's branching value does not match the number of children. #[error("The conditional container's branch Sum input should be a sum with {expected_count} elements, but it had {} elements. Sum rows: {actual_sum_rows:?}", actual_sum_rows.len())] @@ -227,9 +227,9 @@ pub enum EdgeValidationError { source_ty = source_types.clone().unwrap_or_default(), )] CFGEdgeSignatureMismatch { - edge: Box>, - source_types: Option>, - target_types: Box, + edge: ChildrenEdgeData, + source_types: Option, + target_types: TypeRow, }, } @@ -323,14 +323,14 @@ fn validate_io_nodes<'a, N: HugrNode>( OpTag::Input => { return Err(ChildrenValidationError::InternalIOChildren { child, - optype: Box::new(optype.clone()), + optype: optype.clone(), expected_position: "first", }); } OpTag::Output => { return Err(ChildrenValidationError::InternalIOChildren { child, - optype: Box::new(optype.clone()), + optype: optype.clone(), expected_position: "second", }); } @@ -357,9 +357,9 @@ fn validate_cfg_edge(edge: ChildrenEdgeData) -> Result<(), EdgeV if source_types.as_ref() != Some(target_input) { let target_types = target_input.clone(); return Err(EdgeValidationError::CFGEdgeSignatureMismatch { - edge: Box::new(edge), - source_types: source_types.map(Box::new), - target_types: Box::new(target_types), + edge, + source_types, + target_types, }); } diff --git a/hugr-core/src/package.rs b/hugr-core/src/package.rs index e50639ec05..7be90da270 100644 --- a/hugr-core/src/package.rs +++ b/hugr-core/src/package.rs @@ -1,5 +1,6 @@ //! Bundles of hugr modules along with the extension required to load them. +use derive_more::{Display, Error, From}; use std::io; use crate::envelope::{EnvelopeConfig, EnvelopeError, read_envelope, write_envelope}; @@ -7,7 +8,6 @@ use crate::extension::ExtensionRegistry; use crate::hugr::{HugrView, ValidationError}; use crate::std_extensions::STD_REG; use crate::{Hugr, Node}; -use thiserror::Error; #[derive(Debug, Default, Clone, PartialEq)] /// Package of module HUGRs. @@ -131,12 +131,11 @@ impl AsRef<[Hugr]> for Package { } /// Error raised while validating a package. -#[derive(Debug, Error)] +#[derive(Debug, Display, From, Error)] #[non_exhaustive] -#[error("Package validation error.")] pub enum PackageValidationError { /// Error raised while validating the package hugrs. - Validation(#[from] ValidationError), + Validation(ValidationError), } #[cfg(test)] diff --git a/hugr-core/src/std_extensions.rs b/hugr-core/src/std_extensions.rs index d57663391d..1d49ea4e1e 100644 --- a/hugr-core/src/std_extensions.rs +++ b/hugr-core/src/std_extensions.rs @@ -21,7 +21,6 @@ pub fn std_reg() -> ExtensionRegistry { arithmetic::float_types::EXTENSION.to_owned(), collections::array::EXTENSION.to_owned(), collections::list::EXTENSION.to_owned(), - collections::borrow_array::EXTENSION.to_owned(), collections::static_array::EXTENSION.to_owned(), collections::value_array::EXTENSION.to_owned(), logic::EXTENSION.to_owned(), diff --git a/hugr-core/src/std_extensions/arithmetic/int_types.rs b/hugr-core/src/std_extensions/arithmetic/int_types.rs index 71eb8fa91e..5db32d55ed 100644 --- a/hugr-core/src/std_extensions/arithmetic/int_types.rs +++ b/hugr-core/src/std_extensions/arithmetic/int_types.rs @@ -4,14 +4,14 @@ use std::num::NonZeroU64; use std::sync::{Arc, Weak}; use crate::ops::constant::ValueName; -use crate::types::{Term, TypeName}; +use crate::types::TypeName; use crate::{ Extension, extension::ExtensionId, ops::constant::CustomConst, types::{ ConstTypeError, CustomType, Type, TypeBound, - type_param::{TermTypeError, TypeArg, TypeParam}, + type_param::{TypeArg, TypeArgError, TypeParam}, }, }; use lazy_static::lazy_static; @@ -49,7 +49,7 @@ pub fn int_type(width_arg: impl Into) -> Type { lazy_static! { /// Array of valid integer types, indexed by log width of the integer. pub static ref INT_TYPES: [Type; LOG_WIDTH_BOUND as usize] = (0..LOG_WIDTH_BOUND) - .map(|i| int_type(Term::from(u64::from(i)))) + .map(|i| int_type(TypeArg::BoundedNat { n: u64::from(i) })) .collect::>() .try_into() .unwrap(); @@ -69,25 +69,27 @@ pub const LOG_WIDTH_BOUND: u8 = LOG_WIDTH_MAX + 1; /// Type parameter for the log width of the integer. #[allow(clippy::assertions_on_constants)] -pub const LOG_WIDTH_TYPE_PARAM: TypeParam = TypeParam::bounded_nat_type({ +pub const LOG_WIDTH_TYPE_PARAM: TypeParam = TypeParam::bounded_nat({ assert!(LOG_WIDTH_BOUND > 0); NonZeroU64::MIN.saturating_add(LOG_WIDTH_BOUND as u64 - 1) }); /// Get the log width of the specified type argument or error if the argument /// is invalid. -pub(super) fn get_log_width(arg: &TypeArg) -> Result { +pub(super) fn get_log_width(arg: &TypeArg) -> Result { match arg { - TypeArg::BoundedNat(n) if is_valid_log_width(*n as u8) => Ok(*n as u8), - _ => Err(TermTypeError::TypeMismatch { - term: Box::new(arg.clone()), - type_: Box::new(LOG_WIDTH_TYPE_PARAM), + TypeArg::BoundedNat { n } if is_valid_log_width(*n as u8) => Ok(*n as u8), + _ => Err(TypeArgError::TypeMismatch { + arg: arg.clone(), + param: LOG_WIDTH_TYPE_PARAM, }), } } const fn type_arg(log_width: u8) -> TypeArg { - TypeArg::BoundedNat(log_width as u64) + TypeArg::BoundedNat { + n: log_width as u64, + } } /// An integer (either signed or unsigned) @@ -237,13 +239,13 @@ mod test { #[test] fn test_int_widths() { - let type_arg_32 = TypeArg::BoundedNat(5); + let type_arg_32 = TypeArg::BoundedNat { n: 5 }; assert_matches!(get_log_width(&type_arg_32), Ok(5)); - let type_arg_128 = TypeArg::BoundedNat(7); + let type_arg_128 = TypeArg::BoundedNat { n: 7 }; assert_matches!( get_log_width(&type_arg_128), - Err(TermTypeError::TypeMismatch { .. }) + Err(TypeArgError::TypeMismatch { .. }) ); } diff --git a/hugr-core/src/std_extensions/arithmetic/mod.rs b/hugr-core/src/std_extensions/arithmetic/mod.rs index fbf3531ee7..dc26ac4b0b 100644 --- a/hugr-core/src/std_extensions/arithmetic/mod.rs +++ b/hugr-core/src/std_extensions/arithmetic/mod.rs @@ -20,7 +20,7 @@ mod test { for i in 0..LOG_WIDTH_BOUND { assert_eq!( INT_TYPES[i as usize], - int_type(TypeArg::BoundedNat(u64::from(i))) + int_type(TypeArg::BoundedNat { n: u64::from(i) }) ); } } diff --git a/hugr-core/src/std_extensions/collections.rs b/hugr-core/src/std_extensions/collections.rs index 0c52ad94d6..efd53c805e 100644 --- a/hugr-core/src/std_extensions/collections.rs +++ b/hugr-core/src/std_extensions/collections.rs @@ -1,7 +1,6 @@ //! List type and operations. pub mod array; -pub mod borrow_array; pub mod list; pub mod static_array; pub mod value_array; diff --git a/hugr-core/src/std_extensions/collections/array.rs b/hugr-core/src/std_extensions/collections/array.rs index de55a41947..eb31441453 100644 --- a/hugr-core/src/std_extensions/collections/array.rs +++ b/hugr-core/src/std_extensions/collections/array.rs @@ -96,7 +96,7 @@ lazy_static! { Extension::new_arc(EXTENSION_ID, VERSION, |extension, extension_ref| { extension.add_type( ARRAY_TYPENAME, - vec![ TypeParam::max_nat_type(), TypeBound::Linear.into()], + vec![ TypeParam::max_nat(), TypeBound::Any.into()], "Fixed-length array".into(), // Default array is linear, even if the elements are copyable TypeDefBound::any(), @@ -223,7 +223,7 @@ pub trait ArrayOpBuilder: GenericArrayOpBuilder { self.add_generic_array_unpack::(elem_ty, size, input) } /// Adds an array clone operation to the dataflow graph and return the wires - /// representing the original and cloned array. + /// representing the originala and cloned array. /// /// # Arguments /// diff --git a/hugr-core/src/std_extensions/collections/array/array_clone.rs b/hugr-core/src/std_extensions/collections/array/array_clone.rs index 2575a32c26..2a3de6d6d9 100644 --- a/hugr-core/src/std_extensions/collections/array/array_clone.rs +++ b/hugr-core/src/std_extensions/collections/array/array_clone.rs @@ -51,8 +51,8 @@ impl FromStr for GenericArrayCloneDef { impl GenericArrayCloneDef { /// To avoid recursion when defining the extension, take the type definition as an argument. fn signature_from_def(&self, array_def: &TypeDef) -> SignatureFunc { - let params = vec![TypeParam::max_nat_type(), TypeBound::Copyable.into()]; - let size = TypeArg::new_var_use(0, TypeParam::max_nat_type()); + let params = vec![TypeParam::max_nat(), TypeBound::Copyable.into()]; + let size = TypeArg::new_var_use(0, TypeParam::max_nat()); let element_ty = Type::new_var_use(1, TypeBound::Copyable); let array_ty = AK::instantiate_ty(array_def, size, element_ty) .expect("Array type instantiation failed"); @@ -157,7 +157,10 @@ impl MakeExtensionOp for GenericArrayClone { } fn type_args(&self) -> Vec { - vec![self.size.into(), self.elem_ty.clone().into()] + vec![ + TypeArg::BoundedNat { n: self.size }, + self.elem_ty.clone().into(), + ] } } @@ -180,7 +183,7 @@ impl HasConcrete for GenericArrayCloneDef { fn instantiate(&self, type_args: &[TypeArg]) -> Result { match type_args { - [TypeArg::BoundedNat(n), TypeArg::Runtime(ty)] if ty.copyable() => { + [TypeArg::BoundedNat { n }, TypeArg::Type { ty }] if ty.copyable() => { Ok(GenericArrayClone::new(ty.clone(), *n).unwrap()) } _ => Err(SignatureError::InvalidTypeArgs.into()), @@ -194,7 +197,6 @@ mod tests { use crate::extension::prelude::bool_t; use crate::std_extensions::collections::array::Array; - use crate::std_extensions::collections::borrow_array::BorrowArray; use crate::{ extension::prelude::qb_t, ops::{OpTrait, OpType}, @@ -204,7 +206,6 @@ mod tests { #[rstest] #[case(Array)] - #[case(BorrowArray)] fn test_clone_def(#[case] _kind: AK) { let op = GenericArrayClone::::new(bool_t(), 2).unwrap(); let optype: OpType = op.clone().into(); @@ -219,7 +220,6 @@ mod tests { #[rstest] #[case(Array)] - #[case(BorrowArray)] fn test_clone(#[case] _kind: AK) { let size = 2; let element_ty = bool_t(); diff --git a/hugr-core/src/std_extensions/collections/array/array_conversion.rs b/hugr-core/src/std_extensions/collections/array/array_conversion.rs index 015b968002..21544dfd15 100644 --- a/hugr-core/src/std_extensions/collections/array/array_conversion.rs +++ b/hugr-core/src/std_extensions/collections/array/array_conversion.rs @@ -76,9 +76,9 @@ impl { /// To avoid recursion when defining the extension, take the type definition as an argument. fn signature_from_def(&self, array_def: &TypeDef) -> SignatureFunc { - let params = vec![TypeParam::max_nat_type(), TypeBound::Linear.into()]; - let size = TypeArg::new_var_use(0, TypeParam::max_nat_type()); - let element_ty = Type::new_var_use(1, TypeBound::Linear); + let params = vec![TypeParam::max_nat(), TypeBound::Any.into()]; + let size = TypeArg::new_var_use(0, TypeParam::max_nat()); + let element_ty = Type::new_var_use(1, TypeBound::Any); let this_ty = AK::instantiate_ty(array_def, size.clone(), element_ty.clone()) .expect("Array type instantiation failed"); @@ -202,7 +202,10 @@ impl MakeExtensionOp } fn type_args(&self) -> Vec { - vec![TypeArg::BoundedNat(self.size), self.elem_ty.clone().into()] + vec![ + TypeArg::BoundedNat { n: self.size }, + self.elem_ty.clone().into(), + ] } } @@ -231,7 +234,7 @@ impl HasConcrete fn instantiate(&self, type_args: &[TypeArg]) -> Result { match type_args { - [TypeArg::BoundedNat(n), TypeArg::Runtime(ty)] => { + [TypeArg::BoundedNat { n }, TypeArg::Type { ty }] => { Ok(GenericArrayConvert::new(ty.clone(), *n)) } _ => Err(SignatureError::InvalidTypeArgs.into()), @@ -246,14 +249,12 @@ mod tests { use crate::extension::prelude::bool_t; use crate::ops::{OpTrait, OpType}; use crate::std_extensions::collections::array::Array; - use crate::std_extensions::collections::borrow_array::BorrowArray; use crate::std_extensions::collections::value_array::ValueArray; use super::*; #[rstest] #[case(ValueArray, Array)] - #[case(BorrowArray, Array)] fn test_convert_from_def( #[case] _kind: AK, #[case] _other_kind: OtherAK, @@ -266,7 +267,6 @@ mod tests { #[rstest] #[case(ValueArray, Array)] - #[case(BorrowArray, Array)] fn test_convert_into_def( #[case] _kind: AK, #[case] _other_kind: OtherAK, @@ -279,7 +279,6 @@ mod tests { #[rstest] #[case(ValueArray, Array)] - #[case(BorrowArray, Array)] fn test_convert_from( #[case] _kind: AK, #[case] _other_kind: OtherAK, @@ -300,7 +299,6 @@ mod tests { #[rstest] #[case(ValueArray, Array)] - #[case(BorrowArray, Array)] fn test_convert_into( #[case] _kind: AK, #[case] _other_kind: OtherAK, diff --git a/hugr-core/src/std_extensions/collections/array/array_discard.rs b/hugr-core/src/std_extensions/collections/array/array_discard.rs index 7e7a6599e0..67e2281f72 100644 --- a/hugr-core/src/std_extensions/collections/array/array_discard.rs +++ b/hugr-core/src/std_extensions/collections/array/array_discard.rs @@ -51,8 +51,8 @@ impl FromStr for GenericArrayDiscardDef { impl GenericArrayDiscardDef { /// To avoid recursion when defining the extension, take the type definition as an argument. fn signature_from_def(&self, array_def: &TypeDef) -> SignatureFunc { - let params = vec![TypeParam::max_nat_type(), TypeBound::Copyable.into()]; - let size = TypeArg::new_var_use(0, TypeParam::max_nat_type()); + let params = vec![TypeParam::max_nat(), TypeBound::Copyable.into()]; + let size = TypeArg::new_var_use(0, TypeParam::max_nat()); let element_ty = Type::new_var_use(1, TypeBound::Copyable); let array_ty = AK::instantiate_ty(array_def, size, element_ty) .expect("Array type instantiation failed"); @@ -141,7 +141,10 @@ impl MakeExtensionOp for GenericArrayDiscard { } fn type_args(&self) -> Vec { - vec![self.size.into(), self.elem_ty.clone().into()] + vec![ + TypeArg::BoundedNat { n: self.size }, + self.elem_ty.clone().into(), + ] } } @@ -164,7 +167,7 @@ impl HasConcrete for GenericArrayDiscardDef { fn instantiate(&self, type_args: &[TypeArg]) -> Result { match type_args { - [TypeArg::BoundedNat(n), TypeArg::Runtime(ty)] if ty.copyable() => { + [TypeArg::BoundedNat { n }, TypeArg::Type { ty }] if ty.copyable() => { Ok(GenericArrayDiscard::new(ty.clone(), *n).unwrap()) } _ => Err(SignatureError::InvalidTypeArgs.into()), @@ -178,7 +181,6 @@ mod tests { use crate::extension::prelude::bool_t; use crate::std_extensions::collections::array::Array; - use crate::std_extensions::collections::borrow_array::BorrowArray; use crate::{ extension::prelude::qb_t, ops::{OpTrait, OpType}, @@ -188,7 +190,6 @@ mod tests { #[rstest] #[case(Array)] - #[case(BorrowArray)] fn test_discard_def(#[case] _kind: AK) { let op = GenericArrayDiscard::::new(bool_t(), 2).unwrap(); let optype: OpType = op.clone().into(); @@ -200,7 +201,6 @@ mod tests { #[rstest] #[case(Array)] - #[case(BorrowArray)] fn test_discard(#[case] _kind: AK) { let size = 2; let element_ty = bool_t(); diff --git a/hugr-core/src/std_extensions/collections/array/array_op.rs b/hugr-core/src/std_extensions/collections/array/array_op.rs index dc7cf3d940..915603c1da 100644 --- a/hugr-core/src/std_extensions/collections/array/array_op.rs +++ b/hugr-core/src/std_extensions/collections/array/array_op.rs @@ -16,7 +16,7 @@ use crate::extension::{ use crate::ops::{ExtensionOp, OpName}; use crate::type_row; use crate::types::type_param::{TypeArg, TypeParam}; -use crate::types::{FuncValueType, PolyFuncTypeRV, Term, Type, TypeBound}; +use crate::types::{FuncValueType, PolyFuncTypeRV, Type, TypeBound}; use crate::utils::Never; use super::array_kind::ArrayKind; @@ -65,16 +65,16 @@ pub enum GenericArrayOpDef { } /// Static parameters for array operations. Includes array size. Type is part of the type scheme. -const STATIC_SIZE_PARAM: &[TypeParam; 1] = &[TypeParam::max_nat_type()]; +const STATIC_SIZE_PARAM: &[TypeParam; 1] = &[TypeParam::max_nat()]; impl SignatureFromArgs for GenericArrayOpDef { fn compute_signature(&self, arg_values: &[TypeArg]) -> Result { - let [TypeArg::BoundedNat(n)] = *arg_values else { + let [TypeArg::BoundedNat { n }] = *arg_values else { return Err(SignatureError::InvalidTypeArgs); }; - let elem_ty_var = Type::new_var_use(0, TypeBound::Linear); + let elem_ty_var = Type::new_var_use(0, TypeBound::Any); let array_ty = AK::ty(n, elem_ty_var.clone()); - let params = vec![TypeBound::Linear.into()]; + let params = vec![TypeBound::Any.into()]; let poly_func_ty = match self { GenericArrayOpDef::new_array => PolyFuncTypeRV::new( params, @@ -139,11 +139,11 @@ impl GenericArrayOpDef { // signature computed dynamically, so can rely on type definition in extension. (*self).into() } else { - let size_var = TypeArg::new_var_use(0, TypeParam::max_nat_type()); - let elem_ty_var = Type::new_var_use(1, TypeBound::Linear); + let size_var = TypeArg::new_var_use(0, TypeParam::max_nat()); + let elem_ty_var = Type::new_var_use(1, TypeBound::Any); let array_ty = AK::instantiate_ty(array_def, size_var.clone(), elem_ty_var.clone()) .expect("Array type instantiation failed"); - let standard_params = vec![TypeParam::max_nat_type(), TypeBound::Linear.into()]; + let standard_params = vec![TypeParam::max_nat(), TypeBound::Any.into()]; // We can assume that the prelude has ben loaded at this point, // since it doesn't depend on the array extension. @@ -151,7 +151,7 @@ impl GenericArrayOpDef { match self { get => { - let params = vec![TypeParam::max_nat_type(), TypeBound::Copyable.into()]; + let params = vec![TypeParam::max_nat(), TypeBound::Copyable.into()]; let copy_elem_ty = Type::new_var_use(1, TypeBound::Copyable); let copy_array_ty = AK::instantiate_ty(array_def, size_var, copy_elem_ty.clone()) @@ -184,9 +184,9 @@ impl GenericArrayOpDef { ) } discard_empty => PolyFuncTypeRV::new( - vec![TypeBound::Linear.into()], + vec![TypeBound::Any.into()], FuncValueType::new( - AK::instantiate_ty(array_def, 0, Type::new_var_use(0, TypeBound::Linear)) + AK::instantiate_ty(array_def, 0, Type::new_var_use(0, TypeBound::Any)) .expect("Array type instantiation failed"), type_row![], ), @@ -282,11 +282,13 @@ impl MakeExtensionOp for GenericArrayOp { def.instantiate(ext_op.args()) } - fn type_args(&self) -> Vec { + fn type_args(&self) -> Vec { use GenericArrayOpDef::{ _phantom, discard_empty, get, new_array, pop_left, pop_right, set, swap, unpack, }; - let ty_arg = self.elem_ty.clone().into(); + let ty_arg = TypeArg::Type { + ty: self.elem_ty.clone(), + }; match self.def { discard_empty => { debug_assert_eq!( @@ -296,7 +298,7 @@ impl MakeExtensionOp for GenericArrayOp { vec![ty_arg] } new_array | unpack | pop_left | pop_right | get | set | swap => { - vec![self.size.into(), ty_arg] + vec![TypeArg::BoundedNat { n: self.size }, ty_arg] } _phantom(_, never) => match never {}, } @@ -320,10 +322,10 @@ impl HasDef for GenericArrayOp { impl HasConcrete for GenericArrayOpDef { type Concrete = GenericArrayOp; - fn instantiate(&self, type_args: &[Term]) -> Result { + fn instantiate(&self, type_args: &[TypeArg]) -> Result { let (ty, size) = match (self, type_args) { - (GenericArrayOpDef::discard_empty, [Term::Runtime(ty)]) => (ty.clone(), 0), - (_, [Term::BoundedNat(n), Term::Runtime(ty)]) => (ty.clone(), *n), + (GenericArrayOpDef::discard_empty, [TypeArg::Type { ty }]) => (ty.clone(), 0), + (_, [TypeArg::BoundedNat { n }, TypeArg::Type { ty }]) => (ty.clone(), *n), _ => return Err(SignatureError::InvalidTypeArgs.into()), }; @@ -339,7 +341,6 @@ mod tests { use crate::extension::prelude::usize_t; use crate::std_extensions::arithmetic::float_types::float64_type; use crate::std_extensions::collections::array::Array; - use crate::std_extensions::collections::borrow_array::BorrowArray; use crate::std_extensions::collections::value_array::ValueArray; use crate::{ builder::{DFGBuilder, Dataflow, DataflowHugr, inout_sig}, @@ -352,7 +353,6 @@ mod tests { #[rstest] #[case(Array)] #[case(ValueArray)] - #[case(BorrowArray)] fn test_array_ops(#[case] _kind: AK) { for def in GenericArrayOpDef::::iter() { let ty = if def == GenericArrayOpDef::get { @@ -375,7 +375,6 @@ mod tests { #[rstest] #[case(Array)] #[case(ValueArray)] - #[case(BorrowArray)] /// Test building a HUGR involving a new_array operation. fn test_new_array(#[case] _kind: AK) { let mut b = DFGBuilder::new(inout_sig(vec![qb_t(), qb_t()], AK::ty(2, qb_t()))).unwrap(); @@ -392,7 +391,6 @@ mod tests { #[rstest] #[case(Array)] #[case(ValueArray)] - #[case(BorrowArray)] /// Test building a HUGR involving an unpack operation. fn test_unpack(#[case] _kind: AK) { let mut b = DFGBuilder::new(inout_sig(AK::ty(2, qb_t()), vec![qb_t(), qb_t()])).unwrap(); @@ -409,7 +407,6 @@ mod tests { #[rstest] #[case(Array)] #[case(ValueArray)] - #[case(BorrowArray)] fn test_get(#[case] _kind: AK) { let size = 2; let element_ty = bool_t(); @@ -435,7 +432,6 @@ mod tests { #[rstest] #[case(Array)] #[case(ValueArray)] - #[case(BorrowArray)] fn test_set(#[case] _kind: AK) { let size = 2; let element_ty = bool_t(); @@ -458,7 +454,6 @@ mod tests { #[rstest] #[case(Array)] #[case(ValueArray)] - #[case(BorrowArray)] fn test_swap(#[case] _kind: AK) { let size = 2; let element_ty = bool_t(); @@ -480,7 +475,6 @@ mod tests { #[rstest] #[case(Array)] #[case(ValueArray)] - #[case(BorrowArray)] fn test_pops(#[case] _kind: AK) { let size = 2; let element_ty = bool_t(); @@ -513,7 +507,6 @@ mod tests { #[rstest] #[case(Array)] #[case(ValueArray)] - #[case(BorrowArray)] fn test_discard_empty(#[case] _kind: AK) { let size = 0; let element_ty = bool_t(); @@ -532,7 +525,6 @@ mod tests { #[rstest] #[case(Array)] #[case(ValueArray)] - #[case(BorrowArray)] /// Initialize an array operation where the element type is not from the prelude. fn test_non_prelude_op(#[case] _kind: AK) { let size = 2; diff --git a/hugr-core/src/std_extensions/collections/array/array_repeat.rs b/hugr-core/src/std_extensions/collections/array/array_repeat.rs index e2c77ef21c..d3302d253a 100644 --- a/hugr-core/src/std_extensions/collections/array/array_repeat.rs +++ b/hugr-core/src/std_extensions/collections/array/array_repeat.rs @@ -52,9 +52,9 @@ impl FromStr for GenericArrayRepeatDef { impl GenericArrayRepeatDef { /// To avoid recursion when defining the extension, take the type definition as an argument. fn signature_from_def(&self, array_def: &TypeDef) -> SignatureFunc { - let params = vec![TypeParam::max_nat_type(), TypeBound::Linear.into()]; - let n = TypeArg::new_var_use(0, TypeParam::max_nat_type()); - let t = Type::new_var_use(1, TypeBound::Linear); + let params = vec![TypeParam::max_nat(), TypeBound::Any.into()]; + let n = TypeArg::new_var_use(0, TypeParam::max_nat()); + let t = Type::new_var_use(1, TypeBound::Any); let func = Type::new_function(Signature::new(vec![], vec![t.clone()])); let array_ty = AK::instantiate_ty(array_def, n, t).expect("Array type instantiation failed"); @@ -147,7 +147,10 @@ impl MakeExtensionOp for GenericArrayRepeat { } fn type_args(&self) -> Vec { - vec![self.size.into(), self.elem_ty.clone().into()] + vec![ + TypeArg::BoundedNat { n: self.size }, + self.elem_ty.clone().into(), + ] } } @@ -170,7 +173,7 @@ impl HasConcrete for GenericArrayRepeatDef { fn instantiate(&self, type_args: &[TypeArg]) -> Result { match type_args { - [TypeArg::BoundedNat(n), TypeArg::Runtime(ty)] => { + [TypeArg::BoundedNat { n }, TypeArg::Type { ty }] => { Ok(GenericArrayRepeat::new(ty.clone(), *n)) } _ => Err(SignatureError::InvalidTypeArgs.into()), @@ -183,7 +186,6 @@ mod tests { use rstest::rstest; use crate::std_extensions::collections::array::Array; - use crate::std_extensions::collections::borrow_array::BorrowArray; use crate::std_extensions::collections::value_array::ValueArray; use crate::{ extension::prelude::qb_t, @@ -196,7 +198,6 @@ mod tests { #[rstest] #[case(Array)] #[case(ValueArray)] - #[case(BorrowArray)] fn test_repeat_def(#[case] _kind: AK) { let op = GenericArrayRepeat::::new(qb_t(), 2); let optype: OpType = op.clone().into(); @@ -207,7 +208,6 @@ mod tests { #[rstest] #[case(Array)] #[case(ValueArray)] - #[case(BorrowArray)] fn test_repeat(#[case] _kind: AK) { let size = 2; let element_ty = qb_t(); diff --git a/hugr-core/src/std_extensions/collections/array/array_scan.rs b/hugr-core/src/std_extensions/collections/array/array_scan.rs index 416777f436..2dc5d2f734 100644 --- a/hugr-core/src/std_extensions/collections/array/array_scan.rs +++ b/hugr-core/src/std_extensions/collections/array/array_scan.rs @@ -56,15 +56,15 @@ impl GenericArrayScanDef { fn signature_from_def(&self, array_def: &TypeDef) -> SignatureFunc { // array, (T1, *A -> T2, *A), *A, -> array, *A let params = vec![ - TypeParam::max_nat_type(), - TypeBound::Linear.into(), - TypeBound::Linear.into(), - TypeParam::new_list_type(TypeBound::Linear), + TypeParam::max_nat(), + TypeBound::Any.into(), + TypeBound::Any.into(), + TypeParam::new_list(TypeBound::Any), ]; - let n = TypeArg::new_var_use(0, TypeParam::max_nat_type()); - let t1 = Type::new_var_use(1, TypeBound::Linear); - let t2 = Type::new_var_use(2, TypeBound::Linear); - let s = TypeRV::new_row_var_use(3, TypeBound::Linear); + let n = TypeArg::new_var_use(0, TypeParam::max_nat()); + let t1 = Type::new_var_use(1, TypeBound::Any); + let t2 = Type::new_var_use(2, TypeBound::Any); + let s = TypeRV::new_row_var_use(3, TypeBound::Any); PolyFuncTypeRV::new( params, FuncTypeBase::::new( @@ -185,10 +185,12 @@ impl MakeExtensionOp for GenericArrayScan { fn type_args(&self) -> Vec { vec![ - self.size.into(), + TypeArg::BoundedNat { n: self.size }, self.src_ty.clone().into(), self.tgt_ty.clone().into(), - TypeArg::new_list(self.acc_tys.clone().into_iter().map_into()), + TypeArg::Sequence { + elems: self.acc_tys.clone().into_iter().map_into().collect(), + }, ] } } @@ -213,15 +215,15 @@ impl HasConcrete for GenericArrayScanDef { fn instantiate(&self, type_args: &[TypeArg]) -> Result { match type_args { [ - TypeArg::BoundedNat(n), - TypeArg::Runtime(src_ty), - TypeArg::Runtime(tgt_ty), - TypeArg::List(acc_tys), + TypeArg::BoundedNat { n }, + TypeArg::Type { ty: src_ty }, + TypeArg::Type { ty: tgt_ty }, + TypeArg::Sequence { elems: acc_tys }, ] => { let acc_tys: Result<_, OpLoadError> = acc_tys .iter() .map(|acc_ty| match acc_ty { - TypeArg::Runtime(ty) => Ok(ty.clone()), + TypeArg::Type { ty } => Ok(ty.clone()), _ => Err(SignatureError::InvalidTypeArgs.into()), }) .collect(); @@ -243,7 +245,6 @@ mod tests { use crate::extension::prelude::usize_t; use crate::std_extensions::collections::array::Array; - use crate::std_extensions::collections::borrow_array::BorrowArray; use crate::std_extensions::collections::value_array::ValueArray; use crate::{ extension::prelude::{bool_t, qb_t}, @@ -256,7 +257,6 @@ mod tests { #[rstest] #[case(Array)] #[case(ValueArray)] - #[case(BorrowArray)] fn test_scan_def(#[case] _kind: AK) { let op = GenericArrayScan::::new(bool_t(), qb_t(), vec![usize_t()], 2); let optype: OpType = op.clone().into(); @@ -267,7 +267,6 @@ mod tests { #[rstest] #[case(Array)] #[case(ValueArray)] - #[case(BorrowArray)] fn test_scan_map(#[case] _kind: AK) { let size = 2; let src_ty = qb_t(); @@ -293,7 +292,6 @@ mod tests { #[rstest] #[case(Array)] #[case(ValueArray)] - #[case(BorrowArray)] fn test_scan_accs(#[case] _kind: AK) { let size = 2; let src_ty = qb_t(); diff --git a/hugr-core/src/std_extensions/collections/array/array_value.rs b/hugr-core/src/std_extensions/collections/array/array_value.rs index 916f218739..8828acd982 100644 --- a/hugr-core/src/std_extensions/collections/array/array_value.rs +++ b/hugr-core/src/std_extensions/collections/array/array_value.rs @@ -94,7 +94,9 @@ impl GenericArrayValue { // constant can only hold classic type. let ty = match typ.args() { - [TypeArg::BoundedNat(n), TypeArg::Runtime(ty)] if *n as usize == self.values.len() => { + [TypeArg::BoundedNat { n }, TypeArg::Type { ty }] + if *n as usize == self.values.len() => + { ty } _ => { @@ -146,7 +148,6 @@ mod test { use crate::std_extensions::arithmetic::float_types::ConstF64; use crate::std_extensions::collections::array::Array; - use crate::std_extensions::collections::borrow_array::BorrowArray; use crate::std_extensions::collections::value_array::ValueArray; use super::*; @@ -154,7 +155,6 @@ mod test { #[rstest] #[case(Array)] #[case(ValueArray)] - #[case(BorrowArray)] fn test_array_value(#[case] _kind: AK) { let array_value = GenericArrayValue::::new(usize_t(), vec![ConstUsize::new(3).into()]); array_value.validate().unwrap(); diff --git a/hugr-core/src/std_extensions/collections/array/op_builder.rs b/hugr-core/src/std_extensions/collections/array/op_builder.rs index b408e1a3de..2740673f80 100644 --- a/hugr-core/src/std_extensions/collections/array/op_builder.rs +++ b/hugr-core/src/std_extensions/collections/array/op_builder.rs @@ -1,7 +1,6 @@ //! Builder trait for array operations in the dataflow graph. use crate::std_extensions::collections::array::GenericArrayOpDef; -use crate::std_extensions::collections::borrow_array::BorrowArray; use crate::std_extensions::collections::value_array::ValueArray; use crate::{ Wire, @@ -391,11 +390,6 @@ pub fn build_all_value_array_ops(builder: B) -> B { build_all_array_ops_generic::(builder) } -/// Helper function to build a Hugr that contains all basic array operations. -pub fn build_all_borrow_array_ops(builder: B) -> B { - build_all_array_ops_generic::(builder) -} - /// Testing utilities to generate Hugrs that contain array operations. #[cfg(test)] mod test { @@ -417,11 +411,4 @@ mod test { let builder = DFGBuilder::new(sig).unwrap(); build_all_value_array_ops(builder).finish_hugr().unwrap(); } - - #[test] - fn all_borrow_array_ops() { - let sig = Signature::new_endo(Type::EMPTY_TYPEROW); - let builder = DFGBuilder::new(sig).unwrap(); - build_all_borrow_array_ops(builder).finish_hugr().unwrap(); - } } diff --git a/hugr-core/src/std_extensions/collections/borrow_array.rs b/hugr-core/src/std_extensions/collections/borrow_array.rs deleted file mode 100644 index 52982d0833..0000000000 --- a/hugr-core/src/std_extensions/collections/borrow_array.rs +++ /dev/null @@ -1,797 +0,0 @@ -//! A version of the standard fixed-length array extension that includes unsafe -//! operations for borrowing and returning that may panic. - -use std::sync::{self, Arc}; - -use delegate::delegate; -use lazy_static::lazy_static; - -use crate::extension::{ExtensionId, SignatureError, TypeDef, TypeDefBound}; -use crate::ops::constant::{CustomConst, ValueName}; -use crate::type_row; -use crate::types::type_param::{TypeArg, TypeParam}; -use crate::types::{CustomCheckFailure, Term, Type, TypeBound, TypeName}; -use crate::{Extension, Wire}; -use crate::{ - builder::{BuildError, Dataflow}, - extension::SignatureFunc, -}; -use crate::{ - extension::simple_op::{HasConcrete, HasDef, MakeExtensionOp, MakeOpDef, MakeRegisteredOp}, - ops::ExtensionOp, -}; -use crate::{ - extension::{ - OpDef, - prelude::usize_t, - resolution::{ExtensionResolutionError, WeakExtensionRegistry}, - simple_op::{OpLoadError, try_from_name}, - }, - ops::OpName, - types::{FuncValueType, PolyFuncTypeRV}, -}; - -use super::array::op_builder::GenericArrayOpBuilder; -use super::array::{ - Array, ArrayKind, FROM, GenericArrayClone, GenericArrayCloneDef, GenericArrayConvert, - GenericArrayConvertDef, GenericArrayDiscard, GenericArrayDiscardDef, GenericArrayOp, - GenericArrayOpDef, GenericArrayRepeat, GenericArrayRepeatDef, GenericArrayScan, - GenericArrayScanDef, GenericArrayValue, INTO, -}; - -/// Reported unique name of the borrow array type. -pub const BORROW_ARRAY_TYPENAME: TypeName = TypeName::new_inline("borrow_array"); -/// Reported unique name of the borrow array value. -pub const BORROW_ARRAY_VALUENAME: TypeName = TypeName::new_inline("borrow_array"); -/// Reported unique name of the extension -pub const EXTENSION_ID: ExtensionId = ExtensionId::new_unchecked("collections.borrow_arr"); -/// Extension version. -pub const VERSION: semver::Version = semver::Version::new(0, 1, 1); - -/// A linear, unsafe, fixed-length collection of values. -/// -/// Borrow arrays are linear, even if their elements are copyable. -#[derive(Clone, Copy, Debug, derive_more::Display, Eq, PartialEq, Default)] -pub struct BorrowArray; - -impl ArrayKind for BorrowArray { - const EXTENSION_ID: ExtensionId = EXTENSION_ID; - const TYPE_NAME: TypeName = BORROW_ARRAY_TYPENAME; - const VALUE_NAME: ValueName = BORROW_ARRAY_VALUENAME; - - fn extension() -> &'static Arc { - &EXTENSION - } - - fn type_def() -> &'static TypeDef { - EXTENSION.get_type(&BORROW_ARRAY_TYPENAME).unwrap() - } -} - -/// Borrow array operation definitions. -pub type BArrayOpDef = GenericArrayOpDef; -/// Borrow array clone operation definition. -pub type BArrayCloneDef = GenericArrayCloneDef; -/// Borrow array discard operation definition. -pub type BArrayDiscardDef = GenericArrayDiscardDef; -/// Borrow array repeat operation definition. -pub type BArrayRepeatDef = GenericArrayRepeatDef; -/// Borrow array scan operation definition. -pub type BArrayScanDef = GenericArrayScanDef; -/// Borrow array to default array conversion operation definition. -pub type BArrayToArrayDef = GenericArrayConvertDef; -/// Borrow array from default array conversion operation definition. -pub type BArrayFromArrayDef = GenericArrayConvertDef; - -/// Borrow array operations. -pub type BArrayOp = GenericArrayOp; -/// The borrow array clone operation. -pub type BArrayClone = GenericArrayClone; -/// The borrow array discard operation. -pub type BArrayDiscard = GenericArrayDiscard; -/// The borrow array repeat operation. -pub type BArrayRepeat = GenericArrayRepeat; -/// The borrow array scan operation. -pub type BArrayScan = GenericArrayScan; -/// The borrow array to default array conversion operation. -pub type BArrayToArray = GenericArrayConvert; -/// The borrow array from default array conversion operation. -pub type BArrayFromArray = GenericArrayConvert; - -/// A borrow array extension value. -pub type BArrayValue = GenericArrayValue; - -#[derive( - Clone, - Copy, - Debug, - Hash, - PartialEq, - Eq, - strum::EnumIter, - strum::IntoStaticStr, - strum::EnumString, -)] -#[allow(non_camel_case_types, missing_docs)] -#[non_exhaustive] -pub enum BArrayUnsafeOpDef { - /// `borrow: borrow_array, index -> elem_ty, borrow_array` - borrow, - /// `return: borrow_array, index, elem_ty -> borrow_array` - #[strum(serialize = "return")] - r#return, - /// `discard_all_borrowed: borrow_array -> ()` - discard_all_borrowed, - /// `new_all_borrowed: () -> borrow_array` - new_all_borrowed, -} - -impl BArrayUnsafeOpDef { - /// Instantiate a new unsafe borrow array operation with the given element type and array size. - #[must_use] - pub fn to_concrete(self, elem_ty: Type, size: u64) -> BArrayUnsafeOp { - BArrayUnsafeOp { - def: self, - elem_ty, - size, - } - } - - fn signature_from_def(&self, def: &TypeDef, _: &sync::Weak) -> SignatureFunc { - let size_var = TypeArg::new_var_use(0, TypeParam::max_nat_type()); - let elem_ty_var = Type::new_var_use(1, TypeBound::Linear); - let array_ty: Type = def - .instantiate(vec![size_var, elem_ty_var.clone().into()]) - .unwrap() - .into(); - - let params = vec![TypeParam::max_nat_type(), TypeBound::Linear.into()]; - - let usize_t: Type = usize_t(); - - match self { - Self::borrow => PolyFuncTypeRV::new( - params, - FuncValueType::new(vec![array_ty.clone(), usize_t], vec![elem_ty_var, array_ty]), - ), - Self::r#return => PolyFuncTypeRV::new( - params, - FuncValueType::new( - vec![array_ty.clone(), usize_t, elem_ty_var.clone()], - vec![array_ty], - ), - ), - Self::discard_all_borrowed => { - PolyFuncTypeRV::new(params, FuncValueType::new(vec![array_ty], type_row![])) - } - Self::new_all_borrowed => { - PolyFuncTypeRV::new(params, FuncValueType::new(type_row![], vec![array_ty])) - } - } - .into() - } -} - -impl MakeOpDef for BArrayUnsafeOpDef { - fn opdef_id(&self) -> OpName { - <&'static str>::from(self).into() - } - - fn from_def(op_def: &OpDef) -> Result - where - Self: Sized, - { - try_from_name(op_def.name(), op_def.extension_id()) - } - - fn init_signature(&self, extension_ref: &sync::Weak) -> SignatureFunc { - self.signature_from_def( - EXTENSION.get_type(&BORROW_ARRAY_TYPENAME).unwrap(), - extension_ref, - ) - } - - fn extension_ref(&self) -> sync::Weak { - Arc::downgrade(&EXTENSION) - } - - fn extension(&self) -> ExtensionId { - EXTENSION_ID.clone() - } - - fn description(&self) -> String { - match self { - Self::borrow => { - "Take an element from a borrow array (panicking if it was already taken before)" - } - Self::r#return => { - "Put an element into a borrow array (panicking if there is an element already)" - } - Self::discard_all_borrowed => { - "Discard a borrow array where all elements have been borrowed" - } - Self::new_all_borrowed => "Create a new borrow array that contains no elements", - } - .into() - } - - // This method is re-defined here to avoid recursive loops initializing the extension. - fn add_to_extension( - &self, - extension: &mut Extension, - extension_ref: &sync::Weak, - ) -> Result<(), crate::extension::ExtensionBuildError> { - let sig = self.signature_from_def( - extension.get_type(&BORROW_ARRAY_TYPENAME).unwrap(), - extension_ref, - ); - let def = extension.add_op(self.opdef_id(), self.description(), sig, extension_ref)?; - - self.post_opdef(def); - - Ok(()) - } -} - -#[derive(Clone, Debug, PartialEq)] -/// Concrete array operation. -pub struct BArrayUnsafeOp { - /// The operation definition. - pub def: BArrayUnsafeOpDef, - /// The element type of the array. - pub elem_ty: Type, - /// The size of the array. - pub size: u64, -} - -impl MakeExtensionOp for BArrayUnsafeOp { - fn op_id(&self) -> OpName { - self.def.opdef_id() - } - - fn from_extension_op(ext_op: &ExtensionOp) -> Result - where - Self: Sized, - { - let def = BArrayUnsafeOpDef::from_def(ext_op.def())?; - def.instantiate(ext_op.args()) - } - - fn type_args(&self) -> Vec { - vec![self.size.into(), self.elem_ty.clone().into()] - } -} - -impl HasDef for BArrayUnsafeOp { - type Def = BArrayUnsafeOpDef; -} - -impl HasConcrete for BArrayUnsafeOpDef { - type Concrete = BArrayUnsafeOp; - - fn instantiate(&self, type_args: &[TypeArg]) -> Result { - match type_args { - [Term::BoundedNat(n), Term::Runtime(ty)] => Ok(self.to_concrete(ty.clone(), *n)), - _ => Err(SignatureError::InvalidTypeArgs.into()), - } - } -} - -impl MakeRegisteredOp for BArrayUnsafeOp { - fn extension_id(&self) -> ExtensionId { - EXTENSION_ID.clone() - } - - fn extension_ref(&self) -> sync::Weak { - Arc::downgrade(&EXTENSION) - } -} - -lazy_static! { - /// Extension for borrow array operations. - pub static ref EXTENSION: Arc = { - Extension::new_arc(EXTENSION_ID, VERSION, |extension, extension_ref| { - extension.add_type( - BORROW_ARRAY_TYPENAME, - vec![ TypeParam::max_nat_type(), TypeBound::Linear.into()], - "Fixed-length borrow array".into(), - // Borrow array is linear, even if the elements are copyable. - TypeDefBound::any(), - extension_ref, - ) - .unwrap(); - - BArrayOpDef::load_all_ops(extension, extension_ref).unwrap(); - BArrayCloneDef::new().add_to_extension(extension, extension_ref).unwrap(); - BArrayDiscardDef::new().add_to_extension(extension, extension_ref).unwrap(); - BArrayRepeatDef::new().add_to_extension(extension, extension_ref).unwrap(); - BArrayScanDef::new().add_to_extension(extension, extension_ref).unwrap(); - BArrayToArrayDef::new().add_to_extension(extension, extension_ref).unwrap(); - BArrayFromArrayDef::new().add_to_extension(extension, extension_ref).unwrap(); - - BArrayUnsafeOpDef::load_all_ops(extension, extension_ref).unwrap(); - }) - }; -} - -#[typetag::serde(name = "BArrayValue")] -impl CustomConst for BArrayValue { - delegate! { - to self { - fn name(&self) -> ValueName; - fn validate(&self) -> Result<(), CustomCheckFailure>; - fn update_extensions( - &mut self, - extensions: &WeakExtensionRegistry, - ) -> Result<(), ExtensionResolutionError>; - fn get_type(&self) -> Type; - } - } - - fn equal_consts(&self, other: &dyn CustomConst) -> bool { - crate::ops::constant::downcast_equal_consts(self, other) - } -} - -/// Gets the [`TypeDef`] for borrow arrays. Note that instantiations are more easily -/// created via [`borrow_array_type`] and [`borrow_array_type_parametric`] -#[must_use] -pub fn borrow_array_type_def() -> &'static TypeDef { - BorrowArray::type_def() -} - -/// Instantiate a new borrow array type given a size argument and element type. -/// -/// This method is equivalent to [`borrow_array_type_parametric`], but uses concrete -/// arguments types to ensure no errors are possible. -#[must_use] -pub fn borrow_array_type(size: u64, element_ty: Type) -> Type { - BorrowArray::ty(size, element_ty) -} - -/// Instantiate a new borrow array type given the size and element type parameters. -/// -/// This is a generic version of [`borrow_array_type`]. -pub fn borrow_array_type_parametric( - size: impl Into, - element_ty: impl Into, -) -> Result { - BorrowArray::ty_parametric(size, element_ty) -} - -/// Trait for building borrow array operations in a dataflow graph. -pub trait BArrayOpBuilder: GenericArrayOpBuilder { - /// Adds a new array operation to the dataflow graph and return the wire - /// representing the new array. - /// - /// # Arguments - /// - /// * `elem_ty` - The type of the elements in the array. - /// * `values` - An iterator over the values to initialize the array with. - /// - /// # Errors - /// - /// If building the operation fails. - /// - /// # Returns - /// - /// The wire representing the new array. - fn add_new_borrow_array( - &mut self, - elem_ty: Type, - values: impl IntoIterator, - ) -> Result { - self.add_new_generic_array::(elem_ty, values) - } - /// Adds an array unpack operation to the dataflow graph. - /// - /// This operation unpacks an array into individual elements. - /// - /// # Arguments - /// - /// * `elem_ty` - The type of the elements in the array. - /// * `size` - The size of the array. - /// * `input` - The wire representing the array to unpack. - /// - /// # Errors - /// - /// If building the operation fails. - /// - /// # Returns - /// - /// A vector of wires representing the individual elements from the array. - fn add_borrow_array_unpack( - &mut self, - elem_ty: Type, - size: u64, - input: Wire, - ) -> Result, BuildError> { - self.add_generic_array_unpack::(elem_ty, size, input) - } - /// Adds an array clone operation to the dataflow graph and return the wires - /// representing the original and cloned array. - /// - /// # Arguments - /// - /// * `elem_ty` - The type of the elements in the array. - /// * `size` - The size of the array. - /// * `input` - The wire representing the array. - /// - /// # Errors - /// - /// If building the operation fails. - /// - /// # Returns - /// - /// The wires representing the original and cloned array. - fn add_borrow_array_clone( - &mut self, - elem_ty: Type, - size: u64, - input: Wire, - ) -> Result<(Wire, Wire), BuildError> { - self.add_generic_array_clone::(elem_ty, size, input) - } - - /// Adds an array discard operation to the dataflow graph. - /// - /// # Arguments - /// - /// * `elem_ty` - The type of the elements in the array. - /// * `size` - The size of the array. - /// * `input` - The wire representing the array. - /// - /// # Errors - /// - /// If building the operation fails. - fn add_borrow_array_discard( - &mut self, - elem_ty: Type, - size: u64, - input: Wire, - ) -> Result<(), BuildError> { - self.add_generic_array_discard::(elem_ty, size, input) - } - - /// Adds an array get operation to the dataflow graph. - /// - /// # Arguments - /// - /// * `elem_ty` - The type of the elements in the array. - /// * `size` - The size of the array. - /// * `input` - The wire representing the array. - /// * `index` - The wire representing the index to get. - /// - /// # Errors - /// - /// If building the operation fails. - /// - /// # Returns - /// - /// * The wire representing the value at the specified index in the array - /// * The wire representing the array - fn add_borrow_array_get( - &mut self, - elem_ty: Type, - size: u64, - input: Wire, - index: Wire, - ) -> Result<(Wire, Wire), BuildError> { - self.add_generic_array_get::(elem_ty, size, input, index) - } - - /// Adds an array set operation to the dataflow graph. - /// - /// This operation sets the value at a specified index in the array. - /// - /// # Arguments - /// - /// * `elem_ty` - The type of the elements in the array. - /// * `size` - The size of the array. - /// * `input` - The wire representing the array. - /// * `index` - The wire representing the index to set. - /// * `value` - The wire representing the value to set at the specified index. - /// - /// # Errors - /// - /// Returns an error if building the operation fails. - /// - /// # Returns - /// - /// The wire representing the updated array after the set operation. - fn add_borrow_array_set( - &mut self, - elem_ty: Type, - size: u64, - input: Wire, - index: Wire, - value: Wire, - ) -> Result { - self.add_generic_array_set::(elem_ty, size, input, index, value) - } - - /// Adds an array swap operation to the dataflow graph. - /// - /// This operation swaps the values at two specified indices in the array. - /// - /// # Arguments - /// - /// * `elem_ty` - The type of the elements in the array. - /// * `size` - The size of the array. - /// * `input` - The wire representing the array. - /// * `index1` - The wire representing the first index to swap. - /// * `index2` - The wire representing the second index to swap. - /// - /// # Errors - /// - /// Returns an error if building the operation fails. - /// - /// # Returns - /// - /// The wire representing the updated array after the swap operation. - fn add_borrow_array_swap( - &mut self, - elem_ty: Type, - size: u64, - input: Wire, - index1: Wire, - index2: Wire, - ) -> Result { - let op = - GenericArrayOpDef::::swap.instantiate(&[size.into(), elem_ty.into()])?; - let [out] = self - .add_dataflow_op(op, vec![input, index1, index2])? - .outputs_arr(); - Ok(out) - } - - /// Adds an array pop-left operation to the dataflow graph. - /// - /// This operation removes the leftmost element from the array. - /// - /// # Arguments - /// - /// * `elem_ty` - The type of the elements in the array. - /// * `size` - The size of the array. - /// * `input` - The wire representing the array. - /// - /// # Errors - /// - /// Returns an error if building the operation fails. - /// - /// # Returns - /// - /// The wire representing the Option> - fn add_borrow_array_pop_left( - &mut self, - elem_ty: Type, - size: u64, - input: Wire, - ) -> Result { - self.add_generic_array_pop_left::(elem_ty, size, input) - } - - /// Adds an array pop-right operation to the dataflow graph. - /// - /// This operation removes the rightmost element from the array. - /// - /// # Arguments - /// - /// * `elem_ty` - The type of the elements in the array. - /// * `size` - The size of the array. - /// * `input` - The wire representing the array. - /// - /// # Errors - /// - /// Returns an error if building the operation fails. - /// - /// # Returns - /// - /// The wire representing the Option> - fn add_borrow_array_pop_right( - &mut self, - elem_ty: Type, - size: u64, - input: Wire, - ) -> Result { - self.add_generic_array_pop_right::(elem_ty, size, input) - } - - /// Adds an operation to discard an empty array from the dataflow graph. - /// - /// # Arguments - /// - /// * `elem_ty` - The type of the elements in the array. - /// * `input` - The wire representing the array. - /// - /// # Errors - /// - /// Returns an error if building the operation fails. - fn add_borrow_array_discard_empty( - &mut self, - elem_ty: Type, - input: Wire, - ) -> Result<(), BuildError> { - self.add_generic_array_discard_empty::(elem_ty, input) - } - - /// Adds a borrow array borrow operation to the dataflow graph. - /// - /// # Arguments - /// - /// * `elem_ty` - The type of the elements in the array. - /// * `size` - The size of the array. - /// * `input` - The wire representing the array. - /// * `index` - The wire representing the index to get. - /// - /// # Errors - /// - /// Returns an error if building the operation fails. - fn add_borrow_array_borrow( - &mut self, - elem_ty: Type, - size: u64, - input: Wire, - index: Wire, - ) -> Result<(Wire, Wire), BuildError> { - let op = BArrayUnsafeOpDef::borrow.instantiate(&[size.into(), elem_ty.into()])?; - let [out, arr] = self - .add_dataflow_op(op.to_extension_op().unwrap(), vec![input, index])? - .outputs_arr(); - Ok((out, arr)) - } - - /// Adds a borrow array put operation to the dataflow graph. - /// - /// # Arguments - /// - /// * `elem_ty` - The type of the elements in the array. - /// * `size` - The size of the array. - /// * `input` - The wire representing the array. - /// * `index` - The wire representing the index to set. - /// * `value` - The wire representing the value to set at the specified index. - /// - /// # Errors - /// - /// Returns an error if building the operation fails. - fn add_borrow_array_return( - &mut self, - elem_ty: Type, - size: u64, - input: Wire, - index: Wire, - value: Wire, - ) -> Result { - let op = BArrayUnsafeOpDef::r#return.instantiate(&[size.into(), elem_ty.into()])?; - let [arr] = self - .add_dataflow_op(op.to_extension_op().unwrap(), vec![input, index, value])? - .outputs_arr(); - Ok(arr) - } - - /// Adds an operation to discard a borrow array where all elements have been borrowed. - /// - /// # Arguments - /// - /// * `elem_ty` - The type of the elements in the array. - /// * `size` - The size of the array. - /// * `input` - The wire representing the array. - /// - /// # Errors - /// - /// Returns an error if building the operation fails. - fn add_discard_all_borrowed( - &mut self, - elem_ty: Type, - size: u64, - input: Wire, - ) -> Result<(), BuildError> { - let op = - BArrayUnsafeOpDef::discard_all_borrowed.instantiate(&[size.into(), elem_ty.into()])?; - self.add_dataflow_op(op.to_extension_op().unwrap(), vec![input])?; - Ok(()) - } - - /// Adds an operation to create a new empty borrowed array in the dataflow graph. - /// - /// # Arguments - /// - /// * `elem_ty` - The type of the elements in the array. - /// * `size` - The size of the array. - /// - /// # Errors - /// - /// Returns an error if building the operation fails. - fn add_new_all_borrowed(&mut self, elem_ty: Type, size: u64) -> Result { - let op = BArrayUnsafeOpDef::new_all_borrowed.instantiate(&[size.into(), elem_ty.into()])?; - let [arr] = self - .add_dataflow_op(op.to_extension_op().unwrap(), vec![])? - .outputs_arr(); - Ok(arr) - } -} - -impl BArrayOpBuilder for D {} - -#[cfg(test)] -mod test { - use strum::IntoEnumIterator; - - use crate::{ - builder::{DFGBuilder, Dataflow, DataflowHugr as _}, - extension::prelude::{ConstUsize, qb_t, usize_t}, - ops::OpType, - std_extensions::collections::borrow_array::{ - BArrayOpBuilder, BArrayUnsafeOp, BArrayUnsafeOpDef, borrow_array_type, - }, - types::Signature, - }; - - #[test] - fn test_borrow_array_unsafe_ops() { - for def in BArrayUnsafeOpDef::iter() { - let op = def.to_concrete(qb_t(), 2); - let optype: OpType = op.clone().into(); - let new_op: BArrayUnsafeOp = optype.cast().unwrap(); - assert_eq!(new_op, op); - } - } - - #[test] - fn test_borrow_and_return() { - let size = 22; - let elem_ty = qb_t(); - let arr_ty = borrow_array_type(size, elem_ty.clone()); - let _ = { - let mut builder = DFGBuilder::new(Signature::new_endo(vec![arr_ty.clone()])).unwrap(); - let idx1 = builder.add_load_value(ConstUsize::new(11)); - let idx2 = builder.add_load_value(ConstUsize::new(11)); - let [arr] = builder.input_wires_arr(); - let (el, arr_with_take) = builder - .add_borrow_array_borrow(elem_ty.clone(), size, arr, idx1) - .unwrap(); - let arr_with_put = builder - .add_borrow_array_return(elem_ty, size, arr_with_take, idx2, el) - .unwrap(); - builder.finish_hugr_with_outputs([arr_with_put]).unwrap() - }; - } - - #[test] - fn test_discard_all_borrowed() { - let size = 1; - let elem_ty = qb_t(); - let arr_ty = borrow_array_type(size, elem_ty.clone()); - let _ = { - let mut builder = - DFGBuilder::new(Signature::new(vec![arr_ty.clone()], vec![qb_t()])).unwrap(); - let idx = builder.add_load_value(ConstUsize::new(0)); - let [arr] = builder.input_wires_arr(); - let (el, arr_with_borrowed) = builder - .add_borrow_array_borrow(elem_ty.clone(), size, arr, idx) - .unwrap(); - builder - .add_discard_all_borrowed(elem_ty, size, arr_with_borrowed) - .unwrap(); - builder.finish_hugr_with_outputs([el]).unwrap() - }; - } - - #[test] - fn test_new_all_borrowed() { - let size = 5; - let elem_ty = usize_t(); - let arr_ty = borrow_array_type(size, elem_ty.clone()); - let _ = { - let mut builder = - DFGBuilder::new(Signature::new(vec![], vec![arr_ty.clone()])).unwrap(); - let arr = builder.add_new_all_borrowed(elem_ty.clone(), size).unwrap(); - let idx = builder.add_load_value(ConstUsize::new(3)); - let val = builder.add_load_value(ConstUsize::new(202)); - let arr_with_put = builder - .add_borrow_array_return(elem_ty, size, arr, idx, val) - .unwrap(); - builder.finish_hugr_with_outputs([arr_with_put]).unwrap() - }; - } -} diff --git a/hugr-core/src/std_extensions/collections/list.rs b/hugr-core/src/std_extensions/collections/list.rs index 817f90dba4..05d05048a6 100644 --- a/hugr-core/src/std_extensions/collections/list.rs +++ b/hugr-core/src/std_extensions/collections/list.rs @@ -21,7 +21,7 @@ use crate::extension::simple_op::{MakeOpDef, MakeRegisteredOp}; use crate::extension::{ExtensionBuildError, OpDef, SignatureFunc}; use crate::ops::constant::{TryHash, ValueName, maybe_hash_values}; use crate::ops::{OpName, Value}; -use crate::types::{Term, TypeName, TypeRowRV}; +use crate::types::{TypeName, TypeRowRV}; use crate::{ Extension, extension::{ @@ -112,7 +112,7 @@ impl CustomConst for ListValue { .map_err(|_| error())?; // constant can only hold classic type. - let [TypeArg::Runtime(ty)] = typ.args() else { + let [TypeArg::Type { ty }] = typ.args() else { return Err(error()); }; @@ -167,7 +167,7 @@ pub enum ListOp { impl ListOp { /// Type parameter used in the list types. - const TP: TypeParam = TypeParam::RuntimeType(TypeBound::Linear); + const TP: TypeParam = TypeParam::Type { b: TypeBound::Any }; /// Instantiate a list operation with an `element_type`. #[must_use] @@ -181,7 +181,7 @@ impl ListOp { /// Compute the signature of the operation, given the list type definition. fn compute_signature(self, list_type_def: &TypeDef) -> SignatureFunc { use ListOp::{get, insert, length, pop, push, set}; - let e = Type::new_var_use(0, TypeBound::Linear); + let e = Type::new_var_use(0, TypeBound::Any); let l = self.list_type(list_type_def, 0); match self { pop => self @@ -325,7 +325,9 @@ pub fn list_type_def() -> &'static TypeDef { /// Get the type of a list of `elem_type` as a `CustomType`. #[must_use] pub fn list_custom_type(elem_type: Type) -> CustomType { - list_type_def().instantiate(vec![elem_type.into()]).unwrap() + list_type_def() + .instantiate(vec![TypeArg::Type { ty: elem_type }]) + .unwrap() } /// Get the `Type` of a list of `elem_type`. @@ -351,7 +353,7 @@ impl MakeExtensionOp for ListOpInst { fn from_extension_op( ext_op: &ExtensionOp, ) -> Result { - let [Term::Runtime(ty)] = ext_op.args() else { + let [TypeArg::Type { ty }] = ext_op.args() else { return Err(SignatureError::InvalidTypeArgs.into()); }; let name = ext_op.unqualified_id(); @@ -365,8 +367,10 @@ impl MakeExtensionOp for ListOpInst { }) } - fn type_args(&self) -> Vec { - vec![self.elem_type.clone().into()] + fn type_args(&self) -> Vec { + vec![TypeArg::Type { + ty: self.elem_type.clone(), + }] } } @@ -409,9 +413,15 @@ mod test { fn test_list() { let list_def = list_type_def(); - let list_type = list_def.instantiate([usize_t().into()]).unwrap(); + let list_type = list_def + .instantiate([TypeArg::Type { ty: usize_t() }]) + .unwrap(); - assert!(list_def.instantiate([3u64.into()]).is_err()); + assert!( + list_def + .instantiate([TypeArg::BoundedNat { n: 3 }]) + .is_err() + ); list_def.check_custom(&list_type).unwrap(); let list_value = ListValue(vec![ConstUsize::new(3).into()], usize_t()); diff --git a/hugr-core/src/std_extensions/collections/static_array.rs b/hugr-core/src/std_extensions/collections/static_array.rs index c99e7617b2..6f3e889e68 100644 --- a/hugr-core/src/std_extensions/collections/static_array.rs +++ b/hugr-core/src/std_extensions/collections/static_array.rs @@ -38,7 +38,7 @@ use crate::{ types::{ ConstTypeError, CustomCheckFailure, CustomType, PolyFuncType, Signature, Type, TypeArg, TypeBound, TypeName, - type_param::{TermTypeError, TypeParam}, + type_param::{TypeArgError, TypeParam}, }, }; @@ -309,12 +309,12 @@ impl HasConcrete for StaticArrayOpDef { match type_args { [arg] => { let elem_ty = arg - .as_runtime() + .as_type() .filter(|t| Copyable.contains(t.least_upper_bound())) .ok_or(SignatureError::TypeArgMismatch( - TermTypeError::TypeMismatch { - type_: Box::new(Copyable.into()), - term: Box::new(arg.clone()), + TypeArgError::TypeMismatch { + param: Copyable.into(), + arg: arg.clone(), }, ))?; @@ -324,7 +324,7 @@ impl HasConcrete for StaticArrayOpDef { }) } _ => Err( - SignatureError::TypeArgMismatch(TermTypeError::WrongNumberArgs(type_args.len(), 1)) + SignatureError::TypeArgMismatch(TypeArgError::WrongNumberArgs(type_args.len(), 1)) .into(), ), } diff --git a/hugr-core/src/std_extensions/collections/value_array.rs b/hugr-core/src/std_extensions/collections/value_array.rs index 947fef9188..fe89824d77 100644 --- a/hugr-core/src/std_extensions/collections/value_array.rs +++ b/hugr-core/src/std_extensions/collections/value_array.rs @@ -102,7 +102,7 @@ lazy_static! { Extension::new_arc(EXTENSION_ID, VERSION, |extension, extension_ref| { extension.add_type( VALUE_ARRAY_TYPENAME, - vec![ TypeParam::max_nat_type(), TypeBound::Linear.into()], + vec![ TypeParam::max_nat(), TypeBound::Any.into()], "Fixed-length value array".into(), // Value arrays are copyable iff their elements are TypeDefBound::from_params(vec![1]), diff --git a/hugr-core/src/std_extensions/ptr.rs b/hugr-core/src/std_extensions/ptr.rs index 74ecf63fc1..3955c3a972 100644 --- a/hugr-core/src/std_extensions/ptr.rs +++ b/hugr-core/src/std_extensions/ptr.rs @@ -89,7 +89,9 @@ impl MakeOpDef for PtrOpDef { pub const EXTENSION_ID: ExtensionId = ExtensionId::new_unchecked("ptr"); /// Name of pointer type. pub const PTR_TYPE_ID: TypeName = TypeName::new_inline("ptr"); -const TYPE_PARAMS: [TypeParam; 1] = [TypeParam::RuntimeType(TypeBound::Copyable)]; +const TYPE_PARAMS: [TypeParam; 1] = [TypeParam::Type { + b: TypeBound::Copyable, +}]; /// Extension version. pub const VERSION: semver::Version = semver::Version::new(0, 1, 0); @@ -207,7 +209,7 @@ impl HasConcrete for PtrOpDef { fn instantiate(&self, type_args: &[TypeArg]) -> Result { let ty = match type_args { - [TypeArg::Runtime(ty)] => ty.clone(), + [TypeArg::Type { ty }] => ty.clone(), _ => return Err(SignatureError::InvalidTypeArgs.into()), }; diff --git a/hugr-core/src/types.rs b/hugr-core/src/types.rs index cdfd1012a8..2b36233133 100644 --- a/hugr-core/src/types.rs +++ b/hugr-core/src/types.rs @@ -4,7 +4,7 @@ mod check; pub mod custom; mod poly_func; mod row_var; -pub(crate) mod serialize; +mod serialize; mod signature; pub mod type_param; pub mod type_row; @@ -15,14 +15,14 @@ use crate::extension::resolution::{ ExtensionCollectionError, WeakExtensionRegistry, collect_type_exts, }; pub use crate::ops::constant::{ConstTypeError, CustomCheckFailure}; -use crate::types::type_param::check_term_type; +use crate::types::type_param::check_type_arg; use crate::utils::display_list_with_separator; pub use check::SumTypeError; pub use custom::CustomType; pub use poly_func::{PolyFuncType, PolyFuncTypeRV}; pub use signature::{FuncTypeBase, FuncValueType, Signature}; use smol_str::SmolStr; -pub use type_param::{Term, TypeArg}; +pub use type_param::TypeArg; pub use type_row::{TypeRow, TypeRowRV}; pub(crate) use poly_func::PolyFuncTypeBase; @@ -131,11 +131,9 @@ pub enum TypeBound { #[serde(rename = "C", alias = "E")] // alias to read in legacy Eq variants Copyable, /// No bound on the type. - /// - /// It cannot be copied nor discarded. #[serde(rename = "A")] #[default] - Linear, + Any, } impl TypeBound { @@ -154,16 +152,16 @@ impl TypeBound { /// Report if this bound contains another. #[must_use] pub const fn contains(&self, other: TypeBound) -> bool { - use TypeBound::{Copyable, Linear}; - matches!((self, other), (Linear, _) | (_, Copyable)) + use TypeBound::{Any, Copyable}; + matches!((self, other), (Any, _) | (_, Copyable)) } } /// Calculate the least upper bound for an iterator of bounds pub(crate) fn least_upper_bound(mut tags: impl Iterator) -> TypeBound { tags.fold_while(TypeBound::Copyable, |acc, new| { - if acc == TypeBound::Linear || new == TypeBound::Linear { - Done(TypeBound::Linear) + if acc == TypeBound::Any || new == TypeBound::Any { + Done(TypeBound::Any) } else { Continue(acc.union(new)) } @@ -492,7 +490,7 @@ impl TypeBase { /// New use (occurrence) of the type variable with specified index. /// `bound` must be exactly that with which the variable was declared - /// (i.e. as a [`Term::RuntimeType`]`(bound)`), which may be narrower + /// (i.e. as a [`TypeParam::Type`]`(bound)`), which may be narrower /// than required for the use. #[must_use] pub const fn new_var_use(idx: usize, bound: TypeBound) -> Self { @@ -577,7 +575,7 @@ impl TypeBase { TypeEnum::RowVar(rv) => rv.substitute(t), TypeEnum::Alias(_) | TypeEnum::Sum(SumType::Unit { .. }) => vec![self.clone()], TypeEnum::Variable(idx, bound) => { - let TypeArg::Runtime(ty) = t.apply_var(*idx, &((*bound).into())) else { + let TypeArg::Type { ty } = t.apply_var(*idx, &((*bound).into())) else { panic!("Variable was not a type - try validate() first") }; vec![ty.into_()] @@ -655,7 +653,7 @@ impl TypeRV { /// New use (occurrence) of the row variable with specified index. /// `bound` must match that with which the variable was declared - /// (i.e. as a list of runtime types of that bound). + /// (i.e. as a [TypeParam::List]` of a `[TypeParam::Type]` of that bound). /// For use in [OpDef], not [FuncDefn], type schemes only. /// /// [OpDef]: crate::extension::OpDef @@ -742,7 +740,7 @@ impl<'a> Substitution<'a> { .0 .get(idx) .expect("Undeclared type variable - call validate() ?"); - debug_assert_eq!(check_term_type(arg, decl), Ok(())); + debug_assert_eq!(check_type_arg(arg, decl), Ok(())); arg.clone() } @@ -751,14 +749,14 @@ impl<'a> Substitution<'a> { .0 .get(idx) .expect("Undeclared type variable - call validate() ?"); - debug_assert!(check_term_type(arg, &TypeParam::new_list_type(bound)).is_ok()); + debug_assert!(check_type_arg(arg, &TypeParam::new_list(bound)).is_ok()); match arg { - TypeArg::List(elems) => elems + TypeArg::Sequence { elems } => elems .iter() .map(|ta| { match ta { - Term::Runtime(ty) => return ty.clone().into(), - Term::Variable(v) => { + TypeArg::Type { ty } => return ty.clone().into(), + TypeArg::Variable { v } => { if let Some(b) = v.bound_if_row_var() { return TypeRV::new_row_var_use(v.index(), b); } @@ -768,7 +766,7 @@ impl<'a> Substitution<'a> { panic!("Not a list of types - call validate() ?") }) .collect(), - Term::Runtime(ty) if matches!(ty.0, TypeEnum::RowVar(_)) => { + TypeArg::Type { ty } if matches!(ty.0, TypeEnum::RowVar(_)) => { // Standalone "Type" can be used iff its actually a Row Variable not an actual (single) Type vec![ty.clone().into()] } @@ -779,11 +777,11 @@ impl<'a> Substitution<'a> { /// A transformation that can be applied to a [Type] or [`TypeArg`]. /// More general in some ways than a Substitution: can fail with a -/// [`Self::Err`], may change [`TypeBound::Copyable`] to [`TypeBound::Linear`], +/// [`Self::Err`], may change [`TypeBound::Copyable`] to [`TypeBound::Any`], /// and applies to arbitrary extension types rather than type variables. pub trait TypeTransformer { /// Error returned when a [`CustomType`] cannot be transformed, or a type - /// containing it (e.g. if changing a runtime type from copyable to + /// containing it (e.g. if changing a [`TypeArg::Type`] from copyable to /// linear invalidates a parameterized type). type Err: std::error::Error + From; @@ -841,8 +839,8 @@ pub(crate) fn check_typevar_decl( Ok(()) } else { Err(SignatureError::TypeVarDoesNotMatchDeclaration { - cached: Box::new(cached_decl.clone()), - actual: Box::new(actual.clone()), + cached: cached_decl.clone(), + actual: actual.clone(), }) } } @@ -859,7 +857,7 @@ pub(crate) mod test { use crate::extension::prelude::{option_type, qb_t, usize_t}; use crate::std_extensions::collections::array::{array_type, array_type_parametric}; use crate::std_extensions::collections::list::list_type; - use crate::types::type_param::TermTypeError; + use crate::types::type_param::TypeArgError; use crate::{Extension, hugr::IdentList, type_row}; #[test] @@ -932,7 +930,7 @@ pub(crate) mod test { fn sum_variants() { let variants: Vec = vec![ TypeRV::UNIT.into(), - vec![TypeRV::new_row_var_use(0, TypeBound::Linear)].into(), + vec![TypeRV::new_row_var_use(0, TypeBound::Any)].into(), ]; let t = SumType::new(variants.clone()); assert_eq!(variants, t.variants().cloned().collect_vec()); @@ -979,7 +977,7 @@ pub(crate) mod test { |t| array_type(10, t), |t| { array_type_parametric( - TypeArg::new_var_use(0, TypeParam::bounded_nat_type(3.try_into().unwrap())), + TypeArg::new_var_use(0, TypeParam::bounded_nat(3.try_into().unwrap())), t, ) .unwrap() @@ -1003,7 +1001,7 @@ pub(crate) mod test { .unwrap(); e.add_type( COLN, - vec![TypeParam::new_list_type(TypeBound::Copyable)], + vec![TypeParam::new_list(TypeBound::Copyable)], String::new(), TypeDefBound::copyable(), w, @@ -1022,27 +1020,31 @@ pub(crate) mod test { let coln = e.get_type(&COLN).unwrap(); let c_of_cpy = coln - .instantiate([Term::new_list([Type::from(cpy.clone()).into()])]) + .instantiate([TypeArg::Sequence { + elems: vec![Type::from(cpy.clone()).into()], + }]) .unwrap(); let mut t = Type::new_extension(c_of_cpy.clone()); assert_eq!( t.transform(&cpy_to_qb), - Err(SignatureError::from(TermTypeError::TypeMismatch { - type_: Box::new(TypeBound::Copyable.into()), - term: Box::new(qb_t().into()) + Err(SignatureError::from(TypeArgError::TypeMismatch { + param: TypeBound::Copyable.into(), + arg: qb_t().into() })) ); let mut t = Type::new_extension( - coln.instantiate([Term::new_list([mk_opt(Type::from(cpy.clone())).into()])]) - .unwrap(), + coln.instantiate([TypeArg::Sequence { + elems: vec![mk_opt(Type::from(cpy.clone())).into()], + }]) + .unwrap(), ); assert_eq!( t.transform(&cpy_to_qb), - Err(SignatureError::from(TermTypeError::TypeMismatch { - type_: Box::new(TypeBound::Copyable.into()), - term: Box::new(mk_opt(qb_t()).into()) + Err(SignatureError::from(TypeArgError::TypeMismatch { + param: TypeBound::Copyable.into(), + arg: mk_opt(qb_t()).into() })) ); @@ -1052,15 +1054,19 @@ pub(crate) mod test { (ct == &c_of_cpy).then_some(usize_t()) }); let mut t = Type::new_extension( - coln.instantiate([Term::new_list(vec![Type::from(c_of_cpy.clone()).into(); 2])]) - .unwrap(), + coln.instantiate([TypeArg::Sequence { + elems: vec![Type::from(c_of_cpy.clone()).into(); 2], + }]) + .unwrap(), ); assert_eq!(t.transform(&cpy_to_qb2), Ok(true)); assert_eq!( t, Type::new_extension( - coln.instantiate([Term::new_list([usize_t().into(), usize_t().into()])]) - .unwrap() + coln.instantiate([TypeArg::Sequence { + elems: vec![usize_t().into(); 2] + }]) + .unwrap() ) ); } @@ -1109,82 +1115,3 @@ pub(crate) mod test { } } } - -#[cfg(test)] -pub(super) mod proptest_utils { - use proptest::collection::vec; - use proptest::prelude::{Strategy, any_with}; - - use super::serialize::{TermSer, TypeArgSer, TypeParamSer}; - use super::type_param::Term; - - use crate::proptest::RecursionDepth; - use crate::types::serialize::ArrayOrTermSer; - - fn term_is_serde_type_arg(t: &Term) -> bool { - let TermSer::TypeArg(arg) = TermSer::from(t.clone()) else { - return false; - }; - match arg { - TypeArgSer::List { elems: terms } - | TypeArgSer::ListConcat { lists: terms } - | TypeArgSer::Tuple { elems: terms } - | TypeArgSer::TupleConcat { tuples: terms } => terms.iter().all(term_is_serde_type_arg), - TypeArgSer::Variable { v } => term_is_serde_type_param(&v.cached_decl), - TypeArgSer::Type { ty } => { - if let Some(cty) = ty.as_extension() { - cty.args().iter().all(term_is_serde_type_arg) - } else { - true - } - } // Do we need to inspect inside function types? sum types? - TypeArgSer::BoundedNat { .. } - | TypeArgSer::String { .. } - | TypeArgSer::Bytes { .. } - | TypeArgSer::Float { .. } => true, - } - } - - fn term_is_serde_type_param(t: &Term) -> bool { - let TermSer::TypeParam(parm) = TermSer::from(t.clone()) else { - return false; - }; - match parm { - TypeParamSer::Type { .. } - | TypeParamSer::BoundedNat { .. } - | TypeParamSer::String - | TypeParamSer::Bytes - | TypeParamSer::Float - | TypeParamSer::StaticType => true, - TypeParamSer::List { param } => term_is_serde_type_param(¶m), - TypeParamSer::Tuple { params } => { - match ¶ms { - ArrayOrTermSer::Array(terms) => terms.iter().all(term_is_serde_type_param), - ArrayOrTermSer::Term(b) => match &**b { - Term::List(_) => panic!("Should be represented as ArrayOrTermSer::Array"), - // This might be well-typed, but does not fit the (TODO: update) JSON schema - Term::Variable(_) => false, - // Similarly, but not produced by our `impl Arbitrary`: - Term::ListConcat(_) => todo!("Update schema"), - - // The others do not fit the JSON schema, and are not well-typed, - // but can be produced by our impl of Arbitrary, so we must filter out: - _ => false, - }, - } - } - } - } - - pub fn any_serde_type_arg(depth: RecursionDepth) -> impl Strategy { - any_with::(depth).prop_filter("Term was not a TypeArg", term_is_serde_type_arg) - } - - pub fn any_serde_type_arg_vec() -> impl Strategy> { - vec(any_serde_type_arg(RecursionDepth::default()), 1..3) - } - - pub fn any_serde_type_param(depth: RecursionDepth) -> impl Strategy { - any_with::(depth).prop_filter("Term was not a TypeParam", term_is_serde_type_param) - } -} diff --git a/hugr-core/src/types/check.rs b/hugr-core/src/types/check.rs index 072da5884e..2146ee41ba 100644 --- a/hugr-core/src/types/check.rs +++ b/hugr-core/src/types/check.rs @@ -17,9 +17,9 @@ pub enum SumTypeError { /// The element in the tuple that was wrong. index: usize, /// The expected type. - expected: Box, + expected: Type, /// The value that was found. - found: Box, + found: Value, }, /// The type of the variant we were trying to convert into contained type variables #[error("Sum variant #{tag} contained a variable #{varidx}")] @@ -88,8 +88,8 @@ impl super::SumType { Err(SumTypeError::InvalidValueType { tag, index, - expected: Box::new(t.clone()), - found: Box::new(v.clone()), + expected: t.clone(), + found: v.clone(), })?; } } diff --git a/hugr-core/src/types/custom.rs b/hugr-core/src/types/custom.rs index 248e0f6253..02ab188338 100644 --- a/hugr-core/src/types/custom.rs +++ b/hugr-core/src/types/custom.rs @@ -188,7 +188,7 @@ mod test { use crate::extension::ExtensionId; use crate::proptest::RecursionDepth; use crate::proptest::any_nonempty_string; - use crate::types::proptest_utils::any_serde_type_arg; + use crate::types::type_param::TypeArg; use crate::types::{CustomType, TypeBound}; use ::proptest::collection::vec; use ::proptest::prelude::*; @@ -224,7 +224,7 @@ mod test { Just(vec![]).boxed() } else { // a TypeArg may contain a CustomType, so we descend here - vec(any_serde_type_arg(depth.descend()), 0..3).boxed() + vec(any_with::(depth.descend()), 0..3).boxed() }; (any_nonempty_string(), args, any::(), bound) .prop_map(|(id, args, extension, bound)| { diff --git a/hugr-core/src/types/poly_func.rs b/hugr-core/src/types/poly_func.rs index 0de6b1b029..8121741bf8 100644 --- a/hugr-core/src/types/poly_func.rs +++ b/hugr-core/src/types/poly_func.rs @@ -7,14 +7,13 @@ use itertools::Itertools; use crate::extension::SignatureError; #[cfg(test)] use { - super::proptest_utils::any_serde_type_param, crate::proptest::RecursionDepth, ::proptest::{collection::vec, prelude::*}, proptest_derive::Arbitrary, }; use super::Substitution; -use super::type_param::{TypeArg, TypeParam, check_term_types}; +use super::type_param::{TypeArg, TypeParam, check_type_args}; use super::{MaybeRV, NoRV, RowVariable, signature::FuncTypeBase}; /// A polymorphic type scheme, i.e. of a [`FuncDecl`], [`FuncDefn`] or [`OpDef`]. @@ -32,7 +31,7 @@ pub struct PolyFuncTypeBase { /// The declared type parameters, i.e., these must be instantiated with /// the same number of [`TypeArg`]s before the function can be called. This /// defines the indices used by variables inside the body. - #[cfg_attr(test, proptest(strategy = "vec(any_serde_type_param(params), 0..3)"))] + #[cfg_attr(test, proptest(strategy = "vec(any_with::(params), 0..3)"))] params: Vec, /// Template for the function. May contain variables up to length of [`Self::params`] #[cfg_attr(test, proptest(strategy = "any_with::>(params)"))] @@ -123,7 +122,7 @@ impl PolyFuncTypeBase { pub fn instantiate(&self, args: &[TypeArg]) -> Result, SignatureError> { // Check that args are applicable, and that we have a value for each binder, // i.e. each possible free variable within the body. - check_term_types(args, &self.params)?; + check_type_args(args, &self.params)?; Ok(self.body.substitute(&Substitution(args))) } @@ -167,9 +166,9 @@ pub(crate) mod test { use crate::std_extensions::collections::array::{self, array_type_parametric}; use crate::std_extensions::collections::list; use crate::types::signature::FuncTypeBase; - use crate::types::type_param::{TermTypeError, TypeArg, TypeParam}; + use crate::types::type_param::{TypeArg, TypeArgError, TypeParam}; use crate::types::{ - CustomType, FuncValueType, MaybeRV, Signature, Term, Type, TypeBound, TypeName, TypeRV, + CustomType, FuncValueType, MaybeRV, Signature, Type, TypeBound, TypeName, TypeRV, }; use super::PolyFuncTypeBase; @@ -193,19 +192,21 @@ pub(crate) mod test { #[test] fn test_opaque() -> Result<(), SignatureError> { let list_def = list::EXTENSION.get_type(&list::LIST_TYPENAME).unwrap(); - let tyvar = TypeArg::new_var_use(0, TypeBound::Linear.into()); + let tyvar = TypeArg::new_var_use(0, TypeBound::Any.into()); let list_of_var = Type::new_extension(list_def.instantiate([tyvar.clone()])?); let list_len = PolyFuncTypeBase::new_validated( - [TypeBound::Linear.into()], + [TypeBound::Any.into()], Signature::new(vec![list_of_var], vec![usize_t()]), )?; - let t = list_len.instantiate(&[usize_t().into()])?; + let t = list_len.instantiate(&[TypeArg::Type { ty: usize_t() }])?; assert_eq!( t, Signature::new( vec![Type::new_extension( - list_def.instantiate([usize_t().into()]).unwrap() + list_def + .instantiate([TypeArg::Type { ty: usize_t() }]) + .unwrap() )], vec![usize_t()] ) @@ -216,9 +217,9 @@ pub(crate) mod test { #[test] fn test_mismatched_args() -> Result<(), SignatureError> { - let size_var = TypeArg::new_var_use(0, TypeParam::max_nat_type()); - let ty_var = TypeArg::new_var_use(1, TypeBound::Linear.into()); - let type_params = [TypeParam::max_nat_type(), TypeBound::Linear.into()]; + let size_var = TypeArg::new_var_use(0, TypeParam::max_nat()); + let ty_var = TypeArg::new_var_use(1, TypeBound::Any.into()); + let type_params = [TypeParam::max_nat(), TypeBound::Any.into()]; // Valid schema... let good_array = array_type_parametric(size_var.clone(), ty_var.clone())?; @@ -226,23 +227,29 @@ pub(crate) mod test { PolyFuncTypeBase::new_validated(type_params.clone(), Signature::new_endo(good_array))?; // Sanity check (good args) - good_ts.instantiate(&[5u64.into(), usize_t().into()])?; - - let wrong_args = good_ts.instantiate(&[usize_t().into(), 5u64.into()]); + good_ts.instantiate(&[ + TypeArg::BoundedNat { n: 5 }, + TypeArg::Type { ty: usize_t() }, + ])?; + + let wrong_args = good_ts.instantiate(&[ + TypeArg::Type { ty: usize_t() }, + TypeArg::BoundedNat { n: 5 }, + ]); assert_eq!( wrong_args, Err(SignatureError::TypeArgMismatch( - TermTypeError::TypeMismatch { - type_: Box::new(type_params[0].clone()), - term: Box::new(usize_t().into()), + TypeArgError::TypeMismatch { + param: type_params[0].clone(), + arg: TypeArg::Type { ty: usize_t() } } )) ); // (Try to) make a schema with the args in the wrong order - let arg_err = SignatureError::TypeArgMismatch(TermTypeError::TypeMismatch { - type_: Box::new(type_params[0].clone()), - term: Box::new(ty_var.clone()), + let arg_err = SignatureError::TypeArgMismatch(TypeArgError::TypeMismatch { + param: type_params[0].clone(), + arg: ty_var.clone(), }); assert_eq!( array_type_parametric(ty_var.clone(), size_var.clone()), @@ -253,7 +260,7 @@ pub(crate) mod test { "array", [ty_var, size_var], array::EXTENSION_ID, - TypeBound::Linear, + TypeBound::Any, &Arc::downgrade(&array::EXTENSION), )); let bad_ts = @@ -270,16 +277,20 @@ pub(crate) mod test { let list_def = list::EXTENSION.get_type(&list::LIST_TYPENAME).unwrap(); let body_type = Signature::new_endo(Type::new_extension(list_def.instantiate([tv])?)); for decl in [ - Term::new_list_type(Term::max_nat_type()), - Term::StringType, - Term::new_tuple_type([TypeBound::Linear.into(), Term::max_nat_type()]), + TypeParam::List { + param: Box::new(TypeParam::max_nat()), + }, + TypeParam::String, + TypeParam::Tuple { + params: vec![TypeBound::Any.into(), TypeParam::max_nat()], + }, ] { let invalid_ts = PolyFuncTypeBase::new_validated([decl.clone()], body_type.clone()); assert_eq!( invalid_ts.err(), Some(SignatureError::TypeVarDoesNotMatchDeclaration { - cached: Box::new(TypeBound::Copyable.into()), - actual: Box::new(decl) + cached: TypeBound::Copyable.into(), + actual: decl }) ); } @@ -325,7 +336,7 @@ pub(crate) mod test { TYPE_NAME, [TypeArg::new_var_use(0, tp)], EXT_ID, - TypeBound::Linear, + TypeBound::Any, &Arc::downgrade(&ext), ))), ) @@ -337,9 +348,9 @@ pub(crate) mod test { assert_eq!( make_scheme(decl.clone()).err(), Some(SignatureError::TypeArgMismatch( - TermTypeError::TypeMismatch { - type_: Box::new(bound.clone()), - term: Box::new(TypeArg::new_var_use(0, decl.clone())) + TypeArgError::TypeMismatch { + param: bound.clone(), + arg: TypeArg::new_var_use(0, decl.clone()) } )) ); @@ -352,33 +363,38 @@ pub(crate) mod test { decl_accepts_rejects_var( TypeBound::Copyable.into(), &[TypeBound::Copyable.into()], - &[TypeBound::Linear.into()], + &[TypeBound::Any.into()], )?; + let list_of_tys = |b: TypeBound| TypeParam::List { + param: Box::new(b.into()), + }; decl_accepts_rejects_var( - Term::new_list_type(TypeBound::Copyable), - &[Term::new_list_type(TypeBound::Copyable)], - &[Term::new_list_type(TypeBound::Linear)], + list_of_tys(TypeBound::Copyable), + &[list_of_tys(TypeBound::Copyable)], + &[list_of_tys(TypeBound::Any)], )?; decl_accepts_rejects_var( - TypeParam::max_nat_type(), - &[TypeParam::bounded_nat_type(NonZeroU64::new(5).unwrap())], + TypeParam::max_nat(), + &[TypeParam::bounded_nat(NonZeroU64::new(5).unwrap())], &[], )?; decl_accepts_rejects_var( - TypeParam::bounded_nat_type(NonZeroU64::new(10).unwrap()), - &[TypeParam::bounded_nat_type(NonZeroU64::new(5).unwrap())], - &[TypeParam::max_nat_type()], + TypeParam::bounded_nat(NonZeroU64::new(10).unwrap()), + &[TypeParam::bounded_nat(NonZeroU64::new(5).unwrap())], + &[TypeParam::max_nat()], )?; Ok(()) } - const TP_ANY: TypeParam = TypeParam::RuntimeType(TypeBound::Linear); + const TP_ANY: TypeParam = TypeParam::Type { b: TypeBound::Any }; #[test] fn row_variables_bad_schema() { // Mismatched TypeBound (Copyable vs Any) - let decl = Term::new_list_type(TP_ANY); + let decl = TypeParam::List { + param: Box::new(TP_ANY), + }; let e = PolyFuncTypeBase::new_validated( [decl.clone()], FuncValueType::new( @@ -388,26 +404,26 @@ pub(crate) mod test { ) .unwrap_err(); assert_matches!(e, SignatureError::TypeVarDoesNotMatchDeclaration { actual, cached } => { - assert_eq!(*actual, decl); - assert_eq!(*cached, TypeParam::new_list_type(TypeBound::Copyable)); + assert_eq!(actual, decl); + assert_eq!(cached, TypeParam::List {param: Box::new(TypeParam::Type {b: TypeBound::Copyable})}); }); // Declared as row variable, used as type variable let e = PolyFuncTypeBase::new_validated( [decl.clone()], - Signature::new_endo(vec![Type::new_var_use(0, TypeBound::Linear)]), + Signature::new_endo(vec![Type::new_var_use(0, TypeBound::Any)]), ) .unwrap_err(); assert_matches!(e, SignatureError::TypeVarDoesNotMatchDeclaration { actual, cached } => { - assert_eq!(*actual, decl); - assert_eq!(*cached, TP_ANY); + assert_eq!(actual, decl); + assert_eq!(cached, TP_ANY); }); } #[test] fn row_variables() { - let rty = TypeRV::new_row_var_use(0, TypeBound::Linear); + let rty = TypeRV::new_row_var_use(0, TypeBound::Any); let pf = PolyFuncTypeBase::new_validated( - [TypeParam::new_list_type(TP_ANY)], + [TypeParam::new_list(TP_ANY)], FuncValueType::new( vec![usize_t().into(), rty.clone()], vec![TypeRV::new_tuple(rty)], @@ -418,11 +434,16 @@ pub(crate) mod test { fn seq2() -> Vec { vec![usize_t().into(), bool_t().into()] } - pf.instantiate(&[usize_t().into()]).unwrap_err(); - pf.instantiate(&[Term::new_list([usize_t().into(), Term::new_list(seq2())])]) + pf.instantiate(&[TypeArg::Type { ty: usize_t() }]) .unwrap_err(); + pf.instantiate(&[TypeArg::Sequence { + elems: vec![usize_t().into(), TypeArg::Sequence { elems: seq2() }], + }]) + .unwrap_err(); - let t2 = pf.instantiate(&[Term::new_list(seq2())]).unwrap(); + let t2 = pf + .instantiate(&[TypeArg::Sequence { elems: seq2() }]) + .unwrap(); assert_eq!( t2, Signature::new( @@ -439,18 +460,20 @@ pub(crate) mod test { TypeBound::Copyable, ))); let pf = PolyFuncTypeBase::new_validated( - [Term::new_list_type(TypeBound::Copyable)], + [TypeParam::List { + param: Box::new(TypeParam::Type { + b: TypeBound::Copyable, + }), + }], Signature::new(vec![usize_t(), inner_fty.clone()], vec![inner_fty]), ) .unwrap(); let inner3 = Type::new_function(Signature::new_endo(vec![usize_t(), bool_t(), usize_t()])); let t3 = pf - .instantiate(&[Term::new_list([ - usize_t().into(), - bool_t().into(), - usize_t().into(), - ])]) + .instantiate(&[TypeArg::Sequence { + elems: vec![usize_t().into(), bool_t().into(), usize_t().into()], + }]) .unwrap(); assert_eq!( t3, diff --git a/hugr-core/src/types/row_var.rs b/hugr-core/src/types/row_var.rs index 086ab7b076..106870003b 100644 --- a/hugr-core/src/types/row_var.rs +++ b/hugr-core/src/types/row_var.rs @@ -6,7 +6,7 @@ use crate::extension::SignatureError; #[cfg(test)] use proptest::prelude::{BoxedStrategy, Strategy, any}; -/// Describes a row variable - a type variable bound with a list of runtime types +/// Describes a row variable - a type variable bound with a [`TypeParam::List`] of [`TypeParam::Type`] /// of the specified bound (checked in validation) // The serde derives here are not used except as markers // so that other types containing this can also #derive-serde the same way. @@ -70,7 +70,7 @@ impl MaybeRV for RowVariable { } fn validate(&self, var_decls: &[TypeParam]) -> Result<(), SignatureError> { - check_typevar_decl(var_decls, self.0, &TypeParam::new_list_type(self.1)) + check_typevar_decl(var_decls, self.0, &TypeParam::new_list(self.1)) } #[allow(private_interfaces)] diff --git a/hugr-core/src/types/serialize.rs b/hugr-core/src/types/serialize.rs index c0a35dfd5e..198c0c1eda 100644 --- a/hugr-core/src/types/serialize.rs +++ b/hugr-core/src/types/serialize.rs @@ -1,7 +1,3 @@ -use std::sync::Arc; - -use ordered_float::OrderedFloat; - use super::{FuncValueType, MaybeRV, RowVariable, SumType, TypeBase, TypeBound, TypeEnum}; use super::custom::CustomType; @@ -9,12 +5,10 @@ use super::custom::CustomType; use crate::extension::SignatureError; use crate::extension::prelude::{qb_t, usize_t}; use crate::ops::AliasDecl; -use crate::types::type_param::{TermVar, UpperBound}; -use crate::types::{Term, Type}; #[derive(serde::Serialize, serde::Deserialize, Clone, Debug)] #[serde(tag = "t")] -pub(crate) enum SerSimpleType { +pub(super) enum SerSimpleType { Q, I, G(Box), @@ -66,167 +60,3 @@ impl TryFrom for TypeBase { }) } } - -#[derive(Clone, Debug, serde::Deserialize, serde::Serialize)] -#[non_exhaustive] -#[serde(tag = "tp")] -pub(super) enum TypeParamSer { - Type { b: TypeBound }, - BoundedNat { bound: UpperBound }, - String, - Bytes, - Float, - StaticType, - List { param: Box }, - Tuple { params: ArrayOrTermSer }, -} - -#[derive(Clone, Debug, serde::Deserialize, serde::Serialize)] -#[non_exhaustive] -#[serde(tag = "tya")] -pub(super) enum TypeArgSer { - Type { - ty: Type, - }, - BoundedNat { - n: u64, - }, - String { - arg: String, - }, - Bytes { - #[serde(with = "base64")] - value: Arc<[u8]>, - }, - Float { - value: OrderedFloat, - }, - List { - elems: Vec, - }, - ListConcat { - lists: Vec, - }, - Tuple { - elems: Vec, - }, - TupleConcat { - tuples: Vec, - }, - Variable { - #[serde(flatten)] - v: TermVar, - }, -} - -#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] -#[serde(untagged)] -pub(super) enum TermSer { - TypeArg(TypeArgSer), - TypeParam(TypeParamSer), -} - -impl From for TermSer { - fn from(value: Term) -> Self { - match value { - Term::RuntimeType(b) => TermSer::TypeParam(TypeParamSer::Type { b }), - Term::StaticType => TermSer::TypeParam(TypeParamSer::StaticType), - Term::BoundedNatType(bound) => TermSer::TypeParam(TypeParamSer::BoundedNat { bound }), - Term::StringType => TermSer::TypeParam(TypeParamSer::String), - Term::BytesType => TermSer::TypeParam(TypeParamSer::Bytes), - Term::FloatType => TermSer::TypeParam(TypeParamSer::Float), - Term::ListType(param) => TermSer::TypeParam(TypeParamSer::List { param }), - Term::Runtime(ty) => TermSer::TypeArg(TypeArgSer::Type { ty }), - Term::TupleType(params) => TermSer::TypeParam(TypeParamSer::Tuple { - params: (*params).into(), - }), - Term::BoundedNat(n) => TermSer::TypeArg(TypeArgSer::BoundedNat { n }), - Term::String(arg) => TermSer::TypeArg(TypeArgSer::String { arg }), - Term::Bytes(value) => TermSer::TypeArg(TypeArgSer::Bytes { value }), - Term::Float(value) => TermSer::TypeArg(TypeArgSer::Float { value }), - Term::List(elems) => TermSer::TypeArg(TypeArgSer::List { elems }), - Term::Tuple(elems) => TermSer::TypeArg(TypeArgSer::Tuple { elems }), - Term::Variable(v) => TermSer::TypeArg(TypeArgSer::Variable { v }), - Term::ListConcat(lists) => TermSer::TypeArg(TypeArgSer::ListConcat { lists }), - Term::TupleConcat(tuples) => TermSer::TypeArg(TypeArgSer::TupleConcat { tuples }), - } - } -} - -impl From for Term { - fn from(value: TermSer) -> Self { - match value { - TermSer::TypeParam(param) => match param { - TypeParamSer::Type { b } => Term::RuntimeType(b), - TypeParamSer::StaticType => Term::StaticType, - TypeParamSer::BoundedNat { bound } => Term::BoundedNatType(bound), - TypeParamSer::String => Term::StringType, - TypeParamSer::Bytes => Term::BytesType, - TypeParamSer::Float => Term::FloatType, - TypeParamSer::List { param } => Term::ListType(param), - TypeParamSer::Tuple { params } => Term::TupleType(Box::new(params.into())), - }, - TermSer::TypeArg(arg) => match arg { - TypeArgSer::Type { ty } => Term::Runtime(ty), - TypeArgSer::BoundedNat { n } => Term::BoundedNat(n), - TypeArgSer::String { arg } => Term::String(arg), - TypeArgSer::Bytes { value } => Term::Bytes(value), - TypeArgSer::Float { value } => Term::Float(value), - TypeArgSer::List { elems } => Term::List(elems), - TypeArgSer::Tuple { elems } => Term::Tuple(elems), - TypeArgSer::Variable { v } => Term::Variable(v), - TypeArgSer::ListConcat { lists } => Term::ListConcat(lists), - TypeArgSer::TupleConcat { tuples } => Term::TupleConcat(tuples), - }, - } - } -} - -/// Helper type that serialises lists as JSON arrays for compatibility. -#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] -#[serde(untagged)] -pub(super) enum ArrayOrTermSer { - Array(Vec), - Term(Box), // TODO JSON Schema does not really support this yet -} - -impl From for Term { - fn from(value: ArrayOrTermSer) -> Self { - match value { - ArrayOrTermSer::Array(terms) => Term::new_list(terms), - ArrayOrTermSer::Term(term) => *term, - } - } -} - -impl From for ArrayOrTermSer { - fn from(term: Term) -> Self { - match term { - Term::List(terms) => ArrayOrTermSer::Array(terms), - term => ArrayOrTermSer::Term(Box::new(term)), - } - } -} - -/// Helper for to serialize and deserialize the byte string in [`TypeArg::Bytes`] via base64. -mod base64 { - use std::sync::Arc; - - use base64::Engine as _; - use base64::prelude::BASE64_STANDARD; - use serde::{Deserialize, Serialize}; - use serde::{Deserializer, Serializer}; - - pub fn serialize(v: &Arc<[u8]>, s: S) -> Result { - let base64 = BASE64_STANDARD.encode(v); - base64.serialize(s) - } - - pub fn deserialize<'de, D: Deserializer<'de>>(d: D) -> Result, D::Error> { - let base64 = String::deserialize(d)?; - BASE64_STANDARD - .decode(base64.as_bytes()) - .map(|v| v.into()) - .map_err(serde::de::Error::custom) - } -} diff --git a/hugr-core/src/types/type_param.rs b/hugr-core/src/types/type_param.rs index a1cbe4bcea..0b6d19fa7d 100644 --- a/hugr-core/src/types/type_param.rs +++ b/hugr-core/src/types/type_param.rs @@ -4,15 +4,11 @@ //! //! [`TypeDef`]: crate::extension::TypeDef -use ordered_float::OrderedFloat; +use itertools::Itertools; #[cfg(test)] use proptest_derive::Arbitrary; -use smallvec::{SmallVec, smallvec}; -use std::iter::FusedIterator; use std::num::NonZeroU64; -use std::sync::Arc; use thiserror::Error; -use tracing::warn; use super::row_var::MaybeRV; use super::{ @@ -52,286 +48,242 @@ impl UpperBound { } } -/// A [`Term`] that is a static argument to an operation or constructor. -pub type TypeArg = Term; - -/// A [`Term`] that is the static type of an operation or constructor parameter. -pub type TypeParam = Term; - -/// A term in the language of static parameters in HUGR. +/// A *kind* of [`TypeArg`]. Thus, a parameter declared by a [`PolyFuncType`] or [`PolyFuncTypeRV`], +/// specifying a value that must be provided statically in order to instantiate it. +/// +/// [`PolyFuncType`]: super::PolyFuncType +/// [`PolyFuncTypeRV`]: super::PolyFuncTypeRV #[derive( Clone, Debug, PartialEq, Eq, Hash, derive_more::Display, serde::Deserialize, serde::Serialize, )] #[non_exhaustive] -#[serde( - from = "crate::types::serialize::TermSer", - into = "crate::types::serialize::TermSer" -)] -pub enum Term { - /// The type of runtime types. - #[display("Type{}", match _0 { - TypeBound::Linear => String::new(), - _ => format!("[{_0}]") +#[serde(tag = "tp")] +pub enum TypeParam { + /// Argument is a [`TypeArg::Type`]. + #[display("Type{}", match b { + TypeBound::Any => String::new(), + _ => format!("[{b}]") })] - RuntimeType(TypeBound), - /// The type of static data. - StaticType, - /// The type of static natural numbers up to a given bound. - #[display("{}", match _0.value() { + Type { + /// Bound for the type parameter. + b: TypeBound, + }, + /// Argument is a [`TypeArg::BoundedNat`] that is less than the upper bound. + #[display("{}", match bound.value() { Some(v) => format!("BoundedNat[{v}]"), None => "Nat".to_string() })] - BoundedNatType(UpperBound), - /// The type of static strings. See [`Term::String`]. - StringType, - /// The type of static byte strings. See [`Term::Bytes`]. - BytesType, - /// The type of static floating point numbers. See [`Term::Float`]. - FloatType, - /// The type of static lists of indeterminate size containing terms of the - /// specified static type. - #[display("ListType[{_0}]")] - ListType(Box), - /// The type of static tuples. - #[display("TupleType[{_0}]")] - TupleType(Box), - /// A runtime type as a term. Instance of [`Term::RuntimeType`]. - #[display("{_0}")] - Runtime(Type), - /// A 64bit unsigned integer literal. Instance of [`Term::BoundedNatType`]. - #[display("{_0}")] - BoundedNat(u64), - /// UTF-8 encoded string literal. Instance of [`Term::StringType`]. - #[display("\"{_0}\"")] - String(String), - /// Byte string literal. Instance of [`Term::BytesType`]. - #[display("bytes")] - Bytes(Arc<[u8]>), - /// A 64-bit floating point number. Instance of [`Term::FloatType`]. - #[display("{}", _0.into_inner())] - Float(OrderedFloat), - /// A list of static terms. Instance of [`Term::ListType`]. - #[display("[{}]", { - use itertools::Itertools as _; - _0.iter().map(|t|t.to_string()).join(",") - })] - List(Vec), - /// Instance of [`TypeParam::List`] defined by a sequence of concatenated lists of the same type. - #[display("[{}]", { - use itertools::Itertools as _; - _0.iter().map(|t| format!("... {t}")).join(",") - })] - ListConcat(Vec), - /// Instance of [`TypeParam::Tuple`] defined by a sequence of elements of varying type. - #[display("({})", { - use itertools::Itertools as _; - _0.iter().map(std::string::ToString::to_string).join(",") - })] - Tuple(Vec), - /// Instance of [`TypeParam::Tuple`] defined by a sequence of concatenated tuples. - #[display("({})", { - use itertools::Itertools as _; - _0.iter().map(|tuple| format!("... {tuple}")).join(",") - })] - TupleConcat(Vec), - /// Variable (used in type schemes or inside polymorphic functions), - /// but not a runtime type (not even a row variable i.e. list of runtime types) - /// - see [`Term::new_var_use`] - #[display("{_0}")] - Variable(TermVar), + BoundedNat { + /// Upper bound for the Nat parameter. + bound: UpperBound, + }, + /// Argument is a [`TypeArg::String`]. + String, + /// Argument is a [`TypeArg::Sequence`]. A list of indeterminate size containing + /// parameters all of the (same) specified element type. + #[display("List[{param}]")] + List { + /// The [`TypeParam`] describing each element of the list. + param: Box, + }, + /// Argument is a [`TypeArg::Sequence`]. A tuple of parameters. + #[display("Tuple[{}]", params.iter().map(std::string::ToString::to_string).join(", "))] + Tuple { + /// The [`TypeParam`]s contained in the tuple. + params: Vec, + }, } -impl Term { - /// Creates a [`Term::BoundedNatType`] with the maximum bound (`u64::MAX` + 1). +impl TypeParam { + /// [`TypeParam::BoundedNat`] with the maximum bound (`u64::MAX` + 1) #[must_use] - pub const fn max_nat_type() -> Self { - Self::BoundedNatType(UpperBound(None)) + pub const fn max_nat() -> Self { + Self::BoundedNat { + bound: UpperBound(None), + } } - /// Creates a [`Term::BoundedNatType`] with the stated upper bound (non-exclusive). + /// [`TypeParam::BoundedNat`] with the stated upper bound (non-exclusive) #[must_use] - pub const fn bounded_nat_type(upper_bound: NonZeroU64) -> Self { - Self::BoundedNatType(UpperBound(Some(upper_bound))) - } - - /// Creates a new [`Term::List`] given a sequence of its items. - pub fn new_list(items: impl IntoIterator) -> Self { - Self::List(items.into_iter().collect()) - } - - /// Creates a new [`Term::ListType`] given the type of its elements. - pub fn new_list_type(elem: impl Into) -> Self { - Self::ListType(Box::new(elem.into())) + pub const fn bounded_nat(upper_bound: NonZeroU64) -> Self { + Self::BoundedNat { + bound: UpperBound(Some(upper_bound)), + } } - /// Creates a new [`Term::TupleType`] given the type of its elements. - pub fn new_tuple_type(item_types: impl Into) -> Self { - Self::TupleType(Box::new(item_types.into())) + /// Make a new `TypeParam::List` (an arbitrary-length homogeneous list) + pub fn new_list(elem: impl Into) -> Self { + Self::List { + param: Box::new(elem.into()), + } } - /// Checks if this term is a supertype of another. - /// - /// The subtyping relation applies primarily to terms that represent static - /// types. For consistency the relation is extended to a partial order on - /// all terms; in particular it is reflexive so that every term (even if it - /// is not a static type) is considered a subtype of itself. - fn is_supertype(&self, other: &Term) -> bool { + fn contains(&self, other: &TypeParam) -> bool { match (self, other) { - (Term::RuntimeType(b1), Term::RuntimeType(b2)) => b1.contains(*b2), - (Term::BoundedNatType(b1), Term::BoundedNatType(b2)) => b1.contains(b2), - (Term::StringType, Term::StringType) => true, - (Term::StaticType, Term::StaticType) => true, - (Term::ListType(e1), Term::ListType(e2)) => e1.is_supertype(e2), - (Term::TupleType(es1), Term::TupleType(es2)) => es1.is_supertype(es2), - (Term::BytesType, Term::BytesType) => true, - (Term::FloatType, Term::FloatType) => true, - (Term::Runtime(t1), Term::Runtime(t2)) => t1 == t2, - (Term::BoundedNat(n1), Term::BoundedNat(n2)) => n1 == n2, - (Term::String(s1), Term::String(s2)) => s1 == s2, - (Term::Bytes(v1), Term::Bytes(v2)) => v1 == v2, - (Term::Float(f1), Term::Float(f2)) => f1 == f2, - (Term::Variable(v1), Term::Variable(v2)) => v1 == v2, - (Term::List(es1), Term::List(es2)) => { - es1.len() == es2.len() && es1.iter().zip(es2).all(|(e1, e2)| e1.is_supertype(e2)) + (TypeParam::Type { b: b1 }, TypeParam::Type { b: b2 }) => b1.contains(*b2), + (TypeParam::BoundedNat { bound: b1 }, TypeParam::BoundedNat { bound: b2 }) => { + b1.contains(b2) } - (Term::Tuple(es1), Term::Tuple(es2)) => { - es1.len() == es2.len() && es1.iter().zip(es2).all(|(e1, e2)| e1.is_supertype(e2)) + (TypeParam::String, TypeParam::String) => true, + (TypeParam::List { param: e1 }, TypeParam::List { param: e2 }) => e1.contains(e2), + (TypeParam::Tuple { params: es1 }, TypeParam::Tuple { params: es2 }) => { + es1.len() == es2.len() && es1.iter().zip(es2).all(|(e1, e2)| e1.contains(e2)) } _ => false, } } } -impl From for Term { +impl From for TypeParam { fn from(bound: TypeBound) -> Self { - Self::RuntimeType(bound) + Self::Type { b: bound } } } -impl From for Term { +impl From for TypeParam { fn from(bound: UpperBound) -> Self { - Self::BoundedNatType(bound) + Self::BoundedNat { bound } } } -impl From> for Term { +/// A statically-known argument value to an operation. +#[derive( + Clone, Debug, PartialEq, Eq, Hash, serde::Deserialize, serde::Serialize, derive_more::Display, +)] +#[non_exhaustive] +#[serde(tag = "tya")] +pub enum TypeArg { + /// Where the (Type/Op)Def declares that an argument is a [`TypeParam::Type`] + #[display("{ty}")] + Type { + /// The concrete type for the parameter. + ty: Type, + }, + /// Instance of [`TypeParam::BoundedNat`]. 64-bit unsigned integer. + #[display("{n}")] + BoundedNat { + /// The integer value for the parameter. + n: u64, + }, + ///Instance of [`TypeParam::String`]. UTF-8 encoded string argument. + #[display("\"{arg}\"")] + String { + /// The string value for the parameter. + arg: String, + }, + /// Instance of [`TypeParam::List`] or [`TypeParam::Tuple`], defined by a + /// sequence of elements. + #[display("({})", { + use itertools::Itertools as _; + elems.iter().map(std::string::ToString::to_string).join(",") + })] + Sequence { + /// List of element types + elems: Vec, + }, + /// Variable (used in type schemes or inside polymorphic functions), + /// but not a [`TypeArg::Type`] (not even a row variable i.e. [`TypeParam::List`] of type) + /// - see [`TypeArg::new_var_use`] + #[display("{v}")] + Variable { + #[allow(missing_docs)] + #[serde(flatten)] + v: TypeArgVariable, + }, +} + +impl From> for TypeArg { fn from(value: TypeBase) -> Self { match value.try_into_type() { - Ok(ty) => Term::Runtime(ty), - Err(RowVariable(idx, bound)) => Term::new_var_use(idx, TypeParam::new_list_type(bound)), + Ok(ty) => TypeArg::Type { ty }, + Err(RowVariable(idx, bound)) => TypeArg::new_var_use(idx, TypeParam::new_list(bound)), } } } -impl From for Term { +impl From for TypeArg { fn from(n: u64) -> Self { - Self::BoundedNat(n) + Self::BoundedNat { n } } } -impl From for Term { +impl From for TypeArg { fn from(arg: String) -> Self { - Term::String(arg) + TypeArg::String { arg } } } -impl From<&str> for Term { +impl From<&str> for TypeArg { fn from(arg: &str) -> Self { - Term::String(arg.to_string()) - } -} - -impl From> for Term { - fn from(elems: Vec) -> Self { - Self::new_list(elems) + TypeArg::String { + arg: arg.to_string(), + } } } -impl From<[Term; N]> for Term { - fn from(value: [Term; N]) -> Self { - Self::new_list(value) +impl From> for TypeArg { + fn from(elems: Vec) -> Self { + Self::Sequence { elems } } } -/// Variable in a [`Term`], that is not a single runtime type (i.e. not a [`Type::new_var_use`] +/// Variable in a `TypeArg`, that is not a single [`TypeArg::Type`] (i.e. not a [`Type::new_var_use`] /// - it might be a [`Type::new_row_var_use`]). #[derive( Clone, Debug, PartialEq, Eq, Hash, serde::Deserialize, serde::Serialize, derive_more::Display, )] #[display("#{idx}")] -pub struct TermVar { +pub struct TypeArgVariable { idx: usize, - pub(in crate::types) cached_decl: Box, + cached_decl: TypeParam, } -impl Term { - /// [`Type::UNIT`] as a [`Term::Runtime`] - pub const UNIT: Self = Self::Runtime(Type::UNIT); +impl TypeArg { + /// [`Type::UNIT`] as a [`TypeArg::Type`] + pub const UNIT: Self = Self::Type { ty: Type::UNIT }; /// Makes a `TypeArg` representing a use (occurrence) of the type variable /// with the specified index. /// `decl` must be exactly that with which the variable was declared. #[must_use] - pub fn new_var_use(idx: usize, decl: Term) -> Self { + pub fn new_var_use(idx: usize, decl: TypeParam) -> Self { match decl { // Note a TypeParam::List of TypeParam::Type *cannot* be represented // as a TypeArg::Type because the latter stores a Type i.e. only a single type, // not a RowVariable. - Term::RuntimeType(b) => Type::new_var_use(idx, b).into(), - _ => Term::Variable(TermVar { - idx, - cached_decl: Box::new(decl), - }), + TypeParam::Type { b } => Type::new_var_use(idx, b).into(), + _ => TypeArg::Variable { + v: TypeArgVariable { + idx, + cached_decl: decl, + }, + }, } } - /// Creates a new string literal. - #[inline] - pub fn new_string(str: impl ToString) -> Self { - Self::String(str.to_string()) - } - - /// Creates a new concatenated list. - #[inline] - pub fn new_list_concat(lists: impl IntoIterator) -> Self { - Self::ListConcat(lists.into_iter().collect()) - } - - /// Creates a new tuple from its items. - #[inline] - pub fn new_tuple(items: impl IntoIterator) -> Self { - Self::Tuple(items.into_iter().collect()) - } - - /// Creates a new concatenated tuple. - #[inline] - pub fn new_tuple_concat(tuples: impl IntoIterator) -> Self { - Self::TupleConcat(tuples.into_iter().collect()) - } - - /// Returns an integer if the [`Term`] is a natural number literal. + /// Returns an integer if the `TypeArg` is an instance of `BoundedNat`. #[must_use] pub fn as_nat(&self) -> Option { match self { - TypeArg::BoundedNat(n) => Some(*n), + TypeArg::BoundedNat { n } => Some(*n), _ => None, } } - /// Returns a [`Type`] if the [`Term`] is a runtime type. + /// Returns a type if the `TypeArg` is an instance of Type. #[must_use] - pub fn as_runtime(&self) -> Option> { + pub fn as_type(&self) -> Option> { match self { - TypeArg::Runtime(ty) => Some(ty.clone()), + TypeArg::Type { ty } => Some(ty.clone()), _ => None, } } - /// Returns a string if the [`Term`] is a string literal. + /// Returns a string if the `TypeArg` is an instance of String. #[must_use] pub fn as_string(&self) -> Option { match self { - TypeArg::String(arg) => Some(arg.clone()), + TypeArg::String { arg } => Some(arg.clone()), _ => None, } } @@ -340,264 +292,75 @@ impl Term { /// is valid and closed. pub(crate) fn validate(&self, var_decls: &[TypeParam]) -> Result<(), SignatureError> { match self { - Term::Runtime(ty) => ty.validate(var_decls), - Term::List(elems) => { - // TODO: Full validation would check that the type of the elements agrees - elems.iter().try_for_each(|a| a.validate(var_decls)) - } - Term::Tuple(elems) => elems.iter().try_for_each(|a| a.validate(var_decls)), - Term::BoundedNat(_) | Term::String { .. } | Term::Float(_) | Term::Bytes(_) => Ok(()), - TypeArg::ListConcat(lists) => { - // TODO: Full validation would check that each of the lists is indeed a - // list or list variable of the correct types. - lists.iter().try_for_each(|a| a.validate(var_decls)) - } - TypeArg::TupleConcat(tuples) => tuples.iter().try_for_each(|a| a.validate(var_decls)), - Term::Variable(TermVar { idx, cached_decl }) => { + TypeArg::Type { ty } => ty.validate(var_decls), + TypeArg::BoundedNat { .. } | TypeArg::String { .. } => Ok(()), + TypeArg::Sequence { elems } => elems.iter().try_for_each(|a| a.validate(var_decls)), + TypeArg::Variable { + v: TypeArgVariable { idx, cached_decl }, + } => { assert!( - !matches!(&**cached_decl, TypeParam::RuntimeType { .. }), + !matches!(cached_decl, TypeParam::Type { .. }), "Malformed TypeArg::Variable {cached_decl} - should be inconstructible" ); check_typevar_decl(var_decls, *idx, cached_decl) } - Term::RuntimeType { .. } => Ok(()), - Term::BoundedNatType { .. } => Ok(()), - Term::StringType => Ok(()), - Term::BytesType => Ok(()), - Term::FloatType => Ok(()), - Term::ListType(item_type) => item_type.validate(var_decls), - Term::TupleType(item_types) => item_types.validate(var_decls), - Term::StaticType => Ok(()), } } pub(crate) fn substitute(&self, t: &Substitution) -> Self { match self { - Term::Runtime(ty) => { - // RowVariables are represented as Term::Variable + TypeArg::Type { ty } => { + // RowVariables are represented as TypeArg::Variable ty.substitute1(t).into() } - TypeArg::BoundedNat(_) | TypeArg::String(_) | TypeArg::Bytes(_) | TypeArg::Float(_) => { - self.clone() - } // We do not allow variables as bounds on BoundedNat's - TypeArg::List(elems) => { - // NOTE: This implements a hack allowing substitutions to - // replace `TypeArg::Variable`s representing "row variables" - // with a list that is to be spliced into the containing list. - // We won't need this code anymore once we stop conflating types - // with lists of types. - - fn is_type(type_arg: &TypeArg) -> bool { - match type_arg { - TypeArg::Runtime(_) => true, - TypeArg::Variable(v) => v.bound_if_row_var().is_some(), - _ => false, + TypeArg::BoundedNat { .. } | TypeArg::String { .. } => self.clone(), // We do not allow variables as bounds on BoundedNat's + TypeArg::Sequence { elems } => { + let mut are_types = elems.iter().map(|ta| match ta { + TypeArg::Type { .. } => true, + TypeArg::Variable { v } => v.bound_if_row_var().is_some(), + _ => false, + }); + let elems = match are_types.next() { + Some(true) => { + assert!(are_types.all(|b| b)); // If one is a Type, so must the rest be + // So, anything that doesn't produce a Type, was a row variable => multiple Types + elems + .iter() + .flat_map(|ta| match ta.substitute(t) { + ty @ TypeArg::Type { .. } => vec![ty], + TypeArg::Sequence { elems } => elems, + _ => panic!("Expected Type or row of Types"), + }) + .collect() } - } - - let are_types = elems.first().map(is_type).unwrap_or(false); - - Self::new_list_from_parts(elems.iter().map(|elem| match elem.substitute(t) { - list @ TypeArg::List { .. } if are_types => SeqPart::Splice(list), - list @ TypeArg::ListConcat { .. } if are_types => SeqPart::Splice(list), - elem => SeqPart::Item(elem), - })) - } - TypeArg::ListConcat(lists) => { - // When a substitution instantiates spliced list variables, we - // may be able to merge the concatenated lists. - Self::new_list_from_parts( - lists.iter().map(|list| SeqPart::Splice(list.substitute(t))), - ) - } - Term::Tuple(elems) => { - Term::Tuple(elems.iter().map(|elem| elem.substitute(t)).collect()) - } - TypeArg::TupleConcat(tuples) => { - // When a substitution instantiates spliced tuple variables, - // we may be able to merge the concatenated tuples. - Self::new_tuple_from_parts( - tuples - .iter() - .map(|tuple| SeqPart::Splice(tuple.substitute(t))), - ) - } - TypeArg::Variable(TermVar { idx, cached_decl }) => t.apply_var(*idx, cached_decl), - Term::RuntimeType(_) => self.clone(), - Term::BoundedNatType(_) => self.clone(), - Term::StringType => self.clone(), - Term::BytesType => self.clone(), - Term::FloatType => self.clone(), - Term::ListType(item_type) => Term::new_list_type(item_type.substitute(t)), - Term::TupleType(item_types) => Term::new_list_type(item_types.substitute(t)), - Term::StaticType => self.clone(), - } - } - - /// Helper method for [`TypeArg::new_list_from_parts`] and [`TypeArg::new_tuple_from_parts`]. - fn new_seq_from_parts( - parts: impl IntoIterator>, - make_items: impl Fn(Vec) -> Self, - make_concat: impl Fn(Vec) -> Self, - ) -> Self { - let mut items = Vec::new(); - let mut seqs = Vec::new(); - - for part in parts { - match part { - SeqPart::Item(item) => items.push(item), - SeqPart::Splice(seq) => { - if !items.is_empty() { - seqs.push(make_items(std::mem::take(&mut items))); + _ => { + // not types, no need to flatten (and mustn't, in case of nested Sequences) + elems.iter().map(|ta| ta.substitute(t)).collect() } - seqs.push(seq); - } + }; + TypeArg::Sequence { elems } } + TypeArg::Variable { + v: TypeArgVariable { idx, cached_decl }, + } => t.apply_var(*idx, cached_decl), } - - if seqs.is_empty() { - make_items(items) - } else if items.is_empty() { - make_concat(seqs) - } else { - seqs.push(make_items(items)); - make_concat(seqs) - } - } - - /// Creates a new list from a sequence of [`SeqPart`]s. - pub fn new_list_from_parts(parts: impl IntoIterator>) -> Self { - Self::new_seq_from_parts( - parts.into_iter().flat_map(ListPartIter::new), - TypeArg::List, - TypeArg::ListConcat, - ) - } - - /// Iterates over the [`SeqPart`]s of a list. - /// - /// # Examples - /// - /// The parts of a closed list are the items of that list wrapped in [`SeqPart::Item`]: - /// - /// ``` - /// # use hugr_core::types::type_param::{Term, SeqPart}; - /// # let a = Term::new_string("a"); - /// # let b = Term::new_string("b"); - /// let term = Term::new_list([a.clone(), b.clone()]); - /// - /// assert_eq!( - /// term.into_list_parts().collect::>(), - /// vec![SeqPart::Item(a), SeqPart::Item(b)] - /// ); - /// ``` - /// - /// Parts of a concatenated list that are not closed lists are wrapped in [`SeqPart::Splice`]: - /// - /// ``` - /// # use hugr_core::types::type_param::{Term, SeqPart}; - /// # let a = Term::new_string("a"); - /// # let b = Term::new_string("b"); - /// # let c = Term::new_string("c"); - /// let var = Term::new_var_use(0, Term::new_list_type(Term::StringType)); - /// let term = Term::new_list_concat([ - /// Term::new_list([a.clone(), b.clone()]), - /// var.clone(), - /// Term::new_list([c.clone()]) - /// ]); - /// - /// assert_eq!( - /// term.into_list_parts().collect::>(), - /// vec![SeqPart::Item(a), SeqPart::Item(b), SeqPart::Splice(var), SeqPart::Item(c)] - /// ); - /// ``` - /// - /// Nested concatenations are traversed recursively: - /// - /// ``` - /// # use hugr_core::types::type_param::{Term, SeqPart}; - /// # let a = Term::new_string("a"); - /// # let b = Term::new_string("b"); - /// # let c = Term::new_string("c"); - /// let term = Term::new_list_concat([ - /// Term::new_list_concat([ - /// Term::new_list([a.clone()]), - /// Term::new_list([b.clone()]) - /// ]), - /// Term::new_list([]), - /// Term::new_list([c.clone()]) - /// ]); - /// - /// assert_eq!( - /// term.into_list_parts().collect::>(), - /// vec![SeqPart::Item(a), SeqPart::Item(b), SeqPart::Item(c)] - /// ); - /// ``` - /// - /// When invoked on a type argument that is not a list, a single - /// [`SeqPart::Splice`] is returned that wraps the type argument. - /// This is the expected behaviour for type variables that stand for lists. - /// This behaviour also allows this method not to fail on ill-typed type arguments. - /// ``` - /// # use hugr_core::types::type_param::{Term, SeqPart}; - /// let term = Term::new_string("not a list"); - /// assert_eq!( - /// term.clone().into_list_parts().collect::>(), - /// vec![SeqPart::Splice(term)] - /// ); - /// ``` - #[inline] - pub fn into_list_parts(self) -> ListPartIter { - ListPartIter::new(SeqPart::Splice(self)) - } - - /// Creates a new tuple from a sequence of [`SeqPart`]s. - /// - /// Analogous to [`TypeArg::new_list_from_parts`]. - pub fn new_tuple_from_parts(parts: impl IntoIterator>) -> Self { - Self::new_seq_from_parts( - parts.into_iter().flat_map(TuplePartIter::new), - TypeArg::Tuple, - TypeArg::TupleConcat, - ) - } - - /// Iterates over the [`SeqPart`]s of a tuple. - /// - /// Analogous to [`TypeArg::into_list_parts`]. - #[inline] - pub fn into_tuple_parts(self) -> TuplePartIter { - TuplePartIter::new(SeqPart::Splice(self)) } } -impl Transformable for Term { +impl Transformable for TypeArg { fn transform(&mut self, tr: &T) -> Result { match self { - Term::Runtime(ty) => ty.transform(tr), - Term::List(elems) => elems.transform(tr), - Term::Tuple(elems) => elems.transform(tr), - Term::BoundedNat(_) - | Term::String(_) - | Term::Variable(_) - | Term::Float(_) - | Term::Bytes(_) => Ok(false), - Term::RuntimeType { .. } => Ok(false), - Term::BoundedNatType { .. } => Ok(false), - Term::StringType => Ok(false), - Term::BytesType => Ok(false), - Term::FloatType => Ok(false), - Term::ListType(item_type) => item_type.transform(tr), - Term::TupleType(item_types) => item_types.transform(tr), - Term::StaticType => Ok(false), - TypeArg::ListConcat(lists) => lists.transform(tr), - TypeArg::TupleConcat(tuples) => tuples.transform(tr), + TypeArg::Type { ty } => ty.transform(tr), + TypeArg::Sequence { elems } => elems.transform(tr), + TypeArg::BoundedNat { .. } | TypeArg::String { .. } | TypeArg::Variable { .. } => { + Ok(false) + } } } } -impl TermVar { +impl TypeArgVariable { /// Return the index. #[must_use] pub fn index(&self) -> usize { @@ -608,8 +371,8 @@ impl TermVar { /// the [`TypeBound`] of the individual types it might stand for. #[must_use] pub fn bound_if_row_var(&self) -> Option { - if let Term::ListType(item_type) = &*self.cached_decl { - if let Term::RuntimeType(b) = **item_type { + if let TypeParam::List { param } = &self.cached_decl { + if let TypeParam::Type { b } = **param { return Some(b); } } @@ -617,103 +380,80 @@ impl TermVar { } } -/// Checks that a [`Term`] is valid for a given type. -pub fn check_term_type(term: &Term, type_: &Term) -> Result<(), TermTypeError> { - match (term, type_) { - (Term::Variable(TermVar { cached_decl, .. }), _) if type_.is_supertype(cached_decl) => { +/// Checks a [`TypeArg`] is as expected for a [`TypeParam`] +pub fn check_type_arg(arg: &TypeArg, param: &TypeParam) -> Result<(), TypeArgError> { + match (arg, param) { + ( + TypeArg::Variable { + v: TypeArgVariable { cached_decl, .. }, + }, + _, + ) if param.contains(cached_decl) => Ok(()), + (TypeArg::Type { ty }, TypeParam::Type { b: bound }) + if bound.contains(ty.least_upper_bound()) => + { Ok(()) } - (Term::Runtime(ty), Term::RuntimeType(bound)) if bound.contains(ty.least_upper_bound()) => { - Ok(()) - } - (Term::List(elems), Term::ListType(item_type)) => { - elems.iter().try_for_each(|term| { + (TypeArg::Sequence { elems }, TypeParam::List { param }) => { + elems.iter().try_for_each(|arg| { // Also allow elements that are RowVars if fitting into a List of Types - if let (Term::Variable(v), Term::RuntimeType(param_bound)) = (term, &**item_type) { + if let (TypeArg::Variable { v }, TypeParam::Type { b: param_bound }) = + (arg, &**param) + { if v.bound_if_row_var() .is_some_and(|arg_bound| param_bound.contains(arg_bound)) { return Ok(()); } } - check_term_type(term, item_type) + check_type_arg(arg, param) }) } - (Term::ListConcat(lists), Term::ListType(item_type)) => lists - .iter() - .try_for_each(|list| check_term_type(list, item_type)), - (TypeArg::Tuple(_) | TypeArg::TupleConcat(_), TypeParam::TupleType(item_types)) => { - let term_parts: Vec<_> = term.clone().into_tuple_parts().collect(); - let type_parts: Vec<_> = item_types.clone().into_list_parts().collect(); - - for (term, type_) in term_parts.iter().zip(&type_parts) { - match (term, type_) { - (SeqPart::Item(term), SeqPart::Item(type_)) => { - check_term_type(term, type_)?; - } - (_, SeqPart::Splice(_)) | (SeqPart::Splice(_), _) => { - // TODO: Checking tuples with splicing requires more - // sophisticated validation infrastructure to do well. - warn!( - "Validation for open tuples is not implemented yet, succeeding regardless..." - ); - return Ok(()); - } - } + (TypeArg::Sequence { elems: items }, TypeParam::Tuple { params: types }) => { + if items.len() == types.len() { + items + .iter() + .zip(types.iter()) + .try_for_each(|(arg, param)| check_type_arg(arg, param)) + } else { + Err(TypeArgError::WrongNumberTuple(items.len(), types.len())) } - - if term_parts.len() != type_parts.len() { - return Err(TermTypeError::WrongNumberTuple( - term_parts.len(), - type_parts.len(), - )); - } - + } + (TypeArg::BoundedNat { n: val }, TypeParam::BoundedNat { bound }) + if bound.valid_value(*val) => + { Ok(()) } - (Term::BoundedNat(val), Term::BoundedNatType(bound)) if bound.valid_value(*val) => Ok(()), - (Term::String { .. }, Term::StringType) => Ok(()), - (Term::Bytes(_), Term::BytesType) => Ok(()), - (Term::Float(_), Term::FloatType) => Ok(()), - - // Static types - (Term::StaticType, Term::StaticType) => Ok(()), - (Term::StringType, Term::StaticType) => Ok(()), - (Term::BytesType, Term::StaticType) => Ok(()), - (Term::BoundedNatType { .. }, Term::StaticType) => Ok(()), - (Term::FloatType, Term::StaticType) => Ok(()), - (Term::ListType { .. }, Term::StaticType) => Ok(()), - (Term::TupleType(_), Term::StaticType) => Ok(()), - (Term::RuntimeType(_), Term::StaticType) => Ok(()), - _ => Err(TermTypeError::TypeMismatch { - term: Box::new(term.clone()), - type_: Box::new(type_.clone()), + (TypeArg::String { .. }, TypeParam::String) => Ok(()), + _ => Err(TypeArgError::TypeMismatch { + arg: arg.clone(), + param: param.clone(), }), } } -/// Check a list of [`Term`]s is valid for a list of types. -pub fn check_term_types(terms: &[Term], types: &[Term]) -> Result<(), TermTypeError> { - if terms.len() != types.len() { - return Err(TermTypeError::WrongNumberArgs(terms.len(), types.len())); +/// Check a list of type arguments match a list of required type parameters +pub fn check_type_args(args: &[TypeArg], params: &[TypeParam]) -> Result<(), TypeArgError> { + if args.len() != params.len() { + return Err(TypeArgError::WrongNumberArgs(args.len(), params.len())); } - for (term, type_) in terms.iter().zip(types.iter()) { - check_term_type(term, type_)?; + for (a, p) in args.iter().zip(params.iter()) { + check_type_arg(a, p)?; } Ok(()) } -/// Errors that can occur when checking that a [`Term`] has an expected type. +/// Errors that can occur fitting a [`TypeArg`] into a [`TypeParam`] #[derive(Clone, Debug, PartialEq, Eq, Error)] #[non_exhaustive] -pub enum TermTypeError { +pub enum TypeArgError { #[allow(missing_docs)] - /// For now, general case of a term not fitting a type. + /// For now, general case of a type arg not fitting a param. /// We'll have more cases when we allow general Containers. // TODO It may become possible to combine this with ConstTypeError. - #[error("Term {term} does not fit declared type {type_}")] - TypeMismatch { term: Box, type_: Box }, + #[error("Type argument {arg} does not fit declared parameter {param}")] + TypeMismatch { param: TypeParam, arg: TypeArg }, /// Wrong number of type arguments (actual vs expected). // For now this only happens at the top level (TypeArgs of op/type vs TypeParams of Op/TypeDef). // However in the future it may be applicable to e.g. contents of Tuples too. @@ -730,173 +470,35 @@ pub enum TermTypeError { OpaqueTypeMismatch(#[from] crate::types::CustomCheckFailure), /// Invalid value #[error("Invalid value of type argument")] - InvalidValue(Box), -} - -/// Part of a sequence. -#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] -pub enum SeqPart { - /// An individual item in the sequence. - Item(T), - /// A subsequence that is spliced into the parent sequence. - Splice(T), -} - -/// Iterator created by [`TypeArg::into_list_parts`]. -#[derive(Debug, Clone)] -pub struct ListPartIter { - parts: SmallVec<[SeqPart; 1]>, + InvalidValue(TypeArg), } -impl ListPartIter { - #[inline] - fn new(part: SeqPart) -> Self { - Self { - parts: smallvec![part], - } - } -} - -impl Iterator for ListPartIter { - type Item = SeqPart; - - fn next(&mut self) -> Option { - loop { - match self.parts.pop()? { - SeqPart::Splice(TypeArg::List(elems)) => self - .parts - .extend(elems.into_iter().rev().map(SeqPart::Item)), - SeqPart::Splice(TypeArg::ListConcat(lists)) => self - .parts - .extend(lists.into_iter().rev().map(SeqPart::Splice)), - part => return Some(part), - } - } - } -} - -impl FusedIterator for ListPartIter {} - -/// Iterator created by [`TypeArg::into_tuple_parts`]. -#[derive(Debug, Clone)] -pub struct TuplePartIter { - parts: SmallVec<[SeqPart; 1]>, -} - -impl TuplePartIter { - #[inline] - fn new(part: SeqPart) -> Self { - Self { - parts: smallvec![part], - } - } -} - -impl Iterator for TuplePartIter { - type Item = SeqPart; - - fn next(&mut self) -> Option { - loop { - match self.parts.pop()? { - SeqPart::Splice(TypeArg::Tuple(elems)) => self - .parts - .extend(elems.into_iter().rev().map(SeqPart::Item)), - SeqPart::Splice(TypeArg::TupleConcat(tuples)) => self - .parts - .extend(tuples.into_iter().rev().map(SeqPart::Splice)), - part => return Some(part), - } - } - } -} - -impl FusedIterator for TuplePartIter {} - #[cfg(test)] mod test { use itertools::Itertools; - use super::{Substitution, TypeArg, TypeParam, check_term_type}; + use super::{Substitution, TypeArg, TypeParam, check_type_arg}; use crate::extension::prelude::{bool_t, usize_t}; - use crate::types::Term; - use crate::types::type_param::SeqPart; - use crate::types::{TypeBound, TypeRV, type_param::TermTypeError}; - - #[test] - fn new_list_from_parts_items() { - let a = TypeArg::new_string("a"); - let b = TypeArg::new_string("b"); - - let parts = [SeqPart::Item(a.clone()), SeqPart::Item(b.clone())]; - let items = [a, b]; - - assert_eq!( - TypeArg::new_list_from_parts(parts.clone()), - TypeArg::new_list(items.clone()) - ); - - assert_eq!( - TypeArg::new_tuple_from_parts(parts), - TypeArg::new_tuple(items) - ); - } - - #[test] - fn new_list_from_parts_flatten() { - let a = Term::new_string("a"); - let b = Term::new_string("b"); - let c = Term::new_string("c"); - let d = Term::new_string("d"); - let var = Term::new_var_use(0, Term::new_list_type(Term::StringType)); - let parts = [ - SeqPart::Splice(Term::new_list([a.clone(), b.clone()])), - SeqPart::Splice(Term::new_list_concat([Term::new_list([c.clone()])])), - SeqPart::Item(d.clone()), - SeqPart::Splice(var.clone()), - ]; - assert_eq!( - Term::new_list_from_parts(parts), - Term::new_list_concat([Term::new_list([a, b, c, d]), var]) - ); - } - - #[test] - fn new_tuple_from_parts_flatten() { - let a = Term::new_string("a"); - let b = Term::new_string("b"); - let c = Term::new_string("c"); - let d = Term::new_string("d"); - let var = Term::new_var_use(0, Term::new_tuple([Term::StringType])); - let parts = [ - SeqPart::Splice(Term::new_tuple([a.clone(), b.clone()])), - SeqPart::Splice(Term::new_tuple_concat([Term::new_tuple([c.clone()])])), - SeqPart::Item(d.clone()), - SeqPart::Splice(var.clone()), - ]; - assert_eq!( - Term::new_tuple_from_parts(parts), - Term::new_tuple_concat([Term::new_tuple([a, b, c, d]), var]) - ); - } + use crate::types::{TypeBound, TypeRV, type_param::TypeArgError}; #[test] fn type_arg_fits_param() { let rowvar = TypeRV::new_row_var_use; - fn check(arg: impl Into, param: &TypeParam) -> Result<(), TermTypeError> { - check_term_type(&arg.into(), param) + fn check(arg: impl Into, param: &TypeParam) -> Result<(), TypeArgError> { + check_type_arg(&arg.into(), param) } fn check_seq>( args: &[T], param: &TypeParam, - ) -> Result<(), TermTypeError> { + ) -> Result<(), TypeArgError> { let arg = args.iter().cloned().map_into().collect_vec().into(); - check_term_type(&arg, param) + check_type_arg(&arg, param) } - // Simple cases: a Term::Type is a Term::RuntimeType but singleton sequences are lists + // Simple cases: a TypeArg::Type is a TypeParam::Type but singleton sequences are lists check(usize_t(), &TypeBound::Copyable.into()).unwrap(); - let seq_param = TypeParam::new_list_type(TypeBound::Copyable); + let seq_param = TypeParam::new_list(TypeBound::Copyable); check(usize_t(), &seq_param).unwrap_err(); - check_seq(&[usize_t()], &TypeBound::Linear.into()).unwrap_err(); + check_seq(&[usize_t()], &TypeBound::Any.into()).unwrap_err(); // Into a list of type, we can fit a single row var check(rowvar(0, TypeBound::Copyable), &seq_param).unwrap(); @@ -905,17 +507,17 @@ mod test { check_seq(&[rowvar(0, TypeBound::Copyable)], &seq_param).unwrap(); check_seq( &[ - rowvar(1, TypeBound::Linear), + rowvar(1, TypeBound::Any), usize_t().into(), rowvar(0, TypeBound::Copyable), ], - &TypeParam::new_list_type(TypeBound::Linear), + &TypeParam::new_list(TypeBound::Any), ) .unwrap(); // Next one fails because a list of Eq is required check_seq( &[ - rowvar(1, TypeBound::Linear), + rowvar(1, TypeBound::Any), usize_t().into(), rowvar(0, TypeBound::Copyable), ], @@ -930,9 +532,9 @@ mod test { .unwrap_err(); // Similar for nats (but no equivalent of fancy row vars) - check(5, &TypeParam::max_nat_type()).unwrap(); - check_seq(&[5], &TypeParam::max_nat_type()).unwrap_err(); - let list_of_nat = TypeParam::new_list_type(TypeParam::max_nat_type()); + check(5, &TypeParam::max_nat()).unwrap(); + check_seq(&[5], &TypeParam::max_nat()).unwrap_err(); + let list_of_nat = TypeParam::new_list(TypeParam::max_nat()); check(5, &list_of_nat).unwrap_err(); check_seq(&[5], &list_of_nat).unwrap(); check(TypeArg::new_var_use(0, list_of_nat.clone()), &list_of_nat).unwrap(); @@ -943,23 +545,15 @@ mod test { ) .unwrap_err(); - // `Term::TupleType` requires a `Term::Tuple` of the same number of elems - let usize_and_ty = - TypeParam::new_tuple_type([TypeParam::max_nat_type(), TypeBound::Copyable.into()]); - check( - TypeArg::Tuple(vec![5.into(), usize_t().into()]), - &usize_and_ty, - ) - .unwrap(); - check( - TypeArg::Tuple(vec![usize_t().into(), 5.into()]), - &usize_and_ty, - ) - .unwrap_err(); // Wrong way around - let two_types = TypeParam::new_tuple_type(Term::new_list([ - TypeBound::Linear.into(), - TypeBound::Linear.into(), - ])); + // TypeParam::Tuples require a TypeArg::Seq of the same number of elems + let usize_and_ty = TypeParam::Tuple { + params: vec![TypeParam::max_nat(), TypeBound::Copyable.into()], + }; + check(vec![5.into(), usize_t().into()], &usize_and_ty).unwrap(); + check(vec![usize_t().into(), 5.into()], &usize_and_ty).unwrap_err(); // Wrong way around + let two_types = TypeParam::Tuple { + params: vec![TypeBound::Any.into(), TypeBound::Any.into()], + }; check(TypeArg::new_var_use(0, two_types.clone()), &two_types).unwrap(); // not a Row Var which could have any number of elems check(TypeArg::new_var_use(0, seq_param), &two_types).unwrap_err(); @@ -967,143 +561,116 @@ mod test { #[test] fn type_arg_subst_row() { - let row_param = Term::new_list_type(TypeBound::Copyable); - let row_arg: Term = vec![bool_t().into(), Term::UNIT].into(); - check_term_type(&row_arg, &row_param).unwrap(); + let row_param = TypeParam::new_list(TypeBound::Copyable); + let row_arg: TypeArg = vec![bool_t().into(), TypeArg::UNIT].into(); + check_type_arg(&row_arg, &row_param).unwrap(); // Now say a row variable referring to *that* row was used // to instantiate an outer "row parameter" (list of type). - let outer_param = Term::new_list_type(TypeBound::Linear); - let outer_arg = Term::new_list([ - TypeRV::new_row_var_use(0, TypeBound::Copyable).into(), - usize_t().into(), - ]); - check_term_type(&outer_arg, &outer_param).unwrap(); + let outer_param = TypeParam::new_list(TypeBound::Any); + let outer_arg = TypeArg::Sequence { + elems: vec![ + TypeRV::new_row_var_use(0, TypeBound::Copyable).into(), + usize_t().into(), + ], + }; + check_type_arg(&outer_arg, &outer_param).unwrap(); let outer_arg2 = outer_arg.substitute(&Substitution(&[row_arg])); assert_eq!( outer_arg2, - vec![bool_t().into(), Term::UNIT, usize_t().into()].into() + vec![bool_t().into(), TypeArg::UNIT, usize_t().into()].into() ); // Of course this is still valid (as substitution is guaranteed to preserve validity) - check_term_type(&outer_arg2, &outer_param).unwrap(); + check_type_arg(&outer_arg2, &outer_param).unwrap(); } #[test] fn subst_list_list() { - let outer_param = Term::new_list_type(Term::new_list_type(TypeBound::Linear)); - let row_var_decl = Term::new_list_type(TypeBound::Copyable); - let row_var_use = Term::new_var_use(0, row_var_decl.clone()); - let good_arg = Term::new_list([ - // The row variables here refer to `row_var_decl` above - vec![usize_t().into()].into(), - row_var_use.clone(), - vec![row_var_use, usize_t().into()].into(), - ]); - check_term_type(&good_arg, &outer_param).unwrap(); + let outer_param = TypeParam::new_list(TypeParam::new_list(TypeBound::Any)); + let row_var_decl = TypeParam::new_list(TypeBound::Copyable); + let row_var_use = TypeArg::new_var_use(0, row_var_decl.clone()); + let good_arg = TypeArg::Sequence { + elems: vec![ + // The row variables here refer to `row_var_decl` above + vec![usize_t().into()].into(), + row_var_use.clone(), + vec![row_var_use, usize_t().into()].into(), + ], + }; + check_type_arg(&good_arg, &outer_param).unwrap(); // Outer list cannot include single types: - let Term::List(mut elems) = good_arg.clone() else { + let TypeArg::Sequence { mut elems } = good_arg.clone() else { panic!() }; elems.push(usize_t().into()); assert_eq!( - check_term_type(&Term::new_list(elems), &outer_param), - Err(TermTypeError::TypeMismatch { - term: Box::new(usize_t().into()), + check_type_arg(&TypeArg::Sequence { elems }, &outer_param), + Err(TypeArgError::TypeMismatch { + arg: usize_t().into(), // The error reports the type expected for each element of the list: - type_: Box::new(TypeParam::new_list_type(TypeBound::Linear)) + param: TypeParam::new_list(TypeBound::Any) }) ); // Now substitute a list of two types for that row-variable let row_var_arg = vec![usize_t().into(), bool_t().into()].into(); - check_term_type(&row_var_arg, &row_var_decl).unwrap(); + check_type_arg(&row_var_arg, &row_var_decl).unwrap(); let subst_arg = good_arg.substitute(&Substitution(&[row_var_arg.clone()])); - check_term_type(&subst_arg, &outer_param).unwrap(); // invariance of substitution + check_type_arg(&subst_arg, &outer_param).unwrap(); // invariance of substitution assert_eq!( subst_arg, - Term::new_list([ - Term::new_list([usize_t().into()]), - row_var_arg, - Term::new_list([usize_t().into(), bool_t().into(), usize_t().into()]) - ]) + TypeArg::Sequence { + elems: vec![ + vec![usize_t().into()].into(), + row_var_arg, + vec![usize_t().into(), bool_t().into(), usize_t().into()].into() + ] + } ); } - #[test] - fn bytes_json_roundtrip() { - let bytes_arg = Term::Bytes(vec![0, 1, 2, 3, 255, 254, 253, 252].into()); - let serialized = serde_json::to_string(&bytes_arg).unwrap(); - let deserialized: Term = serde_json::from_str(&serialized).unwrap(); - assert_eq!(deserialized, bytes_arg); - } - mod proptest { use proptest::prelude::*; - use super::super::{TermVar, UpperBound}; + use super::super::{TypeArg, TypeArgVariable, TypeParam, UpperBound}; use crate::proptest::RecursionDepth; - use crate::types::{Term, Type, TypeBound, proptest_utils::any_serde_type_param}; + use crate::types::{Type, TypeBound}; - impl Arbitrary for TermVar { + impl Arbitrary for TypeArgVariable { type Parameters = RecursionDepth; type Strategy = BoxedStrategy; fn arbitrary_with(depth: Self::Parameters) -> Self::Strategy { - (any::(), any_serde_type_param(depth)) - .prop_map(|(idx, cached_decl)| Self { - idx, - cached_decl: Box::new(cached_decl), - }) + (any::(), any_with::(depth)) + .prop_map(|(idx, cached_decl)| Self { idx, cached_decl }) .boxed() } } - impl Arbitrary for Term { + impl Arbitrary for TypeParam { type Parameters = RecursionDepth; type Strategy = BoxedStrategy; fn arbitrary_with(depth: Self::Parameters) -> Self::Strategy { use prop::collection::vec; use prop::strategy::Union; let mut strat = Union::new([ - Just(Self::StringType).boxed(), - Just(Self::BytesType).boxed(), - Just(Self::FloatType).boxed(), - Just(Self::StringType).boxed(), - any::().prop_map(Self::from).boxed(), - any::().prop_map(Self::from).boxed(), - any::().prop_map(Self::from).boxed(), - any::().prop_map(Self::from).boxed(), - any::>() - .prop_map(|bytes| Self::Bytes(bytes.into())) + Just(Self::String).boxed(), + any::().prop_map(|b| Self::Type { b }).boxed(), + any::() + .prop_map(|bound| Self::BoundedNat { bound }) .boxed(), - any::() - .prop_map(|value| Self::Float(value.into())) - .boxed(), - any_with::(depth).prop_map(Self::from).boxed(), ]); if !depth.leaf() { - // we descend here because we these constructors contain Terms + // we descend here because we these constructors contain TypeParams strat = strat - .or( - // TODO this is a bit dodgy, TypeArgVariables are supposed - // to be constructed from TypeArg::new_var_use. We are only - // using this instance for serialization now, but if we want - // to generate valid TypeArgs this will need to change. - any_with::(depth.descend()) - .prop_map(Self::Variable) - .boxed(), - ) - .or(any_with::(depth.descend()) - .prop_map(Self::new_list_type) - .boxed()) .or(any_with::(depth.descend()) - .prop_map(Self::new_tuple_type) + .prop_map(|x| Self::List { param: Box::new(x) }) .boxed()) .or(vec(any_with::(depth.descend()), 0..3) - .prop_map(Self::new_list) + .prop_map(|params| Self::Tuple { params }) .boxed()); } @@ -1111,10 +678,33 @@ mod test { } } - proptest! { - #[test] - fn term_contains_itself(term: Term) { - assert!(term.is_supertype(&term)); + impl Arbitrary for TypeArg { + type Parameters = RecursionDepth; + type Strategy = BoxedStrategy; + fn arbitrary_with(depth: Self::Parameters) -> Self::Strategy { + use prop::collection::vec; + use prop::strategy::Union; + let mut strat = Union::new([ + any::().prop_map(|n| Self::BoundedNat { n }).boxed(), + any::().prop_map(|arg| Self::String { arg }).boxed(), + any_with::(depth) + .prop_map(|ty| Self::Type { ty }) + .boxed(), + // TODO this is a bit dodgy, TypeArgVariables are supposed + // to be constructed from TypeArg::new_var_use. We are only + // using this instance for serialization now, but if we want + // to generate valid TypeArgs this will need to change. + any_with::(depth) + .prop_map(|v| Self::Variable { v }) + .boxed(), + ]); + if !depth.leaf() { + // We descend here because this constructor contains TypeArg> + strat = strat.or(vec(any_with::(depth.descend()), 0..3) + .prop_map(|elems| Self::Sequence { elems }) + .boxed()); + } + strat.boxed() } } } diff --git a/hugr-core/src/types/type_row.rs b/hugr-core/src/types/type_row.rs index 7b9e24d282..c458b8c181 100644 --- a/hugr-core/src/types/type_row.rs +++ b/hugr-core/src/types/type_row.rs @@ -8,8 +8,8 @@ use std::{ }; use super::{ - MaybeRV, NoRV, RowVariable, Substitution, Term, Transformable, Type, TypeArg, TypeBase, TypeRV, - TypeTransformer, type_param::TypeParam, + MaybeRV, NoRV, RowVariable, Substitution, Transformable, Type, TypeBase, TypeTransformer, + type_param::TypeParam, }; use crate::{extension::SignatureError, utils::display_list}; use delegate::delegate; @@ -28,8 +28,7 @@ pub struct TypeRowBase { /// Row of single types i.e. of known length, for node inputs/outputs pub type TypeRow = TypeRowBase; -/// Row of types and/or row variables, the number of actual types is thus -/// unknown +/// Row of types and/or row variables, the number of actual types is thus unknown pub type TypeRowRV = TypeRowBase; impl PartialEq> for TypeRowBase { @@ -196,81 +195,6 @@ impl From for TypeRow { } } -// Fallibly convert a [Term] to a [TypeRV]. -// -// This will fail if `arg` is of non-type kind (e.g. String). -impl TryFrom for TypeRV { - type Error = SignatureError; - - fn try_from(value: Term) -> Result { - match value { - TypeArg::Runtime(ty) => Ok(ty.into()), - TypeArg::Variable(v) => Ok(TypeRV::new_row_var_use( - v.index(), - v.bound_if_row_var() - .ok_or(SignatureError::InvalidTypeArgs)?, - )), - _ => Err(SignatureError::InvalidTypeArgs), - } - } -} - -// Fallibly convert a [Term] to a [TypeRow]. -// -// This will fail if `arg` is of non-sequence kind (e.g. Type) -// or if the sequence contains row variables. -impl TryFrom for TypeRow { - type Error = SignatureError; - - fn try_from(value: TypeArg) -> Result { - match value { - TypeArg::List(elems) => elems - .into_iter() - .map(|ta| ta.as_runtime()) - .collect::>>() - .map(|x| x.into()) - .ok_or(SignatureError::InvalidTypeArgs), - _ => Err(SignatureError::InvalidTypeArgs), - } - } -} - -// Fallibly convert a [TypeArg] to a [TypeRowRV]. -// -// This will fail if `arg` is of non-sequence kind (e.g. Type). -impl TryFrom for TypeRowRV { - type Error = SignatureError; - - fn try_from(value: Term) -> Result { - match value { - TypeArg::List(elems) => elems - .into_iter() - .map(TypeRV::try_from) - .collect::, _>>() - .map(|vec| vec.into()), - TypeArg::Variable(v) => Ok(vec![TypeRV::new_row_var_use( - v.index(), - v.bound_if_row_var() - .ok_or(SignatureError::InvalidTypeArgs)?, - )] - .into()), - _ => Err(SignatureError::InvalidTypeArgs), - } - } -} - -impl From for Term { - fn from(value: TypeRow) -> Self { - Term::List(value.into_owned().into_iter().map_into().collect()) - } -} - -impl From for Term { - fn from(value: TypeRowRV) -> Self { - Term::List(value.into_owned().into_iter().map_into().collect()) - } -} - impl Deref for TypeRowBase { type Target = [TypeBase]; @@ -287,12 +211,6 @@ impl DerefMut for TypeRowBase { #[cfg(test)] mod test { - use super::*; - use crate::{ - extension::prelude::bool_t, - types::{Type, TypeArg, TypeRV}, - }; - mod proptest { use crate::proptest::RecursionDepth; use crate::types::{MaybeRV, TypeBase, TypeRowBase}; @@ -313,78 +231,4 @@ mod test { } } } - - #[test] - fn test_try_from_term_to_typerv() { - // Test successful conversion with Runtime type - let runtime_type = Type::UNIT; - let term = TypeArg::Runtime(runtime_type.clone()); - let result = TypeRV::try_from(term); - assert!(result.is_ok()); - assert_eq!(result.unwrap(), TypeRV::from(runtime_type)); - - // Test failure with non-type kind - let term = Term::String("test".to_string()); - let result = TypeRV::try_from(term); - assert!(result.is_err()); - } - - #[test] - fn test_try_from_term_to_typerow() { - // Test successful conversion with List - let types = vec![Type::new_unit_sum(1), bool_t()]; - let type_args = types.iter().map(|t| TypeArg::Runtime(t.clone())).collect(); - let term = TypeArg::List(type_args); - let result = TypeRow::try_from(term); - assert!(result.is_ok()); - assert_eq!(result.unwrap(), TypeRow::from(types)); - - // Test failure with non-list - let term = TypeArg::Runtime(Type::UNIT); - let result = TypeRow::try_from(term); - assert!(result.is_err()); - } - - #[test] - fn test_try_from_term_to_typerowrv() { - // Test successful conversion with List - let types = [TypeRV::from(Type::UNIT), TypeRV::from(bool_t())]; - let type_args = types.iter().map(|t| t.clone().into()).collect(); - let term = TypeArg::List(type_args); - let result = TypeRowRV::try_from(term); - assert!(result.is_ok()); - - // Test failure with non-sequence kind - let term = Term::String("test".to_string()); - let result = TypeRowRV::try_from(term); - assert!(result.is_err()); - } - - #[test] - fn test_from_typerow_to_term() { - let types = vec![Type::UNIT, bool_t()]; - let type_row = TypeRow::from(types); - let term = Term::from(type_row); - - match term { - Term::List(elems) => { - assert_eq!(elems.len(), 2); - } - _ => panic!("Expected Term::List"), - } - } - - #[test] - fn test_from_typerowrv_to_term() { - let types = vec![TypeRV::from(Type::UNIT), TypeRV::from(bool_t())]; - let type_row_rv = TypeRowRV::from(types); - let term = Term::from(type_row_rv); - - match term { - TypeArg::List(elems) => { - assert_eq!(elems.len(), 2); - } - _ => panic!("Expected Term::List"), - } - } } diff --git a/hugr-core/tests/model.rs b/hugr-core/tests/model.rs index 3e451e41df..a312b16441 100644 --- a/hugr-core/tests/model.rs +++ b/hugr-core/tests/model.rs @@ -1,85 +1,105 @@ #![allow(missing_docs)] -use anyhow::Result; use std::str::FromStr; use hugr::std_extensions::std_reg; use hugr_core::{export::export_package, import::import_package}; use hugr_model::v0 as model; -fn roundtrip(source: &str) -> Result { +fn roundtrip(source: &str) -> String { let bump = model::bumpalo::Bump::new(); - let package_ast = model::ast::Package::from_str(source)?; - let package_table = package_ast.resolve(&bump)?; - let core = import_package(&package_table, &std_reg())?; + let package_ast = model::ast::Package::from_str(source).unwrap(); + let package_table = package_ast.resolve(&bump).unwrap(); + let core = import_package(&package_table, &std_reg()).unwrap(); let exported_table = export_package(&core.modules, &core.extensions, &bump); let exported_ast = exported_table.as_ast().unwrap(); - - Ok(exported_ast.to_string()) + exported_ast.to_string() } -macro_rules! test_roundtrip { - ($name: ident, $file: expr) => { - #[test] - #[cfg_attr(miri, ignore)] // Opening files is not supported in (isolated) miri - pub fn $name() { - let ast = roundtrip(include_str!($file)).unwrap_or_else(|err| panic!("{:?}", err)); - insta::assert_snapshot!(ast) - } - }; +#[test] +#[cfg_attr(miri, ignore)] // Opening files is not supported in (isolated) miri +pub fn test_roundtrip_add() { + insta::assert_snapshot!(roundtrip(include_str!( + "../../hugr-model/tests/fixtures/model-add.edn" + ))); } -test_roundtrip!( - test_roundtrip_add, - "../../hugr-model/tests/fixtures/model-add.edn" -); - -test_roundtrip!( - test_roundtrip_call, - "../../hugr-model/tests/fixtures/model-call.edn" -); +#[test] +#[cfg_attr(miri, ignore)] // Opening files is not supported in (isolated) miri +pub fn test_roundtrip_call() { + insta::assert_snapshot!(roundtrip(include_str!( + "../../hugr-model/tests/fixtures/model-call.edn" + ))); +} -test_roundtrip!( - test_roundtrip_alias, - "../../hugr-model/tests/fixtures/model-alias.edn" -); +#[test] +#[cfg_attr(miri, ignore)] // Opening files is not supported in (isolated) miri +pub fn test_roundtrip_alias() { + insta::assert_snapshot!(roundtrip(include_str!( + "../../hugr-model/tests/fixtures/model-alias.edn" + ))); +} -test_roundtrip!( - test_roundtrip_cfg, - "../../hugr-model/tests/fixtures/model-cfg.edn" -); +#[test] +#[cfg_attr(miri, ignore)] // Opening files is not supported in (isolated) miri +pub fn test_roundtrip_cfg() { + insta::assert_snapshot!(roundtrip(include_str!( + "../../hugr-model/tests/fixtures/model-cfg.edn" + ))); +} -test_roundtrip!( - test_roundtrip_cond, - "../../hugr-model/tests/fixtures/model-cond.edn" -); +#[test] +#[cfg_attr(miri, ignore)] // Opening files is not supported in (isolated) miri +pub fn test_roundtrip_cond() { + insta::assert_snapshot!(roundtrip(include_str!( + "../../hugr-model/tests/fixtures/model-cond.edn" + ))); +} -test_roundtrip!( - test_roundtrip_loop, - "../../hugr-model/tests/fixtures/model-loop.edn" -); +#[test] +#[cfg_attr(miri, ignore)] // Opening files is not supported in (isolated) miri +pub fn test_roundtrip_loop() { + insta::assert_snapshot!(roundtrip(include_str!( + "../../hugr-model/tests/fixtures/model-loop.edn" + ))); +} -test_roundtrip!( - test_roundtrip_params, - "../../hugr-model/tests/fixtures/model-params.edn" -); +#[test] +#[cfg_attr(miri, ignore)] // Opening files is not supported in (isolated) miri +pub fn test_roundtrip_params() { + insta::assert_snapshot!(roundtrip(include_str!( + "../../hugr-model/tests/fixtures/model-params.edn" + ))); +} -test_roundtrip!( - test_roundtrip_constraints, - "../../hugr-model/tests/fixtures/model-constraints.edn" -); +#[test] +#[cfg_attr(miri, ignore)] // Opening files is not supported in (isolated) miri +pub fn test_roundtrip_constraints() { + insta::assert_snapshot!(roundtrip(include_str!( + "../../hugr-model/tests/fixtures/model-constraints.edn" + ))); +} -test_roundtrip!( - test_roundtrip_const, - "../../hugr-model/tests/fixtures/model-const.edn" -); +#[test] +#[cfg_attr(miri, ignore)] // Opening files is not supported in (isolated) miri +pub fn test_roundtrip_const() { + insta::assert_snapshot!(roundtrip(include_str!( + "../../hugr-model/tests/fixtures/model-const.edn" + ))); +} -test_roundtrip!( - test_roundtrip_order, - "../../hugr-model/tests/fixtures/model-order.edn" -); +#[test] +#[cfg_attr(miri, ignore)] // Opening files is not supported in (isolated) miri +pub fn test_roundtrip_order() { + insta::assert_snapshot!(roundtrip(include_str!( + "../../hugr-model/tests/fixtures/model-order.edn" + ))); +} -test_roundtrip!( - test_roundtrip_entrypoint, - "../../hugr-model/tests/fixtures/model-entrypoint.edn" -); +#[test] +#[cfg_attr(miri, ignore)] // Opening files is not supported in (isolated) miri +pub fn test_roundtrip_entrypoint() { + insta::assert_snapshot!(roundtrip(include_str!( + "../../hugr-model/tests/fixtures/model-entrypoint.edn" + ))); +} diff --git a/hugr-persistent/tests/persistent_walker_example.rs b/hugr-core/tests/persistent_walker_example.rs similarity index 62% rename from hugr-persistent/tests/persistent_walker_example.rs rename to hugr-core/tests/persistent_walker_example.rs index 19a02bac6a..8da20df657 100644 --- a/hugr-persistent/tests/persistent_walker_example.rs +++ b/hugr-core/tests/persistent_walker_example.rs @@ -2,37 +2,42 @@ use std::collections::{BTreeSet, VecDeque}; -use itertools::{Either, Itertools}; +use hugr::types::EdgeKind; +use itertools::Itertools; use hugr_core::{ - Hugr, HugrView, IncomingPort, OutgoingPort, Port, PortIndex, + Hugr, HugrView, PortIndex, SimpleReplacement, builder::{DFGBuilder, Dataflow, DataflowHugr, endo_sig}, extension::prelude::qb_t, - ops::OpType, - types::EdgeKind, + hugr::{ + persistent::{CommitStateSpace, PersistentReplacement, PinnedWire, Walker}, + views::SiblingSubgraph, + }, }; -use hugr_persistent::{Commit, CommitStateSpace, PersistentWire, PinnedSubgraph, Walker}; - /// The maximum commit depth that we will consider in this example -const MAX_COMMITS: usize = 4; +const MAX_COMMITS: usize = 2; // We define a HUGR extension within this file, with CZ and H gates. Normally, // you would use an existing extension (e.g. as provided by tket2). -use walker_example_extension::cz_gate; +use walker_example_extension::{cz_gate, h_gate}; mod walker_example_extension { use std::sync::Arc; - use hugr_core::Extension; - use hugr_core::extension::ExtensionId; - use hugr_core::ops::{ExtensionOp, OpName}; - use hugr_core::types::{FuncValueType, PolyFuncTypeRV}; + use hugr::Extension; + use hugr::extension::ExtensionId; + use hugr::ops::{ExtensionOp, OpName}; + use hugr::types::{FuncValueType, PolyFuncTypeRV}; use lazy_static::lazy_static; use semver::Version; use super::*; + fn one_qb_func() -> PolyFuncTypeRV { + FuncValueType::new_endo(qb_t()).into() + } + fn two_qb_func() -> PolyFuncTypeRV { FuncValueType::new_endo(vec![qb_t(), qb_t()]).into() } @@ -44,6 +49,15 @@ mod walker_example_extension { EXTENSION_ID, Version::new(0, 0, 0), |extension, extension_ref| { + extension + .add_op( + OpName::new_inline("H"), + "Hadamard".into(), + one_qb_func(), + extension_ref, + ) + .unwrap(); + extension .add_op( OpName::new_inline("CZ"), @@ -61,6 +75,10 @@ mod walker_example_extension { static ref EXTENSION: Arc = extension(); } + pub fn h_gate() -> ExtensionOp { + EXTENSION.instantiate_extension_op("H", []).unwrap() + } + pub fn cz_gate() -> ExtensionOp { EXTENSION.instantiate_extension_op("CZ", []).unwrap() } @@ -91,12 +109,15 @@ fn dfg_hugr() -> Hugr { builder.finish_hugr_with_outputs(vec![q0, q1, q2]).unwrap() } -fn empty_2qb_hugr(flip_args: bool) -> Hugr { - let builder = DFGBuilder::new(endo_sig(vec![qb_t(), qb_t()])).unwrap(); - let [mut q0, mut q1] = builder.input_wires_arr(); - if flip_args { - (q0, q1) = (q1, q0); - } +// TODO: currently empty replacements are buggy, so we have temporarily added +// a single Hadamard gate on each qubit. +fn empty_2qb_hugr() -> Hugr { + let mut builder = DFGBuilder::new(endo_sig(vec![qb_t(), qb_t()])).unwrap(); + let [q0, q1] = builder.input_wires_arr(); + let h0 = builder.add_dataflow_op(h_gate(), vec![q0]).unwrap(); + let [q0] = h0.outputs_arr(); + let h1 = builder.add_dataflow_op(h_gate(), vec![q1]).unwrap(); + let [q1] = h1.outputs_arr(); builder.finish_hugr_with_outputs(vec![q0, q1]).unwrap() } @@ -113,7 +134,7 @@ fn two_cz_3qb_hugr() -> Hugr { /// Traverse all commits in state space, enqueueing all outgoing wires of /// CZ nodes fn enqueue_all( - queue: &mut VecDeque<(PersistentWire, Walker<'static>)>, + queue: &mut VecDeque<(PinnedWire, Walker<'static>)>, state_space: &CommitStateSpace, ) { for id in state_space.all_commit_ids() { @@ -149,10 +170,10 @@ fn build_state_space() -> CommitStateSpace { enqueue_all(&mut wire_queue, &state_space); while let Some((wire, walker)) = wire_queue.pop_front() { - if !walker.is_complete(&wire, None) { + if !wire.is_complete(None) { // expand the wire in all possible ways - let (pinned_node, pinned_port) = walker - .wire_pinned_ports(&wire, None) + let (pinned_node, pinned_port) = wire + .all_pinned_ports() .next() .expect("at least one port was already pinned"); assert!( @@ -162,7 +183,7 @@ fn build_state_space() -> CommitStateSpace { for subwalker in walker.expand(&wire, None) { assert!( subwalker.as_hugr_view().contains_node(pinned_node), - "pinned node {pinned_node:?} is deleted", + "pinned node is deleted" ); wire_queue.push_back((subwalker.get_wire(pinned_node, pinned_port), subwalker)); } @@ -170,10 +191,7 @@ fn build_state_space() -> CommitStateSpace { // we have a complete wire, so we can commute the CZ gates (or // cancel them out) - let patch_nodes: BTreeSet<_> = walker - .wire_pinned_ports(&wire, None) - .map(|(n, _)| n) - .collect(); + let patch_nodes: BTreeSet<_> = wire.all_pinned_ports().map(|(n, _)| n).collect(); // check that the patch applies to more than one commit (or the base), // otherwise we have infinite commutations back and forth let patch_owners: BTreeSet<_> = patch_nodes.iter().map(|n| n.0).collect(); @@ -186,16 +204,22 @@ fn build_state_space() -> CommitStateSpace { continue; } - let Some(new_commit) = create_commit(wire, &walker) else { + let Some(repl) = create_replacement(wire, &walker) else { continue; }; assert_eq!( - new_commit.deleted_nodes().collect::>(), + repl.subgraph() + .nodes() + .iter() + .copied() + .collect::>(), patch_nodes ); - state_space.try_add_commit(new_commit).unwrap(); + state_space + .try_add_replacement(repl) + .expect("repl acts on non-empty subgraph"); // enqueue new wires added by the replacement // (this will also add a lot of already visited wires, but they will @@ -207,14 +231,14 @@ fn build_state_space() -> CommitStateSpace { state_space } -fn create_commit(wire: PersistentWire, walker: &Walker) -> Option { +fn create_replacement(wire: PinnedWire, walker: &Walker) -> Option { let hugr = walker.clone().into_persistent_hugr(); let (out_node, _) = wire - .single_outgoing_port(&hugr) + .pinned_outport() .expect("outgoing port was already pinned (and is unique)"); let (in_node, _) = wire - .all_incoming_ports(&hugr) + .pinned_inports() .exactly_one() .ok() .expect("all our wires have exactly one incoming port"); @@ -232,30 +256,13 @@ fn create_commit(wire: PersistentWire, walker: &Walker) -> Option { let all_edges = hugr.node_connections(out_node, in_node).collect_vec(); let n_shared_qubits = all_edges.len(); - match n_shared_qubits { + let (repl_hugr, subgraph) = match n_shared_qubits { 2 => { // out_node and in_node act on the same qubits - // => replace the two CZ gates with the empty 2qb HUGR - - // If the two CZ gates have flipped port ordering, we need to insert - // a swap gate - let add_swap = all_edges[0][0].index() != all_edges[0][1].index(); - - // Get the wires between the two CZ gates - let wires = all_edges - .into_iter() - .map(|[out_port, _]| walker.get_wire(out_node, out_port)); - - // Create the commit - walker.try_create_commit( - PinnedSubgraph::try_from_wires(wires, walker).unwrap(), - empty_2qb_hugr(add_swap), - |_, port| { - // the incoming/outgoing ports of the subgraph map trivially to the empty 2qb - // HUGR - let dir = port.direction(); - Port::new(dir.reverse(), port.index()) - }, + // => cancel out the two CZ gates + ( + empty_2qb_hugr(), + SiblingSubgraph::try_from_nodes([out_node, in_node], &hugr).ok()?, ) } 1 => { @@ -266,53 +273,32 @@ fn create_commit(wire: PersistentWire, walker: &Walker) -> Option { // Need to figure out the permutation of the qubits // => establish which qubit is shared between the two CZ gates let [out_port, in_port] = all_edges.into_iter().exactly_one().unwrap(); - let shared_qb_out = out_port.index(); - let shared_qb_in = in_port.index(); - - walker.try_create_commit( - PinnedSubgraph::try_from_wires([wire], walker).unwrap(), - repl_hugr, - |node, port| { - // map the incoming/outgoing ports of the subgraph to the replacement as - // follows: - // - the first qubit is the one that is shared between the two CZ gates - // - the second qubit only touches the first CZ (out_node) - // - the third qubit only touches the second CZ (in_node) - match port.as_directed() { - Either::Left(incoming) => { - let in_boundary: [(_, IncomingPort); 3] = [ - (out_node, shared_qb_out.into()), - (out_node, (1 - shared_qb_out).into()), - (in_node, (1 - shared_qb_in).into()), - ]; - let out_index = in_boundary - .iter() - .position(|&(n, p)| n == node && p == incoming) - .expect("invalid input port"); - OutgoingPort::from(out_index).into() - } - Either::Right(outgoing) => { - let out_boundary: [(_, OutgoingPort); 3] = [ - (in_node, shared_qb_in.into()), - (out_node, (1 - shared_qb_out).into()), - (in_node, (1 - shared_qb_in).into()), - ]; - let in_index = out_boundary - .iter() - .position(|&(n, p)| n == node && p == outgoing) - .expect("invalid output port"); - IncomingPort::from(in_index).into() - } - } - }, + let shared_qb_on_out_node = out_port.index(); + let shared_qb_on_in_node = in_port.index(); + + let subgraph = SiblingSubgraph::try_new( + vec![ + vec![(out_node, shared_qb_on_out_node.into())], + vec![(out_node, (1 - shared_qb_on_out_node).into())], + vec![(in_node, (1 - shared_qb_on_in_node).into())], + ], + vec![ + (in_node, shared_qb_on_in_node.into()), + (out_node, (1 - shared_qb_on_out_node).into()), + (in_node, (1 - shared_qb_on_in_node).into()), + ], + &hugr, ) + .ok()?; + + (repl_hugr, subgraph) } _ => unreachable!(), - } - .ok() + }; + + SimpleReplacement::try_new(subgraph, &hugr, repl_hugr).ok() } -#[ignore = "takes 10s (todo: optimise)"] #[test] fn walker_example() { let state_space = build_state_space(); @@ -338,16 +324,26 @@ fn walker_example() { ); } + // assert_eq!(state_space.all_commit_ids().count(), 13); + let empty_commits = state_space .all_commit_ids() - .filter(|&id| state_space.inserted_nodes(id).count() == 0) + // .filter(|&id| state_space.commit_hugr(id).num_nodes() == 3) + .filter(|&id| { + state_space + .inserted_nodes(id) + .filter(|&n| state_space.get_optype(n) == &h_gate().into()) + .count() + == 2 + }) .collect_vec(); // there should be a combination of three empty commits that are compatible // and such that the resulting HUGR is empty let mut empty_hugr = None; - for cs in empty_commits.iter().combinations(3) { - let cs = cs.into_iter().copied(); + // for cs in empty_commits.iter().combinations(3) { + for cs in empty_commits.iter().combinations(2) { + let cs = cs.into_iter().copied().collect_vec(); if let Ok(hugr) = state_space.try_extract_hugr(cs) { empty_hugr = Some(hugr); } @@ -355,23 +351,16 @@ fn walker_example() { let empty_hugr = empty_hugr.unwrap().to_hugr(); - // The empty hugr should have 7 nodes: - // module root, funcdef, 2 func IO, DFG root, 2 DFG IO - assert_eq!(empty_hugr.num_nodes(), 7); - assert_eq!( - empty_hugr - .nodes() - .filter(|&n| { - !matches!( - empty_hugr.get_optype(n), - OpType::Input(_) - | OpType::Output(_) - | OpType::FuncDefn(_) - | OpType::Module(_) - | OpType::DFG(_) - ) - }) - .count(), - 0 - ); + // assert_eq!(empty_hugr.num_nodes(), 3); + + let n_cz = empty_hugr + .nodes() + .filter(|&n| empty_hugr.get_optype(n) == &cz_gate().into()) + .count(); + let n_h = empty_hugr + .nodes() + .filter(|&n| empty_hugr.get_optype(n) == &h_gate().into()) + .count(); + assert_eq!(n_cz, 2); + assert_eq!(n_h, 4); } diff --git a/hugr-core/tests/snapshots/model__roundtrip_add.snap b/hugr-core/tests/snapshots/model__roundtrip_add.snap index 625b621f09..4547fb3ebd 100644 --- a/hugr-core/tests/snapshots/model__roundtrip_add.snap +++ b/hugr-core/tests/snapshots/model__roundtrip_add.snap @@ -1,42 +1,29 @@ --- source: hugr-core/tests/model.rs -expression: ast +expression: "roundtrip(include_str!(\"../../hugr-model/tests/fixtures/model-add.edn\"))" --- (hugr 0) (mod) -(import core.meta.description) - -(import core.nat) - (import core.fn) -(import arithmetic.int.types.int) +(import arithmetic.int.iadd) -(declare-operation - arithmetic.int.iadd - (param ?0 core.nat) - (core.fn - [(arithmetic.int.types.int ?0) (arithmetic.int.types.int ?0)] - [(arithmetic.int.types.int ?0)]) - (meta - (core.meta.description - "addition modulo 2^N (signed and unsigned versions are the same op)"))) +(import arithmetic.int.types.int) (define-func - public example.add (core.fn - [(arithmetic.int.types.int 6) (arithmetic.int.types.int 6)] - [(arithmetic.int.types.int 6)]) + [arithmetic.int.types.int arithmetic.int.types.int] + [arithmetic.int.types.int]) (dfg [%0 %1] [%2] (signature (core.fn - [(arithmetic.int.types.int 6) (arithmetic.int.types.int 6)] - [(arithmetic.int.types.int 6)])) - ((arithmetic.int.iadd 6) [%0 %1] [%2] + [arithmetic.int.types.int arithmetic.int.types.int] + [arithmetic.int.types.int])) + (arithmetic.int.iadd [%0 %1] [%2] (signature (core.fn - [(arithmetic.int.types.int 6) (arithmetic.int.types.int 6)] - [(arithmetic.int.types.int 6)]))))) + [arithmetic.int.types.int arithmetic.int.types.int] + [arithmetic.int.types.int]))))) diff --git a/hugr-core/tests/snapshots/model__roundtrip_alias.snap b/hugr-core/tests/snapshots/model__roundtrip_alias.snap index 6174d8c744..5f0b44daf4 100644 --- a/hugr-core/tests/snapshots/model__roundtrip_alias.snap +++ b/hugr-core/tests/snapshots/model__roundtrip_alias.snap @@ -1,6 +1,6 @@ --- source: hugr-core/tests/model.rs -expression: ast +expression: "roundtrip(include_str!(\"../../hugr-model/tests/fixtures/model-alias.edn\"))" --- (hugr 0) diff --git a/hugr-core/tests/snapshots/model__roundtrip_call.snap b/hugr-core/tests/snapshots/model__roundtrip_call.snap index 8681cf372c..50d9c55c33 100644 --- a/hugr-core/tests/snapshots/model__roundtrip_call.snap +++ b/hugr-core/tests/snapshots/model__roundtrip_call.snap @@ -1,6 +1,6 @@ --- source: hugr-core/tests/model.rs -expression: ast +expression: "roundtrip(include_str!(\"../../hugr-model/tests/fixtures/model-call.edn\"))" --- (hugr 0) @@ -17,14 +17,12 @@ expression: ast (import arithmetic.int.types.int) (declare-func - public example.callee (core.fn [arithmetic.int.types.int] [arithmetic.int.types.int]) (meta (compat.meta_json "description" "\"This is a function declaration.\"")) (meta (compat.meta_json "title" "\"Callee\""))) (define-func - public example.caller (core.fn [arithmetic.int.types.int] [arithmetic.int.types.int]) (meta @@ -43,7 +41,6 @@ expression: ast (core.fn [arithmetic.int.types.int] [arithmetic.int.types.int]))))) (define-func - public example.load (core.fn [] [(core.fn [arithmetic.int.types.int] [arithmetic.int.types.int])]) (dfg [] [%0] diff --git a/hugr-core/tests/snapshots/model__roundtrip_cfg.snap b/hugr-core/tests/snapshots/model__roundtrip_cfg.snap index da2fe4851f..7a6136bdb2 100644 --- a/hugr-core/tests/snapshots/model__roundtrip_cfg.snap +++ b/hugr-core/tests/snapshots/model__roundtrip_cfg.snap @@ -1,6 +1,6 @@ --- source: hugr-core/tests/model.rs -expression: ast +expression: "roundtrip(include_str!(\"../../hugr-model/tests/fixtures/model-cfg.edn\"))" --- (hugr 0) @@ -16,35 +16,36 @@ expression: ast (import core.adt) -(define-func public example.cfg_loop (param ?0 core.type) (core.fn [?0] [?0]) +(define-func example.cfg_loop (param ?0 core.type) (core.fn [?0] [?0]) (dfg [%0] [%1] (signature (core.fn [?0] [?0])) (cfg [%0] [%1] (signature (core.fn [?0] [?0])) (cfg [%2] [%3] - (signature (core.ctrl [[?0]] [[?0]])) + (signature (core.fn [(core.ctrl [?0])] [(core.ctrl [?0])])) (block [%2] [%3 %2] - (signature (core.ctrl [[?0]] [[?0] [?0]])) + (signature + (core.fn [(core.ctrl [?0])] [(core.ctrl [?0]) (core.ctrl [?0])])) (dfg [%4] [%5] (signature (core.fn [?0] [(core.adt [[?0] [?0]])])) ((core.make_adt 0) [%4] [%5] (signature (core.fn [?0] [(core.adt [[?0] [?0]])]))))))))) -(define-func public example.cfg_order (param ?0 core.type) (core.fn [?0] [?0]) +(define-func example.cfg_order (param ?0 core.type) (core.fn [?0] [?0]) (dfg [%0] [%1] (signature (core.fn [?0] [?0])) (cfg [%0] [%1] (signature (core.fn [?0] [?0])) (cfg [%2] [%3] - (signature (core.ctrl [[?0]] [[?0]])) + (signature (core.fn [(core.ctrl [?0])] [(core.ctrl [?0])])) (block [%2] [%6] - (signature (core.ctrl [[?0]] [[?0]])) + (signature (core.fn [(core.ctrl [?0])] [(core.ctrl [?0])])) (dfg [%4] [%5] (signature (core.fn [?0] [(core.adt [[?0]])])) ((core.make_adt 0) [%4] [%5] (signature (core.fn [?0] [(core.adt [[?0]])]))))) (block [%6] [%3] - (signature (core.ctrl [[?0]] [[?0]])) + (signature (core.fn [(core.ctrl [?0])] [(core.ctrl [?0])])) (dfg [%7] [%8] (signature (core.fn [?0] [(core.adt [[?0]])])) ((core.make_adt 0) [%7] [%8] diff --git a/hugr-core/tests/snapshots/model__roundtrip_cond.snap b/hugr-core/tests/snapshots/model__roundtrip_cond.snap index c51323db5c..e4c49f1193 100644 --- a/hugr-core/tests/snapshots/model__roundtrip_cond.snap +++ b/hugr-core/tests/snapshots/model__roundtrip_cond.snap @@ -1,57 +1,42 @@ --- source: hugr-core/tests/model.rs -expression: ast +expression: "roundtrip(include_str!(\"../../hugr-model/tests/fixtures/model-cond.edn\"))" --- (hugr 0) (mod) -(import core.meta.description) - -(import core.nat) - (import core.fn) (import core.adt) (import arithmetic.int.types.int) -(declare-operation - arithmetic.int.ineg - (param ?0 core.nat) - (core.fn [(arithmetic.int.types.int ?0)] [(arithmetic.int.types.int ?0)]) - (meta - (core.meta.description - "negation modulo 2^N (signed and unsigned versions are the same op)"))) +(import arithmetic.int.ineg) (define-func - public example.cond (core.fn - [(core.adt [[] []]) (arithmetic.int.types.int 6)] - [(arithmetic.int.types.int 6)]) + [(core.adt [[] []]) arithmetic.int.types.int] + [arithmetic.int.types.int]) (dfg [%0 %1] [%2] (signature (core.fn - [(core.adt [[] []]) (arithmetic.int.types.int 6)] - [(arithmetic.int.types.int 6)])) + [(core.adt [[] []]) arithmetic.int.types.int] + [arithmetic.int.types.int])) (cond [%0 %1] [%2] (signature (core.fn - [(core.adt [[] []]) (arithmetic.int.types.int 6)] - [(arithmetic.int.types.int 6)])) + [(core.adt [[] []]) arithmetic.int.types.int] + [arithmetic.int.types.int])) (dfg [%3] [%3] (signature - (core.fn - [(arithmetic.int.types.int 6)] - [(arithmetic.int.types.int 6)]))) + (core.fn [arithmetic.int.types.int] [arithmetic.int.types.int]))) (dfg [%4] [%5] (signature - (core.fn - [(arithmetic.int.types.int 6)] - [(arithmetic.int.types.int 6)])) - ((arithmetic.int.ineg 6) [%4] [%5] + (core.fn [arithmetic.int.types.int] [arithmetic.int.types.int])) + (arithmetic.int.ineg [%4] [%5] (signature (core.fn - [(arithmetic.int.types.int 6)] - [(arithmetic.int.types.int 6)]))))))) + [arithmetic.int.types.int] + [arithmetic.int.types.int]))))))) diff --git a/hugr-core/tests/snapshots/model__roundtrip_const.snap b/hugr-core/tests/snapshots/model__roundtrip_const.snap index 3b386275ba..99cfdb55e9 100644 --- a/hugr-core/tests/snapshots/model__roundtrip_const.snap +++ b/hugr-core/tests/snapshots/model__roundtrip_const.snap @@ -1,6 +1,6 @@ --- source: hugr-core/tests/model.rs -expression: ast +expression: "roundtrip(include_str!(\"../../hugr-model/tests/fixtures/model-const.edn\"))" --- (hugr 0) @@ -28,10 +28,7 @@ expression: ast (import core.adt) -(define-func - public - example.bools - (core.fn [] [(core.adt [[] []]) (core.adt [[] []])]) +(define-func example.bools (core.fn [] [(core.adt [[] []]) (core.adt [[] []])]) (dfg [] [%0 %1] (signature (core.fn [] [(core.adt [[] []]) (core.adt [[] []])])) ((core.load_const (core.const.adt [[] []] _ 0 [])) [] [%0] @@ -40,7 +37,6 @@ expression: ast (signature (core.fn [] [(core.adt [[] []])]))))) (define-func - public example.make-pair (core.fn [] @@ -77,10 +73,7 @@ expression: ast [[(collections.array.array 5 (arithmetic.int.types.int 6)) arithmetic.float.types.float64]])]))))) -(define-func - public - example.f64-json - (core.fn [] [arithmetic.float.types.float64]) +(define-func example.f64-json (core.fn [] [arithmetic.float.types.float64]) (dfg [] [%0 %1] (signature (core.fn diff --git a/hugr-core/tests/snapshots/model__roundtrip_constraints.snap b/hugr-core/tests/snapshots/model__roundtrip_constraints.snap index 2c50e73489..b9b406f3c5 100644 --- a/hugr-core/tests/snapshots/model__roundtrip_constraints.snap +++ b/hugr-core/tests/snapshots/model__roundtrip_constraints.snap @@ -1,6 +1,6 @@ --- source: hugr-core/tests/model.rs -expression: ast +expression: "roundtrip(include_str!(\"../../hugr-model/tests/fixtures/model-constraints.edn\"))" --- (hugr 0) @@ -10,25 +10,20 @@ expression: ast (import core.nat) -(import core.nonlinear) - (import core.type) -(import core.fn) +(import core.nonlinear) -(import core.title) +(import core.fn) (declare-func - private - _1 + array.replicate (param ?0 core.nat) (param ?1 core.type) (where (core.nonlinear ?1)) - (core.fn [?1] [(collections.array.array ?0 ?1)]) - (meta (core.title "array.replicate"))) + (core.fn [?1] [(collections.array.array ?0 ?1)])) (declare-func - public array.copy (param ?0 core.nat) (param ?1 core.type) @@ -38,7 +33,6 @@ expression: ast [(collections.array.array ?0 ?1) (collections.array.array ?0 ?1)])) (define-func - public util.copy (param ?0 core.type) (where (core.nonlinear ?0)) diff --git a/hugr-core/tests/snapshots/model__roundtrip_entrypoint.snap b/hugr-core/tests/snapshots/model__roundtrip_entrypoint.snap index 1340bb6b02..1db0b9d1cd 100644 --- a/hugr-core/tests/snapshots/model__roundtrip_entrypoint.snap +++ b/hugr-core/tests/snapshots/model__roundtrip_entrypoint.snap @@ -1,6 +1,6 @@ --- source: hugr-core/tests/model.rs -expression: ast +expression: "roundtrip(include_str!(\"../../hugr-model/tests/fixtures/model-entrypoint.edn\"))" --- (hugr 0) @@ -10,9 +10,8 @@ expression: ast (import core.entrypoint) -(define-func public main (core.fn [] []) - (meta core.entrypoint) - (dfg (signature (core.fn [] [])))) +(define-func main (core.fn [] []) + (dfg (signature (core.fn [] [])) (meta core.entrypoint))) (mod) @@ -20,9 +19,8 @@ expression: ast (import core.entrypoint) -(define-func public wrapper_dfg (core.fn [] []) - (meta core.entrypoint) - (dfg (signature (core.fn [] [])))) +(define-func wrapper_dfg (core.fn [] []) + (dfg (signature (core.fn [] [])) (meta core.entrypoint))) (mod) @@ -36,17 +34,16 @@ expression: ast (import core.adt) -(define-func public wrapper_cfg (core.fn [] []) +(define-func wrapper_cfg (core.fn [] []) (dfg (signature (core.fn [] [])) (cfg (signature (core.fn [] [])) - (meta core.entrypoint) (cfg [%0] [%1] - (signature (core.ctrl [[]] [[]])) + (signature (core.fn [(core.ctrl [])] [(core.ctrl [])])) (meta core.entrypoint) (block [%0] [%1] - (signature (core.ctrl [[]] [[]])) + (signature (core.fn [(core.ctrl [])] [(core.ctrl [])])) (dfg [] [%2] (signature (core.fn [] [(core.adt [[]])])) ((core.make_adt 0) [] [%2] diff --git a/hugr-core/tests/snapshots/model__roundtrip_loop.snap b/hugr-core/tests/snapshots/model__roundtrip_loop.snap index e6991516f3..50035a637c 100644 --- a/hugr-core/tests/snapshots/model__roundtrip_loop.snap +++ b/hugr-core/tests/snapshots/model__roundtrip_loop.snap @@ -1,6 +1,6 @@ --- source: hugr-core/tests/model.rs -expression: ast +expression: "roundtrip(include_str!(\"../../hugr-model/tests/fixtures/model-loop.edn\"))" --- (hugr 0) @@ -14,10 +14,7 @@ expression: ast (import core.adt) -(import core.title) - -(define-func private _1 (param ?0 core.type) (core.fn [?0] [?0]) - (meta (core.title "example.loop")) +(define-func example.loop (param ?0 core.type) (core.fn [?0] [?0]) (dfg [%0] [%1] (signature (core.fn [?0] [?0])) (tail-loop [%0] [%1] diff --git a/hugr-core/tests/snapshots/model__roundtrip_order.snap b/hugr-core/tests/snapshots/model__roundtrip_order.snap index dda51b71cf..e358670e1a 100644 --- a/hugr-core/tests/snapshots/model__roundtrip_order.snap +++ b/hugr-core/tests/snapshots/model__roundtrip_order.snap @@ -1,80 +1,60 @@ --- source: hugr-core/tests/model.rs -expression: ast +expression: "roundtrip(include_str!(\"../../hugr-model/tests/fixtures/model-order.edn\"))" --- (hugr 0) (mod) -(import core.meta.description) +(import core.order_hint.key) -(import core.order_hint.input_key) +(import core.fn) (import core.order_hint.order) (import arithmetic.int.types.int) -(import core.nat) - -(import core.order_hint.key) - -(import core.order_hint.output_key) - -(import core.fn) - -(declare-operation - arithmetic.int.ineg - (param ?0 core.nat) - (core.fn [(arithmetic.int.types.int ?0)] [(arithmetic.int.types.int ?0)]) - (meta - (core.meta.description - "negation modulo 2^N (signed and unsigned versions are the same op)"))) +(import arithmetic.int.ineg) (define-func - public main (core.fn - [(arithmetic.int.types.int 6) - (arithmetic.int.types.int 6) - (arithmetic.int.types.int 6) - (arithmetic.int.types.int 6)] - [(arithmetic.int.types.int 6) - (arithmetic.int.types.int 6) - (arithmetic.int.types.int 6) - (arithmetic.int.types.int 6)]) + [arithmetic.int.types.int + arithmetic.int.types.int + arithmetic.int.types.int + arithmetic.int.types.int] + [arithmetic.int.types.int + arithmetic.int.types.int + arithmetic.int.types.int + arithmetic.int.types.int]) (dfg [%0 %1 %2 %3] [%4 %5 %6 %7] (signature (core.fn - [(arithmetic.int.types.int 6) - (arithmetic.int.types.int 6) - (arithmetic.int.types.int 6) - (arithmetic.int.types.int 6)] - [(arithmetic.int.types.int 6) - (arithmetic.int.types.int 6) - (arithmetic.int.types.int 6) - (arithmetic.int.types.int 6)])) - (meta (core.order_hint.input_key 2)) - (meta (core.order_hint.order 2 4)) - (meta (core.order_hint.order 2 3)) - (meta (core.order_hint.output_key 3)) + [arithmetic.int.types.int + arithmetic.int.types.int + arithmetic.int.types.int + arithmetic.int.types.int] + [arithmetic.int.types.int + arithmetic.int.types.int + arithmetic.int.types.int + arithmetic.int.types.int])) (meta (core.order_hint.order 4 7)) (meta (core.order_hint.order 5 6)) (meta (core.order_hint.order 5 4)) - (meta (core.order_hint.order 5 3)) (meta (core.order_hint.order 6 7)) - ((arithmetic.int.ineg 6) [%0] [%4] + (arithmetic.int.ineg [%0] [%4] (signature - (core.fn [(arithmetic.int.types.int 6)] [(arithmetic.int.types.int 6)])) + (core.fn [arithmetic.int.types.int] [arithmetic.int.types.int])) (meta (core.order_hint.key 4))) - ((arithmetic.int.ineg 6) [%1] [%5] + (arithmetic.int.ineg [%1] [%5] (signature - (core.fn [(arithmetic.int.types.int 6)] [(arithmetic.int.types.int 6)])) + (core.fn [arithmetic.int.types.int] [arithmetic.int.types.int])) (meta (core.order_hint.key 5))) - ((arithmetic.int.ineg 6) [%2] [%6] + (arithmetic.int.ineg [%2] [%6] (signature - (core.fn [(arithmetic.int.types.int 6)] [(arithmetic.int.types.int 6)])) + (core.fn [arithmetic.int.types.int] [arithmetic.int.types.int])) (meta (core.order_hint.key 6))) - ((arithmetic.int.ineg 6) [%3] [%7] + (arithmetic.int.ineg [%3] [%7] (signature - (core.fn [(arithmetic.int.types.int 6)] [(arithmetic.int.types.int 6)])) + (core.fn [arithmetic.int.types.int] [arithmetic.int.types.int])) (meta (core.order_hint.key 7))))) diff --git a/hugr-core/tests/snapshots/model__roundtrip_params.snap b/hugr-core/tests/snapshots/model__roundtrip_params.snap index c5c5eac95b..77d6d9cc77 100644 --- a/hugr-core/tests/snapshots/model__roundtrip_params.snap +++ b/hugr-core/tests/snapshots/model__roundtrip_params.snap @@ -1,54 +1,18 @@ --- source: hugr-core/tests/model.rs -expression: ast +expression: "roundtrip(include_str!(\"../../hugr-model/tests/fixtures/model-params.edn\"))" --- (hugr 0) (mod) -(import core.call) - -(import core.type) - -(import core.bytes) - -(import core.nat) - (import core.fn) -(import core.str) - -(import core.float) - -(import core.title) +(import core.type) (define-func - public example.swap (param ?0 core.type) (param ?1 core.type) (core.fn [?0 ?1] [?1 ?0]) (dfg [%0 %1] [%1 %0] (signature (core.fn [?0 ?1] [?1 ?0])))) - -(declare-func - public - example.literals - (param ?0 core.str) - (param ?1 core.nat) - (param ?2 core.bytes) - (param ?3 core.float) - (core.fn [] [])) - -(define-func private _5 (core.fn [] []) - (meta (core.title "example.call_literals")) - (dfg - (signature (core.fn [] [])) - ((core.call - [] - [] - (example.literals - "string" - 42 - (bytes "SGVsbG8gd29ybGQg8J+Yig==") - 6.023e23)) - (signature (core.fn [] []))))) diff --git a/hugr-llvm/CHANGELOG.md b/hugr-llvm/CHANGELOG.md index 7ab27c08a4..0ed47be85a 100644 --- a/hugr-llvm/CHANGELOG.md +++ b/hugr-llvm/CHANGELOG.md @@ -1,5 +1,3 @@ -# Changelog - # Changelog All notable changes to this project will be documented in this file. @@ -7,16 +5,6 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). -## [0.21.0](https://github.com/CQCL/hugr/compare/hugr-llvm-v0.20.2...hugr-llvm-v0.21.0) - 2025-07-09 - -### New Features - -- [**breaking**] No nested FuncDefns (or AliasDefns) ([#2256](https://github.com/CQCL/hugr/pull/2256)) -- [**breaking**] Split `TypeArg::Sequence` into tuples and lists. ([#2140](https://github.com/CQCL/hugr/pull/2140)) -- [**breaking**] More helpful error messages in model import ([#2272](https://github.com/CQCL/hugr/pull/2272)) -- [**breaking**] Merge `TypeParam` and `TypeArg` into one `Term` type in Rust ([#2309](https://github.com/CQCL/hugr/pull/2309)) -- Add `MakeError` op ([#2377](https://github.com/CQCL/hugr/pull/2377)) - ## [0.20.2](https://github.com/CQCL/hugr/compare/hugr-llvm-v0.20.1...hugr-llvm-v0.20.2) - 2025-06-25 ### New Features diff --git a/hugr-llvm/Cargo.toml b/hugr-llvm/Cargo.toml index 0a906686db..31bd533960 100644 --- a/hugr-llvm/Cargo.toml +++ b/hugr-llvm/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "hugr-llvm" -version = "0.22.1" +version = "0.20.2" description = "A general and extensible crate for lowering HUGRs into LLVM IR" edition.workspace = true @@ -26,8 +26,8 @@ workspace = true [dependencies] inkwell = { version = "0.6.0", default-features = false } -hugr-core = { path = "../hugr-core", version = "0.22.1" } -anyhow.workspace = true +hugr-core = { path = "../hugr-core", version = "0.20.2" } +anyhow = "1.0.98" itertools.workspace = true delegate.workspace = true petgraph.workspace = true diff --git a/hugr-llvm/src/emit/ops/cfg.rs b/hugr-llvm/src/emit/ops/cfg.rs index 790110fe2a..bdf27389a2 100644 --- a/hugr-llvm/src/emit/ops/cfg.rs +++ b/hugr-llvm/src/emit/ops/cfg.rs @@ -217,7 +217,7 @@ impl<'c, 'hugr, H: HugrView> CfgEmitter<'c, 'hugr, H> { #[cfg(test)] mod test { - use hugr_core::builder::{Dataflow, DataflowHugr, SubContainer}; + use hugr_core::builder::{Dataflow, DataflowSubContainer, SubContainer}; use hugr_core::extension::ExtensionRegistry; use hugr_core::extension::prelude::{self, bool_t}; use hugr_core::ops::Value; @@ -279,7 +279,7 @@ mod test { cfg_builder.branch(&b1, 1, &exit_block).unwrap(); let cfg = cfg_builder.finish_sub_container().unwrap(); let [cfg_out] = cfg.outputs_arr(); - builder.finish_hugr_with_outputs([cfg_out]).unwrap() + builder.finish_with_outputs([cfg_out]).unwrap() }); llvm_ctx.add_extensions(CodegenExtsBuilder::add_default_prelude_extensions); check_emission!(hugr, llvm_ctx); @@ -395,7 +395,7 @@ mod test { .unwrap() .outputs_arr() }; - builder.finish_hugr_with_outputs([outer_cfg_out]).unwrap() + builder.finish_with_outputs([outer_cfg_out]).unwrap() }); check_emission!(hugr, llvm_ctx); } diff --git a/hugr-llvm/src/emit/snapshots/hugr_llvm__emit__test__test_fns__diverse_cfg_children@llvm14.snap b/hugr-llvm/src/emit/snapshots/hugr_llvm__emit__test__test_fns__diverse_cfg_children@llvm14.snap index d673a4b73e..9ea0d09e8d 100644 --- a/hugr-llvm/src/emit/snapshots/hugr_llvm__emit__test__test_fns__diverse_cfg_children@llvm14.snap +++ b/hugr-llvm/src/emit/snapshots/hugr_llvm__emit__test__test_fns__diverse_cfg_children@llvm14.snap @@ -13,12 +13,21 @@ entry_block: ; preds = %alloca_block br label %0 0: ; preds = %entry_block - switch i1 false, label %1 [ + %1 = call i1 @_hl.scoped_func.7() + switch i1 false, label %2 [ ] -1: ; preds = %0 - br label %2 +2: ; preds = %0 + br label %3 -2: ; preds = %1 +3: ; preds = %2 + ret i1 %1 +} + +define i1 @_hl.scoped_func.7() { +alloca_block: + br label %entry_block + +entry_block: ; preds = %alloca_block ret i1 false } diff --git a/hugr-llvm/src/emit/snapshots/hugr_llvm__emit__test__test_fns__diverse_cfg_children@pre-mem2reg@llvm14.snap b/hugr-llvm/src/emit/snapshots/hugr_llvm__emit__test__test_fns__diverse_cfg_children@pre-mem2reg@llvm14.snap index 025b85a9ac..c38ac33f4d 100644 --- a/hugr-llvm/src/emit/snapshots/hugr_llvm__emit__test__test_fns__diverse_cfg_children@pre-mem2reg@llvm14.snap +++ b/hugr-llvm/src/emit/snapshots/hugr_llvm__emit__test__test_fns__diverse_cfg_children@pre-mem2reg@llvm14.snap @@ -10,30 +10,31 @@ alloca_block: %"0" = alloca i1, align 1 %"4_0" = alloca i1, align 1 %"01" = alloca i1, align 1 - %"11_0" = alloca {}, align 8 - %"12_0" = alloca i1, align 1 + %"15_0" = alloca {}, align 8 + %"16_0" = alloca i1, align 1 br label %entry_block entry_block: ; preds = %alloca_block br label %0 0: ; preds = %entry_block - store i1 false, i1* %"12_0", align 1 - store {} undef, {}* %"11_0", align 1 - %"11_02" = load {}, {}* %"11_0", align 1 - %"12_03" = load i1, i1* %"12_0", align 1 - store {} %"11_02", {}* %"11_0", align 1 - store i1 %"12_03", i1* %"12_0", align 1 - %"11_04" = load {}, {}* %"11_0", align 1 - %"12_05" = load i1, i1* %"12_0", align 1 - switch i1 false, label %1 [ + %1 = call i1 @_hl.scoped_func.7() + store i1 %1, i1* %"16_0", align 1 + store {} undef, {}* %"15_0", align 1 + %"15_02" = load {}, {}* %"15_0", align 1 + %"16_03" = load i1, i1* %"16_0", align 1 + store {} %"15_02", {}* %"15_0", align 1 + store i1 %"16_03", i1* %"16_0", align 1 + %"15_04" = load {}, {}* %"15_0", align 1 + %"16_05" = load i1, i1* %"16_0", align 1 + switch i1 false, label %2 [ ] -1: ; preds = %0 - store i1 %"12_05", i1* %"01", align 1 - br label %2 +2: ; preds = %0 + store i1 %"16_05", i1* %"01", align 1 + br label %3 -2: ; preds = %1 +3: ; preds = %2 %"06" = load i1, i1* %"01", align 1 store i1 %"06", i1* %"4_0", align 1 %"4_07" = load i1, i1* %"4_0", align 1 @@ -41,3 +42,17 @@ entry_block: ; preds = %alloca_block %"08" = load i1, i1* %"0", align 1 ret i1 %"08" } + +define i1 @_hl.scoped_func.7() { +alloca_block: + %"0" = alloca i1, align 1 + %"10_0" = alloca i1, align 1 + br label %entry_block + +entry_block: ; preds = %alloca_block + store i1 false, i1* %"10_0", align 1 + %"10_01" = load i1, i1* %"10_0", align 1 + store i1 %"10_01", i1* %"0", align 1 + %"02" = load i1, i1* %"0", align 1 + ret i1 %"02" +} diff --git a/hugr-llvm/src/emit/snapshots/hugr_llvm__emit__test__test_fns__diverse_dfg_children@llvm14.snap b/hugr-llvm/src/emit/snapshots/hugr_llvm__emit__test__test_fns__diverse_dfg_children@llvm14.snap new file mode 100644 index 0000000000..ea9074b87b --- /dev/null +++ b/hugr-llvm/src/emit/snapshots/hugr_llvm__emit__test__test_fns__diverse_dfg_children@llvm14.snap @@ -0,0 +1,23 @@ +--- +source: hugr-llvm/src/emit/test.rs +expression: mod_str +--- +; ModuleID = 'test_context' +source_filename = "test_context" + +define i1 @_hl.main.1() { +alloca_block: + br label %entry_block + +entry_block: ; preds = %alloca_block + %0 = call i1 @_hl.scoped_func.8() + ret i1 %0 +} + +define i1 @_hl.scoped_func.8() { +alloca_block: + br label %entry_block + +entry_block: ; preds = %alloca_block + ret i1 false +} diff --git a/hugr-llvm/src/emit/snapshots/hugr_llvm__emit__test__test_fns__diverse_dfg_children@pre-mem2reg@llvm14.snap b/hugr-llvm/src/emit/snapshots/hugr_llvm__emit__test__test_fns__diverse_dfg_children@pre-mem2reg@llvm14.snap new file mode 100644 index 0000000000..f990db641b --- /dev/null +++ b/hugr-llvm/src/emit/snapshots/hugr_llvm__emit__test__test_fns__diverse_dfg_children@pre-mem2reg@llvm14.snap @@ -0,0 +1,38 @@ +--- +source: hugr-llvm/src/emit/test.rs +expression: mod_str +--- +; ModuleID = 'test_context' +source_filename = "test_context" + +define i1 @_hl.main.1() { +alloca_block: + %"0" = alloca i1, align 1 + %"4_0" = alloca i1, align 1 + %"12_0" = alloca i1, align 1 + br label %entry_block + +entry_block: ; preds = %alloca_block + %0 = call i1 @_hl.scoped_func.8() + store i1 %0, i1* %"12_0", align 1 + %"12_01" = load i1, i1* %"12_0", align 1 + store i1 %"12_01", i1* %"4_0", align 1 + %"4_02" = load i1, i1* %"4_0", align 1 + store i1 %"4_02", i1* %"0", align 1 + %"03" = load i1, i1* %"0", align 1 + ret i1 %"03" +} + +define i1 @_hl.scoped_func.8() { +alloca_block: + %"0" = alloca i1, align 1 + %"11_0" = alloca i1, align 1 + br label %entry_block + +entry_block: ; preds = %alloca_block + store i1 false, i1* %"11_0", align 1 + %"11_01" = load i1, i1* %"11_0", align 1 + store i1 %"11_01", i1* %"0", align 1 + %"02" = load i1, i1* %"0", align 1 + ret i1 %"02" +} diff --git a/hugr-llvm/src/emit/test.rs b/hugr-llvm/src/emit/test.rs index d79ac361cd..d5194bf47b 100644 --- a/hugr-llvm/src/emit/test.rs +++ b/hugr-llvm/src/emit/test.rs @@ -1,7 +1,9 @@ use crate::types::HugrFuncType; use crate::utils::fat::FatNode; use anyhow::{Result, anyhow}; -use hugr_core::builder::{BuildHandle, DFGWrapper, FunctionBuilder}; +use hugr_core::builder::{ + BuildHandle, Container, DFGWrapper, HugrBuilder, ModuleBuilder, SubContainer, +}; use hugr_core::extension::ExtensionRegistry; use hugr_core::ops::handle::FuncID; use hugr_core::types::TypeRow; @@ -13,7 +15,7 @@ use inkwell::values::GenericValue; use super::EmitHugr; #[allow(clippy::upper_case_acronyms)] -pub type DFGW = DFGWrapper>>; +pub type DFGW<'a> = DFGWrapper<&'a mut Hugr, BuildHandle>>; pub struct SimpleHugrConfig { ins: TypeRow, @@ -129,13 +131,31 @@ impl SimpleHugrConfig { self } - pub fn finish(self, make: impl FnOnce(DFGW) -> Hugr) -> Hugr { + pub fn finish( + self, + make: impl for<'a> FnOnce(DFGW<'a>) -> as SubContainer>::ContainerHandle, + ) -> Hugr { self.finish_with_exts(|builder, _| make(builder)) } + pub fn finish_with_exts( + self, + make: impl for<'a> FnOnce( + DFGW<'a>, + &ExtensionRegistry, + ) -> as SubContainer>::ContainerHandle, + ) -> Hugr { + let mut mod_b = ModuleBuilder::new(); + let func_b = mod_b + .define_function("main", HugrFuncType::new(self.ins, self.outs)) + .unwrap(); + make(func_b, &self.extensions); + + // Intentionally left as a debugging aid. If the HUGR you construct + // fails validation, uncomment the following line to print it out + // unvalidated. + // println!("{}", mod_b.hugr().mermaid_string()); - pub fn finish_with_exts(self, make: impl FnOnce(DFGW, &ExtensionRegistry) -> Hugr) -> Hugr { - let func_b = FunctionBuilder::new("main", HugrFuncType::new(self.ins, self.outs)).unwrap(); - make(func_b, &self.extensions) + mod_b.finish_hugr().unwrap_or_else(|e| panic!("{e}")) } } @@ -167,7 +187,11 @@ pub use insta; macro_rules! check_emission { // Call the macro with a snapshot name. ($snapshot_name:expr, $hugr: ident, $test_ctx:ident) => {{ - let root = $crate::utils::fat::FatExt::fat_root(&$hugr).unwrap(); + let root = + $crate::utils::fat::FatExt::fat_root::<$crate::emit::test::hugr_core::ops::Module>( + &$hugr, + ) + .unwrap(); let emission = $crate::emit::test::Emission::emit_hugr(root, $test_ctx.get_emit_hugr()).unwrap(); @@ -213,8 +237,8 @@ mod test_fns { use crate::custom::CodegenExtsBuilder; use crate::types::{HugrFuncType, HugrSumType}; + use hugr_core::builder::DataflowSubContainer; use hugr_core::builder::{Container, Dataflow, HugrBuilder, ModuleBuilder, SubContainer}; - use hugr_core::builder::{DataflowHugr, DataflowSubContainer}; use hugr_core::extension::PRELUDE_REGISTRY; use hugr_core::extension::prelude::{ConstUsize, bool_t, usize_t}; use hugr_core::ops::constant::CustomConst; @@ -242,7 +266,7 @@ mod test_fns { builder.input_wires(), ) .unwrap(); - builder.finish_hugr_with_outputs(tag.outputs()).unwrap() + builder.finish_with_outputs(tag.outputs()).unwrap() }); let _ = check_emission!(hugr, llvm_ctx); } @@ -260,7 +284,7 @@ mod test_fns { let w = b.input_wires(); b.finish_with_outputs(w).unwrap() }; - builder.finish_hugr_with_outputs(dfg.outputs()).unwrap() + builder.finish_with_outputs(dfg.outputs()).unwrap() }); check_emission!(hugr, llvm_ctx); } @@ -305,7 +329,7 @@ mod test_fns { cond_b.finish_sub_container().unwrap() }; let [o1, o2] = cond.outputs_arr(); - builder.finish_hugr_with_outputs([o1, o2]).unwrap() + builder.finish_with_outputs([o1, o2]).unwrap() }) }; check_emission!(hugr, llvm_ctx); @@ -325,7 +349,7 @@ mod test_fns { .with_extensions(STD_REG.to_owned()) .finish(|mut builder: DFGW| { let konst = builder.add_load_value(v); - builder.finish_hugr_with_outputs([konst]).unwrap() + builder.finish_with_outputs([konst]).unwrap() }); check_emission!(hugr, llvm_ctx); } @@ -387,7 +411,7 @@ mod test_fns { .instantiate_extension_op("iadd", [4.into()]) .unwrap(); let add = builder.add_dataflow_op(ext_op, [k1, k2]).unwrap(); - builder.finish_hugr_with_outputs(add.outputs()).unwrap() + builder.finish_with_outputs(add.outputs()).unwrap() }); check_emission!(hugr, llvm_ctx); } @@ -429,6 +453,34 @@ mod test_fns { check_emission!(hugr, llvm_ctx); } + #[rstest] + fn diverse_dfg_children(llvm_ctx: TestContext) { + let hugr = SimpleHugrConfig::new() + .with_outs(bool_t()) + .finish(|mut builder: DFGW| { + let [r] = { + let mut builder = builder + .dfg_builder(HugrFuncType::new(type_row![], bool_t()), []) + .unwrap(); + let konst = builder.add_constant(Value::false_val()); + let func = { + let mut builder = builder + .define_function( + "scoped_func", + HugrFuncType::new(type_row![], bool_t()), + ) + .unwrap(); + let w = builder.load_const(&konst); + builder.finish_with_outputs([w]).unwrap() + }; + let [r] = builder.call(func.handle(), &[], []).unwrap().outputs_arr(); + builder.finish_with_outputs([r]).unwrap().outputs_arr() + }; + builder.finish_with_outputs([r]).unwrap() + }); + check_emission!(hugr, llvm_ctx); + } + #[rstest] fn diverse_cfg_children(llvm_ctx: TestContext) { let hugr = SimpleHugrConfig::new() @@ -437,19 +489,29 @@ mod test_fns { let [r] = { let mut builder = builder.cfg_builder([], vec![bool_t()].into()).unwrap(); let konst = builder.add_constant(Value::false_val()); + let func = { + let mut builder = builder + .define_function( + "scoped_func", + HugrFuncType::new(type_row![], bool_t()), + ) + .unwrap(); + let w = builder.load_const(&konst); + builder.finish_with_outputs([w]).unwrap() + }; let entry = { let mut builder = builder .entry_builder([type_row![]], vec![bool_t()].into()) .unwrap(); let control = builder.add_load_value(Value::unary_unit_sum()); - let r = builder.load_const(&konst); + let [r] = builder.call(func.handle(), &[], []).unwrap().outputs_arr(); builder.finish_with_outputs(control, [r]).unwrap() }; let exit = builder.exit_block(); builder.branch(&entry, 0, &exit).unwrap(); builder.finish_sub_container().unwrap().outputs_arr() }; - builder.finish_hugr_with_outputs([r]).unwrap() + builder.finish_with_outputs([r]).unwrap() }); check_emission!(hugr, llvm_ctx); } @@ -513,7 +575,7 @@ mod test_fns { .finish_with_outputs(sum_inp_w, []) .unwrap() .outputs_arr(); - builder.finish_hugr_with_outputs(outs).unwrap() + builder.finish_with_outputs(outs).unwrap() }) }; llvm_ctx.add_extensions(CodegenExtsBuilder::add_default_prelude_extensions); @@ -634,7 +696,7 @@ mod test_fns { }; let [out_int] = tail_l.outputs_arr(); builder - .finish_hugr_with_outputs([out_int]) + .finish_with_outputs([out_int]) .unwrap_or_else(|e| panic!("{e}")) }) } @@ -669,7 +731,7 @@ mod test_fns { .with_extensions(PRELUDE_REGISTRY.to_owned()) .finish(|mut builder: DFGW| { let konst = builder.add_load_value(ConstUsize::new(42)); - builder.finish_hugr_with_outputs([konst]).unwrap() + builder.finish_with_outputs([konst]).unwrap() }); exec_ctx.add_extensions(CodegenExtsBuilder::add_default_prelude_extensions); assert_eq!(42, exec_ctx.exec_hugr_u64(hugr, "main")); diff --git a/hugr-llvm/src/extension/collections/array.rs b/hugr-llvm/src/extension/collections/array.rs index da5141d72f..725fbe2724 100644 --- a/hugr-llvm/src/extension/collections/array.rs +++ b/hugr-llvm/src/extension/collections/array.rs @@ -214,7 +214,7 @@ impl CodegenExtension for ArrayCodegenExtension { .custom_type((array::EXTENSION_ID, array::ARRAY_TYPENAME), { let ccg = self.0.clone(); move |ts, hugr_type| { - let [TypeArg::BoundedNat(n), TypeArg::Runtime(ty)] = hugr_type.args() else { + let [TypeArg::BoundedNat { n }, TypeArg::Type { ty }] = hugr_type.args() else { return Err(anyhow!("Invalid type args for array type")); }; let elem_ty = ts.llvm_type(ty)?; @@ -908,7 +908,7 @@ pub fn emit_scan_op<'c, H: HugrView>( #[cfg(test)] mod test { - use hugr_core::builder::{DataflowHugr, HugrBuilder}; + use hugr_core::builder::Container as _; use hugr_core::extension::prelude::either_type; use hugr_core::ops::Tag; use hugr_core::std_extensions::STD_REG; @@ -952,7 +952,7 @@ mod test { build_all_array_ops(builder.dfg_builder_endo([]).unwrap()) .finish_sub_container() .unwrap(); - builder.finish_hugr().unwrap() + builder.finish_sub_container().unwrap() }); llvm_ctx.add_extensions(|cge| { cge.add_default_prelude_extensions() @@ -971,7 +971,7 @@ mod test { let arr = builder.add_new_array(usize_t(), [us1, us2]).unwrap(); let (_, arr) = builder.add_array_get(usize_t(), 2, arr, us1).unwrap(); builder.add_array_discard(usize_t(), 2, arr).unwrap(); - builder.finish_hugr_with_outputs([]).unwrap() + builder.finish_with_outputs([]).unwrap() }); llvm_ctx.add_extensions(|cge| { cge.add_default_prelude_extensions() @@ -991,7 +991,7 @@ mod test { let (arr1, arr2) = builder.add_array_clone(usize_t(), 2, arr).unwrap(); builder.add_array_discard(usize_t(), 2, arr1).unwrap(); builder.add_array_discard(usize_t(), 2, arr2).unwrap(); - builder.finish_hugr_with_outputs([]).unwrap() + builder.finish_with_outputs([]).unwrap() }); llvm_ctx.add_extensions(|cge| { cge.add_default_prelude_extensions() @@ -1008,7 +1008,7 @@ mod test { .finish(|mut builder| { let vs = vec![ConstUsize::new(1).into(), ConstUsize::new(2).into()]; let arr = builder.add_load_value(array::ArrayValue::new(usize_t(), vs)); - builder.finish_hugr_with_outputs([arr]).unwrap() + builder.finish_with_outputs([arr]).unwrap() }); llvm_ctx.add_extensions(|cge| { cge.add_default_prelude_extensions() @@ -1102,7 +1102,7 @@ mod test { } builder.finish_sub_container().unwrap().out_wire(0) }; - builder.finish_hugr_with_outputs([r]).unwrap() + builder.finish_with_outputs([r]).unwrap() }); exec_ctx.add_extensions(|cge| { cge.add_default_prelude_extensions() @@ -1207,7 +1207,7 @@ mod test { .unwrap(); conditional.finish_sub_container().unwrap().out_wire(0) }; - builder.finish_hugr_with_outputs([r]).unwrap() + builder.finish_with_outputs([r]).unwrap() }); exec_ctx.add_extensions(|cge| { cge.add_default_prelude_extensions() @@ -1314,7 +1314,7 @@ mod test { conditional.finish_sub_container().unwrap().out_wire(0) }; builder.add_array_discard(int_ty.clone(), 2, arr).unwrap(); - builder.finish_hugr_with_outputs([r]).unwrap() + builder.finish_with_outputs([r]).unwrap() }); exec_ctx.add_extensions(|cge| { cge.add_default_prelude_extensions() @@ -1371,7 +1371,7 @@ mod test { builder .add_array_discard(int_ty.clone(), 2, arr_clone) .unwrap(); - builder.finish_hugr_with_outputs([elem]).unwrap() + builder.finish_with_outputs([elem]).unwrap() }); exec_ctx.add_extensions(|cge| { cge.add_default_prelude_extensions() @@ -1441,7 +1441,7 @@ mod test { arr, ) .unwrap(); - builder.finish_hugr_with_outputs([r]).unwrap() + builder.finish_with_outputs([r]).unwrap() }); exec_ctx.add_extensions(|cge| { cge.add_default_prelude_extensions() @@ -1486,7 +1486,7 @@ mod test { r = builder.add_iadd(6, r, elem).unwrap(); } - builder.finish_hugr_with_outputs([r]).unwrap() + builder.finish_with_outputs([r]).unwrap() }); exec_ctx.add_extensions(|cge| { cge.add_default_prelude_extensions() @@ -1518,8 +1518,7 @@ mod test { .with_outs(int_ty.clone()) .with_extensions(exec_registry()) .finish(|mut builder| { - let mut mb = builder.module_root_builder(); - let mut func = mb + let mut func = builder .define_function("foo", Signature::new(vec![], vec![int_ty.clone()])) .unwrap(); let v = func.add_load_value(ConstInt::new_u(6, value).unwrap()); @@ -1540,7 +1539,7 @@ mod test { builder .add_array_discard(int_ty.clone(), size, arr) .unwrap(); - builder.finish_hugr_with_outputs([elem]).unwrap() + builder.finish_with_outputs([elem]).unwrap() }); exec_ctx.add_extensions(|cge| { cge.add_default_prelude_extensions() @@ -1573,8 +1572,7 @@ mod test { .add_new_array(int_ty.clone(), new_array_args) .unwrap(); - let mut mb = builder.module_root_builder(); - let mut func = mb + let mut func = builder .define_function( "foo", Signature::new(vec![int_ty.clone()], vec![int_ty.clone()]), @@ -1612,7 +1610,7 @@ mod test { builder .add_array_discard_empty(int_ty.clone(), arr) .unwrap(); - builder.finish_hugr_with_outputs([r]).unwrap() + builder.finish_with_outputs([r]).unwrap() }); exec_ctx.add_extensions(|cge| { cge.add_default_prelude_extensions() @@ -1644,8 +1642,7 @@ mod test { .add_new_array(int_ty.clone(), new_array_args) .unwrap(); - let mut mb = builder.module_root_builder(); - let mut func = mb + let mut func = builder .define_function( "foo", Signature::new( @@ -1669,7 +1666,7 @@ mod test { .unwrap() .outputs_arr(); builder.add_array_discard(Type::UNIT, size, arr).unwrap(); - builder.finish_hugr_with_outputs([sum]).unwrap() + builder.finish_with_outputs([sum]).unwrap() }); exec_ctx.add_extensions(|cge| { cge.add_default_prelude_extensions() diff --git a/hugr-llvm/src/extension/collections/list.rs b/hugr-llvm/src/extension/collections/list.rs index 746ce11bc0..e1ff76e2a8 100644 --- a/hugr-llvm/src/extension/collections/list.rs +++ b/hugr-llvm/src/extension/collections/list.rs @@ -203,7 +203,7 @@ fn emit_list_op<'c, H: HugrView>( op: ListOp, ) -> Result<()> { let hugr_elem_ty = match args.node().args() { - [TypeArg::Runtime(ty)] => ty.clone(), + [TypeArg::Type { ty }] => ty.clone(), _ => { bail!("Collections: invalid type args for list op"); } @@ -366,7 +366,7 @@ fn build_load_i8_ptr<'c, H: HugrView>( #[cfg(test)] mod test { use hugr_core::{ - builder::{Dataflow, DataflowHugr}, + builder::{Dataflow, DataflowSubContainer}, extension::{ ExtensionRegistry, prelude::{self, ConstUsize, qb_t, usize_t}, @@ -407,7 +407,7 @@ mod test { .add_dataflow_op(ext_op, hugr_builder.input_wires()) .unwrap() .outputs(); - hugr_builder.finish_hugr_with_outputs(outputs).unwrap() + hugr_builder.finish_with_outputs(outputs).unwrap() }); llvm_ctx.add_extensions(CodegenExtsBuilder::add_default_prelude_extensions); llvm_ctx.add_extensions(CodegenExtsBuilder::add_default_list_extensions); @@ -427,7 +427,7 @@ mod test { .with_extensions(es) .finish(|mut hugr_builder| { let list = hugr_builder.add_load_value(ListValue::new(elem_ty, contents)); - hugr_builder.finish_hugr_with_outputs(vec![list]).unwrap() + hugr_builder.finish_with_outputs(vec![list]).unwrap() }); llvm_ctx.add_extensions(CodegenExtsBuilder::add_default_prelude_extensions); diff --git a/hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__static_array__test__emit_static_array_of_static_array@llvm14.snap b/hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__static_array__test__emit_static_array_of_static_array@llvm14.snap index ad17a2c59f..1af774422e 100644 --- a/hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__static_array__test__emit_static_array_of_static_array@llvm14.snap +++ b/hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__static_array__test__emit_static_array_of_static_array@llvm14.snap @@ -15,14 +15,14 @@ source_filename = "test_context" @sa.inner.7f5d5e16.0 = constant { i64, [7 x i64] } { i64 7, [7 x i64] [i64 7, i64 7, i64 7, i64 7, i64 7, i64 7, i64 7] } @sa.inner.a0bc9c53.0 = constant { i64, [8 x i64] } { i64 8, [8 x i64] [i64 8, i64 8, i64 8, i64 8, i64 8, i64 8, i64 8, i64 8] } @sa.inner.1e8aada3.0 = constant { i64, [9 x i64] } { i64 9, [9 x i64] [i64 9, i64 9, i64 9, i64 9, i64 9, i64 9, i64 9, i64 9, i64 9] } -@sa.outer.e55b610a.0 = constant { i64, [10 x { i64, [0 x i64] }*] } { i64 10, [10 x { i64, [0 x i64] }*] [{ i64, [0 x i64] }* @sa.inner.6acc1b76.0, { i64, [0 x i64] }* bitcast ({ i64, [1 x i64] }* @sa.inner.e637bb5.0 to { i64, [0 x i64] }*), { i64, [0 x i64] }* bitcast ({ i64, [2 x i64] }* @sa.inner.2b6593f.0 to { i64, [0 x i64] }*), { i64, [0 x i64] }* bitcast ({ i64, [3 x i64] }* @sa.inner.1b9ad7c.0 to { i64, [0 x i64] }*), { i64, [0 x i64] }* bitcast ({ i64, [4 x i64] }* @sa.inner.e67fbfa4.0 to { i64, [0 x i64] }*), { i64, [0 x i64] }* bitcast ({ i64, [5 x i64] }* @sa.inner.15dc27f6.0 to { i64, [0 x i64] }*), { i64, [0 x i64] }* bitcast ({ i64, [6 x i64] }* @sa.inner.c43a2bb2.0 to { i64, [0 x i64] }*), { i64, [0 x i64] }* bitcast ({ i64, [7 x i64] }* @sa.inner.7f5d5e16.0 to { i64, [0 x i64] }*), { i64, [0 x i64] }* bitcast ({ i64, [8 x i64] }* @sa.inner.a0bc9c53.0 to { i64, [0 x i64] }*), { i64, [0 x i64] }* bitcast ({ i64, [9 x i64] }* @sa.inner.1e8aada3.0 to { i64, [0 x i64] }*)] } +@sa.outer.c4a5911a.0 = constant { i64, [10 x { i64, [0 x i64] }*] } { i64 10, [10 x { i64, [0 x i64] }*] [{ i64, [0 x i64] }* @sa.inner.6acc1b76.0, { i64, [0 x i64] }* bitcast ({ i64, [1 x i64] }* @sa.inner.e637bb5.0 to { i64, [0 x i64] }*), { i64, [0 x i64] }* bitcast ({ i64, [2 x i64] }* @sa.inner.2b6593f.0 to { i64, [0 x i64] }*), { i64, [0 x i64] }* bitcast ({ i64, [3 x i64] }* @sa.inner.1b9ad7c.0 to { i64, [0 x i64] }*), { i64, [0 x i64] }* bitcast ({ i64, [4 x i64] }* @sa.inner.e67fbfa4.0 to { i64, [0 x i64] }*), { i64, [0 x i64] }* bitcast ({ i64, [5 x i64] }* @sa.inner.15dc27f6.0 to { i64, [0 x i64] }*), { i64, [0 x i64] }* bitcast ({ i64, [6 x i64] }* @sa.inner.c43a2bb2.0 to { i64, [0 x i64] }*), { i64, [0 x i64] }* bitcast ({ i64, [7 x i64] }* @sa.inner.7f5d5e16.0 to { i64, [0 x i64] }*), { i64, [0 x i64] }* bitcast ({ i64, [8 x i64] }* @sa.inner.a0bc9c53.0 to { i64, [0 x i64] }*), { i64, [0 x i64] }* bitcast ({ i64, [9 x i64] }* @sa.inner.1e8aada3.0 to { i64, [0 x i64] }*)] } define i64 @_hl.main.1() { alloca_block: br label %entry_block entry_block: ; preds = %alloca_block - %0 = getelementptr inbounds { i64, [0 x { i64, [0 x i64] }*] }, { i64, [0 x { i64, [0 x i64] }*] }* bitcast ({ i64, [10 x { i64, [0 x i64] }*] }* @sa.outer.e55b610a.0 to { i64, [0 x { i64, [0 x i64] }*] }*), i32 0, i32 0 + %0 = getelementptr inbounds { i64, [0 x { i64, [0 x i64] }*] }, { i64, [0 x { i64, [0 x i64] }*] }* bitcast ({ i64, [10 x { i64, [0 x i64] }*] }* @sa.outer.c4a5911a.0 to { i64, [0 x { i64, [0 x i64] }*] }*), i32 0, i32 0 %1 = load i64, i64* %0, align 4 ret i64 %1 } diff --git a/hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__static_array__test__emit_static_array_of_static_array@pre-mem2reg@llvm14.snap b/hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__static_array__test__emit_static_array_of_static_array@pre-mem2reg@llvm14.snap index b0f0741226..be8b63018c 100644 --- a/hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__static_array__test__emit_static_array_of_static_array@pre-mem2reg@llvm14.snap +++ b/hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__static_array__test__emit_static_array_of_static_array@pre-mem2reg@llvm14.snap @@ -15,7 +15,7 @@ source_filename = "test_context" @sa.inner.7f5d5e16.0 = constant { i64, [7 x i64] } { i64 7, [7 x i64] [i64 7, i64 7, i64 7, i64 7, i64 7, i64 7, i64 7] } @sa.inner.a0bc9c53.0 = constant { i64, [8 x i64] } { i64 8, [8 x i64] [i64 8, i64 8, i64 8, i64 8, i64 8, i64 8, i64 8, i64 8] } @sa.inner.1e8aada3.0 = constant { i64, [9 x i64] } { i64 9, [9 x i64] [i64 9, i64 9, i64 9, i64 9, i64 9, i64 9, i64 9, i64 9, i64 9] } -@sa.outer.e55b610a.0 = constant { i64, [10 x { i64, [0 x i64] }*] } { i64 10, [10 x { i64, [0 x i64] }*] [{ i64, [0 x i64] }* @sa.inner.6acc1b76.0, { i64, [0 x i64] }* bitcast ({ i64, [1 x i64] }* @sa.inner.e637bb5.0 to { i64, [0 x i64] }*), { i64, [0 x i64] }* bitcast ({ i64, [2 x i64] }* @sa.inner.2b6593f.0 to { i64, [0 x i64] }*), { i64, [0 x i64] }* bitcast ({ i64, [3 x i64] }* @sa.inner.1b9ad7c.0 to { i64, [0 x i64] }*), { i64, [0 x i64] }* bitcast ({ i64, [4 x i64] }* @sa.inner.e67fbfa4.0 to { i64, [0 x i64] }*), { i64, [0 x i64] }* bitcast ({ i64, [5 x i64] }* @sa.inner.15dc27f6.0 to { i64, [0 x i64] }*), { i64, [0 x i64] }* bitcast ({ i64, [6 x i64] }* @sa.inner.c43a2bb2.0 to { i64, [0 x i64] }*), { i64, [0 x i64] }* bitcast ({ i64, [7 x i64] }* @sa.inner.7f5d5e16.0 to { i64, [0 x i64] }*), { i64, [0 x i64] }* bitcast ({ i64, [8 x i64] }* @sa.inner.a0bc9c53.0 to { i64, [0 x i64] }*), { i64, [0 x i64] }* bitcast ({ i64, [9 x i64] }* @sa.inner.1e8aada3.0 to { i64, [0 x i64] }*)] } +@sa.outer.c4a5911a.0 = constant { i64, [10 x { i64, [0 x i64] }*] } { i64 10, [10 x { i64, [0 x i64] }*] [{ i64, [0 x i64] }* @sa.inner.6acc1b76.0, { i64, [0 x i64] }* bitcast ({ i64, [1 x i64] }* @sa.inner.e637bb5.0 to { i64, [0 x i64] }*), { i64, [0 x i64] }* bitcast ({ i64, [2 x i64] }* @sa.inner.2b6593f.0 to { i64, [0 x i64] }*), { i64, [0 x i64] }* bitcast ({ i64, [3 x i64] }* @sa.inner.1b9ad7c.0 to { i64, [0 x i64] }*), { i64, [0 x i64] }* bitcast ({ i64, [4 x i64] }* @sa.inner.e67fbfa4.0 to { i64, [0 x i64] }*), { i64, [0 x i64] }* bitcast ({ i64, [5 x i64] }* @sa.inner.15dc27f6.0 to { i64, [0 x i64] }*), { i64, [0 x i64] }* bitcast ({ i64, [6 x i64] }* @sa.inner.c43a2bb2.0 to { i64, [0 x i64] }*), { i64, [0 x i64] }* bitcast ({ i64, [7 x i64] }* @sa.inner.7f5d5e16.0 to { i64, [0 x i64] }*), { i64, [0 x i64] }* bitcast ({ i64, [8 x i64] }* @sa.inner.a0bc9c53.0 to { i64, [0 x i64] }*), { i64, [0 x i64] }* bitcast ({ i64, [9 x i64] }* @sa.inner.1e8aada3.0 to { i64, [0 x i64] }*)] } define i64 @_hl.main.1() { alloca_block: @@ -25,7 +25,7 @@ alloca_block: br label %entry_block entry_block: ; preds = %alloca_block - store { i64, [0 x { i64, [0 x i64] }*] }* bitcast ({ i64, [10 x { i64, [0 x i64] }*] }* @sa.outer.e55b610a.0 to { i64, [0 x { i64, [0 x i64] }*] }*), { i64, [0 x { i64, [0 x i64] }*] }** %"5_0", align 8 + store { i64, [0 x { i64, [0 x i64] }*] }* bitcast ({ i64, [10 x { i64, [0 x i64] }*] }* @sa.outer.c4a5911a.0 to { i64, [0 x { i64, [0 x i64] }*] }*), { i64, [0 x { i64, [0 x i64] }*] }** %"5_0", align 8 %"5_01" = load { i64, [0 x { i64, [0 x i64] }*] }*, { i64, [0 x { i64, [0 x i64] }*] }** %"5_0", align 8 %0 = getelementptr inbounds { i64, [0 x { i64, [0 x i64] }*] }, { i64, [0 x { i64, [0 x i64] }*] }* %"5_01", i32 0, i32 0 %1 = load i64, i64* %0, align 4 diff --git a/hugr-llvm/src/extension/collections/stack_array.rs b/hugr-llvm/src/extension/collections/stack_array.rs index 297f539511..9361a298cd 100644 --- a/hugr-llvm/src/extension/collections/stack_array.rs +++ b/hugr-llvm/src/extension/collections/stack_array.rs @@ -126,7 +126,7 @@ impl CodegenExtension for ArrayCodegenExtension { .custom_type((array::EXTENSION_ID, array::ARRAY_TYPENAME), { let ccg = self.0.clone(); move |ts, hugr_type| { - let [TypeArg::BoundedNat(n), TypeArg::Runtime(ty)] = hugr_type.args() else { + let [TypeArg::BoundedNat { n }, TypeArg::Type { ty }] = hugr_type.args() else { return Err(anyhow!("Invalid type args for array type")); }; let elem_ty = ts.llvm_type(ty)?; @@ -726,7 +726,7 @@ fn emit_scan_op<'c, H: HugrView>( #[cfg(test)] mod test { - use hugr_core::builder::{DataflowHugr as _, HugrBuilder}; + use hugr_core::builder::Container as _; use hugr_core::extension::prelude::either_type; use hugr_core::ops::Tag; use hugr_core::std_extensions::STD_REG; @@ -770,7 +770,7 @@ mod test { build_all_array_ops(builder.dfg_builder_endo([]).unwrap()) .finish_sub_container() .unwrap(); - builder.finish_hugr().unwrap() + builder.finish_sub_container().unwrap() }); llvm_ctx.add_extensions(|cge| { cge.add_default_prelude_extensions() @@ -789,7 +789,7 @@ mod test { let arr = builder.add_new_array(usize_t(), [us1, us2]).unwrap(); let (_, arr) = builder.add_array_get(usize_t(), 2, arr, us1).unwrap(); builder.add_array_discard(usize_t(), 2, arr).unwrap(); - builder.finish_hugr_with_outputs([]).unwrap() + builder.finish_with_outputs([]).unwrap() }); llvm_ctx.add_extensions(|cge| { cge.add_default_prelude_extensions() @@ -809,7 +809,7 @@ mod test { let (arr1, arr2) = builder.add_array_clone(usize_t(), 2, arr).unwrap(); builder.add_array_discard(usize_t(), 2, arr1).unwrap(); builder.add_array_discard(usize_t(), 2, arr2).unwrap(); - builder.finish_hugr_with_outputs([]).unwrap() + builder.finish_with_outputs([]).unwrap() }); llvm_ctx.add_extensions(|cge| { cge.add_default_prelude_extensions() @@ -826,7 +826,7 @@ mod test { .finish(|mut builder| { let vs = vec![ConstUsize::new(1).into(), ConstUsize::new(2).into()]; let arr = builder.add_load_value(array::ArrayValue::new(usize_t(), vs)); - builder.finish_hugr_with_outputs([arr]).unwrap() + builder.finish_with_outputs([arr]).unwrap() }); llvm_ctx.add_extensions(|cge| { cge.add_default_prelude_extensions() @@ -885,7 +885,7 @@ mod test { } builder.finish_sub_container().unwrap().out_wire(0) }; - builder.finish_hugr_with_outputs([r]).unwrap() + builder.finish_with_outputs([r]).unwrap() }); exec_ctx.add_extensions(|cge| { cge.add_default_prelude_extensions() @@ -990,7 +990,7 @@ mod test { .unwrap(); conditional.finish_sub_container().unwrap().out_wire(0) }; - builder.finish_hugr_with_outputs([r]).unwrap() + builder.finish_with_outputs([r]).unwrap() }); exec_ctx.add_extensions(|cge| { cge.add_default_prelude_extensions() @@ -1097,7 +1097,7 @@ mod test { conditional.finish_sub_container().unwrap().out_wire(0) }; builder.add_array_discard(int_ty.clone(), 2, arr).unwrap(); - builder.finish_hugr_with_outputs([r]).unwrap() + builder.finish_with_outputs([r]).unwrap() }); exec_ctx.add_extensions(|cge| { cge.add_default_prelude_extensions() @@ -1154,7 +1154,7 @@ mod test { builder .add_array_discard(int_ty.clone(), 2, arr_clone) .unwrap(); - builder.finish_hugr_with_outputs([elem]).unwrap() + builder.finish_with_outputs([elem]).unwrap() }); exec_ctx.add_extensions(|cge| { cge.add_default_prelude_extensions() @@ -1224,7 +1224,7 @@ mod test { arr, ) .unwrap(); - builder.finish_hugr_with_outputs([r]).unwrap() + builder.finish_with_outputs([r]).unwrap() }); exec_ctx.add_extensions(|cge| { cge.add_default_prelude_extensions() @@ -1269,7 +1269,7 @@ mod test { r = builder.add_iadd(6, r, elem).unwrap(); } - builder.finish_hugr_with_outputs([r]).unwrap() + builder.finish_with_outputs([r]).unwrap() }); exec_ctx.add_extensions(|cge| { cge.add_default_prelude_extensions() @@ -1301,8 +1301,7 @@ mod test { .with_outs(int_ty.clone()) .with_extensions(exec_registry()) .finish(|mut builder| { - let mut mb = builder.module_root_builder(); - let mut func = mb + let mut func = builder .define_function("foo", Signature::new(vec![], vec![int_ty.clone()])) .unwrap(); let v = func.add_load_value(ConstInt::new_u(6, value).unwrap()); @@ -1323,7 +1322,7 @@ mod test { builder .add_array_discard(int_ty.clone(), size, arr) .unwrap(); - builder.finish_hugr_with_outputs([elem]).unwrap() + builder.finish_with_outputs([elem]).unwrap() }); exec_ctx.add_extensions(|cge| { cge.add_default_prelude_extensions() @@ -1356,8 +1355,7 @@ mod test { .add_new_array(int_ty.clone(), new_array_args) .unwrap(); - let mut mb = builder.module_root_builder(); - let mut func = mb + let mut func = builder .define_function( "foo", Signature::new(vec![int_ty.clone()], vec![int_ty.clone()]), @@ -1395,7 +1393,7 @@ mod test { builder .add_array_discard_empty(int_ty.clone(), arr) .unwrap(); - builder.finish_hugr_with_outputs([r]).unwrap() + builder.finish_with_outputs([r]).unwrap() }); exec_ctx.add_extensions(|cge| { cge.add_default_prelude_extensions() @@ -1414,6 +1412,7 @@ mod test { // We build a HUGR that: // - Creates an array [1, 2, 3, ..., size] // - Sums up the elements of the array using a scan and returns that sum + let int_ty = int_type(6); let hugr = SimpleHugrConfig::new() .with_outs(int_ty.clone()) @@ -1426,8 +1425,7 @@ mod test { .add_new_array(int_ty.clone(), new_array_args) .unwrap(); - let mut mb = builder.module_root_builder(); - let mut func = mb + let mut func = builder .define_function( "foo", Signature::new( @@ -1451,7 +1449,7 @@ mod test { .unwrap() .outputs_arr(); builder.add_array_discard(Type::UNIT, size, arr).unwrap(); - builder.finish_hugr_with_outputs([sum]).unwrap() + builder.finish_with_outputs([sum]).unwrap() }); exec_ctx.add_extensions(|cge| { cge.add_default_prelude_extensions() diff --git a/hugr-llvm/src/extension/collections/static_array.rs b/hugr-llvm/src/extension/collections/static_array.rs index 50ac99b723..e9df520ee3 100644 --- a/hugr-llvm/src/extension/collections/static_array.rs +++ b/hugr-llvm/src/extension/collections/static_array.rs @@ -370,7 +370,7 @@ impl CodegenExtension for StaticArrayCodegenE let sac = self.0.clone(); move |ts, custom_type| { let element_type = custom_type.args()[0] - .as_runtime() + .as_type() .expect("Type argument for static array must be a type"); sac.static_array_type(ts, &element_type) } @@ -394,7 +394,6 @@ impl CodegenExtension for StaticArrayCodegenE mod test { use super::*; use float_types::float64_type; - use hugr_core::builder::DataflowHugr; use hugr_core::extension::prelude::ConstUsize; use hugr_core::ops::OpType; use hugr_core::ops::Value; @@ -460,7 +459,7 @@ mod test { ])) .finish(|mut builder| { let a = builder.add_load_value(value); - builder.finish_hugr_with_outputs([a]).unwrap() + builder.finish_with_outputs([a]).unwrap() }); check_emission!(hugr, llvm_ctx); } @@ -513,7 +512,7 @@ mod test { } cond.finish_sub_container().unwrap().outputs_arr() }; - builder.finish_hugr_with_outputs([out]).unwrap() + builder.finish_with_outputs([out]).unwrap() }); exec_ctx.add_extensions(|ceb| { @@ -535,7 +534,7 @@ mod test { let arr = builder .add_load_value(StaticArrayValue::try_new("empty", usize_t(), vec![]).unwrap()); let len = builder.add_static_array_len(usize_t(), arr).unwrap(); - builder.finish_hugr_with_outputs([len]).unwrap() + builder.finish_with_outputs([len]).unwrap() }); exec_ctx.add_extensions(|ceb| { @@ -575,7 +574,7 @@ mod test { let len = builder .add_static_array_len(inner_arr_ty, outer_arr) .unwrap(); - builder.finish_hugr_with_outputs([len]).unwrap() + builder.finish_with_outputs([len]).unwrap() }); check_emission!(hugr, llvm_ctx); } diff --git a/hugr-llvm/src/extension/conversions.rs b/hugr-llvm/src/extension/conversions.rs index 0ed8ec88c2..cbc036719b 100644 --- a/hugr-llvm/src/extension/conversions.rs +++ b/hugr-llvm/src/extension/conversions.rs @@ -275,7 +275,7 @@ mod test { use crate::check_emission; use crate::emit::test::{DFGW, SimpleHugrConfig}; use crate::test::{TestContext, exec_ctx, llvm_ctx}; - use hugr_core::builder::{DataflowHugr, SubContainer}; + use hugr_core::builder::SubContainer; use hugr_core::std_extensions::STD_REG; use hugr_core::std_extensions::arithmetic::float_types::ConstF64; use hugr_core::std_extensions::arithmetic::int_types::ConstInt; @@ -311,7 +311,7 @@ mod test { .add_dataflow_op(ext_op, [in1]) .unwrap() .outputs(); - hugr_builder.finish_hugr_with_outputs(outputs).unwrap() + hugr_builder.finish_with_outputs(outputs).unwrap() }) } @@ -381,7 +381,7 @@ mod test { .add_dataflow_op(ext_op, [in1]) .unwrap() .outputs_arr(); - hugr_builder.finish_hugr_with_outputs([out1]).unwrap() + hugr_builder.finish_with_outputs([out1]).unwrap() }); check_emission!(op_name, hugr, llvm_ctx); } @@ -393,7 +393,7 @@ mod test { .with_extensions(PRELUDE_REGISTRY.to_owned()) .finish(|mut builder: DFGW| { let konst = builder.add_load_value(ConstUsize::new(42)); - builder.finish_hugr_with_outputs([konst]).unwrap() + builder.finish_with_outputs([konst]).unwrap() }); exec_ctx.add_extensions(CodegenExtsBuilder::add_default_prelude_extensions); assert_eq!(42, exec_ctx.exec_hugr_u64(hugr, "main")); @@ -417,7 +417,7 @@ mod test { .add_dataflow_op(ConvertOpDef::itousize.without_log_width(), [int]) .unwrap() .outputs_arr(); - builder.finish_hugr_with_outputs([usize_]).unwrap() + builder.finish_with_outputs([usize_]).unwrap() }); exec_ctx.add_extensions(|builder| { builder @@ -481,7 +481,7 @@ mod test { .add_dataflow_op(ConvertOpDef::itousize.without_log_width(), [cond_result]) .unwrap() .outputs_arr(); - builder.finish_hugr_with_outputs([usize_]).unwrap() + builder.finish_with_outputs([usize_]).unwrap() }) } @@ -613,7 +613,7 @@ mod test { let true_result = case_true.add_load_value(ConstUsize::new(6)); case_true.finish_with_outputs([true_result]).unwrap(); let res = cond.finish_sub_container().unwrap(); - builder.finish_hugr_with_outputs(res.outputs()).unwrap() + builder.finish_with_outputs(res.outputs()).unwrap() }); exec_ctx.add_extensions(|builder| { builder @@ -635,7 +635,7 @@ mod test { let [b] = builder.add_dataflow_op(i2b, [i]).unwrap().outputs_arr(); let b2i = EXTENSION.instantiate_extension_op("ifrombool", []).unwrap(); let [i] = builder.add_dataflow_op(b2i, [b]).unwrap().outputs_arr(); - builder.finish_hugr_with_outputs([i]).unwrap() + builder.finish_with_outputs([i]).unwrap() }); exec_ctx.add_extensions(|builder| { builder @@ -663,7 +663,7 @@ mod test { .instantiate_extension_op("bytecast_int64_to_float64", []) .unwrap(); let [f] = builder.add_dataflow_op(i2f, [i]).unwrap().outputs_arr(); - builder.finish_hugr_with_outputs([f]).unwrap() + builder.finish_with_outputs([f]).unwrap() }); exec_ctx.add_extensions(|builder| { builder @@ -690,7 +690,7 @@ mod test { .instantiate_extension_op("bytecast_float64_to_int64", []) .unwrap(); let [i] = builder.add_dataflow_op(f2i, [f]).unwrap().outputs_arr(); - builder.finish_hugr_with_outputs([i]).unwrap() + builder.finish_with_outputs([i]).unwrap() }); exec_ctx.add_extensions(|builder| { builder diff --git a/hugr-llvm/src/extension/float.rs b/hugr-llvm/src/extension/float.rs index 968ae3f585..b95a698b18 100644 --- a/hugr-llvm/src/extension/float.rs +++ b/hugr-llvm/src/extension/float.rs @@ -149,14 +149,13 @@ impl<'a, H: HugrView + 'a> CodegenExtsBuilder<'a, H> { #[cfg(test)] mod test { use hugr_core::Hugr; - use hugr_core::builder::DataflowHugr; use hugr_core::extension::SignatureFunc; use hugr_core::extension::simple_op::MakeOpDef; use hugr_core::std_extensions::STD_REG; use hugr_core::std_extensions::arithmetic::float_ops::FloatOps; use hugr_core::types::TypeRow; use hugr_core::{ - builder::Dataflow, + builder::{Dataflow, DataflowSubContainer}, std_extensions::arithmetic::float_types::{ConstF64, float64_type}, }; use rstest::rstest; @@ -185,7 +184,7 @@ mod test { .add_dataflow_op(op, builder.input_wires()) .unwrap() .outputs(); - builder.finish_hugr_with_outputs(outputs).unwrap() + builder.finish_with_outputs(outputs).unwrap() }) } @@ -197,7 +196,7 @@ mod test { .with_extensions(STD_REG.to_owned()) .finish(|mut builder| { let c = builder.add_load_value(ConstF64::new(3.12)); - builder.finish_hugr_with_outputs([c]).unwrap() + builder.finish_with_outputs([c]).unwrap() }); check_emission!(hugr, llvm_ctx); } diff --git a/hugr-llvm/src/extension/int.rs b/hugr-llvm/src/extension/int.rs index bea508d774..315c7c7296 100644 --- a/hugr-llvm/src/extension/int.rs +++ b/hugr-llvm/src/extension/int.rs @@ -668,7 +668,7 @@ fn emit_int_op<'c, H: HugrView>( ]) }), IntOpDef::inarrow_s => { - let Some(TypeArg::BoundedNat(out_log_width)) = args.node().args().last().cloned() + let Some(TypeArg::BoundedNat { n: out_log_width }) = args.node().args().last().cloned() else { bail!("Type arg to inarrow_s wasn't a Nat"); }; @@ -686,7 +686,7 @@ fn emit_int_op<'c, H: HugrView>( }) } IntOpDef::inarrow_u => { - let Some(TypeArg::BoundedNat(out_log_width)) = args.node().args().last().cloned() + let Some(TypeArg::BoundedNat { n: out_log_width }) = args.node().args().last().cloned() else { bail!("Type arg to inarrow_u wasn't a Nat"); }; @@ -756,7 +756,7 @@ pub(crate) fn get_width_arg>( args: &EmitOpArgs<'_, '_, ExtensionOp, H>, op: &impl MakeExtensionOp, ) -> Result { - let [TypeArg::BoundedNat(log_width)] = args.node.args() else { + let [TypeArg::BoundedNat { n: log_width }] = args.node.args() else { bail!( "Expected exactly one BoundedNat parameter to {}", op.op_id() @@ -1094,7 +1094,7 @@ fn llvm_type<'c>( context: TypingSession<'c, '_>, hugr_type: &CustomType, ) -> Result> { - if let [TypeArg::BoundedNat(n)] = hugr_type.args() { + if let [TypeArg::BoundedNat { n }] = hugr_type.args() { let m = *n as usize; if m < int_types::INT_TYPES.len() && int_types::INT_TYPES[m] == hugr_type.clone().into() { return Ok(match m { @@ -1141,7 +1141,6 @@ impl<'a, H: HugrView + 'a> CodegenExtsBuilder<'a, H> { #[cfg(test)] mod test { use anyhow::Result; - use hugr_core::builder::DataflowHugr; use hugr_core::extension::prelude::{ConstError, UnwrapBuilder, error_type}; use hugr_core::std_extensions::STD_REG; use hugr_core::{ @@ -1243,9 +1242,7 @@ mod test { .unwrap() .outputs(); let processed_outputs = process(&mut hugr_builder, outputs).unwrap(); - hugr_builder - .finish_hugr_with_outputs(processed_outputs) - .unwrap() + hugr_builder.finish_with_outputs(processed_outputs).unwrap() }) } @@ -1581,7 +1578,7 @@ mod test { .add_dataflow_op(iu_to_s, [unsigned]) .unwrap() .outputs_arr(); - hugr_builder.finish_hugr_with_outputs([signed]).unwrap() + hugr_builder.finish_with_outputs([signed]).unwrap() }); let act = int_exec_ctx.exec_hugr_i64(hugr, "main"); assert_eq!(act, val as i64); @@ -1608,7 +1605,7 @@ mod test { .add_dataflow_op(make_int_op("iadd", log_width), [unsigned, num]) .unwrap() .outputs_arr(); - hugr_builder.finish_hugr_with_outputs([res]).unwrap() + hugr_builder.finish_with_outputs([res]).unwrap() }); let act = int_exec_ctx.exec_hugr_u64(hugr, "main"); assert_eq!(act, (val as u64) + 42); diff --git a/hugr-llvm/src/extension/logic.rs b/hugr-llvm/src/extension/logic.rs index b382a21408..50dd2bd17c 100644 --- a/hugr-llvm/src/extension/logic.rs +++ b/hugr-llvm/src/extension/logic.rs @@ -76,7 +76,7 @@ impl<'a, H: HugrView + 'a> CodegenExtsBuilder<'a, H> { mod test { use hugr_core::{ Hugr, - builder::{Dataflow, DataflowHugr}, + builder::{Dataflow, DataflowSubContainer}, extension::{ExtensionRegistry, prelude::bool_t}, std_extensions::logic::{self, LogicOp}, }; @@ -99,7 +99,7 @@ mod test { .add_dataflow_op(op, builder.input_wires()) .unwrap() .outputs(); - builder.finish_hugr_with_outputs(outputs).unwrap() + builder.finish_with_outputs(outputs).unwrap() }) } diff --git a/hugr-llvm/src/extension/prelude.rs b/hugr-llvm/src/extension/prelude.rs index d4b918b559..62a00527c8 100644 --- a/hugr-llvm/src/extension/prelude.rs +++ b/hugr-llvm/src/extension/prelude.rs @@ -117,37 +117,6 @@ pub trait PreludeCodegen: Clone { Ok(err.into()) } - /// Emit instructions to construct an error value from a signal and message. - /// - /// The type of the returned value must match [`Self::error_type`]. - /// - /// The default implementation constructs a struct with the given signal and message. - fn emit_make_error<'c, H: HugrView>( - &self, - ctx: &mut EmitFuncContext<'c, '_, H>, - signal: BasicValueEnum<'c>, - message: BasicValueEnum<'c>, - ) -> Result> { - let builder = ctx.builder(); - - // The usize signal is an i64 but error struct stores an i32. - let i32_type = ctx.typing_session().iw_context().i32_type(); - let signal_int = signal.into_int_value(); - let signal_truncated = builder.build_int_truncate(signal_int, i32_type, "")?; - - // Construct the error struct as runtime value. - let err_ty = ctx.llvm_type(&error_type())?.into_struct_type(); - let undef = err_ty.get_undef(); - let err_with_sig = builder - .build_insert_value(undef, signal_truncated, 0, "")? - .into_struct_value(); - let err_complete = builder - .build_insert_value(err_with_sig, message, 1, "")? - .into_struct_value(); - - Ok(err_complete.into()) - } - /// Emit instructions to halt execution with the error `err`. /// /// The type of `err` must match that returned from [`Self::error_type`]. @@ -376,22 +345,6 @@ pub fn add_prelude_extensions<'a, H: HugrView + 'a>( args.outputs.finish(context.builder(), []) } }) - .extension_op(prelude::PRELUDE_ID, prelude::MAKE_ERROR_OP_ID, { - let pcg = pcg.clone(); - move |context, args| { - let signal = args.inputs[0]; - let message = args.inputs[1]; - ensure!( - message.get_type() - == pcg - .string_type(&context.typing_session())? - .as_basic_type_enum(), - signal.get_type() == pcg.usize_type(&context.typing_session()).into() - ); - let err = pcg.emit_make_error(context, signal, message)?; - args.outputs.finish(context.builder(), [err]) - } - }) .extension_op(prelude::PRELUDE_ID, prelude::PANIC_OP_ID, { let pcg = pcg.clone(); move |context, args| { @@ -436,7 +389,7 @@ pub fn add_prelude_extensions<'a, H: HugrView + 'a>( move |context, args| { let load_nat = LoadNat::from_extension_op(args.node().as_ref())?; let v = match load_nat.get_nat() { - TypeArg::BoundedNat(n) => pcg + TypeArg::BoundedNat { n } => pcg .usize_type(&context.typing_session()) .const_int(n, false), arg => bail!("Unexpected type arg for LoadNat: {}", arg), @@ -452,10 +405,10 @@ pub fn add_prelude_extensions<'a, H: HugrView + 'a>( #[cfg(test)] mod test { - use hugr_core::builder::{Dataflow, DataflowHugr}; + use hugr_core::builder::{Dataflow, DataflowSubContainer}; use hugr_core::extension::PRELUDE; - use hugr_core::extension::prelude::{EXIT_OP_ID, MAKE_ERROR_OP_ID, Noop}; - use hugr_core::types::{Term, Type}; + use hugr_core::extension::prelude::{EXIT_OP_ID, Noop}; + use hugr_core::types::{Type, TypeArg}; use hugr_core::{Hugr, type_row}; use prelude::{PANIC_OP_ID, PRINT_OP_ID, bool_t, qb_t, usize_t}; use rstest::{fixture, rstest}; @@ -526,7 +479,7 @@ mod test { .with_extensions(prelude::PRELUDE_REGISTRY.to_owned()) .finish(|mut builder| { let k = builder.add_load_value(ConstUsize::new(17)); - builder.finish_hugr_with_outputs([k]).unwrap() + builder.finish_with_outputs([k]).unwrap() }); check_emission!(hugr, prelude_llvm_ctx); } @@ -549,7 +502,7 @@ mod test { .finish(|mut builder| { let k1 = builder.add_load_value(konst1); let k2 = builder.add_load_value(konst2); - builder.finish_hugr_with_outputs([k1, k2]).unwrap() + builder.finish_with_outputs([k1, k2]).unwrap() }); check_emission!(hugr, prelude_llvm_ctx); } @@ -566,7 +519,7 @@ mod test { .add_dataflow_op(Noop::new(usize_t()), in_wires) .unwrap() .outputs(); - builder.finish_hugr_with_outputs(r).unwrap() + builder.finish_with_outputs(r).unwrap() }); check_emission!(hugr, prelude_llvm_ctx); } @@ -580,7 +533,7 @@ mod test { .finish(|mut builder| { let in_wires = builder.input_wires(); let r = builder.make_tuple(in_wires).unwrap(); - builder.finish_hugr_with_outputs([r]).unwrap() + builder.finish_with_outputs([r]).unwrap() }); check_emission!(hugr, prelude_llvm_ctx); } @@ -598,7 +551,7 @@ mod test { builder.input_wires(), ) .unwrap(); - builder.finish_hugr_with_outputs(unpack.outputs()).unwrap() + builder.finish_with_outputs(unpack.outputs()).unwrap() }); check_emission!(hugr, prelude_llvm_ctx); } @@ -606,8 +559,10 @@ mod test { #[rstest] fn prelude_panic(prelude_llvm_ctx: TestContext) { let error_val = ConstError::new(42, "PANIC"); - let type_arg_q: Term = qb_t().into(); - let type_arg_2q = Term::new_list([type_arg_q.clone(), type_arg_q]); + let type_arg_q: TypeArg = TypeArg::Type { ty: qb_t() }; + let type_arg_2q: TypeArg = TypeArg::Sequence { + elems: vec![type_arg_q.clone(), type_arg_q], + }; let panic_op = PRELUDE .instantiate_extension_op(&PANIC_OP_ID, [type_arg_2q.clone(), type_arg_2q.clone()]) .unwrap(); @@ -623,7 +578,7 @@ mod test { .add_dataflow_op(panic_op, [err, q0, q1]) .unwrap() .outputs_arr(); - builder.finish_hugr_with_outputs([q0, q1]).unwrap() + builder.finish_with_outputs([q0, q1]).unwrap() }); check_emission!(hugr, prelude_llvm_ctx); @@ -632,8 +587,10 @@ mod test { #[rstest] fn prelude_exit(prelude_llvm_ctx: TestContext) { let error_val = ConstError::new(42, "EXIT"); - let type_arg_q: Term = qb_t().into(); - let type_arg_2q = Term::new_list([type_arg_q.clone(), type_arg_q]); + let type_arg_q: TypeArg = TypeArg::Type { ty: qb_t() }; + let type_arg_2q: TypeArg = TypeArg::Sequence { + elems: vec![type_arg_q.clone(), type_arg_q], + }; let exit_op = PRELUDE .instantiate_extension_op(&EXIT_OP_ID, [type_arg_2q.clone(), type_arg_2q.clone()]) .unwrap(); @@ -649,7 +606,7 @@ mod test { .add_dataflow_op(exit_op, [err, q0, q1]) .unwrap() .outputs_arr(); - builder.finish_hugr_with_outputs([q0, q1]).unwrap() + builder.finish_with_outputs([q0, q1]).unwrap() }); check_emission!(hugr, prelude_llvm_ctx); @@ -665,61 +622,7 @@ mod test { .finish(|mut builder| { let greeting_out = builder.add_load_value(greeting); builder.add_dataflow_op(print_op, [greeting_out]).unwrap(); - builder.finish_hugr_with_outputs([]).unwrap() - }); - - check_emission!(hugr, prelude_llvm_ctx); - } - - #[rstest] - fn prelude_make_error(prelude_llvm_ctx: TestContext) { - let sig: ConstUsize = ConstUsize::new(100); - let msg: ConstString = ConstString::new("Error!".into()); - - let make_error_op = PRELUDE - .instantiate_extension_op(&MAKE_ERROR_OP_ID, []) - .unwrap(); - - let hugr = SimpleHugrConfig::new() - .with_extensions(prelude::PRELUDE_REGISTRY.to_owned()) - .with_outs(error_type()) - .finish(|mut builder| { - let sig_out = builder.add_load_value(sig); - let msg_out = builder.add_load_value(msg); - let [err] = builder - .add_dataflow_op(make_error_op, [sig_out, msg_out]) - .unwrap() - .outputs_arr(); - builder.finish_hugr_with_outputs([err]).unwrap() - }); - - check_emission!(hugr, prelude_llvm_ctx); - } - - #[rstest] - fn prelude_make_error_and_panic(prelude_llvm_ctx: TestContext) { - let sig: ConstUsize = ConstUsize::new(100); - let msg: ConstString = ConstString::new("Error!".into()); - - let make_error_op = PRELUDE - .instantiate_extension_op(&MAKE_ERROR_OP_ID, []) - .unwrap(); - - let panic_op = PRELUDE - .instantiate_extension_op(&PANIC_OP_ID, [Term::new_list([]), Term::new_list([])]) - .unwrap(); - - let hugr = SimpleHugrConfig::new() - .with_extensions(prelude::PRELUDE_REGISTRY.to_owned()) - .finish(|mut builder| { - let sig_out = builder.add_load_value(sig); - let msg_out = builder.add_load_value(msg); - let [err] = builder - .add_dataflow_op(make_error_op, [sig_out, msg_out]) - .unwrap() - .outputs_arr(); - builder.add_dataflow_op(panic_op, [err]).unwrap(); - builder.finish_hugr_with_outputs([]).unwrap() + builder.finish_with_outputs([]).unwrap() }); check_emission!(hugr, prelude_llvm_ctx); @@ -732,10 +635,10 @@ mod test { .with_extensions(prelude::PRELUDE_REGISTRY.to_owned()) .finish(|mut builder| { let v = builder - .add_dataflow_op(LoadNat::new(42u64.into()), vec![]) + .add_dataflow_op(LoadNat::new(TypeArg::BoundedNat { n: 42 }), vec![]) .unwrap() .out_wire(0); - builder.finish_hugr_with_outputs([v]).unwrap() + builder.finish_with_outputs([v]).unwrap() }); check_emission!(hugr, prelude_llvm_ctx); } @@ -748,7 +651,7 @@ mod test { .finish(|mut builder| { let i = builder.add_load_value(ConstUsize::new(42)); let [w1, _w2] = builder.add_barrier([i, i]).unwrap().outputs_arr(); - builder.finish_hugr_with_outputs([w1]).unwrap() + builder.finish_with_outputs([w1]).unwrap() }) } diff --git a/hugr-llvm/src/extension/snapshots/hugr_llvm__extension__prelude__test__prelude_make_error@llvm14.snap b/hugr-llvm/src/extension/snapshots/hugr_llvm__extension__prelude__test__prelude_make_error@llvm14.snap deleted file mode 100644 index 2a543d0e11..0000000000 --- a/hugr-llvm/src/extension/snapshots/hugr_llvm__extension__prelude__test__prelude_make_error@llvm14.snap +++ /dev/null @@ -1,19 +0,0 @@ ---- -source: hugr-llvm/src/extension/prelude.rs -expression: mod_str ---- -; ModuleID = 'test_context' -source_filename = "test_context" - -@0 = private unnamed_addr constant [7 x i8] c"Error!\00", align 1 - -define { i32, i8* } @_hl.main.1() { -alloca_block: - br label %entry_block - -entry_block: ; preds = %alloca_block - %0 = trunc i64 100 to i32 - %1 = insertvalue { i32, i8* } undef, i32 %0, 0 - %2 = insertvalue { i32, i8* } %1, i8* getelementptr inbounds ([7 x i8], [7 x i8]* @0, i32 0, i32 0), 1 - ret { i32, i8* } %2 -} diff --git a/hugr-llvm/src/extension/snapshots/hugr_llvm__extension__prelude__test__prelude_make_error@pre-mem2reg@llvm14.snap b/hugr-llvm/src/extension/snapshots/hugr_llvm__extension__prelude__test__prelude_make_error@pre-mem2reg@llvm14.snap deleted file mode 100644 index d061dc36a3..0000000000 --- a/hugr-llvm/src/extension/snapshots/hugr_llvm__extension__prelude__test__prelude_make_error@pre-mem2reg@llvm14.snap +++ /dev/null @@ -1,31 +0,0 @@ ---- -source: hugr-llvm/src/extension/prelude.rs -expression: mod_str ---- -; ModuleID = 'test_context' -source_filename = "test_context" - -@0 = private unnamed_addr constant [7 x i8] c"Error!\00", align 1 - -define { i32, i8* } @_hl.main.1() { -alloca_block: - %"0" = alloca { i32, i8* }, align 8 - %"7_0" = alloca i8*, align 8 - %"5_0" = alloca i64, align 8 - %"8_0" = alloca { i32, i8* }, align 8 - br label %entry_block - -entry_block: ; preds = %alloca_block - store i8* getelementptr inbounds ([7 x i8], [7 x i8]* @0, i32 0, i32 0), i8** %"7_0", align 8 - store i64 100, i64* %"5_0", align 4 - %"5_01" = load i64, i64* %"5_0", align 4 - %"7_02" = load i8*, i8** %"7_0", align 8 - %0 = trunc i64 %"5_01" to i32 - %1 = insertvalue { i32, i8* } undef, i32 %0, 0 - %2 = insertvalue { i32, i8* } %1, i8* %"7_02", 1 - store { i32, i8* } %2, { i32, i8* }* %"8_0", align 8 - %"8_03" = load { i32, i8* }, { i32, i8* }* %"8_0", align 8 - store { i32, i8* } %"8_03", { i32, i8* }* %"0", align 8 - %"04" = load { i32, i8* }, { i32, i8* }* %"0", align 8 - ret { i32, i8* } %"04" -} diff --git a/hugr-llvm/src/extension/snapshots/hugr_llvm__extension__prelude__test__prelude_make_error_and_panic@llvm14.snap b/hugr-llvm/src/extension/snapshots/hugr_llvm__extension__prelude__test__prelude_make_error_and_panic@llvm14.snap deleted file mode 100644 index fdaae15e98..0000000000 --- a/hugr-llvm/src/extension/snapshots/hugr_llvm__extension__prelude__test__prelude_make_error_and_panic@llvm14.snap +++ /dev/null @@ -1,28 +0,0 @@ ---- -source: hugr-llvm/src/extension/prelude.rs -expression: mod_str ---- -; ModuleID = 'test_context' -source_filename = "test_context" - -@0 = private unnamed_addr constant [7 x i8] c"Error!\00", align 1 -@prelude.panic_template = private unnamed_addr constant [34 x i8] c"Program panicked (signal %i): %s\0A\00", align 1 - -define void @_hl.main.1() { -alloca_block: - br label %entry_block - -entry_block: ; preds = %alloca_block - %0 = trunc i64 100 to i32 - %1 = insertvalue { i32, i8* } undef, i32 %0, 0 - %2 = insertvalue { i32, i8* } %1, i8* getelementptr inbounds ([7 x i8], [7 x i8]* @0, i32 0, i32 0), 1 - %3 = extractvalue { i32, i8* } %2, 0 - %4 = extractvalue { i32, i8* } %2, 1 - %5 = call i32 (i8*, ...) @printf(i8* getelementptr inbounds ([34 x i8], [34 x i8]* @prelude.panic_template, i32 0, i32 0), i32 %3, i8* %4) - call void @abort() - ret void -} - -declare i32 @printf(i8*, ...) - -declare void @abort() diff --git a/hugr-llvm/src/extension/snapshots/hugr_llvm__extension__prelude__test__prelude_make_error_and_panic@pre-mem2reg@llvm14.snap b/hugr-llvm/src/extension/snapshots/hugr_llvm__extension__prelude__test__prelude_make_error_and_panic@pre-mem2reg@llvm14.snap deleted file mode 100644 index 8ff4526e04..0000000000 --- a/hugr-llvm/src/extension/snapshots/hugr_llvm__extension__prelude__test__prelude_make_error_and_panic@pre-mem2reg@llvm14.snap +++ /dev/null @@ -1,37 +0,0 @@ ---- -source: hugr-llvm/src/extension/prelude.rs -expression: mod_str ---- -; ModuleID = 'test_context' -source_filename = "test_context" - -@0 = private unnamed_addr constant [7 x i8] c"Error!\00", align 1 -@prelude.panic_template = private unnamed_addr constant [34 x i8] c"Program panicked (signal %i): %s\0A\00", align 1 - -define void @_hl.main.1() { -alloca_block: - %"7_0" = alloca i8*, align 8 - %"5_0" = alloca i64, align 8 - %"8_0" = alloca { i32, i8* }, align 8 - br label %entry_block - -entry_block: ; preds = %alloca_block - store i8* getelementptr inbounds ([7 x i8], [7 x i8]* @0, i32 0, i32 0), i8** %"7_0", align 8 - store i64 100, i64* %"5_0", align 4 - %"5_01" = load i64, i64* %"5_0", align 4 - %"7_02" = load i8*, i8** %"7_0", align 8 - %0 = trunc i64 %"5_01" to i32 - %1 = insertvalue { i32, i8* } undef, i32 %0, 0 - %2 = insertvalue { i32, i8* } %1, i8* %"7_02", 1 - store { i32, i8* } %2, { i32, i8* }* %"8_0", align 8 - %"8_03" = load { i32, i8* }, { i32, i8* }* %"8_0", align 8 - %3 = extractvalue { i32, i8* } %"8_03", 0 - %4 = extractvalue { i32, i8* } %"8_03", 1 - %5 = call i32 (i8*, ...) @printf(i8* getelementptr inbounds ([34 x i8], [34 x i8]* @prelude.panic_template, i32 0, i32 0), i32 %3, i8* %4) - call void @abort() - ret void -} - -declare i32 @printf(i8*, ...) - -declare void @abort() diff --git a/hugr-llvm/src/test.rs b/hugr-llvm/src/test.rs index 59919baad4..9864ae12e1 100644 --- a/hugr-llvm/src/test.rs +++ b/hugr-llvm/src/test.rs @@ -2,7 +2,7 @@ use std::rc::Rc; use hugr_core::{ Hugr, - builder::{Dataflow, DataflowSubContainer, HugrBuilder, ModuleBuilder}, + builder::{Container, Dataflow, DataflowSubContainer, HugrBuilder, ModuleBuilder}, ops::{OpTrait, OpType}, types::PolyFuncType, }; diff --git a/hugr-llvm/src/utils/fat.rs b/hugr-llvm/src/utils/fat.rs index 1476bcb484..1b046ddf02 100644 --- a/hugr-llvm/src/utils/fat.rs +++ b/hugr-llvm/src/utils/fat.rs @@ -8,7 +8,7 @@ use hugr_core::hugr::views::Rerooted; use hugr_core::{ Hugr, HugrView, IncomingPort, Node, NodeIndex, OutgoingPort, core::HugrNode, - ops::{CFG, DataflowBlock, ExitBlock, Input, Module, OpType, Output}, + ops::{CFG, DataflowBlock, ExitBlock, Input, OpType, Output}, types::Type, }; use itertools::Itertools as _; @@ -373,12 +373,7 @@ pub trait FatExt: HugrView { } /// Try to create a specific [`FatNode`] for the root of a [`HugrView`]. - fn fat_root(&self) -> Option> { - self.try_fat(self.module_root()) - } - - /// Try to create a specific [`FatNode`] for the entrypoint of a [`HugrView`]. - fn fat_entrypoint(&self) -> Option> + fn fat_root(&self) -> Option> where for<'a> &'a OpType: TryInto<&'a OT>, { diff --git a/hugr-model/CHANGELOG.md b/hugr-model/CHANGELOG.md index 4a2e4ccb35..901f077d14 100644 --- a/hugr-model/CHANGELOG.md +++ b/hugr-model/CHANGELOG.md @@ -1,29 +1,5 @@ # Changelog - -## [0.22.0](https://github.com/CQCL/hugr/compare/hugr-model-v0.21.0...hugr-model-v0.22.0) - 2025-07-24 - -### New Features - -- Names of private functions become `core.title` metadata. ([#2448](https://github.com/CQCL/hugr/pull/2448)) -- include generator metatada in model import and cli validate errors ([#2452](https://github.com/CQCL/hugr/pull/2452)) -- Version number in hugr binary format. ([#2468](https://github.com/CQCL/hugr/pull/2468)) -- Use semver crate for -model version, and include in docs ([#2471](https://github.com/CQCL/hugr/pull/2471)) -## [0.21.0](https://github.com/CQCL/hugr/compare/hugr-model-v0.20.2...hugr-model-v0.21.0) - 2025-07-09 - -### Bug Fixes - -- Model import should perform extension resolution ([#2326](https://github.com/CQCL/hugr/pull/2326)) -- [**breaking**] Fixed bugs in model CFG handling and improved CFG signatures ([#2334](https://github.com/CQCL/hugr/pull/2334)) -- [**breaking**] Fix panic in model resolver when variable is used outside of symbol. ([#2362](https://github.com/CQCL/hugr/pull/2362)) -- Order hints on input and output nodes. ([#2422](https://github.com/CQCL/hugr/pull/2422)) - -### New Features - -- [**breaking**] Added float and bytes literal to core and python bindings. ([#2289](https://github.com/CQCL/hugr/pull/2289)) -- [**breaking**] Add Visibility to FuncDefn/FuncDecl. ([#2143](https://github.com/CQCL/hugr/pull/2143)) -- [**breaking**] hugr-model use explicit Option, with ::Unspecified in capnp ([#2424](https://github.com/CQCL/hugr/pull/2424)) - ## [0.20.2](https://github.com/CQCL/hugr/compare/hugr-model-v0.20.1...hugr-model-v0.20.2) - 2025-06-25 ### New Features diff --git a/hugr-model/Cargo.toml b/hugr-model/Cargo.toml index fb67977250..24c1f9e42e 100644 --- a/hugr-model/Cargo.toml +++ b/hugr-model/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "hugr-model" -version = "0.22.1" +version = "0.20.2" readme = "README.md" documentation = "https://docs.rs/hugr-model/" description = "Data model for Quantinuum's HUGR intermediate representation" @@ -27,7 +27,6 @@ ordered-float = { workspace = true } pest = { workspace = true } pest_derive = { workspace = true } pretty = { workspace = true } -semver = { workspace = true } smol_str = { workspace = true, features = ["serde"] } thiserror.workspace = true pyo3 = { workspace = true, optional = true, features = ["extension-module"] } diff --git a/hugr-model/FORMAT_VERSION b/hugr-model/FORMAT_VERSION deleted file mode 100644 index 3eefcb9dd5..0000000000 --- a/hugr-model/FORMAT_VERSION +++ /dev/null @@ -1 +0,0 @@ -1.0.0 diff --git a/hugr-model/capnp/hugr-v0.capnp b/hugr-model/capnp/hugr-v0.capnp index 7891b7f245..f69beb18f7 100644 --- a/hugr-model/capnp/hugr-v0.capnp +++ b/hugr-model/capnp/hugr-v0.capnp @@ -20,12 +20,6 @@ using LinkIndex = UInt32; struct Package { modules @0 :List(Module); - version @1 :Version; -} - -struct Version { - major @0 :UInt32; - minor @1 :UInt32; } struct Module { @@ -67,7 +61,6 @@ struct Operation { } struct Symbol { - visibility @4 :Visibility; name @0 :Text; params @1 :List(Param); constraints @2 :List(TermId); @@ -127,9 +120,3 @@ struct Param { name @0 :Text; type @1 :TermId; } - -enum Visibility { - unspecified @0; - private @1; - public @2; -} diff --git a/hugr-model/src/capnp/hugr_v0_capnp.rs b/hugr-model/src/capnp/hugr_v0_capnp.rs index 9e0a6ed7b3..aea608cfde 100644 --- a/hugr-model/src/capnp/hugr_v0_capnp.rs +++ b/hugr-model/src/capnp/hugr_v0_capnp.rs @@ -72,19 +72,11 @@ pub mod package { pub fn has_modules(&self) -> bool { !self.reader.get_pointer_field(0).is_null() } - #[inline] - pub fn get_version(self) -> ::capnp::Result> { - ::capnp::traits::FromPointerReader::get_from_pointer(&self.reader.get_pointer_field(1), ::core::option::Option::None) - } - #[inline] - pub fn has_version(&self) -> bool { - !self.reader.get_pointer_field(1).is_null() - } } pub struct Builder<'a> { builder: ::capnp::private::layout::StructBuilder<'a> } impl <> ::capnp::traits::HasStructSize for Builder<'_,> { - const STRUCT_SIZE: ::capnp::private::layout::StructSize = ::capnp::private::layout::StructSize { data: 0, pointers: 2 }; + const STRUCT_SIZE: ::capnp::private::layout::StructSize = ::capnp::private::layout::StructSize { data: 0, pointers: 1 }; } impl <> ::capnp::traits::HasTypeId for Builder<'_,> { const TYPE_ID: u64 = _private::TYPE_ID; @@ -150,22 +142,6 @@ pub mod package { pub fn has_modules(&self) -> bool { !self.builder.is_pointer_field_null(0) } - #[inline] - pub fn get_version(self) -> ::capnp::Result> { - ::capnp::traits::FromPointerBuilder::get_from_pointer(self.builder.get_pointer_field(1), ::core::option::Option::None) - } - #[inline] - pub fn set_version(&mut self, value: crate::hugr_v0_capnp::version::Reader<'_>) -> ::capnp::Result<()> { - ::capnp::traits::SetterInput::set_pointer_builder(self.builder.reborrow().get_pointer_field(1), value, false) - } - #[inline] - pub fn init_version(self, ) -> crate::hugr_v0_capnp::version::Builder<'a> { - ::capnp::traits::FromPointerBuilder::init_pointer(self.builder.get_pointer_field(1), 0) - } - #[inline] - pub fn has_version(&self) -> bool { - !self.builder.is_pointer_field_null(1) - } } pub struct Pipeline { _typeless: ::capnp::any_pointer::Pipeline } @@ -175,23 +151,19 @@ pub mod package { } } impl Pipeline { - pub fn get_version(&self) -> crate::hugr_v0_capnp::version::Pipeline { - ::capnp::capability::FromTypelessPipeline::new(self._typeless.get_pointer_field(1)) - } } mod _private { - pub static ENCODED_NODE: [::capnp::Word; 53] = [ - ::capnp::word(0, 0, 0, 0, 6, 0, 6, 0), + pub static ENCODED_NODE: [::capnp::Word; 37] = [ + ::capnp::word(0, 0, 0, 0, 5, 0, 6, 0), ::capnp::word(56, 36, 26, 168, 243, 12, 207, 208), ::capnp::word(20, 0, 0, 0, 1, 0, 0, 0), ::capnp::word(1, 150, 80, 40, 197, 50, 43, 224), - ::capnp::word(2, 0, 7, 0, 0, 0, 0, 0), + ::capnp::word(1, 0, 7, 0, 0, 0, 0, 0), ::capnp::word(0, 0, 0, 0, 0, 0, 0, 0), - ::capnp::word(93, 1, 0, 0, 166, 1, 0, 0), ::capnp::word(21, 0, 0, 0, 226, 0, 0, 0), ::capnp::word(33, 0, 0, 0, 7, 0, 0, 0), ::capnp::word(0, 0, 0, 0, 0, 0, 0, 0), - ::capnp::word(29, 0, 0, 0, 119, 0, 0, 0), + ::capnp::word(29, 0, 0, 0, 63, 0, 0, 0), ::capnp::word(0, 0, 0, 0, 0, 0, 0, 0), ::capnp::word(0, 0, 0, 0, 0, 0, 0, 0), ::capnp::word(99, 97, 112, 110, 112, 47, 104, 117), @@ -199,21 +171,14 @@ pub mod package { ::capnp::word(112, 110, 112, 58, 80, 97, 99, 107), ::capnp::word(97, 103, 101, 0, 0, 0, 0, 0), ::capnp::word(0, 0, 0, 0, 1, 0, 1, 0), - ::capnp::word(8, 0, 0, 0, 3, 0, 4, 0), + ::capnp::word(4, 0, 0, 0, 3, 0, 4, 0), ::capnp::word(0, 0, 0, 0, 0, 0, 0, 0), ::capnp::word(0, 0, 1, 0, 0, 0, 0, 0), ::capnp::word(0, 0, 0, 0, 0, 0, 0, 0), - ::capnp::word(41, 0, 0, 0, 66, 0, 0, 0), + ::capnp::word(13, 0, 0, 0, 66, 0, 0, 0), ::capnp::word(0, 0, 0, 0, 0, 0, 0, 0), - ::capnp::word(36, 0, 0, 0, 3, 0, 1, 0), - ::capnp::word(64, 0, 0, 0, 2, 0, 1, 0), - ::capnp::word(1, 0, 0, 0, 1, 0, 0, 0), - ::capnp::word(0, 0, 1, 0, 1, 0, 0, 0), - ::capnp::word(0, 0, 0, 0, 0, 0, 0, 0), - ::capnp::word(61, 0, 0, 0, 66, 0, 0, 0), - ::capnp::word(0, 0, 0, 0, 0, 0, 0, 0), - ::capnp::word(56, 0, 0, 0, 3, 0, 1, 0), - ::capnp::word(68, 0, 0, 0, 2, 0, 1, 0), + ::capnp::word(8, 0, 0, 0, 3, 0, 1, 0), + ::capnp::word(36, 0, 0, 0, 2, 0, 1, 0), ::capnp::word(109, 111, 100, 117, 108, 101, 115, 0), ::capnp::word(14, 0, 0, 0, 0, 0, 0, 0), ::capnp::word(0, 0, 0, 0, 0, 0, 0, 0), @@ -226,24 +191,15 @@ pub mod package { ::capnp::word(14, 0, 0, 0, 0, 0, 0, 0), ::capnp::word(0, 0, 0, 0, 0, 0, 0, 0), ::capnp::word(0, 0, 0, 0, 0, 0, 0, 0), - ::capnp::word(118, 101, 114, 115, 105, 111, 110, 0), - ::capnp::word(16, 0, 0, 0, 0, 0, 0, 0), - ::capnp::word(167, 171, 245, 145, 177, 155, 108, 182), - ::capnp::word(0, 0, 0, 0, 0, 0, 0, 0), - ::capnp::word(0, 0, 0, 0, 0, 0, 0, 0), - ::capnp::word(16, 0, 0, 0, 0, 0, 0, 0), - ::capnp::word(0, 0, 0, 0, 0, 0, 0, 0), - ::capnp::word(0, 0, 0, 0, 0, 0, 0, 0), ]; pub fn get_field_types(index: u16) -> ::capnp::introspect::Type { match index { 0 => <::capnp::struct_list::Owned as ::capnp::introspect::Introspect>::introspect(), - 1 => ::introspect(), - _ => ::capnp::introspect::panic_invalid_field_index(index), + _ => panic!("invalid field index {}", index), } } pub fn get_annotation_types(child_index: Option, index: u32) -> ::capnp::introspect::Type { - ::capnp::introspect::panic_invalid_annotation_indices(child_index, index) + panic!("invalid annotation indices ({:?}, {}) ", child_index, index) } pub static RAW_SCHEMA: ::capnp::introspect::RawStructSchema = ::capnp::introspect::RawStructSchema { encoded_node: &ENCODED_NODE, @@ -251,237 +207,13 @@ pub mod package { members_by_discriminant: MEMBERS_BY_DISCRIMINANT, members_by_name: MEMBERS_BY_NAME, }; - pub static NONUNION_MEMBERS : &[u16] = &[0,1]; + pub static NONUNION_MEMBERS : &[u16] = &[0]; pub static MEMBERS_BY_DISCRIMINANT : &[u16] = &[]; - pub static MEMBERS_BY_NAME : &[u16] = &[0,1]; + pub static MEMBERS_BY_NAME : &[u16] = &[0]; pub const TYPE_ID: u64 = 0xd0cf_0cf3_a81a_2438; } } -pub mod version { - #[derive(Copy, Clone)] - pub struct Owned(()); - impl ::capnp::introspect::Introspect for Owned { fn introspect() -> ::capnp::introspect::Type { ::capnp::introspect::TypeVariant::Struct(::capnp::introspect::RawBrandedStructSchema { generic: &_private::RAW_SCHEMA, field_types: _private::get_field_types, annotation_types: _private::get_annotation_types }).into() } } - impl ::capnp::traits::Owned for Owned { type Reader<'a> = Reader<'a>; type Builder<'a> = Builder<'a>; } - impl ::capnp::traits::OwnedStruct for Owned { type Reader<'a> = Reader<'a>; type Builder<'a> = Builder<'a>; } - impl ::capnp::traits::Pipelined for Owned { type Pipeline = Pipeline; } - - pub struct Reader<'a> { reader: ::capnp::private::layout::StructReader<'a> } - impl <> ::core::marker::Copy for Reader<'_,> {} - impl <> ::core::clone::Clone for Reader<'_,> { - fn clone(&self) -> Self { *self } - } - - impl <> ::capnp::traits::HasTypeId for Reader<'_,> { - const TYPE_ID: u64 = _private::TYPE_ID; - } - impl <'a,> ::core::convert::From<::capnp::private::layout::StructReader<'a>> for Reader<'a,> { - fn from(reader: ::capnp::private::layout::StructReader<'a>) -> Self { - Self { reader, } - } - } - - impl <'a,> ::core::convert::From> for ::capnp::dynamic_value::Reader<'a> { - fn from(reader: Reader<'a,>) -> Self { - Self::Struct(::capnp::dynamic_struct::Reader::new(reader.reader, ::capnp::schema::StructSchema::new(::capnp::introspect::RawBrandedStructSchema { generic: &_private::RAW_SCHEMA, field_types: _private::get_field_types::<>, annotation_types: _private::get_annotation_types::<>}))) - } - } - - impl <> ::core::fmt::Debug for Reader<'_,> { - fn fmt(&self, f: &mut ::core::fmt::Formatter<'_>) -> ::core::result::Result<(), ::core::fmt::Error> { - core::fmt::Debug::fmt(&::core::convert::Into::<::capnp::dynamic_value::Reader<'_>>::into(*self), f) - } - } - - impl <'a,> ::capnp::traits::FromPointerReader<'a> for Reader<'a,> { - fn get_from_pointer(reader: &::capnp::private::layout::PointerReader<'a>, default: ::core::option::Option<&'a [::capnp::Word]>) -> ::capnp::Result { - ::core::result::Result::Ok(reader.get_struct(default)?.into()) - } - } - - impl <'a,> ::capnp::traits::IntoInternalStructReader<'a> for Reader<'a,> { - fn into_internal_struct_reader(self) -> ::capnp::private::layout::StructReader<'a> { - self.reader - } - } - - impl <'a,> ::capnp::traits::Imbue<'a> for Reader<'a,> { - fn imbue(&mut self, cap_table: &'a ::capnp::private::layout::CapTable) { - self.reader.imbue(::capnp::private::layout::CapTableReader::Plain(cap_table)) - } - } - - impl <> Reader<'_,> { - pub fn reborrow(&self) -> Reader<'_,> { - Self { .. *self } - } - - pub fn total_size(&self) -> ::capnp::Result<::capnp::MessageSize> { - self.reader.total_size() - } - #[inline] - pub fn get_major(self) -> u32 { - self.reader.get_data_field::(0) - } - #[inline] - pub fn get_minor(self) -> u32 { - self.reader.get_data_field::(1) - } - } - - pub struct Builder<'a> { builder: ::capnp::private::layout::StructBuilder<'a> } - impl <> ::capnp::traits::HasStructSize for Builder<'_,> { - const STRUCT_SIZE: ::capnp::private::layout::StructSize = ::capnp::private::layout::StructSize { data: 1, pointers: 0 }; - } - impl <> ::capnp::traits::HasTypeId for Builder<'_,> { - const TYPE_ID: u64 = _private::TYPE_ID; - } - impl <'a,> ::core::convert::From<::capnp::private::layout::StructBuilder<'a>> for Builder<'a,> { - fn from(builder: ::capnp::private::layout::StructBuilder<'a>) -> Self { - Self { builder, } - } - } - - impl <'a,> ::core::convert::From> for ::capnp::dynamic_value::Builder<'a> { - fn from(builder: Builder<'a,>) -> Self { - Self::Struct(::capnp::dynamic_struct::Builder::new(builder.builder, ::capnp::schema::StructSchema::new(::capnp::introspect::RawBrandedStructSchema { generic: &_private::RAW_SCHEMA, field_types: _private::get_field_types::<>, annotation_types: _private::get_annotation_types::<>}))) - } - } - - impl <'a,> ::capnp::traits::ImbueMut<'a> for Builder<'a,> { - fn imbue_mut(&mut self, cap_table: &'a mut ::capnp::private::layout::CapTable) { - self.builder.imbue(::capnp::private::layout::CapTableBuilder::Plain(cap_table)) - } - } - - impl <'a,> ::capnp::traits::FromPointerBuilder<'a> for Builder<'a,> { - fn init_pointer(builder: ::capnp::private::layout::PointerBuilder<'a>, _size: u32) -> Self { - builder.init_struct(::STRUCT_SIZE).into() - } - fn get_from_pointer(builder: ::capnp::private::layout::PointerBuilder<'a>, default: ::core::option::Option<&'a [::capnp::Word]>) -> ::capnp::Result { - ::core::result::Result::Ok(builder.get_struct(::STRUCT_SIZE, default)?.into()) - } - } - - impl <> ::capnp::traits::SetterInput> for Reader<'_,> { - fn set_pointer_builder(mut pointer: ::capnp::private::layout::PointerBuilder<'_>, value: Self, canonicalize: bool) -> ::capnp::Result<()> { pointer.set_struct(&value.reader, canonicalize) } - } - - impl <'a,> Builder<'a,> { - pub fn into_reader(self) -> Reader<'a,> { - self.builder.into_reader().into() - } - pub fn reborrow(&mut self) -> Builder<'_,> { - Builder { builder: self.builder.reborrow() } - } - pub fn reborrow_as_reader(&self) -> Reader<'_,> { - self.builder.as_reader().into() - } - - pub fn total_size(&self) -> ::capnp::Result<::capnp::MessageSize> { - self.builder.as_reader().total_size() - } - #[inline] - pub fn get_major(self) -> u32 { - self.builder.get_data_field::(0) - } - #[inline] - pub fn set_major(&mut self, value: u32) { - self.builder.set_data_field::(0, value); - } - #[inline] - pub fn get_minor(self) -> u32 { - self.builder.get_data_field::(1) - } - #[inline] - pub fn set_minor(&mut self, value: u32) { - self.builder.set_data_field::(1, value); - } - } - - pub struct Pipeline { _typeless: ::capnp::any_pointer::Pipeline } - impl ::capnp::capability::FromTypelessPipeline for Pipeline { - fn new(typeless: ::capnp::any_pointer::Pipeline) -> Self { - Self { _typeless: typeless, } - } - } - impl Pipeline { - } - mod _private { - pub static ENCODED_NODE: [::capnp::Word; 49] = [ - ::capnp::word(0, 0, 0, 0, 6, 0, 6, 0), - ::capnp::word(167, 171, 245, 145, 177, 155, 108, 182), - ::capnp::word(20, 0, 0, 0, 1, 0, 1, 0), - ::capnp::word(1, 150, 80, 40, 197, 50, 43, 224), - ::capnp::word(0, 0, 7, 0, 0, 0, 0, 0), - ::capnp::word(0, 0, 0, 0, 0, 0, 0, 0), - ::capnp::word(168, 1, 0, 0, 230, 1, 0, 0), - ::capnp::word(21, 0, 0, 0, 226, 0, 0, 0), - ::capnp::word(33, 0, 0, 0, 7, 0, 0, 0), - ::capnp::word(0, 0, 0, 0, 0, 0, 0, 0), - ::capnp::word(29, 0, 0, 0, 119, 0, 0, 0), - ::capnp::word(0, 0, 0, 0, 0, 0, 0, 0), - ::capnp::word(0, 0, 0, 0, 0, 0, 0, 0), - ::capnp::word(99, 97, 112, 110, 112, 47, 104, 117), - ::capnp::word(103, 114, 45, 118, 48, 46, 99, 97), - ::capnp::word(112, 110, 112, 58, 86, 101, 114, 115), - ::capnp::word(105, 111, 110, 0, 0, 0, 0, 0), - ::capnp::word(0, 0, 0, 0, 1, 0, 1, 0), - ::capnp::word(8, 0, 0, 0, 3, 0, 4, 0), - ::capnp::word(0, 0, 0, 0, 0, 0, 0, 0), - ::capnp::word(0, 0, 1, 0, 0, 0, 0, 0), - ::capnp::word(0, 0, 0, 0, 0, 0, 0, 0), - ::capnp::word(41, 0, 0, 0, 50, 0, 0, 0), - ::capnp::word(0, 0, 0, 0, 0, 0, 0, 0), - ::capnp::word(36, 0, 0, 0, 3, 0, 1, 0), - ::capnp::word(48, 0, 0, 0, 2, 0, 1, 0), - ::capnp::word(1, 0, 0, 0, 1, 0, 0, 0), - ::capnp::word(0, 0, 1, 0, 1, 0, 0, 0), - ::capnp::word(0, 0, 0, 0, 0, 0, 0, 0), - ::capnp::word(45, 0, 0, 0, 50, 0, 0, 0), - ::capnp::word(0, 0, 0, 0, 0, 0, 0, 0), - ::capnp::word(40, 0, 0, 0, 3, 0, 1, 0), - ::capnp::word(52, 0, 0, 0, 2, 0, 1, 0), - ::capnp::word(109, 97, 106, 111, 114, 0, 0, 0), - ::capnp::word(8, 0, 0, 0, 0, 0, 0, 0), - ::capnp::word(0, 0, 0, 0, 0, 0, 0, 0), - ::capnp::word(0, 0, 0, 0, 0, 0, 0, 0), - ::capnp::word(0, 0, 0, 0, 0, 0, 0, 0), - ::capnp::word(8, 0, 0, 0, 0, 0, 0, 0), - ::capnp::word(0, 0, 0, 0, 0, 0, 0, 0), - ::capnp::word(0, 0, 0, 0, 0, 0, 0, 0), - ::capnp::word(109, 105, 110, 111, 114, 0, 0, 0), - ::capnp::word(8, 0, 0, 0, 0, 0, 0, 0), - ::capnp::word(0, 0, 0, 0, 0, 0, 0, 0), - ::capnp::word(0, 0, 0, 0, 0, 0, 0, 0), - ::capnp::word(0, 0, 0, 0, 0, 0, 0, 0), - ::capnp::word(8, 0, 0, 0, 0, 0, 0, 0), - ::capnp::word(0, 0, 0, 0, 0, 0, 0, 0), - ::capnp::word(0, 0, 0, 0, 0, 0, 0, 0), - ]; - pub fn get_field_types(index: u16) -> ::capnp::introspect::Type { - match index { - 0 => ::introspect(), - 1 => ::introspect(), - _ => ::capnp::introspect::panic_invalid_field_index(index), - } - } - pub fn get_annotation_types(child_index: Option, index: u32) -> ::capnp::introspect::Type { - ::capnp::introspect::panic_invalid_annotation_indices(child_index, index) - } - pub static RAW_SCHEMA: ::capnp::introspect::RawStructSchema = ::capnp::introspect::RawStructSchema { - encoded_node: &ENCODED_NODE, - nonunion_members: NONUNION_MEMBERS, - members_by_discriminant: MEMBERS_BY_DISCRIMINANT, - members_by_name: MEMBERS_BY_NAME, - }; - pub static NONUNION_MEMBERS : &[u16] = &[0,1]; - pub static MEMBERS_BY_DISCRIMINANT : &[u16] = &[]; - pub static MEMBERS_BY_NAME : &[u16] = &[0,1]; - pub const TYPE_ID: u64 = 0xb66c_9bb1_91f5_aba7; - } -} - pub mod module { #[derive(Copy, Clone)] pub struct Owned(()); @@ -692,14 +424,13 @@ pub mod module { impl Pipeline { } mod _private { - pub static ENCODED_NODE: [::capnp::Word; 91] = [ - ::capnp::word(0, 0, 0, 0, 6, 0, 6, 0), + pub static ENCODED_NODE: [::capnp::Word; 90] = [ + ::capnp::word(0, 0, 0, 0, 5, 0, 6, 0), ::capnp::word(167, 107, 35, 13, 152, 216, 48, 189), ::capnp::word(20, 0, 0, 0, 1, 0, 1, 0), ::capnp::word(1, 150, 80, 40, 197, 50, 43, 224), ::capnp::word(3, 0, 7, 0, 0, 0, 0, 0), ::capnp::word(0, 0, 0, 0, 0, 0, 0, 0), - ::capnp::word(232, 1, 0, 0, 98, 2, 0, 0), ::capnp::word(21, 0, 0, 0, 218, 0, 0, 0), ::capnp::word(33, 0, 0, 0, 7, 0, 0, 0), ::capnp::word(0, 0, 0, 0, 0, 0, 0, 0), @@ -791,11 +522,11 @@ pub mod module { 1 => <::capnp::struct_list::Owned as ::capnp::introspect::Introspect>::introspect(), 2 => <::capnp::struct_list::Owned as ::capnp::introspect::Introspect>::introspect(), 3 => <::capnp::struct_list::Owned as ::capnp::introspect::Introspect>::introspect(), - _ => ::capnp::introspect::panic_invalid_field_index(index), + _ => panic!("invalid field index {}", index), } } pub fn get_annotation_types(child_index: Option, index: u32) -> ::capnp::introspect::Type { - ::capnp::introspect::panic_invalid_annotation_indices(child_index, index) + panic!("invalid annotation indices ({:?}, {}) ", child_index, index) } pub static RAW_SCHEMA: ::capnp::introspect::RawStructSchema = ::capnp::introspect::RawStructSchema { encoded_node: &ENCODED_NODE, @@ -1071,14 +802,13 @@ pub mod node { } } mod _private { - pub static ENCODED_NODE: [::capnp::Word; 127] = [ - ::capnp::word(0, 0, 0, 0, 6, 0, 6, 0), + pub static ENCODED_NODE: [::capnp::Word; 126] = [ + ::capnp::word(0, 0, 0, 0, 5, 0, 6, 0), ::capnp::word(108, 130, 159, 249, 96, 124, 57, 228), ::capnp::word(20, 0, 0, 0, 1, 0, 1, 0), ::capnp::word(1, 150, 80, 40, 197, 50, 43, 224), ::capnp::word(5, 0, 7, 0, 0, 0, 0, 0), ::capnp::word(0, 0, 0, 0, 0, 0, 0, 0), - ::capnp::word(100, 2, 0, 0, 46, 3, 0, 0), ::capnp::word(21, 0, 0, 0, 202, 0, 0, 0), ::capnp::word(33, 0, 0, 0, 7, 0, 0, 0), ::capnp::word(0, 0, 0, 0, 0, 0, 0, 0), @@ -1208,11 +938,11 @@ pub mod node { 3 => <::capnp::primitive_list::Owned as ::capnp::introspect::Introspect>::introspect(), 4 => <::capnp::primitive_list::Owned as ::capnp::introspect::Introspect>::introspect(), 5 => ::introspect(), - _ => ::capnp::introspect::panic_invalid_field_index(index), + _ => panic!("invalid field index {}", index), } } pub fn get_annotation_types(child_index: Option, index: u32) -> ::capnp::introspect::Type { - ::capnp::introspect::panic_invalid_annotation_indices(child_index, index) + panic!("invalid annotation indices ({:?}, {}) ", child_index, index) } pub static RAW_SCHEMA: ::capnp::introspect::RawStructSchema = ::capnp::introspect::RawStructSchema { encoded_node: &ENCODED_NODE, @@ -1663,14 +1393,13 @@ pub mod operation { impl Pipeline { } mod _private { - pub static ENCODED_NODE: [::capnp::Word; 230] = [ - ::capnp::word(0, 0, 0, 0, 6, 0, 6, 0), + pub static ENCODED_NODE: [::capnp::Word; 229] = [ + ::capnp::word(0, 0, 0, 0, 5, 0, 6, 0), ::capnp::word(216, 191, 119, 93, 53, 241, 240, 155), ::capnp::word(20, 0, 0, 0, 1, 0, 1, 0), ::capnp::word(1, 150, 80, 40, 197, 50, 43, 224), ::capnp::word(1, 0, 7, 0, 0, 0, 14, 0), ::capnp::word(2, 0, 0, 0, 0, 0, 0, 0), - ::capnp::word(48, 3, 0, 0, 38, 5, 0, 0), ::capnp::word(21, 0, 0, 0, 242, 0, 0, 0), ::capnp::word(33, 0, 0, 0, 7, 0, 0, 0), ::capnp::word(0, 0, 0, 0, 0, 0, 0, 0), @@ -1911,11 +1640,11 @@ pub mod operation { 11 => <::capnp::text::Owned as ::capnp::introspect::Introspect>::introspect(), 12 => ::introspect(), 13 => ::introspect(), - _ => ::capnp::introspect::panic_invalid_field_index(index), + _ => panic!("invalid field index {}", index), } } pub fn get_annotation_types(child_index: Option, index: u32) -> ::capnp::introspect::Type { - ::capnp::introspect::panic_invalid_annotation_indices(child_index, index) + panic!("invalid annotation indices ({:?}, {}) ", child_index, index) } pub static RAW_SCHEMA: ::capnp::introspect::RawStructSchema = ::capnp::introspect::RawStructSchema { encoded_node: &ENCODED_NODE, @@ -2112,14 +1841,13 @@ pub mod operation { } } mod _private { - pub static ENCODED_NODE: [::capnp::Word; 49] = [ - ::capnp::word(0, 0, 0, 0, 6, 0, 6, 0), + pub static ENCODED_NODE: [::capnp::Word; 48] = [ + ::capnp::word(0, 0, 0, 0, 5, 0, 6, 0), ::capnp::word(156, 202, 42, 93, 60, 14, 161, 193), ::capnp::word(30, 0, 0, 0, 1, 0, 1, 0), ::capnp::word(216, 191, 119, 93, 53, 241, 240, 155), ::capnp::word(1, 0, 7, 0, 1, 0, 0, 0), ::capnp::word(0, 0, 0, 0, 0, 0, 0, 0), - ::capnp::word(0, 0, 0, 0, 0, 0, 0, 0), ::capnp::word(21, 0, 0, 0, 66, 1, 0, 0), ::capnp::word(0, 0, 0, 0, 0, 0, 0, 0), ::capnp::word(0, 0, 0, 0, 0, 0, 0, 0), @@ -2167,11 +1895,11 @@ pub mod operation { match index { 0 => ::introspect(), 1 => ::introspect(), - _ => ::capnp::introspect::panic_invalid_field_index(index), + _ => panic!("invalid field index {}", index), } } pub fn get_annotation_types(child_index: Option, index: u32) -> ::capnp::introspect::Type { - ::capnp::introspect::panic_invalid_annotation_indices(child_index, index) + panic!("invalid annotation indices ({:?}, {}) ", child_index, index) } pub static RAW_SCHEMA: ::capnp::introspect::RawStructSchema = ::capnp::introspect::RawStructSchema { encoded_node: &ENCODED_NODE, @@ -2276,10 +2004,6 @@ pub mod symbol { pub fn get_signature(self) -> u32 { self.reader.get_data_field::(0) } - #[inline] - pub fn get_visibility(self) -> ::core::result::Result { - ::core::convert::TryInto::try_into(self.reader.get_data_field::(2)) - } } pub struct Builder<'a> { builder: ::capnp::private::layout::StructBuilder<'a> } @@ -2390,14 +2114,6 @@ pub mod symbol { pub fn set_signature(&mut self, value: u32) { self.builder.set_data_field::(0, value); } - #[inline] - pub fn get_visibility(self) -> ::core::result::Result { - ::core::convert::TryInto::try_into(self.builder.get_data_field::(2)) - } - #[inline] - pub fn set_visibility(&mut self, value: crate::hugr_v0_capnp::Visibility) { - self.builder.set_data_field::(2, value as u16); - } } pub struct Pipeline { _typeless: ::capnp::any_pointer::Pipeline } @@ -2409,18 +2125,17 @@ pub mod symbol { impl Pipeline { } mod _private { - pub static ENCODED_NODE: [::capnp::Word; 105] = [ - ::capnp::word(0, 0, 0, 0, 6, 0, 6, 0), + pub static ENCODED_NODE: [::capnp::Word; 88] = [ + ::capnp::word(0, 0, 0, 0, 5, 0, 6, 0), ::capnp::word(63, 209, 84, 70, 225, 154, 206, 223), ::capnp::word(20, 0, 0, 0, 1, 0, 1, 0), ::capnp::word(1, 150, 80, 40, 197, 50, 43, 224), ::capnp::word(3, 0, 7, 0, 0, 0, 0, 0), ::capnp::word(0, 0, 0, 0, 0, 0, 0, 0), - ::capnp::word(40, 5, 0, 0, 195, 5, 0, 0), ::capnp::word(21, 0, 0, 0, 218, 0, 0, 0), ::capnp::word(33, 0, 0, 0, 7, 0, 0, 0), ::capnp::word(0, 0, 0, 0, 0, 0, 0, 0), - ::capnp::word(29, 0, 0, 0, 31, 1, 0, 0), + ::capnp::word(29, 0, 0, 0, 231, 0, 0, 0), ::capnp::word(0, 0, 0, 0, 0, 0, 0, 0), ::capnp::word(0, 0, 0, 0, 0, 0, 0, 0), ::capnp::word(99, 97, 112, 110, 112, 47, 104, 117), @@ -2428,42 +2143,35 @@ pub mod symbol { ::capnp::word(112, 110, 112, 58, 83, 121, 109, 98), ::capnp::word(111, 108, 0, 0, 0, 0, 0, 0), ::capnp::word(0, 0, 0, 0, 1, 0, 1, 0), - ::capnp::word(20, 0, 0, 0, 3, 0, 4, 0), - ::capnp::word(1, 0, 0, 0, 0, 0, 0, 0), + ::capnp::word(16, 0, 0, 0, 3, 0, 4, 0), + ::capnp::word(0, 0, 0, 0, 0, 0, 0, 0), ::capnp::word(0, 0, 1, 0, 0, 0, 0, 0), ::capnp::word(0, 0, 0, 0, 0, 0, 0, 0), - ::capnp::word(125, 0, 0, 0, 42, 0, 0, 0), + ::capnp::word(97, 0, 0, 0, 42, 0, 0, 0), ::capnp::word(0, 0, 0, 0, 0, 0, 0, 0), - ::capnp::word(120, 0, 0, 0, 3, 0, 1, 0), - ::capnp::word(132, 0, 0, 0, 2, 0, 1, 0), - ::capnp::word(2, 0, 0, 0, 1, 0, 0, 0), + ::capnp::word(92, 0, 0, 0, 3, 0, 1, 0), + ::capnp::word(104, 0, 0, 0, 2, 0, 1, 0), + ::capnp::word(1, 0, 0, 0, 1, 0, 0, 0), ::capnp::word(0, 0, 1, 0, 1, 0, 0, 0), ::capnp::word(0, 0, 0, 0, 0, 0, 0, 0), - ::capnp::word(129, 0, 0, 0, 58, 0, 0, 0), + ::capnp::word(101, 0, 0, 0, 58, 0, 0, 0), ::capnp::word(0, 0, 0, 0, 0, 0, 0, 0), - ::capnp::word(124, 0, 0, 0, 3, 0, 1, 0), - ::capnp::word(152, 0, 0, 0, 2, 0, 1, 0), - ::capnp::word(3, 0, 0, 0, 2, 0, 0, 0), + ::capnp::word(96, 0, 0, 0, 3, 0, 1, 0), + ::capnp::word(124, 0, 0, 0, 2, 0, 1, 0), + ::capnp::word(2, 0, 0, 0, 2, 0, 0, 0), ::capnp::word(0, 0, 1, 0, 2, 0, 0, 0), ::capnp::word(0, 0, 0, 0, 0, 0, 0, 0), - ::capnp::word(149, 0, 0, 0, 98, 0, 0, 0), + ::capnp::word(121, 0, 0, 0, 98, 0, 0, 0), ::capnp::word(0, 0, 0, 0, 0, 0, 0, 0), - ::capnp::word(148, 0, 0, 0, 3, 0, 1, 0), - ::capnp::word(176, 0, 0, 0, 2, 0, 1, 0), - ::capnp::word(4, 0, 0, 0, 0, 0, 0, 0), + ::capnp::word(120, 0, 0, 0, 3, 0, 1, 0), + ::capnp::word(148, 0, 0, 0, 2, 0, 1, 0), + ::capnp::word(3, 0, 0, 0, 0, 0, 0, 0), ::capnp::word(0, 0, 1, 0, 3, 0, 0, 0), ::capnp::word(0, 0, 0, 0, 0, 0, 0, 0), - ::capnp::word(173, 0, 0, 0, 82, 0, 0, 0), - ::capnp::word(0, 0, 0, 0, 0, 0, 0, 0), - ::capnp::word(172, 0, 0, 0, 3, 0, 1, 0), - ::capnp::word(184, 0, 0, 0, 2, 0, 1, 0), - ::capnp::word(0, 0, 0, 0, 2, 0, 0, 0), - ::capnp::word(0, 0, 1, 0, 4, 0, 0, 0), - ::capnp::word(0, 0, 0, 0, 0, 0, 0, 0), - ::capnp::word(181, 0, 0, 0, 90, 0, 0, 0), + ::capnp::word(145, 0, 0, 0, 82, 0, 0, 0), ::capnp::word(0, 0, 0, 0, 0, 0, 0, 0), - ::capnp::word(180, 0, 0, 0, 3, 0, 1, 0), - ::capnp::word(192, 0, 0, 0, 2, 0, 1, 0), + ::capnp::word(144, 0, 0, 0, 3, 0, 1, 0), + ::capnp::word(156, 0, 0, 0, 2, 0, 1, 0), ::capnp::word(110, 97, 109, 101, 0, 0, 0, 0), ::capnp::word(12, 0, 0, 0, 0, 0, 0, 0), ::capnp::word(0, 0, 0, 0, 0, 0, 0, 0), @@ -2506,15 +2214,6 @@ pub mod symbol { ::capnp::word(8, 0, 0, 0, 0, 0, 0, 0), ::capnp::word(0, 0, 0, 0, 0, 0, 0, 0), ::capnp::word(0, 0, 0, 0, 0, 0, 0, 0), - ::capnp::word(118, 105, 115, 105, 98, 105, 108, 105), - ::capnp::word(116, 121, 0, 0, 0, 0, 0, 0), - ::capnp::word(15, 0, 0, 0, 0, 0, 0, 0), - ::capnp::word(1, 131, 104, 122, 242, 21, 131, 141), - ::capnp::word(0, 0, 0, 0, 0, 0, 0, 0), - ::capnp::word(0, 0, 0, 0, 0, 0, 0, 0), - ::capnp::word(15, 0, 0, 0, 0, 0, 0, 0), - ::capnp::word(0, 0, 0, 0, 0, 0, 0, 0), - ::capnp::word(0, 0, 0, 0, 0, 0, 0, 0), ]; pub fn get_field_types(index: u16) -> ::capnp::introspect::Type { match index { @@ -2522,12 +2221,11 @@ pub mod symbol { 1 => <::capnp::struct_list::Owned as ::capnp::introspect::Introspect>::introspect(), 2 => <::capnp::primitive_list::Owned as ::capnp::introspect::Introspect>::introspect(), 3 => ::introspect(), - 4 => ::introspect(), - _ => ::capnp::introspect::panic_invalid_field_index(index), + _ => panic!("invalid field index {}", index), } } pub fn get_annotation_types(child_index: Option, index: u32) -> ::capnp::introspect::Type { - ::capnp::introspect::panic_invalid_annotation_indices(child_index, index) + panic!("invalid annotation indices ({:?}, {}) ", child_index, index) } pub static RAW_SCHEMA: ::capnp::introspect::RawStructSchema = ::capnp::introspect::RawStructSchema { encoded_node: &ENCODED_NODE, @@ -2535,9 +2233,9 @@ pub mod symbol { members_by_discriminant: MEMBERS_BY_DISCRIMINANT, members_by_name: MEMBERS_BY_NAME, }; - pub static NONUNION_MEMBERS : &[u16] = &[0,1,2,3,4]; + pub static NONUNION_MEMBERS : &[u16] = &[0,1,2,3]; pub static MEMBERS_BY_DISCRIMINANT : &[u16] = &[]; - pub static MEMBERS_BY_NAME : &[u16] = &[2,0,1,3,4]; + pub static MEMBERS_BY_NAME : &[u16] = &[2,0,1,3]; pub const TYPE_ID: u64 = 0xdfce_9ae1_4654_d13f; } } @@ -2815,14 +2513,13 @@ pub mod region { } } mod _private { - pub static ENCODED_NODE: [::capnp::Word; 142] = [ - ::capnp::word(0, 0, 0, 0, 6, 0, 6, 0), + pub static ENCODED_NODE: [::capnp::Word; 141] = [ + ::capnp::word(0, 0, 0, 0, 5, 0, 6, 0), ::capnp::word(225, 113, 253, 231, 231, 39, 130, 153), ::capnp::word(20, 0, 0, 0, 1, 0, 1, 0), ::capnp::word(1, 150, 80, 40, 197, 50, 43, 224), ::capnp::word(5, 0, 7, 0, 0, 0, 0, 0), ::capnp::word(0, 0, 0, 0, 0, 0, 0, 0), - ::capnp::word(197, 5, 0, 0, 168, 6, 0, 0), ::capnp::word(21, 0, 0, 0, 218, 0, 0, 0), ::capnp::word(33, 0, 0, 0, 7, 0, 0, 0), ::capnp::word(0, 0, 0, 0, 0, 0, 0, 0), @@ -2968,11 +2665,11 @@ pub mod region { 4 => <::capnp::primitive_list::Owned as ::capnp::introspect::Introspect>::introspect(), 5 => ::introspect(), 6 => ::introspect(), - _ => ::capnp::introspect::panic_invalid_field_index(index), + _ => panic!("invalid field index {}", index), } } pub fn get_annotation_types(child_index: Option, index: u32) -> ::capnp::introspect::Type { - ::capnp::introspect::panic_invalid_annotation_indices(child_index, index) + panic!("invalid annotation indices ({:?}, {}) ", child_index, index) } pub static RAW_SCHEMA: ::capnp::introspect::RawStructSchema = ::capnp::introspect::RawStructSchema { encoded_node: &ENCODED_NODE, @@ -3137,14 +2834,13 @@ pub mod region_scope { impl Pipeline { } mod _private { - pub static ENCODED_NODE: [::capnp::Word; 49] = [ - ::capnp::word(0, 0, 0, 0, 6, 0, 6, 0), + pub static ENCODED_NODE: [::capnp::Word; 48] = [ + ::capnp::word(0, 0, 0, 0, 5, 0, 6, 0), ::capnp::word(163, 135, 81, 30, 243, 205, 148, 170), ::capnp::word(20, 0, 0, 0, 1, 0, 1, 0), ::capnp::word(1, 150, 80, 40, 197, 50, 43, 224), ::capnp::word(0, 0, 7, 0, 0, 0, 0, 0), ::capnp::word(0, 0, 0, 0, 0, 0, 0, 0), - ::capnp::word(170, 6, 0, 0, 236, 6, 0, 0), ::capnp::word(21, 0, 0, 0, 2, 1, 0, 0), ::capnp::word(33, 0, 0, 0, 7, 0, 0, 0), ::capnp::word(0, 0, 0, 0, 0, 0, 0, 0), @@ -3192,11 +2888,11 @@ pub mod region_scope { match index { 0 => ::introspect(), 1 => ::introspect(), - _ => ::capnp::introspect::panic_invalid_field_index(index), + _ => panic!("invalid field index {}", index), } } pub fn get_annotation_types(child_index: Option, index: u32) -> ::capnp::introspect::Type { - ::capnp::introspect::panic_invalid_annotation_indices(child_index, index) + panic!("invalid annotation indices ({:?}, {}) ", child_index, index) } pub static RAW_SCHEMA: ::capnp::introspect::RawStructSchema = ::capnp::introspect::RawStructSchema { encoded_node: &ENCODED_NODE, @@ -3244,14 +2940,13 @@ impl ::capnp::traits::HasTypeId for RegionKind { const TYPE_ID: u64 = 0xe457_1af6_23a3_76b4u64; } mod region_kind { -pub static ENCODED_NODE: [::capnp::Word; 33] = [ - ::capnp::word(0, 0, 0, 0, 6, 0, 6, 0), +pub static ENCODED_NODE: [::capnp::Word; 32] = [ + ::capnp::word(0, 0, 0, 0, 5, 0, 6, 0), ::capnp::word(180, 118, 163, 35, 246, 26, 87, 228), ::capnp::word(20, 0, 0, 0, 2, 0, 0, 0), ::capnp::word(1, 150, 80, 40, 197, 50, 43, 224), ::capnp::word(0, 0, 0, 0, 0, 0, 0, 0), ::capnp::word(0, 0, 0, 0, 0, 0, 0, 0), - ::capnp::word(238, 6, 0, 0, 53, 7, 0, 0), ::capnp::word(21, 0, 0, 0, 250, 0, 0, 0), ::capnp::word(33, 0, 0, 0, 7, 0, 0, 0), ::capnp::word(0, 0, 0, 0, 0, 0, 0, 0), @@ -3280,7 +2975,7 @@ pub static ENCODED_NODE: [::capnp::Word; 33] = [ ::capnp::word(109, 111, 100, 117, 108, 101, 0, 0), ]; pub fn get_annotation_types(child_index: Option, index: u32) -> ::capnp::introspect::Type { - ::capnp::introspect::panic_invalid_annotation_indices(child_index, index) + panic!("invalid annotation indices ({:?}, {}) ", child_index, index) } } @@ -3637,14 +3332,13 @@ pub mod term { impl Pipeline { } mod _private { - pub static ENCODED_NODE: [::capnp::Word; 168] = [ - ::capnp::word(0, 0, 0, 0, 6, 0, 6, 0), + pub static ENCODED_NODE: [::capnp::Word; 167] = [ + ::capnp::word(0, 0, 0, 0, 5, 0, 6, 0), ::capnp::word(178, 107, 91, 137, 60, 121, 191, 207), ::capnp::word(20, 0, 0, 0, 1, 0, 2, 0), ::capnp::word(1, 150, 80, 40, 197, 50, 43, 224), ::capnp::word(1, 0, 7, 0, 0, 0, 10, 0), ::capnp::word(2, 0, 0, 0, 0, 0, 0, 0), - ::capnp::word(55, 7, 0, 0, 105, 9, 0, 0), ::capnp::word(21, 0, 0, 0, 202, 0, 0, 0), ::capnp::word(33, 0, 0, 0, 23, 0, 0, 0), ::capnp::word(0, 0, 0, 0, 0, 0, 0, 0), @@ -3819,11 +3513,11 @@ pub mod term { 7 => ::introspect(), 8 => <() as ::capnp::introspect::Introspect>::introspect(), 9 => <::capnp::struct_list::Owned as ::capnp::introspect::Introspect>::introspect(), - _ => ::capnp::introspect::panic_invalid_field_index(index), + _ => panic!("invalid field index {}", index), } } pub fn get_annotation_types(child_index: Option, index: u32) -> ::capnp::introspect::Type { - ::capnp::introspect::panic_invalid_annotation_indices(child_index, index) + panic!("invalid annotation indices ({:?}, {}) ", child_index, index) } pub static RAW_SCHEMA: ::capnp::introspect::RawStructSchema = ::capnp::introspect::RawStructSchema { encoded_node: &ENCODED_NODE, @@ -4021,14 +3715,13 @@ pub mod term { impl Pipeline { } mod _private { - pub static ENCODED_NODE: [::capnp::Word; 50] = [ - ::capnp::word(0, 0, 0, 0, 6, 0, 6, 0), + pub static ENCODED_NODE: [::capnp::Word; 49] = [ + ::capnp::word(0, 0, 0, 0, 5, 0, 6, 0), ::capnp::word(136, 151, 188, 135, 237, 57, 73, 141), ::capnp::word(25, 0, 0, 0, 1, 0, 1, 0), ::capnp::word(178, 107, 91, 137, 60, 121, 191, 207), ::capnp::word(0, 0, 7, 0, 0, 0, 2, 0), ::capnp::word(2, 0, 0, 0, 0, 0, 0, 0), - ::capnp::word(251, 8, 0, 0, 103, 9, 0, 0), ::capnp::word(21, 0, 0, 0, 10, 1, 0, 0), ::capnp::word(37, 0, 0, 0, 7, 0, 0, 0), ::capnp::word(0, 0, 0, 0, 0, 0, 0, 0), @@ -4077,11 +3770,11 @@ pub mod term { match index { 0 => ::introspect(), 1 => ::introspect(), - _ => ::capnp::introspect::panic_invalid_field_index(index), + _ => panic!("invalid field index {}", index), } } pub fn get_annotation_types(child_index: Option, index: u32) -> ::capnp::introspect::Type { - ::capnp::introspect::panic_invalid_annotation_indices(child_index, index) + panic!("invalid annotation indices ({:?}, {}) ", child_index, index) } pub static RAW_SCHEMA: ::capnp::introspect::RawStructSchema = ::capnp::introspect::RawStructSchema { encoded_node: &ENCODED_NODE, @@ -4264,14 +3957,13 @@ pub mod term { impl Pipeline { } mod _private { - pub static ENCODED_NODE: [::capnp::Word; 52] = [ - ::capnp::word(0, 0, 0, 0, 6, 0, 6, 0), + pub static ENCODED_NODE: [::capnp::Word; 51] = [ + ::capnp::word(0, 0, 0, 0, 5, 0, 6, 0), ::capnp::word(150, 98, 109, 181, 159, 123, 122, 222), ::capnp::word(25, 0, 0, 0, 1, 0, 2, 0), ::capnp::word(178, 107, 91, 137, 60, 121, 191, 207), ::capnp::word(1, 0, 7, 0, 1, 0, 0, 0), ::capnp::word(0, 0, 0, 0, 0, 0, 0, 0), - ::capnp::word(0, 0, 0, 0, 0, 0, 0, 0), ::capnp::word(21, 0, 0, 0, 250, 0, 0, 0), ::capnp::word(0, 0, 0, 0, 0, 0, 0, 0), ::capnp::word(0, 0, 0, 0, 0, 0, 0, 0), @@ -4322,11 +4014,11 @@ pub mod term { match index { 0 => ::introspect(), 1 => <::capnp::primitive_list::Owned as ::capnp::introspect::Introspect>::introspect(), - _ => ::capnp::introspect::panic_invalid_field_index(index), + _ => panic!("invalid field index {}", index), } } pub fn get_annotation_types(child_index: Option, index: u32) -> ::capnp::introspect::Type { - ::capnp::introspect::panic_invalid_annotation_indices(child_index, index) + panic!("invalid annotation indices ({:?}, {}) ", child_index, index) } pub static RAW_SCHEMA: ::capnp::introspect::RawStructSchema = ::capnp::introspect::RawStructSchema { encoded_node: &ENCODED_NODE, @@ -4491,14 +4183,13 @@ pub mod term { impl Pipeline { } mod _private { - pub static ENCODED_NODE: [::capnp::Word; 49] = [ - ::capnp::word(0, 0, 0, 0, 6, 0, 6, 0), + pub static ENCODED_NODE: [::capnp::Word; 48] = [ + ::capnp::word(0, 0, 0, 0, 5, 0, 6, 0), ::capnp::word(55, 205, 218, 56, 109, 17, 119, 134), ::capnp::word(25, 0, 0, 0, 1, 0, 2, 0), ::capnp::word(178, 107, 91, 137, 60, 121, 191, 207), ::capnp::word(1, 0, 7, 0, 1, 0, 0, 0), ::capnp::word(0, 0, 0, 0, 0, 0, 0, 0), - ::capnp::word(0, 0, 0, 0, 0, 0, 0, 0), ::capnp::word(21, 0, 0, 0, 18, 1, 0, 0), ::capnp::word(0, 0, 0, 0, 0, 0, 0, 0), ::capnp::word(0, 0, 0, 0, 0, 0, 0, 0), @@ -4546,11 +4237,11 @@ pub mod term { match index { 0 => ::introspect(), 1 => ::introspect(), - _ => ::capnp::introspect::panic_invalid_field_index(index), + _ => panic!("invalid field index {}", index), } } pub fn get_annotation_types(child_index: Option, index: u32) -> ::capnp::introspect::Type { - ::capnp::introspect::panic_invalid_annotation_indices(child_index, index) + panic!("invalid annotation indices ({:?}, {}) ", child_index, index) } pub static RAW_SCHEMA: ::capnp::introspect::RawStructSchema = ::capnp::introspect::RawStructSchema { encoded_node: &ENCODED_NODE, @@ -4728,14 +4419,13 @@ pub mod param { impl Pipeline { } mod _private { - pub static ENCODED_NODE: [::capnp::Word; 49] = [ - ::capnp::word(0, 0, 0, 0, 6, 0, 6, 0), + pub static ENCODED_NODE: [::capnp::Word; 48] = [ + ::capnp::word(0, 0, 0, 0, 5, 0, 6, 0), ::capnp::word(232, 73, 199, 85, 129, 167, 53, 211), ::capnp::word(20, 0, 0, 0, 1, 0, 1, 0), ::capnp::word(1, 150, 80, 40, 197, 50, 43, 224), ::capnp::word(1, 0, 7, 0, 0, 0, 0, 0), ::capnp::word(0, 0, 0, 0, 0, 0, 0, 0), - ::capnp::word(107, 9, 0, 0, 163, 9, 0, 0), ::capnp::word(21, 0, 0, 0, 210, 0, 0, 0), ::capnp::word(33, 0, 0, 0, 7, 0, 0, 0), ::capnp::word(0, 0, 0, 0, 0, 0, 0, 0), @@ -4783,11 +4473,11 @@ pub mod param { match index { 0 => <::capnp::text::Owned as ::capnp::introspect::Introspect>::introspect(), 1 => ::introspect(), - _ => ::capnp::introspect::panic_invalid_field_index(index), + _ => panic!("invalid field index {}", index), } } pub fn get_annotation_types(child_index: Option, index: u32) -> ::capnp::introspect::Type { - ::capnp::introspect::panic_invalid_annotation_indices(child_index, index) + panic!("invalid annotation indices ({:?}, {}) ", child_index, index) } pub static RAW_SCHEMA: ::capnp::introspect::RawStructSchema = ::capnp::introspect::RawStructSchema { encoded_node: &ENCODED_NODE, @@ -4801,75 +4491,3 @@ pub mod param { pub const TYPE_ID: u64 = 0xd335_a781_55c7_49e8; } } - -#[repr(u16)] -#[derive(Clone, Copy, Debug, PartialEq, Eq)] -pub enum Visibility { - Unspecified = 0, - Private = 1, - Public = 2, -} - -impl ::capnp::introspect::Introspect for Visibility { - fn introspect() -> ::capnp::introspect::Type { ::capnp::introspect::TypeVariant::Enum(::capnp::introspect::RawEnumSchema { encoded_node: &visibility::ENCODED_NODE, annotation_types: visibility::get_annotation_types }).into() } -} -impl ::core::convert::From for ::capnp::dynamic_value::Reader<'_> { - fn from(e: Visibility) -> Self { ::capnp::dynamic_value::Enum::new(e.into(), ::capnp::introspect::RawEnumSchema { encoded_node: &visibility::ENCODED_NODE, annotation_types: visibility::get_annotation_types }.into()).into() } -} -impl ::core::convert::TryFrom for Visibility { - type Error = ::capnp::NotInSchema; - fn try_from(value: u16) -> ::core::result::Result>::Error> { - match value { - 0 => ::core::result::Result::Ok(Self::Unspecified), - 1 => ::core::result::Result::Ok(Self::Private), - 2 => ::core::result::Result::Ok(Self::Public), - n => ::core::result::Result::Err(::capnp::NotInSchema(n)), - } - } -} -impl From for u16 { - #[inline] - fn from(x: Visibility) -> u16 { x as u16 } -} -impl ::capnp::traits::HasTypeId for Visibility { - const TYPE_ID: u64 = 0x8d83_15f2_7a68_8301u64; -} -mod visibility { -pub static ENCODED_NODE: [::capnp::Word; 32] = [ - ::capnp::word(0, 0, 0, 0, 6, 0, 6, 0), - ::capnp::word(1, 131, 104, 122, 242, 21, 131, 141), - ::capnp::word(20, 0, 0, 0, 2, 0, 0, 0), - ::capnp::word(1, 150, 80, 40, 197, 50, 43, 224), - ::capnp::word(0, 0, 0, 0, 0, 0, 0, 0), - ::capnp::word(0, 0, 0, 0, 0, 0, 0, 0), - ::capnp::word(165, 9, 0, 0, 235, 9, 0, 0), - ::capnp::word(21, 0, 0, 0, 250, 0, 0, 0), - ::capnp::word(33, 0, 0, 0, 7, 0, 0, 0), - ::capnp::word(0, 0, 0, 0, 0, 0, 0, 0), - ::capnp::word(29, 0, 0, 0, 79, 0, 0, 0), - ::capnp::word(0, 0, 0, 0, 0, 0, 0, 0), - ::capnp::word(0, 0, 0, 0, 0, 0, 0, 0), - ::capnp::word(99, 97, 112, 110, 112, 47, 104, 117), - ::capnp::word(103, 114, 45, 118, 48, 46, 99, 97), - ::capnp::word(112, 110, 112, 58, 86, 105, 115, 105), - ::capnp::word(98, 105, 108, 105, 116, 121, 0, 0), - ::capnp::word(0, 0, 0, 0, 1, 0, 1, 0), - ::capnp::word(12, 0, 0, 0, 1, 0, 2, 0), - ::capnp::word(0, 0, 0, 0, 0, 0, 0, 0), - ::capnp::word(29, 0, 0, 0, 98, 0, 0, 0), - ::capnp::word(0, 0, 0, 0, 0, 0, 0, 0), - ::capnp::word(1, 0, 0, 0, 0, 0, 0, 0), - ::capnp::word(25, 0, 0, 0, 66, 0, 0, 0), - ::capnp::word(0, 0, 0, 0, 0, 0, 0, 0), - ::capnp::word(2, 0, 0, 0, 0, 0, 0, 0), - ::capnp::word(17, 0, 0, 0, 58, 0, 0, 0), - ::capnp::word(0, 0, 0, 0, 0, 0, 0, 0), - ::capnp::word(117, 110, 115, 112, 101, 99, 105, 102), - ::capnp::word(105, 101, 100, 0, 0, 0, 0, 0), - ::capnp::word(112, 114, 105, 118, 97, 116, 101, 0), - ::capnp::word(112, 117, 98, 108, 105, 99, 0, 0), -]; -pub fn get_annotation_types(child_index: Option, index: u32) -> ::capnp::introspect::Type { - ::capnp::introspect::panic_invalid_annotation_indices(child_index, index) -} -} diff --git a/hugr-model/src/lib.rs b/hugr-model/src/lib.rs index b4b0b8fce2..c139cefc42 100644 --- a/hugr-model/src/lib.rs +++ b/hugr-model/src/lib.rs @@ -1,29 +1,10 @@ //! The data model of the HUGR intermediate representation. -//! //! This crate defines data structures that capture the structure of a HUGR graph and //! all its associated information in a form that can be stored on disk. The data structures //! are not designed for efficient traversal or modification, but for simplicity and serialization. -//! -//! This crate supports version ` -#![doc = include_str!("../FORMAT_VERSION")] -//! ` of the HUGR model format. mod capnp; pub mod v0; -use std::sync::LazyLock; - // This is required here since the generated code assumes it's in the package root. use capnp::hugr_v0_capnp; - -/// The current version of the HUGR model format. -pub static CURRENT_VERSION: LazyLock = LazyLock::new(|| { - // We allow non-zero patch versions, but ignore them for compatibility checks. - let v = semver::Version::parse(include_str!("../FORMAT_VERSION").trim()) - .expect("`FORMAT_VERSION` in `hugr-model` contains version that fails to parse"); - assert!( - v.pre.is_empty(), - "`FORMAT_VERSION` in `hugr-model` should not have a pre-release version" - ); - v -}); diff --git a/hugr-model/src/v0/ast/hugr.pest b/hugr-model/src/v0/ast/hugr.pest index d960cb3a40..698f32b056 100644 --- a/hugr-model/src/v0/ast/hugr.pest +++ b/hugr-model/src/v0/ast/hugr.pest @@ -16,8 +16,6 @@ reserved = @{ | "list" | "meta" | "signature" - | "public" - | "private" | "dfg" | "cfg" | "block" @@ -81,9 +79,7 @@ node_cond = { "(" ~ "cond" ~ port_lists? ~ signature? ~ meta* ~ reg node_import = { "(" ~ "import" ~ symbol_name ~ meta* ~ ")" } node_custom = { "(" ~ term ~ port_lists? ~ signature? ~ meta* ~ region* ~ ")" } -visibility = { "public" | "private" } - -symbol = { visibility? ~ symbol_name ~ param* ~ where_clause* ~ term } +symbol = { symbol_name ~ param* ~ where_clause* ~ term } signature = { "(" ~ "signature" ~ term ~ ")" } param = { "(" ~ "param" ~ term_var ~ term ~ ")" } diff --git a/hugr-model/src/v0/ast/mod.rs b/hugr-model/src/v0/ast/mod.rs index b6e817b990..faee6f8276 100644 --- a/hugr-model/src/v0/ast/mod.rs +++ b/hugr-model/src/v0/ast/mod.rs @@ -25,7 +25,7 @@ use std::sync::Arc; use bumpalo::Bump; use super::table::{self}; -use super::{LinkName, Literal, RegionKind, SymbolName, VarName, Visibility}; +use super::{LinkName, Literal, RegionKind, SymbolName, VarName}; mod parse; mod print; @@ -194,8 +194,6 @@ impl Operation { /// [`table::Symbol`]: crate::v0::table::Symbol #[derive(Debug, Clone, PartialEq, Eq)] pub struct Symbol { - /// The visibility of the symbol. - pub visibility: Option, /// The name of the symbol. pub name: SymbolName, /// The parameters of the symbol. diff --git a/hugr-model/src/v0/ast/parse.rs b/hugr-model/src/v0/ast/parse.rs index a2c9a5cd9b..d1d5dff741 100644 --- a/hugr-model/src/v0/ast/parse.rs +++ b/hugr-model/src/v0/ast/parse.rs @@ -28,7 +28,7 @@ use thiserror::Error; use crate::v0::ast::{LinkName, Module, Operation, SeqPart}; use crate::v0::{Literal, RegionKind}; -use super::{Node, Package, Param, Region, Symbol, VarName, Visibility}; +use super::{Node, Package, Param, Region, Symbol, VarName}; use super::{SymbolName, Term}; mod pest_parser { @@ -292,23 +292,13 @@ fn parse_param(pair: Pair) -> ParseResult { fn parse_symbol(pair: Pair) -> ParseResult { debug_assert_eq!(Rule::symbol, pair.as_rule()); - let mut pairs = pair.into_inner(); - let visibility = take_rule(&mut pairs, Rule::visibility) - .next() - .map(|pair| match pair.as_str() { - "public" => Ok(Visibility::Public), - "private" => Ok(Visibility::Private), - _ => unreachable!("Expected 'public' or 'private', got {}", pair.as_str()), - }) - .transpose()?; let name = parse_symbol_name(pairs.next().unwrap())?; let params = parse_params(&mut pairs)?; let constraints = parse_constraints(&mut pairs)?; let signature = parse_term(pairs.next().unwrap())?; Ok(Symbol { - visibility, name, params, constraints, diff --git a/hugr-model/src/v0/ast/print.rs b/hugr-model/src/v0/ast/print.rs index 071146dedd..dd47602a4b 100644 --- a/hugr-model/src/v0/ast/print.rs +++ b/hugr-model/src/v0/ast/print.rs @@ -7,7 +7,7 @@ use crate::v0::{Literal, RegionKind}; use super::{ LinkName, Module, Node, Operation, Package, Param, Region, SeqPart, Symbol, SymbolName, Term, - VarName, Visibility, + VarName, }; struct Printer<'a> { @@ -369,12 +369,6 @@ fn print_region<'a>(printer: &mut Printer<'a>, region: &'a Region) { } fn print_symbol<'a>(printer: &mut Printer<'a>, symbol: &'a Symbol) { - match symbol.visibility { - None => (), - Some(Visibility::Private) => printer.text("private"), - Some(Visibility::Public) => printer.text("public"), - } - print_symbol_name(printer, &symbol.name); for param in &symbol.params { diff --git a/hugr-model/src/v0/ast/python.rs b/hugr-model/src/v0/ast/python.rs index b70d9447c9..90ef22a814 100644 --- a/hugr-model/src/v0/ast/python.rs +++ b/hugr-model/src/v0/ast/python.rs @@ -1,7 +1,5 @@ use std::sync::Arc; -use crate::v0::Visibility; - use super::{Module, Node, Operation, Package, Param, Region, SeqPart, Symbol, Term}; use pyo3::{ Bound, PyAny, PyResult, @@ -141,41 +139,13 @@ impl<'py> pyo3::IntoPyObject<'py> for &Param { } } -impl<'py> pyo3::FromPyObject<'py> for Visibility { - fn extract_bound(ob: &Bound<'py, PyAny>) -> PyResult { - match ob.str()?.to_str()? { - "Public" => Ok(Visibility::Public), - "Private" => Ok(Visibility::Private), - s => Err(PyTypeError::new_err(format!( - "Expected \"Public\" or \"Private\", got {s}", - ))), - } - } -} - -impl<'py> pyo3::IntoPyObject<'py> for &Visibility { - type Target = pyo3::types::PyAny; - type Output = pyo3::Bound<'py, Self::Target>; - type Error = pyo3::PyErr; - - fn into_pyobject(self, py: pyo3::Python<'py>) -> Result { - let s = match self { - Visibility::Private => "Private", - Visibility::Public => "Public", - }; - Ok(pyo3::types::PyString::new(py, s).into_any()) - } -} - impl<'py> pyo3::FromPyObject<'py> for Symbol { fn extract_bound(symbol: &Bound<'py, PyAny>) -> PyResult { let name = symbol.getattr("name")?.extract()?; let params: Vec<_> = symbol.getattr("params")?.extract()?; - let visibility = symbol.getattr("visibility")?.extract()?; let constraints: Vec<_> = symbol.getattr("constraints")?.extract()?; let signature = symbol.getattr("signature")?.extract()?; Ok(Self { - visibility, name, signature, params: params.into(), @@ -194,7 +164,6 @@ impl<'py> pyo3::IntoPyObject<'py> for &Symbol { let py_class = py_module.getattr("Symbol")?; py_class.call1(( self.name.as_ref(), - &self.visibility, self.params.as_ref(), self.constraints.as_ref(), &self.signature, @@ -456,6 +425,5 @@ impl_into_pyobject_owned!(Symbol); impl_into_pyobject_owned!(Module); impl_into_pyobject_owned!(Package); impl_into_pyobject_owned!(Node); -impl_into_pyobject_owned!(Visibility); impl_into_pyobject_owned!(Region); impl_into_pyobject_owned!(Operation); diff --git a/hugr-model/src/v0/ast/resolve.rs b/hugr-model/src/v0/ast/resolve.rs index f126d3e69e..d691de0f01 100644 --- a/hugr-model/src/v0/ast/resolve.rs +++ b/hugr-model/src/v0/ast/resolve.rs @@ -289,13 +289,11 @@ impl<'a> Context<'a> { fn resolve_symbol(&mut self, symbol: &'a Symbol) -> BuildResult<&'a table::Symbol<'a>> { let name = symbol.name.as_ref(); - let visibility = &symbol.visibility; let params = self.resolve_params(&symbol.params)?; let constraints = self.resolve_terms(&symbol.constraints)?; let signature = self.resolve_term(&symbol.signature)?; Ok(self.bump.alloc(table::Symbol { - visibility, name, params, constraints, @@ -365,7 +363,6 @@ impl<'a> Context<'a> { /// Error that may occur in [`Module::resolve`]. #[derive(Debug, Clone, Error)] #[non_exhaustive] -#[error("Error resolving model module")] pub enum ResolveError { /// Unknown variable. #[error("unknown var: {0}")] @@ -391,18 +388,3 @@ fn try_alloc_slice( } Ok(vec.into_bump_slice()) } - -#[cfg(test)] -mod test { - use crate::v0::ast; - use bumpalo::Bump; - use std::str::FromStr as _; - - #[test] - fn vars_in_root_scope() { - let text = "(hugr 0) (mod) (meta ?x)"; - let ast = ast::Package::from_str(text).unwrap(); - let bump = Bump::new(); - assert!(ast.resolve(&bump).is_err()); - } -} diff --git a/hugr-model/src/v0/ast/view.rs b/hugr-model/src/v0/ast/view.rs index 8c38038100..8feb158539 100644 --- a/hugr-model/src/v0/ast/view.rs +++ b/hugr-model/src/v0/ast/view.rs @@ -91,13 +91,11 @@ impl<'a> View<'a, table::SeqPart> for SeqPart { impl<'a> View<'a, table::Symbol<'a>> for Symbol { fn view(module: &'a table::Module<'a>, id: table::Symbol<'a>) -> Option { - let visibility = id.visibility.clone(); let name = SymbolName::new(id.name); let params = module.view(id.params)?; let constraints = module.view(id.constraints)?; let signature = module.view(id.signature)?; Some(Symbol { - visibility, name, params, constraints, diff --git a/hugr-model/src/v0/binary/read.rs b/hugr-model/src/v0/binary/read.rs index c5f65463ba..001a805cc9 100644 --- a/hugr-model/src/v0/binary/read.rs +++ b/hugr-model/src/v0/binary/read.rs @@ -1,6 +1,6 @@ use crate::capnp::hugr_v0_capnp as hugr_capnp; +use crate::v0 as model; use crate::v0::table; -use crate::{CURRENT_VERSION, v0 as model}; use bumpalo::Bump; use bumpalo::collections::Vec as BumpVec; use std::io::BufRead; @@ -8,20 +8,10 @@ use std::io::BufRead; /// An error encountered while deserialising a model. #[derive(Debug, derive_more::From, derive_more::Display, derive_more::Error)] #[non_exhaustive] -#[display("Error reading a HUGR model payload.")] pub enum ReadError { #[from(forward)] /// An error encountered while decoding a model from a `capnproto` buffer. DecodingError(capnp::Error), - - /// The file could not be read due to a version mismatch. - #[display("Can not read file with version {actual} (tooling version {current}).")] - VersionError { - /// The current version of the hugr-model format. - current: semver::Version, - /// The version of the hugr-model format in the file. - actual: semver::Version, - }, } type ReadResult = Result; @@ -67,15 +57,6 @@ fn read_package<'a>( bump: &'a Bump, reader: hugr_capnp::package::Reader, ) -> ReadResult> { - let version = read_version(reader.get_version()?)?; - - if version.major != CURRENT_VERSION.major || version.minor > CURRENT_VERSION.minor { - return Err(ReadError::VersionError { - current: CURRENT_VERSION.clone(), - actual: version, - }); - } - let modules = reader .get_modules()? .iter() @@ -85,12 +66,6 @@ fn read_package<'a>( Ok(table::Package { modules }) } -fn read_version(reader: hugr_capnp::version::Reader) -> ReadResult { - let major = reader.get_major(); - let minor = reader.get_minor(); - Ok(semver::Version::new(major as u64, minor as u64, 0)) -} - fn read_module<'a>( bump: &'a Bump, reader: hugr_capnp::module::Reader, @@ -151,21 +126,88 @@ fn read_operation<'a>( Which::Dfg(()) => table::Operation::Dfg, Which::Cfg(()) => table::Operation::Cfg, Which::Block(()) => table::Operation::Block, - Which::FuncDefn(reader) => table::Operation::DefineFunc(read_symbol(bump, reader?, None)?), - Which::FuncDecl(reader) => table::Operation::DeclareFunc(read_symbol(bump, reader?, None)?), + Which::FuncDefn(reader) => { + let reader = reader?; + let name = bump.alloc_str(reader.get_name()?.to_str()?); + let params = read_list!(bump, reader.get_params()?, read_param); + let constraints = read_scalar_list!(bump, reader, get_constraints, table::TermId); + let signature = table::TermId(reader.get_signature()); + let symbol = bump.alloc(table::Symbol { + name, + params, + constraints, + signature, + }); + table::Operation::DefineFunc(symbol) + } + Which::FuncDecl(reader) => { + let reader = reader?; + let name = bump.alloc_str(reader.get_name()?.to_str()?); + let params = read_list!(bump, reader.get_params()?, read_param); + let constraints = read_scalar_list!(bump, reader, get_constraints, table::TermId); + let signature = table::TermId(reader.get_signature()); + let symbol = bump.alloc(table::Symbol { + name, + params, + constraints, + signature, + }); + table::Operation::DeclareFunc(symbol) + } Which::AliasDefn(reader) => { let symbol = reader.get_symbol()?; let value = table::TermId(reader.get_value()); - table::Operation::DefineAlias(read_symbol(bump, symbol, Some(&[]))?, value) + let name = bump.alloc_str(symbol.get_name()?.to_str()?); + let params = read_list!(bump, symbol.get_params()?, read_param); + let signature = table::TermId(symbol.get_signature()); + let symbol = bump.alloc(table::Symbol { + name, + params, + constraints: &[], + signature, + }); + table::Operation::DefineAlias(symbol, value) } Which::AliasDecl(reader) => { - table::Operation::DeclareAlias(read_symbol(bump, reader?, Some(&[]))?) + let reader = reader?; + let name = bump.alloc_str(reader.get_name()?.to_str()?); + let params = read_list!(bump, reader.get_params()?, read_param); + let signature = table::TermId(reader.get_signature()); + let symbol = bump.alloc(table::Symbol { + name, + params, + constraints: &[], + signature, + }); + table::Operation::DeclareAlias(symbol) } Which::ConstructorDecl(reader) => { - table::Operation::DeclareConstructor(read_symbol(bump, reader?, None)?) + let reader = reader?; + let name = bump.alloc_str(reader.get_name()?.to_str()?); + let params = read_list!(bump, reader.get_params()?, read_param); + let constraints = read_scalar_list!(bump, reader, get_constraints, table::TermId); + let signature = table::TermId(reader.get_signature()); + let symbol = bump.alloc(table::Symbol { + name, + params, + constraints, + signature, + }); + table::Operation::DeclareConstructor(symbol) } Which::OperationDecl(reader) => { - table::Operation::DeclareOperation(read_symbol(bump, reader?, None)?) + let reader = reader?; + let name = bump.alloc_str(reader.get_name()?.to_str()?); + let params = read_list!(bump, reader.get_params()?, read_param); + let constraints = read_scalar_list!(bump, reader, get_constraints, table::TermId); + let signature = table::TermId(reader.get_signature()); + let symbol = bump.alloc(table::Symbol { + name, + params, + constraints, + signature, + }); + table::Operation::DeclareOperation(symbol) } Which::Custom(operation) => table::Operation::Custom(table::TermId(operation)), Which::TailLoop(()) => table::Operation::TailLoop, @@ -215,40 +257,6 @@ fn read_region_scope(reader: hugr_capnp::region_scope::Reader) -> ReadResult for Option { - fn from(value: hugr_capnp::Visibility) -> Self { - match value { - hugr_capnp::Visibility::Unspecified => None, - hugr_capnp::Visibility::Private => Some(model::Visibility::Private), - hugr_capnp::Visibility::Public => Some(model::Visibility::Public), - } - } -} - -/// (Only) if `constraints` are None, then they are read from the `reader` -fn read_symbol<'a>( - bump: &'a Bump, - reader: hugr_capnp::symbol::Reader, - constraints: Option<&'a [table::TermId]>, -) -> ReadResult<&'a mut table::Symbol<'a>> { - let name = bump.alloc_str(reader.get_name()?.to_str()?); - let visibility = reader.get_visibility()?.into(); - let visibility = bump.alloc(visibility); - let params = read_list!(bump, reader.get_params()?, read_param); - let constraints = match constraints { - Some(cs) => cs, - None => read_scalar_list!(bump, reader, get_constraints, table::TermId), - }; - let signature = table::TermId(reader.get_signature()); - Ok(bump.alloc(table::Symbol { - visibility, - name, - params, - constraints, - signature, - })) -} - fn read_term<'a>(bump: &'a Bump, reader: hugr_capnp::term::Reader) -> ReadResult> { use hugr_capnp::term::Which; Ok(match reader.which()? { diff --git a/hugr-model/src/v0/binary/write.rs b/hugr-model/src/v0/binary/write.rs index 49919dc481..e9b76eca1a 100644 --- a/hugr-model/src/v0/binary/write.rs +++ b/hugr-model/src/v0/binary/write.rs @@ -1,15 +1,13 @@ use std::io::Write; -use crate::CURRENT_VERSION; use crate::capnp::hugr_v0_capnp as hugr_capnp; -use crate::v0::{self as model, table}; +use crate::v0 as model; +use crate::v0::table; /// An error encounter while serializing a model. #[derive(Debug, derive_more::From, derive_more::Display, derive_more::Error)] #[non_exhaustive] -#[display("Error encoding a package in HUGR model format.")] pub enum WriteError { - #[from(forward)] /// An error encountered while encoding a `capnproto` buffer. EncodingError(capnp::Error), } @@ -47,12 +45,6 @@ pub fn write_to_vec(package: &table::Package) -> Vec { fn write_package(mut builder: hugr_capnp::package::Builder, package: &table::Package) { write_list!(builder, init_modules, write_module, package.modules); - write_version(builder.init_version(), &CURRENT_VERSION); -} - -fn write_version(mut builder: hugr_capnp::version::Builder, version: &semver::Version) { - builder.set_major(version.major as u32); - builder.set_minor(version.minor as u32); } fn write_module(mut builder: hugr_capnp::module::Builder, module: &table::Module) { @@ -118,12 +110,6 @@ fn write_operation(mut builder: hugr_capnp::operation::Builder, operation: &tabl fn write_symbol(mut builder: hugr_capnp::symbol::Builder, symbol: &table::Symbol) { builder.set_name(symbol.name); - if let Some(vis) = symbol.visibility { - builder.set_visibility(match vis { - model::Visibility::Private => hugr_capnp::Visibility::Private, - model::Visibility::Public => hugr_capnp::Visibility::Public, - }) - } // else, None -> use capnp default == Unspecified write_list!(builder, init_params, write_param, symbol.params); let _ = builder.set_constraints(table::TermId::unwrap_slice(symbol.constraints)); builder.set_signature(symbol.signature.0); diff --git a/hugr-model/src/v0/mod.rs b/hugr-model/src/v0/mod.rs index c74201a910..15e29f3bde 100644 --- a/hugr-model/src/v0/mod.rs +++ b/hugr-model/src/v0/mod.rs @@ -91,15 +91,6 @@ use smol_str::SmolStr; use std::sync::Arc; use table::LinkIndex; -/// Describes how a function or symbol should be acted upon by a linker -#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)] -pub enum Visibility { - /// The linker should ignore this function or symbol - Private, - /// The linker should act upon this function or symbol - Public, -} - /// Core function types. /// /// - **Parameter:** `?inputs : (core.list core.type)` @@ -172,13 +163,17 @@ pub const CORE_BYTES_TYPE: &str = "core.bytes"; /// - **Result:** `core.static` pub const CORE_FLOAT_TYPE: &str = "core.float"; -/// Type of control flow regions. +/// Type of a control flow edge. /// -/// - **Parameter:** `?inputs : (core.list (core.list core.type))` -/// - **Parameter:** `?outputs : (core.list (core.list core.type))` -/// - **Result:** `core.type` +/// - **Parameter:** `?types : (core.list core.type)` +/// - **Result:** `core.ctrl_type` pub const CORE_CTRL: &str = "core.ctrl"; +/// The type of the types for control flow edges. +/// +/// - **Result:** `?type : core.static` +pub const CORE_CTRL_TYPE: &str = "core.ctrl_type"; + /// The type for runtime constants. /// /// - **Parameter:** `?type : core.type` @@ -287,26 +282,6 @@ pub const COMPAT_CONST_JSON: &str = "compat.const_json"; /// - **Result:** `core.meta` pub const ORDER_HINT_KEY: &str = "core.order_hint.key"; -/// Metadata constructor for order hint keys on input nodes. -/// -/// When the sources of a dataflow region are represented by an input operation -/// within the region, this metadata can be attached the region to give the -/// input node an order hint key. -/// -/// - **Parameter:** `?key : core.nat` -/// - **Result:** `core.meta` -pub const ORDER_HINT_INPUT_KEY: &str = "core.order_hint.input_key"; - -/// Metadata constructor for order hint keys on output nodes. -/// -/// When the targets of a dataflow region are represented by an output operation -/// within the region, this metadata can be attached the region to give the -/// output node an order hint key. -/// -/// - **Parameter:** `?key : core.nat` -/// - **Result:** `core.meta` -pub const ORDER_HINT_OUTPUT_KEY: &str = "core.order_hint.output_key"; - /// Metadata constructor for order hints. /// /// When this metadata is attached to a dataflow region, it can indicate a @@ -322,18 +297,6 @@ pub const ORDER_HINT_OUTPUT_KEY: &str = "core.order_hint.output_key"; /// - **Result:** `core.meta` pub const ORDER_HINT_ORDER: &str = "core.order_hint.order"; -/// Metadata constructor for symbol titles. -/// -/// The names of functions in `hugr-core` are currently not used for symbol -/// resolution, but rather serve as a short description of the function. -/// As such, there is no requirement for uniqueness or formatting. -/// This metadata can be used to preserve that name when serializing through -/// `hugr-model`. -/// -/// - **Parameter:** `?title: core.str` -/// - **Result:** `core.meta` -pub const CORE_TITLE: &str = "core.title"; - pub mod ast; pub mod binary; pub mod scope; diff --git a/hugr-model/src/v0/scope/vars.rs b/hugr-model/src/v0/scope/vars.rs index b7085e2ee8..e35d8812c3 100644 --- a/hugr-model/src/v0/scope/vars.rs +++ b/hugr-model/src/v0/scope/vars.rs @@ -78,23 +78,28 @@ impl<'a> VarTable<'a> { /// # Errors /// /// Returns an error if the variable is not defined in the current scope. + /// + /// # Panics + /// + /// Panics if there are no open scopes. pub fn resolve(&self, name: &'a str) -> Result> { - let scope = self.scopes.last().ok_or(UnknownVarError::Root(name))?; + let scope = self.scopes.last().unwrap(); let set_index = self .vars .get_index_of(&(scope.node, name)) - .ok_or(UnknownVarError::WithinNode(scope.node, name))?; + .ok_or(UnknownVarError(scope.node, name))?; let var_index = (set_index - scope.var_stack) as u16; Ok(VarId(scope.node, var_index)) } /// Check if a variable is visible in the current scope. + /// + /// # Panics + /// + /// Panics if there are no open scopes. #[must_use] pub fn is_visible(&self, var: VarId) -> bool { - let Some(scope) = self.scopes.last() else { - return false; - }; - + let scope = self.scopes.last().unwrap(); scope.node == var.0 && var.1 < scope.var_count } @@ -144,11 +149,5 @@ pub struct DuplicateVarError<'a>(NodeId, &'a str); /// Error that occurs when a variable is not defined in the current scope. #[derive(Debug, Clone, Error)] -pub enum UnknownVarError<'a> { - /// Failed to resolve a variable when in scope of a node. - #[error("can not resolve variable `{1}` in node {0}")] - WithinNode(NodeId, &'a str), - /// Failed to resolve a variable when in the root scope. - #[error("can not resolve variable `{0}` in the root scope")] - Root(&'a str), -} +#[error("can not resolve variable `{1}` in node {0}")] +pub struct UnknownVarError<'a>(NodeId, &'a str); diff --git a/hugr-model/src/v0/table/mod.rs b/hugr-model/src/v0/table/mod.rs index 6ca6370f8f..501305510b 100644 --- a/hugr-model/src/v0/table/mod.rs +++ b/hugr-model/src/v0/table/mod.rs @@ -29,7 +29,7 @@ use smol_str::SmolStr; use thiserror::Error; mod view; -use super::{Literal, RegionKind, Visibility, ast}; +use super::{Literal, RegionKind, ast}; pub use view::View; /// A package consisting of a sequence of [`Module`]s. @@ -303,8 +303,6 @@ pub struct RegionScope { /// [`ast::Symbol`]: crate::v0::ast::Symbol #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] pub struct Symbol<'a> { - /// The visibility of the symbol. - pub visibility: &'a Option, /// The name of the symbol. pub name: &'a str, /// The static parameters. diff --git a/hugr-model/tests/fixtures/model-add.edn b/hugr-model/tests/fixtures/model-add.edn index 5b02678744..93b3a1a5b3 100644 --- a/hugr-model/tests/fixtures/model-add.edn +++ b/hugr-model/tests/fixtures/model-add.edn @@ -2,19 +2,14 @@ (mod) -(define-func - public - example.add +(define-func example.add (core.fn - [(arithmetic.int.types.int 6) (arithmetic.int.types.int 6)] - [(arithmetic.int.types.int 6)]) - (dfg [%0 %1] [%2] - (signature - (core.fn - [(arithmetic.int.types.int 6) (arithmetic.int.types.int 6)] - [(arithmetic.int.types.int 6)])) - ((arithmetic.int.iadd 6) [%0 %1] [%2] - (signature - (core.fn - [(arithmetic.int.types.int 6) (arithmetic.int.types.int 6)] - [(arithmetic.int.types.int 6)]))))) + [arithmetic.int.types.int arithmetic.int.types.int] + [arithmetic.int.types.int]) + (dfg + [%0 %1] + [%2] + (signature (core.fn [arithmetic.int.types.int arithmetic.int.types.int] [arithmetic.int.types.int])) + (arithmetic.int.iadd + [%0 %1] [%2] + (signature (core.fn [arithmetic.int.types.int arithmetic.int.types.int] [arithmetic.int.types.int]))))) diff --git a/hugr-model/tests/fixtures/model-call.edn b/hugr-model/tests/fixtures/model-call.edn index 8fa25feaf6..4bf2eaaac4 100644 --- a/hugr-model/tests/fixtures/model-call.edn +++ b/hugr-model/tests/fixtures/model-call.edn @@ -2,13 +2,13 @@ (mod) -(declare-func public +(declare-func example.callee (core.fn [arithmetic.int.types.int] [arithmetic.int.types.int]) (meta (compat.meta_json "title" "\"Callee\"")) (meta (compat.meta_json "description" "\"This is a function declaration.\""))) -(define-func public example.caller +(define-func example.caller (core.fn [arithmetic.int.types.int] [arithmetic.int.types.int]) (meta (compat.meta_json "title" "\"Caller\"")) (meta (compat.meta_json "description" "\"This defines a function that calls the function which we declared earlier.\"")) @@ -17,7 +17,7 @@ ((core.call _ _ example.callee) [%3] [%4] (signature (core.fn [arithmetic.int.types.int] [arithmetic.int.types.int]))))) -(define-func public +(define-func example.load (core.fn [] [(core.fn [arithmetic.int.types.int] [arithmetic.int.types.int])]) (dfg diff --git a/hugr-model/tests/fixtures/model-cfg.edn b/hugr-model/tests/fixtures/model-cfg.edn index bfce67854b..e2a760f5c0 100644 --- a/hugr-model/tests/fixtures/model-cfg.edn +++ b/hugr-model/tests/fixtures/model-cfg.edn @@ -2,23 +2,23 @@ (mod) -(define-func public example.cfg_loop +(define-func example.cfg_loop (param ?a core.type) (core.fn [?a] [?a]) (dfg [%0] [%1] - (signature (core.fn [?a] [?a])) - (cfg [%0] [%1] - (signature (core.fn [?a] [?a])) - (cfg [%2] [%3] - (signature (core.ctrl [[?a]] [[?a]])) - (block [%2] [%3 %2] - (signature (core.ctrl [[?a]] [[?a] [?a]])) - (dfg [%4] [%5] - (signature (core.fn [?a] [(core.adt [[?a] [?a]])])) - ((core.make_adt 0) [%4] [%5] - (signature (core.fn [?a] [(core.adt [[?a] [?a]])]))))))))) + (signature (core.fn [?a] [?a])) + (cfg [%0] [%1] + (signature (core.fn [?a] [?a])) + (cfg [%2] [%4] + (signature (core.fn [(core.ctrl [?a])] [(core.ctrl [?a])])) + (block [%2] [%4 %2] + (signature (core.fn [(core.ctrl [?a])] [(core.ctrl [?a]) (core.ctrl [?a])])) + (dfg [%5] [%6] + (signature (core.fn [?a] [(core.adt [[?a] [?a]])])) + ((core.make_adt 0) [%5] [%6] + (signature (core.fn [?a] [(core.adt [[?a] [?a]])]))))))))) -(define-func public example.cfg_order +(define-func example.cfg_order (param ?a core.type) (core.fn [?a] [?a]) (dfg [%0] [%1] @@ -26,15 +26,15 @@ (cfg [%0] [%1] (signature (core.fn [?a] [?a])) (cfg [%2] [%4] - (signature (core.ctrl [[?a]] [[?a]])) + (signature (core.fn [(core.ctrl [?a])] [(core.ctrl [?a])])) (block [%3] [%4] - (signature (core.ctrl [[?a]] [[?a]])) + (signature (core.fn [(core.ctrl [?a])] [(core.ctrl [?a])])) (dfg [%5] [%6] (signature (core.fn [?a] [(core.adt [[?a]])])) ((core.make_adt _ _ 0) [%5] [%6] (signature (core.fn [?a] [(core.adt [[?a]])]))))) (block [%2] [%3] - (signature (core.ctrl [[?a]] [[?a]])) + (signature (core.fn [(core.ctrl [?a])] [(core.ctrl [?a])])) (dfg [%7] [%8] (signature (core.fn [?a] [(core.adt [[?a]])])) ((core.make_adt _ _ 0) [%7] [%8] diff --git a/hugr-model/tests/fixtures/model-cond.edn b/hugr-model/tests/fixtures/model-cond.edn index 9f49446d6d..fd6fadcc86 100644 --- a/hugr-model/tests/fixtures/model-cond.edn +++ b/hugr-model/tests/fixtures/model-cond.edn @@ -2,29 +2,16 @@ (mod) -(define-func public - example.cond - (core.fn - [(core.adt [[] []]) (arithmetic.int.types.int 6)] - [(arithmetic.int.types.int 6)]) +(define-func example.cond + (core.fn [(core.adt [[] []]) arithmetic.int.types.int] + [arithmetic.int.types.int]) (dfg [%0 %1] [%2] - (signature - (core.fn - [(core.adt [[] []]) (arithmetic.int.types.int 6)] - [(arithmetic.int.types.int 6)])) - (cond [%0 %1] [%2] - (signature - (core.fn - [(core.adt [[] []]) (arithmetic.int.types.int 6)] - [(arithmetic.int.types.int 6)])) - (dfg [%3] [%3] - (signature - (core.fn [(arithmetic.int.types.int 6)] [(arithmetic.int.types.int 6)]))) - (dfg [%4] [%5] - (signature - (core.fn [(arithmetic.int.types.int 6)] [(arithmetic.int.types.int 6)])) - ((arithmetic.int.ineg 6) [%4] [%5] - (signature - (core.fn - [(arithmetic.int.types.int 6)] - [(arithmetic.int.types.int 6)]))))))) + (signature (core.fn [(core.adt [[] []]) arithmetic.int.types.int] [arithmetic.int.types.int])) + (cond [%0 %1] [%2] + (signature (core.fn [(core.adt [[] []]) arithmetic.int.types.int] [arithmetic.int.types.int])) + (dfg [%3] [%3] + (signature (core.fn [arithmetic.int.types.int] [arithmetic.int.types.int]))) + (dfg [%4] [%5] + (signature (core.fn [arithmetic.int.types.int] [arithmetic.int.types.int])) + (arithmetic.int.ineg [%4] [%5] + (signature (core.fn [arithmetic.int.types.int] [arithmetic.int.types.int]))))))) diff --git a/hugr-model/tests/fixtures/model-const.edn b/hugr-model/tests/fixtures/model-const.edn index 959afe2ef4..5d9bcb49b4 100644 --- a/hugr-model/tests/fixtures/model-const.edn +++ b/hugr-model/tests/fixtures/model-const.edn @@ -2,7 +2,7 @@ (mod) -(define-func public example.bools +(define-func example.bools (core.fn [] [(core.adt [[] []]) (core.adt [[] []])]) (dfg [] [%false %true] @@ -12,7 +12,7 @@ ((core.load_const (core.const.adt 1 (tuple))) [] [%true] (signature (core.fn [] [(core.adt [[] []])]))))) -(define-func public example.make-pair +(define-func example.make-pair (core.fn [] [(core.adt [[(collections.array.array 5 (arithmetic.int.types.int 6)) @@ -45,7 +45,7 @@ [[(collections.array.array 5 (arithmetic.int.types.int 6)) arithmetic.float.types.float64]])]))))) -(define-func public example.f64-json +(define-func example.f64-json (core.fn [] [arithmetic.float.types.float64]) (dfg [] [%0 %1] diff --git a/hugr-model/tests/fixtures/model-constraints.edn b/hugr-model/tests/fixtures/model-constraints.edn index 761a33c058..6884e55936 100644 --- a/hugr-model/tests/fixtures/model-constraints.edn +++ b/hugr-model/tests/fixtures/model-constraints.edn @@ -2,14 +2,13 @@ (mod) -(declare-func private array.replicate +(declare-func array.replicate (param ?n core.nat) (param ?t core.type) (where (core.nonlinear ?t)) (core.fn [?t] [(collections.array.array ?n ?t)])) (declare-func - public array.copy (param ?n core.nat) (param ?t core.type) @@ -19,7 +18,7 @@ [(collections.array.array ?n ?t) (collections.array.array ?n ?t)])) -(define-func public util.copy +(define-func util.copy (param ?t core.type) (where (core.nonlinear ?t)) (core.fn [?t] [?t ?t]) diff --git a/hugr-model/tests/fixtures/model-entrypoint.edn b/hugr-model/tests/fixtures/model-entrypoint.edn index fb70d10309..10cab9173b 100644 --- a/hugr-model/tests/fixtures/model-entrypoint.edn +++ b/hugr-model/tests/fixtures/model-entrypoint.edn @@ -2,7 +2,7 @@ (mod) -(define-func public main +(define-func main (core.fn [] []) (meta core.entrypoint) (dfg [] [] @@ -10,7 +10,7 @@ (mod) -(define-func public wrapper_dfg +(define-func wrapper_dfg (core.fn [] []) (dfg [] [] (signature (core.fn [] [])) @@ -18,17 +18,17 @@ (mod) -(define-func public wrapper_cfg +(define-func wrapper_cfg (core.fn [] []) (dfg [] [] (signature (core.fn [] [])) (cfg [] [] (signature (core.fn [] [])) (cfg [%entry] [%exit] - (signature (core.ctrl [[]] [[]])) + (signature (core.fn [(core.ctrl [])] [(core.ctrl [])])) (meta core.entrypoint) (block [%entry] [%exit] - (signature (core.ctrl [[]] [[]])) + (signature (core.fn [(core.ctrl [])] [(core.ctrl [])])) (dfg [] [%value] (signature (core.fn [] [(core.adt [[]])])) ((core.make_adt _ _ 0) [] [%value] diff --git a/hugr-model/tests/fixtures/model-loop.edn b/hugr-model/tests/fixtures/model-loop.edn index 8276ed74ba..5c4a6779e3 100644 --- a/hugr-model/tests/fixtures/model-loop.edn +++ b/hugr-model/tests/fixtures/model-loop.edn @@ -2,9 +2,7 @@ (mod) -(define-func - private - example.loop +(define-func example.loop (param ?a core.type) (core.fn [?a] [?a]) (dfg [%0] [%1] diff --git a/hugr-model/tests/fixtures/model-order.edn b/hugr-model/tests/fixtures/model-order.edn index ed5c1e69e9..76bf7b0ba6 100644 --- a/hugr-model/tests/fixtures/model-order.edn +++ b/hugr-model/tests/fixtures/model-order.edn @@ -2,54 +2,49 @@ (mod) -(define-func public main +(define-func main (core.fn - [(arithmetic.int.types.int 6) - (arithmetic.int.types.int 6) - (arithmetic.int.types.int 6) - (arithmetic.int.types.int 6)] - [(arithmetic.int.types.int 6) - (arithmetic.int.types.int 6) - (arithmetic.int.types.int 6) - (arithmetic.int.types.int 6)]) + [arithmetic.int.types.int + arithmetic.int.types.int + arithmetic.int.types.int + arithmetic.int.types.int] + [arithmetic.int.types.int + arithmetic.int.types.int + arithmetic.int.types.int + arithmetic.int.types.int]) (dfg [%0 %1 %2 %3] [%4 %5 %6 %7] (signature (core.fn - [(arithmetic.int.types.int 6) - (arithmetic.int.types.int 6) - (arithmetic.int.types.int 6) - (arithmetic.int.types.int 6)] - [(arithmetic.int.types.int 6) - (arithmetic.int.types.int 6) - (arithmetic.int.types.int 6) - (arithmetic.int.types.int 6)])) + [arithmetic.int.types.int + arithmetic.int.types.int + arithmetic.int.types.int + arithmetic.int.types.int] + [arithmetic.int.types.int + arithmetic.int.types.int + arithmetic.int.types.int + arithmetic.int.types.int])) (meta (core.order_hint.order 1 2)) (meta (core.order_hint.order 1 0)) (meta (core.order_hint.order 2 3)) (meta (core.order_hint.order 0 3)) - (meta (core.order_hint.input_key 4)) - (meta (core.order_hint.order 4 0)) - (meta (core.order_hint.order 4 5)) - (meta (core.order_hint.order 1 5)) - (meta (core.order_hint.output_key 5)) - ((arithmetic.int.ineg 6) + (arithmetic.int.ineg [%0] [%4] - (signature (core.fn [(arithmetic.int.types.int 6)] [(arithmetic.int.types.int 6)])) + (signature (core.fn [arithmetic.int.types.int] [arithmetic.int.types.int])) (meta (core.order_hint.key 0))) - ((arithmetic.int.ineg 6) + (arithmetic.int.ineg [%1] [%5] - (signature (core.fn [(arithmetic.int.types.int 6)] [(arithmetic.int.types.int 6)])) + (signature (core.fn [arithmetic.int.types.int] [arithmetic.int.types.int])) (meta (core.order_hint.key 1))) - ((arithmetic.int.ineg 6) + (arithmetic.int.ineg [%2] [%6] - (signature (core.fn [(arithmetic.int.types.int 6)] [(arithmetic.int.types.int 6)])) + (signature (core.fn [arithmetic.int.types.int] [arithmetic.int.types.int])) (meta (core.order_hint.key 2))) - ((arithmetic.int.ineg 6) + (arithmetic.int.ineg [%3] [%7] - (signature (core.fn [(arithmetic.int.types.int 6)] [(arithmetic.int.types.int 6)])) + (signature (core.fn [arithmetic.int.types.int] [arithmetic.int.types.int])) (meta (core.order_hint.key 3))))) diff --git a/hugr-model/tests/fixtures/model-params.edn b/hugr-model/tests/fixtures/model-params.edn index c4e29f933c..ba81fa6007 100644 --- a/hugr-model/tests/fixtures/model-params.edn +++ b/hugr-model/tests/fixtures/model-params.edn @@ -2,25 +2,10 @@ (mod) -(define-func public example.swap +(define-func example.swap ; The types of the values to be swapped are passed as implicit parameters. (param ?a core.type) (param ?b core.type) (core.fn [?a ?b] [?b ?a]) (dfg [%a %b] [%b %a] (signature (core.fn [?a ?b] [?b ?a])))) - -(declare-func public example.literals - (param ?a core.str) - (param ?b core.nat) - (param ?c core.bytes) - (param ?d core.float) - (core.fn [] [])) - -(define-func private example.call_literals - (core.fn [] []) - (dfg [] [] - (signature (core.fn [] [])) - ((core.call - (example.literals "string" 42 (bytes "SGVsbG8gd29ybGQg8J+Yig==") 6.023e23)) - (signature (core.fn [] []))))) diff --git a/hugr-passes/CHANGELOG.md b/hugr-passes/CHANGELOG.md index 7ae8e089b1..d5f2921845 100644 --- a/hugr-passes/CHANGELOG.md +++ b/hugr-passes/CHANGELOG.md @@ -1,37 +1,6 @@ # Changelog -## [0.22.1](https://github.com/CQCL/hugr/compare/hugr-passes-v0.22.0...hugr-passes-v0.22.1) - 2025-07-28 - -### New Features - -- Include copy_discard_array in DelegatingLinearizer::default ([#2479](https://github.com/CQCL/hugr/pull/2479)) -- Inline calls to functions not on cycles in the call graph ([#2450](https://github.com/CQCL/hugr/pull/2450)) - -## [0.22.0](https://github.com/CQCL/hugr/compare/hugr-passes-v0.21.0...hugr-passes-v0.22.0) - 2025-07-24 - -### New Features - -- ReplaceTypes allows linearizing inside Op replacements ([#2435](https://github.com/CQCL/hugr/pull/2435)) -- Add pass for DFG inlining ([#2460](https://github.com/CQCL/hugr/pull/2460)) - -## [0.21.0](https://github.com/CQCL/hugr/compare/hugr-passes-v0.20.2...hugr-passes-v0.21.0) - 2025-07-09 - -### Bug Fixes - -- DeadFuncElimPass+CallGraph w/ non-module-child entrypoint ([#2390](https://github.com/CQCL/hugr/pull/2390)) - -### New Features - -- [**breaking**] No nested FuncDefns (or AliasDefns) ([#2256](https://github.com/CQCL/hugr/pull/2256)) -- [**breaking**] Split `TypeArg::Sequence` into tuples and lists. ([#2140](https://github.com/CQCL/hugr/pull/2140)) -- [**breaking**] Merge `TypeParam` and `TypeArg` into one `Term` type in Rust ([#2309](https://github.com/CQCL/hugr/pull/2309)) -- [**breaking**] Rename 'Any' type bound to 'Linear' ([#2421](https://github.com/CQCL/hugr/pull/2421)) - -### Refactor - -- [**breaking**] Reduce error type sizes ([#2420](https://github.com/CQCL/hugr/pull/2420)) - ## [0.20.2](https://github.com/CQCL/hugr/compare/hugr-passes-v0.20.1...hugr-passes-v0.20.2) - 2025-06-25 ### Bug Fixes diff --git a/hugr-passes/Cargo.toml b/hugr-passes/Cargo.toml index adf019fcc5..8c1daafbcd 100644 --- a/hugr-passes/Cargo.toml +++ b/hugr-passes/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "hugr-passes" -version = "0.22.1" +version = "0.20.2" edition = { workspace = true } rust-version = { workspace = true } license = { workspace = true } @@ -19,7 +19,7 @@ workspace = true bench = false [dependencies] -hugr-core = { path = "../hugr-core", version = "0.22.1" } +hugr-core = { path = "../hugr-core", version = "0.20.2" } portgraph = { workspace = true } ascent = { version = "0.8.0" } derive_more = { workspace = true, features = ["display", "error", "from"] } diff --git a/hugr-passes/src/call_graph.rs b/hugr-passes/src/call_graph.rs index 7baf8530dd..e33881b1f7 100644 --- a/hugr-passes/src/call_graph.rs +++ b/hugr-passes/src/call_graph.rs @@ -1,3 +1,4 @@ +#![warn(missing_docs)] //! Data structure for call graphs of a Hugr use std::collections::HashMap; @@ -5,7 +6,6 @@ use hugr_core::{HugrView, Node, core::HugrNode, ops::OpType}; use petgraph::Graph; /// Weight for an edge in a [`CallGraph`] -#[derive(Clone, Debug, PartialEq, Eq)] pub enum CallGraphEdge { /// Edge corresponds to a [Call](OpType::Call) node (specified) in the Hugr Call(N), @@ -48,20 +48,19 @@ impl CallGraph { /// Makes a new `CallGraph` for a Hugr. pub fn new(hugr: &impl HugrView) -> Self { let mut g = Graph::default(); - let mut node_to_g = hugr + let non_func_root = + (!hugr.get_optype(hugr.entrypoint()).is_module()).then_some(hugr.entrypoint()); + let node_to_g = hugr .children(hugr.module_root()) .filter_map(|n| { let weight = match hugr.get_optype(n) { OpType::FuncDecl(_) => CallGraphNode::FuncDecl(n), OpType::FuncDefn(_) => CallGraphNode::FuncDefn(n), - _ => return None, + _ => (Some(n) == non_func_root).then_some(CallGraphNode::NonFuncRoot)?, }; Some((n, g.add_node(weight))) }) .collect::>(); - if !hugr.entrypoint_optype().is_module() && !node_to_g.contains_key(&hugr.entrypoint()) { - node_to_g.insert(hugr.entrypoint(), g.add_node(CallGraphNode::NonFuncRoot)); - } for (func, cg_node) in &node_to_g { traverse(hugr, *cg_node, *func, &mut g, &node_to_g); } diff --git a/hugr-passes/src/composable.rs b/hugr-passes/src/composable.rs index bda5e66cf7..d7f44fcebb 100644 --- a/hugr-passes/src/composable.rs +++ b/hugr-passes/src/composable.rs @@ -9,16 +9,11 @@ use itertools::Either; /// An optimization pass that can be sequenced with another and/or wrapped /// e.g. by [`ValidatingPass`] pub trait ComposablePass: Sized { - /// Error thrown by this pass. type Error: Error; - /// Result returned by this pass. type Result; // Would like to default to () but currently unstable - /// Run the pass on the given HUGR. fn run(&self, hugr: &mut H) -> Result; - /// Apply a function to the error type of this pass, returning a new - /// [`ComposablePass`] that has the same result type. fn map_err( self, f: impl Fn(Self::Error) -> E2, @@ -57,9 +52,7 @@ pub trait ComposablePass: Sized { /// Trait for combining the error types from two different passes /// into a single error. pub trait ErrorCombiner: Error { - /// Create a combined error from the first pass's error. fn from_first(a: A) -> Self; - /// Create a combined error from the second pass's error. fn from_second(b: B) -> Self; } @@ -120,33 +113,20 @@ pub enum ValidatePassError where N: HugrNode + 'static, { - /// Validation failed on the initial HUGR. #[error("Failed to validate input HUGR: {err}\n{pretty_hugr}")] Input { - /// The validation error that occurred. #[source] - err: Box>, - /// A pretty-printed representation of the HUGR that failed validation. + err: ValidationError, pretty_hugr: String, }, - /// Validation failed on the final HUGR. #[error("Failed to validate output HUGR: {err}\n{pretty_hugr}")] Output { - /// The validation error that occurred. #[source] - err: Box>, - /// A pretty-printed representation of the HUGR that failed validation. + err: ValidationError, pretty_hugr: String, }, - /// An error from the underlying pass. #[error(transparent)] - Underlying(Box), -} - -impl From for ValidatePassError { - fn from(err: E) -> Self { - Self::Underlying(Box::new(err)) - } + Underlying(#[from] E), } /// Runs an underlying pass, but with validation of the Hugr @@ -154,7 +134,6 @@ impl From for ValidatePassError { pub struct ValidatingPass(P, PhantomData); impl, H: HugrMut> ValidatingPass { - /// Return a new [`ValidatingPass`] that wraps the given underlying pass. pub fn new(underlying: P) -> Self { Self(underlying, PhantomData) } @@ -178,12 +157,12 @@ where fn run(&self, hugr: &mut H) -> Result { self.validation_impl(hugr, |err, pretty_hugr| ValidatePassError::Input { - err: Box::new(err), + err, pretty_hugr, })?; - let res = self.0.run(hugr)?; + let res = self.0.run(hugr).map_err(ValidatePassError::Underlying)?; self.validation_impl(hugr, |err, pretty_hugr| ValidatePassError::Output { - err: Box::new(err), + err, pretty_hugr, })?; Ok(res) @@ -234,7 +213,7 @@ pub(crate) fn validate_if_test, H: HugrMut>( if cfg!(test) { ValidatingPass::new(pass).run(hugr) } else { - Ok(pass.run(hugr)?) + pass.run(hugr).map_err(ValidatePassError::Underlying) } } @@ -244,7 +223,8 @@ mod test { use std::convert::Infallible; use hugr_core::builder::{ - Dataflow, DataflowHugr, DataflowSubContainer, FunctionBuilder, HugrBuilder, ModuleBuilder, + Container, Dataflow, DataflowHugr, DataflowSubContainer, FunctionBuilder, HugrBuilder, + ModuleBuilder, }; use hugr_core::extension::prelude::{ConstUsize, MakeTuple, UnpackTuple, bool_t, usize_t}; use hugr_core::hugr::hugrmut::HugrMut; @@ -324,7 +304,7 @@ mod test { assert_eq!(h, backup); // Did nothing let r = ValidatingPass::new(cfold).run(&mut h); - assert!(matches!(r, Err(ValidatePassError::Input { err: e, .. }) if *e == err)); + assert!(matches!(r, Err(ValidatePassError::Input { err: e, .. }) if e == err)); } #[test] diff --git a/hugr-passes/src/const_fold.rs b/hugr-passes/src/const_fold.rs index 9c450b0aca..11a92faa48 100644 --- a/hugr-passes/src/const_fold.rs +++ b/hugr-passes/src/const_fold.rs @@ -1,3 +1,4 @@ +#![warn(missing_docs)] //! Constant-folding pass. //! An (example) use of the [dataflow analysis framework](super::dataflow). diff --git a/hugr-passes/src/const_fold/test.rs b/hugr-passes/src/const_fold/test.rs index a60684ec07..f4165676b2 100644 --- a/hugr-passes/src/const_fold/test.rs +++ b/hugr-passes/src/const_fold/test.rs @@ -1591,7 +1591,7 @@ fn test_module() -> Result<(), Box> { // Define a top-level constant, (only) the second of which can be removed let c7 = mb.add_constant(Value::from(ConstInt::new_u(5, 7)?)); let c17 = mb.add_constant(Value::from(ConstInt::new_u(5, 17)?)); - let ad1 = mb.add_alias_declare("unused", TypeBound::Linear)?; + let ad1 = mb.add_alias_declare("unused", TypeBound::Any)?; let ad2 = mb.add_alias_def("unused2", INT_TYPES[3].clone())?; let mut main = mb.define_function( "main", diff --git a/hugr-passes/src/dataflow.rs b/hugr-passes/src/dataflow.rs index 3311409655..a97901c61b 100644 --- a/hugr-passes/src/dataflow.rs +++ b/hugr-passes/src/dataflow.rs @@ -1,3 +1,4 @@ +#![warn(missing_docs)] //! Dataflow analysis of Hugrs. mod datalog; diff --git a/hugr-passes/src/dataflow/datalog.rs b/hugr-passes/src/dataflow/datalog.rs index 0368f931bc..0eadbcdc10 100644 --- a/hugr-passes/src/dataflow/datalog.rs +++ b/hugr-passes/src/dataflow/datalog.rs @@ -51,11 +51,6 @@ impl Machine { /// or [Conditional](hugr_core::ops::Conditional)). /// Any inputs not given values by `in_values`, are set to [`PartialValue::Top`]. /// Multiple calls for the same `parent` will `join` values for corresponding ports. - #[expect( - clippy::result_large_err, - reason = "Not called recursively and not a performance bottleneck" - )] - #[inline] pub fn prepopulate_inputs( &mut self, parent: H::Node, diff --git a/hugr-passes/src/dataflow/test.rs b/hugr-passes/src/dataflow/test.rs index 915f6f3425..205d9ba4fa 100644 --- a/hugr-passes/src/dataflow/test.rs +++ b/hugr-passes/src/dataflow/test.rs @@ -2,7 +2,7 @@ use std::convert::Infallible; use ascent::{Lattice, lattice::BoundedLattice}; -use hugr_core::builder::{CFGBuilder, DataflowHugr, ModuleBuilder, inout_sig}; +use hugr_core::builder::{CFGBuilder, Container, DataflowHugr, ModuleBuilder, inout_sig}; use hugr_core::ops::{CallIndirect, TailLoop}; use hugr_core::types::{ConstTypeError, TypeRow}; use hugr_core::{Hugr, Node, Wire}; @@ -409,14 +409,11 @@ fn test_call( #[case] out: PartialValue, ) { let mut builder = DFGBuilder::new(Signature::new_endo(vec![bool_t(); 2])).unwrap(); - let func_defn = { - let mut mb = builder.module_root_builder(); - let func_bldr = mb - .define_function("id", Signature::new_endo(bool_t())) - .unwrap(); - let [v] = func_bldr.input_wires_arr(); - func_bldr.finish_with_outputs([v]).unwrap() - }; + let func_bldr = builder + .define_function("id", Signature::new_endo(bool_t())) + .unwrap(); + let [v] = func_bldr.input_wires_arr(); + let func_defn = func_bldr.finish_with_outputs([v]).unwrap(); let [a, b] = builder.input_wires_arr(); let [a2] = builder .call(func_defn.handle(), &[], [a]) @@ -557,8 +554,7 @@ fn call_indirect(#[case] inp1: PartialValue, #[case] inp2: PartialValue> ComposablePass for InlineDFGsPass { - type Error = Infallible; - type Result = (); - - fn run(&self, h: &mut H) -> Result<(), Self::Error> { - let dfgs = h - .entry_descendants() - .skip(1) // Skip the entrypoint itself - .filter(|&n| h.get_optype(n).is_dfg()) - .collect_vec(); - for dfg in dfgs { - h.apply_patch(InlineDFG(dfg.into())) - .map_err(|err| -> Infallible { - match err { - InlineDFGError::CantInlineEntrypoint { .. } => { - unreachable!("We skipped the entrypoint") - } - InlineDFGError::NotDFG { .. } => unreachable!("Should be a DFG"), - _ => unreachable!("No other error cases"), - } - }) - .unwrap(); - } - Ok(()) - } -} - -#[cfg(test)] -mod test { - use hugr_core::{ - HugrView, - builder::{DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer}, - extension::prelude::qb_t, - types::Signature, - }; - - use crate::ComposablePass; - - use super::InlineDFGsPass; - - #[test] - fn inline_dfgs() -> Result<(), Box> { - let mut outer = DFGBuilder::new(Signature::new_endo(vec![qb_t(), qb_t()]))?; - let [a, b] = outer.input_wires_arr(); - - let inner1 = outer.dfg_builder_endo([(qb_t(), a)])?; - let [inner1_a] = inner1.input_wires_arr(); - let [a] = inner1.finish_with_outputs([inner1_a])?.outputs_arr(); - - let mut inner2 = outer.dfg_builder_endo([(qb_t(), b)])?; - let [inner2_b] = inner2.input_wires_arr(); - let inner2_inner = inner2.dfg_builder_endo([(qb_t(), inner2_b)])?; - let [inner2_inner_b] = inner2_inner.input_wires_arr(); - let [inner2_b] = inner2_inner - .finish_with_outputs([inner2_inner_b])? - .outputs_arr(); - let [b] = inner2.finish_with_outputs([inner2_b])?.outputs_arr(); - - let inner3 = outer.dfg_builder_endo([(qb_t(), a), (qb_t(), b)])?; - let [inner3_a, inner3_b] = inner3.input_wires_arr(); - let [a, b] = inner3 - .finish_with_outputs([inner3_a, inner3_b])? - .outputs_arr(); - - let mut h = outer.finish_hugr_with_outputs([a, b])?; - assert_eq!(h.num_nodes(), 5 * 3 + 4); // 5 DFGs with I/O + 4 nodes for module/func roots - InlineDFGsPass.run(&mut h).unwrap(); - - // Root should be the only remaining DFG - assert!(h.get_optype(h.entrypoint()).is_dfg()); - assert!( - h.entry_descendants() - .skip(1) - .all(|n| !h.get_optype(n).is_dfg()) - ); - assert_eq!(h.num_nodes(), 3 + 4); // 1 DFG with I/O + 4 nodes for module/func roots - Ok(()) - } -} diff --git a/hugr-passes/src/inline_funcs.rs b/hugr-passes/src/inline_funcs.rs deleted file mode 100644 index b999560f45..0000000000 --- a/hugr-passes/src/inline_funcs.rs +++ /dev/null @@ -1,229 +0,0 @@ -//! Contains a pass to inline calls to selected functions in a Hugr. -use std::collections::{HashSet, VecDeque}; - -use hugr_core::hugr::hugrmut::HugrMut; -use hugr_core::hugr::patch::inline_call::InlineCall; -use itertools::Itertools; -use petgraph::algo::tarjan_scc; - -use crate::call_graph::{CallGraph, CallGraphNode}; - -/// Error raised by [inline_acyclic] -#[derive(Clone, Debug, thiserror::Error, PartialEq)] -#[non_exhaustive] -pub enum InlineFuncsError {} - -/// Inline (a subset of) [Call]s whose target [FuncDefn]s are not in cycles of the call -/// graph. -/// -/// The function `call_predicate` is passed each such [Call] node and can return -/// `false` to prevent that Call from being inlined. (Note the [Call] may be created as -/// a result of previous inlinings so may not have existed in the original Hugr). -/// -/// [Call]: hugr_core::ops::Call -/// [FuncDefn]: hugr_core::ops::FuncDefn -pub fn inline_acyclic( - h: &mut H, - call_predicate: impl Fn(&H, H::Node) -> bool, -) -> Result<(), InlineFuncsError> { - let cg = CallGraph::new(&*h); - let g = cg.graph(); - let all_funcs_in_cycles = tarjan_scc(g) - .into_iter() - .flat_map(|mut ns| { - if let Ok(n) = ns.iter().exactly_one() { - if g.edges_connecting(*n, *n).next().is_none() { - ns.clear(); // Single-node SCC has no self edge, so discard - } - } - ns.into_iter().map(|n| { - let CallGraphNode::FuncDefn(fd) = g.node_weight(n).unwrap() else { - panic!("Expected only FuncDefns in sccs") - }; - *fd - }) - }) - .collect::>(); - let target_funcs: HashSet = h - .children(h.module_root()) - .filter(|n| h.get_optype(*n).is_func_defn() && !all_funcs_in_cycles.contains(n)) - .collect(); - let mut q = VecDeque::from([h.entrypoint()]); - while let Some(n) = q.pop_front() { - if h.get_optype(n).is_call() { - if let Some(t) = h.static_source(n) { - if target_funcs.contains(&t) && call_predicate(h, n) { - // We've already checked all error conditions - h.apply_patch(InlineCall::new(n)).unwrap(); - } - } - } - // Traverse children - including any resulting from turning Call into DFG - q.extend(h.children(n)); - } - Ok(()) -} - -#[cfg(test)] -mod test { - use std::collections::HashSet; - - use hugr_core::core::HugrNode; - use hugr_core::ops::OpType; - use itertools::Itertools; - use petgraph::visit::EdgeRef; - - use hugr_core::HugrView; - use hugr_core::builder::{Dataflow, DataflowSubContainer, HugrBuilder, ModuleBuilder}; - use hugr_core::{Hugr, extension::prelude::qb_t, types::Signature}; - use rstest::rstest; - - use crate::call_graph::{CallGraph, CallGraphNode}; - use crate::inline_funcs::inline_acyclic; - - /// /->-\ - /// main -> f g -> b -> c - /// / \-<-/ - /// / - /// \-> a -> x - fn make_test_hugr() -> Hugr { - let sig = || Signature::new_endo(qb_t()); - let mut mb = ModuleBuilder::new(); - let x = mb.declare("x", sig().into()).unwrap(); - let a = { - let mut fb = mb.define_function("a", sig()).unwrap(); - let ins = fb.input_wires(); - let res = fb.call(&x, &[], ins).unwrap(); - fb.finish_with_outputs(res.outputs()).unwrap() - }; - let c = { - let fb = mb.define_function("c", sig()).unwrap(); - let ins = fb.input_wires(); - fb.finish_with_outputs(ins).unwrap() - }; - let b = { - let mut fb = mb.define_function("b", sig()).unwrap(); - let ins = fb.input_wires(); - let res = fb.call(c.handle(), &[], ins).unwrap().outputs(); - fb.finish_with_outputs(res).unwrap() - }; - let f = mb.declare("f", sig().into()).unwrap(); - let g = { - let mut fb = mb.define_function("g", sig()).unwrap(); - let ins = fb.input_wires(); - let c1 = fb.call(&f, &[], ins).unwrap(); - let c2 = fb.call(b.handle(), &[], c1.outputs()).unwrap(); - fb.finish_with_outputs(c2.outputs()).unwrap() - }; - let _f = { - let mut fb = mb.define_declaration(&f).unwrap(); - let ins = fb.input_wires(); - let c1 = fb.call(g.handle(), &[], ins).unwrap(); - let c2 = fb.call(a.handle(), &[], c1.outputs()).unwrap(); - fb.finish_with_outputs(c2.outputs()).unwrap() - }; - mb.finish_hugr().unwrap() - } - - fn find_func(h: &H, name: &str) -> H::Node { - h.children(h.module_root()) - .find(|n| { - h.get_optype(*n) - .as_func_defn() - .is_some_and(|fd| fd.func_name() == name) - }) - .unwrap() - } - - #[rstest] - #[case(["a", "b", "c"], ["a", "b", "c"], [vec!["g", "x"], vec!["f"], vec!["x"], vec![], vec![]])] - #[case(["a", "b"], ["a", "b"], [vec!["g", "x"], vec!["f", "c"], vec!["x"], vec!["c"], vec![]])] - #[case(["c"], ["c"], [vec!["g", "a"], vec!("f", "b"), vec!["x"], vec![], vec![]])] - fn test_inline( - #[case] req: impl IntoIterator, - #[case] check_not_called: impl IntoIterator, - #[case] calls_fgabc: [Vec<&'static str>; 5], - ) { - let mut h = make_test_hugr(); - let target_funcs = req - .into_iter() - .map(|name| find_func(&h, name)) - .collect::>(); - inline_acyclic(&mut h, |h, call| { - let tgt = h.static_source(call).unwrap(); - // Check the callback is never asked about an impossible inlining - assert!(["a", "b", "c"].contains(&func_name(h, tgt).as_str())); - target_funcs.contains(&tgt) - }) - .unwrap(); - let cg = CallGraph::new(&h); - for fname in check_not_called { - let fnode = find_func(&h, fname); - let fnode = cg.node_index(fnode).unwrap(); - assert_eq!( - None, - cg.graph() - .edges_directed(fnode, petgraph::Direction::Incoming) - .next() - ); - } - for (fname, tgts) in ["f", "g", "a", "b", "c"].into_iter().zip_eq(calls_fgabc) { - let fnode = find_func(&h, fname); - assert_eq!( - outgoing_calls(&cg, fnode) - .into_iter() - .map(|n| func_name(&h, n).as_str()) - .collect::>(), - HashSet::from_iter(tgts), - "Calls from {fname}" - ); - } - } - - fn outgoing_calls(cg: &CallGraph, src: N) -> Vec { - let src = cg.node_index(src).unwrap(); - cg.graph() - .edges_directed(src, petgraph::Direction::Outgoing) - .map(|e| func_node(cg.graph().node_weight(e.target()).unwrap())) - .collect() - } - - #[test] - fn test_filter_caller() { - let mut h = make_test_hugr(); - let [g, b, c] = ["g", "b", "c"].map(|n| find_func(&h, n)); - // Inline calls contained within `g` - inline_acyclic(&mut h, |h, mut call| { - loop { - if call == g { - return true; - }; - let Some(parent) = h.get_parent(call) else { - return false; - }; - call = parent; - } - }) - .unwrap(); - let cg = CallGraph::new(&h); - // b and then c should have been inlined into g, leaving only cyclic call to f - assert_eq!(outgoing_calls(&cg, g), [find_func(&h, "f")]); - // But c should not have been inlined into b: - assert_eq!(outgoing_calls(&cg, b), [c]); - } - - fn func_node(cgn: &CallGraphNode) -> N { - match cgn { - CallGraphNode::FuncDecl(n) | CallGraphNode::FuncDefn(n) => *n, - CallGraphNode::NonFuncRoot => panic!(), - } - } - - fn func_name(h: &H, n: H::Node) -> &String { - match h.get_optype(n) { - OpType::FuncDecl(fd) => fd.func_name(), - OpType::FuncDefn(fd) => fd.func_name(), - _ => panic!(), - } - } -} diff --git a/hugr-passes/src/lib.rs b/hugr-passes/src/lib.rs index 6e97f3422e..c82fc5abe6 100644 --- a/hugr-passes/src/lib.rs +++ b/hugr-passes/src/lib.rs @@ -1,4 +1,5 @@ //! Compilation passes acting on the HUGR program representation. +#![expect(missing_docs)] // TODO: Fix... pub mod call_graph; pub mod composable; @@ -11,9 +12,6 @@ mod dead_funcs; pub use dead_funcs::{RemoveDeadFuncsError, RemoveDeadFuncsPass, remove_dead_funcs}; pub mod force_order; mod half_node; -pub mod inline_dfgs; -pub mod inline_funcs; -pub use inline_funcs::inline_acyclic; pub mod linearize_array; pub use linearize_array::LinearizeArrayPass; pub mod lower; diff --git a/hugr-passes/src/linearize_array.rs b/hugr-passes/src/linearize_array.rs index 07fbc6e958..4f8da110c9 100644 --- a/hugr-passes/src/linearize_array.rs +++ b/hugr-passes/src/linearize_array.rs @@ -10,7 +10,7 @@ use hugr_core::{ std_extensions::collections::{ array::{ ARRAY_REPEAT_OP_ID, ARRAY_SCAN_OP_ID, Array, ArrayKind, ArrayOpDef, ArrayRepeatDef, - ArrayScanDef, ArrayValue, array_type_parametric, + ArrayScanDef, ArrayValue, array_type_def, array_type_parametric, }, value_array::{self, VArrayFromArrayDef, VArrayToArrayDef, VArrayValue, ValueArray}, }, @@ -21,7 +21,9 @@ use strum::IntoEnumIterator; use crate::{ ComposablePass, ReplaceTypes, - replace_types::{DelegatingLinearizer, NodeTemplate, ReplaceTypesError}, + replace_types::{ + DelegatingLinearizer, NodeTemplate, ReplaceTypesError, handlers::copy_discard_array, + }, }; /// A HUGR -> HUGR pass that turns 'value_array`s into regular linear `array`s. @@ -64,7 +66,7 @@ impl Default for LinearizeArrayPass { // error out and make sure we're not emitting `get`s for nested value // arrays. assert!( - op_def != ArrayOpDef::get || args[1].as_runtime().unwrap().copyable(), + op_def != ArrayOpDef::get || args[1].as_type().unwrap().copyable(), "Cannot linearise arrays in this Hugr: \ Contains a `get` operation on nested value arrays" ); @@ -112,6 +114,8 @@ impl Default for LinearizeArrayPass { )) }, ); + pass.linearizer() + .register_callback(array_type_def(), copy_discard_array); Self(pass) } } diff --git a/hugr-passes/src/lower.rs b/hugr-passes/src/lower.rs index db9e60e135..45b5ce9080 100644 --- a/hugr-passes/src/lower.rs +++ b/hugr-passes/src/lower.rs @@ -1,5 +1,3 @@ -//! Passes to lower operations in a HUGR. - use hugr_core::{ Hugr, Node, hugr::{hugrmut::HugrMut, views::SiblingSubgraph}, diff --git a/hugr-passes/src/monomorphize.rs b/hugr-passes/src/monomorphize.rs index 2d5abd5eb1..2a97f75240 100644 --- a/hugr-passes/src/monomorphize.rs +++ b/hugr-passes/src/monomorphize.rs @@ -133,7 +133,25 @@ fn instantiate( mono_sig: Signature, cache: &mut Instantiations, ) -> Node { - let for_func = cache.entry(poly_func).or_default(); + let for_func = cache.entry(poly_func).or_insert_with(|| { + // First time we've instantiated poly_func. Lift any nested FuncDefn's out to the same level. + let outer_name = h + .get_optype(poly_func) + .as_func_defn() + .unwrap() + .func_name() + .clone(); + let mut to_scan = Vec::from_iter(h.children(poly_func)); + while let Some(n) = to_scan.pop() { + if let OpType::FuncDefn(fd) = h.optype_mut(n) { + *fd.func_name_mut() = mangle_inner_func(&outer_name, fd.func_name()); + h.move_after_sibling(n, poly_func); + } else { + to_scan.extend(h.children(n)); + } + } + HashMap::new() + }); let ve = match for_func.entry(type_args.clone()) { Entry::Occupied(n) => return *n.get(), @@ -213,10 +231,9 @@ impl> ComposablePass for MonomorphizePass { } } -/// Helper to create mangled representations of lists of [TypeArg]s. -struct TypeArgsSeq<'a>(&'a [TypeArg]); +struct TypeArgsList<'a>(&'a [TypeArg]); -impl std::fmt::Display for TypeArgsSeq<'_> { +impl std::fmt::Display for TypeArgsList<'_> { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { for arg in self.0 { f.write_char('$')?; @@ -232,14 +249,13 @@ fn escape_dollar(str: impl AsRef) -> String { fn write_type_arg_str(arg: &TypeArg, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match arg { - TypeArg::Runtime(ty) => f.write_fmt(format_args!("t({})", escape_dollar(ty.to_string()))), - TypeArg::BoundedNat(n) => f.write_fmt(format_args!("n({n})")), - TypeArg::String(arg) => f.write_fmt(format_args!("s({})", escape_dollar(arg))), - TypeArg::List(elems) => f.write_fmt(format_args!("list({})", TypeArgsSeq(elems))), - TypeArg::Tuple(elems) => f.write_fmt(format_args!("tuple({})", TypeArgsSeq(elems))), + TypeArg::Type { ty } => f.write_fmt(format_args!("t({})", escape_dollar(ty.to_string()))), + TypeArg::BoundedNat { n } => f.write_fmt(format_args!("n({n})")), + TypeArg::String { arg } => f.write_fmt(format_args!("s({})", escape_dollar(arg))), + TypeArg::Sequence { elems } => f.write_fmt(format_args!("seq({})", TypeArgsList(elems))), // We are monomorphizing. We will never monomorphize to a signature // containing a variable. - TypeArg::Variable(_) => panic!("type_arg_str variable: {arg}"), + TypeArg::Variable { .. } => panic!("type_arg_str variable: {arg}"), _ => panic!("unknown type arg: {arg}"), } } @@ -259,7 +275,11 @@ fn write_type_arg_str(arg: &TypeArg, f: &mut std::fmt::Formatter<'_>) -> std::fm /// is used as "t({arg})" for the string representation of that arg. pub fn mangle_name(name: &str, type_args: impl AsRef<[TypeArg]>) -> String { let name = escape_dollar(name); - format!("${name}${}", TypeArgsSeq(type_args.as_ref())) + format!("${name}${}", TypeArgsList(type_args.as_ref())) +} + +fn mangle_inner_func(outer_name: &str, inner_name: &str) -> String { + format!("${outer_name}${inner_name}") } #[cfg(test)] @@ -268,7 +288,6 @@ mod test { use std::iter; use hugr_core::extension::simple_op::MakeRegisteredOp as _; - use hugr_core::hugr::hugrmut::HugrMut; use hugr_core::std_extensions::arithmetic::int_types::INT_TYPES; use hugr_core::std_extensions::collections; use hugr_core::std_extensions::collections::array::ArrayKind; @@ -289,7 +308,7 @@ mod test { use crate::{monomorphize, remove_dead_funcs}; - use super::{is_polymorphic, mangle_name}; + use super::{is_polymorphic, mangle_inner_func, mangle_name}; fn pair_type(ty: Type) -> Type { Type::new_tuple(vec![ty.clone(), ty]) @@ -406,12 +425,13 @@ mod test { } #[test] - fn test_multiargs_nats() { + fn test_flattening_multiargs_nats() { //pf1 contains pf2 contains mono_func -> pf1 and pf1 share pf2's and they share mono_func let tv = |i| Type::new_var_use(i, TypeBound::Copyable); - let sv = |i| TypeArg::new_var_use(i, TypeParam::max_nat_type()); - let sa = |n| TypeArg::BoundedNat(n); + let sv = |i| TypeArg::new_var_use(i, TypeParam::max_nat()); + let sa = |n| TypeArg::BoundedNat { n }; + let n: u64 = 5; let mut outer = FunctionBuilder::new( "mainish", @@ -427,23 +447,32 @@ mod test { .unwrap(); let arr2u = || ValueArray::ty_parametric(sa(2), usize_t()).unwrap(); + let pf1t = PolyFuncType::new( + [TypeParam::max_nat()], + Signature::new( + ValueArray::ty_parametric(sv(0), arr2u()).unwrap(), + usize_t(), + ), + ); + let mut pf1 = outer.define_function("pf1", pf1t).unwrap(); - let mut mb = outer.module_root_builder(); + let pf2t = PolyFuncType::new( + [TypeParam::max_nat(), TypeBound::Copyable.into()], + Signature::new( + vec![ValueArray::ty_parametric(sv(0), tv(1)).unwrap()], + tv(1), + ), + ); + let mut pf2 = pf1.define_function("pf2", pf2t).unwrap(); let mono_func = { - let mut fb = mb + let mut fb = pf2 .define_function("get_usz", Signature::new(vec![], usize_t())) .unwrap(); let cst0 = fb.add_load_value(ConstUsize::new(1)); fb.finish_with_outputs([cst0]).unwrap() }; - let pf2 = { - let pf2t = PolyFuncType::new( - [TypeParam::max_nat_type(), TypeBound::Copyable.into()], - Signature::new(ValueArray::ty_parametric(sv(0), tv(1)).unwrap(), tv(1)), - ); - let mut pf2 = mb.define_function("pf2", pf2t).unwrap(); let [inw] = pf2.input_wires_arr(); let [idx] = pf2.call(mono_func.handle(), &[], []).unwrap().outputs_arr(); let op_def = collections::value_array::EXTENSION.get_op("get").unwrap(); @@ -455,16 +484,6 @@ mod test { .unwrap(); pf2.finish_with_outputs([got]).unwrap() }; - - let pf1t = PolyFuncType::new( - [TypeParam::max_nat_type()], - Signature::new( - ValueArray::ty_parametric(sv(0), arr2u()).unwrap(), - usize_t(), - ), - ); - let mut pf1 = mb.define_function("pf1", pf1t).unwrap(); - // pf1: Two calls to pf2, one depending on pf1's TypeArg, the other not let inner = pf1 .call(pf2.handle(), &[sv(0), arr2u().into()], pf1.input_wires()) @@ -472,12 +491,11 @@ mod test { let elem = pf1 .call( pf2.handle(), - &[TypeArg::BoundedNat(2), usize_t().into()], + &[TypeArg::BoundedNat { n: 2 }, usize_t().into()], inner.outputs(), ) .unwrap(); let pf1 = pf1.finish_with_outputs(elem.outputs()).unwrap(); - // Outer: two calls to pf1 with different TypeArgs let [e1] = outer .call(pf1.handle(), &[sa(n)], outer.input_wires()) @@ -498,24 +516,23 @@ mod test { .call(pf1.handle(), &[sa(n - 1)], [ar2_unwrapped]) .unwrap() .outputs_arr(); - let outer_func = outer.container_node(); let mut hugr = outer.finish_hugr_with_outputs([e1, e2]).unwrap(); - hugr.set_entrypoint(hugr.module_root()); // We want to act on everything, not just `main` monomorphize(&mut hugr).unwrap(); let mono_hugr = hugr; mono_hugr.validate().unwrap(); let funcs = list_funcs(&mono_hugr); + let pf2_name = mangle_inner_func("pf1", "pf2"); assert_eq!( funcs.keys().copied().sorted().collect_vec(), vec![ - &mangle_name("pf1", &[TypeArg::BoundedNat(5)]), - &mangle_name("pf1", &[TypeArg::BoundedNat(4)]), - &mangle_name("pf2", &[TypeArg::BoundedNat(5), arr2u().into()]), // from pf1<5> - &mangle_name("pf2", &[TypeArg::BoundedNat(4), arr2u().into()]), // from pf1<4> - &mangle_name("pf2", &[TypeArg::BoundedNat(2), usize_t().into()]), // from both pf1<4> and <5> - "get_usz", - "pf2", + &mangle_name("pf1", &[TypeArg::BoundedNat { n: 5 }]), + &mangle_name("pf1", &[TypeArg::BoundedNat { n: 4 }]), + &mangle_name(&pf2_name, &[TypeArg::BoundedNat { n: 5 }, arr2u().into()]), // from pf1<5> + &mangle_name(&pf2_name, &[TypeArg::BoundedNat { n: 4 }, arr2u().into()]), // from pf1<4> + &mangle_name(&pf2_name, &[TypeArg::BoundedNat { n: 2 }, usize_t().into()]), // from both pf1<4> and <5> + &mangle_inner_func(&pf2_name, "get_usz"), + &pf2_name, "mainish", "pf1" ] @@ -523,10 +540,13 @@ mod test { .sorted() .collect_vec() ); - #[allow(clippy::unnecessary_to_owned)] // it is necessary - let (n, fd) = *funcs.get(&"mainish".to_string()).unwrap(); - assert_eq!(n, outer_func); - assert_eq!(fd.func_name(), "mainish"); // just a sanity check on list_funcs + for (n, fd) in funcs.into_values() { + if n == mono_hugr.entrypoint() { + assert_eq!(fd.func_name(), "mainish"); + } else { + assert_ne!(fd.func_name(), "mainish"); + } + } } fn list_funcs(h: &Hugr) -> HashMap<&String, (Node, &FuncDefn)> { @@ -539,6 +559,50 @@ mod test { .collect::>() } + #[test] + fn test_no_flatten_out_of_mono_func() -> Result<(), Box> { + let ity = || INT_TYPES[4].clone(); + let sig = Signature::new_endo(vec![usize_t(), ity()]); + let mut dfg = DFGBuilder::new(sig.clone()).unwrap(); + let mut mono = dfg.define_function("id2", sig).unwrap(); + let pf = mono + .define_function( + "id", + PolyFuncType::new( + [TypeBound::Any.into()], + Signature::new_endo(Type::new_var_use(0, TypeBound::Any)), + ), + ) + .unwrap(); + let outs = pf.input_wires(); + let pf = pf.finish_with_outputs(outs).unwrap(); + let [a, b] = mono.input_wires_arr(); + let [a] = mono + .call(pf.handle(), &[usize_t().into()], [a]) + .unwrap() + .outputs_arr(); + let [b] = mono + .call(pf.handle(), &[ity().into()], [b]) + .unwrap() + .outputs_arr(); + let mono = mono.finish_with_outputs([a, b]).unwrap(); + let c = dfg.call(mono.handle(), &[], dfg.input_wires()).unwrap(); + let mut hugr = dfg.finish_hugr_with_outputs(c.outputs()).unwrap(); + monomorphize(&mut hugr)?; + let mono_hugr = hugr; + + let mut funcs = list_funcs(&mono_hugr); + #[allow(clippy::unnecessary_to_owned)] // It is necessary + let (m, _) = funcs.remove(&"id2".to_string()).unwrap(); + assert_eq!(m, mono.handle().node()); + assert_eq!(mono_hugr.get_parent(m), Some(mono_hugr.entrypoint())); + for t in [usize_t(), ity()] { + let (n, _) = funcs.remove(&mangle_name("id", &[t.into()])).unwrap(); + assert_eq!(mono_hugr.get_parent(n), Some(m)); // Not lifted to top + } + Ok(()) + } + #[test] fn load_function() { let mut hugr = { @@ -548,8 +612,8 @@ mod test { .define_function( "foo", PolyFuncType::new( - [TypeBound::Linear.into()], - Signature::new_endo(Type::new_var_use(0, TypeBound::Linear)), + [TypeBound::Any.into()], + Signature::new_endo(Type::new_var_use(0, TypeBound::Any)), ), ) .unwrap(); @@ -593,10 +657,9 @@ mod test { #[case::type_int(vec![INT_TYPES[2].clone().into()], "$foo$$t(int(2))")] #[case::string(vec!["arg".into()], "$foo$$s(arg)")] #[case::dollar_string(vec!["$arg".into()], "$foo$$s(\\$arg)")] - #[case::sequence(vec![vec![0.into(), Type::UNIT.into()].into()], "$foo$$list($n(0)$t(Unit))")] - #[case::sequence(vec![TypeArg::Tuple(vec![0.into(),Type::UNIT.into()])], "$foo$$tuple($n(0)$t(Unit))")] + #[case::sequence(vec![vec![0.into(), Type::UNIT.into()].into()], "$foo$$seq($n(0)$t(Unit))")] #[should_panic] - #[case::typeargvariable(vec![TypeArg::new_var_use(1, TypeParam::StringType)], + #[case::typeargvariable(vec![TypeArg::new_var_use(1, TypeParam::String)], "$foo$$v(1)")] #[case::multiple(vec![0.into(), "arg".into()], "$foo$$n(0)$s(arg)")] fn test_mangle_name(#[case] args: Vec, #[case] expected: String) { diff --git a/hugr-passes/src/non_local.rs b/hugr-passes/src/non_local.rs index df276a1ff9..75bbea399e 100644 --- a/hugr-passes/src/non_local.rs +++ b/hugr-passes/src/non_local.rs @@ -1,5 +1,6 @@ //! This module provides functions for finding non-local edges //! in a Hugr and converting them to local edges. +#![warn(missing_docs)] use itertools::Itertools as _; use hugr_core::{ diff --git a/hugr-passes/src/replace_types.rs b/hugr-passes/src/replace_types.rs index 0b5cca8f6a..ac19094c19 100644 --- a/hugr-passes/src/replace_types.rs +++ b/hugr-passes/src/replace_types.rs @@ -1,4 +1,5 @@ #![allow(clippy::type_complexity)] +#![warn(missing_docs)] //! Replace types with other types across the Hugr. See [`ReplaceTypes`] and [Linearizer]. //! use std::borrow::Cow; @@ -107,9 +108,9 @@ impl NodeTemplate { } } - fn replace(self, hugr: &mut impl HugrMut, n: Node) -> Result<(), BuildError> { + fn replace(&self, hugr: &mut impl HugrMut, n: Node) -> Result<(), BuildError> { assert_eq!(hugr.children(n).count(), 0); - let new_optype = match self { + let new_optype = match self.clone() { NodeTemplate::SingleOp(op_type) => op_type, NodeTemplate::CompoundOp(new_h) => { let new_entrypoint = hugr.insert_hugr(n, *new_h).inserted_entrypoint; @@ -171,23 +172,6 @@ fn call>( Ok(Call::try_new(func_sig, type_args)?) } -/// Options for how the replacement for an op is processed. May be specified by -/// [ReplaceTypes::replace_op_with] and [ReplaceTypes::replace_parametrized_op_with]. -/// Otherwise (the default), replacements are inserted as is (without further processing). -#[derive(Clone, Default, PartialEq, Eq)] // More derives might inhibit future extension -pub struct ReplacementOptions { - linearize: bool, -} - -impl ReplacementOptions { - /// Specifies that all operations within the replacement should have their - /// output ports linearized. - pub fn with_linearization(mut self, lin: bool) -> Self { - self.linearize = lin; - self - } -} - /// A configuration of what types, ops, and constants should be replaced with what. /// May be applied to a Hugr via [`Self::run`]. /// @@ -220,14 +204,8 @@ pub struct ReplaceTypes { type_map: HashMap, param_types: HashMap Option>>, linearize: DelegatingLinearizer, - op_map: HashMap, - param_ops: HashMap< - ParametricOp, - ( - Arc Option>, - ReplacementOptions, - ), - >, + op_map: HashMap, + param_ops: HashMap Option>>, consts: HashMap< CustomType, Arc Result>, @@ -281,7 +259,7 @@ pub enum ReplaceTypesError { #[error(transparent)] LinearizeError(#[from] LinearizeError), #[error("Replacement op for {0} could not be added because {1}")] - AddTemplateError(Node, Box), + AddTemplateError(Node, BuildError), } impl ReplaceTypes { @@ -360,36 +338,13 @@ impl ReplaceTypes { } /// Configures this instance to change occurrences of `src` to `dest`. - /// Equivalent to [Self::replace_op_with] with default [ReplacementOptions]. - pub fn replace_op(&mut self, src: &ExtensionOp, dest: NodeTemplate) { - self.replace_op_with(src, dest, ReplacementOptions::default()) - } - - /// Configures this instance to change occurrences of `src` to `dest`. - /// /// Note that if `src` is an instance of a *parametrized* [`OpDef`], this takes /// precedence over [`Self::replace_parametrized_op`] where the `src`s overlap. Thus, /// this should only be used on already-*[monomorphize](super::monomorphize())d* /// Hugrs, as substitution (parametric polymorphism) happening later will not respect /// this replacement. - pub fn replace_op_with( - &mut self, - src: &ExtensionOp, - dest: NodeTemplate, - opts: ReplacementOptions, - ) { - self.op_map.insert(OpHashWrapper::from(src), (dest, opts)); - } - - /// Configures this instance to change occurrences of a parametrized op `src` - /// via a callback that builds the replacement type given the [`TypeArg`]s. - /// Equivalent to [Self::replace_parametrized_op_with] with default [ReplacementOptions]. - pub fn replace_parametrized_op( - &mut self, - src: &OpDef, - dest_fn: impl Fn(&[TypeArg]) -> Option + 'static, - ) { - self.replace_parametrized_op_with(src, dest_fn, ReplacementOptions::default()) + pub fn replace_op(&mut self, src: &ExtensionOp, dest: NodeTemplate) { + self.op_map.insert(OpHashWrapper::from(src), dest); } /// Configures this instance to change occurrences of a parametrized op `src` @@ -398,13 +353,12 @@ impl ReplaceTypes { /// fit the bounds of the original op). /// /// If the Callback returns None, the new typeargs will be applied to the original op. - pub fn replace_parametrized_op_with( + pub fn replace_parametrized_op( &mut self, src: &OpDef, dest_fn: impl Fn(&[TypeArg]) -> Option + 'static, - opts: ReplacementOptions, ) { - self.param_ops.insert(src.into(), (Arc::new(dest_fn), opts)); + self.param_ops.insert(src.into(), Arc::new(dest_fn)); } /// Configures this instance to change [Const]s of type `src_ty`, using @@ -494,40 +448,34 @@ impl ReplaceTypes { | rest.transform(self)?), OpType::Const(Const { value, .. }) => self.change_value(value), - OpType::ExtensionOp(ext_op) => Ok({ - let def = ext_op.def_arc(); - let mut changed = false; - let replacement = match self.op_map.get(&OpHashWrapper::from(&*ext_op)) { - r @ Some(_) => r.cloned(), - None => { - let mut args = ext_op.args().to_vec(); - changed = args.transform(self)?; - let r2 = self - .param_ops - .get(&def.as_ref().into()) - .and_then(|(rep_fn, opts)| rep_fn(&args).map(|nt| (nt, opts.clone()))); - if r2.is_none() && changed { - *ext_op = ExtensionOp::new(def.clone(), args)?; - } - r2 - } - }; - if let Some((replacement, opts)) = replacement { + OpType::ExtensionOp(ext_op) => Ok( + // Copy/discard insertion done by caller + if let Some(replacement) = self.op_map.get(&OpHashWrapper::from(&*ext_op)) { replacement .replace(hugr, n) - .map_err(|e| ReplaceTypesError::AddTemplateError(n, Box::new(e)))?; - if opts.linearize { - for d in hugr.descendants(n).collect::>() { - if d != n { - self.linearize_outputs(hugr, d)?; - } - } - } + .map_err(|e| ReplaceTypesError::AddTemplateError(n, e))?; true } else { - changed - } - }), + let def = ext_op.def_arc(); + let mut args = ext_op.args().to_vec(); + let ch = args.transform(self)?; + if let Some(replacement) = self + .param_ops + .get(&def.as_ref().into()) + .and_then(|rep_fn| rep_fn(&args)) + { + replacement + .replace(hugr, n) + .map_err(|e| ReplaceTypesError::AddTemplateError(n, e))?; + true + } else { + if ch { + *ext_op = ExtensionOp::new(def.clone(), args)?; + } + ch + } + }, + ), OpType::OpaqueOp(_) => panic!("OpaqueOp should not be in a Hugr"), @@ -571,27 +519,6 @@ impl ReplaceTypes { Value::Function { hugr } => self.run(&mut **hugr), } } - - fn linearize_outputs>( - &self, - hugr: &mut H, - n: H::Node, - ) -> Result<(), LinearizeError> { - if let Some(new_sig) = hugr.get_optype(n).dataflow_signature() { - let new_sig = new_sig.into_owned(); - for outp in new_sig.output_ports() { - if !new_sig.out_port_type(outp).unwrap().copyable() { - let targets = hugr.linked_inputs(n, outp).collect::>(); - if targets.len() != 1 { - hugr.disconnect(n, outp); - let src = Wire::new(n, outp); - self.linearize.insert_copy_discard(hugr, src, &targets)?; - } - } - } - } - Ok(()) - } } impl> ComposablePass for ReplaceTypes { @@ -602,8 +529,21 @@ impl> ComposablePass for ReplaceTypes { let mut changed = false; for n in hugr.entry_descendants().collect::>() { changed |= self.change_node(hugr, n)?; - if n != hugr.entrypoint() && changed { - self.linearize_outputs(hugr, n)?; + let new_dfsig = hugr.get_optype(n).dataflow_signature(); + if let Some(new_sig) = new_dfsig + .filter(|_| changed && n != hugr.entrypoint()) + .map(Cow::into_owned) + { + for outp in new_sig.output_ports() { + if !new_sig.out_port_type(outp).unwrap().copyable() { + let targets = hugr.linked_inputs(n, outp).collect::>(); + if targets.len() != 1 { + hugr.disconnect(n, outp); + let src = Wire::new(n, outp); + self.linearize.insert_copy_discard(hugr, src, &targets)?; + } + } + } } } Ok(changed) @@ -701,7 +641,7 @@ mod test { } fn just_elem_type(args: &[TypeArg]) -> &Type { - let [TypeArg::Runtime(ty)] = args else { + let [TypeArg::Type { ty }] = args else { panic!("Expected just elem type") }; ty @@ -715,7 +655,7 @@ mod test { let pv_of_var = ext .add_type( PACKED_VEC.into(), - vec![TypeBound::Linear.into()], + vec![TypeBound::Any.into()], String::new(), TypeDefBound::from_params(vec![0]), w, @@ -730,7 +670,7 @@ mod test { vec![TypeBound::Copyable.into()], Signature::new( vec![pv_of_var.into(), i64_t()], - Type::new_var_use(0, TypeBound::Linear), + Type::new_var_use(0, TypeBound::Any), ), ), w, @@ -808,9 +748,9 @@ mod test { let c_int = Type::from(coln.instantiate([i64_t().into()]).unwrap()); let c_bool = Type::from(coln.instantiate([bool_t().into()]).unwrap()); let mut mb = ModuleBuilder::new(); - let sig = Signature::new_endo(Type::new_var_use(0, TypeBound::Linear)); + let sig = Signature::new_endo(Type::new_var_use(0, TypeBound::Any)); let fb = mb - .define_function("id", PolyFuncType::new([TypeBound::Linear.into()], sig)) + .define_function("id", PolyFuncType::new([TypeBound::Any.into()], sig)) .unwrap(); let inps = fb.input_wires(); let id = fb.finish_with_outputs(inps).unwrap(); @@ -1027,8 +967,8 @@ mod test { IdentList::new_unchecked("NoBoundsCheck"), Version::new(0, 0, 0), |e, w| { - let params = vec![TypeBound::Linear.into()]; - let tv = Type::new_var_use(0, TypeBound::Linear); + let params = vec![TypeBound::Any.into()]; + let tv = Type::new_var_use(0, TypeBound::Any); let list_of_var = list_type(tv.clone()); e.add_op( READ.into(), diff --git a/hugr-passes/src/replace_types/handlers.rs b/hugr-passes/src/replace_types/handlers.rs index 7c0fe5f550..25abb846bc 100644 --- a/hugr-passes/src/replace_types/handlers.rs +++ b/hugr-passes/src/replace_types/handlers.rs @@ -106,7 +106,7 @@ pub fn linearize_generic_array( ) -> Result { // Require known length i.e. usable only after monomorphization, due to no-variables limitation // restriction on NodeTemplate::CompoundOp - let [TypeArg::BoundedNat(n), TypeArg::Runtime(ty)] = args else { + let [TypeArg::BoundedNat { n }, TypeArg::Type { ty }] = args else { panic!("Illegal TypeArgs to array: {args:?}") }; if num_outports == 0 { @@ -116,9 +116,7 @@ pub fn linearize_generic_array( let [to_discard] = dfb.input_wires_arr(); lin.copy_discard_op(ty, 0)? .add(&mut dfb, [to_discard]) - .map_err(|e| { - LinearizeError::NestedTemplateError(Box::new(ty.clone()), Box::new(e)) - })?; + .map_err(|e| LinearizeError::NestedTemplateError(ty.clone(), e))?; let ret = dfb.add_load_value(Value::unary_unit_sum()); dfb.finish_hugr_with_outputs([ret]).unwrap() }; @@ -191,7 +189,7 @@ pub fn linearize_generic_array( let mut copies = lin .copy_discard_op(ty, num_outports)? .add(&mut dfb, [elem]) - .map_err(|e| LinearizeError::NestedTemplateError(Box::new(ty.clone()), Box::new(e)))? + .map_err(|e| LinearizeError::NestedTemplateError(ty.clone(), e))? .outputs(); let copy0 = copies.next().unwrap(); // We'll return this directly @@ -309,7 +307,7 @@ pub fn copy_discard_array( ) -> Result { // Require known length i.e. usable only after monomorphization, due to no-variables limitation // restriction on NodeTemplate::CompoundOp - let [TypeArg::BoundedNat(n), TypeArg::Runtime(ty)] = args else { + let [TypeArg::BoundedNat { n }, TypeArg::Type { ty }] = args else { panic!("Illegal TypeArgs to array: {args:?}") }; if ty.copyable() { diff --git a/hugr-passes/src/replace_types/linearize.rs b/hugr-passes/src/replace_types/linearize.rs index 4227c5d817..bc12e730bd 100644 --- a/hugr-passes/src/replace_types/linearize.rs +++ b/hugr-passes/src/replace_types/linearize.rs @@ -5,19 +5,17 @@ use hugr_core::builder::{ HugrBuilder, inout_sig, }; use hugr_core::extension::{SignatureError, TypeDef}; -use hugr_core::std_extensions::collections::array::array_type_def; use hugr_core::std_extensions::collections::value_array::value_array_type_def; use hugr_core::types::{CustomType, Signature, Type, TypeArg, TypeEnum, TypeRow}; use hugr_core::{HugrView, IncomingPort, Node, Wire, hugr::hugrmut::HugrMut, ops::Tag}; use itertools::Itertools; -use super::handlers::{copy_discard_array, linearize_value_array}; -use super::{NodeTemplate, ParametricType}; +use super::{NodeTemplate, ParametricType, handlers::linearize_value_array}; /// Trait for things that know how to wire up linear outports to other than one /// target. Used to restore Hugr validity when a [`ReplaceTypes`](super::ReplaceTypes) /// results in types of such outports changing from [Copyable] to linear (i.e. -/// [`hugr_core::types::TypeBound::Linear`]). +/// [`hugr_core::types::TypeBound::Any`]). /// /// Note that this is not really effective before [monomorphization]: if a /// function polymorphic over a [Copyable] becomes called with a @@ -54,10 +52,12 @@ pub trait Linearizer { src: Wire, targets: &[(Node, IncomingPort)], ) -> Result<(), LinearizeError> { + let sig = hugr.signature(src.node()).unwrap(); + let typ = sig.port_type(src.source()).unwrap(); let (tgt_node, tgt_inport) = if targets.len() == 1 { *targets.first().unwrap() } else { - // Fail fast if the edges are nonlocal. + // Fail fast if the edges are nonlocal. (TODO transform to local edges!) let src_parent = hugr .get_parent(src.node()) .expect("Root node cannot have out edges"); @@ -74,12 +74,11 @@ pub trait Linearizer { tgt_parent, }); } - let sig = hugr.signature(src.node()).unwrap(); - let typ = sig.port_type(src.source()).unwrap().clone(); + let typ = typ.clone(); // Stop borrowing hugr in order to add_hugr to it let copy_discard_op = self .copy_discard_op(&typ, targets.len())? .add_hugr(hugr, src_parent) - .map_err(|e| LinearizeError::NestedTemplateError(Box::new(typ), Box::new(e)))?; + .map_err(|e| LinearizeError::NestedTemplateError(typ, e))?; for (n, (tgt_node, tgt_port)) in targets.iter().enumerate() { hugr.connect(copy_discard_op, n, *tgt_node, *tgt_port); } @@ -126,7 +125,6 @@ impl Default for DelegatingLinearizer { fn default() -> Self { let mut res = Self::new_empty(); res.register_callback(value_array_type_def(), linearize_value_array); - res.register_callback(array_type_def(), copy_discard_array); res } } @@ -142,16 +140,15 @@ pub struct CallbackHandler<'a>(#[allow(dead_code)] &'a DelegatingLinearizer); #[non_exhaustive] pub enum LinearizeError { #[error("Need copy/discard op for {_0}")] - NeedCopyDiscard(Box), + NeedCopyDiscard(Type), #[error("Copy/discard op for {typ} with {num_outports} outputs had wrong signature {sig:?}")] WrongSignature { - typ: Box, + typ: Type, num_outports: usize, - sig: Option>, + sig: Option, }, #[error( - "Cannot add nonlocal edge for linear type from {src} (with parent {src_parent}) to {tgt} (with parent {tgt_parent}). - Try using LocalizeEdges pass first." + "Cannot add nonlocal edge for linear type from {src} (with parent {src_parent}) to {tgt} (with parent {tgt_parent})" )] NoLinearNonLocalEdges { src: Node, @@ -166,14 +163,14 @@ pub enum LinearizeError { /// [Variable](TypeEnum::Variable)s, [Row variables](TypeEnum::RowVar), /// or [Alias](TypeEnum::Alias)es. #[error("Cannot linearize type {_0}")] - UnsupportedType(Box), + UnsupportedType(Type), /// Neither does linearization make sense for copyable types #[error("Type {_0} is copyable")] - CopyableType(Box), + CopyableType(Type), /// Error may be returned by a callback for e.g. a container because it could /// not generate a [`NodeTemplate`] because of a problem with an element #[error("Could not generate NodeTemplate for contained type {0} because {1}")] - NestedTemplateError(Box, Box), + NestedTemplateError(Type, BuildError), } impl DelegatingLinearizer { @@ -209,7 +206,7 @@ impl DelegatingLinearizer { ) -> Result<(), LinearizeError> { let typ = Type::new_extension(cty.clone()); if typ.copyable() { - return Err(LinearizeError::CopyableType(Box::new(typ))); + return Err(LinearizeError::CopyableType(typ)); } check_sig(©, &typ, 2)?; check_sig(&discard, &typ, 0)?; @@ -250,9 +247,9 @@ impl DelegatingLinearizer { fn check_sig(tmpl: &NodeTemplate, typ: &Type, num_outports: usize) -> Result<(), LinearizeError> { tmpl.check_signature(&typ.clone().into(), &vec![typ.clone(); num_outports].into()) .map_err(|sig| LinearizeError::WrongSignature { - typ: Box::new(typ.clone()), + typ: typ.clone(), num_outports, - sig: sig.map(Box::new), + sig, }) } @@ -263,7 +260,7 @@ impl Linearizer for DelegatingLinearizer { num_outports: usize, ) -> Result { if typ.copyable() { - return Err(LinearizeError::CopyableType(Box::new(typ.clone()))); + return Err(LinearizeError::CopyableType(typ.clone())); } assert!(num_outports != 1); @@ -341,14 +338,14 @@ impl Linearizer for DelegatingLinearizer { let copy_discard_fn = self .copy_discard_parametric .get(&cty.into()) - .ok_or_else(|| LinearizeError::NeedCopyDiscard(Box::new(typ.clone())))?; + .ok_or_else(|| LinearizeError::NeedCopyDiscard(typ.clone()))?; let tmpl = copy_discard_fn(cty.args(), num_outports, &CallbackHandler(self))?; check_sig(&tmpl, typ, num_outports)?; Ok(tmpl) } } TypeEnum::Function(_) => panic!("Ruled out above as copyable"), - _ => Err(LinearizeError::UnsupportedType(Box::new(typ.clone()))), + _ => Err(LinearizeError::UnsupportedType(typ.clone())), } } } @@ -374,7 +371,7 @@ mod test { HugrBuilder, inout_sig, }; - use hugr_core::extension::prelude::{option_type, qb_t, usize_t}; + use hugr_core::extension::prelude::{option_type, usize_t}; use hugr_core::extension::simple_op::MakeExtensionOp; use hugr_core::extension::{ CustomSignatureFunc, OpDef, SignatureError, SignatureFunc, TypeDefBound, Version, @@ -388,16 +385,14 @@ mod test { }; use hugr_core::types::type_param::TypeParam; use hugr_core::types::{ - FuncValueType, PolyFuncTypeRV, Signature, Type, TypeArg, TypeBound, TypeEnum, TypeRow, + FuncValueType, PolyFuncTypeRV, Signature, Type, TypeArg, TypeEnum, TypeRow, }; use hugr_core::{Extension, Hugr, HugrView, Node, hugr::IdentList, type_row}; use itertools::Itertools; use rstest::rstest; use crate::replace_types::handlers::linearize_value_array; - use crate::replace_types::{ - LinearizeError, NodeTemplate, ReplaceTypesError, ReplacementOptions, - }; + use crate::replace_types::{LinearizeError, NodeTemplate, ReplaceTypesError}; use crate::{ComposablePass, ReplaceTypes}; const LIN_T: &str = "Lin"; @@ -409,7 +404,7 @@ mod test { arg_values: &[TypeArg], _def: &'o OpDef, ) -> Result { - let [TypeArg::BoundedNat(n)] = arg_values else { + let [TypeArg::BoundedNat { n }] = arg_values else { panic!() }; let outs = vec![self.0.clone(); *n as usize]; @@ -417,7 +412,7 @@ mod test { } fn static_params(&self) -> &[TypeParam] { - const JUST_NAT: &[TypeParam] = &[TypeParam::max_nat_type()]; + const JUST_NAT: &[TypeParam] = &[TypeParam::max_nat()]; JUST_NAT } } @@ -653,9 +648,9 @@ mod test { assert_eq!( bad_copy, Err(LinearizeError::WrongSignature { - typ: Box::new(lin_t.clone()), + typ: lin_t.clone(), num_outports: 2, - sig: sig3.clone().map(Box::new) + sig: sig3.clone() }) ); @@ -668,9 +663,9 @@ mod test { assert_eq!( bad_discard, Err(LinearizeError::WrongSignature { - typ: Box::new(lin_t.clone()), + typ: lin_t.clone(), num_outports: 0, - sig: sig3.clone().map(Box::new) + sig: sig3.clone() }) ); @@ -690,9 +685,9 @@ mod test { replacer.run(&mut h), Err(ReplaceTypesError::LinearizeError( LinearizeError::WrongSignature { - typ: Box::new(lin_t.clone()), + typ: lin_t.clone(), num_outports: 2, - sig: sig3.clone().map(Box::new) + sig: sig3.clone() } )) ); @@ -805,8 +800,7 @@ mod test { // A simple Hugr that discards a usize_t, with a "drop" function let mut dfb = DFGBuilder::new(inout_sig(usize_t(), type_row![])).unwrap(); let discard_fn = { - let mut mb = dfb.module_root_builder(); - let mut fb = mb + let mut fb = dfb .define_function("drop", Signature::new(lin_t.clone(), type_row![])) .unwrap(); let ins = fb.input_wires(); @@ -821,11 +815,12 @@ mod test { let backup = dfb.finish_hugr().unwrap(); let mut lower_discard_to_call = ReplaceTypes::default(); + // The `copy_fn` here will break completely, but we don't use it lower_discard_to_call .linearizer() .register_simple( lin_ct.clone(), - NodeTemplate::Call(backup.entrypoint(), vec![]), // Arbitrary, unused + NodeTemplate::Call(backup.entrypoint(), vec![]), NodeTemplate::Call(discard_fn, vec![]), ) .unwrap(); @@ -839,85 +834,20 @@ mod test { assert_eq!(h.output_neighbours(discard_fn).count(), 1); } - // But if we lower usize_t to array, the call will fail. + // But if we lower usize_t to array, the call will fail lower_discard_to_call.replace_type( usize_t().as_extension().unwrap().clone(), value_array_type(4, lin_ct.into()), ); let r = lower_discard_to_call.run(&mut backup.clone()); - // Note the error (or success) can be quite fragile, according to what the `discard_fn` - // Node points at in the (hidden here) inner Hugr built by the array linearization helper. - if let Err(ReplaceTypesError::LinearizeError(LinearizeError::NestedTemplateError( - nested_t, - build_err, - ))) = r - { - assert_eq!(*nested_t, lin_t); - assert!(matches!( - *build_err, BuildError::NodeNotFound { node } if node == discard_fn - )); - } else { - panic!("Expected error"); - } - } - - #[test] - fn use_in_op_callback() { - let (e, mut lowerer) = ext_lowerer(); - let drop_ext = Extension::new_arc( - IdentList::new_unchecked("DropExt"), - Version::new(0, 0, 0), - |e, w| { - e.add_op( - "drop".into(), - String::new(), - PolyFuncTypeRV::new( - [TypeBound::Linear.into()], // It won't *lower* for any type tho! - Signature::new(Type::new_var_use(0, TypeBound::Linear), vec![]), - ), - w, + assert!(matches!( + r, + Err(ReplaceTypesError::LinearizeError( + LinearizeError::NestedTemplateError( + nested_t, + BuildError::NodeNotFound { node } ) - .unwrap(); - }, - ); - let drop_op = drop_ext.get_op("drop").unwrap(); - lowerer.replace_parametrized_op_with( - drop_op, - |args| { - let [TypeArg::Runtime(ty)] = args else { - panic!("Expected just one type") - }; - // The Hugr here is invalid, so we have to pull it out manually - let mut dfb = DFGBuilder::new(Signature::new(ty.clone(), vec![])).unwrap(); - let h = std::mem::take(dfb.hugr_mut()); - Some(NodeTemplate::CompoundOp(Box::new(h))) - }, - ReplacementOptions::default().with_linearization(true), - ); - - let build_hugr = |ty: Type| { - let mut dfb = DFGBuilder::new(Signature::new(ty.clone(), vec![])).unwrap(); - let [inp] = dfb.input_wires_arr(); - let drop_op = drop_ext - .instantiate_extension_op("drop", [ty.into()]) - .unwrap(); - dfb.add_dataflow_op(drop_op, [inp]).unwrap(); - dfb.finish_hugr().unwrap() - }; - // We can drop a tuple of 2* lin_t - let lin_t = Type::from(e.get_type(LIN_T).unwrap().instantiate([]).unwrap()); - let mut h = build_hugr(Type::new_tuple(vec![lin_t; 2])); - lowerer.run(&mut h).unwrap(); - h.validate().unwrap(); - let mut exts = h.nodes().filter_map(|n| h.get_optype(n).as_extension_op()); - assert_eq!(exts.clone().count(), 2); - assert!(exts.all(|eo| eo.qualified_id() == "TestExt.discard")); - - // We cannot drop a qubit - let mut h = build_hugr(qb_t()); - assert_eq!( - lowerer.run(&mut h).unwrap_err(), - ReplaceTypesError::LinearizeError(LinearizeError::NeedCopyDiscard(Box::new(qb_t()))) - ); + )) if nested_t == lin_t && node == discard_fn + )); } } diff --git a/hugr-persistent/CHANGELOG.md b/hugr-persistent/CHANGELOG.md deleted file mode 100644 index a69263ca63..0000000000 --- a/hugr-persistent/CHANGELOG.md +++ /dev/null @@ -1,11 +0,0 @@ -# Changelog - - -## [0.2.0](https://github.com/CQCL/hugr/compare/hugr-persistent-v0.1.0...hugr-persistent-v0.2.0) - 2025-07-24 - -### New Features - -- [**breaking**] Update portgraph dependency to 0.15 ([#2455](https://github.com/CQCL/hugr/pull/2455)) -## 0.1.0 (2025-07-10) - -Initial release. diff --git a/hugr-persistent/Cargo.toml b/hugr-persistent/Cargo.toml deleted file mode 100644 index 75aa61f240..0000000000 --- a/hugr-persistent/Cargo.toml +++ /dev/null @@ -1,43 +0,0 @@ -[package] -name = "hugr-persistent" -version = "0.2.1" -edition = { workspace = true } -rust-version = { workspace = true } -license = { workspace = true } -readme = "README.md" -documentation = "https://docs.rs/hugr-persistent/" -homepage = { workspace = true } -repository = { workspace = true } -description = "Persistent IR structure for Quantinuum's HUGR" -keywords = ["Quantum", "Quantinuum"] -categories = ["compilers"] - -[[test]] -name = "persistent_walker_example" - -[dependencies] -hugr-core = { path = "../hugr-core", version = "0.22.1" } - -derive_more = { workspace = true, features = ["display", "error", "from"] } -delegate.workspace = true -itertools.workspace = true -petgraph.workspace = true -portgraph.workspace = true -relrc = { workspace = true, features = ["petgraph", "serde"] } -serde.workspace = true -serde_json.workspace = true -thiserror.workspace = true -wyhash.workspace = true - -[lints] -workspace = true - -[lib] -bench = false - -[dev-dependencies] -rstest.workspace = true -lazy_static.workspace = true -semver.workspace = true -serde_with.workspace = true -insta.workspace = true diff --git a/hugr-persistent/README.md b/hugr-persistent/README.md deleted file mode 100644 index 95386fa664..0000000000 --- a/hugr-persistent/README.md +++ /dev/null @@ -1,59 +0,0 @@ -![](/hugr/assets/hugr_logo.svg) - -# hugr-persistent - -[![build_status][]](https://github.com/CQCL/hugr/actions) -[![crates][]](https://crates.io/crates/hugr-persistent) -[![msrv][]](https://github.com/CQCL/hugr) -[![codecov][]](https://codecov.io/gh/CQCL/hugr) - -The Hierarchical Unified Graph Representation (HUGR, pronounced _hugger_) is the -common representation of quantum circuits and operations in the Quantinuum -ecosystem. - -It provides a high-fidelity representation of operations, that facilitates -compilation and encodes runnable programs. - -The HUGR specification is [here](https://github.com/CQCL/hugr/blob/main/specification/hugr.md). - -## Overview - -This crate provides a persistent data structure for HUGR mutations; mutations to -the data are stored persistently as a set of `Commit`s along with the -dependencies between them. - -As a result of persistency, the entire mutation history of a HUGR can be -traversed and references to previous versions of the data remain valid even -as the HUGR graph is "mutated" by applying patches: the patches are in -effect added to the history as new commits. - -## Usage - -Add the dependency to your project: - -```bash -cargo add hugr-persistent -``` - -Please read the [API documentation here][]. - -## Recent Changes - -See [CHANGELOG][] for a list of changes. The minimum supported rust -version will only change on major releases. - -## Development - -See [DEVELOPMENT.md](https://github.com/CQCL/hugr/blob/main/DEVELOPMENT.md) for instructions on setting up the development environment. - -## License - -This project is licensed under Apache License, Version 2.0 ([LICENSE][] or http://www.apache.org/licenses/LICENSE-2.0). - - [API documentation here]: https://docs.rs/hugr-persistent/ - [build_status]: https://github.com/CQCL/hugr/actions/workflows/ci-rs.yml/badge.svg?branch=main - [msrv]: https://img.shields.io/crates/msrv/hugr-persistent - [crates]: https://img.shields.io/crates/v/hugr-persistent - [codecov]: https://img.shields.io/codecov/c/gh/CQCL/hugr?logo=codecov - [LICENSE]: https://github.com/CQCL/hugr/blob/main/LICENCE - [CHANGELOG]: https://github.com/CQCL/hugr/blob/main/hugr-persistent/CHANGELOG.md diff --git a/hugr-persistent/src/lib.rs b/hugr-persistent/src/lib.rs deleted file mode 100644 index 694e83e23e..0000000000 --- a/hugr-persistent/src/lib.rs +++ /dev/null @@ -1,98 +0,0 @@ -#![doc(hidden)] // TODO: remove when stable - -//! Persistent data structure for HUGR mutations. -//! -//! This crate provides a persistent data structure [`PersistentHugr`] that -//! implements [`HugrView`](hugr_core::HugrView); mutations to the data are -//! stored persistently as a set of [`Commit`]s along with the dependencies -//! between the commits. -//! -//! As a result of persistency, the entire mutation history of a HUGR can be -//! traversed and references to previous versions of the data remain valid even -//! as the HUGR graph is "mutated" by applying patches: the patches are in -//! effect added to the history as new commits. -//! -//! The data structure underlying [`PersistentHugr`], which stores the history -//! of all commits, is [`CommitStateSpace`]. Multiple [`PersistentHugr`] can be -//! stored within a single [`CommitStateSpace`], which allows for the efficient -//! exploration of the space of all possible graph rewrites. -//! -//! ## Overlapping commits -//! -//! In general, [`CommitStateSpace`] may contain overlapping commits. Such -//! mutations are mutually exclusive as they modify the same nodes. It is -//! therefore not possible to apply all commits in a [`CommitStateSpace`] -//! simultaneously. A [`PersistentHugr`] on the other hand always corresponds to -//! a subgraph of a [`CommitStateSpace`] that is guaranteed to contain only -//! non-overlapping, compatible commits. By applying all commits in a -//! [`PersistentHugr`], we can materialize a [`Hugr`](hugr_core::Hugr). -//! Traversing the materialized HUGR is equivalent to using the -//! [`HugrView`](hugr_core::HugrView) implementation of the corresponding -//! [`PersistentHugr`]. -//! -//! ## Summary of data types -//! -//! - [`Commit`] A modification to a [`Hugr`](hugr_core::Hugr) (currently a -//! [`SimpleReplacement`](hugr_core::SimpleReplacement)) that forms the atomic -//! unit of change for a [`PersistentHugr`] (like a commit in git). This is a -//! reference-counted value that is cheap to clone and will be freed when the -//! last reference is dropped. -//! - [`PersistentHugr`] A data structure that implements -//! [`HugrView`][hugr_core::HugrView] and can be used as a drop-in replacement -//! for a [`Hugr`][hugr_core::Hugr] for read-only access and mutations through -//! the [`PatchVerification`](hugr_core::hugr::patch::PatchVerification) and -//! [`Patch`](hugr_core::hugr::Patch) traits. Mutations are stored as a -//! history of commits. Unlike [`CommitStateSpace`], it maintains the -//! invariant that all contained commits are compatible with eachother. -//! - [`CommitStateSpace`] Stores commits, recording the dependencies between -//! them. Includes the base HUGR and any number of possibly incompatible -//! (overlapping) commits. Unlike a [`PersistentHugr`], a state space can -//! contain mutually exclusive commits. -//! -//! ## Usage -//! -//! A [`PersistentHugr`] can be created from a base HUGR using -//! [`PersistentHugr::with_base`]. Replacements can then be applied to it -//! using [`PersistentHugr::add_replacement`]. Alternatively, if you already -//! have a populated state space, use [`PersistentHugr::try_new`] to create a -//! new HUGR with those commits. -//! -//! Add a sequence of commits to a state space by merging a [`PersistentHugr`] -//! into it using [`CommitStateSpace::extend`] or directly using -//! [`CommitStateSpace::try_add_commit`]. -//! -//! To obtain a [`PersistentHugr`] from your state space, use -//! [`CommitStateSpace::try_extract_hugr`]. A [`PersistentHugr`] can always be -//! materialized into a [`Hugr`][hugr_core::Hugr] type using -//! [`PersistentHugr::to_hugr`]. - -mod parents_view; -mod persistent_hugr; -mod resolver; -pub mod state_space; -pub mod subgraph; -mod trait_impls; -pub mod walker; -mod wire; - -pub use persistent_hugr::{Commit, PersistentHugr}; -pub use resolver::{PointerEqResolver, Resolver, SerdeHashResolver}; -pub use state_space::{CommitId, CommitStateSpace, InvalidCommit, PatchNode}; -pub use subgraph::PinnedSubgraph; -pub use walker::Walker; -pub use wire::PersistentWire; - -/// A replacement operation that can be applied to a [`PersistentHugr`]. -pub type PersistentReplacement = hugr_core::SimpleReplacement; - -use persistent_hugr::find_conflicting_node; -use state_space::CommitData; - -pub mod serial { - //! Serialized formats for commits, state spaces and persistent HUGRs. - pub use super::persistent_hugr::serial::*; - pub use super::state_space::serial::*; -} - -#[cfg(test)] -mod tests; diff --git a/hugr-persistent/src/persistent_hugr/serial.rs b/hugr-persistent/src/persistent_hugr/serial.rs deleted file mode 100644 index 9a41e4acef..0000000000 --- a/hugr-persistent/src/persistent_hugr/serial.rs +++ /dev/null @@ -1,75 +0,0 @@ -//! Serialized format for [`PersistentHugr`] - -use hugr_core::Hugr; - -use crate::{CommitStateSpace, Resolver, state_space::serial::SerialCommitStateSpace}; - -use super::PersistentHugr; - -/// Serialized format for [`PersistentHugr`] -#[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)] -pub struct SerialPersistentHugr { - /// The state space of all commits. - state_space: SerialCommitStateSpace, -} - -impl PersistentHugr { - /// Create a new [`CommitStateSpace`] from its serialized format - pub fn from_serial>(value: SerialPersistentHugr) -> Self { - let SerialPersistentHugr { state_space } = value; - let state_space = CommitStateSpace::from_serial(state_space); - Self { state_space } - } - - /// Convert a [`CommitStateSpace`] into its serialized format - pub fn into_serial>(self) -> SerialPersistentHugr { - let Self { state_space } = self; - let state_space = state_space.into_serial(); - SerialPersistentHugr { state_space } - } - - /// Create a serialized format from a reference to [`CommitStateSpace`] - pub fn to_serial>(&self) -> SerialPersistentHugr { - let Self { state_space } = self; - let state_space = state_space.to_serial(); - SerialPersistentHugr { state_space } - } -} - -impl, R: Resolver> From> for SerialPersistentHugr { - fn from(value: PersistentHugr) -> Self { - value.into_serial() - } -} - -impl, R: Resolver> From> for PersistentHugr { - fn from(value: SerialPersistentHugr) -> Self { - PersistentHugr::from_serial(value) - } -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::{ - CommitId, SerdeHashResolver, - tests::{WrappedHugr, test_state_space}, - }; - - use rstest::rstest; - - #[rstest] - fn test_serde_persistent_hugr( - test_state_space: ( - CommitStateSpace>, - [CommitId; 4], - ), - ) { - let (state_space, [cm1, cm2, _, cm4]) = test_state_space; - - let per_hugr = state_space.try_extract_hugr([cm1, cm2, cm4]).unwrap(); - let ser_per_hugr = per_hugr.to_serial::(); - - insta::assert_snapshot!(serde_json::to_string_pretty(&ser_per_hugr).unwrap()); - } -} diff --git a/hugr-persistent/src/persistent_hugr/snapshots/hugr_persistent__persistent_hugr__serial__tests__serde_persistent_hugr.snap b/hugr-persistent/src/persistent_hugr/snapshots/hugr_persistent__persistent_hugr__serial__tests__serde_persistent_hugr.snap deleted file mode 100644 index e7f544586a..0000000000 --- a/hugr-persistent/src/persistent_hugr/snapshots/hugr_persistent__persistent_hugr__serial__tests__serde_persistent_hugr.snap +++ /dev/null @@ -1,184 +0,0 @@ ---- -source: hugr-persistent/src/persistent_hugr/serial.rs -expression: "serde_json::to_string_pretty(&ser_per_hugr).unwrap()" ---- -{ - "state_space": { - "graph": { - "nodes": { - "3fd58bd8c5f2494a": { - "value": { - "Base": { - "hugr": "HUGRiHJv?@{\"modules\":[{\"version\":\"live\",\"nodes\":[{\"parent\":0,\"op\":\"Module\"},{\"parent\":0,\"op\":\"FuncDefn\",\"name\":\"main\",\"signature\":{\"params\":[],\"body\":{\"input\":[{\"t\":\"Sum\",\"s\":\"Unit\",\"size\":2},{\"t\":\"Sum\",\"s\":\"Unit\",\"size\":2}],\"output\":[{\"t\":\"Sum\",\"s\":\"Unit\",\"size\":2}]}},\"visibility\":\"Private\"},{\"parent\":1,\"op\":\"Input\",\"types\":[{\"t\":\"Sum\",\"s\":\"Unit\",\"size\":2},{\"t\":\"Sum\",\"s\":\"Unit\",\"size\":2}]},{\"parent\":1,\"op\":\"Output\",\"types\":[{\"t\":\"Sum\",\"s\":\"Unit\",\"size\":2}]},{\"parent\":1,\"op\":\"DFG\",\"signature\":{\"input\":[{\"t\":\"Sum\",\"s\":\"Unit\",\"size\":2},{\"t\":\"Sum\",\"s\":\"Unit\",\"size\":2}],\"output\":[{\"t\":\"Sum\",\"s\":\"Unit\",\"size\":2}]}},{\"parent\":4,\"op\":\"Input\",\"types\":[{\"t\":\"Sum\",\"s\":\"Unit\",\"size\":2},{\"t\":\"Sum\",\"s\":\"Unit\",\"size\":2}]},{\"parent\":4,\"op\":\"Output\",\"types\":[{\"t\":\"Sum\",\"s\":\"Unit\",\"size\":2}]},{\"parent\":4,\"op\":\"Extension\",\"extension\":\"logic\",\"name\":\"Not\",\"args\":[],\"signature\":{\"input\":[{\"t\":\"Sum\",\"s\":\"Unit\",\"size\":2}],\"output\":[{\"t\":\"Sum\",\"s\":\"Unit\",\"size\":2}]}},{\"parent\":4,\"op\":\"Extension\",\"extension\":\"logic\",\"name\":\"Not\",\"args\":[],\"signature\":{\"input\":[{\"t\":\"Sum\",\"s\":\"Unit\",\"size\":2}],\"output\":[{\"t\":\"Sum\",\"s\":\"Unit\",\"size\":2}]}},{\"parent\":4,\"op\":\"Extension\",\"extension\":\"logic\",\"name\":\"And\",\"args\":[],\"signature\":{\"input\":[{\"t\":\"Sum\",\"s\":\"Unit\",\"size\":2},{\"t\":\"Sum\",\"s\":\"Unit\",\"size\":2}],\"output\":[{\"t\":\"Sum\",\"s\":\"Unit\",\"size\":2}]}}],\"edges\":[[[2,0],[4,0]],[[2,1],[4,1]],[[4,0],[3,0]],[[5,0],[7,0]],[[5,1],[8,0]],[[7,0],[9,0]],[[8,0],[9,1]],[[9,0],[6,0]]],\"metadata\":[null,null,null,null,null,null,null,null,null,null],\"entrypoint\":4}],\"extensions\":[{\"version\":\"0.1.0\",\"name\":\"arithmetic.conversions\",\"types\":{},\"operations\":{\"bytecast_float64_to_int64\":{\"extension\":\"arithmetic.conversions\",\"name\":\"bytecast_float64_to_int64\",\"description\":\"reinterpret an float64 as an int based on its bytes, with the same endianness\",\"signature\":{\"params\":[],\"body\":{\"input\":[{\"t\":\"Opaque\",\"extension\":\"arithmetic.float.types\",\"id\":\"float64\",\"args\":[],\"bound\":\"C\"}],\"output\":[{\"t\":\"Opaque\",\"extension\":\"arithmetic.int.types\",\"id\":\"int\",\"args\":[{\"tya\":\"BoundedNat\",\"n\":6}],\"bound\":\"C\"}]}},\"binary\":false},\"bytecast_int64_to_float64\":{\"extension\":\"arithmetic.conversions\",\"name\":\"bytecast_int64_to_float64\",\"description\":\"reinterpret an int64 as a float64 based on its bytes, with the same endianness\",\"signature\":{\"params\":[],\"body\":{\"input\":[{\"t\":\"Opaque\",\"extension\":\"arithmetic.int.types\",\"id\":\"int\",\"args\":[{\"tya\":\"BoundedNat\",\"n\":6}],\"bound\":\"C\"}],\"output\":[{\"t\":\"Opaque\",\"extension\":\"arithmetic.float.types\",\"id\":\"float64\",\"args\":[],\"bound\":\"C\"}]}},\"binary\":false},\"convert_s\":{\"extension\":\"arithmetic.conversions\",\"name\":\"convert_s\",\"description\":\"signed int to float\",\"signature\":{\"params\":[{\"tp\":\"BoundedNat\",\"bound\":7}],\"body\":{\"input\":[{\"t\":\"Opaque\",\"extension\":\"arithmetic.int.types\",\"id\":\"int\",\"args\":[{\"tya\":\"Variable\",\"idx\":0,\"cached_decl\":{\"tp\":\"BoundedNat\",\"bound\":7}}],\"bound\":\"C\"}],\"output\":[{\"t\":\"Opaque\",\"extension\":\"arithmetic.float.types\",\"id\":\"float64\",\"args\":[],\"bound\":\"C\"}]}},\"binary\":false},\"convert_u\":{\"extension\":\"arithmetic.conversions\",\"name\":\"convert_u\",\"description\":\"unsigned int to float\",\"signature\":{\"params\":[{\"tp\":\"BoundedNat\",\"bound\":7}],\"body\":{\"input\":[{\"t\":\"Opaque\",\"extension\":\"arithmetic.int.types\",\"id\":\"int\",\"args\":[{\"tya\":\"Variable\",\"idx\":0,\"cached_decl\":{\"tp\":\"BoundedNat\",\"bound\":7}}],\"bound\":\"C\"}],\"output\":[{\"t\":\"Opaque\",\"extension\":\"arithmetic.float.types\",\"id\":\"float64\",\"args\":[],\"bound\":\"C\"}]}},\"binary\":false},\"ifrombool\":{\"extension\":\"arithmetic.conversions\",\"name\":\"ifrombool\",\"description\":\"convert from bool into a 1-bit integer (1 is true, 0 is false)\",\"signature\":{\"params\":[],\"body\":{\"input\":[{\"t\":\"Sum\",\"s\":\"Unit\",\"size\":2}],\"output\":[{\"t\":\"Opaque\",\"extension\":\"arithmetic.int.types\",\"id\":\"int\",\"args\":[{\"tya\":\"BoundedNat\",\"n\":0}],\"bound\":\"C\"}]}},\"binary\":false},\"ifromusize\":{\"extension\":\"arithmetic.conversions\",\"name\":\"ifromusize\",\"description\":\"convert a usize to a 64b unsigned integer\",\"signature\":{\"params\":[],\"body\":{\"input\":[{\"t\":\"I\"}],\"output\":[{\"t\":\"Opaque\",\"extension\":\"arithmetic.int.types\",\"id\":\"int\",\"args\":[{\"tya\":\"BoundedNat\",\"n\":6}],\"bound\":\"C\"}]}},\"binary\":false},\"itobool\":{\"extension\":\"arithmetic.conversions\",\"name\":\"itobool\",\"description\":\"convert a 1-bit integer to bool (1 is true, 0 is false)\",\"signature\":{\"params\":[],\"body\":{\"input\":[{\"t\":\"Opaque\",\"extension\":\"arithmetic.int.types\",\"id\":\"int\",\"args\":[{\"tya\":\"BoundedNat\",\"n\":0}],\"bound\":\"C\"}],\"output\":[{\"t\":\"Sum\",\"s\":\"Unit\",\"size\":2}]}},\"binary\":false},\"itostring_s\":{\"extension\":\"arithmetic.conversions\",\"name\":\"itostring_s\",\"description\":\"convert a signed integer to its string representation\",\"signature\":{\"params\":[{\"tp\":\"BoundedNat\",\"bound\":7}],\"body\":{\"input\":[{\"t\":\"Opaque\",\"extension\":\"arithmetic.int.types\",\"id\":\"int\",\"args\":[{\"tya\":\"Variable\",\"idx\":0,\"cached_decl\":{\"tp\":\"BoundedNat\",\"bound\":7}}],\"bound\":\"C\"}],\"output\":[{\"t\":\"Opaque\",\"extension\":\"prelude\",\"id\":\"string\",\"args\":[],\"bound\":\"C\"}]}},\"binary\":false},\"itostring_u\":{\"extension\":\"arithmetic.conversions\",\"name\":\"itostring_u\",\"description\":\"convert an unsigned integer to its string representation\",\"signature\":{\"params\":[{\"tp\":\"BoundedNat\",\"bound\":7}],\"body\":{\"input\":[{\"t\":\"Opaque\",\"extension\":\"arithmetic.int.types\",\"id\":\"int\",\"args\":[{\"tya\":\"Variable\",\"idx\":0,\"cached_decl\":{\"tp\":\"BoundedNat\",\"bound\":7}}],\"bound\":\"C\"}],\"output\":[{\"t\":\"Opaque\",\"extension\":\"prelude\",\"id\":\"string\",\"args\":[],\"bound\":\"C\"}]}},\"binary\":false},\"itousize\":{\"extension\":\"arithmetic.conversions\",\"name\":\"itousize\",\"description\":\"convert a 64b unsigned integer to its usize representation\",\"signature\":{\"params\":[],\"body\":{\"input\":[{\"t\":\"Opaque\",\"extension\":\"arithmetic.int.types\",\"id\":\"int\",\"args\":[{\"tya\":\"BoundedNat\",\"n\":6}],\"bound\":\"C\"}],\"output\":[{\"t\":\"I\"}]}},\"binary\":false},\"trunc_s\":{\"extension\":\"arithmetic.conversions\",\"name\":\"trunc_s\",\"description\":\"float to signed int\",\"signature\":{\"params\":[{\"tp\":\"BoundedNat\",\"bound\":7}],\"body\":{\"input\":[{\"t\":\"Opaque\",\"extension\":\"arithmetic.float.types\",\"id\":\"float64\",\"args\":[],\"bound\":\"C\"}],\"output\":[{\"t\":\"Sum\",\"s\":\"General\",\"rows\":[[{\"t\":\"Opaque\",\"extension\":\"prelude\",\"id\":\"error\",\"args\":[],\"bound\":\"C\"}],[{\"t\":\"Opaque\",\"extension\":\"arithmetic.int.types\",\"id\":\"int\",\"args\":[{\"tya\":\"Variable\",\"idx\":0,\"cached_decl\":{\"tp\":\"BoundedNat\",\"bound\":7}}],\"bound\":\"C\"}]]}]}},\"binary\":false},\"trunc_u\":{\"extension\":\"arithmetic.conversions\",\"name\":\"trunc_u\",\"description\":\"float to unsigned int\",\"signature\":{\"params\":[{\"tp\":\"BoundedNat\",\"bound\":7}],\"body\":{\"input\":[{\"t\":\"Opaque\",\"extension\":\"arithmetic.float.types\",\"id\":\"float64\",\"args\":[],\"bound\":\"C\"}],\"output\":[{\"t\":\"Sum\",\"s\":\"General\",\"rows\":[[{\"t\":\"Opaque\",\"extension\":\"prelude\",\"id\":\"error\",\"args\":[],\"bound\":\"C\"}],[{\"t\":\"Opaque\",\"extension\":\"arithmetic.int.types\",\"id\":\"int\",\"args\":[{\"tya\":\"Variable\",\"idx\":0,\"cached_decl\":{\"tp\":\"BoundedNat\",\"bound\":7}}],\"bound\":\"C\"}]]}]}},\"binary\":false}}},{\"version\":\"0.1.0\",\"name\":\"arithmetic.float\",\"types\":{},\"operations\":{\"fabs\":{\"extension\":\"arithmetic.float\",\"name\":\"fabs\",\"description\":\"absolute value\",\"signature\":{\"params\":[],\"body\":{\"input\":[{\"t\":\"Opaque\",\"extension\":\"arithmetic.float.types\",\"id\":\"float64\",\"args\":[],\"bound\":\"C\"}],\"output\":[{\"t\":\"Opaque\",\"extension\":\"arithmetic.float.types\",\"id\":\"float64\",\"args\":[],\"bound\":\"C\"}]}},\"binary\":false},\"fadd\":{\"extension\":\"arithmetic.float\",\"name\":\"fadd\",\"description\":\"addition\",\"signature\":{\"params\":[],\"body\":{\"input\":[{\"t\":\"Opaque\",\"extension\":\"arithmetic.float.types\",\"id\":\"float64\",\"args\":[],\"bound\":\"C\"},{\"t\":\"Opaque\",\"extension\":\"arithmetic.float.types\",\"id\":\"float64\",\"args\":[],\"bound\":\"C\"}],\"output\":[{\"t\":\"Opaque\",\"extension\":\"arithmetic.float.types\",\"id\":\"float64\",\"args\":[],\"bound\":\"C\"}]}},\"binary\":false},\"fceil\":{\"extension\":\"arithmetic.float\",\"name\":\"fceil\",\"description\":\"ceiling\",\"signature\":{\"params\":[],\"body\":{\"input\":[{\"t\":\"Opaque\",\"extension\":\"arithmetic.float.types\",\"id\":\"float64\",\"args\":[],\"bound\":\"C\"}],\"output\":[{\"t\":\"Opaque\",\"extension\":\"arithmetic.float.types\",\"id\":\"float64\",\"args\":[],\"bound\":\"C\"}]}},\"binary\":false},\"fdiv\":{\"extension\":\"arithmetic.float\",\"name\":\"fdiv\",\"description\":\"division\",\"signature\":{\"params\":[],\"body\":{\"input\":[{\"t\":\"Opaque\",\"extension\":\"arithmetic.float.types\",\"id\":\"float64\",\"args\":[],\"bound\":\"C\"},{\"t\":\"Opaque\",\"extension\":\"arithmetic.float.types\",\"id\":\"float64\",\"args\":[],\"bound\":\"C\"}],\"output\":[{\"t\":\"Opaque\",\"extension\":\"arithmetic.float.types\",\"id\":\"float64\",\"args\":[],\"bound\":\"C\"}]}},\"binary\":false},\"feq\":{\"extension\":\"arithmetic.float\",\"name\":\"feq\",\"description\":\"equality test\",\"signature\":{\"params\":[],\"body\":{\"input\":[{\"t\":\"Opaque\",\"extension\":\"arithmetic.float.types\",\"id\":\"float64\",\"args\":[],\"bound\":\"C\"},{\"t\":\"Opaque\",\"extension\":\"arithmetic.float.types\",\"id\":\"float64\",\"args\":[],\"bound\":\"C\"}],\"output\":[{\"t\":\"Sum\",\"s\":\"Unit\",\"size\":2}]}},\"binary\":false},\"ffloor\":{\"extension\":\"arithmetic.float\",\"name\":\"ffloor\",\"description\":\"floor\",\"signature\":{\"params\":[],\"body\":{\"input\":[{\"t\":\"Opaque\",\"extension\":\"arithmetic.float.types\",\"id\":\"float64\",\"args\":[],\"bound\":\"C\"}],\"output\":[{\"t\":\"Opaque\",\"extension\":\"arithmetic.float.types\",\"id\":\"float64\",\"args\":[],\"bound\":\"C\"}]}},\"binary\":false},\"fge\":{\"extension\":\"arithmetic.float\",\"name\":\"fge\",\"description\":\"\\\"greater than or equal\\\"\",\"signature\":{\"params\":[],\"body\":{\"input\":[{\"t\":\"Opaque\",\"extension\":\"arithmetic.float.types\",\"id\":\"float64\",\"args\":[],\"bound\":\"C\"},{\"t\":\"Opaque\",\"extension\":\"arithmetic.float.types\",\"id\":\"float64\",\"args\":[],\"bound\":\"C\"}],\"output\":[{\"t\":\"Sum\",\"s\":\"Unit\",\"size\":2}]}},\"binary\":false},\"fgt\":{\"extension\":\"arithmetic.float\",\"name\":\"fgt\",\"description\":\"\\\"greater than\\\"\",\"signature\":{\"params\":[],\"body\":{\"input\":[{\"t\":\"Opaque\",\"extension\":\"arithmetic.float.types\",\"id\":\"float64\",\"args\":[],\"bound\":\"C\"},{\"t\":\"Opaque\",\"extension\":\"arithmetic.float.types\",\"id\":\"float64\",\"args\":[],\"bound\":\"C\"}],\"output\":[{\"t\":\"Sum\",\"s\":\"Unit\",\"size\":2}]}},\"binary\":false},\"fle\":{\"extension\":\"arithmetic.float\",\"name\":\"fle\",\"description\":\"\\\"less than or equal\\\"\",\"signature\":{\"params\":[],\"body\":{\"input\":[{\"t\":\"Opaque\",\"extension\":\"arithmetic.float.types\",\"id\":\"float64\",\"args\":[],\"bound\":\"C\"},{\"t\":\"Opaque\",\"extension\":\"arithmetic.float.types\",\"id\":\"float64\",\"args\":[],\"bound\":\"C\"}],\"output\":[{\"t\":\"Sum\",\"s\":\"Unit\",\"size\":2}]}},\"binary\":false},\"flt\":{\"extension\":\"arithmetic.float\",\"name\":\"flt\",\"description\":\"\\\"less than\\\"\",\"signature\":{\"params\":[],\"body\":{\"input\":[{\"t\":\"Opaque\",\"extension\":\"arithmetic.float.types\",\"id\":\"float64\",\"args\":[],\"bound\":\"C\"},{\"t\":\"Opaque\",\"extension\":\"arithmetic.float.types\",\"id\":\"float64\",\"args\":[],\"bound\":\"C\"}],\"output\":[{\"t\":\"Sum\",\"s\":\"Unit\",\"size\":2}]}},\"binary\":false},\"fmax\":{\"extension\":\"arithmetic.float\",\"name\":\"fmax\",\"description\":\"maximum\",\"signature\":{\"params\":[],\"body\":{\"input\":[{\"t\":\"Opaque\",\"extension\":\"arithmetic.float.types\",\"id\":\"float64\",\"args\":[],\"bound\":\"C\"},{\"t\":\"Opaque\",\"extension\":\"arithmetic.float.types\",\"id\":\"float64\",\"args\":[],\"bound\":\"C\"}],\"output\":[{\"t\":\"Opaque\",\"extension\":\"arithmetic.float.types\",\"id\":\"float64\",\"args\":[],\"bound\":\"C\"}]}},\"binary\":false},\"fmin\":{\"extension\":\"arithmetic.float\",\"name\":\"fmin\",\"description\":\"minimum\",\"signature\":{\"params\":[],\"body\":{\"input\":[{\"t\":\"Opaque\",\"extension\":\"arithmetic.float.types\",\"id\":\"float64\",\"args\":[],\"bound\":\"C\"},{\"t\":\"Opaque\",\"extension\":\"arithmetic.float.types\",\"id\":\"float64\",\"args\":[],\"bound\":\"C\"}],\"output\":[{\"t\":\"Opaque\",\"extension\":\"arithmetic.float.types\",\"id\":\"float64\",\"args\":[],\"bound\":\"C\"}]}},\"binary\":false},\"fmul\":{\"extension\":\"arithmetic.float\",\"name\":\"fmul\",\"description\":\"multiplication\",\"signature\":{\"params\":[],\"body\":{\"input\":[{\"t\":\"Opaque\",\"extension\":\"arithmetic.float.types\",\"id\":\"float64\",\"args\":[],\"bound\":\"C\"},{\"t\":\"Opaque\",\"extension\":\"arithmetic.float.types\",\"id\":\"float64\",\"args\":[],\"bound\":\"C\"}],\"output\":[{\"t\":\"Opaque\",\"extension\":\"arithmetic.float.types\",\"id\":\"float64\",\"args\":[],\"bound\":\"C\"}]}},\"binary\":false},\"fne\":{\"extension\":\"arithmetic.float\",\"name\":\"fne\",\"description\":\"inequality test\",\"signature\":{\"params\":[],\"body\":{\"input\":[{\"t\":\"Opaque\",\"extension\":\"arithmetic.float.types\",\"id\":\"float64\",\"args\":[],\"bound\":\"C\"},{\"t\":\"Opaque\",\"extension\":\"arithmetic.float.types\",\"id\":\"float64\",\"args\":[],\"bound\":\"C\"}],\"output\":[{\"t\":\"Sum\",\"s\":\"Unit\",\"size\":2}]}},\"binary\":false},\"fneg\":{\"extension\":\"arithmetic.float\",\"name\":\"fneg\",\"description\":\"negation\",\"signature\":{\"params\":[],\"body\":{\"input\":[{\"t\":\"Opaque\",\"extension\":\"arithmetic.float.types\",\"id\":\"float64\",\"args\":[],\"bound\":\"C\"}],\"output\":[{\"t\":\"Opaque\",\"extension\":\"arithmetic.float.types\",\"id\":\"float64\",\"args\":[],\"bound\":\"C\"}]}},\"binary\":false},\"fpow\":{\"extension\":\"arithmetic.float\",\"name\":\"fpow\",\"description\":\"exponentiation\",\"signature\":{\"params\":[],\"body\":{\"input\":[{\"t\":\"Opaque\",\"extension\":\"arithmetic.float.types\",\"id\":\"float64\",\"args\":[],\"bound\":\"C\"},{\"t\":\"Opaque\",\"extension\":\"arithmetic.float.types\",\"id\":\"float64\",\"args\":[],\"bound\":\"C\"}],\"output\":[{\"t\":\"Opaque\",\"extension\":\"arithmetic.float.types\",\"id\":\"float64\",\"args\":[],\"bound\":\"C\"}]}},\"binary\":false},\"fround\":{\"extension\":\"arithmetic.float\",\"name\":\"fround\",\"description\":\"round\",\"signature\":{\"params\":[],\"body\":{\"input\":[{\"t\":\"Opaque\",\"extension\":\"arithmetic.float.types\",\"id\":\"float64\",\"args\":[],\"bound\":\"C\"}],\"output\":[{\"t\":\"Opaque\",\"extension\":\"arithmetic.float.types\",\"id\":\"float64\",\"args\":[],\"bound\":\"C\"}]}},\"binary\":false},\"fsub\":{\"extension\":\"arithmetic.float\",\"name\":\"fsub\",\"description\":\"subtraction\",\"signature\":{\"params\":[],\"body\":{\"input\":[{\"t\":\"Opaque\",\"extension\":\"arithmetic.float.types\",\"id\":\"float64\",\"args\":[],\"bound\":\"C\"},{\"t\":\"Opaque\",\"extension\":\"arithmetic.float.types\",\"id\":\"float64\",\"args\":[],\"bound\":\"C\"}],\"output\":[{\"t\":\"Opaque\",\"extension\":\"arithmetic.float.types\",\"id\":\"float64\",\"args\":[],\"bound\":\"C\"}]}},\"binary\":false},\"ftostring\":{\"extension\":\"arithmetic.float\",\"name\":\"ftostring\",\"description\":\"string representation\",\"signature\":{\"params\":[],\"body\":{\"input\":[{\"t\":\"Opaque\",\"extension\":\"arithmetic.float.types\",\"id\":\"float64\",\"args\":[],\"bound\":\"C\"}],\"output\":[{\"t\":\"Opaque\",\"extension\":\"prelude\",\"id\":\"string\",\"args\":[],\"bound\":\"C\"}]}},\"binary\":false}}},{\"version\":\"0.1.0\",\"name\":\"arithmetic.float.types\",\"types\":{\"float64\":{\"extension\":\"arithmetic.float.types\",\"name\":\"float64\",\"params\":[],\"description\":\"64-bit IEEE 754-2019 floating-point value\",\"bound\":{\"b\":\"Explicit\",\"bound\":\"C\"}}},\"operations\":{}},{\"version\":\"0.1.0\",\"name\":\"arithmetic.int\",\"types\":{},\"operations\":{\"iabs\":{\"extension\":\"arithmetic.int\",\"name\":\"iabs\",\"description\":\"convert signed to unsigned by taking absolute value\",\"signature\":{\"params\":[{\"tp\":\"BoundedNat\",\"bound\":7}],\"body\":{\"input\":[{\"t\":\"Opaque\",\"extension\":\"arithmetic.int.types\",\"id\":\"int\",\"args\":[{\"tya\":\"Variable\",\"idx\":0,\"cached_decl\":{\"tp\":\"BoundedNat\",\"bound\":7}}],\"bound\":\"C\"}],\"output\":[{\"t\":\"Opaque\",\"extension\":\"arithmetic.int.types\",\"id\":\"int\",\"args\":[{\"tya\":\"Variable\",\"idx\":0,\"cached_decl\":{\"tp\":\"BoundedNat\",\"bound\":7}}],\"bound\":\"C\"}]}},\"binary\":false},\"iadd\":{\"extension\":\"arithmetic.int\",\"name\":\"iadd\",\"description\":\"addition modulo 2^N (signed and unsigned versions are the same op)\",\"signature\":{\"params\":[{\"tp\":\"BoundedNat\",\"bound\":7}],\"body\":{\"input\":[{\"t\":\"Opaque\",\"extension\":\"arithmetic.int.types\",\"id\":\"int\",\"args\":[{\"tya\":\"Variable\",\"idx\":0,\"cached_decl\":{\"tp\":\"BoundedNat\",\"bound\":7}}],\"bound\":\"C\"},{\"t\":\"Opaque\",\"extension\":\"arithmetic.int.types\",\"id\":\"int\",\"args\":[{\"tya\":\"Variable\",\"idx\":0,\"cached_decl\":{\"tp\":\"BoundedNat\",\"bound\":7}}],\"bound\":\"C\"}],\"output\":[{\"t\":\"Opaque\",\"extension\":\"arithmetic.int.types\",\"id\":\"int\",\"args\":[{\"tya\":\"Variable\",\"idx\":0,\"cached_decl\":{\"tp\":\"BoundedNat\",\"bound\":7}}],\"bound\":\"C\"}]}},\"binary\":false},\"iand\":{\"extension\":\"arithmetic.int\",\"name\":\"iand\",\"description\":\"bitwise AND\",\"signature\":{\"params\":[{\"tp\":\"BoundedNat\",\"bound\":7}],\"body\":{\"input\":[{\"t\":\"Opaque\",\"extension\":\"arithmetic.int.types\",\"id\":\"int\",\"args\":[{\"tya\":\"Variable\",\"idx\":0,\"cached_decl\":{\"tp\":\"BoundedNat\",\"bound\":7}}],\"bound\":\"C\"},{\"t\":\"Opaque\",\"extension\":\"arithmetic.int.types\",\"id\":\"int\",\"args\":[{\"tya\":\"Variable\",\"idx\":0,\"cached_decl\":{\"tp\":\"BoundedNat\",\"bound\":7}}],\"bound\":\"C\"}],\"output\":[{\"t\":\"Opaque\",\"extension\":\"arithmetic.int.types\",\"id\":\"int\",\"args\":[{\"tya\":\"Variable\",\"idx\":0,\"cached_decl\":{\"tp\":\"BoundedNat\",\"bound\":7}}],\"bound\":\"C\"}]}},\"binary\":false},\"idiv_checked_s\":{\"extension\":\"arithmetic.int\",\"name\":\"idiv_checked_s\",\"description\":\"as idivmod_checked_s but discarding the second output\",\"signature\":{\"params\":[{\"tp\":\"BoundedNat\",\"bound\":7}],\"body\":{\"input\":[{\"t\":\"Opaque\",\"extension\":\"arithmetic.int.types\",\"id\":\"int\",\"args\":[{\"tya\":\"Variable\",\"idx\":0,\"cached_decl\":{\"tp\":\"BoundedNat\",\"bound\":7}}],\"bound\":\"C\"},{\"t\":\"Opaque\",\"extension\":\"arithmetic.int.types\",\"id\":\"int\",\"args\":[{\"tya\":\"Variable\",\"idx\":0,\"cached_decl\":{\"tp\":\"BoundedNat\",\"bound\":7}}],\"bound\":\"C\"}],\"output\":[{\"t\":\"Sum\",\"s\":\"General\",\"rows\":[[{\"t\":\"Opaque\",\"extension\":\"prelude\",\"id\":\"error\",\"args\":[],\"bound\":\"C\"}],[{\"t\":\"Opaque\",\"extension\":\"arithmetic.int.types\",\"id\":\"int\",\"args\":[{\"tya\":\"Variable\",\"idx\":0,\"cached_decl\":{\"tp\":\"BoundedNat\",\"bound\":7}}],\"bound\":\"C\"}]]}]}},\"binary\":false},\"idiv_checked_u\":{\"extension\":\"arithmetic.int\",\"name\":\"idiv_checked_u\",\"description\":\"as idivmod_checked_u but discarding the second output\",\"signature\":{\"params\":[{\"tp\":\"BoundedNat\",\"bound\":7}],\"body\":{\"input\":[{\"t\":\"Opaque\",\"extension\":\"arithmetic.int.types\",\"id\":\"int\",\"args\":[{\"tya\":\"Variable\",\"idx\":0,\"cached_decl\":{\"tp\":\"BoundedNat\",\"bound\":7}}],\"bound\":\"C\"},{\"t\":\"Opaque\",\"extension\":\"arithmetic.int.types\",\"id\":\"int\",\"args\":[{\"tya\":\"Variable\",\"idx\":0,\"cached_decl\":{\"tp\":\"BoundedNat\",\"bound\":7}}],\"bound\":\"C\"}],\"output\":[{\"t\":\"Sum\",\"s\":\"General\",\"rows\":[[{\"t\":\"Opaque\",\"extension\":\"prelude\",\"id\":\"error\",\"args\":[],\"bound\":\"C\"}],[{\"t\":\"Opaque\",\"extension\":\"arithmetic.int.types\",\"id\":\"int\",\"args\":[{\"tya\":\"Variable\",\"idx\":0,\"cached_decl\":{\"tp\":\"BoundedNat\",\"bound\":7}}],\"bound\":\"C\"}]]}]}},\"binary\":false},\"idiv_s\":{\"extension\":\"arithmetic.int\",\"name\":\"idiv_s\",\"description\":\"as idivmod_s but discarding the second output\",\"signature\":{\"params\":[{\"tp\":\"BoundedNat\",\"bound\":7}],\"body\":{\"input\":[{\"t\":\"Opaque\",\"extension\":\"arithmetic.int.types\",\"id\":\"int\",\"args\":[{\"tya\":\"Variable\",\"idx\":0,\"cached_decl\":{\"tp\":\"BoundedNat\",\"bound\":7}}],\"bound\":\"C\"},{\"t\":\"Opaque\",\"extension\":\"arithmetic.int.types\",\"id\":\"int\",\"args\":[{\"tya\":\"Variable\",\"idx\":0,\"cached_decl\":{\"tp\":\"BoundedNat\",\"bound\":7}}],\"bound\":\"C\"}],\"output\":[{\"t\":\"Opaque\",\"extension\":\"arithmetic.int.types\",\"id\":\"int\",\"args\":[{\"tya\":\"Variable\",\"idx\":0,\"cached_decl\":{\"tp\":\"BoundedNat\",\"bound\":7}}],\"bound\":\"C\"}]}},\"binary\":false},\"idiv_u\":{\"extension\":\"arithmetic.int\",\"name\":\"idiv_u\",\"description\":\"as idivmod_u but discarding the second output\",\"signature\":{\"params\":[{\"tp\":\"BoundedNat\",\"bound\":7}],\"body\":{\"input\":[{\"t\":\"Opaque\",\"extension\":\"arithmetic.int.types\",\"id\":\"int\",\"args\":[{\"tya\":\"Variable\",\"idx\":0,\"cached_decl\":{\"tp\":\"BoundedNat\",\"bound\":7}}],\"bound\":\"C\"},{\"t\":\"Opaque\",\"extension\":\"arithmetic.int.types\",\"id\":\"int\",\"args\":[{\"tya\":\"Variable\",\"idx\":0,\"cached_decl\":{\"tp\":\"BoundedNat\",\"bound\":7}}],\"bound\":\"C\"}],\"output\":[{\"t\":\"Opaque\",\"extension\":\"arithmetic.int.types\",\"id\":\"int\",\"args\":[{\"tya\":\"Variable\",\"idx\":0,\"cached_decl\":{\"tp\":\"BoundedNat\",\"bound\":7}}],\"bound\":\"C\"}]}},\"binary\":false},\"idivmod_checked_s\":{\"extension\":\"arithmetic.int\",\"name\":\"idivmod_checked_s\",\"description\":\"given signed integer -2^{N-1} <= n < 2^{N-1} and unsigned 0 <= m < 2^N, generates signed q and unsigned r where q*m+r=n, 0<=r {} -impl> Resolver for T {} - -/// A resolver that considers two nodes equivalent if they are the same pointer. -/// -/// Resolvers determine when two patches are equivalent and should be merged -/// in the patch history. -/// -/// This is a trivial resolver (to be expanded on later), that considers two -/// patches equivalent if they point to the same data in memory. -#[derive(Clone, Debug, Default, PartialEq, Eq, serde::Serialize, serde::Deserialize)] -pub struct PointerEqResolver; - -impl EquivalenceResolver for PointerEqResolver { - type MergeMapping = (); - - type DedupKey = *const N; - - fn id(&self) -> String { - "PointerEqResolver".to_string() - } - - fn dedup_key(&self, value: &N, _incoming_edges: &[&E]) -> Self::DedupKey { - value as *const N - } - - fn try_merge_mapping( - &self, - a_value: &N, - _a_incoming_edges: &[&E], - b_value: &N, - _b_incoming_edges: &[&E], - ) -> Result { - if std::ptr::eq(a_value, b_value) { - Ok(()) - } else { - Err(relrc::resolver::NotEquivalent) - } - } - - fn move_edge_source(&self, _mapping: &Self::MergeMapping, edge: &E) -> E { - edge.clone() - } -} - -/// A resolver that considers two nodes equivalent if the hashes of their -/// serialisation is the same. -/// -/// ### Generic type parameter -/// -/// This is parametrised over a serializable type `H`, which must implement -/// [`From`]. This type is used to serialise the commit data before -/// hashing it. -/// -/// Resolvers determine when two patches are equivalent and should be merged -/// in the patch history. -#[derive(Clone, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)] -pub struct SerdeHashResolver(#[serde(skip)] PhantomData); - -impl Default for SerdeHashResolver { - fn default() -> Self { - Self(PhantomData) - } -} - -impl SerdeHashResolver { - fn hash(value: &impl serde::Serialize) -> u64 { - let bytes = serde_json::to_vec(value).unwrap(); - const SEED: u64 = 0; - wyhash(&bytes, SEED) - } -} - -impl> EquivalenceResolver - for SerdeHashResolver -{ - type MergeMapping = (); - - type DedupKey = u64; - - fn id(&self) -> String { - "SerdeHashResolver".to_string() - } - - fn dedup_key(&self, value: &CommitData, _incoming_edges: &[&()]) -> Self::DedupKey { - let ser_value = value.clone().into_serial::(); - Self::hash(&ser_value) - } - - fn try_merge_mapping( - &self, - a_value: &CommitData, - _a_incoming_edges: &[&()], - b_value: &CommitData, - _b_incoming_edges: &[&()], - ) -> Result { - let a_ser_value = a_value.clone().into_serial::(); - let b_ser_value = b_value.clone().into_serial::(); - if Self::hash(&a_ser_value) == Self::hash(&b_ser_value) { - Ok(()) - } else { - Err(relrc::resolver::NotEquivalent) - } - } - - fn move_edge_source(&self, _mapping: &Self::MergeMapping, _edge: &()) {} -} - -#[cfg(test)] -mod tests { - use hugr_core::{builder::endo_sig, ops::FuncDefn}; - - use super::*; - use crate::{CommitData, tests::WrappedHugr}; - - #[test] - fn test_serde_hash_resolver_equality() { - let resolver = SerdeHashResolver::::default(); - - // Create a base CommitData - let base_data = CommitData::Base(Hugr::new()); - - // Clone the data to create an equivalent copy - let cloned_data = base_data.clone(); - - // Check that original and cloned data are considered equivalent - let result = resolver.try_merge_mapping(&base_data, &[], &cloned_data, &[]); - // Verify that the merge succeeds since the data is equivalent - assert!(result.is_ok()); - - // Check that the original and replacement data are considered different - let repl_data = CommitData::Base( - Hugr::new_with_entrypoint(FuncDefn::new("dummy", endo_sig(vec![]))).unwrap(), - ); - let result = resolver.try_merge_mapping(&base_data, &[], &repl_data, &[]); - assert!(result.is_err()); - } -} diff --git a/hugr-persistent/src/state_space/snapshots/hugr_persistent__state_space__serial__tests__serialize_state_space.snap b/hugr-persistent/src/state_space/snapshots/hugr_persistent__state_space__serial__tests__serialize_state_space.snap deleted file mode 100644 index b415f9d784..0000000000 --- a/hugr-persistent/src/state_space/snapshots/hugr_persistent__state_space__serial__tests__serialize_state_space.snap +++ /dev/null @@ -1,244 +0,0 @@ ---- -source: hugr-persistent/src/state_space/serial.rs -expression: "serde_json::to_string_pretty(&serialized).unwrap()" ---- -{ - "graph": { - "nodes": { - "3fd58bd8c5f2494a": { - "value": { - "Base": { - "hugr": "HUGRiHJv?@{\"modules\":[{\"version\":\"live\",\"nodes\":[{\"parent\":0,\"op\":\"Module\"},{\"parent\":0,\"op\":\"FuncDefn\",\"name\":\"main\",\"signature\":{\"params\":[],\"body\":{\"input\":[{\"t\":\"Sum\",\"s\":\"Unit\",\"size\":2},{\"t\":\"Sum\",\"s\":\"Unit\",\"size\":2}],\"output\":[{\"t\":\"Sum\",\"s\":\"Unit\",\"size\":2}]}},\"visibility\":\"Private\"},{\"parent\":1,\"op\":\"Input\",\"types\":[{\"t\":\"Sum\",\"s\":\"Unit\",\"size\":2},{\"t\":\"Sum\",\"s\":\"Unit\",\"size\":2}]},{\"parent\":1,\"op\":\"Output\",\"types\":[{\"t\":\"Sum\",\"s\":\"Unit\",\"size\":2}]},{\"parent\":1,\"op\":\"DFG\",\"signature\":{\"input\":[{\"t\":\"Sum\",\"s\":\"Unit\",\"size\":2},{\"t\":\"Sum\",\"s\":\"Unit\",\"size\":2}],\"output\":[{\"t\":\"Sum\",\"s\":\"Unit\",\"size\":2}]}},{\"parent\":4,\"op\":\"Input\",\"types\":[{\"t\":\"Sum\",\"s\":\"Unit\",\"size\":2},{\"t\":\"Sum\",\"s\":\"Unit\",\"size\":2}]},{\"parent\":4,\"op\":\"Output\",\"types\":[{\"t\":\"Sum\",\"s\":\"Unit\",\"size\":2}]},{\"parent\":4,\"op\":\"Extension\",\"extension\":\"logic\",\"name\":\"Not\",\"args\":[],\"signature\":{\"input\":[{\"t\":\"Sum\",\"s\":\"Unit\",\"size\":2}],\"output\":[{\"t\":\"Sum\",\"s\":\"Unit\",\"size\":2}]}},{\"parent\":4,\"op\":\"Extension\",\"extension\":\"logic\",\"name\":\"Not\",\"args\":[],\"signature\":{\"input\":[{\"t\":\"Sum\",\"s\":\"Unit\",\"size\":2}],\"output\":[{\"t\":\"Sum\",\"s\":\"Unit\",\"size\":2}]}},{\"parent\":4,\"op\":\"Extension\",\"extension\":\"logic\",\"name\":\"And\",\"args\":[],\"signature\":{\"input\":[{\"t\":\"Sum\",\"s\":\"Unit\",\"size\":2},{\"t\":\"Sum\",\"s\":\"Unit\",\"size\":2}],\"output\":[{\"t\":\"Sum\",\"s\":\"Unit\",\"size\":2}]}}],\"edges\":[[[2,0],[4,0]],[[2,1],[4,1]],[[4,0],[3,0]],[[5,0],[7,0]],[[5,1],[8,0]],[[7,0],[9,0]],[[8,0],[9,1]],[[9,0],[6,0]]],\"metadata\":[null,null,null,null,null,null,null,null,null,null],\"entrypoint\":4}],\"extensions\":[{\"version\":\"0.1.0\",\"name\":\"arithmetic.conversions\",\"types\":{},\"operations\":{\"bytecast_float64_to_int64\":{\"extension\":\"arithmetic.conversions\",\"name\":\"bytecast_float64_to_int64\",\"description\":\"reinterpret an float64 as an int based on its bytes, with the same endianness\",\"signature\":{\"params\":[],\"body\":{\"input\":[{\"t\":\"Opaque\",\"extension\":\"arithmetic.float.types\",\"id\":\"float64\",\"args\":[],\"bound\":\"C\"}],\"output\":[{\"t\":\"Opaque\",\"extension\":\"arithmetic.int.types\",\"id\":\"int\",\"args\":[{\"tya\":\"BoundedNat\",\"n\":6}],\"bound\":\"C\"}]}},\"binary\":false},\"bytecast_int64_to_float64\":{\"extension\":\"arithmetic.conversions\",\"name\":\"bytecast_int64_to_float64\",\"description\":\"reinterpret an int64 as a float64 based on its bytes, with the same endianness\",\"signature\":{\"params\":[],\"body\":{\"input\":[{\"t\":\"Opaque\",\"extension\":\"arithmetic.int.types\",\"id\":\"int\",\"args\":[{\"tya\":\"BoundedNat\",\"n\":6}],\"bound\":\"C\"}],\"output\":[{\"t\":\"Opaque\",\"extension\":\"arithmetic.float.types\",\"id\":\"float64\",\"args\":[],\"bound\":\"C\"}]}},\"binary\":false},\"convert_s\":{\"extension\":\"arithmetic.conversions\",\"name\":\"convert_s\",\"description\":\"signed int to float\",\"signature\":{\"params\":[{\"tp\":\"BoundedNat\",\"bound\":7}],\"body\":{\"input\":[{\"t\":\"Opaque\",\"extension\":\"arithmetic.int.types\",\"id\":\"int\",\"args\":[{\"tya\":\"Variable\",\"idx\":0,\"cached_decl\":{\"tp\":\"BoundedNat\",\"bound\":7}}],\"bound\":\"C\"}],\"output\":[{\"t\":\"Opaque\",\"extension\":\"arithmetic.float.types\",\"id\":\"float64\",\"args\":[],\"bound\":\"C\"}]}},\"binary\":false},\"convert_u\":{\"extension\":\"arithmetic.conversions\",\"name\":\"convert_u\",\"description\":\"unsigned int to float\",\"signature\":{\"params\":[{\"tp\":\"BoundedNat\",\"bound\":7}],\"body\":{\"input\":[{\"t\":\"Opaque\",\"extension\":\"arithmetic.int.types\",\"id\":\"int\",\"args\":[{\"tya\":\"Variable\",\"idx\":0,\"cached_decl\":{\"tp\":\"BoundedNat\",\"bound\":7}}],\"bound\":\"C\"}],\"output\":[{\"t\":\"Opaque\",\"extension\":\"arithmetic.float.types\",\"id\":\"float64\",\"args\":[],\"bound\":\"C\"}]}},\"binary\":false},\"ifrombool\":{\"extension\":\"arithmetic.conversions\",\"name\":\"ifrombool\",\"description\":\"convert from bool into a 1-bit integer (1 is true, 0 is false)\",\"signature\":{\"params\":[],\"body\":{\"input\":[{\"t\":\"Sum\",\"s\":\"Unit\",\"size\":2}],\"output\":[{\"t\":\"Opaque\",\"extension\":\"arithmetic.int.types\",\"id\":\"int\",\"args\":[{\"tya\":\"BoundedNat\",\"n\":0}],\"bound\":\"C\"}]}},\"binary\":false},\"ifromusize\":{\"extension\":\"arithmetic.conversions\",\"name\":\"ifromusize\",\"description\":\"convert a usize to a 64b unsigned integer\",\"signature\":{\"params\":[],\"body\":{\"input\":[{\"t\":\"I\"}],\"output\":[{\"t\":\"Opaque\",\"extension\":\"arithmetic.int.types\",\"id\":\"int\",\"args\":[{\"tya\":\"BoundedNat\",\"n\":6}],\"bound\":\"C\"}]}},\"binary\":false},\"itobool\":{\"extension\":\"arithmetic.conversions\",\"name\":\"itobool\",\"description\":\"convert a 1-bit integer to bool (1 is true, 0 is false)\",\"signature\":{\"params\":[],\"body\":{\"input\":[{\"t\":\"Opaque\",\"extension\":\"arithmetic.int.types\",\"id\":\"int\",\"args\":[{\"tya\":\"BoundedNat\",\"n\":0}],\"bound\":\"C\"}],\"output\":[{\"t\":\"Sum\",\"s\":\"Unit\",\"size\":2}]}},\"binary\":false},\"itostring_s\":{\"extension\":\"arithmetic.conversions\",\"name\":\"itostring_s\",\"description\":\"convert a signed integer to its string representation\",\"signature\":{\"params\":[{\"tp\":\"BoundedNat\",\"bound\":7}],\"body\":{\"input\":[{\"t\":\"Opaque\",\"extension\":\"arithmetic.int.types\",\"id\":\"int\",\"args\":[{\"tya\":\"Variable\",\"idx\":0,\"cached_decl\":{\"tp\":\"BoundedNat\",\"bound\":7}}],\"bound\":\"C\"}],\"output\":[{\"t\":\"Opaque\",\"extension\":\"prelude\",\"id\":\"string\",\"args\":[],\"bound\":\"C\"}]}},\"binary\":false},\"itostring_u\":{\"extension\":\"arithmetic.conversions\",\"name\":\"itostring_u\",\"description\":\"convert an unsigned integer to its string representation\",\"signature\":{\"params\":[{\"tp\":\"BoundedNat\",\"bound\":7}],\"body\":{\"input\":[{\"t\":\"Opaque\",\"extension\":\"arithmetic.int.types\",\"id\":\"int\",\"args\":[{\"tya\":\"Variable\",\"idx\":0,\"cached_decl\":{\"tp\":\"BoundedNat\",\"bound\":7}}],\"bound\":\"C\"}],\"output\":[{\"t\":\"Opaque\",\"extension\":\"prelude\",\"id\":\"string\",\"args\":[],\"bound\":\"C\"}]}},\"binary\":false},\"itousize\":{\"extension\":\"arithmetic.conversions\",\"name\":\"itousize\",\"description\":\"convert a 64b unsigned integer to its usize representation\",\"signature\":{\"params\":[],\"body\":{\"input\":[{\"t\":\"Opaque\",\"extension\":\"arithmetic.int.types\",\"id\":\"int\",\"args\":[{\"tya\":\"BoundedNat\",\"n\":6}],\"bound\":\"C\"}],\"output\":[{\"t\":\"I\"}]}},\"binary\":false},\"trunc_s\":{\"extension\":\"arithmetic.conversions\",\"name\":\"trunc_s\",\"description\":\"float to signed int\",\"signature\":{\"params\":[{\"tp\":\"BoundedNat\",\"bound\":7}],\"body\":{\"input\":[{\"t\":\"Opaque\",\"extension\":\"arithmetic.float.types\",\"id\":\"float64\",\"args\":[],\"bound\":\"C\"}],\"output\":[{\"t\":\"Sum\",\"s\":\"General\",\"rows\":[[{\"t\":\"Opaque\",\"extension\":\"prelude\",\"id\":\"error\",\"args\":[],\"bound\":\"C\"}],[{\"t\":\"Opaque\",\"extension\":\"arithmetic.int.types\",\"id\":\"int\",\"args\":[{\"tya\":\"Variable\",\"idx\":0,\"cached_decl\":{\"tp\":\"BoundedNat\",\"bound\":7}}],\"bound\":\"C\"}]]}]}},\"binary\":false},\"trunc_u\":{\"extension\":\"arithmetic.conversions\",\"name\":\"trunc_u\",\"description\":\"float to unsigned int\",\"signature\":{\"params\":[{\"tp\":\"BoundedNat\",\"bound\":7}],\"body\":{\"input\":[{\"t\":\"Opaque\",\"extension\":\"arithmetic.float.types\",\"id\":\"float64\",\"args\":[],\"bound\":\"C\"}],\"output\":[{\"t\":\"Sum\",\"s\":\"General\",\"rows\":[[{\"t\":\"Opaque\",\"extension\":\"prelude\",\"id\":\"error\",\"args\":[],\"bound\":\"C\"}],[{\"t\":\"Opaque\",\"extension\":\"arithmetic.int.types\",\"id\":\"int\",\"args\":[{\"tya\":\"Variable\",\"idx\":0,\"cached_decl\":{\"tp\":\"BoundedNat\",\"bound\":7}}],\"bound\":\"C\"}]]}]}},\"binary\":false}}},{\"version\":\"0.1.0\",\"name\":\"arithmetic.float\",\"types\":{},\"operations\":{\"fabs\":{\"extension\":\"arithmetic.float\",\"name\":\"fabs\",\"description\":\"absolute value\",\"signature\":{\"params\":[],\"body\":{\"input\":[{\"t\":\"Opaque\",\"extension\":\"arithmetic.float.types\",\"id\":\"float64\",\"args\":[],\"bound\":\"C\"}],\"output\":[{\"t\":\"Opaque\",\"extension\":\"arithmetic.float.types\",\"id\":\"float64\",\"args\":[],\"bound\":\"C\"}]}},\"binary\":false},\"fadd\":{\"extension\":\"arithmetic.float\",\"name\":\"fadd\",\"description\":\"addition\",\"signature\":{\"params\":[],\"body\":{\"input\":[{\"t\":\"Opaque\",\"extension\":\"arithmetic.float.types\",\"id\":\"float64\",\"args\":[],\"bound\":\"C\"},{\"t\":\"Opaque\",\"extension\":\"arithmetic.float.types\",\"id\":\"float64\",\"args\":[],\"bound\":\"C\"}],\"output\":[{\"t\":\"Opaque\",\"extension\":\"arithmetic.float.types\",\"id\":\"float64\",\"args\":[],\"bound\":\"C\"}]}},\"binary\":false},\"fceil\":{\"extension\":\"arithmetic.float\",\"name\":\"fceil\",\"description\":\"ceiling\",\"signature\":{\"params\":[],\"body\":{\"input\":[{\"t\":\"Opaque\",\"extension\":\"arithmetic.float.types\",\"id\":\"float64\",\"args\":[],\"bound\":\"C\"}],\"output\":[{\"t\":\"Opaque\",\"extension\":\"arithmetic.float.types\",\"id\":\"float64\",\"args\":[],\"bound\":\"C\"}]}},\"binary\":false},\"fdiv\":{\"extension\":\"arithmetic.float\",\"name\":\"fdiv\",\"description\":\"division\",\"signature\":{\"params\":[],\"body\":{\"input\":[{\"t\":\"Opaque\",\"extension\":\"arithmetic.float.types\",\"id\":\"float64\",\"args\":[],\"bound\":\"C\"},{\"t\":\"Opaque\",\"extension\":\"arithmetic.float.types\",\"id\":\"float64\",\"args\":[],\"bound\":\"C\"}],\"output\":[{\"t\":\"Opaque\",\"extension\":\"arithmetic.float.types\",\"id\":\"float64\",\"args\":[],\"bound\":\"C\"}]}},\"binary\":false},\"feq\":{\"extension\":\"arithmetic.float\",\"name\":\"feq\",\"description\":\"equality test\",\"signature\":{\"params\":[],\"body\":{\"input\":[{\"t\":\"Opaque\",\"extension\":\"arithmetic.float.types\",\"id\":\"float64\",\"args\":[],\"bound\":\"C\"},{\"t\":\"Opaque\",\"extension\":\"arithmetic.float.types\",\"id\":\"float64\",\"args\":[],\"bound\":\"C\"}],\"output\":[{\"t\":\"Sum\",\"s\":\"Unit\",\"size\":2}]}},\"binary\":false},\"ffloor\":{\"extension\":\"arithmetic.float\",\"name\":\"ffloor\",\"description\":\"floor\",\"signature\":{\"params\":[],\"body\":{\"input\":[{\"t\":\"Opaque\",\"extension\":\"arithmetic.float.types\",\"id\":\"float64\",\"args\":[],\"bound\":\"C\"}],\"output\":[{\"t\":\"Opaque\",\"extension\":\"arithmetic.float.types\",\"id\":\"float64\",\"args\":[],\"bound\":\"C\"}]}},\"binary\":false},\"fge\":{\"extension\":\"arithmetic.float\",\"name\":\"fge\",\"description\":\"\\\"greater than or equal\\\"\",\"signature\":{\"params\":[],\"body\":{\"input\":[{\"t\":\"Opaque\",\"extension\":\"arithmetic.float.types\",\"id\":\"float64\",\"args\":[],\"bound\":\"C\"},{\"t\":\"Opaque\",\"extension\":\"arithmetic.float.types\",\"id\":\"float64\",\"args\":[],\"bound\":\"C\"}],\"output\":[{\"t\":\"Sum\",\"s\":\"Unit\",\"size\":2}]}},\"binary\":false},\"fgt\":{\"extension\":\"arithmetic.float\",\"name\":\"fgt\",\"description\":\"\\\"greater than\\\"\",\"signature\":{\"params\":[],\"body\":{\"input\":[{\"t\":\"Opaque\",\"extension\":\"arithmetic.float.types\",\"id\":\"float64\",\"args\":[],\"bound\":\"C\"},{\"t\":\"Opaque\",\"extension\":\"arithmetic.float.types\",\"id\":\"float64\",\"args\":[],\"bound\":\"C\"}],\"output\":[{\"t\":\"Sum\",\"s\":\"Unit\",\"size\":2}]}},\"binary\":false},\"fle\":{\"extension\":\"arithmetic.float\",\"name\":\"fle\",\"description\":\"\\\"less than or equal\\\"\",\"signature\":{\"params\":[],\"body\":{\"input\":[{\"t\":\"Opaque\",\"extension\":\"arithmetic.float.types\",\"id\":\"float64\",\"args\":[],\"bound\":\"C\"},{\"t\":\"Opaque\",\"extension\":\"arithmetic.float.types\",\"id\":\"float64\",\"args\":[],\"bound\":\"C\"}],\"output\":[{\"t\":\"Sum\",\"s\":\"Unit\",\"size\":2}]}},\"binary\":false},\"flt\":{\"extension\":\"arithmetic.float\",\"name\":\"flt\",\"description\":\"\\\"less than\\\"\",\"signature\":{\"params\":[],\"body\":{\"input\":[{\"t\":\"Opaque\",\"extension\":\"arithmetic.float.types\",\"id\":\"float64\",\"args\":[],\"bound\":\"C\"},{\"t\":\"Opaque\",\"extension\":\"arithmetic.float.types\",\"id\":\"float64\",\"args\":[],\"bound\":\"C\"}],\"output\":[{\"t\":\"Sum\",\"s\":\"Unit\",\"size\":2}]}},\"binary\":false},\"fmax\":{\"extension\":\"arithmetic.float\",\"name\":\"fmax\",\"description\":\"maximum\",\"signature\":{\"params\":[],\"body\":{\"input\":[{\"t\":\"Opaque\",\"extension\":\"arithmetic.float.types\",\"id\":\"float64\",\"args\":[],\"bound\":\"C\"},{\"t\":\"Opaque\",\"extension\":\"arithmetic.float.types\",\"id\":\"float64\",\"args\":[],\"bound\":\"C\"}],\"output\":[{\"t\":\"Opaque\",\"extension\":\"arithmetic.float.types\",\"id\":\"float64\",\"args\":[],\"bound\":\"C\"}]}},\"binary\":false},\"fmin\":{\"extension\":\"arithmetic.float\",\"name\":\"fmin\",\"description\":\"minimum\",\"signature\":{\"params\":[],\"body\":{\"input\":[{\"t\":\"Opaque\",\"extension\":\"arithmetic.float.types\",\"id\":\"float64\",\"args\":[],\"bound\":\"C\"},{\"t\":\"Opaque\",\"extension\":\"arithmetic.float.types\",\"id\":\"float64\",\"args\":[],\"bound\":\"C\"}],\"output\":[{\"t\":\"Opaque\",\"extension\":\"arithmetic.float.types\",\"id\":\"float64\",\"args\":[],\"bound\":\"C\"}]}},\"binary\":false},\"fmul\":{\"extension\":\"arithmetic.float\",\"name\":\"fmul\",\"description\":\"multiplication\",\"signature\":{\"params\":[],\"body\":{\"input\":[{\"t\":\"Opaque\",\"extension\":\"arithmetic.float.types\",\"id\":\"float64\",\"args\":[],\"bound\":\"C\"},{\"t\":\"Opaque\",\"extension\":\"arithmetic.float.types\",\"id\":\"float64\",\"args\":[],\"bound\":\"C\"}],\"output\":[{\"t\":\"Opaque\",\"extension\":\"arithmetic.float.types\",\"id\":\"float64\",\"args\":[],\"bound\":\"C\"}]}},\"binary\":false},\"fne\":{\"extension\":\"arithmetic.float\",\"name\":\"fne\",\"description\":\"inequality test\",\"signature\":{\"params\":[],\"body\":{\"input\":[{\"t\":\"Opaque\",\"extension\":\"arithmetic.float.types\",\"id\":\"float64\",\"args\":[],\"bound\":\"C\"},{\"t\":\"Opaque\",\"extension\":\"arithmetic.float.types\",\"id\":\"float64\",\"args\":[],\"bound\":\"C\"}],\"output\":[{\"t\":\"Sum\",\"s\":\"Unit\",\"size\":2}]}},\"binary\":false},\"fneg\":{\"extension\":\"arithmetic.float\",\"name\":\"fneg\",\"description\":\"negation\",\"signature\":{\"params\":[],\"body\":{\"input\":[{\"t\":\"Opaque\",\"extension\":\"arithmetic.float.types\",\"id\":\"float64\",\"args\":[],\"bound\":\"C\"}],\"output\":[{\"t\":\"Opaque\",\"extension\":\"arithmetic.float.types\",\"id\":\"float64\",\"args\":[],\"bound\":\"C\"}]}},\"binary\":false},\"fpow\":{\"extension\":\"arithmetic.float\",\"name\":\"fpow\",\"description\":\"exponentiation\",\"signature\":{\"params\":[],\"body\":{\"input\":[{\"t\":\"Opaque\",\"extension\":\"arithmetic.float.types\",\"id\":\"float64\",\"args\":[],\"bound\":\"C\"},{\"t\":\"Opaque\",\"extension\":\"arithmetic.float.types\",\"id\":\"float64\",\"args\":[],\"bound\":\"C\"}],\"output\":[{\"t\":\"Opaque\",\"extension\":\"arithmetic.float.types\",\"id\":\"float64\",\"args\":[],\"bound\":\"C\"}]}},\"binary\":false},\"fround\":{\"extension\":\"arithmetic.float\",\"name\":\"fround\",\"description\":\"round\",\"signature\":{\"params\":[],\"body\":{\"input\":[{\"t\":\"Opaque\",\"extension\":\"arithmetic.float.types\",\"id\":\"float64\",\"args\":[],\"bound\":\"C\"}],\"output\":[{\"t\":\"Opaque\",\"extension\":\"arithmetic.float.types\",\"id\":\"float64\",\"args\":[],\"bound\":\"C\"}]}},\"binary\":false},\"fsub\":{\"extension\":\"arithmetic.float\",\"name\":\"fsub\",\"description\":\"subtraction\",\"signature\":{\"params\":[],\"body\":{\"input\":[{\"t\":\"Opaque\",\"extension\":\"arithmetic.float.types\",\"id\":\"float64\",\"args\":[],\"bound\":\"C\"},{\"t\":\"Opaque\",\"extension\":\"arithmetic.float.types\",\"id\":\"float64\",\"args\":[],\"bound\":\"C\"}],\"output\":[{\"t\":\"Opaque\",\"extension\":\"arithmetic.float.types\",\"id\":\"float64\",\"args\":[],\"bound\":\"C\"}]}},\"binary\":false},\"ftostring\":{\"extension\":\"arithmetic.float\",\"name\":\"ftostring\",\"description\":\"string representation\",\"signature\":{\"params\":[],\"body\":{\"input\":[{\"t\":\"Opaque\",\"extension\":\"arithmetic.float.types\",\"id\":\"float64\",\"args\":[],\"bound\":\"C\"}],\"output\":[{\"t\":\"Opaque\",\"extension\":\"prelude\",\"id\":\"string\",\"args\":[],\"bound\":\"C\"}]}},\"binary\":false}}},{\"version\":\"0.1.0\",\"name\":\"arithmetic.float.types\",\"types\":{\"float64\":{\"extension\":\"arithmetic.float.types\",\"name\":\"float64\",\"params\":[],\"description\":\"64-bit IEEE 754-2019 floating-point value\",\"bound\":{\"b\":\"Explicit\",\"bound\":\"C\"}}},\"operations\":{}},{\"version\":\"0.1.0\",\"name\":\"arithmetic.int\",\"types\":{},\"operations\":{\"iabs\":{\"extension\":\"arithmetic.int\",\"name\":\"iabs\",\"description\":\"convert signed to unsigned by taking absolute value\",\"signature\":{\"params\":[{\"tp\":\"BoundedNat\",\"bound\":7}],\"body\":{\"input\":[{\"t\":\"Opaque\",\"extension\":\"arithmetic.int.types\",\"id\":\"int\",\"args\":[{\"tya\":\"Variable\",\"idx\":0,\"cached_decl\":{\"tp\":\"BoundedNat\",\"bound\":7}}],\"bound\":\"C\"}],\"output\":[{\"t\":\"Opaque\",\"extension\":\"arithmetic.int.types\",\"id\":\"int\",\"args\":[{\"tya\":\"Variable\",\"idx\":0,\"cached_decl\":{\"tp\":\"BoundedNat\",\"bound\":7}}],\"bound\":\"C\"}]}},\"binary\":false},\"iadd\":{\"extension\":\"arithmetic.int\",\"name\":\"iadd\",\"description\":\"addition modulo 2^N (signed and unsigned versions are the same op)\",\"signature\":{\"params\":[{\"tp\":\"BoundedNat\",\"bound\":7}],\"body\":{\"input\":[{\"t\":\"Opaque\",\"extension\":\"arithmetic.int.types\",\"id\":\"int\",\"args\":[{\"tya\":\"Variable\",\"idx\":0,\"cached_decl\":{\"tp\":\"BoundedNat\",\"bound\":7}}],\"bound\":\"C\"},{\"t\":\"Opaque\",\"extension\":\"arithmetic.int.types\",\"id\":\"int\",\"args\":[{\"tya\":\"Variable\",\"idx\":0,\"cached_decl\":{\"tp\":\"BoundedNat\",\"bound\":7}}],\"bound\":\"C\"}],\"output\":[{\"t\":\"Opaque\",\"extension\":\"arithmetic.int.types\",\"id\":\"int\",\"args\":[{\"tya\":\"Variable\",\"idx\":0,\"cached_decl\":{\"tp\":\"BoundedNat\",\"bound\":7}}],\"bound\":\"C\"}]}},\"binary\":false},\"iand\":{\"extension\":\"arithmetic.int\",\"name\":\"iand\",\"description\":\"bitwise AND\",\"signature\":{\"params\":[{\"tp\":\"BoundedNat\",\"bound\":7}],\"body\":{\"input\":[{\"t\":\"Opaque\",\"extension\":\"arithmetic.int.types\",\"id\":\"int\",\"args\":[{\"tya\":\"Variable\",\"idx\":0,\"cached_decl\":{\"tp\":\"BoundedNat\",\"bound\":7}}],\"bound\":\"C\"},{\"t\":\"Opaque\",\"extension\":\"arithmetic.int.types\",\"id\":\"int\",\"args\":[{\"tya\":\"Variable\",\"idx\":0,\"cached_decl\":{\"tp\":\"BoundedNat\",\"bound\":7}}],\"bound\":\"C\"}],\"output\":[{\"t\":\"Opaque\",\"extension\":\"arithmetic.int.types\",\"id\":\"int\",\"args\":[{\"tya\":\"Variable\",\"idx\":0,\"cached_decl\":{\"tp\":\"BoundedNat\",\"bound\":7}}],\"bound\":\"C\"}]}},\"binary\":false},\"idiv_checked_s\":{\"extension\":\"arithmetic.int\",\"name\":\"idiv_checked_s\",\"description\":\"as idivmod_checked_s but discarding the second output\",\"signature\":{\"params\":[{\"tp\":\"BoundedNat\",\"bound\":7}],\"body\":{\"input\":[{\"t\":\"Opaque\",\"extension\":\"arithmetic.int.types\",\"id\":\"int\",\"args\":[{\"tya\":\"Variable\",\"idx\":0,\"cached_decl\":{\"tp\":\"BoundedNat\",\"bound\":7}}],\"bound\":\"C\"},{\"t\":\"Opaque\",\"extension\":\"arithmetic.int.types\",\"id\":\"int\",\"args\":[{\"tya\":\"Variable\",\"idx\":0,\"cached_decl\":{\"tp\":\"BoundedNat\",\"bound\":7}}],\"bound\":\"C\"}],\"output\":[{\"t\":\"Sum\",\"s\":\"General\",\"rows\":[[{\"t\":\"Opaque\",\"extension\":\"prelude\",\"id\":\"error\",\"args\":[],\"bound\":\"C\"}],[{\"t\":\"Opaque\",\"extension\":\"arithmetic.int.types\",\"id\":\"int\",\"args\":[{\"tya\":\"Variable\",\"idx\":0,\"cached_decl\":{\"tp\":\"BoundedNat\",\"bound\":7}}],\"bound\":\"C\"}]]}]}},\"binary\":false},\"idiv_checked_u\":{\"extension\":\"arithmetic.int\",\"name\":\"idiv_checked_u\",\"description\":\"as idivmod_checked_u but discarding the second output\",\"signature\":{\"params\":[{\"tp\":\"BoundedNat\",\"bound\":7}],\"body\":{\"input\":[{\"t\":\"Opaque\",\"extension\":\"arithmetic.int.types\",\"id\":\"int\",\"args\":[{\"tya\":\"Variable\",\"idx\":0,\"cached_decl\":{\"tp\":\"BoundedNat\",\"bound\":7}}],\"bound\":\"C\"},{\"t\":\"Opaque\",\"extension\":\"arithmetic.int.types\",\"id\":\"int\",\"args\":[{\"tya\":\"Variable\",\"idx\":0,\"cached_decl\":{\"tp\":\"BoundedNat\",\"bound\":7}}],\"bound\":\"C\"}],\"output\":[{\"t\":\"Sum\",\"s\":\"General\",\"rows\":[[{\"t\":\"Opaque\",\"extension\":\"prelude\",\"id\":\"error\",\"args\":[],\"bound\":\"C\"}],[{\"t\":\"Opaque\",\"extension\":\"arithmetic.int.types\",\"id\":\"int\",\"args\":[{\"tya\":\"Variable\",\"idx\":0,\"cached_decl\":{\"tp\":\"BoundedNat\",\"bound\":7}}],\"bound\":\"C\"}]]}]}},\"binary\":false},\"idiv_s\":{\"extension\":\"arithmetic.int\",\"name\":\"idiv_s\",\"description\":\"as idivmod_s but discarding the second output\",\"signature\":{\"params\":[{\"tp\":\"BoundedNat\",\"bound\":7}],\"body\":{\"input\":[{\"t\":\"Opaque\",\"extension\":\"arithmetic.int.types\",\"id\":\"int\",\"args\":[{\"tya\":\"Variable\",\"idx\":0,\"cached_decl\":{\"tp\":\"BoundedNat\",\"bound\":7}}],\"bound\":\"C\"},{\"t\":\"Opaque\",\"extension\":\"arithmetic.int.types\",\"id\":\"int\",\"args\":[{\"tya\":\"Variable\",\"idx\":0,\"cached_decl\":{\"tp\":\"BoundedNat\",\"bound\":7}}],\"bound\":\"C\"}],\"output\":[{\"t\":\"Opaque\",\"extension\":\"arithmetic.int.types\",\"id\":\"int\",\"args\":[{\"tya\":\"Variable\",\"idx\":0,\"cached_decl\":{\"tp\":\"BoundedNat\",\"bound\":7}}],\"bound\":\"C\"}]}},\"binary\":false},\"idiv_u\":{\"extension\":\"arithmetic.int\",\"name\":\"idiv_u\",\"description\":\"as idivmod_u but discarding the second output\",\"signature\":{\"params\":[{\"tp\":\"BoundedNat\",\"bound\":7}],\"body\":{\"input\":[{\"t\":\"Opaque\",\"extension\":\"arithmetic.int.types\",\"id\":\"int\",\"args\":[{\"tya\":\"Variable\",\"idx\":0,\"cached_decl\":{\"tp\":\"BoundedNat\",\"bound\":7}}],\"bound\":\"C\"},{\"t\":\"Opaque\",\"extension\":\"arithmetic.int.types\",\"id\":\"int\",\"args\":[{\"tya\":\"Variable\",\"idx\":0,\"cached_decl\":{\"tp\":\"BoundedNat\",\"bound\":7}}],\"bound\":\"C\"}],\"output\":[{\"t\":\"Opaque\",\"extension\":\"arithmetic.int.types\",\"id\":\"int\",\"args\":[{\"tya\":\"Variable\",\"idx\":0,\"cached_decl\":{\"tp\":\"BoundedNat\",\"bound\":7}}],\"bound\":\"C\"}]}},\"binary\":false},\"idivmod_checked_s\":{\"extension\":\"arithmetic.int\",\"name\":\"idivmod_checked_s\",\"description\":\"given signed integer -2^{N-1} <= n < 2^{N-1} and unsigned 0 <= m < 2^N, generates signed q and unsigned r where q*m+r=n, 0<=r, - /// The input ports of the subgraph. - /// - /// Grouped by input parameter. Each port must be unique and belong to a - /// node in `nodes`. - inputs: Vec>, - /// The output ports of the subgraph. - /// - /// Repeated ports are allowed and correspond to copying the output. Every - /// port must belong to a node in `nodes`. - outputs: Vec<(PatchNode, OutgoingPort)>, - /// The commits that must be selected in the host for the subgraph to be - /// valid. - selected_commits: BTreeSet, -} - -impl From> for PinnedSubgraph { - fn from(subgraph: SiblingSubgraph) -> Self { - Self { - inputs: subgraph.incoming_ports().clone(), - outputs: subgraph.outgoing_ports().clone(), - nodes: BTreeSet::from_iter(subgraph.nodes().iter().copied()), - selected_commits: BTreeSet::new(), - } - } -} - -impl PinnedSubgraph { - /// Create a new subgraph from a set of pinned nodes and wires. - /// - /// All nodes must be pinned and all wires must be complete in the given - /// `walker`. - /// - /// Nodes that are not isolated, i.e. are attached to at least one wire in - /// `wires` will be added implicitly to the graph and do not need to be - /// explicitly listed in `nodes`. - pub fn try_from_pinned( - nodes: impl IntoIterator, - wires: impl IntoIterator, - walker: &Walker, - ) -> Result { - let mut selected_commits = BTreeSet::new(); - let host = walker.as_hugr_view(); - let wires = wires.into_iter().collect_vec(); - let nodes = nodes.into_iter().collect_vec(); - - for w in wires.iter() { - if !walker.is_complete(w, None) { - return Err(InvalidPinnedSubgraph::IncompleteWire(w.clone())); - } - for id in w.owners() { - if host.contains_id(id) { - selected_commits.insert(id); - } else { - return Err(InvalidPinnedSubgraph::InvalidCommit(id)); - } - } - } - - if let Some(&unpinned) = nodes.iter().find(|&&n| !walker.is_pinned(n)) { - return Err(InvalidPinnedSubgraph::UnpinnedNode(unpinned)); - } - - let (inputs, outputs, all_nodes) = Self::compute_io_ports(nodes, wires, host); - - Ok(Self { - selected_commits, - nodes: all_nodes, - inputs, - outputs, - }) - } - - /// Create a new subgraph from a set of complete wires in `walker`. - pub fn try_from_wires( - wires: impl IntoIterator, - walker: &Walker, - ) -> Result { - Self::try_from_pinned(std::iter::empty(), wires, walker) - } - - /// Compute the input and output ports for the given pinned nodes and wires. - /// - /// Return the input boundary ports, output boundary ports as well as the - /// set of all nodes in the subgraph. - pub fn compute_io_ports( - nodes: impl IntoIterator, - wires: impl IntoIterator, - host: &PersistentHugr, - ) -> ( - IncomingPorts, - OutgoingPorts, - BTreeSet, - ) { - let mut wire_ports_incoming = BTreeSet::new(); - let mut wire_ports_outgoing = BTreeSet::new(); - - for w in wires { - wire_ports_incoming.extend(w.all_incoming_ports(host)); - wire_ports_outgoing.extend(w.single_outgoing_port(host)); - } - - let mut all_nodes = BTreeSet::from_iter(nodes); - all_nodes.extend(wire_ports_incoming.iter().map(|&(n, _)| n)); - all_nodes.extend(wire_ports_outgoing.iter().map(|&(n, _)| n)); - - // (in/out) boundary: all in/out ports on the nodes of the wire, minus ports - // that are part of the wires - let inputs = all_nodes - .iter() - .flat_map(|&n| host.input_value_ports(n)) - .filter(|node_port| !wire_ports_incoming.contains(node_port)) - .map(|np| vec![np]) - .collect_vec(); - let outputs = all_nodes - .iter() - .flat_map(|&n| host.output_value_ports(n)) - .filter(|node_port| !wire_ports_outgoing.contains(node_port)) - .collect_vec(); - - (inputs, outputs, all_nodes) - } - - /// Convert the pinned subgraph to a [`SiblingSubgraph`] for the given - /// `host`. - /// - /// This will fail if any of the required selected commits are not in the - /// host, if any of the nodes are invalid in the host (e.g. deleted by - /// another commit in host), or if the subgraph is not convex. - pub fn to_sibling_subgraph( - &self, - host: &PersistentHugr, - ) -> Result, InvalidPinnedSubgraph> { - if let Some(&unselected) = self - .selected_commits - .iter() - .find(|&&id| !host.contains_id(id)) - { - return Err(InvalidPinnedSubgraph::InvalidCommit(unselected)); - } - - if let Some(invalid) = self.nodes.iter().find(|&&n| !host.contains_node(n)) { - return Err(InvalidPinnedSubgraph::InvalidNode(*invalid)); - } - - Ok(SiblingSubgraph::try_new( - self.inputs.clone(), - self.outputs.clone(), - host, - )?) - } - - /// Iterate over all the commits required by this pinned subgraph. - pub fn selected_commits(&self) -> impl Iterator + '_ { - self.selected_commits.iter().copied() - } - - /// Iterate over all the nodes in this pinned subgraph. - pub fn nodes(&self) -> impl Iterator + '_ { - self.nodes.iter().copied() - } - - /// Returns the computed [`IncomingPorts`] of the subgraph. - #[must_use] - pub fn incoming_ports(&self) -> &IncomingPorts { - &self.inputs - } - - /// Returns the computed [`OutgoingPorts`] of the subgraph. - #[must_use] - pub fn outgoing_ports(&self) -> &OutgoingPorts { - &self.outputs - } -} - -#[derive(Debug, Clone, Error)] -#[non_exhaustive] -pub enum InvalidPinnedSubgraph { - #[error("Invalid subgraph: {0}")] - InvalidSubgraph(#[from] InvalidSubgraph), - #[error("Invalid commit in host: {0}")] - InvalidCommit(CommitId), - #[error("Wire is not complete: {0:?}")] - IncompleteWire(PersistentWire), - #[error("Node is not pinned: {0}")] - UnpinnedNode(PatchNode), - #[error("Invalid node in host: {0}")] - InvalidNode(PatchNode), -} diff --git a/hugr-persistent/src/wire.rs b/hugr-persistent/src/wire.rs deleted file mode 100644 index a84d4e6923..0000000000 --- a/hugr-persistent/src/wire.rs +++ /dev/null @@ -1,303 +0,0 @@ -use std::collections::{BTreeSet, VecDeque}; - -use hugr_core::{ - Direction, HugrView, IncomingPort, OutgoingPort, Port, Wire, - hugr::patch::simple_replace::BoundaryMode, -}; -use itertools::Itertools; - -use crate::{CommitId, PatchNode, PersistentHugr, Resolver, Walker}; - -/// A wire in a [`PersistentHugr`]. -/// -/// A wire may be composed of multiple wires in the underlying commits -#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] -pub struct PersistentWire { - wires: BTreeSet, -} - -/// A wire within a commit HUGR of a [`PersistentHugr`]. -/// -/// Also stores the ID of the commit that contains the wire; -/// equivalent to (indeed contains) a `Wire`. -/// -/// Note that it does not correspond to a valid wire in a [`PersistentHugr`] -/// (see [`PersistentWire`]): some of its connected ports may be on deleted or -/// IO nodes that are not valid in the [`PersistentHugr`]. -#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] -struct CommitWire(Wire); - -impl CommitWire { - fn from_connected_port( - PatchNode(commit_id, node): PatchNode, - port: impl Into, - hugr: &PersistentHugr, - ) -> Self { - let commit_hugr = hugr.commit_hugr(commit_id); - let wire = Wire::from_connected_port(node, port, commit_hugr); - Self(Wire::new(PatchNode(commit_id, wire.node()), wire.source())) - } - - fn all_connected_ports<'h, R>( - &self, - hugr: &'h PersistentHugr, - ) -> impl Iterator + use<'h, R> { - let wire = Wire::new(self.0.node().1, self.0.source()); - let commit_id = self.commit_id(); - wire.all_connected_ports(hugr.commit_hugr(commit_id)) - .map(move |(node, port)| (hugr.to_persistent_node(node, commit_id), port)) - } - - fn commit_id(&self) -> CommitId { - self.0.node().0 - } - - delegate::delegate! { - to self.0 { - fn node(&self) -> PatchNode; - } - } -} - -/// A node in a commit of a [`PersistentHugr`] is either a valid node of the -/// HUGR, a node deleted by a child commit in that [`PersistentHugr`], or an -/// input or output node in a replacement graph. -#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] -enum NodeStatus { - /// A node deleted by a child commit in that [`PersistentHugr`]. - /// - /// The ID of the child commit is stored in the variant. - Deleted(CommitId), - /// An input or output node in the replacement graph of a Commit - ReplacementIO, - /// A valid node in the [`PersistentHugr`] - Valid, -} - -impl PersistentHugr { - pub fn get_wire(&self, node: PatchNode, port: impl Into) -> PersistentWire { - PersistentWire::from_port(node, port, self) - } - - /// Whether a node is valid in `self`, is deleted or is an IO node in a - /// replacement graph. - fn node_status(&self, per_node @ PatchNode(commit_id, node): PatchNode) -> NodeStatus { - debug_assert!(self.contains_id(commit_id), "unknown commit"); - if self - .replacement(commit_id) - .is_some_and(|repl| repl.get_replacement_io().contains(&node)) - { - NodeStatus::ReplacementIO - } else if let Some(commit_id) = self.find_deleting_commit(per_node) { - NodeStatus::Deleted(commit_id) - } else { - NodeStatus::Valid - } - } -} - -impl PersistentWire { - /// Get the wire connected to a specified port of a pinned node in `hugr`. - fn from_port(node: PatchNode, port: impl Into, per_hugr: &PersistentHugr) -> Self { - assert!(per_hugr.contains_node(node), "node not in hugr"); - - // Queue of wires within each commit HUGR, that combined will form the - // persistent wire. - let mut commit_wires = - BTreeSet::from_iter([CommitWire::from_connected_port(node, port, per_hugr)]); - let mut queue = VecDeque::from_iter(commit_wires.iter().copied()); - - while let Some(wire) = queue.pop_front() { - let commit_id = wire.commit_id(); - let commit_hugr = per_hugr.commit_hugr(commit_id); - let all_ports = wire.all_connected_ports(per_hugr); - - for (per_node @ PatchNode(_, node), port) in all_ports { - match per_hugr.node_status(per_node) { - NodeStatus::Deleted(deleted_by) => { - // If node is deleted, check if there are wires between - // ports on the opposite end of the wire and boundary - // ports in the child commit that deleted the node. - for (opp_node, opp_port) in commit_hugr.linked_ports(node, port) { - let opp_node = per_hugr.to_persistent_node(opp_node, commit_id); - for (child_node, child_port) in - per_hugr.as_state_space().linked_child_ports( - opp_node, - opp_port, - deleted_by, - BoundaryMode::IncludeIO, - ) - { - debug_assert_eq!(child_node.owner(), deleted_by); - let w = CommitWire::from_connected_port( - child_node, child_port, per_hugr, - ); - if commit_wires.insert(w) { - queue.push_back(w); - } - } - } - } - NodeStatus::ReplacementIO => { - // If node is an input (resp. output) node in a replacement graph, there - // must be (at least) one wire between the incoming (resp. outgoing) - // boundary ports of the commit (i.e. the ports connected to - // the input resp. output) and ports in a parent commit. - for (opp_node, opp_port) in commit_hugr.linked_ports(node, port) { - let opp_node = per_hugr.to_persistent_node(opp_node, commit_id); - for (parent_node, parent_port) in per_hugr - .as_state_space() - .linked_parent_ports(opp_node, opp_port) - { - let w = CommitWire::from_connected_port( - parent_node, - parent_port, - per_hugr, - ); - if commit_wires.insert(w) { - queue.push_back(w); - } - } - } - } - NodeStatus::Valid => {} - } - } - } - - Self { - wires: commit_wires, - } - } - - /// Get all ports attached to a wire in `hugr`. - /// - /// All ports returned are on nodes that are contained in `hugr`. - pub fn all_ports( - &self, - hugr: &PersistentHugr, - dir: impl Into>, - ) -> impl Iterator { - all_ports_impl(self.wires.iter().copied(), dir.into(), hugr) - } - - /// All commit IDs that the wire traverses. - pub fn owners(&self) -> impl Iterator { - self.wires.iter().map(|w| w.node().owner()).unique() - } - - /// Consume the wire and return all ports attached to a wire in `hugr`. - /// - /// All ports returned are on nodes that are contained in `hugr`. - pub fn into_all_ports( - self, - hugr: &PersistentHugr, - dir: impl Into>, - ) -> impl Iterator { - all_ports_impl(self.wires.into_iter(), dir.into(), hugr) - } - - pub fn single_outgoing_port( - &self, - hugr: &PersistentHugr, - ) -> Option<(PatchNode, OutgoingPort)> { - single_outgoing(self.all_ports(hugr, Direction::Outgoing)) - } - - pub fn all_incoming_ports( - &self, - hugr: &PersistentHugr, - ) -> impl Iterator { - self.all_ports(hugr, Direction::Incoming) - .map(|(node, port)| (node, port.as_incoming().unwrap())) - } -} - -impl Walker<'_, R> { - /// Get all ports on a wire that are not pinned in `self`. - pub(crate) fn wire_unpinned_ports( - &self, - wire: &PersistentWire, - dir: impl Into>, - ) -> impl Iterator { - let ports = wire.all_ports(self.as_hugr_view(), dir); - ports.filter(|(node, _)| !self.is_pinned(*node)) - } - - /// Get the ports of the wire that are on pinned nodes of `self`. - pub fn wire_pinned_ports( - &self, - wire: &PersistentWire, - dir: impl Into>, - ) -> impl Iterator { - let ports = wire.all_ports(self.as_hugr_view(), dir); - ports.filter(|(node, _)| self.is_pinned(*node)) - } - - /// Get the outgoing port of a wire if it is pinned in `walker`. - pub fn wire_pinned_outport(&self, wire: &PersistentWire) -> Option<(PatchNode, OutgoingPort)> { - single_outgoing(self.wire_pinned_ports(wire, Direction::Outgoing)) - } - - /// Get all pinned incoming ports of a wire. - pub fn wire_pinned_inports( - &self, - wire: &PersistentWire, - ) -> impl Iterator { - self.wire_pinned_ports(wire, Direction::Incoming) - .map(|(node, port)| (node, port.as_incoming().expect("incoming port"))) - } - - /// Whether a wire is complete in the specified direction, i.e. there are no - /// unpinned ports left. - pub fn is_complete(&self, wire: &PersistentWire, dir: impl Into>) -> bool { - self.wire_unpinned_ports(wire, dir).next().is_none() - } -} - -/// Implementation of the (shared) body of [`PersistentWire::all_ports`] and -/// [`PersistentWire::into_all_ports`]. -fn all_ports_impl( - wires: impl Iterator, - dir: Option, - per_hugr: &PersistentHugr, -) -> impl Iterator { - let all_ports = wires.flat_map(move |w| w.all_connected_ports(per_hugr)); - - // Filter out invalid and wrong direction ports - all_ports - .filter(move |(_, port)| dir.is_none_or(|dir| port.direction() == dir)) - .filter(|&(node, _)| per_hugr.node_status(node) == NodeStatus::Valid) -} - -fn single_outgoing(iter: impl Iterator) -> Option<(N, OutgoingPort)> { - let (node, port) = iter.exactly_one().ok()?; - Some((node, port.as_outgoing().ok()?)) -} - -#[cfg(test)] -mod tests { - use std::collections::BTreeSet; - - use crate::{CommitId, CommitStateSpace, PatchNode, tests::test_state_space}; - use hugr_core::{HugrView, OutgoingPort}; - use itertools::Itertools; - use rstest::rstest; - - #[rstest] - fn test_all_ports(test_state_space: (CommitStateSpace, [CommitId; 4])) { - let (state_space, [_, _, cm3, cm4]) = test_state_space; - let hugr = state_space.try_extract_hugr([cm3, cm4]).unwrap(); - let cm4_not = { - let hugr4 = state_space.commit_hugr(cm4); - let out = state_space.replacement(cm4).unwrap().get_replacement_io()[1]; - let node = hugr4.input_neighbours(out).exactly_one().ok().unwrap(); - PatchNode(cm4, node) - }; - let w = hugr.get_wire(cm4_not, OutgoingPort::from(0)); - assert_eq!( - BTreeSet::from_iter(w.wires.iter().map(|w| w.0.node().0)), - BTreeSet::from_iter([cm3, cm4, state_space.base(),]) - ); - } -} diff --git a/hugr-py/CHANGELOG.md b/hugr-py/CHANGELOG.md index 9400909bb1..7cce076974 100644 --- a/hugr-py/CHANGELOG.md +++ b/hugr-py/CHANGELOG.md @@ -1,108 +1,5 @@ # Changelog -## [0.13.0rc1](https://github.com/CQCL/hugr/compare/hugr-py-v0.12.5...hugr-py-v0.13.0rc1) (2025-07-24) - - -### ⚠ BREAKING CHANGES - -* Lowering functions in extension operations are now encoded as binary envelopes. Older hugr versions will error out when trying to load them. -* **py:** `EnvelopeConfig::BINARY` now uses the model binary encoding. `EnvelopeFormat.MODULE` is now `EnvelopeFormat.MODEL`. `EnvelopeFormat.MODULE_WITH_EXTS` is now `EnvelopeFormat.MODEL_WITH_EXTS` -* hugr-model: Symbol has an extra field -* Renamed the `Any` type bound to `Linear` -* The model CFG signature types were changed. -* Added `TypeParam`s and `TypeArg`s corresponding to floats and bytes. -* `TypeArg::Sequence` needs to be replaced with -* FuncDefns must be moved to beneath Module. `Container::define_function` is gone, use `HugrBuilder::module_root_builder`; similarly in hugr-py `DefinitionBuilder` (`define_function` -> `module_root_builder().define_function`). In hugr-llvm, some uses of - -### Features - -* Add `BorrowArray` extension ([#2395](https://github.com/CQCL/hugr/issues/2395)) ([782687e](https://github.com/CQCL/hugr/commit/782687ed917c3e4295c2c3c59a17d784fc6f932d)) -* Add `MakeError` op ([#2377](https://github.com/CQCL/hugr/issues/2377)) ([909a794](https://github.com/CQCL/hugr/commit/909a7948c1465aab5528895bdee0e49958a416b6)), closes [#1863](https://github.com/CQCL/hugr/issues/1863) -* add toposort to HUGR-py ([#2367](https://github.com/CQCL/hugr/issues/2367)) ([34eed34](https://github.com/CQCL/hugr/commit/34eed3422c9aa34bd6b8ad868dcbab733eb5d14c)) -* Add Visibility to FuncDefn/FuncDecl. ([#2143](https://github.com/CQCL/hugr/issues/2143)) ([5bbe0cd](https://github.com/CQCL/hugr/commit/5bbe0cdc60625b4047f0cddc9598d6652ed6f736)) -* Added float and bytes literal to core and python bindings. ([#2289](https://github.com/CQCL/hugr/issues/2289)) ([e9c5e91](https://github.com/CQCL/hugr/commit/e9c5e914d4fd9ee270dee8e43875d8a413b02926)) -* **core, llvm:** add array unpack operations ([#2339](https://github.com/CQCL/hugr/issues/2339)) ([a1a70f1](https://github.com/CQCL/hugr/commit/a1a70f1afb5d8d57082269d167816c7a90497dcf)), closes [#1947](https://github.com/CQCL/hugr/issues/1947) -* Detect and fail on unrecognised envelope flags ([#2453](https://github.com/CQCL/hugr/issues/2453)) ([5e36770](https://github.com/CQCL/hugr/commit/5e36770895b79e878c1bbdf22e67e8cbff6513b6)) -* Export entrypoint metadata in Python and fix bug in import ([#2434](https://github.com/CQCL/hugr/issues/2434)) ([d17b245](https://github.com/CQCL/hugr/commit/d17b245c41d943da1c338094c31a75b55efe4061)) -* Expose `BorrowArray` in `hugr-py` ([#2425](https://github.com/CQCL/hugr/issues/2425)) ([fdb675f](https://github.com/CQCL/hugr/commit/fdb675f1473a9bf349fce0824c56539e239c11f3)), closes [#2406](https://github.com/CQCL/hugr/issues/2406) -* include generator metatada in model import and cli validate errors ([#2452](https://github.com/CQCL/hugr/issues/2452)) ([f7cedb4](https://github.com/CQCL/hugr/commit/f7cedb4f39b67a77b4c6a55ec00b624b54668eaa)) -* Names of private functions become `core.title` metadata. ([#2448](https://github.com/CQCL/hugr/issues/2448)) ([4bc7f65](https://github.com/CQCL/hugr/commit/4bc7f65338d9a8b37d3a5625aeaf093970d97926)) -* No nested FuncDefns (or AliasDefns) ([#2256](https://github.com/CQCL/hugr/issues/2256)) ([214b8df](https://github.com/CQCL/hugr/commit/214b8df837537b8ac15c3b60845350c3818a6ac7)) -* Non-region entrypoints in `hugr-model`. ([#2467](https://github.com/CQCL/hugr/issues/2467)) ([7b42da6](https://github.com/CQCL/hugr/commit/7b42da6f62de9fe36187512dba428fe3db8d6120)) -* Open lists and tuples in `Term` ([#2360](https://github.com/CQCL/hugr/issues/2360)) ([292af80](https://github.com/CQCL/hugr/commit/292af8010dba6b4c2ea5bb69edae31cbf1e0cb6a)) -* **py:** enable Model as default BINARY envelope format ([#2317](https://github.com/CQCL/hugr/issues/2317)) ([f089931](https://github.com/CQCL/hugr/commit/f08993124e48093c2328096a93cec8a9ad67a41c)) -* **py:** Helper methods to get the neighbours of a node ([#2370](https://github.com/CQCL/hugr/issues/2370)) ([bb6fa50](https://github.com/CQCL/hugr/commit/bb6fa50957ac5121bebc78a06335262a6559e695)) -* **py:** Use SumValue serialization for tuples ([#2466](https://github.com/CQCL/hugr/issues/2466)) ([f615037](https://github.com/CQCL/hugr/commit/f615037621aa0eeb37de8f1126fa9020705cb565)) -* Rename 'Any' type bound to 'Linear' ([#2421](https://github.com/CQCL/hugr/issues/2421)) ([c2f8b30](https://github.com/CQCL/hugr/commit/c2f8b30afd3a1b75f6babe77a90b13211e45e3a7)) -* Split `TypeArg::Sequence` into tuples and lists. ([#2140](https://github.com/CQCL/hugr/issues/2140)) ([cc4997f](https://github.com/CQCL/hugr/commit/cc4997f12dad4dfecc37be564712cae18dfce159)) -* Standarize the string formating of sum types and values ([#2432](https://github.com/CQCL/hugr/issues/2432)) ([ec207e7](https://github.com/CQCL/hugr/commit/ec207e7dbe6dbaa9f40421eb0836c9de7e3ea240)) -* Use binary envelopes for operation lower_func encoding ([#2447](https://github.com/CQCL/hugr/issues/2447)) ([2c16a77](https://github.com/CQCL/hugr/commit/2c16a7797a3b5800c5540d1e6a767dd38ad8ca6b)) - - -### Bug Fixes - -* Ensure SumTypes have the same json encoding in -rs and -py ([#2465](https://github.com/CQCL/hugr/issues/2465)) ([7f97e6f](https://github.com/CQCL/hugr/commit/7f97e6f84f0bb2b441fe3e2589e91f19de50198e)) -* Escape html-like labels in DotRenderer ([#2383](https://github.com/CQCL/hugr/issues/2383)) ([eaa7dfe](https://github.com/CQCL/hugr/commit/eaa7dfe35eb08dbd20d5f5353e92b58850e0f31f)) -* Export metadata in Python ([#2342](https://github.com/CQCL/hugr/issues/2342)) ([7be52db](https://github.com/CQCL/hugr/commit/7be52db4f63d7ce8556a5ba0d8d245ebb567e7ed)) -* Fix model export of `Opaque` types. ([#2446](https://github.com/CQCL/hugr/issues/2446)) ([3943499](https://github.com/CQCL/hugr/commit/39434996ba18db83a50455fda90c60aea11a8387)) -* Fixed bug in python model export name mangling. ([#2323](https://github.com/CQCL/hugr/issues/2323)) ([041342f](https://github.com/CQCL/hugr/commit/041342f58a3dcd9f73dbbaab102221c5d9ff5f61)) -* Fixed bugs in model CFG handling and improved CFG signatures ([#2334](https://github.com/CQCL/hugr/issues/2334)) ([ccd2eb2](https://github.com/CQCL/hugr/commit/ccd2eb226358b44aede7dd9e9217448c7e6c0f3a)) -* Fixed export of `Call` and `LoadConst` nodes in `hugr-py`. ([#2429](https://github.com/CQCL/hugr/issues/2429)) ([6a0e270](https://github.com/CQCL/hugr/commit/6a0e270e7edbea4cc08e2948d3f8a16b9e763af7)) -* Fixed invalid extension name in test. ([#2319](https://github.com/CQCL/hugr/issues/2319)) ([c58ddbf](https://github.com/CQCL/hugr/commit/c58ddbfcc0a557a1644fc8094370e6c62a7ce129)) -* Fixed two bugs in import/export of function operations ([#2324](https://github.com/CQCL/hugr/issues/2324)) ([1ad450f](https://github.com/CQCL/hugr/commit/1ad450f807485f7ef6083270aaa4523cb95b2490)) -* map IntValue to unsigned repr when serializing ([#2413](https://github.com/CQCL/hugr/issues/2413)) ([26d426e](https://github.com/CQCL/hugr/commit/26d426ee7ffdc38063a337e66458b8d797131bca)), closes [#2409](https://github.com/CQCL/hugr/issues/2409) -* Order hints on input and output nodes. ([#2422](https://github.com/CQCL/hugr/issues/2422)) ([a31ccbc](https://github.com/CQCL/hugr/commit/a31ccbcaaa7561f8d221269262cd9ca9e89ad67b)) -* **py:** correct ConstString JSON encoding ([#2325](https://github.com/CQCL/hugr/issues/2325)) ([9649a48](https://github.com/CQCL/hugr/commit/9649a48d376aff27e475c70072aecd55ae7a4ccb)) -* StaticArrayVal payload encoding, improve roundtrip checker ([#2444](https://github.com/CQCL/hugr/issues/2444)) ([1a301eb](https://github.com/CQCL/hugr/commit/1a301eb818401c314d4d7bac40698ec2e73babe7)) -* stringify metadata before escaping in renderer ([#2405](https://github.com/CQCL/hugr/issues/2405)) ([8d67420](https://github.com/CQCL/hugr/commit/8d67420e8fd2e979256ff64bcf0b2813ed19ac00)) - -## [0.12.5](https://github.com/CQCL/hugr/compare/hugr-py-v0.12.4...hugr-py-v0.12.5) (2025-07-08) - - -### Bug Fixes - -* map IntValue to unsigned repr when serializing ([#2413](https://github.com/CQCL/hugr/issues/2413)) ([4ad1d4e](https://github.com/CQCL/hugr/commit/4ad1d4e010eca07207306320b3cf74396f1f8181)), closes [#2409](https://github.com/CQCL/hugr/issues/2409) - -## [0.12.4](https://github.com/CQCL/hugr/compare/hugr-py-v0.12.3...hugr-py-v0.12.4) (2025-07-03) - - -### Bug Fixes - -* stringify metadata before escaping in renderer ([#2405](https://github.com/CQCL/hugr/issues/2405)) ([1f01e97](https://github.com/CQCL/hugr/commit/1f01e97696afe02b46eedb2c6e3e2f2369a4ac7b)) - -## [0.12.3](https://github.com/CQCL/hugr/compare/hugr-py-v0.12.2...hugr-py-v0.12.3) (2025-07-03) - - -### Features - -* add toposort to HUGR-py ([#2367](https://github.com/CQCL/hugr/issues/2367)) ([ba8988e](https://github.com/CQCL/hugr/commit/ba8988e87c2a3d64953838e9a1cff4989740cf05)) -* **core, llvm:** add array unpack operations ([#2339](https://github.com/CQCL/hugr/issues/2339)) ([74b25aa](https://github.com/CQCL/hugr/commit/74b25aa3a704c082f84a0c34fad2654e3392ff50)), closes [#1947](https://github.com/CQCL/hugr/issues/1947) -* **py:** Helper methods to get the neighbours of a node ([#2370](https://github.com/CQCL/hugr/issues/2370)) ([1ed6440](https://github.com/CQCL/hugr/commit/1ed64409aaf7e8f26fb5928051245e560881a621)) - - -### Bug Fixes - -* Escape html-like labels in DotRenderer ([#2383](https://github.com/CQCL/hugr/issues/2383)) ([c7a43a6](https://github.com/CQCL/hugr/commit/c7a43a69878e1271251b570070f192ebf57aaadd)) -* Fixed invalid extension name in test. ([#2319](https://github.com/CQCL/hugr/issues/2319)) ([fbe1d9c](https://github.com/CQCL/hugr/commit/fbe1d9c061768360144f5463dcf357fb59ac736f)) -* **py:** correct ConstString JSON encoding ([#2325](https://github.com/CQCL/hugr/issues/2325)) ([325168b](https://github.com/CQCL/hugr/commit/325168b50b5e40e884127ad89d7acb5ab3a412f8)) - -## [0.12.2](https://github.com/CQCL/hugr/compare/hugr-py-v0.12.1...hugr-py-v0.12.2) (2025-06-03) - - -### Bug Fixes - -* use envelopes for `FixedHugr` encoding ([#2283](https://github.com/CQCL/hugr/issues/2283)) ([2c8cbb9](https://github.com/CQCL/hugr/commit/2c8cbb99bc74d5d43956b5f75c89f17748b5ee39)), closes [#2282](https://github.com/CQCL/hugr/issues/2282) - - -### Performance Improvements - -* **py:** mutable `Node` to avoid linear update cost ([#2288](https://github.com/CQCL/hugr/issues/2288)) ([84fb200](https://github.com/CQCL/hugr/commit/84fb2002dc835f6b98ceb95bd80a7bcff9eecdd8)) - - -### Documentation - -* **py:** fix `TypeDef` example ([#2268](https://github.com/CQCL/hugr/issues/2268)) ([ede8e7b](https://github.com/CQCL/hugr/commit/ede8e7b087591303038ecc5b449bb85bf39c948b)) - ## [0.12.1](https://github.com/CQCL/hugr/compare/hugr-py-v0.12.0...hugr-py-v0.12.1) (2025-05-20) diff --git a/hugr-py/Cargo.toml b/hugr-py/Cargo.toml index c864e3bf5b..27020c8122 100644 --- a/hugr-py/Cargo.toml +++ b/hugr-py/Cargo.toml @@ -21,6 +21,6 @@ bench = false [dependencies] bumpalo = { workspace = true, features = ["collections"] } -hugr-model = { version = "0.22.1", path = "../hugr-model", features = ["pyo3"] } +hugr-model = { version = "0.20.2", path = "../hugr-model", features = ["pyo3"] } paste.workspace = true pyo3 = { workspace = true, features = ["extension-module", "abi3-py310"] } diff --git a/hugr-py/pyproject.toml b/hugr-py/pyproject.toml index 93739961da..8ecd43912e 100644 --- a/hugr-py/pyproject.toml +++ b/hugr-py/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "hugr" -version = "0.13.0rc1" +version = "0.12.1" requires-python = ">=3.10" description = "Quantinuum's common representation for quantum programs" license = { file = "LICENCE" } diff --git a/hugr-py/rust/lib.rs b/hugr-py/rust/lib.rs index bf7f0f1cbb..5a6c705d23 100644 --- a/hugr-py/rust/lib.rs +++ b/hugr-py/rust/lib.rs @@ -50,16 +50,6 @@ fn bytes_to_package(bytes: &[u8]) -> PyResult { Ok(package) } -/// Returns the current version of the HUGR model format as a tuple of (major, minor, patch). -#[pyfunction] -fn current_model_version() -> (u64, u64, u64) { - ( - hugr_model::CURRENT_VERSION.major, - hugr_model::CURRENT_VERSION.minor, - hugr_model::CURRENT_VERSION.patch, - ) -} - #[pymodule] fn _hugr(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_function(wrap_pyfunction!(term_to_string, m)?)?; @@ -78,6 +68,5 @@ fn _hugr(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_function(wrap_pyfunction!(string_to_param, m)?)?; m.add_function(wrap_pyfunction!(symbol_to_string, m)?)?; m.add_function(wrap_pyfunction!(string_to_symbol, m)?)?; - m.add_function(wrap_pyfunction!(current_model_version, m)?)?; Ok(()) } diff --git a/hugr-py/src/hugr/__init__.py b/hugr-py/src/hugr/__init__.py index 679d47d112..267d6f9e6a 100644 --- a/hugr-py/src/hugr/__init__.py +++ b/hugr-py/src/hugr/__init__.py @@ -18,4 +18,4 @@ # This is updated by our release-please workflow, triggered by this # annotation: x-release-please-version -__version__ = "0.13.0rc1" +__version__ = "0.12.1" diff --git a/hugr-py/src/hugr/_hugr/__init__.pyi b/hugr-py/src/hugr/_hugr/__init__.pyi index efcc99f910..68605037f3 100644 --- a/hugr-py/src/hugr/_hugr/__init__.pyi +++ b/hugr-py/src/hugr/_hugr/__init__.pyi @@ -18,4 +18,3 @@ def package_to_string(package: hugr.model.Package) -> str: ... def string_to_package(string: str) -> hugr.model.Package: ... def package_to_bytes(package: hugr.model.Package) -> bytes: ... def bytes_to_package(binary: bytes) -> hugr.model.Package: ... -def current_model_version() -> tuple[int, int, int]: ... diff --git a/hugr-py/src/hugr/_serialization/extension.py b/hugr-py/src/hugr/_serialization/extension.py index fed975fa61..5ffdae2ff9 100644 --- a/hugr-py/src/hugr/_serialization/extension.py +++ b/hugr-py/src/hugr/_serialization/extension.py @@ -66,26 +66,12 @@ def deserialize(self, extension: ext.Extension) -> ext.TypeDef: class FixedHugr(ConfiguredBaseModel): - """Fixed HUGR used to define the lowering of an operation. - - Args: - extensions: Extensions used in the HUGR. - hugr: Base64-encoded HUGR envelope. - """ - extensions: ExtensionSet hugr: str def deserialize(self) -> ext.FixedHugr: - # Loading fixed HUGRs requires reading hugr-model envelopes, - # which is not currently supported in Python. - # TODO: Add support for loading fixed HUGRs in Python. - # https://github.com/CQCL/hugr/issues/2287 - msg = ( - "Loading extensions with operation lowering functions is not " - + "supported in Python" - ) - raise NotImplementedError(msg) + hugr = Hugr.from_str(self.hugr) + return ext.FixedHugr(extensions=self.extensions, hugr=hugr) class OpDef(ConfiguredBaseModel, populate_by_name=True): @@ -105,21 +91,13 @@ def deserialize(self, extension: ext.Extension) -> ext.OpDef: self.binary, ) - # Loading fixed HUGRs requires reading hugr-model envelopes, - # which is not currently supported in Python. - # We currently ignore any lower functions instead of raising an error. - # - # TODO: Add support for loading fixed HUGRs in Python. - # https://github.com/CQCL/hugr/issues/2287 - lower_funcs: list[ext.FixedHugr] = [] - return extension.add_op_def( ext.OpDef( name=self.name, description=self.description, misc=self.misc or {}, signature=signature, - lower_funcs=lower_funcs, + lower_funcs=[f.deserialize() for f in self.lower_funcs], ) ) diff --git a/hugr-py/src/hugr/_serialization/ops.py b/hugr-py/src/hugr/_serialization/ops.py index 8f602ac94c..cde3bd6160 100644 --- a/hugr-py/src/hugr/_serialization/ops.py +++ b/hugr-py/src/hugr/_serialization/ops.py @@ -7,7 +7,6 @@ from pydantic import ConfigDict, Field, RootModel -from hugr import tys from hugr.hugr.node_port import ( NodeIdx, # noqa: TCH001 # pydantic needs this alias in scope ) @@ -76,16 +75,11 @@ class FuncDefn(BaseOp): name: str signature: PolyFuncType - visibility: tys.Visibility = Field(default="Private") def deserialize(self) -> ops.FuncDefn: poly_func = self.signature.deserialize() return ops.FuncDefn( - self.name, - params=poly_func.params, - inputs=poly_func.body.input, - _outputs=poly_func.body.output, - visibility=self.visibility, + self.name, inputs=poly_func.body.input, _outputs=poly_func.body.output ) @@ -95,12 +89,9 @@ class FuncDecl(BaseOp): op: Literal["FuncDecl"] = "FuncDecl" name: str signature: PolyFuncType - visibility: tys.Visibility = Field(default="Public") def deserialize(self) -> ops.FuncDecl: - return ops.FuncDecl( - self.name, self.signature.deserialize(), visibility=self.visibility - ) + return ops.FuncDecl(self.name, self.signature.deserialize()) class CustomConst(ConfiguredBaseModel): @@ -132,13 +123,24 @@ class FunctionValue(BaseValue): """A higher-order function value.""" v: Literal["Function"] = Field(default="Function", title="ValueTag") - hugr: str + hugr: Any def deserialize(self) -> val.Value: + from hugr._serialization.serial_hugr import SerialHugr from hugr.hugr import Hugr # pydantic stores the serialized dictionary because of the "Any" annotation - return val.Function(Hugr.from_str(self.hugr)) + return val.Function(Hugr._from_serial(SerialHugr(**self.hugr))) + + +class TupleValue(BaseValue): + """A constant tuple value.""" + + v: Literal["Tuple"] = Field(default="Tuple", title="ValueTag") + vs: list[Value] + + def deserialize(self) -> val.Value: + return val.Tuple(*deser_it(v.root for v in self.vs)) class SumValue(BaseValue): @@ -147,9 +149,9 @@ class SumValue(BaseValue): For any Sum type where this value meets the type of the variant indicated by the tag """ - v: Literal["Sum", "Tuple"] = Field(default="Sum", title="ValueTag") - tag: int = Field(default=0, title="VariantTag") - typ: SumType | None = Field(default=None, title="SumType") + v: Literal["Sum"] = Field(default="Sum", title="ValueTag") + tag: int + typ: SumType vs: list[Value] model_config = ConfigDict( json_schema_extra={ @@ -161,22 +163,15 @@ class SumValue(BaseValue): ) def deserialize(self) -> val.Value: - if self.typ is None: - # Backwards compatibility of "Tuple" values - assert self.tag == 0, "Sum type must be provided if tag is not 0" - vs = deser_it(v.root for v in self.vs) - typ = tys.Sum(variant_rows=[[v.type_() for v in vs]]) - return val.Sum(0, typ, vs) - else: - return val.Sum( - self.tag, self.typ.deserialize(), deser_it(v.root for v in self.vs) - ) + return val.Sum( + self.tag, self.typ.deserialize(), deser_it(v.root for v in self.vs) + ) class Value(RootModel): """A constant Value.""" - root: CustomValue | FunctionValue | SumValue = Field(discriminator="v") + root: CustomValue | FunctionValue | TupleValue | SumValue = Field(discriminator="v") model_config = ConfigDict(json_schema_extra={"required": ["v"]}) @@ -603,5 +598,6 @@ class OpType(RootModel): from hugr import ( # noqa: E402 # needed to avoid circular imports ops, + tys, val, ) diff --git a/hugr-py/src/hugr/_serialization/tys.py b/hugr-py/src/hugr/_serialization/tys.py index 1ed869c56c..c00a733751 100644 --- a/hugr-py/src/hugr/_serialization/tys.py +++ b/hugr-py/src/hugr/_serialization/tys.py @@ -1,6 +1,5 @@ from __future__ import annotations -import base64 import inspect import sys from abc import ABC, abstractmethod @@ -95,20 +94,6 @@ def deserialize(self) -> tys.StringParam: return tys.StringParam() -class BytesParam(BaseTypeParam): - tp: Literal["Bytes"] = "Bytes" - - def deserialize(self) -> tys.BytesParam: - return tys.BytesParam() - - -class FloatParam(BaseTypeParam): - tp: Literal["Float"] = "Float" - - def deserialize(self) -> tys.FloatParam: - return tys.FloatParam() - - class ListParam(BaseTypeParam): tp: Literal["List"] = "List" param: TypeParam @@ -129,13 +114,7 @@ class TypeParam(RootModel): """A type parameter.""" root: Annotated[ - TypeTypeParam - | BoundedNatParam - | StringParam - | FloatParam - | BytesParam - | ListParam - | TupleParam, + TypeTypeParam | BoundedNatParam | StringParam | ListParam | TupleParam, WrapValidator(_json_custom_error_validator), ] = Field(discriminator="tp") @@ -179,56 +158,12 @@ def deserialize(self) -> tys.StringArg: return tys.StringArg(value=self.arg) -class FloatArg(BaseTypeArg): - tya: Literal["Float"] = "Float" - value: float - - def deserialize(self) -> tys.FloatArg: - return tys.FloatArg(value=self.value) - - -class BytesArg(BaseTypeArg): - tya: Literal["Bytes"] = "Bytes" - value: str = Field( - description="Base64-encoded byte string", - json_schema_extra={"contentEncoding": "base64"}, - ) - - def deserialize(self) -> tys.BytesArg: - value = base64.b64decode(self.value) - return tys.BytesArg(value=value) - - -class ListArg(BaseTypeArg): - tya: Literal["List"] = "List" +class SequenceArg(BaseTypeArg): + tya: Literal["Sequence"] = "Sequence" elems: list[TypeArg] - def deserialize(self) -> tys.ListArg: - return tys.ListArg(elems=deser_it(self.elems)) - - -class ListConcatArg(BaseTypeArg): - tya: Literal["ListConcat"] = "ListConcat" - lists: list[TypeArg] - - def deserialize(self) -> tys.ListConcatArg: - return tys.ListConcatArg(lists=deser_it(self.lists)) - - -class TupleArg(BaseTypeArg): - tya: Literal["Tuple"] = "Tuple" - elems: list[TypeArg] - - def deserialize(self) -> tys.TupleArg: - return tys.TupleArg(elems=deser_it(self.elems)) - - -class TupleConcatArg(BaseTypeArg): - tya: Literal["TupleConcat"] = "TupleConcat" - tuples: list[TypeArg] - - def deserialize(self) -> tys.TupleConcatArg: - return tys.TupleConcatArg(tuples=deser_it(self.tuples)) + def deserialize(self) -> tys.SequenceArg: + return tys.SequenceArg(elems=deser_it(self.elems)) class VariableArg(BaseTypeArg): @@ -244,14 +179,7 @@ class TypeArg(RootModel): """A type argument.""" root: Annotated[ - TypeTypeArg - | BoundedNatArg - | StringArg - | BytesArg - | FloatArg - | ListArg - | TupleArg - | VariableArg, + TypeTypeArg | BoundedNatArg | StringArg | SequenceArg | VariableArg, WrapValidator(_json_custom_error_validator), ] = Field(discriminator="tya") @@ -412,15 +340,15 @@ def deserialize(self) -> tys.PolyFuncType: class TypeBound(Enum): Copyable = "C" - Linear = "A" + Any = "A" @staticmethod def join(*bs: TypeBound) -> TypeBound: """Computes the least upper bound for a sequence of bounds.""" res = TypeBound.Copyable for b in bs: - if b == TypeBound.Linear: - return TypeBound.Linear + if b == TypeBound.Any: + return TypeBound.Any if res == TypeBound.Copyable: res = b return res @@ -429,8 +357,8 @@ def __str__(self) -> str: match self: case TypeBound.Copyable: return "Copyable" - case TypeBound.Linear: - return "Linear" + case TypeBound.Any: + return "Any" class Opaque(BaseType): diff --git a/hugr-py/src/hugr/build/dfg.py b/hugr-py/src/hugr/build/dfg.py index ee4b917353..786723a606 100644 --- a/hugr-py/src/hugr/build/dfg.py +++ b/hugr-py/src/hugr/build/dfg.py @@ -21,9 +21,8 @@ if TYPE_CHECKING: from collections.abc import Iterable, Sequence - from hugr.build.function import Module from hugr.hugr.node_port import Node, OutPort, PortOffset, ToNode, Wire - from hugr.tys import TypeParam, TypeRow + from hugr.tys import Type, TypeParam, TypeRow from .cfg import Cfg from .cond_loop import Conditional, If, TailLoop @@ -37,21 +36,40 @@ class DataflowError(Exception): @dataclass() class DefinitionBuilder(Generic[OpVar]): - """Base class for builders that can define constants, and allow access - to the `Module` for declaring/defining functions and aliases. + """Base class for builders that can define functions, constants, and aliases. As this class may be a root node, it does not extend `ParentBuilder`. """ hugr: Hugr[OpVar] - def module_root_builder(self) -> Module: - """Allows access to the `Module` at the root of the Hugr - (outside the scope of this builder, perhaps outside the entrypoint). - """ - from hugr.build.function import Module # Avoid circular import + def define_function( + self, + name: str, + input_types: TypeRow, + output_types: TypeRow | None = None, + type_params: list[TypeParam] | None = None, + parent: ToNode | None = None, + ) -> Function: + """Start building a function definition in the graph. - return Module(self.hugr) + Args: + name: The name of the function. + input_types: The input types for the function. + output_types: The output types for the function. + If not provided, it will be inferred after the function is built. + type_params: The type parameters for the function, if polymorphic. + parent: The parent node of the constant. Defaults to the entrypoint node. + + Returns: + The new function builder. + """ + parent_node = parent or self.hugr.entrypoint + parent_op = ops.FuncDefn(name, input_types, type_params or []) + func = Function.new_nested(parent_op, self.hugr, parent_node) + if output_types is not None: + func.declare_outputs(output_types) + return func def add_const(self, value: val.Value, parent: ToNode | None = None) -> Node: """Add a static constant to the graph. @@ -72,6 +90,11 @@ def add_const(self, value: val.Value, parent: ToNode | None = None) -> Node: parent_node = parent or self.hugr.entrypoint return self.hugr.add_node(ops.Const(value), parent_node) + def add_alias_defn(self, name: str, ty: Type, parent: ToNode | None = None) -> Node: + """Add a type alias definition.""" + parent_node = parent or self.hugr.entrypoint + return self.hugr.add_node(ops.AliasDefn(name, ty), parent_node) + DP = TypeVar("DP", bound=ops.DfParentOp) @@ -132,15 +155,8 @@ def new_nested( """ new = cls.__new__(cls) - try: - num_outs = parent_op.num_out - except ops.IncompleteOp: - num_outs = None - new.hugr = hugr - new.parent_node = hugr.add_node( - parent_op, parent or hugr.entrypoint, num_outs=num_outs - ) + new.parent_node = hugr.add_node(parent_op, parent or hugr.entrypoint) new._init_io_nodes(parent_op) return new @@ -212,14 +228,7 @@ def add_op( >>> dfg.add_op(ops.Noop(), dfg.inputs()[0]) Node(3) """ - try: - num_outs = op.num_out - except ops.IncompleteOp: - num_outs = None - - new_n = self.hugr.add_node( - op, self.parent_node, metadata=metadata, num_outs=num_outs - ) + new_n = self.hugr.add_node(op, self.parent_node, metadata=metadata) self._wire_up(new_n, args) new_n._num_out_ports = op.num_out return new_n @@ -746,6 +755,7 @@ def declare_outputs(self, output_types: TypeRow) -> None: defined yet. The wires passed to :meth:`set_outputs` must match the declared output types. """ + self._set_parent_output_count(len(output_types)) self.parent_op._set_out_types(output_types) def set_outputs(self, *args: Wire) -> None: diff --git a/hugr-py/src/hugr/build/function.py b/hugr-py/src/hugr/build/function.py index 6c7fa29249..b5d8b8c1ff 100644 --- a/hugr-py/src/hugr/build/function.py +++ b/hugr-py/src/hugr/build/function.py @@ -11,7 +11,7 @@ if TYPE_CHECKING: from hugr.hugr.node_port import Node - from hugr.tys import PolyFuncType, Type, TypeBound, TypeParam, TypeRow + from hugr.tys import PolyFuncType, TypeBound, TypeRow __all__ = ["Function", "Module"] @@ -28,39 +28,13 @@ class Module(DefinitionBuilder[ops.Module]): hugr: Hugr[ops.Module] - def __init__(self, hugr: Hugr | None = None) -> None: - self.hugr = Hugr(ops.Module()) if hugr is None else hugr + def __init__(self) -> None: + self.hugr = Hugr(ops.Module()) def define_main(self, input_types: TypeRow) -> Function: """Define the 'main' function in the module. See :meth:`define_function`.""" return self.define_function("main", input_types) - def define_function( - self, - name: str, - input_types: TypeRow, - output_types: TypeRow | None = None, - type_params: list[TypeParam] | None = None, - ) -> Function: - """Start building a function definition in the graph. - - Args: - name: The name of the function. - input_types: The input types for the function. - output_types: The output types for the function. - If not provided, it will be inferred after the function is built. - type_params: The type parameters for the function, if polymorphic. - parent: The parent node of the constant. Defaults to the entrypoint node. - - Returns: - The new function builder. - """ - parent_op = ops.FuncDefn(name, input_types, type_params or []) - func = Function.new_nested(parent_op, self.hugr, self.hugr.module_root) - if output_types is not None: - func.declare_outputs(output_types) - return func - def declare_function(self, name: str, signature: PolyFuncType) -> Node: """Add a function declaration to the module. @@ -78,17 +52,11 @@ def declare_function(self, name: str, signature: PolyFuncType) -> Node: >>> m.declare_function("f", sig) Node(1) """ - return self.hugr.add_node( - ops.FuncDecl(name, signature), self.hugr.entrypoint, num_outs=1 - ) - - def add_alias_defn(self, name: str, ty: Type) -> Node: - """Add a type alias definition.""" - return self.hugr.add_node(ops.AliasDefn(name, ty), self.hugr.module_root) + return self.hugr.add_node(ops.FuncDecl(name, signature), self.hugr.entrypoint) def add_alias_decl(self, name: str, bound: TypeBound) -> Node: """Add a type alias declaration.""" - return self.hugr.add_node(ops.AliasDecl(name, bound), self.hugr.module_root) + return self.hugr.add_node(ops.AliasDecl(name, bound), self.hugr.entrypoint) @property def metadata(self) -> dict[str, object]: diff --git a/hugr-py/src/hugr/envelope.py b/hugr-py/src/hugr/envelope.py index 07d199ad4f..643c6b3ab4 100644 --- a/hugr-py/src/hugr/envelope.py +++ b/hugr-py/src/hugr/envelope.py @@ -46,12 +46,6 @@ # This is a hard-coded magic number that identifies the start of a HUGR envelope. MAGIC_NUMBERS = b"HUGRiHJv" -# The all-unset header flags configuration. -# Bit 7 is always set to ensure we have a printable ASCII character. -_DEFAULT_FLAGS = 0b0100_0000 -# The ZSTD flag bit in the header's flags. -_ZSTD_FLAG = 0b0000_0001 - def make_envelope(package: Package | Hugr, config: EnvelopeConfig) -> bytes: """Encode a HUGR or Package into an envelope, using the given configuration.""" @@ -71,10 +65,10 @@ def make_envelope(package: Package | Hugr, config: EnvelopeConfig) -> bytes: # `make_envelope_str`, but we prioritize speed for binary formats. payload = json_str.encode("utf-8") - case EnvelopeFormat.MODEL: + case EnvelopeFormat.MODULE: payload = bytes(package.to_model()) - case EnvelopeFormat.MODEL_WITH_EXTS: + case EnvelopeFormat.MODULE_WITH_EXTS: package_bytes = bytes(package.to_model()) extension_str = json.dumps( [ext._to_serial().model_dump(mode="json") for ext in package.extensions] @@ -111,7 +105,7 @@ def read_envelope(envelope: bytes) -> Package: match header.format: case EnvelopeFormat.JSON: return ext_s.Package.model_validate_json(payload).deserialize() - case EnvelopeFormat.MODEL | EnvelopeFormat.MODEL_WITH_EXTS: + case EnvelopeFormat.MODULE | EnvelopeFormat.MODULE_WITH_EXTS: msg = "Decoding HUGR envelopes in MODULE format is not supported yet." raise ValueError(msg) @@ -156,10 +150,10 @@ def read_envelope_hugr_str(envelope: str) -> Hugr: class EnvelopeFormat(Enum): """Format used to encode a HUGR envelope.""" - MODEL = 1 - """A capnp-encoded hugr-model.""" - MODEL_WITH_EXTS = 2 - """A capnp-encoded hugr-model, immediately followed by a json-encoded + MODULE = 1 + """A capnp-encoded hugr-module.""" + MODULE_WITH_EXTS = 2 + """A capnp-encoded hugr-module, immediately followed by a json-encoded extension registry.""" JSON = 63 # '?' in ASCII """A json-encoded hugr-package. This format is ASCII-printable.""" @@ -186,9 +180,9 @@ class EnvelopeHeader: def to_bytes(self) -> bytes: header_bytes = bytearray(MAGIC_NUMBERS) header_bytes.append(self.format.value) - flags = _DEFAULT_FLAGS + flags = 0b01000000 if self.zstd: - flags |= _ZSTD_FLAG + flags |= 0b00000001 header_bytes.append(flags) return bytes(header_bytes) @@ -210,15 +204,7 @@ def from_bytes(data: bytes) -> EnvelopeHeader: format: EnvelopeFormat = EnvelopeFormat(data[8]) flags = data[9] - zstd = bool(flags & _ZSTD_FLAG) - other_flags = (flags ^ _DEFAULT_FLAGS) & ~_ZSTD_FLAG - if other_flags: - flag_ids = [i for i in range(8) if other_flags & (1 << i)] - msg = ( - f"Unrecognised Envelope flags {flag_ids}." - + " Please update your HUGR version." - ) - raise ValueError(msg) + zstd = bool(flags & 0b00000001) return EnvelopeHeader(format=format, zstd=zstd) @@ -246,4 +232,4 @@ def _make_header(self) -> EnvelopeHeader: # Set EnvelopeConfig's class variables. # These can only be initialized _after_ the class is defined. EnvelopeConfig.TEXT = EnvelopeConfig(format=EnvelopeFormat.JSON, zstd=None) -EnvelopeConfig.BINARY = EnvelopeConfig(format=EnvelopeFormat.MODEL_WITH_EXTS, zstd=0) +EnvelopeConfig.BINARY = EnvelopeConfig(format=EnvelopeFormat.JSON, zstd=None) diff --git a/hugr-py/src/hugr/ext.py b/hugr-py/src/hugr/ext.py index 53def975d3..8123fa556b 100644 --- a/hugr-py/src/hugr/ext.py +++ b/hugr-py/src/hugr/ext.py @@ -2,7 +2,6 @@ from __future__ import annotations -import base64 from dataclasses import dataclass, field from typing import TYPE_CHECKING, Any, TypeVar @@ -155,8 +154,7 @@ class FixedHugr: hugr: Hugr def _to_serial(self) -> ext_s.FixedHugr: - hugr_64: str = base64.b64encode(self.hugr.to_bytes()).decode() - return ext_s.FixedHugr(extensions=self.extensions, hugr=hugr_64) + return ext_s.FixedHugr(extensions=self.extensions, hugr=self.hugr.to_str()) @dataclass diff --git a/hugr-py/src/hugr/hugr/base.py b/hugr-py/src/hugr/hugr/base.py index d0fe111ee6..f5ca7ff8e2 100644 --- a/hugr-py/src/hugr/hugr/base.py +++ b/hugr-py/src/hugr/hugr/base.py @@ -35,16 +35,13 @@ Conditional, Const, Custom, - DataflowBlock, DataflowOp, - ExitBlock, FuncDefn, IncompleteOp, Module, Op, - is_dataflow_op, ) -from hugr.tys import Kind, OrderKind, Type, ValueKind +from hugr.tys import Kind, Type, ValueKind from hugr.utils import BiMap from hugr.val import Value @@ -101,8 +98,7 @@ class Hugr(Mapping[Node, NodeData], Generic[OpVarCov]): """The core HUGR datastructure. Args: - entrypoint_op: The operation for the entrypoint node. Defaults to a Module - (which will then also be the root). + root_op: The operation for the root node. Defaults to a Module. Examples: >>> h = Hugr() @@ -152,9 +148,7 @@ def __init__(self, entrypoint_op: OpVarCov | None = None) -> None: case None | Module(): pass case ops.FuncDefn(): - self.entrypoint = self.add_node( - entrypoint_op, self.module_root, num_outs=1 - ) + self.entrypoint = self.add_node(entrypoint_op, self.module_root) case _: from hugr.build import Function @@ -232,65 +226,6 @@ def nodes(self) -> Iterable[tuple[Node, NodeData]]: """ return self.items() - def sorted_region_nodes(self, parent: Node) -> Iterator[Node]: - """Iterator over a topological ordering of all the hugr nodes. - - Note that the sort is performed within a hugr region and non-local - edges are ignored. - - Args: - parent: The parent node of the region to sort. - - Raises: - ValueError: If the region contains a cycle. - - Examples: - >>> from hugr.build.tracked_dfg import TrackedDfg - >>> from hugr.std.logic import Not - >>> dfg = TrackedDfg(tys.Bool) - >>> [b] = dfg.track_inputs() - >>> for _ in range(6): - ... _= dfg.add(Not(b)); - >>> dfg.set_tracked_outputs() - >>> nodes = list(dfg.hugr) - >>> list(dfg.hugr.sorted_region_nodes(nodes[4])) - [Node(5), Node(7), Node(8), Node(9), Node(10), Node(11), Node(12), Node(6)] - """ - # A dict to keep track of how many times we see a node. - # Store the Nodes with the input degrees as values. - # Implementation uses Kahn's algorithm - # https://en.wikipedia.org/wiki/Topological_sorting#Kahn's_algorithm - visit_dict: dict[Node, int] = {} - queue: Queue[Node] = Queue() - for node in self.children(parent): - incoming = 0 - for n in self.input_neighbours(node): - same_region = self[n].parent == parent - # Only update the degree of the node if edge is within the same region. - # We do not count non-local edges. - if same_region: - incoming += 1 - if incoming: - visit_dict[node] = incoming - # If a Node has no dependencies, add it to the queue. - else: - queue.put(node) - - while not queue.empty(): - new_node = queue.get() - yield new_node - - for neigh in self.output_neighbours(new_node): - visit_dict[neigh] -= 1 - if visit_dict[neigh] == 0: - del visit_dict[neigh] - queue.put(neigh) - - # If our dict is non-empty here then our graph contains a cycle - if visit_dict: - err = "Graph contains a cycle. No topological ordering exists." - raise ValueError(err) - def links(self) -> Iterator[tuple[OutPort, InPort]]: """Iterator over all the links in the HUGR. @@ -547,12 +482,6 @@ def add_order_link(self, src: ToNode, dst: ToNode) -> None: """ source = src.out(-1) target = dst.inp(-1) - assert ( - self.port_kind(source) == OrderKind() - ), f"Operation {self[src].op.name()} does not support order edges" - assert ( - self.port_kind(target) == OrderKind() - ), f"Operation {self[dst].op.name()} does not support order edges" if not self.has_link(source, target): self.add_link(source, target) @@ -598,20 +527,15 @@ def num_ports(self, node: ToNode, direction: Direction) -> int: Not necessarily the number of connected ports - if port `i` is connected, then all ports `0..i` are assumed to exist. - This value includes order ports. - Args: node: Node to query. direction: Direction of ports to count. Examples: - >>> from hugr.std.logic import Not >>> h = Hugr() - >>> n1 = h.add_node(Not) - >>> n2 = h.add_node(Not) - >>> # Passing offset `2` here allocates new ports automatically - >>> h.add_link(n1.out(0), n2.inp(2)) - >>> h.add_order_link(n1, n2) + >>> n1 = h.add_const(val.TRUE) + >>> n2 = h.add_const(val.FALSE) + >>> h.add_link(n1.out(0), n2.inp(2)) # not a valid link! >>> h.num_ports(n1, Direction.OUTGOING) 1 >>> h.num_ports(n2, Direction.INCOMING) @@ -624,17 +548,11 @@ def num_ports(self, node: ToNode, direction: Direction) -> int: ) def num_in_ports(self, node: ToNode) -> int: - """The number of incoming ports of a node. See :meth:`num_ports`. - - This value does not include order ports. - """ + """The number of incoming ports of a node. See :meth:`num_ports`.""" return self[node]._num_inps def num_out_ports(self, node: ToNode) -> int: - """The number of outgoing ports of a node. See :meth:`num_ports`. - - This value cound does not include order ports. - """ + """The number of outgoing ports of a node. See :meth:`num_ports`.""" return self[node]._num_outs def _linked_ports( @@ -716,16 +634,9 @@ def _node_links( port = cast("P", node.port(offset, direction)) yield port, list(self._linked_ports(port, links)) - order_port = cast("P", node.port(-1, direction)) - linked_order = list(self._linked_ports(order_port, links)) - if linked_order: - yield order_port, linked_order - def outgoing_links(self, node: ToNode) -> Iterable[tuple[OutPort, list[InPort]]]: """Iterator over outgoing links from a given node. - This number includes order ports. - Args: node: Node to query. @@ -737,17 +648,14 @@ def outgoing_links(self, node: ToNode) -> Iterable[tuple[OutPort, list[InPort]]] >>> df = dfg.Dfg() >>> df.hugr.add_link(df.input_node.out(0), df.output_node.inp(0)) >>> df.hugr.add_link(df.input_node.out(0), df.output_node.inp(1)) - >>> df.hugr.add_order_link(df.input_node, df.output_node) >>> list(df.hugr.outgoing_links(df.input_node)) - [(OutPort(Node(5), 0), [InPort(Node(6), 0), InPort(Node(6), 1)]), (OutPort(Node(5), -1), [InPort(Node(6), -1)])] - """ # noqa: E501 + [(OutPort(Node(5), 0), [InPort(Node(6), 0), InPort(Node(6), 1)])] + """ return self._node_links(node, self._links.fwd) def incoming_links(self, node: ToNode) -> Iterable[tuple[InPort, list[OutPort]]]: """Iterator over incoming links to a given node. - This number includes order ports. - Args: node: Node to query. @@ -759,81 +667,11 @@ def incoming_links(self, node: ToNode) -> Iterable[tuple[InPort, list[OutPort]]] >>> df = dfg.Dfg() >>> df.hugr.add_link(df.input_node.out(0), df.output_node.inp(0)) >>> df.hugr.add_link(df.input_node.out(0), df.output_node.inp(1)) - >>> df.hugr.add_order_link(df.input_node, df.output_node) >>> list(df.hugr.incoming_links(df.output_node)) - [(InPort(Node(6), 0), [OutPort(Node(5), 0)]), (InPort(Node(6), 1), [OutPort(Node(5), 0)]), (InPort(Node(6), -1), [OutPort(Node(5), -1)])] + [(InPort(Node(6), 0), [OutPort(Node(5), 0)]), (InPort(Node(6), 1), [OutPort(Node(5), 0)])] """ # noqa: E501 return self._node_links(node, self._links.bck) - def neighbours( - self, node: ToNode, direction: Direction | None = None - ) -> Iterable[Node]: - """Iterator over the neighbours of a node. - - Args: - node: Node to query. - direction: If given, only return neighbours in that direction. - - Returns: - Iterator of nodes connected to `node`, ordered by direction and port - offset. Nodes connected via multiple links will be returned multiple times. - - Examples: - >>> df = dfg.Dfg() - >>> df.hugr.add_link(df.input_node.out(0), df.output_node.inp(0)) - >>> df.hugr.add_link(df.input_node.out(0), df.output_node.inp(1)) - >>> list(df.hugr.neighbours(df.input_node)) - [Node(6), Node(6)] - >>> list(df.hugr.neighbours(df.output_node, Direction.OUTGOING)) - [] - """ - if direction is None or direction == Direction.INCOMING: - for _, linked_outputs in self.incoming_links(node): - for out_port in linked_outputs: - yield out_port.node - if direction is None or direction == Direction.OUTGOING: - for _, linked_inputs in self.outgoing_links(node): - for in_port in linked_inputs: - yield in_port.node - - def input_neighbours(self, node: ToNode) -> Iterable[Node]: - """Iterator over the input neighbours of a node. - - Args: - node: Node to query. - - Returns: - Iterator of nodes connected to `node` via incoming links. - Nodes connected via multiple links will be returned multiple times. - - Examples: - >>> df = dfg.Dfg() - >>> df.hugr.add_link(df.input_node.out(0), df.output_node.inp(0)) - >>> df.hugr.add_link(df.input_node.out(0), df.output_node.inp(1)) - >>> list(df.hugr.input_neighbours(df.output_node)) - [Node(5), Node(5)] - """ - return self.neighbours(node, Direction.INCOMING) - - def output_neighbours(self, node: ToNode) -> Iterable[Node]: - """Iterator over the output neighbours of a node. - - Args: - node: Node to query. - - Returns: - Iterator of nodes connected to `node` via outgoing links. - Nodes connected via multiple links will be returned multiple times. - - Examples: - >>> df = dfg.Dfg() - >>> df.hugr.add_link(df.input_node.out(0), df.output_node.inp(0)) - >>> df.hugr.add_link(df.input_node.out(0), df.output_node.inp(1)) - >>> list(df.hugr.output_neighbours(df.input_node)) - [Node(6), Node(6)] - """ - return self.neighbours(node, Direction.OUTGOING) - def num_incoming(self, node: Node) -> int: """The number of incoming links to a `node`. @@ -843,7 +681,7 @@ def num_incoming(self, node: Node) -> int: >>> df.hugr.num_incoming(df.output_node) 1 """ - return sum(len(links) for (_, links) in self.incoming_links(node)) + return sum(1 for _ in self.incoming_links(node)) def num_outgoing(self, node: ToNode) -> int: """The number of outgoing links from a `node`. @@ -854,7 +692,7 @@ def num_outgoing(self, node: ToNode) -> int: >>> df.hugr.num_outgoing(df.input_node) 1 """ - return sum(len(links) for (_, links) in self.outgoing_links(node)) + return sum(1 for _ in self.outgoing_links(node)) # TODO: num_links and _linked_ports @@ -939,9 +777,7 @@ def _to_serial(self) -> SerialHugr: def _serialize_link( link: tuple[_SO, _SI], - ) -> tuple[ - tuple[NodeIdx, PortOffset | None], tuple[NodeIdx, PortOffset | None] - ]: + ) -> tuple[tuple[NodeIdx, PortOffset], tuple[NodeIdx, PortOffset]]: src, dst = link s, d = self._constrain_offset(src.port), self._constrain_offset(dst.port) return (src.port.node.idx, s), (dst.port.node.idx, d) @@ -968,16 +804,16 @@ def _serialize_link( entrypoint=entrypoint, ) - def _constrain_offset(self, p: P) -> PortOffset | None: - """Constrain an offset to be a valid encoded port offset. - - Order edges and control flow edges should be encoded without an offset. - """ + def _constrain_offset(self, p: P) -> PortOffset: + # An offset of -1 is a special case, indicating an order edge, + # not counted in the number of ports. if p.offset < 0: assert p.offset == -1, "Only order edges are allowed with offset < 0" - return None + offset = self.num_ports(p.node, p.direction) else: - return p.offset + offset = p.offset + + return offset def resolve_extensions(self, registry: ext.ExtensionRegistry) -> Hugr: """Resolve extension types and operations in the HUGR by matching them to @@ -997,7 +833,7 @@ def _connect_df_entrypoint_outputs(self) -> None: """ from hugr.build import Function - if not is_dataflow_op(self.entrypoint_op()): + if not isinstance(self.entrypoint_op(), DataflowOp): return func_node = self[self.entrypoint].parent @@ -1042,8 +878,9 @@ def get_meta(idx: int) -> dict[str, Any]: parent: Node | None = Node(serial_node.root.parent) serial_node.root.parent = -1 - op = serial_node.root.deserialize() - n = hugr._add_node(op, parent, metadata=node_meta, num_outs=op.num_out) + n = hugr._add_node( + serial_node.root.deserialize(), parent, metadata=node_meta + ) assert ( n.idx == idx + boilerplate_nodes ), "Nodes should be added contiguously" @@ -1052,21 +889,11 @@ def get_meta(idx: int) -> dict[str, Any]: hugr.entrypoint = n for (src_node, src_offset), (dst_node, dst_offset) in serial.edges: - src = Node(src_node, _metadata=get_meta(src_node)) - dst = Node(dst_node, _metadata=get_meta(dst_node)) if src_offset is None or dst_offset is None: - src_op = hugr[src].op - if isinstance(src_op, DataflowBlock | ExitBlock): - # Control flow edge - src_offset = 0 - dst_offset = 0 - else: - # Order edge - hugr.add_order_link(src, dst) - continue + continue hugr.add_link( - src.out(src_offset), - dst.inp(dst_offset), + Node(src_node, _metadata=get_meta(src_node)).out(src_offset), + Node(dst_node, _metadata=get_meta(dst_node)).inp(dst_offset), ) return hugr diff --git a/hugr-py/src/hugr/hugr/render.py b/hugr-py/src/hugr/hugr/render.py index ce72b7cb09..059ed2c03f 100644 --- a/hugr-py/src/hugr/hugr/render.py +++ b/hugr-py/src/hugr/hugr/render.py @@ -1,6 +1,5 @@ """Visualise HUGR using graphviz.""" -import html from collections.abc import Iterable from dataclasses import dataclass, field @@ -102,9 +101,7 @@ def render(self, hugr: Hugr) -> Digraph: "margin": "0", "bgcolor": self.config.palette.background, } - if name := hugr[hugr.module_root].metadata.get("name", None): - name = html.escape(str(name)) - else: + if not (name := hugr[hugr.module_root].metadata.get("name", None)): name = "" graph = gv.Digraph(name, strict=False) @@ -218,8 +215,7 @@ def _viz_node(self, node: Node, hugr: Hugr, graph: Digraph) -> None: meta = hugr[node].metadata if len(meta) > 0: data = "

" + "
".join( - html.escape(key) + ": " + html.escape(str(value)) - for key, value in meta.items() + f"{key}: {value}" for key, value in meta.items() ) else: data = "" @@ -240,7 +236,6 @@ def _viz_node(self, node: Node, hugr: Hugr, graph: Digraph) -> None: op_name = op.op_def().name else: op_name = op.name() - op_name = html.escape(op_name) label_config = { "node_back_color": self.config.palette.node, @@ -291,7 +286,7 @@ def _viz_link( label = "" match kind: case ValueKind(ty): - label = html.escape(str(ty)) + label = str(ty) color = self.config.palette.edge case OrderKind(): color = self.config.palette.dark diff --git a/hugr-py/src/hugr/model/__init__.py b/hugr-py/src/hugr/model/__init__.py index ac97177cd2..d7688e0112 100644 --- a/hugr-py/src/hugr/model/__init__.py +++ b/hugr-py/src/hugr/model/__init__.py @@ -5,20 +5,7 @@ from enum import Enum from typing import Protocol -from semver import Version - import hugr._hugr as rust -from hugr.tys import Visibility - - -def _current_version() -> Version: - """Get the current version of the HUGR model.""" - (major, minor, patch) = rust.current_model_version() - return Version(major=major, minor=minor, patch=patch) - - -# The current version of the HUGR model. -CURRENT_VERSION: Version = _current_version() class Term(Protocol): @@ -114,7 +101,6 @@ class Symbol: """A named symbol.""" name: str - visibility: Visibility params: Sequence[Param] = field(default_factory=list) constraints: Sequence[Term] = field(default_factory=list) signature: Term = field(default_factory=Wildcard) @@ -303,8 +289,3 @@ def from_str(s: str) -> "Package": def from_bytes(b: bytes) -> "Package": """Read a package from its binary representation.""" return rust.bytes_to_package(b) - - @property - def version(self) -> Version: - """Returns the model version used to encode this package.""" - return CURRENT_VERSION diff --git a/hugr-py/src/hugr/model/export.py b/hugr-py/src/hugr/model/export.py index 8bef32e452..d93713d2a9 100644 --- a/hugr-py/src/hugr/model/export.py +++ b/hugr-py/src/hugr/model/export.py @@ -29,15 +29,7 @@ Tag, TailLoop, ) -from hugr.tys import ( - ConstKind, - FunctionKind, - Type, - TypeBound, - TypeParam, - TypeTypeParam, - Visibility, -) +from hugr.tys import ConstKind, FunctionKind, Type, TypeBound, TypeParam, TypeTypeParam class ModelExport: @@ -47,7 +39,8 @@ def __init__(self, hugr: Hugr): self.hugr = hugr self.link_ports: _UnionFind[InPort | OutPort] = _UnionFind() self.link_names: dict[InPort | OutPort, str] = {} - self.link_next = 0 + + # TODO: Store the hugr entrypoint for a, b in self.hugr.links(): self.link_ports.union(a, b) @@ -59,26 +52,20 @@ def link_name(self, port: InPort | OutPort) -> str: if root in self.link_names: return self.link_names[root] else: - index = str(self.link_next) - self.link_next += 1 + index = str(len(self.link_names)) self.link_names[root] = index return index - def export_node( - self, node: Node, virtual_input_links: Sequence[str] = [] - ) -> model.Node | None: + def export_node(self, node: Node) -> model.Node | None: """Export the node with the given node id.""" node_data = self.hugr[node] inputs = [self.link_name(InPort(node, i)) for i in range(node_data._num_inps)] - inputs = [*inputs, *virtual_input_links] - outputs = [self.link_name(OutPort(node, i)) for i in range(node_data._num_outs)] meta = self.export_json_meta(node) - meta += self.export_entrypoint_meta(node) # Add an order hint key to the node if necessary - if _has_order_links(self.hugr, node): + if _needs_order_key(self.hugr, node): meta.append(model.Apply("core.order_hint.key", [model.Literal(node.idx)])) match node_data.op: @@ -124,8 +111,7 @@ def export_node( case Conditional() as op: regions = [ - self.export_region_dfg(child, entrypoint_meta=True) - for child in node_data.children + self.export_region_dfg(child) for child in node_data.children ] signature = op.outer_signature().to_model() @@ -152,45 +138,30 @@ def export_node( ) case FuncDefn() as op: - name = _mangle_name(node, op.f_name, op.visibility) + name = _mangle_name(node, op.f_name) symbol = self.export_symbol( - name, op.visibility, op.signature.params, op.signature.body + name, op.signature.params, op.signature.body ) region = self.export_region_dfg(node) - if op.visibility == "Private": - meta.append(model.Apply("core.title", [model.Literal(op.f_name)])) - return model.Node( operation=model.DefineFunc(symbol), regions=[region], meta=meta ) case FuncDecl() as op: - name = _mangle_name(node, op.f_name, op.visibility) + name = _mangle_name(node, op.f_name) symbol = self.export_symbol( - name, op.visibility, op.signature.params, op.signature.body + name, op.signature.params, op.signature.body ) - - if op.visibility == "Private": - meta.append(model.Apply("core.title", [model.Literal(op.f_name)])) - return model.Node(operation=model.DeclareFunc(symbol), meta=meta) case AliasDecl() as op: - symbol = model.Symbol( - name=op.alias, - visibility="Public", - signature=model.Apply("core.type"), - ) + symbol = model.Symbol(name=op.alias, signature=model.Apply("core.type")) return model.Node(operation=model.DeclareAlias(symbol), meta=meta) case AliasDefn() as op: - symbol = model.Symbol( - name=op.alias, - visibility="Public", - signature=model.Apply("core.type"), - ) + symbol = model.Symbol(name=op.alias, signature=model.Apply("core.type")) alias_value = cast(model.Term, op.definition.to_model()) @@ -211,11 +182,6 @@ def export_node( error = f"Call node {node} is not connected to a function." raise ValueError(error) - # We ignore the static input edge since the function is passed - # as an argument instead. - assert len(inputs) == len(input_types) + 1 - inputs = inputs[0 : len(inputs) - 1] - func = model.Apply(func_name, func_args) return model.Node( @@ -236,8 +202,7 @@ def export_node( ) case LoadFunc() as op: - signature = op.outer_signature().to_model() - instantiation = op.instantiation.to_model() + signature = op.instantiation.to_model() func_args = cast( list[model.Term], [type.to_model() for type in op.type_args] ) @@ -251,10 +216,10 @@ def export_node( return model.Node( operation=model.CustomOp( - model.Apply("core.load_const", [instantiation, func]) + model.Apply("core.load_const", [signature, func]) ), signature=signature, - inputs=[], + inputs=inputs, outputs=outputs, meta=meta, ) @@ -307,7 +272,7 @@ def export_node( model.Apply("core.load_const", [type, value]) ), signature=signature, - inputs=[], + inputs=inputs, outputs=outputs, meta=meta, ) @@ -331,21 +296,31 @@ def export_node( case DataflowBlock() as op: region = self.export_region_dfg(node) - input_types = [model.List([type.to_model() for type in op.inputs])] + input_types = [ + model.Apply( + "core.ctrl", + [model.List([type.to_model() for type in op.inputs])], + ) + ] other_output_types = [type.to_model() for type in op.other_outputs] output_types = [ - model.List( + model.Apply( + "core.ctrl", [ - *[type.to_model() for type in row], - *other_output_types, - ] + model.List( + [ + *[type.to_model() for type in row], + *other_output_types, + ] + ) + ], ) for row in op.sum_ty.variant_rows ] signature = model.Apply( - "core.ctrl", + "core.fn", [model.List(input_types), model.List(output_types)], ) @@ -404,18 +379,10 @@ def export_json_meta(self, node: Node) -> list[model.Term]: return meta - def export_entrypoint_meta(self, node: Node) -> list[model.Term]: - """Export entrypoint metadata if the node is the module entrypoint.""" - if self.hugr.entrypoint == node: - return [model.Apply("core.entrypoint")] - else: - return [] - def export_region_module(self, node: Node) -> model.Region: """Export a module node as a module region.""" node_data = self.hugr[node] meta = self.export_json_meta(node) - meta += self.export_entrypoint_meta(node) children = [] for child in node_data.children: @@ -426,7 +393,7 @@ def export_region_module(self, node: Node) -> model.Region: return model.Region(kind=model.RegionKind.MODULE, children=children, meta=meta) - def export_region_dfg(self, node: Node, entrypoint_meta=False) -> model.Region: + def export_region_dfg(self, node: Node) -> model.Region: """Export the children of a node as a dataflow region.""" node_data = self.hugr[node] children: list[model.Node] = [] @@ -436,9 +403,6 @@ def export_region_dfg(self, node: Node, entrypoint_meta=False) -> model.Region: targets = [] meta = [] - if entrypoint_meta: - meta += self.export_entrypoint_meta(node) - for child in node_data.children: child_data = self.hugr[child] @@ -450,13 +414,6 @@ def export_region_dfg(self, node: Node, entrypoint_meta=False) -> model.Region: for i in range(child_data._num_outs) ] - if _has_order_links(self.hugr, child): - meta.append( - model.Apply( - "core.order_hint.input_key", [model.Literal(child.idx)] - ) - ) - case Output() as op: target_types = model.List([type.to_model() for type in op.types]) targets = [ @@ -464,13 +421,6 @@ def export_region_dfg(self, node: Node, entrypoint_meta=False) -> model.Region: for i in range(child_data._num_inps) ] - if _has_order_links(self.hugr, child): - meta.append( - model.Apply( - "core.order_hint.output_key", [model.Literal(child.idx)] - ) - ) - case _: child_node = self.export_node(child) @@ -479,13 +429,14 @@ def export_region_dfg(self, node: Node, entrypoint_meta=False) -> model.Region: children.append(child_node) - meta += [ - model.Apply( - "core.order_hint.order", - [model.Literal(child.idx), model.Literal(successor.idx)], - ) - for successor in self.hugr.outgoing_order_links(child) - ] + meta += [ + model.Apply( + "core.order_hint.order", + [model.Literal(child.idx), model.Literal(successor.idx)], + ) + for successor in self.hugr.outgoing_order_links(child) + if not isinstance(self.hugr[successor].op, Output) + ] signature = model.Apply("core.fn", [source_types, target_types]) @@ -507,7 +458,6 @@ def export_region_cfg(self, node: Node) -> model.Region: source_types: model.Term = model.Wildcard() target_types: model.Term = model.Wildcard() children = [] - meta = self.export_entrypoint_meta(node) for child in node_data.children: child_data = self.hugr[child] @@ -526,14 +476,9 @@ def export_region_cfg(self, node: Node) -> model.Region: source_types = model.List( [type.to_model() for type in op.inputs] ) - source = str(self.link_next) - self.link_next += 1 + source = self.link_name(OutPort(child, 0)) - child_node = self.export_node( - child, virtual_input_links=[source] - ) - else: - child_node = self.export_node(child) + child_node = self.export_node(child) if child_node is not None: children.append(child_node) @@ -545,13 +490,7 @@ def export_region_cfg(self, node: Node) -> model.Region: error = f"CFG {node} has no entry block." raise ValueError(error) - signature = model.Apply( - "core.ctrl", - [ - model.List([source_types]), - model.List([target_types]), - ], - ) + signature = model.Apply("core.fn", [source_types, target_types]) return model.Region( kind=model.RegionKind.CONTROL_FLOW, @@ -559,15 +498,10 @@ def export_region_cfg(self, node: Node) -> model.Region: sources=[source], signature=signature, children=children, - meta=meta, ) def export_symbol( - self, - name: str, - visibility: Visibility, - param_types: Sequence[TypeParam], - body: Type, + self, name: str, param_types: Sequence[TypeParam], body: Type ) -> model.Symbol: """Export a symbol.""" constraints = [] @@ -588,14 +522,13 @@ def export_symbol( return model.Symbol( name=name, - visibility=visibility, params=params, constraints=constraints, signature=cast(model.Term, body.to_model()), ) def find_func_input(self, node: Node) -> str | None: - """Find the symbol name of the function that a node is connected to, if any.""" + """Find the name of the function that a node is connected to, if any.""" try: func_node = next( out_port.node @@ -609,14 +542,12 @@ def find_func_input(self, node: Node) -> str | None: match self.hugr[func_node].op: case FuncDecl() as func_op: name = func_op.f_name - visibility = func_op.visibility case FuncDefn() as func_op: name = func_op.f_name - visibility = func_op.visibility case _: return None - return _mangle_name(func_node, name, visibility) + return _mangle_name(func_node, name) def find_const_input(self, node: Node) -> model.Term | None: """Find and export the constant that a node is connected to, if any.""" @@ -637,17 +568,10 @@ def find_const_input(self, node: Node) -> model.Term | None: return None -def _mangle_name(node: Node, name: str, visibility: Visibility) -> str: - match visibility: - case "Private": - # Until we come to an agreement on the uniqueness of names, - # we mangle the names by replacing id with the node id. - return f"_{node.idx}" - case "Public": - return name - case _: - error = f"Unexpected visibility {visibility}" - raise ValueError(error) +def _mangle_name(node: Node, name: str) -> str: + # Until we come to an agreement on the uniqueness of names, we mangle the names + # by adding the node id. + return f"_{name}_{node.idx}" T = TypeVar("T") @@ -686,12 +610,19 @@ def union(self, a: T, b: T): self.sizes[a] += self.sizes[b] -def _has_order_links(hugr: Hugr, node: Node) -> bool: - """Checks whether the node has any order links.""" - for _succ in hugr.outgoing_order_links(node): - return True - - for _pred in hugr.incoming_order_links(node): - return True +def _needs_order_key(hugr: Hugr, node: Node) -> bool: + """Checks whether the node has any order links for the purposes of + exporting order hint metadata. Order links to `Input` or `Output` + operations are ignored, since they are not present in the model format. + """ + for succ in hugr.outgoing_order_links(node): + succ_op = hugr[succ].op + if not isinstance(succ_op, Output): + return True + + for pred in hugr.incoming_order_links(node): + pred_op = hugr[pred].op + if not isinstance(pred_op, Input): + return True return False diff --git a/hugr-py/src/hugr/ops.py b/hugr-py/src/hugr/ops.py index 23ef810d84..fdcdd89082 100644 --- a/hugr-py/src/hugr/ops.py +++ b/hugr-py/src/hugr/ops.py @@ -19,14 +19,13 @@ import hugr._serialization.ops as sops from hugr import tys, val from hugr.hugr.node_port import Direction, InPort, Node, OutPort, PortOffset, Wire -from hugr.utils import comma_sep_repr, comma_sep_str, ser_it +from hugr.utils import comma_sep_str, ser_it if TYPE_CHECKING: from collections.abc import Sequence from hugr import ext from hugr._serialization.ops import BaseOp - from hugr.tys import Visibility @dataclass @@ -132,12 +131,11 @@ def port_type(self, port: InPort | OutPort) -> tys.Type: Bool """ + sig = self.outer_signature() if port.offset == -1: # Order port msg = "Order port has no type." raise ValueError(msg) - - sig = self.outer_signature() try: if port.direction == Direction.INCOMING: return sig.input[port.offset] @@ -242,12 +240,6 @@ def _to_serial(self, parent: Node) -> sops.Input: def _inputs(self) -> tys.TypeRow: return [] - def port_kind(self, port: InPort | OutPort) -> tys.Kind: - # Input only allows order edges on outgoing ports - if port.offset == -1 and port.direction == Direction.OUTGOING: - return tys.OrderKind() - return tys.ValueKind(self.port_type(port)) - def outer_signature(self) -> tys.FunctionType: return tys.FunctionType(input=[], output=self.types) @@ -265,10 +257,7 @@ class Output(DataflowOp, _PartialOp): """ _types: tys.TypeRow | None = field(default=None, repr=False) - - @property - def num_out(self) -> int: - return 0 + num_out: int = field(default=0, repr=False) @property def types(self) -> tys.TypeRow: @@ -280,12 +269,6 @@ def _to_serial(self, parent: Node) -> sops.Output: def _inputs(self) -> tys.TypeRow: return self.types - def port_kind(self, port: InPort | OutPort) -> tys.Kind: - # Output only allows order edges on incoming ports - if port.offset == -1 and port.direction == Direction.INCOMING: - return tys.OrderKind() - return tys.ValueKind(self.port_type(port)) - def outer_signature(self) -> tys.FunctionType: return tys.FunctionType(input=self.types, output=[]) @@ -551,7 +534,7 @@ def cached_signature(self) -> tys.FunctionType | None: ) def type_args(self) -> list[tys.TypeArg]: - return [tys.ListArg([t.type_arg() for t in self.types])] + return [tys.SequenceArg([t.type_arg() for t in self.types])] def __call__(self, *elements: ComWire) -> Command: return super().__call__(*elements) @@ -593,7 +576,7 @@ def cached_signature(self) -> tys.FunctionType | None: ) def type_args(self) -> list[tys.TypeArg]: - return [tys.ListArg([t.type_arg() for t in self.types])] + return [tys.SequenceArg([t.type_arg() for t in self.types])] @property def num_out(self) -> int: @@ -621,7 +604,7 @@ def name(self) -> str: return "UnpackTuple" -@dataclass(frozen=True) +@dataclass() class Tag(DataflowOp): """Tag a row of incoming values to make them a variant of a sum type. @@ -649,67 +632,63 @@ def outer_signature(self) -> tys.FunctionType: ) def __repr__(self) -> str: - if len(self.sum_ty.variant_rows) == 2: - left, right = self.sum_ty.variant_rows - if len(left) == 0 and self.tag == 1: - return f"Some({comma_sep_repr(right)})" - elif self.tag == 0: - return f"Left({left!r}, {right!r})" - else: - return f"Right({left!r}, {right!r})" - return f"Tag(tag={self.tag}, sum_ty={self.sum_ty!r})" - - def __str__(self) -> str: - if len(self.sum_ty.variant_rows) == 2: - left, right = self.sum_ty.variant_rows - if len(left) == 0 and self.tag == 1: - return "Some" - elif self.tag == 0: - return "Left" - else: - return "Right" return f"Tag({self.tag})" -@dataclass(frozen=True, eq=False, repr=False) +@dataclass class Some(Tag): """Tag operation for the `Some` variant of an Option type. Example: # construct a Some variant holding a row of Bool and Unit types >>> Some(tys.Bool, tys.Unit) - Some(Bool, Unit) + Some """ def __init__(self, *some_tys: tys.Type) -> None: super().__init__(1, tys.Option(*some_tys)) + def __repr__(self) -> str: + return "Some" + -@dataclass(frozen=True, eq=False, repr=False) +@dataclass class Right(Tag): """Tag operation for the `Right` variant of an type.""" def __init__(self, either_type: tys.Either) -> None: super().__init__(1, either_type) + def __repr__(self) -> str: + return "Right" + -@dataclass(frozen=True, eq=False, repr=False) +@dataclass class Left(Tag): """Tag operation for the `Left` variant of an type.""" def __init__(self, either_type: tys.Either) -> None: super().__init__(0, either_type) + def __repr__(self) -> str: + return "Left" + class Continue(Left): """Tag operation for the `Continue` variant of a TailLoop controlling Either type. """ + def __repr__(self) -> str: + return "Continue" + class Break(Right): """Tag operation for the `Break` variant of a TailLoop controlling Either type.""" + def __repr__(self) -> str: + return "Break" + class DfParentOp(Op, Protocol): """Abstract parent of dataflow graph operations. Can be queried for the @@ -840,7 +819,7 @@ def _inputs(self) -> tys.TypeRow: @dataclass class DataflowBlock(DfParentOp): - """Parent of non-exit basic block in a control flow graph.""" + """Parent of non-entry basic block in a control flow graph.""" #: Inputs types of the inner dataflow graph. inputs: tys.TypeRow @@ -1171,8 +1150,6 @@ class FuncDefn(DfParentOp): params: list[tys.TypeParam] = field(default_factory=list) _outputs: tys.TypeRow | None = field(default=None, repr=False) num_out: int = field(default=1, repr=False) - #: Visibility (for linking). - visibility: Visibility = "Private" @property def outputs(self) -> tys.TypeRow: @@ -1199,7 +1176,6 @@ def _to_serial(self, parent: Node) -> sops.FuncDefn: parent=parent.idx, name=self.f_name, signature=self.signature._to_serial(), - visibility=self.visibility, ) def inner_signature(self) -> tys.FunctionType: @@ -1231,15 +1207,12 @@ class FuncDecl(Op): #: polymorphic function signature signature: tys.PolyFuncType num_out: int = field(default=1, repr=False) - #: Visibility (for linking). - visibility: Visibility = "Public" def _to_serial(self, parent: Node) -> sops.FuncDecl: return sops.FuncDecl( parent=parent.idx, name=self.f_name, signature=self.signature._to_serial(), - visibility=self.visibility, ) def port_kind(self, port: InPort | OutPort) -> tys.Kind: @@ -1418,9 +1391,7 @@ class LoadFunc(_CallOrLoad, DataflowOp): is provided. """ - @property - def num_out(self) -> int: - return 1 + num_out: int = field(default=1, repr=False) def _to_serial(self, parent: Node) -> sops.LoadFunction: return sops.LoadFunction( diff --git a/hugr-py/src/hugr/std/_json_defs/collections/borrow_arr.json b/hugr-py/src/hugr/std/_json_defs/collections/borrow_arr.json deleted file mode 100644 index 1774b4aea6..0000000000 --- a/hugr-py/src/hugr/std/_json_defs/collections/borrow_arr.json +++ /dev/null @@ -1,1139 +0,0 @@ -{ - "version": "0.1.1", - "name": "collections.borrow_arr", - "types": { - "borrow_array": { - "extension": "collections.borrow_arr", - "name": "borrow_array", - "params": [ - { - "tp": "BoundedNat", - "bound": null - }, - { - "tp": "Type", - "b": "A" - } - ], - "description": "Fixed-length borrow array", - "bound": { - "b": "Explicit", - "bound": "A" - } - } - }, - "operations": { - "borrow": { - "extension": "collections.borrow_arr", - "name": "borrow", - "description": "Take an element from a borrow array (panicking if it was already taken before)", - "signature": { - "params": [ - { - "tp": "BoundedNat", - "bound": null - }, - { - "tp": "Type", - "b": "A" - } - ], - "body": { - "input": [ - { - "t": "Opaque", - "extension": "collections.borrow_arr", - "id": "borrow_array", - "args": [ - { - "tya": "Variable", - "idx": 0, - "cached_decl": { - "tp": "BoundedNat", - "bound": null - } - }, - { - "tya": "Type", - "ty": { - "t": "V", - "i": 1, - "b": "A" - } - } - ], - "bound": "A" - }, - { - "t": "I" - } - ], - "output": [ - { - "t": "V", - "i": 1, - "b": "A" - }, - { - "t": "Opaque", - "extension": "collections.borrow_arr", - "id": "borrow_array", - "args": [ - { - "tya": "Variable", - "idx": 0, - "cached_decl": { - "tp": "BoundedNat", - "bound": null - } - }, - { - "tya": "Type", - "ty": { - "t": "V", - "i": 1, - "b": "A" - } - } - ], - "bound": "A" - } - ] - } - }, - "binary": false - }, - "clone": { - "extension": "collections.borrow_arr", - "name": "clone", - "description": "Clones an array with copyable elements", - "signature": { - "params": [ - { - "tp": "BoundedNat", - "bound": null - }, - { - "tp": "Type", - "b": "C" - } - ], - "body": { - "input": [ - { - "t": "Opaque", - "extension": "collections.borrow_arr", - "id": "borrow_array", - "args": [ - { - "tya": "Variable", - "idx": 0, - "cached_decl": { - "tp": "BoundedNat", - "bound": null - } - }, - { - "tya": "Type", - "ty": { - "t": "V", - "i": 1, - "b": "C" - } - } - ], - "bound": "A" - } - ], - "output": [ - { - "t": "Opaque", - "extension": "collections.borrow_arr", - "id": "borrow_array", - "args": [ - { - "tya": "Variable", - "idx": 0, - "cached_decl": { - "tp": "BoundedNat", - "bound": null - } - }, - { - "tya": "Type", - "ty": { - "t": "V", - "i": 1, - "b": "C" - } - } - ], - "bound": "A" - }, - { - "t": "Opaque", - "extension": "collections.borrow_arr", - "id": "borrow_array", - "args": [ - { - "tya": "Variable", - "idx": 0, - "cached_decl": { - "tp": "BoundedNat", - "bound": null - } - }, - { - "tya": "Type", - "ty": { - "t": "V", - "i": 1, - "b": "C" - } - } - ], - "bound": "A" - } - ] - } - }, - "binary": false - }, - "discard": { - "extension": "collections.borrow_arr", - "name": "discard", - "description": "Discards an array with copyable elements", - "signature": { - "params": [ - { - "tp": "BoundedNat", - "bound": null - }, - { - "tp": "Type", - "b": "C" - } - ], - "body": { - "input": [ - { - "t": "Opaque", - "extension": "collections.borrow_arr", - "id": "borrow_array", - "args": [ - { - "tya": "Variable", - "idx": 0, - "cached_decl": { - "tp": "BoundedNat", - "bound": null - } - }, - { - "tya": "Type", - "ty": { - "t": "V", - "i": 1, - "b": "C" - } - } - ], - "bound": "A" - } - ], - "output": [] - } - }, - "binary": false - }, - "discard_all_borrowed": { - "extension": "collections.borrow_arr", - "name": "discard_all_borrowed", - "description": "Discard a borrow array where all elements have been borrowed", - "signature": { - "params": [ - { - "tp": "BoundedNat", - "bound": null - }, - { - "tp": "Type", - "b": "A" - } - ], - "body": { - "input": [ - { - "t": "Opaque", - "extension": "collections.borrow_arr", - "id": "borrow_array", - "args": [ - { - "tya": "Variable", - "idx": 0, - "cached_decl": { - "tp": "BoundedNat", - "bound": null - } - }, - { - "tya": "Type", - "ty": { - "t": "V", - "i": 1, - "b": "A" - } - } - ], - "bound": "A" - } - ], - "output": [] - } - }, - "binary": false - }, - "discard_empty": { - "extension": "collections.borrow_arr", - "name": "discard_empty", - "description": "Discard an empty array", - "signature": { - "params": [ - { - "tp": "Type", - "b": "A" - } - ], - "body": { - "input": [ - { - "t": "Opaque", - "extension": "collections.borrow_arr", - "id": "borrow_array", - "args": [ - { - "tya": "BoundedNat", - "n": 0 - }, - { - "tya": "Type", - "ty": { - "t": "V", - "i": 0, - "b": "A" - } - } - ], - "bound": "A" - } - ], - "output": [] - } - }, - "binary": false - }, - "from_array": { - "extension": "collections.borrow_arr", - "name": "from_array", - "description": "Turns `array` into `borrow_array`", - "signature": { - "params": [ - { - "tp": "BoundedNat", - "bound": null - }, - { - "tp": "Type", - "b": "A" - } - ], - "body": { - "input": [ - { - "t": "Opaque", - "extension": "collections.array", - "id": "array", - "args": [ - { - "tya": "Variable", - "idx": 0, - "cached_decl": { - "tp": "BoundedNat", - "bound": null - } - }, - { - "tya": "Type", - "ty": { - "t": "V", - "i": 1, - "b": "A" - } - } - ], - "bound": "A" - } - ], - "output": [ - { - "t": "Opaque", - "extension": "collections.borrow_arr", - "id": "borrow_array", - "args": [ - { - "tya": "Variable", - "idx": 0, - "cached_decl": { - "tp": "BoundedNat", - "bound": null - } - }, - { - "tya": "Type", - "ty": { - "t": "V", - "i": 1, - "b": "A" - } - } - ], - "bound": "A" - } - ] - } - }, - "binary": false - }, - "get": { - "extension": "collections.borrow_arr", - "name": "get", - "description": "Get an element from an array", - "signature": { - "params": [ - { - "tp": "BoundedNat", - "bound": null - }, - { - "tp": "Type", - "b": "C" - } - ], - "body": { - "input": [ - { - "t": "Opaque", - "extension": "collections.borrow_arr", - "id": "borrow_array", - "args": [ - { - "tya": "Variable", - "idx": 0, - "cached_decl": { - "tp": "BoundedNat", - "bound": null - } - }, - { - "tya": "Type", - "ty": { - "t": "V", - "i": 1, - "b": "C" - } - } - ], - "bound": "A" - }, - { - "t": "I" - } - ], - "output": [ - { - "t": "Sum", - "s": "General", - "rows": [ - [], - [ - { - "t": "V", - "i": 1, - "b": "C" - } - ] - ] - }, - { - "t": "Opaque", - "extension": "collections.borrow_arr", - "id": "borrow_array", - "args": [ - { - "tya": "Variable", - "idx": 0, - "cached_decl": { - "tp": "BoundedNat", - "bound": null - } - }, - { - "tya": "Type", - "ty": { - "t": "V", - "i": 1, - "b": "C" - } - } - ], - "bound": "A" - } - ] - } - }, - "binary": false - }, - "new_all_borrowed": { - "extension": "collections.borrow_arr", - "name": "new_all_borrowed", - "description": "Create a new borrow array that contains no elements", - "signature": { - "params": [ - { - "tp": "BoundedNat", - "bound": null - }, - { - "tp": "Type", - "b": "A" - } - ], - "body": { - "input": [], - "output": [ - { - "t": "Opaque", - "extension": "collections.borrow_arr", - "id": "borrow_array", - "args": [ - { - "tya": "Variable", - "idx": 0, - "cached_decl": { - "tp": "BoundedNat", - "bound": null - } - }, - { - "tya": "Type", - "ty": { - "t": "V", - "i": 1, - "b": "A" - } - } - ], - "bound": "A" - } - ] - } - }, - "binary": false - }, - "new_array": { - "extension": "collections.borrow_arr", - "name": "new_array", - "description": "Create a new array from elements", - "signature": null, - "binary": true - }, - "pop_left": { - "extension": "collections.borrow_arr", - "name": "pop_left", - "description": "Pop an element from the left of an array", - "signature": null, - "binary": true - }, - "pop_right": { - "extension": "collections.borrow_arr", - "name": "pop_right", - "description": "Pop an element from the right of an array", - "signature": null, - "binary": true - }, - "repeat": { - "extension": "collections.borrow_arr", - "name": "repeat", - "description": "Creates a new array whose elements are initialised by calling the given function n times", - "signature": { - "params": [ - { - "tp": "BoundedNat", - "bound": null - }, - { - "tp": "Type", - "b": "A" - } - ], - "body": { - "input": [ - { - "t": "G", - "input": [], - "output": [ - { - "t": "V", - "i": 1, - "b": "A" - } - ] - } - ], - "output": [ - { - "t": "Opaque", - "extension": "collections.borrow_arr", - "id": "borrow_array", - "args": [ - { - "tya": "Variable", - "idx": 0, - "cached_decl": { - "tp": "BoundedNat", - "bound": null - } - }, - { - "tya": "Type", - "ty": { - "t": "V", - "i": 1, - "b": "A" - } - } - ], - "bound": "A" - } - ] - } - }, - "binary": false - }, - "return": { - "extension": "collections.borrow_arr", - "name": "return", - "description": "Put an element into a borrow array (panicking if there is an element already)", - "signature": { - "params": [ - { - "tp": "BoundedNat", - "bound": null - }, - { - "tp": "Type", - "b": "A" - } - ], - "body": { - "input": [ - { - "t": "Opaque", - "extension": "collections.borrow_arr", - "id": "borrow_array", - "args": [ - { - "tya": "Variable", - "idx": 0, - "cached_decl": { - "tp": "BoundedNat", - "bound": null - } - }, - { - "tya": "Type", - "ty": { - "t": "V", - "i": 1, - "b": "A" - } - } - ], - "bound": "A" - }, - { - "t": "I" - }, - { - "t": "V", - "i": 1, - "b": "A" - } - ], - "output": [ - { - "t": "Opaque", - "extension": "collections.borrow_arr", - "id": "borrow_array", - "args": [ - { - "tya": "Variable", - "idx": 0, - "cached_decl": { - "tp": "BoundedNat", - "bound": null - } - }, - { - "tya": "Type", - "ty": { - "t": "V", - "i": 1, - "b": "A" - } - } - ], - "bound": "A" - } - ] - } - }, - "binary": false - }, - "scan": { - "extension": "collections.borrow_arr", - "name": "scan", - "description": "A combination of map and foldl. Applies a function to each element of the array with an accumulator that is passed through from start to finish. Returns the resulting array and the final state of the accumulator.", - "signature": { - "params": [ - { - "tp": "BoundedNat", - "bound": null - }, - { - "tp": "Type", - "b": "A" - }, - { - "tp": "Type", - "b": "A" - }, - { - "tp": "List", - "param": { - "tp": "Type", - "b": "A" - } - } - ], - "body": { - "input": [ - { - "t": "Opaque", - "extension": "collections.borrow_arr", - "id": "borrow_array", - "args": [ - { - "tya": "Variable", - "idx": 0, - "cached_decl": { - "tp": "BoundedNat", - "bound": null - } - }, - { - "tya": "Type", - "ty": { - "t": "V", - "i": 1, - "b": "A" - } - } - ], - "bound": "A" - }, - { - "t": "G", - "input": [ - { - "t": "V", - "i": 1, - "b": "A" - }, - { - "t": "R", - "i": 3, - "b": "A" - } - ], - "output": [ - { - "t": "V", - "i": 2, - "b": "A" - }, - { - "t": "R", - "i": 3, - "b": "A" - } - ] - }, - { - "t": "R", - "i": 3, - "b": "A" - } - ], - "output": [ - { - "t": "Opaque", - "extension": "collections.borrow_arr", - "id": "borrow_array", - "args": [ - { - "tya": "Variable", - "idx": 0, - "cached_decl": { - "tp": "BoundedNat", - "bound": null - } - }, - { - "tya": "Type", - "ty": { - "t": "V", - "i": 2, - "b": "A" - } - } - ], - "bound": "A" - }, - { - "t": "R", - "i": 3, - "b": "A" - } - ] - } - }, - "binary": false - }, - "set": { - "extension": "collections.borrow_arr", - "name": "set", - "description": "Set an element in an array", - "signature": { - "params": [ - { - "tp": "BoundedNat", - "bound": null - }, - { - "tp": "Type", - "b": "A" - } - ], - "body": { - "input": [ - { - "t": "Opaque", - "extension": "collections.borrow_arr", - "id": "borrow_array", - "args": [ - { - "tya": "Variable", - "idx": 0, - "cached_decl": { - "tp": "BoundedNat", - "bound": null - } - }, - { - "tya": "Type", - "ty": { - "t": "V", - "i": 1, - "b": "A" - } - } - ], - "bound": "A" - }, - { - "t": "I" - }, - { - "t": "V", - "i": 1, - "b": "A" - } - ], - "output": [ - { - "t": "Sum", - "s": "General", - "rows": [ - [ - { - "t": "V", - "i": 1, - "b": "A" - }, - { - "t": "Opaque", - "extension": "collections.borrow_arr", - "id": "borrow_array", - "args": [ - { - "tya": "Variable", - "idx": 0, - "cached_decl": { - "tp": "BoundedNat", - "bound": null - } - }, - { - "tya": "Type", - "ty": { - "t": "V", - "i": 1, - "b": "A" - } - } - ], - "bound": "A" - } - ], - [ - { - "t": "V", - "i": 1, - "b": "A" - }, - { - "t": "Opaque", - "extension": "collections.borrow_arr", - "id": "borrow_array", - "args": [ - { - "tya": "Variable", - "idx": 0, - "cached_decl": { - "tp": "BoundedNat", - "bound": null - } - }, - { - "tya": "Type", - "ty": { - "t": "V", - "i": 1, - "b": "A" - } - } - ], - "bound": "A" - } - ] - ] - } - ] - } - }, - "binary": false - }, - "swap": { - "extension": "collections.borrow_arr", - "name": "swap", - "description": "Swap two elements in an array", - "signature": { - "params": [ - { - "tp": "BoundedNat", - "bound": null - }, - { - "tp": "Type", - "b": "A" - } - ], - "body": { - "input": [ - { - "t": "Opaque", - "extension": "collections.borrow_arr", - "id": "borrow_array", - "args": [ - { - "tya": "Variable", - "idx": 0, - "cached_decl": { - "tp": "BoundedNat", - "bound": null - } - }, - { - "tya": "Type", - "ty": { - "t": "V", - "i": 1, - "b": "A" - } - } - ], - "bound": "A" - }, - { - "t": "I" - }, - { - "t": "I" - } - ], - "output": [ - { - "t": "Sum", - "s": "General", - "rows": [ - [ - { - "t": "Opaque", - "extension": "collections.borrow_arr", - "id": "borrow_array", - "args": [ - { - "tya": "Variable", - "idx": 0, - "cached_decl": { - "tp": "BoundedNat", - "bound": null - } - }, - { - "tya": "Type", - "ty": { - "t": "V", - "i": 1, - "b": "A" - } - } - ], - "bound": "A" - } - ], - [ - { - "t": "Opaque", - "extension": "collections.borrow_arr", - "id": "borrow_array", - "args": [ - { - "tya": "Variable", - "idx": 0, - "cached_decl": { - "tp": "BoundedNat", - "bound": null - } - }, - { - "tya": "Type", - "ty": { - "t": "V", - "i": 1, - "b": "A" - } - } - ], - "bound": "A" - } - ] - ] - } - ] - } - }, - "binary": false - }, - "to_array": { - "extension": "collections.borrow_arr", - "name": "to_array", - "description": "Turns `borrow_array` into `array`", - "signature": { - "params": [ - { - "tp": "BoundedNat", - "bound": null - }, - { - "tp": "Type", - "b": "A" - } - ], - "body": { - "input": [ - { - "t": "Opaque", - "extension": "collections.borrow_arr", - "id": "borrow_array", - "args": [ - { - "tya": "Variable", - "idx": 0, - "cached_decl": { - "tp": "BoundedNat", - "bound": null - } - }, - { - "tya": "Type", - "ty": { - "t": "V", - "i": 1, - "b": "A" - } - } - ], - "bound": "A" - } - ], - "output": [ - { - "t": "Opaque", - "extension": "collections.array", - "id": "array", - "args": [ - { - "tya": "Variable", - "idx": 0, - "cached_decl": { - "tp": "BoundedNat", - "bound": null - } - }, - { - "tya": "Type", - "ty": { - "t": "V", - "i": 1, - "b": "A" - } - } - ], - "bound": "A" - } - ] - } - }, - "binary": false - }, - "unpack": { - "extension": "collections.borrow_arr", - "name": "unpack", - "description": "Unpack an array into its elements", - "signature": null, - "binary": true - } - } -} diff --git a/hugr-py/src/hugr/std/_json_defs/prelude.json b/hugr-py/src/hugr/std/_json_defs/prelude.json index 81c2f948a0..7cf1d02c70 100644 --- a/hugr-py/src/hugr/std/_json_defs/prelude.json +++ b/hugr-py/src/hugr/std/_json_defs/prelude.json @@ -1,5 +1,5 @@ { - "version": "0.2.1", + "version": "0.2.0", "name": "prelude", "types": { "error": { @@ -77,38 +77,6 @@ }, "binary": false }, - "MakeError": { - "extension": "prelude", - "name": "MakeError", - "description": "Create an error value", - "signature": { - "params": [], - "body": { - "input": [ - { - "t": "I" - }, - { - "t": "Opaque", - "extension": "prelude", - "id": "string", - "args": [], - "bound": "C" - } - ], - "output": [ - { - "t": "Opaque", - "extension": "prelude", - "id": "error", - "args": [], - "bound": "C" - } - ] - } - }, - "binary": false - }, "MakeTuple": { "extension": "prelude", "name": "MakeTuple", diff --git a/hugr-py/src/hugr/std/collections/array.py b/hugr-py/src/hugr/std/collections/array.py index d7f70a2318..958b826502 100644 --- a/hugr-py/src/hugr/std/collections/array.py +++ b/hugr-py/src/hugr/std/collections/array.py @@ -54,7 +54,7 @@ def size(self) -> int | None: return None def type_bound(self) -> tys.TypeBound: - return tys.TypeBound.Linear + return tys.TypeBound.Any @dataclass diff --git a/hugr-py/src/hugr/std/collections/borrow_array.py b/hugr-py/src/hugr/std/collections/borrow_array.py deleted file mode 100644 index 9d01f86e6a..0000000000 --- a/hugr-py/src/hugr/std/collections/borrow_array.py +++ /dev/null @@ -1,94 +0,0 @@ -"""Borrow array types and operations.""" - -from __future__ import annotations - -from dataclasses import dataclass -from typing import cast - -import hugr.model as model -from hugr import tys, val -from hugr.std import _load_extension -from hugr.utils import comma_sep_str - -EXTENSION = _load_extension("collections.borrow_arr") - - -@dataclass(eq=False) -class BorrowArray(tys.ExtType): - """Fixed `size` borrow array of `ty` elements.""" - - def __init__(self, ty: tys.Type, size: int | tys.TypeArg) -> None: - if isinstance(size, int): - size = tys.BoundedNatArg(size) - - err_msg = ( - f"Borrow array size must be a bounded natural or a nat variable, not {size}" - ) - match size: - case tys.BoundedNatArg(_n): - pass - case tys.VariableArg(_idx, param): - if not isinstance(param, tys.BoundedNatParam): - raise ValueError(err_msg) # noqa: TRY004 - case _: - raise ValueError(err_msg) - - ty_arg = tys.TypeTypeArg(ty) - - self.type_def = EXTENSION.types["borrow_array"] - self.args = [size, ty_arg] - - @property - def ty(self) -> tys.Type: - assert isinstance( - self.args[1], tys.TypeTypeArg - ), "Borrow array elements must have a valid type" - return self.args[1].ty - - @property - def size(self) -> int | None: - """If the borrow array has a concrete size, return it. - - Otherwise, return None. - """ - if isinstance(self.args[0], tys.BoundedNatArg): - return self.args[0].n - return None - - def type_bound(self) -> tys.TypeBound: - return tys.TypeBound.Linear - - -# Note that only borrow array values with no elements borrowed should be emitted. -@dataclass -class BorrowArrayVal(val.ExtensionValue): - """Constant value for a statically sized borrow array of elements.""" - - v: list[val.Value] - ty: BorrowArray - - def __init__(self, v: list[val.Value], elem_ty: tys.Type) -> None: - self.v = v - self.ty = BorrowArray(elem_ty, len(v)) - - def to_value(self) -> val.Extension: - name = "BorrowArrayValue" - # The value list must be serialized at this point, otherwise the - # `Extension` value would not be serializable. - vs = [v._to_serial_root() for v in self.v] - element_ty = self.ty.ty._to_serial_root() - serial_val = {"values": vs, "typ": element_ty} - return val.Extension(name, typ=self.ty, val=serial_val) - - def __str__(self) -> str: - return f"borrow_array({comma_sep_str(self.v)})" - - def to_model(self) -> model.Term: - return model.Apply( - "collections.borrow_array.const", - [ - model.Literal(len(self.v)), - cast(model.Term, self.ty.ty.to_model()), - model.List([value.to_model() for value in self.v]), - ], - ) diff --git a/hugr-py/src/hugr/std/collections/static_array.py b/hugr-py/src/hugr/std/collections/static_array.py index 84731ee1bf..60975b336d 100644 --- a/hugr-py/src/hugr/std/collections/static_array.py +++ b/hugr-py/src/hugr/std/collections/static_array.py @@ -50,9 +50,6 @@ def __init__(self, v: list[val.Value], elem_ty: tys.Type, name: str) -> None: self.name = name def to_value(self) -> val.Extension: - # Encode the nested values as JSON strings directly, to mirror what - # happens when loading (where we can't decode the constant payload back - # into specialized `Value`s). serial_val = { "value": { "values": [v._to_serial_root() for v in self.v], diff --git a/hugr-py/src/hugr/std/int.py b/hugr-py/src/hugr/std/int.py index 4107fcd284..f58bc9e3eb 100644 --- a/hugr-py/src/hugr/std/int.py +++ b/hugr-py/src/hugr/std/int.py @@ -52,39 +52,16 @@ def _int_tv(index: int) -> tys.ExtType: INT_T = int_t(5) -def _to_unsigned(val: int, bits: int) -> int: - """Convert a signed integer to its unsigned representation - in twos-complement form. - - Positive integers are unchanged, while negative integers - are converted by adding 2^bits to the value. - - Raises ValueError if the value is out of range for the given bit width - (valid range is [-2^(bits-1), 2^(bits-1)-1]). - """ - half_max = 1 << (bits - 1) - min_val = -half_max - max_val = half_max - 1 - if val < min_val or val > max_val: - msg = f"Value {val} out of range for {bits}-bit signed integer." - raise ValueError(msg) # - - if val < 0: - return (1 << bits) + val - return val - - @dataclass class IntVal(val.ExtensionValue): - """Custom value for a signed integer.""" + """Custom value for an integer.""" v: int width: int = field(default=5) def to_value(self) -> val.Extension: name = "ConstInt" - unsigned = _to_unsigned(self.v, 1 << self.width) - payload = {"log_width": self.width, "value": unsigned} + payload = {"log_width": self.width, "value": self.v} return val.Extension( name, typ=int_t(self.width), @@ -95,9 +72,8 @@ def __str__(self) -> str: return f"{self.v}" def to_model(self) -> model.Term: - unsigned = _to_unsigned(self.v, 1 << self.width) return model.Apply( - "arithmetic.int.const", [model.Literal(self.width), model.Literal(unsigned)] + "arithmetic.int.const", [model.Literal(self.width), model.Literal(self.v)] ) diff --git a/hugr-py/src/hugr/tys.py b/hugr-py/src/hugr/tys.py index a59a9a90bf..8411f19bfa 100644 --- a/hugr-py/src/hugr/tys.py +++ b/hugr-py/src/hugr/tys.py @@ -2,13 +2,12 @@ from __future__ import annotations -import base64 from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Literal, Protocol, cast, runtime_checkable +from typing import TYPE_CHECKING, Protocol, cast, runtime_checkable import hugr._serialization.tys as stys import hugr.model as model -from hugr.utils import comma_sep_repr, comma_sep_str, comma_sep_str_paren, ser_it +from hugr.utils import comma_sep_repr, comma_sep_str, ser_it if TYPE_CHECKING: from collections.abc import Iterable, Sequence @@ -19,7 +18,6 @@ ExtensionId = stys.ExtensionId ExtensionSet = stys.ExtensionSet TypeBound = stys.TypeBound -Visibility = Literal["Public", "Private"] @runtime_checkable @@ -69,7 +67,7 @@ def type_bound(self) -> stys.TypeBound: >>> Tuple(Bool, Bool).type_bound() >>> Tuple(Qubit, Bool).type_bound() - + """ ... # pragma: no cover @@ -156,34 +154,6 @@ def to_model(self) -> model.Term: return model.Apply("core.str") -@dataclass(frozen=True) -class FloatParam(TypeParam): - """Float type parameter.""" - - def _to_serial(self) -> stys.FloatParam: - return stys.FloatParam() - - def __str__(self) -> str: - return "Float" - - def to_model(self) -> model.Term: - return model.Apply("core.float") - - -@dataclass(frozen=True) -class BytesParam(TypeParam): - """Bytes type parameter.""" - - def _to_serial(self) -> stys.BytesParam: - return stys.BytesParam() - - def __str__(self) -> str: - return "Bytes" - - def to_model(self) -> model.Term: - return model.Apply("core.bytes") - - @dataclass(frozen=True) class ListParam(TypeParam): """Type parameter which requires a list of type arguments.""" @@ -275,118 +245,24 @@ def to_model(self) -> model.Term: @dataclass(frozen=True) -class FloatArg(TypeArg): - """A floating point type argument.""" - - value: float - - def _to_serial(self) -> stys.FloatArg: - return stys.FloatArg(value=self.value) - - def __str__(self) -> str: - return f"{self.value}" - - def to_model(self) -> model.Term: - return model.Literal(self.value) - - -@dataclass(frozen=True) -class BytesArg(TypeArg): - """A bytes type argument.""" - - value: bytes - - def _to_serial(self) -> stys.BytesArg: - value = base64.b64encode(self.value).decode() - return stys.BytesArg(value=value) - - def __str__(self) -> str: - return "bytes" - - def to_model(self) -> model.Term: - return model.Literal(self.value) - - -@dataclass(frozen=True) -class ListArg(TypeArg): - """Sequence of type arguments for a :class:`ListParam`.""" +class SequenceArg(TypeArg): + """Sequence of type arguments, for a :class:`ListParam` or :class:`TupleParam`.""" elems: list[TypeArg] - def _to_serial(self) -> stys.ListArg: - return stys.ListArg(elems=ser_it(self.elems)) - - def resolve(self, registry: ext.ExtensionRegistry) -> TypeArg: - return ListArg([arg.resolve(registry) for arg in self.elems]) - - def __str__(self) -> str: - return f"[{comma_sep_str(self.elems)}]" - - def to_model(self) -> model.Term: - return model.List([elem.to_model() for elem in self.elems]) - - -@dataclass(frozen=True) -class ListConcatArg(TypeArg): - """Sequence of lists to concatenate for a :class:`ListParam`.""" - - lists: list[TypeArg] - - def _to_serial(self) -> stys.ListConcatArg: - return stys.ListConcatArg(lists=ser_it(self.lists)) + def _to_serial(self) -> stys.SequenceArg: + return stys.SequenceArg(elems=ser_it(self.elems)) def resolve(self, registry: ext.ExtensionRegistry) -> TypeArg: - return ListConcatArg([arg.resolve(registry) for arg in self.lists]) - - def __str__(self) -> str: - lists = comma_sep_str(f"... {list}" for list in self.lists) - return f"[{lists}]" - - def to_model(self) -> model.Term: - return model.List( - [model.Splice(cast(model.Term, elem.to_model())) for elem in self.lists] - ) - - -@dataclass(frozen=True) -class TupleArg(TypeArg): - """Sequence of type arguments for a :class:`TupleParam`.""" - - elems: list[TypeArg] - - def _to_serial(self) -> stys.TupleArg: - return stys.TupleArg(elems=ser_it(self.elems)) - - def resolve(self, registry: ext.ExtensionRegistry) -> TypeArg: - return TupleArg([arg.resolve(registry) for arg in self.elems]) + return SequenceArg([arg.resolve(registry) for arg in self.elems]) def __str__(self) -> str: return f"({comma_sep_str(self.elems)})" def to_model(self) -> model.Term: - return model.Tuple([elem.to_model() for elem in self.elems]) - - -@dataclass(frozen=True) -class TupleConcatArg(TypeArg): - """Sequence of tuples to concatenate for a :class:`TupleParam`.""" - - tuples: list[TypeArg] - - def _to_serial(self) -> stys.TupleConcatArg: - return stys.TupleConcatArg(tuples=ser_it(self.tuples)) - - def resolve(self, registry: ext.ExtensionRegistry) -> TypeArg: - return TupleConcatArg([arg.resolve(registry) for arg in self.tuples]) - - def __str__(self) -> str: - tuples = comma_sep_str(f"... {tuple}" for tuple in self.tuples) - return f"({tuples})" - - def to_model(self) -> model.Term: - return model.Tuple( - [model.Splice(cast(model.Term, elem.to_model())) for elem in self.tuples] - ) + # TODO: We should separate lists and tuples. + # For now we assume that this is a list. + return model.List([elem.to_model() for elem in self.elems]) @dataclass(frozen=True) @@ -430,38 +306,7 @@ def as_tuple(self) -> Tuple: return Tuple(*self.variant_rows[0]) def __repr__(self) -> str: - if self == Bool: - return "Bool" - elif self == Unit: - return "Unit" - elif all(len(row) == 0 for row in self.variant_rows): - return f"UnitSum({len(self.variant_rows)})" - elif len(self.variant_rows) == 1: - return f"Tuple{tuple(self.variant_rows[0])}" - elif len(self.variant_rows) == 2 and len(self.variant_rows[0]) == 0: - return f"Option({comma_sep_repr(self.variant_rows[1])})" - elif len(self.variant_rows) == 2: - left, right = self.variant_rows - return f"Either(left={left}, right={right})" - else: - return f"Sum({self.variant_rows})" - - def __str__(self) -> str: - if self == Bool: - return "Bool" - elif self == Unit: - return "Unit" - elif all(len(row) == 0 for row in self.variant_rows): - return f"UnitSum({len(self.variant_rows)})" - elif len(self.variant_rows) == 1: - return f"Tuple{tuple(self.variant_rows[0])}" - elif len(self.variant_rows) == 2 and len(self.variant_rows[0]) == 0: - return f"Option({comma_sep_str(self.variant_rows[1])})" - elif len(self.variant_rows) == 2: - left, right = self.variant_rows - return f"Either({comma_sep_str_paren(left)}, {comma_sep_str_paren(right)})" - else: - return f"Sum({self.variant_rows})" + return f"Sum({self.variant_rows})" def __eq__(self, other: object) -> bool: return isinstance(other, Sum) and self.variant_rows == other.variant_rows @@ -480,7 +325,7 @@ def to_model(self) -> model.Term: return model.Apply("core.adt", [variants]) -@dataclass(eq=False, repr=False) +@dataclass(eq=False) class UnitSum(Sum): """Simple :class:`Sum` type with `size` variants of empty rows.""" @@ -493,14 +338,18 @@ def __init__(self, size: int): def _to_serial(self) -> stys.UnitSum: # type: ignore[override] return stys.UnitSum(size=self.size) + def __repr__(self) -> str: + if self == Bool: + return "Bool" + elif self == Unit: + return "Unit" + return f"UnitSum({self.size})" + def resolve(self, registry: ext.ExtensionRegistry) -> UnitSum: return self - def __str__(self) -> str: - return self.__repr__() - -@dataclass(eq=False, repr=False) +@dataclass(eq=False) class Tuple(Sum): """Product type with `tys` elements. Instances of this type correspond to :class:`Sum` with a single variant. @@ -509,8 +358,11 @@ class Tuple(Sum): def __init__(self, *tys: Type): self.variant_rows = [list(tys)] + def __repr__(self) -> str: + return f"Tuple{tuple(self.variant_rows[0])}" + -@dataclass(eq=False, repr=False) +@dataclass(eq=False) class Option(Sum): """Optional tuple of elements. @@ -521,8 +373,11 @@ class Option(Sum): def __init__(self, *tys: Type): self.variant_rows = [[], list(tys)] + def __repr__(self) -> str: + return f"Option({comma_sep_repr(self.variant_rows[1])})" + -@dataclass(eq=False, repr=False) +@dataclass(eq=False) class Either(Sum): """Two-variant tuple of elements. @@ -535,6 +390,16 @@ class Either(Sum): def __init__(self, left: Iterable[Type], right: Iterable[Type]): self.variant_rows = [list(left), list(right)] + def __repr__(self) -> str: # pragma: no cover + left, right = self.variant_rows + return f"Either(left={left}, right={right})" + + def __str__(self) -> str: + left, right = self.variant_rows + left_str = left[0] if len(left) == 1 else tuple(left) + right_str = right[0] if len(right) == 1 else tuple(right) + return f"Either({left_str}, {right_str})" + @dataclass(frozen=True) class Variable(Type): @@ -767,7 +632,15 @@ def __eq__(self, value): return super().__eq__(value) def to_model(self) -> model.Term: - return self._to_opaque().to_model() + # This cast is only neccessary because `Type` can both be an + # actual type or a row variable. + args = [cast(model.Term, arg.to_model()) for arg in self.args] + + extension_name = self.type_def.get_extension().name + type_name = self.type_def.name + name = f"{extension_name}.{type_name}" + + return model.Apply(name, args) def _type_str(name: str, args: Sequence[TypeArg]) -> str: @@ -814,17 +687,17 @@ def __str__(self) -> str: return _type_str(self.id, self.args) def to_model(self) -> model.Term: - # This cast is only necessary because `Type` can both be an + # This cast is only neccessary because `Type` can both be an # actual type or a row variable. args = [cast(model.Term, arg.to_model()) for arg in self.args] - return model.Apply(f"{self.extension}.{self.id}", args) + return model.Apply(self.id, args) @dataclass class _QubitDef(Type): def type_bound(self) -> TypeBound: - return TypeBound.Linear + return TypeBound.Any def _to_serial(self) -> stys.Qubit: return stys.Qubit() diff --git a/hugr-py/src/hugr/utils.py b/hugr-py/src/hugr/utils.py index 0c6048ec32..480f3337b9 100644 --- a/hugr-py/src/hugr/utils.py +++ b/hugr-py/src/hugr/utils.py @@ -215,27 +215,3 @@ def comma_sep_str(items: Iterable[T]) -> str: def comma_sep_repr(items: Iterable[T]) -> str: """Join items with commas and repr.""" return ", ".join(map(repr, items)) - - -def comma_sep_str_paren(items: Iterable[T]) -> str: - """Join items with commas and str, wrapping them in parentheses if more than one.""" - items = list(items) - if len(items) == 0: - return "()" - elif len(items) == 1: - return f"{items[0]}" - else: - return f"({comma_sep_str(items)})" - - -def comma_sep_repr_paren(items: Iterable[T]) -> str: - """Join items with commas and repr, wrapping them in parentheses if more - than one. - """ - items = list(items) - if len(items) == 0: - return "()" - elif len(items) == 1: - return f"{items[0]}" - else: - return f"({comma_sep_repr(items)})" diff --git a/hugr-py/src/hugr/val.py b/hugr-py/src/hugr/val.py index a929969edd..925c91f989 100644 --- a/hugr-py/src/hugr/val.py +++ b/hugr-py/src/hugr/val.py @@ -45,8 +45,8 @@ class Sum(Value): """Sum-of-product value. Example: - >>> Sum(0, tys.Sum([[tys.Bool], [tys.Unit], [tys.Bool]]), [TRUE]) - Sum(tag=0, typ=Sum([[Bool], [Unit], [Bool]]), vals=[TRUE]) + >>> Sum(0, tys.Sum([[tys.Bool], [tys.Unit]]), [TRUE]) + Sum(tag=0, typ=Sum([[Bool], [Unit]]), vals=[TRUE]) """ #: Tag identifying the variant. @@ -70,59 +70,6 @@ def _to_serial(self) -> sops.SumValue: vs=ser_it(self.vals), ) - def __repr__(self) -> str: - if self == TRUE: - return "TRUE" - elif self == FALSE: - return "FALSE" - elif self == Unit: - return "Unit" - elif all(len(row) == 0 for row in self.typ.variant_rows): - return f"UnitSum({self.tag}, {self.n_variants})" - elif len(self.typ.variant_rows) == 1: - return f"Tuple({comma_sep_repr(self.vals)})" - elif len(self.typ.variant_rows) == 2 and len(self.typ.variant_rows[0]) == 0: - # Option - if self.tag == 0: - return f"None({comma_sep_str(self.typ.variant_rows[1])})" - else: - return f"Some({comma_sep_repr(self.vals)})" - elif len(self.typ.variant_rows) == 2: - # Either - left_typ, right_typ = self.typ.variant_rows - if self.tag == 0: - return f"Left(vals={self.vals}, right_typ={list(right_typ)})" - else: - return f"Right(left_typ={list(left_typ)}, vals={self.vals})" - else: - return f"Sum(tag={self.tag}, typ={self.typ}, vals={self.vals})" - - def __str__(self) -> str: - if self == TRUE: - return "TRUE" - elif self == FALSE: - return "FALSE" - elif self == Unit: - return "Unit" - elif all(len(row) == 0 for row in self.typ.variant_rows): - return f"UnitSum({self.tag}, {self.n_variants})" - elif len(self.typ.variant_rows) == 1: - return f"Tuple({comma_sep_str(self.vals)})" - elif len(self.typ.variant_rows) == 2 and len(self.typ.variant_rows[0]) == 0: - # Option - if self.tag == 0: - return "None" - else: - return f"Some({comma_sep_str(self.vals)})" - elif len(self.typ.variant_rows) == 2: - # Either - if self.tag == 0: - return f"Left({comma_sep_str(self.vals)})" - else: - return f"Right({comma_sep_str(self.vals)})" - else: - return f"Sum({self.tag}, {self.typ}, {self.vals})" - def __eq__(self, other: object) -> bool: return ( isinstance(other, Sum) @@ -153,7 +100,6 @@ def to_model(self) -> model.Term: ) -@dataclass(eq=False, repr=False) class UnitSum(Sum): """Simple :class:`Sum` with each variant being an empty row. @@ -173,6 +119,15 @@ def __init__(self, tag: int, size: int): vals=[], ) + def __repr__(self) -> str: + if self == TRUE: + return "TRUE" + if self == FALSE: + return "FALSE" + if self == Unit: + return "Unit" + return f"UnitSum({self.tag}, {self.n_variants})" + def bool_value(b: bool) -> UnitSum: """Convert a python bool to a HUGR boolean value. @@ -194,7 +149,7 @@ def bool_value(b: bool) -> UnitSum: FALSE = bool_value(False) -@dataclass(eq=False, repr=False) +@dataclass(eq=False) class Tuple(Sum): """Tuple or product value, defined by a list of values. Internally a :class:`Sum` with a single variant row. @@ -214,18 +169,18 @@ def __init__(self, *vals: Value): tag=0, typ=tys.Tuple(*(v.type_() for v in val_list)), vals=val_list ) - def _to_serial(self) -> sops.SumValue: - return sops.SumValue( - tag=0, - typ=stys.SumType(root=self.type_()._to_serial()), + # sops.TupleValue isn't an instance of sops.SumValue + # so mypy doesn't like the override of Sum._to_serial + def _to_serial(self) -> sops.TupleValue: # type: ignore[override] + return sops.TupleValue( vs=ser_it(self.vals), ) def __repr__(self) -> str: - return super().__repr__() + return f"Tuple({comma_sep_repr(self.vals)})" -@dataclass(eq=False, repr=False) +@dataclass(eq=False) class Some(Sum): """Optional tuple of value, containing a list of values. @@ -244,8 +199,11 @@ def __init__(self, *vals: Value): tag=1, typ=tys.Option(*(v.type_() for v in val_list)), vals=val_list ) + def __repr__(self) -> str: + return f"Some({comma_sep_repr(self.vals)})" + -@dataclass(eq=False, repr=False) +@dataclass(eq=False) class None_(Sum): """Optional tuple of value, containing no values. @@ -261,8 +219,14 @@ class None_(Sum): def __init__(self, *types: tys.Type): super().__init__(tag=0, typ=tys.Option(*types), vals=[]) + def __repr__(self) -> str: + return f"None({comma_sep_str(self.typ.variant_rows[1])})" + + def __str__(self) -> str: + return "None" + -@dataclass(eq=False, repr=False) +@dataclass(eq=False) class Left(Sum): """Left variant of a :class:`tys.Either` type, containing a list of values. @@ -284,8 +248,15 @@ def __init__(self, vals: Iterable[Value], right_typ: Iterable[tys.Type]): vals=val_list, ) + def __repr__(self) -> str: + _, right_typ = self.typ.variant_rows + return f"Left(vals={self.vals}, right_typ={list(right_typ)})" + + def __str__(self) -> str: + return f"Left({comma_sep_str(self.vals)})" -@dataclass(eq=False, repr=False) + +@dataclass(eq=False) class Right(Sum): """Right variant of a :class:`tys.Either` type, containing a list of values. @@ -309,6 +280,13 @@ def __init__(self, left_typ: Iterable[tys.Type], vals: Iterable[Value]): vals=val_list, ) + def __repr__(self) -> str: + left_typ, _ = self.typ.variant_rows + return f"Right(left_typ={list(left_typ)}, vals={self.vals})" + + def __str__(self) -> str: + return f"Right({comma_sep_str(self.vals)})" + @dataclass class Function(Value): @@ -320,7 +298,9 @@ def type_(self) -> tys.FunctionType: return self.body.entrypoint_op().inner_signature() def _to_serial(self) -> sops.FunctionValue: - return sops.FunctionValue(hugr=self.body.to_str()) + return sops.FunctionValue( + hugr=self.body._to_serial(), + ) def to_model(self) -> model.Term: module = self.body.to_model() diff --git a/hugr-py/tests/__snapshots__/test_hugr_build.ambr b/hugr-py/tests/__snapshots__/test_hugr_build.ambr index 240953c02b..733ba214c5 100644 --- a/hugr-py/tests/__snapshots__/test_hugr_build.ambr +++ b/hugr-py/tests/__snapshots__/test_hugr_build.ambr @@ -191,16 +191,6 @@ - - - - - - -
0
- - - > shape=plain] color="#1CADE4" label="" margin=10 penwidth=1 @@ -490,16 +480,6 @@ - - - - - - -
0
- - - > shape=plain] color="#1CADE4" label="" margin=10 penwidth=1 @@ -814,16 +794,6 @@ - - - - - - -
0
- - - > shape=plain] color="#1CADE4" label="" margin=10 penwidth=1 @@ -855,129 +825,6 @@ ''' # --- -# name: test_fndef_output_ports - ''' - digraph { - bgcolor=white margin=0 nodesep=0.15 rankdir="" ranksep=0.1 - subgraph cluster0 { - subgraph cluster1 { - 2 [label=< - - - - - - -
- - -
Input
-
- > shape=plain] - 3 [label=< - - - - - - - - - - -
- - - - -
0123
-
- - -
Output
-
- > shape=plain] - 4 [label=< - - - - - - - - - - -
- - -
MakeTuple
-
- - - - -
0
-
- > shape=plain] - 1 [label=< - - - - - - - - - - -
- - -
FuncDefn(main)
-
- - - - -
0
-
- > shape=plain] - color="#1CADE4" label="" margin=10 penwidth=1 - } - 0 [label=< - - - - - - -
- - -
[Module]
-
- > shape=plain] - color="#F4A261" label="" margin=10 penwidth=2 - } - 4:"out.0" -> 3:"in.0" [label=Unit arrowhead=none arrowsize=1.0 color="#1CADE4" fontcolor=black fontname=monospace fontsize=9 penwidth=1.5] - 4:"out.0" -> 3:"in.1" [label=Unit arrowhead=none arrowsize=1.0 color="#1CADE4" fontcolor=black fontname=monospace fontsize=9 penwidth=1.5] - 4:"out.0" -> 3:"in.2" [label=Unit arrowhead=none arrowsize=1.0 color="#1CADE4" fontcolor=black fontname=monospace fontsize=9 penwidth=1.5] - 4:"out.0" -> 3:"in.3" [label=Unit arrowhead=none arrowsize=1.0 color="#1CADE4" fontcolor=black fontname=monospace fontsize=9 penwidth=1.5] - } - - ''' -# --- # name: test_higher_order ''' digraph { @@ -1093,7 +940,7 @@ + COLOR="black">Const(Function(body=Hugr(module_root=Node(0), entrypoint=Node(4), _nodes=[NodeData(op=Module(), parent=None, metadata={}), NodeData(op=FuncDefn(f_name='main', inputs=[Qubit], params=[]), parent=Node(0), metadata={}), NodeData(op=Input(types=[Qubit]), parent=Node(1), metadata={}), NodeData(op=Output(), parent=Node(1), metadata={}), NodeData(op=DFG(inputs=[Qubit]), parent=Node(1), metadata={}), NodeData(op=Input(types=[Qubit]), parent=Node(4), metadata={}), NodeData(op=Output(), parent=Node(4), metadata={}), NodeData(op=Noop(Qubit), parent=Node(4), metadata={})], _links=BiMap({_SubPort(port=OutPort(Node(2), 0), sub_offset=0): _SubPort(port=InPort(Node(4), 0), sub_offset=0), _SubPort(port=OutPort(Node(5), 0), sub_offset=0): _SubPort(port=InPort(Node(7), 0), sub_offset=0), _SubPort(port=OutPort(Node(7), 0), sub_offset=0): _SubPort(port=InPort(Node(6), 0), sub_offset=0), _SubPort(port=OutPort(Node(4), 0), sub_offset=0): _SubPort(port=InPort(Node(3), 0), sub_offset=0)}), _free_nodes=[])))
Const(Function(body=Hugr(module_root=Node(0), entrypoint=Node(4), _nodes=[NodeData(op=Module(), parent=None, metadata={}), NodeData(op=FuncDefn(f_name='main', inputs=[Qubit], params=[], visibility='Private'), parent=Node(0), metadata={}), NodeData(op=Input(types=[Qubit]), parent=Node(1), metadata={}), NodeData(op=Output(), parent=Node(1), metadata={}), NodeData(op=DFG(inputs=[Qubit]), parent=Node(1), metadata={}), NodeData(op=Input(types=[Qubit]), parent=Node(4), metadata={}), NodeData(op=Output(), parent=Node(4), metadata={}), NodeData(op=Noop(Qubit), parent=Node(4), metadata={})], _links=BiMap({_SubPort(port=OutPort(Node(2), 0), sub_offset=0): _SubPort(port=InPort(Node(4), 0), sub_offset=0), _SubPort(port=OutPort(Node(5), 0), sub_offset=0): _SubPort(port=InPort(Node(7), 0), sub_offset=0), _SubPort(port=OutPort(Node(7), 0), sub_offset=0): _SubPort(port=InPort(Node(6), 0), sub_offset=0), _SubPort(port=OutPort(Node(4), 0), sub_offset=0): _SubPort(port=InPort(Node(3), 0), sub_offset=0)}), _free_nodes=[])))
@@ -1253,7 +1100,7 @@ } 2:"out.0" -> 4:"in.0" [label=Qubit arrowhead=none arrowsize=1.0 color="#1CADE4" fontcolor=black fontname=monospace fontsize=9 penwidth=1.5] 7:"out.0" -> 8:"in.0" [label="" arrowhead=none arrowsize=1.0 color="#77CEEF" fontcolor=black fontname=monospace fontsize=9 penwidth=1.5] - 8:"out.0" -> 9:"in.0" [label="Qubit -> Qubit" arrowhead=none arrowsize=1.0 color="#1CADE4" fontcolor=black fontname=monospace fontsize=9 penwidth=1.5] + 8:"out.0" -> 9:"in.0" [label="Qubit -> Qubit" arrowhead=none arrowsize=1.0 color="#1CADE4" fontcolor=black fontname=monospace fontsize=9 penwidth=1.5] 5:"out.0" -> 9:"in.1" [label=Qubit arrowhead=none arrowsize=1.0 color="#1CADE4" fontcolor=black fontname=monospace fontsize=9 penwidth=1.5] 5:"out.-1" -> 8:"in.-1" [label="" arrowhead=none arrowsize=1.0 color=black fontcolor=black fontname=monospace fontsize=9 penwidth=1.5] 9:"out.0" -> 6:"in.0" [label=Qubit arrowhead=none arrowsize=1.0 color="#1CADE4" fontcolor=black fontname=monospace fontsize=9 penwidth=1.5] @@ -1262,147 +1109,6 @@ ''' # --- -# name: test_html_labels - ''' - digraph "<i>Module Root</i>" { - bgcolor=white margin=0 nodesep=0.15 rankdir="" ranksep=0.1 - subgraph cluster0 { - subgraph cluster1 { - 2 [label=< - - - - - - - - - - -
- - -
Input
-
- - - - -
0
-
- > shape=plain] - 3 [label=< - - - - - - - - - - -
- - - - -
0
-
- - -
Output
-
- > shape=plain] - 4 [label=< - - - - - - - - - - - - - - -
- - - - -
0
-
- - -
Some
-
- - - - -
0
-
- > shape=plain] - 1 [label=< - - - - - - - - - - -
- - -
[FuncDefn(<jupyter-notebook>)]

label: <b>Bold Label</b>
<other-label>: <i>Italic Label</i>
meta_can_be_anything: [42, 'string', 3.14, True]
-
- - - - -
0
-
- > shape=plain] - color="#F4A261" label="" margin=10 penwidth=2 - } - 0 [label=< - - - - - - -
- - -
Module

name: <i>Module Root</i>
-
- > shape=plain] - color="#1CADE4" label="" margin=10 penwidth=1 - } - 2:"out.0" -> 4:"in.0" [label=Bool arrowhead=none arrowsize=1.0 color="#1CADE4" fontcolor=black fontname=monospace fontsize=9 penwidth=1.5] - 2:"out.0" -> 3:"in.0" [label=Bool arrowhead=none arrowsize=1.0 color="#1CADE4" fontcolor=black fontname=monospace fontsize=9 penwidth=1.5] - } - - ''' -# --- # name: test_insert_nested ''' digraph { @@ -1683,16 +1389,6 @@ - - - - - - -
0
- - - > shape=plain] color="#1CADE4" label="" margin=10 penwidth=1 @@ -1916,16 +1612,6 @@ - - - - - - -
0
- - - > shape=plain] color="#1CADE4" label="" margin=10 penwidth=1 @@ -2147,16 +1833,6 @@ - - - - - - -
0
- - - > shape=plain] color="#1CADE4" label="" margin=10 penwidth=1 @@ -2178,14 +1854,14 @@ > shape=plain] color="#1CADE4" label="" margin=10 penwidth=1 } - 2:"out.0" -> 4:"in.0" [label="int<5>" arrowhead=none arrowsize=1.0 color="#1CADE4" fontcolor=black fontname=monospace fontsize=9 penwidth=1.5] - 2:"out.1" -> 4:"in.1" [label="int<5>" arrowhead=none arrowsize=1.0 color="#1CADE4" fontcolor=black fontname=monospace fontsize=9 penwidth=1.5] - 5:"out.0" -> 7:"in.0" [label="int<5>" arrowhead=none arrowsize=1.0 color="#1CADE4" fontcolor=black fontname=monospace fontsize=9 penwidth=1.5] - 5:"out.1" -> 7:"in.1" [label="int<5>" arrowhead=none arrowsize=1.0 color="#1CADE4" fontcolor=black fontname=monospace fontsize=9 penwidth=1.5] - 7:"out.0" -> 6:"in.0" [label="int<5>" arrowhead=none arrowsize=1.0 color="#1CADE4" fontcolor=black fontname=monospace fontsize=9 penwidth=1.5] - 7:"out.1" -> 6:"in.1" [label="int<5>" arrowhead=none arrowsize=1.0 color="#1CADE4" fontcolor=black fontname=monospace fontsize=9 penwidth=1.5] - 4:"out.0" -> 3:"in.0" [label="int<5>" arrowhead=none arrowsize=1.0 color="#1CADE4" fontcolor=black fontname=monospace fontsize=9 penwidth=1.5] - 4:"out.1" -> 3:"in.1" [label="int<5>" arrowhead=none arrowsize=1.0 color="#1CADE4" fontcolor=black fontname=monospace fontsize=9 penwidth=1.5] + 2:"out.0" -> 4:"in.0" [label="int<5>" arrowhead=none arrowsize=1.0 color="#1CADE4" fontcolor=black fontname=monospace fontsize=9 penwidth=1.5] + 2:"out.1" -> 4:"in.1" [label="int<5>" arrowhead=none arrowsize=1.0 color="#1CADE4" fontcolor=black fontname=monospace fontsize=9 penwidth=1.5] + 5:"out.0" -> 7:"in.0" [label="int<5>" arrowhead=none arrowsize=1.0 color="#1CADE4" fontcolor=black fontname=monospace fontsize=9 penwidth=1.5] + 5:"out.1" -> 7:"in.1" [label="int<5>" arrowhead=none arrowsize=1.0 color="#1CADE4" fontcolor=black fontname=monospace fontsize=9 penwidth=1.5] + 7:"out.0" -> 6:"in.0" [label="int<5>" arrowhead=none arrowsize=1.0 color="#1CADE4" fontcolor=black fontname=monospace fontsize=9 penwidth=1.5] + 7:"out.1" -> 6:"in.1" [label="int<5>" arrowhead=none arrowsize=1.0 color="#1CADE4" fontcolor=black fontname=monospace fontsize=9 penwidth=1.5] + 4:"out.0" -> 3:"in.0" [label="int<5>" arrowhead=none arrowsize=1.0 color="#1CADE4" fontcolor=black fontname=monospace fontsize=9 penwidth=1.5] + 4:"out.1" -> 3:"in.1" [label="int<5>" arrowhead=none arrowsize=1.0 color="#1CADE4" fontcolor=black fontname=monospace fontsize=9 penwidth=1.5] } ''' @@ -2347,16 +2023,6 @@ - - - - - - -
0
- - - > shape=plain] color="#1CADE4" label="" margin=10 penwidth=1 @@ -2686,16 +2352,6 @@ - - - - - - -
0
- - - > shape=plain] color="#1CADE4" label="" margin=10 penwidth=1 @@ -2954,16 +2610,6 @@ - - - - - - -
0
- - - > shape=plain] color="#1CADE4" label="" margin=10 penwidth=1 diff --git a/hugr-py/tests/__snapshots__/test_order_edges.ambr b/hugr-py/tests/__snapshots__/test_order_edges.ambr deleted file mode 100644 index 3dae0b5530..0000000000 --- a/hugr-py/tests/__snapshots__/test_order_edges.ambr +++ /dev/null @@ -1,258 +0,0 @@ -# serializer version: 1 -# name: test_order_unconnected - ''' - digraph { - bgcolor=white margin=0 nodesep=0.15 rankdir="" ranksep=0.1 - subgraph cluster0 { - subgraph cluster1 { - 2 [label=< - - - - - - - - - - -
- - -
Input
-
- - - - -
0
-
- > shape=plain] - 3 [label=< - - - - - - - - - - -
- - - - -
0
-
- - -
Output
-
- > shape=plain] - subgraph cluster4 { - 5 [label=< - - - - - - - - - - -
- - -
Input
-
- - - - -
0
-
- > shape=plain] - 6 [label=< - - - - - - - - - - -
- - - - -
0
-
- - -
Output
-
- > shape=plain] - 7 [label=< - - - - - - - - - - - - - - -
- - - - -
0
-
- - -
MeasureFree
-
- - - - -
0
-
- > shape=plain] - 8 [label=< - - - - - - - - - - -
- - -
QAlloc
-
- - - - -
0
-
- > shape=plain] - 4 [label=< - - - - - - - - - - - - - - -
- - - - -
0
-
- - -
[DFG]
-
- - - - -
0
-
- > shape=plain] - color="#F4A261" label="" margin=10 penwidth=2 - } - 1 [label=< - - - - - - - - - - -
- - -
FuncDefn(main)
-
- - - - -
0
-
- > shape=plain] - color="#1CADE4" label="" margin=10 penwidth=1 - } - 0 [label=< - - - - - - -
- - -
Module
-
- > shape=plain] - color="#1CADE4" label="" margin=10 penwidth=1 - } - 2:"out.0" -> 4:"in.0" [label=Qubit arrowhead=none arrowsize=1.0 color="#1CADE4" fontcolor=black fontname=monospace fontsize=9 penwidth=1.5] - 5:"out.0" -> 7:"in.0" [label=Qubit arrowhead=none arrowsize=1.0 color="#1CADE4" fontcolor=black fontname=monospace fontsize=9 penwidth=1.5] - 7:"out.-1" -> 8:"in.-1" [label="" arrowhead=none arrowsize=1.0 color=black fontcolor=black fontname=monospace fontsize=9 penwidth=1.5] - 8:"out.0" -> 6:"in.0" [label=Qubit arrowhead=none arrowsize=1.0 color="#1CADE4" fontcolor=black fontname=monospace fontsize=9 penwidth=1.5] - 4:"out.0" -> 3:"in.0" [label=Qubit arrowhead=none arrowsize=1.0 color="#1CADE4" fontcolor=black fontname=monospace fontsize=9 penwidth=1.5] - } - - ''' -# --- diff --git a/hugr-py/tests/conftest.py b/hugr-py/tests/conftest.py index 909d0a8bfd..f4b2617a01 100644 --- a/hugr-py/tests/conftest.py +++ b/hugr-py/tests/conftest.py @@ -12,16 +12,13 @@ from hugr import ext, tys from hugr.envelope import EnvelopeConfig from hugr.hugr import Hugr -from hugr.ops import AsExtOp, Command, Const, Custom, DataflowOp, ExtOp, RegisteredOp +from hugr.ops import AsExtOp, Command, DataflowOp, ExtOp, RegisteredOp from hugr.package import Package from hugr.std.float import FLOAT_T if TYPE_CHECKING: - import typing - from syrupy.assertion import SnapshotAssertion - from hugr.hugr.node_port import Node from hugr.ops import ComWire QUANTUM_EXT = ext.Extension("pytest.quantum", ext.Version(0, 1, 0)) @@ -109,32 +106,6 @@ def __call__(self, q: ComWire) -> Command: Measure = MeasureDef() -@QUANTUM_EXT.register_op( - "MeasureFree", - signature=tys.FunctionType([tys.Qubit], [tys.Bool]), -) -@dataclass(frozen=True) -class MeasureFreeDef(RegisteredOp): - def __call__(self, q: ComWire) -> Command: - return super().__call__(q) - - -MeasureFree = MeasureFreeDef() - - -@QUANTUM_EXT.register_op( - "QAlloc", - signature=tys.FunctionType([], [tys.Qubit]), -) -@dataclass(frozen=True) -class QAllocDef(RegisteredOp): - def __call__(self) -> Command: - return super().__call__() - - -QAlloc = QAllocDef() - - @QUANTUM_EXT.register_op( "Rz", signature=tys.FunctionType([tys.Qubit, FLOAT_T], [tys.Qubit]), @@ -175,156 +146,42 @@ def validate( snapshot: A hugr render snapshot. If not None, it will be compared against the rendered HUGR. Pass `--snapshot-update` to pytest to update the snapshot file. """ - if snap is not None: - dot = h.render_dot() if isinstance(h, Hugr) else h.modules[0].render_dot() - assert snap == dot.source - if os.environ.get("HUGR_RENDER_DOT"): - dot.pipe("svg") - - # Encoding formats to test, indexed by the format name as used by - # `hugr convert --format`. - FORMATS = { - "json": EnvelopeConfig.TEXT, - "model-exts": EnvelopeConfig.BINARY, - } - # Envelope formats used when exporting test hugrs. - WRITE_FORMATS = ["json", "model-exts"] - # Envelope formats used as target for `hugr convert` before loading back the - # test hugrs. - # - # Model envelopes cannot currently be loaded from python. - # TODO: Add model envelope loading to python, and add it to the list. - LOAD_FORMATS = ["json"] - + # TODO: Use envelopes instead of legacy hugr-json cmd = [*_base_command(), "validate", "-"] - # validate text and binary formats - for write_fmt in WRITE_FORMATS: - serial = h.to_bytes(FORMATS[write_fmt]) - _run_hugr_cmd(serial, cmd) - - if roundtrip: - # Roundtrip tests: - # Try converting to all possible LOAD_FORMATS, load them back in, - # and check that the loaded HUGR corresponds to the original using - # a node hash comparison. - # - # Run `pytest` with `-vv` to see the hash diff. - for load_fmt in LOAD_FORMATS: - if load_fmt != write_fmt: - cmd = [*_base_command(), "convert", "--format", load_fmt, "-"] - out = _run_hugr_cmd(serial, cmd) - converted_serial = out.stdout - else: - converted_serial = serial - loaded = Package.from_bytes(converted_serial) - - modules = [h] if isinstance(h, Hugr) else h.modules - - assert len(loaded.modules) == len(modules) - for m1, m2 in zip(loaded.modules, modules, strict=True): - h1_hash = _NodeHash.hash_hugr(m1, "original") - h2_hash = _NodeHash.hash_hugr(m2, "loaded") - assert ( - h1_hash == h2_hash - ), f"HUGRs are not the same for {write_fmt} -> {load_fmt}" - - # Lowering functions are currently ignored in Python, - # because we don't support loading -model envelopes yet. - for ext in loaded.extensions: - for op in ext.operations.values(): - assert op.lower_funcs == [] - - -@dataclass(frozen=True, order=True) -class _NodeHash: - op: _OpHash - entrypoint: bool - input_neighbours: int - output_neighbours: int - input_ports: int - output_ports: int - input_order_edges: int - output_order_edges: int - is_region: bool - node_depth: int - children_hashes: list[_NodeHash] - metadata: dict[str, str] - - @classmethod - def hash_hugr(cls, h: Hugr, name: str) -> _NodeHash: - """Returns an order-independent hash of a HUGR.""" - return cls._hash_node(h, h.module_root, 0, name) - - @classmethod - def _hash_node(cls, h: Hugr, n: Node, depth: int, name: str) -> _NodeHash: - children = h.children(n) - child_hashes = sorted(cls._hash_node(h, c, depth + 1, name) for c in children) - metadata = {k: str(v) for k, v in h[n].metadata.items()} - - # Pick a normalized representation of the op name. - op_type = h[n].op - if isinstance(op_type, AsExtOp): - op_type = op_type.ext_op.to_custom_op() - op = _OpHash(f"{op_type.extension}.{op_type.op_name}") - elif isinstance(op_type, Custom): - op = _OpHash(f"{op_type.extension}.{op_type.op_name}") - elif isinstance(op_type, Const): - # We need every custom value to have the same repr if they compare - # equal. For example, an `IntVal(42)` should be the same as the - # equivalent `Extension` value. This needs a lot of extra - # unwrapping, since each class implements different `__repr__` - # methods. - # - # Our solution here is to encode the value into JSON and compare those. - # This may miss some errors, but it's the best we can do for now. Note that - # roundtripping via `sops.Value` is not enough, since nested - # specialized values don't get serialized straight away. (e.g. - # StaticArrayVal's dictionary payload containing a SumValue - # internally, see `test_val_static_array`). - value_dict = op_type.val._to_serial_root().model_dump(mode="json") - op = _OpHash("Const", value_dict) - else: - op = _OpHash(op_type.name()) - - return _NodeHash( - entrypoint=n == h.entrypoint, - op=op, - input_neighbours=h.num_incoming(n), - output_neighbours=h.num_outgoing(n), - input_ports=h.num_in_ports(n), - output_ports=h.num_out_ports(n), - input_order_edges=len(list(h.incoming_order_links(n))), - output_order_edges=len(list(h.outgoing_order_links(n))), - is_region=len(children) > 0, - node_depth=depth, - children_hashes=child_hashes, - metadata=metadata, - ) - - -@dataclass(frozen=True) -class _OpHash: - name: str - payload: None | typing.Any = None - - def __lt__(self, other: _OpHash) -> bool: - """Compare two op hashes by name and payload.""" - return (self.name, repr(self.payload)) < (other.name, repr(other.payload)) - - -def _get_mermaid(serial: bytes) -> str: # - """Render a HUGR as a mermaid diagram using the CLI.""" - return _run_hugr_cmd(serial, [*_base_command(), "mermaid", "-"]).stdout.decode() - - -def _run_hugr_cmd(serial: bytes, cmd: list[str]) -> subprocess.CompletedProcess[bytes]: + serial = h.to_bytes(EnvelopeConfig.BINARY) + _run_hugr_cmd(serial, cmd) + + if not roundtrip: + return + + # Roundtrip checks + if isinstance(h, Hugr): + starting_json = h.to_str() + h2 = Hugr.from_str(starting_json) + roundtrip_json = h2.to_str() + assert roundtrip_json == starting_json + + if snap is not None: + dot = h.render_dot() + assert snap == dot.source + if os.environ.get("HUGR_RENDER_DOT"): + dot.pipe("svg") + else: + # Package + encoded = h.to_str(EnvelopeConfig.TEXT) + loaded = Package.from_str(encoded) + roundtrip_encoded = loaded.to_str(EnvelopeConfig.TEXT) + assert encoded == roundtrip_encoded + + +def _run_hugr_cmd(serial: bytes, cmd: list[str]): """Run a HUGR command. The `serial` argument is the serialized HUGR to pass to the command via stdin. """ try: - return subprocess.run(cmd, check=True, input=serial, capture_output=True) # noqa: S603 + subprocess.run(cmd, check=True, input=serial, capture_output=True) # noqa: S603 except subprocess.CalledProcessError as e: error = e.stderr.decode() raise RuntimeError(error) from e diff --git a/hugr-py/tests/test_cfg.py b/hugr-py/tests/test_cfg.py index 3eed38e106..14753e2f44 100644 --- a/hugr-py/tests/test_cfg.py +++ b/hugr-py/tests/test_cfg.py @@ -8,7 +8,7 @@ def build_basic_cfg(cfg: Cfg) -> None: with cfg.add_entry() as entry: entry.set_single_succ_outputs(*entry.inputs()) - cfg.branch_exit(entry[0]) + cfg.branch(entry[0], cfg.exit) def test_basic_cfg() -> None: diff --git a/hugr-py/tests/test_custom.py b/hugr-py/tests/test_custom.py index 1bb980915f..48f57de7a7 100644 --- a/hugr-py/tests/test_custom.py +++ b/hugr-py/tests/test_custom.py @@ -139,7 +139,7 @@ def test_custom_bad_eq(): ext.TypeDef( "List", description="A list of elements.", - params=[tys.TypeTypeParam(tys.TypeBound.Linear)], + params=[tys.TypeTypeParam(tys.TypeBound.Any)], bound=ext.FromParamsBound([0]), ) ) diff --git a/hugr-py/tests/test_envelope.py b/hugr-py/tests/test_envelope.py index c97cd074bc..1e667e8eef 100644 --- a/hugr-py/tests/test_envelope.py +++ b/hugr-py/tests/test_envelope.py @@ -1,17 +1,10 @@ -from pathlib import Path - -import pytest -import semver - -from hugr import ops, tys +from hugr import tys from hugr.build.function import Module from hugr.envelope import EnvelopeConfig, EnvelopeFormat -from hugr.hugr.node_port import Node from hugr.package import Package -@pytest.fixture -def package() -> Package: +def test_envelope(): mod = Module() f_id = mod.define_function("id", [tys.Qubit]) f_id.set_outputs(f_id.input_node[0]) @@ -24,10 +17,8 @@ def package() -> Package: q = f_main.input_node[0] call = f_main.call(f_id_decl, q) f_main.set_outputs(call) - return Package([mod.hugr, mod2.hugr]) + package = Package([mod.hugr, mod2.hugr]) - -def test_envelope(package: Package): # Binary compression roundtrip for format in [EnvelopeFormat.JSON]: for compression in [None, 0]: @@ -36,30 +27,6 @@ def test_envelope(package: Package): assert decoded == package # String roundtrip - encoded_str = package.to_str(EnvelopeConfig.TEXT) - decoded = Package.from_str(encoded_str) + encoded = package.to_str(EnvelopeConfig.TEXT) + decoded = Package.from_str(encoded) assert decoded == package - - -def test_model(package: Package): - model_pkg = package.to_model() - - # This value is statically defined in the rust bindings. - assert model_pkg.version >= semver.Version(major=1, minor=0, patch=0) - - -def test_legacy_funcdefn(): - p = Path(__file__).parents[2] / "resources" / "test" / "hugr-no-visibility.hugr" - try: - with p.open("rb") as f: - pkg_bytes = f.read() - except FileNotFoundError: - pytest.skip("Missing test file") - decoded = Package.from_bytes(pkg_bytes) - h = decoded.modules[0] - op1 = h[Node(1)].op - assert isinstance(op1, ops.FuncDecl) - assert op1.visibility == "Public" - op2 = h[Node(2)].op - assert isinstance(op2, ops.FuncDefn) - assert op2.visibility == "Private" diff --git a/hugr-py/tests/test_hugr_build.py b/hugr-py/tests/test_hugr_build.py index 0a6677145b..74c8018f9e 100644 --- a/hugr-py/tests/test_hugr_build.py +++ b/hugr-py/tests/test_hugr_build.py @@ -5,16 +5,15 @@ import hugr.ops as ops import hugr.tys as tys import hugr.val as val -from hugr.build.dfg import Dfg, Function, _ancestral_sibling +from hugr.build.dfg import Dfg, _ancestral_sibling from hugr.build.function import Module from hugr.hugr import Hugr from hugr.hugr.node_port import Node, _SubPort from hugr.ops import NoConcreteFunc -from hugr.package import Package from hugr.std.int import INT_T, DivMod, IntVal from hugr.std.logic import Not -from .conftest import QUANTUM_EXT, H, validate +from .conftest import validate def test_stable_indices(): @@ -197,7 +196,7 @@ def test_build_inter_graph(snapshot): validate(h.hugr, snap=snapshot) assert _SubPort(h.input_node.out(-1)) in h.hugr._links - assert h.hugr.num_outgoing(h.input_node) == 3 + assert h.hugr.num_outgoing(h.input_node) == 2 # doesn't count state order assert len(list(h.hugr.outgoing_order_links(h.input_node))) == 1 assert len(list(h.hugr.incoming_order_links(nested))) == 1 assert len(list(h.hugr.incoming_order_links(h.output_node))) == 0 @@ -215,6 +214,7 @@ def test_ancestral_sibling(): @pytest.mark.parametrize( "val", [ + val.Function(simple_id().hugr), val.Sum(1, tys.Sum([[INT_T], [tys.Bool, INT_T]]), [val.TRUE, IntVal(34)]), val.Tuple(val.TRUE, IntVal(23)), ], @@ -232,8 +232,8 @@ def test_poly_function(direct_call: bool) -> None: f_id = mod.declare_function( "id", tys.PolyFuncType( - [tys.TypeTypeParam(tys.TypeBound.Linear)], - tys.FunctionType.endo([tys.Variable(0, tys.TypeBound.Linear)]), + [tys.TypeTypeParam(tys.TypeBound.Any)], + tys.FunctionType.endo([tys.Variable(0, tys.TypeBound.Any)]), ), ) @@ -259,39 +259,6 @@ def test_poly_function(direct_call: bool) -> None: validate(mod.hugr) -def test_literals() -> None: - mod = Module() - - func = mod.declare_function( - "literals", - tys.PolyFuncType( - [ - tys.StringParam(), - tys.BoundedNatParam(), - tys.BytesParam(), - tys.FloatParam(), - ], - tys.FunctionType.endo([tys.Qubit]), - ), - ) - - caller = mod.define_function("caller", [tys.Qubit], [tys.Qubit]) - call = caller.call( - func, - caller.inputs()[0], - instantiation=tys.FunctionType.endo([tys.Qubit]), - type_args=[ - tys.StringArg("string"), - tys.BoundedNatArg(42), - tys.BytesArg(b"HUGR"), - tys.FloatArg(0.9), - ], - ) - caller.set_outputs(call) - - validate(mod.hugr) - - @pytest.mark.parametrize("direct_call", [True, False]) def test_mono_function(direct_call: bool) -> None: mod = Module() @@ -311,37 +278,6 @@ def test_mono_function(direct_call: bool) -> None: validate(mod.hugr) -def test_static_output() -> None: - mod = Module() - - mod.declare_function( - "declared", - tys.PolyFuncType( - [], - tys.FunctionType.endo([]), - ), - ) - - func = mod.define_function("defined", [], []) - func.declare_outputs([]) - func.set_outputs() - - validate(mod.hugr) - - -def test_function_dfg() -> None: - d = Dfg(tys.Qubit) - - f_id = d.module_root_builder().define_function("id", [tys.Qubit]) - f_id.set_outputs(f_id.input_node[0]) - - (q,) = d.inputs() - call = d.call(f_id, q) - d.set_outputs(call) - - validate(d.hugr) - - def test_recursive_function(snapshot) -> None: mod = Module() @@ -363,7 +299,6 @@ def test_invalid_recursive_function() -> None: f_recursive.set_outputs(f_recursive.input_node[0]) -@pytest.mark.skip("Value::Function is deprecated and not supported by model encoding.") def test_higher_order(snapshot) -> None: noop_fn = Dfg(tys.Qubit) noop_fn.set_outputs(noop_fn.add(ops.Noop()(noop_fn.input_node[0]))) @@ -423,76 +358,3 @@ def test_option() -> None: dfg.set_outputs(b) validate(dfg.hugr) - - -# a helper for the toposort tests -@pytest.fixture -def simple_fn() -> Function: - f = Function("prepare_qubit", [tys.Bool, tys.Qubit]) - [b, q] = f.inputs() - - h = f.add_op(H, q) - q = h.out(0) - - nnot = f.add_op(Not, b) - - f.set_outputs(q, nnot, b) - validate(Package([f.hugr], [QUANTUM_EXT])) - return f - - -# https://github.com/CQCL/hugr/issues/2350 -def test_toposort(simple_fn: Function) -> None: - nodes = list(simple_fn.hugr) - func_node = nodes[1] - - sorted_nodes = list(simple_fn.hugr.sorted_region_nodes(func_node)) - assert set(sorted_nodes) == set(simple_fn.hugr.children(simple_fn)) - assert sorted_nodes[0] == simple_fn.input_node - assert sorted_nodes[-1] == simple_fn.output_node - - -def test_toposort_error(simple_fn: Function) -> None: - # Test that we get an error if we toposort an invalid hugr containing a cycle - nodes = list(simple_fn.hugr) - func_node = nodes[1] - - # Add a loop, invalidating the HUGR - simple_fn.hugr.add_link(nodes[4].out_port(), nodes[4].inp(0)) - with pytest.raises( - ValueError, match="Graph contains a cycle. No topological ordering exists." - ): - list(simple_fn.hugr.sorted_region_nodes(func_node)) - - -def test_html_labels(snapshot) -> None: - """Ensures that HTML-like labels can be processed correctly by both the builder and - the renderer. - """ - f = Function( - "", - [tys.Bool], - ) - f.metadata["label"] = "Bold Label" - f.metadata[""] = "Italic Label" - f.metadata["meta_can_be_anything"] = [42, "string", 3.14, True] - - f.hugr[f.hugr.module_root].metadata["name"] = "Module Root" - - b = f.inputs()[0] - f.add_op(ops.Some(tys.Bool), b) - f.set_outputs(b) - - validate(f.hugr, snap=snapshot) - - -# https://github.com/CQCL/hugr/issues/2438 -def test_fndef_output_ports(snapshot): - mod = Module() - main = mod.define_function("main", [], [tys.Unit, tys.Unit, tys.Unit, tys.Unit]) - unit = main.add_op(ops.MakeTuple()) - main.set_outputs(*4 * [unit]) - - assert mod.hugr.num_out_ports(main) == 1 - - validate(mod.hugr, snap=snapshot) diff --git a/hugr-py/tests/test_ops.py b/hugr-py/tests/test_ops.py index fd6b8870db..3d5d444772 100644 --- a/hugr-py/tests/test_ops.py +++ b/hugr-py/tests/test_ops.py @@ -1,6 +1,5 @@ import pytest -from hugr import tys from hugr.hugr.node_port import InPort, Node, OutPort from hugr.ops import ( CFG, @@ -43,8 +42,7 @@ (DivMod, "arithmetic.int.idivmod_u<5>"), (MakeTuple(), "MakeTuple"), (UnpackTuple(), "UnpackTuple"), - (Tag(0, Bool), "Left"), - (Tag(0, tys.Sum([[Bool, Bool, Bool]])), "Tag(0)"), + (Tag(0, Bool), "Tag(0)"), (CFG([]), "CFG"), (DFG([]), "DFG"), (DataflowBlock([]), "DataflowBlock"), @@ -61,7 +59,7 @@ (FuncDecl("bar", PolyFuncType.empty()), "FuncDecl(bar)"), (Const(TRUE), "Const(TRUE)"), (Noop(), "Noop"), - (AliasDecl("baz", TypeBound.Linear), "AliasDecl(baz)"), + (AliasDecl("baz", TypeBound.Any), "AliasDecl(baz)"), (AliasDefn("baz", Bool), "AliasDefn(baz)"), ], ) diff --git a/hugr-py/tests/test_order_edges.py b/hugr-py/tests/test_order_edges.py deleted file mode 100644 index 78e585a2b6..0000000000 --- a/hugr-py/tests/test_order_edges.py +++ /dev/null @@ -1,49 +0,0 @@ -from hugr import tys -from hugr.build.dfg import Dfg -from hugr.package import Package - -from .conftest import QUANTUM_EXT, MeasureFree, QAlloc, validate - - -def test_order_links(): - dfg = Dfg(tys.Bool) - inp_0 = dfg.input_node.out(0) - inp_order = dfg.input_node.out(-1) - out_0 = dfg.output_node.inp(0) - out_1 = dfg.output_node.inp(1) - out_order = dfg.output_node.inp(-1) - - dfg.hugr.add_link(inp_0, out_0) - dfg.hugr.add_link(inp_0, out_1) - assert list(dfg.hugr.outgoing_links(dfg.input_node)) == [ - (inp_0, [out_0, out_1]), - ] - assert list(dfg.hugr.incoming_links(dfg.output_node)) == [ - (out_0, [inp_0]), - (out_1, [inp_0]), - ] - - # Now add an order link - dfg.hugr.add_order_link(dfg.input_node, dfg.output_node) - assert list(dfg.hugr.incoming_order_links(dfg.output_node)) == [dfg.input_node] - assert list(dfg.hugr.outgoing_order_links(dfg.input_node)) == [dfg.output_node] - assert list(dfg.hugr.outgoing_links(dfg.input_node)) == [ - (inp_0, [out_0, out_1]), - (inp_order, [out_order]), - ] - assert list(dfg.hugr.incoming_links(dfg.output_node)) == [ - (out_0, [inp_0]), - (out_1, [inp_0]), - (out_order, [inp_order]), - ] - - -# https://github.com/CQCL/hugr/issues/2439 -def test_order_unconnected(snapshot): - dfg = Dfg(tys.Qubit) - meas = dfg.add(MeasureFree(*dfg.inputs())) - alloc = dfg.add_op(QAlloc) - dfg.hugr.add_order_link(meas, alloc) - dfg.set_outputs(alloc) - - validate(Package([dfg.hugr], [QUANTUM_EXT]), snap=snapshot) diff --git a/hugr-py/tests/test_prelude.py b/hugr-py/tests/test_prelude.py index 71af3d4e06..c2ef2cbeec 100644 --- a/hugr-py/tests/test_prelude.py +++ b/hugr-py/tests/test_prelude.py @@ -1,7 +1,4 @@ -import pytest - from hugr.build.dfg import Dfg -from hugr.std.int import IntVal, int_t from hugr.std.prelude import STRING_T, StringVal from .conftest import validate @@ -19,38 +16,3 @@ def test_string_val(): dfg.set_outputs(v) validate(dfg.hugr) - - -@pytest.mark.parametrize( - ("log_width", "v", "unsigned"), - [ - (5, 1, 1), - (4, 0, 0), - (6, 42, 42), - (2, -1, 15), - (1, -2, 2), - (3, -23, 233), - (3, -256, None), - (2, 16, None), - ], -) -def test_int_val(log_width: int, v: int, unsigned: int | None): - val = IntVal(v, log_width) - if unsigned is None: - with pytest.raises( - ValueError, - match=f"Value {v} out of range for {1<"), (StaticArray(Bool), "static_array"), (ValueArray(Bool, 3), "value_array<3, Type(Bool)>"), - (BorrowArray(Bool, 3), "borrow_array<3, Type(Bool)>"), - (Variable(2, TypeBound.Linear), "$2"), + (Variable(2, TypeBound.Any), "$2"), (RowVariable(4, TypeBound.Copyable), "$4"), (USize(), "USize"), (INT_T, "int<5>"), @@ -147,10 +132,10 @@ def test_args_str(arg: TypeArg, string: str): (FunctionType([Bool, Qubit], [Qubit, Bool]), "Bool, Qubit -> Qubit, Bool"), ( PolyFuncType( - [TypeTypeParam(TypeBound.Linear), BoundedNatParam(7)], + [TypeTypeParam(TypeBound.Any), BoundedNatParam(7)], FunctionType([_int_tv(1)], [Variable(0, TypeBound.Copyable)]), ), - "∀ Linear, Nat(7). int<$1> -> $0", + "∀ Any, Nat(7). int<$1> -> $0", ), ], ) @@ -181,12 +166,12 @@ def test_array(): ls = Array(Bool, 3) assert ls.ty == Bool assert ls.size == 3 - assert ls.type_bound() == TypeBound.Linear + assert ls.type_bound() == TypeBound.Any ls = Array(ty_var, len_var) assert ls.ty == ty_var assert ls.size is None - assert ls.type_bound() == TypeBound.Linear + assert ls.type_bound() == TypeBound.Any ar_val = ArrayVal([val.TRUE, val.FALSE], Bool) assert ar_val.v == [val.TRUE, val.FALSE] @@ -194,7 +179,7 @@ def test_array(): def test_value_array(): - ty_var = Variable(0, TypeBound.Linear) + ty_var = Variable(0, TypeBound.Any) len_var = VariableArg(1, BoundedNatParam()) ls = ValueArray(Bool, 3) @@ -205,32 +190,13 @@ def test_value_array(): ls = ValueArray(ty_var, len_var) assert ls.ty == ty_var assert ls.size is None - assert ls.type_bound() == TypeBound.Linear + assert ls.type_bound() == TypeBound.Any ar_val = ValueArrayVal([val.TRUE, val.FALSE], Bool) assert ar_val.v == [val.TRUE, val.FALSE] assert ar_val.ty == ValueArray(Bool, 2) -def test_borrow_array(): - ty_var = Variable(0, TypeBound.Copyable) - len_var = VariableArg(1, BoundedNatParam()) - - ls = BorrowArray(Bool, 3) - assert ls.ty == Bool - assert ls.size == 3 - assert ls.type_bound() == TypeBound.Linear - - ls = BorrowArray(ty_var, len_var) - assert ls.ty == ty_var - assert ls.size is None - assert ls.type_bound() == TypeBound.Linear - - ar_val = BorrowArrayVal([val.TRUE, val.FALSE], Bool) - assert ar_val.v == [val.TRUE, val.FALSE] - assert ar_val.ty == BorrowArray(Bool, 2) - - def test_static_array(): ty_var = Variable(0, TypeBound.Copyable) diff --git a/hugr-py/tests/test_val.py b/hugr-py/tests/test_val.py index 5afab88400..11fd1a1194 100644 --- a/hugr-py/tests/test_val.py +++ b/hugr-py/tests/test_val.py @@ -14,6 +14,7 @@ Sum, Tuple, UnitSum, + Value, bool_value, ) @@ -43,9 +44,9 @@ def test_sums(): ("value", "string", "repr_str"), [ ( - Sum(0, tys.Sum([[tys.Bool], [tys.Qubit], [tys.Bool]]), [TRUE]), - "Sum(0, Sum([[Bool], [Qubit], [Bool]]), [TRUE])", - "Sum(tag=0, typ=Sum([[Bool], [Qubit], [Bool]]), vals=[TRUE])", + Sum(0, tys.Sum([[tys.Bool], [tys.Qubit]]), [TRUE, FALSE]), + "Sum(tag=0, typ=Sum([[Bool], [Qubit]]), vals=[TRUE, FALSE])", + "Sum(tag=0, typ=Sum([[Bool], [Qubit]]), vals=[TRUE, FALSE])", ), (UnitSum(0, size=1), "Unit", "Unit"), (UnitSum(0, size=2), "FALSE", "FALSE"), @@ -66,15 +67,10 @@ def test_sums(): ), ], ) -def test_val_sum_str(value: Sum, string: str, repr_str: str): +def test_val_sum_str(value: Value, string: str, repr_str: str): assert str(value) == string assert repr(value) == repr_str - # Make sure the corresponding `Sum` also renders the same - sum_val = Sum(value.tag, value.typ, value.vals) - assert str(sum_val) == string - assert repr(sum_val) == repr_str - def test_val_static_array(): from hugr.std.collections.static_array import StaticArrayVal diff --git a/hugr/CHANGELOG.md b/hugr/CHANGELOG.md index a29df3b196..46439ca695 100644 --- a/hugr/CHANGELOG.md +++ b/hugr/CHANGELOG.md @@ -1,94 +1,5 @@ # Changelog - -## [0.22.1](https://github.com/CQCL/hugr/compare/hugr-v0.22.0...hugr-v0.22.1) - 2025-07-28 - -### New Features - -- Include copy_discard_array in DelegatingLinearizer::default ([#2479](https://github.com/CQCL/hugr/pull/2479)) -- Inline calls to functions not on cycles in the call graph ([#2450](https://github.com/CQCL/hugr/pull/2450)) - -## [0.22.0](https://github.com/CQCL/hugr/compare/hugr-v0.21.0...hugr-v0.22.0) - 2025-07-24 - -This release fixes multiple inconsistencies between the serialization formats -and improves the error messages when loading unsupported envelopes. - -We now also support nodes with up to `2^32` connections to the same port (up from `2^16`). - -### Bug Fixes - -- Ensure SumTypes have the same json encoding in -rs and -py ([#2465](https://github.com/CQCL/hugr/pull/2465)) - -### New Features - -- ReplaceTypes allows linearizing inside Op replacements ([#2435](https://github.com/CQCL/hugr/pull/2435)) -- Add pass for DFG inlining ([#2460](https://github.com/CQCL/hugr/pull/2460)) -- Export entrypoint metadata in Python and fix bug in import ([#2434](https://github.com/CQCL/hugr/pull/2434)) -- Names of private functions become `core.title` metadata. ([#2448](https://github.com/CQCL/hugr/pull/2448)) -- [**breaking**] Use binary envelopes for operation lower_func encoding ([#2447](https://github.com/CQCL/hugr/pull/2447)) -- [**breaking**] Update portgraph dependency to 0.15 ([#2455](https://github.com/CQCL/hugr/pull/2455)) -- Detect and fail on unrecognised envelope flags ([#2453](https://github.com/CQCL/hugr/pull/2453)) -- include generator metatada in model import and cli validate errors ([#2452](https://github.com/CQCL/hugr/pull/2452)) -- [**breaking**] Add `insert_region` to HugrMut ([#2463](https://github.com/CQCL/hugr/pull/2463)) -- Non-region entrypoints in `hugr-model`. ([#2467](https://github.com/CQCL/hugr/pull/2467)) - -## [0.21.0](https://github.com/CQCL/hugr/compare/hugr-v0.20.2...hugr-v0.21.0) - 2025-07-09 - - -This release includes a long list of changes: - -- The HUGR model serialization format is now stable, and should be preferred over the old JSON format. -- Type parameters and type arguments are now unified into a single `Term` type. -- Function definitions can no longer be nested inside dataflow regions. Now they must be defined at the top level module. -- Function definitions and declarations now have a `Visibility` field, which define whether they are visible in the public API of the module. -- And many more fixes and improvements. - -### Bug Fixes - -- DeadFuncElimPass+CallGraph w/ non-module-child entrypoint ([#2390](https://github.com/CQCL/hugr/pull/2390)) -- Fixed two bugs in import/export of function operations ([#2324](https://github.com/CQCL/hugr/pull/2324)) -- Model import should perform extension resolution ([#2326](https://github.com/CQCL/hugr/pull/2326)) -- [**breaking**] Fixed bugs in model CFG handling and improved CFG signatures ([#2334](https://github.com/CQCL/hugr/pull/2334)) -- Use List instead of Tuple in conversions for TypeArg/TypeRow ([#2378](https://github.com/CQCL/hugr/pull/2378)) -- Do extension resolution on loaded extensions from the model format ([#2389](https://github.com/CQCL/hugr/pull/2389)) -- Make JSON Schema checks actually work again ([#2412](https://github.com/CQCL/hugr/pull/2412)) -- Order hints on input and output nodes. ([#2422](https://github.com/CQCL/hugr/pull/2422)) - -### Documentation - -- Hide hugr-persistent docs ([#2357](https://github.com/CQCL/hugr/pull/2357)) - -### New Features - -- [**breaking**] Split `TypeArg::Sequence` into tuples and lists. ([#2140](https://github.com/CQCL/hugr/pull/2140)) -- [**breaking**] Added float and bytes literal to core and python bindings. ([#2289](https://github.com/CQCL/hugr/pull/2289)) -- [**breaking**] More helpful error messages in model import ([#2272](https://github.com/CQCL/hugr/pull/2272)) -- [**breaking**] Better error reporting in `hugr-cli`. ([#2318](https://github.com/CQCL/hugr/pull/2318)) -- [**breaking**] Merge `TypeParam` and `TypeArg` into one `Term` type in Rust ([#2309](https://github.com/CQCL/hugr/pull/2309)) -- *(persistent)* Add serialisation for CommitStateSpace ([#2344](https://github.com/CQCL/hugr/pull/2344)) -- add TryFrom impls for TypeArg/TypeRow ([#2366](https://github.com/CQCL/hugr/pull/2366)) -- Add `MakeError` op ([#2377](https://github.com/CQCL/hugr/pull/2377)) -- Open lists and tuples in `Term` ([#2360](https://github.com/CQCL/hugr/pull/2360)) -- Call `FunctionBuilder::add_{in,out}put` for any AsMut ([#2376](https://github.com/CQCL/hugr/pull/2376)) -- Add Root checked methods to DataflowParentID ([#2382](https://github.com/CQCL/hugr/pull/2382)) -- Add PersistentWire type ([#2361](https://github.com/CQCL/hugr/pull/2361)) -- Add `BorrowArray` extension ([#2395](https://github.com/CQCL/hugr/pull/2395)) -- [**breaking**] Add Visibility to FuncDefn/FuncDecl. ([#2143](https://github.com/CQCL/hugr/pull/2143)) -- *(per)* [**breaking**] Support empty wires in commits ([#2349](https://github.com/CQCL/hugr/pull/2349)) -- [**breaking**] hugr-model use explicit Option, with ::Unspecified in capnp ([#2424](https://github.com/CQCL/hugr/pull/2424)) -- [**breaking**] No nested FuncDefns (or AliasDefns) ([#2256](https://github.com/CQCL/hugr/pull/2256)) -- [**breaking**] Rename 'Any' type bound to 'Linear' ([#2421](https://github.com/CQCL/hugr/pull/2421)) - -### Refactor - -- [**breaking**] remove deprecated runtime extension errors ([#2369](https://github.com/CQCL/hugr/pull/2369)) -- [**breaking**] Reduce error type sizes ([#2420](https://github.com/CQCL/hugr/pull/2420)) -- [**breaking**] move PersistentHugr into separate crate ([#2277](https://github.com/CQCL/hugr/pull/2277)) - -### Testing - -- Check hugr json serializations against the schema (again) ([#2216](https://github.com/CQCL/hugr/pull/2216)) - ## [0.20.2](https://github.com/CQCL/hugr/compare/hugr-v0.20.1...hugr-v0.20.2) - 2025-06-25 ### Bug Fixes diff --git a/hugr/Cargo.toml b/hugr/Cargo.toml index 895c147a60..d43ee4ae90 100644 --- a/hugr/Cargo.toml +++ b/hugr/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "hugr" -version = "0.22.1" +version = "0.20.2" edition = { workspace = true } rust-version = { workspace = true } @@ -28,14 +28,12 @@ declarative = ["hugr-core/declarative"] llvm = ["hugr-llvm/llvm14-0"] llvm-test = ["hugr-llvm/llvm14-0", "hugr-llvm/test-utils"] zstd = ["hugr-core/zstd"] -persistent_unstable = ["hugr-persistent"] [dependencies] -hugr-model = { path = "../hugr-model", version = "0.22.1" } -hugr-core = { path = "../hugr-core", version = "0.22.1" } -hugr-passes = { path = "../hugr-passes", version = "0.22.1" } -hugr-llvm = { path = "../hugr-llvm", version = "0.22.1", optional = true } -hugr-persistent = { path = "../hugr-persistent", version = "0.2.1", optional = true } +hugr-model = { path = "../hugr-model", version = "0.20.2" } +hugr-core = { path = "../hugr-core", version = "0.20.2" } +hugr-passes = { path = "../hugr-passes", version = "0.20.2" } +hugr-llvm = { path = "../hugr-llvm", version = "0.20.2", optional = true } [dev-dependencies] lazy_static = { workspace = true } diff --git a/hugr/benches/benchmarks/hugr/examples.rs b/hugr/benches/benchmarks/hugr/examples.rs index 8e274ac8ff..2fdd1a762e 100644 --- a/hugr/benches/benchmarks/hugr/examples.rs +++ b/hugr/benches/benchmarks/hugr/examples.rs @@ -3,8 +3,8 @@ use std::sync::{Arc, LazyLock}; use hugr::builder::{ - BuildError, CFGBuilder, DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer, HugrBuilder, - ModuleBuilder, + BuildError, CFGBuilder, Container, DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer, + HugrBuilder, ModuleBuilder, }; use hugr::extension::ExtensionRegistry; use hugr::extension::prelude::{bool_t, qb_t, usize_t}; diff --git a/hugr/benches/benchmarks/types.rs b/hugr/benches/benchmarks/types.rs index 5b564bddde..0ed0a12a05 100644 --- a/hugr/benches/benchmarks/types.rs +++ b/hugr/benches/benchmarks/types.rs @@ -13,7 +13,7 @@ fn make_complex_type() -> Type { let int = usize_t(); let q_register = Type::new_tuple(vec![qb; 8]); let b_register = Type::new_tuple(vec![int; 8]); - let q_alias = Type::new_alias(AliasDecl::new("QReg", TypeBound::Linear)); + let q_alias = Type::new_alias(AliasDecl::new("QReg", TypeBound::Any)); let sum = Type::new_sum([q_register, q_alias]); Type::new_function(Signature::new(vec![sum], vec![b_register])) } diff --git a/hugr/src/lib.rs b/hugr/src/lib.rs index ce460ce404..341d74f12b 100644 --- a/hugr/src/lib.rs +++ b/hugr/src/lib.rs @@ -140,10 +140,6 @@ pub use hugr_passes as algorithms; #[doc(inline)] pub use hugr_llvm as llvm; -#[cfg(feature = "persistent_unstable")] -#[doc(hidden)] // TODO: remove when stable -pub use hugr_persistent as persistent; - // Modules with hand-picked re-exports. pub mod hugr; diff --git a/justfile b/justfile index a092d45bf9..87771f9803 100644 --- a/justfile +++ b/justfile @@ -10,7 +10,7 @@ setup: # Run the pre-commit checks. check: - HUGR_TEST_SCHEMA=1 uv run pre-commit run --all-files + uv run pre-commit run --all-files # Run all the tests. test: test-rust test-python @@ -20,7 +20,7 @@ test-rust *TEST_ARGS: @# built into a binary build (without using `maturin`) @# @# This feature list should be kept in sync with the `hugr-py/pyproject.toml` - HUGR_TEST_SCHEMA=1 cargo test \ + cargo test \ --workspace \ --exclude 'hugr-py' \ --features 'hugr/declarative hugr/llvm hugr/llvm-test hugr/zstd' {{TEST_ARGS}} diff --git a/release-plz.toml b/release-plz.toml index ebedf1be07..4bc9f71047 100644 --- a/release-plz.toml +++ b/release-plz.toml @@ -22,11 +22,6 @@ pr_labels = ["release"] release_always = false [changelog] - -header = """# Changelog - -""" - sort_commits = "oldest" # Allowed conventional commit types @@ -74,7 +69,3 @@ version_group = "hugr" name = "hugr-llvm" release = true version_group = "hugr" - -[[package]] -name = "hugr-persistent" -release = true diff --git a/resources/test/hugr-no-visibility.hugr b/resources/test/hugr-no-visibility.hugr deleted file mode 100644 index e61f933966..0000000000 --- a/resources/test/hugr-no-visibility.hugr +++ /dev/null @@ -1,52 +0,0 @@ -HUGRiHJv?@{ - "modules": [ - { - "version": "live", - "nodes": [ - { - "parent": 0, - "op": "Module" - }, - { - "name":"polyfunc1", - "op":"FuncDecl", - "parent":0, - "signature":{ - "body":{ - "input":[], - "output":[] - }, - "params":[ - ] - } - }, - { - "name":"polyfunc2", - "op":"FuncDefn", - "parent":0, - "signature":{ - "body":{ - "input":[], - "output":[] - }, - "params":[ - ] - } - }, - { - "op": "Input", - "parent": 2, - "types": [] - }, - { - "op": "Output", - "parent": 2, - "types": [] - } - ], - "edges": [], - "encoder": null - } - ], - "extensions": [] -} diff --git a/scripts/check_extension_versions.py b/scripts/check_extension_versions.py deleted file mode 100644 index f871e23fea..0000000000 --- a/scripts/check_extension_versions.py +++ /dev/null @@ -1,90 +0,0 @@ -#!/usr/bin/env python - -import json -import subprocess -import sys -from pathlib import Path - - -def get_changed_files(target: str) -> list[Path]: - """Get list of changed extension files in the PR""" - # Use git to get the list of files changed compared to target - cmd = [ - "git", - "diff", - "--name-only", - target, - "--", - "specification/std_extensions/", - ] - result = subprocess.run(cmd, capture_output=True, text=True, check=True) # noqa: S603 - changed_files = [Path(f) for f in result.stdout.splitlines() if f.endswith(".json")] - return changed_files - - -def check_version_changes(changed_files: list[Path], target: str) -> list[str]: - """Check if versions have been updated in changed files""" - errors = [] - - for file_path in changed_files: - # Skip files that don't exist anymore (deleted files) - if not file_path.exists(): - continue - - # Get the version in the current branch - with file_path.open("r") as f: - current = json.load(f) - current_version = current.get("version") - - # Get the version in the target branch - try: - cmd = ["git", "show", f"{target}:{file_path}"] - result = subprocess.run(cmd, capture_output=True, text=True) # noqa: S603 - - if result.returncode == 0: - # File exists in target - target_content = json.loads(result.stdout) - target_version = target_content.get("version") - - if current_version == target_version: - errors.append( - f"Error: {file_path} was modified but version {current_version}" - " was not updated." - ) - else: - print( - f"Version updated in {file_path}: {target_version}" - f" -> {current_version}" - ) - - else: - # New file - no version check needed - pass - - except json.JSONDecodeError: - # File is new or not valid JSON in target - pass - return errors - - -def main() -> int: - target = sys.argv[1] if len(sys.argv) > 1 else "origin/main" - changed_files = get_changed_files(target) - if not changed_files: - print("No extension files changed.") - return 0 - - print(f"Changed extension files: {', '.join(map(str, changed_files))}") - - errors = check_version_changes(changed_files, target) - if errors: - for error in errors: - sys.stderr.write(error) - return 1 - - print("All changed extension files have updated versions.") - return 0 - - -if __name__ == "__main__": - sys.exit(main()) diff --git a/scripts/generate_schema.py b/scripts/generate_schema.py index eea90351ad..98661914ce 100644 --- a/scripts/generate_schema.py +++ b/scripts/generate_schema.py @@ -14,7 +14,7 @@ from pathlib import Path from pydantic import ConfigDict -from pydantic.json_schema import DEFAULT_REF_TEMPLATE, models_json_schema +from pydantic.json_schema import models_json_schema from hugr._serialization.extension import Extension, Package from hugr._serialization.serial_hugr import SerialHugr @@ -38,9 +38,6 @@ def write_schema( _, top_level_schema = models_json_schema( [(s, "validation") for s in schemas], title="HUGR schema" ) - top_level_schema["oneOf"] = [ - {"$ref": DEFAULT_REF_TEMPLATE.format(model=s.__name__)} for s in schemas - ] with path.open("w") as f: json.dump(top_level_schema, f, indent=4) diff --git a/specification/hugr.md b/specification/hugr.md index 2cf934fd97..dc8251f30a 100644 --- a/specification/hugr.md +++ b/specification/hugr.md @@ -248,11 +248,6 @@ edges. The following operations are *only* valid as immediate children of a - `AliasDecl`: an external type alias declaration. At link time this can be replaced with the definition. An alias declared with `AliasDecl` is equivalent to a named opaque type. -- `FuncDefn` : a function definition. Like `FuncDecl` but with a function body. - The function body is defined by the sibling graph formed by its children. - At link time `FuncDecl` nodes are replaced by `FuncDefn`. -- `AliasDefn`: type alias definition. At link time `AliasDecl` can be replaced with - `AliasDefn`. There may also be other [scoped definitions](#scoped-definitions). @@ -263,6 +258,11 @@ regions and control-flow regions: - `Const` : a static constant value of type T stored in the node weight. Like `FuncDecl` and `FuncDefn` this has one `Const` out-edge per use. +- `FuncDefn` : a function definition. Like `FuncDecl` but with a function body. + The function body is defined by the sibling graph formed by its children. + At link time `FuncDecl` nodes are replaced by `FuncDefn`. +- `AliasDefn`: type alias definition. At link time `AliasDecl` can be replaced with + `AliasDefn`. A **loadable HUGR** is a module HUGR where all input ports are connected and there are no `FuncDecl/AliasDecl` nodes. @@ -552,8 +552,11 @@ parent(n2) when the edge's locality is: Each of these localities have additional constraints as follows: 1. For Ext edges, we require parent(n1) == - parenti(n2) for some i\>1, *and* for Value edges only there must be a order edge from n1 to - parenti-1(n2). + parenti(n2) for some i\>1, *and* for Value edges only: + * there must be a order edge from n1 to + parenti-1(n2). + * None of the parentj(n2), for i\>j\>=1, + may be a FuncDefn node The order edge records the ordering requirement that results, i.e. it must be possible to @@ -566,6 +569,9 @@ Each of these localities have additional constraints as follows: For Static edges this order edge is not required since the source is guaranteed to causally precede the target. + The FuncDefn restriction means that FuncDefn really are static, + and do not capture runtime values from their environment. + 2. For Dom edges, we must have that parent2(n1) == parenti(n2) is a CFG-node, for some i\>1, **and** parent(n1) strictly dominates @@ -574,6 +580,8 @@ Each of these localities have additional constraints as follows: i\>1 allows the node to target an arbitrarily-deep descendant of the dominated block, similar to an Ext edge.) + The same FuncDefn restriction also applies here, on the parent(j)(n2) for i\>j\>=1 (of course j=i is the CFG and j=i-1 is the basic block). + Specifically, these rules allow for edges where in a given execution of the HUGR the source of the edge executes once, but the target may execute \>=0 times. @@ -771,7 +779,7 @@ existing metadata, given the node ID. engine)? Reserved metadata keys used by the HUGR tooling are prefixed with `core.`. -Use of this prefix by external tooling may cause issues. +Use of this prefix by external tooling may cause issues. #### Generator Metadata Tooling generating HUGR can specify some reserved metadata keys to be used for debugging @@ -824,7 +832,7 @@ copied or discarded (multiple or 0 links from on output port respectively): allows multiple (or 0) outgoing edges from an outport; also these types can be sent down `Const` edges. -Note that all dataflow inputs (`Value`, `Const` and `Function`) always require a single connection, regardless of whether the type is `Linear` or `Copyable`. +Note that all dataflow inputs (`Value`, `Const` and `Function`) always require a single connection, regardless of whether the type is `AnyType` or `Copyable`. **Rows** The `#` is a *row* which is a sequence of zero or more types. Types in the row can optionally be given names in metadata i.e. this does not affect behaviour of the HUGR. When writing literal types, we use `#` to distinguish between tuples and rows, e.g. `(int<1>,int<2>)` is a tuple while `Sum(#(int<1>),#(int<2>))` contains two rows. @@ -858,9 +866,6 @@ such declarations may include (bind) any number of type parameters, of kinds as TypeParam ::= Type(Any|Copyable) | BoundedUSize(u64|) -- note optional bound | Extensions - | String - | Bytes - | Float | List(TypeParam) -- homogeneous, any sized | Tuple([TypeParam]) -- heterogenous, fixed size | Opaque(Name, [TypeArg]) -- e.g. Opaque("Array", [5, Opaque("usize", [])]) @@ -878,26 +883,22 @@ TypeArgs appropriate for the function's TypeParams: ```haskell TypeArg ::= Type(Type) -- could be a variable of kind Type, or contain variable(s) | BoundedUSize(u64) - | String(String) - | Bytes([u8]) - | Float(f64) | Extensions(Extensions) -- may contain TypeArg's of kind Extensions - | List([TypeArg]) - | Tuple([TypeArg]) + | Sequence([TypeArg]) -- fits either a List or Tuple TypeParam | Opaque(Value) | Variable -- refers to an enclosing TypeParam (binder) of any kind above ``` For example, a Function node declaring a `TypeParam::Opaque("Array", [5, TypeArg::Type(Type::Opaque("usize"))])` means that any `Call` to it must statically provide a *value* that is an array of 5 `usize`s; -or a Function node declaring a `TypeParam::BoundedUSize(5)` and a `TypeParam::Type(Linear)` requires two TypeArgs, +or a Function node declaring a `TypeParam::BoundedUSize(5)` and a `TypeParam::Type(Any)` requires two TypeArgs, firstly a non-negative integer less than 5, secondly a type (which might be from an extension, e.g. `usize`). Given TypeArgs, the body of the Function node's type can be converted to a monomorphic signature by substitution, i.e. replacing each type variable in the body with the corresponding TypeArg. This is guaranteed to produce a valid type as long as the TypeArgs match the declared TypeParams, which can be checked in advance. -(Note that within a polymorphic type scheme, type variables of kind `List`, `Tuple` or `Opaque` will only be usable +(Note that within a polymorphic type scheme, type variables of kind `Sequence` or `Opaque` will only be usable as arguments to Opaque types---see [Extension System](#extension-system).) #### Row Variables @@ -909,16 +910,16 @@ treatment, as follows: but also a single `TypeArg::Type`. (This is purely a notational convenience.) For example, `Type::Function(usize, unit, )` is equivalent shorthand for `Type::Function(#(usize), #(unit), )`. -* When a `TypeArg::List` is provided as argument for such a TypeParam, we allow +* When a `TypeArg::Sequence` is provided as argument for such a TypeParam, we allow elements to be a mixture of both types (including variables of kind `TypeParam::Type(_)`) and also row variables. When such variables are instantiated - (with other `List`s) the elements of the inner `List` are spliced directly into - the outer (concatenating their elements), eliding the inner (`List`) wrapper. + (with other Sequences) the elements of the inner Sequence are spliced directly into + the outer (concatenating their elements), eliding the inner (Sequence) wrapper. For example, a polymorphic FuncDefn might declare a row variable X of kind `TypeParam::List(TypeParam::Type(Copyable))` and have as output a (tuple) type `Sum([#(X, usize)])`. A call that instantiates said type-parameter with -`TypeArg::List([usize, unit])` would then have output `Sum([#(usize, unit, usize)])`. +`TypeArg::Sequence([usize, unit])` would then have output `Sum([#(usize, unit, usize)])`. See [Declarative Format](#declarative-format) for more examples. diff --git a/specification/schema/hugr_schema_live.json b/specification/schema/hugr_schema_live.json index a0fd06f72a..4a5c38b0e6 100644 --- a/specification/schema/hugr_schema_live.json +++ b/specification/schema/hugr_schema_live.json @@ -130,41 +130,6 @@ "title": "BoundedNatParam", "type": "object" }, - "BytesArg": { - "additionalProperties": true, - "properties": { - "tya": { - "const": "Bytes", - "default": "Bytes", - "title": "Tya", - "type": "string" - }, - "value": { - "contentEncoding": "base64", - "description": "Base64-encoded byte string", - "title": "Value", - "type": "string" - } - }, - "required": [ - "value" - ], - "title": "BytesArg", - "type": "object" - }, - "BytesParam": { - "additionalProperties": true, - "properties": { - "tp": { - "const": "Bytes", - "default": "Bytes", - "title": "Tp", - "type": "string" - } - }, - "title": "BytesParam", - "type": "object" - }, "CFG": { "additionalProperties": true, "description": "A dataflow node which is defined by a child CFG.", @@ -581,7 +546,6 @@ "type": "object" }, "FixedHugr": { - "description": "Fixed HUGR used to define the lowering of an operation.\n\nArgs:\n extensions: Extensions used in the HUGR.\n hugr: Base64-encoded HUGR envelope.", "properties": { "extensions": { "items": { @@ -602,39 +566,6 @@ "title": "FixedHugr", "type": "object" }, - "FloatArg": { - "additionalProperties": true, - "properties": { - "tya": { - "const": "Float", - "default": "Float", - "title": "Tya", - "type": "string" - }, - "value": { - "title": "Value", - "type": "number" - } - }, - "required": [ - "value" - ], - "title": "FloatArg", - "type": "object" - }, - "FloatParam": { - "additionalProperties": true, - "properties": { - "tp": { - "const": "Float", - "default": "Float", - "title": "Tp", - "type": "string" - } - }, - "title": "FloatParam", - "type": "object" - }, "FromParamsBound": { "properties": { "b": { @@ -677,15 +608,6 @@ }, "signature": { "$ref": "#/$defs/PolyFuncType" - }, - "visibility": { - "default": "Public", - "enum": [ - "Public", - "Private" - ], - "title": "Visibility", - "type": "string" } }, "required": [ @@ -716,15 +638,6 @@ }, "signature": { "$ref": "#/$defs/PolyFuncType" - }, - "visibility": { - "default": "Private", - "enum": [ - "Public", - "Private" - ], - "title": "Visibility", - "type": "string" } }, "required": [ @@ -778,8 +691,7 @@ "type": "string" }, "hugr": { - "title": "Hugr", - "type": "string" + "title": "Hugr" } }, "required": [ @@ -849,29 +761,6 @@ "title": "Input", "type": "object" }, - "ListArg": { - "additionalProperties": true, - "properties": { - "tya": { - "const": "List", - "default": "List", - "title": "Tya", - "type": "string" - }, - "elems": { - "items": { - "$ref": "#/$defs/TypeArg" - }, - "title": "Elems", - "type": "array" - } - }, - "required": [ - "elems" - ], - "title": "ListArg", - "type": "object" - }, "ListParam": { "additionalProperties": true, "properties": { @@ -1282,6 +1171,29 @@ "title": "RowVar", "type": "object" }, + "SequenceArg": { + "additionalProperties": true, + "properties": { + "tya": { + "const": "Sequence", + "default": "Sequence", + "title": "Tya", + "type": "string" + }, + "elems": { + "items": { + "$ref": "#/$defs/TypeArg" + }, + "title": "Elems", + "type": "array" + } + }, + "required": [ + "elems" + ], + "title": "SequenceArg", + "type": "object" + }, "SerialHugr": { "additionalProperties": true, "description": "A serializable representation of a Hugr.", @@ -1464,30 +1376,17 @@ "description": "A Sum variant For any Sum type where this value meets the type of the variant indicated by the tag.", "properties": { "v": { + "const": "Sum", "default": "Sum", - "enum": [ - "Sum", - "Tuple" - ], "title": "ValueTag", "type": "string" }, "tag": { - "default": 0, - "title": "VariantTag", + "title": "Tag", "type": "integer" }, "typ": { - "anyOf": [ - { - "$ref": "#/$defs/SumType" - }, - { - "type": "null" - } - ], - "default": null, - "title": "SumType" + "$ref": "#/$defs/SumType" }, "vs": { "items": { @@ -1498,6 +1397,8 @@ } }, "required": [ + "tag", + "typ", "vs" ], "title": "SumValue", @@ -1582,50 +1483,51 @@ "title": "TailLoop", "type": "object" }, - "TupleArg": { + "TupleParam": { "additionalProperties": true, "properties": { - "tya": { + "tp": { "const": "Tuple", "default": "Tuple", - "title": "Tya", + "title": "Tp", "type": "string" }, - "elems": { + "params": { "items": { - "$ref": "#/$defs/TypeArg" + "$ref": "#/$defs/TypeParam" }, - "title": "Elems", + "title": "Params", "type": "array" } }, "required": [ - "elems" + "params" ], - "title": "TupleArg", + "title": "TupleParam", "type": "object" }, - "TupleParam": { + "TupleValue": { "additionalProperties": true, + "description": "A constant tuple value.", "properties": { - "tp": { + "v": { "const": "Tuple", "default": "Tuple", - "title": "Tp", + "title": "ValueTag", "type": "string" }, - "params": { + "vs": { "items": { - "$ref": "#/$defs/TypeParam" + "$ref": "#/$defs/Value" }, - "title": "Params", + "title": "Vs", "type": "array" } }, "required": [ - "params" + "vs" ], - "title": "TupleParam", + "title": "TupleValue", "type": "object" }, "Type": { @@ -1679,11 +1581,8 @@ "discriminator": { "mapping": { "BoundedNat": "#/$defs/BoundedNatArg", - "Bytes": "#/$defs/BytesArg", - "Float": "#/$defs/FloatArg", - "List": "#/$defs/ListArg", + "Sequence": "#/$defs/SequenceArg", "String": "#/$defs/StringArg", - "Tuple": "#/$defs/TupleArg", "Type": "#/$defs/TypeTypeArg", "Variable": "#/$defs/VariableArg" }, @@ -1700,16 +1599,7 @@ "$ref": "#/$defs/StringArg" }, { - "$ref": "#/$defs/BytesArg" - }, - { - "$ref": "#/$defs/FloatArg" - }, - { - "$ref": "#/$defs/ListArg" - }, - { - "$ref": "#/$defs/TupleArg" + "$ref": "#/$defs/SequenceArg" }, { "$ref": "#/$defs/VariableArg" @@ -1786,8 +1676,6 @@ "discriminator": { "mapping": { "BoundedNat": "#/$defs/BoundedNatParam", - "Bytes": "#/$defs/BytesParam", - "Float": "#/$defs/FloatParam", "List": "#/$defs/ListParam", "String": "#/$defs/StringParam", "Tuple": "#/$defs/TupleParam", @@ -1805,12 +1693,6 @@ { "$ref": "#/$defs/StringParam" }, - { - "$ref": "#/$defs/FloatParam" - }, - { - "$ref": "#/$defs/BytesParam" - }, { "$ref": "#/$defs/ListParam" }, @@ -1909,7 +1791,7 @@ "Extension": "#/$defs/CustomValue", "Function": "#/$defs/FunctionValue", "Sum": "#/$defs/SumValue", - "Tuple": "#/$defs/SumValue" + "Tuple": "#/$defs/TupleValue" }, "propertyName": "v" }, @@ -1920,6 +1802,9 @@ { "$ref": "#/$defs/FunctionValue" }, + { + "$ref": "#/$defs/TupleValue" + }, { "$ref": "#/$defs/SumValue" } @@ -1979,16 +1864,5 @@ "type": "object" } }, - "title": "HUGR schema", - "oneOf": [ - { - "$ref": "#/$defs/SerialHugr" - }, - { - "$ref": "#/$defs/Extension" - }, - { - "$ref": "#/$defs/Package" - } - ] + "title": "HUGR schema" } \ No newline at end of file diff --git a/specification/schema/hugr_schema_strict_live.json b/specification/schema/hugr_schema_strict_live.json index cd37e262cd..419eb86d43 100644 --- a/specification/schema/hugr_schema_strict_live.json +++ b/specification/schema/hugr_schema_strict_live.json @@ -130,41 +130,6 @@ "title": "BoundedNatParam", "type": "object" }, - "BytesArg": { - "additionalProperties": false, - "properties": { - "tya": { - "const": "Bytes", - "default": "Bytes", - "title": "Tya", - "type": "string" - }, - "value": { - "contentEncoding": "base64", - "description": "Base64-encoded byte string", - "title": "Value", - "type": "string" - } - }, - "required": [ - "value" - ], - "title": "BytesArg", - "type": "object" - }, - "BytesParam": { - "additionalProperties": false, - "properties": { - "tp": { - "const": "Bytes", - "default": "Bytes", - "title": "Tp", - "type": "string" - } - }, - "title": "BytesParam", - "type": "object" - }, "CFG": { "additionalProperties": false, "description": "A dataflow node which is defined by a child CFG.", @@ -581,7 +546,6 @@ "type": "object" }, "FixedHugr": { - "description": "Fixed HUGR used to define the lowering of an operation.\n\nArgs:\n extensions: Extensions used in the HUGR.\n hugr: Base64-encoded HUGR envelope.", "properties": { "extensions": { "items": { @@ -602,39 +566,6 @@ "title": "FixedHugr", "type": "object" }, - "FloatArg": { - "additionalProperties": false, - "properties": { - "tya": { - "const": "Float", - "default": "Float", - "title": "Tya", - "type": "string" - }, - "value": { - "title": "Value", - "type": "number" - } - }, - "required": [ - "value" - ], - "title": "FloatArg", - "type": "object" - }, - "FloatParam": { - "additionalProperties": false, - "properties": { - "tp": { - "const": "Float", - "default": "Float", - "title": "Tp", - "type": "string" - } - }, - "title": "FloatParam", - "type": "object" - }, "FromParamsBound": { "properties": { "b": { @@ -677,15 +608,6 @@ }, "signature": { "$ref": "#/$defs/PolyFuncType" - }, - "visibility": { - "default": "Public", - "enum": [ - "Public", - "Private" - ], - "title": "Visibility", - "type": "string" } }, "required": [ @@ -716,15 +638,6 @@ }, "signature": { "$ref": "#/$defs/PolyFuncType" - }, - "visibility": { - "default": "Private", - "enum": [ - "Public", - "Private" - ], - "title": "Visibility", - "type": "string" } }, "required": [ @@ -778,8 +691,7 @@ "type": "string" }, "hugr": { - "title": "Hugr", - "type": "string" + "title": "Hugr" } }, "required": [ @@ -849,29 +761,6 @@ "title": "Input", "type": "object" }, - "ListArg": { - "additionalProperties": false, - "properties": { - "tya": { - "const": "List", - "default": "List", - "title": "Tya", - "type": "string" - }, - "elems": { - "items": { - "$ref": "#/$defs/TypeArg" - }, - "title": "Elems", - "type": "array" - } - }, - "required": [ - "elems" - ], - "title": "ListArg", - "type": "object" - }, "ListParam": { "additionalProperties": false, "properties": { @@ -1282,6 +1171,29 @@ "title": "RowVar", "type": "object" }, + "SequenceArg": { + "additionalProperties": false, + "properties": { + "tya": { + "const": "Sequence", + "default": "Sequence", + "title": "Tya", + "type": "string" + }, + "elems": { + "items": { + "$ref": "#/$defs/TypeArg" + }, + "title": "Elems", + "type": "array" + } + }, + "required": [ + "elems" + ], + "title": "SequenceArg", + "type": "object" + }, "SerialHugr": { "additionalProperties": false, "description": "A serializable representation of a Hugr.", @@ -1464,30 +1376,17 @@ "description": "A Sum variant For any Sum type where this value meets the type of the variant indicated by the tag.", "properties": { "v": { + "const": "Sum", "default": "Sum", - "enum": [ - "Sum", - "Tuple" - ], "title": "ValueTag", "type": "string" }, "tag": { - "default": 0, - "title": "VariantTag", + "title": "Tag", "type": "integer" }, "typ": { - "anyOf": [ - { - "$ref": "#/$defs/SumType" - }, - { - "type": "null" - } - ], - "default": null, - "title": "SumType" + "$ref": "#/$defs/SumType" }, "vs": { "items": { @@ -1498,6 +1397,8 @@ } }, "required": [ + "tag", + "typ", "vs" ], "title": "SumValue", @@ -1582,50 +1483,51 @@ "title": "TailLoop", "type": "object" }, - "TupleArg": { + "TupleParam": { "additionalProperties": false, "properties": { - "tya": { + "tp": { "const": "Tuple", "default": "Tuple", - "title": "Tya", + "title": "Tp", "type": "string" }, - "elems": { + "params": { "items": { - "$ref": "#/$defs/TypeArg" + "$ref": "#/$defs/TypeParam" }, - "title": "Elems", + "title": "Params", "type": "array" } }, "required": [ - "elems" + "params" ], - "title": "TupleArg", + "title": "TupleParam", "type": "object" }, - "TupleParam": { + "TupleValue": { "additionalProperties": false, + "description": "A constant tuple value.", "properties": { - "tp": { + "v": { "const": "Tuple", "default": "Tuple", - "title": "Tp", + "title": "ValueTag", "type": "string" }, - "params": { + "vs": { "items": { - "$ref": "#/$defs/TypeParam" + "$ref": "#/$defs/Value" }, - "title": "Params", + "title": "Vs", "type": "array" } }, "required": [ - "params" + "vs" ], - "title": "TupleParam", + "title": "TupleValue", "type": "object" }, "Type": { @@ -1679,11 +1581,8 @@ "discriminator": { "mapping": { "BoundedNat": "#/$defs/BoundedNatArg", - "Bytes": "#/$defs/BytesArg", - "Float": "#/$defs/FloatArg", - "List": "#/$defs/ListArg", + "Sequence": "#/$defs/SequenceArg", "String": "#/$defs/StringArg", - "Tuple": "#/$defs/TupleArg", "Type": "#/$defs/TypeTypeArg", "Variable": "#/$defs/VariableArg" }, @@ -1700,16 +1599,7 @@ "$ref": "#/$defs/StringArg" }, { - "$ref": "#/$defs/BytesArg" - }, - { - "$ref": "#/$defs/FloatArg" - }, - { - "$ref": "#/$defs/ListArg" - }, - { - "$ref": "#/$defs/TupleArg" + "$ref": "#/$defs/SequenceArg" }, { "$ref": "#/$defs/VariableArg" @@ -1786,8 +1676,6 @@ "discriminator": { "mapping": { "BoundedNat": "#/$defs/BoundedNatParam", - "Bytes": "#/$defs/BytesParam", - "Float": "#/$defs/FloatParam", "List": "#/$defs/ListParam", "String": "#/$defs/StringParam", "Tuple": "#/$defs/TupleParam", @@ -1805,12 +1693,6 @@ { "$ref": "#/$defs/StringParam" }, - { - "$ref": "#/$defs/FloatParam" - }, - { - "$ref": "#/$defs/BytesParam" - }, { "$ref": "#/$defs/ListParam" }, @@ -1909,7 +1791,7 @@ "Extension": "#/$defs/CustomValue", "Function": "#/$defs/FunctionValue", "Sum": "#/$defs/SumValue", - "Tuple": "#/$defs/SumValue" + "Tuple": "#/$defs/TupleValue" }, "propertyName": "v" }, @@ -1920,6 +1802,9 @@ { "$ref": "#/$defs/FunctionValue" }, + { + "$ref": "#/$defs/TupleValue" + }, { "$ref": "#/$defs/SumValue" } @@ -1979,16 +1864,5 @@ "type": "object" } }, - "title": "HUGR schema", - "oneOf": [ - { - "$ref": "#/$defs/SerialHugr" - }, - { - "$ref": "#/$defs/Extension" - }, - { - "$ref": "#/$defs/Package" - } - ] + "title": "HUGR schema" } \ No newline at end of file diff --git a/specification/schema/testing_hugr_schema_live.json b/specification/schema/testing_hugr_schema_live.json index 157facb661..a9f483d3c4 100644 --- a/specification/schema/testing_hugr_schema_live.json +++ b/specification/schema/testing_hugr_schema_live.json @@ -130,41 +130,6 @@ "title": "BoundedNatParam", "type": "object" }, - "BytesArg": { - "additionalProperties": true, - "properties": { - "tya": { - "const": "Bytes", - "default": "Bytes", - "title": "Tya", - "type": "string" - }, - "value": { - "contentEncoding": "base64", - "description": "Base64-encoded byte string", - "title": "Value", - "type": "string" - } - }, - "required": [ - "value" - ], - "title": "BytesArg", - "type": "object" - }, - "BytesParam": { - "additionalProperties": true, - "properties": { - "tp": { - "const": "Bytes", - "default": "Bytes", - "title": "Tp", - "type": "string" - } - }, - "title": "BytesParam", - "type": "object" - }, "CFG": { "additionalProperties": true, "description": "A dataflow node which is defined by a child CFG.", @@ -581,7 +546,6 @@ "type": "object" }, "FixedHugr": { - "description": "Fixed HUGR used to define the lowering of an operation.\n\nArgs:\n extensions: Extensions used in the HUGR.\n hugr: Base64-encoded HUGR envelope.", "properties": { "extensions": { "items": { @@ -602,39 +566,6 @@ "title": "FixedHugr", "type": "object" }, - "FloatArg": { - "additionalProperties": true, - "properties": { - "tya": { - "const": "Float", - "default": "Float", - "title": "Tya", - "type": "string" - }, - "value": { - "title": "Value", - "type": "number" - } - }, - "required": [ - "value" - ], - "title": "FloatArg", - "type": "object" - }, - "FloatParam": { - "additionalProperties": true, - "properties": { - "tp": { - "const": "Float", - "default": "Float", - "title": "Tp", - "type": "string" - } - }, - "title": "FloatParam", - "type": "object" - }, "FromParamsBound": { "properties": { "b": { @@ -677,15 +608,6 @@ }, "signature": { "$ref": "#/$defs/PolyFuncType" - }, - "visibility": { - "default": "Public", - "enum": [ - "Public", - "Private" - ], - "title": "Visibility", - "type": "string" } }, "required": [ @@ -716,15 +638,6 @@ }, "signature": { "$ref": "#/$defs/PolyFuncType" - }, - "visibility": { - "default": "Private", - "enum": [ - "Public", - "Private" - ], - "title": "Visibility", - "type": "string" } }, "required": [ @@ -778,8 +691,7 @@ "type": "string" }, "hugr": { - "title": "Hugr", - "type": "string" + "title": "Hugr" } }, "required": [ @@ -849,29 +761,6 @@ "title": "Input", "type": "object" }, - "ListArg": { - "additionalProperties": true, - "properties": { - "tya": { - "const": "List", - "default": "List", - "title": "Tya", - "type": "string" - }, - "elems": { - "items": { - "$ref": "#/$defs/TypeArg" - }, - "title": "Elems", - "type": "array" - } - }, - "required": [ - "elems" - ], - "title": "ListArg", - "type": "object" - }, "ListParam": { "additionalProperties": true, "properties": { @@ -1282,6 +1171,29 @@ "title": "RowVar", "type": "object" }, + "SequenceArg": { + "additionalProperties": true, + "properties": { + "tya": { + "const": "Sequence", + "default": "Sequence", + "title": "Tya", + "type": "string" + }, + "elems": { + "items": { + "$ref": "#/$defs/TypeArg" + }, + "title": "Elems", + "type": "array" + } + }, + "required": [ + "elems" + ], + "title": "SequenceArg", + "type": "object" + }, "SerialHugr": { "description": "A serializable representation of a Hugr.", "properties": { @@ -1463,30 +1375,17 @@ "description": "A Sum variant For any Sum type where this value meets the type of the variant indicated by the tag.", "properties": { "v": { + "const": "Sum", "default": "Sum", - "enum": [ - "Sum", - "Tuple" - ], "title": "ValueTag", "type": "string" }, "tag": { - "default": 0, - "title": "VariantTag", + "title": "Tag", "type": "integer" }, "typ": { - "anyOf": [ - { - "$ref": "#/$defs/SumType" - }, - { - "type": "null" - } - ], - "default": null, - "title": "SumType" + "$ref": "#/$defs/SumType" }, "vs": { "items": { @@ -1497,6 +1396,8 @@ } }, "required": [ + "tag", + "typ", "vs" ], "title": "SumValue", @@ -1660,50 +1561,51 @@ "title": "TestingHugr", "type": "object" }, - "TupleArg": { + "TupleParam": { "additionalProperties": true, "properties": { - "tya": { + "tp": { "const": "Tuple", "default": "Tuple", - "title": "Tya", + "title": "Tp", "type": "string" }, - "elems": { + "params": { "items": { - "$ref": "#/$defs/TypeArg" + "$ref": "#/$defs/TypeParam" }, - "title": "Elems", + "title": "Params", "type": "array" } }, "required": [ - "elems" + "params" ], - "title": "TupleArg", + "title": "TupleParam", "type": "object" }, - "TupleParam": { + "TupleValue": { "additionalProperties": true, + "description": "A constant tuple value.", "properties": { - "tp": { + "v": { "const": "Tuple", "default": "Tuple", - "title": "Tp", + "title": "ValueTag", "type": "string" }, - "params": { + "vs": { "items": { - "$ref": "#/$defs/TypeParam" + "$ref": "#/$defs/Value" }, - "title": "Params", + "title": "Vs", "type": "array" } }, "required": [ - "params" + "vs" ], - "title": "TupleParam", + "title": "TupleValue", "type": "object" }, "Type": { @@ -1757,11 +1659,8 @@ "discriminator": { "mapping": { "BoundedNat": "#/$defs/BoundedNatArg", - "Bytes": "#/$defs/BytesArg", - "Float": "#/$defs/FloatArg", - "List": "#/$defs/ListArg", + "Sequence": "#/$defs/SequenceArg", "String": "#/$defs/StringArg", - "Tuple": "#/$defs/TupleArg", "Type": "#/$defs/TypeTypeArg", "Variable": "#/$defs/VariableArg" }, @@ -1778,16 +1677,7 @@ "$ref": "#/$defs/StringArg" }, { - "$ref": "#/$defs/BytesArg" - }, - { - "$ref": "#/$defs/FloatArg" - }, - { - "$ref": "#/$defs/ListArg" - }, - { - "$ref": "#/$defs/TupleArg" + "$ref": "#/$defs/SequenceArg" }, { "$ref": "#/$defs/VariableArg" @@ -1864,8 +1754,6 @@ "discriminator": { "mapping": { "BoundedNat": "#/$defs/BoundedNatParam", - "Bytes": "#/$defs/BytesParam", - "Float": "#/$defs/FloatParam", "List": "#/$defs/ListParam", "String": "#/$defs/StringParam", "Tuple": "#/$defs/TupleParam", @@ -1883,12 +1771,6 @@ { "$ref": "#/$defs/StringParam" }, - { - "$ref": "#/$defs/FloatParam" - }, - { - "$ref": "#/$defs/BytesParam" - }, { "$ref": "#/$defs/ListParam" }, @@ -1987,7 +1869,7 @@ "Extension": "#/$defs/CustomValue", "Function": "#/$defs/FunctionValue", "Sum": "#/$defs/SumValue", - "Tuple": "#/$defs/SumValue" + "Tuple": "#/$defs/TupleValue" }, "propertyName": "v" }, @@ -1998,6 +1880,9 @@ { "$ref": "#/$defs/FunctionValue" }, + { + "$ref": "#/$defs/TupleValue" + }, { "$ref": "#/$defs/SumValue" } @@ -2057,16 +1942,5 @@ "type": "object" } }, - "title": "HUGR schema", - "oneOf": [ - { - "$ref": "#/$defs/TestingHugr" - }, - { - "$ref": "#/$defs/Extension" - }, - { - "$ref": "#/$defs/Package" - } - ] + "title": "HUGR schema" } \ No newline at end of file diff --git a/specification/schema/testing_hugr_schema_strict_live.json b/specification/schema/testing_hugr_schema_strict_live.json index 33244f3ed7..108f69f2f4 100644 --- a/specification/schema/testing_hugr_schema_strict_live.json +++ b/specification/schema/testing_hugr_schema_strict_live.json @@ -130,41 +130,6 @@ "title": "BoundedNatParam", "type": "object" }, - "BytesArg": { - "additionalProperties": false, - "properties": { - "tya": { - "const": "Bytes", - "default": "Bytes", - "title": "Tya", - "type": "string" - }, - "value": { - "contentEncoding": "base64", - "description": "Base64-encoded byte string", - "title": "Value", - "type": "string" - } - }, - "required": [ - "value" - ], - "title": "BytesArg", - "type": "object" - }, - "BytesParam": { - "additionalProperties": false, - "properties": { - "tp": { - "const": "Bytes", - "default": "Bytes", - "title": "Tp", - "type": "string" - } - }, - "title": "BytesParam", - "type": "object" - }, "CFG": { "additionalProperties": false, "description": "A dataflow node which is defined by a child CFG.", @@ -581,7 +546,6 @@ "type": "object" }, "FixedHugr": { - "description": "Fixed HUGR used to define the lowering of an operation.\n\nArgs:\n extensions: Extensions used in the HUGR.\n hugr: Base64-encoded HUGR envelope.", "properties": { "extensions": { "items": { @@ -602,39 +566,6 @@ "title": "FixedHugr", "type": "object" }, - "FloatArg": { - "additionalProperties": false, - "properties": { - "tya": { - "const": "Float", - "default": "Float", - "title": "Tya", - "type": "string" - }, - "value": { - "title": "Value", - "type": "number" - } - }, - "required": [ - "value" - ], - "title": "FloatArg", - "type": "object" - }, - "FloatParam": { - "additionalProperties": false, - "properties": { - "tp": { - "const": "Float", - "default": "Float", - "title": "Tp", - "type": "string" - } - }, - "title": "FloatParam", - "type": "object" - }, "FromParamsBound": { "properties": { "b": { @@ -677,15 +608,6 @@ }, "signature": { "$ref": "#/$defs/PolyFuncType" - }, - "visibility": { - "default": "Public", - "enum": [ - "Public", - "Private" - ], - "title": "Visibility", - "type": "string" } }, "required": [ @@ -716,15 +638,6 @@ }, "signature": { "$ref": "#/$defs/PolyFuncType" - }, - "visibility": { - "default": "Private", - "enum": [ - "Public", - "Private" - ], - "title": "Visibility", - "type": "string" } }, "required": [ @@ -778,8 +691,7 @@ "type": "string" }, "hugr": { - "title": "Hugr", - "type": "string" + "title": "Hugr" } }, "required": [ @@ -849,29 +761,6 @@ "title": "Input", "type": "object" }, - "ListArg": { - "additionalProperties": false, - "properties": { - "tya": { - "const": "List", - "default": "List", - "title": "Tya", - "type": "string" - }, - "elems": { - "items": { - "$ref": "#/$defs/TypeArg" - }, - "title": "Elems", - "type": "array" - } - }, - "required": [ - "elems" - ], - "title": "ListArg", - "type": "object" - }, "ListParam": { "additionalProperties": false, "properties": { @@ -1282,6 +1171,29 @@ "title": "RowVar", "type": "object" }, + "SequenceArg": { + "additionalProperties": false, + "properties": { + "tya": { + "const": "Sequence", + "default": "Sequence", + "title": "Tya", + "type": "string" + }, + "elems": { + "items": { + "$ref": "#/$defs/TypeArg" + }, + "title": "Elems", + "type": "array" + } + }, + "required": [ + "elems" + ], + "title": "SequenceArg", + "type": "object" + }, "SerialHugr": { "description": "A serializable representation of a Hugr.", "properties": { @@ -1463,30 +1375,17 @@ "description": "A Sum variant For any Sum type where this value meets the type of the variant indicated by the tag.", "properties": { "v": { + "const": "Sum", "default": "Sum", - "enum": [ - "Sum", - "Tuple" - ], "title": "ValueTag", "type": "string" }, "tag": { - "default": 0, - "title": "VariantTag", + "title": "Tag", "type": "integer" }, "typ": { - "anyOf": [ - { - "$ref": "#/$defs/SumType" - }, - { - "type": "null" - } - ], - "default": null, - "title": "SumType" + "$ref": "#/$defs/SumType" }, "vs": { "items": { @@ -1497,6 +1396,8 @@ } }, "required": [ + "tag", + "typ", "vs" ], "title": "SumValue", @@ -1660,50 +1561,51 @@ "title": "TestingHugr", "type": "object" }, - "TupleArg": { + "TupleParam": { "additionalProperties": false, "properties": { - "tya": { + "tp": { "const": "Tuple", "default": "Tuple", - "title": "Tya", + "title": "Tp", "type": "string" }, - "elems": { + "params": { "items": { - "$ref": "#/$defs/TypeArg" + "$ref": "#/$defs/TypeParam" }, - "title": "Elems", + "title": "Params", "type": "array" } }, "required": [ - "elems" + "params" ], - "title": "TupleArg", + "title": "TupleParam", "type": "object" }, - "TupleParam": { + "TupleValue": { "additionalProperties": false, + "description": "A constant tuple value.", "properties": { - "tp": { + "v": { "const": "Tuple", "default": "Tuple", - "title": "Tp", + "title": "ValueTag", "type": "string" }, - "params": { + "vs": { "items": { - "$ref": "#/$defs/TypeParam" + "$ref": "#/$defs/Value" }, - "title": "Params", + "title": "Vs", "type": "array" } }, "required": [ - "params" + "vs" ], - "title": "TupleParam", + "title": "TupleValue", "type": "object" }, "Type": { @@ -1757,11 +1659,8 @@ "discriminator": { "mapping": { "BoundedNat": "#/$defs/BoundedNatArg", - "Bytes": "#/$defs/BytesArg", - "Float": "#/$defs/FloatArg", - "List": "#/$defs/ListArg", + "Sequence": "#/$defs/SequenceArg", "String": "#/$defs/StringArg", - "Tuple": "#/$defs/TupleArg", "Type": "#/$defs/TypeTypeArg", "Variable": "#/$defs/VariableArg" }, @@ -1778,16 +1677,7 @@ "$ref": "#/$defs/StringArg" }, { - "$ref": "#/$defs/BytesArg" - }, - { - "$ref": "#/$defs/FloatArg" - }, - { - "$ref": "#/$defs/ListArg" - }, - { - "$ref": "#/$defs/TupleArg" + "$ref": "#/$defs/SequenceArg" }, { "$ref": "#/$defs/VariableArg" @@ -1864,8 +1754,6 @@ "discriminator": { "mapping": { "BoundedNat": "#/$defs/BoundedNatParam", - "Bytes": "#/$defs/BytesParam", - "Float": "#/$defs/FloatParam", "List": "#/$defs/ListParam", "String": "#/$defs/StringParam", "Tuple": "#/$defs/TupleParam", @@ -1883,12 +1771,6 @@ { "$ref": "#/$defs/StringParam" }, - { - "$ref": "#/$defs/FloatParam" - }, - { - "$ref": "#/$defs/BytesParam" - }, { "$ref": "#/$defs/ListParam" }, @@ -1987,7 +1869,7 @@ "Extension": "#/$defs/CustomValue", "Function": "#/$defs/FunctionValue", "Sum": "#/$defs/SumValue", - "Tuple": "#/$defs/SumValue" + "Tuple": "#/$defs/TupleValue" }, "propertyName": "v" }, @@ -1998,6 +1880,9 @@ { "$ref": "#/$defs/FunctionValue" }, + { + "$ref": "#/$defs/TupleValue" + }, { "$ref": "#/$defs/SumValue" } @@ -2057,16 +1942,5 @@ "type": "object" } }, - "title": "HUGR schema", - "oneOf": [ - { - "$ref": "#/$defs/TestingHugr" - }, - { - "$ref": "#/$defs/Extension" - }, - { - "$ref": "#/$defs/Package" - } - ] + "title": "HUGR schema" } \ No newline at end of file diff --git a/specification/std_extensions/collections/borrow_arr.json b/specification/std_extensions/collections/borrow_arr.json deleted file mode 100644 index 1774b4aea6..0000000000 --- a/specification/std_extensions/collections/borrow_arr.json +++ /dev/null @@ -1,1139 +0,0 @@ -{ - "version": "0.1.1", - "name": "collections.borrow_arr", - "types": { - "borrow_array": { - "extension": "collections.borrow_arr", - "name": "borrow_array", - "params": [ - { - "tp": "BoundedNat", - "bound": null - }, - { - "tp": "Type", - "b": "A" - } - ], - "description": "Fixed-length borrow array", - "bound": { - "b": "Explicit", - "bound": "A" - } - } - }, - "operations": { - "borrow": { - "extension": "collections.borrow_arr", - "name": "borrow", - "description": "Take an element from a borrow array (panicking if it was already taken before)", - "signature": { - "params": [ - { - "tp": "BoundedNat", - "bound": null - }, - { - "tp": "Type", - "b": "A" - } - ], - "body": { - "input": [ - { - "t": "Opaque", - "extension": "collections.borrow_arr", - "id": "borrow_array", - "args": [ - { - "tya": "Variable", - "idx": 0, - "cached_decl": { - "tp": "BoundedNat", - "bound": null - } - }, - { - "tya": "Type", - "ty": { - "t": "V", - "i": 1, - "b": "A" - } - } - ], - "bound": "A" - }, - { - "t": "I" - } - ], - "output": [ - { - "t": "V", - "i": 1, - "b": "A" - }, - { - "t": "Opaque", - "extension": "collections.borrow_arr", - "id": "borrow_array", - "args": [ - { - "tya": "Variable", - "idx": 0, - "cached_decl": { - "tp": "BoundedNat", - "bound": null - } - }, - { - "tya": "Type", - "ty": { - "t": "V", - "i": 1, - "b": "A" - } - } - ], - "bound": "A" - } - ] - } - }, - "binary": false - }, - "clone": { - "extension": "collections.borrow_arr", - "name": "clone", - "description": "Clones an array with copyable elements", - "signature": { - "params": [ - { - "tp": "BoundedNat", - "bound": null - }, - { - "tp": "Type", - "b": "C" - } - ], - "body": { - "input": [ - { - "t": "Opaque", - "extension": "collections.borrow_arr", - "id": "borrow_array", - "args": [ - { - "tya": "Variable", - "idx": 0, - "cached_decl": { - "tp": "BoundedNat", - "bound": null - } - }, - { - "tya": "Type", - "ty": { - "t": "V", - "i": 1, - "b": "C" - } - } - ], - "bound": "A" - } - ], - "output": [ - { - "t": "Opaque", - "extension": "collections.borrow_arr", - "id": "borrow_array", - "args": [ - { - "tya": "Variable", - "idx": 0, - "cached_decl": { - "tp": "BoundedNat", - "bound": null - } - }, - { - "tya": "Type", - "ty": { - "t": "V", - "i": 1, - "b": "C" - } - } - ], - "bound": "A" - }, - { - "t": "Opaque", - "extension": "collections.borrow_arr", - "id": "borrow_array", - "args": [ - { - "tya": "Variable", - "idx": 0, - "cached_decl": { - "tp": "BoundedNat", - "bound": null - } - }, - { - "tya": "Type", - "ty": { - "t": "V", - "i": 1, - "b": "C" - } - } - ], - "bound": "A" - } - ] - } - }, - "binary": false - }, - "discard": { - "extension": "collections.borrow_arr", - "name": "discard", - "description": "Discards an array with copyable elements", - "signature": { - "params": [ - { - "tp": "BoundedNat", - "bound": null - }, - { - "tp": "Type", - "b": "C" - } - ], - "body": { - "input": [ - { - "t": "Opaque", - "extension": "collections.borrow_arr", - "id": "borrow_array", - "args": [ - { - "tya": "Variable", - "idx": 0, - "cached_decl": { - "tp": "BoundedNat", - "bound": null - } - }, - { - "tya": "Type", - "ty": { - "t": "V", - "i": 1, - "b": "C" - } - } - ], - "bound": "A" - } - ], - "output": [] - } - }, - "binary": false - }, - "discard_all_borrowed": { - "extension": "collections.borrow_arr", - "name": "discard_all_borrowed", - "description": "Discard a borrow array where all elements have been borrowed", - "signature": { - "params": [ - { - "tp": "BoundedNat", - "bound": null - }, - { - "tp": "Type", - "b": "A" - } - ], - "body": { - "input": [ - { - "t": "Opaque", - "extension": "collections.borrow_arr", - "id": "borrow_array", - "args": [ - { - "tya": "Variable", - "idx": 0, - "cached_decl": { - "tp": "BoundedNat", - "bound": null - } - }, - { - "tya": "Type", - "ty": { - "t": "V", - "i": 1, - "b": "A" - } - } - ], - "bound": "A" - } - ], - "output": [] - } - }, - "binary": false - }, - "discard_empty": { - "extension": "collections.borrow_arr", - "name": "discard_empty", - "description": "Discard an empty array", - "signature": { - "params": [ - { - "tp": "Type", - "b": "A" - } - ], - "body": { - "input": [ - { - "t": "Opaque", - "extension": "collections.borrow_arr", - "id": "borrow_array", - "args": [ - { - "tya": "BoundedNat", - "n": 0 - }, - { - "tya": "Type", - "ty": { - "t": "V", - "i": 0, - "b": "A" - } - } - ], - "bound": "A" - } - ], - "output": [] - } - }, - "binary": false - }, - "from_array": { - "extension": "collections.borrow_arr", - "name": "from_array", - "description": "Turns `array` into `borrow_array`", - "signature": { - "params": [ - { - "tp": "BoundedNat", - "bound": null - }, - { - "tp": "Type", - "b": "A" - } - ], - "body": { - "input": [ - { - "t": "Opaque", - "extension": "collections.array", - "id": "array", - "args": [ - { - "tya": "Variable", - "idx": 0, - "cached_decl": { - "tp": "BoundedNat", - "bound": null - } - }, - { - "tya": "Type", - "ty": { - "t": "V", - "i": 1, - "b": "A" - } - } - ], - "bound": "A" - } - ], - "output": [ - { - "t": "Opaque", - "extension": "collections.borrow_arr", - "id": "borrow_array", - "args": [ - { - "tya": "Variable", - "idx": 0, - "cached_decl": { - "tp": "BoundedNat", - "bound": null - } - }, - { - "tya": "Type", - "ty": { - "t": "V", - "i": 1, - "b": "A" - } - } - ], - "bound": "A" - } - ] - } - }, - "binary": false - }, - "get": { - "extension": "collections.borrow_arr", - "name": "get", - "description": "Get an element from an array", - "signature": { - "params": [ - { - "tp": "BoundedNat", - "bound": null - }, - { - "tp": "Type", - "b": "C" - } - ], - "body": { - "input": [ - { - "t": "Opaque", - "extension": "collections.borrow_arr", - "id": "borrow_array", - "args": [ - { - "tya": "Variable", - "idx": 0, - "cached_decl": { - "tp": "BoundedNat", - "bound": null - } - }, - { - "tya": "Type", - "ty": { - "t": "V", - "i": 1, - "b": "C" - } - } - ], - "bound": "A" - }, - { - "t": "I" - } - ], - "output": [ - { - "t": "Sum", - "s": "General", - "rows": [ - [], - [ - { - "t": "V", - "i": 1, - "b": "C" - } - ] - ] - }, - { - "t": "Opaque", - "extension": "collections.borrow_arr", - "id": "borrow_array", - "args": [ - { - "tya": "Variable", - "idx": 0, - "cached_decl": { - "tp": "BoundedNat", - "bound": null - } - }, - { - "tya": "Type", - "ty": { - "t": "V", - "i": 1, - "b": "C" - } - } - ], - "bound": "A" - } - ] - } - }, - "binary": false - }, - "new_all_borrowed": { - "extension": "collections.borrow_arr", - "name": "new_all_borrowed", - "description": "Create a new borrow array that contains no elements", - "signature": { - "params": [ - { - "tp": "BoundedNat", - "bound": null - }, - { - "tp": "Type", - "b": "A" - } - ], - "body": { - "input": [], - "output": [ - { - "t": "Opaque", - "extension": "collections.borrow_arr", - "id": "borrow_array", - "args": [ - { - "tya": "Variable", - "idx": 0, - "cached_decl": { - "tp": "BoundedNat", - "bound": null - } - }, - { - "tya": "Type", - "ty": { - "t": "V", - "i": 1, - "b": "A" - } - } - ], - "bound": "A" - } - ] - } - }, - "binary": false - }, - "new_array": { - "extension": "collections.borrow_arr", - "name": "new_array", - "description": "Create a new array from elements", - "signature": null, - "binary": true - }, - "pop_left": { - "extension": "collections.borrow_arr", - "name": "pop_left", - "description": "Pop an element from the left of an array", - "signature": null, - "binary": true - }, - "pop_right": { - "extension": "collections.borrow_arr", - "name": "pop_right", - "description": "Pop an element from the right of an array", - "signature": null, - "binary": true - }, - "repeat": { - "extension": "collections.borrow_arr", - "name": "repeat", - "description": "Creates a new array whose elements are initialised by calling the given function n times", - "signature": { - "params": [ - { - "tp": "BoundedNat", - "bound": null - }, - { - "tp": "Type", - "b": "A" - } - ], - "body": { - "input": [ - { - "t": "G", - "input": [], - "output": [ - { - "t": "V", - "i": 1, - "b": "A" - } - ] - } - ], - "output": [ - { - "t": "Opaque", - "extension": "collections.borrow_arr", - "id": "borrow_array", - "args": [ - { - "tya": "Variable", - "idx": 0, - "cached_decl": { - "tp": "BoundedNat", - "bound": null - } - }, - { - "tya": "Type", - "ty": { - "t": "V", - "i": 1, - "b": "A" - } - } - ], - "bound": "A" - } - ] - } - }, - "binary": false - }, - "return": { - "extension": "collections.borrow_arr", - "name": "return", - "description": "Put an element into a borrow array (panicking if there is an element already)", - "signature": { - "params": [ - { - "tp": "BoundedNat", - "bound": null - }, - { - "tp": "Type", - "b": "A" - } - ], - "body": { - "input": [ - { - "t": "Opaque", - "extension": "collections.borrow_arr", - "id": "borrow_array", - "args": [ - { - "tya": "Variable", - "idx": 0, - "cached_decl": { - "tp": "BoundedNat", - "bound": null - } - }, - { - "tya": "Type", - "ty": { - "t": "V", - "i": 1, - "b": "A" - } - } - ], - "bound": "A" - }, - { - "t": "I" - }, - { - "t": "V", - "i": 1, - "b": "A" - } - ], - "output": [ - { - "t": "Opaque", - "extension": "collections.borrow_arr", - "id": "borrow_array", - "args": [ - { - "tya": "Variable", - "idx": 0, - "cached_decl": { - "tp": "BoundedNat", - "bound": null - } - }, - { - "tya": "Type", - "ty": { - "t": "V", - "i": 1, - "b": "A" - } - } - ], - "bound": "A" - } - ] - } - }, - "binary": false - }, - "scan": { - "extension": "collections.borrow_arr", - "name": "scan", - "description": "A combination of map and foldl. Applies a function to each element of the array with an accumulator that is passed through from start to finish. Returns the resulting array and the final state of the accumulator.", - "signature": { - "params": [ - { - "tp": "BoundedNat", - "bound": null - }, - { - "tp": "Type", - "b": "A" - }, - { - "tp": "Type", - "b": "A" - }, - { - "tp": "List", - "param": { - "tp": "Type", - "b": "A" - } - } - ], - "body": { - "input": [ - { - "t": "Opaque", - "extension": "collections.borrow_arr", - "id": "borrow_array", - "args": [ - { - "tya": "Variable", - "idx": 0, - "cached_decl": { - "tp": "BoundedNat", - "bound": null - } - }, - { - "tya": "Type", - "ty": { - "t": "V", - "i": 1, - "b": "A" - } - } - ], - "bound": "A" - }, - { - "t": "G", - "input": [ - { - "t": "V", - "i": 1, - "b": "A" - }, - { - "t": "R", - "i": 3, - "b": "A" - } - ], - "output": [ - { - "t": "V", - "i": 2, - "b": "A" - }, - { - "t": "R", - "i": 3, - "b": "A" - } - ] - }, - { - "t": "R", - "i": 3, - "b": "A" - } - ], - "output": [ - { - "t": "Opaque", - "extension": "collections.borrow_arr", - "id": "borrow_array", - "args": [ - { - "tya": "Variable", - "idx": 0, - "cached_decl": { - "tp": "BoundedNat", - "bound": null - } - }, - { - "tya": "Type", - "ty": { - "t": "V", - "i": 2, - "b": "A" - } - } - ], - "bound": "A" - }, - { - "t": "R", - "i": 3, - "b": "A" - } - ] - } - }, - "binary": false - }, - "set": { - "extension": "collections.borrow_arr", - "name": "set", - "description": "Set an element in an array", - "signature": { - "params": [ - { - "tp": "BoundedNat", - "bound": null - }, - { - "tp": "Type", - "b": "A" - } - ], - "body": { - "input": [ - { - "t": "Opaque", - "extension": "collections.borrow_arr", - "id": "borrow_array", - "args": [ - { - "tya": "Variable", - "idx": 0, - "cached_decl": { - "tp": "BoundedNat", - "bound": null - } - }, - { - "tya": "Type", - "ty": { - "t": "V", - "i": 1, - "b": "A" - } - } - ], - "bound": "A" - }, - { - "t": "I" - }, - { - "t": "V", - "i": 1, - "b": "A" - } - ], - "output": [ - { - "t": "Sum", - "s": "General", - "rows": [ - [ - { - "t": "V", - "i": 1, - "b": "A" - }, - { - "t": "Opaque", - "extension": "collections.borrow_arr", - "id": "borrow_array", - "args": [ - { - "tya": "Variable", - "idx": 0, - "cached_decl": { - "tp": "BoundedNat", - "bound": null - } - }, - { - "tya": "Type", - "ty": { - "t": "V", - "i": 1, - "b": "A" - } - } - ], - "bound": "A" - } - ], - [ - { - "t": "V", - "i": 1, - "b": "A" - }, - { - "t": "Opaque", - "extension": "collections.borrow_arr", - "id": "borrow_array", - "args": [ - { - "tya": "Variable", - "idx": 0, - "cached_decl": { - "tp": "BoundedNat", - "bound": null - } - }, - { - "tya": "Type", - "ty": { - "t": "V", - "i": 1, - "b": "A" - } - } - ], - "bound": "A" - } - ] - ] - } - ] - } - }, - "binary": false - }, - "swap": { - "extension": "collections.borrow_arr", - "name": "swap", - "description": "Swap two elements in an array", - "signature": { - "params": [ - { - "tp": "BoundedNat", - "bound": null - }, - { - "tp": "Type", - "b": "A" - } - ], - "body": { - "input": [ - { - "t": "Opaque", - "extension": "collections.borrow_arr", - "id": "borrow_array", - "args": [ - { - "tya": "Variable", - "idx": 0, - "cached_decl": { - "tp": "BoundedNat", - "bound": null - } - }, - { - "tya": "Type", - "ty": { - "t": "V", - "i": 1, - "b": "A" - } - } - ], - "bound": "A" - }, - { - "t": "I" - }, - { - "t": "I" - } - ], - "output": [ - { - "t": "Sum", - "s": "General", - "rows": [ - [ - { - "t": "Opaque", - "extension": "collections.borrow_arr", - "id": "borrow_array", - "args": [ - { - "tya": "Variable", - "idx": 0, - "cached_decl": { - "tp": "BoundedNat", - "bound": null - } - }, - { - "tya": "Type", - "ty": { - "t": "V", - "i": 1, - "b": "A" - } - } - ], - "bound": "A" - } - ], - [ - { - "t": "Opaque", - "extension": "collections.borrow_arr", - "id": "borrow_array", - "args": [ - { - "tya": "Variable", - "idx": 0, - "cached_decl": { - "tp": "BoundedNat", - "bound": null - } - }, - { - "tya": "Type", - "ty": { - "t": "V", - "i": 1, - "b": "A" - } - } - ], - "bound": "A" - } - ] - ] - } - ] - } - }, - "binary": false - }, - "to_array": { - "extension": "collections.borrow_arr", - "name": "to_array", - "description": "Turns `borrow_array` into `array`", - "signature": { - "params": [ - { - "tp": "BoundedNat", - "bound": null - }, - { - "tp": "Type", - "b": "A" - } - ], - "body": { - "input": [ - { - "t": "Opaque", - "extension": "collections.borrow_arr", - "id": "borrow_array", - "args": [ - { - "tya": "Variable", - "idx": 0, - "cached_decl": { - "tp": "BoundedNat", - "bound": null - } - }, - { - "tya": "Type", - "ty": { - "t": "V", - "i": 1, - "b": "A" - } - } - ], - "bound": "A" - } - ], - "output": [ - { - "t": "Opaque", - "extension": "collections.array", - "id": "array", - "args": [ - { - "tya": "Variable", - "idx": 0, - "cached_decl": { - "tp": "BoundedNat", - "bound": null - } - }, - { - "tya": "Type", - "ty": { - "t": "V", - "i": 1, - "b": "A" - } - } - ], - "bound": "A" - } - ] - } - }, - "binary": false - }, - "unpack": { - "extension": "collections.borrow_arr", - "name": "unpack", - "description": "Unpack an array into its elements", - "signature": null, - "binary": true - } - } -} diff --git a/specification/std_extensions/prelude.json b/specification/std_extensions/prelude.json index 81c2f948a0..7cf1d02c70 100644 --- a/specification/std_extensions/prelude.json +++ b/specification/std_extensions/prelude.json @@ -1,5 +1,5 @@ { - "version": "0.2.1", + "version": "0.2.0", "name": "prelude", "types": { "error": { @@ -77,38 +77,6 @@ }, "binary": false }, - "MakeError": { - "extension": "prelude", - "name": "MakeError", - "description": "Create an error value", - "signature": { - "params": [], - "body": { - "input": [ - { - "t": "I" - }, - { - "t": "Opaque", - "extension": "prelude", - "id": "string", - "args": [], - "bound": "C" - } - ], - "output": [ - { - "t": "Opaque", - "extension": "prelude", - "id": "error", - "args": [], - "bound": "C" - } - ] - } - }, - "binary": false - }, "MakeTuple": { "extension": "prelude", "name": "MakeTuple", diff --git a/uv.lock b/uv.lock index e2a327aaae..35a6661182 100644 --- a/uv.lock +++ b/uv.lock @@ -281,7 +281,7 @@ wheels = [ [[package]] name = "hugr" -version = "0.13.0rc1" +version = "0.12.1" source = { editable = "hugr-py" } dependencies = [ { name = "graphviz" }, From 640629d394c033ca06a38e19387d891edde22f46 Mon Sep 17 00:00:00 2001 From: Jenny Chen Date: Mon, 28 Jul 2025 14:02:46 -0600 Subject: [PATCH 2/6] added func getter for EmitFuncContext; based on hugr 0.20.2, which is currently compatible with eldarion --- hugr-llvm/src/emit/func.rs | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/hugr-llvm/src/emit/func.rs b/hugr-llvm/src/emit/func.rs index 77d865540f..3c5eed2ef1 100644 --- a/hugr-llvm/src/emit/func.rs +++ b/hugr-llvm/src/emit/func.rs @@ -109,6 +109,11 @@ impl<'c, 'a, H: HugrView> EmitFuncContext<'c, 'a, H> { self.todo.insert(node.node()); } + /// Returns the current [FunctionValue] being emitted. + pub fn func(&self) -> FunctionValue<'c> { + self.func + } + /// Returns the internal [Builder]. Callers must ensure that it is /// positioned at the end of a basic block. This invariant is not checked(it /// doesn't seem possible to check it). From bdbb0f185e3776e6587d731cfc4de2504a94085e Mon Sep 17 00:00:00 2001 From: Jenny Chen Date: Mon, 28 Jul 2025 14:04:23 -0600 Subject: [PATCH 3/6] reverts back to the latest main --- .config/nextest.toml | 7 + .github/change-filters.yml | 7 + .github/workflows/ci-py.yml | 45 +- .github/workflows/ci-rs.yml | 36 +- .github/workflows/notify-coverage.yml | 2 +- .github/workflows/release-please.yml | 2 + .github/workflows/release-plz.yml | 8 +- .github/workflows/semver-checks.yml | 31 +- .github/workflows/unsoundness.yml | 49 +- .release-please-manifest.json | 4 +- Cargo.lock | 239 ++- Cargo.toml | 20 +- hugr-cli/CHANGELOG.md | 12 + hugr-cli/Cargo.toml | 6 +- hugr-cli/src/extensions.rs | 13 +- hugr-cli/src/lib.rs | 62 +- hugr-cli/src/main.rs | 126 +- hugr-cli/src/mermaid.rs | 25 +- hugr-cli/src/validate.rs | 46 +- hugr-cli/tests/validate.rs | 43 +- hugr-core/CHANGELOG.md | 60 + hugr-core/Cargo.toml | 13 +- hugr-core/src/builder.rs | 12 +- hugr-core/src/builder/build_traits.rs | 106 +- hugr-core/src/builder/circuit.rs | 16 +- hugr-core/src/builder/dataflow.rs | 89 +- hugr-core/src/builder/module.rs | 143 +- hugr-core/src/core.rs | 88 +- hugr-core/src/envelope.rs | 94 +- hugr-core/src/envelope/header.rs | 57 +- hugr-core/src/envelope/package_json.rs | 26 +- hugr-core/src/envelope/serde_with.rs | 383 +++- hugr-core/src/export.rs | 306 ++-- hugr-core/src/extension.rs | 37 +- hugr-core/src/extension/declarative/types.rs | 4 +- hugr-core/src/extension/op_def.rs | 101 +- hugr-core/src/extension/prelude.rs | 128 +- hugr-core/src/extension/prelude/generic.rs | 8 +- hugr-core/src/extension/resolution.rs | 4 +- hugr-core/src/extension/resolution/ops.rs | 4 +- hugr-core/src/extension/resolution/test.rs | 8 +- hugr-core/src/extension/resolution/types.rs | 65 +- .../src/extension/resolution/types_mut.rs | 49 +- hugr-core/src/extension/simple_op.rs | 7 +- hugr-core/src/extension/type_def.rs | 45 +- hugr-core/src/hugr.rs | 25 +- hugr-core/src/hugr/hugrmut.rs | 45 +- hugr-core/src/hugr/internal.rs | 20 +- hugr-core/src/hugr/patch/inline_call.rs | 29 +- hugr-core/src/hugr/patch/outline_cfg.rs | 4 +- hugr-core/src/hugr/patch/peel_loop.rs | 20 +- hugr-core/src/hugr/patch/simple_replace.rs | 125 +- hugr-core/src/hugr/persistent/resolver.rs | 43 - .../src/hugr/persistent/walker/pinned.rs | 164 -- hugr-core/src/hugr/serialize.rs | 9 +- hugr-core/src/hugr/serialize/test.rs | 165 +- .../upgrade/testcases/hugr_with_named_op.json | 114 +- hugr-core/src/hugr/validate.rs | 160 +- hugr-core/src/hugr/validate/test.rs | 346 ++-- hugr-core/src/hugr/views.rs | 2 +- hugr-core/src/hugr/views/impls.rs | 1 + hugr-core/src/hugr/views/render.rs | 14 +- hugr-core/src/hugr/views/rerooted.rs | 1 + hugr-core/src/hugr/views/root_checked/dfg.rs | 316 ++-- hugr-core/src/hugr/views/sibling_subgraph.rs | 10 +- hugr-core/src/hugr/views/tests.rs | 8 +- hugr-core/src/import.rs | 1597 ++++++++++------- hugr-core/src/lib.rs | 3 +- hugr-core/src/ops/constant.rs | 63 +- hugr-core/src/ops/constant/serialize.rs | 59 + hugr-core/src/ops/controlflow.rs | 14 +- hugr-core/src/ops/custom.rs | 13 +- hugr-core/src/ops/dataflow.rs | 12 +- hugr-core/src/ops/module.rs | 66 +- hugr-core/src/ops/tag.rs | 6 +- hugr-core/src/ops/validate.rs | 22 +- hugr-core/src/package.rs | 7 +- hugr-core/src/std_extensions.rs | 1 + .../std_extensions/arithmetic/int_types.rs | 28 +- .../src/std_extensions/arithmetic/mod.rs | 2 +- hugr-core/src/std_extensions/collections.rs | 1 + .../src/std_extensions/collections/array.rs | 4 +- .../collections/array/array_clone.rs | 14 +- .../collections/array/array_conversion.rs | 18 +- .../collections/array/array_discard.rs | 14 +- .../collections/array/array_op.rs | 46 +- .../collections/array/array_repeat.rs | 16 +- .../collections/array/array_scan.rs | 36 +- .../collections/array/array_value.rs | 6 +- .../collections/array/op_builder.rs | 13 + .../collections/borrow_array.rs | 797 ++++++++ .../src/std_extensions/collections/list.rs | 30 +- .../collections/static_array.rs | 12 +- .../std_extensions/collections/value_array.rs | 2 +- hugr-core/src/std_extensions/ptr.rs | 6 +- hugr-core/src/types.rs | 165 +- hugr-core/src/types/check.rs | 8 +- hugr-core/src/types/custom.rs | 4 +- hugr-core/src/types/poly_func.rs | 141 +- hugr-core/src/types/row_var.rs | 4 +- hugr-core/src/types/serialize.rs | 172 +- hugr-core/src/types/type_param.rs | 1090 +++++++---- hugr-core/src/types/type_row.rs | 162 +- hugr-core/tests/model.rs | 142 +- .../tests/snapshots/model__roundtrip_add.snap | 33 +- .../snapshots/model__roundtrip_alias.snap | 2 +- .../snapshots/model__roundtrip_call.snap | 5 +- .../tests/snapshots/model__roundtrip_cfg.snap | 17 +- .../snapshots/model__roundtrip_cond.snap | 41 +- .../snapshots/model__roundtrip_const.snap | 13 +- .../model__roundtrip_constraints.snap | 16 +- .../model__roundtrip_entrypoint.snap | 19 +- .../snapshots/model__roundtrip_loop.snap | 7 +- .../snapshots/model__roundtrip_order.snap | 76 +- .../snapshots/model__roundtrip_params.snap | 40 +- hugr-llvm/CHANGELOG.md | 12 + hugr-llvm/Cargo.toml | 6 +- hugr-llvm/src/emit/func.rs | 5 - hugr-llvm/src/emit/ops/cfg.rs | 6 +- ...test_fns__diverse_cfg_children@llvm14.snap | 17 +- ...verse_cfg_children@pre-mem2reg@llvm14.snap | 45 +- ...test_fns__diverse_dfg_children@llvm14.snap | 23 - ...verse_dfg_children@pre-mem2reg@llvm14.snap | 38 - hugr-llvm/src/emit/test.rs | 98 +- hugr-llvm/src/extension/collections/array.rs | 39 +- hugr-llvm/src/extension/collections/list.rs | 8 +- ...t_static_array_of_static_array@llvm14.snap | 4 +- ...ay_of_static_array@pre-mem2reg@llvm14.snap | 4 +- .../src/extension/collections/stack_array.rs | 40 +- .../src/extension/collections/static_array.rs | 11 +- hugr-llvm/src/extension/conversions.rs | 20 +- hugr-llvm/src/extension/float.rs | 7 +- hugr-llvm/src/extension/int.rs | 17 +- hugr-llvm/src/extension/logic.rs | 4 +- hugr-llvm/src/extension/prelude.rs | 143 +- ...lude__test__prelude_make_error@llvm14.snap | 19 + ...prelude_make_error@pre-mem2reg@llvm14.snap | 31 + ...__prelude_make_error_and_panic@llvm14.snap | 28 + ...ke_error_and_panic@pre-mem2reg@llvm14.snap | 37 + hugr-llvm/src/test.rs | 2 +- hugr-llvm/src/utils/fat.rs | 9 +- hugr-model/CHANGELOG.md | 24 + hugr-model/Cargo.toml | 3 +- hugr-model/FORMAT_VERSION | 1 + hugr-model/capnp/hugr-v0.capnp | 13 + hugr-model/src/capnp/hugr_v0_capnp.rs | 550 +++++- hugr-model/src/lib.rs | 19 + hugr-model/src/v0/ast/hugr.pest | 6 +- hugr-model/src/v0/ast/mod.rs | 4 +- hugr-model/src/v0/ast/parse.rs | 12 +- hugr-model/src/v0/ast/print.rs | 8 +- hugr-model/src/v0/ast/python.rs | 32 + hugr-model/src/v0/ast/resolve.rs | 18 + hugr-model/src/v0/ast/view.rs | 2 + hugr-model/src/v0/binary/read.rs | 140 +- hugr-model/src/v0/binary/write.rs | 18 +- hugr-model/src/v0/mod.rs | 53 +- hugr-model/src/v0/scope/vars.rs | 27 +- hugr-model/src/v0/table/mod.rs | 4 +- hugr-model/tests/fixtures/model-add.edn | 25 +- hugr-model/tests/fixtures/model-call.edn | 6 +- hugr-model/tests/fixtures/model-cfg.edn | 32 +- hugr-model/tests/fixtures/model-cond.edn | 37 +- hugr-model/tests/fixtures/model-const.edn | 6 +- .../tests/fixtures/model-constraints.edn | 5 +- .../tests/fixtures/model-entrypoint.edn | 10 +- hugr-model/tests/fixtures/model-loop.edn | 4 +- hugr-model/tests/fixtures/model-order.edn | 55 +- hugr-model/tests/fixtures/model-params.edn | 17 +- hugr-passes/CHANGELOG.md | 31 + hugr-passes/Cargo.toml | 4 +- hugr-passes/src/call_graph.rs | 11 +- hugr-passes/src/composable.rs | 40 +- hugr-passes/src/const_fold.rs | 1 - hugr-passes/src/const_fold/test.rs | 2 +- hugr-passes/src/dataflow.rs | 1 - hugr-passes/src/dataflow/datalog.rs | 5 + hugr-passes/src/dataflow/test.rs | 18 +- hugr-passes/src/dead_funcs.rs | 14 +- hugr-passes/src/inline_dfgs.rs | 99 + hugr-passes/src/inline_funcs.rs | 229 +++ hugr-passes/src/lib.rs | 4 +- hugr-passes/src/linearize_array.rs | 10 +- hugr-passes/src/lower.rs | 2 + hugr-passes/src/monomorphize.rs | 171 +- hugr-passes/src/non_local.rs | 1 - hugr-passes/src/replace_types.rs | 172 +- hugr-passes/src/replace_types/handlers.rs | 10 +- hugr-passes/src/replace_types/linearize.rs | 156 +- hugr-persistent/CHANGELOG.md | 11 + hugr-persistent/Cargo.toml | 43 + hugr-persistent/README.md | 59 + hugr-persistent/src/lib.rs | 98 + .../src}/parents_view.rs | 14 +- .../src/persistent_hugr.rs | 421 ++--- hugr-persistent/src/persistent_hugr/serial.rs | 75 + ..._serial__tests__serde_persistent_hugr.snap | 184 ++ hugr-persistent/src/resolver.rs | 147 ++ .../src}/state_space.rs | 264 ++- .../src}/state_space/serial.rs | 64 +- ..._serial__tests__serialize_state_space.snap | 244 +++ hugr-persistent/src/subgraph.rs | 215 +++ .../src}/tests.rs | 137 +- .../src}/trait_impls.rs | 45 +- .../src}/walker.rs | 526 ++++-- hugr-persistent/src/wire.rs | 303 ++++ .../tests/persistent_walker_example.rs | 227 +-- hugr-py/CHANGELOG.md | 103 ++ hugr-py/Cargo.toml | 2 +- hugr-py/pyproject.toml | 2 +- hugr-py/rust/lib.rs | 11 + hugr-py/src/hugr/__init__.py | 2 +- hugr-py/src/hugr/_hugr/__init__.pyi | 1 + hugr-py/src/hugr/_serialization/extension.py | 28 +- hugr-py/src/hugr/_serialization/ops.py | 50 +- hugr-py/src/hugr/_serialization/tys.py | 94 +- hugr-py/src/hugr/build/dfg.py | 62 +- hugr-py/src/hugr/build/function.py | 42 +- hugr-py/src/hugr/envelope.py | 36 +- hugr-py/src/hugr/ext.py | 4 +- hugr-py/src/hugr/hugr/base.py | 229 ++- hugr-py/src/hugr/hugr/render.py | 11 +- hugr-py/src/hugr/model/__init__.py | 19 + hugr-py/src/hugr/model/export.py | 201 ++- hugr-py/src/hugr/ops.py | 83 +- .../_json_defs/collections/borrow_arr.json | 1139 ++++++++++++ hugr-py/src/hugr/std/_json_defs/prelude.json | 34 +- hugr-py/src/hugr/std/collections/array.py | 2 +- .../src/hugr/std/collections/borrow_array.py | 94 + .../src/hugr/std/collections/static_array.py | 3 + hugr-py/src/hugr/std/int.py | 30 +- hugr-py/src/hugr/tys.py | 229 ++- hugr-py/src/hugr/utils.py | 24 + hugr-py/src/hugr/val.py | 114 +- .../tests/__snapshots__/test_hugr_build.ambr | 374 +++- .../tests/__snapshots__/test_order_edges.ambr | 258 +++ hugr-py/tests/conftest.py | 203 ++- hugr-py/tests/test_cfg.py | 2 +- hugr-py/tests/test_custom.py | 2 +- hugr-py/tests/test_envelope.py | 43 +- hugr-py/tests/test_hugr_build.py | 150 +- hugr-py/tests/test_ops.py | 6 +- hugr-py/tests/test_order_edges.py | 49 + hugr-py/tests/test_prelude.py | 38 + hugr-py/tests/test_tys.py | 58 +- hugr-py/tests/test_val.py | 14 +- hugr/CHANGELOG.md | 89 + hugr/Cargo.toml | 12 +- hugr/benches/benchmarks/hugr/examples.rs | 4 +- hugr/benches/benchmarks/types.rs | 2 +- hugr/src/lib.rs | 4 + justfile | 4 +- release-plz.toml | 9 + resources/test/hugr-no-visibility.hugr | 52 + scripts/check_extension_versions.py | 90 + scripts/generate_schema.py | 5 +- specification/hugr.md | 47 +- specification/schema/hugr_schema_live.json | 232 ++- .../schema/hugr_schema_strict_live.json | 232 ++- .../schema/testing_hugr_schema_live.json | 232 ++- .../testing_hugr_schema_strict_live.json | 232 ++- .../collections/borrow_arr.json | 1139 ++++++++++++ specification/std_extensions/prelude.json | 34 +- uv.lock | 2 +- 264 files changed, 16221 insertions(+), 5230 deletions(-) create mode 100644 .config/nextest.toml delete mode 100644 hugr-core/src/hugr/persistent/resolver.rs delete mode 100644 hugr-core/src/hugr/persistent/walker/pinned.rs create mode 100644 hugr-core/src/ops/constant/serialize.rs create mode 100644 hugr-core/src/std_extensions/collections/borrow_array.rs delete mode 100644 hugr-llvm/src/emit/snapshots/hugr_llvm__emit__test__test_fns__diverse_dfg_children@llvm14.snap delete mode 100644 hugr-llvm/src/emit/snapshots/hugr_llvm__emit__test__test_fns__diverse_dfg_children@pre-mem2reg@llvm14.snap create mode 100644 hugr-llvm/src/extension/snapshots/hugr_llvm__extension__prelude__test__prelude_make_error@llvm14.snap create mode 100644 hugr-llvm/src/extension/snapshots/hugr_llvm__extension__prelude__test__prelude_make_error@pre-mem2reg@llvm14.snap create mode 100644 hugr-llvm/src/extension/snapshots/hugr_llvm__extension__prelude__test__prelude_make_error_and_panic@llvm14.snap create mode 100644 hugr-llvm/src/extension/snapshots/hugr_llvm__extension__prelude__test__prelude_make_error_and_panic@pre-mem2reg@llvm14.snap create mode 100644 hugr-model/FORMAT_VERSION create mode 100644 hugr-passes/src/inline_dfgs.rs create mode 100644 hugr-passes/src/inline_funcs.rs create mode 100644 hugr-persistent/CHANGELOG.md create mode 100644 hugr-persistent/Cargo.toml create mode 100644 hugr-persistent/README.md create mode 100644 hugr-persistent/src/lib.rs rename {hugr-core/src/hugr/persistent => hugr-persistent/src}/parents_view.rs (95%) rename hugr-core/src/hugr/persistent.rs => hugr-persistent/src/persistent_hugr.rs (59%) create mode 100644 hugr-persistent/src/persistent_hugr/serial.rs create mode 100644 hugr-persistent/src/persistent_hugr/snapshots/hugr_persistent__persistent_hugr__serial__tests__serde_persistent_hugr.snap create mode 100644 hugr-persistent/src/resolver.rs rename {hugr-core/src/hugr/persistent => hugr-persistent/src}/state_space.rs (70%) rename {hugr-core/src/hugr/persistent => hugr-persistent/src}/state_space/serial.rs (66%) create mode 100644 hugr-persistent/src/state_space/snapshots/hugr_persistent__state_space__serial__tests__serialize_state_space.snap create mode 100644 hugr-persistent/src/subgraph.rs rename {hugr-core/src/hugr/persistent => hugr-persistent/src}/tests.rs (81%) rename {hugr-core/src/hugr/persistent => hugr-persistent/src}/trait_impls.rs (92%) rename {hugr-core/src/hugr/persistent => hugr-persistent/src}/walker.rs (52%) create mode 100644 hugr-persistent/src/wire.rs rename {hugr-core => hugr-persistent}/tests/persistent_walker_example.rs (62%) create mode 100644 hugr-py/src/hugr/std/_json_defs/collections/borrow_arr.json create mode 100644 hugr-py/src/hugr/std/collections/borrow_array.py create mode 100644 hugr-py/tests/__snapshots__/test_order_edges.ambr create mode 100644 hugr-py/tests/test_order_edges.py create mode 100644 resources/test/hugr-no-visibility.hugr create mode 100644 scripts/check_extension_versions.py create mode 100644 specification/std_extensions/collections/borrow_arr.json diff --git a/.config/nextest.toml b/.config/nextest.toml new file mode 100644 index 0000000000..e267109367 --- /dev/null +++ b/.config/nextest.toml @@ -0,0 +1,7 @@ + + +[profile.default-miri] +# Fail if tests take more than 5 mins. +# Those tests should be skipped in `.github/workflows/unsondness.yml`. +slow-timeout = { period = "60s", terminate-after = 5 } +fail-fast = false diff --git a/.github/change-filters.yml b/.github/change-filters.yml index f7f7864564..4fe700efc2 100644 --- a/.github/change-filters.yml +++ b/.github/change-filters.yml @@ -1,5 +1,8 @@ # Filters used by [dorny/path-filters](https://github.com/dorny/paths-filter) # to detect changes in each subproject, and only run the corresponding jobs. +# +# We use a composable action to add some additional checks. +# When adding a new category here, make sure to also update `.github/actions/check-changes/action.yml` # Dependencies and common workspace configuration. rust-config: &rust-config @@ -24,8 +27,12 @@ rust: &rust - "hugr-cli/**" - "hugr-core/**" - "hugr-passes/**" + - "hugr-persistent/**" - "specification/schema/**" +std-extensions: + - "specification/std_extensions/**" + python: - *rust - "hugr-py/**" diff --git a/.github/workflows/ci-py.yml b/.github/workflows/ci-py.yml index b71cbdca04..c4f88b3f2e 100644 --- a/.github/workflows/ci-py.yml +++ b/.github/workflows/ci-py.yml @@ -21,22 +21,19 @@ env: jobs: # Check if changes were made to the relevant files. - # Always returns true if running on the default branch, to ensure all changes are throughly checked. + # Always returns true if running on the default branch, to ensure all changes are thoroughly checked. changes: - name: Check for changes in Python files + name: Check for changes runs-on: ubuntu-latest - # Required permissions permissions: pull-requests: read - # Set job outputs to values from filter step outputs: - python: ${{ github.ref_name == github.event.repository.default_branch || steps.filter.outputs.python }} + python: ${{ steps.filter.outputs.python }} + extensions: ${{ steps.filter.outputs.llvm }} steps: - uses: actions/checkout@v4 - - uses: dorny/paths-filter@v3 + - uses: ./.github/actions/check-changes id: filter - with: - filters: .github/change-filters.yml check: needs: changes @@ -179,11 +176,41 @@ jobs: exit 1 fi + extension-versions: + runs-on: ubuntu-latest + needs: [changes] + if: ${{ needs.changes.outputs.extensions == 'true' }} + name: Check std extensions versions + steps: + - uses: actions/checkout@v4 + with: + fetch-depth: 0 # Need full history to compare with main + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: '3.10' + + - name: Check if extension versions are updated + run: | + # Check against latest tag on the target branch + # When not on a pull request, base_ref should be empty so we default to HEAD + if [ -z "$TARGET_REF" ]; then + BASE_SHA="HEAD~1" + else + BASE_SHA=$(git rev-parse origin/$TARGET_REF) + fi + echo "Comparing to ref: $BASE_SHA" + + python ./scripts/check_extension_versions.py $BASE_SHA + env: + TARGET_REF: ${{ github.base_ref }} + # This is a meta job to mark successful completion of the required checks, # even if they are skipped due to no changes in the relevant files. required-checks: name: Required checks 🐍 - needs: [changes, check, test, serialization-schema] + needs: [changes, check, test, serialization-schema, extension-versions] if: ${{ !cancelled() }} runs-on: ubuntu-latest steps: diff --git a/.github/workflows/ci-rs.yml b/.github/workflows/ci-rs.yml index 40e0046e06..cc7a19635c 100644 --- a/.github/workflows/ci-rs.yml +++ b/.github/workflows/ci-rs.yml @@ -19,6 +19,7 @@ env: CI: true # insta snapshots behave differently on ci SCCACHE_GHA_ENABLED: "true" RUSTC_WRAPPER: "sccache" + HUGR_TEST_SCHEMA: "1" # different strings for install action and feature name # adapted from https://github.com/TheDan64/inkwell/blob/master/.github/workflows/test.yml LLVM_VERSION: "14.0" @@ -30,36 +31,16 @@ jobs: changes: name: Check for changes runs-on: ubuntu-latest - # Required permissions permissions: pull-requests: read - # Set job outputs to values from filter step - # These outputs are always true when running after a merge to main, or if the PR has a `run-ci-checks` label. outputs: - rust: ${{ steps.filter.outputs.rust == 'true' || steps.override.outputs.out == 'true' }} - python: ${{ steps.filter.outputs.python == 'true' || steps.override.outputs.out == 'true' }} - model: ${{ steps.filter.outputs.model == 'true' || steps.override.outputs.out == 'true' }} - llvm: ${{ steps.filter.outputs.llvm == 'true' || steps.override.outputs.out == 'true' }} + rust: ${{ steps.filter.outputs.rust }} + llvm: ${{ steps.filter.outputs.llvm }} + model: ${{ steps.filter.outputs.model }} steps: - uses: actions/checkout@v4 - - name: Override label - id: override - run: | - echo "Label contains run-ci-checks: $OVERRIDE_LABEL" - if [ "$OVERRIDE_LABEL" == "true" ]; then - echo "Overriding due to label 'run-ci-checks'" - echo "out=true" >> $GITHUB_OUTPUT - elif [ "$DEFAULT_BRANCH" == "true" ]; then - echo "Overriding due to running on the default branch" - echo "out=true" >> $GITHUB_OUTPUT - fi - env: - OVERRIDE_LABEL: ${{ github.event_name == 'pull_request' && contains( github.event.pull_request.labels.*.name, 'run-ci-checks') }} - DEFAULT_BRANCH: ${{ github.ref_name == github.event.repository.default_branch }} - - uses: dorny/paths-filter@v3 + - uses: ./.github/actions/check-changes id: filter - with: - filters: .github/change-filters.yml check: needs: changes @@ -170,6 +151,12 @@ jobs: - name: Tests hugr-llvm if: ${{ needs.changes.outputs.llvm == 'true'}} run: cargo test -p hugr-llvm --verbose --features llvm${{ env.LLVM_FEATURE_NAME }} + - name: Build hugr-persistent + if: ${{ needs.changes.outputs.rust == 'true'}} + run: cargo test -p hugr-persistent --verbose --no-run + - name: Tests hugr-persistent + if: ${{ needs.changes.outputs.rust == 'true'}} + run: cargo test -p hugr-persistent --verbose - name: Build HUGR binary run: cargo build -p hugr-cli - name: Upload the binary to the artifacts @@ -355,6 +342,7 @@ jobs: cargo llvm-cov --no-report --no-default-features --doctests cargo llvm-cov --no-report --all-features --doctests cargo llvm-cov --no-report -p hugr-llvm --features llvm14-0 --doctests + cargo llvm-cov --no-report -p hugr-persistent --doctests - name: Generate coverage report run: cargo llvm-cov --all-features report --codecov --output-path coverage.json - name: Upload coverage to codecov.io diff --git a/.github/workflows/notify-coverage.yml b/.github/workflows/notify-coverage.yml index 25c8f21702..7eae317ab0 100644 --- a/.github/workflows/notify-coverage.yml +++ b/.github/workflows/notify-coverage.yml @@ -22,7 +22,7 @@ jobs: if: needs.coverage-trend.outputs.should_notify == 'true' steps: - name: Send notification - uses: slackapi/slack-github-action@v2.1.0 + uses: slackapi/slack-github-action@v2.1.1 with: method: chat.postMessage token: ${{ secrets.SLACK_BOT_TOKEN }} diff --git a/.github/workflows/release-please.yml b/.github/workflows/release-please.yml index 16534f141c..35f2c6711b 100644 --- a/.github/workflows/release-please.yml +++ b/.github/workflows/release-please.yml @@ -6,6 +6,7 @@ on: push: branches: - main + - release/* permissions: contents: write @@ -21,3 +22,4 @@ jobs: # Using a personal access token so releases created by this workflow can trigger the deployment workflow token: ${{ secrets.HUGRBOT_PAT }} config-file: release-please-config.json + target-branch: ${{ github.ref_name }} diff --git a/.github/workflows/release-plz.yml b/.github/workflows/release-plz.yml index 5eacec276b..9139b9eb32 100644 --- a/.github/workflows/release-plz.yml +++ b/.github/workflows/release-plz.yml @@ -13,6 +13,9 @@ jobs: release-plz: name: Release-plz runs-on: ubuntu-latest + environment: crate-release + permissions: + id-token: write # Required for OIDC token exchange steps: - name: Checkout repository uses: actions/checkout@v4 @@ -32,8 +35,11 @@ jobs: # otherwise release-plz fails due to uncommitted changes. directory: ${{ runner.temp }}/llvm + - uses: rust-lang/crates-io-auth-action@v1 + id: auth + - name: Run release-plz uses: MarcoIeni/release-plz-action@v0.5 env: GITHUB_TOKEN: ${{ secrets.HUGRBOT_PAT }} - CARGO_REGISTRY_TOKEN: ${{ secrets.CARGO_REGISTRY_TOKEN }} + CARGO_REGISTRY_TOKEN: ${{ steps.auth.outputs.token }} diff --git a/.github/workflows/semver-checks.yml b/.github/workflows/semver-checks.yml index 2c410aa85d..881ec2227d 100644 --- a/.github/workflows/semver-checks.yml +++ b/.github/workflows/semver-checks.yml @@ -6,38 +6,19 @@ on: jobs: # Check if changes were made to the relevant files. - # Always returns true if running on the default branch, to ensure all changes are throughly checked. + # Always returns true if running on the default branch, to ensure all changes are thoroughly checked. changes: name: Check for changes runs-on: ubuntu-latest - # Required permissions permissions: pull-requests: read - # Set job outputs to values from filter step - # These outputs are always true when running after a merge to main, or if the PR has a `run-ci-checks` label. outputs: - rust: ${{ steps.filter.outputs.rust == 'true' || steps.override.outputs.out == 'true' }} - python: ${{ steps.filter.outputs.python == 'true' || steps.override.outputs.out == 'true' }} + rust: ${{ steps.filter.outputs.rust }} + python: ${{ steps.filter.outputs.python }} steps: - - uses: actions/checkout@v4 - - name: Override label - id: override - run: | - echo "Label contains run-ci-checks: $OVERRIDE_LABEL" - if [ "$OVERRIDE_LABEL" == "true" ]; then - echo "Overriding due to label 'run-ci-checks'" - echo "out=true" >> $GITHUB_OUTPUT - elif [ "$DEFAULT_BRANCH" == "true" ]; then - echo "Overriding due to running on the default branch" - echo "out=true" >> $GITHUB_OUTPUT - fi - env: - OVERRIDE_LABEL: ${{ github.event_name == 'pull_request' && contains( github.event.pull_request.labels.*.name, 'run-ci-checks') }} - DEFAULT_BRANCH: ${{ github.ref_name == github.event.repository.default_branch }} - - uses: dorny/paths-filter@v3 - id: filter - with: - filters: .github/change-filters.yml + - uses: actions/checkout@v4 + - uses: ./.github/actions/check-changes + id: filter rs-semver-checks: needs: [changes] diff --git a/.github/workflows/unsoundness.yml b/.github/workflows/unsoundness.yml index 3ba0b4cf0d..f110fe168b 100644 --- a/.github/workflows/unsoundness.yml +++ b/.github/workflows/unsoundness.yml @@ -31,12 +31,53 @@ jobs: rustup toolchain install nightly --component miri rustup override set nightly cargo miri setup - - uses: Swatinem/rust-cache@v2 + + - uses: taiki-e/install-action@v2 with: - prefix-key: v0-miri - - name: Test with Miri - run: cargo miri test + tool: nextest + # Run miri unsoundness checks. + # + # The default "zstd" feature requires FFI to the zstd library encode/decode envelopes. + # As this is not supported in miri, we must disable it here. + # + # We also skip tests that take over 5mins in CI. + - name: Test with Miri + run: | + cargo miri nextest run --no-default-features -- \ + --skip "builder::circuit::test::with_nonlinear_and_outputs" \ + --skip "extension::op_def::test::check_ext_id_wellformed" \ + --skip "extension::resolution::test::register_new_cyclic" \ + --skip "extension::simple_op::test::check_ext_id_wellformed" \ + --skip "extension::test::test_register_update" \ + --skip "extension::type_def::test::test_instantiate_typedef" \ + --skip "hugr::ident::test::proptest::arbitrary_identlist_valid" \ + --skip "hugr::ident::test::test_idents" \ + --skip "hugr::patch::replace::test::test_invalid" \ + --skip "hugr::validate::test::check_ext_id_wellformed" \ + --skip "ops::constant::test::test_json_const" \ + --skip "ops::custom::test::resolve_missing" \ + --skip "ops::custom::test::new_opaque_op" \ + --skip "std_extensions::arithmetic::int_types::test::proptest::valid_signed_int" \ + --skip "types::test::construct" \ + --skip "types::test::transform" \ + --skip "types::test::transform_copyable_to_linear" \ + --skip "types::type_param::test::proptest::term_contains_itself" \ + `# -------- hugr-model` \ + --skip "v0::test::test_literal_text" \ + `# -------- hugr-passes` \ + --skip "dataflow::partial_value::test::bounded_lattice" \ + --skip "dataflow::partial_value::test::lattice" \ + --skip "dataflow::partial_value::test::lattice_associative" \ + --skip "dataflow::partial_value::test::meet_join_self_noop" \ + --skip "dataflow::partial_value::test::partial_value_type" \ + --skip "dataflow::partial_value::test::partial_value_valid" \ + --skip "merge_bbs::test::check_ext_id_wellformed" \ + --skip "monomorphize::test::test_recursion_module" \ + --skip "replace_types::test::dfg_conditional_case" \ + --skip "replace_types::test::module_func_cfg_call" \ + --skip "replace_types::test::op_to_call" \ + # create-issue: uses: CQCL/hugrverse-actions/.github/workflows/create-issue.yml@main diff --git a/.release-please-manifest.json b/.release-please-manifest.json index 7f59cf2351..8e5bbf9e12 100644 --- a/.release-please-manifest.json +++ b/.release-please-manifest.json @@ -1,3 +1,3 @@ { - "hugr-py": "0.12.1" -} + "hugr-py": "0.13.0rc1" +} \ No newline at end of file diff --git a/Cargo.lock b/Cargo.lock index 4981ad7340..a9d98367a2 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -327,9 +327,9 @@ dependencies = [ [[package]] name = "bumpalo" -version = "3.18.1" +version = "3.19.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "793db76d6187cd04dff33004d8e6c9cc4e05cd330500379d2394209271b4aeee" +checksum = "46c5e41b57b8bba42a04676d81cb89e9ee8e859a1a66f80a5a72e1cb76b34d43" [[package]] name = "bytecount" @@ -351,9 +351,9 @@ checksum = "d71b6127be86fdcfddb610f7182ac57211d4b18a3e9c82eb2d17662f2227ad6a" [[package]] name = "capnp" -version = "0.20.6" +version = "0.21.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "053b81915c2ce1629b8fb964f578b18cb39b23ef9d5b24120d0dfc959569a1d9" +checksum = "d55799fdec2a55eee8c267430d7464eb9c27ad2e5c8a49b433ff213b56852c7f" dependencies = [ "embedded-io", ] @@ -440,9 +440,9 @@ dependencies = [ [[package]] name = "clap" -version = "4.5.40" +version = "4.5.41" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "40b6887a1d8685cebccf115538db5c0efe625ccac9696ad45c409d96566e910f" +checksum = "be92d32e80243a54711e5d7ce823c35c41c9d929dc4ab58e1276f625841aadf9" dependencies = [ "clap_builder", "clap_derive", @@ -460,9 +460,9 @@ dependencies = [ [[package]] name = "clap_builder" -version = "4.5.40" +version = "4.5.41" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e0c66c08ce9f0c698cbce5c0279d0bb6ac936d8674174fe48f736533b964f59e" +checksum = "707eab41e9622f9139419d573eca0900137718000c517d47da73045f54331c3d" dependencies = [ "anstream", "anstyle", @@ -472,9 +472,9 @@ dependencies = [ [[package]] name = "clap_derive" -version = "4.5.40" +version = "4.5.41" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d2c7947ae4cc3d851207c1adb5b5e260ff0cca11446b1d6d1423788e442257ce" +checksum = "ef4f52386a59ca4c860f7393bcf8abd8dfd91ecccc0f774635ff68e92eeef491" dependencies = [ "heck", "proc-macro2", @@ -676,9 +676,9 @@ dependencies = [ [[package]] name = "delegate" -version = "0.13.3" +version = "0.13.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b9b6483c2bbed26f97861cf57651d4f2b731964a28cd2257f934a4b452480d21" +checksum = "6178a82cf56c836a3ba61a7935cdb1c49bfaa6fa4327cd5bf554a503087de26b" dependencies = [ "proc-macro2", "quote", @@ -1201,7 +1201,7 @@ checksum = "6dbf3de79e51f3d586ab4cb9d5c3e2c14aa28ed23d180cf89b4df0454a69cc87" [[package]] name = "hugr" -version = "0.20.2" +version = "0.22.1" dependencies = [ "bumpalo", "criterion", @@ -1209,13 +1209,14 @@ dependencies = [ "hugr-llvm", "hugr-model", "hugr-passes", + "hugr-persistent", "lazy_static", "serde_json", ] [[package]] name = "hugr-cli" -version = "0.20.2" +version = "0.22.1" dependencies = [ "anyhow", "assert_cmd", @@ -1230,12 +1231,16 @@ dependencies = [ "serde_json", "tempfile", "thiserror 2.0.12", + "tracing", + "tracing-subscriber", ] [[package]] name = "hugr-core" -version = "0.20.2" +version = "0.22.1" dependencies = [ + "anyhow", + "base64", "cgmath", "cool_asserts", "delegate", @@ -1246,11 +1251,12 @@ dependencies = [ "html-escape", "hugr", "hugr-model", - "indexmap 2.9.0", + "indexmap 2.10.0", "insta", "itertools 0.14.0", "jsonschema", "lazy_static", + "ordered-float", "paste", "petgraph 0.8.2", "portgraph", @@ -1264,17 +1270,19 @@ dependencies = [ "serde_json", "serde_with", "serde_yaml", + "smallvec", "smol_str", "static_assertions", "strum", "thiserror 2.0.12", + "tracing", "typetag", "zstd", ] [[package]] name = "hugr-llvm" -version = "0.20.2" +version = "0.22.1" dependencies = [ "anyhow", "delegate", @@ -1293,14 +1301,14 @@ dependencies = [ [[package]] name = "hugr-model" -version = "0.20.2" +version = "0.22.1" dependencies = [ "base64", "bumpalo", "capnp", "derive_more 1.0.0", "fxhash", - "indexmap 2.9.0", + "indexmap 2.10.0", "insta", "itertools 0.14.0", "ordered-float", @@ -1311,13 +1319,14 @@ dependencies = [ "proptest", "proptest-derive", "pyo3", + "semver", "smol_str", "thiserror 2.0.12", ] [[package]] name = "hugr-passes" -version = "0.20.2" +version = "0.22.1" dependencies = [ "ascent", "derive_more 1.0.0", @@ -1334,6 +1343,28 @@ dependencies = [ "thiserror 2.0.12", ] +[[package]] +name = "hugr-persistent" +version = "0.2.1" +dependencies = [ + "delegate", + "derive_more 1.0.0", + "hugr-core", + "insta", + "itertools 0.14.0", + "lazy_static", + "petgraph 0.8.2", + "portgraph", + "relrc", + "rstest", + "semver", + "serde", + "serde_json", + "serde_with", + "thiserror 2.0.12", + "wyhash", +] + [[package]] name = "hugr-py" version = "0.1.0" @@ -1573,9 +1604,9 @@ dependencies = [ [[package]] name = "indexmap" -version = "2.9.0" +version = "2.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cea70ddb795996207ad57735b50c5982d8844f38ba9ee5f1aedcfb708a2aa11e" +checksum = "fe4cd85333e22411419a0bcae1297d25e58c9443848b11dc6a86fefe8c78a661" dependencies = [ "equivalent", "hashbrown 0.15.4", @@ -1848,6 +1879,16 @@ version = "0.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "61807f77802ff30975e01f4f071c8ba10c022052f98b3294119f3e615d13e5be" +[[package]] +name = "nu-ansi-term" +version = "0.46.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "77a8165726e8236064dbb45459242600304b42a5ea24ee2948e18e023bf7ba84" +dependencies = [ + "overload", + "winapi", +] + [[package]] name = "num" version = "0.4.3" @@ -1967,6 +2008,8 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e2c1f9f56e534ac6a9b8a4600bdf0f530fb393b5f393e7b4d03489c3cf0c3f01" dependencies = [ "num-traits", + "rand 0.8.5", + "serde", ] [[package]] @@ -1975,6 +2018,12 @@ version = "0.5.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1a80800c0488c3a21695ea981a54918fbb37abf04f4d0720c453632255e2ff0e" +[[package]] +name = "overload" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b15813163c1d831bf4a13c3610c05c0d03b39feb07f7e09fa234dac9b15aaf39" + [[package]] name = "parking_lot" version = "0.12.4" @@ -2061,7 +2110,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b4c5cc86750666a3ed20bdaf5ca2a0344f9c67674cae0515bec2da16fbaa47db" dependencies = [ "fixedbitset 0.4.2", - "indexmap 2.9.0", + "indexmap 2.10.0", ] [[package]] @@ -2072,7 +2121,7 @@ checksum = "54acf3a685220b533e437e264e4d932cfbdc4cc7ec0cd232ed73c08d03b8a7ca" dependencies = [ "fixedbitset 0.5.7", "hashbrown 0.15.4", - "indexmap 2.9.0", + "indexmap 2.10.0", "serde", ] @@ -2130,13 +2179,14 @@ checksum = "f84267b20a16ea918e43c6a88433c2d54fa145c92a811b5b047ccbe153674483" [[package]] name = "portgraph" -version = "0.14.1" +version = "0.15.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5fdce52d51ec359351ff3c209fafb6f133562abf52d951ce5821c0184798d979" +checksum = "61fb905fbfbc9abf3bd37853bbd4b25d31dffd5631994f8df528f85455085657" dependencies = [ "bitvec", "delegate", "itertools 0.14.0", + "num-traits", "petgraph 0.8.2", "serde", "thiserror 2.0.12", @@ -2246,7 +2296,7 @@ dependencies = [ "bitflags", "lazy_static", "num-traits", - "rand", + "rand 0.9.1", "rand_chacha", "rand_xorshift", "regex-syntax", @@ -2366,6 +2416,16 @@ version = "0.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "dc33ff2d4973d518d823d61aa239014831e521c75da58e3df4840d3f47749d09" +[[package]] +name = "rand" +version = "0.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404" +dependencies = [ + "rand_core 0.6.4", + "serde", +] + [[package]] name = "rand" version = "0.9.1" @@ -2391,6 +2451,9 @@ name = "rand_core" version = "0.6.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c" +dependencies = [ + "serde", +] [[package]] name = "rand_core" @@ -2675,6 +2738,18 @@ dependencies = [ "serde_json", ] +[[package]] +name = "schemars" +version = "1.0.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1375ba8ef45a6f15d83fa8748f1079428295d403d6ea991d09ab100155fbc06d" +dependencies = [ + "dyn-clone", + "ref-cast", + "serde", + "serde_json", +] + [[package]] name = "scopeguard" version = "1.2.0" @@ -2736,16 +2811,17 @@ dependencies = [ [[package]] name = "serde_with" -version = "3.13.0" +version = "3.14.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bf65a400f8f66fb7b0552869ad70157166676db75ed8181f8104ea91cf9d0b42" +checksum = "f2c45cd61fefa9db6f254525d46e392b852e0e61d9a1fd36e5bd183450a556d5" dependencies = [ "base64", "chrono", "hex", "indexmap 1.9.3", - "indexmap 2.9.0", - "schemars", + "indexmap 2.10.0", + "schemars 0.9.0", + "schemars 1.0.3", "serde", "serde_derive", "serde_json", @@ -2755,9 +2831,9 @@ dependencies = [ [[package]] name = "serde_with_macros" -version = "3.13.0" +version = "3.14.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "81679d9ed988d5e9a5e6531dc3f2c28efbd639cbd1dfb628df08edea6004da77" +checksum = "de90945e6565ce0d9a25098082ed4ee4002e047cb59892c318d66821e14bb30f" dependencies = [ "darling", "proc-macro2", @@ -2771,7 +2847,7 @@ version = "0.9.34+deprecated" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6a8b1a1a2ebf674015cc02edccce75287f1a0130d394307b36743c2f5d504b47" dependencies = [ - "indexmap 2.9.0", + "indexmap 2.10.0", "itoa", "ryu", "serde", @@ -2789,6 +2865,15 @@ dependencies = [ "digest", ] +[[package]] +name = "sharded-slab" +version = "0.1.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f40ca3c46823713e0d4209592e8d6e826aa57e928f09752619fc696c499637f6" +dependencies = [ + "lazy_static", +] + [[package]] name = "shlex" version = "1.3.0" @@ -2985,6 +3070,15 @@ dependencies = [ "syn", ] +[[package]] +name = "thread_local" +version = "1.1.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f60246a4944f24f6e018aa17cdeffb7818b76356965d03b07d6a9886e8962185" +dependencies = [ + "cfg-if", +] + [[package]] name = "time" version = "0.3.41" @@ -3062,7 +3156,7 @@ version = "0.22.27" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "41fe8c660ae4257887cf66394862d21dbca4a6ddd26f04a3560410406a2f819a" dependencies = [ - "indexmap 2.9.0", + "indexmap 2.10.0", "toml_datetime", "winnow", ] @@ -3119,9 +3213,21 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "784e0ac535deb450455cbfa28a6f0df145ea1bb7ae51b821cf5e7927fdcfbdd0" dependencies = [ "pin-project-lite", + "tracing-attributes", "tracing-core", ] +[[package]] +name = "tracing-attributes" +version = "0.1.29" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1b1ffbcf9c6f6b99d386e7444eb608ba646ae452a36b39737deb9663b610f662" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "tracing-core" version = "0.1.34" @@ -3129,6 +3235,32 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b9d12581f227e93f094d3af2ae690a574abb8a2b9b7a96e7cfe9647b2b617678" dependencies = [ "once_cell", + "valuable", +] + +[[package]] +name = "tracing-log" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ee855f1f400bd0e5c02d150ae5de3840039a3f54b025156404e34c23c03f47c3" +dependencies = [ + "log", + "once_cell", + "tracing-core", +] + +[[package]] +name = "tracing-subscriber" +version = "0.3.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e8189decb5ac0fa7bc8b96b7cb9b2701d60d48805aca84a238004d665fcc4008" +dependencies = [ + "nu-ansi-term", + "sharded-slab", + "smallvec", + "thread_local", + "tracing-core", + "tracing-log", ] [[package]] @@ -3271,6 +3403,12 @@ dependencies = [ "vsimd", ] +[[package]] +name = "valuable" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ba73ea9cf16a25df0c8caa16c51acb937d5712a8429db78a3ee29d5dcacd3a65" + [[package]] name = "version_check" version = "0.9.5" @@ -3407,6 +3545,22 @@ dependencies = [ "wasm-bindgen", ] +[[package]] +name = "winapi" +version = "0.3.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5c839a674fcd7a98952e593242ea400abe93992746761e38641405d28b00f419" +dependencies = [ + "winapi-i686-pc-windows-gnu", + "winapi-x86_64-pc-windows-gnu", +] + +[[package]] +name = "winapi-i686-pc-windows-gnu" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6" + [[package]] name = "winapi-util" version = "0.1.9" @@ -3416,6 +3570,12 @@ dependencies = [ "windows-sys 0.59.0", ] +[[package]] +name = "winapi-x86_64-pc-windows-gnu" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f" + [[package]] name = "windows-core" version = "0.61.2" @@ -3638,6 +3798,15 @@ version = "0.6.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ea2f10b9bb0928dfb1b42b65e1f9e36f7f54dbdf08457afefb38afcdec4fa2bb" +[[package]] +name = "wyhash" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ca4d373340c479fd1e779f7a763acee85da3e423b1a9a9acccf97babcc92edbb" +dependencies = [ + "rand_core 0.9.3", +] + [[package]] name = "wyz" version = "0.5.1" diff --git a/Cargo.toml b/Cargo.toml index b123c9897d..56d2033786 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -11,6 +11,7 @@ members = [ "hugr-model", "hugr-llvm", "hugr-py", + "hugr-persistent", ] default-members = ["hugr", "hugr-core", "hugr-passes", "hugr-cli", "hugr-model"] @@ -37,18 +38,14 @@ missing_docs = "warn" # https://github.com/rust-lang/rust-clippy/issues/5112 debug_assert_with_mut_call = "warn" -# TODO: Reduce the size of error types. -result_large_err = "allow" -large_enum_variant = "allow" - [workspace.dependencies] anyhow = "1.0.98" insta = { version = "1.43.1" } bitvec = "1.0.1" -capnp = "0.20.6" +capnp = "0.21.3" cgmath = "0.18.0" cool_asserts = "2.0.3" -delegate = "0.13.3" +delegate = "0.13.4" derive_more = "1.0.0" downcast-rs = "2.0.1" enum_dispatch = "0.3.11" @@ -66,7 +63,7 @@ rstest = "0.24.0" semver = "1.0.26" serde = "1.0.219" serde_json = "1.0.140" -serde_with = "3.13.0" +serde_with = "3.14.0" serde_yaml = "0.9.34" smol_str = "0.3.1" static_assertions = "1.1.0" @@ -74,15 +71,15 @@ strum = "0.27.0" tempfile = "3.20" thiserror = "2.0.12" typetag = "0.2.20" -clap = { version = "4.5.40" } +clap = { version = "4.5.41" } clio = "0.3.5" clap-verbosity-flag = "3.0.3" assert_cmd = "2.0.17" assert_fs = "1.1.3" predicates = "3.1.0" -indexmap = "2.9.0" +indexmap = "2.10.0" fxhash = "0.2.1" -bumpalo = "3.18.1" +bumpalo = "3.19.0" pathsearch = "0.2.0" base64 = "0.22.1" ordered-float = "5.0.0" @@ -92,11 +89,12 @@ pretty = "0.12.4" pretty_assertions = "1.4.1" zstd = "0.13.2" relrc = "0.4.6" +wyhash = "0.6.0" # These public dependencies usually require breaking changes downstream, so we # try to be as permissive as possible. pyo3 = ">= 0.23.4, < 0.25" -portgraph = { version = "0.14.1" } +portgraph = { version = "0.15.1" } petgraph = { version = ">= 0.8.1, < 0.9", default-features = false } [profile.dev.package] diff --git a/hugr-cli/CHANGELOG.md b/hugr-cli/CHANGELOG.md index 5bbc26a608..090b49ea50 100644 --- a/hugr-cli/CHANGELOG.md +++ b/hugr-cli/CHANGELOG.md @@ -1,6 +1,18 @@ # Changelog +## [0.22.0](https://github.com/CQCL/hugr/compare/hugr-cli-v0.21.0...hugr-cli-v0.22.0) - 2025-07-24 + +### New Features + +- include generator metatada in model import and cli validate errors ([#2452](https://github.com/CQCL/hugr/pull/2452)) + +## [0.21.0](https://github.com/CQCL/hugr/compare/hugr-cli-v0.20.2...hugr-cli-v0.21.0) - 2025-07-09 + +### New Features + +- [**breaking**] Better error reporting in `hugr-cli`. ([#2318](https://github.com/CQCL/hugr/pull/2318)) + ## [0.20.2](https://github.com/CQCL/hugr/compare/hugr-cli-v0.20.1...hugr-cli-v0.20.2) - 2025-06-25 ### New Features diff --git a/hugr-cli/Cargo.toml b/hugr-cli/Cargo.toml index a05666cc0a..aba4984530 100644 --- a/hugr-cli/Cargo.toml +++ b/hugr-cli/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "hugr-cli" -version = "0.20.2" +version = "0.22.1" edition = { workspace = true } rust-version = { workspace = true } license = { workspace = true } @@ -19,11 +19,13 @@ bench = false clap = { workspace = true, features = ["derive", "cargo"] } clap-verbosity-flag.workspace = true derive_more = { workspace = true, features = ["display", "error", "from"] } -hugr = { path = "../hugr", version = "0.20.2" } +hugr = { path = "../hugr", version = "0.22.1" } serde_json.workspace = true clio = { workspace = true, features = ["clap-parse"] } anyhow.workspace = true thiserror.workspace = true +tracing = "0.1.41" +tracing-subscriber = { version = "0.3.19", features = ["fmt"] } [lints] workspace = true diff --git a/hugr-cli/src/extensions.rs b/hugr-cli/src/extensions.rs index ff4862634b..1fc31e8571 100644 --- a/hugr-cli/src/extensions.rs +++ b/hugr-cli/src/extensions.rs @@ -1,4 +1,5 @@ //! Dump standard extensions in serialized form. +use anyhow::Result; use clap::Parser; use hugr::extension::ExtensionRegistry; use std::{io::Write, path::PathBuf}; @@ -25,7 +26,7 @@ impl ExtArgs { /// Write out the standard extensions in serialized form. /// Qualified names of extensions used to generate directories under the specified output directory. /// E.g. extension "foo.bar.baz" will be written to "OUTPUT/foo/bar/baz.json". - pub fn run_dump(&self, registry: &ExtensionRegistry) { + pub fn run_dump(&self, registry: &ExtensionRegistry) -> Result<()> { let base_dir = &self.outdir; for ext in registry { @@ -35,15 +36,17 @@ impl ExtArgs { } path.set_extension("json"); - std::fs::create_dir_all(path.clone().parent().unwrap()).unwrap(); + std::fs::create_dir_all(path.clone().parent().unwrap())?; // file buffer - let mut file = std::fs::File::create(&path).unwrap(); + let mut file = std::fs::File::create(&path)?; - serde_json::to_writer_pretty(&mut file, &ext).unwrap(); + serde_json::to_writer_pretty(&mut file, &ext)?; // write newline, for pre-commit end of file check that edits the file to // add newlines if missing. - file.write_all(b"\n").unwrap(); + file.write_all(b"\n")?; } + + Ok(()) } } diff --git a/hugr-cli/src/lib.rs b/hugr-cli/src/lib.rs index 0b91ed547b..d4c269d811 100644 --- a/hugr-cli/src/lib.rs +++ b/hugr-cli/src/lib.rs @@ -57,11 +57,11 @@ //! ``` use clap::{Parser, crate_version}; -use clap_verbosity_flag::log::Level; use clap_verbosity_flag::{InfoLevel, Verbosity}; use hugr::envelope::EnvelopeError; use hugr::package::PackageValidationError; use std::ffi::OsString; +use thiserror::Error; pub mod convert; pub mod extensions; @@ -74,8 +74,19 @@ pub mod validate; #[clap(version = crate_version!(), long_about = None)] #[clap(about = "HUGR CLI tools.")] #[group(id = "hugr")] +pub struct CliArgs { + /// The command to be run. + #[command(subcommand)] + pub command: CliCommand, + /// Verbosity. + #[command(flatten)] + pub verbose: Verbosity, +} + +/// The CLI subcommands. +#[derive(Debug, clap::Subcommand)] #[non_exhaustive] -pub enum CliArgs { +pub enum CliCommand { /// Validate and visualize a HUGR file. Validate(validate::ValArgs), /// Write standard extensions out in serialized form. @@ -90,45 +101,38 @@ pub enum CliArgs { } /// Error type for the CLI. -#[derive(Debug, derive_more::Display, thiserror::Error, derive_more::From)] +#[derive(Debug, Error)] #[non_exhaustive] pub enum CliError { /// Error reading input. - #[display("Error reading from path: {_0}")] - InputFile(std::io::Error), + #[error("Error reading from path.")] + InputFile(#[from] std::io::Error), /// Error parsing input. - #[display("Error parsing package: {_0}")] - Parse(serde_json::Error), - #[display("Error validating HUGR: {_0}")] + #[error("Error parsing package.")] + Parse(#[from] serde_json::Error), + #[error("Error validating HUGR.")] /// Errors produced by the `validate` subcommand. - Validate(PackageValidationError), - #[display("Error decoding HUGR envelope: {_0}")] + Validate(#[from] PackageValidationError), + #[error("Error decoding HUGR envelope.")] /// Errors produced by the `validate` subcommand. - Envelope(EnvelopeError), + Envelope(#[from] EnvelopeError), /// Pretty error when the user passes a non-envelope file. - #[display( + #[error( "Input file is not a HUGR envelope. Invalid magic number.\n\nUse `--hugr-json` to read a raw HUGR JSON file instead." )] NotAnEnvelope, /// Invalid format string for conversion. - #[display( + #[error( "Invalid format: '{_0}'. Valid formats are: json, model, model-exts, model-text, model-text-exts" )] InvalidFormat(String), -} - -/// Other arguments affecting the HUGR CLI runtime. -#[derive(Parser, Debug)] -pub struct OtherArgs { - /// Verbosity. - #[command(flatten)] - pub verbose: Verbosity, -} - -impl OtherArgs { - /// Test whether a `level` message should be output. - #[must_use] - pub fn verbosity(&self, level: Level) -> bool { - self.verbose.log_level_filter() >= level - } + #[error("Error validating HUGR generated by {generator}")] + /// Errors produced by the `validate` subcommand, with a known generator of the HUGR. + ValidateKnownGenerator { + #[source] + /// The inner validation error. + inner: PackageValidationError, + /// The generator of the HUGR. + generator: Box, + }, } diff --git a/hugr-cli/src/main.rs b/hugr-cli/src/main.rs index 8063f25916..28e64020ff 100644 --- a/hugr-cli/src/main.rs +++ b/hugr-cli/src/main.rs @@ -1,84 +1,72 @@ //! Validate serialized HUGR on the command line -use clap::Parser as _; - -use hugr_cli::{CliArgs, convert, mermaid, validate}; +use std::ffi::OsString; -use clap_verbosity_flag::log::Level; +use anyhow::{Result, anyhow}; +use clap::Parser as _; +use clap_verbosity_flag::VerbosityFilter; +use hugr_cli::{CliArgs, CliCommand}; +use tracing::{error, metadata::LevelFilter}; fn main() { - match CliArgs::parse() { - CliArgs::Validate(args) => run_validate(args), - CliArgs::GenExtensions(args) => args.run_dump(&hugr::std_extensions::STD_REG), - CliArgs::Mermaid(args) => run_mermaid(args), - CliArgs::Convert(args) => run_convert(args), - CliArgs::External(args) => { - // External subcommand support: invoke `hugr-` - if args.is_empty() { - eprintln!("No external subcommand specified."); - std::process::exit(1); - } - let subcmd = args[0].to_string_lossy(); - let exe = format!("hugr-{}", subcmd); - let rest: Vec<_> = args[1..] - .iter() - .map(|s| s.to_string_lossy().to_string()) - .collect(); - match std::process::Command::new(&exe).args(&rest).status() { - Ok(status) => { - if !status.success() { - std::process::exit(status.code().unwrap_or(1)); - } - } - Err(e) if e.kind() == std::io::ErrorKind::NotFound => { - eprintln!( - "error: no such subcommand: '{subcmd}'.\nCould not find '{exe}' in PATH." - ); - std::process::exit(1); - } - Err(e) => { - eprintln!("error: failed to invoke '{exe}': {e}"); - std::process::exit(1); - } - } - } - _ => { - eprintln!("Unknown command"); - std::process::exit(1); - } - } -} + let cli_args = CliArgs::parse(); -/// Run the `validate` subcommand. -fn run_validate(mut args: validate::ValArgs) { - let result = args.run(); + let level = match cli_args.verbose.filter() { + VerbosityFilter::Off => LevelFilter::OFF, + VerbosityFilter::Error => LevelFilter::ERROR, + VerbosityFilter::Warn => LevelFilter::WARN, + VerbosityFilter::Info => LevelFilter::INFO, + VerbosityFilter::Debug => LevelFilter::DEBUG, + VerbosityFilter::Trace => LevelFilter::TRACE, + }; + tracing_subscriber::fmt() + .with_writer(std::io::stderr) + .with_max_level(level) + .pretty() + .init(); - if let Err(e) = result { - if args.verbosity(Level::Error) { - eprintln!("{e}"); - } - std::process::exit(1); - } -} - -/// Run the `mermaid` subcommand. -fn run_mermaid(mut args: mermaid::MermaidArgs) { - let result = args.run_print(); + let result = match cli_args.command { + CliCommand::Validate(mut args) => args.run(), + CliCommand::GenExtensions(args) => args.run_dump(&hugr::std_extensions::STD_REG), + CliCommand::Mermaid(mut args) => args.run_print(), + CliCommand::Convert(mut args) => args.run_convert(), + CliCommand::External(args) => run_external(args), + _ => Err(anyhow!("Unknown command")), + }; - if let Err(e) = result { - if args.other_args.verbosity(Level::Error) { - eprintln!("{e}"); - } + if let Err(err) = result { + error!("{:?}", err); std::process::exit(1); } } -/// Run the `convert` subcommand. -fn run_convert(mut args: convert::ConvertArgs) { - let result = args.run_convert(); - - if let Err(e) = result { - eprintln!("{e}"); +fn run_external(args: Vec) -> Result<()> { + // External subcommand support: invoke `hugr-` + if args.is_empty() { + eprintln!("No external subcommand specified."); std::process::exit(1); } + let subcmd = args[0].to_string_lossy(); + let exe = format!("hugr-{subcmd}"); + let rest: Vec<_> = args[1..] + .iter() + .map(|s| s.to_string_lossy().to_string()) + .collect(); + match std::process::Command::new(&exe).args(&rest).status() { + Ok(status) => { + if !status.success() { + std::process::exit(status.code().unwrap_or(1)); + } + } + Err(e) if e.kind() == std::io::ErrorKind::NotFound => { + eprintln!("error: no such subcommand: '{subcmd}'.\nCould not find '{exe}' in PATH."); + std::process::exit(1); + } + Err(e) => { + eprintln!("error: failed to invoke '{exe}': {e}"); + std::process::exit(1); + } + } + + Ok(()) } diff --git a/hugr-cli/src/mermaid.rs b/hugr-cli/src/mermaid.rs index 0ce5c9b8a4..fbe4a09a4e 100644 --- a/hugr-cli/src/mermaid.rs +++ b/hugr-cli/src/mermaid.rs @@ -1,15 +1,14 @@ //! Render mermaid diagrams. use std::io::Write; +use crate::CliError; +use crate::hugr_io::HugrInputArgs; +use anyhow::Result; use clap::Parser; -use clap_verbosity_flag::log::Level; use clio::Output; use hugr::HugrView; use hugr::package::PackageValidationError; -use crate::OtherArgs; -use crate::hugr_io::HugrInputArgs; - /// Dump the standard extensions. #[derive(Parser, Debug)] #[clap(version = "1.0", long_about = None)] @@ -30,15 +29,11 @@ pub struct MermaidArgs { /// Output file '-' for stdout #[clap(long, short, value_parser, default_value = "-")] output: Output, - - /// Additional arguments - #[command(flatten)] - pub other_args: OtherArgs, } impl MermaidArgs { /// Write the mermaid diagram to the output. - pub fn run_print(&mut self) -> Result<(), crate::CliError> { + pub fn run_print(&mut self) -> Result<()> { if self.input_args.hugr_json { self.run_print_hugr() } else { @@ -47,11 +42,11 @@ impl MermaidArgs { } /// Write the mermaid diagram for a HUGR envelope. - pub fn run_print_envelope(&mut self) -> Result<(), crate::CliError> { + pub fn run_print_envelope(&mut self) -> Result<()> { let package = self.input_args.get_package()?; if self.validate { - package.validate()?; + package.validate().map_err(CliError::Validate)?; } for hugr in package.modules { @@ -61,7 +56,7 @@ impl MermaidArgs { } /// Write the mermaid diagram for a legacy HUGR json. - pub fn run_print_hugr(&mut self) -> Result<(), crate::CliError> { + pub fn run_print_hugr(&mut self) -> Result<()> { let hugr = self.input_args.get_hugr()?; if self.validate { @@ -72,10 +67,4 @@ impl MermaidArgs { writeln!(self.output, "{}", hugr.mermaid_string())?; Ok(()) } - - /// Test whether a `level` message should be output. - #[must_use] - pub fn verbosity(&self, level: Level) -> bool { - self.other_args.verbosity(level) - } } diff --git a/hugr-cli/src/validate.rs b/hugr-cli/src/validate.rs index ddf51d135c..2ec19e8b38 100644 --- a/hugr-cli/src/validate.rs +++ b/hugr-cli/src/validate.rs @@ -1,12 +1,13 @@ //! The `validate` subcommand. +use anyhow::Result; use clap::Parser; -use clap_verbosity_flag::log::Level; +use hugr::HugrView; use hugr::package::PackageValidationError; -use hugr::{Hugr, HugrView}; +use tracing::info; +use crate::CliError; use crate::hugr_io::HugrInputArgs; -use crate::{CliError, OtherArgs}; /// Validate and visualise a HUGR file. #[derive(Parser, Debug)] @@ -18,10 +19,6 @@ pub struct ValArgs { /// Hugr input. #[command(flatten)] pub input_args: HugrInputArgs, - - /// Additional arguments - #[command(flatten)] - pub other_args: OtherArgs, } /// String to print when validation is successful. @@ -29,28 +26,35 @@ pub const VALID_PRINT: &str = "HUGR valid!"; impl ValArgs { /// Run the HUGR cli and validate against an extension registry. - pub fn run(&mut self) -> Result, CliError> { - let result = if self.input_args.hugr_json { + pub fn run(&mut self) -> Result<()> { + if self.input_args.hugr_json { let hugr = self.input_args.get_hugr()?; + let generator = hugr::envelope::get_generator(&[&hugr]); + hugr.validate() - .map_err(PackageValidationError::Validation)?; - vec![hugr] + .map_err(PackageValidationError::Validation) + .map_err(|val_err| wrap_generator(generator, val_err))?; } else { let package = self.input_args.get_package()?; - package.validate()?; - package.modules + let generator = hugr::envelope::get_generator(&package.modules); + package + .validate() + .map_err(|val_err| wrap_generator(generator, val_err))?; }; - if self.verbosity(Level::Info) { - eprintln!("{VALID_PRINT}"); - } + info!("{VALID_PRINT}"); - Ok(result) + Ok(()) } +} - /// Test whether a `level` message should be output. - #[must_use] - pub fn verbosity(&self, level: Level) -> bool { - self.other_args.verbosity(level) +fn wrap_generator(generator: Option, val_err: PackageValidationError) -> CliError { + if let Some(g) = generator { + CliError::ValidateKnownGenerator { + inner: val_err, + generator: Box::new(g.to_string()), + } + } else { + CliError::Validate(val_err) } } diff --git a/hugr-cli/tests/validate.rs b/hugr-cli/tests/validate.rs index 5fb09ca091..d7d552decc 100644 --- a/hugr-cli/tests/validate.rs +++ b/hugr-cli/tests/validate.rs @@ -13,12 +13,15 @@ use hugr::types::Type; use hugr::{ builder::{Container, Dataflow}, extension::prelude::{bool_t, qb_t}, + hugr::HugrView, + hugr::hugrmut::HugrMut, std_extensions::arithmetic::float_types::float64_type, types::Signature, }; use hugr_cli::validate::VALID_PRINT; use predicates::{prelude::*, str::contains}; use rstest::{fixture, rstest}; +use serde_json::json; #[fixture] fn cmd() -> Command { @@ -123,6 +126,7 @@ fn test_mermaid_invalid(bad_hugr_string: String, mut cmd: Command) { cmd.write_stdin(bad_hugr_string); cmd.assert() .failure() + .stderr(contains("unconnected port")) .stderr(contains("Error validating HUGR")); } @@ -134,6 +138,7 @@ fn test_bad_hugr(bad_hugr_string: String, mut val_cmd: Command) { val_cmd .assert() .failure() + .stderr(contains("unconnected port")) .stderr(contains("Error validating HUGR")); } @@ -146,7 +151,8 @@ fn test_bad_json(mut val_cmd: Command) { val_cmd .assert() .failure() - .stderr(contains("Error decoding HUGR envelope")); + .stderr(contains("Error decoding HUGR envelope")) + .stderr(contains("missing field")); } #[rstest] @@ -199,3 +205,38 @@ fn test_package_validation(package_string: String, mut val_cmd: Command) { val_cmd.assert().success().stderr(contains(VALID_PRINT)); } + +/// Create a deliberately invalid HUGR with a known generator +#[fixture] +fn invalid_hugr_with_generator() -> Vec { + // Create an invalid HUGR (missing outputs in a dataflow) + let df = DFGBuilder::new(Signature::new_endo(vec![qb_t()])).unwrap(); + let mut bad_hugr = df.hugr().clone(); // Missing outputs makes this invalid + bad_hugr.set_metadata( + bad_hugr.module_root(), + hugr::envelope::GENERATOR_KEY, + json!({"name": "test-generator", "version": "1.0.1"}), + ); + // Create envelope with a specific generator + let envelope_config = EnvelopeConfig::binary(); + + let mut buff = Vec::new(); + // Serialize to string + bad_hugr.store(&mut buff, envelope_config).unwrap(); + buff +} + +#[rstest] +fn test_validate_known_generator(invalid_hugr_with_generator: Vec, mut val_cmd: Command) { + // Write the invalid HUGR to stdin + val_cmd.write_stdin(invalid_hugr_with_generator); + val_cmd.arg("-"); + + // Expect a failure with the generator name in the error message + val_cmd + .assert() + .failure() + .stderr(contains("Error validating HUGR")) + .stderr(contains("unconnected port")) + .stderr(contains("generated by test-generator-v1.0.1")); +} diff --git a/hugr-core/CHANGELOG.md b/hugr-core/CHANGELOG.md index 9026431a7a..2202aac02a 100644 --- a/hugr-core/CHANGELOG.md +++ b/hugr-core/CHANGELOG.md @@ -1,5 +1,65 @@ # Changelog + +## [0.22.0](https://github.com/CQCL/hugr/compare/hugr-core-v0.21.0...hugr-core-v0.22.0) - 2025-07-24 + +### Bug Fixes + +- Ensure SumTypes have the same json encoding in -rs and -py ([#2465](https://github.com/CQCL/hugr/pull/2465)) + +### New Features + +- Export entrypoint metadata in Python and fix bug in import ([#2434](https://github.com/CQCL/hugr/pull/2434)) +- Names of private functions become `core.title` metadata. ([#2448](https://github.com/CQCL/hugr/pull/2448)) +- [**breaking**] Use binary envelopes for operation lower_func encoding ([#2447](https://github.com/CQCL/hugr/pull/2447)) +- [**breaking**] Update portgraph dependency to 0.15 ([#2455](https://github.com/CQCL/hugr/pull/2455)) +- Detect and fail on unrecognised envelope flags ([#2453](https://github.com/CQCL/hugr/pull/2453)) +- include generator metatada in model import and cli validate errors ([#2452](https://github.com/CQCL/hugr/pull/2452)) +- [**breaking**] Add `insert_region` to HugrMut ([#2463](https://github.com/CQCL/hugr/pull/2463)) +- Non-region entrypoints in `hugr-model`. ([#2467](https://github.com/CQCL/hugr/pull/2467)) +## [0.21.0](https://github.com/CQCL/hugr/compare/hugr-core-v0.20.2...hugr-core-v0.21.0) - 2025-07-09 + +### Bug Fixes + +- Fixed two bugs in import/export of function operations ([#2324](https://github.com/CQCL/hugr/pull/2324)) +- Model import should perform extension resolution ([#2326](https://github.com/CQCL/hugr/pull/2326)) +- [**breaking**] Fixed bugs in model CFG handling and improved CFG signatures ([#2334](https://github.com/CQCL/hugr/pull/2334)) +- Use List instead of Tuple in conversions for TypeArg/TypeRow ([#2378](https://github.com/CQCL/hugr/pull/2378)) +- Do extension resolution on loaded extensions from the model format ([#2389](https://github.com/CQCL/hugr/pull/2389)) +- Make JSON Schema checks actually work again ([#2412](https://github.com/CQCL/hugr/pull/2412)) +- Order hints on input and output nodes. ([#2422](https://github.com/CQCL/hugr/pull/2422)) + +### New Features + +- [**breaking**] No nested FuncDefns (or AliasDefns) ([#2256](https://github.com/CQCL/hugr/pull/2256)) +- [**breaking**] Split `TypeArg::Sequence` into tuples and lists. ([#2140](https://github.com/CQCL/hugr/pull/2140)) +- [**breaking**] Added float and bytes literal to core and python bindings. ([#2289](https://github.com/CQCL/hugr/pull/2289)) +- [**breaking**] More helpful error messages in model import ([#2272](https://github.com/CQCL/hugr/pull/2272)) +- [**breaking**] Better error reporting in `hugr-cli`. ([#2318](https://github.com/CQCL/hugr/pull/2318)) +- [**breaking**] Merge `TypeParam` and `TypeArg` into one `Term` type in Rust ([#2309](https://github.com/CQCL/hugr/pull/2309)) +- *(persistent)* Add serialisation for CommitStateSpace ([#2344](https://github.com/CQCL/hugr/pull/2344)) +- add TryFrom impls for TypeArg/TypeRow ([#2366](https://github.com/CQCL/hugr/pull/2366)) +- Add `MakeError` op ([#2377](https://github.com/CQCL/hugr/pull/2377)) +- Open lists and tuples in `Term` ([#2360](https://github.com/CQCL/hugr/pull/2360)) +- Call `FunctionBuilder::add_{in,out}put` for any AsMut ([#2376](https://github.com/CQCL/hugr/pull/2376)) +- Add Root checked methods to DataflowParentID ([#2382](https://github.com/CQCL/hugr/pull/2382)) +- Add PersistentWire type ([#2361](https://github.com/CQCL/hugr/pull/2361)) +- Add `BorrowArray` extension ([#2395](https://github.com/CQCL/hugr/pull/2395)) +- [**breaking**] Rename 'Any' type bound to 'Linear' ([#2421](https://github.com/CQCL/hugr/pull/2421)) +- [**breaking**] Add Visibility to FuncDefn/FuncDecl. ([#2143](https://github.com/CQCL/hugr/pull/2143)) +- *(per)* [**breaking**] Support empty wires in commits ([#2349](https://github.com/CQCL/hugr/pull/2349)) +- [**breaking**] hugr-model use explicit Option, with ::Unspecified in capnp ([#2424](https://github.com/CQCL/hugr/pull/2424)) + +### Refactor + +- [**breaking**] move PersistentHugr into separate crate ([#2277](https://github.com/CQCL/hugr/pull/2277)) +- [**breaking**] remove deprecated runtime extension errors ([#2369](https://github.com/CQCL/hugr/pull/2369)) +- [**breaking**] Reduce error type sizes ([#2420](https://github.com/CQCL/hugr/pull/2420)) + +### Testing + +- Check hugr json serializations against the schema (again) ([#2216](https://github.com/CQCL/hugr/pull/2216)) + ## [0.20.2](https://github.com/CQCL/hugr/compare/hugr-core-v0.20.1...hugr-core-v0.20.2) - 2025-06-25 ### Documentation diff --git a/hugr-core/Cargo.toml b/hugr-core/Cargo.toml index a365c73e19..02ea45885f 100644 --- a/hugr-core/Cargo.toml +++ b/hugr-core/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "hugr-core" -version = "0.20.2" +version = "0.22.1" edition = { workspace = true } rust-version = { workspace = true } @@ -19,6 +19,7 @@ workspace = true [features] declarative = ["serde_yaml"] zstd = ["dep:zstd"] +default = [] [lib] bench = false @@ -26,11 +27,8 @@ bench = false [[test]] name = "model" -[[test]] -name = "persistent_walker_example" - [dependencies] -hugr-model = { version = "0.20.2", path = "../hugr-model" } +hugr-model = { version = "0.22.1", path = "../hugr-model" } cgmath = { workspace = true, features = ["serde"] } delegate = { workspace = true } @@ -63,7 +61,11 @@ thiserror = { workspace = true } typetag = { workspace = true } semver = { workspace = true, features = ["serde"] } zstd = { workspace = true, optional = true } +ordered-float = { workspace = true, features = ["serde"] } +base64.workspace = true relrc = { workspace = true, features = ["petgraph", "serde"] } +smallvec = "1.15.0" +tracing = "0.1.41" [dev-dependencies] rstest = { workspace = true } @@ -77,3 +79,4 @@ proptest-derive = { workspace = true } # Required for documentation examples hugr = { path = "../hugr" } serde_yaml = "0.9.34" +anyhow = { workspace = true } diff --git a/hugr-core/src/builder.rs b/hugr-core/src/builder.rs index aa2d949056..ee5046dd5c 100644 --- a/hugr-core/src/builder.rs +++ b/hugr-core/src/builder.rs @@ -189,7 +189,7 @@ pub enum BuildError { #[error("Found an error while setting the outputs of a {} container, {container_node}. {error}", .container_op.name())] #[allow(missing_docs)] OutputWiring { - container_op: OpType, + container_op: Box, container_node: Node, #[source] error: BuilderWiringError, @@ -201,7 +201,7 @@ pub enum BuildError { #[error("Got an input wire while adding a {} to the circuit. {error}", .op.name())] #[allow(missing_docs)] OperationWiring { - op: OpType, + op: Box, #[source] error: BuilderWiringError, }, @@ -219,7 +219,7 @@ pub enum BuilderWiringError { #[error("Cannot copy linear type {typ} from output {src_offset} of node {src}")] #[allow(missing_docs)] NoCopyLinear { - typ: Type, + typ: Box, src: Node, src_offset: Port, }, @@ -244,7 +244,7 @@ pub enum BuilderWiringError { src_offset: Port, dst: Node, dst_offset: Port, - typ: Type, + typ: Box, }, } @@ -261,8 +261,8 @@ pub(crate) mod test { use super::handle::BuildHandle; use super::{ - BuildError, CFGBuilder, Container, DFGBuilder, Dataflow, DataflowHugr, FuncID, - FunctionBuilder, ModuleBuilder, + BuildError, CFGBuilder, DFGBuilder, Dataflow, DataflowHugr, FuncID, FunctionBuilder, + ModuleBuilder, }; use super::{DataflowSubContainer, HugrBuilder}; diff --git a/hugr-core/src/builder/build_traits.rs b/hugr-core/src/builder/build_traits.rs index 03731bb7da..ac4d46645d 100644 --- a/hugr-core/src/builder/build_traits.rs +++ b/hugr-core/src/builder/build_traits.rs @@ -9,7 +9,7 @@ use crate::{Extension, IncomingPort, Node, OutgoingPort}; use std::iter; use std::sync::Arc; -use super::{BuilderWiringError, FunctionBuilder}; +use super::{BuilderWiringError, ModuleBuilder}; use super::{ CircuitBuilder, handle::{BuildHandle, Outputs}, @@ -21,7 +21,7 @@ use crate::{ }; use crate::extension::ExtensionRegistry; -use crate::types::{PolyFuncType, Signature, Type, TypeArg, TypeRow}; +use crate::types::{Signature, Type, TypeArg, TypeRow}; use itertools::Itertools; @@ -82,37 +82,20 @@ pub trait Container { self.add_child_node(constant.into()).into() } - /// Add a [`ops::FuncDefn`] node and returns a builder to define the function - /// body graph. + /// Insert a HUGR's entrypoint region as a child of the container. /// - /// # Errors - /// - /// This function will return an error if there is an error in adding the - /// [`ops::FuncDefn`] node. - fn define_function( - &mut self, - name: impl Into, - signature: impl Into, - ) -> Result, BuildError> { - let signature: PolyFuncType = signature.into(); - let body = signature.body().clone(); - let f_node = self.add_child_node(ops::FuncDefn::new(name, signature)); - - // Add the extensions used by the function types. - self.use_extensions( - body.used_extensions().unwrap_or_else(|e| { - panic!("Build-time signatures should have valid extensions. {e}") - }), - ); - - let db = DFGBuilder::create_with_io(self.hugr_mut(), f_node, body)?; - Ok(FunctionBuilder::from_dfg_builder(db)) + /// To insert an arbitrary region of a HUGR, use [`Container::add_hugr_region`]. + fn add_hugr(&mut self, child: Hugr) -> InsertionResult { + let region = child.entrypoint(); + self.add_hugr_region(child, region) } - /// Insert a HUGR as a child of the container. - fn add_hugr(&mut self, child: Hugr) -> InsertionResult { + /// Insert a HUGR region as a child of the container. + /// + /// To insert the entrypoint region of a HUGR, use [`Container::add_hugr`]. + fn add_hugr_region(&mut self, child: Hugr, region: Node) -> InsertionResult { let parent = self.container_node(); - self.hugr_mut().insert_hugr(parent, child) + self.hugr_mut().insert_region(parent, child, region) } /// Insert a copy of a HUGR as a child of the container. @@ -155,8 +138,19 @@ pub trait Container { } /// Types implementing this trait can be used to build complete HUGRs -/// (with varying root node types) +/// (with varying entrypoint node types) pub trait HugrBuilder: Container { + /// Allows adding definitions to the module root of which + /// this builder is building a part + fn module_root_builder(&mut self) -> ModuleBuilder<&mut Hugr> { + debug_assert!( + self.hugr() + .get_optype(self.hugr().module_root()) + .is_module() + ); + ModuleBuilder(self.hugr_mut()) + } + /// Finish building the HUGR, perform any validation checks and return it. fn finish_hugr(self) -> Result>; } @@ -216,6 +210,10 @@ pub trait Dataflow: Container { /// Insert a hugr-defined op to the sibling graph, wiring up the /// `input_wires` to the incoming ports of the resulting root node. /// + /// Inserts everything from the entrypoint region of the HUGR. + /// See [`Dataflow::add_hugr_region_with_wires`] for a generic version that allows + /// inserting a region other than the entrypoint. + /// /// # Errors /// /// This function will return an error if there is an error when adding the @@ -225,12 +223,34 @@ pub trait Dataflow: Container { hugr: Hugr, input_wires: impl IntoIterator, ) -> Result, BuildError> { - let optype = hugr.get_optype(hugr.entrypoint()).clone(); + let region = hugr.entrypoint(); + self.add_hugr_region_with_wires(hugr, region, input_wires) + } + + /// Insert a hugr-defined op to the sibling graph, wiring up the + /// `input_wires` to the incoming ports of the resulting root node. + /// + /// `region` must be a node in the `hugr`. See [`Dataflow::add_hugr_with_wires`] + /// for a helper that inserts the entrypoint region to the HUGR. + /// + /// # Errors + /// + /// This function will return an error if there is an error when adding the + /// node. + fn add_hugr_region_with_wires( + &mut self, + hugr: Hugr, + region: Node, + input_wires: impl IntoIterator, + ) -> Result, BuildError> { + let optype = hugr.get_optype(region).clone(); let num_outputs = optype.value_output_count(); - let node = self.add_hugr(hugr).inserted_entrypoint; + let node = self.add_hugr_region(hugr, region).inserted_entrypoint; - wire_up_inputs(input_wires, node, self) - .map_err(|error| BuildError::OperationWiring { op: optype, error })?; + wire_up_inputs(input_wires, node, self).map_err(|error| BuildError::OperationWiring { + op: Box::new(optype), + error, + })?; Ok((node, num_outputs).into()) } @@ -251,8 +271,10 @@ pub trait Dataflow: Container { let optype = hugr.get_optype(hugr.entrypoint()).clone(); let num_outputs = optype.value_output_count(); - wire_up_inputs(input_wires, node, self) - .map_err(|error| BuildError::OperationWiring { op: optype, error })?; + wire_up_inputs(input_wires, node, self).map_err(|error| BuildError::OperationWiring { + op: Box::new(optype), + error, + })?; Ok((node, num_outputs).into()) } @@ -269,7 +291,7 @@ pub trait Dataflow: Container { let [_, out] = self.io(); wire_up_inputs(output_wires.into_iter().collect_vec(), out, self).map_err(|error| { BuildError::OutputWiring { - container_op: self.hugr().get_optype(self.container_node()).clone(), + container_op: Box::new(self.hugr().get_optype(self.container_node()).clone()), container_node: self.container_node(), error, } @@ -678,8 +700,10 @@ fn add_node_with_wires( let num_outputs = op.value_output_count(); let op_node = data_builder.add_child_node(op.clone()); - wire_up_inputs(inputs, op_node, data_builder) - .map_err(|error| BuildError::OperationWiring { op, error })?; + wire_up_inputs(inputs, op_node, data_builder).map_err(|error| BuildError::OperationWiring { + op: Box::new(op), + error, + })?; Ok((op_node, num_outputs)) } @@ -731,7 +755,7 @@ fn wire_up( src_offset: src_port.into(), dst, dst_offset: dst_port.into(), - typ, + typ: Box::new(typ), }); } @@ -762,7 +786,7 @@ fn wire_up( } else if !typ.copyable() & base.linked_ports(src, src_port).next().is_some() { // Don't copy linear edges. return Err(BuilderWiringError::NoCopyLinear { - typ, + typ: Box::new(typ), src, src_offset: src_port.into(), }); diff --git a/hugr-core/src/builder/circuit.rs b/hugr-core/src/builder/circuit.rs index 58f388f439..193bfe0675 100644 --- a/hugr-core/src/builder/circuit.rs +++ b/hugr-core/src/builder/circuit.rs @@ -27,10 +27,10 @@ pub struct CircuitBuilder<'a, T: ?Sized> { #[non_exhaustive] pub enum CircuitBuildError { /// Invalid index for stored wires. - #[error("Invalid wire index {invalid_index} while attempting to add operation {}.", .op.as_ref().map(NamedOp::name).unwrap_or_default())] + #[error("Invalid wire index {invalid_index} while attempting to add operation {}.", .op.as_ref().map(|op| op.name()).unwrap_or_default())] InvalidWireIndex { /// The operation. - op: Option, + op: Option>, /// The invalid indices. invalid_index: usize, }, @@ -38,7 +38,7 @@ pub enum CircuitBuildError { #[error("The linear inputs {:?} had no corresponding output wire in operation {}.", .index.as_slice(), .op.name())] MismatchedLinearInputs { /// The operation. - op: OpType, + op: Box, /// The index of the input that had no corresponding output wire. index: Vec, }, @@ -143,7 +143,7 @@ impl<'a, T: Dataflow + ?Sized> CircuitBuilder<'a, T> { let input_wires = input_wires.map_err(|invalid_index| CircuitBuildError::InvalidWireIndex { - op: Some(op.clone()), + op: Some(Box::new(op.clone())), invalid_index, })?; @@ -169,7 +169,7 @@ impl<'a, T: Dataflow + ?Sized> CircuitBuilder<'a, T> { if !linear_inputs.is_empty() { return Err(CircuitBuildError::MismatchedLinearInputs { - op, + op: Box::new(op), index: linear_inputs.values().copied().collect(), } .into()); @@ -245,7 +245,7 @@ mod test { use cool_asserts::assert_matches; use crate::Extension; - use crate::builder::{Container, HugrBuilder, ModuleBuilder}; + use crate::builder::{HugrBuilder, ModuleBuilder}; use crate::extension::ExtensionId; use crate::extension::prelude::{qb_t, usize_t}; use crate::std_extensions::arithmetic::float_types::ConstF64; @@ -389,7 +389,7 @@ mod test { assert_matches!( circ.append(cx_gate(), [q0, invalid_index]), Err(BuildError::CircuitError(CircuitBuildError::InvalidWireIndex { op, invalid_index: idx })) - if op == Some(cx_gate().into()) && idx == invalid_index, + if op == Some(Box::new(cx_gate().into())) && idx == invalid_index, ); // Untracking an invalid index returns an error @@ -403,7 +403,7 @@ mod test { assert_matches!( circ.append(q_discard(), [q1]), Err(BuildError::CircuitError(CircuitBuildError::MismatchedLinearInputs { op, index })) - if op == q_discard().into() && index == [q1], + if *op == q_discard().into() && index == [q1], ); let outs = circ.finish(); diff --git a/hugr-core/src/builder/dataflow.rs b/hugr-core/src/builder/dataflow.rs index 2a0fdf9315..d1131116f3 100644 --- a/hugr-core/src/builder/dataflow.rs +++ b/hugr-core/src/builder/dataflow.rs @@ -8,14 +8,9 @@ use std::marker::PhantomData; use crate::hugr::internal::HugrMutInternals; use crate::hugr::{HugrView, ValidationError}; -use crate::ops::{self, OpParent}; -use crate::ops::{DataflowParent, Input, Output}; -use crate::{Direction, IncomingPort, OutgoingPort, Wire}; - +use crate::ops::{self, DataflowParent, FuncDefn, Input, OpParent, Output}; use crate::types::{PolyFuncType, Signature, Type}; - -use crate::Node; -use crate::{Hugr, hugr::HugrMut}; +use crate::{Direction, Hugr, IncomingPort, Node, OutgoingPort, Visibility, Wire, hugr::HugrMut}; /// Builder for a [`ops::DFG`] node. #[derive(Debug, Clone, PartialEq)] @@ -152,7 +147,9 @@ impl DFGWrapper { pub type FunctionBuilder = DFGWrapper>>; impl FunctionBuilder { - /// Initialize a builder for a `FuncDefn` rooted HUGR + /// Initialize a builder for a [`FuncDefn`](ops::FuncDefn)-rooted HUGR; + /// the function will be private. (See also [Self::new_vis].) + /// /// # Errors /// /// Error in adding DFG child nodes. @@ -160,9 +157,25 @@ impl FunctionBuilder { name: impl Into, signature: impl Into, ) -> Result { - let signature: PolyFuncType = signature.into(); - let body = signature.body().clone(); - let op = ops::FuncDefn::new(name, signature); + Self::new_with_op(FuncDefn::new(name, signature)) + } + + /// Initialize a builder for a FuncDefn-rooted HUGR, with the specified + /// [Visibility]. + /// + /// # Errors + /// + /// Error in adding DFG child nodes. + pub fn new_vis( + name: impl Into, + signature: impl Into, + visibility: Visibility, + ) -> Result { + Self::new_with_op(FuncDefn::new_vis(name, signature, visibility)) + } + + fn new_with_op(op: FuncDefn) -> Result { + let body = op.signature().body().clone(); let base = Hugr::new_with_entrypoint(op).expect("FuncDefn entrypoint should be valid"); let root = base.entrypoint(); @@ -170,6 +183,31 @@ impl FunctionBuilder { let db = DFGBuilder::create_with_io(base, root, body)?; Ok(Self::from_dfg_builder(db)) } +} + +impl + AsRef> FunctionBuilder { + /// Initialize a new function definition on the root module of an existing HUGR. + /// + /// The HUGR's entrypoint will **not** be modified. + /// + /// # Errors + /// + /// Error in adding DFG child nodes. + pub fn with_hugr( + mut hugr: B, + name: impl Into, + signature: impl Into, + ) -> Result { + let signature: PolyFuncType = signature.into(); + let body = signature.body().clone(); + let op = ops::FuncDefn::new(name, signature); + + let module = hugr.as_ref().module_root(); + let func = hugr.as_mut().add_node_with_parent(module, op); + + let db = DFGBuilder::create_with_io(hugr, func, body)?; + Ok(Self::from_dfg_builder(db)) + } /// Add a new input to the function being constructed. /// @@ -259,31 +297,6 @@ impl FunctionBuilder { } } -impl + AsRef> FunctionBuilder { - /// Initialize a new function definition on the root module of an existing HUGR. - /// - /// The HUGR's entrypoint will **not** be modified. - /// - /// # Errors - /// - /// Error in adding DFG child nodes. - pub fn with_hugr( - mut hugr: B, - name: impl Into, - signature: impl Into, - ) -> Result { - let signature: PolyFuncType = signature.into(); - let body = signature.body().clone(); - let op = ops::FuncDefn::new(name, signature); - - let module = hugr.as_ref().module_root(); - let func = hugr.as_mut().add_node_with_parent(module, op); - - let db = DFGBuilder::create_with_io(hugr, func, body)?; - Ok(Self::from_dfg_builder(db)) - } -} - impl + AsRef, T> Container for DFGWrapper { #[inline] fn container_node(&self) -> Node { @@ -437,7 +450,7 @@ pub(crate) mod test { error: BuilderWiringError::NoCopyLinear { typ, .. }, .. }) - if typ == qb_t() + if *typ == qb_t() ); } @@ -652,7 +665,7 @@ pub(crate) mod test { FunctionBuilder::new( "bad_eval", PolyFuncType::new( - [TypeParam::new_list(TypeBound::Copyable)], + [TypeParam::new_list_type(TypeBound::Copyable)], Signature::new( Type::new_function(FuncValueType::new(usize_t(), tv.clone())), vec![], diff --git a/hugr-core/src/builder/module.rs b/hugr-core/src/builder/module.rs index 543b9f2c1e..5499abae92 100644 --- a/hugr-core/src/builder/module.rs +++ b/hugr-core/src/builder/module.rs @@ -4,25 +4,24 @@ use super::{ dataflow::{DFGBuilder, FunctionBuilder}, }; -use crate::hugr::ValidationError; use crate::hugr::internal::HugrMutInternals; use crate::hugr::views::HugrView; use crate::ops; -use crate::types::{PolyFuncType, Type, TypeBound}; - use crate::ops::handle::{AliasID, FuncID, NodeHandle}; +use crate::types::{PolyFuncType, Type, TypeBound}; +use crate::{Hugr, Node, Visibility}; +use crate::{hugr::ValidationError, ops::FuncDefn}; -use crate::{Hugr, Node}; use smol_str::SmolStr; /// Builder for a HUGR module. -#[derive(Debug, Clone, PartialEq)] +#[derive(Debug, Default, Clone, PartialEq)] pub struct ModuleBuilder(pub(super) T); impl + AsRef> Container for ModuleBuilder { #[inline] fn container_node(&self) -> Node { - self.0.as_ref().entrypoint() + self.0.as_ref().module_root() } #[inline] @@ -39,13 +38,7 @@ impl ModuleBuilder { /// Begin building a new module. #[must_use] pub fn new() -> Self { - Self(Default::default()) - } -} - -impl Default for ModuleBuilder { - fn default() -> Self { - Self::new() + Self::default() } } @@ -75,25 +68,61 @@ impl + AsRef> ModuleBuilder { f_id: &FuncID, ) -> Result, BuildError> { let f_node = f_id.node(); - let decl = - self.hugr() - .get_optype(f_node) - .as_func_decl() - .ok_or(BuildError::UnexpectedType { - node: f_node, - op_desc: "crate::ops::OpType::FuncDecl", - })?; - let name = decl.func_name().clone(); - let sig = decl.signature().clone(); - let body = sig.body().clone(); - self.hugr_mut() - .replace_op(f_node, ops::FuncDefn::new(name, sig)); + let opty = self.hugr_mut().optype_mut(f_node); + let ops::OpType::FuncDecl(decl) = opty else { + return Err(BuildError::UnexpectedType { + node: f_node, + op_desc: "crate::ops::OpType::FuncDecl", + }); + }; + + let body = decl.signature().body().clone(); + *opty = ops::FuncDefn::new_vis( + decl.func_name(), + decl.signature().clone(), + decl.visibility().clone(), + ) + .into(); + + let db = DFGBuilder::create_with_io(self.hugr_mut(), f_node, body)?; + Ok(FunctionBuilder::from_dfg_builder(db)) + } + + /// Add a [`ops::FuncDefn`] node of the specified visibility. + /// Returns a builder to define the function body graph. + /// + /// # Errors + /// + /// This function will return an error if there is an error in adding the + /// [`ops::FuncDefn`] node. + pub fn define_function_vis( + &mut self, + name: impl Into, + signature: impl Into, + visibility: Visibility, + ) -> Result, BuildError> { + self.define_function_op(FuncDefn::new_vis(name, signature, visibility)) + } + + fn define_function_op( + &mut self, + op: FuncDefn, + ) -> Result, BuildError> { + let body = op.signature().body().clone(); + let f_node = self.add_child_node(op); + + // Add the extensions used by the function types. + self.use_extensions( + body.used_extensions().unwrap_or_else(|e| { + panic!("Build-time signatures should have valid extensions. {e}") + }), + ); let db = DFGBuilder::create_with_io(self.hugr_mut(), f_node, body)?; Ok(FunctionBuilder::from_dfg_builder(db)) } - /// Declare a function with `signature` and return a handle to the declaration. + /// Declare a [Visibility::Public] function with `signature` and return a handle to the declaration. /// /// # Errors /// @@ -103,10 +132,26 @@ impl + AsRef> ModuleBuilder { &mut self, name: impl Into, signature: PolyFuncType, + ) -> Result, BuildError> { + self.declare_vis(name, signature, Visibility::Public) + } + + /// Declare a function with the specified `signature` and [Visibility], + /// and return a handle to the declaration. + /// + /// # Errors + /// + /// This function will return an error if there is an error in adding the + /// [`crate::ops::OpType::FuncDecl`] node. + pub fn declare_vis( + &mut self, + name: impl Into, + signature: PolyFuncType, + visibility: Visibility, ) -> Result, BuildError> { let body = signature.body().clone(); // TODO add param names to metadata - let declare_n = self.add_child_node(ops::FuncDecl::new(name, signature)); + let declare_n = self.add_child_node(ops::FuncDecl::new_vis(name, signature, visibility)); // Add the extensions used by the function types. self.use_extensions( @@ -118,6 +163,21 @@ impl + AsRef> ModuleBuilder { Ok(declare_n.into()) } + /// Adds a [`ops::FuncDefn`] node and returns a builder to define the function + /// body graph. The function will be private. (See [Self::define_function_vis].) + /// + /// # Errors + /// + /// This function will return an error if there is an error in adding the + /// [`ops::FuncDefn`] node. + pub fn define_function( + &mut self, + name: impl Into, + signature: impl Into, + ) -> Result, BuildError> { + self.define_function_op(FuncDefn::new(name, signature)) + } + /// Add a [`crate::ops::OpType::AliasDefn`] node and return a handle to the Alias. /// /// # Errors @@ -199,7 +259,7 @@ mod test { let mut module_builder = ModuleBuilder::new(); let qubit_state_type = - module_builder.add_alias_declare("qubit_state", TypeBound::Any)?; + module_builder.add_alias_declare("qubit_state", TypeBound::Linear)?; let f_build = module_builder.define_function( "main", @@ -215,31 +275,6 @@ mod test { Ok(()) } - #[test] - fn local_def() -> Result<(), BuildError> { - let build_result = { - let mut module_builder = ModuleBuilder::new(); - - let mut f_build = module_builder.define_function( - "main", - Signature::new(vec![usize_t()], vec![usize_t(), usize_t()]), - )?; - let local_build = f_build.define_function( - "local", - Signature::new(vec![usize_t()], vec![usize_t(), usize_t()]), - )?; - let [wire] = local_build.input_wires_arr(); - let f_id = local_build.finish_with_outputs([wire, wire])?; - - let call = f_build.call(f_id.handle(), &[], f_build.input_wires())?; - - f_build.finish_with_outputs(call.outputs())?; - module_builder.finish_hugr() - }; - assert_matches!(build_result, Ok(_)); - Ok(()) - } - #[test] fn builder_from_existing() -> Result<(), BuildError> { let hugr = Hugr::new(); diff --git a/hugr-core/src/core.rs b/hugr-core/src/core.rs index 06366822ae..4578aa5357 100644 --- a/hugr-core/src/core.rs +++ b/hugr-core/src/core.rs @@ -7,7 +7,7 @@ pub use itertools::Either; use derive_more::From; use itertools::Either::{Left, Right}; -use crate::hugr::HugrError; +use crate::{HugrView, hugr::HugrError}; /// A handle to a node in the HUGR. #[derive( @@ -34,7 +34,7 @@ pub struct Node { )] #[serde(transparent)] pub struct Port { - offset: portgraph::PortOffset, + offset: portgraph::PortOffset, } /// A trait for getting the undirected index of a port. @@ -139,7 +139,7 @@ impl Port { /// Returns the port as a portgraph `PortOffset`. #[inline] - pub(crate) fn pg_offset(self) -> portgraph::PortOffset { + pub(crate) fn pg_offset(self) -> portgraph::PortOffset { self.offset } } @@ -219,17 +219,55 @@ impl Wire { Self(node, port.into()) } - /// The node that this wire is connected to. + /// Create a new wire from a node and a port that is connected to the wire. + /// + /// If `port` is an incoming port, the wire is traversed to find the unique + /// outgoing port that is connected to the wire. Otherwise, this is + /// equivalent to constructing a wire using [`Wire::new`]. + /// + /// ## Panics + /// + /// This will panic if the wire is not connected to a unique outgoing port. + #[inline] + pub fn from_connected_port( + node: N, + port: impl Into, + hugr: &impl HugrView, + ) -> Self { + let (node, outgoing) = match port.into().as_directed() { + Either::Left(incoming) => hugr + .single_linked_output(node, incoming) + .expect("invalid dfg port"), + Either::Right(outgoing) => (node, outgoing), + }; + Self::new(node, outgoing) + } + + /// The node of the unique outgoing port that the wire is connected to. #[inline] pub fn node(&self) -> N { self.0 } - /// The output port that this wire is connected to. + /// The unique outgoing port that the wire is connected to. #[inline] pub fn source(&self) -> OutgoingPort { self.1 } + + /// Get all ports connected to the wire. + /// + /// Return a chained iterator of the unique outgoing port, followed by all + /// incoming ports connected to the wire. + pub fn all_connected_ports<'h, H: HugrView>( + &self, + hugr: &'h H, + ) -> impl Iterator + use<'h, N, H> { + let node = self.node(); + let out_port = self.source(); + + std::iter::once((node, out_port.into())).chain(hugr.linked_ports(node, out_port)) + } } impl std::fmt::Display for Wire { @@ -238,6 +276,46 @@ impl std::fmt::Display for Wire { } } +/// Marks [FuncDefn](crate::ops::FuncDefn)s and [FuncDecl](crate::ops::FuncDecl)s as +/// to whether they should be considered for linking. +#[derive( + Clone, + Debug, + derive_more::Display, + PartialEq, + Eq, + PartialOrd, + Ord, + serde::Serialize, + serde::Deserialize, +)] +#[cfg_attr(test, derive(proptest_derive::Arbitrary))] +#[non_exhaustive] +pub enum Visibility { + /// Function is visible or exported + Public, + /// Function is hidden, for use within the hugr only + Private, +} + +impl From for Visibility { + fn from(value: hugr_model::v0::Visibility) -> Self { + match value { + hugr_model::v0::Visibility::Private => Self::Private, + hugr_model::v0::Visibility::Public => Self::Public, + } + } +} + +impl From for hugr_model::v0::Visibility { + fn from(value: Visibility) -> Self { + match value { + Visibility::Public => hugr_model::v0::Visibility::Public, + Visibility::Private => hugr_model::v0::Visibility::Private, + } + } +} + /// Enum for uniquely identifying the origin of linear wires in a circuit-like /// dataflow region. /// diff --git a/hugr-core/src/envelope.rs b/hugr-core/src/envelope.rs index 0223267b85..d07f4b3e41 100644 --- a/hugr-core/src/envelope.rs +++ b/hugr-core/src/envelope.rs @@ -73,11 +73,11 @@ pub const USED_EXTENSIONS_KEY: &str = "core.used_extensions"; /// If multiple modules have different generators, a comma-separated list is returned in /// module order. /// If no generator is found, `None` is returned. -fn get_generator(modules: &[H]) -> Option { +pub fn get_generator(modules: &[H]) -> Option { let generators: Vec = modules .iter() .filter_map(|hugr| hugr.get_metadata(hugr.module_root(), GENERATOR_KEY)) - .map(|v| v.to_string()) + .map(format_generator) .collect(); if generators.is_empty() { return None; @@ -86,6 +86,31 @@ fn get_generator(modules: &[H]) -> Option { Some(generators.join(", ")) } +/// Format a generator value from the metadata. +pub fn format_generator(json_val: &serde_json::Value) -> String { + match json_val { + serde_json::Value::String(s) => s.clone(), + serde_json::Value::Object(obj) => { + if let (Some(name), version) = ( + obj.get("name").and_then(|v| v.as_str()), + obj.get("version").and_then(|v| v.as_str()), + ) { + if let Some(version) = version { + // Expected format: {"name": "generator", "version": "1.0.0"} + format!("{name}-v{version}") + } else { + name.to_string() + } + } else { + // just print the whole object as a string + json_val.to_string() + } + } + // Raw JSON string fallback + _ => json_val.to_string(), + } +} + fn gen_str(generator: &Option) -> String { match generator { Some(g) => format!("\ngenerated by {g}"), @@ -97,7 +122,7 @@ fn gen_str(generator: &Option) -> String { #[derive(Error, Debug)] #[error("{inner}{}", gen_str(&self.generator))] pub struct WithGenerator { - inner: E, + inner: Box, /// The name of the generator that produced the envelope, if any. generator: Option, } @@ -105,7 +130,7 @@ pub struct WithGenerator { impl WithGenerator { fn new(err: E, modules: &[impl HugrView]) -> Self { Self { - inner: err, + inner: Box::new(err), generator: get_generator(modules), } } @@ -179,16 +204,15 @@ pub(crate) fn write_envelope_impl<'h>( } /// Error type for envelope operations. -#[derive(derive_more::Display, derive_more::Error, Debug, derive_more::From)] +#[derive(Debug, Error)] #[non_exhaustive] pub enum EnvelopeError { /// Bad magic number. - #[display( + #[error( "Bad magic number. expected 0x{:X} found 0x{:X}", u64::from_be_bytes(*expected), u64::from_be_bytes(*found) )] - #[from(ignore)] MagicNumber { /// The expected magic number. /// @@ -198,20 +222,18 @@ pub enum EnvelopeError { found: [u8; 8], }, /// The specified payload format is invalid. - #[display("Format descriptor {descriptor} is invalid.")] - #[from(ignore)] + #[error("Format descriptor {descriptor} is invalid.")] InvalidFormatDescriptor { /// The unsupported format. descriptor: usize, }, /// The specified payload format is not supported. - #[display("Payload format {format} is not supported.{}", + #[error("Payload format {format} is not supported.{}", match feature { Some(f) => format!(" This requires the '{f}' feature for `hugr`."), None => String::new() }, )] - #[from(ignore)] FormatUnsupported { /// The unsupported format. format: EnvelopeFormat, @@ -221,68 +243,97 @@ pub enum EnvelopeError { /// Not all envelope formats can be represented as ASCII. /// /// This error is used when trying to store the envelope into a string. - #[display("Envelope format {format} cannot be represented as ASCII.")] - #[from(ignore)] + #[error("Envelope format {format} cannot be represented as ASCII.")] NonASCIIFormat { /// The unsupported format. format: EnvelopeFormat, }, /// Envelope encoding required zstd compression, but the feature is not enabled. - #[display("Zstd compression is not supported. This requires the 'zstd' feature for `hugr`.")] - #[from(ignore)] + #[error("Zstd compression is not supported. This requires the 'zstd' feature for `hugr`.")] ZstdUnsupported, /// Expected the envelope to contain a single HUGR. - #[display("Expected an envelope containing a single hugr, but it contained {}.", if *count == 0 { + #[error("Expected an envelope containing a single hugr, but it contained {}.", if *count == 0 { "none".to_string() } else { count.to_string() })] - #[from(ignore)] ExpectedSingleHugr { /// The number of HUGRs in the package. count: usize, }, /// JSON serialization error. + #[error(transparent)] SerdeError { /// The source error. + #[from] source: serde_json::Error, }, /// IO read/write error. + #[error(transparent)] IO { /// The source error. + #[from] source: std::io::Error, }, /// Error writing a json package to the payload. + #[error(transparent)] PackageEncoding { /// The source error. + #[from] source: PackageEncodingError, }, /// Error importing a HUGR from a hugr-model payload. + #[error(transparent)] ModelImport { /// The source error. + #[from] source: ImportError, // TODO add generator to model import errors }, /// Error reading a HUGR model payload. + #[error(transparent)] ModelRead { /// The source error. + #[from] source: hugr_model::v0::binary::ReadError, }, /// Error writing a HUGR model payload. + #[error(transparent)] ModelWrite { /// The source error. + #[from] source: hugr_model::v0::binary::WriteError, }, /// Error reading a HUGR model payload. + #[error("Model text parsing error")] ModelTextRead { /// The source error. + #[from] source: hugr_model::v0::ast::ParseError, }, /// Error reading a HUGR model payload. + #[error(transparent)] ModelTextResolve { /// The source error. + #[from] source: hugr_model::v0::ast::ResolveError, }, + /// Error reading a list of extensions from the envelope. + #[error(transparent)] + ExtensionLoad { + /// The source error. + #[from] + source: crate::extension::ExtensionRegistryLoadError, + }, + /// The specified payload format is not supported. + #[error( + "The envelope configuration has unknown {}. Please update your HUGR version.", + if flag_ids.len() == 1 {format!("flag #{}", flag_ids[0])} else {format!("flags {}", flag_ids.iter().join(", "))} + )] + FlagUnsupported { + /// The unrecognized flag bits. + flag_ids: Vec, + }, } /// Internal implementation of [`read_envelope`] to call with/without the zstd decompression wrapper. @@ -329,11 +380,8 @@ fn decode_model( let mut extension_registry = extension_registry.clone(); if format == EnvelopeFormat::ModelWithExtensions { - let extra_extensions: Vec = - serde_json::from_reader::<_, Vec>(stream)?; - for ext in extra_extensions { - extension_registry.register_updated(ext); - } + let extra_extensions = ExtensionRegistry::load_json(stream, &extension_registry)?; + extension_registry.extend(extra_extensions); } Ok(import_package(&model_package, &extension_registry)?) @@ -803,6 +851,6 @@ pub(crate) mod test { let err_msg = with_gen.to_string(); assert!(err_msg.contains("Extension 'test' version mismatch")); - assert!(err_msg.contains(generator_name.to_string().as_str())); + assert!(err_msg.contains("TestGenerator-v1.2.3")); } } diff --git a/hugr-core/src/envelope/header.rs b/hugr-core/src/envelope/header.rs index 66af887454..54353e2f18 100644 --- a/hugr-core/src/envelope/header.rs +++ b/hugr-core/src/envelope/header.rs @@ -3,6 +3,8 @@ use std::io::{Read, Write}; use std::num::NonZeroU8; +use itertools::Itertools; + use super::EnvelopeError; /// Magic number identifying the start of an envelope. @@ -11,6 +13,12 @@ use super::EnvelopeError; /// to avoid accidental collisions with other file formats. pub const MAGIC_NUMBERS: &[u8] = "HUGRiHJv".as_bytes(); +/// The all-unset header flags configuration. +/// Bit 7 is always set to ensure we have a printable ASCII character. +const DEFAULT_FLAGS: u8 = 0b0100_0000u8; +/// The ZSTD flag bit in the header's flags. +const ZSTD_FLAG: u8 = 0b0000_0001; + /// Header at the start of a binary envelope file. /// /// See the [`crate::envelope`] module documentation for the binary format. @@ -224,8 +232,10 @@ impl EnvelopeHeader { let format_bytes = [self.format as u8]; writer.write_all(&format_bytes)?; // Next is the flags byte. - let mut flags = 0b01000000u8; - flags |= u8::from(self.zstd); + let mut flags = DEFAULT_FLAGS; + if self.zstd { + flags |= ZSTD_FLAG; + } writer.write_all(&[flags])?; Ok(()) @@ -259,7 +269,16 @@ impl EnvelopeHeader { // Next is the flags byte. let mut flags_bytes = [0; 1]; reader.read_exact(&mut flags_bytes)?; - let zstd = flags_bytes[0] & 0x1 != 0; + let flags: u8 = flags_bytes[0]; + + let zstd = flags & ZSTD_FLAG != 0; + + // Check if there's any unrecognized flags. + let other_flags = (flags ^ DEFAULT_FLAGS) & !ZSTD_FLAG; + if other_flags != 0 { + let flag_ids = (0..8).filter(|i| other_flags & (1 << i) != 0).collect_vec(); + return Err(EnvelopeError::FlagUnsupported { flag_ids }); + } Ok(Self { format, zstd }) } @@ -268,6 +287,7 @@ impl EnvelopeHeader { #[cfg(test)] mod tests { use super::*; + use cool_asserts::assert_matches; use rstest::rstest; #[rstest] @@ -296,4 +316,35 @@ mod tests { let read_header = EnvelopeHeader::read(&mut buffer.as_slice()).unwrap(); assert_eq!(header, read_header); } + + #[rstest] + fn header_errors() { + let header = EnvelopeHeader { + format: EnvelopeFormat::Model, + zstd: false, + }; + let mut buffer = Vec::new(); + header.write(&mut buffer).unwrap(); + + assert_eq!(buffer.len(), 10); + let flags = buffer[9]; + assert_eq!(flags, DEFAULT_FLAGS); + + // Invalid magic + let mut invalid_magic = buffer.clone(); + invalid_magic[7] = 0xFF; + assert_matches!( + EnvelopeHeader::read(&mut invalid_magic.as_slice()), + Err(EnvelopeError::MagicNumber { .. }) + ); + + // Unrecognised flags + let mut unrecognised_flags = buffer.clone(); + unrecognised_flags[9] |= 0b0001_0010; + assert_matches!( + EnvelopeHeader::read(&mut unrecognised_flags.as_slice()), + Err(EnvelopeError::FlagUnsupported { flag_ids }) + => assert_eq!(flag_ids, vec![1, 4]) + ); + } } diff --git a/hugr-core/src/envelope/package_json.rs b/hugr-core/src/envelope/package_json.rs index bbdf19d26e..2aa6c982f7 100644 --- a/hugr-core/src/envelope/package_json.rs +++ b/hugr-core/src/envelope/package_json.rs @@ -6,7 +6,6 @@ use std::io; use super::{ExtensionBreakingError, WithGenerator, check_breaking_extensions}; use crate::extension::ExtensionRegistry; use crate::extension::resolution::ExtensionResolutionError; -use crate::hugr::ExtensionError; use crate::package::Package; use crate::{Extension, Hugr}; @@ -57,6 +56,20 @@ pub(super) fn to_json_writer<'h>( modules: hugrs.into_iter().map(HugrSer).collect(), extensions: extensions.iter().map(std::convert::AsRef::as_ref).collect(), }; + + // Validate the hugr serializations against the schema. + // + // NOTE: The schema definition is currently broken, so this check always succeeds. + // See + #[cfg(all(test, not(miri)))] + if std::env::var("HUGR_TEST_SCHEMA").is_ok_and(|x| !x.is_empty()) { + use crate::hugr::serialize::test::check_hugr_serialization_schema; + + for hugr in &pkg_ser.modules { + check_hugr_serialization_schema(hugr.0); + } + } + serde_json::to_writer(writer, &pkg_ser)?; Ok(()) } @@ -64,17 +77,16 @@ pub(super) fn to_json_writer<'h>( /// Error raised while loading a package. #[derive(Debug, Display, Error, From)] #[non_exhaustive] +#[display("Error reading or writing a package in JSON format.")] pub enum PackageEncodingError { /// Error raised while parsing the package json. - JsonEncoding(serde_json::Error), + JsonEncoding(#[from] serde_json::Error), /// Error raised while reading from a file. - IOError(io::Error), + IOError(#[from] io::Error), /// Could not resolve the extension needed to encode the hugr. - ExtensionResolution(WithGenerator), + ExtensionResolution(#[from] WithGenerator), /// Error raised while checking for breaking extension version mismatch. - ExtensionVersion(WithGenerator), - /// Could not resolve the runtime extensions for the hugr. - RuntimeExtensionResolution(ExtensionError), + ExtensionVersion(#[from] WithGenerator), } /// A private package structure implementing the serde traits. diff --git a/hugr-core/src/envelope/serde_with.rs b/hugr-core/src/envelope/serde_with.rs index 7b9517d3e0..28d3cd3189 100644 --- a/hugr-core/src/envelope/serde_with.rs +++ b/hugr-core/src/envelope/serde_with.rs @@ -15,6 +15,9 @@ use crate::std_extensions::STD_REG; /// De/Serialize a package or hugr by encoding it into a textual Envelope and /// storing it as a string. /// +/// This is similar to [`AsBinaryEnvelope`], but uses a textual envelope instead +/// of a binary one. +/// /// Note that only PRELUDE extensions are used to decode the package's content. /// When serializing a HUGR, any additional extensions required to load it are /// embedded in the envelope. Packages should manually add any required @@ -45,9 +48,53 @@ use crate::std_extensions::STD_REG; /// When reading an encoded HUGR, the `AsStringEnvelope` deserializer will first /// try to decode the value as an string-encoded envelope. If that fails, it /// will fallback to decoding the legacy HUGR serde definition. This temporary -/// compatibility layer is meant to be removed in 0.21.0. +/// compatibility is required to support `hugr <= 0.19` and will be removed in +/// a future version. pub struct AsStringEnvelope; +/// De/Serialize a package or hugr by encoding it into a binary envelope and +/// storing it as a base64-encoded string. +/// +/// This is similar to [`AsStringEnvelope`], but uses a binary envelope instead +/// of a string. +/// When deserializing, if the string starts with the envelope magic 'HUGRiHJv' +/// it will be loaded as a string envelope without base64 decoding. +/// +/// Note that only PRELUDE extensions are used to decode the package's content. +/// When serializing a HUGR, any additional extensions required to load it are +/// embedded in the envelope. Packages should manually add any required +/// extensions before serializing. +/// +/// # Examples +/// +/// ```rust +/// # use serde::{Deserialize, Serialize}; +/// # use serde_json::json; +/// # use serde_with::{serde_as}; +/// # use hugr_core::Hugr; +/// # use hugr_core::package::Package; +/// # use hugr_core::envelope::serde_with::AsBinaryEnvelope; +/// # +/// #[serde_as] +/// #[derive(Deserialize, Serialize)] +/// struct A { +/// #[serde_as(as = "AsBinaryEnvelope")] +/// package: Package, +/// #[serde_as(as = "Vec")] +/// hugrs: Vec, +/// } +/// ``` +/// +/// # Backwards compatibility +/// +/// When reading an encoded HUGR, the `AsBinaryEnvelope` deserializer will first +/// try to decode the value as an binary-encoded envelope. If that fails, it +/// will fallback to decoding a string envelope instead, and then finally to +/// decoding the legacy HUGR serde definition. This temporary compatibility +/// layer is required to support `hugr <= 0.19` and will be removed in a future +/// version. +pub struct AsBinaryEnvelope; + /// Implements [`serde_with::DeserializeAs`] and [`serde_with::SerializeAs`] for /// the helper to deserialize `Hugr` and `Package` types, using the given /// extension registry. @@ -211,3 +258,337 @@ macro_rules! impl_serde_as_string_envelope { pub use impl_serde_as_string_envelope; impl_serde_as_string_envelope!(AsStringEnvelope, &STD_REG); + +/// Implements [`serde_with::DeserializeAs`] and [`serde_with::SerializeAs`] for +/// the helper to deserialize `Hugr` and `Package` types, using the given +/// extension registry. +/// +/// This macro is used to implement the default [`AsBinaryEnvelope`] wrapper. +/// +/// # Parameters +/// +/// - `$adaptor`: The name of the adaptor type to implement. +/// - `$extension_reg`: A reference to the extension registry to use for deserialization. +/// +/// # Examples +/// +/// ```rust +/// # use serde::{Deserialize, Serialize}; +/// # use serde_json::json; +/// # use serde_with::{serde_as}; +/// # use hugr_core::Hugr; +/// # use hugr_core::package::Package; +/// # use hugr_core::envelope::serde_with::AsBinaryEnvelope; +/// # use hugr_core::envelope::serde_with::impl_serde_as_binary_envelope; +/// # use hugr_core::extension::ExtensionRegistry; +/// # +/// struct CustomAsEnvelope; +/// +/// impl_serde_as_binary_envelope!(CustomAsEnvelope, &hugr_core::extension::EMPTY_REG); +/// +/// #[serde_as] +/// #[derive(Deserialize, Serialize)] +/// struct A { +/// #[serde_as(as = "CustomAsEnvelope")] +/// package: Package, +/// } +/// ``` +/// +#[macro_export] +macro_rules! impl_serde_as_binary_envelope { + ($adaptor:ident, $extension_reg:expr) => { + impl<'de> serde_with::DeserializeAs<'de, $crate::package::Package> for $adaptor { + fn deserialize_as(deserializer: D) -> Result<$crate::package::Package, D::Error> + where + D: serde::Deserializer<'de>, + { + struct Helper; + impl serde::de::Visitor<'_> for Helper { + type Value = $crate::package::Package; + + fn expecting( + &self, + formatter: &mut std::fmt::Formatter<'_>, + ) -> std::fmt::Result { + formatter.write_str("a base64-encoded envelope") + } + + fn visit_str(self, value: &str) -> Result + where + E: serde::de::Error, + { + use $crate::envelope::serde_with::base64::{DecoderReader, STANDARD}; + + let extensions: &$crate::extension::ExtensionRegistry = $extension_reg; + + if value + .as_bytes() + .starts_with($crate::envelope::MAGIC_NUMBERS) + { + // If the string starts with the envelope magic 'HUGRiHJv', + // skip the base64 decoding. + let reader = std::io::Cursor::new(value.as_bytes()); + $crate::package::Package::load(reader, Some(extensions)) + .map_err(|e| serde::de::Error::custom(format!("{e:?}"))) + } else { + let reader = DecoderReader::new(value.as_bytes(), &STANDARD); + let buf_reader = std::io::BufReader::new(reader); + $crate::package::Package::load(buf_reader, Some(extensions)) + .map_err(|e| serde::de::Error::custom(format!("{e:?}"))) + } + } + } + + deserializer.deserialize_str(Helper) + } + } + + impl<'de> serde_with::DeserializeAs<'de, $crate::Hugr> for $adaptor { + fn deserialize_as(deserializer: D) -> Result<$crate::Hugr, D::Error> + where + D: serde::Deserializer<'de>, + { + struct Helper; + impl<'vis> serde::de::Visitor<'vis> for Helper { + type Value = $crate::Hugr; + + fn expecting( + &self, + formatter: &mut std::fmt::Formatter<'_>, + ) -> std::fmt::Result { + formatter.write_str("a base64-encoded envelope") + } + + fn visit_str(self, value: &str) -> Result + where + E: serde::de::Error, + { + use $crate::envelope::serde_with::base64::{DecoderReader, STANDARD}; + + let extensions: &$crate::extension::ExtensionRegistry = $extension_reg; + + if value + .as_bytes() + .starts_with($crate::envelope::MAGIC_NUMBERS) + { + // If the string starts with the envelope magic 'HUGRiHJv', + // skip the base64 decoding. + let reader = std::io::Cursor::new(value.as_bytes()); + $crate::Hugr::load(reader, Some(extensions)) + .map_err(|e| serde::de::Error::custom(format!("{e:?}"))) + } else { + let reader = DecoderReader::new(value.as_bytes(), &STANDARD); + let buf_reader = std::io::BufReader::new(reader); + $crate::Hugr::load(buf_reader, Some(extensions)) + .map_err(|e| serde::de::Error::custom(format!("{e:?}"))) + } + } + + fn visit_map
(self, map: A) -> Result + where + A: serde::de::MapAccess<'vis>, + { + // Backwards compatibility: If the encoded value is not a + // string, we may have a legacy HUGR serde structure instead. In that + // case, we can add an envelope header and try again. + // + // TODO: Remove this fallback in a breaking change + let deserializer = serde::de::value::MapAccessDeserializer::new(map); + #[allow(deprecated)] + let mut hugr = + $crate::hugr::serialize::serde_deserialize_hugr(deserializer) + .map_err(serde::de::Error::custom)?; + + let extensions: &$crate::extension::ExtensionRegistry = $extension_reg; + hugr.resolve_extension_defs(extensions) + .map_err(serde::de::Error::custom)?; + Ok(hugr) + } + } + + // TODO: Go back to `deserialize_str` once the fallback is removed. + deserializer.deserialize_any(Helper) + } + } + + impl serde_with::SerializeAs<$crate::package::Package> for $adaptor { + fn serialize_as( + source: &$crate::package::Package, + serializer: S, + ) -> Result + where + S: serde::Serializer, + { + use $crate::envelope::serde_with::base64::{EncoderStringWriter, STANDARD}; + + let mut writer = EncoderStringWriter::new(&STANDARD); + source + .store(&mut writer, $crate::envelope::EnvelopeConfig::binary()) + .map_err(serde::ser::Error::custom)?; + let str = writer.into_inner(); + serializer.collect_str(&str) + } + } + + impl serde_with::SerializeAs<$crate::Hugr> for $adaptor { + fn serialize_as(source: &$crate::Hugr, serializer: S) -> Result + where + S: serde::Serializer, + { + // Include any additional extension required to load the HUGR in the envelope. + let extensions: &$crate::extension::ExtensionRegistry = $extension_reg; + let mut extra_extensions = $crate::extension::ExtensionRegistry::default(); + for ext in $crate::hugr::views::HugrView::extensions(source).iter() { + if !extensions.contains(ext.name()) { + extra_extensions.register_updated(ext.clone()); + } + } + use $crate::envelope::serde_with::base64::{EncoderStringWriter, STANDARD}; + + let mut writer = EncoderStringWriter::new(&STANDARD); + source + .store_with_exts( + &mut writer, + $crate::envelope::EnvelopeConfig::binary(), + &extra_extensions, + ) + .map_err(serde::ser::Error::custom)?; + let str = writer.into_inner(); + serializer.collect_str(&str) + } + } + }; +} +pub use impl_serde_as_binary_envelope; + +impl_serde_as_binary_envelope!(AsBinaryEnvelope, &STD_REG); + +// Hidden re-export required to expand the binary envelope macros on external +// crates. +#[doc(hidden)] +pub mod base64 { + pub use base64::Engine; + pub use base64::engine::general_purpose::STANDARD; + pub use base64::read::DecoderReader; + pub use base64::write::EncoderStringWriter; +} + +#[cfg(test)] +mod test { + use rstest::rstest; + use serde::{Deserialize, Serialize}; + use serde_with::serde_as; + + use crate::Hugr; + use crate::package::Package; + + use super::*; + + #[serde_as] + #[derive(Deserialize, Serialize)] + struct TextPkg { + #[serde_as(as = "AsStringEnvelope")] + data: Package, + } + + #[serde_as] + #[derive(Default, Deserialize, Serialize)] + struct TextHugr { + #[serde_as(as = "AsStringEnvelope")] + data: Hugr, + } + + #[serde_as] + #[derive(Deserialize, Serialize)] + struct BinaryPkg { + #[serde_as(as = "AsBinaryEnvelope")] + data: Package, + } + + #[serde_as] + #[derive(Default, Deserialize, Serialize)] + struct BinaryHugr { + #[serde_as(as = "AsBinaryEnvelope")] + data: Hugr, + } + + #[derive(Default, Deserialize, Serialize)] + struct LegacyHugr { + #[serde(deserialize_with = "Hugr::serde_deserialize")] + #[serde(serialize_with = "Hugr::serde_serialize")] + data: Hugr, + } + + impl Default for TextPkg { + fn default() -> Self { + // Default package with a single hugr (so it can be decoded as a hugr too). + Self { + data: Package::from_hugr(Hugr::default()), + } + } + } + + impl Default for BinaryPkg { + fn default() -> Self { + // Default package with a single hugr (so it can be decoded as a hugr too). + Self { + data: Package::from_hugr(Hugr::default()), + } + } + } + + fn decode serde::Deserialize<'a>>(encoded: String) -> Result<(), serde_json::Error> { + let _: T = serde_json::de::from_str(&encoded)?; + Ok(()) + } + + #[rstest] + // Text formats are swappable + #[case::text_pkg_text_pkg(TextPkg::default(), decode::, false)] + #[case::text_pkg_text_hugr(TextPkg::default(), decode::, false)] + #[case::text_hugr_text_pkg(TextHugr::default(), decode::, false)] + #[case::text_hugr_text_hugr(TextHugr::default(), decode::, false)] + // Binary formats can read each other + #[case::bin_pkg_bin_pkg(BinaryPkg::default(), decode::, false)] + #[case::bin_pkg_bin_hugr(BinaryPkg::default(), decode::, false)] + #[case::bin_hugr_bin_pkg(BinaryHugr::default(), decode::, false)] + #[case::bin_hugr_bin_hugr(BinaryHugr::default(), decode::, false)] + // Binary formats can read text ones + #[case::text_pkg_bin_pkg(TextPkg::default(), decode::, false)] + #[case::text_pkg_bin_hugr(TextPkg::default(), decode::, false)] + #[case::text_hugr_bin_pkg(TextHugr::default(), decode::, false)] + #[case::text_hugr_bin_hugr(TextHugr::default(), decode::, false)] + // But text formats can't read binary + #[case::bin_pkg_text_pkg(BinaryPkg::default(), decode::, true)] + #[case::bin_pkg_text_hugr(BinaryPkg::default(), decode::, true)] + #[case::bin_hugr_text_pkg(BinaryHugr::default(), decode::, true)] + #[case::bin_hugr_text_hugr(BinaryHugr::default(), decode::, true)] + // We can read old hugrs into hugrs, but not packages + #[case::legacy_hugr_text_pkg(LegacyHugr::default(), decode::, true)] + #[case::legacy_hugr_text_hugr(LegacyHugr::default(), decode::, false)] + #[case::legacy_hugr_bin_pkg(LegacyHugr::default(), decode::, true)] + #[case::legacy_hugr_bin_hugr(LegacyHugr::default(), decode::, false)] + // Decoding any new format as legacy hugr always fails + #[case::text_pkg_legacy_hugr(TextPkg::default(), decode::, true)] + #[case::text_hugr_legacy_hugr(TextHugr::default(), decode::, true)] + #[case::bin_pkg_legacy_hugr(BinaryPkg::default(), decode::, true)] + #[case::bin_hugr_legacy_hugr(BinaryHugr::default(), decode::, true)] + #[cfg_attr(all(miri, feature = "zstd"), ignore)] // FFI calls (required to compress with zstd) are not supported in miri + fn check_format_compatibility( + #[case] encoder: impl serde::Serialize, + #[case] decoder: fn(String) -> Result<(), serde_json::Error>, + #[case] errors: bool, + ) { + let encoded = serde_json::to_string(&encoder).unwrap(); + let decoded = decoder(encoded); + match (errors, decoded) { + (false, Err(e)) => { + panic!("Decoding error: {e}"); + } + (true, Ok(_)) => { + panic!("Roundtrip should have failed"); + } + _ => {} + } + } +} diff --git a/hugr-core/src/export.rs b/hugr-core/src/export.rs index dff471cc59..5eaead792a 100644 --- a/hugr-core/src/export.rs +++ b/hugr-core/src/export.rs @@ -1,6 +1,8 @@ //! Exporting HUGR graphs to their `hugr-model` representation. +use crate::Visibility; use crate::extension::ExtensionRegistry; use crate::hugr::internal::HugrInternals; +use crate::types::type_param::Term; use crate::{ Direction, Hugr, HugrView, IncomingPort, Node, NodeIndex as _, Port, extension::{ExtensionId, OpDef, SignatureFunc}, @@ -14,19 +16,19 @@ use crate::{ }, types::{ CustomType, EdgeKind, FuncTypeBase, MaybeRV, PolyFuncTypeBase, RowVariable, SumType, - TypeArg, TypeBase, TypeBound, TypeEnum, TypeRow, - type_param::{TypeArgVariable, TypeParam}, - type_row::TypeRowBase, + TypeBase, TypeBound, TypeEnum, type_param::TermVar, type_row::TypeRowBase, }, }; use fxhash::{FxBuildHasher, FxHashMap}; +use hugr_model::v0::bumpalo; use hugr_model::v0::{ self as model, bumpalo::{Bump, collections::String as BumpString, collections::Vec as BumpVec}, table, }; use petgraph::unionfind::UnionFind; +use smol_str::ToSmolStr; use std::fmt::Write; /// Exports a deconstructed `Package` to its representation in the model. @@ -95,6 +97,8 @@ struct Context<'a> { // that ensures that the `node_to_id` and `id_to_node` maps stay in sync. } +const NO_VIS: Option = None; + impl<'a> Context<'a> { pub fn new(hugr: &'a Hugr, bump: &'a Bump) -> Self { let mut module = table::Module::default(); @@ -231,16 +235,6 @@ impl<'a> Context<'a> { } } - /// Get the name of a function definition or declaration node. Returns `None` if not - /// one of those operations. - fn get_func_name(&self, func_node: Node) -> Option<&'a str> { - match self.hugr.get_optype(func_node) { - OpType::FuncDecl(func_decl) => Some(func_decl.func_name()), - OpType::FuncDefn(func_defn) => Some(func_defn.func_name()), - _ => None, - } - } - fn with_local_scope(&mut self, node: table::NodeId, f: impl FnOnce(&mut Self) -> T) -> T { let prev_local_scope = self.local_scope.replace(node); let prev_local_constraints = std::mem::take(&mut self.local_constraints); @@ -269,8 +263,12 @@ impl<'a> Context<'a> { // We record the name of the symbol defined by the node, if any. let symbol = match optype { - OpType::FuncDefn(func_defn) => Some(func_defn.func_name().as_str()), - OpType::FuncDecl(func_decl) => Some(func_decl.func_name().as_str()), + OpType::FuncDefn(_) | OpType::FuncDecl(_) => { + // Functions aren't exported using their core name but with a mangled + // name derived from their id. The function's core name will be recorded + // using `core.title` metadata. + Some(self.mangled_name(node)) + } OpType::AliasDecl(alias_decl) => Some(alias_decl.name.as_str()), OpType::AliasDefn(alias_defn) => Some(alias_defn.name.as_str()), _ => None, @@ -290,6 +288,7 @@ impl<'a> Context<'a> { // the node id. This is necessary to establish the correct node id for the // local scope introduced by some operations. We will overwrite this node later. let mut regions: &[_] = &[]; + let mut meta = Vec::new(); let node = self.id_to_node[&node_id]; let optype = self.hugr.get_optype(node); @@ -310,6 +309,7 @@ impl<'a> Context<'a> { node, model::ScopeClosure::Open, false, + false, )]); table::Operation::Dfg } @@ -334,24 +334,36 @@ impl<'a> Context<'a> { node, model::ScopeClosure::Open, false, + false, )]); table::Operation::Block } OpType::FuncDefn(func) => self.with_local_scope(node_id, |this| { - let name = this.get_func_name(node).unwrap(); - let symbol = this.export_poly_func_type(name, func.signature()); + let symbol_name = this.export_func_name(node, &mut meta); + + let symbol = this.export_poly_func_type( + symbol_name, + Some(func.visibility().clone().into()), + func.signature(), + ); regions = this.bump.alloc_slice_copy(&[this.export_dfg( node, model::ScopeClosure::Closed, false, + false, )]); table::Operation::DefineFunc(symbol) }), OpType::FuncDecl(func) => self.with_local_scope(node_id, |this| { - let name = this.get_func_name(node).unwrap(); - let symbol = this.export_poly_func_type(name, func.signature()); + let symbol_name = this.export_func_name(node, &mut meta); + + let symbol = this.export_poly_func_type( + symbol_name, + Some(func.visibility().clone().into()), + func.signature(), + ); table::Operation::DeclareFunc(symbol) }), @@ -359,6 +371,7 @@ impl<'a> Context<'a> { // TODO: We should support aliases with different types and with parameters let signature = this.make_term_apply(model::CORE_TYPE, &[]); let symbol = this.bump.alloc(table::Symbol { + visibility: &NO_VIS, // not spec'd in hugr-core name: &alias.name, params: &[], constraints: &[], @@ -372,6 +385,7 @@ impl<'a> Context<'a> { // TODO: We should support aliases with different types and with parameters let signature = this.make_term_apply(model::CORE_TYPE, &[]); let symbol = this.bump.alloc(table::Symbol { + visibility: &NO_VIS, // not spec'd in hugr-core name: &alias.name, params: &[], constraints: &[], @@ -385,7 +399,7 @@ impl<'a> Context<'a> { let node = self.connected_function(node).unwrap(); let symbol = self.node_to_id[&node]; let mut args = BumpVec::new_in(self.bump); - args.extend(call.type_args.iter().map(|arg| self.export_type_arg(arg))); + args.extend(call.type_args.iter().map(|arg| self.export_term(arg, None))); let args = args.into_bump_slice(); let func = self.make_term(table::Term::Apply(symbol, args)); @@ -401,7 +415,7 @@ impl<'a> Context<'a> { let node = self.connected_function(node).unwrap(); let symbol = self.node_to_id[&node]; let mut args = BumpVec::new_in(self.bump); - args.extend(load.type_args.iter().map(|arg| self.export_type_arg(arg))); + args.extend(load.type_args.iter().map(|arg| self.export_term(arg, None))); let args = args.into_bump_slice(); let func = self.make_term(table::Term::Apply(symbol, args)); let runtime_type = self.make_term(table::Term::Wildcard); @@ -451,6 +465,7 @@ impl<'a> Context<'a> { node, model::ScopeClosure::Open, false, + false, )]); table::Operation::TailLoop } @@ -464,7 +479,7 @@ impl<'a> Context<'a> { let node = self.export_opdef(op.def()); let params = self .bump - .alloc_slice_fill_iter(op.args().iter().map(|arg| self.export_type_arg(arg))); + .alloc_slice_fill_iter(op.args().iter().map(|arg| self.export_term(arg, None))); let operation = self.make_term(table::Term::Apply(node, params)); table::Operation::Custom(operation) } @@ -473,7 +488,7 @@ impl<'a> Context<'a> { let node = self.make_named_global_ref(op.extension(), op.unqualified_id()); let params = self .bump - .alloc_slice_fill_iter(op.args().iter().map(|arg| self.export_type_arg(arg))); + .alloc_slice_fill_iter(op.args().iter().map(|arg| self.export_term(arg, None))); let operation = self.make_term(table::Term::Apply(node, params)); table::Operation::Custom(operation) } @@ -502,12 +517,10 @@ impl<'a> Context<'a> { let inputs = self.make_ports(node, Direction::Incoming, num_inputs); let outputs = self.make_ports(node, Direction::Outgoing, num_outputs); - let meta = { - let mut meta = Vec::new(); - self.export_node_json_metadata(node, &mut meta); - self.export_node_order_metadata(node, &mut meta); - self.bump.alloc_slice_copy(&meta) - }; + self.export_node_json_metadata(node, &mut meta); + self.export_node_order_metadata(node, &mut meta); + self.export_node_entrypoint_metadata(node, &mut meta); + let meta = self.bump.alloc_slice_copy(&meta); self.module.nodes[node_id.index()] = table::Node { operation, @@ -546,7 +559,7 @@ impl<'a> Context<'a> { let symbol = self.with_local_scope(node, |this| { let name = this.make_qualified_name(opdef.extension_id(), opdef.name()); - this.export_poly_func_type(name, poly_func_type) + this.export_poly_func_type(name, None, poly_func_type) }); let meta = { @@ -578,7 +591,6 @@ impl<'a> Context<'a> { pub fn export_block_signature(&mut self, block: &DataflowBlock) -> table::TermId { let inputs = { let inputs = self.export_type_row(&block.inputs); - let inputs = self.make_term_apply(model::CORE_CTRL, &[inputs]); self.make_term(table::Term::List( self.bump.alloc_slice_copy(&[table::SeqPart::Item(inputs)]), )) @@ -590,13 +602,12 @@ impl<'a> Context<'a> { let mut outputs = BumpVec::with_capacity_in(block.sum_rows.len(), self.bump); for sum_row in &block.sum_rows { let variant = self.export_type_row_with_tail(sum_row, Some(tail)); - let control = self.make_term_apply(model::CORE_CTRL, &[variant]); - outputs.push(table::SeqPart::Item(control)); + outputs.push(table::SeqPart::Item(variant)); } self.make_term(table::Term::List(outputs.into_bump_slice())) }; - self.make_term_apply(model::CORE_FN, &[inputs, outputs]) + self.make_term_apply(model::CORE_CTRL, &[inputs, outputs]) } /// Creates a data flow region from the given node's children. @@ -607,6 +618,7 @@ impl<'a> Context<'a> { node: Node, closure: model::ScopeClosure, export_json_meta: bool, + export_entrypoint_meta: bool, ) -> table::RegionId { let region = self.module.insert_region(table::Region::default()); @@ -625,46 +637,54 @@ impl<'a> Context<'a> { if export_json_meta { self.export_node_json_metadata(node, &mut meta); } - self.export_node_entrypoint_metadata(node, &mut meta); + if export_entrypoint_meta { + self.export_node_entrypoint_metadata(node, &mut meta); + } let children = self.hugr.children(node); let mut region_children = BumpVec::with_capacity_in(children.size_hint().0 - 2, self.bump); - let mut output_node = None; - for child in children { match self.hugr.get_optype(child) { OpType::Input(input) => { sources = self.make_ports(child, Direction::Outgoing, input.types.len()); input_types = Some(&input.types); + + if has_order_edges(self.hugr, child) { + let key = self.make_term(model::Literal::Nat(child.index() as u64).into()); + meta.push(self.make_term_apply(model::ORDER_HINT_INPUT_KEY, &[key])); + } } OpType::Output(output) => { targets = self.make_ports(child, Direction::Incoming, output.types.len()); output_types = Some(&output.types); - output_node = Some(child); + + if has_order_edges(self.hugr, child) { + let key = self.make_term(model::Literal::Nat(child.index() as u64).into()); + meta.push(self.make_term_apply(model::ORDER_HINT_OUTPUT_KEY, &[key])); + } } - child_optype => { + _ => { if let Some(child_id) = self.export_node_shallow(child) { region_children.push(child_id); - - // Record all order edges that originate from this node in metadata. - let successors = child_optype - .other_output_port() - .into_iter() - .flat_map(|port| self.hugr.linked_inputs(child, port)) - .map(|(successor, _)| successor) - .filter(|successor| Some(*successor) != output_node); - - for successor in successors { - let a = - self.make_term(model::Literal::Nat(child.index() as u64).into()); - let b = self - .make_term(model::Literal::Nat(successor.index() as u64).into()); - meta.push(self.make_term_apply(model::ORDER_HINT_ORDER, &[a, b])); - } } } } + + // Record all order edges that originate from this node in metadata. + let successors = self + .hugr + .get_optype(child) + .other_output_port() + .into_iter() + .flat_map(|port| self.hugr.linked_inputs(child, port)) + .map(|(successor, _)| successor); + + for successor in successors { + let a = self.make_term(model::Literal::Nat(child.index() as u64).into()); + let b = self.make_term(model::Literal::Nat(successor.index() as u64).into()); + meta.push(self.make_term_apply(model::ORDER_HINT_ORDER, &[a, b])); + } } for child_id in ®ion_children { @@ -740,18 +760,21 @@ impl<'a> Context<'a> { let signature = { let node_signature = self.hugr.signature(node).unwrap(); - let mut wrap_ctrl = |types: &TypeRow| { - let types = self.export_type_row(types); - let types_ctrl = self.make_term_apply(model::CORE_CTRL, &[types]); + let inputs = { + let types = self.export_type_row(node_signature.input()); self.make_term(table::Term::List( - self.bump - .alloc_slice_copy(&[table::SeqPart::Item(types_ctrl)]), + self.bump.alloc_slice_copy(&[table::SeqPart::Item(types)]), )) }; - let inputs = wrap_ctrl(node_signature.input()); - let outputs = wrap_ctrl(node_signature.output()); - Some(self.make_term_apply(model::CORE_FN, &[inputs, outputs])) + let outputs = { + let types = self.export_type_row(node_signature.output()); + self.make_term(table::Term::List( + self.bump.alloc_slice_copy(&[table::SeqPart::Item(types)]), + )) + }; + + Some(self.make_term_apply(model::CORE_CTRL, &[inputs, outputs])) }; let scope = match closure { @@ -786,7 +809,7 @@ impl<'a> Context<'a> { panic!("expected a `Case` node as a child of a `Conditional` node"); }; - regions.push(self.export_dfg(child, model::ScopeClosure::Open, true)); + regions.push(self.export_dfg(child, model::ScopeClosure::Open, true, true)); } regions.into_bump_slice() @@ -796,16 +819,17 @@ impl<'a> Context<'a> { pub fn export_poly_func_type( &mut self, name: &'a str, + visibility: Option, t: &PolyFuncTypeBase, ) -> &'a table::Symbol<'a> { let mut params = BumpVec::with_capacity_in(t.params().len(), self.bump); let scope = self .local_scope .expect("exporting poly func type outside of local scope"); - + let visibility = self.bump.alloc(visibility); for (i, param) in t.params().iter().enumerate() { let name = self.bump.alloc_str(&i.to_string()); - let r#type = self.export_type_param(param, Some((scope, i as _))); + let r#type = self.export_term(param, Some((scope, i as _))); let param = table::Param { name, r#type }; params.push(param); } @@ -814,6 +838,7 @@ impl<'a> Context<'a> { let body = self.export_func_type(t.body()); self.bump.alloc(table::Symbol { + visibility, name, params: params.into_bump_slice(), constraints, @@ -853,30 +878,12 @@ impl<'a> Context<'a> { let args = self .bump - .alloc_slice_fill_iter(t.args().iter().map(|p| self.export_type_arg(p))); + .alloc_slice_fill_iter(t.args().iter().map(|p| self.export_term(p, None))); let term = table::Term::Apply(symbol, args); self.make_term(term) } - pub fn export_type_arg(&mut self, t: &TypeArg) -> table::TermId { - match t { - TypeArg::Type { ty } => self.export_type(ty), - TypeArg::BoundedNat { n } => self.make_term(model::Literal::Nat(*n).into()), - TypeArg::String { arg } => self.make_term(model::Literal::Str(arg.into()).into()), - TypeArg::Sequence { elems } => { - // For now we assume that the sequence is meant to be a list. - let parts = self.bump.alloc_slice_fill_iter( - elems - .iter() - .map(|elem| table::SeqPart::Item(self.export_type_arg(elem))), - ); - self.make_term(table::Term::List(parts)) - } - TypeArg::Variable { v } => self.export_type_arg_var(v), - } - } - - pub fn export_type_arg_var(&mut self, var: &TypeArgVariable) -> table::TermId { + pub fn export_type_arg_var(&mut self, var: &TermVar) -> table::TermId { let node = self.local_scope.expect("local variable out of scope"); self.make_term(table::Term::Var(table::VarId(node, var.index() as _))) } @@ -942,19 +949,19 @@ impl<'a> Context<'a> { self.make_term(table::Term::List(parts)) } - /// Exports a `TypeParam` to a term. + /// Exports a term. /// - /// The `var` argument is set when the type parameter being exported is the + /// The `var` argument is set when the term being exported is the /// type of a parameter to a polymorphic definition. In that case we can /// generate a `nonlinear` constraint for the type of runtime types marked as /// `TypeBound::Copyable`. - pub fn export_type_param( + pub fn export_term( &mut self, - t: &TypeParam, + t: &Term, var: Option<(table::NodeId, table::VarIndex)>, ) -> table::TermId { match t { - TypeParam::Type { b } => { + Term::RuntimeType(b) => { if let (Some((node, index)), TypeBound::Copyable) = (var, b) { let term = self.make_term(table::Term::Var(table::VarId(node, index))); let non_linear = self.make_term_apply(model::CORE_NON_LINEAR, &[term]); @@ -963,22 +970,57 @@ impl<'a> Context<'a> { self.make_term_apply(model::CORE_TYPE, &[]) } - // This ignores the bound on the natural for now. - TypeParam::BoundedNat { .. } => self.make_term_apply(model::CORE_NAT_TYPE, &[]), - TypeParam::String => self.make_term_apply(model::CORE_STR_TYPE, &[]), - TypeParam::List { param } => { - let item_type = self.export_type_param(param, None); + Term::BoundedNatType(_) => self.make_term_apply(model::CORE_NAT_TYPE, &[]), + Term::StringType => self.make_term_apply(model::CORE_STR_TYPE, &[]), + Term::BytesType => self.make_term_apply(model::CORE_BYTES_TYPE, &[]), + Term::FloatType => self.make_term_apply(model::CORE_FLOAT_TYPE, &[]), + Term::ListType(item_type) => { + let item_type = self.export_term(item_type, None); self.make_term_apply(model::CORE_LIST_TYPE, &[item_type]) } - TypeParam::Tuple { params } => { + Term::TupleType(item_types) => { + let item_types = self.export_term(item_types, None); + self.make_term_apply(model::CORE_TUPLE_TYPE, &[item_types]) + } + Term::Runtime(ty) => self.export_type(ty), + Term::BoundedNat(value) => self.make_term(model::Literal::Nat(*value).into()), + Term::String(value) => self.make_term(model::Literal::Str(value.into()).into()), + Term::Float(value) => self.make_term(model::Literal::Float(*value).into()), + Term::Bytes(value) => self.make_term(model::Literal::Bytes(value.clone()).into()), + Term::List(elems) => { let parts = self.bump.alloc_slice_fill_iter( - params + elems + .iter() + .map(|elem| table::SeqPart::Item(self.export_term(elem, None))), + ); + self.make_term(table::Term::List(parts)) + } + Term::ListConcat(lists) => { + let parts = self.bump.alloc_slice_fill_iter( + lists + .iter() + .map(|elem| table::SeqPart::Splice(self.export_term(elem, None))), + ); + self.make_term(table::Term::List(parts)) + } + Term::Tuple(elems) => { + let parts = self.bump.alloc_slice_fill_iter( + elems .iter() - .map(|param| table::SeqPart::Item(self.export_type_param(param, None))), + .map(|elem| table::SeqPart::Item(self.export_term(elem, None))), ); - let types = self.make_term(table::Term::List(parts)); - self.make_term_apply(model::CORE_TUPLE_TYPE, &[types]) + self.make_term(table::Term::Tuple(parts)) } + Term::TupleConcat(tuples) => { + let parts = self.bump.alloc_slice_fill_iter( + tuples + .iter() + .map(|elem| table::SeqPart::Splice(self.export_term(elem, None))), + ); + self.make_term(table::Term::Tuple(parts)) + } + Term::Variable(v) => self.export_type_arg_var(v), + Term::StaticType => self.make_term_apply(model::CORE_STATIC, &[]), } } @@ -1042,7 +1084,7 @@ impl<'a> Context<'a> { let region = match hugr.entrypoint_optype() { OpType::DFG(_) => { - self.export_dfg(hugr.entrypoint(), model::ScopeClosure::Closed, true) + self.export_dfg(hugr.entrypoint(), model::ScopeClosure::Closed, true, true) } _ => panic!("Value::Function root must be a DFG"), }; @@ -1083,21 +1125,7 @@ impl<'a> Context<'a> { } fn export_node_order_metadata(&mut self, node: Node, meta: &mut Vec) { - fn is_relevant_node(hugr: &Hugr, node: Node) -> bool { - let optype = hugr.get_optype(node); - !optype.is_input() && !optype.is_output() - } - - let optype = self.hugr.get_optype(node); - - let has_order_edges = Direction::BOTH - .iter() - .filter(|dir| optype.other_port_kind(**dir) == Some(EdgeKind::StateOrder)) - .filter_map(|dir| optype.other_port(*dir)) - .flat_map(|port| self.hugr.linked_ports(node, port)) - .any(|(other, _)| is_relevant_node(self.hugr, other)); - - if has_order_edges { + if has_order_edges(self.hugr, node) { let key = self.make_term(model::Literal::Nat(node.index() as u64).into()); meta.push(self.make_term_apply(model::ORDER_HINT_KEY, &[key])); } @@ -1109,6 +1137,33 @@ impl<'a> Context<'a> { } } + /// Used when exporting function definitions or declarations. When the + /// function is public, its symbol name will be the core name. For private + /// functions, the symbol name is derived from the node id and the core name + /// is exported as `core.title` metadata. + /// + /// This is a hack, necessary due to core names for functions being + /// non-functional. Once functions have a "link name", that should be used as the symbol name here. + fn export_func_name(&mut self, node: Node, meta: &mut Vec) -> &'a str { + let (name, vis) = match self.hugr.get_optype(node) { + OpType::FuncDefn(func_defn) => (func_defn.func_name(), func_defn.visibility()), + OpType::FuncDecl(func_decl) => (func_decl.func_name(), func_decl.visibility()), + _ => panic!( + "`export_func_name` is only supposed to be used on function declarations and definitions" + ), + }; + + match vis { + Visibility::Public => name, + Visibility::Private => { + let literal = + self.make_term(table::Term::Literal(model::Literal::Str(name.to_smolstr()))); + meta.push(self.make_term_apply(model::CORE_TITLE, &[literal])); + self.mangled_name(node) + } + } + } + pub fn make_json_meta(&mut self, name: &str, value: &serde_json::Value) -> table::TermId { let value = serde_json::to_string(value).expect("json values are always serializable"); let value = self.make_term(model::Literal::Str(value.into()).into()); @@ -1135,6 +1190,11 @@ impl<'a> Context<'a> { let args = self.bump.alloc_slice_copy(args); self.make_term(table::Term::Apply(symbol, args)) } + + /// Creates a mangled name for a particular node. + fn mangled_name(&self, node: Node) -> &'a str { + bumpalo::format!(in &self.bump, "_{}", node.index()).into_bump_str() + } } type FxIndexSet = indexmap::IndexSet; @@ -1212,6 +1272,18 @@ impl Links { } } +/// Returns `true` if a node has any incident order edges. +fn has_order_edges(hugr: &Hugr, node: Node) -> bool { + let optype = hugr.get_optype(node); + Direction::BOTH + .iter() + .filter(|dir| optype.other_port_kind(**dir) == Some(EdgeKind::StateOrder)) + .filter_map(|dir| optype.other_port(*dir)) + .flat_map(|port| hugr.linked_ports(node, port)) + .next() + .is_some() +} + #[cfg(test)] mod test { use rstest::{fixture, rstest}; diff --git a/hugr-core/src/extension.rs b/hugr-core/src/extension.rs index bb5034e1b1..c6dc2be25a 100644 --- a/hugr-core/src/extension.rs +++ b/hugr-core/src/extension.rs @@ -22,7 +22,7 @@ use crate::hugr::IdentList; use crate::ops::custom::{ExtensionOp, OpaqueOp}; use crate::ops::{OpName, OpNameRef}; use crate::types::RowVariable; -use crate::types::type_param::{TypeArg, TypeArgError, TypeParam}; +use crate::types::type_param::{TermTypeError, TypeArg, TypeParam}; use crate::types::{CustomType, TypeBound, TypeName}; use crate::types::{Signature, TypeNameRef}; @@ -36,7 +36,7 @@ mod type_def; pub use const_fold::{ConstFold, ConstFoldResult, Folder, fold_out_row}; pub use op_def::{ CustomSignatureFunc, CustomValidator, LowerFunc, OpDef, SignatureFromArgs, SignatureFunc, - ValidateJustArgs, ValidateTypeArgs, + ValidateJustArgs, ValidateTypeArgs, deserialize_lower_funcs, }; pub use prelude::{PRELUDE, PRELUDE_REGISTRY}; pub use type_def::{TypeDef, TypeDefBound}; @@ -136,8 +136,8 @@ impl ExtensionRegistry { match self.exts.entry(extension.name().clone()) { btree_map::Entry::Occupied(prev) => Err(ExtensionRegistryError::AlreadyRegistered( extension.name().clone(), - prev.get().version().clone(), - extension.version().clone(), + Box::new(prev.get().version().clone()), + Box::new(extension.version().clone()), )), btree_map::Entry::Vacant(ve) => { ve.insert(extension); @@ -387,7 +387,7 @@ pub enum SignatureError { ExtensionMismatch(ExtensionId, ExtensionId), /// When the type arguments of the node did not match the params declared by the `OpDef` #[error("Type arguments of node did not match params declared by definition: {0}")] - TypeArgMismatch(#[from] TypeArgError), + TypeArgMismatch(#[from] TermTypeError), /// Invalid type arguments #[error("Invalid type arguments for operation")] InvalidTypeArgs, @@ -408,8 +408,8 @@ pub enum SignatureError { /// A Type Variable's cache of its declared kind is incorrect #[error("Type Variable claims to be {cached} but actual declaration {actual}")] TypeVarDoesNotMatchDeclaration { - actual: TypeParam, - cached: TypeParam, + actual: Box, + cached: Box, }, /// A type variable that was used has not been declared #[error("Type variable {idx} was not declared ({num_decls} in scope)")] @@ -425,8 +425,8 @@ pub enum SignatureError { "Incorrect result of type application in Call - cached {cached} but expected {expected}" )] CallIncorrectlyAppliesType { - cached: Signature, - expected: Signature, + cached: Box, + expected: Box, }, /// The result of the type application stored in a [`LoadFunction`] /// is not what we get by applying the type-args to the polymorphic function @@ -436,8 +436,8 @@ pub enum SignatureError { "Incorrect result of type application in LoadFunction - cached {cached} but expected {expected}" )] LoadFunctionIncorrectlyAppliesType { - cached: Signature, - expected: Signature, + cached: Box, + expected: Box, }, /// Extension declaration specifies a binary compute signature function, but none @@ -697,7 +697,7 @@ pub enum ExtensionRegistryError { #[error( "The registry already contains an extension with id {0} and version {1}. New extension has version {2}." )] - AlreadyRegistered(ExtensionId, Version, Version), + AlreadyRegistered(ExtensionId, Box, Box), /// A registered extension has invalid signatures. #[error("The extension {0} contains an invalid signature, {1}.")] InvalidSignature(ExtensionId, #[source] SignatureError), @@ -706,13 +706,20 @@ pub enum ExtensionRegistryError { /// An error that can occur while loading an extension registry. #[derive(Debug, Error)] #[non_exhaustive] +#[error("Extension registry load error")] pub enum ExtensionRegistryLoadError { /// Deserialization error. #[error(transparent)] SerdeError(#[from] serde_json::Error), /// Error when resolving internal extension references. #[error(transparent)] - ExtensionResolutionError(#[from] ExtensionResolutionError), + ExtensionResolutionError(Box), +} + +impl From for ExtensionRegistryLoadError { + fn from(error: ExtensionResolutionError) -> Self { + Self::ExtensionResolutionError(Box::new(error)) + } } /// An error that can occur in building a new extension. @@ -889,8 +896,8 @@ pub mod test { reg.register(ext1_1.clone()), Err(ExtensionRegistryError::AlreadyRegistered( ext_1_id.clone(), - Version::new(1, 0, 0), - Version::new(1, 1, 0) + Box::new(Version::new(1, 0, 0)), + Box::new(Version::new(1, 1, 0)) )) ); diff --git a/hugr-core/src/extension/declarative/types.rs b/hugr-core/src/extension/declarative/types.rs index ebbf628d68..46224d48e8 100644 --- a/hugr-core/src/extension/declarative/types.rs +++ b/hugr-core/src/extension/declarative/types.rs @@ -100,7 +100,7 @@ impl From for TypeDefBound { bound: TypeBound::Copyable, }, TypeDefBoundDeclaration::Any => Self::Explicit { - bound: TypeBound::Any, + bound: TypeBound::Linear, }, } } @@ -129,6 +129,6 @@ impl TypeParamDeclaration { _extension: &Extension, _ctx: DeclarationContext<'_>, ) -> Result { - Ok(TypeParam::String) + Ok(TypeParam::StringType) } } diff --git a/hugr-core/src/extension/op_def.rs b/hugr-core/src/extension/op_def.rs index 9c30cbdd47..6a2b5ab69f 100644 --- a/hugr-core/src/extension/op_def.rs +++ b/hugr-core/src/extension/op_def.rs @@ -12,9 +12,9 @@ use super::{ }; use crate::Hugr; -use crate::envelope::serde_with::AsStringEnvelope; +use crate::envelope::serde_with::AsBinaryEnvelope; use crate::ops::{OpName, OpNameRef}; -use crate::types::type_param::{TypeArg, TypeParam, check_type_args}; +use crate::types::type_param::{TypeArg, TypeParam, check_term_types}; use crate::types::{FuncValueType, PolyFuncType, PolyFuncTypeRV, Signature}; mod serialize_signature_func; @@ -239,7 +239,7 @@ impl SignatureFunc { let static_params = func.static_params(); let (static_args, other_args) = args.split_at(min(static_params.len(), args.len())); - check_type_args(static_args, static_params)?; + check_term_types(static_args, static_params)?; temp = func.compute_signature(static_args, def)?; (&temp, other_args) } @@ -268,8 +268,12 @@ impl Debug for SignatureFunc { /// Different ways that an [OpDef] can lower operation nodes i.e. provide a Hugr /// that implements the operation using a set of other extensions. +/// +/// Does not implement [`serde::Deserialize`] directly since the serde error for +/// untagged enums is unhelpful. Use [`deserialize_lower_funcs`] with +/// [`serde(deserialize_with = "deserialize_lower_funcs")] instead. #[serde_as] -#[derive(serde::Deserialize, serde::Serialize)] +#[derive(serde::Serialize)] #[serde(untagged)] pub enum LowerFunc { /// Lowering to a fixed Hugr. Since this cannot depend upon the [TypeArg]s, @@ -281,8 +285,8 @@ pub enum LowerFunc { /// [OpDef] /// /// [ExtensionOp]: crate::ops::ExtensionOp - #[serde_as(as = "AsStringEnvelope")] - hugr: Hugr, + #[serde_as(as = "Box")] + hugr: Box, }, /// Custom binary function that can (fallibly) compute a Hugr /// for the particular instance and set of available extensions. @@ -290,6 +294,34 @@ pub enum LowerFunc { CustomFunc(Box), } +/// A function for deserializing sequences of [`LowerFunc::FixedHugr`]. +/// +/// We could let serde deserialize [`LowerFunc`] as-is, but if the LowerFunc +/// deserialization fails it just returns an opaque "data did not match any +/// variant of untagged enum LowerFunc" error. This function will return the +/// internal errors instead. +pub fn deserialize_lower_funcs<'de, D>(deserializer: D) -> Result, D::Error> +where + D: serde::Deserializer<'de>, +{ + #[serde_as] + #[derive(serde::Deserialize)] + struct FixedHugrDeserializer { + pub extensions: ExtensionSet, + #[serde_as(as = "Box")] + pub hugr: Box, + } + + let funcs: Vec = serde::Deserialize::deserialize(deserializer)?; + Ok(funcs + .into_iter() + .map(|f| LowerFunc::FixedHugr { + extensions: f.extensions, + hugr: f.hugr, + }) + .collect()) +} + impl Debug for LowerFunc { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { @@ -322,7 +354,11 @@ pub struct OpDef { signature_func: SignatureFunc, // Some operations cannot lower themselves and tools that do not understand them // can only treat them as opaque/black-box ops. - #[serde(default, skip_serializing_if = "Vec::is_empty")] + #[serde( + default, + skip_serializing_if = "Vec::is_empty", + deserialize_with = "deserialize_lower_funcs" + )] pub(crate) lower_funcs: Vec, /// Operations can optionally implement [`ConstFold`] to implement constant folding. @@ -347,7 +383,7 @@ impl OpDef { let (static_args, other_args) = args.split_at(min(custom.static_params().len(), args.len())); static_args.iter().try_for_each(|ta| ta.validate(&[]))?; - check_type_args(static_args, custom.static_params())?; + check_term_types(static_args, custom.static_params())?; temp = custom.compute_signature(static_args, self)?; (&temp, other_args) } @@ -357,7 +393,7 @@ impl OpDef { } }; args.iter().try_for_each(|ta| ta.validate(var_decls))?; - check_type_args(args, pf.params())?; + check_term_types(args, pf.params())?; Ok(()) } @@ -377,7 +413,7 @@ impl OpDef { .filter_map(|f| match f { LowerFunc::FixedHugr { extensions, hugr } => { if available_extensions.is_superset(extensions) { - Some(hugr.clone()) + Some(hugr.as_ref().clone()) } else { None } @@ -553,7 +589,7 @@ pub(super) mod test { use crate::extension::{ExtensionRegistry, ExtensionSet, PRELUDE}; use crate::ops::OpName; use crate::std_extensions::collections::list; - use crate::types::type_param::{TypeArgError, TypeParam}; + use crate::types::type_param::{TermTypeError, TypeParam}; use crate::types::{PolyFuncTypeRV, Signature, Type, TypeArg, TypeBound, TypeRV}; use crate::{Extension, const_extension_ids}; @@ -656,7 +692,7 @@ pub(super) mod test { const OP_NAME: OpName = OpName::new_inline("Reverse"); let ext = Extension::try_new_test_arc(EXT_ID, |ext, extension_ref| { - const TP: TypeParam = TypeParam::Type { b: TypeBound::Any }; + const TP: TypeParam = TypeParam::RuntimeType(TypeBound::Linear); let list_of_var = Type::new_extension(list_def.instantiate(vec![TypeArg::new_var_use(0, TP)])?); let type_scheme = PolyFuncTypeRV::new(vec![TP], Signature::new_endo(vec![list_of_var])); @@ -664,7 +700,7 @@ pub(super) mod test { let def = ext.add_op(OP_NAME, "desc".into(), type_scheme, extension_ref)?; def.add_lower_func(LowerFunc::FixedHugr { extensions: ExtensionSet::new(), - hugr: crate::builder::test::simple_dfg_hugr(), // this is nonsense, but we are not testing the actual lowering here + hugr: Box::new(crate::builder::test::simple_dfg_hugr()), // this is nonsense, but we are not testing the actual lowering here }); def.add_misc("key", Default::default()); assert_eq!(def.description(), "desc"); @@ -678,11 +714,10 @@ pub(super) mod test { reg.validate()?; let e = reg.get(&EXT_ID).unwrap(); - let list_usize = - Type::new_extension(list_def.instantiate(vec![TypeArg::Type { ty: usize_t() }])?); + let list_usize = Type::new_extension(list_def.instantiate(vec![usize_t().into()])?); let mut dfg = DFGBuilder::new(endo_sig(vec![list_usize]))?; let rev = dfg.add_dataflow_op( - e.instantiate_extension_op(&OP_NAME, vec![TypeArg::Type { ty: usize_t() }]) + e.instantiate_extension_op(&OP_NAME, vec![usize_t().into()]) .unwrap(), dfg.input_wires(), )?; @@ -703,13 +738,13 @@ pub(super) mod test { &self, arg_values: &[TypeArg], ) -> Result { - const TP: TypeParam = TypeParam::Type { b: TypeBound::Any }; - let [TypeArg::BoundedNat { n }] = arg_values else { + const TP: TypeParam = TypeParam::RuntimeType(TypeBound::Linear); + let [TypeArg::BoundedNat(n)] = arg_values else { return Err(SignatureError::InvalidTypeArgs); }; let n = *n as usize; let tvs: Vec = (0..n) - .map(|_| Type::new_var_use(0, TypeBound::Any)) + .map(|_| Type::new_var_use(0, TypeBound::Linear)) .collect(); Ok(PolyFuncTypeRV::new( vec![TP.clone()], @@ -718,7 +753,7 @@ pub(super) mod test { } fn static_params(&self) -> &[TypeParam] { - const MAX_NAT: &[TypeParam] = &[TypeParam::max_nat()]; + const MAX_NAT: &[TypeParam] = &[TypeParam::max_nat_type()]; MAX_NAT } } @@ -727,7 +762,7 @@ pub(super) mod test { ext.add_op("MyOp".into(), String::new(), SigFun(), extension_ref)?; // Base case, no type variables: - let args = [TypeArg::BoundedNat { n: 3 }, usize_t().into()]; + let args = [TypeArg::BoundedNat(3), usize_t().into()]; assert_eq!( def.compute_signature(&args), Ok(Signature::new( @@ -740,7 +775,7 @@ pub(super) mod test { // Second arg may be a variable (substitutable) let tyvar = Type::new_var_use(0, TypeBound::Copyable); let tyvars: Vec = vec![tyvar.clone(); 3]; - let args = [TypeArg::BoundedNat { n: 3 }, tyvar.clone().into()]; + let args = [TypeArg::BoundedNat(3), tyvar.clone().into()]; assert_eq!( def.compute_signature(&args), Ok(Signature::new( @@ -753,15 +788,15 @@ pub(super) mod test { // quick sanity check that we are validating the args - note changed bound: assert_eq!( - def.validate_args(&args, &[TypeBound::Any.into()]), + def.validate_args(&args, &[TypeBound::Linear.into()]), Err(SignatureError::TypeVarDoesNotMatchDeclaration { - actual: TypeBound::Any.into(), - cached: TypeBound::Copyable.into() + actual: Box::new(TypeBound::Linear.into()), + cached: Box::new(TypeBound::Copyable.into()) }) ); // First arg must be concrete, not a variable - let kind = TypeParam::bounded_nat(NonZeroU64::new(5).unwrap()); + let kind = TypeParam::bounded_nat_type(NonZeroU64::new(5).unwrap()); let args = [TypeArg::new_var_use(0, kind.clone()), usize_t().into()]; // We can't prevent this from getting into our compute_signature implementation: assert_eq!( @@ -792,13 +827,13 @@ pub(super) mod test { "SimpleOp".into(), String::new(), PolyFuncTypeRV::new( - vec![TypeBound::Any.into()], - Signature::new_endo(vec![Type::new_var_use(0, TypeBound::Any)]), + vec![TypeBound::Linear.into()], + Signature::new_endo(vec![Type::new_var_use(0, TypeBound::Linear)]), ), extension_ref, )?; let tv = Type::new_var_use(0, TypeBound::Copyable); - let args = [TypeArg::Type { ty: tv.clone() }]; + let args = [tv.clone().into()]; let decls = [TypeBound::Copyable.into()]; def.validate_args(&args, &decls).unwrap(); assert_eq!(def.compute_signature(&args), Ok(Signature::new_endo(tv))); @@ -807,9 +842,9 @@ pub(super) mod test { assert_eq!( def.compute_signature(&[arg.clone()]), Err(SignatureError::TypeArgMismatch( - TypeArgError::TypeMismatch { - param: TypeBound::Any.into(), - arg + TermTypeError::TypeMismatch { + type_: Box::new(TypeBound::Linear.into()), + term: Box::new(arg), } )) ); @@ -852,7 +887,7 @@ pub(super) mod test { any::() .prop_map(|extensions| LowerFunc::FixedHugr { extensions, - hugr: simple_dfg_hugr(), + hugr: Box::new(simple_dfg_hugr()), }) .boxed() } diff --git a/hugr-core/src/extension/prelude.rs b/hugr-core/src/extension/prelude.rs index 3af70b75b4..1b59d50ea7 100644 --- a/hugr-core/src/extension/prelude.rs +++ b/hugr-core/src/extension/prelude.rs @@ -18,8 +18,8 @@ use crate::ops::constant::{CustomCheckFailure, CustomConst, ValueName}; use crate::ops::{NamedOp, Value}; use crate::types::type_param::{TypeArg, TypeParam}; use crate::types::{ - CustomType, FuncValueType, PolyFuncType, PolyFuncTypeRV, Signature, SumType, Type, TypeBound, - TypeName, TypeRV, TypeRow, TypeRowRV, + CustomType, FuncValueType, PolyFuncType, PolyFuncTypeRV, Signature, SumType, Term, Type, + TypeBound, TypeName, TypeRV, TypeRow, TypeRowRV, }; use crate::utils::sorted_consts; use crate::{Extension, type_row}; @@ -39,7 +39,7 @@ pub mod generic; /// Name of prelude extension. pub const PRELUDE_ID: ExtensionId = ExtensionId::new_unchecked("prelude"); /// Extension version. -pub const VERSION: semver::Version = semver::Version::new(0, 2, 0); +pub const VERSION: semver::Version = semver::Version::new(0, 2, 1); lazy_static! { /// Prelude extension, containing common types and operations. pub static ref PRELUDE: Arc = { @@ -52,6 +52,7 @@ lazy_static! { // would try to access the `PRELUDE` lazy static recursively, // causing a deadlock. let string_type: Type = string_custom_type(extension_ref).into(); + let usize_type: Type = usize_custom_t(extension_ref).into(); let error_type: CustomType = error_custom_type(extension_ref); prelude @@ -74,7 +75,7 @@ lazy_static! { prelude.add_op( PRINT_OP_ID, "Print the string to standard output".to_string(), - Signature::new(vec![string_type], type_row![]), + Signature::new(vec![string_type.clone()], type_row![]), extension_ref, ) .unwrap(); @@ -96,15 +97,23 @@ lazy_static! { extension_ref, ) .unwrap(); + prelude + .add_op( + MAKE_ERROR_OP_ID, + "Create an error value".to_string(), + Signature::new(vec![usize_type, string_type], vec![error_type.clone().into()]), + extension_ref, + ) + .unwrap(); prelude .add_op( PANIC_OP_ID, "Panic with input error".to_string(), PolyFuncTypeRV::new( - [TypeParam::new_list(TypeBound::Any), TypeParam::new_list(TypeBound::Any)], + [TypeParam::new_list_type(TypeBound::Linear), TypeParam::new_list_type(TypeBound::Linear)], FuncValueType::new( - vec![TypeRV::new_extension(error_type.clone()), TypeRV::new_row_var_use(0, TypeBound::Any)], - vec![TypeRV::new_row_var_use(1, TypeBound::Any)], + vec![TypeRV::new_extension(error_type.clone()), TypeRV::new_row_var_use(0, TypeBound::Linear)], + vec![TypeRV::new_row_var_use(1, TypeBound::Linear)], ), ), extension_ref, @@ -115,10 +124,10 @@ lazy_static! { EXIT_OP_ID, "Exit with input error".to_string(), PolyFuncTypeRV::new( - [TypeParam::new_list(TypeBound::Any), TypeParam::new_list(TypeBound::Any)], + [TypeParam::new_list_type(TypeBound::Linear), TypeParam::new_list_type(TypeBound::Linear)], FuncValueType::new( - vec![TypeRV::new_extension(error_type), TypeRV::new_row_var_use(0, TypeBound::Any)], - vec![TypeRV::new_row_var_use(1, TypeBound::Any)], + vec![TypeRV::new_extension(error_type), TypeRV::new_row_var_use(0, TypeBound::Linear)], + vec![TypeRV::new_row_var_use(1, TypeBound::Linear)], ), ), extension_ref, @@ -151,7 +160,7 @@ pub(crate) fn qb_custom_t(extension_ref: &Weak) -> CustomType { TypeName::new_inline("qubit"), vec![], PRELUDE_ID, - TypeBound::Any, + TypeBound::Linear, extension_ref, ) } @@ -172,10 +181,15 @@ pub fn bool_t() -> Type { Type::new_unit_sum(2) } +/// Name of the prelude `MakeError` operation. +/// +/// This operation can be used to dynamically create error values. +pub const MAKE_ERROR_OP_ID: OpName = OpName::new_inline("MakeError"); + /// Name of the prelude panic operation. /// /// This operation can have any input and any output wires; it is instantiated -/// with two [`TypeArg::Sequence`]s representing these. The first input to the +/// with two [`TypeArg::List`]s representing these. The first input to the /// operation is always an error type; the remaining inputs correspond to the /// first sequence of types in its instantiation; the outputs correspond to the /// second sequence of types in its instantiation. Note that the inputs and @@ -189,7 +203,7 @@ pub const PANIC_OP_ID: OpName = OpName::new_inline("panic"); /// Name of the prelude exit operation. /// /// This operation can have any input and any output wires; it is instantiated -/// with two [`TypeArg::Sequence`]s representing these. The first input to the +/// with two [`TypeArg::List`]s representing these. The first input to the /// operation is always an error type; the remaining inputs correspond to the /// first sequence of types in its instantiation; the outputs correspond to the /// second sequence of types in its instantiation. Note that the inputs and @@ -612,10 +626,10 @@ impl MakeOpDef for TupleOpDef { } fn init_signature(&self, _extension_ref: &Weak) -> SignatureFunc { - let rv = TypeRV::new_row_var_use(0, TypeBound::Any); + let rv = TypeRV::new_row_var_use(0, TypeBound::Linear); let tuple_type = TypeRV::new_tuple(vec![rv.clone()]); - let param = TypeParam::new_list(TypeBound::Any); + let param = TypeParam::new_list_type(TypeBound::Linear); match self { TupleOpDef::MakeTuple => { PolyFuncTypeRV::new([param], FuncValueType::new(rv, tuple_type)) @@ -678,13 +692,13 @@ impl MakeExtensionOp for MakeTuple { if def != TupleOpDef::MakeTuple { return Err(OpLoadError::NotMember(ext_op.unqualified_id().to_string()))?; } - let [TypeArg::Sequence { elems }] = ext_op.args() else { + let [TypeArg::List(elems)] = ext_op.args() else { return Err(SignatureError::InvalidTypeArgs)?; }; let tys: Result, _> = elems .iter() .map(|a| match a { - TypeArg::Type { ty } => Ok(ty.clone()), + TypeArg::Runtime(ty) => Ok(ty.clone()), _ => Err(SignatureError::InvalidTypeArgs), }) .collect(); @@ -692,13 +706,7 @@ impl MakeExtensionOp for MakeTuple { } fn type_args(&self) -> Vec { - vec![TypeArg::Sequence { - elems: self - .0 - .iter() - .map(|t| TypeArg::Type { ty: t.clone() }) - .collect(), - }] + vec![Term::new_list(self.0.iter().map(|t| t.clone().into()))] } } @@ -739,27 +747,21 @@ impl MakeExtensionOp for UnpackTuple { if def != TupleOpDef::UnpackTuple { return Err(OpLoadError::NotMember(ext_op.unqualified_id().to_string()))?; } - let [TypeArg::Sequence { elems }] = ext_op.args() else { + let [Term::List(elems)] = ext_op.args() else { return Err(SignatureError::InvalidTypeArgs)?; }; let tys: Result, _> = elems .iter() .map(|a| match a { - TypeArg::Type { ty } => Ok(ty.clone()), + Term::Runtime(ty) => Ok(ty.clone()), _ => Err(SignatureError::InvalidTypeArgs), }) .collect(); Ok(Self(tys?.into())) } - fn type_args(&self) -> Vec { - vec![TypeArg::Sequence { - elems: self - .0 - .iter() - .map(|t| TypeArg::Type { ty: t.clone() }) - .collect(), - }] + fn type_args(&self) -> Vec { + vec![Term::new_list(self.0.iter().map(|t| t.clone().into()))] } } @@ -798,8 +800,8 @@ impl MakeOpDef for NoopDef { } fn init_signature(&self, _extension_ref: &Weak) -> SignatureFunc { - let tv = Type::new_var_use(0, TypeBound::Any); - PolyFuncType::new([TypeBound::Any.into()], Signature::new_endo(tv)).into() + let tv = Type::new_var_use(0, TypeBound::Linear); + PolyFuncType::new([TypeBound::Linear.into()], Signature::new_endo(tv)).into() } fn description(&self) -> String { @@ -863,14 +865,14 @@ impl MakeExtensionOp for Noop { Self: Sized, { let _def = NoopDef::from_def(ext_op.def())?; - let [TypeArg::Type { ty }] = ext_op.args() else { + let [TypeArg::Runtime(ty)] = ext_op.args() else { return Err(SignatureError::InvalidTypeArgs)?; }; Ok(Self(ty.clone())) } fn type_args(&self) -> Vec { - vec![TypeArg::Type { ty: self.0.clone() }] + vec![self.0.clone().into()] } } @@ -910,8 +912,8 @@ impl MakeOpDef for BarrierDef { fn init_signature(&self, _extension_ref: &Weak) -> SignatureFunc { PolyFuncTypeRV::new( - vec![TypeParam::new_list(TypeBound::Any)], - FuncValueType::new_endo(TypeRV::new_row_var_use(0, TypeBound::Any)), + vec![TypeParam::new_list_type(TypeBound::Linear)], + FuncValueType::new_endo(TypeRV::new_row_var_use(0, TypeBound::Linear)), ) .into() } @@ -969,13 +971,13 @@ impl MakeExtensionOp for Barrier { { let _def = BarrierDef::from_def(ext_op.def())?; - let [TypeArg::Sequence { elems }] = ext_op.args() else { + let [TypeArg::List(elems)] = ext_op.args() else { return Err(SignatureError::InvalidTypeArgs)?; }; let tys: Result, _> = elems .iter() .map(|a| match a { - TypeArg::Type { ty } => Ok(ty.clone()), + TypeArg::Runtime(ty) => Ok(ty.clone()), _ => Err(SignatureError::InvalidTypeArgs), }) .collect(); @@ -985,13 +987,9 @@ impl MakeExtensionOp for Barrier { } fn type_args(&self) -> Vec { - vec![TypeArg::Sequence { - elems: self - .type_row - .iter() - .map(|t| TypeArg::Type { ty: t.clone() }) - .collect(), - }] + vec![TypeArg::new_list( + self.type_row.iter().map(|t| t.clone().into()), + )] } } @@ -1009,6 +1007,7 @@ impl MakeRegisteredOp for Barrier { mod test { use crate::builder::inout_sig; use crate::std_extensions::arithmetic::float_types::{ConstF64, float64_type}; + use crate::types::Term; use crate::{ Hugr, Wire, builder::{DFGBuilder, Dataflow, DataflowHugr, endo_sig}, @@ -1021,6 +1020,8 @@ mod test { type_row, }; + use crate::hugr::views::HugrView; + #[test] fn test_make_tuple() { let op = MakeTuple::new(type_row![Type::UNIT]); @@ -1132,9 +1133,8 @@ mod test { let err = b.add_load_value(error_val); - const TYPE_ARG_NONE: TypeArg = TypeArg::Sequence { elems: vec![] }; let op = PRELUDE - .instantiate_extension_op(&EXIT_OP_ID, [TYPE_ARG_NONE, TYPE_ARG_NONE]) + .instantiate_extension_op(&EXIT_OP_ID, [Term::new_list([]), Term::new_list([])]) .unwrap(); b.add_dataflow_op(op, [err]).unwrap(); @@ -1142,14 +1142,32 @@ mod test { b.finish_hugr_with_outputs([]).unwrap(); } + #[test] + /// test the prelude make error op with the panic op. + fn test_make_error() { + let err_op = PRELUDE + .instantiate_extension_op(&MAKE_ERROR_OP_ID, []) + .unwrap(); + let panic_op = PRELUDE + .instantiate_extension_op(&EXIT_OP_ID, [Term::new_list([]), Term::new_list([])]) + .unwrap(); + + let mut b = + DFGBuilder::new(Signature::new(vec![usize_t(), string_type()], type_row![])).unwrap(); + let [signal, message] = b.input_wires_arr(); + let err_value = b.add_dataflow_op(err_op, [signal, message]).unwrap(); + b.add_dataflow_op(panic_op, err_value.outputs()).unwrap(); + + let h = b.finish_hugr_with_outputs([]).unwrap(); + h.validate().unwrap(); + } + #[test] /// test the panic operation with input and output wires fn test_panic_with_io() { let error_val = ConstError::new(42, "PANIC"); - let type_arg_q: TypeArg = TypeArg::Type { ty: qb_t() }; - let type_arg_2q: TypeArg = TypeArg::Sequence { - elems: vec![type_arg_q.clone(), type_arg_q], - }; + let type_arg_q: Term = qb_t().into(); + let type_arg_2q: Term = Term::new_list([type_arg_q.clone(), type_arg_q]); let panic_op = PRELUDE .instantiate_extension_op(&PANIC_OP_ID, [type_arg_2q.clone(), type_arg_2q.clone()]) .unwrap(); diff --git a/hugr-core/src/extension/prelude/generic.rs b/hugr-core/src/extension/prelude/generic.rs index 9ea231e1bb..ca00c713fd 100644 --- a/hugr-core/src/extension/prelude/generic.rs +++ b/hugr-core/src/extension/prelude/generic.rs @@ -74,7 +74,7 @@ impl MakeOpDef for LoadNatDef { fn init_signature(&self, _extension_ref: &Weak) -> SignatureFunc { let usize_t: Type = usize_custom_t(_extension_ref).into(); - let params = vec![TypeParam::max_nat()]; + let params = vec![TypeParam::max_nat_type()]; PolyFuncTypeRV::new(params, FuncValueType::new(type_row![], vec![usize_t])).into() } @@ -166,7 +166,7 @@ mod tests { extension::prelude::{ConstUsize, usize_t}, ops::{OpType, constant}, type_row, - types::TypeArg, + types::Term, }; use super::LoadNat; @@ -175,7 +175,7 @@ mod tests { fn test_load_nat() { let mut b = DFGBuilder::new(inout_sig(type_row![], vec![usize_t()])).unwrap(); - let arg = TypeArg::BoundedNat { n: 4 }; + let arg = Term::from(4u64); let op = LoadNat::new(arg); let out = b.add_dataflow_op(op.clone(), []).unwrap(); @@ -195,7 +195,7 @@ mod tests { #[test] fn test_load_nat_fold() { - let arg = TypeArg::BoundedNat { n: 5 }; + let arg = Term::from(5u64); let op = LoadNat::new(arg); let optype: OpType = op.into(); diff --git a/hugr-core/src/extension/resolution.rs b/hugr-core/src/extension/resolution.rs index 0e7bfbbab8..52f2c5dbf5 100644 --- a/hugr-core/src/extension/resolution.rs +++ b/hugr-core/src/extension/resolution.rs @@ -26,7 +26,7 @@ pub(crate) use ops::{collect_op_extension, resolve_op_extensions}; pub(crate) use types::{collect_op_types_extensions, collect_signature_exts, collect_type_exts}; pub(crate) use types_mut::resolve_op_types_extensions; use types_mut::{ - resolve_custom_type_exts, resolve_type_exts, resolve_typearg_exts, resolve_value_exts, + resolve_custom_type_exts, resolve_term_exts, resolve_type_exts, resolve_value_exts, }; use derive_more::{Display, Error, From}; @@ -63,7 +63,7 @@ pub fn resolve_typearg_extensions( extensions: &WeakExtensionRegistry, ) -> Result<(), ExtensionResolutionError> { let mut used_extensions = WeakExtensionRegistry::default(); - resolve_typearg_exts(None, arg, extensions, &mut used_extensions) + resolve_term_exts(None, arg, extensions, &mut used_extensions) } /// Update all weak Extension pointers inside a constant value. diff --git a/hugr-core/src/extension/resolution/ops.rs b/hugr-core/src/extension/resolution/ops.rs index a76ba47d8c..e6727fd834 100644 --- a/hugr-core/src/extension/resolution/ops.rs +++ b/hugr-core/src/extension/resolution/ops.rs @@ -98,8 +98,8 @@ pub(crate) fn resolve_op_extensions<'e>( node, extension: opaque.extension().clone(), op: def.name().clone(), - computed: ext_op.signature().into_owned(), - stored: opaque.signature().into_owned(), + computed: Box::new(ext_op.signature().into_owned()), + stored: Box::new(opaque.signature().into_owned()), } .into()); } diff --git a/hugr-core/src/extension/resolution/test.rs b/hugr-core/src/extension/resolution/test.rs index e73dd54fbd..43c64b561d 100644 --- a/hugr-core/src/extension/resolution/test.rs +++ b/hugr-core/src/extension/resolution/test.rs @@ -25,7 +25,7 @@ use crate::std_extensions::arithmetic::int_types::{self, int_type}; use crate::std_extensions::collections::list::ListValue; use crate::std_extensions::std_reg; use crate::types::type_param::TypeParam; -use crate::types::{PolyFuncType, Signature, Type, TypeArg, TypeBound}; +use crate::types::{PolyFuncType, Signature, Type, TypeBound}; use crate::{Extension, Hugr, HugrView, type_row}; #[rstest] @@ -333,12 +333,12 @@ fn resolve_custom_const(#[case] custom_const: impl CustomConst) { #[rstest] fn resolve_call() { let dummy_fn_sig = PolyFuncType::new( - vec![TypeParam::Type { b: TypeBound::Any }], + vec![TypeParam::RuntimeType(TypeBound::Linear)], Signature::new(vec![], vec![bool_t()]), ); - let generic_type_1 = TypeArg::Type { ty: float64_type() }; - let generic_type_2 = TypeArg::Type { ty: int_type(6) }; + let generic_type_1 = float64_type().into(); + let generic_type_2 = int_type(6).into(); let expected_exts = [ float_types::EXTENSION_ID.clone(), int_types::EXTENSION_ID.clone(), diff --git a/hugr-core/src/extension/resolution/types.rs b/hugr-core/src/extension/resolution/types.rs index 531509d6ee..0ea6bd7007 100644 --- a/hugr-core/src/extension/resolution/types.rs +++ b/hugr-core/src/extension/resolution/types.rs @@ -11,7 +11,7 @@ use crate::Node; use crate::extension::{ExtensionRegistry, ExtensionSet}; use crate::ops::{DataflowOpTrait, OpType, Value}; use crate::types::type_row::TypeRowBase; -use crate::types::{FuncTypeBase, MaybeRV, SumType, TypeArg, TypeBase, TypeEnum}; +use crate::types::{FuncTypeBase, MaybeRV, SumType, Term, TypeBase, TypeEnum}; /// Collects every extension used to define the types in an operation. /// @@ -38,7 +38,7 @@ pub(crate) fn collect_op_types_extensions( match op { OpType::ExtensionOp(ext) => { for arg in ext.args() { - collect_typearg_exts(arg, &mut used, &mut missing); + collect_term_exts(arg, &mut used, &mut missing); } collect_signature_exts(&ext.signature(), &mut used, &mut missing); } @@ -55,7 +55,7 @@ pub(crate) fn collect_op_types_extensions( collect_signature_exts(c.func_sig.body(), &mut used, &mut missing); collect_signature_exts(&c.instantiation, &mut used, &mut missing); for arg in &c.type_args { - collect_typearg_exts(arg, &mut used, &mut missing); + collect_term_exts(arg, &mut used, &mut missing); } } OpType::CallIndirect(c) => collect_signature_exts(&c.signature, &mut used, &mut missing), @@ -64,13 +64,13 @@ pub(crate) fn collect_op_types_extensions( collect_signature_exts(lf.func_sig.body(), &mut used, &mut missing); collect_signature_exts(&lf.instantiation, &mut used, &mut missing); for arg in &lf.type_args { - collect_typearg_exts(arg, &mut used, &mut missing); + collect_term_exts(arg, &mut used, &mut missing); } } OpType::DFG(dfg) => collect_signature_exts(&dfg.signature, &mut used, &mut missing), OpType::OpaqueOp(op) => { for arg in op.args() { - collect_typearg_exts(arg, &mut used, &mut missing); + collect_term_exts(arg, &mut used, &mut missing); } collect_signature_exts(&op.signature(), &mut used, &mut missing); } @@ -172,7 +172,7 @@ pub(crate) fn collect_type_exts( match typ.as_type_enum() { TypeEnum::Extension(custom) => { for arg in custom.args() { - collect_typearg_exts(arg, used_extensions, missing_extensions); + collect_term_exts(arg, used_extensions, missing_extensions); } let ext_ref = custom.extension_ref(); // Check if the extension reference is still valid. @@ -202,29 +202,58 @@ pub(crate) fn collect_type_exts( } } -/// Collect the Extension pointers in the [`CustomType`]s inside a type argument. +/// Collect the Extension pointers in the [`CustomType`]s inside a [`Term`]. /// /// # Attributes /// -/// - `arg`: The type argument to collect the extensions from. +/// - `term`: The term argument to collect the extensions from. /// - `used_extensions`: A The registry where to store the used extensions. /// - `missing_extensions`: A set of `ExtensionId`s of which the /// `Weak` pointer has been invalidated. -pub(super) fn collect_typearg_exts( - arg: &TypeArg, +pub(super) fn collect_term_exts( + term: &Term, used_extensions: &mut WeakExtensionRegistry, missing_extensions: &mut ExtensionSet, ) { - match arg { - TypeArg::Type { ty } => collect_type_exts(ty, used_extensions, missing_extensions), - TypeArg::Sequence { elems } => { - for elem in elems { - collect_typearg_exts(elem, used_extensions, missing_extensions); + match term { + Term::Runtime(ty) => collect_type_exts(ty, used_extensions, missing_extensions), + Term::List(elems) => { + for elem in elems.iter() { + collect_term_exts(elem, used_extensions, missing_extensions); } } - // We ignore the `TypeArg::Extension` case, as it is not required to - // **define** the hugr. - _ => {} + Term::Tuple(elems) => { + for elem in elems.iter() { + collect_term_exts(elem, used_extensions, missing_extensions); + } + } + Term::ListType(item_type) => { + collect_term_exts(item_type, used_extensions, missing_extensions) + } + Term::TupleType(item_types) => { + collect_term_exts(item_types, used_extensions, missing_extensions) + } + Term::ListConcat(lists) => { + for list in lists { + collect_term_exts(list, used_extensions, missing_extensions); + } + } + Term::TupleConcat(tuples) => { + for tuple in tuples { + collect_term_exts(tuple, used_extensions, missing_extensions); + } + } + Term::Variable(_) + | Term::RuntimeType(_) + | Term::StaticType + | Term::BoundedNatType(_) + | Term::StringType + | Term::BytesType + | Term::FloatType + | Term::BoundedNat(_) + | Term::String(_) + | Term::Bytes(_) + | Term::Float(_) => {} } } diff --git a/hugr-core/src/extension/resolution/types_mut.rs b/hugr-core/src/extension/resolution/types_mut.rs index c4093a18c2..8135ca0b1b 100644 --- a/hugr-core/src/extension/resolution/types_mut.rs +++ b/hugr-core/src/extension/resolution/types_mut.rs @@ -10,7 +10,7 @@ use super::{ExtensionResolutionError, WeakExtensionRegistry}; use crate::extension::ExtensionSet; use crate::ops::{OpType, Value}; use crate::types::type_row::TypeRowBase; -use crate::types::{CustomType, FuncTypeBase, MaybeRV, SumType, TypeArg, TypeBase, TypeEnum}; +use crate::types::{CustomType, FuncTypeBase, MaybeRV, SumType, Term, TypeBase, TypeEnum}; use crate::{Extension, Node}; /// Replace the dangling extension pointer in the [`CustomType`]s inside an @@ -30,7 +30,7 @@ pub fn resolve_op_types_extensions( match op { OpType::ExtensionOp(ext) => { for arg in ext.args_mut() { - resolve_typearg_exts(node, arg, extensions, used_extensions)?; + resolve_term_exts(node, arg, extensions, used_extensions)?; } resolve_signature_exts(node, ext.signature_mut(), extensions, used_extensions)?; } @@ -61,7 +61,7 @@ pub fn resolve_op_types_extensions( resolve_signature_exts(node, c.func_sig.body_mut(), extensions, used_extensions)?; resolve_signature_exts(node, &mut c.instantiation, extensions, used_extensions)?; for arg in &mut c.type_args { - resolve_typearg_exts(node, arg, extensions, used_extensions)?; + resolve_term_exts(node, arg, extensions, used_extensions)?; } } OpType::CallIndirect(c) => { @@ -74,7 +74,7 @@ pub fn resolve_op_types_extensions( resolve_signature_exts(node, lf.func_sig.body_mut(), extensions, used_extensions)?; resolve_signature_exts(node, &mut lf.instantiation, extensions, used_extensions)?; for arg in &mut lf.type_args { - resolve_typearg_exts(node, arg, extensions, used_extensions)?; + resolve_term_exts(node, arg, extensions, used_extensions)?; } } OpType::DFG(dfg) => { @@ -82,7 +82,7 @@ pub fn resolve_op_types_extensions( } OpType::OpaqueOp(op) => { for arg in op.args_mut() { - resolve_typearg_exts(node, arg, extensions, used_extensions)?; + resolve_term_exts(node, arg, extensions, used_extensions)?; } resolve_signature_exts(node, op.signature_mut(), extensions, used_extensions)?; } @@ -195,7 +195,7 @@ pub(super) fn resolve_custom_type_exts( used_extensions: &mut WeakExtensionRegistry, ) -> Result<(), ExtensionResolutionError> { for arg in custom.args_mut() { - resolve_typearg_exts(node, arg, extensions, used_extensions)?; + resolve_term_exts(node, arg, extensions, used_extensions)?; } let ext_id = custom.extension(); @@ -211,23 +211,42 @@ pub(super) fn resolve_custom_type_exts( Ok(()) } -/// Update all weak Extension pointers in the [`CustomType`]s inside a type arg. +/// Update all weak Extension pointers in the [`CustomType`]s inside a [`Term`]. /// /// Adds the extensions used in the type to the `used_extensions` registry. -pub(super) fn resolve_typearg_exts( +pub(super) fn resolve_term_exts( node: Option, - arg: &mut TypeArg, + term: &mut Term, extensions: &WeakExtensionRegistry, used_extensions: &mut WeakExtensionRegistry, ) -> Result<(), ExtensionResolutionError> { - match arg { - TypeArg::Type { ty } => resolve_type_exts(node, ty, extensions, used_extensions)?, - TypeArg::Sequence { elems } => { - for elem in elems.iter_mut() { - resolve_typearg_exts(node, elem, extensions, used_extensions)?; + match term { + Term::Runtime(ty) => resolve_type_exts(node, ty, extensions, used_extensions)?, + Term::List(children) + | Term::ListConcat(children) + | Term::Tuple(children) + | Term::TupleConcat(children) => { + for child in children.iter_mut() { + resolve_term_exts(node, child, extensions, used_extensions)?; } } - _ => {} + Term::ListType(item_type) => { + resolve_term_exts(node, item_type.as_mut(), extensions, used_extensions)?; + } + Term::TupleType(item_types) => { + resolve_term_exts(node, item_types.as_mut(), extensions, used_extensions)?; + } + Term::Variable(_) + | Term::RuntimeType(_) + | Term::StaticType + | Term::BoundedNatType(_) + | Term::StringType + | Term::BytesType + | Term::FloatType + | Term::BoundedNat(_) + | Term::String(_) + | Term::Bytes(_) + | Term::Float(_) => {} } Ok(()) } diff --git a/hugr-core/src/extension/simple_op.rs b/hugr-core/src/extension/simple_op.rs index bf013ba5dc..8685b63325 100644 --- a/hugr-core/src/extension/simple_op.rs +++ b/hugr-core/src/extension/simple_op.rs @@ -308,7 +308,10 @@ impl From for OpType { mod test { use std::sync::Arc; - use crate::{const_extension_ids, type_row, types::Signature}; + use crate::{ + const_extension_ids, type_row, + types::{Signature, Term}, + }; use super::*; use lazy_static::lazy_static; @@ -393,7 +396,7 @@ mod test { assert_eq!(o.instantiate(&[]), Ok(o.clone())); assert_eq!( - o.instantiate(&[TypeArg::BoundedNat { n: 1 }]), + o.instantiate(&[Term::from(1u64)]), Err(OpLoadError::InvalidArgs(SignatureError::InvalidTypeArgs)) ); } diff --git a/hugr-core/src/extension/type_def.rs b/hugr-core/src/extension/type_def.rs index fceb336b2f..b848c7528f 100644 --- a/hugr-core/src/extension/type_def.rs +++ b/hugr-core/src/extension/type_def.rs @@ -6,7 +6,7 @@ use super::{Extension, ExtensionId, SignatureError}; use crate::types::{CustomType, TypeName, least_upper_bound}; -use crate::types::type_param::{TypeArg, check_type_args}; +use crate::types::type_param::{TypeArg, check_term_types}; use crate::types::type_param::TypeParam; @@ -34,7 +34,7 @@ impl TypeDefBound { #[must_use] pub fn any() -> Self { TypeDefBound::Explicit { - bound: TypeBound::Any, + bound: TypeBound::Linear, } } @@ -79,7 +79,7 @@ pub struct TypeDef { impl TypeDef { /// Check provided type arguments are valid against parameters. pub fn check_args(&self, args: &[TypeArg]) -> Result<(), SignatureError> { - check_type_args(args, &self.params).map_err(SignatureError::TypeArgMismatch) + check_term_types(args, &self.params).map_err(SignatureError::TypeArgMismatch) } /// Check [`CustomType`] is a valid instantiation of this definition. @@ -102,7 +102,7 @@ impl TypeDef { )); } - check_type_args(custom.type_args(), &self.params)?; + check_term_types(custom.type_args(), &self.params)?; let calc_bound = self.bound(custom.args()); if calc_bound == custom.bound() { @@ -123,7 +123,7 @@ impl TypeDef { /// valid instances of the type parameters. pub fn instantiate(&self, args: impl Into>) -> Result { let args = args.into(); - check_type_args(&args, &self.params)?; + check_term_types(&args, &self.params)?; let bound = self.bound(&args); Ok(CustomType::new( self.name().clone(), @@ -142,12 +142,12 @@ impl TypeDef { let args: Vec<_> = args.iter().collect(); if indices.is_empty() { // Assume most general case - return TypeBound::Any; + return TypeBound::Linear; } least_upper_bound(indices.iter().map(|i| { let ta = args.get(*i); match ta { - Some(TypeArg::Type { ty: s }) => s.least_upper_bound(), + Some(TypeArg::Runtime(s)) => s.least_upper_bound(), _ => panic!("TypeArg index does not refer to a type."), } })) @@ -241,7 +241,7 @@ mod test { use crate::extension::SignatureError; use crate::extension::prelude::{qb_t, usize_t}; use crate::std_extensions::arithmetic::float_types::float64_type; - use crate::types::type_param::{TypeArg, TypeArgError, TypeParam}; + use crate::types::type_param::{TermTypeError, TypeParam}; use crate::types::{Signature, Type, TypeBound}; use super::{TypeDef, TypeDefBound}; @@ -250,9 +250,7 @@ mod test { fn test_instantiate_typedef() { let def = TypeDef { name: "MyType".into(), - params: vec![TypeParam::Type { - b: TypeBound::Copyable, - }], + params: vec![TypeParam::RuntimeType(TypeBound::Copyable)], extension: "MyRsrc".try_into().unwrap(), // Dummy extension. Will return `None` when trying to upgrade it into an `Arc`. extension_ref: Default::default(), @@ -260,9 +258,9 @@ mod test { bound: TypeDefBound::FromParams { indices: vec![0] }, }; let typ = Type::new_extension( - def.instantiate(vec![TypeArg::Type { - ty: Type::new_function(Signature::new(vec![], vec![])), - }]) + def.instantiate(vec![ + Type::new_function(Signature::new(vec![], vec![])).into(), + ]) .unwrap(), ); assert_eq!(typ.least_upper_bound(), TypeBound::Copyable); @@ -271,27 +269,24 @@ mod test { // And some bad arguments...firstly, wrong kind of TypeArg: assert_eq!( - def.instantiate([TypeArg::Type { ty: qb_t() }]), + def.instantiate([qb_t().into()]), Err(SignatureError::TypeArgMismatch( - TypeArgError::TypeMismatch { - arg: TypeArg::Type { ty: qb_t() }, - param: TypeBound::Copyable.into() + TermTypeError::TypeMismatch { + term: Box::new(qb_t().into()), + type_: Box::new(TypeBound::Copyable.into()) } )) ); // Too few arguments: assert_eq!( def.instantiate([]).unwrap_err(), - SignatureError::TypeArgMismatch(TypeArgError::WrongNumberArgs(0, 1)) + SignatureError::TypeArgMismatch(TermTypeError::WrongNumberArgs(0, 1)) ); // Too many arguments: assert_eq!( - def.instantiate([ - TypeArg::Type { ty: float64_type() }, - TypeArg::Type { ty: float64_type() }, - ]) - .unwrap_err(), - SignatureError::TypeArgMismatch(TypeArgError::WrongNumberArgs(2, 1)) + def.instantiate([float64_type().into(), float64_type().into(),]) + .unwrap_err(), + SignatureError::TypeArgMismatch(TermTypeError::WrongNumberArgs(2, 1)) ); } } diff --git a/hugr-core/src/hugr.rs b/hugr-core/src/hugr.rs index f9398b0cf8..78bdd88390 100644 --- a/hugr-core/src/hugr.rs +++ b/hugr-core/src/hugr.rs @@ -5,7 +5,6 @@ pub mod hugrmut; pub(crate) mod ident; pub mod internal; pub mod patch; -pub mod persistent; pub mod serialize; pub mod validate; pub mod views; @@ -42,7 +41,7 @@ use crate::{Direction, Node}; #[derive(Clone, Debug, PartialEq)] pub struct Hugr { /// The graph encoding the adjacency structure of the HUGR. - graph: MultiPortGraph, + graph: MultiPortGraph, /// The node hierarchy. hierarchy: Hierarchy, @@ -554,8 +553,8 @@ pub(crate) mod test { use crate::extension::prelude::bool_t; use crate::ops::OpaqueOp; use crate::ops::handle::NodeHandle; - use crate::test_file; use crate::types::Signature; + use crate::{Visibility, test_file}; use cool_asserts::assert_matches; use itertools::Either; use portgraph::LinkView; @@ -675,6 +674,26 @@ pub(crate) mod test { assert_matches!(&hugr, Ok(_)); } + #[test] + #[cfg_attr(miri, ignore)] // Opening files is not supported in (isolated) miri + fn load_funcs_no_visibility() { + let hugr = Hugr::load( + BufReader::new(File::open(test_file!("hugr-no-visibility.hugr")).unwrap()), + None, + ) + .unwrap(); + + let [_mod, decl, defn] = hugr.nodes().take(3).collect_array().unwrap(); + assert_eq!( + hugr.get_optype(decl).as_func_decl().unwrap().visibility(), + &Visibility::Public + ); + assert_eq!( + hugr.get_optype(defn).as_func_defn().unwrap().visibility(), + &Visibility::Private + ); + } + fn hugr_failing_2262() -> Hugr { let sig = Signature::new(vec![bool_t(); 2], bool_t()); let mut mb = ModuleBuilder::new(); diff --git a/hugr-core/src/hugr/hugrmut.rs b/hugr-core/src/hugr/hugrmut.rs index 0265acd59f..74a6d1461d 100644 --- a/hugr-core/src/hugr/hugrmut.rs +++ b/hugr-core/src/hugr/hugrmut.rs @@ -183,7 +183,23 @@ pub trait HugrMut: HugrMutInternals { /// # Panics /// /// If the root node is not in the graph. - fn insert_hugr(&mut self, root: Self::Node, other: Hugr) -> InsertionResult; + fn insert_hugr(&mut self, root: Self::Node, other: Hugr) -> InsertionResult { + let region = other.entrypoint(); + Self::insert_region(self, root, other, region) + } + + /// Insert a sub-region of another hugr into this one, under a given parent node. + /// + /// # Panics + /// + /// - If the root node is not in the graph. + /// - If the `region` node is not in `other`. + fn insert_region( + &mut self, + root: Self::Node, + other: Hugr, + region: Node, + ) -> InsertionResult; /// Copy another hugr into this one, under a given parent node. /// @@ -247,15 +263,17 @@ pub trait HugrMut: HugrMutInternals { ExtensionRegistry: Extend; } -/// Records the result of inserting a Hugr or view -/// via [`HugrMut::insert_hugr`] or [`HugrMut::insert_from_view`]. +/// Records the result of inserting a Hugr or view via [`HugrMut::insert_hugr`], +/// [`HugrMut::insert_from_view`], or [`HugrMut::insert_region`]. /// -/// Contains a map from the nodes in the source HUGR to the nodes in the -/// target HUGR, using their respective `Node` types. +/// Contains a map from the nodes in the source HUGR to the nodes in the target +/// HUGR, using their respective `Node` types. pub struct InsertionResult { - /// The node, after insertion, that was the entrypoint of the inserted Hugr. + /// The node, after insertion, that was the root of the inserted Hugr. /// - /// That is, the value in [`InsertionResult::node_map`] under the key that was the [`HugrView::entrypoint`]. + /// That is, the value in [`InsertionResult::node_map`] under the key that + /// was the the `region` passed to [`HugrMut::insert_region`] or the + /// [`HugrView::entrypoint`] in the other cases. pub inserted_entrypoint: TargetN, /// Map from nodes in the Hugr/view that was inserted, to their new /// positions in the Hugr into which said was inserted. @@ -394,17 +412,14 @@ impl HugrMut for Hugr { (src_port, dst_port) } - fn insert_hugr( + fn insert_region( &mut self, root: Self::Node, mut other: Hugr, + region: Node, ) -> InsertionResult { - let node_map = insert_hugr_internal(self, &other, other.entry_descendants(), |&n| { - if n == other.entrypoint() { - Some(root) - } else { - None - } + let node_map = insert_hugr_internal(self, &other, other.descendants(region), |&n| { + if n == region { Some(root) } else { None } }); // Merge the extension sets. self.extensions.extend(other.extensions()); @@ -420,7 +435,7 @@ impl HugrMut for Hugr { self.metadata.set(new_node_pg, meta); } InsertionResult { - inserted_entrypoint: node_map[&other.entrypoint()], + inserted_entrypoint: node_map[®ion], node_map, } } diff --git a/hugr-core/src/hugr/internal.rs b/hugr-core/src/hugr/internal.rs index 523ddcb1b1..bb1a77f423 100644 --- a/hugr-core/src/hugr/internal.rs +++ b/hugr-core/src/hugr/internal.rs @@ -21,7 +21,7 @@ use crate::ops::handle::NodeHandle; /// view. pub trait HugrInternals { /// The portgraph graph structure returned by [`HugrInternals::region_portgraph`]. - type RegionPortgraph<'p>: LinkView + Clone + 'p + type RegionPortgraph<'p>: LinkView + Clone + 'p where Self: 'p; @@ -109,7 +109,7 @@ impl PortgraphNodeMap for std::collections::HashMap { impl HugrInternals for Hugr { type RegionPortgraph<'p> - = &'p MultiPortGraph + = &'p MultiPortGraph where Self: 'p; @@ -390,6 +390,22 @@ impl HugrMutInternals for Hugr { } } +impl Hugr { + /// Consumes the HUGR and return a flat portgraph view of the region rooted + /// at `parent`. + #[inline] + pub fn into_region_portgraph( + self, + parent: Node, + ) -> portgraph::view::FlatRegion<'static, MultiPortGraph> { + let root = parent.into_portgraph(); + let Self { + graph, hierarchy, .. + } = self; + portgraph::view::FlatRegion::new_without_root(graph, hierarchy, root) + } +} + #[cfg(test)] mod test { use crate::{ diff --git a/hugr-core/src/hugr/patch/inline_call.rs b/hugr-core/src/hugr/patch/inline_call.rs index 40eac06e0e..a4d8383847 100644 --- a/hugr-core/src/hugr/patch/inline_call.rs +++ b/hugr-core/src/hugr/patch/inline_call.rs @@ -291,29 +291,32 @@ mod test { fn test_polymorphic() -> Result<(), Box> { let tuple_ty = Type::new_tuple(vec![usize_t(); 2]); let mut fb = FunctionBuilder::new("mkpair", Signature::new(usize_t(), tuple_ty.clone()))?; - let inner = fb.define_function( - "id", - PolyFuncType::new( - [TypeBound::Copyable.into()], - Signature::new_endo(Type::new_var_use(0, TypeBound::Copyable)), - ), - )?; - let inps = inner.input_wires(); - let inner = inner.finish_with_outputs(inps)?; - let call1 = fb.call(inner.handle(), &[usize_t().into()], fb.input_wires())?; + let helper = { + let mut mb = fb.module_root_builder(); + let fb2 = mb.define_function( + "id", + PolyFuncType::new( + [TypeBound::Copyable.into()], + Signature::new_endo(Type::new_var_use(0, TypeBound::Copyable)), + ), + )?; + let inps = fb2.input_wires(); + fb2.finish_with_outputs(inps)? + }; + let call1 = fb.call(helper.handle(), &[usize_t().into()], fb.input_wires())?; let [call1_out] = call1.outputs_arr(); let tup = fb.make_tuple([call1_out, call1_out])?; - let call2 = fb.call(inner.handle(), &[tuple_ty.into()], [tup])?; + let call2 = fb.call(helper.handle(), &[tuple_ty.into()], [tup])?; let mut hugr = fb.finish_hugr_with_outputs(call2.outputs()).unwrap(); assert_eq!( - hugr.output_neighbours(inner.node()).collect::>(), + hugr.output_neighbours(helper.node()).collect::>(), [call1.node(), call2.node()] ); hugr.apply_patch(InlineCall::new(call1.node()))?; assert_eq!( - hugr.output_neighbours(inner.node()).collect::>(), + hugr.output_neighbours(helper.node()).collect::>(), [call2.node()] ); assert!(hugr.get_optype(call1.node()).is_dfg()); diff --git a/hugr-core/src/hugr/patch/outline_cfg.rs b/hugr-core/src/hugr/patch/outline_cfg.rs index e0d5a27850..081ba24ea1 100644 --- a/hugr-core/src/hugr/patch/outline_cfg.rs +++ b/hugr-core/src/hugr/patch/outline_cfg.rs @@ -48,7 +48,7 @@ impl OutlineCfg { }; let o = h.get_optype(cfg_n); let OpType::CFG(_) = o else { - return Err(OutlineCfgError::ParentNotCfg(cfg_n, o.clone())); + return Err(OutlineCfgError::ParentNotCfg(cfg_n, Box::new(o.clone()))); }; let cfg_entry = h.children(cfg_n).next().unwrap(); let mut entry = None; @@ -215,7 +215,7 @@ pub enum OutlineCfgError { NotSiblings, /// The parent node was not a CFG node #[error("The parent node {0} was not a CFG but a {1}")] - ParentNotCfg(Node, OpType), + ParentNotCfg(Node, Box), /// Multiple blocks had incoming edges #[error("Multiple blocks had predecessors outside the set - at least {0} and {1}")] MultipleEntryNodes(Node, Node), diff --git a/hugr-core/src/hugr/patch/peel_loop.rs b/hugr-core/src/hugr/patch/peel_loop.rs index 9cf61290b6..ccb9218283 100644 --- a/hugr-core/src/hugr/patch/peel_loop.rs +++ b/hugr-core/src/hugr/patch/peel_loop.rs @@ -135,7 +135,7 @@ mod test { use crate::builder::test::simple_dfg_hugr; use crate::builder::{ - Container, Dataflow, DataflowHugr, DataflowSubContainer, FunctionBuilder, HugrBuilder, + Dataflow, DataflowHugr, DataflowSubContainer, FunctionBuilder, HugrBuilder, }; use crate::extension::prelude::{bool_t, usize_t}; use crate::ops::{OpTag, OpTrait, Tag, TailLoop, handle::NodeHandle}; @@ -165,8 +165,13 @@ mod test { #[test] fn peel_loop_incoming_edges() { let i32_t = || INT_TYPES[5].clone(); - let mut mb = crate::builder::ModuleBuilder::new(); - let helper = mb + let mut fb = FunctionBuilder::new( + "main", + Signature::new(vec![bool_t(), usize_t(), i32_t()], usize_t()), + ) + .unwrap(); + let helper = fb + .module_root_builder() .declare( "helper", Signature::new( @@ -176,12 +181,6 @@ mod test { .into(), ) .unwrap(); - let mut fb = mb - .define_function( - "main", - Signature::new(vec![bool_t(), usize_t(), i32_t()], usize_t()), - ) - .unwrap(); let [b, u, i] = fb.input_wires_arr(); let (tl, call) = { let mut tlb = fb @@ -197,8 +196,7 @@ mod test { let [pred, other] = c.outputs_arr(); (tlb.finish_with_outputs(pred, [other]).unwrap(), c.node()) }; - let _ = fb.finish_with_outputs(tl.outputs()).unwrap(); - let mut h = mb.finish_hugr().unwrap(); + let mut h = fb.finish_hugr_with_outputs(tl.outputs()).unwrap(); h.apply_patch(PeelTailLoop::new(tl.node())).unwrap(); h.validate().unwrap(); diff --git a/hugr-core/src/hugr/patch/simple_replace.rs b/hugr-core/src/hugr/patch/simple_replace.rs index 8fb22febda..46e5bde205 100644 --- a/hugr-core/src/hugr/patch/simple_replace.rs +++ b/hugr-core/src/hugr/patch/simple_replace.rs @@ -58,12 +58,12 @@ impl SimpleReplacement { .inner_function_type() .ok_or(InvalidReplacement::InvalidDataflowGraph { node: replacement.entrypoint(), - op: replacement.get_optype(replacement.entrypoint()).to_owned(), + op: Box::new(replacement.get_optype(replacement.entrypoint()).to_owned()), })?; if subgraph_sig != repl_sig { return Err(InvalidReplacement::InvalidSignature { - expected: subgraph_sig, - actual: Some(repl_sig.into_owned()), + expected: Box::new(subgraph_sig), + actual: Some(Box::new(repl_sig.into_owned())), }); } Ok(Self { @@ -126,11 +126,16 @@ impl SimpleReplacement { /// of `self`. /// /// The returned port will be in `replacement`, unless the wire in the - /// replacement is empty, in which case it will another `host` port. + /// replacement is empty and `boundary` is [`BoundaryMode::SnapToHost`] (the + /// default), in which case it will be another `host` port. If + /// [`BoundaryMode::IncludeIO`] is passed, the returned port will always + /// be in `replacement` even if it is invalid (i.e. it is an IO node in + /// the replacement). pub fn linked_replacement_output( &self, port: impl Into>, host: &impl HugrView, + boundary: BoundaryMode, ) -> Option> { let HostPort(node, port) = port.into(); let pos = self @@ -139,7 +144,7 @@ impl SimpleReplacement { .iter() .position(move |&(n, p)| host.linked_inputs(n, p).contains(&(node, port)))?; - Some(self.linked_replacement_output_by_position(pos, host)) + Some(self.linked_replacement_output_by_position(pos, host, boundary)) } /// The outgoing port linked to the i-th output boundary edge of `subgraph`. @@ -150,6 +155,7 @@ impl SimpleReplacement { &self, pos: usize, host: &impl HugrView, + boundary: BoundaryMode, ) -> BoundaryPort { debug_assert!(pos < self.subgraph().signature(host).output_count()); @@ -160,7 +166,7 @@ impl SimpleReplacement { .single_linked_output(repl_out, pos) .expect("valid dfg wire"); - if out_node != repl_inp { + if out_node != repl_inp || boundary == BoundaryMode::IncludeIO { BoundaryPort::Replacement(out_node, out_port) } else { let (in_node, in_port) = *self.subgraph.incoming_ports()[out_port.index()] @@ -207,11 +213,16 @@ impl SimpleReplacement { /// of `self`. /// /// The returned ports will be in `replacement`, unless the wires in the - /// replacement are empty, in which case they are other `host` ports. + /// replacement are empty and `boundary` is [`BoundaryMode::SnapToHost`] + /// (the default), in which case they will be other `host` ports. If + /// [`BoundaryMode::IncludeIO`] is passed, the returned ports will + /// always be in `replacement` even if they are invalid (i.e. they are + /// an IO node in the replacement). pub fn linked_replacement_inputs<'a>( &'a self, port: impl Into>, host: &'a impl HugrView, + boundary: BoundaryMode, ) -> impl Iterator> + 'a { let HostPort(node, port) = port.into(); let positions = self @@ -223,18 +234,16 @@ impl SimpleReplacement { host.single_linked_output(n, p).expect("valid dfg wire") == (node, port) }); - positions.flat_map(|pos| self.linked_replacement_inputs_by_position(pos, host)) + positions + .flat_map(move |pos| self.linked_replacement_inputs_by_position(pos, host, boundary)) } /// The incoming ports linked to the i-th input boundary edge of `subgraph`. - /// - /// The ports will be in `replacement` for all endpoints of the i-th input - /// wire that are not the output node of `replacement` and be in `host` - /// otherwise. fn linked_replacement_inputs_by_position( &self, pos: usize, host: &impl HugrView, + boundary: BoundaryMode, ) -> impl Iterator> { debug_assert!(pos < self.subgraph().signature(host).input_count()); @@ -242,7 +251,7 @@ impl SimpleReplacement { self.replacement .linked_inputs(repl_inp, pos) .flat_map(move |(in_node, in_port)| { - if in_node != repl_out { + if in_node != repl_out || boundary == BoundaryMode::IncludeIO { Either::Left(std::iter::once(BoundaryPort::Replacement(in_node, in_port))) } else { let (out_node, out_port) = self.subgraph.outgoing_ports()[in_port.index()]; @@ -316,7 +325,7 @@ impl SimpleReplacement { subgraph_outgoing_ports .enumerate() .flat_map(|(pos, subg_np)| { - self.linked_replacement_inputs_by_position(pos, host) + self.linked_replacement_inputs_by_position(pos, host, BoundaryMode::SnapToHost) .filter_map(move |np| Some((np.as_replacement()?, subg_np))) }) .map(|((repl_node, repl_port), (subgraph_node, subgraph_port))| { @@ -359,7 +368,7 @@ impl SimpleReplacement { .enumerate() .filter_map(|(pos, subg_all)| { let np = self - .linked_replacement_output_by_position(pos, host) + .linked_replacement_output_by_position(pos, host, BoundaryMode::SnapToHost) .as_replacement()?; Some((np, subg_all)) }) @@ -406,7 +415,7 @@ impl SimpleReplacement { .enumerate() .filter_map(|(pos, subg_all)| { Some(( - self.linked_replacement_output_by_position(pos, host) + self.linked_replacement_output_by_position(pos, host, BoundaryMode::SnapToHost) .as_host()?, subg_all, )) @@ -500,27 +509,25 @@ impl SimpleReplacement { /// Map the host nodes in `self` according to `node_map`. /// /// `node_map` must map nodes in the current HUGR of the subgraph to - /// its equivalent nodes in some `new_hugr`. + /// its equivalent nodes in some `new_host`. /// /// This converts a replacement that acts on nodes of type `HostNode` to - /// a replacement that acts on `new_hugr`, with nodes of type `N`. - /// - /// This does not check convexity. It is up to the caller to ensure that - /// the mapped replacement obtained from this applies on a convex subgraph - /// of the new HUGR. - pub(crate) fn map_host_nodes( + /// a replacement that acts on `new_host`, with nodes of type `N`. + pub fn map_host_nodes( &self, node_map: impl Fn(HostNode) -> N, - ) -> SimpleReplacement { + new_host: &impl HugrView, + ) -> Result, InvalidReplacement> { let Self { subgraph, replacement, } = self; let subgraph = subgraph.map_nodes(node_map); - SimpleReplacement::new_unchecked(subgraph, replacement.clone()) + SimpleReplacement::try_new(subgraph, new_host, replacement.clone()) } - /// Allows to get the [Self::invalidated_nodes] without requiring a [HugrView]. + /// Allows to get the [Self::invalidated_nodes] without requiring a + /// [HugrView]. pub fn invalidation_set(&self) -> impl Iterator { self.subgraph.nodes().iter().copied() } @@ -543,6 +550,24 @@ impl PatchVerification for SimpleReplacement { } } +/// In [`SimpleReplacement::replacement`], IO nodes marking the boundary will +/// not be valid nodes in the host after the replacement is applied. +/// +/// This enum allows specifying whether these invalid nodes on the boundary +/// should be returned or should be resolved to valid nodes in the host. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)] +pub enum BoundaryMode { + /// Only consider nodes that are valid after the replacement is applied. + /// + /// This means that nodes in hosts may be returned in places where nodes in + /// the replacement would be typically expected. + #[default] + SnapToHost, + /// Include all nodes, including potentially invalid ones (inputs and + /// outputs of replacements). + IncludeIO, +} + /// Result of applying a [`SimpleReplacement`]. pub struct Outcome { /// Map from Node in replacement to corresponding Node in the result Hugr @@ -651,11 +676,11 @@ pub(in crate::hugr::patch) mod test { use crate::builder::test::n_identity; use crate::builder::{ - BuildError, Container, DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer, - HugrBuilder, ModuleBuilder, endo_sig, inout_sig, + BuildError, DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer, HugrBuilder, + ModuleBuilder, endo_sig, inout_sig, }; use crate::extension::prelude::{bool_t, qb_t}; - use crate::hugr::patch::simple_replace::Outcome; + use crate::hugr::patch::simple_replace::{BoundaryMode, Outcome}; use crate::hugr::patch::{BoundaryPort, HostPort, PatchVerification, ReplacementPort}; use crate::hugr::views::{HugrView, SiblingSubgraph}; use crate::hugr::{Hugr, HugrMut, Patch}; @@ -1148,7 +1173,11 @@ pub(in crate::hugr::patch) mod test { // Test linked_replacement_inputs with empty replacement let replacement_inputs: Vec<_> = repl - .linked_replacement_inputs((inp, OutgoingPort::from(0)), &hugr) + .linked_replacement_inputs( + (inp, OutgoingPort::from(0)), + &hugr, + BoundaryMode::SnapToHost, + ) .collect(); assert_eq!( @@ -1161,8 +1190,12 @@ pub(in crate::hugr::patch) mod test { // Test linked_replacement_output with empty replacement let replacement_output = (0..4) .map(|i| { - repl.linked_replacement_output((out, IncomingPort::from(i)), &hugr) - .unwrap() + repl.linked_replacement_output( + (out, IncomingPort::from(i)), + &hugr, + BoundaryMode::SnapToHost, + ) + .unwrap() }) .collect_vec(); @@ -1194,7 +1227,11 @@ pub(in crate::hugr::patch) mod test { }; let replacement_inputs: Vec<_> = repl - .linked_replacement_inputs((inp, OutgoingPort::from(0)), &hugr) + .linked_replacement_inputs( + (inp, OutgoingPort::from(0)), + &hugr, + BoundaryMode::SnapToHost, + ) .collect(); assert_eq!( @@ -1206,8 +1243,12 @@ pub(in crate::hugr::patch) mod test { let replacement_output = (0..4) .map(|i| { - repl.linked_replacement_output((out, IncomingPort::from(i)), &hugr) - .unwrap() + repl.linked_replacement_output( + (out, IncomingPort::from(i)), + &hugr, + BoundaryMode::SnapToHost, + ) + .unwrap() }) .collect_vec(); @@ -1244,7 +1285,11 @@ pub(in crate::hugr::patch) mod test { }; let replacement_inputs: Vec<_> = repl - .linked_replacement_inputs((inp, OutgoingPort::from(0)), &hugr) + .linked_replacement_inputs( + (inp, OutgoingPort::from(0)), + &hugr, + BoundaryMode::SnapToHost, + ) .collect(); assert_eq!( @@ -1260,8 +1305,12 @@ pub(in crate::hugr::patch) mod test { let replacement_output = (0..4) .map(|i| { - repl.linked_replacement_output((out, IncomingPort::from(i)), &hugr) - .unwrap() + repl.linked_replacement_output( + (out, IncomingPort::from(i)), + &hugr, + BoundaryMode::SnapToHost, + ) + .unwrap() }) .collect_vec(); diff --git a/hugr-core/src/hugr/persistent/resolver.rs b/hugr-core/src/hugr/persistent/resolver.rs deleted file mode 100644 index 0a0d140ee5..0000000000 --- a/hugr-core/src/hugr/persistent/resolver.rs +++ /dev/null @@ -1,43 +0,0 @@ -use relrc::EquivalenceResolver; - -/// A resolver that considers two nodes equivalent if they are the same pointer. -/// -/// Resolvers determine when two patches are equivalent and should be merged -/// in the patch history. -/// -/// This is a trivial resolver (to be expanded on later), that considers two -/// patches equivalent if they point to the same data in memory. -#[derive(Clone, Debug, Default, PartialEq, Eq)] -pub struct PointerEqResolver; - -impl EquivalenceResolver for PointerEqResolver { - type MergeMapping = (); - - type DedupKey = *const N; - - fn id(&self) -> String { - "PointerEqResolver".to_string() - } - - fn dedup_key(&self, value: &N, _incoming_edges: &[&E]) -> Self::DedupKey { - value as *const N - } - - fn try_merge_mapping( - &self, - a_value: &N, - _a_incoming_edges: &[&E], - b_value: &N, - _b_incoming_edges: &[&E], - ) -> Result { - if std::ptr::eq(a_value, b_value) { - Ok(()) - } else { - Err(relrc::resolver::NotEquivalent) - } - } - - fn move_edge_source(&self, _mapping: &Self::MergeMapping, edge: &E) -> E { - edge.clone() - } -} diff --git a/hugr-core/src/hugr/persistent/walker/pinned.rs b/hugr-core/src/hugr/persistent/walker/pinned.rs deleted file mode 100644 index 02c4d4dcad..0000000000 --- a/hugr-core/src/hugr/persistent/walker/pinned.rs +++ /dev/null @@ -1,164 +0,0 @@ -//! Utilities for pinned ports and pinned wires. -//! -//! Encapsulation: we only ever expose pinned values publicly. - -use itertools::Either; - -use crate::{Direction, IncomingPort, OutgoingPort, Port, hugr::persistent::PatchNode}; - -use super::Walker; - -/// A wire in the current HUGR of a [`Walker`] with some of its endpoints -/// pinned. -/// -/// Just like a normal HUGR [`Wire`](crate::Wire), a [`PinnedWire`] has -/// endpoints: the ports that are linked together by the wire. A [`PinnedWire`] -/// however distinguishes itself in that each of its ports is specified either -/// as "pinned" or "unpinned". A port is pinned if and only if the node it is -/// attached to is pinned in the walker. -/// -/// A [`PinnedWire`] always has at least one pinned port. -/// -/// All pinned ports of a [`PinnedWire`] can be retrieved using -/// [`PinnedWire::pinned_inports`] and [`PinnedWire::pinned_outport`]. Unpinned -/// ports, on the other hand, represent undetermined connections, which may -/// still change as the walker is expanded (see [`Walker::expand`]). -/// -/// Whether all incoming or outgoing ports are pinned can be checked using -/// [`PinnedWire::is_complete`]. -#[derive(Debug, Clone)] -pub struct PinnedWire { - outgoing: MaybePinned, - incoming: Vec>, -} - -/// A private enum to track whether a port is pinned. -#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] -enum MaybePinned