Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions tket/src/serialize/pytket/config/decoder_config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
//! A configuration struct contains a list of custom decoders that define
//! translations of legacy tket primitives into HUGR operations.

use hugr::builder::DFGBuilder;
use hugr::types::Type;
use hugr::{Hugr, Wire};
use itertools::Itertools;
use std::collections::HashMap;

Expand Down Expand Up @@ -107,4 +109,25 @@ impl PytketDecoderConfig {
pub fn type_to_pytket(&self, typ: &Type) -> Option<RegisterCount> {
self.type_translators.type_to_pytket(typ)
}

/// Returns `true` if the two types are isomorphic. I.e. they can be translated
/// into each other without losing information.
pub fn types_are_isomorphic(&self, typ1: &Type, typ2: &Type) -> bool {
self.type_translators.types_are_isomorphic(typ1, typ2)
}

/// Inserts the necessary operations to translate a type into an isomorphic
/// type.
///
/// This operation fails if [`Self::types_are_isomorphic`] returns `false`.
pub(in crate::serialize::pytket) fn transform_typed_value(
&self,
wire: Wire,
initial_type: &Type,
target_type: &Type,
builder: &mut DFGBuilder<&mut Hugr>,
) -> Result<Wire, PytketDecodeError> {
self.type_translators
.transform_typed_value(wire, initial_type, target_type, builder)
}
}
75 changes: 75 additions & 0 deletions tket/src/serialize/pytket/config/type_translators.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,16 @@
use std::collections::HashMap;
use std::sync::RwLock;

use hugr::builder::{BuildError, DFGBuilder, Dataflow};
use hugr::extension::prelude::bool_t;
use hugr::extension::ExtensionId;
use hugr::types::{Type, TypeEnum};
use hugr::{Hugr, Wire};
use itertools::Itertools;

use crate::extension::bool::BoolOp;
use crate::serialize::pytket::extension::{PytketTypeTranslator, RegisterCount};
use crate::serialize::pytket::{PytketDecodeError, PytketDecodeErrorInner};

/// A set of [`PytketTypeTranslator`]s that can be used to translate HUGR types
/// into pytket registers (qubits, bits, and parameter expressions).
Expand Down Expand Up @@ -128,6 +132,77 @@ impl TypeTranslatorSet {
.flatten()
.map(move |idx| &self.type_translators[*idx])
}

/// Returns `true` if the two types are isomorphic. I.e. they can be translated
/// into each other without losing information.
//
// TODO: We should allow custom TypeTranslators to expand this checks,
// and implement their own translations.
pub fn types_are_isomorphic(&self, typ1: &Type, typ2: &Type) -> bool {
if typ1 == typ2 {
return true;
}

// For now, we just hard-code this to the two kind of bits we support.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This only works if the sum-bool is used linearly, right? Is this a problem?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or do we assume that pytket extraction only runs after bool linearisation?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sum bool is marked as unsupported, so we won't see them here.

let native_bool = bool_t();
let tket_bool = crate::extension::bool::bool_type();
if (typ1 == &native_bool && typ2 == &tket_bool)
|| (typ1 == &tket_bool && typ2 == &native_bool)
{
return true;
}

false
}

/// Inserts the necessary operations to translate a type into an isomorphic
/// type.
///
/// This operation fails if [`Self::types_are_isomorphic`] returns `false`.
pub(super) fn transform_typed_value(
&self,
wire: Wire,
initial_type: &Type,
target_type: &Type,
builder: &mut DFGBuilder<&mut Hugr>,
) -> Result<Wire, PytketDecodeError> {
if initial_type == target_type {
return Ok(wire);
}

let map_build_error = |e: BuildError| PytketDecodeErrorInner::CannotTranslateWire {
wire,
initial_type: initial_type.to_string(),
target_type: target_type.to_string(),
context: Some(e.to_string()),
};

// Hard-coded transformations until customs calls are added to [`PytketTypeTranslator`].
let native_bool = bool_t();
let tket_bool = crate::extension::bool::bool_type();
if initial_type == &native_bool && target_type == &tket_bool {
let [wire] = builder
.add_dataflow_op(BoolOp::make_opaque, [wire])
.map_err(map_build_error)?
.outputs_arr();
return Ok(wire);
}
if initial_type == &tket_bool && target_type == &native_bool {
let [wire] = builder
.add_dataflow_op(BoolOp::read, [wire])
.map_err(map_build_error)?
.outputs_arr();
return Ok(wire);
}

Err(PytketDecodeErrorInner::CannotTranslateWire {
wire,
initial_type: initial_type.to_string(),
target_type: target_type.to_string(),
context: None,
}
.wrap())
}
}

#[cfg(test)]
Expand Down
22 changes: 15 additions & 7 deletions tket/src/serialize/pytket/decoder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -341,7 +341,8 @@ impl<'h> PytketDecoderContext<'h> {
let found_wire = self
.wire_tracker
.find_typed_wire(
self.config(),
&self.config,
&mut self.builder,
ty,
&mut qubits,
&mut bits,
Expand Down Expand Up @@ -417,7 +418,8 @@ impl<'h> PytketDecoderContext<'h> {
for q in qubits.iter() {
let mut qubit_args: &[TrackedQubit] = std::slice::from_ref(q);
let Ok(FoundWire::Register(wire)) = self.wire_tracker.find_typed_wire(
self.config(),
&self.config,
&mut self.builder,
&qb_type,
&mut qubit_args,
&mut bit_args,
Expand Down Expand Up @@ -540,14 +542,20 @@ impl<'h> PytketDecoderContext<'h> {
/// - [`PytketDecodeErrorInner::UnexpectedInputType`] if a type in `types` cannot be mapped to a [`RegisterCount`]
/// - [`PytketDecodeErrorInner::NoMatchingWire`] if there is no wire with the requested type for the given qubit/bit arguments.
pub fn find_typed_wires(
&self,
&mut self,
types: &[Type],
qubit_args: &[TrackedQubit],
bit_args: &[TrackedBit],
params: &[LoadedParameter],
) -> Result<TrackedWires, PytketDecodeError> {
self.wire_tracker
.find_typed_wires(self.config(), types, qubit_args, bit_args, params)
self.wire_tracker.find_typed_wires(
&self.config,
&mut self.builder,
types,
qubit_args,
bit_args,
params,
)
}

/// Connects the input ports of a node using a list of input qubits, bits,
Expand Down Expand Up @@ -663,8 +671,8 @@ impl<'h> PytketDecoderContext<'h> {
}

// Gather the input wires, with the types needed by the operation.
let input_wires =
self.find_typed_wires(sig.input_types(), input_qubits, input_bits, params)?;
let input_types = sig.input_types().to_vec();
let input_wires = self.find_typed_wires(&input_types, input_qubits, input_bits, params)?;
debug_assert_eq!(op_input_count, input_wires.register_count());

for (input_idx, wire) in input_wires.wires().enumerate() {
Expand Down
6 changes: 4 additions & 2 deletions tket/src/serialize/pytket/decoder/subgraph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,8 @@ impl<'h> PytketDecoderContext<'h> {
.zip_eq(subgraph.incoming_ports())
{
let found_wire = self.wire_tracker.find_typed_wire(
self.config(),
&self.config,
&mut self.builder,
ty,
&mut input_qubits,
&mut input_bits,
Expand Down Expand Up @@ -354,7 +355,8 @@ impl<'h> PytketDecoderContext<'h> {
// outputs.
for ((ty, edge_id), targets) in payload_inputs.iter().zip_eq(to_insert_inputs) {
let found_wire = self.wire_tracker.find_typed_wire(
self.config(),
&self.config,
&mut self.builder,
ty,
&mut input_qubits,
&mut input_bits,
Expand Down
55 changes: 44 additions & 11 deletions tket/src/serialize/pytket/decoder/wires.rs
Original file line number Diff line number Diff line change
Expand Up @@ -508,6 +508,18 @@ impl WireTracker {
Ok(self.get_bit(id))
}

/// Mark all the values in a wire as outdated.
fn mark_wire_outdated(&mut self, wire: Wire) {
let wire_data = &self.wires[&wire];

for qubit in &wire_data.qubits {
self.qubits[qubit.0].mark_outdated();
}
for bit in &wire_data.bits {
self.bits[bit.0].mark_outdated();
}
}

/// Mark a qubit as outdated, without adding a new wire containing the fresh value.
///
/// This is used when a hugr operation consumes pytket registers as its inputs, but doesn't use them in the outputs.
Expand Down Expand Up @@ -598,6 +610,9 @@ impl WireTracker {
/// The qubit and bit arguments are only consumed as required by the type,
/// some registers may be left unused.
///
/// If the wire type require additional conversion, some operations will be
/// added to the Hugr to perform it.
///
/// # Arguments
///
/// * `config` - The configuration for the decoder, used to count the qubits
Expand All @@ -614,9 +629,11 @@ impl WireTracker {
/// # Errors
///
/// See [`WireTracker::find_typed_wires`] for possible errors.
#[allow(clippy::too_many_arguments)]
pub(in crate::serialize::pytket) fn find_typed_wire(
&self,
&mut self,
config: &PytketDecoderConfig,
builder: &mut DFGBuilder<&mut Hugr>,
ty: &Type,
qubit_args: &mut &[TrackedQubit],
bit_args: &mut &[TrackedBit],
Expand Down Expand Up @@ -668,18 +685,18 @@ impl WireTracker {
let wire_qubits = qubit_args
.iter()
.take(reg_count.qubits)
.map(|q| q.id())
.collect_vec();
let wire_bits = bit_args
.iter()
.take(reg_count.bits)
.map(|bit| bit.id())
.cloned()
.collect_vec();
let wire_qubit_ids = wire_qubits.iter().map(|q| q.id()).collect_vec();
let wire_bits = bit_args.iter().take(reg_count.bits).cloned().collect_vec();
let wire_bit_ids = wire_bits.iter().map(|bit| bit.id()).collect_vec();

// Find a wire that contains the correct type..
let check_wire = |w: &Wire| {
let wire_data = &self.wires[w];
wire_data.ty() == ty && wire_data.qubits == wire_qubits && wire_data.bits == wire_bits
wire_data.qubits == wire_qubit_ids
&& wire_data.bits == wire_bit_ids
&& config.types_are_isomorphic(wire_data.ty(), ty)
};
let Some(wire) = candidate.find(check_wire) else {
return Err(PytketDecodeErrorInner::NoMatchingWire {
Expand All @@ -695,6 +712,7 @@ impl WireTracker {
}
.wrap());
};
drop(candidate);

// Check that none of the selected qubit or bit has been marked as outdated.
if let Some(qubit) = qubit_args
Expand All @@ -719,11 +737,21 @@ impl WireTracker {
}

// Mark the qubits and bits as used.
// TODO: We can use the slice `split_off` method once MSRV is ≥1.87
*qubit_args = &qubit_args[reg_count.qubits..];
*bit_args = &bit_args[reg_count.bits..];

Ok(FoundWire::Register(self.wires[&wire].clone()))
// Convert the wire type, if needed.
let wire_data = &self.wires[&wire];
let new_wire = config.transform_typed_value(wire, wire_data.ty(), ty, builder)?;

if wire == new_wire {
Ok(FoundWire::Register(self.wires[&wire].clone()))
} else {
let ty: Arc<Type> = wire_data.ty.clone();
self.track_wire(new_wire, ty, wire_qubits, wire_bits)?;
self.mark_wire_outdated(wire);
Ok(FoundWire::Register(self.wires[&new_wire].clone()))
}
}

/// Returns a new [TrackedWires] set for a list of [`TrackedQubit`]s,
Expand All @@ -735,6 +763,9 @@ impl WireTracker {
/// The qubit and bit arguments are only consumed as required by the types.
/// Some registers may be left unused.
///
/// If the wire type require additional conversion, some operations will be
/// added to the Hugr to perform it.
///
/// # Arguments
///
/// * `config` - The configuration for the decoder, used to count the qubits and bits required by each type.
Expand All @@ -751,8 +782,9 @@ impl WireTracker {
/// - [`PytketDecodeErrorInner::UnexpectedInputType`] if a type in `types` cannot be mapped to a [`RegisterCount`]
/// - [`PytketDecodeErrorInner::NoMatchingWire`] if there is no wire with the requested type for the given qubit/bit arguments.
pub(super) fn find_typed_wires(
&self,
&mut self,
config: &PytketDecoderConfig,
builder: &mut DFGBuilder<&mut Hugr>,
types: &[Type],
mut qubit_args: &[TrackedQubit],
mut bit_args: &[TrackedBit],
Expand All @@ -768,6 +800,7 @@ impl WireTracker {
for ty in types {
match self.find_typed_wire(
config,
builder,
ty,
&mut qubit_args,
&mut bit_args,
Expand Down
14 changes: 14 additions & 0 deletions tket/src/serialize/pytket/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -432,6 +432,20 @@ pub enum PytketDecodeErrorInner {
/// The envelope decoding error.
source: EnvelopeError,
},
/// Cannot translate a wire from one type to another.
#[display("Cannot translate {wire} from type {initial_type} to type {target_type}{}",
context.as_ref().map(|s| format!(". {s}")).unwrap_or_default()
)]
CannotTranslateWire {
/// The wire that couldn't be translated.
wire: Wire,
/// The initial type of the wire.
initial_type: String,
/// The target type of the wire.
target_type: String,
/// The error that occurred while translating the wire.
context: Option<String>,
},
}

impl PytketDecodeErrorInner {
Expand Down
Loading