diff --git a/src/config.rs b/src/config.rs index 570be398a..75620a59d 100644 --- a/src/config.rs +++ b/src/config.rs @@ -29,9 +29,9 @@ use crate::{ notification, request_response, UserProtocol, }, transport::{ - quic::config::Config as QuicConfig, tcp::config::Config as TcpConfig, - webrtc::config::Config as WebRtcConfig, websocket::config::Config as WebSocketConfig, - MAX_PARALLEL_DIALS, + manager::limits::ConnectionLimitsConfig, quic::config::Config as QuicConfig, + tcp::config::Config as TcpConfig, webrtc::config::Config as WebRtcConfig, + websocket::config::Config as WebSocketConfig, MAX_PARALLEL_DIALS, }, types::protocol::ProtocolName, PeerId, @@ -109,6 +109,9 @@ pub struct ConfigBuilder { /// Maximum number of parallel dial attempts. max_parallel_dials: usize, + + /// Connection limits config. + connection_limits: ConnectionLimitsConfig, } impl Default for ConfigBuilder { @@ -137,6 +140,7 @@ impl ConfigBuilder { notification_protocols: HashMap::new(), request_response_protocols: HashMap::new(), known_addresses: Vec::new(), + connection_limits: ConnectionLimitsConfig::default(), } } @@ -243,6 +247,12 @@ impl ConfigBuilder { self } + /// Set connection limits configuration. + pub fn with_connection_limits(mut self, config: ConnectionLimitsConfig) -> Self { + self.connection_limits = config; + self + } + /// Build [`Litep2pConfig`]. pub fn build(mut self) -> Litep2pConfig { let keypair = match self.keypair { @@ -267,6 +277,7 @@ impl ConfigBuilder { notification_protocols: self.notification_protocols, request_response_protocols: self.request_response_protocols, known_addresses: self.known_addresses, + connection_limits: self.connection_limits, } } } @@ -320,4 +331,7 @@ pub struct Litep2pConfig { /// Known addresses. pub(crate) known_addresses: Vec<(PeerId, Vec)>, + + /// Connection limits config. + pub(crate) connection_limits: ConnectionLimitsConfig, } diff --git a/src/error.rs b/src/error.rs index efb8ad6f9..05dbbd822 100644 --- a/src/error.rs +++ b/src/error.rs @@ -28,6 +28,7 @@ use crate::{ protocol::Direction, + transport::manager::limits::ConnectionLimitsError, types::{protocol::ProtocolName, ConnectionId, SubstreamId}, PeerId, }; @@ -118,6 +119,8 @@ pub enum Error { ChannelClogged, #[error("Connection doesn't exist: `{0:?}`")] ConnectionDoesntExist(ConnectionId), + #[error("Exceeded connection limits `{0:?}`")] + ConnectionLimit(ConnectionLimitsError), } #[derive(Debug, thiserror::Error)] @@ -243,6 +246,12 @@ impl From for Error { } } +impl From for Error { + fn from(error: ConnectionLimitsError) -> Self { + Error::ConnectionLimit(error) + } +} + #[cfg(test)] mod tests { use super::*; diff --git a/src/lib.rs b/src/lib.rs index dfa67b977..ef75a3b7f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -143,6 +143,7 @@ impl Litep2p { supported_transports, bandwidth_sink.clone(), litep2p_config.max_parallel_dials, + litep2p_config.connection_limits, ); // add known addresses to `TransportManager`, if any exist diff --git a/src/protocol/libp2p/kademlia/mod.rs b/src/protocol/libp2p/kademlia/mod.rs index a46e59f45..4904fee05 100644 --- a/src/protocol/libp2p/kademlia/mod.rs +++ b/src/protocol/libp2p/kademlia/mod.rs @@ -897,8 +897,11 @@ mod tests { use super::*; use crate::{ - codec::ProtocolCodec, crypto::ed25519::Keypair, transport::manager::TransportManager, - types::protocol::ProtocolName, BandwidthSink, + codec::ProtocolCodec, + crypto::ed25519::Keypair, + transport::manager::{limits::ConnectionLimitsConfig, TransportManager}, + types::protocol::ProtocolName, + BandwidthSink, }; use tokio::sync::mpsc::channel; @@ -914,6 +917,7 @@ mod tests { HashSet::new(), BandwidthSink::new(), 8usize, + ConnectionLimitsConfig::default(), ); let peer = PeerId::random(); diff --git a/src/protocol/mdns.rs b/src/protocol/mdns.rs index 685c34f76..d305a3f76 100644 --- a/src/protocol/mdns.rs +++ b/src/protocol/mdns.rs @@ -334,7 +334,11 @@ impl Mdns { #[cfg(test)] mod tests { use super::*; - use crate::{crypto::ed25519::Keypair, transport::manager::TransportManager, BandwidthSink}; + use crate::{ + crypto::ed25519::Keypair, + transport::manager::{limits::ConnectionLimitsConfig, TransportManager}, + BandwidthSink, + }; use futures::StreamExt; use multiaddr::Protocol; @@ -350,6 +354,7 @@ mod tests { HashSet::new(), BandwidthSink::new(), 8usize, + ConnectionLimitsConfig::default(), ); let mdns1 = Mdns::new( @@ -372,6 +377,7 @@ mod tests { HashSet::new(), BandwidthSink::new(), 8usize, + ConnectionLimitsConfig::default(), ); let mdns2 = Mdns::new( diff --git a/src/protocol/notification/tests/mod.rs b/src/protocol/notification/tests/mod.rs index 43dd9121e..0b275502f 100644 --- a/src/protocol/notification/tests/mod.rs +++ b/src/protocol/notification/tests/mod.rs @@ -29,7 +29,7 @@ use crate::{ }, InnerTransportEvent, ProtocolCommand, TransportService, }, - transport::manager::TransportManager, + transport::manager::{limits::ConnectionLimitsConfig, TransportManager}, types::protocol::ProtocolName, BandwidthSink, PeerId, }; @@ -53,6 +53,7 @@ fn make_notification_protocol() -> ( HashSet::new(), BandwidthSink::new(), 8usize, + ConnectionLimitsConfig::default(), ); let peer = PeerId::random(); diff --git a/src/protocol/request_response/tests.rs b/src/protocol/request_response/tests.rs index 524b3b2da..7c57b4f9d 100644 --- a/src/protocol/request_response/tests.rs +++ b/src/protocol/request_response/tests.rs @@ -29,7 +29,7 @@ use crate::{ InnerTransportEvent, TransportService, }, substream::Substream, - transport::manager::TransportManager, + transport::manager::{limits::ConnectionLimitsConfig, TransportManager}, types::{RequestId, SubstreamId}, BandwidthSink, Error, PeerId, ProtocolName, }; @@ -51,6 +51,7 @@ fn protocol() -> ( HashSet::new(), BandwidthSink::new(), 8usize, + ConnectionLimitsConfig::default(), ); let peer = PeerId::random(); diff --git a/src/transport/manager/limits.rs b/src/transport/manager/limits.rs new file mode 100644 index 000000000..b6838e8e1 --- /dev/null +++ b/src/transport/manager/limits.rs @@ -0,0 +1,204 @@ +// Copyright 2024 litep2p developers +// +// Permission is hereby granted, free of charge, to any person obtaining a +// copy of this software and associated documentation files (the "Software"), +// to deal in the Software without restriction, including without limitation +// the rights to use, copy, modify, merge, publish, distribute, sublicense, +// and/or sell copies of the Software, and to permit persons to whom the +// Software is furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +// OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +// DEALINGS IN THE SOFTWARE. + +//! Limits for the transport manager. + +use crate::types::ConnectionId; + +use std::collections::HashSet; + +/// Configuration for the connection limits. +#[derive(Debug, Clone, Default)] +pub struct ConnectionLimitsConfig { + /// Maximum number of incoming connections that can be established. + max_incoming_connections: Option, + /// Maximum number of outgoing connections that can be established. + max_outgoing_connections: Option, +} + +impl ConnectionLimitsConfig { + /// Configures the maximum number of incoming connections that can be established. + pub fn max_incoming_connections(mut self, limit: Option) -> Self { + self.max_incoming_connections = limit; + self + } + + /// Configures the maximum number of outgoing connections that can be established. + pub fn max_outgoing_connections(mut self, limit: Option) -> Self { + self.max_outgoing_connections = limit; + self + } +} + +/// Error type for connection limits. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum ConnectionLimitsError { + /// Maximum number of incoming connections exceeded. + MaxIncomingConnectionsExceeded, + /// Maximum number of outgoing connections exceeded. + MaxOutgoingConnectionsExceeded, +} + +/// Connection limits. +#[derive(Debug, Clone)] +pub struct ConnectionLimits { + /// Configuration for the connection limits. + config: ConnectionLimitsConfig, + + /// Established incoming connections. + incoming_connections: HashSet, + /// Established outgoing connections. + outgoing_connections: HashSet, +} + +impl ConnectionLimits { + /// Creates a new connection limits instance. + pub fn new(config: ConnectionLimitsConfig) -> Self { + let max_incoming_connections = config.max_incoming_connections.unwrap_or(0); + let max_outgoing_connections = config.max_outgoing_connections.unwrap_or(0); + + Self { + config, + incoming_connections: HashSet::with_capacity(max_incoming_connections), + outgoing_connections: HashSet::with_capacity(max_outgoing_connections), + } + } + + /// Called when dialing an address. + /// + /// Returns the number of outgoing connections permitted to be established. + /// It is guaranteed that at least one connection can be established if the method returns `Ok`. + /// The number of available outgoing connections can influence the maximum parallel dials to a + /// single address. + /// + /// If the maximum number of outgoing connections is not set, `Ok(usize::MAX)` is returned. + pub fn on_dial_address(&mut self) -> Result { + if let Some(max_outgoing_connections) = self.config.max_outgoing_connections { + if self.outgoing_connections.len() >= max_outgoing_connections { + return Err(ConnectionLimitsError::MaxOutgoingConnectionsExceeded); + } + + return Ok(max_outgoing_connections - self.outgoing_connections.len()); + } + + Ok(usize::MAX) + } + + /// Called when a new connection is established. + pub fn on_connection_established( + &mut self, + connection_id: ConnectionId, + is_listener: bool, + ) -> Result<(), ConnectionLimitsError> { + // Check connection limits. + if is_listener { + if let Some(max_incoming_connections) = self.config.max_incoming_connections { + if self.incoming_connections.len() >= max_incoming_connections { + return Err(ConnectionLimitsError::MaxIncomingConnectionsExceeded); + } + } + } else { + if let Some(max_outgoing_connections) = self.config.max_outgoing_connections { + if self.outgoing_connections.len() >= max_outgoing_connections { + return Err(ConnectionLimitsError::MaxOutgoingConnectionsExceeded); + } + } + } + + // Keep track of the connection. + if is_listener { + if self.config.max_incoming_connections.is_some() { + self.incoming_connections.insert(connection_id); + } + } else { + if self.config.max_outgoing_connections.is_some() { + self.outgoing_connections.insert(connection_id); + } + } + + Ok(()) + } + + /// Called when a connection is closed. + pub fn on_connection_closed(&mut self, connection_id: ConnectionId) { + self.incoming_connections.remove(&connection_id); + self.outgoing_connections.remove(&connection_id); + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::types::ConnectionId; + + #[test] + fn connection_limits() { + let config = ConnectionLimitsConfig::default() + .max_incoming_connections(Some(3)) + .max_outgoing_connections(Some(2)); + let mut limits = ConnectionLimits::new(config); + + let connection_id_in_1 = ConnectionId::random(); + let connection_id_in_2 = ConnectionId::random(); + let connection_id_out_1 = ConnectionId::random(); + let connection_id_out_2 = ConnectionId::random(); + let connection_id_in_3 = ConnectionId::random(); + let connection_id_out_3 = ConnectionId::random(); + + // Establish incoming connection. + assert!(limits.on_connection_established(connection_id_in_1, true).is_ok()); + assert_eq!(limits.incoming_connections.len(), 1); + + assert!(limits.on_connection_established(connection_id_in_2, true).is_ok()); + assert_eq!(limits.incoming_connections.len(), 2); + + assert!(limits.on_connection_established(connection_id_in_3, true).is_ok()); + assert_eq!(limits.incoming_connections.len(), 3); + + assert_eq!( + limits.on_connection_established(ConnectionId::random(), true).unwrap_err(), + ConnectionLimitsError::MaxIncomingConnectionsExceeded + ); + assert_eq!(limits.incoming_connections.len(), 3); + + // Establish outgoing connection. + assert!(limits.on_connection_established(connection_id_out_1, false).is_ok()); + assert_eq!(limits.incoming_connections.len(), 3); + assert_eq!(limits.outgoing_connections.len(), 1); + + assert!(limits.on_connection_established(connection_id_out_2, false).is_ok()); + assert_eq!(limits.incoming_connections.len(), 3); + assert_eq!(limits.outgoing_connections.len(), 2); + + assert_eq!( + limits.on_connection_established(connection_id_out_3, false).unwrap_err(), + ConnectionLimitsError::MaxOutgoingConnectionsExceeded + ); + + // Close connections with peer a. + limits.on_connection_closed(connection_id_in_1); + assert_eq!(limits.incoming_connections.len(), 2); + assert_eq!(limits.outgoing_connections.len(), 2); + + limits.on_connection_closed(connection_id_out_1); + assert_eq!(limits.incoming_connections.len(), 2); + assert_eq!(limits.outgoing_connections.len(), 1); + } +} diff --git a/src/transport/manager/mod.rs b/src/transport/manager/mod.rs index a7248811a..e39cd75a6 100644 --- a/src/transport/manager/mod.rs +++ b/src/transport/manager/mod.rs @@ -57,6 +57,7 @@ pub use handle::{TransportHandle, TransportManagerHandle}; pub use types::SupportedTransport; mod address; +pub mod limits; mod types; pub(crate) mod handle; @@ -75,7 +76,8 @@ const SCORE_CONNECT_SUCCESS: i32 = 100i32; /// Score for a non-working address. const SCORE_CONNECT_FAILURE: i32 = -100i32; -/// TODO: +/// The connection established result. +#[derive(Debug, Clone, Copy, Eq, PartialEq)] enum ConnectionEstablishedResult { /// Accept connection and inform `Litep2p` about the connection. Accept, @@ -242,6 +244,9 @@ pub struct TransportManager { /// Pending connections. pending_connections: HashMap, + + /// Connection limits. + connection_limits: limits::ConnectionLimits, } impl TransportManager { @@ -252,6 +257,7 @@ impl TransportManager { supported_transports: HashSet, bandwidth_sink: BandwidthSink, max_parallel_dials: usize, + connection_limits_config: limits::ConnectionLimitsConfig, ) -> (Self, TransportManagerHandle) { let local_peer_id = PeerId::from_public_key(&keypair.public().into()); let peers = Arc::new(RwLock::new(HashMap::new())); @@ -284,6 +290,7 @@ impl TransportManager { pending_connections: HashMap::new(), next_substream_id: Arc::new(AtomicUsize::new(0usize)), next_connection_id: Arc::new(AtomicUsize::new(0usize)), + connection_limits: limits::ConnectionLimits::new(connection_limits_config), }, handle, ) @@ -393,6 +400,12 @@ impl TransportManager { /// /// Returns an error if the peer is unknown or the peer is already connected. pub async fn dial(&mut self, peer: PeerId) -> crate::Result<()> { + // Don't alter the peer state if there's no capacity to dial. + let available_capacity = self.connection_limits.on_dial_address()?; + // The available capacity is the maximum number of connections that can be established, + // so we limit the number of parallel dials to the minimum of these values. + let limit = available_capacity.min(self.max_parallel_dials); + if peer == self.local_peer_id { return Err(Error::TriedToDialSelf); } @@ -435,7 +448,7 @@ impl TransportManager { tracing::debug!( target: LOG_TARGET, ?peer, - "peer is aready being dialed", + "peer is already being dialed", ); peers.insert( @@ -451,7 +464,7 @@ impl TransportManager { } let mut records: HashMap<_, _> = addresses - .take(self.max_parallel_dials) + .take(limit) .into_iter() .map(|record| (record.address().clone(), record)) .collect(); @@ -558,6 +571,8 @@ impl TransportManager { /// /// Returns an error if address it not valid. pub async fn dial_address(&mut self, address: Multiaddr) -> crate::Result<()> { + self.connection_limits.on_dial_address()?; + let mut record = AddressRecord::from_multiaddr(address) .ok_or(Error::AddressError(AddressError::PeerIdMissing))?; @@ -759,6 +774,8 @@ impl TransportManager { peer: PeerId, connection_id: ConnectionId, ) -> crate::Result> { + self.connection_limits.on_connection_closed(connection_id); + let mut peers = self.peers.write(); let Some(context) = peers.get_mut(&peer) else { tracing::warn!( @@ -911,6 +928,21 @@ impl TransportManager { } }; + // Reject the connection if exceeded limits. + if let Err(error) = self + .connection_limits + .on_connection_established(endpoint.connection_id(), endpoint.is_listener()) + { + tracing::debug!( + target: LOG_TARGET, + ?peer, + ?endpoint, + ?error, + "connection limit exceeded, rejecting connection", + ); + return Ok(ConnectionEstablishedResult::Reject); + } + let mut peers = self.peers.write(); match peers.get_mut(&peer) { Some(context) => match context.state { @@ -1051,7 +1083,7 @@ impl TransportManager { }); // since an inbound connection was removed, the outbound connection can be - // removed from pendind dials + // removed from pending dials // // all records have the same `ConnectionId` so it doens't matter which of them // is used to remove the pending dial @@ -1616,6 +1648,8 @@ impl TransportManager { #[cfg(test)] mod tests { + use limits::ConnectionLimitsConfig; + use super::*; use crate::{ crypto::ed25519::Keypair, executor::DefaultExecutor, transport::dummy::DummyTransport, @@ -1625,13 +1659,31 @@ mod tests { sync::Arc, }; + /// Setup TCP address and connection id. + fn setup_dial_addr(peer: PeerId, connection_id: u16) -> (Multiaddr, ConnectionId) { + let dial_address = Multiaddr::empty() + .with(Protocol::Ip4(Ipv4Addr::new(127, 0, 0, 1))) + .with(Protocol::Tcp(8888 + connection_id)) + .with(Protocol::P2p( + Multihash::from_bytes(&peer.to_bytes()).unwrap(), + )); + let connection_id = ConnectionId::from(connection_id as usize); + + (dial_address, connection_id) + } + #[test] #[should_panic] #[cfg(debug_assertions)] fn duplicate_protocol() { let sink = BandwidthSink::new(); - let (mut manager, _handle) = - TransportManager::new(Keypair::generate(), HashSet::new(), sink, 8usize); + let (mut manager, _handle) = TransportManager::new( + Keypair::generate(), + HashSet::new(), + sink, + 8usize, + ConnectionLimitsConfig::default(), + ); manager.register_protocol( ProtocolName::from("/notif/1"), @@ -1650,8 +1702,13 @@ mod tests { #[cfg(debug_assertions)] fn fallback_protocol_as_duplicate_main_protocol() { let sink = BandwidthSink::new(); - let (mut manager, _handle) = - TransportManager::new(Keypair::generate(), HashSet::new(), sink, 8usize); + let (mut manager, _handle) = TransportManager::new( + Keypair::generate(), + HashSet::new(), + sink, + 8usize, + ConnectionLimitsConfig::default(), + ); manager.register_protocol( ProtocolName::from("/notif/1"), @@ -1673,8 +1730,13 @@ mod tests { #[cfg(debug_assertions)] fn duplicate_fallback_protocol() { let sink = BandwidthSink::new(); - let (mut manager, _handle) = - TransportManager::new(Keypair::generate(), HashSet::new(), sink, 8usize); + let (mut manager, _handle) = TransportManager::new( + Keypair::generate(), + HashSet::new(), + sink, + 8usize, + ConnectionLimitsConfig::default(), + ); manager.register_protocol( ProtocolName::from("/notif/1"), @@ -1699,8 +1761,13 @@ mod tests { #[cfg(debug_assertions)] fn duplicate_transport() { let sink = BandwidthSink::new(); - let (mut manager, _handle) = - TransportManager::new(Keypair::generate(), HashSet::new(), sink, 8usize); + let (mut manager, _handle) = TransportManager::new( + Keypair::generate(), + HashSet::new(), + sink, + 8usize, + ConnectionLimitsConfig::default(), + ); manager.register_transport(SupportedTransport::Tcp, Box::new(DummyTransport::new())); manager.register_transport(SupportedTransport::Tcp, Box::new(DummyTransport::new())); @@ -1711,7 +1778,13 @@ mod tests { let keypair = Keypair::generate(); let local_peer_id = PeerId::from_public_key(&keypair.public().into()); let sink = BandwidthSink::new(); - let (mut manager, _handle) = TransportManager::new(keypair, HashSet::new(), sink, 8usize); + let (mut manager, _handle) = TransportManager::new( + keypair, + HashSet::new(), + sink, + 8usize, + ConnectionLimitsConfig::default(), + ); assert!(manager.dial(local_peer_id).await.is_err()); } @@ -1723,6 +1796,7 @@ mod tests { HashSet::new(), BandwidthSink::new(), 8usize, + ConnectionLimitsConfig::default(), ); let _handle = manager.transport_handle(Arc::new(DefaultExecutor {})); manager.register_transport(SupportedTransport::Tcp, Box::new(DummyTransport::new())); @@ -1752,6 +1826,7 @@ mod tests { HashSet::new(), BandwidthSink::new(), 8usize, + ConnectionLimitsConfig::default(), ); let peer = PeerId::random(); let dial_address = Multiaddr::empty() @@ -1813,6 +1888,7 @@ mod tests { HashSet::new(), BandwidthSink::new(), 8usize, + ConnectionLimitsConfig::default(), ); let _handle = manager.transport_handle(Arc::new(DefaultExecutor {})); manager.register_transport(SupportedTransport::Tcp, Box::new(DummyTransport::new())); @@ -1843,6 +1919,7 @@ mod tests { HashSet::new(), BandwidthSink::new(), 8usize, + ConnectionLimitsConfig::default(), ); let _handle = manager.transport_handle(Arc::new(DefaultExecutor {})); manager.register_transport(SupportedTransport::Tcp, Box::new(DummyTransport::new())); @@ -1887,6 +1964,7 @@ mod tests { HashSet::new(), BandwidthSink::new(), 8usize, + ConnectionLimitsConfig::default(), ); let _handle = manager.transport_handle(Arc::new(DefaultExecutor {})); manager.register_transport(SupportedTransport::Tcp, Box::new(DummyTransport::new())); @@ -1905,6 +1983,7 @@ mod tests { HashSet::new(), BandwidthSink::new(), 8usize, + ConnectionLimitsConfig::default(), ); let _handle = manager.transport_handle(Arc::new(DefaultExecutor {})); manager.register_transport(SupportedTransport::Tcp, Box::new(DummyTransport::new())); @@ -1933,6 +2012,7 @@ mod tests { HashSet::from_iter([SupportedTransport::Tcp, SupportedTransport::Quic]), BandwidthSink::new(), 8usize, + ConnectionLimitsConfig::default(), ); // ipv6 @@ -1991,6 +2071,7 @@ mod tests { HashSet::new(), BandwidthSink::new(), 8usize, + ConnectionLimitsConfig::default(), ); let _handle = manager.transport_handle(Arc::new(DefaultExecutor {})); manager.register_transport(SupportedTransport::Tcp, Box::new(DummyTransport::new())); @@ -2057,6 +2138,7 @@ mod tests { HashSet::new(), BandwidthSink::new(), 8usize, + ConnectionLimitsConfig::default(), ); let _handle = manager.transport_handle(Arc::new(DefaultExecutor {})); manager.register_transport(SupportedTransport::Tcp, Box::new(DummyTransport::new())); @@ -2143,6 +2225,7 @@ mod tests { HashSet::new(), BandwidthSink::new(), 8usize, + ConnectionLimitsConfig::default(), ); let _handle = manager.transport_handle(Arc::new(DefaultExecutor {})); manager.register_transport(SupportedTransport::Tcp, Box::new(DummyTransport::new())); @@ -2227,6 +2310,7 @@ mod tests { HashSet::new(), BandwidthSink::new(), 8usize, + ConnectionLimitsConfig::default(), ); manager.register_transport(SupportedTransport::Tcp, Box::new(DummyTransport::new())); @@ -2331,6 +2415,7 @@ mod tests { HashSet::new(), BandwidthSink::new(), 8usize, + ConnectionLimitsConfig::default(), ); manager.register_transport(SupportedTransport::Tcp, Box::new(DummyTransport::new())); @@ -2433,6 +2518,7 @@ mod tests { HashSet::new(), BandwidthSink::new(), 8usize, + ConnectionLimitsConfig::default(), ); manager.register_transport(SupportedTransport::Tcp, Box::new(DummyTransport::new())); @@ -2539,6 +2625,7 @@ mod tests { HashSet::new(), BandwidthSink::new(), 8usize, + ConnectionLimitsConfig::default(), ); manager.register_transport(SupportedTransport::Tcp, Box::new(DummyTransport::new())); @@ -2667,6 +2754,7 @@ mod tests { HashSet::new(), BandwidthSink::new(), 8usize, + ConnectionLimitsConfig::default(), ); manager.on_dial_failure(ConnectionId::random()).unwrap(); @@ -2685,6 +2773,7 @@ mod tests { HashSet::new(), BandwidthSink::new(), 8usize, + ConnectionLimitsConfig::default(), ); let connection_id = ConnectionId::random(); let peer = PeerId::random(); @@ -2705,6 +2794,7 @@ mod tests { HashSet::new(), BandwidthSink::new(), 8usize, + ConnectionLimitsConfig::default(), ); manager.on_connection_closed(PeerId::random(), ConnectionId::random()).unwrap(); } @@ -2722,6 +2812,7 @@ mod tests { HashSet::new(), BandwidthSink::new(), 8usize, + ConnectionLimitsConfig::default(), ); manager .on_connection_opened( @@ -2745,6 +2836,7 @@ mod tests { HashSet::new(), BandwidthSink::new(), 8usize, + ConnectionLimitsConfig::default(), ); let connection_id = ConnectionId::random(); let peer = PeerId::random(); @@ -2768,6 +2860,7 @@ mod tests { HashSet::new(), BandwidthSink::new(), 8usize, + ConnectionLimitsConfig::default(), ); let connection_id = ConnectionId::random(); let peer = PeerId::random(); @@ -2794,6 +2887,7 @@ mod tests { HashSet::new(), BandwidthSink::new(), 8usize, + ConnectionLimitsConfig::default(), ); manager @@ -2814,6 +2908,7 @@ mod tests { HashSet::new(), BandwidthSink::new(), 8usize, + ConnectionLimitsConfig::default(), ); let connection_id = ConnectionId::random(); let peer = PeerId::random(); @@ -2833,6 +2928,7 @@ mod tests { HashSet::new(), BandwidthSink::new(), 8usize, + ConnectionLimitsConfig::default(), ); assert!(manager.next().await.is_none()); @@ -2845,6 +2941,7 @@ mod tests { HashSet::new(), BandwidthSink::new(), 8usize, + ConnectionLimitsConfig::default(), ); let peer = { @@ -2892,6 +2989,7 @@ mod tests { HashSet::new(), BandwidthSink::new(), 8usize, + ConnectionLimitsConfig::default(), ); let peer = { @@ -2935,6 +3033,7 @@ mod tests { HashSet::new(), BandwidthSink::new(), 8usize, + ConnectionLimitsConfig::default(), ); let peer = { @@ -2978,6 +3077,7 @@ mod tests { HashSet::new(), BandwidthSink::new(), 8usize, + ConnectionLimitsConfig::default(), ); // transport doesn't start with ip/dns @@ -3043,6 +3143,7 @@ mod tests { HashSet::new(), BandwidthSink::new(), 8usize, + ConnectionLimitsConfig::default(), ); async fn call_manager(manager: &mut TransportManager, address: Multiaddr) { @@ -3096,6 +3197,7 @@ mod tests { HashSet::new(), BandwidthSink::new(), 8usize, + ConnectionLimitsConfig::default(), ); let peer = PeerId::random(); let dial_address = Multiaddr::empty() @@ -3187,6 +3289,7 @@ mod tests { HashSet::new(), BandwidthSink::new(), 8usize, + ConnectionLimitsConfig::default(), ); let peer = PeerId::random(); let dial_address = Multiaddr::empty() @@ -3266,4 +3369,168 @@ mod tests { state => panic!("invalid peer state: {state:?}"), } } + + #[tokio::test] + async fn manager_limits_incoming_connections() { + let _ = tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .try_init(); + + let (mut manager, _handle) = TransportManager::new( + Keypair::generate(), + HashSet::new(), + BandwidthSink::new(), + 8usize, + ConnectionLimitsConfig::default() + .max_incoming_connections(Some(3)) + .max_outgoing_connections(Some(2)), + ); + // The connection limit is agnostic of the underlying transports. + manager.register_transport(SupportedTransport::Tcp, Box::new(DummyTransport::new())); + + let peer = PeerId::random(); + let second_peer = PeerId::random(); + + // Setup addresses. + let (first_addr, first_connection_id) = setup_dial_addr(peer, 0); + let (second_addr, second_connection_id) = setup_dial_addr(second_peer, 1); + let (_, third_connection_id) = setup_dial_addr(peer, 2); + let (_, remote_connection_id) = setup_dial_addr(peer, 3); + + // Peer established the first inbound connection. + let result = manager + .on_connection_established( + peer, + &Endpoint::listener(first_addr.clone(), first_connection_id), + ) + .unwrap(); + assert_eq!(result, ConnectionEstablishedResult::Accept); + + // The peer is allowed to dial us a second time. + let result = manager + .on_connection_established( + peer, + &Endpoint::listener(first_addr.clone(), second_connection_id), + ) + .unwrap(); + assert_eq!(result, ConnectionEstablishedResult::Accept); + + // Second peer calls us. + let result = manager + .on_connection_established( + second_peer, + &Endpoint::listener(second_addr.clone(), third_connection_id), + ) + .unwrap(); + assert_eq!(result, ConnectionEstablishedResult::Accept); + + // Limits of inbound connections are reached. + let result = manager + .on_connection_established( + second_peer, + &Endpoint::listener(second_addr.clone(), remote_connection_id), + ) + .unwrap(); + assert_eq!(result, ConnectionEstablishedResult::Reject); + + // Close one connection. + let _ = manager.on_connection_closed(peer, first_connection_id).unwrap(); + + // The second peer can establish 2 inbounds now. + let result = manager + .on_connection_established( + second_peer, + &Endpoint::listener(second_addr.clone(), remote_connection_id), + ) + .unwrap(); + assert_eq!(result, ConnectionEstablishedResult::Accept); + } + + #[tokio::test] + async fn manager_limits_outbound_connections() { + let _ = tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .try_init(); + + let (mut manager, _handle) = TransportManager::new( + Keypair::generate(), + HashSet::new(), + BandwidthSink::new(), + 8usize, + ConnectionLimitsConfig::default() + .max_incoming_connections(Some(3)) + .max_outgoing_connections(Some(2)), + ); + // The connection limit is agnostic of the underlying transports. + manager.register_transport(SupportedTransport::Tcp, Box::new(DummyTransport::new())); + + let peer = PeerId::random(); + let second_peer = PeerId::random(); + let third_peer = PeerId::random(); + + // Setup addresses. + let (first_addr, first_connection_id) = setup_dial_addr(peer, 0); + let (second_addr, second_connection_id) = setup_dial_addr(second_peer, 1); + let (third_addr, third_connection_id) = setup_dial_addr(third_peer, 2); + + // First dial. + manager.dial_address(first_addr.clone()).await.unwrap(); + + // Second dial. + manager.dial_address(second_addr.clone()).await.unwrap(); + + // Third dial, we have a limit on 2 outbound connections. + manager.dial_address(third_addr.clone()).await.unwrap(); + + let result = manager + .on_connection_established( + peer, + &Endpoint::dialer(first_addr.clone(), first_connection_id), + ) + .unwrap(); + assert_eq!(result, ConnectionEstablishedResult::Accept); + + let result = manager + .on_connection_established( + second_peer, + &Endpoint::dialer(second_addr.clone(), second_connection_id), + ) + .unwrap(); + assert_eq!(result, ConnectionEstablishedResult::Accept); + + // We have reached the limit now. + let result = manager + .on_connection_established( + third_peer, + &Endpoint::dialer(third_addr.clone(), third_connection_id), + ) + .unwrap(); + assert_eq!(result, ConnectionEstablishedResult::Reject); + + // While we have 2 outbound connections active, any dials will fail immediately. + // We cannot perform this check for the non negotiated inbound connections yet, + // since the transport will eagerly accept and negotiate them. This requires + // a refactor into the transport manager, to not waste resources on + // negotiating connections that will be rejected. + let result = manager.dial(peer).await.unwrap_err(); + assert!(std::matches!( + result, + Error::ConnectionLimit(limits::ConnectionLimitsError::MaxOutgoingConnectionsExceeded) + )); + let result = manager.dial_address(first_addr.clone()).await.unwrap_err(); + assert!(std::matches!( + result, + Error::ConnectionLimit(limits::ConnectionLimitsError::MaxOutgoingConnectionsExceeded) + )); + + // Close one connection. + let _ = manager.on_connection_closed(peer, first_connection_id).unwrap(); + // We can now dial again. + manager.dial_address(first_addr.clone()).await.unwrap(); + + let result = manager + .on_connection_established(peer, &Endpoint::dialer(first_addr, first_connection_id)) + .unwrap(); + assert_eq!(result, ConnectionEstablishedResult::Accept); + } } diff --git a/src/transport/tcp/mod.rs b/src/transport/tcp/mod.rs index eb06a8633..62cc00d9f 100644 --- a/src/transport/tcp/mod.rs +++ b/src/transport/tcp/mod.rs @@ -490,7 +490,9 @@ mod tests { codec::ProtocolCodec, crypto::ed25519::Keypair, executor::DefaultExecutor, - transport::manager::{ProtocolContext, SupportedTransport, TransportManager}, + transport::manager::{ + limits::ConnectionLimitsConfig, ProtocolContext, SupportedTransport, TransportManager, + }, types::protocol::ProtocolName, BandwidthSink, PeerId, }; @@ -683,6 +685,7 @@ mod tests { HashSet::new(), BandwidthSink::new(), 8usize, + ConnectionLimitsConfig::default(), ); let handle = manager.transport_handle(Arc::new(DefaultExecutor {})); manager.register_transport(