diff --git a/src/multistream_select/dialer_select.rs b/src/multistream_select/dialer_select.rs index 86c22647..f62ef1dc 100644 --- a/src/multistream_select/dialer_select.rs +++ b/src/multistream_select/dialer_select.rs @@ -24,7 +24,6 @@ use crate::{ codec::unsigned_varint::UnsignedVarint, error::{self, Error, ParseError, SubstreamError}, multistream_select::{ - drain_trailing_protocols, protocol::{ webrtc_encode_multistream_message, HeaderLine, Message, MessageIO, Protocol, ProtocolError, PROTO_MULTISTREAM_1_0, @@ -300,6 +299,12 @@ pub enum HandshakeResult { /// The returned tuple contains the negotiated protocol and response /// that must be sent to remote peer. Succeeded(ProtocolName), + + /// The proposed protocol was rejected by the remote peer. + /// + /// The caller should check if there are remaining fallback protocols to try + /// via [`WebRtcDialerState::propose_next_fallback()`]. + Rejected, } /// Handshake state. @@ -335,10 +340,10 @@ impl WebRtcDialerState { fallback_names: Vec, ) -> crate::Result<(Self, Vec)> { let message = webrtc_encode_multistream_message( - std::iter::once(protocol.clone()) - .chain(fallback_names.clone()) - .filter_map(|protocol| Protocol::try_from(protocol.as_ref()).ok()) - .map(Message::Protocol), + Message::Protocol( + Protocol::try_from(protocol.as_ref()).map_err(|_| Error::InvalidData)?, + ), + true, )? .freeze() .to_vec(); @@ -353,72 +358,86 @@ impl WebRtcDialerState { )) } + /// Propose the next fallback protocol to the remote peer. + /// + /// Returns `None` if there are no more fallback protocols to try. + /// Returns `Some(message)` with the encoded message to send, containing the protocol name. + pub fn propose_next_fallback(&mut self) -> crate::Result>> { + if self.fallback_names.is_empty() { + return Ok(None); + } + + let next = self.fallback_names.remove(0); + self.protocol = next; + self.state = HandshakeState::WaitingResponse; + + let message = webrtc_encode_multistream_message( + Message::Protocol( + Protocol::try_from(self.protocol.as_ref()).map_err(|_| Error::InvalidData)?, + ), + true, + )? + .freeze() + .to_vec(); + + Ok(Some(message)) + } + /// Register response to [`WebRtcDialerState`]. pub fn register_response( &mut self, payload: Vec, ) -> Result { - // All multistream-select messages are length-prefixed. Since this code path is not using - // multistream_select::protocol::MessageIO, we need to decode and remove the length here. - let remaining: &[u8] = &payload; - let (len, tail) = unsigned_varint::decode::usize(remaining).map_err(|error| { - tracing::debug!( + let bytes = Bytes::from(payload); + let mut remaining = bytes.clone(); + + while !remaining.is_empty() { + let (len, tail) = unsigned_varint::decode::usize(&remaining).map_err(|error| { + tracing::debug!( target: LOG_TARGET, ?error, - message = ?payload, - "Failed to decode length-prefix in multistream message"); - error::NegotiationError::ParseError(ParseError::InvalidData) - })?; + message = ?remaining, + "Failed to decode length-prefix in multistream message", + ); + error::NegotiationError::ParseError(ParseError::InvalidData) + })?; - let len_size = remaining.len() - tail.len(); - let bytes = Bytes::from(payload); - let payload = bytes.slice(len_size..len_size + len); - let remaining = bytes.slice(len_size + len..); - let message = Message::decode(payload); - - tracing::trace!( - target: LOG_TARGET, - ?message, - "Decoded message while registering response", - ); - - let mut protocols = match message { - Ok(Message::Header(HeaderLine::V1)) => { - vec![PROTO_MULTISTREAM_1_0] + let len_size = remaining.len() - tail.len(); + + if len > tail.len() { + tracing::debug!( + target: LOG_TARGET, + message = ?tail, + length_prefix = len, + actual_length = tail.len(), + "Truncated multistream message", + ); + return Err(error::NegotiationError::ParseError(ParseError::InvalidData)); } - Ok(Message::Protocol(protocol)) => vec![protocol], - Ok(Message::Protocols(protocols)) => protocols, - Ok(Message::NotAvailable) => - return match &self.state { - HandshakeState::WaitingProtocol => Err( - error::NegotiationError::MultistreamSelectError(NegotiationError::Failed), - ), - _ => Err(error::NegotiationError::StateMismatch), - }, - Ok(Message::ListProtocols) => return Err(error::NegotiationError::StateMismatch), - Err(_) => return Err(error::NegotiationError::ParseError(ParseError::InvalidData)), - }; - match drain_trailing_protocols(remaining) { - Ok(protos) => protocols.extend(protos), - Err(error) => return Err(error), - } + let payload = remaining.slice(len_size..len_size + len); + remaining = remaining.slice(len_size + len..); + let message = Message::decode(payload); - let mut protocol_iter = protocols.into_iter(); - loop { - match (&self.state, protocol_iter.next()) { - (HandshakeState::WaitingResponse, None) => - return Err(crate::error::NegotiationError::StateMismatch), - (HandshakeState::WaitingResponse, Some(protocol)) => { - if protocol == PROTO_MULTISTREAM_1_0 { - self.state = HandshakeState::WaitingProtocol; - } else { - return Err(crate::error::NegotiationError::MultistreamSelectError( - NegotiationError::Failed, - )); - } + tracing::trace!( + target: LOG_TARGET, + ?message, + "Decoded message while registering response", + ); + + match (&self.state, message) { + (HandshakeState::WaitingResponse, Ok(Message::Header(HeaderLine::V1))) => { + self.state = HandshakeState::WaitingProtocol; } - (HandshakeState::WaitingProtocol, Some(protocol)) => { + (HandshakeState::WaitingResponse, Ok(Message::Protocol(_))) => { + return Err(crate::error::NegotiationError::MultistreamSelectError( + NegotiationError::Failed, + )); + } + (_, Ok(Message::NotAvailable)) => { + return Ok(HandshakeResult::Rejected); + } + (HandshakeState::WaitingProtocol, Ok(Message::Protocol(protocol))) => { if protocol == PROTO_MULTISTREAM_1_0 { return Err(crate::error::NegotiationError::StateMismatch); } @@ -437,11 +456,16 @@ impl WebRtcDialerState { NegotiationError::Failed, )); } - (HandshakeState::WaitingProtocol, None) => { - return Ok(HandshakeResult::NotReady); + _ => { + return Err(crate::error::NegotiationError::StateMismatch); } } } + + match &self.state { + HandshakeState::WaitingProtocol => Ok(HandshakeResult::NotReady), + HandshakeState::WaitingResponse => Err(crate::error::NegotiationError::StateMismatch), + } } } @@ -816,6 +840,7 @@ mod tests { ) .unwrap(); + // Initial message should only contain the main protocol, not the fallback. let mut bytes = BytesMut::with_capacity(32); bytes.put_u8(MSG_MULTISTREAM_1_0.len() as u8); let _ = Message::Header(HeaderLine::V1).encode(&mut bytes).unwrap(); @@ -824,15 +849,52 @@ mod tests { bytes.put_u8((proto1.as_ref().len() + 1) as u8); // + 1 for \n let _ = Message::Protocol(proto1).encode(&mut bytes).unwrap(); - let proto2 = Protocol::try_from(&b"/sup/proto/1"[..]).expect("valid protocol name"); - bytes.put_u8((proto2.as_ref().len() + 1) as u8); // + 1 for \n - let _ = Message::Protocol(proto2).encode(&mut bytes).unwrap(); - let expected_message = bytes.freeze().to_vec(); assert_eq!(message, expected_message); } + #[test] + fn propose_next_fallback() { + let (mut dialer_state, _message) = WebRtcDialerState::propose( + ProtocolName::from("/13371338/proto/1"), + vec![ProtocolName::from("/sup/proto/1")], + ) + .unwrap(); + + // Simulate receiving header-only response, transitioning to WaitingProtocol. + let mut header_bytes = BytesMut::with_capacity(32); + header_bytes.put_u8(MSG_MULTISTREAM_1_0.len() as u8); + let _ = Message::Header(HeaderLine::V1).encode(&mut header_bytes).unwrap(); + // Append "na" to simulate rejection. + let na_bytes = b"na\n"; + header_bytes.put_u8(na_bytes.len() as u8); + header_bytes.put_slice(na_bytes); + + match dialer_state.register_response(header_bytes.freeze().to_vec()) { + Ok(HandshakeResult::Rejected) => {} + event => panic!("expected Rejected, got: {event:?}"), + } + + // Now propose the next fallback. + let fallback_message = dialer_state + .propose_next_fallback() + .expect("no error") + .expect("should have a fallback"); + + let mut expected = BytesMut::with_capacity(32); + expected.put_u8(MSG_MULTISTREAM_1_0.len() as u8); + let _ = Message::Header(HeaderLine::V1).encode(&mut expected).unwrap(); + let proto = Protocol::try_from(&b"/sup/proto/1"[..]).expect("valid protocol name"); + expected.put_u8((proto.as_ref().len() + 1) as u8); + let _ = Message::Protocol(proto).encode(&mut expected).unwrap(); + + assert_eq!(fallback_message, expected.freeze().to_vec()); + + // No more fallbacks. + assert!(dialer_state.propose_next_fallback().unwrap().is_none()); + } + #[test] fn register_response_header_only() { let mut bytes = BytesMut::with_capacity(32); @@ -875,9 +937,10 @@ mod tests { #[test] fn negotiate_main_protocol() { - let message = webrtc_encode_multistream_message(vec![Message::Protocol( - Protocol::try_from(&b"/13371338/proto/1"[..]).unwrap(), - )]) + let message = webrtc_encode_multistream_message( + Message::Protocol(Protocol::try_from(&b"/13371338/proto/1"[..]).unwrap()), + true, + ) .unwrap() .freeze(); @@ -897,9 +960,10 @@ mod tests { #[test] fn negotiate_fallback_protocol() { - let message = webrtc_encode_multistream_message(vec![Message::Protocol( - Protocol::try_from(&b"/sup/proto/1"[..]).unwrap(), - )]) + let message = webrtc_encode_multistream_message( + Message::Protocol(Protocol::try_from(&b"/sup/proto/1"[..]).unwrap()), + true, + ) .unwrap() .freeze(); diff --git a/src/multistream_select/length_delimited.rs b/src/multistream_select/length_delimited.rs index 7052d629..9e8693a9 100644 --- a/src/multistream_select/length_delimited.rs +++ b/src/multistream_select/length_delimited.rs @@ -28,7 +28,7 @@ use std::{ }; const MAX_LEN_BYTES: u16 = 2; -const MAX_FRAME_SIZE: u16 = (1 << (MAX_LEN_BYTES * 8 - MAX_LEN_BYTES)) - 1; +pub(super) const MAX_FRAME_SIZE: u16 = (1 << (MAX_LEN_BYTES * 8 - MAX_LEN_BYTES)) - 1; const DEFAULT_BUFFER_SIZE: usize = 64; const LOG_TARGET: &str = "litep2p::multistream-select"; diff --git a/src/multistream_select/listener_select.rs b/src/multistream_select/listener_select.rs index 6faa2fe0..9b00c09a 100644 --- a/src/multistream_select/listener_select.rs +++ b/src/multistream_select/listener_select.rs @@ -25,10 +25,9 @@ use crate::{ codec::unsigned_varint::UnsignedVarint, error::{self, Error}, multistream_select::{ - drain_trailing_protocols, protocol::{ webrtc_encode_multistream_message, HeaderLine, Message, MessageIO, Protocol, - ProtocolError, PROTO_MULTISTREAM_1_0, + ProtocolError, }, Negotiated, NegotiationError, }, @@ -333,52 +332,121 @@ pub enum ListenerSelectResult { protocol: ProtocolName, /// `multistream-select` message. - message: BytesMut, + message: Bytes, }, /// Requested protocol is not available. Rejected { /// `multistream-select` message. - message: BytesMut, + message: Bytes, }, + + /// The multistream-select header was received but no protocol was proposed yet. + /// The caller should send the `message` (header echo) and wait for the next payload. + PendingProtocol { + /// `multistream-select` message (header echo). + message: Bytes, + }, +} + +/// Decode a single varint-length-prefixed multistream-select message from `data`, +/// advancing past the consumed bytes. +fn decode_multistream_message(data: &mut Bytes) -> Result { + let (len, tail) = unsigned_varint::decode::usize(data).map_err(|error| { + tracing::debug!( + target: LOG_TARGET, + ?error, + message = ?data, + "Failed to decode length-prefix in multistream message", + ); + error::NegotiationError::ParseError(error::ParseError::InvalidData) + })?; + + if len > tail.len() { + tracing::debug!( + target: LOG_TARGET, + length_prefix = len, + actual_length = tail.len(), + "Truncated multistream message", + ); + return Err(error::NegotiationError::ParseError( + error::ParseError::InvalidData, + )); + } + + let len_size = data.len() - tail.len(); + let payload = data.slice(len_size..len_size + len); + *data = data.slice(len_size + len..); + + Message::decode(payload).map_err(|error| { + tracing::debug!(target: LOG_TARGET, ?error, "Failed to decode multistream message"); + error::NegotiationError::ParseError(error::ParseError::InvalidData) + }) } /// Negotiate protocols for listener. /// -/// Parse protocols offered by the remote peer and check if any of the offered protocols match -/// locally available protocols. If a match is found, return an encoded multistream-select -/// response and the negotiated protocol. If parsing fails or no match is found, return an error. +/// Parse the protocol offered by the remote peer and check if it matches any locally available +/// protocol. The `header_received` parameter indicates whether the multistream-select header +/// has already been exchanged in a previous round. pub fn webrtc_listener_negotiate( supported_protocols: Vec, mut payload: Bytes, + header_received: bool, ) -> crate::Result { - let protocols = drain_trailing_protocols(payload)?; - let mut protocol_iter = protocols.into_iter(); + // Save for zero-copy header echo (Bytes::clone is O(1)). + let raw_payload = payload.clone(); + + let first_msg = decode_multistream_message(&mut payload)?; - // skip the multistream-select header because it's not part of user protocols but verify it's - // present - if protocol_iter.next() != Some(PROTO_MULTISTREAM_1_0) { + let (protocol, header_in_this_payload) = match first_msg { + Message::Header(HeaderLine::V1) => { + if payload.is_empty() { + // Header only — echo the exact received bytes back (zero alloc). + return Ok(ListenerSelectResult::PendingProtocol { + message: raw_payload, + }); + } + // Header + protocol in same payload. + match decode_multistream_message(&mut payload)? { + Message::Protocol(protocol) => (protocol, true), + _ => + return Err(Error::NegotiationError( + error::NegotiationError::ParseError(error::ParseError::InvalidData), + )), + } + } + // Protocol without header is only valid if the header was already exchanged. + Message::Protocol(protocol) if header_received => (protocol, false), + _ => + return Err(Error::NegotiationError( + error::NegotiationError::MultistreamSelectError(NegotiationError::Failed), + )), + }; + + // Reject messages with unexpected trailing data. + if !payload.is_empty() { return Err(Error::NegotiationError( - error::NegotiationError::MultistreamSelectError(NegotiationError::Failed), + error::NegotiationError::ParseError(error::ParseError::InvalidData), )); } - for protocol in protocol_iter { - tracing::trace!( - target: LOG_TARGET, - protocol = ?std::str::from_utf8(protocol.as_ref()), - "listener: checking protocol", - ); + tracing::trace!( + target: LOG_TARGET, + protocol = ?std::str::from_utf8(protocol.as_ref()), + "listener: checking protocol", + ); - for supported in supported_protocols.iter() { - if protocol.as_ref() == supported.as_bytes() { - return Ok(ListenerSelectResult::Accepted { - protocol: supported.clone(), - message: webrtc_encode_multistream_message(std::iter::once( - Message::Protocol(protocol), - ))?, - }); - } + for supported in supported_protocols.iter() { + if protocol.as_ref() == supported.as_bytes() { + return Ok(ListenerSelectResult::Accepted { + protocol: supported.clone(), + message: webrtc_encode_multistream_message( + Message::Protocol(protocol), + header_in_this_payload, + )? + .freeze(), + }); } } @@ -388,7 +456,8 @@ pub fn webrtc_listener_negotiate( ); Ok(ListenerSelectResult::Rejected { - message: webrtc_encode_multistream_message(std::iter::once(Message::NotAvailable))?, + message: webrtc_encode_multistream_message(Message::NotAvailable, header_in_this_payload)? + .freeze(), }) } @@ -407,17 +476,18 @@ mod tests { ProtocolName::from("/13371338/proto/3"), ProtocolName::from("/13371338/proto/4"), ]; - let message = webrtc_encode_multistream_message(vec![ + let message = webrtc_encode_multistream_message( Message::Protocol(Protocol::try_from(&b"/13371338/proto/1"[..]).unwrap()), - Message::Protocol(Protocol::try_from(&b"/sup/proto/1"[..]).unwrap()), - ]) + true, + ) .unwrap() .freeze(); - match webrtc_listener_negotiate(local_protocols, message) { + match webrtc_listener_negotiate(local_protocols, message, false) { Err(error) => panic!("error received: {error:?}"), Ok(ListenerSelectResult::Rejected { .. }) => panic!("message rejected"), - Ok(ListenerSelectResult::Accepted { protocol, message }) => { + Ok(ListenerSelectResult::PendingProtocol { .. }) => panic!("unexpected pending"), + Ok(ListenerSelectResult::Accepted { protocol, .. }) => { assert_eq!(protocol, ProtocolName::from("/13371338/proto/1")); } } @@ -432,32 +502,19 @@ mod tests { ProtocolName::from("/13371338/proto/3"), ProtocolName::from("/13371338/proto/4"), ]; - // The invalid message is really two multistream-select messages inside one `WebRtcMessage`: - // 1. the multistream-select header - // 2. an "ls response" message (that does not contain another header) - // - // This is invalid for two reasons: - // 1. It is malformed. Either the header is followed by one or more `Message::Protocol` - // instances or the header is part of the "ls response". - // 2. This sequence of messages is not spec compliant. A listener receives one of the - // following on an inbound substream: - // - a multistream-select header followed by a `Message::Protocol` instance - // - a multistream-select header followed by an "ls" message (<\n>) - // - // `webrtc_listener_negotiate()` should reject this invalid message. The error can either be - // `InvalidData` because the message is malformed or `StateMismatch` because the message is - // not expected at this point in the protocol. - let message = webrtc_encode_multistream_message(std::iter::once(Message::Protocols(vec![ - Protocol::try_from(&b"/13371338/proto/1"[..]).unwrap(), - Protocol::try_from(&b"/sup/proto/1"[..]).unwrap(), - ]))) + let message = webrtc_encode_multistream_message( + Message::Protocols(vec![ + Protocol::try_from(&b"/13371338/proto/1"[..]).unwrap(), + Protocol::try_from(&b"/sup/proto/1"[..]).unwrap(), + ]), + true, + ) .unwrap() .freeze(); - match webrtc_listener_negotiate(local_protocols, message) { + match webrtc_listener_negotiate(local_protocols, message, false) { Err(error) => assert!(std::matches!( error, - // something has gone off the rails here... Error::NegotiationError(error::NegotiationError::ParseError( error::ParseError::InvalidData )), @@ -476,18 +533,15 @@ mod tests { ProtocolName::from("/13371338/proto/4"), ]; - // send only header line + // Send only header line with varint length prefix. let mut bytes = BytesMut::with_capacity(32); - let message = Message::Header(HeaderLine::V1); - message.encode(&mut bytes).map_err(|_| Error::InvalidData).unwrap(); + Message::Header(HeaderLine::V1).encode(&mut bytes).unwrap(); + let payload = Bytes::from(UnsignedVarint::encode(bytes).unwrap()); - match webrtc_listener_negotiate(local_protocols, bytes.freeze()) { - Err(error) => assert!(std::matches!( - error, - Error::NegotiationError(error::NegotiationError::ParseError( - error::ParseError::InvalidData - )), - )), + match webrtc_listener_negotiate(local_protocols, payload.clone(), false) { + Ok(ListenerSelectResult::PendingProtocol { message }) => { + assert_eq!(message, payload); + } event => panic!("invalid event: {event:?}"), } } @@ -502,19 +556,14 @@ mod tests { ProtocolName::from("/13371338/proto/4"), ]; - // header line missing - let mut bytes = BytesMut::with_capacity(256); - vec![&b"/13371338/proto/1"[..], &b"/sup/proto/1"[..]] - .into_iter() - .for_each(|proto| { - bytes.put_u8((proto.len() + 1) as u8); - - Message::Protocol(Protocol::try_from(proto).unwrap()) - .encode(&mut bytes) - .unwrap(); - }); + // Single protocol, no header. + let mut bytes = BytesMut::with_capacity(64); + Message::Protocol(Protocol::try_from(&b"/13371338/proto/1"[..]).unwrap()) + .encode(&mut bytes) + .unwrap(); + let payload = Bytes::from(UnsignedVarint::encode(bytes).unwrap()); - match webrtc_listener_negotiate(local_protocols, bytes.freeze()) { + match webrtc_listener_negotiate(local_protocols, payload, false) { Err(error) => assert!(std::matches!( error, Error::NegotiationError(error::NegotiationError::MultistreamSelectError( @@ -527,29 +576,118 @@ mod tests { #[test] fn protocol_not_supported() { - let mut local_protocols = vec![ + let local_protocols = vec![ ProtocolName::from("/13371338/proto/1"), ProtocolName::from("/sup/proto/1"), ProtocolName::from("/13371338/proto/2"), ProtocolName::from("/13371338/proto/3"), ProtocolName::from("/13371338/proto/4"), ]; - let message = webrtc_encode_multistream_message(vec![Message::Protocol( - Protocol::try_from(&b"/13371339/proto/1"[..]).unwrap(), - )]) + let message = webrtc_encode_multistream_message( + Message::Protocol(Protocol::try_from(&b"/13371339/proto/1"[..]).unwrap()), + true, + ) .unwrap() .freeze(); - match webrtc_listener_negotiate(local_protocols, message) { + match webrtc_listener_negotiate(local_protocols, message, false) { Err(error) => panic!("error received: {error:?}"), Ok(ListenerSelectResult::Rejected { message }) => { assert_eq!( message, - webrtc_encode_multistream_message(std::iter::once(Message::NotAvailable)) + webrtc_encode_multistream_message(Message::NotAvailable, true) .unwrap() + .freeze() ); } - Ok(ListenerSelectResult::Accepted { protocol, message }) => panic!("message accepted"), + Ok(ListenerSelectResult::Accepted { .. }) => panic!("message accepted"), + Ok(ListenerSelectResult::PendingProtocol { .. }) => panic!("unexpected pending"), + } + } + + #[test] + fn protocols_not_supported() { + let local_protocols = vec![ProtocolName::from("/13371338/proto/1")]; + + // Round 1: send header only → PendingProtocol (header echo). + let mut bytes = BytesMut::with_capacity(32); + Message::Header(HeaderLine::V1).encode(&mut bytes).unwrap(); + let header_payload = Bytes::from(UnsignedVarint::encode(bytes).unwrap()); + + match webrtc_listener_negotiate(local_protocols.clone(), header_payload.clone(), false) { + Ok(ListenerSelectResult::PendingProtocol { message }) => { + assert_eq!(message, header_payload); + } + event => panic!("expected PendingProtocol, got {event:?}"), + } + + // Round 2: send first protocol (not supported) → Rejected (na, no header). + let mut bytes = BytesMut::with_capacity(64); + Message::Protocol(Protocol::try_from(&b"/unsupported/proto/1"[..]).unwrap()) + .encode(&mut bytes) + .unwrap(); + let proto1_payload = Bytes::from(UnsignedVarint::encode(bytes).unwrap()); + + match webrtc_listener_negotiate(local_protocols.clone(), proto1_payload, true) { + Ok(ListenerSelectResult::Rejected { message }) => { + assert_eq!( + message, + webrtc_encode_multistream_message(Message::NotAvailable, false) + .unwrap() + .freeze() + ); + } + event => panic!("expected Rejected, got {event:?}"), + } + + // Round 3: send second protocol (also not supported) → Rejected (na, no header). + let mut bytes = BytesMut::with_capacity(64); + Message::Protocol(Protocol::try_from(&b"/unsupported/proto/2"[..]).unwrap()) + .encode(&mut bytes) + .unwrap(); + let proto2_payload = Bytes::from(UnsignedVarint::encode(bytes).unwrap()); + + match webrtc_listener_negotiate(local_protocols, proto2_payload, true) { + Ok(ListenerSelectResult::Rejected { message }) => { + assert_eq!( + message, + webrtc_encode_multistream_message(Message::NotAvailable, false) + .unwrap() + .freeze() + ); + } + event => panic!("expected Rejected, got {event:?}"), + } + } + + #[test] + fn header_only_then_protocol() { + let local_protocols = vec![ProtocolName::from("/13371338/proto/1")]; + + // Call 1: header only → PendingProtocol. + let mut bytes = BytesMut::with_capacity(32); + Message::Header(HeaderLine::V1).encode(&mut bytes).unwrap(); + let header_payload = Bytes::from(UnsignedVarint::encode(bytes).unwrap()); + + match webrtc_listener_negotiate(local_protocols.clone(), header_payload.clone(), false) { + Ok(ListenerSelectResult::PendingProtocol { message }) => { + assert_eq!(message, header_payload); + } + event => panic!("expected PendingProtocol, got {event:?}"), + } + + // Call 2: protocol only (header_received=true) → Accepted. + let mut bytes = BytesMut::with_capacity(64); + Message::Protocol(Protocol::try_from(&b"/13371338/proto/1"[..]).unwrap()) + .encode(&mut bytes) + .unwrap(); + let proto_payload = Bytes::from(UnsignedVarint::encode(bytes).unwrap()); + + match webrtc_listener_negotiate(local_protocols, proto_payload, true) { + Ok(ListenerSelectResult::Accepted { protocol, .. }) => { + assert_eq!(protocol, ProtocolName::from("/13371338/proto/1")); + } + event => panic!("expected Accepted, got {event:?}"), } } } diff --git a/src/multistream_select/mod.rs b/src/multistream_select/mod.rs index f195b1f3..762ba302 100644 --- a/src/multistream_select/mod.rs +++ b/src/multistream_select/mod.rs @@ -75,7 +75,7 @@ mod listener_select; mod negotiated; mod protocol; -use crate::error::{self, ParseError}; +use crate::error; pub use crate::multistream_select::{ dialer_select::{dialer_select_proto, DialerSelectFuture, HandshakeResult, WebRtcDialerState}, listener_select::{ @@ -86,10 +86,6 @@ pub use crate::multistream_select::{ protocol::{HeaderLine, Message, Protocol, ProtocolError, PROTO_MULTISTREAM_1_0}, }; -use bytes::Bytes; - -const LOG_TARGET: &str = "litep2p::multistream-select"; - /// Supported multistream-select versions. #[derive(Clone, Copy, Debug, PartialEq, Eq)] pub enum Version { @@ -137,63 +133,3 @@ impl Default for Version { Version::V1 } } - -// This function is only used in the WebRTC transport. It expects one or more multistream-select -// messages in `remaining` and returns a list of protocols that were decoded from them. -fn drain_trailing_protocols( - mut remaining: Bytes, -) -> Result, error::NegotiationError> { - let mut protocols = vec![]; - - loop { - if remaining.is_empty() { - break; - } - - let (len, tail) = unsigned_varint::decode::usize(&remaining).map_err(|error| { - tracing::debug!( - target: LOG_TARGET, - ?error, - message = ?remaining, - "Failed to decode length-prefix in multistream message"); - error::NegotiationError::ParseError(ParseError::InvalidData) - })?; - - if len > tail.len() { - tracing::debug!( - target: LOG_TARGET, - message = ?tail, - length_prefix = len, - actual_length = tail.len(), - "Truncated multistream message", - ); - - return Err(error::NegotiationError::ParseError(ParseError::InvalidData)); - } - - let len_size = remaining.len() - tail.len(); - let payload = remaining.slice(len_size..len_size + len); - let res = Message::decode(payload); - - match res { - Ok(Message::Header(HeaderLine::V1)) => protocols.push(PROTO_MULTISTREAM_1_0), - Ok(Message::Protocol(protocol)) => protocols.push(protocol), - Ok(Message::Protocols(_)) => - return Err(error::NegotiationError::ParseError(ParseError::InvalidData)), - Err(error) => { - tracing::debug!( - target: LOG_TARGET, - ?error, - message = ?tail[..len], - "Failed to decode multistream message", - ); - return Err(error::NegotiationError::ParseError(ParseError::InvalidData)); - } - _ => return Err(error::NegotiationError::StateMismatch), - } - - remaining = remaining.slice(len_size + len..); - } - - Ok(protocols) -} diff --git a/src/multistream_select/protocol.rs b/src/multistream_select/protocol.rs index 71775df9..27998fa4 100644 --- a/src/multistream_select/protocol.rs +++ b/src/multistream_select/protocol.rs @@ -26,7 +26,6 @@ //! `MessageReader`. use crate::{ - codec::unsigned_varint::UnsignedVarint, error::Error as Litep2pError, multistream_select::{ length_delimited::{LengthDelimited, LengthDelimitedReader}, @@ -132,6 +131,25 @@ pub enum Message { } impl Message { + /// Returns the exact encoded byte length of this message, without allocating. + pub fn encoded_len(&self) -> usize { + match self { + Message::Header(HeaderLine::V1) => MSG_MULTISTREAM_1_0.len(), + Message::Protocol(p) => p.0.as_ref().len() + 1, + Message::ListProtocols => MSG_LS.len(), + Message::NotAvailable => MSG_PROTOCOL_NA.len(), + Message::Protocols(ps) => { + let mut len = 1usize; // trailing \n + let mut buf = unsigned_varint::encode::usize_buffer(); + for p in ps { + let proto_len = p.0.as_ref().len() + 1; + len += unsigned_varint::encode::usize(proto_len, &mut buf).len() + proto_len; + } + len + } + } + } + /// Encodes a `Message` into its byte representation. pub fn encode(&self, dest: &mut BytesMut) -> Result<(), ProtocolError> { match self { @@ -228,30 +246,43 @@ impl Message { } } -/// Create `multistream-select` message from an iterator of `Message`s. -/// -/// # Note +/// Encode a single `multistream-select` message, optionally preceded by the protocol header. /// -/// This implementation may not be compliant with the multistream-select protocol spec. -/// The only purpose of this was to get the `multistream-select` protocol working with smoldot. +/// When `prepend_header` is `true` the `/multistream/1.0.0` header line is written before the +/// message. Everything is written into a single `BytesMut` allocation. pub fn webrtc_encode_multistream_message( - messages: impl IntoIterator, + message: Message, + prepend_header: bool, ) -> crate::Result { - // encode `/multistream-select/1.0.0` header - let mut bytes = BytesMut::with_capacity(32); - let message = Message::Header(HeaderLine::V1); - message.encode(&mut bytes).map_err(|_| Litep2pError::InvalidData)?; - let mut header = UnsignedVarint::encode(bytes)?; - - // encode each message - for message in messages { - let mut proto_bytes = BytesMut::with_capacity(256); - message.encode(&mut proto_bytes).map_err(|_| Litep2pError::InvalidData)?; - let mut proto_bytes = UnsignedVarint::encode(proto_bytes)?; - header.append(&mut proto_bytes); + let msg_len = message.encoded_len(); + let header_len = MSG_MULTISTREAM_1_0.len(); + let mut varint_buf = unsigned_varint::encode::usize_buffer(); + + let capacity = { + let msg_varint_len = unsigned_varint::encode::usize(msg_len, &mut varint_buf).len(); + let total = if prepend_header { + let header_varint_len = + unsigned_varint::encode::usize(header_len, &mut varint_buf).len(); + header_varint_len + header_len + msg_varint_len + msg_len + } else { + msg_varint_len + msg_len + }; + total.min(super::length_delimited::MAX_FRAME_SIZE as usize) + }; + + let mut output = BytesMut::with_capacity(capacity); + + if prepend_header { + output.extend_from_slice(unsigned_varint::encode::usize(header_len, &mut varint_buf)); + Message::Header(HeaderLine::V1) + .encode(&mut output) + .map_err(|_| Litep2pError::InvalidData)?; } - Ok(BytesMut::from(&header[..])) + output.extend_from_slice(unsigned_varint::encode::usize(msg_len, &mut varint_buf)); + message.encode(&mut output).map_err(|_| Litep2pError::InvalidData)?; + + Ok(output) } /// A `MessageIO` implements a [`Stream`] and [`Sink`] of [`Message`]s. diff --git a/src/transport/webrtc/connection.rs b/src/transport/webrtc/connection.rs index f0152016..ffef8a7a 100644 --- a/src/transport/webrtc/connection.rs +++ b/src/transport/webrtc/connection.rs @@ -21,7 +21,8 @@ use crate::{ error::{Error, ParseError, SubstreamError}, multistream_select::{ - webrtc_listener_negotiate, HandshakeResult, ListenerSelectResult, WebRtcDialerState, + webrtc_listener_negotiate, HandshakeResult, ListenerSelectResult, NegotiationError, + WebRtcDialerState, }, protocol::{Direction, Permit, ProtocolCommand, ProtocolSet, SubstreamKeepAlive}, substream::Substream, @@ -148,7 +149,10 @@ enum ChannelState { Closing, /// Inbound channel is opening. - InboundOpening, + InboundOpening { + /// Whether the multistream-select header has already been received/sent. + header_received: bool, + }, /// Outbound channel is opening. OutboundOpening { @@ -263,7 +267,12 @@ impl WebRtcConnection { "inbound channel opened, wait for `multistream-select` message", ); - self.channels.insert(channel_id, ChannelState::InboundOpening); + self.channels.insert( + channel_id, + ChannelState::InboundOpening { + header_received: false, + }, + ); return Ok(()); }; @@ -308,16 +317,23 @@ impl WebRtcConnection { /// Handle data received to an opening inbound channel. /// /// The first message received over an inbound channel is the `multistream-select` handshake. - /// This handshake contains the protocol (and potentially fallbacks for that protocol) that - /// remote peer wants to use for this channel. Parse the handshake and check if any of the - /// proposed protocols are supported by the local node. If not, send rejection to remote peer - /// and close the channel. If the local node supports one of the protocols, send confirmation - /// for the protocol to remote peer and report an opened substream to the selected protocol. + /// This handshake contains the protocol the remote peer wants to use for this channel. Parse + /// the handshake and check whether the proposed protocol is supported by the local node. + /// If not, send rejection to remote peer and but keep the channel open so that the peer can + /// propose a fallback. If the local node support the protocol, send confirmation for the + /// protocol to remote peer and report an opened substream to the selected protocol. + /// + /// Returns `Ok(Some(...))` if the protocol was accepted and the substream opened, + /// `Ok(None)` if the proposed protocol was rejected (the `na` response has been sent + /// and the channel should remain in [`ChannelState::InboundOpening`] so the dialer can + /// propose another protocol per back-and-forth multistream-select negotiation), + /// or `Err(...)` on a fatal error (channel should be closed). async fn on_inbound_opening_channel_data( &mut self, channel_id: ChannelId, data: Vec, - ) -> crate::Result<(SubstreamId, SubstreamHandle, Option)> { + header_received: bool, + ) -> crate::Result)>> { tracing::trace!( target: LOG_TARGET, peer = ?self.peer, @@ -329,9 +345,10 @@ impl WebRtcConnection { let protocols = self.protocol_set.protocols_with_keep_alives(); let protocol_names = protocols.keys().cloned().collect(); let (response, negotiated) = - match webrtc_listener_negotiate(protocol_names, payload.into())? { + match webrtc_listener_negotiate(protocol_names, payload.into(), header_received)? { ListenerSelectResult::Accepted { protocol, message } => (message, Some(protocol)), - ListenerSelectResult::Rejected { message } => (message, None), + ListenerSelectResult::Rejected { message } + | ListenerSelectResult::PendingProtocol { message } => (message, None), }; self.rtc @@ -343,7 +360,16 @@ impl WebRtcConnection { ) .map_err(Error::WebRtc)?; - let protocol = negotiated.ok_or(Error::SubstreamDoesntExist)?; + let Some(protocol) = negotiated else { + tracing::trace!( + target: LOG_TARGET, + peer = ?self.peer, + ?channel_id, + "inbound protocol rejected, keeping channel open for back-and-forth negotiation", + ); + return Ok(None); + }; + let substream_id = self.protocol_set.next_substream_id(); let codec = self.protocol_set.protocol_codec(&protocol); let opening_permit = self.protocol_set.try_get_permit().ok_or(Error::ConnectionClosed)?; @@ -371,7 +397,7 @@ impl WebRtcConnection { opening_permit, ) .await - .map(|_| (substream_id, handle, lifetime_permit)) + .map(|_| Some((substream_id, handle, lifetime_permit))) .map_err(Into::into) } @@ -411,23 +437,75 @@ impl WebRtcConnection { ParseError::InvalidData.into(), ))?; - let HandshakeResult::Succeeded(protocol) = dialer_state.register_response(message)? else { - tracing::trace!( - target: LOG_TARGET, - peer = ?self.peer, - ?channel_id, - "multistream-select handshake not ready", - ); + let protocol = match dialer_state.register_response(message)? { + HandshakeResult::Succeeded(protocol) => protocol, + HandshakeResult::NotReady => { + tracing::trace!( + target: LOG_TARGET, + peer = ?self.peer, + ?channel_id, + "multistream-select handshake not ready", + ); - self.channels.insert( - channel_id, - ChannelState::OutboundOpening { - context, - dialer_state, - }, - ); + self.channels.insert( + channel_id, + ChannelState::OutboundOpening { + context, + dialer_state, + }, + ); - return Ok(None); + return Ok(None); + } + HandshakeResult::Rejected => match dialer_state.propose_next_fallback() { + Ok(Some(message)) => { + tracing::trace!( + target: LOG_TARGET, + peer = ?self.peer, + ?channel_id, + "protocol rejected, trying next fallback", + ); + + let message = WebRtcMessage::encode(message, None); + self.rtc + .channel(channel_id) + .ok_or(Error::ChannelDoesntExist) + .map_err(|_| { + SubstreamError::NegotiationError(NegotiationError::Failed.into()) + })? + .write(true, message.as_ref()) + .map_err(|_| { + SubstreamError::NegotiationError(NegotiationError::Failed.into()) + })?; + + self.channels.insert( + channel_id, + ChannelState::OutboundOpening { + context, + dialer_state, + }, + ); + + return Ok(None); + } + Ok(None) => { + tracing::debug!( + target: LOG_TARGET, + peer = ?self.peer, + ?channel_id, + "all protocols rejected by remote peer", + ); + + return Err(SubstreamError::NegotiationError( + NegotiationError::Failed.into(), + )); + } + Err(_) => { + return Err(SubstreamError::NegotiationError( + NegotiationError::Failed.into(), + )); + } + }, }; let ChannelContext { @@ -468,13 +546,13 @@ impl WebRtcConnection { ) -> crate::Result<()> { let message = WebRtcMessage::decode(&data)?; - tracing::trace!( + tracing::debug!( target: LOG_TARGET, peer = ?self.peer, ?channel_id, flag = ?message.flag, data_len = message.payload.as_ref().map_or(0usize, |payload| payload.len()), - "handle inbound message", + "handle inbound message on open channel", ); self.handles @@ -495,6 +573,15 @@ impl WebRtcConnection { /// Handle data received from a channel. async fn on_inbound_data(&mut self, channel_id: ChannelId, data: Vec) -> crate::Result<()> { + tracing::debug!( + target: LOG_TARGET, + peer = ?self.peer, + ?channel_id, + data_len = data.len(), + channel_state = ?self.channels.get(&channel_id), + "received channel data", + ); + let Some(state) = self.channels.remove(&channel_id) else { tracing::warn!( target: LOG_TARGET, @@ -507,9 +594,10 @@ impl WebRtcConnection { }; match state { - ChannelState::InboundOpening => { - match self.on_inbound_opening_channel_data(channel_id, data).await { - Ok((substream_id, handle, lifetime_permit)) => { + ChannelState::InboundOpening { header_received } => { + match self.on_inbound_opening_channel_data(channel_id, data, header_received).await + { + Ok(Some((substream_id, handle, lifetime_permit))) => { self.handles.insert(channel_id, handle); self.channels.insert( channel_id, @@ -520,6 +608,15 @@ impl WebRtcConnection { }, ); } + Ok(None) => { + // Header has been exchanged after any successful round. + self.channels.insert( + channel_id, + ChannelState::InboundOpening { + header_received: true, + }, + ); + } Err(error) => { tracing::debug!( target: LOG_TARGET, @@ -700,7 +797,19 @@ impl WebRtcConnection { pub async fn run_event_loop(mut self) { loop { // poll output until we get a timeout - let timeout = match self.rtc.poll_output().unwrap() { + let output = match self.rtc.poll_output() { + Ok(output) => output, + Err(error) => { + tracing::debug!( + target: LOG_TARGET, + peer = ?self.peer, + ?error, + "poll_output failed, closing connection", + ); + return self.on_connection_closed().await; + } + }; + let timeout = match output { Output::Timeout(v) => v, Output::Transmit(v) => { tracing::trace!( @@ -849,6 +958,20 @@ impl WebRtcConnection { keep_alive, connection_id: _, }) => { + // Check if the connection is still healthy before opening new substreams. + // This prevents panics when trying to open channels on a shutting-down + // SCTP association. + if !self.rtc.is_alive() || !self.rtc.is_connected() { + tracing::debug!( + target: LOG_TARGET, + peer = ?self.peer, + ?protocol, + is_alive = self.rtc.is_alive(), + is_connected = self.rtc.is_connected(), + "rejecting substream open: connection not healthy", + ); + continue; + } self.on_open_substream( protocol, fallback_names,