diff --git a/Cargo.lock b/Cargo.lock index f22c497a77..4e9568ee17 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2,6 +2,15 @@ # It is not intended for manual editing. version = 3 +[[package]] +name = "addr2line" +version = "0.22.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6e4503c46a5c0c7844e948c9a4d6acd9f50cccb4de1c48eb9e291ea17470c678" +dependencies = [ + "gimli 0.29.0", +] + [[package]] name = "addr2line" version = "0.23.0" @@ -105,6 +114,9 @@ name = "anyhow" version = "1.0.86" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b3d1d046238990b9cf5bcde22a3fb3584ee5cf65fb2765f454ed428c7a0063da" +dependencies = [ + "backtrace", +] [[package]] name = "arbitrary" @@ -133,6 +145,21 @@ version = "1.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0c4b4d0bd25bd0b74681c0ad21497610ce1b7c91b1022cd21c80c6fbdd9476b0" +[[package]] +name = "backtrace" +version = "0.3.73" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5cc23269a4f8976d0a4d2e7109211a419fe30e8d88d677cd60b6bc79c5732e0a" +dependencies = [ + "addr2line 0.22.0", + "cc", + "cfg-if", + "libc", + "miniz_oxide", + "object", + "rustc-demangle", +] + [[package]] name = "beef" version = "0.5.2" @@ -687,6 +714,12 @@ dependencies = [ "stable_deref_trait", ] +[[package]] +name = "gimli" +version = "0.29.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "40ecd4077b5ae9fd2e9e169b102c6c330d0605168eb0e8bf79952b256dbefffd" + [[package]] name = "gimli" version = "0.30.0" @@ -1699,7 +1732,7 @@ dependencies = [ name = "wasm-tools" version = "1.214.0" dependencies = [ - "addr2line", + "addr2line 0.23.0", "anyhow", "arbitrary", "bitflags", diff --git a/Cargo.toml b/Cargo.toml index b40482b014..c0b3b68357 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -58,7 +58,7 @@ version = "0.214.0" rust-version = "1.76.0" [workspace.dependencies] -anyhow = "1.0.58" +anyhow = { version = "1.0.58", features = ["backtrace"] } arbitrary = "1.1.0" clap = { version = "4.0.0", features = ["derive"] } clap_complete = "4.4.7" diff --git a/crates/wasm-compose/src/composer.rs b/crates/wasm-compose/src/composer.rs index c2bc1644b7..4f9c2fdb49 100644 --- a/crates/wasm-compose/src/composer.rs +++ b/crates/wasm-compose/src/composer.rs @@ -495,8 +495,6 @@ impl<'a> CompositionGraphBuilder<'a> { } } - self.graph.unify_imported_resources(); - Ok((self.instances[root_instance], self.graph)) } } diff --git a/crates/wasm-compose/src/encoding.rs b/crates/wasm-compose/src/encoding.rs index daa2bd98a9..b062c57e8c 100644 --- a/crates/wasm-compose/src/encoding.rs +++ b/crates/wasm-compose/src/encoding.rs @@ -599,6 +599,17 @@ impl<'a> TypeEncoder<'a> { return ret; } + if let Some((instance, name)) = state.cur.instance_exports.get(&key) { + let ret = state.cur.encodable.type_count(); + state.cur.encodable.alias(Alias::InstanceExport { + instance: *instance, + name, + kind: ComponentExportKind::Type, + }); + log::trace!("id defined in current instance"); + return ret; + } + match id.peel_alias(&self.0.types) { Some(next) => id = next, // If there's no more aliases then fall through to the @@ -611,15 +622,17 @@ impl<'a> TypeEncoder<'a> { return match id { AnyTypeId::Core(ComponentCoreTypeId::Sub(_)) => unreachable!(), AnyTypeId::Core(ComponentCoreTypeId::Module(id)) => self.module_type(state, id), - AnyTypeId::Component(id) => match id { - ComponentAnyTypeId::Resource(_) => { - unreachable!("should have been handled in `TypeEncoder::component_entity_type`") + AnyTypeId::Component(id) => { + match id { + ComponentAnyTypeId::Resource(r) => { + unreachable!("should have been handled in `TypeEncoder::component_entity_type`: {r:?}") + } + ComponentAnyTypeId::Defined(id) => self.defined_type(state, id), + ComponentAnyTypeId::Func(id) => self.component_func_type(state, id), + ComponentAnyTypeId::Instance(id) => self.component_instance_type(state, id), + ComponentAnyTypeId::Component(id) => self.component_type(state, id), } - ComponentAnyTypeId::Defined(id) => self.defined_type(state, id), - ComponentAnyTypeId::Func(id) => self.component_func_type(state, id), - ComponentAnyTypeId::Instance(id) => self.component_instance_type(state, id), - ComponentAnyTypeId::Component(id) => self.component_type(state, id), - }, + } }; } @@ -678,6 +691,9 @@ impl<'a> TypeEncoder<'a> { state.cur.encodable.ty().defined_type().borrow(ty); index } + wasmparser::types::ComponentDefinedType::Future(ty) => self.future(state, *ty), + wasmparser::types::ComponentDefinedType::Stream(ty) => self.stream(state, *ty), + wasmparser::types::ComponentDefinedType::Error => self.error(state), } } @@ -799,6 +815,32 @@ impl<'a> TypeEncoder<'a> { } export } + + fn future( + &self, + state: &mut TypeState<'a>, + ty: Option, + ) -> u32 { + let ty = ty.map(|ty| self.component_val_type(state, ty)); + + let index = state.cur.encodable.type_count(); + state.cur.encodable.ty().defined_type().future(ty); + index + } + + fn stream(&self, state: &mut TypeState<'a>, ty: wasmparser::types::ComponentValType) -> u32 { + let ty = self.component_val_type(state, ty); + + let index = state.cur.encodable.type_count(); + state.cur.encodable.ty().defined_type().stream(ty); + index + } + + fn error(&self, state: &mut TypeState<'a>) -> u32 { + let index = state.cur.encodable.type_count(); + state.cur.encodable.ty().defined_type().error(); + index + } } /// Represents an instance index in a composition graph. @@ -1226,10 +1268,11 @@ impl DependencyRegistrar<'_, '_> { match &self.types[ty] { types::ComponentDefinedType::Primitive(_) | types::ComponentDefinedType::Enum(_) - | types::ComponentDefinedType::Flags(_) => {} - types::ComponentDefinedType::List(t) | types::ComponentDefinedType::Option(t) => { - self.val_type(*t) - } + | types::ComponentDefinedType::Flags(_) + | types::ComponentDefinedType::Error => {} + types::ComponentDefinedType::List(t) + | types::ComponentDefinedType::Option(t) + | types::ComponentDefinedType::Stream(t) => self.val_type(*t), types::ComponentDefinedType::Own(r) | types::ComponentDefinedType::Borrow(r) => { self.ty(ComponentAnyTypeId::Resource(*r)) } @@ -1258,6 +1301,11 @@ impl DependencyRegistrar<'_, '_> { self.val_type(*err); } } + types::ComponentDefinedType::Future(ty) => { + if let Some(ty) = ty { + self.val_type(*ty); + } + } } } } @@ -1415,7 +1463,7 @@ impl<'a> CompositionGraphEncoder<'a> { state.push(Encodable::Instance(InstanceType::new())); for (name, types) in exports { let (component, ty) = types[0]; - log::trace!("export {name}"); + log::trace!("export {name}: {ty:?}"); let export = TypeEncoder::new(component).export(name, ty, state); let t = match &mut state.cur.encodable { Encodable::Instance(c) => c, @@ -1431,6 +1479,7 @@ impl<'a> CompositionGraphEncoder<'a> { } } } + let instance_type = match state.pop() { Encodable::Instance(c) => c, _ => unreachable!(), diff --git a/crates/wasm-compose/src/graph.rs b/crates/wasm-compose/src/graph.rs index 9e6387b842..2e7aa07cef 100644 --- a/crates/wasm-compose/src/graph.rs +++ b/crates/wasm-compose/src/graph.rs @@ -431,7 +431,7 @@ pub(crate) struct Instance { } /// The options for encoding a composition graph. -#[derive(Debug, Clone, Copy, Eq, PartialEq, Default)] +#[derive(Debug, Copy, Clone, Eq, PartialEq, Default)] pub struct EncodeOptions { /// Whether or not to define instantiated components. /// @@ -440,7 +440,7 @@ pub struct EncodeOptions { /// The instance in the graph to export. /// - /// If `Some`, the instance's exports will be aliased and + /// If non-empty, the instance's exports will be aliased and /// exported from the resulting component. pub export: Option, @@ -500,9 +500,6 @@ impl ResourceMapping { if value.1 == export_resource { self.map.insert(export_resource, value); self.map.insert(import_resource, value); - } else { - // Can't set two different exports equal to each other -- give up. - return None; } } else { // Couldn't find an export with a name that matches this @@ -551,14 +548,19 @@ impl<'a> CompositionGraph<'a> { /// connected to exports, group them by name, and update the resource /// mapping to make all resources within each group equivalent. /// - /// This should be the last step prior to encoding, after all - /// inter-component connections have been made. It ensures that each set of - /// identical imports composed component can be merged into a single import - /// in the output component. + /// This ensures that each set of identical imports in the composed + /// components can be merged into a single import in the output component. + // + // TODO: How do we balance the need to call this early (so we can match up + // imports with exports which mutually import the same resources) with the + // need to delay decisions about where resources are coming from (so that we + // can match up imported resources with exported resources)? Right now I + // think we're erring on the side if the former at the expense of the + // latter. pub(crate) fn unify_imported_resources(&self) { let mut resource_mapping = self.resource_mapping.borrow_mut(); - let mut resource_imports = HashMap::<_, Vec<_>>::new(); + let mut resource_imports = IndexMap::<_, IndexSet<_>>::new(); for (component_id, component) in &self.components { let component = &component.component; for import_name in component.imports.keys() { @@ -575,12 +577,14 @@ impl<'a> CompositionGraph<'a> { .. } = ty { - if !resource_mapping.map.contains_key(&resource_id.resource()) { - resource_imports - .entry(vec![import_name.to_string(), export_name.to_string()]) - .or_default() - .push((*component_id, resource_id.resource())) + let set = resource_imports + .entry(vec![import_name.to_string(), export_name.to_string()]) + .or_default(); + + if let Some(pair) = resource_mapping.map.get(&resource_id.resource()) { + set.insert(*pair); } + set.insert((*component_id, resource_id.resource())); } } } @@ -588,7 +592,7 @@ impl<'a> CompositionGraph<'a> { } for resources in resource_imports.values() { - match &resources[..] { + match &resources.iter().copied().collect::>()[..] { [] => unreachable!(), [_] => {} [first, rest @ ..] => { @@ -644,10 +648,8 @@ impl<'a> CompositionGraph<'a> { .remap_component_entity(&mut import_type, remapping); remapping.reset_type_cache(); - if context - .component_entity_type(&export_type, &import_type, 0) - .is_ok() - { + let v = context.component_entity_type(&export_type, &import_type, 0); + if v.is_ok() { *self.resource_mapping.borrow_mut() = resource_mapping; true } else { @@ -697,6 +699,10 @@ impl<'a> CompositionGraph<'a> { assert!(self.components.insert(id, entry).is_none()); + if self.components.len() > 1 { + self.unify_imported_resources(); + } + Ok(id) } diff --git a/crates/wasm-encoder/src/component/builder.rs b/crates/wasm-encoder/src/component/builder.rs index 27c39ac8a0..3aa23220bc 100644 --- a/crates/wasm-encoder/src/component/builder.rs +++ b/crates/wasm-encoder/src/component/builder.rs @@ -316,7 +316,7 @@ impl ComponentBuilder { (inc(&mut self.types), self.types().function()) } - /// Declares a + /// Declares a new resource type within this component. pub fn type_resource(&mut self, rep: ValType, dtor: Option) -> u32 { self.types().resource(rep, dtor); inc(&mut self.types) @@ -372,6 +372,106 @@ impl ComponentBuilder { inc(&mut self.core_funcs) } + /// TODO: docs + pub fn async_start(&mut self, ty: u32) -> u32 { + self.canonical_functions().async_start(ty); + inc(&mut self.core_funcs) + } + + /// TODO: docs + pub fn async_return(&mut self, ty: u32) -> u32 { + self.canonical_functions().async_return(ty); + inc(&mut self.core_funcs) + } + + /// TODO: docs + pub fn future_new(&mut self, ty: u32, memory: u32) -> u32 { + self.canonical_functions().future_new(ty, memory); + inc(&mut self.core_funcs) + } + + /// TODO: docs + pub fn future_send(&mut self, ty: u32, options: O) -> u32 + where + O: IntoIterator, + O::IntoIter: ExactSizeIterator, + { + self.canonical_functions().future_send(ty, options); + inc(&mut self.core_funcs) + } + + /// TODO: docs + pub fn future_receive(&mut self, ty: u32, options: O) -> u32 + where + O: IntoIterator, + O::IntoIter: ExactSizeIterator, + { + self.canonical_functions().future_receive(ty, options); + inc(&mut self.core_funcs) + } + + /// TODO: docs + pub fn future_drop_sender(&mut self, ty: u32) -> u32 { + self.canonical_functions().future_drop_sender(ty); + inc(&mut self.core_funcs) + } + + /// TODO: docs + pub fn future_drop_receiver(&mut self, ty: u32) -> u32 { + self.canonical_functions().future_drop_receiver(ty); + inc(&mut self.core_funcs) + } + + /// TODO: docs + pub fn stream_new(&mut self, ty: u32, memory: u32) -> u32 { + self.canonical_functions().stream_new(ty, memory); + inc(&mut self.core_funcs) + } + + /// TODO: docs + pub fn stream_send(&mut self, ty: u32, options: O) -> u32 + where + O: IntoIterator, + O::IntoIter: ExactSizeIterator, + { + self.canonical_functions().stream_send(ty, options); + inc(&mut self.core_funcs) + } + + /// TODO: docs + pub fn stream_receive(&mut self, ty: u32, options: O) -> u32 + where + O: IntoIterator, + O::IntoIter: ExactSizeIterator, + { + self.canonical_functions().stream_receive(ty, options); + inc(&mut self.core_funcs) + } + + /// TODO: docs + pub fn stream_drop_sender(&mut self, ty: u32) -> u32 { + self.canonical_functions().stream_drop_sender(ty); + inc(&mut self.core_funcs) + } + + /// TODO: docs + pub fn stream_drop_receiver(&mut self, ty: u32) -> u32 { + self.canonical_functions().stream_drop_receiver(ty); + inc(&mut self.core_funcs) + } + + /// TODO: docs + pub fn error_drop(&mut self) -> u32 { + self.canonical_functions().error_drop(); + inc(&mut self.core_funcs) + } + + /// TODO: docs + pub fn task_wait(&mut self, memory: u32) -> u32 { + self.canonical_functions().task_wait(memory); + inc(&mut self.core_funcs) + } + /// Adds a new custom section to this component. pub fn custom_section(&mut self, section: &CustomSection<'_>) { self.flush(); diff --git a/crates/wasm-encoder/src/component/canonicals.rs b/crates/wasm-encoder/src/component/canonicals.rs index 340d9ca621..5bbe5d5df6 100644 --- a/crates/wasm-encoder/src/component/canonicals.rs +++ b/crates/wasm-encoder/src/component/canonicals.rs @@ -21,6 +21,10 @@ pub enum CanonicalOption { /// The post-return function to use if the lifting of a function requires /// cleanup after the function returns. PostReturn(u32), + /// TODO: docs + Async, + /// TODO: docs + Callback(u32), } impl Encode for CanonicalOption { @@ -41,6 +45,13 @@ impl Encode for CanonicalOption { sink.push(0x05); idx.encode(sink); } + Self::Async => { + sink.push(0x06); + } + Self::Callback(idx) => { + sink.push(0x07); + idx.encode(sink); + } } } } @@ -144,6 +155,155 @@ impl CanonicalFunctionSection { self.num_added += 1; self } + + /// TODO: docs + pub fn async_start(&mut self, ty: u32) -> &mut Self { + self.bytes.push(0x05); + ty.encode(&mut self.bytes); + self.num_added += 1; + self + } + + /// TODO: docs + pub fn async_return(&mut self, ty: u32) -> &mut Self { + self.bytes.push(0x06); + ty.encode(&mut self.bytes); + self.num_added += 1; + self + } + + /// TODO: docs + pub fn future_new(&mut self, ty: u32, memory: u32) -> &mut Self { + self.bytes.push(0x07); + ty.encode(&mut self.bytes); + memory.encode(&mut self.bytes); + self.num_added += 1; + self + } + + /// TODO: docs + pub fn future_send(&mut self, ty: u32, options: O) -> &mut Self + where + O: IntoIterator, + O::IntoIter: ExactSizeIterator, + { + self.bytes.push(0x08); + ty.encode(&mut self.bytes); + let options = options.into_iter(); + options.len().encode(&mut self.bytes); + for option in options { + option.encode(&mut self.bytes); + } + self.num_added += 1; + self + } + + /// TODO: docs + pub fn future_receive(&mut self, ty: u32, options: O) -> &mut Self + where + O: IntoIterator, + O::IntoIter: ExactSizeIterator, + { + self.bytes.push(0x09); + ty.encode(&mut self.bytes); + let options = options.into_iter(); + options.len().encode(&mut self.bytes); + for option in options { + option.encode(&mut self.bytes); + } + self.num_added += 1; + self + } + + /// TODO: docs + pub fn future_drop_sender(&mut self, ty: u32) -> &mut Self { + self.bytes.push(0x0a); + ty.encode(&mut self.bytes); + self.num_added += 1; + self + } + + /// TODO: docs + pub fn future_drop_receiver(&mut self, ty: u32) -> &mut Self { + self.bytes.push(0x0b); + ty.encode(&mut self.bytes); + self.num_added += 1; + self + } + + /// TODO: docs + pub fn stream_new(&mut self, ty: u32, memory: u32) -> &mut Self { + self.bytes.push(0x0c); + ty.encode(&mut self.bytes); + memory.encode(&mut self.bytes); + self.num_added += 1; + self + } + + /// TODO: docs + pub fn stream_send(&mut self, ty: u32, options: O) -> &mut Self + where + O: IntoIterator, + O::IntoIter: ExactSizeIterator, + { + self.bytes.push(0x0d); + ty.encode(&mut self.bytes); + let options = options.into_iter(); + options.len().encode(&mut self.bytes); + for option in options { + option.encode(&mut self.bytes); + } + self.num_added += 1; + self + } + + /// TODO: docs + pub fn stream_receive(&mut self, ty: u32, options: O) -> &mut Self + where + O: IntoIterator, + O::IntoIter: ExactSizeIterator, + { + self.bytes.push(0x0e); + ty.encode(&mut self.bytes); + let options = options.into_iter(); + options.len().encode(&mut self.bytes); + for option in options { + option.encode(&mut self.bytes); + } + self.num_added += 1; + self + } + + /// TODO: docs + pub fn stream_drop_sender(&mut self, ty: u32) -> &mut Self { + self.bytes.push(0x0f); + ty.encode(&mut self.bytes); + self.num_added += 1; + self + } + + /// TODO: docs + pub fn stream_drop_receiver(&mut self, ty: u32) -> &mut Self { + self.bytes.push(0x10); + ty.encode(&mut self.bytes); + self.num_added += 1; + self + } + + /// TODO: docs + pub fn error_drop(&mut self) -> &mut Self { + self.bytes.push(0x11); + self.num_added += 1; + self + } + + /// TODO: docs + pub fn task_wait(&mut self, memory: u32) -> &mut Self { + self.bytes.push(0x12); + memory.encode(&mut self.bytes); + self.num_added += 1; + self + } } impl Encode for CanonicalFunctionSection { diff --git a/crates/wasm-encoder/src/component/types.rs b/crates/wasm-encoder/src/component/types.rs index 08e37da245..ed248b7746 100644 --- a/crates/wasm-encoder/src/component/types.rs +++ b/crates/wasm-encoder/src/component/types.rs @@ -668,6 +668,23 @@ impl ComponentDefinedTypeEncoder<'_> { self.0.push(0x68); idx.encode(self.0); } + + /// TODO: docs + pub fn future(self, payload: Option) { + self.0.push(0x67); + payload.encode(self.0); + } + + /// TODO: docs + pub fn stream(self, payload: ComponentValType) { + self.0.push(0x66); + payload.encode(self.0); + } + + /// TODO: docs + pub fn error(self) { + self.0.push(0x65); + } } /// An encoder for the type section of WebAssembly components. diff --git a/crates/wasm-encoder/src/reencode.rs b/crates/wasm-encoder/src/reencode.rs index 6c60ba550f..d1b441b208 100644 --- a/crates/wasm-encoder/src/reencode.rs +++ b/crates/wasm-encoder/src/reencode.rs @@ -956,6 +956,10 @@ pub mod utils { wasmparser::CanonicalOption::PostReturn(u) => { crate::component::CanonicalOption::PostReturn(reencoder.function_index(u)) } + wasmparser::CanonicalOption::Async => crate::component::CanonicalOption::Async, + wasmparser::CanonicalOption::Callback(u) => { + crate::component::CanonicalOption::Callback(u) + } } } diff --git a/crates/wasmparser/src/readers/component/canonicals.rs b/crates/wasmparser/src/readers/component/canonicals.rs index 53bfe9da2b..f5a0c089c9 100644 --- a/crates/wasmparser/src/readers/component/canonicals.rs +++ b/crates/wasmparser/src/readers/component/canonicals.rs @@ -23,6 +23,10 @@ pub enum CanonicalOption { /// The post-return function to use if the lifting of a function requires /// cleanup after the function returns. PostReturn(u32), + /// TODO: docs + Async, + /// TODO: docs + Callback(u32), } /// Represents a canonical function in a WebAssembly component. @@ -60,6 +64,85 @@ pub enum CanonicalFunction { /// The type index of the resource that's being accessed. resource: u32, }, + /// TODO: docs + AsyncStart { + /// TODO: docs + component_type_index: u32, + }, + /// TODO: docs + AsyncReturn { + /// TODO: docs + component_type_index: u32, + }, + /// TODO: docs + FutureNew { + /// TODO: docs + ty: u32, + /// TODO: docs + memory: u32, + }, + /// TODO: docs + FutureSend { + /// TODO: docs + ty: u32, + /// TODO: docs + options: Box<[CanonicalOption]>, + }, + /// TODO: docs + FutureReceive { + /// TODO: docs + ty: u32, + /// TODO: docs + options: Box<[CanonicalOption]>, + }, + /// TODO: docs + FutureDropSender { + /// TODO: docs + ty: u32, + }, + /// TODO: docs + FutureDropReceiver { + /// TODO: docs + ty: u32, + }, + /// TODO: docs + StreamNew { + /// TODO: docs + ty: u32, + /// TODO: docs + memory: u32, + }, + /// TODO: docs + StreamSend { + /// TODO: docs + ty: u32, + /// TODO: docs + options: Box<[CanonicalOption]>, + }, + /// TODO: docs + StreamReceive { + /// TODO: docs + ty: u32, + /// TODO: docs + options: Box<[CanonicalOption]>, + }, + /// TODO: docs + StreamDropSender { + /// TODO: docs + ty: u32, + }, + /// TODO: docs + StreamDropReceiver { + /// TODO: docs + ty: u32, + }, + /// TODO: docs + ErrorDrop, + /// TODO: docs + TaskWait { + /// TODO: docs + memory: u32, + }, } /// A reader for the canonical section of a WebAssembly component. @@ -101,6 +184,52 @@ impl<'a> FromReader<'a> for CanonicalFunction { 0x04 => CanonicalFunction::ResourceRep { resource: reader.read()?, }, + 0x05 => CanonicalFunction::AsyncStart { + component_type_index: reader.read()?, + }, + 0x06 => CanonicalFunction::AsyncReturn { + component_type_index: reader.read()?, + }, + 0x07 => CanonicalFunction::FutureNew { + ty: reader.read()?, + memory: reader.read()?, + }, + 0x08 => CanonicalFunction::FutureSend { + ty: reader.read()?, + options: reader + .read_iter(MAX_WASM_CANONICAL_OPTIONS, "canonical options")? + .collect::>()?, + }, + 0x09 => CanonicalFunction::FutureReceive { + ty: reader.read()?, + options: reader + .read_iter(MAX_WASM_CANONICAL_OPTIONS, "canonical options")? + .collect::>()?, + }, + 0x0a => CanonicalFunction::FutureDropSender { ty: reader.read()? }, + 0x0b => CanonicalFunction::FutureDropReceiver { ty: reader.read()? }, + 0x0c => CanonicalFunction::StreamNew { + ty: reader.read()?, + memory: reader.read()?, + }, + 0x0d => CanonicalFunction::StreamSend { + ty: reader.read()?, + options: reader + .read_iter(MAX_WASM_CANONICAL_OPTIONS, "canonical options")? + .collect::>()?, + }, + 0x0e => CanonicalFunction::StreamReceive { + ty: reader.read()?, + options: reader + .read_iter(MAX_WASM_CANONICAL_OPTIONS, "canonical options")? + .collect::>()?, + }, + 0x0f => CanonicalFunction::StreamDropSender { ty: reader.read()? }, + 0x10 => CanonicalFunction::StreamDropReceiver { ty: reader.read()? }, + 0x11 => CanonicalFunction::ErrorDrop, + 0x12 => CanonicalFunction::TaskWait { + memory: reader.read()?, + }, x => return reader.invalid_leading_byte(x, "canonical function"), }) } @@ -115,6 +244,8 @@ impl<'a> FromReader<'a> for CanonicalOption { 0x03 => CanonicalOption::Memory(reader.read_var_u32()?), 0x04 => CanonicalOption::Realloc(reader.read_var_u32()?), 0x05 => CanonicalOption::PostReturn(reader.read_var_u32()?), + 0x06 => CanonicalOption::Async, + 0x07 => CanonicalOption::Callback(reader.read_var_u32()?), x => return reader.invalid_leading_byte(x, "canonical option"), }) } diff --git a/crates/wasmparser/src/readers/component/types.rs b/crates/wasmparser/src/readers/component/types.rs index 1e9c8fa554..8bc730c617 100644 --- a/crates/wasmparser/src/readers/component/types.rs +++ b/crates/wasmparser/src/readers/component/types.rs @@ -493,6 +493,12 @@ pub enum ComponentDefinedType<'a> { Own(u32), /// A borrowed handle to a resource. Borrow(u32), + /// TODO: docs + Future(Option), + /// TODO: docs + Stream(ComponentValType), + /// TODO: docs + Error, } impl<'a> ComponentDefinedType<'a> { @@ -532,6 +538,9 @@ impl<'a> ComponentDefinedType<'a> { }, 0x69 => ComponentDefinedType::Own(reader.read()?), 0x68 => ComponentDefinedType::Borrow(reader.read()?), + 0x67 => ComponentDefinedType::Future(reader.read()?), + 0x66 => ComponentDefinedType::Stream(reader.read()?), + 0x65 => ComponentDefinedType::Error, x => return reader.invalid_leading_byte(x, "component defined type"), }) } diff --git a/crates/wasmparser/src/validator.rs b/crates/wasmparser/src/validator.rs index 5425fad7d2..163372933d 100644 --- a/crates/wasmparser/src/validator.rs +++ b/crates/wasmparser/src/validator.rs @@ -1229,6 +1229,46 @@ impl Validator { crate::CanonicalFunction::ResourceRep { resource } => { current.resource_rep(resource, types, offset) } + crate::CanonicalFunction::AsyncStart { + component_type_index, + } => current.async_start(component_type_index, types, offset), + crate::CanonicalFunction::AsyncReturn { + component_type_index, + } => current.async_return(component_type_index, types, offset), + crate::CanonicalFunction::FutureNew { ty, memory } => { + current.future_new(ty, memory, types, offset) + } + crate::CanonicalFunction::FutureSend { ty, options } => { + current.future_send(ty, options.into_vec(), types, offset) + } + crate::CanonicalFunction::FutureReceive { ty, options } => { + current.future_receive(ty, options.into_vec(), types, offset) + } + crate::CanonicalFunction::FutureDropSender { ty } => { + current.future_drop_sender(ty, types, offset) + } + crate::CanonicalFunction::FutureDropReceiver { ty } => { + current.future_drop_receiver(ty, types, offset) + } + crate::CanonicalFunction::StreamNew { ty, memory } => { + current.stream_new(ty, memory, types, offset) + } + crate::CanonicalFunction::StreamSend { ty, options } => { + current.stream_send(ty, options.into_vec(), types, offset) + } + crate::CanonicalFunction::StreamReceive { ty, options } => { + current.stream_receive(ty, options.into_vec(), types, offset) + } + crate::CanonicalFunction::StreamDropSender { ty } => { + current.stream_drop_sender(ty, types, offset) + } + crate::CanonicalFunction::StreamDropReceiver { ty } => { + current.stream_drop_receiver(ty, types, offset) + } + crate::CanonicalFunction::ErrorDrop => current.error_drop(types, offset), + crate::CanonicalFunction::TaskWait { memory } => { + current.task_wait(memory, types, offset) + } } }, ) diff --git a/crates/wasmparser/src/validator/component.rs b/crates/wasmparser/src/validator/component.rs index 405a03c573..4b92cc9695 100644 --- a/crates/wasmparser/src/validator/component.rs +++ b/crates/wasmparser/src/validator/component.rs @@ -727,7 +727,8 @@ impl ComponentState { // named. ComponentDefinedType::Primitive(_) | ComponentDefinedType::Flags(_) - | ComponentDefinedType::Enum(_) => true, + | ComponentDefinedType::Enum(_) + | ComponentDefinedType::Error => true, // Referenced types of all these aggregates must all be // named. @@ -759,6 +760,12 @@ impl ComponentState { ComponentDefinedType::Own(id) | ComponentDefinedType::Borrow(id) => { set.contains(&ComponentAnyTypeId::from(*id)) } + + ComponentDefinedType::Future(ty) => ty + .as_ref() + .map(|ty| types.type_named_valtype(ty, set)) + .unwrap_or(true), + ComponentDefinedType::Stream(ty) => types.type_named_valtype(ty, set), } } @@ -961,7 +968,7 @@ impl ComponentState { // Lifting a function is for an export, so match the expected canonical ABI // export signature - let info = ty.lower(types, false); + let info = ty.lower(types, false, options.contains(&CanonicalOption::Async)); self.check_options(Some(core_ty), &info, &options, types, offset)?; if core_ty.params() != info.params.as_slice() { @@ -1001,7 +1008,7 @@ impl ComponentState { // Lowering a function is for an import, so use a function type that matches // the expected canonical ABI import signature. - let info = ty.lower(types, true); + let info = ty.lower(types, true, options.contains(&CanonicalOption::Async)); self.check_options(None, &info, &options, types, offset)?; @@ -1092,6 +1099,396 @@ impl ComponentState { Ok(()) } + pub fn async_start( + &mut self, + component_type_index: u32, + types: &mut TypeAlloc, + offset: usize, + ) -> Result<()> { + let mut component_type = self + .function_type_at(component_type_index, types, offset)? + .clone(); + component_type.results = Vec::from(mem::replace(&mut component_type.params, Box::new([]))) + .into_iter() + .map(|(k, v)| (Some(k), v)) + .collect(); + let info = component_type.lower(types, true, false); + + let (_is_new, group_id) = types.intern_canonical_rec_group(RecGroup::implicit( + offset, + SubType { + is_final: true, + supertype_idx: None, + composite_type: CompositeType { + inner: CompositeInnerType::Func(FuncType::new( + info.params.iter(), + info.results.iter(), + )), + shared: false, + }, + }, + )); + let id = types[group_id].start; + self.core_funcs.push(id); + Ok(()) + } + + pub fn async_return( + &mut self, + component_type_index: u32, + types: &mut TypeAlloc, + offset: usize, + ) -> Result<()> { + let mut component_type = self + .function_type_at(component_type_index, types, offset)? + .clone(); + component_type.params = Vec::from(mem::replace(&mut component_type.results, Box::new([]))) + .into_iter() + .enumerate() + .map(|(index, (k, v))| { + ( + k.unwrap_or_else(|| KebabString::new(format!("p{index}")).unwrap()), + v, + ) + }) + .collect(); + let info = component_type.lower(types, true, false); + + let (_is_new, group_id) = types.intern_canonical_rec_group(RecGroup::implicit( + offset, + SubType { + is_final: true, + supertype_idx: None, + composite_type: CompositeType { + inner: CompositeInnerType::Func(FuncType::new( + info.params.iter(), + info.results.iter(), + )), + shared: false, + }, + }, + )); + let id = types[group_id].start; + self.core_funcs.push(id); + Ok(()) + } + + pub fn future_new( + &mut self, + ty: u32, + memory: u32, + types: &mut TypeAlloc, + offset: usize, + ) -> Result<()> { + self.defined_type_at(ty, offset)?; + self.memory_at(memory, offset)?; + + let (_is_new, group_id) = types.intern_canonical_rec_group(RecGroup::implicit( + offset, + SubType { + is_final: true, + supertype_idx: None, + composite_type: CompositeType { + inner: CompositeInnerType::Func(FuncType::new([ValType::I32], [])), + shared: false, + }, + }, + )); + let id = types[group_id].start; + self.core_funcs.push(id); + Ok(()) + } + + pub fn future_send( + &mut self, + ty: u32, + options: Vec, + types: &mut TypeAlloc, + offset: usize, + ) -> Result<()> { + self.defined_type_at(ty, offset)?; + + let mut info = LoweringInfo::default(); + info.requires_memory = true; + info.requires_realloc = false; + self.check_options(None, &info, &options, types, offset)?; + + let (_is_new, group_id) = types.intern_canonical_rec_group(RecGroup::implicit( + offset, + SubType { + is_final: true, + supertype_idx: None, + composite_type: CompositeType { + inner: CompositeInnerType::Func(FuncType::new( + [ValType::I32; 3], + [ValType::I32], + )), + shared: false, + }, + }, + )); + let id = types[group_id].start; + self.core_funcs.push(id); + Ok(()) + } + + pub fn future_receive( + &mut self, + ty: u32, + options: Vec, + types: &mut TypeAlloc, + offset: usize, + ) -> Result<()> { + let ty = self.defined_type_at(ty, offset)?; + + let mut info = LoweringInfo::default(); + info.requires_memory = true; + info.requires_realloc = ComponentValType::Type(ty).contains_ptr(types); + self.check_options(None, &info, &options, types, offset)?; + + let (_is_new, group_id) = types.intern_canonical_rec_group(RecGroup::implicit( + offset, + SubType { + is_final: true, + supertype_idx: None, + composite_type: CompositeType { + inner: CompositeInnerType::Func(FuncType::new( + [ValType::I32; 3], + [ValType::I32], + )), + shared: false, + }, + }, + )); + let id = types[group_id].start; + self.core_funcs.push(id); + Ok(()) + } + + pub fn future_drop_sender( + &mut self, + ty: u32, + types: &mut TypeAlloc, + offset: usize, + ) -> Result<()> { + self.defined_type_at(ty, offset)?; + + let (_is_new, group_id) = types.intern_canonical_rec_group(RecGroup::implicit( + offset, + SubType { + is_final: true, + supertype_idx: None, + composite_type: CompositeType { + inner: CompositeInnerType::Func(FuncType::new([ValType::I32], [])), + shared: false, + }, + }, + )); + let id = types[group_id].start; + self.core_funcs.push(id); + Ok(()) + } + + pub fn future_drop_receiver( + &mut self, + ty: u32, + types: &mut TypeAlloc, + offset: usize, + ) -> Result<()> { + self.defined_type_at(ty, offset)?; + + let (_is_new, group_id) = types.intern_canonical_rec_group(RecGroup::implicit( + offset, + SubType { + is_final: true, + supertype_idx: None, + composite_type: CompositeType { + inner: CompositeInnerType::Func(FuncType::new([ValType::I32], [])), + shared: false, + }, + }, + )); + let id = types[group_id].start; + self.core_funcs.push(id); + Ok(()) + } + + pub fn stream_new( + &mut self, + ty: u32, + memory: u32, + types: &mut TypeAlloc, + offset: usize, + ) -> Result<()> { + self.defined_type_at(ty, offset)?; + self.memory_at(memory, offset)?; + + let (_is_new, group_id) = types.intern_canonical_rec_group(RecGroup::implicit( + offset, + SubType { + is_final: true, + supertype_idx: None, + composite_type: CompositeType { + inner: CompositeInnerType::Func(FuncType::new([ValType::I32], [])), + shared: false, + }, + }, + )); + let id = types[group_id].start; + self.core_funcs.push(id); + Ok(()) + } + + pub fn stream_send( + &mut self, + ty: u32, + options: Vec, + types: &mut TypeAlloc, + offset: usize, + ) -> Result<()> { + self.defined_type_at(ty, offset)?; + + let mut info = LoweringInfo::default(); + info.requires_memory = true; + info.requires_realloc = false; + self.check_options(None, &info, &options, types, offset)?; + + let (_is_new, group_id) = types.intern_canonical_rec_group(RecGroup::implicit( + offset, + SubType { + is_final: true, + supertype_idx: None, + composite_type: CompositeType { + inner: CompositeInnerType::Func(FuncType::new( + [ValType::I32; 3], + [ValType::I32], + )), + shared: false, + }, + }, + )); + let id = types[group_id].start; + self.core_funcs.push(id); + Ok(()) + } + + pub fn stream_receive( + &mut self, + ty: u32, + options: Vec, + types: &mut TypeAlloc, + offset: usize, + ) -> Result<()> { + self.defined_type_at(ty, offset)?; + + let mut info = LoweringInfo::default(); + info.requires_memory = true; + info.requires_realloc = true; + self.check_options(None, &info, &options, types, offset)?; + + let (_is_new, group_id) = types.intern_canonical_rec_group(RecGroup::implicit( + offset, + SubType { + is_final: true, + supertype_idx: None, + composite_type: CompositeType { + inner: CompositeInnerType::Func(FuncType::new( + [ValType::I32; 3], + [ValType::I32], + )), + shared: false, + }, + }, + )); + let id = types[group_id].start; + self.core_funcs.push(id); + Ok(()) + } + + pub fn stream_drop_sender( + &mut self, + ty: u32, + types: &mut TypeAlloc, + offset: usize, + ) -> Result<()> { + self.defined_type_at(ty, offset)?; + + let (_is_new, group_id) = types.intern_canonical_rec_group(RecGroup::implicit( + offset, + SubType { + is_final: true, + supertype_idx: None, + composite_type: CompositeType { + inner: CompositeInnerType::Func(FuncType::new([ValType::I32], [])), + shared: false, + }, + }, + )); + let id = types[group_id].start; + self.core_funcs.push(id); + Ok(()) + } + + pub fn stream_drop_receiver( + &mut self, + ty: u32, + types: &mut TypeAlloc, + offset: usize, + ) -> Result<()> { + self.defined_type_at(ty, offset)?; + + let (_is_new, group_id) = types.intern_canonical_rec_group(RecGroup::implicit( + offset, + SubType { + is_final: true, + supertype_idx: None, + composite_type: CompositeType { + inner: CompositeInnerType::Func(FuncType::new([ValType::I32], [])), + shared: false, + }, + }, + )); + let id = types[group_id].start; + self.core_funcs.push(id); + Ok(()) + } + + pub fn error_drop(&mut self, types: &mut TypeAlloc, offset: usize) -> Result<()> { + let (_is_new, group_id) = types.intern_canonical_rec_group(RecGroup::implicit( + offset, + SubType { + is_final: true, + supertype_idx: None, + composite_type: CompositeType { + inner: CompositeInnerType::Func(FuncType::new([ValType::I32], [])), + shared: false, + }, + }, + )); + let id = types[group_id].start; + self.core_funcs.push(id); + Ok(()) + } + + pub fn task_wait(&mut self, memory: u32, types: &mut TypeAlloc, offset: usize) -> Result<()> { + self.memory_at(memory, offset)?; + + let (_is_new, group_id) = types.intern_canonical_rec_group(RecGroup::implicit( + offset, + SubType { + is_final: true, + supertype_idx: None, + composite_type: CompositeType { + inner: CompositeInnerType::Func(FuncType::new([ValType::I32], [ValType::I32])), + shared: false, + }, + }, + )); + let id = types[group_id].start; + self.core_funcs.push(id); + Ok(()) + } + fn check_local_resource(&self, idx: u32, types: &TypeList, offset: usize) -> Result { let resource = self.resource_at(idx, types, offset)?; match self @@ -1274,6 +1671,8 @@ impl ComponentState { CanonicalOption::Memory(_) => "memory", CanonicalOption::Realloc(_) => "realloc", CanonicalOption::PostReturn(_) => "post-return", + CanonicalOption::Async => "async", + CanonicalOption::Callback(_) => "callback", } } @@ -1362,6 +1761,8 @@ impl ComponentState { } } } + // TODO + CanonicalOption::Async | CanonicalOption::Callback(_) => {} } } @@ -1796,7 +2197,7 @@ impl ComponentState { cx.entity_type(arg, expected, offset).with_context(|| { format!( "type mismatch for export `{name}` of module \ - instantiation argument `{module}`" + instantiation argument `{module}`" ) })?; } @@ -2577,6 +2978,14 @@ impl ComponentState { crate::ComponentDefinedType::Borrow(idx) => Ok(ComponentDefinedType::Borrow( self.resource_at(idx, types, offset)?, )), + crate::ComponentDefinedType::Future(ty) => Ok(ComponentDefinedType::Future( + ty.map(|ty| self.create_component_val_type(ty, offset)) + .transpose()?, + )), + crate::ComponentDefinedType::Stream(ty) => Ok(ComponentDefinedType::Stream( + self.create_component_val_type(ty, offset)?, + )), + crate::ComponentDefinedType::Error => Ok(ComponentDefinedType::Error), } } diff --git a/crates/wasmparser/src/validator/types.rs b/crates/wasmparser/src/validator/types.rs index d713fbeee7..4343984d2a 100644 --- a/crates/wasmparser/src/validator/types.rs +++ b/crates/wasmparser/src/validator/types.rs @@ -1183,9 +1183,23 @@ impl TypeData for ComponentFuncType { impl ComponentFuncType { /// Lowers the component function type to core parameter and result types for the /// canonical ABI. - pub(crate) fn lower(&self, types: &TypeList, is_lower: bool) -> LoweringInfo { + pub(crate) fn lower(&self, types: &TypeList, is_lower: bool, async_: bool) -> LoweringInfo { let mut info = LoweringInfo::default(); + if async_ { + if is_lower { + for _ in 0..3 { + info.params.push(ValType::I32); + } + info.results.push(ValType::I32); + info.requires_memory = true; + info.requires_realloc = self.results.iter().any(|(_, ty)| ty.contains_ptr(types)); + } else { + info.results.push(ValType::I32); + } + return info; + } + for (_, ty) in self.params.iter() { // Check to see if `ty` has a pointer somewhere in it, needed for // any type that transitively contains either a string or a list. @@ -1316,6 +1330,12 @@ pub enum ComponentDefinedType { Own(AliasableResourceId), /// The type is a borrowed handle to the specified resource. Borrow(AliasableResourceId), + /// TODO: docs + Future(Option), + /// TODO: docs + Stream(ComponentValType), + /// TODO: docs + Error, } impl TypeData for ComponentDefinedType { @@ -1323,7 +1343,13 @@ impl TypeData for ComponentDefinedType { fn type_info(&self, types: &TypeList) -> TypeInfo { match self { - Self::Primitive(_) | Self::Flags(_) | Self::Enum(_) | Self::Own(_) => TypeInfo::new(), + Self::Primitive(_) + | Self::Flags(_) + | Self::Enum(_) + | Self::Own(_) + | Self::Future(_) + | Self::Stream(_) + | Self::Error => TypeInfo::new(), Self::Borrow(_) => TypeInfo::borrow(), Self::Record(r) => r.info, Self::Variant(v) => v.info, @@ -1351,7 +1377,13 @@ impl ComponentDefinedType { .any(|case| case.ty.map(|ty| ty.contains_ptr(types)).unwrap_or(false)), Self::List(_) => true, Self::Tuple(t) => t.types.iter().any(|ty| ty.contains_ptr(types)), - Self::Flags(_) | Self::Enum(_) | Self::Own(_) | Self::Borrow(_) => false, + Self::Flags(_) + | Self::Enum(_) + | Self::Own(_) + | Self::Borrow(_) + | Self::Future(_) + | Self::Stream(_) + | Self::Error => false, Self::Option(ty) => ty.contains_ptr(types), Self::Result { ok, err } => { ok.map(|ty| ty.contains_ptr(types)).unwrap_or(false) @@ -1380,7 +1412,12 @@ impl ComponentDefinedType { Self::Flags(names) => { (0..(names.len() + 31) / 32).all(|_| lowered_types.push(ValType::I32)) } - Self::Enum(_) | Self::Own(_) | Self::Borrow(_) => lowered_types.push(ValType::I32), + Self::Enum(_) + | Self::Own(_) + | Self::Borrow(_) + | Self::Future(_) + | Self::Stream(_) + | Self::Error => lowered_types.push(ValType::I32), Self::Option(ty) => { Self::push_variant_wasm_types([ty].into_iter(), types, lowered_types) } @@ -1448,6 +1485,9 @@ impl ComponentDefinedType { ComponentDefinedType::Result { .. } => "result", ComponentDefinedType::Own(_) => "own", ComponentDefinedType::Borrow(_) => "borrow", + ComponentDefinedType::Future(_) => "future", + ComponentDefinedType::Stream(_) => "stream", + ComponentDefinedType::Error => "error", } } } @@ -1617,6 +1657,16 @@ impl<'a> TypesRef<'a> { } } + /// TODO: docs + pub fn component_val_type(&self, ty: crate::ComponentValType) -> ComponentValType { + match ty { + crate::ComponentValType::Primitive(ty) => ComponentValType::Primitive(ty), + crate::ComponentValType::Type(ty) => { + ComponentValType::Type(self.component_defined_type_at(ty)) + } + } + } + /// Returns the number of core types defined so far. pub fn core_type_count(&self) -> u32 { match &self.kind { @@ -3170,7 +3220,8 @@ impl TypeAlloc { match &self[id] { ComponentDefinedType::Primitive(_) | ComponentDefinedType::Flags(_) - | ComponentDefinedType::Enum(_) => {} + | ComponentDefinedType::Enum(_) + | ComponentDefinedType::Error => {} ComponentDefinedType::Record(r) => { for ty in r.fields.values() { self.free_variables_valtype(ty, set); @@ -3188,7 +3239,9 @@ impl TypeAlloc { } } } - ComponentDefinedType::List(ty) | ComponentDefinedType::Option(ty) => { + ComponentDefinedType::List(ty) + | ComponentDefinedType::Option(ty) + | ComponentDefinedType::Stream(ty) => { self.free_variables_valtype(ty, set); } ComponentDefinedType::Result { ok, err } => { @@ -3202,6 +3255,11 @@ impl TypeAlloc { ComponentDefinedType::Own(id) | ComponentDefinedType::Borrow(id) => { set.insert(id.resource()); } + ComponentDefinedType::Future(ty) => { + if let Some(ty) = ty { + self.free_variables_valtype(ty, set); + } + } } } @@ -3302,7 +3360,7 @@ impl TypeAlloc { let ty = &self[id]; match ty { // Primitives are always considered named - ComponentDefinedType::Primitive(_) => true, + ComponentDefinedType::Primitive(_) | ComponentDefinedType::Error => true, // These structures are never allowed to be anonymous, so they // themselves must be named. @@ -3325,15 +3383,20 @@ impl TypeAlloc { .map(|t| self.type_named_valtype(t, set)) .unwrap_or(true) } - ComponentDefinedType::List(ty) | ComponentDefinedType::Option(ty) => { - self.type_named_valtype(ty, set) - } + ComponentDefinedType::List(ty) + | ComponentDefinedType::Option(ty) + | ComponentDefinedType::Stream(ty) => self.type_named_valtype(ty, set), // own/borrow themselves don't have to be named, but the resource // they refer to must be named. ComponentDefinedType::Own(id) | ComponentDefinedType::Borrow(id) => { set.contains(&ComponentAnyTypeId::from(*id)) } + + ComponentDefinedType::Future(ty) => ty + .as_ref() + .map(|ty| self.type_named_valtype(ty, set)) + .unwrap_or(true), } } @@ -3484,7 +3547,8 @@ where match &mut tmp { ComponentDefinedType::Primitive(_) | ComponentDefinedType::Flags(_) - | ComponentDefinedType::Enum(_) => {} + | ComponentDefinedType::Enum(_) + | ComponentDefinedType::Error => {} ComponentDefinedType::Record(r) => { for ty in r.fields.values_mut() { any_changed |= self.remap_valtype(ty, map); @@ -3502,7 +3566,9 @@ where } } } - ComponentDefinedType::List(ty) | ComponentDefinedType::Option(ty) => { + ComponentDefinedType::List(ty) + | ComponentDefinedType::Option(ty) + | ComponentDefinedType::Stream(ty) => { any_changed |= self.remap_valtype(ty, map); } ComponentDefinedType::Result { ok, err } => { @@ -3516,6 +3582,11 @@ where ComponentDefinedType::Own(id) | ComponentDefinedType::Borrow(id) => { any_changed |= self.remap_resource_id(id, map); } + ComponentDefinedType::Future(ty) => { + if let Some(ty) = ty { + any_changed |= self.remap_valtype(ty, map); + } + } } self.insert_if_any_changed(map, any_changed, id, tmp) } @@ -4455,6 +4526,21 @@ impl<'a> SubtypeCx<'a> { } (Own(_), b) => bail!(offset, "expected {}, found own", b.desc()), (Borrow(_), b) => bail!(offset, "expected {}, found borrow", b.desc()), + (Future(a), Future(b)) => match (a, b) { + (None, None) => Ok(()), + (Some(a), Some(b)) => self + .component_val_type(a, b, offset) + .with_context(|| "type mismatch in future"), + (None, Some(_)) => bail!(offset, "expected future type, but found none"), + (Some(_), None) => bail!(offset, "expected future type to not be present"), + }, + (Future(_), b) => bail!(offset, "expected {}, found future", b.desc()), + (Stream(a), Stream(b)) => self + .component_val_type(a, b, offset) + .with_context(|| "type mismatch in stream"), + (Stream(_), b) => bail!(offset, "expected {}, found stream", b.desc()), + (Error, Error) => Ok(()), + (Error, b) => bail!(offset, "expected {}, found error", b.desc()), } } diff --git a/crates/wasmprinter/src/lib.rs b/crates/wasmprinter/src/lib.rs index f8945d5b84..9b6574c491 100644 --- a/crates/wasmprinter/src/lib.rs +++ b/crates/wasmprinter/src/lib.rs @@ -1783,6 +1783,28 @@ impl Printer<'_, '_> { Ok(()) } + fn print_future_type(&mut self, state: &State, ty: Option) -> Result<()> { + self.start_group("future")?; + + if let Some(ty) = ty { + self.result.write_str(" ")?; + self.print_component_val_type(state, &ty)?; + } + + self.end_group()?; + + Ok(()) + } + + fn print_stream_type(&mut self, state: &State, ty: ComponentValType) -> Result<()> { + self.start_group("stream")?; + self.result.write_str(" ")?; + self.print_component_val_type(state, &ty)?; + self.end_group()?; + + Ok(()) + } + fn print_defined_type(&mut self, state: &State, ty: &ComponentDefinedType) -> Result<()> { match ty { ComponentDefinedType::Primitive(ty) => self.print_primitive_val_type(ty)?, @@ -1804,6 +1826,9 @@ impl Printer<'_, '_> { self.print_idx(&state.component.type_names, *idx)?; self.end_group()?; } + ComponentDefinedType::Future(ty) => self.print_future_type(state, *ty)?, + ComponentDefinedType::Stream(ty) => self.print_stream_type(state, *ty)?, + ComponentDefinedType::Error => self.result.write_str("error")?, } Ok(()) @@ -2316,6 +2341,12 @@ impl Printer<'_, '_> { self.print_idx(&state.core.func_names, *idx)?; self.end_group()?; } + CanonicalOption::Async => self.result.write_str("async")?, + CanonicalOption::Callback(idx) => { + self.start_group("callback ")?; + self.print_idx(&state.core.func_names, *idx)?; + self.end_group()?; + } } } Ok(()) @@ -2397,6 +2428,161 @@ impl Printer<'_, '_> { self.end_group()?; state.core.funcs += 1; } + CanonicalFunction::AsyncStart { + component_type_index, + } => { + self.start_group("core func ")?; + self.print_name(&state.core.func_names, state.core.funcs)?; + self.result.write_str(" ")?; + self.start_group("canon call.start ")?; + self.print_idx(&state.component.type_names, component_type_index)?; + self.end_group()?; + self.end_group()?; + state.core.funcs += 1; + } + CanonicalFunction::AsyncReturn { + component_type_index, + } => { + self.start_group("core func ")?; + self.print_name(&state.core.func_names, state.core.funcs)?; + self.result.write_str(" ")?; + self.start_group("canon call.return ")?; + self.print_idx(&state.component.type_names, component_type_index)?; + self.end_group()?; + self.end_group()?; + state.core.funcs += 1; + } + CanonicalFunction::FutureNew { ty, memory } => { + self.start_group("core func ")?; + self.print_name(&state.core.func_names, state.core.funcs)?; + self.result.write_str(" ")?; + self.start_group("canon future.new ")?; + self.print_idx(&state.component.type_names, ty)?; + self.result.write_str(" ")?; + self.print_idx(&state.component.type_names, memory)?; + self.end_group()?; + self.end_group()?; + state.core.funcs += 1; + } + CanonicalFunction::FutureSend { ty, options } => { + self.start_group("core func ")?; + self.print_name(&state.core.func_names, state.core.funcs)?; + self.result.write_str(" ")?; + self.start_group("canon future.send ")?; + self.print_idx(&state.component.type_names, ty)?; + self.result.write_str(" ")?; + self.print_canonical_options(state, &options)?; + self.end_group()?; + self.end_group()?; + state.core.funcs += 1; + } + CanonicalFunction::FutureReceive { ty, options } => { + self.start_group("core func ")?; + self.print_name(&state.core.func_names, state.core.funcs)?; + self.result.write_str(" ")?; + self.start_group("canon future.receive ")?; + self.print_idx(&state.component.type_names, ty)?; + self.result.write_str(" ")?; + self.print_canonical_options(state, &options)?; + self.end_group()?; + self.end_group()?; + state.core.funcs += 1; + } + CanonicalFunction::FutureDropSender { ty } => { + self.start_group("core func ")?; + self.print_name(&state.core.func_names, state.core.funcs)?; + self.result.write_str(" ")?; + self.start_group("canon future.drop_sender ")?; + self.print_idx(&state.component.type_names, ty)?; + self.end_group()?; + self.end_group()?; + state.core.funcs += 1; + } + CanonicalFunction::FutureDropReceiver { ty } => { + self.start_group("core func ")?; + self.print_name(&state.core.func_names, state.core.funcs)?; + self.result.write_str(" ")?; + self.start_group("canon future.drop_receiver ")?; + self.print_idx(&state.component.type_names, ty)?; + self.end_group()?; + self.end_group()?; + state.core.funcs += 1; + } + CanonicalFunction::StreamNew { ty, memory } => { + self.start_group("core func ")?; + self.print_name(&state.core.func_names, state.core.funcs)?; + self.result.write_str(" ")?; + self.start_group("canon stream.new ")?; + self.print_idx(&state.component.type_names, ty)?; + self.result.write_str(" ")?; + self.print_idx(&state.component.type_names, memory)?; + self.end_group()?; + self.end_group()?; + state.core.funcs += 1; + } + CanonicalFunction::StreamSend { ty, options } => { + self.start_group("core func ")?; + self.print_name(&state.core.func_names, state.core.funcs)?; + self.result.write_str(" ")?; + self.start_group("canon stream.send ")?; + self.print_idx(&state.component.type_names, ty)?; + self.result.write_str(" ")?; + self.print_canonical_options(state, &options)?; + self.end_group()?; + self.end_group()?; + state.core.funcs += 1; + } + CanonicalFunction::StreamReceive { ty, options } => { + self.start_group("core func ")?; + self.print_name(&state.core.func_names, state.core.funcs)?; + self.result.write_str(" ")?; + self.start_group("canon stream.receive ")?; + self.print_idx(&state.component.type_names, ty)?; + self.result.write_str(" ")?; + self.print_canonical_options(state, &options)?; + self.end_group()?; + self.end_group()?; + state.core.funcs += 1; + } + CanonicalFunction::StreamDropSender { ty } => { + self.start_group("core func ")?; + self.print_name(&state.core.func_names, state.core.funcs)?; + self.result.write_str(" ")?; + self.start_group("canon stream.drop_sender ")?; + self.print_idx(&state.component.type_names, ty)?; + self.end_group()?; + self.end_group()?; + state.core.funcs += 1; + } + CanonicalFunction::StreamDropReceiver { ty } => { + self.start_group("core func ")?; + self.print_name(&state.core.func_names, state.core.funcs)?; + self.result.write_str(" ")?; + self.start_group("canon stream.drop_receiver ")?; + self.print_idx(&state.component.type_names, ty)?; + self.end_group()?; + self.end_group()?; + state.core.funcs += 1; + } + CanonicalFunction::ErrorDrop => { + self.start_group("core func ")?; + self.print_name(&state.core.func_names, state.core.funcs)?; + self.result.write_str(" ")?; + self.start_group("canon error.drop")?; + self.end_group()?; + self.end_group()?; + state.core.funcs += 1; + } + CanonicalFunction::TaskWait { memory } => { + self.start_group("core func ")?; + self.print_name(&state.core.func_names, state.core.funcs)?; + self.result.write_str(" ")?; + self.start_group("canon task.wait ")?; + self.print_idx(&state.component.type_names, memory)?; + self.end_group()?; + self.end_group()?; + state.core.funcs += 1; + } } } diff --git a/crates/wit-component/src/encoding.rs b/crates/wit-component/src/encoding.rs index 35599208ad..2a81686b2a 100644 --- a/crates/wit-component/src/encoding.rs +++ b/crates/wit-component/src/encoding.rs @@ -74,20 +74,21 @@ use crate::encoding::world::WorldAdapter; use crate::metadata::{self, Bindgen, ModuleMetadata}; use crate::validation::{ - ResourceInfo, ValidatedModule, BARE_FUNC_MODULE_NAME, MAIN_MODULE_IMPORT_NAME, - POST_RETURN_PREFIX, + AsyncExportInfo, PayloadInfo, ResourceInfo, ValidatedModule, BARE_FUNC_MODULE_NAME, + CALLBACK_PREFIX, MAIN_MODULE_IMPORT_NAME, POST_RETURN_PREFIX, }; use crate::StringEncoding; -use anyhow::{anyhow, bail, Context, Result}; +use anyhow::{anyhow, bail, Context, Error, Result}; use indexmap::{IndexMap, IndexSet}; +use std::borrow::Cow; use std::collections::HashMap; use std::hash::Hash; use wasm_encoder::*; use wasmparser::{Validator, WasmFeatures}; use wit_parser::{ abi::{AbiVariant, WasmSignature, WasmType}, - Function, FunctionKind, InterfaceId, LiveTypes, Resolve, Type, TypeDefKind, TypeId, TypeOwner, - WorldItem, WorldKey, + Function, FunctionKind, InterfaceId, LiveTypes, Resolve, Results, Type, TypeDefKind, TypeId, + TypeOwner, WorldItem, WorldKey, }; const INDIRECT_TABLE_NAME: &str = "$imports"; @@ -126,12 +127,13 @@ bitflags::bitflags! { /// A string encoding must be specified, which is always utf-8 for now /// today. const STRING_ENCODING = 1 << 2; + const ASYNC = 1 << 3; } } impl RequiredOptions { - fn for_import(resolve: &Resolve, func: &Function) -> RequiredOptions { - let sig = resolve.wasm_signature(AbiVariant::GuestImport, func); + fn for_import(resolve: &Resolve, func: &Function, abi: AbiVariant) -> RequiredOptions { + let sig = resolve.wasm_signature(abi, func); let mut ret = RequiredOptions::empty(); // Lift the params and lower the results for imports ret.add_lift(TypeContents::for_types( @@ -145,11 +147,14 @@ impl RequiredOptions { if sig.retptr || sig.indirect_params { ret |= RequiredOptions::MEMORY; } + if abi == AbiVariant::GuestImportAsync { + ret |= RequiredOptions::ASYNC; + } ret } - fn for_export(resolve: &Resolve, func: &Function) -> RequiredOptions { - let sig = resolve.wasm_signature(AbiVariant::GuestExport, func); + fn for_export(resolve: &Resolve, func: &Function, abi: AbiVariant) -> RequiredOptions { + let sig = resolve.wasm_signature(abi, func); let mut ret = RequiredOptions::empty(); // Lower the params and lift the results for exports ret.add_lower(TypeContents::for_types( @@ -167,6 +172,9 @@ impl RequiredOptions { ret |= RequiredOptions::REALLOC; } } + if abi == AbiVariant::GuestExportAsync { + ret |= RequiredOptions::ASYNC; + } ret } @@ -204,7 +212,7 @@ impl RequiredOptions { ) -> Result> { #[derive(Default)] struct Iter { - options: [Option; 3], + options: [Option; 5], current: usize, count: usize, } @@ -254,6 +262,10 @@ impl RequiredOptions { iter.push(encoding.into()); } + if self.contains(RequiredOptions::ASYNC) { + iter.push(CanonicalOption::Async); + } + Ok(iter) } } @@ -312,8 +324,9 @@ impl TypeContents { TypeDefKind::Enum(_) => Self::empty(), TypeDefKind::List(t) => Self::for_type(resolve, t) | Self::LIST, TypeDefKind::Type(t) => Self::for_type(resolve, t), - TypeDefKind::Future(_) => todo!("encoding for future"), - TypeDefKind::Stream(_) => todo!("encoding for stream"), + TypeDefKind::Future(_) => Self::empty(), + TypeDefKind::Stream(_) => Self::empty(), + TypeDefKind::Error => Self::empty(), TypeDefKind::Unknown => unreachable!(), }, Type::String => Self::STRING, @@ -485,7 +498,11 @@ impl<'a> EncodingState<'a> { // Next encode all required functions from this imported interface // into the instance type. for (_, func) in interface.functions.iter() { - if !info.lowerings.contains_key(&func.name) { + if !(info.lowerings.contains_key(&func.name) + || info + .lowerings + .contains_key(&format!("[async]{}", func.name))) + { continue; } log::trace!("encoding function type for `{}`", func.name); @@ -497,6 +514,10 @@ impl<'a> EncodingState<'a> { let ty = encoder.ty; // Don't encode empty instance types since they're not // meaningful to the runtime of the component anyway. + // + // TODO: Is this correct? What if another imported interface needs to + // alias a type exported by this interface but can't because we skipped + // encoding the import? if ty.is_empty() { return Ok(()); } @@ -564,7 +585,7 @@ impl<'a> EncodingState<'a> { &shims, info.metadata, ); - args.push((*core_wasm_name, ModuleArg::Instance(index))); + args.push((core_wasm_name.to_string(), ModuleArg::Instance(index))); } // For each adapter module instance imported into the core wasm module @@ -587,7 +608,7 @@ impl<'a> EncodingState<'a> { } let index = self.component.core_instantiate_exports(exports); - args.push((*adapter, ModuleArg::Instance(index))); + args.push((adapter.to_string(), ModuleArg::Instance(index))); } self.add_resource_funcs( @@ -597,11 +618,45 @@ impl<'a> EncodingState<'a> { &mut args, ); + self.add_async_funcs(&info.required_async_funcs, &mut args); + + self.add_payload_funcs( + CustomModule::Main, + &info.required_payload_funcs, + &shims, + &mut args, + )?; + + if info.needs_error_drop { + let index = self.component.error_drop(); + let index = self.component.core_instantiate_exports(vec![( + "[error-drop]", + ExportKind::Func, + index, + )]); + args.push(("$root".to_owned(), ModuleArg::Instance(index))); + } + + if info.needs_task_wait { + let index = self.component.core_alias_export( + self.shim_instance_index + .expect("shim should be instantiated"), + &shims.shim_names[&ShimKind::TaskWait], + ExportKind::Func, + ); + let index = self.component.core_instantiate_exports(vec![( + "[task-wait]", + ExportKind::Func, + index, + )]); + args.push(("$root".to_owned(), ModuleArg::Instance(index))); + } + // Instantiate the main module now that all of its arguments have been // prepared. With this we now have the main linear memory for // liftings/lowerings later on as well as the adapter modules, if any, // instantiated after the core wasm module. - self.instantiate_core_module(args, info); + self.instantiate_core_module(args.iter().map(|(a, b)| (a.as_str(), *b)), info); // Separate the adapters according which should be instantiated before // and after indirect lowerings are encoded. @@ -664,48 +719,62 @@ impl<'a> EncodingState<'a> { let mut exports = Vec::with_capacity(import.lowerings.len()); for (index, (name, lowering)) in import.lowerings.iter().enumerate() { - if !required_imports.funcs.contains(name.as_str()) { - continue; - } - let index = match lowering { - // All direct lowerings can be `canon lower`'d here immediately - // and passed as arguments. - Lowering::Direct => { - let func_index = match &import.interface { - Some(interface) => { - let instance_index = self.imported_instances[interface]; - self.component.alias_export( - instance_index, - name, - ComponentExportKind::Func, - ) - } - None => self.imported_funcs[name], - }; - self.component.lower_func(func_index, []) + let index = { + if !required_imports.funcs.contains(name) { + continue; } + let (async_, name) = if let Some(name) = name.strip_prefix("[async]") { + (true, name) + } else { + (false, name.as_str()) + }; + match lowering { + // All direct lowerings can be `canon lower`'d here immediately + // and passed as arguments. + Lowering::Direct => { + let func_index = match &import.interface { + Some(interface) => { + let instance_index = self.imported_instances[interface]; + self.component.alias_export( + instance_index, + name, + ComponentExportKind::Func, + ) + } + None => self.imported_funcs[name], + }; + self.component.lower_func( + func_index, + if async_ { + vec![CanonicalOption::Async] + } else { + Vec::new() + }, + ) + } - // Add an entry for all indirect lowerings which come as an - // export of the shim module. - Lowering::Indirect { .. } => { - let encoding = - metadata.import_encodings[&(core_wasm_name.to_string(), name.clone())]; - self.component.core_alias_export( - self.shim_instance_index - .expect("shim should be instantiated"), - &shims.shim_names[&ShimKind::IndirectLowering { - interface: interface.clone(), - index, - realloc: for_module, - encoding, - }], - ExportKind::Func, - ) - } + // Add an entry for all indirect lowerings which come as an + // export of the shim module. + Lowering::Indirect { .. } => { + let encoding = metadata.import_encodings + [&(core_wasm_name.to_string(), name.to_string())]; + self.component.core_alias_export( + self.shim_instance_index + .expect("shim should be instantiated"), + &shims.shim_names[&ShimKind::IndirectLowering { + interface: interface.clone(), + index, + realloc: for_module, + encoding, + }], + ExportKind::Func, + ) + } - Lowering::ResourceDrop(id) => { - let resource_idx = self.lookup_resource_index(*id); - self.component.resource_drop(resource_idx) + Lowering::ResourceDrop(id) => { + let resource_idx = self.lookup_resource_index(*id); + self.component.resource_drop(resource_idx) + } } }; exports.push((name.as_str(), ExportKind::Func, index)); @@ -738,6 +807,12 @@ impl<'a> EncodingState<'a> { CustomModule::Main => &self.info.encoder.main_module_exports, CustomModule::Adapter(name) => &self.info.encoder.adapters[name].required_exports, }; + let async_map = match module { + CustomModule::Main => &self.info.info.required_async_funcs, + CustomModule::Adapter(name) => &self.info.adapters[name].info.required_async_funcs, + } + .get(&format!("[export]{BARE_FUNC_MODULE_NAME}")); + let world = &resolve.worlds[self.info.encoder.metadata.world]; for export_name in exports { let export_string = resolve.name_world_key(export_name); @@ -746,8 +821,11 @@ impl<'a> EncodingState<'a> { let ty = self .root_import_type_encoder(None) .encode_func_type(resolve, func)?; + let async_ = async_map + .map(|m| m.get(&func.name).is_some()) + .unwrap_or(false); let core_name = func.core_export_name(None); - let idx = self.encode_lift(module, &core_name, func, ty)?; + let idx = self.encode_lift(module, &core_name, func, ty, async_)?; self.component .export(&export_string, ComponentExportKind::Func, idx, None); } @@ -770,6 +848,12 @@ impl<'a> EncodingState<'a> { log::trace!("encode interface export `{export_name}`"); let resolve = &self.info.encoder.metadata.resolve; + let async_map = match module { + CustomModule::Main => &self.info.info.required_async_funcs, + CustomModule::Adapter(name) => &self.info.adapters[name].info.required_async_funcs, + } + .get(&format!("[export]{export_name}")); + // First execute a `canon lift` for all the functions in this interface // from the core wasm export. This requires type information but notably // not exported type information since we don't want to export this @@ -779,9 +863,14 @@ impl<'a> EncodingState<'a> { let mut imports = Vec::new(); let mut root = self.root_export_type_encoder(Some(export)); for (_, func) in &resolve.interfaces[export].functions { + let async_ = async_map + .map(|m| m.get(&func.name).is_some()) + .unwrap_or(false); let core_name = func.core_export_name(Some(export_name)); let ty = root.encode_func_type(resolve, func)?; - let func_index = root.state.encode_lift(module, &core_name, func, ty)?; + let func_index = root + .state + .encode_lift(module, &core_name, func, ty, async_)?; imports.push(( import_func_name(func), ComponentExportKind::Func, @@ -1076,6 +1165,7 @@ impl<'a> EncodingState<'a> { core_name: &str, func: &Function, ty: u32, + async_: bool, ) -> Result { let resolve = &self.info.encoder.metadata.resolve; let metadata = match module { @@ -1086,15 +1176,32 @@ impl<'a> EncodingState<'a> { CustomModule::Main => &self.info.info.post_returns, CustomModule::Adapter(name) => &self.info.adapters[name].info.post_returns, }; + let callbacks = match module { + CustomModule::Main => &self.info.info.callbacks, + CustomModule::Adapter(name) => &self.info.adapters[name].info.callbacks, + }; let instance_index = match module { CustomModule::Main => self.instance_index.expect("instantiated by now"), CustomModule::Adapter(name) => self.adapter_instances[name], }; + let lift_name = if async_ { + Cow::Owned(format!("[async]{core_name}")) + } else { + Cow::Borrowed(core_name) + }; let core_func_index = self.component - .core_alias_export(instance_index, core_name, ExportKind::Func); + .core_alias_export(instance_index, &lift_name, ExportKind::Func); - let options = RequiredOptions::for_export(resolve, func); + let options = RequiredOptions::for_export( + resolve, + func, + if async_ { + AbiVariant::GuestExportAsync + } else { + AbiVariant::GuestExport + }, + ); let encoding = metadata.export_encodings[core_name]; // TODO: This realloc detection should probably be improved with @@ -1108,6 +1215,14 @@ impl<'a> EncodingState<'a> { .into_iter(encoding, self.memory_index, realloc_index)? .collect::>(); + let callback = format!("{CALLBACK_PREFIX}{core_name}"); + if callbacks.contains(&callback[..]) { + let callback = + self.component + .core_alias_export(instance_index, &callback, ExportKind::Func); + options.push(CanonicalOption::Callback(callback)); + } + let post_return = format!("{POST_RETURN_PREFIX}{core_name}"); if post_returns.contains(&post_return[..]) { let post_return = @@ -1175,6 +1290,30 @@ impl<'a> EncodingState<'a> { &mut ret, ); + self.encode_payload_funcs( + CustomModule::Adapter(adapter_name), + &adapter.info.required_payload_funcs, + &mut signatures, + &mut ret, + )?; + + if adapter.info.needs_task_wait { + let name = ret.list.len().to_string(); + let debug_name = format!("task.wait"); + signatures.push(WasmSignature { + params: vec![WasmType::I32], + results: vec![WasmType::I32], + indirect_params: false, + retptr: false, + }); + ret.list.push(Shim { + name, + debug_name, + options: RequiredOptions::empty(), + kind: ShimKind::TaskWait, + }); + } + let funcs = match self.info.info.adapters_required.get(adapter_name) { Some(funcs) => funcs, None => continue, @@ -1210,6 +1349,30 @@ impl<'a> EncodingState<'a> { &mut ret, ); + self.encode_payload_funcs( + CustomModule::Main, + &info.required_payload_funcs, + &mut signatures, + &mut ret, + )?; + + if info.needs_task_wait { + let name = ret.list.len().to_string(); + let debug_name = format!("task.wait"); + signatures.push(WasmSignature { + params: vec![WasmType::I32], + results: vec![WasmType::I32], + indirect_params: false, + retptr: false, + }); + ret.list.push(Shim { + name, + debug_name, + options: RequiredOptions::empty(), + kind: ShimKind::TaskWait, + }); + } + if ret.list.is_empty() { return Ok(ret); } @@ -1357,6 +1520,9 @@ impl<'a> EncodingState<'a> { ExportKind::Table, ); + let resolve = &self.info.encoder.metadata.resolve; + let (unit_result, payload_results) = resolve.find_future_and_stream_results(); + let mut exports = Vec::new(); exports.push((INDIRECT_TABLE_NAME, ExportKind::Table, table_index)); @@ -1376,6 +1542,7 @@ impl<'a> EncodingState<'a> { } => { let interface = &self.info.import_map[interface]; let (name, _) = interface.lowerings.get_index(*index).unwrap(); + let name = name.strip_prefix("[async]").unwrap_or(name); let func_index = match &interface.interface { Some(interface_id) => { let instance_index = self.imported_instances[interface_id]; @@ -1432,6 +1599,126 @@ impl<'a> EncodingState<'a> { ExportKind::Func, ) } + + ShimKind::PayloadFunc { + for_module, + module, + imported, + function, + ordinal, + kind, + } => { + let funcs = match for_module { + CustomModule::Main => &self.info.info.required_payload_funcs, + CustomModule::Adapter(name) => { + &self.info.adapters[name].info.required_payload_funcs + } + }; + let info = + &funcs[&(module.to_string(), *imported)][&(function.to_string(), *ordinal)]; + let realloc_index = match for_module { + CustomModule::Main => self.realloc_index, + CustomModule::Adapter(name) => self.adapter_import_reallocs[name], + }; + let metadata = match for_module { + CustomModule::Main => &self.info.encoder.metadata.metadata, + CustomModule::Adapter(name) => &self.info.encoder.adapters[*name].metadata, + }; + let encoding = metadata + .import_encodings + .get(&(module.to_string(), info.function.name.clone())) + .copied() + .unwrap_or(StringEncoding::UTF8); + let options = |me: &mut Self, params: Vec, results: Vec| { + Ok::<_, Error>( + (RequiredOptions::for_import( + resolve, + &Function { + name: String::new(), + kind: FunctionKind::Freestanding, + params: params + .into_iter() + .enumerate() + .map(|(i, v)| (format!("a{i}"), v)) + .collect(), + results: match &results[..] { + [] => Results::Named(Vec::new()), + [ty] => Results::Anon(*ty), + _ => unreachable!(), + }, + docs: Default::default(), + stability: wit_parser::Stability::default(), + }, + AbiVariant::GuestImportAsync, + ) & !RequiredOptions::ASYNC) + .into_iter(encoding, me.memory_index, realloc_index)? + .collect::>(), + ) + }; + let type_index = self.payload_type_index(info.ty)?; + + match kind { + PayloadFuncKind::FutureNew => self + .component + .future_new(type_index, self.memory_index.unwrap()), + + PayloadFuncKind::FutureSend => { + let TypeDefKind::Future(payload_type) = &resolve.types[info.ty].kind + else { + unreachable!() + }; + + let options = options( + self, + if let Some(payload_type) = payload_type { + vec![Type::U32, *payload_type] + } else { + vec![Type::U32] + }, + vec![Type::Id(unit_result.unwrap())], + )?; + self.component.future_send(type_index, options) + } + + PayloadFuncKind::FutureReceive => { + let options = options( + self, + vec![Type::U32], + vec![Type::Id(payload_results[&info.ty])], + )?; + self.component.future_receive(type_index, options) + } + + PayloadFuncKind::StreamNew => self + .component + .stream_new(type_index, self.memory_index.unwrap()), + + PayloadFuncKind::StreamSend => { + let TypeDefKind::Stream(payload_type) = &resolve.types[info.ty].kind + else { + unreachable!() + }; + + let options = options( + self, + vec![Type::U32, *payload_type], + vec![Type::Id(unit_result.unwrap())], + )?; + self.component.stream_send(type_index, options) + } + + PayloadFuncKind::StreamReceive => { + let options = options( + self, + vec![Type::Id(info.ty)], + vec![Type::Id(payload_results[&info.ty])], + )?; + self.component.stream_receive(type_index, options) + } + } + } + + ShimKind::TaskWait => self.component.task_wait(self.memory_index.unwrap()), }; exports.push((shim.name.as_str(), ExportKind::Func, core_func_index)); @@ -1513,12 +1800,134 @@ impl<'a> EncodingState<'a> { } } + fn payload_type_index(&mut self, ty: TypeId) -> Result { + let resolve = &self.info.encoder.metadata.resolve; + let ComponentValType::Type(type_index) = self + .root_import_type_encoder(None) + .encode_valtype(resolve, &Type::Id(ty))? + else { + unreachable!() + }; + Ok(type_index) + } + + fn encode_payload_funcs<'b>( + &mut self, + for_module: CustomModule<'b>, + funcs: &'b IndexMap<(String, bool), IndexMap<(String, usize), PayloadInfo<'a>>>, + signatures: &mut Vec, + shims: &mut Shims<'b>, + ) -> Result<()> { + for ((module, imported), info) in funcs { + for ((function, ordinal), info) in info { + if (info.future_new.is_some() + || info.future_send.is_some() + || info.future_receive.is_some() + || info.future_drop_sender.is_some() + || info.future_drop_receiver.is_some() + || info.future_new.is_some() + || info.future_send.is_some() + || info.future_receive.is_some() + || info.future_drop_sender.is_some() + || info.future_drop_receiver.is_some()) + && !imported + { + todo!( + "support referring to exported payload types in \ + stream and future intrinsic declarations" + ); + }; + + let mut add = |name, kind, params, results| { + let debug_name = format!("{module}-{name}"); + signatures.push(WasmSignature { + params, + results, + indirect_params: false, + retptr: false, + }); + let name = shims.list.len().to_string(); + shims.list.push(Shim { + name, + debug_name, + options: RequiredOptions::empty(), + kind: ShimKind::PayloadFunc { + for_module, + module, + imported: *imported, + function, + ordinal: *ordinal, + kind, + }, + }); + Ok::<_, Error>(()) + }; + + if let Some(name) = info.future_new.as_ref() { + add( + name, + PayloadFuncKind::FutureNew, + vec![WasmType::I32], + Vec::new(), + )?; + } + + if let Some(name) = info.future_send.as_ref() { + add( + name, + PayloadFuncKind::FutureSend, + vec![WasmType::I32; 3], + vec![WasmType::I32], + )?; + } + + if let Some(name) = info.future_receive.as_ref() { + add( + name, + PayloadFuncKind::FutureReceive, + vec![WasmType::I32; 3], + vec![WasmType::I32], + )?; + } + + if let Some(name) = info.stream_new.as_ref() { + add( + name, + PayloadFuncKind::StreamNew, + vec![WasmType::I32], + Vec::new(), + )?; + } + + if let Some(name) = info.stream_send.as_ref() { + add( + name, + PayloadFuncKind::StreamSend, + vec![WasmType::I32; 3], + vec![WasmType::I32], + )?; + } + + if let Some(name) = info.stream_receive.as_ref() { + add( + name, + PayloadFuncKind::StreamReceive, + vec![WasmType::I32; 3], + vec![WasmType::I32], + )?; + } + } + } + + Ok(()) + } + fn add_resource_funcs<'b>( &mut self, module: CustomModule<'b>, funcs: &'b IndexMap>, shims: &Shims, - args: &mut Vec<(&'b str, ModuleArg)>, + args: &mut Vec<(String, ModuleArg)>, ) { for (import, info) in funcs { let mut exports = Vec::new(); @@ -1558,18 +1967,166 @@ impl<'a> EncodingState<'a> { } if !exports.is_empty() { let index = self.component.core_instantiate_exports(exports); - args.push((import.as_str(), ModuleArg::Instance(index))); + args.push((import.clone(), ModuleArg::Instance(index))); + } + } + } + + fn add_async_funcs<'b>( + &mut self, + funcs: &'b IndexMap>>, + args: &mut Vec<(String, ModuleArg)>, + ) { + let resolve = &self.info.encoder.metadata.resolve; + for (import, info) in funcs { + let mut exports = Vec::new(); + for info in info.values() { + if info.start_import.is_some() || info.return_import.is_some() { + let type_index = self + .root_import_type_encoder(None) + .encode_func_type(resolve, &info.function) + .unwrap(); + + if let Some(name) = info.start_import.as_ref() { + let index = self.component.async_start(type_index); + exports.push((name.as_str(), ExportKind::Func, index)); + } + if let Some(name) = info.return_import.as_ref() { + let index = self.component.async_return(type_index); + exports.push((name.as_str(), ExportKind::Func, index)); + } + } + } + if !exports.is_empty() { + let index = self.component.core_instantiate_exports(exports); + args.push((import.clone(), ModuleArg::Instance(index))); } } } + fn add_payload_funcs<'b>( + &mut self, + for_module: CustomModule, + funcs: &'b IndexMap<(String, bool), IndexMap<(String, usize), PayloadInfo<'a>>>, + shims: &Shims, + args: &mut Vec<(String, ModuleArg)>, + ) -> Result<()> { + for ((module, imported), info) in funcs { + let mut exports = Vec::new(); + for ((function, ordinal), info) in info { + let indirect = |me: &mut Self, kind| { + me.component.core_alias_export( + me.shim_instance_index.expect("shim should be instantiated"), + &shims.shim_names[&ShimKind::PayloadFunc { + for_module, + module, + imported: *imported, + function, + ordinal: *ordinal, + kind, + }], + ExportKind::Func, + ) + }; + + if let Some(name) = info.future_new.as_ref() { + exports.push(( + name.as_str(), + ExportKind::Func, + indirect(self, PayloadFuncKind::FutureNew), + )); + } + + if let Some(name) = info.future_send.as_ref() { + exports.push(( + name.as_str(), + ExportKind::Func, + indirect(self, PayloadFuncKind::FutureSend), + )); + } + + if let Some(name) = info.future_receive.as_ref() { + exports.push(( + name.as_str(), + ExportKind::Func, + indirect(self, PayloadFuncKind::FutureReceive), + )); + } + + if let Some(name) = info.future_drop_sender.as_ref() { + let type_index = self.payload_type_index(info.ty)?; + let index = self.component.future_drop_sender(type_index); + exports.push((name.as_str(), ExportKind::Func, index)); + } + + if let Some(name) = info.future_drop_receiver.as_ref() { + let type_index = self.payload_type_index(info.ty)?; + let index = self.component.future_drop_receiver(type_index); + exports.push((name.as_str(), ExportKind::Func, index)); + } + + if let Some(name) = info.stream_new.as_ref() { + exports.push(( + name.as_str(), + ExportKind::Func, + indirect(self, PayloadFuncKind::StreamNew), + )); + } + + if let Some(name) = info.stream_send.as_ref() { + exports.push(( + name.as_str(), + ExportKind::Func, + indirect(self, PayloadFuncKind::StreamSend), + )); + } + + if let Some(name) = info.stream_receive.as_ref() { + exports.push(( + name.as_str(), + ExportKind::Func, + indirect(self, PayloadFuncKind::StreamReceive), + )); + } + + if let Some(name) = info.stream_drop_sender.as_ref() { + let type_index = self.payload_type_index(info.ty)?; + let index = self.component.stream_drop_sender(type_index); + exports.push((name.as_str(), ExportKind::Func, index)); + } + + if let Some(name) = info.stream_drop_receiver.as_ref() { + let type_index = self.payload_type_index(info.ty)?; + let index = self.component.stream_drop_receiver(type_index); + exports.push((name.as_str(), ExportKind::Func, index)); + } + } + if !exports.is_empty() { + let index = self.component.core_instantiate_exports(exports); + args.push(( + format!( + "{}{module}", + if *imported { + "[import-payload]" + } else { + "[export-payload]" + } + ), + ModuleArg::Instance(index), + )); + } + } + + Ok(()) + } + /// This function will instantiate the specified adapter module, which may /// depend on previously-instantiated modules. fn instantiate_adapter_module( &mut self, shims: &Shims<'_>, name: &'a str, - adapter: &WorldAdapter, + adapter: &'a WorldAdapter, ) { let mut args = Vec::new(); @@ -1593,7 +2150,10 @@ impl<'a> EncodingState<'a> { } if !core_exports.is_empty() { let instance = self.component.core_instantiate_exports(core_exports); - args.push((MAIN_MODULE_IMPORT_NAME, ModuleArg::Instance(instance))); + args.push(( + MAIN_MODULE_IMPORT_NAME.to_string(), + ModuleArg::Instance(instance), + )); } // The adapter may either be a library or a "minimal" adapter. If it's // the former, we use `LibraryInfo::arguments` to populate inter-module @@ -1614,7 +2174,7 @@ impl<'a> EncodingState<'a> { }; args.push(( - import_name, + import_name.clone(), ModuleArg::Instance(match instance { Instance::MainOrAdapter(which) => resolve(which), Instance::Items(items) => { @@ -1653,7 +2213,7 @@ impl<'a> EncodingState<'a> { ExportKind::Memory, memory, )]); - args.push((module.as_str(), ModuleArg::Instance(instance))); + args.push((module.clone(), ModuleArg::Instance(instance))); } } for (import_name, _) in adapter.info.required_imports.iter() { @@ -1663,7 +2223,7 @@ impl<'a> EncodingState<'a> { shims, adapter.info.metadata, ); - args.push((import_name, ModuleArg::Instance(instance))); + args.push((import_name.clone(), ModuleArg::Instance(instance))); } self.add_resource_funcs( @@ -1673,9 +2233,45 @@ impl<'a> EncodingState<'a> { &mut args, ); - let instance = self - .component - .core_instantiate(self.adapter_modules[name], args); + self.add_async_funcs(&adapter.info.required_async_funcs, &mut args); + + self.add_payload_funcs( + CustomModule::Adapter(name), + &adapter.info.required_payload_funcs, + shims, + &mut args, + ) + .unwrap(); + + if adapter.info.needs_error_drop { + let index = self.component.error_drop(); + let index = self.component.core_instantiate_exports(vec![( + "[error-drop]", + ExportKind::Func, + index, + )]); + args.push(("$root".to_owned(), ModuleArg::Instance(index))); + } + + if adapter.info.needs_task_wait { + let index = self.component.core_alias_export( + self.shim_instance_index + .expect("shim should be instantiated"), + &shims.shim_names[&ShimKind::TaskWait], + ExportKind::Func, + ); + let index = self.component.core_instantiate_exports(vec![( + "[task-wait]", + ExportKind::Func, + index, + )]); + args.push(("$root".to_owned(), ModuleArg::Instance(index))); + } + + let instance = self.component.core_instantiate( + self.adapter_modules[name], + args.iter().map(|(a, b)| (a.as_str(), *b)), + ); self.adapter_instances.insert(name, instance); let realloc = adapter.info.export_realloc.as_ref().map(|name| { @@ -1735,6 +2331,16 @@ struct Shim<'a> { kind: ShimKind<'a>, } +#[derive(Debug, Clone, Hash, Eq, PartialEq)] +enum PayloadFuncKind { + FutureNew, + FutureSend, + FutureReceive, + StreamNew, + StreamSend, + StreamReceive, +} + #[derive(Debug, Clone, Hash, Eq, PartialEq)] enum ShimKind<'a> { /// This shim is a late indirect lowering of an imported function in a @@ -1768,6 +2374,15 @@ enum ShimKind<'a> { /// The name of the resource being destroyed. resource: &'a str, }, + PayloadFunc { + for_module: CustomModule<'a>, + module: &'a str, + imported: bool, + function: &'a str, + ordinal: usize, + kind: PayloadFuncKind, + }, + TaskWait, } /// Indicator for which module is being used for a lowering or where options @@ -1810,6 +2425,7 @@ impl<'a> Shims<'a> { if !required.contains(name.as_str()) { continue; } + let name = name.strip_prefix("[async]").unwrap_or(name); let shim_name = self.list.len().to_string(); log::debug!( "shim {shim_name} is import `{core_wasm_module}` lowering {index} `{name}`", @@ -1821,7 +2437,7 @@ impl<'a> Shims<'a> { sigs.push(sig.clone()); let encoding = *metadata .import_encodings - .get(&(core_wasm_module.to_string(), name.clone())) + .get(&(core_wasm_module.to_string(), name.to_string())) .ok_or_else(|| { anyhow::anyhow!( "missing component metadata for import of \ @@ -1986,6 +2602,7 @@ impl ComponentEncoder { library_info: Option, ) -> Result { let (wasm, metadata) = metadata::decode(bytes)?; + // Merge the adapter's document into our own document to have one large // document, and then afterwards merge worlds as well. // @@ -2075,11 +2692,13 @@ impl ComponentEncoder { } /// Encode the component and return the bytes. - pub fn encode(&self) -> Result> { + pub fn encode(&mut self) -> Result> { if self.module.is_empty() { bail!("a module is required when encoding a component"); } + self.metadata.resolve.add_future_and_stream_results(); + let world = ComponentWorld::new(self).context("failed to decode world from module")?; let mut state = EncodingState { component: ComponentBuilder::default(), @@ -2114,6 +2733,9 @@ impl ComponentEncoder { .raw_custom_section(&crate::base_producers().raw_custom_section()); let bytes = state.component.finish(); + // TODO dicej: remove this: + std::fs::write("/tmp/foo.wasm", &bytes).unwrap(); + if self.validate { let mut validator = Validator::new_with_features( WasmFeatures::default() | WasmFeatures::COMPONENT_MODEL, diff --git a/crates/wit-component/src/encoding/types.rs b/crates/wit-component/src/encoding/types.rs index 6a47c26cf8..6b0ba9d59b 100644 --- a/crates/wit-component/src/encoding/types.rs +++ b/crates/wit-component/src/encoding/types.rs @@ -132,7 +132,7 @@ pub trait ValtypeEncoder<'a> { // it as it was bound here with an alias. let ty = &resolve.types[id]; log::trace!("encode type name={:?} {:?}", ty.name, &ty.kind); - if let Some(index) = self.maybe_import_type(resolve, id) { + if let Some(index) = self.maybe_import_type(resolve, dealias(resolve, id)) { self.type_map().insert(id, index); return Ok(ComponentValType::Type(index)); } @@ -153,8 +153,9 @@ pub trait ValtypeEncoder<'a> { ComponentValType::Type(index) } TypeDefKind::Type(ty) => self.encode_valtype(resolve, ty)?, - TypeDefKind::Future(_) => todo!("encoding for future type"), - TypeDefKind::Stream(_) => todo!("encoding for stream type"), + TypeDefKind::Future(ty) => self.encode_future(resolve, ty)?, + TypeDefKind::Stream(ty) => self.encode_stream(resolve, ty)?, + TypeDefKind::Error => self.encode_error()?, TypeDefKind::Unknown => unreachable!(), TypeDefKind::Resource => { let name = ty.name.as_ref().expect("resources must be named"); @@ -309,6 +310,30 @@ pub trait ValtypeEncoder<'a> { encoder.enum_type(enum_.cases.iter().map(|c| c.name.as_str())); Ok(ComponentValType::Type(index)) } + + fn encode_future( + &mut self, + resolve: &'a Resolve, + payload: &Option, + ) -> Result { + let ty = self.encode_optional_valtype(resolve, payload.as_ref())?; + let (index, encoder) = self.defined_type(); + encoder.future(ty); + Ok(ComponentValType::Type(index)) + } + + fn encode_stream(&mut self, resolve: &'a Resolve, payload: &Type) -> Result { + let ty = self.encode_valtype(resolve, payload)?; + let (index, encoder) = self.defined_type(); + encoder.stream(ty); + Ok(ComponentValType::Type(index)) + } + + fn encode_error(&mut self) -> Result { + let (index, encoder) = self.defined_type(); + encoder.error(); + Ok(ComponentValType::Type(index)) + } } pub struct RootTypeEncoder<'state, 'a> { @@ -427,3 +452,12 @@ impl<'a> ValtypeEncoder<'a> for InstanceTypeEncoder<'_, 'a> { &mut self.func_type_map } } + +pub fn dealias(resolve: &Resolve, mut id: TypeId) -> TypeId { + loop { + match &resolve.types[id].kind { + TypeDefKind::Type(Type::Id(that_id)) => id = *that_id, + _ => break id, + } + } +} diff --git a/crates/wit-component/src/encoding/world.rs b/crates/wit-component/src/encoding/world.rs index f7cb669d57..4c236db8b8 100644 --- a/crates/wit-component/src/encoding/world.rs +++ b/crates/wit-component/src/encoding/world.rs @@ -460,19 +460,26 @@ impl<'a> ComponentWorld<'a> { impl ImportedInterface { fn add_func(&mut self, required: &IndexSet<&str>, resolve: &Resolve, func: &Function) { - if !required.contains(func.name.as_str()) { - return; - } + let (name, abi) = if required.contains(func.name.as_str()) { + (func.name.clone(), AbiVariant::GuestImport) + } else { + let async_name = format!("[async]{}", func.name); + if required.contains(async_name.as_str()) { + (async_name, AbiVariant::GuestImportAsync) + } else { + return; + } + }; log::trace!("add func {}", func.name); - let options = RequiredOptions::for_import(resolve, func); + let options = RequiredOptions::for_import(resolve, func, abi); let lowering = if options.is_empty() { Lowering::Direct } else { - let sig = resolve.wasm_signature(AbiVariant::GuestImport, func); + let sig = resolve.wasm_signature(abi, func); Lowering::Indirect { sig, options } }; - let prev = self.lowerings.insert(func.name.clone(), lowering); + let prev = self.lowerings.insert(name, lowering); assert!(prev.is_none()); } diff --git a/crates/wit-component/src/printing.rs b/crates/wit-component/src/printing.rs index 5c2ce5813f..2f621c5048 100644 --- a/crates/wit-component/src/printing.rs +++ b/crates/wit-component/src/printing.rs @@ -538,6 +538,7 @@ impl WitPrinter { TypeDefKind::Stream(_) => { todo!("document has an unnamed stream type") } + TypeDefKind::Error => self.output.push_str("error"), TypeDefKind::Unknown => unreachable!(), } } @@ -695,6 +696,7 @@ impl WitPrinter { }, TypeDefKind::Future(_) => todo!("declare future"), TypeDefKind::Stream(_) => todo!("declare stream"), + TypeDefKind::Error => todo!("declare error"), TypeDefKind::Unknown => unreachable!(), } } diff --git a/crates/wit-component/src/validation.rs b/crates/wit-component/src/validation.rs index fadd430c1f..a2d82bde60 100644 --- a/crates/wit-component/src/validation.rs +++ b/crates/wit-component/src/validation.rs @@ -48,7 +48,11 @@ pub const RESOURCE_DROP: &str = "[resource-drop]"; pub const RESOURCE_REP: &str = "[resource-rep]"; pub const RESOURCE_NEW: &str = "[resource-new]"; +pub const ASYNC_START: &str = "[async-start]"; +pub const ASYNC_RETURN: &str = "[async-return]"; + pub const POST_RETURN_PREFIX: &str = "cabi_post_"; +pub const CALLBACK_PREFIX: &str = "[callback][async]"; /// Metadata about a validated module and what was found internally. /// @@ -85,6 +89,14 @@ pub struct ValidatedModule<'a> { /// itself. pub required_resource_funcs: IndexMap>, + pub required_async_funcs: IndexMap>>, + + pub required_payload_funcs: + IndexMap<(String, bool), IndexMap<(String, usize), PayloadInfo<'a>>>, + + pub needs_error_drop: bool, + pub needs_task_wait: bool, + /// Whether or not this module exported a linear memory. pub has_memory: bool, @@ -100,6 +112,10 @@ pub struct ValidatedModule<'a> { /// Post-return functions annotated with `cabi_post_*` in their function /// name. pub post_returns: IndexSet, + + /// Callback functions annotated with `[callback]*` in their function + /// name. + pub callbacks: IndexSet, } #[derive(Default)] @@ -116,6 +132,29 @@ pub struct ResourceInfo { pub id: TypeId, } +pub struct AsyncExportInfo<'a> { + pub interface: Option, + pub function: &'a Function, + pub start_import: Option, + pub return_import: Option, +} + +pub struct PayloadInfo<'a> { + pub interface: Option, + pub function: &'a Function, + pub ty: TypeId, + pub future_new: Option, + pub future_send: Option, + pub future_receive: Option, + pub future_drop_sender: Option, + pub future_drop_receiver: Option, + pub stream_new: Option, + pub stream_send: Option, + pub stream_receive: Option, + pub stream_drop_sender: Option, + pub stream_drop_receiver: Option, +} + /// This function validates the following: /// /// * The `bytes` represent a valid core WebAssembly module. @@ -140,12 +179,17 @@ pub fn validate_module<'a>( let mut ret = ValidatedModule { required_imports: Default::default(), adapters_required: Default::default(), + needs_error_drop: false, + needs_task_wait: false, has_memory: false, realloc: None, adapter_realloc: None, metadata: &metadata.metadata, + required_async_funcs: Default::default(), required_resource_funcs: Default::default(), + required_payload_funcs: Default::default(), post_returns: Default::default(), + callbacks: Default::default(), }; for payload in Parser::new(0).parse_all(bytes) { @@ -211,21 +255,38 @@ pub fn validate_module<'a>( let types = types.unwrap(); let world = &metadata.resolve.worlds[metadata.world]; - let mut exported_resource_funcs = Vec::new(); + let mut exported_resource_and_async_funcs = Vec::new(); + let mut payload_funcs = Vec::new(); for (name, funcs) in &import_funcs { // An empty module name is indicative of the top-level import namespace, // so look for top-level functions here. if *name == BARE_FUNC_MODULE_NAME { - let required = - validate_imports_top_level(&metadata.resolve, metadata.world, funcs, &types)?; - let prev = ret.required_imports.insert(BARE_FUNC_MODULE_NAME, required); - assert!(prev.is_none()); + let Imports { + required, + needs_error_drop, + needs_task_wait, + } = validate_imports_top_level(&metadata.resolve, metadata.world, funcs, &types)?; + ret.needs_error_drop = needs_error_drop; + ret.needs_task_wait = needs_task_wait; + if !(required.funcs.is_empty() && required.resources.is_empty()) { + let prev = ret.required_imports.insert(BARE_FUNC_MODULE_NAME, required); + assert!(prev.is_none()); + } continue; } if let Some(interface_name) = name.strip_prefix("[export]") { - exported_resource_funcs.push((name, interface_name, &import_funcs[name])); + exported_resource_and_async_funcs.push((name, interface_name, &import_funcs[name])); + continue; + } + + if let Some((interface_name, imported)) = name + .strip_prefix("[import-payload]") + .map(|v| (v, true)) + .or_else(|| name.strip_prefix("[export-payload]").map(|v| (v, false))) + { + payload_funcs.push((name, imported, interface_name, &import_funcs[name])); continue; } @@ -244,6 +305,7 @@ pub fn validate_module<'a>( name, funcs, &types, + &mut ret.required_payload_funcs, ) .with_context(|| format!("failed to validate import interface `{name}`"))?; let prev = ret.required_imports.insert(name, required); @@ -265,20 +327,24 @@ pub fn validate_module<'a>( &export_funcs, &types, &mut ret.post_returns, + &mut ret.callbacks, + &mut ret.required_payload_funcs, + &mut ret.required_async_funcs, &mut ret.required_resource_funcs, )?; } - for (name, interface_name, funcs) in exported_resource_funcs { + for (name, interface_name, funcs) in exported_resource_and_async_funcs { let world_key = world_key(&metadata.resolve, interface_name); match world.exports.get(&world_key) { Some(WorldItem::Interface { id, .. }) => { - validate_exported_interface_resource_imports( + validate_exported_interface_resource_and_async_imports( &metadata.resolve, *id, name, funcs, &types, + &mut ret.required_async_funcs, &mut ret.required_resource_funcs, )?; } @@ -286,24 +352,60 @@ pub fn validate_module<'a>( } } + for (name, imported, interface_name, funcs) in payload_funcs { + let world_key = world_key(&metadata.resolve, interface_name); + let (item, direction) = if imported { + (world.imports.get(&world_key), "imported") + } else { + (world.exports.get(&world_key), "exported") + }; + match item { + Some(WorldItem::Interface { id, .. }) => { + validate_payload_imports( + &metadata.resolve, + *id, + name, + imported, + funcs, + &types, + &mut ret.required_payload_funcs, + )?; + } + _ => bail!("import from `{name}` does not correspond to {direction} interface"), + } + } + Ok(ret) } -fn validate_exported_interface_resource_imports<'a>( - resolve: &Resolve, +fn validate_exported_interface_resource_and_async_imports<'a, 'b>( + resolve: &'b Resolve, interface: InterfaceId, import_module: &str, funcs: &IndexMap<&'a str, u32>, types: &Types, + required_async_funcs: &mut IndexMap>>, required_resource_funcs: &mut IndexMap>, ) -> Result<()> { let is_resource = |name: &str| match resolve.interfaces[interface].types.get(name) { Some(ty) => matches!(resolve.types[*ty].kind, TypeDefKind::Resource), None => false, }; + let mut async_module = required_async_funcs.get_mut(import_module); for (func_name, ty) in funcs { + if let Some(ref mut info) = async_module { + if let Some(function_name) = func_name.strip_prefix(ASYNC_START) { + info[function_name].start_import = Some(func_name.to_string()); + continue; + } + if let Some(function_name) = func_name.strip_prefix(ASYNC_RETURN) { + info[function_name].return_import = Some(func_name.to_string()); + continue; + } + } + if valid_exported_resource_func(func_name, *ty, types, is_resource)?.is_none() { - bail!("import of `{func_name}` is not a valid resource function"); + bail!("import of `{func_name}` is not a valid resource or async function"); } let info = required_resource_funcs.get_mut(import_module).unwrap(); if let Some(resource_name) = func_name.strip_prefix(RESOURCE_DROP) { @@ -324,6 +426,69 @@ fn validate_exported_interface_resource_imports<'a>( Ok(()) } +fn match_payload_prefix(name: &str, prefix: &str) -> Option<(String, usize)> { + let suffix = name.strip_prefix(prefix)?; + let index = suffix.find(']')?; + Some(( + suffix[index + 1..].to_owned(), + suffix[..index].parse().ok()?, + )) +} + +fn validate_payload_imports<'a, 'b>( + _resolve: &'b Resolve, + _interface: InterfaceId, + module: &str, + import: bool, + funcs: &IndexMap<&'a str, u32>, + _types: &Types, + required_payload_funcs: &mut IndexMap< + (String, bool), + IndexMap<(String, usize), PayloadInfo<'b>>, + >, +) -> Result<()> { + // TODO: Verify that the core wasm function signatures match what we expect for each function found below. + // Presumably any issues will be caught anyway when the final component is validated, but it would be best to + // catch them early. + let module = module + .strip_prefix(if import { + "[import-payload]" + } else { + "[export-payload]" + }) + .unwrap(); + let info = &mut required_payload_funcs[&(module.to_owned(), import)]; + for (orig_func_name, _ty) in funcs { + let func_name = orig_func_name + .strip_prefix("[async]") + .unwrap_or(orig_func_name); + if let Some(key) = match_payload_prefix(func_name, "[future-new-") { + info[&key].future_new = Some(orig_func_name.to_string()); + } else if let Some(key) = match_payload_prefix(func_name, "[future-send-") { + info[&key].future_send = Some(orig_func_name.to_string()); + } else if let Some(key) = match_payload_prefix(func_name, "[future-receive-") { + info[&key].future_receive = Some(orig_func_name.to_string()); + } else if let Some(key) = match_payload_prefix(func_name, "[future-drop-sender-") { + info[&key].future_drop_sender = Some(orig_func_name.to_string()); + } else if let Some(key) = match_payload_prefix(func_name, "[future-drop-receiver-") { + info[&key].future_drop_receiver = Some(orig_func_name.to_string()); + } else if let Some(key) = match_payload_prefix(func_name, "[stream-new-") { + info[&key].stream_new = Some(orig_func_name.to_string()); + } else if let Some(key) = match_payload_prefix(func_name, "[stream-send-") { + info[&key].stream_send = Some(orig_func_name.to_string()); + } else if let Some(key) = match_payload_prefix(func_name, "[stream-receive-") { + info[&key].stream_receive = Some(orig_func_name.to_string()); + } else if let Some(key) = match_payload_prefix(func_name, "[stream-drop-sender-") { + info[&key].stream_drop_sender = Some(orig_func_name.to_string()); + } else if let Some(key) = match_payload_prefix(func_name, "[stream-drop-receiver-") { + info[&key].stream_drop_receiver = Some(orig_func_name.to_string()); + } else { + bail!("unrecognized payload import: {orig_func_name}"); + } + } + Ok(()) +} + /// Validation information from an "adapter module" which is distinct from a /// "main module" validated above. /// @@ -343,6 +508,14 @@ pub struct ValidatedAdapter<'a> { /// itself. pub required_resource_funcs: IndexMap>, + pub required_async_funcs: IndexMap>>, + + pub required_payload_funcs: + IndexMap<(String, bool), IndexMap<(String, usize), PayloadInfo<'a>>>, + + pub needs_error_drop: bool, + pub needs_task_wait: bool, + /// This is the module and field name of the memory import, if one is /// specified. /// @@ -368,6 +541,10 @@ pub struct ValidatedAdapter<'a> { /// Post-return functions annotated with `cabi_post_*` in their function /// name. pub post_returns: IndexSet, + + /// Callback functions annotated with `[callback]*` in their function + /// name. + pub callbacks: IndexSet, } /// This function will validate the `bytes` provided as a wasm adapter module. @@ -405,12 +582,17 @@ pub fn validate_adapter_module<'a>( let mut ret = ValidatedAdapter { required_imports: Default::default(), required_resource_funcs: Default::default(), + required_async_funcs: Default::default(), + required_payload_funcs: Default::default(), + needs_error_drop: false, + needs_task_wait: false, needs_memory: None, needs_core_exports: Default::default(), import_realloc: None, export_realloc: None, metadata, post_returns: Default::default(), + callbacks: Default::default(), }; let mut cabi_realloc = None; @@ -502,7 +684,9 @@ pub fn validate_adapter_module<'a>( } let types = types.unwrap(); - let mut exported_resource_funcs = Vec::new(); + let mut exported_resource_and_async_funcs = Vec::new(); + let mut payload_funcs = Vec::new(); + for (name, funcs) in &import_funcs { if *name == MAIN_MODULE_IMPORT_NAME { ret.needs_core_exports @@ -513,25 +697,48 @@ pub fn validate_adapter_module<'a>( // An empty module name is indicative of the top-level import namespace, // so look for top-level functions here. if *name == BARE_FUNC_MODULE_NAME { - let required = validate_imports_top_level(resolve, world, funcs, &types)?; - ret.required_imports - .insert(BARE_FUNC_MODULE_NAME.to_string(), required); + let Imports { + required, + needs_error_drop, + needs_task_wait, + } = validate_imports_top_level(resolve, world, funcs, &types)?; + ret.needs_error_drop = needs_error_drop; + ret.needs_task_wait = needs_task_wait; + if !(required.funcs.is_empty() && required.resources.is_empty()) { + let prev = ret + .required_imports + .insert(BARE_FUNC_MODULE_NAME.to_string(), required); + assert!(prev.is_none()); + } continue; } if let Some(interface_name) = name.strip_prefix("[export]") { - exported_resource_funcs.push((name, interface_name, &import_funcs[name])); + exported_resource_and_async_funcs.push((name, interface_name, &import_funcs[name])); + continue; + } + + if let Some((interface_name, imported)) = name + .strip_prefix("[import-payload]") + .map(|v| (v, true)) + .or_else(|| name.strip_prefix("[export-payload]").map(|v| (v, false))) + { + payload_funcs.push((name, imported, interface_name, &import_funcs[name])); continue; } if !(is_library && adapters.contains(name)) { match resolve.worlds[world].imports.get(&world_key(resolve, name)) { Some(WorldItem::Interface { id: interface, .. }) => { - let required = - validate_imported_interface(resolve, *interface, name, funcs, &types) - .with_context(|| { - format!("failed to validate import interface `{name}`") - })?; + let required = validate_imported_interface( + resolve, + *interface, + name, + funcs, + &types, + &mut ret.required_payload_funcs, + ) + .with_context(|| format!("failed to validate import interface `{name}`"))?; let prev = ret.required_imports.insert(name.to_string(), required); assert!(prev.is_none()); } @@ -569,20 +776,24 @@ pub fn validate_adapter_module<'a>( &export_funcs, &types, &mut ret.post_returns, + &mut ret.callbacks, + &mut ret.required_payload_funcs, + &mut ret.required_async_funcs, &mut ret.required_resource_funcs, )?; } - for (name, interface_name, funcs) in exported_resource_funcs { + for (name, interface_name, funcs) in exported_resource_and_async_funcs { let world_key = world_key(resolve, interface_name); match world.exports.get(&world_key) { Some(WorldItem::Interface { id, .. }) => { - validate_exported_interface_resource_imports( + validate_exported_interface_resource_and_async_imports( resolve, *id, name, funcs, &types, + &mut ret.required_async_funcs, &mut ret.required_resource_funcs, )?; } @@ -590,6 +801,29 @@ pub fn validate_adapter_module<'a>( } } + for (name, imported, interface_name, funcs) in payload_funcs { + let world_key = world_key(resolve, interface_name); + let (item, direction) = if imported { + (world.imports.get(&world_key), "imported") + } else { + (world.exports.get(&world_key), "exported") + }; + match item { + Some(WorldItem::Interface { id, .. }) => { + validate_payload_imports( + resolve, + *id, + name, + imported, + funcs, + &types, + &mut ret.required_payload_funcs, + )?; + } + _ => bail!("import from `{name}` does not correspond to {direction} interface"), + } + } + Ok(ret) } @@ -616,12 +850,19 @@ fn world_key(resolve: &Resolve, name: &str) -> WorldKey { } } +struct Imports { + required: RequiredImports, + needs_error_drop: bool, + needs_task_wait: bool, +} + fn validate_imports_top_level( resolve: &Resolve, world: WorldId, funcs: &IndexMap<&str, u32>, types: &Types, -) -> Result { +) -> Result { + // TODO: handle top-level required async, future, and stream built-in imports here let is_resource = |name: &str| match resolve.worlds[world] .imports .get(&WorldKey::Name(name.to_string())) @@ -631,24 +872,45 @@ fn validate_imports_top_level( } _ => false, }; - let mut required = RequiredImports::default(); + let mut imports = Imports { + required: RequiredImports::default(), + needs_error_drop: false, + needs_task_wait: false, + }; for (name, ty) in funcs { - match resolve.worlds[world].imports.get(&world_key(resolve, name)) { - Some(WorldItem::Function(func)) => { - let ty = types[types.core_type_at(*ty).unwrap_sub()].unwrap_func(); - validate_func(resolve, ty, func, AbiVariant::GuestImport)?; + { + if *name == "[error-drop]" { + imports.needs_error_drop = true; + continue; + } + + if *name == "[task-wait]" { + imports.needs_task_wait = true; + continue; } - Some(_) => bail!("expected world top-level import `{name}` to be a function"), - None => match valid_imported_resource_func(name, *ty, types, is_resource)? { - Some(name) => { - required.resources.insert(name.to_string()); + + let (abi, name) = if let Some(name) = name.strip_prefix("[async]") { + (AbiVariant::GuestImportAsync, name) + } else { + (AbiVariant::GuestImport, *name) + }; + match resolve.worlds[world].imports.get(&world_key(resolve, name)) { + Some(WorldItem::Function(func)) => { + let ty = types[types.core_type_at(*ty).unwrap_sub()].unwrap_func(); + validate_func(resolve, ty, &func, abi)?; } - None => bail!("no top-level imported function `{name}` specified"), - }, + Some(_) => bail!("expected world top-level import `{name}` to be a function"), + None => match valid_imported_resource_func(name, *ty, types, is_resource)? { + Some(name) => { + imports.required.resources.insert(name.to_string()); + } + None => bail!("no top-level imported function `{name}` specified"), + }, + } } - required.funcs.insert(name.to_string()); + imports.required.funcs.insert(name.to_string()); } - Ok(required) + Ok(imports) } fn valid_imported_resource_func<'a>( @@ -691,12 +953,45 @@ fn valid_exported_resource_func<'a>( Ok(None) } -fn validate_imported_interface( - resolve: &Resolve, +fn find_payloads<'a>( + resolve: &'a Resolve, + interface: Option, + function: &'a Function, + payload_map: &mut IndexMap<(String, usize), PayloadInfo<'a>>, +) { + let types = function.find_futures_and_streams(resolve); + for (index, ty) in types.iter().enumerate() { + payload_map.insert( + (function.name.clone(), index), + PayloadInfo { + interface, + function, + ty: *ty, + future_new: None, + future_send: None, + future_receive: None, + future_drop_sender: None, + future_drop_receiver: None, + stream_new: None, + stream_send: None, + stream_receive: None, + stream_drop_sender: None, + stream_drop_receiver: None, + }, + ); + } +} + +fn validate_imported_interface<'a>( + resolve: &'a Resolve, interface: InterfaceId, name: &str, imports: &IndexMap<&str, u32>, types: &Types, + required_payload_funcs: &mut IndexMap< + (String, bool), + IndexMap<(String, usize), PayloadInfo<'a>>, + >, ) -> Result { let mut required = RequiredImports::default(); let is_resource = |name: &str| { @@ -707,20 +1002,35 @@ fn validate_imported_interface( matches!(resolve.types[ty].kind, TypeDefKind::Resource) }; for (func_name, ty) in imports { - match resolve.interfaces[interface].functions.get(*func_name) { - Some(f) => { - let ty = types[types.core_type_at(*ty).unwrap_sub()].unwrap_func(); - validate_func(resolve, ty, f, AbiVariant::GuestImport)?; - } - None => match valid_imported_resource_func(func_name, *ty, types, is_resource)? { - Some(name) => { - required.resources.insert(name.to_string()); + { + let (abi, func_name) = if let Some(name) = func_name.strip_prefix("[async]") { + (AbiVariant::GuestImportAsync, name) + } else { + (AbiVariant::GuestImport, *func_name) + }; + match resolve.interfaces[interface].functions.get(func_name) { + Some(f) => { + let ty = types[types.core_type_at(*ty).unwrap_sub()].unwrap_func(); + validate_func(resolve, ty, f, abi)?; + find_payloads( + resolve, + Some(interface), + f, + required_payload_funcs + .entry((name.to_string(), true)) + .or_default(), + ); } - None => bail!( - "import interface `{name}` is missing function \ + None => match valid_imported_resource_func(func_name, *ty, types, is_resource)? { + Some(name) => { + required.resources.insert(name.to_string()); + } + None => bail!( + "import interface `{name}` is missing function \ `{func_name}` that is required by the module", - ), - }, + ), + }, + } } required.funcs.insert(func_name.to_string()); } @@ -760,6 +1070,14 @@ fn validate_post_return( ) } +fn validate_callback(ty: &FuncType, func: &Function) -> Result<()> { + validate_func_sig( + &format!("{} callback", func.name), + &FuncType::new([ValType::I32; 4], [ValType::I32]), + ty, + ) +} + fn validate_func_sig(name: &str, expected: &FuncType, ty: &wasmparser::FuncType) -> Result<()> { if ty != expected { bail!( @@ -782,40 +1100,99 @@ fn validate_exported_item<'a>( exports: &IndexMap<&str, u32>, types: &Types, post_returns: &mut IndexSet, + callbacks: &mut IndexSet, + required_payload_funcs: &mut IndexMap< + (String, bool), + IndexMap<(String, usize), PayloadInfo<'a>>, + >, + required_async_funcs: &mut IndexMap>>, required_resource_funcs: &mut IndexMap>, ) -> Result<()> { - let mut validate = |func: &Function, name: Option<&str>| { - let expected_export_name = func.core_export_name(name); - let func_index = match exports.get(expected_export_name.as_ref()) { - Some(func_index) => func_index, - None => bail!( - "module does not export required function `{}`", - expected_export_name - ), - }; - let id = types.core_function_at(*func_index); - let ty = types[id].unwrap_func(); - validate_func(resolve, ty, func, AbiVariant::GuestExport)?; - - let post_return = format!("{POST_RETURN_PREFIX}{expected_export_name}"); - if let Some(index) = exports.get(&post_return[..]) { - let ok = post_returns.insert(post_return); - assert!(ok); - let id = types.core_function_at(*index); + let mut validate = + |func: &'a Function, + interface: Option<(&str, InterfaceId)>, + async_map: &mut IndexMap>, + payload_map: &mut IndexMap<(String, usize), PayloadInfo<'a>>| { + let expected_export_name = func.core_export_name(interface.map(|(n, _)| n)); + let (abi, func_index) = match exports.get(expected_export_name.as_ref()) { + Some(func_index) => (AbiVariant::GuestExport, func_index), + None => match exports.get(format!("[async]{expected_export_name}").as_str()) { + Some(func_index) => { + async_map.insert( + func.name.clone(), + AsyncExportInfo { + interface: interface.map(|(_, id)| id), + function: func, + start_import: None, + return_import: None, + }, + ); + + (AbiVariant::GuestExportAsync, func_index) + } + None => bail!( + "module does not export required function `{}`", + expected_export_name + ), + }, + }; + let id = types.core_function_at(*func_index); let ty = types[id].unwrap_func(); - validate_post_return(resolve, ty, func)?; - } - Ok(()) - }; + validate_func(resolve, ty, func, abi)?; + find_payloads(resolve, interface.map(|(_, id)| id), func, payload_map); + + let post_return = format!("{POST_RETURN_PREFIX}{expected_export_name}"); + if let Some(index) = exports.get(&post_return[..]) { + let ok = post_returns.insert(post_return); + assert!(ok); + let id = types.core_function_at(*index); + let ty = types[id].unwrap_func(); + validate_post_return(resolve, ty, func)?; + } + + let callback = format!("{CALLBACK_PREFIX}{expected_export_name}"); + if let Some(index) = exports.get(&callback[..]) { + let ok = callbacks.insert(callback); + assert!(ok); + let id = types.core_function_at(*index); + let ty = types[id].unwrap_func(); + validate_callback(ty, func)?; + } + + Ok(()) + }; + match item { - WorldItem::Function(func) => validate(func, None)?, - WorldItem::Interface { id: interface, .. } => { - let interface = &resolve.interfaces[*interface]; + WorldItem::Function(func) => validate( + func, + None, + required_async_funcs + .entry(format!("[export]{BARE_FUNC_MODULE_NAME}")) + .or_default(), + required_payload_funcs + .entry((BARE_FUNC_MODULE_NAME.to_string(), false)) + .or_default(), + )?, + WorldItem::Interface { + id: interface_id, .. + } => { + let interface = &resolve.interfaces[*interface_id]; + let mut async_map = IndexMap::new(); for (_, f) in interface.functions.iter() { - validate(f, Some(export_name)).with_context(|| { + validate( + f, + Some((export_name, *interface_id)), + &mut async_map, + required_payload_funcs + .entry((export_name.to_string(), false)) + .or_default(), + ) + .with_context(|| { format!("failed to validate exported interface `{export_name}`") })?; } + let prev = required_async_funcs.insert(format!("[export]{export_name}"), async_map); + assert!(prev.is_none()); let mut map = IndexMap::new(); for (name, id) in interface.types.iter() { if !matches!(resolve.types[*id].kind, TypeDefKind::Resource) { diff --git a/crates/wit-parser/src/abi.rs b/crates/wit-parser/src/abi.rs index f3d7fc82ed..432e80e12f 100644 --- a/crates/wit-parser/src/abi.rs +++ b/crates/wit-parser/src/abi.rs @@ -127,6 +127,8 @@ pub enum AbiVariant { GuestImport, /// The guest is defining and exporting the function. GuestExport, + GuestImportAsync, + GuestExportAsync, } impl Resolve { @@ -135,6 +137,26 @@ impl Resolve { /// The first entry returned is the list of parameters and the second entry /// is the list of results for the wasm function signature. pub fn wasm_signature(&self, variant: AbiVariant, func: &Function) -> WasmSignature { + match variant { + AbiVariant::GuestExportAsync => { + return WasmSignature { + params: Vec::new(), + indirect_params: false, + results: vec![WasmType::Pointer], + retptr: false, + } + } + AbiVariant::GuestImportAsync => { + return WasmSignature { + params: vec![WasmType::Pointer; 3], + indirect_params: true, + results: vec![WasmType::I32], + retptr: true, + } + } + _ => {} + } + const MAX_FLAT_PARAMS: usize = 16; const MAX_FLAT_RESULTS: usize = 1; @@ -185,6 +207,7 @@ impl Resolve { AbiVariant::GuestExport => { results.push(WasmType::Pointer); } + _ => unreachable!(), } } @@ -274,6 +297,10 @@ impl Resolve { result.push(WasmType::I32); } + TypeDefKind::Error => { + result.push(WasmType::I32); + } + TypeDefKind::Unknown => unreachable!(), }, } diff --git a/crates/wit-parser/src/ast.rs b/crates/wit-parser/src/ast.rs index d65d5fed4f..2c5d6a5e7b 100644 --- a/crates/wit-parser/src/ast.rs +++ b/crates/wit-parser/src/ast.rs @@ -737,6 +737,7 @@ enum Type<'a> { Result(Result_<'a>), Future(Future<'a>), Stream(Stream<'a>), + Error(Span), } enum Handle<'a> { @@ -893,8 +894,7 @@ struct Result_<'a> { struct Stream<'a> { span: Span, - element: Option>>, - end: Option>>, + ty: Box>, } struct NamedFunc<'a> { @@ -1353,29 +1353,21 @@ impl<'a> Type<'a> { Ok(Type::Future(Future { span, ty })) } - // stream - // stream<_, Z> // stream // stream Some((span, Token::Stream)) => { - let mut element = None; - let mut end = None; - - if tokens.eat(Token::LessThan)? { - if tokens.eat(Token::Underscore)? { - tokens.expect(Token::Comma)?; - end = Some(Box::new(Type::parse(tokens)?)); - } else { - element = Some(Box::new(Type::parse(tokens)?)); - if tokens.eat(Token::Comma)? { - end = Some(Box::new(Type::parse(tokens)?)); - } - }; - tokens.expect(Token::GreaterThan)?; - }; - Ok(Type::Stream(Stream { span, element, end })) + tokens.expect(Token::LessThan)?; + let ty = Type::parse(tokens)?; + tokens.expect(Token::GreaterThan)?; + Ok(Type::Stream(Stream { + span, + ty: Box::new(ty), + })) } + // error + Some((span, Token::Error)) => Ok(Type::Error(span)), + // own Some((_span, Token::Own)) => { tokens.expect(Token::LessThan)?; @@ -1435,6 +1427,7 @@ impl<'a> Type<'a> { Type::Result(r) => r.span, Type::Future(f) => f.span, Type::Stream(s) => s.span, + Type::Error(span) => *span, } } } diff --git a/crates/wit-parser/src/ast/lex.rs b/crates/wit-parser/src/ast/lex.rs index 93ad600872..ceb427d8b7 100644 --- a/crates/wit-parser/src/ast/lex.rs +++ b/crates/wit-parser/src/ast/lex.rs @@ -79,6 +79,7 @@ pub enum Token { Result_, Future, Stream, + Error, List, Underscore, As, @@ -310,6 +311,7 @@ impl<'a> Tokenizer<'a> { "result" => Result_, "future" => Future, "stream" => Stream, + "error" => Error, "list" => List, "_" => Underscore, "as" => As, @@ -563,6 +565,7 @@ impl Token { Result_ => "keyword `result`", Future => "keyword `future`", Stream => "keyword `stream`", + Error => "keyword `error`", List => "keyword `list`", Underscore => "keyword `_`", Id => "an identifier", diff --git a/crates/wit-parser/src/ast/resolve.rs b/crates/wit-parser/src/ast/resolve.rs index 0d3465a602..611df392a2 100644 --- a/crates/wit-parser/src/ast/resolve.rs +++ b/crates/wit-parser/src/ast/resolve.rs @@ -95,7 +95,8 @@ enum Key { Option(Type), Result(Option, Option), Future(Option), - Stream(Option, Option), + Stream(Type), + Error, } enum TypeItem<'a, 'b> { @@ -1222,13 +1223,11 @@ impl<'a> Resolver<'a> { ok: self.resolve_optional_type(r.ok.as_deref(), stability)?, err: self.resolve_optional_type(r.err.as_deref(), stability)?, }), - ast::Type::Future(t) => { - TypeDefKind::Future(self.resolve_optional_type(t.ty.as_deref(), stability)?) - } - ast::Type::Stream(s) => TypeDefKind::Stream(Stream { - element: self.resolve_optional_type(s.element.as_deref(), stability)?, - end: self.resolve_optional_type(s.end.as_deref(), stability)?, - }), + ast::Type::Future(t) => TypeDefKind::Future( + self.resolve_optional_type(t.ty.as_ref().map(|bx| &**bx), stability)?, + ), + ast::Type::Stream(t) => TypeDefKind::Stream(self.resolve_type(&t.ty, stability)?), + ast::Type::Error(_) => TypeDefKind::Error, }) } @@ -1274,6 +1273,7 @@ impl<'a> Resolver<'a> { _ => {} } let kind = self.resolve_type_def(ty, stability)?; + Ok(self.anon_type_def( TypeDef { kind, @@ -1330,7 +1330,8 @@ impl<'a> Resolver<'a> { TypeDefKind::Option(t) => Key::Option(*t), TypeDefKind::Result(r) => Key::Result(r.ok, r.err), TypeDefKind::Future(ty) => Key::Future(*ty), - TypeDefKind::Stream(s) => Key::Stream(s.element, s.end), + TypeDefKind::Stream(ty) => Key::Stream(*ty), + TypeDefKind::Error => Key::Error, TypeDefKind::Unknown => unreachable!(), }; let id = self.anon_types.entry(key).or_insert_with(|| { @@ -1488,7 +1489,8 @@ fn collect_deps<'a>(ty: &ast::Type<'a>, deps: &mut Vec>) { | ast::Type::Char(_) | ast::Type::String(_) | ast::Type::Flags(_) - | ast::Type::Enum(_) => {} + | ast::Type::Enum(_) + | ast::Type::Error(_) => {} ast::Type::Name(name) => deps.push(name.clone()), ast::Type::List(list) => collect_deps(&list.ty, deps), ast::Type::Handle(handle) => match handle { @@ -1513,6 +1515,7 @@ fn collect_deps<'a>(ty: &ast::Type<'a>, deps: &mut Vec>) { } } } + ast::Type::Stream(ty) => collect_deps(&ty.ty, deps), ast::Type::Option(ty) => collect_deps(&ty.ty, deps), ast::Type::Result(r) => { if let Some(ty) = &r.ok { @@ -1527,13 +1530,5 @@ fn collect_deps<'a>(ty: &ast::Type<'a>, deps: &mut Vec>) { collect_deps(t, deps) } } - ast::Type::Stream(s) => { - if let Some(t) = &s.element { - collect_deps(t, deps); - } - if let Some(t) = &s.end { - collect_deps(t, deps); - } - } } } diff --git a/crates/wit-parser/src/decoding.rs b/crates/wit-parser/src/decoding.rs index 1c61ecfd48..c1dc6dba24 100644 --- a/crates/wit-parser/src/decoding.rs +++ b/crates/wit-parser/src/decoding.rs @@ -1264,15 +1264,16 @@ impl WitPackageDecoder<'_> { | TypeDefKind::Tuple(_) | TypeDefKind::Option(_) | TypeDefKind::Result(_) - | TypeDefKind::Handle(_) => {} + | TypeDefKind::Handle(_) + | TypeDefKind::Future(_) + | TypeDefKind::Stream(_) + | TypeDefKind::Error => {} TypeDefKind::Resource | TypeDefKind::Record(_) | TypeDefKind::Enum(_) | TypeDefKind::Variant(_) - | TypeDefKind::Flags(_) - | TypeDefKind::Future(_) - | TypeDefKind::Stream(_) => { + | TypeDefKind::Flags(_) => { bail!("unexpected unnamed type of kind '{}'", kind.as_str()); } TypeDefKind::Unknown => unreachable!(), @@ -1399,6 +1400,16 @@ impl WitPackageDecoder<'_> { let id = self.type_map[&(*id).into()]; Ok(TypeDefKind::Handle(Handle::Borrow(id))) } + + types::ComponentDefinedType::Future(ty) => Ok(TypeDefKind::Future( + ty.as_ref().map(|ty| self.convert_valtype(ty)).transpose()?, + )), + + types::ComponentDefinedType::Stream(ty) => { + Ok(TypeDefKind::Stream(self.convert_valtype(ty)?)) + } + + types::ComponentDefinedType::Error => Ok(TypeDefKind::Error), } } @@ -1669,11 +1680,34 @@ impl Registrar<'_> { Ok(()) } + types::ComponentDefinedType::Future(payload) => { + let ty = match &self.resolve.types[id].kind { + TypeDefKind::Future(p) => p, + TypeDefKind::Type(Type::Id(_)) => return Ok(()), + _ => bail!("expected a future"), + }; + match (payload, ty) { + (Some(a), Some(b)) => self.valtype(a, b), + (None, None) => Ok(()), + _ => bail!("disagreement on future payload"), + } + } + + types::ComponentDefinedType::Stream(payload) => { + let ty = match &self.resolve.types[id].kind { + TypeDefKind::Stream(p) => p, + TypeDefKind::Type(Type::Id(_)) => return Ok(()), + _ => bail!("expected a stream"), + }; + self.valtype(payload, ty) + } + // These have no recursive structure so they can bail out. types::ComponentDefinedType::Flags(_) | types::ComponentDefinedType::Enum(_) | types::ComponentDefinedType::Own(_) - | types::ComponentDefinedType::Borrow(_) => Ok(()), + | types::ComponentDefinedType::Borrow(_) + | types::ComponentDefinedType::Error => Ok(()), } } diff --git a/crates/wit-parser/src/lib.rs b/crates/wit-parser/src/lib.rs index 2ab630cc69..806aac6890 100644 --- a/crates/wit-parser/src/lib.rs +++ b/crates/wit-parser/src/lib.rs @@ -474,7 +474,8 @@ pub enum TypeDefKind { Result(Result_), List(Type), Future(Option), - Stream(Stream), + Stream(Type), + Error, Type(Type), /// This represents a type of unknown structure imported from a foreign @@ -503,6 +504,7 @@ impl TypeDefKind { TypeDefKind::List(_) => "list", TypeDefKind::Future(_) => "future", TypeDefKind::Stream(_) => "stream", + TypeDefKind::Error => "error", TypeDefKind::Type(_) => "type", TypeDefKind::Unknown => "unknown", } @@ -685,13 +687,6 @@ pub struct Result_ { pub err: Option, } -#[derive(Debug, Clone, PartialEq)] -#[cfg_attr(feature = "serde", derive(Serialize))] -pub struct Stream { - pub element: Option, - pub end: Option, -} - #[derive(Clone, Default, Debug, PartialEq, Eq)] #[cfg_attr(feature = "serde", derive(Serialize))] pub struct Docs { @@ -824,6 +819,68 @@ impl Function { None => Cow::Borrowed(&self.name), } } + + pub fn find_futures_and_streams(&self, resolve: &Resolve) -> Vec { + let mut results = Vec::new(); + for (_, ty) in self.params.iter() { + find_futures_and_streams(resolve, *ty, &mut results); + } + for ty in self.results.iter_types() { + find_futures_and_streams(resolve, *ty, &mut results); + } + results + } +} + +fn find_futures_and_streams(resolve: &Resolve, ty: Type, results: &mut Vec) { + if let Type::Id(id) = ty { + match &resolve.types[id].kind { + TypeDefKind::Resource + | TypeDefKind::Handle(_) + | TypeDefKind::Flags(_) + | TypeDefKind::Enum(_) + | TypeDefKind::Error => {} + TypeDefKind::Record(r) => { + for Field { ty, .. } in &r.fields { + find_futures_and_streams(resolve, *ty, results); + } + } + TypeDefKind::Tuple(t) => { + for ty in &t.types { + find_futures_and_streams(resolve, *ty, results); + } + } + TypeDefKind::Variant(v) => { + for Case { ty, .. } in &v.cases { + if let Some(ty) = ty { + find_futures_and_streams(resolve, *ty, results); + } + } + } + TypeDefKind::Option(ty) | TypeDefKind::List(ty) | TypeDefKind::Type(ty) => { + find_futures_and_streams(resolve, *ty, results); + } + TypeDefKind::Result(r) => { + if let Some(ty) = r.ok { + find_futures_and_streams(resolve, ty, results); + } + if let Some(ty) = r.err { + find_futures_and_streams(resolve, ty, results); + } + } + TypeDefKind::Future(ty) => { + if let Some(ty) = ty { + find_futures_and_streams(resolve, *ty, results); + } + results.push(id); + } + TypeDefKind::Stream(ty) => { + find_futures_and_streams(resolve, *ty, results); + results.push(id); + } + TypeDefKind::Unknown => unreachable!(), + } + } } /// Representation of the stability attributes associated with a world, diff --git a/crates/wit-parser/src/live.rs b/crates/wit-parser/src/live.rs index 1b6986c839..dfca938ba3 100644 --- a/crates/wit-parser/src/live.rs +++ b/crates/wit-parser/src/live.rs @@ -72,7 +72,8 @@ impl LiveTypes { TypeDefKind::Type(t) | TypeDefKind::List(t) | TypeDefKind::Option(t) - | TypeDefKind::Future(Some(t)) => self.add_type(resolve, t), + | TypeDefKind::Future(Some(t)) + | TypeDefKind::Stream(t) => self.add_type(resolve, t), TypeDefKind::Handle(handle) => match handle { crate::Handle::Own(ty) => self.add_type_id(resolve, *ty), crate::Handle::Borrow(ty) => self.add_type_id(resolve, *ty), @@ -103,15 +104,10 @@ impl LiveTypes { self.add_type(resolve, ty); } } - TypeDefKind::Stream(s) => { - if let Some(ty) = &s.element { - self.add_type(resolve, ty); - } - if let Some(ty) = &s.end { - self.add_type(resolve, ty); - } - } - TypeDefKind::Flags(_) | TypeDefKind::Enum(_) | TypeDefKind::Future(None) => {} + TypeDefKind::Error + | TypeDefKind::Flags(_) + | TypeDefKind::Enum(_) + | TypeDefKind::Future(None) => {} TypeDefKind::Unknown => unreachable!(), } assert!(self.set.insert(ty)); diff --git a/crates/wit-parser/src/resolve.rs b/crates/wit-parser/src/resolve.rs index 0baf28f181..73b8ea4036 100644 --- a/crates/wit-parser/src/resolve.rs +++ b/crates/wit-parser/src/resolve.rs @@ -4,9 +4,9 @@ use crate::ast::{parse_use_path, ParsedUsePath}; use crate::serde_::{serialize_arena, serialize_id_map}; use crate::{ AstItem, Docs, Error, Function, FunctionKind, Handle, IncludeName, Interface, InterfaceId, - InterfaceSpan, PackageName, Results, SourceMap, Stability, Type, TypeDef, TypeDefKind, TypeId, - TypeOwner, UnresolvedPackage, UnresolvedPackageGroup, World, WorldId, WorldItem, WorldKey, - WorldSpan, + InterfaceSpan, PackageName, Result_, Results, SourceMap, Stability, Type, TypeDef, TypeDefKind, + TypeId, TypeOwner, UnresolvedPackage, UnresolvedPackageGroup, World, WorldId, WorldItem, + WorldKey, WorldSpan, }; use anyhow::{anyhow, bail, Context, Result}; use id_arena::{Arena, Id}; @@ -491,7 +491,8 @@ impl Resolve { | TypeDefKind::Option(_) | TypeDefKind::Result(_) | TypeDefKind::Future(_) - | TypeDefKind::Stream(_) => false, + | TypeDefKind::Stream(_) + | TypeDefKind::Error => false, TypeDefKind::Type(t) => self.all_bits_valid(t), TypeDefKind::Handle(h) => match h { @@ -1408,6 +1409,144 @@ impl Resolve { } } + pub fn add_future_and_stream_results(&mut self) { + let err = Some(Type::Id(self.types.alloc(TypeDef { + kind: TypeDefKind::Error, + name: None, + docs: Docs::default(), + owner: TypeOwner::None, + stability: Stability::Unknown, + }))); + + self.types.alloc(TypeDef { + kind: TypeDefKind::Result(Result_ { ok: None, err }), + name: None, + docs: Docs::default(), + owner: TypeOwner::None, + stability: Stability::Unknown, + }); + + let types = self + .types + .iter() + .filter_map(|(id, ty)| match &ty.kind { + TypeDefKind::Future(_) | TypeDefKind::Stream(_) => Some(id), + _ => None, + }) + .collect::>(); + + for id in types { + match self.types[id].kind.clone() { + TypeDefKind::Future(ty) => { + self.types.alloc(TypeDef { + kind: TypeDefKind::Result(Result_ { ok: ty, err }), + name: None, + docs: Docs::default(), + owner: TypeOwner::None, + stability: Stability::Unknown, + }); + } + TypeDefKind::Stream(ty) => { + let ok = Some(Type::Id(self.types.alloc(TypeDef { + kind: TypeDefKind::List(ty), + name: None, + docs: Docs::default(), + owner: TypeOwner::None, + stability: Stability::Unknown, + }))); + let result = Type::Id(self.types.alloc(TypeDef { + kind: TypeDefKind::Result(Result_ { ok, err }), + name: None, + docs: Docs::default(), + owner: TypeOwner::None, + stability: Stability::Unknown, + })); + self.types.alloc(TypeDef { + kind: TypeDefKind::Option(result), + name: None, + docs: Docs::default(), + owner: TypeOwner::None, + stability: Stability::Unknown, + }); + } + _ => unreachable!(), + } + } + } + + pub fn find_future_and_stream_results(&self) -> (Option, HashMap) { + #[derive(Copy, Clone, Hash, PartialEq, Eq)] + enum PayloadType { + Future(Option), + Stream(Type), + } + + let mut unit_result = None; + + let mut types = HashMap::<_, HashSet<_>>::new(); + for (id, ty) in &self.types { + match &ty.kind { + TypeDefKind::Future(ty) => { + types + .entry(PayloadType::Future(*ty)) + .or_default() + .insert(id); + } + TypeDefKind::Stream(ty) => { + types + .entry(PayloadType::Stream(*ty)) + .or_default() + .insert(id); + } + _ => {} + } + } + + let mut payload_results = HashMap::new(); + for (id, ty) in &self.types { + match &ty.kind { + TypeDefKind::Option(Type::Id(ty)) => { + if let TypeDefKind::Result(Result_ { + ok: Some(Type::Id(ok)), + err: Some(Type::Id(err)), + }) = &self.types[*ty].kind + { + if let (TypeDefKind::List(ok), TypeDefKind::Error) = + (&self.types[*ok].kind, &self.types[*err].kind) + { + if let Some(types) = types.get(&PayloadType::Stream(*ok)) { + for ty in types { + payload_results.insert(*ty, id); + } + } + } + } + } + TypeDefKind::Result(Result_ { + ok, + err: Some(Type::Id(err)), + }) => { + if let TypeDefKind::Error = &self.types[*err].kind { + if let Some(types) = types.get(&PayloadType::Future(*ok)) { + for ty in types { + payload_results.insert(*ty, id); + } + } + if ok.is_none() { + unit_result = Some(id); + } + } + } + _ => (), + } + } + + if unit_result.is_none() { + eprintln!("couldn't find unit result in {:#?}", self.types); + } + + (unit_result, payload_results) + } fn include_stability(&self, stability: &Stability) -> bool { match stability { Stability::Stable { .. } | Stability::Unknown => true, @@ -1856,7 +1995,7 @@ impl Remap { } } } - Option(t) => self.update_ty(resolve, t, span)?, + Option(t) | Stream(t) => self.update_ty(resolve, t, span)?, Result(r) => { if let Some(ty) = &mut r.ok { self.update_ty(resolve, ty, span)?; @@ -1865,16 +2004,9 @@ impl Remap { self.update_ty(resolve, ty, span)?; } } + Error => {} List(t) => self.update_ty(resolve, t, span)?, Future(Some(t)) => self.update_ty(resolve, t, span)?, - Stream(t) => { - if let Some(ty) = &mut t.element { - self.update_ty(resolve, ty, span)?; - } - if let Some(ty) = &mut t.end { - self.update_ty(resolve, ty, span)?; - } - } // Note that `update_ty` is specifically not used here as typedefs // because for the `type a = b` form that doesn't force `a` to be a @@ -2517,12 +2649,10 @@ impl Remap { .iter() .filter_map(|t| t.as_ref()) .any(|t| self.type_has_borrow(resolve, t)), - TypeDefKind::Stream(r) => [&r.element, &r.end] - .iter() - .filter_map(|t| t.as_ref()) - .any(|t| self.type_has_borrow(resolve, t)), + TypeDefKind::Stream(r) => self.type_has_borrow(resolve, r), TypeDefKind::Future(None) => false, TypeDefKind::Unknown => unreachable!(), + TypeDefKind::Error => todo!(), } } } diff --git a/crates/wit-parser/src/sizealign.rs b/crates/wit-parser/src/sizealign.rs index 5040e1a652..f659de3294 100644 --- a/crates/wit-parser/src/sizealign.rs +++ b/crates/wit-parser/src/sizealign.rs @@ -56,6 +56,8 @@ impl SizeAlign { TypeDefKind::Future(_) => (4, 4), // A stream is represented as an index. TypeDefKind::Stream(_) => (4, 4), + // An error is represented as an index. + TypeDefKind::Error => (4, 4), // This shouldn't be used for anything since raw resources aren't part of the ABI -- just handles to // them. TypeDefKind::Resource => (usize::MAX, usize::MAX), diff --git a/crates/wit-parser/tests/ui/comments.wit.json b/crates/wit-parser/tests/ui/comments.wit.json index f2a04f9d39..7c7a44e1c4 100644 --- a/crates/wit-parser/tests/ui/comments.wit.json +++ b/crates/wit-parser/tests/ui/comments.wit.json @@ -27,10 +27,7 @@ { "name": "bar", "kind": { - "stream": { - "element": 0, - "end": null - } + "stream": 0 }, "owner": { "interface": 0 diff --git a/crates/wit-parser/tests/ui/types.wit b/crates/wit-parser/tests/ui/types.wit index 4d007b64fe..dae4399148 100644 --- a/crates/wit-parser/tests/ui/types.wit +++ b/crates/wit-parser/tests/ui/types.wit @@ -47,10 +47,7 @@ interface types { type t45 = list>>; type t46 = t44; type t47 = %t44; - type t48 = stream; - type t49 = stream<_, u32>; type t50 = stream; - type t51 = stream; type t52 = future; type t53 = future; diff --git a/crates/wit-parser/tests/ui/types.wit.json b/crates/wit-parser/tests/ui/types.wit.json index 8874e3541d..f26954b601 100644 --- a/crates/wit-parser/tests/ui/types.wit.json +++ b/crates/wit-parser/tests/ui/types.wit.json @@ -50,14 +50,11 @@ "t45": 46, "t46": 47, "t47": 48, - "t48": 49, - "t49": 50, - "t50": 51, - "t51": 52, - "t52": 53, - "t53": 54, - "bar": 55, - "foo": 56 + "t50": 49, + "t52": 50, + "t53": 51, + "bar": 52, + "foo": 53 }, "functions": {}, "package": 0 @@ -674,49 +671,10 @@ "interface": 0 } }, - { - "name": "t48", - "kind": { - "stream": { - "element": "u32", - "end": "u32" - } - }, - "owner": { - "interface": 0 - } - }, - { - "name": "t49", - "kind": { - "stream": { - "element": null, - "end": "u32" - } - }, - "owner": { - "interface": 0 - } - }, { "name": "t50", "kind": { - "stream": { - "element": "u32", - "end": null - } - }, - "owner": { - "interface": 0 - } - }, - { - "name": "t51", - "kind": { - "stream": { - "element": null, - "end": null - } + "stream": "u32" }, "owner": { "interface": 0 @@ -752,7 +710,7 @@ { "name": "foo", "kind": { - "type": 55 + "type": 52 }, "owner": { "interface": 0 diff --git a/src/bin/wasm-tools/dump.rs b/src/bin/wasm-tools/dump.rs index 7c3539ec2d..deb9fb485e 100644 --- a/src/bin/wasm-tools/dump.rs +++ b/src/bin/wasm-tools/dump.rs @@ -417,9 +417,21 @@ impl<'a> Dump<'a> { CanonicalFunction::Lower { .. } | CanonicalFunction::ResourceNew { .. } | CanonicalFunction::ResourceDrop { .. } - | CanonicalFunction::ResourceRep { .. } => { - ("core func", &mut i.core_funcs) - } + | CanonicalFunction::ResourceRep { .. } + | CanonicalFunction::AsyncStart { .. } + | CanonicalFunction::AsyncReturn { .. } + | CanonicalFunction::FutureNew { .. } + | CanonicalFunction::FutureSend { .. } + | CanonicalFunction::FutureReceive { .. } + | CanonicalFunction::FutureDropSender { .. } + | CanonicalFunction::FutureDropReceiver { .. } + | CanonicalFunction::StreamNew { .. } + | CanonicalFunction::StreamSend { .. } + | CanonicalFunction::StreamReceive { .. } + | CanonicalFunction::StreamDropSender { .. } + | CanonicalFunction::StreamDropReceiver { .. } + | CanonicalFunction::TaskWait { .. } + | CanonicalFunction::ErrorDrop => ("core func", &mut i.core_funcs), }; write!(me.state, "[{} {}] {:?}", name, inc(col), f)?; diff --git a/src/bin/wasm-tools/main.rs b/src/bin/wasm-tools/main.rs index 8d2cce140d..bd7a50c51f 100644 --- a/src/bin/wasm-tools/main.rs +++ b/src/bin/wasm-tools/main.rs @@ -105,6 +105,11 @@ fn main() -> ExitCode { } fn print_error(color: ColorChoice, err: anyhow::Error) -> Result<()> { + if true { + eprintln!("{err:?}"); + return Ok(()); + } + let color = if color == ColorChoice::Auto && !io::stderr().is_terminal() { ColorChoice::Never } else {