From d576fd2c1a1a7e9cc307ed21c33dc7b3b7d53b85 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Agust=C3=ADn=20Borgna?= Date: Wed, 5 Nov 2025 17:57:57 +0000 Subject: [PATCH 01/11] fix: Don't try to encode wires with no values --- .../src/serialize/pytket/config/type_translators.rs | 13 ++++++++++--- tket/src/serialize/pytket/extension/prelude.rs | 4 ++++ 2 files changed, 14 insertions(+), 3 deletions(-) diff --git a/tket/src/serialize/pytket/config/type_translators.rs b/tket/src/serialize/pytket/config/type_translators.rs index fd585cfbd..431e04b24 100644 --- a/tket/src/serialize/pytket/config/type_translators.rs +++ b/tket/src/serialize/pytket/config/type_translators.rs @@ -61,6 +61,13 @@ impl TypeTranslatorSet { /// Only tuple sums, bools, and custom types are supported. /// Other types will return `None`. pub fn type_to_pytket(&self, typ: &Type) -> Option { + self.type_to_pytket_internal(typ).filter(|c| !c.is_empty()) + } + + /// Recursive call for [`Self::type_to_pytket`]. + /// + /// This allows returning empty register counts, for types that may be included inside other types. + fn type_to_pytket_internal(&self, typ: &Type) -> Option { let cache = self.type_cache.read().ok(); if let Some(count) = cache.and_then(|c| c.get(typ).cloned()) { return count; @@ -79,7 +86,7 @@ impl TypeTranslatorSet { .iter() .map(|ty| { match ty.clone().try_into() { - Ok(ty) => self.type_to_pytket(&ty), + Ok(ty) => self.type_to_pytket_internal(&ty), // Sum types with row variables (variable tuple lengths) are not supported. Err(_) => None, } @@ -161,10 +168,10 @@ mod tests { } #[rstest::rstest] - #[case::empty(SumType::new_unary(0).into(), Some(RegisterCount::default()))] + #[case::empty(SumType::new_unary(0).into(), None)] #[case::native_bool(SumType::new_unary(2).into(), Some(RegisterCount::only_bits(1)))] #[case::simple(bool_t(), Some(RegisterCount::only_bits(1)))] - #[case::tuple(SumType::new_tuple(vec![bool_t(), qb_t(), bool_t()]).into(), Some(RegisterCount::new(1, 2, 0)))] + #[case::tuple(SumType::new_tuple(vec![bool_t(), qb_t(), bool_t(), SumType::new_unary(1).into()]).into(), Some(RegisterCount::new(1, 2, 0)))] #[case::unsupported(SumType::new([vec![bool_t(), qb_t()], vec![bool_t()]]).into(), None)] fn test_translations( translator_set: TypeTranslatorSet, diff --git a/tket/src/serialize/pytket/extension/prelude.rs b/tket/src/serialize/pytket/extension/prelude.rs index 73fe344f4..7b7b2ebf7 100644 --- a/tket/src/serialize/pytket/extension/prelude.rs +++ b/tket/src/serialize/pytket/extension/prelude.rs @@ -93,6 +93,10 @@ impl PreludeEmitter { let args = op.args().first(); match args { Some(TypeArg::Tuple(elems)) | Some(TypeArg::List(elems)) => { + if elems.is_empty() { + return Ok(EncodeStatus::Unsupported); + } + for arg in elems { let TypeArg::Runtime(ty) = arg else { return Ok(EncodeStatus::Unsupported); From 035c76b69146abe1816f8c96d1d3c6920d49b629 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Agust=C3=ADn=20Borgna?= Date: Wed, 5 Nov 2025 18:54:40 +0000 Subject: [PATCH 02/11] feat: Replace SiblingSubgraph with simpler OpaqueSubgraph --- tket/src/serialize/pytket/decoder/subgraph.rs | 59 +++--- tket/src/serialize/pytket/encoder.rs | 86 ++++----- .../pytket/encoder/unsupported_tracker.rs | 10 +- tket/src/serialize/pytket/opaque.rs | 46 ++--- tket/src/serialize/pytket/opaque/payload.rs | 34 ++-- tket/src/serialize/pytket/opaque/subgraph.rs | 177 ++++++++++++++++++ tket/src/serialize/pytket/tests.rs | 4 +- 7 files changed, 288 insertions(+), 128 deletions(-) create mode 100644 tket/src/serialize/pytket/opaque/subgraph.rs diff --git a/tket/src/serialize/pytket/decoder/subgraph.rs b/tket/src/serialize/pytket/decoder/subgraph.rs index 7ce965264..64f96ea63 100644 --- a/tket/src/serialize/pytket/decoder/subgraph.rs +++ b/tket/src/serialize/pytket/decoder/subgraph.rs @@ -5,9 +5,8 @@ use std::sync::Arc; use hugr::builder::Container; use hugr::hugr::hugrmut::{HugrMut, InsertedForest}; -use hugr::hugr::views::SiblingSubgraph; use hugr::ops::{OpTag, OpTrait}; -use hugr::types::{Signature, Type}; +use hugr::types::Type; use hugr::{Hugr, HugrView, Node, OutgoingPort, PortIndex, Wire}; use hugr_core::hugr::internal::HugrMutInternals; use itertools::Itertools; @@ -16,7 +15,9 @@ use crate::serialize::pytket::decoder::{ DecodeStatus, FoundWire, LoadedParameter, PytketDecoderContext, TrackedBit, TrackedQubit, }; use crate::serialize::pytket::extension::RegisterCount; -use crate::serialize::pytket::opaque::{EncodedEdgeID, OpaqueSubgraphPayload, SubgraphId}; +use crate::serialize::pytket::opaque::{ + EncodedEdgeID, OpaqueSubgraph, OpaqueSubgraphPayload, SubgraphId, +}; use crate::serialize::pytket::{PytketDecodeError, PytketDecodeErrorInner, PytketDecoderConfig}; impl<'h> PytketDecoderContext<'h> { @@ -64,13 +65,8 @@ impl<'h> PytketDecoderContext<'h> { let Some(subgraph) = self.opaque_subgraphs.and_then(|s| s.get(id)) else { return Err(PytketDecodeErrorInner::OpaqueSubgraphNotFound { id }.wrap()); }; - let signature = subgraph.signature(self.builder.hugr()); - let old_parent = self - .builder - .hugr() - .get_parent(subgraph.nodes()[0]) - .ok_or_else(|| PytketDecodeErrorInner::ExternalSubgraphWasModified { id }.wrap())?; + let old_parent = subgraph.region(); if !OpTag::DataflowParent.is_superset(self.builder.hugr().get_optype(old_parent).tag()) { return Err(PytketDecodeErrorInner::ExternalSubgraphWasModified { id }.wrap()); } @@ -82,12 +78,10 @@ impl<'h> PytketDecoderContext<'h> { } self.rewire_external_subgraph_inputs( - subgraph, qubits, bits, params, old_parent, new_parent, &signature, + subgraph, qubits, bits, params, old_parent, new_parent, )?; - self.rewire_external_subgraph_outputs( - subgraph, qubits, bits, old_parent, new_parent, &signature, - )?; + self.rewire_external_subgraph_outputs(subgraph, qubits, bits, old_parent, new_parent)?; Ok(DecodeStatus::Success) } @@ -95,22 +89,25 @@ impl<'h> PytketDecoderContext<'h> { /// Rewire the inputs of an external subgraph moved to the new region. /// /// Helper for [`Self::insert_external_subgraph`]. - #[expect(clippy::too_many_arguments)] fn rewire_external_subgraph_inputs( &mut self, - subgraph: &SiblingSubgraph, + subgraph: &OpaqueSubgraph, mut input_qubits: &[TrackedQubit], mut input_bits: &[TrackedBit], mut input_params: &[LoadedParameter], old_parent: Node, new_parent: Node, - signature: &Signature, ) -> Result<(), PytketDecodeError> { let old_input = self.builder.hugr().get_io(old_parent).unwrap()[0]; let new_input = self.builder.hugr().get_io(new_parent).unwrap()[0]; // Reconnect input wires from parts of/nodes in the region that have been encoded into pytket. - for (ty, targets) in signature.input().iter().zip_eq(subgraph.incoming_ports()) { + for (ty, (tgt_node, tgt_port)) in subgraph + .signature() + .input() + .iter() + .zip_eq(subgraph.incoming_ports()) + { let found_wire = self.wire_tracker.find_typed_wire( self.config(), ty, @@ -125,9 +122,11 @@ impl<'h> PytketDecoderContext<'h> { FoundWire::Parameter(param) => param.wire(), FoundWire::Unsupported { .. } => { // Input port with an unsupported type. - let Some((neigh, neigh_port)) = targets.first().and_then(|(tgt, port)| { - self.builder.hugr().single_linked_output(*tgt, *port) - }) else { + let Some((neigh, neigh_port)) = self + .builder + .hugr() + .single_linked_output(*tgt_node, *tgt_port) + else { // The input was disconnected. We just skip it. // (This is the case for unused other-ports) continue; @@ -143,11 +142,9 @@ impl<'h> PytketDecoderContext<'h> { } }; - for (tgt, port) in targets { - self.builder - .hugr_mut() - .connect(wire.node(), wire.source(), *tgt, *port); - } + self.builder + .hugr_mut() + .connect(wire.node(), wire.source(), *tgt_node, *tgt_port); } Ok(()) @@ -161,12 +158,11 @@ impl<'h> PytketDecoderContext<'h> { /// Helper for [`Self::insert_external_subgraph`]. fn rewire_external_subgraph_outputs( &mut self, - subgraph: &SiblingSubgraph, + subgraph: &OpaqueSubgraph, qubits: &[TrackedQubit], bits: &[TrackedBit], old_parent: Node, new_parent: Node, - signature: &Signature, ) -> Result<(), PytketDecodeError> { let old_output = self.builder.hugr().get_io(old_parent).unwrap()[1]; let new_output = self.builder.hugr().get_io(new_parent).unwrap()[1]; @@ -174,7 +170,12 @@ impl<'h> PytketDecoderContext<'h> { let mut output_qubits = qubits; let mut output_bits = bits; - for (ty, (src, src_port)) in signature.output().iter().zip_eq(subgraph.outgoing_ports()) { + for (ty, (src, src_port)) in subgraph + .signature() + .output() + .iter() + .zip_eq(subgraph.outgoing_ports()) + { // Output wire from the subgraph. Depending on the type, we may need // to track new qubits and bits, re-connect it to some output, or // leave it untouched. @@ -191,7 +192,7 @@ impl<'h> PytketDecoderContext<'h> { if wire_qubits.is_none() || wire_bits.is_none() { return Err(make_unexpected_node_out_error( self.config(), - signature.output().iter(), + subgraph.signature().output().iter(), qubits.len(), bits.len(), )); diff --git a/tket/src/serialize/pytket/encoder.rs b/tket/src/serialize/pytket/encoder.rs index bc47ee94b..96c8391a4 100644 --- a/tket/src/serialize/pytket/encoder.rs +++ b/tket/src/serialize/pytket/encoder.rs @@ -5,8 +5,6 @@ mod unsupported_tracker; mod value_tracker; use hugr::core::HugrNode; -use hugr::hugr::views::sibling_subgraph::{IncomingPorts, OutgoingPorts}; -use hugr::hugr::views::SiblingSubgraph; use hugr_core::hugr::internal::PortgraphNodeMap; use tket_json_rs::clexpr::InputClRegister; use tket_json_rs::opbox::BoxID; @@ -34,7 +32,9 @@ use crate::circuit::Circuit; use crate::serialize::pytket::circuit::EncodedCircuitInfo; use crate::serialize::pytket::config::PytketEncoderConfig; use crate::serialize::pytket::extension::RegisterCount; -use crate::serialize::pytket::opaque::{OpaqueSubgraphPayload, OPGROUP_OPAQUE_HUGR}; +use crate::serialize::pytket::opaque::{ + OpaqueSubgraph, OpaqueSubgraphPayload, OPGROUP_OPAQUE_HUGR, +}; /// The state of an in-progress [`SerialCircuit`] being built from a [`Circuit`]. #[derive(derive_more::Debug)] @@ -228,25 +228,25 @@ impl PytketEncoderContext { let mut extra_subgraph: Option> = None; while !self.unsupported.is_empty() { let node = self.unsupported.iter().next().unwrap(); - let opaque_subgraphs = self.unsupported.extract_component(node); - match self.emit_unsupported(opaque_subgraphs.clone(), circ) { + let opaque_subgraphs = self.unsupported.extract_component(node, circ.hugr())?; + match self.emit_unsupported(&opaque_subgraphs, circ) { Ok(()) => (), Err(PytketEncodeError::UnsupportedSubgraphHasNoRegisters {}) => { // We'll store the nodes in the `extra_subgraph` field of the `EncodedCircuitInfo`. // So the decoder can reconstruct the original subgraph. extra_subgraph .get_or_insert_default() - .extend(opaque_subgraphs); + .extend(opaque_subgraphs.nodes().iter().cloned()); } Err(e) => return Err(e), } } - let extra_subgraph = extra_subgraph.map(|nodes| { - let subgraph = - SiblingSubgraph::try_from_nodes(nodes.into_iter().collect_vec(), circ.hugr()) - .expect("Failed to create subgraph from unsupported nodes"); - self.opaque_subgraphs.register_opaque_subgraph(subgraph) - }); + let extra_subgraph = extra_subgraph + .map(|nodes| -> Result<_, PytketEncodeError> { + let subgraph = OpaqueSubgraph::try_from_nodes(nodes, circ.hugr())?; + Ok(self.opaque_subgraphs.register_opaque_subgraph(subgraph)) + }) + .transpose()?; let tracker_result = self.values.finish(circ, region)?; @@ -306,8 +306,10 @@ impl PytketEncoderContext { // // We need to emit the unsupported node here before returning the values. if self.unsupported.is_unsupported(wire.node()) { - let unsupported_nodes = self.unsupported.extract_component(wire.node()); - self.emit_unsupported(unsupported_nodes, circ)?; + let unsupported_nodes = self + .unsupported + .extract_component(wire.node(), circ.hugr())?; + self.emit_unsupported(&unsupported_nodes, circ)?; debug_assert!(!self.unsupported.is_unsupported(wire.node())); return self.get_wire_values(wire, circ); } @@ -340,7 +342,7 @@ impl PytketEncoderContext { node: H::Node, circ: &Circuit, ) -> Result> { - self.get_input_values_internal(node, circ, |_| true) + self.get_input_values_internal(node, circ, |_| true)? .try_into_tracked_values() } @@ -354,7 +356,7 @@ impl PytketEncoderContext { node: H::Node, circ: &Circuit, wire_filter: impl Fn(Wire) -> bool, - ) -> NodeInputValues { + ) -> Result, PytketEncodeError> { let mut tracked_values = TrackedValues::default(); let mut unknown_values = Vec::new(); @@ -380,15 +382,13 @@ impl PytketEncoderContext { Err(PytketEncodeError::OpEncoding(PytketEncodeOpError::WireHasNoValues { wire, })) => unknown_values.push(wire), - Err(e) => panic!( - "get_wire_values should only return WireHasNoValues errors, but got: {e}" - ), + Err(e) => return Err(e), } } - NodeInputValues { + Ok(NodeInputValues { tracked_values, unknown_values, - } + }) } /// Helper to emit a new tket1 command corresponding to a single HUGR node. @@ -575,34 +575,17 @@ impl PytketEncoderContext { /// /// ## Arguments /// - /// - `unsupported_nodes`: The list of nodes to encode as an opaque subgraph. + /// - `subgraph`: The subgraph of unsupported nodes to encode as an opaque subgraph. /// - `circ`: The circuit containing the unsupported nodes. fn emit_unsupported( &mut self, - unsupported_nodes: BTreeSet, + subgraph: &OpaqueSubgraph, circ: &Circuit, ) -> Result<(), PytketEncodeError> { - let subcircuit_id = format!("tk{}", unsupported_nodes.iter().min().unwrap()); - - // TODO: Use a cached topo checker here instead of traversing the full graph each time we create a `SiblingSubgraph`. - // - // TopoConvexChecker likes to borrow the hugr, so it'd be too invasive to store in the `Context`. - let subgraph = SiblingSubgraph::try_from_nodes( - unsupported_nodes.iter().cloned().collect_vec(), - circ.hugr(), - ) - .unwrap_or_else(|e| { - panic!( - "Failed to create subgraph from unsupported nodes [{}]: {e}", - unsupported_nodes.iter().join(", ") - ) - }); - let subgraph_incoming_ports: IncomingPorts = subgraph.incoming_ports().clone(); - let subgraph_outgoing_ports: OutgoingPorts = subgraph.outgoing_ports().clone(); - let subgraph_signature = subgraph.signature(circ.hugr()); - // Encode a payload referencing the subgraph in the Hugr. - let subgraph_id = self.opaque_subgraphs.register_opaque_subgraph(subgraph); + let subgraph_id = self + .opaque_subgraphs + .register_opaque_subgraph(subgraph.clone()); let payload = OpaqueSubgraphPayload::new_external(subgraph_id); // Collects the input values for the subgraph. @@ -610,14 +593,10 @@ impl PytketEncoderContext { // The [`UnsupportedTracker`] ensures that at this point all local input wires must come from // already-encoded nodes, and not from other unsupported nodes not in `unsupported_nodes`. let mut op_values = TrackedValues::default(); - for incoming in subgraph_incoming_ports.iter() { - let Some((first_node, first_port)) = incoming.first() else { - continue; - }; - + for (node, port) in subgraph.incoming_ports().iter() { let (neigh, neigh_out) = circ .hugr() - .single_linked_output(*first_node, *first_port) + .single_linked_output(*node, *port) .expect("Dataflow input port should have a single neighbour"); let wire = Wire::new(neigh, neigh_out); @@ -640,9 +619,10 @@ impl PytketEncoderContext { // Output parameters are mapped to a fresh variable, that can be tracked // back to the encoded subcircuit's function name. let mut out_param_count = 0; - for ((out_node, out_port), ty) in subgraph_outgoing_ports + for ((out_node, out_port), ty) in subgraph + .outgoing_ports() .iter() - .zip(subgraph_signature.output().iter()) + .zip(subgraph.signature().output().iter()) { if self.config().type_to_pytket(ty).is_none() { // Do not try to register ports with unsupported types. @@ -658,9 +638,7 @@ impl PytketEncoderContext { EmitCommandOptions::new().output_params(|p| { let range = out_param_count..out_param_count + p.expected_count; out_param_count += p.expected_count; - range - .map(|i| format!("{subcircuit_id}_out{i}")) - .collect_vec() + range.map(|i| format!("{subgraph_id}_out{i}")).collect_vec() }), )?; op_values.append(new_outputs); diff --git a/tket/src/serialize/pytket/encoder/unsupported_tracker.rs b/tket/src/serialize/pytket/encoder/unsupported_tracker.rs index 62133097e..8b5739b4a 100644 --- a/tket/src/serialize/pytket/encoder/unsupported_tracker.rs +++ b/tket/src/serialize/pytket/encoder/unsupported_tracker.rs @@ -6,6 +6,8 @@ use hugr::core::HugrNode; use hugr::HugrView; use petgraph::unionfind::UnionFind; +use crate::serialize::pytket::opaque::OpaqueSubgraph; +use crate::serialize::pytket::PytketEncodeError; use crate::Circuit; /// A structure for tracking nodes in the hugr that cannot be encoded as TKET1 @@ -75,7 +77,11 @@ impl UnsupportedTracker { /// Once a component has been extracted, no new nodes can be added to it and /// calling [`UnsupportedTracker::record_node`] will use a new component /// instead. - pub fn extract_component(&mut self, node: N) -> BTreeSet { + pub fn extract_component( + &mut self, + node: N, + hugr: &impl HugrView, + ) -> Result, PytketEncodeError> { let node_data = self.nodes.remove(&node).unwrap(); let component = node_data.component; let representative = self.components.find_mut(component); @@ -95,7 +101,7 @@ impl UnsupportedTracker { self.nodes.remove(n); } - nodes + OpaqueSubgraph::try_from_nodes(nodes, hugr) } /// Returns an iterator over the unextracted nodes in the tracker. diff --git a/tket/src/serialize/pytket/opaque.rs b/tket/src/serialize/pytket/opaque.rs index 66fdf1283..daec3b032 100644 --- a/tket/src/serialize/pytket/opaque.rs +++ b/tket/src/serialize/pytket/opaque.rs @@ -2,6 +2,9 @@ //! encoding as barrier metadata in pytket circuits. mod payload; +mod subgraph; + +pub use subgraph::OpaqueSubgraph; pub use payload::{EncodedEdgeID, OpaqueSubgraphPayload, OPGROUP_OPAQUE_HUGR}; @@ -10,10 +13,9 @@ use std::ops::Index; use crate::serialize::pytket::PytketEncodeError; use hugr::core::HugrNode; -use hugr::hugr::views::SiblingSubgraph; use hugr::HugrView; -/// The ID of a subgraph in the Hugr. +/// The ID of an [`OpaqueSubgraph`] registered in an `OpaqueSubgraphs` tracker. #[derive(Debug, derive_more::Display, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] #[display("{tracker_id}.{local_id}")] pub struct SubgraphId { @@ -23,6 +25,22 @@ pub struct SubgraphId { local_id: usize, } +impl serde::Serialize for SubgraphId { + fn serialize(&self, s: S) -> Result { + (&self.tracker_id, &self.local_id).serialize(s) + } +} + +impl<'de> serde::Deserialize<'de> for SubgraphId { + fn deserialize>(d: D) -> Result { + let (tracker_id, local_id) = serde::Deserialize::deserialize(d)?; + Ok(Self { + tracker_id, + local_id, + }) + } +} + /// A set of subgraphs a HUGR that have been marked as _unsupported_ during a /// pytket encoding. #[derive(Debug, Clone)] @@ -41,23 +59,7 @@ pub(super) struct OpaqueSubgraphs { /// in the pytket circuit. /// /// Subcircuits are identified in the barrier metadata by their ID. See [`SubgraphId`]. - opaque_subgraphs: BTreeMap>, -} - -impl serde::Serialize for SubgraphId { - fn serialize(&self, s: S) -> Result { - (&self.tracker_id, &self.local_id).serialize(s) - } -} - -impl<'de> serde::Deserialize<'de> for SubgraphId { - fn deserialize>(d: D) -> Result { - let (tracker_id, local_id) = serde::Deserialize::deserialize(d)?; - Ok(Self { - tracker_id, - local_id, - }) - } + opaque_subgraphs: BTreeMap>, } impl OpaqueSubgraphs { @@ -78,7 +80,7 @@ impl OpaqueSubgraphs { /// Register a new opaque subgraph in the Hugr. /// /// Returns and ID that can be used to identify the subgraph in the pytket circuit. - pub fn register_opaque_subgraph(&mut self, subgraph: SiblingSubgraph) -> SubgraphId { + pub fn register_opaque_subgraph(&mut self, subgraph: OpaqueSubgraph) -> SubgraphId { let id = SubgraphId { local_id: self.next_local_id, tracker_id: self.id, @@ -93,7 +95,7 @@ impl OpaqueSubgraphs { /// # Panics /// /// Panics if the ID is invalid. - pub fn get(&self, id: SubgraphId) -> Option<&SiblingSubgraph> { + pub fn get(&self, id: SubgraphId) -> Option<&OpaqueSubgraph> { self.opaque_subgraphs.get(&id) } @@ -150,7 +152,7 @@ impl OpaqueSubgraphs { } impl Index for OpaqueSubgraphs { - type Output = SiblingSubgraph; + type Output = OpaqueSubgraph; fn index(&self, index: SubgraphId) -> &Self::Output { self.get(index) diff --git a/tket/src/serialize/pytket/opaque/payload.rs b/tket/src/serialize/pytket/opaque/payload.rs index 0991c3b12..b2a116046 100644 --- a/tket/src/serialize/pytket/opaque/payload.rs +++ b/tket/src/serialize/pytket/opaque/payload.rs @@ -4,11 +4,12 @@ use hugr::core::HugrNode; use hugr::envelope::{EnvelopeConfig, EnvelopeError}; use hugr::extension::resolution::{resolve_type_extensions, WeakExtensionRegistry}; use hugr::extension::{ExtensionRegistry, ExtensionRegistryLoadError}; -use hugr::hugr::views::SiblingSubgraph; use hugr::package::Package; use hugr::types::Type; use hugr::{HugrView, Wire}; +use itertools::Itertools; +use crate::serialize::pytket::opaque::OpaqueSubgraph; use crate::serialize::pytket::{ PytketDecodeError, PytketDecodeErrorInner, PytketEncodeError, PytketEncodeOpError, }; @@ -124,31 +125,27 @@ impl OpaqueSubgraphPayload { /// Returns an error if a node in the subgraph has children or calls a /// global function. pub fn new_inline( - subgraph: &SiblingSubgraph, + subgraph: &OpaqueSubgraph, hugr: &impl HugrView, ) -> Result> { - let signature = subgraph.signature(hugr); + let signature = subgraph.signature(); - if !subgraph.function_calls().is_empty() - || subgraph - .nodes() - .iter() - .any(|n| hugr.first_child(*n).is_some()) - { + let Some(opaque_hugr) = subgraph + .is_sibling_subgraph_compatible() + .then(|| subgraph.extract_subgraph(hugr).ok()) + .flatten() + else { return Err(PytketEncodeOpError::UnsupportedStandaloneSubgraph { - nodes: subgraph.nodes().to_vec(), + nodes: subgraph.nodes().iter().cloned().collect_vec(), } .into()); - } + }; - let mut inputs = Vec::with_capacity(subgraph.incoming_ports().iter().map(Vec::len).sum()); - for subgraph_inputs in subgraph.incoming_ports() { - let Some((inp_node, inp_port0)) = subgraph_inputs.first() else { - continue; - }; - let input_wire = Wire::from_connected_port(*inp_node, *inp_port0, hugr); + let mut inputs = Vec::with_capacity(subgraph.incoming_ports().len()); + for (inp_node, inp_port) in subgraph.incoming_ports() { + let input_wire = Wire::from_connected_port(*inp_node, *inp_port, hugr); let edge_id = EncodedEdgeID::new(input_wire); - inputs.extend(itertools::repeat_n(edge_id, subgraph_inputs.len())); + inputs.push(edge_id); } let outputs = subgraph @@ -156,7 +153,6 @@ impl OpaqueSubgraphPayload { .iter() .map(|(n, p)| EncodedEdgeID::new(Wire::new(*n, *p))); - let opaque_hugr = subgraph.extract_subgraph(hugr, ""); let hugr_envelope = Package::from_hugr(opaque_hugr) .store_str(EnvelopeConfig::text()) .unwrap(); diff --git a/tket/src/serialize/pytket/opaque/subgraph.rs b/tket/src/serialize/pytket/opaque/subgraph.rs new file mode 100644 index 000000000..a00e5f1b5 --- /dev/null +++ b/tket/src/serialize/pytket/opaque/subgraph.rs @@ -0,0 +1,177 @@ +//! Opaque subgraph definition. + +use hugr::ops::OpTrait; +use itertools::{Either, Itertools}; + +use std::collections::BTreeSet; + +use crate::serialize::pytket::PytketEncodeError; +use hugr::core::HugrNode; +use hugr::hugr::views::sibling_subgraph::InvalidSubgraph; +use hugr::hugr::views::SiblingSubgraph; +use hugr::types::Signature; +use hugr::{Direction, Hugr, HugrView, IncomingPort, OutgoingPort}; + +/// A subgraph of nodes in the Hugr that could not be encoded as TKET1 +/// operations. +/// +/// of subgraphs it can represent; in particular const edges and order edge are +/// not allowed in the boundary. +#[derive(Debug, Clone)] +pub struct OpaqueSubgraph { + /// The nodes in the subgraph. + nodes: BTreeSet, + /// The incoming ports of the subgraph. + incoming_ports: Vec<(N, IncomingPort)>, + /// The outgoing ports of the subgraph. + outgoing_ports: Vec<(N, OutgoingPort)>, + /// The signature of the subgraph. + signature: Signature, + /// The region containing the subgraph. + region: N, + /// Whether the subgraph can be represented as a [`SiblingSubgraph`] + /// with no external edges. + /// + /// Some cases where this is not possible are: + /// - Having non-local edges. + /// - Having const edges to global definitions. + /// - Having order edges to nodes outside the subgraph. + /// - Calling a global functions. + sibling_subgraph_compatible: bool, +} + +impl OpaqueSubgraph { + /// Create a new [`UnsupportedSubgraph`]. + pub(in crate::serialize::pytket) fn try_from_nodes( + nodes: BTreeSet, + hugr: &impl HugrView, + ) -> Result> { + let region = nodes + .first() + .and_then(|n| hugr.get_parent(*n)) + .unwrap_or_else(|| hugr.entrypoint()); + + // Traverse the nodes, collecting `EdgeKind::Value` boundary ports that connect to other nodes in the _same region_. + // Ignores all other ports. + let mut incoming_ports = Vec::new(); + let mut outgoing_ports = Vec::new(); + let mut input_types = Vec::new(); + let mut output_types = Vec::new(); + let mut sibling_subgraph_compatible = true; + + for &node in &nodes { + let op = hugr.get_optype(node); + let Some(signature) = op.dataflow_signature() else { + continue; + }; + // Check the value ports for boundary edges. + let mut has_nonlocal_boundary = false; + for port in signature + .ports(Direction::Incoming) + .chain(signature.ports(Direction::Outgoing)) + { + let ty = signature.port_type(port).unwrap(); + // If it's a value port to another node in the same region but outside the set, add it to the subgraph. + let mut is_local_boundary = false; + for (n, _) in hugr.linked_ports(node, port) { + if nodes.contains(&n) { + continue; + } + match hugr.get_parent(n) == Some(region) { + true => is_local_boundary = true, + false => has_nonlocal_boundary = true, + } + if is_local_boundary && has_nonlocal_boundary { + break; + } + } + if is_local_boundary { + match port.as_directed() { + Either::Left(inc) => { + incoming_ports.push((node, inc)); + input_types.push(ty.clone()); + } + Either::Right(out) => { + outgoing_ports.push((node, out)); + output_types.push(ty.clone()); + } + } + } + } + // If the node is a region parent, it cannot be contained in a `SiblingSubgraph`. + let is_region_parent = hugr.first_child(node).is_some(); + // If the node has static ports or _other_ ports that connect outside the set, it cannot be contained in a `SiblingSubgraph`. + let non_value_boundary = op + .static_port(Direction::Incoming) + .iter() + .chain(op.static_port(Direction::Outgoing).iter()) + .chain(op.other_port(Direction::Incoming).iter()) + .chain(op.other_port(Direction::Outgoing).iter()) + .any(|&p| hugr.linked_ports(node, p).any(|(n, _)| !nodes.contains(&n))); + + sibling_subgraph_compatible &= + !has_nonlocal_boundary && !is_region_parent && !non_value_boundary; + } + let signature = Signature::new(input_types, output_types); + + Ok(Self { + nodes, + incoming_ports, + outgoing_ports, + signature, + region, + sibling_subgraph_compatible, + }) + } + + /// Returns the nodes in the subgraph. + pub fn nodes(&self) -> &BTreeSet { + &self.nodes + } + + /// Returns the incoming ports of the subgraph. + pub fn incoming_ports(&self) -> &[(N, IncomingPort)] { + &self.incoming_ports + } + + /// Returns the outgoing ports of the subgraph. + pub fn outgoing_ports(&self) -> &[(N, OutgoingPort)] { + &self.outgoing_ports + } + + /// Returns the signature of the subgraph. + pub fn signature(&self) -> &Signature { + &self.signature + } + + /// Returns the region containing the subgraph. + pub fn region(&self) -> N { + self.region + } + + /// Returns whether the subgraph can be represented as a [`SiblingSubgraph`] + /// with no external edges. + /// + /// Some cases where this is not possible are: + /// - Having non-local edges. + /// - Having const edges to global definitions. + /// - Having order edges to nodes outside the subgraph. + /// - Calling a global functions. + pub fn is_sibling_subgraph_compatible(&self) -> bool { + self.sibling_subgraph_compatible + } + + /// Extract the subgraph as a standalone Hugr. + /// + /// Return an error if the subgraph cannot be represented as a + /// [`SiblingSubgraph`]. See + /// [`OpaqueSubgraph::is_sibling_subgraph_compatible`] for more details. + pub fn extract_subgraph( + &self, + hugr: &impl HugrView, + ) -> Result> { + let nodes = self.nodes().iter().cloned().collect_vec(); + let subgraph = SiblingSubgraph::try_from_nodes(nodes, hugr).unwrap(); + Ok(subgraph.extract_subgraph(hugr, "")) + } +} diff --git a/tket/src/serialize/pytket/tests.rs b/tket/src/serialize/pytket/tests.rs index 7ba958665..706b0f9b4 100644 --- a/tket/src/serialize/pytket/tests.rs +++ b/tket/src/serialize/pytket/tests.rs @@ -930,8 +930,8 @@ fn fail_on_modified_hugr(circ_tk1_ops: Circuit) { #[case::output_parameter_wire(circ_output_parameter_wire(), 1, CircuitRoundtripTestConfig::Default)] // TODO: fix edge case: non-local edge from an unsupported node inside a nested CircBox // to/from the input of the head region being encoded... -#[should_panic(expected = "Unsupported edge kind")] -#[case::non_local(circ_non_local(), 1, CircuitRoundtripTestConfig::Default)] +#[should_panic(expected = "has an unconnected port")] +#[case::non_local(circ_non_local(), 2, CircuitRoundtripTestConfig::Default)] fn encoded_circuit_roundtrip( #[case] circ: Circuit, From fb5278d182fed285006f09cd11569b0b787437f8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Agust=C3=ADn=20Borgna?= Date: Thu, 6 Nov 2025 12:28:39 +0000 Subject: [PATCH 03/11] Don't encode nodes with order edges --- tket/src/serialize/pytket/encoder.rs | 151 +++++++++++++++++---------- tket/src/serialize/pytket/tests.rs | 25 +++++ 2 files changed, 120 insertions(+), 56 deletions(-) diff --git a/tket/src/serialize/pytket/encoder.rs b/tket/src/serialize/pytket/encoder.rs index 96c8391a4..5268b68ed 100644 --- a/tket/src/serialize/pytket/encoder.rs +++ b/tket/src/serialize/pytket/encoder.rs @@ -17,9 +17,10 @@ use hugr::types::EdgeKind; use std::borrow::Cow; use std::collections::{BTreeSet, HashMap}; +use std::ops::RangeTo; use std::sync::{Arc, RwLock}; -use hugr::{HugrView, OutgoingPort, Wire}; +use hugr::{Direction, HugrView, OutgoingPort, Wire}; use itertools::Itertools; use tket_json_rs::circuit_json::{self, SerialCircuit}; use unsupported_tracker::UnsupportedTracker; @@ -619,6 +620,11 @@ impl PytketEncoderContext { // Output parameters are mapped to a fresh variable, that can be tracked // back to the encoded subcircuit's function name. let mut out_param_count = 0; + let input_qubits = op_values.qubits.clone(); + let input_bits = op_values.bits.clone(); + let mut out_qubits = input_qubits.as_slice(); + let mut out_bits = input_bits.as_slice(); + for ((out_node, out_port), ty) in subgraph .outgoing_ports() .iter() @@ -632,14 +638,14 @@ impl PytketEncoderContext { *out_node, *out_port, circ, - &op_values.qubits, - &op_values.bits, + &mut out_qubits, + &mut out_bits, &input_param_exprs, - EmitCommandOptions::new().output_params(|p| { + |p| { let range = out_param_count..out_param_count + p.expected_count; out_param_count += p.expected_count; range.map(|i| format!("{subgraph_id}_out{i}")).collect_vec() - }), + }, )?; op_values.append(new_outputs); } @@ -849,6 +855,10 @@ impl PytketEncoderContext { ) -> Result> { let optype = circ.hugr().get_optype(node); + // Try to register non-local inputs to nodes when possible (e.g. + // constants, function definitions). + // + // Otherwise, mark the node as unsupported. if self.encode_nonlocal_inputs(node, optype, circ)? == EncodeStatus::Unsupported { self.unsupported.record_node(node, circ); return Ok(EncodeStatus::Unsupported); @@ -860,9 +870,12 @@ impl PytketEncoderContext { // the unsupported tracker and move on. match optype { OpType::ExtensionOp(op) => { - let config = Arc::clone(&self.config); - if config.op_to_pytket(node, op, circ, self)? == EncodeStatus::Success { - return Ok(EncodeStatus::Success); + // Ignore nodes with order edges, as they cannot be represented in the pytket circuit. + if !self.has_order_edges(node, optype, circ) { + let config = Arc::clone(&self.config); + if config.op_to_pytket(node, op, circ, self)? == EncodeStatus::Success { + return Ok(EncodeStatus::Success); + } } } OpType::LoadConstant(constant) => { @@ -985,6 +998,18 @@ impl PytketEncoderContext { Ok(EncodeStatus::Success) } + /// Check if a node has order edges to nodes outside the region. + /// + /// If that's the case, we don't try to encode the node and report it as + /// unsupported instead. + fn has_order_edges(&mut self, node: H::Node, optype: &OpType, circ: &Circuit) -> bool { + optype + .other_port(Direction::Incoming) + .iter() + .chain(optype.other_port(Direction::Outgoing).iter()) + .any(|&p| circ.hugr().is_linked(node, p)) + } + /// Helper to register values for a node's output wires. /// /// Returns any new value associated with the output wires. @@ -1099,21 +1124,25 @@ impl PytketEncoderContext { /// /// - `node`: The node to register the outputs for. /// - `circ`: The circuit containing the node. - /// - `input_qubits`: The qubit inputs to the operation. - /// - `input_bits`: The bit inputs to the operation. + /// - `qubits`: The qubit registers to use for the output. Values are + /// consumed from this slice as needed, and dropped from the slice as they + /// are used. + /// - `bits`: The bit registers to use for the output. Values are consumed + /// from this slice as needed, and dropped from the slice as they are + /// used. /// - `input_params`: The list of input parameter expressions. - /// - `options`: Options for controlling the output qubit, bits, and - /// parameter expressions. + /// - `options_params_fn`: A function that computes the output parameter + /// expressions given the inputs. #[allow(clippy::too_many_arguments)] fn register_port_output( &mut self, node: H::Node, port: OutgoingPort, circ: &Circuit, - input_qubits: &[TrackedQubit], - input_bits: &[TrackedBit], + qubits: &mut &[TrackedQubit], + bits: &mut &[TrackedBit], input_params: &[String], - options: EmitCommandOptions, + output_params_fn: impl FnOnce(OutputParamArgs<'_>) -> Vec, ) -> Result> { let wire = Wire::new(node, port); @@ -1131,23 +1160,11 @@ impl PytketEncoderContext { ))); }; - let output_qubits = match options.reuse_qubits_fn { - Some(f) => f(input_qubits), - None => input_qubits.to_vec(), - }; - let output_bits = match options.reuse_bits_fn { - Some(f) => f(input_bits), - None => input_bits.to_vec(), - }; - // Compute all the output parameters at once - let out_params = match options.output_params_fn { - Some(f) => f(OutputParamArgs { - expected_count: count.params, - input_params, - }), - None => Vec::new(), - }; + let out_params = output_params_fn(OutputParamArgs { + expected_count: count.params, + input_params, + }); // Check that we got the expected number of outputs. if out_params.len() != count.params { @@ -1166,33 +1183,44 @@ impl PytketEncoderContext { let mut out_wire_values = Vec::with_capacity(count.total()); // Qubits - out_wire_values.extend( - output_qubits - .into_iter() - .take(count.qubits) - .map(TrackedValue::Qubit), - ); - for _ in out_wire_values.len()..count.qubits { - // If we already assigned all input qubit ids, get a fresh one. - let qb = self.values.new_qubit(); - new_outputs.qubits.push(qb); - out_wire_values.push(TrackedValue::Qubit(qb)); - } + // Reuse the ones from `qubits`, dropping them from the slice, + // and allocate new ones as needed. + let output_qubits = match split_off(qubits, ..count.qubits) { + Some(reused_qubits) => reused_qubits.to_vec(), + None => { + // Not enough qubits, allocate some fresh ones. + let mut head_qubits = qubits.to_vec(); + *qubits = &[]; + let new_qubits = (head_qubits.len()..count.qubits).map(|_| { + let q = self.values.new_qubit(); + new_outputs.qubits.push(q); + q + }); + head_qubits.extend(new_qubits); + head_qubits + } + }; + out_wire_values.extend(output_qubits.iter().map(|&q| TrackedValue::Qubit(q))); // Bits - let non_bit_count = out_wire_values.len(); - out_wire_values.extend( - output_bits - .into_iter() - .take(count.bits) - .map(TrackedValue::Bit), - ); - let reused_bit_count = out_wire_values.len() - non_bit_count; - for _ in reused_bit_count..count.bits { - let b = self.values.new_bit(); - new_outputs.bits.push(b); - out_wire_values.push(TrackedValue::Bit(b)); - } + // Reuse the ones from `bits`, dropping them from the slice, + // and allocate new ones as needed. + let output_bits = match split_off(bits, ..count.bits) { + Some(reused_bits) => reused_bits.to_vec(), + None => { + // Not enough bits, allocate some fresh ones. + let mut head_bits = bits.to_vec(); + *bits = &[]; + let new_bits = (head_bits.len()..count.bits).map(|_| { + let b = self.values.new_bit(); + new_outputs.bits.push(b); + b + }); + head_bits.extend(new_bits); + head_bits + } + }; + out_wire_values.extend(output_bits.iter().map(|&b| TrackedValue::Bit(b))); // Parameters for expr in out_params.into_iter().take(count.params) { @@ -1451,3 +1479,14 @@ pub fn make_tk1_classical_expression( op.classical_expr = Some(clexpr); op } + +// TODO: Replace with array's `split_off` method once MSRV is ≥1.87 +fn split_off<'a, T>(slice: &mut &'a [T], range: RangeTo) -> Option<&'a [T]> { + let split_index = range.end; + if split_index > slice.len() { + return None; + } + let (front, back) = slice.split_at(split_index); + *slice = back; + Some(front) +} diff --git a/tket/src/serialize/pytket/tests.rs b/tket/src/serialize/pytket/tests.rs index 706b0f9b4..a2755dab9 100644 --- a/tket/src/serialize/pytket/tests.rs +++ b/tket/src/serialize/pytket/tests.rs @@ -624,6 +624,30 @@ fn circ_unsupported_io_wire() -> Circuit { hugr.into() } +// Nodes with order edges should be marked as unsupported to preserve the connection. +#[fixture] +fn order_edge() -> Circuit { + let input_t = vec![qb_t(), qb_t()]; + let output_t = vec![qb_t(), qb_t()]; + let mut h = FunctionBuilder::new("order_edge", Signature::new(input_t, output_t)).unwrap(); + + let [q1, q2] = h.input_wires_arr(); + + let cx1 = h.add_dataflow_op(TketOp::CX, [q1, q2]).unwrap(); + let [q1, q2] = cx1.outputs_arr(); + + let cx2 = h.add_dataflow_op(TketOp::CX, [q1, q2]).unwrap(); + let [q1, q2] = cx2.outputs_arr(); + + let cx3 = h.add_dataflow_op(TketOp::CX, [q1, q2]).unwrap(); + let [q1, q2] = cx3.outputs_arr(); + + h.set_order(&cx1, &cx3); + + let hugr = h.finish_hugr_with_outputs([q1, q2]).unwrap(); + hugr.into() +} + /// A circuit that requires tracking info in `extra_subgraph` or `straight_through_wires` /// (see `EncodedCircuitInfo`), for a nested circuit in a CircBox. #[fixture] @@ -919,6 +943,7 @@ fn fail_on_modified_hugr(circ_tk1_ops: Circuit) { #[case::recursive(circ_recursive(), 1, CircuitRoundtripTestConfig::Default)] #[case::independent_subgraph(circ_independent_subgraph(), 3, CircuitRoundtripTestConfig::Default)] #[case::unsupported_io_wire(circ_unsupported_io_wire(), 1, CircuitRoundtripTestConfig::Default)] +#[case::order_edge(order_edge(), 1, CircuitRoundtripTestConfig::Default)] // TODO: We need to track [`EncodedCircuitInfo`] for nested CircBoxes too. We // have temporarily disabled encoding of DFG and function calls as CircBoxes to // avoid an error here. From 5de8616d371203bc2d9e696510a1f9cbbfef0b2d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Agust=C3=ADn=20Borgna?= Date: Thu, 6 Nov 2025 12:42:15 +0000 Subject: [PATCH 04/11] fix: Internal-only bits not appearing in pytket circuit's bit register list --- tket/src/serialize/pytket/encoder/value_tracker.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tket/src/serialize/pytket/encoder/value_tracker.rs b/tket/src/serialize/pytket/encoder/value_tracker.rs index b85d9e091..e78fec3ce 100644 --- a/tket/src/serialize/pytket/encoder/value_tracker.rs +++ b/tket/src/serialize/pytket/encoder/value_tracker.rs @@ -486,7 +486,7 @@ impl ValueTracker { Ok(ValueTrackerResult { qubits: self.qubits, - bits: bit_outputs, + bits: self.bits, params: param_outputs, qubit_permutation, input_params: self.input_params, From 56d2be507d78d572088e832e7e2ef42491b05e78 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Agust=C3=ADn=20Borgna?= Date: Thu, 6 Nov 2025 14:06:03 +0000 Subject: [PATCH 05/11] feat: Initial piping for translating isomorphic types --- .../serialize/pytket/config/decoder_config.rs | 23 ++++++ .../pytket/config/type_translators.rs | 75 +++++++++++++++++++ tket/src/serialize/pytket/decoder.rs | 22 ++++-- tket/src/serialize/pytket/decoder/subgraph.rs | 6 +- tket/src/serialize/pytket/decoder/wires.rs | 55 +++++++++++--- tket/src/serialize/pytket/error.rs | 14 ++++ tket/src/serialize/pytket/tests.rs | 29 ++++++- 7 files changed, 201 insertions(+), 23 deletions(-) diff --git a/tket/src/serialize/pytket/config/decoder_config.rs b/tket/src/serialize/pytket/config/decoder_config.rs index f4ea1b825..537ec12c6 100644 --- a/tket/src/serialize/pytket/config/decoder_config.rs +++ b/tket/src/serialize/pytket/config/decoder_config.rs @@ -4,7 +4,9 @@ //! A configuration struct contains a list of custom decoders that define //! translations of legacy tket primitives into HUGR operations. +use hugr::builder::DFGBuilder; use hugr::types::Type; +use hugr::{Hugr, Wire}; use itertools::Itertools; use std::collections::HashMap; @@ -107,4 +109,25 @@ impl PytketDecoderConfig { pub fn type_to_pytket(&self, typ: &Type) -> Option { self.type_translators.type_to_pytket(typ) } + + /// Returns `true` if the two types are isomorphic. I.e. they can be translated + /// into each other without losing information. + pub fn types_are_isomorphic(&self, typ1: &Type, typ2: &Type) -> bool { + self.type_translators.types_are_isomorphic(typ1, typ2) + } + + /// Inserts the necessary operations to translate a type into an isomorphic + /// type. + /// + /// This operation fails if [`Self::types_are_isomorphic`] returns `false`. + pub(in crate::serialize::pytket) fn transform_typed_value( + &self, + wire: Wire, + initial_type: &Type, + target_type: &Type, + builder: &mut DFGBuilder<&mut Hugr>, + ) -> Result { + self.type_translators + .transform_typed_value(wire, initial_type, target_type, builder) + } } diff --git a/tket/src/serialize/pytket/config/type_translators.rs b/tket/src/serialize/pytket/config/type_translators.rs index 431e04b24..2e651bd34 100644 --- a/tket/src/serialize/pytket/config/type_translators.rs +++ b/tket/src/serialize/pytket/config/type_translators.rs @@ -6,12 +6,16 @@ use std::collections::HashMap; use std::sync::RwLock; +use hugr::builder::{BuildError, DFGBuilder, Dataflow}; use hugr::extension::prelude::bool_t; use hugr::extension::ExtensionId; use hugr::types::{Type, TypeEnum}; +use hugr::{Hugr, Wire}; use itertools::Itertools; +use crate::extension::bool::BoolOp; use crate::serialize::pytket::extension::{PytketTypeTranslator, RegisterCount}; +use crate::serialize::pytket::{PytketDecodeError, PytketDecodeErrorInner}; /// A set of [`PytketTypeTranslator`]s that can be used to translate HUGR types /// into pytket registers (qubits, bits, and parameter expressions). @@ -128,6 +132,77 @@ impl TypeTranslatorSet { .flatten() .map(move |idx| &self.type_translators[*idx]) } + + /// Returns `true` if the two types are isomorphic. I.e. they can be translated + /// into each other without losing information. + // + // TODO: We should allow custom TypeTranslators to expand this checks, + // and implement their own translations. + pub fn types_are_isomorphic(&self, typ1: &Type, typ2: &Type) -> bool { + if typ1 == typ2 { + return true; + } + + // For now, we just hard-code this to the two kind of bits we support. + let native_bool = bool_t(); + let tket_bool = crate::extension::bool::bool_type(); + if (typ1 == &native_bool && typ2 == &tket_bool) + || (typ1 == &tket_bool && typ2 == &native_bool) + { + return true; + } + + false + } + + /// Inserts the necessary operations to translate a type into an isomorphic + /// type. + /// + /// This operation fails if [`Self::types_are_isomorphic`] returns `false`. + pub(super) fn transform_typed_value( + &self, + wire: Wire, + initial_type: &Type, + target_type: &Type, + builder: &mut DFGBuilder<&mut Hugr>, + ) -> Result { + if initial_type == target_type { + return Ok(wire); + } + + let map_build_error = |e: BuildError| PytketDecodeErrorInner::CannotTranslateWire { + wire, + initial_type: initial_type.to_string(), + target_type: target_type.to_string(), + context: Some(e.to_string()), + }; + + // Hard-coded transformations until customs calls are added to [`PytketTypeTranslator`]. + let native_bool = bool_t(); + let tket_bool = crate::extension::bool::bool_type(); + if initial_type == &native_bool && target_type == &tket_bool { + let [wire] = builder + .add_dataflow_op(BoolOp::make_opaque, [wire]) + .map_err(map_build_error)? + .outputs_arr(); + return Ok(wire); + } + if initial_type == &tket_bool && target_type == &native_bool { + let [wire] = builder + .add_dataflow_op(BoolOp::read, [wire]) + .map_err(map_build_error)? + .outputs_arr(); + return Ok(wire); + } + + Err(PytketDecodeErrorInner::CannotTranslateWire { + wire, + initial_type: initial_type.to_string(), + target_type: target_type.to_string(), + context: None, + } + .wrap()) + } } #[cfg(test)] diff --git a/tket/src/serialize/pytket/decoder.rs b/tket/src/serialize/pytket/decoder.rs index 1cfa02522..503750126 100644 --- a/tket/src/serialize/pytket/decoder.rs +++ b/tket/src/serialize/pytket/decoder.rs @@ -341,7 +341,8 @@ impl<'h> PytketDecoderContext<'h> { let found_wire = self .wire_tracker .find_typed_wire( - self.config(), + &self.config, + &mut self.builder, ty, &mut qubits, &mut bits, @@ -417,7 +418,8 @@ impl<'h> PytketDecoderContext<'h> { for q in qubits.iter() { let mut qubit_args: &[TrackedQubit] = std::slice::from_ref(q); let Ok(FoundWire::Register(wire)) = self.wire_tracker.find_typed_wire( - self.config(), + &self.config, + &mut self.builder, &qb_type, &mut qubit_args, &mut bit_args, @@ -540,14 +542,20 @@ impl<'h> PytketDecoderContext<'h> { /// - [`PytketDecodeErrorInner::UnexpectedInputType`] if a type in `types` cannot be mapped to a [`RegisterCount`] /// - [`PytketDecodeErrorInner::NoMatchingWire`] if there is no wire with the requested type for the given qubit/bit arguments. pub fn find_typed_wires( - &self, + &mut self, types: &[Type], qubit_args: &[TrackedQubit], bit_args: &[TrackedBit], params: &[LoadedParameter], ) -> Result { - self.wire_tracker - .find_typed_wires(self.config(), types, qubit_args, bit_args, params) + self.wire_tracker.find_typed_wires( + &self.config, + &mut self.builder, + types, + qubit_args, + bit_args, + params, + ) } /// Connects the input ports of a node using a list of input qubits, bits, @@ -663,8 +671,8 @@ impl<'h> PytketDecoderContext<'h> { } // Gather the input wires, with the types needed by the operation. - let input_wires = - self.find_typed_wires(sig.input_types(), input_qubits, input_bits, params)?; + let input_types = sig.input_types().to_vec(); + let input_wires = self.find_typed_wires(&input_types, input_qubits, input_bits, params)?; debug_assert_eq!(op_input_count, input_wires.register_count()); for (input_idx, wire) in input_wires.wires().enumerate() { diff --git a/tket/src/serialize/pytket/decoder/subgraph.rs b/tket/src/serialize/pytket/decoder/subgraph.rs index 64f96ea63..b74d2a2ec 100644 --- a/tket/src/serialize/pytket/decoder/subgraph.rs +++ b/tket/src/serialize/pytket/decoder/subgraph.rs @@ -109,7 +109,8 @@ impl<'h> PytketDecoderContext<'h> { .zip_eq(subgraph.incoming_ports()) { let found_wire = self.wire_tracker.find_typed_wire( - self.config(), + &self.config, + &mut self.builder, ty, &mut input_qubits, &mut input_bits, @@ -354,7 +355,8 @@ impl<'h> PytketDecoderContext<'h> { // outputs. for ((ty, edge_id), targets) in payload_inputs.iter().zip_eq(to_insert_inputs) { let found_wire = self.wire_tracker.find_typed_wire( - self.config(), + &self.config, + &mut self.builder, ty, &mut input_qubits, &mut input_bits, diff --git a/tket/src/serialize/pytket/decoder/wires.rs b/tket/src/serialize/pytket/decoder/wires.rs index dc3ace02f..2fe1f253e 100644 --- a/tket/src/serialize/pytket/decoder/wires.rs +++ b/tket/src/serialize/pytket/decoder/wires.rs @@ -508,6 +508,18 @@ impl WireTracker { Ok(self.get_bit(id)) } + /// Mark all the values in a wire as outdated. + fn mark_wire_outdated(&mut self, wire: Wire) { + let wire_data = &self.wires[&wire]; + + for qubit in &wire_data.qubits { + self.qubits[qubit.0].mark_outdated(); + } + for bit in &wire_data.bits { + self.bits[bit.0].mark_outdated(); + } + } + /// Mark a qubit as outdated, without adding a new wire containing the fresh value. /// /// This is used when a hugr operation consumes pytket registers as its inputs, but doesn't use them in the outputs. @@ -598,6 +610,9 @@ impl WireTracker { /// The qubit and bit arguments are only consumed as required by the type, /// some registers may be left unused. /// + /// If the wire type require additional conversion, some operations will be + /// added to the Hugr to perform it. + /// /// # Arguments /// /// * `config` - The configuration for the decoder, used to count the qubits @@ -614,9 +629,11 @@ impl WireTracker { /// # Errors /// /// See [`WireTracker::find_typed_wires`] for possible errors. + #[allow(clippy::too_many_arguments)] pub(in crate::serialize::pytket) fn find_typed_wire( - &self, + &mut self, config: &PytketDecoderConfig, + builder: &mut DFGBuilder<&mut Hugr>, ty: &Type, qubit_args: &mut &[TrackedQubit], bit_args: &mut &[TrackedBit], @@ -668,18 +685,18 @@ impl WireTracker { let wire_qubits = qubit_args .iter() .take(reg_count.qubits) - .map(|q| q.id()) - .collect_vec(); - let wire_bits = bit_args - .iter() - .take(reg_count.bits) - .map(|bit| bit.id()) + .cloned() .collect_vec(); + let wire_qubit_ids = wire_qubits.iter().map(|q| q.id()).collect_vec(); + let wire_bits = bit_args.iter().take(reg_count.bits).cloned().collect_vec(); + let wire_bit_ids = wire_bits.iter().map(|bit| bit.id()).collect_vec(); // Find a wire that contains the correct type.. let check_wire = |w: &Wire| { let wire_data = &self.wires[w]; - wire_data.ty() == ty && wire_data.qubits == wire_qubits && wire_data.bits == wire_bits + wire_data.qubits == wire_qubit_ids + && wire_data.bits == wire_bit_ids + && config.types_are_isomorphic(wire_data.ty(), ty) }; let Some(wire) = candidate.find(check_wire) else { return Err(PytketDecodeErrorInner::NoMatchingWire { @@ -695,6 +712,7 @@ impl WireTracker { } .wrap()); }; + drop(candidate); // Check that none of the selected qubit or bit has been marked as outdated. if let Some(qubit) = qubit_args @@ -719,11 +737,21 @@ impl WireTracker { } // Mark the qubits and bits as used. - // TODO: We can use the slice `split_off` method once MSRV is ≥1.87 *qubit_args = &qubit_args[reg_count.qubits..]; *bit_args = &bit_args[reg_count.bits..]; - Ok(FoundWire::Register(self.wires[&wire].clone())) + // Convert the wire type, if needed. + let wire_data = &self.wires[&wire]; + let new_wire = config.transform_typed_value(wire, wire_data.ty(), ty, builder)?; + + if wire == new_wire { + Ok(FoundWire::Register(self.wires[&wire].clone())) + } else { + let ty: Arc = wire_data.ty.clone(); + self.track_wire(new_wire, ty, wire_qubits, wire_bits)?; + self.mark_wire_outdated(wire); + Ok(FoundWire::Register(self.wires[&new_wire].clone())) + } } /// Returns a new [TrackedWires] set for a list of [`TrackedQubit`]s, @@ -735,6 +763,9 @@ impl WireTracker { /// The qubit and bit arguments are only consumed as required by the types. /// Some registers may be left unused. /// + /// If the wire type require additional conversion, some operations will be + /// added to the Hugr to perform it. + /// /// # Arguments /// /// * `config` - The configuration for the decoder, used to count the qubits and bits required by each type. @@ -751,8 +782,9 @@ impl WireTracker { /// - [`PytketDecodeErrorInner::UnexpectedInputType`] if a type in `types` cannot be mapped to a [`RegisterCount`] /// - [`PytketDecodeErrorInner::NoMatchingWire`] if there is no wire with the requested type for the given qubit/bit arguments. pub(super) fn find_typed_wires( - &self, + &mut self, config: &PytketDecoderConfig, + builder: &mut DFGBuilder<&mut Hugr>, types: &[Type], mut qubit_args: &[TrackedQubit], mut bit_args: &[TrackedBit], @@ -768,6 +800,7 @@ impl WireTracker { for ty in types { match self.find_typed_wire( config, + builder, ty, &mut qubit_args, &mut bit_args, diff --git a/tket/src/serialize/pytket/error.rs b/tket/src/serialize/pytket/error.rs index cca397d38..9c7924a95 100644 --- a/tket/src/serialize/pytket/error.rs +++ b/tket/src/serialize/pytket/error.rs @@ -432,6 +432,20 @@ pub enum PytketDecodeErrorInner { /// The envelope decoding error. source: EnvelopeError, }, + /// Cannot translate a wire from one type to another. + #[display("Cannot translate {wire} from type {initial_type} to type {target_type}{}", + context.as_ref().map(|s| format!(". {s}")).unwrap_or_default() + )] + CannotTranslateWire { + /// The wire that couldn't be translated. + wire: Wire, + /// The initial type of the wire. + initial_type: String, + /// The target type of the wire. + target_type: String, + /// The error that occurred while translating the wire. + context: Option, + }, } impl PytketDecodeErrorInner { diff --git a/tket/src/serialize/pytket/tests.rs b/tket/src/serialize/pytket/tests.rs index a2755dab9..e85dc1e66 100644 --- a/tket/src/serialize/pytket/tests.rs +++ b/tket/src/serialize/pytket/tests.rs @@ -27,7 +27,7 @@ use tket_json_rs::register; use super::{TKETDecode, METADATA_INPUT_PARAMETERS, METADATA_Q_REGISTERS}; use crate::circuit::Circuit; -use crate::extension::bool::BoolOp; +use crate::extension::bool::{bool_type, BoolOp}; use crate::extension::rotation::{rotation_type, ConstRotation, RotationOp}; use crate::extension::sympy::SympyOpDef; use crate::extension::TKET1_EXTENSION_ID; @@ -626,7 +626,7 @@ fn circ_unsupported_io_wire() -> Circuit { // Nodes with order edges should be marked as unsupported to preserve the connection. #[fixture] -fn order_edge() -> Circuit { +fn circ_order_edge() -> Circuit { let input_t = vec![qb_t(), qb_t()]; let output_t = vec![qb_t(), qb_t()]; let mut h = FunctionBuilder::new("order_edge", Signature::new(input_t, output_t)).unwrap(); @@ -648,6 +648,28 @@ fn order_edge() -> Circuit { hugr.into() } +// Nodes with order edges should be marked as unsupported to preserve the connection. +#[fixture] +fn circ_bool_conversion() -> Circuit { + let input_t = vec![bool_t(), bool_type()]; + let output_t = vec![bool_t(), bool_type()]; + let mut h = FunctionBuilder::new("bool_conversion", Signature::new(input_t, output_t)).unwrap(); + + let [native_b0, tket_b1] = h.input_wires_arr(); + + let [tket_b0] = h + .add_dataflow_op(BoolOp::make_opaque, [native_b0]) + .unwrap() + .outputs_arr(); + let [native_b1] = h + .add_dataflow_op(BoolOp::read, [tket_b1]) + .unwrap() + .outputs_arr(); + + let hugr = h.finish_hugr_with_outputs([native_b1, tket_b0]).unwrap(); + hugr.into() +} + /// A circuit that requires tracking info in `extra_subgraph` or `straight_through_wires` /// (see `EncodedCircuitInfo`), for a nested circuit in a CircBox. #[fixture] @@ -943,7 +965,8 @@ fn fail_on_modified_hugr(circ_tk1_ops: Circuit) { #[case::recursive(circ_recursive(), 1, CircuitRoundtripTestConfig::Default)] #[case::independent_subgraph(circ_independent_subgraph(), 3, CircuitRoundtripTestConfig::Default)] #[case::unsupported_io_wire(circ_unsupported_io_wire(), 1, CircuitRoundtripTestConfig::Default)] -#[case::order_edge(order_edge(), 1, CircuitRoundtripTestConfig::Default)] +#[case::order_edge(circ_order_edge(), 1, CircuitRoundtripTestConfig::Default)] +#[case::bool_conversion(circ_bool_conversion(), 1, CircuitRoundtripTestConfig::Default)] // TODO: We need to track [`EncodedCircuitInfo`] for nested CircBoxes too. We // have temporarily disabled encoding of DFG and function calls as CircBoxes to // avoid an error here. From 431b8e67dd8ff126c6ba0864ee29f24f58b7a6b9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Agust=C3=ADn=20Borgna?= Date: Thu, 6 Nov 2025 14:33:37 +0000 Subject: [PATCH 06/11] fix: Allocate qbs declared by pytket only if consumed --- tket/src/serialize/pytket/decoder.rs | 14 ++-- .../serialize/pytket/decoder/tracked_elem.rs | 10 +++ tket/src/serialize/pytket/decoder/wires.rs | 79 ++++++++++++++----- tket/src/serialize/pytket/tests.rs | 2 - 4 files changed, 79 insertions(+), 26 deletions(-) diff --git a/tket/src/serialize/pytket/decoder.rs b/tket/src/serialize/pytket/decoder.rs index 503750126..36b491a8b 100644 --- a/tket/src/serialize/pytket/decoder.rs +++ b/tket/src/serialize/pytket/decoder.rs @@ -18,7 +18,7 @@ use std::sync::Arc; use hugr::builder::{BuildHandle, Container, DFGBuilder, Dataflow, FunctionBuilder, SubContainer}; use hugr::extension::prelude::{bool_t, qb_t}; use hugr::ops::handle::{DataflowOpID, NodeHandle}; -use hugr::ops::{OpParent, OpTrait, OpType, Value, DFG}; +use hugr::ops::{OpParent, OpTrait, OpType, DFG}; use hugr::types::{Signature, Type, TypeRow}; use hugr::{Hugr, HugrView, Node, OutgoingPort, Wire}; use tracked_elem::{TrackedBitId, TrackedQubitId}; @@ -272,14 +272,16 @@ impl<'h> PytketDecoderContext<'h> { wire_tracker.register_input_parameter(LoadedParameter::rotation(wire), param)?; } - // Any additional qubits or bits required by the circuit get initialized to |0> / false. + // Any additional qubits or bits required by the circuit are registered + // in the tracker without a wire being created. + // + // We'll lazily initialize them with a QAlloc or a LoadConstant + // operation if necessary. for q in qubits { - let q_wire = dfg.add_dataflow_op(TketOp::QAlloc, []).unwrap().out_wire(0); - wire_tracker.track_wire(q_wire, q.ty(), [q], [])?; + wire_tracker.track_qubit(q.pytket_register_arc(), Some(q.reg_hash()))?; } for b in bits { - let b_wire = dfg.add_load_value(Value::false_val()); - wire_tracker.track_wire(b_wire, b.ty(), [], [b])?; + wire_tracker.track_bit(b.pytket_register_arc(), Some(b.reg_hash()))?; } wire_tracker.compute_output_permutation(&serialcirc.implicit_permutation); diff --git a/tket/src/serialize/pytket/decoder/tracked_elem.rs b/tket/src/serialize/pytket/decoder/tracked_elem.rs index b1e9b93b2..95d2eb643 100644 --- a/tket/src/serialize/pytket/decoder/tracked_elem.rs +++ b/tket/src/serialize/pytket/decoder/tracked_elem.rs @@ -105,6 +105,11 @@ impl TrackedQubit { self.id } + /// Returns the hash of the pytket register for this tracked element. + pub(super) fn reg_hash(&self) -> RegisterHash { + self.reg_hash + } + /// Returns `true` if the element has been overwritten by a new value. pub fn is_outdated(&self) -> bool { self.outdated @@ -158,6 +163,11 @@ impl TrackedBit { self.id } + /// Returns the hash of the pytket register for this tracked element. + pub(super) fn reg_hash(&self) -> RegisterHash { + self.reg_hash + } + /// Returns `true` if the element has been overwritten by a new value. pub fn is_outdated(&self) -> bool { self.outdated diff --git a/tket/src/serialize/pytket/decoder/wires.rs b/tket/src/serialize/pytket/decoder/wires.rs index 2fe1f253e..d206245a5 100644 --- a/tket/src/serialize/pytket/decoder/wires.rs +++ b/tket/src/serialize/pytket/decoder/wires.rs @@ -4,6 +4,7 @@ use std::collections::{BTreeMap, VecDeque}; use std::sync::Arc; use hugr::builder::{DFGBuilder, Dataflow as _}; +use hugr::extension::prelude::{bool_t, qb_t}; use hugr::hugr::hugrmut::HugrMut; use hugr::ops::Value; use hugr::std_extensions::arithmetic::float_types::{float64_type, ConstF64}; @@ -14,6 +15,7 @@ use itertools::Itertools; use tket_json_rs::circuit_json::ImplicitPermutation; use tket_json_rs::register::ElementId as PytketRegister; +use crate::extension::bool::bool_type; use crate::extension::rotation::{rotation_type, ConstRotation}; use crate::serialize::pytket::decoder::param::parser::{parse_pytket_param, PytketParam}; use crate::serialize::pytket::decoder::{ @@ -25,7 +27,7 @@ use crate::serialize::pytket::opaque::EncodedEdgeID; use crate::serialize::pytket::{ PytketDecodeError, PytketDecodeErrorInner, PytketDecoderConfig, RegisterHash, }; -use crate::symbolic_constant_op; +use crate::{symbolic_constant_op, TketOp}; /// Tracked data for a wire in [`TrackedWires`]. #[derive(Debug, Clone, PartialEq)] @@ -481,6 +483,7 @@ impl WireTracker { if let Some(previous_id) = self.latest_qubit_tracker.insert(hash, id) { self.qubits[previous_id.0].mark_outdated(); } + self.qubit_wires.insert(id, Vec::new()); Ok(self.get_qubit(id)) } @@ -504,7 +507,7 @@ impl WireTracker { if let Some(previous_id) = self.latest_bit_tracker.insert(hash, id) { self.bits[previous_id.0].mark_outdated(); } - + self.bit_wires.insert(id, Vec::new()); Ok(self.get_bit(id)) } @@ -679,7 +682,7 @@ impl WireTracker { .first() .into_iter() .flat_map(|bit| self.bit_wires(bit)); - let mut candidate = qubit_candidates.chain(bit_candidates); + let candidates = qubit_candidates.chain(bit_candidates).collect_vec(); // The bits and qubits we expect the wire to contain. let wire_qubits = qubit_args @@ -698,21 +701,30 @@ impl WireTracker { && wire_data.bits == wire_bit_ids && config.types_are_isomorphic(wire_data.ty(), ty) }; - let Some(wire) = candidate.find(check_wire) else { - return Err(PytketDecodeErrorInner::NoMatchingWire { - ty: ty.to_string(), - qubit_args: qubit_args - .iter() - .map(|q| q.pytket_register().to_string()) - .collect(), - bit_args: bit_args - .iter() - .map(|bit| bit.pytket_register().to_string()) - .collect(), + let wire = match candidates.into_iter().find(check_wire) { + Some(wire) => wire, + // Handle lazy initialization of qubit and bit wires. These are + // normally qubits/bits present in the pytket circuit definition, + // but not in the region's input. + _ if ty == &qb_t() => self.initialize_qubit_wire(builder, qubit_args[0].clone())?, + _ if ty == &bool_t() || ty == &bool_type() => { + self.initialize_bit_wire(builder, bit_args[0].clone())? + } + _ => { + return Err(PytketDecodeErrorInner::NoMatchingWire { + ty: ty.to_string(), + qubit_args: qubit_args + .iter() + .map(|q| q.pytket_register().to_string()) + .collect(), + bit_args: bit_args + .iter() + .map(|bit| bit.pytket_register().to_string()) + .collect(), + } + .wrap()); } - .wrap()); }; - drop(candidate); // Check that none of the selected qubit or bit has been marked as outdated. if let Some(qubit) = qubit_args @@ -991,10 +1003,10 @@ impl WireTracker { .collect::>()?; for &q in &qubits { - self.qubit_wires.entry(q).or_default().push(wire); + self.qubit_wires[&q].push(wire); } for &b in &bits { - self.bit_wires.entry(b).or_default().push(wire); + self.bit_wires[&b].push(wire); } let wire_data = WireData { @@ -1085,6 +1097,37 @@ impl WireTracker { } } } + + /// Initialize a qubit wire that has been declared earlier. + /// + /// This is used when a qubit is declared in the pytket circuit definition, + /// but not in the region's input. + fn initialize_qubit_wire( + &mut self, + builder: &mut DFGBuilder<&mut Hugr>, + qubit: TrackedQubit, + ) -> Result { + let wire = builder + .add_dataflow_op(TketOp::QAlloc, []) + .unwrap() + .out_wire(0); + self.track_wire(wire, qubit.ty(), [qubit], [])?; + Ok(wire) + } + + /// Initialize a bit wire that has been declared earlier. + /// + /// This is used when a bit is declared in the pytket circuit definition, + /// but not in the region's input. + fn initialize_bit_wire( + &mut self, + builder: &mut DFGBuilder<&mut Hugr>, + bit: TrackedBit, + ) -> Result { + let wire = builder.add_load_const(Value::false_val()); + self.track_wire(wire, bit.ty(), [], [bit])?; + Ok(wire) + } } /// Only single-indexed registers are supported. diff --git a/tket/src/serialize/pytket/tests.rs b/tket/src/serialize/pytket/tests.rs index e85dc1e66..873f0e96f 100644 --- a/tket/src/serialize/pytket/tests.rs +++ b/tket/src/serialize/pytket/tests.rs @@ -958,8 +958,6 @@ fn fail_on_modified_hugr(circ_tk1_ops: Circuit) { #[case::preset_parameterized(circ_parameterized(), 1, CircuitRoundtripTestConfig::Default)] #[case::nested_dfgs(circ_nested_dfgs(), 2, CircuitRoundtripTestConfig::Default)] #[case::flat_opaque(circ_tk1_ops(), 1, CircuitRoundtripTestConfig::Default)] -// TODO: Fail due to eagerly emitting QAllocs that never get consumed. We should do that lazily. -#[should_panic(expected = "has an unconnected port")] #[case::unsupported_subtree(circ_unsupported_subtree(), 3, CircuitRoundtripTestConfig::Default)] #[case::global_defs(circ_global_defs(), 1, CircuitRoundtripTestConfig::Default)] #[case::recursive(circ_recursive(), 1, CircuitRoundtripTestConfig::Default)] From 85ccdeefe2af653c53652440cc48c2fe9c12a784 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Agust=C3=ADn=20Borgna?= Date: Thu, 6 Nov 2025 15:46:21 +0000 Subject: [PATCH 07/11] fix: Make sure output params have the right type --- tket/src/serialize/pytket/decoder.rs | 10 +++++++++- tket/src/serialize/pytket/tests.rs | 10 ++++++++-- 2 files changed, 17 insertions(+), 3 deletions(-) diff --git a/tket/src/serialize/pytket/decoder.rs b/tket/src/serialize/pytket/decoder.rs index 36b491a8b..e15fa88eb 100644 --- a/tket/src/serialize/pytket/decoder.rs +++ b/tket/src/serialize/pytket/decoder.rs @@ -7,6 +7,7 @@ mod wires; use hugr::extension::ExtensionRegistry; use hugr::hugr::hugrmut::HugrMut; +use hugr::std_extensions::arithmetic::float_types::float64_type; pub use param::{LoadedParameter, ParameterType}; pub use tracked_elem::{TrackedBit, TrackedQubit}; pub use wires::TrackedWires; @@ -372,7 +373,14 @@ impl<'h> PytketDecoderContext<'h> { let wire = match found_wire { FoundWire::Register(wire) => wire.wire(), - FoundWire::Parameter(param) => param.wire(), + FoundWire::Parameter(param) => { + let param_ty = if ty == &float64_type() { + ParameterType::FloatHalfTurns + } else { + ParameterType::Rotation + }; + param.with_type(param_ty, &mut self.builder).wire() + } FoundWire::Unsupported { .. } => { // Disconnected port with an unsupported type. We just skip // it, since it must have been disconnected in the original diff --git a/tket/src/serialize/pytket/tests.rs b/tket/src/serialize/pytket/tests.rs index 873f0e96f..5e686ff4b 100644 --- a/tket/src/serialize/pytket/tests.rs +++ b/tket/src/serialize/pytket/tests.rs @@ -708,7 +708,7 @@ fn circ_unsupported_extras_in_circ_box() -> Circuit { #[fixture] fn circ_output_parameter_wire() -> Circuit { let input_t = vec![]; - let output_t = vec![float64_type()]; + let output_t = vec![float64_type(), rotation_type()]; let mut h = FunctionBuilder::new("output_parameter_wire", Signature::new(input_t, output_t)).unwrap(); @@ -718,8 +718,14 @@ fn circ_output_parameter_wire() -> Circuit { .add_dataflow_op(FloatOps::fmul, [pi, two]) .unwrap() .out_wire(0); + let two_pi_rotation = h + .add_dataflow_op(RotationOp::from_halfturns_unchecked, [two_pi]) + .unwrap() + .out_wire(0); - let hugr = h.finish_hugr_with_outputs([two_pi]).unwrap(); + let hugr = h + .finish_hugr_with_outputs([two_pi, two_pi_rotation]) + .unwrap(); hugr.into() } From a585069faf224370d9d03cac2f20fb7090455bca Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Agust=C3=ADn=20Borgna?= Date: Thu, 6 Nov 2025 17:44:24 +0000 Subject: [PATCH 08/11] fix: Don't try to encode wires with complex types that include a parameter --- .../pytket/config/type_translators.rs | 20 +++++++++++++++++-- tket/src/serialize/pytket/encoder.rs | 10 ++++++---- .../src/serialize/pytket/extension/prelude.rs | 2 +- tket/src/serialize/pytket/tests.rs | 18 ++++++++++++++++- 4 files changed, 42 insertions(+), 8 deletions(-) diff --git a/tket/src/serialize/pytket/config/type_translators.rs b/tket/src/serialize/pytket/config/type_translators.rs index 2e651bd34..961d46942 100644 --- a/tket/src/serialize/pytket/config/type_translators.rs +++ b/tket/src/serialize/pytket/config/type_translators.rs @@ -9,11 +9,13 @@ use std::sync::RwLock; use hugr::builder::{BuildError, DFGBuilder, Dataflow}; use hugr::extension::prelude::bool_t; use hugr::extension::ExtensionId; +use hugr::std_extensions::arithmetic::float_types; use hugr::types::{Type, TypeEnum}; use hugr::{Hugr, Wire}; use itertools::Itertools; use crate::extension::bool::BoolOp; +use crate::extension::rotation; use crate::serialize::pytket::extension::{PytketTypeTranslator, RegisterCount}; use crate::serialize::pytket::{PytketDecodeError, PytketDecodeErrorInner}; @@ -77,6 +79,14 @@ impl TypeTranslatorSet { return count; } + // We currently don't allow user types to contain parameters, + // so we handle rotations and floats manually here. + if typ.as_extension().is_some_and(|ext| { + [float_types::EXTENSION_ID, rotation::ROTATION_EXTENSION_ID].contains(ext.extension()) + }) { + return Some(RegisterCount::only_params(1)); + } + let res = match typ.as_type_enum() { TypeEnum::Sum(sum) => { if sum.num_variants() == 0 { @@ -96,7 +106,8 @@ impl TypeTranslatorSet { } }) .sum(); - count + // Don't allow parameters nested inside other types + count.filter(|c| c.params == 0) } else { None } @@ -105,7 +116,12 @@ impl TypeTranslatorSet { let type_ext = custom.extension(); for encoder in self.translators_for_extension(type_ext) { if let Some(count) = encoder.type_to_pytket(custom, self) { - break 'outer Some(count); + // Don't allow user types with nested parameters + if count.params == 0 { + break 'outer Some(count); + } else { + break 'outer None; + } } } None diff --git a/tket/src/serialize/pytket/encoder.rs b/tket/src/serialize/pytket/encoder.rs index 5268b68ed..fd52e3154 100644 --- a/tket/src/serialize/pytket/encoder.rs +++ b/tket/src/serialize/pytket/encoder.rs @@ -895,10 +895,12 @@ impl PytketEncoderContext { } OpType::Const(op) => { let config = Arc::clone(&self.config); - if let Some(values) = config.const_to_pytket(&op.value, self)? { - let wire = Wire::new(node, 0); - self.values.register_wire(wire, values.into_iter(), circ)?; - return Ok(EncodeStatus::Success); + if self.config().type_to_pytket(&op.get_type()).is_some() { + if let Some(values) = config.const_to_pytket(&op.value, self)? { + let wire = Wire::new(node, 0); + self.values.register_wire(wire, values.into_iter(), circ)?; + return Ok(EncodeStatus::Success); + } } } // TODO: DFG and function call emissions are temporarily disabled, diff --git a/tket/src/serialize/pytket/extension/prelude.rs b/tket/src/serialize/pytket/extension/prelude.rs index 7b7b2ebf7..becef3c2c 100644 --- a/tket/src/serialize/pytket/extension/prelude.rs +++ b/tket/src/serialize/pytket/extension/prelude.rs @@ -102,7 +102,7 @@ impl PreludeEmitter { return Ok(EncodeStatus::Unsupported); }; let count = encoder.config().type_to_pytket(ty); - if count.is_none() { + if count.is_none_or(|c| c.params > 0) { return Ok(EncodeStatus::Unsupported); } } diff --git a/tket/src/serialize/pytket/tests.rs b/tket/src/serialize/pytket/tests.rs index 5e686ff4b..87d7404a8 100644 --- a/tket/src/serialize/pytket/tests.rs +++ b/tket/src/serialize/pytket/tests.rs @@ -17,7 +17,7 @@ use hugr::hugr::hugrmut::HugrMut; use hugr::ops::handle::FuncID; use hugr::ops::{OpParent, OpType, Value}; use hugr::std_extensions::arithmetic::float_ops::FloatOps; -use hugr::types::Signature; +use hugr::types::{Signature, SumType}; use hugr::HugrView; use itertools::Itertools; use rstest::{fixture, rstest}; @@ -729,6 +729,21 @@ fn circ_output_parameter_wire() -> Circuit { hugr.into() } +// A circuit with a [float64] wire, which should be treated as unsupported. +#[fixture] +fn circ_complex_param_type() -> Circuit { + let input_t = vec![]; + let output_t = vec![SumType::new_tuple(vec![float64_type()]).into()]; + let mut h = + FunctionBuilder::new("complex_param_type", Signature::new(input_t, output_t)).unwrap(); + + let float64 = h.add_load_value(ConstF64::new(1.0)); + let float_tuple = h.make_tuple([float64]).unwrap(); + + let hugr = h.finish_hugr_with_outputs([float_tuple]).unwrap(); + hugr.into() +} + /// Check that all circuit ops have been translated to a native gate. /// /// Panics if there are tk1 ops in the circuit. @@ -971,6 +986,7 @@ fn fail_on_modified_hugr(circ_tk1_ops: Circuit) { #[case::unsupported_io_wire(circ_unsupported_io_wire(), 1, CircuitRoundtripTestConfig::Default)] #[case::order_edge(circ_order_edge(), 1, CircuitRoundtripTestConfig::Default)] #[case::bool_conversion(circ_bool_conversion(), 1, CircuitRoundtripTestConfig::Default)] +#[case::complex_param_type(circ_complex_param_type(), 1, CircuitRoundtripTestConfig::Default)] // TODO: We need to track [`EncodedCircuitInfo`] for nested CircBoxes too. We // have temporarily disabled encoding of DFG and function calls as CircBoxes to // avoid an error here. From 2475a905602d5c9b94ae90f358551599bbcedb4e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Agust=C3=ADn=20Borgna?= Date: Thu, 6 Nov 2025 18:10:40 +0000 Subject: [PATCH 09/11] fix: encode input parameters to extra_subgraphs --- tket/src/serialize/pytket.rs | 2 +- tket/src/serialize/pytket/circuit.rs | 34 +++++++++------ tket/src/serialize/pytket/decoder.rs | 48 +++++++++++++-------- tket/src/serialize/pytket/decoder/wires.rs | 11 ++++- tket/src/serialize/pytket/encoder.rs | 18 +++++--- tket/src/serialize/pytket/error.rs | 9 +++- tket/src/serialize/pytket/extension/core.rs | 2 +- 7 files changed, 82 insertions(+), 42 deletions(-) diff --git a/tket/src/serialize/pytket.rs b/tket/src/serialize/pytket.rs index be5dab9a8..07af6e238 100644 --- a/tket/src/serialize/pytket.rs +++ b/tket/src/serialize/pytket.rs @@ -133,7 +133,7 @@ impl TKETDecode for SerialCircuit { options: DecodeOptions, ) -> Result { let mut decoder = PytketDecoderContext::new(self, hugr, target, options, None)?; - decoder.run_decoder(&self.commands, None, &[])?; + decoder.run_decoder(&self.commands, None)?; Ok(decoder.finish(&[])?.node()) } diff --git a/tket/src/serialize/pytket/circuit.rs b/tket/src/serialize/pytket/circuit.rs index 6b3aa8937..7be63ae44 100644 --- a/tket/src/serialize/pytket/circuit.rs +++ b/tket/src/serialize/pytket/circuit.rs @@ -55,15 +55,8 @@ pub struct EncodedCircuit { pub(super) struct EncodedCircuitInfo { /// The serial circuit encoded from the region. pub serial_circuit: SerialCircuit, - /// A subgraph of the region that does not contain any operation encodable - /// as a pytket command, and has no qubit/bits in its boundary that could be - /// used to emit an opaque barrier command in the [`serial_circuit`]. - pub extra_subgraph: Option, - /// List of wires that directly connected the input node to the output node in the encoded region, - /// and were not encoded in [`serial_circuit`]. - /// - /// We just store the input nodes's output port and output node's input port here. - pub straight_through_wires: Vec, + /// Information about any unsupported nodes in the region that could not be encoded as a pytket command. + pub additional_nodes_and_wires: AdditionalNodesAndWires, /// List of parameters in the pytket circuit in the order they appear in the /// hugr input. /// @@ -77,6 +70,24 @@ pub(super) struct EncodedCircuitInfo { pub output_params: Vec, } +/// Nodes and edges from the original region that could not be encoded into the +/// pytket circuit, as they cannot be attached to a pytket command. +#[derive(Debug, Clone)] +pub(super) struct AdditionalNodesAndWires { + /// A subgraph of the region that does not contain any operation encodable + /// as a pytket command, and has no qubit/bits in its boundary that could be + /// used to emit an opaque barrier command in the [`serial_circuit`]. + pub extra_subgraph: Option, + /// Parameter expression inputs to the `extra_subgraph`. + /// These cannot be encoded either if there's no pytket command to attach them to. + pub extra_subgraph_params: Vec, + /// List of wires that directly connected the input node to the output node in the encoded region, + /// and were not encoded in [`serial_circuit`]. + /// + /// We just store the input nodes's output port and output node's input port here. + pub straight_through_wires: Vec, +} + /// A wire stored in the [`EncodedCircuitInfo`] that directly connected the /// input node to the output node in the encoded region, and was not encoded in /// the pytket circuit. @@ -181,8 +192,7 @@ impl EncodedCircuit { )?; decoder.run_decoder( &encoded.serial_circuit.commands, - encoded.extra_subgraph, - &encoded.straight_through_wires, + Some(&encoded.additional_nodes_and_wires), )?; let decoded_node = decoder.finish(&encoded.output_params)?.node(); @@ -333,7 +343,7 @@ impl EncodedCircuit { let mut decoder = PytketDecoderContext::new(serial_circuit, &mut hugr, target, options, None)?; - decoder.run_decoder(&serial_circuit.commands, None, &[])?; + decoder.run_decoder(&serial_circuit.commands, None)?; decoder.finish(&[])?; Ok(hugr) } diff --git a/tket/src/serialize/pytket/decoder.rs b/tket/src/serialize/pytket/decoder.rs index e15fa88eb..0b29d1a44 100644 --- a/tket/src/serialize/pytket/decoder.rs +++ b/tket/src/serialize/pytket/decoder.rs @@ -34,11 +34,11 @@ use super::{ METADATA_Q_REGISTERS, }; use crate::extension::rotation::rotation_type; -use crate::serialize::pytket::circuit::StraightThroughWire; +use crate::serialize::pytket::circuit::{AdditionalNodesAndWires, StraightThroughWire}; use crate::serialize::pytket::config::PytketDecoderConfig; use crate::serialize::pytket::decoder::wires::WireTracker; use crate::serialize::pytket::extension::{build_opaque_tket_op, RegisterCount}; -use crate::serialize::pytket::opaque::{EncodedEdgeID, OpaqueSubgraphs, SubgraphId}; +use crate::serialize::pytket::opaque::{EncodedEdgeID, OpaqueSubgraphs}; use crate::serialize::pytket::{ default_decoder_config, DecodeInsertionTarget, DecodeOptions, PytketDecodeErrorInner, }; @@ -460,8 +460,7 @@ impl<'h> PytketDecoderContext<'h> { pub(super) fn run_decoder( &mut self, commands: &[circuit_json::Command], - extra_subgraph: Option, - straight_through_wires: &[StraightThroughWire], + extra_nodes_and_wires: Option<&AdditionalNodesAndWires>, ) -> Result<(), PytketDecodeError> { let config = self.config().clone(); for com in commands { @@ -470,22 +469,33 @@ impl<'h> PytketDecoderContext<'h> { .map_err(|e| e.pytket_op(&op_type))?; } - // Add additional subgraphs if not encoded in commands. - if let Some(subgraph_id) = extra_subgraph { - self.insert_external_subgraph(subgraph_id, &[], &[], &[]) - .map_err(|e| e.hugr_op("External subgraph"))?; - } - - // Add wires from the input node to the output node that didn't get encoded in commands. + // Add additional subgraphs and wires not encoded in commands. let [input_node, output_node] = self.builder.io(); - for StraightThroughWire { - input_source, - output_target, - } in straight_through_wires - { - self.builder - .hugr_mut() - .connect(input_node, *input_source, output_node, *output_target); + if let Some(extras) = extra_nodes_and_wires { + if let Some(subgraph_id) = extras.extra_subgraph { + let params = extras + .extra_subgraph_params + .iter() + .map(|p| self.load_half_turns(p)) + .collect_vec(); + + self.insert_external_subgraph(subgraph_id, &[], &[], ¶ms) + .map_err(|e| e.hugr_op("External subgraph"))?; + } + + // Add wires from the input node to the output node that didn't get encoded in commands. + for StraightThroughWire { + input_source, + output_target, + } in &extras.straight_through_wires + { + self.builder.hugr_mut().connect( + input_node, + *input_source, + output_node, + *output_target, + ); + } } Ok(()) } diff --git a/tket/src/serialize/pytket/decoder/wires.rs b/tket/src/serialize/pytket/decoder/wires.rs index d206245a5..9867a49aa 100644 --- a/tket/src/serialize/pytket/decoder/wires.rs +++ b/tket/src/serialize/pytket/decoder/wires.rs @@ -657,7 +657,16 @@ impl WireTracker { PytketDecodeErrorInner::NoMatchingParameter { ty: ty.to_string() }.wrap(), ); }; - return Ok(FoundWire::Parameter(*param)); + if ty == param.wire_type() { + return Ok(FoundWire::Parameter(*param)); + } + // Convert between half-turn floats and rotations as needed. + let param_ty = if ty == &float64_type() { + ParameterType::FloatHalfTurns + } else { + ParameterType::Rotation + }; + return Ok(FoundWire::Parameter(param.with_type(param_ty, builder))); } // Translate the wire type to a pytket register count. diff --git a/tket/src/serialize/pytket/encoder.rs b/tket/src/serialize/pytket/encoder.rs index fd52e3154..e2e4b93d3 100644 --- a/tket/src/serialize/pytket/encoder.rs +++ b/tket/src/serialize/pytket/encoder.rs @@ -30,7 +30,7 @@ use super::{ PytketEncodeError, PytketEncodeOpError, METADATA_OPGROUP, METADATA_PHASE, METADATA_Q_REGISTERS, }; use crate::circuit::Circuit; -use crate::serialize::pytket::circuit::EncodedCircuitInfo; +use crate::serialize::pytket::circuit::{AdditionalNodesAndWires, EncodedCircuitInfo}; use crate::serialize::pytket::config::PytketEncoderConfig; use crate::serialize::pytket::extension::RegisterCount; use crate::serialize::pytket::opaque::{ @@ -227,17 +227,19 @@ impl PytketEncoderContext { ) -> Result<(EncodedCircuitInfo, OpaqueSubgraphs), PytketEncodeError> { // Add any remaining unsupported nodes let mut extra_subgraph: Option> = None; + let mut extra_subgraph_params = Vec::new(); while !self.unsupported.is_empty() { let node = self.unsupported.iter().next().unwrap(); let opaque_subgraphs = self.unsupported.extract_component(node, circ.hugr())?; match self.emit_unsupported(&opaque_subgraphs, circ) { Ok(()) => (), - Err(PytketEncodeError::UnsupportedSubgraphHasNoRegisters {}) => { + Err(PytketEncodeError::UnsupportedSubgraphHasNoRegisters { params }) => { // We'll store the nodes in the `extra_subgraph` field of the `EncodedCircuitInfo`. // So the decoder can reconstruct the original subgraph. extra_subgraph .get_or_insert_default() .extend(opaque_subgraphs.nodes().iter().cloned()); + extra_subgraph_params.extend(params); } Err(e) => return Err(e), } @@ -263,8 +265,11 @@ impl PytketEncoderContext { serial_circuit: ser, input_params: tracker_result.input_params, output_params: tracker_result.params, - extra_subgraph, - straight_through_wires: tracker_result.straight_through_wires, + additional_nodes_and_wires: AdditionalNodesAndWires { + extra_subgraph, + extra_subgraph_params, + straight_through_wires: tracker_result.straight_through_wires, + }, }; Ok((info, self.opaque_subgraphs)) @@ -605,7 +610,6 @@ impl PytketEncoderContext { // If the wire is not tracked, no need to consume it. continue; }; - op_values.extend(tracked_values.iter().cloned()); } @@ -654,7 +658,9 @@ impl PytketEncoderContext { // // This should only fail when looking at the "leftover" unsupported nodes at the end of the decoding process. if op_values.qubits.is_empty() && op_values.bits.is_empty() { - return Err(PytketEncodeError::UnsupportedSubgraphHasNoRegisters {}); + return Err(PytketEncodeError::UnsupportedSubgraphHasNoRegisters { + params: input_param_exprs.clone(), + }); } // Create pytket operation, and add the subcircuit as hugr diff --git a/tket/src/serialize/pytket/error.rs b/tket/src/serialize/pytket/error.rs index 9c7924a95..26d1fc342 100644 --- a/tket/src/serialize/pytket/error.rs +++ b/tket/src/serialize/pytket/error.rs @@ -101,8 +101,13 @@ pub enum PytketEncodeError { head_op: String, }, /// No qubits or bits to attach the barrier command to for unsupported nodes. - #[display("An unsupported subgraph has no qubits or bits to attach the barrier command to.")] - UnsupportedSubgraphHasNoRegisters {}, + #[display("An unsupported subgraph has no qubits or bits to attach the barrier command to{}", + if params.is_empty() {"".to_string()} else {format!(" alongside its parameters [{}]", params.iter().join(", "))} + )] + UnsupportedSubgraphHasNoRegisters { + /// Parameter inputs to the unsupported subgraph. + params: Vec, + }, } impl PytketEncodeError { diff --git a/tket/src/serialize/pytket/extension/core.rs b/tket/src/serialize/pytket/extension/core.rs index 55c6b38b6..c7b5128fb 100644 --- a/tket/src/serialize/pytket/extension/core.rs +++ b/tket/src/serialize/pytket/extension/core.rs @@ -86,7 +86,7 @@ impl PytketDecoder for CoreDecoder { options, decoder.opaque_subgraphs, )?; - nested_decoder.run_decoder(&serial_circuit.commands, None, &[])?; + nested_decoder.run_decoder(&serial_circuit.commands, None)?; let internal = nested_decoder.finish(&[])?.node(); decoder From 118b86f77c7620baae2e8280983d50e54fe4a98d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Agust=C3=ADn=20Borgna?= Date: Thu, 6 Nov 2025 18:13:00 +0000 Subject: [PATCH 10/11] Review comments --- tket/src/serialize/pytket/opaque/subgraph.rs | 11 +++++++---- tket/src/serialize/pytket/tests.rs | 2 +- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/tket/src/serialize/pytket/opaque/subgraph.rs b/tket/src/serialize/pytket/opaque/subgraph.rs index a00e5f1b5..a001088cd 100644 --- a/tket/src/serialize/pytket/opaque/subgraph.rs +++ b/tket/src/serialize/pytket/opaque/subgraph.rs @@ -15,8 +15,11 @@ use hugr::{Direction, Hugr, HugrView, IncomingPort, OutgoingPort}; /// A subgraph of nodes in the Hugr that could not be encoded as TKET1 /// operations. /// -/// of subgraphs it can represent; in particular const edges and order edge are -/// not allowed in the boundary. +/// This is a simpler version of [`SiblingSubgraph`] that +/// - Always maps boundary ports to exactly one node port. +/// - Does not require a convex checker to verify convexity (since we always create regions using a toposort). +/// - Allows non-value edges at its boundaries, which can be left unmodified by the encoder/decoder. +/// - Keeps a flag indicating if it can be represented as a valid [`SiblingSubgraph`]. #[derive(Debug, Clone)] pub struct OpaqueSubgraph { /// The nodes in the subgraph. @@ -36,7 +39,7 @@ pub struct OpaqueSubgraph { /// - Having non-local edges. /// - Having const edges to global definitions. /// - Having order edges to nodes outside the subgraph. - /// - Calling a global functions. + /// - Calling global functions. sibling_subgraph_compatible: bool, } @@ -156,7 +159,7 @@ impl OpaqueSubgraph { /// - Having non-local edges. /// - Having const edges to global definitions. /// - Having order edges to nodes outside the subgraph. - /// - Calling a global functions. + /// - Calling global functions. pub fn is_sibling_subgraph_compatible(&self) -> bool { self.sibling_subgraph_compatible } diff --git a/tket/src/serialize/pytket/tests.rs b/tket/src/serialize/pytket/tests.rs index 87d7404a8..32083ec3b 100644 --- a/tket/src/serialize/pytket/tests.rs +++ b/tket/src/serialize/pytket/tests.rs @@ -648,7 +648,7 @@ fn circ_order_edge() -> Circuit { hugr.into() } -// Nodes with order edges should be marked as unsupported to preserve the connection. +// Bool types get converted automatically between native and tket representations. #[fixture] fn circ_bool_conversion() -> Circuit { let input_t = vec![bool_t(), bool_type()]; From 9600a72e53ef35f348533fd85d7b59545e26a21b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Agust=C3=ADn=20Borgna?= Date: Thu, 6 Nov 2025 18:53:36 +0000 Subject: [PATCH 11/11] fix: Move non-local edges originating from decoded region's input --- tket/src/serialize/pytket/circuit.rs | 22 ++++++++++++++++++++++ tket/src/serialize/pytket/tests.rs | 3 --- 2 files changed, 22 insertions(+), 3 deletions(-) diff --git a/tket/src/serialize/pytket/circuit.rs b/tket/src/serialize/pytket/circuit.rs index 7be63ae44..7a9e188c7 100644 --- a/tket/src/serialize/pytket/circuit.rs +++ b/tket/src/serialize/pytket/circuit.rs @@ -8,6 +8,7 @@ use hugr::core::{HugrNode, IncomingPort, OutgoingPort}; use hugr::hugr::hugrmut::HugrMut; use hugr::ops::handle::NodeHandle; use hugr::ops::{OpParent, OpTag, OpTrait}; +use hugr::types::EdgeKind; use hugr::{Hugr, HugrView, Node}; use hugr_core::hugr::internal::HugrMutInternals; use itertools::Itertools; @@ -196,6 +197,27 @@ impl EncodedCircuit { )?; let decoded_node = decoder.finish(&encoded.output_params)?.node(); + // Move any non-local edges from originating from the old input node. + let old_input = hugr.get_io(original_region).unwrap()[0]; + let input_optype = hugr.get_optype(old_input).clone(); + let new_input = hugr.get_io(decoded_node).unwrap()[0]; + for src_port in hugr.node_outputs(old_input).collect_vec() { + for (tgt_node, tgt_port) in hugr.linked_inputs(old_input, src_port).collect_vec() { + let tgt_parent = hugr.get_parent(tgt_node); + let is_local_wire = tgt_parent == Some(original_region); + let is_value_wire = + matches!(input_optype.port_kind(src_port), Some(EdgeKind::Value(_))); + let wire_to_decoded_region = tgt_parent == Some(decoded_node); + // Ignore local wires, as all nodes will be deleted. + // Also ignore value wires to the newly decoded region, + // as they come from transplanted opaque subgraphs that already + // re-connected their inputs. + if !(is_local_wire || (is_value_wire && wire_to_decoded_region)) { + hugr.connect(new_input, src_port, tgt_node, tgt_port); + } + } + } + // Replace the region with the decoded function. // // All descendant nodes that were re-used by the decoded circuit got diff --git a/tket/src/serialize/pytket/tests.rs b/tket/src/serialize/pytket/tests.rs index 32083ec3b..baa8e9349 100644 --- a/tket/src/serialize/pytket/tests.rs +++ b/tket/src/serialize/pytket/tests.rs @@ -996,9 +996,6 @@ fn fail_on_modified_hugr(circ_tk1_ops: Circuit) { CircuitRoundtripTestConfig::Default )] #[case::output_parameter_wire(circ_output_parameter_wire(), 1, CircuitRoundtripTestConfig::Default)] -// TODO: fix edge case: non-local edge from an unsupported node inside a nested CircBox -// to/from the input of the head region being encoded... -#[should_panic(expected = "has an unconnected port")] #[case::non_local(circ_non_local(), 2, CircuitRoundtripTestConfig::Default)] fn encoded_circuit_roundtrip(