diff --git a/Cargo.lock b/Cargo.lock index 658484522..39efe70ca 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5793,9 +5793,9 @@ checksum = "a2eb9349b6444b326872e140eb1cf5e7c522154d69e7a0ffb0fb81c06b37543f" [[package]] name = "str0m" -version = "0.2.0" +version = "0.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ee48572247f422dcbe68630c973f8296fbd5157119cd36a3223e48bf83d47727" +checksum = "d3f10d3f68e60168d81110410428a435dbde28cc5525f5f7c6fdec92dbdc2800" dependencies = [ "combine", "crc", diff --git a/Cargo.toml b/Cargo.toml index 89ad45663..bcd9fdb7e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -36,7 +36,7 @@ simple-dns = "0.5.3" smallvec = "1.10.0" snow = { version = "0.9.3", features = ["ring-resolver"], default-features = false } socket2 = { version = "0.5.5", features = ["all"] } -str0m = "0.2.0" +str0m = "0.4.1" thiserror = "1.0.39" tokio-stream = "0.1.12" tokio-tungstenite = { version = "0.20.0", features = ["rustls-tls-native-roots"] } diff --git a/src/multistream_select/dialer_select.rs b/src/multistream_select/dialer_select.rs index 2a8a025e2..9a65ef2bb 100644 --- a/src/multistream_select/dialer_select.rs +++ b/src/multistream_select/dialer_select.rs @@ -24,7 +24,9 @@ use crate::{ codec::unsigned_varint::UnsignedVarint, error::{self, Error}, multistream_select::{ - protocol::{HeaderLine, Message, MessageIO, Protocol, ProtocolError}, + protocol::{ + encode_multistream_message, HeaderLine, Message, MessageIO, Protocol, ProtocolError, + }, Negotiated, NegotiationError, Version, }, types::protocol::ProtocolName, @@ -224,7 +226,7 @@ where } /// `multistream-select` handshake result for dialer. -#[derive(Debug)] +#[derive(Debug, PartialEq, Eq)] pub enum HandshakeResult { /// Handshake is not complete, data missing. NotReady, @@ -259,7 +261,6 @@ pub struct DialerState { state: HandshakeState, } -// TODO: tests impl DialerState { /// Propose protocol to remote peer. /// @@ -269,21 +270,14 @@ impl DialerState { protocol: ProtocolName, fallback_names: Vec, ) -> crate::Result<(Self, Vec)> { - // encode `/multistream-select/1.0.0` header - let mut bytes = BytesMut::with_capacity(64); - let message = Message::Header(HeaderLine::V1); - let _ = message.encode(&mut bytes).map_err(|_| Error::InvalidData)?; - let mut header = UnsignedVarint::encode(bytes)?; - - // encode proposed protocol - let mut proto_bytes = BytesMut::with_capacity(512); - let message = Message::Protocol(Protocol::try_from(protocol.as_bytes()).unwrap()); - let _ = message.encode(&mut proto_bytes).map_err(|_| Error::InvalidData)?; - let proto_bytes = UnsignedVarint::encode(proto_bytes)?; - - // TODO: add fallback names - - header.append(&mut proto_bytes.into()); + let message = encode_multistream_message( + std::iter::once(protocol.clone()) + .chain(fallback_names.clone()) + .filter_map(|protocol| Protocol::try_from(protocol.as_ref()).ok()) + .map(|protocol| Message::Protocol(protocol)), + )? + .freeze() + .to_vec(); Ok(( Self { @@ -291,7 +285,7 @@ impl DialerState { fallback_names, state: HandshakeState::WaitingResponse, }, - header, + message, )) } @@ -328,10 +322,9 @@ impl DialerState { return Ok(HandshakeResult::Succeeded(self.protocol.clone())); } - // TODO: zzz for fallback in &self.fallback_names { if fallback.as_bytes() == protocol.as_ref() { - return Ok(HandshakeResult::Succeeded(self.protocol.clone())); + return Ok(HandshakeResult::Succeeded(fallback.clone())); } } @@ -346,3 +339,148 @@ impl DialerState { } } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn propose() { + let (mut dialer_state, message) = + DialerState::propose(ProtocolName::from("/13371338/proto/1"), vec![]).unwrap(); + let message = bytes::BytesMut::from(&message[..]).freeze(); + + let Message::Protocols(protocols) = Message::decode(message).unwrap() else { + panic!("invalid message type"); + }; + + assert_eq!(protocols.len(), 2); + assert_eq!( + protocols[0], + Protocol::try_from(&b"/multistream/1.0.0"[..]) + .expect("valid multitstream-select header") + ); + assert_eq!( + protocols[1], + Protocol::try_from(&b"/13371338/proto/1"[..]) + .expect("valid multitstream-select header") + ); + } + + #[test] + fn propose_with_fallback() { + let (mut dialer_state, message) = DialerState::propose( + ProtocolName::from("/13371338/proto/1"), + vec![ProtocolName::from("/sup/proto/1")], + ) + .unwrap(); + let message = bytes::BytesMut::from(&message[..]).freeze(); + + let Message::Protocols(protocols) = Message::decode(message).unwrap() else { + panic!("invalid message type"); + }; + + assert_eq!(protocols.len(), 3); + assert_eq!( + protocols[0], + Protocol::try_from(&b"/multistream/1.0.0"[..]) + .expect("valid multitstream-select header") + ); + assert_eq!( + protocols[1], + Protocol::try_from(&b"/13371338/proto/1"[..]) + .expect("valid multitstream-select header") + ); + assert_eq!( + protocols[2], + Protocol::try_from(&b"/sup/proto/1"[..]).expect("valid multitstream-select header") + ); + } + + #[test] + fn register_response_invalid_message() { + // send only header line + let mut bytes = BytesMut::with_capacity(32); + let message = Message::Header(HeaderLine::V1); + let _ = message.encode(&mut bytes).map_err(|_| Error::InvalidData).unwrap(); + + let (mut dialer_state, _message) = + DialerState::propose(ProtocolName::from("/13371338/proto/1"), vec![]).unwrap(); + + match dialer_state.register_response(bytes.freeze().to_vec()) { + Err(Error::NegotiationError(error::NegotiationError::MultistreamSelectError( + NegotiationError::Failed, + ))) => {} + event => panic!("invalid event: {event:?}"), + } + } + + #[test] + fn header_line_missing() { + // header line missing + let mut bytes = BytesMut::with_capacity(256); + let message = Message::Protocols(vec![ + Protocol::try_from(&b"/13371338/proto/1"[..]).unwrap(), + Protocol::try_from(&b"/sup/proto/1"[..]).unwrap(), + ]); + let _ = message.encode(&mut bytes).map_err(|_| Error::InvalidData).unwrap(); + + let (mut dialer_state, _message) = + DialerState::propose(ProtocolName::from("/13371338/proto/1"), vec![]).unwrap(); + + match dialer_state.register_response(bytes.freeze().to_vec()) { + Err(Error::NegotiationError(error::NegotiationError::MultistreamSelectError( + NegotiationError::Failed, + ))) => {} + event => panic!("invalid event: {event:?}"), + } + } + + #[test] + fn negotiate_main_protocol() { + let message = encode_multistream_message( + vec![Message::Protocol( + Protocol::try_from(&b"/13371338/proto/1"[..]).unwrap(), + )] + .into_iter(), + ) + .unwrap() + .freeze(); + + let (mut dialer_state, _message) = DialerState::propose( + ProtocolName::from("/13371338/proto/1"), + vec![ProtocolName::from("/sup/proto/1")], + ) + .unwrap(); + + match dialer_state.register_response(message.to_vec()) { + Ok(HandshakeResult::Succeeded(negotiated)) => + assert_eq!(negotiated, ProtocolName::from("/13371338/proto/1")), + _ => panic!("invalid event"), + } + } + + #[test] + fn negotiate_fallback_protocol() { + let message = encode_multistream_message( + vec![Message::Protocol( + Protocol::try_from(&b"/sup/proto/1"[..]).unwrap(), + )] + .into_iter(), + ) + .unwrap() + .freeze(); + + let (mut dialer_state, _message) = DialerState::propose( + ProtocolName::from("/13371338/proto/1"), + vec![ProtocolName::from("/sup/proto/1")], + ) + .unwrap(); + + match dialer_state.register_response(message.to_vec()) { + Ok(HandshakeResult::Succeeded(negotiated)) => + assert_eq!(negotiated, ProtocolName::from("/sup/proto/1")), + _ => panic!("invalid event"), + } + } +} diff --git a/src/multistream_select/listener_select.rs b/src/multistream_select/listener_select.rs index 4217a332b..ae7d4d4b5 100644 --- a/src/multistream_select/listener_select.rs +++ b/src/multistream_select/listener_select.rs @@ -25,7 +25,9 @@ use crate::{ codec::unsigned_varint::UnsignedVarint, error::{self, Error}, multistream_select::{ - protocol::{HeaderLine, Message, MessageIO, Protocol, ProtocolError}, + protocol::{ + encode_multistream_message, HeaderLine, Message, MessageIO, Protocol, ProtocolError, + }, Negotiated, NegotiationError, }, types::protocol::ProtocolName, @@ -322,6 +324,25 @@ where } } +/// Result of [`listener_negotiate()`]. +#[derive(Debug)] +pub enum ListenerSelectResult { + /// Requested protocol is available and substream can be accepted. + Accepted { + /// Protocol that is confirmed. + protocol: ProtocolName, + + /// `multistream-select` message. + message: BytesMut, + }, + + /// Requested protocol is not available. + Rejected { + /// `multistream-select` message. + message: BytesMut, + }, +} + /// Negotiate protocols for listener. /// /// Parse protocols offered by the remote peer and check if any of the offered protocols match @@ -330,7 +351,7 @@ where pub fn listener_negotiate<'a>( supported_protocols: &'a mut impl Iterator, payload: Bytes, -) -> crate::Result<(ProtocolName, BytesMut)> { +) -> crate::Result { let Message::Protocols(protocols) = Message::decode(payload).map_err(|_| Error::InvalidData)? else { return Err(Error::NegotiationError( @@ -344,37 +365,39 @@ pub fn listener_negotiate<'a>( let header = Protocol::try_from(&b"/multistream/1.0.0"[..]).expect("valid multitstream-select header"); - if !std::matches!(protocol_iter.next(), Some(header)) { + if protocol_iter.next() != Some(header) { return Err(Error::NegotiationError( error::NegotiationError::MultistreamSelectError(NegotiationError::Failed), )); } for protocol in protocol_iter { + tracing::trace!( + target: LOG_TARGET, + protocol = ?std::str::from_utf8(protocol.as_ref()), + "listener: checking protocol", + ); + for supported in &mut *supported_protocols { if protocol.as_ref() == supported.as_bytes() { - // encode `/multistream-select/1.0.0` header - let mut bytes = BytesMut::with_capacity(64); - let message = Message::Header(HeaderLine::V1); - let _ = message.encode(&mut bytes).map_err(|_| Error::InvalidData)?; - let mut header = UnsignedVarint::encode(bytes)?; - - // encode negotiated protocol - let mut proto_bytes = BytesMut::with_capacity(512); - let message = Message::Protocol(protocol); - let _ = message.encode(&mut proto_bytes).map_err(|_| Error::InvalidData)?; - let proto_bytes = UnsignedVarint::encode(proto_bytes)?; - - header.append(&mut proto_bytes.into()); - - return Ok((supported.clone(), BytesMut::from(&header[..]))); + return Ok(ListenerSelectResult::Accepted { + protocol: supported.clone(), + message: encode_multistream_message(std::iter::once(Message::Protocol( + protocol, + )))?, + }); } } } - Err(Error::NegotiationError( - error::NegotiationError::MultistreamSelectError(NegotiationError::Failed), - )) + tracing::trace!( + target: LOG_TARGET, + "listener: handshake rejected, no supported protocol found", + ); + + Ok(ListenerSelectResult::Rejected { + message: encode_multistream_message(std::iter::once(Message::NotAvailable))?, + }) } #[cfg(test)] @@ -382,14 +405,137 @@ mod tests { use super::*; #[test] - fn listener_negotiate_works() {} + fn listener_negotiate_works() { + let mut local_protocols = vec![ + ProtocolName::from("/13371338/proto/1"), + ProtocolName::from("/sup/proto/1"), + ProtocolName::from("/13371338/proto/2"), + ProtocolName::from("/13371338/proto/3"), + ProtocolName::from("/13371338/proto/4"), + ]; + let message = encode_multistream_message( + vec![ + Message::Protocol(Protocol::try_from(&b"/13371338/proto/1"[..]).unwrap()), + Message::Protocol(Protocol::try_from(&b"/sup/proto/1"[..]).unwrap()), + ] + .into_iter(), + ) + .unwrap() + .freeze(); + + match listener_negotiate(&mut local_protocols.iter(), message) { + Err(error) => panic!("error received: {error:?}"), + Ok(ListenerSelectResult::Rejected { .. }) => panic!("message rejected"), + Ok(ListenerSelectResult::Accepted { protocol, message }) => { + assert_eq!(protocol, ProtocolName::from("/13371338/proto/1")); + } + } + } #[test] - fn invalid_message_offered() {} + fn invalid_message() { + let mut local_protocols = vec![ + ProtocolName::from("/13371338/proto/1"), + ProtocolName::from("/sup/proto/1"), + ProtocolName::from("/13371338/proto/2"), + ProtocolName::from("/13371338/proto/3"), + ProtocolName::from("/13371338/proto/4"), + ]; + let message = encode_multistream_message(std::iter::once(Message::Protocols(vec![ + Protocol::try_from(&b"/13371338/proto/1"[..]).unwrap(), + Protocol::try_from(&b"/sup/proto/1"[..]).unwrap(), + ]))) + .unwrap() + .freeze(); + + match listener_negotiate(&mut local_protocols.iter(), message) { + Err(error) => assert!(std::matches!(error, Error::InvalidData)), + _ => panic!("invalid event"), + } + } #[test] - fn no_supported_protocol() {} + fn only_header_line_received() { + let mut local_protocols = vec![ + ProtocolName::from("/13371338/proto/1"), + ProtocolName::from("/sup/proto/1"), + ProtocolName::from("/13371338/proto/2"), + ProtocolName::from("/13371338/proto/3"), + ProtocolName::from("/13371338/proto/4"), + ]; + + // send only header line + let mut bytes = BytesMut::with_capacity(32); + let message = Message::Header(HeaderLine::V1); + let _ = message.encode(&mut bytes).map_err(|_| Error::InvalidData).unwrap(); + + match listener_negotiate(&mut local_protocols.iter(), bytes.freeze()) { + Err(error) => assert!(std::matches!( + error, + Error::NegotiationError(error::NegotiationError::MultistreamSelectError( + NegotiationError::Failed + )) + )), + event => panic!("invalid event: {event:?}"), + } + } #[test] - fn multistream_select_header_missing() {} + fn header_line_missing() { + let mut local_protocols = vec![ + ProtocolName::from("/13371338/proto/1"), + ProtocolName::from("/sup/proto/1"), + ProtocolName::from("/13371338/proto/2"), + ProtocolName::from("/13371338/proto/3"), + ProtocolName::from("/13371338/proto/4"), + ]; + + // header line missing + let mut bytes = BytesMut::with_capacity(256); + let message = Message::Protocols(vec![ + Protocol::try_from(&b"/13371338/proto/1"[..]).unwrap(), + Protocol::try_from(&b"/sup/proto/1"[..]).unwrap(), + ]); + let _ = message.encode(&mut bytes).map_err(|_| Error::InvalidData).unwrap(); + + match listener_negotiate(&mut local_protocols.iter(), bytes.freeze()) { + Err(error) => assert!(std::matches!( + error, + Error::NegotiationError(error::NegotiationError::MultistreamSelectError( + NegotiationError::Failed + )) + )), + event => panic!("invalid event: {event:?}"), + } + } + + #[test] + fn protocol_not_supported() { + let mut local_protocols = vec![ + ProtocolName::from("/13371338/proto/1"), + ProtocolName::from("/sup/proto/1"), + ProtocolName::from("/13371338/proto/2"), + ProtocolName::from("/13371338/proto/3"), + ProtocolName::from("/13371338/proto/4"), + ]; + let message = encode_multistream_message( + vec![Message::Protocol( + Protocol::try_from(&b"/13371339/proto/1"[..]).unwrap(), + )] + .into_iter(), + ) + .unwrap() + .freeze(); + + match listener_negotiate(&mut local_protocols.iter(), message) { + Err(error) => panic!("error received: {error:?}"), + Ok(ListenerSelectResult::Rejected { message }) => { + assert_eq!( + message, + encode_multistream_message(std::iter::once(Message::NotAvailable)).unwrap() + ); + } + Ok(ListenerSelectResult::Accepted { protocol, message }) => panic!("message accepted"), + } + } } diff --git a/src/multistream_select/mod.rs b/src/multistream_select/mod.rs index a69493975..86abe0269 100644 --- a/src/multistream_select/mod.rs +++ b/src/multistream_select/mod.rs @@ -77,7 +77,9 @@ mod protocol; pub use crate::multistream_select::{ dialer_select::{dialer_select_proto, DialerSelectFuture, DialerState, HandshakeResult}, - listener_select::{listener_negotiate, listener_select_proto, ListenerSelectFuture}, + listener_select::{ + listener_negotiate, listener_select_proto, ListenerSelectFuture, ListenerSelectResult, + }, negotiated::{Negotiated, NegotiatedComplete, NegotiationError}, protocol::{HeaderLine, Message, Protocol, ProtocolError}, }; diff --git a/src/multistream_select/protocol.rs b/src/multistream_select/protocol.rs index bf710850a..cc196b6f1 100644 --- a/src/multistream_select/protocol.rs +++ b/src/multistream_select/protocol.rs @@ -25,9 +25,13 @@ //! `Stream` and `Sink` implementations of `MessageIO` and //! `MessageReader`. -use crate::multistream_select::{ - length_delimited::{LengthDelimited, LengthDelimitedReader}, - Version, +use crate::{ + codec::unsigned_varint::UnsignedVarint, + error::Error as Litep2pError, + multistream_select::{ + length_delimited::{LengthDelimited, LengthDelimitedReader}, + Version, + }, }; use bytes::{BufMut, Bytes, BytesMut}; @@ -223,6 +227,28 @@ impl Message { } } +/// Create `multistream-select` message from an iterator of `Message`s. +pub fn encode_multistream_message( + messages: impl IntoIterator, +) -> crate::Result { + // encode `/multistream-select/1.0.0` header + let mut bytes = BytesMut::with_capacity(32); + let message = Message::Header(HeaderLine::V1); + let _ = message.encode(&mut bytes).map_err(|_| Litep2pError::InvalidData)?; + let mut header = UnsignedVarint::encode(bytes)?; + + // encode each message + for message in messages { + let mut proto_bytes = BytesMut::with_capacity(256); + let _ = message.encode(&mut proto_bytes).map_err(|_| Litep2pError::InvalidData)?; + let proto_bytes = UnsignedVarint::encode(proto_bytes)?; + + header.append(&mut proto_bytes.into()); + } + + Ok(BytesMut::from(&header[..])) +} + /// A `MessageIO` implements a [`Stream`] and [`Sink`] of [`Message`]s. #[pin_project::pin_project] pub struct MessageIO { diff --git a/src/protocol/notification/mod.rs b/src/protocol/notification/mod.rs index 589555e49..b820c68f7 100644 --- a/src/protocol/notification/mod.rs +++ b/src/protocol/notification/mod.rs @@ -1652,6 +1652,29 @@ impl NotificationProtocol { .report_notification_stream_open_failure(peer, NotificationError::Rejected) .await; + // NOTE: this is used to work around an issue in Substrate where the protocol + // is not notified if an inbound substream is closed. That indicates that remote + // wishes the close the connection but `Notifications` still keeps the substream state + // as `Open` until the outbound substream is closed (even though the outbound substream + // is also closed at that point). This causes a further issue: inbound substreams + // are automatically opened when state is `Open`, even if the inbound substream belongs + // to a new "connection" (pair of substreams). + // + // basically what happens (from Substrate's PoV) is there are pair of substreams (`inbound1`, `outbound1`), + // litep2p closes both substreams so both `inbound1` and outbound1 become non-readable/writable. + // Substrate doesn't detect this an instead only marks `inbound1` is closed while still keeping + // the (now-closed) `outbound1` active and it will be detected closed only when Substrate tries to + // write something into that substream. If now litep2p tries to open a new connection to Substrate, + // the outbound substream from litep2p's PoV will be automatically accepted (https://github.com/paritytech/polkadot-sdk/blob/59b2661444de2a25f2125a831bd786035a9fac4b/substrate/client/network/src/protocol/notifications/handler.rs#L544-L556) + // but since Substrate thinks `outbound1` is still active, it won't open a new outbound substream + // and it ends up having (`inbound2`, `outbound1`) as its pair of substreams which doens't make sense. + // + // since litep2p is expecting to receive an inbound substream from Substrate and never receives it, + // it basically can't make progress with the substream open request because litep2p can't force Substrate + // to detect that `outbound1` is closed. Easiest (and very hacky at the same time) way to reset the substream + // state is to close the connection. This is not an appropriate way to fix the issue and causes issues with, + // e.g., smoldot which at the time of writing this doesn't support the transaction protocol. The only way to fix + // this cleanly is to make Substrate detect closed substreams correctly. if let Err(error) = self.service.force_close(peer) { tracing::debug!( target: LOG_TARGET, diff --git a/src/substream/mod.rs b/src/substream/mod.rs index 514b0c02e..8818ac557 100644 --- a/src/substream/mod.rs +++ b/src/substream/mod.rs @@ -24,7 +24,7 @@ use crate::{ codec::ProtocolCodec, error::{Error, SubstreamError}, - transport::{quic, tcp, websocket}, + transport::{quic, tcp, webrtc, websocket}, types::SubstreamId, PeerId, }; @@ -44,7 +44,7 @@ use std::{ }; /// Logging target for the file. -const LOG_TARGET: &str = "substream"; +const LOG_TARGET: &str = "litep2p::substream"; macro_rules! poll_flush { ($substream:expr, $cx:ident) => {{ @@ -52,6 +52,7 @@ macro_rules! poll_flush { SubstreamType::Tcp(substream) => Pin::new(substream).poll_flush($cx), SubstreamType::WebSocket(substream) => Pin::new(substream).poll_flush($cx), SubstreamType::Quic(substream) => Pin::new(substream).poll_flush($cx), + SubstreamType::WebRtc(substream) => Pin::new(substream).poll_flush($cx), #[cfg(test)] SubstreamType::Mock(_) => unreachable!(), } @@ -64,6 +65,7 @@ macro_rules! poll_write { SubstreamType::Tcp(substream) => Pin::new(substream).poll_write($cx, $frame), SubstreamType::WebSocket(substream) => Pin::new(substream).poll_write($cx, $frame), SubstreamType::Quic(substream) => Pin::new(substream).poll_write($cx, $frame), + SubstreamType::WebRtc(substream) => Pin::new(substream).poll_write($cx, $frame), #[cfg(test)] SubstreamType::Mock(_) => unreachable!(), } @@ -76,6 +78,7 @@ macro_rules! poll_read { SubstreamType::Tcp(substream) => Pin::new(substream).poll_read($cx, $buffer), SubstreamType::WebSocket(substream) => Pin::new(substream).poll_read($cx, $buffer), SubstreamType::Quic(substream) => Pin::new(substream).poll_read($cx, $buffer), + SubstreamType::WebRtc(substream) => Pin::new(substream).poll_read($cx, $buffer), #[cfg(test)] SubstreamType::Mock(_) => unreachable!(), } @@ -88,6 +91,7 @@ macro_rules! poll_shutdown { SubstreamType::Tcp(substream) => Pin::new(substream).poll_shutdown($cx), SubstreamType::WebSocket(substream) => Pin::new(substream).poll_shutdown($cx), SubstreamType::Quic(substream) => Pin::new(substream).poll_shutdown($cx), + SubstreamType::WebRtc(substream) => Pin::new(substream).poll_shutdown($cx), #[cfg(test)] SubstreamType::Mock(substream) => { let _ = Pin::new(substream).poll_close($cx); @@ -148,6 +152,7 @@ enum SubstreamType { Tcp(tcp::Substream), WebSocket(websocket::Substream), Quic(quic::Substream), + WebRtc(webrtc::Substream), #[cfg(test)] Mock(Box), } @@ -158,6 +163,7 @@ impl fmt::Debug for SubstreamType { Self::Tcp(_) => write!(f, "Tcp"), Self::WebSocket(_) => write!(f, "WebSocket"), Self::Quic(_) => write!(f, "Quic"), + Self::WebRtc(_) => write!(f, "WebRtc"), #[cfg(test)] Self::Mock(_) => write!(f, "Mock"), } @@ -276,6 +282,18 @@ impl Substream { Self::new(peer, substream_id, SubstreamType::Quic(substream), codec) } + /// Create new [`Substream`] for WebRTC. + pub(crate) fn new_webrtc( + peer: PeerId, + substream_id: SubstreamId, + substream: webrtc::Substream, + codec: ProtocolCodec, + ) -> Self { + tracing::trace!(target: LOG_TARGET, ?peer, ?codec, "create new substream for webrtc"); + + Self::new(peer, substream_id, SubstreamType::WebRtc(substream), codec) + } + /// Create new [`Substream`] for mocking. #[cfg(test)] pub(crate) fn new_mock( @@ -299,6 +317,7 @@ impl Substream { SubstreamType::Tcp(mut substream) => substream.shutdown().await, SubstreamType::WebSocket(mut substream) => substream.shutdown().await, SubstreamType::Quic(mut substream) => substream.shutdown().await, + SubstreamType::WebRtc(mut substream) => substream.shutdown().await, #[cfg(test)] SubstreamType::Mock(mut substream) => { let _ = futures::SinkExt::close(&mut substream).await; @@ -409,6 +428,29 @@ impl Substream { substream.write_all_chunks(&mut [len.freeze(), bytes]).await } }, + SubstreamType::WebRtc(ref mut substream) => match self.codec { + ProtocolCodec::Unspecified => panic!("codec is unspecified"), + ProtocolCodec::Identity(payload_size) => + Self::send_identity_payload(substream, payload_size, bytes).await, + ProtocolCodec::UnsignedVarint(max_size) => { + check_size!(max_size, bytes.len()); + + let mut buffer = [0u8; 10]; + let len = unsigned_varint::encode::usize(bytes.len(), &mut buffer); + let mut offset = 0; + + while offset < len.len() { + offset += substream.write(&len[offset..]).await?; + } + + while bytes.has_remaining() { + let nwritten = substream.write(&bytes).await?; + bytes.advance(nwritten); + } + + substream.flush().await.map_err(From::from) + } + }, } } } @@ -625,6 +667,14 @@ impl Sink for Substream { // `MockSubstream` implements `Sink` so calls to `start_send()` must be delegated delegate_start_send!(&mut self.substream, item); + tracing::trace!( + target: LOG_TARGET, + peer = ?self.peer, + substream_id = ?self.substream_id, + data_len = item.len(), + "Substream::start_send()", + ); + match self.codec { ProtocolCodec::Identity(payload_size) => { if item.len() != payload_size { diff --git a/src/transport/manager/mod.rs b/src/transport/manager/mod.rs index 61d7bdd5d..a125b5f62 100644 --- a/src/transport/manager/mod.rs +++ b/src/transport/manager/mod.rs @@ -1576,7 +1576,7 @@ impl TransportManager { Ok(None) => {} } } - _ => panic!("event not supported"), + event => panic!("event not supported: {event:?}"), } }, } diff --git a/src/transport/webrtc/config.rs b/src/transport/webrtc/config.rs index 526829d28..b93140103 100644 --- a/src/transport/webrtc/config.rs +++ b/src/transport/webrtc/config.rs @@ -27,4 +27,20 @@ use multiaddr::Multiaddr; pub struct Config { /// WebRTC listening address. pub listen_addresses: Vec, + + /// Connection datagram buffer size. + /// + /// How many datagrams can the buffer between `WebRtcTransport` and a connection handler hold. + pub datagram_buffer_size: usize, +} + +impl Default for Config { + fn default() -> Self { + Self { + listen_addresses: vec!["/ip4/127.0.0.1/udp/8888/webrtc-direct" + .parse() + .expect("valid multiaddress")], + datagram_buffer_size: 2048, + } + } } diff --git a/src/transport/webrtc/connection.rs b/src/transport/webrtc/connection.rs index bf49fdedf..ef8200008 100644 --- a/src/transport/webrtc/connection.rs +++ b/src/transport/webrtc/connection.rs @@ -18,33 +18,27 @@ // FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER // DEALINGS IN THE SOFTWARE. -#![allow(unused)] - use crate::{ - config::Role, - crypto::{ed25519::Keypair, noise::NoiseContext}, error::Error, - multistream_select::{listener_negotiate, DialerState, HandshakeResult}, + multistream_select::{listener_negotiate, DialerState, HandshakeResult, ListenerSelectResult}, protocol::{Direction, Permit, ProtocolCommand, ProtocolSet}, substream::Substream, transport::{ webrtc::{ - substream::SubstreamBackend, - util::{SubstreamContext, WebRtcMessage}, - WebRtcEvent, + substream::{Event as SubstreamEvent, Substream as WebRtcSubstream, SubstreamHandle}, + util::WebRtcMessage, }, Endpoint, }, - types::{protocol::ProtocolName, ConnectionId, SubstreamId}, + types::{protocol::ProtocolName, SubstreamId}, PeerId, }; -use futures::StreamExt; -use multiaddr::{multihash::Multihash, Multiaddr, Protocol}; +use futures::{Stream, StreamExt}; +use indexmap::IndexMap; use str0m::{ - change::Fingerprint, - channel::{ChannelConfig, ChannelData, ChannelId}, - net::Receive, + channel::{ChannelConfig, ChannelId}, + net::{Protocol as Str0mProtocol, Receive}, Event, IceConnectionState, Input, Output, Rtc, }; use tokio::{net::UdpSocket, sync::mpsc::Receiver}; @@ -52,84 +46,117 @@ use tokio::{net::UdpSocket, sync::mpsc::Receiver}; use std::{ collections::HashMap, net::SocketAddr, + pin::Pin, sync::Arc, - time::{Duration, Instant}, + task::{Context, Poll}, + time::Instant, }; /// Logging target for the file. const LOG_TARGET: &str = "litep2p::webrtc::connection"; -/// Create Noise prologue. -fn noise_prologue_new(local_fingerprint: Vec, remote_fingerprint: Vec) -> Vec { - const PREFIX: &[u8] = b"libp2p-webrtc-noise:"; - let mut prologue = - Vec::with_capacity(PREFIX.len() + local_fingerprint.len() + remote_fingerprint.len()); - prologue.extend_from_slice(PREFIX); - prologue.extend_from_slice(&remote_fingerprint); - prologue.extend_from_slice(&local_fingerprint); +/// Channel context. +#[derive(Debug)] +struct ChannelContext { + /// Protocol name. + protocol: ProtocolName, + + /// Fallback names. + fallback_names: Vec, + + /// Substream ID. + substream_id: SubstreamId, - prologue + /// Permit which keeps the connection open. + permit: Permit, } -/// WebRTC connection state. -#[derive(Debug)] -enum State { - /// Connection state is poisoned. - Poisoned, +/// Set of [`SubstreamHandle`]s. +struct SubstreamHandleSet { + /// Current index. + index: usize, - /// Connection state is closed. - Closed, + /// Substream handles. + handles: IndexMap, +} - /// Connection state is opened. - Opened { - /// Noise handshaker. - handshaker: NoiseContext, - }, +impl SubstreamHandleSet { + /// Create new [`SubstreamHandleSet`]. + pub fn new() -> Self { + Self { + index: 0usize, + handles: IndexMap::new(), + } + } - /// Handshake has been sent - HandshakeSent { - /// Noise handshaker. - handshaker: NoiseContext, - }, + /// Get mutable access to `SubstreamHandle`. + pub fn get_mut(&mut self, key: &ChannelId) -> Option<&mut SubstreamHandle> { + self.handles.get_mut(key) + } - /// Connection is open. - Open { - /// Remote peer ID. - peer: PeerId, - }, + /// Insert new handle to [`SubstreamHandleSet`]. + pub fn insert(&mut self, key: ChannelId, handle: SubstreamHandle) { + assert!(self.handles.insert(key, handle).is_none()); + } + + /// Remove handle from [`SubstreamHandleSet`]. + pub fn remove(&mut self, key: &ChannelId) -> Option { + self.handles.shift_remove(key) + } } -/// Substream state. +impl Stream for SubstreamHandleSet { + type Item = (ChannelId, Option); + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let len = match self.handles.len() { + 0 => return Poll::Pending, + len => len, + }; + let start_index = self.index; + + loop { + let index = self.index % len; + self.index += 1; + + let (key, stream) = self.handles.get_index_mut(index).expect("handle to exist"); + match stream.poll_next_unpin(cx) { + Poll::Pending => {} + Poll::Ready(event) => return Poll::Ready(Some((*key, event))), + } + + if self.index == start_index + len { + break Poll::Pending; + } + } + } +} + +/// Channel state. #[derive(Debug)] -enum SubstreamState { - /// Substream state is poisoned. - Poisoned, +enum ChannelState { + /// Channel is closing. + Closing, - /// Substream (outbound) is opening. - Opening { - /// Protocol. - protocol: ProtocolName, + /// Inbound channel is opening. + InboundOpening, - /// Negotiated fallback. - fallback: Option, + /// Outbound channel is opening. + OutboundOpening { + /// Channel context. + context: ChannelContext, /// `multistream-select` dialer state. dialer_state: DialerState, - - /// Substream ID, - substream_id: SubstreamId, - - /// Connection permit. - permit: Permit, }, - /// Substream is open. + /// Channel is open. Open { /// Substream ID. substream_id: SubstreamId, - /// Substream. - substream: SubstreamContext, + /// Channel ID. + channel_id: ChannelId, /// Connection permit. permit: Permit, @@ -137,26 +164,19 @@ enum SubstreamState { } /// WebRTC connection. -// TODO: too much stuff, refactor? -pub(super) struct WebRtcConnection { - /// Connection ID. - pub(super) connection_id: ConnectionId, - +pub struct WebRtcConnection { /// `str0m` WebRTC object. - pub(super) rtc: Rtc, - - /// Noise channel ID. - _noise_channel_id: ChannelId, - - /// Identity keypair. - id_keypair: Keypair, - - /// Connection state. - state: State, + rtc: Rtc, /// Protocol set. protocol_set: ProtocolSet, + /// Remote peer ID. + peer: PeerId, + + /// Endpoint. + endpoint: Endpoint, + /// Peer address peer_address: SocketAddr, @@ -169,430 +189,423 @@ pub(super) struct WebRtcConnection { /// RX channel for receiving datagrams from the transport. dgram_rx: Receiver>, - /// Substream backend. - backend: SubstreamBackend, - - /// Next substream ID. - substream_id: SubstreamId, + /// Pending outbound channels. + pending_outbound: HashMap, - /// Pending outbound substreams. - pending_outbound: HashMap, SubstreamId, Permit)>, + /// Open channels. + channels: HashMap, - /// Open substreams. - substreams: HashMap, + /// Substream handles. + handles: SubstreamHandleSet, } impl WebRtcConnection { - pub(super) fn new( + /// Create new [`WebRtcConnection`]. + pub fn new( rtc: Rtc, - connection_id: ConnectionId, - _noise_channel_id: ChannelId, - id_keypair: Keypair, - protocol_set: ProtocolSet, + peer: PeerId, peer_address: SocketAddr, local_address: SocketAddr, socket: Arc, + protocol_set: ProtocolSet, + endpoint: Endpoint, dgram_rx: Receiver>, - ) -> WebRtcConnection { - WebRtcConnection { + ) -> Self { + Self { rtc, - socket, - dgram_rx, protocol_set, - id_keypair, + peer, peer_address, local_address, - connection_id, - _noise_channel_id, - state: State::Closed, - substreams: HashMap::new(), - backend: SubstreamBackend::new(), - substream_id: SubstreamId::new(), + socket, + endpoint, + dgram_rx, pending_outbound: HashMap::new(), + channels: HashMap::new(), + handles: SubstreamHandleSet::new(), } } - pub(super) async fn poll_output(&mut self) -> crate::Result { - match self.rtc.poll_output() { - Ok(output) => self.handle_output(output).await, - Err(error) => { - tracing::debug!( - target: LOG_TARGET, - connection_id = ?self.connection_id, - ?error, - "`WebRtcConnection::poll_output()` failed", - ); - return Err(Error::WebRtc(error)); - } - } - } - - /// Handle data received from peer. - pub(super) async fn on_input(&mut self, buffer: Vec) -> crate::Result<()> { - let message = Input::Receive( - Instant::now(), - Receive { - source: self.peer_address, - destination: self.local_address, - contents: buffer.as_slice().try_into().map_err(|_| Error::InvalidData)?, - }, + /// Handle opened channel. + /// + /// If the channel is inbound, nothing is done because we have to wait for data + /// `multistream-select` handshake to be received from remote peer before anything + /// else can be done. + /// + /// If the channel is outbound, send `multistream-select` handshake to remote peer. + async fn on_channel_opened( + &mut self, + channel_id: ChannelId, + channel_name: String, + ) -> crate::Result<()> { + tracing::trace!( + target: LOG_TARGET, + peer = ?self.peer, + ?channel_id, + ?channel_name, + "channel opened", ); - match self.rtc.accepts(&message) { - true => self.rtc.handle_input(message).map_err(|error| { - tracing::debug!(target: LOG_TARGET, source = ?self.peer_address, ?error, "failed to handle data"); - Error::InputRejected - }), - false => return Err(Error::InputRejected), - } - } + let Some(mut context) = self.pending_outbound.remove(&channel_id) else { + tracing::trace!( + target: LOG_TARGET, + peer = ?self.peer, + ?channel_id, + "inbound channel opened, wait for `multistream-select` message", + ); - async fn handle_output(&mut self, output: Output) -> crate::Result { - match output { - Output::Transmit(transmit) => { - self.socket - .send_to(&transmit.contents, transmit.destination) - .await - .expect("send to succeed"); - Ok(WebRtcEvent::Noop) - } - Output::Timeout(t) => Ok(WebRtcEvent::Timeout(t)), - Output::Event(e) => match e { - Event::IceConnectionStateChange(v) => { - if v == IceConnectionState::Disconnected { - tracing::debug!(target: LOG_TARGET, "ice connection closed"); - return Err(Error::Disconnected); - } - Ok(WebRtcEvent::Noop) - } - Event::ChannelOpen(cid, name) => { - // TODO: remove, report issue to smoldot - tokio::time::sleep(std::time::Duration::from_millis(500)).await; - self.on_channel_open(cid, name).map(|_| WebRtcEvent::Noop) - } - Event::ChannelData(data) => self.on_channel_data(data).await, - Event::ChannelClose(channel_id) => { - // TODO: notify the protocol - tracing::debug!(target: LOG_TARGET, ?channel_id, "channel closed"); - Ok(WebRtcEvent::Noop) - } - Event::Connected => { - match std::mem::replace(&mut self.state, State::Poisoned) { - State::Closed => { - let remote_fingerprint = self.remote_fingerprint(); - let local_fingerprint = self.local_fingerprint(); - - let handshaker = NoiseContext::with_prologue( - &self.id_keypair, - noise_prologue_new(local_fingerprint, remote_fingerprint), - ); - - self.state = State::Opened { handshaker }; - } - state => { - tracing::debug!( - target: LOG_TARGET, - ?state, - "invalid state for connection" - ); - return Err(Error::InvalidState); - } - } - Ok(WebRtcEvent::Noop) - } - event => { - tracing::warn!(target: LOG_TARGET, ?event, "unhandled event"); - Ok(WebRtcEvent::Noop) - } - }, - } - } - - /// Get remote fingerprint to bytes. - fn remote_fingerprint(&mut self) -> Vec { - let fingerprint = self - .rtc - .direct_api() - .remote_dtls_fingerprint() - .clone() - .expect("fingerprint to exist"); - Self::fingerprint_to_bytes(&fingerprint) - } - - /// Get local fingerprint as bytes. - fn local_fingerprint(&mut self) -> Vec { - Self::fingerprint_to_bytes(&self.rtc.direct_api().local_dtls_fingerprint()) - } - - /// Convert `Fingerprint` to bytes. - fn fingerprint_to_bytes(fingerprint: &Fingerprint) -> Vec { - const MULTIHASH_SHA256_CODE: u64 = 0x12; - Multihash::wrap(MULTIHASH_SHA256_CODE, &fingerprint.bytes) - .expect("fingerprint's len to be 32 bytes") - .to_bytes() - } - - fn on_noise_channel_open(&mut self) -> crate::Result<()> { - tracing::trace!(target: LOG_TARGET, "send initial noise handshake"); - - let State::Opened { mut handshaker } = std::mem::replace(&mut self.state, State::Poisoned) - else { - return Err(Error::InvalidState); + self.channels.insert(channel_id, ChannelState::InboundOpening); + return Ok(()); }; - // create first noise handshake and send it to remote peer - let payload = WebRtcMessage::encode(handshaker.first_message(Role::Dialer), None); + let fallback_names = std::mem::replace(&mut context.fallback_names, Vec::new()); + let (dialer_state, message) = + DialerState::propose(context.protocol.clone(), fallback_names)?; + let message = WebRtcMessage::encode(message); self.rtc - .channel(self._noise_channel_id) + .channel(channel_id) .ok_or(Error::ChannelDoesntExist)? - .write(true, payload.as_slice()) + .write(true, message.as_ref()) .map_err(|error| Error::WebRtc(error))?; - self.state = State::HandshakeSent { handshaker }; + self.channels.insert( + channel_id, + ChannelState::OutboundOpening { + context, + dialer_state, + }, + ); + Ok(()) } - fn on_channel_open(&mut self, channel_id: ChannelId, name: String) -> crate::Result<()> { - tracing::debug!(target: LOG_TARGET, ?channel_id, channel_name = ?name, "channel opened"); + /// Handle closed channel. + async fn on_channel_closed(&mut self, channel_id: ChannelId) -> crate::Result<()> { + tracing::trace!( + target: LOG_TARGET, + peer = ?self.peer, + ?channel_id, + "channel closed", + ); - if channel_id == self._noise_channel_id { - return self.on_noise_channel_open(); - } - - match self.pending_outbound.remove(&channel_id) { - None => { - tracing::trace!(target: LOG_TARGET, ?channel_id, "remote opened a substream"); - } - Some((protocol, fallback_names, substream_id, permit)) => { - tracing::trace!(target: LOG_TARGET, ?channel_id, "dialer negotiate protocol"); - - let (dialer_state, message) = - DialerState::propose(protocol.clone(), fallback_names)?; - let message = WebRtcMessage::encode(message, None); - - self.rtc - .channel(channel_id) - .ok_or(Error::ChannelDoesntExist)? - .write(true, message.as_ref()) - .map_err(|error| Error::WebRtc(error))?; - - self.substreams.insert( - channel_id, - SubstreamState::Opening { - protocol, - fallback: None, - substream_id, - dialer_state, - permit, - }, - ); - } - } + self.pending_outbound.remove(&channel_id); + self.channels.remove(&channel_id); + self.handles.remove(&channel_id); Ok(()) } - async fn on_noise_channel_data(&mut self, data: Vec) -> crate::Result { - tracing::trace!(target: LOG_TARGET, "handle noise handshake reply"); + /// Handle data received to an opening inbound channel. + /// + /// The first message received over an inbound channel is the `multistream-select` handshake. + /// This handshake contains the protocol (and potentially fallbacks for that protocol) that + /// remote peer wants to use for this channel. Parse the handshake and check if any of the + /// proposed protocols are supported by the local node. If not, send rejection to remote peer + /// and close the channel. If the local node supports one of the protocols, send confirmation + /// for the protocol to remote peer and report an opened substream to the selected protocol. + async fn on_inbound_opening_channel_data( + &mut self, + channel_id: ChannelId, + data: Vec, + ) -> crate::Result<(SubstreamId, SubstreamHandle, Permit)> { + tracing::trace!( + target: LOG_TARGET, + peer = ?self.peer, + ?channel_id, + "handle opening inbound substream", + ); - let State::HandshakeSent { mut handshaker } = - std::mem::replace(&mut self.state, State::Poisoned) - else { - return Err(Error::InvalidState); - }; + let payload = WebRtcMessage::decode(&data)?.payload.ok_or(Error::InvalidData)?; + let (response, negotiated) = + match listener_negotiate(&mut self.protocol_set.protocols().iter(), payload.into())? { + ListenerSelectResult::Accepted { protocol, message } => (message, Some(protocol)), + ListenerSelectResult::Rejected { message } => (message, None), + }; - let message = WebRtcMessage::decode(&data)?.payload.ok_or(Error::InvalidData)?; - let public_key = handshaker.get_remote_public_key(&message)?; - let remote_peer_id = PeerId::from_public_key(&public_key); + self.rtc + .channel(channel_id) + .ok_or(Error::ChannelDoesntExist)? + .write(true, WebRtcMessage::encode(response.to_vec()).as_ref()) + .map_err(|error| Error::WebRtc(error))?; + + let protocol = negotiated.ok_or(Error::SubstreamDoesntExist)?; + let substream_id = self.protocol_set.next_substream_id(); + let codec = self.protocol_set.protocol_codec(&protocol); + let permit = self.protocol_set.try_get_permit().ok_or(Error::ConnectionClosed)?; + let (substream, handle) = WebRtcSubstream::new(); + let substream = Substream::new_webrtc(self.peer, substream_id, substream, codec); tracing::trace!( target: LOG_TARGET, - ?remote_peer_id, - "remote reply parsed successfully" + peer = ?self.peer, + ?channel_id, + ?substream_id, + ?protocol, + "inbound substream opened", ); - // create second noise handshake message and send it to remote - let payload = WebRtcMessage::encode(handshaker.second_message(), None); + self.protocol_set + .report_substream_open(self.peer, protocol.clone(), Direction::Inbound, substream) + .await + .map(|_| (substream_id, handle, permit)) + } - let mut channel = - self.rtc.channel(self._noise_channel_id).ok_or(Error::ChannelDoesntExist)?; + /// Handle data received to an opening outbound channel. + /// + /// When an outbound channel is opened, the first message the local node sends it the + /// `multistream-select` handshake which contains the protocol (and any fallbacks for that + /// protocol) that the local node wants to use to negotiate for the channel. When a message is + /// received from a remote peer for a channel in state [`ChannelState::OutboundOpening`], parse + /// the `multistream-select` handshake response. The response either contains a rejection which + /// causes the substream to be closed, a partial response, or a full response. If a partial + /// response is heard, e.g., only the header line is received, the handshake cannot be concluded + /// and the channel is placed back in the [`ChannelState::OutboundOpening`] state to wait for + /// the rest of the handshake. If a full response is received (or rest of the partial response), + /// the protocol confirmation is verified and the substream is reported to the protocol. + /// + /// If the substream fails to open for whatever reason, since this is an outbound substream, + /// the protocol is notified of the failure. + async fn on_outbound_opening_channel_data( + &mut self, + channel_id: ChannelId, + data: Vec, + mut dialer_state: DialerState, + context: ChannelContext, + ) -> crate::Result> { + tracing::trace!( + target: LOG_TARGET, + peer = ?self.peer, + ?channel_id, + "handle opening outbound substream", + ); - channel.write(true, payload.as_slice()).map_err(|error| Error::WebRtc(error))?; + let message = WebRtcMessage::decode(&data)?.payload.ok_or(Error::InvalidData)?; + let HandshakeResult::Succeeded(protocol) = dialer_state.register_response(message)? else { + tracing::trace!( + target: LOG_TARGET, + peer = ?self.peer, + ?channel_id, + "multisteam-select handshake not ready", + ); + + self.channels.insert( + channel_id, + ChannelState::OutboundOpening { + context, + dialer_state, + }, + ); - let remote_fingerprint = self - .rtc - .direct_api() - .remote_dtls_fingerprint() - .clone() - .expect("fingerprint to exist") - .bytes; + return Ok(None); + }; - const MULTIHASH_SHA256_CODE: u64 = 0x12; - let certificate = Multihash::wrap(MULTIHASH_SHA256_CODE, &remote_fingerprint) - .expect("fingerprint's len to be 32 bytes"); + let ChannelContext { + substream_id, + permit, + .. + } = context; + let codec = self.protocol_set.protocol_codec(&protocol); + let (substream, handle) = WebRtcSubstream::new(); + let substream = Substream::new_webrtc(self.peer, substream_id, substream, codec); - let address = Multiaddr::empty() - .with(Protocol::from(self.peer_address.ip())) - .with(Protocol::Udp(self.peer_address.port())) - .with(Protocol::WebRTC) - .with(Protocol::Certhash(certificate)) - .with(Protocol::P2p(PeerId::from(public_key).into())); + tracing::trace!( + target: LOG_TARGET, + peer = ?self.peer, + ?channel_id, + ?substream_id, + ?protocol, + "outbound substream opened", + ); self.protocol_set - .report_connection_established( - remote_peer_id, - Endpoint::listener(address, self.connection_id), + .report_substream_open( + self.peer, + protocol.clone(), + Direction::Outbound(substream_id), + substream, ) - .await?; - - self.state = State::Open { - peer: remote_peer_id, - }; - - Ok(WebRtcEvent::Noop) + .await + .map(|_| Some((substream_id, handle, permit))) } - /// Report open substream to the protocol. - async fn report_open_substream( + /// Handle data received from an open channel. + async fn on_open_channel_data( &mut self, channel_id: ChannelId, - protocol: ProtocolName, - ) -> crate::Result { - // let substream_id = self.substream_id.next(); - // let (mut substream, tx) = self.backend.substream(channel_id); - // let substream: Box = { - // substream.apply_codec(self.protocol_set.protocol_codec(&protocol)); - // Box::new(substream) - // }; - // let permit = self.protocol_set.try_get_permit().ok_or(Error::ConnectionClosed)?; - - // self.substreams.insert( - // channel_id, - // SubstreamState::Open { - // substream_id, - // substream: SubstreamContext::new(channel_id, tx), - // permit, - // }, - // ); - // TODO: fix - - if let State::Open { peer, .. } = &mut self.state { - // let _ = self - // .protocol_set - // .report_substream_open(*peer, protocol.clone(), Direction::Inbound, substream) - // .await; - todo!(); - } - - Ok(WebRtcEvent::Noop) - } - - /// Negotiate protocol for the channel - async fn listener_negotiate_protocol(&mut self, d: ChannelData) -> crate::Result { - tracing::trace!(target: LOG_TARGET, channel_id = ?d.id, "negotiate protocol for the channel"); - - let payload = WebRtcMessage::decode(&d.data)?.payload.ok_or(Error::InvalidData)?; + data: Vec, + ) -> crate::Result<()> { + let message = WebRtcMessage::decode(&data)?; - let (protocol, response) = - listener_negotiate(&mut self.protocol_set.protocols().iter(), payload.into())?; - - let message = WebRtcMessage::encode(response.to_vec(), None); - - self.rtc - .channel(d.id) - .ok_or(Error::ChannelDoesntExist)? - .write(true, message.as_ref()) - .map_err(|error| Error::WebRtc(error))?; + tracing::trace!( + target: LOG_TARGET, + peer = ?self.peer, + ?channel_id, + flags = message.flags, + data_len = message.payload.as_ref().map_or(0usize, |payload| payload.len()), + "handle inbound message", + ); - self.report_open_substream(d.id, protocol).await - - // let substream_id = self.substream_id.next(); - // let (mut substream, tx) = self.backend.substream(d.id); - // let substream: Box = { - // substream.apply_codec(self.protocol_set.protocol_codec(&protocol)); - // Box::new(substream) - // }; - // let permit = self.protocol_set.try_get_permit().ok_or(Error::ConnectionClosed)?; - - // self.substreams.insert( - // d.id, - // SubstreamState::Open { - // substream_id, - // substream: SubstreamContext::new(d.id, tx), - // permit, - // }, - // ); - - // if let State::Open { peer, .. } = &mut self.state { - // let _ = self - // .protocol_set - // .report_substream_open(*peer, protocol.clone(), Direction::Inbound, substream) - // .await; - // } - // Ok(WebRtcEvent::Noop) + self.handles + .get_mut(&channel_id) + .ok_or_else(|| { + tracing::warn!( + target: LOG_TARGET, + peer = ?self.peer, + ?channel_id, + "data received from an unknown channel", + ); + debug_assert!(false); + Error::InvalidState + })? + .on_message(message) + .await } - async fn on_channel_data(&mut self, d: ChannelData) -> crate::Result { - match &self.state { - State::HandshakeSent { .. } => self.on_noise_channel_data(d.data).await, - State::Open { .. } => { - match self.substreams.get_mut(&d.id) { - None => match self.listener_negotiate_protocol(d).await { - Ok(_) => { - tracing::debug!(target: LOG_TARGET, "protocol negotiated for the channel"); - - Ok(WebRtcEvent::Noop) - } - Err(error) => { - tracing::debug!(target: LOG_TARGET, ?error, "failed to negotiate protocol"); + /// Handle data received from a channel. + async fn on_inbound_data(&mut self, channel_id: ChannelId, data: Vec) -> crate::Result<()> { + let Some(state) = self.channels.remove(&channel_id) else { + tracing::warn!( + target: LOG_TARGET, + peer = ?self.peer, + ?channel_id, + "data received over a channel that doesn't exist", + ); + debug_assert!(false); + return Err(Error::InvalidState); + }; - // TODO: close channel - Ok(WebRtcEvent::Noop) - } - }, - Some(SubstreamState::Poisoned) => return Err(Error::ConnectionClosed), - Some(SubstreamState::Opening { - ref mut dialer_state, - .. - }) => { - tracing::info!(target: LOG_TARGET, "try to decode message"); - let message = - WebRtcMessage::decode(&d.data)?.payload.ok_or(Error::InvalidData)?; - tracing::info!(target: LOG_TARGET, "decoded successfully"); - - match dialer_state.register_response(message) { - Ok(HandshakeResult::NotReady) => {} - Ok(HandshakeResult::Succeeded(protocol)) => { - tracing::warn!(target: LOG_TARGET, ?protocol, "protocol negotiated, inform protocol handler"); - - return self.report_open_substream(d.id, protocol).await; - } - Err(error) => { - tracing::error!(target: LOG_TARGET, ?error, "failed to negotiate protocol"); - // TODO: close channel - } - } + match state { + ChannelState::InboundOpening => + match self.on_inbound_opening_channel_data(channel_id, data).await { + Ok((substream_id, handle, permit)) => { + self.handles.insert(channel_id, handle); + self.channels.insert( + channel_id, + ChannelState::Open { + substream_id, + channel_id, + permit, + }, + ); + } + Err(error) => { + tracing::debug!( + target: LOG_TARGET, + peer = ?self.peer, + ?channel_id, + ?error, + "failed to handle opening inbound substream", + ); - Ok(WebRtcEvent::Noop) + self.channels.insert(channel_id, ChannelState::Closing); + self.rtc.direct_api().close_data_channel(channel_id); + } + }, + ChannelState::OutboundOpening { + context, + dialer_state, + } => { + let protocol = context.protocol.clone(); + let substream_id = context.substream_id; + + match self + .on_outbound_opening_channel_data(channel_id, data, dialer_state, context) + .await + { + Ok(Some((substream_id, handle, permit))) => { + self.handles.insert(channel_id, handle); + self.channels.insert( + channel_id, + ChannelState::Open { + substream_id, + channel_id, + permit, + }, + ); } - Some(SubstreamState::Open { substream, .. }) => { - // TODO: might be empty message with flags - // TODO: if decoding fails, close the substream - let message = - WebRtcMessage::decode(&d.data)?.payload.ok_or(Error::InvalidData)?; - let _ = substream.tx.send(message).await; - - Ok(WebRtcEvent::Noop) + Ok(None) => {} + Err(error) => { + tracing::debug!( + target: LOG_TARGET, + peer = ?self.peer, + ?channel_id, + ?error, + "failed to handle opening outbound substream", + ); + + let _ = self + .protocol_set + .report_substream_open_failure(protocol, substream_id, error) + .await; + + self.rtc.direct_api().close_data_channel(channel_id); + self.channels.insert(channel_id, ChannelState::Closing); } } } - _ => Err(Error::InvalidState), + ChannelState::Open { + substream_id, + channel_id, + permit, + } => match self.on_open_channel_data(channel_id, data).await { + Ok(()) => { + self.channels.insert( + channel_id, + ChannelState::Open { + substream_id, + channel_id, + permit, + }, + ); + } + Err(error) => { + tracing::debug!( + target: LOG_TARGET, + peer = ?self.peer, + ?channel_id, + ?error, + "failed to handle data for an open channel", + ); + + self.rtc.direct_api().close_data_channel(channel_id); + self.channels.insert(channel_id, ChannelState::Closing); + } + }, + ChannelState::Closing => { + tracing::debug!( + target: LOG_TARGET, + peer = ?self.peer, + ?channel_id, + "channel closing, discarding received data", + ); + self.channels.insert(channel_id, ChannelState::Closing); + } } + + Ok(()) + } + + /// Handle outbound data. + fn on_outbound_data(&mut self, channel_id: ChannelId, data: Vec) -> crate::Result<()> { + tracing::trace!( + target: LOG_TARGET, + peer = ?self.peer, + ?channel_id, + data_len = ?data.len(), + "send data", + ); + + self.rtc + .channel(channel_id) + .ok_or(Error::ChannelDoesntExist)? + .write(true, WebRtcMessage::encode(data).as_ref()) + .map_err(|error| Error::WebRtc(error)) + .map(|_| ()) } /// Open outbound substream. - fn open_substream( + fn on_open_substream( &mut self, protocol: ProtocolName, fallback_names: Vec, @@ -600,7 +613,7 @@ impl WebRtcConnection { permit: Permit, ) { let channel_id = self.rtc.direct_api().create_data_channel(ChannelConfig { - label: protocol.to_string(), + label: "".to_string(), ordered: false, reliability: Default::default(), negotiated: None, @@ -609,114 +622,201 @@ impl WebRtcConnection { tracing::trace!( target: LOG_TARGET, + peer = ?self.peer, ?channel_id, ?substream_id, ?protocol, ?fallback_names, - "open data channel" + "open data channel", ); - self.pending_outbound - .insert(channel_id, (protocol, fallback_names, substream_id, permit)); + self.pending_outbound.insert( + channel_id, + ChannelContext { + protocol, + fallback_names, + substream_id, + permit, + }, + ); } - /// Run the event loop of a negotiated WebRTC connection. - pub(super) async fn run(mut self) -> crate::Result<()> { - loop { - if !self.rtc.is_alive() { - tracing::debug!( - target: LOG_TARGET, - "`Rtc` is not alive, closing `WebRtcConnection`" - ); - return Ok(()); - } + /// Connection to peer has been closed. + async fn on_connection_closed(&mut self) { + tracing::trace!( + target: LOG_TARGET, + peer = ?self.peer, + "connection closed", + ); - let duration = match self.poll_output().await { - Ok(WebRtcEvent::Timeout(timeout)) => { - let timeout = - std::cmp::min(timeout, Instant::now() + Duration::from_millis(100)); - (timeout - Instant::now()).max(Duration::from_millis(1)) - } - Ok(WebRtcEvent::Noop) => continue, - Err(error) => { - tracing::debug!( + let _ = self + .protocol_set + .report_connection_closed(self.peer, self.endpoint.connection_id()) + .await; + } + + /// Start running event loop of [`WebRtcConnection`]. + pub async fn run(mut self) { + tracing::trace!( + target: LOG_TARGET, + peer = ?self.peer, + "start webrtc connection event loop", + ); + + let _ = self + .protocol_set + .report_connection_established(self.peer, self.endpoint.clone()) + .await; + + loop { + // poll output until we get a timeout + let timeout = match self.rtc.poll_output().unwrap() { + Output::Timeout(v) => v, + Output::Transmit(v) => { + tracing::trace!( target: LOG_TARGET, - ?error, - "error occurred, closing connection" + peer = ?self.peer, + datagram_len = ?v.contents.len(), + "transmit data", ); - self.rtc.disconnect(); - return Ok(()); + + self.socket.try_send_to(&v.contents, v.destination).unwrap(); + continue; } + Output::Event(v) => match v { + Event::IceConnectionStateChange(IceConnectionState::Disconnected) => { + tracing::trace!( + target: LOG_TARGET, + peer = ?self.peer, + "ice connection state changed to closed", + ); + return self.on_connection_closed().await; + } + Event::ChannelOpen(channel_id, name) => { + if let Err(error) = self.on_channel_opened(channel_id, name).await { + tracing::debug!( + target: LOG_TARGET, + peer = ?self.peer, + ?channel_id, + ?error, + "failed to handle opened channel", + ); + } + + continue; + } + Event::ChannelClose(channel_id) => { + if let Err(error) = self.on_channel_closed(channel_id).await { + tracing::debug!( + target: LOG_TARGET, + peer = ?self.peer, + ?channel_id, + ?error, + "failed to handle closed channel", + ); + } + + continue; + } + Event::ChannelData(info) => { + if let Err(error) = self.on_inbound_data(info.id, info.data).await { + tracing::debug!( + target: LOG_TARGET, + peer = ?self.peer, + channel_id = ?info.id, + ?error, + "failed to handle channel data", + ); + } + + continue; + } + event => { + tracing::debug!( + target: LOG_TARGET, + peer = ?self.peer, + ?event, + "unhandled event", + ); + continue; + } + }, }; + let duration = timeout - Instant::now(); + if duration.is_zero() { + self.rtc.handle_input(Input::Timeout(Instant::now())).unwrap(); + continue; + } + tokio::select! { - message = self.dgram_rx.recv() => match message { - Some(message) => match self.on_input(message).await { - Ok(_) | Err(Error::InputRejected) => {}, - Err(error) => { - tracing::debug!(target: LOG_TARGET, ?error, "failed to handle input"); - return Err(error) - } + biased; + datagram = self.dgram_rx.recv() => match datagram { + Some(datagram) => { + let input = Input::Receive( + Instant::now(), + Receive { + proto: Str0mProtocol::Udp, + source: self.peer_address, + destination: self.local_address, + contents: datagram.as_slice().try_into().unwrap(), + }, + ); + + self.rtc.handle_input(input).unwrap(); } None => { - tracing::debug!( + tracing::trace!( target: LOG_TARGET, - source = ?self.peer_address, - "transport shut down, shutting down connection", + peer = ?self.peer, + "read `None` from `dgram_rx`", ); - return Ok(()); + return self.on_connection_closed().await; } }, - event = self.backend.next_event() => { - let (channel_id, message) = event.ok_or(Error::EssentialTaskClosed)?; + event = self.handles.next() => match event { + None => unreachable!(), + Some((channel_id, None | Some(SubstreamEvent::Close))) => { + tracing::trace!( + target: LOG_TARGET, + peer = ?self.peer, + ?channel_id, + "channel closed", + ); - match self.substreams.get_mut(&channel_id) { - None => { - tracing::debug!(target: LOG_TARGET, "protocol tried to send message over substream that doesn't exist"); - } - Some(SubstreamState::Poisoned) => {}, - Some(SubstreamState::Opening { .. }) => { - tracing::debug!(target: LOG_TARGET, "protocol tried to send message over substream that isn't open"); - } - Some(SubstreamState::Open { .. }) => { - tracing::trace!(target: LOG_TARGET, ?channel_id, ?message, "send message to remote peer"); - - self.rtc - .channel(channel_id) - .ok_or(Error::ChannelDoesntExist)? - .write(true, message.as_ref()) - .map_err(|error| Error::WebRtc(error))?; - } + self.rtc.direct_api().close_data_channel(channel_id); + self.channels.insert(channel_id, ChannelState::Closing); + self.handles.remove(&channel_id); } - } - event = self.protocol_set.next() => match event { - Some(event) => match event { - ProtocolCommand::OpenSubstream { protocol, fallback_names, substream_id, permit } => { - self.open_substream(protocol, fallback_names, substream_id, permit); - } - ProtocolCommand::ForceClose => { - tracing::debug!(target: LOG_TARGET, "force closing connection"); - return Ok(()); + Some((channel_id, Some(SubstreamEvent::Message(data)))) => { + if let Err(error) = self.on_outbound_data(channel_id, data) { + tracing::debug!( + target: LOG_TARGET, + ?channel_id, + ?error, + "failed to send data to remote peer", + ); } } - None => { - tracing::debug!(target: LOG_TARGET, "handle to protocol closed, closing connection"); - return Ok(()); + Some((_, Some(SubstreamEvent::RecvClosed))) => {} + }, + command = self.protocol_set.next() => match command { + None | Some(ProtocolCommand::ForceClose) => { + tracing::trace!( + target: LOG_TARGET, + peer = ?self.peer, + ?command, + "`ProtocolSet` instructed to close connection", + ); + return self.on_connection_closed().await; + } + Some(ProtocolCommand::OpenSubstream { protocol, fallback_names, substream_id, permit }) => { + self.on_open_substream(protocol, fallback_names, substream_id, permit); } }, - _ = tokio::time::sleep(duration) => {} - } - - // drive time forward in the client - if let Err(error) = self.rtc.handle_input(Input::Timeout(Instant::now())) { - tracing::debug!( - target: LOG_TARGET, - ?error, - "failed to handle timeout for `Rtc`" - ); - - self.rtc.disconnect(); - return Err(Error::Disconnected); + _ = tokio::time::sleep(duration) => { + self.rtc.handle_input(Input::Timeout(Instant::now())).unwrap(); + } } } } diff --git a/src/transport/webrtc/mod.rs b/src/transport/webrtc/mod.rs index 5ee3616fe..b0192d651 100644 --- a/src/transport/webrtc/mod.rs +++ b/src/transport/webrtc/mod.rs @@ -20,50 +20,52 @@ //! WebRTC transport. -#![allow(unused)] - use crate::{ error::{AddressError, Error}, transport::{ manager::TransportHandle, - webrtc::{config::Config, connection::WebRtcConnection}, - Transport, TransportBuilder, TransportEvent, + webrtc::{config::Config, connection::WebRtcConnection, opening::OpeningWebRtcConnection}, + Endpoint, Transport, TransportBuilder, TransportEvent, }, types::ConnectionId, PeerId, }; -use futures::{Stream, StreamExt}; +use futures::{future::BoxFuture, Future, Stream}; +use futures_timer::Delay; use multiaddr::{multihash::Multihash, Multiaddr, Protocol}; use socket2::{Domain, Socket, Type}; use str0m::{ change::{DtlsCert, IceCreds}, channel::{ChannelConfig, ChannelId}, - net::{DatagramRecv, Receive}, + net::{DatagramRecv, Protocol as Str0mProtocol, Receive}, Candidate, Input, Rtc, }; use tokio::{ io::ReadBuf, net::UdpSocket, - sync::mpsc::{channel, Sender}, + sync::mpsc::{channel, error::TrySendError, Sender}, }; use std::{ - collections::HashMap, + collections::{HashMap, VecDeque}, net::{IpAddr, SocketAddr}, pin::Pin, sync::Arc, task::{Context, Poll}, - time::Instant, + time::{Duration, Instant}, }; -pub mod config; +pub(crate) use substream::Substream; mod connection; +mod opening; mod substream; mod util; -mod schema { +pub mod config; + +pub(super) mod schema { pub(super) mod webrtc { include!(concat!(env!("OUT_DIR"), "/webrtc.rs")); } @@ -80,6 +82,40 @@ const LOG_TARGET: &str = "litep2p::webrtc"; const REMOTE_FINGERPRINT: &str = "sha-256 FF:FF:FF:FF:FF:FF:FF:FF:FF:FF:FF:FF:FF:FF:FF:FF:FF:FF:FF:FF:FF:FF:FF:FF:FF:FF:FF:FF:FF:FF:FF:FF"; +/// Connection context. +struct ConnectionContext { + /// Remote peer ID. + peer: PeerId, + + /// Connection ID. + connection_id: ConnectionId, + + /// TX channel for sending datagrams to the connection event loop. + tx: Sender>, +} + +/// Events received from opening connections that are handled +/// by the [`WebRtcTransport`] event loop. +enum ConnectionEvent { + /// Connection established. + ConnectionEstablished { + /// Remote peer ID. + peer: PeerId, + + /// Endpoint. + endpoint: Endpoint, + }, + + /// Connection to peer closed. + ConnectionClosed, + + /// Timeout. + Timeout { + /// Timeout duration. + duration: Duration, + }, +} + /// WebRTC transport. pub(crate) struct WebRtcTransport { /// Transport context. @@ -94,8 +130,23 @@ pub(crate) struct WebRtcTransport { /// Assigned listen addresss. listen_address: SocketAddr, + /// Datagram buffer size. + datagram_buffer_size: usize, + /// Connected peers. - peers: HashMap>>, + open: HashMap, + + /// OpeningWebRtc connections. + opening: HashMap, + + /// `ConnectionId -> SocketAddr` mappings. + connections: HashMap, + + /// Pending timeouts. + timeouts: HashMap>, + + /// Pending events. + pending_events: VecDeque, } impl WebRtcTransport { @@ -174,8 +225,8 @@ impl WebRtcTransport { .set_dtls_cert(self.dtls_cert.clone()) .set_fingerprint_verification(false) .build(); - rtc.add_local_candidate(Candidate::host(destination).unwrap()); - rtc.add_remote_candidate(Candidate::host(source).unwrap()); + rtc.add_local_candidate(Candidate::host(destination, Str0mProtocol::Udp).unwrap()); + rtc.add_remote_candidate(Candidate::host(source, Str0mProtocol::Udp).unwrap()); rtc.direct_api() .set_remote_fingerprint(REMOTE_FINGERPRINT.parse().expect("parse() to succeed")); rtc.direct_api().set_remote_ice_credentials(IceCreds { @@ -201,18 +252,88 @@ impl WebRtcTransport { (rtc, noise_channel_id) } + /// Poll opening connection. + fn poll_connection(&mut self, source: &SocketAddr) -> ConnectionEvent { + let Some(connection) = self.opening.get_mut(source) else { + tracing::warn!( + target: LOG_TARGET, + ?source, + "connection doesn't exist", + ); + return ConnectionEvent::ConnectionClosed; + }; + + loop { + match connection.poll_process() { + opening::WebRtcEvent::Timeout { timeout } => { + let duration = timeout - Instant::now(); + + match duration.is_zero() { + true => match connection.on_timeout() { + Ok(()) => continue, + Err(error) => { + tracing::debug!( + target: LOG_TARGET, + ?source, + ?error, + "failed to handle timeout", + ); + + return ConnectionEvent::ConnectionClosed; + } + }, + false => return ConnectionEvent::Timeout { duration }, + } + } + opening::WebRtcEvent::Transmit { + destination, + datagram, + } => + if let Err(error) = self.socket.try_send_to(&datagram, destination) { + tracing::warn!( + target: LOG_TARGET, + ?source, + ?error, + "failed to send datagram", + ); + }, + opening::WebRtcEvent::ConnectionClosed => return ConnectionEvent::ConnectionClosed, + opening::WebRtcEvent::ConnectionOpened { peer, endpoint } => { + return ConnectionEvent::ConnectionEstablished { peer, endpoint }; + } + } + } + } + /// Handle socket input. - fn on_socket_input(&mut self, source: SocketAddr, buffer: Vec) -> crate::Result<()> { - // if the `Rtc` object already exists for `souce`, pass the message directly to that - // connection. - if let Some(tx) = self.peers.get_mut(&source) { - // TODO: implement properly + /// + /// If the datagram was received from an active client, it's dispatched to the connection + /// handler, if there is space in the queue. If the datagram opened a new connection or it + /// belonged to a client who is opening, the event loop is instructed to poll the client + /// until it timeouts. + /// + /// Returns `true` if the client should be polled. + fn on_socket_input(&mut self, source: SocketAddr, buffer: Vec) -> crate::Result { + if let Some(ConnectionContext { + peer, + connection_id, + tx, + }) = self.open.get_mut(&source) + { match tx.try_send(buffer) { - Ok(()) => return Ok(()), - Err(error) => { - tracing::warn!(target: LOG_TARGET, ?error, "failed to send datagram to connection"); - return Ok(()); + Ok(_) => return Ok(false), + Err(TrySendError::Full(_)) => { + tracing::warn!( + target: LOG_TARGET, + ?source, + ?peer, + ?connection_id, + "channel full, dropping datagram", + ); + + return Ok(false); } + Err(TrySendError::Closed(_)) => return Ok(false), } } @@ -222,7 +343,7 @@ impl WebRtcTransport { buffer.as_slice().try_into().map_err(|_| Error::InvalidData)?; match contents { - DatagramRecv::Stun(message) => { + DatagramRecv::Stun(message) if !self.opening.contains_key(&source) => { if let Some((ufrag, pass)) = message.split_username() { tracing::debug!( target: LOG_TARGET, @@ -244,44 +365,38 @@ impl WebRtcTransport { Instant::now(), Receive { source, + proto: Str0mProtocol::Udp, destination: self.socket.local_addr().unwrap(), contents: DatagramRecv::Stun(message.clone()), }, )) .expect("client to handle input successfully"); - let (tx, rx) = channel(64); let connection_id = self.context.next_connection_id(); - - let connection = WebRtcConnection::new( + let connection = OpeningWebRtcConnection::new( rtc, connection_id, noise_channel_id, self.context.keypair.clone(), - self.context.protocol_set(connection_id), source, self.listen_address, - Arc::clone(&self.socket), - rx, ); - - self.context.executor.run(Box::pin(async move { - let _ = connection.run().await; - })); - self.peers.insert(source, tx); + self.opening.insert(source, connection); } } - message => { - tracing::error!( - target: LOG_TARGET, - ?source, - ?message, - "received unexpected message for a connection that doesn't eixst" - ); + msg => { + if let Err(error) = self.opening.get_mut(&source).expect("to exist").on_input(msg) { + tracing::error!( + target: LOG_TARGET, + ?error, + ?source, + "failed to handle inbound datagram" + ); + } } } - Ok(()) + Ok(true) } } @@ -303,18 +418,18 @@ impl TransportBuilder for WebRtcTransport { let (listen_address, _) = Self::get_socket_address(&config.listen_addresses[0])?; let socket = match listen_address.is_ipv4() { true => { - let socket = Socket::new(Domain::IPV6, Type::DGRAM, Some(socket2::Protocol::UDP))?; + let socket = Socket::new(Domain::IPV4, Type::DGRAM, Some(socket2::Protocol::UDP))?; socket.bind(&listen_address.into())?; socket } false => { - let socket = Socket::new(Domain::IPV4, Type::DGRAM, Some(socket2::Protocol::UDP))?; + let socket = Socket::new(Domain::IPV6, Type::DGRAM, Some(socket2::Protocol::UDP))?; socket.set_only_v6(true)?; socket.bind(&listen_address.into())?; socket } }; - socket.listen(1024)?; + socket.set_reuse_address(true)?; socket.set_nonblocking(true)?; #[cfg(unix)] @@ -343,8 +458,13 @@ impl TransportBuilder for WebRtcTransport { context, dtls_cert, listen_address, - peers: HashMap::new(), + open: HashMap::new(), + opening: HashMap::new(), + connections: HashMap::new(), socket: Arc::new(socket), + timeouts: HashMap::new(), + pending_events: VecDeque::new(), + datagram_buffer_size: config.datagram_buffer_size, }, listen_multi_addresses, )) @@ -363,12 +483,94 @@ impl Transport for WebRtcTransport { Err(Error::NotSupported(format!("webrtc cannot dial peers"))) } - fn accept(&mut self, _connection_id: ConnectionId) -> crate::Result<()> { + fn accept(&mut self, connection_id: ConnectionId) -> crate::Result<()> { + tracing::trace!( + target: LOG_TARGET, + ?connection_id, + "inbound connection accepted", + ); + + let (peer, source, endpoint) = + self.connections.remove(&connection_id).ok_or_else(|| { + tracing::warn!( + target: LOG_TARGET, + ?connection_id, + "pending connection doens't exist", + ); + + Error::InvalidState + })?; + + let connection = self.opening.remove(&source).ok_or_else(|| { + tracing::warn!( + target: LOG_TARGET, + ?connection_id, + "pending connection doens't exist", + ); + + Error::InvalidState + })?; + + let rtc = connection.on_accept()?; + let (tx, rx) = channel(self.datagram_buffer_size); + let protocol_set = self.context.protocol_set(connection_id); + let connection_id = endpoint.connection_id(); + + let connection = WebRtcConnection::new( + rtc, + peer, + source, + self.listen_address, + Arc::clone(&self.socket), + protocol_set, + endpoint, + rx, + ); + self.open.insert( + source, + ConnectionContext { + tx, + peer, + connection_id, + }, + ); + + self.context.executor.run(Box::pin(async move { + connection.run().await; + })); + Ok(()) } - fn reject(&mut self, _connection_id: ConnectionId) -> crate::Result<()> { - Ok(()) + fn reject(&mut self, connection_id: ConnectionId) -> crate::Result<()> { + tracing::trace!( + target: LOG_TARGET, + ?connection_id, + "inbound connection rejected", + ); + + let (_, source, _) = self.connections.remove(&connection_id).ok_or_else(|| { + tracing::warn!( + target: LOG_TARGET, + ?connection_id, + "pending connection doens't exist", + ); + + Error::InvalidState + })?; + + self.opening + .remove(&source) + .ok_or_else(|| { + tracing::warn!( + target: LOG_TARGET, + ?connection_id, + "pending connection doens't exist", + ); + + Error::InvalidState + }) + .map(|_| ()) } fn open( @@ -383,51 +585,130 @@ impl Transport for WebRtcTransport { Ok(()) } - /// Cancel opening connections. fn cancel(&mut self, _connection_id: ConnectionId) {} } impl Stream for WebRtcTransport { type Item = TransportEvent; - fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - // TODO: optimizations - let mut buf = vec![0u8; 16384]; - let mut read_buf = ReadBuf::new(&mut buf); + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let this = Pin::into_inner(self); - match self.socket.poll_recv_from(cx, &mut read_buf) { - Poll::Pending => {} - Poll::Ready(Ok(source)) => { - let nread = read_buf.filled().len(); - buf.truncate(nread); + if let Some(event) = this.pending_events.pop_front() { + return Poll::Ready(Some(event)); + } - if let Err(error) = self.on_socket_input(source, buf) { - tracing::error!(target: LOG_TARGET, ?error, "failed to handle input"); - } - } - Poll::Ready(Err(error)) => { - tracing::debug!( - target: LOG_TARGET, - ?error, - "failed to read from webrtc socket", - ); + loop { + let mut buf = vec![0u8; 16384]; + let mut read_buf = ReadBuf::new(&mut buf); - return Poll::Ready(None); + match this.socket.poll_recv_from(cx, &mut read_buf) { + Poll::Pending => break, + Poll::Ready(Err(error)) => { + tracing::info!( + target: LOG_TARGET, + ?error, + "webrtc udp socket closed", + ); + + return Poll::Ready(None); + } + Poll::Ready(Ok(source)) => { + let nread = read_buf.filled().len(); + buf.truncate(nread); + + match this.on_socket_input(source, buf) { + Ok(false) => {} + Ok(true) => loop { + match this.poll_connection(&source) { + ConnectionEvent::ConnectionEstablished { peer, endpoint } => { + this.connections.insert( + endpoint.connection_id(), + (peer, source, endpoint.clone()), + ); + + // keep polling the connection until it registers a timeout + this.pending_events.push_back( + TransportEvent::ConnectionEstablished { peer, endpoint }, + ); + } + ConnectionEvent::ConnectionClosed => { + this.opening.remove(&source); + this.timeouts.remove(&source); + + break; + } + ConnectionEvent::Timeout { duration } => { + this.timeouts.insert( + source, + Box::pin(async move { Delay::new(duration).await }), + ); + + break; + } + } + }, + Err(error) => { + tracing::debug!( + target: LOG_TARGET, + ?source, + ?error, + "failed to handle datagram", + ); + } + } + } } } - Poll::Pending - } -} + // go over all pending timeouts to see if any of them have expired + // and if any of them have, poll the connection until it registers another timeout + let pending_events = this + .timeouts + .iter_mut() + .filter_map(|(source, mut delay)| match Pin::new(&mut delay).poll(cx) { + Poll::Pending => None, + Poll::Ready(_) => Some(*source), + }) + .collect::>() + .into_iter() + .filter_map(|source| { + let mut pending_event = None; + + loop { + match this.poll_connection(&source) { + ConnectionEvent::ConnectionEstablished { peer, endpoint } => { + this.connections + .insert(endpoint.connection_id(), (peer, source, endpoint.clone())); + + // keep polling the connection until it registers a timeout + pending_event = + Some(TransportEvent::ConnectionEstablished { peer, endpoint }); + } + ConnectionEvent::ConnectionClosed => { + this.opening.remove(&source); + return None; + } + ConnectionEvent::Timeout { duration } => { + this.timeouts.insert( + source, + Box::pin(async move { + Delay::new(duration); + }), + ); + break; + } + } + } -// TODO: remove -/// Events propagated between client. -#[allow(clippy::large_enum_variant)] -#[derive(Debug)] -enum WebRtcEvent { - /// When we have nothing to propagate. - Noop, + return pending_event; + }) + .collect::>(); - /// Poll client has reached timeout. - Timeout(Instant), + this.timeouts.retain(|source, _| this.opening.contains_key(source)); + this.pending_events.extend(pending_events); + this.pending_events + .pop_front() + .map_or(Poll::Pending, |event| Poll::Ready(Some(event))) + } } diff --git a/src/transport/webrtc/opening.rs b/src/transport/webrtc/opening.rs new file mode 100644 index 000000000..e4e39ea9c --- /dev/null +++ b/src/transport/webrtc/opening.rs @@ -0,0 +1,473 @@ +// Copyright 2023-2024 litep2p developers +// +// Permission is hereby granted, free of charge, to any person obtaining a +// copy of this software and associated documentation files (the "Software"), +// to deal in the Software without restriction, including without limitation +// the rights to use, copy, modify, merge, publish, distribute, sublicense, +// and/or sell copies of the Software, and to permit persons to whom the +// Software is furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +// OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +// DEALINGS IN THE SOFTWARE. + +//! WebRTC handshaking code for an opening connection. + +use crate::{ + config::Role, + crypto::{ed25519::Keypair, noise::NoiseContext}, + transport::{webrtc::util::WebRtcMessage, Endpoint}, + types::ConnectionId, + Error, PeerId, +}; + +use multiaddr::{multihash::Multihash, Multiaddr, Protocol}; +use str0m::{ + change::Fingerprint, + channel::ChannelId, + net::{DatagramRecv, DatagramSend, Protocol as Str0mProtocol, Receive}, + Event, IceConnectionState, Input, Output, Rtc, +}; + +use std::{net::SocketAddr, time::Instant}; + +/// Logging target for the file. +const LOG_TARGET: &str = "litep2p::webrtc::connection"; + +/// Create Noise prologue. +fn noise_prologue(local_fingerprint: Vec, remote_fingerprint: Vec) -> Vec { + const PREFIX: &[u8] = b"libp2p-webrtc-noise:"; + let mut prologue = + Vec::with_capacity(PREFIX.len() + local_fingerprint.len() + remote_fingerprint.len()); + prologue.extend_from_slice(PREFIX); + prologue.extend_from_slice(&remote_fingerprint); + prologue.extend_from_slice(&local_fingerprint); + + prologue +} + +/// WebRTC connection event. +#[derive(Debug)] +pub enum WebRtcEvent { + /// Register timeout for the connection. + Timeout { + /// Timeout. + timeout: Instant, + }, + + /// Transmit data to remote peer. + Transmit { + /// Destination. + destination: SocketAddr, + + /// Datagram to transmit. + datagram: DatagramSend, + }, + + /// Connection closed. + ConnectionClosed, + + /// Connection established. + ConnectionOpened { + /// Remote peer ID. + peer: PeerId, + + /// Endpoint. + endpoint: Endpoint, + }, +} + +/// Opening WebRTC connection. +/// +/// This object is used to track an opening connection which starts with a Noise handshake. +/// After the handshake is done, this object is destroyed and a new WebRTC connection object +/// is created which implements a normal connection event loop dealing with substreams. +pub struct OpeningWebRtcConnection { + /// WebRTC object + rtc: Rtc, + + /// Connection state. + state: State, + + /// Connection ID. + connection_id: ConnectionId, + + /// Noise channel ID. + noise_channel_id: ChannelId, + + /// Local keypair. + id_keypair: Keypair, + + /// Peer address + peer_address: SocketAddr, + + /// Local address. + local_address: SocketAddr, +} + +/// Connection state. +#[derive(Debug)] +enum State { + /// Connection is poisoned. + Poisoned, + + /// Connection is closed. + Closed, + + /// Connection has been opened. + Opened { + /// Noise context. + context: NoiseContext, + }, + + /// Local Noise handshake has been sent to peer and the connection + /// is waiting for an answer. + HandshakeSent { + /// Noise context. + context: NoiseContext, + }, + + /// Response to local Noise handshake has been received and the connection + /// is being validated by `TransportManager`. + Validating { + /// Noise context. + context: NoiseContext, + }, +} + +impl OpeningWebRtcConnection { + /// Create new [`OpeningWebRtcConnection`]. + pub fn new( + rtc: Rtc, + connection_id: ConnectionId, + noise_channel_id: ChannelId, + id_keypair: Keypair, + peer_address: SocketAddr, + local_address: SocketAddr, + ) -> OpeningWebRtcConnection { + tracing::trace!( + target: LOG_TARGET, + ?connection_id, + ?peer_address, + "new connection opened", + ); + + Self { + rtc, + state: State::Closed, + connection_id, + noise_channel_id, + id_keypair, + peer_address, + local_address, + } + } + + /// Get remote fingerprint to bytes. + fn remote_fingerprint(&mut self) -> Vec { + let fingerprint = self + .rtc + .direct_api() + .remote_dtls_fingerprint() + .clone() + .expect("fingerprint to exist"); + Self::fingerprint_to_bytes(&fingerprint) + } + + /// Get local fingerprint as bytes. + fn local_fingerprint(&mut self) -> Vec { + Self::fingerprint_to_bytes(&self.rtc.direct_api().local_dtls_fingerprint()) + } + + /// Convert `Fingerprint` to bytes. + fn fingerprint_to_bytes(fingerprint: &Fingerprint) -> Vec { + const MULTIHASH_SHA256_CODE: u64 = 0x12; + Multihash::wrap(MULTIHASH_SHA256_CODE, &fingerprint.bytes) + .expect("fingerprint's len to be 32 bytes") + .to_bytes() + } + + /// Once a Noise data channel has been opened, even though the light client was the dialer, + /// the WebRTC server will act as the dialer as per the specification. + /// + /// Create the first Noise handshake message and send it to remote peer. + fn on_noise_channel_open(&mut self) -> crate::Result<()> { + tracing::trace!(target: LOG_TARGET, "send initial noise handshake"); + + let State::Opened { mut context } = std::mem::replace(&mut self.state, State::Poisoned) + else { + return Err(Error::InvalidState); + }; + + // create first noise handshake and send it to remote peer + let payload = WebRtcMessage::encode(context.first_message(Role::Dialer)); + + self.rtc + .channel(self.noise_channel_id) + .ok_or(Error::ChannelDoesntExist)? + .write(true, payload.as_slice()) + .map_err(|error| Error::WebRtc(error))?; + + self.state = State::HandshakeSent { context }; + Ok(()) + } + + /// Handle timeout. + pub fn on_timeout(&mut self) -> crate::Result<()> { + if let Err(error) = self.rtc.handle_input(Input::Timeout(Instant::now())) { + tracing::error!( + target: LOG_TARGET, + ?error, + "failed to handle timeout for `Rtc`" + ); + + self.rtc.disconnect(); + return Err(Error::Disconnected); + } + + Ok(()) + } + + /// Handle Noise handshake response. + /// + /// The message contains remote's peer ID which is used by the `TransportManager` to validate + /// the connection. Note the Noise handshake requires one more messages to be sent by the dialer + /// (us) but the inbound connection must first be verified by the `TransportManager` which will + /// either accept or reject the connection. + /// + /// If the peer is accepted, [`OpeningWebRtcConnection::on_accept()`] is called which creates + /// the final Noise message and sends it to the remote peer, concluding the handshake. + fn on_noise_channel_data(&mut self, data: Vec) -> crate::Result { + tracing::trace!(target: LOG_TARGET, "handle noise handshake reply"); + + let State::HandshakeSent { mut context } = + std::mem::replace(&mut self.state, State::Poisoned) + else { + return Err(Error::InvalidState); + }; + + let message = WebRtcMessage::decode(&data)?.payload.ok_or(Error::InvalidData)?; + let public_key = context.get_remote_public_key(&message)?; + let remote_peer_id = PeerId::from_public_key(&public_key); + + tracing::trace!( + target: LOG_TARGET, + ?remote_peer_id, + "remote reply parsed successfully", + ); + + self.state = State::Validating { context }; + + let remote_fingerprint = self + .rtc + .direct_api() + .remote_dtls_fingerprint() + .clone() + .expect("fingerprint to exist") + .bytes; + + const MULTIHASH_SHA256_CODE: u64 = 0x12; + let certificate = Multihash::wrap(MULTIHASH_SHA256_CODE, &remote_fingerprint) + .expect("fingerprint's len to be 32 bytes"); + + let address = Multiaddr::empty() + .with(Protocol::from(self.peer_address.ip())) + .with(Protocol::Udp(self.peer_address.port())) + .with(Protocol::WebRTC) + .with(Protocol::Certhash(certificate)) + .with(Protocol::P2p(PeerId::from(public_key).into())); + + Ok(WebRtcEvent::ConnectionOpened { + peer: remote_peer_id, + endpoint: Endpoint::listener(address, self.connection_id), + }) + } + + /// Accept connection by sending the final Noise handshake message + /// and return the `Rtc` object for further use. + pub fn on_accept(mut self) -> crate::Result { + tracing::trace!(target: LOG_TARGET, "accept webrtc connection"); + + let State::Validating { mut context } = std::mem::replace(&mut self.state, State::Poisoned) + else { + return Err(Error::InvalidState); + }; + + // create second noise handshake message and send it to remote + let payload = WebRtcMessage::encode(context.second_message()); + + let mut channel = + self.rtc.channel(self.noise_channel_id).ok_or(Error::ChannelDoesntExist)?; + + channel.write(true, payload.as_slice()).map_err(|error| Error::WebRtc(error))?; + self.rtc.direct_api().close_data_channel(self.noise_channel_id); + + Ok(self.rtc) + } + + /// Handle input from peer. + pub fn on_input(&mut self, buffer: DatagramRecv) -> crate::Result<()> { + tracing::trace!( + target: LOG_TARGET, + peer = ?self.peer_address, + "handle input from peer", + ); + + let message = Input::Receive( + Instant::now(), + Receive { + source: self.peer_address, + proto: Str0mProtocol::Udp, + destination: self.local_address, + contents: buffer, + }, + ); + + match self.rtc.accepts(&message) { + true => self.rtc.handle_input(message).map_err(|error| { + tracing::debug!(target: LOG_TARGET, source = ?self.peer_address, ?error, "failed to handle data"); + Error::InputRejected + }), + false => { + tracing::warn!( + target: LOG_TARGET, + peer = ?self.peer_address, + "input rejected", + ); + return Err(Error::InputRejected); + } + } + } + + /// Progress the state of [`OpeningWebRtcConnection`]. + pub fn poll_process(&mut self) -> WebRtcEvent { + if !self.rtc.is_alive() { + tracing::debug!( + target: LOG_TARGET, + "`Rtc` is not alive, closing `WebRtcConnection`" + ); + + return WebRtcEvent::ConnectionClosed; + } + + loop { + let output = match self.rtc.poll_output() { + Ok(output) => output, + Err(error) => { + tracing::debug!( + target: LOG_TARGET, + connection_id = ?self.connection_id, + ?error, + "`WebRtcConnection::poll_process()` failed", + ); + + return WebRtcEvent::ConnectionClosed; + } + }; + + match output { + Output::Transmit(transmit) => { + tracing::trace!( + target: LOG_TARGET, + "transmit data", + ); + + return WebRtcEvent::Transmit { + destination: transmit.destination, + datagram: transmit.contents, + }; + } + Output::Timeout(timeout) => return WebRtcEvent::Timeout { timeout }, + Output::Event(e) => match e { + Event::IceConnectionStateChange(v) => + if v == IceConnectionState::Disconnected { + tracing::trace!(target: LOG_TARGET, "ice connection closed"); + return WebRtcEvent::ConnectionClosed; + }, + Event::ChannelOpen(channel_id, name) => { + tracing::trace!( + target: LOG_TARGET, + connection_id = ?self.connection_id, + ?channel_id, + ?name, + "channel opened", + ); + + if channel_id != self.noise_channel_id { + tracing::warn!( + target: LOG_TARGET, + connection_id = ?self.connection_id, + ?channel_id, + "ignoring opened channel", + ); + continue; + } + + // TODO: no expect + self.on_noise_channel_open().expect("to succeed"); + } + Event::ChannelData(data) => { + tracing::trace!( + target: LOG_TARGET, + "data received over channel", + ); + + if data.id != self.noise_channel_id { + tracing::warn!( + target: LOG_TARGET, + channel_id = ?data.id, + connection_id = ?self.connection_id, + "ignoring data from channel", + ); + continue; + } + + // TODO: no expect + return self.on_noise_channel_data(data.data).expect("to succeed"); + } + Event::ChannelClose(channel_id) => { + tracing::debug!(target: LOG_TARGET, ?channel_id, "channel closed"); + } + Event::Connected => match std::mem::replace(&mut self.state, State::Poisoned) { + State::Closed => { + let remote_fingerprint = self.remote_fingerprint(); + let local_fingerprint = self.local_fingerprint(); + + let context = NoiseContext::with_prologue( + &self.id_keypair, + noise_prologue(local_fingerprint, remote_fingerprint), + ); + + tracing::debug!( + target: LOG_TARGET, + peer = ?self.peer_address, + "connection opened", + ); + + self.state = State::Opened { context }; + } + state => { + tracing::debug!( + target: LOG_TARGET, + peer = ?self.peer_address, + ?state, + "invalid state for connection" + ); + return WebRtcEvent::ConnectionClosed; + } + }, + event => { + tracing::warn!(target: LOG_TARGET, ?event, "unhandled event"); + } + }, + } + } + } +} diff --git a/src/transport/webrtc/substream.rs b/src/transport/webrtc/substream.rs index dad62cd47..b027b8af5 100644 --- a/src/transport/webrtc/substream.rs +++ b/src/transport/webrtc/substream.rs @@ -1,4 +1,4 @@ -// Copyright 2023 litep2p developers +// Copyright 2024 litep2p developers // // Permission is hereby granted, free of charge, to any person obtaining a // copy of this software and associated documentation files (the "Software"), @@ -18,132 +18,447 @@ // FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER // DEALINGS IN THE SOFTWARE. -//! Channel-backed substream. - use crate::{ - codec::{identity::Identity, unsigned_varint::UnsignedVarint, ProtocolCodec}, - error::Error, + transport::webrtc::{schema::webrtc::message::Flag, util::WebRtcMessage}, + Error, }; -use bytes::BytesMut; -use futures::{Sink, Stream}; -use str0m::channel::ChannelId; +use bytes::{Buf, BufMut, BytesMut}; +use futures::{Future, Stream}; +use parking_lot::Mutex; use tokio::sync::mpsc::{channel, Receiver, Sender}; -use tokio_stream::wrappers::ReceiverStream; -use tokio_util::sync::PollSender; use std::{ pin::Pin, + sync::Arc, task::{Context, Poll}, }; -// TODO: use substream id +/// Maximum frame size. +const MAX_FRAME_SIZE: usize = 16384; + +/// Substream event. +#[derive(Debug, PartialEq, Eq)] +pub enum Event { + /// Receiver closed. + RecvClosed, + + /// Send/receive message. + Message(Vec), + + /// Close substream. + Close, +} + +/// Substream stream. +enum State { + /// Substream is fully open. + Open, -/// Channel-backed substream. -#[derive(Debug)] + /// Remote is no longer interested in receiving anything. + SendClosed, +} + +/// Channel-backedn substream. pub struct Substream { - /// Channel ID. - id: ChannelId, + /// Substream state. + state: Arc>, - /// TX channel for sending messages to transport. - tx: PollSender<(ChannelId, Vec)>, + /// Read buffer. + read_buffer: BytesMut, - /// RX channel for receiving messages from transport. - rx: ReceiverStream>, + /// TX channel for sending messages to `peer`. + tx: Sender, - /// Protocol codec. - codec: Option, + /// RX channel for receiving messages from `peer`. + rx: Receiver, } impl Substream { /// Create new [`Substream`]. - pub fn new(id: ChannelId, tx: Sender<(ChannelId, Vec)>) -> (Self, Sender>) { - let (to_protocol, rx) = channel(64); + pub fn new() -> (Self, SubstreamHandle) { + let (outbound_tx, outbound_rx) = channel(256); + let (inbound_tx, inbound_rx) = channel(256); + let state = Arc::new(Mutex::new(State::Open)); + let handle = SubstreamHandle { + tx: inbound_tx, + rx: outbound_rx, + state: Arc::clone(&state), + }; ( Self { - id, - codec: None, - tx: PollSender::new(tx), - rx: ReceiverStream::new(rx), + state, + tx: outbound_tx, + rx: inbound_rx, + read_buffer: BytesMut::new(), }, - to_protocol, + handle, ) } +} + +/// Substream handle that is given to the transport backend. +pub struct SubstreamHandle { + state: Arc>, + + /// TX channel for sending messages to `peer`. + tx: Sender, - /// Apply codec for the substream. - pub fn apply_codec(&mut self, codec: ProtocolCodec) { - self.codec = Some(codec); + /// RX channel for receiving messages from `peer`. + rx: Receiver, +} + +impl SubstreamHandle { + /// Handle message received from a remote peer. + /// + /// If the message contains any flags, handle them first and appropriately close the correct + /// side of the substream. If the message contained any payload, send it to the protocol for + /// further processing. + pub async fn on_message(&self, message: WebRtcMessage) -> crate::Result<()> { + if let Some(flags) = message.flags { + if flags == Flag::Fin as i32 { + let _ = self.tx.send(Event::RecvClosed).await?; + } + + if flags & 1 == Flag::StopSending as i32 { + *self.state.lock() = State::SendClosed; + } + + if flags & 2 == Flag::ResetStream as i32 { + return Err(Error::ConnectionClosed); + } + } + + if let Some(payload) = message.payload { + if !payload.is_empty() { + return self.tx.send(Event::Message(payload)).await.map_err(From::from); + } + } + + Ok(()) } } -impl Sink for Substream { - type Error = Error; +impl Stream for SubstreamHandle { + type Item = Event; - fn poll_ready<'a>(mut self: Pin<&mut Self>, cx: &mut Context<'a>) -> Poll> { - let pinned = Pin::new(&mut self.tx); - pinned.poll_ready(cx).map_err(|_| Error::Unknown) + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.rx.poll_recv(cx) } +} + +impl tokio::io::AsyncRead for Substream { + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut tokio::io::ReadBuf<'_>, + ) -> Poll> { + // if there are any remaining bytes from a previous read, consume them first + if self.read_buffer.remaining() > 0 { + let num_bytes = std::cmp::min(self.read_buffer.remaining(), buf.remaining()); - fn start_send(mut self: Pin<&mut Self>, item: bytes::Bytes) -> Result<(), Error> { - let item: Vec = match self.codec.as_ref().expect("codec to exist") { - ProtocolCodec::Identity(_) => Identity::encode(item)?.into(), - ProtocolCodec::UnsignedVarint(_) => UnsignedVarint::encode(item)?.into(), - ProtocolCodec::Unspecified => unreachable!(), // TODO: may not be correct + buf.put_slice(&self.read_buffer[..num_bytes]); + self.read_buffer.advance(num_bytes); + + // TODO: optimize by trying to read more data from substream and not exiting early + return Poll::Ready(Ok(())); + } + + loop { + match futures::ready!(self.rx.poll_recv(cx)) { + None | Some(Event::Close) | Some(Event::RecvClosed) => { + return Poll::Ready(Err(std::io::ErrorKind::BrokenPipe.into())); + } + Some(Event::Message(message)) => { + if message.len() > MAX_FRAME_SIZE { + return Poll::Ready(Err(std::io::ErrorKind::PermissionDenied.into())); + } + + match buf.remaining() >= message.len() { + true => buf.put_slice(&message), + false => { + let remaining = buf.remaining(); + buf.put_slice(&message[..remaining]); + self.read_buffer.put_slice(&message[remaining..]); + } + } + + return Poll::Ready(Ok(())); + } + } + } + } +} + +impl tokio::io::AsyncWrite for Substream { + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + if let State::SendClosed = *self.state.lock() { + return Poll::Ready(Err(std::io::ErrorKind::BrokenPipe.into())); + } + + // TODO: try to coalesce multiple calls to `poll_write()` into single `Event::Message` + + let num_bytes = std::cmp::min(MAX_FRAME_SIZE, buf.len()); + let future = self.tx.reserve(); + futures::pin_mut!(future); + + let permit = match futures::ready!(future.poll(cx)) { + Err(_) => return Poll::Ready(Err(std::io::ErrorKind::BrokenPipe.into())), + Ok(permit) => permit, }; - let id = self.id; - Pin::new(&mut self.tx).start_send((id, item)).map_err(|_| Error::Unknown) + let frame = buf[..num_bytes].to_vec(); + permit.send(Event::Message(frame)); + + Poll::Ready(Ok(num_bytes)) } - fn poll_flush<'a>(mut self: Pin<&mut Self>, cx: &mut Context<'a>) -> Poll> { - Pin::new(&mut self.tx).poll_flush(cx).map_err(|_| Error::Unknown) + fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) } - fn poll_close<'a>(mut self: Pin<&mut Self>, cx: &mut Context<'a>) -> Poll> { - Pin::new(&mut self.tx).poll_close(cx).map_err(|_| Error::Unknown) + fn poll_shutdown( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + let future = self.tx.reserve(); + futures::pin_mut!(future); + + let permit = match futures::ready!(future.poll(cx)) { + Err(_) => return Poll::Ready(Err(std::io::ErrorKind::BrokenPipe.into())), + Ok(permit) => permit, + }; + permit.send(Event::Close); + + Poll::Ready(Ok(())) } } -impl Stream for Substream { - type Item = crate::Result; +#[cfg(test)] +mod tests { + use super::*; + use futures::StreamExt; + use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, ReadBuf}; - fn poll_next<'a>( - mut self: Pin<&mut Self>, - cx: &mut Context<'a>, - ) -> Poll>> { - match Pin::new(&mut self.rx).poll_next(cx) { - Poll::Pending => Poll::Pending, - Poll::Ready(None) => Poll::Ready(None), - Poll::Ready(Some(value)) => Poll::Ready(Some(Ok(BytesMut::from(value.as_slice())))), + #[tokio::test] + async fn write_small_frame() { + let (mut substream, mut handle) = Substream::new(); + + substream.write_all(&vec![0u8; 1337]).await.unwrap(); + + assert_eq!(handle.next().await, Some(Event::Message(vec![0u8; 1337]))); + + futures::future::poll_fn(|cx| match handle.poll_next_unpin(cx) { + Poll::Pending => Poll::Ready(()), + Poll::Ready(_) => panic!("invalid event"), + }) + .await; + } + + #[tokio::test] + async fn write_large_frame() { + let (mut substream, mut handle) = Substream::new(); + + substream.write_all(&vec![0u8; (2 * MAX_FRAME_SIZE) + 1]).await.unwrap(); + + assert_eq!( + handle.rx.recv().await, + Some(Event::Message(vec![0u8; MAX_FRAME_SIZE])) + ); + assert_eq!( + handle.rx.recv().await, + Some(Event::Message(vec![0u8; MAX_FRAME_SIZE])) + ); + assert_eq!(handle.rx.recv().await, Some(Event::Message(vec![0u8; 1]))); + + futures::future::poll_fn(|cx| match handle.poll_next_unpin(cx) { + Poll::Pending => Poll::Ready(()), + Poll::Ready(_) => panic!("invalid event"), + }) + .await; + } + + #[tokio::test] + async fn try_to_write_to_closed_substream() { + let (mut substream, handle) = Substream::new(); + *handle.state.lock() = State::SendClosed; + + match substream.write_all(&vec![0u8; 1337]).await { + Err(error) => assert_eq!(error.kind(), std::io::ErrorKind::BrokenPipe), + _ => panic!("invalid event"), } } -} -// TODO: rename? -pub struct SubstreamBackend { - /// TX channel for creating new [`Substream`] objects. - tx: Sender<(ChannelId, Vec)>, + #[tokio::test] + async fn substream_shutdown() { + let (mut substream, mut handle) = Substream::new(); - /// RX channel for receiving messages from protocols. - rx: Receiver<(ChannelId, Vec)>, -} + substream.write_all(&vec![1u8; 1337]).await.unwrap(); + substream.shutdown().await.unwrap(); + + assert_eq!(handle.next().await, Some(Event::Message(vec![1u8; 1337]))); + assert_eq!(handle.next().await, Some(Event::Close)); + } + + #[tokio::test] + async fn try_to_read_from_closed_substream() { + let (mut substream, handle) = Substream::new(); + handle + .on_message(WebRtcMessage { + payload: None, + flags: Some(0i32), + }) + .await + .unwrap(); + + match substream.read(&mut vec![0u8; 256]).await { + Err(error) => assert_eq!(error.kind(), std::io::ErrorKind::BrokenPipe), + _ => panic!("invalid event"), + } + } + + #[tokio::test] + async fn read_small_frame() { + let (mut substream, handle) = Substream::new(); + handle.tx.send(Event::Message(vec![1u8; 256])).await.unwrap(); + + let mut buf = vec![0u8; 2048]; + + match substream.read(&mut buf).await { + Ok(nread) => { + assert_eq!(nread, 256); + assert_eq!(buf[..nread], vec![1u8; 256]); + } + Err(error) => panic!("invalid event: {error:?}"), + } + + let mut read_buf = ReadBuf::new(&mut buf); + futures::future::poll_fn(|cx| { + match Pin::new(&mut substream).poll_read(cx, &mut read_buf) { + Poll::Pending => Poll::Ready(()), + _ => panic!("invalid event"), + } + }) + .await; + } + + #[tokio::test] + async fn read_small_frame_in_two_reads() { + let (mut substream, handle) = Substream::new(); + let mut first = vec![1u8; 256]; + first.extend_from_slice(&vec![2u8; 256]); -impl SubstreamBackend { - /// Create new [`SubstreamBackend`]. - pub fn new() -> Self { - let (tx, rx) = channel(1024); + handle.tx.send(Event::Message(first)).await.unwrap(); - Self { tx, rx } + let mut buf = vec![0u8; 256]; + + match substream.read(&mut buf).await { + Ok(nread) => { + assert_eq!(nread, 256); + assert_eq!(buf[..nread], vec![1u8; 256]); + } + Err(error) => panic!("invalid event: {error:?}"), + } + + match substream.read(&mut buf).await { + Ok(nread) => { + assert_eq!(nread, 256); + assert_eq!(buf[..nread], vec![2u8; 256]); + } + Err(error) => panic!("invalid event: {error:?}"), + } + + let mut read_buf = ReadBuf::new(&mut buf); + futures::future::poll_fn(|cx| { + match Pin::new(&mut substream).poll_read(cx, &mut read_buf) { + Poll::Pending => Poll::Ready(()), + _ => panic!("invalid event"), + } + }) + .await; } - /// Create new substream. - pub fn substream(&mut self, id: ChannelId) -> (Substream, Sender>) { - Substream::new(id, self.tx.clone()) + #[tokio::test] + async fn read_frames() { + let (mut substream, handle) = Substream::new(); + let mut first = vec![1u8; 256]; + first.extend_from_slice(&vec![2u8; 256]); + + handle.tx.send(Event::Message(first)).await.unwrap(); + handle.tx.send(Event::Message(vec![4u8; 2048])).await.unwrap(); + + let mut buf = vec![0u8; 256]; + + match substream.read(&mut buf).await { + Ok(nread) => { + assert_eq!(nread, 256); + assert_eq!(buf[..nread], vec![1u8; 256]); + } + Err(error) => panic!("invalid event: {error:?}"), + } + + let mut buf = vec![0u8; 128]; + + match substream.read(&mut buf).await { + Ok(nread) => { + assert_eq!(nread, 128); + assert_eq!(buf[..nread], vec![2u8; 128]); + } + Err(error) => panic!("invalid event: {error:?}"), + } + + let mut buf = vec![0u8; 128]; + + match substream.read(&mut buf).await { + Ok(nread) => { + assert_eq!(nread, 128); + assert_eq!(buf[..nread], vec![2u8; 128]); + } + Err(error) => panic!("invalid event: {error:?}"), + } + + let mut buf = vec![0u8; MAX_FRAME_SIZE]; + + match substream.read(&mut buf).await { + Ok(nread) => { + assert_eq!(nread, 2048); + assert_eq!(buf[..nread], vec![4u8; 2048]); + } + Err(error) => panic!("invalid event: {error:?}"), + } + + let mut read_buf = ReadBuf::new(&mut buf); + futures::future::poll_fn(|cx| { + match Pin::new(&mut substream).poll_read(cx, &mut read_buf) { + Poll::Pending => Poll::Ready(()), + _ => panic!("invalid event"), + } + }) + .await; } - /// Poll next event. - pub async fn next_event(&mut self) -> Option<(ChannelId, Vec)> { - self.rx.recv().await + #[tokio::test] + async fn backpressure_works() { + let (mut substream, _handle) = Substream::new(); + + // use all available bandwidth which by default is `256 * MAX_FRAME_SIZE`, + for _ in 0..128 { + substream.write_all(&vec![0u8; 2 * MAX_FRAME_SIZE]).await.unwrap(); + } + + // try to write one more byte but since all available bandwidth + // is taken the call will block + futures::future::poll_fn(|cx| { + match Pin::new(&mut substream).poll_write(cx, &vec![0u8; 1]) { + Poll::Pending => Poll::Ready(()), + _ => panic!("invalid event"), + } + }) + .await; } } diff --git a/src/transport/webrtc/util.rs b/src/transport/webrtc/util.rs index e985f4ae4..82939d736 100644 --- a/src/transport/webrtc/util.rs +++ b/src/transport/webrtc/util.rs @@ -21,27 +21,8 @@ use crate::{codec::unsigned_varint::UnsignedVarint, error::Error, transport::webrtc::schema}; use prost::Message; -use str0m::channel::ChannelId; -use tokio::sync::mpsc::Sender; use tokio_util::codec::{Decoder, Encoder}; -/// Substream context. -#[derive(Debug)] -pub struct SubstreamContext { - /// `str0m` channel id. - pub channel_id: ChannelId, - - /// TX channel for sending messages to the protocol. - pub tx: Sender>, -} - -impl SubstreamContext { - /// Create new [`SubstreamContext`]. - pub fn new(channel_id: ChannelId, tx: Sender>) -> Self { - Self { channel_id, tx } - } -} - /// WebRTC mesage. #[derive(Debug)] pub struct WebRtcMessage { @@ -54,10 +35,29 @@ pub struct WebRtcMessage { impl WebRtcMessage { /// Encode WebRTC message. - pub fn encode(payload: Vec, flag: Option) -> Vec { + pub fn encode(payload: Vec) -> Vec { let protobuf_payload = schema::webrtc::Message { message: (!payload.is_empty()).then_some(payload), - flag, + flag: None, + }; + let mut payload = Vec::with_capacity(protobuf_payload.encoded_len()); + protobuf_payload + .encode(&mut payload) + .expect("Vec to provide needed capacity"); + + let mut out_buf = bytes::BytesMut::with_capacity(payload.len() + 4); + let mut codec = UnsignedVarint::new(None); + let _result = codec.encode(payload.into(), &mut out_buf); + + out_buf.into() + } + + /// Encode WebRTC message with flags. + #[allow(unused)] + pub fn encode_with_flags(payload: Vec, flags: i32) -> Vec { + let protobuf_payload = schema::webrtc::Message { + message: (!payload.is_empty()).then_some(payload), + flag: Some(flags), }; let mut payload = Vec::with_capacity(protobuf_payload.encoded_len()); protobuf_payload @@ -65,7 +65,6 @@ impl WebRtcMessage { .expect("Vec to provide needed capacity"); let mut out_buf = bytes::BytesMut::with_capacity(payload.len() + 4); - // TODO: set correct size let mut codec = UnsignedVarint::new(None); let _result = codec.encode(payload.into(), &mut out_buf); @@ -95,7 +94,7 @@ mod tests { #[test] fn with_payload_no_flags() { - let message = WebRtcMessage::encode("Hello, world!".as_bytes().to_vec(), None); + let message = WebRtcMessage::encode("Hello, world!".as_bytes().to_vec()); let decoded = WebRtcMessage::decode(&message).unwrap(); assert_eq!(decoded.payload, Some("Hello, world!".as_bytes().to_vec())); @@ -104,7 +103,7 @@ mod tests { #[test] fn with_payload_and_flags() { - let message = WebRtcMessage::encode("Hello, world!".as_bytes().to_vec(), Some(1i32)); + let message = WebRtcMessage::encode_with_flags("Hello, world!".as_bytes().to_vec(), 1i32); let decoded = WebRtcMessage::decode(&message).unwrap(); assert_eq!(decoded.payload, Some("Hello, world!".as_bytes().to_vec())); @@ -113,7 +112,7 @@ mod tests { #[test] fn no_payload_with_flags() { - let message = WebRtcMessage::encode(vec![], Some(2i32)); + let message = WebRtcMessage::encode_with_flags(vec![], 2i32); let decoded = WebRtcMessage::decode(&message).unwrap(); assert_eq!(decoded.payload, None); diff --git a/tests/webrtc.rs b/tests/webrtc.rs new file mode 100644 index 000000000..d80c6cb2c --- /dev/null +++ b/tests/webrtc.rs @@ -0,0 +1,80 @@ +// Copyright 2023 litep2p developers +// +// Permission is hereby granted, free of charge, to any person obtaining a +// copy of this software and associated documentation files (the "Software"), +// to deal in the Software without restriction, including without limitation +// the rights to use, copy, modify, merge, publish, distribute, sublicense, +// and/or sell copies of the Software, and to permit persons to whom the +// Software is furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +// OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +// DEALINGS IN THE SOFTWARE. + +use futures::StreamExt; +use litep2p::{ + config::ConfigBuilder as Litep2pConfigBuilder, + crypto::ed25519::Keypair, + protocol::{libp2p::ping, notification::ConfigBuilder}, + transport::webrtc::config::Config, + types::protocol::ProtocolName, + Litep2p, +}; + +#[tokio::test] +#[ignore] +async fn webrtc_test() { + let _ = tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .try_init(); + + let (ping_config, mut ping_event_stream) = ping::Config::default(); + let (notif_config, mut notif_event_stream) = ConfigBuilder::new(ProtocolName::from( + // Westend block-announces protocol name. + "/e143f23803ac50e8f6f8e62695d1ce9e4e1d68aa36c1cd2cfd15340213f3423e/block-announces/1", + )) + .with_max_size(5 * 1024 * 1024) + .with_handshake(vec![1, 2, 3, 4]) + .with_auto_accept_inbound(true) + .build(); + + let config = Litep2pConfigBuilder::new() + .with_keypair(Keypair::generate()) + .with_webrtc(Config { + listen_addresses: vec!["/ip4/192.168.1.170/udp/8888/webrtc-direct".parse().unwrap()], + ..Default::default() + }) + .with_libp2p_ping(ping_config) + .with_notification_protocol(notif_config) + .build(); + + let mut litep2p = Litep2p::new(config).unwrap(); + let address = litep2p.listen_addresses().next().unwrap().clone(); + + tracing::info!("listen address: {address:?}"); + + loop { + tokio::select! { + event = litep2p.next_event() => { + tracing::debug!("litep2p event received: {event:?}"); + } + event = ping_event_stream.next() => { + if std::matches!(event, None) { + tracing::error!("ping event stream terminated"); + break + } + tracing::error!("ping event received: {event:?}"); + } + _event = notif_event_stream.next() => { + // tracing::error!("notification event received: {event:?}"); + } + } + } +}