Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
382 changes: 381 additions & 1 deletion hugr-core/src/envelope/serde_with.rs

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion hugr-core/src/extension.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ mod type_def;
pub use const_fold::{ConstFold, ConstFoldResult, Folder, fold_out_row};
pub use op_def::{
CustomSignatureFunc, CustomValidator, LowerFunc, OpDef, SignatureFromArgs, SignatureFunc,
ValidateJustArgs, ValidateTypeArgs,
ValidateJustArgs, ValidateTypeArgs, deserialize_lower_funcs,
};
pub use prelude::{PRELUDE, PRELUDE_REGISTRY};
pub use type_def::{TypeDef, TypeDefBound};
Expand Down
44 changes: 40 additions & 4 deletions hugr-core/src/extension/op_def.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use super::{
};

use crate::Hugr;
use crate::envelope::serde_with::AsStringEnvelope;
use crate::envelope::serde_with::AsBinaryEnvelope;
use crate::ops::{OpName, OpNameRef};
use crate::types::type_param::{TypeArg, TypeParam, check_term_types};
use crate::types::{FuncValueType, PolyFuncType, PolyFuncTypeRV, Signature};
Expand Down Expand Up @@ -268,8 +268,12 @@ impl Debug for SignatureFunc {

/// Different ways that an [OpDef] can lower operation nodes i.e. provide a Hugr
/// that implements the operation using a set of other extensions.
///
/// Does not implement [`serde::Deserialize`] directly since the serde error for
/// untagged enums is unhelpful. Use [`deserialize_lower_funcs`] with
/// [`serde(deserialize_with = "deserialize_lower_funcs")] instead.
#[serde_as]
#[derive(serde::Deserialize, serde::Serialize)]
#[derive(serde::Serialize)]
#[serde(untagged)]
pub enum LowerFunc {
/// Lowering to a fixed Hugr. Since this cannot depend upon the [TypeArg]s,
Expand All @@ -281,7 +285,7 @@ pub enum LowerFunc {
/// [OpDef]
///
/// [ExtensionOp]: crate::ops::ExtensionOp
#[serde_as(as = "Box<AsStringEnvelope>")]
#[serde_as(as = "Box<AsBinaryEnvelope>")]
hugr: Box<Hugr>,
},
/// Custom binary function that can (fallibly) compute a Hugr
Expand All @@ -290,6 +294,34 @@ pub enum LowerFunc {
CustomFunc(Box<dyn CustomLowerFunc>),
}

/// A function for deserializing sequences of [`LowerFunc::FixedHugr`].
///
/// We could let serde deserialize [`LowerFunc`] as-is, but if the LowerFunc
/// deserialization fails it just returns an opaque "data did not match any
/// variant of untagged enum LowerFunc" error. This function will return the
/// internal errors instead.
pub fn deserialize_lower_funcs<'de, D>(deserializer: D) -> Result<Vec<LowerFunc>, D::Error>
where
D: serde::Deserializer<'de>,
{
#[serde_as]
#[derive(serde::Deserialize)]
struct FixedHugrDeserializer {
pub extensions: ExtensionSet,
#[serde_as(as = "Box<AsBinaryEnvelope>")]
pub hugr: Box<Hugr>,
}

let funcs: Vec<FixedHugrDeserializer> = serde::Deserialize::deserialize(deserializer)?;
Ok(funcs
.into_iter()
.map(|f| LowerFunc::FixedHugr {
extensions: f.extensions,
hugr: f.hugr,
})
.collect())
}

impl Debug for LowerFunc {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Expand Down Expand Up @@ -322,7 +354,11 @@ pub struct OpDef {
signature_func: SignatureFunc,
// Some operations cannot lower themselves and tools that do not understand them
// can only treat them as opaque/black-box ops.
#[serde(default, skip_serializing_if = "Vec::is_empty")]
#[serde(
default,
skip_serializing_if = "Vec::is_empty",
deserialize_with = "deserialize_lower_funcs"
)]
pub(crate) lower_funcs: Vec<LowerFunc>,

/// Operations can optionally implement [`ConstFold`] to implement constant folding.
Expand Down
28 changes: 25 additions & 3 deletions hugr-py/src/hugr/_serialization/extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,12 +66,26 @@ def deserialize(self, extension: ext.Extension) -> ext.TypeDef:


class FixedHugr(ConfiguredBaseModel):
"""Fixed HUGR used to define the lowering of an operation.

Args:
extensions: Extensions used in the HUGR.
hugr: Base64-encoded HUGR envelope.
"""

extensions: ExtensionSet
hugr: str

def deserialize(self) -> ext.FixedHugr:
hugr = Hugr.from_str(self.hugr)
return ext.FixedHugr(extensions=self.extensions, hugr=hugr)
# Loading fixed HUGRs requires reading hugr-model envelopes,
# which is not currently supported in Python.
# TODO: Add support for loading fixed HUGRs in Python.
# https://github.com/CQCL/hugr/issues/2287
msg = (
"Loading extensions with operation lowering functions is not "
+ "supported in Python"
)
raise NotImplementedError(msg)


class OpDef(ConfiguredBaseModel, populate_by_name=True):
Expand All @@ -91,13 +105,21 @@ def deserialize(self, extension: ext.Extension) -> ext.OpDef:
self.binary,
)

# Loading fixed HUGRs requires reading hugr-model envelopes,
# which is not currently supported in Python.
# We currently ignore any lower functions instead of raising an error.
#
# TODO: Add support for loading fixed HUGRs in Python.
# https://github.com/CQCL/hugr/issues/2287
lower_funcs: list[ext.FixedHugr] = []

return extension.add_op_def(
ext.OpDef(
name=self.name,
description=self.description,
misc=self.misc or {},
signature=signature,
lower_funcs=[f.deserialize() for f in self.lower_funcs],
lower_funcs=lower_funcs,
)
)

Expand Down
4 changes: 3 additions & 1 deletion hugr-py/src/hugr/ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from __future__ import annotations

import base64
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, TypeVar

Expand Down Expand Up @@ -154,7 +155,8 @@ class FixedHugr:
hugr: Hugr

def _to_serial(self) -> ext_s.FixedHugr:
return ext_s.FixedHugr(extensions=self.extensions, hugr=self.hugr.to_str())
hugr_64: str = base64.b64encode(self.hugr.to_bytes()).decode()
return ext_s.FixedHugr(extensions=self.extensions, hugr=hugr_64)


@dataclass
Expand Down
6 changes: 6 additions & 0 deletions hugr-py/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,12 @@ def validate(
h1_hash == h2_hash
), f"HUGRs are not the same for {write_fmt} -> {load_fmt}"

# Lowering functions are currently ignored in Python,
# because we don't support loading -model envelopes yet.
for ext in loaded.extensions:
for op in ext.operations.values():
assert op.lower_funcs == []


@dataclass(frozen=True, order=True)
class _NodeHash:
Expand Down
1 change: 1 addition & 0 deletions specification/schema/hugr_schema_live.json
Original file line number Diff line number Diff line change
Expand Up @@ -581,6 +581,7 @@
"type": "object"
},
"FixedHugr": {
"description": "Fixed HUGR used to define the lowering of an operation.\n\nArgs:\n extensions: Extensions used in the HUGR.\n hugr: Base64-encoded HUGR envelope.",
"properties": {
"extensions": {
"items": {
Expand Down
1 change: 1 addition & 0 deletions specification/schema/hugr_schema_strict_live.json
Original file line number Diff line number Diff line change
Expand Up @@ -581,6 +581,7 @@
"type": "object"
},
"FixedHugr": {
"description": "Fixed HUGR used to define the lowering of an operation.\n\nArgs:\n extensions: Extensions used in the HUGR.\n hugr: Base64-encoded HUGR envelope.",
"properties": {
"extensions": {
"items": {
Expand Down
1 change: 1 addition & 0 deletions specification/schema/testing_hugr_schema_live.json
Original file line number Diff line number Diff line change
Expand Up @@ -581,6 +581,7 @@
"type": "object"
},
"FixedHugr": {
"description": "Fixed HUGR used to define the lowering of an operation.\n\nArgs:\n extensions: Extensions used in the HUGR.\n hugr: Base64-encoded HUGR envelope.",
"properties": {
"extensions": {
"items": {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -581,6 +581,7 @@
"type": "object"
},
"FixedHugr": {
"description": "Fixed HUGR used to define the lowering of an operation.\n\nArgs:\n extensions: Extensions used in the HUGR.\n hugr: Base64-encoded HUGR envelope.",
"properties": {
"extensions": {
"items": {
Expand Down
Loading