Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion tket/src/serialize/pytket.rs
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ impl TKETDecode for SerialCircuit {
options: DecodeOptions,
) -> Result<Node, Self::DecodeError> {
let mut decoder = PytketDecoderContext::new(self, hugr, target, options, None)?;
decoder.run_decoder(&self.commands, None, &[])?;
decoder.run_decoder(&self.commands, None)?;
Ok(decoder.finish(&[])?.node())
}

Expand Down
56 changes: 44 additions & 12 deletions tket/src/serialize/pytket/circuit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ use hugr::core::{HugrNode, IncomingPort, OutgoingPort};
use hugr::hugr::hugrmut::HugrMut;
use hugr::ops::handle::NodeHandle;
use hugr::ops::{OpParent, OpTag, OpTrait};
use hugr::types::EdgeKind;
use hugr::{Hugr, HugrView, Node};
use hugr_core::hugr::internal::HugrMutInternals;
use itertools::Itertools;
Expand Down Expand Up @@ -55,15 +56,8 @@ pub struct EncodedCircuit<Node: HugrNode> {
pub(super) struct EncodedCircuitInfo {
/// The serial circuit encoded from the region.
pub serial_circuit: SerialCircuit,
/// A subgraph of the region that does not contain any operation encodable
/// as a pytket command, and has no qubit/bits in its boundary that could be
/// used to emit an opaque barrier command in the [`serial_circuit`].
pub extra_subgraph: Option<SubgraphId>,
/// List of wires that directly connected the input node to the output node in the encoded region,
/// and were not encoded in [`serial_circuit`].
///
/// We just store the input nodes's output port and output node's input port here.
pub straight_through_wires: Vec<StraightThroughWire>,
/// Information about any unsupported nodes in the region that could not be encoded as a pytket command.
pub additional_nodes_and_wires: AdditionalNodesAndWires,
/// List of parameters in the pytket circuit in the order they appear in the
/// hugr input.
///
Expand All @@ -77,6 +71,24 @@ pub(super) struct EncodedCircuitInfo {
pub output_params: Vec<String>,
}

/// Nodes and edges from the original region that could not be encoded into the
/// pytket circuit, as they cannot be attached to a pytket command.
#[derive(Debug, Clone)]
pub(super) struct AdditionalNodesAndWires {
/// A subgraph of the region that does not contain any operation encodable
/// as a pytket command, and has no qubit/bits in its boundary that could be
/// used to emit an opaque barrier command in the [`serial_circuit`].
pub extra_subgraph: Option<SubgraphId>,
/// Parameter expression inputs to the `extra_subgraph`.
/// These cannot be encoded either if there's no pytket command to attach them to.
pub extra_subgraph_params: Vec<String>,
/// List of wires that directly connected the input node to the output node in the encoded region,
/// and were not encoded in [`serial_circuit`].
///
/// We just store the input nodes's output port and output node's input port here.
pub straight_through_wires: Vec<StraightThroughWire>,
}

/// A wire stored in the [`EncodedCircuitInfo`] that directly connected the
/// input node to the output node in the encoded region, and was not encoded in
/// the pytket circuit.
Expand Down Expand Up @@ -181,11 +193,31 @@ impl EncodedCircuit<Node> {
)?;
decoder.run_decoder(
&encoded.serial_circuit.commands,
encoded.extra_subgraph,
&encoded.straight_through_wires,
Some(&encoded.additional_nodes_and_wires),
)?;
let decoded_node = decoder.finish(&encoded.output_params)?.node();

// Move any non-local edges from originating from the old input node.
let old_input = hugr.get_io(original_region).unwrap()[0];
let input_optype = hugr.get_optype(old_input).clone();
let new_input = hugr.get_io(decoded_node).unwrap()[0];
for src_port in hugr.node_outputs(old_input).collect_vec() {
for (tgt_node, tgt_port) in hugr.linked_inputs(old_input, src_port).collect_vec() {
let tgt_parent = hugr.get_parent(tgt_node);
let is_local_wire = tgt_parent == Some(original_region);
let is_value_wire =
matches!(input_optype.port_kind(src_port), Some(EdgeKind::Value(_)));
let wire_to_decoded_region = tgt_parent == Some(decoded_node);
// Ignore local wires, as all nodes will be deleted.
// Also ignore value wires to the newly decoded region,
// as they come from transplanted opaque subgraphs that already
// re-connected their inputs.
if !(is_local_wire || (is_value_wire && wire_to_decoded_region)) {
hugr.connect(new_input, src_port, tgt_node, tgt_port);
}
}
}

// Replace the region with the decoded function.
//
// All descendant nodes that were re-used by the decoded circuit got
Expand Down Expand Up @@ -333,7 +365,7 @@ impl<Node: HugrNode> EncodedCircuit<Node> {

let mut decoder =
PytketDecoderContext::new(serial_circuit, &mut hugr, target, options, None)?;
decoder.run_decoder(&serial_circuit.commands, None, &[])?;
decoder.run_decoder(&serial_circuit.commands, None)?;
decoder.finish(&[])?;
Ok(hugr)
}
Expand Down
23 changes: 23 additions & 0 deletions tket/src/serialize/pytket/config/decoder_config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
//! A configuration struct contains a list of custom decoders that define
//! translations of legacy tket primitives into HUGR operations.

use hugr::builder::DFGBuilder;
use hugr::types::Type;
use hugr::{Hugr, Wire};
use itertools::Itertools;
use std::collections::HashMap;

Expand Down Expand Up @@ -107,4 +109,25 @@ impl PytketDecoderConfig {
pub fn type_to_pytket(&self, typ: &Type) -> Option<RegisterCount> {
self.type_translators.type_to_pytket(typ)
}

/// Returns `true` if the two types are isomorphic. I.e. they can be translated
/// into each other without losing information.
pub fn types_are_isomorphic(&self, typ1: &Type, typ2: &Type) -> bool {
self.type_translators.types_are_isomorphic(typ1, typ2)
}

/// Inserts the necessary operations to translate a type into an isomorphic
/// type.
///
/// This operation fails if [`Self::types_are_isomorphic`] returns `false`.
pub(in crate::serialize::pytket) fn transform_typed_value(
&self,
wire: Wire,
initial_type: &Type,
target_type: &Type,
builder: &mut DFGBuilder<&mut Hugr>,
) -> Result<Wire, PytketDecodeError> {
self.type_translators
.transform_typed_value(wire, initial_type, target_type, builder)
}
}
108 changes: 103 additions & 5 deletions tket/src/serialize/pytket/config/type_translators.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,18 @@
use std::collections::HashMap;
use std::sync::RwLock;

use hugr::builder::{BuildError, DFGBuilder, Dataflow};
use hugr::extension::prelude::bool_t;
use hugr::extension::ExtensionId;
use hugr::std_extensions::arithmetic::float_types;
use hugr::types::{Type, TypeEnum};
use hugr::{Hugr, Wire};
use itertools::Itertools;

use crate::extension::bool::BoolOp;
use crate::extension::rotation;
use crate::serialize::pytket::extension::{PytketTypeTranslator, RegisterCount};
use crate::serialize::pytket::{PytketDecodeError, PytketDecodeErrorInner};

/// A set of [`PytketTypeTranslator`]s that can be used to translate HUGR types
/// into pytket registers (qubits, bits, and parameter expressions).
Expand Down Expand Up @@ -61,11 +67,26 @@ impl TypeTranslatorSet {
/// Only tuple sums, bools, and custom types are supported.
/// Other types will return `None`.
pub fn type_to_pytket(&self, typ: &Type) -> Option<RegisterCount> {
self.type_to_pytket_internal(typ).filter(|c| !c.is_empty())
}

/// Recursive call for [`Self::type_to_pytket`].
///
/// This allows returning empty register counts, for types that may be included inside other types.
fn type_to_pytket_internal(&self, typ: &Type) -> Option<RegisterCount> {
let cache = self.type_cache.read().ok();
if let Some(count) = cache.and_then(|c| c.get(typ).cloned()) {
return count;
}

// We currently don't allow user types to contain parameters,
// so we handle rotations and floats manually here.
if typ.as_extension().is_some_and(|ext| {
[float_types::EXTENSION_ID, rotation::ROTATION_EXTENSION_ID].contains(ext.extension())
}) {
return Some(RegisterCount::only_params(1));
}

let res = match typ.as_type_enum() {
TypeEnum::Sum(sum) => {
if sum.num_variants() == 0 {
Expand All @@ -79,13 +100,14 @@ impl TypeTranslatorSet {
.iter()
.map(|ty| {
match ty.clone().try_into() {
Ok(ty) => self.type_to_pytket(&ty),
Ok(ty) => self.type_to_pytket_internal(&ty),
// Sum types with row variables (variable tuple lengths) are not supported.
Err(_) => None,
}
})
.sum();
count
// Don't allow parameters nested inside other types
count.filter(|c| c.params == 0)
} else {
None
}
Expand All @@ -94,7 +116,12 @@ impl TypeTranslatorSet {
let type_ext = custom.extension();
for encoder in self.translators_for_extension(type_ext) {
if let Some(count) = encoder.type_to_pytket(custom, self) {
break 'outer Some(count);
// Don't allow user types with nested parameters
if count.params == 0 {
break 'outer Some(count);
} else {
break 'outer None;
}
}
}
None
Expand All @@ -121,6 +148,77 @@ impl TypeTranslatorSet {
.flatten()
.map(move |idx| &self.type_translators[*idx])
}

/// Returns `true` if the two types are isomorphic. I.e. they can be translated
/// into each other without losing information.
//
// TODO: We should allow custom TypeTranslators to expand this checks,
// and implement their own translations.
pub fn types_are_isomorphic(&self, typ1: &Type, typ2: &Type) -> bool {
if typ1 == typ2 {
return true;
}

// For now, we just hard-code this to the two kind of bits we support.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This only works if the sum-bool is used linearly, right? Is this a problem?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or do we assume that pytket extraction only runs after bool linearisation?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sum bool is marked as unsupported, so we won't see them here.

let native_bool = bool_t();
let tket_bool = crate::extension::bool::bool_type();
if (typ1 == &native_bool && typ2 == &tket_bool)
|| (typ1 == &tket_bool && typ2 == &native_bool)
{
return true;
}

false
}

/// Inserts the necessary operations to translate a type into an isomorphic
/// type.
///
/// This operation fails if [`Self::types_are_isomorphic`] returns `false`.
pub(super) fn transform_typed_value(
&self,
wire: Wire,
initial_type: &Type,
target_type: &Type,
builder: &mut DFGBuilder<&mut Hugr>,
) -> Result<Wire, PytketDecodeError> {
if initial_type == target_type {
return Ok(wire);
}

let map_build_error = |e: BuildError| PytketDecodeErrorInner::CannotTranslateWire {
wire,
initial_type: initial_type.to_string(),
target_type: target_type.to_string(),
context: Some(e.to_string()),
};

// Hard-coded transformations until customs calls are added to [`PytketTypeTranslator`].
let native_bool = bool_t();
let tket_bool = crate::extension::bool::bool_type();
if initial_type == &native_bool && target_type == &tket_bool {
let [wire] = builder
.add_dataflow_op(BoolOp::make_opaque, [wire])
.map_err(map_build_error)?
.outputs_arr();
return Ok(wire);
}
if initial_type == &tket_bool && target_type == &native_bool {
let [wire] = builder
.add_dataflow_op(BoolOp::read, [wire])
.map_err(map_build_error)?
.outputs_arr();
return Ok(wire);
}

Err(PytketDecodeErrorInner::CannotTranslateWire {
wire,
initial_type: initial_type.to_string(),
target_type: target_type.to_string(),
context: None,
}
.wrap())
}
}

#[cfg(test)]
Expand Down Expand Up @@ -161,10 +259,10 @@ mod tests {
}

#[rstest::rstest]
#[case::empty(SumType::new_unary(0).into(), Some(RegisterCount::default()))]
#[case::empty(SumType::new_unary(0).into(), None)]
#[case::native_bool(SumType::new_unary(2).into(), Some(RegisterCount::only_bits(1)))]
#[case::simple(bool_t(), Some(RegisterCount::only_bits(1)))]
#[case::tuple(SumType::new_tuple(vec![bool_t(), qb_t(), bool_t()]).into(), Some(RegisterCount::new(1, 2, 0)))]
#[case::tuple(SumType::new_tuple(vec![bool_t(), qb_t(), bool_t(), SumType::new_unary(1).into()]).into(), Some(RegisterCount::new(1, 2, 0)))]
#[case::unsupported(SumType::new([vec![bool_t(), qb_t()], vec![bool_t()]]).into(), None)]
fn test_translations(
translator_set: TypeTranslatorSet,
Expand Down
Loading
Loading