diff --git a/Cargo.lock b/Cargo.lock index 36df61ca8f..e3cbc2a247 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1263,16 +1263,19 @@ dependencies = [ "proptest", "proptest-derive", "regex", + "relrc", "rstest", "semver", "serde", "serde_json", "serde_with", "serde_yaml", + "smallvec", "smol_str", "static_assertions", "strum", "thiserror 2.0.12", + "tracing", "typetag", "zstd", ] diff --git a/hugr-core/Cargo.toml b/hugr-core/Cargo.toml index 3206a25212..c7d9eacfc9 100644 --- a/hugr-core/Cargo.toml +++ b/hugr-core/Cargo.toml @@ -63,6 +63,9 @@ semver = { workspace = true, features = ["serde"] } zstd = { workspace = true, optional = true } ordered-float = { workspace = true, features = ["serde"] } base64.workspace = true +relrc = { workspace = true, features = ["petgraph", "serde"] } +smallvec = "1.15.0" +tracing = "0.1.41" [dev-dependencies] rstest = { workspace = true } diff --git a/hugr-core/src/export.rs b/hugr-core/src/export.rs index ff6fac8b43..ecd6506933 100644 --- a/hugr-core/src/export.rs +++ b/hugr-core/src/export.rs @@ -945,7 +945,7 @@ impl<'a> Context<'a> { self.make_term_apply(model::CORE_TYPE, &[]) } - Term::BoundedNatType { .. } => self.make_term_apply(model::CORE_NAT_TYPE, &[]), + Term::BoundedNatType(_) => self.make_term_apply(model::CORE_NAT_TYPE, &[]), Term::StringType => self.make_term_apply(model::CORE_STR_TYPE, &[]), Term::BytesType => self.make_term_apply(model::CORE_BYTES_TYPE, &[]), Term::FloatType => self.make_term_apply(model::CORE_FLOAT_TYPE, &[]), @@ -953,14 +953,9 @@ impl<'a> Context<'a> { let item_type = self.export_term(item_type, None); self.make_term_apply(model::CORE_LIST_TYPE, &[item_type]) } - Term::TupleType(params) => { - let item_types = self.bump.alloc_slice_fill_iter( - params - .iter() - .map(|param| table::SeqPart::Item(self.export_term(param, None))), - ); - let types = self.make_term(table::Term::List(item_types)); - self.make_term_apply(model::CORE_TUPLE_TYPE, &[types]) + Term::TupleType(item_types) => { + let item_types = self.export_term(item_types, None); + self.make_term_apply(model::CORE_TUPLE_TYPE, &[item_types]) } Term::Runtime(ty) => self.export_type(ty), Term::BoundedNat(value) => self.make_term(model::Literal::Nat(*value).into()), @@ -975,6 +970,14 @@ impl<'a> Context<'a> { ); self.make_term(table::Term::List(parts)) } + Term::ListConcat(lists) => { + let parts = self.bump.alloc_slice_fill_iter( + lists + .iter() + .map(|elem| table::SeqPart::Splice(self.export_term(elem, None))), + ); + self.make_term(table::Term::List(parts)) + } Term::Tuple(elems) => { let parts = self.bump.alloc_slice_fill_iter( elems @@ -983,6 +986,14 @@ impl<'a> Context<'a> { ); self.make_term(table::Term::Tuple(parts)) } + Term::TupleConcat(tuples) => { + let parts = self.bump.alloc_slice_fill_iter( + tuples + .iter() + .map(|elem| table::SeqPart::Splice(self.export_term(elem, None))), + ); + self.make_term(table::Term::Tuple(parts)) + } Term::Variable(v) => self.export_type_arg_var(v), Term::StaticType => self.make_term_apply(model::CORE_STATIC, &[]), } diff --git a/hugr-core/src/extension/resolution/types.rs b/hugr-core/src/extension/resolution/types.rs index 6f5799790a..0ea6bd7007 100644 --- a/hugr-core/src/extension/resolution/types.rs +++ b/hugr-core/src/extension/resolution/types.rs @@ -231,8 +231,16 @@ pub(super) fn collect_term_exts( collect_term_exts(item_type, used_extensions, missing_extensions) } Term::TupleType(item_types) => { - for item_type in item_types { - collect_term_exts(item_type, used_extensions, missing_extensions); + collect_term_exts(item_types, used_extensions, missing_extensions) + } + Term::ListConcat(lists) => { + for list in lists { + collect_term_exts(list, used_extensions, missing_extensions); + } + } + Term::TupleConcat(tuples) => { + for tuple in tuples { + collect_term_exts(tuple, used_extensions, missing_extensions); } } Term::Variable(_) diff --git a/hugr-core/src/extension/resolution/types_mut.rs b/hugr-core/src/extension/resolution/types_mut.rs index e54a21e5ac..8135ca0b1b 100644 --- a/hugr-core/src/extension/resolution/types_mut.rs +++ b/hugr-core/src/extension/resolution/types_mut.rs @@ -222,23 +222,19 @@ pub(super) fn resolve_term_exts( ) -> Result<(), ExtensionResolutionError> { match term { Term::Runtime(ty) => resolve_type_exts(node, ty, extensions, used_extensions)?, - Term::List(elems) => { - for elem in elems.iter_mut() { - resolve_term_exts(node, elem, extensions, used_extensions)?; - } - } - Term::Tuple(elems) => { - for elem in elems.iter_mut() { - resolve_term_exts(node, elem, extensions, used_extensions)?; + Term::List(children) + | Term::ListConcat(children) + | Term::Tuple(children) + | Term::TupleConcat(children) => { + for child in children.iter_mut() { + resolve_term_exts(node, child, extensions, used_extensions)?; } } Term::ListType(item_type) => { - resolve_term_exts(node, item_type, extensions, used_extensions)?; + resolve_term_exts(node, item_type.as_mut(), extensions, used_extensions)?; } Term::TupleType(item_types) => { - for item_type in item_types.iter_mut() { - resolve_term_exts(node, item_type, extensions, used_extensions)?; - } + resolve_term_exts(node, item_types.as_mut(), extensions, used_extensions)?; } Term::Variable(_) | Term::RuntimeType(_) diff --git a/hugr-core/src/hugr/serialize/test.rs b/hugr-core/src/hugr/serialize/test.rs index 13e766f38f..892c176c99 100644 --- a/hugr-core/src/hugr/serialize/test.rs +++ b/hugr-core/src/hugr/serialize/test.rs @@ -482,7 +482,7 @@ fn polyfunctype2() -> PolyFuncTypeRV { #[case(PolyFuncType::new([TypeParam::StringType], Signature::new_endo(vec![Type::new_var_use(0, TypeBound::Copyable)])))] #[case(PolyFuncType::new([TypeBound::Copyable.into()], Signature::new_endo(vec![Type::new_var_use(0, TypeBound::Copyable)])))] #[case(PolyFuncType::new([TypeParam::new_list_type(TypeBound::Any)], Signature::new_endo(type_row![])))] -#[case(PolyFuncType::new([TypeParam::TupleType([TypeBound::Any.into(), TypeParam::bounded_nat_type(2.try_into().unwrap())].into())], Signature::new_endo(type_row![])))] +#[case(PolyFuncType::new([TypeParam::new_tuple_type([TypeBound::Any.into(), TypeParam::bounded_nat_type(2.try_into().unwrap())])], Signature::new_endo(type_row![])))] #[case(PolyFuncType::new( [TypeParam::new_list_type(TypeBound::Any)], Signature::new_endo(Type::new_tuple(TypeRV::new_row_var_use(0, TypeBound::Any)))))] @@ -495,7 +495,7 @@ fn roundtrip_polyfunctype_fixedlen(#[case] poly_func_type: PolyFuncType) { #[case(PolyFuncTypeRV::new([TypeParam::StringType], FuncValueType::new_endo(vec![Type::new_var_use(0, TypeBound::Copyable)])))] #[case(PolyFuncTypeRV::new([TypeBound::Copyable.into()], FuncValueType::new_endo(vec![Type::new_var_use(0, TypeBound::Copyable)])))] #[case(PolyFuncTypeRV::new([TypeParam::new_list_type(TypeBound::Any)], FuncValueType::new_endo(type_row![])))] -#[case(PolyFuncTypeRV::new([TypeParam::TupleType([TypeBound::Any.into(), TypeParam::bounded_nat_type(2.try_into().unwrap())].into())], FuncValueType::new_endo(type_row![])))] +#[case(PolyFuncTypeRV::new([TypeParam::new_tuple_type([TypeBound::Any.into(), TypeParam::bounded_nat_type(2.try_into().unwrap())])], FuncValueType::new_endo(type_row![])))] #[case(PolyFuncTypeRV::new( [TypeParam::new_list_type(TypeBound::Any)], FuncValueType::new_endo(TypeRV::new_row_var_use(0, TypeBound::Any))))] diff --git a/hugr-core/src/import.rs b/hugr-core/src/import.rs index b8cda15463..fbad01434e 100644 --- a/hugr-core/src/import.rs +++ b/hugr-core/src/import.rs @@ -25,13 +25,14 @@ use crate::{ types::{ CustomType, FuncTypeBase, MaybeRV, PolyFuncType, PolyFuncTypeBase, RowVariable, Signature, Term, Type, TypeArg, TypeBase, TypeBound, TypeEnum, TypeName, TypeRow, - type_param::TypeParam, type_row::TypeRowBase, + type_param::{SeqPart, TypeParam}, + type_row::TypeRowBase, }, }; use fxhash::FxHashMap; use hugr_model::v0 as model; use hugr_model::v0::table; -use itertools::Either; +use itertools::{Either, Itertools}; use smol_str::{SmolStr, ToSmolStr}; use thiserror::Error; @@ -1255,14 +1256,10 @@ impl<'a> Context<'a> { if let Some([item_types]) = self.match_symbol(term_id, model::CORE_TUPLE_TYPE)? { // At present `hugr-model` has no way to express that the item // types of a tuple must be copyable. Therefore we import it as `Any`. - let item_types = (|| { - self.import_closed_list(item_types)? - .into_iter() - .map(|param| self.import_term(param)) - .collect::>() - })() - .map_err(|err| error_context!(err, "item types of tuple type"))?; - return Ok(TypeParam::TupleType(item_types)); + let item_types = self + .import_term(item_types) + .map_err(|err| error_context!(err, "item types of tuple type"))?; + return Ok(TypeParam::new_tuple_type(item_types)); } match self.get_term(term_id)? { @@ -1277,28 +1274,24 @@ impl<'a> Context<'a> { Ok(Term::new_var_use(var.1 as _, decl)) } - table::Term::List { .. } => { - let elems = (|| { - self.import_closed_list(term_id)? - .iter() - .map(|item| self.import_term(*item)) - .collect::>() - })() - .map_err(|err| error_context!(err, "list items"))?; - - Ok(Term::List(elems)) + table::Term::List(parts) => { + // PERFORMANCE: Can we do this without the additional allocation? + let parts: Vec<_> = parts + .iter() + .map(|part| self.import_seq_part(part)) + .collect::>() + .map_err(|err| error_context!(err, "list parts"))?; + Ok(TypeArg::new_list_from_parts(parts)) } - table::Term::Tuple { .. } => { - let elems = (|| { - self.import_closed_list(term_id)? - .iter() - .map(|item| self.import_term(*item)) - .collect::>() - })() - .map_err(|err| error_context!(err, "tuple items"))?; - - Ok(Term::Tuple(elems)) + table::Term::Tuple(parts) => { + // PERFORMANCE: Can we do this without the additional allocation? + let parts: Vec<_> = parts + .iter() + .map(|part| self.import_seq_part(part)) + .try_collect() + .map_err(|err| error_context!(err, "tuple parts"))?; + Ok(TypeArg::new_tuple_from_parts(parts)) } table::Term::Literal(model::Literal::Str(value)) => { @@ -1322,6 +1315,16 @@ impl<'a> Context<'a> { .map_err(|err| error_context!(err, "term {}", term_id)) } + fn import_seq_part( + &mut self, + seq_part: &'a table::SeqPart, + ) -> Result, ImportError> { + Ok(match seq_part { + table::SeqPart::Item(term_id) => SeqPart::Item(self.import_term(*term_id)?), + table::SeqPart::Splice(term_id) => SeqPart::Splice(self.import_term(*term_id)?), + }) + } + /// Import a `Type` from a term that represents a runtime type. fn import_type( &mut self, diff --git a/hugr-core/src/types/poly_func.rs b/hugr-core/src/types/poly_func.rs index 8425897b1d..35e3d80b24 100644 --- a/hugr-core/src/types/poly_func.rs +++ b/hugr-core/src/types/poly_func.rs @@ -271,7 +271,7 @@ pub(crate) mod test { for decl in [ Term::new_list_type(Term::max_nat_type()), Term::StringType, - Term::TupleType(vec![TypeBound::Any.into(), Term::max_nat_type()]), + Term::new_tuple_type([TypeBound::Any.into(), Term::max_nat_type()]), ] { let invalid_ts = PolyFuncTypeBase::new_validated([decl.clone()], body_type.clone()); assert_eq!( diff --git a/hugr-core/src/types/serialize.rs b/hugr-core/src/types/serialize.rs index f36f0d3081..1fb61bcde4 100644 --- a/hugr-core/src/types/serialize.rs +++ b/hugr-core/src/types/serialize.rs @@ -78,7 +78,7 @@ pub(super) enum TypeParamSer { Float, StaticType, List { param: Box }, - Tuple { params: Vec }, + Tuple { params: ArrayOrTermSer }, } #[derive(Clone, Debug, serde::Deserialize, serde::Serialize)] @@ -104,9 +104,15 @@ pub(super) enum TypeArgSer { List { elems: Vec, }, + ListConcat { + lists: Vec, + }, Tuple { elems: Vec, }, + TupleConcat { + tuples: Vec, + }, Variable { #[serde(flatten)] v: TermVar, @@ -130,8 +136,10 @@ impl From for TermSer { Term::BytesType => TermSer::TypeParam(TypeParamSer::Bytes), Term::FloatType => TermSer::TypeParam(TypeParamSer::Float), Term::ListType(param) => TermSer::TypeParam(TypeParamSer::List { param }), - Term::TupleType(params) => TermSer::TypeParam(TypeParamSer::Tuple { params }), Term::Runtime(ty) => TermSer::TypeArg(TypeArgSer::Type { ty }), + Term::TupleType(params) => TermSer::TypeParam(TypeParamSer::Tuple { + params: (*params).into(), + }), Term::BoundedNat(n) => TermSer::TypeArg(TypeArgSer::BoundedNat { n }), Term::String(arg) => TermSer::TypeArg(TypeArgSer::String { arg }), Term::Bytes(value) => TermSer::TypeArg(TypeArgSer::Bytes { value }), @@ -139,6 +147,8 @@ impl From for TermSer { Term::List(elems) => TermSer::TypeArg(TypeArgSer::List { elems }), Term::Tuple(elems) => TermSer::TypeArg(TypeArgSer::Tuple { elems }), Term::Variable(v) => TermSer::TypeArg(TypeArgSer::Variable { v }), + Term::ListConcat(lists) => TermSer::TypeArg(TypeArgSer::ListConcat { lists }), + Term::TupleConcat(tuples) => TermSer::TypeArg(TypeArgSer::TupleConcat { tuples }), } } } @@ -154,7 +164,7 @@ impl From for Term { TypeParamSer::Bytes => Term::BytesType, TypeParamSer::Float => Term::FloatType, TypeParamSer::List { param } => Term::ListType(param), - TypeParamSer::Tuple { params } => Term::TupleType(params), + TypeParamSer::Tuple { params } => Term::TupleType(Box::new(params.into())), }, TermSer::TypeArg(arg) => match arg { TypeArgSer::Type { ty } => Term::Runtime(ty), @@ -165,11 +175,39 @@ impl From for Term { TypeArgSer::List { elems } => Term::List(elems), TypeArgSer::Tuple { elems } => Term::Tuple(elems), TypeArgSer::Variable { v } => Term::Variable(v), + TypeArgSer::ListConcat { lists } => Term::ListConcat(lists), + TypeArgSer::TupleConcat { tuples } => Term::TupleConcat(tuples), }, } } } +/// Helper type that serialises lists as JSON arrays for compatibility. +#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] +#[serde(untagged)] +pub(super) enum ArrayOrTermSer { + Array(Vec), + Term(Box), +} + +impl From for Term { + fn from(value: ArrayOrTermSer) -> Self { + match value { + ArrayOrTermSer::Array(terms) => Term::new_list(terms), + ArrayOrTermSer::Term(term) => *term, + } + } +} + +impl From for ArrayOrTermSer { + fn from(term: Term) -> Self { + match term { + Term::List(terms) => ArrayOrTermSer::Array(terms), + term => ArrayOrTermSer::Term(Box::new(term)), + } + } +} + /// Helper for to serialize and deserialize the byte string in [`TypeArg::Bytes`] via base64. mod base64 { use std::sync::Arc; diff --git a/hugr-core/src/types/type_param.rs b/hugr-core/src/types/type_param.rs index c53a46a687..c9be168498 100644 --- a/hugr-core/src/types/type_param.rs +++ b/hugr-core/src/types/type_param.rs @@ -4,13 +4,15 @@ //! //! [`TypeDef`]: crate::extension::TypeDef -use itertools::Itertools; use ordered_float::OrderedFloat; #[cfg(test)] use proptest_derive::Arbitrary; +use smallvec::{SmallVec, smallvec}; +use std::iter::FusedIterator; use std::num::NonZeroU64; use std::sync::Arc; use thiserror::Error; +use tracing::warn; use super::row_var::MaybeRV; use super::{ @@ -91,8 +93,8 @@ pub enum Term { #[display("ListType[{_0}]")] ListType(Box), /// The type of static tuples. - #[display("TupleType[{}]", _0.iter().map(std::string::ToString::to_string).join(", "))] - TupleType(Vec), + #[display("TupleType[{_0}]")] + TupleType(Box), /// A runtime type as a term. Instance of [`Term::RuntimeType`]. #[display("{_0}")] Runtime(Type), @@ -114,12 +116,24 @@ pub enum Term { _0.iter().map(|t|t.to_string()).join(",") })] List(Vec), - /// A tuple of static terms. Instance of [`Term::TupleType`]. + /// Instance of [`TypeParam::List`] defined by a sequence of concatenated lists of the same type. + #[display("[{}]", { + use itertools::Itertools as _; + _0.iter().map(|t| format!("... {}", t)).join(",") + })] + ListConcat(Vec), + /// Instance of [`TypeParam::Tuple`] defined by a sequence of elements of varying type. #[display("({})", { use itertools::Itertools as _; _0.iter().map(std::string::ToString::to_string).join(",") })] Tuple(Vec), + /// Instance of [`TypeParam::Tuple`] defined by a sequence of concatenated tuples. + #[display("({})", { + use itertools::Itertools as _; + _0.iter().map(|tuple| format!("... {}", tuple)).join(",") + })] + TupleConcat(Vec), /// Variable (used in type schemes or inside polymorphic functions), /// but not a runtime type (not even a row variable i.e. list of runtime types) /// - see [`Term::new_var_use`] @@ -150,9 +164,9 @@ impl Term { Self::ListType(Box::new(elem.into())) } - /// Creates a new [`Term::TupleType`] given the types of its elements. - pub fn new_tuple_type(item_types: impl IntoIterator) -> Self { - Self::TupleType(item_types.into_iter().collect()) + /// Creates a new [`Term::TupleType`] given the type of its elements. + pub fn new_tuple_type(item_types: impl Into) -> Self { + Self::TupleType(Box::new(item_types.into())) } /// Checks if this term is a supertype of another. @@ -168,9 +182,7 @@ impl Term { (Term::StringType, Term::StringType) => true, (Term::StaticType, Term::StaticType) => true, (Term::ListType(e1), Term::ListType(e2)) => e1.is_supertype(e2), - (Term::TupleType(es1), Term::TupleType(es2)) => { - es1.len() == es2.len() && es1.iter().zip(es2).all(|(e1, e2)| e1.is_supertype(e2)) - } + (Term::TupleType(es1), Term::TupleType(es2)) => es1.is_supertype(es2), (Term::BytesType, Term::BytesType) => true, (Term::FloatType, Term::FloatType) => true, (Term::Runtime(t1), Term::Runtime(t2)) => t1 == t2, @@ -235,6 +247,12 @@ impl From> for Term { } } +impl From<[Term; N]> for Term { + fn from(value: [Term; N]) -> Self { + Self::new_list(value) + } +} + /// Variable in a [`Term`], that is not a single runtime type (i.e. not a [`Type::new_var_use`] /// - it might be a [`Type::new_row_var_use`]). #[derive( @@ -267,6 +285,30 @@ impl Term { } } + /// Creates a new string literal. + #[inline] + pub fn new_string(str: impl ToString) -> Self { + Self::String(str.to_string()) + } + + /// Creates a new concatenated list. + #[inline] + pub fn new_list_concat(lists: impl IntoIterator) -> Self { + Self::ListConcat(lists.into_iter().collect()) + } + + /// Creates a new tuple from its items. + #[inline] + pub fn new_tuple(items: impl IntoIterator) -> Self { + Self::Tuple(items.into_iter().collect()) + } + + /// Creates a new concatenated tuple. + #[inline] + pub fn new_tuple_concat(tuples: impl IntoIterator) -> Self { + Self::TupleConcat(tuples.into_iter().collect()) + } + /// Returns an integer if the [`Term`] is a natural number literal. #[must_use] pub fn as_nat(&self) -> Option { @@ -305,6 +347,12 @@ impl Term { } Term::Tuple(elems) => elems.iter().try_for_each(|a| a.validate(var_decls)), Term::BoundedNat(_) | Term::String { .. } | Term::Float(_) | Term::Bytes(_) => Ok(()), + TypeArg::ListConcat(lists) => { + // TODO: Full validation would check that each of the lists is indeed a + // list or list variable of the correct types. + lists.iter().try_for_each(|a| a.validate(var_decls)) + } + TypeArg::TupleConcat(tuples) => tuples.iter().try_for_each(|a| a.validate(var_decls)), Term::Variable(TermVar { idx, cached_decl }) => { assert!( !matches!(&**cached_decl, TypeParam::RuntimeType { .. }), @@ -319,7 +367,7 @@ impl Term { Term::BytesType => Ok(()), Term::FloatType => Ok(()), Term::ListType(item_type) => item_type.validate(var_decls), - Term::TupleType(params) => params.iter().try_for_each(|p| p.validate(var_decls)), + Term::TupleType(item_types) => item_types.validate(var_decls), Term::StaticType => Ok(()), } } @@ -330,51 +378,198 @@ impl Term { // RowVariables are represented as Term::Variable ty.substitute1(t).into() } - Term::BoundedNat(_) | Term::String { .. } | Term::Bytes(_) | Term::Float(_) => { + TypeArg::BoundedNat(_) | TypeArg::String(_) | TypeArg::Bytes(_) | TypeArg::Float(_) => { self.clone() - } - Term::List(elems) => { - let mut are_types = elems.iter().map(|ta| match ta { - Term::Runtime { .. } => true, - Term::Variable(v) => v.bound_if_row_var().is_some(), - _ => false, - }); - let elems = match are_types.next() { - Some(true) => { - assert!(are_types.all(|b| b)); // If one is a Type, so must the rest be - // So, anything that doesn't produce a Type, was a row variable => multiple Types - elems - .iter() - .flat_map(|ta| match ta.substitute(t) { - ty @ Term::Runtime { .. } => vec![ty], - Term::List(elems) => elems, - _ => panic!("Expected Type or row of Types"), - }) - .collect() - } - _ => { - // not types, no need to flatten (and mustn't, in case of nested Sequences) - elems.iter().map(|ta| ta.substitute(t)).collect() + } // We do not allow variables as bounds on BoundedNat's + TypeArg::List(elems) => { + // NOTE: This implements a hack allowing substitutions to + // replace `TypeArg::Variable`s representing "row variables" + // with a list that is to be spliced into the containing list. + // We won't need this code anymore once we stop conflating types + // with lists of types. + + fn is_type(type_arg: &TypeArg) -> bool { + match type_arg { + TypeArg::Runtime(_) => true, + TypeArg::Variable(v) => v.bound_if_row_var().is_some(), + _ => false, } - }; - Term::List(elems) + } + + let are_types = elems.first().map(is_type).unwrap_or(false); + + Self::new_list_from_parts(elems.iter().map(|elem| match elem.substitute(t) { + list @ TypeArg::List { .. } if are_types => SeqPart::Splice(list), + list @ TypeArg::ListConcat { .. } if are_types => SeqPart::Splice(list), + elem => SeqPart::Item(elem), + })) + } + TypeArg::ListConcat(lists) => { + // When a substitution instantiates spliced list variables, we + // may be able to merge the concatenated lists. + Self::new_list_from_parts( + lists.iter().map(|list| SeqPart::Splice(list.substitute(t))), + ) } Term::Tuple(elems) => { Term::Tuple(elems.iter().map(|elem| elem.substitute(t)).collect()) } - Term::Variable(TermVar { idx, cached_decl }) => t.apply_var(*idx, cached_decl), - Term::RuntimeType { .. } => self.clone(), - Term::BoundedNatType { .. } => self.clone(), + TypeArg::TupleConcat(tuples) => { + // When a substitution instantiates spliced tuple variables, + // we may be able to merge the concatenated tuples. + Self::new_tuple_from_parts( + tuples + .iter() + .map(|tuple| SeqPart::Splice(tuple.substitute(t))), + ) + } + TypeArg::Variable(TermVar { idx, cached_decl }) => t.apply_var(*idx, cached_decl), + Term::RuntimeType(_) => self.clone(), + Term::BoundedNatType(_) => self.clone(), Term::StringType => self.clone(), Term::BytesType => self.clone(), Term::FloatType => self.clone(), Term::ListType(item_type) => Term::new_list_type(item_type.substitute(t)), - Term::TupleType(params) => { - Term::TupleType(params.iter().map(|p| p.substitute(t)).collect()) - } + Term::TupleType(item_types) => Term::new_list_type(item_types.substitute(t)), Term::StaticType => self.clone(), } } + + /// Helper method for [`TypeArg::new_list_from_parts`] and [`TypeArg::new_tuple_from_parts`]. + fn new_seq_from_parts( + parts: impl IntoIterator>, + make_items: impl Fn(Vec) -> Self, + make_concat: impl Fn(Vec) -> Self, + ) -> Self { + let mut items = Vec::new(); + let mut seqs = Vec::new(); + + for part in parts { + match part { + SeqPart::Item(item) => items.push(item), + SeqPart::Splice(seq) => { + if !items.is_empty() { + seqs.push(make_items(std::mem::take(&mut items))); + } + seqs.push(seq); + } + } + } + + if seqs.is_empty() { + make_items(items) + } else if items.is_empty() { + make_concat(seqs) + } else { + seqs.push(make_items(items)); + make_concat(seqs) + } + } + + /// Creates a new list from a sequence of [`SeqPart`]s. + pub fn new_list_from_parts(parts: impl IntoIterator>) -> Self { + Self::new_seq_from_parts( + parts.into_iter().flat_map(ListPartIter::new), + TypeArg::List, + TypeArg::ListConcat, + ) + } + + /// Iterates over the [`SeqPart`]s of a list. + /// + /// # Examples + /// + /// The parts of a closed list are the items of that list wrapped in [`SeqPart::Item`]: + /// + /// ``` + /// # use hugr_core::types::type_param::{Term, SeqPart}; + /// # let a = Term::new_string("a"); + /// # let b = Term::new_string("b"); + /// let term = Term::new_list([a.clone(), b.clone()]); + /// + /// assert_eq!( + /// term.into_list_parts().collect::>(), + /// vec![SeqPart::Item(a), SeqPart::Item(b)] + /// ); + /// ``` + /// + /// Parts of a concatenated list that are not closed lists are wrapped in [`SeqPart::Splice`]: + /// + /// ``` + /// # use hugr_core::types::type_param::{Term, SeqPart}; + /// # let a = Term::new_string("a"); + /// # let b = Term::new_string("b"); + /// # let c = Term::new_string("c"); + /// let var = Term::new_var_use(0, Term::new_list_type(Term::StringType)); + /// let term = Term::new_list_concat([ + /// Term::new_list([a.clone(), b.clone()]), + /// var.clone(), + /// Term::new_list([c.clone()]) + /// ]); + /// + /// assert_eq!( + /// term.into_list_parts().collect::>(), + /// vec![SeqPart::Item(a), SeqPart::Item(b), SeqPart::Splice(var), SeqPart::Item(c)] + /// ); + /// ``` + /// + /// Nested concatenations are traversed recursively: + /// + /// ``` + /// # use hugr_core::types::type_param::{Term, SeqPart}; + /// # let a = Term::new_string("a"); + /// # let b = Term::new_string("b"); + /// # let c = Term::new_string("c"); + /// let term = Term::new_list_concat([ + /// Term::new_list_concat([ + /// Term::new_list([a.clone()]), + /// Term::new_list([b.clone()]) + /// ]), + /// Term::new_list([]), + /// Term::new_list([c.clone()]) + /// ]); + /// + /// assert_eq!( + /// term.into_list_parts().collect::>(), + /// vec![SeqPart::Item(a), SeqPart::Item(b), SeqPart::Item(c)] + /// ); + /// ``` + /// + /// When invoked on a type argument that is not a list, a single + /// [`SeqPart::Splice`] is returned that wraps the type argument. + /// This is the expected behaviour for type variables that stand for lists. + /// This behaviour also allows this method not to fail on ill-typed type arguments. + /// ``` + /// # use hugr_core::types::type_param::{Term, SeqPart}; + /// let term = Term::new_string("not a list"); + /// assert_eq!( + /// term.clone().into_list_parts().collect::>(), + /// vec![SeqPart::Splice(term)] + /// ); + /// ``` + #[inline] + pub fn into_list_parts(self) -> ListPartIter { + ListPartIter::new(SeqPart::Splice(self)) + } + + /// Creates a new tuple from a sequence of [`SeqPart`]s. + /// + /// Analogous to [`TypeArg::new_list_from_parts`]. + pub fn new_tuple_from_parts(parts: impl IntoIterator>) -> Self { + Self::new_seq_from_parts( + parts.into_iter().flat_map(TuplePartIter::new), + TypeArg::Tuple, + TypeArg::TupleConcat, + ) + } + + /// Iterates over the [`SeqPart`]s of a tuple. + /// + /// Analogous to [`TypeArg::into_list_parts`]. + #[inline] + pub fn into_tuple_parts(self) -> TuplePartIter { + TuplePartIter::new(SeqPart::Splice(self)) + } } impl Transformable for Term { @@ -396,6 +591,8 @@ impl Transformable for Term { Term::ListType(item_type) => item_type.transform(tr), Term::TupleType(item_types) => item_types.transform(tr), Term::StaticType => Ok(false), + TypeArg::ListConcat(lists) => lists.transform(tr), + TypeArg::TupleConcat(tuples) => tuples.transform(tr), } } } @@ -442,18 +639,37 @@ pub fn check_term_type(term: &Term, type_: &Term) -> Result<(), TermTypeError> { check_term_type(term, item_type) }) } - (Term::Tuple(items), Term::TupleType(item_types)) => { - if items.len() != item_types.len() { + (Term::ListConcat(lists), Term::ListType(item_type)) => lists + .iter() + .try_for_each(|list| check_term_type(list, item_type)), + (TypeArg::Tuple(_) | TypeArg::TupleConcat(_), TypeParam::TupleType(item_types)) => { + let term_parts: Vec<_> = term.clone().into_tuple_parts().collect(); + let type_parts: Vec<_> = item_types.clone().into_list_parts().collect(); + + for (term, type_) in term_parts.iter().zip(&type_parts) { + match (term, type_) { + (SeqPart::Item(term), SeqPart::Item(type_)) => { + check_term_type(term, type_)?; + } + (_, SeqPart::Splice(_)) | (SeqPart::Splice(_), _) => { + // TODO: Checking tuples with splicing requires more + // sophisticated validation infrastructure to do well. + warn!( + "Validation for open tuples is not implemented yet, succeeding regardless..." + ); + return Ok(()); + } + } + } + + if term_parts.len() != type_parts.len() { return Err(TermTypeError::WrongNumberTuple( - items.len(), - item_types.len(), + term_parts.len(), + type_parts.len(), )); } - items - .iter() - .zip(item_types.iter()) - .try_for_each(|(term, type_)| check_term_type(term, type_)) + Ok(()) } (Term::BoundedNat(val), Term::BoundedNatType(bound)) if bound.valid_value(*val) => Ok(()), (Term::String { .. }, Term::StringType) => Ok(()), @@ -517,6 +733,85 @@ pub enum TermTypeError { InvalidValue(TypeArg), } +/// Part of a sequence. +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub enum SeqPart { + /// An individual item in the sequence. + Item(T), + /// A subsequence that is spliced into the parent sequence. + Splice(T), +} + +/// Iterator created by [`TypeArg::into_list_parts`]. +#[derive(Debug, Clone)] +pub struct ListPartIter { + parts: SmallVec<[SeqPart; 1]>, +} + +impl ListPartIter { + #[inline] + fn new(part: SeqPart) -> Self { + Self { + parts: smallvec![part], + } + } +} + +impl Iterator for ListPartIter { + type Item = SeqPart; + + fn next(&mut self) -> Option { + loop { + match self.parts.pop()? { + SeqPart::Splice(TypeArg::List(elems)) => self + .parts + .extend(elems.into_iter().rev().map(SeqPart::Item)), + SeqPart::Splice(TypeArg::ListConcat(lists)) => self + .parts + .extend(lists.into_iter().rev().map(SeqPart::Splice)), + part => return Some(part), + } + } + } +} + +impl FusedIterator for ListPartIter {} + +/// Iterator created by [`TypeArg::into_tuple_parts`]. +#[derive(Debug, Clone)] +pub struct TuplePartIter { + parts: SmallVec<[SeqPart; 1]>, +} + +impl TuplePartIter { + #[inline] + fn new(part: SeqPart) -> Self { + Self { + parts: smallvec![part], + } + } +} + +impl Iterator for TuplePartIter { + type Item = SeqPart; + + fn next(&mut self) -> Option { + loop { + match self.parts.pop()? { + SeqPart::Splice(TypeArg::Tuple(elems)) => self + .parts + .extend(elems.into_iter().rev().map(SeqPart::Item)), + SeqPart::Splice(TypeArg::TupleConcat(tuples)) => self + .parts + .extend(tuples.into_iter().rev().map(SeqPart::Splice)), + part => return Some(part), + } + } + } +} + +impl FusedIterator for TuplePartIter {} + #[cfg(test)] mod test { use itertools::Itertools; @@ -524,8 +819,66 @@ mod test { use super::{Substitution, TypeArg, TypeParam, check_term_type}; use crate::extension::prelude::{bool_t, usize_t}; use crate::types::Term; + use crate::types::type_param::SeqPart; use crate::types::{TypeBound, TypeRV, type_param::TermTypeError}; + #[test] + fn new_list_from_parts_items() { + let a = TypeArg::new_string("a"); + let b = TypeArg::new_string("b"); + + let parts = [SeqPart::Item(a.clone()), SeqPart::Item(b.clone())]; + let items = [a, b]; + + assert_eq!( + TypeArg::new_list_from_parts(parts.clone()), + TypeArg::new_list(items.clone()) + ); + + assert_eq!( + TypeArg::new_tuple_from_parts(parts), + TypeArg::new_tuple(items) + ); + } + + #[test] + fn new_list_from_parts_flatten() { + let a = Term::new_string("a"); + let b = Term::new_string("b"); + let c = Term::new_string("c"); + let d = Term::new_string("d"); + let var = Term::new_var_use(0, Term::new_list_type(Term::StringType)); + let parts = [ + SeqPart::Splice(Term::new_list([a.clone(), b.clone()])), + SeqPart::Splice(Term::new_list_concat([Term::new_list([c.clone()])])), + SeqPart::Item(d.clone()), + SeqPart::Splice(var.clone()), + ]; + assert_eq!( + Term::new_list_from_parts(parts), + Term::new_list_concat([Term::new_list([a, b, c, d]), var]) + ); + } + + #[test] + fn new_tuple_from_parts_flatten() { + let a = Term::new_string("a"); + let b = Term::new_string("b"); + let c = Term::new_string("c"); + let d = Term::new_string("d"); + let var = Term::new_var_use(0, Term::new_tuple([Term::StringType])); + let parts = [ + SeqPart::Splice(Term::new_tuple([a.clone(), b.clone()])), + SeqPart::Splice(Term::new_tuple_concat([Term::new_tuple([c.clone()])])), + SeqPart::Item(d.clone()), + SeqPart::Splice(var.clone()), + ]; + assert_eq!( + Term::new_tuple_from_parts(parts), + Term::new_tuple_concat([Term::new_tuple([a, b, c, d]), var]) + ); + } + #[test] fn type_arg_fits_param() { let rowvar = TypeRV::new_row_var_use; @@ -592,7 +945,7 @@ mod test { // `Term::TupleType` requires a `Term::Tuple` of the same number of elems let usize_and_ty = - TypeParam::TupleType(vec![TypeParam::max_nat_type(), TypeBound::Copyable.into()]); + TypeParam::new_tuple_type([TypeParam::max_nat_type(), TypeBound::Copyable.into()]); check( TypeArg::Tuple(vec![5.into(), usize_t().into()]), &usize_and_ty, @@ -603,7 +956,10 @@ mod test { &usize_and_ty, ) .unwrap_err(); // Wrong way around - let two_types = TypeParam::TupleType(vec![TypeBound::Any.into(), TypeBound::Any.into()]); + let two_types = TypeParam::new_tuple_type(Term::new_list([ + TypeBound::Any.into(), + TypeBound::Any.into(), + ])); check(TypeArg::new_var_use(0, two_types.clone()), &two_types).unwrap(); // not a Row Var which could have any number of elems check(TypeArg::new_var_use(0, seq_param), &two_types).unwrap_err(); @@ -743,11 +1099,11 @@ mod test { .or(any_with::(depth.descend()) .prop_map(Self::new_list_type) .boxed()) - .or(vec(any_with::(depth.descend()), 0..3) + .or(any_with::(depth.descend()) .prop_map(Self::new_tuple_type) .boxed()) .or(vec(any_with::(depth.descend()), 0..3) - .prop_map(Term::new_list) + .prop_map(Self::new_list) .boxed()); } diff --git a/hugr-py/src/hugr/_serialization/tys.py b/hugr-py/src/hugr/_serialization/tys.py index 5a500ede2e..fa9bd5b517 100644 --- a/hugr-py/src/hugr/_serialization/tys.py +++ b/hugr-py/src/hugr/_serialization/tys.py @@ -207,6 +207,14 @@ def deserialize(self) -> tys.ListArg: return tys.ListArg(elems=deser_it(self.elems)) +class ListConcatArg(BaseTypeArg): + tya: Literal["ListConcat"] = "ListConcat" + lists: list[TypeArg] + + def deserialize(self) -> tys.ListConcatArg: + return tys.ListConcatArg(lists=deser_it(self.lists)) + + class TupleArg(BaseTypeArg): tya: Literal["Tuple"] = "Tuple" elems: list[TypeArg] @@ -215,6 +223,14 @@ def deserialize(self) -> tys.TupleArg: return tys.TupleArg(elems=deser_it(self.elems)) +class TupleConcatArg(BaseTypeArg): + tya: Literal["TupleConcat"] = "TupleConcat" + tuples: list[TypeArg] + + def deserialize(self) -> tys.TupleConcatArg: + return tys.TupleConcatArg(tuples=deser_it(self.tuples)) + + class VariableArg(BaseTypeArg): tya: Literal["Variable"] = "Variable" idx: int diff --git a/hugr-py/src/hugr/tys.py b/hugr-py/src/hugr/tys.py index addf7301d7..3010a22853 100644 --- a/hugr-py/src/hugr/tys.py +++ b/hugr-py/src/hugr/tys.py @@ -325,6 +325,28 @@ def to_model(self) -> model.Term: return model.List([elem.to_model() for elem in self.elems]) +@dataclass(frozen=True) +class ListConcatArg(TypeArg): + """Sequence of lists to concatenate for a :class:`ListParam`.""" + + lists: list[TypeArg] + + def _to_serial(self) -> stys.ListConcatArg: + return stys.ListConcatArg(lists=ser_it(self.lists)) + + def resolve(self, registry: ext.ExtensionRegistry) -> TypeArg: + return ListConcatArg([arg.resolve(registry) for arg in self.lists]) + + def __str__(self) -> str: + lists = comma_sep_str(f"... {list}" for list in self.lists) + return f"[{lists}]" + + def to_model(self) -> model.Term: + return model.List( + [model.Splice(cast(model.Term, elem.to_model())) for elem in self.lists] + ) + + @dataclass(frozen=True) class TupleArg(TypeArg): """Sequence of type arguments for a :class:`TupleParam`.""" @@ -344,6 +366,28 @@ def to_model(self) -> model.Term: return model.Tuple([elem.to_model() for elem in self.elems]) +@dataclass(frozen=True) +class TupleConcatArg(TypeArg): + """Sequence of tuples to concatenate for a :class:`TupleParam`.""" + + tuples: list[TypeArg] + + def _to_serial(self) -> stys.TupleConcatArg: + return stys.TupleConcatArg(tuples=ser_it(self.tuples)) + + def resolve(self, registry: ext.ExtensionRegistry) -> TypeArg: + return TupleConcatArg([arg.resolve(registry) for arg in self.tuples]) + + def __str__(self) -> str: + tuples = comma_sep_str(f"... {tuple}" for tuple in self.tuples) + return f"({tuples})" + + def to_model(self) -> model.Term: + return model.Tuple( + [model.Splice(cast(model.Term, elem.to_model())) for elem in self.tuples] + ) + + @dataclass(frozen=True) class VariableArg(TypeArg): """A type argument variable."""