diff --git a/src/multistream_select/dialer_select.rs b/src/multistream_select/dialer_select.rs index 86c22647c..21b4fc067 100644 --- a/src/multistream_select/dialer_select.rs +++ b/src/multistream_select/dialer_select.rs @@ -24,7 +24,6 @@ use crate::{ codec::unsigned_varint::UnsignedVarint, error::{self, Error, ParseError, SubstreamError}, multistream_select::{ - drain_trailing_protocols, protocol::{ webrtc_encode_multistream_message, HeaderLine, Message, MessageIO, Protocol, ProtocolError, PROTO_MULTISTREAM_1_0, @@ -300,6 +299,12 @@ pub enum HandshakeResult { /// The returned tuple contains the negotiated protocol and response /// that must be sent to remote peer. Succeeded(ProtocolName), + + /// The proposed protocol was rejected by the remote peer. + /// + /// The caller should check if there are remaining fallback protocols to try + /// via [`WebRtcDialerState::propose_next_fallback()`]. + Rejected, } /// Handshake state. @@ -334,12 +339,9 @@ impl WebRtcDialerState { protocol: ProtocolName, fallback_names: Vec, ) -> crate::Result<(Self, Vec)> { - let message = webrtc_encode_multistream_message( - std::iter::once(protocol.clone()) - .chain(fallback_names.clone()) - .filter_map(|protocol| Protocol::try_from(protocol.as_ref()).ok()) - .map(Message::Protocol), - )? + let message = webrtc_encode_multistream_message(Message::Protocol( + Protocol::try_from(protocol.as_ref()).map_err(|_| Error::InvalidData)?, + ))? .freeze() .to_vec(); @@ -353,72 +355,83 @@ impl WebRtcDialerState { )) } + /// Propose the next fallback protocol to the remote peer. + /// + /// Returns `None` if there are no more fallback protocols to try. + /// Returns `Some(message)` with the encoded message to send, containing the protocol name. + pub fn propose_next_fallback(&mut self) -> crate::Result>> { + if self.fallback_names.is_empty() { + return Ok(None); + } + + let next = self.fallback_names.remove(0); + self.protocol = next; + self.state = HandshakeState::WaitingResponse; + + let message = webrtc_encode_multistream_message(Message::Protocol( + Protocol::try_from(self.protocol.as_ref()).map_err(|_| Error::InvalidData)?, + ))? + .freeze() + .to_vec(); + + Ok(Some(message)) + } + /// Register response to [`WebRtcDialerState`]. pub fn register_response( &mut self, payload: Vec, ) -> Result { - // All multistream-select messages are length-prefixed. Since this code path is not using - // multistream_select::protocol::MessageIO, we need to decode and remove the length here. - let remaining: &[u8] = &payload; - let (len, tail) = unsigned_varint::decode::usize(remaining).map_err(|error| { - tracing::debug!( + let bytes = Bytes::from(payload); + let mut remaining = bytes.clone(); + + while !remaining.is_empty() { + let (len, tail) = unsigned_varint::decode::usize(&remaining).map_err(|error| { + tracing::debug!( target: LOG_TARGET, ?error, - message = ?payload, - "Failed to decode length-prefix in multistream message"); - error::NegotiationError::ParseError(ParseError::InvalidData) - })?; + message = ?remaining, + "Failed to decode length-prefix in multistream message", + ); + error::NegotiationError::ParseError(ParseError::InvalidData) + })?; - let len_size = remaining.len() - tail.len(); - let bytes = Bytes::from(payload); - let payload = bytes.slice(len_size..len_size + len); - let remaining = bytes.slice(len_size + len..); - let message = Message::decode(payload); - - tracing::trace!( - target: LOG_TARGET, - ?message, - "Decoded message while registering response", - ); - - let mut protocols = match message { - Ok(Message::Header(HeaderLine::V1)) => { - vec![PROTO_MULTISTREAM_1_0] + let len_size = remaining.len() - tail.len(); + + if len > tail.len() { + tracing::debug!( + target: LOG_TARGET, + message = ?tail, + length_prefix = len, + actual_length = tail.len(), + "Truncated multistream message", + ); + return Err(error::NegotiationError::ParseError(ParseError::InvalidData)); } - Ok(Message::Protocol(protocol)) => vec![protocol], - Ok(Message::Protocols(protocols)) => protocols, - Ok(Message::NotAvailable) => - return match &self.state { - HandshakeState::WaitingProtocol => Err( - error::NegotiationError::MultistreamSelectError(NegotiationError::Failed), - ), - _ => Err(error::NegotiationError::StateMismatch), - }, - Ok(Message::ListProtocols) => return Err(error::NegotiationError::StateMismatch), - Err(_) => return Err(error::NegotiationError::ParseError(ParseError::InvalidData)), - }; - match drain_trailing_protocols(remaining) { - Ok(protos) => protocols.extend(protos), - Err(error) => return Err(error), - } + let payload = remaining.slice(len_size..len_size + len); + remaining = remaining.slice(len_size + len..); + let message = Message::decode(payload); - let mut protocol_iter = protocols.into_iter(); - loop { - match (&self.state, protocol_iter.next()) { - (HandshakeState::WaitingResponse, None) => - return Err(crate::error::NegotiationError::StateMismatch), - (HandshakeState::WaitingResponse, Some(protocol)) => { - if protocol == PROTO_MULTISTREAM_1_0 { - self.state = HandshakeState::WaitingProtocol; - } else { - return Err(crate::error::NegotiationError::MultistreamSelectError( - NegotiationError::Failed, - )); - } + tracing::trace!( + target: LOG_TARGET, + ?message, + "Decoded message while registering response", + ); + + match (&self.state, message) { + (HandshakeState::WaitingResponse, Ok(Message::Header(HeaderLine::V1))) => { + self.state = HandshakeState::WaitingProtocol; } - (HandshakeState::WaitingProtocol, Some(protocol)) => { + (HandshakeState::WaitingResponse, Ok(Message::Protocol(_))) => { + return Err(crate::error::NegotiationError::MultistreamSelectError( + NegotiationError::Failed, + )); + } + (_, Ok(Message::NotAvailable)) => { + return Ok(HandshakeResult::Rejected); + } + (HandshakeState::WaitingProtocol, Ok(Message::Protocol(protocol))) => { if protocol == PROTO_MULTISTREAM_1_0 { return Err(crate::error::NegotiationError::StateMismatch); } @@ -437,11 +450,16 @@ impl WebRtcDialerState { NegotiationError::Failed, )); } - (HandshakeState::WaitingProtocol, None) => { - return Ok(HandshakeResult::NotReady); + _ => { + return Err(crate::error::NegotiationError::StateMismatch); } } } + + match &self.state { + HandshakeState::WaitingProtocol => Ok(HandshakeResult::NotReady), + HandshakeState::WaitingResponse => Err(crate::error::NegotiationError::StateMismatch), + } } } @@ -816,6 +834,7 @@ mod tests { ) .unwrap(); + // Initial message should only contain the main protocol, not the fallback. let mut bytes = BytesMut::with_capacity(32); bytes.put_u8(MSG_MULTISTREAM_1_0.len() as u8); let _ = Message::Header(HeaderLine::V1).encode(&mut bytes).unwrap(); @@ -824,15 +843,52 @@ mod tests { bytes.put_u8((proto1.as_ref().len() + 1) as u8); // + 1 for \n let _ = Message::Protocol(proto1).encode(&mut bytes).unwrap(); - let proto2 = Protocol::try_from(&b"/sup/proto/1"[..]).expect("valid protocol name"); - bytes.put_u8((proto2.as_ref().len() + 1) as u8); // + 1 for \n - let _ = Message::Protocol(proto2).encode(&mut bytes).unwrap(); - let expected_message = bytes.freeze().to_vec(); assert_eq!(message, expected_message); } + #[test] + fn propose_next_fallback() { + let (mut dialer_state, _message) = WebRtcDialerState::propose( + ProtocolName::from("/13371338/proto/1"), + vec![ProtocolName::from("/sup/proto/1")], + ) + .unwrap(); + + // Simulate receiving header-only response, transitioning to WaitingProtocol. + let mut header_bytes = BytesMut::with_capacity(32); + header_bytes.put_u8(MSG_MULTISTREAM_1_0.len() as u8); + let _ = Message::Header(HeaderLine::V1).encode(&mut header_bytes).unwrap(); + // Append "na" to simulate rejection. + let na_bytes = b"na\n"; + header_bytes.put_u8(na_bytes.len() as u8); + header_bytes.put_slice(na_bytes); + + match dialer_state.register_response(header_bytes.freeze().to_vec()) { + Ok(HandshakeResult::Rejected) => {} + event => panic!("expected Rejected, got: {event:?}"), + } + + // Now propose the next fallback. + let fallback_message = dialer_state + .propose_next_fallback() + .expect("no error") + .expect("should have a fallback"); + + let mut expected = BytesMut::with_capacity(32); + expected.put_u8(MSG_MULTISTREAM_1_0.len() as u8); + let _ = Message::Header(HeaderLine::V1).encode(&mut expected).unwrap(); + let proto = Protocol::try_from(&b"/sup/proto/1"[..]).expect("valid protocol name"); + expected.put_u8((proto.as_ref().len() + 1) as u8); + let _ = Message::Protocol(proto).encode(&mut expected).unwrap(); + + assert_eq!(fallback_message, expected.freeze().to_vec()); + + // No more fallbacks. + assert!(dialer_state.propose_next_fallback().unwrap().is_none()); + } + #[test] fn register_response_header_only() { let mut bytes = BytesMut::with_capacity(32); @@ -875,9 +931,9 @@ mod tests { #[test] fn negotiate_main_protocol() { - let message = webrtc_encode_multistream_message(vec![Message::Protocol( + let message = webrtc_encode_multistream_message(Message::Protocol( Protocol::try_from(&b"/13371338/proto/1"[..]).unwrap(), - )]) + )) .unwrap() .freeze(); @@ -897,9 +953,9 @@ mod tests { #[test] fn negotiate_fallback_protocol() { - let message = webrtc_encode_multistream_message(vec![Message::Protocol( + let message = webrtc_encode_multistream_message(Message::Protocol( Protocol::try_from(&b"/sup/proto/1"[..]).unwrap(), - )]) + )) .unwrap() .freeze(); diff --git a/src/multistream_select/listener_select.rs b/src/multistream_select/listener_select.rs index 6d73efcb7..3347f4550 100644 --- a/src/multistream_select/listener_select.rs +++ b/src/multistream_select/listener_select.rs @@ -374,9 +374,7 @@ pub fn webrtc_listener_negotiate<'a>( if protocol.as_ref() == supported.as_bytes() { return Ok(ListenerSelectResult::Accepted { protocol: supported.clone(), - message: webrtc_encode_multistream_message(std::iter::once( - Message::Protocol(protocol), - ))?, + message: webrtc_encode_multistream_message(Message::Protocol(protocol))?, }); } } @@ -388,7 +386,7 @@ pub fn webrtc_listener_negotiate<'a>( ); Ok(ListenerSelectResult::Rejected { - message: webrtc_encode_multistream_message(std::iter::once(Message::NotAvailable))?, + message: webrtc_encode_multistream_message(Message::NotAvailable)?, }) } @@ -407,10 +405,9 @@ mod tests { ProtocolName::from("/13371338/proto/3"), ProtocolName::from("/13371338/proto/4"), ]; - let message = webrtc_encode_multistream_message(vec![ - Message::Protocol(Protocol::try_from(&b"/13371338/proto/1"[..]).unwrap()), - Message::Protocol(Protocol::try_from(&b"/sup/proto/1"[..]).unwrap()), - ]) + let message = webrtc_encode_multistream_message(Message::Protocol( + Protocol::try_from(&b"/13371338/proto/1"[..]).unwrap(), + )) .unwrap() .freeze(); @@ -447,10 +444,10 @@ mod tests { // `webrtc_listener_negotiate()` should reject this invalid message. The error can either be // `InvalidData` because the message is malformed or `StateMismatch` because the message is // not expected at this point in the protocol. - let message = webrtc_encode_multistream_message(std::iter::once(Message::Protocols(vec![ + let message = webrtc_encode_multistream_message(Message::Protocols(vec![ Protocol::try_from(&b"/13371338/proto/1"[..]).unwrap(), Protocol::try_from(&b"/sup/proto/1"[..]).unwrap(), - ]))) + ])) .unwrap() .freeze(); @@ -534,9 +531,9 @@ mod tests { ProtocolName::from("/13371338/proto/3"), ProtocolName::from("/13371338/proto/4"), ]; - let message = webrtc_encode_multistream_message(vec![Message::Protocol( + let message = webrtc_encode_multistream_message(Message::Protocol( Protocol::try_from(&b"/13371339/proto/1"[..]).unwrap(), - )]) + )) .unwrap() .freeze(); @@ -545,8 +542,7 @@ mod tests { Ok(ListenerSelectResult::Rejected { message }) => { assert_eq!( message, - webrtc_encode_multistream_message(std::iter::once(Message::NotAvailable)) - .unwrap() + webrtc_encode_multistream_message(Message::NotAvailable).unwrap() ); } Ok(ListenerSelectResult::Accepted { protocol, message }) => panic!("message accepted"), diff --git a/src/multistream_select/protocol.rs b/src/multistream_select/protocol.rs index 71775df9a..ebd21247b 100644 --- a/src/multistream_select/protocol.rs +++ b/src/multistream_select/protocol.rs @@ -234,24 +234,21 @@ impl Message { /// /// This implementation may not be compliant with the multistream-select protocol spec. /// The only purpose of this was to get the `multistream-select` protocol working with smoldot. -pub fn webrtc_encode_multistream_message( - messages: impl IntoIterator, -) -> crate::Result { +pub fn webrtc_encode_multistream_message(message: Message) -> crate::Result { // encode `/multistream-select/1.0.0` header let mut bytes = BytesMut::with_capacity(32); - let message = Message::Header(HeaderLine::V1); - message.encode(&mut bytes).map_err(|_| Litep2pError::InvalidData)?; - let mut header = UnsignedVarint::encode(bytes)?; - - // encode each message - for message in messages { - let mut proto_bytes = BytesMut::with_capacity(256); - message.encode(&mut proto_bytes).map_err(|_| Litep2pError::InvalidData)?; - let mut proto_bytes = UnsignedVarint::encode(proto_bytes)?; - header.append(&mut proto_bytes); - } - - Ok(BytesMut::from(&header[..])) + Message::Header(HeaderLine::V1) + .encode(&mut bytes) + .map_err(|_| Litep2pError::InvalidData)?; + let mut output = UnsignedVarint::encode(bytes)?; + + // encode the message + let mut msg_bytes = BytesMut::with_capacity(256); + message.encode(&mut msg_bytes).map_err(|_| Litep2pError::InvalidData)?; + let mut msg_bytes = UnsignedVarint::encode(msg_bytes)?; + output.append(&mut msg_bytes); + + Ok(BytesMut::from(&output[..])) } /// A `MessageIO` implements a [`Stream`] and [`Sink`] of [`Message`]s. diff --git a/src/schema/webrtc.proto b/src/schema/webrtc.proto index f36e04f94..852f3c6c1 100644 --- a/src/schema/webrtc.proto +++ b/src/schema/webrtc.proto @@ -12,6 +12,10 @@ message Message { // The sender abruptly terminates the sending part of the stream. The // receiver MAY discard any data that it already received on that stream. RESET_STREAM = 2; + // Sending the FIN_ACK flag acknowledges the previous receipt of a message + // with the FIN flag set. Receiving a FIN_ACK flag gives the recipient + // confidence that the remote has received all sent messages. + FIN_ACK = 3; } optional Flag flag = 1; diff --git a/src/transport/webrtc/connection.rs b/src/transport/webrtc/connection.rs index c47e0024d..08e2451eb 100644 --- a/src/transport/webrtc/connection.rs +++ b/src/transport/webrtc/connection.rs @@ -21,12 +21,14 @@ use crate::{ error::{Error, ParseError, SubstreamError}, multistream_select::{ - webrtc_listener_negotiate, HandshakeResult, ListenerSelectResult, WebRtcDialerState, + webrtc_listener_negotiate, HandshakeResult, ListenerSelectResult, NegotiationError, + WebRtcDialerState, }, protocol::{Direction, Permit, ProtocolCommand, ProtocolSet}, substream::Substream, transport::{ webrtc::{ + schema::webrtc::message::Flag, substream::{Event as SubstreamEvent, Substream as WebRtcSubstream, SubstreamHandle}, util::WebRtcMessage, }, @@ -263,7 +265,7 @@ impl WebRtcConnection { let fallback_names = std::mem::take(&mut context.fallback_names); let (dialer_state, message) = WebRtcDialerState::propose(context.protocol.clone(), fallback_names)?; - let message = WebRtcMessage::encode(message); + let message = WebRtcMessage::encode(message, None); self.rtc .channel(channel_id) @@ -301,16 +303,22 @@ impl WebRtcConnection { /// Handle data received to an opening inbound channel. /// /// The first message received over an inbound channel is the `multistream-select` handshake. - /// This handshake contains the protocol (and potentially fallbacks for that protocol) that - /// remote peer wants to use for this channel. Parse the handshake and check if any of the - /// proposed protocols are supported by the local node. If not, send rejection to remote peer - /// and close the channel. If the local node supports one of the protocols, send confirmation - /// for the protocol to remote peer and report an opened substream to the selected protocol. + /// This handshake contains the protocol the remote peer wants to use for this channel. Parse + /// the handshake and check whether the proposed protocol is supported by the local node. + /// If not, send rejection to remote peer and but keep the channel open so that the peer can + /// propose a fallback. If the local node support the protocol, send confirmation for the + /// protocol to remote peer and report an opened substream to the selected protocol. + /// + /// Returns `Ok(Some(...))` if the protocol was accepted and the substream opened, + /// `Ok(None)` if the proposed protocol was rejected (the `na` response has been sent + /// and the channel should remain in [`ChannelState::InboundOpening`] so the dialer can + /// propose another protocol per back-and-forth multistream-select negotiation), + /// or `Err(...)` on a fatal error (channel should be closed). async fn on_inbound_opening_channel_data( &mut self, channel_id: ChannelId, data: Vec, - ) -> crate::Result<(SubstreamId, SubstreamHandle, Permit)> { + ) -> crate::Result> { tracing::trace!( target: LOG_TARGET, peer = ?self.peer, @@ -330,10 +338,22 @@ impl WebRtcConnection { self.rtc .channel(channel_id) .ok_or(Error::ChannelDoesntExist)? - .write(true, WebRtcMessage::encode(response.to_vec()).as_ref()) + .write( + true, + WebRtcMessage::encode(response.to_vec(), None).as_ref(), + ) .map_err(Error::WebRtc)?; - let protocol = negotiated.ok_or(Error::SubstreamDoesntExist)?; + let Some(protocol) = negotiated else { + tracing::trace!( + target: LOG_TARGET, + peer = ?self.peer, + ?channel_id, + "inbound protocol rejected, keeping channel open for back-and-forth negotiation", + ); + return Ok(None); + }; + let substream_id = self.protocol_set.next_substream_id(); let codec = self.protocol_set.protocol_codec(&protocol); let permit = self.protocol_set.try_get_permit().ok_or(Error::ConnectionClosed)?; @@ -352,7 +372,7 @@ impl WebRtcConnection { self.protocol_set .report_substream_open(self.peer, protocol.clone(), Direction::Inbound, substream) .await - .map(|_| (substream_id, handle, permit)) + .map(|_| Some((substream_id, handle, permit))) .map_err(Into::into) } @@ -392,23 +412,75 @@ impl WebRtcConnection { ParseError::InvalidData.into(), ))?; - let HandshakeResult::Succeeded(protocol) = dialer_state.register_response(message)? else { - tracing::trace!( - target: LOG_TARGET, - peer = ?self.peer, - ?channel_id, - "multistream-select handshake not ready", - ); + let protocol = match dialer_state.register_response(message)? { + HandshakeResult::Succeeded(protocol) => protocol, + HandshakeResult::NotReady => { + tracing::trace!( + target: LOG_TARGET, + peer = ?self.peer, + ?channel_id, + "multistream-select handshake not ready", + ); - self.channels.insert( - channel_id, - ChannelState::OutboundOpening { - context, - dialer_state, - }, - ); + self.channels.insert( + channel_id, + ChannelState::OutboundOpening { + context, + dialer_state, + }, + ); - return Ok(None); + return Ok(None); + } + HandshakeResult::Rejected => match dialer_state.propose_next_fallback() { + Ok(Some(message)) => { + tracing::trace!( + target: LOG_TARGET, + peer = ?self.peer, + ?channel_id, + "protocol rejected, trying next fallback", + ); + + let message = WebRtcMessage::encode(message, None); + self.rtc + .channel(channel_id) + .ok_or(Error::ChannelDoesntExist) + .map_err(|_| { + SubstreamError::NegotiationError(NegotiationError::Failed.into()) + })? + .write(true, message.as_ref()) + .map_err(|_| { + SubstreamError::NegotiationError(NegotiationError::Failed.into()) + })?; + + self.channels.insert( + channel_id, + ChannelState::OutboundOpening { + context, + dialer_state, + }, + ); + + return Ok(None); + } + Ok(None) => { + tracing::debug!( + target: LOG_TARGET, + peer = ?self.peer, + ?channel_id, + "all protocols rejected by remote peer", + ); + + return Err(SubstreamError::NegotiationError( + NegotiationError::Failed.into(), + )); + } + Err(_) => { + return Err(SubstreamError::NegotiationError( + NegotiationError::Failed.into(), + )); + } + }, }; let ChannelContext { @@ -448,13 +520,13 @@ impl WebRtcConnection { ) -> crate::Result<()> { let message = WebRtcMessage::decode(&data)?; - tracing::trace!( + tracing::debug!( target: LOG_TARGET, peer = ?self.peer, ?channel_id, - flags = message.flags, + flag = ?message.flag, data_len = message.payload.as_ref().map_or(0usize, |payload| payload.len()), - "handle inbound message", + "handle inbound message on open channel", ); self.handles @@ -475,6 +547,15 @@ impl WebRtcConnection { /// Handle data received from a channel. async fn on_inbound_data(&mut self, channel_id: ChannelId, data: Vec) -> crate::Result<()> { + tracing::debug!( + target: LOG_TARGET, + peer = ?self.peer, + ?channel_id, + data_len = data.len(), + channel_state = ?self.channels.get(&channel_id), + "received channel data", + ); + let Some(state) = self.channels.remove(&channel_id) else { tracing::warn!( target: LOG_TARGET, @@ -489,7 +570,7 @@ impl WebRtcConnection { match state { ChannelState::InboundOpening => { match self.on_inbound_opening_channel_data(channel_id, data).await { - Ok((substream_id, handle, permit)) => { + Ok(Some((substream_id, handle, permit))) => { self.handles.insert(channel_id, handle); self.channels.insert( channel_id, @@ -500,6 +581,12 @@ impl WebRtcConnection { }, ); } + Ok(None) => { + // Protocol was rejected but `na` response was sent. Keep the + // channel open in `InboundOpening` so the dialer can propose + // another protocol (back-and-forth multistream-select). + self.channels.insert(channel_id, ChannelState::InboundOpening); + } Err(error) => { tracing::debug!( target: LOG_TARGET, @@ -598,20 +685,26 @@ impl WebRtcConnection { Ok(()) } - /// Handle outbound data. - fn on_outbound_data(&mut self, channel_id: ChannelId, data: Vec) -> crate::Result<()> { + /// Handle outbound data with optional flag. + fn on_outbound_data( + &mut self, + channel_id: ChannelId, + data: Vec, + flag: Option, + ) -> crate::Result<()> { tracing::trace!( target: LOG_TARGET, peer = ?self.peer, ?channel_id, data_len = ?data.len(), + ?flag, "send data", ); self.rtc .channel(channel_id) .ok_or(Error::ChannelDoesntExist)? - .write(true, WebRtcMessage::encode(data).as_ref()) + .write(true, WebRtcMessage::encode(data, flag).as_ref()) .map_err(Error::WebRtc) .map(|_| ()) } @@ -682,7 +775,19 @@ impl WebRtcConnection { loop { // poll output until we get a timeout - let timeout = match self.rtc.poll_output().unwrap() { + let output = match self.rtc.poll_output() { + Ok(output) => output, + Err(error) => { + tracing::debug!( + target: LOG_TARGET, + peer = ?self.peer, + ?error, + "poll_output failed, closing connection", + ); + return self.on_connection_closed().await; + } + }; + let timeout = match output { Output::Timeout(v) => v, Output::Transmit(v) => { tracing::trace!( @@ -788,7 +893,7 @@ impl WebRtcConnection { }, event = self.handles.next() => match event { None => unreachable!(), - Some((channel_id, None | Some(SubstreamEvent::Close))) => { + Some((channel_id, None)) => { tracing::trace!( target: LOG_TARGET, peer = ?self.peer, @@ -800,11 +905,12 @@ impl WebRtcConnection { self.channels.insert(channel_id, ChannelState::Closing); self.handles.remove(&channel_id); } - Some((channel_id, Some(SubstreamEvent::Message(data)))) => { - if let Err(error) = self.on_outbound_data(channel_id, data) { + Some((channel_id, Some(SubstreamEvent::Message { payload, flag }))) => { + if let Err(error) = self.on_outbound_data(channel_id, payload, flag) { tracing::debug!( target: LOG_TARGET, ?channel_id, + ?flag, ?error, "failed to send data to remote peer", ); @@ -823,6 +929,20 @@ impl WebRtcConnection { return self.on_connection_closed().await; } Some(ProtocolCommand::OpenSubstream { protocol, fallback_names, substream_id, permit, .. }) => { + // Check if the connection is still healthy before opening new substreams. + // This prevents panics when trying to open channels on a shutting-down + // SCTP association. + if !self.rtc.is_alive() || !self.rtc.is_connected() { + tracing::debug!( + target: LOG_TARGET, + peer = ?self.peer, + ?protocol, + is_alive = self.rtc.is_alive(), + is_connected = self.rtc.is_connected(), + "rejecting substream open: connection not healthy", + ); + continue; + } self.on_open_substream(protocol, fallback_names, substream_id, permit); } }, diff --git a/src/transport/webrtc/opening.rs b/src/transport/webrtc/opening.rs index 5d0d46af4..f778ca843 100644 --- a/src/transport/webrtc/opening.rs +++ b/src/transport/webrtc/opening.rs @@ -207,7 +207,7 @@ impl OpeningWebRtcConnection { }; // create first noise handshake and send it to remote peer - let payload = WebRtcMessage::encode(context.first_message(Role::Dialer)?); + let payload = WebRtcMessage::encode(context.first_message(Role::Dialer)?, None); self.rtc .channel(self.noise_channel_id) @@ -300,7 +300,7 @@ impl OpeningWebRtcConnection { }; // create second noise handshake message and send it to remote - let payload = WebRtcMessage::encode(context.second_message()?); + let payload = WebRtcMessage::encode(context.second_message()?, None); let mut channel = self.rtc.channel(self.noise_channel_id).ok_or(Error::ChannelDoesntExist)?; diff --git a/src/transport/webrtc/substream.rs b/src/transport/webrtc/substream.rs index f839fd83e..5dcf54564 100644 --- a/src/transport/webrtc/substream.rs +++ b/src/transport/webrtc/substream.rs @@ -24,7 +24,7 @@ use crate::{ }; use bytes::{Buf, BufMut, BytesMut}; -use futures::Stream; +use futures::{task::AtomicWaker, Future, Stream}; use parking_lot::Mutex; use tokio::sync::mpsc::{channel, Receiver, Sender}; use tokio_util::sync::PollSender; @@ -33,31 +33,46 @@ use std::{ pin::Pin, sync::Arc, task::{Context, Poll}, + time::Duration, }; /// Maximum frame size. const MAX_FRAME_SIZE: usize = 16384; +/// Timeout for waiting on FIN_ACK after sending FIN. +/// Matches go-libp2p's 5 second stream close timeout. +const FIN_ACK_TIMEOUT: Duration = Duration::from_secs(5); + /// Substream event. #[derive(Debug, PartialEq, Eq)] pub enum Event { /// Receiver closed. RecvClosed, - /// Send/receive message. - Message(Vec), - - /// Close substream. - Close, + /// Send/receive message with optional flag. + Message { + payload: Vec, + flag: Option, + }, } /// Substream stream. +#[derive(Debug, Clone, Copy)] enum State { /// Substream is fully open. Open, /// Remote is no longer interested in receiving anything. SendClosed, + + /// Shutdown initiated, flushing pending data before sending FIN. + Closing, + + /// We sent FIN, waiting for FIN_ACK. + FinSent, + + /// We received FIN_ACK, write half is closed. + FinAcked, } /// Channel-backed substream. Must be owned and polled by exactly one task at a time. @@ -74,6 +89,16 @@ pub struct Substream { /// RX channel for receiving messages from `peer`. rx: Receiver, + + /// Waker to notify when shutdown completes (FIN_ACK received). + shutdown_waker: Arc, + + /// Waker to notify when write state changes (e.g., STOP_SENDING received). + write_waker: Arc, + + /// Timeout for waiting on FIN_ACK after sending FIN. + /// Boxed to maintain Unpin for Substream while allowing the Sleep to be polled. + fin_ack_timeout: Option>>, } impl Substream { @@ -82,11 +107,18 @@ impl Substream { let (outbound_tx, outbound_rx) = channel(256); let (inbound_tx, inbound_rx) = channel(256); let state = Arc::new(Mutex::new(State::Open)); + let shutdown_waker = Arc::new(AtomicWaker::new()); + let write_waker = Arc::new(AtomicWaker::new()); let handle = SubstreamHandle { - tx: inbound_tx, + inbound_tx, + outbound_tx: outbound_tx.clone(), rx: outbound_rx, state: Arc::clone(&state), + shutdown_waker: Arc::clone(&shutdown_waker), + write_waker: Arc::clone(&write_waker), + read_closed: std::sync::atomic::AtomicBool::new(false), + reset_sent: false, }; ( @@ -95,6 +127,9 @@ impl Substream { tx: PollSender::new(outbound_tx), rx: inbound_rx, read_buffer: BytesMut::new(), + shutdown_waker, + write_waker, + fin_ack_timeout: None, }, handle, ) @@ -106,36 +141,115 @@ pub struct SubstreamHandle { state: Arc>, /// TX channel for sending inbound messages from `peer` to the associated `Substream`. - tx: Sender, + inbound_tx: Sender, + + /// TX channel for sending outbound messages to `peer` (e.g., FIN_ACK responses). + outbound_tx: Sender, /// RX channel for receiving outbound messages to `peer` from the associated `Substream`. rx: Receiver, + + /// Waker to notify when shutdown completes (FIN_ACK received). + shutdown_waker: Arc, + + /// Waker to notify when write state changes (e.g., STOP_SENDING received). + write_waker: Arc, + + /// Whether we've already sent RecvClosed to the inbound channel. + /// Prevents duplicate RecvClosed events if multiple FIN messages are received. + read_closed: std::sync::atomic::AtomicBool, + + /// Whether RESET_STREAM has been sent on abrupt close. + reset_sent: bool, } impl SubstreamHandle { /// Handle message received from a remote peer. /// - /// If the message contains any flags, handle them first and appropriately close the correct - /// side of the substream. If the message contained any payload, send it to the protocol for - /// further processing. + /// Process an incoming WebRTC message, handling any payload and flags. + /// + /// Payload is processed first (if present), then flags are handled. This ensures that + /// a FIN message containing final data will deliver that data before signaling closure. pub async fn on_message(&self, message: WebRtcMessage) -> crate::Result<()> { - if let Some(flags) = message.flags { - if flags == Flag::Fin as i32 { - self.tx.send(Event::RecvClosed).await?; - } - - if flags & 1 == Flag::StopSending as i32 { - *self.state.lock() = State::SendClosed; - } - - if flags & 2 == Flag::ResetStream as i32 { - return Err(Error::ConnectionClosed); + // Process payload first, before handling flags. + // This ensures that if a FIN message contains data, we deliver it before closing. + if let Some(payload) = message.payload { + if !payload.is_empty() { + self.inbound_tx + .send(Event::Message { + payload, + flag: None, + }) + .await?; } } - if let Some(payload) = message.payload { - if !payload.is_empty() { - return self.tx.send(Event::Message(payload)).await.map_err(From::from); + // Now handle flags + if let Some(flag) = message.flag { + match flag { + Flag::Fin => { + // Guard against duplicate FIN messages - only send RecvClosed once + if self.read_closed.swap(true, std::sync::atomic::Ordering::SeqCst) { + // Already processed FIN, ignore duplicate + tracing::debug!( + target: "litep2p::webrtc::substream", + "received duplicate FIN, ignoring" + ); + return Ok(()); + } + + // Received FIN from remote, close our read half + self.inbound_tx.send(Event::RecvClosed).await?; + + // Send FIN_ACK back to remote using try_send to avoid blocking. + // If the channel is full, the remote will timeout waiting for FIN_ACK + // and handle it gracefully. This prevents deadlock if the outbound + // channel is blocked due to backpressure. + if let Err(e) = self.outbound_tx.try_send(Event::Message { + payload: vec![], + flag: Some(Flag::FinAck), + }) { + tracing::warn!( + target: "litep2p::webrtc::substream", + ?e, + "failed to send FIN_ACK, remote will timeout" + ); + } + return Ok(()); + } + Flag::FinAck => { + // Received FIN_ACK, we can now fully close our write half + let mut state = self.state.lock(); + if matches!(*state, State::FinSent) { + *state = State::FinAcked; + // Wake up any task waiting on shutdown + self.shutdown_waker.wake(); + } else { + tracing::warn!( + target: "litep2p::webrtc::substream", + ?state, + "received FIN_ACK in unexpected state, ignoring" + ); + } + return Ok(()); + } + Flag::StopSending => { + *self.state.lock() = State::SendClosed; + // Wake any blocked poll_write so it can see the state change + self.write_waker.wake(); + return Ok(()); + } + Flag::ResetStream => { + // RESET_STREAM abruptly terminates both sides of the stream + // (matching go-libp2p behavior) + // Close the read side + let _ = self.inbound_tx.try_send(Event::RecvClosed); + // Close the write side + *self.state.lock() = State::SendClosed; + // Wake any blocked poll_write so it can see the state change + self.write_waker.wake(); + return Err(Error::ConnectionClosed); + } } } @@ -147,7 +261,47 @@ impl Stream for SubstreamHandle { type Item = Event; fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - self.rx.poll_recv(cx) + // First, try to drain any pending outbound messages + match self.rx.poll_recv(cx) { + Poll::Ready(Some(event)) => return Poll::Ready(Some(event)), + Poll::Ready(None) => { + // Outbound channel closed (all senders dropped) + return Poll::Ready(None); + } + Poll::Pending => { + // No messages available, check if we should signal closure + } + } + + // Check if Substream has been dropped (inbound channel closed) + // When Substream is dropped, there will be no more outbound messages + // Since we've already tried to recv above and got Pending, we know the queue is empty + if self.inbound_tx.is_closed() { + let state = *self.state.lock(); + + // If shutdown completed gracefully (FinAcked), just close + if matches!(state, State::FinAcked) { + return Poll::Ready(None); + } + + // Abrupt close - send RESET_STREAM to notify remote peer + // This follows the libp2p WebRTC spec for non-graceful stream termination + if !self.reset_sent { + self.reset_sent = true; + tracing::debug!( + target: "litep2p::webrtc::substream", + "Substream dropped without graceful close, sending RESET_STREAM" + ); + return Poll::Ready(Some(Event::Message { + payload: vec![], + flag: Some(Flag::ResetStream), + })); + } + + return Poll::Ready(None); + } + + Poll::Pending } } @@ -169,19 +323,19 @@ impl tokio::io::AsyncRead for Substream { } match futures::ready!(self.rx.poll_recv(cx)) { - None | Some(Event::Close) | Some(Event::RecvClosed) => + None | Some(Event::RecvClosed) => Poll::Ready(Err(std::io::ErrorKind::BrokenPipe.into())), - Some(Event::Message(message)) => { - if message.len() > MAX_FRAME_SIZE { + Some(Event::Message { payload, flag: _ }) => { + if payload.len() > MAX_FRAME_SIZE { return Poll::Ready(Err(std::io::ErrorKind::PermissionDenied.into())); } - match buf.remaining() >= message.len() { - true => buf.put_slice(&message), + match buf.remaining() >= payload.len() { + true => buf.put_slice(&payload), false => { let remaining = buf.remaining(); - buf.put_slice(&message[..remaining]); - self.read_buffer.put_slice(&message[remaining..]); + buf.put_slice(&payload[..remaining]); + self.read_buffer.put_slice(&payload[remaining..]); } } @@ -197,8 +351,15 @@ impl tokio::io::AsyncWrite for Substream { cx: &mut Context<'_>, buf: &[u8], ) -> Poll> { - if let State::SendClosed = *self.state.lock() { - return Poll::Ready(Err(std::io::ErrorKind::BrokenPipe.into())); + // Register waker so we get notified on state changes (e.g., STOP_SENDING) + self.write_waker.register(cx.waker()); + + // Reject writes if we're closing or closed + match *self.state.lock() { + State::SendClosed | State::Closing | State::FinSent | State::FinAcked => { + return Poll::Ready(Err(std::io::ErrorKind::BrokenPipe.into())); + } + State::Open => {} } match futures::ready!(self.tx.poll_reserve(cx)) { @@ -206,10 +367,21 @@ impl tokio::io::AsyncWrite for Substream { Err(_) => return Poll::Ready(Err(std::io::ErrorKind::BrokenPipe.into())), }; + // Re-check state after poll_reserve - it may have changed while we were waiting + match *self.state.lock() { + State::SendClosed | State::Closing | State::FinSent | State::FinAcked => { + return Poll::Ready(Err(std::io::ErrorKind::BrokenPipe.into())); + } + State::Open => {} + } + let num_bytes = std::cmp::min(MAX_FRAME_SIZE, buf.len()); let frame = buf[..num_bytes].to_vec(); - match self.tx.send_item(Event::Message(frame)) { + match self.tx.send_item(Event::Message { + payload: frame, + flag: None, + }) { Ok(()) => Poll::Ready(Ok(num_bytes)), Err(_) => Poll::Ready(Err(std::io::ErrorKind::BrokenPipe.into())), } @@ -223,13 +395,105 @@ impl tokio::io::AsyncWrite for Substream { mut self: Pin<&mut Self>, cx: &mut Context<'_>, ) -> Poll> { + // State machine for proper shutdown: + // 1. Transition to Closing (stops accepting new writes) + // 2. Flush pending data + // 3. Send FIN flag + // 4. Transition to FinSent + // 5. Wait for FIN_ACK + // 6. Transition to FinAcked and complete + + let current_state = *self.state.lock(); + + match current_state { + // Already received FIN_ACK, shutdown complete + State::FinAcked => return Poll::Ready(Ok(())), + + // Sent FIN, waiting for FIN_ACK - poll timeout and return Pending + State::FinSent => { + // Register waker FIRST to avoid race condition with on_message + self.shutdown_waker.register(cx.waker()); + + // Re-check state after waker registration in case FIN_ACK arrived + // between the initial state check and waker registration + if matches!(*self.state.lock(), State::FinAcked) { + return Poll::Ready(Ok(())); + } + + // Poll the timeout - if it fires, force shutdown completion + if let Some(timeout) = self.fin_ack_timeout.as_mut() { + if timeout.as_mut().poll(cx).is_ready() { + tracing::debug!( + target: "litep2p::webrtc::substream", + "FIN_ACK timeout exceeded, forcing shutdown completion" + ); + *self.state.lock() = State::FinAcked; + return Poll::Ready(Ok(())); + } + } + + return Poll::Pending; + } + + // First call to shutdown - transition to Closing + State::Open => { + *self.state.lock() = State::Closing; + } + + State::Closing => { + // Already in closing state, continue with shutdown process. + // Guard against duplicate FIN sends: if timeout is already set, we've + // already sent FIN and are waiting for FIN_ACK. This shouldn't happen + // with correct AsyncWrite usage (&mut self), but provides defense in depth. + if self.fin_ack_timeout.is_some() { + self.shutdown_waker.register(cx.waker()); + return Poll::Pending; + } + } + + State::SendClosed => { + // Remote closed send, we can still send FIN + } + } + + // Flush any pending data + // Note: Currently poll_flush is a no-op, but the channel backpressure + // provides implicit flushing since we wait for poll_reserve below + futures::ready!(self.as_mut().poll_flush(cx))?; + + // Reserve space to send FIN match futures::ready!(self.tx.poll_reserve(cx)) { Ok(()) => {} Err(_) => return Poll::Ready(Err(std::io::ErrorKind::BrokenPipe.into())), }; - match self.tx.send_item(Event::Close) { - Ok(()) => Poll::Ready(Ok(())), + // Send message with FIN flag + match self.tx.send_item(Event::Message { + payload: vec![], + flag: Some(Flag::Fin), + }) { + Ok(()) => { + // Race condition mitigation strategy: + // 1. Transition to FinSent FIRST so on_message can recognize FIN_ACK (if waker + // registered first, FIN_ACK would be ignored since state != FinSent) + // 2. Register waker so we'll be notified on future FIN_ACK arrivals + // 3. Re-check state to catch FIN_ACK that arrived between steps 1 and 2 (wake() + // called before waker registered has no effect, but state changed) + *self.state.lock() = State::FinSent; + self.shutdown_waker.register(cx.waker()); + if matches!(*self.state.lock(), State::FinAcked) { + return Poll::Ready(Ok(())); + } + + // Initialize the timeout for FIN_ACK + let mut timeout = Box::pin(tokio::time::sleep(FIN_ACK_TIMEOUT)); + // Poll the timeout once to register it with tokio's timer + // This ensures we'll be woken when it expires + let _ = timeout.as_mut().poll(cx); + self.fin_ack_timeout = Some(timeout); + + Poll::Pending + } Err(_) => Poll::Ready(Err(std::io::ErrorKind::BrokenPipe.into())), } } @@ -247,7 +511,13 @@ mod tests { substream.write_all(&vec![0u8; 1337]).await.unwrap(); - assert_eq!(handle.next().await, Some(Event::Message(vec![0u8; 1337]))); + assert_eq!( + handle.next().await, + Some(Event::Message { + payload: vec![0u8; 1337], + flag: None + }) + ); futures::future::poll_fn(|cx| match handle.poll_next_unpin(cx) { Poll::Pending => Poll::Ready(()), @@ -264,13 +534,25 @@ mod tests { assert_eq!( handle.rx.recv().await, - Some(Event::Message(vec![0u8; MAX_FRAME_SIZE])) + Some(Event::Message { + payload: vec![0u8; MAX_FRAME_SIZE], + flag: None, + }) + ); + assert_eq!( + handle.rx.recv().await, + Some(Event::Message { + payload: vec![0u8; MAX_FRAME_SIZE], + flag: None, + }) ); assert_eq!( handle.rx.recv().await, - Some(Event::Message(vec![0u8; MAX_FRAME_SIZE])) + Some(Event::Message { + payload: vec![0u8; 1], + flag: None, + }) ); - assert_eq!(handle.rx.recv().await, Some(Event::Message(vec![0u8; 1]))); futures::future::poll_fn(|cx| match handle.poll_next_unpin(cx) { Poll::Pending => Poll::Ready(()), @@ -295,10 +577,38 @@ mod tests { let (mut substream, mut handle) = Substream::new(); substream.write_all(&vec![1u8; 1337]).await.unwrap(); - substream.shutdown().await.unwrap(); - assert_eq!(handle.next().await, Some(Event::Message(vec![1u8; 1337]))); - assert_eq!(handle.next().await, Some(Event::Close)); + // Spawn shutdown since it waits for FIN_ACK + let shutdown_task = tokio::spawn(async move { + substream.shutdown().await.unwrap(); + }); + + assert_eq!( + handle.next().await, + Some(Event::Message { + payload: vec![1u8; 1337], + flag: None, + }) + ); + // After shutdown, should send FIN flag + assert_eq!( + handle.next().await, + Some(Event::Message { + payload: vec![], + flag: Some(Flag::Fin) + }) + ); + + // Send FIN_ACK to complete shutdown + handle + .on_message(WebRtcMessage { + payload: None, + flag: Some(Flag::FinAck), + }) + .await + .unwrap(); + + shutdown_task.await.unwrap(); } #[tokio::test] @@ -307,7 +617,7 @@ mod tests { handle .on_message(WebRtcMessage { payload: None, - flags: Some(0i32), + flag: Some(Flag::Fin), }) .await .unwrap(); @@ -321,7 +631,14 @@ mod tests { #[tokio::test] async fn read_small_frame() { let (mut substream, handle) = Substream::new(); - handle.tx.send(Event::Message(vec![1u8; 256])).await.unwrap(); + handle + .inbound_tx + .send(Event::Message { + payload: vec![1u8; 256], + flag: None, + }) + .await + .unwrap(); let mut buf = vec![0u8; 2048]; @@ -349,7 +666,14 @@ mod tests { let mut first = vec![1u8; 256]; first.extend_from_slice(&vec![2u8; 256]); - handle.tx.send(Event::Message(first)).await.unwrap(); + handle + .inbound_tx + .send(Event::Message { + payload: first, + flag: None, + }) + .await + .unwrap(); let mut buf = vec![0u8; 256]; @@ -385,8 +709,22 @@ mod tests { let mut first = vec![1u8; 256]; first.extend_from_slice(&vec![2u8; 256]); - handle.tx.send(Event::Message(first)).await.unwrap(); - handle.tx.send(Event::Message(vec![4u8; 2048])).await.unwrap(); + handle + .inbound_tx + .send(Event::Message { + payload: first, + flag: None, + }) + .await + .unwrap(); + handle + .inbound_tx + .send(Event::Message { + payload: vec![4u8; 2048], + flag: None, + }) + .await + .unwrap(); let mut buf = vec![0u8; 256]; @@ -500,4 +838,775 @@ mod tests { .expect("writer task did not complete after capacity was freed") .expect("writer task panicked"); } + + #[tokio::test] + async fn fin_flag_sent_on_shutdown() { + let (mut substream, mut handle) = Substream::new(); + + // Spawn shutdown since it waits for FIN_ACK + let shutdown_task = tokio::spawn(async move { + substream.shutdown().await.unwrap(); + }); + + // Should receive FIN flag + assert_eq!( + handle.next().await, + Some(Event::Message { + payload: vec![], + flag: Some(Flag::Fin) + }) + ); + + // Verify state is FinSent + assert!(matches!(*handle.state.lock(), State::FinSent)); + + // Send FIN_ACK to complete shutdown cleanly (avoids waiting for timeout) + handle + .on_message(WebRtcMessage { + payload: None, + flag: Some(Flag::FinAck), + }) + .await + .unwrap(); + + // Wait for shutdown to complete + shutdown_task.await.unwrap(); + } + + #[tokio::test] + async fn fin_ack_response_on_receiving_fin() { + let (mut substream, mut handle) = Substream::new(); + + // Spawn task to consume inbound events sent to the substream + let consumer_task = tokio::spawn(async move { + // Substream should receive RecvClosed + let mut buf = vec![0u8; 1024]; + match substream.read(&mut buf).await { + Err(e) if e.kind() == std::io::ErrorKind::BrokenPipe => { + // Expected - read half closed + } + other => panic!("Unexpected result: {:?}", other), + } + }); + + // Simulate receiving FIN from remote + handle + .on_message(WebRtcMessage { + payload: None, + flag: Some(Flag::Fin), + }) + .await + .unwrap(); + + // Wait for consumer task to complete + consumer_task.await.unwrap(); + + // Verify FIN_ACK was sent outbound to network + assert_eq!( + handle.next().await, + Some(Event::Message { + payload: vec![], + flag: Some(Flag::FinAck) + }) + ); + } + + #[tokio::test] + async fn fin_ack_received_transitions_to_fin_acked() { + let (mut substream, handle) = Substream::new(); + + // Spawn shutdown since it waits for FIN_ACK + let shutdown_task = tokio::spawn(async move { + substream.shutdown().await.unwrap(); + }); + + // Wait a bit for FIN to be sent + tokio::task::yield_now().await; + + // Verify we're in FinSent state + assert!(matches!(*handle.state.lock(), State::FinSent)); + + // Simulate receiving FIN_ACK from remote + handle + .on_message(WebRtcMessage { + payload: None, + flag: Some(Flag::FinAck), + }) + .await + .unwrap(); + + // Should transition to FinAcked + assert!(matches!(*handle.state.lock(), State::FinAcked)); + + // Shutdown should now complete + shutdown_task.await.unwrap(); + } + + #[tokio::test] + async fn full_fin_handshake() { + let (mut substream, mut handle) = Substream::new(); + + // Write some data + substream.write_all(&vec![1u8; 100]).await.unwrap(); + + // Spawn shutdown in background since it will wait for FIN_ACK + let shutdown_task = tokio::spawn(async move { + substream.shutdown().await.unwrap(); + }); + + // Verify data was sent + assert_eq!( + handle.next().await, + Some(Event::Message { + payload: vec![1u8; 100], + flag: None, + }) + ); + + // Verify FIN was sent + assert_eq!( + handle.next().await, + Some(Event::Message { + payload: vec![], + flag: Some(Flag::Fin) + }) + ); + + // Simulate receiving FIN_ACK + handle + .on_message(WebRtcMessage { + payload: None, + flag: Some(Flag::FinAck), + }) + .await + .unwrap(); + + // Should be in FinAcked state + assert!(matches!(*handle.state.lock(), State::FinAcked)); + + // Shutdown should now complete + shutdown_task.await.unwrap(); + } + + #[tokio::test] + async fn stop_sending_flag_closes_send_half() { + let (mut substream, handle) = Substream::new(); + + // Simulate receiving STOP_SENDING + handle + .on_message(WebRtcMessage { + payload: None, + flag: Some(Flag::StopSending), + }) + .await + .unwrap(); + + // Should transition to SendClosed + assert!(matches!(*handle.state.lock(), State::SendClosed)); + + // Attempting to write should fail + match substream.write_all(&vec![0u8; 100]).await { + Err(error) => assert_eq!(error.kind(), std::io::ErrorKind::BrokenPipe), + _ => panic!("write should have failed"), + } + } + + #[tokio::test] + async fn reset_stream_flag_closes_both_sides() { + use tokio::io::AsyncWriteExt; + let (mut substream, handle) = Substream::new(); + + // Simulate receiving RESET_STREAM + let result = handle + .on_message(WebRtcMessage { + payload: None, + flag: Some(Flag::ResetStream), + }) + .await; + + // Should return connection closed error + assert!(matches!(result, Err(Error::ConnectionClosed))); + + // Write side should be closed (state = SendClosed) + assert!(matches!(*handle.state.lock(), State::SendClosed)); + + // Attempting to write should fail + match substream.write_all(&vec![0u8; 100]).await { + Err(error) => assert_eq!(error.kind(), std::io::ErrorKind::BrokenPipe), + _ => panic!("write should have failed"), + } + + // Read side should also be closed (RecvClosed event was sent) + // The substream's rx channel should have RecvClosed + assert!(matches!(substream.rx.try_recv(), Ok(Event::RecvClosed))); + } + + #[tokio::test] + async fn fin_ack_does_not_trigger_other_flag() { + let (mut substream, handle) = Substream::new(); + + // Spawn shutdown since it waits for FIN_ACK + let shutdown_task = tokio::spawn(async move { + substream.shutdown().await.unwrap(); + }); + + // Wait a bit for FIN to be sent + tokio::task::yield_now().await; + + // Verify we're in FinSent state + assert!(matches!(*handle.state.lock(), State::FinSent)); + + // Now simulate receiving FIN_ACK (value = 3) + // This should NOT trigger STOP_SENDING (value = 1) or RESET_STREAM (value = 2) + // even though 3 & 1 == 1 and 3 & 2 == 2 + handle + .on_message(WebRtcMessage { + payload: None, + flag: Some(Flag::FinAck), + }) + .await + .unwrap(); + + // Should transition to FinAcked, not SendClosed + assert!(matches!(*handle.state.lock(), State::FinAcked)); + + // Shutdown should complete + shutdown_task.await.unwrap(); + + // Writing should still work (not closed by STOP_SENDING) + // Note: We already sent FIN, so write won't actually work, but the state check happens + // first + } + + #[tokio::test] + async fn flags_are_mutually_exclusive() { + let (_substream, handle) = Substream::new(); + + // Test that STOP_SENDING (1) is handled correctly + handle + .on_message(WebRtcMessage { + payload: None, + flag: Some(Flag::StopSending), + }) + .await + .unwrap(); + + assert!(matches!(*handle.state.lock(), State::SendClosed)); + + // Create a new substream for RESET_STREAM test + let (_substream2, handle2) = Substream::new(); + + // Test that RESET_STREAM (2) is handled correctly + let result = handle2 + .on_message(WebRtcMessage { + payload: None, + flag: Some(Flag::ResetStream), + }) + .await; + + assert!(matches!(result, Err(Error::ConnectionClosed))); + + // Create a new substream for FIN test + let (mut substream3, handle3) = Substream::new(); + + // Spawn shutdown since it waits for FIN_ACK + let shutdown_task3 = tokio::spawn(async move { + substream3.shutdown().await.unwrap(); + }); + + // Wait a bit for FIN to be sent + tokio::task::yield_now().await; + + // Test that FIN_ACK (3) is handled correctly + handle3 + .on_message(WebRtcMessage { + payload: None, + flag: Some(Flag::FinAck), + }) + .await + .unwrap(); + + assert!(matches!(*handle3.state.lock(), State::FinAcked)); + + // Shutdown should complete + shutdown_task3.await.unwrap(); + } + + #[tokio::test] + async fn stop_sending_wakes_blocked_writer() { + use tokio::io::AsyncWriteExt; + let (mut substream, handle) = Substream::new(); + + // Fill up the channel to cause poll_write to return Pending + // Channel capacity is 256 + for _ in 0..256 { + substream.write_all(&[1u8; 100]).await.unwrap(); + } + + // Now the next write should block waiting for channel capacity + let write_task = tokio::spawn(async move { + // This write will block because channel is full + let result = substream.write_all(&[2u8; 100]).await; + // Should fail because STOP_SENDING was received + assert!(result.is_err()); + }); + + // Give the writer time to block on poll_reserve + tokio::time::sleep(Duration::from_millis(10)).await; + assert!(!write_task.is_finished(), "write should be blocked"); + + // Simulate receiving STOP_SENDING from remote + handle + .on_message(WebRtcMessage { + payload: None, + flag: Some(Flag::StopSending), + }) + .await + .unwrap(); + + // The write task should wake up and see the state change + tokio::time::timeout(Duration::from_secs(1), write_task) + .await + .expect("write task should complete after STOP_SENDING") + .unwrap(); + } + + #[tokio::test] + async fn reset_stream_wakes_blocked_writer() { + use tokio::io::AsyncWriteExt; + let (mut substream, handle) = Substream::new(); + + // Fill up the channel to cause poll_write to return Pending + // Channel capacity is 256 + for _ in 0..256 { + substream.write_all(&[1u8; 100]).await.unwrap(); + } + + // Now the next write should block waiting for channel capacity + let write_task = tokio::spawn(async move { + // This write will block because channel is full + let result = substream.write_all(&[2u8; 100]).await; + // Should fail because RESET_STREAM was received + assert!(result.is_err()); + }); + + // Give the writer time to block on poll_reserve + tokio::time::sleep(Duration::from_millis(10)).await; + assert!(!write_task.is_finished(), "write should be blocked"); + + // Simulate receiving RESET_STREAM from remote + let result = handle + .on_message(WebRtcMessage { + payload: None, + flag: Some(Flag::ResetStream), + }) + .await; + // RESET_STREAM returns an error + assert!(result.is_err()); + + // The write task should wake up and see the state change + tokio::time::timeout(Duration::from_secs(1), write_task) + .await + .expect("write task should complete after RESET_STREAM") + .unwrap(); + } + + #[tokio::test] + async fn shutdown_rejects_new_writes() { + use tokio::io::AsyncWriteExt; + let (mut substream, mut handle) = Substream::new(); + + // Write some data + substream.write_all(&vec![1u8; 100]).await.unwrap(); + + // Spawn shutdown in background + let shutdown_task = tokio::spawn(async move { + substream.shutdown().await.unwrap(); + }); + + // Wait for data and FIN to be sent + assert_eq!( + handle.next().await, + Some(Event::Message { + payload: vec![1u8; 100], + flag: None, + }) + ); + assert_eq!( + handle.next().await, + Some(Event::Message { + payload: vec![], + flag: Some(Flag::Fin) + }) + ); + + // Verify we transitioned through Closing to FinSent + assert!(matches!(*handle.state.lock(), State::FinSent)); + + // Send FIN_ACK to complete shutdown + handle + .on_message(WebRtcMessage { + payload: None, + flag: Some(Flag::FinAck), + }) + .await + .unwrap(); + + // Shutdown should complete + shutdown_task.await.unwrap(); + } + + #[tokio::test] + async fn shutdown_idempotent() { + use tokio::io::AsyncWriteExt; + let (mut substream, mut handle) = Substream::new(); + + // Spawn first shutdown + let shutdown_task1 = tokio::spawn(async move { + substream.shutdown().await.unwrap(); + substream + }); + + // Wait for FIN to be sent + assert_eq!( + handle.next().await, + Some(Event::Message { + payload: vec![], + flag: Some(Flag::Fin) + }) + ); + assert!(matches!(*handle.state.lock(), State::FinSent)); + + // Send FIN_ACK to complete first shutdown + handle + .on_message(WebRtcMessage { + payload: None, + flag: Some(Flag::FinAck), + }) + .await + .unwrap(); + + // First shutdown should complete + let mut substream = shutdown_task1.await.unwrap(); + + // Second shutdown should succeed without error (already in FinAcked state) + substream.shutdown().await.unwrap(); + assert!(matches!(*handle.state.lock(), State::FinAcked)); + } + + #[tokio::test] + async fn shutdown_timeout_without_fin_ack() { + use tokio::time::{timeout, Duration}; + + let (mut substream, mut handle) = Substream::new(); + + // Spawn shutdown in background + let shutdown_task = tokio::spawn(async move { + substream.shutdown().await.unwrap(); + }); + + // Wait for FIN to be sent + assert_eq!( + handle.next().await, + Some(Event::Message { + payload: vec![], + flag: Some(Flag::Fin) + }) + ); + + // Verify we're in FinSent state + assert!(matches!(*handle.state.lock(), State::FinSent)); + + // DON'T send FIN_ACK - let it timeout + // The shutdown should complete after FIN_ACK_TIMEOUT (5 seconds) + // Add a bit of buffer to the timeout + let result = timeout(Duration::from_secs(7), shutdown_task).await; + + assert!(result.is_ok(), "Shutdown should complete after timeout"); + assert!( + result.unwrap().is_ok(), + "Shutdown should succeed after timeout" + ); + + // Should have transitioned to FinAcked after timeout + assert!(matches!(*handle.state.lock(), State::FinAcked)); + } + + #[tokio::test] + async fn closing_state_blocks_writes() { + use tokio::io::AsyncWriteExt; + + let (mut substream, handle) = Substream::new(); + + // Manually transition to Closing state + *handle.state.lock() = State::Closing; + + // Attempt to write should fail + let result = substream.write_all(&vec![1u8; 100]).await; + assert!(result.is_err()); + assert_eq!(result.unwrap_err().kind(), std::io::ErrorKind::BrokenPipe); + } + + #[tokio::test] + async fn handle_signals_closure_after_substream_dropped() { + use futures::StreamExt; + + let (mut substream, mut handle) = Substream::new(); + + // Complete shutdown handshake (client-initiated) + let shutdown_task = tokio::spawn(async move { + substream.shutdown().await.unwrap(); + // Substream will be dropped here + }); + + // Receive FIN + assert_eq!( + handle.next().await, + Some(Event::Message { + payload: vec![], + flag: Some(Flag::Fin) + }) + ); + + // Send FIN_ACK + handle + .on_message(WebRtcMessage { + payload: None, + flag: Some(Flag::FinAck), + }) + .await + .unwrap(); + + // Wait for shutdown to complete and Substream to drop + shutdown_task.await.unwrap(); + + // Verify handle signals closure (returns None) + assert_eq!( + handle.next().await, + None, + "SubstreamHandle should signal closure after Substream is dropped" + ); + } + + #[tokio::test] + async fn server_side_closure_after_receiving_fin() { + use futures::StreamExt; + + let (mut substream, mut handle) = Substream::new(); + + // Spawn task to consume from substream (server side) + let server_task = tokio::spawn(async move { + let mut buf = vec![0u8; 1024]; + // This should fail because we receive RecvClosed + match substream.read(&mut buf).await { + Err(e) if e.kind() == std::io::ErrorKind::BrokenPipe => { + // Expected - read half closed by FIN + } + other => panic!("Unexpected result: {:?}", other), + } + // Substream dropped here without calling shutdown() - this is an abrupt close + }); + + // Remote (client) sends FIN + handle + .on_message(WebRtcMessage { + payload: None, + flag: Some(Flag::Fin), + }) + .await + .unwrap(); + + // Verify FIN_ACK was sent back + assert_eq!( + handle.next().await, + Some(Event::Message { + payload: vec![], + flag: Some(Flag::FinAck) + }) + ); + + // Wait for server to close substream + server_task.await.unwrap(); + + // Since server didn't call shutdown(), this is an abrupt close - RESET_STREAM is sent + assert_eq!( + handle.next().await, + Some(Event::Message { + payload: vec![], + flag: Some(Flag::ResetStream) + }), + "SubstreamHandle should send RESET_STREAM when server drops without shutdown" + ); + + // Then closure + assert_eq!( + handle.next().await, + None, + "SubstreamHandle should signal closure after sending RESET_STREAM" + ); + } + + #[tokio::test] + async fn abrupt_close_sends_reset_stream() { + use futures::StreamExt; + + let (substream, mut handle) = Substream::new(); + + // Drop substream without calling shutdown() - this is an abrupt close + drop(substream); + + // Verify RESET_STREAM is sent before closure + assert_eq!( + handle.next().await, + Some(Event::Message { + payload: vec![], + flag: Some(Flag::ResetStream) + }), + "SubstreamHandle should send RESET_STREAM on abrupt close" + ); + + // Then verify handle signals closure + assert_eq!( + handle.next().await, + None, + "SubstreamHandle should signal closure after sending RESET_STREAM" + ); + } + + #[tokio::test] + async fn graceful_close_does_not_send_reset_stream() { + use futures::StreamExt; + + let (mut substream, mut handle) = Substream::new(); + + // Complete graceful shutdown + let shutdown_task = tokio::spawn(async move { + substream.shutdown().await.unwrap(); + // Substream dropped after graceful shutdown + }); + + // Receive FIN + assert_eq!( + handle.next().await, + Some(Event::Message { + payload: vec![], + flag: Some(Flag::Fin) + }) + ); + + // Send FIN_ACK to complete handshake + handle + .on_message(WebRtcMessage { + payload: None, + flag: Some(Flag::FinAck), + }) + .await + .unwrap(); + + // Wait for shutdown to complete + shutdown_task.await.unwrap(); + + // Verify handle signals closure directly (no RESET_STREAM) + assert_eq!( + handle.next().await, + None, + "SubstreamHandle should NOT send RESET_STREAM after graceful close" + ); + } + + #[tokio::test] + async fn simultaneous_close() { + // Test simultaneous close where both sides send FIN at the same time. + // This verifies that: + // 1. Both sides can be in FinSent state simultaneously + // 2. Both sides correctly respond to FIN with FIN_ACK even when in FinSent state + // 3. Both sides eventually transition to FinAcked + + let (mut substream, mut handle) = Substream::new(); + + // Local side initiates shutdown (sends FIN, transitions to FinSent) + let shutdown_task = tokio::spawn(async move { + substream.shutdown().await.unwrap(); + }); + + // Wait for local FIN to be sent + assert_eq!( + handle.next().await, + Some(Event::Message { + payload: vec![], + flag: Some(Flag::Fin) + }) + ); + + // Verify local is in FinSent state + assert!(matches!(*handle.state.lock(), State::FinSent)); + + // Now simulate remote also sending FIN (simultaneous close) + // This should trigger FIN_ACK response even though we're in FinSent state + handle + .on_message(WebRtcMessage { + payload: None, + flag: Some(Flag::Fin), + }) + .await + .unwrap(); + + // Local should send FIN_ACK in response to remote's FIN + assert_eq!( + handle.next().await, + Some(Event::Message { + payload: vec![], + flag: Some(Flag::FinAck) + }) + ); + + // Local should still be in FinSent (waiting for FIN_ACK from remote) + assert!(matches!(*handle.state.lock(), State::FinSent)); + + // Now remote sends FIN_ACK (completing their side of the handshake) + handle + .on_message(WebRtcMessage { + payload: None, + flag: Some(Flag::FinAck), + }) + .await + .unwrap(); + + // Local should now transition to FinAcked + assert!(matches!(*handle.state.lock(), State::FinAcked)); + + // Shutdown should complete successfully + shutdown_task.await.unwrap(); + } + + #[tokio::test] + async fn fin_with_payload_delivers_data_before_close() { + // Test that when a FIN message contains payload data, the data is delivered + // to the substream before the RecvClosed event. This is important because + // the spec allows a FIN message to contain final data. + + let (mut substream, handle) = Substream::new(); + + // Simulate receiving FIN with payload from remote + handle + .on_message(WebRtcMessage { + payload: Some(b"final data".to_vec()), + flag: Some(Flag::Fin), + }) + .await + .unwrap(); + + // First, we should receive the payload data + let mut buf = vec![0u8; 1024]; + let n = substream.read(&mut buf).await.unwrap(); + assert_eq!(&buf[..n], b"final data"); + + // Then, subsequent read should fail with BrokenPipe (RecvClosed) + match substream.read(&mut buf).await { + Err(e) if e.kind() == std::io::ErrorKind::BrokenPipe => { + // Expected - read half closed after FIN + } + other => panic!("Expected BrokenPipe error, got: {:?}", other), + } + } } diff --git a/src/transport/webrtc/util.rs b/src/transport/webrtc/util.rs index 55917afc6..ae050d50d 100644 --- a/src/transport/webrtc/util.rs +++ b/src/transport/webrtc/util.rs @@ -18,74 +18,97 @@ // FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER // DEALINGS IN THE SOFTWARE. -use crate::{codec::unsigned_varint::UnsignedVarint, error::ParseError, transport::webrtc::schema}; +use crate::{ + error::ParseError, + transport::webrtc::schema::{self, webrtc::message::Flag}, +}; use prost::Message; -use tokio_util::codec::{Decoder, Encoder}; -/// WebRTC mesage. +/// WebRTC message. #[derive(Debug)] pub struct WebRtcMessage { /// Payload. pub payload: Option>, - // Flags. - pub flags: Option, + /// Flag. + pub flag: Option, } impl WebRtcMessage { - /// Encode WebRTC message. - pub fn encode(payload: Vec) -> Vec { + /// Encode WebRTC message with optional flag. + /// + /// Uses a single allocation by pre-calculating the total size and encoding + /// the varint length prefix and protobuf message directly into the output buffer. + pub fn encode(payload: Vec, flag: Option) -> Vec { let protobuf_payload = schema::webrtc::Message { message: (!payload.is_empty()).then_some(payload), - flag: None, + flag: flag.map(|f| f as i32), }; - let mut payload = Vec::with_capacity(protobuf_payload.encoded_len()); - protobuf_payload - .encode(&mut payload) - .expect("Vec to provide needed capacity"); - let mut out_buf = bytes::BytesMut::with_capacity(payload.len() + 4); - let mut codec = UnsignedVarint::new(None); - let _result = codec.encode(payload.into(), &mut out_buf); + // Calculate sizes upfront for single allocation with exact capacity + let protobuf_len = protobuf_payload.encoded_len(); + // Varint uses 7 bits per byte, so calculate exact length needed + // ilog2 gives the position of the highest set bit (0-indexed), divide by 7 for varint bytes + let varint_len = if protobuf_len == 0 { + 1 + } else { + (protobuf_len.ilog2() as usize / 7) + 1 + }; - out_buf.into() - } + // Single allocation for the entire output with exact size + let mut out_buf = Vec::with_capacity(varint_len + protobuf_len); - /// Encode WebRTC message with flags. - #[allow(unused)] - pub fn encode_with_flags(payload: Vec, flags: i32) -> Vec { - let protobuf_payload = schema::webrtc::Message { - message: (!payload.is_empty()).then_some(payload), - flag: Some(flags), - }; - let mut payload = Vec::with_capacity(protobuf_payload.encoded_len()); + // Encode varint length prefix directly + let mut varint_buf = unsigned_varint::encode::usize_buffer(); + let varint_slice = unsigned_varint::encode::usize(protobuf_len, &mut varint_buf); + out_buf.extend_from_slice(varint_slice); + + // Encode protobuf directly into output buffer protobuf_payload - .encode(&mut payload) + .encode(&mut out_buf) .expect("Vec to provide needed capacity"); - let mut out_buf = bytes::BytesMut::with_capacity(payload.len() + 4); - let mut codec = UnsignedVarint::new(None); - let _result = codec.encode(payload.into(), &mut out_buf); - - out_buf.into() + out_buf } /// Decode payload into [`WebRtcMessage`]. + /// + /// Decodes the varint length prefix directly from the slice without allocations, + /// then decodes the protobuf message from the remaining bytes. + /// + /// # Flag handling + /// + /// Unknown flag values (e.g., from a newer protocol version) are logged as warnings + /// and treated as `None` for forward compatibility. This allows the message payload + /// to still be processed even if the flag is not recognized. pub fn decode(payload: &[u8]) -> Result { - // TODO: https://github.com/paritytech/litep2p/issues/352 set correct size - let mut codec = UnsignedVarint::new(None); - let mut data = bytes::BytesMut::from(payload); - let result = codec - .decode(&mut data) - .map_err(|_| ParseError::InvalidData)? - .ok_or(ParseError::InvalidData)?; - - match schema::webrtc::Message::decode(result) { - Ok(message) => Ok(Self { - payload: message.message, - flags: message.flag, - }), + // Decode varint length prefix directly from slice (no allocation) + // Returns (decoded_length, remaining_bytes_after_varint) + let (len, remaining) = + unsigned_varint::decode::usize(payload).map_err(|_| ParseError::InvalidData)?; + + // Get exactly `len` bytes of protobuf data (no allocation) + let protobuf_data = remaining.get(..len).ok_or(ParseError::InvalidData)?; + + match schema::webrtc::Message::decode(protobuf_data) { + Ok(message) => { + let flag = message.flag.and_then(|f| match Flag::try_from(f) { + Ok(flag) => Some(flag), + Err(_) => { + tracing::warn!( + target: "litep2p::webrtc", + ?f, + "received message with unknown flag value, ignoring flag" + ); + None + } + }); + Ok(Self { + payload: message.message, + flag, + }) + } Err(_) => Err(ParseError::InvalidData), } } @@ -96,29 +119,30 @@ mod tests { use super::*; #[test] - fn with_payload_no_flags() { - let message = WebRtcMessage::encode("Hello, world!".as_bytes().to_vec()); + fn with_payload_no_flag() { + let message = WebRtcMessage::encode("Hello, world!".as_bytes().to_vec(), None); let decoded = WebRtcMessage::decode(&message).unwrap(); assert_eq!(decoded.payload, Some("Hello, world!".as_bytes().to_vec())); - assert_eq!(decoded.flags, None); + assert_eq!(decoded.flag, None); } #[test] - fn with_payload_and_flags() { - let message = WebRtcMessage::encode_with_flags("Hello, world!".as_bytes().to_vec(), 1i32); + fn with_payload_and_flag() { + let message = + WebRtcMessage::encode("Hello, world!".as_bytes().to_vec(), Some(Flag::StopSending)); let decoded = WebRtcMessage::decode(&message).unwrap(); assert_eq!(decoded.payload, Some("Hello, world!".as_bytes().to_vec())); - assert_eq!(decoded.flags, Some(1i32)); + assert_eq!(decoded.flag, Some(Flag::StopSending)); } #[test] - fn no_payload_with_flags() { - let message = WebRtcMessage::encode_with_flags(vec![], 2i32); + fn no_payload_with_flag() { + let message = WebRtcMessage::encode(vec![], Some(Flag::ResetStream)); let decoded = WebRtcMessage::decode(&message).unwrap(); assert_eq!(decoded.payload, None); - assert_eq!(decoded.flags, Some(2i32)); + assert_eq!(decoded.flag, Some(Flag::ResetStream)); } }